├── .circleci
└── config.yml
├── .gitignore
├── RAGflow overview.drawio
├── README.md
├── TODO
├── app
├── .dockerignore
├── .streamlit
│ └── config.toml
├── Dockerfile
├── __init__.py
├── main.py
├── page_apikeys.py
├── page_chat.py
├── page_dashboard.py
├── page_documentstore.py
├── page_filemanager.py
├── page_home.py
├── page_login.py
├── page_parameters.py
├── requirements.txt
└── utils.py
├── docker-compose.dev.yaml
├── docker-compose.integration.test.yaml
├── docker-compose.local.test.yaml
├── k8s
├── app-cluster-ip-service.yaml
├── app-deployment.yaml
├── backend-cluster-ip-service.yaml
├── backend-deployment.yaml
├── ingress-service.yaml
├── pgadmin-cluster-ip-service.yaml
├── pgadmin-deployment.yaml
├── postgres-cluster-ip-service.yaml
├── postgres-deployment.yaml
├── shared-persistent-volume-claim.yaml
├── vectorstore-cluster-ip-service.yaml
└── vectorstore-deployment.yaml
├── ragflow
├── .dockerignore
├── Dockerfile
├── __init__.py
├── api
│ ├── __init__.py
│ ├── database.py
│ ├── main.py
│ ├── models.py
│ ├── routers
│ │ ├── __init__.py
│ │ ├── auth_router.py
│ │ ├── chats_router.py
│ │ ├── configs_router.py
│ │ ├── evals_router.py
│ │ ├── gens_router.py
│ │ └── user_router.py
│ ├── schemas.py
│ └── services
│ │ ├── __init__.py
│ │ ├── auth_service.py
│ │ ├── common_service.py
│ │ └── user_service.py
├── commons
│ ├── __init__.py
│ ├── chroma
│ │ ├── ChromaClient.py
│ │ └── __init__.py
│ ├── configurations
│ │ ├── BaseConfigurations.py
│ │ ├── Hyperparameters.py
│ │ ├── QAConfigurations.py
│ │ └── __init__.py
│ ├── prompts
│ │ ├── __init__.py
│ │ ├── grade_answers_prompts.py
│ │ ├── grade_retriever_prompts.py
│ │ ├── qa_answer_prompts.py
│ │ └── qa_geneneration_prompts.py
│ └── vectorstore
│ │ ├── __init__.py
│ │ └── pgvector_utils.py
├── evaluation
│ ├── __init__.py
│ ├── hp_evaluator.py
│ ├── metrics
│ │ ├── __init__.py
│ │ ├── answer_embedding_similarity.py
│ │ ├── predicted_answer_accuracy.py
│ │ ├── retriever_mrr_accuracy.py
│ │ ├── retriever_semantic_accuracy.py
│ │ └── rouge_score.py
│ └── utils.py
├── example.py
├── generation
│ ├── __init__.py
│ └── label_dataset_generator.py
├── requirements.txt
└── utils
│ ├── __init__.py
│ ├── doc_processing.py
│ ├── hyperparam_chats.py
│ └── utils.py
├── resources
├── dev
│ ├── document_store
│ │ ├── msg_life-gb-2021-EN_final_1-15.pdf
│ │ ├── msg_life-gb-2021-EN_final_16-30.pdf
│ │ ├── msg_life-gb-2021-EN_final_31-45.pdf
│ │ ├── msg_life-gb-2021-EN_final_46-59.pdf
│ │ └── msg_life-gb-2021-EN_final_60-end.pdf
│ ├── hyperparameters.json
│ ├── hyperparameters_results.json
│ ├── hyperparameters_results_data.csv
│ ├── label_dataset.json
│ └── label_dataset_gen_params.json
└── tests
│ ├── document_store
│ ├── churchill_speech.docx
│ ├── churchill_speech.pdf
│ └── churchill_speech.txt
│ ├── input_hyperparameters.json
│ ├── input_label_dataset.json
│ ├── input_label_dataset_gen_params_with_upsert.json
│ └── input_label_dataset_gen_params_without_upsert.json
├── tests
├── Dockerfile
├── __init__.py
├── conftest.py
├── requirements.txt
├── test_backend_configs.py
├── test_backend_generator.py
├── test_backend_hp_evaluator.py
├── utils.py
└── wait-for-it.sh
└── vectorstore
├── .dockerignore
├── Dockerfile
├── requirements.txt
└── server.py
/.circleci/config.yml:
--------------------------------------------------------------------------------
1 | version: 2.1
2 |
3 | executors:
4 | python-docker-executor:
5 | docker:
6 | - image: cimg/python:3.11
7 |
8 | jobs:
9 | # first we build and push the new Docker images to Docker Hub
10 | push_app_image:
11 | executor: python-docker-executor
12 | steps:
13 | - checkout
14 | - set_env_vars
15 | - docker_login
16 | - setup_remote_docker:
17 | docker_layer_caching: false
18 | - run:
19 | name: Inject build info into frontend
20 | command: |
21 | sed -i "s/\$BUILD_NUMBER/${PIPELINE_NUMBER}/g" ./app/main.py
22 | sed -i "s/\$BUILD_DATE/${NOW}/g" ./app/main.py
23 | sed -i "s/\$GIT_SHA/${TAG}/g" ./app/main.py
24 | - run:
25 | name: Push app image
26 | command: |
27 | docker build -t $DOCKER_USER/$IMAGE_NAME_BASE-app:$TAG -t $DOCKER_USER/$IMAGE_NAME_BASE-app:latest ./app
28 | docker push $DOCKER_USER/$IMAGE_NAME_BASE-app:$TAG
29 | docker push $DOCKER_USER/$IMAGE_NAME_BASE-app:latest
30 | push_ragflow_image:
31 | executor: python-docker-executor
32 | steps:
33 | - checkout
34 | - set_env_vars
35 | - docker_login
36 | - setup_remote_docker:
37 | docker_layer_caching: false
38 | - run:
39 | name: Push ragflow image
40 | command: |
41 | docker build -t $DOCKER_USER/$IMAGE_NAME_BASE-backend:$TAG -t $DOCKER_USER/$IMAGE_NAME_BASE-backend:latest ./ragflow
42 | docker push $DOCKER_USER/$IMAGE_NAME_BASE-backend:$TAG
43 | docker push $DOCKER_USER/$IMAGE_NAME_BASE-backend:latest
44 | push_vectorstore_image:
45 | executor: python-docker-executor
46 | steps:
47 | - checkout
48 | - set_env_vars
49 | - docker_login
50 | - setup_remote_docker:
51 | docker_layer_caching: false
52 | - run:
53 | name: Push vectorstore image
54 | command: |
55 | docker build -t $DOCKER_USER/$IMAGE_NAME_BASE-vectorstore:$TAG -t $DOCKER_USER/$IMAGE_NAME_BASE-vectorstore:latest ./vectorstore
56 | docker push $DOCKER_USER/$IMAGE_NAME_BASE-vectorstore:$TAG
57 | docker push $DOCKER_USER/$IMAGE_NAME_BASE-vectorstore:latest
58 | push_test_image:
59 | executor: python-docker-executor
60 | steps:
61 | - checkout
62 | - set_env_vars
63 | - docker_login
64 | - setup_remote_docker:
65 | docker_layer_caching: false
66 | - run:
67 | name: Push test image
68 | command: |
69 | docker build -t $DOCKER_USER/$IMAGE_NAME_BASE-test:$TAG -t $DOCKER_USER/$IMAGE_NAME_BASE-test:latest ./tests
70 | docker push $DOCKER_USER/$IMAGE_NAME_BASE-test:$TAG
71 | docker push $DOCKER_USER/$IMAGE_NAME_BASE-test:latest
72 |
73 | # then we run tests on the new Docker images
74 | run_integration_tests:
75 | machine: true
76 | steps:
77 | - checkout
78 | - run:
79 | name: Clear Docker cache
80 | command: docker system prune --all --force --volumes
81 | - run:
82 | name: Run Docker Compose to build, start and test
83 | command: |
84 | docker-compose -f docker-compose.integration.test.yaml up --exit-code-from test-suite
85 | - run:
86 | name: Shut services down
87 | command: docker-compose -f docker-compose.integration.test.yaml down
88 | - store_artifacts:
89 | path: /tests/test-reports
90 | destination: test-reports
91 |
92 | # if everything was successful we deploy the Docker images into the k8s cluster
93 | deploy_to_gke_k8s_cluster:
94 | docker:
95 | - image: google/cloud-sdk
96 | steps:
97 | - checkout
98 | - set_env_vars
99 | - setup_remote_docker:
100 | docker_layer_caching: false
101 | - run:
102 | name: Setup Google Cloud SDK
103 | command: |
104 | echo "$GOOGLE_SERVICE_KEY" > ${HOME}/gcloud-service-key.json
105 | gcloud auth activate-service-account --key-file=${HOME}/gcloud-service-key.json
106 | gcloud config set project "$GOOGLE_PROJECT_ID"
107 | gcloud config set compute/zone "$GOOGLE_COMPUTE_ZONE"
108 | gcloud container clusters get-credentials "$GKE_CLUSTER_NAME"
109 | - run:
110 | name: Deploy to GKE k8s cluster
111 | command: |
112 | kubectl apply -f ./k8s
113 | kubectl set image deployments/app-deployment app-frontend=$DOCKER_USER/$IMAGE_NAME_BASE-app:$TAG
114 | kubectl set image deployments/backend-deployment ragflow-backend=$DOCKER_USER/$IMAGE_NAME_BASE-backend:$TAG
115 | kubectl set image deployments/vectorstore-deployment chromadb-vectorstore=$DOCKER_USER/$IMAGE_NAME_BASE-vectorstore:$TAG
116 |
117 | workflows:
118 | version: 2
119 | build-deploy:
120 | jobs:
121 | - push_app_image
122 | - push_ragflow_image
123 | - push_vectorstore_image
124 | - push_test_image
125 | - run_integration_tests:
126 | requires:
127 | - push_app_image
128 | - push_ragflow_image
129 | - push_vectorstore_image
130 | - push_test_image
131 | - deploy_to_gke_k8s_cluster:
132 | requires:
133 | - run_integration_tests
134 |
135 | commands:
136 | set_env_vars:
137 | steps:
138 | - run:
139 | name: Setup tag and base image name
140 | command: |
141 | echo 'export TAG=${CIRCLE_SHA1:0:8}' >> $BASH_ENV
142 | echo 'export IMAGE_NAME_BASE=ragflow' >> $BASH_ENV
143 | echo 'export NOW=$(date --utc +"%Y-%m-%d %H:%M:%S")' >> $BASH_ENV
144 | echo 'export PIPELINE_NUMBER=<< pipeline.number >>' >> $BASH_ENV
145 | docker_login:
146 | steps:
147 | - run:
148 | name: Login into Docker Hub
149 | command: |
150 | echo "$DOCKER_PWD" | docker login --username "$DOCKER_USER" --password-stdin
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # env variables
2 | .env
3 |
4 | # local stuff
5 | .venv/
6 | .vscode/
7 |
8 | **.bin
9 | **/.ipynb_checkpoints
10 | **/__pycache__
11 |
12 | # db
13 | postgres/
14 | vectorstore/chromadb/
15 | *.sqlite3
16 |
17 | # jupyter notebooks
18 | *.ipynb
19 | **/notebooks
20 |
21 | # k8s configs
22 | k8s/dashboard.yaml
23 |
24 | # misc
25 | #resources/*
26 | #!resources/msg_life-gb-2021-EN_final.pdf
27 |
--------------------------------------------------------------------------------
/RAGflow overview.drawio:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # RAGflow: Build optimized and robust LLM applications
2 |
3 | [](https://app.circleci.com/pipelines/circleci/6FfqBzs4fBDyTPvBNqnq5x)
4 |
5 | RAGflow provides tools for constructing and evaluating Retrieval Augmented Generation (RAG) systems, empowering developers to craft efficient question-answer applications leveraging LLMs. The stack consists of
6 |
7 | `Language` [Python](https://www.python.org/)\
8 | `Frameworks for LLMs` [LangChain](https://www.langchain.com/) [OpenAI](https://www.openai.com/) [Hugging Face](https://huggingface.co/)\
9 | `Framework for API` [FastAPI](https://fastapi.tiangolo.com/)\
10 | `Databases` [ChromaDB](https://www.trychroma.com/) [Postgres](https://www.postgresql.org/)\
11 | `Frontend` [Streamlit](https://www.streamlit.io/)\
12 | `CI/CD` [Docker](https://www.docker.com/) [Kubernetes](https://kubernetes.io/) [CircleCI](https://circleci.com/) [GKE](https://cloud.google.com/kubernetes-engine)
13 |
14 | # 🚀 Getting Started
15 |
16 | - CircleCI pushes the Docker images after each successful build to
17 | - https://hub.docker.com/u/andreasx42
18 | - Google Kubernetes Engine cluster currently not available
19 | - Checkout repository
20 | - Start application with ‘docker-compose up --build’
21 | - Application should be available on localhost:8501.
22 | - Backend API documentation is available on localhost:8080/docs
23 | - Use Kubernetes with 'kubectl apply -f k8s' to deploy locally
24 | - Application should be available directly on localhost/
25 | - For backend API access we use nginx routing with localhost/api/\*
26 | - Be aware to check deployment configs for image versions
27 |
28 | # 📖 What is Retrievel Augmented Generation (RAG)?
29 |
30 |
31 |
Source
32 |
33 | In RAG, when a user query is received, relevant documents or passages are retrieved from a massive corpus, i.e. a document store. These retrieved documents are then provided as context to a generative model, which synthesizes a coherent response or answer using both the input query and the retrieved information. This approach leverages the strengths of both retrieval-based and generative systems, aiming to produce accurate and well-formed responses by drawing from vast amounts of textual data.
34 |
35 | # 🚀 Workflow of RAGflow
36 |
37 | - Automatic Generation of Question-Answer Pairs\
38 | Begin with RAGflow's capability to generate relevant question-answer pairs from provided documents which is used as an evaluation dataset to evaluate RAG systems.
39 | Hyperparameter Evaluation
40 | - Evaluate provided hyperparameters \
41 | After generating Q&A pairs, dive into hyperparameter evaluation. Provide your hyperparameters, let RAGflow evaluate their efficacy, and obtain insights for crafting robust RAG systems.
42 | This approach allows you to select efficient document splitting strategies, language and embedding models which could be further finetuned with respect to your document store.
43 |
44 | Here is a schematic overview:
45 |
46 | 
47 |
48 | # 🌟 Key Features & Functionalities
49 |
50 | - `Document Store Integration` Provide documents in formats like pdf and docx as knowledge base.
51 | - `Dynamic Parameter Selection` Customize parameters such as document splitting strategies, embedding model, and question-answering LLMs for evaluations.
52 | - `Automated Dataset Generation` Automatically generates question answer pairs from the provided documents as evaluation dataset to evaluate each parameterized RAG system.
53 | - `Evaluation & Optimization` Optimize performance using grid searches across parameter spaces.
54 | - `Advanced Retrieval Mechanisms` Employ techniques like "MMR" and the SelfQueryRetriever for optimal data extraction.
55 | - `Integration with External Platforms` Collaborate with platforms like Anyscale, MosaicML, and Replicate for enhanced functionalities and state-of-the-art LLM models.
56 | - `Interactive Feedback Loop` Refine and improve your RAG system with interactive feedback based on real-world results.
57 |
58 | # 🛠️ Development
59 |
60 | Directory Structure
61 |
62 | - `/.circleci` CircleCI integration config for CI/CD pipeline.
63 | - `/app` Frontend components and resources in Streamlit.
64 | - `/ragflow` Backend services and APIs.
65 | - `/tests` Test scripts and test data.
66 | - `/resources` Data storage.
67 | - `/vectorstore` ChromaDB component.
68 |
69 | # 🌐 Links & Resources
70 |
71 | - TBA
72 |
--------------------------------------------------------------------------------
/TODO:
--------------------------------------------------------------------------------
1 | For k8s deployment:
2 | - use $TAG as variable in deployment config and use sed to change to build number
3 | - sed -i 's/\$TAG/{{IMAGE_TAG}}/g' client-deployment.yaml
4 | or:
5 | - imperative command
6 | - kubectl set image deployment/$name client=andreasx42/ragflow-$img_name:$tag
7 | Secrets:
8 | - kubectl create secret generic --from-literal key=value
9 | Nginx installation
10 | - kubectl apply -f https://raw.githubusercontent.com/kubernetes/ingress-nginx/controller-v1.8.2/deploy/static/provider/cloud/deploy.yaml
11 | k8s dashboard:
12 | - create admin: kubectl -n kubernetes-dashboard create token admin-user
13 |
14 |
15 | HuggingFaceEmbeddings(
16 | model_name=model_name,
17 | model_kwargs={"device": "cuda"},
18 | encode_kwargs={"device": "cuda", "batch_size": 100})
--------------------------------------------------------------------------------
/app/.dockerignore:
--------------------------------------------------------------------------------
1 | # git
2 | .git
3 | .gitignore
4 | .cache
5 | **.log
6 |
7 | # Byte-compiled / optimized / DLL files
8 | **/__pycache__
9 | **/*.pyc
10 | **.py[cod]
11 | *$py.class
12 | .vscode/
13 |
14 | # Jupyter Notebook
15 | **/.ipynb_checkpoints
16 | notebooks/
17 |
18 | # dotenv
19 | .env
20 |
21 | # virtualenv
22 | .venv*
23 | venv*/
24 | ENV*/
--------------------------------------------------------------------------------
/app/.streamlit/config.toml:
--------------------------------------------------------------------------------
1 | [server]
2 | maxUploadSize = 10
3 | # this is needed for local development with docker
4 | # if you don't want to start the default browser:
5 | headless = true
6 | # you will need this for local development:
7 | runOnSave = true
8 | # you will need this if running docker on windows host:
9 | fileWatcherType = "poll"
10 |
11 | [theme]
12 | primaryColor = "#E3735E"
13 | backgroundColor = "#121212" # Dark gray background
14 | secondaryBackgroundColor = "#1E1E1E" # Slightly lighter gray for secondary elements
15 | textColor = "#E0E0E0" # Light gray text for better contrast against the dark background
16 | font = "sans serif"
--------------------------------------------------------------------------------
/app/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM python:3.11
2 |
3 | # Port the app is running on
4 | EXPOSE 8501
5 |
6 | # Install dependencies
7 | WORKDIR /app
8 |
9 | COPY ./requirements.txt ./
10 |
11 | RUN pip install --no-cache-dir --upgrade -r ./requirements.txt
12 |
13 | # Copy all into image
14 | COPY ./ ./
15 |
16 | HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health
17 |
18 | CMD ["streamlit", "run", "main.py", "--server.port=8501", "--server.address=0.0.0.0"]
--------------------------------------------------------------------------------
/app/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AndreasX42/RAGflow/cec31b91af8ab7faa05a3d0376338c7d294d29aa/app/__init__.py
--------------------------------------------------------------------------------
/app/main.py:
--------------------------------------------------------------------------------
1 | import streamlit as st
2 | from streamlit_option_menu import option_menu
3 |
4 | from page_parameters import page_parameters
5 | from page_home import page_home
6 | from page_documentstore import page_documentstore
7 | from page_dashboard import page_dashboard
8 | from page_login import page_login
9 | from page_filemanager import page_filemanager
10 | from page_apikeys import page_apikeys
11 | from page_chat import page_chat
12 | from utils import get_auth_user
13 |
14 | def main():
15 | # Display selected page with the respective function
16 | def sideBar():
17 | with st.sidebar:
18 | selected = option_menu(
19 | menu_title="Main Menu",
20 | options=[
21 | "Home",
22 | "Dashboard",
23 | "Parameters",
24 | "Q&A Chats",
25 | "Documents",
26 | "File Manager",
27 | "API Keys",
28 | "Login",
29 | ],
30 | icons=[
31 | "house-fill",
32 | "clipboard2-pulse-fill",
33 | "toggles",
34 | "chat-text-fill",
35 | "cloud-upload-fill",
36 | "file-earmark-text-fill",
37 | "safe-fill",
38 | "door-open-fill",
39 | ],
40 | menu_icon="cast",
41 | default_index=0,
42 | styles={
43 | "nav-link": {
44 | "--hover-color": "#1E1E1E",
45 | },
46 | "nav-link-selected": {"background-color": "#880808"},
47 | },
48 | )
49 |
50 | if selected == "Home":
51 | page_home()
52 | if selected == "Dashboard":
53 | page_dashboard()
54 | if selected == "Parameters":
55 | page_parameters()
56 | if selected == "Q&A Chats":
57 | page_chat()
58 | if selected == "Documents":
59 | page_documentstore()
60 | if selected == "File Manager":
61 | page_filemanager()
62 | if selected == "API Keys":
63 | page_apikeys()
64 | if selected == "Login":
65 | page_login()
66 |
67 | sideBar()
68 |
69 |
70 | if __name__ == "__main__":
71 | st.set_page_config(
72 | page_title="RAGflow",
73 | page_icon="🧊",
74 | layout="wide",
75 | initial_sidebar_state="expanded",
76 | )
77 |
78 | # set user_id attribute in session state after browser refresh if user already authenticated
79 | if "user_id" not in st.session_state or not st.session_state.user_id:
80 | user_data, success = get_auth_user()
81 | if success:
82 | st.session_state.user_id = str(user_data.get("id"))
83 |
84 | main()
85 |
86 | with st.sidebar:
87 | if "user_id" in st.session_state and st.session_state.user_id:
88 | auth_button = 'Authenticated
'
89 |
90 | else:
91 | auth_button = 'Not Authenticated
'
92 |
93 | st.markdown(auth_button, unsafe_allow_html=True)
94 |
95 | st.markdown("
" * 0, unsafe_allow_html=True)
96 | st.markdown("---")
97 | st.markdown(
98 | '',
99 | unsafe_allow_html=True,
100 | )
101 |
102 | version_info = """
103 |
104 | Build Number:
$BUILD_NUMBER
105 | Git Commit SHA:
$GIT_SHA
106 | Build Time (UTC): $BUILD_DATE
107 |
108 | """
109 |
110 | st.markdown(version_info, unsafe_allow_html=True)
111 |
--------------------------------------------------------------------------------
/app/page_apikeys.py:
--------------------------------------------------------------------------------
1 | import streamlit as st
2 | from utils import display_user_login_warning
3 |
4 |
5 | def page_apikeys():
6 | st.title("Provide API Keys")
7 | st.subheader("Enter the API Keys for all services you want to use.")
8 |
9 | if display_user_login_warning():
10 | return
11 |
12 | if "api_keys" not in st.session_state:
13 | st.session_state.api_keys = {}
14 |
15 | # Create text area to input API keys
16 | input_text = st.text_area(
17 | "Enter the required API keys for your intended use. The keys will not get stored outside of Streamlits session state.\n\nProvide one 'name=key' pair per line, for example:",
18 | "OPENAI_API_KEY = your_openai_key\nANYSCALE_API_KEY = your_anyscale_key",
19 | )
20 |
21 | # Create a submit button
22 | if st.button("Submit"):
23 | store_in_cache(input_text)
24 | st.success("API keys stored successfully!")
25 |
26 | # Display stored API keys for testing purposes (you might want to remove this in a real application)
27 | if st.session_state.api_keys:
28 | keys = ""
29 | for key, value in st.session_state.api_keys.items():
30 | keys += f"{key}: {value}\n"
31 |
32 | st.markdown("
" * 1, unsafe_allow_html=True)
33 | st.code(keys)
34 |
35 |
36 | def store_in_cache(api_keys: str) -> None:
37 | """Writes api keys to streamlit session state."""
38 | st.session_state.api_keys = {}
39 |
40 | for line in api_keys.splitlines():
41 | if "=" in line:
42 | key, value = line.split("=", 1)
43 | st.session_state.api_keys[key.strip()] = value.strip()
44 |
--------------------------------------------------------------------------------
/app/page_chat.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import streamlit as st
3 | from utils import *
4 | import os
5 |
6 |
7 | def page_chat():
8 | # Title of the page
9 | st.subheader("Chat with a RAG model from a hyperparameter evaluation")
10 |
11 | if display_user_login_warning():
12 | return
13 |
14 | if not os.path.exists(get_hyperparameters_results_path()):
15 | st.warning("No hyperparameter results available. Run some evaluation.")
16 | return
17 |
18 | if "messages" not in st.session_state:
19 | st.session_state.messages = [
20 | {"role": "assistant", "content": "How may I help you?"}
21 | ]
22 |
23 | with st.expander("View hyperparameter results"):
24 | df = load_hp_results()
25 | showData = st.multiselect(
26 | "Filter: ",
27 | df.columns,
28 | default=[
29 | "id",
30 | "chunk_size",
31 | "chunk_overlap",
32 | "length_function_name",
33 | "num_retrieved_docs",
34 | "search_type",
35 | "similarity_method",
36 | "use_llm_grader",
37 | ],
38 | )
39 | st.dataframe(df[showData], use_container_width=True)
40 |
41 | # select chat model from hyperparam run id
42 | hp_id = 0
43 | hp_id = st.selectbox(
44 | "Select chat model from hyperparameter evaluations", options=list(df.id)
45 | )
46 | # Check if hp_id changed and clear chat history if it did
47 | if (
48 | "hp_id" in st.session_state
49 | and hp_id != st.session_state.hp_id
50 | or "hp_id" not in st.session_state
51 | ):
52 | st.session_state.messages = [
53 | {"role": "assistant", "content": "How may I help you?"}
54 | ]
55 | # Update the session state with the new hp_id
56 | st.session_state.hp_id = hp_id
57 |
58 | st.markdown("
" * 1, unsafe_allow_html=True)
59 | st.write("Chat with the chosen parametrized RAG")
60 |
61 | for message in st.session_state.messages:
62 | with st.chat_message(message["role"]):
63 | st.write(message["content"])
64 |
65 | # User-provided query
66 | if query := st.chat_input():
67 | st.session_state.messages.append({"role": "user", "content": query})
68 | with st.chat_message("user"):
69 | st.write(query)
70 |
71 | # Generate a new response if last message is not from assistant
72 | if st.session_state.messages[-1]["role"] != "assistant":
73 | with st.chat_message("assistant"):
74 | # get and display answer
75 | answer, source_docs = get_rag_response_stream(hp_id, query)
76 |
77 | # display retrieved documents
78 | if "I don't know".lower() not in answer.lower() and source_docs is not None:
79 | display_documents(source_docs)
80 |
81 | message = {"role": "assistant", "content": answer}
82 | st.session_state.messages.append(message)
83 |
84 |
85 | def display_documents(documents: list[dict]):
86 | for idx, doc in enumerate(documents["source_documents"]):
87 | with st.expander(f"Document {idx + 1}"):
88 | st.text(f"Source: {os.path.basename(doc['metadata']['source'])}")
89 | st.text(
90 | f"Index location: {doc['metadata']['start_index']}-{doc['metadata']['end_index']}"
91 | )
92 | st.text_area("Content", value=doc["page_content"], height=150)
93 |
94 |
95 | def retrieve_source_documents(hp_id: int, query: str):
96 | documents = get_docs_from_query(hp_id, query)
97 | display_documents(documents)
98 |
99 |
100 | def load_hp_results() -> pd.DataFrame:
101 | with open(get_hyperparameters_results_path(), encoding="utf-8") as file:
102 | hp_data = json.load(file)
103 |
104 | df = pd.DataFrame(hp_data)
105 | df["id"] = df["id"].astype(int)
106 |
107 | # Flatten the "scores" sub-dictionary
108 | scores_df = pd.json_normalize(df["scores"])
109 |
110 | # Combine the flattened scores DataFrame with the original DataFrame
111 | df = pd.concat(
112 | [df[["id"]], df.drop(columns=["id", "scores"], axis=1), scores_df], axis=1
113 | )
114 |
115 | # Print the resulting DataFrame
116 | return df
117 |
118 |
119 | def display_rag_response_stream(hp_id: int, query: str):
120 | # Start a separate thread for fetching stream data
121 | if "rag_response_stream_data" not in st.session_state:
122 | st.session_state["rag_response_stream_data"] = "Starting stream...\n"
123 |
124 | get_rag_response_stream(hp_id, query)
125 |
--------------------------------------------------------------------------------
/app/page_dashboard.py:
--------------------------------------------------------------------------------
1 | import streamlit as st
2 | import pandas as pd
3 | import os
4 | import pandas as pd
5 | import json
6 | import plotly.express as px
7 | from utils import *
8 | from utils import display_user_login_warning
9 |
10 |
11 | def page_dashboard():
12 | st.title("Dashboard Page")
13 | st.subheader("Analyse hyperparameter metrics and evaluation dataset.")
14 |
15 | if display_user_login_warning():
16 | return
17 |
18 | else:
19 | tab1, tab2, tab3, tab4 = st.tabs(
20 | [
21 | "Charts",
22 | "Evaluation dataset",
23 | "Generated predictions",
24 | "Combined dataset",
25 | ]
26 | )
27 |
28 | with tab1:
29 | if os.path.exists(get_hyperparameters_results_path()):
30 | plot_hyperparameters_results(get_hyperparameters_results_path())
31 | else:
32 | st.warning("No hyperparameter results available. Run some evaluation.")
33 |
34 | with tab2:
35 | if os.path.exists(get_label_dataset_path()):
36 | df_label_dataset = get_df_label_dataset()
37 |
38 | showData = st.multiselect(
39 | "Filter: ",
40 | df_label_dataset.columns,
41 | default=["question", "answer", "context", "source", "id"],
42 | )
43 | st.dataframe(df_label_dataset[showData], use_container_width=True)
44 | else:
45 | st.warning("No evaluation data available. Generate it.")
46 |
47 | with tab3:
48 | if os.path.exists(get_hyperparameters_results_data_path()):
49 | df_hp_runs = pd.read_csv(get_hyperparameters_results_data_path())
50 |
51 | showData = st.multiselect(
52 | "Filter: ",
53 | df_hp_runs.columns,
54 | default=[
55 | "hp_id",
56 | "predicted_answer",
57 | "retrieved_docs",
58 | "qa_id",
59 | ],
60 | )
61 | st.dataframe(df_hp_runs[showData], use_container_width=True)
62 | else:
63 | st.warning("No generated dataset from hyperparameter runs available.")
64 |
65 | with tab4:
66 | if os.path.exists(get_label_dataset_path()) and os.path.exists(
67 | get_hyperparameters_results_data_path()
68 | ):
69 | df_label_dataset4 = get_df_label_dataset()
70 | df_hp_runs4 = get_df_hp_runs()
71 |
72 | merged_df = df_label_dataset4.merge(
73 | df_hp_runs4, left_on="id", right_on="qa_id"
74 | )
75 | merged_df.drop(columns="id", inplace=True)
76 |
77 | showData = st.multiselect(
78 | "Filter: ",
79 | merged_df.columns,
80 | default=[
81 | "question",
82 | "answer",
83 | "predicted_answer",
84 | "retrieved_docs",
85 | "hp_id",
86 | "qa_id",
87 | "source",
88 | ],
89 | )
90 | st.dataframe(merged_df[showData], use_container_width=True)
91 | else:
92 | st.warning("Not sufficient data available.")
93 |
94 |
95 | def get_df_label_dataset() -> pd.DataFrame:
96 | df_label_dataset = pd.read_json(get_label_dataset_path())
97 | df_label_dataset = pd.concat(
98 | [df_label_dataset, pd.json_normalize(df_label_dataset["metadata"])],
99 | axis=1,
100 | )
101 | df_label_dataset = df_label_dataset.drop(columns=["metadata"])
102 |
103 | return df_label_dataset
104 |
105 |
106 | def get_df_hp_runs() -> pd.DataFrame():
107 | return pd.read_csv(get_hyperparameters_results_data_path())
108 |
109 |
110 | def plot_hyperparameters_results(hyperparameters_results_path: str):
111 | with open(hyperparameters_results_path, "r", encoding="utf-8") as file:
112 | hyperparameters_results = json.load(file)
113 |
114 | # Convert the list of dictionaries to a DataFrame
115 | df = pd.DataFrame(hyperparameters_results)
116 |
117 | # Extract scores and timestamps into separate DataFrames
118 | scores_df = df["scores"].apply(pd.Series)
119 | scores_df["timestamp"] = pd.to_datetime(df["timestamp"])
120 | # Sort DataFrame based on timestamps for chronological order
121 | scores_df = scores_df.sort_values(by="timestamp")
122 | df = df.loc[scores_df.index]
123 | # Create a combined x-axis with incrementing number + timestamp string
124 | df["x_values"] = range(1, len(scores_df) + 1)
125 | df["x_ticks"] = [f"{i}" for i, ts in enumerate(scores_df["timestamp"])]
126 | df["timestamp"] = scores_df["timestamp"]
127 | # Melt the scores_df DataFrame for Plotly plotting
128 | df_melted = pd.melt(
129 | scores_df.reset_index(),
130 | id_vars=["index"],
131 | value_vars=[
132 | "answer_similarity_score",
133 | "retriever_mrr@3",
134 | "retriever_mrr@5",
135 | "retriever_mrr@10",
136 | "rouge1",
137 | "rouge2",
138 | "rougeLCS",
139 | "correctness_score",
140 | "comprehensiveness_score",
141 | "readability_score",
142 | "retriever_semantic_accuracy",
143 | ],
144 | )
145 | # Merge on 'index' to get the correct x_values and x_ticks
146 | df_melted = df_melted.merge(
147 | df[["x_values", "x_ticks", "timestamp"]],
148 | left_on="index",
149 | right_index=True,
150 | how="left",
151 | )
152 | df_melted = df_melted[df_melted.value >= 0]
153 |
154 | # Plot using Plotly Express
155 | fig = px.scatter(
156 | df_melted,
157 | x="x_values",
158 | y="value",
159 | color="variable",
160 | hover_data=["x_ticks", "variable", "value", "timestamp"],
161 | # color_discrete_sequence=px.colors.sequential.Viridis,
162 | labels={"x_values": "Hyperparameter run id", "value": "Scores"},
163 | title="Hyperparameter chart",
164 | )
165 |
166 | fig.update_layout(
167 | # xaxis_tickangle=-45,
168 | xaxis=dict(tickvals=df["x_values"], ticktext=df["x_ticks"]),
169 | yaxis=dict(tickvals=[i / 10.0 for i in range(11)]),
170 | plot_bgcolor="#F5F5DC",
171 | paper_bgcolor="#121212",
172 | height=600, # Set the height of the plot
173 | width=950, # Set the width of the plot
174 | )
175 | st.plotly_chart(fig)
176 |
--------------------------------------------------------------------------------
/app/page_documentstore.py:
--------------------------------------------------------------------------------
1 | import streamlit as st
2 | from utils import *
3 | from utils import display_user_login_warning
4 |
5 |
6 | def page_documentstore():
7 | # Title of the page
8 | st.title("Document Store Page")
9 | st.subheader("Provide the documents that the application should use.")
10 |
11 | if display_user_login_warning():
12 | return
13 |
14 | else:
15 | tab1, tab2 = st.tabs(["Upload Documents", "Provide cloud storage"])
16 |
17 | with tab1:
18 | upload_files(
19 | context="docs",
20 | dropdown_msg="Upload your documents",
21 | ext_list=["pdf", "txt", "docx"],
22 | file_path=get_document_store_path(),
23 | allow_multiple_files=True,
24 | )
25 |
26 | with tab2:
27 | st.write("Not implemented yet.")
28 | # Provide link to cloud resource
29 | cloud_link = st.text_input("Provide a link to your cloud resource:")
30 |
31 | if cloud_link:
32 | st.write(f"You provided the link: {cloud_link}")
33 |
34 | st.markdown("
" * 1, unsafe_allow_html=True)
35 |
36 | # List all files in the directory
37 | st.subheader("Your Document Store")
38 | path = get_document_store_path()
39 | if path:
40 | structure = ptree(path)
41 | st.code(structure, language="plaintext")
42 |
--------------------------------------------------------------------------------
/app/page_filemanager.py:
--------------------------------------------------------------------------------
1 | import streamlit as st
2 | import os
3 | import zipfile
4 | import io
5 | from utils import (
6 | list_files_in_directory,
7 | get_user_directory,
8 | ptree,
9 | display_user_login_warning,
10 | )
11 |
12 | import time
13 |
14 |
15 | def page_filemanager():
16 | st.title("File Manager")
17 | st.subheader("Manager the files in your directory.")
18 |
19 | if display_user_login_warning():
20 | return
21 |
22 | else:
23 | tab1, tab2 = st.tabs(["Download files", "Delete files"])
24 |
25 | with tab1:
26 | files = list_files_in_directory(get_user_directory())
27 | selected_files = st.multiselect("Select files to download:", files)
28 |
29 | if len(selected_files) > 0 and st.button("Select Files"):
30 | # Create a zip archive in-memory
31 | zip_buffer = io.BytesIO()
32 | with zipfile.ZipFile(
33 | zip_buffer, "a", zipfile.ZIP_DEFLATED, False
34 | ) as zip_file:
35 | for selected in selected_files:
36 | file_path = os.path.join(get_user_directory(), selected)
37 | zip_file.write(file_path, selected)
38 |
39 | zip_buffer.seek(0)
40 | st.download_button(
41 | label="Download Files",
42 | data=zip_buffer,
43 | file_name="files.zip",
44 | mime="application/zip",
45 | )
46 |
47 | with tab2:
48 | files = list_files_in_directory(get_user_directory())
49 | selected_files = st.multiselect("Select files to delete:", files)
50 |
51 | if len(selected_files) > 0 and st.button("Delete Files"):
52 | action_success = True
53 | for selected in selected_files:
54 | file_path = os.path.join(get_user_directory(), selected)
55 | try:
56 | os.remove(file_path)
57 | except Exception as e:
58 | st.error(f"Error deleting {selected}: {e}")
59 | action_success = False
60 |
61 | if action_success:
62 | st.success(f"Deleted selected files successfully!")
63 |
64 | st.markdown("
" * 1, unsafe_allow_html=True)
65 |
66 | # List all files in the directory
67 | st.subheader("User directory")
68 | path = get_user_directory()
69 | if path or st.session_state.refresh_trigger:
70 | structure = ptree(path)
71 | st.code(structure, language="plaintext")
72 |
--------------------------------------------------------------------------------
/app/page_home.py:
--------------------------------------------------------------------------------
1 | import streamlit as st
2 |
3 |
4 | # Define pages and their functions
5 | def page_home():
6 | st.title("Welcome to RAGflow!")
7 |
8 | tab1, tab2, tab3 = st.tabs(["The App", "Getting Started", "About"])
9 |
10 | with tab1:
11 | # Introductory text
12 | st.write(
13 | """
14 | RAGflow is an advanced application framework tailored to streamline the construction and evaluation processes for Retrieval Augmented Generation (RAG) systems in question-answer contexts. Here's a brief rundown of its functionality:
15 | """
16 | )
17 |
18 | # Functionality points
19 | st.header("Key Features & Functionalities")
20 |
21 | # 1. Document Store Integration
22 | st.subheader("1. Document Store Integration")
23 | st.write(
24 | """
25 | RAGflow starts by interfacing with a variety of document stores, enabling users to select from a diverse range of data sources.
26 | """
27 | )
28 |
29 | # 2. Dynamic Parameter Selection
30 | st.subheader("2. Dynamic Parameter Selection")
31 | st.write(
32 | """
33 | Through an intuitive interface, users can customize various parameters, including document splitting strategies, embedding model choices, and question-answering LLMs. These parameters influence how the app splits, encodes, and processes data.
34 | """
35 | )
36 |
37 | # 3. Automated Dataset Generation
38 | st.subheader("3. Automated Dataset Generation")
39 | st.write(
40 | """
41 | One of RAGflow's core strengths is its capability to auto-generate datasets. This feature allows for a seamless transition from raw data to a structured format ready for processing.
42 | """
43 | )
44 |
45 | # 4. Advanced Retrieval Mechanisms
46 | st.subheader("4. Advanced Retrieval Mechanisms")
47 | st.write(
48 | """
49 | Incorporating state-of-the-art retrieval methods, RAGflow employs techniques like "MMR" for filtering and the SelfQueryRetriever for targeted data extraction. This ensures that the most relevant document chunks are presented for question-answering tasks.
50 | """
51 | )
52 |
53 | # 5. Integration with External Platforms
54 | st.subheader("5. Integration with External Platforms")
55 | st.write(
56 | """
57 | RAGflow is designed to work in tandem with platforms like Anyscale, MosaicML, and Replicate, offering users extended functionalities and integration possibilities.
58 | """
59 | )
60 |
61 | # 6. Evaluation & Optimization
62 | st.subheader("6. Evaluation & Optimization")
63 | st.write(
64 | """
65 | At its core, RAGflow is built to optimize performance. By performing grid searches across the selected parameter space, it ensures that users achieve the best possible results for their specific configurations.
66 | """
67 | )
68 |
69 | # 7. Interactive Feedback Loop
70 | st.subheader("7. Interactive Feedback Loop")
71 | st.write(
72 | """
73 | As users interact with the generated question-answer systems, RAGflow offers feedback mechanisms, allowing for continuous improvement and refinement based on real-world results.
74 | """
75 | )
76 |
77 | # Conclusion
78 | st.write(
79 | """
80 | Dive into RAGflow, set your parameters, and watch as it automates and optimizes the intricate process of building robust, data-driven question-answering systems!
81 | """
82 | )
83 |
84 | with tab2:
85 | # Functional stages
86 | st.header("Using This App: A Step-by-Step Guide")
87 |
88 | st.write(
89 | """
90 | 1. Account Setup:
91 | Log in or register to get access to all services.
92 |
93 | 2. API Key Submission:
94 | Supply all the required API keys, such as those for OpenAI. The keys needed will vary based on the LLMs and services you intend to use. Keys will only get stored locally until page reloads.
95 |
96 | 3. Document Upload:
97 | Upload the documents for which you aim to create a RAG application under 'Documents'.
98 |
99 | 4. Parameter Configuration:
100 | Input the necessary parameters directly or by uploading a JSON file in 'Parameters'. This is essential for both the evaluation dataset generation in the 'QA Generator settings' tab and for the hyperparameter assessment in the 'Hyperparameters settings' tab.
101 |
102 | 5. Dashboard Analysis:
103 | Navigate to the 'Dashboard' to inspect hyperparameter metrics and view the generated data.
104 |
105 | 6. File Management:
106 | The 'File Manager' allows you to delete or download files within your directory for added flexibility.
107 | """
108 | )
109 |
110 | with tab3:
111 | st.subheader("About Me")
112 | st.write(
113 | """
114 | Hello! I'm a passionate developer and AI enthusiast working on cutting-edge projects to leverage the power of machine learning.
115 | My latest project, RAGflow, focuses on building and evaluating Retrieval Augmented Generation (RAG) systems.
116 | Curious about my work? Check out my [GitHub repository](https://github.com/AndreasX42/RAGflow) for more details!
117 | """
118 | )
119 |
--------------------------------------------------------------------------------
/app/page_login.py:
--------------------------------------------------------------------------------
1 | import streamlit as st
2 | import time
3 |
4 | from utils import *
5 |
6 |
7 | def page_login():
8 | # Session State to track if user is logged in
9 | if "logged_in" not in st.session_state:
10 | st.session_state.logged_in = False
11 |
12 | if "user_id" not in st.session_state:
13 | st.session_state.user_id = ""
14 |
15 | # Tabs for Login and Registration
16 | tab1, tab2 = st.tabs(["Login", "Register"])
17 |
18 | # Login Tab
19 | with tab1:
20 | # Check if user is already logged in
21 | response, success = get_auth_user()
22 | if success:
23 | st.session_state.user_id = str(response.get("id"))
24 | st.success(f"Already signed in as '{response.get('username')}'.")
25 |
26 | elif not st.session_state.logged_in:
27 | with st.form("login_form"):
28 | st.subheader("Login")
29 | # Input fields for username and password
30 | login_username = st.text_input("Username", key="login_username")
31 | login_password = st.text_input(
32 | "Password", type="password", key="login_password"
33 | )
34 |
35 | # Submit button for the form
36 | submit_button = st.form_submit_button("Login")
37 |
38 | if submit_button:
39 | # Attempt to login
40 | response, success = user_login(login_username, login_password)
41 | if success:
42 | st.session_state.logged_in = True
43 | st.session_state.user_id = str(response.get("id"))
44 | st.success(
45 | f"Logged in successfully as {response.get('username')}."
46 | )
47 | else:
48 | st.error(
49 | f"Login failed: {response.get('detail', 'Unknown error')}"
50 | )
51 |
52 | if get_cookie_value() or st.session_state.user_id:
53 | with st.form("logout_form"):
54 | st.subheader("Logout")
55 | logout_button = st.form_submit_button("Logout")
56 | if logout_button:
57 | success = user_logout()
58 |
59 | if success:
60 | st.session_state.logged_in = False
61 | st.session_state.user_id = ""
62 | st.success("You are logged out!")
63 |
64 | else:
65 | st.error("Error logging out")
66 |
67 | time.sleep(1)
68 | st.rerun()
69 |
70 | # Registration Tab
71 | with tab2:
72 | with st.form("register_form"):
73 | st.subheader("Register")
74 | reg_username = st.text_input(
75 | "Username",
76 | key="reg_username",
77 | help="Username must be between 4 and 64 characters.",
78 | )
79 |
80 | reg_email = st.text_input("Email", key="reg_email")
81 |
82 | reg_password = st.text_input(
83 | "Password",
84 | type="password",
85 | key="reg_password",
86 | help="Password must be between 8 and 128 characters.",
87 | )
88 |
89 | if st.form_submit_button("Register"):
90 | reg_response, success = user_register(
91 | reg_username, reg_email, reg_password
92 | )
93 |
94 | if success:
95 | st.success(
96 | f"Registered successfully! Please log in, {reg_response.get('username')}."
97 | )
98 |
99 | else:
100 | st.error(f"Registration failed! Info: {reg_response['detail']}")
101 |
--------------------------------------------------------------------------------
/app/page_parameters.py:
--------------------------------------------------------------------------------
1 | import json
2 | import pandas as pd
3 | import streamlit as st
4 | from utils import *
5 |
6 | from streamlit import chat_message
7 |
8 |
9 | def page_parameters():
10 | st.title("Parameters Page")
11 | st.subheader("Provide parameters for building and evaluating the system.")
12 |
13 | if display_user_login_warning():
14 | return
15 |
16 | tab1, tab2 = st.tabs(["QA Generator settings", "Hyperparameters settings"])
17 |
18 | valid_data = get_valid_params()
19 |
20 | with tab1:
21 | st.write(
22 | "The QA generator provides the possibility to generate question-context-answer triples from the provided documents that are used to evaluate hyperparameters and the corresponding RAG model in a consecutive step. You can either provide parameters for the generator through the drop-down menu below or by uploading a JSON file."
23 | )
24 |
25 | provide_qa_gen_form(valid_data)
26 |
27 | upload_files(
28 | context="qa_params",
29 | dropdown_msg="Upload JSON file",
30 | ext_list=["json"],
31 | file_path=get_label_dataset_gen_params_path(),
32 | )
33 |
34 | st.markdown("
" * 1, unsafe_allow_html=True)
35 |
36 | submit_button = st.button("Start evaluation dataset generation", key="SubmitQA")
37 |
38 | if submit_button:
39 | with st.spinner("Running generator..."):
40 | result = start_qa_gen()
41 |
42 | if "Success" in result:
43 | st.success(result)
44 | else:
45 | st.error(result)
46 |
47 | with tab2:
48 | st.write(
49 | "The Hyperparameter evaluator provides functionality of benchmarking RAG models with the corresponding parameters. During evaluation a LLM predicts an answer with the provided query and retrieved document chunks. With that we can calculate embedding similarities of label and predicted answers and ROUGE scores to provoide some metrics. We can also provide a LLM that is used for grading the predicted answers and the retrieved documents to extract even more metrics."
50 | )
51 |
52 | provide_hp_params_form(valid_data)
53 |
54 | upload_files(
55 | context="hp_params",
56 | dropdown_msg="Upload JSON file",
57 | ext_list=["json"],
58 | file_path=get_hyperparameters_path(),
59 | )
60 |
61 | st.markdown("
" * 1, unsafe_allow_html=True)
62 |
63 | submit_button = st.button("Start hyperparameter evaluation", key="SubmitHP")
64 |
65 | if submit_button:
66 | with st.spinner("Running hyperparameter evaluation..."):
67 | result = start_hp_run()
68 |
69 | if "Success" in result:
70 | st.success(result)
71 | else:
72 | st.error(result)
73 |
74 |
75 | def provide_hp_params_form(valid_data: dict):
76 | with st.expander("Hyperparameters settings"):
77 | attributes2 = {
78 | "chunk_size": st.number_input("Chunk Size", value=512, key="chunk size"),
79 | "chunk_overlap": st.number_input(
80 | "Chunk Overlap", value=10, key="chunk overlap"
81 | ),
82 | "length_function_name": st.selectbox(
83 | "Length Function for splitting or embedding model for corresponding tokenizer",
84 | valid_data["embedding_models"] + ["len"],
85 | key="len_tab2",
86 | ),
87 | "num_retrieved_docs": st.number_input(
88 | "Number of Docs the retriever should return",
89 | value=3,
90 | key="number of docs to retrieve",
91 | ),
92 | "similarity_method": st.selectbox(
93 | "Retriever Similarity Method", valid_data["retr_sim_method"]
94 | ),
95 | "search_type": st.selectbox(
96 | "Retriever Search Type", valid_data["retr_search_types"]
97 | ),
98 | "embedding_model": st.selectbox(
99 | "Embedding Model",
100 | valid_data["embedding_models"],
101 | key="Name of embedding model.",
102 | ),
103 | "qa_llm": st.selectbox(
104 | "QA Language Model",
105 | valid_data["llm_models"],
106 | key="Name of LLM for QA task.",
107 | ),
108 | }
109 |
110 | use_llm_grader = st.checkbox(
111 | "Use LLM Grader", value=False, key="use_llm_grader_checkbox"
112 | )
113 |
114 | if use_llm_grader:
115 | attributes2["grade_answer_prompt"] = st.selectbox(
116 | "Grade Answer Prompt",
117 | valid_data["grade_answer_prompts"],
118 | key="Type of prompt to use for answer grading.",
119 | )
120 | attributes2["grade_docs_prompt"] = st.selectbox(
121 | "Grade Documents Prompt",
122 | valid_data["grade_documents_prompts"],
123 | key="Type of prompt to grade retrieved document chunks.",
124 | )
125 | attributes2["grader_llm"] = st.selectbox(
126 | "Grading LLM",
127 | valid_data["llm_models"],
128 | key="Name of LLM for grading.",
129 | )
130 |
131 | submit_button = st.button("Submit", key="Submit2")
132 | if submit_button:
133 | attributes2["use_llm_grader"] = use_llm_grader
134 | st.write("Saved to file. You've entered the following values:")
135 | # This is just a mockup to show you how to display the attributes.
136 | # You'll probably want to process or display them differently.
137 | st.write(attributes2)
138 | write_json(
139 | attributes2,
140 | get_hyperparameters_path(),
141 | append=True,
142 | )
143 |
144 |
145 | def provide_qa_gen_form(valid_data: dict):
146 | with st.expander("Provide QA Generator settings form"):
147 | attributes = {
148 | "chunk_size": st.number_input("Chunk Size", value=512),
149 | "chunk_overlap": st.number_input("Chunk Overlap", value=10),
150 | "length_function_name": st.selectbox(
151 | "Length Function for splitting or embedding model for corresponding tokenizer",
152 | valid_data["embedding_models"] + ["len"],
153 | ),
154 | "qa_generator_llm": st.selectbox(
155 | "QA Language Model", valid_data["llm_models"]
156 | ),
157 | }
158 |
159 | persist_to_vs = st.checkbox(
160 | "Cache answer embeddings in ChromaDB for each of the models listed below",
161 | value=False,
162 | key="persist_to_vs",
163 | )
164 |
165 | # Get input from the user
166 | if persist_to_vs:
167 | input_text = st.text_area(
168 | "Enter a list of embedding model names for caching the generated answer embeddings in chromadb (one per line):"
169 | )
170 |
171 | submit_button = st.button("Submit", key="submit1")
172 | if submit_button:
173 | # save bool in dict
174 | attributes["persist_to_vs"] = persist_to_vs
175 |
176 | if persist_to_vs:
177 | # Store the list in a dictionary
178 | # Split the text by newline to get a list
179 | text_list = input_text.split("\n")
180 | attributes["embedding_model_list"] = text_list
181 | else:
182 | attributes["embedding_model_list"] = []
183 |
184 | with st.spinner("Saving to file."):
185 | st.write("Saved to file. You've entered the following values:")
186 | # This is just a mockup to show you how to display the attributes.
187 | # You'll probably want to process or display them differently.
188 | st.write(attributes)
189 | # save json
190 | write_json(
191 | attributes,
192 | get_label_dataset_gen_params_path(),
193 | )
194 |
--------------------------------------------------------------------------------
/app/requirements.txt:
--------------------------------------------------------------------------------
1 | streamlit==1.28.1
2 | streamlit_option_menu==0.3.6
3 |
4 | plotly==5.18.0
5 | pandas==2.1.2
6 | numerize==0.12
7 | matplotlib==3.8.1
8 | seaborn==0.13.0
9 |
10 | openpyxl==3.1.2
--------------------------------------------------------------------------------
/docker-compose.dev.yaml:
--------------------------------------------------------------------------------
1 | version: "3"
2 | services:
3 | app:
4 | build:
5 | dockerfile: Dockerfile
6 | context: ./app
7 | container_name: app-frontend
8 | ports:
9 | - "8501:8501"
10 | volumes:
11 | - ./app:/app
12 | - ./tmp:/app/tmp
13 | environment:
14 | - RAGFLOW_HOST=ragflow-backend
15 | - RAGFLOW_PORT=8080
16 | restart: always
17 | ragflow-backend:
18 | build:
19 | dockerfile: Dockerfile
20 | context: ./ragflow
21 | container_name: ragflow-backend
22 | ports:
23 | - "8080:8080"
24 | volumes:
25 | - ./ragflow:/backend/ragflow
26 | - ./tmp:/backend/tmp
27 | environment:
28 | - JWT_SECRET_KEY=brjia5mOUlE3RN0CFy
29 | - POSTGRES_DATABASE=postgres_db
30 | - PGVECTOR_DATABASE=vector_db
31 | - POSTGRES_USER=admin
32 | - POSTGRES_PASSWORD=my_password
33 | - POSTGRES_HOST=postgres
34 | - POSTGRES_PORT=5432
35 | - POSTGRES_DRIVER=psycopg2
36 | - EXECUTION_CONTEXT=DEV
37 | - LOG_LEVEL=INFO
38 | - CHROMADB_HOST=chromadb
39 | - CHROMADB_PORT=8000
40 | restart: always
41 | chromadb:
42 | build:
43 | dockerfile: Dockerfile
44 | context: ./vectorstore
45 | container_name: chromadb
46 | ports:
47 | - "8000:8000"
48 | volumes:
49 | - ./vectorstore/chroma:/chroma/chroma
50 | - ./vectorstore:/vectorstore
51 | environment:
52 | - IS_PERSISTENT=TRUE
53 | - ALLOW_RESET=TRUE
54 | restart: always
55 | postgres:
56 | image: ankane/pgvector:v0.5.1
57 | container_name: postgres
58 | ports:
59 | - "5432:5432"
60 | volumes:
61 | - ./postgres/data:/var/lib/postgresql/data
62 | environment:
63 | - POSTGRES_USER=admin
64 | - POSTGRES_PASSWORD=my_password
65 | restart: always
66 | pgadmin:
67 | image: dpage/pgadmin4:8.0
68 | container_name: pgadmin
69 | environment:
70 | - PGADMIN_DEFAULT_EMAIL=admin@admin.com
71 | - PGADMIN_DEFAULT_PASSWORD=my_password
72 | ports:
73 | - "5050:80"
74 | restart: always
75 |
--------------------------------------------------------------------------------
/docker-compose.integration.test.yaml:
--------------------------------------------------------------------------------
1 | version: "3"
2 | services:
3 | ragflow-test:
4 | image: andreasx42/ragflow-backend:latest
5 | container_name: ragflow-test
6 | ports:
7 | - "8080:8080"
8 | environment:
9 | - JWT_SECRET_KEY=brjia5mOUlE3RN0CFy
10 | - POSTGRES_DATABASE=postgres
11 | - PGVECTOR_DATABASE=postgres
12 | - POSTGRES_USER=admin
13 | - POSTGRES_PASSWORD=my_password
14 | - POSTGRES_HOST=postgres-test
15 | - POSTGRES_PORT=5432
16 | - POSTGRES_DRIVER=psycopg2
17 | - EXECUTION_CONTEXT=TEST
18 | - LOG_LEVEL=INFO
19 | - CHROMADB_HOST=chromadb-test
20 | - CHROMADB_PORT=8000
21 | - INPUT_LABEL_DATASET=./resources/input_label_dataset.json
22 | volumes:
23 | - ./resources/tests:/backend/resources
24 | chromadb-test:
25 | image: andreasx42/ragflow-vectorstore:latest
26 | container_name: chromadb-test
27 | ports:
28 | - 8000:8000
29 | environment:
30 | - IS_PERSISTENT=TRUE
31 | - ALLOW_RESET=TRUE
32 | test-suite:
33 | image: andreasx42/ragflow-test:latest
34 | container_name: tester
35 | environment:
36 | - RAGFLOW_HOST=ragflow-test
37 | - RAGFLOW_PORT=8080
38 | - CHROMADB_HOST=chromadb-test
39 | - CHROMADB_PORT=8000
40 | depends_on:
41 | - ragflow-test
42 | - chromadb-test
43 | volumes:
44 | - ./resources/tests:/tests/resources
45 | postgres-test:
46 | image: ankane/pgvector:v0.5.1
47 | container_name: postgres-test
48 | ports:
49 | - "5432:5432"
50 | environment:
51 | - POSTGRES_USER=admin
52 | - POSTGRES_PASSWORD=my_password
53 | restart: always
54 |
--------------------------------------------------------------------------------
/docker-compose.local.test.yaml:
--------------------------------------------------------------------------------
1 | version: "3"
2 | services:
3 | ragflow-test:
4 | build:
5 | dockerfile: Dockerfile
6 | context: ./ragflow
7 | container_name: ragflow-test
8 | ports:
9 | - "8080:8080"
10 | environment:
11 | - JWT_SECRET_KEY=brjia5mOUlE3RN0CFy
12 | - POSTGRES_DATABASE=postgres
13 | - PGVECTOR_DATABASE=postgres
14 | - POSTGRES_USER=admin
15 | - POSTGRES_PASSWORD=my_password
16 | - POSTGRES_HOST=postgres-test
17 | - POSTGRES_PORT=5432
18 | - POSTGRES_DRIVER=psycopg2
19 | - EXECUTION_CONTEXT=TEST
20 | - LOG_LEVEL=INFO
21 | - CHROMADB_HOST=chromadb-test
22 | - CHROMADB_PORT=8000
23 | - INPUT_LABEL_DATASET=./resources/input_label_dataset.json
24 | volumes:
25 | - ./ragflow:/backend/ragflow
26 | - ./resources/tests:/backend/resources
27 | chromadb-test:
28 | build:
29 | dockerfile: Dockerfile
30 | context: ./vectorstore
31 | container_name: chromadb-test
32 | ports:
33 | - 8000:8000
34 | environment:
35 | - IS_PERSISTENT=TRUE
36 | - ALLOW_RESET=TRUE
37 | test-suite:
38 | build:
39 | dockerfile: Dockerfile
40 | context: ./tests
41 | container_name: tester
42 | environment:
43 | - RAGFLOW_HOST=ragflow-test
44 | - RAGFLOW_PORT=8080
45 | - CHROMADB_HOST=chromadb-test
46 | - CHROMADB_PORT=8000
47 | depends_on:
48 | - ragflow-test
49 | - chromadb-test
50 | volumes:
51 | - ./resources/tests:/tests/resources
52 | - ./tests:/tests
53 | postgres-test:
54 | image: ankane/pgvector:v0.5.1
55 | container_name: postgres-test
56 | ports:
57 | - "5432:5432"
58 | environment:
59 | - POSTGRES_USER=admin
60 | - POSTGRES_PASSWORD=my_password
61 | restart: always
62 |
--------------------------------------------------------------------------------
/k8s/app-cluster-ip-service.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: v1
2 | kind: Service
3 | metadata:
4 | name: app-cluster-ip-service
5 | spec:
6 | type: ClusterIP
7 | selector:
8 | component: frontend
9 | ports:
10 | - port: 8501
11 | targetPort: 8501
12 |
--------------------------------------------------------------------------------
/k8s/app-deployment.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: apps/v1
2 | kind: Deployment
3 | metadata:
4 | name: app-deployment
5 | spec:
6 | replicas: 1
7 | selector:
8 | matchLabels:
9 | component: frontend
10 | template:
11 | metadata:
12 | labels:
13 | component: frontend
14 | spec:
15 | volumes:
16 | - name: frontend-storage
17 | persistentVolumeClaim:
18 | claimName: shared-persistent-volume-claim
19 | containers:
20 | - name: app-frontend
21 | image: andreasx42/ragflow-app:latest
22 | env:
23 | - name: RAGFLOW_HOST
24 | value: backend-cluster-ip-service
25 | - name: RAGFLOW_PORT
26 | value: '8080'
27 | ports:
28 | - containerPort: 8501
29 | volumeMounts:
30 | - name: frontend-storage
31 | mountPath: app/tmp
32 | subPath: tmp
--------------------------------------------------------------------------------
/k8s/backend-cluster-ip-service.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: v1
2 | kind: Service
3 | metadata:
4 | name: backend-cluster-ip-service
5 | spec:
6 | type: ClusterIP
7 | selector:
8 | component: backend
9 | ports:
10 | - port: 8080
11 | targetPort: 8080
12 |
--------------------------------------------------------------------------------
/k8s/backend-deployment.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: apps/v1
2 | kind: Deployment
3 | metadata:
4 | name: backend-deployment
5 | spec:
6 | replicas: 1
7 | selector:
8 | matchLabels:
9 | component: backend
10 | template:
11 | metadata:
12 | labels:
13 | component: backend
14 | spec:
15 | volumes:
16 | - name: backend-storage
17 | persistentVolumeClaim:
18 | claimName: shared-persistent-volume-claim
19 | containers:
20 | - name: ragflow-backend
21 | image: andreasx42/ragflow-backend:latest
22 | env:
23 | - name: CHROMADB_HOST
24 | value: vectorstore-cluster-ip-service
25 | - name: CHROMADB_PORT
26 | value: '8000'
27 | - name: POSTGRES_HOST
28 | value: postgres-cluster-ip-service
29 | - name: POSTGRES_PORT
30 | value: '5432'
31 | - name: POSTGRES_DRIVER
32 | value: psycopg2
33 | - name: POSTGRES_PASSWORD
34 | valueFrom:
35 | secretKeyRef:
36 | name: pgsecrets
37 | key: POSTGRES_PASSWORD
38 | - name: POSTGRES_USER
39 | valueFrom:
40 | secretKeyRef:
41 | name: pgsecrets
42 | key: POSTGRES_USER
43 | - name: POSTGRES_DATABASE
44 | valueFrom:
45 | secretKeyRef:
46 | name: pgsecrets
47 | key: POSTGRES_DATABASE
48 | - name: PGVECTOR_DATABASE
49 | valueFrom:
50 | secretKeyRef:
51 | name: pgsecrets
52 | key: PGVECTOR_DATABASE
53 | - name: JWT_SECRET_KEY
54 | valueFrom:
55 | secretKeyRef:
56 | name: pgsecrets
57 | key: JWT_SECRET_KEY
58 | ports:
59 | - containerPort: 8080
60 | volumeMounts:
61 | - name: backend-storage
62 | mountPath: backend/tmp
63 | subPath: tmp
--------------------------------------------------------------------------------
/k8s/ingress-service.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: networking.k8s.io/v1
2 | kind: Ingress
3 | metadata:
4 | name: ingress-service
5 | annotations:
6 | nginx.ingress.kubernetes.io/use-regex: 'true'
7 | nginx.ingress.kubernetes.io/ssl-redirect: 'false'
8 | nginx.ingress.kubernetes.io/configuration-snippet: |
9 | proxy_set_header X-Script-Name /pgadmin;
10 | nginx.ingress.kubernetes.io/rewrite-target: /$1
11 | spec:
12 | ingressClassName: nginx
13 | rules:
14 | - http:
15 | paths:
16 | - path: /api/?(.*)
17 | pathType: ImplementationSpecific
18 | backend:
19 | service:
20 | name: backend-cluster-ip-service
21 | port:
22 | number: 8080
23 | - path: /pgadmin/?(.*)
24 | pathType: ImplementationSpecific
25 | backend:
26 | service:
27 | name: pgadmin-cluster-ip-service
28 | port:
29 | number: 5050
30 | - path: /?(.*)
31 | pathType: ImplementationSpecific
32 | backend:
33 | service:
34 | name: app-cluster-ip-service
35 | port:
36 | number: 8501
--------------------------------------------------------------------------------
/k8s/pgadmin-cluster-ip-service.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: v1
2 | kind: Service
3 | metadata:
4 | name: pgadmin-cluster-ip-service
5 | spec:
6 | type: ClusterIP
7 | selector:
8 | component: pgadmin
9 | ports:
10 | - port: 5050
11 | targetPort: 80
12 |
--------------------------------------------------------------------------------
/k8s/pgadmin-deployment.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: apps/v1
2 | kind: Deployment
3 | metadata:
4 | name: pgadmin-deployment
5 | spec:
6 | replicas: 1
7 | selector:
8 | matchLabels:
9 | component: pgadmin
10 | template:
11 | metadata:
12 | labels:
13 | component: pgadmin
14 | spec:
15 | containers:
16 | - name: pgadmin
17 | image: dpage/pgadmin4:8.0
18 | ports:
19 | - containerPort: 80
20 | env:
21 | - name: PGADMIN_DEFAULT_EMAIL
22 | valueFrom:
23 | secretKeyRef:
24 | name: pgsecrets
25 | key: PGADMIN_DEFAULT_EMAIL
26 | - name: PGADMIN_DEFAULT_PASSWORD
27 | valueFrom:
28 | secretKeyRef:
29 | name: pgsecrets
30 | key: POSTGRES_PASSWORD
31 |
--------------------------------------------------------------------------------
/k8s/postgres-cluster-ip-service.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: v1
2 | kind: Service
3 | metadata:
4 | name: postgres-cluster-ip-service
5 | spec:
6 | type: ClusterIP
7 | selector:
8 | component: postgres
9 | ports:
10 | - port: 5432
11 | targetPort: 5432
12 |
--------------------------------------------------------------------------------
/k8s/postgres-deployment.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: apps/v1
2 | kind: Deployment
3 | metadata:
4 | name: postgres-deployment
5 | spec:
6 | replicas: 1
7 | selector:
8 | matchLabels:
9 | component: postgres
10 | template:
11 | metadata:
12 | labels:
13 | component: postgres
14 | spec:
15 | volumes:
16 | - name: postgres-storage
17 | persistentVolumeClaim:
18 | claimName: shared-persistent-volume-claim
19 | containers:
20 | - name: postgres
21 | image: ankane/pgvector:v0.5.1
22 | ports:
23 | - containerPort: 5432
24 | volumeMounts:
25 | - name: postgres-storage
26 | mountPath: /var/lib/postgresql/data
27 | subPath: dbs/postgres
28 | env:
29 | - name: POSTGRES_PASSWORD
30 | valueFrom:
31 | secretKeyRef:
32 | name: pgsecrets
33 | key: POSTGRES_PASSWORD
34 | - name: POSTGRES_USER
35 | valueFrom:
36 | secretKeyRef:
37 | name: pgsecrets
38 | key: POSTGRES_USER
39 |
--------------------------------------------------------------------------------
/k8s/shared-persistent-volume-claim.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: v1
2 | kind: PersistentVolumeClaim
3 | metadata:
4 | name: shared-persistent-volume-claim
5 | spec:
6 | accessModes:
7 | - ReadWriteOnce
8 | resources:
9 | requests:
10 | storage: 2Gi
--------------------------------------------------------------------------------
/k8s/vectorstore-cluster-ip-service.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: v1
2 | kind: Service
3 | metadata:
4 | name: vectorstore-cluster-ip-service
5 | spec:
6 | type: ClusterIP
7 | selector:
8 | component: vectorstore
9 | ports:
10 | - port: 8000
11 | targetPort: 8000
12 |
--------------------------------------------------------------------------------
/k8s/vectorstore-deployment.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: apps/v1
2 | kind: Deployment
3 | metadata:
4 | name: vectorstore-deployment
5 | spec:
6 | replicas: 1
7 | selector:
8 | matchLabels:
9 | component: vectorstore
10 | template:
11 | metadata:
12 | labels:
13 | component: vectorstore
14 | spec:
15 | volumes:
16 | - name: vectorstore-storage
17 | persistentVolumeClaim:
18 | claimName: shared-persistent-volume-claim
19 | containers:
20 | - name: chromadb-vectorstore
21 | image: andreasx42/ragflow-vectorstore:latest
22 | env:
23 | - name: IS_PERSISTENT
24 | value: "TRUE"
25 | - name: ALLOW_RESET
26 | value: "TRUE"
27 | ports:
28 | - containerPort: 8000
29 | volumeMounts:
30 | - name: vectorstore-storage
31 | mountPath: chroma/chroma
32 | subPath: dbs/chromadb
33 |
--------------------------------------------------------------------------------
/ragflow/.dockerignore:
--------------------------------------------------------------------------------
1 | # git
2 | .git
3 | .gitignore
4 | .cache
5 | **.log
6 |
7 | # data/chromadb
8 | data/datasets.zip
9 | data/desc.docx
10 | chromadb/
11 |
12 | # Byte-compiled / optimized / DLL files
13 | **/__pycache__
14 | **/*.pyc
15 | **.py[cod]
16 | *$py.class
17 | .vscode/
18 |
19 | # Jupyter Notebook
20 | **/.ipynb_checkpoints
21 | notebooks/
22 |
23 | # dotenv
24 | .env
25 |
26 | # virtualenv
27 | .venv*
28 | venv*/
29 | ENV*/
--------------------------------------------------------------------------------
/ragflow/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM python:3.11
2 |
3 | # Port the app is running on
4 | EXPOSE 8080
5 |
6 | # Install dependencies
7 | WORKDIR /backend
8 |
9 | COPY ./requirements.txt ./
10 | RUN pip install --no-cache-dir --upgrade -r ./requirements.txt
11 |
12 | # Copy all into image
13 | COPY ./ ./ragflow
14 |
15 | ENV PYTHONPATH "${PYTHONPATH}:/backend"
16 |
17 | CMD ["uvicorn", "ragflow.api:app", "--host", "0.0.0.0", "--port", "8080", "--reload"]
--------------------------------------------------------------------------------
/ragflow/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 |
4 | logging.basicConfig(
5 | level=os.environ.get("LOG_LEVEL", "INFO"),
6 | format="%(asctime)s - %(levelname)s - %(name)s:%(filename)s:%(lineno)d - %(message)s",
7 | )
8 |
--------------------------------------------------------------------------------
/ragflow/api/__init__.py:
--------------------------------------------------------------------------------
1 | from ragflow.api.main import app
2 |
--------------------------------------------------------------------------------
/ragflow/api/database.py:
--------------------------------------------------------------------------------
1 | from sqlalchemy import create_engine
2 | from sqlalchemy.orm import declarative_base, sessionmaker
3 |
4 | import os
5 |
6 | DRIVER = os.environ.get("POSTGRES_DRIVER")
7 | HOST = os.environ.get("POSTGRES_HOST", "localhost")
8 | PORT = os.environ.get("POSTGRES_PORT")
9 | DB = os.environ.get("POSTGRES_DATABASE")
10 | USER = os.environ.get("POSTGRES_USER")
11 | PWD = os.environ.get("POSTGRES_PASSWORD")
12 |
13 |
14 | if HOST == "localhost":
15 | HOST += f":{PORT}"
16 |
17 | DATABASE_URL = f"postgresql+{DRIVER}://{USER}:{PWD}@{HOST}/{DB}"
18 |
19 | engine = create_engine(DATABASE_URL)
20 |
21 | Session = sessionmaker(autocommit=False, autoflush=False, bind=engine)
22 |
23 | Base = declarative_base()
24 |
--------------------------------------------------------------------------------
/ragflow/api/main.py:
--------------------------------------------------------------------------------
1 | import os
2 | from fastapi import FastAPI
3 |
4 | from ragflow.api.routers import (
5 | configs_router,
6 | evals_router,
7 | gens_router,
8 | chats_router,
9 | auth_router,
10 | user_router,
11 | )
12 |
13 | from ragflow.api.database import Base, engine
14 |
15 | app = FastAPI(
16 | title="FastAPI RAGflow Documentation",
17 | version="0.01",
18 | description="""API for the main component of RAGflow to generate evaluation datasets of QA pairs and to run hyperparameter evaluations.
19 | """,
20 | # root path should only be "/api" in production
21 | root_path="" if os.environ.get("EXECUTION_CONTEXT") in ["DEV", "TEST"] else "/api",
22 | )
23 |
24 | app.include_router(configs_router.router)
25 | app.include_router(gens_router.router)
26 | app.include_router(evals_router.router)
27 | app.include_router(chats_router.router)
28 | app.include_router(auth_router.router)
29 | app.include_router(user_router.router)
30 |
31 | Base.metadata.create_all(bind=engine)
32 |
--------------------------------------------------------------------------------
/ragflow/api/models.py:
--------------------------------------------------------------------------------
1 | import datetime as dt
2 | from sqlalchemy import Column, Integer, String, DateTime, Boolean
3 |
4 | from ragflow.api.database import Base
5 |
6 |
7 | class User(Base):
8 | __tablename__ = "users"
9 | id = Column(Integer, primary_key=True, index=True)
10 | username = Column(String, index=True, unique=True)
11 | hashed_password = Column(String, index=False, unique=False)
12 | email = Column(String, index=True, unique=True)
13 | date_created = Column(DateTime, default=dt.datetime.now(dt.UTC))
14 | is_active = Column(Boolean, default=True)
15 | role = Column(String, default="user", index=True)
16 |
--------------------------------------------------------------------------------
/ragflow/api/routers/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AndreasX42/RAGflow/cec31b91af8ab7faa05a3d0376338c7d294d29aa/ragflow/api/routers/__init__.py
--------------------------------------------------------------------------------
/ragflow/api/routers/auth_router.py:
--------------------------------------------------------------------------------
1 | from fastapi import APIRouter, Depends, HTTPException, Response, Request
2 | from starlette import status
3 | import sqlalchemy.orm as orm
4 | from fastapi.security import OAuth2PasswordRequestForm
5 | from datetime import timedelta
6 |
7 | from ragflow.api.schemas import UserFromDB
8 | from ragflow.api.services import get_db
9 | from ragflow.api.services.auth_service import (
10 | authenticate_user,
11 | create_access_token,
12 | get_current_active_user,
13 | )
14 |
15 |
16 | router = APIRouter(
17 | prefix="/auth",
18 | tags=["Authentication endpoints"],
19 | )
20 |
21 | ACCESS_TOKEN_EXPIRATION_IN_MINUTES = 60 * 24 * 7 # 7 days
22 |
23 |
24 | @router.get("/user", response_model=UserFromDB, status_code=status.HTTP_200_OK)
25 | async def get_authenticated_user(user: UserFromDB = Depends(get_current_active_user)):
26 | return user
27 |
28 |
29 | @router.post("/login", response_model=UserFromDB, status_code=status.HTTP_200_OK)
30 | async def login_for_access_token(
31 | response: Response,
32 | form_data: OAuth2PasswordRequestForm = Depends(),
33 | db_session: orm.Session = Depends(get_db),
34 | ):
35 | user = await authenticate_user(
36 | username=form_data.username, password=form_data.password, db_session=db_session
37 | )
38 |
39 | if not user:
40 | raise HTTPException(
41 | status_code=status.HTTP_401_UNAUTHORIZED,
42 | detail="Incorrect username or password",
43 | headers={"WWW-Authenticate": "Bearer"},
44 | )
45 |
46 | if not user.is_active:
47 | raise HTTPException(
48 | status_code=status.HTTP_401_UNAUTHORIZED,
49 | detail="User is disabled",
50 | headers={"WWW-Authenticate": "Bearer"},
51 | )
52 |
53 | jwt, exp = create_access_token(
54 | subject={"sub": user.username, "user_id": user.id, "user_role": user.role},
55 | expires_delta=timedelta(minutes=ACCESS_TOKEN_EXPIRATION_IN_MINUTES),
56 | )
57 |
58 | response.set_cookie(
59 | key="access_token",
60 | value=jwt,
61 | expires=int(exp.timestamp()),
62 | httponly=True,
63 | )
64 |
65 | return UserFromDB.model_validate(user)
66 |
67 |
68 | @router.get("/logout", response_model=dict, status_code=status.HTTP_200_OK)
69 | async def logout(response: Response):
70 | response.delete_cookie(key="access_token")
71 |
72 | return {"message": "Logged out successfully"}
73 |
--------------------------------------------------------------------------------
/ragflow/api/routers/chats_router.py:
--------------------------------------------------------------------------------
1 | from fastapi import APIRouter, HTTPException
2 | from fastapi.responses import StreamingResponse
3 | from starlette import status
4 | from pydantic import BaseModel, Field
5 |
6 | from ragflow.utils.hyperparam_chats import AsyncCallbackHandler
7 | from ragflow.utils.hyperparam_chats import query_chat, create_gen, get_docs
8 |
9 | router = APIRouter(
10 | prefix="/chats",
11 | tags=["Endpoints to chat with parameterized RAGs"],
12 | )
13 |
14 |
15 | class ChatQueryRequest(BaseModel):
16 | hp_id: int = Field(ge=0, description="Hyperparameter run id")
17 | hyperparameters_results_path: str = (
18 | Field(min_length=3, description="path to list of hp results"),
19 | )
20 | user_id: str = Field(description="user id from db")
21 | api_keys: dict[str, str] = Field(description="Dictionary of API keys.")
22 | query: str = Field(min_length=5, description="User query")
23 |
24 | class Config:
25 | json_schema_extra = {
26 | "example": {
27 | "id": 0,
28 | "hyperparameters_results_path": "./tmp/hyperparameters_results.json",
29 | "user_id": "1",
30 | "api_keys": {
31 | "OPENAI_API_KEY": "your_api_key_here",
32 | "ANOTHER_API_KEY": "another_key_here",
33 | },
34 | "query": "My query",
35 | }
36 | }
37 |
38 |
39 | @router.post("/query", status_code=status.HTTP_200_OK)
40 | async def query_chat_model_normal(request: ChatQueryRequest):
41 | try:
42 | return query_chat(**request.model_dump())
43 | except Exception as ex:
44 | raise HTTPException(status_code=400, detail=str(ex))
45 |
46 |
47 | @router.post("/get_docs", status_code=status.HTTP_200_OK)
48 | async def query_chat_model_normal(request: ChatQueryRequest):
49 | try:
50 | return await get_docs(**request.model_dump())
51 | except Exception as ex:
52 | raise HTTPException(status_code=400, detail=str(ex))
53 |
54 |
55 | @router.post("/query_stream", status_code=status.HTTP_200_OK)
56 | async def query_chat_model_stream(request: ChatQueryRequest):
57 | stream_it = AsyncCallbackHandler()
58 | gen = create_gen(**request.model_dump(), stream_it=stream_it)
59 | return StreamingResponse(gen, media_type="text/event-stream")
60 |
--------------------------------------------------------------------------------
/ragflow/api/routers/configs_router.py:
--------------------------------------------------------------------------------
1 | from fastapi import APIRouter
2 | from starlette import status
3 |
4 | from ragflow.commons.configurations import (
5 | LLM_MODELS,
6 | EMB_MODELS,
7 | CVGradeAnswerPrompt,
8 | CVGradeRetrieverPrompt,
9 | CVRetrieverSearchType,
10 | CVSimilarityMethod,
11 | )
12 |
13 | router = APIRouter(
14 | prefix="/configs",
15 | tags=["Configurations"],
16 | )
17 |
18 |
19 | @router.get("/llm_models", status_code=status.HTTP_200_OK)
20 | async def get_list_of_supported_llm_models():
21 | return LLM_MODELS
22 |
23 |
24 | @router.get("/embedding_models", status_code=status.HTTP_200_OK)
25 | async def list_of_embedding_models():
26 | return EMB_MODELS
27 |
28 |
29 | @router.get("/retriever_similarity_methods", status_code=status.HTTP_200_OK)
30 | async def list_of_similarity_methods_for_retriever():
31 | return [e.value for e in CVSimilarityMethod]
32 |
33 |
34 | @router.get("/retriever_search_types", status_code=status.HTTP_200_OK)
35 | async def list_of_search_types_for_retriever():
36 | return [e.value for e in CVRetrieverSearchType]
37 |
38 |
39 | @router.get("/grade_answer_prompts", status_code=status.HTTP_200_OK)
40 | async def list_of_prompts_for_grading_answers():
41 | return [e.value for e in CVGradeAnswerPrompt]
42 |
43 |
44 | @router.get("/grade_documents_prompts", status_code=status.HTTP_200_OK)
45 | async def list_of_prompts_for_grading_documents_retrieved():
46 | return [e.value for e in CVGradeRetrieverPrompt]
47 |
--------------------------------------------------------------------------------
/ragflow/api/routers/evals_router.py:
--------------------------------------------------------------------------------
1 | from fastapi import APIRouter, HTTPException
2 | from starlette import status
3 | from pydantic import BaseModel, Field
4 |
5 | from ragflow.evaluation import arun_evaluation
6 |
7 |
8 | router = APIRouter(
9 | tags=["Hyperparameter evaluation"],
10 | )
11 |
12 |
13 | class EvaluationRequest(BaseModel):
14 | document_store_path: str = Field(min_length=3, description="path to documents")
15 | label_dataset_path: str = Field(
16 | min_length=3, description="path to generated evaluation dataset"
17 | )
18 | hyperparameters_path: str = Field(
19 | min_length=3, description="path to list of hyperparameters"
20 | )
21 | hyperparameters_results_path: str = Field(
22 | min_length=3, description="path to list of hp results"
23 | )
24 | hyperparameters_results_data_path: str = Field(
25 | min_length=3,
26 | description="path to list of additional data generated during hp eval",
27 | )
28 | user_id: str = Field(description="user id from db")
29 | api_keys: dict[str, str] = Field(description="Dictionary of API keys.")
30 |
31 | class Config:
32 | json_schema_extra = {
33 | "example": {
34 | "document_store_path": "./tmp/document_store/", # path to documents
35 | "hyperparameters_path": "./tmp/hyperparameters.json", # path to list of hyperparameters
36 | "label_dataset_path": "./tmp/label_dataset.json", # path to generated evaluation dataset
37 | "hyperparameters_results_path": "./tmp/hyperparameters_results.json", # path to list of eval results
38 | "hyperparameters_results_data_path": "./tmp/hyperparameters_results_data.csv", # path to list of generated predictions and retrieved docs
39 | "user_id": "1", # user id
40 | "api_keys": {
41 | "OPENAI_API_KEY": "your_api_key_here",
42 | "ANOTHER_API_KEY": "another_key_here",
43 | },
44 | }
45 | }
46 |
47 |
48 | @router.post("/evaluation", status_code=status.HTTP_200_OK)
49 | async def start_evaluation_run(eval_request: EvaluationRequest):
50 | try:
51 | await arun_evaluation(**eval_request.model_dump())
52 | except Exception as ex:
53 | print(ex)
54 | raise HTTPException(status_code=400, detail=str(ex))
55 |
--------------------------------------------------------------------------------
/ragflow/api/routers/gens_router.py:
--------------------------------------------------------------------------------
1 | from fastapi import APIRouter, HTTPException
2 | from starlette import status
3 | from pydantic import BaseModel, Field
4 |
5 | from ragflow.generation import agenerate_evaluation_set
6 |
7 | router = APIRouter(
8 | tags=["Evaluation label set generation"],
9 | )
10 |
11 |
12 | class GenerationRequest(BaseModel):
13 | document_store_path: str = Field(min_length=3, description="path to documents")
14 | label_dataset_gen_params_path: str = Field(
15 | min_length=3, description="path to configurations for qa generation"
16 | )
17 | label_dataset_path: str = Field(
18 | min_length=3,
19 | description="path to where the generated qa pairs should be stored.",
20 | )
21 | user_id: str = Field(description="user id from db")
22 | api_keys: dict[str, str] = Field(description="Dictionary of API keys.")
23 |
24 | class Config:
25 | json_schema_extra = {
26 | "example": {
27 | "document_store_path": "./tmp/document_store/", # path to documents
28 | "label_dataset_gen_params_path": "./tmp/label_dataset_gen_params.json", # path to list of hyperparameters
29 | "label_dataset_path": "./tmp/label_dataset.json", # path to generated evaluation dataset
30 | "user_id": "1", # user id
31 | "api_keys": {
32 | "OPENAI_API_KEY": "your_api_key_here",
33 | "ANOTHER_API_KEY": "another_key_here",
34 | },
35 | }
36 | }
37 |
38 |
39 | @router.post("/generation", status_code=status.HTTP_200_OK)
40 | async def start_evalset_generation(gen_request: GenerationRequest):
41 | try:
42 | await agenerate_evaluation_set(**gen_request.model_dump())
43 | except Exception as ex:
44 | print(ex)
45 | raise HTTPException(status_code=400, detail=str(ex))
46 |
--------------------------------------------------------------------------------
/ragflow/api/routers/user_router.py:
--------------------------------------------------------------------------------
1 | from fastapi import APIRouter, Depends, HTTPException
2 | import sqlalchemy.orm as orm
3 | from starlette import status
4 |
5 | from ragflow.api.services import get_db
6 | import ragflow.api.services.user_service as user_service
7 | import ragflow.api.services.auth_service as auth_service
8 |
9 | from ragflow.api.schemas import CreateUserRequest, UpdateUserRequest, UserFromDB
10 |
11 | router = APIRouter(
12 | tags=["User endpoints"],
13 | )
14 |
15 | import logging
16 |
17 | logger = logging.getLogger(__name__)
18 |
19 |
20 | @router.post(
21 | "/users",
22 | response_model=UserFromDB,
23 | status_code=status.HTTP_201_CREATED,
24 | )
25 | async def create_user(
26 | user_data: CreateUserRequest,
27 | db_session: orm.Session = Depends(get_db),
28 | ):
29 | return await user_service.create_user(user_data=user_data, db_session=db_session)
30 |
31 |
32 | @router.get(
33 | "/users",
34 | response_model=list[UserFromDB],
35 | status_code=status.HTTP_200_OK,
36 | )
37 | async def get_all_users(
38 | user: UserFromDB = Depends(auth_service.get_current_active_user),
39 | db_session: orm.Session = Depends(get_db),
40 | ):
41 | if user.role != "admin":
42 | raise HTTPException(status_code=401, detail="User not authorized")
43 |
44 | return await user_service.get_all_users(db_session=db_session)
45 |
46 |
47 | @router.get(
48 | "/users/{user_id}",
49 | response_model=UserFromDB,
50 | status_code=status.HTTP_200_OK,
51 | )
52 | async def get_user_by_id(
53 | user_id: int,
54 | user: UserFromDB = Depends(auth_service.get_current_active_user),
55 | db_session: orm.Session = Depends(get_db),
56 | ):
57 | if user.id != user_id and user.role != "admin":
58 | raise HTTPException(status_code=401, detail="User not authorized")
59 |
60 | return await user_service.get_user_by_id(user_id=user_id, db_session=db_session)
61 |
62 |
63 | @router.delete(
64 | "/users/{user_id}",
65 | response_model=UserFromDB,
66 | status_code=status.HTTP_200_OK,
67 | )
68 | async def delete_user_by_id(
69 | user_id: int,
70 | user: UserFromDB = Depends(auth_service.get_current_active_user),
71 | db_session: orm.Session = Depends(get_db),
72 | ):
73 | if user.id != user_id and user.role != "admin":
74 | raise HTTPException(status_code=401, detail="User not authorized")
75 |
76 | user_to_delete = await user_service.get_user_by_id(
77 | user_id=user_id, db_session=db_session
78 | )
79 |
80 | await user_service.delete_user(user=user_to_delete, db_session=db_session)
81 |
82 | return user_to_delete
83 |
84 |
85 | @router.put(
86 | "/users/{user_id}",
87 | response_model=UserFromDB,
88 | status_code=status.HTTP_200_OK,
89 | )
90 | async def update_user_by_id(
91 | user_id: int,
92 | user_data: UpdateUserRequest,
93 | user: UserFromDB = Depends(auth_service.get_current_active_user),
94 | db_session: orm.Session = Depends(get_db),
95 | ):
96 | if user.id != user_id and user.role != "admin":
97 | raise HTTPException(status_code=401, detail="User not authorized")
98 |
99 | user_to_change = await user_service.get_user_by_id(
100 | user_id=user_id, db_session=db_session
101 | )
102 |
103 | return await user_service.update_user(
104 | user=user_to_change, user_data=user_data, db_session=db_session
105 | )
106 |
--------------------------------------------------------------------------------
/ragflow/api/schemas.py:
--------------------------------------------------------------------------------
1 | import datetime as dt
2 | from typing import Annotated
3 | from pydantic import BaseModel, EmailStr, StringConstraints
4 |
5 | from typing import Optional
6 |
7 |
8 | class UserBase(BaseModel):
9 | username: str
10 | email: str
11 |
12 |
13 | class UserFromDB(UserBase):
14 | id: int
15 | date_created: dt.datetime
16 | role: str
17 | is_active: bool
18 |
19 | class Config:
20 | from_attributes = True
21 |
22 |
23 | class UpdateUserRequest(UserBase):
24 | password: Optional[str] = None
25 |
26 |
27 | class CreateUserRequest(UserBase):
28 | username: Annotated[str, StringConstraints(min_length=4, max_length=64)]
29 | password: Annotated[str, StringConstraints(min_length=8, max_length=128)]
30 | email: EmailStr
31 |
--------------------------------------------------------------------------------
/ragflow/api/services/__init__.py:
--------------------------------------------------------------------------------
1 | from ragflow.api.services.common_service import get_db
2 | import ragflow.api.services.user_service as user_service
3 | import ragflow.api.services.auth_service as auth_service
4 |
--------------------------------------------------------------------------------
/ragflow/api/services/auth_service.py:
--------------------------------------------------------------------------------
1 | from typing import Union
2 |
3 | from fastapi import Depends, HTTPException, Request
4 | from fastapi.security import OAuth2PasswordBearer
5 | from starlette import status
6 | import sqlalchemy.orm as orm
7 | from jose import JWTError, jwt
8 | from passlib.context import CryptContext
9 |
10 | from datetime import datetime, timedelta
11 | import os
12 |
13 | from ragflow.api.schemas import UserFromDB
14 | from ragflow.api.services import user_service
15 | from ragflow.api.services import get_db
16 |
17 | JWT_SECRET_KEY = os.environ.get("JWT_SECRET_KEY")
18 | HASH_ALGORITHM = "HS256"
19 |
20 | pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
21 | oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/login")
22 |
23 |
24 | def verify_password(plain_password: str, hashed_password: str) -> bool:
25 | return pwd_context.verify(plain_password, hashed_password)
26 |
27 |
28 | def get_password_hash(password: str) -> str:
29 | return pwd_context.hash(password)
30 |
31 |
32 | async def authenticate_user(
33 | username: str, password: str, db_session: orm.Session
34 | ) -> Union[UserFromDB, bool]:
35 | user = await user_service.get_user_by_name(username=username, db_session=db_session)
36 |
37 | if user is None:
38 | return False
39 |
40 | if not verify_password(password, user.hashed_password):
41 | return False
42 |
43 | return user
44 |
45 |
46 | def create_access_token(
47 | subject: dict, expires_delta: timedelta = None
48 | ) -> tuple[str, datetime]:
49 | to_encode = subject.copy()
50 |
51 | if expires_delta is not None:
52 | exp = datetime.utcnow() + expires_delta
53 |
54 | else:
55 | exp = datetime.utcnow() + timedelta(minutes=15)
56 |
57 | to_encode |= {"exp": exp}
58 | encoded_jwt = jwt.encode(to_encode, JWT_SECRET_KEY, HASH_ALGORITHM)
59 |
60 | return encoded_jwt, exp
61 |
62 |
63 | async def get_current_user(
64 | request: Request,
65 | db_session: orm.Session = Depends(get_db),
66 | ) -> UserFromDB:
67 | credentials_exception = HTTPException(
68 | status_code=status.HTTP_401_UNAUTHORIZED,
69 | detail="Could not validate credentials",
70 | headers={"WWW-Authenticate": "Bearer"},
71 | )
72 |
73 | try:
74 | token = request.cookies.get("access_token")
75 |
76 | if token is None:
77 | raise HTTPException(
78 | status_code=status.HTTP_401_UNAUTHORIZED,
79 | detail="User not authenticated",
80 | )
81 |
82 | payload = jwt.decode(token, JWT_SECRET_KEY, algorithms=[HASH_ALGORITHM])
83 | username = payload.get("sub")
84 | user_id = payload.get("user_id")
85 | user_role = payload.get("user_role")
86 |
87 | if username is None or user_id is None or user_role is None:
88 | raise credentials_exception
89 |
90 | except JWTError:
91 | raise credentials_exception
92 |
93 | user = await user_service.get_user_by_id(user_id=user_id, db_session=db_session)
94 |
95 | if user is None:
96 | raise credentials_exception
97 |
98 | return UserFromDB.model_validate(user)
99 |
100 |
101 | async def get_current_active_user(
102 | current_user: UserFromDB = Depends(get_current_user),
103 | ) -> UserFromDB:
104 | if not current_user.is_active:
105 | raise HTTPException(status_code=400, detail="Inactive user")
106 |
107 | return current_user
108 |
--------------------------------------------------------------------------------
/ragflow/api/services/common_service.py:
--------------------------------------------------------------------------------
1 | from ragflow.api.database import Session
2 |
3 |
4 | def get_db():
5 | db_session = Session()
6 | try:
7 | yield db_session
8 | finally:
9 | db_session.close()
10 |
--------------------------------------------------------------------------------
/ragflow/api/services/user_service.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | from ragflow.api.services import auth_service
4 | from ragflow.api.schemas import UserFromDB, CreateUserRequest, UpdateUserRequest
5 | from ragflow.api.models import User
6 | from ragflow.api.database import Session
7 |
8 | from fastapi import HTTPException
9 |
10 | logger = logging.getLogger(__name__)
11 |
12 |
13 | async def create_user(user_data: CreateUserRequest, db_session: Session) -> UserFromDB:
14 | user_data_dict = user_data.model_dump()
15 |
16 | user_data_dict["hashed_password"] = auth_service.get_password_hash(
17 | user_data_dict.pop("password")
18 | )
19 |
20 | # Check if a user with the same username or email already exists
21 | existing_user = (
22 | db_session.query(User)
23 | .filter(
24 | (User.username == user_data_dict.get("username"))
25 | | (User.email == user_data_dict.get("email"))
26 | )
27 | .first()
28 | )
29 |
30 | if existing_user:
31 | # Determine which attribute (username or email) already exists in db
32 | duplicate_field = (
33 | "username"
34 | if existing_user.username == user_data_dict.get("username")
35 | else "email"
36 | )
37 | raise HTTPException(
38 | status_code=400,
39 | detail=f"User with provided {duplicate_field} already exists!",
40 | )
41 |
42 | user = User(**user_data_dict)
43 |
44 | db_session.add(user)
45 | db_session.commit()
46 | db_session.refresh(user)
47 |
48 | return UserFromDB.model_validate(user)
49 |
50 |
51 | async def get_all_users(db_session: Session) -> list[UserFromDB]:
52 | users = db_session.query(User).all()
53 | return list(map(UserFromDB.model_validate, users))
54 |
55 |
56 | async def get_user_by_id(user_id: int, db_session: Session) -> User:
57 | user = db_session.query(User).filter(User.id == user_id).first()
58 |
59 | return user
60 |
61 |
62 | async def get_user_by_name(username: str, db_session: Session) -> User:
63 | user = db_session.query(User).filter(User.username == username).first()
64 |
65 | return user
66 |
67 |
68 | async def delete_user(user: User, db_session: Session) -> None:
69 | db_session.delete(user)
70 | db_session.commit()
71 |
72 |
73 | async def update_user(
74 | user: User,
75 | user_data: UpdateUserRequest,
76 | db_session: Session,
77 | ) -> UserFromDB:
78 | user.username = user_data.username
79 | user.email = user_data.email
80 |
81 | if user_data.password:
82 | user.hashed_password = auth_service.get_password_hash(user_data.password)
83 |
84 | db_session.commit()
85 | db_session.refresh(user)
86 |
87 | return UserFromDB.model_validate(user)
88 |
--------------------------------------------------------------------------------
/ragflow/commons/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AndreasX42/RAGflow/cec31b91af8ab7faa05a3d0376338c7d294d29aa/ragflow/commons/__init__.py
--------------------------------------------------------------------------------
/ragflow/commons/chroma/ChromaClient.py:
--------------------------------------------------------------------------------
1 | import chromadb
2 | from chromadb.config import Settings
3 | import os
4 |
5 | API_HOST = os.environ.get("CHROMADB_HOST", "localhost")
6 | API_PORT = os.environ.get("CHROMADB_PORT", 8000)
7 |
8 |
9 | class ChromaClient:
10 | """Get a chromadb client as ContextManager."""
11 |
12 | def __init__(self):
13 | self.chroma_client = chromadb.HttpClient(
14 | host=API_HOST,
15 | port=API_PORT,
16 | settings=Settings(anonymized_telemetry=False),
17 | )
18 |
19 | def get_client(self):
20 | return self.chroma_client
21 |
22 | def __enter__(self):
23 | return self.chroma_client
24 |
25 | def __exit__(self, exc_type, exc_value, traceback):
26 | # TODO: self.chroma_client.stop()
27 | pass
28 |
--------------------------------------------------------------------------------
/ragflow/commons/chroma/__init__.py:
--------------------------------------------------------------------------------
1 | from ragflow.commons.chroma.ChromaClient import ChromaClient
2 |
--------------------------------------------------------------------------------
/ragflow/commons/configurations/BaseConfigurations.py:
--------------------------------------------------------------------------------
1 | from langchain.schema.embeddings import Embeddings
2 | from langchain.schema.language_model import BaseLanguageModel
3 | from langchain.llms.fake import FakeListLLM
4 |
5 | from langchain.embeddings import OpenAIEmbeddings, DeterministicFakeEmbedding
6 | from langchain.chat_models import ChatOpenAI, ChatAnyscale
7 |
8 | import tiktoken
9 | import builtins
10 | import logging
11 |
12 | from abc import abstractmethod
13 | from pydantic.v1 import BaseModel, Field, Extra, validator
14 | from typing import Any, Optional
15 |
16 | from enum import Enum
17 |
18 | logger = logging.getLogger(__name__)
19 |
20 | ####################
21 | # LLM Models
22 | ####################
23 | LLAMA_2_MODELS = [
24 | "Llama-2-7b-chat-hf",
25 | "Llama-2-13b-chat-hf",
26 | "Llama-2-70b-chat-hf",
27 | ]
28 |
29 | OPENAI_LLM_MODELS = ["gpt-3.5-turbo", "gpt-4"]
30 |
31 | LLM_MODELS = [*OPENAI_LLM_MODELS, *LLAMA_2_MODELS]
32 |
33 | # Anyscale configs
34 | ANYSCALE_LLM_PREFIX = "meta-llama/"
35 | ANYSCALE_API_URL = "https://api.endpoints.anyscale.com/v1"
36 |
37 | ####################
38 | # Embedding Models
39 | ####################
40 | OPENAI_EMB_MODELS = ["text-embedding-ada-002"]
41 |
42 | EMB_MODELS = [*OPENAI_EMB_MODELS]
43 |
44 |
45 | ####################
46 | # Enumerations
47 | ####################
48 | class CVGradeAnswerPrompt(Enum):
49 | ZERO_SHOT = "zero_shot"
50 | FEW_SHOT = "few_shot"
51 | NONE = "none"
52 |
53 |
54 | class CVGradeRetrieverPrompt(Enum):
55 | DEFAULT = "default"
56 | NONE = "none"
57 |
58 |
59 | class CVRetrieverSearchType(Enum):
60 | BY_SIMILARITY = "similarity"
61 | MAXIMAL_MARGINAL_RELEVANCE = "mmr"
62 |
63 |
64 | class CVSimilarityMethod(Enum):
65 | COSINE = "cosine"
66 | L2_NORM = "l2"
67 | INNER_PRODUCT = "ip"
68 |
69 |
70 | ####################
71 | # Test LLM and embedding model for mocks
72 | ####################
73 | class TestDummyLLM(FakeListLLM):
74 | """Langchains FakeListLLM with model name included."""
75 |
76 | model_name: str = "TestDummyLLM"
77 |
78 | def __init__(self):
79 | super().__init__(responses=["foo_response"])
80 |
81 | def dict(self, *args, **kwargs):
82 | output = super().dict(*args, **kwargs)
83 | output["model_name"] = self.model_name
84 | return output
85 |
86 |
87 | class TestDummyEmbedding(DeterministicFakeEmbedding):
88 | """Langchains DeterministicFakeEmbedding with model name included."""
89 |
90 | model: str = "TestDummyEmbedding"
91 |
92 | def __init__(self):
93 | super().__init__(size=2)
94 |
95 |
96 | class BaseConfigurations(BaseModel):
97 | """Base class for configuration objects."""
98 |
99 | chunk_size: int = Field(ge=0)
100 | chunk_overlap: int = Field(ge=0)
101 | length_function_name: str
102 | length_function: Any
103 |
104 | class Config:
105 | allow_mutation = False
106 | arbitrary_types_allowed = True
107 | extra = Extra.forbid
108 |
109 | @validator("length_function", pre=False, always=True)
110 | def populate_length_function(cls, v: callable, values: dict[str, str]):
111 | """Convert provided length function name to actual function."""
112 |
113 | return cls.set_length_function(values["length_function_name"])
114 |
115 | @staticmethod
116 | def get_language_model(model_name: str, api_keys: dict) -> BaseLanguageModel:
117 | if model_name in OPENAI_LLM_MODELS:
118 | return ChatOpenAI(
119 | openai_api_key=api_keys["OPENAI_API_KEY"],
120 | model_name=model_name,
121 | temperature=0.0,
122 | )
123 |
124 | elif model_name in LLAMA_2_MODELS:
125 | return ChatAnyscale(
126 | anyscale_api_key=api_keys["ANYSCALE_API_KEY"],
127 | model_name=f"{ANYSCALE_LLM_PREFIX}{model_name}",
128 | anyscale_api_base=ANYSCALE_API_URL,
129 | temperature=0.0,
130 | )
131 | # only for testing purposes
132 | elif model_name == "TestDummyLLM":
133 | return TestDummyLLM()
134 |
135 | raise NotImplementedError(f"LLM model '{model_name}' not supported.")
136 |
137 | @staticmethod
138 | def get_embedding_model(model_name: str, api_keys: dict) -> Embeddings:
139 | if model_name in OPENAI_EMB_MODELS:
140 | return OpenAIEmbeddings(
141 | openai_api_key=api_keys["OPENAI_API_KEY"], model=model_name
142 | )
143 | elif model_name == "TestDummyEmbedding":
144 | return TestDummyEmbedding()
145 |
146 | raise NotImplementedError(f"Embedding model '{model_name}' not supported.")
147 |
148 | @classmethod
149 | def set_length_function(cls, length_function_name: str) -> callable:
150 | # Extract the function name from the string
151 | func = length_function_name.strip("<>").split(" ")[-1]
152 |
153 | # Check if the function name exists in Python's built-ins
154 | if hasattr(builtins, func):
155 | return getattr(builtins, func)
156 |
157 | else:
158 | try:
159 | encoding = tiktoken.encoding_for_model(length_function_name)
160 | return lambda x: len(encoding.encode(x))
161 | except Exception as ex:
162 | logger.error(f"Length function '{length_function_name}' not supported")
163 | raise NotImplementedError(
164 | f"Error setting length function, neither python built-in nor valid tiktoken name passed. {ex.args}"
165 | )
166 |
167 | @staticmethod
168 | def get_language_model_name(llm: BaseLanguageModel) -> str:
169 | """Retrieve name of language model from object"""
170 | return llm.model_name
171 |
172 | @staticmethod
173 | def get_embedding_model_name(emb: Embeddings) -> str:
174 | """Retrieve name of embedding model name from object"""
175 | return emb.model
176 |
177 | def to_dict(self) -> dict:
178 | _data = self.dict()
179 | _data.pop("length_function", None)
180 | return _data
181 |
182 | @classmethod
183 | @abstractmethod
184 | def from_dict(cls, input_dict):
185 | pass
186 |
--------------------------------------------------------------------------------
/ragflow/commons/configurations/Hyperparameters.py:
--------------------------------------------------------------------------------
1 | from langchain.schema.embeddings import Embeddings
2 | from langchain.schema.language_model import BaseLanguageModel
3 |
4 | from ragflow.commons.configurations.BaseConfigurations import (
5 | BaseConfigurations,
6 | LLM_MODELS,
7 | EMB_MODELS,
8 | CVGradeAnswerPrompt,
9 | CVGradeRetrieverPrompt,
10 | CVRetrieverSearchType,
11 | CVSimilarityMethod,
12 | )
13 |
14 | import logging
15 |
16 | from pydantic.v1 import Field, validator
17 |
18 | logger = logging.getLogger(__name__)
19 |
20 |
21 | class Hyperparameters(BaseConfigurations):
22 | """Class to model hyperparameters."""
23 |
24 | id: int = Field(ge=0)
25 | num_retrieved_docs: int = Field(ge=0)
26 | grade_answer_prompt: CVGradeAnswerPrompt
27 | grade_docs_prompt: CVGradeRetrieverPrompt
28 | search_type: CVRetrieverSearchType
29 | similarity_method: CVSimilarityMethod
30 | use_llm_grader: bool
31 | qa_llm: BaseLanguageModel
32 | grader_llm: BaseLanguageModel
33 | embedding_model: Embeddings
34 |
35 | @validator("qa_llm", "grader_llm", pre=True, always=True)
36 | def check_language_model_name(cls, v, field, values):
37 | """Validation to check if provided LLM model is supported."""
38 |
39 | if cls.get_language_model_name(v) not in LLM_MODELS + ["TestDummyLLM"]:
40 | raise ValueError(f"{v} not in list of valid values {LLM_MODELS}.")
41 | return v
42 |
43 | @validator("embedding_model", pre=True, always=True)
44 | def check_embedding_model_name(cls, v):
45 | """Validation to check if provided embedding model is supported."""
46 |
47 | if cls.get_embedding_model_name(v) not in EMB_MODELS + ["TestDummyEmbedding"]:
48 | raise ValueError(f"{v} not in list of valid values {EMB_MODELS}.")
49 | return v
50 |
51 | def to_dict(self):
52 | _data = super().to_dict()
53 |
54 | # Modify the dictionary for fields that need special handling
55 | _data["qa_llm"] = _data["qa_llm"]["model_name"]
56 | _data["grader_llm"] = _data["grader_llm"]["model_name"]
57 | _data["embedding_model"] = _data["embedding_model"]["model"]
58 |
59 | # remove unnecessary values again if we don't use a llm for grading
60 | if not _data["use_llm_grader"]:
61 | _data.pop("grade_answer_prompt", None)
62 | _data.pop("grade_docs_prompt", None)
63 | _data.pop("grader_llm", None)
64 |
65 | return _data
66 |
67 | @classmethod
68 | def from_dict(
69 | cls, input_dict: dict[str, str], hp_id: str, api_keys: dict[str, str]
70 | ):
71 | _input = dict(**input_dict)
72 |
73 | # set default values if use_grader_llm=False and additional values not set
74 | if not _input["use_llm_grader"]:
75 | _input["grade_answer_prompt"] = CVGradeAnswerPrompt.NONE.value
76 | _input["grade_docs_prompt"] = CVGradeRetrieverPrompt.NONE.value
77 | _input["grader_llm"] = "TestDummyLLM"
78 |
79 | _input["id"] = hp_id
80 |
81 | # set the actual langchain objects
82 | _input["embedding_model"] = cls.get_embedding_model(
83 | _input["embedding_model"], api_keys
84 | )
85 | _input["qa_llm"] = cls.get_language_model(_input["qa_llm"], api_keys)
86 | _input["grader_llm"] = cls.get_language_model(_input["grader_llm"], api_keys)
87 |
88 | return cls(**_input)
89 |
--------------------------------------------------------------------------------
/ragflow/commons/configurations/QAConfigurations.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from pydantic.v1 import validator
3 |
4 | from langchain.schema.language_model import BaseLanguageModel
5 | from langchain.schema.embeddings import Embeddings
6 | from ragflow.commons.configurations.BaseConfigurations import (
7 | BaseConfigurations,
8 | LLM_MODELS,
9 | )
10 |
11 |
12 | logger = logging.getLogger(__name__)
13 |
14 |
15 | class QAConfigurations(BaseConfigurations):
16 | """Class to model qa generation configs."""
17 |
18 | qa_generator_llm: BaseLanguageModel
19 | persist_to_vs: bool
20 | embedding_model_list: list[Embeddings]
21 |
22 | @validator("qa_generator_llm", pre=True, always=True)
23 | def check_language_model_name(cls, v):
24 | """Validation to check if provided LLM model is supported."""
25 |
26 | if cls.get_language_model_name(v) not in LLM_MODELS + ["TestDummyLLM"]:
27 | raise ValueError(f"{v} not in list of valid values {LLM_MODELS}.")
28 | return v
29 |
30 | def to_dict(self):
31 | _data = super().to_dict()
32 |
33 | # Modify the dictionary for fields that need special handling
34 | _data["qa_generator_llm"] = _data["qa_generator_llm"]["model_name"]
35 | _data["embedding_model_list"] = [
36 | model["model"] for model in _data["embedding_model_list"]
37 | ]
38 |
39 | return _data
40 |
41 | @classmethod
42 | def from_dict(cls, input_dict: dict[str, str], api_keys: dict[str, str]):
43 | _input = dict(**input_dict)
44 |
45 | _input["qa_generator_llm"] = cls.get_language_model(
46 | _input["qa_generator_llm"], api_keys
47 | )
48 |
49 | # get the list of embedding models, filter the unique models and map them to LangChain objects
50 | embedding_models = set(_input["embedding_model_list"])
51 |
52 | _input["embedding_model_list"] = [
53 | cls.get_embedding_model(model, api_keys) for model in embedding_models
54 | ]
55 |
56 | return cls(**_input)
57 |
--------------------------------------------------------------------------------
/ragflow/commons/configurations/__init__.py:
--------------------------------------------------------------------------------
1 | from ragflow.commons.configurations.BaseConfigurations import (
2 | BaseConfigurations,
3 | CVGradeAnswerPrompt,
4 | CVGradeRetrieverPrompt,
5 | CVRetrieverSearchType,
6 | CVSimilarityMethod,
7 | LLM_MODELS,
8 | EMB_MODELS,
9 | )
10 | from ragflow.commons.configurations.Hyperparameters import Hyperparameters
11 | from ragflow.commons.configurations.QAConfigurations import QAConfigurations
12 |
13 | import os
14 | import dotenv
15 | import logging
16 |
17 | logger = logging.getLogger(__name__)
18 |
19 | # Here we load .env file which contains the API keys to access LLMs, etc.
20 | if os.environ.get("ENV") == "DEV":
21 | try:
22 | dotenv.load_dotenv(dotenv.find_dotenv(), override=True)
23 | except Exception as ex:
24 | logger.error(
25 | "Failed to load .env file in 'DEV' environment containing API keys."
26 | )
27 | raise ex
28 |
--------------------------------------------------------------------------------
/ragflow/commons/prompts/__init__.py:
--------------------------------------------------------------------------------
1 | from ragflow.commons.prompts.qa_geneneration_prompts import (
2 | QA_GENERATION_PROMPT_SELECTOR,
3 | )
4 |
5 | from ragflow.commons.prompts.grade_answers_prompts import (
6 | GRADE_ANSWER_PROMPT_FAST,
7 | GRADE_ANSWER_PROMPT_5_CATEGORIES_5_GRADES_ZERO_SHOT,
8 | GRADE_ANSWER_PROMPT_3_CATEGORIES_4_GRADES_FEW_SHOT,
9 | )
10 |
11 | from ragflow.commons.prompts.grade_retriever_prompts import GRADE_RETRIEVER_PROMPT
12 |
13 | from ragflow.commons.prompts.qa_answer_prompts import QA_ANSWER_PROMPT
14 |
--------------------------------------------------------------------------------
/ragflow/commons/prompts/grade_answers_prompts.py:
--------------------------------------------------------------------------------
1 | from langchain.prompts import PromptTemplate
2 |
3 | template = """Please act as an impartial judge and evaluate the quality of the provided answer which attempts to answer the provided question based on a provided context.
4 |
5 | Grade the student answers based ONLY on their factual accuracy. Ignore differences in punctuation and phrasing between the student answer and true answer. It is OK if the student answer contains more information than the true answer, as long as it does not contain any conflicting statements. If the student answers that there is no specific information provided in the context, then the answer is incorrect. Begin!
6 |
7 | QUESTION: {query}
8 | STUDENT ANSWER: {result}
9 | TRUE ANSWER: {answer}
10 |
11 | Your response should be as follows:
12 |
13 | CORRECTNESS: (1,2,3,4 or 5) - grade 1 means the answer was completly incorrect, a higher grade towards 5 means the answer is more correct, does clarify more parts of the question and is more readable. The best grade is 5.
14 | """
15 |
16 | GRADE_ANSWER_PROMPT_FAST = PromptTemplate(
17 | input_variables=["query", "result", "answer"], template=template
18 | )
19 |
20 | template = """Please act as an impartial judge and evaluate the quality of the provided answer which attempts to answer the provided question based on a provided context.
21 |
22 | You'll be given a function grading_function which you'll call for each provided context, question and answer to submit your reasoning and score for the correctness, comprehensiveness and readability of the answer.
23 |
24 | Below is your grading rubric:
25 |
26 | - Correctness: Does the answer correctly answer the question.
27 |
28 | - Comprehensiveness: How comprehensive is the answer, does it fully answer all aspects of the question and provide comprehensive explanation and other necessary information.
29 |
30 | - Readability: How readable is the answer, does it have redundant information or incomplete information that hurts the readability of the answer. Rate from 0 (completely unreadable) to 1 (highly readable)
31 |
32 | Grade the student answers based ONLY on their factual accuracy. Ignore differences in punctuation and phrasing between the student answer and true answer. It is OK if the student answer contains more information than the true answer, as long as it does not contain any conflicting statements. If the student answers that there is no specific information provided in the context, then the answer is incorrect. Begin!
33 |
34 | QUESTION: {query}
35 | STUDENT ANSWER: {result}
36 | TRUE ANSWER: {answer}
37 |
38 | Your response should be as follows, do not provide any additional information.
39 |
40 | CORRECTNESS: (1,2,3,4 or 5) - scale from 1 worst to 5 best
41 | COMPREHENSIVENESS: (1,2,3,4 or 5) - scale from 1 worst to 5 best
42 | READABILITY: (1,2,3,4 or 5) - scale from 1 worst to 5 best
43 | """
44 |
45 | GRADE_ANSWER_PROMPT_5_CATEGORIES_5_GRADES_ZERO_SHOT = PromptTemplate(
46 | input_variables=["query", "result", "answer"], template=template
47 | )
48 |
49 |
50 | template = """Please act as an impartial judge and evaluate the quality of the provided answer which attempts to answer the provided question based on a provided context.
51 |
52 | You'll be given a function grading_function which you'll call for each provided context, question and answer to submit your reasoning and score for the correctness, comprehensiveness and readability of the answer.
53 |
54 | Below is your grading rubric:
55 |
56 | - CORRECTNESS: If the answer correctly answer the question, below are the details for different scores:
57 |
58 | - Score 0: the answer is completely incorrect, doesn’t mention anything about the question or is completely contrary to the correct answer.
59 |
60 | - For example, when asked “How to terminate a databricks cluster”, the answer is empty string, or content that’s completely irrelevant, or sorry I don’t know the answer.
61 |
62 | - Score 1: the answer provides some relevance to the question and answers one aspect of the question correctly.
63 |
64 | - Example:
65 |
66 | - Question: How to terminate a databricks cluster
67 |
68 | - Answer: Databricks cluster is a cloud-based computing environment that allows users to process big data and run distributed data processing tasks efficiently.
69 |
70 | - Or answer: In the Databricks workspace, navigate to the "Clusters" tab. And then this is a hard question that I need to think more about it
71 |
72 | - Score 2: the answer mostly answer the question but is missing or hallucinating on one critical aspect.
73 |
74 | - Example:
75 |
76 | - Question: How to terminate a databricks cluster”
77 |
78 | - Answer: “In the Databricks workspace, navigate to the "Clusters" tab.
79 |
80 | Find the cluster you want to terminate from the list of active clusters.
81 |
82 | And then you’ll find a button to terminate all clusters at once”
83 |
84 | - Score 3: the answer correctly answer the question and not missing any major aspect
85 |
86 | - Example:
87 |
88 | - Question: How to terminate a databricks cluster
89 |
90 | - Answer: In the Databricks workspace, navigate to the "Clusters" tab.
91 |
92 | Find the cluster you want to terminate from the list of active clusters.
93 |
94 | Click on the down-arrow next to the cluster name to open the cluster details.
95 |
96 | Click on the "Terminate" button. A confirmation dialog will appear. Click "Terminate" again to confirm the action.”
97 |
98 | - COMPREHENSIVENESS: How comprehensive is the answer, does it fully answer all aspects of the question and provide comprehensive explanation and other necessary information. Below are the details for different scores:
99 |
100 | - Score 0: typically if the answer is completely incorrect, then the comprehensiveness is also zero score.
101 |
102 | - Score 1: if the answer is correct but too short to fully answer the question, then we can give score 1 for comprehensiveness.
103 |
104 | - Example:
105 |
106 | - Question: How to use databricks API to create a cluster?
107 |
108 | - Answer: First, you will need a Databricks access token with the appropriate permissions. You can generate this token through the Databricks UI under the 'User Settings' option. And then (the rest is missing)
109 |
110 | - Score 2: the answer is correct and roughly answer the main aspects of the question, but it’s missing description about details. Or is completely missing details about one minor aspect.
111 |
112 | - Example:
113 |
114 | - Question: How to use databricks API to create a cluster?
115 |
116 | - Answer: You will need a Databricks access token with the appropriate permissions. Then you’ll need to set up the request URL, then you can make the HTTP Request. Then you can handle the request response.
117 |
118 | - Example:
119 |
120 | - Question: How to use databricks API to create a cluster?
121 |
122 | - Answer: You will need a Databricks access token with the appropriate permissions. Then you’ll need to set up the request URL, then you can make the HTTP Request. Then you can handle the request response.
123 |
124 | - Score 3: the answer is correct, and covers all the main aspects of the question
125 |
126 | - READABILITY: How readable is the answer, does it have redundant information or incomplete information that hurts the readability of the answer.
127 |
128 | - Score 0: the answer is completely unreadable, e.g. fully of symbols that’s hard to read; e.g. keeps repeating the words that it’s very hard to understand the meaning of the paragraph. No meaningful information can be extracted from the answer.
129 |
130 | - Score 1: the answer is slightly readable, there are irrelevant symbols or repeated words, but it can roughly form a meaningful sentence that cover some aspects of the answer.
131 |
132 | - Example:
133 |
134 | - Question: How to use databricks API to create a cluster?
135 |
136 | - Answer: You you you you you you will need a Databricks access token with the appropriate permissions. And then then you’ll need to set up the request URL, then you can make the HTTP Request. Then Then Then Then Then Then Then Then Then
137 |
138 | - Score 2: the answer is correct and mostly readable, but there is one obvious piece that’s affecting the readability (mentioning of irrelevant pieces, repeated words)
139 |
140 | - Example:
141 |
142 | - Question: How to terminate a databricks cluster
143 |
144 | - Answer: In the Databricks workspace, navigate to the "Clusters" tab.
145 |
146 | Find the cluster you want to terminate from the list of active clusters.
147 |
148 | Click on the down-arrow next to the cluster name to open the cluster details.
149 |
150 | Click on the "Terminate" button…………………………………..
151 |
152 | A confirmation dialog will appear. Click "Terminate" again to confirm the action.
153 |
154 | - Score 3: the answer is correct and reader friendly, no obvious piece that affect readability.
155 |
156 | Grade the student answers based ONLY on their factual accuracy. Ignore differences in punctuation and phrasing between the student answer and true answer. It is OK if the student answer contains more information than the true answer, as long as it does not contain any conflicting statements. If the student answers that there is no specific information provided in the context, then the answer is Incorrect. Begin!
157 |
158 | QUESTION: {query}
159 | STUDENT ANSWER: {result}
160 | TRUE ANSWER: {answer}
161 |
162 | Your response should be as follows, do not provide any additional information.
163 |
164 | CORRECTNESS: (0,1,2 or 3)
165 | COMPREHENSIVENESS: (0,1,2 or 3)
166 | READABILITY: (0,1,2 or 3)
167 | """
168 |
169 | GRADE_ANSWER_PROMPT_3_CATEGORIES_4_GRADES_FEW_SHOT = PromptTemplate(
170 | input_variables=["query", "result", "answer"], template=template
171 | )
172 |
--------------------------------------------------------------------------------
/ragflow/commons/prompts/grade_retriever_prompts.py:
--------------------------------------------------------------------------------
1 | from langchain.prompts import PromptTemplate
2 |
3 | template = """
4 | Given the question:
5 | {query}
6 |
7 | Here are some documents retrieved in response to the question:
8 | {result}
9 |
10 | And here is the answer to the question:
11 | {answer}
12 |
13 | GRADING CRITERIA: We want to know if the question can be directly answered with the provided documents and without providing any additional outside sources. Does the retrieved documents make it possible to answer the question?
14 |
15 | Your response should be as follows without providing any additional information:
16 |
17 | GRADE: (0 to 1) - grade '0' means it is impossible to answer the questions with the documents in any way, grade '1' means the question can be fully answered with the provided documents. The more aspects of the questions can be answered the higher the grade should be, with a maximum grade of 1.
18 | """
19 |
20 | GRADE_RETRIEVER_PROMPT = PromptTemplate(
21 | input_variables=["query", "result", "answer"], template=template
22 | )
23 |
--------------------------------------------------------------------------------
/ragflow/commons/prompts/qa_answer_prompts.py:
--------------------------------------------------------------------------------
1 | from langchain.prompts.prompt import PromptTemplate
2 |
3 | template = """Use the following pieces of context to answer the question at the end. If you don't know the answer, say '''I don't know the answer to this question.'''. Don't try to make up an answer. Write a comprehensive and well formated answer that is readable for a broad and diverse audience of readers. Answer the question as much as possible and only with the provided documents. Don't reference the provided documents but write a extensive answer but also keep the answer concise within a maximum of 8 sentences.
4 | {context}
5 | Question: {question}
6 | Helpful Answer:"""
7 |
8 | QA_ANSWER_PROMPT = PromptTemplate(
9 | input_variables=["context", "question"],
10 | template=template,
11 | )
12 |
--------------------------------------------------------------------------------
/ragflow/commons/prompts/qa_geneneration_prompts.py:
--------------------------------------------------------------------------------
1 | # flake8: noqa
2 | from langchain.chains.prompt_selector import ConditionalPromptSelector, is_chat_model
3 | from langchain.prompts.chat import (
4 | ChatPromptTemplate,
5 | HumanMessagePromptTemplate,
6 | SystemMessagePromptTemplate,
7 | )
8 | from langchain.prompts.prompt import PromptTemplate
9 |
10 | temp_chat_1 = """You are a smart assistant that can identify key information in a given text and extract a corresponding question and answer pair of this key information so that we can build a Frequently Asked Questions (FAQ) section. Both the question and the answer should be comprehensive, well formulated and readable for a broad and diverse audience of readers.
11 |
12 | When coming up with this question/answer pair, you must respond in the following format:
13 |
14 | ```
15 | {{
16 | "question": "$YOUR_QUESTION_HERE",
17 | "answer": "$THE_ANSWER_HERE"
18 | }}
19 | ```
20 |
21 | Everything between the ``` must be valid json.
22 | """
23 | temp_chat_2 = """Please come up with a question and answer pair, in the specified JSON format, for the following text:
24 | ----------------
25 | {text}"""
26 |
27 | CHAT_PROMPT = ChatPromptTemplate.from_messages(
28 | [
29 | SystemMessagePromptTemplate.from_template(temp_chat_1),
30 | HumanMessagePromptTemplate.from_template(temp_chat_2),
31 | ]
32 | )
33 |
34 | temp_dft = """You are a smart assistant that can identify key information in a given text and extract a corresponding question and answer pair of this key information so that we can build a Frequently Asked Questions (FAQ) section. Both the question and the answer should be comprehensive, well formulated and readable for a broad and diverse audience of readers.
35 |
36 | When coming up with this question/answer pair, you must respond in the following format:
37 |
38 | ```
39 | {{
40 | "question": "$YOUR_QUESTION_HERE",
41 | "answer": "$THE_ANSWER_HERE"
42 | }}
43 | ```
44 |
45 | Everything between the ``` must be valid json.
46 |
47 | Please come up with a question/answer pair, in the specified JSON format, for the following text:
48 | ----------------
49 | {text}
50 | """
51 |
52 | PROMPT = PromptTemplate.from_template(temp_dft)
53 |
54 | QA_GENERATION_PROMPT_SELECTOR = ConditionalPromptSelector(
55 | default_prompt=PROMPT, conditionals=[(is_chat_model, CHAT_PROMPT)]
56 | )
57 |
--------------------------------------------------------------------------------
/ragflow/commons/vectorstore/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AndreasX42/RAGflow/cec31b91af8ab7faa05a3d0376338c7d294d29aa/ragflow/commons/vectorstore/__init__.py
--------------------------------------------------------------------------------
/ragflow/commons/vectorstore/pgvector_utils.py:
--------------------------------------------------------------------------------
1 | from sqlalchemy import create_engine, Table, MetaData
2 | from sqlalchemy.sql import select
3 | from langchain.vectorstores.pgvector import PGVector
4 | from langchain.schema.embeddings import Embeddings
5 |
6 | from ragflow.api import PGVECTOR_URL
7 |
8 | TABLE = "langchain_pg_collection"
9 | COL_NAME = "name"
10 |
11 |
12 | def delete_collection(name: str) -> None:
13 | col = PGVector(
14 | collection_name=name,
15 | connection_string=PGVECTOR_URL,
16 | embedding_function=None,
17 | )
18 |
19 | col.delete_collection()
20 |
21 |
22 | def create_collection(name: str, embedding: Embeddings) -> PGVector:
23 | col = PGVector(
24 | collection_name=name,
25 | connection_string=PGVECTOR_URL,
26 | embedding_function=embedding,
27 | pre_delete_collection=True,
28 | )
29 |
30 | col.create_collection()
31 |
32 | return col
33 |
34 |
35 | def list_collections() -> list[str]:
36 | # Create an engine
37 | engine = create_engine(PGVECTOR_URL)
38 |
39 | # Reflect the specific table
40 | metadata = MetaData()
41 | table = Table(TABLE, metadata, autoload_with=engine)
42 |
43 | # Query the column
44 | query = select(table.c[COL_NAME])
45 | with engine.connect() as connection:
46 | results = connection.execute(query).fetchall()
47 |
48 | return results
49 |
--------------------------------------------------------------------------------
/ragflow/evaluation/__init__.py:
--------------------------------------------------------------------------------
1 | from ragflow.evaluation.hp_evaluator import arun_evaluation
2 |
--------------------------------------------------------------------------------
/ragflow/evaluation/hp_evaluator.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import asyncio
3 | from tqdm.asyncio import tqdm as tqdm_asyncio
4 | import glob
5 | import os
6 |
7 | from datetime import datetime
8 |
9 | from ragflow.utils import get_retriever, get_qa_llm, write_json, read_json
10 | from ragflow.utils.doc_processing import aload_and_chunk_docs
11 | from ragflow.evaluation.utils import process_retrieved_docs, write_generated_data_to_csv
12 | from ragflow.commons.configurations import Hyperparameters
13 | from ragflow.commons.chroma import ChromaClient
14 |
15 | from ragflow.evaluation.metrics import (
16 | answer_embedding_similarity,
17 | predicted_answer_accuracy,
18 | retriever_mrr_accuracy,
19 | retriever_semantic_accuracy,
20 | rouge_score,
21 | )
22 |
23 | logger = logging.getLogger(__name__)
24 |
25 |
26 | async def arun_eval_for_hp(
27 | label_dataset: list[dict[str, str]],
28 | hp: Hyperparameters,
29 | document_store: list[str],
30 | user_id: str,
31 | ) -> dict:
32 | """Entry point for initiating the evaluation based on the provided hyperparameters and documents.
33 |
34 | Args:
35 | label_dataset (list[dict[str, str]]): _description_
36 | hp (BaseConfigurations): _description_
37 | docs_path (str): _description_
38 |
39 | Returns:
40 | dict: _description_
41 | """
42 | scores = {
43 | "answer_similarity_score": -1,
44 | "retriever_mrr@3": -1,
45 | "retriever_mrr@5": -1,
46 | "retriever_mrr@10": -1,
47 | "rouge1": -1,
48 | "rouge2": -1,
49 | "rougeLCS": -1,
50 | # metrics below use LLM for grading
51 | "correctness_score": -1,
52 | "comprehensiveness_score": -1,
53 | "readability_score": -1,
54 | "retriever_semantic_accuracy": -1,
55 | }
56 |
57 | # create chunks of all provided documents
58 | chunks = await aload_and_chunk_docs(hp, document_store)
59 |
60 | # get retriever from chunks
61 | retriever, retrieverForGrading = get_retriever(chunks, hp, user_id, for_eval=True)
62 |
63 | # chunks are no longer needed
64 | del chunks
65 |
66 | # llm for answering queries
67 | qa_llm = get_qa_llm(retriever, hp.qa_llm)
68 |
69 | # list(dict[question, result, source_documents])
70 | predicted_answers = await asyncio.gather(
71 | *[qa_llm.acall(qa_pair) for qa_pair in label_dataset]
72 | )
73 |
74 | # list of retrieved docs for each qa_pair
75 | process_retrieved_docs(predicted_answers, hp.id)
76 |
77 | # Calculate embedding similarities of label and predicted answers
78 | scores[
79 | "answer_similarity_score"
80 | ] = answer_embedding_similarity.grade_embedding_similarity(
81 | label_dataset, predicted_answers, hp.embedding_model, user_id
82 | )
83 |
84 | (
85 | scores["retriever_mrr@3"],
86 | scores["retriever_mrr@5"],
87 | scores["retriever_mrr@10"],
88 | ) = await retriever_mrr_accuracy.grade_retriever(label_dataset, retrieverForGrading)
89 |
90 | # Calculate ROUGE scores
91 | scores["rouge1"], scores["rouge2"], scores["rougeLCS"] = rouge_score.grade_rouge(
92 | label_dataset, predicted_answers
93 | )
94 |
95 | # if we want a llm to grade the predicted answers as well
96 | if hp.use_llm_grader:
97 | # grade predicted answers
98 | (
99 | scores["correctness_score"],
100 | scores["comprehensiveness_score"],
101 | scores["readability_score"],
102 | ) = predicted_answer_accuracy.grade_predicted_answer(
103 | label_dataset,
104 | predicted_answers,
105 | hp.grader_llm,
106 | hp.grade_answer_prompt,
107 | )
108 |
109 | # grade quality of retrieved documents used to answer the questions
110 | scores[
111 | "retriever_semantic_accuracy"
112 | ] = retriever_semantic_accuracy.grade_retriever(
113 | label_dataset,
114 | predicted_answers,
115 | hp.grader_llm,
116 | hp.grade_docs_prompt,
117 | )
118 |
119 | # preprocess dict before writing to json and add additional information like a timestamp
120 | result_dict = hp.to_dict()
121 | result_dict |= {"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S,%f")[:-3]}
122 | result_dict |= {"scores": scores}
123 |
124 | # if in test mode store additional information about retriever and vectorstore objects
125 | if os.environ.get("EXECUTION_CONTEXT") == "TEST":
126 | result_dict |= {"retriever_params": retriever.dict()}
127 | result_dict |= {"vectorstore_params": retriever.vectorstore._collection.dict()}
128 |
129 | return result_dict, predicted_answers
130 |
131 |
132 | def prepare_evaluation_run(
133 | hyperparameters_path: str, user_id: str, api_keys: dict
134 | ) -> list[Hyperparameters]:
135 | """Preparation steps before running evaluation. Primarily we have to increment the hyperparameter ids correctly."""
136 |
137 | def extract_hpid(s):
138 | import re
139 |
140 | match = re.search(r"hpid_(\d+)", s)
141 | return int(match.group(1)) if match else 0
142 |
143 | hyperparams_list = read_json(hyperparameters_path)
144 |
145 | # find all previous hp runs with corresponding ids
146 | with ChromaClient() as client:
147 | collections = client.list_collections()
148 |
149 | hp_ids = [
150 | extract_hpid(col.name)
151 | for col in collections
152 | if col.name.startswith(f"userid_{user_id}_") and "_hpid_" in col.name
153 | ]
154 |
155 | if not hp_ids or all(id == -1 for id in hp_ids):
156 | next_id = 0
157 | else:
158 | next_id = max(hp_ids) + 1
159 |
160 | hp_list = [
161 | Hyperparameters.from_dict(config, id, api_keys)
162 | for id, config in enumerate(hyperparams_list, start=next_id)
163 | ]
164 |
165 | return hp_list
166 |
167 |
168 | async def arun_evaluation(
169 | document_store_path: str,
170 | hyperparameters_path: str,
171 | label_dataset_path: str,
172 | hyperparameters_results_path: str,
173 | hyperparameters_results_data_path: str,
174 | user_id: str,
175 | api_keys: dict,
176 | ) -> None:
177 | """Entry point for the evaluation of a set of hyperparameters to benchmark the corresponding RAG model.
178 |
179 | Args:
180 | document_store_path (str): Path to the documents
181 | hyperparameters_path (str): Path to the hyperparameters
182 | label_dataset_path (str): Path to the evaluation ground truth dataset
183 | hyperparameters_results_path (str): Path to the JSON file where results should get stored
184 | hyperparameters_results_data_path (str): Path to the CSV file where additional evaluation data should get stored
185 | user_id (str): The userd id
186 | api_keys (dict): All required API keys for the LLM endpoints
187 | """
188 | # load evaluation dataset
189 | label_dataset = read_json(label_dataset_path)
190 |
191 | # preprocess hyperparameter configs from json file
192 | hp_list = prepare_evaluation_run(hyperparameters_path, user_id, api_keys)
193 |
194 | document_store = glob.glob(f"{document_store_path}/*")
195 |
196 | tasks = [
197 | arun_eval_for_hp(label_dataset, hp, document_store, user_id) for hp in hp_list
198 | ]
199 |
200 | # run evaluations for all hyperparams
201 | results = await tqdm_asyncio.gather(*tasks, total=len(tasks))
202 |
203 | eval_scores, predicted_answers = list(zip(*results))
204 |
205 | # delete processed hyperparameters from input json
206 | write_json(
207 | data=[],
208 | filename=hyperparameters_path,
209 | append=True if os.environ.get("EXECUTION_CONTEXT", "") == "TEST" else False,
210 | )
211 |
212 | # write eval metrics to json
213 | write_json(eval_scores, hyperparameters_results_path)
214 |
215 | # write predicted answers and retrieved docs to csv
216 | write_generated_data_to_csv(predicted_answers, hyperparameters_results_data_path)
217 |
--------------------------------------------------------------------------------
/ragflow/evaluation/metrics/__init__.py:
--------------------------------------------------------------------------------
1 | from ragflow.evaluation.metrics import (
2 | answer_embedding_similarity,
3 | predicted_answer_accuracy,
4 | retriever_mrr_accuracy,
5 | retriever_semantic_accuracy,
6 | rouge_score,
7 | )
8 |
--------------------------------------------------------------------------------
/ragflow/evaluation/metrics/answer_embedding_similarity.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from langchain.schema.embeddings import Embeddings
4 | from ragflow.commons.chroma import ChromaClient
5 | from ragflow.commons.configurations import BaseConfigurations
6 |
7 | import logging
8 |
9 | logger = logging.getLogger(__name__)
10 |
11 |
12 | def grade_embedding_similarity(
13 | label_dataset: list[dict],
14 | predictions: list[dict],
15 | embedding_model: Embeddings,
16 | user_id: str,
17 | ) -> float:
18 | """Calculate similarities of label answers and generated answers using the corresponding embeddings. We multiply the matrixes of the provided embeddings and take the average of the diagonal values, which should be the cosine similarites assuming that the embeddings were already normalized.
19 |
20 | Args:
21 | label_dataset (list[dict]): The evaluation ground truth dataset of QA pairs
22 | predictions (list[dict]): A dict containing the predicted answers from the queries and the retrieved document chunks
23 | embedding_model (Embeddings): The embedding model
24 | user_id (str): The user id
25 | answers from the evaluation dataset from ChromaDB.
26 |
27 | Returns:
28 | float: The average similarity score.
29 | """
30 | logger.info("Calculating embedding similarities.")
31 |
32 | num_qa_pairs = len(label_dataset)
33 |
34 | label_answers = [qa_pair["answer"] for qa_pair in label_dataset]
35 | predicted_answers = [qa_pair["result"] for qa_pair in predictions]
36 |
37 | # try using embeddings of answers of evaluation set from vectorstore that were stored in ChromaDB in generation process, otherwise we calculate them again
38 | try:
39 | with ChromaClient() as CHROMA_CLIENT:
40 | collection_name = f"userid_{user_id}_qaid_0_{BaseConfigurations.get_embedding_model_name(embedding_model)}"
41 |
42 | for col in CHROMA_CLIENT.list_collections():
43 | if col.name == collection_name:
44 | collection = col
45 | break
46 |
47 | ids = [qa["metadata"]["id"] for qa in label_dataset]
48 |
49 | target_embeddings = np.array(
50 | collection.get(ids=ids, include=["embeddings"])["embeddings"]
51 | ).reshape(num_qa_pairs, -1)
52 |
53 | logger.info("Embeddings for label answers loaded successfully.")
54 |
55 | except Exception as ex:
56 | logger.info(
57 | f"Embeddings of {BaseConfigurations.get_embedding_model_name(embedding_model)} for label answers could not be loaded from vectorstore.\n\
58 | Collections: {CHROMA_CLIENT.list_collections()}.\n\
59 | Exception: {ex.args}"
60 | )
61 |
62 | target_embeddings = np.array(
63 | embedding_model.embed_documents(label_answers)
64 | ).reshape(num_qa_pairs, -1)
65 |
66 | predicted_embeddings = np.array(
67 | embedding_model.embed_documents(predicted_answers)
68 | ).reshape(num_qa_pairs, -1)
69 |
70 | emb_norms = np.linalg.norm(target_embeddings, axis=1) * np.linalg.norm(
71 | predicted_embeddings, axis=1
72 | )
73 |
74 | dot_prod = np.diag(np.dot(target_embeddings, predicted_embeddings.T))
75 |
76 | similarities = dot_prod / emb_norms
77 |
78 | return np.nanmean(similarities)
79 |
--------------------------------------------------------------------------------
/ragflow/evaluation/metrics/predicted_answer_accuracy.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from langchain.evaluation.qa import QAEvalChain
4 | from langchain.schema.language_model import BaseLanguageModel
5 |
6 | from ragflow.evaluation.utils import extract_llm_metric
7 | from ragflow.commons.configurations import CVGradeAnswerPrompt
8 | from ragflow.commons.prompts import (
9 | GRADE_ANSWER_PROMPT_5_CATEGORIES_5_GRADES_ZERO_SHOT,
10 | GRADE_ANSWER_PROMPT_3_CATEGORIES_4_GRADES_FEW_SHOT,
11 | )
12 |
13 | import logging
14 |
15 | logger = logging.getLogger(__name__)
16 |
17 |
18 | def grade_predicted_answer(
19 | label_dataset: list[dict],
20 | predicted_answers: list[str],
21 | grader_llm: BaseLanguageModel,
22 | grade_answer_prompt: CVGradeAnswerPrompt,
23 | ) -> tuple[float, float, float]:
24 | """_summary_Using a QAEvalChain and a LLM for grading we calculate a grade for the predicted answers from the query and the retrieved document chunks. We provide zero shot and few shot prompts and get the LLM to provide a score for correctness, comprehensiveness and readability of the predicted answers.
25 |
26 | Args:
27 | label_dataset (list[dict]): The evaluation ground truth dataset of QA pairs
28 | predicted_answers (list[str]): The predicted answers
29 | grader_llm (BaseLanguageModel): The LLM used for grading
30 | grade_answer_prompt (CVGradeAnswerPrompt): The prompt used in the QAEvalChain
31 |
32 | Returns:
33 | tuple[float, float, float]: Returns average scores for correctness, comprehensiveness and readability
34 | """
35 |
36 | logger.info("Grading generated answers.")
37 |
38 | if grade_answer_prompt == CVGradeAnswerPrompt.ZERO_SHOT:
39 | prompt, MAX_GRADE = GRADE_ANSWER_PROMPT_5_CATEGORIES_5_GRADES_ZERO_SHOT, 5
40 | elif grade_answer_prompt == CVGradeAnswerPrompt.FEW_SHOT:
41 | prompt, MAX_GRADE = GRADE_ANSWER_PROMPT_3_CATEGORIES_4_GRADES_FEW_SHOT, 3
42 | else:
43 | prompt, MAX_GRADE = None, 1
44 |
45 | # Note: GPT-4 grader is advised by OAI model_name="gpt-4"
46 | eval_chain = QAEvalChain.from_llm(llm=grader_llm, prompt=prompt, verbose=False)
47 |
48 | outputs = eval_chain.evaluate(
49 | label_dataset,
50 | predicted_answers,
51 | question_key="question",
52 | prediction_key="result",
53 | )
54 |
55 | # vectorize the function for efficiency
56 | v_extract_llm_metric = np.vectorize(extract_llm_metric)
57 |
58 | outputs = np.array([output["results"] for output in outputs])
59 |
60 | correctness = v_extract_llm_metric(outputs, "CORRECTNESS")
61 | comprehensiveness = v_extract_llm_metric(outputs, "COMPREHENSIVENESS")
62 | readability = v_extract_llm_metric(outputs, "READABILITY")
63 |
64 | return (
65 | np.nanmean(correctness) / MAX_GRADE,
66 | np.nanmean(comprehensiveness) / MAX_GRADE,
67 | np.nanmean(readability) / MAX_GRADE,
68 | )
69 |
--------------------------------------------------------------------------------
/ragflow/evaluation/metrics/retriever_mrr_accuracy.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import numpy as np
3 | from langchain.schema.vectorstore import VectorStoreRetriever
4 | from langchain.schema.document import Document
5 |
6 | import logging
7 |
8 | logger = logging.getLogger(__name__)
9 |
10 |
11 | async def grade_retriever(
12 | label_dataset: list[dict], retrieverForGrading: VectorStoreRetriever
13 | ) -> tuple[float, float, float]:
14 | """Calculate metrics for the precision of the retriever using MRR (Mean Reciprocal Rank) scores.
15 | For every list of retrieved documents we calculate the rank of the chunks wrt reference chunk from
16 | the label dataset.
17 |
18 | Args:
19 | label_dataset (list[dict]): _description_
20 | retrieved_docs (list[str]): _description_
21 | retrieverForGrading (VectorStoreRetriever): _description_
22 |
23 | Returns:
24 | tuple[float,float,float]: MRR scores for top 3, 5 and 10
25 | """
26 |
27 | logger.info("Grading retrieved document chunks using MRR.")
28 |
29 | # get predicted answers with top 10 retrieved document chunks
30 | retrieved_chunks = await asyncio.gather(
31 | *[
32 | retrieverForGrading.aget_relevant_documents(qa_pair["question"])
33 | for qa_pair in label_dataset
34 | ]
35 | )
36 |
37 | ref_chunks = [
38 | Document(page_content=label["metadata"]["context"], metadata=label["metadata"])
39 | for label in label_dataset
40 | ]
41 |
42 | return calculate_mrr(ref_chunks, retrieved_chunks)
43 |
44 |
45 | def calculate_mrr(ref_chunks, retrieved_chunks):
46 | """Calculates mrr scores."""
47 |
48 | top3, top5, top10 = [], [], []
49 |
50 | # Check to ensure ref_chunks is not empty
51 | if not ref_chunks:
52 | return 0, 0, 0
53 |
54 | for ref_chunk, retr_chunks in zip(ref_chunks, retrieved_chunks):
55 | hit = False
56 | for idx, chunk in enumerate(retr_chunks, 1):
57 | if is_hit(ref_chunk, chunk):
58 | rank = 1 / idx
59 | if idx <= 3:
60 | top3.append(rank)
61 | if idx <= 5:
62 | top5.append(rank)
63 | if idx <= 10:
64 | top10.append(rank)
65 | hit = True
66 | break
67 |
68 | if not hit: # If there's no hit, the rank contribution is 0
69 | top3.append(0)
70 | top5.append(0)
71 | top10.append(0)
72 |
73 | # Calculate the MRR for the top 3, 5, and 10 documents
74 | mrr_3 = sum(top3) / len(ref_chunks)
75 | mrr_5 = sum(top5) / len(ref_chunks)
76 | mrr_10 = sum(top10) / len(ref_chunks)
77 |
78 | return mrr_3, mrr_5, mrr_10
79 |
80 |
81 | def is_hit(ref_chunk, retrieved_chunk):
82 | """Checks if retrieved chunk is close to reference chunk."""
83 |
84 | # Check if both chunks are from same document
85 | if ref_chunk.metadata["source"] != retrieved_chunk.metadata["source"]:
86 | return False
87 |
88 | label_start, label_end = (
89 | ref_chunk.metadata["start_index"],
90 | ref_chunk.metadata["end_index"],
91 | )
92 | retr_start, retr_end = (
93 | retrieved_chunk.metadata["start_index"],
94 | retrieved_chunk.metadata["end_index"],
95 | )
96 |
97 | retr_center = retr_start + (retr_end - retr_start) // 2
98 |
99 | # Consider retrieved chunk to be a hit if it is near the reference chunk
100 | return retr_center in range(label_start, label_end + 1)
101 |
--------------------------------------------------------------------------------
/ragflow/evaluation/metrics/retriever_semantic_accuracy.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from langchain.evaluation.qa import QAEvalChain
4 | from langchain.schema.language_model import BaseLanguageModel
5 |
6 | from ragflow.commons.prompts import GRADE_RETRIEVER_PROMPT
7 | from ragflow.evaluation.utils import extract_llm_metric
8 | from ragflow.commons.configurations import CVGradeRetrieverPrompt
9 |
10 | import logging
11 |
12 | logger = logging.getLogger(__name__)
13 |
14 |
15 | def grade_retriever(
16 | label_dataset: list[dict],
17 | retrieved_docs: list[str],
18 | grader_llm: BaseLanguageModel,
19 | grade_docs_prompt: CVGradeRetrieverPrompt,
20 | ) -> float:
21 | """Using LangChains QAEvalChain we use a LLM as grader to get a metric how well the found document chunks from the vectorstore provide the answer for the label QA pairs. Per QA pair the LLM has to rate the retrieved document chunks with 0 or 1 depending if the question can be answered with the provided document chunks.
22 |
23 | Args:
24 | label_dataset (list[dict]): The evaluation ground truth dataset of QA pairs
25 | retrieved_docs (list[str]): The retrieved document chunks of each QA pair
26 | grader_llm (BaseLanguageModel): The LLM grading the document chunks
27 | grade_docs_prompt (CVGradeRetrieverPrompt): The type of prompt used in the QAEvalChain
28 |
29 | Returns:
30 | float: The average score of all QA pairs
31 | """
32 |
33 | logger.info("Grading retrieved document chunks.")
34 |
35 | # TODO: Provide more Prompts and more detailed prompt engineering
36 | if grade_docs_prompt == CVGradeRetrieverPrompt.DEFAULT:
37 | prompt = GRADE_RETRIEVER_PROMPT
38 | else:
39 | prompt = GRADE_RETRIEVER_PROMPT
40 |
41 | # Note: GPT-4 grader is advised by OAI
42 | eval_chain = QAEvalChain.from_llm(llm=grader_llm, prompt=prompt)
43 |
44 | outputs = eval_chain.evaluate(
45 | label_dataset,
46 | retrieved_docs,
47 | question_key="question",
48 | answer_key="answer",
49 | prediction_key="retrieved_docs",
50 | )
51 |
52 | # vectorize the function for efficiency
53 | v_extract_llm_metric = np.vectorize(extract_llm_metric)
54 |
55 | outputs = np.array([output["results"] for output in outputs])
56 | retrieval_accuracy = v_extract_llm_metric(outputs, "GRADE")
57 |
58 | return np.nanmean(retrieval_accuracy)
59 |
--------------------------------------------------------------------------------
/ragflow/evaluation/metrics/rouge_score.py:
--------------------------------------------------------------------------------
1 | import evaluate
2 |
3 | import logging
4 |
5 | logger = logging.getLogger(__name__)
6 |
7 |
8 | def grade_rouge(
9 | label_dataset: list[dict], predicted_answers: list[dict]
10 | ) -> tuple[float, float]:
11 | """Calculates rouge1 and rouge2 scores of the label and generated answers.
12 |
13 | Args:
14 | label_dataset (list[dict]): The evaluation ground truth dataset of QA pairs
15 | predictions (list[dict]): The predicted answers
16 |
17 | Returns:
18 | tuple[float, float]: The rouge1 and rouge2 average scores
19 | """
20 | logger.info("Calculating ROUGE scores.")
21 |
22 | label_answers = [qa_pair["answer"] for qa_pair in label_dataset]
23 | predicted_answers = [qa_pair["result"] for qa_pair in predicted_answers]
24 |
25 | rouge_score = evaluate.load("rouge")
26 | score = rouge_score.compute(references=label_answers, predictions=predicted_answers)
27 |
28 | return score["rouge1"], score["rouge2"], score["rougeL"]
29 |
--------------------------------------------------------------------------------
/ragflow/evaluation/utils.py:
--------------------------------------------------------------------------------
1 | from langchain.schema.retriever import BaseRetriever
2 | from langchain.docstore.document import Document
3 |
4 | import pandas as pd
5 | import numpy as np
6 | import re
7 | import os
8 |
9 |
10 | def extract_llm_metric(text: str, metric: str) -> float:
11 | """Utiliy function for extracting scores from LLM output from grading of generated answers and retrieved document chunks.
12 |
13 | Args:
14 | text (str): LLM result
15 | metric (str): name of metric
16 |
17 | Returns:
18 | number: the found score as integer or np.nan
19 | """
20 | match = re.search(f"{metric}: (\d+)", text)
21 | if match:
22 | return int(match.group(1))
23 | return np.nan
24 |
25 |
26 | async def aget_retrieved_documents(
27 | qa_pair: dict[str, str], retriever: BaseRetriever
28 | ) -> dict:
29 | """Retrieve most similar documents to query asynchronously, postprocess the document chunks and return the qa pair with the retrieved documents string as result
30 |
31 | Args:
32 | qa_pair (dict[str, str]): _description_
33 | retriever (BaseRetriever): _description_
34 |
35 | Returns:
36 | list[dict]: _description_
37 | """
38 | query = qa_pair["question"]
39 | docs_retrieved = await retriever.aget_relevant_documents(query)
40 |
41 | retrieved_doc_text = "\n\n".join(
42 | f"Retrieved document {i}: {doc.page_content}"
43 | for i, doc in enumerate(docs_retrieved)
44 | )
45 | retrieved_dict = {
46 | "question": qa_pair["question"],
47 | "answer": qa_pair["answer"],
48 | "result": retrieved_doc_text,
49 | }
50 |
51 | return retrieved_dict
52 |
53 |
54 | def process_retrieved_docs(qa_results: list[dict], hp_id: int) -> None:
55 | """For each list of Documents retrieved in QA_LLM call, we merge the documents page contents into a string. We store the id of the hyperparameters for later.
56 |
57 | Args:
58 | qa_results (list[dict]): _description_
59 | hp_id (int): _description_
60 | """
61 |
62 | # Each qa_result is a dict containing in "source_documents" the list of retrieved documents used to answer the query
63 | for qa_result in qa_results:
64 | retrieved_docs_str = "\n\n---------------------------\n".join(
65 | f"///Retrieved document {i}: {clean_page_content(doc.page_content)}\n---Metadata: {doc.metadata}"
66 | for i, doc in enumerate(qa_result["source_documents"])
67 | )
68 |
69 | # key now stores the concatenated string of the page contents
70 | qa_result["retrieved_docs"] = retrieved_docs_str
71 |
72 | # id of hyperparameters needed later
73 | qa_result["hp_id"] = hp_id
74 |
75 |
76 | def convert_element_to_df(element: dict):
77 | """Function to convert each tuple element of result to dataframe."""
78 | df = pd.DataFrame(element)
79 | df = pd.concat([df, pd.json_normalize(df["metadata"])], axis=1)
80 | df = df.drop(columns=["metadata", "question", "answer", "source"])
81 | df = df.rename(columns={"result": "predicted_answer", "id": "qa_id"})
82 | df["hp_id"] = df["hp_id"].astype(int)
83 |
84 | return df
85 |
86 |
87 | def clean_page_content(page_content: str):
88 | """Helper function to remove paragraph breaks."""
89 | return re.sub("\n+", "\n", page_content)
90 |
91 |
92 | def write_generated_data_to_csv(
93 | predicted_answers: list[dict],
94 | hyperparameters_results_data_path: str,
95 | ) -> None:
96 | """Write predicted answerd and retrieved documents to csv."""
97 | # Create the dataframe
98 | df = pd.concat(
99 | [convert_element_to_df(element) for element in predicted_answers], axis=0
100 | )
101 |
102 | if os.path.exists(hyperparameters_results_data_path):
103 | base_df = pd.read_csv(hyperparameters_results_data_path)
104 | else:
105 | base_df = pd.DataFrame()
106 |
107 | # Concatenate the DataFrames along rows while preserving shared columns
108 | result = pd.concat(
109 | [
110 | base_df,
111 | df[["hp_id", "predicted_answer", "retrieved_docs", "qa_id"]],
112 | ],
113 | axis=0,
114 | )
115 |
116 | # Write df to csv
117 | result.sort_values(by=["hp_id"]).to_csv(
118 | hyperparameters_results_data_path, index=False
119 | )
120 |
--------------------------------------------------------------------------------
/ragflow/example.py:
--------------------------------------------------------------------------------
1 | import dotenv
2 | import logging
3 | import asyncio
4 |
5 | from ragflow.evaluation import arun_evaluation
6 | from ragflow.generation import agenerate_evaluation_set
7 |
8 | logger = logging.getLogger(__name__)
9 |
10 | dotenv.load_dotenv(dotenv.find_dotenv(), override=True)
11 |
12 |
13 | async def main(
14 | document_store_path: str = "./resources/dev/document_store/",
15 | label_dataset_path: str = "./resources/dev/label_dataset.json",
16 | label_dataset_gen_params_path: str = "./resources/dev/label_dataset_gen_params.json",
17 | hyperparameters_path: str = "./resources/dev/hyperparameters.json",
18 | hyperparameters_results_path: str = "./resources/dev/hyperparameters_results.json",
19 | hyperparameters_results_data_path: str = "./resources/dev/hyperparameters_results_data.csv",
20 | user_id: str = "336e131c-d9f3-4185-a402-f5f8875e34f0",
21 | ):
22 | """After generating label_dataset we need to calculate the metrics based on Chunking strategy, type of vectorstore, retriever (similarity search), QA LLM
23 |
24 | Dependencies of parameters and build order:
25 | generating chunks | depends on chunking strategy
26 | |__ vector database | depends on type von DB
27 | |__ retriever | depends on Embedding model
28 | |__ QA LLM | depends on LLM and digests retriever
29 |
30 | |__ grader: llm who grades generated answers
31 |
32 | TODO 1: State of the art retrieval with lots of additional LLM calls:
33 | - use "MMR" to filter out similiar document chunks during similarity search
34 | - Use SelfQueryRetriever to find information in query about specific context, e.g. querying specific document only
35 | - Use ContextualCompressionRetriever to compress or summarize the document chunks before serving them to QA-LLM
36 |
37 | - retriever object would like something like this, additional metadata info for SelfQueryRetriever has to be provided as well
38 | retriever = ContextualCompressionRetriever(
39 | base_compressor=LLMChainExtractor.from_llm(llm),
40 | base_retriever=SelfQueryRetriever.from_llm(..., earch_type = "mmr")
41 | )
42 |
43 | https://learn.deeplearning.ai/langchain-chat-with-your-data/lesson/5/retrieval
44 | https://learn.deeplearning.ai/langchain-chat-with-your-data/lesson/7/chat
45 |
46 | TODO 2: Integrate Anyscale (Llama2), MosaicML (MPT + EMBDS), Replicate
47 |
48 | Args:
49 | chain (str, optional): _description_. Defaults to "".
50 | retriever (str, optional): _description_. Defaults to "".
51 | retriever_type (str, optional): _description_. Defaults to "".
52 | k (int, optional): _description_. Defaults to 3.
53 | grade_prompt (str, optional): _description_. Defaults to "".
54 | file_path (str, optional): _description_. Defaults to "./resources".
55 | eval_gt_path (str, optional): _description_. Defaults to "./resources/label_dataset.json".
56 | """
57 |
58 | # provide API keys here if necessary
59 | import os
60 |
61 | api_keys = {"OPENAI_API_KEY": os.environ.get("OPENAI_API_KEY")}
62 |
63 | ################################################################
64 | # First phase: Loading or generating evaluation dataset
65 | ################################################################
66 |
67 | logger.info("Checking for evaluation dataset configs.")
68 |
69 | await agenerate_evaluation_set(
70 | label_dataset_gen_params_path,
71 | label_dataset_path,
72 | document_store_path,
73 | user_id,
74 | api_keys,
75 | )
76 |
77 | ################################################################
78 | # Second phase: Running evaluations
79 | ################################################################
80 |
81 | logger.info("Starting evaluation for all provided hyperparameters.")
82 |
83 | await arun_evaluation(
84 | document_store_path,
85 | hyperparameters_path,
86 | label_dataset_path,
87 | hyperparameters_results_path,
88 | hyperparameters_results_data_path,
89 | user_id,
90 | api_keys,
91 | )
92 |
93 |
94 | if __name__ == "__main__":
95 | asyncio.run(main())
96 |
--------------------------------------------------------------------------------
/ragflow/generation/__init__.py:
--------------------------------------------------------------------------------
1 | from ragflow.generation.label_dataset_generator import agenerate_evaluation_set
2 |
--------------------------------------------------------------------------------
/ragflow/generation/label_dataset_generator.py:
--------------------------------------------------------------------------------
1 | from json import JSONDecodeError
2 | import itertools
3 | import asyncio
4 | import glob
5 | import os
6 | from datetime import datetime
7 | from tqdm.asyncio import tqdm as tqdm_asyncio
8 |
9 | from langchain.chains import QAGenerationChain
10 | from langchain.docstore.document import Document
11 | from langchain.schema.embeddings import Embeddings
12 |
13 | from ragflow.commons.prompts import QA_GENERATION_PROMPT_SELECTOR
14 | from ragflow.utils import aload_and_chunk_docs, write_json, read_json
15 | from ragflow.commons.configurations import QAConfigurations
16 | from ragflow.commons.chroma import ChromaClient
17 |
18 | import uuid
19 | import logging
20 |
21 | logger = logging.getLogger(__name__)
22 |
23 |
24 | async def get_qa_from_chunk(
25 | chunk: Document,
26 | qa_generator_chain: QAGenerationChain,
27 | ) -> list[dict]:
28 | """Generate QA from provided text document chunk."""
29 | try:
30 | # return list of qa pairs
31 | qa_pairs = qa_generator_chain.run(chunk.page_content)
32 |
33 | # attach chunk metadata to qa_pair
34 | for qa_pair in qa_pairs:
35 | qa_pair["metadata"] = dict(**chunk.metadata)
36 | qa_pair["metadata"].update(
37 | {"id": str(uuid.uuid4()), "context": chunk.page_content}
38 | )
39 |
40 | return qa_pairs
41 | except JSONDecodeError:
42 | return []
43 |
44 |
45 | async def agenerate_label_dataset_from_doc(
46 | hp: QAConfigurations,
47 | doc_path: str,
48 | ) -> list[dict[str, str]]:
49 | """Generate a pairs of QAs that are used as ground truth in downstream tasks, i.e. RAG evaluations
50 |
51 | Args:
52 | llm (BaseLanguageModel): the language model used in the QAGenerationChain
53 | chunks (List[Document]): the document chunks used for QA generation
54 |
55 | Returns:
56 | List[Dict[str, str]]: returns a list of dictionary of question - answer pairs
57 | """
58 |
59 | logger.debug(f"Starting QA generation process for {doc_path}.")
60 |
61 | # load data and chunk doc
62 | chunks = await aload_and_chunk_docs(hp, [doc_path])
63 |
64 | qa_generator_chain = QAGenerationChain.from_llm(
65 | hp.qa_generator_llm,
66 | prompt=QA_GENERATION_PROMPT_SELECTOR.get_prompt(hp.qa_generator_llm),
67 | )
68 |
69 | tasks = [get_qa_from_chunk(chunk, qa_generator_chain) for chunk in chunks]
70 |
71 | qa_pairs = await asyncio.gather(*tasks)
72 | qa_pairs = list(itertools.chain.from_iterable(qa_pairs))
73 |
74 | return qa_pairs
75 |
76 |
77 | async def agenerate_label_dataset_from_docs(
78 | hp: QAConfigurations,
79 | docs_path: list[str],
80 | ) -> list[dict]:
81 | """Asynchronous wrapper around the agenerate_label_dataset function.
82 |
83 | Args:
84 | qa_gen_configs (dict): _description_
85 | docs_path (list[str]): _description_
86 |
87 | Returns:
88 | list[dict]: _description_
89 | """
90 | tasks = [agenerate_label_dataset_from_doc(hp, doc_path) for doc_path in docs_path]
91 |
92 | results = await tqdm_asyncio.gather(*tasks)
93 |
94 | qa_pairs = list(itertools.chain.from_iterable(results))
95 |
96 | return qa_pairs
97 |
98 |
99 | async def aupsert_embeddings_for_model(
100 | qa_pairs: list[dict],
101 | embedding_model: Embeddings,
102 | user_id: str,
103 | ) -> None:
104 | """Embeds and upserts each generated answer into vectorstore. This is helpful if you want to run different hyperparameter runs with the same embedding model because you only have to embed these answers once. The embeddings are used during evaluation to check similarity of generated and predicted answers."""
105 | with ChromaClient() as CHROMA_CLIENT:
106 | collection_name = f"userid_{user_id}_qaid_0_{QAConfigurations.get_embedding_model_name(embedding_model)}"
107 |
108 | # check if collection already exists, if not create a new one with the embeddings
109 | if [
110 | collection
111 | for collection in CHROMA_CLIENT.list_collections()
112 | if collection.name.startswith(f"userid_{user_id}_")
113 | ]:
114 | logger.info(f"Collection {collection_name} already exists, skipping it.")
115 | return None
116 |
117 | collection = CHROMA_CLIENT.create_collection(
118 | name=collection_name,
119 | metadata={
120 | "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S,%f")[:-3],
121 | },
122 | )
123 |
124 | ids = [qa_pair["metadata"]["id"] for qa_pair in qa_pairs]
125 |
126 | # maybe async function is not implemented for embedding model
127 | try:
128 | embeddings = await embedding_model.aembed_documents(
129 | [qa_pair["answer"] for qa_pair in qa_pairs]
130 | )
131 | except NotImplementedError as ex:
132 | logger.error(
133 | f"Exception during eval set generation and upserting to vectorstore, {ex}"
134 | )
135 | embeddings = embedding_model.embed_documents(
136 | [qa_pair["answer"] for qa_pair in qa_pairs]
137 | )
138 |
139 | collection.upsert(
140 | ids=ids,
141 | embeddings=embeddings,
142 | metadatas=[
143 | {
144 | "question": qa_pair["question"],
145 | "answer": qa_pair["answer"],
146 | **qa_pair["metadata"],
147 | }
148 | for qa_pair in qa_pairs
149 | ],
150 | )
151 |
152 | logger.info(
153 | f"Upserted {QAConfigurations.get_embedding_model_name(embedding_model)} embeddings to vectorstore."
154 | )
155 |
156 |
157 | async def agenerate_and_save_dataset(
158 | hp: QAConfigurations,
159 | docs_path: str,
160 | label_dataset_path: str,
161 | user_id: str,
162 | ):
163 | """Generate a new evaluation dataset and save it to a JSON file."""
164 |
165 | logger.info("Starting QA generation suite.")
166 |
167 | # generate label dataset
168 | label_dataset = await agenerate_label_dataset_from_docs(hp, docs_path)
169 |
170 | # During test execution: if label_dataset is empty because of test dummy LLM, we inject a real dataset for test
171 | if (
172 | os.environ.get("EXECUTION_CONTEXT") == "TEST"
173 | and hp.persist_to_vs
174 | and not label_dataset
175 | ):
176 | label_dataset = read_json(os.environ.get("INPUT_LABEL_DATASET"))
177 |
178 | # write eval dataset to json
179 | write_json(label_dataset, label_dataset_path)
180 |
181 | # cache answers of qa pairs in vectorstore for each embedding model provided
182 | if hp.persist_to_vs:
183 | tasks = [
184 | aupsert_embeddings_for_model(label_dataset, embedding_model, user_id)
185 | for embedding_model in hp.embedding_model_list
186 | ]
187 |
188 | await asyncio.gather(*tasks)
189 |
190 |
191 | async def agenerate_evaluation_set(
192 | label_dataset_gen_params_path: str,
193 | label_dataset_path: str,
194 | document_store_path: str,
195 | user_id: str,
196 | api_keys: dict[str, str],
197 | ):
198 | """Entry function to generate the evaluation dataset.
199 |
200 | Args:
201 | label_dataset_gen_params (dict): _description_
202 | label_dataset_path (str): _description_
203 |
204 | Returns:
205 | _type_: _description_
206 | """
207 |
208 | logger.info("Checking for evaluation dataset configs.")
209 |
210 | import nltk
211 | from nltk.data import find
212 |
213 | # Check if 'punkt' is already downloaded, and download if not
214 | try:
215 | find("tokenizers/punkt")
216 | except LookupError:
217 | nltk.download("punkt")
218 |
219 | label_dataset_gen_params = read_json(label_dataset_gen_params_path)
220 |
221 | # TODO: Only one single QA generation supported per user
222 | if isinstance(label_dataset_gen_params, list):
223 | label_dataset_gen_params = label_dataset_gen_params[-1]
224 |
225 | # set up QAConfiguration object at the beginning to evaluate inputs
226 | label_dataset_gen_params = QAConfigurations.from_dict(
227 | label_dataset_gen_params, api_keys
228 | )
229 |
230 | # get list of all documents in document_store_path
231 | document_store = glob.glob(f"{document_store_path}/*")
232 |
233 | with ChromaClient() as client:
234 | # Filter collections specific to the user_id.
235 | user_collections = [
236 | collection
237 | for collection in client.list_collections()
238 | if collection.name.startswith(f"userid_{user_id}_")
239 | ]
240 |
241 | # Check if there are any collections with the QA identifier and delete them.
242 | if any(["_qaid_0_" in collection.name for collection in user_collections]):
243 | for collection in user_collections:
244 | client.delete_collection(name=collection.name)
245 |
246 | # start generation process
247 | await agenerate_and_save_dataset(
248 | label_dataset_gen_params, document_store, label_dataset_path, user_id
249 | )
250 |
--------------------------------------------------------------------------------
/ragflow/requirements.txt:
--------------------------------------------------------------------------------
1 | # Core dependencies for the AI and ML operations
2 | openai==0.28.1
3 | tiktoken==0.5.1
4 | langchain==0.0.330
5 | Cython==3.0.5
6 |
7 | pandas==2.1.2
8 | numpy==1.26.1
9 | scikit-learn==1.3.2
10 |
11 | # PDF and document processing
12 | simsimd==3.5.3
13 | docx2txt==0.8
14 | pypdf==3.17.0
15 | pdf2image==1.16.3
16 | pdfminer.six==20221105
17 |
18 | # OCR and unstructured data handling
19 | unstructured==0.10.27
20 |
21 | # Evaluation metrics
22 | evaluate==0.4.1
23 | rouge_score==0.1.2
24 | absl-py==2.0.0
25 |
26 | # Web framework and API
27 | fastapi==0.104.1
28 | pydantic[email]==2.4.2
29 | uvicorn==0.24.0.post1
30 | python-dotenv==1.0.0
31 | SQLAlchemy==2.0.23
32 | psycopg2-binary==2.9.9
33 | python-jose==3.3.0
34 | passlib==1.7.4
35 | python-multipart==0.0.6
36 |
37 | # Database connector
38 | chromadb==0.4.15
39 |
40 | # Notebooks and external APIs
41 | lark==1.1.8 # selfqueryretriever
42 | cohere==4.32 # for reranker
--------------------------------------------------------------------------------
/ragflow/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from ragflow.utils.utils import (
2 | get_retriever,
3 | get_qa_llm,
4 | read_json,
5 | write_json,
6 | )
7 |
8 | from ragflow.utils.doc_processing import aload_and_chunk_docs, load_and_chunk_doc
9 |
--------------------------------------------------------------------------------
/ragflow/utils/doc_processing.py:
--------------------------------------------------------------------------------
1 | import os
2 | import asyncio
3 | import itertools
4 |
5 | from langchain.docstore.document import Document
6 | from langchain.text_splitter import RecursiveCharacterTextSplitter
7 |
8 | from langchain.document_loaders import (
9 | TextLoader,
10 | Docx2txtLoader,
11 | # PyPDFLoader, # already splits data during loading
12 | UnstructuredPDFLoader, # returns only 1 Document
13 | # PyMuPDFLoader, # returns 1 Document per page
14 | )
15 |
16 | from ragflow.commons.configurations import BaseConfigurations
17 |
18 | import logging
19 |
20 | logger = logging.getLogger(__name__)
21 |
22 |
23 | def load_document(file: str) -> list[Document]:
24 | """Loads file from given path into a list of Documents, currently pdf, txt and docx are supported.
25 |
26 | Args:
27 | file (str): file path
28 |
29 | Returns:
30 | List[Document]: loaded files as list of Documents
31 | """
32 | _, extension = os.path.splitext(file)
33 |
34 | if extension == ".pdf":
35 | loader = UnstructuredPDFLoader(file)
36 | elif extension == ".txt":
37 | loader = TextLoader(file, encoding="utf-8")
38 | elif extension == ".docx":
39 | loader = Docx2txtLoader(file)
40 | else:
41 | logger.error(f"Unsupported file type detected, {file}!")
42 | raise NotImplementedError(f"Unsupported file type detected, {file}!")
43 |
44 | data = loader.load()
45 | return data
46 |
47 |
48 | def split_data(
49 | data: list[Document],
50 | chunk_size: int = 4096,
51 | chunk_overlap: int = 0,
52 | length_function: callable = len,
53 | ) -> list[Document]:
54 | """Function for splitting the provided data, i.e. List of documents loaded.
55 |
56 | Args:
57 | data (List[Document]): _description_
58 | chunk_size (Optional[int], optional): _description_. Defaults to 4096.
59 | chunk_overlap (Optional[int], optional): _description_. Defaults to 0.
60 | length_function (_type_, optional): _description_.
61 |
62 | Returns:
63 | List[Document]: the splitted document chunks
64 | """
65 |
66 | text_splitter = RecursiveCharacterTextSplitter(
67 | chunk_size=chunk_size,
68 | chunk_overlap=chunk_overlap,
69 | length_function=length_function,
70 | add_start_index=True,
71 | )
72 |
73 | chunks = text_splitter.split_documents(data)
74 |
75 | # add end_index
76 | for chunk in chunks:
77 | chunk.metadata["end_index"] = chunk.metadata["start_index"] + len(
78 | chunk.page_content
79 | )
80 |
81 | return chunks
82 |
83 |
84 | def load_and_chunk_doc(
85 | hp: BaseConfigurations,
86 | file: str,
87 | ) -> list[Document]:
88 | """Wrapper function for loading and splitting document in one call."""
89 |
90 | logger.debug(f"Loading and splitting file {file}.")
91 |
92 | data = load_document(file)
93 | chunks = split_data(data, hp.chunk_size, hp.chunk_overlap, hp.length_function)
94 | return chunks
95 |
96 |
97 | # TODO: Check for better async implementation!
98 | async def aload_and_chunk_docs(
99 | hp: BaseConfigurations, files: list[str]
100 | ) -> list[Document]:
101 | """Async implementation of load_and_chunk_doc function."""
102 | loop = asyncio.get_event_loop()
103 |
104 | futures = [
105 | loop.run_in_executor(None, load_and_chunk_doc, hp, file) for file in files
106 | ]
107 |
108 | results = await asyncio.gather(*futures)
109 | chunks = list(itertools.chain.from_iterable(results))
110 |
111 | return chunks
112 |
--------------------------------------------------------------------------------
/ragflow/utils/hyperparam_chats.py:
--------------------------------------------------------------------------------
1 | import json
2 | import asyncio
3 | from uuid import UUID
4 | from langchain.schema.messages import BaseMessage
5 | import pandas as pd
6 |
7 | from langchain.vectorstores.chroma import Chroma
8 | from langchain.callbacks.streaming_aiter import AsyncIteratorCallbackHandler
9 | from langchain.chains.llm import LLMChain
10 | from langchain.chains.question_answering import load_qa_chain
11 | from langchain.chains import ConversationalRetrievalChain
12 | from langchain.retrievers.multi_query import MultiQueryRetriever
13 | from langchain.memory import ConversationBufferWindowMemory
14 | from langchain.schema import LLMResult, Document
15 | from langchain.chat_models.base import BaseChatModel
16 |
17 | from langchain.chains.conversational_retrieval.prompts import (
18 | CONDENSE_QUESTION_PROMPT,
19 | )
20 |
21 | from ragflow.commons.chroma import ChromaClient
22 | from ragflow.commons.configurations import Hyperparameters
23 | from ragflow.commons.prompts import QA_ANSWER_PROMPT
24 |
25 | from typing import Any, Dict, List, Optional, Sequence
26 |
27 | import logging
28 |
29 | logger = logging.getLogger(__name__)
30 |
31 | chats_cache: Dict[str, Dict[int, ConversationalRetrievalChain]] = {}
32 |
33 |
34 | class AsyncCallbackHandler(AsyncIteratorCallbackHandler):
35 | # content: str = ""
36 | # final_answer: bool = True
37 |
38 | def __init__(self) -> None:
39 | super().__init__()
40 | self.source_documents = None
41 |
42 | async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
43 | self.queue.put_nowait(token)
44 |
45 | async def on_chat_model_start(
46 | self,
47 | serialized: Dict[str, Any],
48 | messages: List[List[BaseMessage]],
49 | *,
50 | run_id: UUID,
51 | parent_run_id: UUID | None = None,
52 | tags: List[str] | None = None,
53 | metadata: Dict[str, Any] | None = None,
54 | **kwargs: Any,
55 | ) -> Any:
56 | pass
57 |
58 | async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
59 | # if source documents were extracted
60 | if self.source_documents:
61 | self.queue.put_nowait(self.source_documents)
62 |
63 | self.source_documents = None
64 | self.done.set()
65 |
66 |
67 | class RetrieverCallbackHandler(AsyncIteratorCallbackHandler):
68 | def __init__(self, streaming_callback_handler: AsyncCallbackHandler) -> None:
69 | super().__init__()
70 | self.streaming_callback_handler = streaming_callback_handler
71 |
72 | async def on_retriever_end(
73 | self, source_docs, *, run_id, parent_run_id, tags, **kwargs
74 | ):
75 | source_docs_list = (
76 | [
77 | {"page_content": doc.page_content, "metadata": doc.metadata}
78 | for doc in source_docs
79 | ]
80 | if source_docs
81 | else None
82 | )
83 |
84 | self.streaming_callback_handler.source_documents = json.dumps(
85 | {"source_documents": source_docs_list}
86 | )
87 |
88 |
89 | async def aquery_chat(
90 | hp_id: int,
91 | hyperparameters_results_path: str,
92 | user_id: str,
93 | api_keys: dict,
94 | query: str,
95 | stream_it: AsyncCallbackHandler,
96 | ):
97 | llm = getOrCreateChatModel(
98 | hp_id, hyperparameters_results_path, user_id, api_keys, stream_it
99 | )
100 |
101 | await llm.acall(
102 | query,
103 | callbacks=[RetrieverCallbackHandler(streaming_callback_handler=stream_it)],
104 | )
105 |
106 |
107 | async def create_gen(
108 | hp_id: int,
109 | hyperparameters_results_path: str,
110 | user_id: str,
111 | api_keys: dict,
112 | query: str,
113 | stream_it: AsyncCallbackHandler,
114 | ):
115 | task = asyncio.create_task(
116 | aquery_chat(
117 | hp_id, hyperparameters_results_path, user_id, api_keys, query, stream_it
118 | )
119 | )
120 | async for token in stream_it.aiter():
121 | yield token
122 | await task
123 |
124 |
125 | def query_chat(
126 | hp_id: int,
127 | hyperparameters_results_path: str,
128 | user_id: str,
129 | api_keys: dict,
130 | query: str,
131 | ) -> dict:
132 | # get or load llm model
133 | llm = getOrCreateChatModel(hp_id, hyperparameters_results_path, user_id, api_keys)
134 |
135 | return llm(query)
136 |
137 |
138 | async def get_docs(
139 | hp_id: int,
140 | hyperparameters_results_path: str,
141 | user_id: str,
142 | api_keys: dict,
143 | query: str,
144 | ) -> List[Document]:
145 | # get or load llm model
146 | llm = getOrCreateChatModel(hp_id, hyperparameters_results_path, user_id, api_keys)
147 |
148 | return await llm.retriever.retriever.aget_relevant_documents(query)
149 |
150 |
151 | def getOrCreateChatModel(
152 | hp_id: int,
153 | hyperparameters_results_path: str,
154 | user_id: str,
155 | api_keys: dict,
156 | streaming_callback: Optional[AsyncCallbackHandler] = None,
157 | ) -> None:
158 | # if model has not been loaded yet
159 | if (
160 | user_id not in chats_cache
161 | or hp_id not in chats_cache[user_id]
162 | or not isinstance(chats_cache[user_id][hp_id], ConversationalRetrievalChain)
163 | ):
164 | # load hyperparameter results
165 | with open(hyperparameters_results_path, encoding="utf-8") as file:
166 | hp_data = json.load(file)
167 |
168 | df = pd.DataFrame(hp_data)
169 |
170 | # check that hp_id really exists in results
171 | if hp_id not in df.id.values:
172 | raise NotImplementedError("Could not find requested hyperparameter run id.")
173 |
174 | with ChromaClient() as client:
175 | for col in client.list_collections():
176 | if col.name == f"userid_{user_id}_hpid_{hp_id}":
177 | collection = col
178 | break
179 |
180 | # check that vectorstore contains collection for hp id
181 | if not collection:
182 | raise NotImplementedError(
183 | "Could not find data in vectorstore for requested hyperparameter run id."
184 | )
185 |
186 | # create retriever and llm from collection
187 | hp_data = df[df.id == hp_id].iloc[0].to_dict()
188 | for key in ["id", "scores", "timestamp"]:
189 | hp_data.pop(key)
190 |
191 | hp = Hyperparameters.from_dict(
192 | input_dict=hp_data,
193 | hp_id=hp_id,
194 | api_keys=api_keys,
195 | )
196 |
197 | index = Chroma(
198 | client=ChromaClient().get_client(),
199 | collection_name=collection.name,
200 | collection_metadata=collection.metadata,
201 | embedding_function=hp.embedding_model,
202 | )
203 |
204 | # baseline retriever built from vectorstore collection
205 | retriever = index.as_retriever(search_type="similarity", search_kwargs={"k": 1})
206 |
207 | # streaming and non-streaming models, create new instance for streaming model
208 | streaming_llm = Hyperparameters.get_language_model(
209 | model_name=Hyperparameters.get_language_model_name(hp.qa_llm),
210 | api_keys=api_keys,
211 | )
212 |
213 | if isinstance(streaming_llm, BaseChatModel):
214 | streaming_llm.streaming = True
215 | streaming_llm.callbacks = [streaming_callback]
216 |
217 | # llm model from hp for non streaming chains
218 | non_streaming_llm = hp.qa_llm
219 |
220 | # LLM chain for generating new question from user query and chat history
221 | question_generator = LLMChain(
222 | llm=non_streaming_llm, prompt=CONDENSE_QUESTION_PROMPT
223 | )
224 |
225 | # llm that answers the newly generated condensed question
226 | doc_qa_chain = load_qa_chain(
227 | streaming_llm, chain_type="stuff", prompt=QA_ANSWER_PROMPT
228 | )
229 |
230 | # advanced retriever using multiple similar queries
231 | multi_query_retriever = MultiQueryRetriever.from_llm(
232 | retriever=retriever, llm=non_streaming_llm, include_original=True
233 | )
234 |
235 | # memory object to store the chat history
236 | memory = ConversationBufferWindowMemory(
237 | memory_key="chat_history", k=5, return_messages=True, output_key="result"
238 | )
239 |
240 | qa_llm = ConversationalRetrievalChain(
241 | retriever=multi_query_retriever,
242 | combine_docs_chain=doc_qa_chain,
243 | question_generator=question_generator,
244 | memory=memory,
245 | output_key="result",
246 | return_source_documents=True,
247 | return_generated_question=True,
248 | response_if_no_docs_found="I don't know the answer to this question.",
249 | )
250 |
251 | # cache llm
252 | if user_id not in chats_cache:
253 | chats_cache[user_id] = {}
254 |
255 | chats_cache[user_id][hp_id] = qa_llm
256 |
257 | logger.info(
258 | f"\nCreated new ConversationalRetrievalChain for {user_id}:{hp_id}."
259 | )
260 |
261 | # model is already loaded
262 |
263 | logger.info(f"\nRetrieved ConversationalRetrievalChain for {user_id}:{hp_id}.")
264 |
265 | # add new streaming callback to qa llm
266 | qa_llm = chats_cache[user_id][hp_id]
267 | qa_llm.combine_docs_chain.llm_chain.llm.callbacks = [streaming_callback]
268 |
269 | return qa_llm
270 |
--------------------------------------------------------------------------------
/ragflow/utils/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | from enum import Enum
4 | from datetime import datetime
5 |
6 | from langchain.chains import RetrievalQA
7 | from langchain.docstore.document import Document
8 |
9 | from langchain.schema.language_model import BaseLanguageModel
10 | from langchain.schema.vectorstore import VectorStoreRetriever
11 |
12 | # vector db
13 | from langchain.vectorstores.chroma import Chroma
14 |
15 | from ragflow.commons.prompts import QA_ANSWER_PROMPT
16 | from ragflow.commons.configurations import Hyperparameters
17 | from ragflow.commons.chroma import ChromaClient
18 |
19 | from typing import Any, Optional, Union
20 |
21 | import json
22 | import logging
23 |
24 | logger = logging.getLogger(__name__)
25 |
26 |
27 | def get_retriever(
28 | chunks: list[Document], hp: Hyperparameters, user_id: str, for_eval: bool = False
29 | ) -> Union[tuple[VectorStoreRetriever, VectorStoreRetriever], VectorStoreRetriever]:
30 | """Sets up a vector database based on the document chunks and the embedding model provided.
31 | Here we use Chroma for the vectorstore.
32 |
33 | Args:
34 | chunks (list[Document]): _description_
35 | hp (Hyperparameters): _description_
36 | user_id (str): _description_
37 | for_eval (bool): for evaluation we return a second retriever with k=10
38 |
39 | Returns:
40 | tuple[VectorStoreRetriever, VectorStoreRetriever]: Returns two retrievers, one for predicting the
41 | answers and one for grading the retrieval where we need to have the top 10 documents returned
42 | """
43 | logger.info("Constructing vectorstore and retriever.")
44 |
45 | vectorstore = Chroma.from_documents(
46 | documents=chunks,
47 | embedding=hp.embedding_model,
48 | client=ChromaClient().get_client(),
49 | collection_name=f"userid_{user_id}_hpid_{hp.id}",
50 | collection_metadata={
51 | "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S,%f")[:-3],
52 | "hnsw:space": hp.similarity_method.value,
53 | },
54 | )
55 |
56 | retriever = vectorstore.as_retriever(
57 | search_type=hp.search_type.value, search_kwargs={"k": hp.num_retrieved_docs}
58 | )
59 |
60 | # if we are evaluating the semantic accuracy of the retriever as well, we have to return a second version of the retriever with different settings
61 | if not for_eval:
62 | return retriever
63 |
64 | retrieverForGrading = vectorstore.as_retriever(
65 | search_type=hp.search_type.value, search_kwargs={"k": 10}
66 | )
67 |
68 | return retriever, retrieverForGrading
69 |
70 |
71 | def get_qa_llm(
72 | retriever: VectorStoreRetriever,
73 | qa_llm: BaseLanguageModel,
74 | return_source_documents: Optional[bool] = True,
75 | ) -> RetrievalQA:
76 | """Sets up a LangChain RetrievalQA model based on a retriever and language model that answers
77 | queries based on retrieved document chunks.
78 |
79 | Args:
80 | qa_llm (Optional[BaseLanguageModel], optional): language model.
81 |
82 | Returns:
83 | RetrievalQA: RetrievalQA object
84 | """
85 | logger.debug("Setting up QA LLM with provided retriever.")
86 |
87 | qa_llm_r = RetrievalQA.from_chain_type(
88 | llm=qa_llm,
89 | chain_type="stuff",
90 | retriever=retriever,
91 | chain_type_kwargs={"prompt": QA_ANSWER_PROMPT},
92 | input_key="question",
93 | return_source_documents=return_source_documents,
94 | )
95 |
96 | return qa_llm_r
97 |
98 |
99 | def read_json(filename: str) -> Any:
100 | """Load dataset from a JSON file."""
101 |
102 | with open(filename, "r", encoding="utf-8") as file:
103 | return json.load(file)
104 |
105 |
106 | def write_json(data: list[dict], filename: str, append: Optional[bool] = True) -> None:
107 | """Function used to store generated QA pairs, i.e. the ground truth.
108 |
109 | Args:
110 | data (_type_): _description_
111 | filename (str, optional): _description_.
112 | """
113 |
114 | logger.info(f"Writting JSON to {filename}.")
115 |
116 | # Check if file exists
117 | if os.path.exists(filename) and append:
118 | # File exists, read the data
119 | with open(filename, "r", encoding="utf-8") as file:
120 | json_data = json.load(file)
121 | # Assuming the data is a list; you can modify as per your requirements
122 | json_data.extend(data)
123 | else:
124 | json_data = data
125 |
126 | # Write the combined data back to the file
127 | with open(filename, "w", encoding="utf-8") as file:
128 | json.dump(json_data, file, indent=4, default=convert_to_serializable)
129 |
130 |
131 | # Convert non-serializable types
132 | def convert_to_serializable(obj: object) -> str:
133 | """Preprocessing step before writing to json file
134 |
135 | Args:
136 | obj (object): _description_
137 |
138 | Returns:
139 | str: _description_
140 | """
141 | if isinstance(obj, Enum):
142 | return obj.value
143 | elif isinstance(obj, np.ndarray):
144 | return obj.tolist()
145 | elif isinstance(obj, set):
146 | return list(obj)
147 | elif callable(obj): # For and similar types
148 | return str(obj)
149 | elif isinstance(obj, type): # For
150 | return str(obj)
151 | return f"WARNING: Type {type(obj).__name__} not serializable!"
152 |
--------------------------------------------------------------------------------
/resources/dev/document_store/msg_life-gb-2021-EN_final_1-15.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AndreasX42/RAGflow/cec31b91af8ab7faa05a3d0376338c7d294d29aa/resources/dev/document_store/msg_life-gb-2021-EN_final_1-15.pdf
--------------------------------------------------------------------------------
/resources/dev/document_store/msg_life-gb-2021-EN_final_16-30.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AndreasX42/RAGflow/cec31b91af8ab7faa05a3d0376338c7d294d29aa/resources/dev/document_store/msg_life-gb-2021-EN_final_16-30.pdf
--------------------------------------------------------------------------------
/resources/dev/document_store/msg_life-gb-2021-EN_final_31-45.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AndreasX42/RAGflow/cec31b91af8ab7faa05a3d0376338c7d294d29aa/resources/dev/document_store/msg_life-gb-2021-EN_final_31-45.pdf
--------------------------------------------------------------------------------
/resources/dev/document_store/msg_life-gb-2021-EN_final_46-59.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AndreasX42/RAGflow/cec31b91af8ab7faa05a3d0376338c7d294d29aa/resources/dev/document_store/msg_life-gb-2021-EN_final_46-59.pdf
--------------------------------------------------------------------------------
/resources/dev/document_store/msg_life-gb-2021-EN_final_60-end.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AndreasX42/RAGflow/cec31b91af8ab7faa05a3d0376338c7d294d29aa/resources/dev/document_store/msg_life-gb-2021-EN_final_60-end.pdf
--------------------------------------------------------------------------------
/resources/dev/hyperparameters.json:
--------------------------------------------------------------------------------
1 | []
--------------------------------------------------------------------------------
/resources/dev/hyperparameters_results.json:
--------------------------------------------------------------------------------
1 | [
2 | {
3 | "chunk_size": 512,
4 | "chunk_overlap": 10,
5 | "length_function_name": "len",
6 | "id": 0,
7 | "num_retrieved_docs": 3,
8 | "search_type": "similarity",
9 | "similarity_method": "cosine",
10 | "use_llm_grader": false,
11 | "qa_llm": "gpt-3.5-turbo",
12 | "embedding_model": "text-embedding-ada-002",
13 | "timestamp": "2023-11-08 20:14:36,501",
14 | "scores": {
15 | "answer_similarity_score": 0.9455289276479025,
16 | "retriever_mrr@3": 0.654818325434439,
17 | "retriever_mrr@5": 0.6714060031595576,
18 | "retriever_mrr@10": 0.680941096817874,
19 | "rouge1": 0.5848481755355666,
20 | "rouge2": 0.4622595944607476,
21 | "rougeLCS": 0.524537744702827,
22 | "correctness_score": -1,
23 | "comprehensiveness_score": -1,
24 | "readability_score": -1,
25 | "retriever_semantic_accuracy": -1
26 | }
27 | }
28 | ]
--------------------------------------------------------------------------------
/resources/dev/label_dataset_gen_params.json:
--------------------------------------------------------------------------------
1 | [
2 | {
3 | "chunk_size": 256,
4 | "chunk_overlap": 10,
5 | "length_function_name": "text-embedding-ada-002",
6 | "qa_generator_llm": "gpt-3.5-turbo",
7 | "persist_to_vs": true,
8 | "embedding_model_list": [
9 | "text-embedding-ada-002"
10 | ]
11 | }
12 | ]
--------------------------------------------------------------------------------
/resources/tests/document_store/churchill_speech.docx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AndreasX42/RAGflow/cec31b91af8ab7faa05a3d0376338c7d294d29aa/resources/tests/document_store/churchill_speech.docx
--------------------------------------------------------------------------------
/resources/tests/document_store/churchill_speech.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AndreasX42/RAGflow/cec31b91af8ab7faa05a3d0376338c7d294d29aa/resources/tests/document_store/churchill_speech.pdf
--------------------------------------------------------------------------------
/resources/tests/input_hyperparameters.json:
--------------------------------------------------------------------------------
1 | [
2 | {
3 | "chunk_size": 256,
4 | "chunk_overlap": 0,
5 | "length_function_name": "len",
6 | "num_retrieved_docs": 1,
7 | "similarity_method": "l2",
8 | "search_type": "similarity",
9 | "embedding_model": "TestDummyEmbedding",
10 | "qa_llm": "TestDummyLLM",
11 | "use_llm_grader": false
12 | },
13 | {
14 | "chunk_size": 512,
15 | "chunk_overlap": 10,
16 | "length_function_name": "len",
17 | "num_retrieved_docs": 2,
18 | "similarity_method": "ip",
19 | "search_type": "mmr",
20 | "embedding_model": "TestDummyEmbedding",
21 | "qa_llm": "TestDummyLLM",
22 | "use_llm_grader": false
23 | },
24 | {
25 | "chunk_size": 512,
26 | "chunk_overlap": 10,
27 | "length_function_name": "len",
28 | "num_retrieved_docs": 3,
29 | "similarity_method": "cosine",
30 | "search_type": "similarity",
31 | "embedding_model": "TestDummyEmbedding",
32 | "qa_llm": "TestDummyLLM",
33 | "use_llm_grader": false
34 | }
35 | ]
--------------------------------------------------------------------------------
/resources/tests/input_label_dataset_gen_params_with_upsert.json:
--------------------------------------------------------------------------------
1 | [
2 | {
3 | "chunk_size": 512,
4 | "chunk_overlap": 10,
5 | "length_function_name": "len",
6 | "qa_generator_llm": "TestDummyLLM",
7 | "persist_to_vs": true,
8 | "embedding_model_list": [
9 | "TestDummyEmbedding",
10 | "TestDummyEmbedding"
11 | ]
12 | }
13 | ]
--------------------------------------------------------------------------------
/resources/tests/input_label_dataset_gen_params_without_upsert.json:
--------------------------------------------------------------------------------
1 | [
2 | {
3 | "chunk_size": 512,
4 | "chunk_overlap": 10,
5 | "length_function_name": "len",
6 | "qa_generator_llm": "TestDummyLLM",
7 | "persist_to_vs": false,
8 | "embedding_model_list": [
9 | "TestDummyEmbedding",
10 | "TestDummyEmbedding"
11 | ]
12 | }
13 | ]
--------------------------------------------------------------------------------
/tests/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM python:3.11
2 |
3 | WORKDIR /tests
4 |
5 | COPY ./requirements.txt ./
6 |
7 | RUN pip install --no-cache-dir --upgrade -r requirements.txt
8 |
9 | COPY ./ ./
10 |
11 | RUN chmod +x wait-for-it.sh
12 |
13 | ENTRYPOINT ["/bin/bash", "-c", "./wait-for-it.sh chromadb-test:8000 --timeout=60 && ./wait-for-it.sh postgres-test:5432 --timeout=60 && ./wait-for-it.sh ragflow-test:8080 --timeout=60 -- pytest"]
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 |
4 | logging.basicConfig(
5 | level=os.environ.get("LOG_LEVEL", "INFO"),
6 | format="%(asctime)s - %(levelname)s - %(name)s:%(filename)s:%(lineno)d - %(message)s",
7 | )
8 |
--------------------------------------------------------------------------------
/tests/conftest.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import glob
3 | import os
4 |
5 | from tests.utils import HYPERPARAMETERS_RESULTS_PATH
6 |
7 | import logging
8 |
9 | logger = logging.getLogger(__name__)
10 |
11 |
12 | @pytest.fixture(scope="function", autouse=True)
13 | def cleanup_output_files():
14 | # Setup: Anything before the yield is the setup. You can leave it empty if there's no setup.
15 |
16 | yield # This will allow the test to run.
17 |
18 | # Teardown: Anything after the yield is the teardown.
19 | for file in glob.glob("./resources/output_*"):
20 | try:
21 | os.remove(file)
22 | except Exception as e:
23 | logger.error(f"Error deleting {file}. Error: {e}")
24 |
--------------------------------------------------------------------------------
/tests/requirements.txt:
--------------------------------------------------------------------------------
1 | pytest==7.4.3
2 | requests==2.31.0
3 | chromadb==0.4.15
--------------------------------------------------------------------------------
/tests/test_backend_configs.py:
--------------------------------------------------------------------------------
1 | from tests.utils import RAGFLOW_HOST, RAGFLOW_PORT, fetch_data
2 |
3 | LLM_MODELS = [
4 | "gpt-3.5-turbo",
5 | "gpt-4",
6 | "Llama-2-7b-chat-hf",
7 | "Llama-2-13b-chat-hf",
8 | "Llama-2-70b-chat-hf",
9 | ]
10 |
11 | EMB_MODELS = ["text-embedding-ada-002"]
12 |
13 | CVGradeAnswerPrompt = ["zero_shot", "few_shot", "none"]
14 | CVGradeRetrieverPrompt = ["default", "none"]
15 | CVRetrieverSearchType = ["similarity", "mmr"]
16 | CVSimilarityMethod = ["cosine", "l2", "ip"]
17 |
18 |
19 | def test_ragflow_config_endpoints():
20 | """Fetch data for all endpoints."""
21 | endpoints = {
22 | "/configs/llm_models": LLM_MODELS,
23 | "/configs/embedding_models": EMB_MODELS,
24 | "/configs/retriever_similarity_methods": CVSimilarityMethod,
25 | "/configs/retriever_search_types": CVRetrieverSearchType,
26 | "/configs/grade_answer_prompts": CVGradeAnswerPrompt,
27 | "/configs/grade_documents_prompts": CVGradeRetrieverPrompt,
28 | }
29 |
30 | for endpoint, expected_values in endpoints.items():
31 | response = fetch_data(
32 | method="get", host=RAGFLOW_HOST, port=RAGFLOW_PORT, endpoint=endpoint
33 | ).json()
34 |
35 | # Assert that the fetched data matches the enum values
36 | if isinstance(expected_values, list):
37 | assert set(response) == set(
38 | expected_values
39 | ), f"Mismatch in endpoint {endpoint}"
40 |
41 | else:
42 | raise NotImplementedError(
43 | f"Type {type(expected_values)} not implemented for endpoint {endpoint}."
44 | )
45 |
--------------------------------------------------------------------------------
/tests/test_backend_generator.py:
--------------------------------------------------------------------------------
1 | import requests
2 | import pytest
3 | import json
4 | import chromadb
5 | import logging
6 |
7 | from tests.utils import (
8 | RAGFLOW_HOST,
9 | RAGFLOW_PORT,
10 | CHROMADB_HOST,
11 | CHROMADB_PORT,
12 | LABEL_DATASET_ENDPOINT,
13 | DOCUMENT_STORE_PATH,
14 | LABEL_DATASET_PATH,
15 | fetch_data,
16 | first_user_id,
17 | second_user_id,
18 | user_id_without_upsert,
19 | )
20 |
21 | logger = logging.getLogger(__name__)
22 |
23 |
24 | def test_generator_without_upsert(user_id_without_upsert):
25 | """Test QA generator without upserting to ChromaDB. Since we use a fake LLM provided by LangChain for testing purposes to mock a real one, we are not able to generate real QA pairs and the generated label_dataset.json file is empty."""
26 |
27 | json_payload = {
28 | "document_store_path": DOCUMENT_STORE_PATH,
29 | "label_dataset_path": "./resources/output_label_dataset_without_upsert.json",
30 | "label_dataset_gen_params_path": "./resources/input_label_dataset_gen_params_without_upsert.json",
31 | "user_id": user_id_without_upsert,
32 | "api_keys": {},
33 | }
34 |
35 | response = fetch_data(
36 | method="post",
37 | host=RAGFLOW_HOST,
38 | port=RAGFLOW_PORT,
39 | endpoint=LABEL_DATASET_ENDPOINT,
40 | payload=json_payload,
41 | )
42 |
43 | assert (
44 | response.status_code == 200
45 | ), f"The response from the {LABEL_DATASET_ENDPOINT} endpoint should be ok."
46 |
47 | with open(json_payload["label_dataset_path"], encoding="utf-8") as file:
48 | data = json.load(file)
49 |
50 | assert (
51 | not data
52 | ), f"Expected no content in {json_payload['label_dataset_path']}, but found it not empty"
53 |
54 |
55 | @pytest.mark.parametrize(
56 | "user_id_fixture, expected_num_collections, expected_error",
57 | [
58 | ("first_user_id", 1, None),
59 | ("second_user_id", 2, None),
60 | ],
61 | )
62 | def test_generator_with_upsert(
63 | user_id_fixture, expected_num_collections, expected_error, request
64 | ):
65 | """Helper function to run generator with upserting twice to test upsert functionality"""
66 |
67 | def get_fixture_value(fixture_name):
68 | """Get the value associated with a fixture name."""
69 | return request.getfixturevalue(fixture_name)
70 |
71 | user_id = get_fixture_value(user_id_fixture)
72 |
73 | json_payload = {
74 | "document_store_path": DOCUMENT_STORE_PATH,
75 | "label_dataset_path": "./resources/output_label_dataset_with_upsert.json",
76 | "label_dataset_gen_params_path": "./resources/input_label_dataset_gen_params_with_upsert.json",
77 | "user_id": user_id,
78 | "api_keys": {},
79 | }
80 |
81 | # case when a invalid UUID gets passed to endpoints
82 | if expected_error:
83 | with pytest.raises(
84 | expected_error,
85 | match=f"Unprocessable Entity for url: http://ragflow-test:8080{LABEL_DATASET_ENDPOINT}",
86 | ):
87 | # Assuming the error will arise here when pydantic throws an exception because user id is not valid
88 | response = fetch_data(
89 | method="post",
90 | host=RAGFLOW_HOST,
91 | port=RAGFLOW_PORT,
92 | endpoint=LABEL_DATASET_ENDPOINT,
93 | payload=json_payload,
94 | )
95 |
96 | return
97 |
98 | # case for valid payload data
99 | response = fetch_data(
100 | method="post",
101 | host=RAGFLOW_HOST,
102 | port=RAGFLOW_PORT,
103 | endpoint=LABEL_DATASET_ENDPOINT,
104 | payload=json_payload,
105 | )
106 | assert (
107 | response.status_code == 200
108 | ), "The response from the {LABEL_DATASET_ENDPOINT} endpoint should be ok."
109 |
110 | with open(json_payload["label_dataset_path"], encoding="utf-8") as file:
111 | data = json.load(file)
112 |
113 | assert (
114 | data
115 | ), f"Expected content in {json_payload['label_dataset_path']}, but found it empty or null."
116 | # connect to chromadb
117 | client = chromadb.HttpClient(
118 | host=CHROMADB_HOST,
119 | port=CHROMADB_PORT,
120 | )
121 |
122 | collections_list = client.list_collections()
123 |
124 | assert (
125 | len(collections_list) == expected_num_collections
126 | ), f"Expected {expected_num_collections} collections in ChromaDB but found {len(collections_list)}."
127 |
128 | # get most recently created collection
129 | sorted_collections = sorted(
130 | collections_list, key=lambda col: col.metadata["timestamp"], reverse=True
131 | )
132 |
133 | collection = sorted_collections[0]
134 |
135 | assert (
136 | collection.name == f"userid_{user_id}_qaid_0_TestDummyEmbedding"
137 | ), "Name of collection should contain first 8 chars of user id and embedding model name."
138 |
139 | assert (
140 | collection.name == f"userid_{user_id}_qaid_0_TestDummyEmbedding"
141 | ), "The metadata dict of collection should also contain the entire user id"
142 | embeddings = collection.get(include=["embeddings"])["embeddings"]
143 |
144 | assert len(embeddings) == len(data)
145 |
146 | assert len(embeddings[0]) == 2
147 |
--------------------------------------------------------------------------------
/tests/test_backend_hp_evaluator.py:
--------------------------------------------------------------------------------
1 | import json
2 | import logging
3 | import chromadb
4 | import pytest
5 |
6 | from tests.utils import (
7 | RAGFLOW_HOST,
8 | RAGFLOW_PORT,
9 | CHROMADB_HOST,
10 | CHROMADB_PORT,
11 | HP_EVALUATION_ENDPOINT,
12 | DOCUMENT_STORE_PATH,
13 | LABEL_DATASET_PATH,
14 | HYPERPARAMETERS_PATH,
15 | HYPERPARAMETERS_RESULTS_PATH,
16 | HYPERPARAMETERS_RESULTS_DATA_PATH,
17 | fetch_data,
18 | first_user_id,
19 | second_user_id,
20 | user_id_without_upsert,
21 | )
22 |
23 | logger = logging.getLogger(__name__)
24 |
25 |
26 | @pytest.mark.parametrize(
27 | "user_id_fixture, num_hp_run",
28 | [("first_user_id", 0), ("second_user_id", 1), ("first_user_id", 2)],
29 | )
30 | def test_ragflow_evaluator(user_id_fixture, num_hp_run, request):
31 | """Test the evaluation of provided hyperparameter configurations."""
32 |
33 | def get_fixture_value(fixture_name):
34 | """Get the value associated with a fixture name."""
35 | return request.getfixturevalue(fixture_name)
36 |
37 | user_id = get_fixture_value(user_id_fixture)
38 |
39 | json_payload = {
40 | "document_store_path": DOCUMENT_STORE_PATH,
41 | "label_dataset_path": LABEL_DATASET_PATH,
42 | "hyperparameters_path": HYPERPARAMETERS_PATH,
43 | "hyperparameters_results_path": HYPERPARAMETERS_RESULTS_PATH,
44 | "hyperparameters_results_data_path": HYPERPARAMETERS_RESULTS_DATA_PATH,
45 | "user_id": user_id,
46 | "api_keys": {},
47 | }
48 |
49 | response = fetch_data(
50 | method="post",
51 | host=RAGFLOW_HOST,
52 | port=RAGFLOW_PORT,
53 | endpoint=HP_EVALUATION_ENDPOINT,
54 | payload=json_payload,
55 | )
56 |
57 | client = chromadb.HttpClient(
58 | host=CHROMADB_HOST,
59 | port=CHROMADB_PORT,
60 | )
61 |
62 | collections_list = client.list_collections()
63 | collection_names = [col.name for col in collections_list]
64 |
65 | # Check if hp_id get incremented correctly for users, the second evaluation of first user should create a collection with hpid_1 as suffix
66 | if len(collections_list) in range(3, 6):
67 | for i in range(3):
68 | assert (
69 | f"userid_{user_id}_hpid_{i}" in collection_names
70 | ), "First user with 3 hp evals should have 3 corresponding collections now"
71 | elif len(collections_list) in range(6, 9):
72 | for i in range(3):
73 | assert (
74 | f"userid_{user_id}_hpid_{i}" in collection_names
75 | ), "Second user with 3 hp evals should have 3 corresponding collections now"
76 | elif len(collections_list) in range(9, 12):
77 | for i in range(3, 6):
78 | assert (
79 | f"userid_{user_id}_hpid_{i}" in collection_names
80 | ), "First user with next 3 hp evals should have 6 corresponding collections now"
81 | else:
82 | assert False, "Unexpected collections_list length"
83 |
84 | # Check if hyperparameters from json were set correctly
85 | assert (
86 | response.status_code == 200
87 | ), f"The response from the {HP_EVALUATION_ENDPOINT} endpoint should be ok."
88 |
89 | # test the generated results of the evaluation run
90 | with open(json_payload["hyperparameters_results_path"], encoding="utf-8") as file:
91 | hyperparameters_results_list = json.load(file)
92 |
93 | with open(json_payload["hyperparameters_path"], encoding="utf-8") as file:
94 | hyperparameters_list = json.load(file)
95 |
96 | hyperparameters_results = hyperparameters_results_list[num_hp_run]
97 | hyperparameters = hyperparameters_list[num_hp_run]
98 |
99 | assert (
100 | hyperparameters["similarity_method"]
101 | == hyperparameters_results["vectorstore_params"]["metadata"]["hnsw:space"]
102 | ), "Retriever similarity methods should match."
103 |
104 | assert (
105 | hyperparameters["search_type"]
106 | == hyperparameters_results["retriever_params"]["search_type"]
107 | ), "Retriever search types should match."
108 |
109 | assert (
110 | hyperparameters["num_retrieved_docs"]
111 | == hyperparameters_results["retriever_params"]["search_kwargs"]["k"]
112 | ), "Number of retrieved docs should match."
113 |
114 | assert (
115 | hyperparameters["embedding_model"]
116 | == hyperparameters_results["retriever_params"]["tags"][1]
117 | ), "Embedding models should match."
118 |
--------------------------------------------------------------------------------
/tests/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import Optional
3 | import requests
4 | import pytest
5 |
6 | # Constants
7 | RAGFLOW_HOST = os.environ.get("RAGFLOW_HOST")
8 | RAGFLOW_PORT = os.environ.get("RAGFLOW_PORT")
9 |
10 | CHROMADB_HOST = os.environ.get("CHROMADB_HOST")
11 | CHROMADB_PORT = os.environ.get("CHROMADB_PORT")
12 |
13 | LABEL_DATASET_ENDPOINT = "/generation"
14 | HP_EVALUATION_ENDPOINT = "/evaluation"
15 |
16 | DOCUMENT_STORE_PATH = "./resources/document_store"
17 | LABEL_DATASET_PATH = "./resources/input_label_dataset.json"
18 | HYPERPARAMETERS_PATH = "./resources/input_hyperparameters.json"
19 | HYPERPARAMETERS_RESULTS_PATH = "./resources/output_hyperparameters_results.json"
20 | HYPERPARAMETERS_RESULTS_DATA_PATH = (
21 | "./resources/output_hyperparameters_results_data.csv"
22 | )
23 | # pytest fixtures
24 |
25 |
26 | @pytest.fixture
27 | def user_id_without_upsert() -> str:
28 | """Returns a test user id of 0."""
29 | return "1"
30 |
31 |
32 | @pytest.fixture
33 | def first_user_id() -> str:
34 | """Returns a test user id of 0."""
35 | return "2"
36 |
37 |
38 | @pytest.fixture
39 | def second_user_id() -> str:
40 | """Returns a test user id of 1."""
41 | return "3"
42 |
43 |
44 | # helper functions
45 | def fetch_data(
46 | method: str,
47 | host: str,
48 | port: str,
49 | endpoint: str,
50 | payload: Optional[dict] = None,
51 | ):
52 | """Fetch data from the given endpoint."""
53 |
54 | if method == "get":
55 | response = requests.get(f"http://{host}:{port}{endpoint}")
56 | else:
57 | response = requests.post(f"http://{host}:{port}{endpoint}", json=payload)
58 |
59 | # Handle response errors if needed
60 | response.raise_for_status()
61 |
62 | return response
63 |
--------------------------------------------------------------------------------
/tests/wait-for-it.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # Use this script to test if a given TCP host/port are available
3 |
4 | WAITFORIT_cmdname=${0##*/}
5 |
6 | echoerr() { if [[ $WAITFORIT_QUIET -ne 1 ]]; then echo "$@" 1>&2; fi }
7 |
8 | usage()
9 | {
10 | cat << USAGE >&2
11 | Usage:
12 | $WAITFORIT_cmdname host:port [-s] [-t timeout] [-- command args]
13 | -h HOST | --host=HOST Host or IP under test
14 | -p PORT | --port=PORT TCP port under test
15 | Alternatively, you specify the host and port as host:port
16 | -s | --strict Only execute subcommand if the test succeeds
17 | -q | --quiet Don't output any status messages
18 | -t TIMEOUT | --timeout=TIMEOUT
19 | Timeout in seconds, zero for no timeout
20 | -- COMMAND ARGS Execute command with args after the test finishes
21 | USAGE
22 | exit 1
23 | }
24 |
25 | wait_for()
26 | {
27 | if [[ $WAITFORIT_TIMEOUT -gt 0 ]]; then
28 | echoerr "$WAITFORIT_cmdname: waiting $WAITFORIT_TIMEOUT seconds for $WAITFORIT_HOST:$WAITFORIT_PORT"
29 | else
30 | echoerr "$WAITFORIT_cmdname: waiting for $WAITFORIT_HOST:$WAITFORIT_PORT without a timeout"
31 | fi
32 | WAITFORIT_start_ts=$(date +%s)
33 | while :
34 | do
35 | (echo > /dev/tcp/$WAITFORIT_HOST/$WAITFORIT_PORT) >/dev/null 2>&1
36 | result=$?
37 | if [[ $result -eq 0 ]]; then
38 | WAITFORIT_end_ts=$(date +%s)
39 | echoerr "$WAITFORIT_cmdname: $WAITFORIT_HOST:$WAITFORIT_PORT is available after $((WAITFORIT_end_ts - WAITFORIT_start_ts)) seconds"
40 | break
41 | fi
42 | sleep 1
43 | done
44 | }
45 |
46 | wait_for_wrapper()
47 | {
48 | # In order to support SIGINT during timeout: http://unix.stackexchange.com/a/57692
49 | if [[ $WAITFORIT_QUIET -eq 1 ]]; then
50 | timeout $WAITFORIT_BUSYTIMEFLAG $WAITFORIT_TIMEOUT $0 --quiet --child --host=$WAITFORIT_HOST --port=$WAITFORIT_PORT --timeout=$WAITFORIT_TIMEOUT &
51 | else
52 | timeout $WAITFORIT_BUSYTIMEFLAG $WAITFORIT_TIMEOUT $0 --child --host=$WAITFORIT_HOST --port=$WAITFORIT_PORT --timeout=$WAITFORIT_TIMEOUT &
53 | fi
54 | WAITFORIT_PID=$!
55 | trap "kill -INT -$WAITFORIT_PID" INT
56 | wait $WAITFORIT_PID
57 | WAITFORIT_RESULT=$?
58 | if [[ $WAITFORIT_RESULT -ne 0 ]]; then
59 | echoerr "$WAITFORIT_cmdname: timeout occurred after waiting $WAITFORIT_TIMEOUT seconds for $WAITFORIT_HOST:$WAITFORIT_PORT"
60 | fi
61 | return $WAITFORIT_RESULT
62 | }
63 |
64 | # process arguments
65 | while [[ $# -gt 0 ]]
66 | do
67 | case "$1" in
68 | *:* )
69 | WAITFORIT_hostport=(${1//:/ })
70 | WAITFORIT_HOST=${WAITFORIT_hostport[0]}
71 | WAITFORIT_PORT=${WAITFORIT_hostport[1]}
72 | shift 1
73 | ;;
74 | --child)
75 | WAITFORIT_CHILD=1
76 | shift 1
77 | ;;
78 | -q | --quiet)
79 | WAITFORIT_QUIET=1
80 | shift 1
81 | ;;
82 | -s | --strict)
83 | WAITFORIT_STRICT=1
84 | shift 1
85 | ;;
86 | -h)
87 | WAITFORIT_HOST="$2"
88 | if [[ $WAITFORIT_HOST == "" ]]; then break; fi
89 | shift 2
90 | ;;
91 | --host=*)
92 | WAITFORIT_HOST="${1#*=}"
93 | shift 1
94 | ;;
95 | -p)
96 | WAITFORIT_PORT="$2"
97 | if [[ $WAITFORIT_PORT == "" ]]; then break; fi
98 | shift 2
99 | ;;
100 | --port=*)
101 | WAITFORIT_PORT="${1#*=}"
102 | shift 1
103 | ;;
104 | -t)
105 | WAITFORIT_TIMEOUT="$2"
106 | if [[ $WAITFORIT_TIMEOUT == "" ]]; then break; fi
107 | shift 2
108 | ;;
109 | --timeout=*)
110 | WAITFORIT_TIMEOUT="${1#*=}"
111 | shift 1
112 | ;;
113 | --)
114 | shift
115 | WAITFORIT_CLI=("$@")
116 | break
117 | ;;
118 | --help)
119 | usage
120 | ;;
121 | *)
122 | echoerr "Unknown argument: $1"
123 | usage
124 | ;;
125 | esac
126 | done
127 |
128 | if [[ "$WAITFORIT_HOST" == "" || "$WAITFORIT_PORT" == "" ]]; then
129 | echoerr "Error: you need to provide a host and port to test."
130 | usage
131 | fi
132 |
133 | WAITFORIT_TIMEOUT=${WAITFORIT_TIMEOUT:-15}
134 | WAITFORIT_STRICT=${WAITFORIT_STRICT:-0}
135 | WAITFORIT_CHILD=${WAITFORIT_CHILD:-0}
136 | WAITFORIT_QUIET=${WAITFORIT_QUIET:-0}
137 |
138 | # Check to see if timeout is from busybox?
139 | WAITFORIT_TIMEOUT_PATH=$(type -p timeout)
140 | WAITFORIT_TIMEOUT_PATH=$(realpath $WAITFORIT_TIMEOUT_PATH 2>/dev/null || readlink -f $WAITFORIT_TIMEOUT_PATH)
141 |
142 | WAITFORIT_BUSYTIMEFLAG=""
143 | if [[ $WAITFORIT_TIMEOUT_PATH =~ "busybox" ]]; then
144 | WAITFORIT_ISBUSY=1
145 | # Check if busybox timeout uses -t flag
146 | # (recent Alpine versions don't support -t anymore)
147 | if timeout &>/dev/stdout | grep -q -e '-t '; then
148 | WAITFORIT_BUSYTIMEFLAG="-t"
149 | fi
150 | else
151 | WAITFORIT_ISBUSY=0
152 | fi
153 |
154 | if [[ $WAITFORIT_CHILD -gt 0 ]]; then
155 | wait_for
156 | WAITFORIT_RESULT=$?
157 | exit $WAITFORIT_RESULT
158 | else
159 | if [[ $WAITFORIT_TIMEOUT -gt 0 ]]; then
160 | wait_for_wrapper
161 | WAITFORIT_RESULT=$?
162 | else
163 | wait_for
164 | WAITFORIT_RESULT=$?
165 | fi
166 | fi
167 |
168 | if [[ $WAITFORIT_CLI != "" ]]; then
169 | if [[ $WAITFORIT_RESULT -ne 0 && $WAITFORIT_STRICT -eq 1 ]]; then
170 | echoerr "$WAITFORIT_cmdname: strict mode, refusing to execute subprocess"
171 | exit $WAITFORIT_RESULT
172 | fi
173 | exec "${WAITFORIT_CLI[@]}"
174 | else
175 | exit $WAITFORIT_RESULT
176 | fi
177 |
--------------------------------------------------------------------------------
/vectorstore/.dockerignore:
--------------------------------------------------------------------------------
1 | # git
2 | .git
3 | .gitignore
4 | .cache
5 | **.log
6 |
7 | # data/chromadb
8 | data/datasets.zip
9 | data/desc.docx
10 | chromadb/
11 |
12 | # Byte-compiled / optimized / DLL files
13 | **/__pycache__
14 | **/*.pyc
15 | **.py[cod]
16 | *$py.class
17 | .vscode/
18 |
19 | # Jupyter Notebook
20 | **/.ipynb_checkpoints
21 | notebooks/
22 |
23 | # dotenv
24 | .env
25 |
26 | # virtualenv
27 | .venv*
28 | venv*/
29 | ENV*/
--------------------------------------------------------------------------------
/vectorstore/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM python:3.11
2 |
3 | EXPOSE 8000
4 |
5 | WORKDIR /vectorstore
6 |
7 | COPY ./requirements.txt ./
8 |
9 | RUN pip install --no-cache-dir -r requirements.txt
10 |
11 | COPY ./server.py ./
12 |
13 | CMD ["uvicorn", "server:app", "--host", "0.0.0.0", "--port", "8000"]
--------------------------------------------------------------------------------
/vectorstore/requirements.txt:
--------------------------------------------------------------------------------
1 | chromadb==0.4.15
2 | uvicorn==0.24.0.post1
--------------------------------------------------------------------------------
/vectorstore/server.py:
--------------------------------------------------------------------------------
1 | import chromadb
2 | import chromadb.config
3 | from chromadb.server.fastapi import FastAPI
4 |
5 | settings = chromadb.config.Settings()
6 |
7 | server = FastAPI(settings)
8 | app = server.app
9 |
--------------------------------------------------------------------------------