├── .chroma_env.example
├── .dockerignore
├── .env.example
├── .github
└── CODEOWNERS
├── .gitignore
├── .streamlit
└── config.toml
├── .vscode
└── settings.json
├── Dockerfile
├── LICENSE
├── README-FOR-DEVELOPERS.md
├── README.md
├── TODOS-AND-OTHER-DEV-NOTES.md
├── _prepare_env.py
├── agentblocks
├── collectionhelper.py
├── core.py
├── docconveyer.py
├── webprocess.py
├── webretrieve.py
└── websearch.py
├── agents
├── dbmanager.py
├── exporter.py
├── ingester_summarizer.py
├── research_heatseek.py
├── researcher.py
├── researcher_data.py
├── share_manager.py
└── websearcher_quick.py
├── announcements
├── version-0-2-6.md
└── version-0-2.md
├── api.py
├── components
├── __init__.py
├── chat_with_docs_chain.py
├── chroma_ddg.py
├── chroma_ddg_retriever.py
├── llm.py
└── openai_embeddings_ddg.py
├── config
├── logging.json
└── trafilatura.cfg
├── docdocgo.py
├── eval
├── RAG-PARAMS.md
├── ai_news_1.py
├── ai_resume_builders.txt
├── list_ai_resume_builders.txt
├── openai_news.py
├── prompt_running_llama2.txt
├── queries.md
├── testing-rag-quality
│ ├── 2401.01325.pdf
│ └── rag-tests.md
└── top_russian_desserts.py
├── ingest_local_docs.py
├── media
├── ddg-logo.png
├── minimal10.png
├── minimal10.svg
├── minimal11.png
├── minimal11.svg
├── minimal13.svg
├── minimal7-plain-svg.svg
├── minimal7.png
├── minimal7.svg
├── minimal9.png
└── minimal9.svg
├── requirements.txt
├── streamlit_app.py
└── utils
├── __init__.py
├── algo.py
├── async_utils.py
├── chat_state.py
├── debug.py
├── docgrab.py
├── filesystem.py
├── helpers.py
├── ingest.py
├── input.py
├── lang_utils.py
├── log.py
├── older
├── prompts-checkpoints.py
└── prompts-older.py
├── output.py
├── prepare.py
├── prompts.py
├── query_parsing.py
├── rag.py
├── streamlit
├── fix_event_loop.py
├── helpers.py
├── ingest.py
└── prepare.py
├── strings.py
├── type_utils.py
└── web.py
/.chroma_env.example:
--------------------------------------------------------------------------------
1 | # You need to use this only if you want to run Chroma in a Docker container
2 | # Choose your own secret key, but make sure it matches CHROMA_SERVER_AUTHN_CREDENTIALS in .env
3 | # NOTE: Do not enclose values in quotes (contrary to the example in the chromadb docs)
4 | CHROMA_SERVER_AUTHN_CREDENTIALS=choose-your-own-secret-key
5 |
6 | # You don't need to change these
7 | CHROMA_SERVER_AUTHN_PROVIDER=chromadb.auth.token_authn.TokenAuthenticationServerProvider
8 | CHROMA_AUTH_TOKEN_TRANSPORT_HEADER=X-Chroma-Token
9 | ANONYMIZED_TELEMETRY=False
10 |
--------------------------------------------------------------------------------
/.dockerignore:
--------------------------------------------------------------------------------
1 | .venv/
2 | __pycache__/
3 | *.pyc
4 | .git/
5 | *.log
6 | dist/
7 | .DS_Store
8 | *.egg-info/
9 | *.egg
10 |
11 | .my-chroma*
12 | docs-private/
13 | DOCS-APPENDIX.md
14 | docdocgo*.pem
15 | .env.*
16 |
--------------------------------------------------------------------------------
/.env.example:
--------------------------------------------------------------------------------
1 | # Copy this file to .env and fill in the values
2 |
3 | ## NOTE: The only required value is DEFAULT_OPENAI_API_KEY. The app should work if the other values
4 | # are left as is or not defined at all. However, you are strongly encouraged to fill in your
5 | # own SERPER_API_KEY value. It is also recommended to fill in the BYPASS_SETTINGS_RESTRICTIONS and
6 | # BYPASS_SETTINGS_RESTRICTIONS_PASSWORD values as needed (see below).
7 | DEFAULT_OPENAI_API_KEY="" # your OpenAI API key
8 |
9 | ## Your Google Serper API key (for web searches).
10 | # If you don't set this, my key will be used, which may have
11 | # run out of credits by the time you use it (there's no payment info associated with it).
12 | # You can get a free key without a credit card here: https://serper.dev/
13 | # SERPER_API_KEY=""
14 |
15 | # Mode to use if none is specified in the query
16 | DEFAULT_MODE="/docs" # or "/web" | "/quotes" | "/details" | "/chat" | "/research"
17 |
18 | # Variables controlling whether the Streamlit app imposes some functionality restrictions
19 | BYPASS_SETTINGS_RESTRICTIONS="" # whether to immediately allow all settings (any non-empty string means true)
20 | BYPASS_SETTINGS_RESTRICTIONS_PASSWORD="" # what to enter in the OpenAI API key field to bypass settings
21 | # restrictions. If BYPASS_SETTINGS_RESTRICTIONS is non-empty or if the user enters their own OpenAI API key,
22 | # this becomes - mostly - irrelevant, as full settings access is already granted. I say "mostly" because
23 | # this password can also be used for a couple of admin-only features, such as deleting the default collection
24 | # and deleting a range of collections (see dbmanager.py). Recommendation: for local use,
25 | # the easiest thing to do is to set both values. For a public deployment, you may want to leave
26 | # BYPASS_SETTINGS_RESTRICTIONS empty.
27 |
28 | # If you are NOT using Azure, the following setting determines the chat model
29 | # NOTE: You can change the model and temperature on the fly in Streamlit UI or via the API
30 | MODEL_NAME="gpt-3.5-turbo-0125" # default model to use for chat
31 | ALLOWED_MODELS="gpt-3.5-turbo-0125, gpt-4-turbo-2024-04-09"
32 | CONTEXT_LENGTH="16000" # you can also make it lower than the actual context length
33 | EMBEDDINGS_MODEL_NAME="text-embedding-3-large"
34 | EMBEDDINGS_DIMENSIONS="3072" # number of dimensions for the embeddings model
35 | # Possible embedding models and their dimensions:
36 | # text-embedding-ada-002: 1536
37 | # text-embedding-3-small: 1536 or 512
38 | # text-embedding-3-large: 3072, 1024 or 512
39 | # NOTE: depending on the model, different relevance score thresholds are needed (see
40 | # score_threshold_min/max in chroma_ddg_retriever.py). Currently these values are auto-set
41 | # based on the model name, but currently they have only been tested for the
42 | # "text-embedding-3-large"-3072 and "text-embedding-ada-002" models.
43 |
44 | # If you are using Azure, uncomment the following settings and fill in the values
45 | # AZURE_OPENAI_API_KEY="" # your OpenAI API key
46 | # OPENAI_API_TYPE="azure" # yours may be different
47 | # OPENAI_API_BASE="https://techops.openai.azure.com" # yours may be different
48 | # OPENAI_API_VERSION="2023-05-15" # or whichever version you're using
49 | # CHAT_DEPLOYMENT_NAME="gpt-35-turbo" # your deployment name for the chat model you want to use
50 | # EMBEDDINGS_DEPLOYMENT_NAME="text-embedding-ada-002" # your deployment name for the embedding model
51 |
52 | TEMPERATURE="0.3" # 0 to 2, higher means more random, lower means more conservative (>1.6 is jibberish)
53 | LLM_REQUEST_TIMEOUT="3" # seconds to wait before request to the LLM service should be considered timed
54 | # out and the bot should retry sending the request (a few times with exponential back-off).
55 | # For Azure, it seems that a slightly longer timeout is needed, about 9 seconds.
56 |
57 | # Whether to use Playwright for more reliable web scraping (any non-empty string means true)
58 | # It's recommended to use Playwright but it may not work in all environments and requires
59 | # a small amount of setup (mostly just running "playwright install" - you will be prompted)
60 | USE_PLAYWRIGHT=""
61 |
62 | DEFAULT_COLLECTION_NAME="docdocgo-documentation" # name of the initially selected collection
63 |
64 | ## There are two ways to use the Chroma vector database: using a local database or via HTTP
65 | ## (e.g. by running it in a Docker container). The default is the local option. While it's the
66 | ## simplest way to get started, the HTTP option is more robust and scalable.
67 |
68 | ## The following item is only relevant if you want to use a local Chroma db
69 | VECTORDB_DIR="chroma/" # directory of your doc database
70 |
71 | ## The following items are only relevant if you want to run Chroma in a Docker container
72 | USE_CHROMA_VIA_HTTP="" # whether to use Chroma via HTTP (any non-empty string means true)
73 | CHROMA_SERVER_AUTH_CREDENTIALS="" # choose your own (must be the same as in your Chroma Docker container)
74 | CHROMA_SERVER_HOST="localhost" # IP address of your Chroma server
75 | CHROMA_SERVER_HTTP_PORT="8000" # port your Chroma server is listening on
76 |
77 | ## Settings for the response
78 |
79 | # Whether to include the error message in the user-facing error message
80 | INCLUDE_ERROR_IN_USER_FACING_ERROR_MSG="sure"
81 |
82 | ## Settings for Streamlit
83 |
84 | DOMAIN_NAME_FOR_SHARING="" # domain name for sharing links (e.g. "https://localhost:8501")
85 |
86 | # Chat avatars (if you don't set these, the default avatars will be used)
87 | USER_AVATAR="" # URL of the user's chat avatar (e.g. "./images/user-avatar-example.png")
88 | BOT_AVATAR="" # URL of the bot's chat avatar (e.g. "./images/bot-avatar-example.png")
89 |
90 | # Warning message to display at the top of the Streamlit app (normally should be left empty)
91 | STREAMLIT_WARNING_NOTIFICATION=""
92 |
93 | ## The following items are only relevant for ingesting docs
94 | COLLECTON_NAME_FOR_INGESTED_DOCS="" # collection name for the docs you want to ingest
95 | DOCS_TO_INGEST_DIR_OR_FILE="" # path to directory or single file with your docs to ingest
96 |
97 | ## The following items are only relevant if running the FastAPI server
98 | DOCDOCGO_API_KEY="" # choose your own
99 | MAX_UPLOAD_BYTES="104857600" # max size of files that can be uploaded (default is 100MB)
100 |
101 | ## Logging settings
102 | DEFAULT_LOGGER_NAME="ddg"
103 |
104 | # Overrides for config/logging.json
105 | LOG_FORMAT="%(asctime)s - %(name)s - %(levelname)s - %(message)s - [%(filename)s:%(lineno)d in %(funcName)s]"
106 | LOG_LEVEL="INFO" # DEBUG, INFO, WARNING, ERROR, CRITICAL
107 |
108 | ## Additional options, used in development (any non-empty string means true)
109 | RERAISE_EXCEPTIONS="" # don't keep convo going if an exception is raised, just reraise it
110 | PRINT_STANDALONE_QUERY="" # print query used for the retriever
111 | PRINT_QA_PROMPT="" # print the prompt used for the final response (if asking about your docs)
112 | PRINT_CONDENSE_QUESTION_PROMPT="" # print the prompt used to condense the question
113 | PRINT_SIMILARITIES="" # print the similarities between the query and the docs
114 | PRINT_RESEARCHER_PROMPT="" # print the prompt used for the final researcher response
115 | INITIAL_TEST_QUERY_STREAMLIT="" # auto-run as the first query in Streamlit
116 |
--------------------------------------------------------------------------------
/.github/CODEOWNERS:
--------------------------------------------------------------------------------
1 | * @reasonmethis
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | docdocgo*.pem
2 | chroma-cloud*
3 | .tmp*
4 | credentials/
5 | chroma-cloud-backup/
6 | .chroma_env
7 | my-chroma*
8 | .chroma_env
9 | DOCS-APPENDIX.md
10 | docs-private
11 | launch.json
12 | logs/
13 | debug*.txt
14 | # Byte-compiled / optimized / DLL files
15 | __pycache__/
16 | *.py[cod]
17 | *$py.class
18 |
19 | # C extensions
20 | *.so
21 |
22 | # Distribution / packaging
23 | .Python
24 | build/
25 | develop-eggs/
26 | dist/
27 | downloads/
28 | eggs/
29 | .eggs/
30 | lib/
31 | lib64/
32 | parts/
33 | sdist/
34 | var/
35 | wheels/
36 | share/python-wheels/
37 | *.egg-info/
38 | .installed.cfg
39 | *.egg
40 | MANIFEST
41 |
42 | # PyInstaller
43 | # Usually these files are written by a python script from a template
44 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
45 | *.manifest
46 | *.spec
47 |
48 | # Installer logs
49 | pip-log.txt
50 | pip-delete-this-directory.txt
51 |
52 | # Unit test / coverage reports
53 | htmlcov/
54 | .tox/
55 | .nox/
56 | .coverage
57 | .coverage.*
58 | .cache
59 | nosetests.xml
60 | coverage.xml
61 | *.cover
62 | *.py,cover
63 | .hypothesis/
64 | .pytest_cache/
65 | cover/
66 |
67 | # Translations
68 | *.mo
69 | *.pot
70 |
71 | # Django stuff:
72 | *.log
73 | local_settings.py
74 | db.sqlite3
75 | db.sqlite3-journal
76 |
77 | # Flask stuff:
78 | instance/
79 | .webassets-cache
80 |
81 | # Scrapy stuff:
82 | .scrapy
83 |
84 | # Sphinx documentation
85 | docs/_build/
86 |
87 | # PyBuilder
88 | .pybuilder/
89 | target/
90 |
91 | # Jupyter Notebook
92 | .ipynb_checkpoints
93 |
94 | # IPython
95 | profile_default/
96 | ipython_config.py
97 |
98 | # pyenv
99 | # For a library or package, you might want to ignore these files since the code is
100 | # intended to run in multiple environments; otherwise, check them in:
101 | # .python-version
102 |
103 | # pipenv
104 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
105 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
106 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
107 | # install all needed dependencies.
108 | #Pipfile.lock
109 |
110 | # poetry
111 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
112 | # This is especially recommended for binary packages to ensure reproducibility, and is more
113 | # commonly ignored for libraries.
114 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
115 | #poetry.lock
116 |
117 | # pdm
118 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
119 | #pdm.lock
120 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
121 | # in version control.
122 | # https://pdm.fming.dev/#use-with-ide
123 | .pdm.toml
124 |
125 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
126 | __pypackages__/
127 |
128 | # Celery stuff
129 | celerybeat-schedule
130 | celerybeat.pid
131 |
132 | # SageMath parsed files
133 | *.sage.py
134 |
135 | # Environments
136 | .env*
137 | .venv
138 | env/
139 | venv/
140 | ENV/
141 | env.bak/
142 | venv.bak/
143 |
144 | # Spyder project settings
145 | .spyderproject
146 | .spyproject
147 |
148 | # Rope project settings
149 | .ropeproject
150 |
151 | # mkdocs documentation
152 | /site
153 |
154 | # mypy
155 | .mypy_cache/
156 | .dmypy.json
157 | dmypy.json
158 |
159 | # Pyre type checker
160 | .pyre/
161 |
162 | # pytype static type analyzer
163 | .pytype/
164 |
165 | # Cython debug symbols
166 | cython_debug/
167 |
168 | # PyCharm
169 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
170 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
171 | # and can be added to the global gitignore or merged into this file. For a more nuclear
172 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
173 | #.idea/
174 |
--------------------------------------------------------------------------------
/.streamlit/config.toml:
--------------------------------------------------------------------------------
1 | [theme]
2 | base="light"
3 | primaryColor="#005F26"
4 |
--------------------------------------------------------------------------------
/.vscode/settings.json:
--------------------------------------------------------------------------------
1 | {
2 | "[python]": {
3 | "editor.defaultFormatter": "charliermarsh.ruff"
4 | },
5 | "search.useIgnoreFiles": true
6 | }
7 |
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM python:3.11-slim
2 |
3 | WORKDIR /app
4 |
5 | COPY requirements.txt .
6 |
7 | RUN pip install --no-cache-dir -r requirements.txt
8 |
9 | COPY . .
10 |
11 | EXPOSE 80
12 |
13 | # https://github.com/tiangolo/fastapi/discussions/9328#discussioncomment-8242245
14 | CMD ["uvicorn", "api:app", "--host", "0.0.0.0", "--port", "80", "--proxy-headers", "--forwarded-allow-ips=*"]
15 |
16 |
17 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023-2024 Dmitriy Vasilyuk
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/TODOS-AND-OTHER-DEV-NOTES.md:
--------------------------------------------------------------------------------
1 | # Dev Notes
2 |
3 | ## Assorted TODOs
4 |
5 | 1. DONE chromadb.api.models.Collection has feature to filter by included text and metadatas (I think)
6 | 2. Eval mode: gives two answers, user likes one, bot tunes params
7 | 3. Generate several versions of the standalone query
8 | 4. Use tasks in case answer takes too long (meanwhile can show a tip)
9 | 5. Ability for user to customize bot settings (num of docs, relevance threshold, temperature, etc.)
10 | 6. DONE Make use of larger context window
11 | 7. DONE Reorder relevant documents in order (based on overlap e.g.)
12 | 8. DONE API token for the bot backend
13 | 9. For long convos include short summary in the prompt
14 | 10. DONE Use streaming to improve detecting true LLM stalls
15 | 11. NA Send partial results to the user to avoid timeout
16 | 12. Return retrieved documents in the response
17 | 13. Speed up limiting tokens - see tiktoken comment in OpenAIEmbeddings
18 | 14. DONE Include current date/time in the prompt
19 | 15. DONE Use two different chunk sizes: search for small chunks, then small chunks point to larger chunks
20 | 16. Include tips for next possible commands in the response
21 | 17. Add /re feedback command to automatically change query, report type, search queries
22 | 18. Automatically infer command from user query
23 | 19. Ability to ingest current chat history into a collection
24 |
25 | ## Evaluations
26 |
27 | ### Research tasks
28 |
29 | 1. "extract main content from html"
30 |
31 | ## Discovered Tricks for Getting Content from Websites
32 |
33 | 1. To help with bot detection, it helps to set a different device in playwright
34 | - e.g. "iphone13"
35 | - For example, this helped with the following links:
36 | - "https://www.sciencedaily.com/news/computers_math/artificial_intelligence/",
37 | - "https://www.investors.com/news/technology/ai-stocks-artificial-intelligence-trends-and-news/",
38 | - "https://www.forbes.com/sites/forbesbusinesscouncil/2022/10/07/recent-advancements-in-artificial-intelligence/?sh=266cd08e7fa5",
39 | 2. I increased the MIN_EXTRACTED_SIZE setting from 250 to 1500 in trafilatura to extract more content
40 | - for example, in the following link, with the default setting it only extracted a cookie banner:
41 | - "https://www.artificialintelligence-news.com/"
42 |
43 |
44 | ## Other Notes
45 |
46 | ### get_relevant_documents on a retriever
47 |
48 | - calls _get_relevant_documents
49 | - for vectorstores, that calls self.vectorstore.similarity_search_with_relevance_scores(query, **self.search_kwargs)
50 | - _similarity_search_with_relevance_scores
51 | - similarity_search_with_score
52 | - for chroma, that calls self.__query_collection(query_embeddings=[query_embedding], n_results=k, where=filter)
53 | - besides k and filter, no other search_kwargs are passed (specifically where_document),
54 | even though it accepts arbitrary kwargs
55 | - self._collection.query(
56 | query_texts=query_texts,
57 | query_embeddings=query_embeddings,
58 | n_results=n_results,
59 | where=where,
60 | **kwargs,
61 | )
62 |
--------------------------------------------------------------------------------
/_prepare_env.py:
--------------------------------------------------------------------------------
1 | """Load the environment variables and perform other initialization tasks."""
2 | from utils.prepare import is_env_loaded
3 |
4 | if not is_env_loaded:
5 | raise RuntimeError("This should be unreachable.")
6 |
--------------------------------------------------------------------------------
/agentblocks/collectionhelper.py:
--------------------------------------------------------------------------------
1 | import uuid
2 |
3 | from langchain_core.documents import Document
4 |
5 | from agents.dbmanager import get_full_collection_name
6 | from components.chroma_ddg import ChromaDDG, exists_collection
7 | from utils.chat_state import ChatState
8 | from utils.docgrab import ingest_into_chroma
9 | from utils.helpers import get_timestamp
10 | from utils.prepare import get_logger
11 |
12 | logger = get_logger()
13 |
14 | SMALL_WORDS = {"a", "an", "the", "of", "in", "on", "at", "for", "to", "and", "or"}
15 | SMALL_WORDS |= {"is", "are", "was", "were", "be", "been", "being", "am", "what"}
16 | SMALL_WORDS |= {"what", "which", "who", "whom", "whose", "where", "when", "how"}
17 | SMALL_WORDS |= {"this", "that", "these", "those", "there", "here", "can", "could"}
18 | SMALL_WORDS |= {"i", "you", "he", "she", "it", "we", "they", "me", "him", "her"}
19 | SMALL_WORDS |= {"my", "your", "his", "her", "its", "our", "their", "mine", "yours"}
20 | SMALL_WORDS |= {"some", "any"}
21 |
22 |
23 | def construct_new_collection_name(query: str, chat_state: ChatState) -> str:
24 | """
25 | Construct and return new collection name based on the query and user ID. Ensures
26 | that the collection name is unique by appending a number to the end if necessary.
27 | """
28 | # Decide on the collection name consistent with ChromaDB's naming rules
29 | query_words = [x.lower() for x in query.split()]
30 | words = []
31 | words_excluding_small = []
32 | for word in query_words:
33 | word_just_alnum = "".join(x for x in word if x.isalnum())
34 | if not word_just_alnum:
35 | break
36 | words.append(word_just_alnum)
37 | if word not in SMALL_WORDS:
38 | words_excluding_small.append(word_just_alnum)
39 |
40 | words = words_excluding_small if len(words_excluding_small) > 2 else words
41 |
42 | new_coll_name = "-".join(words[:3])[:35].rstrip("-")
43 |
44 | # Screen for too short collection names or those that are convetible to a number
45 | try:
46 | if len(new_coll_name) < 3 or int(new_coll_name) is not None:
47 | new_coll_name = f"collection-{new_coll_name}".rstrip("-")
48 | except ValueError:
49 | pass
50 |
51 | # Construct full collection name (preliminary)
52 | new_coll_name = get_full_collection_name(chat_state.user_id, new_coll_name)
53 |
54 | new_coll_name_final = new_coll_name
55 |
56 | # Check if collection exists, if so, add a number to the end
57 | for i in range(2, 1000000):
58 | if not exists_collection(new_coll_name_final, chat_state.vectorstore.client):
59 | return new_coll_name_final
60 | new_coll_name_final = f"{new_coll_name}-{i}"
61 |
62 |
63 | def ingest_into_collection(
64 | *,
65 | collection_name: str,
66 | docs: list[Document],
67 | collection_metadata: dict[str, str] | None,
68 | chat_state: ChatState,
69 | is_new_collection: bool,
70 | retry_with_random_name: bool = False,
71 | ) -> ChromaDDG:
72 | """
73 | Load provided documents and metadata into a new or existing collection.
74 |
75 | It's ok to pass an empty list of documents. If retry_with_random_name is True,
76 | and the collection name is invalid, it will replace it with a random valid one.
77 |
78 | If is_new_collection is True, the function will add the "created_at" and "updated_at"
79 | metadata fields to the collection (overwriting such fields in the passed metadata).
80 | If is_new_collection is False, it will only add the "updated_at" field, and only if
81 | collection_metadata is not None.
82 | """
83 | logger.info("Creating new collection and loading data")
84 |
85 | for i in range(2):
86 | try:
87 | timestamp = get_timestamp()
88 | if is_new_collection:
89 | full_metadata = {"created_at": timestamp, "updated_at": timestamp}
90 | if collection_metadata:
91 | full_metadata.update(collection_metadata)
92 | elif collection_metadata is not None:
93 | full_metadata = {"updated_at": timestamp} | collection_metadata
94 | else:
95 | # For now, we don't add "updated_at" if no metadata is passed,
96 | # because we would need to fetch the existing metadata first
97 | full_metadata = None
98 |
99 | vectorstore = ingest_into_chroma(
100 | docs,
101 | collection_name=collection_name,
102 | openai_api_key=chat_state.openai_api_key,
103 | chroma_client=chat_state.vectorstore.client,
104 | collection_metadata=full_metadata,
105 | )
106 | break # success
107 | except Exception as e: # bad name error may not be ValueError in docker mode
108 | logger.error(f"Error ingesting documents into ChromaDB: {e}")
109 | if (
110 | not retry_with_random_name
111 | or i != 0
112 | or "Expected collection name" not in str(e)
113 | ):
114 | raise e # i == 1 means tried normal name and random name, give up
115 |
116 | # Create a random valid collection name and try again
117 | collection_name = get_full_collection_name(
118 | chat_state.user_id, "collection-" + uuid.uuid4().hex[:8]
119 | )
120 | logger.warning(f"Ingesting into new collection named {collection_name}")
121 |
122 | logger.info(f"Finished ingesting data into new collection {collection_name}")
123 | return vectorstore
124 |
--------------------------------------------------------------------------------
/agentblocks/core.py:
--------------------------------------------------------------------------------
1 | from typing import Callable
2 |
3 | from pydantic import BaseModel
4 |
5 | from utils.strings import extract_json
6 | from utils.type_utils import ChainType, DDGError, JSONish
7 |
8 | MAX_ENFORCE_FORMAT_ATTEMPTS = 4
9 |
10 |
11 | class EnforceFormatError(DDGError):
12 | default_user_facing_message = (
13 | "Apologies, I tried to compose a message in the right "
14 | "format, but I kept running into trouble."
15 | )
16 |
17 |
18 | def enforce_format(
19 | chain: ChainType,
20 | inputs: dict,
21 | validator_transformer: Callable,
22 | max_attempts: int = MAX_ENFORCE_FORMAT_ATTEMPTS,
23 | ):
24 | """
25 | Enforce a format on the output of a chain. If the chain fails to produce
26 | the expected format, retry up to `max_attempts` times. An attempt is considered
27 | successful if the chain does not raise an exception and the output can be
28 | transformed using the `validator_transformer`. Return the transformed output. On
29 | failure, raise an `EnforceFormatError` with a message that includes the last output.
30 | """
31 | for i in range(max_attempts):
32 | output = "OUTPUT_FAILED"
33 | try:
34 | output = chain.invoke(inputs)
35 | return validator_transformer(output)
36 | except Exception as e:
37 | err_msg = f"Failed to enforce format. Output:\n\n{output}\n\nError: {e}"
38 | print(err_msg)
39 | if i == max_attempts - 1:
40 | raise EnforceFormatError(err_msg) from e
41 |
42 |
43 | MAX_ENFORCE_FORMAT_ATTEMPTS = 4
44 |
45 |
46 | def enforce_json_format(
47 | chain: ChainType,
48 | inputs: dict,
49 | validator_transformer: Callable, # e.g. pydantic model's validator - model_validate() (!)
50 | max_attempts: int = MAX_ENFORCE_FORMAT_ATTEMPTS,
51 | ) -> JSONish:
52 | return enforce_format(
53 | chain, inputs, lambda x: validator_transformer(extract_json(x)), max_attempts
54 | )
55 |
56 |
57 | def enforce_pydantic_json(
58 | chain: ChainType,
59 | inputs: dict,
60 | pydantic_model: BaseModel,
61 | max_attempts: int = MAX_ENFORCE_FORMAT_ATTEMPTS,
62 | ) -> BaseModel:
63 | return enforce_format(
64 | chain,
65 | inputs,
66 | lambda x: pydantic_model.model_validate(extract_json(x)),
67 | max_attempts,
68 | )
69 |
--------------------------------------------------------------------------------
/agentblocks/docconveyer.py:
--------------------------------------------------------------------------------
1 | import uuid
2 |
3 | # class DocData(BaseModel):
4 | # text: str | None = None
5 | # source: str | None = None
6 | # num_tokens: int | None = None
7 | # doc_uuid: str | None = None
8 | # error: str | None = None
9 | # is_ingested: bool = False
10 | from typing import TypeVar
11 | from langchain_core.documents import Document
12 | from pydantic import BaseModel, Field
13 |
14 | from utils.lang_utils import ROUGH_UPPER_LIMIT_AVG_CHARS_PER_TOKEN, get_num_tokens
15 | from utils.prepare import get_logger
16 | from utils.type_utils import Doc
17 | from langchain_text_splitters import RecursiveCharacterTextSplitter
18 |
19 | logger = get_logger()
20 |
21 | DocT = TypeVar("DocT", Document, Doc)
22 |
23 |
24 | def _split_doc_based_on_tokens(
25 | doc: Document | Doc, max_tokens: float, target_num_chars: int
26 | ) -> list[Document]:
27 | """
28 | Helper function for the main function below that definitely splits the provided document.
29 | """
30 | text_splitter = RecursiveCharacterTextSplitter(
31 | chunk_size=target_num_chars,
32 | chunk_overlap=0,
33 | add_start_index=True, # metadata will include start_index of snippet in original doc
34 | )
35 | candidate_new_docs = text_splitter.create_documents(
36 | texts=[doc.page_content], metadatas=[doc.metadata]
37 | )
38 | new_docs = []
39 |
40 | # Calculate the number of tokens in each part and split further if needed
41 | for maybe_new_doc in candidate_new_docs:
42 | maybe_new_doc.metadata["num_tokens"] = get_num_tokens(
43 | maybe_new_doc.page_content
44 | )
45 |
46 | # If the part is sufficiently small, add it to the list of new docs
47 | if maybe_new_doc.metadata["num_tokens"] <= max_tokens:
48 | new_docs.append(maybe_new_doc)
49 | continue
50 |
51 | # If actual num tokens is X times the target, reduce target_num_chars by ~X
52 | new_target_num_chars = min(
53 | int(target_num_chars * max_tokens / maybe_new_doc.metadata["num_tokens"]),
54 | target_num_chars // 2,
55 | )
56 |
57 | # Split the part further and adjust the start index of each snippet
58 | new_parts = _split_doc_based_on_tokens(
59 | maybe_new_doc, max_tokens, new_target_num_chars
60 | )
61 | for new_part in new_parts:
62 | new_part.metadata["start_index"] += maybe_new_doc.metadata["start_index"]
63 |
64 | # Add the new parts to the list of new docs
65 | new_docs.extend(new_parts)
66 |
67 | return new_docs
68 |
69 |
70 | def split_doc_based_on_tokens(doc: DocT, max_tokens: float) -> list[DocT]:
71 | """
72 | Split a document into parts based on the number of tokens in each part. Specifically,
73 | if the number of tokens in a part (or the original doc) is within max_tokens, then the part is
74 | added to the list of new docs. Otherwise, the part is split further. The resulting Document
75 | objects are returned as a list and contain the copy of the metadata from the parent doc, plus
76 | the "start_index" of its occurrence in the parent doc. The metadata will also include the
77 | "num_tokens" of the part.
78 | """
79 | # If the doc is too large then don't count tokens, just split it.
80 | num_chars = len(doc.page_content)
81 | if num_chars / ROUGH_UPPER_LIMIT_AVG_CHARS_PER_TOKEN > max_tokens:
82 | # Doc almost certainly has more tokens than max_tokens
83 | target_num_chars = int(max_tokens * ROUGH_UPPER_LIMIT_AVG_CHARS_PER_TOKEN / 2)
84 | else:
85 | # Count the tokens to see if we need to split
86 | if (num_tokens := get_num_tokens(doc.page_content)) <= max_tokens:
87 | doc.metadata["num_tokens"] = num_tokens
88 | return [doc]
89 | # Guess the target number of characters for the text splitter
90 | target_num_chars = int(num_chars / (num_tokens / max_tokens) / 2)
91 |
92 | documents = _split_doc_based_on_tokens(doc, max_tokens, target_num_chars)
93 | if isinstance(doc, Document):
94 | return documents
95 | else:
96 | return [Doc.from_lc_doc(d) for d in documents]
97 |
98 |
99 | def break_up_big_docs(
100 | docs: list[DocT],
101 | max_tokens: float, # = int(CONTEXT_LENGTH * 0.25),
102 | ) -> list[DocT]:
103 | """
104 | Split each big document into parts, leaving the small ones as they are. Big vs small is
105 | determined by how the number of tokens in the document compares to the max_tokens parameter.
106 | """
107 | logger.info(f"Breaking up {len(docs)} docs into parts with max_tokens={max_tokens}")
108 | new_docs = []
109 | for doc in docs:
110 | doc_parts = split_doc_based_on_tokens(doc, max_tokens)
111 | logger.debug(f"Split doc {doc.metadata.get('source')} into {len(doc_parts)} parts")
112 | if len(doc_parts) > 1:
113 | # Create an id representing the original doc
114 | full_doc_ref = uuid.uuid4().hex[:8]
115 | for i, doc_part in enumerate(doc_parts):
116 | doc_part.metadata["full_doc_ref"] = full_doc_ref
117 | doc_part.metadata["part_id"] = str(i)
118 |
119 | new_docs.extend(doc_parts)
120 | return new_docs
121 |
122 |
123 | def limit_num_docs_by_tokens(
124 | docs: list[Document | Doc], max_tokens: float
125 | ) -> tuple[int, int]:
126 | """
127 | Limit the number of documents in a list based on the total number of tokens in the
128 | documents.
129 |
130 | Return the number of documents that can be included in the
131 | list without exceeding the maximum number of tokens, and the total number of tokens
132 | in the included documents.
133 |
134 | Additionally, the function updates the metadata of each document with the number of
135 | tokens in the document, if it is not already present.
136 | """
137 | tot_tokens = 0
138 | for i, doc in enumerate(docs):
139 | if (num_tokens := doc.metadata.get("num_tokens")) is None:
140 | doc.metadata["num_tokens"] = num_tokens = get_num_tokens(doc.page_content)
141 | if tot_tokens + num_tokens > max_tokens:
142 | return i, tot_tokens
143 | tot_tokens += num_tokens
144 |
145 | return len(docs), tot_tokens
146 |
147 |
148 | class DocConveyer(BaseModel):
149 | docs: list[Doc] = Field(default_factory=list)
150 | max_tokens_for_breaking_up_docs: int | float | None = None
151 | idx_first_not_done: int = 0 # done = "pushed out" by get_next_docs
152 |
153 | def __init__(self, **data):
154 | super().__init__(**data)
155 | if self.max_tokens_for_breaking_up_docs is not None:
156 | self.docs = break_up_big_docs(
157 | self.docs, self.max_tokens_for_breaking_up_docs
158 | )
159 |
160 | @property
161 | def num_available_docs(self):
162 | return len(self.docs) - self.idx_first_not_done
163 |
164 | def add_docs(self, docs: list[Doc]):
165 | if self.max_tokens_for_breaking_up_docs is not None:
166 | docs = break_up_big_docs(docs, self.max_tokens_for_breaking_up_docs)
167 | self.docs.extend(docs)
168 |
169 | def get_next_docs(
170 | self,
171 | max_tokens: float,
172 | max_docs: int | None = None,
173 | max_full_docs: int | None = None,
174 | ) -> list[Doc]:
175 | logger.debug(f"{self.num_available_docs} docs available")
176 | num_docs, _ = limit_num_docs_by_tokens(
177 | self.docs[self.idx_first_not_done :], max_tokens
178 | )
179 | if max_docs is not None:
180 | num_docs = min(num_docs, max_docs)
181 |
182 | if max_full_docs is not None:
183 | num_full_docs = 0
184 | old_full_doc_ref = None
185 | new_num_docs = 0
186 | for doc in self.docs[
187 | self.idx_first_not_done : self.idx_first_not_done + num_docs
188 | ]:
189 | if num_full_docs >= max_full_docs:
190 | break
191 | new_num_docs += 1
192 | new_full_doc_ref = doc.metadata.get("full_doc_ref")
193 | if new_full_doc_ref is None or new_full_doc_ref != old_full_doc_ref:
194 | num_full_docs += 1
195 | old_full_doc_ref = new_full_doc_ref
196 | num_docs = new_num_docs
197 |
198 | self.idx_first_not_done += num_docs
199 | logger.debug(f"Returning {num_docs} docs")
200 | return self.docs[self.idx_first_not_done - num_docs : self.idx_first_not_done]
201 |
202 | def clear_done_docs(self):
203 | self.docs = self.docs[self.idx_first_not_done :]
204 | self.idx_first_not_done = 0
205 |
--------------------------------------------------------------------------------
/agentblocks/webprocess.py:
--------------------------------------------------------------------------------
1 | from typing import Callable
2 |
3 | from pydantic import BaseModel, Field
4 |
5 | from agentblocks.webretrieve import get_content_from_urls
6 | from utils.type_utils import Doc
7 | from utils.web import LinkData
8 |
9 | DEFAULT_MIN_OK_URLS = 5
10 | DEFAULT_INIT_BATCH_SIZE = 0 # 0 = "auto-determined"
11 |
12 |
13 | class URLConveyer(BaseModel):
14 | urls: list[str]
15 | link_data_dict: dict[str, LinkData] = Field(default_factory=dict)
16 | # num_ok_urls: int = 0
17 | idx_first_not_done: int = 0 # done = "pushed out" by get_next_docs
18 | idx_first_not_tried: int = 0 # not necessarily equal to len(link_data_dict)
19 | idx_last_url_refresh: int = 0
20 |
21 | num_url_retrievals: int = 0
22 |
23 | default_min_ok_urls: int = DEFAULT_MIN_OK_URLS
24 | default_init_batch_size: int = DEFAULT_INIT_BATCH_SIZE # 0 = "auto-determined"
25 |
26 | @property
27 | def num_tried_urls_since_refresh(self) -> int:
28 | return self.idx_first_not_tried - self.idx_last_url_refresh
29 |
30 | @property
31 | def num_untried_urls(self) -> int:
32 | return len(self.urls) - self.idx_first_not_tried
33 |
34 | def refresh_urls(
35 | self,
36 | urls: list[str],
37 | only_add_truly_new: bool = True,
38 | idx_to_cut_at: int | None = None,
39 | ):
40 | if idx_to_cut_at is None:
41 | idx_to_cut_at = self.idx_first_not_tried
42 | elif idx_to_cut_at < 0 or idx_to_cut_at > len(self.urls):
43 | raise ValueError(f"idx_to_cut_at must be between 0 and {len(self.urls)}")
44 |
45 | del self.urls[idx_to_cut_at:]
46 |
47 | if only_add_truly_new:
48 | old_urls = set(self.urls)
49 | urls = [url for url in urls if url not in old_urls]
50 |
51 | self.urls.extend(urls)
52 | self.idx_last_url_refresh = idx_to_cut_at
53 |
54 | def retrieve_content_from_urls(
55 | self,
56 | min_ok_urls: int | None = None, # use default_min_ok_urls if None
57 | init_batch_size: int | None = None, # use default_init_batch_size if None
58 | batch_fetcher: Callable[[list[str]], list[str]] | None = None,
59 | ):
60 | url_retrieval_data = get_content_from_urls(
61 | urls=self.urls[self.idx_first_not_tried :],
62 | min_ok_urls=min_ok_urls
63 | if min_ok_urls is not None
64 | else self.default_min_ok_urls,
65 | init_batch_size=init_batch_size
66 | if init_batch_size is not None
67 | else self.default_init_batch_size,
68 | batch_fetcher=batch_fetcher,
69 | )
70 |
71 | self.num_url_retrievals += 1
72 | self.link_data_dict.update(url_retrieval_data.link_data_dict)
73 | self.idx_first_not_tried += url_retrieval_data.idx_first_not_tried
74 |
75 | def get_next_docs(self) -> list[Doc]:
76 | docs = []
77 | for url in self.urls[self.idx_first_not_done : self.idx_first_not_tried]:
78 | link_data = self.link_data_dict[url]
79 | if link_data.error:
80 | continue
81 | doc = Doc(page_content=link_data.text, metadata={"source": url})
82 | if link_data.num_tokens is not None:
83 | doc.metadata["num_tokens"] = link_data.num_tokens
84 | docs.append(doc)
85 |
86 | self.idx_first_not_done = self.idx_first_not_tried
87 | return docs
88 |
89 | def get_next_docs_with_url_retrieval(
90 | self,
91 | min_ok_urls: int | None = None, # use default_min_ok_urls if None
92 | init_batch_size: int | None = None, # use default_init_batch_size if None
93 | batch_fetcher: Callable[[list[str]], list[str]] | None = None,
94 | ) -> list[Doc]:
95 | if docs := self.get_next_docs():
96 | return docs
97 |
98 | # If no more docs, retrieve more content from URLs
99 | self.retrieve_content_from_urls(
100 | min_ok_urls=min_ok_urls,
101 | init_batch_size=init_batch_size,
102 | batch_fetcher=batch_fetcher,
103 | )
104 | return self.get_next_docs()
105 |
--------------------------------------------------------------------------------
/agentblocks/webretrieve.py:
--------------------------------------------------------------------------------
1 | from typing import Callable
2 |
3 | from pydantic import BaseModel, Field
4 |
5 | from utils.prepare import get_logger
6 | from utils.type_utils import DDGError
7 | from utils.web import LinkData, get_batch_url_fetcher
8 |
9 | logger = get_logger()
10 |
11 |
12 | class URLRetrievalData(BaseModel):
13 | urls: list[str]
14 | link_data_dict: dict[str, LinkData] = Field(default_factory=dict)
15 | num_ok_urls: int = 0
16 | idx_first_not_tried: int = 0 # different from len(link_data_dict) if urls repeat
17 |
18 |
19 | MAX_INIT_BATCH_SIZE = 10
20 |
21 |
22 | def get_content_from_urls(
23 | urls: list[str],
24 | min_ok_urls: int,
25 | init_batch_size: int = 0, # auto-determined if 0
26 | batch_fetcher: Callable[[list[str]], list[str]] | None = None,
27 | ) -> URLRetrievalData:
28 | """
29 | Fetch content from a list of urls using a batch fetcher. If at least
30 | min_ok_urls urls are fetched successfully, return the fetched content.
31 | Otherwise, fetch a new batch of urls, and repeat until at least min_ok_urls
32 | urls are fetched successfully.
33 |
34 | If there are duplicate URLs
35 |
36 | Args:
37 | - urls: list of urls to fetch content from
38 | - min_ok_urls: minimum number of urls that need to be fetched successfully
39 | - init_batch_size: initial batch size to fetch content from
40 | - batch_fetcher: function to fetch content from a batch of urls
41 |
42 | Returns:
43 | - URLRetrievalData: object containing the fetched content
44 | """
45 | try:
46 | batch_fetcher = batch_fetcher or get_batch_url_fetcher()
47 | init_batch_size = init_batch_size or min(
48 | MAX_INIT_BATCH_SIZE, round(min_ok_urls * 1.2)
49 | ) # NOTE: could optimize
50 |
51 | logger.info(
52 | f"Fetching content from {len(urls)} urls:\n"
53 | f" - {min_ok_urls} successfully obtained URLs needed\n"
54 | f" - {init_batch_size} is the initial batch size\n"
55 | )
56 |
57 | res = URLRetrievalData(urls=urls)
58 | num_urls = len(urls)
59 | url_set = set() # to keep track of unique urls
60 |
61 | # If, say, only 3 ok urls are still needed, we might want to try fetching 3 + extra
62 | num_extras = max(2, init_batch_size - min_ok_urls)
63 |
64 | while res.num_ok_urls < min_ok_urls:
65 | batch_size = min(
66 | init_batch_size,
67 | min_ok_urls - res.num_ok_urls + num_extras,
68 | )
69 | batch_urls = []
70 | for url in urls[res.idx_first_not_tried :]:
71 | if len(batch_urls) == batch_size:
72 | break
73 | if url in url_set:
74 | continue
75 | batch_urls.append(url)
76 | url_set.add(url)
77 | res.idx_first_not_tried += 1
78 |
79 | if (batch_size := len(batch_urls)) == 0:
80 | break # no more urls to fetch
81 |
82 | logger.info(f"Fetching {batch_size} urls:\n- " + "\n- ".join(batch_urls))
83 |
84 | # Fetch content from urls in batch
85 | batch_htmls = batch_fetcher(batch_urls)
86 |
87 | # Process fetched content
88 | for url, html in zip(batch_urls, batch_htmls):
89 | link_data = LinkData.from_raw_content(html)
90 | res.link_data_dict[url] = link_data
91 | if not link_data.error:
92 | res.num_ok_urls += 1
93 |
94 | logger.info(
95 | f"Total URLs processed: {res.idx_first_not_tried} ({num_urls} total)\n"
96 | f"Total successful URLs: {res.num_ok_urls} ({min_ok_urls} needed)\n"
97 | )
98 |
99 | return res
100 | except Exception as e:
101 | raise DDGError(
102 | user_facing_message="Apologies, I ran into a problem trying to fetch URL content."
103 | ) from e
104 |
--------------------------------------------------------------------------------
/agentblocks/websearch.py:
--------------------------------------------------------------------------------
1 | from typing import Any
2 | from pydantic import BaseModel
3 |
4 | from agentblocks.core import enforce_pydantic_json
5 | from utils.algo import interleave_iterables, remove_duplicates_keep_order
6 | from utils.async_utils import gather_tasks_sync
7 | from utils.chat_state import ChatState
8 | from utils.prepare import get_logger
9 | from utils.type_utils import DDGError
10 | from langchain_community.utilities import GoogleSerperAPIWrapper
11 |
12 | logger = get_logger()
13 |
14 | WEB_SEARCH_API_ISSUE_MSG = (
15 | "Apologies, it seems I'm having an issue with the API I use to search the web."
16 | )
17 |
18 | domain_blacklist = ["youtube.com"]
19 |
20 |
21 | class WebSearchAPIError(DDGError):
22 | # NOTE: Can be raised e.g. like this: raise WebSearchAPIError() from e
23 | # In that case, the original exception will be stored in self.__cause__
24 | default_user_facing_message = WEB_SEARCH_API_ISSUE_MSG
25 |
26 |
27 | def _extract_domain(url: str):
28 | try:
29 | full_domain = url.split("://")[-1].split("/")[0] # blah.blah.domain.com
30 | return ".".join(full_domain.split(".")[-2:]) # domain.com
31 | except Exception:
32 | return ""
33 |
34 |
35 | def get_links_from_search_results(search_results: list[dict[str, Any]]):
36 | logger.debug(
37 | f"Getting links from results of {len(search_results)} searches "
38 | f"with {[len(s) for s in search_results]} links each."
39 | )
40 | links_for_each_query = [
41 | [x["link"] for x in search_result.get("organic", []) if "link" in x]
42 | for search_result in search_results
43 | ] # [[links for query 1], [links for query 2], ...]
44 |
45 | logger.debug(
46 | f"Number of links for each query: {[len(links) for links in links_for_each_query]}"
47 | )
48 | # NOTE: can ask LLM to decide which links to keep
49 | return [
50 | link
51 | for link in remove_duplicates_keep_order(
52 | interleave_iterables(links_for_each_query)
53 | )
54 | if _extract_domain(link) not in domain_blacklist
55 | ]
56 |
57 |
58 | def get_links_from_queries(
59 | queries: list[str], num_search_results: int = 10
60 | ) -> list[str]:
61 | """
62 | Get links from a list of queries by doing a Google search for each query.
63 | """
64 | try:
65 | # Do a Google search for each query
66 | logger.info(f"Performing web search for queries: {queries}")
67 | search = GoogleSerperAPIWrapper(k=num_search_results)
68 | search_tasks = [search.aresults(query) for query in queries]
69 | search_results = gather_tasks_sync(search_tasks) # can use serper's batching
70 |
71 | # Check for errors
72 | for search_result in search_results:
73 | try:
74 | if search_result["statusCode"] // 100 != 2:
75 | logger.error(f"Error in search result: {search_result}")
76 | # search_result can be {"statusCode": 400, "message": "Not enough credits"}
77 | raise WebSearchAPIError() # TODO: add message, make sure it gets logged
78 | except KeyError:
79 | logger.warning("No status code in search result, assuming success.")
80 |
81 | # Get links from search results
82 | return get_links_from_search_results(search_results)
83 | except WebSearchAPIError as e:
84 | raise e
85 | except Exception as e:
86 | raise WebSearchAPIError() from e
87 |
88 |
89 | class Queries(BaseModel):
90 | queries: list[str]
91 |
92 |
93 | def get_web_search_queries_from_prompt(
94 | prompt, inputs: dict[str, str], chat_state: ChatState
95 | ) -> list[str]:
96 | """
97 | Get web search queries from a prompt. The prompt should ask the LLM to generate
98 | web search queries in the format {"queries": ["query1", "query2", ...]}
99 | """
100 | logger.info("Submitting prompt to get web search queries")
101 | query_generator_chain = chat_state.get_prompt_llm_chain(prompt, to_user=False)
102 |
103 | return enforce_pydantic_json(query_generator_chain, inputs, Queries).queries
104 |
--------------------------------------------------------------------------------
/agents/exporter.py:
--------------------------------------------------------------------------------
1 | from utils.chat_state import ChatState
2 | from utils.helpers import (
3 | DELIMITER80_NONL,
4 | EXPORT_COMMAND_HELP_MSG,
5 | format_nonstreaming_answer,
6 | )
7 | from utils.query_parsing import ExportCommand, get_command, get_int, get_int_or_command
8 | from utils.type_utils import INSTRUCT_EXPORT_CHAT_HISTORY, Instruction, Props
9 |
10 | REVERSE_COMMAND = "reverse"
11 |
12 |
13 | def get_exporter_response(chat_state: ChatState) -> Props:
14 | message = chat_state.message
15 | subcommand = chat_state.parsed_query.export_command
16 |
17 | if subcommand == ExportCommand.NONE:
18 | return format_nonstreaming_answer(EXPORT_COMMAND_HELP_MSG)
19 |
20 | if subcommand == ExportCommand.CHAT:
21 | max_messages = 1000000000
22 | is_reverse = False
23 | cmd, rest = get_int_or_command(message, [REVERSE_COMMAND])
24 | if isinstance(cmd, int):
25 | max_messages = cmd
26 | cmd, rest = get_command(rest, [REVERSE_COMMAND])
27 | if cmd is not None:
28 | is_reverse = True
29 | elif rest: # invalid non-empty command
30 | return format_nonstreaming_answer(EXPORT_COMMAND_HELP_MSG)
31 | elif cmd == REVERSE_COMMAND:
32 | is_reverse = True
33 | cmd, rest = get_int(rest)
34 | if cmd is not None:
35 | max_messages = cmd
36 | elif rest: # invalid non-empty command after reverse
37 | return format_nonstreaming_answer(EXPORT_COMMAND_HELP_MSG)
38 | elif rest: # cmd is None but message is not empty (invalid command)
39 | return format_nonstreaming_answer(EXPORT_COMMAND_HELP_MSG)
40 |
41 | if max_messages <= 0:
42 | return format_nonstreaming_answer(EXPORT_COMMAND_HELP_MSG)
43 |
44 | # Get and format the last num_messages messages
45 | msgs = []
46 | for (user_msg, ai_msg), sources in zip(
47 | reversed(chat_state.chat_history), reversed(chat_state.sources_history)
48 | ):
49 | if len(msgs) >= max_messages:
50 | break
51 | if ai_msg is not None or sources:
52 | ai_msg_full = f"**DDG:** {ai_msg or ''}"
53 | if sources:
54 | tmp = "\n- ".join(sources)
55 | ai_msg_full += f"\n\n**SOURCES:**\n\n- {tmp}"
56 | msgs.append(ai_msg_full)
57 | if user_msg is not None:
58 | msgs.append(f"**USER:** {user_msg}")
59 |
60 | # We may have overshot by one message (didn't want to check twice in the loop)
61 | if len(msgs) > max_messages:
62 | msgs.pop()
63 |
64 | # Reverse the messages if needed, format them and return along with instructions
65 | data = f"\n\n{DELIMITER80_NONL}\n\n".join(msgs[:: 1 if is_reverse else -1])
66 | return format_nonstreaming_answer(
67 | "I have collected our chat history and sent it to the UI for export."
68 | ) | {
69 | "instructions": [Instruction(type=INSTRUCT_EXPORT_CHAT_HISTORY, data=data)],
70 | }
71 | raise ValueError(f"Invalid export subcommand: {subcommand}")
72 |
--------------------------------------------------------------------------------
/agents/ingester_summarizer.py:
--------------------------------------------------------------------------------
1 | import uuid
2 |
3 | from icecream import ic
4 |
5 | from agentblocks.collectionhelper import ingest_into_collection
6 | from agents.dbmanager import (
7 | get_access_role,
8 | get_full_collection_name,
9 | get_user_facing_collection_name,
10 | )
11 | from components.llm import get_prompt_llm_chain
12 | from utils.chat_state import ChatState
13 | from utils.helpers import (
14 | ADDITIVE_COLLECTION_PREFIX,
15 | INGESTED_DOCS_INIT_PREFIX,
16 | format_invalid_input_answer,
17 | format_nonstreaming_answer,
18 | )
19 | from utils.lang_utils import limit_tokens_in_text, limit_tokens_in_texts
20 | from utils.prepare import CONTEXT_LENGTH
21 | from utils.prompts import SUMMARIZER_PROMPT
22 | from utils.query_parsing import IngestCommand
23 | from utils.type_utils import INSTRUCT_SHOW_UPLOADER, AccessRole, ChatMode, Instruction
24 | from utils.web import LinkData, get_batch_url_fetcher
25 | from langchain_core.documents import Document
26 |
27 | DEFAULT_MAX_TOKENS_FINAL_CONTEXT = int(CONTEXT_LENGTH * 0.7)
28 |
29 | NO_EDITOR_ACCESS_STATUS = "No editor access to collection" # a bit of duplication
30 | NO_MULTIPLE_INGESTION_SOURCES_STATUS = (
31 | "Cannot ingest uploaded files and an external resource at the same time"
32 | )
33 | DOC_SEPARATOR = "\n" + "-" * 40 + "\n\n"
34 |
35 |
36 | def summarize(docs: list[Document], chat_state: ChatState) -> str:
37 | if (num_docs := len(docs)) == 0:
38 | return ""
39 |
40 | summarizer_chain = get_prompt_llm_chain(
41 | SUMMARIZER_PROMPT,
42 | llm_settings=chat_state.bot_settings,
43 | api_key=chat_state.openai_api_key,
44 | callbacks=chat_state.callbacks,
45 | )
46 |
47 | if num_docs == 1:
48 | shortened_text, num_tokens = limit_tokens_in_text(
49 | docs[0].page_content, DEFAULT_MAX_TOKENS_FINAL_CONTEXT
50 | ) # TODO: ability to summarize longer content
51 | shortened_texts = [shortened_text]
52 | else:
53 | # Multiple documents
54 | shortened_texts, nums_of_tokens = limit_tokens_in_texts(
55 | [doc.page_content for doc in docs], DEFAULT_MAX_TOKENS_FINAL_CONTEXT
56 | )
57 |
58 | # Construct the final context
59 | final_texts = []
60 | for i, (doc, shortened_text) in enumerate(zip(docs, shortened_texts)):
61 | if doc.page_content != shortened_text:
62 | suffix = "\n\nNOTE: The above content was truncated to fit the maximum token limit."
63 | else:
64 | suffix = ""
65 |
66 | final_texts.append(
67 | f"SOURCE: {doc.metadata.get('source', 'Unknown')}"
68 | f"\n\n{shortened_text}{suffix}"
69 | )
70 |
71 | final_context = DOC_SEPARATOR.join(final_texts)
72 | ic(final_context)
73 |
74 | return summarizer_chain.invoke({"content": final_context})
75 |
76 |
77 | def get_ingester_summarizer_response(chat_state: ChatState):
78 | # TODO: remove similar functionality in streamlit/ingest.py
79 | message = chat_state.parsed_query.message
80 | ingest_command = chat_state.parsed_query.ingest_command
81 |
82 | # If there's no message and no docs, just return request for files
83 | if not (docs := chat_state.uploaded_docs) and not message:
84 | return {
85 | "answer": "Please select your documents to upload and ingest.",
86 | "instructions": [Instruction(type=INSTRUCT_SHOW_UPLOADER)],
87 | }
88 |
89 | # Determine if we need to use the same collection or create a new one
90 | coll_name_as_shown = get_user_facing_collection_name(
91 | chat_state.user_id, chat_state.vectorstore.name
92 | )
93 | if ingest_command == IngestCommand.ADD or (
94 | ingest_command == IngestCommand.DEFAULT
95 | and coll_name_as_shown.startswith(ADDITIVE_COLLECTION_PREFIX)
96 | ):
97 | # We will use the same collection
98 | is_new_collection = False
99 | coll_name_full = chat_state.vectorstore.name
100 |
101 | # Check for editor access
102 | if get_access_role(chat_state).value < AccessRole.EDITOR.value:
103 | cmd_str = (
104 | "summarize"
105 | if chat_state.chat_mode == ChatMode.SUMMARIZE_COMMAND_ID
106 | else "ingest"
107 | )
108 | return format_invalid_input_answer(
109 | "Apologies, you can't ingest content into the current collection "
110 | "because you don't have editor access to it. You can ingest content "
111 | "into a new collection instead. For example:\n\n"
112 | f"```\n\n/{cmd_str} new {message}\n\n```",
113 | NO_EDITOR_ACCESS_STATUS,
114 | )
115 | else:
116 | # We will need to create a new collection
117 | is_new_collection = True
118 | coll_name_as_shown = INGESTED_DOCS_INIT_PREFIX + uuid.uuid4().hex[:8]
119 | coll_name_full = get_full_collection_name(
120 | chat_state.user_id, coll_name_as_shown
121 | )
122 |
123 | # If there are uploaded docs, ingest or summarize them
124 | if docs:
125 | if message:
126 | return format_invalid_input_answer(
127 | "Apologies, you can't simultaneously ingest uploaded files and "
128 | f"an external resource ({message}).",
129 | NO_MULTIPLE_INGESTION_SOURCES_STATUS,
130 | )
131 | if chat_state.chat_mode == ChatMode.INGEST_COMMAND_ID:
132 | res = format_nonstreaming_answer(
133 | "The files you uploaded have been ingested into the collection "
134 | f"`{coll_name_as_shown}`. If you don't need to ingest "
135 | "more content into it, rename it with `/db rename my-cool-collection-name`."
136 | )
137 | else:
138 | res = {"answer": summarize(docs, chat_state)}
139 | else:
140 | # If no uploaded docs, ingest or summarize the external resource
141 | fetch_func = get_batch_url_fetcher() # don't really need the batch aspect here
142 | html = fetch_func([message])[0]
143 | link_data = LinkData.from_raw_content(html)
144 |
145 | if link_data.error:
146 | return format_nonstreaming_answer(
147 | f"Apologies, I could not retrieve the resource `{message}`."
148 | )
149 |
150 | docs = [Document(page_content=link_data.text, metadata={"source": message})]
151 |
152 | if chat_state.chat_mode == ChatMode.INGEST_COMMAND_ID:
153 | # "/ingest https://some.url.com" command - just ingest, don't summarize
154 | res = format_nonstreaming_answer(
155 | f"The resource `{message}` has been ingested into the collection "
156 | f"`{coll_name_as_shown}`. Feel free to ask questions about the collection's "
157 | "content. If you don't need to ingest "
158 | "more resources into it, rename it with `/db rename my-cool-collection-name`."
159 | )
160 | else:
161 | res = {"answer": summarize(docs, chat_state)}
162 |
163 | # Ingest into the collection
164 | coll_metadata = {} if is_new_collection else chat_state.fetch_collection_metadata()
165 | vectorstore = ingest_into_collection(
166 | collection_name=coll_name_full,
167 | docs=docs,
168 | collection_metadata=coll_metadata,
169 | chat_state=chat_state,
170 | is_new_collection=is_new_collection,
171 | )
172 |
173 | if is_new_collection:
174 | # Switch to the newly created collection
175 | res["vectorstore"] = vectorstore
176 |
177 | return res
178 |
--------------------------------------------------------------------------------
/agents/researcher_data.py:
--------------------------------------------------------------------------------
1 | from pydantic import BaseModel, Field, model_validator
2 |
3 | from utils.web import LinkData
4 |
5 |
6 | class Report(BaseModel):
7 | report_text: str
8 | sources: list[str] = Field(default_factory=list)
9 | parent_report_ids: list[str] = Field(default_factory=list)
10 | child_report_id: str | None = None
11 | evaluation: str | None = None
12 |
13 |
14 | class ResearchReportData(BaseModel):
15 | query: str
16 | search_queries: list[str]
17 | report_type: str
18 | unprocessed_links: list[str]
19 | processed_links: list[str]
20 | link_data_dict: dict[str, LinkData]
21 | max_tokens_final_context: int
22 | main_report: str = ""
23 | base_reports: list[Report] = Field(default_factory=list)
24 | combined_reports: list[Report] = Field(default_factory=list)
25 | combined_report_id_levels: list[list[str]] = Field(default_factory=list) # levels
26 | num_obtained_unprocessed_links: int = 0
27 | num_obtained_unprocessed_ok_links: int = 0
28 | num_links_from_latest_queries: int | None = None
29 | evaluation: str | None = None
30 |
31 | @model_validator(mode="after")
32 | def validate(self):
33 | if self.num_links_from_latest_queries is None:
34 | self.num_links_from_latest_queries = len(self.unprocessed_links)
35 | return self
36 |
37 | @property
38 | def num_processed_links_from_latest_queries(self) -> int:
39 | # All unprocessed links are from the latest queries so we just subtract
40 | return self.num_links_from_latest_queries - len(self.unprocessed_links)
41 |
42 | def is_base_report(self, id: str) -> bool:
43 | return not id.startswith("c")
44 |
45 | def is_report_childless(self, id: str) -> bool:
46 | return self.get_report_by_id(id).child_report_id is None
47 |
48 | def get_report_by_id(self, id) -> Report:
49 | return (
50 | self.base_reports[int(id)]
51 | if self.is_base_report(id)
52 | else self.combined_reports[int(id[1:])]
53 | )
54 |
55 | def get_parent_reports(self, report: Report) -> list[Report]:
56 | return [
57 | self.get_report_by_id(parent_id) for parent_id in report.parent_report_ids
58 | ]
59 |
60 | def get_ancestor_ids(self, report: Report) -> list[int]:
61 | res = []
62 | for parent_id in report.parent_report_ids:
63 | if self.is_base_report(parent_id):
64 | res.append(parent_id)
65 | else:
66 | res.extend(self.get_ancestor_ids(self.get_report_by_id(parent_id)))
67 |
68 | def get_sources(self, report: Report) -> list[str]:
69 | res = report.sources.copy()
70 | for parent_report in self.get_parent_reports(report):
71 | res.extend(self.get_sources(parent_report))
72 | return res
73 |
--------------------------------------------------------------------------------
/agents/share_manager.py:
--------------------------------------------------------------------------------
1 | from agents.dbmanager import get_access_role, is_main_owner
2 | from utils.chat_state import ChatState
3 | from utils.helpers import SHARE_COMMAND_HELP_MSG, format_nonstreaming_answer
4 | from utils.prepare import DOMAIN_NAME_FOR_SHARING
5 | from utils.query_parsing import ShareCommand, ShareRevokeSubCommand
6 | from utils.type_utils import (
7 | AccessCodeSettings,
8 | AccessCodeType,
9 | AccessRole,
10 | Props,
11 | )
12 |
13 | share_type_to_access_role = {
14 | ShareCommand.EDITOR: AccessRole.EDITOR,
15 | ShareCommand.VIEWER: AccessRole.VIEWER,
16 | ShareCommand.OWNER: AccessRole.OWNER,
17 | }
18 |
19 |
20 | def handle_share_command(chat_state: ChatState) -> Props:
21 | """Handle the /share command."""
22 | share_params = chat_state.parsed_query.share_params
23 | if share_params.share_type == ShareCommand.NONE:
24 | return format_nonstreaming_answer(SHARE_COMMAND_HELP_MSG)
25 |
26 | # Check that the user is an owner
27 | access_role = get_access_role(chat_state)
28 | if access_role.value < AccessRole.OWNER.value:
29 | return format_nonstreaming_answer(
30 | "Apologies, you don't have owner-level access to this collection. "
31 | f"Your current access level: {access_role.name.lower()}."
32 | )
33 | # NOTE: this introduces redundant fetching of metadata (in get_access_role
34 | # and save_access_code_settings. Can optimize later.
35 |
36 | if share_params.share_type in (
37 | ShareCommand.EDITOR,
38 | ShareCommand.VIEWER,
39 | ShareCommand.OWNER,
40 | ):
41 | # Check that the access code is not empty
42 | if (access_code := share_params.access_code) is None:
43 | return format_nonstreaming_answer(
44 | "Apologies, you need to specify an access code."
45 | f"\n\n{SHARE_COMMAND_HELP_MSG}"
46 | )
47 |
48 | # Check that the access code type is not empty
49 | if (code_type := share_params.access_code_type) is None:
50 | return format_nonstreaming_answer(
51 | "Apologies, you need to specify an access code type."
52 | f"\n\n{SHARE_COMMAND_HELP_MSG}"
53 | )
54 | if code_type != AccessCodeType.NEED_ALWAYS:
55 | return format_nonstreaming_answer(
56 | "Apologies, Dmitriy hasn't implemented this code type for me yet."
57 | )
58 |
59 | # Check that the access code contains only letters and numbers
60 | if not access_code.isalnum():
61 | return format_nonstreaming_answer(
62 | "Apologies, the access code can only contain letters and numbers."
63 | )
64 |
65 | # Save the access code and its settings
66 | access_code_settings = AccessCodeSettings(
67 | code_type=code_type,
68 | access_role=share_type_to_access_role[share_params.share_type],
69 | )
70 | chat_state.save_access_code_settings(
71 | access_code=share_params.access_code,
72 | access_code_settings=access_code_settings,
73 | )
74 |
75 | # Form share link
76 | link = (
77 | f"{DOMAIN_NAME_FOR_SHARING}?collection={chat_state.vectorstore.name}"
78 | f"&access_code={share_params.access_code}"
79 | )
80 |
81 | # Return the response with the share link
82 | return format_nonstreaming_answer(
83 | "The current collection has been shared. Here's the link you can send "
84 | "to the people you want to grant access to:\n\n"
85 | f"```\n{link}\n```"
86 | )
87 |
88 | if share_params.share_type == ShareCommand.REVOKE:
89 | collection_permissions = chat_state.get_collection_permissions()
90 | match share_params.revoke_type:
91 | case ShareRevokeSubCommand.ALL_CODES:
92 | collection_permissions.access_code_to_settings = {}
93 | ans = "All access codes have been revoked."
94 |
95 | case ShareRevokeSubCommand.ALL_USERS:
96 | # If user is not the main owner, don't delete their owner status
97 | saved_entries = {}
98 | if is_main_owner(chat_state):
99 | ans = "All users except you have been revoked."
100 | else:
101 | try:
102 | saved_entries[chat_state.user_id] = (
103 | collection_permissions.user_id_to_settings[
104 | chat_state.user_id
105 | ]
106 | )
107 | ans = (
108 | "All users except the main owner and you have been revoked."
109 | )
110 | except KeyError:
111 | ans = "All users except the main owner have been revoked (including you)."
112 |
113 | # Clear all user settings except saved_entries
114 | collection_permissions.user_id_to_settings = saved_entries
115 |
116 | case ShareRevokeSubCommand.CODE:
117 | code = (share_params.code_or_user_to_revoke or "").strip()
118 | if not code:
119 | return format_nonstreaming_answer(
120 | "Apologies, you need to specify an access code to revoke."
121 | )
122 | try:
123 | collection_permissions.access_code_to_settings.pop(code)
124 | ans = "The access code has been revoked."
125 | except KeyError:
126 | return format_nonstreaming_answer(
127 | "Apologies, the access code you specified doesn't exist."
128 | )
129 |
130 | case ShareRevokeSubCommand.USER:
131 | user_id = (share_params.code_or_user_to_revoke or "").strip()
132 | if not user_id:
133 | return format_nonstreaming_answer(
134 | "Apologies, you need to specify a user to revoke."
135 | )
136 | if user_id == "public":
137 | user_id = ""
138 | try:
139 | print(collection_permissions.user_id_to_settings)
140 | collection_permissions.user_id_to_settings.pop(user_id)
141 | print(collection_permissions.user_id_to_settings)
142 | ans = "The user has been revoked."
143 | except KeyError:
144 | return format_nonstreaming_answer(
145 | "Apologies, the user you specified doesn't have a stored access role."
146 | )
147 | case _:
148 | return format_nonstreaming_answer(
149 | "Apologies, Dmitriy hasn't implemented this share subcommand for me yet."
150 | )
151 |
152 | chat_state.save_collection_permissions(
153 | collection_permissions, use_cached_metadata=True
154 | )
155 | return format_nonstreaming_answer(ans)
156 |
157 | return format_nonstreaming_answer(
158 | "Apologies, Dmitriy hasn't implemented this share subcommand for me yet."
159 | )
160 |
--------------------------------------------------------------------------------
/agents/websearcher_quick.py:
--------------------------------------------------------------------------------
1 | from datetime import datetime
2 |
3 | from agentblocks.websearch import get_links_from_search_results
4 | from components.llm import get_prompt_llm_chain
5 | from utils.async_utils import gather_tasks_sync, make_sync
6 | from utils.chat_state import ChatState
7 | from utils.lang_utils import get_num_tokens, limit_tokens_in_texts
8 | from utils.prepare import CONTEXT_LENGTH
9 | from utils.prompts import RESEARCHER_PROMPT_SIMPLE
10 | from utils.web import (
11 | afetch_urls_in_parallel_playwright,
12 | get_text_from_html,
13 | remove_failed_fetches,
14 | )
15 | from langchain_community.utilities import GoogleSerperAPIWrapper
16 |
17 |
18 | def get_related_websearch_queries(message: str):
19 | search = GoogleSerperAPIWrapper()
20 | search_results = search.results(message)
21 | # print("search results:", json.dumps(search_results, indent=4))
22 | related_searches = [x["query"] for x in search_results.get("relatedSearches", [])]
23 | people_also_ask = [x["question"] for x in search_results.get("peopleAlsoAsk", [])]
24 |
25 | return related_searches, people_also_ask
26 |
27 |
28 | def get_websearcher_response_quick(
29 | chat_state: ChatState, max_queries: int = 3, max_total_links: int = 9
30 | ):
31 | """
32 | Perform quick web research on a query, with only one call to the LLM.
33 |
34 | It gets related queries to search for by examining the "related searches" and
35 | "people also ask" sections of the Google search results for the query. It then
36 | searches for each of these queries on Google, and gets the top links for each
37 | query. It then fetches the content from each link, shortens them to fit all
38 | texts into one context window, and then sends it to the LLM to generate a report.
39 | """
40 | message = chat_state.message
41 | related_searches, people_also_ask = get_related_websearch_queries(message)
42 |
43 | # Get queries to search for
44 | queries = [message] + [
45 | query for query in related_searches + people_also_ask if query != message
46 | ]
47 | queries = queries[:max_queries]
48 | print("queries:", queries)
49 |
50 | # Get links
51 | search = GoogleSerperAPIWrapper()
52 | search_tasks = [search.aresults(query) for query in queries]
53 | search_results = gather_tasks_sync(search_tasks)
54 | links = get_links_from_search_results(search_results)[:max_total_links]
55 | print("Links:", links)
56 |
57 | # Get content from links, measuring time taken
58 | t_start = datetime.now()
59 | print("Fetching content from links...")
60 | # htmls = make_sync(afetch_urls_in_parallel_chromium_loader)(links)
61 | htmls = make_sync(afetch_urls_in_parallel_playwright)(links)
62 | # htmls = make_sync(afetch_urls_in_parallel_html_loader)(links)
63 | # htmls = fetch_urls_with_lc_html_loader(links) # takes ~20s, slow
64 | t_fetch_end = datetime.now()
65 |
66 | print("Processing content...")
67 | texts = [get_text_from_html(html) for html in htmls]
68 | ok_texts, ok_links = remove_failed_fetches(texts, links)
69 |
70 | processed_texts, token_counts = limit_tokens_in_texts(
71 | ok_texts, max_tot_tokens=int(CONTEXT_LENGTH * 0.5)
72 | )
73 | processed_texts = [
74 | f"SOURCE: {link}\nCONTENT:\n{text}\n====="
75 | for text, link in zip(processed_texts, ok_links)
76 | ]
77 | texts_str = "\n\n".join(processed_texts)
78 | t_end = datetime.now()
79 |
80 | print("CONTEXT FOR GENERATING REPORT:\n")
81 | print(texts_str + "\n" + "=" * 100 + "\n")
82 | print("Original number of links:", len(links))
83 | print("Number of links after removing unsuccessfully fetched ones:", len(ok_links))
84 | print("Time taken to fetch sites:", t_fetch_end - t_start)
85 | print("Time taken to process sites:", t_end - t_fetch_end)
86 | print("Number of resulting tokens:", get_num_tokens(texts_str))
87 |
88 | chain = get_prompt_llm_chain(
89 | RESEARCHER_PROMPT_SIMPLE,
90 | llm_settings=chat_state.bot_settings,
91 | api_key=chat_state.openai_api_key,
92 | callbacks=chat_state.callbacks,
93 | stream=True,
94 | )
95 | answer = chain.invoke({"texts_str": texts_str, "query": message})
96 | return {"answer": answer}
97 |
--------------------------------------------------------------------------------
/announcements/version-0-2-6.md:
--------------------------------------------------------------------------------
1 | ### GPT-4o mini and other updates
2 |
3 | First, DocDocGo had a teeny-tiny issue with its database growing too gigantic for its VM. The issue is now fixed, so if you tried to use DocDocGo and couldn't access it, my apologies - it's now back in action! :muscle::
4 |
5 | Updates:
6 |
7 | **For users:**
8 |
9 | 1. The default model is now **GPT-4o mini** (thanks, OpenAI! :rocket:) instead of GPT-3.5, so expect a significant improvement in the quality of responses. If you use DocDocGo with you own OpenAI API key, you can also switch to GPT-4o, GPT-4, or the old friend GPT-3.5.
10 | 2. The database has been upgraded and cleaned up for better performance.
11 |
12 | **For developers:**
13 |
14 | 1. The [developer docs](https://github.com/reasonmethis/docdocgo-core/blob/main/README-FOR-DEVELOPERS.md) have been significantly expanded to include lots of goodies on hosting the Chroma database server on AWS and GCP, running locally with Docker, cleaning it, etc. There is a lot more that can (and should) be added, so if something is unclear or you have any questions or suggestions, feel free to DM me here or on [LinkdIn](https://www.linkedin.com/in/dmitriyvasilyuk/).
15 | 2. In addition to cleaning up the database, Langchain and Chroma DB have been upgraded to the latest versions (0.2.10 and 0.5.4, respectively). A slew of deprecated imports have been updated, and a bug related to a breaking change in Langchain's `Chroma` has been found and squashed :beetle:.
16 |
17 | I hope you enjoy using DocDocGo with the shiny new GPT-4o mini! Because of the model switch and other major changes, it's possible that there may be some hiccups that I haven't caught yet. Feel free to reach out on [GitHub](https://github.com/reasonmethis/docdocgo-core) if you see any :spider:, and the exterminator will be dispatched posthaste!
18 |
19 | ## Formatted for LinkedIn
20 |
21 | ### DocDocGo updates: GPT-4o mini and other news ###
22 |
23 | - App: https://docdocgo.streamlit.app
24 | - GitHub: https://github.com/reasonmethis/docdocgo-core
25 |
26 | 🚀 The default model is now GPT-4o mini instead of GPT-3.5, so expect a significant improvement in the quality of responses. If you use DocDocGo with you own OpenAI API key, you can also switch to GPT-4o, GPT-4, or the old friend GPT-3.5.
27 |
28 | ⚡ The database has been upgraded and cleaned up for better performance.
29 |
30 | 📝 The developer docs (https://github.com/reasonmethis/docdocgo-core/blob/main/README-FOR-DEVELOPERS.md) have been significantly expanded to include lots of goodies on hosting the Chroma database server on AWS and GCP, running locally with Docker, cleaning it, etc. There is a lot more that can (and should) be added, so if you have any questions or suggestions, feel free to DM me.
31 |
32 | 🔧 In addition to cleaning up the database, Langchain and Chroma DB have been upgraded to the latest versions (0.2.10 and 0.5.4, respectively). A slew of deprecated imports have been updated, and a bug related to a breaking change in Langchain's `Chroma` has been found and squashed 🪲.
33 |
34 | 📢 If this is your first time hearing about DocDocGo, it's an AI assistant that helps you find information that requires sifting through far more than the first page of Google search results. It's kind of like if ChatGPT, Perplexity and GPT Researcher had a 👶
35 |
36 | I hope you enjoy using DocDocGo with the shiny new GPT-4o mini! Because of the model switch and other major changes, it's possible that there may be some hiccups that I haven't caught yet. Feel free to reach out here or on GitHub (https://github.com/reasonmethis/docdocgo-core) if you see any 🕷️, and the exterminator will be dispatched posthaste!
--------------------------------------------------------------------------------
/announcements/version-0-2.md:
--------------------------------------------------------------------------------
1 | :tada:**Announcing DocDocGo version 0.2**:tada:
2 |
3 | DocDocGo is growing up! Let's celebrate by describing the new features.
4 |
5 | **1. DDG finally has a logo that's not just the 🦉 character!**
6 |
7 | Its [design](https://github.com/reasonmethis/docdocgo-core/blob/9724fe27cd1a5812861feb018c2b45bed55af968/media/ddg-logo.png) continues the noble tradition of trying too hard to be clever, starting with the name "DocDocGo" :grin:
8 |
9 | @StreamlitTeam continues to ship - the new `st.logo` feature arrived just in time!
10 |
11 | **2. Re-engineered collections**
12 |
13 | With over 150 public collections and growing, it was time to manage them better. Now, collections track how recently they were created/updated. Use:
14 |
15 | - `/db list` - see the 20 most recent collections (in a cute table)
16 |
17 | - `/db list 21+` - see the next 20, and so on
18 |
19 | - `/db list blah` - list collections with "blah" in the name
20 |
21 | DDG will remind you of the availablej commands when you use `/db` or `/db list`.
22 |
23 | **Safety improvement:** Deleting a collection by number, e.g. `/db delete 42`, now only works if it was first listed by `/db list`.
24 |
25 | **3. Default modes**
26 |
27 | Tired of typing `/reseach` or `/research heatseek`? Let your fingers rest by selecting a default mode in the Streamlit sidebar.
28 |
29 | You can still override it when you need to - just type the desired command as usual, for example: `/help What's the difference between regular research and heatseek?`.
30 |
31 | **4. UI improvements**
32 |
33 | The UI has been refreshed with a new intro screen that has:
34 |
35 | - The one-click sample commands to get you started if you're new
36 |
37 | - The new and improved welcome message
38 |
39 | Take the new version for a spin and let me know what you think! :rocket:
40 |
41 | https://docdocgo.streamlit.app
42 |
--------------------------------------------------------------------------------
/components/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/reasonmethis/docdocgo-core/e6408ea9c7f4657525715730f893e8a3d27296a5/components/__init__.py
--------------------------------------------------------------------------------
/components/chroma_ddg.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import Any, Callable, Optional
3 |
4 | from chromadb import ClientAPI, Collection, HttpClient, PersistentClient
5 | from chromadb.api.types import Where # , WhereDocument
6 | from chromadb.config import Settings
7 | from langchain_community.vectorstores.chroma import _results_to_docs_and_scores
8 | from langchain_core.embeddings import Embeddings
9 |
10 | from components.openai_embeddings_ddg import get_openai_embeddings
11 | from utils.prepare import (
12 | CHROMA_SERVER_AUTHN_CREDENTIALS,
13 | CHROMA_SERVER_HOST,
14 | CHROMA_SERVER_HTTP_PORT,
15 | USE_CHROMA_VIA_HTTP,
16 | VECTORDB_DIR,
17 | get_logger,
18 | )
19 | from utils.type_utils import DDGError
20 | from langchain_community.vectorstores import Chroma
21 | from langchain_core.documents import Document
22 |
23 | logger = get_logger()
24 |
25 |
26 | class CollectionDoesNotExist(DDGError):
27 | """Exception raised when a collection does not exist."""
28 |
29 | pass # REVIEW
30 |
31 |
32 | class ChromaDDG(Chroma):
33 | """
34 | Modified Chroma vectorstore for DocDocGo.
35 |
36 | Specifically:
37 | 1. It enables the kwargs passed to similarity_search_with_score
38 | to be passed to the __query_collection method rather than ignored,
39 | which allows us to pass the 'where_document' parameter.
40 | 2. __init__ is overridden to allow the option of using get_collection (which doesn't create
41 | a collection if it doesn't exist) rather than always using get_or_create_collection (which does).
42 | """
43 |
44 | def __init__(
45 | self,
46 | *,
47 | collection_name: str,
48 | client: ClientAPI,
49 | create_if_not_exists: bool,
50 | embedding_function: Optional[Embeddings] = None,
51 | persist_directory: Optional[str] = None,
52 | client_settings: Optional[Settings] = None,
53 | collection_metadata: Optional[dict] = None,
54 | relevance_score_fn: Optional[Callable[[float], float]] = None,
55 | ) -> None:
56 | """Initialize with a Chroma client."""
57 | self._client_settings = client_settings
58 | self._client = client
59 | self._persist_directory = persist_directory
60 |
61 | self._embedding_function = embedding_function
62 | self.override_relevance_score_fn = relevance_score_fn
63 | logger.info(f"ChromaDDG init: {create_if_not_exists=}, {collection_name=}")
64 |
65 | if not create_if_not_exists and collection_metadata is not None:
66 | # We must check if the collection exists in this case because when metadata
67 | # is passed, we must use get_or_create_collection to update the metadata
68 | if not exists_collection(collection_name, client):
69 | raise CollectionDoesNotExist()
70 |
71 | if create_if_not_exists or collection_metadata is not None:
72 | self._collection = self._client.get_or_create_collection(
73 | name=collection_name,
74 | embedding_function=None,
75 | metadata=collection_metadata,
76 | )
77 | else:
78 | try:
79 | self._collection = self._client.get_collection(
80 | name=collection_name,
81 | embedding_function=None,
82 | )
83 | except Exception as e:
84 | logger.info(f"Failed to get collection {collection_name}: {str(e)}")
85 | raise CollectionDoesNotExist() if "does not exist" in str(e) else e
86 |
87 | def __bool__(self) -> bool:
88 | """Always return True to avoid ambiguity of what False could mean."""
89 | return True
90 |
91 | @property
92 | def name(self) -> str:
93 | """Name of the underlying chromadb collection."""
94 | return self._collection.name
95 |
96 | @property
97 | def collection(self) -> Collection:
98 | """The underlying chromadb collection."""
99 | return self._collection
100 |
101 | @property
102 | def client(self) -> ClientAPI:
103 | """The underlying chromadb client."""
104 | return self._client
105 |
106 | def get_cached_collection_metadata(self) -> dict[str, Any] | None:
107 | """Get locally cached metadata for the underlying chromadb collection."""
108 | return self._collection.metadata
109 |
110 | def fetch_collection_metadata(self) -> dict[str, Any]:
111 | """Fetch metadata for the underlying chromadb collection."""
112 | logger.info(f"Fetching metadata for collection {self.name}")
113 | self._collection = self._client.get_collection(
114 | self.name, embedding_function=self._collection._embedding_function
115 | )
116 | logger.info(f"Fetched metadata for collection {self.name}")
117 | return self._collection.metadata
118 |
119 | def save_collection_metadata(self, metadata: dict[str, Any]) -> None:
120 | """Set metadata for the underlying chromadb collection."""
121 | self._collection.modify(metadata=metadata)
122 |
123 | def rename_collection(self, new_name: str) -> None:
124 | """Rename the underlying chromadb collection."""
125 | self._collection.modify(name=new_name)
126 |
127 | def delete_collection(self, collection_name: str) -> None:
128 | """Delete the underlying chromadb collection."""
129 | self._client.delete_collection(collection_name)
130 |
131 | def similarity_search_with_score(
132 | self,
133 | query: str,
134 | k: int, # = DEFAULT_K,
135 | filter: Where | None = None,
136 | **kwargs: Any,
137 | ) -> list[tuple[Document, float]]:
138 | """
139 | Run similarity search with Chroma with distance.
140 |
141 | Args:
142 | query (str): Query text to search for.
143 | k (int): Number of results to return.
144 | filter (Where | None): Filter by metadata. Corresponds to the chromadb 'where'
145 | parameter. Defaults to None.
146 | **kwargs: Additional keyword arguments. Only 'where_document' is used, if present.
147 |
148 | Returns:
149 | list[tuple[Document, float]]: list of documents most similar to
150 | the query text and cosine distance in float for each.
151 | """
152 |
153 | # Determine if the passed kwargs contain a 'where_document' parameter
154 | # If so, we'll pass it to the __query_collection method
155 | try:
156 | possible_where_document_kwarg = {"where_document": kwargs["where_document"]}
157 | except KeyError:
158 | possible_where_document_kwarg = {}
159 |
160 | # Query by text or embedding, depending on whether an embedding function is present
161 | if self._embedding_function is None:
162 | results = self._Chroma__query_collection(
163 | query_texts=[query],
164 | n_results=k,
165 | where=filter,
166 | **possible_where_document_kwarg,
167 | )
168 | else:
169 | query_embedding = self._embedding_function.embed_query(query)
170 | results = self._Chroma__query_collection(
171 | query_embeddings=[query_embedding],
172 | n_results=k,
173 | where=filter,
174 | **possible_where_document_kwarg,
175 | )
176 |
177 | return _results_to_docs_and_scores(results)
178 |
179 |
180 | def exists_collection(
181 | collection_name: str,
182 | client: ClientAPI,
183 | ) -> bool:
184 | """
185 | Check if a collection exists.
186 | """
187 | # NOTE: Alternative: return collection_name in {x.name for x in client.list_collections()}
188 | try:
189 | client.get_collection(collection_name)
190 | return True
191 | except Exception as e: # Exception: {ValueError: "Collection 'test' does not exist"}
192 | if "does not exist" in str(e):
193 | return False # collection does not exist
194 | raise e
195 |
196 |
197 | def initialize_client(use_chroma_via_http: bool = USE_CHROMA_VIA_HTTP) -> ClientAPI:
198 | """
199 | Initialize a chroma client.
200 | """
201 | if use_chroma_via_http:
202 | return HttpClient(
203 | host=CHROMA_SERVER_HOST, # must provide host and port explicitly...
204 | port=CHROMA_SERVER_HTTP_PORT, # ...if the env vars are different from defaults
205 | settings=Settings(
206 | chroma_client_auth_provider="chromadb.auth.token_authn.TokenAuthClientProvider",
207 | chroma_auth_token_transport_header="X-Chroma-Token",
208 | chroma_client_auth_credentials=CHROMA_SERVER_AUTHN_CREDENTIALS,
209 | anonymized_telemetry=False,
210 | ),
211 | )
212 | if not isinstance(VECTORDB_DIR, str) or not os.path.isdir(VECTORDB_DIR):
213 | # NOTE: interestingly, isdir(None) returns True, hence the additional check
214 | raise ValueError(f"Invalid chromadb path: {VECTORDB_DIR}")
215 |
216 | return PersistentClient(VECTORDB_DIR, settings=Settings(anonymized_telemetry=False))
217 | # return Client(Settings(chroma_db_impl="duckdb+parquet", persist_directory=path))
218 |
219 |
220 | def ensure_chroma_client(client: ClientAPI | None = None) -> ClientAPI:
221 | """
222 | Ensure that a chroma client is initialized and return it.
223 | """
224 | return client or initialize_client()
225 |
226 |
227 | def get_vectorstore_using_openai_api_key(
228 | collection_name: str,
229 | *,
230 | openai_api_key: str,
231 | client: ClientAPI | None = None,
232 | create_if_not_exists: bool = False,
233 | ) -> ChromaDDG:
234 | """
235 | Load a ChromaDDG vectorstore from a given collection name and OpenAI API key.
236 | """
237 | return ChromaDDG(
238 | client=ensure_chroma_client(client),
239 | collection_name=collection_name,
240 | create_if_not_exists=create_if_not_exists,
241 | embedding_function=get_openai_embeddings(openai_api_key),
242 | )
243 |
--------------------------------------------------------------------------------
/components/chroma_ddg_retriever.py:
--------------------------------------------------------------------------------
1 | from typing import Any, ClassVar
2 |
3 | from chromadb.api.types import Where, WhereDocument
4 | from langchain_core.documents import Document
5 | from pydantic import Field
6 |
7 | from utils.helpers import DELIMITER, lin_interpolate
8 | from utils.lang_utils import expand_chunks
9 | from utils.prepare import CONTEXT_LENGTH, EMBEDDINGS_MODEL_NAME
10 | from langchain_core.callbacks import AsyncCallbackManagerForRetrieverRun, CallbackManagerForRetrieverRun
11 | from langchain_core.language_models import BaseLanguageModel
12 | from langchain_core.vectorstores import VectorStoreRetriever
13 |
14 |
15 | class ChromaDDGRetriever(VectorStoreRetriever):
16 | """
17 | A retriever that uses a ChromaDDG vectorstore to find relevant documents.
18 |
19 | Compared to a generic VectorStoreRetriever, this retriever has the additional
20 | feature of accepting the `where` and `where_document` parameters in its
21 | `get_relevant_documents` method. These parameters are used to filter the
22 | documents by their metadata and contained text, respectively.
23 |
24 | NOTE: even though the underlying vectorstore is not explicitly required to
25 | be a ChromaDDG, it must be a vectorstore that supports the `where` and
26 | `where_document` parameters in its `similarity_search`-type methods.
27 | """
28 |
29 | llm_for_token_counting: BaseLanguageModel | None
30 |
31 | verbose: bool = False # print similarity scores and other info
32 | k_overshot = 20 # number of docs to return initially (prune later)
33 | score_threshold_overshot = (
34 | -12345.0
35 | ) # score threshold to use initially (prune later)
36 |
37 | k_min = 2 # min number of docs to return after pruning
38 | score_threshold_min = (
39 | 0.61 if EMBEDDINGS_MODEL_NAME == "text-embeddings-ada-002" else -0.1
40 | ) # use k_min if score of k_min'th doc is <= this
41 |
42 | k_max = 10 # max number of docs to return after pruning
43 | score_threshold_max = (
44 | 0.76 if EMBEDDINGS_MODEL_NAME == "text-embeddings-ada-002" else 0.2
45 | ) # use k_max if score of k_max'th doc is >= this
46 |
47 | max_total_tokens = int(CONTEXT_LENGTH * 0.5) # consistent with ChatWithDocsChain
48 | max_average_tokens_per_chunk = int(max_total_tokens / k_max)
49 |
50 | # get_relevant_documents() must return only docs, but we'll save scores here
51 | similarities: list = Field(default_factory=list)
52 |
53 | allowed_search_types: ClassVar[tuple[str]] = (
54 | "similarity",
55 | # "similarity_score_threshold", # NOTE can add at some point
56 | "mmr",
57 | "similarity_ddg",
58 | )
59 |
60 | def _get_relevant_documents(
61 | self,
62 | query: str,
63 | *,
64 | run_manager: CallbackManagerForRetrieverRun,
65 | filter: Where | None = None, # For metadata (Langchain naming convention)
66 | where_document: WhereDocument | None = None, # Filter by text in document
67 | **kwargs: Any, # For additional search params
68 | ) -> list[Document]:
69 | # Combine global search kwargs with per-query search params passed here
70 | search_kwargs = self.search_kwargs | kwargs
71 | if filter is not None:
72 | search_kwargs["filter"] = filter
73 | if where_document is not None:
74 | search_kwargs["where_document"] = where_document
75 |
76 | # Perform search depending on search type
77 | if self.search_type == "similarity":
78 | return self.vectorstore.similarity_search(query, **search_kwargs)
79 | elif self.search_type == "mmr":
80 | return self.vectorstore.max_marginal_relevance_search(
81 | query, **search_kwargs
82 | )
83 |
84 | # Main search method used by DocDocGo
85 | assert self.search_type == "similarity_ddg", "Invalid search type"
86 | assert str(type(self.vectorstore)).endswith("ChromaDDG'>"), "Bad vectorstore"
87 |
88 | # First, get more docs than we need, then we'll pare them down
89 | # NOTE this is because apparently Chroma can miss even the most relevant doc
90 | # if k (num docs to return) is not high (e.g. "Who is Big Keetie?", k = 10)
91 | # k_overshot = max(k := search_kwargs["k"], self.k_overshot)
92 | # score_threshold_overshot = min(
93 | # score_threshold := search_kwargs["score_threshold"],
94 | # self.score_threshold_overshot,
95 | # ) # usually simply 0
96 |
97 | docs_and_similarities_overshot = (
98 | self.vectorstore.similarity_search_with_relevance_scores(
99 | query,
100 | **(
101 | search_kwargs
102 | | {
103 | "k": self.k_overshot,
104 | "score_threshold": self.score_threshold_overshot,
105 | }
106 | ),
107 | )
108 | )
109 |
110 | if self.verbose:
111 | for doc, sim in docs_and_similarities_overshot:
112 | print(f"[SIMILARITY: {sim:.2f}] {repr(doc.page_content[:60])}")
113 | print(f"Before paring down: {len(docs_and_similarities_overshot)} docs.")
114 |
115 | # Now, pare down the results
116 | chunks: list[Document] = []
117 | self.similarities: list[float] = []
118 | for k, (doc, sim) in enumerate(docs_and_similarities_overshot, start=1):
119 | # If we've already found enough docs, stop
120 | if k > self.k_max:
121 | break
122 |
123 | # Find score_threshold if we were to have k docs
124 | score_threshold_if_stop = lin_interpolate(
125 | k,
126 | self.k_min,
127 | self.k_max,
128 | self.score_threshold_min,
129 | self.score_threshold_max,
130 | )
131 |
132 | # If we have min num of docs and similarity is below the threshold, stop
133 | if k > self.k_min and sim < score_threshold_if_stop:
134 | break
135 |
136 | # Otherwise, add the doc to the list and keep going
137 | chunks.append(doc)
138 | self.similarities.append(sim)
139 |
140 | if self.verbose:
141 | print(f"After paring down: {len(chunks)} docs.")
142 | if chunks:
143 | print(
144 | f"Similarities from {self.similarities[-1]:.2f} to {self.similarities[0]:.2f}"
145 | )
146 | print(DELIMITER)
147 |
148 | # Get the parent documents for the chunks
149 | try:
150 | print("METADATAS:")
151 | for chunk in chunks:
152 | print(chunk.metadata)
153 | parent_ids = [chunk.metadata["parent_id"] for chunk in chunks]
154 | except KeyError:
155 | # If it's an older collection, without parent docs, just return the chunks
156 | return chunks
157 | unique_parent_ids = list(set(parent_ids))
158 | rsp = self.vectorstore.collection.get(unique_parent_ids)
159 |
160 | parent_docs_by_id = {
161 | id: Document(page_content=text, metadata=metadata)
162 | for id, text, metadata in zip(
163 | rsp["ids"], rsp["documents"], rsp["metadatas"]
164 | )
165 | }
166 |
167 | # Expand chunks using the parent docs
168 | max_total_tokens = min(
169 | self.max_total_tokens, self.max_average_tokens_per_chunk * len(chunks)
170 | )
171 | expanded_chunks = expand_chunks(
172 | chunks,
173 | parent_docs_by_id,
174 | max_total_tokens,
175 | llm_for_token_counting=self.llm_for_token_counting,
176 | )
177 | return expanded_chunks
178 |
179 | async def _aget_relevant_documents(
180 | self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
181 | ) -> list[Document]:
182 | raise NotImplementedError("Asynchronous retrieval not yet supported.")
183 |
--------------------------------------------------------------------------------
/components/llm.py:
--------------------------------------------------------------------------------
1 | from typing import Any
2 | from uuid import UUID
3 | from langchain_core.outputs import ChatGenerationChunk, GenerationChunk, LLMResult
4 | from langchain_openai import AzureChatOpenAI, ChatOpenAI
5 | from streamlit.delta_generator import DeltaGenerator
6 |
7 | from utils.helpers import DELIMITER, MAIN_BOT_PREFIX
8 | from utils.lang_utils import msg_list_chat_history_to_string
9 | from utils.prepare import (
10 | CHAT_DEPLOYMENT_NAME,
11 | IS_AZURE,
12 | LLM_REQUEST_TIMEOUT,
13 | )
14 | from utils.streamlit.helpers import fix_markdown
15 | from utils.type_utils import BotSettings, CallbacksOrNone
16 | from langchain_core.callbacks import BaseCallbackHandler
17 | from langchain_core.language_models import BaseChatModel
18 | from langchain_core.output_parsers import StrOutputParser
19 | from langchain_core.prompt_values import ChatPromptValue
20 | from langchain_core.prompts import PromptTemplate
21 |
22 |
23 | class CallbackHandlerDDGStreamlit(BaseCallbackHandler):
24 | def __init__(self, container: DeltaGenerator, end_str: str = ""):
25 | self.container = container
26 | self.buffer = ""
27 | self.end_str = end_str
28 | self.end_str_printed = False
29 |
30 | def on_llm_new_token(
31 | self,
32 | token: str,
33 | *,
34 | chunk: GenerationChunk | ChatGenerationChunk | None = None,
35 | run_id: UUID,
36 | parent_run_id: UUID | None = None,
37 | **kwargs: Any,
38 | ) -> None:
39 | self.buffer += token
40 | self.container.markdown(fix_markdown(self.buffer))
41 |
42 | def on_llm_end(
43 | self,
44 | response: LLMResult,
45 | *,
46 | run_id: UUID,
47 | parent_run_id: UUID | None = None,
48 | **kwargs: Any,
49 | ) -> None:
50 | """Run when LLM ends running."""
51 | if self.end_str:
52 | self.end_str_printed = True
53 | self.container.markdown(fix_markdown(self.buffer + self.end_str))
54 |
55 |
56 | class CallbackHandlerDDGConsole(BaseCallbackHandler):
57 | def __init__(self, init_str: str = MAIN_BOT_PREFIX):
58 | self.init_str = init_str
59 |
60 | def on_llm_start(
61 | self, serialized: dict[str, Any], prompts: list[str], **kwargs: Any
62 | ) -> None:
63 | print(self.init_str, end="", flush=True)
64 |
65 | def on_llm_new_token(self, token, **kwargs) -> None:
66 | print(token, end="", flush=True)
67 |
68 | def on_llm_end(self, *args, **kwargs) -> None:
69 | print()
70 |
71 | def on_retry(self, *args, **kwargs) -> None:
72 | print(f"ON_RETRY: \nargs = {args}\nkwargs = {kwargs}")
73 |
74 |
75 | def get_llm_with_callbacks(
76 | settings: BotSettings, api_key: str | None = None, callbacks: CallbacksOrNone = None
77 | ) -> BaseChatModel:
78 | """
79 | Returns a chat model instance (either AzureChatOpenAI or ChatOpenAI, depending
80 | on the value of IS_AZURE). In the case of AzureChatOpenAI, the model is
81 | determined by CHAT_DEPLOYMENT_NAME (and other Azure-specific environment variables),
82 | not by settings.model_name.
83 | """
84 | if IS_AZURE:
85 | llm = AzureChatOpenAI(
86 | deployment_name=CHAT_DEPLOYMENT_NAME,
87 | temperature=settings.temperature,
88 | request_timeout=LLM_REQUEST_TIMEOUT,
89 | streaming=True, # seems to help with timeouts
90 | callbacks=callbacks,
91 | )
92 | else:
93 | llm = ChatOpenAI(
94 | api_key=api_key or "", # don't allow None, no implicit key from env
95 | model=settings.llm_model_name,
96 | temperature=settings.temperature,
97 | request_timeout=LLM_REQUEST_TIMEOUT,
98 | streaming=True,
99 | callbacks=callbacks,
100 | verbose=True, # tmp
101 | )
102 | return llm
103 |
104 |
105 | def get_llm(
106 | settings: BotSettings,
107 | api_key: str | None = None,
108 | callbacks: CallbacksOrNone = None,
109 | stream=False,
110 | init_str=MAIN_BOT_PREFIX,
111 | ) -> BaseChatModel:
112 | """
113 | Return a chat model instance (either AzureChatOpenAI or ChatOpenAI, depending
114 | on the value of IS_AZURE). If callbacks is passed, it will be used as the
115 | callbacks for the chat model. Otherwise, if stream is True, then a CallbackHandlerDDG
116 | will be used with the passed init_str as the init_str.
117 | """
118 | if callbacks is None:
119 | callbacks = [CallbackHandlerDDGConsole(init_str)] if stream else []
120 | return get_llm_with_callbacks(settings, api_key, callbacks)
121 |
122 |
123 | def get_prompt_llm_chain(
124 | prompt: PromptTemplate,
125 | llm_settings: BotSettings,
126 | api_key: str | None = None,
127 | print_prompt=False,
128 | **kwargs,
129 | ):
130 | if not print_prompt:
131 | return prompt | get_llm(llm_settings, api_key, **kwargs) | StrOutputParser()
132 |
133 | def print_and_return(thing):
134 | if isinstance(thing, ChatPromptValue):
135 | print(f"PROMPT:\n{msg_list_chat_history_to_string(thing.messages)}")
136 | else:
137 | print(f"PROMPT:\n{type(thing)}\n{thing}")
138 | print(DELIMITER)
139 | return thing
140 |
141 | return (
142 | prompt
143 | | print_and_return
144 | | get_llm(llm_settings, api_key, **kwargs)
145 | | StrOutputParser()
146 | )
147 |
148 |
149 | def get_llm_from_prompt_llm_chain(prompt_llm_chain):
150 | return prompt_llm_chain.middle[-1]
151 |
152 |
153 | def get_prompt_text(prompt: PromptTemplate, inputs: dict[str, str]) -> str:
154 | return prompt.invoke(inputs).to_string()
155 |
156 |
157 | if __name__ == "__main__":
158 | # NOTE: Run this file as "python -m components.llm"
159 | x = CallbackHandlerDDGConsole("BOT: ")
160 |
--------------------------------------------------------------------------------
/components/openai_embeddings_ddg.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from langchain_core.embeddings import Embeddings
4 | from langchain_openai import OpenAIEmbeddings
5 |
6 | from utils.prepare import EMBEDDINGS_DIMENSIONS, EMBEDDINGS_MODEL_NAME, IS_AZURE
7 |
8 |
9 | def get_openai_embeddings(
10 | api_key: str | None = None,
11 | embeddings_model_name: str = EMBEDDINGS_MODEL_NAME,
12 | embeddings_dimensions: int = EMBEDDINGS_DIMENSIONS,
13 | ) -> Embeddings:
14 | # Create the embeddings object, pulling current values of env vars
15 | # (as of Aug 8, 2023, max chunk size for Azure API is 16)
16 | return (
17 | OpenAIEmbeddings( # NOTE: should be able to simplify this
18 | deployment=os.getenv("EMBEDDINGS_DEPLOYMENT_NAME"), chunk_size=16
19 | )
20 | if IS_AZURE
21 | else OpenAIEmbeddings(
22 | api_key=api_key or "",
23 | model=embeddings_model_name,
24 | dimensions=None
25 | if embeddings_model_name == "text-embeddings-ada-002"
26 | else embeddings_dimensions,
27 | ) # NOTE: if empty API key, will throw
28 | )
29 |
30 |
31 | # class OpenAIEmbeddingsDDG(Embeddings):
32 | # """
33 | # Custom version of OpenAIEmbeddings for DocDocGo. Unlike the original,
34 | # an object of this class will pull the current values of env vars every time.
35 | # This helps in situations where the user has changed env vars such as
36 | # OPENAI_API_KEY, as is possible in the Streamlit app.
37 |
38 | # This way is also more consistent with the behavior of e.g. ChatOpenAI, which
39 | # always uses the current values of env vars when querying the OpenAI API.
40 | # """
41 |
42 | # @staticmethod
43 | # def get_fresh_openai_embeddings():
44 | # # Create the embeddings object, pulling current values of env vars
45 | # # (as of Aug 8, 2023, max chunk size for Azure API is 16)
46 | # return (
47 | # OpenAIEmbeddings( # NOTE: should be able to simplify this
48 | # deployment=os.getenv("EMBEDDINGS_DEPLOYMENT_NAME"), chunk_size=16
49 | # )
50 | # if os.getenv("OPENAI_API_BASE") # proxy for whether we're using Azure
51 | # else OpenAIEmbeddings()
52 | # )
53 |
54 | # def embed_documents(self, texts: list[str]) -> list[list[float]]:
55 | # """Embed search docs."""
56 | # return self.get_fresh_openai_embeddings().embed_documents(texts)
57 |
58 | # def embed_query(self, text: str) -> list[float]:
59 | # """Embed query text."""
60 | # return self.get_fresh_openai_embeddings().embed_query(text)
61 |
62 | # async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
63 | # """Asynchronous Embed search docs."""
64 | # return await self.get_fresh_openai_embeddings().aembed_documents(texts)
65 |
66 | # async def aembed_query(self, text: str) -> list[float]:
67 | # """Asynchronous Embed query text."""
68 | # return await self.get_fresh_openai_embeddings().aembed_query(text)
69 |
--------------------------------------------------------------------------------
/config/logging.json:
--------------------------------------------------------------------------------
1 | {
2 | "version": 1,
3 | "disable_existing_loggers": false,
4 | "formatters": {
5 | "simple": {
6 | "format": "[%(levelname)s|%(name)s|%(module)s|L%(lineno)d] %(asctime)s: %(message)s",
7 | "datefmt": "%Y-%m-%dT%H:%M:%S%z"
8 | },
9 | "json": {
10 | "()": "utils.log.MyJSONFormatter",
11 | "fmt_keys": {
12 | "level": "levelname",
13 | "message": "message",
14 | "timestamp": "timestamp",
15 | "logger": "name",
16 | "module": "module",
17 | "function": "funcName",
18 | "line": "lineno",
19 | "thread_name": "threadName"
20 | }
21 | }
22 | },
23 | "handlers": {
24 | "h1_stderr": {
25 | "class": "logging.StreamHandler",
26 | "level": "INFO",
27 | "formatter": "simple",
28 | "stream": "ext://sys.stderr"
29 | },
30 | "h2_file_json": {
31 | "class": "logging.handlers.RotatingFileHandler",
32 | "level": "DEBUG",
33 | "formatter": "json",
34 | "filename": "logs/ddg-logs.jsonl",
35 | "maxBytes": 1000000,
36 | "backupCount": 2
37 | },
38 | "h3_queue": {
39 | "()": "utils.log.QueueListenerHandler",
40 | "handlers": ["cfg://handlers.h2_file_json", "cfg://handlers.h1_stderr"],
41 | "respect_handler_level": true
42 | }
43 | },
44 | "loggers": {
45 | "trafilatura": {
46 | "level": "ERROR",
47 | "handlers": ["h3_queue"],
48 | "propagate": false
49 | },
50 | "ddg": {
51 | "level": "DEBUG",
52 | "handlers": ["h3_queue"],
53 | "propagate": false
54 | }
55 | },
56 | "root": {
57 | "level": "WARNING",
58 | "handlers": ["h3_queue"]
59 | }
60 | }
61 |
--------------------------------------------------------------------------------
/config/trafilatura.cfg:
--------------------------------------------------------------------------------
1 | # Defines settings for trafilatura (https://github.com/adbar/trafilatura)
2 |
3 | [DEFAULT]
4 |
5 | # Download
6 | DOWNLOAD_TIMEOUT = 30
7 | MAX_FILE_SIZE = 20000000
8 | MIN_FILE_SIZE = 10
9 | # sleep between requests
10 | SLEEP_TIME = 5
11 | # user-agents here: agent1,agent2,...
12 | USER_AGENTS =
13 | # cookie for HTTP requests
14 | COOKIE =
15 |
16 | # Extraction
17 | MIN_EXTRACTED_SIZE = 1500
18 | # MIN_EXTRACTED_SIZE = 250
19 | # increased to increase recall
20 | # eg for https://www.artificialintelligence-news.com/ extracted only
21 | # cookie consent text, even with favor_recall = True
22 | MIN_EXTRACTED_COMM_SIZE = 1
23 | MIN_OUTPUT_SIZE = 1
24 | MIN_OUTPUT_COMM_SIZE = 1
25 |
26 | # CLI file processing only, set to 0 to disable
27 | EXTRACTION_TIMEOUT = 30
28 |
29 | # Deduplication
30 | MIN_DUPLCHECK_SIZE = 100
31 | MAX_REPETITIONS = 2
32 |
33 |
--------------------------------------------------------------------------------
/eval/RAG-PARAMS.md:
--------------------------------------------------------------------------------
1 | # Parameters for RAG
2 |
3 | ## ChromaDDGRetriever
4 |
5 | k_min = 2 # min number of docs to return after pruning
6 | score_threshold_min = 0.61 # use k_min if score of k_min'th doc is <= this
7 | k_max = 10 # max number of docs to return after pruning
8 | score_threshold_max = 0.76 # use k_max if score of k_max'th doc is >= this
9 |
10 | ## prepare_chunks
11 |
12 | text_splitter = RecursiveCharacterTextSplitter(
13 | chunk_size=400,
14 | chunk_overlap=40,
15 |
--------------------------------------------------------------------------------
/eval/ai_news_1.py:
--------------------------------------------------------------------------------
1 | ai_news_1 ="""sdecoret - stock.adobe.com
2 | Biggest AI news of 2022
3 | From misinformation and disinformation in the war in Ukraine to James Earl Jones lending his voice to voice cloning, these are the most notable stories of the year.
4 | Despite economic uncertainty and a few hiccups regarding ethics and bias, 2022 was a year of change for AI technologies.
5 | Here is a look back at the stories that dominated the year.
6 | Can AI be sentient?
7 | Blake Lemoine, a former engineer at Google, broke the internet when he declared that
8 | [Google's Language Model for Dialogue Applications](https://www.techtarget.com/searchenterpriseai/feature/Ex-Google-engineer-Blake-Lemoine-discusses-sentient-AI) (LaMDA), a large language model ( [LLM](https://www.techtarget.com/whatis/definition/large-language-model-LLM)) in development, is sentient.
9 | Blake claimed that based on LaMDA's answers to his questions, the LLM can communicate its fears and have feelings the way a human can.
10 | Google later fired Lemoine for choosing to "persistently violate clear employment and data security policies that include the need to safeguard product information," the company said in a statement.
11 | Generative AI
12 | From Dall-E 2 to Stable Diffusion to ChatGPT, generative AI has made AI more accessible to not only enterprise users but also consumers.
13 | However, with more accessibility comes more responsibility.
14 | As everyday consumers use Dall-E 2 to create striking images, artists and creators are concerned that these models are not only erasing a need for them but also
15 | [stealing their artwork](https://www.techtarget.com/searchenterpriseai/feature/The-creative-thief-AI-tools-creating-generated-art).
16 | This has also brought up concerns about copyright infringement laws when it comes to work created -- or stolen -- by AI.
17 | According to Michael G. Bennett, director of education curriculum and business lead for responsible AI at the Institute for Experiential AI, while some artists may have a case for copyright infringement if the artwork generated shows a strong similarity to the artist's work, if the work does not show a strong similarity, most artists will not have a case. Many times, the work being produced is the same as if a young artist were to be inspired by an older artist.
18 | Despite the concerns behind generated artwork and generative AI, many creators are embracing it. Marketing vendor Omneky uses Dall-E to generate ads for customers.
19 | Metaverse, avatars and metaverse technologies
20 | The metaverse dominated early 2022, with vendors such as Nvidia describing what they believe the next iteration of the internet will look like. Nvidia introduced an avatar engine. Meta -- Facebook's parent company -- invested heavily in metaverse technology. Researchers and analysts described a metaverse that includes
21 | [holographic avatars](https://www.techtarget.com/searchenterpriseai/feature/Enterprise-applications-of-the-metaverse-slow-but-coming) that could be used to train medical professionals and others.
22 | Vendors, including Hour One, introduced technologies such as a 3D news studio, where users can choose an avatar newscaster.
23 | Finally, consumers could reimagine themselves with Prisma Labs' Lensa app, which generated avatars for users.
24 | While many agree that we're still far away from the metaverse that vendors and enterprises alike are describing, the technologies are certainly growing more rampant, making the metaverse more real to all.
25 | Enterprises turn to AI to help ease the Great Resignation
26 | The coronavirus pandemic made many re-evaluate their employment status. This, in turn, meant many employees resigned in favor of other opportunities.
27 | For enterprises in the restaurant industry, this meant many positions were left unfilled. From pizza chain Jet's Pizza to
28 | [Panera Bread](https://www.techtarget.com/searchenterpriseai/news/252524391/Paneras-AI-drive-through-test-addresses-labor-concerns), enterprises turned to AI tools and technologies that could supplement human workers so that those left on staff could focus their attention elsewhere.
29 | "We're starting to see AI move out of that theoretical space of all the cool things AI can do," Liz Miller, an analyst at Constellation Research, told TechTarget Editorial back in August. "We're starting to see real, specific kinds of business acceleration-minded applications coming off of the workbench and getting into real life."
30 | Voice cloning and James Earl Jones
31 | Liz MillerAnalyst, Constellation Research
32 | When news broke that James Earl Jones was lending his voice to
33 | [speech-to-speech voice cloning](https://www.techtarget.com/searchenterpriseai/feature/James-Earl-Jones-AI-and-the-growing-voice-cloning-market) for future appearances of his Star Wars character Darth Vader, it gave insight into the growing use of the technology.
34 | There are two types of voice cloning: speech-to-speech and text-to-speech.
35 | Text-to-speech is used in contact center environments, where synthetic speech can communicate with consumers before they speak to agents. It's especially helpful when agents speak different languages.
36 | AI and the war in Ukraine
37 | When the war between Russia and Ukraine began earlier this year,
38 | [AI was a tool used to spread disinformation](https://www.techtarget.com/searchenterpriseai/feature/AI-and-disinformation-in-the-Russia-Ukraine-war). Bad actors used the technology to create videos that spread misinformation. Deepfakes or AI-generated humans were created to spread anti-Ukrainian discourse.
39 | The use of deepfakes in the war showed how easy it has become for everyday consumers to create them without coding knowledge.
40 | The use of AI in the war also shows the psychological impact machine learning can have.
41 | "Machine learning is exceptionally good at learning how to exploit human psychology because the internet provides a vast and fast feedback loop to learn what will reinforce and or break beliefs by demographic cohorts," said Mike Gualtieri, an analyst at Forrester Research."""
--------------------------------------------------------------------------------
/eval/queries.md:
--------------------------------------------------------------------------------
1 | # Queries For Evaluation
2 |
3 | /research AI-powered resume builders - simple numbered list of as many as possible
4 | /research running llama2 locally
5 | /research legal arguments for and against disqualifying Trump from running
6 | /research what are some unique traditional Ukrainian dishes
7 | /research what are some important ai research papers and their summaries
8 | /research what do various organizations report as the recommended protein intake?
9 | /research recent ai news
10 | /research I'm looking for a tutorial on configuring Firestore (not Firebase) so that my external server-side application can use it. Most importantly, I want to know exactly how to set up the right role to the service account. Please assign a score 0-100 to each website for how well it fulfills that purpose. The tutorial should not be from Google's docs and they should not be about Firebase, only regular Firestore. If it's from Google's docs or mentions Firebase, assign score 0. Please just assign a score to each source and add a one-two sentence description, nothing else. Include links. If your prompt includes a previous report with assessments, keep them in your current report
11 | /research Please find quotes by Noam Chomsky about the scientific or practical value of ChatGPT. For each quote include a link to the source
12 |
13 | What is the typical range for the 1536 values in the text-embedding-ada-002 embeddings?
14 |
--------------------------------------------------------------------------------
/eval/testing-rag-quality/2401.01325.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/reasonmethis/docdocgo-core/e6408ea9c7f4657525715730f893e8a3d27296a5/eval/testing-rag-quality/2401.01325.pdf
--------------------------------------------------------------------------------
/eval/testing-rag-quality/rag-tests.md:
--------------------------------------------------------------------------------
1 | Asking "What are the conclusions of the paper?" after ingesting 2401.01325.pdf
--------------------------------------------------------------------------------
/ingest_local_docs.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | from langchain_community.document_loaders import DirectoryLoader, TextLoader
5 |
6 | from _prepare_env import is_env_loaded
7 | from components.chroma_ddg import initialize_client
8 | from utils.docgrab import (
9 | JSONLDocumentLoader,
10 | ingest_into_chroma,
11 | )
12 | from utils.helpers import clear_directory, get_timestamp, is_directory_empty, print_no_newline
13 | from utils.prepare import DEFAULT_OPENAI_API_KEY
14 |
15 | is_env_loaded = is_env_loaded # see explanation at the end of docdocgo.py
16 |
17 | if __name__ == "__main__":
18 | print(
19 | "-" * 70
20 | + "\n"
21 | + " " * 20
22 | + "Ingestion of Local Docs into Local DB\n"
23 | + "NOTE: It's recommended to use the Streamlit UI for ingestion instead,\n"
24 | + "as this script is not maintained as frequently.\n"
25 | + "-" * 70
26 | + "\n"
27 | )
28 |
29 | DOCS_TO_INGEST_DIR_OR_FILE = os.getenv("DOCS_TO_INGEST_DIR_OR_FILE")
30 | COLLECTON_NAME_FOR_INGESTED_DOCS = os.getenv("COLLECTON_NAME_FOR_INGESTED_DOCS")
31 | VECTORDB_DIR = os.getenv("VECTORDB_DIR")
32 |
33 | if (
34 | not DOCS_TO_INGEST_DIR_OR_FILE
35 | or not VECTORDB_DIR
36 | or not COLLECTON_NAME_FOR_INGESTED_DOCS
37 | ):
38 | print(
39 | "Please set DOCS_TO_INGEST_DIR_OR_FILE, COLLECTON_NAME_FOR_INGESTED_DOC, "
40 | "and VECTORDB_DIR in `.env`."
41 | )
42 | sys.exit()
43 | if not os.path.exists(DOCS_TO_INGEST_DIR_OR_FILE):
44 | print(f"{DOCS_TO_INGEST_DIR_OR_FILE} does not exist.")
45 | sys.exit()
46 |
47 | # Validate the save directory
48 | if not os.path.isdir(VECTORDB_DIR):
49 | if os.path.exists(VECTORDB_DIR):
50 | print(f"{VECTORDB_DIR} is not a directory.")
51 | sys.exit()
52 | print(f"{VECTORDB_DIR} is not an existing directory.")
53 | ans = input("Do you want to create it? [y/N] ")
54 | if ans.lower() != "y":
55 | sys.exit()
56 | else:
57 | try:
58 | os.makedirs(VECTORDB_DIR)
59 | except Exception as e:
60 | print(f"Could not create {VECTORDB_DIR}: {e}")
61 | sys.exit()
62 |
63 | # Check if the save directory is empty; if not, give the user options
64 | chroma_client = None
65 | if is_directory_empty(VECTORDB_DIR):
66 | is_new_db = True
67 | else:
68 | ans = input(
69 | f"{VECTORDB_DIR} is not empty. Please select an option:\n"
70 | "1. Use it (select if it contains your existing db)\n"
71 | "2. Delete all its content and create a new db\n"
72 | "3. Cancel\nYour choice: "
73 | )
74 | print()
75 | if ans not in {"1", "2"}:
76 | sys.exit()
77 | is_new_db = ans == "2"
78 | if is_new_db:
79 | ans = input("Type 'erase' to confirm deleting the directory's content: ")
80 | if ans != "erase":
81 | sys.exit()
82 | print_no_newline(f"Deleting files and folders in {VECTORDB_DIR}...")
83 | clear_directory(VECTORDB_DIR)
84 | print("Done!")
85 | else:
86 | chroma_client = initialize_client(use_chroma_via_http=False)
87 | collections = chroma_client.list_collections()
88 | collection_names = [c.name for c in collections]
89 | if COLLECTON_NAME_FOR_INGESTED_DOCS in collection_names:
90 | ans = input(
91 | f"Collection {COLLECTON_NAME_FOR_INGESTED_DOCS} already exists. "
92 | "Please select an option:\n"
93 | "1. Append\n"
94 | "2. Overwrite\n"
95 | "3. Cancel\nYour choice: "
96 | )
97 | print()
98 | if ans not in {"1", "2"}:
99 | sys.exit()
100 | if ans == "2":
101 | ans = input(
102 | "Type 'erase' to confirm deleting the existing collection: "
103 | )
104 | if ans != "erase":
105 | sys.exit()
106 | print_no_newline(
107 | f"Deleting collection {COLLECTON_NAME_FOR_INGESTED_DOCS}..."
108 | )
109 | chroma_client.delete_collection(COLLECTON_NAME_FOR_INGESTED_DOCS)
110 | print("Done!\n")
111 |
112 | # Confirm the ingestion
113 | print(f"You are about to ingest documents from: {DOCS_TO_INGEST_DIR_OR_FILE}")
114 | print(f"The vector database directory: {VECTORDB_DIR}")
115 | print(f"The collection name is: {COLLECTON_NAME_FOR_INGESTED_DOCS}")
116 | print(f"Is this a new database? {'Yes' if is_new_db else 'No'}")
117 | if os.getenv("USE_CHROMA_VIA_HTTP") or os.getenv("CHROMA_API_IMPL"):
118 | print(
119 | "NOTE: we will disable the USE_CHROMA_VIA_HTTP and "
120 | "CHROMA_API_IMPL environment variables since this is a local ingestion."
121 | )
122 | os.environ["USE_CHROMA_VIA_HTTP"] = os.environ["CHROMA_API_IMPL"] = ""
123 |
124 | print("\nATTENTION: This will incur some cost on your OpenAI account.")
125 |
126 | if input("Press Enter to proceed. Any non-empty input will cancel the procedure: "):
127 | print("Ingestion cancelled. Exiting")
128 | sys.exit()
129 |
130 | # Load the documents
131 | print_no_newline("Loading your documents...")
132 | if os.path.isfile(DOCS_TO_INGEST_DIR_OR_FILE):
133 | if DOCS_TO_INGEST_DIR_OR_FILE.endswith(".jsonl"):
134 | loader = JSONLDocumentLoader(DOCS_TO_INGEST_DIR_OR_FILE)
135 | else:
136 | loader = TextLoader(DOCS_TO_INGEST_DIR_OR_FILE, autodetect_encoding=True)
137 | else:
138 | loader = DirectoryLoader(
139 | DOCS_TO_INGEST_DIR_OR_FILE,
140 | loader_cls=TextLoader, # default is Unstructured but python-magic hangs on import
141 | loader_kwargs={"autodetect_encoding": True},
142 | )
143 | docs = loader.load()
144 | print("Done!")
145 |
146 | # Ingest into chromadb (this will print messages regarding the status)
147 | timestamp = get_timestamp()
148 | ingest_into_chroma(
149 | docs,
150 | collection_name=COLLECTON_NAME_FOR_INGESTED_DOCS,
151 | chroma_client=chroma_client,
152 | save_dir=None if chroma_client else VECTORDB_DIR,
153 | openai_api_key=DEFAULT_OPENAI_API_KEY,
154 | collection_metadata={"created_at": timestamp, "updated_at": timestamp},
155 | )
156 |
--------------------------------------------------------------------------------
/media/ddg-logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/reasonmethis/docdocgo-core/e6408ea9c7f4657525715730f893e8a3d27296a5/media/ddg-logo.png
--------------------------------------------------------------------------------
/media/minimal10.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/reasonmethis/docdocgo-core/e6408ea9c7f4657525715730f893e8a3d27296a5/media/minimal10.png
--------------------------------------------------------------------------------
/media/minimal10.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
126 |
--------------------------------------------------------------------------------
/media/minimal11.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/reasonmethis/docdocgo-core/e6408ea9c7f4657525715730f893e8a3d27296a5/media/minimal11.png
--------------------------------------------------------------------------------
/media/minimal11.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
126 |
--------------------------------------------------------------------------------
/media/minimal13.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
144 |
--------------------------------------------------------------------------------
/media/minimal7-plain-svg.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
98 |
--------------------------------------------------------------------------------
/media/minimal7.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/reasonmethis/docdocgo-core/e6408ea9c7f4657525715730f893e8a3d27296a5/media/minimal7.png
--------------------------------------------------------------------------------
/media/minimal7.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
123 |
--------------------------------------------------------------------------------
/media/minimal9.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/reasonmethis/docdocgo-core/e6408ea9c7f4657525715730f893e8a3d27296a5/media/minimal9.png
--------------------------------------------------------------------------------
/media/minimal9.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
126 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | aiohttp==3.9.1
2 | aiosignal==1.3.1
3 | altair==5.2.0
4 | annotated-types==0.6.0
5 | anyio==4.2.0
6 | asgiref==3.7.2
7 | asttokens==2.4.1
8 | attrs==23.1.0
9 | backoff==2.2.1
10 | bcrypt==4.1.2
11 | beautifulsoup4==4.12.2
12 | blinker==1.7.0
13 | bs4==0.0.1
14 | build==1.2.1
15 | cachetools==5.3.2
16 | certifi==2023.11.17
17 | charset-normalizer==3.3.2
18 | chroma-hnswlib==0.7.5
19 | chromadb==0.5.4
20 | click==8.1.7
21 | colorama==0.4.6
22 | coloredlogs==15.0.1
23 | courlan==0.9.5
24 | dataclasses-json==0.6.3
25 | dateparser==1.2.0
26 | Deprecated==1.2.14
27 | distro==1.9.0
28 | docx2txt==0.8
29 | executing==2.0.1
30 | fake-useragent==1.4.0
31 | fastapi==0.108.0
32 | filelock==3.13.1
33 | flatbuffers==23.5.26
34 | frozenlist==1.4.1
35 | fsspec==2023.12.2
36 | gitdb==4.0.11
37 | GitPython==3.1.40
38 | google-api-core==2.15.0
39 | google-auth==2.25.2
40 | google-cloud-core==2.4.1
41 | google-cloud-firestore==2.14.0
42 | googleapis-common-protos==1.62.0
43 | greenlet==3.0.1
44 | grpcio==1.60.0
45 | grpcio-status==1.60.0
46 | h11==0.14.0
47 | htmldate==1.6.0
48 | httpcore==1.0.2
49 | httptools==0.6.1
50 | httpx==0.27.0
51 | huggingface-hub==0.20.1
52 | humanfriendly==10.0
53 | icecream==2.1.3
54 | idna==3.6
55 | importlib-metadata==6.11.0
56 | importlib-resources==6.1.1
57 | itsdangerous==2.1.2
58 | Jinja2==3.1.2
59 | jsonpatch==1.33
60 | jsonpointer==2.4
61 | jsonschema==4.20.0
62 | jsonschema-specifications==2023.12.1
63 | jusText==3.0.0
64 | kubernetes==28.1.0
65 | langchain==0.2.10
66 | langchain-cli==0.0.25
67 | langchain-community==0.2.9
68 | langchain-core==0.2.22
69 | langchain-openai==0.1.17
70 | langchain-text-splitters==0.2.2
71 | langcodes==3.3.0
72 | langserve==0.2.2
73 | langsmith==0.1.93
74 | libcst==1.4.0
75 | lxml==4.9.4
76 | markdown-it-py==3.0.0
77 | MarkupSafe==2.1.3
78 | marshmallow==3.20.1
79 | mdurl==0.1.2
80 | mmh3==4.0.1
81 | monotonic==1.6
82 | mpmath==1.3.0
83 | multidict==6.0.4
84 | mypy-extensions==1.0.0
85 | numpy==1.26.2
86 | oauthlib==3.2.2
87 | onnxruntime==1.16.3
88 | openai==1.36.1
89 | opentelemetry-api==1.22.0
90 | opentelemetry-exporter-otlp-proto-common==1.22.0
91 | opentelemetry-exporter-otlp-proto-grpc==1.22.0
92 | opentelemetry-instrumentation==0.43b0
93 | opentelemetry-instrumentation-asgi==0.43b0
94 | opentelemetry-instrumentation-fastapi==0.43b0
95 | opentelemetry-proto==1.22.0
96 | opentelemetry-sdk==1.22.0
97 | opentelemetry-semantic-conventions==0.43b0
98 | opentelemetry-util-http==0.43b0
99 | orjson==3.10.6
100 | overrides==7.4.0
101 | packaging==23.2
102 | pandas==2.1.4
103 | Pillow==10.1.0
104 | playwright==1.40.0
105 | posthog==3.1.0
106 | proto-plus==1.23.0
107 | protobuf==4.25.1
108 | pulsar-client==3.3.0
109 | pyarrow==14.0.2
110 | pyasn1==0.5.1
111 | pyasn1-modules==0.3.0
112 | pydantic==2.5.3
113 | pydantic_core==2.14.6
114 | pydeck==0.8.1b0
115 | pyee==11.0.1
116 | Pygments==2.17.2
117 | pypdf==4.0.0
118 | PyPika==0.48.9
119 | pyproject-toml==0.0.10
120 | pyproject_hooks==1.1.0
121 | pyreadline3==3.4.1
122 | python-dateutil==2.8.2
123 | python-dotenv==1.0.0
124 | python-multipart==0.0.9
125 | pytz==2023.3.post1
126 | PyYAML==6.0.1
127 | referencing==0.32.0
128 | regex==2023.12.25
129 | requests==2.31.0
130 | requests-oauthlib==1.3.1
131 | rich==13.7.0
132 | rpds-py==0.15.2
133 | rsa==4.9
134 | shellingham==1.5.4
135 | six==1.16.0
136 | smmap==5.0.1
137 | sniffio==1.3.0
138 | soupsieve==2.5
139 | SQLAlchemy==2.0.23
140 | sse-starlette==1.8.2
141 | starlette==0.32.0.post1
142 | streamlit==1.35.0
143 | sympy==1.12
144 | tenacity==8.2.3
145 | tiktoken==0.7.0
146 | tld==0.13
147 | tokenizers==0.15.0
148 | toml==0.10.2
149 | tomlkit==0.12.5
150 | toolz==0.12.0
151 | tornado==6.4
152 | tqdm==4.66.1
153 | trafilatura==1.6.3
154 | typer==0.9.0
155 | typing-inspect==0.9.0
156 | typing_extensions==4.9.0
157 | tzdata==2023.3
158 | tzlocal==5.2
159 | urllib3==1.26.18
160 | uvicorn==0.23.2
161 | validators==0.22.0
162 | watchdog==3.0.0
163 | watchfiles==0.21.0
164 | websocket-client==1.7.0
165 | websockets==12.0
166 | Werkzeug==3.0.1
167 | wrapt==1.16.0
168 | yarl==1.9.4
169 | zipp==3.17.0
170 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/reasonmethis/docdocgo-core/e6408ea9c7f4657525715730f893e8a3d27296a5/utils/__init__.py
--------------------------------------------------------------------------------
/utils/algo.py:
--------------------------------------------------------------------------------
1 | from itertools import zip_longest
2 | from typing import Any, Iterable, Iterator
3 |
4 |
5 | def interleave_iterables(iterables: Iterable[Iterable[Any]]) -> Iterator[Any]:
6 | """
7 | Interleave elements from multiple iterables.
8 |
9 | This function takes an iterable of iterables (like lists, tuples, etc.) and yields elements
10 | in an interleaved manner: first yielding the first element of each iterable, then the second
11 | element of each, and so on. If an iterable is exhausted (shorter than others), it is skipped.
12 |
13 | Example:
14 | >>> list(interleave_iterables([[1, 2, 3], ('a', 'b'), [10, 20, 30, 40]]))
15 | [1, 'a', 10, 2, 'b', 20, 3, 30, 40]
16 | """
17 |
18 | sentinel = object() # unique object to identify exhausted sublists
19 |
20 | for elements in zip_longest(*iterables, fillvalue=sentinel):
21 | for element in elements:
22 | if element is not sentinel:
23 | yield element
24 |
25 |
26 | def remove_duplicates_keep_order(iterable: Iterable):
27 | """
28 | Remove duplicates from an iterable while keeping order. Returns a list.
29 | """
30 | return list(dict.fromkeys(iterable))
31 |
32 |
33 | def insert_interval(
34 | current_intervals: list[tuple[int, int]], new_interval: tuple[int, int]
35 | ) -> list[tuple[int, int]]:
36 | """
37 | Add a new interval to a list of sorted non-overlapping intervals. If the new interval overlaps
38 | with any of the existing intervals, merge them. Return the resulting list of intervals.
39 | """
40 | new_intervals = []
41 | new_interval_start, new_interval_end = new_interval
42 | for i, (start, end) in enumerate(current_intervals):
43 | if end < new_interval_start:
44 | new_intervals.append((start, end))
45 | elif start > new_interval_end:
46 | # Add the new interval and all remaining intervals
47 | new_intervals.append((new_interval_start, new_interval_end))
48 | return new_intervals + current_intervals[i:]
49 | else:
50 | # Intervals overlap or are adjacent, merge them
51 | new_interval_start = min(start, new_interval_start)
52 | new_interval_end = max(end, new_interval_end)
53 |
54 | # If we're here, the new/merged interval is the last one
55 | new_intervals.append((new_interval_start, new_interval_end))
56 | return new_intervals
57 |
--------------------------------------------------------------------------------
/utils/async_utils.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
3 |
4 | # NOTE: consider using async_to_sync from asgiref.sync library
5 |
6 |
7 | def run_task_sync(task):
8 | """
9 | Run an asyncio task (more precisely, a coroutine object, such as the result of
10 | calling an async function) synchronously.
11 | """
12 | # print("-" * 100)
13 | # try:
14 | # print("=" * 100)
15 | # asyncio.get_running_loop()
16 | # print("*-*" * 100)
17 | # except RuntimeError as e:
18 | # print("?" * 100)
19 | # print(e)
20 | # print("?" * 100)
21 | # res = asyncio.run(task) # no preexisting event loop, so run in the main thread
22 | # print("!" * 100)
23 | # print(res)
24 | # print("!" * 100)
25 | # return res
26 | # TODO: consider a situation when there is a non-running loop
27 |
28 | with ThreadPoolExecutor(max_workers=1) as executor:
29 | future = executor.submit(asyncio.run, task)
30 | return future.result()
31 |
32 |
33 | def make_sync(async_func):
34 | """
35 | Make an asynchronous function synchronous.
36 | """
37 |
38 | def wrapper(*args, **kwargs):
39 | return run_task_sync(async_func(*args, **kwargs))
40 |
41 | return wrapper
42 |
43 |
44 | def gather_tasks_sync(tasks):
45 | """
46 | Run a list of asyncio tasks synchronously.
47 | """
48 |
49 | async def coroutine_from_tasks():
50 | return await asyncio.gather(*tasks)
51 |
52 | return run_task_sync(coroutine_from_tasks())
53 |
54 |
55 | def execute_func_map_in_processes(func, inputs, max_workers=None):
56 | """
57 | Execute a function on a list of inputs in a separate process for each input.
58 | """
59 | with ProcessPoolExecutor(max_workers=max_workers) as executor:
60 | return list(executor.map(func, inputs))
61 |
62 | def execute_func_map_in_threads(func, inputs, max_workers=None):
63 | """
64 | Execute a function on a list of inputs in a separate thread for each input.
65 | """
66 | with ThreadPoolExecutor(max_workers=max_workers) as executor:
67 | return list(executor.map(func, inputs))
68 |
--------------------------------------------------------------------------------
/utils/debug.py:
--------------------------------------------------------------------------------
1 |
2 | from utils.prepare import get_logger
3 | from langchain_core.prompts import PromptTemplate
4 |
5 | logger = get_logger()
6 |
7 |
8 | def save_prompt_text_to_file(
9 | prompt: PromptTemplate, inputs: dict[str, str], file_path: str
10 | ) -> None:
11 | prompt_text = prompt.invoke(inputs).to_string()
12 | with open(file_path, "a", encoding="utf-8") as f:
13 | f.write("-" * 80 + "\n\n" + prompt_text)
14 | logger.debug(f"Saved prompt text to {file_path}")
15 |
--------------------------------------------------------------------------------
/utils/docgrab.py:
--------------------------------------------------------------------------------
1 | import json
2 | import uuid
3 | from typing import Iterable
4 |
5 | from chromadb import ClientAPI
6 | from dotenv import load_dotenv
7 | from langchain_community.document_loaders import GitbookLoader
8 |
9 | from components.chroma_ddg import ChromaDDG
10 | from components.openai_embeddings_ddg import get_openai_embeddings
11 | from utils.prepare import EMBEDDINGS_DIMENSIONS, get_logger
12 | from utils.rag import rag_text_splitter
13 | from langchain_core.documents import Document
14 |
15 | load_dotenv(override=True)
16 | logger = get_logger()
17 |
18 |
19 | class JSONLDocumentLoader:
20 | def __init__(self, file_path: str, max_docs=None) -> None:
21 | self.file_path = file_path
22 | self.max_docs = max_docs
23 |
24 | def load(self) -> list[Document]:
25 | docs = load_docs_from_jsonl(self.file_path)
26 | if self.max_docs is None or self.max_docs > len(docs):
27 | return docs
28 | return docs[: self.max_docs]
29 |
30 |
31 | def save_docs_to_jsonl(docs: Iterable[Document], file_path: str) -> None:
32 | with open(file_path, "a") as f:
33 | for doc in docs:
34 | f.write(doc.json() + "\n")
35 |
36 |
37 | def load_docs_from_jsonl(file_path: str) -> list[Document]:
38 | docs = []
39 | with open(file_path, "r") as f:
40 | for line in f:
41 | data = json.loads(line)
42 | doc = Document(**data)
43 | docs.append(doc)
44 | return docs
45 |
46 |
47 | def load_gitbook(root_url: str) -> list[Document]:
48 | all_pages_docs = GitbookLoader(root_url, load_all_paths=True).load()
49 | return all_pages_docs
50 |
51 |
52 | def prepare_chunks(
53 | texts: list[str], metadatas: list[dict], ids: list[str]
54 | ) -> list[Document]:
55 | """
56 | Split documents into chunks and add parent ids to the chunks' metadata.
57 | Returns a list of snippets (each is a Document).
58 |
59 | It is ok to pass an empty list of texts.
60 | """
61 | logger.info(f"Splitting {len(texts)} documents into chunks...")
62 |
63 | # Add parent ids to metadata (will delete, but only after it they propagate to snippets)
64 | for metadata, id in zip(metadatas, ids):
65 | metadata["parent_id"] = id
66 |
67 | # Split into snippets
68 | snippets = rag_text_splitter.create_documents(texts, metadatas)
69 | logger.info(f"Obtained {len(snippets)} chunks.")
70 |
71 | # Restore original metadata
72 | for metadata in metadatas:
73 | del metadata["parent_id"]
74 |
75 | return snippets
76 |
77 |
78 | FAKE_FULL_DOC_EMBEDDING = [1.0] * EMBEDDINGS_DIMENSIONS
79 |
80 | # TODO: remove the logic of saving to the db, leave only doc preparation. We should
81 | # separate concerns and reduce the number of places we write to the db.
82 | def ingest_into_chroma(
83 | docs: list[Document],
84 | *,
85 | collection_name: str,
86 | openai_api_key: str,
87 | chroma_client: ClientAPI | None = None,
88 | save_dir: str | None = None,
89 | collection_metadata: dict[str, str] | None = None,
90 | ) -> ChromaDDG:
91 | """
92 | Load documents and/or collection metadata into a Chroma collection, return a vectorstore
93 | object.
94 |
95 | If collection_metadata is passed and the collection exists, the metadata will be
96 | replaced with the passed metadata, according to the Chroma docs.
97 |
98 | NOTE: Normally, the higher level agentblocks.collectionhelper.ingest_into_collection
99 | should be used, which creates/updates the "created_at" and "updated_at" metadata fields.
100 | """
101 | assert bool(chroma_client) != bool(save_dir), "Invalid vector db destination"
102 |
103 | # Handle special case of no docs - just create/update collection with given metadata
104 | if not docs:
105 | return ChromaDDG(
106 | embedding_function=get_openai_embeddings(openai_api_key),
107 | client=chroma_client,
108 | persist_directory=save_dir,
109 | collection_name=collection_name,
110 | collection_metadata=collection_metadata,
111 | create_if_not_exists=True,
112 | )
113 |
114 | # Prepare full texts, metadatas and ids
115 | full_doc_ids = [str(uuid.uuid4()) for _ in range(len(docs))]
116 | texts = [doc.page_content for doc in docs]
117 | metadatas = [doc.metadata for doc in docs]
118 |
119 | # Split into snippets, embed and add them
120 | vectorstore: ChromaDDG = ChromaDDG.from_documents(
121 | prepare_chunks(texts, metadatas, full_doc_ids),
122 | embedding=get_openai_embeddings(openai_api_key),
123 | client=chroma_client,
124 | persist_directory=save_dir,
125 | collection_name=collection_name,
126 | collection_metadata=collection_metadata,
127 | create_if_not_exists=True, # ok to pass (kwargs are passed to __init__)
128 | )
129 |
130 | # Add the original full docs (with fake embeddings)
131 | # NOTE: should be possible to add everything in one call, with some work
132 | fake_embeddings = [FAKE_FULL_DOC_EMBEDDING for _ in range(len(docs))]
133 | vectorstore.collection.add(full_doc_ids, fake_embeddings, metadatas, texts)
134 |
135 | logger.info(f"Ingested documents into collection {collection_name}")
136 | if save_dir:
137 | logger.info(f"Saved to {save_dir}")
138 | return vectorstore
139 |
140 |
141 | if __name__ == "__main__":
142 | pass
143 | # download all pages from gitbook and save to jsonl
144 | # all_pages_docs = load_gitbook(GITBOOK_ROOT_URL)
145 | # print(f"Loaded {len(all_pages_docs)} documents")
146 | # save_docs_to_jsonl(all_pages_docs, "docs.jsonl")
147 |
148 | # load from jsonl
149 | # docs = load_docs_from_jsonl("docs.jsonl")
150 | # print(f"Loaded {len(docs)} documents")
151 |
--------------------------------------------------------------------------------
/utils/filesystem.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 |
4 | def ensure_path_exists(path, is_directory=False):
5 | """
6 | Ensure that a given path exists as either a file or a directory.
7 |
8 | Parameters:
9 | - path (str): The file or directory path to check or create.
10 | - is_directory (bool): Flag to indicate if the path should be a directory (True) or a file (False).
11 |
12 | This function checks if the given path exists. If it exists, it checks if it matches the type specified
13 | by `is_directory` (file or directory). If it does not exist, the function creates the necessary
14 | directories and, if required, an empty file at that path.
15 |
16 | Returns:
17 | - None
18 | """
19 | # Check if the path exists
20 | if os.path.exists(path):
21 | # If it exists, check if it is the correct type (file or directory)
22 | if is_directory:
23 | if not os.path.isdir(path):
24 | raise ValueError(f"Path {path} exists but is not a directory as expected.")
25 | elif not os.path.isfile(path):
26 | raise ValueError(f"Path {path} exists but is not a file as expected.")
27 | else:
28 | # If the path does not exist, create it
29 | if is_directory:
30 | # Create all intermediate directories
31 | os.makedirs(path, exist_ok=True)
32 | else:
33 | # Create the directory path to the file if it doesn't exist
34 | os.makedirs(os.path.dirname(path), exist_ok=True)
35 | # Create the file
36 | with open(path, "w"):
37 | pass # Creating an empty file
--------------------------------------------------------------------------------
/utils/ingest.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 |
4 | import docx2txt
5 | from bs4 import BeautifulSoup
6 | from icecream import ic
7 | from pypdf import PdfReader
8 | from starlette.datastructures import UploadFile # err if "from fastapi"
9 | from streamlit.runtime.uploaded_file_manager import UploadedFile
10 | from langchain_core.documents import Document
11 |
12 | allowed_extensions = [
13 | "",
14 | ".txt",
15 | ".md",
16 | ".rtf",
17 | ".log",
18 | ".pdf",
19 | ".docx",
20 | ".html",
21 | ".htm",
22 | ]
23 |
24 |
25 | def get_page_texts_from_pdf(file):
26 | reader = PdfReader(file)
27 | return [page.extract_text() for page in reader.pages]
28 |
29 |
30 | DEFAULT_PAGE_START = "PAGE {page_num}:\n"
31 | DEFAULT_PAGE_SEP = "\n" + "-" * 3 + "\n\n"
32 |
33 |
34 | def get_text_from_pdf(file, page_start=DEFAULT_PAGE_START, page_sep=DEFAULT_PAGE_SEP):
35 | return page_sep.join(
36 | [
37 | page_start.replace("{page_num}", str(i)) + x
38 | for i, x in enumerate(get_page_texts_from_pdf(file), start=1)
39 | ]
40 | )
41 |
42 |
43 | def extract_text(files, allow_all_ext):
44 | docs = []
45 | failed_files = []
46 | unsupported_ext_files = []
47 | for file in files:
48 | if isinstance(file, UploadFile):
49 | file_name = file.filename or "unnamed-file" # need?
50 | file = file.file
51 | else:
52 | file_name = file.name
53 | try:
54 | extension = os.path.splitext(file_name)[1]
55 | if not allow_all_ext and extension not in allowed_extensions:
56 | unsupported_ext_files.append(file_name)
57 | continue
58 | if extension == ".pdf":
59 | for i, text in enumerate(get_page_texts_from_pdf(file)):
60 | metadata = {"source": f"{file_name} (page {i + 1})"}
61 | docs.append(Document(page_content=text, metadata=metadata))
62 | else:
63 | if extension == ".docx":
64 | text = docx2txt.process(file)
65 | elif extension in [".html", ".htm"]:
66 | soup = BeautifulSoup(file, "html.parser")
67 | # Remove script and style elements
68 | for script_or_style in soup(["script", "style"]):
69 | script_or_style.extract()
70 | text = soup.get_text()
71 | # Replace multiple newlines with single newlines
72 | text = re.sub(r"\n{2,}", "\n\n", text)
73 | print(text)
74 | else:
75 | # Treat as text file
76 | if isinstance(file, UploadedFile):
77 | text = file.getvalue().decode("utf-8")
78 | # NOTE: not sure what the advantage is for Streamlit's UploadedFile
79 | else:
80 | text = file.read().decode("utf-8")
81 | docs.append(Document(page_content=text, metadata={"source": file_name}))
82 | except Exception as e:
83 | ic(e)
84 | failed_files.append(file_name)
85 |
86 | ic(len(docs), failed_files, unsupported_ext_files)
87 | return docs, failed_files, unsupported_ext_files
88 |
89 | def format_ingest_failure(failed_files, unsupported_ext_files):
90 | res = (
91 | "Apologies, the following files failed to process:\n```\n"
92 | + "\n".join(unsupported_ext_files + failed_files)
93 | + "\n```"
94 | )
95 | if unsupported_ext_files:
96 | unsupported_extentions = sorted(
97 | list(set(os.path.splitext(x)[1] for x in unsupported_ext_files))
98 | )
99 | res += "\n\nUnsupported file extensions: " + ", ".join(
100 | f"`{x.lstrip('.') or ''}`" for x in unsupported_extentions
101 | )
102 | return res
--------------------------------------------------------------------------------
/utils/input.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Iterable
2 |
3 | def get_menu_choice(
4 | choices: Iterable[str],
5 | prompt: str = "Please enter the desired number: ",
6 | default: int | None = None,
7 | ) -> int:
8 | """
9 | Get a choice from the user from a list of choices.
10 |
11 | Returns the index of the choice selected by the user. Prompts the user
12 | repeatedly until a valid choice is entered.
13 | """
14 |
15 | while True:
16 | # Print the choices with numbered identifiers
17 | for idx, choice_text in enumerate(choices, start=1):
18 | print(f"{idx}. {choice_text}")
19 |
20 | # Prompt for the user's choice
21 | user_input = input(prompt)
22 |
23 | # Use default if the user input is empty and default is set
24 | if user_input == "" and default is not None:
25 | return default
26 |
27 | # Check if the input is a valid choice
28 | try:
29 | chosen_idx = int(user_input) - 1
30 | except ValueError:
31 | chosen_idx = -1
32 | if 0 <= chosen_idx < len(choices):
33 | return chosen_idx
34 | print()
35 |
36 | def get_choice_from_dict_menu(
37 | menu: dict[Any, str],
38 | prompt: str = "Please enter the desired number: ",
39 | default: Any | None = None,
40 | ) -> Any:
41 | """
42 | Get a choice from the user from a dictionary menu.
43 |
44 | Returns the key of the choice selected by the user. Prompts the user
45 | repeatedly until a valid choice is entered.
46 | """
47 |
48 | default_idx = None if default is None else -123
49 | idx = get_menu_choice(menu.values(), prompt, default_idx)
50 | return default if idx == -123 else list(menu.keys())[idx]
51 |
--------------------------------------------------------------------------------
/utils/log.py:
--------------------------------------------------------------------------------
1 | import atexit
2 | import datetime as dt
3 | import json
4 | import logging
5 | import logging.config
6 | import logging.handlers
7 | import pathlib
8 | from logging.config import ConvertingDict, ConvertingList, valid_ident
9 | from logging.handlers import QueueHandler, QueueListener
10 | from queue import Queue
11 |
12 | from utils.filesystem import ensure_path_exists
13 |
14 | # from typing import override
15 |
16 | LOG_RECORD_BUILTIN_ATTRS = {
17 | "args",
18 | "asctime",
19 | "created",
20 | "exc_info",
21 | "exc_text",
22 | "filename",
23 | "funcName",
24 | "levelname",
25 | "levelno",
26 | "lineno",
27 | "module",
28 | "msecs",
29 | "message",
30 | "msg",
31 | "name",
32 | "pathname",
33 | "process",
34 | "processName",
35 | "relativeCreated",
36 | "stack_info",
37 | "thread",
38 | "threadName",
39 | "taskName",
40 | }
41 |
42 |
43 | # https://github.com/mCodingLLC/VideosSampleCode/blob/master/videos/135_modern_logging/mylogger.py
44 | class MyJSONFormatter(logging.Formatter):
45 | def __init__(
46 | self,
47 | *,
48 | fmt_keys: dict[str, str] | None = None,
49 | ):
50 | super().__init__()
51 | self.fmt_keys = fmt_keys if fmt_keys is not None else {}
52 |
53 | # @override
54 | def format(self, record: logging.LogRecord) -> str:
55 | message = self._prepare_log_dict(record)
56 | return json.dumps(message, default=str)
57 |
58 | def _prepare_log_dict(self, record: logging.LogRecord):
59 | always_fields = {
60 | "message": record.getMessage(),
61 | "timestamp": dt.datetime.fromtimestamp(
62 | record.created, tz=dt.timezone.utc
63 | ).isoformat(),
64 | }
65 | if record.exc_info is not None:
66 | always_fields["exc_info"] = self.formatException(record.exc_info)
67 |
68 | if record.stack_info is not None:
69 | always_fields["stack_info"] = self.formatStack(record.stack_info)
70 |
71 | message = {
72 | key: msg_val
73 | if (msg_val := always_fields.pop(val, None)) is not None
74 | else getattr(record, val)
75 | for key, val in self.fmt_keys.items()
76 | }
77 | message.update(always_fields)
78 |
79 | for key, val in record.__dict__.items():
80 | if key not in LOG_RECORD_BUILTIN_ATTRS:
81 | message[key] = val
82 |
83 | return message
84 |
85 |
86 | class NonErrorFilter(logging.Filter):
87 | # @override
88 | def filter(self, record: logging.LogRecord) -> bool | logging.LogRecord:
89 | return record.levelno <= logging.INFO
90 |
91 |
92 | # https://rob-blackbourn.medium.com/how-to-use-python-logging-queuehandler-with-dictconfig-1e8b1284e27a
93 | def _resolve_handlers(h):
94 | """Converts a list of handlers specified in a JSON file to actual handler objects."""
95 | if isinstance(h, ConvertingList):
96 | # Getting each item converts it to an actual object
97 | h = [h[i] for i in range(len(h))]
98 | return h
99 |
100 |
101 | def _resolve_convertingdict(q):
102 | """Resolves a dict specified in a JSON file to an actual object, if necessary."""
103 | if not isinstance(q, ConvertingDict):
104 | return q
105 | if "__resolved_value__" in q:
106 | return q["__resolved_value__"]
107 |
108 | cname = q.pop("class")
109 | klass = q.configurator.resolve(cname)
110 | props = q.pop(".", None)
111 | kwargs = {k: q[k] for k in q if valid_ident(k)}
112 | result = klass(**kwargs)
113 | if props:
114 | for name, value in props.items():
115 | setattr(result, name, value)
116 |
117 | q["__resolved_value__"] = result
118 | return result
119 |
120 |
121 | class QueueListenerHandler(QueueHandler):
122 | def __init__(
123 | self, handlers, respect_handler_level=False, auto_run=True, queue=Queue(0)
124 | ):
125 | queue = _resolve_convertingdict(queue)
126 | super().__init__(queue)
127 | handlers = _resolve_handlers(handlers)
128 | self._listener = QueueListener(
129 | self.queue, *handlers, respect_handler_level=respect_handler_level
130 | )
131 | if auto_run:
132 | self.start()
133 | atexit.register(self.stop)
134 |
135 | def start(self):
136 | self._listener.start()
137 |
138 | def stop(self):
139 | self._listener.stop()
140 |
141 | def emit(self, record): # NOTE: remove?
142 | return super().emit(record)
143 |
144 |
145 | def setup_logging(log_level: str | None, log_format: str | None):
146 | ensure_path_exists("logs/ddg-logs.jsonl")
147 | with open(pathlib.Path("config/logging.json")) as f:
148 | config = json.load(f)
149 | if log_format:
150 | config["formatters"]["custom"] = {
151 | "format": log_format,
152 | "datefmt": "%Y-%m-%dT%H:%M:%S%z",
153 | }
154 | config["handlers"]["h1_stderr"]["formatter"] = "custom"
155 |
156 | if log_level:
157 | config["handlers"]["h1_stderr"]["level"] = log_level
158 |
159 | logging.config.dictConfig(config)
160 |
--------------------------------------------------------------------------------
/utils/older/prompts-checkpoints.py:
--------------------------------------------------------------------------------
1 | iterative_report_improver_template = """\
2 | You are ARIA, Advanced Report Improvement Assistant.
3 |
4 | For this task, you can use the following information retrieved from the Internet:
5 |
6 | {new_info}
7 |
8 | END OF RETRIEVED INFORMATION
9 |
10 | Your task: pinpoint areas of improvement in the report/answer prepared in response to the following query:
11 |
12 | {query}
13 |
14 | END OF QUERY. REPORT/ANSWER TO ANALYZE:
15 |
16 | {previous_report}
17 |
18 | END OF REPORT
19 |
20 | Please decide how specifically the RETRIEVED INFORMATION you were provided at the beginning can address any areas of improvement in the report: additions, corrections, deletions, perhaps even a complete rewrite.
21 |
22 | Please write: "ACTION ITEMS FOR IMPROVEMENT:" then provide a numbered list of the individual SOURCEs in the RETRIEVED INFORMATION: first the URL, then one brief sentence stating whether its CONTENT from above will be used to enhance the report - be as specific as possible and if that particular content is not useful then just write "NOT RELEVANT".
23 |
24 | Add one more item in your numbered list - any additional instructions you can think of for improving the report/answer, independent of the RETRIEVED INFORMATION, for example how to rearrange sections, what to remove, reword, etc.
25 |
26 | After that, write: "NEW REPORT:" and write a new report from scratch. Important: any action items you listed must be **fully** implemented in your report, in which case your report must necessarily be different from the original report. In fact, the new report can be completely different if needed, the only concern is to craft an informative, no-fluff answer to the user's query:
27 |
28 | {query}
29 |
30 | END OF QUERY. Your report/answer should be: {report_type}.
31 |
32 | Finish with: "REPORT ASSESSMENT: X%, where X is your estimate of how well your new report serves user's information need on a scale from 0% to 100%, based on their query.
33 | """
--------------------------------------------------------------------------------
/utils/output.py:
--------------------------------------------------------------------------------
1 | class ConditionalLogger:
2 | def __init__(self, verbose=True):
3 | self.verbose = verbose
4 |
5 | def log(self, message):
6 | if self.verbose:
7 | print(message)
8 |
9 | def log_no_newline(self, message):
10 | if self.verbose:
11 | print(message, end="")
12 |
13 | def log_error(self, message):
14 | print(message)
15 |
16 |
17 | def format_exception(e: Exception) -> str:
18 | """
19 | Format an exception to include both its type and message.
20 |
21 | Args:
22 | e (Exception): The exception to be formatted.
23 |
24 | Returns:
25 | str: A string representation of the exception, including its type and message.
26 |
27 | Example:
28 | try:
29 | raise ValueError("An example error")
30 | except ValueError as e:
31 | formatted_exception = format_exception(e)
32 | print(formatted_exception) # Output: "ValueError: An example error"
33 | """
34 | try:
35 | return f"{type(e).__name__}: {e}"
36 | except Exception:
37 | return f"Unknown error: {e}"
38 |
39 |
--------------------------------------------------------------------------------
/utils/prepare.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import sys
4 |
5 | from dotenv import load_dotenv
6 |
7 | from utils.log import setup_logging
8 |
9 | load_dotenv(override=True)
10 |
11 | # Import chromadb while making it think sqlite3 has a new enough version.
12 | # This is necessary because the version of sqlite3 in Streamlit Share is too old.
13 | # We don't actually use sqlite3 in a Streamlit Share context, because we use HttpClient,
14 | # so it's fine to use the app there despite ending up with an incompatible sqlite3.
15 | sys.modules["sqlite3"] = lambda: None
16 | sys.modules["sqlite3"].sqlite_version_info = (42, 42, 42)
17 | __import__("chromadb")
18 | del sys.modules["sqlite3"]
19 | __import__("sqlite3") # import here because chromadb was supposed to import it
20 |
21 | # Set up logging
22 | LOG_LEVEL = os.getenv("LOG_LEVEL")
23 | LOG_FORMAT = os.getenv("LOG_FORMAT")
24 | setup_logging(LOG_LEVEL, LOG_FORMAT)
25 | DEFAULT_LOGGER_NAME = os.getenv("DEFAULT_LOGGER_NAME", "ddg")
26 |
27 |
28 | def get_logger(logger_name: str = DEFAULT_LOGGER_NAME):
29 | return logging.getLogger(logger_name)
30 |
31 |
32 | # Set up the environment variables
33 | DEFAULT_OPENAI_API_KEY = os.getenv("DEFAULT_OPENAI_API_KEY", "")
34 | IS_AZURE = bool(os.getenv("OPENAI_API_BASE") or os.getenv("AZURE_OPENAI_API_KEY"))
35 | EMBEDDINGS_DEPLOYMENT_NAME = os.getenv("EMBEDDINGS_DEPLOYMENT_NAME")
36 | CHAT_DEPLOYMENT_NAME = os.getenv("CHAT_DEPLOYMENT_NAME")
37 |
38 | DEFAULT_COLLECTION_NAME = os.getenv("DEFAULT_COLLECTION_NAME", "docdocgo-documentation")
39 |
40 | if USE_CHROMA_VIA_HTTP := bool(os.getenv("USE_CHROMA_VIA_HTTP")):
41 | os.environ["CHROMA_API_IMPL"] = "rest"
42 |
43 | # The following three variables are only used if USE_CHROMA_VIA_HTTP is True
44 | CHROMA_SERVER_HOST = os.getenv("CHROMA_SERVER_HOST", "localhost")
45 | CHROMA_SERVER_HTTP_PORT = os.getenv("CHROMA_SERVER_HTTP_PORT", "8000")
46 | CHROMA_SERVER_AUTHN_CREDENTIALS = os.getenv("CHROMA_SERVER_AUTHN_CREDENTIALS", "")
47 |
48 | # The following variable is only used if USE_CHROMA_VIA_HTTP is False
49 | VECTORDB_DIR = os.getenv("VECTORDB_DIR", "chroma/")
50 |
51 | MODEL_NAME = os.getenv("MODEL_NAME", "gpt-4o-mini") # rename to DEFAULT_MODEL?
52 | CONTEXT_LENGTH = int(os.getenv("CONTEXT_LENGTH", 16000)) # it's actually more like max
53 | # size of what we think we can feed to the model so that it doesn't get overwhelmed
54 | TEMPERATURE = float(os.getenv("TEMPERATURE", 0.3))
55 |
56 | ALLOWED_MODELS = os.getenv("ALLOWED_MODELS", MODEL_NAME).split(",")
57 | ALLOWED_MODELS = [model.strip() for model in ALLOWED_MODELS]
58 | if MODEL_NAME not in ALLOWED_MODELS:
59 | raise ValueError("The default model must be in the list of allowed models.")
60 |
61 | EMBEDDINGS_MODEL_NAME = os.getenv("EMBEDDINGS_MODEL_NAME", "text-embedding-3-large")
62 | EMBEDDINGS_DIMENSIONS = int(os.getenv("EMBEDDINGS_DIMENSIONS", 3072))
63 |
64 | LLM_REQUEST_TIMEOUT = float(os.getenv("LLM_REQUEST_TIMEOUT", 9))
65 |
66 | DEFAULT_MODE = os.getenv("DEFAULT_MODE", "/kb")
67 |
68 | INCLUDE_ERROR_IN_USER_FACING_ERROR_MSG = bool(
69 | os.getenv("INCLUDE_ERROR_IN_USER_FACING_ERROR_MSG")
70 | )
71 |
72 | BYPASS_SETTINGS_RESTRICTIONS = bool(os.getenv("BYPASS_SETTINGS_RESTRICTIONS"))
73 | BYPASS_SETTINGS_RESTRICTIONS_PASSWORD = os.getenv(
74 | "BYPASS_SETTINGS_RESTRICTIONS_PASSWORD"
75 | )
76 |
77 | DOMAIN_NAME_FOR_SHARING = os.getenv("DOMAIN_NAME_FOR_SHARING", "shared")
78 |
79 | MAX_UPLOAD_BYTES = int(os.getenv("MAX_UPLOAD_BYTES", 100 * 1024 * 1024))
80 |
81 | INITIAL_TEST_QUERY_STREAMLIT = os.getenv("INITIAL_QUERY_STREAMLIT")
82 |
83 | # Check that the necessary environment variables are set
84 | DUMMY_OPENAI_API_KEY_PLACEHOLDER = "DUMMY NON-EMPTY VALUE"
85 |
86 | if IS_AZURE and not (
87 | EMBEDDINGS_DEPLOYMENT_NAME
88 | and CHAT_DEPLOYMENT_NAME
89 | and os.getenv("AZURE_OPENAI_API_KEY")
90 | and os.getenv("OPENAI_API_BASE")
91 | ):
92 | raise ValueError(
93 | "You have set some but not all environment variables necessary to utilize the "
94 | "Azure OpenAI API endpoint. Please refer to .env.example for details."
95 | )
96 | elif not IS_AZURE and not DEFAULT_OPENAI_API_KEY:
97 | # We don't exit because we could get the key from the Streamlit app
98 | print(
99 | "WARNING: You have not set the DEFAULT_OPENAI_API_KEY environment variable. "
100 | "This is ok when running the Streamlit app, but not when running "
101 | "the command line app. For now, we will set it to a dummy non-empty value "
102 | "to avoid problems initializing the vectorstore etc. "
103 | "Please refer to .env.example for additional information."
104 | )
105 | os.environ["DEFAULT_OPENAI_API_KEY"] = DUMMY_OPENAI_API_KEY_PLACEHOLDER
106 | DEFAULT_OPENAI_API_KEY = DUMMY_OPENAI_API_KEY_PLACEHOLDER
107 | # TODO investigate the behavior when this happens
108 |
109 | if not os.getenv("SERPER_API_KEY") and not os.getenv("IGNORE_LACK_OF_SERPER_API_KEY"):
110 | raise ValueError(
111 | "You have not set the SERPER_API_KEY environment variable, "
112 | "which is necessary for the Internet search functionality."
113 | "Pease set the SERPER_API_KEY environment variable to your Google Serper API key, "
114 | "which you can get for free, with no credit card, at https://serper.dev. "
115 | "This free key will allow you to make about 1250 searches until payment is required.\n\n"
116 | "If you want to supress this error, set the IGNORE_LACK_OF_SERPER_API_KEY environment "
117 | "variable to any non-empty value. You can then use features that do not require the "
118 | "Internet search functionality."
119 | )
120 |
121 | # Verify the validity of the db path
122 | if not os.getenv("USE_CHROMA_VIA_HTTP") and not os.path.isdir(VECTORDB_DIR):
123 | try:
124 | abs_path = os.path.abspath(VECTORDB_DIR)
125 | except Exception:
126 | abs_path = "INVALID PATH"
127 | raise ValueError(
128 | "You have not specified a valid directory for the vector database. "
129 | "Please set the VECTORDB_DIR environment variable in .env, as shown in .env.example. "
130 | "Alternatively, if you have a Chroma DB server running, you can set the "
131 | "USE_CHROMA_VIA_HTTP environment variable to any non-empty value. "
132 | f"\n\nThe path you have specified is: {VECTORDB_DIR}.\n"
133 | f"The absolute path resolves to: {abs_path}."
134 | )
135 |
136 |
137 | is_env_loaded = True
138 |
--------------------------------------------------------------------------------
/utils/rag.py:
--------------------------------------------------------------------------------
1 | from langchain_text_splitters import RecursiveCharacterTextSplitter
2 |
3 | rag_text_splitter = RecursiveCharacterTextSplitter(
4 | chunk_size=400,
5 | chunk_overlap=40,
6 | add_start_index=True, # metadata will include start index of snippet in original doc
7 | )
8 |
--------------------------------------------------------------------------------
/utils/streamlit/fix_event_loop.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 |
3 |
4 | def remove_tornado_fix() -> None:
5 | """
6 | UNDOES THE FOLLOWING OBSOLETE FIX IN streamlit.web.bootstrap.run():
7 | Set default asyncio policy to be compatible with Tornado 6.
8 |
9 | Tornado 6 (at least) is not compatible with the default
10 | asyncio implementation on Windows. So here we
11 | pick the older SelectorEventLoopPolicy when the OS is Windows
12 | if the known-incompatible default policy is in use.
13 |
14 | This has to happen as early as possible to make it a low priority and
15 | overridable
16 |
17 | See: https://github.com/tornadoweb/tornado/issues/2608
18 |
19 | FIXME: if/when tornado supports the defaults in asyncio,
20 | remove and bump tornado requirement for py38
21 | """
22 | try:
23 | from asyncio import ( # type: ignore[attr-defined]
24 | WindowsProactorEventLoopPolicy,
25 | WindowsSelectorEventLoopPolicy,
26 | )
27 | except ImportError:
28 | pass
29 | # Not affected
30 | else:
31 | if type(asyncio.get_event_loop_policy()) is WindowsSelectorEventLoopPolicy:
32 | asyncio.set_event_loop_policy(WindowsProactorEventLoopPolicy())
33 |
--------------------------------------------------------------------------------
/utils/streamlit/helpers.py:
--------------------------------------------------------------------------------
1 | import re
2 | import time
3 | from typing import Any
4 |
5 | import streamlit as st
6 | from pydantic import BaseModel
7 |
8 | from utils.type_utils import ChatMode
9 |
10 | WELCOME_TEMPLATE_COLL_IN_URL = """\
11 | Welcome! You are now using the `{}` collection. \
12 | When chatting, I will use its content as my knowledge base. \
13 | To get help using me, just type `/help `.\
14 | """
15 |
16 | NO_AUTH_TEMPLATE_COLL_IN_URL = """\
17 | Apologies, the current URL doesn't provide access to the \
18 | requested collection `{}`. This can happen if the \
19 | access code is invalid or if the collection doesn't exist.
20 |
21 | I have switched to my default collection.
22 |
23 | To get help using me, just type `/help `.\
24 | """
25 |
26 | WELCOME_MSG_NEW_CREDENTIALS = """\
27 | Welcome! The user credentials have changed but you are still \
28 | authorised for the current collection.\
29 | """
30 |
31 | NO_AUTH_MSG_NEW_CREDENTIALS = """\
32 | Welcome! The user credentials have changed, \
33 | so I've switched to the default collection.\
34 | """
35 |
36 | POST_INGEST_MESSAGE_TEMPLATE_NEW_COLL = """\
37 | Your documents have been uploaded to a new collection: `{coll_name}`. \
38 | You can now use this collection in your queries. If you are done uploading, \
39 | you can rename it:
40 | ```
41 | /db rename my-cool-collection-name
42 | ```
43 | """
44 |
45 | POST_INGEST_MESSAGE_TEMPLATE_EXISTING_COLL = """\
46 | Your documents have been added to the existing collection: `{coll_name}`. \
47 | If you are done uploading, you can rename it:
48 | ```
49 | /db rename my-cool-collection-name
50 | ```
51 | """
52 |
53 | mode_option_to_prefix = {
54 | "/kb (main mode)": (
55 | "", # TODO: this presupposes that DEFAULT_MODE is /kb
56 | "Chat using the current collection as a knowledge base.",
57 | ),
58 | "/research": (
59 | "/re ",
60 | "Research a topic on the web.",
61 | ),
62 | "/research heatseek": (
63 | "/re hs ",
64 | "Find sites that have specific information you need.",
65 | ),
66 | "/summarize": (
67 | "/su ",
68 | "Summarize the content of a URL and ingest into a collection.",
69 | ),
70 | "/ingest": (
71 | "/in ",
72 | "Ingest content from a URL into a collection.",
73 | ),
74 | "/db": (
75 | "/db ",
76 | "Manage your collections.",
77 | ),
78 | "/export": (
79 | "/ex ",
80 | "Export your current chat history.",
81 | ),
82 | "/share": (
83 | "/sh ",
84 | "Create a shareable link to the current collection.",
85 | ),
86 | "/web": (
87 | "/we ",
88 | "Do a quick web research without ingesting the content.",
89 | ),
90 | "/details": (
91 | "/de ",
92 | "Provide details on everything retrieved from the collection in response to your query.",
93 | ),
94 | "/quotes": (
95 | "/qu ",
96 | "Provide quotes from the collection in relevant to your query.",
97 | ),
98 | "/chat": (
99 | "/ch ",
100 | "Regular chat, without retrieving information from the current collection.",
101 | ),
102 | }
103 | mode_options = list(mode_option_to_prefix.keys())
104 |
105 | chat_with_docs_status_config = {
106 | "thinking.header": "One sec...",
107 | "thinking.body": "Retrieving sources and composing reply...",
108 | "complete.header": "Done!",
109 | "complete.body": "Reply composed.",
110 | "error.header": "Error.",
111 | "error.body": "Apologies, there was an error.",
112 | }
113 |
114 | just_chat_status_config = chat_with_docs_status_config | {
115 | "thinking.body": "Composing reply...",
116 | }
117 |
118 | web_status_config = chat_with_docs_status_config | {
119 | "thinking.header": "Doing Internet research (takes 10-30s)...",
120 | "thinking.body": "Retrieving content from websites and composing report...",
121 | "complete.body": "Report composed.",
122 | }
123 |
124 | research_status_config = chat_with_docs_status_config | {
125 | "thinking.header": "Doing Internet research (takes 10-30s)...",
126 | "thinking.body": "Retrieving content from websites and composing report...",
127 | "complete.body": "Report composed and sources added to the document collection.",
128 | }
129 |
130 | ingest_status_config = chat_with_docs_status_config | {
131 | "thinking.header": "Ingesting...",
132 | "thinking.body": "Retrieving and ingesting content...",
133 | "complete.body": "Content was added to the document collection.",
134 | }
135 |
136 | summarize_status_config = chat_with_docs_status_config | {
137 | "thinking.header": "Summarizing and ingesting...",
138 | "thinking.body": "Retrieving, summarizing, and ingesting content...",
139 | "complete.body": "Summary composed and retrieved content added to the document collection.",
140 | }
141 |
142 | status_config = {
143 | ChatMode.JUST_CHAT_COMMAND_ID: just_chat_status_config,
144 | ChatMode.CHAT_WITH_DOCS_COMMAND_ID: chat_with_docs_status_config,
145 | ChatMode.DETAILS_COMMAND_ID: chat_with_docs_status_config,
146 | ChatMode.QUOTES_COMMAND_ID: chat_with_docs_status_config,
147 | ChatMode.HELP_COMMAND_ID: chat_with_docs_status_config,
148 | ChatMode.WEB_COMMAND_ID: web_status_config,
149 | ChatMode.RESEARCH_COMMAND_ID: research_status_config,
150 | ChatMode.INGEST_COMMAND_ID: ingest_status_config,
151 | ChatMode.SUMMARIZE_COMMAND_ID: summarize_status_config,
152 | }
153 |
154 | STAND_BY_FOR_INGESTION_MESSAGE = (
155 | "\n\n--- \n\n> **PLEASE STAND BY WHILE CONTENT IS INGESTED...**"
156 | )
157 |
158 |
159 | def get_init_msg(is_initial_load, is_authorised, coll_name_in_url, init_coll_name):
160 | if not is_initial_load: # key changed without reloading the page
161 | if is_authorised:
162 | return WELCOME_MSG_NEW_CREDENTIALS
163 | else:
164 | return NO_AUTH_MSG_NEW_CREDENTIALS
165 | elif coll_name_in_url: # page (re)load with coll in URL
166 | if is_authorised:
167 | return WELCOME_TEMPLATE_COLL_IN_URL.format(init_coll_name)
168 | else:
169 | return NO_AUTH_TEMPLATE_COLL_IN_URL.format(init_coll_name)
170 | else: # page (re)load without coll in URL
171 | return None # just to be explicit
172 |
173 |
174 | def update_url(params: dict[str, str]):
175 | st.query_params.clear() # NOTE: should be a way to set them in one go
176 | st.query_params.update(params)
177 |
178 |
179 | def update_url_if_scheduled():
180 | if st.session_state.update_query_params is not None:
181 | update_url(st.session_state.update_query_params)
182 | st.session_state.update_query_params = None
183 |
184 |
185 | def escape_dollars(text: str) -> str:
186 | """
187 | Escape dollar signs in the text that come before numbers.
188 | """
189 | return re.sub(r"\$(?=\d)", r"\$", text)
190 |
191 |
192 | def fix_markdown(text: str) -> str:
193 | """
194 | Escape dollar signs in the text that come before numbers and add
195 | two spaces before every newline.
196 | """
197 | return re.sub(r"\$(?=\d)", r"\$", text).replace("\n", " \n")
198 |
199 |
200 | def write_slowly(message_placeholder, answer, delay=None):
201 | """Write a message to the message placeholder not all at once but word by word."""
202 | pieces = answer.split(" ")
203 | delay = delay or min(0.025, 10 / (len(pieces) + 1))
204 | for i in range(1, len(pieces) + 1):
205 | message_placeholder.markdown(fix_markdown(" ".join(pieces[:i])))
206 | time.sleep(delay)
207 |
208 |
209 | def show_sources(
210 | sources: list[str] | None,
211 | callback_handler=None,
212 | ):
213 | """Show the sources if present."""
214 | # If the cb handler is provided, remove the stand-by message
215 | if callback_handler and callback_handler.end_str_printed:
216 | callback_handler.container.markdown(fix_markdown(callback_handler.buffer))
217 |
218 | if not sources:
219 | return
220 | with st.expander("Sources"):
221 | for source in sources:
222 | st.markdown(source)
223 |
224 |
225 | def show_uploader(is_teleporting=False, border=True):
226 | if is_teleporting:
227 | try:
228 | st.session_state.uploader_placeholder.empty()
229 | except AttributeError:
230 | pass # should never happen, but just in case
231 | # Switch between one of the two possible keys (to avoid duplicate key error)
232 | st.session_state.uploader_form_key = (
233 | "uploader-form-alt"
234 | if st.session_state.uploader_form_key == "uploader-form"
235 | else "uploader-form"
236 | )
237 |
238 | # Show the uploader
239 | st.session_state.uploader_placeholder = st.empty()
240 | with st.session_state.uploader_placeholder:
241 | with st.form(
242 | st.session_state.uploader_form_key, clear_on_submit=True, border=border
243 | ):
244 | files = st.file_uploader(
245 | "Upload your documents",
246 | accept_multiple_files=True,
247 | label_visibility="collapsed",
248 | )
249 | cols = st.columns([1, 1])
250 | with cols[0]:
251 | is_submitted = st.form_submit_button("Upload")
252 | with cols[1]:
253 | allow_all_ext = st.toggle("Allow all extensions", value=False)
254 | return (files if is_submitted else []), allow_all_ext
255 |
256 |
257 | class DownloaderData(BaseModel):
258 | data: Any
259 | file_name: str
260 | mime: str = "text/plain"
261 |
262 |
263 | def show_downloader(
264 | downloader_data: DownloaderData | None = None, is_teleporting=False
265 | ):
266 | if downloader_data is None:
267 | try:
268 | # Use the data from the already existing downloader
269 | downloader_data = st.session_state.downloader_data
270 | except AttributeError:
271 | raise ValueError("downloader_data not provided and not in session state")
272 | else:
273 | st.session_state.downloader_data = downloader_data
274 |
275 | if is_teleporting:
276 | try:
277 | st.session_state.downloader_placeholder.empty()
278 | except AttributeError:
279 | pass # should never happen, but just in case
280 |
281 | # Switch between one of the two possible keys (to avoid duplicate key error)
282 | st.session_state.downloader_form_key = (
283 | "downloader-alt"
284 | if st.session_state.downloader_form_key == "downloader"
285 | else "downloader"
286 | )
287 |
288 | # Show the downloader
289 | st.session_state.downloader_placeholder = st.empty()
290 | with st.session_state.downloader_placeholder:
291 | is_downloaded = st.download_button(
292 | label="Download Exported Data",
293 | data=downloader_data.data,
294 | file_name=downloader_data.file_name,
295 | mime=downloader_data.mime,
296 | key=st.session_state.downloader_form_key,
297 | )
298 | return is_downloaded
299 |
--------------------------------------------------------------------------------
/utils/streamlit/ingest.py:
--------------------------------------------------------------------------------
1 | import uuid
2 |
3 | import streamlit as st
4 |
5 | from agentblocks.collectionhelper import ingest_into_collection
6 | from agents.dbmanager import (
7 | get_full_collection_name,
8 | get_user_facing_collection_name,
9 | )
10 | from utils.chat_state import ChatState
11 | from utils.helpers import ADDITIVE_COLLECTION_PREFIX, INGESTED_DOCS_INIT_PREFIX
12 | from utils.prepare import DEFAULT_COLLECTION_NAME
13 | from utils.query_parsing import IngestCommand
14 | from utils.streamlit.helpers import (
15 | POST_INGEST_MESSAGE_TEMPLATE_EXISTING_COLL,
16 | POST_INGEST_MESSAGE_TEMPLATE_NEW_COLL,
17 | )
18 | from langchain_core.documents import Document
19 |
20 |
21 | def ingest_docs(docs: list[Document], chat_state: ChatState):
22 | ingest_command = chat_state.parsed_query.ingest_command
23 |
24 | uploaded_docs_coll_name_as_shown = get_user_facing_collection_name(
25 | chat_state.user_id, chat_state.vectorstore.name
26 | ) # might reassign below
27 | if chat_state.vectorstore.name != DEFAULT_COLLECTION_NAME and (
28 | ingest_command == IngestCommand.ADD
29 | or (
30 | ingest_command == IngestCommand.DEFAULT
31 | and uploaded_docs_coll_name_as_shown.startswith(ADDITIVE_COLLECTION_PREFIX)
32 | )
33 | ):
34 | # We will use the same collection
35 | is_new_collection = False
36 | uploaded_docs_coll_name_full = chat_state.vectorstore.name
37 |
38 | # Fetch the collection metadata so we can add timestamps
39 | collection_metadata = chat_state.fetch_collection_metadata()
40 | else:
41 | # We will need to create a new collection
42 | is_new_collection = True
43 | collection_metadata = {}
44 | uploaded_docs_coll_name_as_shown = (
45 | INGESTED_DOCS_INIT_PREFIX + uuid.uuid4().hex[:8]
46 | )
47 | uploaded_docs_coll_name_full = get_full_collection_name(
48 | chat_state.user_id, uploaded_docs_coll_name_as_shown
49 | )
50 | try:
51 | vectorstore = ingest_into_collection(
52 | collection_name=uploaded_docs_coll_name_full,
53 | docs=docs,
54 | collection_metadata=collection_metadata,
55 | chat_state=chat_state,
56 | is_new_collection=is_new_collection,
57 | )
58 | if is_new_collection:
59 | # Switch to the newly created collection
60 | chat_state.vectorstore = vectorstore
61 | msg_template = POST_INGEST_MESSAGE_TEMPLATE_NEW_COLL
62 | else:
63 | msg_template = POST_INGEST_MESSAGE_TEMPLATE_EXISTING_COLL
64 |
65 | with st.chat_message("assistant", avatar=st.session_state.bot_avatar):
66 | st.markdown(msg_template.format(coll_name=uploaded_docs_coll_name_as_shown))
67 | except Exception as e:
68 | with st.chat_message("assistant", avatar=st.session_state.bot_avatar):
69 | st.markdown(
70 | f"Apologies, an error occurred during ingestion:\n```\n{e}\n```"
71 | )
72 |
--------------------------------------------------------------------------------
/utils/streamlit/prepare.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import streamlit as st
4 |
5 | from components.llm import CallbackHandlerDDGConsole
6 | from docdocgo import do_intro_tasks
7 | from utils.chat_state import ChatState
8 | from utils.prepare import DEFAULT_OPENAI_API_KEY, DUMMY_OPENAI_API_KEY_PLACEHOLDER
9 | from utils.streamlit.fix_event_loop import remove_tornado_fix
10 | from utils.type_utils import OperationMode
11 | from utils.streamlit.helpers import mode_options
12 |
13 |
14 | def prepare_app():
15 | if os.getenv("STREAMLIT_SCHEDULED_MAINTENANCE"):
16 | st.markdown("Welcome to DocDocGo! 🦉")
17 | st.markdown(
18 | "Scheduled maintenance is currently in progress. Please check back later."
19 | )
20 | st.stop()
21 |
22 | # Flag for whether or not the OpenAI API key has succeeded at least once
23 | st.session_state.llm_api_key_ok_status = False
24 |
25 | print("query params:", st.query_params)
26 | st.session_state.update_query_params = None
27 | st.session_state.init_collection_name = st.query_params.get("collection")
28 | st.session_state.access_code = st.query_params.get("access_code")
29 | try:
30 | remove_tornado_fix()
31 | vectorstore = do_intro_tasks(openai_api_key=DEFAULT_OPENAI_API_KEY)
32 | except Exception as e:
33 | st.error(
34 | "Apologies, I could not load the vector database. This "
35 | "could be due to a misconfiguration of the environment variables "
36 | f"or missing files. The error reads: \n\n{e}"
37 | )
38 | st.stop()
39 |
40 | st.session_state.chat_state = ChatState(
41 | operation_mode=OperationMode.STREAMLIT,
42 | vectorstore=vectorstore,
43 | callbacks=[
44 | CallbackHandlerDDGConsole(),
45 | "placeholder for CallbackHandlerDDGStreamlit",
46 | ],
47 | openai_api_key=DEFAULT_OPENAI_API_KEY,
48 | )
49 |
50 | st.session_state.prev_supplied_openai_api_key = None
51 | st.session_state.default_openai_api_key = DEFAULT_OPENAI_API_KEY
52 | if st.session_state.default_openai_api_key == DUMMY_OPENAI_API_KEY_PLACEHOLDER:
53 | st.session_state.default_openai_api_key = ""
54 |
55 | st.session_state.idx_file_upload = -1
56 | st.session_state.uploader_form_key = "uploader-form"
57 |
58 | st.session_state.idx_file_download = -1
59 | st.session_state.downloader_form_key = "downloader"
60 |
61 | st.session_state.user_avatar = os.getenv("USER_AVATAR") or None
62 | st.session_state.bot_avatar = os.getenv("BOT_AVATAR") or None # TODO: document this
63 |
64 | SAMPLE_QUERIES = os.getenv(
65 | "SAMPLE_QUERIES",
66 | "/research heatseek Code for row of buttons Streamlit, /research What are the biggest AI news this month?, /help How does infinite research work?",
67 | )
68 | st.session_state.sample_queries = [q.strip() for q in SAMPLE_QUERIES.split(",")]
69 | st.session_state.default_mode = mode_options[0]
--------------------------------------------------------------------------------
/utils/strings.py:
--------------------------------------------------------------------------------
1 | import json
2 | import re
3 | from typing import Iterable
4 |
5 | from utils.type_utils import JSONish
6 |
7 |
8 | def split_preserving_whitespace(text: str) -> tuple[list[str], list[str]]:
9 | """Split a string into parts, preserving whitespace. The result will
10 | be a tuple of two lists: the parts and the whitespaces. The whitespaces
11 | will be strings of whitespace characters (spaces, tabs, newlines, etc.).
12 | The original string can be reconstructed by joining the parts and the
13 | whitespaces, as follows:
14 | ''.join([w + p for p, w in zip(parts, whitespaces)]) + whitespaces[-1]
15 |
16 | In other words, the first whitespace is assumed to come before the first
17 | part, and the last whitespace is assumed to come after the last part. If needed,
18 | the first and/or last whitespace element can be an empty string.
19 |
20 | Each part is guaranteed to be non-empty.
21 |
22 | Returns:
23 | A tuple of two lists: the parts and the whitespaces.
24 | """
25 | parts = []
26 | whitespaces = []
27 | start = 0
28 | for match in re.finditer(r"\S+", text):
29 | end = match.start()
30 | parts.append(match.group())
31 | whitespaces.append(text[start:end])
32 | start = match.end()
33 | whitespaces.append(text[start:])
34 | return parts, whitespaces
35 |
36 |
37 | def remove_consecutive_blank_lines(
38 | lines: list[str], max_consecutive_blank_lines=1
39 | ) -> list[str]:
40 | """Remove consecutive blank lines from a list of lines."""
41 | new_lines = []
42 | num_consecutive_blank_lines = 0
43 | for line in lines:
44 | if line:
45 | # Non-blank line
46 | num_consecutive_blank_lines = 0
47 | new_lines.append(line)
48 | else:
49 | # Blank line
50 | num_consecutive_blank_lines += 1
51 | if num_consecutive_blank_lines <= max_consecutive_blank_lines:
52 | new_lines.append(line)
53 | return new_lines
54 |
55 |
56 | def limit_number_of_words(text: str, max_words: int) -> tuple[str, int]:
57 | """
58 | Limit the number of words in a text to a given number.
59 | Return the text and the number of words in it. Strip leading and trailing
60 | whitespace from the text.
61 | """
62 | parts, whitespaces = split_preserving_whitespace(text)
63 | num_words = min(max_words, len(parts))
64 | if num_words == 0:
65 | return "", 0
66 | text = (
67 | "".join(p + w for p, w in zip(parts[: num_words - 1], whitespaces[1:num_words]))
68 | + parts[num_words - 1]
69 | )
70 | return text, num_words
71 |
72 |
73 | def limit_num_characters(text: str, max_characters: int) -> str:
74 | """
75 | Limit the number of characters in a string to a given number. If the string
76 | is longer than the given number of characters, truncate it and append an
77 | ellipsis
78 | """
79 | if len(text) <= max_characters:
80 | return text
81 | if max_characters > 0:
82 | return text[: max_characters - 1] + "…"
83 | return ""
84 |
85 |
86 | def extract_json(text: str) -> JSONish:
87 | """
88 | Extract a single JSON object or array from a string. Determines the first
89 | occurrence of '{' or '[', and the last occurrence of '}' or ']', then
90 | extracts the JSON structure accordingly. Returns a dictionary or list on
91 | success, throws on parse errors, including a bracket mismatch.
92 | """
93 | length = len(text)
94 | first_curly_brace = text.find("{")
95 | last_curly_brace = text.rfind("}")
96 | first_square_brace = text.find("[")
97 | last_square_brace = text.rfind("]")
98 |
99 | assert (
100 | first_curly_brace + first_square_brace > -2
101 | ), "No opening curly or square bracket found"
102 |
103 | if first_curly_brace == -1:
104 | first_curly_brace = length
105 | elif first_square_brace == -1:
106 | first_square_brace = length
107 |
108 | assert (first_curly_brace < first_square_brace) == (
109 | last_curly_brace > last_square_brace
110 | ), "Mismatched curly and square brackets"
111 |
112 | first = min(first_curly_brace, first_square_brace)
113 | last = max(last_curly_brace, last_square_brace)
114 |
115 | assert first < last, "No closing bracket found"
116 |
117 | return json.loads(text[first : last + 1])
118 |
119 |
120 | def has_which_substring(text: str, substrings: Iterable[str]) -> str | None:
121 | """
122 | Return the first substring in the list that is a substring of the text.
123 | If no such substring is found, return None.
124 | """
125 | for substring in substrings:
126 | if substring in text:
127 | return substring
128 | return None
129 |
--------------------------------------------------------------------------------
/utils/type_utils.py:
--------------------------------------------------------------------------------
1 | from enum import Enum
2 | from typing import Any
3 | from langchain_core.documents.base import Document
4 | from langchain_core.runnables import RunnableSerializable
5 | from pydantic import BaseModel, Field
6 |
7 | from utils.prepare import MODEL_NAME, TEMPERATURE
8 | from langchain_core.callbacks import BaseCallbackHandler
9 |
10 | JSONish = str | int | float | dict[str, "JSONish"] | list["JSONish"]
11 | JSONishDict = dict[str, JSONish]
12 | JSONishSimple = str | int | float | dict[str, Any] | list # for use in pydantic models
13 | Props = dict[str, Any]
14 |
15 | PairwiseChatHistory = list[tuple[str, str]]
16 | CallbacksOrNone = list[BaseCallbackHandler] | None
17 | ChainType = RunnableSerializable[dict, str] # double check this
18 |
19 | OperationMode = Enum("OperationMode", "CONSOLE STREAMLIT FASTAPI")
20 |
21 |
22 | class ChatMode(Enum):
23 | NONE_COMMAND_ID = -1
24 | RETRY_COMMAND_ID = 0
25 | CHAT_WITH_DOCS_COMMAND_ID = 1
26 | DETAILS_COMMAND_ID = 2
27 | QUOTES_COMMAND_ID = 3
28 | WEB_COMMAND_ID = 4
29 | RESEARCH_COMMAND_ID = 5
30 | JUST_CHAT_COMMAND_ID = 6
31 | DB_COMMAND_ID = 7
32 | HELP_COMMAND_ID = 8
33 | INGEST_COMMAND_ID = 9
34 | EXPORT_COMMAND_ID = 10
35 | SUMMARIZE_COMMAND_ID = 11
36 | SHARE_COMMAND_ID = 12
37 |
38 |
39 | chat_modes_needing_llm = {
40 | ChatMode.WEB_COMMAND_ID,
41 | ChatMode.RESEARCH_COMMAND_ID,
42 | ChatMode.JUST_CHAT_COMMAND_ID,
43 | ChatMode.DETAILS_COMMAND_ID,
44 | ChatMode.QUOTES_COMMAND_ID,
45 | ChatMode.CHAT_WITH_DOCS_COMMAND_ID,
46 | ChatMode.SUMMARIZE_COMMAND_ID,
47 | ChatMode.HELP_COMMAND_ID,
48 | }
49 |
50 |
51 | class DDGError(Exception):
52 | default_user_facing_message = (
53 | "Apologies, I ran into some trouble when preparing a response to you."
54 | )
55 | default_http_status_code = 500
56 |
57 | def __init__(
58 | self,
59 | message: str | None = None,
60 | user_facing_message: str | None = None,
61 | http_status_code: int | None = None,
62 | ):
63 | super().__init__(message) # could include user_facing_message
64 |
65 | if user_facing_message is not None:
66 | self.user_facing_message = user_facing_message
67 | else:
68 | self.user_facing_message = self.default_user_facing_message
69 |
70 | if http_status_code is not None:
71 | self.http_status_code = http_status_code
72 | else:
73 | self.http_status_code = self.default_http_status_code
74 |
75 | @property
76 | def user_facing_message_full(self):
77 | if self.__cause__ is None:
78 | return self.user_facing_message
79 | return (
80 | f"{self.user_facing_message} The error reads:\n```\n{self.__cause__}\n```"
81 | )
82 |
83 |
84 | class BotSettings(BaseModel):
85 | llm_model_name: str = MODEL_NAME
86 | temperature: float = TEMPERATURE
87 |
88 |
89 | AccessRole = Enum("AccessRole", {"NONE": 0, "VIEWER": 1, "EDITOR": 2, "OWNER": 3})
90 |
91 | AccessCodeType = Enum("AccessCodeType", "NEED_ALWAYS NEED_ONCE NO_ACCESS")
92 |
93 |
94 | class CollectionUserSettings(BaseModel):
95 | access_role: AccessRole = AccessRole.NONE
96 |
97 |
98 | class AccessCodeSettings(BaseModel):
99 | code_type: AccessCodeType = AccessCodeType.NO_ACCESS
100 | access_role: AccessRole = AccessRole.NONE
101 |
102 |
103 | COLLECTION_USERS_METADATA_KEY = "collection_users"
104 |
105 |
106 | class CollectionPermissions(BaseModel):
107 | user_id_to_settings: dict[str, CollectionUserSettings] = Field(default_factory=dict)
108 | # NOTE: key "" refers to settings for a general user
109 |
110 | access_code_to_settings: dict[str, AccessCodeSettings] = Field(default_factory=dict)
111 |
112 | def get_user_settings(self, user_id: str | None) -> CollectionUserSettings:
113 | return self.user_id_to_settings.get(user_id or "", CollectionUserSettings())
114 |
115 | def set_user_settings(
116 | self, user_id: str | None, settings: CollectionUserSettings
117 | ) -> None:
118 | self.user_id_to_settings[user_id or ""] = settings
119 |
120 | def get_access_code_settings(self, access_code: str) -> AccessCodeSettings:
121 | return self.access_code_to_settings.get(access_code, AccessCodeSettings())
122 |
123 | def set_access_code_settings(
124 | self, access_code: str, settings: AccessCodeSettings
125 | ) -> None:
126 | self.access_code_to_settings[access_code] = settings
127 |
128 |
129 | INSTRUCT_SHOW_UPLOADER = "INSTRUCT_SHOW_UPLOADER"
130 | INSTRUCT_CACHE_ACCESS_CODE = "INSTRUCT_CACHE_ACCESS_CODE"
131 | INSTRUCT_AUTO_RUN_NEXT_QUERY = "INSTRUCT_AUTO_RUN_NEXT_QUERY"
132 | INSTRUCT_EXPORT_CHAT_HISTORY = "INSTRUCT_EXPORT_CHAT_HISTORY"
133 | # INSTRUCTION_SKIP_CHAT_HISTORY = "INSTRUCTION_SKIP_CHAT_HISTORY"
134 |
135 |
136 | class Instruction(BaseModel):
137 | type: str
138 | user_id: str | None = None
139 | access_code: str | None = None
140 | data: JSONishSimple | None = None # NOTE: absorb the two fields above into this one
141 |
142 |
143 | class Doc(BaseModel):
144 | """Pydantic-compatible version of Langchain's Document."""
145 |
146 | page_content: str
147 | metadata: dict[str, Any]
148 |
149 | @staticmethod
150 | def from_lc_doc(langchain_doc: Document) -> "Doc":
151 | return Doc(
152 | page_content=langchain_doc.page_content, metadata=langchain_doc.metadata
153 | )
154 |
155 | def to_lc_doc(self) -> Document:
156 | return Document(page_content=self.page_content, metadata=self.metadata)
157 |
--------------------------------------------------------------------------------