├── .gitignore ├── Dockerfile ├── LICENSE ├── README.md ├── app.py ├── chat_utils.py ├── chatrag.py ├── config.py ├── gradio_utils.py ├── model_utils.py ├── pics ├── RAG_Query.png ├── model_dropdown.png ├── query.png └── start_state.png ├── requirements.txt └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | /data 2 | /Databases 3 | neo4j_info.txt 4 | /.idea 5 | /__pycache__ 6 | .env -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvcr.io/nvidia/pytorch:25.02-py3 2 | WORKDIR /home/jake/Programming/Personal/Chat-RAG 3 | 4 | COPY . . 5 | RUN true 6 | RUN pip install --no-cache-dir -r requirements.txt 7 | EXPOSE 7860 8 | EXPOSE 5000 9 | 10 | RUN useradd Jake 11 | USER Jake 12 | 13 | CMD ["python", "app.py", "--host", "0.0.0.0", "--port", "8080"] -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Chat RAG: Interactive Coding Assistant 2 | 3 | ## Overview 4 | Chat RAG is an advanced interactive coding assistant that leverages Retrieval-Augmented Generation (RAG) to provide 5 | informed responses to coding queries. Built with a user-friendly Gradio interface, it allows users to interact 6 | with various language models, customize model parameters, and upload context files from local directories or GitHub 7 | repositories for more accurate assistance. 8 | 9 | 10 | ## Features 11 | - **Multiple Model Providers**: Support for Ollama, HuggingFace, NVIDIA NIM, OpenAI, and Anthropic models. 12 | (**If you don't see all of these providers make sure you have all the environment variables set in the .env file!**) 13 | - **Wide Range of Language Models**: Choose from models like Codestral, Mistral-Nemo, LLaMA3.1, DeepSeek Coder v2, Gemma2, and CodeGemma. 14 | - **Dynamic Model Switching**: Seamlessly switch between different language models. 15 | - **Customizable Model Parameters**: Adjust temperature, max tokens, top-p, and context window size. 16 | - **Interactive Chat Interface**: Easy-to-use chat interface for asking coding questions. 17 | - **RAG-powered Responses**: Utilizes uploaded documents or enter a GitHub repository to provide context-aware answers. 18 | - **Chat With Files**: Support for uploading additional context files. 19 | - **Chat with a GitHub Repo:** Support for using a GitHub repositories files as context for the model. 20 | - **Chat With a Database:** Support of connecting a new or existing database. **(Coming Soon)** 21 | - **Custom Prompts**: Ability to set custom system prompts for the chat engine. 22 | - **Enhanced Memory Management**: Dynamically manage chat memory for different models. 23 | - **Streaming Responses**: Real-time response generation for a more interactive experience. 24 | - **Model Quantization**: Options for 2-bit(Double 4 Bit Quant), 4-bit, and 8-bit quantization for HuggingFace models. 25 | - **Parsing Advanced File Types**: Parsing with Llama Parse for .pdf, .csv, .xlsx, .docx, .xml. 26 | 27 | 28 | ## Setup and Usage 29 | 1. Clone the repository. 30 | 2. Install the required dependencies. 31 | 3. Set up your .env file with the following: 32 | ```bash 33 | GRADIO_TEMP_DIR="YourPathTo/Chat-RAG/data" 34 | GRADIO_WATCH_DIRS="YourPathTo/Chat-RAG" 35 | HUGGINGFACE_HUB_TOKEN="YOUR HF TOKEN HERE" 36 | NVIDIA_API_KEY="YOUR NVIDIA API KEY HERE" 37 | OPENAI_API_KEY="YOUR OpenAI API KEY HERE" 38 | ANTHROPIC_API_KEY="YOUR Anthropic API KEY HERE" 39 | GITHUB_PAT="YOUR GITHUB PERSONAL ACCESS TOKEN HERE" 40 | LLAMA_CLOUD_API_KEY="YOUR LLAMA_CLOUD_API_KEY" 41 | ``` 42 | 4. Run the application: 43 | ```bash 44 | gradio chatrag.py 45 | ``` 46 | or 47 | ```commandline 48 | python app.py 49 | ``` 50 | 5. The app will automatically open a new tab and launch in your browser. 51 | 6. Select a Model Provider. 52 | 7. Select a language model from the dropdown menu. 53 | 8. (Optional) Upload relevant files for additional context. 54 | 9. Type your coding question in the text box and press enter. 55 | 10. The model will stream the response to your query back to you in the chat window. 56 | 57 | 58 | ## Project Structure 59 | - `app.py`: If you don't want to run it in gradio live reload, use this file. 60 | - `chatrag.py`: Main application file with Gradio UI setup. 61 | - `chat.py`: Utilities for document loading and chat engine creation. 62 | - `gr_utils.py`: Gradio-specific utility functions for UI interactions. 63 | - `model_utils.py`: Model management and configuration utilities. 64 | - `utils.py`: General utilities for embedding, LLM setup, and chat memory. 65 | 66 | 67 | ## Pictures 68 | ### Start State of the App 69 | ![Start State of the App](pics/start_state.png "Start State of the App") 70 | ### Dropdown Menu in Action 71 | ![Dropdown Menu](pics/model_dropdown.png "Dropdown Menu in Action") 72 | ### Query Example 73 | ![Query Example](pics/query.png "Query Example") 74 | ### RAG Query Example 75 | ![RAG Query Example](pics/RAG_Query.png "RAG Query Example") 76 | 77 | 78 | ### Contributing 79 | Contributions are welcome! Please feel free to submit a Pull Request or Fork the Repository. 80 | 81 | 82 | ### Coming in Future Updates 83 | - Video of the program in action. 84 | - Add the ability to load an existing Neo4j DB into the model 85 | - The ability to add models to the list for different model providers. 86 | 87 | 88 | ### Need Help or Have Feature Suggestions? 89 | Feel free to reach out to me through GitHub, LinkedIn, or through email. All of those are available on my website [JFCoded](https://www.jfcoded.com/contact). -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | # This file is used to launch the program if the user wants to launch it using python argument versus gradio 2 | from chatrag import demo 3 | 4 | demo.launch(inbrowser=True, share=True) -------------------------------------------------------------------------------- /chat_utils.py: -------------------------------------------------------------------------------- 1 | from llama_index.core import SimpleDirectoryReader, StorageContext 2 | from llama_parse import LlamaParse 3 | from llama_index.vector_stores.neo4jvector import Neo4jVectorStore 4 | from llama_index.vector_stores.chroma import ChromaVectorStore 5 | from llama_index.vector_stores.milvus import MilvusVectorStore 6 | from llama_index.readers.github import GithubClient, GithubRepositoryReader 7 | from utils import (setup_index_and_chat_engine, set_embedding_model, set_chat_memory, 8 | set_ollama_llm, set_huggingface_llm, set_nvidia_model, set_openai_model, set_anth_model) 9 | import torch, os, glob, gc, dotenv, chromadb 10 | dotenv.load_dotenv() 11 | 12 | DIRECTORY_PATH = "data" 13 | Neo4j_DB_PATH = "Databases/Neo4j" 14 | Chroma_DB_PATH = "Databases/ChromaDB" 15 | Milvus_DB_PATH = "Databases/MilvusDB" 16 | EMBED_MODEL = set_embedding_model() 17 | 18 | # TODO Add free parsing options for advanced docs, Llama Parse only lets you parse 1000 free docs a day 19 | # TODO Figure out why multiprocessing of docs causes program to reload in a loop 20 | # Local Document Loading Function 21 | def load_local_docs(): 22 | parser = LlamaParse(api_key=os.getenv("LLAMA_CLOUD_API_KEY")) 23 | all_files = glob.glob(os.path.join(DIRECTORY_PATH, "**", "*"), recursive=True) 24 | all_files = [f for f in all_files if os.path.isfile(f)] 25 | documents = [] 26 | supported_extensions = [".pdf", ".docx", ".xlsx", ".csv", ".xml", ".html", ".json"] 27 | for file in all_files: 28 | file_extension = os.path.splitext(file)[1].lower() 29 | if "LLAMA_CLOUD_API_KEY" in os.environ and file_extension in supported_extensions: 30 | file_extractor = {file_extension: parser} 31 | documents.extend( 32 | SimpleDirectoryReader(input_files=[file], file_extractor=file_extractor).load_data()) 33 | else: 34 | documents.extend(SimpleDirectoryReader(input_files=[file]).load_data()) 35 | return documents 36 | 37 | # GitHub Repo Reader setup function. Sets all initial parameters and handles data load of the repository 38 | def load_github_repo(owner, repo, branch): 39 | if "GITHUB_PAT" in os.environ: 40 | github_client = GithubClient(github_token=os.getenv("GITHUB_PAT"), verbose=True) 41 | owner=owner 42 | repo=repo 43 | branch=branch 44 | documents= GithubRepositoryReader( 45 | github_client=github_client, 46 | owner=owner, 47 | repo=repo, 48 | use_parser=False, 49 | verbose=False, 50 | filter_file_extensions=([".png", ".jpg", ".jpeg", ".gif", ".svg"], 51 | GithubRepositoryReader.FilterType.EXCLUDE) 52 | ).load_data(branch=branch) 53 | return documents 54 | else: 55 | print("Couldn't find your GitHub Personal Access Token in the environment file. Make sure you enter your " 56 | "GitHub Personal Access Token in the .env file.") 57 | 58 | 59 | # TODO Finish and Test Vector Store implementation 60 | # Setting up different vector stores 61 | def setup_vector_store(vector_store, username, password, url, collection_name): 62 | if vector_store == "Neo4j": 63 | username = username 64 | password = password 65 | url = url 66 | embed_dim = 1536 67 | neo4j_vector_store = Neo4jVectorStore(username, 68 | password, 69 | url, 70 | embed_dim, 71 | database=collection_name, 72 | hybrid_search=True, 73 | distance_strategy="euclidean") 74 | storage_context = StorageContext.from_defaults(vector_store=neo4j_vector_store) 75 | return storage_context 76 | elif vector_store == "ChromaDB": 77 | chroma_client = chromadb.EphemeralClient() 78 | # Check to see if collection exists already 79 | chroma_collection = "" 80 | for c in chroma_client.list_collections(): 81 | if c == collection_name: 82 | chroma_collection = chroma_client.get_collection(collection_name) 83 | else: 84 | chroma_collection = chroma_client.create_collection(collection_name) 85 | chroma_vector_store = ChromaVectorStore(chroma_collection, 86 | persist_dir=Chroma_DB_PATH) 87 | storage_context = StorageContext.from_defaults(vector_store=chroma_vector_store) 88 | return storage_context 89 | elif vector_store == "Milvus": 90 | milvus_vector_store = MilvusVectorStore(collection_name=collection_name, 91 | dim=1536, 92 | overwrite=False) 93 | storage_context = StorageContext.from_defaults(vector_store=milvus_vector_store) 94 | return storage_context 95 | else: 96 | storage_context = None 97 | return storage_context 98 | 99 | # Calls setup chat engine function with model and data personalization's user inputs from front end 100 | def create_chat_engine(model_provider, model, temperature, max_tokens, custom_prompt, top_p, 101 | context_window, quantization, owner, repo, branch, vector_store, username, password, url, 102 | collection_name): 103 | # Clearing GPU Memory 104 | torch.cuda.empty_cache() 105 | gc.collect() 106 | # Loading local Documents and GitHub Repos if applicable 107 | documents = load_local_docs() 108 | if owner and repo and branch: 109 | documents.extend(load_github_repo(owner, repo, branch)) 110 | # Loading Storage Context if any is set by a vector store 111 | if vector_store is not None or "": 112 | storage_context = setup_vector_store(vector_store, username, password, url, collection_name) 113 | else: 114 | storage_context = None 115 | # Loading Embedding Model from global parameter 116 | embed_model = EMBED_MODEL 117 | # Loading LLM based off users input 118 | llm_setters = { 119 | "Ollama": lambda: set_ollama_llm(model, temperature, max_tokens), 120 | "HuggingFace": lambda: set_huggingface_llm(model, temperature, max_tokens, top_p, context_window, quantization), 121 | "NVIDIA NIM": lambda: set_nvidia_model(model, temperature, max_tokens, top_p), 122 | "OpenAI": lambda: set_openai_model(model, temperature, max_tokens, top_p), 123 | "Anthropic": lambda: set_anth_model(model, temperature, max_tokens) 124 | } 125 | try: 126 | llm = llm_setters[model_provider]() 127 | except KeyError: 128 | raise ValueError(f"Unsupported model provider: {model_provider}") 129 | # Setting model memory 130 | memory = set_chat_memory(model) 131 | return setup_index_and_chat_engine(docs=documents, llm=llm, embed_model=embed_model, 132 | memory=memory, custom_prompt=custom_prompt, storage_context=storage_context) 133 | -------------------------------------------------------------------------------- /chatrag.py: -------------------------------------------------------------------------------- 1 | import os 2 | import gradio as gr 3 | from gradio_utils import GradioUtils 4 | from model_utils import ModelManager 5 | import dotenv 6 | dotenv.load_dotenv() 7 | gradioUtils = GradioUtils() 8 | modelUtils = ModelManager() 9 | 10 | Neo4j_DB_PATH = "Databases/Neo4j" 11 | Chroma_DB_PATH = "Databases/ChromaDB" 12 | Milvus_DB_PATH = "Databases/MilvusDB" 13 | 14 | css = """ 15 | .gradio-container{ 16 | background:radial-gradient(#416e8a, #000000); 17 | } 18 | #button{ 19 | background:#06354d 20 | } 21 | """ 22 | 23 | # --------------------------Gradio Layout----------------------------- 24 | with gr.Blocks(title="Chat RAG", fill_width=True, css=css) as demo: 25 | gr.Markdown("# Chat RAG: Interactive Coding Assistant" 26 | ) 27 | with gr.Row(): 28 | with gr.Column(scale=7, variant="compact"): # 29 | chatbot = gr.Chatbot(label="Chat RAG", height="80vh") 30 | msg = gr.Textbox(placeholder="Enter your message here and hit return when you're ready...", 31 | interactive=True, container=False, autoscroll=True) 32 | with gr.Row(): 33 | clear = gr.ClearButton([msg, chatbot], 34 | value="Clear Chat Window", 35 | elem_id="button") 36 | clear_chat_mem = gr.Button(value="Clear Chat Window and Chat Memory", 37 | elem_id="button") 38 | with gr.Column(scale=3): # 39 | with gr.Tab("Chat With Files"): 40 | files = gr.Files(interactive=True, 41 | file_count="multiple", 42 | file_types=["text", ".pdf", ".xlsx", ".py", ".txt", ".dart", ".c", ".jsx", ".xml", 43 | ".css", ".cpp", ".html", ".docx", ".doc", ".js", ".json", ".csv"]) 44 | with gr.Row(): 45 | upload = gr.Button(value="Upload Data to Knowledge Base", 46 | interactive=True, 47 | size="sm", 48 | elem_id="button") 49 | clear_db = gr.Button(value="Clear Knowledge Base", 50 | interactive=True, 51 | size="sm", 52 | elem_id="button") 53 | # TODO Finish Database Backend Implementation 54 | with gr.Tab("Chat With a Database(Coming Soon)"): 55 | db_selector = gr.Radio(label="Database", value="ChromaDB", 56 | choices=["ChromaDB", "Milvus", "Neo4j"]) 57 | @gr.render(inputs=db_selector) 58 | def render_db_components(provider): 59 | if provider == "ChromaDB": 60 | with gr.Row(): 61 | collection_name = gr.Textbox(label="Chroma Collection Name",interactive=True, 62 | placeholder="Enter Database Collection Name Here..") 63 | with gr.Row(): 64 | dbfiles = gr.Files(interactive=True, 65 | file_count="multiple", 66 | file_types=["text", ".pdf", ".xlsx", ".py", ".txt", ".dart", ".c", 67 | ".jsx",".xml",".css", ".cpp", ".html", ".docx", ".doc", 68 | ".js", ".json", ".csv"]) 69 | with gr.Row(): 70 | upload_db_files = gr.Button("Upload Data to Database", 71 | interactive=True, 72 | size="sm", 73 | elem_id="button") 74 | load_db = gr.Button("Load Database to Model", 75 | interactive=True, 76 | size="sm", 77 | elem_id="button") 78 | remove_db = gr.Button("Remove Database from Model", 79 | interactive=True, 80 | size="sm", 81 | elem_id="button") 82 | load_db.click(modelUtils.setup_database, 83 | inputs=[db_selector, None, None, None, collection_name]) 84 | remove_db.click(modelUtils.remove_database) 85 | elif provider == "Milvus": 86 | with gr.Row(): 87 | collection_name = gr.Textbox(label="Milvus Collection Name",interactive=True, 88 | placeholder="Enter Database Collection Name Here..") 89 | with gr.Row(): 90 | dbfiles = gr.Files(interactive=True, 91 | file_count="multiple", 92 | file_types=["text", ".pdf", ".xlsx", ".py", ".txt", ".dart", ".c", 93 | ".jsx", ".xml", ".css", ".cpp", ".html", ".docx", ".doc", 94 | ".js",".json", ".csv"]) 95 | with gr.Row(): 96 | upload_db_files = gr.Button("Upload Data to Database", 97 | interactive=True, 98 | size="sm", 99 | elem_id="button") 100 | load_db = gr.Button("Load Database to Model", 101 | interactive=True, 102 | size="sm", 103 | elem_id="button") 104 | remove_db = gr.Button("Remove Database from Model", 105 | interactive=True, 106 | size="sm", 107 | elem_id="button") 108 | load_db.click(modelUtils.setup_database, 109 | inputs=[db_selector, None, None, None, collection_name]) 110 | remove_db.click(modelUtils.remove_database) 111 | elif provider == "Neo4j": 112 | with gr.Row(): 113 | dbfiles = gr.Files(interactive=True, 114 | file_count="multiple", 115 | file_types=["text", ".pdf", ".xlsx", ".py", ".txt", ".dart", ".c", 116 | ".jsx", ".xml",".css", ".cpp", ".html", ".docx", ".doc", 117 | ".js", ".json", ".csv"]) 118 | with gr.Row(): 119 | neo_un = gr.Textbox(label="Neo4j Database Name", 120 | placeholder="Enter Database Name Here...", 121 | interactive=True) 122 | neo_pw = gr.Textbox(label="Neo4j Database Password", 123 | placeholder="Enter Database Password Here...", 124 | interactive=True) 125 | neo_url = gr.Textbox(label="Neo4j Database Link", 126 | placeholder="Enter Database Link Here...", 127 | interactive=True) 128 | with gr.Row(): 129 | upload_db_files = gr.Button("Upload Data to Database", 130 | interactive=True, 131 | size="sm", 132 | elem_id="button") 133 | load_db = gr.Button("Load Database to Model", 134 | interactive=True, 135 | size="sm", 136 | elem_id="button") 137 | remove_db = gr.Button("Remove Database from Model", 138 | interactive=True, 139 | size="sm", 140 | elem_id="button") 141 | load_db.click(modelUtils.setup_database, 142 | inputs=[db_selector, neo_un, neo_pw, neo_url, None]) 143 | remove_db.click(modelUtils.remove_database, outputs=[neo_un, neo_pw, neo_url]) 144 | with gr.Tab("Chat With a GitHub Repository"): 145 | repoOwnerUsername = gr.Textbox(label="GitHub Repository Owners Username:", 146 | placeholder="Enter GitHub Repository Owners Username Here....", 147 | interactive= True) 148 | repoName = gr.Textbox(label="GitHub Repository Name:", 149 | placeholder="Enter Repository Name Here....", 150 | interactive= True) 151 | repoBranch = gr.Textbox(label="GitHub Repository Branch Name:", 152 | placeholder="Enter Branch Name Here....", 153 | interactive=True) 154 | with gr.Row(): 155 | getRepo = gr.Button(value="Load Repository to Model", 156 | size="sm", 157 | interactive=True, 158 | elem_id="button") 159 | removeRepo = gr.Button(value="Reset Info and Remove Repository from Model", 160 | size="sm", 161 | interactive=True, 162 | elem_id="button") 163 | 164 | choices = ["Ollama"] 165 | if "HUGGINGFACE_HUB_TOKEN" in os.environ: 166 | choices.append("HuggingFace") 167 | if "NVIDIA_API_KEY" in os.environ: 168 | choices.append("NVIDIA NIM") 169 | if "OPENAI_API_KEY" in os.environ: 170 | choices.append("OpenAI") 171 | if "ANTHROPIC_API_KEY" in os.environ: 172 | choices.append("Anthropic") 173 | model_provider = gr.Radio(label="Select Model Provider", 174 | value="Ollama", 175 | choices=choices, 176 | interactive=True, 177 | info="Choose your model provider.") 178 | 179 | @gr.render(inputs=model_provider) 180 | def render_provider_components(provider): 181 | if provider == "Ollama": 182 | # --------------Ollama Components------------------------ 183 | selected_chat_model = gr.Dropdown(choices=list(modelUtils.model_display_names["Ollama"].keys()), 184 | interactive=True, 185 | label="Select a Chat Model", 186 | value="Codestral 22B", 187 | filterable=True, 188 | info="Choose the model you want to chat with from the list below.") 189 | temperature = gr.Slider(minimum=0, maximum=1, value=.75, step=.05, 190 | label="Model Temperature", 191 | info="Select a temperature between 0 and 1 for the model.", 192 | interactive=True) 193 | max_tokens = gr.Slider(minimum=100, maximum=5000, value=2048, step=1, 194 | label="Max Output Tokens", 195 | info="Set the maximum number of tokens the model can respond with.", 196 | interactive=True) 197 | custom_prompt = gr.Textbox(label="Enter a Custom Prompt", 198 | placeholder="Enter your custom prompt here...", 199 | interactive=True) 200 | # ---------Ollama Buttons----------------- 201 | selected_chat_model.change(gradioUtils.update_model, 202 | inputs=[selected_chat_model], 203 | outputs=[chatbot]) 204 | temperature.release(gradioUtils.update_model_temp, 205 | inputs=[temperature]) 206 | max_tokens.release(gradioUtils.update_max_tokens, 207 | inputs=[max_tokens]) 208 | custom_prompt.submit(gradioUtils.update_chat_prompt, 209 | inputs=[custom_prompt]) 210 | 211 | elif provider == "HuggingFace": 212 | # ------------------HuggingFace components------------------------------- 213 | hf_quantization = gr.Dropdown(choices=["Choose a Quantization","No Quantization", "2 Bit", "4 Bit", "8 Bit"], 214 | interactive=True, 215 | label="Model Quantization", 216 | value="Choose a Quantization", 217 | info="Choose Model Quantization.") 218 | hf_model = gr.Dropdown(choices=list(modelUtils.model_display_names["HuggingFace"].keys()), 219 | interactive=True, 220 | label="Select a Chat Model", 221 | value="Choose a Model", 222 | filterable=True, 223 | info="Choose a Hugging Face model.") 224 | hf_temperature = gr.Slider(minimum=0, maximum=1, value=.75, step=.05, 225 | label="Model Temperature", 226 | info="Select a Temperature between 0 and 1 for the model.", 227 | interactive=True) 228 | hf_top_p = gr.Slider(minimum=0, maximum=1, value=0.4, step=0.05, 229 | label="Top P", 230 | info="Select a Top P value between 0 and 1 for the model.", 231 | interactive=True) 232 | hf_ctx_wnd = gr.Slider(minimum=100, maximum=10000, value=2048, step=1, 233 | label="Context Window", 234 | info="Select a Context Window value between 100 and 10000 for the model.", 235 | interactive=True) 236 | hf_max_tokens = gr.Slider(minimum=100, maximum=5000, value=2048, step=1, 237 | label="Max Output Tokens", 238 | info="Set the maximum number of tokens the model can respond with.", 239 | interactive=True) 240 | hf_custom_prompt = gr.Textbox(label="Enter a Custom Prompt", 241 | placeholder="Enter your custom prompt here...", 242 | interactive=True) 243 | # ---------HuggingFace Buttons----------------- 244 | hf_model.change(gradioUtils.update_model, 245 | inputs=[hf_model], 246 | outputs=[chatbot]) 247 | hf_quantization.change(gradioUtils.update_quant, 248 | inputs=[hf_quantization]) 249 | hf_temperature.release(gradioUtils.update_model_temp, 250 | inputs=[hf_temperature]) 251 | hf_top_p.release(gradioUtils.update_top_p, 252 | inputs=[hf_top_p]) 253 | hf_ctx_wnd.release(gradioUtils.update_context_window, 254 | inputs=[hf_ctx_wnd]) 255 | hf_max_tokens.release(gradioUtils.update_max_tokens, 256 | inputs=[hf_max_tokens]) 257 | hf_custom_prompt.submit(gradioUtils.update_chat_prompt, 258 | inputs=[hf_custom_prompt]) 259 | 260 | elif provider=="NVIDIA NIM": 261 | # ----------------------------NVIDIA NIM components--------------------------- 262 | nv_model = gr.Dropdown(choices=list(modelUtils.model_display_names["NVIDIA NIM"].keys()), 263 | interactive=True, 264 | label="Select a NVIDIA NIM", 265 | value="Codestral 22B", 266 | filterable=True, 267 | info="Choose a NVIDIA NIM.") 268 | nv_temperature = gr.Slider(minimum=0, maximum=1, value=.75, step=.05, 269 | label="Model Temperature", 270 | info="Select a temperature between .1 and 1 to set the model to.", 271 | interactive=True) 272 | nv_top_p = gr.Slider(minimum=0, maximum=1, value=0.4, step=0.05, 273 | label="Top P", 274 | info="Set the top p value for the model.", 275 | interactive=True) 276 | nv_max_tokens = gr.Slider(minimum=100, maximum=5000, value=2048, step=1, 277 | label="Max Output Tokens", 278 | info="Set the maximum number of tokens the model can respond with.", 279 | interactive=True) 280 | # ---------NVIDIA Buttons----------------- 281 | nv_model.change(gradioUtils.update_model, 282 | inputs=[nv_model], 283 | outputs=[chatbot]) 284 | nv_temperature.release(gradioUtils.update_model_temp, 285 | inputs=[nv_temperature]) 286 | nv_top_p.release(gradioUtils.update_top_p, 287 | inputs=[nv_top_p]) 288 | nv_max_tokens.release(gradioUtils.update_max_tokens, 289 | inputs=[nv_max_tokens]) 290 | 291 | elif provider=="OpenAI": 292 | # ----------------------------OPEN AI components--------------------------- 293 | openai_model = gr.Dropdown(choices=list(modelUtils.model_display_names["OpenAI"].keys()), 294 | interactive=True, 295 | label="Select a OpenAI Model", 296 | value="GPT-4o", 297 | filterable=True, 298 | info="Choose a OpenAI model.") 299 | openai_temperature = gr.Slider(minimum=0, maximum=1, value=.75, step=.05, 300 | label="Model Temperature", 301 | info="Select a temperature between .1 and 1 to set the model to.", 302 | interactive=True) 303 | openai_top_p = gr.Slider(minimum=0, maximum=1, value=0.4, step=0.05, 304 | label="Top P", 305 | info="Set the top p value for the model.", 306 | interactive=True) 307 | openai_max_tokens = gr.Slider(minimum=100, maximum=5000, value=2048, step=1, 308 | label="Max Output Tokens", 309 | info="Set the maximum number of tokens the model can respond with.", 310 | interactive=True) 311 | # ---------OpenAI Buttons----------------- 312 | openai_model.change(gradioUtils.update_model, 313 | inputs=[openai_model], 314 | outputs=[chatbot]) 315 | openai_temperature.release(gradioUtils.update_model_temp, 316 | inputs=[openai_temperature]) 317 | openai_top_p.release(gradioUtils.update_top_p, 318 | inputs=[openai_top_p]) 319 | openai_max_tokens.release(gradioUtils.update_max_tokens, 320 | inputs=[openai_max_tokens]) 321 | elif provider=="Anthropic": 322 | # ----------------------------Anthropic components--------------------------- 323 | anth_model = gr.Dropdown(choices=list(modelUtils.model_display_names["Anthropic"].keys()), 324 | interactive=True, 325 | label="Select a Anthropic Model", 326 | value="Claude 3.7 Sonnet", 327 | filterable=True, 328 | info="Choose a Anthropic model.") 329 | anth_temperature = gr.Slider(minimum=0, maximum=1, value=.75, step=.05, 330 | label="Model Temperature", 331 | info="Select a temperature between .1 and 1 to set the model to.", 332 | interactive=True) 333 | anth_max_tokens = gr.Slider(minimum=100, maximum=5000, value=2048, step=1, 334 | label="Max Output Tokens", 335 | info="Set the maximum number of tokens the model can respond with.", 336 | interactive=True) 337 | # ---------Anthropic Buttons----------------- 338 | anth_model.change(gradioUtils.update_model, 339 | inputs=[anth_model], 340 | outputs=[chatbot]) 341 | anth_temperature.release(gradioUtils.update_model_temp, 342 | inputs=[anth_temperature]) 343 | anth_max_tokens.release(gradioUtils.update_max_tokens, 344 | inputs=[anth_max_tokens]) 345 | gradioUtils.update_model_provider(provider) 346 | 347 | # ----------------------------------Button Functionality For RAG Chat----------------------------------------------- 348 | msg.submit(gradioUtils.stream_response, 349 | inputs=[msg], 350 | outputs=[msg, chatbot], 351 | show_progress="full", 352 | scroll_to_output=True) 353 | # --------------------Buttons in Left Column-------------------------------- 354 | clear.click(gradioUtils.clear_chat_history, 355 | outputs=chatbot) 356 | clear_chat_mem.click(gradioUtils.clear_his_and_mem, 357 | outputs=chatbot) 358 | # --------------------Buttons in Right Column-------------------------------- 359 | files.upload(gradioUtils.handle_doc_upload, inputs=files, 360 | show_progress="full") 361 | upload.click(lambda: gradioUtils.model_manager.reset_chat_engine()) 362 | clear_db.click(gradioUtils.delete_db, 363 | show_progress="full") 364 | getRepo.click(gradioUtils.set_github_info, inputs=[repoOwnerUsername, repoName, repoBranch]) 365 | removeRepo.click(modelUtils.reset_github_info, outputs=[repoOwnerUsername, repoName, repoBranch]) 366 | 367 | demo.launch(inbrowser=True) # , share=True -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | """ 2 | Standard config file that stores lists so they don't take up room in the main files. 3 | """ 4 | 5 | OLLAMA_MODEL_LIST = { 6 | "Codestral": "codestral:latest", 7 | "Qwen 3": "qwen3:latest", 8 | "Gemma 3": "gemma3:12b", 9 | "CodeGemma": "codegemma:latest", 10 | "Mistral-Nemo": "mistral-nemo:latest", 11 | "Llama3.1": "llama3.1:latest", 12 | "Deepseek R-1": "deepseek-r1:14b", 13 | "DeepSeek Coder V2": "deepseek-coder-v2:latest" 14 | } 15 | HF_MODEL_LIST = { 16 | "Choose a Model": "", 17 | "Codestral 22B": "mistralai/Codestral-22B-v0.1", 18 | "Qwen 3": "Qwen/Qwen3-14B", 19 | "Gemma 3":"google/gemma-3-12b-it", 20 | "CodeGemma 7B-Instruct": "google/codegemma-7b-it", 21 | "Mistral-Nemo 12B-Instruct": "mistralai/Mistral-Nemo-Instruct-2407", 22 | "Llama3.1 8B-Instruct": "meta-llama/Meta-Llama-3.1-8B-Instruct", 23 | "DeepSeek Coder V2 16B": "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct", 24 | 25 | } 26 | NV_MODEL_LIST = { 27 | "Codestral 22B": "mistralai/codestral-22b-instruct-v0.1", 28 | "Qwen 32B": "qwen/qwq-32b", 29 | "Gemma 3": "google/gemma-3-27b-it", 30 | "CodeGemma 7B": "google/codegemma-7b", 31 | "Mistral-Nemo 12B": "nv-mistralai/mistral-nemo-12b-instruct", 32 | "Llama 3.1 8B": "meta/llama-3.1-8b-instruct", 33 | 34 | } 35 | OA_MODEL_LIST = {"GPT-4o": "gpt-4o", 36 | "GPT-4o mini": "gpt-4o-mini", 37 | "GPT-4": "gpt-4", 38 | } 39 | ANTH_MODEL_LIST = {"Claude 3.7 Sonnet": "claude-3-7-sonnet-20250219", 40 | "Claude 3.5 Sonnet": "claude-3-5-sonnet-20241022", 41 | "Claude 3.5 Haiku": "claude-3-5-haiku-20241022", 42 | "Claude 3 Opus": "claude-3-opus-20240229", 43 | "Claude 3 Sonnet": "claude-3-sonnet-20240229", 44 | } 45 | -------------------------------------------------------------------------------- /gradio_utils.py: -------------------------------------------------------------------------------- 1 | import os, shutil 2 | import gradio as gr 3 | from model_utils import ModelManager, ModelParamUpdates 4 | 5 | """ 6 | Main gradio class that is the connector function between the front and backend. This function serves many purposes 7 | from updating model parameters and calling the appropriate function to handing the response streaming after calling the 8 | process input function. 9 | This class also handles chat memory, deleting the knowledge base, and handing the document uploads from the main gradio 10 | file asset. 11 | """ 12 | class GradioUtils: 13 | def __init__(self): 14 | self.model_manager = ModelManager() 15 | self.model_param_updater = ModelParamUpdates(self.model_manager) 16 | self.chat_history = [] 17 | 18 | # This function gets the users query, send it to the chat engine to be processed and then streams the response back 19 | def stream_response(self, message: str): 20 | streaming_response = self.model_manager.process_input(message) 21 | full_response = "" 22 | for tokens in streaming_response.response_gen: 23 | full_response += tokens 24 | yield "", self.chat_history + [(message, full_response)] 25 | self.chat_history.append((message, full_response)) 26 | 27 | # This function clears the chat history dictionary 28 | def clear_chat_history(self): 29 | self.chat_history.clear() 30 | 31 | """ 32 | This function clears the chat history and reset the chat engine to clear the model memory to give the user a 33 | fresh chat with no context 34 | """ 35 | def clear_his_and_mem(self): 36 | self.clear_chat_history() 37 | self.model_manager.reset_chat_engine() 38 | 39 | """ 40 | This function deletes the data file to remove the uploaded data and resets the chat engine to remove it from the 41 | models' context. It also sends the warning message to the front end to alert the user of the changes made. 42 | """ 43 | def delete_db(self): 44 | gr.Info("Wait about 10 seconds for the files to clear. After this message disappears you should " 45 | "be in the clear.", duration=15) 46 | if os.path.exists("data"): 47 | shutil.rmtree("data") 48 | os.makedirs("data") 49 | self.model_manager.reset_chat_engine() 50 | 51 | """ 52 | This function clears the chat history, updates the model provider based off users selection and then sends a 53 | #warning message about model loading and downloading wait times if the user requested to use a huggingface model. 54 | """ 55 | def update_model_provider(self, provider): 56 | self.clear_chat_history() 57 | self.model_manager.update_model_provider(provider) 58 | if self.model_manager.provider == "HuggingFace": 59 | gr.Warning( 60 | "If this is your first time using HuggingFace the model may need to download. Please be patient.", 61 | duration=10) 62 | 63 | # This function sends the users model selection through to the model manager 64 | def update_model(self, display_name): 65 | self.clear_chat_history() 66 | self.model_manager.update_model(display_name) 67 | 68 | # This function sends the users quantization selection through to the model parameter updater function 69 | def update_quant(self, quantization): 70 | self.model_param_updater.update_quant(quantization) 71 | 72 | # This function sends the users quantization selection through to the model parameter updater function 73 | def update_model_temp(self, temperature): 74 | self.model_param_updater.update_model_temp(temperature) 75 | 76 | # This function sends the users top p selection through to the model parameter updater function 77 | def update_top_p(self, top_p): 78 | self.model_param_updater.update_top_p(top_p) 79 | 80 | # This function sends the users context window size selection through to the model parameter updater function 81 | def update_context_window(self, context_window): 82 | self.model_param_updater.update_context_window(context_window) 83 | 84 | # This function sends the users max token selection through to the model parameter updater function 85 | def update_max_tokens(self, max_tokens): 86 | self.model_param_updater.update_max_tokens(max_tokens) 87 | 88 | # This function sends the users custom prompt through to the model parameter updater function 89 | def update_chat_prompt(self, custom_prompt): 90 | self.model_param_updater.update_chat_prompt(custom_prompt) 91 | 92 | # This function sends the users GitHub repo info through to the model manager function 93 | def set_github_info(self, owner, repo, branch): 94 | self.model_manager.set_github_info(owner, repo, branch) 95 | 96 | """ 97 | This function handles the document uploading and sending the user a message about what to do for the model 98 | to see the files. 99 | """ 100 | @staticmethod 101 | def handle_doc_upload(files): 102 | gr.Warning("Make sure you hit the upload button or the model wont see your files!", duration=10) 103 | return [file.name for file in files] 104 | -------------------------------------------------------------------------------- /model_utils.py: -------------------------------------------------------------------------------- 1 | import torch, gc 2 | import gradio as gr 3 | from chat_utils import create_chat_engine 4 | from config import HF_MODEL_LIST, OLLAMA_MODEL_LIST, NV_MODEL_LIST, OA_MODEL_LIST, ANTH_MODEL_LIST 5 | 6 | 7 | # TODO needs to be optimized doesnt fully do its job. 8 | # Clears gpu memory so a new model can be load. 9 | def reset_gpu_memory(): 10 | torch.cuda.empty_cache() 11 | gc.collect() 12 | 13 | 14 | """ 15 | Main model class that deals with most of the functionality of the model and chat engine. This class handles the 16 | setting and resetting of the chat engine, model and model provider switching, database loading and resetting, and github 17 | repository loading and reset/ 18 | """ 19 | class ModelManager: 20 | def __init__(self): 21 | self.collection_name = None 22 | self.url = None 23 | self.password = None 24 | self.username = None 25 | self.vector_store = None 26 | self.model_param_updates = ModelParamUpdates(self) 27 | self.branch = None 28 | self.repo = None 29 | self.owner = None 30 | self.chat_engine = None 31 | self.provider = "Ollama" 32 | self.selected_model = "codestral:latest" 33 | self.model_display_names = { 34 | "Ollama": OLLAMA_MODEL_LIST, 35 | "HuggingFace": HF_MODEL_LIST, 36 | "NVIDIA NIM": NV_MODEL_LIST, 37 | "OpenAI": OA_MODEL_LIST, 38 | "Anthropic": ANTH_MODEL_LIST 39 | } 40 | 41 | # Creates the initial chat engine 42 | def create_initial_chat_engine(self): 43 | return create_chat_engine(self.provider, self.selected_model, 44 | self.model_param_updates.temperature, 45 | self.model_param_updates.max_tokens, 46 | self.model_param_updates.custom_prompt, 47 | self.model_param_updates.top_p, 48 | self.model_param_updates.context_window, 49 | self.model_param_updates.quantization, 50 | self.owner, self.repo, self.branch, self.vector_store, self.username, 51 | self.password, self.url, self.collection_name) 52 | 53 | # Processes the query from the user and sends it to the chat engine for processing 54 | def process_input(self, message): 55 | if self.chat_engine is None: 56 | self.chat_engine = self.create_initial_chat_engine() 57 | return self.chat_engine.stream_chat(message) 58 | 59 | # Updates the model provider and sends it to the chat engine based off the selection of the user 60 | def update_model_provider(self, provider): 61 | reset_gpu_memory() 62 | self.provider = provider 63 | default_models = { 64 | "Ollama": "codestral:latest", 65 | "HuggingFace": "", 66 | "NVIDIA NIM": "mistralai/codestral-22b-instruct-v0.1", 67 | "OpenAI": "gpt-4o", 68 | "Anthropic": "claude-3-5-sonnet-20240620" 69 | } 70 | self.selected_model = default_models.get(provider, "codestral:latest") 71 | gr.Info(f"Model provider updated to {provider}.", duration=10) 72 | self.reset_chat_engine() 73 | 74 | # Updates the model and sends it to the chat engine based off the selection of the user 75 | def update_model(self, display_name): 76 | reset_gpu_memory() 77 | self.selected_model = self.model_display_names[self.provider].get(display_name, self.selected_model) 78 | self.reset_chat_engine() 79 | gr.Info(f"Model updated to {display_name}.", duration=10) 80 | 81 | # Sets GitHub info to add its data to the context of the model 82 | def set_github_info(self, owner, repo, branch): 83 | self.owner, self.repo, self.branch = owner, repo, branch 84 | if all([owner, repo, branch]) != "": 85 | gr.Info( 86 | f"GitHub repository info set to Owners Username: {owner}, Repository Name: {repo}, and Branch Name: {branch}.") 87 | self.reset_chat_engine() 88 | 89 | # Resets GitHub info to remove the data from the context of the model 90 | def reset_github_info(self): 91 | self.owner = self.repo = self.branch = "" 92 | self.set_github_info(self.owner, self.repo, self.branch) 93 | gr.Info("GitHub repository info cleared and repository files from the models context!") 94 | self.reset_chat_engine() 95 | return self.owner, self.repo, self.branch 96 | 97 | # Sets database parameters and adds it to the models context 98 | def setup_database(self, vector_store, username, password, url, collection_name): 99 | self.vector_store, self.username, self.password, self.url, self.collection_name = vector_store, username, password, url, collection_name 100 | self.reset_chat_engine() 101 | gr.Info(f"Database connection established with {vector_store}.", duration=10) 102 | return self.vector_store 103 | 104 | # Resets database parameters and removes it from the models context 105 | def remove_database(self): 106 | self.vector_store = self.username = self.password = self.url = self.collection_name = None 107 | self.reset_chat_engine() 108 | gr.Info("Database connection removed.", duration=10) 109 | return self.username, self.password, self.url 110 | 111 | # Resets chat engine so the new parameters and new data can be loaded or removed into or from the model 112 | def reset_chat_engine(self): 113 | reset_gpu_memory() 114 | self.chat_engine = self.create_initial_chat_engine() 115 | 116 | 117 | """ 118 | Secondary class that sets initial model and chat engine parameters as well as updates model and chat engine parameters. 119 | It is also responsible for send gradio warning and info messages to the front end for the user to know their action was 120 | a success. 121 | """ 122 | class ModelParamUpdates: 123 | def __init__(self, model_manager): 124 | self.model_manager = model_manager 125 | self.max_tokens = 2048 126 | self.temperature = .75 127 | self.top_p = .4 128 | self.context_window = 2048 129 | self.quantization = "4 Bit" 130 | self.custom_prompt = None 131 | 132 | # Updates model quantization and sends a message to the user about the change. Also reset the gpu memory. 133 | def update_quant(self, quantization): 134 | reset_gpu_memory() 135 | self.quantization = quantization 136 | gr.Info(f"Quantization updated to {quantization}.", duration=10) 137 | self.model_manager.reset_chat_engine() 138 | 139 | # Updates model temperature parameter, send user a message about the change, and resets gpu memory. 140 | def update_model_temp(self, temperature): 141 | reset_gpu_memory() 142 | self.temperature = temperature 143 | gr.Info(f"Model temperature updated to {temperature}.", duration=10) 144 | gr.Warning("Changing this value can affect the randomness " 145 | "and diversity of generated responses. Use with caution!", 146 | duration=10) 147 | self.model_manager.reset_chat_engine() 148 | 149 | # Updates model top p parameter, send user a message about the change, and resets gpu memory. 150 | def update_top_p(self, top_p): 151 | reset_gpu_memory() 152 | self.top_p = top_p 153 | gr.Info(f"Top P updated to {top_p}.", duration=10) 154 | gr.Warning("Changing this value can affect the randomness " 155 | "and diversity of generated responses. Use with caution!", 156 | duration=10) 157 | self.model_manager.reset_chat_engine() 158 | 159 | # Updates model context window parameter, send user a message about the change, and resets gpu memory. 160 | def update_context_window(self, context_window): 161 | reset_gpu_memory() 162 | self.context_window = context_window 163 | gr.Info(f"Context Window updated to {context_window}.", duration=10) 164 | gr.Warning("Changing this value can affect the amount of the context the model can see and use " 165 | "to answer your question.", 166 | duration=10) 167 | self.model_manager.reset_chat_engine() 168 | 169 | # Updates the max output tokens a model can respond with send user a message about the change, and resets gpu memory. 170 | def update_max_tokens(self, max_tokens): 171 | reset_gpu_memory() 172 | self.max_tokens = max_tokens 173 | gr.Info(f"Max Tokens set to {max_tokens}.", duration=10) 174 | gr.Warning("Please note that reducing the maximum number of tokens may" 175 | " cause incomplete or unexpected responses from the model if a user's question requires more tokens" 176 | " for an accurate answer.", 177 | duration=10) 178 | self.model_manager.reset_chat_engine() 179 | 180 | # Updates the chat engines system prompt, send user a message about the change, and resets gpu memory. 181 | def update_chat_prompt(self, custom_prompt): 182 | reset_gpu_memory() 183 | self.custom_prompt = custom_prompt 184 | gr.Warning("Caution: Changing the chat prompt may significantly alter the model's responses and could " 185 | "potentially cause misleading or incorrect information to be generated. Please ensure that " 186 | "the modified prompt is appropriate for your intended use case.", duration=10) 187 | self.model_manager.reset_chat_engine() 188 | -------------------------------------------------------------------------------- /pics/RAG_Query.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JakeFurtaw/Chat-RAG/648e6dccb0e7aa5b509fc927e51a8df2c75681f5/pics/RAG_Query.png -------------------------------------------------------------------------------- /pics/model_dropdown.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JakeFurtaw/Chat-RAG/648e6dccb0e7aa5b509fc927e51a8df2c75681f5/pics/model_dropdown.png -------------------------------------------------------------------------------- /pics/query.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JakeFurtaw/Chat-RAG/648e6dccb0e7aa5b509fc927e51a8df2c75681f5/pics/query.png -------------------------------------------------------------------------------- /pics/start_state.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JakeFurtaw/Chat-RAG/648e6dccb0e7aa5b509fc927e51a8df2c75681f5/pics/start_state.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.33.0 2 | aiofiles==23.2.1 3 | aiohappyeyeballs==2.4.0 4 | aiohttp==3.10.5 5 | aiosignal==1.3.1 6 | annotated-types==0.7.0 7 | anthropic==0.28.1 8 | anyio==4.4.0 9 | asgiref==3.8.1 10 | attrs==24.2.0 11 | backoff==2.2.1 12 | bcrypt==4.2.0 13 | beautifulsoup4==4.12.3 14 | bitsandbytes==0.43.3 15 | boto3==1.37.13 16 | botocore==1.37.13 17 | build==1.2.1 18 | cachetools==5.5.0 19 | certifi==2024.8.30 20 | charset-normalizer==3.3.2 21 | chroma-hnswlib==0.7.6 22 | chromadb==0.5.5 23 | click==8.1.7 24 | coloredlogs==15.0.1 25 | contourpy==1.3.0 26 | cycler==0.12.1 27 | dataclasses-json==0.6.7 28 | Deprecated==1.2.14 29 | dirtyjson==1.0.8 30 | distro==1.9.0 31 | einops==0.8.0 32 | environs==9.5.0 33 | fastapi==0.112.2 34 | ffmpy==0.4.0 35 | filelock==3.15.4 36 | filetype==1.2.0 37 | flash-attn==2.6.3 38 | flatbuffers==24.3.25 39 | fonttools==4.53.1 40 | frozenlist==1.4.1 41 | fsspec==2024.6.1 42 | google-auth==2.34.0 43 | googleapis-common-protos==1.65.0 44 | gradio==4.42.0 45 | gradio_client==1.3.0 46 | greenlet==3.0.3 47 | grpcio==1.66.1 48 | h11==0.14.0 49 | httpcore==1.0.5 50 | httptools==0.6.1 51 | httpx==0.27.2 52 | huggingface-hub==0.23.5 53 | humanfriendly==10.0 54 | idna==3.8 55 | importlib_metadata==8.4.0 56 | importlib_resources==6.4.4 57 | Jinja2==3.1.4 58 | jiter==0.5.0 59 | jmespath==1.0.1 60 | joblib==1.4.2 61 | jsonpatch==1.33 62 | jsonpointer==3.0.0 63 | kiwisolver==1.4.5 64 | kubernetes==30.1.0 65 | langchain==0.2.16 66 | langchain-community==0.2.16 67 | langchain-core==0.2.38 68 | langchain-huggingface==0.0.3 69 | langchain-ollama==0.1.3 70 | langchain-text-splitters==0.2.4 71 | langsmith==0.1.117 72 | llama-cloud==0.0.15 73 | llama-index==0.11.2 74 | llama-index-agent-openai==0.3.0 75 | llama-index-cli==0.3.0 76 | llama-index-core==0.11.3 77 | llama-index-embeddings-huggingface==0.3.1 78 | llama-index-embeddings-openai==0.2.3 79 | llama-index-indices-managed-llama-cloud==0.3.0 80 | llama-index-legacy==0.9.48.post3 81 | llama-index-llms-anthropic==0.3.0 82 | llama-index-llms-huggingface==0.3.1 83 | llama-index-llms-nvidia==0.2.1 84 | llama-index-llms-ollama==0.3.0 85 | llama-index-llms-openai==0.2.0 86 | llama-index-llms-openai-like==0.2.0 87 | llama-index-multi-modal-llms-openai==0.2.0 88 | llama-index-program-openai==0.2.0 89 | llama-index-question-gen-openai==0.2.0 90 | llama-index-readers-file==0.2.0 91 | llama-index-readers-github==0.2.0 92 | llama-index-readers-llama-parse==0.2.0 93 | llama-index-vector-stores-chroma==0.2.0 94 | llama-index-vector-stores-milvus==0.2.3 95 | llama-index-vector-stores-neo4jvector==0.2.1 96 | llama-parse==0.5.1 97 | markdown-it-py==3.0.0 98 | MarkupSafe==2.1.5 99 | marshmallow==3.22.0 100 | matplotlib==3.9.2 101 | mdurl==0.1.2 102 | milvus-lite==2.4.10 103 | minijinja==2.2.0 104 | mmh3==4.1.0 105 | monotonic==1.6 106 | mpmath==1.3.0 107 | multidict==6.0.5 108 | mypy-extensions==1.0.0 109 | neo4j==5.24.0 110 | nest-asyncio==1.6.0 111 | networkx==3.3 112 | nltk==3.9.1 113 | numpy==1.26.4 114 | nvidia-cublas-cu12==12.1.3.1 115 | nvidia-cuda-cupti-cu12==12.1.105 116 | nvidia-cuda-nvrtc-cu12==12.1.105 117 | nvidia-cuda-runtime-cu12==12.1.105 118 | nvidia-cudnn-cu12==9.1.0.70 119 | nvidia-cufft-cu12==11.0.2.54 120 | nvidia-curand-cu12==10.3.2.106 121 | nvidia-cusolver-cu12==11.4.5.107 122 | nvidia-cusparse-cu12==12.1.0.106 123 | nvidia-nccl-cu12==2.20.5 124 | nvidia-nvjitlink-cu12==12.9.41 125 | nvidia-nvtx-cu12==12.1.105 126 | oauthlib==3.2.2 127 | ollama==0.3.2 128 | onnxruntime==1.19.0 129 | openai==1.43.0 130 | opentelemetry-api==1.27.0 131 | opentelemetry-exporter-otlp-proto-common==1.27.0 132 | opentelemetry-exporter-otlp-proto-grpc==1.27.0 133 | opentelemetry-instrumentation==0.48b0 134 | opentelemetry-instrumentation-asgi==0.48b0 135 | opentelemetry-instrumentation-fastapi==0.48b0 136 | opentelemetry-proto==1.27.0 137 | opentelemetry-sdk==1.27.0 138 | opentelemetry-semantic-conventions==0.48b0 139 | opentelemetry-util-http==0.48b0 140 | orjson==3.10.7 141 | overrides==7.7.0 142 | packaging==24.1 143 | pandas==2.2.2 144 | pillow==10.4.0 145 | pip==25.1 146 | posthog==3.6.0 147 | protobuf==4.25.6 148 | psutil==6.0.0 149 | pyasn1==0.6.0 150 | pyasn1_modules==0.4.0 151 | pydantic==2.8.2 152 | pydantic_core==2.20.1 153 | pydub==0.25.1 154 | Pygments==2.18.0 155 | pymilvus==2.4.6 156 | pyparsing==3.1.4 157 | pypdf==4.3.1 158 | PyPika==0.48.9 159 | pyproject_hooks==1.1.0 160 | python-dateutil==2.9.0.post0 161 | python-dotenv==1.0.1 162 | python-multipart==0.0.9 163 | pytz==2024.1 164 | PyYAML==6.0.2 165 | regex==2024.7.24 166 | requests==2.32.3 167 | requests-oauthlib==2.0.0 168 | rich==13.8.0 169 | rsa==4.9 170 | ruff==0.6.3 171 | s3transfer==0.11.4 172 | safehttpx==0.1.6 173 | safetensors==0.4.4 174 | scikit-learn==1.5.1 175 | scipy==1.14.1 176 | semantic-version==2.10.0 177 | sentence-transformers==3.0.1 178 | sentencepiece==0.2.0 179 | setuptools==74.0.0 180 | shellingham==1.5.4 181 | six==1.16.0 182 | sniffio==1.3.1 183 | soupsieve==2.6 184 | SQLAlchemy==2.0.32 185 | starlette==0.38.2 186 | striprtf==0.0.26 187 | sympy==1.13.2 188 | tenacity==8.5.0 189 | text-generation==0.7.0 190 | threadpoolctl==3.5.0 191 | tiktoken==0.7.0 192 | tokenizers==0.19.1 193 | tomlkit==0.12.0 194 | torch==2.4.0 195 | torchaudio==2.4.0 196 | torchvision==0.19.0 197 | tqdm==4.66.5 198 | transformers==4.44.2 199 | triton==3.0.0 200 | typer==0.12.5 201 | typing_extensions==4.12.2 202 | typing-inspect==0.9.0 203 | tzdata==2024.1 204 | ujson==5.10.0 205 | urllib3==2.2.2 206 | uvicorn==0.30.6 207 | uvloop==0.20.0 208 | watchfiles==0.24.0 209 | websocket-client==1.8.0 210 | websockets==12.0 211 | wheel==0.44.0 212 | wrapt==1.16.0 213 | xformers==0.0.27.post2 214 | yarl==1.9.4 215 | zipp==3.20.1 -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from llama_index.core.chat_engine.types import ChatMode 2 | from llama_index.embeddings.huggingface import HuggingFaceEmbedding 3 | from llama_index.llms.anthropic import Anthropic 4 | from llama_index.llms.ollama import Ollama 5 | from llama_index.llms.huggingface import HuggingFaceLLM 6 | from llama_index.llms.nvidia import NVIDIA 7 | from llama_index.llms.openai import OpenAI 8 | from llama_index.core import VectorStoreIndex 9 | from llama_index.core.memory import ChatMemoryBuffer 10 | from llama_index.core.llms import ChatMessage, MessageRole 11 | from transformers import BitsAndBytesConfig 12 | import torch, dotenv, os, gc 13 | from huggingface_hub import login 14 | 15 | dotenv.load_dotenv() 16 | 17 | # Used to determine what devices are available and set different gpus to different purposes 18 | def set_device(gpu: int = None) -> str: 19 | return f"cuda:{gpu}" if torch.cuda.is_available() and gpu is not None else "cpu" 20 | 21 | # Sets embedding model using a hugging face embedding model for local embeddings. 22 | def set_embedding_model(): 23 | embed_model = HuggingFaceEmbedding(model_name="/home/jake/Programming/Models/embedding/multilingual-e5-large-instruct", 24 | device=set_device(0), trust_remote_code=True) 25 | return embed_model 26 | 27 | # Function that configures Ollama models and sets some of the initial parameters 28 | def set_ollama_llm(model, temperature, max_tokens): 29 | llm_models = { 30 | "codestral:latest": {"model": "codestral:latest", "device": set_device(1)}, 31 | "qwen3:latest": {"model": "qwen3:latest", "device": set_device(1)}, 32 | "gemma3:12b": {"model": "gemma3:12b", "device": set_device(1)}, 33 | "codegemma:latest": {"model": "codegemma:latest", "device":set_device(1)}, 34 | "mistral-nemo:latest": {"model": "mistral-nemo:latest", "device": set_device(1)}, 35 | "llama3.1:latest": {"model": "llama3.1:latest", "device": set_device(1)}, 36 | "deepseek-r1:14b": {"model": "deepseek-r1:14b", "device": set_device(1)}, 37 | "deepseek-coder-v2:latest": {"model": "deepseek-coder-v2:latest", "device": set_device(1)}, 38 | 39 | } 40 | llm_config = llm_models.get(model, llm_models["codestral:latest"]) 41 | return Ollama(model=llm_config["model"], request_timeout=30.0, device=llm_config["device"], 42 | temperature=temperature, additional_kwargs={"num_predict": max_tokens}) 43 | 44 | # Sets huggingface model and quantization based off of users input 45 | def set_huggingface_llm(model, temperature, max_tokens, top_p, context_window, quantization): 46 | torch.cuda.empty_cache() 47 | gc.collect() 48 | if model == "": 49 | torch.cuda.empty_cache() 50 | gc.collect() 51 | pass 52 | else: 53 | if quantization == "2 Bit": 54 | quantization_config = BitsAndBytesConfig( 55 | load_in_4bit=True, 56 | bnb_4bit_compute_dtype=torch.bfloat16, 57 | bnb_4bit_quant_type="nf4", 58 | bnb_4bit_use_double_quant=True 59 | ) 60 | elif quantization == "4 Bit": 61 | quantization_config = BitsAndBytesConfig( 62 | load_in_4bit=True, 63 | bnb_4bit_compute_dtype=torch.bfloat16, 64 | bnb_4bit_quant_type="nf4" 65 | ) 66 | elif quantization == "8 Bit": 67 | quantization_config = BitsAndBytesConfig( 68 | load_in_8bit=True, 69 | bnb_8bit_compute_dtype=torch.bfloat16, 70 | ) 71 | elif quantization == "No Quantization": 72 | quantization_config = None 73 | else: 74 | quantization_config = None 75 | 76 | model_kwargs = {"quantization_config": quantization_config, 77 | "trust_remote_code": True} 78 | torch.cuda.empty_cache() 79 | gc.collect() 80 | login(token=os.getenv("HUGGINGFACE_HUB_TOKEN")) 81 | return HuggingFaceLLM( 82 | model_name=model, 83 | tokenizer_name=model, 84 | context_window=context_window, 85 | max_new_tokens=max_tokens, 86 | model_kwargs=model_kwargs, 87 | is_chat_model=True, 88 | device_map="cuda:0", 89 | generate_kwargs={ 90 | "temperature": temperature, 91 | "top_p": top_p, 92 | "do_sample": True, 93 | }, 94 | ) 95 | 96 | # Sets NVIDIA NIM model and parameters based off of users input 97 | def set_nvidia_model(model, temperature, max_tokens, top_p): 98 | return NVIDIA( 99 | model=model, 100 | max_tokens=max_tokens, 101 | temperature=temperature, 102 | top_p=top_p, 103 | nvidia_api_key=os.getenv("NVIDIA_API_KEY") 104 | ) 105 | 106 | # Sets OpenAI model and parameters based off of users input 107 | def set_openai_model(model, temperature, max_tokens, top_p): 108 | return OpenAI( 109 | model=model, 110 | max_tokens=max_tokens, 111 | temperature=temperature, 112 | top_p=top_p, 113 | api_key=os.getenv("OPENAI_API_KEY"), 114 | ) 115 | 116 | # Sets Anthropic model and parameters based off of users input 117 | def set_anth_model(model, temperature, max_tokens): 118 | return Anthropic( 119 | model=model, 120 | max_tokens=max_tokens, 121 | temperature=temperature, 122 | api_key=os.getenv("ANTHROPIC_API_KEY"), 123 | 124 | ) 125 | 126 | # Sets chat memory limits based off of models default context length to ensure users don't exceed models limits 127 | def set_chat_memory(model): 128 | memory_limits = { 129 | "codestral:latest": 30000, 130 | "mistralai/Codestral-22B-v0.1": 30000, 131 | "qwen3:latest":30000, 132 | "Qwen/Qwen3-14B":30000, 133 | "gemma3:12b":124000, 134 | "google/gemma-3-12b-it":124000, 135 | "mistral-nemo:latest": 124000, 136 | "mistralai/Mistral-Nemo-Instruct-2407": 124000, 137 | "llama3.1:latest": 124000, 138 | "meta-llama/Meta-Llama-3.1-8B-Instruct": 124000, 139 | "deepseek-coder-v2:latest": 124000, 140 | "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct": 124000, 141 | "deepseek-r1:14b":124000, 142 | "codegemma:latest": 6000, 143 | "google/codegemma-7b": 6000, 144 | } 145 | token_limit = memory_limits.get(model, 30000) 146 | return ChatMemoryBuffer.from_defaults(token_limit=token_limit) 147 | 148 | 149 | # TODO Finish neo4j implementation 150 | """ 151 | Sets up the initial chat engine. Loads documents, model, embedding model, memory, prompt or custom prompt, 152 | and storage context. This pulls from defaults set above and gets updated by create chat engine as users input new 153 | parameters and data. 154 | """ 155 | def setup_index_and_chat_engine(docs, embed_model, llm, memory, custom_prompt, storage_context): 156 | if storage_context: 157 | index = VectorStoreIndex.from_documents(docs, storage_context=storage_context, embed_model=embed_model) 158 | else: 159 | index = VectorStoreIndex.from_documents(docs, embed_model=embed_model) 160 | chat_prompt = ( 161 | "You are an AI coding assistant, your primary function is to help users with coding-related questions \n" 162 | "and tasks. You have access to a knowledge base of programming documentation and best practices. \n" 163 | "When answering questions please follow these guidelines. 1. Provide clear, concise, and \n" 164 | "accurate code snippets when appropriate. 2. Explain your code and reasoning step by step. 3. Offer \n" 165 | "suggestions for best practices and potential optimizations. 4. If the user's question is unclear, \n" 166 | "ask for clarification dont assume or guess the answer to any question. 5. When referencing external \n" 167 | "libraries or frameworks, briefly explain their purpose. 6. If the question involves multiple possible \n" 168 | "approaches, outline the pros and cons of each. Always Remember to be friendly! \n" 169 | "Response:" 170 | ) 171 | system_message = ChatMessage(role=MessageRole.SYSTEM, content=chat_prompt if custom_prompt is None else custom_prompt) 172 | chat_engine = index.as_chat_engine( 173 | chat_mode=ChatMode.CONTEXT, 174 | memory=memory, 175 | stream=True, 176 | system_prompt=system_message, 177 | llm=llm, 178 | verbose=True, 179 | context_prompt=("Context information is below.\n" 180 | "---------------------\n" 181 | "{context_str}\n" 182 | "---------------------\n" 183 | "Given the context information above I want you to think step by step to answer \n" 184 | "the query in a crisp manner, incase case you don't know the answer say 'I don't know!'. \n" 185 | "Query: {query_str} \n" 186 | "Answer: ") 187 | ) 188 | return chat_engine 189 | --------------------------------------------------------------------------------