├── .gitignore
├── Dockerfile
├── LICENSE
├── Makefile
├── MassGPT.jpg
├── README.md
├── requirements.txt
└── src
├── ann_index.py
├── ann_search.py
├── bot.py
├── completion.py
├── copytest.py
├── copytest_chat.py
├── db.py
├── download_hf_models_at_buildtime.py
├── exceptions.py
├── extract.py
├── github_api.py
├── gpt3.py
├── hn_summary_db.py
├── jigit.py
├── massgpt.py
├── models.py
├── pdf_text.py
├── querytest.py
├── results.txt
├── s3.py
├── s3_bucket
├── README.md
├── __init__.py
├── bucket.py
└── exceptions.py
├── subprompt.py
└── tokenizer.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 | *.py~
131 | *.txt~
132 | *.hnsf
133 |
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM nvidia/cuda:11.7.1-cudnn8-devel-ubuntu22.04
2 |
3 | # set working directory
4 | WORKDIR /app
5 |
6 | ARG DEBIAN_FRONTEND=noninteractive
7 |
8 | RUN apt-get update && apt-get upgrade -y && apt-get install -y python3-pip ffmpeg git
9 |
10 | # update pip
11 | RUN pip3 install --upgrade pip
12 |
13 | RUN pip3 install git+https://github.com/openai/whisper.git
14 |
15 | # add requirements
16 | COPY ./requirements.txt /app/requirements.txt
17 |
18 | # install requirements
19 | RUN pip3 install -r requirements.txt
20 |
21 | COPY src/*.py .
22 |
23 | # force download of the hugging face models into the container
24 | RUN python3 download_hf_models_at_buildtime.py
25 |
26 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 |
2 | .PHONY: help build push all
3 |
4 | help:
5 | @echo "Makefile commands:"
6 | @echo "build"
7 | @echo "push"
8 | @echo "all"
9 |
10 | .DEFAULT_GOAL := all
11 |
12 | build:
13 | docker build -t jiggyai/massgpt:${TAG} .
14 |
15 | push:
16 | docker push jiggyai/massgpt:${TAG}
17 |
18 | all: build push
19 |
--------------------------------------------------------------------------------
/MassGPT.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiggy-ai/mass-gpt/d6bc22912c1f88a1db5894df49ca35980c095bda/MassGPT.jpg
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | **MassGPT** is an experimental open source Telegram bot that interacts with users by responding to messages via a GPT3.5 (davinci-003) completion that is conditioned by a shared sub-prompt context of all recent user messages.
6 |
7 | Message the MassGPT Telegram bot directly, and it will respond back to you taking into account the content of relevant messages it has received from others. Users are assigned a numeric user id that does not leak any Telegram user info. The messages you send are not private as they are accessible to other users via the bot or via the /context command.
8 |
9 | You can also send the bot a url to inject a summary of the web page text into the current context. Use the /context command to see the entire current context.
10 |
11 | Currently there is one global chat context of recent messages from other users, but I plan to scale this by using an embedding of the user message to dynamically assemble a relevant context via ANN query of past messages, message summaries, and url summaries.
12 |
13 | There are several motivations for creating this. One is to explore chatting with an LLM where the model is presenting some partial consensus of other user’s current inputs instead of the model’s background representations. Another is to explore dynamic prompt contexts which should have a lot of interesting applications given that GPT seems much less likely to hallucinate when summarizing from its current prompt context.
14 |
15 | Open to PRs if you are interested in hacking on this in any way, or feel free to message me @ wskish on Twitter or Telegram.
16 |
17 |
18 |
19 | **OpenAI**
20 |
21 | * OPENAI_API_KEY # your OpenAI API key
22 |
23 |
24 | **PostgresQL**
25 |
26 | Database for keeping track of items we have already seen and associated item info.
27 |
28 | - MASSGPT_POSTGRES_HOST # The database FQDN
29 | - MASSGPT_POSTGRES_USER # The database username
30 | - MASSGPT_POSTGRES_PASS # The database password
31 |
32 | **Telegram**
33 |
34 | * MASSGPT_TELEGRAM_API_TOKEN # The bot's telegram API token
35 |
36 |
37 |
38 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | loguru==0.6.0
2 | pydantic==1.10.2
3 | sqlmodel==0.0.8
4 | psycopg2-binary==2.9.5
5 | requests==2.28.1
6 | python-telegram-bot==20.b0
7 | BeautifulSoup4==4.11.1
8 | openai==0.25.0
9 | readability-lxml==0.8.1
10 | transformers==4.25.1
11 | markdown==3.4.1
12 | pdfminer.six==20221105
13 | openai==0.25.0
14 | readability-lxml==0.8.1
15 | pdfminer.six==20221105
16 | sentence_transformers==2.2.2
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
--------------------------------------------------------------------------------
/src/ann_index.py:
--------------------------------------------------------------------------------
1 | import hnswlib
2 | import psutil
3 |
4 | from s3 import bucket
5 |
6 | from loguru import logger
7 | from sqlmodel import Session, select
8 | from db import engine
9 | from models import *
10 |
11 | CPU_COUNT = psutil.cpu_count()
12 | ST_MODEL_NAME = 'multi-qa-mpnet-base-dot-v1'
13 |
14 |
15 | hnsw_index = hnswlib.Index(space='cosine', dim=768)
16 | hnsw_index.set_num_threads(int(CPU_COUNT/2))
17 |
18 | hnsw_index.init_index(max_elements = 20000,
19 | ef_construction = 100,
20 | M = 16)
21 |
22 | count = 0
23 | with Session(engine) as session:
24 | for embedding in session.exec(select(Embedding).where(Embedding.collection == ST_MODEL_NAME)):
25 | if embedding.vector is None:
26 | session.delete(embedding)
27 | session.commit()
28 | continue
29 | hnsw_index.add_items([embedding.vector], [embedding.id])
30 | count += 1
31 | print(embedding.id)
32 |
33 | filename = "index-%s-%d.hnsf" % (ST_MODEL_NAME, count)
34 | hnsw_index.save_index(filename)
35 |
36 | objkey = f"massgpt/ann-index/{ST_MODEL_NAME}-{count}.hnsw"
37 | bucket.upload_file(filename, objkey)
38 |
39 | with Session(engine) as session:
40 | ix = HnswIndex(collection = ST_MODEL_NAME,
41 | count = count,
42 | objkey = objkey)
43 | session.add(ix)
44 | session.commit()
45 |
--------------------------------------------------------------------------------
/src/ann_search.py:
--------------------------------------------------------------------------------
1 |
2 | import hnswlib
3 | from db import engine
4 | from models import *
5 | from sentence_transformers import SentenceTransformer
6 | from sqlmodel import Session, select
7 | import psutil
8 | import hn_summary_db
9 |
10 | CPU_COUNT = psutil.cpu_count()
11 |
12 | ## Embedding Config
13 | ST_MODEL_NAME = 'multi-qa-mpnet-base-dot-v1'
14 | st_model = SentenceTransformer(ST_MODEL_NAME)
15 |
16 | hnsw_ix = hnswlib.Index(space='cosine', dim=768)
17 | hnsw_ix.load_index('index-multi-qa-mpnet-base-dot-v1-0.hnsf', max_elements=20000)
18 | hnsw_ix.set_ef(1000)
19 |
20 | STORY_SOURCES = [EmbeddingSource.hn_story_summary, EmbeddingSource.hn_story_title]
21 |
22 |
23 | def search(query : str):
24 | query = query.rstrip()
25 | print(query)
26 | vector = st_model.encode(query)
27 |
28 | ids, distances = hnsw_ix.knn_query([vector], k=10)
29 | ids = [int(i) for i in ids[0]]
30 | distances = [float(i) for i in distances[0]]
31 | with Session(engine) as session:
32 | ann_embeddings = session.exec(select(Embedding).where(Embedding.id.in_(ids))).all()
33 | hn_story_ids = [e.source_id for e in ann_embeddings if e.source in STORY_SOURCES]
34 | stories = hn_summary_db.stories(hn_story_ids)
35 |
36 | vid_to_distance = {vid : distance for vid,distance in zip(ids, distances)}
37 | sid_to_vid = {emb.source_id : emb.id for emb in ann_embeddings}
38 |
39 | def distance(story):
40 | return vid_to_distance[sid_to_vid[story.id]]
41 |
42 | stories.sort(key=distance)
43 |
44 | for story in stories:
45 | print(f"{distance(story):.2f} {story.title}")
46 |
47 | return stories
48 |
49 | if __name__ == "__main__":
50 | search("SBF fraud")
51 | while(True):
52 | stories = search(input("Search: "))
53 | for s in stories:
54 | print()
55 | print(s.title)
56 | print(hn_summary_db.story_summary(s))
57 | hn_summary_db.story_text(s)
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
--------------------------------------------------------------------------------
/src/bot.py:
--------------------------------------------------------------------------------
1 | """
2 | MassGPT
3 |
4 | A destructive distillation of mass communication.
5 |
6 | Condition an LLM completion on a dynamically assembled subprompt context.
7 |
8 | Telegram: @MassGPTbot
9 | https://t.me/MassGPTbot
10 |
11 | Copyright (C) 2022 William S. Kish
12 |
13 | """
14 |
15 | import os
16 | from sqlmodel import Session, select
17 | from loguru import logger
18 | from pydantic import BaseModel, Field
19 | from telegram import Update
20 | from telegram.ext import ApplicationBuilder, MessageHandler, CommandHandler, ContextTypes, filters
21 | import re
22 |
23 | import openai
24 | from db import engine
25 | from models import *
26 | from exceptions import *
27 |
28 | import massgpt
29 |
30 |
31 | # the bot app
32 | bot = ApplicationBuilder().token(os.environ['MASSGPT_TELEGRAM_API_TOKEN']).build()
33 |
34 |
35 | def extract_url(text: str):
36 | try:
37 | return re.search("(?Phttps?://[^\s]+)", text).group("url")
38 | except AttributeError:
39 | return None
40 |
41 |
42 |
43 | def get_telegram_user(update : Update) -> User:
44 | """
45 | return database User object of the sender of a telegram message
46 | Create the user if it didn't previously exist.
47 | """
48 | with Session(engine) as session:
49 | user = session.exec(select(User).where(User.telegram_id == update.message.from_user.id)).first()
50 | if not user:
51 | tuser = update.message.from_user
52 | user = User(username = tuser.username,
53 | first_name = tuser.first_name,
54 | last_name = tuser.last_name,
55 | telegram_id = tuser.id,
56 | telegram_is_bot = tuser.is_bot,
57 | telegram_is_premium = tuser.is_premium,
58 | telegram_lanuage_code = tuser.language_code)
59 | session.add(user)
60 | session.commit()
61 | session.refresh(user)
62 | return user
63 |
64 |
65 |
66 | async def message(update: Update, tgram_context: ContextTypes.DEFAULT_TYPE) -> None:
67 | """
68 | Handle message received from user.
69 | Send back to the user the response text from the model.
70 | Handle exceptions by sending an error message to the user.
71 | """
72 | user = get_telegram_user(update)
73 | text = update.message.text
74 | logger.info(f'{user.id} {user.first_name} {user.last_name} {user.username} {user.telegram_id}: "{text}"')
75 | try:
76 | url = extract_url(text)
77 | print("URL", url)
78 | if url:
79 | response = massgpt.summarize_url(user, url)
80 | else:
81 | response = massgpt.receive_message(user, text)
82 | await update.message.reply_text(response)
83 | except (openai.error.ServiceUnavailableError, openai.error.RateLimitError):
84 | await update.message.reply_text("The OpenAI server is overloaded.")
85 | except ExtractException:
86 | await update.message.reply_text("Unable to extract text from url.")
87 | except MinimumTokenLimit:
88 | await update.message.reply_text("Message too short; please send a longer message.")
89 | except MaximumTokenLimit:
90 | await update.message.reply_text("Message too large; please send a shorter message.")
91 | except Exception as e:
92 | logger.exception("error processing message")
93 | await update.message.reply_text("An exceptional condition occured.")
94 |
95 |
96 | bot.add_handler(MessageHandler(filters.TEXT & ~filters.COMMAND, message))
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 | async def command(update: Update, tgram_context: ContextTypes.DEFAULT_TYPE) -> None:
108 | """
109 | Handle command from user
110 | context - Respond with the current chat context
111 | url - Summarize a url and add the summary to the chat context
112 | """
113 | user = get_telegram_user(update)
114 | text = update.message.text
115 |
116 | logger.info(f'{user.id} {user.first_name} {user.last_name} {user.username} {user.telegram_id}: "{text}"')
117 |
118 |
119 | if text == '/context':
120 | await update.message.reply_text("The current context:")
121 | for msg in massgpt.current_context():
122 | await update.message.reply_text(msg)
123 | return
124 | elif text == '/prompts':
125 | await update.message.reply_text(massgpt.current_prompts())
126 | return
127 | elif text[:5] == '/url ':
128 | try:
129 | url = extract_url(text)
130 | response = massgpt.summarize_url(user, url)
131 | await update.message.reply_text(response)
132 | except (openai.error.ServiceUnavailableError, openai.error.RateLimitError):
133 | await update.message.reply_text("The OpenAI server is overloaded.")
134 | except ExtractException:
135 | await update.message.reply_text("Unable to extract text from url.")
136 | except Exception as e:
137 | logger.exception("error processing message")
138 | await update.message.reply_text("An exceptional condition occured.")
139 | return
140 | await update.message.reply_text("Send me a message and I will assemble a collection of recent or related messages into a GPT prompt context and prompt your message against that dynamic context, sending you the GPT response. Send /context to see the current prompt context. Send '/url ' to add a summary of the url to the context. Consider this to be a public chat and please maintain a kind and curious standard.")
141 |
142 |
143 | bot.add_handler(MessageHandler(filters.COMMAND, command))
144 |
145 |
146 | logger.info("run_polling")
147 | bot.run_polling()
148 |
--------------------------------------------------------------------------------
/src/completion.py:
--------------------------------------------------------------------------------
1 | # Completion Abstraction
2 | # Copyright(C) 2022 William S. Kish
3 |
4 |
5 | import os
6 | from pydantic import BaseModel, Field
7 | from sqlmodel import Session, select
8 |
9 | from models import Completion
10 | from subprompt import SubPrompt
11 |
12 | from exceptions import *
13 |
14 |
15 | class CompletionLimits(BaseModel):
16 | """
17 | specification for limits for:
18 | min prompt tokens
19 | min completion tokens
20 | max completion tokens
21 | given a max_content tokens
22 | """
23 | max_context : int
24 | min_prompt : int
25 | min_completion : int
26 | max_completion : int
27 |
28 | def max_completion_tokens(self, prompt : SubPrompt) -> int:
29 | """
30 | returns the maximum completion tokens available given the max_context limit
31 | and the actual number of tokens in the prompt
32 | raises MinimumTokenLimit or MaximumTokenLimit exceptions if the prompt
33 | is too small or too big.
34 | """
35 | if prompt.tokens < self.min_prompt:
36 | raise MinimumTokenLimit
37 | if prompt.tokens > self.max_prompt_tokens():
38 | raise MaximumTokenLimit
39 | max_available_tokens = self.max_context - prompt.tokens
40 | if max_available_tokens > self.max_completion:
41 | return self.max_completion
42 | return max_available_tokens
43 |
44 | def max_prompt_tokens(self) -> int:
45 | """
46 | return the maximum prompt size in tokens
47 | """
48 | return self.max_context - self.min_completion
49 |
50 |
51 |
52 | class CompletionTask:
53 | """
54 | A LLM completion task that shares a particular llm configuration and prompt/completion limit structure.
55 | This is a base class for a model-specific completion task. A model-api-specific implemention must
56 | at a minimum implement the _completion() method.
57 | See gpt3.GPT3CompletionTask for an example implementation.
58 | """
59 | def __init__(self,
60 | limits : CompletionLimits,
61 | model : str) -> "CompletionTask" :
62 |
63 | self.limits = limits
64 | self.model = model
65 |
66 | def max_prompt_tokens(self) -> int:
67 | return self.limits.max_prompt_tokens()
68 |
69 | def limits(self) -> CompletionLimits:
70 | return self.limits
71 |
72 | def _completion(self,
73 | prompt : str,
74 | max_completion_tokens : int) -> str :
75 | """
76 | perform the actual completion, returning the completion text string
77 | This should be implemented in a model-api-specific base class.
78 | """
79 | pass
80 |
81 | def completion(self, prompt : SubPrompt) -> Completion:
82 | """
83 | prompt the model with the specified prompt and return the resulting Completion
84 | """
85 | # check prompt limits and return max completion size to request
86 | # given the size of the prompt and the configured limits
87 | max_completion = self.limits.max_completion_tokens(prompt)
88 |
89 | # perform the completion inference
90 | response = self._completion(prompt = str(prompt),
91 | max_completion_tokens = max_completion)
92 |
93 | completion = Completion(model = self.model,
94 | prompt = str(prompt),
95 | temperature = 0, # XXX set this as model params?
96 | completion = response)
97 | return completion
98 |
--------------------------------------------------------------------------------
/src/copytest.py:
--------------------------------------------------------------------------------
1 | from loguru import logger
2 | from gpt3 import GPT3CompletionTask, CompletionLimits
3 | from exceptions import *
4 | from models import Completion
5 |
6 | from subprompt import SubPrompt
7 |
8 | class MessageSubPrompt(SubPrompt):
9 | """
10 | SubPrompt Context for a user-generated message
11 | """
12 | MAX_TOKENS = 300
13 | @classmethod
14 | def from_user_str(cls, username : str, msg: str) -> "SubPrompt":
15 | # create user message specific subprompt
16 | text = f"'user-{username}' wrote to MassGPT: {msg}"
17 | return MessageSubPrompt(text=text, max_tokens=MessageSubPrompt.MAX_TOKENS)
18 |
19 |
20 | PREPROMPT = SubPrompt( \
21 | """You are MassGPT, and this is a fun experiment. \
22 | You were built by Jiggy AI using OpenAI text-davinci-003. \
23 | Instruction: Different users are sending you messages. \
24 | They can not communicate with each other directly. \
25 | Any user to user message must be relayed through you. \
26 | Pass along any interesting message. \
27 | Try not to repeat yourself. \
28 | Ask users questions if you are not sure what to say. \
29 | If a user expresses interest in a topic discussed here, \
30 | respond to them based on what you read here. \
31 | Users have recently said the following to you:""")
32 |
33 |
34 | class GPTCopyTestTask(GPT3CompletionTask):
35 | """
36 | Generated message response completions based on dynamic history of recent messages and most used message
37 | """
38 | TEMPERATURE = 0.0
39 |
40 | # General Prompt Strategy:
41 | # Upon reception of message from a user 999, compose the following prompt
42 | # based on recent messages received from other users:
43 |
44 | #PREPROMPT = SubPrompt("Prepare to copy some of the following messages that users sent to MassGPT:")
45 |
46 |
47 | # "User 123 wrote: ABC is the greatest thing ever"
48 | # "User 234 wrote: ABC is cool but i like GGG more"
49 | # "User 345 wrote: DDD is the best ever!"
50 | # etc
51 | #PENULTIMATE_PROMPT = SubPrompt("Instruction: Respond to the following user message considering the above context and Instruction:")
52 | # "User 999 wrote: What do folks think about ABC?" # End of Prompt
53 | #FINAL_PROMPT = SubPrompt("MassGPT responded:")
54 | # Then send resulting llm completion back to user 999 in response to his message
55 |
56 | def __init__(self) -> "GPTCopyTestTask":
57 | limits = CompletionLimits(min_prompt = 0,
58 | min_completion = 300,
59 | max_completion = 400)
60 |
61 | super().__init__(limits = limits,
62 | temperature = GPTCopyTestTask.TEMPERATURE,
63 | stop = ['###'],
64 | model = 'text-davinci-003')
65 |
66 |
67 | def completion(self,
68 | recent_msgs : MessageSubPrompt,
69 | user_msg : MessageSubPrompt) -> Completion:
70 | """
71 | return completion for the provided subprompts
72 | """
73 | #prompt = GPTCopyTestTask.PREPROMPT
74 | #final_prompt = GPTCopyTestTask.PENULTIMATE_PROMPT
75 | #final_prompt += user_msg
76 | #final_prompt += GPTCopyTestTask.FINAL_PROMPT
77 |
78 | #logger.info(f"overhead tokens: {(prompt + final_prompt).tokens}")
79 |
80 | #available_tokens = self.max_prompt_tokens() - (prompt + final_prompt).tokens
81 |
82 | #logger.info(f"available_tokens: {available_tokens}")
83 | # assemble list of most recent_messages up to available token limit
84 | prompt = SubPrompt("")
85 | for sub in recent_msgs:
86 | prompt += sub
87 | prompt += user_msg
88 | #prompt += final_prompt
89 | # add most recent user message after penultimate prompt
90 | logger.info(f"final prompt tokens: {prompt.tokens} max{self.max_prompt_tokens()}")
91 |
92 | logger.info(f"final prompt token_count: {prompt.tokens} chars: {len(prompt.text)}")
93 |
94 | return super().completion(prompt)
95 |
96 |
97 | msg_response_task = GPTCopyTestTask()
98 |
99 |
100 | #users = ['cat', 'dog', 'pig', 'cow', 'rocket', 'ocean', 'mountain', 'tree']
101 | #users = ['cat', 'rocket']
102 | users = list(range(10))
103 |
104 | from random import choice
105 |
106 |
107 | NUM_MESSAGES = 220
108 | #NUM_MESSAGES = 150
109 |
110 | data = []
111 | messages = []
112 | for i in range(NUM_MESSAGES):
113 | user = choice(users)
114 | msg = f"This is message {i}"
115 | data.append((user, msg))
116 | msp = MessageSubPrompt.from_user_str(user, msg)
117 | messages.append(msp)
118 | print(msp)
119 |
120 | print("=======================")
121 | target_u = choice(users)
122 |
123 |
124 | final_p = MessageSubPrompt.from_user_str(choice(users),
125 | f"Instruction: Copy all of the messages that 'user-{target_u}' wrote to MassGPT, only one message per line:")
126 |
127 | final_p = SubPrompt(f"Instruction: Copy all of the messages that 'user-{target_u}' wrote to MassGPT, only one message per line:")
128 |
129 |
130 | comp = msg_response_task.completion(messages, final_p)
131 |
132 | print(comp.prompt)
133 | print("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~")
134 | print(str(comp))
135 |
136 |
137 | answers = [msg for user,msg in data if user==target_u]
138 |
139 | correct = 0
140 | results = comp.completion.rstrip().lstrip().split('\n')
141 |
142 | answers = set([a.rstrip().lstrip().lower() for a in answers])
143 | comps = set([c.rstrip().lstrip().lower() for c in results])
144 |
145 | #print(answers)
146 | #print(comps)
147 | correct = answers.intersection(comps)
148 |
149 |
150 |
151 | """
152 | for c , a in zip(results.split("\n"), answer):
153 | a = a.rstrip().lower()
154 | c = c.rstrip().lower()
155 | print(f"{a} | {c}")
156 | if a in c:
157 | correct += 1
158 | """
159 |
160 | precision = len(correct)/len(comps)
161 | recall = len(correct)/len(answers)
162 | print(f"precision {precision} \t recall {recall}")
163 |
--------------------------------------------------------------------------------
/src/copytest_chat.py:
--------------------------------------------------------------------------------
1 | from loguru import logger
2 | import openai
3 | from time import time
4 | from retry import retry
5 |
6 | users = list(range(10))
7 |
8 | from random import choice
9 |
10 | #NUM_MESSAGES = 220
11 | NUM_MESSAGES = 480
12 |
13 | @retry(5)
14 | def test(model):
15 | data = []
16 | messages = []
17 | target_u = choice(users)
18 | prompt = f"Instruction: Copy all of the messages that 'user-{target_u}' wrote to MassGPT, only one message per line:\n"
19 | for i in range(NUM_MESSAGES):
20 | user = choice(users)
21 | msg = f"This is message {i}"
22 | data.append((user, msg))
23 | text = f"'user-{user}' wrote to MassGPT: {msg}\n"
24 | prompt += text
25 |
26 | prompt += f"Instruction: Copy all of the messages that 'user-{target_u}' wrote to MassGPT, only one message per line:"
27 | print(prompt)
28 | print("=======================")
29 |
30 | messages = [{"role": "user", "content": prompt}]
31 |
32 |
33 | t0 = time()
34 |
35 | if model in ['gpt-3.5-turbo', 'gpt-4']:
36 | response = openai.ChatCompletion.create(model=model,
37 | messages=messages,
38 | temperature=0)
39 | print(response)
40 | response_text = response['choices'][0]['message']['content']
41 | else:
42 |
43 | response = openai.Completion.create(engine=model,
44 | temperature=0,
45 | max_tokens=500,
46 | prompt=prompt)
47 | response_text = response.choices[0].text
48 | dt = time() - t0
49 | print("dt", dt)
50 |
51 | print(response_text)
52 | print("=======================")
53 |
54 | answers = [msg for user,msg in data if user==target_u]
55 |
56 | correct = 0
57 | results = response_text.strip().split('\n')
58 | print("RESULTS:")
59 | print(results)
60 |
61 | answers = set([a.strip().lower() for a in answers])
62 | comps = set([c.split(':')[-1].strip().lower() for c in results])
63 |
64 | print("ANSWERS")
65 | print(answers)
66 | print('comps')
67 | print(comps)
68 | correct = answers.intersection(comps)
69 |
70 |
71 | precision = len(correct)/len(comps)
72 | recall = len(correct)/len(answers)
73 | print(f"precision {precision} \t recall {recall}")
74 |
75 | return precision, recall, dt
76 |
77 |
78 | precision = {}
79 | recall = {}
80 | latency = {}
81 | #MODELS = ['gpt-3.5-turbo', 'text-davinci-003'] #, 'text-davinci-002']
82 |
83 | #MODELS = ['gpt-3.5-turbo', 'gpt-4'] #, 'text-davinci-002']
84 |
85 | MODELS = ['gpt-4'] #, 'text-davinci-002']
86 |
87 | for model in MODELS:
88 | precision[model] = []
89 | recall[model] = []
90 | latency[model] = []
91 |
92 | for model in MODELS:
93 | for i in range(10):
94 | p, r, dt = test(model)
95 | precision[model].append(p)
96 | recall[model].append(r)
97 | latency[model].append(dt)
98 | for model in MODELS:
99 | print(f"{model:15} precision {100*sum(precision[model])/len(precision[model]):.1f}% \trecall {100*sum(recall[model])/len(recall[model]):.1f}%\tlatency {sum(latency[model])/len(latency[model]):.1f} s")
100 |
101 |
--------------------------------------------------------------------------------
/src/db.py:
--------------------------------------------------------------------------------
1 | # database engine
2 | import os
3 | from sqlmodel import create_engine, SQLModel
4 |
5 |
6 |
7 | # DB Config
8 | db_host = os.environ['MASSGPT_POSTGRES_HOST']
9 | user = os.environ['MASSGPT_POSTGRES_USER']
10 | passwd = os.environ['MASSGPT_POSTGRES_PASS']
11 |
12 | DBURI = 'postgresql+psycopg2://%s:%s@%s:5432/massgpt' % (user, passwd, db_host)
13 |
14 | engine = create_engine(DBURI, pool_pre_ping=True, echo=False)
15 |
16 | if __name__ == "__main__":
17 | from models import *
18 | SQLModel.metadata.create_all(engine)
19 | print("create_all complete")
20 |
--------------------------------------------------------------------------------
/src/download_hf_models_at_buildtime.py:
--------------------------------------------------------------------------------
1 | # used at docker build time to pull model cache into container
2 |
3 | from transformers import GPT2Tokenizer
4 | tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
5 |
6 | from sentence_transformers import SentenceTransformer
7 | st_model_name = 'multi-qa-mpnet-base-dot-v1'
8 | st_model = SentenceTransformer(st_model_name)
9 |
10 | #import whisper
11 | #whisper_model = whisper.load_model("large")
12 |
13 | print("HF done")
14 |
--------------------------------------------------------------------------------
/src/exceptions.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | class MinimumTokenLimit(Exception):
4 | """
5 | The specified minimum token count has been exceeded
6 | """
7 |
8 | class MaximumTokenLimit(Exception):
9 | """
10 | The specified maximum token count has been exceeded
11 | """
12 |
13 |
14 |
15 | class ExtractException(Exception):
16 | """
17 | various extraction errors
18 | """
19 |
20 | class UnsupportedHostException(ExtractException):
21 | """
22 | The URL is for a host we know we can't access reliably
23 | """
24 |
25 | class UnsupportedContentType(ExtractException):
26 | """
27 | The http content type is unsupported.
28 | """
29 |
30 | class EmptyText(ExtractException):
31 | """
32 | Unable to extract any readable text from the URL.
33 | """
34 |
35 |
36 | class NetworkError(ExtractException):
37 | """
38 | Unable to access the content.
39 | """
40 |
41 |
42 |
--------------------------------------------------------------------------------
/src/extract.py:
--------------------------------------------------------------------------------
1 | """
2 | Extract readable text from a URL via various hacks
3 | """
4 |
5 | from loguru import logger
6 | from bs4 import BeautifulSoup, NavigableString, Tag
7 | from readability import Document # https://github.com/buriy/python-readability
8 | import requests
9 | import urllib.parse
10 |
11 | from github_api import github_readme_text
12 | from pdf_text import pdf_text
13 |
14 | from exceptions import *
15 |
16 |
17 | def extract_text_from_html(content):
18 | soup = BeautifulSoup(content, 'html.parser')
19 |
20 | output = ""
21 | title = soup.find('title')
22 | if title:
23 | output += "Title: " + title
24 |
25 | blacklist = ['[document]','noscript','header','html','meta','head','input','script', "style"]
26 | # there may be more elements we don't want
27 |
28 | for t in soup.find_all(text=True):
29 | if t.parent.name not in blacklist:
30 | output += '{} '.format(t)
31 | return output
32 |
33 |
34 | def get_url_text(url):
35 | """
36 | get url content and extract readable text
37 | returns the text
38 | """
39 | resp = requests.get(url, timeout=30)
40 |
41 | if resp.status_code != 200:
42 | logger.warning(url)
43 | raise NetworkError(f"Unable to get URL ({resp.status_code})")
44 |
45 | CONTENT_TYPE = resp.headers['Content-Type']
46 |
47 | if 'pdf' in CONTENT_TYPE:
48 | return pdf_text(resp.content)
49 |
50 | if "html" not in CONTENT_TYPE:
51 | logger.warning(url)
52 | raise UnsupportedContentType(f"Unsupported content type: {resp.headers['Content-Type']}")
53 |
54 | doc = Document(resp.text)
55 | text = extract_text_from_html(doc.summary())
56 |
57 | if not len(text) or text.isspace():
58 | logger.warning(url)
59 | raise EmptyText("Unable to extract text data from url")
60 | return text
61 |
62 |
63 |
64 | def url_to_text(url):
65 | HOPELESS = ["youtube.com",
66 | "www.youtube.com"]
67 | if urllib.parse.urlparse(url).netloc in HOPELESS:
68 | logger.warning(url)
69 | raise UnsupportedHostException("Unsupported host: {urllib.parse.urlparse(url).netloc}")
70 |
71 | if urllib.parse.urlparse(url).netloc == 'github.com':
72 | # for github repos use api to attempt to find a readme file
73 | text = github_readme_text(url)
74 | else:
75 | text = get_url_text(url)
76 |
77 | logger.info("url_to_text: "+text)
78 | return text
79 |
--------------------------------------------------------------------------------
/src/github_api.py:
--------------------------------------------------------------------------------
1 | # wrangle text out of github repo readme via api
2 | # only works for repo readme, not other github pages like issues, discussions, etc
3 |
4 | from loguru import logger
5 | import requests
6 | import markdown
7 | from bs4 import BeautifulSoup
8 |
9 |
10 | def md_to_text(md):
11 | html = markdown.markdown(md)
12 | soup = BeautifulSoup(html, features='html.parser')
13 | return soup.get_text()
14 |
15 |
16 | def github_readme_text(github_repo_url):
17 | # split a github url into the owner and repo components
18 | # use the github api to try to find the readme text
19 | # ['https:', '', 'github.com', 'jiggy-ai', 'hn_summary']
20 | spliturl = github_repo_url.rstrip('/').split('/')
21 | if len(spliturl) != 5:
22 | logger.warning(github_repo_url)
23 | raise Exception(f"Unable to process github url {github_repo_url}")
24 | owner = spliturl[3]
25 | repo = spliturl[4]
26 | contenturl = f'https://api.github.com/repos/{owner}/{repo}/readme'
27 | resp = requests.get(contenturl)
28 | if resp.status_code != 200:
29 | logger.warning(f"{github_repo_url} {resp.content}")
30 | raise Exception(f"Unable to get readme for {github_repo_url}")
31 | item = resp.json()
32 | md = requests.get(item['download_url']).text
33 | return md_to_text(md)
34 |
35 |
--------------------------------------------------------------------------------
/src/gpt3.py:
--------------------------------------------------------------------------------
1 | # GPT3 specific Completions
2 | # Copyright (C) 2022 William S. Kish
3 |
4 | import os
5 | from loguru import logger
6 | from time import sleep
7 |
8 | import completion
9 | import openai
10 |
11 | openai.api_key = os.environ["OPENAI_API_KEY"]
12 |
13 | OPENAI_COMPLETION_MODELS = ["text-davinci-003", "text-davinci-002", "text-davinci-001"]
14 |
15 |
16 | def CompletionLimits(min_prompt:int,
17 | min_completion:int,
18 | max_completion:int,
19 | max_context: int = 4097) -> completion.CompletionLimits:
20 | """
21 | CompletionLimits specific to GPT3 models with max_context of 4097 tokens
22 | """
23 | assert(max_context <= 4097)
24 | return completion.CompletionLimits(max_context = max_context,
25 | min_prompt = min_prompt,
26 | min_completion = min_completion,
27 | max_completion = max_completion)
28 |
29 |
30 |
31 | RETRY_COUNT = 10
32 |
33 | class GPT3CompletionTask(completion.CompletionTask):
34 | """
35 | An OpenAI GP3-class completion task implemented using OpenAI API
36 | """
37 |
38 |
39 | def __init__(self,
40 | limits : completion.CompletionLimits,
41 | temperature : float = 1,
42 | top_p : float = 1,
43 | stop : list[str] = None,
44 | model : str = 'text-davinci-003') -> "GPT3CompletionTask":
45 |
46 | assert(model in OPENAI_COMPLETION_MODELS)
47 |
48 | if model in ["text-davinci-003", "text-davinci-002"]:
49 | assert(limits.max_context <= 4097)
50 | else:
51 | assert(limits.max_context <= 2048)
52 |
53 | self.stop = stop
54 | self.top_p = top_p
55 | self.temperature = temperature
56 |
57 | super().__init__(limits = limits,
58 | model = model)
59 |
60 |
61 | def _completion(self,
62 | prompt : str,
63 | max_completion_tokens : int) -> str :
64 | """
65 | perform the actual completion via openai api
66 | returns the completion text string
67 | """
68 | def completion():
69 | resp = openai.Completion.create(engine = self.model,
70 | prompt = prompt,
71 | temperature = self.temperature,
72 | top_p = self.top_p,
73 | stop = self.stop,
74 | max_tokens = max_completion_tokens)
75 | return resp.choices[0].text
76 |
77 | for i in range(RETRY_COUNT):
78 | try:
79 | return completion()
80 | except (openai.error.RateLimitError, openai.error.ServiceUnavailableError):
81 | logger.warning("openai error")
82 | if i == RETRY_COUNT-1:
83 | raise
84 | sleep(i**1.3)
85 | except Exception as e:
86 | logger.exception("_completion")
87 | raise
88 |
89 |
90 |
91 |
--------------------------------------------------------------------------------
/src/hn_summary_db.py:
--------------------------------------------------------------------------------
1 |
2 | from typing import Optional
3 | from pydantic import condecimal
4 | from time import time
5 |
6 | import os
7 | from sqlmodel import create_engine, SQLModel, Field, Session, select
8 |
9 |
10 |
11 | # DB Config
12 | db_host = os.environ['HNSUM_POSTGRES_HOST']
13 | user = os.environ['HNSUM_POSTGRES_USER']
14 | passwd = os.environ['HNSUM_POSTGRES_PASS']
15 |
16 | DBURI = 'postgresql+psycopg2://%s:%s@%s:5432/hnsum' % (user, passwd, db_host)
17 |
18 | # Create DB Engine
19 | engine = create_engine(DBURI, pool_pre_ping=True, echo=False)
20 |
21 |
22 | timestamp = condecimal(max_digits=14, decimal_places=3)
23 |
24 |
25 | class HackerNewsStory(SQLModel, table=True):
26 | # partial state of HN Story item. See https://github.com/HackerNews/API#items
27 | # we dont include "type" since we are only recording type='story' here.
28 | id: int = Field(primary_key=True, description="The item's unique id.")
29 | by: str = Field(index=True, description="The username of the item's author.")
30 | time: int = Field(index=True, description="Creation date of the item, in Unix Time.")
31 | title: str = Field(description="The title of the story, poll or job. HTML.")
32 | text: Optional[str] = Field(description="The comment, story or poll text. HTML.")
33 | url: Optional[str] = Field(description="The url associated with the Item.")
34 |
35 | class StoryText(SQLModel, table=True):
36 | id: int = Field(primary_key=True, description="The summary unique id.")
37 | story_id: int = Field(index=True, description="The story id this text is associated with.")
38 | mechanism: str = Field(description="identifies which software mechanism exracted the text from the url")
39 | crawl_time: timestamp = Field(default_factory=time, description='The epoch timestamp when the url was crawled.')
40 | html: Optional[str] = Field(description="original html content")
41 | text: str = Field(max_length=65535, description="The readable text we managed to extract from the Story Url.")
42 |
43 | class StorySummary(SQLModel, table=True):
44 | id: int = Field(primary_key=True, description="The summary unique id.")
45 | story_id: int = Field(index=True, description="The story id.")
46 | model: str = Field(description="The model used to summarize a story")
47 | prompt: str = Field(max_length=65535, description="The prompt used to create the summary.")
48 | summary: str = Field(max_length=65535, description="The summary we got back from the model.")
49 | upvotes: Optional[int] = Field(default=0, description="The number of upvotes for this summary.")
50 | votes: Optional[int] = Field(default=0, description="The total number of votes for this summary.")
51 |
52 |
53 |
54 |
55 | def stories(story_ids : int) -> HackerNewsStory:
56 | with Session(engine) as session:
57 | return session.exec(select(HackerNewsStory).where(HackerNewsStory.id.in_(story_ids))).all()
58 |
59 |
60 |
61 | def story_text(story : HackerNewsStory) -> str:
62 | with Session(engine) as session:
63 | storytext = session.exec(select(StoryText).where(StoryText.story_id == story.id)).first()
64 | if not storytext:
65 | return ""
66 | return storytext.text
67 |
68 | def story_summary(story : HackerNewsStory) -> str:
69 | with Session(engine) as session:
70 | storysummary = session.exec(select(StorySummary).where(StorySummary.story_id == story.id)).first()
71 | if not storysummary:
72 | return ""
73 | return storysummary.summary
74 |
75 |
--------------------------------------------------------------------------------
/src/jigit.py:
--------------------------------------------------------------------------------
1 | # Begining of a framework for Jigits, which are widgets that allow LLMs to interact with code
2 | # Copyright (C) 2022 William S. Kish
3 |
4 |
5 | from subprompt import SubPrompt
6 |
7 |
8 | class Jigit(BaseModel):
9 |
10 | id: str # unique ID for a particular Jigit
11 | prompt: SubPrompt # The subprompt to include to activate this jidget
12 |
13 |
14 | def handler(completion: str):
15 | pass
16 |
17 |
18 | def process_completion(self, completion: str):
19 | if self.id in completion:
20 | self.handler(completion)
21 |
--------------------------------------------------------------------------------
/src/massgpt.py:
--------------------------------------------------------------------------------
1 | # MassGPT app
2 | #
3 | # Copyright (C) 2022 William S. Kish
4 |
5 | from sqlmodel import Session, select, delete
6 | from sentence_transformers import SentenceTransformer
7 | import urllib.parse
8 |
9 | from db import engine
10 |
11 | import gpt3
12 |
13 | from exceptions import *
14 | from models import *
15 |
16 | from extract import url_to_text
17 | from subprompt import SubPrompt
18 |
19 |
20 | ###
21 | ### Various Specialized SubPrompts
22 | ###
23 |
24 | class MessageSubPrompt(SubPrompt):
25 | """
26 | SubPrompt Context for a user-generated message
27 | """
28 | MAX_TOKENS = 300
29 |
30 | @classmethod
31 | def from_msg(cls, msg: Message) -> "SubPrompt":
32 | # create user message specific subprompt
33 | text = f"user-{msg.user_id} wrote to MassGPT: {msg.text}"
34 | return MessageSubPrompt(text=text, max_tokens=MessageSubPrompt.MAX_TOKENS)
35 |
36 |
37 | class MessageResponseSubPrompt(SubPrompt):
38 | """
39 | SubPrompt Context for a user-generated message
40 | """
41 | @classmethod
42 | def from_msg_completion(cls, msp: MessageSubPrompt, comp: Completion) -> "SubPrompt":
43 | return msp # + f"MassGPT responded: {comp.completion}" # leave the model response out for experimentation
44 |
45 |
46 |
47 | class UrlSummarySubPrompt(SubPrompt):
48 | """
49 | SubPrompt Context for a user-requested URL Summary
50 | """
51 | @classmethod
52 | def from_summary(cls, user: User, text : str) -> "SubPrompt":
53 | text = f"User {user.id} posted a link with the following summary:\n{text}\n"
54 | # don't need to specify max _tokens here since the summary is a model output
55 | # that is regulated through the user_summary_limits
56 | return UrlSummarySubPrompt(text=text)
57 |
58 |
59 | class SummarySubPrompt(SubPrompt):
60 | """
61 | SubPrompt Context for a system-generated summary
62 | """
63 | @classmethod
64 | def from_msg(cls, text : str) -> "SubPrompt":
65 | text = f"Here is a summary of previous discussions for reference: {text}"
66 | # don't need to specify max _tokens here since the summary is a model output
67 | # that is regulated through the msg_summary_limits
68 | return SummarySubPrompt(text=text)
69 |
70 |
71 |
72 |
73 |
74 | class MassGPTMessageTask(gpt3.GPT3CompletionTask):
75 | """
76 | Generated message response completions based on dynamic history of recent messages and most used message
77 | """
78 | TEMPERATURE = 0.4
79 |
80 | # General Prompt Strategy:
81 | # Upon reception of message from a user 999, compose the following prompt
82 | # based on recent messages received from other users:
83 |
84 | PREPROMPT = SubPrompt( \
85 | """You are MassGPT, and this is a fun experiment. \
86 | You were built by Jiggy AI using OpenAI text-davinci-003. \
87 | Instruction: Different users are sending you messages. \
88 | They can not communicate with each other directly. \
89 | Any user to user message must be relayed through you. \
90 | Pass along any interesting message. \
91 | Try not to repeat yourself. \
92 | Ask users questions if you are not sure what to say. \
93 | If a user expresses interest in a topic discussed here, \
94 | respond to them based on what you read here. \
95 | Users have recently said the following to you:""")
96 |
97 |
98 | # "User 123 wrote: ABC is the greatest thing ever"
99 | # "User 234 wrote: ABC is cool but i like GGG more"
100 | # "User 345 wrote: DDD is the best ever!"
101 | # etc
102 | PENULTIMATE_PROMPT = SubPrompt("Instruction: Respond to the following user message considering the above context and Instruction:")
103 | # "User 999 wrote: What do folks think about ABC?" # End of Prompt
104 | FINAL_PROMPT = SubPrompt("MassGPT responded:")
105 | # Then send resulting llm completion back to user 999 in response to his message
106 |
107 | def __init__(self) -> "MassGPTMessageTask":
108 | limits = gpt3.CompletionLimits(min_prompt = 0,
109 | min_completion = 300,
110 | max_completion = 400)
111 |
112 | super().__init__(limits = limits,
113 | temperature = MassGPTMessageTask.TEMPERATURE,
114 | model = 'text-davinci-003')
115 |
116 |
117 | def completion(self,
118 | recent_msgs : MessageResponseSubPrompt,
119 | user_msg : MessageSubPrompt) -> Completion:
120 | """
121 | return completion for the provided subprompts
122 | """
123 | prompt = MassGPTMessageTask.PREPROMPT
124 | final_prompt = MassGPTMessageTask.PENULTIMATE_PROMPT
125 | final_prompt += user_msg
126 | final_prompt += MassGPTMessageTask.FINAL_PROMPT
127 |
128 | logger.info(f"overhead tokens: {(prompt + final_prompt).tokens}")
129 |
130 | available_tokens = self.max_prompt_tokens() - (prompt + final_prompt).tokens
131 | logger.info(f"available_tokens: {available_tokens}")
132 | # assemble list of most recent_messages up to available token limit
133 | reversed_subs = []
134 | # add previous message context
135 | for sub in reversed(recent_msgs):
136 | if available_tokens - sub.tokens -1 < 0:
137 | break
138 | reversed_subs.append(sub)
139 | available_tokens -= sub.tokens + 1
140 |
141 | for sub in reversed(reversed_subs):
142 | prompt += sub
143 |
144 | prompt += final_prompt
145 | # add most recent user message after penultimate prompt
146 | logger.info(f"final prompt tokens: {prompt.tokens} max{self.max_prompt_tokens()}")
147 |
148 | logger.info(f"final prompt token_count: {prompt.tokens} chars: {len(prompt.text)}")
149 |
150 | return super().completion(prompt)
151 |
152 |
153 |
154 |
155 | msg_response_task = MassGPTMessageTask()
156 |
157 |
158 |
159 |
160 | class UrlSummaryTask(gpt3.GPT3CompletionTask):
161 | """
162 | A Factory class to dynamically compose a prompt context for URL Summary task
163 | """
164 |
165 | # the PROMPT_PREFIX is prepended to the url content before sending to the language model
166 | SUMMARIZE_PROMPT_PREFIX = SubPrompt("Provide a detailed summary of the following web content, including what type of content it is (e.g. news article, essay, technical report, blog post, product documentation, content marketing, etc). If the content looks like an error message, respond 'content unavailable'. If there is anything controversial please highlight the controversy. If there is something surprising, unique, or clever, please highlight that as well:")
167 |
168 | # prompt prefix for Github Readme files
169 | GITHUB_PROMPT_PREFIX = SubPrompt("Provide a summary of the following github project readme file, including the purpose of the project, what problems it may be used to solve, and anything the author mentions that differentiates this project from others:")
170 |
171 | TEMPERATURE = 0.2
172 |
173 | def __init__(self) -> "UrlSummaryTask":
174 | limits = gpt3.CompletionLimits(min_prompt = 40,
175 | min_completion = 300,
176 | max_completion = 600)
177 |
178 | super().__init__(limits = limits,
179 | temperature = UrlSummaryTask.TEMPERATURE,
180 | model = 'text-davinci-003')
181 |
182 | def prefix(self, url):
183 | if urllib.parse.urlparse(url).netloc == 'github.com':
184 | return UrlSummaryTask.GITHUB_PROMPT_PREFIX
185 | else:
186 | return UrlSummaryTask.SUMMARIZE_PROMPT_PREFIX
187 |
188 |
189 | def completion(self, url: str, url_text : str) -> Completion:
190 | """
191 | return prompt text to summary the following url and and url_text.
192 | The url is required in able to enable host-specific prompt strategy.
193 | For example a different prompt is used to summarize github repo's versus other web sites.
194 | """
195 | prompt = self.prefix(url) + url_text
196 | prompt.truncate(self.max_prompt_tokens())
197 | return super().completion(prompt)
198 |
199 |
200 |
201 |
202 | url_summary_task = UrlSummaryTask()
203 |
204 |
205 |
206 | ## Embedding Config
207 | ST_MODEL_NAME = 'multi-qa-mpnet-base-dot-v1'
208 | st_model = SentenceTransformer(ST_MODEL_NAME)
209 |
210 |
211 |
212 |
213 | class Context():
214 | """
215 | A context for assembling a large prompt context from recent user message subprompts
216 | """
217 | def __init__(self) -> "Context":
218 | self.tokens = 0
219 | self._sub_prompts = []
220 |
221 | def push(self, sub_prompt : SubPrompt) -> bool:
222 | """
223 | Push sub_prompt onto begining of context.
224 | Used to recreate context in reverse order from database select
225 | raises MaximumTokenLimit when prompt context limit is exceeded
226 | """
227 | if self.tokens > msg_response_task.max_prompt_tokens():
228 | raise MaximumTokenLimit
229 | self._sub_prompts.insert(0, sub_prompt)
230 | self.tokens += sub_prompt.tokens
231 |
232 | def add(self, sub_prompt : SubPrompt) -> None:
233 | # add new prompt to end of sub_prompts
234 | self._sub_prompts.append(sub_prompt)
235 | self.tokens += sub_prompt.tokens
236 | # remove oldest subprompts if over prompt context limit exceeded
237 | while self.tokens > msg_response_task.max_prompt_tokens():
238 | self.tokens -= self._sub_prompts.pop(0).tokens
239 |
240 | def sub_prompts(self) -> list[SubPrompt]:
241 | return self._sub_prompts
242 |
243 |
244 | ##
245 | ### maintain a single global context (for now)
246 | ##
247 | context = Context()
248 |
249 |
250 |
251 |
252 | def receive_message(user : User, text : str) -> str:
253 | """
254 | receive a message from the specified user.
255 | Return the message response
256 | """
257 | logger.info(f"message from {user.id} {user.first_name} {user.last_name}: {text}")
258 | with Session(engine) as session:
259 | # persist msg to database so we can regain recent msg context after pod restart
260 | msg = Message(text=text, user_id=user.id)
261 | session.add(msg)
262 | session.commit()
263 | session.refresh(msg)
264 |
265 | # embedding should move to background work queue
266 | t0 = time()
267 | vector = [float(x) for x in st_model.encode(msg.text)]
268 | embedding = Embedding(source = EmbeddingSource.message,
269 | source_id = msg.id,
270 | collection = ST_MODEL_NAME,
271 | model = ST_MODEL_NAME,
272 | vector = vector)
273 | logger.info(f"embedding dt: {time()-t0}")
274 | session.add(embedding)
275 | session.commit()
276 | session.refresh(msg)
277 |
278 | # build final aggregate prompt
279 | msg_subprompt = MessageSubPrompt.from_msg(msg)
280 |
281 | completion = msg_response_task.completion(context.sub_prompts(), msg_subprompt)
282 | logger.info(str(completion))
283 |
284 | # add the new user message to the global shared context
285 | rsp_subprompt = MessageResponseSubPrompt.from_msg_completion(msg_subprompt, completion)
286 | context.add(rsp_subprompt)
287 |
288 | # save response to database
289 | with Session(engine) as session:
290 | session.add(completion)
291 | session.commit()
292 | session.refresh(completion)
293 | session.add(Response(message_id=msg.id,
294 | completion_id=completion.id))
295 | session.commit()
296 | session.refresh(completion)
297 |
298 | return str(completion)
299 |
300 |
301 |
302 |
303 | def summarize_url(user : User, url : str) -> str:
304 | """
305 | Summarize a url for a user.
306 | Return the URL summary, adding the summary to the current context
307 | """
308 | # check if message contains a URL
309 | # if so extract and summarize the contents
310 | text = url_to_text(url)
311 | with Session(engine) as session:
312 | db_url = URL(url=url, user_id = user.id)
313 | session.add(db_url)
314 | session.commit()
315 | session.refresh(db_url)
316 | urltext = UrlText(url_id = db_url.id,
317 | mechanism = "url_to_text",
318 | text = text)
319 | session.add(urltext)
320 | session.commit()
321 | session.refresh(urltext)
322 |
323 | completion = url_summary_task.completion(url, text)
324 |
325 | summary_text = str(completion)
326 |
327 | with Session(engine) as session:
328 | session.add(completion)
329 | session.commit()
330 | session.refresh(completion)
331 |
332 | url_summary = UrlSummary(text_id = urltext.id,
333 | user_id = user.id,
334 | model = url_summary_task.model,
335 | prefix = url_summary_task.prefix(url).text,
336 | summary = summary_text)
337 | session.add(url_summary)
338 | session.commit()
339 | session.refresh(url_summary)
340 |
341 | # embedding should move to background work queue
342 | t0 = time()
343 | vector = [float(x) for x in st_model.encode(summary_text)]
344 | embedding = Embedding(source = EmbeddingSource.url_summary,
345 | source_id = url_summary.id,
346 | collection = ST_MODEL_NAME,
347 | model = ST_MODEL_NAME,
348 | vector = vector)
349 | logger.info(f"embedding dt: {time()-t0}")
350 | session.add(embedding)
351 |
352 | session.commit()
353 |
354 | # add the summary to recent context
355 | context.add(UrlSummarySubPrompt.from_summary(user=user, text=summary_text))
356 |
357 | logger.info(summary_text)
358 | # send the text summary to the user as FYI
359 | return summary_text
360 |
361 |
362 |
363 | def current_prompts() -> str:
364 |
365 | prompt = "Current prompt stack:\n\n"
366 | prompt += MassGPTMessageTask.PREPROMPT.text +"\n\n"
367 | prompt += "[Context as shown by /context]\n\n"
368 | prompt += MassGPTMessageTask.PENULTIMATE_PROMPT.text + "\n\n"
369 | prompt += "[Most recent message from user]"
370 |
371 | return prompt
372 |
373 |
374 |
375 | def current_context(max_len=4096) -> str:
376 | """
377 | iterator returning max_len length strings of current context
378 | """
379 | size = 0
380 | text = ""
381 | for sub in context.sub_prompts():
382 | if len(sub.text) + size > max_len:
383 | yield text
384 | text = ""
385 | size = 0
386 | text += sub.text + "\n"
387 | size += len(sub.text) + 1
388 | if text:
389 | yield text
390 | yield f"{context.tokens} tokens"
391 |
392 |
393 |
394 | def load_context_from_db():
395 | last = ""
396 | logger.info('load_context_from_db')
397 | with Session(engine) as session:
398 | for msg in session.exec(select(Message).order_by(Message.id.desc())):
399 | try:
400 | msg_subprompt = MessageSubPrompt.from_msg(msg)
401 | except:
402 | continue # historic message to big for current limits
403 | if msg_subprompt.text == last: continue # basic dedup
404 | last = msg_subprompt.text
405 | resp = session.exec(select(Response).where(Response.message_id == msg.id)).first()
406 | if not resp:
407 | try:
408 | context.push(msg_subprompt)
409 | except MaximumTokenLimit:
410 | break
411 | continue
412 | comp = session.exec(select(Completion).where(Completion.id == resp.completion_id)).first()
413 | if not comp: continue
414 | rsp_subprompt = MessageResponseSubPrompt.from_msg_completion(msg_subprompt, comp)
415 | try:
416 | context.push(rsp_subprompt)
417 | except MaximumTokenLimit:
418 | break
419 |
420 | load_context_from_db()
421 |
--------------------------------------------------------------------------------
/src/models.py:
--------------------------------------------------------------------------------
1 |
2 | # MassGPT sqlmodel models
3 | # Copyright (C) 2022 William S. Kish
4 |
5 | from loguru import logger
6 | from typing import Optional, List
7 | from array import array
8 | import enum
9 | from sqlmodel import Field, SQLModel, Column, ARRAY, Float, Enum
10 | from sqlalchemy import BigInteger
11 | from pydantic import BaseModel, ValidationError, validator
12 | from pydantic import condecimal
13 | from time import time
14 |
15 | timestamp = condecimal(max_digits=14, decimal_places=3) # unix epoch timestamp decimal to millisecond precision
16 |
17 |
18 | class UserStatus(str, enum.Enum):
19 | enabled = "enabled"
20 | disabled = "disabled"
21 |
22 |
23 | class User(SQLModel, table=True):
24 | id: int = Field(primary_key=True, description='Unique user ID')
25 | created_at: timestamp = Field(default_factory=time, description='The epoch timestamp when the Evaluation was created.')
26 | username: Optional[str] = Field(default=None, index=True, description="Username")
27 | first_name: Optional[str] = Field(description="User's first name")
28 | last_name: Optional[str] = Field(description="User's last name")
29 | auth0_id: Optional[str] = Field(index=True, description='Auth0 user_id')
30 | telegram_id: Optional[int] = Field(sa_column=Column(BigInteger()), index=True, description='Telegram User ID')
31 | telegram_is_bot: Optional[bool] = Field(description="is_bot from telegram")
32 | telegram_is_premium: Optional[bool] = Field(description="is_premium from telegram")
33 | telegram_language_code: Optional[str] = Field(description="language_code from telegram")
34 |
35 |
36 |
37 |
38 | class Message(SQLModel, table=True):
39 | id: int = Field(primary_key=True, description='Our unique message id')
40 | text: str = Field(max_length=4096, description='The message text')
41 | user_id: int = Field(index=True, foreign_key='user.id', description='The user who sent the Message')
42 | created_at: timestamp = Field(index=True, default_factory=time, description='The epoch timestamp when the Message was created.')
43 |
44 |
45 |
46 |
47 | class Response(SQLModel, table=True):
48 | id: int = Field(primary_key=True, description='Our unique message id')
49 | message_id: int = Field(index=True, foreign_key="message.id", description="The message to which this is a response.")
50 | created_at: timestamp = Field(index=True, default_factory=time, description='The epoch timestamp when the Message was created.')
51 | completion_id: int = Field(foreign_key="completion.id", description="associated completion that provided the response text")
52 |
53 |
54 | class URL(SQLModel, table=True):
55 | id: int = Field(primary_key=True, description='Unique ID')
56 | url: str = Field(max_length=2048, description='The actual supplied URL')
57 | user_id: int = Field(index=True, foreign_key='user.id', description='The user who sent the URL')
58 | created_at: timestamp = Field(default_factory=time, description='The epoch timestamp when this was created.')
59 |
60 |
61 | class UrlText(SQLModel, table=True):
62 | id: int = Field(primary_key=True, description="The text unique id.")
63 | url_id: int = Field(index=True, foreign_key="url.id", description="The usr this text was extracted from.")
64 | mechanism: str = Field(description="identifies which software mechanism exracted the text from the url")
65 | created_at: timestamp = Field(default_factory=time, description='The epoch timestamp when the url was crawled.')
66 | text: str = Field(max_length=65535, description="The readable text we managed to extract from the Url.")
67 | content: Optional[str] = Field(max_length=65535, description="original html content")
68 | content_type: Optional[str] = Field(description="content type from http")
69 |
70 |
71 | class UrlSummary(SQLModel, table=True):
72 | id: int = Field(primary_key=True, description="The URL Summary unique id.")
73 | text_id: int = Field(index=True, foreign_key="urltext.id", description="The UrlText used to create the summary.")
74 | model: str = Field(description="The model used to produce this summary.")
75 | prefix: str = Field(max_length=8192, description="The prompt prefix used to create the summary.")
76 | summary: str = Field(max_length=8192, description="The model summary of the UrlText.")
77 | created_at: timestamp = Field(default_factory=time, description='The epoch timestamp when this was created.')
78 | #completion_id: int = Field(foreign_key="completion.id", description="associated completion that provided the response text")
79 |
80 |
81 | class ChatSummary(SQLModel, table=True):
82 | id: int = Field(primary_key=True, description="The ChatSummary unique id.")
83 | created_at: timestamp = Field(index=True, default_factory=time, description='The epoch timestamp when this was created.')
84 | completion_id: int = Field(foreign_key="completion.id", description="associated completion that provided the response text")
85 |
86 |
87 |
88 |
89 |
90 | class Completion(SQLModel, table=True):
91 | """
92 | Model prompt + completion
93 | A low-level of model prompts+completions.
94 | """
95 | id: int = Field(primary_key=True, description="The completion unique id.")
96 | model: str = Field(description="model engine")
97 | temperature: int = Field(description="configured temperature") # XXX covert to float
98 | prompt: str = Field(max_length=65535, description="The prompt used to generate the completion.")
99 | completion: str = Field(max_length=65535, description="The completion received from the model.")
100 | created_at: timestamp = Field(default_factory=time, description='The epoch timestamp when this was created.')
101 |
102 | def __str__(self):
103 | """
104 | str(Completion) returns the completion text for convenience
105 | """
106 | return self.completion
107 |
108 |
109 |
110 | class EmbeddingSource(str, enum.Enum):
111 | """
112 | The source of the text for embedding
113 | """
114 | message = "message"
115 | chat_summary = "chat_summary"
116 | url_summary = "url_summary"
117 | hn_story_summary = "hn_story_summary"
118 | hn_story_title = "hn_story_title"
119 |
120 |
121 |
122 |
123 | class Embedding(SQLModel, table=True):
124 | id: int = Field(default=None,
125 | primary_key=True,
126 | description='Unique database identifier for a given embedding vector.')
127 | collection: str = Field(index=True, description='The name of the collection that holds this vector.')
128 | source: EmbeddingSource = Field(sa_column=Column(Enum(EmbeddingSource)),
129 | description='The source of this embedding')
130 | source_id: int = Field(index=True, description='The message/chat/url_summary id that produced this embedding.')
131 | model: str = Field(description="The model used to produce this embedding.")
132 | vector: List[float] = Field(sa_column=Column(ARRAY(Float(24))),
133 | description='The embedding vector.')
134 |
135 |
136 | class HnswIndex(SQLModel, table=True):
137 |
138 | id: int = Field(default=None,
139 | primary_key=True,
140 | description='Unique database identifier for a given index')
141 | collection: str = Field(index=True, description='The name of the collection that holds this vector.')
142 | count: int = Field(default=0, description="The number of vectors included in the index. The number of vectors in the collection at the time of index build.")
143 | objkey: str = Field(description='The index key name in object store')
144 | created_at: timestamp = Field(default_factory=time, description='The epoch timestamp when the index was requested to be created.')
145 |
--------------------------------------------------------------------------------
/src/pdf_text.py:
--------------------------------------------------------------------------------
1 |
2 | from loguru import logger
3 | from io import BytesIO
4 | from pdfminer.high_level import extract_pages
5 | from pdfminer.layout import LTTextContainer
6 |
7 |
8 | def pdf_text(pdf_bytes):
9 | """
10 | extract text from the pdf_bytes and return it as a single string
11 | """
12 | text = ""
13 | for page_layout in extract_pages(BytesIO(pdf_bytes)):
14 | for element in page_layout:
15 | if isinstance(element, LTTextContainer):
16 | for text_line in element:
17 | text += text_line.get_text().rstrip() + " "
18 |
19 | logger.info("pdf_text: "+text)
20 | return text
21 |
--------------------------------------------------------------------------------
/src/querytest.py:
--------------------------------------------------------------------------------
1 | from loguru import logger
2 | from gpt3 import GPT3CompletionTask, CompletionLimits
3 | from exceptions import *
4 | from models import Completion
5 |
6 | from subprompt import SubPrompt
7 |
8 | import wikipedia
9 |
10 |
11 | class SubjectQueryTask(GPT3CompletionTask):
12 | def __init__(self,
13 | min_completion=100,
14 | max_completion=100) -> "SubjectQueryTask":
15 |
16 | limits = CompletionLimits(min_prompt = 0,
17 | min_completion = min_completion,
18 | max_completion = max_completion)
19 |
20 |
21 | super().__init__(limits = limits,
22 | temperature = .1,
23 | model = 'text-davinci-003')
24 |
25 | def completion(self,
26 | query : str) -> Completion:
27 |
28 | prompt = SubPrompt("What is the primary topic of this questions:")
29 | prompt += query
30 | print(prompt)
31 | resp = super().completion(prompt)
32 | print(str(resp))
33 | return resp
34 |
35 |
36 |
37 | class WikipediaQueryTask(GPT3CompletionTask):
38 | """
39 | Generated message response completions based on dynamic history of recent messages and most used message
40 | """
41 | TEMPERATURE = 0.0
42 |
43 | def __init__(self,
44 | wikipedia_page:str,
45 | min_completion=100,
46 | max_completion=100) -> "WikipediaQueryTask":
47 |
48 | limits = CompletionLimits(min_prompt = 0,
49 | min_completion = min_completion,
50 | max_completion = max_completion)
51 |
52 | self.page = wikipedia_page
53 | page = wikipedia.page(self.page)
54 | self.content = page.content
55 |
56 | super().__init__(limits = limits,
57 | temperature = 0,
58 | model = 'text-davinci-003')
59 |
60 |
61 | def completion(self,
62 | query : str) -> Completion:
63 | """
64 | return completion for the provided subprompts
65 | """
66 |
67 | prompt = SubPrompt(f"Respond to the final question regarding this info on '{self.page}' using only the information provided here in the following text:")
68 | final = SubPrompt(query)
69 | available_tokens = self.max_prompt_tokens() - (prompt + final).tokens - 1
70 | article = SubPrompt(self.content, max_tokens=available_tokens, truncate=True)
71 | prompt += article
72 | prompt += final
73 | print(prompt)
74 | resp = super().completion(prompt)
75 | print(str(resp))
76 | return resp
77 |
78 | query_task = WikipediaQueryTask("rocket engine", 200, 400)
79 |
80 |
81 |
82 |
--------------------------------------------------------------------------------
/src/results.txt:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | gpt-3.5-turbo precision 93.3% recall 71.3% latency 10.8 s
5 | gpt-4 precision 100.0% recall 100.0% latency 11.9 s
6 |
7 |
8 |
--------------------------------------------------------------------------------
/src/s3.py:
--------------------------------------------------------------------------------
1 | import s3_bucket as S3
2 | import boto3
3 | import os
4 | from botocore.exceptions import ClientError
5 |
6 | BUCKET_NAME = 'jiggy-assets'
7 | ENDPOINT_URL = os.environ.get("JIGGY_STORAGE_ENDPOINT_URL", "https://us-southeast-1.linodeobjects.com")
8 |
9 | STORAGE_KEY_ID = os.environ['JIGGY_STORAGE_KEY_ID']
10 | STORAGE_SECRET_KEY = os.environ['JIGGY_STORAGE_KEY_SECRET']
11 |
12 |
13 | S3.Bucket.prepare(STORAGE_KEY_ID,
14 | STORAGE_SECRET_KEY,
15 | endpoint_url=ENDPOINT_URL)
16 |
17 |
18 | bucket = S3.Bucket(BUCKET_NAME)
19 |
20 |
21 | linode_obj_config = {
22 | "aws_access_key_id": STORAGE_KEY_ID,
23 | "aws_secret_access_key": STORAGE_SECRET_KEY,
24 | "endpoint_url": ENDPOINT_URL}
25 |
26 |
27 | s3_client = boto3.client('s3', **linode_obj_config)
28 |
29 |
30 | def create_presigned_url(object_name, expiration=300):
31 | """
32 | Generate a presigned URL to share an S3 object
33 | :param object_name: string
34 | :param expiration: Time in seconds for the presigned URL to remain valid
35 | :return: Presigned URL as string. If error, returns None.
36 | """
37 |
38 | # Generate a presigned URL for the S3 object
39 | try:
40 | response = s3_client.generate_presigned_url('get_object',
41 | Params={'Bucket': BUCKET_NAME,
42 | 'Key': object_name},
43 | ExpiresIn=expiration)
44 | except ClientError as e:
45 | logging.error(e)
46 | return None
47 |
48 | # The response contains the presigned URL
49 | return response
50 |
51 |
52 |
--------------------------------------------------------------------------------
/src/s3_bucket/README.md:
--------------------------------------------------------------------------------
1 | forked from https://github.com/CodeLighthouse/s3-bucket
2 |
3 | awaiting https://github.com/CodeLighthouse/s3-bucket/pull/4
4 |
--------------------------------------------------------------------------------
/src/s3_bucket/__init__.py:
--------------------------------------------------------------------------------
1 | from .bucket import Bucket
2 | from . import exceptions as Exceptions
3 |
4 | __all__ = [
5 | "Bucket",
6 | "Exceptions"
7 | ]
--------------------------------------------------------------------------------
/src/s3_bucket/bucket.py:
--------------------------------------------------------------------------------
1 | import boto3
2 | from botocore.exceptions import ClientError
3 | from typing import Union, Dict
4 | from . import exceptions
5 |
6 |
7 | class Bucket:
8 | """
9 | CLASS THAT HANDLES S3 BUCKET TRANSACTIONS. ABSTRACTS AWAY BOTO3'S ARCANE BS.
10 | HANDLES BOTO3'S EXCEPTIONS WITH CUSTOM EXCEPTION CLASSES TO MAKE CODE USABLE
11 | """
12 | _AWS_ACCESS_KEY_ID = None
13 | _AWS_SECRET_ACCESS_KEY = None
14 | _ENDPOINT_URL = None
15 |
16 | def __init__(self, bucket_name: str):
17 |
18 | # ENSURE THE PACKAGE HAS BEEN CONFIGURED WITH THE APPROPRIATE ACCESS KEYS
19 | if not Bucket._AWS_ACCESS_KEY_ID or not Bucket._AWS_SECRET_ACCESS_KEY:
20 | raise TypeError("AWS access key ID and AWS Secret access key must be configured. They're class variables"
21 | "for the Bucket class. You can set them by calling Bucket.prepare(access_key, secret_key)")
22 | self.bucket_name = bucket_name
23 |
24 | @classmethod
25 | def prepare(cls, aws_access_key_id: str, aws_secret_access_key: str, aws_session_token=None, endpoint_url=None):
26 | cls._AWS_ACCESS_KEY_ID = aws_access_key_id
27 | cls._AWS_SECRET_ACCESS_KEY = aws_secret_access_key
28 | cls._AWS_SESSION_TOKEN = aws_session_token
29 | cls._ENDPOINT_URL = endpoint_url
30 |
31 | @staticmethod
32 | def _get_boto3_resource():
33 | """
34 | GET AND CONFIGURE THE BOTO3 S3 API RESOURCE. THIS IS A "PRIVATE" METHOD
35 | """
36 | # CREATE A "SESSION" WITH BOTO3
37 |
38 | _session = boto3.Session(
39 | aws_access_key_id=Bucket._AWS_ACCESS_KEY_ID,
40 | aws_secret_access_key=Bucket._AWS_SECRET_ACCESS_KEY,
41 | aws_session_token=Bucket._AWS_SESSION_TOKEN
42 | )
43 |
44 | # CREATE S3 RESOURCE
45 | resource = _session.resource('s3', endpoint_url=Bucket._ENDPOINT_URL)
46 |
47 | return resource
48 |
49 | def _handle_boto3_client_error(self, e: ClientError, key=None):
50 | """
51 | HANDLE BOTO3'S CLIENT ERROR. BOTO3 ONLY RETURNS ONE TYPE OF EXCEPTION, WITH DIFFERENT KEYS AND MESSAGES FOR
52 | DIFFERENT TYPES OF ERRORS. REFER TO EXCEPTIONS.PY FOR EXPLANATION
53 |
54 | :param e: THE CLIENTERROR TO HANDLE
55 | :param key: THE KEY OF THE OBJECT WE'RE DEALING WITH. OPTIONAL, DEFAULT IS None
56 | """
57 | error_code: str = e.response.get('Error').get('Code')
58 |
59 | print(e.response)
60 |
61 | if error_code == 'AccessDenied':
62 | raise exceptions.BucketAccessDenied(self.bucket_name)
63 | elif error_code == 'NoSuchBucket':
64 | raise exceptions.NoSuchBucket(self.bucket_name)
65 | elif error_code == 'NoSuchKey':
66 | raise exceptions.NoSuchKey(key, self.bucket_name)
67 | else:
68 | raise exceptions.UnknownBucketException(self.bucket_name, e)
69 |
70 | def get(self, key: str, response_content_type: str = None) -> (bytes, Dict):
71 | """
72 | GET AN OBJECT FROM THE BUCKET AND RETURN A BYTES TYPE THAT MUST BE DECODED ACCORDING TO THE ENCODING TYPE
73 |
74 | :param key: THE KEY IN S3 OF THE OBJECT TO GET
75 | :param response_content_type: THE CONTENT TYPE TO ENFORCE ON THE RESPONSE. MAY BE USEFUL IN SOME CASES
76 | :return: A TWO-TUPLE: (1) A BYTES OBJECT THAT MUST BE DECODED DEPENDING ON HOW IT WAS ENCODED.
77 | LEFT UP TO MIDDLEWARE TO DETERMINE AND (2) A DICT CONTAINING METADATA ON WHEN THE OBJECT WAS STORED
78 | """
79 |
80 | # GET S3 resource
81 | resource = Bucket._get_boto3_resource()
82 | s3_bucket = resource.Object(self.bucket_name, key)
83 |
84 | try:
85 | if response_content_type:
86 | response = s3_bucket.get(ResponseContentType=response_content_type)
87 | else:
88 | response = s3_bucket.get()
89 |
90 | data = response.get('Body').read() # THE OBJECT DATA STORED
91 | metadata: Dict = response.get('Metadata') # METADATA STORED WITH THE OBJECT
92 | return data, metadata
93 |
94 | # BOTO RAISES ONLY ONE ERROR TYPE THAT THEN MUST BE PROCESSES TO GET THE CODE
95 | except ClientError as e:
96 | self._handle_boto3_client_error(e, key=key)
97 |
98 | def put(self, key: str, data: Union[str, bytes], content_type: str = None, metadata: Dict = {}) -> Dict:
99 | """
100 | PUT AN OBJECT INTO THE BUCKET
101 |
102 | :param key: THE KEY TO STORE THE OBJECT UNDER
103 | :param data: THE DATA TO STORE. CAN BE BYTES OR STRING
104 | :param content_type: THE MIME TYPE TO STORE THE DATA AS. MAY BE IMPORTANT FOR BINARY DATA
105 | :param metadata: A DICT CONTAINING METADATA TO STORE WITH THE OBJECT. EXAMPLES INCLUDE TIMESTAMP OR
106 | ORGANIZATION NAME. VALUES _MUST_ BE STRINGS.
107 | :return: A DICT CONTAINING THE RESPONSE FROM S3. IF AN EXCEPTION IS NOT THROWN, ASSUME PUT OPERATION WAS SUCCESSFUL.
108 | """
109 |
110 | # GET RESOURCE
111 | resource = Bucket._get_boto3_resource()
112 | s3_bucket = resource.Object(self.bucket_name, key)
113 |
114 | # PUT IT
115 | try:
116 | if content_type:
117 | response = s3_bucket.put(
118 | Body=data,
119 | ContentType=content_type,
120 | Key=key,
121 | Metadata=metadata
122 | )
123 | else:
124 | response = s3_bucket.put(
125 | Body=data,
126 | Key=key,
127 | Metadata=metadata
128 | )
129 | return response
130 |
131 | # BOTO RAISES ONLY ONE ERROR TYPE THAT THEN MUST BE PROCESSES TO GET THE CODE
132 | except ClientError as e:
133 | self._handle_boto3_client_error(e, key=key)
134 |
135 | def delete(self, key: str) -> Dict:
136 | """
137 | DELETE A SPECIFIED OBJECT FROM THE BUCKET
138 |
139 | :param key: A STRING THAT IS THE OBJECT'S KEY IDENTIFIER IN S3
140 | :return: THE RESPONSE FROM S3. IF NO EXCEPTION WAS THROWN, ASSUME DELETE OPERATION WAS SUCCESSFUL
141 | """
142 | # GET S3 RESOURCE
143 | resource = Bucket._get_boto3_resource()
144 | s3_bucket = resource.Object(self.bucket_name, key)
145 |
146 | try:
147 | response = s3_bucket.delete()
148 | return response
149 |
150 | # BOTO RAISES ONLY ONE ERROR TYPE THAT THEN MUST BE PROCESSES TO GET THE CODE
151 | except ClientError as e:
152 | self._handle_boto3_client_error(e, key=key)
153 |
154 | def upload_file(self, local_filepath: str, key: str) -> Dict:
155 | """
156 | UPLOAD A LOCAL FILE TO THE BUCKET. TRANSPARENTLY MANAGES MULTIPART UPLOADS.
157 |
158 | :param local_filepath: THE ABSOLUTE FILEPATH OF THE FILE TO STORE
159 | :param key: THE KEY TO STORE THE FILE UNDER IN THE BUCKET
160 | :return: A DICT CONTAINING THE RESPONSE FROM S3. IF NO EXCEPTION IS THROWN, ASSUME OPERATION
161 | COMPLETED SUCCESSFULLY
162 | """
163 |
164 | # GET S3 RESOURCE
165 | resource = Bucket._get_boto3_resource()
166 | s3_bucket = resource.Object(self.bucket_name, key)
167 |
168 | try:
169 | response = s3_bucket.upload_file(local_filepath)
170 | return response
171 |
172 | # BOTO RAISES ONLY ONE ERROR TYPE THAT THEN MUST BE PROCESSES TO GET THE CODE
173 | except ClientError as e:
174 | self._handle_boto3_client_error(e, key=key)
175 |
176 | def download_file(self, key: str, local_filepath: str) -> Dict:
177 | """
178 | DOWNLOAD AN OBJECT FROM THE BUCKET TO A LOCAL FILE. TRANSPARENTLY MANAGES MULTIPART DOWNLOADS.
179 |
180 | :param key: THE KEY THAT IDENTIFIES THE OBJECT TO DOWNLOAD
181 | :param local_filepath: THE ABSOLUTE FILEPATH TO STORE THE OBJECT TO
182 | :return: A DICT CONTAINING THE RESPONSE FROM S3. IF NO EXCEPTION IS THROWN, ASSUME OPERATION
183 | COMPLETED SUCCESSFULLY
184 | """
185 | # GET S3 RESOURCE
186 | resource = Bucket._get_boto3_resource()
187 | s3_bucket = resource.Object(self.bucket_name, key)
188 |
189 | try:
190 | response = s3_bucket.download_file(local_filepath)
191 | return response
192 |
193 | # BOTO RAISES ONLY ONE ERROR TYPE THAT THEN MUST BE PROCESSES TO GET THE CODE
194 | except ClientError as e:
195 | self._handle_boto3_client_error(e, key=key)
196 |
--------------------------------------------------------------------------------
/src/s3_bucket/exceptions.py:
--------------------------------------------------------------------------------
1 | from botocore.exceptions import ClientError
2 |
3 | """
4 | CUSTOM EXCEPTIONS - DESIGNED AS MORE USEFUL WRAPPERS TO BOTOCORE'S ARCANE BS EXCUSE-FOR-EXCEPTIONS
5 | BASICALLY ENCAPSULATE BOTO EXCEPTIONS W/MORE INFORMATION.
6 | CLIENT CODE UNLIKELY TO KNOW HOW TO CATCH BOTOCORE EXCEPTIONS, BUT THESE ARE EXPOSED THROUGH S3 CLASS SO EZ
7 | """
8 |
9 |
10 | class BucketException(Exception):
11 | """
12 | PARENT CLASS THAT ENSURES THAT ALL BUCKET ERRORS ARE CONSISTENT
13 | """
14 |
15 | def __init__(self, message, bucket):
16 | self.bucket = bucket
17 | self.message = f'{message}'
18 | super().__init__(self.message)
19 |
20 |
21 | class NoSuchKey(BucketException):
22 | """
23 | RAISED IF YOU TRY TO ACCESS A NON-EXISTENT OBJECT, SINCE IT HAS MOST LIKELY EXPIRED
24 | """
25 |
26 | def __init__(self, key, bucket):
27 | self.key = key
28 | self.bucket = bucket
29 | self.message = f'No object in bucket {bucket} matches {key}. Has it expired?'
30 | super().__init__(self.message, self.bucket)
31 |
32 |
33 | class NoSuchBucket(BucketException):
34 | """
35 | RAISED IF YOU TRY TO ACCESS A NONEXISTENT BUCKET
36 | """
37 |
38 | def __init__(self, bucket_name):
39 | self.bucket = bucket_name
40 | self.message = f'Bucket {bucket_name} does not exist!'
41 | super().__init__(self.message, self.bucket)
42 |
43 |
44 | class BucketAccessDenied(BucketException):
45 | """
46 | RAISED IF ACCESS TO A BUCKET IS DENIED - LIKELY BECAUSE IT DOESN'T EXIST
47 | """
48 |
49 | def __init__(self, bucket_name):
50 | self.bucket = bucket_name
51 | self.message = f'Unable to access bucket {self.bucket}. Does it exist?'
52 |
53 | super().__init__(self.message, self.bucket)
54 |
55 |
56 | class UnknownBucketException(BucketException):
57 | """
58 | RAISED IF AN UNKNOWN S3 EXCEPTION OCCURS
59 | """
60 |
61 | def __init__(self, bucket_name, e: ClientError):
62 | self.bucket = bucket_name
63 | error_code: str = e.response.get('Error').get('Code')
64 | error_message: str = e.response.get('Error').get('Message')
65 | self.message = f'Unknown Bucket Exception {error_code}: {error_message}'
66 | super().__init__(self.message, self.bucket)
67 |
--------------------------------------------------------------------------------
/src/subprompt.py:
--------------------------------------------------------------------------------
1 | #
2 | # SubPrompt class that assists with keeping track of token counts and
3 | # efficiently combining SubPrompts
4 | #
5 | # Copyright (C) 2022 William S. Kish
6 |
7 | import os
8 | from pydantic import BaseModel, Field
9 | from tokenizer import token_len
10 | from typing import Optional
11 | from exceptions import *
12 |
13 |
14 | # have found that marking truncated text as truncated stops the model from trying
15 | # to complete the missing text instead of summarizing it as requested
16 | TRUNCATED = ""
17 | TRUNCATED_LEN = token_len(TRUNCATED)
18 |
19 | class SubPrompt:
20 | """
21 | A SubPrompt is a text string and associated token count for the string
22 |
23 | len(SubPrompt) returns the length of the SubPrompt in tokens
24 |
25 | SubPrompt1 + SubPrompt2 returns a new subprompt which contains the
26 | concatenated text of the 2 subprompts separated by "\n"
27 | and a mostly accurate token count.
28 |
29 | The combined token count is estimated (not computed) so can sometimes overestimate the
30 | actual token count by 1 token. Tests on random strings show this occurs less
31 | than 1% of the time.
32 | """
33 |
34 | def truncate(self, max_tokens, precise=False):
35 | if precise == True:
36 | raise Exception("precise truncation is not yet implemented")
37 | # crudely truncate longer texts to get it back down to approximately the target max_tokens
38 | # TODO: find precise truncation point using multiple calls to token_len()
39 | # TODO: consider option to truncating at sentence boundaries.
40 | if self.tokens <= max_tokens:
41 | return
42 | split_point = int(len(self.text) * (max_tokens-TRUNCATED_LEN) / self.tokens)
43 | while not self.text[split_point].isspace():
44 | split_point -= 1
45 | self.text = self.text[:split_point] + TRUNCATED
46 | self.tokens = token_len(self.text)
47 | if self.tokens > max_tokens:
48 | self.truncate(max_tokens*.95)
49 |
50 |
51 | def __init__(self, text: str, max_tokens=None, truncate=False, precise=False, tokens=None) -> "SubPrompt":
52 | """
53 | Create a subprompt from the specified string.
54 | If max_tokens is specified, then the SubPrompt will be limited to max_tokens.
55 | The behavior when max_tokens is exceeded is controlled by truncate.
56 | MaximumTokenLimit exception raised if the text exceeds the specified max_tokens and truncate is False.
57 | If truncate is true then the text will be truncated to meet the limit.
58 | If precise is False then the truncation will be very quick but only approximate.
59 | If precise is True then the truncation will be slower but guaranteed to meet the max_tokens limit.
60 | """
61 | if precise == True:
62 | raise Exception("precise truncation is not yet implemented")
63 | if tokens is None:
64 | tokens = token_len(text)
65 | self.text = text
66 | self.tokens = tokens
67 | if max_tokens is not None and tokens > max_tokens:
68 | if not truncate:
69 | raise MaximumTokenLimit
70 | self.truncate(max_tokens, precise=precise)
71 |
72 |
73 |
74 | def __len__(self) -> int:
75 | return self.tokens
76 |
77 |
78 | def __add__(self, o) -> "SubPrompt":
79 | """
80 | Combine the token strings and token counts with a newline character in between them.
81 | This will occasionally overestimate the combined token count by 1 token,
82 | which is acceptable for our intended use.
83 | """
84 | if isinstance(o, str):
85 | o = SubPrompt(o)
86 |
87 | return SubPrompt(text = self.text + "\n" + o.text,
88 | tokens = self.tokens + 1 + o.tokens)
89 |
90 | def __str__(self):
91 | return self.text
92 |
93 |
94 |
95 | if __name__ == "__main__":
96 | """
97 | test with random strings
98 | """
99 | from string import ascii_lowercase, whitespace, digits
100 | from random import sample, randint
101 |
102 | chars = ascii_lowercase + whitespace + digits
103 |
104 | def randstring(n):
105 | return "".join([sample(chars, 1)[0] for i in range(n)])
106 |
107 | count = 0
108 | for i in range(100000):
109 | c1 = randstring(randint(20, 100))
110 | c2 = randstring(randint(20, 100))
111 | print("c1", c1)
112 | print("c2", c2)
113 | sp1 = SubPrompt(c1)
114 | sp2 = SubPrompt(c2)
115 | print("sp1", sp1)
116 | print("sp2", sp2)
117 |
118 | print("len sp1:", len(sp1))
119 | print("len sp2:", len(sp2))
120 | #assert(len(sp1) == token_len(sp1.text))
121 | #assert(len(sp2) == token_len(sp2.text))
122 |
123 | sp3 = sp1 + sp2
124 | print("sp3", sp3)
125 |
126 | print("len sp3:", len(sp3))
127 | sp3len = token_len(sp3.text)
128 | print(sp3len)
129 | if len(sp3) != sp3len:
130 | count += 1
131 | assert(len(sp3) >= sp3len)
132 | print(count, "errors")
133 | # 651 errors out of 100000 on a typical run
134 |
135 |
136 |
137 |
138 |
139 |
--------------------------------------------------------------------------------
/src/tokenizer.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | from transformers import GPT2Tokenizer
4 |
5 |
6 | tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
7 |
8 | def token_len(text : str) -> int:
9 | """
10 | return number of tokens in text per gpt2 tokenizer
11 | """
12 | return len(tokenizer(text)['input_ids'])
13 |
14 |
15 |
--------------------------------------------------------------------------------