├── .gitignore ├── Dockerfile ├── LICENSE ├── Makefile ├── MassGPT.jpg ├── README.md ├── requirements.txt └── src ├── ann_index.py ├── ann_search.py ├── bot.py ├── completion.py ├── copytest.py ├── copytest_chat.py ├── db.py ├── download_hf_models_at_buildtime.py ├── exceptions.py ├── extract.py ├── github_api.py ├── gpt3.py ├── hn_summary_db.py ├── jigit.py ├── massgpt.py ├── models.py ├── pdf_text.py ├── querytest.py ├── results.txt ├── s3.py ├── s3_bucket ├── README.md ├── __init__.py ├── bucket.py └── exceptions.py ├── subprompt.py └── tokenizer.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | *.py~ 131 | *.txt~ 132 | *.hnsf 133 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:11.7.1-cudnn8-devel-ubuntu22.04 2 | 3 | # set working directory 4 | WORKDIR /app 5 | 6 | ARG DEBIAN_FRONTEND=noninteractive 7 | 8 | RUN apt-get update && apt-get upgrade -y && apt-get install -y python3-pip ffmpeg git 9 | 10 | # update pip 11 | RUN pip3 install --upgrade pip 12 | 13 | RUN pip3 install git+https://github.com/openai/whisper.git 14 | 15 | # add requirements 16 | COPY ./requirements.txt /app/requirements.txt 17 | 18 | # install requirements 19 | RUN pip3 install -r requirements.txt 20 | 21 | COPY src/*.py . 22 | 23 | # force download of the hugging face models into the container 24 | RUN python3 download_hf_models_at_buildtime.py 25 | 26 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | 2 | .PHONY: help build push all 3 | 4 | help: 5 | @echo "Makefile commands:" 6 | @echo "build" 7 | @echo "push" 8 | @echo "all" 9 | 10 | .DEFAULT_GOAL := all 11 | 12 | build: 13 | docker build -t jiggyai/massgpt:${TAG} . 14 | 15 | push: 16 | docker push jiggyai/massgpt:${TAG} 17 | 18 | all: build push 19 | -------------------------------------------------------------------------------- /MassGPT.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiggy-ai/mass-gpt/d6bc22912c1f88a1db5894df49ca35980c095bda/MassGPT.jpg -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | MassGPTBot 3 |

