├── .env.example ├── .gitignore ├── LICENSE ├── README.md ├── add_your_files_here └── .example ├── anthropic_generation.py ├── app.py ├── dummyfunctions.py ├── example.env ├── ragutils.py └── requirements.txt /.env.example: -------------------------------------------------------------------------------- 1 | ANTHROPIC_API_KEY= 2 | OPENAI_API_KEY= 3 | FIELDPROMPT= 4 | EXAMPLEPROMPT= 5 | DESCRIPTIONPROMPT= -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Tonic 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Adapt-a-RAG: Adaptable Retrieval Augmented Generation 2 | 3 | Adapt-a-RAG is an adaptable retrieval augmented application that provides question answering over documents, GitHub repositories, and websites. It takes data, creates synthetic data, and uses that synthetic data to optimize the prompts of the Adapt-a-RAG application. The application recompiles itself every run in a unique and adapted way to the user query. 4 | 5 | ## Table of Contents 6 | 7 | - [Introduction](#introduction) 8 | - [Setup](#setup) 9 | - [How It Works](#how-it-works) 10 | - [Contributing](#contributing) 11 | - [License](#license) 12 | 13 | ## Introduction 14 | 15 | Adapt-a-RAG is an innovative application that leverages the power of retrieval augmented generation to provide accurate and relevant answers to user queries. By adapting itself to each query, Adapt-a-RAG ensures that the generated responses are tailored to the specific needs of the user. 16 | 17 | The application utilizes various data sources, including documents, GitHub repositories, and websites, to gather information and generate synthetic data. This synthetic data is then used to optimize the prompts of the Adapt-a-RAG application, enabling it to provide more accurate and contextually relevant answers. 18 | 19 | ## Setup 20 | 21 | To set up Adapt-a-RAG, follow these steps: 22 | 23 | 1. Clone the repository: 24 | ```bash 25 | git clone https://github.com/Josephrp/adapt-a-rag.git 26 | ``` 27 | 28 | 2. Install the required dependencies: 29 | ```bash 30 | pip install -r requirements.txt 31 | ``` 32 | 33 | 3. Configure the necessary API keys and environment variables. 34 | - edit .env.example 35 | 36 | 4. add your keys and `seed prompts`to a .env file 37 | - open a file editor 38 | - add the following text exactly: 39 | ```python 40 | ANTHROPIC_API_KEY= 41 | OPENAI_API_KEY= 42 | FIELDPROMPT= 43 | EXAMPLEPROMPT= 44 | DESCRIPTIONPROMPT= 45 | ``` 46 | 47 | - save the file name : `.env` in the same folder as app.py 48 | 49 | 5. Add your files the folder `add_your_files_here` ; supported formats : ... 50 | 51 | 6. Run the application: 52 | ```python 53 | python app.py 54 | ``` 55 | 56 | ## How It Works 57 | 58 | Adapt-a-RAG works by following these key steps: 59 | 60 | 1. **Data Collection**: The application collects data from various sources, including documents, GitHub repositories, and websites. It utilizes different reader classes such as `CSVReader`, `DocxReader`, `PDFReader`, `ChromaReader`, and `SimpleWebPageReader` to extract information from these sources. 61 | 62 | 2. **Synthetic Data Generation**: Adapt-a-RAG generates synthetic data using the collected data. It employs techniques such as data augmentation and synthesis to create additional training examples that can help improve the performance of the application. 63 | 64 | 3. **Prompt Optimization**: The synthetic data is used to optimize the prompts of the Adapt-a-RAG application. By fine-tuning the prompts based on the generated data, the application can generate more accurate and relevant responses to user queries. 65 | 66 | 4. **Recompilation**: Adapt-a-RAG recompiles itself every run based on the optimized prompts and the specific user query. This dynamic recompilation allows the application to adapt and provide tailored responses to each query. 67 | 68 | 5. **Question Answering**: Once recompiled, Adapt-a-RAG takes the user query and retrieves relevant information from the collected data sources. It then generates a response using the optimized prompts and the retrieved information, providing accurate and contextually relevant answers to the user. 69 | 70 | ## Contributing 71 | 72 | We welcome contributions to Adapt-a-RAG! If you'd like to contribute, please follow these steps: 73 | 74 | 1. Fork the repository on GitHub. 75 | 76 | 2. Create a new branch from the `devbranch`: 77 | ``` 78 | git checkout -b feature/your-feature-name devbranch 79 | ``` 80 | 81 | 3. Make your changes and commit them with descriptive commit messages. 82 | 83 | 4. Push your changes to your forked repository: 84 | ``` 85 | git push origin feature/your-feature-name 86 | ``` 87 | 88 | 5. Open a pull request against the `devbranch` of the main repository. 89 | 90 | Please ensure that your contributions adhere to the project's coding conventions and include appropriate tests and documentation. 91 | 92 | ## License 93 | 94 | Adapt-a-RAG is released under the MIT License. See the [LICENSE](LICENSE) file for more details. 95 | -------------------------------------------------------------------------------- /add_your_files_here/.example: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TeamTonic/adapt-a-rag/285a05965a80274dad822f8ee6a79662dd966413/add_your_files_here/.example -------------------------------------------------------------------------------- /anthropic_generation.py: -------------------------------------------------------------------------------- 1 | # source of this code 2 | # https://github.com/stanfordnlp/dspy/blob/main/dspy/experimental/synthetic_data.py 3 | 4 | import dspy 5 | from dspy.functional import TypedPredictor 6 | from dspy.teleprompt import LabeledFewShot 7 | from dsp.modules.anthropic import Claude 8 | 9 | from pydantic import BaseModel, Field 10 | from dotenv import load_dotenv 11 | import os 12 | 13 | load_dotenv() 14 | 15 | clause_llm = Claude( 16 | api_key= os.getenv("ANTHROPIC_API_KEY") 17 | 18 | ) 19 | 20 | colbertv2_wiki17_abstracts = dspy.ColBERTv2(url='http://20.102.90.50:2017/wiki17_abstracts') 21 | dspy.settings.configure(lm=clause_llm, rm=colbertv2_wiki17_abstracts) 22 | 23 | class SyntheticFact(BaseModel): 24 | fact: str = Field(..., description="a statement") 25 | varacity: bool = Field(..., description="is the statement true or false") 26 | 27 | class ExampleSignature(dspy.Signature): 28 | """Generate an example of a synthetic fact.""" 29 | fact: SyntheticFact = dspy.OutputField() 30 | 31 | generator = TypedPredictor(ExampleSignature) 32 | examples = generator(config=dict(n=10)) 33 | 34 | # If you have examples and want more 35 | existing_examples = [ 36 | dspy.Example(fact="The sky is blue", varacity=True), 37 | dspy.Example(fact="The sky is green", varacity=False), 38 | ] 39 | trained = LabeledFewShot().compile(student=generator, trainset=existing_examples) 40 | 41 | augmented_examples = trained(config=dict(n=10)) 42 | 43 | x = 0 # added for debugging purposes -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | import llama_index 2 | from llama_index.readers.file import CSVReader 3 | from llama_index.readers.file import DocxReader 4 | from llama_index.readers.file import EpubReader 5 | from llama_index.readers.file import FlatReader 6 | from llama_index.readers.file import HTMLTagReader 7 | from llama_index.readers.file import HWPReader 8 | from llama_index.readers.file import IPYNBReader 9 | from llama_index.readers.file import ImageCaptionReader 10 | from llama_index.readers.file import ImageReader 11 | from llama_index.readers.file import ImageTabularChartReader 12 | from llama_index.readers.file import ImageVisionLLMReader 13 | from llama_index.readers.file import MarkdownReader 14 | from llama_index.readers.file import MboxReader 15 | from llama_index.readers.file import PDFReader 16 | from llama_index.readers.file import PagedCSVReader 17 | from llama_index.readers.file import PandasCSVReader 18 | from llama_index.readers.file import PptxReader 19 | from llama_index.readers.file import PyMuPDFReader 20 | from llama_index.readers.file import RTFReader 21 | from llama_index.readers.file import UnstructuredReader 22 | from llama_index.readers.file import VideoAudioReader 23 | from llama_index.readers.file import XMLReader 24 | from llama_index.readers.chroma import ChromaReader 25 | from llama_index.readers.web import AsyncWebPageReader 26 | from llama_index.readers.web import BeautifulSoupWebReader 27 | from llama_index.readers.web import KnowledgeBaseWebReader 28 | from llama_index.readers.web import MainContentExtractorReader 29 | from llama_index.readers.web import NewsArticleReader 30 | from llama_index.readers.web import ReadabilityWebPageReader 31 | from llama_index.readers.web import RssNewsReader 32 | from llama_index.readers.web import RssReader 33 | from llama_index.readers.web import SimpleWebPageReader 34 | from llama_index.readers.web import SitemapReader 35 | from llama_index.readers.web import TrafilaturaWebReader 36 | from llama_index.readers.web import UnstructuredURLLoader 37 | from llama_index.readers.web import WholeSiteReader 38 | 39 | from langchain_core.documents.base import Document 40 | ####LlamaParse 41 | import llama_parse 42 | from llama_parse import LlamaParse 43 | from llama_index.core import SimpleDirectoryReader 44 | import random 45 | from typing import List, Optional 46 | from pydantic import BaseModel 47 | import dspy 48 | import gradio as gr 49 | import dspy 50 | from dspy.retrieve.chromadb_rm import ChromadbRM 51 | from dspy.evaluate import Evaluate 52 | from dspy.datasets.hotpotqa import HotPotQA 53 | from dspy.teleprompt import BootstrapFewShotWithRandomSearch, BootstrapFinetune 54 | from dsp.modules.lm import LM 55 | from dsp.utils.utils import deduplicate 56 | import os 57 | import dotenv 58 | from dotenv import load_dotenv, set_key 59 | from pathlib import Path 60 | 61 | from typing import Any, List, Dict 62 | import base64 63 | 64 | import chromadb 65 | 66 | # Define constants and configurations 67 | NUM_THREADS = 4 # Example constant, adjust according to your actual configuration 68 | RECOMPILE_INTO_MODEL_FROM_SCRATCH = False # Example flag 69 | 70 | # ## LOADING DATA 71 | # %load_ext autoreload 72 | # %autoreload 2 73 | 74 | # %set_env CUDA_VISIBLE_DEVICES=7 75 | # import sys; sys.path.append('/future/u/okhattab/repos/public/stanfordnlp/dspy') 76 | 77 | # Assume all necessary imports for llama_index readers are correctly done at the beginning 78 | 79 | ports = [7140, 7141, 7142, 7143, 7144, 7145] 80 | #llamaChat = dspy.HFClientTGI(model="meta-llama/Llama-2-13b-chat-hf", port=ports, max_tokens=150) (DELETED) 81 | # colbertv2 = dspy.ColBERTv2(url='http://20.102.90.50:2017/wiki17_abstracts') 82 | class APIKeyManager: 83 | 84 | @staticmethod 85 | def set_api_keys(anthropic_api_key: str, openai_api_key: str): 86 | """ 87 | Function to securely set API keys by updating the .env file in the application's directory. 88 | This approach ensures that sensitive information is not hard-coded into the application. 89 | """ 90 | print("Setting API keys...") 91 | # Define the path to the .env file 92 | env_path = Path('.') / '.env' 93 | 94 | print(f"Loading existing .env file from: {env_path}") 95 | # Load existing .env file or create one if it doesn't exist 96 | load_dotenv(dotenv_path=env_path, override=True) 97 | 98 | print("Updating .env file with new API keys...") 99 | # Update the .env file with the new values 100 | set_key(env_path, "ANTHROPIC_API_KEY", anthropic_api_key) 101 | set_key(env_path, "OPENAI_API_KEY", openai_api_key) 102 | 103 | print("API keys updated successfully.") 104 | # Returns a confirmation without exposing the keys 105 | return "API keys updated successfully in .env file. Please proceed with your operations." 106 | 107 | @staticmethod 108 | def load_api_keys_and_prompts(): 109 | """ 110 | Loads API keys and prompts from an existing .env file into the application's environment. 111 | This function assumes the .env file is located in the same directory as the script. 112 | """ 113 | print("Loading API keys and prompts...") 114 | # Define the path to the .env file 115 | env_path = Path('.') / '.env' 116 | 117 | print(f"Loading .env file from: {env_path}") 118 | # Load the .env file 119 | load_dotenv(dotenv_path=env_path) 120 | 121 | print("Accessing variables from the environment...") 122 | # Access the variables from the environment 123 | anthropic_api_key = os.getenv("ANTHROPIC_API_KEY") 124 | openai_api_key = os.getenv("OPENAI_API_KEY") 125 | field_prompt = os.getenv("FIELDPROMPT") 126 | example_prompt = os.getenv("EXAMPLEPROMPT") 127 | description_prompt = os.getenv("DESCRIPTIONPROMPT") 128 | 129 | print("API keys and prompts loaded successfully.") 130 | # Optionally, print a confirmation or return the loaded values 131 | return { 132 | "ANTHROPIC_API_KEY": anthropic_api_key, 133 | "OPENAI_API_KEY": openai_api_key, 134 | "FIELDPROMPT": field_prompt, 135 | "EXAMPLEPROMPT": example_prompt, 136 | "DESCRIPTIONPROMPT": description_prompt 137 | } 138 | 139 | class DataProcessor: 140 | def __init__(self, source_file: str, collection_name: str, persist_directory: str): 141 | self.source_file = source_file 142 | self.collection_name = collection_name 143 | self.persist_directory = persist_directory 144 | 145 | def load_data_from_source_and_store(self) -> Any: 146 | # def load_data_from_source_and_store(source: Union[str, dict], collection_name: str, persist_directory: str) -> Any: 147 | """ 148 | Loads data from various sources and stores the data in ChromaDB. 149 | 150 | :param source: A string representing a file path or a URL, or a dictionary specifying web content to fetch. 151 | :param collection_name: Name of the ChromaDB collection to store the data. 152 | :param persist_directory: Path to the directory where ChromaDB data will be persisted. 153 | :return: Loaded data. 154 | """ 155 | # Determine the file extension 156 | if isinstance(self.source_file, str): 157 | ext = os.path.splitext(self.source_file)[-1].lower() 158 | else: 159 | raise TypeError("Source must be a string (file path or URL).") 160 | 161 | # Load data using appropriate reader 162 | if ext == '.csv': 163 | reader = CSVReader() 164 | elif ext == '.docx': 165 | reader = DocxReader() 166 | elif ext == '.epub': 167 | reader = EpubReader() 168 | elif ext == '.html': 169 | reader = HTMLTagReader() 170 | elif ext == '.hwp': 171 | reader = HWPReader() 172 | elif ext == '.ipynb': 173 | reader = IPYNBReader() 174 | elif ext in ['.png', '.jpg', '.jpeg']: 175 | reader = ImageReader() # Assuming ImageReader can handle common image formats 176 | elif ext == '.md': 177 | reader = MarkdownReader() 178 | elif ext == '.mbox': 179 | reader = MboxReader() 180 | elif ext == '.pdf': 181 | reader = PDFReader() 182 | elif ext == '.pptx': 183 | reader = PptxReader() 184 | elif ext == '.rtf': 185 | reader = RTFReader() 186 | elif ext == '.xml': 187 | reader = XMLReader() 188 | elif self.source_file.startswith('http'): 189 | reader = AsyncWebPageReader() # Simplified assumption for URLs 190 | else: 191 | raise ValueError(f"Unsupported source type: {self.source_file}") 192 | 193 | # Use the reader to load data 194 | # data = reader.read(self.source_file) # Adjust method name as necessary 195 | data = reader.load_data(self.source_file) # Adjust method name as necessary 196 | 197 | chroma_client = chromadb.Client() 198 | collection = chroma_client.create_collection(name=self.collection_name) 199 | 200 | collection.add( 201 | documents=[i.text for i in data], # the text fields 202 | metadatas=[i.extra_info for i in data], # the metadata 203 | ids=[i.doc_id for i in data], # the generated ids 204 | ) 205 | 206 | 207 | # Store the data in ChromaDB 208 | # retriever_model = ChromadbRM(self.collection_name, self.persist_directory) 209 | 210 | # retriever_model(data) 211 | 212 | return data 213 | 214 | def choose_reader(full_path:str): 215 | """ 216 | Loads data from various sources and stores the data in ChromaDB. 217 | 218 | :param source: A string representing a file path or a URL, or a dictionary specifying web content to fetch. 219 | """ 220 | # Determine the file extension 221 | if isinstance(full_path, str): 222 | ext = os.path.splitext(full_path)[-1].lower() 223 | else: 224 | raise TypeError("Source must be a string (file path or URL).") 225 | 226 | # Load data using appropriate reader 227 | if ext == '.csv': 228 | reader = CSVReader() 229 | elif ext == '.docx': 230 | reader = DocxReader() 231 | elif ext == '.epub': 232 | reader = EpubReader() 233 | elif ext == '.html': 234 | reader = HTMLTagReader() 235 | elif ext == '.hwp': 236 | reader = HWPReader() 237 | elif ext == '.ipynb': 238 | reader = IPYNBReader() 239 | elif ext in ['.png', '.jpg', '.jpeg']: 240 | reader = ImageReader() # Assuming ImageReader can handle common image formats 241 | elif ext == '.md': 242 | reader = MarkdownReader() 243 | elif ext == '.mbox': 244 | reader = MboxReader() 245 | elif ext == '.pdf': 246 | reader = PDFReader() 247 | elif ext == '.pptx': 248 | reader = PptxReader() 249 | elif ext == '.rtf': 250 | reader = RTFReader() 251 | elif ext == '.xml': 252 | reader = XMLReader() 253 | elif full_path.startswith('http'): 254 | reader = AsyncWebPageReader() # Simplified assumption for URLs 255 | else: 256 | raise ValueError(f"Unsupported source type: {full_path}") 257 | 258 | # Use the reader to load data 259 | data = reader.read(full_path) # Adjust method name as necessary 260 | 261 | return data 262 | 263 | 264 | class DocumentLoader: 265 | 266 | @staticmethod 267 | def load_documents_from_folder(folder_path: str) -> List[Document]: 268 | """Loads documents from files within a specified folder""" 269 | folder_path = "./add_your_files_here" 270 | documents = [] 271 | for root, _, filenames in os.walk(folder_path): 272 | for filename in filenames: 273 | full_path = os.path.join(root, filename) 274 | 275 | reader = choose_reader(full_path) 276 | 277 | x=0 278 | 279 | if reader: 280 | print(f"Loading document from '{filename}' with {type(reader).__name__}") 281 | 282 | try: 283 | docs = list(reader.load_data(input_files=[full_path])) 284 | documents.extend(docs) 285 | 286 | except Exception as e: 287 | print(f"Failed to load document from '{filename}'. Error: {e}") 288 | # Convert to langchain format 289 | documents = [ doc.to_langchain_format() 290 | for doc in documents 291 | ] 292 | return documents 293 | 294 | ### DSPY DATA GENERATOR 295 | 296 | # class descriptionSignature(dspy.Signature): 297 | # load_dotenv() 298 | # field_prompt = os.getenv('FIELDPROMPT', 'Default field prompt if not set') 299 | # example_prompt = os.getenv('EXAMPLEPROMPT', 'Default example prompt if not set') 300 | # description_prompt = os.getenv('DESCRIPTIONPROMPT', 'Default description prompt if not set') 301 | # field_name = dspy.InputField(desc=field_prompt) 302 | # example = dspy.InputField(desc=example_prompt) 303 | # description = dspy.OutputField(desc=description_prompt) 304 | 305 | load_dotenv() 306 | 307 | # https://github.com/stanfordnlp/dspy?tab=readme-ov-file#4-two-powerful-concepts-signatures--teleprompters 308 | class DescriptionSignature(dspy.Signature): 309 | """Write a simple search query that will help answer a complex question.""" 310 | 311 | context = dspy.InputField(desc="may contain relevant facts") 312 | question = dspy.InputField() 313 | query = dspy.OutputField() 314 | 315 | 316 | class SyntheticDataGenerator: 317 | def __init__(self, schema_class: Optional[BaseModel] = None, examples: Optional[List[dspy.Example]] = None): 318 | self.schema_class = schema_class 319 | self.examples = examples 320 | print("SyntheticDataGenerator initialized.") 321 | 322 | def generate(self, sample_size: int) -> List[dspy.Example]: 323 | print(f"Starting data generation for sample size: {sample_size}") 324 | if not self.schema_class and not self.examples: 325 | raise ValueError("Either a schema_class or examples must be provided.") 326 | if self.examples and len(self.examples) >= sample_size: 327 | print("No additional data generation needed.") 328 | return self.examples[:sample_size] 329 | 330 | additional_samples_needed = sample_size - (len(self.examples) if self.examples else 0) 331 | print(f"Generating {additional_samples_needed} additional samples.") 332 | generated_examples = self._generate_additional_examples(additional_samples_needed) 333 | 334 | return self.examples + generated_examples if self.examples else generated_examples 335 | 336 | def _define_or_infer_fields(self): 337 | print("Defining or inferring fields for data generation.") 338 | if self.schema_class: 339 | data_schema = self.schema_class.model_json_schema() 340 | properties = data_schema['properties'] 341 | elif self.examples: 342 | inferred_schema = self.examples[0].__dict__['_store'] 343 | descriptor = dspy.Predict(DescriptionSignature) 344 | properties = {field: {'description': str((descriptor(field_name=field, example=str(inferred_schema[field]))).description)} 345 | for field in inferred_schema.keys()} 346 | else: 347 | properties = {} 348 | return properties 349 | 350 | def _generate_additional_examples(self, additional_samples_needed: int) -> List[dspy.Example]: 351 | print(f"Generating {additional_samples_needed} additional examples.") 352 | properties = self._define_or_infer_fields() 353 | class_name = f"{self.schema_class.__name__ if self.schema_class else 'Inferred'}Signature" 354 | fields = self._prepare_fields(properties) 355 | 356 | signature_class = type(class_name, (dspy.Signature,), fields) 357 | generator = dspy.Predict(signature_class, n=additional_samples_needed) 358 | response = generator(sindex=str(random.randint(1, additional_samples_needed))) 359 | 360 | return [dspy.Example({field_name: getattr(completion, field_name) for field_name in properties.keys()}) 361 | for completion in response.completions] 362 | 363 | def _prepare_fields(self, properties) -> dict: 364 | print("Preparing fields for the signature class.") 365 | return { 366 | '__doc__': f"Generates the following outputs: {{{', '.join(properties.keys())}}}.", 367 | 'sindex': dspy.InputField(desc="a random string"), 368 | **{field_name: dspy.OutputField(desc=properties[field_name].get('description', 'No description')) 369 | for field_name in properties.keys()}, 370 | } 371 | 372 | # # Generating synthetic data via existing examples 373 | # generator = SyntheticDataGenerator(examples=existing_examples) 374 | # dataframe = generator.generate(sample_size=5) 375 | 376 | class ClaudeModelManager: 377 | def __init__(self, model: str = "claude-3-opus-20240229", api_key: Optional[str] = None, api_base: Optional[str] = None): 378 | self.model = model 379 | self.api_key = api_key 380 | self.api_base = api_base 381 | self.initialize_claude() 382 | 383 | def initialize_claude(self): 384 | """Wrapper around anthropic's API. Supports both the Anthropic and Azure APIs.""" 385 | def __init__( 386 | self, 387 | model: str = "claude-3-opus-20240229", 388 | api_key: Optional[str] = None, 389 | api_base: Optional[str] = None, 390 | **kwargs, 391 | ): 392 | print("Initializing Claude...") 393 | super().__init__(model) 394 | 395 | try: 396 | from anthropic import Anthropic, RateLimitError 397 | print("Successfully imported Anthropics's API client.") 398 | except ImportError as err: 399 | print("Failed to import Anthropics's API client.") 400 | raise ImportError("Claude requires `pip install anthropic`.") from err 401 | 402 | self.provider = "anthropic" 403 | self.api_key = os.environ.get("ANTHROPIC_API_KEY") if api_key is None else api_key 404 | if self.api_key: 405 | print("API key is set.") 406 | else: 407 | print("API key is not set. Please ensure it's provided or set in the environment variables.") 408 | 409 | self.api_base = BASE_URL if api_base is None else api_base 410 | print(f"API base URL is set to: {self.api_base}") 411 | 412 | self.kwargs = { 413 | "temperature": 0.0 if "temperature" not in kwargs else kwargs["temperature"], 414 | "max_tokens": min(kwargs.get("max_tokens", 4096), 4096), 415 | "top_p": 1.0 if "top_p" not in kwargs else kwargs["top_p"], 416 | "top_k": 1 if "top_k" not in kwargs else kwargs["top_k"], 417 | "n": kwargs.pop("n", kwargs.pop("num_generations", 1)), 418 | **kwargs, 419 | } 420 | self.kwargs["model"] = model 421 | print(f"Model parameters set: {self.kwargs}") 422 | 423 | # self.history: List[dict[str, Any]] = [] 424 | self.history = [] # changed to be commatible with older versions 425 | self.client = Anthropic(api_key=self.api_key) 426 | print("Anthropic client initialized.") 427 | 428 | class SyntheticDataHandler: 429 | def __init__(self, examples: Optional[List[dspy.Example]] = None): 430 | self.generator = SyntheticDataGenerator(examples=examples) 431 | 432 | def generate_data(self, sample_size: int): 433 | return self.generator.generate(sample_size=sample_size) 434 | 435 | 436 | class ClaudeModelConfig: 437 | def __init__(self, model_name): 438 | self.model = model_name 439 | 440 | def get_model(self): 441 | return Claude(model=self.model) 442 | 443 | def configure_dspy_settings(lm_model): 444 | dspy.settings.configure(rm=colbertv2, lm=lm_model) 445 | 446 | class DatasetPreparation: 447 | @staticmethod 448 | def prepare_datasets(dataset): 449 | trainset = [x.with_inputs('question') for x in dataset.train] 450 | devset = [x.with_inputs('question') for x in dataset.dev] 451 | testset = [x.with_inputs('question') for x in dataset.test] 452 | return trainset, devset, testset 453 | 454 | # class BasicMH(dspy.Module): 455 | # def __init__(self, claude_model, passages_per_hop=3): 456 | # super().__init__() 457 | # self.claude_model = claude_model 458 | # self.passages_per_hop = passages_per_hop 459 | 460 | # def forward(self, question): 461 | # context = [] 462 | # for hop in range(2): 463 | # search_results = self.claude_model.search(question, context=context, k=self.passages_per_hop) 464 | # passages = [result.passage for result in search_results] 465 | # context = self.deduplicate(context + passages) 466 | # answer = self.claude_model.generate(context=context, question=question) 467 | # return answer 468 | 469 | # @staticmethod 470 | # def deduplicate(passages): 471 | # return list(dict.fromkeys(passages)) 472 | 473 | class ModelCompilationAndEnsemble: 474 | @staticmethod 475 | def compile_or_load_models(recompile, trainset, num_models=4): 476 | ensemble = [] 477 | if recompile: 478 | metric_EM = dspy.evaluate.answer_exact_match 479 | tp = BootstrapFewShotWithRandomSearch(metric=metric_EM, max_bootstrapped_demos=2, num_threads=NUM_THREADS) 480 | claude_bs = tp.compile(Claude(), trainset=trainset[:50], valset=trainset[50:200]) 481 | ensemble = [prog for *_, prog in claude_bs.candidate_programs[:num_models]] 482 | else: 483 | for idx in range(num_models): 484 | claude_model = Claude(model=f'multihop_claude3opus_{idx}.json') 485 | ensemble.append(claude_model) 486 | return ensemble 487 | 488 | # # # Instantiate Claude with desired parameters 489 | # # claude_model = Claude(model="claude-3-opus-20240229") 490 | 491 | # # # Configure dspy settings with Claude as the language model 492 | # # dspy.settings.configure(rm=colbertv2, lm=claude_model) 493 | # # #dspy.settings.configure(rm=colbertv2, lm=llamaChat) #Llama change into model based on line 166 494 | 495 | # # dataset = dataframe 496 | # # trainset = [x.with_inputs('question') for x in dataset.train] 497 | # # devset = [x.with_inputs('question') for x in dataset.dev] 498 | # # testset = [x.with_inputs('question') for x in dataset.test] 499 | 500 | # # #len(trainset), len(devset), len(testset) 501 | # # #trainset[0] 502 | 503 | # class BasicMH(dspy.Module): 504 | # def __init__(self, claude_model, passages_per_hop=3): 505 | # super().__init__() 506 | 507 | # self.claude_model = claude_model 508 | # self.passages_per_hop = passages_per_hop 509 | 510 | # def forward(self, question): 511 | # context = [] 512 | 513 | # for hop in range(2): 514 | # # Retrieval using Claude model 515 | # search_results = self.claude_model.search(question, context=context, k=self.passages_per_hop) 516 | # passages = [result.passage for result in search_results] 517 | # context = deduplicate(context + passages) 518 | 519 | # # Generation using Claude model 520 | # answer = self.claude_model.generate(context=context, question=question) 521 | 522 | # return answer 523 | 524 | # metric_EM = dspy.evaluate.answer_exact_match 525 | 526 | # if RECOMPILE_INTO_MODEL_FROM_SCRATCH: 527 | # tp = BootstrapFewShotWithRandomSearch(metric=metric_EM, max_bootstrapped_demos=2, num_threads=NUM_THREADS) 528 | # # Compile the Claude model using BootstrapFewShotWithRandomSearch 529 | # claude_bs = tp.compile(Claude(), trainset=trainset[:50], valset=trainset[50:200]) 530 | 531 | # # Get the compiled programs 532 | # ensemble = [prog for *_, prog in claude_bs.candidate_programs[:4]] 533 | 534 | # for idx, prog in enumerate(ensemble): 535 | # # Save the compiled Claude models if needed 536 | # # prog.save(f'multihop_llama213b_{idx}.json') 537 | # pass 538 | # else: 539 | # ensemble = [] 540 | 541 | # for idx in range(4): 542 | # # Load the previously trained Claude models 543 | # claude_model = Claude(model=f'multihop_claude3opus_{idx}.json') #need to prepare this .json file 544 | # ensemble.append(claude_model) 545 | 546 | # # Select the first Claude model from the ensemble 547 | # claude_program = ensemble[0] 548 | 549 | # Add this class definition to your app.py 550 | 551 | class ChatbotManager: 552 | def __init__(self): 553 | self.models = self.load_models() 554 | self.history = [] 555 | 556 | def load_models(self): 557 | pass 558 | # return models 559 | 560 | def generate_response(self, text, image, model_select_dropdown, top_p, temperature, repetition_penalty, max_length_tokens, max_context_length_tokens): 561 | return gradio_chatbot_output, self.history, "Generate: Success" 562 | 563 | def generate_prompt_with_history( text, history, max_length=2048): 564 | """ 565 | Generate a prompt with history for the deepseek application. 566 | Args: 567 | text (str): The text prompt. 568 | history (list): List of previous conversation messages. 569 | max_length (int): The maximum length of the prompt. 570 | Returns: 571 | tuple: A tuple containing the generated prompt, conversation, and conversation copy. If the prompt could not be generated within the max_length limit, returns None. 572 | """ 573 | user_role_ind = 0 574 | bot_role_ind = 1 575 | 576 | # Initialize conversation 577 | conversation = ""# ADD DSPY HERE vl_chat_processor.new_chat_template() 578 | 579 | if history: 580 | conversation.messages = history 581 | 582 | # if image is not None: 583 | # if "" not in text: 584 | # text = ( 585 | # "" + "\n" + text 586 | # ) # append the in a new line after the text prompt 587 | # text = (text, image) 588 | 589 | conversation.append_message(conversation.roles[user_role_ind], text) 590 | conversation.append_message(conversation.roles[bot_role_ind], "") 591 | 592 | # Create a copy of the conversation to avoid history truncation in the UI 593 | conversation_copy = conversation.copy() 594 | logger.info("=" * 80) 595 | logger.info(get_prompt(conversation)) 596 | 597 | rounds = len(conversation.messages) // 2 598 | 599 | for _ in range(rounds): 600 | current_prompt = get_prompt(conversation) 601 | # current_prompt = ( 602 | # current_prompt.replace("", "") 603 | # if sft_format == "deepseek" 604 | # else current_prompt 605 | # ) 606 | 607 | # if current_prompt.count("") > 2: 608 | # for _ in range(len(conversation_copy.messages) - 2): 609 | # conversation_copy.messages.pop(0) 610 | # return conversation_copy 611 | 612 | # if torch.tensor(tokenizer.encode(current_prompt)).size(-1) <= max_length: 613 | # return conversation_copy 614 | 615 | if len(conversation.messages) % 2 != 0: 616 | gr.Error("The messages between user and assistant are not paired.") 617 | return 618 | 619 | try: 620 | for _ in range(2): # pop out two messages in a row 621 | conversation.messages.pop(0) 622 | except IndexError: 623 | gr.Error("Input text processing failed, unable to respond in this round.") 624 | return None 625 | 626 | gr.Error("Prompt could not be generated within max_length limit.") 627 | return None 628 | 629 | def to_gradio_chatbot(conv): 630 | """Convert the conversation to gradio chatbot format.""" 631 | ret = [] 632 | for i, (role, msg) in enumerate(conv.messages[conv.offset :]): 633 | if i % 2 == 0: 634 | if type(msg) is tuple: 635 | msg, image = msg 636 | msg = msg 637 | if isinstance(image, str): 638 | with open(image, "rb") as f: 639 | data = f.read() 640 | img_b64_str = base64.b64encode(data).decode() 641 | image_str = f'' 642 | msg = msg.replace("\n".join([""] * 4), image_str) 643 | else: 644 | max_hw, min_hw = max(image.size), min(image.size) 645 | aspect_ratio = max_hw / min_hw 646 | max_len, min_len = 800, 400 647 | shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw)) 648 | longest_edge = int(shortest_edge * aspect_ratio) 649 | W, H = image.size 650 | if H > W: 651 | H, W = longest_edge, shortest_edge 652 | else: 653 | H, W = shortest_edge, longest_edge 654 | image = image.resize((W, H)) 655 | buffered = BytesIO() 656 | image.save(buffered, format="JPEG") 657 | img_b64_str = base64.b64encode(buffered.getvalue()).decode() 658 | img_str = f'user upload image' 659 | msg = msg.replace("", img_str) 660 | ret.append([msg, None]) 661 | else: 662 | ret[-1][-1] = msg 663 | return ret 664 | def to_gradio_history(conv): 665 | """Convert the conversation to gradio history state.""" 666 | return conv.messages[conv.offset :] 667 | 668 | 669 | def get_prompt(conv) -> str: 670 | """Get the prompt for generation.""" 671 | system_prompt = conv.system_template.format(system_message=conv.system_message) 672 | if conv.sep_style == SeparatorStyle.DeepSeek: 673 | seps = [conv.sep, conv.sep2] 674 | if system_prompt == "" or system_prompt is None: 675 | ret = "" 676 | else: 677 | ret = system_prompt + seps[0] 678 | for i, (role, message) in enumerate(conv.messages): 679 | if message: 680 | if type(message) is tuple: # multimodal message 681 | message, _ = message 682 | ret += role + ": " + message + seps[i % 2] 683 | else: 684 | ret += role + ":" 685 | return ret 686 | else: 687 | return conv.get_prompt 688 | 689 | def predict(text, chatbot, history, top_p, temperature, repetition_penalty, max_length_tokens, max_context_length_tokens, model_select_dropdown,): 690 | """ 691 | Function to predict the response based on the user's input and selected model. 692 | Parameters: 693 | user_text (str): The input text from the user. 694 | user_image (str): The input image from the user. 695 | chatbot (str): The chatbot's name. 696 | history (str): The history of the chat. 697 | top_p (float): The top-p parameter for the model. 698 | temperature (float): The temperature parameter for the model. 699 | max_length_tokens (int): The maximum length of tokens for the model. 700 | max_context_length_tokens (int): The maximum length of context tokens for the model. 701 | model_select_dropdown (str): The selected model from the dropdown. 702 | Returns: 703 | generator: A generator that yields the chatbot outputs, history, and status. 704 | """ 705 | print("running the prediction function") 706 | # try: 707 | # tokenizer, vl_gpt, vl_chat_processor = models[model_select_dropdown] 708 | 709 | # if text == "": 710 | # yield chatbot, history, "Empty context." 711 | # return 712 | # except KeyError: 713 | # yield [[text, "No Model Found"]], [], "No Model Found" 714 | # return 715 | 716 | conversation = generate_prompt_with_history( 717 | text, 718 | image, 719 | history, 720 | max_length=max_context_length_tokens, 721 | ) 722 | prompts = convert_conversation_to_prompts(conversation) 723 | gradio_chatbot_output = to_gradio_chatbot(conversation) 724 | 725 | # full_response = "" 726 | # with torch.no_grad(): 727 | # for x in deepseek_generate( 728 | # prompts=prompts, 729 | # vl_gpt=vl_gpt, 730 | # vl_chat_processor=vl_chat_processor, 731 | # tokenizer=tokenizer, 732 | # stop_words=stop_words, 733 | # max_length=max_length_tokens, 734 | # temperature=temperature, 735 | # repetition_penalty=repetition_penalty, 736 | # top_p=top_p, 737 | # ): 738 | # full_response += x 739 | # response = strip_stop_words(full_response, stop_words) 740 | # conversation.update_last_message(response) 741 | # gradio_chatbot_output[-1][1] = response 742 | # yield gradio_chatbot_output, to_gradio_history( 743 | # conversation 744 | # ), 745 | "Generating..." 746 | 747 | print("flushed result to gradio") 748 | # torch.cuda.empty_cache() 749 | 750 | # if is_variable_assigned("x"): 751 | # print(f"{model_select_dropdown}:\n{text}\n{'-' * 80}\n{x}\n{'=' * 80}") 752 | # print( 753 | # f"temperature: {temperature}, top_p: {top_p}, repetition_penalty: {repetition_penalty}, max_length_tokens: {max_length_tokens}" 754 | # ) 755 | 756 | yield gradio_chatbot_output, to_gradio_history(conversation), "Generate: Success" 757 | 758 | 759 | def retry( 760 | text, 761 | image, 762 | chatbot, 763 | history, 764 | top_p, 765 | temperature, 766 | repetition_penalty, 767 | max_length_tokens, 768 | max_context_length_tokens, 769 | model_select_dropdown, 770 | ): 771 | if len(history) == 0: 772 | yield (chatbot, history, "Empty context") 773 | return 774 | 775 | chatbot.pop() 776 | history.pop() 777 | text = history.pop()[-1] 778 | if type(text) is tuple: 779 | text, image = text 780 | 781 | yield from predict( 782 | text, 783 | chatbot, 784 | history, 785 | top_p, 786 | temperature, 787 | repetition_penalty, 788 | max_length_tokens, 789 | max_context_length_tokens, 790 | model_select_dropdown, 791 | ) 792 | 793 | 794 | class Application: 795 | def __init__(self): 796 | self.api_key_manager = APIKeyManager() 797 | # self.data_processor = DataProcessor(source_file="", collection_name="adapt-a-rag", persist_directory="/your_files_here") 798 | self.data_processor = DataProcessor(source_file="", collection_name="adapt-a-rag", persist_directory="your_files_here") 799 | self.claude_model_manager = ClaudeModelManager() 800 | self.synthetic_data_handler = SyntheticDataHandler() 801 | self.chatbot_manager = ChatbotManager() 802 | 803 | def set_api_keys(self, anthropic_api_key, openai_api_key): 804 | return self.api_key_manager.set_api_keys(anthropic_api_key, openai_api_key) 805 | 806 | def handle_file_upload(self, uploaded_file): 807 | self.data_processor.source_file = uploaded_file.name 808 | loaded_data = self.data_processor.load_data_from_source_and_store() 809 | print("Data from {uploaded_file.name} loaded and stored successfully.") 810 | return loaded_data 811 | 812 | def handle_synthetic_data(self, schema_class_name, sample_size): 813 | synthetic_data = self.synthetic_data_handler.generate_data(sample_size=int(sample_size)) 814 | synthetic_data_str = "\n".join([str(data) for data in synthetic_data]) 815 | print ("Generated {sample_size} synthetic data items:\n{synthetic_data_str}") 816 | return synthetic_data 817 | 818 | def handle_chatbot_interaction(self, text, model_select, top_p, temperature, repetition_penalty, max_length_tokens, max_context_length_tokens): 819 | chatbot_response, history, status = self.chatbot_manager.generate_response(text, None, model_select, top_p, temperature, repetition_penalty, max_length_tokens, max_context_length_tokens) 820 | return chatbot_response 821 | def main(self): 822 | with gr.Blocks() as demo: 823 | with gr.Accordion("API Keys", open=True) as api_keys_accordion: 824 | with gr.Row(): 825 | anthropic_api_key_input = gr.Textbox(label="Anthropic API Key", type="password") 826 | openai_api_key_input = gr.Textbox(label="OpenAI API Key", type="password") 827 | submit_button = gr.Button("Submit") 828 | confirmation_output = gr.Textbox(label="Confirmation", visible=False) 829 | 830 | submit_button.click( 831 | fn=self.set_api_keys, 832 | inputs=[anthropic_api_key_input, openai_api_key_input], 833 | outputs=confirmation_output 834 | ) 835 | 836 | with gr.Accordion("Upload Data") as upload_data_accordion: 837 | file_upload = gr.File(label="Upload Data File") 838 | file_upload_button = gr.Button("Process Uploaded File") 839 | file_upload_output = gr.Textbox() 840 | 841 | file_upload_button.click( 842 | fn=self.handle_file_upload, 843 | inputs=[file_upload], 844 | outputs=file_upload_output 845 | ) 846 | 847 | with gr.Accordion("Generate Synthetic Data") as generate_data_accordion: 848 | schema_input = gr.Textbox(label="Schema Class Name") 849 | sample_size_input = gr.Number(label="Sample Size", value=100) 850 | synthetic_data_button = gr.Button("Generate Synthetic Data") 851 | synthetic_data_output = gr.Textbox() 852 | 853 | synthetic_data_button.click( 854 | fn=self.handle_synthetic_data, 855 | inputs=[schema_input, sample_size_input], 856 | outputs=synthetic_data_output 857 | ) 858 | 859 | with gr.Accordion("Chatbot") as chatbot_accordion: 860 | text_input = gr.Textbox(label="Enter your question") 861 | # model_select = gr.Dropdown(label="Select Model", choices=list(self.chatbot_manager.models.keys())) 862 | model_select = gr.Dropdown(label="Select Model", choices=[ClaudeModelManager(api_key=os.getenv("ANTHROPIC_API_KEY"))]) 863 | top_p_input = gr.Slider(label="Top-p", minimum=0.0, maximum=1.0, value=0.95, step=0.01) 864 | # top_p_input = gr.Slider() 865 | temperature_input = gr.Slider(label="Temperature", minimum=0.0, maximum=1.0, value=0.7, step=0.01) 866 | repetition_penalty_input = gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, value=1.1, step=0.1) 867 | max_length_tokens_input = gr.Number(label="Max Length Tokens", value=2048) 868 | max_context_length_tokens_input = gr.Number(label="Max Context Length Tokens", value=2048) 869 | chatbot_output = gr.Chatbot(label="Chatbot Conversation") 870 | submit_button = gr.Button("Submit") 871 | 872 | submit_button.click( 873 | fn=self.handle_chatbot_interaction, 874 | inputs=[text_input, model_select, top_p_input, temperature_input, repetition_penalty_input, max_length_tokens_input, max_context_length_tokens_input], 875 | outputs=chatbot_output 876 | ) 877 | 878 | demo.launch() 879 | 880 | if __name__ == "__main__": 881 | app = Application() 882 | app.main() 883 | 884 | # Example usage 885 | # source_file = "example.txt" # Replace with your source file path 886 | # collection_name = "adapt-a-rag" #Need to be defined 887 | # persist_directory = "/your_files_here" #Need to be defined 888 | 889 | # loaded_data = load_data_from_source_and_store(source_file, collection_name="adapt-a-rag", persist_directory="/your_files_here") 890 | # print("Data loaded and stored successfully in ChromaDB.") 891 | 892 | -------------------------------------------------------------------------------- /dummyfunctions.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | def choose_reader(file_path: str) -> Any: 4 | ext = os.path.splitext(file_path)[1].lower() 5 | 6 | readers = { 7 | ".json": JSONFileReader(), 8 | ".csv": CSVReader(), 9 | ".docx": DocxReader(), 10 | ".epub": EpubReader(), 11 | ".flat": FlatReader(), # This is an assumption; adjust based on actual reader 12 | ".html": HTMLTagReader(), 13 | ".hwp": HWPReader(), 14 | ".ipynb": IPYNBReader(), 15 | ".png": ImageReader(), # Assuming generic image handling 16 | ".jpg": ImageReader(), 17 | ".jpeg": ImageReader(), 18 | # Continue for all file types... 19 | } 20 | 21 | 22 | def choose_reader(file_path: str) -> Any: 23 | """Choose the appropriate reader based on the file extension.""" 24 | 25 | ext = os.path.splitext(file_path)[1].lower() 26 | 27 | readers: Dict[str, Any] = { 28 | ".json": JSONFileReader(), 29 | ".csv": CSVFileReader(), 30 | ".xlsx": ExcelSheetReader(), 31 | ".xls": ExcelSheetReader(), 32 | ".html": HTMLFileReader(), 33 | ".pdf": PDFMinerReader(), 34 | # Add more extensions and their corresponding readers as needed... 35 | } 36 | 37 | # Dummy backend function for handling user query 38 | def handle_query(user_query: str) -> str: 39 | # Placeholder for processing user query 40 | return f"Processed query: {user_query}" 41 | 42 | # Dummy backend function for handling repository input 43 | def handle_repository(repository_link: str) -> str: 44 | # Placeholder for processing repository input 45 | return f"Processed repository link: {repository_link}" 46 | 47 | # New dummy function for handling synthetic data generation 48 | def handle_synthetic_data(schema_class_name: str, sample_size: int) -> str: 49 | # Placeholder for generating synthetic data based on the schema class name and sample size 50 | return f"Synthetic data for schema '{schema_class_name}' with {sample_size} samples has been generated." 51 | 52 | # New dummy function for handling file uploads 53 | def handle_file_upload(uploaded_file): 54 | # Placeholder for processing the uploaded file 55 | if uploaded_file is not None: 56 | return f"Uploaded file '{uploaded_file.name}' has been processed." 57 | return "No file was uploaded." 58 | 59 | from dspy.modules.anthropic import Claude 60 | anthropicChat = Claude(model="claude-3-opus-20240229", port=ports, max_tokens=150) 61 | 62 | 63 | 64 | 65 | #class BasicMH(dspy.Module): 66 | # def __init__(self, passages_per_hop=3): 67 | # super().__init__() 68 | 69 | # self.retrieve = dspy.Retrieve(k=passages_per_hop) 70 | # self.generate_query = [dspy.ChainOfThought("context, question -> search_query") for _ in range(2)] 71 | # self.generate_answer = dspy.ChainOfThought("context, question -> answer") 72 | 73 | # def forward(self, question): 74 | # context = [] 75 | 76 | # for hop in range(2): 77 | # search_query = self.generate_query[hop](context=context, question=question).search_query 78 | # passages = self.retrieve(search_query).passages 79 | # context = deduplicate(context + passages) 80 | 81 | # return self.generate_answer(context=context, question=question).copy(context=context) 82 | 83 | 84 | 85 | # parser = LlamaParse( 86 | # api_key="llx-...", # can also be set in your env as LLAMA_CLOUD_API_KEY 87 | # result_type="markdown", # "markdown" and "text" are available 88 | # num_workers=4, # if multiple files passed, split in `num_workers` API calls 89 | # verbose=True, 90 | # language="en" # Optionaly you can define a language, default=en 91 | # ) 92 | # # sync 93 | # documents = parser.load_data("./my_file.pdf") 94 | 95 | # # sync batch 96 | # documents = parser.load_data(["./my_file1.pdf", "./my_file2.pdf"]) 97 | 98 | # # async 99 | # documents = await parser.aload_data("./my_file.pdf") 100 | 101 | # # async batch 102 | # documents = await parser.aload_data(["./my_file1.pdf", "./my_file2.pdf"]) 103 | # LlamaPack example 104 | # from llama_index.core.llama_pack import download_llama_pack 105 | 106 | #Ragas : https://colab.research.google.com/gist/virattt/6a91d2a9dcf99604637e400d48d2a918/ragas-first-look.ipynb 107 | #from ragas.testset.generator import TestsetGenerator 108 | #from ragas.testset.evolutions import simple, reasoning, multi_context 109 | 110 | # generator with openai models 111 | # generator = TestsetGenerator.with_openai() 112 | 113 | # generate testset 114 | #testset = generator.generate_with_langchain_docs(documents, test_size=10, distributions={simple: 0.5, reasoning: 0.25, multi_context: 0.25}) 115 | 116 | # visualize the dataset as a pandas DataFrame 117 | #dataframe = testset.to_pandas() 118 | #dataframe.head(10) 119 | 120 | 121 | #### DSPY APPLICATION LOGIC GOES HERE 122 | 123 | 124 | # We will show you how to import the agent from these files! 125 | 126 | # from llama_index.core.llama_pack import download_llama_pack 127 | 128 | # # download and install dependencies 129 | # download_llama_pack("LLMCompilerAgentPack", "./llm_compiler_agent_pack") 130 | # From here, you can use the pack. You can import the relevant modules from the download folder (in the example below we assume it's a relative import or the directory has been added to your system path). 131 | 132 | # # setup pack arguments 133 | 134 | # from llama_index.core.agent import AgentRunner 135 | # from llm_compiler_agent_pack.step import LLMCompilerAgentWorker 136 | 137 | # agent_worker = LLMCompilerAgentWorker.from_tools( 138 | # tools, llm=llm, verbose=True, callback_manager=callback_manager 139 | # ) 140 | # agent = AgentRunner(agent_worker, callback_manager=callback_manager) 141 | 142 | # # start using the agent 143 | # response = agent.chat("What is (121 * 3) + 42?") 144 | # You can also use/initialize the pack directly. 145 | 146 | # from llm_compiler_agent_pack.base import LLMCompilerAgentPack 147 | 148 | # agent_pack = LLMCompilerAgentPack(tools, llm=llm) 149 | # The run() function is a light wrapper around agent.chat(). 150 | 151 | # response = pack.run("Tell me about the population of Boston") 152 | # You can also directly get modules from the pack. 153 | 154 | # # use the agent 155 | # agent = pack.agent 156 | # response = agent.chat("task") 157 | 158 | 159 | # from llama_parse import LlamaParse 160 | # from llama_index.core import SimpleDirectoryReader 161 | 162 | # parser = LlamaParse( 163 | # api_key="llx-...", # can also be set in your env as LLAMA_CLOUD_API_KEY 164 | # result_type="markdown", # "markdown" and "text" are available 165 | # verbose=True 166 | # ) 167 | 168 | # file_extractor = {".pdf": parser} 169 | # documents = SimpleDirectoryReader("./data", file_extractor=file_extractor).load_data() 170 | 171 | 172 | ## Compiling using meta-llama/Llama-2-13b-chat-hf 173 | #RECOMPILE_INTO_MODEL_FROM_SCRATCH = False 174 | #NUM_THREADS = 24 175 | 176 | #metric_EM = dspy.evaluate.answer_exact_match 177 | 178 | #if RECOMPILE_INTO_MODEL_FROM_SCRATCH: 179 | # tp = BootstrapFewShotWithRandomSearch(metric=metric_EM, max_bootstrapped_demos=2, num_threads=NUM_THREADS) 180 | # basicmh_bs = tp.compile(BasicMH(), trainset=trainset[:50], valset=trainset[50:200]) 181 | 182 | # ensemble = [prog for *_, prog in basicmh_bs.candidate_programs[:4]] 183 | 184 | # for idx, prog in enumerate(ensemble): 185 | # # prog.save(f'multihop_llama213b_{idx}.json') 186 | # pass 187 | #if not RECOMPILE_INTO_MODEL_FROM_SCRATCH: 188 | # ensemble = [] 189 | 190 | # for idx in range(4): 191 | # prog = BasicMH() 192 | # prog.load(f'multihop_llama213b_{idx}.json') 193 | # ensemble.append(prog) 194 | #llama_program = ensemble[0] 195 | #RECOMPILE_INTO_MODEL_FROM_SCRATCH = False 196 | #NUM_THREADS = 24 -------------------------------------------------------------------------------- /example.env: -------------------------------------------------------------------------------- 1 | ANTHROPIC_API_KEY= 2 | OPENAI_API_KEY= -------------------------------------------------------------------------------- /ragutils.py: -------------------------------------------------------------------------------- 1 | ## from weaviate/recipes 2 | import dspy 3 | from dsp.utils import deduplicate 4 | 5 | class RAG(dspy.Module): 6 | def __init__(self, num_passages=3): 7 | super().__init__() 8 | 9 | self.retrieve = dspy.Retrieve(k=num_passages) 10 | self.generate_answer = dspy.ChainOfThought("question, contexts -> answer") 11 | 12 | def forward(self, question): 13 | contexts = self.retrieve(question).passages 14 | prediction = self.generate_answer(question=question, contexts=contexts 15 | return dspy.Prediction(answer=prediction.answer) 16 | 17 | class Reranker(dspy.Signature): 18 | """Please rerank these documents.""" 19 | 20 | context = dspy.InputField(desc="documents coarsely determined to be relevant to the question.") 21 | question = dspy.InputField() 22 | ranked_context = dspy.OutputField(desc="A ranking of documents by relevance to the question.") 23 | 24 | class RAGwithReranker(dspy.Module): 25 | def __init__(self): 26 | super().__init__() 27 | 28 | self.retrieve = dspy.Retrieve(k=5) 29 | self.reranker = dspy.ChainOfThought(Reranker) 30 | self.generate_answer = dspy.ChainOfThought(GenerateAnswer) 31 | 32 | def forward(self, question): 33 | context = self.retrieve(question).passages 34 | context = self.reranker(context=context, question=question).ranked_context 35 | pred = self.generate_answer(context=context, question=question).best_answer 36 | return dspy.Prediction(answer=pred) 37 | 38 | class Summarizer(dspy.Signature): 39 | """Please summarize all relevant information in the context.""" 40 | 41 | context = dspy.InputField(desc="documents determined to be relevant to the question.") 42 | question = dspy.InputField() 43 | summarized_context = dspy.OutputField(desc="A summarization of information in the documents that will help answer the quesetion.") 44 | 45 | class RAGwithSummarizer(dspy.Module): 46 | def __init__(self): 47 | super().__init__() 48 | 49 | self.retrieve = dspy.Retrieve(k=5) 50 | self.summarizer = dspy.ChainOfThought(Summarizer) 51 | self.generate_answer = dspy.ChainOfThought(GenerateAnswer) 52 | 53 | def forward(self, question): 54 | context = self.retrieve(question).passages 55 | context = self.summarizer(context=context, question=question).summarized_context 56 | pred = self.generate_answer(context=context, question=question).best_answer 57 | return dspy.Prediction(answer=pred) 58 | 59 | class GenerateSearchQuery(dspy.Signature): 60 | """Write a simple search query that will help answer a complex question.""" 61 | 62 | context = dspy.InputField(desc="may contain relevant facts") 63 | question = dspy.InputField() 64 | query = dspy.OutputField() 65 | 66 | class MultiHopRAG(dspy.Module): 67 | def __init__(self, passages_per_hop=3, max_hops=2): 68 | super().__init__() 69 | 70 | self.generate_question = [dspy.ChainOfThought(GenerateSearchQuery) for _ in range(max_hops)] 71 | self.retrieve = dspy.Retrieve(k=passages_per_hop) 72 | self.generate_answer = dspy.ChainOfThought(GenerateAnswer) 73 | self.max_hops = max_hops 74 | 75 | def forward(self, question): 76 | context = [] 77 | 78 | for hop in range(self.max_hops): 79 | query = self.generate_question[hop](context=context, question=question).query 80 | passages = self.retrieve(query).passages 81 | context = deduplicate(context + passages) 82 | 83 | pred = self.generate_answer(context=context, question=question) 84 | return dspy.Prediction(context=context, answer=pred.best_answer) 85 | 86 | class MultiHopRAGwithSummarization(dspy.Module): 87 | def __init__(self, passages_per_hop=3, max_hops=2): 88 | super().__init__() 89 | 90 | self.generate_question = [dspy.ChainOfThought(GenerateSearchQuery) for _ in range(max_hops)] 91 | self.retrieve = dspy.Retrieve(k=passages_per_hop) 92 | self.summarizer = dspy.ChainOfThought(Summarizer) 93 | self.generate_answer = dspy.ChainOfThought(GenerateAnswer) 94 | self.max_hops = max_hops 95 | 96 | def forward(self, question): 97 | context = [] 98 | 99 | for hop in range(self.max_hops): 100 | query = self.generate_question[hop](context=context, question=question).query 101 | passages = self.retrieve(query).passages 102 | summarized_passages = self.summarizer(question=query, context=passages).summarized_context 103 | context.append(summarized_passages) 104 | 105 | pred = self.generate_answer(context=context, question=question) 106 | return dspy.Prediction(context=context, answer=pred.best_answer) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | langchain 2 | llama-index 3 | llama-index-readers-file 4 | llama-index-readers-chroma 5 | llama-index-readers-web 6 | llama-parse 7 | dspy-ai 8 | anthropic 9 | ragas 10 | chromadb 11 | unstructured 12 | tiktoken 13 | openai 14 | git+https://github.com/stanfordnlp/dspy.git 15 | python-dotenv --------------------------------------------------------------------------------