├── .gitignore ├── Dockerfile ├── LICENSE ├── Makefile ├── README.md ├── allrank ├── __init__.py ├── click_models │ ├── __init__.py │ ├── base.py │ ├── cascade_models.py │ ├── click_utils.py │ └── duplicate_aware.py ├── config.py ├── config_template.json ├── data │ ├── __init__.py │ ├── dataset_loading.py │ ├── dataset_saving.py │ └── generate_dummy_data.py ├── inference │ ├── __init__.py │ └── inference_utils.py ├── main.py ├── models │ ├── __init__.py │ ├── losses │ │ ├── __init__.py │ │ ├── approxNDCG.py │ │ ├── bce.py │ │ ├── binary_listNet.py │ │ ├── lambdaLoss.py │ │ ├── listMLE.py │ │ ├── listNet.py │ │ ├── loss_utils.py │ │ ├── neuralNDCG.py │ │ ├── ordinal.py │ │ ├── pointwise.py │ │ └── rankNet.py │ ├── metrics.py │ ├── model.py │ ├── model_utils.py │ ├── positional.py │ └── transformer.py ├── rank_and_click.py ├── training │ ├── __init__.py │ ├── early_stop.py │ └── train_utils.py └── utils │ ├── __init__.py │ ├── args_utils.py │ ├── command_executor.py │ ├── config_utils.py │ ├── experiments.py │ ├── file_utils.py │ ├── ltr_logging.py │ ├── python_utils.py │ └── tensorboard_utils.py ├── reproducibility ├── HOWTO.md ├── configs │ ├── contextaware_web30k │ │ ├── ndcgloss2pp.json │ │ ├── ndcgloss2pp_mlp.json │ │ ├── ordinal.json │ │ └── ordinal_mlp.json │ └── neuralndcg_web30k │ │ ├── approxndcg.json │ │ ├── lambdarank_atmax.json │ │ └── neuralndcg_atmax.json └── normalize_features.py ├── requirements.txt ├── scripts ├── ci.sh ├── local_config.json ├── local_config_click_model.json ├── run_example.sh ├── run_in_docker.sh ├── run_in_docker_click.sh └── run_tests.sh ├── setup.cfg ├── setup.py └── tests ├── __init__.py ├── click_models ├── __init__.py ├── test_alternative_click_models.py ├── test_apply_click_model.py ├── test_base_cascade_model.py ├── test_diverse_clicks_model.py ├── test_duplicate_click_model.py ├── test_feature_click_model.py ├── test_fixed_click_model.py ├── test_masked_click_model.py └── test_random_click_model.py ├── losses ├── __init__.py ├── test_approxndcg.py ├── test_binary_listnet.py ├── test_lambdaloss.py ├── test_listmle.py ├── test_listnet.py ├── test_loss_ordinal.py ├── test_loss_pointwise.py ├── test_mrr.py ├── test_ndcg.py ├── test_neuralndcg.py ├── test_ranknet.py └── utils.py └── test_rank_slates.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | .DS_Store 5 | # C extensions 6 | *.so 7 | 8 | # Distribution / packaging 9 | .Python 10 | .DS_Store 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | parts/ 18 | sdist/ 19 | var/ 20 | *.egg-info/ 21 | .installed.cfg 22 | *.egg 23 | 24 | # PyInstaller 25 | # Usually these files are written by a python script from a template 26 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 27 | *.manifest 28 | *.spec 29 | 30 | # Installer logs 31 | pip-log.txt 32 | pip-delete-this-directory.txt 33 | 34 | # Unit test / coverage reports 35 | htmlcov/ 36 | .tox/ 37 | .coverage 38 | .cache 39 | nosetests.xml 40 | coverage.xml 41 | .mypy_cache 42 | 43 | # Translations 44 | *.mo 45 | *.pot 46 | 47 | # Django stuff: 48 | *.log 49 | 50 | # Sphinx documentation 51 | docs/_build/ 52 | 53 | *.sqlite3 54 | 55 | # Vagrant 56 | vagrant/.vagrant 57 | vagrant/Vagrantfile.local 58 | 59 | # MKDocs 60 | site/ 61 | 62 | # Static files 63 | bower_components/ 64 | node_modules/ 65 | 66 | # Editors 67 | .idea/ 68 | *~ 69 | 70 | # Project-specific files 71 | model_output 72 | allrank/config.json 73 | task-data 74 | dummy_data 75 | 76 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | ARG arch_version 2 | 3 | FROM python:3.10 as base 4 | 5 | MAINTAINER MLR 6 | 7 | RUN mkdir /allrank 8 | COPY requirements.txt setup.py Makefile README.md /allrank/ 9 | 10 | RUN make -C /allrank install-reqs 11 | 12 | WORKDIR /allrank 13 | 14 | FROM base as CPU 15 | RUN python3 -m pip install torchvision==0.14.1 torch==1.13.1 --extra-index-url https://download.pytorch.org/whl/cpu 16 | 17 | FROM base as GPU 18 | RUN python3 -m pip install torchvision==0.14.1 torch==1.13.1 19 | 20 | FROM ${arch_version} as FINAL 21 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2019 Allegro.pl sp. z o.o. 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: ci 2 | ci: lint tests wheel egg 3 | 4 | .PHONY: lint 5 | lint: 6 | flake8 allrank 7 | flake8 tests 8 | mypy allrank --ignore-missing-imports --check-untyped-defs 9 | mypy tests --ignore-missing-imports --check-untyped-defs 10 | 11 | .PHONY: install-reqs 12 | install-reqs: 13 | pip install -r requirements.txt 14 | python setup.py install 15 | 16 | .PHONY: tests 17 | tests: install-reqs unittests 18 | 19 | .PHONY: unittests 20 | unittests: 21 | python -m pytest 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # allRank : Learning to Rank in PyTorch 2 | 3 | ## About 4 | 5 | allRank is a PyTorch-based framework for training neural Learning-to-Rank (LTR) models, featuring implementations of: 6 | * common pointwise, pairwise and listwise loss functions 7 | * fully connected and Transformer-like scoring functions 8 | * commonly used evaluation metrics like Normalized Discounted Cumulative Gain (NDCG) and Mean Reciprocal Rank (MRR) 9 | * click-models for experiments on simulated click-through data 10 | 11 | ### Motivation 12 | 13 | allRank provides an easy and flexible way to experiment with various LTR neural network models and loss functions. 14 | It is easy to add a custom loss, and to configure the model and the training procedure. 15 | We hope that allRank will facilitate both research in neural LTR and its industrial applications. 16 | 17 | ## Features 18 | 19 | ### Implemented loss functions: 20 | 1. ListNet (for binary and graded relevance) 21 | 2. ListMLE 22 | 3. RankNet 23 | 4. Ordinal loss 24 | 5. LambdaRank 25 | 6. LambdaLoss 26 | 7. ApproxNDCG 27 | 8. RMSE 28 | 9. NeuralNDCG (introduced in https://arxiv.org/pdf/2102.07831) 29 | 30 | ### Getting started guide 31 | 32 | To help you get started, we provide a ```run_example.sh``` script which generates dummy ranking data in libsvm format and trains 33 | a Transformer model on the data using provided example ```config.json``` config file. Once you run the script, the dummy data can be found in `dummy_data` directory 34 | and the results of the experiment in `test_run` directory. To run the example, Docker is required. 35 | 36 | ### Getting the right architecture version (GPU vs CPU-only) 37 | 38 | Since torch binaries are different for GPU and CPU and GPU version doesn't work on CPU - one must select & build appropriate docker image version. 39 | 40 | To do so pass `gpu` or `cpu` as `arch_version` build-arg in 41 | 42 | ```docker build --build-arg arch_version=${ARCH_VERSION}``` 43 | 44 | When calling `run_example.sh` you can select the proper version by a first cmd line argument e.g. 45 | 46 | ```run_example.sh gpu ...``` 47 | 48 | with `cpu` being the default if not specified. 49 | 50 | ### Configuring your model & training 51 | 52 | To train your own model, configure your experiment in ```config.json``` file and run 53 | 54 | ```python allrank/main.py --config_file_name allrank/config.json --run_id --job_dir ``` 55 | 56 | All the hyperparameters of the training procedure: i.e. model defintion, data location, loss and metrics used, training hyperparametrs etc. are controlled 57 | by the ```config.json``` file. We provide a template file ```config_template.json``` where supported attributes, their meaning and possible values are explained. 58 | Note that following MSLR-WEB30K convention, your libsvm file with training data should be named `train.txt`. You can specify the name of the validation dataset 59 | (eg. valid or test) in the config. Results will be saved under the path ```/results/``` 60 | 61 | Google Cloud Storage is supported in allRank as a place for data and job results. 62 | 63 | 64 | ### Implementing custom loss functions 65 | 66 | To experiment with your own custom loss, you need to implement a function that takes two tensors (model prediction and ground truth) as input 67 | and put it in the `losses` package, making sure it is exposed on a package level. 68 | To use it in training, simply pass the name (and args, if your loss method has some hyperparameters) of your function in the correct place in the config file: 69 | 70 | ``` 71 | "loss": { 72 | "name": "yourLoss", 73 | "args": { 74 | "arg1": val1, 75 | "arg2: val2 76 | } 77 | } 78 | ``` 79 | 80 | ### Applying click-model 81 | 82 | To apply a click model you need to first have an allRank model trained. 83 | Next, run: 84 | 85 | ```python allrank/rank_and_click.py --input-model-path --roles --config_file_name allrank/config.json --run_id --job_dir ``` 86 | 87 | The model will be used to rank all slates from the dataset specified in config. Next - a click model configured in config will be applied and the resulting click-through dataset will be written under ```/results/``` in a libSVM format. 88 | The path to the results directory may then be used as an input for another allRank model training. 89 | 90 | ## Continuous integration 91 | 92 | You should run `scripts/ci.sh` to verify that code passes style guidelines and unit tests. 93 | 94 | ## Research 95 | 96 | This framework was developed to support the research project [Context-Aware Learning to Rank with Self-Attention](https://arxiv.org/abs/2005.10084). If you use allRank in your research, please cite: 97 | ``` 98 | @article{Pobrotyn2020ContextAwareLT, 99 | title={Context-Aware Learning to Rank with Self-Attention}, 100 | author={Przemyslaw Pobrotyn and Tomasz Bartczak and Mikolaj Synowiec and Radoslaw Bialobrzeski and Jaroslaw Bojar}, 101 | journal={ArXiv}, 102 | year={2020}, 103 | volume={abs/2005.10084} 104 | } 105 | ``` 106 | Additionally, if you use the NeuralNDCG loss function, please cite the corresponding work, [NeuralNDCG: Direct Optimisation of a Ranking Metric via Differentiable Relaxation of Sorting](https://arxiv.org/abs/2102.07831): 107 | ``` 108 | @article{Pobrotyn2021NeuralNDCG, 109 | title={NeuralNDCG: Direct Optimisation of a Ranking Metric via Differentiable Relaxation of Sorting}, 110 | author={Przemyslaw Pobrotyn and Radoslaw Bialobrzeski}, 111 | journal={ArXiv}, 112 | year={2021}, 113 | volume={abs/2102.07831} 114 | } 115 | ``` 116 | 117 | ## License 118 | 119 | Apache 2 License 120 | -------------------------------------------------------------------------------- /allrank/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allegro/allRank/c88475661cb72db292d13283fdbc4f2ae6498ee4/allrank/__init__.py -------------------------------------------------------------------------------- /allrank/click_models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allegro/allRank/c88475661cb72db292d13283fdbc4f2ae6498ee4/allrank/click_models/__init__.py -------------------------------------------------------------------------------- /allrank/click_models/base.py: -------------------------------------------------------------------------------- 1 | import math 2 | from abc import ABC, abstractmethod 3 | from typing import List, Tuple, Callable 4 | 5 | import numpy as np 6 | import torch 7 | 8 | 9 | class ClickModel(ABC): 10 | """ 11 | Base class for all click models. Specifies the click model contract 12 | """ 13 | 14 | @abstractmethod 15 | def click(self, documents: Tuple[torch.Tensor, torch.Tensor]) -> np.ndarray: 16 | """ 17 | Applies a click model and returns the mask for documents. 18 | 19 | :rtype: np.ndarray [ number_of_documents ] -> a mask of the same length as the documents - 20 | defining whether a document was clicked (1), not clicked (0) or is a padded element (-1) 21 | 22 | :param documents: Tuple of : 23 | torch.Tensor [ number_of_documents, dimensionality_of_latent_vector ], representing features of documents 24 | torch.Tensor [ number_of_documents ] representing relevancy of documents 25 | """ 26 | pass 27 | 28 | 29 | class RandomClickModel(ClickModel): 30 | """ 31 | This ClickModel clicks a configured number of times on random documents 32 | """ 33 | 34 | def __init__(self, n_clicks: int): 35 | """ 36 | 37 | :param n_clicks: number of documents that will be clicked 38 | """ 39 | self.n_clicks = n_clicks 40 | 41 | def click(self, documents: Tuple[torch.Tensor, torch.Tensor]) -> np.ndarray: 42 | X, y = documents 43 | clicks = np.random.choice(range(len(y)), size=self.n_clicks, replace=False) 44 | mask = np.zeros(len(y), dtype=bool) 45 | mask[clicks] = 1 46 | return mask 47 | 48 | 49 | class FixedClickModel(ClickModel): 50 | """ 51 | This ClickModel clicks on documents at fixed positions 52 | """ 53 | 54 | def __init__(self, click_positions: List[int]): 55 | """ 56 | 57 | :param click_positions: list of indices of documents that will be clicked 58 | """ 59 | self.click_positions = click_positions 60 | 61 | def click(self, documents: Tuple[torch.Tensor, torch.Tensor]) -> np.ndarray: 62 | X, y = documents 63 | clicks = np.zeros(len(y), dtype=bool) 64 | clicks[self.click_positions] = 1 65 | return clicks 66 | 67 | 68 | class MultipleClickModel(ClickModel): 69 | """ 70 | This click model uses one of given click models with given probability 71 | """ 72 | 73 | def __init__(self, inner_click_models: List[ClickModel], probabilities: List[float]): 74 | """ 75 | 76 | :param inner_click_models: list of click models to choose from 77 | :param probabilities: list of probabilities - must be of the same length as list of click models and sum to 1.0 78 | """ 79 | self.inner_click_models = inner_click_models 80 | assert math.isclose(np.sum(probabilities), 1.0, abs_tol=1e-5), \ 81 | f"probabilities should sum to one, but got {probabilities} which sums to {np.sum(probabilities)}" 82 | self.probabilities = np.array(probabilities).cumsum() 83 | 84 | def click(self, documents: Tuple[torch.Tensor, torch.Tensor]) -> np.ndarray: 85 | index = np.argmax(np.random.rand() < self.probabilities) 86 | result = self.inner_click_models[index].click(documents) # type: ignore 87 | return result 88 | 89 | 90 | class ConditionedClickModel(ClickModel): 91 | """ 92 | This click model allows to combine multiple click models with a logical funciton 93 | """ 94 | 95 | def __init__(self, inner_click_models: List[ClickModel], combiner: Callable): 96 | """ 97 | 98 | :param inner_click_models: list of click models to combine 99 | :param combiner: a function applied to the result of clicks from click models - e.g. np.all or np.any 100 | """ 101 | self.inner_click_models = inner_click_models 102 | self.combiner = combiner 103 | 104 | def click(self, documents: Tuple[torch.Tensor, torch.Tensor]) -> np.ndarray: 105 | clicks_from_click_models = [click_model.click(documents) for click_model in self.inner_click_models] 106 | return self.combiner(clicks_from_click_models, 0) 107 | 108 | 109 | class MaxClicksModel(ClickModel): 110 | """ 111 | This click model takes other click model and limits the number of clicks to given value 112 | effectively keeping top `max_clicks` clicks 113 | """ 114 | 115 | def __init__(self, inner_click_model: ClickModel, max_clicks: int): 116 | """ 117 | 118 | :param inner_click_model: a click model to generate clicks 119 | :param max_clicks: number of clicks that should be preserved 120 | """ 121 | self.inner_click_model = inner_click_model 122 | self.max_clicks = max_clicks 123 | 124 | def click(self, documents: Tuple[torch.Tensor, torch.Tensor]) -> np.ndarray: 125 | underlying_clicks = self.inner_click_model.click(documents) 126 | if self.max_clicks is not None: 127 | max_clicks_mask = underlying_clicks.cumsum() <= self.max_clicks 128 | return underlying_clicks * max_clicks_mask 129 | return underlying_clicks 130 | 131 | 132 | class OnlyRelevantClickModel(ClickModel): 133 | """ 134 | This ClickModel clicks on a document when its relevancy is greater that or equal to a predefined value 135 | 136 | """ 137 | 138 | def __init__(self, relevancy_threshold: float): 139 | """ 140 | :param relevancy_threshold: a minimum value of relevancy of a document to be clicked (inclusive) 141 | """ 142 | self.relevancy_threshold = relevancy_threshold 143 | 144 | def click(self, documents: Tuple[torch.Tensor, torch.Tensor]) -> np.ndarray: 145 | X, y = documents 146 | return np.array(y) >= self.relevancy_threshold 147 | -------------------------------------------------------------------------------- /allrank/click_models/cascade_models.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import numpy as np 4 | import torch 5 | from scipy.spatial.distance import cdist 6 | 7 | from allrank.click_models.base import ClickModel 8 | from allrank.click_models.duplicate_aware import EverythingButDuplicatesClickModel 9 | from allrank.data.dataset_loading import PADDED_Y_VALUE 10 | 11 | 12 | class BaseCascadeModel(ClickModel): 13 | """ 14 | This ClickModel simulates decaying probability of observing an item 15 | and clicks on an observed item given it's relevance is greater than or equal to a given threshold 16 | 17 | """ 18 | 19 | def __init__(self, eta: float, threshold: float): 20 | """ 21 | 22 | :param eta: the power to be applied over a result of a decay function (specified as 1/position) 23 | to decide whether a document was observed 24 | :param threshold: a minimum value of relevancy of an observed document to be clicked (inclusive) 25 | """ 26 | self.eta = eta 27 | self.threshold = threshold 28 | 29 | def click(self, documents: Tuple[torch.Tensor, torch.Tensor]) -> np.ndarray: 30 | X, y = documents 31 | observed_mask = (1 / np.arange(1, len(y) + 1) ** self.eta) >= np.random.rand(len(y)) 32 | return (y * observed_mask >= self.threshold).numpy() 33 | 34 | 35 | class DiverseClicksModel(ClickModel): 36 | """ 37 | A 'diverse-clicks' model from Seq2Slate paper https://arxiv.org/abs/1810.02019 38 | It clicks on documents from top to the bottom if: 39 | 1. a delegate click model decides to click on the document (in the original paper - CascadeModel 40 | 2. it is no closer than a defined percentile of distances to a previously clicked document 41 | """ 42 | 43 | def __init__(self, inner_click_model, q_percentile=0.5): 44 | """ 45 | 46 | :param inner_click_model: original, non-diversified click model 47 | :param q_percentile: a percentile of pairwise distances that will be used as a distance threshold to tell if a pair is a duplicate 48 | """ 49 | self.inner_click_model = inner_click_model 50 | self.q_percentile = q_percentile 51 | 52 | def __pairwise_distances_list(self, X): 53 | dist = cdist(X, X, metric='euclidean') 54 | triu_indices = np.triu_indices(dist.shape[0] - 1) 55 | return dist[:-1, 1:][triu_indices] 56 | 57 | def click(self, documents: Tuple[torch.Tensor, torch.Tensor]) -> np.ndarray: 58 | X, y = documents 59 | 60 | real_docs_mask = (y != PADDED_Y_VALUE) 61 | real_X = X[real_docs_mask, :] 62 | 63 | distances = self.__pairwise_distances_list(real_X) 64 | if len(distances) == 0: 65 | duplicate_margin = 0 66 | else: 67 | duplicate_margin = np.quantile(distances, q=self.q_percentile) 68 | 69 | def not_similar(x_vec, clicked_X): 70 | cX = clicked_X.copy() 71 | cX.append(x_vec) 72 | cX = torch.stack(cX, dim=0) 73 | cm = EverythingButDuplicatesClickModel(duplicate_margin) 74 | clicks = cm.click((cX, np.ones(len(cX)))) 75 | last_element_clicked = clicks[-1] 76 | return last_element_clicked == 1 77 | 78 | relevant_for_click = self.inner_click_model.click(documents) 79 | 80 | clicked_Xs = [] # type: ignore 81 | indices_to_click = np.argwhere(relevant_for_click == 1) 82 | for idx_to_click in indices_to_click: 83 | idx_to_click = idx_to_click[0] 84 | X_to_click = X[idx_to_click] 85 | if not_similar(X_to_click, clicked_Xs): 86 | clicked_Xs.append(X_to_click) 87 | else: 88 | relevant_for_click[idx_to_click] = 0 89 | 90 | return relevant_for_click 91 | -------------------------------------------------------------------------------- /allrank/click_models/click_utils.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple, Union 2 | 3 | import numpy as np 4 | import torch 5 | 6 | from allrank.click_models.base import ClickModel 7 | from allrank.data.dataset_loading import PADDED_Y_VALUE 8 | 9 | 10 | def click_on_slates(slates: Union[Tuple[np.ndarray, np.ndarray], Tuple[torch.Tensor, torch.Tensor]], 11 | click_model: ClickModel, include_empty: bool) -> Tuple[List[Union[np.ndarray, torch.Tensor]], List[List[int]]]: 12 | """ 13 | This metod runs a click model on a list of slates and returns new slates with `y` taken from clicks 14 | 15 | :param slates: a Tuple of X, y: 16 | X being a list of slates represented by document vectors 17 | y being a list of slates represented by document relevancies 18 | :param click_model: a click model to be applied to every slate 19 | :param include_empty: if True - will return even slates that didn't get any click 20 | :return: Tuple of X, clicks, X representing the same document vectors as input 'X', clicks representing click mask for every slate 21 | """ 22 | X, y = slates 23 | clicks = [MaskedRemainMasked(click_model).click(slate) for slate in zip(X, y)] 24 | X_with_clicks = [[X, slate_clicks] for X, slate_clicks in list(zip(X, clicks)) if 25 | (np.sum(slate_clicks > 0) > 0 or include_empty)] 26 | return_X, clicks = map(list, zip(*X_with_clicks)) # type: ignore 27 | return return_X, clicks # type: ignore 28 | 29 | 30 | class MaskedRemainMasked(ClickModel): 31 | """ 32 | This click model wraps another click model and: 33 | 1. ensures inner click model do not get documents that were padded 34 | 2. ensures padded documents get '-1' in 'clicked' vector 35 | """ 36 | 37 | def __init__(self, inner_click_model: ClickModel): 38 | """ 39 | 40 | :param inner_click_model: a click model that is run on the list of non-padded documents 41 | """ 42 | self.inner_click_model = inner_click_model 43 | 44 | def click(self, documents: Union[Tuple[np.ndarray, np.ndarray], Tuple[torch.Tensor, torch.Tensor]]) -> np.ndarray: 45 | X, y = documents 46 | padded_values_mask = y == PADDED_Y_VALUE 47 | real_X = X[~padded_values_mask] 48 | real_y = y[~padded_values_mask] 49 | clicks = self.inner_click_model.click((real_X, real_y)) 50 | final_clicks = np.zeros_like(y) 51 | final_clicks[padded_values_mask] = PADDED_Y_VALUE 52 | final_clicks[~padded_values_mask] = clicks 53 | return final_clicks 54 | -------------------------------------------------------------------------------- /allrank/click_models/duplicate_aware.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Union 2 | 3 | import numpy as np 4 | import torch 5 | from scipy.spatial.distance import cdist 6 | 7 | from allrank.click_models.base import ClickModel 8 | 9 | 10 | class EverythingButDuplicatesClickModel(ClickModel): 11 | """ 12 | This ClickModel clicks on every document, which was not previously clicked, 13 | if the distance between this document and any previous is larger than given margin in given metric 14 | """ 15 | 16 | def __init__(self, duplicate_margin: float = 0, metric: str = "euclidean"): 17 | """ 18 | 19 | :param duplicate_margin: a margin to tell whether a pair of documents is treated as a duplicate. 20 | If the distance is less than or equal to this value - this marks a duplicate 21 | :param metric: a metric in which pairwise distances are calculated 22 | (metric must be supported by `scipy.spatial.distance.cdist`) 23 | """ 24 | self.duplicate_margin = duplicate_margin 25 | self.metric = metric 26 | 27 | def click(self, documents: Tuple[torch.Tensor, Union[torch.Tensor, np.ndarray]]) -> np.ndarray: 28 | X, y = documents 29 | dist = cdist(X, X, metric=self.metric) 30 | dist = np.triu(dist, k=1) 31 | np.fill_diagonal(dist, np.inf) 32 | indices = np.tril_indices(dist.shape[0]) 33 | dist[indices] = np.inf 34 | return 1 * (dist > self.duplicate_margin).min(0) 35 | -------------------------------------------------------------------------------- /allrank/config.py: -------------------------------------------------------------------------------- 1 | import json 2 | from collections import defaultdict 3 | from typing import Dict, List, Optional 4 | 5 | from attr import attrib, attrs 6 | 7 | 8 | @attrs 9 | class TransformerConfig: 10 | N = attrib(type=int) 11 | d_ff = attrib(type=int) 12 | h = attrib(type=int) 13 | positional_encoding = attrib(type=dict) 14 | dropout = attrib(type=float) 15 | 16 | 17 | @attrs 18 | class FCConfig: 19 | sizes = attrib(type=List[int]) 20 | input_norm = attrib(type=bool) 21 | activation = attrib(type=str) 22 | dropout = attrib(type=float) 23 | 24 | 25 | @attrs 26 | class PostModelConfig: 27 | d_output = attrib(type=int) 28 | output_activation = attrib(type=str) 29 | 30 | 31 | @attrs 32 | class ModelConfig: 33 | fc_model = attrib(type=FCConfig) 34 | transformer = attrib(type=TransformerConfig) 35 | post_model = attrib(type=PostModelConfig) 36 | 37 | 38 | @attrs 39 | class PositionalEncoding: 40 | strategy = attrib(type=str) 41 | max_indices = attrib(type=int) 42 | 43 | 44 | @attrs 45 | class DataConfig: 46 | path = attrib(type=str) 47 | num_workers = attrib(type=int) 48 | batch_size = attrib(type=int) 49 | slate_length = attrib(type=int) 50 | validation_ds_role = attrib(type=str) 51 | 52 | 53 | @attrs 54 | class TrainingConfig: 55 | epochs = attrib(type=int) 56 | gradient_clipping_norm = attrib(type=float) 57 | early_stopping_patience = attrib(type=int, default=0) 58 | 59 | 60 | @attrs 61 | class NameArgsConfig: 62 | name = attrib(type=str) 63 | args = attrib(type=dict) 64 | 65 | 66 | @attrs 67 | class Config: 68 | model = attrib(type=ModelConfig) 69 | data = attrib(type=DataConfig) 70 | optimizer = attrib(type=NameArgsConfig) 71 | training = attrib(type=TrainingConfig) 72 | loss = attrib(type=NameArgsConfig) 73 | metrics = attrib(type=Dict[str, List[int]]) 74 | lr_scheduler = attrib(type=NameArgsConfig) 75 | val_metric = attrib(type=str, default=None) 76 | expected_metrics = attrib(type=Dict[str, Dict[str, float]], default={}) 77 | detect_anomaly = attrib(type=bool, default=False) 78 | click_model = attrib(type=Optional[NameArgsConfig], default=None) 79 | 80 | @classmethod 81 | def from_json(cls, config_path): 82 | with open(config_path) as config_file: 83 | config = json.load(config_file) 84 | return Config.from_dict(config) 85 | 86 | @classmethod 87 | def from_dict(cls, config): 88 | config["model"] = ModelConfig(**config["model"]) 89 | if config["model"].transformer: 90 | config["model"].transformer = TransformerConfig(**config["model"].transformer) 91 | if config["model"].transformer.positional_encoding: 92 | config["model"].transformer.positional_encoding = PositionalEncoding( 93 | **config["model"].transformer.positional_encoding) 94 | config["data"] = DataConfig(**config["data"]) 95 | config["optimizer"] = NameArgsConfig(**config["optimizer"]) 96 | config["training"] = TrainingConfig(**config["training"]) 97 | config["metrics"] = cls._parse_metrics(config["metrics"]) 98 | config["lr_scheduler"] = NameArgsConfig(**config["lr_scheduler"]) 99 | config["loss"] = NameArgsConfig(**config["loss"]) 100 | if "click_model" in config.keys(): 101 | config["click_model"] = NameArgsConfig(**config["click_model"]) 102 | return cls(**config) 103 | 104 | @staticmethod 105 | def _parse_metrics(metrics): 106 | metrics_dict = defaultdict(list) # type: Dict[str, list] 107 | for metric_string in metrics: 108 | try: 109 | name, at = metric_string.split("_") 110 | metrics_dict[name].append(int(at)) 111 | except (ValueError, TypeError): 112 | raise MetricConfigError( 113 | metric_string, 114 | "Wrong formatting of metric in config. Expected format: _ where name is valid metric name and at is and int") 115 | return metrics_dict 116 | 117 | 118 | class MetricConfigError(Exception): 119 | pass 120 | -------------------------------------------------------------------------------- /allrank/config_template.json: -------------------------------------------------------------------------------- 1 | { 2 | "model": 3 | { 4 | "fc_model": 5 | { 6 | "sizes":