4 | 5 | **MassGPT** is an experimental open source Telegram bot that interacts with users by responding to messages via a GPT3.5 (davinci-003) completion that is conditioned by a shared sub-prompt context of all recent user messages. 6 | 7 | Message the MassGPT Telegram bot directly, and it will respond back to you taking into account the content of relevant messages it has received from others. Users are assigned a numeric user id that does not leak any Telegram user info. The messages you send are not private as they are accessible to other users via the bot or via the /context command. 8 | 9 | You can also send the bot a url to inject a summary of the web page text into the current context. Use the /context command to see the entire current context. 10 | 11 | Currently there is one global chat context of recent messages from other users, but I plan to scale this by using an embedding of the user message to dynamically assemble a relevant context via ANN query of past messages, message summaries, and url summaries. 12 | 13 | There are several motivations for creating this. One is to explore chatting with an LLM where the model is presenting some partial consensus of other user’s current inputs instead of the model’s background representations. Another is to explore dynamic prompt contexts which should have a lot of interesting applications given that GPT seems much less likely to hallucinate when summarizing from its current prompt context. 14 | 15 | Open to PRs if you are interested in hacking on this in any way, or feel free to message me @ wskish on Twitter or Telegram. 16 | 17 | 18 | 19 | **OpenAI** 20 | 21 | * OPENAI_API_KEY # your OpenAI API key 22 | 23 | 24 | **PostgresQL** 25 | 26 | Database for keeping track of items we have already seen and associated item info. 27 | 28 | - MASSGPT_POSTGRES_HOST # The database FQDN 29 | - MASSGPT_POSTGRES_USER # The database username 30 | - MASSGPT_POSTGRES_PASS # The database password 31 | 32 | **Telegram** 33 | 34 | * MASSGPT_TELEGRAM_API_TOKEN # The bot's telegram API token 35 | 36 | 37 | 38 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | loguru==0.6.0 2 | pydantic==1.10.2 3 | sqlmodel==0.0.8 4 | psycopg2-binary==2.9.5 5 | requests==2.28.1 6 | python-telegram-bot==20.b0 7 | BeautifulSoup4==4.11.1 8 | openai==0.25.0 9 | readability-lxml==0.8.1 10 | transformers==4.25.1 11 | markdown==3.4.1 12 | pdfminer.six==20221105 13 | openai==0.25.0 14 | readability-lxml==0.8.1 15 | pdfminer.six==20221105 16 | sentence_transformers==2.2.2 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | -------------------------------------------------------------------------------- /src/ann_index.py: -------------------------------------------------------------------------------- 1 | import hnswlib 2 | import psutil 3 | 4 | from s3 import bucket 5 | 6 | from loguru import logger 7 | from sqlmodel import Session, select 8 | from db import engine 9 | from models import * 10 | 11 | CPU_COUNT = psutil.cpu_count() 12 | ST_MODEL_NAME = 'multi-qa-mpnet-base-dot-v1' 13 | 14 | 15 | hnsw_index = hnswlib.Index(space='cosine', dim=768) 16 | hnsw_index.set_num_threads(int(CPU_COUNT/2)) 17 | 18 | hnsw_index.init_index(max_elements = 20000, 19 | ef_construction = 100, 20 | M = 16) 21 | 22 | count = 0 23 | with Session(engine) as session: 24 | for embedding in session.exec(select(Embedding).where(Embedding.collection == ST_MODEL_NAME)): 25 | if embedding.vector is None: 26 | session.delete(embedding) 27 | session.commit() 28 | continue 29 | hnsw_index.add_items([embedding.vector], [embedding.id]) 30 | count += 1 31 | print(embedding.id) 32 | 33 | filename = "index-%s-%d.hnsf" % (ST_MODEL_NAME, count) 34 | hnsw_index.save_index(filename) 35 | 36 | objkey = f"massgpt/ann-index/{ST_MODEL_NAME}-{count}.hnsw" 37 | bucket.upload_file(filename, objkey) 38 | 39 | with Session(engine) as session: 40 | ix = HnswIndex(collection = ST_MODEL_NAME, 41 | count = count, 42 | objkey = objkey) 43 | session.add(ix) 44 | session.commit() 45 | -------------------------------------------------------------------------------- /src/ann_search.py: -------------------------------------------------------------------------------- 1 | 2 | import hnswlib 3 | from db import engine 4 | from models import * 5 | from sentence_transformers import SentenceTransformer 6 | from sqlmodel import Session, select 7 | import psutil 8 | import hn_summary_db 9 | 10 | CPU_COUNT = psutil.cpu_count() 11 | 12 | ## Embedding Config 13 | ST_MODEL_NAME = 'multi-qa-mpnet-base-dot-v1' 14 | st_model = SentenceTransformer(ST_MODEL_NAME) 15 | 16 | hnsw_ix = hnswlib.Index(space='cosine', dim=768) 17 | hnsw_ix.load_index('index-multi-qa-mpnet-base-dot-v1-0.hnsf', max_elements=20000) 18 | hnsw_ix.set_ef(1000) 19 | 20 | STORY_SOURCES = [EmbeddingSource.hn_story_summary, EmbeddingSource.hn_story_title] 21 | 22 | 23 | def search(query : str): 24 | query = query.rstrip() 25 | print(query) 26 | vector = st_model.encode(query) 27 | 28 | ids, distances = hnsw_ix.knn_query([vector], k=10) 29 | ids = [int(i) for i in ids[0]] 30 | distances = [float(i) for i in distances[0]] 31 | with Session(engine) as session: 32 | ann_embeddings = session.exec(select(Embedding).where(Embedding.id.in_(ids))).all() 33 | hn_story_ids = [e.source_id for e in ann_embeddings if e.source in STORY_SOURCES] 34 | stories = hn_summary_db.stories(hn_story_ids) 35 | 36 | vid_to_distance = {vid : distance for vid,distance in zip(ids, distances)} 37 | sid_to_vid = {emb.source_id : emb.id for emb in ann_embeddings} 38 | 39 | def distance(story): 40 | return vid_to_distance[sid_to_vid[story.id]] 41 | 42 | stories.sort(key=distance) 43 | 44 | for story in stories: 45 | print(f"{distance(story):.2f} {story.title}") 46 | 47 | return stories 48 | 49 | if __name__ == "__main__": 50 | search("SBF fraud") 51 | while(True): 52 | stories = search(input("Search: ")) 53 | for s in stories: 54 | print() 55 | print(s.title) 56 | print(hn_summary_db.story_summary(s)) 57 | hn_summary_db.story_text(s) 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | -------------------------------------------------------------------------------- /src/bot.py: -------------------------------------------------------------------------------- 1 | """ 2 | MassGPT 3 | 4 | A destructive distillation of mass communication. 5 | 6 | Condition an LLM completion on a dynamically assembled subprompt context. 7 | 8 | Telegram: @MassGPTbot 9 | https://t.me/MassGPTbot 10 | 11 | Copyright (C) 2022 William S. Kish 12 | 13 | """ 14 | 15 | import os 16 | from sqlmodel import Session, select 17 | from loguru import logger 18 | from pydantic import BaseModel, Field 19 | from telegram import Update 20 | from telegram.ext import ApplicationBuilder, MessageHandler, CommandHandler, ContextTypes, filters 21 | import re 22 | 23 | import openai 24 | from db import engine 25 | from models import * 26 | from exceptions import * 27 | 28 | import massgpt 29 | 30 | 31 | # the bot app 32 | bot = ApplicationBuilder().token(os.environ['MASSGPT_TELEGRAM_API_TOKEN']).build() 33 | 34 | 35 | def extract_url(text: str): 36 | try: 37 | return re.search("(?Phttps?://[^\s]+)", text).group("url") 38 | except AttributeError: 39 | return None 40 | 41 | 42 | 43 | def get_telegram_user(update : Update) -> User: 44 | """ 45 | return database User object of the sender of a telegram message 46 | Create the user if it didn't previously exist. 47 | """ 48 | with Session(engine) as session: 49 | user = session.exec(select(User).where(User.telegram_id == update.message.from_user.id)).first() 50 | if not user: 51 | tuser = update.message.from_user 52 | user = User(username = tuser.username, 53 | first_name = tuser.first_name, 54 | last_name = tuser.last_name, 55 | telegram_id = tuser.id, 56 | telegram_is_bot = tuser.is_bot, 57 | telegram_is_premium = tuser.is_premium, 58 | telegram_lanuage_code = tuser.language_code) 59 | session.add(user) 60 | session.commit() 61 | session.refresh(user) 62 | return user 63 | 64 | 65 | 66 | async def message(update: Update, tgram_context: ContextTypes.DEFAULT_TYPE) -> None: 67 | """ 68 | Handle message received from user. 69 | Send back to the user the response text from the model. 70 | Handle exceptions by sending an error message to the user. 71 | """ 72 | user = get_telegram_user(update) 73 | text = update.message.text 74 | logger.info(f'{user.id} {user.first_name} {user.last_name} {user.username} {user.telegram_id}: "{text}"') 75 | try: 76 | url = extract_url(text) 77 | print("URL", url) 78 | if url: 79 | response = massgpt.summarize_url(user, url) 80 | else: 81 | response = massgpt.receive_message(user, text) 82 | await update.message.reply_text(response) 83 | except (openai.error.ServiceUnavailableError, openai.error.RateLimitError): 84 | await update.message.reply_text("The OpenAI server is overloaded.") 85 | except ExtractException: 86 | await update.message.reply_text("Unable to extract text from url.") 87 | except MinimumTokenLimit: 88 | await update.message.reply_text("Message too short; please send a longer message.") 89 | except MaximumTokenLimit: 90 | await update.message.reply_text("Message too large; please send a shorter message.") 91 | except Exception as e: 92 | logger.exception("error processing message") 93 | await update.message.reply_text("An exceptional condition occured.") 94 | 95 | 96 | bot.add_handler(MessageHandler(filters.TEXT & ~filters.COMMAND, message)) 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | async def command(update: Update, tgram_context: ContextTypes.DEFAULT_TYPE) -> None: 108 | """ 109 | Handle command from user 110 | context - Respond with the current chat context 111 | url - Summarize a url and add the summary to the chat context 112 | """ 113 | user = get_telegram_user(update) 114 | text = update.message.text 115 | 116 | logger.info(f'{user.id} {user.first_name} {user.last_name} {user.username} {user.telegram_id}: "{text}"') 117 | 118 | 119 | if text == '/context': 120 | await update.message.reply_text("The current context:") 121 | for msg in massgpt.current_context(): 122 | await update.message.reply_text(msg) 123 | return 124 | elif text == '/prompts': 125 | await update.message.reply_text(massgpt.current_prompts()) 126 | return 127 | elif text[:5] == '/url ': 128 | try: 129 | url = extract_url(text) 130 | response = massgpt.summarize_url(user, url) 131 | await update.message.reply_text(response) 132 | except (openai.error.ServiceUnavailableError, openai.error.RateLimitError): 133 | await update.message.reply_text("The OpenAI server is overloaded.") 134 | except ExtractException: 135 | await update.message.reply_text("Unable to extract text from url.") 136 | except Exception as e: 137 | logger.exception("error processing message") 138 | await update.message.reply_text("An exceptional condition occured.") 139 | return 140 | await update.message.reply_text("Send me a message and I will assemble a collection of recent or related messages into a GPT prompt context and prompt your message against that dynamic context, sending you the GPT response. Send /context to see the current prompt context. Send '/url ' to add a summary of the url to the context. Consider this to be a public chat and please maintain a kind and curious standard.") 141 | 142 | 143 | bot.add_handler(MessageHandler(filters.COMMAND, command)) 144 | 145 | 146 | logger.info("run_polling") 147 | bot.run_polling() 148 | -------------------------------------------------------------------------------- /src/completion.py: -------------------------------------------------------------------------------- 1 | # Completion Abstraction 2 | # Copyright(C) 2022 William S. Kish 3 | 4 | 5 | import os 6 | from pydantic import BaseModel, Field 7 | from sqlmodel import Session, select 8 | 9 | from models import Completion 10 | from subprompt import SubPrompt 11 | 12 | from exceptions import * 13 | 14 | 15 | class CompletionLimits(BaseModel): 16 | """ 17 | specification for limits for: 18 | min prompt tokens 19 | min completion tokens 20 | max completion tokens 21 | given a max_content tokens 22 | """ 23 | max_context : int 24 | min_prompt : int 25 | min_completion : int 26 | max_completion : int 27 | 28 | def max_completion_tokens(self, prompt : SubPrompt) -> int: 29 | """ 30 | returns the maximum completion tokens available given the max_context limit 31 | and the actual number of tokens in the prompt 32 | raises MinimumTokenLimit or MaximumTokenLimit exceptions if the prompt 33 | is too small or too big. 34 | """ 35 | if prompt.tokens < self.min_prompt: 36 | raise MinimumTokenLimit 37 | if prompt.tokens > self.max_prompt_tokens(): 38 | raise MaximumTokenLimit 39 | max_available_tokens = self.max_context - prompt.tokens 40 | if max_available_tokens > self.max_completion: 41 | return self.max_completion 42 | return max_available_tokens 43 | 44 | def max_prompt_tokens(self) -> int: 45 | """ 46 | return the maximum prompt size in tokens 47 | """ 48 | return self.max_context - self.min_completion 49 | 50 | 51 | 52 | class CompletionTask: 53 | """ 54 | A LLM completion task that shares a particular llm configuration and prompt/completion limit structure. 55 | This is a base class for a model-specific completion task. A model-api-specific implemention must 56 | at a minimum implement the _completion() method. 57 | See gpt3.GPT3CompletionTask for an example implementation. 58 | """ 59 | def __init__(self, 60 | limits : CompletionLimits, 61 | model : str) -> "CompletionTask" : 62 | 63 | self.limits = limits 64 | self.model = model 65 | 66 | def max_prompt_tokens(self) -> int: 67 | return self.limits.max_prompt_tokens() 68 | 69 | def limits(self) -> CompletionLimits: 70 | return self.limits 71 | 72 | def _completion(self, 73 | prompt : str, 74 | max_completion_tokens : int) -> str : 75 | """ 76 | perform the actual completion, returning the completion text string 77 | This should be implemented in a model-api-specific base class. 78 | """ 79 | pass 80 | 81 | def completion(self, prompt : SubPrompt) -> Completion: 82 | """ 83 | prompt the model with the specified prompt and return the resulting Completion 84 | """ 85 | # check prompt limits and return max completion size to request 86 | # given the size of the prompt and the configured limits 87 | max_completion = self.limits.max_completion_tokens(prompt) 88 | 89 | # perform the completion inference 90 | response = self._completion(prompt = str(prompt), 91 | max_completion_tokens = max_completion) 92 | 93 | completion = Completion(model = self.model, 94 | prompt = str(prompt), 95 | temperature = 0, # XXX set this as model params? 96 | completion = response) 97 | return completion 98 | -------------------------------------------------------------------------------- /src/copytest.py: -------------------------------------------------------------------------------- 1 | from loguru import logger 2 | from gpt3 import GPT3CompletionTask, CompletionLimits 3 | from exceptions import * 4 | from models import Completion 5 | 6 | from subprompt import SubPrompt 7 | 8 | class MessageSubPrompt(SubPrompt): 9 | """ 10 | SubPrompt Context for a user-generated message 11 | """ 12 | MAX_TOKENS = 300 13 | @classmethod 14 | def from_user_str(cls, username : str, msg: str) -> "SubPrompt": 15 | # create user message specific subprompt 16 | text = f"'user-{username}' wrote to MassGPT: {msg}" 17 | return MessageSubPrompt(text=text, max_tokens=MessageSubPrompt.MAX_TOKENS) 18 | 19 | 20 | PREPROMPT = SubPrompt( \ 21 | """You are MassGPT, and this is a fun experiment. \ 22 | You were built by Jiggy AI using OpenAI text-davinci-003. \ 23 | Instruction: Different users are sending you messages. \ 24 | They can not communicate with each other directly. \ 25 | Any user to user message must be relayed through you. \ 26 | Pass along any interesting message. \ 27 | Try not to repeat yourself. \ 28 | Ask users questions if you are not sure what to say. \ 29 | If a user expresses interest in a topic discussed here, \ 30 | respond to them based on what you read here. \ 31 | Users have recently said the following to you:""") 32 | 33 | 34 | class GPTCopyTestTask(GPT3CompletionTask): 35 | """ 36 | Generated message response completions based on dynamic history of recent messages and most used message 37 | """ 38 | TEMPERATURE = 0.0 39 | 40 | # General Prompt Strategy: 41 | # Upon reception of message from a user 999, compose the following prompt 42 | # based on recent messages received from other users: 43 | 44 | #PREPROMPT = SubPrompt("Prepare to copy some of the following messages that users sent to MassGPT:") 45 | 46 | 47 | # "User 123 wrote: ABC is the greatest thing ever" 48 | # "User 234 wrote: ABC is cool but i like GGG more" 49 | # "User 345 wrote: DDD is the best ever!" 50 | # etc 51 | #PENULTIMATE_PROMPT = SubPrompt("Instruction: Respond to the following user message considering the above context and Instruction:") 52 | # "User 999 wrote: What do folks think about ABC?" # End of Prompt 53 | #FINAL_PROMPT = SubPrompt("MassGPT responded:") 54 | # Then send resulting llm completion back to user 999 in response to his message 55 | 56 | def __init__(self) -> "GPTCopyTestTask": 57 | limits = CompletionLimits(min_prompt = 0, 58 | min_completion = 300, 59 | max_completion = 400) 60 | 61 | super().__init__(limits = limits, 62 | temperature = GPTCopyTestTask.TEMPERATURE, 63 | stop = ['###'], 64 | model = 'text-davinci-003') 65 | 66 | 67 | def completion(self, 68 | recent_msgs : MessageSubPrompt, 69 | user_msg : MessageSubPrompt) -> Completion: 70 | """ 71 | return completion for the provided subprompts 72 | """ 73 | #prompt = GPTCopyTestTask.PREPROMPT 74 | #final_prompt = GPTCopyTestTask.PENULTIMATE_PROMPT 75 | #final_prompt += user_msg 76 | #final_prompt += GPTCopyTestTask.FINAL_PROMPT 77 | 78 | #logger.info(f"overhead tokens: {(prompt + final_prompt).tokens}") 79 | 80 | #available_tokens = self.max_prompt_tokens() - (prompt + final_prompt).tokens 81 | 82 | #logger.info(f"available_tokens: {available_tokens}") 83 | # assemble list of most recent_messages up to available token limit 84 | prompt = SubPrompt("") 85 | for sub in recent_msgs: 86 | prompt += sub 87 | prompt += user_msg 88 | #prompt += final_prompt 89 | # add most recent user message after penultimate prompt 90 | logger.info(f"final prompt tokens: {prompt.tokens} max{self.max_prompt_tokens()}") 91 | 92 | logger.info(f"final prompt token_count: {prompt.tokens} chars: {len(prompt.text)}") 93 | 94 | return super().completion(prompt) 95 | 96 | 97 | msg_response_task = GPTCopyTestTask() 98 | 99 | 100 | #users = ['cat', 'dog', 'pig', 'cow', 'rocket', 'ocean', 'mountain', 'tree'] 101 | #users = ['cat', 'rocket'] 102 | users = list(range(10)) 103 | 104 | from random import choice 105 | 106 | 107 | NUM_MESSAGES = 220 108 | #NUM_MESSAGES = 150 109 | 110 | data = [] 111 | messages = [] 112 | for i in range(NUM_MESSAGES): 113 | user = choice(users) 114 | msg = f"This is message {i}" 115 | data.append((user, msg)) 116 | msp = MessageSubPrompt.from_user_str(user, msg) 117 | messages.append(msp) 118 | print(msp) 119 | 120 | print("=======================") 121 | target_u = choice(users) 122 | 123 | 124 | final_p = MessageSubPrompt.from_user_str(choice(users), 125 | f"Instruction: Copy all of the messages that 'user-{target_u}' wrote to MassGPT, only one message per line:") 126 | 127 | final_p = SubPrompt(f"Instruction: Copy all of the messages that 'user-{target_u}' wrote to MassGPT, only one message per line:") 128 | 129 | 130 | comp = msg_response_task.completion(messages, final_p) 131 | 132 | print(comp.prompt) 133 | print("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~") 134 | print(str(comp)) 135 | 136 | 137 | answers = [msg for user,msg in data if user==target_u] 138 | 139 | correct = 0 140 | results = comp.completion.rstrip().lstrip().split('\n') 141 | 142 | answers = set([a.rstrip().lstrip().lower() for a in answers]) 143 | comps = set([c.rstrip().lstrip().lower() for c in results]) 144 | 145 | #print(answers) 146 | #print(comps) 147 | correct = answers.intersection(comps) 148 | 149 | 150 | 151 | """ 152 | for c , a in zip(results.split("\n"), answer): 153 | a = a.rstrip().lower() 154 | c = c.rstrip().lower() 155 | print(f"{a} | {c}") 156 | if a in c: 157 | correct += 1 158 | """ 159 | 160 | precision = len(correct)/len(comps) 161 | recall = len(correct)/len(answers) 162 | print(f"precision {precision} \t recall {recall}") 163 | -------------------------------------------------------------------------------- /src/copytest_chat.py: -------------------------------------------------------------------------------- 1 | from loguru import logger 2 | import openai 3 | from time import time 4 | from retry import retry 5 | 6 | users = list(range(10)) 7 | 8 | from random import choice 9 | 10 | #NUM_MESSAGES = 220 11 | NUM_MESSAGES = 480 12 | 13 | @retry(5) 14 | def test(model): 15 | data = [] 16 | messages = [] 17 | target_u = choice(users) 18 | prompt = f"Instruction: Copy all of the messages that 'user-{target_u}' wrote to MassGPT, only one message per line:\n" 19 | for i in range(NUM_MESSAGES): 20 | user = choice(users) 21 | msg = f"This is message {i}" 22 | data.append((user, msg)) 23 | text = f"'user-{user}' wrote to MassGPT: {msg}\n" 24 | prompt += text 25 | 26 | prompt += f"Instruction: Copy all of the messages that 'user-{target_u}' wrote to MassGPT, only one message per line:" 27 | print(prompt) 28 | print("=======================") 29 | 30 | messages = [{"role": "user", "content": prompt}] 31 | 32 | 33 | t0 = time() 34 | 35 | if model in ['gpt-3.5-turbo', 'gpt-4']: 36 | response = openai.ChatCompletion.create(model=model, 37 | messages=messages, 38 | temperature=0) 39 | print(response) 40 | response_text = response['choices'][0]['message']['content'] 41 | else: 42 | 43 | response = openai.Completion.create(engine=model, 44 | temperature=0, 45 | max_tokens=500, 46 | prompt=prompt) 47 | response_text = response.choices[0].text 48 | dt = time() - t0 49 | print("dt", dt) 50 | 51 | print(response_text) 52 | print("=======================") 53 | 54 | answers = [msg for user,msg in data if user==target_u] 55 | 56 | correct = 0 57 | results = response_text.strip().split('\n') 58 | print("RESULTS:") 59 | print(results) 60 | 61 | answers = set([a.strip().lower() for a in answers]) 62 | comps = set([c.split(':')[-1].strip().lower() for c in results]) 63 | 64 | print("ANSWERS") 65 | print(answers) 66 | print('comps') 67 | print(comps) 68 | correct = answers.intersection(comps) 69 | 70 | 71 | precision = len(correct)/len(comps) 72 | recall = len(correct)/len(answers) 73 | print(f"precision {precision} \t recall {recall}") 74 | 75 | return precision, recall, dt 76 | 77 | 78 | precision = {} 79 | recall = {} 80 | latency = {} 81 | #MODELS = ['gpt-3.5-turbo', 'text-davinci-003'] #, 'text-davinci-002'] 82 | 83 | #MODELS = ['gpt-3.5-turbo', 'gpt-4'] #, 'text-davinci-002'] 84 | 85 | MODELS = ['gpt-4'] #, 'text-davinci-002'] 86 | 87 | for model in MODELS: 88 | precision[model] = [] 89 | recall[model] = [] 90 | latency[model] = [] 91 | 92 | for model in MODELS: 93 | for i in range(10): 94 | p, r, dt = test(model) 95 | precision[model].append(p) 96 | recall[model].append(r) 97 | latency[model].append(dt) 98 | for model in MODELS: 99 | print(f"{model:15} precision {100*sum(precision[model])/len(precision[model]):.1f}% \trecall {100*sum(recall[model])/len(recall[model]):.1f}%\tlatency {sum(latency[model])/len(latency[model]):.1f} s") 100 | 101 | -------------------------------------------------------------------------------- /src/db.py: -------------------------------------------------------------------------------- 1 | # database engine 2 | import os 3 | from sqlmodel import create_engine, SQLModel 4 | 5 | 6 | 7 | # DB Config 8 | db_host = os.environ['MASSGPT_POSTGRES_HOST'] 9 | user = os.environ['MASSGPT_POSTGRES_USER'] 10 | passwd = os.environ['MASSGPT_POSTGRES_PASS'] 11 | 12 | DBURI = 'postgresql+psycopg2://%s:%s@%s:5432/massgpt' % (user, passwd, db_host) 13 | 14 | engine = create_engine(DBURI, pool_pre_ping=True, echo=False) 15 | 16 | if __name__ == "__main__": 17 | from models import * 18 | SQLModel.metadata.create_all(engine) 19 | print("create_all complete") 20 | -------------------------------------------------------------------------------- /src/download_hf_models_at_buildtime.py: -------------------------------------------------------------------------------- 1 | # used at docker build time to pull model cache into container 2 | 3 | from transformers import GPT2Tokenizer 4 | tokenizer = GPT2Tokenizer.from_pretrained("gpt2") 5 | 6 | from sentence_transformers import SentenceTransformer 7 | st_model_name = 'multi-qa-mpnet-base-dot-v1' 8 | st_model = SentenceTransformer(st_model_name) 9 | 10 | #import whisper 11 | #whisper_model = whisper.load_model("large") 12 | 13 | print("HF done") 14 | -------------------------------------------------------------------------------- /src/exceptions.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | class MinimumTokenLimit(Exception): 4 | """ 5 | The specified minimum token count has been exceeded 6 | """ 7 | 8 | class MaximumTokenLimit(Exception): 9 | """ 10 | The specified maximum token count has been exceeded 11 | """ 12 | 13 | 14 | 15 | class ExtractException(Exception): 16 | """ 17 | various extraction errors 18 | """ 19 | 20 | class UnsupportedHostException(ExtractException): 21 | """ 22 | The URL is for a host we know we can't access reliably 23 | """ 24 | 25 | class UnsupportedContentType(ExtractException): 26 | """ 27 | The http content type is unsupported. 28 | """ 29 | 30 | class EmptyText(ExtractException): 31 | """ 32 | Unable to extract any readable text from the URL. 33 | """ 34 | 35 | 36 | class NetworkError(ExtractException): 37 | """ 38 | Unable to access the content. 39 | """ 40 | 41 | 42 | -------------------------------------------------------------------------------- /src/extract.py: -------------------------------------------------------------------------------- 1 | """ 2 | Extract readable text from a URL via various hacks 3 | """ 4 | 5 | from loguru import logger 6 | from bs4 import BeautifulSoup, NavigableString, Tag 7 | from readability import Document # https://github.com/buriy/python-readability 8 | import requests 9 | import urllib.parse 10 | 11 | from github_api import github_readme_text 12 | from pdf_text import pdf_text 13 | 14 | from exceptions import * 15 | 16 | 17 | def extract_text_from_html(content): 18 | soup = BeautifulSoup(content, 'html.parser') 19 | 20 | output = "" 21 | title = soup.find('title') 22 | if title: 23 | output += "Title: " + title 24 | 25 | blacklist = ['[document]','noscript','header','html','meta','head','input','script', "style"] 26 | # there may be more elements we don't want 27 | 28 | for t in soup.find_all(text=True): 29 | if t.parent.name not in blacklist: 30 | output += '{} '.format(t) 31 | return output 32 | 33 | 34 | def get_url_text(url): 35 | """ 36 | get url content and extract readable text 37 | returns the text 38 | """ 39 | resp = requests.get(url, timeout=30) 40 | 41 | if resp.status_code != 200: 42 | logger.warning(url) 43 | raise NetworkError(f"Unable to get URL ({resp.status_code})") 44 | 45 | CONTENT_TYPE = resp.headers['Content-Type'] 46 | 47 | if 'pdf' in CONTENT_TYPE: 48 | return pdf_text(resp.content) 49 | 50 | if "html" not in CONTENT_TYPE: 51 | logger.warning(url) 52 | raise UnsupportedContentType(f"Unsupported content type: {resp.headers['Content-Type']}") 53 | 54 | doc = Document(resp.text) 55 | text = extract_text_from_html(doc.summary()) 56 | 57 | if not len(text) or text.isspace(): 58 | logger.warning(url) 59 | raise EmptyText("Unable to extract text data from url") 60 | return text 61 | 62 | 63 | 64 | def url_to_text(url): 65 | HOPELESS = ["youtube.com", 66 | "www.youtube.com"] 67 | if urllib.parse.urlparse(url).netloc in HOPELESS: 68 | logger.warning(url) 69 | raise UnsupportedHostException("Unsupported host: {urllib.parse.urlparse(url).netloc}") 70 | 71 | if urllib.parse.urlparse(url).netloc == 'github.com': 72 | # for github repos use api to attempt to find a readme file 73 | text = github_readme_text(url) 74 | else: 75 | text = get_url_text(url) 76 | 77 | logger.info("url_to_text: "+text) 78 | return text 79 | -------------------------------------------------------------------------------- /src/github_api.py: -------------------------------------------------------------------------------- 1 | # wrangle text out of github repo readme via api 2 | # only works for repo readme, not other github pages like issues, discussions, etc 3 | 4 | from loguru import logger 5 | import requests 6 | import markdown 7 | from bs4 import BeautifulSoup 8 | 9 | 10 | def md_to_text(md): 11 | html = markdown.markdown(md) 12 | soup = BeautifulSoup(html, features='html.parser') 13 | return soup.get_text() 14 | 15 | 16 | def github_readme_text(github_repo_url): 17 | # split a github url into the owner and repo components 18 | # use the github api to try to find the readme text 19 | # ['https:', '', 'github.com', 'jiggy-ai', 'hn_summary'] 20 | spliturl = github_repo_url.rstrip('/').split('/') 21 | if len(spliturl) != 5: 22 | logger.warning(github_repo_url) 23 | raise Exception(f"Unable to process github url {github_repo_url}") 24 | owner = spliturl[3] 25 | repo = spliturl[4] 26 | contenturl = f'https://api.github.com/repos/{owner}/{repo}/readme' 27 | resp = requests.get(contenturl) 28 | if resp.status_code != 200: 29 | logger.warning(f"{github_repo_url} {resp.content}") 30 | raise Exception(f"Unable to get readme for {github_repo_url}") 31 | item = resp.json() 32 | md = requests.get(item['download_url']).text 33 | return md_to_text(md) 34 | 35 | -------------------------------------------------------------------------------- /src/gpt3.py: -------------------------------------------------------------------------------- 1 | # GPT3 specific Completions 2 | # Copyright (C) 2022 William S. Kish 3 | 4 | import os 5 | from loguru import logger 6 | from time import sleep 7 | 8 | import completion 9 | import openai 10 | 11 | openai.api_key = os.environ["OPENAI_API_KEY"] 12 | 13 | OPENAI_COMPLETION_MODELS = ["text-davinci-003", "text-davinci-002", "text-davinci-001"] 14 | 15 | 16 | def CompletionLimits(min_prompt:int, 17 | min_completion:int, 18 | max_completion:int, 19 | max_context: int = 4097) -> completion.CompletionLimits: 20 | """ 21 | CompletionLimits specific to GPT3 models with max_context of 4097 tokens 22 | """ 23 | assert(max_context <= 4097) 24 | return completion.CompletionLimits(max_context = max_context, 25 | min_prompt = min_prompt, 26 | min_completion = min_completion, 27 | max_completion = max_completion) 28 | 29 | 30 | 31 | RETRY_COUNT = 10 32 | 33 | class GPT3CompletionTask(completion.CompletionTask): 34 | """ 35 | An OpenAI GP3-class completion task implemented using OpenAI API 36 | """ 37 | 38 | 39 | def __init__(self, 40 | limits : completion.CompletionLimits, 41 | temperature : float = 1, 42 | top_p : float = 1, 43 | stop : list[str] = None, 44 | model : str = 'text-davinci-003') -> "GPT3CompletionTask": 45 | 46 | assert(model in OPENAI_COMPLETION_MODELS) 47 | 48 | if model in ["text-davinci-003", "text-davinci-002"]: 49 | assert(limits.max_context <= 4097) 50 | else: 51 | assert(limits.max_context <= 2048) 52 | 53 | self.stop = stop 54 | self.top_p = top_p 55 | self.temperature = temperature 56 | 57 | super().__init__(limits = limits, 58 | model = model) 59 | 60 | 61 | def _completion(self, 62 | prompt : str, 63 | max_completion_tokens : int) -> str : 64 | """ 65 | perform the actual completion via openai api 66 | returns the completion text string 67 | """ 68 | def completion(): 69 | resp = openai.Completion.create(engine = self.model, 70 | prompt = prompt, 71 | temperature = self.temperature, 72 | top_p = self.top_p, 73 | stop = self.stop, 74 | max_tokens = max_completion_tokens) 75 | return resp.choices[0].text 76 | 77 | for i in range(RETRY_COUNT): 78 | try: 79 | return completion() 80 | except (openai.error.RateLimitError, openai.error.ServiceUnavailableError): 81 | logger.warning("openai error") 82 | if i == RETRY_COUNT-1: 83 | raise 84 | sleep(i**1.3) 85 | except Exception as e: 86 | logger.exception("_completion") 87 | raise 88 | 89 | 90 | 91 | -------------------------------------------------------------------------------- /src/hn_summary_db.py: -------------------------------------------------------------------------------- 1 | 2 | from typing import Optional 3 | from pydantic import condecimal 4 | from time import time 5 | 6 | import os 7 | from sqlmodel import create_engine, SQLModel, Field, Session, select 8 | 9 | 10 | 11 | # DB Config 12 | db_host = os.environ['HNSUM_POSTGRES_HOST'] 13 | user = os.environ['HNSUM_POSTGRES_USER'] 14 | passwd = os.environ['HNSUM_POSTGRES_PASS'] 15 | 16 | DBURI = 'postgresql+psycopg2://%s:%s@%s:5432/hnsum' % (user, passwd, db_host) 17 | 18 | # Create DB Engine 19 | engine = create_engine(DBURI, pool_pre_ping=True, echo=False) 20 | 21 | 22 | timestamp = condecimal(max_digits=14, decimal_places=3) 23 | 24 | 25 | class HackerNewsStory(SQLModel, table=True): 26 | # partial state of HN Story item. See https://github.com/HackerNews/API#items 27 | # we dont include "type" since we are only recording type='story' here. 28 | id: int = Field(primary_key=True, description="The item's unique id.") 29 | by: str = Field(index=True, description="The username of the item's author.") 30 | time: int = Field(index=True, description="Creation date of the item, in Unix Time.") 31 | title: str = Field(description="The title of the story, poll or job. HTML.") 32 | text: Optional[str] = Field(description="The comment, story or poll text. HTML.") 33 | url: Optional[str] = Field(description="The url associated with the Item.") 34 | 35 | class StoryText(SQLModel, table=True): 36 | id: int = Field(primary_key=True, description="The summary unique id.") 37 | story_id: int = Field(index=True, description="The story id this text is associated with.") 38 | mechanism: str = Field(description="identifies which software mechanism exracted the text from the url") 39 | crawl_time: timestamp = Field(default_factory=time, description='The epoch timestamp when the url was crawled.') 40 | html: Optional[str] = Field(description="original html content") 41 | text: str = Field(max_length=65535, description="The readable text we managed to extract from the Story Url.") 42 | 43 | class StorySummary(SQLModel, table=True): 44 | id: int = Field(primary_key=True, description="The summary unique id.") 45 | story_id: int = Field(index=True, description="The story id.") 46 | model: str = Field(description="The model used to summarize a story") 47 | prompt: str = Field(max_length=65535, description="The prompt used to create the summary.") 48 | summary: str = Field(max_length=65535, description="The summary we got back from the model.") 49 | upvotes: Optional[int] = Field(default=0, description="The number of upvotes for this summary.") 50 | votes: Optional[int] = Field(default=0, description="The total number of votes for this summary.") 51 | 52 | 53 | 54 | 55 | def stories(story_ids : int) -> HackerNewsStory: 56 | with Session(engine) as session: 57 | return session.exec(select(HackerNewsStory).where(HackerNewsStory.id.in_(story_ids))).all() 58 | 59 | 60 | 61 | def story_text(story : HackerNewsStory) -> str: 62 | with Session(engine) as session: 63 | storytext = session.exec(select(StoryText).where(StoryText.story_id == story.id)).first() 64 | if not storytext: 65 | return "" 66 | return storytext.text 67 | 68 | def story_summary(story : HackerNewsStory) -> str: 69 | with Session(engine) as session: 70 | storysummary = session.exec(select(StorySummary).where(StorySummary.story_id == story.id)).first() 71 | if not storysummary: 72 | return "" 73 | return storysummary.summary 74 | 75 | -------------------------------------------------------------------------------- /src/jigit.py: -------------------------------------------------------------------------------- 1 | # Begining of a framework for Jigits, which are widgets that allow LLMs to interact with code 2 | # Copyright (C) 2022 William S. Kish 3 | 4 | 5 | from subprompt import SubPrompt 6 | 7 | 8 | class Jigit(BaseModel): 9 | 10 | id: str # unique ID for a particular Jigit 11 | prompt: SubPrompt # The subprompt to include to activate this jidget 12 | 13 | 14 | def handler(completion: str): 15 | pass 16 | 17 | 18 | def process_completion(self, completion: str): 19 | if self.id in completion: 20 | self.handler(completion) 21 | -------------------------------------------------------------------------------- /src/massgpt.py: -------------------------------------------------------------------------------- 1 | # MassGPT app 2 | # 3 | # Copyright (C) 2022 William S. Kish 4 | 5 | from sqlmodel import Session, select, delete 6 | from sentence_transformers import SentenceTransformer 7 | import urllib.parse 8 | 9 | from db import engine 10 | 11 | import gpt3 12 | 13 | from exceptions import * 14 | from models import * 15 | 16 | from extract import url_to_text 17 | from subprompt import SubPrompt 18 | 19 | 20 | ### 21 | ### Various Specialized SubPrompts 22 | ### 23 | 24 | class MessageSubPrompt(SubPrompt): 25 | """ 26 | SubPrompt Context for a user-generated message 27 | """ 28 | MAX_TOKENS = 300 29 | 30 | @classmethod 31 | def from_msg(cls, msg: Message) -> "SubPrompt": 32 | # create user message specific subprompt 33 | text = f"user-{msg.user_id} wrote to MassGPT: {msg.text}" 34 | return MessageSubPrompt(text=text, max_tokens=MessageSubPrompt.MAX_TOKENS) 35 | 36 | 37 | class MessageResponseSubPrompt(SubPrompt): 38 | """ 39 | SubPrompt Context for a user-generated message 40 | """ 41 | @classmethod 42 | def from_msg_completion(cls, msp: MessageSubPrompt, comp: Completion) -> "SubPrompt": 43 | return msp # + f"MassGPT responded: {comp.completion}" # leave the model response out for experimentation 44 | 45 | 46 | 47 | class UrlSummarySubPrompt(SubPrompt): 48 | """ 49 | SubPrompt Context for a user-requested URL Summary 50 | """ 51 | @classmethod 52 | def from_summary(cls, user: User, text : str) -> "SubPrompt": 53 | text = f"User {user.id} posted a link with the following summary:\n{text}\n" 54 | # don't need to specify max _tokens here since the summary is a model output 55 | # that is regulated through the user_summary_limits 56 | return UrlSummarySubPrompt(text=text) 57 | 58 | 59 | class SummarySubPrompt(SubPrompt): 60 | """ 61 | SubPrompt Context for a system-generated summary 62 | """ 63 | @classmethod 64 | def from_msg(cls, text : str) -> "SubPrompt": 65 | text = f"Here is a summary of previous discussions for reference: {text}" 66 | # don't need to specify max _tokens here since the summary is a model output 67 | # that is regulated through the msg_summary_limits 68 | return SummarySubPrompt(text=text) 69 | 70 | 71 | 72 | 73 | 74 | class MassGPTMessageTask(gpt3.GPT3CompletionTask): 75 | """ 76 | Generated message response completions based on dynamic history of recent messages and most used message 77 | """ 78 | TEMPERATURE = 0.4 79 | 80 | # General Prompt Strategy: 81 | # Upon reception of message from a user 999, compose the following prompt 82 | # based on recent messages received from other users: 83 | 84 | PREPROMPT = SubPrompt( \ 85 | """You are MassGPT, and this is a fun experiment. \ 86 | You were built by Jiggy AI using OpenAI text-davinci-003. \ 87 | Instruction: Different users are sending you messages. \ 88 | They can not communicate with each other directly. \ 89 | Any user to user message must be relayed through you. \ 90 | Pass along any interesting message. \ 91 | Try not to repeat yourself. \ 92 | Ask users questions if you are not sure what to say. \ 93 | If a user expresses interest in a topic discussed here, \ 94 | respond to them based on what you read here. \ 95 | Users have recently said the following to you:""") 96 | 97 | 98 | # "User 123 wrote: ABC is the greatest thing ever" 99 | # "User 234 wrote: ABC is cool but i like GGG more" 100 | # "User 345 wrote: DDD is the best ever!" 101 | # etc 102 | PENULTIMATE_PROMPT = SubPrompt("Instruction: Respond to the following user message considering the above context and Instruction:") 103 | # "User 999 wrote: What do folks think about ABC?" # End of Prompt 104 | FINAL_PROMPT = SubPrompt("MassGPT responded:") 105 | # Then send resulting llm completion back to user 999 in response to his message 106 | 107 | def __init__(self) -> "MassGPTMessageTask": 108 | limits = gpt3.CompletionLimits(min_prompt = 0, 109 | min_completion = 300, 110 | max_completion = 400) 111 | 112 | super().__init__(limits = limits, 113 | temperature = MassGPTMessageTask.TEMPERATURE, 114 | model = 'text-davinci-003') 115 | 116 | 117 | def completion(self, 118 | recent_msgs : MessageResponseSubPrompt, 119 | user_msg : MessageSubPrompt) -> Completion: 120 | """ 121 | return completion for the provided subprompts 122 | """ 123 | prompt = MassGPTMessageTask.PREPROMPT 124 | final_prompt = MassGPTMessageTask.PENULTIMATE_PROMPT 125 | final_prompt += user_msg 126 | final_prompt += MassGPTMessageTask.FINAL_PROMPT 127 | 128 | logger.info(f"overhead tokens: {(prompt + final_prompt).tokens}") 129 | 130 | available_tokens = self.max_prompt_tokens() - (prompt + final_prompt).tokens 131 | logger.info(f"available_tokens: {available_tokens}") 132 | # assemble list of most recent_messages up to available token limit 133 | reversed_subs = [] 134 | # add previous message context 135 | for sub in reversed(recent_msgs): 136 | if available_tokens - sub.tokens -1 < 0: 137 | break 138 | reversed_subs.append(sub) 139 | available_tokens -= sub.tokens + 1 140 | 141 | for sub in reversed(reversed_subs): 142 | prompt += sub 143 | 144 | prompt += final_prompt 145 | # add most recent user message after penultimate prompt 146 | logger.info(f"final prompt tokens: {prompt.tokens} max{self.max_prompt_tokens()}") 147 | 148 | logger.info(f"final prompt token_count: {prompt.tokens} chars: {len(prompt.text)}") 149 | 150 | return super().completion(prompt) 151 | 152 | 153 | 154 | 155 | msg_response_task = MassGPTMessageTask() 156 | 157 | 158 | 159 | 160 | class UrlSummaryTask(gpt3.GPT3CompletionTask): 161 | """ 162 | A Factory class to dynamically compose a prompt context for URL Summary task 163 | """ 164 | 165 | # the PROMPT_PREFIX is prepended to the url content before sending to the language model 166 | SUMMARIZE_PROMPT_PREFIX = SubPrompt("Provide a detailed summary of the following web content, including what type of content it is (e.g. news article, essay, technical report, blog post, product documentation, content marketing, etc). If the content looks like an error message, respond 'content unavailable'. If there is anything controversial please highlight the controversy. If there is something surprising, unique, or clever, please highlight that as well:") 167 | 168 | # prompt prefix for Github Readme files 169 | GITHUB_PROMPT_PREFIX = SubPrompt("Provide a summary of the following github project readme file, including the purpose of the project, what problems it may be used to solve, and anything the author mentions that differentiates this project from others:") 170 | 171 | TEMPERATURE = 0.2 172 | 173 | def __init__(self) -> "UrlSummaryTask": 174 | limits = gpt3.CompletionLimits(min_prompt = 40, 175 | min_completion = 300, 176 | max_completion = 600) 177 | 178 | super().__init__(limits = limits, 179 | temperature = UrlSummaryTask.TEMPERATURE, 180 | model = 'text-davinci-003') 181 | 182 | def prefix(self, url): 183 | if urllib.parse.urlparse(url).netloc == 'github.com': 184 | return UrlSummaryTask.GITHUB_PROMPT_PREFIX 185 | else: 186 | return UrlSummaryTask.SUMMARIZE_PROMPT_PREFIX 187 | 188 | 189 | def completion(self, url: str, url_text : str) -> Completion: 190 | """ 191 | return prompt text to summary the following url and and url_text. 192 | The url is required in able to enable host-specific prompt strategy. 193 | For example a different prompt is used to summarize github repo's versus other web sites. 194 | """ 195 | prompt = self.prefix(url) + url_text 196 | prompt.truncate(self.max_prompt_tokens()) 197 | return super().completion(prompt) 198 | 199 | 200 | 201 | 202 | url_summary_task = UrlSummaryTask() 203 | 204 | 205 | 206 | ## Embedding Config 207 | ST_MODEL_NAME = 'multi-qa-mpnet-base-dot-v1' 208 | st_model = SentenceTransformer(ST_MODEL_NAME) 209 | 210 | 211 | 212 | 213 | class Context(): 214 | """ 215 | A context for assembling a large prompt context from recent user message subprompts 216 | """ 217 | def __init__(self) -> "Context": 218 | self.tokens = 0 219 | self._sub_prompts = [] 220 | 221 | def push(self, sub_prompt : SubPrompt) -> bool: 222 | """ 223 | Push sub_prompt onto begining of context. 224 | Used to recreate context in reverse order from database select 225 | raises MaximumTokenLimit when prompt context limit is exceeded 226 | """ 227 | if self.tokens > msg_response_task.max_prompt_tokens(): 228 | raise MaximumTokenLimit 229 | self._sub_prompts.insert(0, sub_prompt) 230 | self.tokens += sub_prompt.tokens 231 | 232 | def add(self, sub_prompt : SubPrompt) -> None: 233 | # add new prompt to end of sub_prompts 234 | self._sub_prompts.append(sub_prompt) 235 | self.tokens += sub_prompt.tokens 236 | # remove oldest subprompts if over prompt context limit exceeded 237 | while self.tokens > msg_response_task.max_prompt_tokens(): 238 | self.tokens -= self._sub_prompts.pop(0).tokens 239 | 240 | def sub_prompts(self) -> list[SubPrompt]: 241 | return self._sub_prompts 242 | 243 | 244 | ## 245 | ### maintain a single global context (for now) 246 | ## 247 | context = Context() 248 | 249 | 250 | 251 | 252 | def receive_message(user : User, text : str) -> str: 253 | """ 254 | receive a message from the specified user. 255 | Return the message response 256 | """ 257 | logger.info(f"message from {user.id} {user.first_name} {user.last_name}: {text}") 258 | with Session(engine) as session: 259 | # persist msg to database so we can regain recent msg context after pod restart 260 | msg = Message(text=text, user_id=user.id) 261 | session.add(msg) 262 | session.commit() 263 | session.refresh(msg) 264 | 265 | # embedding should move to background work queue 266 | t0 = time() 267 | vector = [float(x) for x in st_model.encode(msg.text)] 268 | embedding = Embedding(source = EmbeddingSource.message, 269 | source_id = msg.id, 270 | collection = ST_MODEL_NAME, 271 | model = ST_MODEL_NAME, 272 | vector = vector) 273 | logger.info(f"embedding dt: {time()-t0}") 274 | session.add(embedding) 275 | session.commit() 276 | session.refresh(msg) 277 | 278 | # build final aggregate prompt 279 | msg_subprompt = MessageSubPrompt.from_msg(msg) 280 | 281 | completion = msg_response_task.completion(context.sub_prompts(), msg_subprompt) 282 | logger.info(str(completion)) 283 | 284 | # add the new user message to the global shared context 285 | rsp_subprompt = MessageResponseSubPrompt.from_msg_completion(msg_subprompt, completion) 286 | context.add(rsp_subprompt) 287 | 288 | # save response to database 289 | with Session(engine) as session: 290 | session.add(completion) 291 | session.commit() 292 | session.refresh(completion) 293 | session.add(Response(message_id=msg.id, 294 | completion_id=completion.id)) 295 | session.commit() 296 | session.refresh(completion) 297 | 298 | return str(completion) 299 | 300 | 301 | 302 | 303 | def summarize_url(user : User, url : str) -> str: 304 | """ 305 | Summarize a url for a user. 306 | Return the URL summary, adding the summary to the current context 307 | """ 308 | # check if message contains a URL 309 | # if so extract and summarize the contents 310 | text = url_to_text(url) 311 | with Session(engine) as session: 312 | db_url = URL(url=url, user_id = user.id) 313 | session.add(db_url) 314 | session.commit() 315 | session.refresh(db_url) 316 | urltext = UrlText(url_id = db_url.id, 317 | mechanism = "url_to_text", 318 | text = text) 319 | session.add(urltext) 320 | session.commit() 321 | session.refresh(urltext) 322 | 323 | completion = url_summary_task.completion(url, text) 324 | 325 | summary_text = str(completion) 326 | 327 | with Session(engine) as session: 328 | session.add(completion) 329 | session.commit() 330 | session.refresh(completion) 331 | 332 | url_summary = UrlSummary(text_id = urltext.id, 333 | user_id = user.id, 334 | model = url_summary_task.model, 335 | prefix = url_summary_task.prefix(url).text, 336 | summary = summary_text) 337 | session.add(url_summary) 338 | session.commit() 339 | session.refresh(url_summary) 340 | 341 | # embedding should move to background work queue 342 | t0 = time() 343 | vector = [float(x) for x in st_model.encode(summary_text)] 344 | embedding = Embedding(source = EmbeddingSource.url_summary, 345 | source_id = url_summary.id, 346 | collection = ST_MODEL_NAME, 347 | model = ST_MODEL_NAME, 348 | vector = vector) 349 | logger.info(f"embedding dt: {time()-t0}") 350 | session.add(embedding) 351 | 352 | session.commit() 353 | 354 | # add the summary to recent context 355 | context.add(UrlSummarySubPrompt.from_summary(user=user, text=summary_text)) 356 | 357 | logger.info(summary_text) 358 | # send the text summary to the user as FYI 359 | return summary_text 360 | 361 | 362 | 363 | def current_prompts() -> str: 364 | 365 | prompt = "Current prompt stack:\n\n" 366 | prompt += MassGPTMessageTask.PREPROMPT.text +"\n\n" 367 | prompt += "[Context as shown by /context]\n\n" 368 | prompt += MassGPTMessageTask.PENULTIMATE_PROMPT.text + "\n\n" 369 | prompt += "[Most recent message from user]" 370 | 371 | return prompt 372 | 373 | 374 | 375 | def current_context(max_len=4096) -> str: 376 | """ 377 | iterator returning max_len length strings of current context 378 | """ 379 | size = 0 380 | text = "" 381 | for sub in context.sub_prompts(): 382 | if len(sub.text) + size > max_len: 383 | yield text 384 | text = "" 385 | size = 0 386 | text += sub.text + "\n" 387 | size += len(sub.text) + 1 388 | if text: 389 | yield text 390 | yield f"{context.tokens} tokens" 391 | 392 | 393 | 394 | def load_context_from_db(): 395 | last = "" 396 | logger.info('load_context_from_db') 397 | with Session(engine) as session: 398 | for msg in session.exec(select(Message).order_by(Message.id.desc())): 399 | try: 400 | msg_subprompt = MessageSubPrompt.from_msg(msg) 401 | except: 402 | continue # historic message to big for current limits 403 | if msg_subprompt.text == last: continue # basic dedup 404 | last = msg_subprompt.text 405 | resp = session.exec(select(Response).where(Response.message_id == msg.id)).first() 406 | if not resp: 407 | try: 408 | context.push(msg_subprompt) 409 | except MaximumTokenLimit: 410 | break 411 | continue 412 | comp = session.exec(select(Completion).where(Completion.id == resp.completion_id)).first() 413 | if not comp: continue 414 | rsp_subprompt = MessageResponseSubPrompt.from_msg_completion(msg_subprompt, comp) 415 | try: 416 | context.push(rsp_subprompt) 417 | except MaximumTokenLimit: 418 | break 419 | 420 | load_context_from_db() 421 | -------------------------------------------------------------------------------- /src/models.py: -------------------------------------------------------------------------------- 1 | 2 | # MassGPT sqlmodel models 3 | # Copyright (C) 2022 William S. Kish 4 | 5 | from loguru import logger 6 | from typing import Optional, List 7 | from array import array 8 | import enum 9 | from sqlmodel import Field, SQLModel, Column, ARRAY, Float, Enum 10 | from sqlalchemy import BigInteger 11 | from pydantic import BaseModel, ValidationError, validator 12 | from pydantic import condecimal 13 | from time import time 14 | 15 | timestamp = condecimal(max_digits=14, decimal_places=3) # unix epoch timestamp decimal to millisecond precision 16 | 17 | 18 | class UserStatus(str, enum.Enum): 19 | enabled = "enabled" 20 | disabled = "disabled" 21 | 22 | 23 | class User(SQLModel, table=True): 24 | id: int = Field(primary_key=True, description='Unique user ID') 25 | created_at: timestamp = Field(default_factory=time, description='The epoch timestamp when the Evaluation was created.') 26 | username: Optional[str] = Field(default=None, index=True, description="Username") 27 | first_name: Optional[str] = Field(description="User's first name") 28 | last_name: Optional[str] = Field(description="User's last name") 29 | auth0_id: Optional[str] = Field(index=True, description='Auth0 user_id') 30 | telegram_id: Optional[int] = Field(sa_column=Column(BigInteger()), index=True, description='Telegram User ID') 31 | telegram_is_bot: Optional[bool] = Field(description="is_bot from telegram") 32 | telegram_is_premium: Optional[bool] = Field(description="is_premium from telegram") 33 | telegram_language_code: Optional[str] = Field(description="language_code from telegram") 34 | 35 | 36 | 37 | 38 | class Message(SQLModel, table=True): 39 | id: int = Field(primary_key=True, description='Our unique message id') 40 | text: str = Field(max_length=4096, description='The message text') 41 | user_id: int = Field(index=True, foreign_key='user.id', description='The user who sent the Message') 42 | created_at: timestamp = Field(index=True, default_factory=time, description='The epoch timestamp when the Message was created.') 43 | 44 | 45 | 46 | 47 | class Response(SQLModel, table=True): 48 | id: int = Field(primary_key=True, description='Our unique message id') 49 | message_id: int = Field(index=True, foreign_key="message.id", description="The message to which this is a response.") 50 | created_at: timestamp = Field(index=True, default_factory=time, description='The epoch timestamp when the Message was created.') 51 | completion_id: int = Field(foreign_key="completion.id", description="associated completion that provided the response text") 52 | 53 | 54 | class URL(SQLModel, table=True): 55 | id: int = Field(primary_key=True, description='Unique ID') 56 | url: str = Field(max_length=2048, description='The actual supplied URL') 57 | user_id: int = Field(index=True, foreign_key='user.id', description='The user who sent the URL') 58 | created_at: timestamp = Field(default_factory=time, description='The epoch timestamp when this was created.') 59 | 60 | 61 | class UrlText(SQLModel, table=True): 62 | id: int = Field(primary_key=True, description="The text unique id.") 63 | url_id: int = Field(index=True, foreign_key="url.id", description="The usr this text was extracted from.") 64 | mechanism: str = Field(description="identifies which software mechanism exracted the text from the url") 65 | created_at: timestamp = Field(default_factory=time, description='The epoch timestamp when the url was crawled.') 66 | text: str = Field(max_length=65535, description="The readable text we managed to extract from the Url.") 67 | content: Optional[str] = Field(max_length=65535, description="original html content") 68 | content_type: Optional[str] = Field(description="content type from http") 69 | 70 | 71 | class UrlSummary(SQLModel, table=True): 72 | id: int = Field(primary_key=True, description="The URL Summary unique id.") 73 | text_id: int = Field(index=True, foreign_key="urltext.id", description="The UrlText used to create the summary.") 74 | model: str = Field(description="The model used to produce this summary.") 75 | prefix: str = Field(max_length=8192, description="The prompt prefix used to create the summary.") 76 | summary: str = Field(max_length=8192, description="The model summary of the UrlText.") 77 | created_at: timestamp = Field(default_factory=time, description='The epoch timestamp when this was created.') 78 | #completion_id: int = Field(foreign_key="completion.id", description="associated completion that provided the response text") 79 | 80 | 81 | class ChatSummary(SQLModel, table=True): 82 | id: int = Field(primary_key=True, description="The ChatSummary unique id.") 83 | created_at: timestamp = Field(index=True, default_factory=time, description='The epoch timestamp when this was created.') 84 | completion_id: int = Field(foreign_key="completion.id", description="associated completion that provided the response text") 85 | 86 | 87 | 88 | 89 | 90 | class Completion(SQLModel, table=True): 91 | """ 92 | Model prompt + completion 93 | A low-level of model prompts+completions. 94 | """ 95 | id: int = Field(primary_key=True, description="The completion unique id.") 96 | model: str = Field(description="model engine") 97 | temperature: int = Field(description="configured temperature") # XXX covert to float 98 | prompt: str = Field(max_length=65535, description="The prompt used to generate the completion.") 99 | completion: str = Field(max_length=65535, description="The completion received from the model.") 100 | created_at: timestamp = Field(default_factory=time, description='The epoch timestamp when this was created.') 101 | 102 | def __str__(self): 103 | """ 104 | str(Completion) returns the completion text for convenience 105 | """ 106 | return self.completion 107 | 108 | 109 | 110 | class EmbeddingSource(str, enum.Enum): 111 | """ 112 | The source of the text for embedding 113 | """ 114 | message = "message" 115 | chat_summary = "chat_summary" 116 | url_summary = "url_summary" 117 | hn_story_summary = "hn_story_summary" 118 | hn_story_title = "hn_story_title" 119 | 120 | 121 | 122 | 123 | class Embedding(SQLModel, table=True): 124 | id: int = Field(default=None, 125 | primary_key=True, 126 | description='Unique database identifier for a given embedding vector.') 127 | collection: str = Field(index=True, description='The name of the collection that holds this vector.') 128 | source: EmbeddingSource = Field(sa_column=Column(Enum(EmbeddingSource)), 129 | description='The source of this embedding') 130 | source_id: int = Field(index=True, description='The message/chat/url_summary id that produced this embedding.') 131 | model: str = Field(description="The model used to produce this embedding.") 132 | vector: List[float] = Field(sa_column=Column(ARRAY(Float(24))), 133 | description='The embedding vector.') 134 | 135 | 136 | class HnswIndex(SQLModel, table=True): 137 | 138 | id: int = Field(default=None, 139 | primary_key=True, 140 | description='Unique database identifier for a given index') 141 | collection: str = Field(index=True, description='The name of the collection that holds this vector.') 142 | count: int = Field(default=0, description="The number of vectors included in the index. The number of vectors in the collection at the time of index build.") 143 | objkey: str = Field(description='The index key name in object store') 144 | created_at: timestamp = Field(default_factory=time, description='The epoch timestamp when the index was requested to be created.') 145 | -------------------------------------------------------------------------------- /src/pdf_text.py: -------------------------------------------------------------------------------- 1 | 2 | from loguru import logger 3 | from io import BytesIO 4 | from pdfminer.high_level import extract_pages 5 | from pdfminer.layout import LTTextContainer 6 | 7 | 8 | def pdf_text(pdf_bytes): 9 | """ 10 | extract text from the pdf_bytes and return it as a single string 11 | """ 12 | text = "" 13 | for page_layout in extract_pages(BytesIO(pdf_bytes)): 14 | for element in page_layout: 15 | if isinstance(element, LTTextContainer): 16 | for text_line in element: 17 | text += text_line.get_text().rstrip() + " " 18 | 19 | logger.info("pdf_text: "+text) 20 | return text 21 | -------------------------------------------------------------------------------- /src/querytest.py: -------------------------------------------------------------------------------- 1 | from loguru import logger 2 | from gpt3 import GPT3CompletionTask, CompletionLimits 3 | from exceptions import * 4 | from models import Completion 5 | 6 | from subprompt import SubPrompt 7 | 8 | import wikipedia 9 | 10 | 11 | class SubjectQueryTask(GPT3CompletionTask): 12 | def __init__(self, 13 | min_completion=100, 14 | max_completion=100) -> "SubjectQueryTask": 15 | 16 | limits = CompletionLimits(min_prompt = 0, 17 | min_completion = min_completion, 18 | max_completion = max_completion) 19 | 20 | 21 | super().__init__(limits = limits, 22 | temperature = .1, 23 | model = 'text-davinci-003') 24 | 25 | def completion(self, 26 | query : str) -> Completion: 27 | 28 | prompt = SubPrompt("What is the primary topic of this questions:") 29 | prompt += query 30 | print(prompt) 31 | resp = super().completion(prompt) 32 | print(str(resp)) 33 | return resp 34 | 35 | 36 | 37 | class WikipediaQueryTask(GPT3CompletionTask): 38 | """ 39 | Generated message response completions based on dynamic history of recent messages and most used message 40 | """ 41 | TEMPERATURE = 0.0 42 | 43 | def __init__(self, 44 | wikipedia_page:str, 45 | min_completion=100, 46 | max_completion=100) -> "WikipediaQueryTask": 47 | 48 | limits = CompletionLimits(min_prompt = 0, 49 | min_completion = min_completion, 50 | max_completion = max_completion) 51 | 52 | self.page = wikipedia_page 53 | page = wikipedia.page(self.page) 54 | self.content = page.content 55 | 56 | super().__init__(limits = limits, 57 | temperature = 0, 58 | model = 'text-davinci-003') 59 | 60 | 61 | def completion(self, 62 | query : str) -> Completion: 63 | """ 64 | return completion for the provided subprompts 65 | """ 66 | 67 | prompt = SubPrompt(f"Respond to the final question regarding this info on '{self.page}' using only the information provided here in the following text:") 68 | final = SubPrompt(query) 69 | available_tokens = self.max_prompt_tokens() - (prompt + final).tokens - 1 70 | article = SubPrompt(self.content, max_tokens=available_tokens, truncate=True) 71 | prompt += article 72 | prompt += final 73 | print(prompt) 74 | resp = super().completion(prompt) 75 | print(str(resp)) 76 | return resp 77 | 78 | query_task = WikipediaQueryTask("rocket engine", 200, 400) 79 | 80 | 81 | 82 | -------------------------------------------------------------------------------- /src/results.txt: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | gpt-3.5-turbo precision 93.3% recall 71.3% latency 10.8 s 5 | gpt-4 precision 100.0% recall 100.0% latency 11.9 s 6 | 7 | 8 | -------------------------------------------------------------------------------- /src/s3.py: -------------------------------------------------------------------------------- 1 | import s3_bucket as S3 2 | import boto3 3 | import os 4 | from botocore.exceptions import ClientError 5 | 6 | BUCKET_NAME = 'jiggy-assets' 7 | ENDPOINT_URL = os.environ.get("JIGGY_STORAGE_ENDPOINT_URL", "https://us-southeast-1.linodeobjects.com") 8 | 9 | STORAGE_KEY_ID = os.environ['JIGGY_STORAGE_KEY_ID'] 10 | STORAGE_SECRET_KEY = os.environ['JIGGY_STORAGE_KEY_SECRET'] 11 | 12 | 13 | S3.Bucket.prepare(STORAGE_KEY_ID, 14 | STORAGE_SECRET_KEY, 15 | endpoint_url=ENDPOINT_URL) 16 | 17 | 18 | bucket = S3.Bucket(BUCKET_NAME) 19 | 20 | 21 | linode_obj_config = { 22 | "aws_access_key_id": STORAGE_KEY_ID, 23 | "aws_secret_access_key": STORAGE_SECRET_KEY, 24 | "endpoint_url": ENDPOINT_URL} 25 | 26 | 27 | s3_client = boto3.client('s3', **linode_obj_config) 28 | 29 | 30 | def create_presigned_url(object_name, expiration=300): 31 | """ 32 | Generate a presigned URL to share an S3 object 33 | :param object_name: string 34 | :param expiration: Time in seconds for the presigned URL to remain valid 35 | :return: Presigned URL as string. If error, returns None. 36 | """ 37 | 38 | # Generate a presigned URL for the S3 object 39 | try: 40 | response = s3_client.generate_presigned_url('get_object', 41 | Params={'Bucket': BUCKET_NAME, 42 | 'Key': object_name}, 43 | ExpiresIn=expiration) 44 | except ClientError as e: 45 | logging.error(e) 46 | return None 47 | 48 | # The response contains the presigned URL 49 | return response 50 | 51 | 52 | -------------------------------------------------------------------------------- /src/s3_bucket/README.md: -------------------------------------------------------------------------------- 1 | forked from https://github.com/CodeLighthouse/s3-bucket 2 | 3 | awaiting https://github.com/CodeLighthouse/s3-bucket/pull/4 4 | -------------------------------------------------------------------------------- /src/s3_bucket/__init__.py: -------------------------------------------------------------------------------- 1 | from .bucket import Bucket 2 | from . import exceptions as Exceptions 3 | 4 | __all__ = [ 5 | "Bucket", 6 | "Exceptions" 7 | ] -------------------------------------------------------------------------------- /src/s3_bucket/bucket.py: -------------------------------------------------------------------------------- 1 | import boto3 2 | from botocore.exceptions import ClientError 3 | from typing import Union, Dict 4 | from . import exceptions 5 | 6 | 7 | class Bucket: 8 | """ 9 | CLASS THAT HANDLES S3 BUCKET TRANSACTIONS. ABSTRACTS AWAY BOTO3'S ARCANE BS. 10 | HANDLES BOTO3'S EXCEPTIONS WITH CUSTOM EXCEPTION CLASSES TO MAKE CODE USABLE 11 | """ 12 | _AWS_ACCESS_KEY_ID = None 13 | _AWS_SECRET_ACCESS_KEY = None 14 | _ENDPOINT_URL = None 15 | 16 | def __init__(self, bucket_name: str): 17 | 18 | # ENSURE THE PACKAGE HAS BEEN CONFIGURED WITH THE APPROPRIATE ACCESS KEYS 19 | if not Bucket._AWS_ACCESS_KEY_ID or not Bucket._AWS_SECRET_ACCESS_KEY: 20 | raise TypeError("AWS access key ID and AWS Secret access key must be configured. They're class variables" 21 | "for the Bucket class. You can set them by calling Bucket.prepare(access_key, secret_key)") 22 | self.bucket_name = bucket_name 23 | 24 | @classmethod 25 | def prepare(cls, aws_access_key_id: str, aws_secret_access_key: str, aws_session_token=None, endpoint_url=None): 26 | cls._AWS_ACCESS_KEY_ID = aws_access_key_id 27 | cls._AWS_SECRET_ACCESS_KEY = aws_secret_access_key 28 | cls._AWS_SESSION_TOKEN = aws_session_token 29 | cls._ENDPOINT_URL = endpoint_url 30 | 31 | @staticmethod 32 | def _get_boto3_resource(): 33 | """ 34 | GET AND CONFIGURE THE BOTO3 S3 API RESOURCE. THIS IS A "PRIVATE" METHOD 35 | """ 36 | # CREATE A "SESSION" WITH BOTO3 37 | 38 | _session = boto3.Session( 39 | aws_access_key_id=Bucket._AWS_ACCESS_KEY_ID, 40 | aws_secret_access_key=Bucket._AWS_SECRET_ACCESS_KEY, 41 | aws_session_token=Bucket._AWS_SESSION_TOKEN 42 | ) 43 | 44 | # CREATE S3 RESOURCE 45 | resource = _session.resource('s3', endpoint_url=Bucket._ENDPOINT_URL) 46 | 47 | return resource 48 | 49 | def _handle_boto3_client_error(self, e: ClientError, key=None): 50 | """ 51 | HANDLE BOTO3'S CLIENT ERROR. BOTO3 ONLY RETURNS ONE TYPE OF EXCEPTION, WITH DIFFERENT KEYS AND MESSAGES FOR 52 | DIFFERENT TYPES OF ERRORS. REFER TO EXCEPTIONS.PY FOR EXPLANATION 53 | 54 | :param e: THE CLIENTERROR TO HANDLE 55 | :param key: THE KEY OF THE OBJECT WE'RE DEALING WITH. OPTIONAL, DEFAULT IS None 56 | """ 57 | error_code: str = e.response.get('Error').get('Code') 58 | 59 | print(e.response) 60 | 61 | if error_code == 'AccessDenied': 62 | raise exceptions.BucketAccessDenied(self.bucket_name) 63 | elif error_code == 'NoSuchBucket': 64 | raise exceptions.NoSuchBucket(self.bucket_name) 65 | elif error_code == 'NoSuchKey': 66 | raise exceptions.NoSuchKey(key, self.bucket_name) 67 | else: 68 | raise exceptions.UnknownBucketException(self.bucket_name, e) 69 | 70 | def get(self, key: str, response_content_type: str = None) -> (bytes, Dict): 71 | """ 72 | GET AN OBJECT FROM THE BUCKET AND RETURN A BYTES TYPE THAT MUST BE DECODED ACCORDING TO THE ENCODING TYPE 73 | 74 | :param key: THE KEY IN S3 OF THE OBJECT TO GET 75 | :param response_content_type: THE CONTENT TYPE TO ENFORCE ON THE RESPONSE. MAY BE USEFUL IN SOME CASES 76 | :return: A TWO-TUPLE: (1) A BYTES OBJECT THAT MUST BE DECODED DEPENDING ON HOW IT WAS ENCODED. 77 | LEFT UP TO MIDDLEWARE TO DETERMINE AND (2) A DICT CONTAINING METADATA ON WHEN THE OBJECT WAS STORED 78 | """ 79 | 80 | # GET S3 resource 81 | resource = Bucket._get_boto3_resource() 82 | s3_bucket = resource.Object(self.bucket_name, key) 83 | 84 | try: 85 | if response_content_type: 86 | response = s3_bucket.get(ResponseContentType=response_content_type) 87 | else: 88 | response = s3_bucket.get() 89 | 90 | data = response.get('Body').read() # THE OBJECT DATA STORED 91 | metadata: Dict = response.get('Metadata') # METADATA STORED WITH THE OBJECT 92 | return data, metadata 93 | 94 | # BOTO RAISES ONLY ONE ERROR TYPE THAT THEN MUST BE PROCESSES TO GET THE CODE 95 | except ClientError as e: 96 | self._handle_boto3_client_error(e, key=key) 97 | 98 | def put(self, key: str, data: Union[str, bytes], content_type: str = None, metadata: Dict = {}) -> Dict: 99 | """ 100 | PUT AN OBJECT INTO THE BUCKET 101 | 102 | :param key: THE KEY TO STORE THE OBJECT UNDER 103 | :param data: THE DATA TO STORE. CAN BE BYTES OR STRING 104 | :param content_type: THE MIME TYPE TO STORE THE DATA AS. MAY BE IMPORTANT FOR BINARY DATA 105 | :param metadata: A DICT CONTAINING METADATA TO STORE WITH THE OBJECT. EXAMPLES INCLUDE TIMESTAMP OR 106 | ORGANIZATION NAME. VALUES _MUST_ BE STRINGS. 107 | :return: A DICT CONTAINING THE RESPONSE FROM S3. IF AN EXCEPTION IS NOT THROWN, ASSUME PUT OPERATION WAS SUCCESSFUL. 108 | """ 109 | 110 | # GET RESOURCE 111 | resource = Bucket._get_boto3_resource() 112 | s3_bucket = resource.Object(self.bucket_name, key) 113 | 114 | # PUT IT 115 | try: 116 | if content_type: 117 | response = s3_bucket.put( 118 | Body=data, 119 | ContentType=content_type, 120 | Key=key, 121 | Metadata=metadata 122 | ) 123 | else: 124 | response = s3_bucket.put( 125 | Body=data, 126 | Key=key, 127 | Metadata=metadata 128 | ) 129 | return response 130 | 131 | # BOTO RAISES ONLY ONE ERROR TYPE THAT THEN MUST BE PROCESSES TO GET THE CODE 132 | except ClientError as e: 133 | self._handle_boto3_client_error(e, key=key) 134 | 135 | def delete(self, key: str) -> Dict: 136 | """ 137 | DELETE A SPECIFIED OBJECT FROM THE BUCKET 138 | 139 | :param key: A STRING THAT IS THE OBJECT'S KEY IDENTIFIER IN S3 140 | :return: THE RESPONSE FROM S3. IF NO EXCEPTION WAS THROWN, ASSUME DELETE OPERATION WAS SUCCESSFUL 141 | """ 142 | # GET S3 RESOURCE 143 | resource = Bucket._get_boto3_resource() 144 | s3_bucket = resource.Object(self.bucket_name, key) 145 | 146 | try: 147 | response = s3_bucket.delete() 148 | return response 149 | 150 | # BOTO RAISES ONLY ONE ERROR TYPE THAT THEN MUST BE PROCESSES TO GET THE CODE 151 | except ClientError as e: 152 | self._handle_boto3_client_error(e, key=key) 153 | 154 | def upload_file(self, local_filepath: str, key: str) -> Dict: 155 | """ 156 | UPLOAD A LOCAL FILE TO THE BUCKET. TRANSPARENTLY MANAGES MULTIPART UPLOADS. 157 | 158 | :param local_filepath: THE ABSOLUTE FILEPATH OF THE FILE TO STORE 159 | :param key: THE KEY TO STORE THE FILE UNDER IN THE BUCKET 160 | :return: A DICT CONTAINING THE RESPONSE FROM S3. IF NO EXCEPTION IS THROWN, ASSUME OPERATION 161 | COMPLETED SUCCESSFULLY 162 | """ 163 | 164 | # GET S3 RESOURCE 165 | resource = Bucket._get_boto3_resource() 166 | s3_bucket = resource.Object(self.bucket_name, key) 167 | 168 | try: 169 | response = s3_bucket.upload_file(local_filepath) 170 | return response 171 | 172 | # BOTO RAISES ONLY ONE ERROR TYPE THAT THEN MUST BE PROCESSES TO GET THE CODE 173 | except ClientError as e: 174 | self._handle_boto3_client_error(e, key=key) 175 | 176 | def download_file(self, key: str, local_filepath: str) -> Dict: 177 | """ 178 | DOWNLOAD AN OBJECT FROM THE BUCKET TO A LOCAL FILE. TRANSPARENTLY MANAGES MULTIPART DOWNLOADS. 179 | 180 | :param key: THE KEY THAT IDENTIFIES THE OBJECT TO DOWNLOAD 181 | :param local_filepath: THE ABSOLUTE FILEPATH TO STORE THE OBJECT TO 182 | :return: A DICT CONTAINING THE RESPONSE FROM S3. IF NO EXCEPTION IS THROWN, ASSUME OPERATION 183 | COMPLETED SUCCESSFULLY 184 | """ 185 | # GET S3 RESOURCE 186 | resource = Bucket._get_boto3_resource() 187 | s3_bucket = resource.Object(self.bucket_name, key) 188 | 189 | try: 190 | response = s3_bucket.download_file(local_filepath) 191 | return response 192 | 193 | # BOTO RAISES ONLY ONE ERROR TYPE THAT THEN MUST BE PROCESSES TO GET THE CODE 194 | except ClientError as e: 195 | self._handle_boto3_client_error(e, key=key) 196 | -------------------------------------------------------------------------------- /src/s3_bucket/exceptions.py: -------------------------------------------------------------------------------- 1 | from botocore.exceptions import ClientError 2 | 3 | """ 4 | CUSTOM EXCEPTIONS - DESIGNED AS MORE USEFUL WRAPPERS TO BOTOCORE'S ARCANE BS EXCUSE-FOR-EXCEPTIONS 5 | BASICALLY ENCAPSULATE BOTO EXCEPTIONS W/MORE INFORMATION. 6 | CLIENT CODE UNLIKELY TO KNOW HOW TO CATCH BOTOCORE EXCEPTIONS, BUT THESE ARE EXPOSED THROUGH S3 CLASS SO EZ 7 | """ 8 | 9 | 10 | class BucketException(Exception): 11 | """ 12 | PARENT CLASS THAT ENSURES THAT ALL BUCKET ERRORS ARE CONSISTENT 13 | """ 14 | 15 | def __init__(self, message, bucket): 16 | self.bucket = bucket 17 | self.message = f'{message}' 18 | super().__init__(self.message) 19 | 20 | 21 | class NoSuchKey(BucketException): 22 | """ 23 | RAISED IF YOU TRY TO ACCESS A NON-EXISTENT OBJECT, SINCE IT HAS MOST LIKELY EXPIRED 24 | """ 25 | 26 | def __init__(self, key, bucket): 27 | self.key = key 28 | self.bucket = bucket 29 | self.message = f'No object in bucket {bucket} matches {key}. Has it expired?' 30 | super().__init__(self.message, self.bucket) 31 | 32 | 33 | class NoSuchBucket(BucketException): 34 | """ 35 | RAISED IF YOU TRY TO ACCESS A NONEXISTENT BUCKET 36 | """ 37 | 38 | def __init__(self, bucket_name): 39 | self.bucket = bucket_name 40 | self.message = f'Bucket {bucket_name} does not exist!' 41 | super().__init__(self.message, self.bucket) 42 | 43 | 44 | class BucketAccessDenied(BucketException): 45 | """ 46 | RAISED IF ACCESS TO A BUCKET IS DENIED - LIKELY BECAUSE IT DOESN'T EXIST 47 | """ 48 | 49 | def __init__(self, bucket_name): 50 | self.bucket = bucket_name 51 | self.message = f'Unable to access bucket {self.bucket}. Does it exist?' 52 | 53 | super().__init__(self.message, self.bucket) 54 | 55 | 56 | class UnknownBucketException(BucketException): 57 | """ 58 | RAISED IF AN UNKNOWN S3 EXCEPTION OCCURS 59 | """ 60 | 61 | def __init__(self, bucket_name, e: ClientError): 62 | self.bucket = bucket_name 63 | error_code: str = e.response.get('Error').get('Code') 64 | error_message: str = e.response.get('Error').get('Message') 65 | self.message = f'Unknown Bucket Exception {error_code}: {error_message}' 66 | super().__init__(self.message, self.bucket) 67 | -------------------------------------------------------------------------------- /src/subprompt.py: -------------------------------------------------------------------------------- 1 | # 2 | # SubPrompt class that assists with keeping track of token counts and 3 | # efficiently combining SubPrompts 4 | # 5 | # Copyright (C) 2022 William S. Kish 6 | 7 | import os 8 | from pydantic import BaseModel, Field 9 | from tokenizer import token_len 10 | from typing import Optional 11 | from exceptions import * 12 | 13 | 14 | # have found that marking truncated text as truncated stops the model from trying 15 | # to complete the missing text instead of summarizing it as requested 16 | TRUNCATED = "" 17 | TRUNCATED_LEN = token_len(TRUNCATED) 18 | 19 | class SubPrompt: 20 | """ 21 | A SubPrompt is a text string and associated token count for the string 22 | 23 | len(SubPrompt) returns the length of the SubPrompt in tokens 24 | 25 | SubPrompt1 + SubPrompt2 returns a new subprompt which contains the 26 | concatenated text of the 2 subprompts separated by "\n" 27 | and a mostly accurate token count. 28 | 29 | The combined token count is estimated (not computed) so can sometimes overestimate the 30 | actual token count by 1 token. Tests on random strings show this occurs less 31 | than 1% of the time. 32 | """ 33 | 34 | def truncate(self, max_tokens, precise=False): 35 | if precise == True: 36 | raise Exception("precise truncation is not yet implemented") 37 | # crudely truncate longer texts to get it back down to approximately the target max_tokens 38 | # TODO: find precise truncation point using multiple calls to token_len() 39 | # TODO: consider option to truncating at sentence boundaries. 40 | if self.tokens <= max_tokens: 41 | return 42 | split_point = int(len(self.text) * (max_tokens-TRUNCATED_LEN) / self.tokens) 43 | while not self.text[split_point].isspace(): 44 | split_point -= 1 45 | self.text = self.text[:split_point] + TRUNCATED 46 | self.tokens = token_len(self.text) 47 | if self.tokens > max_tokens: 48 | self.truncate(max_tokens*.95) 49 | 50 | 51 | def __init__(self, text: str, max_tokens=None, truncate=False, precise=False, tokens=None) -> "SubPrompt": 52 | """ 53 | Create a subprompt from the specified string. 54 | If max_tokens is specified, then the SubPrompt will be limited to max_tokens. 55 | The behavior when max_tokens is exceeded is controlled by truncate. 56 | MaximumTokenLimit exception raised if the text exceeds the specified max_tokens and truncate is False. 57 | If truncate is true then the text will be truncated to meet the limit. 58 | If precise is False then the truncation will be very quick but only approximate. 59 | If precise is True then the truncation will be slower but guaranteed to meet the max_tokens limit. 60 | """ 61 | if precise == True: 62 | raise Exception("precise truncation is not yet implemented") 63 | if tokens is None: 64 | tokens = token_len(text) 65 | self.text = text 66 | self.tokens = tokens 67 | if max_tokens is not None and tokens > max_tokens: 68 | if not truncate: 69 | raise MaximumTokenLimit 70 | self.truncate(max_tokens, precise=precise) 71 | 72 | 73 | 74 | def __len__(self) -> int: 75 | return self.tokens 76 | 77 | 78 | def __add__(self, o) -> "SubPrompt": 79 | """ 80 | Combine the token strings and token counts with a newline character in between them. 81 | This will occasionally overestimate the combined token count by 1 token, 82 | which is acceptable for our intended use. 83 | """ 84 | if isinstance(o, str): 85 | o = SubPrompt(o) 86 | 87 | return SubPrompt(text = self.text + "\n" + o.text, 88 | tokens = self.tokens + 1 + o.tokens) 89 | 90 | def __str__(self): 91 | return self.text 92 | 93 | 94 | 95 | if __name__ == "__main__": 96 | """ 97 | test with random strings 98 | """ 99 | from string import ascii_lowercase, whitespace, digits 100 | from random import sample, randint 101 | 102 | chars = ascii_lowercase + whitespace + digits 103 | 104 | def randstring(n): 105 | return "".join([sample(chars, 1)[0] for i in range(n)]) 106 | 107 | count = 0 108 | for i in range(100000): 109 | c1 = randstring(randint(20, 100)) 110 | c2 = randstring(randint(20, 100)) 111 | print("c1", c1) 112 | print("c2", c2) 113 | sp1 = SubPrompt(c1) 114 | sp2 = SubPrompt(c2) 115 | print("sp1", sp1) 116 | print("sp2", sp2) 117 | 118 | print("len sp1:", len(sp1)) 119 | print("len sp2:", len(sp2)) 120 | #assert(len(sp1) == token_len(sp1.text)) 121 | #assert(len(sp2) == token_len(sp2.text)) 122 | 123 | sp3 = sp1 + sp2 124 | print("sp3", sp3) 125 | 126 | print("len sp3:", len(sp3)) 127 | sp3len = token_len(sp3.text) 128 | print(sp3len) 129 | if len(sp3) != sp3len: 130 | count += 1 131 | assert(len(sp3) >= sp3len) 132 | print(count, "errors") 133 | # 651 errors out of 100000 on a typical run 134 | 135 | 136 | 137 | 138 | 139 | -------------------------------------------------------------------------------- /src/tokenizer.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | from transformers import GPT2Tokenizer 4 | 5 | 6 | tokenizer = GPT2Tokenizer.from_pretrained("gpt2") 7 | 8 | def token_len(text : str) -> int: 9 | """ 10 | return number of tokens in text per gpt2 tokenizer 11 | """ 12 | return len(tokenizer(text)['input_ids']) 13 | 14 | 15 | --------------------------------------------------------------------------------