├── rag_experiment_accelerator ├── __init__.py ├── io │ ├── __init__.py │ ├── local │ │ ├── __init__.py │ │ ├── loaders │ │ │ ├── __init__.py │ │ │ ├── tests │ │ │ │ ├── test_local_loader.py │ │ │ │ └── test_jsonl_loader.py │ │ │ ├── jsonl_loader.py │ │ │ └── local_loader.py │ │ ├── writers │ │ │ ├── __init__.py │ │ │ ├── tests │ │ │ │ └── test_jsonl_writer.py │ │ │ └── jsonl_writer.py │ │ ├── base.py │ │ └── tests │ │ │ └── test_local_io_base.py │ ├── exceptions.py │ ├── loader.py │ └── writer.py ├── llm │ ├── __init__.py │ ├── exceptions.py │ ├── prompts_text │ │ ├── llm_answer_relevance_instruction.txt │ │ ├── prompt_instruction_keywords.txt │ │ ├── prompt_instruction_summary.txt │ │ ├── main_instruction_short.txt │ │ ├── prompt_instruction_title.txt │ │ ├── prompt_instruction_entities.txt │ │ ├── llm_context_precision_instruction.txt │ │ ├── prompt_generate_hypothetical_questions.txt │ │ ├── rerank_prompt_instruction.txt │ │ ├── prompt_generate_hypothetical_answer.txt │ │ ├── do_need_multiple_prompt_instruction.txt │ │ ├── main_instruction_long.txt │ │ ├── prompt_generate_hypothetical_document.txt │ │ └── multiple_prompt_instruction.txt │ └── prompt │ │ ├── rerank_prompts.py │ │ ├── multiprompts.py │ │ ├── hyde_prompts.py │ │ ├── __init__.py │ │ ├── ragas_prompts.py │ │ ├── instruction_prompts.py │ │ └── qna_prompts.py ├── nlp │ └── __init__.py ├── artifact │ ├── __init__.py │ ├── handlers │ │ ├── __init__.py │ │ ├── exceptions.py │ │ ├── typing.py │ │ └── tests │ │ │ └── test_artifact_handler.py │ └── models │ │ ├── __init__.py │ │ └── query_output.py ├── config │ ├── __init__.py │ ├── search_config.py │ ├── openai_config.py │ ├── embedding_model_config.py │ ├── rerank_config.py │ ├── query_expansion.py │ ├── eval_config.py │ ├── language_config.py │ ├── chunking_config.py │ ├── sampling_config.py │ ├── paths.py │ ├── config_validator.py │ └── path_config.py ├── embedding │ ├── __init__.py │ ├── factory.py │ ├── embedding_model.py │ ├── tests │ │ ├── test_st_embedding_model.py │ │ ├── test_factory.py │ │ └── test_aoai_embedding_model.py │ └── aoai_embedding_model.py ├── reranking │ ├── __init__.py │ └── reranker.py ├── sampling │ ├── __init__.py │ └── tests │ │ └── data │ │ └── test1.txt ├── utils │ ├── __init__.py │ ├── auth.py │ ├── logging.py │ └── timetook.py ├── data_assets │ ├── __init__.py │ └── data_asset.py ├── doc_loader │ ├── __init__.py │ ├── tests │ │ ├── test_data │ │ │ └── json │ │ │ │ ├── data.bad.not_a_list.json │ │ │ │ ├── data.bad.invalid_keys.json │ │ │ │ └── data.valid.json │ │ ├── test_custom_html_loader.py │ │ └── test_docx_loader.py │ ├── docxLoader.py │ ├── markdownLoader.py │ ├── textLoader.py │ ├── jsonLoader.py │ ├── htmlLoader.py │ ├── customJsonLoader.py │ └── structuredLoader.py ├── evaluation │ ├── __init__.py │ ├── tests │ │ ├── test_transformer_based_metrics.py │ │ ├── test_spacy_evaluator.py │ │ └── test_search_eval.py │ ├── spacy_evaluator.py │ ├── search_eval.py │ └── transformer_based_metrics.py ├── ingest_data │ └── __init__.py ├── init_Index │ └── __init__.py ├── search_type │ └── __init__.py ├── checkpoint │ ├── __init__.py │ ├── null_checkpoint.py │ ├── tests │ │ ├── test_null_checkpoint.py │ │ ├── test_local_storage_checkpoint.py │ │ └── test_checkpoint.py │ ├── checkpoint_factory.py │ ├── checkpoint_decorator.py │ ├── local_storage_checkpoint.py │ └── README.md └── run │ ├── tests │ ├── data │ │ └── test_data.jsonl │ └── test_qa_generation.py │ ├── evaluation.py │ └── qa_generation.py ├── .coveragerc ├── .flake8 ├── docs ├── azd.png ├── launch.png ├── wsl.md └── configs-appendix.md ├── images ├── elbow_5.png ├── map_at_k.png ├── view_logs.png ├── map_scores.png ├── search_chart.png ├── sample_metric.png ├── hyper_parameters.png ├── metric_analysis.png ├── metric_comparison.png ├── view_list_of_runs.png ├── AzureMLPipeline.drawio.png ├── create_access_policies.png ├── create_compute_cluster.png ├── compare_metrics_for_runs.png ├── view_logs_parallel_step.png ├── azureml_pipeline_overview.png └── all_cluster_predictions_cluster_number_5.jpg ├── data ├── pdf │ └── sample-pdf.pdf ├── docx │ └── sample-docx.docx ├── text │ └── sample-text.txt └── json │ └── sample-json.json ├── .azureml └── config.json ├── data-ci └── docx │ └── sample-docx.docx ├── dev-requirements.txt ├── promptflow └── rag-experiment-accelerator │ ├── setup │ ├── flow.dag.yaml │ └── setup_env.py │ ├── images │ ├── upload_local_flow.png │ └── end_to_end_flow_diagram.png │ ├── custom_environment │ ├── environment.yaml │ ├── rag_experiment_accelerator-0.9-py3-none-any.whl │ └── Dockerfile │ ├── querying │ ├── flow.dag.yaml │ └── querying.py │ ├── evaluation │ ├── flow.dag.yaml │ └── evaluation.py │ ├── qa_generation │ ├── flow.dag.yaml │ └── generate_qa.py │ ├── index │ ├── flow.dag.yaml │ └── create_index.py │ ├── flow.dag.yaml │ └── env_setup.md ├── experimental └── readme.md ├── setup.cfg ├── azure.yaml ├── .vscode ├── settings.json └── launch.json ├── infra ├── main.bicepparam ├── generate_arm_template.sh ├── shared │ ├── machineLearning.bicep │ ├── keyvault-secret.bicep │ ├── keyvault.bicep │ ├── monitoring.bicep │ ├── search-services.bicep │ ├── cognitiveservices.bicep │ ├── storekeys.bicep │ └── storage.bicep └── network │ ├── azure_bastion.bicep │ └── network_isolation.bicep ├── pyproject.toml ├── .devcontainer ├── post-create.sh └── devcontainer.json ├── setup.py ├── CODE_OF_CONDUCT.md ├── .github ├── dependabot.yml ├── actions │ └── configure_azureml_agent │ │ └── action.yml └── workflows │ ├── build_validation_workflow.yml │ ├── rag_exp_acc_ci.yml │ └── config.json ├── 02_qa_generation.py ├── requirements.txt ├── .pre-commit-config.yaml ├── env_to_keyvault.py ├── cspell.json ├── SUPPORT.md ├── 04_evaluation.py ├── .env.template ├── 03_querying.py ├── 01_index.py ├── azureml ├── index.py ├── eval.py └── query.py ├── SECURITY.md └── config.sample.json /rag_experiment_accelerator/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /rag_experiment_accelerator/io/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /rag_experiment_accelerator/llm/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /rag_experiment_accelerator/nlp/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /rag_experiment_accelerator/artifact/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /rag_experiment_accelerator/config/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /rag_experiment_accelerator/embedding/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /rag_experiment_accelerator/io/local/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /rag_experiment_accelerator/reranking/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /rag_experiment_accelerator/sampling/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /rag_experiment_accelerator/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.coveragerc: -------------------------------------------------------------------------------- 1 | [run] 2 | omit = 3 | */__init__.py -------------------------------------------------------------------------------- /rag_experiment_accelerator/data_assets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /rag_experiment_accelerator/doc_loader/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /rag_experiment_accelerator/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /rag_experiment_accelerator/ingest_data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /rag_experiment_accelerator/init_Index/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /rag_experiment_accelerator/search_type/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /rag_experiment_accelerator/artifact/handlers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /rag_experiment_accelerator/artifact/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /rag_experiment_accelerator/io/local/loaders/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /rag_experiment_accelerator/io/local/writers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /rag_experiment_accelerator/sampling/tests/data/test1.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 120 3 | extend-ignore = E203, E501 -------------------------------------------------------------------------------- /docs/azd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/rag-experiment-accelerator/HEAD/docs/azd.png -------------------------------------------------------------------------------- /docs/launch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/rag-experiment-accelerator/HEAD/docs/launch.png -------------------------------------------------------------------------------- /images/elbow_5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/rag-experiment-accelerator/HEAD/images/elbow_5.png -------------------------------------------------------------------------------- /rag_experiment_accelerator/llm/exceptions.py: -------------------------------------------------------------------------------- 1 | class ContentFilteredException(Exception): 2 | pass 3 | -------------------------------------------------------------------------------- /images/map_at_k.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/rag-experiment-accelerator/HEAD/images/map_at_k.png -------------------------------------------------------------------------------- /images/view_logs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/rag-experiment-accelerator/HEAD/images/view_logs.png -------------------------------------------------------------------------------- /data/pdf/sample-pdf.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/rag-experiment-accelerator/HEAD/data/pdf/sample-pdf.pdf -------------------------------------------------------------------------------- /images/map_scores.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/rag-experiment-accelerator/HEAD/images/map_scores.png -------------------------------------------------------------------------------- /images/search_chart.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/rag-experiment-accelerator/HEAD/images/search_chart.png -------------------------------------------------------------------------------- /.azureml/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "workspace_name": "", 3 | "resource_group": "", 4 | "subscription_id": "" 5 | } 6 | -------------------------------------------------------------------------------- /images/sample_metric.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/rag-experiment-accelerator/HEAD/images/sample_metric.png -------------------------------------------------------------------------------- /data/docx/sample-docx.docx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/rag-experiment-accelerator/HEAD/data/docx/sample-docx.docx -------------------------------------------------------------------------------- /images/hyper_parameters.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/rag-experiment-accelerator/HEAD/images/hyper_parameters.png -------------------------------------------------------------------------------- /images/metric_analysis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/rag-experiment-accelerator/HEAD/images/metric_analysis.png -------------------------------------------------------------------------------- /images/metric_comparison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/rag-experiment-accelerator/HEAD/images/metric_comparison.png -------------------------------------------------------------------------------- /images/view_list_of_runs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/rag-experiment-accelerator/HEAD/images/view_list_of_runs.png -------------------------------------------------------------------------------- /data-ci/docx/sample-docx.docx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/rag-experiment-accelerator/HEAD/data-ci/docx/sample-docx.docx -------------------------------------------------------------------------------- /images/AzureMLPipeline.drawio.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/rag-experiment-accelerator/HEAD/images/AzureMLPipeline.drawio.png -------------------------------------------------------------------------------- /images/create_access_policies.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/rag-experiment-accelerator/HEAD/images/create_access_policies.png -------------------------------------------------------------------------------- /images/create_compute_cluster.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/rag-experiment-accelerator/HEAD/images/create_compute_cluster.png -------------------------------------------------------------------------------- /images/compare_metrics_for_runs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/rag-experiment-accelerator/HEAD/images/compare_metrics_for_runs.png -------------------------------------------------------------------------------- /images/view_logs_parallel_step.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/rag-experiment-accelerator/HEAD/images/view_logs_parallel_step.png -------------------------------------------------------------------------------- /images/azureml_pipeline_overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/rag-experiment-accelerator/HEAD/images/azureml_pipeline_overview.png -------------------------------------------------------------------------------- /dev-requirements.txt: -------------------------------------------------------------------------------- 1 | promptflow==1.15.0 2 | promptflow-tools==1.4.0 3 | pytest==8.3.3 4 | pytest-cov==5.0.0 5 | flake8==7.1.1 6 | pre-commit==3.8.0 7 | black==24.8.0 8 | -------------------------------------------------------------------------------- /images/all_cluster_predictions_cluster_number_5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/rag-experiment-accelerator/HEAD/images/all_cluster_predictions_cluster_number_5.jpg -------------------------------------------------------------------------------- /rag_experiment_accelerator/doc_loader/tests/test_data/json/data.bad.not_a_list.json: -------------------------------------------------------------------------------- 1 | { 2 | "title": "Data should be a list", 3 | "content": "This is the content for item 1." 4 | } -------------------------------------------------------------------------------- /promptflow/rag-experiment-accelerator/setup/flow.dag.yaml: -------------------------------------------------------------------------------- 1 | inputs: {} 2 | outputs: {} 3 | nodes: 4 | - name: setup_env 5 | type: python 6 | source: 7 | type: code 8 | path: setup_env.py 9 | -------------------------------------------------------------------------------- /promptflow/rag-experiment-accelerator/images/upload_local_flow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/rag-experiment-accelerator/HEAD/promptflow/rag-experiment-accelerator/images/upload_local_flow.png -------------------------------------------------------------------------------- /experimental/readme.md: -------------------------------------------------------------------------------- 1 | 2 | RAG EXPERIMENT ACCELERATOR EXPERIMENTAL 3 | 4 | This is the experimental version of the RAG Experiment Accelerator. It is a work in progress and is not yet ready for production use. -------------------------------------------------------------------------------- /promptflow/rag-experiment-accelerator/images/end_to_end_flow_diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/rag-experiment-accelerator/HEAD/promptflow/rag-experiment-accelerator/images/end_to_end_flow_diagram.png -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = RAG Experiment Accelerator 3 | version = 0.9 4 | description = A tool to accelerate the process of running experiments with RAG 5 | 6 | [options] 7 | python_requires = >=3.11, <4 8 | -------------------------------------------------------------------------------- /azure.yaml: -------------------------------------------------------------------------------- 1 | # yaml-language-server: $schema=https://raw.githubusercontent.com/Azure/azure-dev/main/schemas/v1.0/azure.yaml.json 2 | 3 | name: rag-experiment-accelerator 4 | metadata: 5 | template: azd-init@1.6.1 6 | -------------------------------------------------------------------------------- /promptflow/rag-experiment-accelerator/custom_environment/environment.yaml: -------------------------------------------------------------------------------- 1 | $schema: https://azuremlschemas.azureedge.net/latest/environment.schema.json 2 | name: rag-experiment-accelerator-environment 3 | build: 4 | path: ./ -------------------------------------------------------------------------------- /rag_experiment_accelerator/doc_loader/tests/test_data/json/data.bad.invalid_keys.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "title": "Contains 'contents' key instead of 'content'", 4 | "contents": "This is the content for item 1." 5 | } 6 | ] -------------------------------------------------------------------------------- /rag_experiment_accelerator/artifact/handlers/exceptions.py: -------------------------------------------------------------------------------- 1 | class LoadException(Exception): 2 | def __init__(self, path: str): 3 | super().__init__( 4 | f"Cannot load at path: {path}. Please ensure it is supported by the loader." 5 | ) 6 | -------------------------------------------------------------------------------- /promptflow/rag-experiment-accelerator/custom_environment/rag_experiment_accelerator-0.9-py3-none-any.whl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/rag-experiment-accelerator/HEAD/promptflow/rag-experiment-accelerator/custom_environment/rag_experiment_accelerator-0.9-py3-none-any.whl -------------------------------------------------------------------------------- /rag_experiment_accelerator/artifact/handlers/typing.py: -------------------------------------------------------------------------------- 1 | from typing import TypeVar 2 | from rag_experiment_accelerator.io.loader import Loader 3 | from rag_experiment_accelerator.io.writer import Writer 4 | 5 | T = TypeVar("T", bound=Writer) 6 | U = TypeVar("U", bound=Loader) 7 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "python.testing.pytestArgs": [ 3 | "rag_experiment_accelerator" 4 | ], 5 | "files.exclude": { 6 | ": 4 | User: 5 | During the 19th century, industrialization led to significant urban growth, changes in employment patterns, and advancements in transportation technologies like railways and steamships. 6 | 7 | Assistant: 8 | 19th-century industrialization spurred urban growth and transportation advancements. -------------------------------------------------------------------------------- /promptflow/rag-experiment-accelerator/qa_generation/flow.dag.yaml: -------------------------------------------------------------------------------- 1 | inputs: 2 | config_dir: 3 | type: string 4 | default: ../ 5 | outputs: {} 6 | nodes: 7 | - name: setup_env 8 | type: python 9 | source: 10 | type: code 11 | path: ../setup/setup_env.py 12 | inputs: 13 | connection: "" 14 | - name: generate_qa 15 | type: python 16 | source: 17 | type: code 18 | path: generate_qa.py 19 | inputs: 20 | config_dir: ${inputs.config_dir} 21 | activate: 22 | when: ${setup_env.output} 23 | is: true 24 | -------------------------------------------------------------------------------- /rag_experiment_accelerator/llm/prompts_text/main_instruction_short.txt: -------------------------------------------------------------------------------- 1 | You provide answers to questions based solely on the information provided below. 2 | Answer precisely and concisely, addressing only what is asked without extraneous details. 3 | If the information needed to answer isn't available in the provided context, respond with "I don't know.". 4 | Cite specific sources by filename whenever you reference data or excerpts from the provided context. 5 | 6 | Input format: 7 | 8 | Context: 9 | {context} 10 | 11 | Question: 12 | {question} -------------------------------------------------------------------------------- /rag_experiment_accelerator/config/eval_config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from rag_experiment_accelerator.config.base_config import BaseConfig 3 | 4 | 5 | @dataclass 6 | class EvalConfig(BaseConfig): 7 | metric_types: list[str] = field( 8 | default_factory=lambda: [ 9 | "fuzzy_score", 10 | "bert_all_MiniLM_L6_v2", 11 | "cosine_ochiai", 12 | "bert_distilbert_base_nli_stsb_mean_tokens", 13 | "llm_answer_relevance", 14 | "llm_context_precision", 15 | ] 16 | ) 17 | -------------------------------------------------------------------------------- /rag_experiment_accelerator/run/tests/data/test_data.jsonl: -------------------------------------------------------------------------------- 1 | {"user_prompt":"What happens when there is a lack of dopamine in the brain?","output_prompt":"When there is a lack of dopamine in the brain, it can lead to movement disorders such as Parkinson's disease.","context":"Normally, there are brain cells (neurons) in the human brain that produce dopamine. These neurons concentrate in a particular area of the brain, called the substantia nigra. Dopamine is a chemical that relays messages between the substantia nigra and other parts of the brain to control movements of the human body."} 2 | -------------------------------------------------------------------------------- /rag_experiment_accelerator/io/local/base.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | class LocalIOBase: 5 | """ 6 | Base class for local input/output operations. 7 | """ 8 | 9 | def exists(self, path: str) -> bool: 10 | """ 11 | Check if a file or directory exists at the given path. 12 | 13 | Args: 14 | path (str): The path to check. 15 | 16 | Returns: 17 | bool: True if the file or directory exists, False otherwise. 18 | """ 19 | if os.path.exists(path): 20 | return True 21 | return False 22 | -------------------------------------------------------------------------------- /rag_experiment_accelerator/llm/prompts_text/prompt_instruction_title.txt: -------------------------------------------------------------------------------- 1 | Identify and provide an appropriate title for the given user text in a single sentence, ensuring the title is between 10 to 15 words long. Do not format the output as a list or include any additional text or metadata. 2 | 3 | : 4 | User: 5 | Exploring the impacts of climate change on global agriculture, this article discusses the shifts in crop yields, changes in weather patterns, and their effects on farming techniques. 6 | 7 | Assistant: 8 | Climate Change Effects on Global Agriculture and Crop Production -------------------------------------------------------------------------------- /rag_experiment_accelerator/llm/prompts_text/prompt_instruction_entities.txt: -------------------------------------------------------------------------------- 1 | Identify the key entities (person, organization, location, date, year, brand, geography, proper nouns, month, etc.) from the given user text. Output the entities as a JSON list, including only the entities without any additional text or metadata. 2 | 3 | : 4 | 5 | User: 6 | In March 2021, Apple Inc. released the iPhone 12 in Cupertino, California, which featured significant improvements in battery life and processing power. 7 | 8 | Assistant: 9 | ["March 2021", "Apple Inc.", "iPhone 12", "Cupertino", "California"] -------------------------------------------------------------------------------- /promptflow/rag-experiment-accelerator/querying/querying.py: -------------------------------------------------------------------------------- 1 | from promptflow import tool 2 | from rag_experiment_accelerator.run.querying import run 3 | from rag_experiment_accelerator.config.environment import Environment 4 | from rag_experiment_accelerator.config.config import Config 5 | 6 | 7 | @tool 8 | def my_python_tool(config_path: str) -> bool: 9 | environment = Environment.from_env_or_keyvault() 10 | config = Config.from_path(environment, config_path) 11 | 12 | for index_config in config.index.flatten(): 13 | run(environment, config, index_config) 14 | return True 15 | -------------------------------------------------------------------------------- /promptflow/rag-experiment-accelerator/index/flow.dag.yaml: -------------------------------------------------------------------------------- 1 | inputs: 2 | config_dir: 3 | type: string 4 | default: ../ 5 | should_index: 6 | type: bool 7 | default: true 8 | outputs: {} 9 | nodes: 10 | - name: setup_env 11 | type: python 12 | source: 13 | type: code 14 | path: ../setup/setup_env.py 15 | inputs: 16 | connection: "" 17 | - name: create_index 18 | type: python 19 | source: 20 | type: code 21 | path: create_index.py 22 | inputs: 23 | should_index: true 24 | config_dir: ${inputs.config_dir} 25 | activate: 26 | when: ${setup_env.output} 27 | is: true 28 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | # To get started with Dependabot version updates, you'll need to specify which 2 | # package ecosystems to update and where the package manifests are located. 3 | # Please see the documentation for all configuration options: 4 | # https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates 5 | version: 2 6 | updates: 7 | - package-ecosystem: "github-actions" 8 | directory: "/" 9 | schedule: 10 | interval: "weekly" 11 | - package-ecosystem: "pip" 12 | directory: "/" 13 | schedule: 14 | interval: "weekly" 15 | open-pull-requests-limit: 50 16 | -------------------------------------------------------------------------------- /rag_experiment_accelerator/checkpoint/null_checkpoint.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | from rag_experiment_accelerator.checkpoint.checkpoint import Checkpoint 3 | 4 | 5 | class NullCheckpoint(Checkpoint): 6 | """ 7 | A dummy checkpoint implementation that does not do anything, used in cases where the checkpoints should be ignored. 8 | """ 9 | 10 | def __init__(self): 11 | pass 12 | 13 | def _has_data(self, id: str, method) -> bool: 14 | return False 15 | 16 | def _load(self, id: str, method) -> Any: 17 | pass 18 | 19 | def _save(self, data: Any, id: str, method): 20 | pass 21 | -------------------------------------------------------------------------------- /rag_experiment_accelerator/io/exceptions.py: -------------------------------------------------------------------------------- 1 | class WriteException(Exception): 2 | def __init__(self, path: str, e: Exception): 3 | super().__init__( 4 | f"Unable to write to file to path: {path}. Please ensure" 5 | " you have the proper permissions to write to the file.", 6 | e, 7 | ) 8 | 9 | 10 | class CopyException(Exception): 11 | def __init__(self, src: str, dest: str, e: Exception): 12 | super().__init__( 13 | f"Unable to copy file from {src} to {dest}. Please ensure" 14 | " you have the proper permissions to copy the file.", 15 | e, 16 | ) 17 | -------------------------------------------------------------------------------- /rag_experiment_accelerator/embedding/factory.py: -------------------------------------------------------------------------------- 1 | from rag_experiment_accelerator.embedding.aoai_embedding_model import AOAIEmbeddingModel 2 | from rag_experiment_accelerator.embedding.st_embedding_model import STEmbeddingModel 3 | 4 | 5 | def create_embedding_model(model_type: str, **kwargs): 6 | match model_type: 7 | case "azure": 8 | return AOAIEmbeddingModel(**kwargs) 9 | case "sentence-transformer": 10 | return STEmbeddingModel(**kwargs) 11 | case _: 12 | raise ValueError( 13 | f"Invalid embedding type: {model_type}. Must be one of ['azure', 'sentence-transformer']" 14 | ) 15 | -------------------------------------------------------------------------------- /rag_experiment_accelerator/io/local/tests/test_local_io_base.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import tempfile 4 | import pytest 5 | 6 | from rag_experiment_accelerator.io.local.base import LocalIOBase 7 | 8 | 9 | @pytest.fixture() 10 | def temp_dir(): 11 | dir = tempfile.mkdtemp() 12 | yield dir 13 | if os.path.exists(dir): 14 | shutil.rmtree(dir) 15 | 16 | 17 | def test_exists_true(temp_dir: str) -> bool: 18 | loader = LocalIOBase() 19 | assert loader.exists(temp_dir) is True 20 | 21 | 22 | def test_exists_false() -> bool: 23 | path = "/tmp/non-existing-file" 24 | loader = LocalIOBase() 25 | assert loader.exists(path) is False 26 | -------------------------------------------------------------------------------- /promptflow/rag-experiment-accelerator/qa_generation/generate_qa.py: -------------------------------------------------------------------------------- 1 | from promptflow import tool 2 | from rag_experiment_accelerator.run.qa_generation import run 3 | from rag_experiment_accelerator.config.config import Config 4 | from rag_experiment_accelerator.config.environment import Environment 5 | from rag_experiment_accelerator.config.paths import get_all_file_paths 6 | 7 | 8 | @tool 9 | def my_python_tool(config_path: str, should_generate_qa: bool) -> bool: 10 | environment = Environment.from_env_or_keyvault() 11 | config = Config.from_path(environment, config_path) 12 | 13 | if should_generate_qa: 14 | run(environment, config, get_all_file_paths(config.path.data_dir)) 15 | return True 16 | -------------------------------------------------------------------------------- /rag_experiment_accelerator/config/language_config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from rag_experiment_accelerator.config.base_config import BaseConfig 3 | 4 | 5 | @dataclass 6 | class LanguageAnalyzerConfig(BaseConfig): 7 | analyzer_name: str = "en.microsoft" 8 | index_analyzer_name: str = "" 9 | search_analyzer_name: str = "" 10 | char_filters: list[any] = field(default_factory=list) 11 | tokenizers: list[any] = field(default_factory=list) 12 | token_filters: list[any] = field(default_factory=list) 13 | 14 | 15 | @dataclass 16 | class LanguageConfig(BaseConfig): 17 | analyzer: LanguageAnalyzerConfig = field(default_factory=LanguageAnalyzerConfig) 18 | query_language: str = "en-us" 19 | -------------------------------------------------------------------------------- /rag_experiment_accelerator/checkpoint/tests/test_null_checkpoint.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from rag_experiment_accelerator.checkpoint.null_checkpoint import NullCheckpoint 4 | 5 | 6 | def dummy(word): 7 | return f"hello {word}" 8 | 9 | 10 | class TestNullCheckpoint(unittest.TestCase): 11 | def test_wrapped_method_is_not_cached(self): 12 | checkpoint = NullCheckpoint() 13 | data_id = "unique_id" 14 | result1 = checkpoint.load_or_run(dummy, data_id, "first run") 15 | result2 = checkpoint.load_or_run(dummy, data_id, "second run") 16 | self.assertEqual(result1, "hello first run") 17 | self.assertEqual(result2, "hello second run") 18 | 19 | 20 | if __name__ == "__main__": 21 | unittest.main() 22 | -------------------------------------------------------------------------------- /.devcontainer/devcontainer.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "RAG Experiment Accelerator", 3 | "image": "mcr.microsoft.com/devcontainers/python:1-3.11-bullseye", 4 | "features": { 5 | "ghcr.io/devcontainers/features/azure-cli:1": { 6 | "version": "latest" 7 | }, 8 | "ghcr.io/azure/azure-dev/azd:latest": {} 9 | }, 10 | "postCreateCommand": "./.devcontainer/post-create.sh", 11 | "customizations": { 12 | "vscode": { 13 | "extensions": [ 14 | "github.vscode-pull-request-github", 15 | "ms-vscode.azure-account", 16 | "ms-python.python", 17 | "ms-python.flake8", 18 | "ms-azuretools.vscode-bicep", 19 | "prompt-flow.prompt-flow", 20 | "ms-azuretools.azure-dev", 21 | "streetsidesoftware.code-spell-checker" 22 | ] 23 | } 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /.github/actions/configure_azureml_agent/action.yml: -------------------------------------------------------------------------------- 1 | name: Prepare build environment 2 | 3 | description: Prepares build environment for python and prompt flow related workflow execution. 4 | 5 | inputs: 6 | versionSpec: 7 | description: "The Python version to use in the environment." 8 | default: "3.11" 9 | 10 | 11 | runs: 12 | using: composite 13 | steps: 14 | - name: Checkout 15 | uses: actions/checkout@v4 16 | 17 | - uses: actions/setup-python@v4 18 | with: 19 | python-version: ${{ inputs.versionSpec }} 20 | 21 | - name: Load all prompt flow and related dependencies 22 | shell: bash 23 | run: | 24 | set -e # fail on error 25 | python -m pip install --upgrade pip 26 | python -m pip install . 27 | 28 | -------------------------------------------------------------------------------- /infra/generate_arm_template.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | az bicep version 2>/dev/null || az bicep install 6 | 7 | TEMPLATES=() 8 | FILES=() 9 | 10 | for ARG in $@; do 11 | # If the argument is supplied with "-f", then it is a template file that needs to be built 12 | if [[ $ARG == -f=* ]]; then 13 | TEMPLATES+=(${ARG#-f=}) 14 | else 15 | # Otherwise, it is a file that has been edited 16 | az bicep format --insert-final-newline -f $ARG & 17 | FILES+=($ARG) 18 | fi 19 | done 20 | 21 | wait 22 | 23 | git add ${FILES[@]} 24 | 25 | # Build the templates 26 | for TEMPLATE in ${TEMPLATES[@]}; do 27 | az bicep build -f $TEMPLATE 28 | git add "${TEMPLATE%.bicep}.json" # Change the extension from .bicep to .json 29 | done 30 | -------------------------------------------------------------------------------- /rag_experiment_accelerator/evaluation/tests/test_transformer_based_metrics.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import MagicMock 2 | 3 | import numpy as np 4 | 5 | 6 | from rag_experiment_accelerator.evaluation.transformer_based_metrics import ( 7 | compare_semantic_document_values, 8 | ) 9 | 10 | 11 | def test_compare_semantic_document_values(): 12 | mock_sentence_transformer = MagicMock() 13 | embeddings1 = np.array([[0.1, 0.2, 0.3, 0.4, 0.7]]) 14 | embeddings2 = np.array([[0.1, 0.3, 0.4, 0.5, 0.6]]) 15 | 16 | mock_sentence_transformer.encode.side_effect = [embeddings1, embeddings2] 17 | 18 | value1 = "value1" 19 | value2 = "value2" 20 | 21 | assert ( 22 | compare_semantic_document_values(value1, value2, mock_sentence_transformer) 23 | == 97 24 | ) 25 | -------------------------------------------------------------------------------- /infra/shared/machineLearning.bicep: -------------------------------------------------------------------------------- 1 | metadata description = 'Creates an Azure Machine Learning Workspace.' 2 | param name string 3 | param location string = resourceGroup().location 4 | param tags object = {} 5 | param storageAccount string 6 | param keyVault string 7 | param applicationInsights string 8 | 9 | resource machineLearningWorkspace 'Microsoft.MachineLearningServices/workspaces@2023-06-01-preview' = { 10 | name: name 11 | location: location 12 | identity: { 13 | type: 'systemAssigned' 14 | } 15 | tags: tags 16 | properties: { 17 | storageAccount: storageAccount 18 | keyVault: keyVault 19 | applicationInsights: applicationInsights 20 | } 21 | } 22 | 23 | output workspaceName string = machineLearningWorkspace.name 24 | output workspaceId string = machineLearningWorkspace.id 25 | -------------------------------------------------------------------------------- /rag_experiment_accelerator/io/local/writers/tests/test_jsonl_writer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import tempfile 4 | 5 | import pytest 6 | 7 | from rag_experiment_accelerator.io.local.writers.jsonl_writer import ( 8 | JsonlWriter, 9 | ) 10 | 11 | 12 | @pytest.fixture() 13 | def temp_dir(): 14 | dir = tempfile.mkdtemp() 15 | yield dir 16 | if os.path.exists(dir): 17 | shutil.rmtree(dir) 18 | 19 | 20 | def test__write_file(temp_dir: str): 21 | # set up 22 | data = {"test": "test"} 23 | path = temp_dir + "/test.jsonl" 24 | 25 | # write the file 26 | writer = JsonlWriter() 27 | writer._write_file(path, data) 28 | 29 | # check file was written and contains the correct data 30 | with open(path) as file: 31 | assert file.readline() == '{"test": "test"}\n' 32 | -------------------------------------------------------------------------------- /rag_experiment_accelerator/llm/prompts_text/llm_context_precision_instruction.txt: -------------------------------------------------------------------------------- 1 | Given a question and a context, verify if the information in the given context is useful in answering the question. Return a Yes/No answer. 2 | 3 | 4 | User: 5 | Context: 6 | The latest software update for smartphones includes significant improvements in security protocols and user interface enhancements. 7 | 8 | Question: 9 | What does the new software update include for smartphones? 10 | 11 | Assistant: 12 | Yes 13 | 14 | 15 | 16 | User: 17 | Context: 18 | Coffee consumption statistics in 2019 show that adults aged 25-34 are the largest group of coffee drinkers in the United States. 19 | 20 | Question: 21 | What are the health benefits of drinking coffee? 22 | 23 | Assistant: 24 | No -------------------------------------------------------------------------------- /rag_experiment_accelerator/config/chunking_config.py: -------------------------------------------------------------------------------- 1 | from enum import StrEnum 2 | from dataclasses import dataclass 3 | from rag_experiment_accelerator.config.base_config import BaseConfig 4 | 5 | 6 | class ChunkingStrategy(StrEnum): 7 | BASIC = "basic" 8 | AZURE_DOCUMENT_INTELLIGENCE = "azure-document-intelligence" 9 | 10 | def __repr__(self) -> str: 11 | return f'"{self.value}"' 12 | 13 | 14 | @dataclass 15 | class ChunkingConfig(BaseConfig): 16 | preprocess: bool = False 17 | chunk_size: int = 512 18 | overlap_size: int = 128 19 | generate_title: bool = False 20 | generate_summary: bool = False 21 | override_content_with_summary: bool = False 22 | chunking_strategy: ChunkingStrategy = ChunkingStrategy.BASIC 23 | # only for azure document intelligence strategy 24 | azure_document_intelligence_model: str = "prebuilt-read" 25 | -------------------------------------------------------------------------------- /infra/network/azure_bastion.bicep: -------------------------------------------------------------------------------- 1 | param vnetName string 2 | param bastionName string 3 | param bastionSubnetName string 4 | param location string 5 | param publicIpName string // Name of the existing public IP resource 6 | 7 | resource bastion 'Microsoft.Network/bastionHosts@2023-04-01' = { 8 | name: bastionName 9 | location: location 10 | properties: { 11 | dnsName: bastionName 12 | ipConfigurations: [ 13 | { 14 | name: 'bastionIpConfig' 15 | properties: { 16 | subnet: { 17 | id: resourceId('Microsoft.Network/virtualNetworks/subnets', vnetName, bastionSubnetName) 18 | } 19 | publicIPAddress: { 20 | id: resourceId('Microsoft.Network/publicIPAddresses', publicIpName) 21 | } 22 | } 23 | } 24 | ] 25 | } 26 | } 27 | 28 | output bastionFqdn string = bastion.properties.dnsName 29 | -------------------------------------------------------------------------------- /rag_experiment_accelerator/config/sampling_config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from rag_experiment_accelerator.config.base_config import BaseConfig 3 | 4 | 5 | @dataclass 6 | class SamplingConfig(BaseConfig): 7 | """A class to hold parameters for the sampling. 8 | 9 | Attributes: 10 | sample_data (bool): 11 | Flag indicating whether to sample the data. 12 | percentage (int): 13 | Percentage of data to sample. 14 | optimum_k (str): 15 | Optimum value of k for clustering. 16 | min_cluster (int): 17 | Minimum number of clusters. 18 | max_cluster (int): 19 | Maximum number of clusters. 20 | """ 21 | 22 | sample_data: bool = False 23 | percentage: int = 5 24 | optimum_k: str = "auto" 25 | min_cluster: int = 2 26 | max_cluster: int = 30 27 | only_sample: bool = False 28 | -------------------------------------------------------------------------------- /promptflow/rag-experiment-accelerator/index/create_index.py: -------------------------------------------------------------------------------- 1 | from promptflow import tool 2 | from rag_experiment_accelerator.checkpoint import init_checkpoint 3 | from rag_experiment_accelerator.run.index import run 4 | from rag_experiment_accelerator.config.paths import get_all_file_paths 5 | from rag_experiment_accelerator.config.environment import Environment 6 | from rag_experiment_accelerator.config.config import Config 7 | 8 | 9 | @tool 10 | def my_python_tool(should_index: bool, config_path: str) -> bool: 11 | environment = Environment.from_env_or_keyvault() 12 | config = Config.from_path(environment, config_path) 13 | init_checkpoint(config) 14 | 15 | if should_index: 16 | file_paths = get_all_file_paths(config.path.data_dir) 17 | for index_config in config.index.flatten(): 18 | run(environment, config, index_config, file_paths) 19 | return True 20 | -------------------------------------------------------------------------------- /rag_experiment_accelerator/llm/prompt/rerank_prompts.py: -------------------------------------------------------------------------------- 1 | import re 2 | import json 3 | from rag_experiment_accelerator.llm.prompt.prompt import StructuredPrompt, PromptTag 4 | 5 | 6 | def validate_rerank(text: str) -> bool: 7 | json_output = json.loads(text) 8 | 9 | def key_matches(key: str) -> bool: 10 | return bool(re.match(r"^document_\d+$", key)) 11 | 12 | return isinstance(json_output, dict) and all( 13 | isinstance(key, str) and isinstance(value, int) and key_matches(key) 14 | for key, value in json_output.items() 15 | ) 16 | 17 | 18 | _rerank_template: str = """ 19 | ${documents} 20 | 21 | Question: ${question} 22 | """ 23 | 24 | rerank_prompt_instruction = StructuredPrompt( 25 | system_message="prompt_instruction_keywords.txt", 26 | user_template=_rerank_template, 27 | validator=validate_rerank, 28 | tags={PromptTag.JSON, PromptTag.NonStrict}, 29 | ) 30 | -------------------------------------------------------------------------------- /rag_experiment_accelerator/utils/auth.py: -------------------------------------------------------------------------------- 1 | from azure.identity import DefaultAzureCredential 2 | 3 | from rag_experiment_accelerator.utils.logging import get_logger 4 | 5 | logger = get_logger(__name__) 6 | 7 | 8 | def get_default_az_cred(): 9 | """ 10 | Returns a DefaultAzureCredential object that can be used to authenticate with Azure services. 11 | If the credential cannot be obtained, an error is logged and an exception is raised. 12 | """ 13 | try: 14 | credential = DefaultAzureCredential() 15 | # Check if credential can get token successfully. 16 | credential.get_token("https://management.azure.com/.default") 17 | except Exception as ex: 18 | logger.error( 19 | "Unable to get a token from DefaultAzureCredential. Please run 'az" 20 | " login' in your terminal and try again." 21 | ) 22 | raise ex 23 | return credential 24 | -------------------------------------------------------------------------------- /rag_experiment_accelerator/doc_loader/tests/test_custom_html_loader.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import Mock 2 | 3 | from rag_experiment_accelerator.doc_loader.htmlLoader import load_html_files 4 | from rag_experiment_accelerator.config.paths import get_all_file_paths 5 | 6 | 7 | def test_load_html_files(): 8 | chunks = load_html_files( 9 | environment=Mock(), 10 | file_paths=get_all_file_paths("./data/html"), 11 | chunk_size=1000, 12 | overlap_size=200, 13 | ) 14 | 15 | assert len(chunks) == 20 16 | 17 | assert ( 18 | "Deep Neural Nets: 33 years ago and 33 years from now" 19 | in list(chunks[0].values())[0]["content"] 20 | ) 21 | assert ( 22 | "Deep Neural Nets: 33 years ago and 33 years from now" 23 | not in list(chunks[5].values())[0]["content"] 24 | ) 25 | assert "Musings of a Computer Scientist." in list(chunks[19].values())[0]["content"] 26 | -------------------------------------------------------------------------------- /infra/shared/keyvault-secret.bicep: -------------------------------------------------------------------------------- 1 | metadata description = 'Creates or updates a secret in an Azure Key Vault.' 2 | param name string 3 | param tags object = {} 4 | param keyVaultName string 5 | param contentType string = 'string' 6 | @description('The value of the secret. Provide only derived values like blob storage access, but do not hard code any secrets in your templates') 7 | @secure() 8 | param secretValue string 9 | 10 | param enabled bool = true 11 | param exp int = 0 12 | param nbf int = 0 13 | 14 | resource keyVaultSecret 'Microsoft.KeyVault/vaults/secrets@2022-07-01' = { 15 | name: name 16 | tags: tags 17 | parent: keyVault 18 | properties: { 19 | attributes: { 20 | enabled: enabled 21 | exp: exp 22 | nbf: nbf 23 | } 24 | contentType: contentType 25 | value: secretValue 26 | } 27 | } 28 | 29 | resource keyVault 'Microsoft.KeyVault/vaults@2022-07-01' existing = { 30 | name: keyVaultName 31 | } 32 | -------------------------------------------------------------------------------- /rag_experiment_accelerator/doc_loader/tests/test_docx_loader.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import Mock 2 | 3 | from rag_experiment_accelerator.doc_loader.docxLoader import load_docx_files 4 | from rag_experiment_accelerator.config.paths import get_all_file_paths 5 | 6 | 7 | def test_load_docx_files(): 8 | folder_path = "./data/docx" 9 | chunk_size = 1000 10 | overlap_size = 400 11 | 12 | original_doc = load_docx_files( 13 | environment=Mock(), 14 | file_paths=get_all_file_paths(folder_path), 15 | chunk_size=chunk_size, 16 | overlap_size=overlap_size, 17 | ) 18 | 19 | assert len(original_doc) == 3 20 | 21 | assert "We recently commissioned" in list(original_doc[0].values())[0]["content"] 22 | assert "We recently commissioned" in list(original_doc[1].values())[0]["content"] 23 | assert ( 24 | "We recently commissioned" not in list(original_doc[2].values())[0]["content"] 25 | ) 26 | -------------------------------------------------------------------------------- /02_qa_generation.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from rag_experiment_accelerator.checkpoint import init_checkpoint 4 | from rag_experiment_accelerator.run.qa_generation import run 5 | from rag_experiment_accelerator.config.config import Config 6 | from rag_experiment_accelerator.config.environment import Environment 7 | from rag_experiment_accelerator.config.paths import get_all_file_paths 8 | 9 | 10 | if __name__ == "__main__": 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument( 13 | "--config_path", type=str, help="input: path to the config file" 14 | ) 15 | parser.add_argument("--data_dir", type=str, help="input: path to the input data") 16 | args, _ = parser.parse_known_args() 17 | 18 | environment = Environment.from_env_or_keyvault() 19 | config = Config.from_path(environment, args.config_path, args.data_dir) 20 | init_checkpoint(config) 21 | 22 | run(environment, config, get_all_file_paths(config.path.data_dir)) 23 | -------------------------------------------------------------------------------- /rag_experiment_accelerator/utils/logging.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import sys 4 | 5 | # Global variable to cache the logging level 6 | _cached_logging_level = None 7 | 8 | 9 | def get_logger(name: str) -> logging.Logger: 10 | """Get Logger 11 | 12 | Args: 13 | name (str): Logger name 14 | 15 | Returns: 16 | logging.Logger: named logger 17 | """ 18 | logger = logging.getLogger(name) 19 | if logger.hasHandlers(): 20 | return logger 21 | 22 | global _cached_logging_level 23 | if not _cached_logging_level: 24 | _cached_logging_level = os.getenv("LOGGING_LEVEL", "INFO").upper() 25 | 26 | handler = logging.StreamHandler(sys.stdout) 27 | formatter = logging.Formatter( 28 | "%(asctime)s - %(levelname)s - %(name)s - %(message)s" 29 | ) 30 | handler.setFormatter(formatter) 31 | logger.setLevel(_cached_logging_level) 32 | logger.addHandler(handler) 33 | 34 | return logger 35 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | azure-ai-ml==1.25.0 2 | azure-ai-textanalytics==5.3.0 3 | azure-core==1.31.0 4 | azure-identity==1.18.0 5 | azure-keyvault==4.2.0 6 | azure-keyvault-secrets==4.8.* 7 | azure-search-documents==11.4.b11 8 | azure.ai.documentintelligence==1.0.0b4 9 | azureml-core==1.57.0.post1 10 | azureml-mlflow==1.57.0.post1 11 | beautifulsoup4==4.12.3 12 | datasets==3.0.0 13 | docx2txt==0.8 14 | evaluate==0.4.3 15 | hnswlib==0.8.0 16 | jsonschema==4.23.0 17 | kaleido==0.2.1 18 | langchain==0.3.0 19 | langchain-community==0.3.0 20 | levenshtein==0.26.0 21 | lxml==5.3.0 22 | matplotlib==3.9.2 23 | mlflow==2.16.1 24 | openai==1.64.0 25 | plotly==5.24.1 26 | pypdf==4.3.1 27 | pytesseract==0.3.13 28 | python-dotenv==1.0.1 29 | PyMuPDF==1.24.10 30 | PyPDF2~=3.0 31 | rapidfuzz==3.9.7 32 | rouge-score==0.1.2 33 | scikit-learn==1.5.2 34 | sentence-transformers==3.1.1 35 | spacy==3.7.6 36 | textdistance==4.6.3 37 | tiktoken==0.7.0 38 | tqdm==4.66.5 39 | umap-learn==0.5.6 40 | unstructured==0.15.13 41 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/psf/black 3 | rev: 23.12.1 4 | hooks: 5 | - id: black 6 | language_version: python3 7 | 8 | - repo: https://github.com/pycqa/flake8 9 | rev: 7.0.0 10 | hooks: 11 | - id: flake8 12 | args: [--extend-ignore=E501] 13 | 14 | - repo: https://github.com/python-jsonschema/check-jsonschema 15 | rev: 0.28.5 16 | hooks: 17 | - id: check-jsonschema 18 | files: ^config.sample.json|.github/workflows/config.json$ 19 | types: [json] 20 | args: ["--schemafile", "config.schema.json"] 21 | 22 | - repo: local 23 | hooks: 24 | - id: bicep 25 | name: bicep 26 | description: Lint and build Bicep files 27 | entry: ./infra/generate_arm_template.sh 28 | language: script 29 | files: \.bicep$ 30 | require_serial: true 31 | args: # Bicep files that we want to generate ARM templates from 32 | - -f=./infra/main.bicep 33 | -------------------------------------------------------------------------------- /rag_experiment_accelerator/llm/prompt/multiprompts.py: -------------------------------------------------------------------------------- 1 | import json 2 | from rag_experiment_accelerator.llm.prompt.prompt import StructuredPrompt, PromptTag 3 | 4 | 5 | def validate_do_we_need_multiple(text: str) -> bool: 6 | return text.lower().strip() in ["complex", "simple"] 7 | 8 | 9 | def validate_multiple_prompt(text: str) -> bool: 10 | json_output = json.loads(text) 11 | return isinstance(json_output, list) and all( 12 | isinstance(item, str) for item in json_output 13 | ) 14 | 15 | 16 | do_need_multiple_prompt_instruction = StructuredPrompt( 17 | system_message="do_need_multiple_prompt_instruction.txt", 18 | user_template="${text}", 19 | validator=validate_do_we_need_multiple, 20 | tags={PromptTag.NonStrict}, 21 | ) 22 | 23 | multiple_prompt_instruction = StructuredPrompt( 24 | system_message="multiple_prompt_instruction.txt", 25 | user_template="${text}", 26 | validator=validate_multiple_prompt, 27 | tags={PromptTag.JSON, PromptTag.NonStrict}, 28 | ) 29 | -------------------------------------------------------------------------------- /rag_experiment_accelerator/utils/timetook.py: -------------------------------------------------------------------------------- 1 | import time 2 | from .logging import get_logger 3 | 4 | 5 | class TimeTook(object): 6 | """ 7 | Calculates the time a block took to run. 8 | Example usage: 9 | with TimeTook("sample"): 10 | s = [x for x in range(10000000)] 11 | Modified from: https://blog.usejournal.com/how-to-create-your-own-timing-context-manager-in-python-a0e944b48cf8 # noqa 12 | """ 13 | 14 | def __init__(self, description, logger): 15 | self.description = description 16 | self.logger = logger if logger else get_logger(__name__) 17 | self.start = None 18 | self.end = None 19 | 20 | def __enter__(self): 21 | self.start = time.perf_counter() 22 | self.logger.info(f"Starting {self.description}") 23 | 24 | def __exit__(self, type, value, traceback): 25 | self.end = time.perf_counter() 26 | self.logger.info( 27 | f"Time took for {self.description}: " f"{self.end - self.start} seconds" 28 | ) 29 | -------------------------------------------------------------------------------- /rag_experiment_accelerator/llm/prompt/hyde_prompts.py: -------------------------------------------------------------------------------- 1 | import json 2 | from rag_experiment_accelerator.llm.prompt.prompt import ( 3 | Prompt, 4 | StructuredPrompt, 5 | PromptTag, 6 | ) 7 | 8 | 9 | def validate_hypothetical_questions(text: str) -> bool: 10 | json_output = json.loads(text) 11 | return isinstance(json_output, list) and all( 12 | isinstance(item, str) for item in json_output 13 | ) 14 | 15 | 16 | prompt_generate_hypothetical_answer = Prompt( 17 | system_message="prompt_generate_hypothetical_answer.txt", 18 | user_template="${text}", 19 | ) 20 | 21 | prompt_generate_hypothetical_document = Prompt( 22 | system_message="prompt_generate_hypothetical_document.txt", 23 | user_template="${text}", 24 | ) 25 | 26 | prompt_generate_hypothetical_questions = StructuredPrompt( 27 | system_message="prompt_generate_hypothetical_questions.txt", 28 | user_template="${text}", 29 | validator=validate_hypothetical_questions, 30 | tags={PromptTag.JSON, PromptTag.NonStrict}, 31 | ) 32 | -------------------------------------------------------------------------------- /env_to_keyvault.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script is used to create secrets in Azure Keyvault from the environment variables. 3 | 4 | For the list of environment parameters that will be created as secrets, please refer to the Environment class in rag_experiment_accelerator/config/environment.py. 5 | """ 6 | 7 | import argparse 8 | 9 | from rag_experiment_accelerator.config.environment import Environment 10 | from rag_experiment_accelerator.utils.logging import get_logger 11 | 12 | 13 | logger = get_logger(__name__) 14 | 15 | 16 | if __name__ == "__main__": 17 | parser = argparse.ArgumentParser() 18 | 19 | environment = Environment.from_env_or_keyvault() 20 | logger.info("Creating secrets in Keyvault from the environment") 21 | logger.info("The following secrets will be created:") 22 | for secret in environment.fields(): 23 | logger.info(f" - {secret[0]}") 24 | 25 | environment.to_keyvault() 26 | logger.info( 27 | f"Secrets in Keyvault {environment.azure_key_vault_endpoint} have been created successfully." 28 | ) 29 | -------------------------------------------------------------------------------- /infra/shared/keyvault.bicep: -------------------------------------------------------------------------------- 1 | param name string 2 | param location string = resourceGroup().location 3 | param tags object = {} 4 | 5 | @description('Service principal will be granted read access to the KeyVault. If unset, no service principal is granted access by default') 6 | param principalId string = '' 7 | 8 | var defaultAccessPolicies = !empty(principalId) ? [ 9 | { 10 | objectId: principalId 11 | permissions: { secrets: [ 'get', 'set', 'list' ] } 12 | tenantId: subscription().tenantId 13 | } 14 | ] : [] 15 | 16 | resource keyVault 'Microsoft.KeyVault/vaults@2022-07-01' = { 17 | name: name 18 | location: location 19 | tags: tags 20 | properties: { 21 | tenantId: subscription().tenantId 22 | sku: { family: 'A', name: 'standard' } 23 | enabledForTemplateDeployment: true 24 | accessPolicies: union(defaultAccessPolicies, [ 25 | // define access policies here 26 | ]) 27 | } 28 | } 29 | 30 | output id string = keyVault.id 31 | output endpoint string = keyVault.properties.vaultUri 32 | output name string = keyVault.name 33 | -------------------------------------------------------------------------------- /infra/shared/monitoring.bicep: -------------------------------------------------------------------------------- 1 | param logAnalyticsName string 2 | param applicationInsightsName string 3 | param location string = resourceGroup().location 4 | param tags object = {} 5 | 6 | resource logAnalytics 'Microsoft.OperationalInsights/workspaces@2021-12-01-preview' = { 7 | name: logAnalyticsName 8 | location: location 9 | tags: tags 10 | properties: any({ 11 | retentionInDays: 30 12 | features: { 13 | searchVersion: 1 14 | } 15 | sku: { 16 | name: 'PerGB2018' 17 | } 18 | }) 19 | } 20 | 21 | resource applicationInsights 'Microsoft.Insights/components@2020-02-02' = { 22 | name: applicationInsightsName 23 | location: location 24 | tags: tags 25 | kind: 'web' 26 | properties: { 27 | Application_Type: 'web' 28 | WorkspaceResourceId: logAnalytics.id 29 | } 30 | } 31 | 32 | output applicationInsightsId string = applicationInsights.id 33 | output applicationInsightsName string = applicationInsights.name 34 | output logAnalyticsWorkspaceId string = logAnalytics.id 35 | output logAnalyticsWorkspaceName string = logAnalytics.name 36 | -------------------------------------------------------------------------------- /rag_experiment_accelerator/io/local/writers/jsonl_writer.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from rag_experiment_accelerator.io.local.writers.local_writer import LocalWriter 4 | from rag_experiment_accelerator.utils.logging import get_logger 5 | 6 | logger = get_logger(__name__) 7 | 8 | 9 | class JsonlWriter(LocalWriter): 10 | """ 11 | A class for writing data to a JSONL file. 12 | 13 | Inherits from the LocalWriter class. 14 | 15 | Attributes: 16 | None 17 | 18 | Methods: 19 | write_file: Writes data to a JSONL file. 20 | 21 | """ 22 | 23 | def _write_file(self, path: str, data, **kwargs): 24 | """ 25 | Writes the given data to a JSONL file. 26 | 27 | Args: 28 | path (str): The path to the JSONL file. 29 | data: The data to be written to the file. 30 | **kwargs: Additional keyword arguments to be passed to the json.dumps() function. 31 | 32 | Returns: 33 | None 34 | 35 | """ 36 | with open(path, "a") as file: 37 | file.write(json.dumps(data, **kwargs) + "\n") 38 | -------------------------------------------------------------------------------- /cspell.json: -------------------------------------------------------------------------------- 1 | { 2 | "version": "0.2", 3 | "ignorePaths": [], 4 | "dictionaryDefinitions": [], 5 | "dictionaries": [], 6 | "words": [ 7 | "AOAI", 8 | "azuretools", 9 | "dataframe", 10 | "groundtruth", 11 | "jaccard", 12 | "keyvault", 13 | "spacy", 14 | "textdistance" 15 | ], 16 | "ignoreWords": [ 17 | "BLAS", 18 | "CNTK", 19 | "CUDA", 20 | "Caffe", 21 | "Chainer", 22 | "Cython", 23 | "Hogwild", 24 | "Jetson", 25 | "LAPACK", 26 | "NCCL", 27 | "NOTSET", 28 | "Numba", 29 | "OPENAI", 30 | "ROCM", 31 | "Theano", 32 | "aarch", 33 | "autograd", 34 | "azureml", 35 | "coveragerc", 36 | "devcontainer", 37 | "distilbert", 38 | "distro", 39 | "htmlcov", 40 | "keyvault", 41 | "libuv", 42 | "mlflow", 43 | "mpnet", 44 | "ndarray", 45 | "promptflow", 46 | "ptrblck", 47 | "pytest", 48 | "rerank", 49 | "scikit" 50 | ], 51 | "import": [] 52 | } 53 | -------------------------------------------------------------------------------- /data/text/sample-text.txt: -------------------------------------------------------------------------------- 1 | Deep learning is the subset of machine learning methods based on artificial neural networks (ANNs) with representation learning. The adjective "deep" refers to the use of multiple layers in the network. Methods used can be either supervised, semi-supervised or unsupervised.[2] 2 | 3 | Deep-learning architectures such as deep neural networks, deep belief networks, recurrent neural networks, convolutional neural networks and transformers have been applied to fields including computer vision, speech recognition, natural language processing, machine translation, bioinformatics, drug design, medical image analysis, climate science, material inspection and board game programs, where they have produced results comparable to and in some cases surpassing human expert performance.[3][4][5] 4 | 5 | Artificial neural networks were inspired by information processing and distributed communication nodes in biological systems. ANNs have various differences from biological brains. Specifically, artificial neural networks tend to be static and symbolic, while the biological brain of most living organisms is dynamic (plastic) and analog.[6][7] ANNs are generally seen as low quality models for brain function.[8] 6 | -------------------------------------------------------------------------------- /rag_experiment_accelerator/llm/prompts_text/prompt_generate_hypothetical_questions.txt: -------------------------------------------------------------------------------- 1 | You are a helpful expert research assistant. 2 | When users ask a question, enhance their inquiry by suggesting up to five additional related questions. 3 | These questions should help them delve deeper into the subject or explore various dimensions of the topic. 4 | Each suggested question should be concise and direct, avoiding compound sentences. 5 | Ensure the questions are complete, clearly formulated, and closely related to the original question. 6 | Output the questions in JSON format, with each question as an item in a list. 7 | 8 | 9 | User: 10 | What impact does social media have on mental health? 11 | 12 | Assistant: 13 | ["How does social media usage correlate with anxiety levels?", "What are the effects of social media on teenagers’ self-esteem?", "Can social media influence depression among adults?", "Are there positive psychological impacts of social media?", "How do different social media platforms affect mood?"] 14 | 15 | Respond with json. It should contain list of elements, where each element is a string, containing generated questions. 16 | 17 | Example output structure: ["This is first question", "This is second question", "This is third question"] -------------------------------------------------------------------------------- /rag_experiment_accelerator/io/loader.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | 4 | class Loader(ABC): 5 | """ 6 | Abstract base class for data loaders. 7 | """ 8 | 9 | @abstractmethod 10 | def load(self, src: str, **kwargs) -> list: 11 | """ 12 | Load data from the specified source. 13 | 14 | Args: 15 | src (str): The source of the data. 16 | **kwargs: Additional keyword arguments. 17 | 18 | Returns: 19 | list: The loaded data. 20 | """ 21 | pass 22 | 23 | @abstractmethod 24 | def can_handle(self, src: str) -> bool: 25 | """ 26 | Check if the loader can handle the specified source. 27 | 28 | Args: 29 | src (str): The source to check. 30 | 31 | Returns: 32 | bool: True if the loader can handle the source, False otherwise. 33 | """ 34 | pass 35 | 36 | @abstractmethod 37 | def exists(self, src: str) -> bool: 38 | """ 39 | Check if the specified source exists. 40 | 41 | Args: 42 | src (str): The source to check. 43 | 44 | Returns: 45 | bool: True if the source exists, False otherwise. 46 | """ 47 | pass 48 | -------------------------------------------------------------------------------- /rag_experiment_accelerator/embedding/embedding_model.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod, ABC 2 | 3 | 4 | class EmbeddingModel(ABC): 5 | """ 6 | Base class for embedding models. 7 | 8 | Args: 9 | name (str): The name of the embedding model. 10 | dimension (int): The dimension of the embeddings. 11 | 12 | Attributes: 13 | dimension (int): The dimension of the embeddings. 14 | 15 | Methods: 16 | generate_embedding(chunk: str) -> list: Abstract method to generate embeddings for a given chunk of text. 17 | """ 18 | 19 | def __init__(self, name: str, dimension: int, **kwargs) -> None: 20 | self.name = name 21 | self.dimension = dimension 22 | 23 | @abstractmethod 24 | def generate_embedding(self, chunk: str) -> list[float]: 25 | """ 26 | abstract method to generate embeddings for a given chunk of text. 27 | 28 | Args: 29 | chunk (str): The input text chunk for which the embedding needs to be generated. 30 | 31 | Returns: 32 | list: The generated embedding as a list. 33 | """ 34 | pass 35 | 36 | def to_dict(self) -> dict: 37 | return { 38 | "dimension": self.dimension, 39 | "name": self.name, 40 | } 41 | -------------------------------------------------------------------------------- /promptflow/rag-experiment-accelerator/evaluation/evaluation.py: -------------------------------------------------------------------------------- 1 | from promptflow import tool 2 | import mlflow 3 | 4 | from rag_experiment_accelerator.evaluation.eval import get_run_tags 5 | from rag_experiment_accelerator.run.evaluation import run, initialise_mlflow_client 6 | from rag_experiment_accelerator.config.environment import Environment 7 | from rag_experiment_accelerator.config.config import Config 8 | from rag_experiment_accelerator.config.paths import ( 9 | mlflow_run_name, 10 | formatted_datetime_suffix, 11 | ) 12 | 13 | 14 | @tool 15 | def my_python_tool(config_path: str) -> bool: 16 | environment = Environment.from_env_or_keyvault() 17 | config = Config.from_path(environment, config_path) 18 | mlflow_client = initialise_mlflow_client(environment, config) 19 | name_suffix = formatted_datetime_suffix() 20 | 21 | mlflow.set_tags(get_run_tags(config)) 22 | with mlflow.start_run(run_name=mlflow_run_name(config.job_name, name_suffix)): 23 | mlflow.set_tags() 24 | for index_config in config.index.flatten(): 25 | run( 26 | environment, 27 | config, 28 | index_config, 29 | mlflow_client, 30 | name_suffix, 31 | ) 32 | return True 33 | -------------------------------------------------------------------------------- /SUPPORT.md: -------------------------------------------------------------------------------- 1 | # TODO: The maintainer of this repo has not yet edited this file 2 | 3 | **REPO OWNER**: Do you want Customer Service & Support (CSS) support for this product/project? 4 | 5 | - **No CSS support:** Fill out this template with information about how to file issues and get help. 6 | - **Yes CSS support:** Fill out an intake form at [aka.ms/onboardsupport](https://aka.ms/onboardsupport). CSS will work with/help you to determine next steps. 7 | - **Not sure?** Fill out an intake as though the answer were "Yes". CSS will help you decide. 8 | 9 | *Then remove this first heading from this SUPPORT.MD file before publishing your repo.* 10 | 11 | # Support 12 | 13 | ## How to file issues and get help 14 | 15 | This project uses GitHub Issues to track bugs and feature requests. Please search the existing 16 | issues before filing new issues to avoid duplicates. For new issues, file your bug or 17 | feature request as a new Issue. 18 | 19 | For help and questions about using this project, please **REPO MAINTAINER: INSERT INSTRUCTIONS HERE 20 | FOR HOW TO ENGAGE REPO OWNERS OR COMMUNITY FOR HELP. COULD BE A STACK OVERFLOW TAG OR OTHER 21 | CHANNEL. WHERE WILL YOU HELP PEOPLE?**. 22 | 23 | ## Microsoft Support Policy 24 | 25 | Support for this **PROJECT or PRODUCT** is limited to the resources listed above. 26 | -------------------------------------------------------------------------------- /rag_experiment_accelerator/io/local/loaders/tests/test_jsonl_loader.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import shutil 4 | import tempfile 5 | import pytest 6 | 7 | from rag_experiment_accelerator.io.local.loaders.jsonl_loader import JsonlLoader 8 | 9 | 10 | @pytest.fixture() 11 | def temp_dir(): 12 | dir = tempfile.mkdtemp() 13 | yield dir 14 | if os.path.exists(dir): 15 | shutil.rmtree(dir) 16 | 17 | 18 | def test_loads(temp_dir: str): 19 | test_data = {"test": {"test1": 1, "test2": 2}} 20 | # write the file 21 | path = f"{temp_dir}/test.jsonl" 22 | with open(path, "a") as file: 23 | file.write(json.dumps(test_data) + "\n") 24 | 25 | # load the file 26 | loader = JsonlLoader() 27 | loaded_data = loader.load(path) 28 | 29 | assert loaded_data == [test_data] 30 | 31 | 32 | def test_loads_raises_file_not_found(temp_dir: str): 33 | path = f"{temp_dir}/non-existsing-file.jsonl" 34 | loader = JsonlLoader() 35 | with pytest.raises(FileNotFoundError): 36 | loader.load(path) 37 | 38 | 39 | def test_can_handle_true(): 40 | path = "test.jsonl" 41 | loader = JsonlLoader() 42 | assert loader.can_handle(path) is True 43 | 44 | 45 | def test_can_handle_false(): 46 | path = "test.txt" 47 | loader = JsonlLoader() 48 | assert loader.can_handle(path) is False 49 | -------------------------------------------------------------------------------- /rag_experiment_accelerator/doc_loader/docxLoader.py: -------------------------------------------------------------------------------- 1 | from langchain_community.document_loaders import Docx2txtLoader 2 | 3 | from rag_experiment_accelerator.doc_loader.structuredLoader import ( 4 | load_structured_files, 5 | ) 6 | from rag_experiment_accelerator.utils.logging import get_logger 7 | from rag_experiment_accelerator.config.environment import Environment 8 | 9 | logger = get_logger(__name__) 10 | 11 | 12 | def load_docx_files( 13 | environment: Environment, 14 | file_paths: list[str], 15 | chunk_size: str, 16 | overlap_size: str, 17 | **kwargs: dict, 18 | ): 19 | """ 20 | Load and process docx files from a given folder path. 21 | 22 | Args: 23 | environment (Environment): The environment class 24 | file_paths (list[str]): Sequence of paths to load. 25 | chunk_size (int): The size of each text chunk in characters. 26 | overlap_size (int): The size of the overlap between text chunks in characters. 27 | **kwargs (dict): Unused. 28 | 29 | 30 | Returns: 31 | list[Document]: A list of processed and split document chunks. 32 | """ 33 | 34 | logger.debug("Loading docx files") 35 | 36 | return load_structured_files( 37 | file_format="DOCX", 38 | language=None, 39 | loader=Docx2txtLoader, 40 | file_paths=file_paths, 41 | chunk_size=chunk_size, 42 | overlap_size=overlap_size, 43 | ) 44 | -------------------------------------------------------------------------------- /rag_experiment_accelerator/embedding/tests/test_st_embedding_model.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import patch 2 | import pytest 3 | import numpy as np 4 | from rag_experiment_accelerator.embedding.st_embedding_model import STEmbeddingModel 5 | 6 | 7 | @patch("rag_experiment_accelerator.embedding.st_embedding_model.SentenceTransformer") 8 | def test_generate_embedding(mock_sentence_transformer): 9 | expected_embeddings = [0.1, 0.2, 0.3] 10 | mock_embeddings = np.array([expected_embeddings]) 11 | mock_sentence_transformer.return_value.encode.return_value = mock_embeddings 12 | 13 | model = STEmbeddingModel("all-mpnet-base-v2") 14 | embeddings = model.generate_embedding("Hello world") 15 | 16 | assert expected_embeddings == embeddings 17 | 18 | 19 | def test_sentence_transformer_embedding_model_raises_non_existing_model(): 20 | with pytest.raises(OSError): 21 | STEmbeddingModel("non-existing-model", 123) 22 | 23 | 24 | def test_sentence_transformer_embedding_model_raises_unsupported_model(): 25 | with pytest.raises(ValueError): 26 | STEmbeddingModel("non-existing-model") 27 | 28 | 29 | @patch("rag_experiment_accelerator.embedding.st_embedding_model.SentenceTransformer") 30 | def test_sentence_transformer_embedding_model_succeeds(mock_sentence_transformer): 31 | try: 32 | STEmbeddingModel("all-mpnet-base-v2") 33 | except BaseException: 34 | assert False, "Should not have thrown an exception" 35 | -------------------------------------------------------------------------------- /promptflow/rag-experiment-accelerator/flow.dag.yaml: -------------------------------------------------------------------------------- 1 | inputs: 2 | should_index: 3 | type: bool 4 | default: true 5 | config_dir: 6 | type: string 7 | default: ../.. 8 | should_generate_qa: 9 | type: bool 10 | default: true 11 | outputs: {} 12 | nodes: 13 | - name: setup 14 | type: python 15 | source: 16 | type: code 17 | path: setup/setup_env.py 18 | inputs: 19 | connection: "" 20 | - name: index 21 | type: python 22 | source: 23 | type: code 24 | path: index/create_index.py 25 | inputs: 26 | should_index: ${inputs.should_index} 27 | config_dir: ${inputs.config_dir} 28 | activate: 29 | when: ${setup.output} 30 | is: true 31 | - name: generate_qa 32 | type: python 33 | source: 34 | type: code 35 | path: qa_generation/generate_qa.py 36 | inputs: 37 | config_dir: ${inputs.config_dir} 38 | should_generate_qa: ${inputs.should_generate_qa} 39 | activate: 40 | when: ${index.output} 41 | is: true 42 | - name: querying 43 | type: python 44 | source: 45 | type: code 46 | path: querying/querying.py 47 | inputs: 48 | config_dir: ${inputs.config_dir} 49 | activate: 50 | when: ${generate_qa.output} 51 | is: true 52 | - name: evaluation 53 | type: python 54 | source: 55 | type: code 56 | path: evaluation/evaluation.py 57 | inputs: 58 | config_dir: ${inputs.config_dir} 59 | activate: 60 | when: ${querying.output} 61 | is: true 62 | -------------------------------------------------------------------------------- /rag_experiment_accelerator/doc_loader/markdownLoader.py: -------------------------------------------------------------------------------- 1 | from langchain_community.document_loaders import UnstructuredMarkdownLoader 2 | 3 | from rag_experiment_accelerator.doc_loader.structuredLoader import ( 4 | load_structured_files, 5 | ) 6 | from rag_experiment_accelerator.utils.logging import get_logger 7 | from rag_experiment_accelerator.config.environment import Environment 8 | 9 | logger = get_logger(__name__) 10 | 11 | 12 | def load_markdown_files( 13 | environment: Environment, 14 | file_paths: list[str], 15 | chunk_size: str, 16 | overlap_size: str, 17 | **kwargs: dict, 18 | ): 19 | """ 20 | Load and process Markdown files from a given folder path. 21 | 22 | Args: 23 | environment (Environment): The environment class 24 | file_paths (list[str]): Sequence of paths to load. 25 | chunk_size (str): The size of the chunks to split the documents into. 26 | overlap_size (str): The size of the overlapping parts between chunks. 27 | **kwargs (dict): Unused. 28 | 29 | Returns: 30 | list[Document]: A list of processed and split document chunks. 31 | """ 32 | 33 | logger.debug("Loading markdown files") 34 | 35 | return load_structured_files( 36 | file_format="MARKDOWN", 37 | language="markdown", 38 | loader=UnstructuredMarkdownLoader, 39 | file_paths=file_paths, 40 | chunk_size=chunk_size, 41 | overlap_size=overlap_size, 42 | ) 43 | -------------------------------------------------------------------------------- /rag_experiment_accelerator/config/paths.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | from datetime import datetime 4 | 5 | from rag_experiment_accelerator.utils.logging import get_logger 6 | 7 | 8 | logger = get_logger(__name__) 9 | 10 | 11 | def get_all_file_paths(directory: str) -> list[str]: 12 | """ 13 | Returns a list of all file paths in a directory listed recursively. 14 | """ 15 | pattern = os.path.join(directory, "**", "*") 16 | return [file for file in glob.glob(pattern, recursive=True) if os.path.isfile(file)] 17 | 18 | 19 | def try_create_directory(directory: str) -> None: 20 | """ 21 | Tries to create a directory with the given path. 22 | 23 | Args: 24 | directory (str): The path of the directory to be created. 25 | 26 | Returns: 27 | None 28 | 29 | Raises: 30 | OSError: If an error occurs while creating the directory. 31 | """ 32 | try: 33 | os.makedirs(directory, exist_ok=True) 34 | except OSError as e: 35 | logger.warn(f"Failed to create directory {directory}: {e.strerror}") 36 | 37 | 38 | def formatted_datetime_suffix(): 39 | """Return a suffix to use when naming the run and its artifacts.""" 40 | return datetime.now().strftime("%Y_%m_%d_%H_%M_%S") 41 | 42 | 43 | def mlflow_run_name(job_name: str, suffix: str = None): 44 | """Returns a name to use for the MlFlow experiment run.""" 45 | if not suffix: 46 | suffix = formatted_datetime_suffix() 47 | return f"{job_name}_{suffix}" 48 | -------------------------------------------------------------------------------- /rag_experiment_accelerator/llm/prompts_text/rerank_prompt_instruction.txt: -------------------------------------------------------------------------------- 1 | You are provided with a list of documents, each identified by a number and accompanied by a summary. 2 | A user question is also given. 3 | Rank these documents based on their relevance to the question, assigning each a relevance score from 1 to 10, where 10 indicates the highest relevance. 4 | Respond with the reranked document numbers and their relevance scores formatted as a JSON string, according to the schema below. 5 | Ensure the JSON string contains all listed documents and adheres to format specifications without any additional text or explanation. 6 | 7 | 8 | User: 9 | Document 1: 10 | Overview of renewable energy trends and their economic impacts. 11 | 12 | Document 2: 13 | Analysis of fossil fuel dependency in developing countries. 14 | 15 | Document 3: 16 | Detailed report on the advancements in solar energy panels. 17 | 18 | Document 4: 19 | Comparison of wind energy efficiencies across continents. 20 | 21 | Document 5: 22 | Study on the environmental impacts of hydraulic fracturing. 23 | 24 | Document 6: 25 | Historical data on the use of renewable resources in Europe. 26 | 27 | Question: What are the latest developments in solar energy technology? 28 | 29 | Assistant: 30 | { 31 | "document_1": 7, 32 | "document_2": 3, 33 | "document_3": 10, 34 | "document_4": 6, 35 | "document_5": 2, 36 | "document_6": 4 37 | } -------------------------------------------------------------------------------- /rag_experiment_accelerator/checkpoint/checkpoint_factory.py: -------------------------------------------------------------------------------- 1 | from rag_experiment_accelerator.config.config import Config, ExecutionEnvironment 2 | 3 | global _checkpoint_instance 4 | _checkpoint_instance = None 5 | 6 | 7 | def get_checkpoint(): 8 | """ 9 | Returns the current checkpoint instance. 10 | """ 11 | global _checkpoint_instance 12 | if not _checkpoint_instance: 13 | raise Exception("Checkpoint not initialized yet. Call init_checkpoint() first.") 14 | return _checkpoint_instance 15 | 16 | 17 | def init_checkpoint(config: Config): 18 | """ 19 | Initializes the checkpoint instance based on the provided configuration. 20 | """ 21 | global _checkpoint_instance 22 | _checkpoint_instance = _get_checkpoint_base_on_config(config) 23 | 24 | 25 | def _get_checkpoint_base_on_config(config: Config): 26 | # import inside the method to avoid circular dependencies 27 | from rag_experiment_accelerator.checkpoint.null_checkpoint import NullCheckpoint 28 | from rag_experiment_accelerator.checkpoint.local_storage_checkpoint import ( 29 | LocalStorageCheckpoint, 30 | ) 31 | 32 | if not config.use_checkpoints: 33 | return NullCheckpoint() 34 | 35 | if config.execution_environment == ExecutionEnvironment.AZURE_ML: 36 | # Currently not supported in Azure ML: https://github.com/microsoft/rag-experiment-accelerator/issues/491 37 | return NullCheckpoint() 38 | 39 | return LocalStorageCheckpoint(directory=config.path.artifacts_dir) 40 | -------------------------------------------------------------------------------- /rag_experiment_accelerator/doc_loader/textLoader.py: -------------------------------------------------------------------------------- 1 | from langchain_community.document_loaders import TextLoader 2 | 3 | from rag_experiment_accelerator.doc_loader.structuredLoader import ( 4 | load_structured_files, 5 | ) 6 | from rag_experiment_accelerator.utils.logging import get_logger 7 | from rag_experiment_accelerator.config.environment import Environment 8 | 9 | logger = get_logger(__name__) 10 | 11 | 12 | def load_text_files( 13 | environment: Environment, 14 | file_paths: list[str], 15 | chunk_size: str, 16 | overlap_size: str, 17 | **kwargs: dict, 18 | ): 19 | """ 20 | Load and process text files from a given folder path. 21 | 22 | Args: 23 | environment (Environment): The environment class 24 | chunking_strategy (str): The chunking strategy to use between "azure-document-intelligence" and "basic". 25 | file_paths (list[str]): Sequence of paths to load. 26 | chunk_size (int): The size of each text chunk in characters. 27 | overlap_size (int): The size of the overlap between text chunks in characters. 28 | **kwargs (dict): Unused. 29 | 30 | Returns: 31 | list[Document]: A list of processed and split document chunks. 32 | """ 33 | 34 | logger.debug("Loading text files") 35 | 36 | return load_structured_files( 37 | file_format="TEXT", 38 | language=None, 39 | loader=TextLoader, 40 | file_paths=file_paths, 41 | chunk_size=chunk_size, 42 | overlap_size=overlap_size, 43 | ) 44 | -------------------------------------------------------------------------------- /rag_experiment_accelerator/llm/prompts_text/prompt_generate_hypothetical_answer.txt: -------------------------------------------------------------------------------- 1 | You are a helpful expert research assistant. 2 | Generate a hypothetical answer to the given question as if it were found in a document. 3 | This involves creating a detailed, expert-level response based on the context or nature of the question, even if specific data or real documents aren't available. 4 | Your answer should reflect what an expert might write in a relevant document or article. 5 | 6 | 7 | User: 8 | What are the potential impacts of artificial intelligence on job markets in the next decade? 9 | Assistant: 10 | Experts predict that artificial intelligence will significantly automate tasks, potentially displacing jobs in sectors like manufacturing and customer service, while creating new opportunities in AI development and data analysis. 11 | 12 | 13 | User: 14 | How could climate change affect coastal cities by 2050? 15 | Assistant: 16 | By 2050, climate change is expected to cause more frequent and severe flooding in coastal cities due to rising sea levels and increased storm intensity, necessitating major adaptations in urban planning and infrastructure. 17 | 18 | 19 | User: 20 | What are the latest advancements in renewable energy technologies? 21 | Assistant: 22 | Recent advancements in renewable energy include improvements in solar panel efficiency, development of larger offshore wind turbines, and breakthroughs in battery storage technology, all contributing to more sustainable energy solutions. -------------------------------------------------------------------------------- /rag_experiment_accelerator/io/writer.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | 4 | class Writer(ABC): 5 | """Abstract base class for a writer.""" 6 | 7 | @abstractmethod 8 | def write(self, path: str, data, **kwargs): 9 | """Write data to a file. 10 | 11 | Args: 12 | path (str): The path of the file to write to. 13 | data: The data to write to the file. 14 | **kwargs: Additional keyword arguments. 15 | 16 | Returns: 17 | None 18 | """ 19 | pass 20 | 21 | @abstractmethod 22 | def copy(self, src: str, dest: str, **kwargs): 23 | """Copy a file from source to destination. 24 | 25 | Args: 26 | src (str): The path of the source file. 27 | dest (str): The path of the destination file. 28 | **kwargs: Additional keyword arguments. 29 | 30 | Returns: 31 | None 32 | """ 33 | pass 34 | 35 | @abstractmethod 36 | def delete(self, src: str): 37 | """Delete a file. 38 | 39 | Args: 40 | src (str): The path of the file to delete. 41 | 42 | Returns: 43 | None 44 | """ 45 | pass 46 | 47 | @abstractmethod 48 | def exists(self, path: str) -> bool: 49 | """Check if a file exists. 50 | 51 | Args: 52 | path (str): The path of the file to check. 53 | 54 | Returns: 55 | bool: True if the file exists, False otherwise. 56 | """ 57 | pass 58 | -------------------------------------------------------------------------------- /04_evaluation.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import mlflow 4 | 5 | from azureml.pipeline import initialise_mlflow_client 6 | from rag_experiment_accelerator.config.environment import Environment 7 | from rag_experiment_accelerator.run.evaluation import run 8 | from rag_experiment_accelerator.config.config import Config 9 | from rag_experiment_accelerator.config.paths import ( 10 | mlflow_run_name, 11 | formatted_datetime_suffix, 12 | ) 13 | 14 | 15 | if __name__ == "__main__": 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument( 18 | "--config_path", type=str, help="input: path to the config file" 19 | ) 20 | parser.add_argument( 21 | "--data_dir", 22 | type=str, 23 | help="input: path to the input data", 24 | default=None, # default is initialized in Config 25 | ) 26 | args, _ = parser.parse_known_args() 27 | 28 | environment = Environment.from_env_or_keyvault() 29 | config = Config.from_path(environment, args.config_path, args.data_dir) 30 | name_suffix = formatted_datetime_suffix() 31 | mlflow_client = initialise_mlflow_client(environment, config) 32 | mlflow.set_experiment(config.experiment_name) 33 | 34 | for index_config in config.index.flatten(): 35 | with mlflow.start_run(run_name=mlflow_run_name(config.job_name, name_suffix)): 36 | run( 37 | environment, 38 | config, 39 | index_config, 40 | mlflow_client, 41 | name_suffix=name_suffix, 42 | ) 43 | -------------------------------------------------------------------------------- /.env.template: -------------------------------------------------------------------------------- 1 | # For more information on environment variables, please refer to the documentation in ./docs/environment-variables.md 2 | 3 | #### Azure Search Service 4 | AZURE_SEARCH_SERVICE_ENDPOINT= 5 | AZURE_SEARCH_ADMIN_KEY= 6 | AZURE_SEARCH_USE_SEMANTIC_SEARCH="True" 7 | 8 | #### OpenAI 9 | OPENAI_API_KEY= 10 | # Must be 'azure' or 'open_ai' 11 | OPENAI_API_TYPE= 12 | 13 | ##### Azure OpenAI 14 | # Required when OPENAI_API_TYPE is set to 'azure' 15 | OPENAI_ENDPOINT= 16 | OPENAI_API_VERSION= 17 | 18 | #### Azure Machine Learning 19 | AML_SUBSCRIPTION_ID= 20 | AML_WORKSPACE_NAME= 21 | AML_RESOURCE_GROUP_NAME= 22 | 23 | ############ 24 | # OPTIONAL 25 | ############ 26 | 27 | #### Azure Document Intelligence 28 | AZURE_DOCUMENT_INTELLIGENCE_ENDPOINT= 29 | AZURE_DOCUMENT_INTELLIGENCE_ADMIN_KEY= 30 | 31 | #### Azure Search Skillsets 32 | AZURE_LANGUAGE_SERVICE_ENDPOINT= 33 | AZURE_LANGUAGE_SERVICE_KEY= 34 | 35 | #### Multithreading 36 | # Uncomment and set the maximum number of worker threads to use. 37 | # By default (if left commented) it to be optimized by number of CPU cores 38 | #MAX_WORKER_THREADS=1 39 | 40 | # One of: NOTSET, DEBUG, INFO, WARN, ERROR, CRITICAL. Default is INFO 41 | # LOGGING_LEVEL= 42 | ## If you're planning to run the pipeline on AML cluster 43 | # Name of the compute cluster to use for the Azure Machine Learning pipeline: fill in if you are planning to run on AML cluster 44 | # AML_COMPUTE_NAME= 45 | # Maximum number of instances in the compute cluster 46 | # AML_COMPUTE_INSTANCES_NUMBER= 47 | # Azure Key Vault endpoint 48 | # AZURE_KEY_VAULT_ENDPOINT= 49 | -------------------------------------------------------------------------------- /rag_experiment_accelerator/checkpoint/tests/test_local_storage_checkpoint.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import os 3 | import tempfile 4 | import shutil 5 | from unittest.mock import MagicMock 6 | 7 | from rag_experiment_accelerator.checkpoint.checkpoint_factory import ( 8 | get_checkpoint, 9 | init_checkpoint, 10 | ) 11 | from rag_experiment_accelerator.checkpoint.checkpoint_decorator import ( 12 | cache_with_checkpoint, 13 | ) 14 | from rag_experiment_accelerator.checkpoint.local_storage_checkpoint import ( 15 | LocalStorageCheckpoint, 16 | ) 17 | 18 | 19 | @cache_with_checkpoint(id="call_identifier") 20 | def dummy(word, call_identifier): 21 | return f"hello {word}" 22 | 23 | 24 | class TestLocalStorageCheckpoint(unittest.TestCase): 25 | def setUp(self): 26 | self.temp_dir = tempfile.mkdtemp() 27 | 28 | def tearDown(self): 29 | if os.path.exists(self.temp_dir): 30 | shutil.rmtree(self.temp_dir) 31 | 32 | def test_wrapped_method_is_cached(self): 33 | config = MagicMock() 34 | config.use_checkpoints = True 35 | config.artifacts_dir = self.temp_dir 36 | init_checkpoint(config) 37 | checkpoint = get_checkpoint() 38 | assert isinstance(checkpoint, LocalStorageCheckpoint) 39 | 40 | data_id = "same_id" 41 | result1 = dummy("first run", data_id) 42 | result2 = dummy("second run", data_id) 43 | self.assertEqual(result1, "hello first run") 44 | self.assertEqual(result2, "hello first run") 45 | 46 | 47 | if __name__ == "__main__": 48 | unittest.main() 49 | -------------------------------------------------------------------------------- /rag_experiment_accelerator/doc_loader/jsonLoader.py: -------------------------------------------------------------------------------- 1 | from rag_experiment_accelerator.doc_loader.customJsonLoader import ( 2 | CustomJSONLoader, 3 | ) 4 | from rag_experiment_accelerator.doc_loader.structuredLoader import ( 5 | load_structured_files, 6 | ) 7 | from rag_experiment_accelerator.utils.logging import get_logger 8 | from rag_experiment_accelerator.config.environment import Environment 9 | 10 | logger = get_logger(__name__) 11 | 12 | 13 | def load_json_files( 14 | environment: Environment, 15 | file_paths: list[str], 16 | chunk_size: str, 17 | overlap_size: str, 18 | **kwargs: dict, 19 | ): 20 | """ 21 | Load and process Json files from a given folder path. 22 | 23 | Args: 24 | environment (Environment): The environment class 25 | file_paths (list[str]): Sequence of paths to load. 26 | chunk_size (int): The size of each text chunk in characters. 27 | overlap_size (int): The size of the overlap between text chunks in characters. 28 | **kwargs (dict): Unused. 29 | 30 | Returns: 31 | list[Document]: A list of processed and split document chunks. 32 | """ 33 | 34 | logger.debug("Loading json files") 35 | 36 | keys_to_load = ["content", "title"] 37 | return load_structured_files( 38 | file_format="JSON", 39 | language=None, 40 | loader=CustomJSONLoader, 41 | file_paths=file_paths, 42 | chunk_size=chunk_size, 43 | overlap_size=overlap_size, 44 | loader_kwargs={ 45 | "keys_to_load": keys_to_load, 46 | }, 47 | ) 48 | -------------------------------------------------------------------------------- /promptflow/rag-experiment-accelerator/setup/setup_env.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from promptflow.connections import CustomConnection 4 | 5 | from promptflow import tool 6 | 7 | 8 | @tool 9 | def my_python_tool(connection: CustomConnection): 10 | os.environ["AZURE_SEARCH_SERVICE_ENDPOINT"] = connection.configs[ 11 | "AZURE_SEARCH_SERVICE_ENDPOINT" 12 | ] 13 | os.environ["AZURE_SEARCH_ADMIN_KEY"] = connection.secrets["AZURE_SEARCH_ADMIN_KEY"] 14 | os.environ["OPENAI_API_KEY"] = connection.secrets["OPENAI_API_KEY"] 15 | os.environ["OPENAI_API_TYPE"] = "azure" 16 | os.environ["OPENAI_ENDPOINT"] = connection.configs["OPENAI_ENDPOINT"] 17 | os.environ["OPENAI_API_VERSION"] = connection.configs["OPENAI_API_VERSION"] 18 | os.environ["AML_SUBSCRIPTION_ID"] = connection.secrets["AML_SUBSCRIPTION_ID"] 19 | os.environ["AML_RESOURCE_GROUP_NAME"] = connection.secrets[ 20 | "AML_RESOURCE_GROUP_NAME" 21 | ] 22 | os.environ["AML_WORKSPACE_NAME"] = connection.secrets["AML_WORKSPACE_NAME"] 23 | 24 | if "AZURE_LANGUAGE_SERVICE_KEY" in connection.secrets: 25 | os.environ["AZURE_LANGUAGE_SERVICE_KEY"] = connection.secrets[ 26 | "AZURE_LANGUAGE_SERVICE_KEY" 27 | ] 28 | 29 | if "AZURE_LANGUAGE_SERVICE_ENDPOINT" in connection.configs: 30 | os.environ["AZURE_LANGUAGE_SERVICE_ENDPOINT"] = connection.configs[ 31 | "AZURE_LANGUAGE_SERVICE_ENDPOINT" 32 | ] 33 | 34 | if "LOGGING_LEVEL" in connection.configs: 35 | os.environ["LOGGING_LEVEL"] = connection.configs["LOGGING_LEVEL"] 36 | 37 | return True 38 | -------------------------------------------------------------------------------- /rag_experiment_accelerator/checkpoint/checkpoint_decorator.py: -------------------------------------------------------------------------------- 1 | from rag_experiment_accelerator.checkpoint.checkpoint_factory import get_checkpoint 2 | 3 | 4 | def cache_with_checkpoint(id=None): 5 | """ 6 | A decorator that can be used to cache the results of a method call using the globally initialized Checkpoint object. 7 | An id must be provided to the decorator, which is used to identify the cached result. 8 | If the method is called with the same id again, the cached result is returned instead of executing the method. 9 | """ 10 | 11 | def decorator(func): 12 | def wrapper(*args, **kwargs): 13 | if id is None: 14 | raise ValueError( 15 | "'id' must be provided to the cache_with_checkpoint decorator" 16 | ) 17 | 18 | eval_context = {**globals(), **locals(), **kwargs} 19 | arg_dict = { 20 | param: value 21 | for param, value in zip( 22 | func.__code__.co_varnames[: func.__code__.co_argcount], args 23 | ) 24 | } 25 | eval_context.update(arg_dict) 26 | 27 | try: 28 | evaluated_id = eval(id, eval_context) 29 | except Exception as e: 30 | raise ValueError( 31 | f"Failed to evaluate the provided expression: {id}" 32 | ) from e 33 | 34 | checkpoint = get_checkpoint() 35 | return checkpoint.load_or_run(func, evaluated_id, *args, **kwargs) 36 | 37 | return wrapper 38 | 39 | return decorator 40 | -------------------------------------------------------------------------------- /infra/shared/search-services.bicep: -------------------------------------------------------------------------------- 1 | metadata description = 'Creates an Azure AI Search instance.' 2 | param name string 3 | param location string = resourceGroup().location 4 | param tags object = {} 5 | 6 | param sku object = { 7 | name: 'standard' 8 | } 9 | 10 | param authOptions object = {} 11 | param disableLocalAuth bool = false 12 | param encryptionWithCmk object = { 13 | enforcement: 'Unspecified' 14 | } 15 | @allowed([ 16 | 'default' 17 | 'highDensity' 18 | ]) 19 | param hostingMode string = 'default' 20 | param networkRuleSet object = { 21 | bypass: 'None' 22 | ipRules: [] 23 | } 24 | param partitionCount int = 1 25 | @allowed([ 26 | 'enabled' 27 | 'disabled' 28 | ]) 29 | param publicNetworkAccess string = 'enabled' 30 | param replicaCount int = 1 31 | @allowed([ 32 | 'disabled' 33 | 'free' 34 | 'standard' 35 | ]) 36 | param semanticSearch string = 'disabled' 37 | 38 | resource search 'Microsoft.Search/searchServices@2023-11-01' = { 39 | name: name 40 | location: location 41 | tags: tags 42 | identity: { 43 | type: 'SystemAssigned' 44 | } 45 | properties: { 46 | authOptions: authOptions 47 | disableLocalAuth: disableLocalAuth 48 | encryptionWithCmk: encryptionWithCmk 49 | hostingMode: hostingMode 50 | networkRuleSet: networkRuleSet 51 | partitionCount: partitionCount 52 | publicNetworkAccess: publicNetworkAccess 53 | replicaCount: replicaCount 54 | semanticSearch: semanticSearch 55 | } 56 | sku: sku 57 | } 58 | 59 | output id string = search.id 60 | output endpoint string = 'https://${name}.search.windows.net/' 61 | output name string = search.name 62 | -------------------------------------------------------------------------------- /rag_experiment_accelerator/doc_loader/htmlLoader.py: -------------------------------------------------------------------------------- 1 | from langchain_community.document_loaders import BSHTMLLoader 2 | 3 | from rag_experiment_accelerator.doc_loader.structuredLoader import ( 4 | load_structured_files, 5 | ) 6 | from rag_experiment_accelerator.utils.logging import get_logger 7 | from rag_experiment_accelerator.config.environment import Environment 8 | 9 | logger = get_logger(__name__) 10 | 11 | 12 | def load_html_files( 13 | environment: Environment, 14 | file_paths: list[str], 15 | chunk_size: str, 16 | overlap_size: str, 17 | **kwargs: dict, 18 | ): 19 | """ 20 | Load and process HTML files from a given folder path. 21 | 22 | Args: 23 | chunking_strategy (str): The chunking strategy to use between "azure-document-intelligence" and "basic". 24 | file_paths (list[str]): Sequence of paths to load. 25 | chunk_size (str): The size of the chunks to split the documents into. 26 | overlap_size (str): The size of the overlapping parts between chunks. 27 | glob_patterns (list[str]): List of file extensions to consider (e.g., ["html", "htm", ...]). 28 | **kwargs (dict): Unused. 29 | 30 | Returns: 31 | list[Document]: A list of processed and split document chunks. 32 | """ 33 | 34 | logger.debug("Loading html files") 35 | 36 | return load_structured_files( 37 | file_format="HTML", 38 | language="html", 39 | loader=BSHTMLLoader, 40 | file_paths=file_paths, 41 | chunk_size=chunk_size, 42 | overlap_size=overlap_size, 43 | loader_kwargs={"open_encoding": "utf-8"}, 44 | ) 45 | -------------------------------------------------------------------------------- /infra/network/network_isolation.bicep: -------------------------------------------------------------------------------- 1 | param vnetName string 2 | param location string 3 | 4 | @minLength(1) 5 | param vnetAddressSpace string 6 | param proxySubnetName string 7 | 8 | @minLength(1) 9 | param proxySubnetAddressSpace string 10 | param azureSubnetName string 11 | 12 | @minLength(1) 13 | param azureSubnetAddressSpace string 14 | param resourcePrefix string 15 | param azureResources array 16 | 17 | resource vnet 'Microsoft.Network/virtualNetworks@2020-06-01' = { 18 | name: vnetName 19 | location: location 20 | properties: { 21 | addressSpace: { 22 | addressPrefixes: [ 23 | vnetAddressSpace 24 | ] 25 | } 26 | subnets: [ 27 | { 28 | name: proxySubnetName 29 | properties: { 30 | addressPrefix: proxySubnetAddressSpace 31 | } 32 | } 33 | { 34 | name: azureSubnetName 35 | properties: { 36 | addressPrefix: azureSubnetAddressSpace 37 | } 38 | } 39 | ] 40 | } 41 | } 42 | 43 | resource privateEndpoints 'Microsoft.Network/privateEndpoints@2020-07-01' = [ 44 | for (resource, i) in azureResources: { 45 | name: '${resourcePrefix}${resource.type}PrivateEndpoint' 46 | location: location 47 | properties: { 48 | privateLinkServiceConnections: [ 49 | { 50 | name: '${resourcePrefix}${resource.type}PLSConnection' 51 | properties: { 52 | privateLinkServiceId: resource.resourceId 53 | groupIds: [resource.type] 54 | } 55 | } 56 | ] 57 | subnet: { 58 | id: '${vnet.id}/subnets/${azureSubnetName}' 59 | } 60 | } 61 | } 62 | ] 63 | -------------------------------------------------------------------------------- /rag_experiment_accelerator/data_assets/data_asset.py: -------------------------------------------------------------------------------- 1 | from azure.ai.ml import MLClient 2 | from azure.ai.ml.entities import Data 3 | from azure.ai.ml.constants import AssetTypes 4 | 5 | from rag_experiment_accelerator.utils.logging import get_logger 6 | from rag_experiment_accelerator.utils.auth import get_default_az_cred 7 | from rag_experiment_accelerator.config.environment import Environment 8 | 9 | logger = get_logger(__name__) 10 | 11 | 12 | def create_data_asset(data_path: str, data_asset_name: str, environment: Environment): 13 | """ 14 | Creates a new data asset in Azure Machine Learning workspace. 15 | 16 | Args: 17 | data_path (str): The path to the data file. 18 | data_asset_name (str): The name of the data asset. 19 | environment (Environment): Class containing the environment configuration 20 | 21 | Returns: 22 | int: The version of the created data asset. 23 | """ 24 | 25 | ml_client = MLClient( 26 | get_default_az_cred(), 27 | environment.aml_subscription_id, 28 | environment.aml_resource_group_name, 29 | environment.aml_workspace_name, 30 | ) 31 | 32 | aml_dataset = Data( 33 | path=data_path, 34 | type=AssetTypes.URI_FILE, 35 | description="rag data", 36 | name=data_asset_name, 37 | ) 38 | 39 | ml_client.data.create_or_update(aml_dataset) 40 | 41 | aml_dataset_unlabeled = ml_client.data.get(name=data_asset_name, label="latest") 42 | 43 | logger.info(f"Dataset version: {aml_dataset_unlabeled.version}") 44 | logger.info(f"Dataset ID: {aml_dataset_unlabeled.id}") 45 | 46 | return aml_dataset_unlabeled.version 47 | -------------------------------------------------------------------------------- /rag_experiment_accelerator/io/local/loaders/jsonl_loader.py: -------------------------------------------------------------------------------- 1 | import json 2 | from json.decoder import JSONDecodeError 3 | 4 | from rag_experiment_accelerator.io.local.loaders.local_loader import LocalLoader 5 | 6 | 7 | class JsonlLoader(LocalLoader): 8 | """A class for loading data from JSONL files.""" 9 | 10 | def load(self, path: str, **kwargs) -> list: 11 | """Load data from a JSONL file. 12 | 13 | Args: 14 | path (str): The path to the JSONL file. 15 | **kwargs: Additional keyword arguments to be passed to json.loads(). 16 | 17 | Returns: 18 | list: A list of loaded data. 19 | 20 | Raises: 21 | FileNotFoundError: If the file is not found at the specified path. 22 | """ 23 | if not self.exists(path): 24 | raise FileNotFoundError(f"File not found at path: {path}") 25 | 26 | data_load = [] 27 | with open(path, "r") as file: 28 | for line in file: 29 | try: 30 | data = json.loads(line, **kwargs) 31 | except JSONDecodeError as jde: 32 | jde.add_note(f'Error occurred on line {len(data_load) + 1} in input file {path}') 33 | 34 | data_load.append(data) 35 | 36 | return data_load 37 | 38 | def can_handle(self, path: str) -> bool: 39 | """Check if the loader can handle the given file path. 40 | 41 | Args: 42 | path (str): The file path to check. 43 | 44 | Returns: 45 | bool: True if the loader can handle the file, False otherwise. 46 | """ 47 | ext = self._get_file_ext(path) 48 | return ext == ".jsonl" 49 | -------------------------------------------------------------------------------- /rag_experiment_accelerator/llm/prompt/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | 3 | from rag_experiment_accelerator.llm.prompt.prompt import ( 4 | Prompt, 5 | StructuredPrompt, 6 | StructuredWithCoTPrompt, 7 | CoTPrompt, 8 | PromptTag, 9 | ) 10 | 11 | from rag_experiment_accelerator.llm.prompt.hyde_prompts import ( 12 | prompt_generate_hypothetical_answer, 13 | prompt_generate_hypothetical_document, 14 | prompt_generate_hypothetical_questions, 15 | ) 16 | 17 | from rag_experiment_accelerator.llm.prompt.instruction_prompts import ( 18 | prompt_instruction_entities, 19 | prompt_instruction_keywords, 20 | prompt_instruction_title, 21 | prompt_instruction_summary, 22 | main_instruction_short, 23 | main_instruction_long, 24 | main_instruction, 25 | ) 26 | 27 | from rag_experiment_accelerator.llm.prompt.multiprompts import ( 28 | do_need_multiple_prompt_instruction, 29 | multiple_prompt_instruction, 30 | ) 31 | 32 | from rag_experiment_accelerator.llm.prompt.qna_prompts import ( 33 | generate_qna_long_single_context_instruction_prompt, 34 | generate_qna_short_single_context_instruction_prompt, 35 | generate_qna_long_multiple_context_instruction_prompt, 36 | generate_qna_short_multiple_context_instruction_prompt, 37 | generate_qna_short_single_context_no_cot_instruction_prompt, 38 | qna_generation_prompt, 39 | ) 40 | 41 | from rag_experiment_accelerator.llm.prompt.ragas_prompts import ( 42 | llm_answer_relevance_instruction, 43 | llm_context_precision_instruction, 44 | llm_context_recall_instruction, 45 | ) 46 | 47 | from rag_experiment_accelerator.llm.prompt.rerank_prompts import ( 48 | rerank_prompt_instruction, 49 | ) 50 | -------------------------------------------------------------------------------- /infra/shared/cognitiveservices.bicep: -------------------------------------------------------------------------------- 1 | metadata description = 'Creates an Azure Cognitive Services instance.' 2 | param name string 3 | param location string = resourceGroup().location 4 | param tags object = {} 5 | @description('The custom subdomain name used to access the API. Defaults to the value of the name parameter.') 6 | param customSubDomainName string = name 7 | param deployments array = [] 8 | param kind string = 'OpenAI' 9 | 10 | @allowed([ 'Enabled', 'Disabled' ]) 11 | param publicNetworkAccess string = 'Enabled' 12 | param sku object = { 13 | name: 'S0' 14 | } 15 | 16 | param allowedIpRules array = [] 17 | param networkAcls object = empty(allowedIpRules) ? { 18 | defaultAction: 'Allow' 19 | } : { 20 | ipRules: allowedIpRules 21 | defaultAction: 'Deny' 22 | } 23 | 24 | resource account 'Microsoft.CognitiveServices/accounts@2023-05-01' = { 25 | name: name 26 | location: location 27 | tags: tags 28 | kind: kind 29 | properties: { 30 | customSubDomainName: customSubDomainName 31 | publicNetworkAccess: publicNetworkAccess 32 | networkAcls: networkAcls 33 | } 34 | sku: sku 35 | } 36 | 37 | @batchSize(1) 38 | resource deployment 'Microsoft.CognitiveServices/accounts/deployments@2023-05-01' = [for deployment in deployments: { 39 | parent: account 40 | name: deployment.name 41 | properties: { 42 | model: deployment.model 43 | raiPolicyName: contains(deployment, 'raiPolicyName') ? deployment.raiPolicyName : null 44 | } 45 | sku: contains(deployment, 'sku') ? deployment.sku : { 46 | name: 'Standard' 47 | capacity: 30 48 | } 49 | }] 50 | 51 | output endpoint string = account.properties.endpoint 52 | output id string = account.id 53 | output name string = account.name 54 | -------------------------------------------------------------------------------- /rag_experiment_accelerator/evaluation/tests/test_spacy_evaluator.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import MagicMock, call, patch 2 | from rag_experiment_accelerator.evaluation.spacy_evaluator import ( 3 | SpacyEvaluator, 4 | ) 5 | 6 | 7 | @patch("rag_experiment_accelerator.evaluation.spacy_evaluator.load") 8 | def test_evaluator_init(mock_nlp): 9 | similarity_threshold = 0.4 10 | evaluator = SpacyEvaluator(similarity_threshold=similarity_threshold) 11 | assert similarity_threshold == evaluator.similarity_threshold 12 | 13 | 14 | @patch("rag_experiment_accelerator.evaluation.spacy_evaluator.load") 15 | def test_similarity_returns_similar(mock_nlp): 16 | mock_doc_1 = MagicMock() 17 | mock_doc_1.similarity.return_value = 1 18 | mock_doc_2 = MagicMock() 19 | mock_nlp().side_effect = [mock_doc_1, mock_doc_2] 20 | 21 | evaluator = SpacyEvaluator() 22 | actual = evaluator.similarity("test word", "test word") 23 | 24 | mock_doc_1.similarity.assert_called_once_with(mock_doc_2) 25 | assert actual == 1 26 | 27 | 28 | @patch( 29 | "rag_experiment_accelerator.evaluation.spacy_evaluator.SpacyEvaluator.similarity" 30 | ) 31 | @patch("rag_experiment_accelerator.evaluation.spacy_evaluator.load") 32 | def test_is_relevant_returns_valid(mock_nlp, mock_similarity): 33 | mock_similarity.side_effect = [1, 0.05] 34 | 35 | evaluator = SpacyEvaluator() 36 | actual_true = evaluator.is_relevant("test phrase", "test phrase") 37 | actual_false = evaluator.is_relevant("phrase", "different") 38 | 39 | mock_similarity.assert_has_calls( 40 | [call("test phrase", "test phrase"), call("phrase", "different")] 41 | ) 42 | assert actual_true is True 43 | assert actual_false is False 44 | -------------------------------------------------------------------------------- /.github/workflows/build_validation_workflow.yml: -------------------------------------------------------------------------------- 1 | name: Build validation 2 | 3 | on: 4 | workflow_call: 5 | workflow_dispatch: 6 | pull_request: 7 | branches: 8 | - main 9 | - development 10 | - prerelease 11 | push: 12 | branches: 13 | - main 14 | - development 15 | - prerelease 16 | merge_group: 17 | 18 | concurrency: 19 | group: ${{ github.workflow }}-${{ github.ref }} 20 | cancel-in-progress: ${{ github.ref != 'refs/heads/development' && github.ref != 'refs/heads/main' && github.ref != 'refs/heads/prerelease' }} 21 | 22 | jobs: 23 | validate-code: 24 | name: job for validating code and structure 25 | runs-on: ubuntu-latest 26 | steps: 27 | - name: Checkout Actions 28 | uses: actions/checkout@v4 29 | - uses: actions/setup-python@v5 30 | with: 31 | python-version: "3.11" 32 | - name: Load all build validation related dependencies 33 | shell: bash 34 | run: | 35 | set -e # fail on error 36 | python -m pip install --upgrade pip 37 | python -m pip install -e . -r requirements.txt -r dev-requirements.txt 38 | 39 | - name: Download spacy model 40 | shell: bash 41 | run: | 42 | python -m spacy download en_core_web_sm 43 | 44 | - name: Run flake 45 | shell: bash 46 | run: | 47 | flake8 --extend-ignore=E501 48 | 49 | - name: Execute Unit Tests 50 | shell: bash 51 | run: | 52 | pytest . --cov=. --cov-report=html --cov-config=.coveragerc 53 | 54 | - name: Publish Unit Test Results 55 | uses: actions/upload-artifact@v4 56 | with: 57 | name: unit-test-results 58 | path: "htmlcov/**" 59 | -------------------------------------------------------------------------------- /rag_experiment_accelerator/llm/prompt/ragas_prompts.py: -------------------------------------------------------------------------------- 1 | import re 2 | import json 3 | from rag_experiment_accelerator.llm.prompt.prompt import ( 4 | Prompt, 5 | StructuredPrompt, 6 | PromptTag, 7 | ) 8 | 9 | 10 | def validate_context_precision(text: str) -> bool: 11 | return text.lower().strip() in ["yes", "no"] 12 | 13 | 14 | def validate_context_recall(text: str) -> bool: 15 | json_text = json.loads(text) 16 | 17 | def is_valid_entry(entry): 18 | statement_key_pattern = re.compile(r"^statement_\d+$") 19 | return all( 20 | key in ["reason", "attributed"] or statement_key_pattern.match(key) 21 | for key in entry.keys() 22 | ) 23 | 24 | return isinstance(json_text, list) and all( 25 | is_valid_entry(entry) for entry in json_text 26 | ) 27 | 28 | 29 | _context_precision_input = """ 30 | Context: 31 | ${context} 32 | 33 | Question: 34 | ${question} 35 | """ 36 | 37 | _context_recall_input = """ 38 | question: ${question} 39 | context: ${context} 40 | answer: ${answer} 41 | """ 42 | 43 | llm_answer_relevance_instruction = Prompt( 44 | system_message="llm_answer_relevance_instruction.txt", 45 | user_template="${text}", 46 | tags={PromptTag.NonStrict}, 47 | ) 48 | 49 | llm_context_precision_instruction = StructuredPrompt( 50 | system_message="llm_context_precision_instruction.txt", 51 | user_template=_context_precision_input, 52 | validator=validate_context_precision, 53 | tags={PromptTag.NonStrict}, 54 | ) 55 | 56 | llm_context_recall_instruction = StructuredPrompt( 57 | system_message="llm_context_recall_instruction.txt", 58 | user_template=_context_recall_input, 59 | validator=validate_context_recall, 60 | tags={PromptTag.JSON, PromptTag.NonStrict}, 61 | ) 62 | -------------------------------------------------------------------------------- /rag_experiment_accelerator/embedding/tests/test_factory.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import patch, MagicMock 2 | import pytest 3 | 4 | from rag_experiment_accelerator.embedding.aoai_embedding_model import AOAIEmbeddingModel 5 | from rag_experiment_accelerator.embedding.st_embedding_model import STEmbeddingModel 6 | from rag_experiment_accelerator.embedding.factory import create_embedding_model 7 | 8 | 9 | def test_create_aoai_embedding_model(): 10 | embedding_type = "azure" 11 | model_name = "test_model" 12 | dimension = 768 13 | environment = MagicMock() 14 | model = create_embedding_model( 15 | model_type=embedding_type, 16 | model_name=model_name, 17 | dimension=dimension, 18 | environment=environment, 19 | ) 20 | assert isinstance(model, AOAIEmbeddingModel) 21 | 22 | 23 | @patch("rag_experiment_accelerator.embedding.st_embedding_model.SentenceTransformer") 24 | def test_create_st_embedding_model(mock_sentence_transformer): 25 | embedding_type = "sentence-transformer" 26 | model_name = "all-mpnet-base-v2" 27 | dimension = 768 28 | environment = MagicMock() 29 | model = create_embedding_model( 30 | model_type=embedding_type, 31 | model_name=model_name, 32 | dimension=dimension, 33 | environment=environment, 34 | ) 35 | assert isinstance(model, STEmbeddingModel) 36 | 37 | 38 | def test_create_raises_invalid_embedding_type(): 39 | embedding_type = "not-valid" 40 | model_name = "test_model" 41 | dimension = 768 42 | environment = MagicMock() 43 | with pytest.raises(ValueError): 44 | create_embedding_model( 45 | model_type=embedding_type, 46 | model_name=model_name, 47 | dimension=dimension, 48 | environment=environment, 49 | ) 50 | -------------------------------------------------------------------------------- /rag_experiment_accelerator/llm/prompts_text/do_need_multiple_prompt_instruction.txt: -------------------------------------------------------------------------------- 1 | Analyze the given question to determine if it falls into one of the following categories: 2 | 3 | 1. Simple, Factual Question 4 | - Description: The question asks for a straightforward fact or piece of information. 5 | - Characteristics: 6 | - The answer can likely be found stated directly in a single passage of a relevant document. 7 | - Further breakdown of the question is unlikely to be beneficial. 8 | - Examples: 9 | - "What year did World War 2 end?" 10 | - "What is the capital of France?" 11 | - "What are the specifications of product X?" 12 | 13 | 2. Complex, Multi-part Question 14 | - Description: The question involves multiple components or asks for information about several related topics. 15 | - Characteristics: 16 | - Different parts of the question may need to be answered by separate passages or documents. 17 | - Breaking the question down into sub-questions for each component can yield better results. 18 | - The question is open-ended and may have a complex or nuanced answer. 19 | - Answering may require synthesizing information from multiple sources. 20 | - There may not be a single definitive answer and could require analysis from multiple angles. 21 | - Examples: 22 | - "What were the key causes, major battles, and outcomes of the American Revolutionary War?" 23 | - "How do electric cars work and how do they compare to gas-powered vehicles?" 24 | 25 | Output Requirement: 26 | Respond with one of two categories "complex" or "simple". Do not add anything else to the response. 27 | Use lower case and don't add . at the end of the response. 28 | 29 | 30 | User: 31 | What are the benefits of renewable energy? 32 | 33 | Assistant: 34 | complex -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | // Use IntelliSense to learn about possible attributes. 3 | // Hover to view descriptions of existing attributes. 4 | // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 5 | "version": "0.2.0", 6 | "configurations": [ 7 | { 8 | "name": "1. Index", 9 | "type": "debugpy", 10 | "request": "launch", 11 | "module": "01_index", 12 | "justMyCode": true, 13 | "envFile": "${input:dotEnvFilePath}", 14 | }, 15 | { 16 | "name": "2. Generate QA", 17 | "type": "debugpy", 18 | "request": "launch", 19 | "module": "02_qa_generation", 20 | "justMyCode": true, 21 | "envFile": "${input:dotEnvFilePath}", 22 | }, 23 | { 24 | "name": "3. Querying", 25 | "type": "debugpy", 26 | "request": "launch", 27 | "module": "03_querying", 28 | "justMyCode": true, 29 | "envFile": "${input:dotEnvFilePath}", 30 | }, 31 | { 32 | "name": "4. Evaluate", 33 | "type": "debugpy", 34 | "request": "launch", 35 | "module": "04_evaluation", 36 | "justMyCode": true, 37 | "envFile": "${input:dotEnvFilePath}", 38 | }, 39 | { 40 | "name": "5. AzureML Pipeline", 41 | "type": "debugpy", 42 | "request": "launch", 43 | "module": "azureml.pipeline", 44 | "justMyCode": true, 45 | "envFile": "${input:dotEnvFilePath}", 46 | } 47 | ], 48 | "inputs": [ 49 | { 50 | "id": "dotEnvFilePath", 51 | "type": "command", 52 | "command": "azure-dev.commands.getDotEnvFilePath" 53 | } 54 | ] 55 | } 56 | -------------------------------------------------------------------------------- /data/json/sample-json.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "item_id": "1", 4 | "title": "Title 1", 5 | "content": "This is the content for item 1.", 6 | "summary": "Summary for item 1.", 7 | "description": "Description for item 1.", 8 | "created_date": "2023-01-01T00:00:00Z", 9 | "modified_date": "2023-01-02T00:00:00Z" 10 | }, 11 | { 12 | "item_id": "2", 13 | "title": "Title 2", 14 | "content": "This is the content for item 2.", 15 | "summary": "Summary for item 2.", 16 | "description": "Description for item 2.", 17 | "created_date": "2023-01-02T00:00:00Z", 18 | "modified_date": "2023-01-03T00:00:00Z" 19 | }, 20 | { 21 | "item_id": "3", 22 | "title": "Title 3", 23 | "content": "This is the content for item 3.", 24 | "summary": "Summary for item 3.", 25 | "description": "Description for item 3.", 26 | "created_date": "2023-01-03T00:00:00Z", 27 | "modified_date": "2023-01-04T00:00:00Z" 28 | }, 29 | { 30 | "item_id": "4", 31 | "title": "Title 4", 32 | "content": "This is the content for item 4.", 33 | "summary": "Summary for item 4.", 34 | "description": "Description for item 4.", 35 | "created_date": "2023-01-04T00:00:00Z", 36 | "modified_date": "2023-01-05T00:00:00Z" 37 | }, 38 | { 39 | "item_id": "5", 40 | "title": "Title 5", 41 | "content": "This is the content for item 5.", 42 | "summary": "Summary for item 5.", 43 | "description": "Description for item 5.", 44 | "created_date": "2023-01-05T00:00:00Z", 45 | "modified_date": "2023-01-06T00:00:00Z" 46 | }, 47 | { 48 | "item_id": "6", 49 | "title": "Title 6", 50 | "content": "This is the content for item 6.", 51 | "summary": "Summary for item 6.", 52 | "description": "Description for item 6.", 53 | "created_date": "2023-01-05T00:00:00Z", 54 | "modified_date": "2023-01-06T00:00:00Z" 55 | } 56 | ] -------------------------------------------------------------------------------- /docs/wsl.md: -------------------------------------------------------------------------------- 1 | # Setting up WSL 2 | 3 | There are numerous guides to setting up WSL, and this is not a comprehensive guide. Instead this might help you setup the basics. 4 | 5 | #### If you are using Docker Desktop 6 | 7 | To enable **Developing inside a Container** you must configure the integration between Docker Desktop and WSL on your machine. 8 | 9 | >1. Launch Docker Desktop 10 | >2. Open **Settings > General**. Make sure the *Use the WSL 2 based engine" is enabled. 11 | >3. Navigate to **Settings > Resources > WSL INTEGRATION**. 12 | > - Ensure *Enable Integration with my default WSL distro" is enabled. 13 | > - Enable the Ubuntu-18.04 option. 14 | >4. Select **Apply & Restart** 15 | 16 | ## Configure Git in Ubuntu WSL environment 17 | 18 | The next step is to configure Git for your Ubuntu WSL environment. We will use the bash prompt from the previous step to issue the following commands: 19 | 20 | Set Git User Name and Email 21 | 22 | ``` bash 23 | git config --global user.name "Your Name" 24 | git config --global user.email "youremail@yourdomain.com" 25 | ``` 26 | 27 | Set Git [UseHttps](https://github.com/microsoft/Git-Credential-Manager-Core/blob/main/docs/configuration.md#credentialusehttppath) 28 | 29 | ``` bash 30 | git config --global credential.useHttpPath true 31 | ``` 32 | 33 | Configure Git to use the Windows Host Credential Manager 34 | 35 | ``` bash 36 | git config --global credential.helper "/mnt/c/Program\ Files/Git/mingw64/libexec/git-core/git-credential-manager-core.exe" 37 | ``` 38 | 39 | ## Install Azure CLI On WSL 40 | 41 | In your Ubuntu 18.04(WSL) terminal from the previous step, follow the directions [here](https://docs.microsoft.com/en-us/cli/azure/install-azure-cli-linux) to install Azure CLI. 42 | 43 | 44 | Install Azure CLI and authorize: 45 | ```bash 46 | az login 47 | az account set --subscription="" 48 | az account show 49 | ``` 50 | -------------------------------------------------------------------------------- /rag_experiment_accelerator/evaluation/spacy_evaluator.py: -------------------------------------------------------------------------------- 1 | from spacy import load 2 | 3 | from rag_experiment_accelerator.utils.logging import get_logger 4 | 5 | logger = get_logger(__name__) 6 | 7 | 8 | class SpacyEvaluator: 9 | """ 10 | A class for evaluating the similarity between two documents using spaCy. 11 | 12 | Args: 13 | similarity_threshold (float): The minimum similarity score required for two documents to be considered relevant. 14 | model (str): The name of the spaCy model to use for processing the documents. 15 | 16 | Attributes: 17 | nlp (spacy.Language): The spaCy language model used for processing the documents. 18 | similarity_threshold (float): The minimum similarity score required for two documents to be considered relevant. 19 | 20 | Methods: 21 | similarity(doc1: str, doc2: str) -> float: Calculates the similarity score between two documents. 22 | is_relevant(doc1: str, doc2: str) -> bool: Determines whether two documents are relevant based on their similarity score. 23 | """ 24 | 25 | def __init__(self, similarity_threshold=0.8, model="en_core_web_lg") -> None: 26 | try: 27 | self.nlp = load(model) 28 | except OSError: 29 | logger.info(f"Downloading spacy language model: {model}") 30 | from spacy.cli import download 31 | 32 | download(model) 33 | self.nlp = load(model) 34 | self.similarity_threshold = similarity_threshold 35 | 36 | def similarity(self, doc1: str, doc2: str): 37 | nlp_doc1 = self.nlp(doc1) 38 | nlp_doc2 = self.nlp(doc2) 39 | return nlp_doc1.similarity(nlp_doc2) 40 | 41 | def is_relevant(self, doc1: str, doc2: str): 42 | similarity = self.similarity(doc1, doc2) 43 | logger.info(f"Similarity Score: {similarity}") 44 | 45 | return similarity > self.similarity_threshold 46 | -------------------------------------------------------------------------------- /rag_experiment_accelerator/llm/prompt/instruction_prompts.py: -------------------------------------------------------------------------------- 1 | import json 2 | from rag_experiment_accelerator.llm.prompt.prompt import ( 3 | Prompt, 4 | StructuredPrompt, 5 | PromptTag, 6 | ) 7 | 8 | 9 | def validate_instruction_keyword(text: str) -> bool: 10 | json_output = json.loads(text) 11 | return isinstance(json_output, list) and all( 12 | isinstance(item, str) for item in json_output 13 | ) 14 | 15 | 16 | def validate_instruction_entities(text: str) -> bool: 17 | json_output = json.loads(text) 18 | return isinstance(json_output, list) 19 | 20 | 21 | _main_response_template: str = """ 22 | Context: 23 | ${context} 24 | 25 | Question: 26 | ${question} 27 | """ 28 | 29 | prompt_instruction_entities = StructuredPrompt( 30 | system_message="prompt_instruction_entities.txt", 31 | user_template="${text}", 32 | validator=validate_instruction_entities, 33 | tags={PromptTag.JSON}, 34 | ) 35 | 36 | prompt_instruction_keywords = StructuredPrompt( 37 | system_message="prompt_instruction_keywords.txt", 38 | user_template="${text}", 39 | validator=validate_instruction_keyword, 40 | tags={PromptTag.JSON}, 41 | ) 42 | 43 | prompt_instruction_title = Prompt( 44 | system_message="prompt_instruction_title.txt", 45 | user_template="${text}", 46 | ) 47 | 48 | prompt_instruction_summary = Prompt( 49 | system_message="prompt_instruction_summary.txt", 50 | user_template="${text}", 51 | ) 52 | 53 | # TODO: Add selector for usage of long/short prompts 54 | main_instruction_short = Prompt( 55 | system_message="main_instruction_short.txt", 56 | user_template=_main_response_template, 57 | ) 58 | 59 | # TODO: Add selector for usage of long/short prompts 60 | main_instruction_long = Prompt( 61 | system_message="main_instruction_long.txt", 62 | user_template=_main_response_template, 63 | ) 64 | 65 | main_instruction = main_instruction_short 66 | -------------------------------------------------------------------------------- /rag_experiment_accelerator/llm/prompts_text/main_instruction_long.txt: -------------------------------------------------------------------------------- 1 | You provide answers to questions based solely on the information provided below. 2 | Answer precisely and concisely, addressing only what is asked without extraneous details. 3 | If the information needed to answer isn't available in the provided context, respond with "I don't know.". 4 | Cite specific sources by filename whenever you reference data or excerpts from the provided context. 5 | 6 | Input format: 7 | 8 | Context: 9 | {context} 10 | 11 | Question: 12 | {question} 13 | 14 | : 15 | User: 16 | Context: 17 | Sales data from the fourth quarter shows an increase in revenue from Region A, whereas Region B experienced a slight decline. (source: Q4_Sales_Report.txt) 18 | 19 | Question: 20 | Did revenue increase in Region A in the fourth quarter? 21 | 22 | Assistant: 23 | Yes, revenue in Region A increased in the fourth quarter according to the data provided in Q4_Sales_Report.txt. 24 | 25 | 26 | : 27 | User: 28 | Context: 29 | The new software update includes improvements to security protocols and user interface enhancements. (source: Update_Release_Notes.txt) 30 | 31 | Question: 32 | What does the new software update include? 33 | 34 | Assistant: 35 | The new software update includes improvements to security protocols and user interface enhancements, as detailed in Update_Release_Notes.txt. 36 | 37 | 38 | : 39 | User: 40 | Context: 41 | Employee satisfaction has significantly improved due to recent changes in workplace policies. (source: Employee_Feedback_2023.txt) 42 | 43 | Question: 44 | What has improved due to recent changes in workplace policies? 45 | 46 | Assistant: 47 | Employee satisfaction has significantly improved due to recent changes in workplace policies, as mentioned in Employee_Feedback_2023.txt. 48 | -------------------------------------------------------------------------------- /rag_experiment_accelerator/doc_loader/tests/test_data/json/data.valid.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "item_id": "1", 4 | "title": "Title TEST 1", 5 | "content": "This is the content for item 1.", 6 | "summary": "Summary for item 1.", 7 | "description": "Description for item 1.", 8 | "created_date": "2023-01-01T00:00:00Z", 9 | "modified_date": "2023-01-02T00:00:00Z" 10 | }, 11 | { 12 | "item_id": "2", 13 | "title": "Title 2", 14 | "content": "This is the content for item 2.", 15 | "summary": "Summary for item 2.", 16 | "description": "Description for item 2.", 17 | "created_date": "2023-01-02T00:00:00Z", 18 | "modified_date": "2023-01-03T00:00:00Z" 19 | }, 20 | { 21 | "item_id": "3", 22 | "title": "Title 3", 23 | "content": "This is the content for item 3.", 24 | "summary": "Summary for item 3.", 25 | "description": "Description for item 3.", 26 | "created_date": "2023-01-03T00:00:00Z", 27 | "modified_date": "2023-01-04T00:00:00Z" 28 | }, 29 | { 30 | "item_id": "4", 31 | "title": "Title 4", 32 | "content": "This is the content for item 4.", 33 | "summary": "Summary for item 4.", 34 | "description": "Description for item 4.", 35 | "created_date": "2023-01-04T00:00:00Z", 36 | "modified_date": "2023-01-05T00:00:00Z" 37 | }, 38 | { 39 | "item_id": "5", 40 | "title": "Title 5", 41 | "content": "This is the content for item 5.", 42 | "summary": "Summary for item 5.", 43 | "description": "Description for item 5.", 44 | "created_date": "2023-01-05T00:00:00Z", 45 | "modified_date": "2023-01-06T00:00:00Z" 46 | }, 47 | { 48 | "item_id": "6", 49 | "title": "Title 6", 50 | "content": "This is the content for item 6.", 51 | "summary": "Summary for item 6.", 52 | "description": "Description for item 6.", 53 | "created_date": "2023-01-05T00:00:00Z", 54 | "modified_date": "2023-01-06T00:00:00Z" 55 | } 56 | ] -------------------------------------------------------------------------------- /rag_experiment_accelerator/io/local/loaders/local_loader.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | import pathlib 3 | 4 | from rag_experiment_accelerator.io.loader import Loader 5 | from rag_experiment_accelerator.io.local.base import LocalIOBase 6 | 7 | 8 | class LocalLoader(LocalIOBase, Loader): 9 | """ 10 | A class that represents a local data loader. 11 | 12 | This class provides methods for loading data from a local source. 13 | 14 | Attributes: 15 | None 16 | 17 | Methods: 18 | load(src: str, **kwargs) -> list: 19 | Abstract method to load data from a local source. 20 | 21 | can_handle(src: str) -> bool: 22 | Abstract method to check if the loader can handle the given source. 23 | 24 | _get_file_ext(path: str): 25 | Internal method to get the file extension from a given path. 26 | """ 27 | 28 | @abstractmethod 29 | def load(self, src: str, **kwargs) -> list: 30 | """ 31 | Abstract method to load data from a local source. 32 | 33 | Args: 34 | src (str): The path or source of the data. 35 | 36 | Returns: 37 | list: The loaded data. 38 | """ 39 | pass 40 | 41 | @abstractmethod 42 | def can_handle(self, src: str) -> bool: 43 | """ 44 | Abstract method to check if the loader can handle the given source. 45 | 46 | Args: 47 | src (str): The path or source of the data. 48 | 49 | Returns: 50 | bool: True if the loader can handle the source, False otherwise. 51 | """ 52 | pass 53 | 54 | def _get_file_ext(self, path: str): 55 | """ 56 | Internal method to get the file extension from a given path. 57 | 58 | Args: 59 | path (str): The path of the file. 60 | 61 | Returns: 62 | str: The file extension. 63 | """ 64 | return pathlib.Path(path).suffix 65 | -------------------------------------------------------------------------------- /03_querying.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import mlflow 3 | from azureml.pipeline import initialise_mlflow_client 4 | 5 | from rag_experiment_accelerator.checkpoint import init_checkpoint 6 | from rag_experiment_accelerator.config.config import Config 7 | from rag_experiment_accelerator.config.environment import Environment 8 | from rag_experiment_accelerator.config.paths import mlflow_run_name 9 | from rag_experiment_accelerator.run.querying import run 10 | from rag_experiment_accelerator.data_assets.data_asset import create_data_asset 11 | from rag_experiment_accelerator.artifact.handlers.query_output_handler import ( 12 | QueryOutputHandler, 13 | ) 14 | 15 | 16 | if __name__ == "__main__": 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument( 19 | "--config_path", type=str, help="input: path to the config file" 20 | ) 21 | parser.add_argument( 22 | "--data_dir", 23 | type=str, 24 | help="input: path to the input data", 25 | default=None, # default is initialized in Config 26 | ) 27 | args, _ = parser.parse_known_args() 28 | 29 | environment = Environment.from_env_or_keyvault() 30 | config = Config.from_path( 31 | environment, 32 | args.config_path, 33 | ) 34 | mlflow_client = initialise_mlflow_client(environment, config) 35 | mlflow.set_experiment(config.experiment_name) 36 | 37 | handler = QueryOutputHandler(config.path.query_data_dir) 38 | init_checkpoint(config) 39 | 40 | for index_config in config.index.flatten(): 41 | with mlflow.start_run(run_name=mlflow_run_name(config.job_name)): 42 | run(environment, config, index_config, mlflow_client) 43 | 44 | index_name = index_config.index_name() 45 | create_data_asset( 46 | data_path=handler.get_output_path( 47 | index_name, config.experiment_name, config.job_name 48 | ), 49 | data_asset_name=index_name, 50 | environment=environment, 51 | ) 52 | -------------------------------------------------------------------------------- /infra/shared/storekeys.bicep: -------------------------------------------------------------------------------- 1 | param keyVaultName string = '' 2 | param azureOpenAIName string = '' 3 | param documentIntelligenceName string = '' 4 | param azureAISearchName string = '' 5 | param rgName string = '' 6 | // Do not use _ in the key names as it is not allowed in the key vault secret name 7 | param openAIKeyName string = 'openai-api-key' 8 | param documentIntelligenceKeyName string = 'azure-document-intelligence-admin-key' 9 | param searchKeyName string = 'azure-search-admin-key' 10 | 11 | resource openAIKeySecret 'Microsoft.KeyVault/vaults/secrets@2022-07-01' = { 12 | parent: keyVault 13 | name: openAIKeyName 14 | properties: { 15 | contentType: 'string' 16 | value: listKeys( 17 | resourceId(subscription().subscriptionId, rgName, 'Microsoft.CognitiveServices/accounts', azureOpenAIName), 18 | '2023-05-01' 19 | ).key1 20 | } 21 | } 22 | 23 | resource documentIntelligenceKeySecret 'Microsoft.KeyVault/vaults/secrets@2022-07-01' = { 24 | parent: keyVault 25 | name: documentIntelligenceKeyName 26 | properties: { 27 | contentType: 'string' 28 | value: listKeys( 29 | resourceId( 30 | subscription().subscriptionId, 31 | rgName, 32 | 'Microsoft.CognitiveServices/accounts', 33 | documentIntelligenceName 34 | ), 35 | '2023-05-01' 36 | ).key1 37 | } 38 | } 39 | 40 | resource searchKeySecret 'Microsoft.KeyVault/vaults/secrets@2022-07-01' = { 41 | parent: keyVault 42 | name: searchKeyName 43 | properties: { 44 | contentType: 'string' 45 | value: listAdminKeys( 46 | resourceId(subscription().subscriptionId, rgName, 'Microsoft.Search/searchServices', azureAISearchName), 47 | '2023-11-01' 48 | ).primaryKey 49 | } 50 | } 51 | 52 | resource keyVault 'Microsoft.KeyVault/vaults@2022-07-01' existing = { 53 | name: keyVaultName 54 | } 55 | 56 | output SEARCH_KEY_NAME string = searchKeySecret.name 57 | output OPENAI_KEY_NAME string = openAIKeySecret.name 58 | output DOCUMENTINTELLIGENCE_KEY_NAME string = documentIntelligenceKeySecret.name 59 | -------------------------------------------------------------------------------- /01_index.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | import mlflow 4 | 5 | from azureml.pipeline import initialise_mlflow_client 6 | 7 | from rag_experiment_accelerator.checkpoint import init_checkpoint 8 | from rag_experiment_accelerator.run.index import run 9 | from rag_experiment_accelerator.config.config import Config 10 | from rag_experiment_accelerator.config.environment import Environment 11 | from rag_experiment_accelerator.config.paths import get_all_file_paths, mlflow_run_name 12 | 13 | if __name__ == "__main__": 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument( 16 | "--config_path", type=str, help="input: path to the config file" 17 | ) 18 | parser.add_argument("--data_dir", type=str, help="input: path to the input data") 19 | parser.add_argument( 20 | "-s", 21 | "--sampling", 22 | action="store_true", 23 | help="input: run sampling. Avoid running on distributed compute", 24 | ) 25 | args, _ = parser.parse_known_args() 26 | 27 | environment = Environment.from_env_or_keyvault() 28 | config = Config.from_path(environment, args.config_path, args.data_dir) 29 | init_checkpoint(config) 30 | file_paths = get_all_file_paths(config.path.data_dir) 31 | mlflow_client = initialise_mlflow_client(environment, config) 32 | mlflow.set_experiment(config.experiment_name) 33 | 34 | do_sample = args.sampling 35 | index_dict = {"indexes": []} 36 | 37 | file_paths = get_all_file_paths(config.path.data_dir) 38 | for index_config in config.index.flatten(): 39 | with mlflow.start_run(run_name=mlflow_run_name(f"index_job_{config.job_name}")): 40 | index_name = run( 41 | environment, config, index_config, file_paths, mlflow_client, do_sample 42 | ) 43 | index_dict["indexes"].append(index_name) 44 | 45 | # saves the list of index names locally, not used afterwards 46 | with open(config.path.generated_index_names_file, "w") as index_names_file: 47 | json.dump(index_dict, index_names_file, indent=4) 48 | -------------------------------------------------------------------------------- /rag_experiment_accelerator/checkpoint/local_storage_checkpoint.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import hashlib 4 | 5 | from typing import Any, List, Set 6 | from rag_experiment_accelerator.checkpoint.checkpoint import Checkpoint 7 | 8 | 9 | class LocalStorageCheckpoint(Checkpoint): 10 | """ 11 | A checkpoint implementation that stores the data in the local file system. 12 | """ 13 | 14 | def __init__(self, directory: str = "."): 15 | self.checkpoint_location = f"{directory}/checkpoints" 16 | os.makedirs(self.checkpoint_location, exist_ok=True) 17 | self.internal_ids: Set[str] = self._get_existing_checkpoint_ids() 18 | 19 | def _has_data(self, id: str, method) -> bool: 20 | checkpoint_id = self._build_internal_id(id, method) 21 | return checkpoint_id in self.internal_ids 22 | 23 | def _load(self, id: str, method) -> List: 24 | file_path = self._get_checkpoint_file_path(id, method) 25 | with open(file_path, "rb") as file: 26 | data = pickle.load(file) 27 | return data 28 | 29 | def _save(self, data: Any, id: str, method): 30 | file_path = self._get_checkpoint_file_path(id, method) 31 | with open(file_path, "wb") as file: 32 | pickle.dump(data, file) 33 | internal_id = self._build_internal_id(id, method) 34 | self.internal_ids.add(internal_id) 35 | 36 | def _get_checkpoint_file_path(self, id: str, method): 37 | checkpoint_id = self._build_internal_id(id, method) 38 | return f"{self.checkpoint_location}/{checkpoint_id}.pkl" 39 | 40 | def _build_internal_id(self, id: str, method): 41 | hashed_id = hashlib.sha256(id.encode()).hexdigest() 42 | return f"{method.__name__}___{hashed_id}" 43 | 44 | def _get_existing_checkpoint_ids(self) -> Set[str]: 45 | ids = set() 46 | file_names = os.listdir(self.checkpoint_location) 47 | 48 | for file_name in file_names: 49 | file_name = file_name.replace(".pkl", "") 50 | ids.add(file_name) 51 | 52 | return ids 53 | -------------------------------------------------------------------------------- /rag_experiment_accelerator/llm/prompts_text/prompt_generate_hypothetical_document.txt: -------------------------------------------------------------------------------- 1 | You are a skilled research assistant capable of generating entire hypothetical documents that might contain answers to the given questions. 2 | Create documents that could exist based on the context or nature of the question, simulating different formats such as detailed narratives, structured tables, or loosely structured text based on what an expert document might look like. 3 | 4 | 5 | User: 6 | Discuss the future challenges and opportunities for small businesses in emerging markets. 7 | Assistant: 8 | This document examines both the challenges and opportunities facing small businesses in emerging markets over the next decade. Key challenges include limited access to capital and market volatility, while opportunities lie in increasing digitalization and growing consumer markets. Detailed analysis follows in the subsequent sections, providing a comprehensive overview with expert insights and projections. 9 | 10 | 11 | User: 12 | What are the current rates of deforestation in different regions of the Amazon? 13 | Assistant: 14 | Table: 15 | | Region | Deforestation Rate (sq km/year) | 16 | |---------------|---------------------------------| 17 | | Northern Amazon| 120 | 18 | | Southern Amazon| 95 | 19 | | Eastern Amazon | 110 | 20 | | Western Amazon | 85 | 21 | Summary: This table summarizes the annual rates of deforestation in various parts of the Amazon, highlighting significant environmental impacts." 22 | 23 | 24 | User: 25 | What are the latest theories in particle physics? 26 | Assistant: 27 | New Theories in Particle Physics 28 | - Lots of excitement around string theory updates. 29 | - Evidence pointing to possible fifth force of nature; details still unclear. 30 | - Recent experiments suggest discrepancies in the Standard Model; further investigation needed. 31 | Notes: The field is rapidly evolving, with many research papers published recently. Some data is still under review, making concrete conclusions difficult at this stage. -------------------------------------------------------------------------------- /rag_experiment_accelerator/evaluation/search_eval.py: -------------------------------------------------------------------------------- 1 | from sklearn import metrics 2 | 3 | from rag_experiment_accelerator.evaluation.spacy_evaluator import ( 4 | SpacyEvaluator, 5 | ) 6 | from rag_experiment_accelerator.utils.logging import get_logger 7 | 8 | logger = get_logger(__name__) 9 | 10 | 11 | def evaluate_search_result( 12 | search_response: list, evaluation_content: str, evaluator: SpacyEvaluator 13 | ): 14 | content = [] 15 | 16 | # create list of all docs with their is_relevant result to calculate recall and precision 17 | is_relevant_results = [] 18 | for doc in search_response: 19 | is_relevant = evaluator.is_relevant(doc["content"], evaluation_content) 20 | is_relevant_results.append(is_relevant) 21 | 22 | recall_scores = [] 23 | precision_scores = [] 24 | recall_predictions = [False for _ in range(len(search_response))] 25 | precision_predictions = [True for _ in range(len(search_response))] 26 | for i, doc in enumerate(search_response): 27 | k = i + 1 28 | logger.info("++++++++++++++++++++++++++++++++++") 29 | logger.info(f"Content: {doc['content']}") 30 | logger.info(f"Search Score: {doc['@search.score']}") 31 | 32 | precision_score = round( 33 | metrics.precision_score( 34 | is_relevant_results[:k], precision_predictions[:k] 35 | ), 36 | 2, 37 | ) 38 | precision_scores.append(precision_score) 39 | logger.info(f"Precision Score: {precision_score}@{k}") 40 | 41 | recall_predictions[i] = is_relevant_results[i] 42 | recall_score = round( 43 | metrics.recall_score(is_relevant_results, recall_predictions), 2 44 | ) 45 | recall_scores.append(recall_score) 46 | logger.info(f"Recall Score: {recall_score}@{k}") 47 | 48 | # TODO: should we only append content when it is relevant? 49 | content.append(doc["content"]) 50 | 51 | eval_metrics = { 52 | "recall_scores": recall_scores, 53 | "precision_scores": precision_scores, 54 | } 55 | 56 | return content, eval_metrics 57 | -------------------------------------------------------------------------------- /rag_experiment_accelerator/artifact/models/query_output.py: -------------------------------------------------------------------------------- 1 | class QueryOutput: 2 | """ 3 | Represents the output of a query. 4 | 5 | Attributes: 6 | rerank (bool): Indicates whether reranking is enabled. 7 | rerank_type (str): The type of reranking. 8 | cross_encoder_model (str): The model used for cross-encoding. 9 | llm_rerank_threshold (int): The threshold for reranking using LLM. 10 | retrieve_num_of_documents (int): The number of documents to retrieve. 11 | cross_encoder_at_k (int): The value of k for cross-encoder. 12 | question_count (int): The count of questions. 13 | actual (str): The actual output. 14 | expected (str): The expected output. 15 | search_type (str): The type of search. 16 | search_evals (list): The evaluations for search. 17 | context (str): The qna context of the query. 18 | retrieved_contexts (list): The list of retrieved contexts of the query. 19 | question (str): The question of the query. 20 | """ 21 | 22 | def __init__( 23 | self, 24 | rerank: bool, 25 | rerank_type: str, 26 | cross_encoder_model: str, 27 | llm_rerank_threshold: int, 28 | retrieve_num_of_documents: int, 29 | cross_encoder_at_k: int, 30 | question_count: int, 31 | actual: str, 32 | expected: str, 33 | search_type: str, 34 | search_evals: list, 35 | context: str, 36 | retrieved_contexts: list, 37 | question: str, 38 | ): 39 | self.rerank = rerank 40 | self.rerank_type = rerank_type 41 | self.cross_encoder_model = cross_encoder_model 42 | self.llm_rerank_threshold = llm_rerank_threshold 43 | self.retrieve_num_of_documents = retrieve_num_of_documents 44 | self.cross_encoder_at_k = cross_encoder_at_k 45 | self.question_count = question_count 46 | self.actual = actual 47 | self.expected = expected 48 | self.search_type = search_type 49 | self.search_evals = search_evals 50 | self.context = context 51 | self.retrieved_contexts = retrieved_contexts 52 | self.question = question 53 | -------------------------------------------------------------------------------- /azureml/index.py: -------------------------------------------------------------------------------- 1 | from rag_experiment_accelerator.checkpoint import init_checkpoint 2 | import os 3 | import sys 4 | import argparse 5 | 6 | import mlflow 7 | 8 | project_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) 9 | sys.path.append(project_dir) 10 | 11 | from rag_experiment_accelerator.config.environment import Environment # noqa: E402 12 | from rag_experiment_accelerator.config.config import Config # noqa: E402 13 | from rag_experiment_accelerator.config.index_config import IndexConfig # noqa: E402 14 | from rag_experiment_accelerator.run.index import run as index_run # noqa: E402 15 | 16 | 17 | def init(): 18 | """Main function of the script.""" 19 | 20 | global args 21 | 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument( 24 | "--config_path", type=str, help="input: path to the config file" 25 | ) 26 | parser.add_argument("--data_dir", type=str, help="input: path to the data") 27 | parser.add_argument("--index_name", type=str, help="input: experiment index name") 28 | parser.add_argument( 29 | "--keyvault", 30 | type=str, 31 | help="input: keyvault to load the environment from", 32 | ) 33 | parser.add_argument( 34 | "--mlflow_tracking_uri", 35 | type=str, 36 | help="input: mlflow tracking uri to log to", 37 | ) 38 | parser.add_argument( 39 | "--index_name_path", 40 | type=str, 41 | help="output: path to write a file with index name", 42 | ) 43 | 44 | args, _ = parser.parse_known_args() 45 | 46 | global config 47 | global environment 48 | global index_config 49 | global mlflow_client 50 | 51 | environment = Environment.from_keyvault(args.keyvault) 52 | config = Config.from_path(environment, args.config_path, args.data_dir) 53 | init_checkpoint(config) 54 | 55 | index_config = IndexConfig.from_index_name(args.index_name) 56 | mlflow_client = mlflow.MlflowClient(args.mlflow_tracking_uri) 57 | 58 | 59 | def run(input_paths: list[str]) -> list[str]: 60 | global args 61 | global config 62 | global environment 63 | global index_config 64 | global mlflow_client 65 | 66 | index_run(environment, config, index_config, input_paths, mlflow_client) 67 | 68 | return [args.index_name] 69 | -------------------------------------------------------------------------------- /promptflow/rag-experiment-accelerator/env_setup.md: -------------------------------------------------------------------------------- 1 | # Promptflow Secret Setup 2 | 3 | ## Prerequisites 4 | Install the dev-requirements and login to the az cli. 5 | ``` bash 6 | # Install the dev requirements 7 | pip install -r dev-requirements.txt 8 | 9 | # Login to the az cli 10 | az login 11 | ``` 12 | 13 | ## AzureML Connections 14 | A Custom Connection is a generic connection type that stores and manages credentials required for interacting with LLMs. It has two dictionaries, `secrets` for secrets to be stored in Key Vault, and `configs` for non-secrets that are stored in the AzureML workspace. 15 | 16 | 17 | You can create a custom connection in the AzureML workspace by following the instructions [here](https://learn.microsoft.com/en-us/azure/machine-learning/prompt-flow/tools-reference/python-tool?view=azureml-api-2#create-a-custom-connection). The key-value pairs required are listed in the Secrets and Configs sections. 18 | 19 | The following variables are required to be set as secret: 20 | - AZURE_SEARCH_ADMIN_KEY 21 | - OPENAI_API_KEY 22 | - AML_SUBSCRIPTION_ID 23 | - AML_RESOURCE_GROUP_NAME 24 | - AML_WORKSPACE_NAME 25 | 26 | And the remaining variables must not be set as secret: 27 | - AZURE_SEARCH_SERVICE_ENDPOINT 28 | - OPENAI_ENDPOINT 29 | - OPENAI_API_VERSION 30 | 31 | The following variables are optional: 32 | - AZURE_LANGUAGE_SERVICE_KEY - secret 33 | - AZURE_LANGUAGE_SERVICE_ENDPOINT - non secret 34 | - LOGGING_LEVEL - non secret 35 | 36 | ## Configuring your connection locally 37 | To configure promptflow to connect to AzureML, you need to update the top level `.azureml/config.json` file with the `workspace_name`, `resource_group`, and `subscription_id` that your connection is stored in. You can find more information about this in the [documentation](https://microsoft.github.io/promptflow/how-to-guides/set-global-configs.html#azureml). 38 | 39 | To update the local promptflow connection provider to look for AzureML connections, you can use the following code: 40 | ``` bash 41 | # Set your promptflow connection provider to azureml 42 | pf config set connection.provider=azureml 43 | 44 | # Verify that the connection appears 45 | pf connection list 46 | ``` 47 | Note: Depending on the context you're running the `pf` commands from, you may need to move the `.azureml` folder into the root of the repository. 48 | -------------------------------------------------------------------------------- /rag_experiment_accelerator/evaluation/tests/test_search_eval.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import patch 2 | 3 | from rag_experiment_accelerator.evaluation.search_eval import ( 4 | evaluate_search_result, 5 | ) 6 | 7 | evaluation_content = "my content to evaluate" 8 | search_response = [ 9 | { 10 | "@search.score": 0.03755760192871094, 11 | "content": "this is the first chunk", 12 | }, 13 | { 14 | "@search.score": 0.029906954616308212, 15 | "content": "this is the second chunk", 16 | }, 17 | { 18 | "@search.score": 0.028612013906240463, 19 | "content": "this is the third chunk", 20 | }, 21 | ] 22 | 23 | 24 | def test_evaluate_search_result_calulates_precision_score(): 25 | with patch( 26 | "rag_experiment_accelerator.evaluation.spacy_evaluator.SpacyEvaluator" 27 | ) as evaluator: 28 | evaluator.is_relevant.side_effect = [True, False, True] 29 | 30 | _, evaluation = evaluate_search_result( 31 | search_response, evaluation_content, evaluator 32 | ) 33 | 34 | expected_precision = [1.0, 0.5, 0.67] 35 | for i, precision in enumerate(evaluation.get("precision_scores")): 36 | assert precision == expected_precision[i] 37 | 38 | 39 | def test_evaluate_search_result_calulates_recall_score(): 40 | with patch( 41 | "rag_experiment_accelerator.evaluation.spacy_evaluator.SpacyEvaluator" 42 | ) as evaluator: 43 | evaluator.is_relevant.side_effect = [True, False, True] 44 | 45 | _, evaluation = evaluate_search_result( 46 | search_response, evaluation_content, evaluator 47 | ) 48 | 49 | expected_recall = [0.5, 0.5, 1.0] 50 | for i, recall in enumerate(evaluation.get("recall_scores")): 51 | assert recall == expected_recall[i] 52 | 53 | 54 | def test_evaluate_search_result_returns_all_search_content(): 55 | with patch( 56 | "rag_experiment_accelerator.evaluation.spacy_evaluator.SpacyEvaluator" 57 | ) as evaluator: 58 | evaluator.is_relevant.side_effect = [True, False, True] 59 | 60 | content, _ = evaluate_search_result( 61 | search_response, evaluation_content, evaluator 62 | ) 63 | 64 | for i, doc in enumerate(search_response): 65 | assert doc["content"] == content[i] 66 | -------------------------------------------------------------------------------- /rag_experiment_accelerator/checkpoint/README.md: -------------------------------------------------------------------------------- 1 | # Checkpoints 2 | 3 | ## What is a checkpoint? 4 | Checkpoints are used to skip the processing of data that has already been processed in previous runs. 5 | A checkpoint object is used to wrap methods, so when the method is called with an ID that was called before, instead of executing the method, the checkpoint will return the result of the previous execution. 6 | 7 | ## Usage 8 | 9 | ### 1. Initialize the checkpoint object: 10 | ```python 11 | init_checkpoint(config) 12 | ``` 13 | 14 | ### 2. Wrap the method you want to cache with the checkpoint decorator: 15 | ```python 16 | @cache_with_checkpoint(id="arg2.id") 17 | def method(arg1, arg2): 18 | pass 19 | ``` 20 | 21 | or wrap the method using the checkpoint object: 22 | ```python 23 | get_checkpoint().load_or_run(method, arg2.id, arg1, arg2) 24 | ``` 25 | 26 | (arg2.id is the ID that uniquely identifies the call in this example) 27 | 28 | This call will check if the provided method has previously been executed with the given ID, If it has, it returns the cached result, otherwise, it executes the method with the given arguments and caches the result for future calls. 29 | 30 | ## Checkpoint types 31 | 32 | ### Checkpoint 33 | The base class for all checkpoints. It provides the basic functionality for initializing and retrieving the checkpoint instance. 34 | 35 | A Checkpoint object is a singleton, meaning, only one checkpoint instance exists at a time. 36 | To create a new checkpoint instance (or to override the existing instance), use the `init_checkpoint` method, this method will create a checkpoint object according to the provided configuration. 37 | 38 | To get the current checkpoint instance, use the `get_checkpoint` method. 39 | 40 | ### LocalStorageCheckpoint 41 | Checkpoint implementation for the local executions of the pipeline (i.e. the developer's machine), uses the `pickle` library for serializing and persisting the method results to the local storage. 42 | The checkpoint data is saved in the `artifacts/checkpoint` directory. 43 | 44 | ### NullCheckpoint 45 | Checkpoint implementation that does not cache any data. This is useful when you want to disable the checkpointing mechanism. 46 | 47 | ## Deleting Checkpoint data 48 | To delete the checkpoint data, simply run the following `Make` command: 49 | ```bash 50 | make clear_checkpoints 51 | ``` -------------------------------------------------------------------------------- /rag_experiment_accelerator/config/config_validator.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from jsonschema import ValidationError, validate 4 | import requests 5 | 6 | schema_cache = {} 7 | 8 | 9 | def fetch_json_schema_from_url(schema_url): 10 | """Fetch the JSON schema from a URL.""" 11 | response = requests.get(schema_url, timeout=5) 12 | response.raise_for_status() 13 | return response.json() 14 | 15 | 16 | def fetch_json_schema_from_file(schema_file_path, source_file_path): 17 | """Fetch the JSON schema from a local file path.""" 18 | normalised_schema_path = get_normalised_schema_path( 19 | schema_file_path, source_file_path 20 | ) 21 | 22 | if not os.path.isfile(normalised_schema_path): 23 | raise ValueError(f"Local schema file not found: {normalised_schema_path}") 24 | 25 | with open(normalised_schema_path, "r", encoding="utf8") as schema_file: 26 | return json.load(schema_file) 27 | 28 | 29 | def get_normalised_schema_path(schema_file_path, source_file_path): 30 | source_dir = os.path.dirname(source_file_path) 31 | new_schema_file_path = os.path.join(source_dir, schema_file_path) 32 | return os.path.normpath(new_schema_file_path) 33 | 34 | 35 | def fetch_json_schema(schema_reference, source_file_path): 36 | """Fetch the JSON schema from a URL or local file path, with caching.""" 37 | if schema_reference in schema_cache: 38 | return schema_cache[schema_reference] 39 | 40 | schema = ( 41 | fetch_json_schema_from_url(schema_reference) 42 | if schema_reference.startswith(("http://", "https://")) 43 | else fetch_json_schema_from_file(schema_reference, source_file_path) 44 | ) 45 | 46 | schema_cache[schema_reference] = schema 47 | return schema 48 | 49 | 50 | def validate_json_with_schema( 51 | json_data, source_file_path 52 | ) -> tuple[bool, ValidationError | None]: 53 | """Validate a JSON object using the schema specified in its $schema property.""" 54 | try: 55 | schema_reference = json_data.get("$schema") 56 | if not schema_reference: 57 | return True, None 58 | 59 | schema = fetch_json_schema(schema_reference, source_file_path) 60 | 61 | validate(instance=json_data, schema=schema) 62 | return True, None 63 | except ValidationError as ve: 64 | return False, ve 65 | -------------------------------------------------------------------------------- /rag_experiment_accelerator/doc_loader/customJsonLoader.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | from typing import Union 4 | 5 | from langchain.docstore.document import Document 6 | from langchain.document_loaders.base import BaseLoader 7 | 8 | # Replaces langchain.document_loaders.JSONLoader to not use jq for windows compatibility 9 | # Note: Does not currently support jsonl, which is what the seq_num metadata field tracks 10 | 11 | 12 | class CustomJSONLoader(BaseLoader): 13 | def __init__( 14 | self, 15 | file_path: Union[str, Path], 16 | keys_to_load: list[str] = ["content", "title"], 17 | strict_keys: bool = True, 18 | ): 19 | self.file_path = Path(file_path).resolve() 20 | self._keys_to_load = keys_to_load 21 | self._strict_keys = strict_keys 22 | 23 | def _load_schema_from_dict(self, data: dict) -> str: 24 | if self._keys_to_load is None: 25 | return data 26 | else: 27 | return_dict = {} 28 | for k in self._keys_to_load: 29 | value = data.get(k) 30 | if value is None and self._strict_keys: 31 | raise ValueError( 32 | f"JSON file at path {self.file_path} must contain the field '{k}'" 33 | ) 34 | return_dict[k] = value 35 | return return_dict 36 | 37 | def load(self) -> list[Document]: 38 | """Load and return documents from the JSON file.""" 39 | docs: list[Document] = [] 40 | # Load JSON file 41 | with self.file_path.open(encoding="utf-8") as f: 42 | data = json.load(f) 43 | page_content = [] 44 | 45 | if not isinstance(data, list): 46 | raise ValueError( 47 | f"JSON file at path: {self.file_path} must be a list of object and expects each object to contain the fields {self._keys_to_load}" 48 | ) 49 | else: 50 | for entry in data: 51 | data_dict = self._load_schema_from_dict(entry) 52 | page_content.append(data_dict) 53 | 54 | metadata = { 55 | "source": str(self.file_path), 56 | } 57 | 58 | docs.append(Document(page_content=str(page_content), metadata=metadata)) 59 | return docs 60 | -------------------------------------------------------------------------------- /azureml/eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import argparse 4 | import mlflow 5 | 6 | project_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) 7 | sys.path.append(project_dir) 8 | 9 | from rag_experiment_accelerator.config.environment import Environment # noqa: E402 10 | from rag_experiment_accelerator.config.config import Config # noqa: E402 11 | from rag_experiment_accelerator.config.index_config import IndexConfig # noqa: E402 12 | from rag_experiment_accelerator.run.evaluation import run as eval_run # noqa: E402 13 | 14 | 15 | def main(): 16 | """Main function of the script.""" 17 | 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument( 20 | "--config_path", type=str, help="input: path to the config file" 21 | ) 22 | parser.add_argument( 23 | "--index_name_path", 24 | type=str, 25 | help="input: path to a file containing index name", 26 | ) 27 | parser.add_argument( 28 | "--query_result_dir", 29 | type=str, 30 | help="input: path to read results of querying from", 31 | ) 32 | parser.add_argument( 33 | "--keyvault", type=str, help="keyvault to load the environment from" 34 | ) 35 | parser.add_argument( 36 | "--mlflow_tracking_uri", 37 | type=str, 38 | help="input: mlflow tracking uri to log to", 39 | ) 40 | parser.add_argument( 41 | "--mlflow_parent_run_id", 42 | type=str, 43 | help="input: mlflow parent run id to connect nested run to", 44 | ) 45 | parser.add_argument( 46 | "--eval_result_dir", 47 | type=str, 48 | help="output: path to write results of evaluation to", 49 | ) 50 | args = parser.parse_args() 51 | 52 | environment = Environment.from_keyvault(args.keyvault) 53 | config = Config.from_path(environment, config_path=args.config_path) 54 | with open(args.index_name_path, "r") as f: 55 | index_name = f.readline() 56 | index_config = IndexConfig.from_index_name(index_name) 57 | 58 | config.path.query_data_dir = args.query_result_dir 59 | config.path.eval_data_dir = args.eval_result_dir 60 | 61 | mlflow_client = mlflow.MlflowClient(args.mlflow_tracking_uri) 62 | eval_run(environment, config, index_config, mlflow_client, name_suffix="_result") 63 | 64 | 65 | if __name__ == "__main__": 66 | main() 67 | -------------------------------------------------------------------------------- /rag_experiment_accelerator/run/tests/test_qa_generation.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import MagicMock, patch 2 | 3 | from rag_experiment_accelerator.run.qa_generation import run 4 | 5 | 6 | @patch("rag_experiment_accelerator.run.qa_generation.Environment") 7 | @patch("rag_experiment_accelerator.run.qa_generation.exists") 8 | @patch("rag_experiment_accelerator.run.qa_generation.load_documents") 9 | @patch("rag_experiment_accelerator.run.qa_generation.cluster") 10 | @patch("rag_experiment_accelerator.run.qa_generation.generate_qna") 11 | @patch("rag_experiment_accelerator.run.qa_generation.create_data_asset") 12 | def test_run( 13 | mock_create_data_asset, 14 | mock_generate_qna, 15 | mock_cluster, 16 | mock_load_documents, 17 | mock_exists, 18 | mock_environment, 19 | ): 20 | # Arrange 21 | data_dir = "test_data_dir" 22 | df_instance = MagicMock() 23 | 24 | mock_config = MagicMock() 25 | mock_config.index.sampling.sample_data = True 26 | mock_config.index.sampling.optimum_k = 3 27 | 28 | sampled_input_data_csv_path = f"{data_dir}/sampling/sampled_cluster_predictions_cluster_number_{mock_config.index.sampling.optimum_k}.csv" 29 | mock_config.path.sampled_cluster_predictions_path.return_value = ( 30 | sampled_input_data_csv_path 31 | ) 32 | mock_exists.return_value = False 33 | 34 | mock_load_documents.return_value = all_docs_instance = MagicMock() 35 | mock_cluster.return_value = all_docs_instance = MagicMock() 36 | mock_generate_qna.return_value = df_instance 37 | filepaths = ["file_path_one", "file_path_two"] 38 | 39 | # Act 40 | run(mock_environment, mock_config, filepaths) 41 | 42 | # Assert 43 | mock_load_documents.assert_called_once_with( 44 | mock_environment, 45 | mock_config.index.chunking.chunking_strategy, 46 | mock_config.data_formats, 47 | filepaths, 48 | 2000, 49 | 0, 50 | ) 51 | mock_generate_qna.assert_called_once_with( 52 | mock_environment, 53 | mock_config, 54 | all_docs_instance, 55 | mock_config.openai.azure_oai_chat_deployment_name, 56 | ) 57 | df_instance.to_json.assert_called_once_with( 58 | mock_config.path.eval_data_file, orient="records", lines=True 59 | ) 60 | mock_create_data_asset.assert_called_once_with( 61 | mock_config.path.eval_data_file, "eval_data", mock_environment 62 | ) 63 | -------------------------------------------------------------------------------- /rag_experiment_accelerator/llm/prompt/qna_prompts.py: -------------------------------------------------------------------------------- 1 | import json 2 | from rag_experiment_accelerator.llm.prompt.prompt import ( 3 | StructuredWithCoTPrompt, 4 | StructuredPrompt, 5 | PromptTag, 6 | ) 7 | 8 | 9 | def qna_generation_validate(response: str) -> bool: 10 | response_json = json.loads(response) 11 | return ( 12 | isinstance(response_json, dict) 13 | and "question" in response_json 14 | and "answer" in response_json 15 | ) 16 | 17 | 18 | _response_template: str = """ 19 | Context: 20 | ${context} 21 | """ 22 | 23 | # TODO: Add selector for usage of long/short prompts 24 | generate_qna_long_single_context_instruction_prompt = StructuredWithCoTPrompt( 25 | system_message="generate_qna_long_single_context.txt", 26 | user_template=_response_template, 27 | tags={PromptTag.JSON, PromptTag.NonStrict}, 28 | validator=qna_generation_validate, 29 | ) 30 | 31 | # TODO: Add selector for usage of long/short prompts 32 | generate_qna_short_single_context_instruction_prompt = StructuredWithCoTPrompt( 33 | system_message="generate_qna_short_single_context.txt", 34 | user_template=_response_template, 35 | tags={PromptTag.JSON, PromptTag.NonStrict}, 36 | validator=qna_generation_validate, 37 | ) 38 | 39 | # TODO: Add selector for usage of long/short prompts 40 | generate_qna_long_multiple_context_instruction_prompt = StructuredWithCoTPrompt( 41 | system_message="generate_qna_long_multi_context.txt", 42 | user_template=_response_template, 43 | tags={PromptTag.JSON, PromptTag.NonStrict}, 44 | validator=qna_generation_validate, 45 | ) 46 | 47 | # TODO: Add selector for usage of long/short prompts 48 | generate_qna_short_multiple_context_instruction_prompt = StructuredWithCoTPrompt( 49 | system_message="generate_qna_short_multi_context.txt", 50 | user_template=_response_template, 51 | tags={PromptTag.JSON, PromptTag.NonStrict}, 52 | validator=qna_generation_validate, 53 | ) 54 | 55 | # TODO: Add selector for usage of long/short prompts 56 | generate_qna_short_single_context_no_cot_instruction_prompt = StructuredPrompt( 57 | system_message="generate_qna_short_single_context_no_cot.txt", 58 | user_template=_response_template, 59 | tags={PromptTag.JSON, PromptTag.NonStrict}, 60 | validator=qna_generation_validate, 61 | ) 62 | 63 | 64 | qna_generation_prompt = generate_qna_short_single_context_no_cot_instruction_prompt 65 | -------------------------------------------------------------------------------- /azureml/query.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import argparse 4 | 5 | import mlflow 6 | 7 | project_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) 8 | sys.path.append(project_dir) 9 | 10 | from rag_experiment_accelerator.checkpoint import init_checkpoint # noqa: E402 11 | from rag_experiment_accelerator.config.environment import Environment # noqa: E402 12 | from rag_experiment_accelerator.config.config import Config # noqa: E402 13 | from rag_experiment_accelerator.config.index_config import IndexConfig # noqa: E402 14 | from rag_experiment_accelerator.run.querying import run as query_run # noqa: E402 15 | 16 | 17 | def main(): 18 | """Main function of the script.""" 19 | 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument( 22 | "--config_path", type=str, help="input: path to the config file" 23 | ) 24 | parser.add_argument( 25 | "--eval_data_path", type=str, help="input: path to the data to evaluate on" 26 | ) 27 | parser.add_argument( 28 | "--index_name_path", 29 | type=str, 30 | help="input: path to a file containing index name", 31 | ) 32 | parser.add_argument( 33 | "--keyvault", 34 | type=str, 35 | help="input: keyvault to load the environment from", 36 | ) 37 | parser.add_argument( 38 | "--mlflow_tracking_uri", 39 | type=str, 40 | help="input: mlflow tracking uri to log to", 41 | ) 42 | parser.add_argument( 43 | "--mlflow_parent_run_id", 44 | type=str, 45 | help="input: mlflow parent run id to connect nested run to", 46 | ) 47 | parser.add_argument( 48 | "--query_result_dir", 49 | type=str, 50 | help="output: path to write results of querying to", 51 | ) 52 | args = parser.parse_args() 53 | 54 | environment = Environment.from_keyvault(args.keyvault) 55 | 56 | config = Config.from_path(environment, args.config_path) 57 | config.path.eval_data_file = args.eval_data_path 58 | config.path.query_data_dir = args.query_result_dir 59 | init_checkpoint(config) 60 | 61 | with open(args.index_name_path, "r") as f: 62 | index_name = f.readline() 63 | index_config = IndexConfig.from_index_name(index_name) 64 | 65 | mlflow_client = mlflow.MlflowClient(args.mlflow_tracking_uri) 66 | query_run(environment, config, index_config, mlflow_client) 67 | 68 | 69 | if __name__ == "__main__": 70 | main() 71 | -------------------------------------------------------------------------------- /rag_experiment_accelerator/run/evaluation.py: -------------------------------------------------------------------------------- 1 | from typing import MutableMapping 2 | from azure.ai.ml import MLClient 3 | from dotenv import load_dotenv 4 | import mlflow 5 | 6 | from rag_experiment_accelerator.config.config import Config 7 | from rag_experiment_accelerator.config.index_config import IndexConfig 8 | from rag_experiment_accelerator.config.environment import Environment 9 | from rag_experiment_accelerator.evaluation import eval 10 | from rag_experiment_accelerator.utils.logging import get_logger 11 | 12 | 13 | load_dotenv(override=True) 14 | logger = get_logger(__name__) 15 | 16 | 17 | def _flatten_dict_gen(d, parent_key, sep): 18 | for k, v in d.items(): 19 | new_key = parent_key + sep + str(k) if parent_key else k 20 | if isinstance(v, MutableMapping): 21 | yield from flatten_dict(v, new_key, sep=sep).items() 22 | else: 23 | yield new_key, v 24 | 25 | 26 | def flatten_dict(d: MutableMapping, parent_key: str = "", sep: str = "."): 27 | return dict(_flatten_dict_gen(d, parent_key, sep)) 28 | 29 | 30 | def get_job_hyper_params(config: Config, index_config: IndexConfig) -> dict: 31 | """ 32 | Returns the hyper parameters for the current job. 33 | """ 34 | config_dict = config.to_dict() 35 | 36 | # Remove not needed parameters 37 | for param in ["path", "main_instruction", "use_checkpoints"]: 38 | config_dict.__delitem__(param) 39 | 40 | config_flatten_dict = flatten_dict(config_dict) 41 | 42 | return config_flatten_dict 43 | 44 | 45 | def run( 46 | environment: Environment, 47 | config: Config, 48 | index_config: IndexConfig, 49 | mlflow_client: MLClient, 50 | name_suffix: str, 51 | ): 52 | """ 53 | Runs the evaluation process for the RAG experiment accelerator. 54 | 55 | This function initializes the configuration, sets up the ML client, and runs the evaluation process 56 | for all combinations of chunk sizes, overlap sizes, embedding dimensions, EF constructions, and EF searches. 57 | 58 | Returns: 59 | None 60 | """ 61 | logger.info(f"Evaluating Index: {index_config.index_name()}") 62 | 63 | params = get_job_hyper_params(config, index_config) 64 | mlflow.log_params(params) 65 | 66 | eval.evaluate_prompts( 67 | environment=environment, 68 | config=config, 69 | index_config=index_config, 70 | mlflow_client=mlflow_client, 71 | name_suffix=name_suffix, 72 | ) 73 | -------------------------------------------------------------------------------- /rag_experiment_accelerator/checkpoint/tests/test_checkpoint.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import MagicMock 2 | import pytest 3 | from unittest.mock import patch 4 | 5 | from rag_experiment_accelerator.checkpoint.checkpoint_factory import ( 6 | get_checkpoint, 7 | init_checkpoint, 8 | ) 9 | from rag_experiment_accelerator.checkpoint.local_storage_checkpoint import ( 10 | LocalStorageCheckpoint, 11 | ) 12 | from rag_experiment_accelerator.checkpoint.null_checkpoint import NullCheckpoint 13 | from rag_experiment_accelerator.config.config import ExecutionEnvironment 14 | 15 | 16 | @pytest.fixture 17 | def mock_checkpoints(): 18 | with patch.object( 19 | LocalStorageCheckpoint, "__init__", return_value=None 20 | ), patch.object(NullCheckpoint, "__init__", return_value=None): 21 | yield 22 | 23 | 24 | def test_get_checkpoint_without_init_fails(): 25 | with pytest.raises(Exception) as e_info: 26 | get_checkpoint() 27 | assert ( 28 | str(e_info.value) 29 | == "Checkpoint not initialized yet. Call init_checkpoint() first." 30 | ) 31 | 32 | 33 | def test_get_checkpoint_for_local_executions(mock_checkpoints): 34 | config = MagicMock() 35 | config.execution_environment = ExecutionEnvironment.LOCAL 36 | config.use_checkpoints = True 37 | 38 | init_checkpoint(config) 39 | checkpoint = get_checkpoint() 40 | assert isinstance(checkpoint, LocalStorageCheckpoint) 41 | 42 | 43 | def test_get_checkpoint_for_azure_ml(mock_checkpoints): 44 | config = MagicMock() 45 | config.execution_environment = ExecutionEnvironment.AZURE_ML 46 | config.use_checkpoints = True 47 | 48 | init_checkpoint(config) 49 | checkpoint = get_checkpoint() 50 | # currently not supposed for Azure ML, so it should return NullCheckpoint 51 | assert isinstance(checkpoint, NullCheckpoint) 52 | 53 | 54 | def test_get_checkpoint_when_should_not_use_checkpoints_locally(mock_checkpoints): 55 | config = MagicMock() 56 | config.execution_environment = ExecutionEnvironment.LOCAL 57 | config.use_checkpoints = False 58 | 59 | init_checkpoint(config) 60 | checkpoint = get_checkpoint() 61 | assert isinstance(checkpoint, NullCheckpoint) 62 | 63 | 64 | def test_get_checkpoint_when_should_not_use_checkpoints_in_azure_ml(mock_checkpoints): 65 | config = MagicMock() 66 | config.execution_environment = ExecutionEnvironment.AZURE_ML 67 | config.use_checkpoints = False 68 | 69 | init_checkpoint(config) 70 | checkpoint = get_checkpoint() 71 | assert isinstance(checkpoint, NullCheckpoint) 72 | -------------------------------------------------------------------------------- /rag_experiment_accelerator/embedding/aoai_embedding_model.py: -------------------------------------------------------------------------------- 1 | from openai import AzureOpenAI 2 | 3 | from rag_experiment_accelerator.config.environment import Environment 4 | from rag_experiment_accelerator.embedding.embedding_model import EmbeddingModel 5 | 6 | 7 | class AOAIEmbeddingModel(EmbeddingModel): 8 | """ 9 | A class representing an AOAI Embedding Model. 10 | 11 | Args: 12 | model_name (str): The name of the deployment. 13 | environment (Environment): The initialized environment. 14 | dimension (int, optional): The dimension of the embedding. Defaults to 1536 which is the dimension of text-embedding-ada-002. 15 | **kwargs: Additional keyword arguments. 16 | 17 | Attributes: 18 | model_name (str): The name of the deployment. 19 | _client (AzureOpenAI): The initialized AzureOpenAI client. 20 | 21 | """ 22 | 23 | def __init__( 24 | self, 25 | model_name: str, 26 | environment: Environment, 27 | dimension: int = 1536, 28 | shorten_dimensions: bool = False, 29 | **kwargs 30 | ) -> None: 31 | super().__init__(name=model_name, dimension=dimension, **kwargs) 32 | self.model_name = model_name 33 | self.shorten_dimensions = shorten_dimensions 34 | self._client: AzureOpenAI = self._initialize_client(environment=environment) 35 | 36 | def _initialize_client(self, environment: Environment) -> AzureOpenAI: 37 | """ 38 | Initializes the AzureOpenAIClient. 39 | 40 | Args: 41 | environment (Environment): The initialized environment. 42 | 43 | Returns: 44 | AzureOpenAI: The initialized AzureOpenAI client. 45 | 46 | """ 47 | return AzureOpenAI( 48 | azure_endpoint=environment.openai_endpoint, 49 | api_key=environment.openai_api_key, 50 | api_version=environment.openai_api_version, 51 | ) 52 | 53 | def generate_embedding(self, chunk: str) -> list[float]: 54 | """ 55 | Generates the embedding for a given chunk of text. 56 | 57 | Args: 58 | chunk (str): The input text. 59 | 60 | Returns: 61 | list[float]: The generated embedding. 62 | 63 | """ 64 | 65 | kwargs = {} 66 | if self.shorten_dimensions: 67 | kwargs["dimensions"] = self.dimension 68 | 69 | response = self._client.embeddings.create( 70 | input=chunk, model=self.model_name, **kwargs 71 | ) 72 | 73 | return response.data[0].embedding 74 | -------------------------------------------------------------------------------- /rag_experiment_accelerator/llm/prompts_text/multiple_prompt_instruction.txt: -------------------------------------------------------------------------------- 1 | Your task is to take a question as input and generate a maximum of three sub-questions that cover all aspects of the original question. 2 | The output should be in strict JSON format, containing the list of sub-questions. 3 | 4 | Requirements: 5 | 1. Analyze the original question to identify key aspects or components. 6 | 2. Generate sub-questions that address each identified aspect. 7 | 3. Ensure the sub-questions collectively cover the entire scope of the original question. 8 | 4. Output the sub-questions as a JSON list with an array of the sub-questions. 9 | 5. Produce a JSON output that is 100 percent structurally correct, with proper nesting, comma placement, and quotation marks. 10 | 6. Format the JSON with proper indentation for readability. 11 | 7. Ensure there is no trailing comma after the last element in the array. 12 | 8. Generate anywhere from 2 up to 10 sub-questions. 13 | 14 | 15 | User: 16 | How does climate change affect ocean biodiversity? 17 | 18 | Assistant: 19 | ["What impact does rising ocean temperatures have on marine species?", "How does ocean acidification affect coral reefs and shellfish populations?"] 20 | 21 | 22 | 23 | User: 24 | What are the key considerations when implementing AI technologies in healthcare? 25 | 26 | Assistant: 27 | ["What ethical concerns arise with the use of AI in patient care?", "How can AI improve diagnosis accuracy in healthcare?", "What are the data privacy implications of using AI in healthcare?", "How can AI be used to personalize patient treatment plans?", "What are the challenges of integrating AI with existing healthcare IT systems?", "How does AI impact the roles and responsibilities of healthcare professionals?", "What training is required for healthcare staff to effectively use AI tools?", "How can AI help in managing healthcare costs?", "What are the regulatory considerations for AI in healthcare?", "How can AI technologies enhance patient engagement and satisfaction?"] 28 | 29 | 30 | 31 | User: 32 | What should someone consider when starting an online business? 33 | 34 | Assistant: 35 | ["What are the key legal requirements for starting an online business?", "How should one choose the right platform for their online business?", "What are effective digital marketing strategies for a new online business?", "How does one handle logistics and supply chain management for an online store?", "What customer service practices should be implemented for online businesses?"] 36 | -------------------------------------------------------------------------------- /docs/configs-appendix.md: -------------------------------------------------------------------------------- 1 | # Understanding the config files 2 | 3 | ## Prerequisites 4 | Familiarity with [ReadMe configuration of elements](/README.md#Description-of-configuration-elements) 5 | 6 | ## Configuration links for more reading. 7 | - Search Types 8 | - [Semantic Search][semantic search] 9 | - [Vector Search][vector search] 10 | - [Hybrid Search][hybrid search] 11 | - Chunking Strategies 12 | - [Size][Chunk Size] 13 | - [Overlap][Overlap] 14 | - [Embedding][Embeddings] 15 | - Models: The accelerator uses [Sentence Transformer][Sentence Transformer] to generate the embeddings which utilizes [Pre-Trained Models][Transformer Models] based on embedding dimensions. 16 | - Dimensions: Each valid value maps to different models for embedding. 17 | - 384: [all-MiniLM-L6-v2][all-MiniLM-L6-v2] 18 | - 768: [all-mpnet-base-v2][all-mpnet-base-v2] 19 | - 1024:[bert-large-nli-mean-tokens][bert-large-nli-mean-tokens] 20 | - LLM Metrics calculated using scikit-learn in combination with `Math.mean` 21 | - [Precision][precision score] 22 | - [Recall][recall score] 23 | - [Prompt Engineering][prompts] 24 | 25 | 26 | 27 | [Chunk Size]: https://learn.microsoft.com/en-us/azure/search/vector-search-how-to-chunk-documents#common-chunking-techniques 28 | [Overlap]: https://learn.microsoft.com/en-us/azure/search/vector-search-how-to-chunk-documents#content-overlap-considerations 29 | [Embeddings]: https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/understand-embeddings 30 | [Sentence Transformer]: https://www.sbert.net/ 31 | [Transformer Models]: https://www.sbert.net/docs/pretrained_models.html 32 | [all-MiniLM-L6-v2]: https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2 33 | [all-mpnet-base-v2]: https://huggingface.co/sentence-transformers/all-mpnet-base-v2 34 | [bert-large-nli-mean-tokens]: https://huggingface.co/sentence-transformers/bert-large-nli-mean-tokens 35 | [prompts]: https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/advanced-prompt-engineering?pivots=programming-language-chat-completions 36 | [recall score]: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.recall_score.html 37 | [precision score]: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.precision_score.html#sklearn.metrics.precision_score 38 | [vector search]: https://learn.microsoft.com/en-us/azure/search/vector-search-overview 39 | [hybrid search]: https://learn.microsoft.com/en-us/azure/search/hybrid-search-overview 40 | [semantic search]: https://learn.microsoft.com/en-us/azure/search/semantic-search-overview -------------------------------------------------------------------------------- /rag_experiment_accelerator/config/path_config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | import os 3 | from rag_experiment_accelerator.config import paths 4 | from rag_experiment_accelerator.config.base_config import BaseConfig 5 | 6 | 7 | class Paths: 8 | ARTIFACTS_DIR = "artifacts" 9 | DATA_DIR = "data" 10 | EVAL_DATA_FILE = "eval_data.jsonl" 11 | GENERATED_INDEX_NAMES_FILE = "generated_index_names.jsonl" 12 | QUERY_DATA_DIR = "query_data" 13 | EVAL_DATA_DIR = "eval_score" 14 | SAMPLING_OUTPUT_DIR = "sampling" 15 | 16 | 17 | @dataclass 18 | class PathConfig(BaseConfig): 19 | artifacts_dir: str = "" 20 | data_dir: str = "" 21 | eval_data_file: str = "" 22 | eval_data_dir: str = "" 23 | generated_index_names_file: str = "" 24 | query_data_dir: str = "" 25 | sampling_output_dir: str = "" 26 | 27 | def initialize_paths(self, config_file_path: str, data_dir: str) -> None: 28 | self._config_dir = os.path.dirname(config_file_path) 29 | 30 | if not self.artifacts_dir: 31 | self.artifacts_dir = os.path.join(self._config_dir, Paths.ARTIFACTS_DIR) 32 | paths.try_create_directory(self.artifacts_dir) 33 | 34 | if data_dir: 35 | self.data_dir = data_dir 36 | elif not self.data_dir: 37 | self.data_dir = os.path.join(self._config_dir, Paths.DATA_DIR) 38 | 39 | if not self.eval_data_file: 40 | self.eval_data_file = os.path.join(self.artifacts_dir, Paths.EVAL_DATA_FILE) 41 | 42 | if not self.generated_index_names_file: 43 | self.generated_index_names_file = os.path.join( 44 | self.artifacts_dir, Paths.GENERATED_INDEX_NAMES_FILE 45 | ) 46 | 47 | if not self.query_data_dir: 48 | self.query_data_dir = os.path.join(self.artifacts_dir, Paths.QUERY_DATA_DIR) 49 | paths.try_create_directory(self.query_data_dir) 50 | 51 | if not self.eval_data_dir: 52 | self.eval_data_dir = os.path.join(self.artifacts_dir, Paths.EVAL_DATA_DIR) 53 | paths.try_create_directory(self.eval_data_dir) 54 | 55 | if not self.sampling_output_dir: 56 | self.sampling_output_dir = os.path.join( 57 | self.artifacts_dir, Paths.SAMPLING_OUTPUT_DIR 58 | ) 59 | paths.try_create_directory(self.sampling_output_dir) 60 | 61 | def sampled_cluster_predictions_path(self, optimum_k: int) -> str: 62 | return os.path.join( 63 | self.sampling_output_dir, 64 | f"sampled_cluster_predictions_cluster_number_{optimum_k}.csv", 65 | ) 66 | -------------------------------------------------------------------------------- /rag_experiment_accelerator/evaluation/transformer_based_metrics.py: -------------------------------------------------------------------------------- 1 | from sentence_transformers import SentenceTransformer 2 | from sklearn.metrics.pairwise import cosine_similarity 3 | 4 | 5 | # todo: can we remove this hardcoding and name the model in the config file? 6 | metric_type_model_mapping = { 7 | "bert_all_MiniLM_L6_v2": "all-MiniLM-L6-v2", 8 | "bert_base_nli_mean_tokens": "bert-base-nli-mean-tokens", 9 | "bert_large_nli_mean_tokens": "bert-large-nli-mean-tokens", 10 | "bert_large_nli_stsb_mean_tokens": "bert-large-nli-stsb-mean-tokens", 11 | "bert_distilbert_base_nli_stsb_mean_tokens": "distilbert-base-nli-stsb-mean-tokens", 12 | "bert_paraphrase_multilingual_MiniLM_L12_v2": "paraphrase-multilingual-MiniLM-L12-v2", 13 | } 14 | 15 | 16 | def compare_semantic_document_values(doc1, doc2, model_type): 17 | """ 18 | Compares the semantic values of two documents and returns the percentage of differences. 19 | 20 | Args: 21 | doc1 (str): The first document to compare. 22 | doc2 (str): The second document to compare. 23 | model_type (SentenceTransformer): The SentenceTransformer model to use for comparison. 24 | 25 | Returns: 26 | int: The percentage of differences between the two documents. 27 | """ 28 | differences = semantic_compare_values(doc1, doc2, model_type) 29 | 30 | return int(sum(differences) / len(differences)) 31 | 32 | 33 | def semantic_compare_values( 34 | value1: str, 35 | value2: str, 36 | model_type: SentenceTransformer, 37 | ) -> list[float]: 38 | """ 39 | Computes the semantic similarity between two values using a pre-trained SentenceTransformer model. 40 | 41 | Args: 42 | value1 (str): The first value to compare. 43 | value2 (str): The second value to compare. 44 | model_type (SentenceTransformer): The pre-trained SentenceTransformer model to use for encoding the values. 45 | 46 | Returns: 47 | A list of the similarity scores. 48 | """ 49 | embedding1 = model_type.encode([str(value1)]) 50 | embedding2 = model_type.encode([str(value2)]) 51 | similarity_score = cosine_similarity(embedding1, embedding2) 52 | 53 | return [similarity_score * 100] 54 | 55 | 56 | def compute_transformer_based_score( 57 | actual, 58 | expected, 59 | metric_type, 60 | ): 61 | if metric_type not in metric_type_model_mapping: 62 | raise KeyError(f"Invalid metric type: {metric_type}") 63 | 64 | transformer = SentenceTransformer( 65 | f"sentence-transformers/{metric_type_model_mapping[metric_type]}" 66 | ) 67 | return compare_semantic_document_values(actual, expected, transformer) 68 | -------------------------------------------------------------------------------- /infra/shared/storage.bicep: -------------------------------------------------------------------------------- 1 | metadata description = 'Creates an Azure storage account.' 2 | param name string 3 | param location string = resourceGroup().location 4 | param tags object = {} 5 | 6 | @allowed([ 7 | 'Cool' 8 | 'Hot' 9 | 'Premium' ]) 10 | param accessTier string = 'Hot' 11 | param allowBlobPublicAccess bool = true 12 | param allowCrossTenantReplication bool = true 13 | param allowSharedKeyAccess bool = true 14 | param containers array = [] 15 | param defaultToOAuthAuthentication bool = false 16 | param deleteRetentionPolicy object = {} 17 | @allowed([ 'AzureDnsZone', 'Standard' ]) 18 | param dnsEndpointType string = 'Standard' 19 | param kind string = 'StorageV2' 20 | param minimumTlsVersion string = 'TLS1_2' 21 | param queues array = [] 22 | param supportsHttpsTrafficOnly bool = true 23 | param networkAcls object = { 24 | bypass: 'AzureServices' 25 | defaultAction: 'Allow' 26 | } 27 | @allowed([ 'Enabled', 'Disabled' ]) 28 | param publicNetworkAccess string = 'Enabled' 29 | param sku object = { name: 'Standard_LRS' } 30 | 31 | resource storage 'Microsoft.Storage/storageAccounts@2022-05-01' = { 32 | name: name 33 | location: location 34 | tags: tags 35 | kind: kind 36 | sku: sku 37 | properties: { 38 | accessTier: accessTier 39 | allowBlobPublicAccess: allowBlobPublicAccess 40 | allowCrossTenantReplication: allowCrossTenantReplication 41 | allowSharedKeyAccess: allowSharedKeyAccess 42 | defaultToOAuthAuthentication: defaultToOAuthAuthentication 43 | dnsEndpointType: dnsEndpointType 44 | minimumTlsVersion: minimumTlsVersion 45 | networkAcls: networkAcls 46 | publicNetworkAccess: publicNetworkAccess 47 | supportsHttpsTrafficOnly: supportsHttpsTrafficOnly 48 | } 49 | 50 | resource blobServices 'blobServices' = if (!empty(containers)) { 51 | name: 'default' 52 | properties: { 53 | deleteRetentionPolicy: deleteRetentionPolicy 54 | } 55 | resource container 'containers' = [for container in containers: { 56 | name: container.name 57 | properties: { 58 | publicAccess: contains(container, 'publicAccess') ? container.publicAccess : 'None' 59 | } 60 | }] 61 | } 62 | 63 | resource queueServices 'queueServices' = if (!empty(queues)) { 64 | name: 'default' 65 | properties: { 66 | cors: { 67 | corsRules: [] 68 | } 69 | } 70 | resource queue 'queues' = [for queue in queues: { 71 | name: queue.name 72 | properties: { 73 | metadata: {} 74 | } 75 | }] 76 | } 77 | } 78 | 79 | output name string = storage.name 80 | output id string = storage.id 81 | output primaryEndpoints object = storage.properties.primaryEndpoints 82 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet) and [Xamarin](https://github.com/xamarin). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/security.md/definition), please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/security.md/msrc/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/security.md/msrc/pgp). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/security.md/msrc/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/security.md/cvd). 40 | 41 | 42 | -------------------------------------------------------------------------------- /rag_experiment_accelerator/doc_loader/structuredLoader.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | 3 | from langchain.document_loaders.base import BaseLoader 4 | from langchain.text_splitter import RecursiveCharacterTextSplitter 5 | 6 | from rag_experiment_accelerator.utils.logging import get_logger 7 | 8 | logger = get_logger(__name__) 9 | 10 | 11 | def load_structured_files( 12 | file_format: str, 13 | language: str, 14 | loader: BaseLoader, 15 | file_paths: list[str], 16 | chunk_size: str, 17 | overlap_size: str, 18 | loader_kwargs: dict[any] = None, 19 | ): 20 | """ 21 | Load and process structured files. 22 | 23 | Args: 24 | chunking_strategy (str): The chunking strategy to use between "azure-document-intelligence" and "basic". 25 | file_format (str): The file_format of the documents to be loaded. 26 | language (str): The language of the documents to be loaded. 27 | loader (BaseLoader): The document loader object that reads the files. 28 | file_paths (str): The paths to the files to load. 29 | chunk_size (str): The size of the chunks to split the documents into. 30 | overlap_size (str): The size of the overlapping parts between chunks. 31 | glob_patterns (list[str]): List of file extensions to consider (e.g., ["txt", "md"]). 32 | loader_kwargs (dict[any]): Extra arguments to loader. 33 | 34 | Returns: 35 | list[Document]: A list of processed and split document chunks. 36 | """ 37 | 38 | logger.info(f"Loading {file_format} files") 39 | 40 | documents = [] 41 | if loader_kwargs is None: 42 | loader_kwargs = {} 43 | 44 | for file in file_paths: 45 | documents += loader(file, **loader_kwargs).load() 46 | 47 | logger.debug(f"Loaded {len(documents)} {file_format} files") 48 | if language is None: 49 | text_splitter = RecursiveCharacterTextSplitter( 50 | chunk_size=chunk_size, 51 | chunk_overlap=overlap_size, 52 | length_function=len, 53 | ) 54 | else: 55 | text_splitter = RecursiveCharacterTextSplitter().from_language( 56 | language=language, 57 | chunk_size=chunk_size, 58 | chunk_overlap=overlap_size, 59 | ) 60 | 61 | logger.debug( 62 | f"Splitting {file_format} files into chunks of {chunk_size} characters with an overlap of {overlap_size} characters" 63 | ) 64 | 65 | docs = text_splitter.split_documents(documents) 66 | docsList = [] 67 | for doc in docs: 68 | docsList.append( 69 | {str(uuid.uuid4()): {"content": doc.page_content, "metadata": doc.metadata}} 70 | ) 71 | 72 | logger.info(f"Split {len(documents)} {file_format} files into {len(docs)} chunks") 73 | 74 | return docsList 75 | -------------------------------------------------------------------------------- /.github/workflows/rag_exp_acc_ci.yml: -------------------------------------------------------------------------------- 1 | name: RAG Experiment Accelerator CI 2 | 3 | on: 4 | workflow_call: 5 | workflow_dispatch: 6 | pull_request: 7 | types: [opened, ready_for_review, synchronize] 8 | branches: 9 | - main 10 | - development 11 | - prerelease 12 | push: 13 | branches: 14 | - main 15 | - development 16 | - prerelease 17 | merge_group: 18 | 19 | concurrency: 20 | group: ${{ github.workflow }}-${{ github.ref }} 21 | cancel-in-progress: ${{ github.ref != 'refs/heads/development' && github.ref != 'refs/heads/main' && github.ref != 'refs/heads/prerelease'}} 22 | 23 | jobs: 24 | execute-code-and-check: 25 | env: 26 | AZURE_SEARCH_ADMIN_KEY: ${{ secrets.AZURE_SEARCH_ADMIN_KEY }} 27 | AZURE_SEARCH_SERVICE_ENDPOINT: ${{ secrets.AZURE_SEARCH_SERVICE_ENDPOINT }} 28 | AZURE_SEARCH_USE_SEMANTIC_SEARCH: "true" 29 | AZURE_LANGUAGE_SERVICE_KEY: ${{ secrets.AZURE_LANGUAGE_SERVICE_KEY }} 30 | AZURE_LANGUAGE_SERVICE_ENDPOINT: ${{ secrets.AZURE_LANGUAGE_SERVICE_ENDPOINT }} 31 | OPENAI_API_TYPE: "azure" 32 | OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} 33 | OPENAI_API_VERSION: ${{ secrets.OPENAI_API_VERSION }} 34 | OPENAI_ENDPOINT: ${{ secrets.OPENAI_ENDPOINT }} 35 | AML_RESOURCE_GROUP_NAME: ${{ secrets.RESOURCE_GROUP_NAME }} 36 | AML_SUBSCRIPTION_ID: ${{ secrets.SUBSCRIPTION_ID }} 37 | AML_WORKSPACE_NAME: ${{ secrets.WORKSPACE_NAME }} 38 | AZURE_DOCUMENT_INTELLIGENCE_ENDPOINT: "" 39 | AZURE_DOCUMENT_INTELLIGENCE_ADMIN_KEY: "" 40 | name: code validation through execution 41 | runs-on: ubuntu-latest 42 | steps: 43 | - name: Checkout Actions 44 | uses: actions/checkout@v4 45 | - name: Azure login 46 | uses: azure/login@v2 47 | with: 48 | creds: ${{ secrets.azure_credentials }} 49 | - name: Configure Azure ML Agent 50 | uses: ./.github/actions/configure_azureml_agent 51 | - name: execute index creation step 52 | shell: bash 53 | run: | 54 | python 01_index.py --data_dir='data-ci' --config_path=${{ github.workspace }}/.github/workflows/config.json 55 | - name: execute qna step 56 | shell: bash 57 | run: | 58 | python 02_qa_generation.py --data_dir='data-ci' --config_path=${{ github.workspace }}/.github/workflows/config.json 59 | - name: execute querying step 60 | shell: bash 61 | run: | 62 | python 03_querying.py --data_dir='data-ci' --config_path=${{ github.workspace }}/.github/workflows/config.json 63 | - name: execute evaluation step 64 | shell: bash 65 | run: | 66 | python 04_evaluation.py --data_dir='data-ci' --config_path=${{ github.workspace }}/.github/workflows/config.json 67 | -------------------------------------------------------------------------------- /rag_experiment_accelerator/run/qa_generation.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from os.path import exists 3 | 4 | from dotenv import load_dotenv 5 | 6 | from rag_experiment_accelerator.config.config import Config 7 | from rag_experiment_accelerator.config.environment import Environment 8 | from rag_experiment_accelerator.data_assets.data_asset import create_data_asset 9 | from rag_experiment_accelerator.doc_loader.documentLoader import load_documents 10 | from rag_experiment_accelerator.ingest_data.acs_ingest import generate_qna 11 | from rag_experiment_accelerator.utils.logging import get_logger 12 | from rag_experiment_accelerator.sampling.clustering import ( 13 | dataframe_to_chunk_dict, 14 | load_parser, 15 | ) 16 | from rag_experiment_accelerator.sampling.clustering import cluster 17 | 18 | load_dotenv(override=True) 19 | 20 | logger = get_logger(__name__) 21 | 22 | 23 | def run( 24 | environment: Environment, 25 | config: Config, 26 | file_paths: list[str], 27 | ): 28 | """ 29 | Runs the main experiment loop for the QA generation process using the provided configuration and data. 30 | 31 | Returns: 32 | None 33 | """ 34 | logger.info("Running QA generation") 35 | 36 | all_docs = {} 37 | # Check if we have already sampled 38 | if config.index.sampling.sample_data: 39 | logger.info("Running QA Generation process with sampling") 40 | sampled_cluster_predictions_path = config.path.sampled_cluster_predictions_path( 41 | config.index.sampling.optimum_k 42 | ) 43 | if exists(sampled_cluster_predictions_path): 44 | df = pd.read_csv(sampled_cluster_predictions_path) 45 | all_docs = dataframe_to_chunk_dict(df) 46 | logger.info("Loaded sampled data") 47 | else: 48 | all_docs = load_documents( 49 | environment, 50 | config.index.chunking.chunking_strategy, 51 | config.data_formats, 52 | file_paths, 53 | 2000, 54 | 0, 55 | ) 56 | parser = load_parser() 57 | all_docs = cluster( 58 | "", all_docs, config.path.sampling_output_dir, config, parser 59 | ) 60 | else: 61 | all_docs = load_documents( 62 | environment, 63 | config.index.chunking.chunking_strategy, 64 | config.data_formats, 65 | file_paths, 66 | 2000, 67 | 0, 68 | ) 69 | 70 | # generate qna 71 | df = generate_qna( 72 | environment, config, all_docs, config.openai.azure_oai_chat_deployment_name 73 | ) 74 | # write to jsonl 75 | df.to_json(config.path.eval_data_file, orient="records", lines=True) 76 | # create data asset in mlstudio 77 | create_data_asset(config.path.eval_data_file, "eval_data", environment) 78 | -------------------------------------------------------------------------------- /rag_experiment_accelerator/embedding/tests/test_aoai_embedding_model.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import patch, MagicMock 2 | 3 | from openai.types.create_embedding_response import CreateEmbeddingResponse, Usage 4 | from openai.types.embedding import Embedding 5 | 6 | from rag_experiment_accelerator.embedding.aoai_embedding_model import AOAIEmbeddingModel 7 | 8 | 9 | @patch( 10 | "rag_experiment_accelerator.embedding.aoai_embedding_model.AOAIEmbeddingModel._initialize_client" 11 | ) 12 | def test_generate_embedding(mock_client): 13 | expected_embeddings = Embedding( 14 | embedding=[0.1, 0.2, 0.3], index=0, object="embedding" 15 | ) 16 | mock_embeddings = CreateEmbeddingResponse( 17 | data=[expected_embeddings], 18 | model="model_name", 19 | object="list", 20 | usage=Usage(prompt_tokens=0, total_tokens=0), 21 | ) 22 | 23 | mock_client().embeddings.create.return_value = mock_embeddings 24 | 25 | environment = MagicMock() 26 | model = AOAIEmbeddingModel("text-embedding-ada-002", environment=environment) 27 | embeddings = model.generate_embedding("Hello world") 28 | assert embeddings == mock_embeddings.data[0].embedding 29 | 30 | 31 | def test_emebdding_dimension_has_default(): 32 | environment = MagicMock() 33 | model = AOAIEmbeddingModel("text-embedding-ada-002", environment) 34 | assert model.dimension == 1536 35 | 36 | 37 | def test_can_set_embedding_dimension(): 38 | environment = MagicMock() 39 | model = AOAIEmbeddingModel("text-embedding-ada-002", environment, 123) 40 | assert model.dimension == 123 41 | 42 | 43 | @patch( 44 | "rag_experiment_accelerator.embedding.aoai_embedding_model.AOAIEmbeddingModel._initialize_client" 45 | ) 46 | def test_generate_embeddings_no_shortening(mock_client): 47 | mock_client().embeddings.create.return_value = MagicMock() 48 | environment = MagicMock() 49 | 50 | model = AOAIEmbeddingModel( 51 | "text-embedding-3-large", environment=environment, dimension=3072 52 | ) 53 | model.generate_embedding("Hello world") 54 | 55 | mock_client().embeddings.create.assert_called_with( 56 | input="Hello world", model="text-embedding-3-large" 57 | ) 58 | 59 | 60 | @patch( 61 | "rag_experiment_accelerator.embedding.aoai_embedding_model.AOAIEmbeddingModel._initialize_client" 62 | ) 63 | def test_generate_embeddings_with_shortening(mock_client): 64 | mock_client().embeddings.create.return_value = MagicMock() 65 | environment = MagicMock() 66 | 67 | model = AOAIEmbeddingModel( 68 | "text-embedding-3-large", 69 | environment=environment, 70 | dimension=256, 71 | shorten_dimensions=True, 72 | ) 73 | model.generate_embedding("Hello world") 74 | 75 | mock_client().embeddings.create.assert_called_with( 76 | input="Hello world", model="text-embedding-3-large", dimensions=256 77 | ) 78 | -------------------------------------------------------------------------------- /.github/workflows/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "$schema": "../../config.schema.json", 3 | "experiment_name": "baseline", 4 | "job_name": "baseline_job", 5 | "job_description": "", 6 | "data_formats": ["*"], 7 | "main_instruction": "", 8 | "use_checkpoints": true, 9 | "path": {}, 10 | "index": { 11 | "index_name_prefix": "ci", 12 | "ef_construction": [400], 13 | "ef_search": [400], 14 | "chunking": { 15 | "preprocess": false, 16 | "chunk_size": [1000], 17 | "overlap_size": [200], 18 | "generate_title": false, 19 | "generate_summary": false, 20 | "override_content_with_summary": false, 21 | "chunking_strategy": "basic", 22 | "azure_document_intelligence_model": "prebuilt-read" 23 | }, 24 | "embedding_model": [ 25 | { 26 | "type": "sentence-transformer", 27 | "model_name": "all-mpnet-base-v2" 28 | } 29 | ], 30 | "sampling": { 31 | "sample_data": false, 32 | "percentage": 5, 33 | "optimum_k": "auto", 34 | "min_cluster": 2, 35 | "max_cluster": 30 36 | } 37 | }, 38 | "language":{ 39 | "analyzer": { 40 | "analyzer_name": "en.microsoft", 41 | "index_analyzer_name": "", 42 | "search_analyzer_name": "", 43 | "char_filters": [], 44 | "tokenizers": [], 45 | "token_filters": [] 46 | }, 47 | "query_language": "en-us" 48 | }, 49 | "rerank": { 50 | "enabled": true, 51 | "type": "cross_encoder", 52 | "llm_rerank_threshold": 3, 53 | "cross_encoder_at_k": 4, 54 | "cross_encoder_model": "cross-encoder/stsb-roberta-base" 55 | }, 56 | "search": { 57 | "retrieve_num_of_documents": 5, 58 | "search_type": [ 59 | "search_for_manual_hybrid", 60 | "search_for_match_Hybrid_multi", 61 | "search_for_match_semantic" 62 | ], 63 | "search_relevancy_threshold": 0.8 64 | }, 65 | "query_expansion": { 66 | "query_expansion": true, 67 | "hyde": "generated_hypothetical_answer", 68 | "min_query_expansion_related_question_similarity_score": 90, 69 | "expand_to_multiple_questions": true 70 | }, 71 | "openai": { 72 | "azure_oai_chat_deployment_name": "gpt-35-turbo", 73 | "azure_oai_eval_deployment_name": "gpt-35-turbo", 74 | "temperature": 0 75 | }, 76 | "eval": { 77 | "metric_types": [ 78 | "fuzzy_score", 79 | "cosine_ochiai", 80 | "rouge2_recall", 81 | "bert_all_MiniLM_L6_v2", 82 | "bert_distilbert_base_nli_stsb_mean_tokens", 83 | "llm_answer_relevance", 84 | "llm_context_precision" 85 | ] 86 | } 87 | } -------------------------------------------------------------------------------- /config.sample.json: -------------------------------------------------------------------------------- 1 | { 2 | "$schema": "https://raw.githubusercontent.com/microsoft/rag-experiment-accelerator/development/config.schema.json", 3 | "experiment_name": "exp-name", 4 | "job_name": "job-name", 5 | "job_description": "", 6 | "data_formats": ["*"], 7 | "main_instruction": "", 8 | "use_checkpoints": true, 9 | "path": {}, 10 | "index": { 11 | "index_name_prefix": "idx", 12 | "ef_construction": [400], 13 | "ef_search": [400], 14 | "chunking": { 15 | "preprocess": false, 16 | "chunk_size": [1000], 17 | "overlap_size": [200], 18 | "generate_title": false, 19 | "generate_summary": false, 20 | "override_content_with_summary": false, 21 | "chunking_strategy": "basic", 22 | "azure_document_intelligence_model": "prebuilt-read" 23 | }, 24 | "embedding_model": [ 25 | { 26 | "type": "sentence-transformer", 27 | "model_name": "all-mpnet-base-v2" 28 | } 29 | ], 30 | "sampling": { 31 | "sample_data": false, 32 | "percentage": 5, 33 | "optimum_k": "auto", 34 | "min_cluster": 2, 35 | "max_cluster": 30, 36 | "only_sample": false 37 | } 38 | }, 39 | "language": { 40 | "analyzer": { 41 | "analyzer_name": "en.microsoft", 42 | "index_analyzer_name": "", 43 | "search_analyzer_name": "", 44 | "char_filters": [], 45 | "tokenizers": [], 46 | "token_filters": [] 47 | }, 48 | "query_language": "en-us" 49 | }, 50 | "rerank": { 51 | "enabled": true, 52 | "type": "cross_encoder", 53 | "llm_rerank_threshold": 3, 54 | "cross_encoder_at_k": 4, 55 | "cross_encoder_model": "cross-encoder/stsb-roberta-base" 56 | }, 57 | "search": { 58 | "retrieve_num_of_documents": 5, 59 | "search_type": [ 60 | "search_for_manual_hybrid", 61 | "search_for_match_Hybrid_multi", 62 | "search_for_match_semantic" 63 | ], 64 | "search_relevancy_threshold": 0.8 65 | }, 66 | "query_expansion": { 67 | "hyde": "disabled", 68 | "query_expansion": false, 69 | "min_query_expansion_related_question_similarity_score": 90, 70 | "expand_to_multiple_questions": false 71 | }, 72 | "openai": { 73 | "azure_oai_chat_deployment_name": "gpt-35-turbo", 74 | "azure_oai_eval_deployment_name": "gpt-35-turbo", 75 | "temperature": 0 76 | }, 77 | "eval": { 78 | "metric_types": [ 79 | "fuzzy_score", 80 | "bert_all_MiniLM_L6_v2", 81 | "cosine_ochiai", 82 | "bert_distilbert_base_nli_stsb_mean_tokens", 83 | "llm_answer_relevance", 84 | "llm_context_precision" 85 | ] 86 | } 87 | } 88 | -------------------------------------------------------------------------------- /rag_experiment_accelerator/artifact/handlers/tests/test_artifact_handler.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import Mock 2 | import pytest 3 | 4 | from rag_experiment_accelerator.artifact.handlers.artifact_handler import ( 5 | ArtifactHandler, 6 | ) 7 | from rag_experiment_accelerator.artifact.handlers.exceptions import LoadException 8 | 9 | 10 | def test_loads(): 11 | data = "This is test data" 12 | mock_writer = Mock() 13 | mock_loader = Mock() 14 | mock_loader.can_handle.return_value = True 15 | mock_loader.load.return_value = [data] 16 | 17 | handler = ArtifactHandler("data_location", writer=mock_writer, loader=mock_loader) 18 | 19 | name = "test.jsonl" 20 | loaded_data = handler.load(name) 21 | 22 | assert loaded_data == [data] 23 | 24 | 25 | def test_save_dict(): 26 | mock_writer = Mock() 27 | mock_loader = Mock() 28 | 29 | handler = ArtifactHandler("data_location", writer=mock_writer, loader=mock_loader) 30 | 31 | dict_to_save = {"testing": 123, "mic": "check"} 32 | artifact_name = "test.jsonl" 33 | handler.save_dict(dict_to_save, "test.jsonl") 34 | path = f"{handler.data_location}/{artifact_name}" 35 | 36 | assert mock_writer.write.call_count == 1 37 | assert mock_writer.write.called_with(dict_to_save, path) 38 | 39 | 40 | def test_loads_raises_no_data_returned(): 41 | mock_writer = Mock() 42 | mock_loader = Mock() 43 | mock_loader.can_handle.return_value = True 44 | mock_loader.load.return_value = [] 45 | handler = ArtifactHandler("data_location", writer=mock_writer, loader=mock_loader) 46 | name = "test.jsonl" 47 | 48 | with pytest.raises(LoadException): 49 | handler.load(name) 50 | 51 | 52 | def test_load_raises_cant_handle(): 53 | mock_writer = Mock() 54 | mock_loader = Mock() 55 | handler = ArtifactHandler("data_location", writer=mock_writer, loader=mock_loader) 56 | 57 | mock_loader.can_handle.return_value = False 58 | 59 | with pytest.raises(LoadException): 60 | handler.load("test.txt") 61 | 62 | 63 | def test_handle_archive(): 64 | mock_writer = Mock() 65 | mock_loader = Mock() 66 | mock_writer.exists.return_value = True 67 | data_location = "data_location" 68 | handler = ArtifactHandler(data_location, writer=mock_writer, loader=mock_loader) 69 | 70 | name = "test.jsonl" 71 | dest = handler.handle_archive(name) 72 | 73 | src = f"{data_location}/{name}" 74 | mock_writer.copy.assert_called_once_with(src, dest) 75 | mock_writer.delete.assert_called_once_with(src) 76 | 77 | 78 | def test_handle_archive_no_op(): 79 | mock_writer = Mock() 80 | mock_loader = Mock() 81 | # only archive is exists 82 | mock_writer.exists.return_value = False 83 | handler = ArtifactHandler("data_location", writer=mock_writer, loader=mock_loader) 84 | 85 | dest = handler.handle_archive("test.jsonl") 86 | 87 | mock_writer.copy.assert_not_called() 88 | mock_writer.delete.assert_not_called() 89 | assert dest is None 90 | -------------------------------------------------------------------------------- /rag_experiment_accelerator/reranking/reranker.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | from sentence_transformers import CrossEncoder 4 | 5 | from rag_experiment_accelerator.llm.prompt import rerank_prompt_instruction 6 | from rag_experiment_accelerator.llm.response_generator import ResponseGenerator 7 | from rag_experiment_accelerator.utils.logging import get_logger 8 | 9 | logger = get_logger(__name__) 10 | 11 | 12 | def cross_encoder_rerank_documents( 13 | documents, user_prompt, output_prompt, model_name, k 14 | ): 15 | """ 16 | Reranks a list of documents based on their relevance to a user prompt using a cross-encoder model. 17 | 18 | Args: 19 | documents (list): A list of documents to be reranked. 20 | user_prompt (str): The user prompt to be used as the query. 21 | output_prompt (str): The output prompt to be used as the context. 22 | model_name (str): The name of the pre-trained cross-encoder model to be used. 23 | k (int): The number of top documents to be returned. 24 | 25 | Returns: 26 | list: A list of the top k documents, sorted by their relevance to the user prompt. 27 | """ 28 | if not documents: 29 | return [] 30 | 31 | model = CrossEncoder(model_name) 32 | cross_scores_ques = model.predict( 33 | [[user_prompt, item] for item in documents], 34 | apply_softmax=True, 35 | convert_to_numpy=True, 36 | ) 37 | 38 | top_indices_ques = cross_scores_ques.argsort()[-k:][::-1] 39 | sub_context = [] 40 | for idx in list(top_indices_ques): 41 | sub_context.append(documents[idx]) 42 | 43 | return sub_context 44 | 45 | 46 | def llm_rerank_documents( 47 | documents, question, response_generator: ResponseGenerator, rerank_threshold 48 | ): 49 | """ 50 | Reranks a list of documents based on a given question using the LLM model. 51 | 52 | Args: 53 | documents (list): A list of documents to be reranked. 54 | question (str): The question to be used for reranking. 55 | response_generator (ResponseGenerator): The initialised ResponseGenerator to use. 56 | rerank_threshold (int): The threshold for reranking documents. 57 | 58 | Returns: 59 | list: A list of reranked documents. 60 | """ 61 | rerank_context = "" 62 | for index, docs in enumerate(documents): 63 | rerank_context += "\ndocument " + str(index) + ":\n" 64 | rerank_context += docs + "\n" 65 | 66 | response: dict[str, int] | None = response_generator.generate_response( 67 | rerank_prompt_instruction, 68 | documents=rerank_context, 69 | question=question, 70 | prompt_last=True, 71 | ) 72 | 73 | logger.debug("Reranker response:\n", response) 74 | 75 | if response is None: 76 | return documents 77 | 78 | result = [] 79 | for key, _ in sorted(response.items(), key=lambda x: x[1], reverse=True): 80 | document_index = int(re.search(r"document_(\d+)", key)) 81 | result.append(documents[document_index]) 82 | 83 | return result 84 | --------------------------------------------------------------------------------