├── .circleci └── config.yml ├── .gitignore ├── RAGflow overview.drawio ├── README.md ├── TODO ├── app ├── .dockerignore ├── .streamlit │ └── config.toml ├── Dockerfile ├── __init__.py ├── main.py ├── page_apikeys.py ├── page_chat.py ├── page_dashboard.py ├── page_documentstore.py ├── page_filemanager.py ├── page_home.py ├── page_login.py ├── page_parameters.py ├── requirements.txt └── utils.py ├── docker-compose.dev.yaml ├── docker-compose.integration.test.yaml ├── docker-compose.local.test.yaml ├── k8s ├── app-cluster-ip-service.yaml ├── app-deployment.yaml ├── backend-cluster-ip-service.yaml ├── backend-deployment.yaml ├── ingress-service.yaml ├── pgadmin-cluster-ip-service.yaml ├── pgadmin-deployment.yaml ├── postgres-cluster-ip-service.yaml ├── postgres-deployment.yaml ├── shared-persistent-volume-claim.yaml ├── vectorstore-cluster-ip-service.yaml └── vectorstore-deployment.yaml ├── ragflow ├── .dockerignore ├── Dockerfile ├── __init__.py ├── api │ ├── __init__.py │ ├── database.py │ ├── main.py │ ├── models.py │ ├── routers │ │ ├── __init__.py │ │ ├── auth_router.py │ │ ├── chats_router.py │ │ ├── configs_router.py │ │ ├── evals_router.py │ │ ├── gens_router.py │ │ └── user_router.py │ ├── schemas.py │ └── services │ │ ├── __init__.py │ │ ├── auth_service.py │ │ ├── common_service.py │ │ └── user_service.py ├── commons │ ├── __init__.py │ ├── chroma │ │ ├── ChromaClient.py │ │ └── __init__.py │ ├── configurations │ │ ├── BaseConfigurations.py │ │ ├── Hyperparameters.py │ │ ├── QAConfigurations.py │ │ └── __init__.py │ ├── prompts │ │ ├── __init__.py │ │ ├── grade_answers_prompts.py │ │ ├── grade_retriever_prompts.py │ │ ├── qa_answer_prompts.py │ │ └── qa_geneneration_prompts.py │ └── vectorstore │ │ ├── __init__.py │ │ └── pgvector_utils.py ├── evaluation │ ├── __init__.py │ ├── hp_evaluator.py │ ├── metrics │ │ ├── __init__.py │ │ ├── answer_embedding_similarity.py │ │ ├── predicted_answer_accuracy.py │ │ ├── retriever_mrr_accuracy.py │ │ ├── retriever_semantic_accuracy.py │ │ └── rouge_score.py │ └── utils.py ├── example.py ├── generation │ ├── __init__.py │ └── label_dataset_generator.py ├── requirements.txt └── utils │ ├── __init__.py │ ├── doc_processing.py │ ├── hyperparam_chats.py │ └── utils.py ├── resources ├── dev │ ├── document_store │ │ ├── msg_life-gb-2021-EN_final_1-15.pdf │ │ ├── msg_life-gb-2021-EN_final_16-30.pdf │ │ ├── msg_life-gb-2021-EN_final_31-45.pdf │ │ ├── msg_life-gb-2021-EN_final_46-59.pdf │ │ └── msg_life-gb-2021-EN_final_60-end.pdf │ ├── hyperparameters.json │ ├── hyperparameters_results.json │ ├── hyperparameters_results_data.csv │ ├── label_dataset.json │ └── label_dataset_gen_params.json └── tests │ ├── document_store │ ├── churchill_speech.docx │ ├── churchill_speech.pdf │ └── churchill_speech.txt │ ├── input_hyperparameters.json │ ├── input_label_dataset.json │ ├── input_label_dataset_gen_params_with_upsert.json │ └── input_label_dataset_gen_params_without_upsert.json ├── tests ├── Dockerfile ├── __init__.py ├── conftest.py ├── requirements.txt ├── test_backend_configs.py ├── test_backend_generator.py ├── test_backend_hp_evaluator.py ├── utils.py └── wait-for-it.sh └── vectorstore ├── .dockerignore ├── Dockerfile ├── requirements.txt └── server.py /.circleci/config.yml: -------------------------------------------------------------------------------- 1 | version: 2.1 2 | 3 | executors: 4 | python-docker-executor: 5 | docker: 6 | - image: cimg/python:3.11 7 | 8 | jobs: 9 | # first we build and push the new Docker images to Docker Hub 10 | push_app_image: 11 | executor: python-docker-executor 12 | steps: 13 | - checkout 14 | - set_env_vars 15 | - docker_login 16 | - setup_remote_docker: 17 | docker_layer_caching: false 18 | - run: 19 | name: Inject build info into frontend 20 | command: | 21 | sed -i "s/\$BUILD_NUMBER/${PIPELINE_NUMBER}/g" ./app/main.py 22 | sed -i "s/\$BUILD_DATE/${NOW}/g" ./app/main.py 23 | sed -i "s/\$GIT_SHA/${TAG}/g" ./app/main.py 24 | - run: 25 | name: Push app image 26 | command: | 27 | docker build -t $DOCKER_USER/$IMAGE_NAME_BASE-app:$TAG -t $DOCKER_USER/$IMAGE_NAME_BASE-app:latest ./app 28 | docker push $DOCKER_USER/$IMAGE_NAME_BASE-app:$TAG 29 | docker push $DOCKER_USER/$IMAGE_NAME_BASE-app:latest 30 | push_ragflow_image: 31 | executor: python-docker-executor 32 | steps: 33 | - checkout 34 | - set_env_vars 35 | - docker_login 36 | - setup_remote_docker: 37 | docker_layer_caching: false 38 | - run: 39 | name: Push ragflow image 40 | command: | 41 | docker build -t $DOCKER_USER/$IMAGE_NAME_BASE-backend:$TAG -t $DOCKER_USER/$IMAGE_NAME_BASE-backend:latest ./ragflow 42 | docker push $DOCKER_USER/$IMAGE_NAME_BASE-backend:$TAG 43 | docker push $DOCKER_USER/$IMAGE_NAME_BASE-backend:latest 44 | push_vectorstore_image: 45 | executor: python-docker-executor 46 | steps: 47 | - checkout 48 | - set_env_vars 49 | - docker_login 50 | - setup_remote_docker: 51 | docker_layer_caching: false 52 | - run: 53 | name: Push vectorstore image 54 | command: | 55 | docker build -t $DOCKER_USER/$IMAGE_NAME_BASE-vectorstore:$TAG -t $DOCKER_USER/$IMAGE_NAME_BASE-vectorstore:latest ./vectorstore 56 | docker push $DOCKER_USER/$IMAGE_NAME_BASE-vectorstore:$TAG 57 | docker push $DOCKER_USER/$IMAGE_NAME_BASE-vectorstore:latest 58 | push_test_image: 59 | executor: python-docker-executor 60 | steps: 61 | - checkout 62 | - set_env_vars 63 | - docker_login 64 | - setup_remote_docker: 65 | docker_layer_caching: false 66 | - run: 67 | name: Push test image 68 | command: | 69 | docker build -t $DOCKER_USER/$IMAGE_NAME_BASE-test:$TAG -t $DOCKER_USER/$IMAGE_NAME_BASE-test:latest ./tests 70 | docker push $DOCKER_USER/$IMAGE_NAME_BASE-test:$TAG 71 | docker push $DOCKER_USER/$IMAGE_NAME_BASE-test:latest 72 | 73 | # then we run tests on the new Docker images 74 | run_integration_tests: 75 | machine: true 76 | steps: 77 | - checkout 78 | - run: 79 | name: Clear Docker cache 80 | command: docker system prune --all --force --volumes 81 | - run: 82 | name: Run Docker Compose to build, start and test 83 | command: | 84 | docker-compose -f docker-compose.integration.test.yaml up --exit-code-from test-suite 85 | - run: 86 | name: Shut services down 87 | command: docker-compose -f docker-compose.integration.test.yaml down 88 | - store_artifacts: 89 | path: /tests/test-reports 90 | destination: test-reports 91 | 92 | # if everything was successful we deploy the Docker images into the k8s cluster 93 | deploy_to_gke_k8s_cluster: 94 | docker: 95 | - image: google/cloud-sdk 96 | steps: 97 | - checkout 98 | - set_env_vars 99 | - setup_remote_docker: 100 | docker_layer_caching: false 101 | - run: 102 | name: Setup Google Cloud SDK 103 | command: | 104 | echo "$GOOGLE_SERVICE_KEY" > ${HOME}/gcloud-service-key.json 105 | gcloud auth activate-service-account --key-file=${HOME}/gcloud-service-key.json 106 | gcloud config set project "$GOOGLE_PROJECT_ID" 107 | gcloud config set compute/zone "$GOOGLE_COMPUTE_ZONE" 108 | gcloud container clusters get-credentials "$GKE_CLUSTER_NAME" 109 | - run: 110 | name: Deploy to GKE k8s cluster 111 | command: | 112 | kubectl apply -f ./k8s 113 | kubectl set image deployments/app-deployment app-frontend=$DOCKER_USER/$IMAGE_NAME_BASE-app:$TAG 114 | kubectl set image deployments/backend-deployment ragflow-backend=$DOCKER_USER/$IMAGE_NAME_BASE-backend:$TAG 115 | kubectl set image deployments/vectorstore-deployment chromadb-vectorstore=$DOCKER_USER/$IMAGE_NAME_BASE-vectorstore:$TAG 116 | 117 | workflows: 118 | version: 2 119 | build-deploy: 120 | jobs: 121 | - push_app_image 122 | - push_ragflow_image 123 | - push_vectorstore_image 124 | - push_test_image 125 | - run_integration_tests: 126 | requires: 127 | - push_app_image 128 | - push_ragflow_image 129 | - push_vectorstore_image 130 | - push_test_image 131 | - deploy_to_gke_k8s_cluster: 132 | requires: 133 | - run_integration_tests 134 | 135 | commands: 136 | set_env_vars: 137 | steps: 138 | - run: 139 | name: Setup tag and base image name 140 | command: | 141 | echo 'export TAG=${CIRCLE_SHA1:0:8}' >> $BASH_ENV 142 | echo 'export IMAGE_NAME_BASE=ragflow' >> $BASH_ENV 143 | echo 'export NOW=$(date --utc +"%Y-%m-%d %H:%M:%S")' >> $BASH_ENV 144 | echo 'export PIPELINE_NUMBER=<< pipeline.number >>' >> $BASH_ENV 145 | docker_login: 146 | steps: 147 | - run: 148 | name: Login into Docker Hub 149 | command: | 150 | echo "$DOCKER_PWD" | docker login --username "$DOCKER_USER" --password-stdin -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # env variables 2 | .env 3 | 4 | # local stuff 5 | .venv/ 6 | .vscode/ 7 | 8 | **.bin 9 | **/.ipynb_checkpoints 10 | **/__pycache__ 11 | 12 | # db 13 | postgres/ 14 | vectorstore/chromadb/ 15 | *.sqlite3 16 | 17 | # jupyter notebooks 18 | *.ipynb 19 | **/notebooks 20 | 21 | # k8s configs 22 | k8s/dashboard.yaml 23 | 24 | # misc 25 | #resources/* 26 | #!resources/msg_life-gb-2021-EN_final.pdf 27 | -------------------------------------------------------------------------------- /RAGflow overview.drawio: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # RAGflow: Build optimized and robust LLM applications 2 | 3 | [![CircleCI](https://dl.circleci.com/status-badge/img/circleci/6FfqBzs4fBDyTPvBNqnq5x/8HU8omXUEUaEgrpWMj271K/tree/main.svg?style=shield&circle-token=545d0058e25f4566f54a9282ef976f6a8a77b327)](https://app.circleci.com/pipelines/circleci/6FfqBzs4fBDyTPvBNqnq5x) 4 | 5 | RAGflow provides tools for constructing and evaluating Retrieval Augmented Generation (RAG) systems, empowering developers to craft efficient question-answer applications leveraging LLMs. The stack consists of 6 | 7 | `Language` [Python](https://www.python.org/)\ 8 | `Frameworks for LLMs` [LangChain](https://www.langchain.com/) [OpenAI](https://www.openai.com/) [Hugging Face](https://huggingface.co/)\ 9 | `Framework for API` [FastAPI](https://fastapi.tiangolo.com/)\ 10 | `Databases` [ChromaDB](https://www.trychroma.com/) [Postgres](https://www.postgresql.org/)\ 11 | `Frontend` [Streamlit](https://www.streamlit.io/)\ 12 | `CI/CD` [Docker](https://www.docker.com/) [Kubernetes](https://kubernetes.io/) [CircleCI](https://circleci.com/) [GKE](https://cloud.google.com/kubernetes-engine) 13 | 14 | # 🚀 Getting Started 15 | 16 | - CircleCI pushes the Docker images after each successful build to 17 | - https://hub.docker.com/u/andreasx42 18 | - Google Kubernetes Engine cluster currently not available 19 | - Checkout repository 20 | - Start application with ‘docker-compose up --build’ 21 | - Application should be available on localhost:8501. 22 | - Backend API documentation is available on localhost:8080/docs 23 | - Use Kubernetes with 'kubectl apply -f k8s' to deploy locally 24 | - Application should be available directly on localhost/ 25 | - For backend API access we use nginx routing with localhost/api/\* 26 | - Be aware to check deployment configs for image versions 27 | 28 | # 📖 What is Retrievel Augmented Generation (RAG)? 29 | 30 | Description 31 |

Source

