├── ragfit ├── __init__.py ├── models │ ├── __init__.py │ ├── openai_executor.py │ ├── vllm.py │ └── hf.py ├── evaluation │ ├── __init__.py │ ├── base.py │ ├── deep.py │ └── metrics.py ├── processing │ ├── __init__.py │ ├── global_steps │ │ ├── __init__.py │ │ ├── filters.py │ │ ├── output.py │ │ ├── aggregation.py │ │ └── sampling.py │ ├── local_steps │ │ ├── __init__.py │ │ ├── api │ │ │ ├── __init__.py │ │ │ └── openai.py │ │ ├── retrievers │ │ │ ├── __init__.py │ │ │ └── haystack.py │ │ ├── inference.py │ │ ├── context.py │ │ ├── formatting.py │ │ ├── raft.py │ │ ├── common_datasets.py │ │ └── prompter.py │ ├── answer_processors │ │ ├── __init__.py │ │ └── regex.py │ ├── dataset_loaders │ │ ├── __init__.py │ │ └── loaders.py │ ├── prompts │ │ ├── qa-short.txt │ │ ├── fact_statement.txt │ │ ├── qa.txt │ │ ├── reason_before_answer.txt │ │ ├── qa-with-answer.txt │ │ ├── qa-fewshot.txt │ │ ├── prompt_instructions │ │ │ ├── qa.txt │ │ │ ├── qa-yes-no.txt │ │ │ ├── qa-short.txt │ │ │ └── qa-yes-no-maybe.txt │ │ ├── cot.txt │ │ └── cot-fewshot.txt │ ├── utils.py │ ├── step.py │ └── pipeline.py └── utils.py ├── docs ├── reference │ ├── utils.md │ ├── models │ │ ├── hf.md │ │ ├── vllm.md │ │ └── openai_executor.md │ ├── evaluation │ │ ├── base.md │ │ ├── deep.md │ │ └── metrics.md │ └── processing │ │ ├── step.md │ │ ├── utils.md │ │ ├── pipeline.md │ │ ├── local_steps │ │ ├── raft.md │ │ ├── context.md │ │ ├── inference.md │ │ ├── prompter.md │ │ ├── api │ │ │ └── openai.md │ │ ├── formatting.md │ │ ├── common_datasets.md │ │ └── retrievers │ │ │ └── haystack.md │ │ ├── global_steps │ │ ├── filters.md │ │ ├── output.md │ │ ├── sampling.md │ │ └── aggregation.md │ │ ├── answer_processors │ │ └── regex.md │ │ └── dataset_loaders │ │ └── loaders.md ├── assets │ ├── rag_fit.png │ └── rag_fit_white.png ├── requirements.txt ├── stylesheets │ └── extra.css ├── README.md ├── scripts │ └── generate_docstrings.py ├── training.md ├── inference.md ├── index.md ├── evaluation.md ├── processing.md └── pubmed.md ├── assets └── rag_fit.png ├── .gitignore ├── ruff.toml ├── configs ├── inference-vllm.yaml ├── paper │ ├── evaluation-asqa-short.yaml │ ├── processing-asqa-baseline.yaml │ ├── evaluation-pubmed.yaml │ ├── inference-asqa.yaml │ ├── processing-asqa-cot-dev.yaml │ ├── inference-pubmed.yaml │ ├── evaluation-asqa-long.yaml │ ├── processing-asqa-context.yaml │ ├── processing-asqa-retrieval.yaml │ ├── processing-pubmed-context.yaml │ ├── processing-asqa-cot-train.yaml │ ├── training-asqa.yaml │ ├── training-pubmed.yaml │ └── training-pubmed-alternative.yaml ├── inference.yaml ├── external │ └── haystack │ │ └── qdrant.yaml ├── processing-nq.yaml ├── evaluation.yaml ├── training.yaml └── processing.yaml ├── processing.py ├── .github └── workflows │ └── docs.yml ├── .pre-commit-config.yaml ├── pyproject.toml ├── inference.py ├── evaluation.py ├── training.py ├── README.md ├── mkdocs.yml └── LICENSE /ragfit/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ragfit/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ragfit/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ragfit/processing/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/reference/utils.md: -------------------------------------------------------------------------------- 1 | ::: ragfit.utils -------------------------------------------------------------------------------- /ragfit/processing/global_steps/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ragfit/processing/local_steps/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/reference/models/hf.md: -------------------------------------------------------------------------------- 1 | ::: ragfit.models.hf -------------------------------------------------------------------------------- /ragfit/processing/answer_processors/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ragfit/processing/dataset_loaders/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ragfit/processing/local_steps/api/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/reference/models/vllm.md: -------------------------------------------------------------------------------- 1 | ::: ragfit.models.vllm -------------------------------------------------------------------------------- /ragfit/processing/local_steps/retrievers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ragfit/processing/prompts/qa-short.txt: -------------------------------------------------------------------------------- 1 | Question: {query} -------------------------------------------------------------------------------- /docs/reference/evaluation/base.md: -------------------------------------------------------------------------------- 1 | ::: ragfit.evaluation.base -------------------------------------------------------------------------------- /docs/reference/evaluation/deep.md: -------------------------------------------------------------------------------- 1 | ::: ragfit.evaluation.deep -------------------------------------------------------------------------------- /docs/reference/processing/step.md: -------------------------------------------------------------------------------- 1 | ::: ragfit.processing.step -------------------------------------------------------------------------------- /docs/reference/processing/utils.md: -------------------------------------------------------------------------------- 1 | ::: ragfit.processing.utils -------------------------------------------------------------------------------- /docs/reference/evaluation/metrics.md: -------------------------------------------------------------------------------- 1 | ::: ragfit.evaluation.metrics -------------------------------------------------------------------------------- /docs/reference/processing/pipeline.md: -------------------------------------------------------------------------------- 1 | ::: ragfit.processing.pipeline -------------------------------------------------------------------------------- /docs/reference/models/openai_executor.md: -------------------------------------------------------------------------------- 1 | ::: ragfit.models.openai_executor -------------------------------------------------------------------------------- /docs/reference/processing/local_steps/raft.md: -------------------------------------------------------------------------------- 1 | ::: ragfit.processing.local_steps.raft -------------------------------------------------------------------------------- /ragfit/processing/prompts/fact_statement.txt: -------------------------------------------------------------------------------- 1 | From [{i}], we know that: {content} -------------------------------------------------------------------------------- /ragfit/processing/prompts/qa.txt: -------------------------------------------------------------------------------- 1 | Question: {question} 2 | Context: 3 | {context} -------------------------------------------------------------------------------- /assets/rag_fit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelLabs/RAG-FiT/HEAD/assets/rag_fit.png -------------------------------------------------------------------------------- /docs/reference/processing/global_steps/filters.md: -------------------------------------------------------------------------------- 1 | ::: ragfit.processing.global_steps.filters -------------------------------------------------------------------------------- /docs/reference/processing/global_steps/output.md: -------------------------------------------------------------------------------- 1 | ::: ragfit.processing.global_steps.output -------------------------------------------------------------------------------- /docs/reference/processing/global_steps/sampling.md: -------------------------------------------------------------------------------- 1 | ::: ragfit.processing.global_steps.sampling -------------------------------------------------------------------------------- /docs/reference/processing/local_steps/context.md: -------------------------------------------------------------------------------- 1 | ::: ragfit.processing.local_steps.context -------------------------------------------------------------------------------- /docs/reference/processing/local_steps/inference.md: -------------------------------------------------------------------------------- 1 | ::: ragfit.processing.local_steps.inference -------------------------------------------------------------------------------- /docs/reference/processing/local_steps/prompter.md: -------------------------------------------------------------------------------- 1 | ::: ragfit.processing.local_steps.prompter -------------------------------------------------------------------------------- /ragfit/processing/prompts/reason_before_answer.txt: -------------------------------------------------------------------------------- 1 | {reasoning} 2 | Thus, the final answer is: -------------------------------------------------------------------------------- /docs/reference/processing/answer_processors/regex.md: -------------------------------------------------------------------------------- 1 | ::: ragfit.processing.answer_processors.regex -------------------------------------------------------------------------------- /docs/reference/processing/dataset_loaders/loaders.md: -------------------------------------------------------------------------------- 1 | ::: ragfit.processing.dataset_loaders.loaders -------------------------------------------------------------------------------- /docs/reference/processing/local_steps/api/openai.md: -------------------------------------------------------------------------------- 1 | ::: ragfit.processing.local_steps.api.openai -------------------------------------------------------------------------------- /docs/reference/processing/local_steps/formatting.md: -------------------------------------------------------------------------------- 1 | ::: ragfit.processing.local_steps.formatting -------------------------------------------------------------------------------- /docs/assets/rag_fit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelLabs/RAG-FiT/HEAD/docs/assets/rag_fit.png -------------------------------------------------------------------------------- /docs/reference/processing/global_steps/aggregation.md: -------------------------------------------------------------------------------- 1 | ::: ragfit.processing.global_steps.aggregation -------------------------------------------------------------------------------- /docs/reference/processing/local_steps/common_datasets.md: -------------------------------------------------------------------------------- 1 | ::: ragfit.processing.local_steps.common_datasets -------------------------------------------------------------------------------- /docs/assets/rag_fit_white.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelLabs/RAG-FiT/HEAD/docs/assets/rag_fit_white.png -------------------------------------------------------------------------------- /docs/reference/processing/local_steps/retrievers/haystack.md: -------------------------------------------------------------------------------- 1 | ::: ragfit.processing.local_steps.retrievers.haystack -------------------------------------------------------------------------------- /ragfit/processing/prompts/qa-with-answer.txt: -------------------------------------------------------------------------------- 1 | Question: {question} 2 | Context: 3 | {context} 4 | Answer: {answer} -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /.python-version 2 | /outputs/ 3 | __pycache__/ 4 | /site/ 5 | /multirun/ 6 | wandb 7 | .ipynb_checkpoints 8 | -------------------------------------------------------------------------------- /ragfit/processing/prompts/qa-fewshot.txt: -------------------------------------------------------------------------------- 1 | Examples: 2 | {fewshot} 3 | 4 | Question: {question} 5 | Context: 6 | {context} -------------------------------------------------------------------------------- /ragfit/processing/prompts/prompt_instructions/qa.txt: -------------------------------------------------------------------------------- 1 | You are a helpful question answerer who can provide an answer given a question and relevant context. -------------------------------------------------------------------------------- /ruff.toml: -------------------------------------------------------------------------------- 1 | line-length = 90 2 | 3 | [lint] 4 | select = ["E", "F", "W", "I", "N", "Q"] 5 | ignore = ["E203", "F841", "E501", "F821"] 6 | exclude = ["*.ipynb"] -------------------------------------------------------------------------------- /ragfit/processing/prompts/prompt_instructions/qa-yes-no.txt: -------------------------------------------------------------------------------- 1 | You are a helpful question answerer who can provide an answer given a question and relevant context. Please answer with yes or no. -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | mkdocs-material>=9.5.30 2 | mkdocstrings>=0.25.2 3 | mkdocs-gen-files 4 | mkdocstrings-python-legacy 5 | mkdocstrings-python>=1.10.7 6 | pymdown-extensions>=10.9 7 | mkdocs-material-extensions>=1.3.1 -------------------------------------------------------------------------------- /ragfit/processing/prompts/prompt_instructions/qa-short.txt: -------------------------------------------------------------------------------- 1 | You are a helpful question answerer who can provide an answer given a question and relevant context. Answer the following question with a short span. The answer needs to be just in a few words. 2 | -------------------------------------------------------------------------------- /ragfit/processing/prompts/prompt_instructions/qa-yes-no-maybe.txt: -------------------------------------------------------------------------------- 1 | You are a helpful question answerer who can provide an answer given a question and relevant context. Please answer with "yes", "no" or "maybe", if there is not enough information to answer the question. -------------------------------------------------------------------------------- /ragfit/processing/global_steps/filters.py: -------------------------------------------------------------------------------- 1 | """Module containing filters.""" 2 | 3 | 4 | def msmarco_positive_filter(x): 5 | """ 6 | Identify the positive passages in MSMARCO dataset. 7 | """ 8 | return 1 in x["passages"]["is_selected"] 9 | 10 | 11 | filters = {"MSMARCO": msmarco_positive_filter} 12 | -------------------------------------------------------------------------------- /ragfit/utils.py: -------------------------------------------------------------------------------- 1 | def check_package_installed(package_name: str, optional_msg: str = ""): 2 | """ 3 | Check if a package is installed. 4 | """ 5 | 6 | import importlib.util 7 | 8 | if importlib.util.find_spec(package_name) is None: 9 | raise ImportError(f"{package_name} package is not installed; {optional_msg}") 10 | -------------------------------------------------------------------------------- /docs/stylesheets/extra.css: -------------------------------------------------------------------------------- 1 | [data-md-color-scheme="ragfit"] { 2 | --md-primary-fg-color: #2B66AF; 3 | --md-primary-fg-color--light: #f2f2f2; 4 | --md-primary-fg-color--dark: #333333; 5 | } 6 | 7 | /* Indentation. */ 8 | div.doc-contents:not(.first) { 9 | padding-left: 20px; 10 | border-left: .05rem solid var(--md-typeset-table-color); 11 | } 12 | 13 | -------------------------------------------------------------------------------- /configs/inference-vllm.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | _target_: ragfit.models.vllm.VLLMInference 3 | model_name_or_path: "facebook/opt-125m" 4 | llm_params: 5 | dtype: auto 6 | generation: 7 | temperature: 0.5 8 | top_p: 0.95 9 | seed: 1911 10 | num_gpus: 1 11 | 12 | data_file: my-processed-data.jsnol 13 | generated_file: model-predictions.jsonl 14 | input_key: prompt 15 | generation_key: output 16 | target_key: answers 17 | limit: -------------------------------------------------------------------------------- /processing.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import hydra 4 | from omegaconf import OmegaConf 5 | 6 | from ragfit.processing.pipeline import DataPipeline 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | @hydra.main(version_base=None, config_path="./configs", config_name="processing") 12 | def main(args): 13 | logger.info(OmegaConf.to_yaml(args)) 14 | 15 | pipeline = DataPipeline(**args) 16 | pipeline.process() 17 | 18 | 19 | if __name__ == "__main__": 20 | main() 21 | -------------------------------------------------------------------------------- /.github/workflows/docs.yml: -------------------------------------------------------------------------------- 1 | name: Publish docs via GitHub Pages 2 | on: 3 | push: 4 | branches: 5 | - main 6 | 7 | jobs: 8 | build: 9 | name: Deploy docs 10 | runs-on: ubuntu-latest 11 | steps: 12 | - name: Checkout main 13 | uses: actions/checkout@v4 14 | 15 | - name: Deploy docs 16 | uses: mhausenblas/mkdocs-deploy-gh-pages@1.26 17 | env: 18 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 19 | REQUIREMENTS: docs/requirements.txt 20 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v4.6.0 4 | hooks: 5 | - id: trailing-whitespace 6 | - id: check-docstring-first 7 | - id: check-added-large-files 8 | - id: check-toml 9 | - id: check-yaml 10 | args: ["--unsafe", "mkdocs.yml"] 11 | - id: check-merge-conflict 12 | - repo: https://github.com/astral-sh/ruff-pre-commit 13 | rev: v0.6.3 14 | hooks: 15 | - id: ruff 16 | args: ["--fix"] 17 | - id: ruff-format 18 | -------------------------------------------------------------------------------- /configs/paper/evaluation-asqa-short.yaml: -------------------------------------------------------------------------------- 1 | answer_processor: 2 | _target_: ragfit.processing.answer_processors.regex.RegexAnswer 3 | capture_pattern: ": (.*)" 4 | stopping_pattern: 5 | 6 | metrics: 7 | - _target_: ragfit.evaluation.metrics.StringEM 8 | 9 | key_names: 10 | generated: text 11 | label: answer-short 12 | query: query 13 | 14 | results_file: evaluation-asqa-baseline.yaml 15 | generated_file: asqa-baseline-dev-generated.jsonl 16 | data_file: asqa-baseline-dev.jsonl 17 | limit: 18 | use_wandb: 19 | experiment: 20 | wandb_entity: 21 | -------------------------------------------------------------------------------- /ragfit/processing/prompts/cot.txt: -------------------------------------------------------------------------------- 1 | Question: {question} 2 | Context: 3 | {context} 4 | 5 | Answer this question using the information given in the context above. Here is things to pay attention to: 6 | - First provide step-by-step reasoning on how to answer the question. 7 | - In the reasoning, if you need to copy paste some sentences from the context, include them in ##begin_quote## and ##end_quote##. This would mean that things outside of ##begin_quote## and ##end_quote## are not directly copy paste from the context. 8 | - End your response with final answer in the form : $answer, the answer should be succinct. -------------------------------------------------------------------------------- /configs/paper/processing-asqa-baseline.yaml: -------------------------------------------------------------------------------- 1 | name: asqa_baseline 2 | cache: false 3 | output_path: . 4 | steps: 5 | - _target_: ragfit.processing.dataset_loaders.loaders.LocalLoader 6 | inputs: dev 7 | filename: asqa-dev.jsonl 8 | 9 | - _target_: ragfit.processing.local_steps.prompter.TextPrompter 10 | inputs: dev 11 | prompt_file: ragfit/processing/prompts/qa-short.txt 12 | output_key: prompt 13 | mapping: 14 | query: query 15 | 16 | - _target_: ragfit.processing.global_steps.output.OutputData 17 | inputs: dev 18 | prefix: asqa-baseline 19 | 20 | -------------------------------------------------------------------------------- /configs/paper/evaluation-pubmed.yaml: -------------------------------------------------------------------------------- 1 | answer_processor: 2 | _target_: ragfit.processing.answer_processors.regex.RegexAnswer 3 | capture_pattern: 4 | stopping_pattern: 5 | 6 | metrics: 7 | - _target_: ragfit.evaluation.metrics.Classification 8 | mapping: 9 | "yes": 1 10 | "no": 0 11 | "maybe": 2 12 | else_value: 2 13 | 14 | key_names: 15 | generated: text 16 | label: answers 17 | query: query 18 | 19 | results_file: evaluation-pubmed-rag.yaml 20 | generated_file: pubmed-rag-test-generated.jsonl 21 | data_file: pubmed-rag-test.jsonl 22 | limit: 23 | use_wandb: 24 | experiment: 25 | wandb_entity: 26 | -------------------------------------------------------------------------------- /ragfit/processing/prompts/cot-fewshot.txt: -------------------------------------------------------------------------------- 1 | Some examples: 2 | {fewshot} 3 | 4 | Question: {question} 5 | Context: 6 | {context} 7 | 8 | Answer this question using the information given in the context above. Here is things to pay attention to: 9 | - First provide step-by-step reasoning on how to answer the question. 10 | - In the reasoning, if you need to copy paste some sentences from the context, include them in ##begin_quote## and ##end_quote##. This would mean that things outside of ##begin_quote## and ##end_quote## are not directly copy paste from the context. 11 | - End your response with final answer in the form : $answer, the answer should be succinct. -------------------------------------------------------------------------------- /ragfit/processing/utils.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import json 3 | from typing import Any, Dict 4 | 5 | 6 | def dict_hash(dictionary: Dict[str, Any]) -> str: 7 | """ 8 | Hash dictionary using MD5. Used in step caching; steps are cached based on the signature. 9 | """ 10 | dhash = hashlib.md5() 11 | encoded = json.dumps(dictionary, sort_keys=True).encode() 12 | dhash.update(encoded) 13 | return dhash.hexdigest() 14 | 15 | 16 | def is_jsonable(x): 17 | """ 18 | Test if input is JSON-serializable. 19 | """ 20 | try: 21 | json.dumps(x) 22 | return True 23 | except (TypeError, OverflowError): 24 | return False 25 | -------------------------------------------------------------------------------- /configs/inference.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | _target_: ragfit.models.hf.HFInference 3 | model_name_or_path: microsoft/Phi-3-mini-128k-instruct 4 | load_in_4bit: false 5 | load_in_8bit: true 6 | device_map: auto 7 | torch_dtype: 8 | trust_remote_code: true 9 | instruction: ragfit/processing/prompts/prompt_instructions/qa.txt 10 | instruct_in_prompt: false 11 | lora_path: 12 | generation: 13 | do_sample: false 14 | max_new_tokens: 50 15 | max_length: 16 | temperature: 17 | top_k: 18 | top_p: 19 | return_full_text: false 20 | 21 | data_file: my-processed-data.jsnol 22 | generated_file: model-predictions.jsonl 23 | input_key: prompt 24 | generation_key: output 25 | target_key: answer 26 | limit: 27 | -------------------------------------------------------------------------------- /configs/paper/inference-asqa.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | _target_: ragfit.models.hf.HFInference 3 | model_name_or_path: microsoft/Phi-3-mini-128k-instruct 4 | load_in_4bit: false 5 | load_in_8bit: true 6 | device_map: auto 7 | torch_dtype: 8 | trust_remote_code: true 9 | instruction: ragfit/processing/prompts/prompt_instructions/qa.txt 10 | instruct_in_prompt: false 11 | lora_path: 12 | generation: 13 | do_sample: false 14 | max_new_tokens: 50 15 | max_length: 16 | temperature: 17 | top_k: 18 | top_p: 19 | return_full_text: false 20 | 21 | data_file: asqa-baseline-dev.jsonl 22 | generated_file: asqa-baseline-dev-generated.jsonl 23 | input_key: prompt 24 | generation_key: output 25 | target_key: answers 26 | limit: 27 | 28 | -------------------------------------------------------------------------------- /configs/paper/processing-asqa-cot-dev.yaml: -------------------------------------------------------------------------------- 1 | name: asqa_cot_dev 2 | cache: false 3 | output_path: . 4 | steps: 5 | - _target_: ragfit.processing.dataset_loaders.loaders.LocalLoader 6 | inputs: dev 7 | filename: asqa-dev.jsonl 8 | 9 | - _target_: ragfit.processing.local_steps.context.DocumentsJoiner 10 | inputs: dev 11 | docs_key: positive_passages 12 | k: 5 13 | 14 | - _target_: ragfit.processing.local_steps.prompter.TextPrompter 15 | inputs: dev 16 | prompt_file: ragfit/processing/prompts/cot.txt 17 | output_key: prompt 18 | mapping: 19 | question: query 20 | context: positive_passages 21 | 22 | - _target_: ragfit.processing.global_steps.output.OutputData 23 | inputs: dev 24 | prefix: asqa-cot 25 | 26 | -------------------------------------------------------------------------------- /configs/paper/inference-pubmed.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | _target_: ragfit.models.hf.HFInference 3 | model_name_or_path: microsoft/Phi-3-mini-128k-instruct 4 | load_in_4bit: false 5 | load_in_8bit: true 6 | device_map: auto 7 | torch_dtype: 8 | trust_remote_code: true 9 | instruction: ragfit/processing/prompts/prompt_instructions/qa-yes-no.txt 10 | instruct_in_prompt: false 11 | lora_path: ./trained_model/checkpoint 12 | generation: 13 | do_sample: false 14 | max_new_tokens: 50 15 | max_length: 16 | temperature: 17 | top_k: 18 | top_p: 19 | return_full_text: false 20 | 21 | data_file: pubmed-rag-test.jsonl 22 | generated_file: pubmed-rag-test-generated.jsonl 23 | input_key: prompt 24 | generation_key: output 25 | target_key: answers 26 | limit: 27 | 28 | -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | # Generating RAG-FiT Documentation 2 | 3 | ## Installation 4 | 5 | Install python packages required for building mkdocs documentation website. 6 | 7 | ``` sh 8 | pip install -r docs/requirements.txt 9 | ``` 10 | 11 | ## Adding new content 12 | 13 | - Generate python docstrings using: 14 | 15 | ``` sh 16 | cd 17 | python docs/scripts/generate_docstring.py 18 | ``` 19 | 20 | - Add new API documentation to `mkdocs.yaml` in the appropriate section under the API main section. 21 | - Add any new context as a markdown file and in the appropriate section under `nav` section in `mkdocs.yaml`. 22 | - Check website functions correctly. 23 | 24 | ``` sh 25 | mkdocs serve 26 | ``` 27 | 28 | Go to `http://127.0.0.1:8000/` for a live preview. 29 | 30 | - Build and upload new website: 31 | 32 | ``` sh 33 | mkdocs build 34 | ``` -------------------------------------------------------------------------------- /ragfit/evaluation/base.py: -------------------------------------------------------------------------------- 1 | class MetricBase: 2 | """ 3 | Base class for metrics. 4 | 5 | Metrics can be local or global; local means score are calculated per example. 6 | Global means score is calculated by looking at the entire dataset, e.g. fluency. 7 | """ 8 | 9 | def __init__(self, key_names, **kwargs): 10 | self.key_names = key_names 11 | self.kwargs = kwargs 12 | self.field = self.key_names["generated"] 13 | self.target = self.key_names["label"] 14 | 15 | def measure(self, example: dict) -> dict[str, float]: 16 | """ 17 | Measure the performance of the model on a given example. 18 | 19 | Parameters: 20 | example (dict): The example to evaluate the model on. 21 | 22 | Returns: 23 | dict[str, float]: A dictionary containing the performance metrics. 24 | 25 | """ 26 | pass 27 | -------------------------------------------------------------------------------- /configs/paper/evaluation-asqa-long.yaml: -------------------------------------------------------------------------------- 1 | answer_processor: 2 | _target_: ragfit.processing.answer_processors.regex.RegexAnswer 3 | capture_pattern: ": (.*)" 4 | stopping_pattern: 5 | 6 | metrics: 7 | - _target_: ragfit.evaluation.deep.Faithfulness 8 | azure_endpoint: azure.endpoint.com 9 | azure_deployment: GPT-4-32k-Bot 10 | api_version: 2024-05-01-preview 11 | - _target_: ragfit.evaluation.deep.Relevancy 12 | azure_endpoint: azure.endpoint.com 13 | azure_deployment: GPT-4-32k-Bot 14 | api_version: 2024-05-01-preview 15 | embeddings: BAAI/bge-small-en-v1.5 16 | 17 | key_names: 18 | generated: text 19 | label: answers 20 | query: query 21 | context: positive_passages 22 | 23 | results_file: asqa-context-dev-generated-results.yaml 24 | generated_file: asqa-context-dev-generated.jsonl 25 | data_file: asqa-context-dev.jsonl 26 | limit: 27 | use_wandb: 28 | experiment: 29 | wandb_entity: 30 | -------------------------------------------------------------------------------- /configs/paper/processing-asqa-context.yaml: -------------------------------------------------------------------------------- 1 | name: asqa_context 2 | cache: false 3 | output_path: . 4 | steps: 5 | - _target_: ragfit.processing.dataset_loaders.loaders.LocalLoader 6 | inputs: train 7 | filename: asqa-train.jsonl 8 | 9 | - _target_: ragfit.processing.dataset_loaders.loaders.LocalLoader 10 | inputs: dev 11 | filename: asqa-dev.jsonl 12 | 13 | - _target_: ragfit.processing.local_steps.context.DocumentsJoiner 14 | inputs: [train, dev] 15 | docs_key: positive_passages 16 | k: 5 17 | 18 | - _target_: ragfit.processing.local_steps.prompter.TextPrompter 19 | inputs: [train, dev] 20 | prompt_file: ragfit/processing/prompts/qa.txt 21 | output_key: prompt 22 | mapping: 23 | question: query 24 | context: positive_passages 25 | 26 | - _target_: ragfit.processing.global_steps.output.OutputData 27 | inputs: [train, dev] 28 | prefix: asqa-context 29 | 30 | -------------------------------------------------------------------------------- /ragfit/processing/answer_processors/regex.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | 4 | class RegexAnswer: 5 | """ 6 | Extract answers from the text using regular expressions. 7 | 8 | Pattern is the regular expression used to extract the answer. 9 | Stopping pattern is a string used to split the answer. 10 | 11 | Example: 12 | `r = RegexAnswer(": (.*)", "[,.;]")` 13 | """ 14 | 15 | def __init__(self, capture_pattern=None, stopping_pattern=None): 16 | self.capture_pattern = capture_pattern 17 | self.stopping_pattern = stopping_pattern 18 | 19 | def __call__(self, text: str): 20 | """ 21 | Extract the answer from the text. 22 | """ 23 | if (capture := self.capture_pattern) and capture != "": 24 | match = re.search(capture, text, re.MULTILINE | re.DOTALL) 25 | if match: 26 | text = match.group(1) 27 | 28 | if (stopping := self.stopping_pattern) and stopping != "": 29 | text = re.split(stopping, text)[0] 30 | 31 | return text 32 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "ragfit" 3 | version = "1.2.0" 4 | description = "Framework for enhancing LLMs for RAG tasks using fine-tuning." 5 | readme = "README.md" 6 | license = {file = "LICENSE"} 7 | requires-python = ">=3.10" 8 | dependencies = [ 9 | "bert-score>=0.3.13", 10 | "bitsandbytes==0.42.0", 11 | "datasets==2.16.1", 12 | "evaluate==0.4.1", 13 | "hydra-core==1.3.2", 14 | "nltk==3.9", 15 | "openai==1.23.3", 16 | "peft==0.11.1", 17 | "pyyaml==6.0.1", 18 | "rouge-score==0.1.2", 19 | "sentence-transformers==2.4.0", 20 | "sentencepiece==0.2.0", 21 | "torch==2.6.0", 22 | "transformers>=4.50.0", 23 | "trl==0.8.6", 24 | "wandb==0.16.4", 25 | ] 26 | 27 | [project.urls] 28 | Homepage = "https://github.com/IntelLabs/RAG-FiT" 29 | Documentation = "https://intellabs.github.io/RAG-FiT/" 30 | 31 | [tool.setuptools] 32 | packages = ["ragfit"] 33 | 34 | [project.optional-dependencies] 35 | deepeval = [ 36 | "deepeval==0.21.73", 37 | ] 38 | haystack = [ 39 | "haystack-ai==2.3.1", 40 | "qdrant-haystack>=5.0.0", 41 | ] 42 | 43 | -------------------------------------------------------------------------------- /configs/external/haystack/qdrant.yaml: -------------------------------------------------------------------------------- 1 | components: 2 | retriever: 3 | init_parameters: 4 | document_store: 5 | type: haystack_integrations.document_stores.qdrant.document_store.QdrantDocumentStore 6 | init_parameters: 7 | url: qdrant.url.com 8 | port: 6333 9 | index: wikipedia 10 | embedding_dim: 768 11 | similarity: dot_product 12 | write_batch_size: 50 13 | top_k: 10 14 | type: haystack_integrations.components.retrievers.qdrant.retriever.QdrantEmbeddingRetriever 15 | text_embedder: 16 | init_parameters: 17 | batch_size: 64 18 | model: BAAI/llm-embedder 19 | prefix: "Represent this query for retrieving relevant documents: " 20 | device: 21 | type: haystack.components.embedders.sentence_transformers_text_embedder.SentenceTransformersTextEmbedder 22 | connections: 23 | - receiver: retriever.query_embedding 24 | sender: text_embedder.embedding 25 | 26 | -------------------------------------------------------------------------------- /ragfit/processing/local_steps/api/openai.py: -------------------------------------------------------------------------------- 1 | from ragfit.models.openai_executor import OpenAIExecutor 2 | 3 | from ...step import LocalStep 4 | 5 | 6 | class OpenAIChat(LocalStep): 7 | """ 8 | Interaction with OpenAI service. 9 | 10 | Model is represented by the `OpenAIExecutor`. 11 | 12 | This step is a wrapper, extracting the prompt from the item, interact with the API, and 13 | saves the response to the `answer` key in the item. 14 | """ 15 | 16 | def __init__(self, model, instruction, prompt_key, answer_key, **kwargs): 17 | """ 18 | Args: 19 | model (dict): Configuration for the OpenAIExecutor. 20 | instruction (str): Path to the system instruction file. 21 | prompt_key (str): Key to the prompt in the item. 22 | answer_key (str): Key to store the response. 23 | """ 24 | super().__init__(**kwargs) 25 | 26 | self.model = OpenAIExecutor(**model) 27 | self.prompt_key = prompt_key 28 | self.answer_key = answer_key 29 | self.instruction = open(instruction).read() 30 | 31 | def process_item(self, item, index, datasets, **kwargs): 32 | answer = self.model.chat(item[self.prompt_key], self.instruction) 33 | item[self.answer_key] = answer 34 | return item 35 | -------------------------------------------------------------------------------- /ragfit/processing/local_steps/inference.py: -------------------------------------------------------------------------------- 1 | """Module for inference steps, which can use LLM output to augment the data.""" 2 | 3 | from ragfit.models.vllm import VLLMInference 4 | 5 | from ..step import LocalStep 6 | 7 | 8 | class HFStep(LocalStep): 9 | """ 10 | Class for running inference with a Hugging Face model based on the vLLM engine. 11 | """ 12 | 13 | def __init__(self, input_key, output_key, model_kwargs, **kwargs): 14 | """ 15 | Initialize the HFStep class. 16 | 17 | Args: 18 | input_key (str): The key for the input text to be served as the prompt. 19 | output_key (str): The key for for saving the generated text. 20 | model_kwargs (dict): The keyword arguments to pass to the vLLM model. 21 | **kwargs: Additional keyword arguments to pass to the LocalStep. 22 | """ 23 | super().__init__(**kwargs) 24 | self.input_key = input_key 25 | self.output_key = output_key 26 | self.model_kwargs = model_kwargs 27 | self.model = VLLMInference(**model_kwargs) 28 | 29 | def process_item(self, item, index, datasets, **kwargs): 30 | prompt = item[self.input_key] 31 | response = self.model.generate(prompt) 32 | item[self.output_key] = response 33 | return item 34 | -------------------------------------------------------------------------------- /configs/processing-nq.yaml: -------------------------------------------------------------------------------- 1 | name: nq 2 | cache: false 3 | output_path: . 4 | steps: 5 | - _target_: ragfit.processing.dataset_loaders.loaders.HFLoader 6 | inputs: train 7 | dataset_config: 8 | path: Tevatron/wikipedia-nq 9 | split: train 10 | 11 | - _target_: ragfit.processing.global_steps.sampling.ShuffleSelect 12 | inputs: train 13 | shuffle: 42 14 | limit: 10000 15 | 16 | - _target_: ragfit.processing.local_steps.prompter.TextPrompter 17 | inputs: train 18 | prompt_file: ragfit/processing/prompts/qa-short.txt 19 | output_key: prompt 20 | mapping: 21 | query: query 22 | 23 | - _target_: ragfit.processing.local_steps.inference.HFStep 24 | inputs: train 25 | input_key: prompt 26 | output_key: generated 27 | model_kwargs: 28 | model_name_or_path: meta-llama/Meta-Llama-3.1-8B-Instruct 29 | instruction: ragfit/processing/prompts/prompt_instructions/qa-short.txt 30 | num_gpus: 2 31 | llm_params: 32 | dtype: auto 33 | max_model_len: 4096 34 | generation: 35 | temperature: 0 36 | max_tokens: 50 37 | 38 | - _target_: ragfit.processing.global_steps.output.OutputData 39 | inputs: train 40 | prefix: nq-with-answers 41 | -------------------------------------------------------------------------------- /configs/paper/processing-asqa-retrieval.yaml: -------------------------------------------------------------------------------- 1 | name: asqa_retrieval 2 | cache: false 3 | output_path: . 4 | steps: 5 | - _target_: ragfit.processing.dataset_loaders.loaders.HFLoader 6 | inputs: train 7 | dataset_config: 8 | path: din0s/asqa 9 | split: train 10 | 11 | - _target_: ragfit.processing.dataset_loaders.loaders.HFLoader 12 | inputs: dev 13 | dataset_config: 14 | path: din0s/asqa 15 | split: dev 16 | 17 | - _target_: ragfit.processing.local_steps.common_datasets.ASQA 18 | inputs: [train, dev] 19 | 20 | - _target_: 21 | ragfit.processing.local_steps.retrievers.haystack.HaystackRetriever 22 | inputs: [train, dev] 23 | pipeline_or_yaml_path: ./configs/external/haystack/qdrant.yaml 24 | docs_key: positive_passages 25 | query_key: query 26 | 27 | - _target_: ragfit.processing.local_steps.context.ContextHandler 28 | inputs: [train, dev] 29 | docs_key: positive_passages 30 | 31 | - _target_: ragfit.processing.global_steps.sampling.Sampler 32 | inputs: [train, dev] 33 | k: 1 34 | input_key: positive_passages 35 | output_key: negative_passages 36 | 37 | - _target_: ragfit.processing.global_steps.output.OutputData 38 | inputs: [train, dev] 39 | prefix: asqa 40 | 41 | -------------------------------------------------------------------------------- /configs/evaluation.yaml: -------------------------------------------------------------------------------- 1 | answer_processor: 2 | _target_: ragfit.processing.answer_processors.regex.RegexAnswer 3 | capture_pattern: # ": (.*)" 4 | stopping_pattern: # "[,.;]" 5 | 6 | metrics: 7 | - _target_: ragfit.evaluation.metrics.HFEvaluate 8 | metric_names: [rouge] 9 | - _target_: ragfit.evaluation.metrics.EM 10 | - _target_: ragfit.evaluation.metrics.StringEM 11 | - _target_: ragfit.evaluation.metrics.F1 12 | - _target_: ragfit.evaluation.metrics.BERTScore 13 | model: microsoft/deberta-large-mnli 14 | - _target_: ragfit.evaluation.metrics.Semantic 15 | model: vectara/hallucination_evaluation_model 16 | - _target_: src.evaluation.metrics.Classification 17 | mapping: {"yes": 1, "no": 0, "maybe": 2} 18 | else_value: 2 19 | - _target_: ragfit.evaluation.deep.Faithfulness 20 | azure_endpoint: 21 | azure_deployment: 22 | api_version: 23 | - _target_: ragfit.evaluation.deep.Relevancy 24 | azure_endpoint: 25 | azure_deployment: 26 | api_version: 27 | embeddings: BAAI/bge-small-en-v1.5 28 | 29 | 30 | key_names: 31 | generated: generated 32 | label: answer 33 | query: query 34 | context: context 35 | 36 | results_file: my-evaluation.yaml 37 | generated_file: inference.jsonl 38 | data_file: my-processed-data.jsonl 39 | limit: 40 | use_wandb: 41 | experiment: 42 | wandb_entity: -------------------------------------------------------------------------------- /configs/paper/processing-pubmed-context.yaml: -------------------------------------------------------------------------------- 1 | name: pubmed_rag 2 | cache: true 3 | output_path: . 4 | steps: 5 | - _target_: ragfit.processing.dataset_loaders.loaders.HFLoader 6 | inputs: train 7 | dataset_config: 8 | path: bigbio/pubmed_qa 9 | split: train 10 | 11 | - _target_: ragfit.processing.dataset_loaders.loaders.HFLoader 12 | inputs: test 13 | dataset_config: 14 | path: bigbio/pubmed_qa 15 | name: pubmed_qa_labeled_fold0_source 16 | split: test 17 | 18 | - _target_: ragfit.processing.global_steps.sampling.ShuffleSelect 19 | inputs: train 20 | limit: 50000 21 | 22 | - _target_: ragfit.processing.local_steps.common_datasets.PubMed 23 | inputs: [train, test] 24 | 25 | - _target_: ragfit.processing.local_steps.context.DocumentsJoiner 26 | inputs: [train, test] 27 | docs_key: positive_passages 28 | k: 5 29 | 30 | - _target_: ragfit.processing.local_steps.prompter.TextPrompter 31 | inputs: [train, test] 32 | prompt_file: ragfit/processing/prompts/qa.txt 33 | output_key: prompt 34 | mapping: 35 | question: query 36 | context: positive_passages 37 | 38 | - _target_: ragfit.processing.global_steps.output.OutputData 39 | inputs: [train, test] 40 | prefix: pubmed-rag 41 | 42 | -------------------------------------------------------------------------------- /ragfit/processing/dataset_loaders/loaders.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset 2 | 3 | from ..step import GlobalStep 4 | 5 | 6 | class HFLoader(GlobalStep): 7 | """ 8 | Class to load a dataset using the Hugging Face datasets library. 9 | Can use either a HuggingFace tag or a local file. 10 | 11 | Caching is disabled as this step does not manipulate the dataset hence no need for caching. 12 | """ 13 | 14 | def __init__(self, dataset_config: dict, **kwargs): 15 | super().__init__(**kwargs) 16 | self.dataset_config = dataset_config 17 | self.cache_step = False 18 | 19 | def process(self, dataset_name, datasets, **kwargs): 20 | datasets[dataset_name] = load_dataset(**self.dataset_config) 21 | 22 | 23 | class LocalLoader(GlobalStep): 24 | """ 25 | Class to load a local dataset, in a JSON format. 26 | 27 | Caching is disabled as this step does not manipulate the dataset hence no need for caching. 28 | """ 29 | 30 | def __init__(self, filename, **kwargs): 31 | super().__init__(**kwargs) 32 | self.filename = filename 33 | self.cache_step = False 34 | 35 | def process(self, dataset_name, datasets, **kwargs): 36 | datasets[dataset_name] = load_dataset( 37 | "json", 38 | data_files=self.filename, 39 | split="train", 40 | ) 41 | -------------------------------------------------------------------------------- /configs/paper/processing-asqa-cot-train.yaml: -------------------------------------------------------------------------------- 1 | name: asqa_cot_raft_train 2 | cache: false 3 | output_path: . 4 | steps: 5 | - _target_: ragfit.processing.dataset_loaders.loaders.LocalLoader 6 | inputs: train 7 | filename: asqa-train.jsonl 8 | 9 | - _target_: ragfit.processing.local_steps.raft.RAFTStep 10 | inputs: train 11 | k: 5 12 | raft_p: 0.5 13 | neg_docs_num: 2 14 | output_key: raft_docs 15 | 16 | - _target_: ragfit.processing.local_steps.context.DocumentsJoiner 17 | inputs: train 18 | docs_key: raft_docs 19 | k: 20 | 21 | - _target_: ragfit.processing.local_steps.prompter.TextPrompter 22 | inputs: train 23 | prompt_file: ragfit/processing/prompts/cot.txt 24 | output_key: prompt 25 | mapping: 26 | question: query 27 | context: raft_docs 28 | 29 | - _target_: ragfit.processing.local_steps.api.openai.OpenAIChat 30 | inputs: train 31 | prompt_key: prompt 32 | answer_key: generated_answer 33 | instruction: ragfit/processing/prompts/prompt_instructions/qa.txt 34 | model: 35 | azure_endpoint: azure.endpoint.com 36 | api_version: 2024-05-01-preview 37 | model: GPT-4-32k-Bot 38 | 39 | - _target_: ragfit.processing.global_steps.output.OutputData 40 | inputs: train 41 | prefix: asqa-raft-cot 42 | 43 | -------------------------------------------------------------------------------- /configs/paper/training-asqa.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | _target_: ragfit.models.hf.HFTrain 3 | model_name_or_path: microsoft/Phi-3-mini-128k-instruct 4 | load_in_4bit: false 5 | load_in_8bit: true 6 | torch_dtype: 7 | device_map: 8 | trust_remote_code: true 9 | lora: 10 | bias: none 11 | fan_in_fan_out: false 12 | layers_pattern: 13 | layers_to_transform: 14 | lora_alpha: 16 15 | lora_dropout: 0.1 16 | peft_type: LORA 17 | r: 16 18 | target_modules: 19 | - qkv_proj 20 | task_type: CAUSAL_LM 21 | use_rslora: true 22 | completion_start: <|assistant|> 23 | instruction_in_prompt: 24 | max_sequence_len: 4000 25 | 26 | train: 27 | output_dir: ./trained_models/ 28 | bf16: false 29 | fp16: false 30 | gradient_accumulation_steps: 2 31 | group_by_length: 32 | learning_rate: 1e-4 33 | logging_steps: 10 34 | lr_scheduler_type: cosine 35 | max_steps: -1 36 | num_train_epochs: 1 37 | per_device_train_batch_size: 1 38 | optim: paged_adamw_8bit 39 | remove_unused_columns: true 40 | save_steps: 20000 41 | save_total_limit: 1 42 | warmup_ratio: 0.03 43 | weight_decay: 0.001 44 | report_to: 45 | 46 | instruction: ragfit/processing/prompts/prompt_instructions/qa.txt 47 | template: 48 | data_file: 49 | input_key: prompt 50 | output_key: 51 | resume_checkpoint: 52 | limit: 53 | shuffle: 54 | hfhub_tag: 55 | use_wandb: 56 | experiment: 57 | wandb_entity: 58 | 59 | -------------------------------------------------------------------------------- /configs/paper/training-pubmed.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | _target_: ragfit.models.hf.HFTrain 3 | model_name_or_path: microsoft/Phi-3-mini-128k-instruct 4 | load_in_4bit: false 5 | load_in_8bit: true 6 | torch_dtype: 7 | device_map: 8 | trust_remote_code: true 9 | lora: 10 | bias: none 11 | fan_in_fan_out: false 12 | layers_pattern: 13 | layers_to_transform: 14 | lora_alpha: 16 15 | lora_dropout: 0.1 16 | peft_type: LORA 17 | r: 16 18 | target_modules: 19 | - qkv_proj 20 | task_type: CAUSAL_LM 21 | use_rslora: true 22 | completion_start: <|assistant|> 23 | instruction_in_prompt: 24 | max_sequence_len: 2000 25 | 26 | train: 27 | output_dir: ./trained_model/ 28 | bf16: false 29 | fp16: false 30 | gradient_accumulation_steps: 2 31 | group_by_length: 32 | learning_rate: 1e-4 33 | logging_steps: 10 34 | lr_scheduler_type: cosine 35 | max_steps: -1 36 | num_train_epochs: 1 37 | per_device_train_batch_size: 1 38 | optim: paged_adamw_8bit 39 | remove_unused_columns: true 40 | save_steps: 20000 41 | save_total_limit: 1 42 | warmup_ratio: 0.03 43 | weight_decay: 0.001 44 | report_to: 45 | 46 | instruction: ragfit/processing/prompts/prompt_instructions/qa-yes-no.txt 47 | template: 48 | data_file: pubmed-rag-train.jsonl 49 | input_key: prompt 50 | output_key: answers 51 | resume_checkpoint: 52 | limit: 53 | shuffle: 54 | hfhub_tag: 55 | use_wandb: 56 | experiment: 57 | wandb_entity: 58 | 59 | -------------------------------------------------------------------------------- /docs/scripts/generate_docstrings.py: -------------------------------------------------------------------------------- 1 | # Ported from https://mkdocstrings.github.io/recipes/#automatic-code-reference-pages 2 | 3 | from pathlib import Path 4 | 5 | import mkdocs_gen_files 6 | 7 | nav = mkdocs_gen_files.Nav() 8 | 9 | src = Path(__file__).parent.parent.parent / "ragfit" 10 | excluded = ["__init__.py", "__pycache__"] 11 | 12 | for path in sorted(src.rglob("*.py")): 13 | if any(path.name.__contains__(exclude) for exclude in excluded): 14 | print(f"Skipping {path}") 15 | continue 16 | print(f"Processing file: {path}") 17 | module_path = path.relative_to(src).with_suffix("") 18 | doc_path = path.relative_to(src).with_suffix(".md") 19 | full_doc_path = Path("reference", doc_path) 20 | parts = tuple(module_path.parts) 21 | print(f"{module_path} -> {doc_path} -> {full_doc_path} ; {module_path.parts}") 22 | 23 | # if parts[-1] == "__init__": 24 | # parts = parts[:-1] 25 | # doc_path = doc_path.with_name("index.md") 26 | # full_doc_path = full_doc_path.with_name("index.md") 27 | # elif parts[-1] == "__main__": 28 | # continue 29 | 30 | nav[parts] = doc_path.as_posix() 31 | print(f"{doc_path.as_posix()}") 32 | 33 | print(f"Writing to {full_doc_path}") 34 | with mkdocs_gen_files.open(full_doc_path, "w") as fd: 35 | ident = ".".join(parts) 36 | fd.write(f"::: ragfit.{ident}") 37 | 38 | mkdocs_gen_files.set_edit_path(full_doc_path, path) 39 | 40 | print("---\n") 41 | # with mkdocs_gen_files.open("reference/SUMMARY.md", "w") as nav_file: 42 | # nav_file.writelines(nav.build_literate_nav()) 43 | -------------------------------------------------------------------------------- /configs/training.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | _target_: ragfit.models.hf.HFTrain 3 | model_name_or_path: microsoft/Phi-3-mini-128k-instruct 4 | load_in_4bit: false 5 | load_in_8bit: true 6 | torch_dtype: 7 | device_map: 8 | trust_remote_code: true 9 | lora: 10 | bias: none 11 | fan_in_fan_out: false 12 | layers_pattern: 13 | layers_to_transform: 14 | lora_alpha: 16 15 | lora_dropout: 0.1 16 | peft_type: LORA 17 | r: 16 18 | target_modules: 19 | - qkv_proj 20 | task_type: CAUSAL_LM 21 | use_rslora: true 22 | completion_start: <|assistant|> 23 | instruction_in_prompt: 24 | max_sequence_len: 4000 25 | 26 | train: 27 | output_dir: ./trained_models 28 | bf16: false 29 | fp16: false 30 | gradient_accumulation_steps: 4 31 | group_by_length: 32 | learning_rate: 2e-5 33 | logging_steps: 10 34 | lr_scheduler_type: cosine 35 | max_steps: -1 36 | num_train_epochs: 1 37 | per_device_train_batch_size: 1 38 | optim: adamw_torch_fused 39 | remove_unused_columns: true 40 | save_steps: 20000 41 | save_total_limit: 2 42 | warmup_ratio: 0.03 43 | weight_decay: 0.001 44 | report_to: wandb 45 | 46 | instruction: ragfit/processing/prompts/prompt_instructions/qa.txt 47 | template: # specify a template file or use chatML format with tokenizer's chat template 48 | data_file: my-processed-data.jsonl 49 | input_key: my_prompt 50 | output_key: answer 51 | resume_checkpoint: 52 | limit: 53 | shuffle: 54 | use_wandb: 55 | hfhub_tag: 56 | experiment: phi-training 57 | wandb_entity: 58 | -------------------------------------------------------------------------------- /configs/paper/training-pubmed-alternative.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | _target_: ragfit.models.hf.HFTrain 3 | model_name_or_path: microsoft/Phi-3-mini-128k-instruct 4 | # load_in_4bit: false 5 | # load_in_8bit: true 6 | quantization_config: 7 | _target_: transformers.BitsAndBytesConfig 8 | load_in_8bit: true 9 | 10 | torch_dtype: 11 | device_map: auto 12 | trust_remote_code: true 13 | lora: 14 | bias: none 15 | fan_in_fan_out: false 16 | layers_pattern: 17 | layers_to_transform: 18 | lora_alpha: 16 19 | lora_dropout: 0.1 20 | peft_type: LORA 21 | r: 16 22 | target_modules: 23 | - qkv_proj 24 | task_type: CAUSAL_LM 25 | use_rslora: true 26 | completion_start: <|assistant|> 27 | instruction_in_prompt: 28 | max_sequence_len: 2000 29 | 30 | train: 31 | output_dir: ./trained_model/ 32 | bf16: false 33 | fp16: false 34 | gradient_accumulation_steps: 2 35 | group_by_length: 36 | learning_rate: 1e-4 37 | logging_steps: 10 38 | lr_scheduler_type: cosine 39 | max_steps: -1 40 | num_train_epochs: 1 41 | per_device_train_batch_size: 1 42 | optim: paged_adamw_8bit 43 | remove_unused_columns: true 44 | save_steps: 20000 45 | save_total_limit: 1 46 | warmup_ratio: 0.03 47 | weight_decay: 0.001 48 | report_to: 49 | 50 | instruction: ragfit/processing/prompts/prompt_instructions/qa-yes-no.txt 51 | template: 52 | data_file: pubmed-rag-train.jsonl 53 | input_key: prompt 54 | output_key: answers 55 | resume_checkpoint: 56 | limit: 57 | shuffle: 58 | hfhub_tag: 59 | use_wandb: 60 | experiment: 61 | wandb_entity: 62 | dev_split: 0.1 -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | from functools import partial 4 | from pathlib import Path 5 | 6 | import hydra 7 | from datasets import load_dataset 8 | from hydra.utils import to_absolute_path 9 | from omegaconf import OmegaConf 10 | 11 | 12 | @hydra.main(version_base=None, config_path="./configs", config_name="inference") 13 | def main(args): 14 | logging.info(OmegaConf.to_yaml(args)) 15 | 16 | logging.info(f"Loading data file: {args.data_file}") 17 | data = load_dataset( 18 | "json", data_files=to_absolute_path(args.data_file), split="train" 19 | ) 20 | 21 | model = hydra.utils.instantiate(args.model, _convert_="object") 22 | logging.info(f"Loaded model: {model}") 23 | 24 | logging.info(f"Generated (opt. cache) file: {args.generated_file}") 25 | 26 | if Path(args.generated_file).exists(): 27 | saved_data = load_dataset("json", data_files=args.generated_file, split="train") 28 | else: 29 | saved_data = [] 30 | 31 | if args.limit: 32 | data = data.select(range(args.limit)) 33 | 34 | def map_generate(model, example, idx): 35 | if idx >= len(saved_data): 36 | out = model.generate(example[args.input_key]) 37 | example[args.generation_key] = out 38 | 39 | with open(args.generated_file, "a") as f: 40 | f.write( 41 | json.dumps({"text": out, "target": example[args.target_key]}) + "\n" 42 | ) 43 | 44 | else: 45 | example[args.generation_key] = saved_data[idx]["text"] 46 | 47 | return example 48 | 49 | data = data.map(partial(map_generate, model), with_indices=True) 50 | 51 | 52 | if __name__ == "__main__": 53 | main() 54 | -------------------------------------------------------------------------------- /configs/processing.yaml: -------------------------------------------------------------------------------- 1 | name: my_processing_pipeline 2 | cache: true 3 | output_path: . 4 | steps: 5 | - _target_: ragfit.processing.dataset_loaders.loaders.HFLoader 6 | inputs: main 7 | dataset_config: 8 | path: Tevatron/wikipedia-nq 9 | split: train 10 | 11 | - _target_: ragfit.processing.global_steps.sampling.ShuffleSelect 12 | inputs: main 13 | shuffle: 42 14 | limit: 10000 15 | 16 | - _target_: ragfit.processing.local_steps.context.ContextHandler 17 | inputs: main 18 | docs_key: positive_passages 19 | text_key: text 20 | 21 | - _target_: ragfit.processing.local_steps.context.DocumentsJoiner 22 | inputs: main 23 | docs_key: positive_passages 24 | 25 | - _target_: ragfit.processing.local_steps.formatting.FlattenList 26 | inputs: main 27 | input_key: answers 28 | output_key: answers 29 | 30 | - _target_: ragfit.processing.global_steps.sampling.FewShot 31 | inputs: main 32 | k: 3 33 | output_key: fewshot_examples 34 | 35 | - _target_: ragfit.processing.local_steps.prompter.FewshotPrompter 36 | inputs: main 37 | prompt_file: ragfit/processing/prompts/qa-with-answer.txt 38 | fewshot_key: fewshot_examples 39 | output_key: fewshot_examples 40 | mapping: 41 | question: query 42 | context: positive_passages 43 | answer: answers 44 | 45 | - _target_: ragfit.processing.local_steps.prompter.TextPrompter 46 | inputs: main 47 | prompt_file: ragfit/processing/prompts/qa-fewshot.txt 48 | output_key: prompt 49 | mapping: 50 | question: query 51 | context: positive_passages 52 | fewshot: fewshot_examples 53 | answer: answers 54 | 55 | - _target_: ragfit.processing.global_steps.output.OutputData 56 | inputs: main 57 | prefix: nq_processed 58 | 59 | - _target_: ragfit.processing.global_steps.output.HFHubOutput 60 | inputs: main 61 | hfhub_tag: username/dataset_tag_name -------------------------------------------------------------------------------- /ragfit/processing/local_steps/context.py: -------------------------------------------------------------------------------- 1 | from ..step import LocalStep 2 | 3 | 4 | class ContextHandler(LocalStep): 5 | """ 6 | Example class for processing retrieved documents. 7 | 8 | In this simple example, the text is combined with the title. 9 | """ 10 | 11 | def __init__(self, docs_key, title_key="title", text_key="content", **kwargs): 12 | """ 13 | Args: 14 | docs_key (str): Key to the documents in the item. 15 | title_key (str): Key to the title in the document. 16 | text_key (str): Key to the text in the document. 17 | """ 18 | super().__init__(**kwargs) 19 | self.docs_key = docs_key 20 | self.title_key = title_key 21 | self.text_key = text_key 22 | 23 | def process_item(self, item, index, datasets, **kwargs): 24 | docs = item[self.docs_key] 25 | docs = [f"{doc[self.title_key]}: {doc[self.text_key]}" for doc in docs] 26 | item[self.docs_key] = docs 27 | return item 28 | 29 | 30 | class DocumentsJoiner(LocalStep): 31 | """ 32 | Class to select top-K and join the documents into a string. 33 | """ 34 | 35 | def __init__(self, docs_key, k=None, join_string="\n", **kwargs): 36 | """ 37 | Join `k` documents in `docs_key` into a string using `join_string` and save back to `docs_key`. 38 | 39 | Args: 40 | docs_key (str): Key to the documents in the item. 41 | k (int, optional): Number of documents to select or take all. Defaults to None. 42 | join_string (str): String to join the documents. Defaults to "\\n". 43 | """ 44 | super().__init__(**kwargs) 45 | self.docs_key = docs_key 46 | self.k = k 47 | self.join_string = join_string 48 | 49 | def process_item(self, item, index, datasets, **kwargs): 50 | docs = item[self.docs_key] 51 | if self.k is not None: 52 | docs = docs[: self.k] 53 | docs = self.join_string.join(docs) 54 | item[self.docs_key] = docs 55 | return item 56 | -------------------------------------------------------------------------------- /ragfit/processing/local_steps/formatting.py: -------------------------------------------------------------------------------- 1 | from ..step import LocalStep 2 | 3 | 4 | class ColumnUpdater(LocalStep): 5 | """ 6 | Simple class to create new columns from existing columns in a dataset. 7 | Existing columns are not modified. 8 | 9 | Args: 10 | keys_mapping (dict): Dictionary with "from:to" mapping. 11 | """ 12 | 13 | def __init__(self, keys_mapping: dict, **kwargs): 14 | super().__init__(**kwargs) 15 | self.keys_mapping = keys_mapping 16 | 17 | def process_item(self, item, index, datasets, **kwargs): 18 | for from_key, to_key in self.keys_mapping.items(): 19 | item[to_key] = item[from_key] 20 | return item 21 | 22 | 23 | class FlattenList(LocalStep): 24 | """ 25 | Class to join a list of strings into a single string. 26 | """ 27 | 28 | def __init__(self, input_key, output_key, string_join=", ", **kwargs): 29 | """ 30 | Args: 31 | input_key (str): Key to the list of strings. 32 | output_key (str): Key to store the joined string. 33 | string_join (str): String to join the list of strings. Defaults to ", ". 34 | """ 35 | super().__init__(**kwargs) 36 | self.input_key = input_key 37 | self.output_key = output_key 38 | self.string_join = string_join 39 | 40 | def process_item(self, item, index, datasets, **kwargs): 41 | item[self.output_key] = self.string_join.join(item[self.input_key]) 42 | return item 43 | 44 | 45 | class UpdateField(LocalStep): 46 | """ 47 | Class to update a field in the dataset with a new value. 48 | """ 49 | 50 | def __init__(self, input_key: str, value, **kwargs): 51 | """ 52 | Args: 53 | input_key (str): example key to change. 54 | value: New value to set for the field. 55 | """ 56 | super().__init__(**kwargs) 57 | self.input_key = input_key 58 | self.value = value 59 | 60 | def process_item(self, item, index, datasets, **kwargs): 61 | item[self.input_key] = self.value 62 | return item 63 | -------------------------------------------------------------------------------- /ragfit/processing/global_steps/output.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from ..step import GlobalStep 4 | 5 | 6 | class OutputData(GlobalStep): 7 | """ 8 | Simple class to output the dataset to a jsonl file. 9 | 10 | Caching is disabled as this step does not manipulate the dataset hence no need for caching. 11 | """ 12 | 13 | def __init__(self, prefix, filename=None, directory=None, **kwargs): 14 | """ 15 | Args: 16 | prefix (str): Prefix for the output. 17 | filename (str, optional): Name of the output file. If not provided, the output file name will be generated based on the prefix and dataset name. 18 | directory (str, optional): Directory to save the output file. If not provided, the output file will be saved in the current directory. 19 | 20 | The output name is `{prefix}-{dataset_keyname/filename}.jsonl` if `filename` is not provided. 21 | """ 22 | super().__init__(**kwargs) 23 | self.prefix = prefix 24 | self.filename = filename 25 | self.dir = directory 26 | self.cache_step = False 27 | 28 | def process(self, dataset_name, datasets, **kwargs): 29 | if self.filename: 30 | name = self.filename 31 | else: 32 | name = dataset_name 33 | fname = f"{self.prefix}-{name}.jsonl" 34 | if self.dir is not None: 35 | fname = os.path.join(self.dir, fname) if self.dir else fname 36 | datasets[dataset_name].to_json(fname, lines=True) 37 | 38 | 39 | class HFHubOutput(GlobalStep): 40 | """ 41 | Simple class to output the dataset to Hugging Face Hub. 42 | 43 | Caching is disabled as this step does not manipulate the dataset hence no need for caching. 44 | """ 45 | 46 | def __init__(self, hfhub_tag, private=True, **kwargs): 47 | """ 48 | Args: 49 | hfhub_tag (str): Tag for the Hugging Face Hub. 50 | private (bool): Whether the dataset should be private or not. Default is True. 51 | """ 52 | super().__init__(**kwargs) 53 | self.hfhub_tag = hfhub_tag 54 | self.private = private 55 | self.cache_step = False 56 | 57 | def process(self, dataset_name, datasets, **kwargs): 58 | datasets[dataset_name].push_to_hub(self.hfhub_tag, private=self.private) 59 | -------------------------------------------------------------------------------- /ragfit/processing/local_steps/raft.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | from ..step import LocalStep 4 | 5 | 6 | class RAFTStep(LocalStep): 7 | """ 8 | Implementation of RAFT: Adapting Language Model to Domain Specific RAG. 9 | 10 | This class compiles a list of negative documents with probability `raft_p`, 11 | and a combination of positive and negative documents with probability 1 - `raft_p`. 12 | 13 | Zhang, Tianjun, Shishir G. Patil, Naman Jain, Sheng Shen, Matei Zaharia, 14 | Ion Stoica, and Joseph E. Gonzalez. 2024. “RAFT: Adapting Language Model 15 | to Domain Specific RAG.” arXiv. http://arxiv.org/abs/2403.10131. 16 | """ 17 | 18 | def __init__( 19 | self, 20 | k: int = 5, 21 | raft_p=0.5, 22 | neg_docs_num=5, 23 | positive_key="positive_passages", 24 | negative_key="negative_passages", 25 | output_key="docs", 26 | **kwargs, 27 | ): 28 | """ 29 | Args: 30 | k (int): The number of positive passages to consider. 31 | raft_p (float, optional): The probability of using positive passages. Defaults to 0.5. 32 | neg_docs_num (int, optional): The number of negative passages to consider. Defaults to 2. 33 | positive_key (str, optional): The key containing the positive passages. Defaults to "positive_passages". 34 | negative_key (str, optional): The key containing the negative passages. Defaults to "negative_passages". 35 | output_key (str, optional): The key to store the output. Defaults to "docs". 36 | """ 37 | super().__init__(**kwargs) 38 | self.k = k 39 | self.raft_p = raft_p 40 | self.neg_docs_num = neg_docs_num 41 | self.positive_key = positive_key 42 | self.negative_key = negative_key 43 | self.output_key = output_key 44 | 45 | def process_item(self, item: dict, index, datasets, **kwargs): 46 | docs_pos = item[self.positive_key] 47 | docs_neg = item.get(self.negative_key, []) 48 | 49 | p = random.random() # nosec 50 | oracle = 0 51 | 52 | if p > self.raft_p: 53 | docs = docs_pos[: self.k] + docs_neg[: self.neg_docs_num] 54 | else: 55 | docs = docs_neg[: self.neg_docs_num] 56 | 57 | item[self.output_key] = docs 58 | 59 | return item 60 | -------------------------------------------------------------------------------- /ragfit/processing/local_steps/common_datasets.py: -------------------------------------------------------------------------------- 1 | from ragfit.evaluation.metrics import normalize_text 2 | 3 | from ..step import LocalStep 4 | 5 | 6 | class ASQA(LocalStep): 7 | """ 8 | Normalizes ASQA dataset. 9 | 10 | It has long answer, to be measured with ROUGE-L and multiple short answers, to be 11 | measured with string-EM. Long answer is saved in the `answers` field, while the 12 | short answers (list of lists) are saved in the `answer-short` field. 13 | """ 14 | 15 | def process_item(self, item, index, datasets, **kwargs): 16 | item["answer-long"] = [ann["long_answer"] for ann in item["annotations"]] 17 | short = [] 18 | for qa_pair in item["qa_pairs"]: 19 | normalize = [normalize_text(ans) for ans in qa_pair["short_answers"]] 20 | short.append(normalize) 21 | item["answer-short"] = short 22 | 23 | item["answers"] = item["answer-long"] 24 | item["query"] = item["ambiguous_question"] 25 | 26 | return item 27 | 28 | 29 | class HotPot(LocalStep): 30 | """ 31 | Normalizes NotPotQA dataset to look like NQ, TQA 32 | """ 33 | 34 | def process_item(self, item, index, datasets, **kwargs): 35 | item["answers"] = [item["answer"]] 36 | item["query"] = item["question"] 37 | 38 | # Contexts converted into a list of relevant documents (Dict with title + text) 39 | titles = item["context"]["title"] 40 | 41 | sentences = item["context"]["sentences"] 42 | sentences = ["".join(lines) for lines in sentences] 43 | 44 | docs = [f"{title}: {text}" for title, text in zip(titles, sentences)] 45 | item["positive_passages"] = docs 46 | 47 | return item 48 | 49 | 50 | class ARCC(LocalStep): 51 | """ 52 | Prepare dataset for RAG augmentation. 53 | """ 54 | 55 | def process_item(self, item, index, datasets, **kwargs): 56 | item["query"] = item["question"] 57 | item["options"] = item["choices"]["text"] 58 | item["answers"] = item["answerKey"] 59 | 60 | return item 61 | 62 | 63 | class PubMed(LocalStep): 64 | """ 65 | Prepare dataset for RAG augmentation. 66 | """ 67 | 68 | def process_item(self, item, index, datasets, **kwargs): 69 | item["query"] = item["QUESTION"] 70 | item["answers"] = [item["final_decision"]] 71 | docs = item["CONTEXTS"] 72 | item["positive_passages"] = docs 73 | 74 | return item 75 | -------------------------------------------------------------------------------- /ragfit/processing/local_steps/prompter.py: -------------------------------------------------------------------------------- 1 | from ..step import LocalStep 2 | 3 | 4 | class TextPrompter(LocalStep): 5 | """ 6 | Class for creating prompts. The input is a prompt file with placeholders, a mapping of the placeholders to the item keys, and the key to store the result. 7 | """ 8 | 9 | def __init__(self, prompt_file: str, mapping: dict, output_key, **kwargs): 10 | """ 11 | Args: 12 | prompt_file (str): Path to the prompt file. 13 | mapping (dict): Mapping of the placeholders in the prompt to the item keys. 14 | output_key (str): Key to store the formatted prompt. 15 | """ 16 | super().__init__(**kwargs) 17 | self.prompt = open(prompt_file).read() 18 | self.mapping = mapping 19 | self.output_key = output_key 20 | 21 | def process_item(self, item, index, datasets, **kwargs): 22 | prompt = self.prompt.format( 23 | **{k: item.get(v, "") for k, v in self.mapping.items()} 24 | ) 25 | item[self.output_key] = prompt 26 | return item 27 | 28 | 29 | class FewshotPrompter(LocalStep): 30 | """ 31 | Class for formatting fewshot examples into a string, to be used in a prompt. 32 | 33 | The prompt template contains a placeholder for the fewshot examples; this 34 | class is used to format the examples into a string. 35 | """ 36 | 37 | def __init__( 38 | self, prompt_file: str, fewshot_key: str, mapping: dict, output_key: str, **kwargs 39 | ): 40 | """ 41 | Args: 42 | prompt_file (str): Path to the prompt file for the individual fewshot examples. 43 | fewshot_key (str): Key to the fewshot examples in the item. 44 | mapping (dict): Mapping of the placeholders in the prompt to the item keys. 45 | output_key (str): Key to store the formatted fewshot examples. 46 | """ 47 | super().__init__(**kwargs) 48 | self.prompt = open(prompt_file).read() 49 | self.fewshot_key = fewshot_key 50 | self.mapping = mapping 51 | self.output_key = output_key 52 | 53 | def process_item(self, item, index, datasets, **kwargs): 54 | texts = [] 55 | for ex in item[self.fewshot_key]: 56 | text = self.prompt.format( 57 | **{k: ex.get(v, "") for k, v in self.mapping.items()} 58 | ) 59 | texts.append(text) 60 | 61 | item[self.output_key] = "\n\n".join(texts) 62 | return item 63 | -------------------------------------------------------------------------------- /ragfit/processing/local_steps/retrievers/haystack.py: -------------------------------------------------------------------------------- 1 | from ...step import LocalStep 2 | 3 | 4 | class HaystackRetriever(LocalStep): 5 | """ 6 | Class for document retrieval using Haystack v2 pipelines. 7 | """ 8 | 9 | def __init__(self, pipeline_or_yaml_path, docs_key, query_key, **kwargs): 10 | super().__init__(**kwargs) 11 | from haystack import Pipeline 12 | 13 | if isinstance(pipeline_or_yaml_path, str): 14 | self.pipe = Pipeline.load(open(pipeline_or_yaml_path)) 15 | else: 16 | self.pipe = pipeline_or_yaml_path 17 | 18 | self.docs_key = docs_key 19 | self.query_key = query_key 20 | 21 | def default_query_function(self, query): 22 | """ 23 | Create the default querying of the pipeline, by inserting the input query into all mandatory fields. 24 | """ 25 | pipe_inputs = self.pipe.inputs() 26 | query_dict = {} 27 | for inp_node_name, inp_node_params in pipe_inputs.items(): 28 | for param_name, param_values in inp_node_params.items(): 29 | if param_values["is_mandatory"]: 30 | if inp_node_name not in query_dict: 31 | query_dict[inp_node_name] = {} 32 | 33 | query_dict[inp_node_name][param_name] = query 34 | 35 | return query_dict 36 | 37 | def query(self, query, structure=None): 38 | """ 39 | Haystack v2 pipelines can have multiple inputs; structure specify how to call `pipe.run`. 40 | 41 | For example, structure could look like this: 42 | { 43 | "Retriever": {"query": "query",}, 44 | "Reranker": {"query": "query"}, 45 | } 46 | and we replace the **value** of each key with the query. 47 | """ 48 | 49 | if structure is None: 50 | structure = self.default_query_function(query) 51 | else: 52 | for key, value in structure.items(): 53 | structure[key] = {k: query for k in value.keys()} 54 | 55 | response = self.pipe.run(structure) 56 | all_documents = [] 57 | for v in response.values(): 58 | if "documents" in v: 59 | # has documents, add to list 60 | all_documents += v["documents"] 61 | 62 | all_documents = [ 63 | {"content": d.content, "title": d.meta.get("title")} for d in all_documents 64 | ] 65 | 66 | return all_documents 67 | 68 | def process_item(self, item, index, datasets, **kwargs): 69 | """ 70 | Query the `query_key` in the item and store the results in the `docs_key`. 71 | Retrieved documents are stored as a list of dictionaries with keys `content` and `title`. 72 | """ 73 | item[self.docs_key] = self.query(item[self.query_key]) 74 | return item 75 | -------------------------------------------------------------------------------- /ragfit/processing/global_steps/aggregation.py: -------------------------------------------------------------------------------- 1 | from datasets import concatenate_datasets 2 | 3 | from ..step import GlobalStep 4 | from .filters import filters 5 | 6 | 7 | class FilterDataset(GlobalStep): 8 | """ 9 | Step for filtering a dataset. 10 | """ 11 | 12 | def __init__(self, filter_fn, **kwargs): 13 | """ 14 | Args: 15 | filter_fn (function): Function to filter the dataset. 16 | """ 17 | super().__init__(**kwargs) 18 | self.filter_fn = filters[filter_fn] 19 | 20 | def process(self, dataset_name, datasets, **kwargs): 21 | datasets[dataset_name] = datasets[dataset_name].filter(self.filter_fn) 22 | 23 | 24 | class SelectColumns(GlobalStep): 25 | """ 26 | Step for selecting specified columns in a dataset. 27 | """ 28 | 29 | def __init__(self, columns: list[str], **kwargs): 30 | """ 31 | Args: 32 | columns (list): List of keys to keep in the dataset. 33 | """ 34 | super().__init__(**kwargs) 35 | assert isinstance(columns, list), "columns should be a list of strings." 36 | self.columns = columns 37 | 38 | def process(self, dataset_name, datasets, **kwargs): 39 | datasets[dataset_name] = datasets[dataset_name].select_columns(self.columns) 40 | 41 | 42 | class MergeDatasets(GlobalStep): 43 | """ 44 | Step for merging datasets. 45 | 46 | Merge is done using concatenation. Optional shuffling by providing a seed. 47 | """ 48 | 49 | def __init__(self, output, shuffle=None, **kwargs): 50 | """ 51 | Args: 52 | output (str): Name of the output dataset. Should be unique. 53 | shuffle (int, optional): seed for shuffling. Default is None. 54 | """ 55 | super().__init__(**kwargs) 56 | self.output = output 57 | self.shuffle = shuffle 58 | self.completed = False 59 | self.cache_step = False 60 | 61 | def process(self, dataset_name, datasets, **kwargs): 62 | if not self.completed: 63 | data = concatenate_datasets([datasets[name] for name in self.inputs]) 64 | if self.shuffle: 65 | data = data.shuffle(self.shuffle) 66 | datasets[self.output] = data 67 | self.completed = True 68 | 69 | 70 | class DatasetTagger(GlobalStep): 71 | """ 72 | Class to tag each example with the dataset name. Useful when running aggregations. 73 | """ 74 | 75 | def __init__(self, keyword="source", **kwargs): 76 | """ 77 | Args: 78 | keyword (str): The key to use for tagging. Default is "source". 79 | """ 80 | super().__init__(**kwargs) 81 | self.keyword = keyword 82 | 83 | def tag(self, item, dataset_name): 84 | item[self.keyword] = dataset_name 85 | return item 86 | 87 | def process(self, dataset_name, datasets, **kwargs): 88 | datasets[dataset_name] = datasets[dataset_name].map( 89 | lambda item: self.tag(item, dataset_name), 90 | load_from_cache_file=False, 91 | ) 92 | -------------------------------------------------------------------------------- /ragfit/models/openai_executor.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import time 4 | from typing import List, Union 5 | 6 | from openai import AzureOpenAI 7 | 8 | 9 | class OpenAIExecutor: 10 | """ 11 | Class representing an interface to the Azure OpenAI API. 12 | """ 13 | 14 | def __init__( 15 | self, 16 | azure_endpoint: str, 17 | api_key: str = None, 18 | api_version: str = "2024-02-15-preview", 19 | model: str = "GPT-4-32k-Bot", 20 | chat_parameters: dict = None, 21 | delay: int = 1, 22 | ): 23 | """ 24 | Initialize the OpenAIExecutor. 25 | 26 | Args: 27 | azure_endpoint (str): The Azure endpoint. 28 | api_key (str): The API key, can also read of ENV variable. 29 | api_version (str): The API version. 30 | model (str): The model to use, sometimes called deployment or engine. 31 | chat_parameters (dict): The chat parameters. 32 | delay (int): delay between calls. 33 | """ 34 | self.delay = delay 35 | self.model = model 36 | self.chat_parameters = dict( 37 | temperature=0.7, 38 | max_tokens=200, 39 | top_p=0.95, 40 | frequency_penalty=0, 41 | presence_penalty=0, 42 | stop=None, 43 | ) 44 | if chat_parameters: 45 | self.chat_parameters.update(chat_parameters) 46 | 47 | self.client = AzureOpenAI( 48 | azure_endpoint=azure_endpoint, 49 | api_key=api_key or os.getenv("AZURE_OPENAI_API_KEY"), 50 | api_version=api_version, 51 | ) 52 | 53 | def chat(self, prompt: Union[List, str], instruction: str = None) -> str: 54 | """ 55 | Chat with the OpenAI API. 56 | 57 | Args: 58 | prompt (Union[List, str]): The prompt to chat. 59 | instruction (str): The instruction to use. 60 | 61 | Returns: 62 | str: The response. Empty string if error. 63 | """ 64 | if isinstance(prompt, str): 65 | prompt = [ 66 | { 67 | "role": "system", 68 | "content": ( 69 | instruction 70 | or "You are an AI assistant that helps people find information." 71 | ), 72 | }, 73 | {"role": "user", "content": prompt}, 74 | ] 75 | 76 | if self.delay: 77 | time.sleep(self.delay) 78 | 79 | try: 80 | completion = self.client.chat.completions.create( 81 | model=self.model, 82 | messages=prompt, 83 | **self.chat_parameters, 84 | ) 85 | message_obj = completion.choices[0].message 86 | 87 | if hasattr(message_obj, "content"): 88 | answer = message_obj.content 89 | return answer or "" 90 | else: 91 | return "" 92 | 93 | except Exception as e: 94 | logging.info(f"OPENAI error:\n{e}") 95 | return "" 96 | -------------------------------------------------------------------------------- /ragfit/models/vllm.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from pathlib import Path 3 | from typing import Dict 4 | 5 | from transformers import AutoConfig, AutoTokenizer 6 | 7 | from ragfit.utils import check_package_installed 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | class VLLMInference: 13 | """ 14 | Initializes a vLLM-based inference engine. 15 | 16 | Args: 17 | model_name_or_path (str): The name or path of the model. 18 | instruction (Path): path to the instruction file. 19 | instruct_in_prompt (bool): whether to include the instruction in the prompt for models without system role. 20 | template (Path): path to a prompt template file if tokenizer does not include chat template. Optional. 21 | num_gpus (int, optional): The number of GPUs to use. Defaults to 1. 22 | llm_params (Dict, optional): Additional parameters for the LLM model. Supports all parameters define by vLLM LLM engine. Defaults to an empty dictionary. 23 | generation (Dict, optional): Additional parameters for text generation. Supports all the keywords of `SamplingParams` of vLLM. Defaults to an empty dictionary. 24 | """ 25 | 26 | def __init__( 27 | self, 28 | model_name_or_path: str, 29 | instruction: Path, 30 | instruct_in_prompt: bool = False, 31 | template: Path = None, 32 | num_gpus: int = 1, 33 | llm_params: Dict = {}, 34 | generation: Dict = {}, 35 | ): 36 | check_package_installed( 37 | "vllm", 38 | "please refer to vLLM website for installation instructions, or run: pip install vllm", 39 | ) 40 | from vllm import LLM, SamplingParams 41 | 42 | self.model_name = model_name_or_path 43 | self.instruct_in_prompt = instruct_in_prompt 44 | self.template = open(template).read() if template else None 45 | self.instruction = open(instruction).read() 46 | logger.info(f"Using the following instruction: {self.instruction}") 47 | 48 | self.sampling_params = SamplingParams(**generation) 49 | self.llm = LLM(model=self.model_name, tensor_parallel_size=num_gpus, **llm_params) 50 | 51 | self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) 52 | self.config = AutoConfig.from_pretrained(self.model_name) 53 | 54 | def generate(self, prompt: str) -> str: 55 | """ 56 | Generates text based on the given prompt. 57 | """ 58 | if self.template: 59 | prompt = self.template.format(instruction=self.instruction, query=prompt) 60 | else: 61 | if self.instruct_in_prompt: 62 | prompt = self.instruction + "\n" + prompt 63 | messages = [ 64 | {"role": "system", "content": self.instruction}, 65 | {"role": "user", "content": prompt}, 66 | ] 67 | 68 | prompt = self.tokenizer.apply_chat_template( 69 | messages, 70 | tokenize=False, 71 | add_generation_prompt=True, 72 | truncation=True, 73 | max_length=( 74 | self.config.max_position_embeddings - self.sampling_params.max_tokens 75 | ), 76 | ) 77 | 78 | output = self.llm.generate(prompt, self.sampling_params) 79 | return output[0].outputs[0].text 80 | -------------------------------------------------------------------------------- /evaluation.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from collections import defaultdict 4 | from pathlib import Path 5 | 6 | import hydra 7 | import torch 8 | import yaml 9 | from datasets import load_dataset 10 | from hydra.utils import to_absolute_path 11 | from omegaconf import OmegaConf 12 | from tqdm import tqdm 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | def setup_wandb(args: dict): 18 | """ 19 | WANDB integration for tracking evaluations. 20 | """ 21 | import wandb 22 | from wandb.wandb_run import Run 23 | 24 | env = {key: os.getenv(key) for key in os.environ} 25 | run: Run = wandb.init( 26 | job_type="eval", 27 | project=args["experiment"], 28 | entity=args["wandb_entity"], 29 | config={**args, **env}, 30 | tags=["eval"], 31 | ) 32 | return run 33 | 34 | 35 | @hydra.main(version_base=None, config_path="./configs", config_name="evaluation") 36 | def main(args): 37 | logger.info(OmegaConf.to_yaml(args)) 38 | 39 | if args.use_wandb: 40 | run = setup_wandb(OmegaConf.to_container(args)) 41 | 42 | logger.info(f"Loading dataset: {args.data_file}") 43 | data = load_dataset( 44 | "json", data_files=to_absolute_path(args.data_file), split="train" 45 | ) 46 | 47 | generated_data = load_dataset("json", data_files=args.generated_file, split="train") 48 | logging.info(f"Loaded {len(generated_data)} examples from {args.generated_file}") 49 | 50 | if args.limit: 51 | data = data.select(range(args.limit)) 52 | 53 | if args.answer_processor: 54 | answer_processor = hydra.utils.instantiate( 55 | args.answer_processor, _convert_="object" 56 | ) 57 | else: 58 | 59 | def answer_processor(x): 60 | return x 61 | 62 | def map_load(example, idx): 63 | example[args.key_names["generated"]] = answer_processor( 64 | generated_data[idx][args.key_names["generated"]] 65 | ) 66 | return example 67 | 68 | data = data.map(map_load, with_indices=True) 69 | size = len(data) 70 | 71 | results = {"local": defaultdict(list), "global": {}} 72 | for metric in args.metrics: 73 | obj = hydra.utils.instantiate( 74 | metric, key_names=args.key_names, _convert_="object" 75 | ) 76 | if obj.local: 77 | for example in tqdm(data): 78 | calculation = obj.measure(example) 79 | for key, val in calculation.items(): 80 | results["local"][key].append(val) 81 | else: 82 | calculation = obj.measure(data) 83 | for key, val in calculation.items(): 84 | results["global"][key] = val 85 | del obj 86 | torch.cuda.empty_cache() 87 | 88 | logging.info(f"Normalizing by size {size}") 89 | for key in results["local"].keys(): 90 | results["local"][key] = float(sum(results["local"][key]) / size) 91 | results["local"] = dict(results["local"]) 92 | 93 | logging.info(f"Results: {results}") 94 | if args.use_wandb: 95 | run.log(results, step=0) 96 | 97 | if args.results_file is None: 98 | args.results_file = Path(args.generated_file).stem + "-results.yaml" 99 | 100 | with open(args.results_file, "w") as f: 101 | yaml.dump(results, f, sort_keys=True) 102 | logging.info(f"Results saved to {args.results_file}") 103 | 104 | 105 | if __name__ == "__main__": 106 | main() 107 | -------------------------------------------------------------------------------- /ragfit/processing/step.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from .utils import dict_hash, is_jsonable 4 | 5 | 6 | class BaseStep: 7 | """ 8 | Class representing a step in a processing pipeline. 9 | Entry point is `__call__`. 10 | Users would inherit either LocalStep or GlobalStep. 11 | 12 | Step can be cached (on by default: `cache_step=True`) to prevent re-computation. 13 | 14 | Individual steps can disable caching if and only if they do not manipulate the dataset, as 15 | re-computation of later steps is conditioned on the necessity of caching. 16 | """ 17 | 18 | def __init__(self, **kwargs): 19 | self.kwargs = kwargs 20 | self.inputs: list[str] = kwargs.get("inputs", ["main_dataset"]) 21 | self.step_hash = None 22 | self.cache_step = True 23 | 24 | if isinstance(self.inputs, str): 25 | self.inputs = [self.inputs] 26 | 27 | assert ( 28 | not isinstance(self.inputs, str) and len(self.inputs) > 0 29 | ), f"`inputs` should be a list, got {type(self.inputs)}" 30 | 31 | def calc_hash(self): 32 | """ 33 | Calculate hash for a step based on its properties. 34 | Updates the `step_hash` property. 35 | """ 36 | args_to_hash = {} 37 | for property, value in vars(self).items(): 38 | if is_jsonable(value): 39 | args_to_hash[property] = value 40 | self.step_hash = dict_hash(args_to_hash) 41 | 42 | def get_hash(self): 43 | """ 44 | Step hash getter. If hash is not calculated, it calculates it first. 45 | """ 46 | if self.step_hash is None: 47 | self.calc_hash() 48 | return self.step_hash 49 | 50 | def __call__(self, datasets, **kwargs): 51 | """ 52 | Pipeline is running these steps using `__call__`. 53 | """ 54 | logging.info(f"Running processing step: {type(self).__name__}") 55 | self.process_inputs(datasets, **kwargs) 56 | 57 | def process_inputs(self, datasets, **kwargs): 58 | """ 59 | Run the step `process` function for each dataset in `inputs`. 60 | """ 61 | for dataset_name in self.inputs: 62 | self.process(dataset_name, datasets, **kwargs) 63 | 64 | def process(self, dataset_name, datasets, **kwargs): 65 | """ 66 | General processing of `dataset_name` in `datasets`, in place. 67 | """ 68 | pass 69 | 70 | 71 | class LocalStep(BaseStep): 72 | """ 73 | Class representing a step in a processing pipeline, processing individual examples. 74 | 75 | The function to overwrite is `process_item`; the function accepts an item, index, and all the other datasets, if needed. 76 | """ 77 | 78 | def __init__(self, **kwargs): 79 | super().__init__(**kwargs) 80 | 81 | def process(self, dataset_name, datasets, **kwargs): 82 | datasets[dataset_name] = datasets[dataset_name].map( 83 | lambda item, index: self.process_item(item, index, datasets, **kwargs), 84 | with_indices=True, 85 | load_from_cache_file=False, 86 | ) 87 | 88 | def process_item(self, item, index, datasets, **kwargs): 89 | return item 90 | 91 | 92 | class GlobalStep(BaseStep): 93 | """ 94 | Class representing a step in a processing pipeline, processing the entire dataset. 95 | 96 | The function to overwrite is `process_all`; the function accepts the dataset and all the other datasets, if needed. 97 | """ 98 | 99 | def __init__(self, **kwargs): 100 | super().__init__(**kwargs) 101 | 102 | def process(self, dataset_name, datasets, **kwargs): 103 | datasets[dataset_name] = self.process_all( 104 | datasets[dataset_name], datasets, **kwargs 105 | ) 106 | 107 | def process_all(self, dataset, datasets, **kwargs): 108 | return dataset 109 | -------------------------------------------------------------------------------- /ragfit/processing/global_steps/sampling.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | from ..step import GlobalStep 4 | 5 | 6 | class ShuffleSelect(GlobalStep): 7 | """ 8 | Class to optionally shuffle and select a subset of the dataset. 9 | 10 | Based on the `shuffle` and `select` methods of HF Dataset. 11 | """ 12 | 13 | def __init__(self, shuffle=None, limit=None, **kwargs): 14 | """ 15 | Args: 16 | shuffle (int): Seed for shuffling the dataset. 17 | limit (int): Number of items to select from the dataset. 18 | """ 19 | super().__init__(**kwargs) 20 | self.shuffle = shuffle 21 | self.limit = limit 22 | 23 | def process_all(self, dataset, datasets, **kwargs): 24 | if self.shuffle: 25 | dataset = dataset.shuffle(seed=self.shuffle) 26 | if self.limit: 27 | dataset = dataset.select(range(min(len(dataset), self.limit))) 28 | return dataset 29 | 30 | 31 | class Sampler(GlobalStep): 32 | """ 33 | Class to augment a dataset with sampled examples from the same or another dataset. 34 | 35 | Full examples can be collected, as well as an individual example keys like `query`, `documents`, etc. 36 | 37 | The step can be used to collect negative documents, negative queries and collect fewshot examples. 38 | For fewshot examples, use the dedicated `FewShot` class. 39 | """ 40 | 41 | def __init__( 42 | self, k, input_key=None, output_key="fewshot", input_dataset=None, **kwargs 43 | ): 44 | """ 45 | Args: 46 | k (int): Number of examples to collect. 47 | input_key (str): a key to collect from the collected examples, or None to take entire example. 48 | output_key (str): output key to use for the examples. 49 | input_dataset (str): Name of the dataset to take the examples from. To use the same dataset, use None. 50 | """ 51 | super().__init__(**kwargs) 52 | self.k = k 53 | self.input_key = input_key 54 | self.input_dataset = input_dataset 55 | self.output_key = output_key 56 | 57 | def process(self, dataset_name, datasets, **kwargs): 58 | input_dataset = datasets[self.input_dataset or dataset_name] 59 | 60 | def find_examples(item, idx): 61 | ids = [] 62 | while len(ids) < self.k: 63 | rand_idx = random.randint(0, len(input_dataset) - 1) 64 | if self.input_dataset is None and rand_idx == idx: 65 | continue 66 | if rand_idx in ids: 67 | continue 68 | ids.append(rand_idx) 69 | examples = [ 70 | ( 71 | input_dataset[id_] 72 | if self.input_key is None 73 | else input_dataset[id_][self.input_key] 74 | ) 75 | for id_ in ids 76 | ] 77 | item[self.output_key] = examples if self.k > 1 else examples[0] 78 | return item 79 | 80 | datasets[dataset_name] = datasets[dataset_name].map( 81 | lambda item, index: find_examples(item, index), 82 | with_indices=True, 83 | load_from_cache_file=False, 84 | ) 85 | 86 | 87 | class FewShot(Sampler): 88 | """ 89 | Class to collect fewshot examples from the same or another dataset. 90 | """ 91 | 92 | def __init__(self, k, output_key="fewshot", input_dataset=None, **kwargs): 93 | """ 94 | Args: 95 | k (int): Number of examples to collect. 96 | output_key (str): output key to use for the collected examples. 97 | input_dataset (str): Name of the dataset to take the examples from. To use the same dataset, use None. 98 | """ 99 | super().__init__( 100 | k=k, 101 | output_key=output_key, 102 | input_key=None, 103 | input_dataset=input_dataset, 104 | **kwargs, 105 | ) 106 | -------------------------------------------------------------------------------- /docs/training.md: -------------------------------------------------------------------------------- 1 | # Training 2 | 3 | Training is done on the processed files. The training configuration has 3 parts: model, training arguments and data. 4 | 5 | ```yaml 6 | model: 7 | _target_: ragfit.models.hf.HFTrain 8 | model_name_or_path: microsoft/Phi-3-mini-128k-instruct 9 | load_in_4bit: false 10 | load_in_8bit: true 11 | lora: 12 | bias: none 13 | fan_in_fan_out: false 14 | lora_alpha: 16 15 | lora_dropout: 0.1 16 | peft_type: LORA 17 | r: 16 18 | target_modules: 19 | - qkv_proj 20 | task_type: CAUSAL_LM 21 | use_rslora: true 22 | completion_start: <|assistant|> 23 | instruction_in_prompt: 24 | max_sequence_len: 4000 25 | ``` 26 | Model loading is done in the `HFTrain` class, which loads models from HuggingFace hub and uses PEFT adapters. Other 27 | classes can be implemented. The important keys here are: `completion_start` which indicates the beginning of the text 28 | where loss is to be calculated. This is model/tokenizer specific. Additionally, there is the `instruction_in_prompt` 29 | key, which if set to *True*, inserts the system instruction in the prompt, for models which do not support a dedicated 30 | system role. 31 | 32 | Next is the training arguments: 33 | ```yaml 34 | train: 35 | output_dir: ./trained_models/ 36 | bf16: false 37 | fp16: false 38 | gradient_accumulation_steps: 2 39 | group_by_length: 40 | learning_rate: 1e-4 41 | logging_steps: 10 42 | lr_scheduler_type: cosine 43 | max_steps: -1 44 | num_train_epochs: 1 45 | per_device_train_batch_size: 1 46 | optim: paged_adamw_8bit 47 | remove_unused_columns: true 48 | save_steps: 20000 49 | save_total_limit: 1 50 | warmup_ratio: 0.03 51 | weight_decay: 0.001 52 | report_to: 53 | ``` 54 | 55 | Training is done using the `SFTTrainer` in `TRL`. Training arguments are based on HuggingFace `Trainer`. 56 | 57 | Finally, data and other options: 58 | ```yaml 59 | instruction: ragfit/processing/prompts/prompt_instructions/qa.txt 60 | template: 61 | data_file: 62 | input_key: prompt 63 | output_key: 64 | resume_checkpoint: 65 | limit: 66 | shuffle: 67 | hfhub_tag: 68 | use_wandb: 69 | experiment: 70 | wandb_entity: 71 | ``` 72 | 73 | Here are they important keys: 74 | 75 | - The instruction file to use for training (should later be used for inference as well). 76 | - If the model/tokenizer do not support a chat template, the user needs to provided a custom template; they placeholders to 77 | fill are `query` and `output`. 78 | - Data file is the processed file to train on. 79 | - Input key is the prompt. 80 | - Output key is completion text to learn. 81 | - Limit and shuffle can be used to filter the dataset for debugging purposes. 82 | - The framework can push the trained model to `hfhub_tab`. 83 | - The last three keys related to experiment tracking using WANDB. Other services can be used by modifying the 84 | `report_to` key. 85 | 86 | ## Sending Runs 87 | 88 | As we mentioned in the Data Augmentation page, we demonstrate the framework functionality using the ASQA dataset and the 89 | Phi-3 model, experimenting with 5 different configurations. Only 2 configurations require fine-tuning. One can send the 90 | training job like this: 91 | 92 | ```sh 93 | python training.py -cp configs/paper -cn training-asqa \ 94 | data_file=asqa-context-train.jsonl \ 95 | output_key=answers \ 96 | train.output_dir=./trained_models_context/ 97 | ``` 98 | 99 | The `-cp` and `-cn` are overrides for the default configuration, which is `./configs/training.yaml`. Then there are 100 | overrides for the processed data file to use, the name of the label key and where to save the trained model. Overrides 101 | are based on the [Hydra](https://hydra.cc/) vocabulary. 102 | 103 | For the CoT model with RAFT contexts, we run: 104 | ```sh 105 | python training.py -cp configs/paper -cn training-asqa \ 106 | data_file=asqa-raft-cot-train.jsonl \ 107 | output_key=generated_answer \ 108 | train.output_dir=./trained_models_cot/ 109 | ``` -------------------------------------------------------------------------------- /training.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from operator import itemgetter 4 | from pathlib import Path 5 | 6 | import hydra 7 | import wandb 8 | from datasets import load_dataset 9 | from hydra.utils import to_absolute_path 10 | from omegaconf import OmegaConf 11 | from transformers import TrainingArguments 12 | from trl import DataCollatorForCompletionOnlyLM, SFTTrainer 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | def setup_wandb(args: dict): 18 | """ 19 | WANDB integration for tracking training runs. 20 | """ 21 | env = {key: os.getenv(key) for key in os.environ} 22 | run = wandb.init( 23 | job_type="train", 24 | project=args["project"], 25 | group=args["experiment"], 26 | entity=args["wandb_entity"], 27 | config={**args, **env}, 28 | tags=["train"], 29 | ) 30 | return run 31 | 32 | 33 | @hydra.main(version_base=None, config_path="./configs", config_name="training") 34 | def main(args): 35 | logger.info(OmegaConf.to_yaml(args)) 36 | OmegaConf.set_struct(args, False) 37 | 38 | logger.info(f"Experiment name: {args.experiment}") 39 | logger.info(f"Output path: {args.train.output_dir}") 40 | 41 | if args.use_wandb: 42 | run = setup_wandb(OmegaConf.to_container(args)) 43 | 44 | logger.info(f"Loading dataset: {args.data_file}") 45 | dataset = load_dataset( 46 | "json", data_files=to_absolute_path(args.data_file), split="train" 47 | ) 48 | 49 | logger.info(f"Loading instruction from file {args.instruction}...") 50 | instruction = open(args.instruction).read() 51 | logger.info(f"Loaded instruction: {instruction}") 52 | 53 | if args.shuffle: 54 | dataset = dataset.shuffle(seed=args.shuffle) 55 | 56 | if args.limit: 57 | dataset = dataset.select(range(min(args.limit, len(dataset)))) 58 | 59 | model_class = hydra.utils.instantiate(args.model, _convert_="object") 60 | logger.info("Model was loaded.") 61 | 62 | def format_answer(example): 63 | query = example[args.input_key] 64 | if args.model.instruction_in_prompt: 65 | query = instruction + "\n" + query 66 | 67 | output = ( 68 | out[0] if isinstance(out := example[args.output_key], list) else out 69 | ) or "" 70 | 71 | if args.template: 72 | return open(args.template).read().format(query=query, output=output) 73 | else: 74 | messages = [ 75 | { 76 | "role": "system", 77 | "content": instruction, 78 | }, 79 | {"role": "user", "content": query}, 80 | { 81 | "role": "assistant", 82 | "content": output, 83 | }, 84 | ] 85 | 86 | return dict(messages=messages) 87 | 88 | dataset = dataset.map(format_answer) 89 | 90 | # Split the dataset into train and dev 91 | train, dev = itemgetter("train", "test")(dataset.train_test_split(args.dev_split)) 92 | 93 | collator = DataCollatorForCompletionOnlyLM( 94 | model_class.tokenizer.encode( 95 | args.model.completion_start, add_special_tokens=False 96 | ), 97 | tokenizer=model_class.tokenizer, 98 | ) 99 | 100 | logger.info("Initializing training arguments...") 101 | training_args = TrainingArguments(**args.train) 102 | 103 | logger.info("Starting to train...") 104 | trainer = SFTTrainer( 105 | model=model_class.model, 106 | args=training_args, 107 | data_collator=collator, 108 | train_dataset=train, 109 | eval_dataset=dev, 110 | dataset_batch_size=1, 111 | packing=False, 112 | max_seq_length=args.model.max_sequence_len, 113 | dataset_kwargs=dict(add_special_tokens=False), 114 | ) 115 | trainer.train(resume_from_checkpoint=args.resume_checkpoint) 116 | 117 | logger.info( 118 | f"Finished training; saving model to {args.train.output_dir}/checkpoint..." 119 | ) 120 | 121 | trainer.model.save_pretrained(Path(args.train.output_dir) / "checkpoint/") 122 | 123 | if args.hfhub_tag: 124 | trainer.model.push_to_hub(args.hfhub_tag, private=True) 125 | 126 | 127 | if __name__ == "__main__": 128 | main() 129 | -------------------------------------------------------------------------------- /ragfit/evaluation/deep.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | 4 | from .base import MetricBase 5 | 6 | 7 | class DeepEvalBase(MetricBase): 8 | """ 9 | Base class for DeepEval metrics. 10 | 11 | Here we use AzureChatOpenAI interface; replace if needed. 12 | """ 13 | 14 | def __init__( 15 | self, 16 | key_names: dict, 17 | api_version, 18 | azure_endpoint, 19 | azure_deployment, 20 | **kwargs, 21 | ): 22 | super().__init__(key_names, **kwargs) 23 | from deepeval.test_case import LLMTestCase 24 | from langchain_openai import AzureChatOpenAI 25 | 26 | self.local = True 27 | self.query = self.key_names["query"] 28 | self.context = self.key_names["context"] 29 | self.test_case = LLMTestCase 30 | 31 | self.model = AzureChatOpenAI( 32 | api_version=api_version, 33 | azure_endpoint=azure_endpoint, 34 | azure_deployment=azure_deployment, 35 | request_timeout=600, 36 | max_retries=10, 37 | ) 38 | 39 | 40 | class Faithfulness(DeepEvalBase): 41 | """ 42 | Faithfulness metric from DeepEval, based on RAGAS. 43 | 44 | Measures faithfulness of generated text by comparing it to the target text. 45 | """ 46 | 47 | def __init__(self, key_names: dict, threshold=0.3, **kwargs): 48 | super().__init__(key_names, **kwargs) 49 | from deepeval.metrics.ragas import RAGASFaithfulnessMetric 50 | 51 | self.metric = RAGASFaithfulnessMetric(threshold=threshold, model=self.model) 52 | 53 | def measure(self, example): 54 | query = example[self.query] 55 | output = example[self.field] 56 | context = example[self.context] 57 | 58 | test_case = self.test_case( 59 | input=query, 60 | actual_output=output or "No answer.", 61 | retrieval_context=[context] if isinstance(context, str) else context, 62 | ) 63 | try: 64 | self.metric.measure(test_case) 65 | score = self.metric.score 66 | except Exception as e: 67 | logging.error(f"OpenAI exception: {e}") 68 | score = 0 69 | 70 | return {"faithfulness": score if not math.isnan(score) else 0} 71 | 72 | 73 | class Relevancy(DeepEvalBase): 74 | """ 75 | Answer relevancy metric from DeepEval, based on RAGAS. 76 | 77 | Measures relevancy of generated text by comparing it to the retrieved documents. 78 | """ 79 | 80 | def __init__(self, key_names: dict, embeddings, threshold=0.3, **kwargs): 81 | super().__init__(key_names, **kwargs) 82 | from deepeval.metrics.ragas import RAGASAnswerRelevancyMetric 83 | from ragas.embeddings import HuggingfaceEmbeddings 84 | 85 | self.metric = RAGASAnswerRelevancyMetric( 86 | threshold=threshold, 87 | embeddings=HuggingfaceEmbeddings(model_name=embeddings), 88 | model=self.model, 89 | ) 90 | 91 | def measure(self, example): 92 | query = example[self.query] 93 | output = example[self.field] 94 | context = example[self.context] 95 | 96 | test_case = self.test_case( 97 | input=query, 98 | actual_output=output or "No answer.", 99 | retrieval_context=[context] if isinstance(context, str) else context, 100 | ) 101 | try: 102 | self.metric.measure(test_case) 103 | score = self.metric.score 104 | except Exception as e: 105 | logging.error(f"OpenAI exception: {e}") 106 | score = 0 107 | 108 | return {"relevancy": score} 109 | 110 | 111 | class Hallucination(DeepEvalBase): 112 | """ 113 | Hallucination metric from DeepEval. 114 | 115 | Measures hallucination of generated text by comparing it to the retrieved documents. 116 | """ 117 | 118 | def __init__(self, key_names: dict, threshold=0.5, **kwargs): 119 | super().__init__(key_names, **kwargs) 120 | from deepeval.metrics import HallucinationMetric 121 | 122 | self.metric = HallucinationMetric( 123 | threshold=threshold, include_reason=False, model=self.model 124 | ) 125 | 126 | def measure(self, example): 127 | output = example[self.field] 128 | context = example[self.context] 129 | 130 | test_case = self.test_case( 131 | input="", 132 | actual_output=output, 133 | context=[context] if isinstance(context, str) else context, 134 | ) 135 | 136 | try: 137 | self.metric.measure(test_case) 138 | score = self.metric.score 139 | except Exception as e: 140 | logging.error(f"OpenAI exception: {e}") 141 | score = 0 142 | 143 | return {"hallucination": score} 144 | -------------------------------------------------------------------------------- /docs/inference.md: -------------------------------------------------------------------------------- 1 | # Inference 2 | 3 | In the inference stage, we take the processed dataset and LLM and make predictions. The LLM can be fine-tuned. The 4 | processed data encapsulates the RAG interactions: pre-processing, retrieval, ranking, prompt-creation, and possibly 5 | other types of transformations. So this step deals with producing the predictions to be evaluated. 6 | 7 | It is simple in nature, described by the following configuration: 8 | 9 | ```yaml 10 | model: 11 | _target_: ragfit.models.hf.HFInference 12 | model_name_or_path: microsoft/Phi-3-mini-128k-instruct 13 | load_in_4bit: false 14 | load_in_8bit: true 15 | device_map: auto 16 | torch_dtype: 17 | trust_remote_code: true 18 | instruction: ragfit/processing/prompts/prompt_instructions/qa.txt 19 | instruct_in_prompt: false 20 | lora_path: 21 | generation: 22 | do_sample: false 23 | max_new_tokens: 50 24 | max_length: 25 | temperature: 26 | top_k: 27 | top_p: 28 | return_full_text: false 29 | 30 | data_file: asqa-baseline-dev.jsonl 31 | generated_file: asqa-baseline-dev-generated.jsonl 32 | input_key: prompt 33 | generation_key: output 34 | target_key: answers 35 | limit: 36 | ``` 37 | 38 | The model section deals with details regarding the model loading and generation options. System instruction can be 39 | provided, as we mentioned previously: the datasets are model independent, and all model details (system instruction, 40 | custom chat template) are needed only during training and inference. Similarly, `instruct_in_prompt` inserts the system 41 | instruction inside the prompt, for models which don't support a system role. 42 | 43 | Other parameters: 44 | - Data file is the processed file. 45 | - Generated file is the file that will be created with the completions (and labels, for easy debugging). 46 | - Target key is the label keyword. 47 | - Limit: to a number of examples, for debugging. 48 | 49 | ## Running Inference 50 | In order to run evaluations for ASQA, like in the paper, there are 5 configurations to run: baseline, context, context 51 | with fine-tuned model, CoT reasoning, and CoT reasoning with a model that was fine-tuned with distractor documents. 52 | 53 | The baseline inference uses the configuration as is; the other calls, use the configuration and just override the value 54 | of the processed data to use and optionally LORA path for the model. 55 | 56 | 57 | **Baseline**: 58 | ```sh 59 | python inference.py -cp configs/paper -cn inference-asqa 60 | ``` 61 | 62 | **Context**: 63 | ```sh 64 | python inference.py -cp configs/paper -cn inference-asqa \ 65 | data_file=asqa-context-dev.jsonl \ 66 | generated_file=asqa-context-dev-generated.jsonl 67 | ``` 68 | 69 | **Context with fine-tuned model**: 70 | ```sh 71 | python inference.py -cp configs/paper -cn inference-asqa \ 72 | data_file=asqa-context-dev.jsonl \ 73 | generated_file=asqa-context-ft-dev-generated.jsonl \ 74 | model.lora_path=./path/to/lora/checkpoint 75 | ``` 76 | 77 | **Chain-of-Thought**: 78 | ```sh 79 | python inference.py -cp configs/paper -cn inference-asqa \ 80 | data_file=asqa-cot-dev.jsonl \ 81 | generated_file=asqa-cot-ft-dev-generated.jsonl 82 | ``` 83 | 84 | **Chain-of-Thought with fine-tuned model**: 85 | ```sh 86 | python inference.py -cp configs/paper -cn inference-asqa \ 87 | data_file=asqa-cot-dev.jsonl \ 88 | generated_file=asqa-cot-ft-dev-generated.jsonl \ 89 | model.lora_path=./path/to/lora/checkpoint 90 | ``` 91 | 92 | ## Running Inference with vLLM Backend 93 | 94 | To achieve potentially faster inference speeds, you can run inference using the vLLM backend. The functionality of the inference process remains similar to the previously defined process, with the addition of extra arguments that can be used with the vLLM engine. 95 | 96 | Here is an example of an inference configuration using the vLLM engine: 97 | 98 | ```yaml 99 | model: 100 | _target_: ragfit.models.vllm.VLLMInference 101 | model_name_or_path: "facebook/opt-125m" 102 | llm_params: 103 | dtype: auto 104 | generation: 105 | temperature: 0.5 106 | top_p: 0.95 107 | seed: 1911 108 | num_gpus: 1 109 | 110 | data_file: my-processed-data.jsnol 111 | generated_file: model-predictions.jsonl 112 | input_key: prompt 113 | generation_key: output 114 | target_key: answers 115 | limit: 116 | ``` 117 | 118 | The main differences in this configuration are as follows: 119 | 120 | - `ragfit.models.vllm.VLLMInference`: This class is used to utilize the vLLM-based engine. 121 | - `llm_params`: These are optional vLLM arguments that can be passed to the LLM class. 122 | - `generation`: These are optional arguments that define the generation policy. The supported arguments are compatible with vLLM's `SamplingParams`. 123 | - `num_gpus`: This specifies the number of GPUs to use during inference. 124 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 |
2 | 3 |
4 | 5 | ---------- 6 | 7 | **RAG-FiT** is a library designed to improve LLMs ability to use external information by fine-tuning models on 8 | specially created RAG-augmented datasets. The library helps create the data for training, given a RAG technique, helps 9 | easily train models using parameter-efficient finetuning (PEFT), and finally can help users measure the improved 10 | performance using various, RAG-specific metrics. The library is modular, workflows are customizable using configuration 11 | files. Formerly called *RAG Foundry*. 12 | 13 | Comments, suggestions, issues and pull-requests are welcomed! ❤️ 14 | 15 | ### Installation 16 | Clone and run: 17 | 18 | ```sh 19 | pip install -e . 20 | ``` 21 | 22 | Optional packages can be installed: 23 | ```sh 24 | pip install -e .[haystack] 25 | pip install -e .[deepeval] 26 | ``` 27 | 28 | ### Quick Start 29 | 30 | For a simple, end-to-end example, see the [PubmedQA Tutorial](pubmed.md). 31 | 32 | ## Overview 33 | 34 | The RAG-FiT framework facilitates fast prototyping and experimentation with various RAG settings and configurations, 35 | including data selection and filtering, processing, retrieval, ranking, query manipulation, prompt generation, training, 36 | inference, output processing and evaluation. The library is comprised of 4 modules: dataset creation, training, 37 | inference and evaluation. 38 | 39 | * **Dataset Creation**: The processing module creates datasets, persisting RAG interactions, to be used for RAG training 40 | and inference. RAG interactions include dataset loading, columns normalization, data aggregation (fewshot creation), 41 | information retrieval using external tools and frameworks, API integration, template-based prompt creation and any other 42 | form of pre-processing. The data is saved in a consistent, model-independent, input-output format, along with all other 43 | fields and metadata. See [Processing](processing.md). 44 | 45 | * **Training**: using PEFT for efficient training and TRL (e.g. supervised FT) users can train any model on the augmented 46 | datasets. Training is done on the completions. Models can be pushed to HF Hub. See [Training](training.md). 47 | 48 | * **Inference**: generating predictions using the augmented datasets with trained or untrained LLMs. See [Inference](inference.md). 49 | 50 | * **Evaluation**: running evaluation on the generated output from the inference module. Users can provide a list of 51 | metrics to run; custom metrics can be implemented easily. Current metrics include EM, F1, ROUGE, BERTScore, Deepeval, 52 | RAGAS, HF `evaluate` and classification. Metrics can be *local*—run on each example, or *global*—run on the entire 53 | dataset, e.g. recall. Metrics can utilize any feature in the dataset, like retrieval results, reasoning, 54 | citations and attributions, not just the input and output texts. See [Evaluation](evaluation.md). 55 | 56 | 57 | ## Running 58 | The 4 modules are represented as scripts: `processing.py`, `training.py`, `inference.py` and `evaluation.py` at the top 59 | level. Every call has the form `python SCRIPT options...`. 60 | 61 | The library utilizes the [Hydra](https://hydra.cc/docs/intro/) configuration tool; it enables the use of hierarchical 62 | configurations, easily overridden of values in the CLI and the ability to run multiple jobs remotely (e.g. integrations with 63 | SLURM and Ray). It represents a *configuration-as-code* approach, as it can instantiate python classes according to 64 | configuration (the `_target_` keyword indicates the python class to use in a given context). 65 | 66 | There are default configurations for each module in the [configs](./configs/) folder. A configuration file can be 67 | overridden like so: 68 | 69 | ```sh 70 | python processing -cp configs/paper -cn processing-asqa-retrieval 71 | ``` 72 | 73 | Individual keywords can be overridden as well: 74 | ```sh 75 | python processing -cp configs/paper -cn processing-asqa-retrieval \ 76 | output_path=/store/data/here \ 77 | cache=true 78 | ``` 79 | 80 | For a complete set of configurations, **reproducing the experimentation in the paper with the ASQA dataset**, see the 81 | configurations in the [Paper](./configs/paper) folder. 82 | 83 | ## Citation 84 | 85 | Please cite our paper if it helps your research: [RAG Foundry: A Framework for Enhancing LLMs for Retrieval Augmented Generation](https://arxiv.org/abs/2408.02545). 86 | 87 | ```BibTex 88 | @article{fleischerRAGFoundryFramework2024, 89 | title = {{RAG} {Foundry}: {A} {Framework} for {Enhancing} {LLMs} for {Retrieval} {Augmented} {Generation}}, 90 | author = {Fleischer, Daniel and Berchansky, Moshe and Wasserblat, Moshe and Izsak, Peter}, 91 | year = 2024, 92 | note = {arXiv:2408.02545 [cs]}, 93 | annote = {Comment: 10 pages}, 94 | url = {http://arxiv.org/abs/2408.02545}, 95 | publisher = {arXiv}, 96 | } 97 | ``` 98 | 99 | ## License 100 | 101 | The code is licensed under the [Apache 2.0 License](LICENSE). 102 | 103 | ## Disclaimer 104 | 105 | This is not an official Intel product. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 |
4 | 5 | ---------- 6 | 7 | **RAG-FiT** is a library designed to improve LLMs ability to use external information by fine-tuning models on 8 | specially created RAG-augmented datasets. The library helps create the data for training, given a RAG technique, helps 9 | easily train models using parameter-efficient finetuning (PEFT), and finally can help users measure the improved 10 | performance using various, RAG-specific metrics. The library is modular, workflows are customizable using configuration 11 | files. Formerly called *RAG Foundry*. 12 | 13 | Comments, suggestions, issues and pull-requests are welcomed! ❤️ 14 | 15 | ### Installation 16 | Clone and run: 17 | 18 | ```sh 19 | pip install -e . 20 | ``` 21 | 22 | Optional packages can be installed: 23 | ```sh 24 | pip install -e .[haystack] 25 | pip install -e .[deepeval] 26 | ``` 27 | 28 | ### Quick Start 29 | 30 | For a simple, end-to-end example, see the [PubmedQA Tutorial](./docs/pubmed.md). 31 | 32 | ## Overview 33 | 34 | The RAG-FiT framework facilitates fast prototyping and experimentation with various RAG settings and configurations, 35 | including data selection and filtering, processing, retrieval, ranking, query manipulation, prompt generation, training, 36 | inference, output processing and evaluation. The library is comprised of 4 modules: dataset creation, training, 37 | inference and evaluation. 38 | 39 | * **Dataset Creation**: The processing module creates datasets, persisting RAG interactions, to be used for RAG training 40 | and inference. RAG interactions include dataset loading, columns normalization, data aggregation (fewshot creation), 41 | information retrieval using external tools and frameworks, API integration, template-based prompt creation and any other 42 | form of pre-processing. The data is saved in a consistent, model-independent, input-output format, along with all other 43 | fields and metadata. See [Processing.md](docs/processing.md). 44 | 45 | * **Training**: using PEFT for efficient training and TRL (e.g. supervised FT) users can train any model on the augmented 46 | datasets. Training is done on the completions. Models can be pushed to HF Hub. See [Training.md](docs/training.md). 47 | 48 | * **Inference**: generating predictions using the augmented datasets with trained or untrained LLMs. See [Inference.md](docs/inference.md). 49 | 50 | * **Evaluation**: running evaluation on the generated output from the inference module. Users can provide a list of 51 | metrics to run; custom metrics can be implemented easily. Current metrics include EM, F1, ROUGE, BERTScore, Deepeval, 52 | RAGAS, HF `evaluate` and classification. Metrics can be *local*—run on each example, or *global*—run on the entire 53 | dataset, e.g. recall. Metrics can utilize any feature in the dataset, like retrieval results, reasoning, 54 | citations and attributions, not just the input and output texts. See [Evaluation.md](docs/evaluation.md). 55 | 56 | 57 | ## Running 58 | The 4 modules are represented as scripts: `processing.py`, `training.py`, `inference.py` and `evaluation.py` at the top 59 | level. Every call has the form `python SCRIPT options...`. 60 | 61 | The library utilizes the [Hydra](https://hydra.cc/docs/intro/) configuration tool; it enables the use of hierarchical 62 | configurations, easily overridden of values in the CLI and the ability to run multiple jobs remotely (e.g. integrations with 63 | SLURM and Ray). It represents a *configuration-as-code* approach, as it can instantiate python classes according to 64 | configuration (the `_target_` keyword indicates the python class to use in a given context). 65 | 66 | There are default configurations for each module in the [configs](./configs/) folder. A configuration file can be 67 | overridden like so: 68 | 69 | ```sh 70 | python processing -cp configs/paper -cn processing-asqa-retrieval 71 | ``` 72 | 73 | Individual keywords can be overridden as well: 74 | ```sh 75 | python processing -cp configs/paper -cn processing-asqa-retrieval \ 76 | output_path=/store/data/here \ 77 | cache=true 78 | ``` 79 | 80 | For a complete set of configurations, **reproducing the experimentation in the paper with the ASQA dataset**, see the 81 | configurations in the [Paper](./configs/paper) folder. 82 | 83 | ## Citation 84 | 85 | Please cite our paper if it helps your research: [RAG Foundry: A Framework for Enhancing LLMs for Retrieval Augmented Generation](https://arxiv.org/abs/2408.02545). 86 | 87 | ```BibTex 88 | @article{fleischerRAGFoundryFramework2024, 89 | title = {{RAG} {Foundry}: {A} {Framework} for {Enhancing} {LLMs} for {Retrieval} {Augmented} {Generation}}, 90 | author = {Fleischer, Daniel and Berchansky, Moshe and Wasserblat, Moshe and Izsak, Peter}, 91 | year = 2024, 92 | note = {arXiv:2408.02545 [cs]}, 93 | annote = {Comment: 10 pages}, 94 | url = {http://arxiv.org/abs/2408.02545}, 95 | publisher = {arXiv}, 96 | } 97 | ``` 98 | 99 | ## License 100 | 101 | The code is licensed under the [Apache 2.0 License](LICENSE). 102 | 103 | ## Disclaimer 104 | 105 | This is not an official Intel product. -------------------------------------------------------------------------------- /ragfit/processing/pipeline.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | import os 4 | from typing import List 5 | 6 | import hydra 7 | from datasets import load_dataset 8 | from tqdm import tqdm 9 | 10 | from .step import BaseStep 11 | 12 | 13 | class DataPipeline: 14 | """Class for creating a data pipeline. 15 | 16 | The pipeline holds the list of steps and run them one after the other. The 17 | datasets are stored in a global dictionary, where datasets are referred by a 18 | key name, as indicated in the `inputs` parameter for each step. The pipeline 19 | manages the cache lookup and creation. 20 | 21 | Args: 22 | name (str): Name of the pipeline. 23 | output_path (str, optional): Path to store the cache files. Defaults to ".". 24 | cache (bool, optional): Whether to cache the datasets. Defaults to True. 25 | steps (List[BaseStep], optional): List of steps in the pipeline. Defaults to []. 26 | inputs (str, optional): Name of the main dataset. Defaults to "main_dataset". 27 | 28 | """ 29 | 30 | def __init__( 31 | self, 32 | name, 33 | output_path=".", 34 | cache=True, 35 | steps: List[BaseStep] = [], 36 | inputs: str = "main_dataset", 37 | **kwargs, 38 | ) -> None: 39 | self.name = name 40 | self.output_path = output_path 41 | self.cache = cache 42 | logging.info(f"Caching state: {self.cache}") 43 | self.last_update = math.inf 44 | 45 | self.steps = [ 46 | hydra.utils.instantiate(step, _convert_="object") for step in steps 47 | ] # TODO: do it lazily to prevent OOM 48 | 49 | self.inputs = inputs if isinstance(inputs, list) else [inputs] 50 | self.datasets = {} 51 | 52 | def gen_cache_fn(self, step, index, dataset_name): 53 | """ 54 | Create a unique cache filename for a given dataset, at a given step, in a given index. 55 | Uses the step name, inputs, hash and pipeline's path and name and dataset name. 56 | 57 | Returns a string. 58 | """ 59 | return ( 60 | f"{self.output_path}/cache" 61 | f"_{self.name}_{index}" 62 | f"_{type(step).__name__}" 63 | f"_{dataset_name}_{step.get_hash()}.json" 64 | ) 65 | 66 | def get_cache_mapping(self, step: BaseStep, index: int): 67 | """ 68 | Returns a mapping between input datasets and cache filenames, for a given step. 69 | """ 70 | if self.cache: 71 | datasets_caches = { 72 | dataset_name: self.gen_cache_fn(step, index, dataset_name) 73 | for dataset_name in step.inputs 74 | } 75 | return datasets_caches 76 | 77 | return None 78 | 79 | def cache_step(self, step, step_index): 80 | """ 81 | Write to cache-files the current state of the global datasets dictionary for the given inputs. 82 | """ 83 | if self.cache: 84 | for dataset_name in step.inputs: 85 | dataset = self.datasets[dataset_name] 86 | saved_path = self.gen_cache_fn(step, step_index, dataset_name) 87 | dataset.to_json(saved_path, lines=True) 88 | 89 | def load_from_cache(self, caches_map): 90 | """ 91 | Load datasets from cache using a cache_map. 92 | Updates the global datasets dictionary. 93 | 94 | Internal function, shouldn't be used by the user. 95 | """ 96 | logging.info(f"Loading dataset from checkpoints {caches_map}") 97 | for dataset_name, saved_path in caches_map.items(): 98 | self.datasets[dataset_name] = load_dataset( 99 | "json", data_files=[saved_path], split="train" 100 | ) 101 | 102 | def delete_cache(self): 103 | """ 104 | Removing cache files for all steps, cleaning the pipeline. 105 | """ 106 | logging.info("Removing cache files for entire pipeline.") 107 | if self.cache: 108 | for i, step in enumerate(self.steps): 109 | cache_map = self.get_cache_mapping(step, i) 110 | if cache_map is not None: 111 | for dataset_name, cache_path in cache_map.items(): 112 | if os.path.exists(cache_path): 113 | os.remove(cache_path) 114 | 115 | def process(self): 116 | """ 117 | Run pipeline, step after step. 118 | 119 | Caching is handled here. A step is calculated either if there was a change in the pipeline at a previous 120 | step OR the current step has no cache file. 121 | 122 | When a step is calculated, it is cached and self.last_update is updated to the current step index. 123 | """ 124 | for i, step in tqdm(enumerate(self.steps)): 125 | logging.info(f"Processing step {i}") 126 | 127 | cache_map = self.get_cache_mapping(step, i) 128 | if ( 129 | (cache_map is not None) 130 | and (all(os.path.exists(v) for v in cache_map.values())) 131 | and (i < self.last_update) 132 | ): 133 | logging.info(f"Loading cached datasets for {type(step).__name__}") 134 | self.load_from_cache(cache_map) 135 | else: 136 | step(self.datasets) 137 | if step.cache_step: 138 | self.cache_step(step, i) 139 | self.last_update = i 140 | -------------------------------------------------------------------------------- /ragfit/models/hf.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from pathlib import Path 3 | 4 | from peft import LoraConfig, get_peft_model 5 | from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, pipeline 6 | 7 | logger = logging.getLogger(__name__) 8 | 9 | 10 | class HFInference: 11 | """ 12 | Class for running HF model inference locally. 13 | """ 14 | 15 | def __init__( 16 | self, 17 | model_name_or_path: str, 18 | torch_dtype, 19 | device_map, 20 | instruction: Path, 21 | instruct_in_prompt: False, 22 | template: Path = None, 23 | lora_path=None, 24 | generation=None, 25 | task="text-generation", 26 | **kwargs, 27 | ): 28 | """ 29 | Initialize a HF model, with optional LORA adapter. 30 | 31 | Args: 32 | model_name_or_path (str): HF model name or path. 33 | torch_dtype (str): torch dtype for the model. 34 | device_map: device map for the model. 35 | instruction (Path): path to the instruction file. 36 | instruct_in_prompt (bool): whether to include the instruction in the prompt for models without system role. 37 | template (Path): path to a prompt template file if tokenizer does not include chat template. Optional. 38 | lora_path (Path): path to the LORA adapter. 39 | generation (dict): generation kwargs. 40 | task (str): task for the pipeline. 41 | """ 42 | 43 | self.model_name = model_name_or_path 44 | self.generation_kwargs = generation 45 | self.instruction = open(instruction).read() 46 | logger.info(f"Using the following instruction: {self.instruction}") 47 | 48 | self.instruct_in_prompt = instruct_in_prompt 49 | self.template = open(template).read() if template else None 50 | 51 | self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, **kwargs) 52 | 53 | self.config = AutoConfig.from_pretrained(self.model_name, **kwargs) 54 | self.config.torch_dtype = torch_dtype or "auto" 55 | 56 | self.model = AutoModelForCausalLM.from_pretrained( 57 | self.model_name, config=self.config, device_map=device_map, **kwargs 58 | ) 59 | if lora_path: 60 | logger.info(f"Loading LORA: {lora_path}") 61 | self.model.load_adapter(lora_path) 62 | 63 | self.pipe = pipeline( 64 | task=task, 65 | model=self.model, 66 | tokenizer=self.tokenizer, 67 | ) 68 | 69 | def generate(self, prompt: str) -> str: 70 | """ 71 | Given an input, generate a response. 72 | """ 73 | 74 | if self.template: 75 | prompt = self.template.format(instruction=self.instruction, query=prompt) 76 | 77 | else: 78 | if self.instruct_in_prompt: 79 | prompt = self.instruction + "\n" + prompt 80 | 81 | messages = [ 82 | {"role": "system", "content": self.instruction}, 83 | {"role": "user", "content": prompt}, 84 | ] 85 | 86 | prompt = self.tokenizer.apply_chat_template( 87 | messages, 88 | tokenize=False, 89 | add_generation_prompt=True, 90 | truncation=True, 91 | max_length=( 92 | self.config.max_position_embeddings 93 | - self.generation_kwargs["max_new_tokens"] 94 | ), 95 | ) 96 | 97 | output = self.pipe(prompt, **self.generation_kwargs) 98 | return output[0]["generated_text"] 99 | 100 | 101 | class HFTrain: 102 | """ 103 | Class for training HF models locally. 104 | """ 105 | 106 | def __init__( 107 | self, 108 | model_name_or_path, 109 | torch_dtype, 110 | device_map, 111 | lora: LoraConfig = None, 112 | generation=None, 113 | completion_start: str = "", 114 | instruction_in_prompt=None, 115 | max_sequence_len=None, 116 | **kwargs, 117 | ): 118 | """ 119 | Args: 120 | model_name_or_path: str - HF model name or path. 121 | torch_dtype: str - torch dtype for the model. 122 | device_map: dict - device map for the model. 123 | lora: dict - LoRA adapter config. 124 | generation: dict - generation kwargs. 125 | completion_start: str - used to find the start of the completion in the prompt. 126 | instruction_in_prompt: bool - whether to include the instruction in the prompt for models without system role. 127 | """ 128 | self.model_name = model_name_or_path 129 | self.complete_start = completion_start 130 | self.instructions_in_prompt = instruction_in_prompt 131 | self.generation_kwargs = generation 132 | 133 | self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) 134 | self.tokenizer.pad_token = self.tokenizer.eos_token 135 | 136 | self.config = AutoConfig.from_pretrained(self.model_name, **kwargs) 137 | self.config.torch_dtype = torch_dtype or "auto" 138 | 139 | self.model = AutoModelForCausalLM.from_pretrained( 140 | self.model_name, 141 | config=self.config, 142 | device_map=device_map, 143 | **kwargs, 144 | ) 145 | 146 | self.model.config.use_cache = False 147 | logger.info(f"Loaded model: {self.model}") 148 | 149 | logger.info(f"Initializing LORA based on {lora}") 150 | self.model = get_peft_model(self.model, LoraConfig(**lora)) 151 | -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | # site info 2 | site_name: RAG-FiT Documentation 3 | # site_url: http:// 4 | site_author: Intel Labs 5 | site_description: 6 | Small, minimalistic and modular library to improve and measure LLMs RAG ability, using prompt-engineering and fine-tuning. 7 | 8 | # Repository 9 | repo_url: https://github.com/IntelLabs/RAG-FiT 10 | repo_name: IntelLabs/RAG-FiT 11 | 12 | # theme 13 | theme: 14 | name: material 15 | palette: 16 | - scheme: ragfit 17 | toggle: 18 | icon: material/toggle-switch-off-outline 19 | name: Switch to dark mode 20 | - scheme: slate 21 | toggle: 22 | icon: material/toggle-switch 23 | name: Switch to light mode 24 | primary: black 25 | accent: light-blue 26 | font: 27 | text: Open sans 28 | code: inconsolata 29 | language: en 30 | logo: assets/rag_fit_white.png 31 | features: 32 | - announce.dismiss 33 | # - content.action.edit 34 | # - content.action.view 35 | - content.code.annotate 36 | - content.code.copy 37 | # - content.code.select 38 | # - content.footnote.tooltips 39 | # - content.tabs.link 40 | - content.tooltips 41 | # - header.autohide 42 | # - navigation.expand 43 | - navigation.footer 44 | - navigation.indexes 45 | # - navigation.instant 46 | # - navigation.instant.prefetch 47 | # - navigation.instant.progress 48 | # - navigation.prune 49 | - navigation.sections 50 | - navigation.tabs 51 | # - navigation.tabs.sticky 52 | - navigation.top 53 | - navigation.tracking 54 | - search.highlight 55 | - search.share 56 | - search.suggest 57 | - toc.follow 58 | # - toc.integrate 59 | icon: 60 | repo: simple/intel 61 | extra_css: 62 | - stylesheets/extra.css 63 | plugins: 64 | # - blog 65 | - search: 66 | separator: 67 | '[\s\u200b\-_,:!=\[\]()"`/]+|\.(?!\d)|&[lg]t;|(?!\b)(?=[A-Z][a-z])' 68 | # - gen-files: 69 | # scripts: 70 | # - docs/scripts/generate_docstrings.py 71 | - mkdocstrings: 72 | handlers: 73 | python: 74 | # selection: 75 | # inherited_members: true # Allow looking up inherited methods 76 | options: 77 | # show_protected_members: true 78 | # show_private_members: true 79 | # docstring_style: google 80 | docstring_section_style: list 81 | show_source: true # don't include source code 82 | show_docstring_functions: true 83 | show_signature_annotations: true 84 | show_docstring_description: true 85 | show_docstring_examples: true 86 | # show_docstring_attributes: true 87 | # merge_init_into_class: false 88 | # rendering: 89 | # docstring_section_style: list 90 | # show_root_heading: true # actually display anything at all... 91 | # # show_root_full_path: true # display "diffrax.asdf" not just "asdf" 92 | # show_if_no_docstring: true 93 | # show_signature_annotations: true 94 | # show_source: false # don't include source code 95 | # show_docstring_attributes: true 96 | # merge_init_into_class: false 97 | # # members_order: source # order methods according to their order of definition in the source code, not alphabetical order 98 | # # heading_level: 4 99 | # Extensions 100 | markdown_extensions: 101 | - abbr 102 | - admonition 103 | - attr_list 104 | - def_list 105 | - footnotes 106 | - md_in_html 107 | - toc: 108 | permalink: true 109 | - pymdownx.arithmatex: 110 | generic: true 111 | - pymdownx.betterem: 112 | smart_enable: all 113 | - pymdownx.caret 114 | - pymdownx.details 115 | - pymdownx.emoji: 116 | emoji_generator: !!python/name:material.extensions.emoji.to_svg 117 | emoji_index: !!python/name:material.extensions.emoji.twemoji 118 | - pymdownx.highlight: 119 | anchor_linenums: true 120 | line_spans: __span 121 | pygments_lang_class: true 122 | - pymdownx.inlinehilite 123 | - pymdownx.keys 124 | - pymdownx.magiclink: 125 | normalize_issue_symbols: true 126 | repo_url_shorthand: true 127 | user: squidfunk 128 | repo: mkdocs-material 129 | - pymdownx.mark 130 | - pymdownx.smartsymbols 131 | - pymdownx.snippets: 132 | auto_append: 133 | - includes/mkdocs.md 134 | - pymdownx.superfences: 135 | custom_fences: 136 | - name: mermaid 137 | class: mermaid 138 | format: !!python/name:pymdownx.superfences.fence_code_format 139 | - pymdownx.tabbed: 140 | alternate_style: true 141 | combine_header_slug: true 142 | slugify: !!python/object/apply:pymdownx.slugs.slugify 143 | kwds: 144 | case: lower 145 | - pymdownx.tasklist: 146 | custom_checkbox: true 147 | - pymdownx.tilde 148 | 149 | extra: 150 | generator: false 151 | 152 | nav: 153 | - Home: 154 | - Home: "index.md" 155 | - Tutorial: "pubmed.md" 156 | - Guide: 157 | - Data Augmentation: "processing.md" 158 | - Training: "training.md" 159 | - Inference: "inference.md" 160 | - Evaluation: "evaluation.md" 161 | - API: 162 | - Processing: 163 | - Step: "reference/processing/step.md" 164 | - Pipeline: "reference/processing/pipeline.md" 165 | - Dataset Loaders: 166 | - loaders: "reference/processing/dataset_loaders/loaders.md" 167 | - Local Steps: 168 | - Common Datasets: "reference/processing/local_steps/common_datasets.md" 169 | - Formatting: "reference/processing/local_steps/formatting.md" 170 | - Retrievers: 171 | - Haystack: 172 | "reference/processing/local_steps/retrievers/haystack.md" 173 | - API: 174 | - OpenAI Chat: 175 | "reference/processing/local_steps/api/openai.md" 176 | - Context: "reference/processing/local_steps/context.md" 177 | - Prompt Creation: "reference/processing/local_steps/prompter.md" 178 | - Inference: "reference/processing/local_steps/inference.md" 179 | - RAFT: "reference/processing/local_steps/raft.md" 180 | - Global Steps: 181 | - Aggregation and merging: "reference/processing/global_steps/aggregation.md" 182 | - Sampling and Fewshot: "reference/processing/global_steps/sampling.md" 183 | - Filters: "reference/processing/global_steps/filters.md" 184 | - Output: "reference/processing/global_steps/output.md" 185 | - Answer Processors: 186 | - regex: "reference/processing/answer_processors/regex.md" 187 | - Utils: "reference/processing/utils.md" 188 | - Models: 189 | - Transformers: "reference/models/hf.md" 190 | - OpenAI: "reference/models/openai_executor.md" 191 | - vLLM: "reference/models/vllm.md" 192 | - Evaluation: 193 | - Base: "reference/evaluation/base.md" 194 | - Metrics: "reference/evaluation/metrics.md" 195 | - DeepEval: "reference/evaluation/deep.md" 196 | - Utils: "reference/utils.md" 197 | 198 | -------------------------------------------------------------------------------- /docs/evaluation.md: -------------------------------------------------------------------------------- 1 | # Evaluations 2 | 3 | The evaluation module takes the produced inference file and the original processed dataset and runs a list of 4 | evaluations, producing a final results file, in a YAML format. The evaluations are represented as metric classes. 5 | 6 | We implement several metrics including: a wrapper for HuggingFace `evaluate` class, which can accept a list of metrics, 7 | EM, F1, classification (accuracy, precision, recall, F1), BERTScore, Semantic similarity (using a customizable 8 | cross-encoder). The module can also run metrics from [DeepEval](https://docs.confident-ai.com/docs/getting-started), 9 | which offers a large collection of LLM evaluations. 10 | 11 | Metrics can be either local or global; a local metric runs over each example individually, scores are collected and 12 | averaged. A global metric runs on the entire dataset at once, for example: classification F1. 13 | 14 | 15 | The configuration contains the following section: 16 | 17 | ```yaml 18 | answer_processor: 19 | _target_: ragfit.processing.answer_processors.regex.RegexAnswer 20 | capture_pattern: # ": (.*)" 21 | stopping_pattern: # "[,.;]" 22 | ``` 23 | The evaluation module introduces the concept of an Answer Processor. This class can run post-processing on the 24 | generated text, preparing it for evaluations or the specific format some metrics require. 25 | 26 | There is a default processor, called `RegexAnswer`; it can filter text, based on a python regex capture pattern. It can 27 | also split text using a stopping pattern. For example, in the Chain-of-Thought reasoning we used in the paper, the model 28 | is instruction to explain its answer, cite if needed and finally print the final results in the following format 29 | `: ...`. We can use this format as a capture pattern; thus models that learn to answer using this pattern (obey the 30 | instruction) will score higher. 31 | 32 | Next is a list of metrics; each one is a python class: 33 | ```yaml 34 | metrics: 35 | - _target_: ragfit.evaluation.metrics.HFEvaluate 36 | metric_names: [rouge] 37 | - _target_: ragfit.evaluation.metrics.EM 38 | - _target_: ragfit.evaluation.metrics.F1 39 | - _target_: ragfit.evaluation.metrics.BERTScore 40 | model: microsoft/deberta-large-mnli 41 | ``` 42 | 43 | Some metrics require additional parameters, for example HuggingFace `evaluate` requires the metrics' names, BERTScore 44 | requires an embedding model. 45 | 46 | ```yaml 47 | key_names: 48 | generated: generated 49 | label: answer 50 | query: query 51 | context: context 52 | ``` 53 | A mapping of keys and values: the values should represent the names of the corresponding fields in the processed 54 | dataset. 55 | 56 | Finally: 57 | ```yaml 58 | results_file: my-evaluation.yaml 59 | generated_file: inference.jsonl 60 | data_file: my-processed-data.jsonl 61 | limit: 62 | ``` 63 | 64 | One needs to provide the generated inference file, the processed dataset and a filename for the results summary. A limit 65 | number of rows can be provided for debugging purposes. 66 | 67 | ## Running Evaluations on ASQA 68 | 69 | As the final part of the demonstration of the framework with the ASQA dataset and Phi-3 models, we will evaluate the 70 | different RAG configurations, with and without the use of fine-tuning. 71 | 72 | As a reminder, ASQA has 2 types of answers: long answer and short answers. We will evaluate the generated answers using 73 | the long answer with RAGAS metrics (faithfulness and relevancy) and use the short answers with ASQA defined STR-EM. 74 | 75 | ### Short 76 | Starting with the short answers, the label keyword is `answer-short` (recall the processing) and a representative 77 | configuration looks like this: 78 | 79 | ```yaml 80 | answer_processor: 81 | _target_: ragfit.processing.answer_processors.regex.RegexAnswer 82 | capture_pattern: ": (.*)" 83 | stopping_pattern: 84 | 85 | metrics: 86 | - _target_: ragfit.evaluation.metrics.StringEM 87 | 88 | key_names: 89 | generated: text 90 | label: answer-short 91 | query: query 92 | 93 | results_file: evaluation-asqa-baseline.yaml 94 | generated_file: asqa-baseline-dev-generated.jsonl 95 | data_file: asqa-baseline-dev.jsonl 96 | ``` 97 | 98 | Here are the calls to evaluate the different configurations: 99 | 100 | **Baseline**: 101 | ```sh 102 | python evaluation.py -cp configs/paper -cn evaluation-asqa-short 103 | ``` 104 | 105 | **Context**: 106 | ```sh 107 | python evaluation.py -cp configs/paper -cn evaluation-asqa-short \ 108 | results_file=asqa-context-dev-generated-results.yaml \ 109 | data_file=asqa-context-dev.jsonl \ 110 | generated_file=asqa-context-dev-generated.jsonl 111 | ``` 112 | 113 | **Context with fine-tuned model**: 114 | ```sh 115 | python evaluation.py -cp configs/paper -cn evaluation-asqa-short \ 116 | results_file=asqa-context-ft-dev-generated-results.yaml \ 117 | data_file=asqa-context-dev.jsonl \ 118 | generated_file=asqa-context-ft-dev-generated.jsonl 119 | ``` 120 | 121 | **Chain-of-Thought**: 122 | ```sh 123 | python evaluation.py -cp configs/paper -cn evaluation-asqa-short \ 124 | results_file=asqa-cot-dev-generated-results.yaml \ 125 | data_file=asqa-cot-dev.jsonl \ 126 | generated_file=asqa-cot-dev-generated.jsonl 127 | ``` 128 | 129 | **Chain-of-Thought with fine-tuned model**: 130 | ```sh 131 | python evaluation.py -cp configs/paper -cn evaluation-asqa-short \ 132 | results_file=asqa-cot-ft-dev-generated-results.yaml \ 133 | data_file=asqa-cot-dev.jsonl \ 134 | generated_file=asqa-cot-ft-dev-generated.jsonl 135 | ``` 136 | 137 | 138 | ### Long 139 | Evaluation the generated output with respect to the full answer, we use two RAGAS metrics, namely faithfulness and 140 | relevancy. The RAGAS metrics require a context for the critic to make a judgment, so these are not relevant for the 141 | baseline configuration. 142 | 143 | The different in configuration is in the list of metrics and keywords: 144 | 145 | ```yaml 146 | metrics: 147 | - _target_: ragfit.evaluation.deep.Faithfulness 148 | azure_endpoint: azure.endpoint.com 149 | azure_deployment: GPT-4-32k-Bot 150 | api_version: 2024-05-01-preview 151 | - _target_: ragfit.evaluation.deep.Relevancy 152 | azure_endpoint: azure.endpoint.com 153 | azure_deployment: GPT-4-32k-Bot 154 | api_version: 2024-05-01-preview 155 | embeddings: BAAI/bge-small-en-v1.5 156 | 157 | key_names: 158 | generated: text 159 | label: answers 160 | query: query 161 | context: positive_passages 162 | ``` 163 | 164 | The relevancy metrics an embedder—it generates probable questions based on the generated answer (and the context) and 165 | then measures semantic similarity to the original question. 166 | 167 | **Context**: 168 | ```sh 169 | python evaluation.py -cp configs/paper -cn evaluation-asqa-long \ 170 | results_file=asqa-context-dev-generated-results-ragas.yaml \ 171 | data_file=asqa-context-dev.jsonl \ 172 | generated_file=asqa-context-dev-generated.jsonl 173 | ``` 174 | 175 | **Context with fine-tuned model**: 176 | ```sh 177 | python evaluation.py -cp configs/paper -cn evaluation-asqa-long \ 178 | results_file=asqa-context-ft-dev-generated-results-ragas.yaml \ 179 | data_file=asqa-context-dev.jsonl \ 180 | generated_file=asqa-context-ft-dev-generated.jsonl 181 | ``` 182 | 183 | **Chain-of-Thought**: 184 | ```sh 185 | python evaluation.py -cp configs/paper -cn evaluation-asqa-long \ 186 | results_file=asqa-cot-dev-generated-results-ragas.yaml \ 187 | data_file=asqa-cot-dev.jsonl \ 188 | generated_file=asqa-cot-dev-generated.jsonl 189 | ``` 190 | 191 | **Chain-of-Thought with fine-tuned model**: 192 | ```sh 193 | python evaluation.py -cp configs/paper -cn evaluation-asqa-long \ 194 | results_file=asqa-cot-ft-dev-generated-results-ragas.yaml \ 195 | data_file=asqa-cot-dev.jsonl \ 196 | generated_file=asqa-cot-ft-dev-generated.jsonl 197 | ``` 198 | 199 | 200 | 201 | 202 | -------------------------------------------------------------------------------- /docs/processing.md: -------------------------------------------------------------------------------- 1 | # Data Augmentation 2 | 3 | To demonstrate the usage of RAG-FiT data augmentation, we will follow the experimentation presented in the paper. 4 | Choosing the ASQA Q&A dataset and the Phi-3 model. We compare a baseline configuration with 4 other configurations: 5 | 6 | 1. Retrieval augmentation using a corpus and inserting the documents in the prompt after the question. 7 | 2. Similar to (1) but having the model fine-tune on the completions. 8 | 3. Similar to (1) and adding a Chain-of-Thought instruction for the model to explain its reasoning and format its 9 | answer. 10 | 4. Similar to (3) but having the model fine-tune on the completions while implementing a technique from RAFT where 11 | distracting documents are used. 12 | 13 | The [ASQA dataset](https://huggingface.co/datasets/din0s/asqa) has two types of answer: a long answer and lists of short 14 | answers (actually list of lists). Additionally, it has some minimal amount of context in the data, so we augment it 15 | using a corpus, stored as a vector DB; we use [Qdrant](https://qdrant.tech/). 16 | 17 | In order to train configuration (4), we need to have CoT well-reasoned responses as labels, so we use OpenAI GPT4 model to augment a 18 | dataset with these synthetic labels. 19 | 20 | **Notice**: all the configurations mentioned here, implementing the experiments done in the paper, are saved in 21 | `configs/paper/`. They don't run by default, they need to be specified by running: 22 | 23 | ```sh 24 | python module-name.py -cp configs/paper -cn config-name-without-extension 25 | ``` 26 | 27 | ## Retrieval 28 | 29 | The first step would be to augment the entire dataset (train, dev) with relevant documents, based on the questions, see 30 | [processing-asqa-retrieval.yaml](../configs/paper/processing-asqa-retrieval.yaml). Let's focus on the different steps: 31 | 32 | ```yaml 33 | - _target_: ragfit.processing.dataset_loaders.loaders.HFLoader 34 | inputs: train 35 | dataset_config: 36 | path: din0s/asqa 37 | split: train 38 | 39 | - _target_: ragfit.processing.dataset_loaders.loaders.HFLoader 40 | inputs: dev 41 | dataset_config: 42 | path: din0s/asqa 43 | split: dev 44 | ``` 45 | 46 | We load the train and dev splits, to be used in the pipeline; they will be referred using the `inputs` keyword used in this 47 | step. 48 | 49 | ```yaml 50 | - _target_: ragfit.processing.local_steps.common_datasets.ASQA 51 | inputs: [train, dev] 52 | ``` 53 | We do some minimal processing, related to ASQA, namely column renaming, collecting the short and long answers and 54 | having a consistent scheme, for example: `query`, `answers`, `positive_passages`, etc. Feel free to add your own types 55 | of pre-processing. 56 | 57 | Notice the `inputs` keyword can accept a list of strings, meaning the step will run over the datasets specified. 58 | 59 | ```yaml 60 | - _target_: 61 | ragfit.processing.local_steps.retrievers.haystack.HaystackRetriever 62 | inputs: [train, dev] 63 | pipeline_or_yaml_path: ./configs/external/haystack/qdrant.yaml 64 | docs_key: positive_passages 65 | query_key: query 66 | ``` 67 | This is the retrieval step. We use the [Haystack](https://haystack.deepset.ai/) framework for building RAG pipelines; in 68 | this example, the Haystack pipeline is comprised of an embedder and a retriever, connecting the Qdrant using a 69 | Qdrant-Haystack integration (all defined in the requirements file). The Haystack pipeline is initialized from the 70 | [Qdrant.yaml](../configs/external/haystack/qdrant.yaml) configuration. One can use other frameworks for retrieval, like 71 | LangChain, LlamaIndex, or others. 72 | 73 | The retrieval step will store the most relevant documents (k=5) in the `docs_key` and the query will be defined by the 74 | `query_key`. 75 | 76 | ```yaml 77 | - _target_: ragfit.processing.local_steps.context.ContextHandler 78 | inputs: [train, dev] 79 | docs_key: positive_passages 80 | ``` 81 | In this simple step, the documents retrieved are processed; they have a title and content fields and this step combine 82 | these into a single string for every document. This step may be unnecessary, depending on the retrieval mechanism and 83 | format. 84 | 85 | ```yaml 86 | - _target_: ragfit.processing.global_steps.sampling.Sampler 87 | inputs: [train, dev] 88 | k: 1 89 | input_key: positive_passages 90 | output_key: negative_passages 91 | ``` 92 | The `Sampler` class deals with sampling examples from the same dataset or others. In order to train the RAFT-based 93 | model on a combination of relevant and distracting documents, we need to collect these distracting documents. Here we 94 | chose to collect positive documents from other examples, to be used as negative documents. The `Sampler` is then ran 95 | with k=1, it collects only the `positive_passages` from the examples it samples and store them in a new keyword, called 96 | `negative_passages`. 97 | 98 | ```yaml 99 | - _target_: ragfit.processing.global_steps.output.OutputData 100 | inputs: [train, dev] 101 | prefix: asqa 102 | ``` 103 | Finally we write the two resulting dataset to disk. They represent the retrieval-augmented datasets, ready to be 104 | processed for the different tasks. 105 | 106 | To run this process: 107 | ```sh 108 | python processing.py -cp configs/paper -cn processing-asqa-retrieval 109 | ``` 110 | 111 | ## Baseline Configuration 112 | 113 | For the baseline, there is not going to be context, only the question presented to the model. We use 114 | instruction-following models that have a chat template built-in. The framework populates the chat template using the 115 | inputs and outputs we generate, so we don't need to worry about roles and special tokens. Additionally, the system 116 | instruction is specified only during training and inference: it needn't be part of the dataset so these next steps mainly 117 | deal with the prompt generation. 118 | 119 | These are the interesting steps: 120 | 121 | ```yaml 122 | - _target_: ragfit.processing.dataset_loaders.loaders.LocalLoader 123 | inputs: dev 124 | filename: asqa-dev.jsonl 125 | 126 | - _target_: ragfit.processing.local_steps.prompter.TextPrompter 127 | inputs: dev 128 | prompt_file: ragfit/processing/prompts/qa-short.txt 129 | output_key: prompt 130 | mapping: 131 | query: query 132 | ``` 133 | 134 | We load the locally retrieval-augmented files we generated in the previous section. 135 | 136 | The `TextPrompter` populates a template file containing placeholders in python format, see the [short 137 | template](../ragfit/processing/prompts/qa-short.txt). The step replace the placeholders with variables using a provided 138 | mapping. The result is a string, saved in a keyword called `outputs_key`. 139 | 140 | To run this process: 141 | ```sh 142 | python processing.py -cp configs/paper -cn processing-asqa-baseline 143 | ``` 144 | 145 | ## Context 146 | 147 | Preparing for configurations (1) and (2), we want to augment the examples with the top 5 documents we collected in the 148 | first step. 149 | 150 | ```yaml 151 | - _target_: ragfit.processing.local_steps.context.DocumentsJoiner 152 | inputs: [train, dev] 153 | docs_key: positive_passages 154 | k: 5 155 | 156 | - _target_: ragfit.processing.local_steps.prompter.TextPrompter 157 | inputs: [train, dev] 158 | prompt_file: ragfit/processing/prompts/qa.txt 159 | output_key: prompt 160 | mapping: 161 | question: query 162 | context: positive_passages 163 | ``` 164 | The `DocumentJoiner` joins a list of strings and is needed before the `TextPrompter` we've seen from the previous 165 | section. We prepare a dev file—for testing the model with retrieved documents—and also a training file, in order 166 | to run fine-tuning. Both configurations will be evaluated on the dev dataset. 167 | 168 | To run this process: 169 | ```sh 170 | python processing.py -cp configs/paper -cn processing-asqa-context 171 | ``` 172 | 173 | ## Chain-of-Thought 174 | 175 | We prepare a dev set with CoT reasoning prompt. The configuration will be similar to the *Context* configuration, 176 | however here we use a different prompt template: 177 | 178 | ```yaml 179 | - _target_: ragfit.processing.local_steps.prompter.TextPrompter 180 | inputs: dev 181 | prompt_file: ragfit/processing/prompts/cot.txt 182 | output_key: prompt 183 | mapping: 184 | question: query 185 | context: positive_passages 186 | ``` 187 | 188 | To run this process: 189 | ```sh 190 | python processing.py -cp configs/paper -cn processing-asqa-cot-dev 191 | ``` 192 | 193 | ## Chain-of-Thought Training Dataset 194 | 195 | In order to train a model on a CoT-based prompt, we need to collect well-reasoned responses; we use GPT4 for that. 196 | Additionally, we implement a technique from RAFT where some percentage of the examples have purely distractor documents, 197 | in order for the model ability to filter noise. Here are the relevant steps: 198 | 199 | ```yaml 200 | - _target_: ragfit.processing.local_steps.raft.RAFTStep 201 | inputs: train 202 | k: 5 203 | raft_p: 0.5 204 | neg_docs_num: 2 205 | output_key: raft_docs 206 | ``` 207 | The `RAFTStep` implements the logic presented in the paper; the percentage of purely-distractor documents is defined by 208 | `raft_p`. The list of documents, some relevant, some distracting, are saved in a keyword called `output_key`. 209 | 210 | ```yaml 211 | - _target_: ragfit.processing.local_steps.context.DocumentsJoiner 212 | inputs: train 213 | docs_key: raft_docs 214 | k: 215 | 216 | - _target_: ragfit.processing.local_steps.prompter.TextPrompter 217 | inputs: train 218 | prompt_file: ragfit/processing/prompts/cot.txt 219 | output_key: prompt 220 | mapping: 221 | question: query 222 | context: raft_docs 223 | ``` 224 | The documents are joined into strings; when `k:` all documents are used. The prompt used is the same as when building the dev dataset. 225 | 226 | Next is interacting with OpeanAI; we implemented an [OpenAI class](../ragfit/models/openai_executor.py) using Azure, 227 | one can implement using other abstractions. The step itself needs the `prompt_key`, instruction file and the results are 228 | saved in the `answer_key`. 229 | ```yaml 230 | - _target_: ragfit.processing.local_steps.api.openai.OpenAIChat 231 | inputs: train 232 | prompt_key: prompt 233 | answer_key: generated_answer 234 | instruction: ragfit/processing/prompts/prompt_instructions/qa.txt 235 | model: 236 | azure_endpoint: azure.endpoint.com 237 | api_version: 2024-05-01-preview 238 | model: GPT-4-32k-Bot 239 | ``` 240 | 241 | To run this process: 242 | ```sh 243 | python processing.py -cp configs/paper -cn processing-asqa-cot-train 244 | ``` 245 | -------------------------------------------------------------------------------- /docs/pubmed.md: -------------------------------------------------------------------------------- 1 | # Fine-tuning Phi-3 for PubmedQA 2 | 3 | We will demonstrate the RAG-FiT framework by creating a RAG augmented dataset, fine-tuning a model and running an evaluation on the [PubmedQA](https://huggingface.co/datasets/bigbio/pubmed_qa) dataset. We will follow the experimentation in the paper, implementing the **RAG-sft** configuration, which comprised of creating prompts with relevant context and fine-tuning a model on the completions. 4 | 5 | The [PubmedQA](https://huggingface.co/datasets/bigbio/pubmed_qa) dataset contains relevant context for each question, so there's no need for retrieval—for an example with a retrieval step, see the ASQA processing [tutorial](./processing.md). 6 | 7 | **Notice**: all the configurations mentioned in this guide, implementing the experiments done in the paper, are saved in 8 | `configs/paper/`. They don't run by default, they need to be specified by running: 9 | 10 | ```sh 11 | python module-name.py -cp configs/paper -cn config-name-without-extension 12 | ``` 13 | 14 | 15 | ## RAG Dataset Creation 16 | 17 | We use the 1st module, called `processing.py` to generate the RAG-augmented dataset. To run it: 18 | 19 | ```sh 20 | python processing.py -cp configs/paper -cn processing-pubmed-context 21 | ``` 22 | 23 | Let's analyze the [configuration file](../configs/paper/processing-pubmed-context.yaml) used for the dataset creation: 24 | 25 | ```yaml 26 | name: pubmed_rag 27 | cache: true 28 | output_path: . 29 | ``` 30 | 31 | Start by defining a pipeline name, turning caching on, and specifying the current folder for the output files. 32 | 33 | ```yaml 34 | steps: 35 | - _target_: ragfit.processing.dataset_loaders.loaders.HFLoader 36 | inputs: train 37 | dataset_config: 38 | path: bigbio/pubmed_qa 39 | split: train 40 | 41 | - _target_: ragfit.processing.dataset_loaders.loaders.HFLoader 42 | inputs: test 43 | dataset_config: 44 | path: bigbio/pubmed_qa 45 | name: pubmed_qa_labeled_fold0_source 46 | split: test 47 | ``` 48 | 49 | Next we load a training and test sets from the Hugging Face hub. The `inputs` keyword is used to denote the datasets to be used on the subsequent steps. 50 | 51 | ```yaml 52 | - _target_: ragfit.processing.global_steps.sampling.ShuffleSelect 53 | inputs: train 54 | limit: 50000 55 | 56 | - _target_: ragfit.processing.local_steps.common_datasets.PubMed 57 | inputs: [train, test] 58 | 59 | - _target_: ragfit.processing.local_steps.context.DocumentsJoiner 60 | inputs: [train, test] 61 | docs_key: positive_passages 62 | k: 5 63 | ``` 64 | 65 | Next are 3 technical steps: we limit the size of the training dataset to 50k examples (optional). We do minimal processing of features: namely creating a `query`, `answers` and `positive_passages` features. Finally, we combine `k=5` relevant documents for each example into a string, to be used later in a prompt. 66 | 67 | ```yaml 68 | - _target_: ragfit.processing.local_steps.prompter.TextPrompter 69 | inputs: [train, test] 70 | prompt_file: ragfit/processing/prompts/qa.txt 71 | output_key: prompt 72 | mapping: 73 | question: query 74 | context: positive_passages 75 | ``` 76 | 77 | Next is the prompt generation step; we used a QA prompt with `question` and `context` placeholders. We map the values using the `mapping` keyword. 78 | 79 | > [!IMPORTANT] 80 | > There is no model-dependency in the prompt building. For models/tokenizers supporting a chat format, the prompt is going to be uttered by the *user* role, where the chat, including a system instruction, is constructed only in the training and inference stages. For models/tokenizers not supporting a chat format, a template can be provided by the users, to be used in the training and inference stages. 81 | 82 | Finally we write the results to files. 83 | 84 | ## Training 85 | 86 | Training is done on the generated files. The training configuration has 3 parts: model, training arguments and data. 87 | 88 | ```yaml 89 | model: 90 | _target_: ragfit.models.hf.HFTrain 91 | model_name_or_path: microsoft/Phi-3-mini-128k-instruct 92 | load_in_4bit: false 93 | load_in_8bit: true 94 | lora: 95 | lora_alpha: 16 96 | lora_dropout: 0.1 97 | peft_type: LORA 98 | r: 16 99 | target_modules: 100 | - qkv_proj 101 | task_type: CAUSAL_LM 102 | completion_start: <|assistant|> 103 | instruction_in_prompt: 104 | max_sequence_len: 2000 105 | ``` 106 | 107 | Model loading is implemented using the `HFTrain` class, which loads models from HuggingFace hub and uses PEFT adapters. Other classes can be implemented. The important keys here are: `completion_start` which indicates the beginning of the text where loss is to be calculated. This is model/tokenizer specific. Additionally, there is the `instruction_in_prompt` key, which if set to *True*, inserts the system instruction in the prompt, for models which do not support a dedicated system role. 108 | 109 | ```yaml 110 | train: 111 | output_dir: ./trained_model/ 112 | gradient_accumulation_steps: 2 113 | learning_rate: 1e-4 114 | logging_steps: 10 115 | lr_scheduler_type: cosine 116 | num_train_epochs: 1 117 | per_device_train_batch_size: 1 118 | optim: paged_adamw_8bit 119 | warmup_ratio: 0.03 120 | weight_decay: 0.001 121 | ``` 122 | 123 | Training is done using the `SFTTrainer` in `TRL`. Training arguments are based on HuggingFace `Trainer`. 124 | 125 | ```yaml 126 | instruction: ragfit/processing/prompts/prompt_instructions/qa-yes-no.txt 127 | template: 128 | data_file: pubmed-rag-train.jsonl 129 | input_key: prompt 130 | output_key: answers 131 | limit: 132 | shuffle: 133 | hfhub_tag: 134 | ``` 135 | 136 | Here are they important keys: 137 | 138 | - The instruction file to use for training (should later be used for inference as well). In the case of PubmedQA, the answers are either Yes or No, so we specify this in the system instruction. 139 | - If the model/tokenizer do not support a chat template, the user needs to provided a custom template; they placeholders to fill are `query` and `output`. 140 | - Data file is the processed file to train on. 141 | - Input key is the prompt. 142 | - Output key is completion text to learn. 143 | - Limit and shuffle can be used to filter the dataset for debugging purposes. 144 | - The framework can push the trained model to `hfhub_tab`. 145 | 146 | We create a training job by running: 147 | 148 | ```sh 149 | python training.py -cp configs/paper -cn training-pubmed 150 | ``` 151 | 152 | ## Inference 153 | 154 | In the inference stage, we take the processed dataset and LLM and make predictions. The LLM can be fine-tuned. The processed data encapsulates the RAG interactions: pre-processing, retrieval, ranking, prompt-creation, and possibly other types of transformations. So this step deals with producing the predictions to be evaluated. 155 | 156 | It is simple in nature, described by the following configuration: 157 | 158 | ```yaml 159 | model: 160 | _target_: ragfit.models.hf.HFInference 161 | model_name_or_path: microsoft/Phi-3-mini-128k-instruct 162 | load_in_4bit: false 163 | load_in_8bit: true 164 | device_map: auto 165 | trust_remote_code: true 166 | instruction: ragfit/processing/prompts/prompt_instructions/qa-yes-no.txt 167 | lora_path: ./trained_model/checkpoint 168 | generation: 169 | do_sample: false 170 | max_new_tokens: 50 171 | return_full_text: false 172 | 173 | data_file: pubmed-rag-test.jsonl 174 | generated_file: pubmed-rag-test-generated.jsonl 175 | input_key: prompt 176 | generation_key: output 177 | target_key: answers 178 | limit: 179 | ``` 180 | 181 | The model section deals with details regarding the model loading and generation options. System instruction can be provided, as we mentioned previously: the datasets are model independent, and all model details (system instruction, custom chat template) are needed only during training and inference. Similarly, `instruct_in_prompt` inserts the system instruction inside the prompt, for models which don't support a *system* role. 182 | 183 | Other parameters: 184 | 185 | - Data file is the processed file. 186 | - Generated file is the file that will be created with the completions (and labels, for easy debugging). 187 | - Target key is the label keyword. 188 | - Limit: to a number of examples, for debugging. 189 | 190 | In order to run inference: 191 | 192 | ```sh 193 | python inference.py -cp configs/paper -cn inference-pubmed 194 | ``` 195 | 196 | ## Evaluations 197 | 198 | The evaluation module takes the produced inference file and the original processed dataset and runs a list of evaluations, producing a final results file, in a YAML format. The evaluations are represented as metric classes. 199 | 200 | We implement several metrics including: a wrapper for HuggingFace `evaluate` class, which can accept a list of metrics, EM, F1, classification (accuracy, precision, recall, F1), BERTScore, Semantic similarity (using a customizable cross-encoder). The module can also run metrics from [DeepEval](https://docs.confident-ai.com/docs/getting-started), which offers a large collection of LLM evaluations. 201 | 202 | The configuration for the evaluation looks like this: 203 | 204 | ```yaml 205 | answer_processor: 206 | _target_: ragfit.processing.answer_processors.regex.RegexAnswer 207 | capture_pattern: 208 | stopping_pattern: 209 | 210 | metrics: 211 | - _target_: ragfit.evaluation.metrics.Classification 212 | mapping: 213 | "yes": 1 214 | "no": 0 215 | "maybe": 2 216 | else_value: 2 217 | 218 | key_names: 219 | generated: text 220 | label: answers 221 | query: query 222 | 223 | results_file: evaluation-pubmed-rag.yaml 224 | generated_file: pubmed-rag-test-generated.jsonl 225 | data_file: pubmed-rag-test.jsonl 226 | limit: 227 | ``` 228 | 229 | The evaluation module introduces the concept of an **Answer Processor**. This class can run post-processing on the generated text, including: aligning text with the expect output, implement evaluation-specific formatting, extracting the specific sections, processing meta-data like citations, etc. 230 | 231 | The default processor is called `RegexAnswer`; it can filter text, based on a python regex capture pattern. It can also split text using a stopping pattern. For example, in the Chain-of-Thought reasoning we used in the paper, the model is instruction to explain its answer, cite if needed and finally print the final results in the following format `: ...`. We can use this format as a capture pattern; thus models that learn to answer using this pattern (obey the instruction) will score higher. 232 | 233 | For PubmedQA we use a **classification metric**; we provide a mapping of keys and a default key, since the PubmedQA expert annotated test set can contain Yes, No or Maybe, as answers. 234 | 235 | The rest of the arguments are straightforward: 236 | 237 | - Keyword names for input, output and target. 238 | - Name of inference file, name of the processed data. 239 | - Name for the results summary report. 240 | - Limit, for debugging purposes. 241 | 242 | Running the evaluation: 243 | 244 | ```sh 245 | python evaluation.py -cp configs/paper -cn evaluation-pubmed 246 | ``` 247 | 248 | ## Summary 249 | 250 | In this tutorial, we enhanced an LLM to better perform Q&A on the PubmedQA task, by generating a training dataset containing relevant context, fine-tuning and evaluating the model on the testset. By modifying the configurations presented here, one can run an evaluation on an untrained model and see the benefit of RAG. One can implement other RAG techniques; for example, see the ASQA tutorial for a more advanced usecase (as well as more thorough explanations), including external retrieval, OpenAI integration and Chain-of-thought prompting: [data creation](./processing.md), [training](./training.md), [inference](./inference.md) and [evaluation](./evaluation.md). 251 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2024 Intel Corporation 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /ragfit/evaluation/metrics.py: -------------------------------------------------------------------------------- 1 | import re 2 | import string 3 | import unicodedata 4 | from collections import Counter, defaultdict 5 | 6 | import regex 7 | 8 | from .base import MetricBase 9 | 10 | 11 | class HFEvaluate(MetricBase): 12 | """ 13 | Wrapper class around `evaluate` metrics; easy to use, only need metric names. 14 | """ 15 | 16 | def __init__(self, key_names, metric_names: list[str], **kwargs): 17 | """ 18 | Args: 19 | key_names (dict): A dictionary containing the field names. 20 | metric_names (list[str]): A list of metric names. 21 | """ 22 | import evaluate 23 | 24 | super().__init__(key_names, **kwargs) 25 | self.metric_names = metric_names 26 | self.metric = evaluate.combine(metric_names) 27 | self.local = True 28 | 29 | def measure(self, example): 30 | """ 31 | Measure the performance of the model on a given example. 32 | 33 | Args: 34 | example (dict): The example containing input and target values. 35 | 36 | Returns: 37 | dict: The performance metric(s) computed for the example. 38 | """ 39 | input = example[self.field] 40 | target = example[self.target] 41 | 42 | if isinstance(target, list): 43 | results = defaultdict(int) 44 | for tar in target: 45 | results = { 46 | k: max(v, results[k]) 47 | for k, v in self.metric.compute( 48 | predictions=[input], references=[tar] 49 | ).items() 50 | } 51 | return results 52 | else: 53 | return self.metric.compute(predictions=[input], references=[target]) 54 | 55 | 56 | class Classification(MetricBase): 57 | """ 58 | Metrics for classification answers: accuracy, precision, recall, F1; macro-averaged. 59 | 60 | mapping: dict - mapping of labels to integers. 61 | Example: {"true": 1, "false": 0, "maybe": 2} 62 | else_value: int - value to assign to labels not in the mapping. 63 | """ 64 | 65 | def __init__( 66 | self, key_names: dict, mapping: dict, else_value: int = 2, **kwargs 67 | ) -> None: 68 | from sklearn.metrics import accuracy_score, precision_recall_fscore_support 69 | 70 | super().__init__(key_names, **kwargs) 71 | self.local = False 72 | self.mapping = mapping 73 | self.else_value = else_value 74 | self.precision_recall_fn = precision_recall_fscore_support 75 | self.accuracy_fn = accuracy_score 76 | 77 | def in_text(self, text): 78 | if "yes" in text: 79 | return 1 80 | if "no" in text: 81 | return 0 82 | return 2 83 | 84 | def measure(self, example: dict): 85 | inputs = example[self.field] 86 | targets = example[self.target] 87 | 88 | if isinstance(targets[0], list): 89 | targets = [t[0] for t in targets] 90 | 91 | inputs = [self.in_text(normalize_text(i).strip()) for i in inputs] 92 | 93 | targets = [ 94 | self.mapping.get(normalize_text(t).strip(), self.else_value) for t in targets 95 | ] 96 | 97 | precision, recall, f1, _ = self.precision_recall_fn( 98 | targets, inputs, average="macro" 99 | ) 100 | accuracy = self.accuracy_fn(targets, inputs) 101 | 102 | return { 103 | "accuracy": float(accuracy), 104 | "precision": float(precision), 105 | "recall": float(recall), 106 | "f1": float(f1), 107 | } 108 | 109 | 110 | def normalize_text(s): 111 | """ 112 | Normalize the given text by lowercasing it, removing punctuation, articles, and extra whitespace. 113 | 114 | Args: 115 | s (str): The text to be normalized. 116 | 117 | Returns: 118 | str: The normalized text. 119 | """ 120 | 121 | def remove_articles(text): 122 | return re.sub(r"\b(a|an|the)\b", " ", text) 123 | 124 | def white_space_fix(text): 125 | return " ".join(text.split()) 126 | 127 | def remove_punc(text): 128 | exclude = set(string.punctuation) 129 | return "".join(ch for ch in text if ch not in exclude) 130 | 131 | def lower(text): 132 | return text.lower() 133 | 134 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 135 | 136 | 137 | class F1(MetricBase): 138 | """ 139 | Implementing F1 based on code from Kilt. 140 | """ 141 | 142 | def __init__(self, key_names, **kwargs) -> None: 143 | """Initialize the Metrics class. 144 | 145 | Args: 146 | key_names (dict): A dictionary containing the field names. 147 | """ 148 | super().__init__(key_names, **kwargs) 149 | self.local = True 150 | 151 | @staticmethod 152 | def _f1(prediction, ground_truth): 153 | prediction_tokens = normalize_text(prediction).split() 154 | ground_truth_tokens = normalize_text(ground_truth).split() 155 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens) 156 | num_same = sum(common.values()) 157 | if num_same == 0: 158 | return 0 159 | precision = 1.0 * num_same / len(prediction_tokens) 160 | recall = 1.0 * num_same / len(ground_truth_tokens) 161 | f1 = (2 * precision * recall) / (precision + recall) 162 | return f1 163 | 164 | def measure(self, example: dict): 165 | input = example[self.field] 166 | target = example[self.target] 167 | 168 | assert isinstance(input, str), f"Generated text should be a string: {input}" 169 | if not isinstance(target, list): 170 | target = [target] 171 | 172 | scores = [self._f1(input, t) for t in target] 173 | return {"F1": max(scores)} 174 | 175 | 176 | class EM(MetricBase): 177 | """ 178 | Implementing Exact Match based on code from Kilt. 179 | """ 180 | 181 | def __init__(self, key_names, **kwargs) -> None: 182 | """Initialize the Metrics class. 183 | 184 | Args: 185 | key_names (dict): A dictionary containing the field names. 186 | """ 187 | super().__init__(key_names, **kwargs) 188 | self.local = True 189 | 190 | def measure(self, example: dict): 191 | input = example[self.field] 192 | target = example[self.target] 193 | 194 | assert isinstance(input, str), f"Generated text should be a string: {input}" 195 | if not isinstance(target, list): 196 | target = [target] 197 | 198 | scores = [normalize_text(input) == normalize_text(t) for t in target] 199 | return {"EM": int(max(scores))} 200 | 201 | 202 | class StringEM(MetricBase): 203 | """ 204 | Implementing String Exact Match. 205 | 206 | Used in ASQA to evaluate whether the annoated short answers appear in the 207 | generated answer as sub-strings. 208 | """ 209 | 210 | def __init__(self, key_names: dict, **kwargs) -> None: 211 | """ 212 | Initialize the Metrics class. 213 | 214 | Args: 215 | key_names (dict): A dictionary containing the field names. 216 | """ 217 | super().__init__(key_names, **kwargs) 218 | self.local = True 219 | 220 | def measure(self, example: dict): 221 | input = example[self.field] 222 | target = example[self.target] 223 | 224 | assert isinstance(input, str), f"Generated text should be a string: {input}" 225 | assert isinstance(target[0], list), f"Target should be a list of lists: {target}" 226 | 227 | input = normalize_text(input) 228 | scores = [any(cand in input for cand in item) for item in target] 229 | 230 | return {"StringEM": sum(scores) / len(scores)} 231 | 232 | 233 | class SimpleTokenizer(object): 234 | ALPHA_NUM = r"[\p{L}\p{N}\p{M}]+" 235 | NON_WS = r"[^\p{Z}\p{C}]" 236 | 237 | def __init__(self): 238 | """ 239 | Args: 240 | annotators: None or empty set (only tokenizes). 241 | """ 242 | self._regexp = regex.compile( 243 | "(%s)|(%s)" % (self.ALPHA_NUM, self.NON_WS), 244 | flags=regex.IGNORECASE + regex.UNICODE + regex.MULTILINE, 245 | ) 246 | 247 | def tokenize(self, text, uncased=False): 248 | matches = [m for m in self._regexp.finditer(text)] 249 | if uncased: 250 | tokens = [m.group().lower() for m in matches] 251 | else: 252 | tokens = [m.group() for m in matches] 253 | return tokens 254 | 255 | 256 | class RecallEM(MetricBase): 257 | """ 258 | Implementing EM as in XRAG. 259 | """ 260 | 261 | def __init__(self, key_names, **kwargs) -> None: 262 | """Initialize the Metrics class. 263 | 264 | Args: 265 | key_names (dict): A dictionary containing the field names. 266 | """ 267 | super().__init__(key_names, **kwargs) 268 | self.local = True 269 | 270 | @staticmethod 271 | def _normalize(text): 272 | return unicodedata.normalize("NFD", text) 273 | 274 | def has_answer(self, answers, text, tokenizer=SimpleTokenizer()): 275 | """Check if a document contains an answer string.""" 276 | text = self._normalize(text) 277 | text = tokenizer.tokenize(text, uncased=True) 278 | 279 | for answer in answers: 280 | answer = self._normalize(answer) 281 | answer = tokenizer.tokenize(answer, uncased=True) 282 | for i in range(0, len(text) - len(answer) + 1): 283 | if answer == text[i : i + len(answer)]: 284 | return True 285 | return False 286 | 287 | def measure(self, example: dict): 288 | input = example[self.field] 289 | target = example[self.target] 290 | 291 | assert isinstance(input, str), f"Generated text should be a string: {input}" 292 | 293 | if not isinstance(target, list): 294 | target = [target] 295 | 296 | scores = self.has_answer(target, input) 297 | return {"recallEM": int(scores)} 298 | 299 | 300 | class BERTScore(MetricBase): 301 | """ 302 | BERTScore metric, based on the BERTScore library. 303 | """ 304 | 305 | def __init__(self, key_names: dict, model="microsoft/deberta-large-mnli", **kwargs): 306 | """Initialize the Metrics class. 307 | 308 | Args: 309 | key_names (dict): A dictionary containing the field names. 310 | model (str, optional): The name of the BERT model to use. Defaults to "microsoft/deberta-large-mnli". 311 | """ 312 | super().__init__(key_names, **kwargs) 313 | from bert_score import BERTScorer 314 | 315 | self.scorer = BERTScorer(model, lang="en", rescale_with_baseline=True) 316 | self.local = True 317 | 318 | def measure(self, example): 319 | input = example[self.field] 320 | target = example[self.target] 321 | 322 | if not isinstance(target, list): 323 | target = [target] 324 | 325 | scores = [self.scorer.score([input], [t])[2].item() for t in target] 326 | 327 | return {"BERTScore-F1": max(scores)} 328 | 329 | 330 | class Semantic(MetricBase): 331 | """ 332 | Semantic similarity between label and answer using a cross-encoder. 333 | """ 334 | 335 | def __init__( 336 | self, 337 | key_names: dict, 338 | model: str = "vectara/hallucination_evaluation_model", 339 | **kwargs, 340 | ) -> None: 341 | """ 342 | Initializes an instance of the class. 343 | 344 | Args: 345 | key_names (dict): A dictionary containing the field names. 346 | model (str, optional): The name of the BERT model to use. 347 | """ 348 | super().__init__(key_names, **kwargs) 349 | 350 | from sentence_transformers import CrossEncoder 351 | 352 | self.model = CrossEncoder(model) 353 | self.local = True 354 | 355 | def measure(self, example): 356 | input = example[self.field] 357 | target = example[self.target] 358 | if not isinstance(target, list): 359 | target = [target] 360 | 361 | scores = self.model.predict([[input, t] for t in target]) 362 | 363 | return {"Semantic": max(scores)} 364 | --------------------------------------------------------------------------------