├── .cursorrules ├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.md │ └── feature_request.md └── workflows │ └── python-publish.yml ├── .gitignore ├── .readthedocs.yaml ├── CITATION.cff ├── README-PyPI.md ├── README.md ├── README_zh.md ├── docs ├── .gitignore ├── docs │ ├── api │ │ ├── dataset.md │ │ ├── generators │ │ │ ├── generator.md │ │ │ ├── models │ │ │ │ ├── base_rag_model.md │ │ │ │ ├── fid_model.md │ │ │ │ ├── huggingface_model.md │ │ │ │ ├── litellm_model.md │ │ │ │ ├── model_factory.md │ │ │ │ ├── openai_model.md │ │ │ │ └── vllm_model.md │ │ │ ├── prompt_generator.md │ │ │ └── rag_methods │ │ │ │ ├── basic_rag.md │ │ │ │ ├── basic_rag_method.md │ │ │ │ ├── chain_of_thought_rag.md │ │ │ │ ├── fid_rag_method.md │ │ │ │ ├── in_context_ralm_rag.md │ │ │ │ └── zero_shot.md │ │ ├── index.md │ │ ├── metrics.md │ │ ├── rerankings │ │ │ ├── base.md │ │ │ ├── blender.md │ │ │ ├── colbert_ranker.md │ │ │ ├── echo_rank.md │ │ │ ├── first_ranker.md │ │ │ ├── flashrank.md │ │ │ ├── incontext_reranker.md │ │ │ ├── inrank.md │ │ │ ├── listt5.md │ │ │ ├── lit5.md │ │ │ ├── llm2vec_reranker.md │ │ │ ├── llm_layerwise_ranker.md │ │ │ ├── monobert.md │ │ │ ├── monot5.md │ │ │ ├── rank_fid.md │ │ │ ├── rankgpt.md │ │ │ ├── rankt5.md │ │ │ ├── reranking.md │ │ │ ├── sentence_transformer_reranker.md │ │ │ ├── splade_reranker.md │ │ │ ├── transformer_reranker.md │ │ │ ├── twolar.md │ │ │ ├── upr.md │ │ │ ├── vicuna_reranker.md │ │ │ └── zephyr_reranker.md │ │ └── retrievers │ │ │ ├── bge.md │ │ │ ├── bm25.md │ │ │ ├── colbert.md │ │ │ ├── contriever.md │ │ │ ├── dense.md │ │ │ └── retriever.md │ ├── assets │ │ ├── overview.jpg │ │ └── rankify-crop.png │ ├── contribution.md │ ├── getting-started.md │ ├── index.md │ ├── installation.md │ ├── js │ │ ├── clustrmaps.js │ │ └── runllm-widget.js │ ├── stylesheets │ │ └── extra.css │ └── tutorials │ │ └── index.md ├── mkdocs.yml ├── overrides │ ├── home.html │ ├── main.html │ └── partials │ │ ├── footer.html │ │ └── tabs.html └── requirements.txt ├── examples ├── demo.py ├── generator_fid.py ├── generator_huggingface.py ├── generator_litellm.py ├── generator_openai.py ├── generator_ralm.py ├── generator_vllm.py ├── reranking.py ├── reranking_example.py ├── retreiver.py └── retrieved_dataset.py ├── images ├── output.gif ├── overview.png ├── rankify-crop.png ├── rankify-logo-.png └── rankify-logo.png ├── pyproject.toml └── rankify ├── __init__.py ├── __version__.py ├── agent └── __init__.py ├── dataset ├── __init__.py └── dataset.py ├── generator ├── generator.py ├── models │ ├── base_rag_model.py │ ├── fid_model.py │ ├── huggingface_model.py │ ├── litellm_model.py │ ├── model_factory.py │ ├── openai_model.py │ └── vllm_model.py ├── prompt_generator.py └── rag_methods │ ├── base_rag_method.py │ ├── basic_rag.py │ ├── chain_of_thought_rag.py │ ├── fid_rag_method.py │ ├── in_context_ralm_rag.py │ └── zero_shot.py ├── indexing └── __init__.py ├── metrics ├── __init__.py └── metrics.py ├── models ├── __init__.py ├── apiranker.py ├── base.py ├── blender_reranker.py ├── colbert_ranker.py ├── echorank.py ├── first_reranker.py ├── flashrank.py ├── incontext_reranker.py ├── inranker.py ├── listt5.py ├── lit5_reranker.py ├── llm2vec_reranker.py ├── llm_layerwise_ranker.py ├── monobert.py ├── monot5.py ├── monot5_.py ├── rank_fid.py ├── rankgpt.py ├── rankt5.py ├── reranking.py ├── sentence_transformer_reranker.py ├── splade_reranker.py ├── transformer_ranker.py ├── twolar.py ├── upr.py ├── vicuna_reranker.py └── zephyr_reranker.py ├── requirements.txt ├── retrievers ├── BGERetriever.py ├── OnlineRetriever.py ├── __init__.py ├── bm25.py ├── colbert.py ├── contriever.py ├── dpr.py ├── hyde.py ├── retriever.py └── wikipedia_cleaner.py ├── tools ├── .env.example ├── Readme.md ├── Tools.py ├── __init__.py └── websearch │ ├── __init__.py │ ├── content_scraping │ ├── config.py │ ├── crawler.py │ ├── scrapedResult.py │ ├── strategyFactory.py │ └── utils.py │ ├── context │ ├── __init__.py │ └── build_search_context.py │ ├── models │ ├── SerpResults.py │ └── Source.py │ └── serp │ ├── SearchAPIClient.py │ ├── SerperApiClient.py │ ├── config.py │ └── errors.py └── utils ├── __init__.py ├── api ├── __init__.py ├── claudeclient.py ├── litellmclient.py └── openaiclient.py ├── dataset ├── __init__.py ├── download.py └── utils.py ├── generator ├── FiD │ ├── __init__.py │ ├── data.py │ ├── model.py │ └── util.py ├── __init__.py ├── download.py ├── generator_models.py └── huggingface_models │ ├── __init__.py │ └── model_utils.py ├── helper.py ├── models ├── __init__.py ├── colbert.py ├── fidt5.py ├── incontext_reranker │ ├── __init__.py │ ├── custom_cache.py │ ├── custom_modeling_llama.py │ └── custom_modeling_mistral.py ├── llm2vec.py ├── llm2vec_model │ ├── __init__.py │ ├── attn_mask_utils.py │ ├── bidirectional_gemma.py │ ├── bidirectional_llama.py │ ├── bidirectional_mistral.py │ ├── bidirectional_qwen2.py │ └── utils.py ├── rank_listwise_os_llm.py ├── rank_llm │ ├── __init__.py │ ├── data.py │ └── rerank │ │ ├── __init__.py │ │ ├── api_keys.py │ │ ├── identity_reranker.py │ │ ├── listwise │ │ ├── __init__.py │ │ ├── listwise_rankllm.py │ │ ├── lit5 │ │ │ ├── __init__.py │ │ │ ├── model.py │ │ │ └── modeling_t5.py │ │ └── rank_listwise_os_llm.py │ │ ├── rankllm.py │ │ └── reranker.py └── twolar_utils.py ├── pre_defind_models.py ├── pre_defined_datasets.py ├── pre_defined_methods.py ├── pre_defined_methods_retrievers.py └── retrievers ├── __init__.py ├── colbert ├── __init__.py ├── colbert │ ├── __init__.py │ ├── data │ │ ├── __init__.py │ │ ├── collection.py │ │ ├── dataset.py │ │ ├── examples.py │ │ ├── queries.py │ │ └── ranking.py │ ├── distillation │ │ ├── __init__.py │ │ ├── ranking_scorer.py │ │ └── scorer.py │ ├── evaluation │ │ ├── __init__.py │ │ ├── load_model.py │ │ ├── loaders.py │ │ └── metrics.py │ ├── index.py │ ├── index_updater.py │ ├── indexer.py │ ├── indexing │ │ ├── __init__.py │ │ ├── codecs │ │ │ ├── __init__.py │ │ │ ├── decompress_residuals.cpp │ │ │ ├── decompress_residuals.cu │ │ │ ├── packbits.cpp │ │ │ ├── packbits.cu │ │ │ ├── residual.py │ │ │ ├── residual_embeddings.py │ │ │ └── residual_embeddings_strided.py │ │ ├── collection_encoder.py │ │ ├── collection_indexer.py │ │ ├── index_manager.py │ │ ├── index_saver.py │ │ ├── loaders.py │ │ └── utils.py │ ├── infra │ │ ├── __init__.py │ │ ├── config │ │ │ ├── __init__.py │ │ │ ├── base_config.py │ │ │ ├── config.py │ │ │ ├── core_config.py │ │ │ └── settings.py │ │ ├── launcher.py │ │ ├── provenance.py │ │ ├── run.py │ │ └── utilities │ │ │ ├── __init__.py │ │ │ ├── annotate_em.py │ │ │ ├── create_triples.py │ │ │ └── minicorpus.py │ ├── modeling │ │ ├── __init__.py │ │ ├── base_colbert.py │ │ ├── checkpoint.py │ │ ├── colbert.py │ │ ├── hf_colbert.py │ │ ├── reranker │ │ │ ├── __init__.py │ │ │ ├── electra.py │ │ │ └── tokenizer.py │ │ ├── segmented_maxsim.cpp │ │ └── tokenization │ │ │ ├── __init__.py │ │ │ ├── doc_tokenization.py │ │ │ ├── query_tokenization.py │ │ │ └── utils.py │ ├── parameters.py │ ├── ranking │ │ └── __init__.py │ ├── search │ │ ├── __init__.py │ │ ├── candidate_generation.py │ │ ├── decompress_residuals.cpp │ │ ├── filter_pids.cpp │ │ ├── index_loader.py │ │ ├── index_storage.py │ │ ├── segmented_lookup.cpp │ │ ├── strided_tensor.py │ │ └── strided_tensor_core.py │ ├── searcher.py │ ├── tests │ │ ├── __init__.py │ │ ├── e2e_test.py │ │ ├── index_coalesce_test.py │ │ ├── index_updater_test.py │ │ └── tokenizers_test.py │ ├── trainer.py │ ├── training │ │ ├── __init__.py │ │ ├── eager_batcher.py │ │ ├── lazy_batcher.py │ │ ├── rerank_batcher.py │ │ ├── training.py │ │ └── utils.py │ ├── utilities │ │ ├── __init__.py │ │ ├── annotate_em.py │ │ ├── create_triples.py │ │ └── minicorpus.py │ └── utils │ │ ├── __init__.py │ │ ├── amp.py │ │ ├── coalesce.py │ │ ├── distributed.py │ │ ├── logging.py │ │ ├── parser.py │ │ ├── runs.py │ │ └── utils.py └── utility │ ├── __init__.py │ ├── evaluate │ ├── __init__.py │ ├── annotate_EM.py │ ├── annotate_EM_helpers.py │ ├── evaluate_lotte_rankings.py │ └── msmarco_passages.py │ ├── preprocess │ ├── __init__.py │ ├── docs2passages.py │ └── queries_split.py │ ├── rankings │ ├── __init__.py │ ├── dev_subsample.py │ ├── merge.py │ ├── split_by_offset.py │ ├── split_by_queries.py │ └── tune.py │ ├── supervision │ ├── __init__.py │ ├── self_training.py │ └── triples.py │ └── utils │ ├── __init__.py │ ├── dpr.py │ ├── qa_loaders.py │ └── save_metadata.py ├── contriever ├── __init__.py ├── contriever.py ├── data.py ├── dist_utils.py ├── index.py ├── normalize_text.py └── utils.py ├── hyde ├── __init__.py ├── generator.py └── promptor.py └── splade ├── __init__.py ├── datasets ├── __init__.py ├── dataloaders.py ├── datasets.py └── rerank.py ├── indexing ├── __init__.py └── inverted_index.py ├── losses ├── __init__.py ├── pairwise.py ├── pointwise.py └── regularization.py ├── models ├── __init__.py ├── models_utils.py ├── transformer_rank.py └── transformer_rep.py ├── tasks ├── __init__.py ├── amp.py ├── base │ ├── __init__.py │ ├── early_stopping.py │ ├── evaluator.py │ ├── saver.py │ └── trainer.py ├── transformer_evaluator.py └── transformer_trainer.py └── utils ├── __init__.py ├── hydra.py ├── index_figure.py ├── metrics.py ├── processing_trec_eval.py └── utils.py /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Steps to reproduce the behavior: 15 | 1. Go to '...' 16 | 2. Click on '....' 17 | 3. Scroll down to '....' 18 | 4. See error 19 | 20 | **Expected behavior** 21 | A clear and concise description of what you expected to happen. 22 | 23 | **Screenshots** 24 | If applicable, add screenshots to help explain your problem. 25 | 26 | **Desktop (please complete the following information):** 27 | - OS: [e.g. iOS] 28 | - Browser [e.g. chrome, safari] 29 | - Version [e.g. 22] 30 | 31 | **Smartphone (please complete the following information):** 32 | - Device: [e.g. iPhone6] 33 | - OS: [e.g. iOS8.1] 34 | - Browser [e.g. stock browser, safari] 35 | - Version [e.g. 22] 36 | 37 | **Additional context** 38 | Add any other context about the problem here. 39 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | name: Upload Python Package 2 | 3 | on: 4 | release: 5 | types: [published] 6 | 7 | permissions: 8 | contents: read 9 | 10 | jobs: 11 | deploy: 12 | runs-on: ubuntu-latest 13 | 14 | steps: 15 | - name: Checkout code 16 | uses: actions/checkout@v4 17 | 18 | - name: Set up Python 19 | uses: actions/setup-python@v3 20 | with: 21 | python-version: '3.10' 22 | 23 | - name: Install dependencies 24 | run: | 25 | python -m pip install --upgrade pip 26 | pip install --upgrade setuptools setuptools_scm[toml] build 27 | 28 | - name: Build package 29 | run: python -m build 30 | 31 | - name: Publish package 32 | uses: pypa/gh-action-pypi-publish@release/v1 33 | with: 34 | user: __token__ 35 | password: ${{ secrets.PYPI_API_TOKEN }} 36 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | build: 4 | os: ubuntu-22.04 5 | tools: 6 | python: "3.10" 7 | 8 | python: 9 | install: 10 | - requirements: docs/requirements.txt # Use the lightweight docs requirements 11 | 12 | 13 | mkdocs: 14 | configuration: docs/mkdocs.yml -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | date-released: 2025-02 3 | message: "If you use this software, please cite it as below." 4 | authors: 5 | - family-names: "Abdallah" 6 | given-names: "Abdelrahman" 7 | - family-names: "Mozafari" 8 | given-names: "Jamshid" 9 | - family-names: "Piryani" 10 | given-names: "Bhawna" 11 | - family-names: "Ali" 12 | given-names: "Mohammed" 13 | - family-names: "Jatowt" 14 | given-names: "Adam" 15 | title: "Rankify: A Comprehensive Python Toolkit for Retrieval, Re-Ranking, and Retrieval-Augmented Generation" 16 | url: "https://arxiv.org/abs/2502.02464" 17 | preferred-citation: 18 | type: article 19 | authors: 20 | - family-names: "Abdallah" 21 | given-names: "Abdelrahman" 22 | - family-names: "Mozafari" 23 | given-names: "Jamshid" 24 | - family-names: "Piryani" 25 | given-names: "Bhawna" 26 | - family-names: "Ali" 27 | given-names: "Mohammed" 28 | - family-names: "Jatowt" 29 | given-names: "Adam" 30 | title: "Rankify: A Comprehensive Python Toolkit for Retrieval, Re-Ranking, and Retrieval-Augmented Generation" 31 | journal: "CoRR" 32 | volume: "abs/2502.02464" 33 | year: 2025 34 | url: "https://arxiv.org/abs/2502.02464" 35 | eprinttype: "arXiv" 36 | eprint: "2502.02464" -------------------------------------------------------------------------------- /docs/.gitignore: -------------------------------------------------------------------------------- 1 | site 2 | .cache -------------------------------------------------------------------------------- /docs/docs/api/dataset.md: -------------------------------------------------------------------------------- 1 | # Dataset 2 | 3 | ::: rankify.dataset.dataset 4 | handler: python 5 | options: 6 | show_source: true 7 | show_undocumented_members: true 8 | show_root_heading: true 9 | show_inherited_members: true 10 | heading_level: 2 11 | docstring_style: google 12 | show_root_full_path: true 13 | show_object_full_path: false 14 | separate_signature: false -------------------------------------------------------------------------------- /docs/docs/api/generators/generator.md: -------------------------------------------------------------------------------- 1 | # Generator 2 | 3 | ::: rankify.generator.generator 4 | handler: python 5 | options: 6 | show_source: true 7 | show_undocumented_members: true 8 | show_root_heading: true 9 | show_inherited_members: true 10 | heading_level: 2 11 | docstring_style: google 12 | show_root_full_path: true 13 | show_object_full_path: false 14 | separate_signature: false -------------------------------------------------------------------------------- /docs/docs/api/generators/models/base_rag_model.md: -------------------------------------------------------------------------------- 1 | # Base RAG Model 2 | 3 | ::: rankify.generator.models.base_rag_model 4 | handler: python 5 | options: 6 | show_source: true 7 | show_undocumented_members: true 8 | show_root_heading: true 9 | show_inherited_members: true 10 | heading_level: 2 11 | docstring_style: google 12 | show_root_full_path: true 13 | show_object_full_path: false 14 | separate_signature: false -------------------------------------------------------------------------------- /docs/docs/api/generators/models/fid_model.md: -------------------------------------------------------------------------------- 1 | # Base RAG Model 2 | 3 | ::: rankify.generator.models.fid_model 4 | handler: python 5 | options: 6 | show_source: true 7 | show_undocumented_members: true 8 | show_root_heading: true 9 | show_inherited_members: true 10 | heading_level: 2 11 | docstring_style: google 12 | show_root_full_path: true 13 | show_object_full_path: false 14 | separate_signature: false -------------------------------------------------------------------------------- /docs/docs/api/generators/models/huggingface_model.md: -------------------------------------------------------------------------------- 1 | # HuggingFace Model 2 | 3 | ::: rankify.generator.models.huggingface_model 4 | handler: python 5 | options: 6 | show_source: true 7 | show_undocumented_members: true 8 | show_root_heading: true 9 | show_inherited_members: true 10 | heading_level: 2 11 | docstring_style: google 12 | show_root_full_path: true 13 | show_object_full_path: false 14 | separate_signature: false -------------------------------------------------------------------------------- /docs/docs/api/generators/models/litellm_model.md: -------------------------------------------------------------------------------- 1 | # LiteLLM Model 2 | 3 | ::: rankify.generator.models.litellm_model 4 | handler: python 5 | options: 6 | show_source: true 7 | show_undocumented_members: true 8 | show_root_heading: true 9 | show_inherited_members: true 10 | heading_level: 2 11 | docstring_style: google 12 | show_root_full_path: true 13 | show_object_full_path: false 14 | separate_signature: false -------------------------------------------------------------------------------- /docs/docs/api/generators/models/model_factory.md: -------------------------------------------------------------------------------- 1 | # Model Factory 2 | 3 | ::: rankify.generator.models.model_factory 4 | handler: python 5 | options: 6 | show_source: true 7 | show_undocumented_members: true 8 | show_root_heading: true 9 | show_inherited_members: true 10 | heading_level: 2 11 | docstring_style: google 12 | show_root_full_path: true 13 | show_object_full_path: false 14 | separate_signature: false -------------------------------------------------------------------------------- /docs/docs/api/generators/models/openai_model.md: -------------------------------------------------------------------------------- 1 | # OpenAI Model 2 | 3 | ::: rankify.generator.models.openai_model 4 | handler: python 5 | options: 6 | show_source: true 7 | show_undocumented_members: true 8 | show_root_heading: true 9 | show_inherited_members: true 10 | heading_level: 2 11 | docstring_style: google 12 | show_root_full_path: true 13 | show_object_full_path: false 14 | separate_signature: false -------------------------------------------------------------------------------- /docs/docs/api/generators/models/vllm_model.md: -------------------------------------------------------------------------------- 1 | # vLLM Model 2 | 3 | ::: rankify.generator.models.vllm_model 4 | handler: python 5 | options: 6 | show_source: true 7 | show_undocumented_members: true 8 | show_root_heading: true 9 | show_inherited_members: true 10 | heading_level: 2 11 | docstring_style: google 12 | show_root_full_path: true 13 | show_object_full_path: false 14 | separate_signature: false -------------------------------------------------------------------------------- /docs/docs/api/generators/prompt_generator.md: -------------------------------------------------------------------------------- 1 | # Prompt Generator 2 | 3 | ::: rankify.generator.prompt_generator 4 | handler: python 5 | options: 6 | show_source: true 7 | show_undocumented_members: true 8 | show_root_heading: true 9 | show_inherited_members: true 10 | heading_level: 2 11 | docstring_style: google 12 | show_root_full_path: true 13 | show_object_full_path: false 14 | separate_signature: false -------------------------------------------------------------------------------- /docs/docs/api/generators/rag_methods/basic_rag.md: -------------------------------------------------------------------------------- 1 | # Basic RAG 2 | 3 | ::: rankify.generator.rag_methods.basic_rag 4 | handler: python 5 | options: 6 | show_source: true 7 | show_undocumented_members: true 8 | show_root_heading: true 9 | show_inherited_members: true 10 | heading_level: 2 11 | docstring_style: google 12 | show_root_full_path: true 13 | show_object_full_path: false 14 | separate_signature: false -------------------------------------------------------------------------------- /docs/docs/api/generators/rag_methods/basic_rag_method.md: -------------------------------------------------------------------------------- 1 | # Basic RAG Method 2 | 3 | ::: rankify.generator.rag_methods.basic_rag_method 4 | handler: python 5 | options: 6 | show_source: true 7 | show_undocumented_members: true 8 | show_root_heading: true 9 | show_inherited_members: true 10 | heading_level: 2 11 | docstring_style: google 12 | show_root_full_path: true 13 | show_object_full_path: false 14 | separate_signature: false -------------------------------------------------------------------------------- /docs/docs/api/generators/rag_methods/chain_of_thought_rag.md: -------------------------------------------------------------------------------- 1 | # Chain of Thought RAG 2 | 3 | ::: rankify.generator.rag_methods.chain_of_thought_rag 4 | handler: python 5 | options: 6 | show_source: true 7 | show_undocumented_members: true 8 | show_root_heading: true 9 | show_inherited_members: true 10 | heading_level: 2 11 | docstring_style: google 12 | show_root_full_path: true 13 | show_object_full_path: false 14 | separate_signature: false -------------------------------------------------------------------------------- /docs/docs/api/generators/rag_methods/fid_rag_method.md: -------------------------------------------------------------------------------- 1 | # FiD Generator 2 | 3 | ::: rankify.generator..rag_methods.fid_rag_method 4 | handler: python 5 | options: 6 | show_source: true 7 | show_undocumented_members: true 8 | show_root_heading: true 9 | show_inherited_members: true 10 | heading_level: 2 11 | docstring_style: google 12 | show_root_full_path: true 13 | show_object_full_path: false 14 | separate_signature: false -------------------------------------------------------------------------------- /docs/docs/api/generators/rag_methods/in_context_ralm_rag.md: -------------------------------------------------------------------------------- 1 | # In Context Ralm Generator 2 | 3 | ::: rankify.generator.rag_methods.in_context_ralm_rag 4 | handler: python 5 | options: 6 | show_source: true 7 | show_undocumented_members: true 8 | show_root_heading: true 9 | show_inherited_members: true 10 | heading_level: 2 11 | docstring_style: google 12 | show_root_full_path: true 13 | show_object_full_path: false 14 | separate_signature: false -------------------------------------------------------------------------------- /docs/docs/api/generators/rag_methods/zero_shot.md: -------------------------------------------------------------------------------- 1 | # Zero shot Generation 2 | 3 | ::: rankify.generator.rag_methods.zero_shot 4 | handler: python 5 | options: 6 | show_source: true 7 | show_undocumented_members: true 8 | show_root_heading: true 9 | show_inherited_members: true 10 | heading_level: 2 11 | docstring_style: google 12 | show_root_full_path: true 13 | show_object_full_path: false 14 | separate_signature: false -------------------------------------------------------------------------------- /docs/docs/api/index.md: -------------------------------------------------------------------------------- 1 | # API Reference 2 | 3 | Below is an overview of the modules, classes, and functions available in **Rankify**. 4 | 5 | ## Dataset Module 6 | - [Dataset](dataset.md) 7 | 8 | ## Metrics Module 9 | - [Metrics](metrics.md) 10 | 11 | ## Retrievers 12 | - [Retriever](retrievers/retriever.md) 13 | - [BM25 Retriever](retrievers/bm25.md) 14 | - [Dense Retriever](retrievers/dense.md) 15 | - [BGE Retriever](retrievers/bge.md) 16 | - [ColBERT Retriever](retrievers/colbert.md) 17 | - [Contriever Retriever](retrievers/contriever.md) 18 | 19 | ## Rerankers 20 | - [Base](rerankings/base.md) 21 | - [Reranking](rerankings/reranking.md) 22 | - [UPR](rerankings/upr.md) 23 | - [FlashRank](rerankings/flashrank.md) 24 | - [RankGPT](rerankings/rankgpt.md) 25 | - [Blender Reranker](rerankings/blender.md) 26 | - [ColBERT Reranker](rerankings/colbert_ranker.md) 27 | - [EchoRank](rerankings/echo_rank.md) 28 | - [First Reranker](rerankings/first_ranker.md) 29 | - [Incontext Reranker](rerankings/incontext_reranker.md) 30 | - [InRank Reranker](rerankings/inrank.md) 31 | - [ListT5 Reranker](rerankings/listt5.md) 32 | - [Lit5 Reranker](rerankings/lit5.md) 33 | - [LLM Layerwise Reranker](rerankings/llm_layerwise_ranker.md) 34 | - [LLM2vec Reranker](rerankings/llm2vec_reranker.md) 35 | - [MonoBERT Reranker](rerankings/monobert.md) 36 | - [MonoT5 Reranker](rerankings/monot5.md) 37 | - [Rank Fid Reranker](rerankings/rank_fid.md) 38 | - [RankT5 Reranker](rerankings/rankt5.md) 39 | - [Sentence Transformer Reranker](rerankings/sentence_transformer_reranker.md) 40 | - [SPLADE Reranker](rerankings/splade_reranker.md) 41 | - [Transformer Reranker](rerankings/transformer_reranker.md) 42 | - [TwoLAR Reranker](rerankings/twolar.md) 43 | - [Vicuna Reranker](rerankings/vicuna_reranker.md) 44 | - [Zephyr Reranker](rerankings/zephyr_reranker.md) 45 | 46 | ## Generators 47 | - [Base](generators/base.md) 48 | - [Generator](generators/generator.md) 49 | - [FiD Generator](generators/fid.md) 50 | - [In Context RALM Generator](generators/in_context_ralm.md) 51 | -------------------------------------------------------------------------------- /docs/docs/api/metrics.md: -------------------------------------------------------------------------------- 1 | # Metrics 2 | 3 | ::: rankify.metrics.metrics 4 | handler: python 5 | options: 6 | show_source: true 7 | show_undocumented_members: true 8 | show_root_heading: true 9 | show_inherited_members: true 10 | heading_level: 2 11 | docstring_style: google 12 | show_root_full_path: true 13 | show_object_full_path: false 14 | separate_signature: false -------------------------------------------------------------------------------- /docs/docs/api/rerankings/base.md: -------------------------------------------------------------------------------- 1 | # BaseRanking Model 2 | 3 | ::: rankify.models.base 4 | handler: python 5 | options: 6 | show_source: true 7 | show_undocumented_members: true 8 | show_root_heading: true 9 | show_inherited_members: true 10 | heading_level: 2 11 | docstring_style: google 12 | show_root_full_path: true 13 | show_object_full_path: false 14 | separate_signature: false -------------------------------------------------------------------------------- /docs/docs/api/rerankings/blender.md: -------------------------------------------------------------------------------- 1 | # Blender Reranker 2 | 3 | ::: rankify.models.blender_reranker 4 | handler: python 5 | options: 6 | show_source: true 7 | show_undocumented_members: true 8 | show_root_heading: true 9 | show_inherited_members: true 10 | heading_level: 2 11 | docstring_style: google 12 | show_root_full_path: true 13 | show_object_full_path: false 14 | separate_signature: false -------------------------------------------------------------------------------- /docs/docs/api/rerankings/colbert_ranker.md: -------------------------------------------------------------------------------- 1 | # ColBert Reranker 2 | 3 | ::: rankify.models.colbert_ranker 4 | handler: python 5 | options: 6 | show_source: true 7 | show_undocumented_members: true 8 | show_root_heading: true 9 | show_inherited_members: true 10 | heading_level: 2 11 | docstring_style: google 12 | show_root_full_path: true 13 | show_object_full_path: false 14 | separate_signature: false -------------------------------------------------------------------------------- /docs/docs/api/rerankings/echo_rank.md: -------------------------------------------------------------------------------- 1 | # Echo Reranker 2 | 3 | ::: rankify.models.echorank 4 | handler: python 5 | options: 6 | show_source: true 7 | show_undocumented_members: true 8 | show_root_heading: true 9 | show_inherited_members: true 10 | heading_level: 2 11 | docstring_style: google 12 | show_root_full_path: true 13 | show_object_full_path: false 14 | separate_signature: false -------------------------------------------------------------------------------- /docs/docs/api/rerankings/first_ranker.md: -------------------------------------------------------------------------------- 1 | # First Reranker 2 | 3 | ::: rankify.models.first_reranker 4 | handler: python 5 | options: 6 | show_source: true 7 | show_undocumented_members: true 8 | show_root_heading: true 9 | show_inherited_members: true 10 | heading_level: 2 11 | docstring_style: google 12 | show_root_full_path: true 13 | show_object_full_path: false 14 | separate_signature: false -------------------------------------------------------------------------------- /docs/docs/api/rerankings/flashrank.md: -------------------------------------------------------------------------------- 1 | # Flash Reranker 2 | 3 | ::: rankify.models.flashrank 4 | handler: python 5 | options: 6 | show_source: true 7 | show_undocumented_members: true 8 | show_root_heading: true 9 | show_inherited_members: true 10 | heading_level: 2 11 | docstring_style: google 12 | show_root_full_path: true 13 | show_object_full_path: false 14 | separate_signature: false -------------------------------------------------------------------------------- /docs/docs/api/rerankings/incontext_reranker.md: -------------------------------------------------------------------------------- 1 | # incontext Reranker 2 | 3 | ::: rankify.models.incontext_reranker 4 | handler: python 5 | options: 6 | show_source: true 7 | show_undocumented_members: true 8 | show_root_heading: true 9 | show_inherited_members: true 10 | heading_level: 2 11 | docstring_style: google 12 | show_root_full_path: true 13 | show_object_full_path: false 14 | separate_signature: false -------------------------------------------------------------------------------- /docs/docs/api/rerankings/inrank.md: -------------------------------------------------------------------------------- 1 | # InRank Reranker 2 | 3 | ::: rankify.models.inranker 4 | handler: python 5 | options: 6 | show_source: true 7 | show_undocumented_members: true 8 | show_root_heading: true 9 | show_inherited_members: true 10 | heading_level: 2 11 | docstring_style: google 12 | show_root_full_path: true 13 | show_object_full_path: false 14 | separate_signature: false -------------------------------------------------------------------------------- /docs/docs/api/rerankings/listt5.md: -------------------------------------------------------------------------------- 1 | # Listt5 Reranker 2 | 3 | ::: rankify.models.listt5 4 | handler: python 5 | options: 6 | show_source: true 7 | show_undocumented_members: true 8 | show_root_heading: true 9 | show_inherited_members: true 10 | heading_level: 2 11 | docstring_style: google 12 | show_root_full_path: true 13 | show_object_full_path: false 14 | separate_signature: false -------------------------------------------------------------------------------- /docs/docs/api/rerankings/lit5.md: -------------------------------------------------------------------------------- 1 | # Lit5 Reranker 2 | 3 | ::: rankify.models.lit5_reranker 4 | handler: python 5 | options: 6 | show_source: true 7 | show_undocumented_members: true 8 | show_root_heading: true 9 | show_inherited_members: true 10 | heading_level: 2 11 | docstring_style: google 12 | show_root_full_path: true 13 | show_object_full_path: false 14 | separate_signature: false -------------------------------------------------------------------------------- /docs/docs/api/rerankings/llm2vec_reranker.md: -------------------------------------------------------------------------------- 1 | # LLM2vec Reranker 2 | 3 | ::: rankify.models.llm2vec_reranker 4 | handler: python 5 | options: 6 | show_source: true 7 | show_undocumented_members: true 8 | show_root_heading: true 9 | show_inherited_members: true 10 | heading_level: 2 11 | docstring_style: google 12 | show_root_full_path: true 13 | show_object_full_path: false 14 | separate_signature: false -------------------------------------------------------------------------------- /docs/docs/api/rerankings/llm_layerwise_ranker.md: -------------------------------------------------------------------------------- 1 | # LLM Layerwise Reranker 2 | 3 | ::: rankify.models.llm_layerwise_ranker 4 | handler: python 5 | options: 6 | show_source: true 7 | show_undocumented_members: true 8 | show_root_heading: true 9 | show_inherited_members: true 10 | heading_level: 2 11 | docstring_style: google 12 | show_root_full_path: true 13 | show_object_full_path: false 14 | separate_signature: false -------------------------------------------------------------------------------- /docs/docs/api/rerankings/monobert.md: -------------------------------------------------------------------------------- 1 | # Monobert Reranker 2 | 3 | ::: rankify.models.monobert 4 | handler: python 5 | options: 6 | show_source: true 7 | show_undocumented_members: true 8 | show_root_heading: true 9 | show_inherited_members: true 10 | heading_level: 2 11 | docstring_style: google 12 | show_root_full_path: true 13 | show_object_full_path: false 14 | separate_signature: false -------------------------------------------------------------------------------- /docs/docs/api/rerankings/monot5.md: -------------------------------------------------------------------------------- 1 | # Monot5 Reranker 2 | 3 | ::: rankify.models.monot5 4 | handler: python 5 | options: 6 | show_source: true 7 | show_undocumented_members: true 8 | show_root_heading: true 9 | show_inherited_members: true 10 | heading_level: 2 11 | docstring_style: google 12 | show_root_full_path: true 13 | show_object_full_path: false 14 | separate_signature: false -------------------------------------------------------------------------------- /docs/docs/api/rerankings/rank_fid.md: -------------------------------------------------------------------------------- 1 | # RankFiD Reranker 2 | 3 | ::: rankify.models.rank_fid 4 | handler: python 5 | options: 6 | show_source: true 7 | show_undocumented_members: true 8 | show_root_heading: true 9 | show_inherited_members: true 10 | heading_level: 2 11 | docstring_style: google 12 | show_root_full_path: true 13 | show_object_full_path: false 14 | separate_signature: false -------------------------------------------------------------------------------- /docs/docs/api/rerankings/rankgpt.md: -------------------------------------------------------------------------------- 1 | # RankGPT Reranker 2 | 3 | ::: rankify.models.rankgpt 4 | handler: python 5 | options: 6 | show_source: true 7 | show_undocumented_members: true 8 | show_root_heading: true 9 | show_inherited_members: true 10 | heading_level: 2 11 | docstring_style: google 12 | show_root_full_path: true 13 | show_object_full_path: false 14 | separate_signature: false -------------------------------------------------------------------------------- /docs/docs/api/rerankings/rankt5.md: -------------------------------------------------------------------------------- 1 | # RankT5 Reranker 2 | 3 | ::: rankify.models.rankt5 4 | handler: python 5 | options: 6 | show_source: true 7 | show_undocumented_members: true 8 | show_root_heading: true 9 | show_inherited_members: true 10 | heading_level: 2 11 | docstring_style: google 12 | show_root_full_path: true 13 | show_object_full_path: false 14 | separate_signature: false -------------------------------------------------------------------------------- /docs/docs/api/rerankings/reranking.md: -------------------------------------------------------------------------------- 1 | # Reranking Models 2 | 3 | ::: rankify.models.reranking 4 | handler: python 5 | options: 6 | show_source: true 7 | show_undocumented_members: true 8 | show_root_heading: true 9 | show_inherited_members: true 10 | heading_level: 2 11 | docstring_style: google 12 | show_root_full_path: true 13 | show_object_full_path: false 14 | separate_signature: false -------------------------------------------------------------------------------- /docs/docs/api/rerankings/sentence_transformer_reranker.md: -------------------------------------------------------------------------------- 1 | # Sentence Transforme Reranker 2 | 3 | ::: rankify.models.sentence_transformer_reranker 4 | handler: python 5 | options: 6 | show_source: true 7 | show_undocumented_members: true 8 | show_root_heading: true 9 | show_inherited_members: true 10 | heading_level: 2 11 | docstring_style: google 12 | show_root_full_path: true 13 | show_object_full_path: false 14 | separate_signature: false -------------------------------------------------------------------------------- /docs/docs/api/rerankings/splade_reranker.md: -------------------------------------------------------------------------------- 1 | # Splade Reranker 2 | 3 | ::: rankify.models.splade_reranker 4 | handler: python 5 | options: 6 | show_source: true 7 | show_undocumented_members: true 8 | show_root_heading: true 9 | show_inherited_members: true 10 | heading_level: 2 11 | docstring_style: google 12 | show_root_full_path: true 13 | show_object_full_path: false 14 | separate_signature: false -------------------------------------------------------------------------------- /docs/docs/api/rerankings/transformer_reranker.md: -------------------------------------------------------------------------------- 1 | # Transformer Reranker 2 | 3 | ::: rankify.models.transformer_ranker 4 | handler: python 5 | options: 6 | show_source: true 7 | show_undocumented_members: true 8 | show_root_heading: true 9 | show_inherited_members: true 10 | heading_level: 2 11 | docstring_style: google 12 | show_root_full_path: true 13 | show_object_full_path: false 14 | separate_signature: false -------------------------------------------------------------------------------- /docs/docs/api/rerankings/twolar.md: -------------------------------------------------------------------------------- 1 | # Twolar Reranker 2 | 3 | ::: rankify.models.twolar 4 | handler: python 5 | options: 6 | show_source: true 7 | show_undocumented_members: true 8 | show_root_heading: true 9 | show_inherited_members: true 10 | heading_level: 2 11 | docstring_style: google 12 | show_root_full_path: true 13 | show_object_full_path: false 14 | separate_signature: false -------------------------------------------------------------------------------- /docs/docs/api/rerankings/upr.md: -------------------------------------------------------------------------------- 1 | # UPR Reranker 2 | 3 | ::: rankify.models.upr 4 | handler: python 5 | options: 6 | show_source: true 7 | show_undocumented_members: true 8 | show_root_heading: true 9 | show_inherited_members: true 10 | heading_level: 2 11 | docstring_style: google 12 | show_root_full_path: true 13 | show_object_full_path: false 14 | separate_signature: false -------------------------------------------------------------------------------- /docs/docs/api/rerankings/vicuna_reranker.md: -------------------------------------------------------------------------------- 1 | # Vicuna Reranker 2 | 3 | ::: rankify.models.vicuna_reranker 4 | handler: python 5 | options: 6 | show_source: true 7 | show_undocumented_members: true 8 | show_root_heading: true 9 | show_inherited_members: true 10 | heading_level: 2 11 | docstring_style: google 12 | show_root_full_path: true 13 | show_object_full_path: false 14 | separate_signature: false -------------------------------------------------------------------------------- /docs/docs/api/rerankings/zephyr_reranker.md: -------------------------------------------------------------------------------- 1 | # Zephyr Reranker 2 | 3 | ::: rankify.models.zephyr_reranker 4 | handler: python 5 | options: 6 | show_source: true 7 | show_undocumented_members: true 8 | show_root_heading: true 9 | show_inherited_members: true 10 | heading_level: 2 11 | docstring_style: google 12 | show_root_full_path: true 13 | show_object_full_path: false 14 | separate_signature: false -------------------------------------------------------------------------------- /docs/docs/api/retrievers/bge.md: -------------------------------------------------------------------------------- 1 | # BGE Retriever 2 | 3 | ::: rankify.retrievers.BGERetriever 4 | handler: python 5 | options: 6 | show_source: true 7 | show_undocumented_members: true 8 | show_root_heading: true 9 | show_inherited_members: true 10 | heading_level: 2 11 | docstring_style: google 12 | show_root_full_path: true 13 | show_object_full_path: false 14 | separate_signature: false -------------------------------------------------------------------------------- /docs/docs/api/retrievers/bm25.md: -------------------------------------------------------------------------------- 1 | # BM25 Retriever 2 | 3 | ::: rankify.retrievers.bm25 4 | handler: python 5 | options: 6 | show_source: true 7 | show_undocumented_members: true 8 | show_root_heading: true 9 | show_inherited_members: true 10 | heading_level: 2 11 | docstring_style: google 12 | show_root_full_path: true 13 | show_object_full_path: false 14 | separate_signature: false -------------------------------------------------------------------------------- /docs/docs/api/retrievers/colbert.md: -------------------------------------------------------------------------------- 1 | # Colbert Retriever 2 | 3 | ::: rankify.retrievers.colbert 4 | handler: python 5 | options: 6 | show_source: true 7 | show_undocumented_members: true 8 | show_root_heading: true 9 | show_inherited_members: true 10 | heading_level: 2 11 | docstring_style: google 12 | show_root_full_path: true 13 | show_object_full_path: false 14 | separate_signature: false -------------------------------------------------------------------------------- /docs/docs/api/retrievers/contriever.md: -------------------------------------------------------------------------------- 1 | # Contriever Retriever 2 | 3 | ::: rankify.retrievers.contriever 4 | handler: python 5 | options: 6 | show_source: true 7 | show_undocumented_members: true 8 | show_root_heading: true 9 | show_inherited_members: true 10 | heading_level: 2 11 | docstring_style: google 12 | show_root_full_path: true 13 | show_object_full_path: false 14 | separate_signature: false -------------------------------------------------------------------------------- /docs/docs/api/retrievers/dense.md: -------------------------------------------------------------------------------- 1 | # Dense Retriever 2 | 3 | ::: rankify.retrievers.dpr 4 | handler: python 5 | options: 6 | show_source: true 7 | show_undocumented_members: true 8 | show_root_heading: true 9 | show_inherited_members: true 10 | heading_level: 2 11 | docstring_style: google 12 | show_root_full_path: true 13 | show_object_full_path: false 14 | separate_signature: false -------------------------------------------------------------------------------- /docs/docs/api/retrievers/retriever.md: -------------------------------------------------------------------------------- 1 | # Retriever 2 | 3 | ::: rankify.retrievers.retriever 4 | handler: python 5 | options: 6 | show_source: true 7 | show_undocumented_members: true 8 | show_root_heading: true 9 | show_inherited_members: true 10 | heading_level: 2 11 | docstring_style: google 12 | show_root_full_path: true 13 | show_object_full_path: false 14 | separate_signature: false -------------------------------------------------------------------------------- /docs/docs/assets/overview.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataScienceUIBK/Rankify/c3d521c2d1a17c7d82f2a9025aa000d9ad41b7d2/docs/docs/assets/overview.jpg -------------------------------------------------------------------------------- /docs/docs/assets/rankify-crop.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataScienceUIBK/Rankify/c3d521c2d1a17c7d82f2a9025aa000d9ad41b7d2/docs/docs/assets/rankify-crop.png -------------------------------------------------------------------------------- /docs/docs/contribution.md: -------------------------------------------------------------------------------- 1 | --- 2 | title: "Contributing" 3 | sidebar_position: 5 4 | --- 5 | 6 | # 💡 Contributing 7 | 8 | Follow these steps to get involved: 9 | 10 | 1️⃣ **Fork this repository** to your GitHub account. 11 | 12 | 2️⃣ **Create a new branch** for your feature or fix: 13 | ```bash 14 | git checkout -b feature/YourFeatureName 15 | ``` 16 | 3️⃣ **Make your changes** and **commit them**: 17 | ```bash 18 | git commit -m "Add YourFeatureName" 19 | ``` 20 | 4️⃣ **Push the changes** to your branch: 21 | ```bash 22 | git push origin feature/YourFeatureName 23 | ``` 24 | 5️⃣ **Submit a Pull Request** to propose your changes. 25 | 26 | 🙏 Thank you for helping make this project better! 27 | -------------------------------------------------------------------------------- /docs/docs/index.md: -------------------------------------------------------------------------------- 1 | --- 2 | sidebar_position: 1 3 | hide: 4 | - navigation 5 | - toc 6 | template: home.html 7 | --- 8 | 9 | 10 | 11 | # Welcome to Rankify 12 | 13 | For documentation now please visit our [Rankify Repo](https://github.com/DataScienceUIBK/Rankify). 14 | -------------------------------------------------------------------------------- /docs/docs/js/clustrmaps.js: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataScienceUIBK/Rankify/c3d521c2d1a17c7d82f2a9025aa000d9ad41b7d2/docs/docs/js/clustrmaps.js -------------------------------------------------------------------------------- /docs/docs/js/runllm-widget.js: -------------------------------------------------------------------------------- 1 | document.addEventListener("DOMContentLoaded", function () { 2 | var script = document.createElement("script"); 3 | script.defer = true; 4 | script.type = "module"; 5 | script.id = "rankify-widget-script"; 6 | script.src = 7 | "https://widget.runllm.com"; 8 | script.setAttribute("rankify-name", "Rankify"); 9 | script.setAttribute("rankify-preset", "mkdocs"); 10 | script.setAttribute("rankify-server-address", "https://api.rankify.com"); 11 | script.setAttribute("rankify-assistant-id", "132"); 12 | script.setAttribute("rankify-position", "BOTTOM_RIGHT"); 13 | script.setAttribute("rankify-keyboard-shortcut", "Mod+j"); 14 | script.setAttribute( 15 | "rankify-slack-community-url", 16 | "" 17 | ); 18 | 19 | document.head.appendChild(script); 20 | }); -------------------------------------------------------------------------------- /docs/overrides/main.html: -------------------------------------------------------------------------------- 1 | {% extends "base.html" %} 2 | 3 | {% block extrahead %} 4 | 5 | 12 | {% endblock %} -------------------------------------------------------------------------------- /docs/overrides/partials/tabs.html: -------------------------------------------------------------------------------- 1 | 22 | 23 | {% import "partials/tabs-item.html" as item with context %} 24 | 25 | 26 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | mkdocs 2 | mkdocs-material 3 | mkdocs-rtd-dropdown 4 | mkdocstrings[python] 5 | mkdocs-material>=9.5.41 6 | mkdocs-jupyter>=0.25.1 7 | mkdocs-material[imaging]>=9.5.41 8 | mkdocs-redirects>=1.2.1 9 | mkdocstrings>=0.26.1 10 | mkdocstrings-python>=1.12.2 11 | urllib3==1.26.6 12 | mistune==3.0.2 -------------------------------------------------------------------------------- /examples/generator_fid.py: -------------------------------------------------------------------------------- 1 | from rankify.dataset.dataset import Document, Question, Answer, Context 2 | from rankify.generator.generator import Generator 3 | 4 | # Define question and answer 5 | question = Question("What is the capital of France?") 6 | answers = Answer([""]) 7 | contexts = [ 8 | Context(id=1, title="France", text="The capital of France is Paris.", score=0.9), 9 | Context(id=2, title="Germany", text="Berlin is the capital of Germany.", score=0.5) 10 | ] 11 | 12 | # Construct document 13 | doc = Document(question=question, answers=answers, contexts=contexts) 14 | 15 | # Initialize Generator (e.g., Meta Llama) 16 | generator = Generator(method="fid", model_name='nq_reader_base', backend="fid") 17 | 18 | # Generate answer 19 | generated_answers = generator.generate([doc]) 20 | print(generated_answers) # Output: ["Paris"] 21 | -------------------------------------------------------------------------------- /examples/generator_huggingface.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from rankify.dataset.dataset import Document, Question, Answer, Context 3 | from rankify.generator.generator import Generator 4 | 5 | # Define question and answer 6 | question = Question("What is the capital of Austria?") 7 | answers=Answer("") 8 | contexts = [ 9 | Context(id=1, title="France", text="The capital of France is Paris.", score=0.9), 10 | Context(id=2, title="Germany", text="Berlin is the capital of Germany.", score=0.5) 11 | ] 12 | 13 | # Construct document 14 | doc = Document(question=question, answers=answers, contexts=contexts) 15 | 16 | # Initialize Generator (e.g., Meta Llama) 17 | generator = Generator(method="chain-of-thought-rag", model_name='meta-llama/Meta-Llama-3.1-8B-Instruct', backend="huggingface", torch_dtype=torch.float16) 18 | 19 | # Generate answer 20 | generated_answers = generator.generate([doc]) 21 | print(generated_answers) # Output: ["Paris"] 22 | -------------------------------------------------------------------------------- /examples/generator_litellm.py: -------------------------------------------------------------------------------- 1 | from rankify.dataset.dataset import Document, Question, Answer, Context 2 | from rankify.generator.generator import Generator 3 | from rankify.utils.models.rank_llm.rerank.api_keys import get_litellm_api_key 4 | 5 | # Define question and answer 6 | question = Question("What is the capital of France?") 7 | #answers = Answer(["Paris"]) 8 | answers = Answer([""]) 9 | contexts = [ 10 | Context(id=1, title="France", text="The capital of France is Paris.", score=0.9), 11 | Context(id=2, title="Germany", text="Berlin is the capital of Germany.", score=0.5) 12 | ] 13 | 14 | # Construct document 15 | doc = Document(question=question, answers=answers, contexts=contexts) 16 | 17 | #load api-key 18 | api_key = get_litellm_api_key() 19 | 20 | # Initialize Generator (e.g., Meta Llama) 21 | generator = Generator(method="basic-rag", model_name='ollama/mistral', backend="litellm", api_key=api_key) 22 | 23 | # Generate answer 24 | generated_answers = generator.generate([doc]) 25 | print(generated_answers) # Output: ["Paris"] 26 | -------------------------------------------------------------------------------- /examples/generator_openai.py: -------------------------------------------------------------------------------- 1 | from rankify.dataset.dataset import Document, Question, Answer, Context 2 | from rankify.generator.generator import Generator 3 | from rankify.utils.models.rank_llm.rerank.api_keys import get_openai_api_key 4 | 5 | # Define question and answer 6 | question = Question("What is the capital of France?") 7 | #answers = Answer(["Paris"]) 8 | answers = Answer([""]) 9 | contexts = [ 10 | Context(id=1, title="France", text="The capital of France is Paris.", score=0.9), 11 | Context(id=2, title="Germany", text="Berlin is the capital of Germany.", score=0.5) 12 | ] 13 | 14 | # Construct document 15 | doc = Document(question=question, answers=answers, contexts=contexts) 16 | 17 | #load api-key 18 | api_key = get_openai_api_key() 19 | 20 | # Initialize Generator (e.g., Meta Llama) 21 | generator = Generator(method="basic-rag", model_name='gpt-3.5-turbo', backend="openai", api_keys=[api_key]) 22 | 23 | # Generate answer 24 | generated_answers = generator.generate([doc], max_tokens=10) 25 | print(generated_answers) # Output: ["Paris"] 26 | -------------------------------------------------------------------------------- /examples/generator_ralm.py: -------------------------------------------------------------------------------- 1 | from rankify.dataset.dataset import Document, Question, Answer, Context 2 | from rankify.generator.generator import Generator 3 | 4 | # Sample question and contexts 5 | question = Question("What is the capital of France?") 6 | answers=Answer('') 7 | contexts = [ 8 | Context(id=1, title="France", text="The capital of France is Paris.", score=0.9), 9 | Context(id=2, title="Germany", text="Berlin is the capital of Germany.", score=0.5) 10 | ] 11 | 12 | # Create a Document 13 | doc = Document(question=question, answers= answers, contexts=contexts) 14 | 15 | # Initialize Generator (e.g., Meta Llama, with huggingface backend) 16 | generator = Generator(method="in-context-ralm", model_name='meta-llama/Meta-Llama-3.1-8B-Instruct', backend="huggingface") 17 | 18 | # Generate answer 19 | generated_answers = generator.generate([doc]) 20 | print(generated_answers) # Output: ["Paris"] 21 | -------------------------------------------------------------------------------- /examples/generator_vllm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from vllm import SamplingParams 3 | from rankify.dataset.dataset import Document, Question, Answer, Context 4 | from rankify.generator.generator import Generator 5 | 6 | # Define question and answer 7 | question = Question("What is the capital of France?") 8 | answers=Answer("") 9 | contexts = [ 10 | Context(id=1, title="France", text="The capital of France is Paris.", score=0.9), 11 | Context(id=2, title="Germany", text="Berlin is the capital of Germany.", score=0.5) 12 | ] 13 | 14 | # Construct document 15 | doc = Document(question=question, answers=answers, contexts=contexts) 16 | 17 | # Define sampling parameters for vllm 18 | sampling_params = SamplingParams(temperature=0.7, top_p=0.9, max_tokens=100) 19 | 20 | # Initialize Generator (e.g., Meta Llama) 21 | generator = Generator(method="basic-rag", model_name='mistralai/Mistral-7B-v0.1', backend="vllm", dtype="float16") 22 | 23 | # Generate answer 24 | generated_answers = generator.generate([doc],sampling_params=sampling_params) 25 | 26 | output = generated_answers[0][0] 27 | print(output.prompt.strip()) 28 | print(output.outputs[0].text.strip()) # Output: ["Paris"] 29 | -------------------------------------------------------------------------------- /examples/reranking.py: -------------------------------------------------------------------------------- 1 | from rankify.dataset.dataset import Document, Question, Answer, Context 2 | from rankify.models.reranking import Reranking 3 | from rankify.utils.pre_defind_models import HF_PRE_DEFIND_MODELS 4 | 5 | import os 6 | os.environ["CUDA_VISIBLE_DEVICES"] = "1" 7 | 8 | ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY") 9 | OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") 10 | # Sample document setup 11 | question = Question("When did Thomas Edison invent the light bulb?") 12 | answers = Answer(["1879"]) 13 | contexts = [ 14 | Context(text="Lightning strike at Seoul National University", id=1), 15 | Context(text="Thomas Edison tried to invent a device for cars but failed", id=2), 16 | Context(text="Coffee is good for diet", id=3), 17 | Context(text="Thomas Edison invented the light bulb in 1879", id=4), 18 | Context(text="Thomas Edison worked with electricity", id=5), 19 | ] 20 | document = Document(question=question, answers=answers, contexts=contexts) 21 | 22 | # Function to test a reranking model 23 | def test_reranking_model(model_category, model_name): 24 | reranker = None 25 | try: 26 | print(f"Testing {model_category}: {model_name} ...") 27 | if model_category =="apiranker" or model_category =="rankgpt-api": 28 | api_key = OPENAI_API_KEY 29 | else: 30 | api_key=ANTHROPIC_API_KEY 31 | 32 | reranker = Reranking(method=model_category, model_name=model_name, api_key=api_key) 33 | reranker.rank([document]) 34 | 35 | # Print reordered contexts 36 | print("Reordered Contexts:") 37 | for context in document.reorder_contexts: 38 | print(f" - {context.text}") 39 | print(f"✔ {model_name} passed!\n") 40 | 41 | except Exception as e: 42 | print(f"❌ {model_name} failed with error: {e}\n") 43 | 44 | # Iterate over all models and test each one 45 | for category, models in HF_PRE_DEFIND_MODELS.items(): 46 | print(category, "----------") 47 | if category == 'flashrank-model-file' or category =="apiranker": 48 | continue 49 | for model_key, model_name in models.items(): 50 | test_reranking_model(category, model_key) 51 | break -------------------------------------------------------------------------------- /examples/reranking_example.py: -------------------------------------------------------------------------------- 1 | from rankify.dataset.dataset import Document, Question, Answer, Context 2 | from rankify.models.reranking import Reranking 3 | from rankify.utils.pre_defind_models import HF_PRE_DEFIND_MODELS 4 | 5 | import os 6 | os.environ["CUDA_VISIBLE_DEVICES"] = "1" 7 | 8 | ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY") 9 | OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") 10 | # Sample document setup 11 | question = Question("when was Barack Obama president?") 12 | answers = Answer(["1879"]) 13 | contexts = [ 14 | Context(text="asd asd asd", id=1), 15 | Context(text="Barack Hussein Obama II (born August 4, 1961) is an American politician who was the 44th president of the United States from 2009 to 2017. A member of the Democratic Party, he was the first African American president in American history.", id=2), 16 | Context(text="asdasda sadasdasd", id=3), 17 | # Context(text="Thomas Edison invented the light bulb in 1879", id=4), 18 | # Context(text="Thomas Edison worked with electricity", id=5), 19 | ] 20 | document = Document(question=question, answers=answers, contexts=contexts) 21 | 22 | # Function to test a reranking model 23 | def test_reranking_model(model_category, model_name): 24 | reranker = None 25 | try: 26 | print(f"Testing {model_category}: {model_name} ...") 27 | if model_category =="apiranker" or model_category =="rankgpt-api": 28 | api_key = OPENAI_API_KEY 29 | else: 30 | api_key=ANTHROPIC_API_KEY 31 | 32 | reranker = Reranking(method=model_category, model_name=model_name, api_key=api_key) 33 | reranker.rank([document]) 34 | 35 | # Print reordered contexts 36 | print("Reordered Contexts:") 37 | for context in document.reorder_contexts: 38 | print(f" - {context.text}") 39 | print(f"✔ {model_name} passed!\n") 40 | 41 | except Exception as e: 42 | print(f"❌ {model_name} failed with error: {e}\n") 43 | 44 | # Iterate over all models and test each one 45 | # for category, models in HF_PRE_DEFIND_MODELS.items(): 46 | # print(category, "----------") 47 | # if category == 'flashrank-model-file' or category =="apiranker": 48 | # continue 49 | # for model_key, model_name in models.items(): 50 | test_reranking_model("upr", "t5-base") 51 | -------------------------------------------------------------------------------- /examples/retreiver.py: -------------------------------------------------------------------------------- 1 | 2 | import rankify 3 | import os 4 | 5 | from rankify.dataset.dataset import Document, Question, Answer, Context 6 | from rankify.retrievers.retriever import Retriever 7 | 8 | # Sample Documents 9 | documents = [ 10 | Document(question=Question("the cast of a good day to die hard?"), answers=Answer([ 11 | "Jai Courtney", 12 | "Sebastian Koch", 13 | "Radivoje Bukvi\u0107", 14 | "Yuliya Snigir", 15 | "Sergei Kolesnikov", 16 | "Mary Elizabeth Winstead", 17 | "Bruce Willis" 18 | ]), contexts=[]), 19 | # Document(question=Question("Who wrote Hamlet?"), answers=Answer(["Shakespeare"]), contexts=[]) 20 | ] 21 | 22 | 23 | 24 | 25 | """retriever = Retriever(method="colbert", n_docs=1 , index_type="wiki" ) 26 | retrieved_documents = retriever.retrieve(documents) 27 | 28 | # Print the first retrieved document 29 | for i, doc in enumerate(retrieved_documents): 30 | print(f"\nDocument {i+1}:") 31 | print(doc) 32 | 33 | retriever = Retriever(method="colbert", n_docs=1 , index_type="msmarco" ) 34 | retrieved_documents = retriever.retrieve(documents) 35 | 36 | 37 | # Print the first retrieved document 38 | for i, doc in enumerate(retrieved_documents): 39 | print(f"\nDocument {i+1}:") 40 | print(doc) 41 | """ 42 | 43 | # OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") 44 | # retriever = Retriever(method="hyde", n_docs=2 , index_type="wiki", api_key=OPENAI_API_KEY ) 45 | # retrieved_documents = retriever.retrieve(documents) 46 | 47 | # # Print the first retrieved document 48 | # for i, doc in enumerate(retrieved_documents): 49 | # print(f"\nDocument {i+1}:") 50 | # print(doc) 51 | 52 | # retriever = Retriever(method="bge", n_docs=1 , index_type="msmarco" ) 53 | serpapi= "" 54 | 55 | serper = "" 56 | online_retriever = Retriever(method="online_retriever", n_docs=5 , api_key=serper) 57 | retrieved_documents = online_retriever.retrieve(documents) 58 | # Print the first retrieved document 59 | for i, doc in enumerate(retrieved_documents): 60 | print(f"\nDocument {i+1}:") 61 | print(doc) 62 | 63 | 64 | 65 | -------------------------------------------------------------------------------- /examples/retrieved_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 3 | from rankify.dataset.dataset import Dataset ,Document, Context, Question,Answer 4 | from rankify.metrics.metrics import Metrics 5 | #Dataset.avaiable_dataset() 6 | 7 | 8 | datasets = ["web_questions-test"]#, "ChroniclingAmericaQA-test" , "ArchivialQA-test"]#["nq-dev", "nq-test" , "squad1-test", "trivia-dev", "trivia-test", "webq-test", "squad1-dev" ] # 9 | 10 | for name in datasets: 11 | print("*"*100) 12 | print(name) 13 | dataset= Dataset('bm25', name , 100) 14 | documents = dataset.download(force_download=False) 15 | 16 | print(len(documents[0].contexts),documents[0].answers ) 17 | 18 | metrics = Metrics(documents) 19 | 20 | before_ranking_metrics = metrics.calculate_retrieval_metrics(ks=[1,5,10,20,50,100],use_reordered=False) 21 | print(before_ranking_metrics) 22 | print("#"*100) 23 | dataset.save_dataset("webq-bm25-test.json", save_text=True) -------------------------------------------------------------------------------- /images/output.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataScienceUIBK/Rankify/c3d521c2d1a17c7d82f2a9025aa000d9ad41b7d2/images/output.gif -------------------------------------------------------------------------------- /images/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataScienceUIBK/Rankify/c3d521c2d1a17c7d82f2a9025aa000d9ad41b7d2/images/overview.png -------------------------------------------------------------------------------- /images/rankify-crop.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataScienceUIBK/Rankify/c3d521c2d1a17c7d82f2a9025aa000d9ad41b7d2/images/rankify-crop.png -------------------------------------------------------------------------------- /images/rankify-logo-.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataScienceUIBK/Rankify/c3d521c2d1a17c7d82f2a9025aa000d9ad41b7d2/images/rankify-logo-.png -------------------------------------------------------------------------------- /images/rankify-logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataScienceUIBK/Rankify/c3d521c2d1a17c7d82f2a9025aa000d9ad41b7d2/images/rankify-logo.png -------------------------------------------------------------------------------- /rankify/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | if 'RERANKING_CACHE_DIR' not in os.environ: 4 | os.environ['RERANKING_CACHE_DIR'] = os.path.join(os.path.expanduser('~'),'.cache','rankify') -------------------------------------------------------------------------------- /rankify/__version__.py: -------------------------------------------------------------------------------- 1 | # file generated by setuptools_scm 2 | # don't change, don't track in version control 3 | TYPE_CHECKING = False 4 | if TYPE_CHECKING: 5 | from typing import Tuple, Union 6 | VERSION_TUPLE = Tuple[Union[int, str], ...] 7 | else: 8 | VERSION_TUPLE = object 9 | 10 | version: str 11 | __version__: str 12 | __version_tuple__: VERSION_TUPLE 13 | version_tuple: VERSION_TUPLE 14 | 15 | __version__ = version = '0.1.3' 16 | __version_tuple__ = version_tuple = (0, 1, 3, ) 17 | -------------------------------------------------------------------------------- /rankify/agent/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataScienceUIBK/Rankify/c3d521c2d1a17c7d82f2a9025aa000d9ad41b7d2/rankify/agent/__init__.py -------------------------------------------------------------------------------- /rankify/dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataScienceUIBK/Rankify/c3d521c2d1a17c7d82f2a9025aa000d9ad41b7d2/rankify/dataset/__init__.py -------------------------------------------------------------------------------- /rankify/generator/models/base_rag_model.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import List 3 | 4 | class BaseRAGModel(ABC): 5 | """ 6 | **Base RAG Model** for Retrieval-Augmented Generation (RAG). 7 | 8 | This is an abstract base class for implementing RAG models. 9 | It defines the interface for generating responses and optional embedding generation. 10 | 11 | Methods: 12 | generate(prompt: str, **kwargs) -> str: 13 | Abstract method to generate a response based on the given prompt. 14 | embed(text: str, **kwargs) -> List[float]: 15 | Optional method to generate embeddings for the given text. 16 | 17 | Notes: 18 | - This class serves as a blueprint for RAG models like `OpenAIModel` and `HuggingFaceModel`. 19 | - The `embed` method is optional and can be implemented if needed. 20 | """ 21 | 22 | @abstractmethod 23 | def generate(self, prompt: str, **kwargs) -> str: 24 | """Generate a response based on the given prompt.""" 25 | pass 26 | 27 | def embed(self, text: str, **kwargs) -> List[float]: 28 | """Optional: Generate embeddings for the given text.""" 29 | raise NotImplementedError("Embedding is not required for this implementation.") -------------------------------------------------------------------------------- /rankify/generator/models/huggingface_model.py: -------------------------------------------------------------------------------- 1 | from rankify.generator.models.base_rag_model import BaseRAGModel 2 | from rankify.generator.prompt_generator import PromptGenerator 3 | 4 | class HuggingFaceModel(BaseRAGModel): 5 | """ 6 | **Hugging Face Model** for Retrieval-Augmented Generation (RAG). 7 | 8 | This class integrates Hugging Face's pretrained models for text generation in a RAG pipeline. 9 | It uses the Hugging Face Transformers library for tokenization and model inference. 10 | 11 | Attributes: 12 | model_name (str): Name of the Hugging Face model. 13 | tokenizer: Tokenizer instance for encoding input text. 14 | model: Pretrained Hugging Face model for text generation. 15 | prompt_generator (PromptGenerator): Instance for generating prompts. 16 | 17 | Notes: 18 | - This model uses Hugging Face's Transformers library for text generation. 19 | - Default generation parameters like `max_length` and `temperature` can be overridden. 20 | """ 21 | 22 | def __init__(self, model_name: str, tokenizer, model, prompt_generator: PromptGenerator): 23 | self.model_name = model_name 24 | self.tokenizer = tokenizer 25 | self.model = model 26 | self.prompt_generator = prompt_generator 27 | 28 | def generate(self, prompt: str, **kwargs) -> str: 29 | """Generate a response using Hugging Face's model.""" 30 | inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device) # ensure inputs are on the same device as the model 31 | # define generation parameters defaults TODO: should be excluded into config 32 | kwargs.setdefault("max_length", 128) 33 | kwargs.setdefault("temperature", 0.7) 34 | 35 | 36 | outputs = self.model.generate(**inputs, **kwargs) 37 | return self.tokenizer.decode(outputs[0], skip_special_tokens=True) -------------------------------------------------------------------------------- /rankify/generator/models/litellm_model.py: -------------------------------------------------------------------------------- 1 | from rankify.generator.models.base_rag_model import BaseRAGModel 2 | from rankify.generator.prompt_generator import PromptGenerator 3 | from rankify.utils.api.litellmclient import LitellmClient 4 | 5 | class LitellmModel(BaseRAGModel): 6 | """ 7 | **LiteLLM Model** for Retrieval-Augmented Generation (RAG). 8 | 9 | This class integrates LiteLLM's API for text generation in a RAG pipeline. 10 | It uses the LiteLLM API to generate responses based on input prompts. 11 | 12 | Attributes: 13 | model_name (str): Name of the LiteLLM model. 14 | api_keys (list): List of API keys for authenticating with LiteLLM. 15 | prompt_generator (PromptGenerator): Instance for generating prompts. 16 | client (LitellmClient): Client for interacting with the LiteLLM API. 17 | 18 | Notes: 19 | - This model uses LiteLLM's API for text generation. 20 | - It supports additional parameters like `max_tokens` and `temperature`. 21 | """ 22 | def __init__(self, model_name: str, api_key: str, prompt_generator: PromptGenerator): 23 | """ 24 | Initialize the LitellmModel with the LitellmClient. 25 | 26 | :param model_name: Name of the LiteLLM model. 27 | :param api_keys: List of API keys for LiteLLM. 28 | :param prompt_generator: Instance of PromptGenerator for generating prompts. 29 | """ 30 | self.model_name = model_name 31 | self.prompt_generator = prompt_generator 32 | 33 | self.client = LitellmClient(keys=api_key) 34 | 35 | def generate(self, prompt: str, **kwargs) -> str: 36 | """ 37 | Generate a response using LiteLLM's API. 38 | 39 | :param prompt: The input prompt for the model. 40 | :param kwargs: Additional parameters for the LiteLLM API call. 41 | :return: The generated response as a string. 42 | """ 43 | # Todo: use this later -> Generate the prompt using the prompt generator 44 | # full_prompt = self.prompt_generator.generate_prompt(prompt) 45 | 46 | # Set default parameters for the LiteLLM API call 47 | kwargs.setdefault("model", self.model_name) 48 | kwargs.setdefault("max_tokens", 128) 49 | kwargs.setdefault("temperature", 0.7) 50 | 51 | # Call the LiteLLM API using the LitellmClient 52 | response = self.client.chat(messages=[{"role": "user", "content": prompt}], return_text=True, **kwargs) 53 | return response -------------------------------------------------------------------------------- /rankify/generator/models/model_factory.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from rankify.generator.models.fid_model import FiDModel 3 | from rankify.generator.models.litellm_model import LitellmModel 4 | from rankify.generator.models.openai_model import OpenAIModel 5 | from rankify.generator.models.base_rag_model import BaseRAGModel 6 | from rankify.generator.models.huggingface_model import HuggingFaceModel 7 | from rankify.generator.models.vllm_model import VLLMModel 8 | from rankify.generator.prompt_generator import PromptGenerator 9 | 10 | from transformers import AutoTokenizer, AutoModelForCausalLM 11 | 12 | from rankify.utils.generator.huggingface_models.model_utils import load_model, load_tokenizer 13 | 14 | def model_factory(model_name: str, backend: str, method: str, use_litellm=False, **kwargs) -> BaseRAGModel: 15 | prompt_generator = PromptGenerator(model_type=model_name, method=method) 16 | if backend == "openai": 17 | return OpenAIModel(model_name, kwargs["api_keys"], prompt_generator) 18 | elif backend == "huggingface": 19 | tokenizer = load_tokenizer(model_name) 20 | model = load_model(model_name, **kwargs) 21 | return HuggingFaceModel(model_name, tokenizer, model, prompt_generator) 22 | elif backend == "fid": 23 | return FiDModel(method, model_name, **kwargs) 24 | elif backend == "litellm": 25 | return LitellmModel(model_name, kwargs["api_key"], prompt_generator) 26 | elif backend == "vllm": 27 | return VLLMModel(model_name, prompt_generator, **kwargs) 28 | else: 29 | raise ValueError(f"Unsupported backend: {backend}") -------------------------------------------------------------------------------- /rankify/generator/models/openai_model.py: -------------------------------------------------------------------------------- 1 | from rankify.generator.models.base_rag_model import BaseRAGModel 2 | from rankify.generator.prompt_generator import PromptGenerator 3 | from rankify.utils.api.litellmclient import LitellmClient 4 | from rankify.utils.api.openaiclient import OpenaiClient 5 | 6 | class OpenAIModel(BaseRAGModel): 7 | """ 8 | **OpenAI Model** for Retrieval-Augmented Generation (RAG). 9 | 10 | This class integrates OpenAI's GPT models for text generation in a RAG pipeline. 11 | It uses the OpenAI API to generate responses based on input prompts. 12 | 13 | Attributes: 14 | model_name (str): Name of the OpenAI model (e.g., "gpt-3.5-turbo"). 15 | api_keys (list): List of API keys for authenticating with OpenAI. 16 | prompt_generator (PromptGenerator): Instance for generating prompts. 17 | client (OpenaiClient): Client for interacting with the OpenAI API. 18 | 19 | Notes: 20 | - This model uses OpenAI's GPT models for text generation. 21 | - It supports additional parameters like `max_tokens` and `temperature`. 22 | """ 23 | def __init__(self, model_name: str, api_keys: list, prompt_generator: PromptGenerator, base_url: str = None): 24 | """ 25 | Initialize the OpenAIModel with the OpenaiClient. 26 | 27 | :param model_name: Name of the OpenAI model (e.g., "gpt-3.5-turbo"). 28 | :param api_keys: List of API keys for OpenAI. 29 | :param prompt_generator: Instance of PromptGenerator for generating prompts. 30 | :param base_url: Optional base URL for the OpenAI API. 31 | """ 32 | self.model_name = model_name 33 | self.prompt_generator = prompt_generator 34 | 35 | self.client = OpenaiClient(keys=api_keys, base_url=base_url) 36 | 37 | def generate(self, prompt: str, **kwargs) -> str: 38 | """ 39 | Generate a response using OpenAI's API. 40 | 41 | :param prompt: The input prompt for the model. 42 | :param kwargs: Additional parameters for the OpenAI API call. 43 | :return: The generated response as a string. 44 | """ 45 | # Todo: use this later -> Generate the prompt using the prompt generator 46 | #full_prompt = self.prompt_generator.generate_prompt(prompt) 47 | 48 | # Set default parameters for the OpenAI API call 49 | kwargs.setdefault("model", self.model_name) 50 | kwargs.setdefault("max_tokens", 128) 51 | kwargs.setdefault("temperature", 0.7) 52 | 53 | # Call the OpenAI API using the OpenaiClient 54 | response = self.client.chat(messages=[{"role": "user", "content": prompt}], return_text=True, **kwargs) 55 | return response -------------------------------------------------------------------------------- /rankify/generator/models/vllm_model.py: -------------------------------------------------------------------------------- 1 | from rankify.generator.models.base_rag_model import BaseRAGModel 2 | from rankify.generator.prompt_generator import PromptGenerator 3 | from vllm import LLM, SamplingParams 4 | 5 | class VLLMModel(BaseRAGModel): 6 | """ 7 | **vLLM Model** for Retrieval-Augmented Generation (RAG). 8 | 9 | This class integrates vLLM's API for text generation in a RAG pipeline. 10 | It uses the vLLM library to generate responses based on input prompts. 11 | 12 | Attributes: 13 | model_name (str): Name of the vLLM model. 14 | prompt_generator (PromptGenerator): Instance for generating prompts. 15 | client (LLM): Client for interacting with the vLLM library. 16 | 17 | Notes: 18 | - This model uses vLLM's library for text generation. 19 | - It supports additional parameters like `max_tokens` and `temperature`. 20 | """ 21 | def __init__(self, model_name: str, prompt_generator: PromptGenerator, **kwargs): 22 | """ 23 | Initialize the VLLMModel with the vLLM client. 24 | 25 | :param model_name: Name of the vLLM model. 26 | :param prompt_generator: Instance of PromptGenerator for generating prompts. 27 | :param device: Device to run the model on (default: "cuda"). 28 | """ 29 | self.model_name = model_name 30 | self.prompt_generator = prompt_generator 31 | self.llm = LLM(model=model_name, **kwargs) 32 | 33 | def generate(self, prompt: str, **kwargs) -> str: 34 | """ 35 | Generate a response using vLLM's API. 36 | 37 | :param prompt: The input prompt for the model. 38 | :param kwargs: Additional parameters for the vLLM API call. 39 | :return: The generated response as a string. 40 | """ 41 | # Todo: use this later -> Generate the prompt using the prompt generator 42 | # full_prompt = self.prompt_generator.generate_prompt(prompt) 43 | 44 | # Call the vLLM API using the LLM client 45 | response = self.llm.generate(prompt, kwargs["sampling_params"]) 46 | 47 | return response -------------------------------------------------------------------------------- /rankify/generator/prompt_generator.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | 4 | class PromptGenerator: 5 | def __init__(self, model_type: str, method: str): 6 | self.model_type = model_type 7 | self.method = method 8 | 9 | def generate_system_prompt(self) -> str: 10 | """Generate a system-level prompt.""" 11 | return f"System prompt for model type: {self.model_type}, method: {self.method}" 12 | 13 | def generate_user_prompt(self, question: str, contexts: List[str]) -> str: 14 | """Generate a user-level prompt.""" 15 | context_str = "\n".join(contexts) 16 | return f"Question: {question}\nContexts:\n{context_str}" -------------------------------------------------------------------------------- /rankify/generator/rag_methods/base_rag_method.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import List 3 | from rankify.dataset.dataset import Document 4 | 5 | class BaseRAGMethod(ABC): 6 | @abstractmethod 7 | def answer_questions(self, documents: List[Document], **kwargs) -> List[str]: 8 | """ 9 | Abstract method to answer a question based on a list of documents. 10 | """ 11 | pass -------------------------------------------------------------------------------- /rankify/generator/rag_methods/basic_rag.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from rankify.generator.models.base_rag_model import BaseRAGModel 3 | 4 | from typing import List 5 | from rankify.dataset.dataset import Document 6 | from rankify.generator.rag_methods.base_rag_method import BaseRAGMethod 7 | 8 | class BasicRAG(BaseRAGMethod): 9 | def __init__(self, model: BaseRAGModel, **kwargs): 10 | self.model = model 11 | 12 | def answer_questions(self, documents: List[Document], **kwargs) -> List[str]: 13 | """ 14 | Answer question for a list of documents using the model. 15 | 16 | Args: 17 | documents (List[Document]): A list of Document objects containing questions and contexts. 18 | 19 | Returns: 20 | str: An answer based on the given documents and question. 21 | """ 22 | answers = [] 23 | 24 | for document in documents: 25 | # Extract question and contexts from the document 26 | question = document.question.question 27 | contexts = [context.text for context in document.contexts] 28 | 29 | # Construct the prompt 30 | prompt = f"""Answer this question based on the given contexts, provide a concise answer. You only need to answer the question, not provide context. 31 | Question: {question}\nContexts:\n""" + "\n".join(contexts) 32 | 33 | # Generate the answer using the model 34 | answer = self.model.generate(prompt=prompt, **kwargs) 35 | 36 | # Append the answer to the list 37 | answers.append(answer) 38 | 39 | return answers -------------------------------------------------------------------------------- /rankify/generator/rag_methods/chain_of_thought_rag.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from rankify.dataset.dataset import Document 3 | from rankify.generator.models.base_rag_model import BaseRAGModel 4 | from rankify.generator.rag_methods.base_rag_method import BaseRAGMethod 5 | 6 | 7 | class ChainOfThoughtRAG(BaseRAGMethod): 8 | """ 9 | **Chain-of-Thought RAG** for Open-Domain Question Answering. 10 | 11 | This class implements a chain-of-thought retrieval-augmented generation (RAG) method, 12 | where the model generates answers by reasoning step-by-step using the provided contexts. 13 | 14 | Attributes: 15 | model (BaseRAGModel): The underlying model used for text generation. 16 | 17 | Methods: 18 | answer_questions(documents: List[Document], **kwargs) -> List[str]: 19 | Generates answers for a list of documents using chain-of-thought reasoning. 20 | 21 | Notes: 22 | - This method uses chain-of-thought reasoning to generate more detailed and logical answers. 23 | - The model can use the provided contexts or rely on its own knowledge to generate answers. 24 | """ 25 | def __init__(self, model: BaseRAGModel, **kwargs): 26 | self.model = model 27 | 28 | def answer_questions(self, documents: List[Document], **kwargs) -> List[str]: 29 | """Answer a question using chain-of-thought reasoning.""" 30 | answers = [] 31 | 32 | for document in documents: 33 | # Extract question and contexts from the document 34 | question = document.question.question 35 | contexts = [context.text for context in document.contexts] 36 | 37 | # Construct the prompt 38 | prompt = f"""Answer this question using internal chain of thought reasoning, think and 39 | lay out your logic in multiple steps. You may use the provided contexts, but you can also discard it and just 40 | reason by your own knowledge. :\nQuestion: {question}\nContexts:\n""".join(contexts) 41 | 42 | # Generate the answer using the model 43 | answer = self.model.generate(prompt=prompt, **kwargs) 44 | 45 | # Append the answer to the list 46 | answers.append(answer) 47 | return answers -------------------------------------------------------------------------------- /rankify/generator/rag_methods/fid_rag_method.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from rankify.generator.models.base_rag_model import BaseRAGModel 3 | 4 | from typing import List 5 | from rankify.dataset.dataset import Document 6 | from rankify.generator.rag_methods.base_rag_method import BaseRAGMethod 7 | 8 | class FiDRAGMethod(BaseRAGMethod): 9 | """ 10 | **FiD RAG Method** for Open-Domain Question Answering. 11 | 12 | This class implements a retrieval-augmented generation (RAG) method using the 13 | Fusion-in-Decoder (FiD) approach. The FiD model aggregates information from 14 | multiple retrieved passages to generate context-aware answers. 15 | 16 | References: 17 | - **Izacard & Grave** *Leveraging Passage Retrieval with Generative Models for Open-Domain QA* 18 | [Paper](https://arxiv.org/abs/2007.01282) 19 | 20 | Attributes: 21 | model (BaseRAGModel): The underlying FiD model used for text generation. 22 | 23 | Methods: 24 | answer_questions(documents: List[Document], **kwargs) -> List[str]: 25 | Generates answers for a list of documents using the FiD model. 26 | 27 | Notes: 28 | - The FiD model combines multiple passages to generate better responses. 29 | 30 | """ 31 | def __init__(self, model: BaseRAGModel): 32 | self.model = model 33 | 34 | def answer_questions(self, documents: List[Document], **kwargs) -> List[str]: 35 | """ 36 | Answer questions for a list of documents using the model. 37 | 38 | Args: 39 | documents (List[Document]): A list of Document objects containing questions and contexts. 40 | 41 | Returns: 42 | Lists[str]: An answer based on the given documents and question. 43 | """ 44 | 45 | # Generate the answer using the model 46 | answer = self.model.generate(documents, **kwargs) 47 | 48 | return answer -------------------------------------------------------------------------------- /rankify/generator/rag_methods/zero_shot.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from rankify.generator.models.base_rag_model import BaseRAGModel 3 | from rankify.dataset.dataset import Document 4 | from rankify.generator.rag_methods.base_rag_method import BaseRAGMethod 5 | 6 | class ZeroShotRAG(BaseRAGMethod): 7 | """ 8 | **Zero-Shot RAG** for Open-Domain Question Answering. 9 | 10 | This class implements a zero-shot retrieval-augmented generation (RAG) method, 11 | where the model generates answers directly from the provided contexts without 12 | requiring additional fine-tuning. 13 | 14 | Attributes: 15 | model (BaseRAGModel): The underlying model used for text generation. 16 | """ 17 | def __init__(self, model: BaseRAGModel, **kwargs): 18 | """ 19 | Initialize the ZeroShotRAG method. 20 | 21 | Args: 22 | model (BaseRAGModel): A model instance for text generation. 23 | kwargs: Additional arguments for customization. 24 | """ 25 | self.model = model 26 | 27 | def answer_questions(self, documents: List[Document], **kwargs) -> List[str]: 28 | """ 29 | Answer questions for a list of documents using the model in a zero-shot manner. 30 | 31 | Args: 32 | documents (List[Document]): A list of Document objects containing questions and contexts. 33 | 34 | Returns: 35 | List[str]: A list of answers. 36 | """ 37 | answers = [] 38 | 39 | for document in documents: 40 | # Extract question and contexts from the document 41 | question = document.question.question 42 | 43 | # Construct the prompt by adding question 44 | prompt = f"Question: {question}\n" 45 | 46 | # Generate the answer using the model 47 | answer = self.model.generate(prompt, **kwargs) 48 | 49 | # Append the answer to the list 50 | answers.append(answer) 51 | 52 | return answers -------------------------------------------------------------------------------- /rankify/indexing/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataScienceUIBK/Rankify/c3d521c2d1a17c7d82f2a9025aa000d9ad41b7d2/rankify/indexing/__init__.py -------------------------------------------------------------------------------- /rankify/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from .metrics import Metrics 2 | __all__ = ['Metrics'] -------------------------------------------------------------------------------- /rankify/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataScienceUIBK/Rankify/c3d521c2d1a17c7d82f2a9025aa000d9ad41b7d2/rankify/models/__init__.py -------------------------------------------------------------------------------- /rankify/models/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from rankify.dataset.dataset import Document 3 | 4 | 5 | 6 | class BaseRanking(ABC): 7 | """ 8 | An abstract base class for implementing different ranking models. 9 | 10 | This class defines the interface for all ranking models, ensuring that all subclasses implement the required methods. 11 | 12 | Attributes: 13 | method (str): The name of the ranking method. 14 | model_name (str): The name of the model being used for ranking. 15 | api_key (str, optional): An optional API key for accessing remote models or services. 16 | """ 17 | 18 | @abstractmethod 19 | def __init__(self, method: str= None, model_name: str= None, api_key: str= None, **kwargs) ->None: 20 | """ 21 | Initializes the base ranking model. 22 | 23 | Args: 24 | method (str, optional): The name of the ranking method. Defaults to None. 25 | model_name (str, optional): The name of the model being used for ranking. Defaults to None. 26 | api_key (str, optional): An optional API key for accessing remote models or services. Defaults to None. 27 | 28 | Example: 29 | ```python 30 | class MyRanking(BaseRanking): 31 | def __init__(self, method, model_name): 32 | super().__init__(method, model_name) 33 | ``` 34 | """ 35 | pass 36 | 37 | @abstractmethod 38 | def rank(self, documents: list[Document] ): 39 | """ 40 | Abstract method to rank a list of documents. 41 | 42 | Args: 43 | documents (list[Document]): A list of Document instances that need to be ranked. 44 | 45 | Raises: 46 | NotImplementedError: This method must be implemented by subclasses. 47 | 48 | Example: 49 | ```python 50 | class MyRanking(BaseRanking): 51 | def __init__(self, method, model_name): 52 | super().__init__(method, model_name) 53 | 54 | def rank(self, documents): 55 | # Ranking implementation here 56 | pass 57 | ``` 58 | """ 59 | pass 60 | 61 | -------------------------------------------------------------------------------- /rankify/requirements.txt: -------------------------------------------------------------------------------- 1 | pandas==2.2.3 2 | prettytable==3.11.0 3 | tqdm==4.66.5 4 | requests==2.32.3 5 | torch==2.5.0 6 | transformers==4.45.2 7 | sentencepiece==0.2.0 8 | openai==1.52.2 9 | anthropic==0.37.1 10 | together==1.3.3 11 | onnxruntime==1.19.2 12 | llama-cpp-python==0.2.76 13 | vllm==0.6.3 14 | ftfy==6.3.1 15 | dacite==1.8.1 16 | fschat[model_worker]>=0.2.36 17 | llm-blender==0.0.2 18 | sentence_transformers==3.3.0 19 | flash-attn==2.5.0 20 | pyserini==0.43.0 21 | faiss-cpu==1.9.0.post1 22 | omegaconf==2.3.0 23 | h5py==3.12.1 24 | py7zr==0.22.0 25 | ujson==5.10.0 26 | ninja==1.11.1.3 27 | cohere==5.14.0 28 | dotenv>=0.9.9 29 | litellm>=1.61.20 30 | langchain>=0.3.19 31 | fasttext-wheel>=0.9.2 32 | wikipedia-api>=0.8.1 33 | pillow>=10.4.0 34 | smolagents>=1.9.2 35 | Crawl4AI>=0.6.3 -------------------------------------------------------------------------------- /rankify/retrievers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataScienceUIBK/Rankify/c3d521c2d1a17c7d82f2a9025aa000d9ad41b7d2/rankify/retrievers/__init__.py -------------------------------------------------------------------------------- /rankify/tools/.env.example: -------------------------------------------------------------------------------- 1 | SERPER_API_KEY= 2 | SERPER_API_URL = https://google.serper.dev/search 3 | OPENROUTER_API_KEY= 4 | -------------------------------------------------------------------------------- /rankify/tools/Readme.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataScienceUIBK/Rankify/c3d521c2d1a17c7d82f2a9025aa000d9ad41b7d2/rankify/tools/Readme.md -------------------------------------------------------------------------------- /rankify/tools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataScienceUIBK/Rankify/c3d521c2d1a17c7d82f2a9025aa000d9ad41b7d2/rankify/tools/__init__.py -------------------------------------------------------------------------------- /rankify/tools/websearch/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataScienceUIBK/Rankify/c3d521c2d1a17c7d82f2a9025aa000d9ad41b7d2/rankify/tools/websearch/__init__.py -------------------------------------------------------------------------------- /rankify/tools/websearch/content_scraping/scrapedResult.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | from dataclasses import dataclass 4 | from typing import Optional 5 | 6 | 7 | @dataclass 8 | class ScrapedResult: 9 | def __init__(self, name:str,success:bool, content:Optional[str]=None,error:Optional[str]=None): 10 | print(f"results:{content}") 11 | self.name = name 12 | self.success = success 13 | self.content = content 14 | self.error = error 15 | self.raw_markdown_len= 0 16 | self.citations_markdown_len = 0 17 | 18 | 19 | def print_extracted_result(result: ScrapedResult): 20 | """ Method to print out the extracted results""" 21 | if result.success: 22 | print(f"\n=={result.name} Results ===") 23 | print(f"Extracted content: {result.content}") 24 | print(f"Raw markdown length: {result.raw_markdown_len}") 25 | print(f"Citations markdown length: {result.citations_markdown_len}") 26 | else: 27 | print(f"Error in {result.name}: {result.error}") 28 | -------------------------------------------------------------------------------- /rankify/tools/websearch/content_scraping/strategyFactory.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Optional 3 | from crawl4ai.extraction_strategy import ( 4 | LLMExtractionStrategy, 5 | NoExtractionStrategy, 6 | CosineStrategy, 7 | create_llm_config 8 | ) 9 | from rankify.tools.websearch.content_scraping.config import ( 10 | DEFAULT_PROVIDER, 11 | DEFAULT_PROVIDER_API_KEY 12 | ) 13 | llm_config = create_llm_config( 14 | provider=DEFAULT_PROVIDER, 15 | api_token=os.environ.get(DEFAULT_PROVIDER_API_KEY), 16 | ) 17 | 18 | class StrategyFactory: 19 | """Factory-pattern-based class to create extraction strategy""" 20 | 21 | @staticmethod 22 | def create_llm_strategy( 23 | input_format: str = 'markdown', 24 | instruction: str = "Extract relevant content from the provided text, only return the text, no markdown formatting, remove all footnotes, citations, and other metadata and only keep the main content", 25 | verbose=True, 26 | ) -> LLMExtractionStrategy: 27 | return LLMExtractionStrategy( 28 | input_format=input_format, 29 | llm_config=create_llm_config( 30 | provider=DEFAULT_PROVIDER, 31 | api_token=os.environ.get(DEFAULT_PROVIDER_API_KEY), 32 | ), 33 | verbose=verbose, 34 | instruction=instruction 35 | ) 36 | 37 | @staticmethod 38 | def create_no_extraction_strategy() -> NoExtractionStrategy: 39 | return NoExtractionStrategy() 40 | 41 | @staticmethod 42 | def create_cosine_strategy(semantic_filter: Optional[str] = None, 43 | word_count_threshold: int = 10, 44 | max_dist: float = 0.2, 45 | sim_threshold: float = 0.3, 46 | debug: bool = False) -> CosineStrategy: 47 | return CosineStrategy( 48 | semantic_filter=semantic_filter, 49 | word_count_threshold=word_count_threshold, 50 | sim_threshold=sim_threshold, 51 | max_dist=max_dist, 52 | verbose=debug 53 | ) 54 | -------------------------------------------------------------------------------- /rankify/tools/websearch/context/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataScienceUIBK/Rankify/c3d521c2d1a17c7d82f2a9025aa000d9ad41b7d2/rankify/tools/websearch/context/__init__.py -------------------------------------------------------------------------------- /rankify/tools/websearch/context/build_search_context.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import json 4 | import subprocess 5 | 6 | from pyserini.search.lucene import LuceneSearcher 7 | 8 | from rankify.tools.websearch.content_scraping.crawler import WebScraper 9 | from rankify.tools.websearch.content_scraping.scrapedResult import print_extracted_result 10 | from typing import List, TypeVar 11 | import tempfile 12 | T = TypeVar('T') 13 | 14 | 15 | async def build_search_context(sources: T) -> T: 16 | if sources is None: 17 | return [] 18 | 19 | # Filter only Wikipedia sources 20 | filtered_sources = [ 21 | (i, source) 22 | for i, source in enumerate(sources.data['organic']) 23 | if 'wikipedia.org' in source.get('link', '') 24 | ] 25 | 26 | urls = [source[1]['link'] for source in filtered_sources] 27 | print(f"Filtered Wikipedia URLs: {urls}") 28 | 29 | scraper = WebScraper(strategies=['no_extraction']) 30 | multi_results = await scraper.scrape_many(urls) 31 | 32 | sources_with_html: List[T] = [] 33 | for _, source in filtered_sources: 34 | url = source['link'] 35 | if url in multi_results: 36 | source['html'] = multi_results[url]['no_extraction'].content 37 | sources_with_html.append(source) 38 | 39 | return sources_with_html 40 | 41 | 42 | -------------------------------------------------------------------------------- /rankify/tools/websearch/models/SerpResults.py: -------------------------------------------------------------------------------- 1 | 2 | from typing import Generic, TypeVar, Optional 3 | from dataclasses import dataclass 4 | T = TypeVar('T') 5 | 6 | @dataclass 7 | class SerpResult(Generic[T]): 8 | def __init__(self, data: Optional[T] = None, error: Optional[T] = None): 9 | self.data = data 10 | self.success = error is None 11 | self.error = error 12 | 13 | @property 14 | def is_success(self): 15 | return self.success 16 | def __repr__(self): 17 | return f'SerpResult(data={self.data!r}, error={self.error!r})' 18 | 19 | -------------------------------------------------------------------------------- /rankify/tools/websearch/models/Source.py: -------------------------------------------------------------------------------- 1 | """ 2 | Source model to retain the data of the retrieved documents, with runtime validation. 3 | It captures various attributes of the sources that are needed to enhance the reranker. 4 | 5 | """ 6 | 7 | from pydantic import BaseModel 8 | from dataclasses import dataclass 9 | from typing import Optional 10 | 11 | 12 | @dataclass 13 | class Source(BaseModel): 14 | link:str 15 | html:str 16 | author:Optional[str] = None 17 | published_date:Optional[str] = None 18 | credibility_score:float = 0.0 19 | ref_len:int = 0 20 | html_content_len: int = 0 21 | 22 | def __repr__(self): 23 | return f"Source : ({self.link} \n{self.html} \n{self.author} \n{self.published_date} \n{self.credibility_score} \n{self.ref_len} \n{self.html_content_len} \n)" 24 | -------------------------------------------------------------------------------- /rankify/tools/websearch/serp/SearchAPIClient.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from abc import ABC, abstractmethod 3 | from typing import Optional, Dict, Any 4 | from loguru import logger 5 | from rankify.tools.websearch.models.SerpResults import SerpResult 6 | 7 | class SearchAPIClient(ABC): 8 | """ 9 | Abstract Search API client. 10 | """ 11 | @abstractmethod 12 | def search_web(self,query:str,num_results:int = 10 , file_path:Optional[str]=None)-> SerpResult[Dict[str, Any]]: 13 | """ 14 | Search web page using SERPAPI client. 15 | """ 16 | pass 17 | 18 | 19 | -------------------------------------------------------------------------------- /rankify/tools/websearch/serp/SerperApiClient.py: -------------------------------------------------------------------------------- 1 | import requests 2 | from rankify.tools.websearch.models.SerpResults import SerpResult 3 | from rankify.tools.websearch.serp.SearchAPIClient import SearchAPIClient 4 | from typing import Optional, Dict, Any, List 5 | 6 | from rankify.tools.websearch.serp.config import SerpConfig 7 | from rankify.tools.websearch.serp.errors import SerperAPIException 8 | 9 | 10 | class SerperApiClient(SearchAPIClient): 11 | 12 | def __init__(self, api_key:Optional[str]=None,config:Optional[SerpConfig]=None): 13 | if api_key: 14 | self.config = SerpConfig(api_key=api_key) 15 | elif config: 16 | self.config = config 17 | else: 18 | self.config = SerpConfig.load_env_vars() 19 | self.header = { 20 | 'Accept': 'application/json', 21 | 'X-API-Key': self.config.api_key, 22 | } 23 | 24 | def search_web(self,query:str,num_results:int = 10 , search_location:Optional[str]=None) -> SerpResult[ 25 | Dict[str, Any]]: 26 | 27 | if not query.strip(): 28 | return SerperAPIException(error=f"query is required") 29 | try: 30 | payload = { 31 | 'q': query, 32 | 'numResults': num_results, 33 | 'gl': search_location or self.config.default_location 34 | } 35 | response = requests.post( 36 | url=self.config.api_url, 37 | headers=self.header, 38 | json=payload, 39 | timeout=self.config.timeout, 40 | ) 41 | response.raise_for_status() 42 | 43 | data = response.json() 44 | 45 | results = { 46 | 'organic': self.extract_fields(data.get('organic',[]),['title','link','snippet','date']), 47 | 'topStories': self.extract_fields(data.get('topStories',[]),['title','imageUrl']), 48 | 'images':self.extract_fields(data.get('images',[])[:6],['title','imageUrl']), 49 | 'answerBox':data.get('answerBox'), 50 | 'peopleAlsoAsk':data.get('peopleAlsoAsk'), 51 | 'relatedSearches':data.get('relatedSearches') 52 | } 53 | 54 | return SerpResult(data=results) 55 | 56 | except requests.RequestException as e: 57 | return SerpResult(error=f"Serper API request failed {str(e)}") 58 | except Exception as e: 59 | return SerpResult(error=f" Unexcpected error {str(e)}") 60 | 61 | @staticmethod 62 | def extract_fields(items:List[Dict[str, Any]], field:List[str]) -> List[Dict[str, Any]]: 63 | return [{key:item.get(key,"") for key in field if key in item} for item in items] 64 | 65 | -------------------------------------------------------------------------------- /rankify/tools/websearch/serp/config.py: -------------------------------------------------------------------------------- 1 | """ 2 | modular search results page API configuration. 3 | Allows configuration of various SERPER API clients, thorough multiple constructor methods e.g. .env file. 4 | 5 | """ 6 | 7 | import os 8 | from dotenv import load_dotenv 9 | from dataclasses import dataclass 10 | from loguru import logger 11 | from .errors import SearchAPIException, SerperAPIException 12 | load_dotenv() 13 | 14 | @dataclass 15 | class SerpConfig: 16 | """ 17 | Configures the SERP API, which handles general configurations api_key, ...etc. 18 | Allowing multiple constructors of Serp configurations e.g., loading env variables. 19 | """ 20 | def __init__(self, api_key:str,api_url:str="https://google.serper.dev/search",serp_api_client:str='SERPER',default_location:str = "us" , timeout:int = 10 ): 21 | self.api_key = api_key 22 | self.api_url = api_url 23 | self.serp_api_client = serp_api_client 24 | self.default_location = default_location 25 | self.timeout = timeout 26 | 27 | @classmethod 28 | def load_env_vars(cls,serp_api_client:str='SERPER')->'SerpConfig': 29 | logger.info('Loading SERP API Configuring from .env file') 30 | 31 | """ 32 | Initialize configurations of SERPER from .env file. 33 | 34 | Arg: 35 | cls 36 | Returns: 37 | SerperConfig 38 | """ 39 | print(serp_api_client+'_API_KEY') 40 | api_key = os.getenv(serp_api_client+'_API_KEY') 41 | if not api_key: 42 | raise SerperAPIException('⛔ SERPER_API_key not found in env file.') 43 | api_url = os.getenv(serp_api_client+'_API_URL') 44 | if not api_url: 45 | raise SerperAPIException('⛔ SERPER_API_URL not found in env file.') 46 | return cls(api_key = api_key,api_url= api_url) 47 | 48 | 49 | 50 | -------------------------------------------------------------------------------- /rankify/tools/websearch/serp/errors.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Any 2 | 3 | from loguru import logger 4 | 5 | from rankify.tools.websearch.models.SerpResults import SerpResult 6 | 7 | 8 | class SearchAPIException(Exception): 9 | """ 10 | Exception raised for errors in search API. 11 | """ 12 | def __init__(self, error): 13 | logger.exception(str(error)) 14 | super().__init__(error) 15 | pass 16 | 17 | class SerperAPIException(SearchAPIException): 18 | """ 19 | Exception raised for errors in serper API client. 20 | """ 21 | def __init__(self, error: object) -> None: 22 | logger.exception(str(error)) 23 | super().__init__(error) 24 | pass 25 | 26 | 27 | -------------------------------------------------------------------------------- /rankify/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataScienceUIBK/Rankify/c3d521c2d1a17c7d82f2a9025aa000d9ad41b7d2/rankify/utils/__init__.py -------------------------------------------------------------------------------- /rankify/utils/api/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataScienceUIBK/Rankify/c3d521c2d1a17c7d82f2a9025aa000d9ad41b7d2/rankify/utils/api/__init__.py -------------------------------------------------------------------------------- /rankify/utils/api/claudeclient.py: -------------------------------------------------------------------------------- 1 | from anthropic import Anthropic, HUMAN_PROMPT, AI_PROMPT 2 | 3 | 4 | class ClaudeClient: 5 | def __init__(self, keys, url): 6 | 7 | self.anthropic = Anthropic(api_key=keys, base_url=url) 8 | 9 | def chat(self, messages, return_text=True, max_tokens=300, *args, **kwargs): 10 | system = ' '.join([turn['content'] for turn in messages if turn['role'] == 'system']) 11 | messages = [turn for turn in messages if turn['role'] != 'system'] 12 | if len(system) == 0: 13 | system = None 14 | completion = self.anthropic.beta.messages.create(messages=messages, system=system, max_tokens=max_tokens, *args, **kwargs) 15 | if return_text: 16 | completion = completion.content[0].text 17 | return completion 18 | 19 | def text(self, max_tokens=None, return_text=True, *args, **kwargs): 20 | completion = self.anthropic.beta.messages.create(max_tokens_to_sample=max_tokens, *args, **kwargs) 21 | if return_text: 22 | completion = completion.completion 23 | return completion 24 | -------------------------------------------------------------------------------- /rankify/utils/api/litellmclient.py: -------------------------------------------------------------------------------- 1 | 2 | from litellm import completion 3 | 4 | class LitellmClient: 5 | # https://github.com/BerriAI/litellm 6 | def __init__(self, keys=None): 7 | self.api_key = keys 8 | 9 | def chat(self, return_text=True, *args, **kwargs): 10 | response = completion(api_key=self.api_key, *args, **kwargs) 11 | if return_text: 12 | response = response.choices[0].message.content 13 | return response -------------------------------------------------------------------------------- /rankify/utils/api/openaiclient.py: -------------------------------------------------------------------------------- 1 | 2 | from openai import OpenAI 3 | import openai 4 | import time 5 | class OpenaiClient: 6 | def __init__(self, keys=None, base_url =None, start_id=None, proxy=None): 7 | 8 | if isinstance(keys, str): 9 | keys = [keys] 10 | if keys is None: 11 | raise "Please provide OpenAI Key." 12 | 13 | self.key = keys 14 | self.base_url = base_url 15 | self.key_id = start_id or 0 16 | self.key_id = self.key_id % len(self.key) 17 | self.api_key = self.key[self.key_id % len(self.key)] 18 | print(self.base_url ) 19 | self.client = OpenAI(api_key=self.api_key , base_url = self.base_url) 20 | 21 | def chat(self, *args, return_text=False, reduce_length=False, **kwargs): 22 | while True: 23 | try: 24 | completion = self.client.chat.completions.create(*args, **kwargs, timeout=30) 25 | break 26 | except Exception as e: 27 | print(str(e)) 28 | if "This model's maximum context length is" in str(e): 29 | print('reduce_length') 30 | return 'ERROR::reduce_length' 31 | time.sleep(0.1) 32 | if return_text: 33 | completion = completion.choices[0].message.content 34 | return completion 35 | 36 | def text(self, *args, return_text=False, reduce_length=False, **kwargs): 37 | while True: 38 | try: 39 | completion = self.client.completions.create( 40 | *args, **kwargs 41 | ) 42 | break 43 | except Exception as e: 44 | print(e) 45 | if "This model's maximum context length is" in str(e): 46 | print('reduce_length') 47 | return 'ERROR::reduce_length' 48 | time.sleep(0.1) 49 | if return_text: 50 | completion = completion.choices[0].text 51 | return completion 52 | -------------------------------------------------------------------------------- /rankify/utils/dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataScienceUIBK/Rankify/c3d521c2d1a17c7d82f2a9025aa000d9ad41b7d2/rankify/utils/dataset/__init__.py -------------------------------------------------------------------------------- /rankify/utils/dataset/download.py: -------------------------------------------------------------------------------- 1 | import os 2 | import requests 3 | from rankify.utils.pre_defined_datasets import HF_PRE_DEFIND_DATASET 4 | from tqdm import tqdm 5 | 6 | class DownloadManger: 7 | @staticmethod 8 | def download(retriever: str, dataset: str, force_download: bool = True) ->str: 9 | if retriever not in HF_PRE_DEFIND_DATASET: 10 | raise FileNotFoundError(f"Retriever {retriever} Not Supported yet. Please choose another retriever.\nCheck Dataset.available_dataset()") 11 | if dataset not in HF_PRE_DEFIND_DATASET[retriever]: 12 | raise FileNotFoundError(f"Dataset {dataset} Not Supported yet. Please choose another dataset.\nCheck Dataset.available_dataset()") 13 | 14 | filename = HF_PRE_DEFIND_DATASET[retriever][dataset]['filename'] 15 | if '-' in dataset: 16 | dataset_name, dataset_split = dataset.split('-', 1) 17 | else: 18 | dataset_name = dataset 19 | urls = HF_PRE_DEFIND_DATASET[retriever][dataset]['url'] 20 | path = os.path.join(os.environ['RERANKING_CACHE_DIR'], 'dataset', retriever, dataset_name) 21 | file_path = os.path.join(path, filename) 22 | 23 | # If force_download is False and file already exists, skip downloading 24 | if not force_download and os.path.exists(file_path): 25 | print(f"File {file_path} already exists. Skipping download.") 26 | return file_path 27 | 28 | os.makedirs(path, exist_ok=True) 29 | 30 | for url in urls: 31 | response = requests.get(url, stream=True) 32 | if response.status_code == 200: 33 | total_size = int(response.headers.get('content-length', 0)) 34 | with open(file_path, 'wb') as file, tqdm( 35 | desc=f"Downloading {retriever} {dataset_name} {filename}", 36 | unit='B', 37 | unit_scale=True, 38 | unit_divisor=1024, 39 | total=total_size, 40 | ) as bar: 41 | # Update progress bar while streaming chunks 42 | for chunk in response.iter_content(chunk_size=1024): 43 | if chunk: # Filter out keep-alive chunks 44 | file.write(chunk) 45 | bar.update(len(chunk)) 46 | 47 | return file_path 48 | else: 49 | raise Exception(f'Failed to download the file from {url}') 50 | 51 | 52 | -------------------------------------------------------------------------------- /rankify/utils/dataset/utils.py: -------------------------------------------------------------------------------- 1 | from rankify.utils.pre_defined_datasets import HF_PRE_DEFIND_DATASET 2 | import pandas as pd 3 | from prettytable import PrettyTable 4 | 5 | def get_datasets_info(): 6 | table = PrettyTable(['Retriever', 'Dataset', 'Original ext', 'Compressed','Desc','URL']) 7 | for retriever, datasets in HF_PRE_DEFIND_DATASET.items(): 8 | for dataset_name, dataset_info in datasets.items(): 9 | 10 | flattened_entry = { 11 | 'retriever': retriever, 12 | 'dataset': dataset_name, 13 | 'original_ext': dataset_info.get('original_ext'), 14 | 'compressed': dataset_info.get('compressed'), 15 | 'desc': dataset_info.get('desc'), 16 | 'url': dataset_info.get('url') 17 | } 18 | table.add_row(flattened_entry.values()) 19 | 20 | print(table) -------------------------------------------------------------------------------- /rankify/utils/generator/FiD/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataScienceUIBK/Rankify/c3d521c2d1a17c7d82f2a9025aa000d9ad41b7d2/rankify/utils/generator/FiD/__init__.py -------------------------------------------------------------------------------- /rankify/utils/generator/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataScienceUIBK/Rankify/c3d521c2d1a17c7d82f2a9025aa000d9ad41b7d2/rankify/utils/generator/__init__.py -------------------------------------------------------------------------------- /rankify/utils/generator/download.py: -------------------------------------------------------------------------------- 1 | import os 2 | import requests 3 | import tarfile 4 | from tqdm import tqdm 5 | 6 | class ModelDownloader: 7 | """ 8 | Utility class for downloading and extracting model files. 9 | """ 10 | 11 | @staticmethod 12 | def download_and_extract(url, output_dir): 13 | """ 14 | Downloads and extracts a model from a given URL. 15 | """ 16 | os.makedirs(output_dir, exist_ok=True) 17 | tar_path = os.path.join(output_dir, "model.tar.gz") 18 | 19 | response = requests.get(url, stream=True) 20 | total_size = int(response.headers.get("content-length", 0)) 21 | with open(tar_path, "wb") as file, tqdm( 22 | desc="Downloading Model", total=total_size, unit="B", unit_scale=True 23 | ) as pbar: 24 | for chunk in response.iter_content(chunk_size=1024): 25 | if chunk: 26 | file.write(chunk) 27 | pbar.update(len(chunk)) 28 | 29 | with tarfile.open(tar_path, "r:gz") as tar: 30 | subdir = tar.getnames()[0] # Get the top-level directory inside the tarball 31 | tar.extractall(output_dir) 32 | 33 | # Move contents up if they were extracted into a subdirectory 34 | extracted_path = os.path.join(output_dir, subdir) 35 | if os.path.exists(extracted_path) and os.path.isdir(extracted_path): 36 | for file in os.listdir(extracted_path): 37 | os.rename(os.path.join(extracted_path, file), os.path.join(output_dir, file)) 38 | os.rmdir(extracted_path) # Remove the now-empty subdirectory 39 | -------------------------------------------------------------------------------- /rankify/utils/generator/generator_models.py: -------------------------------------------------------------------------------- 1 | from rankify.generator.rag_methods.basic_rag import BasicRAG 2 | from rankify.generator.rag_methods.chain_of_thought_rag import ChainOfThoughtRAG 3 | from rankify.generator.rag_methods.fid_rag_method import FiDRAGMethod 4 | from rankify.generator.rag_methods.in_context_ralm_rag import InContextRALMRAG 5 | from rankify.generator.rag_methods.zero_shot import ZeroShotRAG 6 | 7 | 8 | RAG_METHODS = { 9 | "in-context-ralm": InContextRALMRAG, 10 | "fid": FiDRAGMethod, 11 | "zero-shot": ZeroShotRAG, 12 | "basic-rag": BasicRAG, 13 | "chain-of-thought-rag": ChainOfThoughtRAG, 14 | } -------------------------------------------------------------------------------- /rankify/utils/generator/huggingface_models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataScienceUIBK/Rankify/c3d521c2d1a17c7d82f2a9025aa000d9ad41b7d2/rankify/utils/generator/huggingface_models/__init__.py -------------------------------------------------------------------------------- /rankify/utils/generator/huggingface_models/model_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer 3 | from huggingface_hub import login 4 | 5 | 6 | def load_tokenizer(model_name): 7 | return AutoTokenizer.from_pretrained(model_name) 8 | 9 | def load_model(model_name, **kwargs): 10 | model_parallelism = kwargs.get("model_parallelism", False) 11 | cache_dir = kwargs.get("cache_dir", None) 12 | auth_token = kwargs.get("auth_token", None) 13 | torch_dtype = kwargs.get("torch_dtype", None) 14 | 15 | device = "cuda" if torch.cuda.is_available() else "cpu" 16 | device_count = torch.cuda.device_count() 17 | 18 | config = AutoConfig.from_pretrained(model_name) 19 | model_args = {} 20 | if cache_dir is not None: 21 | model_args["cache_dir"] = cache_dir 22 | if model_parallelism: 23 | model_args["device_map"] = "auto" 24 | model_args["low_cpu_mem_usage"] = True 25 | if hasattr(config, "torch_dtype") and config.torch_dtype is not None: 26 | model_args["torch_dtype"] = config.torch_dtype 27 | if torch_dtype is not None: 28 | model_args["torch_dtype"] = torch_dtype # overload dtype if user specifies 29 | if auth_token is not None: 30 | model_args["use_auth_token"] = auth_token 31 | 32 | model = AutoModelForCausalLM.from_pretrained(model_name, **model_args).eval() 33 | if not model_parallelism: 34 | model = model.to(device) 35 | 36 | return model 37 | -------------------------------------------------------------------------------- /rankify/utils/helper.py: -------------------------------------------------------------------------------- 1 | 2 | from typing import Union, List, Optional, Tuple 3 | import torch 4 | 5 | def get_device( 6 | device: Optional[Union[str, torch.device]], 7 | no_mps: bool = False, 8 | ) -> Union[str, torch.device]: 9 | if not device: 10 | if torch.cuda.is_available(): 11 | device = "cuda" 12 | elif torch.backends.mps.is_available() and not no_mps: 13 | device = "mps" 14 | else: 15 | device = "cpu" 16 | return device 17 | 18 | def get_dtype( 19 | dtype: Optional[Union[str, torch.dtype]], 20 | device: Optional[Union[str, torch.device]], 21 | verbose: int = 1, 22 | ) -> torch.dtype: 23 | if dtype is None: 24 | print("No dtype set") 25 | if device == "cpu": 26 | dtype = torch.float32 27 | if not isinstance(dtype, torch.dtype): 28 | if dtype == "fp16" or "float16": 29 | dtype = torch.float16 30 | elif dtype == "bf16" or "bfloat16": 31 | dtype = torch.bfloat16 32 | else: 33 | dtype = torch.float32 34 | return dtype -------------------------------------------------------------------------------- /rankify/utils/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataScienceUIBK/Rankify/c3d521c2d1a17c7d82f2a9025aa000d9ad41b7d2/rankify/utils/models/__init__.py -------------------------------------------------------------------------------- /rankify/utils/models/incontext_reranker/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataScienceUIBK/Rankify/c3d521c2d1a17c7d82f2a9025aa000d9ad41b7d2/rankify/utils/models/incontext_reranker/__init__.py -------------------------------------------------------------------------------- /rankify/utils/models/llm2vec_model/__init__.py: -------------------------------------------------------------------------------- 1 | from .bidirectional_mistral import MistralBiModel, MistralBiForMNTP 2 | from .bidirectional_llama import LlamaBiModel, LlamaBiForMNTP 3 | from .bidirectional_gemma import GemmaBiModel, GemmaBiForMNTP 4 | from .bidirectional_qwen2 import Qwen2BiModel, Qwen2BiForMNTP -------------------------------------------------------------------------------- /rankify/utils/models/llm2vec_model/utils.py: -------------------------------------------------------------------------------- 1 | import importlib.metadata 2 | from packaging import version 3 | from transformers.utils.import_utils import _is_package_available 4 | 5 | 6 | def is_transformers_attn_greater_or_equal_4_43_1(): 7 | if not _is_package_available("transformers"): 8 | return False 9 | 10 | return version.parse(importlib.metadata.version("transformers")) >= version.parse( 11 | "4.43.1" 12 | ) -------------------------------------------------------------------------------- /rankify/utils/models/rank_llm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataScienceUIBK/Rankify/c3d521c2d1a17c7d82f2a9025aa000d9ad41b7d2/rankify/utils/models/rank_llm/__init__.py -------------------------------------------------------------------------------- /rankify/utils/models/rank_llm/rerank/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | 4 | 5 | from .api_keys import get_azure_openai_args, get_openai_api_key 6 | from .identity_reranker import IdentityReranker 7 | from .rankllm import PromptMode, RankLLM 8 | from .reranker import Reranker 9 | 10 | logging.basicConfig( 11 | level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" 12 | ) 13 | 14 | __all__ = [ 15 | "IdentityReranker", 16 | "RankLLM", 17 | "get_azure_openai_args", 18 | "get_openai_api_key", 19 | "PromptMode", 20 | "Reranker", 21 | ] -------------------------------------------------------------------------------- /rankify/utils/models/rank_llm/rerank/api_keys.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Dict 3 | 4 | from dotenv import load_dotenv 5 | from tqdm import tqdm 6 | 7 | # Disable tqdm progress bars globally 8 | tqdm.disable = True 9 | # Common OpenAI API key paths 10 | paths = [ 11 | "OPENAI_API_KEY", 12 | "OPEN_AI_API_KEY", 13 | ] 14 | 15 | 16 | def get_openai_api_key() -> str: 17 | load_dotenv(dotenv_path=f".env.local") 18 | 19 | for path in paths: 20 | if os.getenv(path) is not None: 21 | return os.getenv(path) 22 | return None 23 | 24 | def get_litellm_api_key() -> str: 25 | load_dotenv(dotenv_path=f".env.local") 26 | litellm_api_key = os.getenv("LITELLM_API_KEY") 27 | 28 | if litellm_api_key is None: 29 | raise ValueError( 30 | "LITELLM_API_KEY not found in environment variables. Please set it." 31 | ) 32 | return litellm_api_key 33 | 34 | def get_azure_openai_args() -> Dict[str, str]: 35 | load_dotenv(dotenv_path=f".env.local") 36 | azure_args = { 37 | "api_type": "azure", 38 | "api_version": os.getenv("AZURE_OPENAI_API_VERSION"), 39 | "api_base": os.getenv("AZURE_OPENAI_API_BASE"), 40 | } 41 | 42 | # Sanity check 43 | assert all( 44 | list(azure_args.values()) 45 | ), "Ensure that `AZURE_OPENAI_API_BASE`, `AZURE_OPENAI_API_VERSION` are set" 46 | return azure_args -------------------------------------------------------------------------------- /rankify/utils/models/rank_llm/rerank/identity_reranker.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import random 3 | from typing import List 4 | 5 | from rankify.utils.models.rank_llm.data import Request, Result 6 | from datetime import datetime 7 | from typing import Any, List 8 | 9 | 10 | from tqdm import tqdm 11 | 12 | # Disable tqdm progress bars globally 13 | tqdm.disable = True 14 | 15 | class IdentityReranker: 16 | def rerank_batch( 17 | self, 18 | requests: List[Request], 19 | rank_start: int = 0, 20 | rank_end: int = 100, 21 | shuffle_candidates: bool = False, 22 | logging: bool = False, 23 | **kwargs: Any, 24 | ) -> List[Result]: 25 | """ 26 | A trivial reranker that returns a subsection of the retrieved candidates list as-is or shuffled. 27 | 28 | Args: 29 | requests (List[Request]): The list of requests. Each request has a query and a candidates list. 30 | rank_start (int, optional): The starting rank for returning. Defaults to 0. 31 | rank_end (int, optional): The end rank for returning. Defaults to 100. 32 | shuffle_candidates (bool, optional): Whether to shuffle candidates before returning. Defaults to False. 33 | 34 | Returns: 35 | List[Result]: A list containing the reranked candidates. 36 | """ 37 | results = [] 38 | for request in requests: 39 | rerank_result = Result( 40 | query=copy.deepcopy(request.query), 41 | candidates=copy.deepcopy(request.candidates), 42 | ranking_exec_summary=[], 43 | ) 44 | if shuffle_candidates: 45 | # Randomly shuffle rerank_result between rank_start and rank_end 46 | rerank_result.candidates[rank_start:rank_end] = random.sample( 47 | rerank_result.candidates[rank_start:rank_end], 48 | len(rerank_result.candidates[rank_start:rank_end]), 49 | ) 50 | results.append(rerank_result) 51 | return results 52 | 53 | def get_name(self) -> str: 54 | return "identity_reranker" 55 | 56 | def get_output_filename( 57 | self, 58 | top_k_candidates: int, 59 | dataset_name: str, 60 | shuffle_candidates: bool, 61 | **kwargs: Any, 62 | ) -> str: 63 | return f"identity_{datetime.isoformat(datetime.now())}" -------------------------------------------------------------------------------- /rankify/utils/models/rank_llm/rerank/listwise/__init__.py: -------------------------------------------------------------------------------- 1 | #from .rank_gpt import SafeOpenai 2 | from .rank_listwise_os_llm import RankListwiseOSLLM 3 | #from .vicuna_reranker import VicunaReranker 4 | #from .zephyr_reranker import ZephyrReranker 5 | 6 | __all__ = [ 7 | "RankListwiseOSLLM", 8 | "VicunaReranker", 9 | "ZephyrReranker", 10 | "SafeOpenai", 11 | ] -------------------------------------------------------------------------------- /rankify/utils/models/rank_llm/rerank/listwise/lit5/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataScienceUIBK/Rankify/c3d521c2d1a17c7d82f2a9025aa000d9ad41b7d2/rankify/utils/models/rank_llm/rerank/listwise/lit5/__init__.py -------------------------------------------------------------------------------- /rankify/utils/pre_defined_methods_retrievers.py: -------------------------------------------------------------------------------- 1 | from rankify.retrievers.OnlineRetriever import OnlineRetriever 2 | from rankify.retrievers.dpr import DenseRetriever 3 | from rankify.retrievers.bm25 import BM25Retriever 4 | from rankify.retrievers.contriever import ContrieverRetriever 5 | from rankify.retrievers.BGERetriever import BGERetriever 6 | from rankify.retrievers.colbert import ColBERTRetriever 7 | from rankify.retrievers.hyde import HydeRetreiver 8 | 9 | 10 | METHOD_MAP ={ 11 | 'bm25': BM25Retriever, 12 | 'dpr': DenseRetriever, 13 | 'bpr': DenseRetriever, 14 | 'contriever': ContrieverRetriever, 15 | 'ance': DenseRetriever, 16 | 'bge': BGERetriever, 17 | 'colbert': ColBERTRetriever, 18 | 'hyde': HydeRetreiver, 19 | 'online_retriever': OnlineRetriever, 20 | } -------------------------------------------------------------------------------- /rankify/utils/retrievers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataScienceUIBK/Rankify/c3d521c2d1a17c7d82f2a9025aa000d9ad41b7d2/rankify/utils/retrievers/__init__.py -------------------------------------------------------------------------------- /rankify/utils/retrievers/colbert/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataScienceUIBK/Rankify/c3d521c2d1a17c7d82f2a9025aa000d9ad41b7d2/rankify/utils/retrievers/colbert/__init__.py -------------------------------------------------------------------------------- /rankify/utils/retrievers/colbert/colbert/__init__.py: -------------------------------------------------------------------------------- 1 | from .indexer import Indexer 2 | from .searcher import Searcher 3 | from .index_updater import IndexUpdater 4 | 5 | from .modeling.checkpoint import Checkpoint 6 | -------------------------------------------------------------------------------- /rankify/utils/retrievers/colbert/colbert/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .collection import * 2 | from .queries import * 3 | 4 | from .ranking import * 5 | from .examples import * 6 | -------------------------------------------------------------------------------- /rankify/utils/retrievers/colbert/colbert/data/dataset.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | # Not just the corpus, but also an arbitrary number of query sets, indexed by name in a dictionary/dotdict. 4 | # And also query sets with top-k PIDs. 5 | # QAs too? TripleSets too? 6 | 7 | 8 | class Dataset: 9 | def __init__(self): 10 | pass 11 | 12 | def select(self, key): 13 | # Select the {corpus, queryset, tripleset, rankingset} determined by uniqueness or by key and return a "unique" dataset (e.g., for key=train) 14 | pass 15 | -------------------------------------------------------------------------------- /rankify/utils/retrievers/colbert/colbert/distillation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataScienceUIBK/Rankify/c3d521c2d1a17c7d82f2a9025aa000d9ad41b7d2/rankify/utils/retrievers/colbert/colbert/distillation/__init__.py -------------------------------------------------------------------------------- /rankify/utils/retrievers/colbert/colbert/distillation/ranking_scorer.py: -------------------------------------------------------------------------------- 1 | import tqdm 2 | import ujson 3 | 4 | from collections import defaultdict 5 | 6 | from colbert.utils.utils import print_message, zipstar 7 | from utility.utils.save_metadata import get_metadata_only 8 | 9 | from colbert.infra import Run 10 | from colbert.data import Ranking 11 | from colbert.infra.provenance import Provenance 12 | from colbert.distillation.scorer import Scorer 13 | 14 | 15 | class RankingScorer: 16 | def __init__(self, scorer: Scorer, ranking: Ranking): 17 | self.scorer = scorer 18 | self.ranking = ranking.tolist() 19 | self.__provenance = Provenance() 20 | 21 | print_message(f"#> Loaded ranking with {len(self.ranking)} qid--pid pairs!") 22 | 23 | def provenance(self): 24 | return self.__provenance 25 | 26 | def run(self): 27 | print_message(f"#> Starting..") 28 | 29 | qids, pids, *_ = zipstar(self.ranking) 30 | distillation_scores = self.scorer.launch(qids, pids) 31 | 32 | scores_by_qid = defaultdict(list) 33 | 34 | for qid, pid, score in tqdm.tqdm(zip(qids, pids, distillation_scores)): 35 | scores_by_qid[qid].append((score, pid)) 36 | 37 | with Run().open('distillation_scores.json', 'w') as f: 38 | for qid in tqdm.tqdm(scores_by_qid): 39 | obj = (qid, scores_by_qid[qid]) 40 | f.write(ujson.dumps(obj) + '\n') 41 | 42 | output_path = f.name 43 | print_message(f'#> Saved the distillation_scores to {output_path}') 44 | 45 | with Run().open(f'{output_path}.meta', 'w') as f: 46 | d = {} 47 | d['metadata'] = get_metadata_only() 48 | d['provenance'] = self.provenance() 49 | line = ujson.dumps(d, indent=4) 50 | f.write(line) 51 | 52 | return output_path 53 | -------------------------------------------------------------------------------- /rankify/utils/retrievers/colbert/colbert/distillation/scorer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import tqdm 3 | 4 | from transformers import AutoTokenizer, AutoModelForSequenceClassification 5 | 6 | from colbert.infra.launcher import Launcher 7 | from colbert.infra import Run, RunConfig 8 | from colbert.modeling.reranker.electra import ElectraReranker 9 | from colbert.utils.utils import flatten 10 | 11 | 12 | DEFAULT_MODEL = 'cross-encoder/ms-marco-MiniLM-L-6-v2' 13 | 14 | 15 | class Scorer: 16 | def __init__(self, queries, collection, model=DEFAULT_MODEL, maxlen=180, bsize=256): 17 | self.queries = queries 18 | self.collection = collection 19 | self.model = model 20 | 21 | self.maxlen = maxlen 22 | self.bsize = bsize 23 | 24 | def launch(self, qids, pids): 25 | launcher = Launcher(self._score_pairs_process, return_all=True) 26 | outputs = launcher.launch(Run().config, qids, pids) 27 | 28 | return flatten(outputs) 29 | 30 | def _score_pairs_process(self, config, qids, pids): 31 | assert len(qids) == len(pids), (len(qids), len(pids)) 32 | share = 1 + len(qids) // config.nranks 33 | offset = config.rank * share 34 | endpos = (1 + config.rank) * share 35 | 36 | return self._score_pairs(qids[offset:endpos], pids[offset:endpos], show_progress=(config.rank < 1)) 37 | 38 | def _score_pairs(self, qids, pids, show_progress=False): 39 | tokenizer = AutoTokenizer.from_pretrained(self.model) 40 | model = AutoModelForSequenceClassification.from_pretrained(self.model).cuda() 41 | 42 | assert len(qids) == len(pids), (len(qids), len(pids)) 43 | 44 | scores = [] 45 | 46 | model.eval() 47 | with torch.inference_mode(): 48 | with torch.cuda.amp.autocast(): 49 | for offset in tqdm.tqdm(range(0, len(qids), self.bsize), disable=(not show_progress)): 50 | endpos = offset + self.bsize 51 | 52 | queries_ = [self.queries[qid] for qid in qids[offset:endpos]] 53 | passages_ = [self.collection[pid] for pid in pids[offset:endpos]] 54 | 55 | features = tokenizer(queries_, passages_, padding='longest', truncation=True, 56 | return_tensors='pt', max_length=self.maxlen).to(model.device) 57 | 58 | scores.append(model(**features).logits.flatten()) 59 | 60 | scores = torch.cat(scores) 61 | scores = scores.tolist() 62 | 63 | Run().print(f'Returning with {len(scores)} scores') 64 | 65 | return scores 66 | 67 | 68 | # LONG-TERM TODO: This can be sped up by sorting by length in advance. 69 | -------------------------------------------------------------------------------- /rankify/utils/retrievers/colbert/colbert/evaluation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataScienceUIBK/Rankify/c3d521c2d1a17c7d82f2a9025aa000d9ad41b7d2/rankify/utils/retrievers/colbert/colbert/evaluation/__init__.py -------------------------------------------------------------------------------- /rankify/utils/retrievers/colbert/colbert/evaluation/load_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import ujson 3 | import torch 4 | import random 5 | 6 | from collections import defaultdict, OrderedDict 7 | 8 | from rankify.utils.retrievers.colbert.colbert.parameters import DEVICE 9 | from rankify.utils.retrievers.colbert.colbert.modeling.colbert import ColBERT 10 | from rankify.utils.retrievers.colbert.colbert.utils.utils import print_message, load_checkpoint 11 | 12 | 13 | def load_model(args, do_print=True): 14 | colbert = ColBERT.from_pretrained('bert-base-uncased', 15 | query_maxlen=args.query_maxlen, 16 | doc_maxlen=args.doc_maxlen, 17 | dim=args.dim, 18 | similarity_metric=args.similarity, 19 | mask_punctuation=args.mask_punctuation) 20 | colbert = colbert.to(DEVICE) 21 | 22 | print_message("#> Loading model checkpoint.", condition=do_print) 23 | 24 | checkpoint = load_checkpoint(args.checkpoint, colbert, do_print=do_print) 25 | 26 | colbert.eval() 27 | 28 | return colbert, checkpoint 29 | -------------------------------------------------------------------------------- /rankify/utils/retrievers/colbert/colbert/index.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | # TODO: This is the loaded index, underneath a searcher. 4 | 5 | 6 | """ 7 | ## Operations: 8 | 9 | index = Index(index='/path/to/index') 10 | index.load_to_memory() 11 | 12 | batch_of_pids = [2324,32432,98743,23432] 13 | index.lookup(batch_of_pids, device='cuda:0') -> (N, doc_maxlen, dim) 14 | 15 | index.iterate_over_parts() 16 | 17 | """ 18 | -------------------------------------------------------------------------------- /rankify/utils/retrievers/colbert/colbert/indexing/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataScienceUIBK/Rankify/c3d521c2d1a17c7d82f2a9025aa000d9ad41b7d2/rankify/utils/retrievers/colbert/colbert/indexing/__init__.py -------------------------------------------------------------------------------- /rankify/utils/retrievers/colbert/colbert/indexing/codecs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataScienceUIBK/Rankify/c3d521c2d1a17c7d82f2a9025aa000d9ad41b7d2/rankify/utils/retrievers/colbert/colbert/indexing/codecs/__init__.py -------------------------------------------------------------------------------- /rankify/utils/retrievers/colbert/colbert/indexing/codecs/decompress_residuals.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | torch::Tensor decompress_residuals_cuda( 4 | const torch::Tensor binary_residuals, const torch::Tensor bucket_weights, 5 | const torch::Tensor reversed_bit_map, 6 | const torch::Tensor bucket_weight_combinations, const torch::Tensor codes, 7 | const torch::Tensor centroids, const int dim, const int nbits); 8 | 9 | torch::Tensor decompress_residuals( 10 | const torch::Tensor binary_residuals, const torch::Tensor bucket_weights, 11 | const torch::Tensor reversed_bit_map, 12 | const torch::Tensor bucket_weight_combinations, const torch::Tensor codes, 13 | const torch::Tensor centroids, const int dim, const int nbits) { 14 | // Add input verification 15 | return decompress_residuals_cuda( 16 | binary_residuals, bucket_weights, reversed_bit_map, 17 | bucket_weight_combinations, codes, centroids, dim, nbits); 18 | } 19 | 20 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 21 | m.def("decompress_residuals_cpp", &decompress_residuals, 22 | "Decompress residuals"); 23 | } 24 | -------------------------------------------------------------------------------- /rankify/utils/retrievers/colbert/colbert/indexing/codecs/packbits.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | torch::Tensor packbits_cuda(const torch::Tensor residuals); 4 | 5 | torch::Tensor packbits(const torch::Tensor residuals) { 6 | return packbits_cuda(residuals); 7 | } 8 | 9 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 10 | m.def("packbits_cpp", &packbits, "Pack bits"); 11 | } 12 | 13 | -------------------------------------------------------------------------------- /rankify/utils/retrievers/colbert/colbert/indexing/codecs/packbits.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | #define FULL_MASK 0xffffffff 9 | 10 | __global__ void packbits_kernel( 11 | const uint8_t* residuals, 12 | uint8_t* packed_residuals, 13 | const int residuals_size) { 14 | const int i = blockIdx.x; 15 | const int j = threadIdx.x; 16 | 17 | assert(blockDim.x == 32); 18 | 19 | const int residuals_idx = i * blockDim.x + j; 20 | if (residuals_idx >= residuals_size) { 21 | return; 22 | } 23 | 24 | const int packed_residuals_idx = residuals_idx / 8; 25 | 26 | 27 | uint32_t mask = __ballot_sync(FULL_MASK, residuals[residuals_idx]); 28 | 29 | mask = __brev(mask); 30 | 31 | if (residuals_idx % 32 == 0) { 32 | for (int k = 0; k < 4; k++) { 33 | packed_residuals[packed_residuals_idx + k] = 34 | (mask >> (8 * (4 - k - 1))) & 0xff; 35 | } 36 | } 37 | } 38 | 39 | torch::Tensor packbits_cuda(const torch::Tensor residuals) { 40 | auto options = torch::TensorOptions() 41 | .dtype(torch::kUInt8) 42 | .device(torch::kCUDA, residuals.device().index()) 43 | .requires_grad(false); 44 | assert(residuals.size(0) % 32 == 0); 45 | torch::Tensor packed_residuals = torch::zeros({int(residuals.size(0) / 8)}, options); 46 | 47 | const int threads = 32; 48 | const int blocks = std::ceil(residuals.size(0) / (float) threads); 49 | 50 | packbits_kernel<<>>( 51 | residuals.data(), 52 | packed_residuals.data(), 53 | residuals.size(0) 54 | ); 55 | 56 | return packed_residuals; 57 | } 58 | -------------------------------------------------------------------------------- /rankify/utils/retrievers/colbert/colbert/indexing/codecs/residual_embeddings_strided.py: -------------------------------------------------------------------------------- 1 | # from colbert.indexing.codecs.residual import ResidualCodec 2 | import rankify.utils.retrievers.colbert.colbert.indexing.codecs.residual_embeddings as residual_embeddings 3 | 4 | from rankify.utils.retrievers.colbert.colbert.search.strided_tensor import StridedTensor 5 | 6 | class ResidualEmbeddingsStrided: 7 | def __init__(self, codec, embeddings, doclens): 8 | self.codec = codec 9 | self.codes = embeddings.codes 10 | self.residuals = embeddings.residuals 11 | self.use_gpu = self.codec.use_gpu 12 | 13 | self.codes_strided = StridedTensor(self.codes, doclens, use_gpu=self.use_gpu) 14 | self.residuals_strided = StridedTensor(self.residuals, doclens, use_gpu=self.use_gpu) 15 | 16 | def lookup_pids(self, passage_ids, out_device='cuda'): 17 | codes_packed, codes_lengths = self.codes_strided.lookup(passage_ids)#.as_packed_tensor() 18 | residuals_packed, _ = self.residuals_strided.lookup(passage_ids)#.as_packed_tensor() 19 | 20 | embeddings_packed = self.codec.decompress(residual_embeddings.ResidualEmbeddings(codes_packed, residuals_packed)) 21 | 22 | return embeddings_packed, codes_lengths 23 | 24 | def lookup_codes(self, passage_ids): 25 | return self.codes_strided.lookup(passage_ids)#.as_packed_tensor() 26 | -------------------------------------------------------------------------------- /rankify/utils/retrievers/colbert/colbert/indexing/collection_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from rankify.utils.retrievers.colbert.colbert.infra.run import Run 4 | from rankify.utils.retrievers.colbert.colbert.utils.utils import print_message, batch 5 | 6 | 7 | class CollectionEncoder: 8 | def __init__(self, config, checkpoint): 9 | self.config = config 10 | self.checkpoint = checkpoint 11 | self.use_gpu = self.config.total_visible_gpus > 0 12 | 13 | def encode_passages(self, passages): 14 | Run().print(f"#> Encoding {len(passages)} passages..") 15 | 16 | if len(passages) == 0: 17 | return None, None 18 | 19 | with torch.inference_mode(): 20 | embs, doclens = [], [] 21 | 22 | # Batch here to avoid OOM from storing intermediate embeddings on GPU. 23 | # Storing on the GPU helps with speed of masking, etc. 24 | # But ideally this batching happens internally inside docFromText. 25 | for passages_batch in batch(passages, self.config.index_bsize * 50): 26 | embs_, doclens_ = self.checkpoint.docFromText( 27 | passages_batch, 28 | bsize=self.config.index_bsize, 29 | keep_dims="flatten", 30 | showprogress=(not self.use_gpu), 31 | pool_factor=self.config.pool_factor, 32 | clustering_mode=self.config.clustering_mode, 33 | protected_tokens=self.config.protected_tokens, 34 | ) 35 | embs.append(embs_) 36 | doclens.extend(doclens_) 37 | 38 | embs = torch.cat(embs) 39 | 40 | # embs, doclens = self.checkpoint.docFromText(passages, bsize=self.config.index_bsize, 41 | # keep_dims='flatten', showprogress=(self.config.rank < 1)) 42 | 43 | # with torch.inference_mode(): 44 | # embs = self.checkpoint.docFromText(passages, bsize=self.config.index_bsize, 45 | # keep_dims=False, showprogress=(self.config.rank < 1)) 46 | # assert type(embs) is list 47 | # assert len(embs) == len(passages) 48 | 49 | # doclens = [d.size(0) for d in embs] 50 | # embs = torch.cat(embs) 51 | 52 | return embs, doclens 53 | -------------------------------------------------------------------------------- /rankify/utils/retrievers/colbert/colbert/indexing/index_manager.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | from bitarray import bitarray 5 | 6 | 7 | class IndexManager(): 8 | def __init__(self, dim): 9 | self.dim = dim 10 | 11 | def save(self, tensor, path_prefix): 12 | torch.save(tensor, path_prefix) 13 | 14 | def save_bitarray(self, bitarray, path_prefix): 15 | with open(path_prefix, "wb") as f: 16 | bitarray.tofile(f) 17 | 18 | 19 | def load_index_part(filename, verbose=True): 20 | part = torch.load(filename) 21 | 22 | if type(part) == list: # for backward compatibility 23 | part = torch.cat(part) 24 | 25 | return part 26 | 27 | 28 | def load_compressed_index_part(filename, dim, bits): 29 | a = bitarray() 30 | 31 | with open(filename, "rb") as f: 32 | a.fromfile(f) 33 | 34 | n = len(a) // dim // bits 35 | part = torch.tensor(np.frombuffer(a.tobytes(), dtype=np.uint8)) # TODO: isn't from_numpy(.) faster? 36 | part = part.reshape((n, int(np.ceil(dim * bits / 8)))) 37 | 38 | return part 39 | -------------------------------------------------------------------------------- /rankify/utils/retrievers/colbert/colbert/indexing/loaders.py: -------------------------------------------------------------------------------- 1 | import re 2 | import os 3 | import ujson 4 | 5 | 6 | def get_parts(directory): 7 | extension = '.pt' 8 | 9 | parts = sorted([int(filename[: -1 * len(extension)]) for filename in os.listdir(directory) 10 | if filename.endswith(extension)]) 11 | 12 | assert list(range(len(parts))) == parts, parts 13 | 14 | # Integer-sortedness matters. 15 | parts_paths = [os.path.join(directory, '{}{}'.format(filename, extension)) for filename in parts] 16 | samples_paths = [os.path.join(directory, '{}.sample'.format(filename)) for filename in parts] 17 | 18 | return parts, parts_paths, samples_paths 19 | 20 | 21 | def load_doclens(directory, flatten=True): 22 | doclens_filenames = {} 23 | 24 | for filename in os.listdir(directory): 25 | match = re.match(r"doclens.(\d+).json", filename) 26 | 27 | if match is not None: 28 | doclens_filenames[int(match.group(1))] = filename 29 | 30 | doclens_filenames = [os.path.join(directory, doclens_filenames[i]) for i in sorted(doclens_filenames.keys())] 31 | 32 | all_doclens = [ujson.load(open(filename)) for filename in doclens_filenames] 33 | 34 | if flatten: 35 | all_doclens = [x for sub_doclens in all_doclens for x in sub_doclens] 36 | 37 | if len(all_doclens) == 0: 38 | raise ValueError("Could not load doclens") 39 | 40 | return all_doclens 41 | 42 | 43 | def get_deltas(directory): 44 | extension = '.residuals.pt' 45 | 46 | parts = sorted([int(filename[: -1 * len(extension)]) for filename in os.listdir(directory) 47 | if filename.endswith(extension)]) 48 | 49 | assert list(range(len(parts))) == parts, parts 50 | 51 | # Integer-sortedness matters. 52 | parts_paths = [os.path.join(directory, '{}{}'.format(filename, extension)) for filename in parts] 53 | 54 | return parts, parts_paths 55 | 56 | 57 | # def load_compression_data(level, path): 58 | # with open(path, "r") as f: 59 | # for line in f: 60 | # line = line.split(',') 61 | # bits = int(line[0]) 62 | 63 | # if bits == level: 64 | # return [float(v) for v in line[1:]] 65 | 66 | # raise ValueError(f"No data found for {level}-bit compression") 67 | -------------------------------------------------------------------------------- /rankify/utils/retrievers/colbert/colbert/indexing/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import tqdm 4 | 5 | from rankify.utils.retrievers.colbert.colbert.indexing.loaders import load_doclens 6 | from rankify.utils.retrievers.colbert.colbert.utils.utils import print_message, flatten 7 | 8 | def optimize_ivf(orig_ivf, orig_ivf_lengths, index_path, verbose:int=3): 9 | if verbose > 1: 10 | print_message("#> Optimizing IVF to store map from centroids to list of pids..") 11 | 12 | print_message("#> Building the emb2pid mapping..") 13 | all_doclens = load_doclens(index_path, flatten=False) 14 | 15 | # assert self.num_embeddings == sum(flatten(all_doclens)) 16 | 17 | all_doclens = flatten(all_doclens) 18 | total_num_embeddings = sum(all_doclens) 19 | 20 | emb2pid = torch.zeros(total_num_embeddings, dtype=torch.int) 21 | 22 | """ 23 | EVENTUALLY: Use two tensors. emb2pid_offsets will have every 256th element. 24 | emb2pid_delta will have the delta from the corresponding offset, 25 | """ 26 | 27 | offset_doclens = 0 28 | for pid, dlength in enumerate(all_doclens): 29 | emb2pid[offset_doclens: offset_doclens + dlength] = pid 30 | offset_doclens += dlength 31 | 32 | if verbose > 1: 33 | print_message("len(emb2pid) =", len(emb2pid)) 34 | 35 | ivf = emb2pid[orig_ivf] 36 | unique_pids_per_centroid = [] 37 | ivf_lengths = [] 38 | 39 | offset = 0 40 | for length in tqdm.tqdm(orig_ivf_lengths.tolist()): 41 | pids = torch.unique(ivf[offset:offset+length]) 42 | unique_pids_per_centroid.append(pids) 43 | ivf_lengths.append(pids.shape[0]) 44 | offset += length 45 | ivf = torch.cat(unique_pids_per_centroid) 46 | ivf_lengths = torch.tensor(ivf_lengths) 47 | 48 | max_stride = ivf_lengths.max().item() 49 | zero = torch.zeros(1, dtype=torch.long, device=ivf_lengths.device) 50 | offsets = torch.cat((zero, torch.cumsum(ivf_lengths, dim=0))) 51 | inner_dims = ivf.size()[1:] 52 | 53 | if offsets[-2] + max_stride > ivf.size(0): 54 | padding = torch.zeros(max_stride, *inner_dims, dtype=ivf.dtype, device=ivf.device) 55 | ivf = torch.cat((ivf, padding)) 56 | 57 | original_ivf_path = os.path.join(index_path, 'ivf.pt') 58 | optimized_ivf_path = os.path.join(index_path, 'ivf.pid.pt') 59 | torch.save((ivf, ivf_lengths), optimized_ivf_path) 60 | if verbose > 1: 61 | print_message(f"#> Saved optimized IVF to {optimized_ivf_path}") 62 | if os.path.exists(original_ivf_path): 63 | print_message(f"#> Original IVF at path \"{original_ivf_path}\" can now be removed") 64 | 65 | return ivf, ivf_lengths 66 | 67 | -------------------------------------------------------------------------------- /rankify/utils/retrievers/colbert/colbert/infra/__init__.py: -------------------------------------------------------------------------------- 1 | from .run import * 2 | from .config import * -------------------------------------------------------------------------------- /rankify/utils/retrievers/colbert/colbert/infra/config/__init__.py: -------------------------------------------------------------------------------- 1 | from .config import * 2 | from .settings import * -------------------------------------------------------------------------------- /rankify/utils/retrievers/colbert/colbert/infra/config/config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from .base_config import BaseConfig 4 | from .settings import * 5 | 6 | 7 | @dataclass 8 | class RunConfig(BaseConfig, RunSettings): 9 | pass 10 | 11 | 12 | @dataclass 13 | class ColBERTConfig(RunSettings, ResourceSettings, DocSettings, QuerySettings, TrainingSettings, 14 | IndexingSettings, SearchSettings, BaseConfig, TokenizerSettings): 15 | pass 16 | -------------------------------------------------------------------------------- /rankify/utils/retrievers/colbert/colbert/infra/config/core_config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import ujson 4 | import dataclasses 5 | 6 | from typing import Any 7 | from collections import defaultdict 8 | from dataclasses import dataclass, fields 9 | from rankify.utils.retrievers.colbert.colbert.utils.utils import timestamp, torch_load_dnn 10 | 11 | from rankify.utils.retrievers.colbert.utility.utils.save_metadata import get_metadata_only 12 | 13 | 14 | @dataclass 15 | class DefaultVal: 16 | val: Any 17 | 18 | def __hash__(self): 19 | return hash(repr(self.val)) 20 | 21 | def __eq__(self, other): 22 | self.val == other.val 23 | 24 | @dataclass 25 | class CoreConfig: 26 | def __post_init__(self): 27 | """ 28 | Source: https://stackoverflow.com/a/58081120/1493011 29 | """ 30 | 31 | self.assigned = {} 32 | 33 | for field in fields(self): 34 | field_val = getattr(self, field.name) 35 | 36 | if isinstance(field_val, DefaultVal) or field_val is None: 37 | setattr(self, field.name, field.default.val) 38 | 39 | if not isinstance(field_val, DefaultVal): 40 | self.assigned[field.name] = True 41 | 42 | def assign_defaults(self): 43 | for field in fields(self): 44 | setattr(self, field.name, field.default.val) 45 | self.assigned[field.name] = True 46 | 47 | def configure(self, ignore_unrecognized=True, **kw_args): 48 | ignored = set() 49 | 50 | for key, value in kw_args.items(): 51 | self.set(key, value, ignore_unrecognized) or ignored.update({key}) 52 | 53 | return ignored 54 | 55 | """ 56 | # TODO: Take a config object, not kw_args. 57 | 58 | for key in config.assigned: 59 | value = getattr(config, key) 60 | """ 61 | 62 | def set(self, key, value, ignore_unrecognized=False): 63 | if hasattr(self, key): 64 | setattr(self, key, value) 65 | self.assigned[key] = True 66 | return True 67 | 68 | if not ignore_unrecognized: 69 | raise Exception(f"Unrecognized key `{key}` for {type(self)}") 70 | 71 | def help(self): 72 | print(ujson.dumps(self.export(), indent=4)) 73 | 74 | def __export_value(self, v): 75 | v = v.provenance() if hasattr(v, 'provenance') else v 76 | 77 | if isinstance(v, list) and len(v) > 100: 78 | v = (f"list with {len(v)} elements starting with...", v[:3]) 79 | 80 | if isinstance(v, dict) and len(v) > 100: 81 | v = (f"dict with {len(v)} keys starting with...", list(v.keys())[:3]) 82 | 83 | return v 84 | 85 | def export(self): 86 | d = dataclasses.asdict(self) 87 | 88 | for k, v in d.items(): 89 | d[k] = self.__export_value(v) 90 | 91 | return d 92 | -------------------------------------------------------------------------------- /rankify/utils/retrievers/colbert/colbert/infra/provenance.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import traceback 3 | import inspect 4 | 5 | 6 | class Provenance: 7 | def __init__(self) -> None: 8 | self.initial_stacktrace = self.stacktrace() 9 | 10 | def stacktrace(self): 11 | trace = inspect.stack() 12 | output = [] 13 | 14 | for frame in trace[2:-1]: 15 | try: 16 | frame = f'{frame.filename}:{frame.lineno}:{frame.function}: {frame.code_context[0].strip()}' 17 | output.append(frame) 18 | except: 19 | output.append(None) 20 | 21 | return output 22 | 23 | def toDict(self): # for ujson 24 | self.serialization_stacktrace = self.stacktrace() 25 | return dict(self.__dict__) 26 | 27 | 28 | """if __name__ == '__main__': 29 | p = Provenance() 30 | print(p.toDict().keys()) 31 | 32 | import ujson 33 | print(ujson.dumps(p, indent=4)) 34 | 35 | 36 | class X: 37 | def __init__(self) -> None: 38 | pass 39 | 40 | def toDict(self): 41 | return {'key': 1} 42 | 43 | print(ujson.dumps(X()))""" -------------------------------------------------------------------------------- /rankify/utils/retrievers/colbert/colbert/infra/utilities/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataScienceUIBK/Rankify/c3d521c2d1a17c7d82f2a9025aa000d9ad41b7d2/rankify/utils/retrievers/colbert/colbert/infra/utilities/__init__.py -------------------------------------------------------------------------------- /rankify/utils/retrievers/colbert/colbert/infra/utilities/create_triples.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | from colbert.utils.utils import print_message 4 | from utility.utils.save_metadata import save_metadata 5 | from utility.supervision.triples import sample_for_query 6 | 7 | from colbert.data.ranking import Ranking 8 | from colbert.data.examples import Examples 9 | 10 | MAX_NUM_TRIPLES = 40_000_000 11 | 12 | 13 | class Triples: 14 | def __init__(self, ranking, seed=12345): 15 | random.seed(seed) # TODO: Use internal RNG instead.. 16 | self.qid2rankings = Ranking.cast(ranking).todict() 17 | 18 | def create(self, positives, depth): 19 | assert all(len(x) == 2 for x in positives) 20 | assert all(maxBest <= maxDepth for maxBest, maxDepth in positives), positives 21 | 22 | Triples = [] 23 | NonEmptyQIDs = 0 24 | 25 | for processing_idx, qid in enumerate(self.qid2rankings): 26 | l = sample_for_query(qid, self.qid2rankings[qid], positives, depth, False, None) 27 | NonEmptyQIDs += (len(l) > 0) 28 | Triples.extend(l) 29 | 30 | if processing_idx % (10_000) == 0: 31 | print_message(f"#> Done with {processing_idx+1} questions!\t\t " 32 | f"{str(len(Triples) / 1000)}k triples for {NonEmptyQIDs} unqiue QIDs.") 33 | 34 | print_message(f"#> Sub-sample the triples (if > {MAX_NUM_TRIPLES})..") 35 | print_message(f"#> len(Triples) = {len(Triples)}") 36 | 37 | if len(Triples) > MAX_NUM_TRIPLES: 38 | Triples = random.sample(Triples, MAX_NUM_TRIPLES) 39 | 40 | ### Prepare the triples ### 41 | print_message("#> Shuffling the triples...") 42 | random.shuffle(Triples) 43 | 44 | self.Triples = Examples(data=Triples) 45 | 46 | return Triples 47 | 48 | def save(self, new_path): 49 | Examples(data=self.Triples).save(new_path) 50 | 51 | # save_metadata(f'{output}.meta', args) # TODO: What args to save?? {seed, positives, depth, rankings if path or else whatever provenance the rankings object shares} 52 | 53 | -------------------------------------------------------------------------------- /rankify/utils/retrievers/colbert/colbert/infra/utilities/minicorpus.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | from colbert.utils.utils import create_directory 5 | 6 | from colbert.data import Collection, Queries, Ranking 7 | 8 | 9 | def sample_minicorpus(name, factor, topk=30, maxdev=3000): 10 | """ 11 | Factor: 12 | * nano=1 13 | * micro=10 14 | * mini=100 15 | * small=100 with topk=100 16 | * medium=150 with topk=300 17 | """ 18 | 19 | random.seed(12345) 20 | 21 | # Load collection 22 | collection = Collection(path='/dfs/scratch0/okhattab/OpenQA/collection.tsv') 23 | 24 | # Load train and dev queries 25 | qas_train = Queries(path='/dfs/scratch0/okhattab/OpenQA/NQ/train/qas.json').qas() 26 | qas_dev = Queries(path='/dfs/scratch0/okhattab/OpenQA/NQ/dev/qas.json').qas() 27 | 28 | # Load train and dev C3 rankings 29 | ranking_train = Ranking(path='/dfs/scratch0/okhattab/OpenQA/NQ/train/rankings/C3.tsv.annotated').todict() 30 | ranking_dev = Ranking(path='/dfs/scratch0/okhattab/OpenQA/NQ/dev/rankings/C3.tsv.annotated').todict() 31 | 32 | # Sample NT and ND queries from each, keep only the top-k passages for those 33 | sample_train = random.sample(list(qas_train.keys()), min(len(qas_train.keys()), 300*factor)) 34 | sample_dev = random.sample(list(qas_dev.keys()), min(len(qas_dev.keys()), maxdev, 30*factor)) 35 | 36 | train_pids = [pid for qid in sample_train for qpids in ranking_train[qid][:topk] for pid in qpids] 37 | dev_pids = [pid for qid in sample_dev for qpids in ranking_dev[qid][:topk] for pid in qpids] 38 | 39 | sample_pids = sorted(list(set(train_pids + dev_pids))) 40 | print(f'len(sample_pids) = {len(sample_pids)}') 41 | 42 | # Save the new query sets: train and dev 43 | ROOT = f'/future/u/okhattab/root/unit/data/NQ-{name}' 44 | 45 | create_directory(os.path.join(ROOT, 'train')) 46 | create_directory(os.path.join(ROOT, 'dev')) 47 | 48 | new_train = Queries(data={qid: qas_train[qid] for qid in sample_train}) 49 | new_train.save(os.path.join(ROOT, 'train/questions.tsv')) 50 | new_train.save_qas(os.path.join(ROOT, 'train/qas.json')) 51 | 52 | new_dev = Queries(data={qid: qas_dev[qid] for qid in sample_dev}) 53 | new_dev.save(os.path.join(ROOT, 'dev/questions.tsv')) 54 | new_dev.save_qas(os.path.join(ROOT, 'dev/qas.json')) 55 | 56 | # Save the new collection 57 | print(f"Saving to {os.path.join(ROOT, 'collection.tsv')}") 58 | Collection(data=[collection[pid] for pid in sample_pids]).save(os.path.join(ROOT, 'collection.tsv')) 59 | 60 | print('#> Done!') 61 | 62 | 63 | """if __name__ == '__main__': 64 | sample_minicorpus('medium', 150, topk=300)""" 65 | -------------------------------------------------------------------------------- /rankify/utils/retrievers/colbert/colbert/modeling/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataScienceUIBK/Rankify/c3d521c2d1a17c7d82f2a9025aa000d9ad41b7d2/rankify/utils/retrievers/colbert/colbert/modeling/__init__.py -------------------------------------------------------------------------------- /rankify/utils/retrievers/colbert/colbert/modeling/reranker/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataScienceUIBK/Rankify/c3d521c2d1a17c7d82f2a9025aa000d9ad41b7d2/rankify/utils/retrievers/colbert/colbert/modeling/reranker/__init__.py -------------------------------------------------------------------------------- /rankify/utils/retrievers/colbert/colbert/modeling/reranker/electra.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from transformers import ElectraPreTrainedModel, ElectraModel, AutoTokenizer 4 | 5 | class ElectraReranker(ElectraPreTrainedModel): 6 | """ 7 | Shallow wrapper around HuggingFace transformers. All new parameters should be defined at this level. 8 | 9 | This makes sure `{from,save}_pretrained` and `init_weights` are applied to new parameters correctly. 10 | """ 11 | _keys_to_ignore_on_load_unexpected = [r"cls"] 12 | 13 | def __init__(self, config): 14 | super().__init__(config) 15 | 16 | self.electra = ElectraModel(config) 17 | self.linear = nn.Linear(config.hidden_size, 1) 18 | self.raw_tokenizer = AutoTokenizer.from_pretrained('google/electra-large-discriminator') 19 | 20 | self.init_weights() 21 | 22 | def forward(self, encoding): 23 | outputs = self.electra(encoding.input_ids, 24 | attention_mask=encoding.attention_mask, 25 | token_type_ids=encoding.token_type_ids)[0] 26 | 27 | scores = self.linear(outputs[:, 0]).squeeze(-1) 28 | 29 | return scores 30 | 31 | def save(self, path): 32 | assert not path.endswith('.dnn'), f"{path}: We reserve *.dnn names for the deprecated checkpoint format." 33 | 34 | self.save_pretrained(path) 35 | self.raw_tokenizer.save_pretrained(path) -------------------------------------------------------------------------------- /rankify/utils/retrievers/colbert/colbert/modeling/reranker/tokenizer.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoTokenizer 2 | 3 | class RerankerTokenizer(): 4 | def __init__(self, total_maxlen, base): 5 | self.total_maxlen = total_maxlen 6 | self.tok = AutoTokenizer.from_pretrained(base) 7 | 8 | def tensorize(self, questions, passages): 9 | assert type(questions) in [list, tuple], type(questions) 10 | assert type(passages) in [list, tuple], type(passages) 11 | 12 | encoding = self.tok(questions, passages, padding='longest', truncation='longest_first', 13 | return_tensors='pt', max_length=self.total_maxlen, add_special_tokens=True) 14 | 15 | return encoding 16 | -------------------------------------------------------------------------------- /rankify/utils/retrievers/colbert/colbert/modeling/tokenization/__init__.py: -------------------------------------------------------------------------------- 1 | from rankify.utils.retrievers.colbert.colbert.modeling.tokenization.query_tokenization import * 2 | from rankify.utils.retrievers.colbert.colbert.modeling.tokenization.doc_tokenization import * 3 | from rankify.utils.retrievers.colbert.colbert.modeling.tokenization.utils import tensorize_triples 4 | -------------------------------------------------------------------------------- /rankify/utils/retrievers/colbert/colbert/modeling/tokenization/doc_tokenization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | # from transformers import BertTokenizerFast 4 | 5 | from rankify.utils.retrievers.colbert.colbert.modeling.hf_colbert import class_factory 6 | from rankify.utils.retrievers.colbert.colbert.infra import ColBERTConfig 7 | from rankify.utils.retrievers.colbert.colbert.modeling.tokenization.utils import _split_into_batches, _sort_by_length, _insert_prefix_token 8 | from rankify.utils.retrievers.colbert.colbert.parameters import DEVICE 9 | 10 | class DocTokenizer(): 11 | def __init__(self, config: ColBERTConfig): 12 | HF_ColBERT = class_factory(config.checkpoint) 13 | self.tok = HF_ColBERT.raw_tokenizer_from_pretrained(config.checkpoint) 14 | 15 | self.config = config 16 | self.doc_maxlen = config.doc_maxlen 17 | 18 | self.D_marker_token, self.D_marker_token_id = self.config.doc_token, self.tok.convert_tokens_to_ids(self.config.doc_token_id) 19 | self.cls_token, self.cls_token_id = self.tok.cls_token, self.tok.cls_token_id 20 | self.sep_token, self.sep_token_id = self.tok.sep_token, self.tok.sep_token_id 21 | 22 | def tokenize(self, batch_text, add_special_tokens=False): 23 | assert type(batch_text) in [list, tuple], (type(batch_text)) 24 | 25 | tokens = [self.tok.tokenize(x, add_special_tokens=False).to(DEVICE) for x in batch_text] 26 | 27 | if not add_special_tokens: 28 | return tokens 29 | 30 | prefix, suffix = [self.cls_token, self.D_marker_token], [self.sep_token] 31 | tokens = [prefix + lst + suffix for lst in tokens] 32 | 33 | return tokens 34 | 35 | def encode(self, batch_text, add_special_tokens=False): 36 | assert type(batch_text) in [list, tuple], (type(batch_text)) 37 | 38 | ids = self.tok(batch_text, add_special_tokens=False).to(DEVICE)['input_ids'] 39 | 40 | if not add_special_tokens: 41 | return ids 42 | 43 | prefix, suffix = [self.cls_token_id, self.D_marker_token_id], [self.sep_token_id] 44 | ids = [prefix + lst + suffix for lst in ids] 45 | 46 | return ids 47 | 48 | def tensorize(self, batch_text, bsize=None): 49 | assert type(batch_text) in [list, tuple], (type(batch_text)) 50 | 51 | obj = self.tok(batch_text, padding='longest', truncation='longest_first', 52 | return_tensors='pt', max_length=(self.doc_maxlen - 1)).to(DEVICE) 53 | 54 | ids = _insert_prefix_token(obj['input_ids'], self.D_marker_token_id) 55 | mask = _insert_prefix_token(obj['attention_mask'], 1) 56 | 57 | if bsize: 58 | ids, mask, reverse_indices = _sort_by_length(ids, mask, bsize) 59 | batches = _split_into_batches(ids, mask, bsize) 60 | return batches, reverse_indices 61 | 62 | return ids, mask 63 | -------------------------------------------------------------------------------- /rankify/utils/retrievers/colbert/colbert/modeling/tokenization/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def tensorize_triples(query_tokenizer, doc_tokenizer, queries, passages, scores, bsize, nway): 5 | # assert len(passages) == len(scores) == bsize * nway 6 | # assert bsize is None or len(queries) % bsize == 0 7 | 8 | # N = len(queries) 9 | Q_ids, Q_mask = query_tokenizer.tensorize(queries) 10 | D_ids, D_mask = doc_tokenizer.tensorize(passages) 11 | # D_ids, D_mask = D_ids.view(2, N, -1), D_mask.view(2, N, -1) 12 | 13 | # # Compute max among {length of i^th positive, length of i^th negative} for i \in N 14 | # maxlens = D_mask.sum(-1).max(0).values 15 | 16 | # # Sort by maxlens 17 | # indices = maxlens.sort().indices 18 | # Q_ids, Q_mask = Q_ids[indices], Q_mask[indices] 19 | # D_ids, D_mask = D_ids[:, indices], D_mask[:, indices] 20 | 21 | # (positive_ids, negative_ids), (positive_mask, negative_mask) = D_ids, D_mask 22 | 23 | query_batches = _split_into_batches(Q_ids, Q_mask, bsize) 24 | doc_batches = _split_into_batches(D_ids, D_mask, bsize * nway) 25 | # positive_batches = _split_into_batches(positive_ids, positive_mask, bsize) 26 | # negative_batches = _split_into_batches(negative_ids, negative_mask, bsize) 27 | 28 | if len(scores): 29 | score_batches = _split_into_batches2(scores, bsize * nway) 30 | else: 31 | score_batches = [[] for _ in doc_batches] 32 | 33 | batches = [] 34 | for Q, D, S in zip(query_batches, doc_batches, score_batches): 35 | batches.append((Q, D, S)) 36 | 37 | return batches 38 | 39 | 40 | def _sort_by_length(ids, mask, bsize): 41 | if ids.size(0) <= bsize: 42 | return ids, mask, torch.arange(ids.size(0)) 43 | 44 | indices = mask.sum(-1).sort().indices 45 | reverse_indices = indices.sort().indices 46 | 47 | return ids[indices], mask[indices], reverse_indices 48 | 49 | 50 | def _split_into_batches(ids, mask, bsize): 51 | batches = [] 52 | for offset in range(0, ids.size(0), bsize): 53 | batches.append((ids[offset:offset+bsize], mask[offset:offset+bsize])) 54 | 55 | return batches 56 | 57 | 58 | def _split_into_batches2(scores, bsize): 59 | batches = [] 60 | for offset in range(0, len(scores), bsize): 61 | batches.append(scores[offset:offset+bsize]) 62 | 63 | return batches 64 | 65 | def _insert_prefix_token(tensor: torch.Tensor, prefix_id: int): 66 | prefix_tensor = torch.full( 67 | (tensor.size(0), 1), 68 | prefix_id, 69 | dtype=tensor.dtype, 70 | device=tensor.device, 71 | ) 72 | return torch.cat([tensor[:, :1], prefix_tensor, tensor[:, 1:]], dim=1) 73 | -------------------------------------------------------------------------------- /rankify/utils/retrievers/colbert/colbert/parameters.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") 4 | 5 | SAVED_CHECKPOINTS = [32*1000, 100*1000, 150*1000, 200*1000, 250*1000, 300*1000, 400*1000] 6 | SAVED_CHECKPOINTS += [10*1000, 20*1000, 30*1000, 40*1000, 50*1000, 60*1000, 70*1000, 80*1000, 90*1000] 7 | SAVED_CHECKPOINTS += [25*1000, 50*1000, 75*1000] 8 | 9 | SAVED_CHECKPOINTS = set(SAVED_CHECKPOINTS) 10 | 11 | 12 | # TODO: final_ckpt 2k, 5k, 10k 20k, 50k, 100k 150k 200k, 500k, 1M 2M, 5M, 10M -------------------------------------------------------------------------------- /rankify/utils/retrievers/colbert/colbert/ranking/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataScienceUIBK/Rankify/c3d521c2d1a17c7d82f2a9025aa000d9ad41b7d2/rankify/utils/retrievers/colbert/colbert/ranking/__init__.py -------------------------------------------------------------------------------- /rankify/utils/retrievers/colbert/colbert/search/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataScienceUIBK/Rankify/c3d521c2d1a17c7d82f2a9025aa000d9ad41b7d2/rankify/utils/retrievers/colbert/colbert/search/__init__.py -------------------------------------------------------------------------------- /rankify/utils/retrievers/colbert/colbert/search/candidate_generation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from rankify.utils.retrievers.colbert.colbert.search.strided_tensor import StridedTensor 4 | from .strided_tensor_core import _create_mask, _create_view 5 | 6 | 7 | class CandidateGeneration: 8 | 9 | def __init__(self, use_gpu=True): 10 | self.use_gpu = use_gpu 11 | 12 | def get_cells(self, Q, ncells): 13 | scores = (self.codec.centroids @ Q.T) 14 | if ncells == 1: 15 | cells = scores.argmax(dim=0, keepdim=True).permute(1, 0) 16 | else: 17 | cells = scores.topk(ncells, dim=0, sorted=False).indices.permute(1, 0) # (32, ncells) 18 | cells = cells.flatten().contiguous() # (32 * ncells,) 19 | cells = cells.unique(sorted=False) 20 | return cells, scores 21 | 22 | def generate_candidate_eids(self, Q, ncells): 23 | cells, scores = self.get_cells(Q, ncells) 24 | 25 | eids, cell_lengths = self.ivf.lookup(cells) # eids = (packedlen,) lengths = (32 * ncells,) 26 | eids = eids.long() 27 | if self.use_gpu: 28 | eids = eids.cuda() 29 | return eids, scores 30 | 31 | def generate_candidate_pids(self, Q, ncells): 32 | cells, scores = self.get_cells(Q, ncells) 33 | 34 | pids, cell_lengths = self.ivf.lookup(cells) 35 | if self.use_gpu: 36 | pids = pids.cuda() 37 | return pids, scores 38 | 39 | def generate_candidate_scores(self, Q, eids): 40 | E = self.lookup_eids(eids) 41 | if self.use_gpu: 42 | E = E.cuda() 43 | return (Q.unsqueeze(0) @ E.unsqueeze(2)).squeeze(-1).T 44 | 45 | def generate_candidates(self, config, Q): 46 | ncells = config.ncells 47 | 48 | assert isinstance(self.ivf, StridedTensor) 49 | 50 | Q = Q.squeeze(0) 51 | if self.use_gpu: 52 | Q = Q.cuda().half() 53 | assert Q.dim() == 2 54 | 55 | pids, centroid_scores = self.generate_candidate_pids(Q, ncells) 56 | 57 | sorter = pids.sort() 58 | pids = sorter.values 59 | 60 | pids, pids_counts = torch.unique_consecutive(pids, return_counts=True) 61 | if self.use_gpu: 62 | pids, pids_counts = pids.cuda(), pids_counts.cuda() 63 | 64 | return pids, centroid_scores 65 | -------------------------------------------------------------------------------- /rankify/utils/retrievers/colbert/colbert/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataScienceUIBK/Rankify/c3d521c2d1a17c7d82f2a9025aa000d9ad41b7d2/rankify/utils/retrievers/colbert/colbert/tests/__init__.py -------------------------------------------------------------------------------- /rankify/utils/retrievers/colbert/colbert/trainer.py: -------------------------------------------------------------------------------- 1 | from colbert.infra.run import Run 2 | from colbert.infra.launcher import Launcher 3 | from colbert.infra.config import ColBERTConfig, RunConfig 4 | 5 | from colbert.training.training import train 6 | 7 | 8 | class Trainer: 9 | def __init__(self, triples, queries, collection, config=None): 10 | self.config = ColBERTConfig.from_existing(config, Run().config) 11 | 12 | self.triples = triples 13 | self.queries = queries 14 | self.collection = collection 15 | 16 | def configure(self, **kw_args): 17 | self.config.configure(**kw_args) 18 | 19 | def train(self, checkpoint='bert-base-uncased'): 20 | """ 21 | Note that config.checkpoint is ignored. Only the supplied checkpoint here is used. 22 | """ 23 | 24 | # Resources don't come from the config object. They come from the input parameters. 25 | # TODO: After the API stabilizes, make this "self.config.assign()" to emphasize this distinction. 26 | self.configure(triples=self.triples, queries=self.queries, collection=self.collection) 27 | self.configure(checkpoint=checkpoint) 28 | 29 | launcher = Launcher(train) 30 | 31 | self._best_checkpoint_path = launcher.launch(self.config, self.triples, self.queries, self.collection) 32 | 33 | 34 | def best_checkpoint_path(self): 35 | return self._best_checkpoint_path 36 | 37 | -------------------------------------------------------------------------------- /rankify/utils/retrievers/colbert/colbert/training/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataScienceUIBK/Rankify/c3d521c2d1a17c7d82f2a9025aa000d9ad41b7d2/rankify/utils/retrievers/colbert/colbert/training/__init__.py -------------------------------------------------------------------------------- /rankify/utils/retrievers/colbert/colbert/training/eager_batcher.py: -------------------------------------------------------------------------------- 1 | import os 2 | import ujson 3 | 4 | from functools import partial 5 | from colbert.utils.utils import print_message 6 | from colbert.modeling.tokenization import QueryTokenizer, DocTokenizer, tensorize_triples 7 | 8 | from colbert.utils.runs import Run 9 | 10 | 11 | class EagerBatcher(): 12 | def __init__(self, args, rank=0, nranks=1): 13 | self.rank, self.nranks = rank, nranks 14 | self.bsize, self.accumsteps = args.bsize, args.accumsteps 15 | 16 | self.query_tokenizer = QueryTokenizer(args.query_maxlen) 17 | self.doc_tokenizer = DocTokenizer(args.doc_maxlen) 18 | self.tensorize_triples = partial(tensorize_triples, self.query_tokenizer, self.doc_tokenizer) 19 | 20 | self.triples_path = args.triples 21 | self._reset_triples() 22 | 23 | def _reset_triples(self): 24 | self.reader = open(self.triples_path, mode='r', encoding="utf-8") 25 | self.position = 0 26 | 27 | def __iter__(self): 28 | return self 29 | 30 | def __next__(self): 31 | queries, positives, negatives = [], [], [] 32 | 33 | for line_idx, line in zip(range(self.bsize * self.nranks), self.reader): 34 | if (self.position + line_idx) % self.nranks != self.rank: 35 | continue 36 | 37 | query, pos, neg = line.strip().split('\t') 38 | 39 | queries.append(query) 40 | positives.append(pos) 41 | negatives.append(neg) 42 | 43 | self.position += line_idx + 1 44 | 45 | if len(queries) < self.bsize: 46 | raise StopIteration 47 | 48 | return self.collate(queries, positives, negatives) 49 | 50 | def collate(self, queries, positives, negatives): 51 | assert len(queries) == len(positives) == len(negatives) == self.bsize 52 | 53 | return self.tensorize_triples(queries, positives, negatives, self.bsize // self.accumsteps) 54 | 55 | def skip_to_batch(self, batch_idx, intended_batch_size): 56 | self._reset_triples() 57 | 58 | Run.warn(f'Skipping to batch #{batch_idx} (with intended_batch_size = {intended_batch_size}) for training.') 59 | 60 | _ = [self.reader.readline() for _ in range(batch_idx * intended_batch_size)] 61 | 62 | return None 63 | -------------------------------------------------------------------------------- /rankify/utils/retrievers/colbert/colbert/training/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | # from colbert.utils.runs import Run 5 | from colbert.utils.utils import print_message, save_checkpoint 6 | from colbert.parameters import SAVED_CHECKPOINTS 7 | from colbert.infra.run import Run 8 | 9 | 10 | def print_progress(scores): 11 | positive_avg, negative_avg = round(scores[:, 0].mean().item(), 2), round(scores[:, 1].mean().item(), 2) 12 | print("#>>> ", positive_avg, negative_avg, '\t\t|\t\t', positive_avg - negative_avg) 13 | 14 | 15 | def manage_checkpoints(args, colbert, optimizer, batch_idx, savepath=None, consumed_all_triples=False): 16 | # arguments = dict(args) 17 | 18 | # TODO: Call provenance() on the values that support it?? 19 | 20 | checkpoints_path = savepath or os.path.join(Run().path_, 'checkpoints') 21 | name = None 22 | 23 | try: 24 | save = colbert.save 25 | except: 26 | save = colbert.module.save 27 | 28 | if not os.path.exists(checkpoints_path): 29 | os.makedirs(checkpoints_path) 30 | 31 | path_save = None 32 | 33 | if consumed_all_triples or (batch_idx % 2000 == 0): 34 | # name = os.path.join(path, "colbert.dnn") 35 | # save_checkpoint(name, 0, batch_idx, colbert, optimizer, arguments) 36 | path_save = os.path.join(checkpoints_path, "colbert") 37 | 38 | if batch_idx in SAVED_CHECKPOINTS: 39 | # name = os.path.join(path, "colbert-{}.dnn".format(batch_idx)) 40 | # save_checkpoint(name, 0, batch_idx, colbert, optimizer, arguments) 41 | path_save = os.path.join(checkpoints_path, f"colbert-{batch_idx}") 42 | 43 | if path_save: 44 | print(f"#> Saving a checkpoint to {path_save} ..") 45 | 46 | checkpoint = {} 47 | checkpoint['batch'] = batch_idx 48 | # checkpoint['epoch'] = 0 49 | # checkpoint['model_state_dict'] = model.state_dict() 50 | # checkpoint['optimizer_state_dict'] = optimizer.state_dict() 51 | # checkpoint['arguments'] = arguments 52 | 53 | save(path_save) 54 | 55 | return path_save 56 | -------------------------------------------------------------------------------- /rankify/utils/retrievers/colbert/colbert/utilities/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataScienceUIBK/Rankify/c3d521c2d1a17c7d82f2a9025aa000d9ad41b7d2/rankify/utils/retrievers/colbert/colbert/utilities/__init__.py -------------------------------------------------------------------------------- /rankify/utils/retrievers/colbert/colbert/utilities/create_triples.py: -------------------------------------------------------------------------------- 1 | import random 2 | from colbert.infra.provenance import Provenance 3 | 4 | from utility.utils.save_metadata import save_metadata 5 | from utility.supervision.triples import sample_for_query 6 | 7 | from colbert.utils.utils import print_message 8 | 9 | from colbert.data.ranking import Ranking 10 | from colbert.data.examples import Examples 11 | 12 | MAX_NUM_TRIPLES = 40_000_000 13 | 14 | 15 | class Triples: 16 | def __init__(self, ranking, seed=12345): 17 | random.seed(seed) # TODO: Use internal RNG instead.. 18 | self.seed = seed 19 | 20 | ranking = Ranking.cast(ranking) 21 | self.ranking_provenance = ranking.provenance() 22 | self.qid2rankings = ranking.todict() 23 | 24 | def create(self, positives, depth): 25 | assert all(len(x) == 2 for x in positives) 26 | assert all(maxBest <= maxDepth for maxBest, maxDepth in positives), positives 27 | 28 | self.positives = positives 29 | self.depth = depth 30 | 31 | Triples = [] 32 | NonEmptyQIDs = 0 33 | 34 | for processing_idx, qid in enumerate(self.qid2rankings): 35 | l = sample_for_query(qid, self.qid2rankings[qid], positives, depth, False, None) 36 | NonEmptyQIDs += (len(l) > 0) 37 | Triples.extend(l) 38 | 39 | if processing_idx % (10_000) == 0: 40 | print_message(f"#> Done with {processing_idx+1} questions!\t\t " 41 | f"{str(len(Triples) / 1000)}k triples for {NonEmptyQIDs} unqiue QIDs.") 42 | 43 | print_message(f"#> Sub-sample the triples (if > {MAX_NUM_TRIPLES})..") 44 | print_message(f"#> len(Triples) = {len(Triples)}") 45 | 46 | if len(Triples) > MAX_NUM_TRIPLES: 47 | Triples = random.sample(Triples, MAX_NUM_TRIPLES) 48 | 49 | ### Prepare the triples ### 50 | print_message("#> Shuffling the triples...") 51 | random.shuffle(Triples) 52 | 53 | self.Triples = Examples(data=Triples) 54 | 55 | return Triples 56 | 57 | def save(self, new_path): 58 | provenance = Provenance() 59 | provenance.source = 'Triples::create' 60 | provenance.seed = self.seed 61 | provenance.positives = self.positives 62 | provenance.depth = self.depth 63 | provenance.ranking = self.ranking_provenance 64 | 65 | Examples(data=self.Triples, provenance=provenance).save(new_path) 66 | -------------------------------------------------------------------------------- /rankify/utils/retrievers/colbert/colbert/utilities/minicorpus.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | from colbert.utils.utils import create_directory 5 | 6 | from colbert.data.collection import Collection 7 | from colbert.data.queries import Queries 8 | from colbert.data.ranking import Ranking 9 | 10 | 11 | def sample_minicorpus(name, factor, topk=30, maxdev=3000): 12 | """ 13 | Factor: 14 | * nano=1 15 | * micro=10 16 | * mini=100 17 | * small=100 with topk=100 18 | * medium=150 with topk=300 19 | """ 20 | 21 | random.seed(12345) 22 | 23 | # Load collection 24 | collection = Collection(path='/dfs/scratch0/okhattab/OpenQA/collection.tsv') 25 | 26 | # Load train and dev queries 27 | qas_train = Queries(path='/dfs/scratch0/okhattab/OpenQA/NQ/train/qas.json').qas() 28 | qas_dev = Queries(path='/dfs/scratch0/okhattab/OpenQA/NQ/dev/qas.json').qas() 29 | 30 | # Load train and dev C3 rankings 31 | ranking_train = Ranking(path='/dfs/scratch0/okhattab/OpenQA/NQ/train/rankings/C3.tsv.annotated').todict() 32 | ranking_dev = Ranking(path='/dfs/scratch0/okhattab/OpenQA/NQ/dev/rankings/C3.tsv.annotated').todict() 33 | 34 | # Sample NT and ND queries from each, keep only the top-k passages for those 35 | sample_train = random.sample(list(qas_train.keys()), min(len(qas_train.keys()), 300*factor)) 36 | sample_dev = random.sample(list(qas_dev.keys()), min(len(qas_dev.keys()), maxdev, 30*factor)) 37 | 38 | train_pids = [pid for qid in sample_train for qpids in ranking_train[qid][:topk] for pid in qpids] 39 | dev_pids = [pid for qid in sample_dev for qpids in ranking_dev[qid][:topk] for pid in qpids] 40 | 41 | sample_pids = sorted(list(set(train_pids + dev_pids))) 42 | print(f'len(sample_pids) = {len(sample_pids)}') 43 | 44 | # Save the new query sets: train and dev 45 | ROOT = f'/future/u/okhattab/root/unit/data/NQ-{name}' 46 | 47 | create_directory(os.path.join(ROOT, 'train')) 48 | create_directory(os.path.join(ROOT, 'dev')) 49 | 50 | new_train = Queries(data={qid: qas_train[qid] for qid in sample_train}) 51 | new_train.save(os.path.join(ROOT, 'train/questions.tsv')) 52 | new_train.save_qas(os.path.join(ROOT, 'train/qas.json')) 53 | 54 | new_dev = Queries(data={qid: qas_dev[qid] for qid in sample_dev}) 55 | new_dev.save(os.path.join(ROOT, 'dev/questions.tsv')) 56 | new_dev.save_qas(os.path.join(ROOT, 'dev/qas.json')) 57 | 58 | # Save the new collection 59 | print(f"Saving to {os.path.join(ROOT, 'collection.tsv')}") 60 | Collection(data=[collection[pid] for pid in sample_pids]).save(os.path.join(ROOT, 'collection.tsv')) 61 | 62 | print('#> Done!') 63 | 64 | """ 65 | if __name__ == '__main__': 66 | sample_minicorpus('medium', 150, topk=300) 67 | """ -------------------------------------------------------------------------------- /rankify/utils/retrievers/colbert/colbert/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataScienceUIBK/Rankify/c3d521c2d1a17c7d82f2a9025aa000d9ad41b7d2/rankify/utils/retrievers/colbert/colbert/utils/__init__.py -------------------------------------------------------------------------------- /rankify/utils/retrievers/colbert/colbert/utils/amp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from contextlib import contextmanager 4 | from rankify.utils.retrievers.colbert.colbert.utils.utils import NullContextManager 5 | 6 | 7 | class MixedPrecisionManager(): 8 | def __init__(self, activated): 9 | self.activated = activated 10 | 11 | if self.activated: 12 | self.scaler = torch.cuda.amp.GradScaler() 13 | 14 | def context(self): 15 | return torch.cuda.amp.autocast() if self.activated else NullContextManager() 16 | 17 | def backward(self, loss): 18 | if self.activated: 19 | self.scaler.scale(loss).backward() 20 | else: 21 | loss.backward() 22 | 23 | def step(self, colbert, optimizer, scheduler=None): 24 | if self.activated: 25 | self.scaler.unscale_(optimizer) 26 | torch.nn.utils.clip_grad_norm_(colbert.parameters(), 2.0, error_if_nonfinite=False) 27 | 28 | self.scaler.step(optimizer) 29 | self.scaler.update() 30 | else: 31 | torch.nn.utils.clip_grad_norm_(colbert.parameters(), 2.0) 32 | optimizer.step() 33 | 34 | if scheduler is not None: 35 | scheduler.step() 36 | 37 | optimizer.zero_grad() 38 | -------------------------------------------------------------------------------- /rankify/utils/retrievers/colbert/colbert/utils/distributed.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import torch 4 | import numpy as np 5 | 6 | ALREADY_INITALIZED = False 7 | 8 | # TODO: Consider torch.distributed.is_initialized() instead 9 | 10 | 11 | def init(rank): 12 | nranks = 'WORLD_SIZE' in os.environ and int(os.environ['WORLD_SIZE']) 13 | nranks = max(1, nranks) 14 | is_distributed = (nranks > 1) or ('WORLD_SIZE' in os.environ) 15 | 16 | global ALREADY_INITALIZED 17 | if ALREADY_INITALIZED: 18 | return nranks, is_distributed 19 | 20 | ALREADY_INITALIZED = True 21 | 22 | if is_distributed and torch.cuda.is_available(): 23 | num_gpus = torch.cuda.device_count() 24 | print(f'nranks = {nranks} \t num_gpus = {num_gpus} \t device={rank % num_gpus}') 25 | 26 | torch.cuda.set_device(rank % num_gpus) 27 | torch.distributed.init_process_group(backend='nccl', init_method='env://') 28 | 29 | return nranks, is_distributed 30 | 31 | 32 | def barrier(rank): 33 | nranks = 'WORLD_SIZE' in os.environ and int(os.environ['WORLD_SIZE']) 34 | nranks = max(1, nranks) 35 | 36 | if rank >= 0 and nranks > 1: 37 | torch.distributed.barrier(device_ids=[rank % torch.cuda.device_count()]) 38 | -------------------------------------------------------------------------------- /rankify/utils/retrievers/colbert/utility/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataScienceUIBK/Rankify/c3d521c2d1a17c7d82f2a9025aa000d9ad41b7d2/rankify/utils/retrievers/colbert/utility/__init__.py -------------------------------------------------------------------------------- /rankify/utils/retrievers/colbert/utility/evaluate/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataScienceUIBK/Rankify/c3d521c2d1a17c7d82f2a9025aa000d9ad41b7d2/rankify/utils/retrievers/colbert/utility/evaluate/__init__.py -------------------------------------------------------------------------------- /rankify/utils/retrievers/colbert/utility/evaluate/annotate_EM_helpers.py: -------------------------------------------------------------------------------- 1 | from colbert.utils.utils import print_message 2 | from utility.utils.dpr import DPR_normalize, has_answer 3 | 4 | 5 | def tokenize_all_answers(args): 6 | qid, question, answers = args 7 | return qid, question, [DPR_normalize(ans) for ans in answers] 8 | 9 | 10 | def assign_label_to_passage(args): 11 | idx, (qid, pid, rank, passage, tokenized_answers) = args 12 | 13 | if idx % (1*1000*1000) == 0: 14 | print(idx) 15 | 16 | return qid, pid, rank, has_answer(tokenized_answers, passage) 17 | 18 | 19 | def check_sizes(qid2answers, qid2rankings): 20 | num_judged_queries = len(qid2answers) 21 | num_ranked_queries = len(qid2rankings) 22 | 23 | print_message('num_judged_queries =', num_judged_queries) 24 | print_message('num_ranked_queries =', num_ranked_queries) 25 | 26 | if num_judged_queries != num_ranked_queries: 27 | assert num_ranked_queries <= num_judged_queries 28 | 29 | print('\n\n') 30 | print_message('[WARNING] num_judged_queries != num_ranked_queries') 31 | print('\n\n') 32 | 33 | return num_judged_queries, num_ranked_queries 34 | 35 | 36 | def compute_and_write_labels(output_path, qid2answers, qid2rankings): 37 | cutoffs = [1, 5, 10, 20, 30, 50, 100, 1000, 'all'] 38 | success = {cutoff: 0.0 for cutoff in cutoffs} 39 | counts = {cutoff: 0.0 for cutoff in cutoffs} 40 | 41 | with open(output_path, 'w') as f: 42 | for qid in qid2answers: 43 | if qid not in qid2rankings: 44 | continue 45 | 46 | prev_rank = 0 # ranks should start at one (i.e., and not zero) 47 | labels = [] 48 | 49 | for pid, rank, label in qid2rankings[qid]: 50 | assert rank == prev_rank+1, (qid, pid, (prev_rank, rank)) 51 | prev_rank = rank 52 | 53 | labels.append(label) 54 | line = '\t'.join(map(str, [qid, pid, rank, int(label)])) + '\n' 55 | f.write(line) 56 | 57 | for cutoff in cutoffs: 58 | if cutoff != 'all': 59 | success[cutoff] += sum(labels[:cutoff]) > 0 60 | counts[cutoff] += sum(labels[:cutoff]) 61 | else: 62 | success[cutoff] += sum(labels) > 0 63 | counts[cutoff] += sum(labels) 64 | 65 | return success, counts 66 | 67 | 68 | # def dump_metrics(f, nqueries, cutoffs, success, counts): 69 | # for cutoff in cutoffs: 70 | # success_log = "#> P@{} = {}".format(cutoff, success[cutoff] / nqueries) 71 | # counts_log = "#> D@{} = {}".format(cutoff, counts[cutoff] / nqueries) 72 | # print('\n'.join([success_log, counts_log]) + '\n') 73 | 74 | # f.write('\n'.join([success_log, counts_log]) + '\n\n') 75 | -------------------------------------------------------------------------------- /rankify/utils/retrievers/colbert/utility/evaluate/evaluate_lotte_rankings.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from collections import defaultdict 3 | import jsonlines 4 | import os 5 | import sys 6 | 7 | 8 | def evaluate_dataset(query_type, dataset, split, k, data_rootdir, rankings_rootdir): 9 | data_path = os.path.join(data_rootdir, dataset, split) 10 | rankings_path = os.path.join( 11 | rankings_rootdir, split, f"{dataset}.{query_type}.ranking.tsv" 12 | ) 13 | if not os.path.exists(rankings_path): 14 | print(f"[query_type={query_type}, dataset={dataset}] Success@{k}: ???") 15 | return 16 | rankings = defaultdict(list) 17 | with open(rankings_path, "r") as f: 18 | for line in f: 19 | items = line.strip().split("\t") 20 | qid, pid, rank = items[:3] 21 | qid = int(qid) 22 | pid = int(pid) 23 | rank = int(rank) 24 | rankings[qid].append(pid) 25 | assert rank == len(rankings[qid]) 26 | 27 | success = 0 28 | qas_path = os.path.join(data_path, f"qas.{query_type}.jsonl") 29 | 30 | num_total_qids = 0 31 | with jsonlines.open(qas_path, mode="r") as f: 32 | for line in f: 33 | qid = int(line["qid"]) 34 | num_total_qids += 1 35 | if qid not in rankings: 36 | print(f"WARNING: qid {qid} not found in {rankings_path}!", file=sys.stderr) 37 | continue 38 | answer_pids = set(line["answer_pids"]) 39 | if len(set(rankings[qid][:k]).intersection(answer_pids)) > 0: 40 | success += 1 41 | print( 42 | f"[query_type={query_type}, dataset={dataset}] " 43 | f"Success@{k}: {success / num_total_qids * 100:.1f}" 44 | ) 45 | 46 | 47 | def main(args): 48 | for query_type in ["search", "forum"]: 49 | for dataset in [ 50 | "writing", 51 | "recreation", 52 | "science", 53 | "technology", 54 | "lifestyle", 55 | "pooled", 56 | ]: 57 | evaluate_dataset( 58 | query_type, 59 | dataset, 60 | args.split, 61 | args.k, 62 | args.data_dir, 63 | args.rankings_dir, 64 | ) 65 | print() 66 | 67 | 68 | if __name__ == "__main__": 69 | parser = argparse.ArgumentParser(description="LoTTE evaluation script") 70 | parser.add_argument("--k", type=int, default=5, help="Success@k") 71 | parser.add_argument( 72 | "-s", "--split", choices=["dev", "test"], required=True, help="Split" 73 | ) 74 | parser.add_argument( 75 | "-d", "--data_dir", type=str, required=True, help="Path to LoTTE data directory" 76 | ) 77 | parser.add_argument( 78 | "-r", 79 | "--rankings_dir", 80 | type=str, 81 | required=True, 82 | help="Path to LoTTE rankings directory", 83 | ) 84 | args = parser.parse_args() 85 | main(args) 86 | -------------------------------------------------------------------------------- /rankify/utils/retrievers/colbert/utility/preprocess/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataScienceUIBK/Rankify/c3d521c2d1a17c7d82f2a9025aa000d9ad41b7d2/rankify/utils/retrievers/colbert/utility/preprocess/__init__.py -------------------------------------------------------------------------------- /rankify/utils/retrievers/colbert/utility/preprocess/queries_split.py: -------------------------------------------------------------------------------- 1 | """ 2 | Divide a query set into two. 3 | """ 4 | 5 | import os 6 | import math 7 | import ujson 8 | import random 9 | 10 | from argparse import ArgumentParser 11 | from collections import OrderedDict 12 | from colbert.utils.utils import print_message 13 | 14 | 15 | def main(args): 16 | random.seed(12345) 17 | 18 | """ 19 | Load the queries 20 | """ 21 | Queries = OrderedDict() 22 | 23 | print_message(f"#> Loading queries from {args.input}..") 24 | with open(args.input) as f: 25 | for line in f: 26 | qid, query = line.strip().split('\t') 27 | 28 | assert qid not in Queries 29 | Queries[qid] = query 30 | 31 | """ 32 | Apply the splitting 33 | """ 34 | size_a = len(Queries) - args.holdout 35 | size_b = args.holdout 36 | size_a, size_b = max(size_a, size_b), min(size_a, size_b) 37 | 38 | assert size_a > 0 and size_b > 0, (len(Queries), size_a, size_b) 39 | 40 | print_message(f"#> Deterministically splitting the queries into ({size_a}, {size_b})-sized splits.") 41 | 42 | keys = list(Queries.keys()) 43 | sample_b_indices = sorted(list(random.sample(range(len(keys)), size_b))) 44 | sample_a_indices = sorted(list(set.difference(set(list(range(len(keys)))), set(sample_b_indices)))) 45 | 46 | assert len(sample_a_indices) == size_a 47 | assert len(sample_b_indices) == size_b 48 | 49 | sample_a = [keys[idx] for idx in sample_a_indices] 50 | sample_b = [keys[idx] for idx in sample_b_indices] 51 | 52 | """ 53 | Write the output 54 | """ 55 | 56 | output_path_a = f'{args.input}.a' 57 | output_path_b = f'{args.input}.b' 58 | 59 | assert not os.path.exists(output_path_a), output_path_a 60 | assert not os.path.exists(output_path_b), output_path_b 61 | 62 | print_message(f"#> Writing the splits out to {output_path_a} and {output_path_b} ...") 63 | 64 | for output_path, sample in [(output_path_a, sample_a), (output_path_b, sample_b)]: 65 | with open(output_path, 'w') as f: 66 | for qid in sample: 67 | query = Queries[qid] 68 | line = '\t'.join([qid, query]) + '\n' 69 | f.write(line) 70 | 71 | 72 | if __name__ == "__main__": 73 | parser = ArgumentParser(description="queries_split.") 74 | 75 | # Input Arguments. 76 | parser.add_argument('--input', dest='input', required=True) 77 | parser.add_argument('--holdout', dest='holdout', required=True, type=int) 78 | 79 | args = parser.parse_args() 80 | 81 | main(args) 82 | -------------------------------------------------------------------------------- /rankify/utils/retrievers/colbert/utility/rankings/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataScienceUIBK/Rankify/c3d521c2d1a17c7d82f2a9025aa000d9ad41b7d2/rankify/utils/retrievers/colbert/utility/rankings/__init__.py -------------------------------------------------------------------------------- /rankify/utils/retrievers/colbert/utility/rankings/dev_subsample.py: -------------------------------------------------------------------------------- 1 | import os 2 | import ujson 3 | import random 4 | 5 | from argparse import ArgumentParser 6 | 7 | from colbert.utils.utils import print_message, create_directory, load_ranking, groupby_first_item 8 | from utility.utils.qa_loaders import load_qas_ 9 | 10 | 11 | def main(args): 12 | print_message("#> Loading all..") 13 | qas = load_qas_(args.qas) 14 | rankings = load_ranking(args.ranking) 15 | qid2rankings = groupby_first_item(rankings) 16 | 17 | print_message("#> Subsampling all..") 18 | qas_sample = random.sample(qas, args.sample) 19 | 20 | with open(args.output, 'w') as f: 21 | for qid, *_ in qas_sample: 22 | for items in qid2rankings[qid]: 23 | items = [qid] + items 24 | line = '\t'.join(map(str, items)) + '\n' 25 | f.write(line) 26 | 27 | print('\n\n') 28 | print(args.output) 29 | print("#> Done.") 30 | 31 | 32 | if __name__ == "__main__": 33 | random.seed(12345) 34 | 35 | parser = ArgumentParser(description='Subsample the dev set.') 36 | parser.add_argument('--qas', dest='qas', required=True, type=str) 37 | parser.add_argument('--ranking', dest='ranking', required=True) 38 | parser.add_argument('--output', dest='output', required=True) 39 | 40 | parser.add_argument('--sample', dest='sample', default=1500, type=int) 41 | 42 | args = parser.parse_args() 43 | 44 | assert not os.path.exists(args.output), args.output 45 | create_directory(os.path.dirname(args.output)) 46 | 47 | main(args) 48 | -------------------------------------------------------------------------------- /rankify/utils/retrievers/colbert/utility/rankings/merge.py: -------------------------------------------------------------------------------- 1 | """ 2 | Divide two or more ranking files, by score. 3 | """ 4 | 5 | import os 6 | import tqdm 7 | 8 | from argparse import ArgumentParser 9 | from collections import defaultdict 10 | from colbert.utils.utils import print_message, file_tqdm 11 | 12 | 13 | def main(args): 14 | Rankings = defaultdict(list) 15 | 16 | for path in args.input: 17 | print_message(f"#> Loading the rankings in {path} ..") 18 | 19 | with open(path) as f: 20 | for line in file_tqdm(f): 21 | qid, pid, rank, score = line.strip().split('\t') 22 | qid, pid, rank = map(int, [qid, pid, rank]) 23 | score = float(score) 24 | 25 | Rankings[qid].append((score, rank, pid)) 26 | 27 | with open(args.output, 'w') as f: 28 | print_message(f"#> Writing the output rankings to {args.output} ..") 29 | 30 | for qid in tqdm.tqdm(Rankings): 31 | ranking = sorted(Rankings[qid], reverse=True) 32 | 33 | for rank, (score, original_rank, pid) in enumerate(ranking): 34 | rank = rank + 1 # 1-indexed 35 | 36 | if (args.depth > 0) and (rank > args.depth): 37 | break 38 | 39 | line = [qid, pid, rank, score] 40 | line = '\t'.join(map(str, line)) + '\n' 41 | f.write(line) 42 | 43 | 44 | if __name__ == "__main__": 45 | parser = ArgumentParser(description="merge_rankings.") 46 | 47 | # Input Arguments. 48 | parser.add_argument('--input', dest='input', required=True, nargs='+') 49 | parser.add_argument('--output', dest='output', required=True, type=str) 50 | 51 | parser.add_argument('--depth', dest='depth', required=True, type=int) 52 | 53 | args = parser.parse_args() 54 | 55 | assert not os.path.exists(args.output), args.output 56 | 57 | main(args) 58 | -------------------------------------------------------------------------------- /rankify/utils/retrievers/colbert/utility/rankings/split_by_offset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Split the ranked lists after retrieval with a merged query set. 3 | """ 4 | 5 | import os 6 | import random 7 | 8 | from argparse import ArgumentParser 9 | 10 | 11 | def main(args): 12 | output_paths = ['{}.{}'.format(args.ranking, split) for split in args.names] 13 | assert all(not os.path.exists(path) for path in output_paths), output_paths 14 | 15 | output_files = [open(path, 'w') for path in output_paths] 16 | 17 | with open(args.ranking) as f: 18 | for line in f: 19 | qid, pid, rank, *other = line.strip().split('\t') 20 | qid = int(qid) 21 | split_output_path = output_files[qid // args.gap - 1] 22 | qid = qid % args.gap 23 | 24 | split_output_path.write('\t'.join([str(x) for x in [qid, pid, rank, *other]]) + '\n') 25 | 26 | print(f.name) 27 | 28 | _ = [f.close() for f in output_files] 29 | 30 | print("#> Done!") 31 | 32 | 33 | if __name__ == "__main__": 34 | random.seed(12345) 35 | 36 | parser = ArgumentParser(description='Subsample the dev set.') 37 | parser.add_argument('--ranking', dest='ranking', required=True) 38 | 39 | parser.add_argument('--names', dest='names', required=False, default=['train', 'dev', 'test'], type=str, nargs='+') # order matters! 40 | parser.add_argument('--gap', dest='gap', required=False, default=1_000_000_000, type=int) # larger than any individual query set 41 | 42 | args = parser.parse_args() 43 | 44 | main(args) 45 | -------------------------------------------------------------------------------- /rankify/utils/retrievers/colbert/utility/rankings/split_by_queries.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import tqdm 4 | import ujson 5 | import random 6 | 7 | from argparse import ArgumentParser 8 | from collections import OrderedDict 9 | from colbert.utils.utils import print_message, file_tqdm 10 | 11 | 12 | def main(args): 13 | qid_to_file_idx = {} 14 | 15 | for qrels_idx, qrels in enumerate(args.all_queries): 16 | with open(qrels) as f: 17 | for line in f: 18 | qid, *_ = line.strip().split('\t') 19 | qid = int(qid) 20 | 21 | assert qid_to_file_idx.get(qid, qrels_idx) == qrels_idx, (qid, qrels_idx) 22 | qid_to_file_idx[qid] = qrels_idx 23 | 24 | all_outputs_paths = [f'{args.ranking}.{idx}' for idx in range(len(args.all_queries))] 25 | 26 | assert all(not os.path.exists(path) for path in all_outputs_paths) 27 | 28 | all_outputs = [open(path, 'w') for path in all_outputs_paths] 29 | 30 | with open(args.ranking) as f: 31 | print_message(f"#> Loading ranked lists from {f.name} ..") 32 | 33 | last_file_idx = -1 34 | 35 | for line in file_tqdm(f): 36 | qid, *_ = line.strip().split('\t') 37 | 38 | file_idx = qid_to_file_idx[int(qid)] 39 | 40 | if file_idx != last_file_idx: 41 | print_message(f"#> Switched to file #{file_idx} at {all_outputs[file_idx].name}") 42 | 43 | last_file_idx = file_idx 44 | 45 | all_outputs[file_idx].write(line) 46 | 47 | print() 48 | 49 | for f in all_outputs: 50 | print(f.name) 51 | f.close() 52 | 53 | print("#> Done!") 54 | 55 | 56 | if __name__ == "__main__": 57 | random.seed(12345) 58 | 59 | parser = ArgumentParser(description='.') 60 | 61 | # Input Arguments 62 | parser.add_argument('--ranking', dest='ranking', required=True, type=str) 63 | parser.add_argument('--all-queries', dest='all_queries', required=True, type=str, nargs='+') 64 | 65 | args = parser.parse_args() 66 | 67 | main(args) 68 | -------------------------------------------------------------------------------- /rankify/utils/retrievers/colbert/utility/rankings/tune.py: -------------------------------------------------------------------------------- 1 | import os 2 | import ujson 3 | import random 4 | 5 | from argparse import ArgumentParser 6 | from colbert.utils.utils import print_message, create_directory 7 | from utility.utils.save_metadata import save_metadata 8 | 9 | 10 | def main(args): 11 | AllMetrics = {} 12 | Scores = {} 13 | 14 | for path in args.paths: 15 | with open(path) as f: 16 | metric = ujson.load(f) 17 | AllMetrics[path] = metric 18 | 19 | for k in args.metric: 20 | metric = metric[k] 21 | 22 | assert type(metric) is float 23 | Scores[path] = metric 24 | 25 | MaxKey = max(Scores, key=Scores.get) 26 | 27 | MaxCKPT = int(MaxKey.split('/')[-2].split('.')[-1]) 28 | MaxARGS = os.path.join(os.path.dirname(MaxKey), 'logs', 'args.json') 29 | 30 | with open(MaxARGS) as f: 31 | logs = ujson.load(f) 32 | MaxCHECKPOINT = logs['checkpoint'] 33 | 34 | assert MaxCHECKPOINT.endswith(f'colbert-{MaxCKPT}.dnn'), (MaxCHECKPOINT, MaxCKPT) 35 | 36 | with open(args.output, 'w') as f: 37 | f.write(MaxCHECKPOINT) 38 | 39 | args.Scores = Scores 40 | args.AllMetrics = AllMetrics 41 | 42 | save_metadata(f'{args.output}.meta', args) 43 | 44 | print('\n\n', args, '\n\n') 45 | print(args.output) 46 | print_message("#> Done.") 47 | 48 | 49 | if __name__ == "__main__": 50 | random.seed(12345) 51 | 52 | parser = ArgumentParser(description='.') 53 | 54 | # Input / Output Arguments 55 | parser.add_argument('--metric', dest='metric', required=True, type=str) # e.g., success.20 56 | parser.add_argument('--paths', dest='paths', required=True, type=str, nargs='+') 57 | parser.add_argument('--output', dest='output', required=True, type=str) 58 | 59 | args = parser.parse_args() 60 | 61 | args.metric = args.metric.split('.') 62 | 63 | assert not os.path.exists(args.output), args.output 64 | create_directory(os.path.dirname(args.output)) 65 | 66 | main(args) 67 | -------------------------------------------------------------------------------- /rankify/utils/retrievers/colbert/utility/supervision/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataScienceUIBK/Rankify/c3d521c2d1a17c7d82f2a9025aa000d9ad41b7d2/rankify/utils/retrievers/colbert/utility/supervision/__init__.py -------------------------------------------------------------------------------- /rankify/utils/retrievers/colbert/utility/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataScienceUIBK/Rankify/c3d521c2d1a17c7d82f2a9025aa000d9ad41b7d2/rankify/utils/retrievers/colbert/utility/utils/__init__.py -------------------------------------------------------------------------------- /rankify/utils/retrievers/colbert/utility/utils/qa_loaders.py: -------------------------------------------------------------------------------- 1 | import os 2 | import ujson 3 | 4 | from collections import defaultdict 5 | from colbert.utils.utils import print_message, file_tqdm 6 | 7 | 8 | def load_collection_(path, retain_titles): 9 | with open(path) as f: 10 | collection = [] 11 | 12 | for line in file_tqdm(f): 13 | _, passage, title = line.strip().split('\t') 14 | 15 | if retain_titles: 16 | passage = title + ' | ' + passage 17 | 18 | collection.append(passage) 19 | 20 | return collection 21 | 22 | 23 | def load_qas_(path): 24 | print_message("#> Loading the reference QAs from", path) 25 | 26 | triples = [] 27 | 28 | with open(path) as f: 29 | for line in f: 30 | qa = ujson.loads(line) 31 | triples.append((qa['qid'], qa['question'], qa['answers'])) 32 | 33 | return triples 34 | -------------------------------------------------------------------------------- /rankify/utils/retrievers/colbert/utility/utils/save_metadata.py: -------------------------------------------------------------------------------- 1 | from rankify.utils.retrievers.colbert.colbert.utils.utils import dotdict 2 | import os 3 | import sys 4 | import git 5 | import time 6 | import copy 7 | import ujson 8 | import socket 9 | 10 | 11 | def get_metadata_only(): 12 | args = dotdict() 13 | 14 | args.hostname = socket.gethostname() 15 | try: 16 | args.git_branch = git.Repo(search_parent_directories=True).active_branch.name 17 | args.git_hash = git.Repo(search_parent_directories=True).head.object.hexsha 18 | args.git_commit_datetime = str(git.Repo(search_parent_directories=True).head.object.committed_datetime) 19 | except git.exc.InvalidGitRepositoryError as e: 20 | pass 21 | args.current_datetime = time.strftime('%b %d, %Y ; %l:%M%p %Z (%z)') 22 | args.cmd = ' '.join(sys.argv) 23 | 24 | return args 25 | 26 | 27 | def get_metadata(args): 28 | args = copy.deepcopy(args) 29 | 30 | args.hostname = socket.gethostname() 31 | args.git_branch = git.Repo(search_parent_directories=True).active_branch.name 32 | args.git_hash = git.Repo(search_parent_directories=True).head.object.hexsha 33 | args.git_commit_datetime = str(git.Repo(search_parent_directories=True).head.object.committed_datetime) 34 | args.current_datetime = time.strftime('%b %d, %Y ; %l:%M%p %Z (%z)') 35 | args.cmd = ' '.join(sys.argv) 36 | 37 | try: 38 | args.input_arguments = copy.deepcopy(args.input_arguments.__dict__) 39 | except: 40 | args.input_arguments = None 41 | 42 | return dict(args.__dict__) 43 | 44 | # TODO: No reason for deepcopy. But: (a) Call provenance() on objects that can, (b) Only save simple, small objects. No massive lists or models or weird stuff! 45 | # With that, I think we don't even need (necessarily) to restrict things to input_arguments. 46 | 47 | def format_metadata(metadata): 48 | assert type(metadata) == dict 49 | 50 | return ujson.dumps(metadata, indent=4) 51 | 52 | 53 | def save_metadata(path, args): 54 | assert not os.path.exists(path), path 55 | 56 | with open(path, 'w') as output_metadata: 57 | data = get_metadata(args) 58 | output_metadata.write(format_metadata(data) + '\n') 59 | 60 | return data 61 | -------------------------------------------------------------------------------- /rankify/utils/retrievers/contriever/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataScienceUIBK/Rankify/c3d521c2d1a17c7d82f2a9025aa000d9ad41b7d2/rankify/utils/retrievers/contriever/__init__.py -------------------------------------------------------------------------------- /rankify/utils/retrievers/hyde/__init__.py: -------------------------------------------------------------------------------- 1 | from .generator import OpenAIGenerator, CohereGenerator 2 | from .promptor import Promptor 3 | -------------------------------------------------------------------------------- /rankify/utils/retrievers/hyde/promptor.py: -------------------------------------------------------------------------------- 1 | WEB_SEARCH = """Please write a passage to answer the question. 2 | Question: {} 3 | Passage:""" 4 | 5 | 6 | SCIFACT = """Please write a scientific paper passage to support/refute the claim. 7 | Claim: {} 8 | Passage:""" 9 | 10 | 11 | ARGUANA = """Please write a counter argument for the passage. 12 | Passage: {} 13 | Counter Argument:""" 14 | 15 | 16 | TREC_COVID = """Please write a scientific paper passage to answer the question. 17 | Question: {} 18 | Passage:""" 19 | 20 | 21 | FIQA = """Please write a financial article passage to answer the question. 22 | Question: {} 23 | Passage:""" 24 | 25 | 26 | DBPEDIA_ENTITY = """Please write a passage to answer the question. 27 | Question: {} 28 | Passage:""" 29 | 30 | 31 | TREC_NEWS = """Please write a news passage about the topic. 32 | Topic: {} 33 | Passage:""" 34 | 35 | 36 | MR_TYDI = """Please write a passage in {} to answer the question in detail. 37 | Question: {} 38 | Passage:""" 39 | 40 | 41 | class Promptor: 42 | def __init__(self, task: str, language: str = 'en'): 43 | self.task = task 44 | self.language = language 45 | 46 | def build_prompt(self, query: str): 47 | if self.task == 'web search': 48 | return WEB_SEARCH.format(query) 49 | elif self.task == 'scifact': 50 | return SCIFACT.format(query) 51 | elif self.task == 'arguana': 52 | return ARGUANA.format(query) 53 | elif self.task == 'trec-covid': 54 | return TREC_COVID.format(query) 55 | elif self.task == 'fiqa': 56 | return FIQA.format(query) 57 | elif self.task == 'dbpedia-entity': 58 | return DBPEDIA_ENTITY.format(query) 59 | elif self.task == 'trec-news': 60 | return TREC_NEWS.format(query) 61 | elif self.task == 'mr-tydi': 62 | return MR_TYDI.format(self.language, query) 63 | else: 64 | raise ValueError('Task not supported') -------------------------------------------------------------------------------- /rankify/utils/retrievers/splade/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataScienceUIBK/Rankify/c3d521c2d1a17c7d82f2a9025aa000d9ad41b7d2/rankify/utils/retrievers/splade/__init__.py -------------------------------------------------------------------------------- /rankify/utils/retrievers/splade/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataScienceUIBK/Rankify/c3d521c2d1a17c7d82f2a9025aa000d9ad41b7d2/rankify/utils/retrievers/splade/datasets/__init__.py -------------------------------------------------------------------------------- /rankify/utils/retrievers/splade/indexing/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataScienceUIBK/Rankify/c3d521c2d1a17c7d82f2a9025aa000d9ad41b7d2/rankify/utils/retrievers/splade/indexing/__init__.py -------------------------------------------------------------------------------- /rankify/utils/retrievers/splade/losses/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataScienceUIBK/Rankify/c3d521c2d1a17c7d82f2a9025aa000d9ad41b7d2/rankify/utils/retrievers/splade/losses/__init__.py -------------------------------------------------------------------------------- /rankify/utils/retrievers/splade/losses/pointwise.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class BCEWithLogitsLoss: 5 | def __init__(self): 6 | self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 7 | self.loss = torch.nn.BCEWithLogitsLoss(reduction="mean") 8 | 9 | def __call__(self, out_d): 10 | pos_scores, neg_scores = out_d["pos_score"], out_d["neg_score"] 11 | p = pos_scores.squeeze() 12 | n = neg_scores.squeeze() 13 | labels = torch.cat([torch.ones(p.shape[0]), torch.zeros(n.shape[0])]).to(self.device) 14 | return self.loss(torch.cat([p, n]), labels) 15 | -------------------------------------------------------------------------------- /rankify/utils/retrievers/splade/losses/regularization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class L1: 5 | 6 | def __call__(self, batch_rep): 7 | return torch.sum(torch.abs(batch_rep), dim=-1).mean() 8 | 9 | 10 | class L0: 11 | """non-differentiable 12 | """ 13 | 14 | def __call__(self, batch_rep): 15 | return torch.count_nonzero(batch_rep, dim=-1).float().mean() 16 | 17 | 18 | class FLOPS: 19 | """constraint from Minimizing FLOPs to Learn Efficient Sparse Representations 20 | https://arxiv.org/abs/2004.05665 21 | """ 22 | 23 | def __call__(self, batch_rep): 24 | return torch.sum(torch.mean(torch.abs(batch_rep), dim=0) ** 2) 25 | 26 | 27 | class RegWeightScheduler: 28 | """same scheduling as in: Minimizing FLOPs to Learn Efficient Sparse Representations 29 | https://arxiv.org/abs/2004.05665 30 | """ 31 | 32 | def __init__(self, lambda_, T): 33 | self.lambda_ = lambda_ 34 | self.T = T 35 | self.t = 0 36 | self.lambda_t = 0 37 | 38 | def step(self): 39 | """quadratic increase until time T 40 | """ 41 | if self.t >= self.T: 42 | pass 43 | else: 44 | self.t += 1 45 | self.lambda_t = self.lambda_ * (self.t / self.T) ** 2 46 | return self.lambda_t 47 | 48 | def get_lambda(self): 49 | return self.lambda_t 50 | 51 | 52 | class SparsityRatio: 53 | """non-differentiable 54 | """ 55 | 56 | def __init__(self, output_dim): 57 | self.output_dim = output_dim 58 | 59 | def __call__(self, batch_rep): 60 | return 1 - torch.sum(batch_rep != 0, dim=-1).float().mean() / self.output_dim 61 | 62 | 63 | def init_regularizer(reg, **kwargs): 64 | if reg == "L0": 65 | return L0() 66 | elif reg == "sparsity_ratio": 67 | return SparsityRatio(output_dim=kwargs["output_dim"]) 68 | elif reg == "L1": 69 | return L1() 70 | elif reg == "FLOPS": 71 | return FLOPS() 72 | else: 73 | raise NotImplementedError("provide valid regularizer") 74 | -------------------------------------------------------------------------------- /rankify/utils/retrievers/splade/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataScienceUIBK/Rankify/c3d521c2d1a17c7d82f2a9025aa000d9ad41b7d2/rankify/utils/retrievers/splade/models/__init__.py -------------------------------------------------------------------------------- /rankify/utils/retrievers/splade/models/models_utils.py: -------------------------------------------------------------------------------- 1 | from omegaconf import DictConfig 2 | 3 | from ..models.transformer_rep import Splade, SpladeDoc, SpladeTopK, SpladeLexical 4 | 5 | 6 | def get_model(config: DictConfig, init_dict: DictConfig): 7 | # no need to reload model here, it will be done later 8 | # (either in train.py or in Evaluator.__init__() 9 | 10 | model_map = { 11 | "splade": Splade, 12 | "splade_doc": SpladeDoc, 13 | "splade_topk": SpladeTopK, 14 | "splade_lexical": SpladeLexical 15 | } 16 | try: 17 | model_class = model_map[config["matching_type"]] 18 | except KeyError: 19 | raise NotImplementedError("provide valid matching type ({})".format(config["matching_type"])) 20 | return model_class(**init_dict) 21 | -------------------------------------------------------------------------------- /rankify/utils/retrievers/splade/tasks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataScienceUIBK/Rankify/c3d521c2d1a17c7d82f2a9025aa000d9ad41b7d2/rankify/utils/retrievers/splade/tasks/__init__.py -------------------------------------------------------------------------------- /rankify/utils/retrievers/splade/tasks/amp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | # inspired from Colbert repo: https://github.com/stanford-futuredata/ColBERT 4 | 5 | # replace this with contextlib.nullcontext if python >3.7 6 | # see https://stackoverflow.com/a/45187287 7 | class NullContextManager(object): 8 | def __init__(self, dummy_resource=None): 9 | self.dummy_resource = dummy_resource 10 | 11 | def __enter__(self): 12 | return self.dummy_resource 13 | 14 | def __exit__(self, *args): 15 | pass 16 | 17 | 18 | class MixedPrecisionManager: 19 | def __init__(self, activated): 20 | 21 | print("Using FP16:", activated) 22 | self.activated = activated 23 | if self.activated: 24 | self.scaler = torch.cuda.amp.GradScaler() 25 | 26 | def context(self): 27 | return torch.cuda.amp.autocast() if self.activated else NullContextManager() 28 | 29 | def backward(self, loss): 30 | if self.activated: 31 | self.scaler.scale(loss).backward() 32 | else: 33 | loss.backward() 34 | 35 | def step(self, optimizer): 36 | if self.activated: 37 | self.scaler.unscale_(optimizer) 38 | self.scaler.step(optimizer) 39 | self.scaler.update() 40 | optimizer.zero_grad() 41 | else: 42 | optimizer.step() 43 | optimizer.zero_grad() 44 | -------------------------------------------------------------------------------- /rankify/utils/retrievers/splade/tasks/base/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataScienceUIBK/Rankify/c3d521c2d1a17c7d82f2a9025aa000d9ad41b7d2/rankify/utils/retrievers/splade/tasks/base/__init__.py -------------------------------------------------------------------------------- /rankify/utils/retrievers/splade/tasks/base/early_stopping.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class EarlyStopping: 5 | 6 | def __init__(self, patience, mode): 7 | """mode: early stopping on loss or metrics ? 8 | """ 9 | self.patience = patience 10 | self.counter = 0 11 | self.best = np.Inf if mode == "loss" else 0 12 | self.fn = lambda x, y: x < y if mode == "loss" else lambda a, b: a > b 13 | self.stop = False 14 | print("-- initialize early stopping with {}, patience={}".format(mode, patience)) 15 | 16 | def __call__(self, val_perf, trainer, step): 17 | if self.fn(val_perf, self.best): 18 | # => improvement 19 | self.best = val_perf 20 | self.counter = 0 21 | trainer.save_checkpoint(step=step, perf=val_perf, is_best=True) 22 | else: 23 | # => no improvement 24 | self.counter += 1 25 | if self.counter > self.patience: 26 | self.stop = True 27 | -------------------------------------------------------------------------------- /rankify/utils/retrievers/splade/tasks/base/saver.py: -------------------------------------------------------------------------------- 1 | class ValidationSaver: 2 | 3 | def __init__(self, loss): 4 | """loss: boolean indicating if we monitor loss (True) or metric (False)""" 5 | self.loss = loss 6 | self.best = 10e9 if loss else 0 7 | self.fn = lambda x, y: x < y if loss else x > y 8 | 9 | def __call__(self, val_perf, trainer, step): 10 | if self.fn(val_perf, self.best): 11 | # => improvement 12 | self.best = val_perf 13 | trainer.save_checkpoint(step=step, perf=val_perf, is_best=True) 14 | -------------------------------------------------------------------------------- /rankify/utils/retrievers/splade/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataScienceUIBK/Rankify/c3d521c2d1a17c7d82f2a9025aa000d9ad41b7d2/rankify/utils/retrievers/splade/utils/__init__.py -------------------------------------------------------------------------------- /rankify/utils/retrievers/splade/utils/hydra.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from hydra.utils import get_original_cwd 4 | from omegaconf import OmegaConf 5 | 6 | 7 | def hydra_chdir(exp_dict): 8 | print(OmegaConf.to_yaml(exp_dict)) 9 | try: 10 | os.chdir(get_original_cwd()) 11 | except ValueError: 12 | # hydra manual init, nothing to do 13 | pass 14 | -------------------------------------------------------------------------------- /rankify/utils/retrievers/splade/utils/index_figure.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | import numpy as np 5 | from omegaconf import DictConfig 6 | 7 | import hydra 8 | from conf.CONFIG_CHOICE import CONFIG_NAME, CONFIG_PATH 9 | from .utils import get_initialize_config 10 | 11 | 12 | @hydra.main(config_path=CONFIG_PATH, config_name=CONFIG_NAME) 13 | def index_figure(exp_dict: DictConfig): 14 | import matplotlib.pyplot as plt 15 | exp_dict, config, init_dict, model_training_config = get_initialize_config(exp_dict) 16 | index_folder = config.index_dir 17 | index_file = os.path.join(index_folder, "index_dist.json") 18 | figure_file = os.path.join(index_folder, "index_dist.png") 19 | index_dist = json.load(open(index_file)) 20 | 21 | sorted_dist = -np.array(sorted(-np.array(list(index_dist.values())))) 22 | 23 | fig = plt.figure() 24 | ax = plt.gca() 25 | ax.plot(sorted_dist) 26 | ax.set_yscale("log") 27 | ax.set_title("Index distribution (size of posting list)") 28 | ax.set_xlabel("Token in decreasing number of documents") 29 | ax.set_ylabel("Number of documents") 30 | plt.savefig(figure_file) 31 | print(figure_file) 32 | 33 | 34 | if __name__ == "__main__": 35 | index_figure() 36 | -------------------------------------------------------------------------------- /rankify/utils/retrievers/splade/utils/metrics.py: -------------------------------------------------------------------------------- 1 | from collections import Counter 2 | 3 | import pytrec_eval 4 | from pytrec_eval import RelevanceEvaluator 5 | 6 | 7 | def truncate_run(run, k): 8 | """truncates run file to only contain top-k results for each query""" 9 | temp_d = {} 10 | for q_id in run: 11 | sorted_run = {k: v for k, v in sorted(run[q_id].items(), key=lambda item: item[1], reverse=True)} 12 | temp_d[q_id] = {k: sorted_run[k] for k in list(sorted_run.keys())[:k]} 13 | return temp_d 14 | 15 | 16 | def mrr_k(run, qrel, k, agg=True): 17 | evaluator = RelevanceEvaluator(qrel, {"recip_rank"}) 18 | truncated = truncate_run(run, k) 19 | mrr = evaluator.evaluate(truncated) 20 | if agg: 21 | mrr = sum([d["recip_rank"] for d in mrr.values()]) / max(1, len(mrr)) 22 | return mrr 23 | 24 | 25 | def evaluate(run, qrel, metric, agg=True, select=None): 26 | assert metric in pytrec_eval.supported_measures, print("provide valid pytrec_eval metric") 27 | evaluator = RelevanceEvaluator(qrel, {metric}) 28 | out_eval = evaluator.evaluate(run) 29 | res = Counter({}) 30 | if agg: 31 | for d in out_eval.values(): # when there are several results provided (e.g. several cut values) 32 | res += Counter(d) 33 | res = {k: v / len(out_eval) for k, v in res.items()} 34 | if select is not None: 35 | string_dict = "{}_{}".format(metric, select) 36 | if string_dict in res: 37 | return res[string_dict] 38 | else: # If the metric is not on the dict, say that it was 0 39 | return 0 40 | else: 41 | return res 42 | else: 43 | return out_eval 44 | 45 | 46 | def init_eval(metric): 47 | if metric not in ["MRR@10", "recall@10", "recall@50", "recall@100", "recall@200", "recall@500", "recall@1000"]: 48 | raise NotImplementedError("provide valid metric") 49 | if metric == "MRR@10": 50 | return lambda x, y: mrr_k(x, y, k=10, agg=True) 51 | else: 52 | return lambda x, y: evaluate(x, y, metric="recall", agg=True, select=metric.split('@')[1]) 53 | -------------------------------------------------------------------------------- /rankify/utils/retrievers/splade/utils/processing_trec_eval.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from collections import defaultdict 4 | 5 | """bunch of methods to convert from trec_eval format to pytrec_eval (and vice versa) 6 | """ 7 | 8 | 9 | def build_json_qrel(qrel_file_path): 10 | """ 11 | input file has format: 186154 0 1160 1 12 | """ 13 | temp_d = defaultdict(dict) 14 | with open(qrel_file_path) as reader: 15 | for line in reader: 16 | q_id, _, d_id, rel = line.split("\t") 17 | temp_d[q_id][d_id] = int(rel) 18 | print("built qrel file, contains {} queries...", len(temp_d)) 19 | json.dump(dict(temp_d), open(os.path.join(os.path.dirname(qrel_file_path), "qrel.json"), "w")) 20 | --------------------------------------------------------------------------------