├── .chroma_env.example ├── .dockerignore ├── .env.example ├── .github └── CODEOWNERS ├── .gitignore ├── .streamlit └── config.toml ├── .vscode └── settings.json ├── Dockerfile ├── LICENSE ├── README-FOR-DEVELOPERS.md ├── README.md ├── TODOS-AND-OTHER-DEV-NOTES.md ├── _prepare_env.py ├── agentblocks ├── collectionhelper.py ├── core.py ├── docconveyer.py ├── webprocess.py ├── webretrieve.py └── websearch.py ├── agents ├── dbmanager.py ├── exporter.py ├── ingester_summarizer.py ├── research_heatseek.py ├── researcher.py ├── researcher_data.py ├── share_manager.py └── websearcher_quick.py ├── announcements ├── version-0-2-6.md └── version-0-2.md ├── api.py ├── components ├── __init__.py ├── chat_with_docs_chain.py ├── chroma_ddg.py ├── chroma_ddg_retriever.py ├── llm.py └── openai_embeddings_ddg.py ├── config ├── logging.json └── trafilatura.cfg ├── docdocgo.py ├── eval ├── RAG-PARAMS.md ├── ai_news_1.py ├── ai_resume_builders.txt ├── list_ai_resume_builders.txt ├── openai_news.py ├── prompt_running_llama2.txt ├── queries.md ├── testing-rag-quality │ ├── 2401.01325.pdf │ └── rag-tests.md └── top_russian_desserts.py ├── ingest_local_docs.py ├── media ├── ddg-logo.png ├── minimal10.png ├── minimal10.svg ├── minimal11.png ├── minimal11.svg ├── minimal13.svg ├── minimal7-plain-svg.svg ├── minimal7.png ├── minimal7.svg ├── minimal9.png └── minimal9.svg ├── requirements.txt ├── streamlit_app.py └── utils ├── __init__.py ├── algo.py ├── async_utils.py ├── chat_state.py ├── debug.py ├── docgrab.py ├── filesystem.py ├── helpers.py ├── ingest.py ├── input.py ├── lang_utils.py ├── log.py ├── older ├── prompts-checkpoints.py └── prompts-older.py ├── output.py ├── prepare.py ├── prompts.py ├── query_parsing.py ├── rag.py ├── streamlit ├── fix_event_loop.py ├── helpers.py ├── ingest.py └── prepare.py ├── strings.py ├── type_utils.py └── web.py /.chroma_env.example: -------------------------------------------------------------------------------- 1 | # You need to use this only if you want to run Chroma in a Docker container 2 | # Choose your own secret key, but make sure it matches CHROMA_SERVER_AUTHN_CREDENTIALS in .env 3 | # NOTE: Do not enclose values in quotes (contrary to the example in the chromadb docs) 4 | CHROMA_SERVER_AUTHN_CREDENTIALS=choose-your-own-secret-key 5 | 6 | # You don't need to change these 7 | CHROMA_SERVER_AUTHN_PROVIDER=chromadb.auth.token_authn.TokenAuthenticationServerProvider 8 | CHROMA_AUTH_TOKEN_TRANSPORT_HEADER=X-Chroma-Token 9 | ANONYMIZED_TELEMETRY=False 10 | -------------------------------------------------------------------------------- /.dockerignore: -------------------------------------------------------------------------------- 1 | .venv/ 2 | __pycache__/ 3 | *.pyc 4 | .git/ 5 | *.log 6 | dist/ 7 | .DS_Store 8 | *.egg-info/ 9 | *.egg 10 | 11 | .my-chroma* 12 | docs-private/ 13 | DOCS-APPENDIX.md 14 | docdocgo*.pem 15 | .env.* 16 | -------------------------------------------------------------------------------- /.env.example: -------------------------------------------------------------------------------- 1 | # Copy this file to .env and fill in the values 2 | 3 | ## NOTE: The only required value is DEFAULT_OPENAI_API_KEY. The app should work if the other values 4 | # are left as is or not defined at all. However, you are strongly encouraged to fill in your 5 | # own SERPER_API_KEY value. It is also recommended to fill in the BYPASS_SETTINGS_RESTRICTIONS and 6 | # BYPASS_SETTINGS_RESTRICTIONS_PASSWORD values as needed (see below). 7 | DEFAULT_OPENAI_API_KEY="" # your OpenAI API key 8 | 9 | ## Your Google Serper API key (for web searches). 10 | # If you don't set this, my key will be used, which may have 11 | # run out of credits by the time you use it (there's no payment info associated with it). 12 | # You can get a free key without a credit card here: https://serper.dev/ 13 | # SERPER_API_KEY="" 14 | 15 | # Mode to use if none is specified in the query 16 | DEFAULT_MODE="/docs" # or "/web" | "/quotes" | "/details" | "/chat" | "/research" 17 | 18 | # Variables controlling whether the Streamlit app imposes some functionality restrictions 19 | BYPASS_SETTINGS_RESTRICTIONS="" # whether to immediately allow all settings (any non-empty string means true) 20 | BYPASS_SETTINGS_RESTRICTIONS_PASSWORD="" # what to enter in the OpenAI API key field to bypass settings 21 | # restrictions. If BYPASS_SETTINGS_RESTRICTIONS is non-empty or if the user enters their own OpenAI API key, 22 | # this becomes - mostly - irrelevant, as full settings access is already granted. I say "mostly" because 23 | # this password can also be used for a couple of admin-only features, such as deleting the default collection 24 | # and deleting a range of collections (see dbmanager.py). Recommendation: for local use, 25 | # the easiest thing to do is to set both values. For a public deployment, you may want to leave 26 | # BYPASS_SETTINGS_RESTRICTIONS empty. 27 | 28 | # If you are NOT using Azure, the following setting determines the chat model 29 | # NOTE: You can change the model and temperature on the fly in Streamlit UI or via the API 30 | MODEL_NAME="gpt-3.5-turbo-0125" # default model to use for chat 31 | ALLOWED_MODELS="gpt-3.5-turbo-0125, gpt-4-turbo-2024-04-09" 32 | CONTEXT_LENGTH="16000" # you can also make it lower than the actual context length 33 | EMBEDDINGS_MODEL_NAME="text-embedding-3-large" 34 | EMBEDDINGS_DIMENSIONS="3072" # number of dimensions for the embeddings model 35 | # Possible embedding models and their dimensions: 36 | # text-embedding-ada-002: 1536 37 | # text-embedding-3-small: 1536 or 512 38 | # text-embedding-3-large: 3072, 1024 or 512 39 | # NOTE: depending on the model, different relevance score thresholds are needed (see 40 | # score_threshold_min/max in chroma_ddg_retriever.py). Currently these values are auto-set 41 | # based on the model name, but currently they have only been tested for the 42 | # "text-embedding-3-large"-3072 and "text-embedding-ada-002" models. 43 | 44 | # If you are using Azure, uncomment the following settings and fill in the values 45 | # AZURE_OPENAI_API_KEY="" # your OpenAI API key 46 | # OPENAI_API_TYPE="azure" # yours may be different 47 | # OPENAI_API_BASE="https://techops.openai.azure.com" # yours may be different 48 | # OPENAI_API_VERSION="2023-05-15" # or whichever version you're using 49 | # CHAT_DEPLOYMENT_NAME="gpt-35-turbo" # your deployment name for the chat model you want to use 50 | # EMBEDDINGS_DEPLOYMENT_NAME="text-embedding-ada-002" # your deployment name for the embedding model 51 | 52 | TEMPERATURE="0.3" # 0 to 2, higher means more random, lower means more conservative (>1.6 is jibberish) 53 | LLM_REQUEST_TIMEOUT="3" # seconds to wait before request to the LLM service should be considered timed 54 | # out and the bot should retry sending the request (a few times with exponential back-off). 55 | # For Azure, it seems that a slightly longer timeout is needed, about 9 seconds. 56 | 57 | # Whether to use Playwright for more reliable web scraping (any non-empty string means true) 58 | # It's recommended to use Playwright but it may not work in all environments and requires 59 | # a small amount of setup (mostly just running "playwright install" - you will be prompted) 60 | USE_PLAYWRIGHT="" 61 | 62 | DEFAULT_COLLECTION_NAME="docdocgo-documentation" # name of the initially selected collection 63 | 64 | ## There are two ways to use the Chroma vector database: using a local database or via HTTP 65 | ## (e.g. by running it in a Docker container). The default is the local option. While it's the 66 | ## simplest way to get started, the HTTP option is more robust and scalable. 67 | 68 | ## The following item is only relevant if you want to use a local Chroma db 69 | VECTORDB_DIR="chroma/" # directory of your doc database 70 | 71 | ## The following items are only relevant if you want to run Chroma in a Docker container 72 | USE_CHROMA_VIA_HTTP="" # whether to use Chroma via HTTP (any non-empty string means true) 73 | CHROMA_SERVER_AUTH_CREDENTIALS="" # choose your own (must be the same as in your Chroma Docker container) 74 | CHROMA_SERVER_HOST="localhost" # IP address of your Chroma server 75 | CHROMA_SERVER_HTTP_PORT="8000" # port your Chroma server is listening on 76 | 77 | ## Settings for the response 78 | 79 | # Whether to include the error message in the user-facing error message 80 | INCLUDE_ERROR_IN_USER_FACING_ERROR_MSG="sure" 81 | 82 | ## Settings for Streamlit 83 | 84 | DOMAIN_NAME_FOR_SHARING="" # domain name for sharing links (e.g. "https://localhost:8501") 85 | 86 | # Chat avatars (if you don't set these, the default avatars will be used) 87 | USER_AVATAR="" # URL of the user's chat avatar (e.g. "./images/user-avatar-example.png") 88 | BOT_AVATAR="" # URL of the bot's chat avatar (e.g. "./images/bot-avatar-example.png") 89 | 90 | # Warning message to display at the top of the Streamlit app (normally should be left empty) 91 | STREAMLIT_WARNING_NOTIFICATION="" 92 | 93 | ## The following items are only relevant for ingesting docs 94 | COLLECTON_NAME_FOR_INGESTED_DOCS="" # collection name for the docs you want to ingest 95 | DOCS_TO_INGEST_DIR_OR_FILE="" # path to directory or single file with your docs to ingest 96 | 97 | ## The following items are only relevant if running the FastAPI server 98 | DOCDOCGO_API_KEY="" # choose your own 99 | MAX_UPLOAD_BYTES="104857600" # max size of files that can be uploaded (default is 100MB) 100 | 101 | ## Logging settings 102 | DEFAULT_LOGGER_NAME="ddg" 103 | 104 | # Overrides for config/logging.json 105 | LOG_FORMAT="%(asctime)s - %(name)s - %(levelname)s - %(message)s - [%(filename)s:%(lineno)d in %(funcName)s]" 106 | LOG_LEVEL="INFO" # DEBUG, INFO, WARNING, ERROR, CRITICAL 107 | 108 | ## Additional options, used in development (any non-empty string means true) 109 | RERAISE_EXCEPTIONS="" # don't keep convo going if an exception is raised, just reraise it 110 | PRINT_STANDALONE_QUERY="" # print query used for the retriever 111 | PRINT_QA_PROMPT="" # print the prompt used for the final response (if asking about your docs) 112 | PRINT_CONDENSE_QUESTION_PROMPT="" # print the prompt used to condense the question 113 | PRINT_SIMILARITIES="" # print the similarities between the query and the docs 114 | PRINT_RESEARCHER_PROMPT="" # print the prompt used for the final researcher response 115 | INITIAL_TEST_QUERY_STREAMLIT="" # auto-run as the first query in Streamlit 116 | -------------------------------------------------------------------------------- /.github/CODEOWNERS: -------------------------------------------------------------------------------- 1 | * @reasonmethis -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | docdocgo*.pem 2 | chroma-cloud* 3 | .tmp* 4 | credentials/ 5 | chroma-cloud-backup/ 6 | .chroma_env 7 | my-chroma* 8 | .chroma_env 9 | DOCS-APPENDIX.md 10 | docs-private 11 | launch.json 12 | logs/ 13 | debug*.txt 14 | # Byte-compiled / optimized / DLL files 15 | __pycache__/ 16 | *.py[cod] 17 | *$py.class 18 | 19 | # C extensions 20 | *.so 21 | 22 | # Distribution / packaging 23 | .Python 24 | build/ 25 | develop-eggs/ 26 | dist/ 27 | downloads/ 28 | eggs/ 29 | .eggs/ 30 | lib/ 31 | lib64/ 32 | parts/ 33 | sdist/ 34 | var/ 35 | wheels/ 36 | share/python-wheels/ 37 | *.egg-info/ 38 | .installed.cfg 39 | *.egg 40 | MANIFEST 41 | 42 | # PyInstaller 43 | # Usually these files are written by a python script from a template 44 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 45 | *.manifest 46 | *.spec 47 | 48 | # Installer logs 49 | pip-log.txt 50 | pip-delete-this-directory.txt 51 | 52 | # Unit test / coverage reports 53 | htmlcov/ 54 | .tox/ 55 | .nox/ 56 | .coverage 57 | .coverage.* 58 | .cache 59 | nosetests.xml 60 | coverage.xml 61 | *.cover 62 | *.py,cover 63 | .hypothesis/ 64 | .pytest_cache/ 65 | cover/ 66 | 67 | # Translations 68 | *.mo 69 | *.pot 70 | 71 | # Django stuff: 72 | *.log 73 | local_settings.py 74 | db.sqlite3 75 | db.sqlite3-journal 76 | 77 | # Flask stuff: 78 | instance/ 79 | .webassets-cache 80 | 81 | # Scrapy stuff: 82 | .scrapy 83 | 84 | # Sphinx documentation 85 | docs/_build/ 86 | 87 | # PyBuilder 88 | .pybuilder/ 89 | target/ 90 | 91 | # Jupyter Notebook 92 | .ipynb_checkpoints 93 | 94 | # IPython 95 | profile_default/ 96 | ipython_config.py 97 | 98 | # pyenv 99 | # For a library or package, you might want to ignore these files since the code is 100 | # intended to run in multiple environments; otherwise, check them in: 101 | # .python-version 102 | 103 | # pipenv 104 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 105 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 106 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 107 | # install all needed dependencies. 108 | #Pipfile.lock 109 | 110 | # poetry 111 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 112 | # This is especially recommended for binary packages to ensure reproducibility, and is more 113 | # commonly ignored for libraries. 114 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 115 | #poetry.lock 116 | 117 | # pdm 118 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 119 | #pdm.lock 120 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 121 | # in version control. 122 | # https://pdm.fming.dev/#use-with-ide 123 | .pdm.toml 124 | 125 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 126 | __pypackages__/ 127 | 128 | # Celery stuff 129 | celerybeat-schedule 130 | celerybeat.pid 131 | 132 | # SageMath parsed files 133 | *.sage.py 134 | 135 | # Environments 136 | .env* 137 | .venv 138 | env/ 139 | venv/ 140 | ENV/ 141 | env.bak/ 142 | venv.bak/ 143 | 144 | # Spyder project settings 145 | .spyderproject 146 | .spyproject 147 | 148 | # Rope project settings 149 | .ropeproject 150 | 151 | # mkdocs documentation 152 | /site 153 | 154 | # mypy 155 | .mypy_cache/ 156 | .dmypy.json 157 | dmypy.json 158 | 159 | # Pyre type checker 160 | .pyre/ 161 | 162 | # pytype static type analyzer 163 | .pytype/ 164 | 165 | # Cython debug symbols 166 | cython_debug/ 167 | 168 | # PyCharm 169 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 170 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 171 | # and can be added to the global gitignore or merged into this file. For a more nuclear 172 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 173 | #.idea/ 174 | -------------------------------------------------------------------------------- /.streamlit/config.toml: -------------------------------------------------------------------------------- 1 | [theme] 2 | base="light" 3 | primaryColor="#005F26" 4 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "[python]": { 3 | "editor.defaultFormatter": "charliermarsh.ruff" 4 | }, 5 | "search.useIgnoreFiles": true 6 | } 7 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.11-slim 2 | 3 | WORKDIR /app 4 | 5 | COPY requirements.txt . 6 | 7 | RUN pip install --no-cache-dir -r requirements.txt 8 | 9 | COPY . . 10 | 11 | EXPOSE 80 12 | 13 | # https://github.com/tiangolo/fastapi/discussions/9328#discussioncomment-8242245 14 | CMD ["uvicorn", "api:app", "--host", "0.0.0.0", "--port", "80", "--proxy-headers", "--forwarded-allow-ips=*"] 15 | 16 | 17 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023-2024 Dmitriy Vasilyuk 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /TODOS-AND-OTHER-DEV-NOTES.md: -------------------------------------------------------------------------------- 1 | # Dev Notes 2 | 3 | ## Assorted TODOs 4 | 5 | 1. DONE chromadb.api.models.Collection has feature to filter by included text and metadatas (I think) 6 | 2. Eval mode: gives two answers, user likes one, bot tunes params 7 | 3. Generate several versions of the standalone query 8 | 4. Use tasks in case answer takes too long (meanwhile can show a tip) 9 | 5. Ability for user to customize bot settings (num of docs, relevance threshold, temperature, etc.) 10 | 6. DONE Make use of larger context window 11 | 7. DONE Reorder relevant documents in order (based on overlap e.g.) 12 | 8. DONE API token for the bot backend 13 | 9. For long convos include short summary in the prompt 14 | 10. DONE Use streaming to improve detecting true LLM stalls 15 | 11. NA Send partial results to the user to avoid timeout 16 | 12. Return retrieved documents in the response 17 | 13. Speed up limiting tokens - see tiktoken comment in OpenAIEmbeddings 18 | 14. DONE Include current date/time in the prompt 19 | 15. DONE Use two different chunk sizes: search for small chunks, then small chunks point to larger chunks 20 | 16. Include tips for next possible commands in the response 21 | 17. Add /re feedback command to automatically change query, report type, search queries 22 | 18. Automatically infer command from user query 23 | 19. Ability to ingest current chat history into a collection 24 | 25 | ## Evaluations 26 | 27 | ### Research tasks 28 | 29 | 1. "extract main content from html" 30 | 31 | ## Discovered Tricks for Getting Content from Websites 32 | 33 | 1. To help with bot detection, it helps to set a different device in playwright 34 | - e.g. "iphone13" 35 | - For example, this helped with the following links: 36 | - "https://www.sciencedaily.com/news/computers_math/artificial_intelligence/", 37 | - "https://www.investors.com/news/technology/ai-stocks-artificial-intelligence-trends-and-news/", 38 | - "https://www.forbes.com/sites/forbesbusinesscouncil/2022/10/07/recent-advancements-in-artificial-intelligence/?sh=266cd08e7fa5", 39 | 2. I increased the MIN_EXTRACTED_SIZE setting from 250 to 1500 in trafilatura to extract more content 40 | - for example, in the following link, with the default setting it only extracted a cookie banner: 41 | - "https://www.artificialintelligence-news.com/" 42 | 43 | 44 | ## Other Notes 45 | 46 | ### get_relevant_documents on a retriever 47 | 48 | - calls _get_relevant_documents 49 | - for vectorstores, that calls self.vectorstore.similarity_search_with_relevance_scores(query, **self.search_kwargs) 50 | - _similarity_search_with_relevance_scores 51 | - similarity_search_with_score 52 | - for chroma, that calls self.__query_collection(query_embeddings=[query_embedding], n_results=k, where=filter) 53 | - besides k and filter, no other search_kwargs are passed (specifically where_document), 54 | even though it accepts arbitrary kwargs 55 | - self._collection.query( 56 | query_texts=query_texts, 57 | query_embeddings=query_embeddings, 58 | n_results=n_results, 59 | where=where, 60 | **kwargs, 61 | ) 62 | -------------------------------------------------------------------------------- /_prepare_env.py: -------------------------------------------------------------------------------- 1 | """Load the environment variables and perform other initialization tasks.""" 2 | from utils.prepare import is_env_loaded 3 | 4 | if not is_env_loaded: 5 | raise RuntimeError("This should be unreachable.") 6 | -------------------------------------------------------------------------------- /agentblocks/collectionhelper.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | 3 | from langchain_core.documents import Document 4 | 5 | from agents.dbmanager import get_full_collection_name 6 | from components.chroma_ddg import ChromaDDG, exists_collection 7 | from utils.chat_state import ChatState 8 | from utils.docgrab import ingest_into_chroma 9 | from utils.helpers import get_timestamp 10 | from utils.prepare import get_logger 11 | 12 | logger = get_logger() 13 | 14 | SMALL_WORDS = {"a", "an", "the", "of", "in", "on", "at", "for", "to", "and", "or"} 15 | SMALL_WORDS |= {"is", "are", "was", "were", "be", "been", "being", "am", "what"} 16 | SMALL_WORDS |= {"what", "which", "who", "whom", "whose", "where", "when", "how"} 17 | SMALL_WORDS |= {"this", "that", "these", "those", "there", "here", "can", "could"} 18 | SMALL_WORDS |= {"i", "you", "he", "she", "it", "we", "they", "me", "him", "her"} 19 | SMALL_WORDS |= {"my", "your", "his", "her", "its", "our", "their", "mine", "yours"} 20 | SMALL_WORDS |= {"some", "any"} 21 | 22 | 23 | def construct_new_collection_name(query: str, chat_state: ChatState) -> str: 24 | """ 25 | Construct and return new collection name based on the query and user ID. Ensures 26 | that the collection name is unique by appending a number to the end if necessary. 27 | """ 28 | # Decide on the collection name consistent with ChromaDB's naming rules 29 | query_words = [x.lower() for x in query.split()] 30 | words = [] 31 | words_excluding_small = [] 32 | for word in query_words: 33 | word_just_alnum = "".join(x for x in word if x.isalnum()) 34 | if not word_just_alnum: 35 | break 36 | words.append(word_just_alnum) 37 | if word not in SMALL_WORDS: 38 | words_excluding_small.append(word_just_alnum) 39 | 40 | words = words_excluding_small if len(words_excluding_small) > 2 else words 41 | 42 | new_coll_name = "-".join(words[:3])[:35].rstrip("-") 43 | 44 | # Screen for too short collection names or those that are convetible to a number 45 | try: 46 | if len(new_coll_name) < 3 or int(new_coll_name) is not None: 47 | new_coll_name = f"collection-{new_coll_name}".rstrip("-") 48 | except ValueError: 49 | pass 50 | 51 | # Construct full collection name (preliminary) 52 | new_coll_name = get_full_collection_name(chat_state.user_id, new_coll_name) 53 | 54 | new_coll_name_final = new_coll_name 55 | 56 | # Check if collection exists, if so, add a number to the end 57 | for i in range(2, 1000000): 58 | if not exists_collection(new_coll_name_final, chat_state.vectorstore.client): 59 | return new_coll_name_final 60 | new_coll_name_final = f"{new_coll_name}-{i}" 61 | 62 | 63 | def ingest_into_collection( 64 | *, 65 | collection_name: str, 66 | docs: list[Document], 67 | collection_metadata: dict[str, str] | None, 68 | chat_state: ChatState, 69 | is_new_collection: bool, 70 | retry_with_random_name: bool = False, 71 | ) -> ChromaDDG: 72 | """ 73 | Load provided documents and metadata into a new or existing collection. 74 | 75 | It's ok to pass an empty list of documents. If retry_with_random_name is True, 76 | and the collection name is invalid, it will replace it with a random valid one. 77 | 78 | If is_new_collection is True, the function will add the "created_at" and "updated_at" 79 | metadata fields to the collection (overwriting such fields in the passed metadata). 80 | If is_new_collection is False, it will only add the "updated_at" field, and only if 81 | collection_metadata is not None. 82 | """ 83 | logger.info("Creating new collection and loading data") 84 | 85 | for i in range(2): 86 | try: 87 | timestamp = get_timestamp() 88 | if is_new_collection: 89 | full_metadata = {"created_at": timestamp, "updated_at": timestamp} 90 | if collection_metadata: 91 | full_metadata.update(collection_metadata) 92 | elif collection_metadata is not None: 93 | full_metadata = {"updated_at": timestamp} | collection_metadata 94 | else: 95 | # For now, we don't add "updated_at" if no metadata is passed, 96 | # because we would need to fetch the existing metadata first 97 | full_metadata = None 98 | 99 | vectorstore = ingest_into_chroma( 100 | docs, 101 | collection_name=collection_name, 102 | openai_api_key=chat_state.openai_api_key, 103 | chroma_client=chat_state.vectorstore.client, 104 | collection_metadata=full_metadata, 105 | ) 106 | break # success 107 | except Exception as e: # bad name error may not be ValueError in docker mode 108 | logger.error(f"Error ingesting documents into ChromaDB: {e}") 109 | if ( 110 | not retry_with_random_name 111 | or i != 0 112 | or "Expected collection name" not in str(e) 113 | ): 114 | raise e # i == 1 means tried normal name and random name, give up 115 | 116 | # Create a random valid collection name and try again 117 | collection_name = get_full_collection_name( 118 | chat_state.user_id, "collection-" + uuid.uuid4().hex[:8] 119 | ) 120 | logger.warning(f"Ingesting into new collection named {collection_name}") 121 | 122 | logger.info(f"Finished ingesting data into new collection {collection_name}") 123 | return vectorstore 124 | -------------------------------------------------------------------------------- /agentblocks/core.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | 3 | from pydantic import BaseModel 4 | 5 | from utils.strings import extract_json 6 | from utils.type_utils import ChainType, DDGError, JSONish 7 | 8 | MAX_ENFORCE_FORMAT_ATTEMPTS = 4 9 | 10 | 11 | class EnforceFormatError(DDGError): 12 | default_user_facing_message = ( 13 | "Apologies, I tried to compose a message in the right " 14 | "format, but I kept running into trouble." 15 | ) 16 | 17 | 18 | def enforce_format( 19 | chain: ChainType, 20 | inputs: dict, 21 | validator_transformer: Callable, 22 | max_attempts: int = MAX_ENFORCE_FORMAT_ATTEMPTS, 23 | ): 24 | """ 25 | Enforce a format on the output of a chain. If the chain fails to produce 26 | the expected format, retry up to `max_attempts` times. An attempt is considered 27 | successful if the chain does not raise an exception and the output can be 28 | transformed using the `validator_transformer`. Return the transformed output. On 29 | failure, raise an `EnforceFormatError` with a message that includes the last output. 30 | """ 31 | for i in range(max_attempts): 32 | output = "OUTPUT_FAILED" 33 | try: 34 | output = chain.invoke(inputs) 35 | return validator_transformer(output) 36 | except Exception as e: 37 | err_msg = f"Failed to enforce format. Output:\n\n{output}\n\nError: {e}" 38 | print(err_msg) 39 | if i == max_attempts - 1: 40 | raise EnforceFormatError(err_msg) from e 41 | 42 | 43 | MAX_ENFORCE_FORMAT_ATTEMPTS = 4 44 | 45 | 46 | def enforce_json_format( 47 | chain: ChainType, 48 | inputs: dict, 49 | validator_transformer: Callable, # e.g. pydantic model's validator - model_validate() (!) 50 | max_attempts: int = MAX_ENFORCE_FORMAT_ATTEMPTS, 51 | ) -> JSONish: 52 | return enforce_format( 53 | chain, inputs, lambda x: validator_transformer(extract_json(x)), max_attempts 54 | ) 55 | 56 | 57 | def enforce_pydantic_json( 58 | chain: ChainType, 59 | inputs: dict, 60 | pydantic_model: BaseModel, 61 | max_attempts: int = MAX_ENFORCE_FORMAT_ATTEMPTS, 62 | ) -> BaseModel: 63 | return enforce_format( 64 | chain, 65 | inputs, 66 | lambda x: pydantic_model.model_validate(extract_json(x)), 67 | max_attempts, 68 | ) 69 | -------------------------------------------------------------------------------- /agentblocks/docconveyer.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | 3 | # class DocData(BaseModel): 4 | # text: str | None = None 5 | # source: str | None = None 6 | # num_tokens: int | None = None 7 | # doc_uuid: str | None = None 8 | # error: str | None = None 9 | # is_ingested: bool = False 10 | from typing import TypeVar 11 | from langchain_core.documents import Document 12 | from pydantic import BaseModel, Field 13 | 14 | from utils.lang_utils import ROUGH_UPPER_LIMIT_AVG_CHARS_PER_TOKEN, get_num_tokens 15 | from utils.prepare import get_logger 16 | from utils.type_utils import Doc 17 | from langchain_text_splitters import RecursiveCharacterTextSplitter 18 | 19 | logger = get_logger() 20 | 21 | DocT = TypeVar("DocT", Document, Doc) 22 | 23 | 24 | def _split_doc_based_on_tokens( 25 | doc: Document | Doc, max_tokens: float, target_num_chars: int 26 | ) -> list[Document]: 27 | """ 28 | Helper function for the main function below that definitely splits the provided document. 29 | """ 30 | text_splitter = RecursiveCharacterTextSplitter( 31 | chunk_size=target_num_chars, 32 | chunk_overlap=0, 33 | add_start_index=True, # metadata will include start_index of snippet in original doc 34 | ) 35 | candidate_new_docs = text_splitter.create_documents( 36 | texts=[doc.page_content], metadatas=[doc.metadata] 37 | ) 38 | new_docs = [] 39 | 40 | # Calculate the number of tokens in each part and split further if needed 41 | for maybe_new_doc in candidate_new_docs: 42 | maybe_new_doc.metadata["num_tokens"] = get_num_tokens( 43 | maybe_new_doc.page_content 44 | ) 45 | 46 | # If the part is sufficiently small, add it to the list of new docs 47 | if maybe_new_doc.metadata["num_tokens"] <= max_tokens: 48 | new_docs.append(maybe_new_doc) 49 | continue 50 | 51 | # If actual num tokens is X times the target, reduce target_num_chars by ~X 52 | new_target_num_chars = min( 53 | int(target_num_chars * max_tokens / maybe_new_doc.metadata["num_tokens"]), 54 | target_num_chars // 2, 55 | ) 56 | 57 | # Split the part further and adjust the start index of each snippet 58 | new_parts = _split_doc_based_on_tokens( 59 | maybe_new_doc, max_tokens, new_target_num_chars 60 | ) 61 | for new_part in new_parts: 62 | new_part.metadata["start_index"] += maybe_new_doc.metadata["start_index"] 63 | 64 | # Add the new parts to the list of new docs 65 | new_docs.extend(new_parts) 66 | 67 | return new_docs 68 | 69 | 70 | def split_doc_based_on_tokens(doc: DocT, max_tokens: float) -> list[DocT]: 71 | """ 72 | Split a document into parts based on the number of tokens in each part. Specifically, 73 | if the number of tokens in a part (or the original doc) is within max_tokens, then the part is 74 | added to the list of new docs. Otherwise, the part is split further. The resulting Document 75 | objects are returned as a list and contain the copy of the metadata from the parent doc, plus 76 | the "start_index" of its occurrence in the parent doc. The metadata will also include the 77 | "num_tokens" of the part. 78 | """ 79 | # If the doc is too large then don't count tokens, just split it. 80 | num_chars = len(doc.page_content) 81 | if num_chars / ROUGH_UPPER_LIMIT_AVG_CHARS_PER_TOKEN > max_tokens: 82 | # Doc almost certainly has more tokens than max_tokens 83 | target_num_chars = int(max_tokens * ROUGH_UPPER_LIMIT_AVG_CHARS_PER_TOKEN / 2) 84 | else: 85 | # Count the tokens to see if we need to split 86 | if (num_tokens := get_num_tokens(doc.page_content)) <= max_tokens: 87 | doc.metadata["num_tokens"] = num_tokens 88 | return [doc] 89 | # Guess the target number of characters for the text splitter 90 | target_num_chars = int(num_chars / (num_tokens / max_tokens) / 2) 91 | 92 | documents = _split_doc_based_on_tokens(doc, max_tokens, target_num_chars) 93 | if isinstance(doc, Document): 94 | return documents 95 | else: 96 | return [Doc.from_lc_doc(d) for d in documents] 97 | 98 | 99 | def break_up_big_docs( 100 | docs: list[DocT], 101 | max_tokens: float, # = int(CONTEXT_LENGTH * 0.25), 102 | ) -> list[DocT]: 103 | """ 104 | Split each big document into parts, leaving the small ones as they are. Big vs small is 105 | determined by how the number of tokens in the document compares to the max_tokens parameter. 106 | """ 107 | logger.info(f"Breaking up {len(docs)} docs into parts with max_tokens={max_tokens}") 108 | new_docs = [] 109 | for doc in docs: 110 | doc_parts = split_doc_based_on_tokens(doc, max_tokens) 111 | logger.debug(f"Split doc {doc.metadata.get('source')} into {len(doc_parts)} parts") 112 | if len(doc_parts) > 1: 113 | # Create an id representing the original doc 114 | full_doc_ref = uuid.uuid4().hex[:8] 115 | for i, doc_part in enumerate(doc_parts): 116 | doc_part.metadata["full_doc_ref"] = full_doc_ref 117 | doc_part.metadata["part_id"] = str(i) 118 | 119 | new_docs.extend(doc_parts) 120 | return new_docs 121 | 122 | 123 | def limit_num_docs_by_tokens( 124 | docs: list[Document | Doc], max_tokens: float 125 | ) -> tuple[int, int]: 126 | """ 127 | Limit the number of documents in a list based on the total number of tokens in the 128 | documents. 129 | 130 | Return the number of documents that can be included in the 131 | list without exceeding the maximum number of tokens, and the total number of tokens 132 | in the included documents. 133 | 134 | Additionally, the function updates the metadata of each document with the number of 135 | tokens in the document, if it is not already present. 136 | """ 137 | tot_tokens = 0 138 | for i, doc in enumerate(docs): 139 | if (num_tokens := doc.metadata.get("num_tokens")) is None: 140 | doc.metadata["num_tokens"] = num_tokens = get_num_tokens(doc.page_content) 141 | if tot_tokens + num_tokens > max_tokens: 142 | return i, tot_tokens 143 | tot_tokens += num_tokens 144 | 145 | return len(docs), tot_tokens 146 | 147 | 148 | class DocConveyer(BaseModel): 149 | docs: list[Doc] = Field(default_factory=list) 150 | max_tokens_for_breaking_up_docs: int | float | None = None 151 | idx_first_not_done: int = 0 # done = "pushed out" by get_next_docs 152 | 153 | def __init__(self, **data): 154 | super().__init__(**data) 155 | if self.max_tokens_for_breaking_up_docs is not None: 156 | self.docs = break_up_big_docs( 157 | self.docs, self.max_tokens_for_breaking_up_docs 158 | ) 159 | 160 | @property 161 | def num_available_docs(self): 162 | return len(self.docs) - self.idx_first_not_done 163 | 164 | def add_docs(self, docs: list[Doc]): 165 | if self.max_tokens_for_breaking_up_docs is not None: 166 | docs = break_up_big_docs(docs, self.max_tokens_for_breaking_up_docs) 167 | self.docs.extend(docs) 168 | 169 | def get_next_docs( 170 | self, 171 | max_tokens: float, 172 | max_docs: int | None = None, 173 | max_full_docs: int | None = None, 174 | ) -> list[Doc]: 175 | logger.debug(f"{self.num_available_docs} docs available") 176 | num_docs, _ = limit_num_docs_by_tokens( 177 | self.docs[self.idx_first_not_done :], max_tokens 178 | ) 179 | if max_docs is not None: 180 | num_docs = min(num_docs, max_docs) 181 | 182 | if max_full_docs is not None: 183 | num_full_docs = 0 184 | old_full_doc_ref = None 185 | new_num_docs = 0 186 | for doc in self.docs[ 187 | self.idx_first_not_done : self.idx_first_not_done + num_docs 188 | ]: 189 | if num_full_docs >= max_full_docs: 190 | break 191 | new_num_docs += 1 192 | new_full_doc_ref = doc.metadata.get("full_doc_ref") 193 | if new_full_doc_ref is None or new_full_doc_ref != old_full_doc_ref: 194 | num_full_docs += 1 195 | old_full_doc_ref = new_full_doc_ref 196 | num_docs = new_num_docs 197 | 198 | self.idx_first_not_done += num_docs 199 | logger.debug(f"Returning {num_docs} docs") 200 | return self.docs[self.idx_first_not_done - num_docs : self.idx_first_not_done] 201 | 202 | def clear_done_docs(self): 203 | self.docs = self.docs[self.idx_first_not_done :] 204 | self.idx_first_not_done = 0 205 | -------------------------------------------------------------------------------- /agentblocks/webprocess.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | 3 | from pydantic import BaseModel, Field 4 | 5 | from agentblocks.webretrieve import get_content_from_urls 6 | from utils.type_utils import Doc 7 | from utils.web import LinkData 8 | 9 | DEFAULT_MIN_OK_URLS = 5 10 | DEFAULT_INIT_BATCH_SIZE = 0 # 0 = "auto-determined" 11 | 12 | 13 | class URLConveyer(BaseModel): 14 | urls: list[str] 15 | link_data_dict: dict[str, LinkData] = Field(default_factory=dict) 16 | # num_ok_urls: int = 0 17 | idx_first_not_done: int = 0 # done = "pushed out" by get_next_docs 18 | idx_first_not_tried: int = 0 # not necessarily equal to len(link_data_dict) 19 | idx_last_url_refresh: int = 0 20 | 21 | num_url_retrievals: int = 0 22 | 23 | default_min_ok_urls: int = DEFAULT_MIN_OK_URLS 24 | default_init_batch_size: int = DEFAULT_INIT_BATCH_SIZE # 0 = "auto-determined" 25 | 26 | @property 27 | def num_tried_urls_since_refresh(self) -> int: 28 | return self.idx_first_not_tried - self.idx_last_url_refresh 29 | 30 | @property 31 | def num_untried_urls(self) -> int: 32 | return len(self.urls) - self.idx_first_not_tried 33 | 34 | def refresh_urls( 35 | self, 36 | urls: list[str], 37 | only_add_truly_new: bool = True, 38 | idx_to_cut_at: int | None = None, 39 | ): 40 | if idx_to_cut_at is None: 41 | idx_to_cut_at = self.idx_first_not_tried 42 | elif idx_to_cut_at < 0 or idx_to_cut_at > len(self.urls): 43 | raise ValueError(f"idx_to_cut_at must be between 0 and {len(self.urls)}") 44 | 45 | del self.urls[idx_to_cut_at:] 46 | 47 | if only_add_truly_new: 48 | old_urls = set(self.urls) 49 | urls = [url for url in urls if url not in old_urls] 50 | 51 | self.urls.extend(urls) 52 | self.idx_last_url_refresh = idx_to_cut_at 53 | 54 | def retrieve_content_from_urls( 55 | self, 56 | min_ok_urls: int | None = None, # use default_min_ok_urls if None 57 | init_batch_size: int | None = None, # use default_init_batch_size if None 58 | batch_fetcher: Callable[[list[str]], list[str]] | None = None, 59 | ): 60 | url_retrieval_data = get_content_from_urls( 61 | urls=self.urls[self.idx_first_not_tried :], 62 | min_ok_urls=min_ok_urls 63 | if min_ok_urls is not None 64 | else self.default_min_ok_urls, 65 | init_batch_size=init_batch_size 66 | if init_batch_size is not None 67 | else self.default_init_batch_size, 68 | batch_fetcher=batch_fetcher, 69 | ) 70 | 71 | self.num_url_retrievals += 1 72 | self.link_data_dict.update(url_retrieval_data.link_data_dict) 73 | self.idx_first_not_tried += url_retrieval_data.idx_first_not_tried 74 | 75 | def get_next_docs(self) -> list[Doc]: 76 | docs = [] 77 | for url in self.urls[self.idx_first_not_done : self.idx_first_not_tried]: 78 | link_data = self.link_data_dict[url] 79 | if link_data.error: 80 | continue 81 | doc = Doc(page_content=link_data.text, metadata={"source": url}) 82 | if link_data.num_tokens is not None: 83 | doc.metadata["num_tokens"] = link_data.num_tokens 84 | docs.append(doc) 85 | 86 | self.idx_first_not_done = self.idx_first_not_tried 87 | return docs 88 | 89 | def get_next_docs_with_url_retrieval( 90 | self, 91 | min_ok_urls: int | None = None, # use default_min_ok_urls if None 92 | init_batch_size: int | None = None, # use default_init_batch_size if None 93 | batch_fetcher: Callable[[list[str]], list[str]] | None = None, 94 | ) -> list[Doc]: 95 | if docs := self.get_next_docs(): 96 | return docs 97 | 98 | # If no more docs, retrieve more content from URLs 99 | self.retrieve_content_from_urls( 100 | min_ok_urls=min_ok_urls, 101 | init_batch_size=init_batch_size, 102 | batch_fetcher=batch_fetcher, 103 | ) 104 | return self.get_next_docs() 105 | -------------------------------------------------------------------------------- /agentblocks/webretrieve.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | 3 | from pydantic import BaseModel, Field 4 | 5 | from utils.prepare import get_logger 6 | from utils.type_utils import DDGError 7 | from utils.web import LinkData, get_batch_url_fetcher 8 | 9 | logger = get_logger() 10 | 11 | 12 | class URLRetrievalData(BaseModel): 13 | urls: list[str] 14 | link_data_dict: dict[str, LinkData] = Field(default_factory=dict) 15 | num_ok_urls: int = 0 16 | idx_first_not_tried: int = 0 # different from len(link_data_dict) if urls repeat 17 | 18 | 19 | MAX_INIT_BATCH_SIZE = 10 20 | 21 | 22 | def get_content_from_urls( 23 | urls: list[str], 24 | min_ok_urls: int, 25 | init_batch_size: int = 0, # auto-determined if 0 26 | batch_fetcher: Callable[[list[str]], list[str]] | None = None, 27 | ) -> URLRetrievalData: 28 | """ 29 | Fetch content from a list of urls using a batch fetcher. If at least 30 | min_ok_urls urls are fetched successfully, return the fetched content. 31 | Otherwise, fetch a new batch of urls, and repeat until at least min_ok_urls 32 | urls are fetched successfully. 33 | 34 | If there are duplicate URLs 35 | 36 | Args: 37 | - urls: list of urls to fetch content from 38 | - min_ok_urls: minimum number of urls that need to be fetched successfully 39 | - init_batch_size: initial batch size to fetch content from 40 | - batch_fetcher: function to fetch content from a batch of urls 41 | 42 | Returns: 43 | - URLRetrievalData: object containing the fetched content 44 | """ 45 | try: 46 | batch_fetcher = batch_fetcher or get_batch_url_fetcher() 47 | init_batch_size = init_batch_size or min( 48 | MAX_INIT_BATCH_SIZE, round(min_ok_urls * 1.2) 49 | ) # NOTE: could optimize 50 | 51 | logger.info( 52 | f"Fetching content from {len(urls)} urls:\n" 53 | f" - {min_ok_urls} successfully obtained URLs needed\n" 54 | f" - {init_batch_size} is the initial batch size\n" 55 | ) 56 | 57 | res = URLRetrievalData(urls=urls) 58 | num_urls = len(urls) 59 | url_set = set() # to keep track of unique urls 60 | 61 | # If, say, only 3 ok urls are still needed, we might want to try fetching 3 + extra 62 | num_extras = max(2, init_batch_size - min_ok_urls) 63 | 64 | while res.num_ok_urls < min_ok_urls: 65 | batch_size = min( 66 | init_batch_size, 67 | min_ok_urls - res.num_ok_urls + num_extras, 68 | ) 69 | batch_urls = [] 70 | for url in urls[res.idx_first_not_tried :]: 71 | if len(batch_urls) == batch_size: 72 | break 73 | if url in url_set: 74 | continue 75 | batch_urls.append(url) 76 | url_set.add(url) 77 | res.idx_first_not_tried += 1 78 | 79 | if (batch_size := len(batch_urls)) == 0: 80 | break # no more urls to fetch 81 | 82 | logger.info(f"Fetching {batch_size} urls:\n- " + "\n- ".join(batch_urls)) 83 | 84 | # Fetch content from urls in batch 85 | batch_htmls = batch_fetcher(batch_urls) 86 | 87 | # Process fetched content 88 | for url, html in zip(batch_urls, batch_htmls): 89 | link_data = LinkData.from_raw_content(html) 90 | res.link_data_dict[url] = link_data 91 | if not link_data.error: 92 | res.num_ok_urls += 1 93 | 94 | logger.info( 95 | f"Total URLs processed: {res.idx_first_not_tried} ({num_urls} total)\n" 96 | f"Total successful URLs: {res.num_ok_urls} ({min_ok_urls} needed)\n" 97 | ) 98 | 99 | return res 100 | except Exception as e: 101 | raise DDGError( 102 | user_facing_message="Apologies, I ran into a problem trying to fetch URL content." 103 | ) from e 104 | -------------------------------------------------------------------------------- /agentblocks/websearch.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | from pydantic import BaseModel 3 | 4 | from agentblocks.core import enforce_pydantic_json 5 | from utils.algo import interleave_iterables, remove_duplicates_keep_order 6 | from utils.async_utils import gather_tasks_sync 7 | from utils.chat_state import ChatState 8 | from utils.prepare import get_logger 9 | from utils.type_utils import DDGError 10 | from langchain_community.utilities import GoogleSerperAPIWrapper 11 | 12 | logger = get_logger() 13 | 14 | WEB_SEARCH_API_ISSUE_MSG = ( 15 | "Apologies, it seems I'm having an issue with the API I use to search the web." 16 | ) 17 | 18 | domain_blacklist = ["youtube.com"] 19 | 20 | 21 | class WebSearchAPIError(DDGError): 22 | # NOTE: Can be raised e.g. like this: raise WebSearchAPIError() from e 23 | # In that case, the original exception will be stored in self.__cause__ 24 | default_user_facing_message = WEB_SEARCH_API_ISSUE_MSG 25 | 26 | 27 | def _extract_domain(url: str): 28 | try: 29 | full_domain = url.split("://")[-1].split("/")[0] # blah.blah.domain.com 30 | return ".".join(full_domain.split(".")[-2:]) # domain.com 31 | except Exception: 32 | return "" 33 | 34 | 35 | def get_links_from_search_results(search_results: list[dict[str, Any]]): 36 | logger.debug( 37 | f"Getting links from results of {len(search_results)} searches " 38 | f"with {[len(s) for s in search_results]} links each." 39 | ) 40 | links_for_each_query = [ 41 | [x["link"] for x in search_result.get("organic", []) if "link" in x] 42 | for search_result in search_results 43 | ] # [[links for query 1], [links for query 2], ...] 44 | 45 | logger.debug( 46 | f"Number of links for each query: {[len(links) for links in links_for_each_query]}" 47 | ) 48 | # NOTE: can ask LLM to decide which links to keep 49 | return [ 50 | link 51 | for link in remove_duplicates_keep_order( 52 | interleave_iterables(links_for_each_query) 53 | ) 54 | if _extract_domain(link) not in domain_blacklist 55 | ] 56 | 57 | 58 | def get_links_from_queries( 59 | queries: list[str], num_search_results: int = 10 60 | ) -> list[str]: 61 | """ 62 | Get links from a list of queries by doing a Google search for each query. 63 | """ 64 | try: 65 | # Do a Google search for each query 66 | logger.info(f"Performing web search for queries: {queries}") 67 | search = GoogleSerperAPIWrapper(k=num_search_results) 68 | search_tasks = [search.aresults(query) for query in queries] 69 | search_results = gather_tasks_sync(search_tasks) # can use serper's batching 70 | 71 | # Check for errors 72 | for search_result in search_results: 73 | try: 74 | if search_result["statusCode"] // 100 != 2: 75 | logger.error(f"Error in search result: {search_result}") 76 | # search_result can be {"statusCode": 400, "message": "Not enough credits"} 77 | raise WebSearchAPIError() # TODO: add message, make sure it gets logged 78 | except KeyError: 79 | logger.warning("No status code in search result, assuming success.") 80 | 81 | # Get links from search results 82 | return get_links_from_search_results(search_results) 83 | except WebSearchAPIError as e: 84 | raise e 85 | except Exception as e: 86 | raise WebSearchAPIError() from e 87 | 88 | 89 | class Queries(BaseModel): 90 | queries: list[str] 91 | 92 | 93 | def get_web_search_queries_from_prompt( 94 | prompt, inputs: dict[str, str], chat_state: ChatState 95 | ) -> list[str]: 96 | """ 97 | Get web search queries from a prompt. The prompt should ask the LLM to generate 98 | web search queries in the format {"queries": ["query1", "query2", ...]} 99 | """ 100 | logger.info("Submitting prompt to get web search queries") 101 | query_generator_chain = chat_state.get_prompt_llm_chain(prompt, to_user=False) 102 | 103 | return enforce_pydantic_json(query_generator_chain, inputs, Queries).queries 104 | -------------------------------------------------------------------------------- /agents/exporter.py: -------------------------------------------------------------------------------- 1 | from utils.chat_state import ChatState 2 | from utils.helpers import ( 3 | DELIMITER80_NONL, 4 | EXPORT_COMMAND_HELP_MSG, 5 | format_nonstreaming_answer, 6 | ) 7 | from utils.query_parsing import ExportCommand, get_command, get_int, get_int_or_command 8 | from utils.type_utils import INSTRUCT_EXPORT_CHAT_HISTORY, Instruction, Props 9 | 10 | REVERSE_COMMAND = "reverse" 11 | 12 | 13 | def get_exporter_response(chat_state: ChatState) -> Props: 14 | message = chat_state.message 15 | subcommand = chat_state.parsed_query.export_command 16 | 17 | if subcommand == ExportCommand.NONE: 18 | return format_nonstreaming_answer(EXPORT_COMMAND_HELP_MSG) 19 | 20 | if subcommand == ExportCommand.CHAT: 21 | max_messages = 1000000000 22 | is_reverse = False 23 | cmd, rest = get_int_or_command(message, [REVERSE_COMMAND]) 24 | if isinstance(cmd, int): 25 | max_messages = cmd 26 | cmd, rest = get_command(rest, [REVERSE_COMMAND]) 27 | if cmd is not None: 28 | is_reverse = True 29 | elif rest: # invalid non-empty command 30 | return format_nonstreaming_answer(EXPORT_COMMAND_HELP_MSG) 31 | elif cmd == REVERSE_COMMAND: 32 | is_reverse = True 33 | cmd, rest = get_int(rest) 34 | if cmd is not None: 35 | max_messages = cmd 36 | elif rest: # invalid non-empty command after reverse 37 | return format_nonstreaming_answer(EXPORT_COMMAND_HELP_MSG) 38 | elif rest: # cmd is None but message is not empty (invalid command) 39 | return format_nonstreaming_answer(EXPORT_COMMAND_HELP_MSG) 40 | 41 | if max_messages <= 0: 42 | return format_nonstreaming_answer(EXPORT_COMMAND_HELP_MSG) 43 | 44 | # Get and format the last num_messages messages 45 | msgs = [] 46 | for (user_msg, ai_msg), sources in zip( 47 | reversed(chat_state.chat_history), reversed(chat_state.sources_history) 48 | ): 49 | if len(msgs) >= max_messages: 50 | break 51 | if ai_msg is not None or sources: 52 | ai_msg_full = f"**DDG:** {ai_msg or ''}" 53 | if sources: 54 | tmp = "\n- ".join(sources) 55 | ai_msg_full += f"\n\n**SOURCES:**\n\n- {tmp}" 56 | msgs.append(ai_msg_full) 57 | if user_msg is not None: 58 | msgs.append(f"**USER:** {user_msg}") 59 | 60 | # We may have overshot by one message (didn't want to check twice in the loop) 61 | if len(msgs) > max_messages: 62 | msgs.pop() 63 | 64 | # Reverse the messages if needed, format them and return along with instructions 65 | data = f"\n\n{DELIMITER80_NONL}\n\n".join(msgs[:: 1 if is_reverse else -1]) 66 | return format_nonstreaming_answer( 67 | "I have collected our chat history and sent it to the UI for export." 68 | ) | { 69 | "instructions": [Instruction(type=INSTRUCT_EXPORT_CHAT_HISTORY, data=data)], 70 | } 71 | raise ValueError(f"Invalid export subcommand: {subcommand}") 72 | -------------------------------------------------------------------------------- /agents/ingester_summarizer.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | 3 | from icecream import ic 4 | 5 | from agentblocks.collectionhelper import ingest_into_collection 6 | from agents.dbmanager import ( 7 | get_access_role, 8 | get_full_collection_name, 9 | get_user_facing_collection_name, 10 | ) 11 | from components.llm import get_prompt_llm_chain 12 | from utils.chat_state import ChatState 13 | from utils.helpers import ( 14 | ADDITIVE_COLLECTION_PREFIX, 15 | INGESTED_DOCS_INIT_PREFIX, 16 | format_invalid_input_answer, 17 | format_nonstreaming_answer, 18 | ) 19 | from utils.lang_utils import limit_tokens_in_text, limit_tokens_in_texts 20 | from utils.prepare import CONTEXT_LENGTH 21 | from utils.prompts import SUMMARIZER_PROMPT 22 | from utils.query_parsing import IngestCommand 23 | from utils.type_utils import INSTRUCT_SHOW_UPLOADER, AccessRole, ChatMode, Instruction 24 | from utils.web import LinkData, get_batch_url_fetcher 25 | from langchain_core.documents import Document 26 | 27 | DEFAULT_MAX_TOKENS_FINAL_CONTEXT = int(CONTEXT_LENGTH * 0.7) 28 | 29 | NO_EDITOR_ACCESS_STATUS = "No editor access to collection" # a bit of duplication 30 | NO_MULTIPLE_INGESTION_SOURCES_STATUS = ( 31 | "Cannot ingest uploaded files and an external resource at the same time" 32 | ) 33 | DOC_SEPARATOR = "\n" + "-" * 40 + "\n\n" 34 | 35 | 36 | def summarize(docs: list[Document], chat_state: ChatState) -> str: 37 | if (num_docs := len(docs)) == 0: 38 | return "" 39 | 40 | summarizer_chain = get_prompt_llm_chain( 41 | SUMMARIZER_PROMPT, 42 | llm_settings=chat_state.bot_settings, 43 | api_key=chat_state.openai_api_key, 44 | callbacks=chat_state.callbacks, 45 | ) 46 | 47 | if num_docs == 1: 48 | shortened_text, num_tokens = limit_tokens_in_text( 49 | docs[0].page_content, DEFAULT_MAX_TOKENS_FINAL_CONTEXT 50 | ) # TODO: ability to summarize longer content 51 | shortened_texts = [shortened_text] 52 | else: 53 | # Multiple documents 54 | shortened_texts, nums_of_tokens = limit_tokens_in_texts( 55 | [doc.page_content for doc in docs], DEFAULT_MAX_TOKENS_FINAL_CONTEXT 56 | ) 57 | 58 | # Construct the final context 59 | final_texts = [] 60 | for i, (doc, shortened_text) in enumerate(zip(docs, shortened_texts)): 61 | if doc.page_content != shortened_text: 62 | suffix = "\n\nNOTE: The above content was truncated to fit the maximum token limit." 63 | else: 64 | suffix = "" 65 | 66 | final_texts.append( 67 | f"SOURCE: {doc.metadata.get('source', 'Unknown')}" 68 | f"\n\n{shortened_text}{suffix}" 69 | ) 70 | 71 | final_context = DOC_SEPARATOR.join(final_texts) 72 | ic(final_context) 73 | 74 | return summarizer_chain.invoke({"content": final_context}) 75 | 76 | 77 | def get_ingester_summarizer_response(chat_state: ChatState): 78 | # TODO: remove similar functionality in streamlit/ingest.py 79 | message = chat_state.parsed_query.message 80 | ingest_command = chat_state.parsed_query.ingest_command 81 | 82 | # If there's no message and no docs, just return request for files 83 | if not (docs := chat_state.uploaded_docs) and not message: 84 | return { 85 | "answer": "Please select your documents to upload and ingest.", 86 | "instructions": [Instruction(type=INSTRUCT_SHOW_UPLOADER)], 87 | } 88 | 89 | # Determine if we need to use the same collection or create a new one 90 | coll_name_as_shown = get_user_facing_collection_name( 91 | chat_state.user_id, chat_state.vectorstore.name 92 | ) 93 | if ingest_command == IngestCommand.ADD or ( 94 | ingest_command == IngestCommand.DEFAULT 95 | and coll_name_as_shown.startswith(ADDITIVE_COLLECTION_PREFIX) 96 | ): 97 | # We will use the same collection 98 | is_new_collection = False 99 | coll_name_full = chat_state.vectorstore.name 100 | 101 | # Check for editor access 102 | if get_access_role(chat_state).value < AccessRole.EDITOR.value: 103 | cmd_str = ( 104 | "summarize" 105 | if chat_state.chat_mode == ChatMode.SUMMARIZE_COMMAND_ID 106 | else "ingest" 107 | ) 108 | return format_invalid_input_answer( 109 | "Apologies, you can't ingest content into the current collection " 110 | "because you don't have editor access to it. You can ingest content " 111 | "into a new collection instead. For example:\n\n" 112 | f"```\n\n/{cmd_str} new {message}\n\n```", 113 | NO_EDITOR_ACCESS_STATUS, 114 | ) 115 | else: 116 | # We will need to create a new collection 117 | is_new_collection = True 118 | coll_name_as_shown = INGESTED_DOCS_INIT_PREFIX + uuid.uuid4().hex[:8] 119 | coll_name_full = get_full_collection_name( 120 | chat_state.user_id, coll_name_as_shown 121 | ) 122 | 123 | # If there are uploaded docs, ingest or summarize them 124 | if docs: 125 | if message: 126 | return format_invalid_input_answer( 127 | "Apologies, you can't simultaneously ingest uploaded files and " 128 | f"an external resource ({message}).", 129 | NO_MULTIPLE_INGESTION_SOURCES_STATUS, 130 | ) 131 | if chat_state.chat_mode == ChatMode.INGEST_COMMAND_ID: 132 | res = format_nonstreaming_answer( 133 | "The files you uploaded have been ingested into the collection " 134 | f"`{coll_name_as_shown}`. If you don't need to ingest " 135 | "more content into it, rename it with `/db rename my-cool-collection-name`." 136 | ) 137 | else: 138 | res = {"answer": summarize(docs, chat_state)} 139 | else: 140 | # If no uploaded docs, ingest or summarize the external resource 141 | fetch_func = get_batch_url_fetcher() # don't really need the batch aspect here 142 | html = fetch_func([message])[0] 143 | link_data = LinkData.from_raw_content(html) 144 | 145 | if link_data.error: 146 | return format_nonstreaming_answer( 147 | f"Apologies, I could not retrieve the resource `{message}`." 148 | ) 149 | 150 | docs = [Document(page_content=link_data.text, metadata={"source": message})] 151 | 152 | if chat_state.chat_mode == ChatMode.INGEST_COMMAND_ID: 153 | # "/ingest https://some.url.com" command - just ingest, don't summarize 154 | res = format_nonstreaming_answer( 155 | f"The resource `{message}` has been ingested into the collection " 156 | f"`{coll_name_as_shown}`. Feel free to ask questions about the collection's " 157 | "content. If you don't need to ingest " 158 | "more resources into it, rename it with `/db rename my-cool-collection-name`." 159 | ) 160 | else: 161 | res = {"answer": summarize(docs, chat_state)} 162 | 163 | # Ingest into the collection 164 | coll_metadata = {} if is_new_collection else chat_state.fetch_collection_metadata() 165 | vectorstore = ingest_into_collection( 166 | collection_name=coll_name_full, 167 | docs=docs, 168 | collection_metadata=coll_metadata, 169 | chat_state=chat_state, 170 | is_new_collection=is_new_collection, 171 | ) 172 | 173 | if is_new_collection: 174 | # Switch to the newly created collection 175 | res["vectorstore"] = vectorstore 176 | 177 | return res 178 | -------------------------------------------------------------------------------- /agents/researcher_data.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel, Field, model_validator 2 | 3 | from utils.web import LinkData 4 | 5 | 6 | class Report(BaseModel): 7 | report_text: str 8 | sources: list[str] = Field(default_factory=list) 9 | parent_report_ids: list[str] = Field(default_factory=list) 10 | child_report_id: str | None = None 11 | evaluation: str | None = None 12 | 13 | 14 | class ResearchReportData(BaseModel): 15 | query: str 16 | search_queries: list[str] 17 | report_type: str 18 | unprocessed_links: list[str] 19 | processed_links: list[str] 20 | link_data_dict: dict[str, LinkData] 21 | max_tokens_final_context: int 22 | main_report: str = "" 23 | base_reports: list[Report] = Field(default_factory=list) 24 | combined_reports: list[Report] = Field(default_factory=list) 25 | combined_report_id_levels: list[list[str]] = Field(default_factory=list) # levels 26 | num_obtained_unprocessed_links: int = 0 27 | num_obtained_unprocessed_ok_links: int = 0 28 | num_links_from_latest_queries: int | None = None 29 | evaluation: str | None = None 30 | 31 | @model_validator(mode="after") 32 | def validate(self): 33 | if self.num_links_from_latest_queries is None: 34 | self.num_links_from_latest_queries = len(self.unprocessed_links) 35 | return self 36 | 37 | @property 38 | def num_processed_links_from_latest_queries(self) -> int: 39 | # All unprocessed links are from the latest queries so we just subtract 40 | return self.num_links_from_latest_queries - len(self.unprocessed_links) 41 | 42 | def is_base_report(self, id: str) -> bool: 43 | return not id.startswith("c") 44 | 45 | def is_report_childless(self, id: str) -> bool: 46 | return self.get_report_by_id(id).child_report_id is None 47 | 48 | def get_report_by_id(self, id) -> Report: 49 | return ( 50 | self.base_reports[int(id)] 51 | if self.is_base_report(id) 52 | else self.combined_reports[int(id[1:])] 53 | ) 54 | 55 | def get_parent_reports(self, report: Report) -> list[Report]: 56 | return [ 57 | self.get_report_by_id(parent_id) for parent_id in report.parent_report_ids 58 | ] 59 | 60 | def get_ancestor_ids(self, report: Report) -> list[int]: 61 | res = [] 62 | for parent_id in report.parent_report_ids: 63 | if self.is_base_report(parent_id): 64 | res.append(parent_id) 65 | else: 66 | res.extend(self.get_ancestor_ids(self.get_report_by_id(parent_id))) 67 | 68 | def get_sources(self, report: Report) -> list[str]: 69 | res = report.sources.copy() 70 | for parent_report in self.get_parent_reports(report): 71 | res.extend(self.get_sources(parent_report)) 72 | return res 73 | -------------------------------------------------------------------------------- /agents/share_manager.py: -------------------------------------------------------------------------------- 1 | from agents.dbmanager import get_access_role, is_main_owner 2 | from utils.chat_state import ChatState 3 | from utils.helpers import SHARE_COMMAND_HELP_MSG, format_nonstreaming_answer 4 | from utils.prepare import DOMAIN_NAME_FOR_SHARING 5 | from utils.query_parsing import ShareCommand, ShareRevokeSubCommand 6 | from utils.type_utils import ( 7 | AccessCodeSettings, 8 | AccessCodeType, 9 | AccessRole, 10 | Props, 11 | ) 12 | 13 | share_type_to_access_role = { 14 | ShareCommand.EDITOR: AccessRole.EDITOR, 15 | ShareCommand.VIEWER: AccessRole.VIEWER, 16 | ShareCommand.OWNER: AccessRole.OWNER, 17 | } 18 | 19 | 20 | def handle_share_command(chat_state: ChatState) -> Props: 21 | """Handle the /share command.""" 22 | share_params = chat_state.parsed_query.share_params 23 | if share_params.share_type == ShareCommand.NONE: 24 | return format_nonstreaming_answer(SHARE_COMMAND_HELP_MSG) 25 | 26 | # Check that the user is an owner 27 | access_role = get_access_role(chat_state) 28 | if access_role.value < AccessRole.OWNER.value: 29 | return format_nonstreaming_answer( 30 | "Apologies, you don't have owner-level access to this collection. " 31 | f"Your current access level: {access_role.name.lower()}." 32 | ) 33 | # NOTE: this introduces redundant fetching of metadata (in get_access_role 34 | # and save_access_code_settings. Can optimize later. 35 | 36 | if share_params.share_type in ( 37 | ShareCommand.EDITOR, 38 | ShareCommand.VIEWER, 39 | ShareCommand.OWNER, 40 | ): 41 | # Check that the access code is not empty 42 | if (access_code := share_params.access_code) is None: 43 | return format_nonstreaming_answer( 44 | "Apologies, you need to specify an access code." 45 | f"\n\n{SHARE_COMMAND_HELP_MSG}" 46 | ) 47 | 48 | # Check that the access code type is not empty 49 | if (code_type := share_params.access_code_type) is None: 50 | return format_nonstreaming_answer( 51 | "Apologies, you need to specify an access code type." 52 | f"\n\n{SHARE_COMMAND_HELP_MSG}" 53 | ) 54 | if code_type != AccessCodeType.NEED_ALWAYS: 55 | return format_nonstreaming_answer( 56 | "Apologies, Dmitriy hasn't implemented this code type for me yet." 57 | ) 58 | 59 | # Check that the access code contains only letters and numbers 60 | if not access_code.isalnum(): 61 | return format_nonstreaming_answer( 62 | "Apologies, the access code can only contain letters and numbers." 63 | ) 64 | 65 | # Save the access code and its settings 66 | access_code_settings = AccessCodeSettings( 67 | code_type=code_type, 68 | access_role=share_type_to_access_role[share_params.share_type], 69 | ) 70 | chat_state.save_access_code_settings( 71 | access_code=share_params.access_code, 72 | access_code_settings=access_code_settings, 73 | ) 74 | 75 | # Form share link 76 | link = ( 77 | f"{DOMAIN_NAME_FOR_SHARING}?collection={chat_state.vectorstore.name}" 78 | f"&access_code={share_params.access_code}" 79 | ) 80 | 81 | # Return the response with the share link 82 | return format_nonstreaming_answer( 83 | "The current collection has been shared. Here's the link you can send " 84 | "to the people you want to grant access to:\n\n" 85 | f"```\n{link}\n```" 86 | ) 87 | 88 | if share_params.share_type == ShareCommand.REVOKE: 89 | collection_permissions = chat_state.get_collection_permissions() 90 | match share_params.revoke_type: 91 | case ShareRevokeSubCommand.ALL_CODES: 92 | collection_permissions.access_code_to_settings = {} 93 | ans = "All access codes have been revoked." 94 | 95 | case ShareRevokeSubCommand.ALL_USERS: 96 | # If user is not the main owner, don't delete their owner status 97 | saved_entries = {} 98 | if is_main_owner(chat_state): 99 | ans = "All users except you have been revoked." 100 | else: 101 | try: 102 | saved_entries[chat_state.user_id] = ( 103 | collection_permissions.user_id_to_settings[ 104 | chat_state.user_id 105 | ] 106 | ) 107 | ans = ( 108 | "All users except the main owner and you have been revoked." 109 | ) 110 | except KeyError: 111 | ans = "All users except the main owner have been revoked (including you)." 112 | 113 | # Clear all user settings except saved_entries 114 | collection_permissions.user_id_to_settings = saved_entries 115 | 116 | case ShareRevokeSubCommand.CODE: 117 | code = (share_params.code_or_user_to_revoke or "").strip() 118 | if not code: 119 | return format_nonstreaming_answer( 120 | "Apologies, you need to specify an access code to revoke." 121 | ) 122 | try: 123 | collection_permissions.access_code_to_settings.pop(code) 124 | ans = "The access code has been revoked." 125 | except KeyError: 126 | return format_nonstreaming_answer( 127 | "Apologies, the access code you specified doesn't exist." 128 | ) 129 | 130 | case ShareRevokeSubCommand.USER: 131 | user_id = (share_params.code_or_user_to_revoke or "").strip() 132 | if not user_id: 133 | return format_nonstreaming_answer( 134 | "Apologies, you need to specify a user to revoke." 135 | ) 136 | if user_id == "public": 137 | user_id = "" 138 | try: 139 | print(collection_permissions.user_id_to_settings) 140 | collection_permissions.user_id_to_settings.pop(user_id) 141 | print(collection_permissions.user_id_to_settings) 142 | ans = "The user has been revoked." 143 | except KeyError: 144 | return format_nonstreaming_answer( 145 | "Apologies, the user you specified doesn't have a stored access role." 146 | ) 147 | case _: 148 | return format_nonstreaming_answer( 149 | "Apologies, Dmitriy hasn't implemented this share subcommand for me yet." 150 | ) 151 | 152 | chat_state.save_collection_permissions( 153 | collection_permissions, use_cached_metadata=True 154 | ) 155 | return format_nonstreaming_answer(ans) 156 | 157 | return format_nonstreaming_answer( 158 | "Apologies, Dmitriy hasn't implemented this share subcommand for me yet." 159 | ) 160 | -------------------------------------------------------------------------------- /agents/websearcher_quick.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | 3 | from agentblocks.websearch import get_links_from_search_results 4 | from components.llm import get_prompt_llm_chain 5 | from utils.async_utils import gather_tasks_sync, make_sync 6 | from utils.chat_state import ChatState 7 | from utils.lang_utils import get_num_tokens, limit_tokens_in_texts 8 | from utils.prepare import CONTEXT_LENGTH 9 | from utils.prompts import RESEARCHER_PROMPT_SIMPLE 10 | from utils.web import ( 11 | afetch_urls_in_parallel_playwright, 12 | get_text_from_html, 13 | remove_failed_fetches, 14 | ) 15 | from langchain_community.utilities import GoogleSerperAPIWrapper 16 | 17 | 18 | def get_related_websearch_queries(message: str): 19 | search = GoogleSerperAPIWrapper() 20 | search_results = search.results(message) 21 | # print("search results:", json.dumps(search_results, indent=4)) 22 | related_searches = [x["query"] for x in search_results.get("relatedSearches", [])] 23 | people_also_ask = [x["question"] for x in search_results.get("peopleAlsoAsk", [])] 24 | 25 | return related_searches, people_also_ask 26 | 27 | 28 | def get_websearcher_response_quick( 29 | chat_state: ChatState, max_queries: int = 3, max_total_links: int = 9 30 | ): 31 | """ 32 | Perform quick web research on a query, with only one call to the LLM. 33 | 34 | It gets related queries to search for by examining the "related searches" and 35 | "people also ask" sections of the Google search results for the query. It then 36 | searches for each of these queries on Google, and gets the top links for each 37 | query. It then fetches the content from each link, shortens them to fit all 38 | texts into one context window, and then sends it to the LLM to generate a report. 39 | """ 40 | message = chat_state.message 41 | related_searches, people_also_ask = get_related_websearch_queries(message) 42 | 43 | # Get queries to search for 44 | queries = [message] + [ 45 | query for query in related_searches + people_also_ask if query != message 46 | ] 47 | queries = queries[:max_queries] 48 | print("queries:", queries) 49 | 50 | # Get links 51 | search = GoogleSerperAPIWrapper() 52 | search_tasks = [search.aresults(query) for query in queries] 53 | search_results = gather_tasks_sync(search_tasks) 54 | links = get_links_from_search_results(search_results)[:max_total_links] 55 | print("Links:", links) 56 | 57 | # Get content from links, measuring time taken 58 | t_start = datetime.now() 59 | print("Fetching content from links...") 60 | # htmls = make_sync(afetch_urls_in_parallel_chromium_loader)(links) 61 | htmls = make_sync(afetch_urls_in_parallel_playwright)(links) 62 | # htmls = make_sync(afetch_urls_in_parallel_html_loader)(links) 63 | # htmls = fetch_urls_with_lc_html_loader(links) # takes ~20s, slow 64 | t_fetch_end = datetime.now() 65 | 66 | print("Processing content...") 67 | texts = [get_text_from_html(html) for html in htmls] 68 | ok_texts, ok_links = remove_failed_fetches(texts, links) 69 | 70 | processed_texts, token_counts = limit_tokens_in_texts( 71 | ok_texts, max_tot_tokens=int(CONTEXT_LENGTH * 0.5) 72 | ) 73 | processed_texts = [ 74 | f"SOURCE: {link}\nCONTENT:\n{text}\n=====" 75 | for text, link in zip(processed_texts, ok_links) 76 | ] 77 | texts_str = "\n\n".join(processed_texts) 78 | t_end = datetime.now() 79 | 80 | print("CONTEXT FOR GENERATING REPORT:\n") 81 | print(texts_str + "\n" + "=" * 100 + "\n") 82 | print("Original number of links:", len(links)) 83 | print("Number of links after removing unsuccessfully fetched ones:", len(ok_links)) 84 | print("Time taken to fetch sites:", t_fetch_end - t_start) 85 | print("Time taken to process sites:", t_end - t_fetch_end) 86 | print("Number of resulting tokens:", get_num_tokens(texts_str)) 87 | 88 | chain = get_prompt_llm_chain( 89 | RESEARCHER_PROMPT_SIMPLE, 90 | llm_settings=chat_state.bot_settings, 91 | api_key=chat_state.openai_api_key, 92 | callbacks=chat_state.callbacks, 93 | stream=True, 94 | ) 95 | answer = chain.invoke({"texts_str": texts_str, "query": message}) 96 | return {"answer": answer} 97 | -------------------------------------------------------------------------------- /announcements/version-0-2-6.md: -------------------------------------------------------------------------------- 1 | ### GPT-4o mini and other updates 2 | 3 | First, DocDocGo had a teeny-tiny issue with its database growing too gigantic for its VM. The issue is now fixed, so if you tried to use DocDocGo and couldn't access it, my apologies - it's now back in action! :muscle:: 4 | 5 | Updates: 6 | 7 | **For users:** 8 | 9 | 1. The default model is now **GPT-4o mini** (thanks, OpenAI! :rocket:) instead of GPT-3.5, so expect a significant improvement in the quality of responses. If you use DocDocGo with you own OpenAI API key, you can also switch to GPT-4o, GPT-4, or the old friend GPT-3.5. 10 | 2. The database has been upgraded and cleaned up for better performance. 11 | 12 | **For developers:** 13 | 14 | 1. The [developer docs](https://github.com/reasonmethis/docdocgo-core/blob/main/README-FOR-DEVELOPERS.md) have been significantly expanded to include lots of goodies on hosting the Chroma database server on AWS and GCP, running locally with Docker, cleaning it, etc. There is a lot more that can (and should) be added, so if something is unclear or you have any questions or suggestions, feel free to DM me here or on [LinkdIn](https://www.linkedin.com/in/dmitriyvasilyuk/). 15 | 2. In addition to cleaning up the database, Langchain and Chroma DB have been upgraded to the latest versions (0.2.10 and 0.5.4, respectively). A slew of deprecated imports have been updated, and a bug related to a breaking change in Langchain's `Chroma` has been found and squashed :beetle:. 16 | 17 | I hope you enjoy using DocDocGo with the shiny new GPT-4o mini! Because of the model switch and other major changes, it's possible that there may be some hiccups that I haven't caught yet. Feel free to reach out on [GitHub](https://github.com/reasonmethis/docdocgo-core) if you see any :spider:, and the exterminator will be dispatched posthaste! 18 | 19 | ## Formatted for LinkedIn 20 | 21 | ### DocDocGo updates: GPT-4o mini and other news ### 22 | 23 | - App: https://docdocgo.streamlit.app 24 | - GitHub: https://github.com/reasonmethis/docdocgo-core 25 | 26 | 🚀 The default model is now GPT-4o mini instead of GPT-3.5, so expect a significant improvement in the quality of responses. If you use DocDocGo with you own OpenAI API key, you can also switch to GPT-4o, GPT-4, or the old friend GPT-3.5. 27 | 28 | ⚡ The database has been upgraded and cleaned up for better performance. 29 | 30 | 📝 The developer docs (https://github.com/reasonmethis/docdocgo-core/blob/main/README-FOR-DEVELOPERS.md) have been significantly expanded to include lots of goodies on hosting the Chroma database server on AWS and GCP, running locally with Docker, cleaning it, etc. There is a lot more that can (and should) be added, so if you have any questions or suggestions, feel free to DM me. 31 | 32 | 🔧 In addition to cleaning up the database, Langchain and Chroma DB have been upgraded to the latest versions (0.2.10 and 0.5.4, respectively). A slew of deprecated imports have been updated, and a bug related to a breaking change in Langchain's `Chroma` has been found and squashed 🪲. 33 | 34 | 📢 If this is your first time hearing about DocDocGo, it's an AI assistant that helps you find information that requires sifting through far more than the first page of Google search results. It's kind of like if ChatGPT, Perplexity and GPT Researcher had a 👶 35 | 36 | I hope you enjoy using DocDocGo with the shiny new GPT-4o mini! Because of the model switch and other major changes, it's possible that there may be some hiccups that I haven't caught yet. Feel free to reach out here or on GitHub (https://github.com/reasonmethis/docdocgo-core) if you see any 🕷️, and the exterminator will be dispatched posthaste! -------------------------------------------------------------------------------- /announcements/version-0-2.md: -------------------------------------------------------------------------------- 1 | :tada:**Announcing DocDocGo version 0.2**:tada: 2 | 3 | DocDocGo is growing up! Let's celebrate by describing the new features. 4 | 5 | **1. DDG finally has a logo that's not just the 🦉 character!** 6 | 7 | Its [design](https://github.com/reasonmethis/docdocgo-core/blob/9724fe27cd1a5812861feb018c2b45bed55af968/media/ddg-logo.png) continues the noble tradition of trying too hard to be clever, starting with the name "DocDocGo" :grin: 8 | 9 | @StreamlitTeam continues to ship - the new `st.logo` feature arrived just in time! 10 | 11 | **2. Re-engineered collections** 12 | 13 | With over 150 public collections and growing, it was time to manage them better. Now, collections track how recently they were created/updated. Use: 14 | 15 | - `/db list` - see the 20 most recent collections (in a cute table) 16 | 17 | - `/db list 21+` - see the next 20, and so on 18 | 19 | - `/db list blah` - list collections with "blah" in the name 20 | 21 | DDG will remind you of the availablej commands when you use `/db` or `/db list`. 22 | 23 | **Safety improvement:** Deleting a collection by number, e.g. `/db delete 42`, now only works if it was first listed by `/db list`. 24 | 25 | **3. Default modes** 26 | 27 | Tired of typing `/reseach` or `/research heatseek`? Let your fingers rest by selecting a default mode in the Streamlit sidebar. 28 | 29 | You can still override it when you need to - just type the desired command as usual, for example: `/help What's the difference between regular research and heatseek?`. 30 | 31 | **4. UI improvements** 32 | 33 | The UI has been refreshed with a new intro screen that has: 34 | 35 | - The one-click sample commands to get you started if you're new 36 | 37 | - The new and improved welcome message 38 | 39 | Take the new version for a spin and let me know what you think! :rocket: 40 | 41 | https://docdocgo.streamlit.app 42 | -------------------------------------------------------------------------------- /components/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reasonmethis/docdocgo-core/e6408ea9c7f4657525715730f893e8a3d27296a5/components/__init__.py -------------------------------------------------------------------------------- /components/chroma_ddg.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Any, Callable, Optional 3 | 4 | from chromadb import ClientAPI, Collection, HttpClient, PersistentClient 5 | from chromadb.api.types import Where # , WhereDocument 6 | from chromadb.config import Settings 7 | from langchain_community.vectorstores.chroma import _results_to_docs_and_scores 8 | from langchain_core.embeddings import Embeddings 9 | 10 | from components.openai_embeddings_ddg import get_openai_embeddings 11 | from utils.prepare import ( 12 | CHROMA_SERVER_AUTHN_CREDENTIALS, 13 | CHROMA_SERVER_HOST, 14 | CHROMA_SERVER_HTTP_PORT, 15 | USE_CHROMA_VIA_HTTP, 16 | VECTORDB_DIR, 17 | get_logger, 18 | ) 19 | from utils.type_utils import DDGError 20 | from langchain_community.vectorstores import Chroma 21 | from langchain_core.documents import Document 22 | 23 | logger = get_logger() 24 | 25 | 26 | class CollectionDoesNotExist(DDGError): 27 | """Exception raised when a collection does not exist.""" 28 | 29 | pass # REVIEW 30 | 31 | 32 | class ChromaDDG(Chroma): 33 | """ 34 | Modified Chroma vectorstore for DocDocGo. 35 | 36 | Specifically: 37 | 1. It enables the kwargs passed to similarity_search_with_score 38 | to be passed to the __query_collection method rather than ignored, 39 | which allows us to pass the 'where_document' parameter. 40 | 2. __init__ is overridden to allow the option of using get_collection (which doesn't create 41 | a collection if it doesn't exist) rather than always using get_or_create_collection (which does). 42 | """ 43 | 44 | def __init__( 45 | self, 46 | *, 47 | collection_name: str, 48 | client: ClientAPI, 49 | create_if_not_exists: bool, 50 | embedding_function: Optional[Embeddings] = None, 51 | persist_directory: Optional[str] = None, 52 | client_settings: Optional[Settings] = None, 53 | collection_metadata: Optional[dict] = None, 54 | relevance_score_fn: Optional[Callable[[float], float]] = None, 55 | ) -> None: 56 | """Initialize with a Chroma client.""" 57 | self._client_settings = client_settings 58 | self._client = client 59 | self._persist_directory = persist_directory 60 | 61 | self._embedding_function = embedding_function 62 | self.override_relevance_score_fn = relevance_score_fn 63 | logger.info(f"ChromaDDG init: {create_if_not_exists=}, {collection_name=}") 64 | 65 | if not create_if_not_exists and collection_metadata is not None: 66 | # We must check if the collection exists in this case because when metadata 67 | # is passed, we must use get_or_create_collection to update the metadata 68 | if not exists_collection(collection_name, client): 69 | raise CollectionDoesNotExist() 70 | 71 | if create_if_not_exists or collection_metadata is not None: 72 | self._collection = self._client.get_or_create_collection( 73 | name=collection_name, 74 | embedding_function=None, 75 | metadata=collection_metadata, 76 | ) 77 | else: 78 | try: 79 | self._collection = self._client.get_collection( 80 | name=collection_name, 81 | embedding_function=None, 82 | ) 83 | except Exception as e: 84 | logger.info(f"Failed to get collection {collection_name}: {str(e)}") 85 | raise CollectionDoesNotExist() if "does not exist" in str(e) else e 86 | 87 | def __bool__(self) -> bool: 88 | """Always return True to avoid ambiguity of what False could mean.""" 89 | return True 90 | 91 | @property 92 | def name(self) -> str: 93 | """Name of the underlying chromadb collection.""" 94 | return self._collection.name 95 | 96 | @property 97 | def collection(self) -> Collection: 98 | """The underlying chromadb collection.""" 99 | return self._collection 100 | 101 | @property 102 | def client(self) -> ClientAPI: 103 | """The underlying chromadb client.""" 104 | return self._client 105 | 106 | def get_cached_collection_metadata(self) -> dict[str, Any] | None: 107 | """Get locally cached metadata for the underlying chromadb collection.""" 108 | return self._collection.metadata 109 | 110 | def fetch_collection_metadata(self) -> dict[str, Any]: 111 | """Fetch metadata for the underlying chromadb collection.""" 112 | logger.info(f"Fetching metadata for collection {self.name}") 113 | self._collection = self._client.get_collection( 114 | self.name, embedding_function=self._collection._embedding_function 115 | ) 116 | logger.info(f"Fetched metadata for collection {self.name}") 117 | return self._collection.metadata 118 | 119 | def save_collection_metadata(self, metadata: dict[str, Any]) -> None: 120 | """Set metadata for the underlying chromadb collection.""" 121 | self._collection.modify(metadata=metadata) 122 | 123 | def rename_collection(self, new_name: str) -> None: 124 | """Rename the underlying chromadb collection.""" 125 | self._collection.modify(name=new_name) 126 | 127 | def delete_collection(self, collection_name: str) -> None: 128 | """Delete the underlying chromadb collection.""" 129 | self._client.delete_collection(collection_name) 130 | 131 | def similarity_search_with_score( 132 | self, 133 | query: str, 134 | k: int, # = DEFAULT_K, 135 | filter: Where | None = None, 136 | **kwargs: Any, 137 | ) -> list[tuple[Document, float]]: 138 | """ 139 | Run similarity search with Chroma with distance. 140 | 141 | Args: 142 | query (str): Query text to search for. 143 | k (int): Number of results to return. 144 | filter (Where | None): Filter by metadata. Corresponds to the chromadb 'where' 145 | parameter. Defaults to None. 146 | **kwargs: Additional keyword arguments. Only 'where_document' is used, if present. 147 | 148 | Returns: 149 | list[tuple[Document, float]]: list of documents most similar to 150 | the query text and cosine distance in float for each. 151 | """ 152 | 153 | # Determine if the passed kwargs contain a 'where_document' parameter 154 | # If so, we'll pass it to the __query_collection method 155 | try: 156 | possible_where_document_kwarg = {"where_document": kwargs["where_document"]} 157 | except KeyError: 158 | possible_where_document_kwarg = {} 159 | 160 | # Query by text or embedding, depending on whether an embedding function is present 161 | if self._embedding_function is None: 162 | results = self._Chroma__query_collection( 163 | query_texts=[query], 164 | n_results=k, 165 | where=filter, 166 | **possible_where_document_kwarg, 167 | ) 168 | else: 169 | query_embedding = self._embedding_function.embed_query(query) 170 | results = self._Chroma__query_collection( 171 | query_embeddings=[query_embedding], 172 | n_results=k, 173 | where=filter, 174 | **possible_where_document_kwarg, 175 | ) 176 | 177 | return _results_to_docs_and_scores(results) 178 | 179 | 180 | def exists_collection( 181 | collection_name: str, 182 | client: ClientAPI, 183 | ) -> bool: 184 | """ 185 | Check if a collection exists. 186 | """ 187 | # NOTE: Alternative: return collection_name in {x.name for x in client.list_collections()} 188 | try: 189 | client.get_collection(collection_name) 190 | return True 191 | except Exception as e: # Exception: {ValueError: "Collection 'test' does not exist"} 192 | if "does not exist" in str(e): 193 | return False # collection does not exist 194 | raise e 195 | 196 | 197 | def initialize_client(use_chroma_via_http: bool = USE_CHROMA_VIA_HTTP) -> ClientAPI: 198 | """ 199 | Initialize a chroma client. 200 | """ 201 | if use_chroma_via_http: 202 | return HttpClient( 203 | host=CHROMA_SERVER_HOST, # must provide host and port explicitly... 204 | port=CHROMA_SERVER_HTTP_PORT, # ...if the env vars are different from defaults 205 | settings=Settings( 206 | chroma_client_auth_provider="chromadb.auth.token_authn.TokenAuthClientProvider", 207 | chroma_auth_token_transport_header="X-Chroma-Token", 208 | chroma_client_auth_credentials=CHROMA_SERVER_AUTHN_CREDENTIALS, 209 | anonymized_telemetry=False, 210 | ), 211 | ) 212 | if not isinstance(VECTORDB_DIR, str) or not os.path.isdir(VECTORDB_DIR): 213 | # NOTE: interestingly, isdir(None) returns True, hence the additional check 214 | raise ValueError(f"Invalid chromadb path: {VECTORDB_DIR}") 215 | 216 | return PersistentClient(VECTORDB_DIR, settings=Settings(anonymized_telemetry=False)) 217 | # return Client(Settings(chroma_db_impl="duckdb+parquet", persist_directory=path)) 218 | 219 | 220 | def ensure_chroma_client(client: ClientAPI | None = None) -> ClientAPI: 221 | """ 222 | Ensure that a chroma client is initialized and return it. 223 | """ 224 | return client or initialize_client() 225 | 226 | 227 | def get_vectorstore_using_openai_api_key( 228 | collection_name: str, 229 | *, 230 | openai_api_key: str, 231 | client: ClientAPI | None = None, 232 | create_if_not_exists: bool = False, 233 | ) -> ChromaDDG: 234 | """ 235 | Load a ChromaDDG vectorstore from a given collection name and OpenAI API key. 236 | """ 237 | return ChromaDDG( 238 | client=ensure_chroma_client(client), 239 | collection_name=collection_name, 240 | create_if_not_exists=create_if_not_exists, 241 | embedding_function=get_openai_embeddings(openai_api_key), 242 | ) 243 | -------------------------------------------------------------------------------- /components/chroma_ddg_retriever.py: -------------------------------------------------------------------------------- 1 | from typing import Any, ClassVar 2 | 3 | from chromadb.api.types import Where, WhereDocument 4 | from langchain_core.documents import Document 5 | from pydantic import Field 6 | 7 | from utils.helpers import DELIMITER, lin_interpolate 8 | from utils.lang_utils import expand_chunks 9 | from utils.prepare import CONTEXT_LENGTH, EMBEDDINGS_MODEL_NAME 10 | from langchain_core.callbacks import AsyncCallbackManagerForRetrieverRun, CallbackManagerForRetrieverRun 11 | from langchain_core.language_models import BaseLanguageModel 12 | from langchain_core.vectorstores import VectorStoreRetriever 13 | 14 | 15 | class ChromaDDGRetriever(VectorStoreRetriever): 16 | """ 17 | A retriever that uses a ChromaDDG vectorstore to find relevant documents. 18 | 19 | Compared to a generic VectorStoreRetriever, this retriever has the additional 20 | feature of accepting the `where` and `where_document` parameters in its 21 | `get_relevant_documents` method. These parameters are used to filter the 22 | documents by their metadata and contained text, respectively. 23 | 24 | NOTE: even though the underlying vectorstore is not explicitly required to 25 | be a ChromaDDG, it must be a vectorstore that supports the `where` and 26 | `where_document` parameters in its `similarity_search`-type methods. 27 | """ 28 | 29 | llm_for_token_counting: BaseLanguageModel | None 30 | 31 | verbose: bool = False # print similarity scores and other info 32 | k_overshot = 20 # number of docs to return initially (prune later) 33 | score_threshold_overshot = ( 34 | -12345.0 35 | ) # score threshold to use initially (prune later) 36 | 37 | k_min = 2 # min number of docs to return after pruning 38 | score_threshold_min = ( 39 | 0.61 if EMBEDDINGS_MODEL_NAME == "text-embeddings-ada-002" else -0.1 40 | ) # use k_min if score of k_min'th doc is <= this 41 | 42 | k_max = 10 # max number of docs to return after pruning 43 | score_threshold_max = ( 44 | 0.76 if EMBEDDINGS_MODEL_NAME == "text-embeddings-ada-002" else 0.2 45 | ) # use k_max if score of k_max'th doc is >= this 46 | 47 | max_total_tokens = int(CONTEXT_LENGTH * 0.5) # consistent with ChatWithDocsChain 48 | max_average_tokens_per_chunk = int(max_total_tokens / k_max) 49 | 50 | # get_relevant_documents() must return only docs, but we'll save scores here 51 | similarities: list = Field(default_factory=list) 52 | 53 | allowed_search_types: ClassVar[tuple[str]] = ( 54 | "similarity", 55 | # "similarity_score_threshold", # NOTE can add at some point 56 | "mmr", 57 | "similarity_ddg", 58 | ) 59 | 60 | def _get_relevant_documents( 61 | self, 62 | query: str, 63 | *, 64 | run_manager: CallbackManagerForRetrieverRun, 65 | filter: Where | None = None, # For metadata (Langchain naming convention) 66 | where_document: WhereDocument | None = None, # Filter by text in document 67 | **kwargs: Any, # For additional search params 68 | ) -> list[Document]: 69 | # Combine global search kwargs with per-query search params passed here 70 | search_kwargs = self.search_kwargs | kwargs 71 | if filter is not None: 72 | search_kwargs["filter"] = filter 73 | if where_document is not None: 74 | search_kwargs["where_document"] = where_document 75 | 76 | # Perform search depending on search type 77 | if self.search_type == "similarity": 78 | return self.vectorstore.similarity_search(query, **search_kwargs) 79 | elif self.search_type == "mmr": 80 | return self.vectorstore.max_marginal_relevance_search( 81 | query, **search_kwargs 82 | ) 83 | 84 | # Main search method used by DocDocGo 85 | assert self.search_type == "similarity_ddg", "Invalid search type" 86 | assert str(type(self.vectorstore)).endswith("ChromaDDG'>"), "Bad vectorstore" 87 | 88 | # First, get more docs than we need, then we'll pare them down 89 | # NOTE this is because apparently Chroma can miss even the most relevant doc 90 | # if k (num docs to return) is not high (e.g. "Who is Big Keetie?", k = 10) 91 | # k_overshot = max(k := search_kwargs["k"], self.k_overshot) 92 | # score_threshold_overshot = min( 93 | # score_threshold := search_kwargs["score_threshold"], 94 | # self.score_threshold_overshot, 95 | # ) # usually simply 0 96 | 97 | docs_and_similarities_overshot = ( 98 | self.vectorstore.similarity_search_with_relevance_scores( 99 | query, 100 | **( 101 | search_kwargs 102 | | { 103 | "k": self.k_overshot, 104 | "score_threshold": self.score_threshold_overshot, 105 | } 106 | ), 107 | ) 108 | ) 109 | 110 | if self.verbose: 111 | for doc, sim in docs_and_similarities_overshot: 112 | print(f"[SIMILARITY: {sim:.2f}] {repr(doc.page_content[:60])}") 113 | print(f"Before paring down: {len(docs_and_similarities_overshot)} docs.") 114 | 115 | # Now, pare down the results 116 | chunks: list[Document] = [] 117 | self.similarities: list[float] = [] 118 | for k, (doc, sim) in enumerate(docs_and_similarities_overshot, start=1): 119 | # If we've already found enough docs, stop 120 | if k > self.k_max: 121 | break 122 | 123 | # Find score_threshold if we were to have k docs 124 | score_threshold_if_stop = lin_interpolate( 125 | k, 126 | self.k_min, 127 | self.k_max, 128 | self.score_threshold_min, 129 | self.score_threshold_max, 130 | ) 131 | 132 | # If we have min num of docs and similarity is below the threshold, stop 133 | if k > self.k_min and sim < score_threshold_if_stop: 134 | break 135 | 136 | # Otherwise, add the doc to the list and keep going 137 | chunks.append(doc) 138 | self.similarities.append(sim) 139 | 140 | if self.verbose: 141 | print(f"After paring down: {len(chunks)} docs.") 142 | if chunks: 143 | print( 144 | f"Similarities from {self.similarities[-1]:.2f} to {self.similarities[0]:.2f}" 145 | ) 146 | print(DELIMITER) 147 | 148 | # Get the parent documents for the chunks 149 | try: 150 | print("METADATAS:") 151 | for chunk in chunks: 152 | print(chunk.metadata) 153 | parent_ids = [chunk.metadata["parent_id"] for chunk in chunks] 154 | except KeyError: 155 | # If it's an older collection, without parent docs, just return the chunks 156 | return chunks 157 | unique_parent_ids = list(set(parent_ids)) 158 | rsp = self.vectorstore.collection.get(unique_parent_ids) 159 | 160 | parent_docs_by_id = { 161 | id: Document(page_content=text, metadata=metadata) 162 | for id, text, metadata in zip( 163 | rsp["ids"], rsp["documents"], rsp["metadatas"] 164 | ) 165 | } 166 | 167 | # Expand chunks using the parent docs 168 | max_total_tokens = min( 169 | self.max_total_tokens, self.max_average_tokens_per_chunk * len(chunks) 170 | ) 171 | expanded_chunks = expand_chunks( 172 | chunks, 173 | parent_docs_by_id, 174 | max_total_tokens, 175 | llm_for_token_counting=self.llm_for_token_counting, 176 | ) 177 | return expanded_chunks 178 | 179 | async def _aget_relevant_documents( 180 | self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun 181 | ) -> list[Document]: 182 | raise NotImplementedError("Asynchronous retrieval not yet supported.") 183 | -------------------------------------------------------------------------------- /components/llm.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | from uuid import UUID 3 | from langchain_core.outputs import ChatGenerationChunk, GenerationChunk, LLMResult 4 | from langchain_openai import AzureChatOpenAI, ChatOpenAI 5 | from streamlit.delta_generator import DeltaGenerator 6 | 7 | from utils.helpers import DELIMITER, MAIN_BOT_PREFIX 8 | from utils.lang_utils import msg_list_chat_history_to_string 9 | from utils.prepare import ( 10 | CHAT_DEPLOYMENT_NAME, 11 | IS_AZURE, 12 | LLM_REQUEST_TIMEOUT, 13 | ) 14 | from utils.streamlit.helpers import fix_markdown 15 | from utils.type_utils import BotSettings, CallbacksOrNone 16 | from langchain_core.callbacks import BaseCallbackHandler 17 | from langchain_core.language_models import BaseChatModel 18 | from langchain_core.output_parsers import StrOutputParser 19 | from langchain_core.prompt_values import ChatPromptValue 20 | from langchain_core.prompts import PromptTemplate 21 | 22 | 23 | class CallbackHandlerDDGStreamlit(BaseCallbackHandler): 24 | def __init__(self, container: DeltaGenerator, end_str: str = ""): 25 | self.container = container 26 | self.buffer = "" 27 | self.end_str = end_str 28 | self.end_str_printed = False 29 | 30 | def on_llm_new_token( 31 | self, 32 | token: str, 33 | *, 34 | chunk: GenerationChunk | ChatGenerationChunk | None = None, 35 | run_id: UUID, 36 | parent_run_id: UUID | None = None, 37 | **kwargs: Any, 38 | ) -> None: 39 | self.buffer += token 40 | self.container.markdown(fix_markdown(self.buffer)) 41 | 42 | def on_llm_end( 43 | self, 44 | response: LLMResult, 45 | *, 46 | run_id: UUID, 47 | parent_run_id: UUID | None = None, 48 | **kwargs: Any, 49 | ) -> None: 50 | """Run when LLM ends running.""" 51 | if self.end_str: 52 | self.end_str_printed = True 53 | self.container.markdown(fix_markdown(self.buffer + self.end_str)) 54 | 55 | 56 | class CallbackHandlerDDGConsole(BaseCallbackHandler): 57 | def __init__(self, init_str: str = MAIN_BOT_PREFIX): 58 | self.init_str = init_str 59 | 60 | def on_llm_start( 61 | self, serialized: dict[str, Any], prompts: list[str], **kwargs: Any 62 | ) -> None: 63 | print(self.init_str, end="", flush=True) 64 | 65 | def on_llm_new_token(self, token, **kwargs) -> None: 66 | print(token, end="", flush=True) 67 | 68 | def on_llm_end(self, *args, **kwargs) -> None: 69 | print() 70 | 71 | def on_retry(self, *args, **kwargs) -> None: 72 | print(f"ON_RETRY: \nargs = {args}\nkwargs = {kwargs}") 73 | 74 | 75 | def get_llm_with_callbacks( 76 | settings: BotSettings, api_key: str | None = None, callbacks: CallbacksOrNone = None 77 | ) -> BaseChatModel: 78 | """ 79 | Returns a chat model instance (either AzureChatOpenAI or ChatOpenAI, depending 80 | on the value of IS_AZURE). In the case of AzureChatOpenAI, the model is 81 | determined by CHAT_DEPLOYMENT_NAME (and other Azure-specific environment variables), 82 | not by settings.model_name. 83 | """ 84 | if IS_AZURE: 85 | llm = AzureChatOpenAI( 86 | deployment_name=CHAT_DEPLOYMENT_NAME, 87 | temperature=settings.temperature, 88 | request_timeout=LLM_REQUEST_TIMEOUT, 89 | streaming=True, # seems to help with timeouts 90 | callbacks=callbacks, 91 | ) 92 | else: 93 | llm = ChatOpenAI( 94 | api_key=api_key or "", # don't allow None, no implicit key from env 95 | model=settings.llm_model_name, 96 | temperature=settings.temperature, 97 | request_timeout=LLM_REQUEST_TIMEOUT, 98 | streaming=True, 99 | callbacks=callbacks, 100 | verbose=True, # tmp 101 | ) 102 | return llm 103 | 104 | 105 | def get_llm( 106 | settings: BotSettings, 107 | api_key: str | None = None, 108 | callbacks: CallbacksOrNone = None, 109 | stream=False, 110 | init_str=MAIN_BOT_PREFIX, 111 | ) -> BaseChatModel: 112 | """ 113 | Return a chat model instance (either AzureChatOpenAI or ChatOpenAI, depending 114 | on the value of IS_AZURE). If callbacks is passed, it will be used as the 115 | callbacks for the chat model. Otherwise, if stream is True, then a CallbackHandlerDDG 116 | will be used with the passed init_str as the init_str. 117 | """ 118 | if callbacks is None: 119 | callbacks = [CallbackHandlerDDGConsole(init_str)] if stream else [] 120 | return get_llm_with_callbacks(settings, api_key, callbacks) 121 | 122 | 123 | def get_prompt_llm_chain( 124 | prompt: PromptTemplate, 125 | llm_settings: BotSettings, 126 | api_key: str | None = None, 127 | print_prompt=False, 128 | **kwargs, 129 | ): 130 | if not print_prompt: 131 | return prompt | get_llm(llm_settings, api_key, **kwargs) | StrOutputParser() 132 | 133 | def print_and_return(thing): 134 | if isinstance(thing, ChatPromptValue): 135 | print(f"PROMPT:\n{msg_list_chat_history_to_string(thing.messages)}") 136 | else: 137 | print(f"PROMPT:\n{type(thing)}\n{thing}") 138 | print(DELIMITER) 139 | return thing 140 | 141 | return ( 142 | prompt 143 | | print_and_return 144 | | get_llm(llm_settings, api_key, **kwargs) 145 | | StrOutputParser() 146 | ) 147 | 148 | 149 | def get_llm_from_prompt_llm_chain(prompt_llm_chain): 150 | return prompt_llm_chain.middle[-1] 151 | 152 | 153 | def get_prompt_text(prompt: PromptTemplate, inputs: dict[str, str]) -> str: 154 | return prompt.invoke(inputs).to_string() 155 | 156 | 157 | if __name__ == "__main__": 158 | # NOTE: Run this file as "python -m components.llm" 159 | x = CallbackHandlerDDGConsole("BOT: ") 160 | -------------------------------------------------------------------------------- /components/openai_embeddings_ddg.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from langchain_core.embeddings import Embeddings 4 | from langchain_openai import OpenAIEmbeddings 5 | 6 | from utils.prepare import EMBEDDINGS_DIMENSIONS, EMBEDDINGS_MODEL_NAME, IS_AZURE 7 | 8 | 9 | def get_openai_embeddings( 10 | api_key: str | None = None, 11 | embeddings_model_name: str = EMBEDDINGS_MODEL_NAME, 12 | embeddings_dimensions: int = EMBEDDINGS_DIMENSIONS, 13 | ) -> Embeddings: 14 | # Create the embeddings object, pulling current values of env vars 15 | # (as of Aug 8, 2023, max chunk size for Azure API is 16) 16 | return ( 17 | OpenAIEmbeddings( # NOTE: should be able to simplify this 18 | deployment=os.getenv("EMBEDDINGS_DEPLOYMENT_NAME"), chunk_size=16 19 | ) 20 | if IS_AZURE 21 | else OpenAIEmbeddings( 22 | api_key=api_key or "", 23 | model=embeddings_model_name, 24 | dimensions=None 25 | if embeddings_model_name == "text-embeddings-ada-002" 26 | else embeddings_dimensions, 27 | ) # NOTE: if empty API key, will throw 28 | ) 29 | 30 | 31 | # class OpenAIEmbeddingsDDG(Embeddings): 32 | # """ 33 | # Custom version of OpenAIEmbeddings for DocDocGo. Unlike the original, 34 | # an object of this class will pull the current values of env vars every time. 35 | # This helps in situations where the user has changed env vars such as 36 | # OPENAI_API_KEY, as is possible in the Streamlit app. 37 | 38 | # This way is also more consistent with the behavior of e.g. ChatOpenAI, which 39 | # always uses the current values of env vars when querying the OpenAI API. 40 | # """ 41 | 42 | # @staticmethod 43 | # def get_fresh_openai_embeddings(): 44 | # # Create the embeddings object, pulling current values of env vars 45 | # # (as of Aug 8, 2023, max chunk size for Azure API is 16) 46 | # return ( 47 | # OpenAIEmbeddings( # NOTE: should be able to simplify this 48 | # deployment=os.getenv("EMBEDDINGS_DEPLOYMENT_NAME"), chunk_size=16 49 | # ) 50 | # if os.getenv("OPENAI_API_BASE") # proxy for whether we're using Azure 51 | # else OpenAIEmbeddings() 52 | # ) 53 | 54 | # def embed_documents(self, texts: list[str]) -> list[list[float]]: 55 | # """Embed search docs.""" 56 | # return self.get_fresh_openai_embeddings().embed_documents(texts) 57 | 58 | # def embed_query(self, text: str) -> list[float]: 59 | # """Embed query text.""" 60 | # return self.get_fresh_openai_embeddings().embed_query(text) 61 | 62 | # async def aembed_documents(self, texts: list[str]) -> list[list[float]]: 63 | # """Asynchronous Embed search docs.""" 64 | # return await self.get_fresh_openai_embeddings().aembed_documents(texts) 65 | 66 | # async def aembed_query(self, text: str) -> list[float]: 67 | # """Asynchronous Embed query text.""" 68 | # return await self.get_fresh_openai_embeddings().aembed_query(text) 69 | -------------------------------------------------------------------------------- /config/logging.json: -------------------------------------------------------------------------------- 1 | { 2 | "version": 1, 3 | "disable_existing_loggers": false, 4 | "formatters": { 5 | "simple": { 6 | "format": "[%(levelname)s|%(name)s|%(module)s|L%(lineno)d] %(asctime)s: %(message)s", 7 | "datefmt": "%Y-%m-%dT%H:%M:%S%z" 8 | }, 9 | "json": { 10 | "()": "utils.log.MyJSONFormatter", 11 | "fmt_keys": { 12 | "level": "levelname", 13 | "message": "message", 14 | "timestamp": "timestamp", 15 | "logger": "name", 16 | "module": "module", 17 | "function": "funcName", 18 | "line": "lineno", 19 | "thread_name": "threadName" 20 | } 21 | } 22 | }, 23 | "handlers": { 24 | "h1_stderr": { 25 | "class": "logging.StreamHandler", 26 | "level": "INFO", 27 | "formatter": "simple", 28 | "stream": "ext://sys.stderr" 29 | }, 30 | "h2_file_json": { 31 | "class": "logging.handlers.RotatingFileHandler", 32 | "level": "DEBUG", 33 | "formatter": "json", 34 | "filename": "logs/ddg-logs.jsonl", 35 | "maxBytes": 1000000, 36 | "backupCount": 2 37 | }, 38 | "h3_queue": { 39 | "()": "utils.log.QueueListenerHandler", 40 | "handlers": ["cfg://handlers.h2_file_json", "cfg://handlers.h1_stderr"], 41 | "respect_handler_level": true 42 | } 43 | }, 44 | "loggers": { 45 | "trafilatura": { 46 | "level": "ERROR", 47 | "handlers": ["h3_queue"], 48 | "propagate": false 49 | }, 50 | "ddg": { 51 | "level": "DEBUG", 52 | "handlers": ["h3_queue"], 53 | "propagate": false 54 | } 55 | }, 56 | "root": { 57 | "level": "WARNING", 58 | "handlers": ["h3_queue"] 59 | } 60 | } 61 | -------------------------------------------------------------------------------- /config/trafilatura.cfg: -------------------------------------------------------------------------------- 1 | # Defines settings for trafilatura (https://github.com/adbar/trafilatura) 2 | 3 | [DEFAULT] 4 | 5 | # Download 6 | DOWNLOAD_TIMEOUT = 30 7 | MAX_FILE_SIZE = 20000000 8 | MIN_FILE_SIZE = 10 9 | # sleep between requests 10 | SLEEP_TIME = 5 11 | # user-agents here: agent1,agent2,... 12 | USER_AGENTS = 13 | # cookie for HTTP requests 14 | COOKIE = 15 | 16 | # Extraction 17 | MIN_EXTRACTED_SIZE = 1500 18 | # MIN_EXTRACTED_SIZE = 250 19 | # increased to increase recall 20 | # eg for https://www.artificialintelligence-news.com/ extracted only 21 | # cookie consent text, even with favor_recall = True 22 | MIN_EXTRACTED_COMM_SIZE = 1 23 | MIN_OUTPUT_SIZE = 1 24 | MIN_OUTPUT_COMM_SIZE = 1 25 | 26 | # CLI file processing only, set to 0 to disable 27 | EXTRACTION_TIMEOUT = 30 28 | 29 | # Deduplication 30 | MIN_DUPLCHECK_SIZE = 100 31 | MAX_REPETITIONS = 2 32 | 33 | -------------------------------------------------------------------------------- /eval/RAG-PARAMS.md: -------------------------------------------------------------------------------- 1 | # Parameters for RAG 2 | 3 | ## ChromaDDGRetriever 4 | 5 | k_min = 2 # min number of docs to return after pruning 6 | score_threshold_min = 0.61 # use k_min if score of k_min'th doc is <= this 7 | k_max = 10 # max number of docs to return after pruning 8 | score_threshold_max = 0.76 # use k_max if score of k_max'th doc is >= this 9 | 10 | ## prepare_chunks 11 | 12 | text_splitter = RecursiveCharacterTextSplitter( 13 | chunk_size=400, 14 | chunk_overlap=40, 15 | -------------------------------------------------------------------------------- /eval/ai_news_1.py: -------------------------------------------------------------------------------- 1 | ai_news_1 ="""sdecoret - stock.adobe.com 2 | Biggest AI news of 2022 3 | From misinformation and disinformation in the war in Ukraine to James Earl Jones lending his voice to voice cloning, these are the most notable stories of the year. 4 | Despite economic uncertainty and a few hiccups regarding ethics and bias, 2022 was a year of change for AI technologies. 5 | Here is a look back at the stories that dominated the year. 6 | Can AI be sentient? 7 | Blake Lemoine, a former engineer at Google, broke the internet when he declared that 8 | [Google's Language Model for Dialogue Applications](https://www.techtarget.com/searchenterpriseai/feature/Ex-Google-engineer-Blake-Lemoine-discusses-sentient-AI) (LaMDA), a large language model ( [LLM](https://www.techtarget.com/whatis/definition/large-language-model-LLM)) in development, is sentient. 9 | Blake claimed that based on LaMDA's answers to his questions, the LLM can communicate its fears and have feelings the way a human can. 10 | Google later fired Lemoine for choosing to "persistently violate clear employment and data security policies that include the need to safeguard product information," the company said in a statement. 11 | Generative AI 12 | From Dall-E 2 to Stable Diffusion to ChatGPT, generative AI has made AI more accessible to not only enterprise users but also consumers. 13 | However, with more accessibility comes more responsibility. 14 | As everyday consumers use Dall-E 2 to create striking images, artists and creators are concerned that these models are not only erasing a need for them but also 15 | [stealing their artwork](https://www.techtarget.com/searchenterpriseai/feature/The-creative-thief-AI-tools-creating-generated-art). 16 | This has also brought up concerns about copyright infringement laws when it comes to work created -- or stolen -- by AI. 17 | According to Michael G. Bennett, director of education curriculum and business lead for responsible AI at the Institute for Experiential AI, while some artists may have a case for copyright infringement if the artwork generated shows a strong similarity to the artist's work, if the work does not show a strong similarity, most artists will not have a case. Many times, the work being produced is the same as if a young artist were to be inspired by an older artist. 18 | Despite the concerns behind generated artwork and generative AI, many creators are embracing it. Marketing vendor Omneky uses Dall-E to generate ads for customers. 19 | Metaverse, avatars and metaverse technologies 20 | The metaverse dominated early 2022, with vendors such as Nvidia describing what they believe the next iteration of the internet will look like. Nvidia introduced an avatar engine. Meta -- Facebook's parent company -- invested heavily in metaverse technology. Researchers and analysts described a metaverse that includes 21 | [holographic avatars](https://www.techtarget.com/searchenterpriseai/feature/Enterprise-applications-of-the-metaverse-slow-but-coming) that could be used to train medical professionals and others. 22 | Vendors, including Hour One, introduced technologies such as a 3D news studio, where users can choose an avatar newscaster. 23 | Finally, consumers could reimagine themselves with Prisma Labs' Lensa app, which generated avatars for users. 24 | While many agree that we're still far away from the metaverse that vendors and enterprises alike are describing, the technologies are certainly growing more rampant, making the metaverse more real to all. 25 | Enterprises turn to AI to help ease the Great Resignation 26 | The coronavirus pandemic made many re-evaluate their employment status. This, in turn, meant many employees resigned in favor of other opportunities. 27 | For enterprises in the restaurant industry, this meant many positions were left unfilled. From pizza chain Jet's Pizza to 28 | [Panera Bread](https://www.techtarget.com/searchenterpriseai/news/252524391/Paneras-AI-drive-through-test-addresses-labor-concerns), enterprises turned to AI tools and technologies that could supplement human workers so that those left on staff could focus their attention elsewhere. 29 | "We're starting to see AI move out of that theoretical space of all the cool things AI can do," Liz Miller, an analyst at Constellation Research, told TechTarget Editorial back in August. "We're starting to see real, specific kinds of business acceleration-minded applications coming off of the workbench and getting into real life." 30 | Voice cloning and James Earl Jones 31 | Liz MillerAnalyst, Constellation Research 32 | When news broke that James Earl Jones was lending his voice to 33 | [speech-to-speech voice cloning](https://www.techtarget.com/searchenterpriseai/feature/James-Earl-Jones-AI-and-the-growing-voice-cloning-market) for future appearances of his Star Wars character Darth Vader, it gave insight into the growing use of the technology. 34 | There are two types of voice cloning: speech-to-speech and text-to-speech. 35 | Text-to-speech is used in contact center environments, where synthetic speech can communicate with consumers before they speak to agents. It's especially helpful when agents speak different languages. 36 | AI and the war in Ukraine 37 | When the war between Russia and Ukraine began earlier this year, 38 | [AI was a tool used to spread disinformation](https://www.techtarget.com/searchenterpriseai/feature/AI-and-disinformation-in-the-Russia-Ukraine-war). Bad actors used the technology to create videos that spread misinformation. Deepfakes or AI-generated humans were created to spread anti-Ukrainian discourse. 39 | The use of deepfakes in the war showed how easy it has become for everyday consumers to create them without coding knowledge. 40 | The use of AI in the war also shows the psychological impact machine learning can have. 41 | "Machine learning is exceptionally good at learning how to exploit human psychology because the internet provides a vast and fast feedback loop to learn what will reinforce and or break beliefs by demographic cohorts," said Mike Gualtieri, an analyst at Forrester Research.""" -------------------------------------------------------------------------------- /eval/queries.md: -------------------------------------------------------------------------------- 1 | # Queries For Evaluation 2 | 3 | /research AI-powered resume builders - simple numbered list of as many as possible 4 | /research running llama2 locally 5 | /research legal arguments for and against disqualifying Trump from running 6 | /research what are some unique traditional Ukrainian dishes 7 | /research what are some important ai research papers and their summaries 8 | /research what do various organizations report as the recommended protein intake? 9 | /research recent ai news 10 | /research I'm looking for a tutorial on configuring Firestore (not Firebase) so that my external server-side application can use it. Most importantly, I want to know exactly how to set up the right role to the service account. Please assign a score 0-100 to each website for how well it fulfills that purpose. The tutorial should not be from Google's docs and they should not be about Firebase, only regular Firestore. If it's from Google's docs or mentions Firebase, assign score 0. Please just assign a score to each source and add a one-two sentence description, nothing else. Include links. If your prompt includes a previous report with assessments, keep them in your current report 11 | /research Please find quotes by Noam Chomsky about the scientific or practical value of ChatGPT. For each quote include a link to the source 12 | 13 | What is the typical range for the 1536 values in the text-embedding-ada-002 embeddings? 14 | -------------------------------------------------------------------------------- /eval/testing-rag-quality/2401.01325.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reasonmethis/docdocgo-core/e6408ea9c7f4657525715730f893e8a3d27296a5/eval/testing-rag-quality/2401.01325.pdf -------------------------------------------------------------------------------- /eval/testing-rag-quality/rag-tests.md: -------------------------------------------------------------------------------- 1 | Asking "What are the conclusions of the paper?" after ingesting 2401.01325.pdf -------------------------------------------------------------------------------- /ingest_local_docs.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | from langchain_community.document_loaders import DirectoryLoader, TextLoader 5 | 6 | from _prepare_env import is_env_loaded 7 | from components.chroma_ddg import initialize_client 8 | from utils.docgrab import ( 9 | JSONLDocumentLoader, 10 | ingest_into_chroma, 11 | ) 12 | from utils.helpers import clear_directory, get_timestamp, is_directory_empty, print_no_newline 13 | from utils.prepare import DEFAULT_OPENAI_API_KEY 14 | 15 | is_env_loaded = is_env_loaded # see explanation at the end of docdocgo.py 16 | 17 | if __name__ == "__main__": 18 | print( 19 | "-" * 70 20 | + "\n" 21 | + " " * 20 22 | + "Ingestion of Local Docs into Local DB\n" 23 | + "NOTE: It's recommended to use the Streamlit UI for ingestion instead,\n" 24 | + "as this script is not maintained as frequently.\n" 25 | + "-" * 70 26 | + "\n" 27 | ) 28 | 29 | DOCS_TO_INGEST_DIR_OR_FILE = os.getenv("DOCS_TO_INGEST_DIR_OR_FILE") 30 | COLLECTON_NAME_FOR_INGESTED_DOCS = os.getenv("COLLECTON_NAME_FOR_INGESTED_DOCS") 31 | VECTORDB_DIR = os.getenv("VECTORDB_DIR") 32 | 33 | if ( 34 | not DOCS_TO_INGEST_DIR_OR_FILE 35 | or not VECTORDB_DIR 36 | or not COLLECTON_NAME_FOR_INGESTED_DOCS 37 | ): 38 | print( 39 | "Please set DOCS_TO_INGEST_DIR_OR_FILE, COLLECTON_NAME_FOR_INGESTED_DOC, " 40 | "and VECTORDB_DIR in `.env`." 41 | ) 42 | sys.exit() 43 | if not os.path.exists(DOCS_TO_INGEST_DIR_OR_FILE): 44 | print(f"{DOCS_TO_INGEST_DIR_OR_FILE} does not exist.") 45 | sys.exit() 46 | 47 | # Validate the save directory 48 | if not os.path.isdir(VECTORDB_DIR): 49 | if os.path.exists(VECTORDB_DIR): 50 | print(f"{VECTORDB_DIR} is not a directory.") 51 | sys.exit() 52 | print(f"{VECTORDB_DIR} is not an existing directory.") 53 | ans = input("Do you want to create it? [y/N] ") 54 | if ans.lower() != "y": 55 | sys.exit() 56 | else: 57 | try: 58 | os.makedirs(VECTORDB_DIR) 59 | except Exception as e: 60 | print(f"Could not create {VECTORDB_DIR}: {e}") 61 | sys.exit() 62 | 63 | # Check if the save directory is empty; if not, give the user options 64 | chroma_client = None 65 | if is_directory_empty(VECTORDB_DIR): 66 | is_new_db = True 67 | else: 68 | ans = input( 69 | f"{VECTORDB_DIR} is not empty. Please select an option:\n" 70 | "1. Use it (select if it contains your existing db)\n" 71 | "2. Delete all its content and create a new db\n" 72 | "3. Cancel\nYour choice: " 73 | ) 74 | print() 75 | if ans not in {"1", "2"}: 76 | sys.exit() 77 | is_new_db = ans == "2" 78 | if is_new_db: 79 | ans = input("Type 'erase' to confirm deleting the directory's content: ") 80 | if ans != "erase": 81 | sys.exit() 82 | print_no_newline(f"Deleting files and folders in {VECTORDB_DIR}...") 83 | clear_directory(VECTORDB_DIR) 84 | print("Done!") 85 | else: 86 | chroma_client = initialize_client(use_chroma_via_http=False) 87 | collections = chroma_client.list_collections() 88 | collection_names = [c.name for c in collections] 89 | if COLLECTON_NAME_FOR_INGESTED_DOCS in collection_names: 90 | ans = input( 91 | f"Collection {COLLECTON_NAME_FOR_INGESTED_DOCS} already exists. " 92 | "Please select an option:\n" 93 | "1. Append\n" 94 | "2. Overwrite\n" 95 | "3. Cancel\nYour choice: " 96 | ) 97 | print() 98 | if ans not in {"1", "2"}: 99 | sys.exit() 100 | if ans == "2": 101 | ans = input( 102 | "Type 'erase' to confirm deleting the existing collection: " 103 | ) 104 | if ans != "erase": 105 | sys.exit() 106 | print_no_newline( 107 | f"Deleting collection {COLLECTON_NAME_FOR_INGESTED_DOCS}..." 108 | ) 109 | chroma_client.delete_collection(COLLECTON_NAME_FOR_INGESTED_DOCS) 110 | print("Done!\n") 111 | 112 | # Confirm the ingestion 113 | print(f"You are about to ingest documents from: {DOCS_TO_INGEST_DIR_OR_FILE}") 114 | print(f"The vector database directory: {VECTORDB_DIR}") 115 | print(f"The collection name is: {COLLECTON_NAME_FOR_INGESTED_DOCS}") 116 | print(f"Is this a new database? {'Yes' if is_new_db else 'No'}") 117 | if os.getenv("USE_CHROMA_VIA_HTTP") or os.getenv("CHROMA_API_IMPL"): 118 | print( 119 | "NOTE: we will disable the USE_CHROMA_VIA_HTTP and " 120 | "CHROMA_API_IMPL environment variables since this is a local ingestion." 121 | ) 122 | os.environ["USE_CHROMA_VIA_HTTP"] = os.environ["CHROMA_API_IMPL"] = "" 123 | 124 | print("\nATTENTION: This will incur some cost on your OpenAI account.") 125 | 126 | if input("Press Enter to proceed. Any non-empty input will cancel the procedure: "): 127 | print("Ingestion cancelled. Exiting") 128 | sys.exit() 129 | 130 | # Load the documents 131 | print_no_newline("Loading your documents...") 132 | if os.path.isfile(DOCS_TO_INGEST_DIR_OR_FILE): 133 | if DOCS_TO_INGEST_DIR_OR_FILE.endswith(".jsonl"): 134 | loader = JSONLDocumentLoader(DOCS_TO_INGEST_DIR_OR_FILE) 135 | else: 136 | loader = TextLoader(DOCS_TO_INGEST_DIR_OR_FILE, autodetect_encoding=True) 137 | else: 138 | loader = DirectoryLoader( 139 | DOCS_TO_INGEST_DIR_OR_FILE, 140 | loader_cls=TextLoader, # default is Unstructured but python-magic hangs on import 141 | loader_kwargs={"autodetect_encoding": True}, 142 | ) 143 | docs = loader.load() 144 | print("Done!") 145 | 146 | # Ingest into chromadb (this will print messages regarding the status) 147 | timestamp = get_timestamp() 148 | ingest_into_chroma( 149 | docs, 150 | collection_name=COLLECTON_NAME_FOR_INGESTED_DOCS, 151 | chroma_client=chroma_client, 152 | save_dir=None if chroma_client else VECTORDB_DIR, 153 | openai_api_key=DEFAULT_OPENAI_API_KEY, 154 | collection_metadata={"created_at": timestamp, "updated_at": timestamp}, 155 | ) 156 | -------------------------------------------------------------------------------- /media/ddg-logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reasonmethis/docdocgo-core/e6408ea9c7f4657525715730f893e8a3d27296a5/media/ddg-logo.png -------------------------------------------------------------------------------- /media/minimal10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reasonmethis/docdocgo-core/e6408ea9c7f4657525715730f893e8a3d27296a5/media/minimal10.png -------------------------------------------------------------------------------- /media/minimal10.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 19 | 38 | 40 | 45 | 52 | 59 | 66 | 73 | 78 | 85 | 92 | 99 | 106 | 111 | 119 | 124 | 125 | 126 | -------------------------------------------------------------------------------- /media/minimal11.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reasonmethis/docdocgo-core/e6408ea9c7f4657525715730f893e8a3d27296a5/media/minimal11.png -------------------------------------------------------------------------------- /media/minimal11.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 19 | 38 | 40 | 45 | 52 | 59 | 66 | 73 | 78 | 85 | 92 | 99 | 106 | 111 | 119 | 124 | 125 | 126 | -------------------------------------------------------------------------------- /media/minimal13.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 19 | 39 | 55 | 56 | 58 | 63 | 70 | 77 | 84 | 91 | 96 | 103 | 110 | 117 | 124 | 129 | 137 | 142 | 143 | 144 | -------------------------------------------------------------------------------- /media/minimal7-plain-svg.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 12 | 14 | 17 | 24 | 31 | 36 | 43 | 50 | 57 | 62 | 69 | 77 | 84 | 89 | 96 | 97 | 98 | -------------------------------------------------------------------------------- /media/minimal7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reasonmethis/docdocgo-core/e6408ea9c7f4657525715730f893e8a3d27296a5/media/minimal7.png -------------------------------------------------------------------------------- /media/minimal7.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 16 | 35 | 37 | 42 | 49 | 56 | 61 | 68 | 75 | 82 | 87 | 94 | 102 | 109 | 114 | 121 | 122 | 123 | -------------------------------------------------------------------------------- /media/minimal9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reasonmethis/docdocgo-core/e6408ea9c7f4657525715730f893e8a3d27296a5/media/minimal9.png -------------------------------------------------------------------------------- /media/minimal9.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 19 | 38 | 40 | 45 | 52 | 59 | 64 | 71 | 78 | 85 | 90 | 97 | 105 | 112 | 117 | 124 | 125 | 126 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | aiohttp==3.9.1 2 | aiosignal==1.3.1 3 | altair==5.2.0 4 | annotated-types==0.6.0 5 | anyio==4.2.0 6 | asgiref==3.7.2 7 | asttokens==2.4.1 8 | attrs==23.1.0 9 | backoff==2.2.1 10 | bcrypt==4.1.2 11 | beautifulsoup4==4.12.2 12 | blinker==1.7.0 13 | bs4==0.0.1 14 | build==1.2.1 15 | cachetools==5.3.2 16 | certifi==2023.11.17 17 | charset-normalizer==3.3.2 18 | chroma-hnswlib==0.7.5 19 | chromadb==0.5.4 20 | click==8.1.7 21 | colorama==0.4.6 22 | coloredlogs==15.0.1 23 | courlan==0.9.5 24 | dataclasses-json==0.6.3 25 | dateparser==1.2.0 26 | Deprecated==1.2.14 27 | distro==1.9.0 28 | docx2txt==0.8 29 | executing==2.0.1 30 | fake-useragent==1.4.0 31 | fastapi==0.108.0 32 | filelock==3.13.1 33 | flatbuffers==23.5.26 34 | frozenlist==1.4.1 35 | fsspec==2023.12.2 36 | gitdb==4.0.11 37 | GitPython==3.1.40 38 | google-api-core==2.15.0 39 | google-auth==2.25.2 40 | google-cloud-core==2.4.1 41 | google-cloud-firestore==2.14.0 42 | googleapis-common-protos==1.62.0 43 | greenlet==3.0.1 44 | grpcio==1.60.0 45 | grpcio-status==1.60.0 46 | h11==0.14.0 47 | htmldate==1.6.0 48 | httpcore==1.0.2 49 | httptools==0.6.1 50 | httpx==0.27.0 51 | huggingface-hub==0.20.1 52 | humanfriendly==10.0 53 | icecream==2.1.3 54 | idna==3.6 55 | importlib-metadata==6.11.0 56 | importlib-resources==6.1.1 57 | itsdangerous==2.1.2 58 | Jinja2==3.1.2 59 | jsonpatch==1.33 60 | jsonpointer==2.4 61 | jsonschema==4.20.0 62 | jsonschema-specifications==2023.12.1 63 | jusText==3.0.0 64 | kubernetes==28.1.0 65 | langchain==0.2.10 66 | langchain-cli==0.0.25 67 | langchain-community==0.2.9 68 | langchain-core==0.2.22 69 | langchain-openai==0.1.17 70 | langchain-text-splitters==0.2.2 71 | langcodes==3.3.0 72 | langserve==0.2.2 73 | langsmith==0.1.93 74 | libcst==1.4.0 75 | lxml==4.9.4 76 | markdown-it-py==3.0.0 77 | MarkupSafe==2.1.3 78 | marshmallow==3.20.1 79 | mdurl==0.1.2 80 | mmh3==4.0.1 81 | monotonic==1.6 82 | mpmath==1.3.0 83 | multidict==6.0.4 84 | mypy-extensions==1.0.0 85 | numpy==1.26.2 86 | oauthlib==3.2.2 87 | onnxruntime==1.16.3 88 | openai==1.36.1 89 | opentelemetry-api==1.22.0 90 | opentelemetry-exporter-otlp-proto-common==1.22.0 91 | opentelemetry-exporter-otlp-proto-grpc==1.22.0 92 | opentelemetry-instrumentation==0.43b0 93 | opentelemetry-instrumentation-asgi==0.43b0 94 | opentelemetry-instrumentation-fastapi==0.43b0 95 | opentelemetry-proto==1.22.0 96 | opentelemetry-sdk==1.22.0 97 | opentelemetry-semantic-conventions==0.43b0 98 | opentelemetry-util-http==0.43b0 99 | orjson==3.10.6 100 | overrides==7.4.0 101 | packaging==23.2 102 | pandas==2.1.4 103 | Pillow==10.1.0 104 | playwright==1.40.0 105 | posthog==3.1.0 106 | proto-plus==1.23.0 107 | protobuf==4.25.1 108 | pulsar-client==3.3.0 109 | pyarrow==14.0.2 110 | pyasn1==0.5.1 111 | pyasn1-modules==0.3.0 112 | pydantic==2.5.3 113 | pydantic_core==2.14.6 114 | pydeck==0.8.1b0 115 | pyee==11.0.1 116 | Pygments==2.17.2 117 | pypdf==4.0.0 118 | PyPika==0.48.9 119 | pyproject-toml==0.0.10 120 | pyproject_hooks==1.1.0 121 | pyreadline3==3.4.1 122 | python-dateutil==2.8.2 123 | python-dotenv==1.0.0 124 | python-multipart==0.0.9 125 | pytz==2023.3.post1 126 | PyYAML==6.0.1 127 | referencing==0.32.0 128 | regex==2023.12.25 129 | requests==2.31.0 130 | requests-oauthlib==1.3.1 131 | rich==13.7.0 132 | rpds-py==0.15.2 133 | rsa==4.9 134 | shellingham==1.5.4 135 | six==1.16.0 136 | smmap==5.0.1 137 | sniffio==1.3.0 138 | soupsieve==2.5 139 | SQLAlchemy==2.0.23 140 | sse-starlette==1.8.2 141 | starlette==0.32.0.post1 142 | streamlit==1.35.0 143 | sympy==1.12 144 | tenacity==8.2.3 145 | tiktoken==0.7.0 146 | tld==0.13 147 | tokenizers==0.15.0 148 | toml==0.10.2 149 | tomlkit==0.12.5 150 | toolz==0.12.0 151 | tornado==6.4 152 | tqdm==4.66.1 153 | trafilatura==1.6.3 154 | typer==0.9.0 155 | typing-inspect==0.9.0 156 | typing_extensions==4.9.0 157 | tzdata==2023.3 158 | tzlocal==5.2 159 | urllib3==1.26.18 160 | uvicorn==0.23.2 161 | validators==0.22.0 162 | watchdog==3.0.0 163 | watchfiles==0.21.0 164 | websocket-client==1.7.0 165 | websockets==12.0 166 | Werkzeug==3.0.1 167 | wrapt==1.16.0 168 | yarl==1.9.4 169 | zipp==3.17.0 170 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reasonmethis/docdocgo-core/e6408ea9c7f4657525715730f893e8a3d27296a5/utils/__init__.py -------------------------------------------------------------------------------- /utils/algo.py: -------------------------------------------------------------------------------- 1 | from itertools import zip_longest 2 | from typing import Any, Iterable, Iterator 3 | 4 | 5 | def interleave_iterables(iterables: Iterable[Iterable[Any]]) -> Iterator[Any]: 6 | """ 7 | Interleave elements from multiple iterables. 8 | 9 | This function takes an iterable of iterables (like lists, tuples, etc.) and yields elements 10 | in an interleaved manner: first yielding the first element of each iterable, then the second 11 | element of each, and so on. If an iterable is exhausted (shorter than others), it is skipped. 12 | 13 | Example: 14 | >>> list(interleave_iterables([[1, 2, 3], ('a', 'b'), [10, 20, 30, 40]])) 15 | [1, 'a', 10, 2, 'b', 20, 3, 30, 40] 16 | """ 17 | 18 | sentinel = object() # unique object to identify exhausted sublists 19 | 20 | for elements in zip_longest(*iterables, fillvalue=sentinel): 21 | for element in elements: 22 | if element is not sentinel: 23 | yield element 24 | 25 | 26 | def remove_duplicates_keep_order(iterable: Iterable): 27 | """ 28 | Remove duplicates from an iterable while keeping order. Returns a list. 29 | """ 30 | return list(dict.fromkeys(iterable)) 31 | 32 | 33 | def insert_interval( 34 | current_intervals: list[tuple[int, int]], new_interval: tuple[int, int] 35 | ) -> list[tuple[int, int]]: 36 | """ 37 | Add a new interval to a list of sorted non-overlapping intervals. If the new interval overlaps 38 | with any of the existing intervals, merge them. Return the resulting list of intervals. 39 | """ 40 | new_intervals = [] 41 | new_interval_start, new_interval_end = new_interval 42 | for i, (start, end) in enumerate(current_intervals): 43 | if end < new_interval_start: 44 | new_intervals.append((start, end)) 45 | elif start > new_interval_end: 46 | # Add the new interval and all remaining intervals 47 | new_intervals.append((new_interval_start, new_interval_end)) 48 | return new_intervals + current_intervals[i:] 49 | else: 50 | # Intervals overlap or are adjacent, merge them 51 | new_interval_start = min(start, new_interval_start) 52 | new_interval_end = max(end, new_interval_end) 53 | 54 | # If we're here, the new/merged interval is the last one 55 | new_intervals.append((new_interval_start, new_interval_end)) 56 | return new_intervals 57 | -------------------------------------------------------------------------------- /utils/async_utils.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor 3 | 4 | # NOTE: consider using async_to_sync from asgiref.sync library 5 | 6 | 7 | def run_task_sync(task): 8 | """ 9 | Run an asyncio task (more precisely, a coroutine object, such as the result of 10 | calling an async function) synchronously. 11 | """ 12 | # print("-" * 100) 13 | # try: 14 | # print("=" * 100) 15 | # asyncio.get_running_loop() 16 | # print("*-*" * 100) 17 | # except RuntimeError as e: 18 | # print("?" * 100) 19 | # print(e) 20 | # print("?" * 100) 21 | # res = asyncio.run(task) # no preexisting event loop, so run in the main thread 22 | # print("!" * 100) 23 | # print(res) 24 | # print("!" * 100) 25 | # return res 26 | # TODO: consider a situation when there is a non-running loop 27 | 28 | with ThreadPoolExecutor(max_workers=1) as executor: 29 | future = executor.submit(asyncio.run, task) 30 | return future.result() 31 | 32 | 33 | def make_sync(async_func): 34 | """ 35 | Make an asynchronous function synchronous. 36 | """ 37 | 38 | def wrapper(*args, **kwargs): 39 | return run_task_sync(async_func(*args, **kwargs)) 40 | 41 | return wrapper 42 | 43 | 44 | def gather_tasks_sync(tasks): 45 | """ 46 | Run a list of asyncio tasks synchronously. 47 | """ 48 | 49 | async def coroutine_from_tasks(): 50 | return await asyncio.gather(*tasks) 51 | 52 | return run_task_sync(coroutine_from_tasks()) 53 | 54 | 55 | def execute_func_map_in_processes(func, inputs, max_workers=None): 56 | """ 57 | Execute a function on a list of inputs in a separate process for each input. 58 | """ 59 | with ProcessPoolExecutor(max_workers=max_workers) as executor: 60 | return list(executor.map(func, inputs)) 61 | 62 | def execute_func_map_in_threads(func, inputs, max_workers=None): 63 | """ 64 | Execute a function on a list of inputs in a separate thread for each input. 65 | """ 66 | with ThreadPoolExecutor(max_workers=max_workers) as executor: 67 | return list(executor.map(func, inputs)) 68 | -------------------------------------------------------------------------------- /utils/debug.py: -------------------------------------------------------------------------------- 1 | 2 | from utils.prepare import get_logger 3 | from langchain_core.prompts import PromptTemplate 4 | 5 | logger = get_logger() 6 | 7 | 8 | def save_prompt_text_to_file( 9 | prompt: PromptTemplate, inputs: dict[str, str], file_path: str 10 | ) -> None: 11 | prompt_text = prompt.invoke(inputs).to_string() 12 | with open(file_path, "a", encoding="utf-8") as f: 13 | f.write("-" * 80 + "\n\n" + prompt_text) 14 | logger.debug(f"Saved prompt text to {file_path}") 15 | -------------------------------------------------------------------------------- /utils/docgrab.py: -------------------------------------------------------------------------------- 1 | import json 2 | import uuid 3 | from typing import Iterable 4 | 5 | from chromadb import ClientAPI 6 | from dotenv import load_dotenv 7 | from langchain_community.document_loaders import GitbookLoader 8 | 9 | from components.chroma_ddg import ChromaDDG 10 | from components.openai_embeddings_ddg import get_openai_embeddings 11 | from utils.prepare import EMBEDDINGS_DIMENSIONS, get_logger 12 | from utils.rag import rag_text_splitter 13 | from langchain_core.documents import Document 14 | 15 | load_dotenv(override=True) 16 | logger = get_logger() 17 | 18 | 19 | class JSONLDocumentLoader: 20 | def __init__(self, file_path: str, max_docs=None) -> None: 21 | self.file_path = file_path 22 | self.max_docs = max_docs 23 | 24 | def load(self) -> list[Document]: 25 | docs = load_docs_from_jsonl(self.file_path) 26 | if self.max_docs is None or self.max_docs > len(docs): 27 | return docs 28 | return docs[: self.max_docs] 29 | 30 | 31 | def save_docs_to_jsonl(docs: Iterable[Document], file_path: str) -> None: 32 | with open(file_path, "a") as f: 33 | for doc in docs: 34 | f.write(doc.json() + "\n") 35 | 36 | 37 | def load_docs_from_jsonl(file_path: str) -> list[Document]: 38 | docs = [] 39 | with open(file_path, "r") as f: 40 | for line in f: 41 | data = json.loads(line) 42 | doc = Document(**data) 43 | docs.append(doc) 44 | return docs 45 | 46 | 47 | def load_gitbook(root_url: str) -> list[Document]: 48 | all_pages_docs = GitbookLoader(root_url, load_all_paths=True).load() 49 | return all_pages_docs 50 | 51 | 52 | def prepare_chunks( 53 | texts: list[str], metadatas: list[dict], ids: list[str] 54 | ) -> list[Document]: 55 | """ 56 | Split documents into chunks and add parent ids to the chunks' metadata. 57 | Returns a list of snippets (each is a Document). 58 | 59 | It is ok to pass an empty list of texts. 60 | """ 61 | logger.info(f"Splitting {len(texts)} documents into chunks...") 62 | 63 | # Add parent ids to metadata (will delete, but only after it they propagate to snippets) 64 | for metadata, id in zip(metadatas, ids): 65 | metadata["parent_id"] = id 66 | 67 | # Split into snippets 68 | snippets = rag_text_splitter.create_documents(texts, metadatas) 69 | logger.info(f"Obtained {len(snippets)} chunks.") 70 | 71 | # Restore original metadata 72 | for metadata in metadatas: 73 | del metadata["parent_id"] 74 | 75 | return snippets 76 | 77 | 78 | FAKE_FULL_DOC_EMBEDDING = [1.0] * EMBEDDINGS_DIMENSIONS 79 | 80 | # TODO: remove the logic of saving to the db, leave only doc preparation. We should 81 | # separate concerns and reduce the number of places we write to the db. 82 | def ingest_into_chroma( 83 | docs: list[Document], 84 | *, 85 | collection_name: str, 86 | openai_api_key: str, 87 | chroma_client: ClientAPI | None = None, 88 | save_dir: str | None = None, 89 | collection_metadata: dict[str, str] | None = None, 90 | ) -> ChromaDDG: 91 | """ 92 | Load documents and/or collection metadata into a Chroma collection, return a vectorstore 93 | object. 94 | 95 | If collection_metadata is passed and the collection exists, the metadata will be 96 | replaced with the passed metadata, according to the Chroma docs. 97 | 98 | NOTE: Normally, the higher level agentblocks.collectionhelper.ingest_into_collection 99 | should be used, which creates/updates the "created_at" and "updated_at" metadata fields. 100 | """ 101 | assert bool(chroma_client) != bool(save_dir), "Invalid vector db destination" 102 | 103 | # Handle special case of no docs - just create/update collection with given metadata 104 | if not docs: 105 | return ChromaDDG( 106 | embedding_function=get_openai_embeddings(openai_api_key), 107 | client=chroma_client, 108 | persist_directory=save_dir, 109 | collection_name=collection_name, 110 | collection_metadata=collection_metadata, 111 | create_if_not_exists=True, 112 | ) 113 | 114 | # Prepare full texts, metadatas and ids 115 | full_doc_ids = [str(uuid.uuid4()) for _ in range(len(docs))] 116 | texts = [doc.page_content for doc in docs] 117 | metadatas = [doc.metadata for doc in docs] 118 | 119 | # Split into snippets, embed and add them 120 | vectorstore: ChromaDDG = ChromaDDG.from_documents( 121 | prepare_chunks(texts, metadatas, full_doc_ids), 122 | embedding=get_openai_embeddings(openai_api_key), 123 | client=chroma_client, 124 | persist_directory=save_dir, 125 | collection_name=collection_name, 126 | collection_metadata=collection_metadata, 127 | create_if_not_exists=True, # ok to pass (kwargs are passed to __init__) 128 | ) 129 | 130 | # Add the original full docs (with fake embeddings) 131 | # NOTE: should be possible to add everything in one call, with some work 132 | fake_embeddings = [FAKE_FULL_DOC_EMBEDDING for _ in range(len(docs))] 133 | vectorstore.collection.add(full_doc_ids, fake_embeddings, metadatas, texts) 134 | 135 | logger.info(f"Ingested documents into collection {collection_name}") 136 | if save_dir: 137 | logger.info(f"Saved to {save_dir}") 138 | return vectorstore 139 | 140 | 141 | if __name__ == "__main__": 142 | pass 143 | # download all pages from gitbook and save to jsonl 144 | # all_pages_docs = load_gitbook(GITBOOK_ROOT_URL) 145 | # print(f"Loaded {len(all_pages_docs)} documents") 146 | # save_docs_to_jsonl(all_pages_docs, "docs.jsonl") 147 | 148 | # load from jsonl 149 | # docs = load_docs_from_jsonl("docs.jsonl") 150 | # print(f"Loaded {len(docs)} documents") 151 | -------------------------------------------------------------------------------- /utils/filesystem.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | def ensure_path_exists(path, is_directory=False): 5 | """ 6 | Ensure that a given path exists as either a file or a directory. 7 | 8 | Parameters: 9 | - path (str): The file or directory path to check or create. 10 | - is_directory (bool): Flag to indicate if the path should be a directory (True) or a file (False). 11 | 12 | This function checks if the given path exists. If it exists, it checks if it matches the type specified 13 | by `is_directory` (file or directory). If it does not exist, the function creates the necessary 14 | directories and, if required, an empty file at that path. 15 | 16 | Returns: 17 | - None 18 | """ 19 | # Check if the path exists 20 | if os.path.exists(path): 21 | # If it exists, check if it is the correct type (file or directory) 22 | if is_directory: 23 | if not os.path.isdir(path): 24 | raise ValueError(f"Path {path} exists but is not a directory as expected.") 25 | elif not os.path.isfile(path): 26 | raise ValueError(f"Path {path} exists but is not a file as expected.") 27 | else: 28 | # If the path does not exist, create it 29 | if is_directory: 30 | # Create all intermediate directories 31 | os.makedirs(path, exist_ok=True) 32 | else: 33 | # Create the directory path to the file if it doesn't exist 34 | os.makedirs(os.path.dirname(path), exist_ok=True) 35 | # Create the file 36 | with open(path, "w"): 37 | pass # Creating an empty file -------------------------------------------------------------------------------- /utils/ingest.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | 4 | import docx2txt 5 | from bs4 import BeautifulSoup 6 | from icecream import ic 7 | from pypdf import PdfReader 8 | from starlette.datastructures import UploadFile # err if "from fastapi" 9 | from streamlit.runtime.uploaded_file_manager import UploadedFile 10 | from langchain_core.documents import Document 11 | 12 | allowed_extensions = [ 13 | "", 14 | ".txt", 15 | ".md", 16 | ".rtf", 17 | ".log", 18 | ".pdf", 19 | ".docx", 20 | ".html", 21 | ".htm", 22 | ] 23 | 24 | 25 | def get_page_texts_from_pdf(file): 26 | reader = PdfReader(file) 27 | return [page.extract_text() for page in reader.pages] 28 | 29 | 30 | DEFAULT_PAGE_START = "PAGE {page_num}:\n" 31 | DEFAULT_PAGE_SEP = "\n" + "-" * 3 + "\n\n" 32 | 33 | 34 | def get_text_from_pdf(file, page_start=DEFAULT_PAGE_START, page_sep=DEFAULT_PAGE_SEP): 35 | return page_sep.join( 36 | [ 37 | page_start.replace("{page_num}", str(i)) + x 38 | for i, x in enumerate(get_page_texts_from_pdf(file), start=1) 39 | ] 40 | ) 41 | 42 | 43 | def extract_text(files, allow_all_ext): 44 | docs = [] 45 | failed_files = [] 46 | unsupported_ext_files = [] 47 | for file in files: 48 | if isinstance(file, UploadFile): 49 | file_name = file.filename or "unnamed-file" # need? 50 | file = file.file 51 | else: 52 | file_name = file.name 53 | try: 54 | extension = os.path.splitext(file_name)[1] 55 | if not allow_all_ext and extension not in allowed_extensions: 56 | unsupported_ext_files.append(file_name) 57 | continue 58 | if extension == ".pdf": 59 | for i, text in enumerate(get_page_texts_from_pdf(file)): 60 | metadata = {"source": f"{file_name} (page {i + 1})"} 61 | docs.append(Document(page_content=text, metadata=metadata)) 62 | else: 63 | if extension == ".docx": 64 | text = docx2txt.process(file) 65 | elif extension in [".html", ".htm"]: 66 | soup = BeautifulSoup(file, "html.parser") 67 | # Remove script and style elements 68 | for script_or_style in soup(["script", "style"]): 69 | script_or_style.extract() 70 | text = soup.get_text() 71 | # Replace multiple newlines with single newlines 72 | text = re.sub(r"\n{2,}", "\n\n", text) 73 | print(text) 74 | else: 75 | # Treat as text file 76 | if isinstance(file, UploadedFile): 77 | text = file.getvalue().decode("utf-8") 78 | # NOTE: not sure what the advantage is for Streamlit's UploadedFile 79 | else: 80 | text = file.read().decode("utf-8") 81 | docs.append(Document(page_content=text, metadata={"source": file_name})) 82 | except Exception as e: 83 | ic(e) 84 | failed_files.append(file_name) 85 | 86 | ic(len(docs), failed_files, unsupported_ext_files) 87 | return docs, failed_files, unsupported_ext_files 88 | 89 | def format_ingest_failure(failed_files, unsupported_ext_files): 90 | res = ( 91 | "Apologies, the following files failed to process:\n```\n" 92 | + "\n".join(unsupported_ext_files + failed_files) 93 | + "\n```" 94 | ) 95 | if unsupported_ext_files: 96 | unsupported_extentions = sorted( 97 | list(set(os.path.splitext(x)[1] for x in unsupported_ext_files)) 98 | ) 99 | res += "\n\nUnsupported file extensions: " + ", ".join( 100 | f"`{x.lstrip('.') or ''}`" for x in unsupported_extentions 101 | ) 102 | return res -------------------------------------------------------------------------------- /utils/input.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Iterable 2 | 3 | def get_menu_choice( 4 | choices: Iterable[str], 5 | prompt: str = "Please enter the desired number: ", 6 | default: int | None = None, 7 | ) -> int: 8 | """ 9 | Get a choice from the user from a list of choices. 10 | 11 | Returns the index of the choice selected by the user. Prompts the user 12 | repeatedly until a valid choice is entered. 13 | """ 14 | 15 | while True: 16 | # Print the choices with numbered identifiers 17 | for idx, choice_text in enumerate(choices, start=1): 18 | print(f"{idx}. {choice_text}") 19 | 20 | # Prompt for the user's choice 21 | user_input = input(prompt) 22 | 23 | # Use default if the user input is empty and default is set 24 | if user_input == "" and default is not None: 25 | return default 26 | 27 | # Check if the input is a valid choice 28 | try: 29 | chosen_idx = int(user_input) - 1 30 | except ValueError: 31 | chosen_idx = -1 32 | if 0 <= chosen_idx < len(choices): 33 | return chosen_idx 34 | print() 35 | 36 | def get_choice_from_dict_menu( 37 | menu: dict[Any, str], 38 | prompt: str = "Please enter the desired number: ", 39 | default: Any | None = None, 40 | ) -> Any: 41 | """ 42 | Get a choice from the user from a dictionary menu. 43 | 44 | Returns the key of the choice selected by the user. Prompts the user 45 | repeatedly until a valid choice is entered. 46 | """ 47 | 48 | default_idx = None if default is None else -123 49 | idx = get_menu_choice(menu.values(), prompt, default_idx) 50 | return default if idx == -123 else list(menu.keys())[idx] 51 | -------------------------------------------------------------------------------- /utils/log.py: -------------------------------------------------------------------------------- 1 | import atexit 2 | import datetime as dt 3 | import json 4 | import logging 5 | import logging.config 6 | import logging.handlers 7 | import pathlib 8 | from logging.config import ConvertingDict, ConvertingList, valid_ident 9 | from logging.handlers import QueueHandler, QueueListener 10 | from queue import Queue 11 | 12 | from utils.filesystem import ensure_path_exists 13 | 14 | # from typing import override 15 | 16 | LOG_RECORD_BUILTIN_ATTRS = { 17 | "args", 18 | "asctime", 19 | "created", 20 | "exc_info", 21 | "exc_text", 22 | "filename", 23 | "funcName", 24 | "levelname", 25 | "levelno", 26 | "lineno", 27 | "module", 28 | "msecs", 29 | "message", 30 | "msg", 31 | "name", 32 | "pathname", 33 | "process", 34 | "processName", 35 | "relativeCreated", 36 | "stack_info", 37 | "thread", 38 | "threadName", 39 | "taskName", 40 | } 41 | 42 | 43 | # https://github.com/mCodingLLC/VideosSampleCode/blob/master/videos/135_modern_logging/mylogger.py 44 | class MyJSONFormatter(logging.Formatter): 45 | def __init__( 46 | self, 47 | *, 48 | fmt_keys: dict[str, str] | None = None, 49 | ): 50 | super().__init__() 51 | self.fmt_keys = fmt_keys if fmt_keys is not None else {} 52 | 53 | # @override 54 | def format(self, record: logging.LogRecord) -> str: 55 | message = self._prepare_log_dict(record) 56 | return json.dumps(message, default=str) 57 | 58 | def _prepare_log_dict(self, record: logging.LogRecord): 59 | always_fields = { 60 | "message": record.getMessage(), 61 | "timestamp": dt.datetime.fromtimestamp( 62 | record.created, tz=dt.timezone.utc 63 | ).isoformat(), 64 | } 65 | if record.exc_info is not None: 66 | always_fields["exc_info"] = self.formatException(record.exc_info) 67 | 68 | if record.stack_info is not None: 69 | always_fields["stack_info"] = self.formatStack(record.stack_info) 70 | 71 | message = { 72 | key: msg_val 73 | if (msg_val := always_fields.pop(val, None)) is not None 74 | else getattr(record, val) 75 | for key, val in self.fmt_keys.items() 76 | } 77 | message.update(always_fields) 78 | 79 | for key, val in record.__dict__.items(): 80 | if key not in LOG_RECORD_BUILTIN_ATTRS: 81 | message[key] = val 82 | 83 | return message 84 | 85 | 86 | class NonErrorFilter(logging.Filter): 87 | # @override 88 | def filter(self, record: logging.LogRecord) -> bool | logging.LogRecord: 89 | return record.levelno <= logging.INFO 90 | 91 | 92 | # https://rob-blackbourn.medium.com/how-to-use-python-logging-queuehandler-with-dictconfig-1e8b1284e27a 93 | def _resolve_handlers(h): 94 | """Converts a list of handlers specified in a JSON file to actual handler objects.""" 95 | if isinstance(h, ConvertingList): 96 | # Getting each item converts it to an actual object 97 | h = [h[i] for i in range(len(h))] 98 | return h 99 | 100 | 101 | def _resolve_convertingdict(q): 102 | """Resolves a dict specified in a JSON file to an actual object, if necessary.""" 103 | if not isinstance(q, ConvertingDict): 104 | return q 105 | if "__resolved_value__" in q: 106 | return q["__resolved_value__"] 107 | 108 | cname = q.pop("class") 109 | klass = q.configurator.resolve(cname) 110 | props = q.pop(".", None) 111 | kwargs = {k: q[k] for k in q if valid_ident(k)} 112 | result = klass(**kwargs) 113 | if props: 114 | for name, value in props.items(): 115 | setattr(result, name, value) 116 | 117 | q["__resolved_value__"] = result 118 | return result 119 | 120 | 121 | class QueueListenerHandler(QueueHandler): 122 | def __init__( 123 | self, handlers, respect_handler_level=False, auto_run=True, queue=Queue(0) 124 | ): 125 | queue = _resolve_convertingdict(queue) 126 | super().__init__(queue) 127 | handlers = _resolve_handlers(handlers) 128 | self._listener = QueueListener( 129 | self.queue, *handlers, respect_handler_level=respect_handler_level 130 | ) 131 | if auto_run: 132 | self.start() 133 | atexit.register(self.stop) 134 | 135 | def start(self): 136 | self._listener.start() 137 | 138 | def stop(self): 139 | self._listener.stop() 140 | 141 | def emit(self, record): # NOTE: remove? 142 | return super().emit(record) 143 | 144 | 145 | def setup_logging(log_level: str | None, log_format: str | None): 146 | ensure_path_exists("logs/ddg-logs.jsonl") 147 | with open(pathlib.Path("config/logging.json")) as f: 148 | config = json.load(f) 149 | if log_format: 150 | config["formatters"]["custom"] = { 151 | "format": log_format, 152 | "datefmt": "%Y-%m-%dT%H:%M:%S%z", 153 | } 154 | config["handlers"]["h1_stderr"]["formatter"] = "custom" 155 | 156 | if log_level: 157 | config["handlers"]["h1_stderr"]["level"] = log_level 158 | 159 | logging.config.dictConfig(config) 160 | -------------------------------------------------------------------------------- /utils/older/prompts-checkpoints.py: -------------------------------------------------------------------------------- 1 | iterative_report_improver_template = """\ 2 | You are ARIA, Advanced Report Improvement Assistant. 3 | 4 | For this task, you can use the following information retrieved from the Internet: 5 | 6 | {new_info} 7 | 8 | END OF RETRIEVED INFORMATION 9 | 10 | Your task: pinpoint areas of improvement in the report/answer prepared in response to the following query: 11 | 12 | {query} 13 | 14 | END OF QUERY. REPORT/ANSWER TO ANALYZE: 15 | 16 | {previous_report} 17 | 18 | END OF REPORT 19 | 20 | Please decide how specifically the RETRIEVED INFORMATION you were provided at the beginning can address any areas of improvement in the report: additions, corrections, deletions, perhaps even a complete rewrite. 21 | 22 | Please write: "ACTION ITEMS FOR IMPROVEMENT:" then provide a numbered list of the individual SOURCEs in the RETRIEVED INFORMATION: first the URL, then one brief sentence stating whether its CONTENT from above will be used to enhance the report - be as specific as possible and if that particular content is not useful then just write "NOT RELEVANT". 23 | 24 | Add one more item in your numbered list - any additional instructions you can think of for improving the report/answer, independent of the RETRIEVED INFORMATION, for example how to rearrange sections, what to remove, reword, etc. 25 | 26 | After that, write: "NEW REPORT:" and write a new report from scratch. Important: any action items you listed must be **fully** implemented in your report, in which case your report must necessarily be different from the original report. In fact, the new report can be completely different if needed, the only concern is to craft an informative, no-fluff answer to the user's query: 27 | 28 | {query} 29 | 30 | END OF QUERY. Your report/answer should be: {report_type}. 31 | 32 | Finish with: "REPORT ASSESSMENT: X%, where X is your estimate of how well your new report serves user's information need on a scale from 0% to 100%, based on their query. 33 | """ -------------------------------------------------------------------------------- /utils/output.py: -------------------------------------------------------------------------------- 1 | class ConditionalLogger: 2 | def __init__(self, verbose=True): 3 | self.verbose = verbose 4 | 5 | def log(self, message): 6 | if self.verbose: 7 | print(message) 8 | 9 | def log_no_newline(self, message): 10 | if self.verbose: 11 | print(message, end="") 12 | 13 | def log_error(self, message): 14 | print(message) 15 | 16 | 17 | def format_exception(e: Exception) -> str: 18 | """ 19 | Format an exception to include both its type and message. 20 | 21 | Args: 22 | e (Exception): The exception to be formatted. 23 | 24 | Returns: 25 | str: A string representation of the exception, including its type and message. 26 | 27 | Example: 28 | try: 29 | raise ValueError("An example error") 30 | except ValueError as e: 31 | formatted_exception = format_exception(e) 32 | print(formatted_exception) # Output: "ValueError: An example error" 33 | """ 34 | try: 35 | return f"{type(e).__name__}: {e}" 36 | except Exception: 37 | return f"Unknown error: {e}" 38 | 39 | -------------------------------------------------------------------------------- /utils/prepare.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import sys 4 | 5 | from dotenv import load_dotenv 6 | 7 | from utils.log import setup_logging 8 | 9 | load_dotenv(override=True) 10 | 11 | # Import chromadb while making it think sqlite3 has a new enough version. 12 | # This is necessary because the version of sqlite3 in Streamlit Share is too old. 13 | # We don't actually use sqlite3 in a Streamlit Share context, because we use HttpClient, 14 | # so it's fine to use the app there despite ending up with an incompatible sqlite3. 15 | sys.modules["sqlite3"] = lambda: None 16 | sys.modules["sqlite3"].sqlite_version_info = (42, 42, 42) 17 | __import__("chromadb") 18 | del sys.modules["sqlite3"] 19 | __import__("sqlite3") # import here because chromadb was supposed to import it 20 | 21 | # Set up logging 22 | LOG_LEVEL = os.getenv("LOG_LEVEL") 23 | LOG_FORMAT = os.getenv("LOG_FORMAT") 24 | setup_logging(LOG_LEVEL, LOG_FORMAT) 25 | DEFAULT_LOGGER_NAME = os.getenv("DEFAULT_LOGGER_NAME", "ddg") 26 | 27 | 28 | def get_logger(logger_name: str = DEFAULT_LOGGER_NAME): 29 | return logging.getLogger(logger_name) 30 | 31 | 32 | # Set up the environment variables 33 | DEFAULT_OPENAI_API_KEY = os.getenv("DEFAULT_OPENAI_API_KEY", "") 34 | IS_AZURE = bool(os.getenv("OPENAI_API_BASE") or os.getenv("AZURE_OPENAI_API_KEY")) 35 | EMBEDDINGS_DEPLOYMENT_NAME = os.getenv("EMBEDDINGS_DEPLOYMENT_NAME") 36 | CHAT_DEPLOYMENT_NAME = os.getenv("CHAT_DEPLOYMENT_NAME") 37 | 38 | DEFAULT_COLLECTION_NAME = os.getenv("DEFAULT_COLLECTION_NAME", "docdocgo-documentation") 39 | 40 | if USE_CHROMA_VIA_HTTP := bool(os.getenv("USE_CHROMA_VIA_HTTP")): 41 | os.environ["CHROMA_API_IMPL"] = "rest" 42 | 43 | # The following three variables are only used if USE_CHROMA_VIA_HTTP is True 44 | CHROMA_SERVER_HOST = os.getenv("CHROMA_SERVER_HOST", "localhost") 45 | CHROMA_SERVER_HTTP_PORT = os.getenv("CHROMA_SERVER_HTTP_PORT", "8000") 46 | CHROMA_SERVER_AUTHN_CREDENTIALS = os.getenv("CHROMA_SERVER_AUTHN_CREDENTIALS", "") 47 | 48 | # The following variable is only used if USE_CHROMA_VIA_HTTP is False 49 | VECTORDB_DIR = os.getenv("VECTORDB_DIR", "chroma/") 50 | 51 | MODEL_NAME = os.getenv("MODEL_NAME", "gpt-4o-mini") # rename to DEFAULT_MODEL? 52 | CONTEXT_LENGTH = int(os.getenv("CONTEXT_LENGTH", 16000)) # it's actually more like max 53 | # size of what we think we can feed to the model so that it doesn't get overwhelmed 54 | TEMPERATURE = float(os.getenv("TEMPERATURE", 0.3)) 55 | 56 | ALLOWED_MODELS = os.getenv("ALLOWED_MODELS", MODEL_NAME).split(",") 57 | ALLOWED_MODELS = [model.strip() for model in ALLOWED_MODELS] 58 | if MODEL_NAME not in ALLOWED_MODELS: 59 | raise ValueError("The default model must be in the list of allowed models.") 60 | 61 | EMBEDDINGS_MODEL_NAME = os.getenv("EMBEDDINGS_MODEL_NAME", "text-embedding-3-large") 62 | EMBEDDINGS_DIMENSIONS = int(os.getenv("EMBEDDINGS_DIMENSIONS", 3072)) 63 | 64 | LLM_REQUEST_TIMEOUT = float(os.getenv("LLM_REQUEST_TIMEOUT", 9)) 65 | 66 | DEFAULT_MODE = os.getenv("DEFAULT_MODE", "/kb") 67 | 68 | INCLUDE_ERROR_IN_USER_FACING_ERROR_MSG = bool( 69 | os.getenv("INCLUDE_ERROR_IN_USER_FACING_ERROR_MSG") 70 | ) 71 | 72 | BYPASS_SETTINGS_RESTRICTIONS = bool(os.getenv("BYPASS_SETTINGS_RESTRICTIONS")) 73 | BYPASS_SETTINGS_RESTRICTIONS_PASSWORD = os.getenv( 74 | "BYPASS_SETTINGS_RESTRICTIONS_PASSWORD" 75 | ) 76 | 77 | DOMAIN_NAME_FOR_SHARING = os.getenv("DOMAIN_NAME_FOR_SHARING", "shared") 78 | 79 | MAX_UPLOAD_BYTES = int(os.getenv("MAX_UPLOAD_BYTES", 100 * 1024 * 1024)) 80 | 81 | INITIAL_TEST_QUERY_STREAMLIT = os.getenv("INITIAL_QUERY_STREAMLIT") 82 | 83 | # Check that the necessary environment variables are set 84 | DUMMY_OPENAI_API_KEY_PLACEHOLDER = "DUMMY NON-EMPTY VALUE" 85 | 86 | if IS_AZURE and not ( 87 | EMBEDDINGS_DEPLOYMENT_NAME 88 | and CHAT_DEPLOYMENT_NAME 89 | and os.getenv("AZURE_OPENAI_API_KEY") 90 | and os.getenv("OPENAI_API_BASE") 91 | ): 92 | raise ValueError( 93 | "You have set some but not all environment variables necessary to utilize the " 94 | "Azure OpenAI API endpoint. Please refer to .env.example for details." 95 | ) 96 | elif not IS_AZURE and not DEFAULT_OPENAI_API_KEY: 97 | # We don't exit because we could get the key from the Streamlit app 98 | print( 99 | "WARNING: You have not set the DEFAULT_OPENAI_API_KEY environment variable. " 100 | "This is ok when running the Streamlit app, but not when running " 101 | "the command line app. For now, we will set it to a dummy non-empty value " 102 | "to avoid problems initializing the vectorstore etc. " 103 | "Please refer to .env.example for additional information." 104 | ) 105 | os.environ["DEFAULT_OPENAI_API_KEY"] = DUMMY_OPENAI_API_KEY_PLACEHOLDER 106 | DEFAULT_OPENAI_API_KEY = DUMMY_OPENAI_API_KEY_PLACEHOLDER 107 | # TODO investigate the behavior when this happens 108 | 109 | if not os.getenv("SERPER_API_KEY") and not os.getenv("IGNORE_LACK_OF_SERPER_API_KEY"): 110 | raise ValueError( 111 | "You have not set the SERPER_API_KEY environment variable, " 112 | "which is necessary for the Internet search functionality." 113 | "Pease set the SERPER_API_KEY environment variable to your Google Serper API key, " 114 | "which you can get for free, with no credit card, at https://serper.dev. " 115 | "This free key will allow you to make about 1250 searches until payment is required.\n\n" 116 | "If you want to supress this error, set the IGNORE_LACK_OF_SERPER_API_KEY environment " 117 | "variable to any non-empty value. You can then use features that do not require the " 118 | "Internet search functionality." 119 | ) 120 | 121 | # Verify the validity of the db path 122 | if not os.getenv("USE_CHROMA_VIA_HTTP") and not os.path.isdir(VECTORDB_DIR): 123 | try: 124 | abs_path = os.path.abspath(VECTORDB_DIR) 125 | except Exception: 126 | abs_path = "INVALID PATH" 127 | raise ValueError( 128 | "You have not specified a valid directory for the vector database. " 129 | "Please set the VECTORDB_DIR environment variable in .env, as shown in .env.example. " 130 | "Alternatively, if you have a Chroma DB server running, you can set the " 131 | "USE_CHROMA_VIA_HTTP environment variable to any non-empty value. " 132 | f"\n\nThe path you have specified is: {VECTORDB_DIR}.\n" 133 | f"The absolute path resolves to: {abs_path}." 134 | ) 135 | 136 | 137 | is_env_loaded = True 138 | -------------------------------------------------------------------------------- /utils/rag.py: -------------------------------------------------------------------------------- 1 | from langchain_text_splitters import RecursiveCharacterTextSplitter 2 | 3 | rag_text_splitter = RecursiveCharacterTextSplitter( 4 | chunk_size=400, 5 | chunk_overlap=40, 6 | add_start_index=True, # metadata will include start index of snippet in original doc 7 | ) 8 | -------------------------------------------------------------------------------- /utils/streamlit/fix_event_loop.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | 3 | 4 | def remove_tornado_fix() -> None: 5 | """ 6 | UNDOES THE FOLLOWING OBSOLETE FIX IN streamlit.web.bootstrap.run(): 7 | Set default asyncio policy to be compatible with Tornado 6. 8 | 9 | Tornado 6 (at least) is not compatible with the default 10 | asyncio implementation on Windows. So here we 11 | pick the older SelectorEventLoopPolicy when the OS is Windows 12 | if the known-incompatible default policy is in use. 13 | 14 | This has to happen as early as possible to make it a low priority and 15 | overridable 16 | 17 | See: https://github.com/tornadoweb/tornado/issues/2608 18 | 19 | FIXME: if/when tornado supports the defaults in asyncio, 20 | remove and bump tornado requirement for py38 21 | """ 22 | try: 23 | from asyncio import ( # type: ignore[attr-defined] 24 | WindowsProactorEventLoopPolicy, 25 | WindowsSelectorEventLoopPolicy, 26 | ) 27 | except ImportError: 28 | pass 29 | # Not affected 30 | else: 31 | if type(asyncio.get_event_loop_policy()) is WindowsSelectorEventLoopPolicy: 32 | asyncio.set_event_loop_policy(WindowsProactorEventLoopPolicy()) 33 | -------------------------------------------------------------------------------- /utils/streamlit/helpers.py: -------------------------------------------------------------------------------- 1 | import re 2 | import time 3 | from typing import Any 4 | 5 | import streamlit as st 6 | from pydantic import BaseModel 7 | 8 | from utils.type_utils import ChatMode 9 | 10 | WELCOME_TEMPLATE_COLL_IN_URL = """\ 11 | Welcome! You are now using the `{}` collection. \ 12 | When chatting, I will use its content as my knowledge base. \ 13 | To get help using me, just type `/help `.\ 14 | """ 15 | 16 | NO_AUTH_TEMPLATE_COLL_IN_URL = """\ 17 | Apologies, the current URL doesn't provide access to the \ 18 | requested collection `{}`. This can happen if the \ 19 | access code is invalid or if the collection doesn't exist. 20 | 21 | I have switched to my default collection. 22 | 23 | To get help using me, just type `/help `.\ 24 | """ 25 | 26 | WELCOME_MSG_NEW_CREDENTIALS = """\ 27 | Welcome! The user credentials have changed but you are still \ 28 | authorised for the current collection.\ 29 | """ 30 | 31 | NO_AUTH_MSG_NEW_CREDENTIALS = """\ 32 | Welcome! The user credentials have changed, \ 33 | so I've switched to the default collection.\ 34 | """ 35 | 36 | POST_INGEST_MESSAGE_TEMPLATE_NEW_COLL = """\ 37 | Your documents have been uploaded to a new collection: `{coll_name}`. \ 38 | You can now use this collection in your queries. If you are done uploading, \ 39 | you can rename it: 40 | ``` 41 | /db rename my-cool-collection-name 42 | ``` 43 | """ 44 | 45 | POST_INGEST_MESSAGE_TEMPLATE_EXISTING_COLL = """\ 46 | Your documents have been added to the existing collection: `{coll_name}`. \ 47 | If you are done uploading, you can rename it: 48 | ``` 49 | /db rename my-cool-collection-name 50 | ``` 51 | """ 52 | 53 | mode_option_to_prefix = { 54 | "/kb (main mode)": ( 55 | "", # TODO: this presupposes that DEFAULT_MODE is /kb 56 | "Chat using the current collection as a knowledge base.", 57 | ), 58 | "/research": ( 59 | "/re ", 60 | "Research a topic on the web.", 61 | ), 62 | "/research heatseek": ( 63 | "/re hs ", 64 | "Find sites that have specific information you need.", 65 | ), 66 | "/summarize": ( 67 | "/su ", 68 | "Summarize the content of a URL and ingest into a collection.", 69 | ), 70 | "/ingest": ( 71 | "/in ", 72 | "Ingest content from a URL into a collection.", 73 | ), 74 | "/db": ( 75 | "/db ", 76 | "Manage your collections.", 77 | ), 78 | "/export": ( 79 | "/ex ", 80 | "Export your current chat history.", 81 | ), 82 | "/share": ( 83 | "/sh ", 84 | "Create a shareable link to the current collection.", 85 | ), 86 | "/web": ( 87 | "/we ", 88 | "Do a quick web research without ingesting the content.", 89 | ), 90 | "/details": ( 91 | "/de ", 92 | "Provide details on everything retrieved from the collection in response to your query.", 93 | ), 94 | "/quotes": ( 95 | "/qu ", 96 | "Provide quotes from the collection in relevant to your query.", 97 | ), 98 | "/chat": ( 99 | "/ch ", 100 | "Regular chat, without retrieving information from the current collection.", 101 | ), 102 | } 103 | mode_options = list(mode_option_to_prefix.keys()) 104 | 105 | chat_with_docs_status_config = { 106 | "thinking.header": "One sec...", 107 | "thinking.body": "Retrieving sources and composing reply...", 108 | "complete.header": "Done!", 109 | "complete.body": "Reply composed.", 110 | "error.header": "Error.", 111 | "error.body": "Apologies, there was an error.", 112 | } 113 | 114 | just_chat_status_config = chat_with_docs_status_config | { 115 | "thinking.body": "Composing reply...", 116 | } 117 | 118 | web_status_config = chat_with_docs_status_config | { 119 | "thinking.header": "Doing Internet research (takes 10-30s)...", 120 | "thinking.body": "Retrieving content from websites and composing report...", 121 | "complete.body": "Report composed.", 122 | } 123 | 124 | research_status_config = chat_with_docs_status_config | { 125 | "thinking.header": "Doing Internet research (takes 10-30s)...", 126 | "thinking.body": "Retrieving content from websites and composing report...", 127 | "complete.body": "Report composed and sources added to the document collection.", 128 | } 129 | 130 | ingest_status_config = chat_with_docs_status_config | { 131 | "thinking.header": "Ingesting...", 132 | "thinking.body": "Retrieving and ingesting content...", 133 | "complete.body": "Content was added to the document collection.", 134 | } 135 | 136 | summarize_status_config = chat_with_docs_status_config | { 137 | "thinking.header": "Summarizing and ingesting...", 138 | "thinking.body": "Retrieving, summarizing, and ingesting content...", 139 | "complete.body": "Summary composed and retrieved content added to the document collection.", 140 | } 141 | 142 | status_config = { 143 | ChatMode.JUST_CHAT_COMMAND_ID: just_chat_status_config, 144 | ChatMode.CHAT_WITH_DOCS_COMMAND_ID: chat_with_docs_status_config, 145 | ChatMode.DETAILS_COMMAND_ID: chat_with_docs_status_config, 146 | ChatMode.QUOTES_COMMAND_ID: chat_with_docs_status_config, 147 | ChatMode.HELP_COMMAND_ID: chat_with_docs_status_config, 148 | ChatMode.WEB_COMMAND_ID: web_status_config, 149 | ChatMode.RESEARCH_COMMAND_ID: research_status_config, 150 | ChatMode.INGEST_COMMAND_ID: ingest_status_config, 151 | ChatMode.SUMMARIZE_COMMAND_ID: summarize_status_config, 152 | } 153 | 154 | STAND_BY_FOR_INGESTION_MESSAGE = ( 155 | "\n\n--- \n\n> **PLEASE STAND BY WHILE CONTENT IS INGESTED...**" 156 | ) 157 | 158 | 159 | def get_init_msg(is_initial_load, is_authorised, coll_name_in_url, init_coll_name): 160 | if not is_initial_load: # key changed without reloading the page 161 | if is_authorised: 162 | return WELCOME_MSG_NEW_CREDENTIALS 163 | else: 164 | return NO_AUTH_MSG_NEW_CREDENTIALS 165 | elif coll_name_in_url: # page (re)load with coll in URL 166 | if is_authorised: 167 | return WELCOME_TEMPLATE_COLL_IN_URL.format(init_coll_name) 168 | else: 169 | return NO_AUTH_TEMPLATE_COLL_IN_URL.format(init_coll_name) 170 | else: # page (re)load without coll in URL 171 | return None # just to be explicit 172 | 173 | 174 | def update_url(params: dict[str, str]): 175 | st.query_params.clear() # NOTE: should be a way to set them in one go 176 | st.query_params.update(params) 177 | 178 | 179 | def update_url_if_scheduled(): 180 | if st.session_state.update_query_params is not None: 181 | update_url(st.session_state.update_query_params) 182 | st.session_state.update_query_params = None 183 | 184 | 185 | def escape_dollars(text: str) -> str: 186 | """ 187 | Escape dollar signs in the text that come before numbers. 188 | """ 189 | return re.sub(r"\$(?=\d)", r"\$", text) 190 | 191 | 192 | def fix_markdown(text: str) -> str: 193 | """ 194 | Escape dollar signs in the text that come before numbers and add 195 | two spaces before every newline. 196 | """ 197 | return re.sub(r"\$(?=\d)", r"\$", text).replace("\n", " \n") 198 | 199 | 200 | def write_slowly(message_placeholder, answer, delay=None): 201 | """Write a message to the message placeholder not all at once but word by word.""" 202 | pieces = answer.split(" ") 203 | delay = delay or min(0.025, 10 / (len(pieces) + 1)) 204 | for i in range(1, len(pieces) + 1): 205 | message_placeholder.markdown(fix_markdown(" ".join(pieces[:i]))) 206 | time.sleep(delay) 207 | 208 | 209 | def show_sources( 210 | sources: list[str] | None, 211 | callback_handler=None, 212 | ): 213 | """Show the sources if present.""" 214 | # If the cb handler is provided, remove the stand-by message 215 | if callback_handler and callback_handler.end_str_printed: 216 | callback_handler.container.markdown(fix_markdown(callback_handler.buffer)) 217 | 218 | if not sources: 219 | return 220 | with st.expander("Sources"): 221 | for source in sources: 222 | st.markdown(source) 223 | 224 | 225 | def show_uploader(is_teleporting=False, border=True): 226 | if is_teleporting: 227 | try: 228 | st.session_state.uploader_placeholder.empty() 229 | except AttributeError: 230 | pass # should never happen, but just in case 231 | # Switch between one of the two possible keys (to avoid duplicate key error) 232 | st.session_state.uploader_form_key = ( 233 | "uploader-form-alt" 234 | if st.session_state.uploader_form_key == "uploader-form" 235 | else "uploader-form" 236 | ) 237 | 238 | # Show the uploader 239 | st.session_state.uploader_placeholder = st.empty() 240 | with st.session_state.uploader_placeholder: 241 | with st.form( 242 | st.session_state.uploader_form_key, clear_on_submit=True, border=border 243 | ): 244 | files = st.file_uploader( 245 | "Upload your documents", 246 | accept_multiple_files=True, 247 | label_visibility="collapsed", 248 | ) 249 | cols = st.columns([1, 1]) 250 | with cols[0]: 251 | is_submitted = st.form_submit_button("Upload") 252 | with cols[1]: 253 | allow_all_ext = st.toggle("Allow all extensions", value=False) 254 | return (files if is_submitted else []), allow_all_ext 255 | 256 | 257 | class DownloaderData(BaseModel): 258 | data: Any 259 | file_name: str 260 | mime: str = "text/plain" 261 | 262 | 263 | def show_downloader( 264 | downloader_data: DownloaderData | None = None, is_teleporting=False 265 | ): 266 | if downloader_data is None: 267 | try: 268 | # Use the data from the already existing downloader 269 | downloader_data = st.session_state.downloader_data 270 | except AttributeError: 271 | raise ValueError("downloader_data not provided and not in session state") 272 | else: 273 | st.session_state.downloader_data = downloader_data 274 | 275 | if is_teleporting: 276 | try: 277 | st.session_state.downloader_placeholder.empty() 278 | except AttributeError: 279 | pass # should never happen, but just in case 280 | 281 | # Switch between one of the two possible keys (to avoid duplicate key error) 282 | st.session_state.downloader_form_key = ( 283 | "downloader-alt" 284 | if st.session_state.downloader_form_key == "downloader" 285 | else "downloader" 286 | ) 287 | 288 | # Show the downloader 289 | st.session_state.downloader_placeholder = st.empty() 290 | with st.session_state.downloader_placeholder: 291 | is_downloaded = st.download_button( 292 | label="Download Exported Data", 293 | data=downloader_data.data, 294 | file_name=downloader_data.file_name, 295 | mime=downloader_data.mime, 296 | key=st.session_state.downloader_form_key, 297 | ) 298 | return is_downloaded 299 | -------------------------------------------------------------------------------- /utils/streamlit/ingest.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | 3 | import streamlit as st 4 | 5 | from agentblocks.collectionhelper import ingest_into_collection 6 | from agents.dbmanager import ( 7 | get_full_collection_name, 8 | get_user_facing_collection_name, 9 | ) 10 | from utils.chat_state import ChatState 11 | from utils.helpers import ADDITIVE_COLLECTION_PREFIX, INGESTED_DOCS_INIT_PREFIX 12 | from utils.prepare import DEFAULT_COLLECTION_NAME 13 | from utils.query_parsing import IngestCommand 14 | from utils.streamlit.helpers import ( 15 | POST_INGEST_MESSAGE_TEMPLATE_EXISTING_COLL, 16 | POST_INGEST_MESSAGE_TEMPLATE_NEW_COLL, 17 | ) 18 | from langchain_core.documents import Document 19 | 20 | 21 | def ingest_docs(docs: list[Document], chat_state: ChatState): 22 | ingest_command = chat_state.parsed_query.ingest_command 23 | 24 | uploaded_docs_coll_name_as_shown = get_user_facing_collection_name( 25 | chat_state.user_id, chat_state.vectorstore.name 26 | ) # might reassign below 27 | if chat_state.vectorstore.name != DEFAULT_COLLECTION_NAME and ( 28 | ingest_command == IngestCommand.ADD 29 | or ( 30 | ingest_command == IngestCommand.DEFAULT 31 | and uploaded_docs_coll_name_as_shown.startswith(ADDITIVE_COLLECTION_PREFIX) 32 | ) 33 | ): 34 | # We will use the same collection 35 | is_new_collection = False 36 | uploaded_docs_coll_name_full = chat_state.vectorstore.name 37 | 38 | # Fetch the collection metadata so we can add timestamps 39 | collection_metadata = chat_state.fetch_collection_metadata() 40 | else: 41 | # We will need to create a new collection 42 | is_new_collection = True 43 | collection_metadata = {} 44 | uploaded_docs_coll_name_as_shown = ( 45 | INGESTED_DOCS_INIT_PREFIX + uuid.uuid4().hex[:8] 46 | ) 47 | uploaded_docs_coll_name_full = get_full_collection_name( 48 | chat_state.user_id, uploaded_docs_coll_name_as_shown 49 | ) 50 | try: 51 | vectorstore = ingest_into_collection( 52 | collection_name=uploaded_docs_coll_name_full, 53 | docs=docs, 54 | collection_metadata=collection_metadata, 55 | chat_state=chat_state, 56 | is_new_collection=is_new_collection, 57 | ) 58 | if is_new_collection: 59 | # Switch to the newly created collection 60 | chat_state.vectorstore = vectorstore 61 | msg_template = POST_INGEST_MESSAGE_TEMPLATE_NEW_COLL 62 | else: 63 | msg_template = POST_INGEST_MESSAGE_TEMPLATE_EXISTING_COLL 64 | 65 | with st.chat_message("assistant", avatar=st.session_state.bot_avatar): 66 | st.markdown(msg_template.format(coll_name=uploaded_docs_coll_name_as_shown)) 67 | except Exception as e: 68 | with st.chat_message("assistant", avatar=st.session_state.bot_avatar): 69 | st.markdown( 70 | f"Apologies, an error occurred during ingestion:\n```\n{e}\n```" 71 | ) 72 | -------------------------------------------------------------------------------- /utils/streamlit/prepare.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import streamlit as st 4 | 5 | from components.llm import CallbackHandlerDDGConsole 6 | from docdocgo import do_intro_tasks 7 | from utils.chat_state import ChatState 8 | from utils.prepare import DEFAULT_OPENAI_API_KEY, DUMMY_OPENAI_API_KEY_PLACEHOLDER 9 | from utils.streamlit.fix_event_loop import remove_tornado_fix 10 | from utils.type_utils import OperationMode 11 | from utils.streamlit.helpers import mode_options 12 | 13 | 14 | def prepare_app(): 15 | if os.getenv("STREAMLIT_SCHEDULED_MAINTENANCE"): 16 | st.markdown("Welcome to DocDocGo! 🦉") 17 | st.markdown( 18 | "Scheduled maintenance is currently in progress. Please check back later." 19 | ) 20 | st.stop() 21 | 22 | # Flag for whether or not the OpenAI API key has succeeded at least once 23 | st.session_state.llm_api_key_ok_status = False 24 | 25 | print("query params:", st.query_params) 26 | st.session_state.update_query_params = None 27 | st.session_state.init_collection_name = st.query_params.get("collection") 28 | st.session_state.access_code = st.query_params.get("access_code") 29 | try: 30 | remove_tornado_fix() 31 | vectorstore = do_intro_tasks(openai_api_key=DEFAULT_OPENAI_API_KEY) 32 | except Exception as e: 33 | st.error( 34 | "Apologies, I could not load the vector database. This " 35 | "could be due to a misconfiguration of the environment variables " 36 | f"or missing files. The error reads: \n\n{e}" 37 | ) 38 | st.stop() 39 | 40 | st.session_state.chat_state = ChatState( 41 | operation_mode=OperationMode.STREAMLIT, 42 | vectorstore=vectorstore, 43 | callbacks=[ 44 | CallbackHandlerDDGConsole(), 45 | "placeholder for CallbackHandlerDDGStreamlit", 46 | ], 47 | openai_api_key=DEFAULT_OPENAI_API_KEY, 48 | ) 49 | 50 | st.session_state.prev_supplied_openai_api_key = None 51 | st.session_state.default_openai_api_key = DEFAULT_OPENAI_API_KEY 52 | if st.session_state.default_openai_api_key == DUMMY_OPENAI_API_KEY_PLACEHOLDER: 53 | st.session_state.default_openai_api_key = "" 54 | 55 | st.session_state.idx_file_upload = -1 56 | st.session_state.uploader_form_key = "uploader-form" 57 | 58 | st.session_state.idx_file_download = -1 59 | st.session_state.downloader_form_key = "downloader" 60 | 61 | st.session_state.user_avatar = os.getenv("USER_AVATAR") or None 62 | st.session_state.bot_avatar = os.getenv("BOT_AVATAR") or None # TODO: document this 63 | 64 | SAMPLE_QUERIES = os.getenv( 65 | "SAMPLE_QUERIES", 66 | "/research heatseek Code for row of buttons Streamlit, /research What are the biggest AI news this month?, /help How does infinite research work?", 67 | ) 68 | st.session_state.sample_queries = [q.strip() for q in SAMPLE_QUERIES.split(",")] 69 | st.session_state.default_mode = mode_options[0] -------------------------------------------------------------------------------- /utils/strings.py: -------------------------------------------------------------------------------- 1 | import json 2 | import re 3 | from typing import Iterable 4 | 5 | from utils.type_utils import JSONish 6 | 7 | 8 | def split_preserving_whitespace(text: str) -> tuple[list[str], list[str]]: 9 | """Split a string into parts, preserving whitespace. The result will 10 | be a tuple of two lists: the parts and the whitespaces. The whitespaces 11 | will be strings of whitespace characters (spaces, tabs, newlines, etc.). 12 | The original string can be reconstructed by joining the parts and the 13 | whitespaces, as follows: 14 | ''.join([w + p for p, w in zip(parts, whitespaces)]) + whitespaces[-1] 15 | 16 | In other words, the first whitespace is assumed to come before the first 17 | part, and the last whitespace is assumed to come after the last part. If needed, 18 | the first and/or last whitespace element can be an empty string. 19 | 20 | Each part is guaranteed to be non-empty. 21 | 22 | Returns: 23 | A tuple of two lists: the parts and the whitespaces. 24 | """ 25 | parts = [] 26 | whitespaces = [] 27 | start = 0 28 | for match in re.finditer(r"\S+", text): 29 | end = match.start() 30 | parts.append(match.group()) 31 | whitespaces.append(text[start:end]) 32 | start = match.end() 33 | whitespaces.append(text[start:]) 34 | return parts, whitespaces 35 | 36 | 37 | def remove_consecutive_blank_lines( 38 | lines: list[str], max_consecutive_blank_lines=1 39 | ) -> list[str]: 40 | """Remove consecutive blank lines from a list of lines.""" 41 | new_lines = [] 42 | num_consecutive_blank_lines = 0 43 | for line in lines: 44 | if line: 45 | # Non-blank line 46 | num_consecutive_blank_lines = 0 47 | new_lines.append(line) 48 | else: 49 | # Blank line 50 | num_consecutive_blank_lines += 1 51 | if num_consecutive_blank_lines <= max_consecutive_blank_lines: 52 | new_lines.append(line) 53 | return new_lines 54 | 55 | 56 | def limit_number_of_words(text: str, max_words: int) -> tuple[str, int]: 57 | """ 58 | Limit the number of words in a text to a given number. 59 | Return the text and the number of words in it. Strip leading and trailing 60 | whitespace from the text. 61 | """ 62 | parts, whitespaces = split_preserving_whitespace(text) 63 | num_words = min(max_words, len(parts)) 64 | if num_words == 0: 65 | return "", 0 66 | text = ( 67 | "".join(p + w for p, w in zip(parts[: num_words - 1], whitespaces[1:num_words])) 68 | + parts[num_words - 1] 69 | ) 70 | return text, num_words 71 | 72 | 73 | def limit_num_characters(text: str, max_characters: int) -> str: 74 | """ 75 | Limit the number of characters in a string to a given number. If the string 76 | is longer than the given number of characters, truncate it and append an 77 | ellipsis 78 | """ 79 | if len(text) <= max_characters: 80 | return text 81 | if max_characters > 0: 82 | return text[: max_characters - 1] + "…" 83 | return "" 84 | 85 | 86 | def extract_json(text: str) -> JSONish: 87 | """ 88 | Extract a single JSON object or array from a string. Determines the first 89 | occurrence of '{' or '[', and the last occurrence of '}' or ']', then 90 | extracts the JSON structure accordingly. Returns a dictionary or list on 91 | success, throws on parse errors, including a bracket mismatch. 92 | """ 93 | length = len(text) 94 | first_curly_brace = text.find("{") 95 | last_curly_brace = text.rfind("}") 96 | first_square_brace = text.find("[") 97 | last_square_brace = text.rfind("]") 98 | 99 | assert ( 100 | first_curly_brace + first_square_brace > -2 101 | ), "No opening curly or square bracket found" 102 | 103 | if first_curly_brace == -1: 104 | first_curly_brace = length 105 | elif first_square_brace == -1: 106 | first_square_brace = length 107 | 108 | assert (first_curly_brace < first_square_brace) == ( 109 | last_curly_brace > last_square_brace 110 | ), "Mismatched curly and square brackets" 111 | 112 | first = min(first_curly_brace, first_square_brace) 113 | last = max(last_curly_brace, last_square_brace) 114 | 115 | assert first < last, "No closing bracket found" 116 | 117 | return json.loads(text[first : last + 1]) 118 | 119 | 120 | def has_which_substring(text: str, substrings: Iterable[str]) -> str | None: 121 | """ 122 | Return the first substring in the list that is a substring of the text. 123 | If no such substring is found, return None. 124 | """ 125 | for substring in substrings: 126 | if substring in text: 127 | return substring 128 | return None 129 | -------------------------------------------------------------------------------- /utils/type_utils.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | from typing import Any 3 | from langchain_core.documents.base import Document 4 | from langchain_core.runnables import RunnableSerializable 5 | from pydantic import BaseModel, Field 6 | 7 | from utils.prepare import MODEL_NAME, TEMPERATURE 8 | from langchain_core.callbacks import BaseCallbackHandler 9 | 10 | JSONish = str | int | float | dict[str, "JSONish"] | list["JSONish"] 11 | JSONishDict = dict[str, JSONish] 12 | JSONishSimple = str | int | float | dict[str, Any] | list # for use in pydantic models 13 | Props = dict[str, Any] 14 | 15 | PairwiseChatHistory = list[tuple[str, str]] 16 | CallbacksOrNone = list[BaseCallbackHandler] | None 17 | ChainType = RunnableSerializable[dict, str] # double check this 18 | 19 | OperationMode = Enum("OperationMode", "CONSOLE STREAMLIT FASTAPI") 20 | 21 | 22 | class ChatMode(Enum): 23 | NONE_COMMAND_ID = -1 24 | RETRY_COMMAND_ID = 0 25 | CHAT_WITH_DOCS_COMMAND_ID = 1 26 | DETAILS_COMMAND_ID = 2 27 | QUOTES_COMMAND_ID = 3 28 | WEB_COMMAND_ID = 4 29 | RESEARCH_COMMAND_ID = 5 30 | JUST_CHAT_COMMAND_ID = 6 31 | DB_COMMAND_ID = 7 32 | HELP_COMMAND_ID = 8 33 | INGEST_COMMAND_ID = 9 34 | EXPORT_COMMAND_ID = 10 35 | SUMMARIZE_COMMAND_ID = 11 36 | SHARE_COMMAND_ID = 12 37 | 38 | 39 | chat_modes_needing_llm = { 40 | ChatMode.WEB_COMMAND_ID, 41 | ChatMode.RESEARCH_COMMAND_ID, 42 | ChatMode.JUST_CHAT_COMMAND_ID, 43 | ChatMode.DETAILS_COMMAND_ID, 44 | ChatMode.QUOTES_COMMAND_ID, 45 | ChatMode.CHAT_WITH_DOCS_COMMAND_ID, 46 | ChatMode.SUMMARIZE_COMMAND_ID, 47 | ChatMode.HELP_COMMAND_ID, 48 | } 49 | 50 | 51 | class DDGError(Exception): 52 | default_user_facing_message = ( 53 | "Apologies, I ran into some trouble when preparing a response to you." 54 | ) 55 | default_http_status_code = 500 56 | 57 | def __init__( 58 | self, 59 | message: str | None = None, 60 | user_facing_message: str | None = None, 61 | http_status_code: int | None = None, 62 | ): 63 | super().__init__(message) # could include user_facing_message 64 | 65 | if user_facing_message is not None: 66 | self.user_facing_message = user_facing_message 67 | else: 68 | self.user_facing_message = self.default_user_facing_message 69 | 70 | if http_status_code is not None: 71 | self.http_status_code = http_status_code 72 | else: 73 | self.http_status_code = self.default_http_status_code 74 | 75 | @property 76 | def user_facing_message_full(self): 77 | if self.__cause__ is None: 78 | return self.user_facing_message 79 | return ( 80 | f"{self.user_facing_message} The error reads:\n```\n{self.__cause__}\n```" 81 | ) 82 | 83 | 84 | class BotSettings(BaseModel): 85 | llm_model_name: str = MODEL_NAME 86 | temperature: float = TEMPERATURE 87 | 88 | 89 | AccessRole = Enum("AccessRole", {"NONE": 0, "VIEWER": 1, "EDITOR": 2, "OWNER": 3}) 90 | 91 | AccessCodeType = Enum("AccessCodeType", "NEED_ALWAYS NEED_ONCE NO_ACCESS") 92 | 93 | 94 | class CollectionUserSettings(BaseModel): 95 | access_role: AccessRole = AccessRole.NONE 96 | 97 | 98 | class AccessCodeSettings(BaseModel): 99 | code_type: AccessCodeType = AccessCodeType.NO_ACCESS 100 | access_role: AccessRole = AccessRole.NONE 101 | 102 | 103 | COLLECTION_USERS_METADATA_KEY = "collection_users" 104 | 105 | 106 | class CollectionPermissions(BaseModel): 107 | user_id_to_settings: dict[str, CollectionUserSettings] = Field(default_factory=dict) 108 | # NOTE: key "" refers to settings for a general user 109 | 110 | access_code_to_settings: dict[str, AccessCodeSettings] = Field(default_factory=dict) 111 | 112 | def get_user_settings(self, user_id: str | None) -> CollectionUserSettings: 113 | return self.user_id_to_settings.get(user_id or "", CollectionUserSettings()) 114 | 115 | def set_user_settings( 116 | self, user_id: str | None, settings: CollectionUserSettings 117 | ) -> None: 118 | self.user_id_to_settings[user_id or ""] = settings 119 | 120 | def get_access_code_settings(self, access_code: str) -> AccessCodeSettings: 121 | return self.access_code_to_settings.get(access_code, AccessCodeSettings()) 122 | 123 | def set_access_code_settings( 124 | self, access_code: str, settings: AccessCodeSettings 125 | ) -> None: 126 | self.access_code_to_settings[access_code] = settings 127 | 128 | 129 | INSTRUCT_SHOW_UPLOADER = "INSTRUCT_SHOW_UPLOADER" 130 | INSTRUCT_CACHE_ACCESS_CODE = "INSTRUCT_CACHE_ACCESS_CODE" 131 | INSTRUCT_AUTO_RUN_NEXT_QUERY = "INSTRUCT_AUTO_RUN_NEXT_QUERY" 132 | INSTRUCT_EXPORT_CHAT_HISTORY = "INSTRUCT_EXPORT_CHAT_HISTORY" 133 | # INSTRUCTION_SKIP_CHAT_HISTORY = "INSTRUCTION_SKIP_CHAT_HISTORY" 134 | 135 | 136 | class Instruction(BaseModel): 137 | type: str 138 | user_id: str | None = None 139 | access_code: str | None = None 140 | data: JSONishSimple | None = None # NOTE: absorb the two fields above into this one 141 | 142 | 143 | class Doc(BaseModel): 144 | """Pydantic-compatible version of Langchain's Document.""" 145 | 146 | page_content: str 147 | metadata: dict[str, Any] 148 | 149 | @staticmethod 150 | def from_lc_doc(langchain_doc: Document) -> "Doc": 151 | return Doc( 152 | page_content=langchain_doc.page_content, metadata=langchain_doc.metadata 153 | ) 154 | 155 | def to_lc_doc(self) -> Document: 156 | return Document(page_content=self.page_content, metadata=self.metadata) 157 | --------------------------------------------------------------------------------