32 | 33 | In RAG, when a user query is received, relevant documents or passages are retrieved from a massive corpus, i.e. a document store. These retrieved documents are then provided as context to a generative model, which synthesizes a coherent response or answer using both the input query and the retrieved information. This approach leverages the strengths of both retrieval-based and generative systems, aiming to produce accurate and well-formed responses by drawing from vast amounts of textual data. 34 | 35 | # 🚀 Workflow of RAGflow 36 | 37 | - Automatic Generation of Question-Answer Pairs\ 38 | Begin with RAGflow's capability to generate relevant question-answer pairs from provided documents which is used as an evaluation dataset to evaluate RAG systems. 39 | Hyperparameter Evaluation 40 | - Evaluate provided hyperparameters \ 41 | After generating Q&A pairs, dive into hyperparameter evaluation. Provide your hyperparameters, let RAGflow evaluate their efficacy, and obtain insights for crafting robust RAG systems. 42 | This approach allows you to select efficient document splitting strategies, language and embedding models which could be further finetuned with respect to your document store. 43 | 44 | Here is a schematic overview: 45 | 46 | ![schematics](https://github.com/AndreasX42/RAGflow/assets/141482745/8ea78a21-8224-4baf-a441-dc4aa8249762) 47 | 48 | # 🌟 Key Features & Functionalities 49 | 50 | - `Document Store Integration` Provide documents in formats like pdf and docx as knowledge base. 51 | - `Dynamic Parameter Selection` Customize parameters such as document splitting strategies, embedding model, and question-answering LLMs for evaluations. 52 | - `Automated Dataset Generation` Automatically generates question answer pairs from the provided documents as evaluation dataset to evaluate each parameterized RAG system. 53 | - `Evaluation & Optimization` Optimize performance using grid searches across parameter spaces. 54 | - `Advanced Retrieval Mechanisms` Employ techniques like "MMR" and the SelfQueryRetriever for optimal data extraction. 55 | - `Integration with External Platforms` Collaborate with platforms like Anyscale, MosaicML, and Replicate for enhanced functionalities and state-of-the-art LLM models. 56 | - `Interactive Feedback Loop` Refine and improve your RAG system with interactive feedback based on real-world results. 57 | 58 | # 🛠️ Development 59 | 60 | Directory Structure 61 | 62 | - `/.circleci` CircleCI integration config for CI/CD pipeline. 63 | - `/app` Frontend components and resources in Streamlit. 64 | - `/ragflow` Backend services and APIs. 65 | - `/tests` Test scripts and test data. 66 | - `/resources` Data storage. 67 | - `/vectorstore` ChromaDB component. 68 | 69 | # 🌐 Links & Resources 70 | 71 | - TBA 72 | -------------------------------------------------------------------------------- /TODO: -------------------------------------------------------------------------------- 1 | For k8s deployment: 2 | - use $TAG as variable in deployment config and use sed to change to build number 3 | - sed -i 's/\$TAG/{{IMAGE_TAG}}/g' client-deployment.yaml 4 | or: 5 | - imperative command 6 | - kubectl set image deployment/$name client=andreasx42/ragflow-$img_name:$tag 7 | Secrets: 8 | - kubectl create secret generic --from-literal key=value 9 | Nginx installation 10 | - kubectl apply -f https://raw.githubusercontent.com/kubernetes/ingress-nginx/controller-v1.8.2/deploy/static/provider/cloud/deploy.yaml 11 | k8s dashboard: 12 | - create admin: kubectl -n kubernetes-dashboard create token admin-user 13 | 14 | 15 | HuggingFaceEmbeddings( 16 | model_name=model_name, 17 | model_kwargs={"device": "cuda"}, 18 | encode_kwargs={"device": "cuda", "batch_size": 100}) -------------------------------------------------------------------------------- /app/.dockerignore: -------------------------------------------------------------------------------- 1 | # git 2 | .git 3 | .gitignore 4 | .cache 5 | **.log 6 | 7 | # Byte-compiled / optimized / DLL files 8 | **/__pycache__ 9 | **/*.pyc 10 | **.py[cod] 11 | *$py.class 12 | .vscode/ 13 | 14 | # Jupyter Notebook 15 | **/.ipynb_checkpoints 16 | notebooks/ 17 | 18 | # dotenv 19 | .env 20 | 21 | # virtualenv 22 | .venv* 23 | venv*/ 24 | ENV*/ -------------------------------------------------------------------------------- /app/.streamlit/config.toml: -------------------------------------------------------------------------------- 1 | [server] 2 | maxUploadSize = 10 3 | # this is needed for local development with docker 4 | # if you don't want to start the default browser: 5 | headless = true 6 | # you will need this for local development: 7 | runOnSave = true 8 | # you will need this if running docker on windows host: 9 | fileWatcherType = "poll" 10 | 11 | [theme] 12 | primaryColor = "#E3735E" 13 | backgroundColor = "#121212" # Dark gray background 14 | secondaryBackgroundColor = "#1E1E1E" # Slightly lighter gray for secondary elements 15 | textColor = "#E0E0E0" # Light gray text for better contrast against the dark background 16 | font = "sans serif" -------------------------------------------------------------------------------- /app/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.11 2 | 3 | # Port the app is running on 4 | EXPOSE 8501 5 | 6 | # Install dependencies 7 | WORKDIR /app 8 | 9 | COPY ./requirements.txt ./ 10 | 11 | RUN pip install --no-cache-dir --upgrade -r ./requirements.txt 12 | 13 | # Copy all into image 14 | COPY ./ ./ 15 | 16 | HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health 17 | 18 | CMD ["streamlit", "run", "main.py", "--server.port=8501", "--server.address=0.0.0.0"] -------------------------------------------------------------------------------- /app/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AndreasX42/RAGflow/cec31b91af8ab7faa05a3d0376338c7d294d29aa/app/__init__.py -------------------------------------------------------------------------------- /app/main.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | from streamlit_option_menu import option_menu 3 | 4 | from page_parameters import page_parameters 5 | from page_home import page_home 6 | from page_documentstore import page_documentstore 7 | from page_dashboard import page_dashboard 8 | from page_login import page_login 9 | from page_filemanager import page_filemanager 10 | from page_apikeys import page_apikeys 11 | from page_chat import page_chat 12 | from utils import get_auth_user 13 | 14 | def main(): 15 | # Display selected page with the respective function 16 | def sideBar(): 17 | with st.sidebar: 18 | selected = option_menu( 19 | menu_title="Main Menu", 20 | options=[ 21 | "Home", 22 | "Dashboard", 23 | "Parameters", 24 | "Q&A Chats", 25 | "Documents", 26 | "File Manager", 27 | "API Keys", 28 | "Login", 29 | ], 30 | icons=[ 31 | "house-fill", 32 | "clipboard2-pulse-fill", 33 | "toggles", 34 | "chat-text-fill", 35 | "cloud-upload-fill", 36 | "file-earmark-text-fill", 37 | "safe-fill", 38 | "door-open-fill", 39 | ], 40 | menu_icon="cast", 41 | default_index=0, 42 | styles={ 43 | "nav-link": { 44 | "--hover-color": "#1E1E1E", 45 | }, 46 | "nav-link-selected": {"background-color": "#880808"}, 47 | }, 48 | ) 49 | 50 | if selected == "Home": 51 | page_home() 52 | if selected == "Dashboard": 53 | page_dashboard() 54 | if selected == "Parameters": 55 | page_parameters() 56 | if selected == "Q&A Chats": 57 | page_chat() 58 | if selected == "Documents": 59 | page_documentstore() 60 | if selected == "File Manager": 61 | page_filemanager() 62 | if selected == "API Keys": 63 | page_apikeys() 64 | if selected == "Login": 65 | page_login() 66 | 67 | sideBar() 68 | 69 | 70 | if __name__ == "__main__": 71 | st.set_page_config( 72 | page_title="RAGflow", 73 | page_icon="🧊", 74 | layout="wide", 75 | initial_sidebar_state="expanded", 76 | ) 77 | 78 | # set user_id attribute in session state after browser refresh if user already authenticated 79 | if "user_id" not in st.session_state or not st.session_state.user_id: 80 | user_data, success = get_auth_user() 81 | if success: 82 | st.session_state.user_id = str(user_data.get("id")) 83 | 84 | main() 85 | 86 | with st.sidebar: 87 | if "user_id" in st.session_state and st.session_state.user_id: 88 | auth_button = '
Authenticated
' 89 | 90 | else: 91 | auth_button = '
Not Authenticated
' 92 | 93 | st.markdown(auth_button, unsafe_allow_html=True) 94 | 95 | st.markdown("
" * 0, unsafe_allow_html=True) 96 | st.markdown("---") 97 | st.markdown( 98 | '
Made in  Streamlit logo  by @AndreasX42
', 99 | unsafe_allow_html=True, 100 | ) 101 | 102 | version_info = """ 103 |
104 | Build Number: $BUILD_NUMBER
105 | Git Commit SHA: $GIT_SHA
106 | Build Time (UTC): $BUILD_DATE
107 |
108 | """ 109 | 110 | st.markdown(version_info, unsafe_allow_html=True) 111 | -------------------------------------------------------------------------------- /app/page_apikeys.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | from utils import display_user_login_warning 3 | 4 | 5 | def page_apikeys(): 6 | st.title("Provide API Keys") 7 | st.subheader("Enter the API Keys for all services you want to use.") 8 | 9 | if display_user_login_warning(): 10 | return 11 | 12 | if "api_keys" not in st.session_state: 13 | st.session_state.api_keys = {} 14 | 15 | # Create text area to input API keys 16 | input_text = st.text_area( 17 | "Enter the required API keys for your intended use. The keys will not get stored outside of Streamlits session state.\n\nProvide one 'name=key' pair per line, for example:", 18 | "OPENAI_API_KEY = your_openai_key\nANYSCALE_API_KEY = your_anyscale_key", 19 | ) 20 | 21 | # Create a submit button 22 | if st.button("Submit"): 23 | store_in_cache(input_text) 24 | st.success("API keys stored successfully!") 25 | 26 | # Display stored API keys for testing purposes (you might want to remove this in a real application) 27 | if st.session_state.api_keys: 28 | keys = "" 29 | for key, value in st.session_state.api_keys.items(): 30 | keys += f"{key}: {value}\n" 31 | 32 | st.markdown("
" * 1, unsafe_allow_html=True) 33 | st.code(keys) 34 | 35 | 36 | def store_in_cache(api_keys: str) -> None: 37 | """Writes api keys to streamlit session state.""" 38 | st.session_state.api_keys = {} 39 | 40 | for line in api_keys.splitlines(): 41 | if "=" in line: 42 | key, value = line.split("=", 1) 43 | st.session_state.api_keys[key.strip()] = value.strip() 44 | -------------------------------------------------------------------------------- /app/page_chat.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import streamlit as st 3 | from utils import * 4 | import os 5 | 6 | 7 | def page_chat(): 8 | # Title of the page 9 | st.subheader("Chat with a RAG model from a hyperparameter evaluation") 10 | 11 | if display_user_login_warning(): 12 | return 13 | 14 | if not os.path.exists(get_hyperparameters_results_path()): 15 | st.warning("No hyperparameter results available. Run some evaluation.") 16 | return 17 | 18 | if "messages" not in st.session_state: 19 | st.session_state.messages = [ 20 | {"role": "assistant", "content": "How may I help you?"} 21 | ] 22 | 23 | with st.expander("View hyperparameter results"): 24 | df = load_hp_results() 25 | showData = st.multiselect( 26 | "Filter: ", 27 | df.columns, 28 | default=[ 29 | "id", 30 | "chunk_size", 31 | "chunk_overlap", 32 | "length_function_name", 33 | "num_retrieved_docs", 34 | "search_type", 35 | "similarity_method", 36 | "use_llm_grader", 37 | ], 38 | ) 39 | st.dataframe(df[showData], use_container_width=True) 40 | 41 | # select chat model from hyperparam run id 42 | hp_id = 0 43 | hp_id = st.selectbox( 44 | "Select chat model from hyperparameter evaluations", options=list(df.id) 45 | ) 46 | # Check if hp_id changed and clear chat history if it did 47 | if ( 48 | "hp_id" in st.session_state 49 | and hp_id != st.session_state.hp_id 50 | or "hp_id" not in st.session_state 51 | ): 52 | st.session_state.messages = [ 53 | {"role": "assistant", "content": "How may I help you?"} 54 | ] 55 | # Update the session state with the new hp_id 56 | st.session_state.hp_id = hp_id 57 | 58 | st.markdown("
" * 1, unsafe_allow_html=True) 59 | st.write("Chat with the chosen parametrized RAG") 60 | 61 | for message in st.session_state.messages: 62 | with st.chat_message(message["role"]): 63 | st.write(message["content"]) 64 | 65 | # User-provided query 66 | if query := st.chat_input(): 67 | st.session_state.messages.append({"role": "user", "content": query}) 68 | with st.chat_message("user"): 69 | st.write(query) 70 | 71 | # Generate a new response if last message is not from assistant 72 | if st.session_state.messages[-1]["role"] != "assistant": 73 | with st.chat_message("assistant"): 74 | # get and display answer 75 | answer, source_docs = get_rag_response_stream(hp_id, query) 76 | 77 | # display retrieved documents 78 | if "I don't know".lower() not in answer.lower() and source_docs is not None: 79 | display_documents(source_docs) 80 | 81 | message = {"role": "assistant", "content": answer} 82 | st.session_state.messages.append(message) 83 | 84 | 85 | def display_documents(documents: list[dict]): 86 | for idx, doc in enumerate(documents["source_documents"]): 87 | with st.expander(f"Document {idx + 1}"): 88 | st.text(f"Source: {os.path.basename(doc['metadata']['source'])}") 89 | st.text( 90 | f"Index location: {doc['metadata']['start_index']}-{doc['metadata']['end_index']}" 91 | ) 92 | st.text_area("Content", value=doc["page_content"], height=150) 93 | 94 | 95 | def retrieve_source_documents(hp_id: int, query: str): 96 | documents = get_docs_from_query(hp_id, query) 97 | display_documents(documents) 98 | 99 | 100 | def load_hp_results() -> pd.DataFrame: 101 | with open(get_hyperparameters_results_path(), encoding="utf-8") as file: 102 | hp_data = json.load(file) 103 | 104 | df = pd.DataFrame(hp_data) 105 | df["id"] = df["id"].astype(int) 106 | 107 | # Flatten the "scores" sub-dictionary 108 | scores_df = pd.json_normalize(df["scores"]) 109 | 110 | # Combine the flattened scores DataFrame with the original DataFrame 111 | df = pd.concat( 112 | [df[["id"]], df.drop(columns=["id", "scores"], axis=1), scores_df], axis=1 113 | ) 114 | 115 | # Print the resulting DataFrame 116 | return df 117 | 118 | 119 | def display_rag_response_stream(hp_id: int, query: str): 120 | # Start a separate thread for fetching stream data 121 | if "rag_response_stream_data" not in st.session_state: 122 | st.session_state["rag_response_stream_data"] = "Starting stream...\n" 123 | 124 | get_rag_response_stream(hp_id, query) 125 | -------------------------------------------------------------------------------- /app/page_dashboard.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | import pandas as pd 3 | import os 4 | import pandas as pd 5 | import json 6 | import plotly.express as px 7 | from utils import * 8 | from utils import display_user_login_warning 9 | 10 | 11 | def page_dashboard(): 12 | st.title("Dashboard Page") 13 | st.subheader("Analyse hyperparameter metrics and evaluation dataset.") 14 | 15 | if display_user_login_warning(): 16 | return 17 | 18 | else: 19 | tab1, tab2, tab3, tab4 = st.tabs( 20 | [ 21 | "Charts", 22 | "Evaluation dataset", 23 | "Generated predictions", 24 | "Combined dataset", 25 | ] 26 | ) 27 | 28 | with tab1: 29 | if os.path.exists(get_hyperparameters_results_path()): 30 | plot_hyperparameters_results(get_hyperparameters_results_path()) 31 | else: 32 | st.warning("No hyperparameter results available. Run some evaluation.") 33 | 34 | with tab2: 35 | if os.path.exists(get_label_dataset_path()): 36 | df_label_dataset = get_df_label_dataset() 37 | 38 | showData = st.multiselect( 39 | "Filter: ", 40 | df_label_dataset.columns, 41 | default=["question", "answer", "context", "source", "id"], 42 | ) 43 | st.dataframe(df_label_dataset[showData], use_container_width=True) 44 | else: 45 | st.warning("No evaluation data available. Generate it.") 46 | 47 | with tab3: 48 | if os.path.exists(get_hyperparameters_results_data_path()): 49 | df_hp_runs = pd.read_csv(get_hyperparameters_results_data_path()) 50 | 51 | showData = st.multiselect( 52 | "Filter: ", 53 | df_hp_runs.columns, 54 | default=[ 55 | "hp_id", 56 | "predicted_answer", 57 | "retrieved_docs", 58 | "qa_id", 59 | ], 60 | ) 61 | st.dataframe(df_hp_runs[showData], use_container_width=True) 62 | else: 63 | st.warning("No generated dataset from hyperparameter runs available.") 64 | 65 | with tab4: 66 | if os.path.exists(get_label_dataset_path()) and os.path.exists( 67 | get_hyperparameters_results_data_path() 68 | ): 69 | df_label_dataset4 = get_df_label_dataset() 70 | df_hp_runs4 = get_df_hp_runs() 71 | 72 | merged_df = df_label_dataset4.merge( 73 | df_hp_runs4, left_on="id", right_on="qa_id" 74 | ) 75 | merged_df.drop(columns="id", inplace=True) 76 | 77 | showData = st.multiselect( 78 | "Filter: ", 79 | merged_df.columns, 80 | default=[ 81 | "question", 82 | "answer", 83 | "predicted_answer", 84 | "retrieved_docs", 85 | "hp_id", 86 | "qa_id", 87 | "source", 88 | ], 89 | ) 90 | st.dataframe(merged_df[showData], use_container_width=True) 91 | else: 92 | st.warning("Not sufficient data available.") 93 | 94 | 95 | def get_df_label_dataset() -> pd.DataFrame: 96 | df_label_dataset = pd.read_json(get_label_dataset_path()) 97 | df_label_dataset = pd.concat( 98 | [df_label_dataset, pd.json_normalize(df_label_dataset["metadata"])], 99 | axis=1, 100 | ) 101 | df_label_dataset = df_label_dataset.drop(columns=["metadata"]) 102 | 103 | return df_label_dataset 104 | 105 | 106 | def get_df_hp_runs() -> pd.DataFrame(): 107 | return pd.read_csv(get_hyperparameters_results_data_path()) 108 | 109 | 110 | def plot_hyperparameters_results(hyperparameters_results_path: str): 111 | with open(hyperparameters_results_path, "r", encoding="utf-8") as file: 112 | hyperparameters_results = json.load(file) 113 | 114 | # Convert the list of dictionaries to a DataFrame 115 | df = pd.DataFrame(hyperparameters_results) 116 | 117 | # Extract scores and timestamps into separate DataFrames 118 | scores_df = df["scores"].apply(pd.Series) 119 | scores_df["timestamp"] = pd.to_datetime(df["timestamp"]) 120 | # Sort DataFrame based on timestamps for chronological order 121 | scores_df = scores_df.sort_values(by="timestamp") 122 | df = df.loc[scores_df.index] 123 | # Create a combined x-axis with incrementing number + timestamp string 124 | df["x_values"] = range(1, len(scores_df) + 1) 125 | df["x_ticks"] = [f"{i}" for i, ts in enumerate(scores_df["timestamp"])] 126 | df["timestamp"] = scores_df["timestamp"] 127 | # Melt the scores_df DataFrame for Plotly plotting 128 | df_melted = pd.melt( 129 | scores_df.reset_index(), 130 | id_vars=["index"], 131 | value_vars=[ 132 | "answer_similarity_score", 133 | "retriever_mrr@3", 134 | "retriever_mrr@5", 135 | "retriever_mrr@10", 136 | "rouge1", 137 | "rouge2", 138 | "rougeLCS", 139 | "correctness_score", 140 | "comprehensiveness_score", 141 | "readability_score", 142 | "retriever_semantic_accuracy", 143 | ], 144 | ) 145 | # Merge on 'index' to get the correct x_values and x_ticks 146 | df_melted = df_melted.merge( 147 | df[["x_values", "x_ticks", "timestamp"]], 148 | left_on="index", 149 | right_index=True, 150 | how="left", 151 | ) 152 | df_melted = df_melted[df_melted.value >= 0] 153 | 154 | # Plot using Plotly Express 155 | fig = px.scatter( 156 | df_melted, 157 | x="x_values", 158 | y="value", 159 | color="variable", 160 | hover_data=["x_ticks", "variable", "value", "timestamp"], 161 | # color_discrete_sequence=px.colors.sequential.Viridis, 162 | labels={"x_values": "Hyperparameter run id", "value": "Scores"}, 163 | title="Hyperparameter chart", 164 | ) 165 | 166 | fig.update_layout( 167 | # xaxis_tickangle=-45, 168 | xaxis=dict(tickvals=df["x_values"], ticktext=df["x_ticks"]), 169 | yaxis=dict(tickvals=[i / 10.0 for i in range(11)]), 170 | plot_bgcolor="#F5F5DC", 171 | paper_bgcolor="#121212", 172 | height=600, # Set the height of the plot 173 | width=950, # Set the width of the plot 174 | ) 175 | st.plotly_chart(fig) 176 | -------------------------------------------------------------------------------- /app/page_documentstore.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | from utils import * 3 | from utils import display_user_login_warning 4 | 5 | 6 | def page_documentstore(): 7 | # Title of the page 8 | st.title("Document Store Page") 9 | st.subheader("Provide the documents that the application should use.") 10 | 11 | if display_user_login_warning(): 12 | return 13 | 14 | else: 15 | tab1, tab2 = st.tabs(["Upload Documents", "Provide cloud storage"]) 16 | 17 | with tab1: 18 | upload_files( 19 | context="docs", 20 | dropdown_msg="Upload your documents", 21 | ext_list=["pdf", "txt", "docx"], 22 | file_path=get_document_store_path(), 23 | allow_multiple_files=True, 24 | ) 25 | 26 | with tab2: 27 | st.write("Not implemented yet.") 28 | # Provide link to cloud resource 29 | cloud_link = st.text_input("Provide a link to your cloud resource:") 30 | 31 | if cloud_link: 32 | st.write(f"You provided the link: {cloud_link}") 33 | 34 | st.markdown("
" * 1, unsafe_allow_html=True) 35 | 36 | # List all files in the directory 37 | st.subheader("Your Document Store") 38 | path = get_document_store_path() 39 | if path: 40 | structure = ptree(path) 41 | st.code(structure, language="plaintext") 42 | -------------------------------------------------------------------------------- /app/page_filemanager.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | import os 3 | import zipfile 4 | import io 5 | from utils import ( 6 | list_files_in_directory, 7 | get_user_directory, 8 | ptree, 9 | display_user_login_warning, 10 | ) 11 | 12 | import time 13 | 14 | 15 | def page_filemanager(): 16 | st.title("File Manager") 17 | st.subheader("Manager the files in your directory.") 18 | 19 | if display_user_login_warning(): 20 | return 21 | 22 | else: 23 | tab1, tab2 = st.tabs(["Download files", "Delete files"]) 24 | 25 | with tab1: 26 | files = list_files_in_directory(get_user_directory()) 27 | selected_files = st.multiselect("Select files to download:", files) 28 | 29 | if len(selected_files) > 0 and st.button("Select Files"): 30 | # Create a zip archive in-memory 31 | zip_buffer = io.BytesIO() 32 | with zipfile.ZipFile( 33 | zip_buffer, "a", zipfile.ZIP_DEFLATED, False 34 | ) as zip_file: 35 | for selected in selected_files: 36 | file_path = os.path.join(get_user_directory(), selected) 37 | zip_file.write(file_path, selected) 38 | 39 | zip_buffer.seek(0) 40 | st.download_button( 41 | label="Download Files", 42 | data=zip_buffer, 43 | file_name="files.zip", 44 | mime="application/zip", 45 | ) 46 | 47 | with tab2: 48 | files = list_files_in_directory(get_user_directory()) 49 | selected_files = st.multiselect("Select files to delete:", files) 50 | 51 | if len(selected_files) > 0 and st.button("Delete Files"): 52 | action_success = True 53 | for selected in selected_files: 54 | file_path = os.path.join(get_user_directory(), selected) 55 | try: 56 | os.remove(file_path) 57 | except Exception as e: 58 | st.error(f"Error deleting {selected}: {e}") 59 | action_success = False 60 | 61 | if action_success: 62 | st.success(f"Deleted selected files successfully!") 63 | 64 | st.markdown("
" * 1, unsafe_allow_html=True) 65 | 66 | # List all files in the directory 67 | st.subheader("User directory") 68 | path = get_user_directory() 69 | if path or st.session_state.refresh_trigger: 70 | structure = ptree(path) 71 | st.code(structure, language="plaintext") 72 | -------------------------------------------------------------------------------- /app/page_home.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | 3 | 4 | # Define pages and their functions 5 | def page_home(): 6 | st.title("Welcome to RAGflow!") 7 | 8 | tab1, tab2, tab3 = st.tabs(["The App", "Getting Started", "About"]) 9 | 10 | with tab1: 11 | # Introductory text 12 | st.write( 13 | """ 14 | RAGflow is an advanced application framework tailored to streamline the construction and evaluation processes for Retrieval Augmented Generation (RAG) systems in question-answer contexts. Here's a brief rundown of its functionality: 15 | """ 16 | ) 17 | 18 | # Functionality points 19 | st.header("Key Features & Functionalities") 20 | 21 | # 1. Document Store Integration 22 | st.subheader("1. Document Store Integration") 23 | st.write( 24 | """ 25 | RAGflow starts by interfacing with a variety of document stores, enabling users to select from a diverse range of data sources. 26 | """ 27 | ) 28 | 29 | # 2. Dynamic Parameter Selection 30 | st.subheader("2. Dynamic Parameter Selection") 31 | st.write( 32 | """ 33 | Through an intuitive interface, users can customize various parameters, including document splitting strategies, embedding model choices, and question-answering LLMs. These parameters influence how the app splits, encodes, and processes data. 34 | """ 35 | ) 36 | 37 | # 3. Automated Dataset Generation 38 | st.subheader("3. Automated Dataset Generation") 39 | st.write( 40 | """ 41 | One of RAGflow's core strengths is its capability to auto-generate datasets. This feature allows for a seamless transition from raw data to a structured format ready for processing. 42 | """ 43 | ) 44 | 45 | # 4. Advanced Retrieval Mechanisms 46 | st.subheader("4. Advanced Retrieval Mechanisms") 47 | st.write( 48 | """ 49 | Incorporating state-of-the-art retrieval methods, RAGflow employs techniques like "MMR" for filtering and the SelfQueryRetriever for targeted data extraction. This ensures that the most relevant document chunks are presented for question-answering tasks. 50 | """ 51 | ) 52 | 53 | # 5. Integration with External Platforms 54 | st.subheader("5. Integration with External Platforms") 55 | st.write( 56 | """ 57 | RAGflow is designed to work in tandem with platforms like Anyscale, MosaicML, and Replicate, offering users extended functionalities and integration possibilities. 58 | """ 59 | ) 60 | 61 | # 6. Evaluation & Optimization 62 | st.subheader("6. Evaluation & Optimization") 63 | st.write( 64 | """ 65 | At its core, RAGflow is built to optimize performance. By performing grid searches across the selected parameter space, it ensures that users achieve the best possible results for their specific configurations. 66 | """ 67 | ) 68 | 69 | # 7. Interactive Feedback Loop 70 | st.subheader("7. Interactive Feedback Loop") 71 | st.write( 72 | """ 73 | As users interact with the generated question-answer systems, RAGflow offers feedback mechanisms, allowing for continuous improvement and refinement based on real-world results. 74 | """ 75 | ) 76 | 77 | # Conclusion 78 | st.write( 79 | """ 80 | Dive into RAGflow, set your parameters, and watch as it automates and optimizes the intricate process of building robust, data-driven question-answering systems! 81 | """ 82 | ) 83 | 84 | with tab2: 85 | # Functional stages 86 | st.header("Using This App: A Step-by-Step Guide") 87 | 88 | st.write( 89 | """ 90 | 1. Account Setup: 91 | Log in or register to get access to all services. 92 | 93 | 2. API Key Submission: 94 | Supply all the required API keys, such as those for OpenAI. The keys needed will vary based on the LLMs and services you intend to use. Keys will only get stored locally until page reloads. 95 | 96 | 3. Document Upload: 97 | Upload the documents for which you aim to create a RAG application under 'Documents'. 98 | 99 | 4. Parameter Configuration: 100 | Input the necessary parameters directly or by uploading a JSON file in 'Parameters'. This is essential for both the evaluation dataset generation in the 'QA Generator settings' tab and for the hyperparameter assessment in the 'Hyperparameters settings' tab. 101 | 102 | 5. Dashboard Analysis: 103 | Navigate to the 'Dashboard' to inspect hyperparameter metrics and view the generated data. 104 | 105 | 6. File Management: 106 | The 'File Manager' allows you to delete or download files within your directory for added flexibility. 107 | """ 108 | ) 109 | 110 | with tab3: 111 | st.subheader("About Me") 112 | st.write( 113 | """ 114 | Hello! I'm a passionate developer and AI enthusiast working on cutting-edge projects to leverage the power of machine learning. 115 | My latest project, RAGflow, focuses on building and evaluating Retrieval Augmented Generation (RAG) systems. 116 | Curious about my work? Check out my [GitHub repository](https://github.com/AndreasX42/RAGflow) for more details! 117 | """ 118 | ) 119 | -------------------------------------------------------------------------------- /app/page_login.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | import time 3 | 4 | from utils import * 5 | 6 | 7 | def page_login(): 8 | # Session State to track if user is logged in 9 | if "logged_in" not in st.session_state: 10 | st.session_state.logged_in = False 11 | 12 | if "user_id" not in st.session_state: 13 | st.session_state.user_id = "" 14 | 15 | # Tabs for Login and Registration 16 | tab1, tab2 = st.tabs(["Login", "Register"]) 17 | 18 | # Login Tab 19 | with tab1: 20 | # Check if user is already logged in 21 | response, success = get_auth_user() 22 | if success: 23 | st.session_state.user_id = str(response.get("id")) 24 | st.success(f"Already signed in as '{response.get('username')}'.") 25 | 26 | elif not st.session_state.logged_in: 27 | with st.form("login_form"): 28 | st.subheader("Login") 29 | # Input fields for username and password 30 | login_username = st.text_input("Username", key="login_username") 31 | login_password = st.text_input( 32 | "Password", type="password", key="login_password" 33 | ) 34 | 35 | # Submit button for the form 36 | submit_button = st.form_submit_button("Login") 37 | 38 | if submit_button: 39 | # Attempt to login 40 | response, success = user_login(login_username, login_password) 41 | if success: 42 | st.session_state.logged_in = True 43 | st.session_state.user_id = str(response.get("id")) 44 | st.success( 45 | f"Logged in successfully as {response.get('username')}." 46 | ) 47 | else: 48 | st.error( 49 | f"Login failed: {response.get('detail', 'Unknown error')}" 50 | ) 51 | 52 | if get_cookie_value() or st.session_state.user_id: 53 | with st.form("logout_form"): 54 | st.subheader("Logout") 55 | logout_button = st.form_submit_button("Logout") 56 | if logout_button: 57 | success = user_logout() 58 | 59 | if success: 60 | st.session_state.logged_in = False 61 | st.session_state.user_id = "" 62 | st.success("You are logged out!") 63 | 64 | else: 65 | st.error("Error logging out") 66 | 67 | time.sleep(1) 68 | st.rerun() 69 | 70 | # Registration Tab 71 | with tab2: 72 | with st.form("register_form"): 73 | st.subheader("Register") 74 | reg_username = st.text_input( 75 | "Username", 76 | key="reg_username", 77 | help="Username must be between 4 and 64 characters.", 78 | ) 79 | 80 | reg_email = st.text_input("Email", key="reg_email") 81 | 82 | reg_password = st.text_input( 83 | "Password", 84 | type="password", 85 | key="reg_password", 86 | help="Password must be between 8 and 128 characters.", 87 | ) 88 | 89 | if st.form_submit_button("Register"): 90 | reg_response, success = user_register( 91 | reg_username, reg_email, reg_password 92 | ) 93 | 94 | if success: 95 | st.success( 96 | f"Registered successfully! Please log in, {reg_response.get('username')}." 97 | ) 98 | 99 | else: 100 | st.error(f"Registration failed! Info: {reg_response['detail']}") 101 | -------------------------------------------------------------------------------- /app/page_parameters.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pandas as pd 3 | import streamlit as st 4 | from utils import * 5 | 6 | from streamlit import chat_message 7 | 8 | 9 | def page_parameters(): 10 | st.title("Parameters Page") 11 | st.subheader("Provide parameters for building and evaluating the system.") 12 | 13 | if display_user_login_warning(): 14 | return 15 | 16 | tab1, tab2 = st.tabs(["QA Generator settings", "Hyperparameters settings"]) 17 | 18 | valid_data = get_valid_params() 19 | 20 | with tab1: 21 | st.write( 22 | "The QA generator provides the possibility to generate question-context-answer triples from the provided documents that are used to evaluate hyperparameters and the corresponding RAG model in a consecutive step. You can either provide parameters for the generator through the drop-down menu below or by uploading a JSON file." 23 | ) 24 | 25 | provide_qa_gen_form(valid_data) 26 | 27 | upload_files( 28 | context="qa_params", 29 | dropdown_msg="Upload JSON file", 30 | ext_list=["json"], 31 | file_path=get_label_dataset_gen_params_path(), 32 | ) 33 | 34 | st.markdown("
" * 1, unsafe_allow_html=True) 35 | 36 | submit_button = st.button("Start evaluation dataset generation", key="SubmitQA") 37 | 38 | if submit_button: 39 | with st.spinner("Running generator..."): 40 | result = start_qa_gen() 41 | 42 | if "Success" in result: 43 | st.success(result) 44 | else: 45 | st.error(result) 46 | 47 | with tab2: 48 | st.write( 49 | "The Hyperparameter evaluator provides functionality of benchmarking RAG models with the corresponding parameters. During evaluation a LLM predicts an answer with the provided query and retrieved document chunks. With that we can calculate embedding similarities of label and predicted answers and ROUGE scores to provoide some metrics. We can also provide a LLM that is used for grading the predicted answers and the retrieved documents to extract even more metrics." 50 | ) 51 | 52 | provide_hp_params_form(valid_data) 53 | 54 | upload_files( 55 | context="hp_params", 56 | dropdown_msg="Upload JSON file", 57 | ext_list=["json"], 58 | file_path=get_hyperparameters_path(), 59 | ) 60 | 61 | st.markdown("
" * 1, unsafe_allow_html=True) 62 | 63 | submit_button = st.button("Start hyperparameter evaluation", key="SubmitHP") 64 | 65 | if submit_button: 66 | with st.spinner("Running hyperparameter evaluation..."): 67 | result = start_hp_run() 68 | 69 | if "Success" in result: 70 | st.success(result) 71 | else: 72 | st.error(result) 73 | 74 | 75 | def provide_hp_params_form(valid_data: dict): 76 | with st.expander("Hyperparameters settings"): 77 | attributes2 = { 78 | "chunk_size": st.number_input("Chunk Size", value=512, key="chunk size"), 79 | "chunk_overlap": st.number_input( 80 | "Chunk Overlap", value=10, key="chunk overlap" 81 | ), 82 | "length_function_name": st.selectbox( 83 | "Length Function for splitting or embedding model for corresponding tokenizer", 84 | valid_data["embedding_models"] + ["len"], 85 | key="len_tab2", 86 | ), 87 | "num_retrieved_docs": st.number_input( 88 | "Number of Docs the retriever should return", 89 | value=3, 90 | key="number of docs to retrieve", 91 | ), 92 | "similarity_method": st.selectbox( 93 | "Retriever Similarity Method", valid_data["retr_sim_method"] 94 | ), 95 | "search_type": st.selectbox( 96 | "Retriever Search Type", valid_data["retr_search_types"] 97 | ), 98 | "embedding_model": st.selectbox( 99 | "Embedding Model", 100 | valid_data["embedding_models"], 101 | key="Name of embedding model.", 102 | ), 103 | "qa_llm": st.selectbox( 104 | "QA Language Model", 105 | valid_data["llm_models"], 106 | key="Name of LLM for QA task.", 107 | ), 108 | } 109 | 110 | use_llm_grader = st.checkbox( 111 | "Use LLM Grader", value=False, key="use_llm_grader_checkbox" 112 | ) 113 | 114 | if use_llm_grader: 115 | attributes2["grade_answer_prompt"] = st.selectbox( 116 | "Grade Answer Prompt", 117 | valid_data["grade_answer_prompts"], 118 | key="Type of prompt to use for answer grading.", 119 | ) 120 | attributes2["grade_docs_prompt"] = st.selectbox( 121 | "Grade Documents Prompt", 122 | valid_data["grade_documents_prompts"], 123 | key="Type of prompt to grade retrieved document chunks.", 124 | ) 125 | attributes2["grader_llm"] = st.selectbox( 126 | "Grading LLM", 127 | valid_data["llm_models"], 128 | key="Name of LLM for grading.", 129 | ) 130 | 131 | submit_button = st.button("Submit", key="Submit2") 132 | if submit_button: 133 | attributes2["use_llm_grader"] = use_llm_grader 134 | st.write("Saved to file. You've entered the following values:") 135 | # This is just a mockup to show you how to display the attributes. 136 | # You'll probably want to process or display them differently. 137 | st.write(attributes2) 138 | write_json( 139 | attributes2, 140 | get_hyperparameters_path(), 141 | append=True, 142 | ) 143 | 144 | 145 | def provide_qa_gen_form(valid_data: dict): 146 | with st.expander("Provide QA Generator settings form"): 147 | attributes = { 148 | "chunk_size": st.number_input("Chunk Size", value=512), 149 | "chunk_overlap": st.number_input("Chunk Overlap", value=10), 150 | "length_function_name": st.selectbox( 151 | "Length Function for splitting or embedding model for corresponding tokenizer", 152 | valid_data["embedding_models"] + ["len"], 153 | ), 154 | "qa_generator_llm": st.selectbox( 155 | "QA Language Model", valid_data["llm_models"] 156 | ), 157 | } 158 | 159 | persist_to_vs = st.checkbox( 160 | "Cache answer embeddings in ChromaDB for each of the models listed below", 161 | value=False, 162 | key="persist_to_vs", 163 | ) 164 | 165 | # Get input from the user 166 | if persist_to_vs: 167 | input_text = st.text_area( 168 | "Enter a list of embedding model names for caching the generated answer embeddings in chromadb (one per line):" 169 | ) 170 | 171 | submit_button = st.button("Submit", key="submit1") 172 | if submit_button: 173 | # save bool in dict 174 | attributes["persist_to_vs"] = persist_to_vs 175 | 176 | if persist_to_vs: 177 | # Store the list in a dictionary 178 | # Split the text by newline to get a list 179 | text_list = input_text.split("\n") 180 | attributes["embedding_model_list"] = text_list 181 | else: 182 | attributes["embedding_model_list"] = [] 183 | 184 | with st.spinner("Saving to file."): 185 | st.write("Saved to file. You've entered the following values:") 186 | # This is just a mockup to show you how to display the attributes. 187 | # You'll probably want to process or display them differently. 188 | st.write(attributes) 189 | # save json 190 | write_json( 191 | attributes, 192 | get_label_dataset_gen_params_path(), 193 | ) 194 | -------------------------------------------------------------------------------- /app/requirements.txt: -------------------------------------------------------------------------------- 1 | streamlit==1.28.1 2 | streamlit_option_menu==0.3.6 3 | 4 | plotly==5.18.0 5 | pandas==2.1.2 6 | numerize==0.12 7 | matplotlib==3.8.1 8 | seaborn==0.13.0 9 | 10 | openpyxl==3.1.2 -------------------------------------------------------------------------------- /docker-compose.dev.yaml: -------------------------------------------------------------------------------- 1 | version: "3" 2 | services: 3 | app: 4 | build: 5 | dockerfile: Dockerfile 6 | context: ./app 7 | container_name: app-frontend 8 | ports: 9 | - "8501:8501" 10 | volumes: 11 | - ./app:/app 12 | - ./tmp:/app/tmp 13 | environment: 14 | - RAGFLOW_HOST=ragflow-backend 15 | - RAGFLOW_PORT=8080 16 | restart: always 17 | ragflow-backend: 18 | build: 19 | dockerfile: Dockerfile 20 | context: ./ragflow 21 | container_name: ragflow-backend 22 | ports: 23 | - "8080:8080" 24 | volumes: 25 | - ./ragflow:/backend/ragflow 26 | - ./tmp:/backend/tmp 27 | environment: 28 | - JWT_SECRET_KEY=brjia5mOUlE3RN0CFy 29 | - POSTGRES_DATABASE=postgres_db 30 | - PGVECTOR_DATABASE=vector_db 31 | - POSTGRES_USER=admin 32 | - POSTGRES_PASSWORD=my_password 33 | - POSTGRES_HOST=postgres 34 | - POSTGRES_PORT=5432 35 | - POSTGRES_DRIVER=psycopg2 36 | - EXECUTION_CONTEXT=DEV 37 | - LOG_LEVEL=INFO 38 | - CHROMADB_HOST=chromadb 39 | - CHROMADB_PORT=8000 40 | restart: always 41 | chromadb: 42 | build: 43 | dockerfile: Dockerfile 44 | context: ./vectorstore 45 | container_name: chromadb 46 | ports: 47 | - "8000:8000" 48 | volumes: 49 | - ./vectorstore/chroma:/chroma/chroma 50 | - ./vectorstore:/vectorstore 51 | environment: 52 | - IS_PERSISTENT=TRUE 53 | - ALLOW_RESET=TRUE 54 | restart: always 55 | postgres: 56 | image: ankane/pgvector:v0.5.1 57 | container_name: postgres 58 | ports: 59 | - "5432:5432" 60 | volumes: 61 | - ./postgres/data:/var/lib/postgresql/data 62 | environment: 63 | - POSTGRES_USER=admin 64 | - POSTGRES_PASSWORD=my_password 65 | restart: always 66 | pgadmin: 67 | image: dpage/pgadmin4:8.0 68 | container_name: pgadmin 69 | environment: 70 | - PGADMIN_DEFAULT_EMAIL=admin@admin.com 71 | - PGADMIN_DEFAULT_PASSWORD=my_password 72 | ports: 73 | - "5050:80" 74 | restart: always 75 | -------------------------------------------------------------------------------- /docker-compose.integration.test.yaml: -------------------------------------------------------------------------------- 1 | version: "3" 2 | services: 3 | ragflow-test: 4 | image: andreasx42/ragflow-backend:latest 5 | container_name: ragflow-test 6 | ports: 7 | - "8080:8080" 8 | environment: 9 | - JWT_SECRET_KEY=brjia5mOUlE3RN0CFy 10 | - POSTGRES_DATABASE=postgres 11 | - PGVECTOR_DATABASE=postgres 12 | - POSTGRES_USER=admin 13 | - POSTGRES_PASSWORD=my_password 14 | - POSTGRES_HOST=postgres-test 15 | - POSTGRES_PORT=5432 16 | - POSTGRES_DRIVER=psycopg2 17 | - EXECUTION_CONTEXT=TEST 18 | - LOG_LEVEL=INFO 19 | - CHROMADB_HOST=chromadb-test 20 | - CHROMADB_PORT=8000 21 | - INPUT_LABEL_DATASET=./resources/input_label_dataset.json 22 | volumes: 23 | - ./resources/tests:/backend/resources 24 | chromadb-test: 25 | image: andreasx42/ragflow-vectorstore:latest 26 | container_name: chromadb-test 27 | ports: 28 | - 8000:8000 29 | environment: 30 | - IS_PERSISTENT=TRUE 31 | - ALLOW_RESET=TRUE 32 | test-suite: 33 | image: andreasx42/ragflow-test:latest 34 | container_name: tester 35 | environment: 36 | - RAGFLOW_HOST=ragflow-test 37 | - RAGFLOW_PORT=8080 38 | - CHROMADB_HOST=chromadb-test 39 | - CHROMADB_PORT=8000 40 | depends_on: 41 | - ragflow-test 42 | - chromadb-test 43 | volumes: 44 | - ./resources/tests:/tests/resources 45 | postgres-test: 46 | image: ankane/pgvector:v0.5.1 47 | container_name: postgres-test 48 | ports: 49 | - "5432:5432" 50 | environment: 51 | - POSTGRES_USER=admin 52 | - POSTGRES_PASSWORD=my_password 53 | restart: always 54 | -------------------------------------------------------------------------------- /docker-compose.local.test.yaml: -------------------------------------------------------------------------------- 1 | version: "3" 2 | services: 3 | ragflow-test: 4 | build: 5 | dockerfile: Dockerfile 6 | context: ./ragflow 7 | container_name: ragflow-test 8 | ports: 9 | - "8080:8080" 10 | environment: 11 | - JWT_SECRET_KEY=brjia5mOUlE3RN0CFy 12 | - POSTGRES_DATABASE=postgres 13 | - PGVECTOR_DATABASE=postgres 14 | - POSTGRES_USER=admin 15 | - POSTGRES_PASSWORD=my_password 16 | - POSTGRES_HOST=postgres-test 17 | - POSTGRES_PORT=5432 18 | - POSTGRES_DRIVER=psycopg2 19 | - EXECUTION_CONTEXT=TEST 20 | - LOG_LEVEL=INFO 21 | - CHROMADB_HOST=chromadb-test 22 | - CHROMADB_PORT=8000 23 | - INPUT_LABEL_DATASET=./resources/input_label_dataset.json 24 | volumes: 25 | - ./ragflow:/backend/ragflow 26 | - ./resources/tests:/backend/resources 27 | chromadb-test: 28 | build: 29 | dockerfile: Dockerfile 30 | context: ./vectorstore 31 | container_name: chromadb-test 32 | ports: 33 | - 8000:8000 34 | environment: 35 | - IS_PERSISTENT=TRUE 36 | - ALLOW_RESET=TRUE 37 | test-suite: 38 | build: 39 | dockerfile: Dockerfile 40 | context: ./tests 41 | container_name: tester 42 | environment: 43 | - RAGFLOW_HOST=ragflow-test 44 | - RAGFLOW_PORT=8080 45 | - CHROMADB_HOST=chromadb-test 46 | - CHROMADB_PORT=8000 47 | depends_on: 48 | - ragflow-test 49 | - chromadb-test 50 | volumes: 51 | - ./resources/tests:/tests/resources 52 | - ./tests:/tests 53 | postgres-test: 54 | image: ankane/pgvector:v0.5.1 55 | container_name: postgres-test 56 | ports: 57 | - "5432:5432" 58 | environment: 59 | - POSTGRES_USER=admin 60 | - POSTGRES_PASSWORD=my_password 61 | restart: always 62 | -------------------------------------------------------------------------------- /k8s/app-cluster-ip-service.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: v1 2 | kind: Service 3 | metadata: 4 | name: app-cluster-ip-service 5 | spec: 6 | type: ClusterIP 7 | selector: 8 | component: frontend 9 | ports: 10 | - port: 8501 11 | targetPort: 8501 12 | -------------------------------------------------------------------------------- /k8s/app-deployment.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: apps/v1 2 | kind: Deployment 3 | metadata: 4 | name: app-deployment 5 | spec: 6 | replicas: 1 7 | selector: 8 | matchLabels: 9 | component: frontend 10 | template: 11 | metadata: 12 | labels: 13 | component: frontend 14 | spec: 15 | volumes: 16 | - name: frontend-storage 17 | persistentVolumeClaim: 18 | claimName: shared-persistent-volume-claim 19 | containers: 20 | - name: app-frontend 21 | image: andreasx42/ragflow-app:latest 22 | env: 23 | - name: RAGFLOW_HOST 24 | value: backend-cluster-ip-service 25 | - name: RAGFLOW_PORT 26 | value: '8080' 27 | ports: 28 | - containerPort: 8501 29 | volumeMounts: 30 | - name: frontend-storage 31 | mountPath: app/tmp 32 | subPath: tmp -------------------------------------------------------------------------------- /k8s/backend-cluster-ip-service.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: v1 2 | kind: Service 3 | metadata: 4 | name: backend-cluster-ip-service 5 | spec: 6 | type: ClusterIP 7 | selector: 8 | component: backend 9 | ports: 10 | - port: 8080 11 | targetPort: 8080 12 | -------------------------------------------------------------------------------- /k8s/backend-deployment.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: apps/v1 2 | kind: Deployment 3 | metadata: 4 | name: backend-deployment 5 | spec: 6 | replicas: 1 7 | selector: 8 | matchLabels: 9 | component: backend 10 | template: 11 | metadata: 12 | labels: 13 | component: backend 14 | spec: 15 | volumes: 16 | - name: backend-storage 17 | persistentVolumeClaim: 18 | claimName: shared-persistent-volume-claim 19 | containers: 20 | - name: ragflow-backend 21 | image: andreasx42/ragflow-backend:latest 22 | env: 23 | - name: CHROMADB_HOST 24 | value: vectorstore-cluster-ip-service 25 | - name: CHROMADB_PORT 26 | value: '8000' 27 | - name: POSTGRES_HOST 28 | value: postgres-cluster-ip-service 29 | - name: POSTGRES_PORT 30 | value: '5432' 31 | - name: POSTGRES_DRIVER 32 | value: psycopg2 33 | - name: POSTGRES_PASSWORD 34 | valueFrom: 35 | secretKeyRef: 36 | name: pgsecrets 37 | key: POSTGRES_PASSWORD 38 | - name: POSTGRES_USER 39 | valueFrom: 40 | secretKeyRef: 41 | name: pgsecrets 42 | key: POSTGRES_USER 43 | - name: POSTGRES_DATABASE 44 | valueFrom: 45 | secretKeyRef: 46 | name: pgsecrets 47 | key: POSTGRES_DATABASE 48 | - name: PGVECTOR_DATABASE 49 | valueFrom: 50 | secretKeyRef: 51 | name: pgsecrets 52 | key: PGVECTOR_DATABASE 53 | - name: JWT_SECRET_KEY 54 | valueFrom: 55 | secretKeyRef: 56 | name: pgsecrets 57 | key: JWT_SECRET_KEY 58 | ports: 59 | - containerPort: 8080 60 | volumeMounts: 61 | - name: backend-storage 62 | mountPath: backend/tmp 63 | subPath: tmp -------------------------------------------------------------------------------- /k8s/ingress-service.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: networking.k8s.io/v1 2 | kind: Ingress 3 | metadata: 4 | name: ingress-service 5 | annotations: 6 | nginx.ingress.kubernetes.io/use-regex: 'true' 7 | nginx.ingress.kubernetes.io/ssl-redirect: 'false' 8 | nginx.ingress.kubernetes.io/configuration-snippet: | 9 | proxy_set_header X-Script-Name /pgadmin; 10 | nginx.ingress.kubernetes.io/rewrite-target: /$1 11 | spec: 12 | ingressClassName: nginx 13 | rules: 14 | - http: 15 | paths: 16 | - path: /api/?(.*) 17 | pathType: ImplementationSpecific 18 | backend: 19 | service: 20 | name: backend-cluster-ip-service 21 | port: 22 | number: 8080 23 | - path: /pgadmin/?(.*) 24 | pathType: ImplementationSpecific 25 | backend: 26 | service: 27 | name: pgadmin-cluster-ip-service 28 | port: 29 | number: 5050 30 | - path: /?(.*) 31 | pathType: ImplementationSpecific 32 | backend: 33 | service: 34 | name: app-cluster-ip-service 35 | port: 36 | number: 8501 -------------------------------------------------------------------------------- /k8s/pgadmin-cluster-ip-service.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: v1 2 | kind: Service 3 | metadata: 4 | name: pgadmin-cluster-ip-service 5 | spec: 6 | type: ClusterIP 7 | selector: 8 | component: pgadmin 9 | ports: 10 | - port: 5050 11 | targetPort: 80 12 | -------------------------------------------------------------------------------- /k8s/pgadmin-deployment.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: apps/v1 2 | kind: Deployment 3 | metadata: 4 | name: pgadmin-deployment 5 | spec: 6 | replicas: 1 7 | selector: 8 | matchLabels: 9 | component: pgadmin 10 | template: 11 | metadata: 12 | labels: 13 | component: pgadmin 14 | spec: 15 | containers: 16 | - name: pgadmin 17 | image: dpage/pgadmin4:8.0 18 | ports: 19 | - containerPort: 80 20 | env: 21 | - name: PGADMIN_DEFAULT_EMAIL 22 | valueFrom: 23 | secretKeyRef: 24 | name: pgsecrets 25 | key: PGADMIN_DEFAULT_EMAIL 26 | - name: PGADMIN_DEFAULT_PASSWORD 27 | valueFrom: 28 | secretKeyRef: 29 | name: pgsecrets 30 | key: POSTGRES_PASSWORD 31 | -------------------------------------------------------------------------------- /k8s/postgres-cluster-ip-service.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: v1 2 | kind: Service 3 | metadata: 4 | name: postgres-cluster-ip-service 5 | spec: 6 | type: ClusterIP 7 | selector: 8 | component: postgres 9 | ports: 10 | - port: 5432 11 | targetPort: 5432 12 | -------------------------------------------------------------------------------- /k8s/postgres-deployment.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: apps/v1 2 | kind: Deployment 3 | metadata: 4 | name: postgres-deployment 5 | spec: 6 | replicas: 1 7 | selector: 8 | matchLabels: 9 | component: postgres 10 | template: 11 | metadata: 12 | labels: 13 | component: postgres 14 | spec: 15 | volumes: 16 | - name: postgres-storage 17 | persistentVolumeClaim: 18 | claimName: shared-persistent-volume-claim 19 | containers: 20 | - name: postgres 21 | image: ankane/pgvector:v0.5.1 22 | ports: 23 | - containerPort: 5432 24 | volumeMounts: 25 | - name: postgres-storage 26 | mountPath: /var/lib/postgresql/data 27 | subPath: dbs/postgres 28 | env: 29 | - name: POSTGRES_PASSWORD 30 | valueFrom: 31 | secretKeyRef: 32 | name: pgsecrets 33 | key: POSTGRES_PASSWORD 34 | - name: POSTGRES_USER 35 | valueFrom: 36 | secretKeyRef: 37 | name: pgsecrets 38 | key: POSTGRES_USER 39 | -------------------------------------------------------------------------------- /k8s/shared-persistent-volume-claim.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: v1 2 | kind: PersistentVolumeClaim 3 | metadata: 4 | name: shared-persistent-volume-claim 5 | spec: 6 | accessModes: 7 | - ReadWriteOnce 8 | resources: 9 | requests: 10 | storage: 2Gi -------------------------------------------------------------------------------- /k8s/vectorstore-cluster-ip-service.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: v1 2 | kind: Service 3 | metadata: 4 | name: vectorstore-cluster-ip-service 5 | spec: 6 | type: ClusterIP 7 | selector: 8 | component: vectorstore 9 | ports: 10 | - port: 8000 11 | targetPort: 8000 12 | -------------------------------------------------------------------------------- /k8s/vectorstore-deployment.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: apps/v1 2 | kind: Deployment 3 | metadata: 4 | name: vectorstore-deployment 5 | spec: 6 | replicas: 1 7 | selector: 8 | matchLabels: 9 | component: vectorstore 10 | template: 11 | metadata: 12 | labels: 13 | component: vectorstore 14 | spec: 15 | volumes: 16 | - name: vectorstore-storage 17 | persistentVolumeClaim: 18 | claimName: shared-persistent-volume-claim 19 | containers: 20 | - name: chromadb-vectorstore 21 | image: andreasx42/ragflow-vectorstore:latest 22 | env: 23 | - name: IS_PERSISTENT 24 | value: "TRUE" 25 | - name: ALLOW_RESET 26 | value: "TRUE" 27 | ports: 28 | - containerPort: 8000 29 | volumeMounts: 30 | - name: vectorstore-storage 31 | mountPath: chroma/chroma 32 | subPath: dbs/chromadb 33 | -------------------------------------------------------------------------------- /ragflow/.dockerignore: -------------------------------------------------------------------------------- 1 | # git 2 | .git 3 | .gitignore 4 | .cache 5 | **.log 6 | 7 | # data/chromadb 8 | data/datasets.zip 9 | data/desc.docx 10 | chromadb/ 11 | 12 | # Byte-compiled / optimized / DLL files 13 | **/__pycache__ 14 | **/*.pyc 15 | **.py[cod] 16 | *$py.class 17 | .vscode/ 18 | 19 | # Jupyter Notebook 20 | **/.ipynb_checkpoints 21 | notebooks/ 22 | 23 | # dotenv 24 | .env 25 | 26 | # virtualenv 27 | .venv* 28 | venv*/ 29 | ENV*/ -------------------------------------------------------------------------------- /ragflow/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.11 2 | 3 | # Port the app is running on 4 | EXPOSE 8080 5 | 6 | # Install dependencies 7 | WORKDIR /backend 8 | 9 | COPY ./requirements.txt ./ 10 | RUN pip install --no-cache-dir --upgrade -r ./requirements.txt 11 | 12 | # Copy all into image 13 | COPY ./ ./ragflow 14 | 15 | ENV PYTHONPATH "${PYTHONPATH}:/backend" 16 | 17 | CMD ["uvicorn", "ragflow.api:app", "--host", "0.0.0.0", "--port", "8080", "--reload"] -------------------------------------------------------------------------------- /ragflow/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | 4 | logging.basicConfig( 5 | level=os.environ.get("LOG_LEVEL", "INFO"), 6 | format="%(asctime)s - %(levelname)s - %(name)s:%(filename)s:%(lineno)d - %(message)s", 7 | ) 8 | -------------------------------------------------------------------------------- /ragflow/api/__init__.py: -------------------------------------------------------------------------------- 1 | from ragflow.api.main import app 2 | -------------------------------------------------------------------------------- /ragflow/api/database.py: -------------------------------------------------------------------------------- 1 | from sqlalchemy import create_engine 2 | from sqlalchemy.orm import declarative_base, sessionmaker 3 | 4 | import os 5 | 6 | DRIVER = os.environ.get("POSTGRES_DRIVER") 7 | HOST = os.environ.get("POSTGRES_HOST", "localhost") 8 | PORT = os.environ.get("POSTGRES_PORT") 9 | DB = os.environ.get("POSTGRES_DATABASE") 10 | USER = os.environ.get("POSTGRES_USER") 11 | PWD = os.environ.get("POSTGRES_PASSWORD") 12 | 13 | 14 | if HOST == "localhost": 15 | HOST += f":{PORT}" 16 | 17 | DATABASE_URL = f"postgresql+{DRIVER}://{USER}:{PWD}@{HOST}/{DB}" 18 | 19 | engine = create_engine(DATABASE_URL) 20 | 21 | Session = sessionmaker(autocommit=False, autoflush=False, bind=engine) 22 | 23 | Base = declarative_base() 24 | -------------------------------------------------------------------------------- /ragflow/api/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | from fastapi import FastAPI 3 | 4 | from ragflow.api.routers import ( 5 | configs_router, 6 | evals_router, 7 | gens_router, 8 | chats_router, 9 | auth_router, 10 | user_router, 11 | ) 12 | 13 | from ragflow.api.database import Base, engine 14 | 15 | app = FastAPI( 16 | title="FastAPI RAGflow Documentation", 17 | version="0.01", 18 | description="""API for the main component of RAGflow to generate evaluation datasets of QA pairs and to run hyperparameter evaluations. 19 | """, 20 | # root path should only be "/api" in production 21 | root_path="" if os.environ.get("EXECUTION_CONTEXT") in ["DEV", "TEST"] else "/api", 22 | ) 23 | 24 | app.include_router(configs_router.router) 25 | app.include_router(gens_router.router) 26 | app.include_router(evals_router.router) 27 | app.include_router(chats_router.router) 28 | app.include_router(auth_router.router) 29 | app.include_router(user_router.router) 30 | 31 | Base.metadata.create_all(bind=engine) 32 | -------------------------------------------------------------------------------- /ragflow/api/models.py: -------------------------------------------------------------------------------- 1 | import datetime as dt 2 | from sqlalchemy import Column, Integer, String, DateTime, Boolean 3 | 4 | from ragflow.api.database import Base 5 | 6 | 7 | class User(Base): 8 | __tablename__ = "users" 9 | id = Column(Integer, primary_key=True, index=True) 10 | username = Column(String, index=True, unique=True) 11 | hashed_password = Column(String, index=False, unique=False) 12 | email = Column(String, index=True, unique=True) 13 | date_created = Column(DateTime, default=dt.datetime.now(dt.UTC)) 14 | is_active = Column(Boolean, default=True) 15 | role = Column(String, default="user", index=True) 16 | -------------------------------------------------------------------------------- /ragflow/api/routers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AndreasX42/RAGflow/cec31b91af8ab7faa05a3d0376338c7d294d29aa/ragflow/api/routers/__init__.py -------------------------------------------------------------------------------- /ragflow/api/routers/auth_router.py: -------------------------------------------------------------------------------- 1 | from fastapi import APIRouter, Depends, HTTPException, Response, Request 2 | from starlette import status 3 | import sqlalchemy.orm as orm 4 | from fastapi.security import OAuth2PasswordRequestForm 5 | from datetime import timedelta 6 | 7 | from ragflow.api.schemas import UserFromDB 8 | from ragflow.api.services import get_db 9 | from ragflow.api.services.auth_service import ( 10 | authenticate_user, 11 | create_access_token, 12 | get_current_active_user, 13 | ) 14 | 15 | 16 | router = APIRouter( 17 | prefix="/auth", 18 | tags=["Authentication endpoints"], 19 | ) 20 | 21 | ACCESS_TOKEN_EXPIRATION_IN_MINUTES = 60 * 24 * 7 # 7 days 22 | 23 | 24 | @router.get("/user", response_model=UserFromDB, status_code=status.HTTP_200_OK) 25 | async def get_authenticated_user(user: UserFromDB = Depends(get_current_active_user)): 26 | return user 27 | 28 | 29 | @router.post("/login", response_model=UserFromDB, status_code=status.HTTP_200_OK) 30 | async def login_for_access_token( 31 | response: Response, 32 | form_data: OAuth2PasswordRequestForm = Depends(), 33 | db_session: orm.Session = Depends(get_db), 34 | ): 35 | user = await authenticate_user( 36 | username=form_data.username, password=form_data.password, db_session=db_session 37 | ) 38 | 39 | if not user: 40 | raise HTTPException( 41 | status_code=status.HTTP_401_UNAUTHORIZED, 42 | detail="Incorrect username or password", 43 | headers={"WWW-Authenticate": "Bearer"}, 44 | ) 45 | 46 | if not user.is_active: 47 | raise HTTPException( 48 | status_code=status.HTTP_401_UNAUTHORIZED, 49 | detail="User is disabled", 50 | headers={"WWW-Authenticate": "Bearer"}, 51 | ) 52 | 53 | jwt, exp = create_access_token( 54 | subject={"sub": user.username, "user_id": user.id, "user_role": user.role}, 55 | expires_delta=timedelta(minutes=ACCESS_TOKEN_EXPIRATION_IN_MINUTES), 56 | ) 57 | 58 | response.set_cookie( 59 | key="access_token", 60 | value=jwt, 61 | expires=int(exp.timestamp()), 62 | httponly=True, 63 | ) 64 | 65 | return UserFromDB.model_validate(user) 66 | 67 | 68 | @router.get("/logout", response_model=dict, status_code=status.HTTP_200_OK) 69 | async def logout(response: Response): 70 | response.delete_cookie(key="access_token") 71 | 72 | return {"message": "Logged out successfully"} 73 | -------------------------------------------------------------------------------- /ragflow/api/routers/chats_router.py: -------------------------------------------------------------------------------- 1 | from fastapi import APIRouter, HTTPException 2 | from fastapi.responses import StreamingResponse 3 | from starlette import status 4 | from pydantic import BaseModel, Field 5 | 6 | from ragflow.utils.hyperparam_chats import AsyncCallbackHandler 7 | from ragflow.utils.hyperparam_chats import query_chat, create_gen, get_docs 8 | 9 | router = APIRouter( 10 | prefix="/chats", 11 | tags=["Endpoints to chat with parameterized RAGs"], 12 | ) 13 | 14 | 15 | class ChatQueryRequest(BaseModel): 16 | hp_id: int = Field(ge=0, description="Hyperparameter run id") 17 | hyperparameters_results_path: str = ( 18 | Field(min_length=3, description="path to list of hp results"), 19 | ) 20 | user_id: str = Field(description="user id from db") 21 | api_keys: dict[str, str] = Field(description="Dictionary of API keys.") 22 | query: str = Field(min_length=5, description="User query") 23 | 24 | class Config: 25 | json_schema_extra = { 26 | "example": { 27 | "id": 0, 28 | "hyperparameters_results_path": "./tmp/hyperparameters_results.json", 29 | "user_id": "1", 30 | "api_keys": { 31 | "OPENAI_API_KEY": "your_api_key_here", 32 | "ANOTHER_API_KEY": "another_key_here", 33 | }, 34 | "query": "My query", 35 | } 36 | } 37 | 38 | 39 | @router.post("/query", status_code=status.HTTP_200_OK) 40 | async def query_chat_model_normal(request: ChatQueryRequest): 41 | try: 42 | return query_chat(**request.model_dump()) 43 | except Exception as ex: 44 | raise HTTPException(status_code=400, detail=str(ex)) 45 | 46 | 47 | @router.post("/get_docs", status_code=status.HTTP_200_OK) 48 | async def query_chat_model_normal(request: ChatQueryRequest): 49 | try: 50 | return await get_docs(**request.model_dump()) 51 | except Exception as ex: 52 | raise HTTPException(status_code=400, detail=str(ex)) 53 | 54 | 55 | @router.post("/query_stream", status_code=status.HTTP_200_OK) 56 | async def query_chat_model_stream(request: ChatQueryRequest): 57 | stream_it = AsyncCallbackHandler() 58 | gen = create_gen(**request.model_dump(), stream_it=stream_it) 59 | return StreamingResponse(gen, media_type="text/event-stream") 60 | -------------------------------------------------------------------------------- /ragflow/api/routers/configs_router.py: -------------------------------------------------------------------------------- 1 | from fastapi import APIRouter 2 | from starlette import status 3 | 4 | from ragflow.commons.configurations import ( 5 | LLM_MODELS, 6 | EMB_MODELS, 7 | CVGradeAnswerPrompt, 8 | CVGradeRetrieverPrompt, 9 | CVRetrieverSearchType, 10 | CVSimilarityMethod, 11 | ) 12 | 13 | router = APIRouter( 14 | prefix="/configs", 15 | tags=["Configurations"], 16 | ) 17 | 18 | 19 | @router.get("/llm_models", status_code=status.HTTP_200_OK) 20 | async def get_list_of_supported_llm_models(): 21 | return LLM_MODELS 22 | 23 | 24 | @router.get("/embedding_models", status_code=status.HTTP_200_OK) 25 | async def list_of_embedding_models(): 26 | return EMB_MODELS 27 | 28 | 29 | @router.get("/retriever_similarity_methods", status_code=status.HTTP_200_OK) 30 | async def list_of_similarity_methods_for_retriever(): 31 | return [e.value for e in CVSimilarityMethod] 32 | 33 | 34 | @router.get("/retriever_search_types", status_code=status.HTTP_200_OK) 35 | async def list_of_search_types_for_retriever(): 36 | return [e.value for e in CVRetrieverSearchType] 37 | 38 | 39 | @router.get("/grade_answer_prompts", status_code=status.HTTP_200_OK) 40 | async def list_of_prompts_for_grading_answers(): 41 | return [e.value for e in CVGradeAnswerPrompt] 42 | 43 | 44 | @router.get("/grade_documents_prompts", status_code=status.HTTP_200_OK) 45 | async def list_of_prompts_for_grading_documents_retrieved(): 46 | return [e.value for e in CVGradeRetrieverPrompt] 47 | -------------------------------------------------------------------------------- /ragflow/api/routers/evals_router.py: -------------------------------------------------------------------------------- 1 | from fastapi import APIRouter, HTTPException 2 | from starlette import status 3 | from pydantic import BaseModel, Field 4 | 5 | from ragflow.evaluation import arun_evaluation 6 | 7 | 8 | router = APIRouter( 9 | tags=["Hyperparameter evaluation"], 10 | ) 11 | 12 | 13 | class EvaluationRequest(BaseModel): 14 | document_store_path: str = Field(min_length=3, description="path to documents") 15 | label_dataset_path: str = Field( 16 | min_length=3, description="path to generated evaluation dataset" 17 | ) 18 | hyperparameters_path: str = Field( 19 | min_length=3, description="path to list of hyperparameters" 20 | ) 21 | hyperparameters_results_path: str = Field( 22 | min_length=3, description="path to list of hp results" 23 | ) 24 | hyperparameters_results_data_path: str = Field( 25 | min_length=3, 26 | description="path to list of additional data generated during hp eval", 27 | ) 28 | user_id: str = Field(description="user id from db") 29 | api_keys: dict[str, str] = Field(description="Dictionary of API keys.") 30 | 31 | class Config: 32 | json_schema_extra = { 33 | "example": { 34 | "document_store_path": "./tmp/document_store/", # path to documents 35 | "hyperparameters_path": "./tmp/hyperparameters.json", # path to list of hyperparameters 36 | "label_dataset_path": "./tmp/label_dataset.json", # path to generated evaluation dataset 37 | "hyperparameters_results_path": "./tmp/hyperparameters_results.json", # path to list of eval results 38 | "hyperparameters_results_data_path": "./tmp/hyperparameters_results_data.csv", # path to list of generated predictions and retrieved docs 39 | "user_id": "1", # user id 40 | "api_keys": { 41 | "OPENAI_API_KEY": "your_api_key_here", 42 | "ANOTHER_API_KEY": "another_key_here", 43 | }, 44 | } 45 | } 46 | 47 | 48 | @router.post("/evaluation", status_code=status.HTTP_200_OK) 49 | async def start_evaluation_run(eval_request: EvaluationRequest): 50 | try: 51 | await arun_evaluation(**eval_request.model_dump()) 52 | except Exception as ex: 53 | print(ex) 54 | raise HTTPException(status_code=400, detail=str(ex)) 55 | -------------------------------------------------------------------------------- /ragflow/api/routers/gens_router.py: -------------------------------------------------------------------------------- 1 | from fastapi import APIRouter, HTTPException 2 | from starlette import status 3 | from pydantic import BaseModel, Field 4 | 5 | from ragflow.generation import agenerate_evaluation_set 6 | 7 | router = APIRouter( 8 | tags=["Evaluation label set generation"], 9 | ) 10 | 11 | 12 | class GenerationRequest(BaseModel): 13 | document_store_path: str = Field(min_length=3, description="path to documents") 14 | label_dataset_gen_params_path: str = Field( 15 | min_length=3, description="path to configurations for qa generation" 16 | ) 17 | label_dataset_path: str = Field( 18 | min_length=3, 19 | description="path to where the generated qa pairs should be stored.", 20 | ) 21 | user_id: str = Field(description="user id from db") 22 | api_keys: dict[str, str] = Field(description="Dictionary of API keys.") 23 | 24 | class Config: 25 | json_schema_extra = { 26 | "example": { 27 | "document_store_path": "./tmp/document_store/", # path to documents 28 | "label_dataset_gen_params_path": "./tmp/label_dataset_gen_params.json", # path to list of hyperparameters 29 | "label_dataset_path": "./tmp/label_dataset.json", # path to generated evaluation dataset 30 | "user_id": "1", # user id 31 | "api_keys": { 32 | "OPENAI_API_KEY": "your_api_key_here", 33 | "ANOTHER_API_KEY": "another_key_here", 34 | }, 35 | } 36 | } 37 | 38 | 39 | @router.post("/generation", status_code=status.HTTP_200_OK) 40 | async def start_evalset_generation(gen_request: GenerationRequest): 41 | try: 42 | await agenerate_evaluation_set(**gen_request.model_dump()) 43 | except Exception as ex: 44 | print(ex) 45 | raise HTTPException(status_code=400, detail=str(ex)) 46 | -------------------------------------------------------------------------------- /ragflow/api/routers/user_router.py: -------------------------------------------------------------------------------- 1 | from fastapi import APIRouter, Depends, HTTPException 2 | import sqlalchemy.orm as orm 3 | from starlette import status 4 | 5 | from ragflow.api.services import get_db 6 | import ragflow.api.services.user_service as user_service 7 | import ragflow.api.services.auth_service as auth_service 8 | 9 | from ragflow.api.schemas import CreateUserRequest, UpdateUserRequest, UserFromDB 10 | 11 | router = APIRouter( 12 | tags=["User endpoints"], 13 | ) 14 | 15 | import logging 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | 20 | @router.post( 21 | "/users", 22 | response_model=UserFromDB, 23 | status_code=status.HTTP_201_CREATED, 24 | ) 25 | async def create_user( 26 | user_data: CreateUserRequest, 27 | db_session: orm.Session = Depends(get_db), 28 | ): 29 | return await user_service.create_user(user_data=user_data, db_session=db_session) 30 | 31 | 32 | @router.get( 33 | "/users", 34 | response_model=list[UserFromDB], 35 | status_code=status.HTTP_200_OK, 36 | ) 37 | async def get_all_users( 38 | user: UserFromDB = Depends(auth_service.get_current_active_user), 39 | db_session: orm.Session = Depends(get_db), 40 | ): 41 | if user.role != "admin": 42 | raise HTTPException(status_code=401, detail="User not authorized") 43 | 44 | return await user_service.get_all_users(db_session=db_session) 45 | 46 | 47 | @router.get( 48 | "/users/{user_id}", 49 | response_model=UserFromDB, 50 | status_code=status.HTTP_200_OK, 51 | ) 52 | async def get_user_by_id( 53 | user_id: int, 54 | user: UserFromDB = Depends(auth_service.get_current_active_user), 55 | db_session: orm.Session = Depends(get_db), 56 | ): 57 | if user.id != user_id and user.role != "admin": 58 | raise HTTPException(status_code=401, detail="User not authorized") 59 | 60 | return await user_service.get_user_by_id(user_id=user_id, db_session=db_session) 61 | 62 | 63 | @router.delete( 64 | "/users/{user_id}", 65 | response_model=UserFromDB, 66 | status_code=status.HTTP_200_OK, 67 | ) 68 | async def delete_user_by_id( 69 | user_id: int, 70 | user: UserFromDB = Depends(auth_service.get_current_active_user), 71 | db_session: orm.Session = Depends(get_db), 72 | ): 73 | if user.id != user_id and user.role != "admin": 74 | raise HTTPException(status_code=401, detail="User not authorized") 75 | 76 | user_to_delete = await user_service.get_user_by_id( 77 | user_id=user_id, db_session=db_session 78 | ) 79 | 80 | await user_service.delete_user(user=user_to_delete, db_session=db_session) 81 | 82 | return user_to_delete 83 | 84 | 85 | @router.put( 86 | "/users/{user_id}", 87 | response_model=UserFromDB, 88 | status_code=status.HTTP_200_OK, 89 | ) 90 | async def update_user_by_id( 91 | user_id: int, 92 | user_data: UpdateUserRequest, 93 | user: UserFromDB = Depends(auth_service.get_current_active_user), 94 | db_session: orm.Session = Depends(get_db), 95 | ): 96 | if user.id != user_id and user.role != "admin": 97 | raise HTTPException(status_code=401, detail="User not authorized") 98 | 99 | user_to_change = await user_service.get_user_by_id( 100 | user_id=user_id, db_session=db_session 101 | ) 102 | 103 | return await user_service.update_user( 104 | user=user_to_change, user_data=user_data, db_session=db_session 105 | ) 106 | -------------------------------------------------------------------------------- /ragflow/api/schemas.py: -------------------------------------------------------------------------------- 1 | import datetime as dt 2 | from typing import Annotated 3 | from pydantic import BaseModel, EmailStr, StringConstraints 4 | 5 | from typing import Optional 6 | 7 | 8 | class UserBase(BaseModel): 9 | username: str 10 | email: str 11 | 12 | 13 | class UserFromDB(UserBase): 14 | id: int 15 | date_created: dt.datetime 16 | role: str 17 | is_active: bool 18 | 19 | class Config: 20 | from_attributes = True 21 | 22 | 23 | class UpdateUserRequest(UserBase): 24 | password: Optional[str] = None 25 | 26 | 27 | class CreateUserRequest(UserBase): 28 | username: Annotated[str, StringConstraints(min_length=4, max_length=64)] 29 | password: Annotated[str, StringConstraints(min_length=8, max_length=128)] 30 | email: EmailStr 31 | -------------------------------------------------------------------------------- /ragflow/api/services/__init__.py: -------------------------------------------------------------------------------- 1 | from ragflow.api.services.common_service import get_db 2 | import ragflow.api.services.user_service as user_service 3 | import ragflow.api.services.auth_service as auth_service 4 | -------------------------------------------------------------------------------- /ragflow/api/services/auth_service.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | from fastapi import Depends, HTTPException, Request 4 | from fastapi.security import OAuth2PasswordBearer 5 | from starlette import status 6 | import sqlalchemy.orm as orm 7 | from jose import JWTError, jwt 8 | from passlib.context import CryptContext 9 | 10 | from datetime import datetime, timedelta 11 | import os 12 | 13 | from ragflow.api.schemas import UserFromDB 14 | from ragflow.api.services import user_service 15 | from ragflow.api.services import get_db 16 | 17 | JWT_SECRET_KEY = os.environ.get("JWT_SECRET_KEY") 18 | HASH_ALGORITHM = "HS256" 19 | 20 | pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") 21 | oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/login") 22 | 23 | 24 | def verify_password(plain_password: str, hashed_password: str) -> bool: 25 | return pwd_context.verify(plain_password, hashed_password) 26 | 27 | 28 | def get_password_hash(password: str) -> str: 29 | return pwd_context.hash(password) 30 | 31 | 32 | async def authenticate_user( 33 | username: str, password: str, db_session: orm.Session 34 | ) -> Union[UserFromDB, bool]: 35 | user = await user_service.get_user_by_name(username=username, db_session=db_session) 36 | 37 | if user is None: 38 | return False 39 | 40 | if not verify_password(password, user.hashed_password): 41 | return False 42 | 43 | return user 44 | 45 | 46 | def create_access_token( 47 | subject: dict, expires_delta: timedelta = None 48 | ) -> tuple[str, datetime]: 49 | to_encode = subject.copy() 50 | 51 | if expires_delta is not None: 52 | exp = datetime.utcnow() + expires_delta 53 | 54 | else: 55 | exp = datetime.utcnow() + timedelta(minutes=15) 56 | 57 | to_encode |= {"exp": exp} 58 | encoded_jwt = jwt.encode(to_encode, JWT_SECRET_KEY, HASH_ALGORITHM) 59 | 60 | return encoded_jwt, exp 61 | 62 | 63 | async def get_current_user( 64 | request: Request, 65 | db_session: orm.Session = Depends(get_db), 66 | ) -> UserFromDB: 67 | credentials_exception = HTTPException( 68 | status_code=status.HTTP_401_UNAUTHORIZED, 69 | detail="Could not validate credentials", 70 | headers={"WWW-Authenticate": "Bearer"}, 71 | ) 72 | 73 | try: 74 | token = request.cookies.get("access_token") 75 | 76 | if token is None: 77 | raise HTTPException( 78 | status_code=status.HTTP_401_UNAUTHORIZED, 79 | detail="User not authenticated", 80 | ) 81 | 82 | payload = jwt.decode(token, JWT_SECRET_KEY, algorithms=[HASH_ALGORITHM]) 83 | username = payload.get("sub") 84 | user_id = payload.get("user_id") 85 | user_role = payload.get("user_role") 86 | 87 | if username is None or user_id is None or user_role is None: 88 | raise credentials_exception 89 | 90 | except JWTError: 91 | raise credentials_exception 92 | 93 | user = await user_service.get_user_by_id(user_id=user_id, db_session=db_session) 94 | 95 | if user is None: 96 | raise credentials_exception 97 | 98 | return UserFromDB.model_validate(user) 99 | 100 | 101 | async def get_current_active_user( 102 | current_user: UserFromDB = Depends(get_current_user), 103 | ) -> UserFromDB: 104 | if not current_user.is_active: 105 | raise HTTPException(status_code=400, detail="Inactive user") 106 | 107 | return current_user 108 | -------------------------------------------------------------------------------- /ragflow/api/services/common_service.py: -------------------------------------------------------------------------------- 1 | from ragflow.api.database import Session 2 | 3 | 4 | def get_db(): 5 | db_session = Session() 6 | try: 7 | yield db_session 8 | finally: 9 | db_session.close() 10 | -------------------------------------------------------------------------------- /ragflow/api/services/user_service.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from ragflow.api.services import auth_service 4 | from ragflow.api.schemas import UserFromDB, CreateUserRequest, UpdateUserRequest 5 | from ragflow.api.models import User 6 | from ragflow.api.database import Session 7 | 8 | from fastapi import HTTPException 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | async def create_user(user_data: CreateUserRequest, db_session: Session) -> UserFromDB: 14 | user_data_dict = user_data.model_dump() 15 | 16 | user_data_dict["hashed_password"] = auth_service.get_password_hash( 17 | user_data_dict.pop("password") 18 | ) 19 | 20 | # Check if a user with the same username or email already exists 21 | existing_user = ( 22 | db_session.query(User) 23 | .filter( 24 | (User.username == user_data_dict.get("username")) 25 | | (User.email == user_data_dict.get("email")) 26 | ) 27 | .first() 28 | ) 29 | 30 | if existing_user: 31 | # Determine which attribute (username or email) already exists in db 32 | duplicate_field = ( 33 | "username" 34 | if existing_user.username == user_data_dict.get("username") 35 | else "email" 36 | ) 37 | raise HTTPException( 38 | status_code=400, 39 | detail=f"User with provided {duplicate_field} already exists!", 40 | ) 41 | 42 | user = User(**user_data_dict) 43 | 44 | db_session.add(user) 45 | db_session.commit() 46 | db_session.refresh(user) 47 | 48 | return UserFromDB.model_validate(user) 49 | 50 | 51 | async def get_all_users(db_session: Session) -> list[UserFromDB]: 52 | users = db_session.query(User).all() 53 | return list(map(UserFromDB.model_validate, users)) 54 | 55 | 56 | async def get_user_by_id(user_id: int, db_session: Session) -> User: 57 | user = db_session.query(User).filter(User.id == user_id).first() 58 | 59 | return user 60 | 61 | 62 | async def get_user_by_name(username: str, db_session: Session) -> User: 63 | user = db_session.query(User).filter(User.username == username).first() 64 | 65 | return user 66 | 67 | 68 | async def delete_user(user: User, db_session: Session) -> None: 69 | db_session.delete(user) 70 | db_session.commit() 71 | 72 | 73 | async def update_user( 74 | user: User, 75 | user_data: UpdateUserRequest, 76 | db_session: Session, 77 | ) -> UserFromDB: 78 | user.username = user_data.username 79 | user.email = user_data.email 80 | 81 | if user_data.password: 82 | user.hashed_password = auth_service.get_password_hash(user_data.password) 83 | 84 | db_session.commit() 85 | db_session.refresh(user) 86 | 87 | return UserFromDB.model_validate(user) 88 | -------------------------------------------------------------------------------- /ragflow/commons/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AndreasX42/RAGflow/cec31b91af8ab7faa05a3d0376338c7d294d29aa/ragflow/commons/__init__.py -------------------------------------------------------------------------------- /ragflow/commons/chroma/ChromaClient.py: -------------------------------------------------------------------------------- 1 | import chromadb 2 | from chromadb.config import Settings 3 | import os 4 | 5 | API_HOST = os.environ.get("CHROMADB_HOST", "localhost") 6 | API_PORT = os.environ.get("CHROMADB_PORT", 8000) 7 | 8 | 9 | class ChromaClient: 10 | """Get a chromadb client as ContextManager.""" 11 | 12 | def __init__(self): 13 | self.chroma_client = chromadb.HttpClient( 14 | host=API_HOST, 15 | port=API_PORT, 16 | settings=Settings(anonymized_telemetry=False), 17 | ) 18 | 19 | def get_client(self): 20 | return self.chroma_client 21 | 22 | def __enter__(self): 23 | return self.chroma_client 24 | 25 | def __exit__(self, exc_type, exc_value, traceback): 26 | # TODO: self.chroma_client.stop() 27 | pass 28 | -------------------------------------------------------------------------------- /ragflow/commons/chroma/__init__.py: -------------------------------------------------------------------------------- 1 | from ragflow.commons.chroma.ChromaClient import ChromaClient 2 | -------------------------------------------------------------------------------- /ragflow/commons/configurations/BaseConfigurations.py: -------------------------------------------------------------------------------- 1 | from langchain.schema.embeddings import Embeddings 2 | from langchain.schema.language_model import BaseLanguageModel 3 | from langchain.llms.fake import FakeListLLM 4 | 5 | from langchain.embeddings import OpenAIEmbeddings, DeterministicFakeEmbedding 6 | from langchain.chat_models import ChatOpenAI, ChatAnyscale 7 | 8 | import tiktoken 9 | import builtins 10 | import logging 11 | 12 | from abc import abstractmethod 13 | from pydantic.v1 import BaseModel, Field, Extra, validator 14 | from typing import Any, Optional 15 | 16 | from enum import Enum 17 | 18 | logger = logging.getLogger(__name__) 19 | 20 | #################### 21 | # LLM Models 22 | #################### 23 | LLAMA_2_MODELS = [ 24 | "Llama-2-7b-chat-hf", 25 | "Llama-2-13b-chat-hf", 26 | "Llama-2-70b-chat-hf", 27 | ] 28 | 29 | OPENAI_LLM_MODELS = ["gpt-3.5-turbo", "gpt-4"] 30 | 31 | LLM_MODELS = [*OPENAI_LLM_MODELS, *LLAMA_2_MODELS] 32 | 33 | # Anyscale configs 34 | ANYSCALE_LLM_PREFIX = "meta-llama/" 35 | ANYSCALE_API_URL = "https://api.endpoints.anyscale.com/v1" 36 | 37 | #################### 38 | # Embedding Models 39 | #################### 40 | OPENAI_EMB_MODELS = ["text-embedding-ada-002"] 41 | 42 | EMB_MODELS = [*OPENAI_EMB_MODELS] 43 | 44 | 45 | #################### 46 | # Enumerations 47 | #################### 48 | class CVGradeAnswerPrompt(Enum): 49 | ZERO_SHOT = "zero_shot" 50 | FEW_SHOT = "few_shot" 51 | NONE = "none" 52 | 53 | 54 | class CVGradeRetrieverPrompt(Enum): 55 | DEFAULT = "default" 56 | NONE = "none" 57 | 58 | 59 | class CVRetrieverSearchType(Enum): 60 | BY_SIMILARITY = "similarity" 61 | MAXIMAL_MARGINAL_RELEVANCE = "mmr" 62 | 63 | 64 | class CVSimilarityMethod(Enum): 65 | COSINE = "cosine" 66 | L2_NORM = "l2" 67 | INNER_PRODUCT = "ip" 68 | 69 | 70 | #################### 71 | # Test LLM and embedding model for mocks 72 | #################### 73 | class TestDummyLLM(FakeListLLM): 74 | """Langchains FakeListLLM with model name included.""" 75 | 76 | model_name: str = "TestDummyLLM" 77 | 78 | def __init__(self): 79 | super().__init__(responses=["foo_response"]) 80 | 81 | def dict(self, *args, **kwargs): 82 | output = super().dict(*args, **kwargs) 83 | output["model_name"] = self.model_name 84 | return output 85 | 86 | 87 | class TestDummyEmbedding(DeterministicFakeEmbedding): 88 | """Langchains DeterministicFakeEmbedding with model name included.""" 89 | 90 | model: str = "TestDummyEmbedding" 91 | 92 | def __init__(self): 93 | super().__init__(size=2) 94 | 95 | 96 | class BaseConfigurations(BaseModel): 97 | """Base class for configuration objects.""" 98 | 99 | chunk_size: int = Field(ge=0) 100 | chunk_overlap: int = Field(ge=0) 101 | length_function_name: str 102 | length_function: Any 103 | 104 | class Config: 105 | allow_mutation = False 106 | arbitrary_types_allowed = True 107 | extra = Extra.forbid 108 | 109 | @validator("length_function", pre=False, always=True) 110 | def populate_length_function(cls, v: callable, values: dict[str, str]): 111 | """Convert provided length function name to actual function.""" 112 | 113 | return cls.set_length_function(values["length_function_name"]) 114 | 115 | @staticmethod 116 | def get_language_model(model_name: str, api_keys: dict) -> BaseLanguageModel: 117 | if model_name in OPENAI_LLM_MODELS: 118 | return ChatOpenAI( 119 | openai_api_key=api_keys["OPENAI_API_KEY"], 120 | model_name=model_name, 121 | temperature=0.0, 122 | ) 123 | 124 | elif model_name in LLAMA_2_MODELS: 125 | return ChatAnyscale( 126 | anyscale_api_key=api_keys["ANYSCALE_API_KEY"], 127 | model_name=f"{ANYSCALE_LLM_PREFIX}{model_name}", 128 | anyscale_api_base=ANYSCALE_API_URL, 129 | temperature=0.0, 130 | ) 131 | # only for testing purposes 132 | elif model_name == "TestDummyLLM": 133 | return TestDummyLLM() 134 | 135 | raise NotImplementedError(f"LLM model '{model_name}' not supported.") 136 | 137 | @staticmethod 138 | def get_embedding_model(model_name: str, api_keys: dict) -> Embeddings: 139 | if model_name in OPENAI_EMB_MODELS: 140 | return OpenAIEmbeddings( 141 | openai_api_key=api_keys["OPENAI_API_KEY"], model=model_name 142 | ) 143 | elif model_name == "TestDummyEmbedding": 144 | return TestDummyEmbedding() 145 | 146 | raise NotImplementedError(f"Embedding model '{model_name}' not supported.") 147 | 148 | @classmethod 149 | def set_length_function(cls, length_function_name: str) -> callable: 150 | # Extract the function name from the string 151 | func = length_function_name.strip("<>").split(" ")[-1] 152 | 153 | # Check if the function name exists in Python's built-ins 154 | if hasattr(builtins, func): 155 | return getattr(builtins, func) 156 | 157 | else: 158 | try: 159 | encoding = tiktoken.encoding_for_model(length_function_name) 160 | return lambda x: len(encoding.encode(x)) 161 | except Exception as ex: 162 | logger.error(f"Length function '{length_function_name}' not supported") 163 | raise NotImplementedError( 164 | f"Error setting length function, neither python built-in nor valid tiktoken name passed. {ex.args}" 165 | ) 166 | 167 | @staticmethod 168 | def get_language_model_name(llm: BaseLanguageModel) -> str: 169 | """Retrieve name of language model from object""" 170 | return llm.model_name 171 | 172 | @staticmethod 173 | def get_embedding_model_name(emb: Embeddings) -> str: 174 | """Retrieve name of embedding model name from object""" 175 | return emb.model 176 | 177 | def to_dict(self) -> dict: 178 | _data = self.dict() 179 | _data.pop("length_function", None) 180 | return _data 181 | 182 | @classmethod 183 | @abstractmethod 184 | def from_dict(cls, input_dict): 185 | pass 186 | -------------------------------------------------------------------------------- /ragflow/commons/configurations/Hyperparameters.py: -------------------------------------------------------------------------------- 1 | from langchain.schema.embeddings import Embeddings 2 | from langchain.schema.language_model import BaseLanguageModel 3 | 4 | from ragflow.commons.configurations.BaseConfigurations import ( 5 | BaseConfigurations, 6 | LLM_MODELS, 7 | EMB_MODELS, 8 | CVGradeAnswerPrompt, 9 | CVGradeRetrieverPrompt, 10 | CVRetrieverSearchType, 11 | CVSimilarityMethod, 12 | ) 13 | 14 | import logging 15 | 16 | from pydantic.v1 import Field, validator 17 | 18 | logger = logging.getLogger(__name__) 19 | 20 | 21 | class Hyperparameters(BaseConfigurations): 22 | """Class to model hyperparameters.""" 23 | 24 | id: int = Field(ge=0) 25 | num_retrieved_docs: int = Field(ge=0) 26 | grade_answer_prompt: CVGradeAnswerPrompt 27 | grade_docs_prompt: CVGradeRetrieverPrompt 28 | search_type: CVRetrieverSearchType 29 | similarity_method: CVSimilarityMethod 30 | use_llm_grader: bool 31 | qa_llm: BaseLanguageModel 32 | grader_llm: BaseLanguageModel 33 | embedding_model: Embeddings 34 | 35 | @validator("qa_llm", "grader_llm", pre=True, always=True) 36 | def check_language_model_name(cls, v, field, values): 37 | """Validation to check if provided LLM model is supported.""" 38 | 39 | if cls.get_language_model_name(v) not in LLM_MODELS + ["TestDummyLLM"]: 40 | raise ValueError(f"{v} not in list of valid values {LLM_MODELS}.") 41 | return v 42 | 43 | @validator("embedding_model", pre=True, always=True) 44 | def check_embedding_model_name(cls, v): 45 | """Validation to check if provided embedding model is supported.""" 46 | 47 | if cls.get_embedding_model_name(v) not in EMB_MODELS + ["TestDummyEmbedding"]: 48 | raise ValueError(f"{v} not in list of valid values {EMB_MODELS}.") 49 | return v 50 | 51 | def to_dict(self): 52 | _data = super().to_dict() 53 | 54 | # Modify the dictionary for fields that need special handling 55 | _data["qa_llm"] = _data["qa_llm"]["model_name"] 56 | _data["grader_llm"] = _data["grader_llm"]["model_name"] 57 | _data["embedding_model"] = _data["embedding_model"]["model"] 58 | 59 | # remove unnecessary values again if we don't use a llm for grading 60 | if not _data["use_llm_grader"]: 61 | _data.pop("grade_answer_prompt", None) 62 | _data.pop("grade_docs_prompt", None) 63 | _data.pop("grader_llm", None) 64 | 65 | return _data 66 | 67 | @classmethod 68 | def from_dict( 69 | cls, input_dict: dict[str, str], hp_id: str, api_keys: dict[str, str] 70 | ): 71 | _input = dict(**input_dict) 72 | 73 | # set default values if use_grader_llm=False and additional values not set 74 | if not _input["use_llm_grader"]: 75 | _input["grade_answer_prompt"] = CVGradeAnswerPrompt.NONE.value 76 | _input["grade_docs_prompt"] = CVGradeRetrieverPrompt.NONE.value 77 | _input["grader_llm"] = "TestDummyLLM" 78 | 79 | _input["id"] = hp_id 80 | 81 | # set the actual langchain objects 82 | _input["embedding_model"] = cls.get_embedding_model( 83 | _input["embedding_model"], api_keys 84 | ) 85 | _input["qa_llm"] = cls.get_language_model(_input["qa_llm"], api_keys) 86 | _input["grader_llm"] = cls.get_language_model(_input["grader_llm"], api_keys) 87 | 88 | return cls(**_input) 89 | -------------------------------------------------------------------------------- /ragflow/commons/configurations/QAConfigurations.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from pydantic.v1 import validator 3 | 4 | from langchain.schema.language_model import BaseLanguageModel 5 | from langchain.schema.embeddings import Embeddings 6 | from ragflow.commons.configurations.BaseConfigurations import ( 7 | BaseConfigurations, 8 | LLM_MODELS, 9 | ) 10 | 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | class QAConfigurations(BaseConfigurations): 16 | """Class to model qa generation configs.""" 17 | 18 | qa_generator_llm: BaseLanguageModel 19 | persist_to_vs: bool 20 | embedding_model_list: list[Embeddings] 21 | 22 | @validator("qa_generator_llm", pre=True, always=True) 23 | def check_language_model_name(cls, v): 24 | """Validation to check if provided LLM model is supported.""" 25 | 26 | if cls.get_language_model_name(v) not in LLM_MODELS + ["TestDummyLLM"]: 27 | raise ValueError(f"{v} not in list of valid values {LLM_MODELS}.") 28 | return v 29 | 30 | def to_dict(self): 31 | _data = super().to_dict() 32 | 33 | # Modify the dictionary for fields that need special handling 34 | _data["qa_generator_llm"] = _data["qa_generator_llm"]["model_name"] 35 | _data["embedding_model_list"] = [ 36 | model["model"] for model in _data["embedding_model_list"] 37 | ] 38 | 39 | return _data 40 | 41 | @classmethod 42 | def from_dict(cls, input_dict: dict[str, str], api_keys: dict[str, str]): 43 | _input = dict(**input_dict) 44 | 45 | _input["qa_generator_llm"] = cls.get_language_model( 46 | _input["qa_generator_llm"], api_keys 47 | ) 48 | 49 | # get the list of embedding models, filter the unique models and map them to LangChain objects 50 | embedding_models = set(_input["embedding_model_list"]) 51 | 52 | _input["embedding_model_list"] = [ 53 | cls.get_embedding_model(model, api_keys) for model in embedding_models 54 | ] 55 | 56 | return cls(**_input) 57 | -------------------------------------------------------------------------------- /ragflow/commons/configurations/__init__.py: -------------------------------------------------------------------------------- 1 | from ragflow.commons.configurations.BaseConfigurations import ( 2 | BaseConfigurations, 3 | CVGradeAnswerPrompt, 4 | CVGradeRetrieverPrompt, 5 | CVRetrieverSearchType, 6 | CVSimilarityMethod, 7 | LLM_MODELS, 8 | EMB_MODELS, 9 | ) 10 | from ragflow.commons.configurations.Hyperparameters import Hyperparameters 11 | from ragflow.commons.configurations.QAConfigurations import QAConfigurations 12 | 13 | import os 14 | import dotenv 15 | import logging 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | # Here we load .env file which contains the API keys to access LLMs, etc. 20 | if os.environ.get("ENV") == "DEV": 21 | try: 22 | dotenv.load_dotenv(dotenv.find_dotenv(), override=True) 23 | except Exception as ex: 24 | logger.error( 25 | "Failed to load .env file in 'DEV' environment containing API keys." 26 | ) 27 | raise ex 28 | -------------------------------------------------------------------------------- /ragflow/commons/prompts/__init__.py: -------------------------------------------------------------------------------- 1 | from ragflow.commons.prompts.qa_geneneration_prompts import ( 2 | QA_GENERATION_PROMPT_SELECTOR, 3 | ) 4 | 5 | from ragflow.commons.prompts.grade_answers_prompts import ( 6 | GRADE_ANSWER_PROMPT_FAST, 7 | GRADE_ANSWER_PROMPT_5_CATEGORIES_5_GRADES_ZERO_SHOT, 8 | GRADE_ANSWER_PROMPT_3_CATEGORIES_4_GRADES_FEW_SHOT, 9 | ) 10 | 11 | from ragflow.commons.prompts.grade_retriever_prompts import GRADE_RETRIEVER_PROMPT 12 | 13 | from ragflow.commons.prompts.qa_answer_prompts import QA_ANSWER_PROMPT 14 | -------------------------------------------------------------------------------- /ragflow/commons/prompts/grade_answers_prompts.py: -------------------------------------------------------------------------------- 1 | from langchain.prompts import PromptTemplate 2 | 3 | template = """Please act as an impartial judge and evaluate the quality of the provided answer which attempts to answer the provided question based on a provided context. 4 | 5 | Grade the student answers based ONLY on their factual accuracy. Ignore differences in punctuation and phrasing between the student answer and true answer. It is OK if the student answer contains more information than the true answer, as long as it does not contain any conflicting statements. If the student answers that there is no specific information provided in the context, then the answer is incorrect. Begin! 6 | 7 | QUESTION: {query} 8 | STUDENT ANSWER: {result} 9 | TRUE ANSWER: {answer} 10 | 11 | Your response should be as follows: 12 | 13 | CORRECTNESS: (1,2,3,4 or 5) - grade 1 means the answer was completly incorrect, a higher grade towards 5 means the answer is more correct, does clarify more parts of the question and is more readable. The best grade is 5. 14 | """ 15 | 16 | GRADE_ANSWER_PROMPT_FAST = PromptTemplate( 17 | input_variables=["query", "result", "answer"], template=template 18 | ) 19 | 20 | template = """Please act as an impartial judge and evaluate the quality of the provided answer which attempts to answer the provided question based on a provided context. 21 | 22 | You'll be given a function grading_function which you'll call for each provided context, question and answer to submit your reasoning and score for the correctness, comprehensiveness and readability of the answer. 23 | 24 | Below is your grading rubric: 25 | 26 | - Correctness: Does the answer correctly answer the question. 27 | 28 | - Comprehensiveness: How comprehensive is the answer, does it fully answer all aspects of the question and provide comprehensive explanation and other necessary information. 29 | 30 | - Readability: How readable is the answer, does it have redundant information or incomplete information that hurts the readability of the answer. Rate from 0 (completely unreadable) to 1 (highly readable) 31 | 32 | Grade the student answers based ONLY on their factual accuracy. Ignore differences in punctuation and phrasing between the student answer and true answer. It is OK if the student answer contains more information than the true answer, as long as it does not contain any conflicting statements. If the student answers that there is no specific information provided in the context, then the answer is incorrect. Begin! 33 | 34 | QUESTION: {query} 35 | STUDENT ANSWER: {result} 36 | TRUE ANSWER: {answer} 37 | 38 | Your response should be as follows, do not provide any additional information. 39 | 40 | CORRECTNESS: (1,2,3,4 or 5) - scale from 1 worst to 5 best 41 | COMPREHENSIVENESS: (1,2,3,4 or 5) - scale from 1 worst to 5 best 42 | READABILITY: (1,2,3,4 or 5) - scale from 1 worst to 5 best 43 | """ 44 | 45 | GRADE_ANSWER_PROMPT_5_CATEGORIES_5_GRADES_ZERO_SHOT = PromptTemplate( 46 | input_variables=["query", "result", "answer"], template=template 47 | ) 48 | 49 | 50 | template = """Please act as an impartial judge and evaluate the quality of the provided answer which attempts to answer the provided question based on a provided context. 51 | 52 | You'll be given a function grading_function which you'll call for each provided context, question and answer to submit your reasoning and score for the correctness, comprehensiveness and readability of the answer. 53 | 54 | Below is your grading rubric: 55 | 56 | - CORRECTNESS: If the answer correctly answer the question, below are the details for different scores: 57 | 58 | - Score 0: the answer is completely incorrect, doesn’t mention anything about the question or is completely contrary to the correct answer. 59 | 60 | - For example, when asked “How to terminate a databricks cluster”, the answer is empty string, or content that’s completely irrelevant, or sorry I don’t know the answer. 61 | 62 | - Score 1: the answer provides some relevance to the question and answers one aspect of the question correctly. 63 | 64 | - Example: 65 | 66 | - Question: How to terminate a databricks cluster 67 | 68 | - Answer: Databricks cluster is a cloud-based computing environment that allows users to process big data and run distributed data processing tasks efficiently. 69 | 70 | - Or answer: In the Databricks workspace, navigate to the "Clusters" tab. And then this is a hard question that I need to think more about it 71 | 72 | - Score 2: the answer mostly answer the question but is missing or hallucinating on one critical aspect. 73 | 74 | - Example: 75 | 76 | - Question: How to terminate a databricks cluster” 77 | 78 | - Answer: “In the Databricks workspace, navigate to the "Clusters" tab. 79 | 80 | Find the cluster you want to terminate from the list of active clusters. 81 | 82 | And then you’ll find a button to terminate all clusters at once” 83 | 84 | - Score 3: the answer correctly answer the question and not missing any major aspect 85 | 86 | - Example: 87 | 88 | - Question: How to terminate a databricks cluster 89 | 90 | - Answer: In the Databricks workspace, navigate to the "Clusters" tab. 91 | 92 | Find the cluster you want to terminate from the list of active clusters. 93 | 94 | Click on the down-arrow next to the cluster name to open the cluster details. 95 | 96 | Click on the "Terminate" button. A confirmation dialog will appear. Click "Terminate" again to confirm the action.” 97 | 98 | - COMPREHENSIVENESS: How comprehensive is the answer, does it fully answer all aspects of the question and provide comprehensive explanation and other necessary information. Below are the details for different scores: 99 | 100 | - Score 0: typically if the answer is completely incorrect, then the comprehensiveness is also zero score. 101 | 102 | - Score 1: if the answer is correct but too short to fully answer the question, then we can give score 1 for comprehensiveness. 103 | 104 | - Example: 105 | 106 | - Question: How to use databricks API to create a cluster? 107 | 108 | - Answer: First, you will need a Databricks access token with the appropriate permissions. You can generate this token through the Databricks UI under the 'User Settings' option. And then (the rest is missing) 109 | 110 | - Score 2: the answer is correct and roughly answer the main aspects of the question, but it’s missing description about details. Or is completely missing details about one minor aspect. 111 | 112 | - Example: 113 | 114 | - Question: How to use databricks API to create a cluster? 115 | 116 | - Answer: You will need a Databricks access token with the appropriate permissions. Then you’ll need to set up the request URL, then you can make the HTTP Request. Then you can handle the request response. 117 | 118 | - Example: 119 | 120 | - Question: How to use databricks API to create a cluster? 121 | 122 | - Answer: You will need a Databricks access token with the appropriate permissions. Then you’ll need to set up the request URL, then you can make the HTTP Request. Then you can handle the request response. 123 | 124 | - Score 3: the answer is correct, and covers all the main aspects of the question 125 | 126 | - READABILITY: How readable is the answer, does it have redundant information or incomplete information that hurts the readability of the answer. 127 | 128 | - Score 0: the answer is completely unreadable, e.g. fully of symbols that’s hard to read; e.g. keeps repeating the words that it’s very hard to understand the meaning of the paragraph. No meaningful information can be extracted from the answer. 129 | 130 | - Score 1: the answer is slightly readable, there are irrelevant symbols or repeated words, but it can roughly form a meaningful sentence that cover some aspects of the answer. 131 | 132 | - Example: 133 | 134 | - Question: How to use databricks API to create a cluster? 135 | 136 | - Answer: You you you you you you will need a Databricks access token with the appropriate permissions. And then then you’ll need to set up the request URL, then you can make the HTTP Request. Then Then Then Then Then Then Then Then Then 137 | 138 | - Score 2: the answer is correct and mostly readable, but there is one obvious piece that’s affecting the readability (mentioning of irrelevant pieces, repeated words) 139 | 140 | - Example: 141 | 142 | - Question: How to terminate a databricks cluster 143 | 144 | - Answer: In the Databricks workspace, navigate to the "Clusters" tab. 145 | 146 | Find the cluster you want to terminate from the list of active clusters. 147 | 148 | Click on the down-arrow next to the cluster name to open the cluster details. 149 | 150 | Click on the "Terminate" button………………………………….. 151 | 152 | A confirmation dialog will appear. Click "Terminate" again to confirm the action. 153 | 154 | - Score 3: the answer is correct and reader friendly, no obvious piece that affect readability. 155 | 156 | Grade the student answers based ONLY on their factual accuracy. Ignore differences in punctuation and phrasing between the student answer and true answer. It is OK if the student answer contains more information than the true answer, as long as it does not contain any conflicting statements. If the student answers that there is no specific information provided in the context, then the answer is Incorrect. Begin! 157 | 158 | QUESTION: {query} 159 | STUDENT ANSWER: {result} 160 | TRUE ANSWER: {answer} 161 | 162 | Your response should be as follows, do not provide any additional information. 163 | 164 | CORRECTNESS: (0,1,2 or 3) 165 | COMPREHENSIVENESS: (0,1,2 or 3) 166 | READABILITY: (0,1,2 or 3) 167 | """ 168 | 169 | GRADE_ANSWER_PROMPT_3_CATEGORIES_4_GRADES_FEW_SHOT = PromptTemplate( 170 | input_variables=["query", "result", "answer"], template=template 171 | ) 172 | -------------------------------------------------------------------------------- /ragflow/commons/prompts/grade_retriever_prompts.py: -------------------------------------------------------------------------------- 1 | from langchain.prompts import PromptTemplate 2 | 3 | template = """ 4 | Given the question: 5 | {query} 6 | 7 | Here are some documents retrieved in response to the question: 8 | {result} 9 | 10 | And here is the answer to the question: 11 | {answer} 12 | 13 | GRADING CRITERIA: We want to know if the question can be directly answered with the provided documents and without providing any additional outside sources. Does the retrieved documents make it possible to answer the question? 14 | 15 | Your response should be as follows without providing any additional information: 16 | 17 | GRADE: (0 to 1) - grade '0' means it is impossible to answer the questions with the documents in any way, grade '1' means the question can be fully answered with the provided documents. The more aspects of the questions can be answered the higher the grade should be, with a maximum grade of 1. 18 | """ 19 | 20 | GRADE_RETRIEVER_PROMPT = PromptTemplate( 21 | input_variables=["query", "result", "answer"], template=template 22 | ) 23 | -------------------------------------------------------------------------------- /ragflow/commons/prompts/qa_answer_prompts.py: -------------------------------------------------------------------------------- 1 | from langchain.prompts.prompt import PromptTemplate 2 | 3 | template = """Use the following pieces of context to answer the question at the end. If you don't know the answer, say '''I don't know the answer to this question.'''. Don't try to make up an answer. Write a comprehensive and well formated answer that is readable for a broad and diverse audience of readers. Answer the question as much as possible and only with the provided documents. Don't reference the provided documents but write a extensive answer but also keep the answer concise within a maximum of 8 sentences. 4 | {context} 5 | Question: {question} 6 | Helpful Answer:""" 7 | 8 | QA_ANSWER_PROMPT = PromptTemplate( 9 | input_variables=["context", "question"], 10 | template=template, 11 | ) 12 | -------------------------------------------------------------------------------- /ragflow/commons/prompts/qa_geneneration_prompts.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | from langchain.chains.prompt_selector import ConditionalPromptSelector, is_chat_model 3 | from langchain.prompts.chat import ( 4 | ChatPromptTemplate, 5 | HumanMessagePromptTemplate, 6 | SystemMessagePromptTemplate, 7 | ) 8 | from langchain.prompts.prompt import PromptTemplate 9 | 10 | temp_chat_1 = """You are a smart assistant that can identify key information in a given text and extract a corresponding question and answer pair of this key information so that we can build a Frequently Asked Questions (FAQ) section. Both the question and the answer should be comprehensive, well formulated and readable for a broad and diverse audience of readers. 11 | 12 | When coming up with this question/answer pair, you must respond in the following format: 13 | 14 | ``` 15 | {{ 16 | "question": "$YOUR_QUESTION_HERE", 17 | "answer": "$THE_ANSWER_HERE" 18 | }} 19 | ``` 20 | 21 | Everything between the ``` must be valid json. 22 | """ 23 | temp_chat_2 = """Please come up with a question and answer pair, in the specified JSON format, for the following text: 24 | ---------------- 25 | {text}""" 26 | 27 | CHAT_PROMPT = ChatPromptTemplate.from_messages( 28 | [ 29 | SystemMessagePromptTemplate.from_template(temp_chat_1), 30 | HumanMessagePromptTemplate.from_template(temp_chat_2), 31 | ] 32 | ) 33 | 34 | temp_dft = """You are a smart assistant that can identify key information in a given text and extract a corresponding question and answer pair of this key information so that we can build a Frequently Asked Questions (FAQ) section. Both the question and the answer should be comprehensive, well formulated and readable for a broad and diverse audience of readers. 35 | 36 | When coming up with this question/answer pair, you must respond in the following format: 37 | 38 | ``` 39 | {{ 40 | "question": "$YOUR_QUESTION_HERE", 41 | "answer": "$THE_ANSWER_HERE" 42 | }} 43 | ``` 44 | 45 | Everything between the ``` must be valid json. 46 | 47 | Please come up with a question/answer pair, in the specified JSON format, for the following text: 48 | ---------------- 49 | {text} 50 | """ 51 | 52 | PROMPT = PromptTemplate.from_template(temp_dft) 53 | 54 | QA_GENERATION_PROMPT_SELECTOR = ConditionalPromptSelector( 55 | default_prompt=PROMPT, conditionals=[(is_chat_model, CHAT_PROMPT)] 56 | ) 57 | -------------------------------------------------------------------------------- /ragflow/commons/vectorstore/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AndreasX42/RAGflow/cec31b91af8ab7faa05a3d0376338c7d294d29aa/ragflow/commons/vectorstore/__init__.py -------------------------------------------------------------------------------- /ragflow/commons/vectorstore/pgvector_utils.py: -------------------------------------------------------------------------------- 1 | from sqlalchemy import create_engine, Table, MetaData 2 | from sqlalchemy.sql import select 3 | from langchain.vectorstores.pgvector import PGVector 4 | from langchain.schema.embeddings import Embeddings 5 | 6 | from ragflow.api import PGVECTOR_URL 7 | 8 | TABLE = "langchain_pg_collection" 9 | COL_NAME = "name" 10 | 11 | 12 | def delete_collection(name: str) -> None: 13 | col = PGVector( 14 | collection_name=name, 15 | connection_string=PGVECTOR_URL, 16 | embedding_function=None, 17 | ) 18 | 19 | col.delete_collection() 20 | 21 | 22 | def create_collection(name: str, embedding: Embeddings) -> PGVector: 23 | col = PGVector( 24 | collection_name=name, 25 | connection_string=PGVECTOR_URL, 26 | embedding_function=embedding, 27 | pre_delete_collection=True, 28 | ) 29 | 30 | col.create_collection() 31 | 32 | return col 33 | 34 | 35 | def list_collections() -> list[str]: 36 | # Create an engine 37 | engine = create_engine(PGVECTOR_URL) 38 | 39 | # Reflect the specific table 40 | metadata = MetaData() 41 | table = Table(TABLE, metadata, autoload_with=engine) 42 | 43 | # Query the column 44 | query = select(table.c[COL_NAME]) 45 | with engine.connect() as connection: 46 | results = connection.execute(query).fetchall() 47 | 48 | return results 49 | -------------------------------------------------------------------------------- /ragflow/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | from ragflow.evaluation.hp_evaluator import arun_evaluation 2 | -------------------------------------------------------------------------------- /ragflow/evaluation/hp_evaluator.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import asyncio 3 | from tqdm.asyncio import tqdm as tqdm_asyncio 4 | import glob 5 | import os 6 | 7 | from datetime import datetime 8 | 9 | from ragflow.utils import get_retriever, get_qa_llm, write_json, read_json 10 | from ragflow.utils.doc_processing import aload_and_chunk_docs 11 | from ragflow.evaluation.utils import process_retrieved_docs, write_generated_data_to_csv 12 | from ragflow.commons.configurations import Hyperparameters 13 | from ragflow.commons.chroma import ChromaClient 14 | 15 | from ragflow.evaluation.metrics import ( 16 | answer_embedding_similarity, 17 | predicted_answer_accuracy, 18 | retriever_mrr_accuracy, 19 | retriever_semantic_accuracy, 20 | rouge_score, 21 | ) 22 | 23 | logger = logging.getLogger(__name__) 24 | 25 | 26 | async def arun_eval_for_hp( 27 | label_dataset: list[dict[str, str]], 28 | hp: Hyperparameters, 29 | document_store: list[str], 30 | user_id: str, 31 | ) -> dict: 32 | """Entry point for initiating the evaluation based on the provided hyperparameters and documents. 33 | 34 | Args: 35 | label_dataset (list[dict[str, str]]): _description_ 36 | hp (BaseConfigurations): _description_ 37 | docs_path (str): _description_ 38 | 39 | Returns: 40 | dict: _description_ 41 | """ 42 | scores = { 43 | "answer_similarity_score": -1, 44 | "retriever_mrr@3": -1, 45 | "retriever_mrr@5": -1, 46 | "retriever_mrr@10": -1, 47 | "rouge1": -1, 48 | "rouge2": -1, 49 | "rougeLCS": -1, 50 | # metrics below use LLM for grading 51 | "correctness_score": -1, 52 | "comprehensiveness_score": -1, 53 | "readability_score": -1, 54 | "retriever_semantic_accuracy": -1, 55 | } 56 | 57 | # create chunks of all provided documents 58 | chunks = await aload_and_chunk_docs(hp, document_store) 59 | 60 | # get retriever from chunks 61 | retriever, retrieverForGrading = get_retriever(chunks, hp, user_id, for_eval=True) 62 | 63 | # chunks are no longer needed 64 | del chunks 65 | 66 | # llm for answering queries 67 | qa_llm = get_qa_llm(retriever, hp.qa_llm) 68 | 69 | # list(dict[question, result, source_documents]) 70 | predicted_answers = await asyncio.gather( 71 | *[qa_llm.acall(qa_pair) for qa_pair in label_dataset] 72 | ) 73 | 74 | # list of retrieved docs for each qa_pair 75 | process_retrieved_docs(predicted_answers, hp.id) 76 | 77 | # Calculate embedding similarities of label and predicted answers 78 | scores[ 79 | "answer_similarity_score" 80 | ] = answer_embedding_similarity.grade_embedding_similarity( 81 | label_dataset, predicted_answers, hp.embedding_model, user_id 82 | ) 83 | 84 | ( 85 | scores["retriever_mrr@3"], 86 | scores["retriever_mrr@5"], 87 | scores["retriever_mrr@10"], 88 | ) = await retriever_mrr_accuracy.grade_retriever(label_dataset, retrieverForGrading) 89 | 90 | # Calculate ROUGE scores 91 | scores["rouge1"], scores["rouge2"], scores["rougeLCS"] = rouge_score.grade_rouge( 92 | label_dataset, predicted_answers 93 | ) 94 | 95 | # if we want a llm to grade the predicted answers as well 96 | if hp.use_llm_grader: 97 | # grade predicted answers 98 | ( 99 | scores["correctness_score"], 100 | scores["comprehensiveness_score"], 101 | scores["readability_score"], 102 | ) = predicted_answer_accuracy.grade_predicted_answer( 103 | label_dataset, 104 | predicted_answers, 105 | hp.grader_llm, 106 | hp.grade_answer_prompt, 107 | ) 108 | 109 | # grade quality of retrieved documents used to answer the questions 110 | scores[ 111 | "retriever_semantic_accuracy" 112 | ] = retriever_semantic_accuracy.grade_retriever( 113 | label_dataset, 114 | predicted_answers, 115 | hp.grader_llm, 116 | hp.grade_docs_prompt, 117 | ) 118 | 119 | # preprocess dict before writing to json and add additional information like a timestamp 120 | result_dict = hp.to_dict() 121 | result_dict |= {"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S,%f")[:-3]} 122 | result_dict |= {"scores": scores} 123 | 124 | # if in test mode store additional information about retriever and vectorstore objects 125 | if os.environ.get("EXECUTION_CONTEXT") == "TEST": 126 | result_dict |= {"retriever_params": retriever.dict()} 127 | result_dict |= {"vectorstore_params": retriever.vectorstore._collection.dict()} 128 | 129 | return result_dict, predicted_answers 130 | 131 | 132 | def prepare_evaluation_run( 133 | hyperparameters_path: str, user_id: str, api_keys: dict 134 | ) -> list[Hyperparameters]: 135 | """Preparation steps before running evaluation. Primarily we have to increment the hyperparameter ids correctly.""" 136 | 137 | def extract_hpid(s): 138 | import re 139 | 140 | match = re.search(r"hpid_(\d+)", s) 141 | return int(match.group(1)) if match else 0 142 | 143 | hyperparams_list = read_json(hyperparameters_path) 144 | 145 | # find all previous hp runs with corresponding ids 146 | with ChromaClient() as client: 147 | collections = client.list_collections() 148 | 149 | hp_ids = [ 150 | extract_hpid(col.name) 151 | for col in collections 152 | if col.name.startswith(f"userid_{user_id}_") and "_hpid_" in col.name 153 | ] 154 | 155 | if not hp_ids or all(id == -1 for id in hp_ids): 156 | next_id = 0 157 | else: 158 | next_id = max(hp_ids) + 1 159 | 160 | hp_list = [ 161 | Hyperparameters.from_dict(config, id, api_keys) 162 | for id, config in enumerate(hyperparams_list, start=next_id) 163 | ] 164 | 165 | return hp_list 166 | 167 | 168 | async def arun_evaluation( 169 | document_store_path: str, 170 | hyperparameters_path: str, 171 | label_dataset_path: str, 172 | hyperparameters_results_path: str, 173 | hyperparameters_results_data_path: str, 174 | user_id: str, 175 | api_keys: dict, 176 | ) -> None: 177 | """Entry point for the evaluation of a set of hyperparameters to benchmark the corresponding RAG model. 178 | 179 | Args: 180 | document_store_path (str): Path to the documents 181 | hyperparameters_path (str): Path to the hyperparameters 182 | label_dataset_path (str): Path to the evaluation ground truth dataset 183 | hyperparameters_results_path (str): Path to the JSON file where results should get stored 184 | hyperparameters_results_data_path (str): Path to the CSV file where additional evaluation data should get stored 185 | user_id (str): The userd id 186 | api_keys (dict): All required API keys for the LLM endpoints 187 | """ 188 | # load evaluation dataset 189 | label_dataset = read_json(label_dataset_path) 190 | 191 | # preprocess hyperparameter configs from json file 192 | hp_list = prepare_evaluation_run(hyperparameters_path, user_id, api_keys) 193 | 194 | document_store = glob.glob(f"{document_store_path}/*") 195 | 196 | tasks = [ 197 | arun_eval_for_hp(label_dataset, hp, document_store, user_id) for hp in hp_list 198 | ] 199 | 200 | # run evaluations for all hyperparams 201 | results = await tqdm_asyncio.gather(*tasks, total=len(tasks)) 202 | 203 | eval_scores, predicted_answers = list(zip(*results)) 204 | 205 | # delete processed hyperparameters from input json 206 | write_json( 207 | data=[], 208 | filename=hyperparameters_path, 209 | append=True if os.environ.get("EXECUTION_CONTEXT", "") == "TEST" else False, 210 | ) 211 | 212 | # write eval metrics to json 213 | write_json(eval_scores, hyperparameters_results_path) 214 | 215 | # write predicted answers and retrieved docs to csv 216 | write_generated_data_to_csv(predicted_answers, hyperparameters_results_data_path) 217 | -------------------------------------------------------------------------------- /ragflow/evaluation/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from ragflow.evaluation.metrics import ( 2 | answer_embedding_similarity, 3 | predicted_answer_accuracy, 4 | retriever_mrr_accuracy, 5 | retriever_semantic_accuracy, 6 | rouge_score, 7 | ) 8 | -------------------------------------------------------------------------------- /ragflow/evaluation/metrics/answer_embedding_similarity.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from langchain.schema.embeddings import Embeddings 4 | from ragflow.commons.chroma import ChromaClient 5 | from ragflow.commons.configurations import BaseConfigurations 6 | 7 | import logging 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | def grade_embedding_similarity( 13 | label_dataset: list[dict], 14 | predictions: list[dict], 15 | embedding_model: Embeddings, 16 | user_id: str, 17 | ) -> float: 18 | """Calculate similarities of label answers and generated answers using the corresponding embeddings. We multiply the matrixes of the provided embeddings and take the average of the diagonal values, which should be the cosine similarites assuming that the embeddings were already normalized. 19 | 20 | Args: 21 | label_dataset (list[dict]): The evaluation ground truth dataset of QA pairs 22 | predictions (list[dict]): A dict containing the predicted answers from the queries and the retrieved document chunks 23 | embedding_model (Embeddings): The embedding model 24 | user_id (str): The user id 25 | answers from the evaluation dataset from ChromaDB. 26 | 27 | Returns: 28 | float: The average similarity score. 29 | """ 30 | logger.info("Calculating embedding similarities.") 31 | 32 | num_qa_pairs = len(label_dataset) 33 | 34 | label_answers = [qa_pair["answer"] for qa_pair in label_dataset] 35 | predicted_answers = [qa_pair["result"] for qa_pair in predictions] 36 | 37 | # try using embeddings of answers of evaluation set from vectorstore that were stored in ChromaDB in generation process, otherwise we calculate them again 38 | try: 39 | with ChromaClient() as CHROMA_CLIENT: 40 | collection_name = f"userid_{user_id}_qaid_0_{BaseConfigurations.get_embedding_model_name(embedding_model)}" 41 | 42 | for col in CHROMA_CLIENT.list_collections(): 43 | if col.name == collection_name: 44 | collection = col 45 | break 46 | 47 | ids = [qa["metadata"]["id"] for qa in label_dataset] 48 | 49 | target_embeddings = np.array( 50 | collection.get(ids=ids, include=["embeddings"])["embeddings"] 51 | ).reshape(num_qa_pairs, -1) 52 | 53 | logger.info("Embeddings for label answers loaded successfully.") 54 | 55 | except Exception as ex: 56 | logger.info( 57 | f"Embeddings of {BaseConfigurations.get_embedding_model_name(embedding_model)} for label answers could not be loaded from vectorstore.\n\ 58 | Collections: {CHROMA_CLIENT.list_collections()}.\n\ 59 | Exception: {ex.args}" 60 | ) 61 | 62 | target_embeddings = np.array( 63 | embedding_model.embed_documents(label_answers) 64 | ).reshape(num_qa_pairs, -1) 65 | 66 | predicted_embeddings = np.array( 67 | embedding_model.embed_documents(predicted_answers) 68 | ).reshape(num_qa_pairs, -1) 69 | 70 | emb_norms = np.linalg.norm(target_embeddings, axis=1) * np.linalg.norm( 71 | predicted_embeddings, axis=1 72 | ) 73 | 74 | dot_prod = np.diag(np.dot(target_embeddings, predicted_embeddings.T)) 75 | 76 | similarities = dot_prod / emb_norms 77 | 78 | return np.nanmean(similarities) 79 | -------------------------------------------------------------------------------- /ragflow/evaluation/metrics/predicted_answer_accuracy.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from langchain.evaluation.qa import QAEvalChain 4 | from langchain.schema.language_model import BaseLanguageModel 5 | 6 | from ragflow.evaluation.utils import extract_llm_metric 7 | from ragflow.commons.configurations import CVGradeAnswerPrompt 8 | from ragflow.commons.prompts import ( 9 | GRADE_ANSWER_PROMPT_5_CATEGORIES_5_GRADES_ZERO_SHOT, 10 | GRADE_ANSWER_PROMPT_3_CATEGORIES_4_GRADES_FEW_SHOT, 11 | ) 12 | 13 | import logging 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | def grade_predicted_answer( 19 | label_dataset: list[dict], 20 | predicted_answers: list[str], 21 | grader_llm: BaseLanguageModel, 22 | grade_answer_prompt: CVGradeAnswerPrompt, 23 | ) -> tuple[float, float, float]: 24 | """_summary_Using a QAEvalChain and a LLM for grading we calculate a grade for the predicted answers from the query and the retrieved document chunks. We provide zero shot and few shot prompts and get the LLM to provide a score for correctness, comprehensiveness and readability of the predicted answers. 25 | 26 | Args: 27 | label_dataset (list[dict]): The evaluation ground truth dataset of QA pairs 28 | predicted_answers (list[str]): The predicted answers 29 | grader_llm (BaseLanguageModel): The LLM used for grading 30 | grade_answer_prompt (CVGradeAnswerPrompt): The prompt used in the QAEvalChain 31 | 32 | Returns: 33 | tuple[float, float, float]: Returns average scores for correctness, comprehensiveness and readability 34 | """ 35 | 36 | logger.info("Grading generated answers.") 37 | 38 | if grade_answer_prompt == CVGradeAnswerPrompt.ZERO_SHOT: 39 | prompt, MAX_GRADE = GRADE_ANSWER_PROMPT_5_CATEGORIES_5_GRADES_ZERO_SHOT, 5 40 | elif grade_answer_prompt == CVGradeAnswerPrompt.FEW_SHOT: 41 | prompt, MAX_GRADE = GRADE_ANSWER_PROMPT_3_CATEGORIES_4_GRADES_FEW_SHOT, 3 42 | else: 43 | prompt, MAX_GRADE = None, 1 44 | 45 | # Note: GPT-4 grader is advised by OAI model_name="gpt-4" 46 | eval_chain = QAEvalChain.from_llm(llm=grader_llm, prompt=prompt, verbose=False) 47 | 48 | outputs = eval_chain.evaluate( 49 | label_dataset, 50 | predicted_answers, 51 | question_key="question", 52 | prediction_key="result", 53 | ) 54 | 55 | # vectorize the function for efficiency 56 | v_extract_llm_metric = np.vectorize(extract_llm_metric) 57 | 58 | outputs = np.array([output["results"] for output in outputs]) 59 | 60 | correctness = v_extract_llm_metric(outputs, "CORRECTNESS") 61 | comprehensiveness = v_extract_llm_metric(outputs, "COMPREHENSIVENESS") 62 | readability = v_extract_llm_metric(outputs, "READABILITY") 63 | 64 | return ( 65 | np.nanmean(correctness) / MAX_GRADE, 66 | np.nanmean(comprehensiveness) / MAX_GRADE, 67 | np.nanmean(readability) / MAX_GRADE, 68 | ) 69 | -------------------------------------------------------------------------------- /ragflow/evaluation/metrics/retriever_mrr_accuracy.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import numpy as np 3 | from langchain.schema.vectorstore import VectorStoreRetriever 4 | from langchain.schema.document import Document 5 | 6 | import logging 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | async def grade_retriever( 12 | label_dataset: list[dict], retrieverForGrading: VectorStoreRetriever 13 | ) -> tuple[float, float, float]: 14 | """Calculate metrics for the precision of the retriever using MRR (Mean Reciprocal Rank) scores. 15 | For every list of retrieved documents we calculate the rank of the chunks wrt reference chunk from 16 | the label dataset. 17 | 18 | Args: 19 | label_dataset (list[dict]): _description_ 20 | retrieved_docs (list[str]): _description_ 21 | retrieverForGrading (VectorStoreRetriever): _description_ 22 | 23 | Returns: 24 | tuple[float,float,float]: MRR scores for top 3, 5 and 10 25 | """ 26 | 27 | logger.info("Grading retrieved document chunks using MRR.") 28 | 29 | # get predicted answers with top 10 retrieved document chunks 30 | retrieved_chunks = await asyncio.gather( 31 | *[ 32 | retrieverForGrading.aget_relevant_documents(qa_pair["question"]) 33 | for qa_pair in label_dataset 34 | ] 35 | ) 36 | 37 | ref_chunks = [ 38 | Document(page_content=label["metadata"]["context"], metadata=label["metadata"]) 39 | for label in label_dataset 40 | ] 41 | 42 | return calculate_mrr(ref_chunks, retrieved_chunks) 43 | 44 | 45 | def calculate_mrr(ref_chunks, retrieved_chunks): 46 | """Calculates mrr scores.""" 47 | 48 | top3, top5, top10 = [], [], [] 49 | 50 | # Check to ensure ref_chunks is not empty 51 | if not ref_chunks: 52 | return 0, 0, 0 53 | 54 | for ref_chunk, retr_chunks in zip(ref_chunks, retrieved_chunks): 55 | hit = False 56 | for idx, chunk in enumerate(retr_chunks, 1): 57 | if is_hit(ref_chunk, chunk): 58 | rank = 1 / idx 59 | if idx <= 3: 60 | top3.append(rank) 61 | if idx <= 5: 62 | top5.append(rank) 63 | if idx <= 10: 64 | top10.append(rank) 65 | hit = True 66 | break 67 | 68 | if not hit: # If there's no hit, the rank contribution is 0 69 | top3.append(0) 70 | top5.append(0) 71 | top10.append(0) 72 | 73 | # Calculate the MRR for the top 3, 5, and 10 documents 74 | mrr_3 = sum(top3) / len(ref_chunks) 75 | mrr_5 = sum(top5) / len(ref_chunks) 76 | mrr_10 = sum(top10) / len(ref_chunks) 77 | 78 | return mrr_3, mrr_5, mrr_10 79 | 80 | 81 | def is_hit(ref_chunk, retrieved_chunk): 82 | """Checks if retrieved chunk is close to reference chunk.""" 83 | 84 | # Check if both chunks are from same document 85 | if ref_chunk.metadata["source"] != retrieved_chunk.metadata["source"]: 86 | return False 87 | 88 | label_start, label_end = ( 89 | ref_chunk.metadata["start_index"], 90 | ref_chunk.metadata["end_index"], 91 | ) 92 | retr_start, retr_end = ( 93 | retrieved_chunk.metadata["start_index"], 94 | retrieved_chunk.metadata["end_index"], 95 | ) 96 | 97 | retr_center = retr_start + (retr_end - retr_start) // 2 98 | 99 | # Consider retrieved chunk to be a hit if it is near the reference chunk 100 | return retr_center in range(label_start, label_end + 1) 101 | -------------------------------------------------------------------------------- /ragflow/evaluation/metrics/retriever_semantic_accuracy.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from langchain.evaluation.qa import QAEvalChain 4 | from langchain.schema.language_model import BaseLanguageModel 5 | 6 | from ragflow.commons.prompts import GRADE_RETRIEVER_PROMPT 7 | from ragflow.evaluation.utils import extract_llm_metric 8 | from ragflow.commons.configurations import CVGradeRetrieverPrompt 9 | 10 | import logging 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | def grade_retriever( 16 | label_dataset: list[dict], 17 | retrieved_docs: list[str], 18 | grader_llm: BaseLanguageModel, 19 | grade_docs_prompt: CVGradeRetrieverPrompt, 20 | ) -> float: 21 | """Using LangChains QAEvalChain we use a LLM as grader to get a metric how well the found document chunks from the vectorstore provide the answer for the label QA pairs. Per QA pair the LLM has to rate the retrieved document chunks with 0 or 1 depending if the question can be answered with the provided document chunks. 22 | 23 | Args: 24 | label_dataset (list[dict]): The evaluation ground truth dataset of QA pairs 25 | retrieved_docs (list[str]): The retrieved document chunks of each QA pair 26 | grader_llm (BaseLanguageModel): The LLM grading the document chunks 27 | grade_docs_prompt (CVGradeRetrieverPrompt): The type of prompt used in the QAEvalChain 28 | 29 | Returns: 30 | float: The average score of all QA pairs 31 | """ 32 | 33 | logger.info("Grading retrieved document chunks.") 34 | 35 | # TODO: Provide more Prompts and more detailed prompt engineering 36 | if grade_docs_prompt == CVGradeRetrieverPrompt.DEFAULT: 37 | prompt = GRADE_RETRIEVER_PROMPT 38 | else: 39 | prompt = GRADE_RETRIEVER_PROMPT 40 | 41 | # Note: GPT-4 grader is advised by OAI 42 | eval_chain = QAEvalChain.from_llm(llm=grader_llm, prompt=prompt) 43 | 44 | outputs = eval_chain.evaluate( 45 | label_dataset, 46 | retrieved_docs, 47 | question_key="question", 48 | answer_key="answer", 49 | prediction_key="retrieved_docs", 50 | ) 51 | 52 | # vectorize the function for efficiency 53 | v_extract_llm_metric = np.vectorize(extract_llm_metric) 54 | 55 | outputs = np.array([output["results"] for output in outputs]) 56 | retrieval_accuracy = v_extract_llm_metric(outputs, "GRADE") 57 | 58 | return np.nanmean(retrieval_accuracy) 59 | -------------------------------------------------------------------------------- /ragflow/evaluation/metrics/rouge_score.py: -------------------------------------------------------------------------------- 1 | import evaluate 2 | 3 | import logging 4 | 5 | logger = logging.getLogger(__name__) 6 | 7 | 8 | def grade_rouge( 9 | label_dataset: list[dict], predicted_answers: list[dict] 10 | ) -> tuple[float, float]: 11 | """Calculates rouge1 and rouge2 scores of the label and generated answers. 12 | 13 | Args: 14 | label_dataset (list[dict]): The evaluation ground truth dataset of QA pairs 15 | predictions (list[dict]): The predicted answers 16 | 17 | Returns: 18 | tuple[float, float]: The rouge1 and rouge2 average scores 19 | """ 20 | logger.info("Calculating ROUGE scores.") 21 | 22 | label_answers = [qa_pair["answer"] for qa_pair in label_dataset] 23 | predicted_answers = [qa_pair["result"] for qa_pair in predicted_answers] 24 | 25 | rouge_score = evaluate.load("rouge") 26 | score = rouge_score.compute(references=label_answers, predictions=predicted_answers) 27 | 28 | return score["rouge1"], score["rouge2"], score["rougeL"] 29 | -------------------------------------------------------------------------------- /ragflow/evaluation/utils.py: -------------------------------------------------------------------------------- 1 | from langchain.schema.retriever import BaseRetriever 2 | from langchain.docstore.document import Document 3 | 4 | import pandas as pd 5 | import numpy as np 6 | import re 7 | import os 8 | 9 | 10 | def extract_llm_metric(text: str, metric: str) -> float: 11 | """Utiliy function for extracting scores from LLM output from grading of generated answers and retrieved document chunks. 12 | 13 | Args: 14 | text (str): LLM result 15 | metric (str): name of metric 16 | 17 | Returns: 18 | number: the found score as integer or np.nan 19 | """ 20 | match = re.search(f"{metric}: (\d+)", text) 21 | if match: 22 | return int(match.group(1)) 23 | return np.nan 24 | 25 | 26 | async def aget_retrieved_documents( 27 | qa_pair: dict[str, str], retriever: BaseRetriever 28 | ) -> dict: 29 | """Retrieve most similar documents to query asynchronously, postprocess the document chunks and return the qa pair with the retrieved documents string as result 30 | 31 | Args: 32 | qa_pair (dict[str, str]): _description_ 33 | retriever (BaseRetriever): _description_ 34 | 35 | Returns: 36 | list[dict]: _description_ 37 | """ 38 | query = qa_pair["question"] 39 | docs_retrieved = await retriever.aget_relevant_documents(query) 40 | 41 | retrieved_doc_text = "\n\n".join( 42 | f"Retrieved document {i}: {doc.page_content}" 43 | for i, doc in enumerate(docs_retrieved) 44 | ) 45 | retrieved_dict = { 46 | "question": qa_pair["question"], 47 | "answer": qa_pair["answer"], 48 | "result": retrieved_doc_text, 49 | } 50 | 51 | return retrieved_dict 52 | 53 | 54 | def process_retrieved_docs(qa_results: list[dict], hp_id: int) -> None: 55 | """For each list of Documents retrieved in QA_LLM call, we merge the documents page contents into a string. We store the id of the hyperparameters for later. 56 | 57 | Args: 58 | qa_results (list[dict]): _description_ 59 | hp_id (int): _description_ 60 | """ 61 | 62 | # Each qa_result is a dict containing in "source_documents" the list of retrieved documents used to answer the query 63 | for qa_result in qa_results: 64 | retrieved_docs_str = "\n\n---------------------------\n".join( 65 | f"///Retrieved document {i}: {clean_page_content(doc.page_content)}\n---Metadata: {doc.metadata}" 66 | for i, doc in enumerate(qa_result["source_documents"]) 67 | ) 68 | 69 | # key now stores the concatenated string of the page contents 70 | qa_result["retrieved_docs"] = retrieved_docs_str 71 | 72 | # id of hyperparameters needed later 73 | qa_result["hp_id"] = hp_id 74 | 75 | 76 | def convert_element_to_df(element: dict): 77 | """Function to convert each tuple element of result to dataframe.""" 78 | df = pd.DataFrame(element) 79 | df = pd.concat([df, pd.json_normalize(df["metadata"])], axis=1) 80 | df = df.drop(columns=["metadata", "question", "answer", "source"]) 81 | df = df.rename(columns={"result": "predicted_answer", "id": "qa_id"}) 82 | df["hp_id"] = df["hp_id"].astype(int) 83 | 84 | return df 85 | 86 | 87 | def clean_page_content(page_content: str): 88 | """Helper function to remove paragraph breaks.""" 89 | return re.sub("\n+", "\n", page_content) 90 | 91 | 92 | def write_generated_data_to_csv( 93 | predicted_answers: list[dict], 94 | hyperparameters_results_data_path: str, 95 | ) -> None: 96 | """Write predicted answerd and retrieved documents to csv.""" 97 | # Create the dataframe 98 | df = pd.concat( 99 | [convert_element_to_df(element) for element in predicted_answers], axis=0 100 | ) 101 | 102 | if os.path.exists(hyperparameters_results_data_path): 103 | base_df = pd.read_csv(hyperparameters_results_data_path) 104 | else: 105 | base_df = pd.DataFrame() 106 | 107 | # Concatenate the DataFrames along rows while preserving shared columns 108 | result = pd.concat( 109 | [ 110 | base_df, 111 | df[["hp_id", "predicted_answer", "retrieved_docs", "qa_id"]], 112 | ], 113 | axis=0, 114 | ) 115 | 116 | # Write df to csv 117 | result.sort_values(by=["hp_id"]).to_csv( 118 | hyperparameters_results_data_path, index=False 119 | ) 120 | -------------------------------------------------------------------------------- /ragflow/example.py: -------------------------------------------------------------------------------- 1 | import dotenv 2 | import logging 3 | import asyncio 4 | 5 | from ragflow.evaluation import arun_evaluation 6 | from ragflow.generation import agenerate_evaluation_set 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | dotenv.load_dotenv(dotenv.find_dotenv(), override=True) 11 | 12 | 13 | async def main( 14 | document_store_path: str = "./resources/dev/document_store/", 15 | label_dataset_path: str = "./resources/dev/label_dataset.json", 16 | label_dataset_gen_params_path: str = "./resources/dev/label_dataset_gen_params.json", 17 | hyperparameters_path: str = "./resources/dev/hyperparameters.json", 18 | hyperparameters_results_path: str = "./resources/dev/hyperparameters_results.json", 19 | hyperparameters_results_data_path: str = "./resources/dev/hyperparameters_results_data.csv", 20 | user_id: str = "336e131c-d9f3-4185-a402-f5f8875e34f0", 21 | ): 22 | """After generating label_dataset we need to calculate the metrics based on Chunking strategy, type of vectorstore, retriever (similarity search), QA LLM 23 | 24 | Dependencies of parameters and build order: 25 | generating chunks | depends on chunking strategy 26 | |__ vector database | depends on type von DB 27 | |__ retriever | depends on Embedding model 28 | |__ QA LLM | depends on LLM and digests retriever 29 | 30 | |__ grader: llm who grades generated answers 31 | 32 | TODO 1: State of the art retrieval with lots of additional LLM calls: 33 | - use "MMR" to filter out similiar document chunks during similarity search 34 | - Use SelfQueryRetriever to find information in query about specific context, e.g. querying specific document only 35 | - Use ContextualCompressionRetriever to compress or summarize the document chunks before serving them to QA-LLM 36 | 37 | - retriever object would like something like this, additional metadata info for SelfQueryRetriever has to be provided as well 38 | retriever = ContextualCompressionRetriever( 39 | base_compressor=LLMChainExtractor.from_llm(llm), 40 | base_retriever=SelfQueryRetriever.from_llm(..., earch_type = "mmr") 41 | ) 42 | 43 | https://learn.deeplearning.ai/langchain-chat-with-your-data/lesson/5/retrieval 44 | https://learn.deeplearning.ai/langchain-chat-with-your-data/lesson/7/chat 45 | 46 | TODO 2: Integrate Anyscale (Llama2), MosaicML (MPT + EMBDS), Replicate 47 | 48 | Args: 49 | chain (str, optional): _description_. Defaults to "". 50 | retriever (str, optional): _description_. Defaults to "". 51 | retriever_type (str, optional): _description_. Defaults to "". 52 | k (int, optional): _description_. Defaults to 3. 53 | grade_prompt (str, optional): _description_. Defaults to "". 54 | file_path (str, optional): _description_. Defaults to "./resources". 55 | eval_gt_path (str, optional): _description_. Defaults to "./resources/label_dataset.json". 56 | """ 57 | 58 | # provide API keys here if necessary 59 | import os 60 | 61 | api_keys = {"OPENAI_API_KEY": os.environ.get("OPENAI_API_KEY")} 62 | 63 | ################################################################ 64 | # First phase: Loading or generating evaluation dataset 65 | ################################################################ 66 | 67 | logger.info("Checking for evaluation dataset configs.") 68 | 69 | await agenerate_evaluation_set( 70 | label_dataset_gen_params_path, 71 | label_dataset_path, 72 | document_store_path, 73 | user_id, 74 | api_keys, 75 | ) 76 | 77 | ################################################################ 78 | # Second phase: Running evaluations 79 | ################################################################ 80 | 81 | logger.info("Starting evaluation for all provided hyperparameters.") 82 | 83 | await arun_evaluation( 84 | document_store_path, 85 | hyperparameters_path, 86 | label_dataset_path, 87 | hyperparameters_results_path, 88 | hyperparameters_results_data_path, 89 | user_id, 90 | api_keys, 91 | ) 92 | 93 | 94 | if __name__ == "__main__": 95 | asyncio.run(main()) 96 | -------------------------------------------------------------------------------- /ragflow/generation/__init__.py: -------------------------------------------------------------------------------- 1 | from ragflow.generation.label_dataset_generator import agenerate_evaluation_set 2 | -------------------------------------------------------------------------------- /ragflow/generation/label_dataset_generator.py: -------------------------------------------------------------------------------- 1 | from json import JSONDecodeError 2 | import itertools 3 | import asyncio 4 | import glob 5 | import os 6 | from datetime import datetime 7 | from tqdm.asyncio import tqdm as tqdm_asyncio 8 | 9 | from langchain.chains import QAGenerationChain 10 | from langchain.docstore.document import Document 11 | from langchain.schema.embeddings import Embeddings 12 | 13 | from ragflow.commons.prompts import QA_GENERATION_PROMPT_SELECTOR 14 | from ragflow.utils import aload_and_chunk_docs, write_json, read_json 15 | from ragflow.commons.configurations import QAConfigurations 16 | from ragflow.commons.chroma import ChromaClient 17 | 18 | import uuid 19 | import logging 20 | 21 | logger = logging.getLogger(__name__) 22 | 23 | 24 | async def get_qa_from_chunk( 25 | chunk: Document, 26 | qa_generator_chain: QAGenerationChain, 27 | ) -> list[dict]: 28 | """Generate QA from provided text document chunk.""" 29 | try: 30 | # return list of qa pairs 31 | qa_pairs = qa_generator_chain.run(chunk.page_content) 32 | 33 | # attach chunk metadata to qa_pair 34 | for qa_pair in qa_pairs: 35 | qa_pair["metadata"] = dict(**chunk.metadata) 36 | qa_pair["metadata"].update( 37 | {"id": str(uuid.uuid4()), "context": chunk.page_content} 38 | ) 39 | 40 | return qa_pairs 41 | except JSONDecodeError: 42 | return [] 43 | 44 | 45 | async def agenerate_label_dataset_from_doc( 46 | hp: QAConfigurations, 47 | doc_path: str, 48 | ) -> list[dict[str, str]]: 49 | """Generate a pairs of QAs that are used as ground truth in downstream tasks, i.e. RAG evaluations 50 | 51 | Args: 52 | llm (BaseLanguageModel): the language model used in the QAGenerationChain 53 | chunks (List[Document]): the document chunks used for QA generation 54 | 55 | Returns: 56 | List[Dict[str, str]]: returns a list of dictionary of question - answer pairs 57 | """ 58 | 59 | logger.debug(f"Starting QA generation process for {doc_path}.") 60 | 61 | # load data and chunk doc 62 | chunks = await aload_and_chunk_docs(hp, [doc_path]) 63 | 64 | qa_generator_chain = QAGenerationChain.from_llm( 65 | hp.qa_generator_llm, 66 | prompt=QA_GENERATION_PROMPT_SELECTOR.get_prompt(hp.qa_generator_llm), 67 | ) 68 | 69 | tasks = [get_qa_from_chunk(chunk, qa_generator_chain) for chunk in chunks] 70 | 71 | qa_pairs = await asyncio.gather(*tasks) 72 | qa_pairs = list(itertools.chain.from_iterable(qa_pairs)) 73 | 74 | return qa_pairs 75 | 76 | 77 | async def agenerate_label_dataset_from_docs( 78 | hp: QAConfigurations, 79 | docs_path: list[str], 80 | ) -> list[dict]: 81 | """Asynchronous wrapper around the agenerate_label_dataset function. 82 | 83 | Args: 84 | qa_gen_configs (dict): _description_ 85 | docs_path (list[str]): _description_ 86 | 87 | Returns: 88 | list[dict]: _description_ 89 | """ 90 | tasks = [agenerate_label_dataset_from_doc(hp, doc_path) for doc_path in docs_path] 91 | 92 | results = await tqdm_asyncio.gather(*tasks) 93 | 94 | qa_pairs = list(itertools.chain.from_iterable(results)) 95 | 96 | return qa_pairs 97 | 98 | 99 | async def aupsert_embeddings_for_model( 100 | qa_pairs: list[dict], 101 | embedding_model: Embeddings, 102 | user_id: str, 103 | ) -> None: 104 | """Embeds and upserts each generated answer into vectorstore. This is helpful if you want to run different hyperparameter runs with the same embedding model because you only have to embed these answers once. The embeddings are used during evaluation to check similarity of generated and predicted answers.""" 105 | with ChromaClient() as CHROMA_CLIENT: 106 | collection_name = f"userid_{user_id}_qaid_0_{QAConfigurations.get_embedding_model_name(embedding_model)}" 107 | 108 | # check if collection already exists, if not create a new one with the embeddings 109 | if [ 110 | collection 111 | for collection in CHROMA_CLIENT.list_collections() 112 | if collection.name.startswith(f"userid_{user_id}_") 113 | ]: 114 | logger.info(f"Collection {collection_name} already exists, skipping it.") 115 | return None 116 | 117 | collection = CHROMA_CLIENT.create_collection( 118 | name=collection_name, 119 | metadata={ 120 | "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S,%f")[:-3], 121 | }, 122 | ) 123 | 124 | ids = [qa_pair["metadata"]["id"] for qa_pair in qa_pairs] 125 | 126 | # maybe async function is not implemented for embedding model 127 | try: 128 | embeddings = await embedding_model.aembed_documents( 129 | [qa_pair["answer"] for qa_pair in qa_pairs] 130 | ) 131 | except NotImplementedError as ex: 132 | logger.error( 133 | f"Exception during eval set generation and upserting to vectorstore, {ex}" 134 | ) 135 | embeddings = embedding_model.embed_documents( 136 | [qa_pair["answer"] for qa_pair in qa_pairs] 137 | ) 138 | 139 | collection.upsert( 140 | ids=ids, 141 | embeddings=embeddings, 142 | metadatas=[ 143 | { 144 | "question": qa_pair["question"], 145 | "answer": qa_pair["answer"], 146 | **qa_pair["metadata"], 147 | } 148 | for qa_pair in qa_pairs 149 | ], 150 | ) 151 | 152 | logger.info( 153 | f"Upserted {QAConfigurations.get_embedding_model_name(embedding_model)} embeddings to vectorstore." 154 | ) 155 | 156 | 157 | async def agenerate_and_save_dataset( 158 | hp: QAConfigurations, 159 | docs_path: str, 160 | label_dataset_path: str, 161 | user_id: str, 162 | ): 163 | """Generate a new evaluation dataset and save it to a JSON file.""" 164 | 165 | logger.info("Starting QA generation suite.") 166 | 167 | # generate label dataset 168 | label_dataset = await agenerate_label_dataset_from_docs(hp, docs_path) 169 | 170 | # During test execution: if label_dataset is empty because of test dummy LLM, we inject a real dataset for test 171 | if ( 172 | os.environ.get("EXECUTION_CONTEXT") == "TEST" 173 | and hp.persist_to_vs 174 | and not label_dataset 175 | ): 176 | label_dataset = read_json(os.environ.get("INPUT_LABEL_DATASET")) 177 | 178 | # write eval dataset to json 179 | write_json(label_dataset, label_dataset_path) 180 | 181 | # cache answers of qa pairs in vectorstore for each embedding model provided 182 | if hp.persist_to_vs: 183 | tasks = [ 184 | aupsert_embeddings_for_model(label_dataset, embedding_model, user_id) 185 | for embedding_model in hp.embedding_model_list 186 | ] 187 | 188 | await asyncio.gather(*tasks) 189 | 190 | 191 | async def agenerate_evaluation_set( 192 | label_dataset_gen_params_path: str, 193 | label_dataset_path: str, 194 | document_store_path: str, 195 | user_id: str, 196 | api_keys: dict[str, str], 197 | ): 198 | """Entry function to generate the evaluation dataset. 199 | 200 | Args: 201 | label_dataset_gen_params (dict): _description_ 202 | label_dataset_path (str): _description_ 203 | 204 | Returns: 205 | _type_: _description_ 206 | """ 207 | 208 | logger.info("Checking for evaluation dataset configs.") 209 | 210 | import nltk 211 | from nltk.data import find 212 | 213 | # Check if 'punkt' is already downloaded, and download if not 214 | try: 215 | find("tokenizers/punkt") 216 | except LookupError: 217 | nltk.download("punkt") 218 | 219 | label_dataset_gen_params = read_json(label_dataset_gen_params_path) 220 | 221 | # TODO: Only one single QA generation supported per user 222 | if isinstance(label_dataset_gen_params, list): 223 | label_dataset_gen_params = label_dataset_gen_params[-1] 224 | 225 | # set up QAConfiguration object at the beginning to evaluate inputs 226 | label_dataset_gen_params = QAConfigurations.from_dict( 227 | label_dataset_gen_params, api_keys 228 | ) 229 | 230 | # get list of all documents in document_store_path 231 | document_store = glob.glob(f"{document_store_path}/*") 232 | 233 | with ChromaClient() as client: 234 | # Filter collections specific to the user_id. 235 | user_collections = [ 236 | collection 237 | for collection in client.list_collections() 238 | if collection.name.startswith(f"userid_{user_id}_") 239 | ] 240 | 241 | # Check if there are any collections with the QA identifier and delete them. 242 | if any(["_qaid_0_" in collection.name for collection in user_collections]): 243 | for collection in user_collections: 244 | client.delete_collection(name=collection.name) 245 | 246 | # start generation process 247 | await agenerate_and_save_dataset( 248 | label_dataset_gen_params, document_store, label_dataset_path, user_id 249 | ) 250 | -------------------------------------------------------------------------------- /ragflow/requirements.txt: -------------------------------------------------------------------------------- 1 | # Core dependencies for the AI and ML operations 2 | openai==0.28.1 3 | tiktoken==0.5.1 4 | langchain==0.0.330 5 | Cython==3.0.5 6 | 7 | pandas==2.1.2 8 | numpy==1.26.1 9 | scikit-learn==1.3.2 10 | 11 | # PDF and document processing 12 | simsimd==3.5.3 13 | docx2txt==0.8 14 | pypdf==3.17.0 15 | pdf2image==1.16.3 16 | pdfminer.six==20221105 17 | 18 | # OCR and unstructured data handling 19 | unstructured==0.10.27 20 | 21 | # Evaluation metrics 22 | evaluate==0.4.1 23 | rouge_score==0.1.2 24 | absl-py==2.0.0 25 | 26 | # Web framework and API 27 | fastapi==0.104.1 28 | pydantic[email]==2.4.2 29 | uvicorn==0.24.0.post1 30 | python-dotenv==1.0.0 31 | SQLAlchemy==2.0.23 32 | psycopg2-binary==2.9.9 33 | python-jose==3.3.0 34 | passlib==1.7.4 35 | python-multipart==0.0.6 36 | 37 | # Database connector 38 | chromadb==0.4.15 39 | 40 | # Notebooks and external APIs 41 | lark==1.1.8 # selfqueryretriever 42 | cohere==4.32 # for reranker -------------------------------------------------------------------------------- /ragflow/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from ragflow.utils.utils import ( 2 | get_retriever, 3 | get_qa_llm, 4 | read_json, 5 | write_json, 6 | ) 7 | 8 | from ragflow.utils.doc_processing import aload_and_chunk_docs, load_and_chunk_doc 9 | -------------------------------------------------------------------------------- /ragflow/utils/doc_processing.py: -------------------------------------------------------------------------------- 1 | import os 2 | import asyncio 3 | import itertools 4 | 5 | from langchain.docstore.document import Document 6 | from langchain.text_splitter import RecursiveCharacterTextSplitter 7 | 8 | from langchain.document_loaders import ( 9 | TextLoader, 10 | Docx2txtLoader, 11 | # PyPDFLoader, # already splits data during loading 12 | UnstructuredPDFLoader, # returns only 1 Document 13 | # PyMuPDFLoader, # returns 1 Document per page 14 | ) 15 | 16 | from ragflow.commons.configurations import BaseConfigurations 17 | 18 | import logging 19 | 20 | logger = logging.getLogger(__name__) 21 | 22 | 23 | def load_document(file: str) -> list[Document]: 24 | """Loads file from given path into a list of Documents, currently pdf, txt and docx are supported. 25 | 26 | Args: 27 | file (str): file path 28 | 29 | Returns: 30 | List[Document]: loaded files as list of Documents 31 | """ 32 | _, extension = os.path.splitext(file) 33 | 34 | if extension == ".pdf": 35 | loader = UnstructuredPDFLoader(file) 36 | elif extension == ".txt": 37 | loader = TextLoader(file, encoding="utf-8") 38 | elif extension == ".docx": 39 | loader = Docx2txtLoader(file) 40 | else: 41 | logger.error(f"Unsupported file type detected, {file}!") 42 | raise NotImplementedError(f"Unsupported file type detected, {file}!") 43 | 44 | data = loader.load() 45 | return data 46 | 47 | 48 | def split_data( 49 | data: list[Document], 50 | chunk_size: int = 4096, 51 | chunk_overlap: int = 0, 52 | length_function: callable = len, 53 | ) -> list[Document]: 54 | """Function for splitting the provided data, i.e. List of documents loaded. 55 | 56 | Args: 57 | data (List[Document]): _description_ 58 | chunk_size (Optional[int], optional): _description_. Defaults to 4096. 59 | chunk_overlap (Optional[int], optional): _description_. Defaults to 0. 60 | length_function (_type_, optional): _description_. 61 | 62 | Returns: 63 | List[Document]: the splitted document chunks 64 | """ 65 | 66 | text_splitter = RecursiveCharacterTextSplitter( 67 | chunk_size=chunk_size, 68 | chunk_overlap=chunk_overlap, 69 | length_function=length_function, 70 | add_start_index=True, 71 | ) 72 | 73 | chunks = text_splitter.split_documents(data) 74 | 75 | # add end_index 76 | for chunk in chunks: 77 | chunk.metadata["end_index"] = chunk.metadata["start_index"] + len( 78 | chunk.page_content 79 | ) 80 | 81 | return chunks 82 | 83 | 84 | def load_and_chunk_doc( 85 | hp: BaseConfigurations, 86 | file: str, 87 | ) -> list[Document]: 88 | """Wrapper function for loading and splitting document in one call.""" 89 | 90 | logger.debug(f"Loading and splitting file {file}.") 91 | 92 | data = load_document(file) 93 | chunks = split_data(data, hp.chunk_size, hp.chunk_overlap, hp.length_function) 94 | return chunks 95 | 96 | 97 | # TODO: Check for better async implementation! 98 | async def aload_and_chunk_docs( 99 | hp: BaseConfigurations, files: list[str] 100 | ) -> list[Document]: 101 | """Async implementation of load_and_chunk_doc function.""" 102 | loop = asyncio.get_event_loop() 103 | 104 | futures = [ 105 | loop.run_in_executor(None, load_and_chunk_doc, hp, file) for file in files 106 | ] 107 | 108 | results = await asyncio.gather(*futures) 109 | chunks = list(itertools.chain.from_iterable(results)) 110 | 111 | return chunks 112 | -------------------------------------------------------------------------------- /ragflow/utils/hyperparam_chats.py: -------------------------------------------------------------------------------- 1 | import json 2 | import asyncio 3 | from uuid import UUID 4 | from langchain.schema.messages import BaseMessage 5 | import pandas as pd 6 | 7 | from langchain.vectorstores.chroma import Chroma 8 | from langchain.callbacks.streaming_aiter import AsyncIteratorCallbackHandler 9 | from langchain.chains.llm import LLMChain 10 | from langchain.chains.question_answering import load_qa_chain 11 | from langchain.chains import ConversationalRetrievalChain 12 | from langchain.retrievers.multi_query import MultiQueryRetriever 13 | from langchain.memory import ConversationBufferWindowMemory 14 | from langchain.schema import LLMResult, Document 15 | from langchain.chat_models.base import BaseChatModel 16 | 17 | from langchain.chains.conversational_retrieval.prompts import ( 18 | CONDENSE_QUESTION_PROMPT, 19 | ) 20 | 21 | from ragflow.commons.chroma import ChromaClient 22 | from ragflow.commons.configurations import Hyperparameters 23 | from ragflow.commons.prompts import QA_ANSWER_PROMPT 24 | 25 | from typing import Any, Dict, List, Optional, Sequence 26 | 27 | import logging 28 | 29 | logger = logging.getLogger(__name__) 30 | 31 | chats_cache: Dict[str, Dict[int, ConversationalRetrievalChain]] = {} 32 | 33 | 34 | class AsyncCallbackHandler(AsyncIteratorCallbackHandler): 35 | # content: str = "" 36 | # final_answer: bool = True 37 | 38 | def __init__(self) -> None: 39 | super().__init__() 40 | self.source_documents = None 41 | 42 | async def on_llm_new_token(self, token: str, **kwargs: Any) -> None: 43 | self.queue.put_nowait(token) 44 | 45 | async def on_chat_model_start( 46 | self, 47 | serialized: Dict[str, Any], 48 | messages: List[List[BaseMessage]], 49 | *, 50 | run_id: UUID, 51 | parent_run_id: UUID | None = None, 52 | tags: List[str] | None = None, 53 | metadata: Dict[str, Any] | None = None, 54 | **kwargs: Any, 55 | ) -> Any: 56 | pass 57 | 58 | async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: 59 | # if source documents were extracted 60 | if self.source_documents: 61 | self.queue.put_nowait(self.source_documents) 62 | 63 | self.source_documents = None 64 | self.done.set() 65 | 66 | 67 | class RetrieverCallbackHandler(AsyncIteratorCallbackHandler): 68 | def __init__(self, streaming_callback_handler: AsyncCallbackHandler) -> None: 69 | super().__init__() 70 | self.streaming_callback_handler = streaming_callback_handler 71 | 72 | async def on_retriever_end( 73 | self, source_docs, *, run_id, parent_run_id, tags, **kwargs 74 | ): 75 | source_docs_list = ( 76 | [ 77 | {"page_content": doc.page_content, "metadata": doc.metadata} 78 | for doc in source_docs 79 | ] 80 | if source_docs 81 | else None 82 | ) 83 | 84 | self.streaming_callback_handler.source_documents = json.dumps( 85 | {"source_documents": source_docs_list} 86 | ) 87 | 88 | 89 | async def aquery_chat( 90 | hp_id: int, 91 | hyperparameters_results_path: str, 92 | user_id: str, 93 | api_keys: dict, 94 | query: str, 95 | stream_it: AsyncCallbackHandler, 96 | ): 97 | llm = getOrCreateChatModel( 98 | hp_id, hyperparameters_results_path, user_id, api_keys, stream_it 99 | ) 100 | 101 | await llm.acall( 102 | query, 103 | callbacks=[RetrieverCallbackHandler(streaming_callback_handler=stream_it)], 104 | ) 105 | 106 | 107 | async def create_gen( 108 | hp_id: int, 109 | hyperparameters_results_path: str, 110 | user_id: str, 111 | api_keys: dict, 112 | query: str, 113 | stream_it: AsyncCallbackHandler, 114 | ): 115 | task = asyncio.create_task( 116 | aquery_chat( 117 | hp_id, hyperparameters_results_path, user_id, api_keys, query, stream_it 118 | ) 119 | ) 120 | async for token in stream_it.aiter(): 121 | yield token 122 | await task 123 | 124 | 125 | def query_chat( 126 | hp_id: int, 127 | hyperparameters_results_path: str, 128 | user_id: str, 129 | api_keys: dict, 130 | query: str, 131 | ) -> dict: 132 | # get or load llm model 133 | llm = getOrCreateChatModel(hp_id, hyperparameters_results_path, user_id, api_keys) 134 | 135 | return llm(query) 136 | 137 | 138 | async def get_docs( 139 | hp_id: int, 140 | hyperparameters_results_path: str, 141 | user_id: str, 142 | api_keys: dict, 143 | query: str, 144 | ) -> List[Document]: 145 | # get or load llm model 146 | llm = getOrCreateChatModel(hp_id, hyperparameters_results_path, user_id, api_keys) 147 | 148 | return await llm.retriever.retriever.aget_relevant_documents(query) 149 | 150 | 151 | def getOrCreateChatModel( 152 | hp_id: int, 153 | hyperparameters_results_path: str, 154 | user_id: str, 155 | api_keys: dict, 156 | streaming_callback: Optional[AsyncCallbackHandler] = None, 157 | ) -> None: 158 | # if model has not been loaded yet 159 | if ( 160 | user_id not in chats_cache 161 | or hp_id not in chats_cache[user_id] 162 | or not isinstance(chats_cache[user_id][hp_id], ConversationalRetrievalChain) 163 | ): 164 | # load hyperparameter results 165 | with open(hyperparameters_results_path, encoding="utf-8") as file: 166 | hp_data = json.load(file) 167 | 168 | df = pd.DataFrame(hp_data) 169 | 170 | # check that hp_id really exists in results 171 | if hp_id not in df.id.values: 172 | raise NotImplementedError("Could not find requested hyperparameter run id.") 173 | 174 | with ChromaClient() as client: 175 | for col in client.list_collections(): 176 | if col.name == f"userid_{user_id}_hpid_{hp_id}": 177 | collection = col 178 | break 179 | 180 | # check that vectorstore contains collection for hp id 181 | if not collection: 182 | raise NotImplementedError( 183 | "Could not find data in vectorstore for requested hyperparameter run id." 184 | ) 185 | 186 | # create retriever and llm from collection 187 | hp_data = df[df.id == hp_id].iloc[0].to_dict() 188 | for key in ["id", "scores", "timestamp"]: 189 | hp_data.pop(key) 190 | 191 | hp = Hyperparameters.from_dict( 192 | input_dict=hp_data, 193 | hp_id=hp_id, 194 | api_keys=api_keys, 195 | ) 196 | 197 | index = Chroma( 198 | client=ChromaClient().get_client(), 199 | collection_name=collection.name, 200 | collection_metadata=collection.metadata, 201 | embedding_function=hp.embedding_model, 202 | ) 203 | 204 | # baseline retriever built from vectorstore collection 205 | retriever = index.as_retriever(search_type="similarity", search_kwargs={"k": 1}) 206 | 207 | # streaming and non-streaming models, create new instance for streaming model 208 | streaming_llm = Hyperparameters.get_language_model( 209 | model_name=Hyperparameters.get_language_model_name(hp.qa_llm), 210 | api_keys=api_keys, 211 | ) 212 | 213 | if isinstance(streaming_llm, BaseChatModel): 214 | streaming_llm.streaming = True 215 | streaming_llm.callbacks = [streaming_callback] 216 | 217 | # llm model from hp for non streaming chains 218 | non_streaming_llm = hp.qa_llm 219 | 220 | # LLM chain for generating new question from user query and chat history 221 | question_generator = LLMChain( 222 | llm=non_streaming_llm, prompt=CONDENSE_QUESTION_PROMPT 223 | ) 224 | 225 | # llm that answers the newly generated condensed question 226 | doc_qa_chain = load_qa_chain( 227 | streaming_llm, chain_type="stuff", prompt=QA_ANSWER_PROMPT 228 | ) 229 | 230 | # advanced retriever using multiple similar queries 231 | multi_query_retriever = MultiQueryRetriever.from_llm( 232 | retriever=retriever, llm=non_streaming_llm, include_original=True 233 | ) 234 | 235 | # memory object to store the chat history 236 | memory = ConversationBufferWindowMemory( 237 | memory_key="chat_history", k=5, return_messages=True, output_key="result" 238 | ) 239 | 240 | qa_llm = ConversationalRetrievalChain( 241 | retriever=multi_query_retriever, 242 | combine_docs_chain=doc_qa_chain, 243 | question_generator=question_generator, 244 | memory=memory, 245 | output_key="result", 246 | return_source_documents=True, 247 | return_generated_question=True, 248 | response_if_no_docs_found="I don't know the answer to this question.", 249 | ) 250 | 251 | # cache llm 252 | if user_id not in chats_cache: 253 | chats_cache[user_id] = {} 254 | 255 | chats_cache[user_id][hp_id] = qa_llm 256 | 257 | logger.info( 258 | f"\nCreated new ConversationalRetrievalChain for {user_id}:{hp_id}." 259 | ) 260 | 261 | # model is already loaded 262 | 263 | logger.info(f"\nRetrieved ConversationalRetrievalChain for {user_id}:{hp_id}.") 264 | 265 | # add new streaming callback to qa llm 266 | qa_llm = chats_cache[user_id][hp_id] 267 | qa_llm.combine_docs_chain.llm_chain.llm.callbacks = [streaming_callback] 268 | 269 | return qa_llm 270 | -------------------------------------------------------------------------------- /ragflow/utils/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from enum import Enum 4 | from datetime import datetime 5 | 6 | from langchain.chains import RetrievalQA 7 | from langchain.docstore.document import Document 8 | 9 | from langchain.schema.language_model import BaseLanguageModel 10 | from langchain.schema.vectorstore import VectorStoreRetriever 11 | 12 | # vector db 13 | from langchain.vectorstores.chroma import Chroma 14 | 15 | from ragflow.commons.prompts import QA_ANSWER_PROMPT 16 | from ragflow.commons.configurations import Hyperparameters 17 | from ragflow.commons.chroma import ChromaClient 18 | 19 | from typing import Any, Optional, Union 20 | 21 | import json 22 | import logging 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | 27 | def get_retriever( 28 | chunks: list[Document], hp: Hyperparameters, user_id: str, for_eval: bool = False 29 | ) -> Union[tuple[VectorStoreRetriever, VectorStoreRetriever], VectorStoreRetriever]: 30 | """Sets up a vector database based on the document chunks and the embedding model provided. 31 | Here we use Chroma for the vectorstore. 32 | 33 | Args: 34 | chunks (list[Document]): _description_ 35 | hp (Hyperparameters): _description_ 36 | user_id (str): _description_ 37 | for_eval (bool): for evaluation we return a second retriever with k=10 38 | 39 | Returns: 40 | tuple[VectorStoreRetriever, VectorStoreRetriever]: Returns two retrievers, one for predicting the 41 | answers and one for grading the retrieval where we need to have the top 10 documents returned 42 | """ 43 | logger.info("Constructing vectorstore and retriever.") 44 | 45 | vectorstore = Chroma.from_documents( 46 | documents=chunks, 47 | embedding=hp.embedding_model, 48 | client=ChromaClient().get_client(), 49 | collection_name=f"userid_{user_id}_hpid_{hp.id}", 50 | collection_metadata={ 51 | "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S,%f")[:-3], 52 | "hnsw:space": hp.similarity_method.value, 53 | }, 54 | ) 55 | 56 | retriever = vectorstore.as_retriever( 57 | search_type=hp.search_type.value, search_kwargs={"k": hp.num_retrieved_docs} 58 | ) 59 | 60 | # if we are evaluating the semantic accuracy of the retriever as well, we have to return a second version of the retriever with different settings 61 | if not for_eval: 62 | return retriever 63 | 64 | retrieverForGrading = vectorstore.as_retriever( 65 | search_type=hp.search_type.value, search_kwargs={"k": 10} 66 | ) 67 | 68 | return retriever, retrieverForGrading 69 | 70 | 71 | def get_qa_llm( 72 | retriever: VectorStoreRetriever, 73 | qa_llm: BaseLanguageModel, 74 | return_source_documents: Optional[bool] = True, 75 | ) -> RetrievalQA: 76 | """Sets up a LangChain RetrievalQA model based on a retriever and language model that answers 77 | queries based on retrieved document chunks. 78 | 79 | Args: 80 | qa_llm (Optional[BaseLanguageModel], optional): language model. 81 | 82 | Returns: 83 | RetrievalQA: RetrievalQA object 84 | """ 85 | logger.debug("Setting up QA LLM with provided retriever.") 86 | 87 | qa_llm_r = RetrievalQA.from_chain_type( 88 | llm=qa_llm, 89 | chain_type="stuff", 90 | retriever=retriever, 91 | chain_type_kwargs={"prompt": QA_ANSWER_PROMPT}, 92 | input_key="question", 93 | return_source_documents=return_source_documents, 94 | ) 95 | 96 | return qa_llm_r 97 | 98 | 99 | def read_json(filename: str) -> Any: 100 | """Load dataset from a JSON file.""" 101 | 102 | with open(filename, "r", encoding="utf-8") as file: 103 | return json.load(file) 104 | 105 | 106 | def write_json(data: list[dict], filename: str, append: Optional[bool] = True) -> None: 107 | """Function used to store generated QA pairs, i.e. the ground truth. 108 | 109 | Args: 110 | data (_type_): _description_ 111 | filename (str, optional): _description_. 112 | """ 113 | 114 | logger.info(f"Writting JSON to {filename}.") 115 | 116 | # Check if file exists 117 | if os.path.exists(filename) and append: 118 | # File exists, read the data 119 | with open(filename, "r", encoding="utf-8") as file: 120 | json_data = json.load(file) 121 | # Assuming the data is a list; you can modify as per your requirements 122 | json_data.extend(data) 123 | else: 124 | json_data = data 125 | 126 | # Write the combined data back to the file 127 | with open(filename, "w", encoding="utf-8") as file: 128 | json.dump(json_data, file, indent=4, default=convert_to_serializable) 129 | 130 | 131 | # Convert non-serializable types 132 | def convert_to_serializable(obj: object) -> str: 133 | """Preprocessing step before writing to json file 134 | 135 | Args: 136 | obj (object): _description_ 137 | 138 | Returns: 139 | str: _description_ 140 | """ 141 | if isinstance(obj, Enum): 142 | return obj.value 143 | elif isinstance(obj, np.ndarray): 144 | return obj.tolist() 145 | elif isinstance(obj, set): 146 | return list(obj) 147 | elif callable(obj): # For and similar types 148 | return str(obj) 149 | elif isinstance(obj, type): # For 150 | return str(obj) 151 | return f"WARNING: Type {type(obj).__name__} not serializable!" 152 | -------------------------------------------------------------------------------- /resources/dev/document_store/msg_life-gb-2021-EN_final_1-15.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AndreasX42/RAGflow/cec31b91af8ab7faa05a3d0376338c7d294d29aa/resources/dev/document_store/msg_life-gb-2021-EN_final_1-15.pdf -------------------------------------------------------------------------------- /resources/dev/document_store/msg_life-gb-2021-EN_final_16-30.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AndreasX42/RAGflow/cec31b91af8ab7faa05a3d0376338c7d294d29aa/resources/dev/document_store/msg_life-gb-2021-EN_final_16-30.pdf -------------------------------------------------------------------------------- /resources/dev/document_store/msg_life-gb-2021-EN_final_31-45.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AndreasX42/RAGflow/cec31b91af8ab7faa05a3d0376338c7d294d29aa/resources/dev/document_store/msg_life-gb-2021-EN_final_31-45.pdf -------------------------------------------------------------------------------- /resources/dev/document_store/msg_life-gb-2021-EN_final_46-59.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AndreasX42/RAGflow/cec31b91af8ab7faa05a3d0376338c7d294d29aa/resources/dev/document_store/msg_life-gb-2021-EN_final_46-59.pdf -------------------------------------------------------------------------------- /resources/dev/document_store/msg_life-gb-2021-EN_final_60-end.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AndreasX42/RAGflow/cec31b91af8ab7faa05a3d0376338c7d294d29aa/resources/dev/document_store/msg_life-gb-2021-EN_final_60-end.pdf -------------------------------------------------------------------------------- /resources/dev/hyperparameters.json: -------------------------------------------------------------------------------- 1 | [] -------------------------------------------------------------------------------- /resources/dev/hyperparameters_results.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "chunk_size": 512, 4 | "chunk_overlap": 10, 5 | "length_function_name": "len", 6 | "id": 0, 7 | "num_retrieved_docs": 3, 8 | "search_type": "similarity", 9 | "similarity_method": "cosine", 10 | "use_llm_grader": false, 11 | "qa_llm": "gpt-3.5-turbo", 12 | "embedding_model": "text-embedding-ada-002", 13 | "timestamp": "2023-11-08 20:14:36,501", 14 | "scores": { 15 | "answer_similarity_score": 0.9455289276479025, 16 | "retriever_mrr@3": 0.654818325434439, 17 | "retriever_mrr@5": 0.6714060031595576, 18 | "retriever_mrr@10": 0.680941096817874, 19 | "rouge1": 0.5848481755355666, 20 | "rouge2": 0.4622595944607476, 21 | "rougeLCS": 0.524537744702827, 22 | "correctness_score": -1, 23 | "comprehensiveness_score": -1, 24 | "readability_score": -1, 25 | "retriever_semantic_accuracy": -1 26 | } 27 | } 28 | ] -------------------------------------------------------------------------------- /resources/dev/label_dataset_gen_params.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "chunk_size": 256, 4 | "chunk_overlap": 10, 5 | "length_function_name": "text-embedding-ada-002", 6 | "qa_generator_llm": "gpt-3.5-turbo", 7 | "persist_to_vs": true, 8 | "embedding_model_list": [ 9 | "text-embedding-ada-002" 10 | ] 11 | } 12 | ] -------------------------------------------------------------------------------- /resources/tests/document_store/churchill_speech.docx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AndreasX42/RAGflow/cec31b91af8ab7faa05a3d0376338c7d294d29aa/resources/tests/document_store/churchill_speech.docx -------------------------------------------------------------------------------- /resources/tests/document_store/churchill_speech.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AndreasX42/RAGflow/cec31b91af8ab7faa05a3d0376338c7d294d29aa/resources/tests/document_store/churchill_speech.pdf -------------------------------------------------------------------------------- /resources/tests/input_hyperparameters.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "chunk_size": 256, 4 | "chunk_overlap": 0, 5 | "length_function_name": "len", 6 | "num_retrieved_docs": 1, 7 | "similarity_method": "l2", 8 | "search_type": "similarity", 9 | "embedding_model": "TestDummyEmbedding", 10 | "qa_llm": "TestDummyLLM", 11 | "use_llm_grader": false 12 | }, 13 | { 14 | "chunk_size": 512, 15 | "chunk_overlap": 10, 16 | "length_function_name": "len", 17 | "num_retrieved_docs": 2, 18 | "similarity_method": "ip", 19 | "search_type": "mmr", 20 | "embedding_model": "TestDummyEmbedding", 21 | "qa_llm": "TestDummyLLM", 22 | "use_llm_grader": false 23 | }, 24 | { 25 | "chunk_size": 512, 26 | "chunk_overlap": 10, 27 | "length_function_name": "len", 28 | "num_retrieved_docs": 3, 29 | "similarity_method": "cosine", 30 | "search_type": "similarity", 31 | "embedding_model": "TestDummyEmbedding", 32 | "qa_llm": "TestDummyLLM", 33 | "use_llm_grader": false 34 | } 35 | ] -------------------------------------------------------------------------------- /resources/tests/input_label_dataset_gen_params_with_upsert.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "chunk_size": 512, 4 | "chunk_overlap": 10, 5 | "length_function_name": "len", 6 | "qa_generator_llm": "TestDummyLLM", 7 | "persist_to_vs": true, 8 | "embedding_model_list": [ 9 | "TestDummyEmbedding", 10 | "TestDummyEmbedding" 11 | ] 12 | } 13 | ] -------------------------------------------------------------------------------- /resources/tests/input_label_dataset_gen_params_without_upsert.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "chunk_size": 512, 4 | "chunk_overlap": 10, 5 | "length_function_name": "len", 6 | "qa_generator_llm": "TestDummyLLM", 7 | "persist_to_vs": false, 8 | "embedding_model_list": [ 9 | "TestDummyEmbedding", 10 | "TestDummyEmbedding" 11 | ] 12 | } 13 | ] -------------------------------------------------------------------------------- /tests/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.11 2 | 3 | WORKDIR /tests 4 | 5 | COPY ./requirements.txt ./ 6 | 7 | RUN pip install --no-cache-dir --upgrade -r requirements.txt 8 | 9 | COPY ./ ./ 10 | 11 | RUN chmod +x wait-for-it.sh 12 | 13 | ENTRYPOINT ["/bin/bash", "-c", "./wait-for-it.sh chromadb-test:8000 --timeout=60 && ./wait-for-it.sh postgres-test:5432 --timeout=60 && ./wait-for-it.sh ragflow-test:8080 --timeout=60 -- pytest"] -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | 4 | logging.basicConfig( 5 | level=os.environ.get("LOG_LEVEL", "INFO"), 6 | format="%(asctime)s - %(levelname)s - %(name)s:%(filename)s:%(lineno)d - %(message)s", 7 | ) 8 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import glob 3 | import os 4 | 5 | from tests.utils import HYPERPARAMETERS_RESULTS_PATH 6 | 7 | import logging 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | @pytest.fixture(scope="function", autouse=True) 13 | def cleanup_output_files(): 14 | # Setup: Anything before the yield is the setup. You can leave it empty if there's no setup. 15 | 16 | yield # This will allow the test to run. 17 | 18 | # Teardown: Anything after the yield is the teardown. 19 | for file in glob.glob("./resources/output_*"): 20 | try: 21 | os.remove(file) 22 | except Exception as e: 23 | logger.error(f"Error deleting {file}. Error: {e}") 24 | -------------------------------------------------------------------------------- /tests/requirements.txt: -------------------------------------------------------------------------------- 1 | pytest==7.4.3 2 | requests==2.31.0 3 | chromadb==0.4.15 -------------------------------------------------------------------------------- /tests/test_backend_configs.py: -------------------------------------------------------------------------------- 1 | from tests.utils import RAGFLOW_HOST, RAGFLOW_PORT, fetch_data 2 | 3 | LLM_MODELS = [ 4 | "gpt-3.5-turbo", 5 | "gpt-4", 6 | "Llama-2-7b-chat-hf", 7 | "Llama-2-13b-chat-hf", 8 | "Llama-2-70b-chat-hf", 9 | ] 10 | 11 | EMB_MODELS = ["text-embedding-ada-002"] 12 | 13 | CVGradeAnswerPrompt = ["zero_shot", "few_shot", "none"] 14 | CVGradeRetrieverPrompt = ["default", "none"] 15 | CVRetrieverSearchType = ["similarity", "mmr"] 16 | CVSimilarityMethod = ["cosine", "l2", "ip"] 17 | 18 | 19 | def test_ragflow_config_endpoints(): 20 | """Fetch data for all endpoints.""" 21 | endpoints = { 22 | "/configs/llm_models": LLM_MODELS, 23 | "/configs/embedding_models": EMB_MODELS, 24 | "/configs/retriever_similarity_methods": CVSimilarityMethod, 25 | "/configs/retriever_search_types": CVRetrieverSearchType, 26 | "/configs/grade_answer_prompts": CVGradeAnswerPrompt, 27 | "/configs/grade_documents_prompts": CVGradeRetrieverPrompt, 28 | } 29 | 30 | for endpoint, expected_values in endpoints.items(): 31 | response = fetch_data( 32 | method="get", host=RAGFLOW_HOST, port=RAGFLOW_PORT, endpoint=endpoint 33 | ).json() 34 | 35 | # Assert that the fetched data matches the enum values 36 | if isinstance(expected_values, list): 37 | assert set(response) == set( 38 | expected_values 39 | ), f"Mismatch in endpoint {endpoint}" 40 | 41 | else: 42 | raise NotImplementedError( 43 | f"Type {type(expected_values)} not implemented for endpoint {endpoint}." 44 | ) 45 | -------------------------------------------------------------------------------- /tests/test_backend_generator.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import pytest 3 | import json 4 | import chromadb 5 | import logging 6 | 7 | from tests.utils import ( 8 | RAGFLOW_HOST, 9 | RAGFLOW_PORT, 10 | CHROMADB_HOST, 11 | CHROMADB_PORT, 12 | LABEL_DATASET_ENDPOINT, 13 | DOCUMENT_STORE_PATH, 14 | LABEL_DATASET_PATH, 15 | fetch_data, 16 | first_user_id, 17 | second_user_id, 18 | user_id_without_upsert, 19 | ) 20 | 21 | logger = logging.getLogger(__name__) 22 | 23 | 24 | def test_generator_without_upsert(user_id_without_upsert): 25 | """Test QA generator without upserting to ChromaDB. Since we use a fake LLM provided by LangChain for testing purposes to mock a real one, we are not able to generate real QA pairs and the generated label_dataset.json file is empty.""" 26 | 27 | json_payload = { 28 | "document_store_path": DOCUMENT_STORE_PATH, 29 | "label_dataset_path": "./resources/output_label_dataset_without_upsert.json", 30 | "label_dataset_gen_params_path": "./resources/input_label_dataset_gen_params_without_upsert.json", 31 | "user_id": user_id_without_upsert, 32 | "api_keys": {}, 33 | } 34 | 35 | response = fetch_data( 36 | method="post", 37 | host=RAGFLOW_HOST, 38 | port=RAGFLOW_PORT, 39 | endpoint=LABEL_DATASET_ENDPOINT, 40 | payload=json_payload, 41 | ) 42 | 43 | assert ( 44 | response.status_code == 200 45 | ), f"The response from the {LABEL_DATASET_ENDPOINT} endpoint should be ok." 46 | 47 | with open(json_payload["label_dataset_path"], encoding="utf-8") as file: 48 | data = json.load(file) 49 | 50 | assert ( 51 | not data 52 | ), f"Expected no content in {json_payload['label_dataset_path']}, but found it not empty" 53 | 54 | 55 | @pytest.mark.parametrize( 56 | "user_id_fixture, expected_num_collections, expected_error", 57 | [ 58 | ("first_user_id", 1, None), 59 | ("second_user_id", 2, None), 60 | ], 61 | ) 62 | def test_generator_with_upsert( 63 | user_id_fixture, expected_num_collections, expected_error, request 64 | ): 65 | """Helper function to run generator with upserting twice to test upsert functionality""" 66 | 67 | def get_fixture_value(fixture_name): 68 | """Get the value associated with a fixture name.""" 69 | return request.getfixturevalue(fixture_name) 70 | 71 | user_id = get_fixture_value(user_id_fixture) 72 | 73 | json_payload = { 74 | "document_store_path": DOCUMENT_STORE_PATH, 75 | "label_dataset_path": "./resources/output_label_dataset_with_upsert.json", 76 | "label_dataset_gen_params_path": "./resources/input_label_dataset_gen_params_with_upsert.json", 77 | "user_id": user_id, 78 | "api_keys": {}, 79 | } 80 | 81 | # case when a invalid UUID gets passed to endpoints 82 | if expected_error: 83 | with pytest.raises( 84 | expected_error, 85 | match=f"Unprocessable Entity for url: http://ragflow-test:8080{LABEL_DATASET_ENDPOINT}", 86 | ): 87 | # Assuming the error will arise here when pydantic throws an exception because user id is not valid 88 | response = fetch_data( 89 | method="post", 90 | host=RAGFLOW_HOST, 91 | port=RAGFLOW_PORT, 92 | endpoint=LABEL_DATASET_ENDPOINT, 93 | payload=json_payload, 94 | ) 95 | 96 | return 97 | 98 | # case for valid payload data 99 | response = fetch_data( 100 | method="post", 101 | host=RAGFLOW_HOST, 102 | port=RAGFLOW_PORT, 103 | endpoint=LABEL_DATASET_ENDPOINT, 104 | payload=json_payload, 105 | ) 106 | assert ( 107 | response.status_code == 200 108 | ), "The response from the {LABEL_DATASET_ENDPOINT} endpoint should be ok." 109 | 110 | with open(json_payload["label_dataset_path"], encoding="utf-8") as file: 111 | data = json.load(file) 112 | 113 | assert ( 114 | data 115 | ), f"Expected content in {json_payload['label_dataset_path']}, but found it empty or null." 116 | # connect to chromadb 117 | client = chromadb.HttpClient( 118 | host=CHROMADB_HOST, 119 | port=CHROMADB_PORT, 120 | ) 121 | 122 | collections_list = client.list_collections() 123 | 124 | assert ( 125 | len(collections_list) == expected_num_collections 126 | ), f"Expected {expected_num_collections} collections in ChromaDB but found {len(collections_list)}." 127 | 128 | # get most recently created collection 129 | sorted_collections = sorted( 130 | collections_list, key=lambda col: col.metadata["timestamp"], reverse=True 131 | ) 132 | 133 | collection = sorted_collections[0] 134 | 135 | assert ( 136 | collection.name == f"userid_{user_id}_qaid_0_TestDummyEmbedding" 137 | ), "Name of collection should contain first 8 chars of user id and embedding model name." 138 | 139 | assert ( 140 | collection.name == f"userid_{user_id}_qaid_0_TestDummyEmbedding" 141 | ), "The metadata dict of collection should also contain the entire user id" 142 | embeddings = collection.get(include=["embeddings"])["embeddings"] 143 | 144 | assert len(embeddings) == len(data) 145 | 146 | assert len(embeddings[0]) == 2 147 | -------------------------------------------------------------------------------- /tests/test_backend_hp_evaluator.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import chromadb 4 | import pytest 5 | 6 | from tests.utils import ( 7 | RAGFLOW_HOST, 8 | RAGFLOW_PORT, 9 | CHROMADB_HOST, 10 | CHROMADB_PORT, 11 | HP_EVALUATION_ENDPOINT, 12 | DOCUMENT_STORE_PATH, 13 | LABEL_DATASET_PATH, 14 | HYPERPARAMETERS_PATH, 15 | HYPERPARAMETERS_RESULTS_PATH, 16 | HYPERPARAMETERS_RESULTS_DATA_PATH, 17 | fetch_data, 18 | first_user_id, 19 | second_user_id, 20 | user_id_without_upsert, 21 | ) 22 | 23 | logger = logging.getLogger(__name__) 24 | 25 | 26 | @pytest.mark.parametrize( 27 | "user_id_fixture, num_hp_run", 28 | [("first_user_id", 0), ("second_user_id", 1), ("first_user_id", 2)], 29 | ) 30 | def test_ragflow_evaluator(user_id_fixture, num_hp_run, request): 31 | """Test the evaluation of provided hyperparameter configurations.""" 32 | 33 | def get_fixture_value(fixture_name): 34 | """Get the value associated with a fixture name.""" 35 | return request.getfixturevalue(fixture_name) 36 | 37 | user_id = get_fixture_value(user_id_fixture) 38 | 39 | json_payload = { 40 | "document_store_path": DOCUMENT_STORE_PATH, 41 | "label_dataset_path": LABEL_DATASET_PATH, 42 | "hyperparameters_path": HYPERPARAMETERS_PATH, 43 | "hyperparameters_results_path": HYPERPARAMETERS_RESULTS_PATH, 44 | "hyperparameters_results_data_path": HYPERPARAMETERS_RESULTS_DATA_PATH, 45 | "user_id": user_id, 46 | "api_keys": {}, 47 | } 48 | 49 | response = fetch_data( 50 | method="post", 51 | host=RAGFLOW_HOST, 52 | port=RAGFLOW_PORT, 53 | endpoint=HP_EVALUATION_ENDPOINT, 54 | payload=json_payload, 55 | ) 56 | 57 | client = chromadb.HttpClient( 58 | host=CHROMADB_HOST, 59 | port=CHROMADB_PORT, 60 | ) 61 | 62 | collections_list = client.list_collections() 63 | collection_names = [col.name for col in collections_list] 64 | 65 | # Check if hp_id get incremented correctly for users, the second evaluation of first user should create a collection with hpid_1 as suffix 66 | if len(collections_list) in range(3, 6): 67 | for i in range(3): 68 | assert ( 69 | f"userid_{user_id}_hpid_{i}" in collection_names 70 | ), "First user with 3 hp evals should have 3 corresponding collections now" 71 | elif len(collections_list) in range(6, 9): 72 | for i in range(3): 73 | assert ( 74 | f"userid_{user_id}_hpid_{i}" in collection_names 75 | ), "Second user with 3 hp evals should have 3 corresponding collections now" 76 | elif len(collections_list) in range(9, 12): 77 | for i in range(3, 6): 78 | assert ( 79 | f"userid_{user_id}_hpid_{i}" in collection_names 80 | ), "First user with next 3 hp evals should have 6 corresponding collections now" 81 | else: 82 | assert False, "Unexpected collections_list length" 83 | 84 | # Check if hyperparameters from json were set correctly 85 | assert ( 86 | response.status_code == 200 87 | ), f"The response from the {HP_EVALUATION_ENDPOINT} endpoint should be ok." 88 | 89 | # test the generated results of the evaluation run 90 | with open(json_payload["hyperparameters_results_path"], encoding="utf-8") as file: 91 | hyperparameters_results_list = json.load(file) 92 | 93 | with open(json_payload["hyperparameters_path"], encoding="utf-8") as file: 94 | hyperparameters_list = json.load(file) 95 | 96 | hyperparameters_results = hyperparameters_results_list[num_hp_run] 97 | hyperparameters = hyperparameters_list[num_hp_run] 98 | 99 | assert ( 100 | hyperparameters["similarity_method"] 101 | == hyperparameters_results["vectorstore_params"]["metadata"]["hnsw:space"] 102 | ), "Retriever similarity methods should match." 103 | 104 | assert ( 105 | hyperparameters["search_type"] 106 | == hyperparameters_results["retriever_params"]["search_type"] 107 | ), "Retriever search types should match." 108 | 109 | assert ( 110 | hyperparameters["num_retrieved_docs"] 111 | == hyperparameters_results["retriever_params"]["search_kwargs"]["k"] 112 | ), "Number of retrieved docs should match." 113 | 114 | assert ( 115 | hyperparameters["embedding_model"] 116 | == hyperparameters_results["retriever_params"]["tags"][1] 117 | ), "Embedding models should match." 118 | -------------------------------------------------------------------------------- /tests/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Optional 3 | import requests 4 | import pytest 5 | 6 | # Constants 7 | RAGFLOW_HOST = os.environ.get("RAGFLOW_HOST") 8 | RAGFLOW_PORT = os.environ.get("RAGFLOW_PORT") 9 | 10 | CHROMADB_HOST = os.environ.get("CHROMADB_HOST") 11 | CHROMADB_PORT = os.environ.get("CHROMADB_PORT") 12 | 13 | LABEL_DATASET_ENDPOINT = "/generation" 14 | HP_EVALUATION_ENDPOINT = "/evaluation" 15 | 16 | DOCUMENT_STORE_PATH = "./resources/document_store" 17 | LABEL_DATASET_PATH = "./resources/input_label_dataset.json" 18 | HYPERPARAMETERS_PATH = "./resources/input_hyperparameters.json" 19 | HYPERPARAMETERS_RESULTS_PATH = "./resources/output_hyperparameters_results.json" 20 | HYPERPARAMETERS_RESULTS_DATA_PATH = ( 21 | "./resources/output_hyperparameters_results_data.csv" 22 | ) 23 | # pytest fixtures 24 | 25 | 26 | @pytest.fixture 27 | def user_id_without_upsert() -> str: 28 | """Returns a test user id of 0.""" 29 | return "1" 30 | 31 | 32 | @pytest.fixture 33 | def first_user_id() -> str: 34 | """Returns a test user id of 0.""" 35 | return "2" 36 | 37 | 38 | @pytest.fixture 39 | def second_user_id() -> str: 40 | """Returns a test user id of 1.""" 41 | return "3" 42 | 43 | 44 | # helper functions 45 | def fetch_data( 46 | method: str, 47 | host: str, 48 | port: str, 49 | endpoint: str, 50 | payload: Optional[dict] = None, 51 | ): 52 | """Fetch data from the given endpoint.""" 53 | 54 | if method == "get": 55 | response = requests.get(f"http://{host}:{port}{endpoint}") 56 | else: 57 | response = requests.post(f"http://{host}:{port}{endpoint}", json=payload) 58 | 59 | # Handle response errors if needed 60 | response.raise_for_status() 61 | 62 | return response 63 | -------------------------------------------------------------------------------- /tests/wait-for-it.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Use this script to test if a given TCP host/port are available 3 | 4 | WAITFORIT_cmdname=${0##*/} 5 | 6 | echoerr() { if [[ $WAITFORIT_QUIET -ne 1 ]]; then echo "$@" 1>&2; fi } 7 | 8 | usage() 9 | { 10 | cat << USAGE >&2 11 | Usage: 12 | $WAITFORIT_cmdname host:port [-s] [-t timeout] [-- command args] 13 | -h HOST | --host=HOST Host or IP under test 14 | -p PORT | --port=PORT TCP port under test 15 | Alternatively, you specify the host and port as host:port 16 | -s | --strict Only execute subcommand if the test succeeds 17 | -q | --quiet Don't output any status messages 18 | -t TIMEOUT | --timeout=TIMEOUT 19 | Timeout in seconds, zero for no timeout 20 | -- COMMAND ARGS Execute command with args after the test finishes 21 | USAGE 22 | exit 1 23 | } 24 | 25 | wait_for() 26 | { 27 | if [[ $WAITFORIT_TIMEOUT -gt 0 ]]; then 28 | echoerr "$WAITFORIT_cmdname: waiting $WAITFORIT_TIMEOUT seconds for $WAITFORIT_HOST:$WAITFORIT_PORT" 29 | else 30 | echoerr "$WAITFORIT_cmdname: waiting for $WAITFORIT_HOST:$WAITFORIT_PORT without a timeout" 31 | fi 32 | WAITFORIT_start_ts=$(date +%s) 33 | while : 34 | do 35 | (echo > /dev/tcp/$WAITFORIT_HOST/$WAITFORIT_PORT) >/dev/null 2>&1 36 | result=$? 37 | if [[ $result -eq 0 ]]; then 38 | WAITFORIT_end_ts=$(date +%s) 39 | echoerr "$WAITFORIT_cmdname: $WAITFORIT_HOST:$WAITFORIT_PORT is available after $((WAITFORIT_end_ts - WAITFORIT_start_ts)) seconds" 40 | break 41 | fi 42 | sleep 1 43 | done 44 | } 45 | 46 | wait_for_wrapper() 47 | { 48 | # In order to support SIGINT during timeout: http://unix.stackexchange.com/a/57692 49 | if [[ $WAITFORIT_QUIET -eq 1 ]]; then 50 | timeout $WAITFORIT_BUSYTIMEFLAG $WAITFORIT_TIMEOUT $0 --quiet --child --host=$WAITFORIT_HOST --port=$WAITFORIT_PORT --timeout=$WAITFORIT_TIMEOUT & 51 | else 52 | timeout $WAITFORIT_BUSYTIMEFLAG $WAITFORIT_TIMEOUT $0 --child --host=$WAITFORIT_HOST --port=$WAITFORIT_PORT --timeout=$WAITFORIT_TIMEOUT & 53 | fi 54 | WAITFORIT_PID=$! 55 | trap "kill -INT -$WAITFORIT_PID" INT 56 | wait $WAITFORIT_PID 57 | WAITFORIT_RESULT=$? 58 | if [[ $WAITFORIT_RESULT -ne 0 ]]; then 59 | echoerr "$WAITFORIT_cmdname: timeout occurred after waiting $WAITFORIT_TIMEOUT seconds for $WAITFORIT_HOST:$WAITFORIT_PORT" 60 | fi 61 | return $WAITFORIT_RESULT 62 | } 63 | 64 | # process arguments 65 | while [[ $# -gt 0 ]] 66 | do 67 | case "$1" in 68 | *:* ) 69 | WAITFORIT_hostport=(${1//:/ }) 70 | WAITFORIT_HOST=${WAITFORIT_hostport[0]} 71 | WAITFORIT_PORT=${WAITFORIT_hostport[1]} 72 | shift 1 73 | ;; 74 | --child) 75 | WAITFORIT_CHILD=1 76 | shift 1 77 | ;; 78 | -q | --quiet) 79 | WAITFORIT_QUIET=1 80 | shift 1 81 | ;; 82 | -s | --strict) 83 | WAITFORIT_STRICT=1 84 | shift 1 85 | ;; 86 | -h) 87 | WAITFORIT_HOST="$2" 88 | if [[ $WAITFORIT_HOST == "" ]]; then break; fi 89 | shift 2 90 | ;; 91 | --host=*) 92 | WAITFORIT_HOST="${1#*=}" 93 | shift 1 94 | ;; 95 | -p) 96 | WAITFORIT_PORT="$2" 97 | if [[ $WAITFORIT_PORT == "" ]]; then break; fi 98 | shift 2 99 | ;; 100 | --port=*) 101 | WAITFORIT_PORT="${1#*=}" 102 | shift 1 103 | ;; 104 | -t) 105 | WAITFORIT_TIMEOUT="$2" 106 | if [[ $WAITFORIT_TIMEOUT == "" ]]; then break; fi 107 | shift 2 108 | ;; 109 | --timeout=*) 110 | WAITFORIT_TIMEOUT="${1#*=}" 111 | shift 1 112 | ;; 113 | --) 114 | shift 115 | WAITFORIT_CLI=("$@") 116 | break 117 | ;; 118 | --help) 119 | usage 120 | ;; 121 | *) 122 | echoerr "Unknown argument: $1" 123 | usage 124 | ;; 125 | esac 126 | done 127 | 128 | if [[ "$WAITFORIT_HOST" == "" || "$WAITFORIT_PORT" == "" ]]; then 129 | echoerr "Error: you need to provide a host and port to test." 130 | usage 131 | fi 132 | 133 | WAITFORIT_TIMEOUT=${WAITFORIT_TIMEOUT:-15} 134 | WAITFORIT_STRICT=${WAITFORIT_STRICT:-0} 135 | WAITFORIT_CHILD=${WAITFORIT_CHILD:-0} 136 | WAITFORIT_QUIET=${WAITFORIT_QUIET:-0} 137 | 138 | # Check to see if timeout is from busybox? 139 | WAITFORIT_TIMEOUT_PATH=$(type -p timeout) 140 | WAITFORIT_TIMEOUT_PATH=$(realpath $WAITFORIT_TIMEOUT_PATH 2>/dev/null || readlink -f $WAITFORIT_TIMEOUT_PATH) 141 | 142 | WAITFORIT_BUSYTIMEFLAG="" 143 | if [[ $WAITFORIT_TIMEOUT_PATH =~ "busybox" ]]; then 144 | WAITFORIT_ISBUSY=1 145 | # Check if busybox timeout uses -t flag 146 | # (recent Alpine versions don't support -t anymore) 147 | if timeout &>/dev/stdout | grep -q -e '-t '; then 148 | WAITFORIT_BUSYTIMEFLAG="-t" 149 | fi 150 | else 151 | WAITFORIT_ISBUSY=0 152 | fi 153 | 154 | if [[ $WAITFORIT_CHILD -gt 0 ]]; then 155 | wait_for 156 | WAITFORIT_RESULT=$? 157 | exit $WAITFORIT_RESULT 158 | else 159 | if [[ $WAITFORIT_TIMEOUT -gt 0 ]]; then 160 | wait_for_wrapper 161 | WAITFORIT_RESULT=$? 162 | else 163 | wait_for 164 | WAITFORIT_RESULT=$? 165 | fi 166 | fi 167 | 168 | if [[ $WAITFORIT_CLI != "" ]]; then 169 | if [[ $WAITFORIT_RESULT -ne 0 && $WAITFORIT_STRICT -eq 1 ]]; then 170 | echoerr "$WAITFORIT_cmdname: strict mode, refusing to execute subprocess" 171 | exit $WAITFORIT_RESULT 172 | fi 173 | exec "${WAITFORIT_CLI[@]}" 174 | else 175 | exit $WAITFORIT_RESULT 176 | fi 177 | -------------------------------------------------------------------------------- /vectorstore/.dockerignore: -------------------------------------------------------------------------------- 1 | # git 2 | .git 3 | .gitignore 4 | .cache 5 | **.log 6 | 7 | # data/chromadb 8 | data/datasets.zip 9 | data/desc.docx 10 | chromadb/ 11 | 12 | # Byte-compiled / optimized / DLL files 13 | **/__pycache__ 14 | **/*.pyc 15 | **.py[cod] 16 | *$py.class 17 | .vscode/ 18 | 19 | # Jupyter Notebook 20 | **/.ipynb_checkpoints 21 | notebooks/ 22 | 23 | # dotenv 24 | .env 25 | 26 | # virtualenv 27 | .venv* 28 | venv*/ 29 | ENV*/ -------------------------------------------------------------------------------- /vectorstore/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.11 2 | 3 | EXPOSE 8000 4 | 5 | WORKDIR /vectorstore 6 | 7 | COPY ./requirements.txt ./ 8 | 9 | RUN pip install --no-cache-dir -r requirements.txt 10 | 11 | COPY ./server.py ./ 12 | 13 | CMD ["uvicorn", "server:app", "--host", "0.0.0.0", "--port", "8000"] -------------------------------------------------------------------------------- /vectorstore/requirements.txt: -------------------------------------------------------------------------------- 1 | chromadb==0.4.15 2 | uvicorn==0.24.0.post1 -------------------------------------------------------------------------------- /vectorstore/server.py: -------------------------------------------------------------------------------- 1 | import chromadb 2 | import chromadb.config 3 | from chromadb.server.fastapi import FastAPI 4 | 5 | settings = chromadb.config.Settings() 6 | 7 | server = FastAPI(settings) 8 | app = server.app 9 | --------------------------------------------------------------------------------