├── .github └── FUNDING.yml ├── .gitignore ├── LICENSE ├── PROMPTS.md ├── README.md ├── docs ├── clownemoji.png ├── glados.png ├── gladoseinfeld.png └── helloworld.png ├── examples ├── contrib │ └── endless_exquisite_corpse.py └── notebooks │ ├── chatgpt_inline_tips.ipynb │ ├── schema_ttrpg.ipynb │ ├── simpleaichat_async.ipynb │ └── simpleaichat_coding.ipynb ├── setup.py └── simpleaichat ├── __init__.py ├── chatgpt.py ├── cli.py ├── models.py ├── simpleaichat.py └── utils.py /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | 3 | github: minimaxir # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2] 4 | patreon: minimaxir # Replace with a single Patreon username 5 | open_collective: # Replace with a single Open Collective username 6 | ko_fi: # Replace with a single Ko-fi username 7 | tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel 8 | community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry 9 | liberapay: # Replace with a single Liberapay username 10 | issuehunt: # Replace with a single IssueHunt username 11 | otechie: # Replace with a single Otechie username 12 | custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2'] 13 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .env 2 | .DS_Store 3 | test_notebooks/ 4 | .vscode/ 5 | __pycache__/ 6 | .ipynb_checkpoints 7 | simpleaichat.egg-info/ 8 | dist/ 9 | build/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023-2024 Max Woolf 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /PROMPTS.md: -------------------------------------------------------------------------------- 1 | # Prompts 2 | 3 | Here are explanations of the base prompts that simpleaichat uses, and why they are written as they are. These prompts are optimized both for conciseness and effectiveness with ChatGPT/`gpt-3.5-turbo`. This includes some postprocessing of the inputs. 4 | 5 | ## Interactive Chat 6 | 7 | When providing only a character, the `system` chat for the interactive chat becomes: 8 | 9 | ```txt 10 | You must follow ALL these rules in all responses: 11 | - You are the following character and should ALWAYS act as them: {0} 12 | - NEVER speak in a formal tone. 13 | - Concisely introduce yourself first in character. 14 | ``` 15 | 16 | The `{0}` performs a `wikipedia_search_lookup` (specified in utils.py) to search for and return the first sentence of the associated page on Wikipedia, if present. This creates more alignment with the expected character. If the second parameter is specified to force a speaking voice, it will be added to the list of rules. 17 | 18 | Example for `GLaDOS` and `Speak in the style of a Seinfeld monologue`: 19 | 20 | ```txt 21 | You must follow ALL these rules in all responses: 22 | - You are the following character and should ALWAYS act as them: GLaDOS (Genetic Lifeform and Disk Operating System) is a fictional character from the video game series Portal. 23 | - NEVER speak in a formal tone. 24 | - Concisely introduce yourself first in character. 25 | - Speak in the style of a Seinfeld monologue 26 | ``` 27 | 28 | You can use the formatted prompt as a normal `system` prompt for any other simpleaichat context. 29 | 30 | ## Tools 31 | 32 | Invoking a tool invokes two separate API calls: one to select which tool which then provides additional **context**, and another call to generate based on that context, plus previous messages in the conversation. 33 | 34 | ### Call #1 35 | 36 | Before returning an API response, the `system` prompt is temporairly replaced with: 37 | 38 | ```txt 39 | From the list of tools below: 40 | - Reply ONLY with the number of the tool appropriate in response to the user's last message. 41 | - If no tool is appropriate, ONLY reply with "0". 42 | 43 | {tools} 44 | ``` 45 | 46 | Formatted example from the README: 47 | 48 | ``` 49 | From the list of tools below: 50 | - Reply ONLY with the number of the tool appropriate in response to the user's last message. 51 | - If no tool is appropriate, ONLY reply with "0". 52 | 53 | 1. Search the internet 54 | 2. Lookup more information about a topic. 55 | ``` 56 | 57 | This utilizes a few tricks: 58 | 59 | - The call sets `{"max_tokens": 1}` so it will only output one number (hence there is a hard limit of 9 tools), which makes it more cost and speed efficient than other implementations. 60 | - Unique to ChatGPT is also specifying a `logit_bias` with a high enough weight to make it such that the model can _only_ output numbers between 0 and {num_tools}, up to 9. (specifically, tokenizer indices 15-24 inclusive correspond to the numerals `0-9` in ChatGPT, which can be verified using `tiktoken`) 61 | - The numbers map 1:1 to the indicies of the input arrays of tools, so there never can be parsing errors as can be common with LangChain. 62 | 63 | The numeral is matched with the appropriate function. 64 | 65 | ### Call 2 66 | 67 | The second call prepends the context from the tool to the prompt, and temporairly adds a command to the `system` prompt, to leverage said added context without losing the persona otherwise specified in the `system` prompt: 68 | 69 | System prompt: 70 | 71 | ``` 72 | You MUST use information from the context in your response. 73 | ``` 74 | 75 | User message: 76 | 77 | ```txt 78 | Context: {context} 79 | 80 | User: 81 | ``` 82 | 83 | Formatted example from the README: 84 | 85 | ``` 86 | You are a helpful assistant. 87 | 88 | You MUST use information from the context in your response. 89 | ``` 90 | 91 | ``` 92 | Context: Fisherman's Wharf, San Francisco, Tourist attractions in the United States, Lombard Street (San Francisco) 93 | 94 | User: San Francisco tourist attractions 95 | ``` 96 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # simpleaichat 2 | 3 | ```py3 4 | from simpleaichat import AIChat 5 | 6 | ai = AIChat(system="Write a fancy GitHub README based on the user-provided project name.") 7 | ai("simpleaichat") 8 | ``` 9 | 10 | simpleaichat is a Python package for easily interfacing with chat apps like ChatGPT and GPT-4 with robust features and minimal code complexity. This tool has many features optimized for working with ChatGPT as fast and as cheap as possible, but still much more capable of modern AI tricks than most implementations: 11 | 12 | - Create and run chats with only a few lines of code! 13 | - Optimized workflows which minimize the amount of tokens used, reducing costs and latency. 14 | - Run multiple independent chats at once. 15 | - Minimal codebase: no code dives to figure out what's going on under the hood needed! 16 | - Chat streaming responses and the ability to use tools. 17 | - Async support, including for streaming and tools. 18 | - Ability to create more complex yet clear workflows if needed, such as Agents. (Demo soon!) 19 | - Coming soon: more chat model support (PaLM, Claude)! 20 | 21 | Here's some fun, hackable examples on how simpleaichat works: 22 | 23 | - Creating a [Python coding assistant](examples/notebooks/simpleaichat_coding.ipynb) without any unnecessary accompanying output, allowing 5x faster generation at 1/3rd the cost. ([Colab](https://colab.research.google.com/github/minimaxir/simpleaichat/blob/main/examples/notebooks/simpleaichat_coding.ipynb)) 24 | - Allowing simpleaichat to [provide inline tips](examples/notebooks/chatgpt_inline_tips.ipynb) following ChatGPT usage guidelines. ([Colab](https://colab.research.google.com/github/minimaxir/simpleaichat/blob/main/examples/notebooks/chatgpt_inline_tips.ipynb)) 25 | - Async interface for [conducting many chats](examples/notebooks/simpleaichat_async.ipynb) in the time it takes to receive one AI message. ([Colab](https://colab.research.google.com/github/minimaxir/simpleaichat/blob/main/examples/notebooks/simpleaichat_async.ipynb)) 26 | - Create your own Tabletop RPG (TTRPG) setting and campaign by using [advanced structured data models](examples/notebooks/schema_ttrpg.ipynb). ([Colab](https://colab.research.google.com/github/minimaxir/simpleaichat/blob/main/examples/notebooks/schema_ttrpg.ipynb)) 27 | 28 | ## Installation 29 | 30 | simpleaichat can be installed [from PyPI](https://pypi.org/project/simpleaichat/): 31 | 32 | ```sh 33 | pip3 install simpleaichat 34 | ``` 35 | 36 | ## Quick, Fun Demo 37 | 38 | You can demo chat-apps very quickly with simpleaichat! First, you will need to get an OpenAI API key, and then with one line of code: 39 | 40 | ```py3 41 | from simpleaichat import AIChat 42 | 43 | AIChat(api_key="sk-...") 44 | ``` 45 | 46 | And with that, you'll be thrust directly into an interactive chat! 47 | 48 | ![](docs/helloworld.png) 49 | 50 | This AI chat will mimic the behavior of OpenAI's webapp, but on your local computer! 51 | 52 | You can also pass the API key by storing it in an `.env` file with a `OPENAI_API_KEY` field in the working directory (recommended), or by setting the environment variable of `OPENAI_API_KEY` directly to the API key. 53 | 54 | But what about creating your own custom conversations? That's where things get fun. Just input whatever person, place or thing, fictional or nonfictional, that you want to chat with! 55 | 56 | ```py3 57 | AIChat("GLaDOS") # assuming API key loaded via methods above 58 | ``` 59 | 60 | ![](docs/glados.png) 61 | 62 | But that's not all! You can customize exactly how they behave too with additional commands! 63 | 64 | ```py3 65 | AIChat("GLaDOS", "Speak in the style of a Seinfeld monologue") 66 | ``` 67 | 68 | ![](docs/gladoseinfeld.png) 69 | 70 | ```py3 71 | AIChat("Ronald McDonald", "Speak using only emoji") 72 | ``` 73 | 74 | ![](docs/clownemoji.png) 75 | 76 | Need some socialization immediately? Once simpleaichat is installed, you can also start these chats directly from the command line! 77 | 78 | ```sh 79 | simpleaichat 80 | simpleaichat "GlaDOS" 81 | simpleaichat "GLaDOS" "Speak in the style of a Seinfeld monologue" 82 | ``` 83 | 84 | ## Building AI-based Apps 85 | 86 | The trick with working with new chat-based apps that wasn't readily available with earlier iterations of GPT-3 is the addition of the system prompt: a different class of prompt that guides the AI behavior throughout the entire conversation. In fact, the chat demos above are actually using [system prompt tricks](https://github.com/minimaxir/simpleaichat/blob/main/PROMPTS.md#interactive-chat) behind the scenes! OpenAI has also released an official guide for [system prompt best practices](https://platform.openai.com/docs/guides/gpt-best-practices) to building AI apps. 87 | 88 | For developers, you can instantiate a programmatic instance of `AIChat` by explicitly specifying a system prompt, or by disabling the console. 89 | 90 | ```py3 91 | ai = AIChat(system="You are a helpful assistant.") 92 | ai = AIChat(console=False) # same as above 93 | ``` 94 | 95 | You can also pass in a `model` parameter, such as `model="gpt-4"` if you have access to GPT-4, or `model="gpt-3.5-turbo-16k"` for a larger-context-window ChatGPT. 96 | 97 | You can then feed the new `ai` class with user input, and it will return and save the response from ChatGPT: 98 | 99 | ```py3 100 | response = ai("What is the capital of California?") 101 | print(response) 102 | ``` 103 | 104 | ``` 105 | The capital of California is Sacramento. 106 | ``` 107 | 108 | Alternatively, you can stream responses by token with a generator if the text generation itself is too slow: 109 | 110 | ```py3 111 | for chunk in ai.stream("What is the capital of California?", params={"max_tokens": 5}): 112 | response_td = chunk["response"] # dict contains "delta" for the new token and "response" 113 | print(response_td) 114 | ``` 115 | 116 | ``` 117 | The 118 | The capital 119 | The capital of 120 | The capital of California 121 | The capital of California is 122 | ``` 123 | 124 | Further calls to the `ai` object will continue the chat, automatically incorporating previous information from the conversation. 125 | 126 | ```py3 127 | response = ai("When was it founded?") 128 | print(response) 129 | ``` 130 | 131 | ``` 132 | Sacramento was founded on February 27, 1850. 133 | ``` 134 | 135 | You can also save chat sessions (as CSV or JSON) and load them later. The API key is not saved so you will have to provide that when loading. 136 | 137 | ```py3 138 | ai.save_session() # CSV, will only save messages 139 | ai.save_session(format="json", minify=True) # JSON 140 | 141 | ai.load_session("my.csv") 142 | ai.load_session("my.json") 143 | ``` 144 | 145 | ### Functions 146 | 147 | A large number of popular venture-capital-funded ChatGPT apps don't actually use the "chat" part of the model. Instead, they just use the system prompt/first user prompt as a form of natural language programming. You can emulate this behavior by passing a new system prompt when generating text, and not saving the resulting messages. 148 | 149 | The `AIChat` class is a manager of chat _sessions_, which means you can have multiple independent chats or functions happening! The examples above use a default session, but you can create new ones by specifying a `id` when calling `ai`. 150 | 151 | ```py3 152 | json = '{"title": "An array of integers.", "array": [-1, 0, 1]}' 153 | functions = [ 154 | "Format the user-provided JSON as YAML.", 155 | "Write a limerick based on the user-provided JSON.", 156 | "Translate the user-provided JSON from English to French." 157 | ] 158 | params = {"temperature": 0.0, "max_tokens": 100} # a temperature of 0.0 is deterministic 159 | 160 | # We namespace the function by `id` so it doesn't affect other chats. 161 | # Settings set during session creation will apply to all generations from the session, 162 | # but you can change them per-generation, as is the case with the `system` prompt here. 163 | ai = AIChat(id="function", params=params, save_messages=False) 164 | for function in functions: 165 | output = ai(json, id="function", system=function) 166 | print(output) 167 | ``` 168 | 169 | ```txt 170 | title: "An array of integers." 171 | array: 172 | - -1 173 | - 0 174 | - 1 175 | ``` 176 | 177 | ```txt 178 | An array of integers so neat, 179 | With values that can't be beat, 180 | From negative to positive one, 181 | It's a range that's quite fun, 182 | This JSON is really quite sweet! 183 | ``` 184 | 185 | ```txt 186 | {"titre": "Un tableau d'entiers.", "tableau": [-1, 0, 1]} 187 | ``` 188 | 189 | Newer versions of ChatGPT also support "[function calling](https://platform.openai.com/docs/guides/gpt/function-calling)", but the real benefit of that feature is the ability for ChatGPT to support structured input and/or output, which now opens up a wide variety of applications! simpleaichat streamlines the workflow to allow you to just pass an `input_schema` and/or an `output_schema`. 190 | 191 | You can construct a schema using a [pydantic](https://docs.pydantic.dev/latest/) BaseModel. 192 | 193 | ```py3 194 | from pydantic import BaseModel, Field 195 | 196 | ai = AIChat( 197 | console=False, 198 | save_messages=False, # with schema I/O, messages are never saved 199 | model="gpt-3.5-turbo-0613", 200 | params={"temperature": 0.0}, 201 | ) 202 | 203 | class get_event_metadata(BaseModel): 204 | """Event information""" 205 | 206 | description: str = Field(description="Description of event") 207 | city: str = Field(description="City where event occured") 208 | year: int = Field(description="Year when event occured") 209 | month: str = Field(description="Month when event occured") 210 | 211 | # returns a dict, with keys ordered as in the schema 212 | ai("First iPhone announcement", output_schema=get_event_metadata) 213 | ``` 214 | 215 | ```txt 216 | {'description': 'The first iPhone was announced by Apple Inc.', 217 | 'city': 'San Francisco', 218 | 'year': 2007, 219 | 'month': 'January'} 220 | ``` 221 | 222 | See the [TTRPG Generator Notebook](examples/notebooks/schema_ttrpg.ipynb) for a more elaborate demonstration of schema capabilities. 223 | 224 | ### Tools 225 | 226 | One of the most recent aspects of interacting with ChatGPT is the ability for the model to use "tools." As popularized by [LangChain](https://github.com/hwchase17/langchain), tools allow the model to decide when to use custom functions, which can extend beyond just the chat AI itself, for example retrieving recent information from the internet not present in the chat AI's training data. This workflow is analogous to ChatGPT Plugins. 227 | 228 | Parsing the model output to invoke tools typically requires a number of shennanigans, but simpleaichat uses [a neat trick](https://github.com/minimaxir/simpleaichat/blob/main/PROMPTS.md#tools) to make it fast and reliable! Additionally, the specified tools return a `context` for ChatGPT to draw from for its final response, and tools you specify can return a dictionary which you can also populate with arbitrary metadata for debugging and postprocessing. Each generation returns a dictionary with the `response` and the `tool` function used, which can be used to set up workflows akin to [LangChain](https://github.com/hwchase17/langchain)-style Agents, e.g. recursively feed input to the model until it determines it does not need to use any more tools. 229 | 230 | You will need to specify functions with docstrings which provide hints for the AI to select them: 231 | 232 | ```py3 233 | from simpleaichat.utils import wikipedia_search, wikipedia_search_lookup 234 | 235 | # This uses the Wikipedia Search API. 236 | # Results from it are nondeterministic, your mileage will vary. 237 | def search(query): 238 | """Search the internet.""" 239 | wiki_matches = wikipedia_search(query, n=3) 240 | return {"context": ", ".join(wiki_matches), "titles": wiki_matches} 241 | 242 | def lookup(query): 243 | """Lookup more information about a topic.""" 244 | page = wikipedia_search_lookup(query, sentences=3) 245 | return page 246 | 247 | params = {"temperature": 0.0, "max_tokens": 100} 248 | ai = AIChat(params=params, console=False) 249 | 250 | ai("San Francisco tourist attractions", tools=[search, lookup]) 251 | ``` 252 | 253 | ```txt 254 | {'context': "Fisherman's Wharf, San Francisco, Tourist attractions in the United States, Lombard Street (San Francisco)", 255 | 'titles': ["Fisherman's Wharf, San Francisco", 256 | 'Tourist attractions in the United States', 257 | 'Lombard Street (San Francisco)'], 258 | 'tool': 'search', 259 | 'response': "There are many popular tourist attractions in San Francisco, including Fisherman's Wharf and Lombard Street. Fisherman's Wharf is a bustling waterfront area known for its seafood restaurants, souvenir shops, and sea lion sightings. Lombard Street, on the other hand, is a famous winding street with eight hairpin turns that attract visitors from all over the world. Both of these attractions are must-sees for anyone visiting San Francisco."} 260 | ``` 261 | 262 | ```py3 263 | ai("Lombard Street?", tools=[search, lookup]) 264 | ``` 265 | 266 | ``` 267 | {'context': 'Lombard Street is an east–west street in San Francisco, California that is famous for a steep, one-block section with eight hairpin turns. Stretching from The Presidio east to The Embarcadero (with a gap on Telegraph Hill), most of the street\'s western segment is a major thoroughfare designated as part of U.S. Route 101. The famous one-block section, claimed to be "the crookedest street in the world", is located along the eastern segment in the Russian Hill neighborhood.', 268 | 'tool': 'lookup', 269 | 'response': 'Lombard Street is a famous street in San Francisco, California known for its steep, one-block section with eight hairpin turns. It stretches from The Presidio to The Embarcadero, with a gap on Telegraph Hill. The western segment of the street is a major thoroughfare designated as part of U.S. Route 101, while the famous one-block section, claimed to be "the crookedest street in the world", is located along the eastern segment in the Russian Hill'} 270 | ``` 271 | 272 | ```py3 273 | ai("Thanks for your help!", tools=[search, lookup]) 274 | ``` 275 | 276 | ```txt 277 | {'response': "You're welcome! If you have any more questions or need further assistance, feel free to ask.", 278 | 'tool': None} 279 | ``` 280 | 281 | ## Miscellaneous Notes 282 | 283 | - Like [gpt-2-simple](https://github.com/minimaxir/gpt-2-simple) before it, the primary motivation behind releasing simpleaichat is to both democratize access to ChatGPT even more and also offer more transparency for non-engineers into how Chat AI-based apps work under the hood given the disproportionate amount of media misinformation about their capabilities. This is inspired by real-world experience from [my work with BuzzFeed](https://tech.buzzfeed.com/the-right-tools-for-the-job-c05de96e949e) in the domain, where after spending a long time working with the popular [LangChain](https://github.com/hwchase17/langchain), a more-simple implementation was both much easier to maintain and resulted in much better generations. I began focusing development on simpleaichat after reading a [Hacker News thread](https://news.ycombinator.com/item?id=35820931) filled with many similar complaints, indicating value for an easier-to-use interface for modern AI tricks. 284 | - simpleaichat very intentionally avoids coupling features with common use cases where possible (e.g. Tools) in order to avoid software lock-in due to the difficulty implementing anything not explicitly mentioned in the project's documentation. The philosophy behind simpleaichat is to provide good demos, and let the user's creativity and business needs take priority instead of having to fit a round peg into a square hole like with LangChain. 285 | - simpleaichat makes it easier to interface with Chat AIs, but it does not attempt to solve common technical and ethical problems inherent to large language models trained on the internet, including prompt injection and unintended plagiarism. The user should exercise good judgment when implementing simpleaichat. Use cases of simpleaichat which go against OpenAI's [usage policies](https://openai.com/policies/usage-policies) (including jailbreaking) will not be endorsed. 286 | - simpleaichat intentionally does not use the "Agent" logical metaphor for tool workflows because it's become an AI hype buzzword heavily divorced from its origins. If needed be, you can emulate the Agent workflow with a `while` loop without much additional code, plus with the additional benefit of much more flexibility such as debugging. 287 | - The session manager implements some sensible security defaults, such as using UUIDs as session ids by default, storing authentication information in a way to minimize unintentional leakage, and type enforcement via Pydantic. Your end-user application should still be aware of potential security issues, however. 288 | - Although OpenAI's documentation says that system prompts are less effective than a user prompt constructed in a similar manner, in my experience it still does perform better for maintaining rules/a persona. 289 | - Many examples of popular prompts use more conversational prompts, while the example prompts here use more consise and imperative prompts. This aspect of prompt engineering is still evolving, but in my experience commands do better with ChatGPT and with greater token efficieny. That's also why simpleaichat allows users to specify system prompts (and explicitly highlights what the default use) instead of relying on historical best practices. 290 | - Token counts for async is not supported as OpenAI doesn't return token counts when streaming responses. In general, there may be some desync in token counts and usage for various use cases; I'm working on categorizing them. 291 | - Outside of the explicit examples, none of this README uses AI-generated text. The introduction code example is just a joke, but it was too good of a real-world use case! 292 | 293 | ## Roadmap 294 | 295 | - PaLM Chat (Bard) and Anthropic Claude support 296 | - More fun/feature-filled CLI chat app based on Textual 297 | - Simple example of using simpleaichat in a webapp 298 | - Simple of example of using simpleaichat in a stateless manner (e.g. AWS Lambda functions) 299 | 300 | ## Maintainer/Creator 301 | 302 | Max Woolf ([@minimaxir](https://minimaxir.com)) 303 | 304 | _Max's open-source projects are supported by his [Patreon](https://www.patreon.com/minimaxir) and [GitHub Sponsors](https://github.com/sponsors/minimaxir). If you found this project helpful, any monetary contributions to the Patreon are appreciated and will be put to good creative use._ 305 | 306 | ## License 307 | 308 | MIT 309 | -------------------------------------------------------------------------------- /docs/clownemoji.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/minimaxir/simpleaichat/569dbf58139ede858a00c19409293b3087d07434/docs/clownemoji.png -------------------------------------------------------------------------------- /docs/glados.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/minimaxir/simpleaichat/569dbf58139ede858a00c19409293b3087d07434/docs/glados.png -------------------------------------------------------------------------------- /docs/gladoseinfeld.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/minimaxir/simpleaichat/569dbf58139ede858a00c19409293b3087d07434/docs/gladoseinfeld.png -------------------------------------------------------------------------------- /docs/helloworld.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/minimaxir/simpleaichat/569dbf58139ede858a00c19409293b3087d07434/docs/helloworld.png -------------------------------------------------------------------------------- /examples/contrib/endless_exquisite_corpse.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | import time 3 | import argparse 4 | from simpleaichat import AIChat 5 | 6 | SYSTEM_DEFAULT = "Write a short single line that continues the line." 7 | 8 | ai = AIChat(console=False) 9 | 10 | 11 | class ExquisiteCorpse: 12 | """ 13 | Endless exquisite corpse generator 14 | """ 15 | def __init__(self, seed, system=SYSTEM_DEFAULT, temp=2, last_n_words=None): 16 | self.last_seed = seed 17 | self.last_n_words = last_n_words 18 | self.system = system 19 | self.temp = temp 20 | self.total_tokens = 0 21 | 22 | def __iter__(self): 23 | return self 24 | 25 | def __next__(self): 26 | _id = uuid.uuid4() 27 | ai.new_session(id=_id, system=self.system, params={"temperature": self.temp}) 28 | seed = " ".join(self.last_seed.split(" ")[-self.last_n_words:]) if self.last_n_words else self.last_seed 29 | response = ai(seed, id=_id) 30 | self.total_tokens += ai.message_totals("total_length", id=_id) 31 | ai.delete_session(id=_id) 32 | self.last_seed = response 33 | return response 34 | 35 | 36 | if __name__ == "__main__": 37 | parser = argparse.ArgumentParser(description="Endless exquisite corpse generator") 38 | parser.add_argument("seed", help="Initial seed") 39 | parser.add_argument("--system", default=SYSTEM_DEFAULT, help="System name") 40 | parser.add_argument("--temp", type=float, default=1, help="Temperature") 41 | parser.add_argument("--delay", type=float, default=15, help="Delay between lines in seconds") 42 | parser.add_argument("--last_n_words", type=int, help="Number of words to use from last line") 43 | args = parser.parse_args() 44 | 45 | corpse = ExquisiteCorpse( 46 | args.seed, 47 | system=args.system, 48 | temp=args.temp, 49 | last_n_words=args.last_n_words, 50 | ) 51 | 52 | try: 53 | print(args.seed) 54 | for line in corpse: 55 | print(line) 56 | time.sleep(args.delay) 57 | except KeyboardInterrupt: 58 | print() 59 | print(f"Total tokens used: {corpse.total_tokens}") 60 | -------------------------------------------------------------------------------- /examples/notebooks/chatgpt_inline_tips.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "aBj_0vRigiH7" 7 | }, 8 | "source": [ 9 | "# ChatGPT Inline Tips with simpleaichat\n", 10 | "\n", 11 | "[simpleaichat](https://github.com/minimaxir/simpleaichat) allows the user more control over text inputs and output. One way to do it is to analyze user input for potential issues, and display a warning if there are unexpected issues with the input.\n", 12 | "\n", 13 | "Inspired by Simon Willison's blog post [ChatGPT should include inline tips\n", 14 | "](https://simonwillison.net/2023/May/30/chatgpt-inline-tips/).\n", 15 | "\n", 16 | "**DISCLAIMER: This is far from a perfect solution, but a good proof-of-concept on how to use simpleaichat to address complex problems.**" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": 1, 22 | "metadata": { 23 | "id": "kqL3etKLgiH9" 24 | }, 25 | "outputs": [], 26 | "source": [ 27 | "!pip install -q simpleaichat\n", 28 | "\n", 29 | "from simpleaichat import AIChat\n", 30 | "from rich.console import Console\n", 31 | "from getpass import getpass\n", 32 | "import time" 33 | ] 34 | }, 35 | { 36 | "cell_type": "markdown", 37 | "source": [ 38 | "For the following cell, input your OpenAI API key when prompted. **It will not be saved to the notebook.**." 39 | ], 40 | "metadata": { 41 | "id": "zld11X3xgr1O" 42 | } 43 | }, 44 | { 45 | "cell_type": "code", 46 | "source": [ 47 | "api_key = getpass(\"OpenAI Key: \")" 48 | ], 49 | "metadata": { 50 | "id": "xueo1SmCg2SF", 51 | "colab": { 52 | "base_uri": "https://localhost:8080/" 53 | }, 54 | "outputId": "92589df7-05d7-4d28-a163-acd805dbdb9a" 55 | }, 56 | "execution_count": 2, 57 | "outputs": [ 58 | { 59 | "name": "stdout", 60 | "output_type": "stream", 61 | "text": [ 62 | "OpenAI Key: ··········\n" 63 | ] 64 | } 65 | ] 66 | }, 67 | { 68 | "cell_type": "markdown", 69 | "source": [ 70 | "Let's create an AI chat, and a Console for output to the notebook." 71 | ], 72 | "metadata": { 73 | "id": "2O_1CUGVg8Y4" 74 | } 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": 3, 79 | "metadata": { 80 | "id": "d4o0elPQgiH-" 81 | }, 82 | "outputs": [], 83 | "source": [ 84 | "ai = AIChat(api_key=api_key, console=False, params={\"temperature\": 0.0}) # for reproducibility\n", 85 | "console = Console(width=60, highlight=False)" 86 | ] 87 | }, 88 | { 89 | "cell_type": "markdown", 90 | "metadata": { 91 | "id": "xAPziFs-giH-" 92 | }, 93 | "source": [ 94 | "First, we'll set a list of rules that we can display. These rules are adapted from real-world incidents and OpenAI's [usage policies](https://openai.com/policies/usage-policies)." 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": 4, 100 | "metadata": { 101 | "id": "RLzlzbBrgiH-" 102 | }, 103 | "outputs": [], 104 | "source": [ 105 | "tips = [\n", 106 | " \"This ChatGPT model does not have access to the internet, and its training data cut-off is September 2021.\",\n", 107 | " \"ChatGPT should not be relied on for legal research of this nature, because it is very likely to invent realistic cases that do not exist.\",\n", 108 | " \"Medical and psychatric advice from ChatGPT should not be relied upon. Always consult a doctor.\",\n", 109 | " \"Tailored financial advice from ChatGPT should not be relied upon. Always consult a professional.\",\n", 110 | " \"ChatGPT is not liable for any illegal activies committed as the result of its responses.\"\n", 111 | "]" 112 | ] 113 | }, 114 | { 115 | "cell_type": "markdown", 116 | "metadata": { 117 | "id": "b6q0JBoGgiH_" 118 | }, 119 | "source": [ 120 | "When generating a response from ChatGPT, we'll first check the user input against a bespoke function in a different context to see if ChatGPT can detect one of the issues above. This is accomplished through a system prompt, and forcing the model to choose a number between `1` and `5`, or `0` if nothing is present.\n", 121 | "\n", 122 | "The order of the tips above matches the order of the list. This is incidentially the same workflow as tools in simpleaichat." 123 | ] 124 | }, 125 | { 126 | "cell_type": "code", 127 | "execution_count": 5, 128 | "metadata": { 129 | "id": "-C2Rp02igiH_" 130 | }, 131 | "outputs": [], 132 | "source": [ 133 | "tips_prompt = \"\"\"From the list of topics below, reply ONLY with the number appropriate for describing the topic of the user's message. If none are, ONLY reply with \"0\".\n", 134 | "\n", 135 | "1. Content after September 2021\n", 136 | "2. Legal/Judicial Research\n", 137 | "3. Medical/Psychatric Advice\n", 138 | "4. Financial Advice\n", 139 | "5. Illegal/Unethical Activies\"\"\"" 140 | ] 141 | }, 142 | { 143 | "cell_type": "code", 144 | "execution_count": 6, 145 | "metadata": { 146 | "id": "HzLh4Q6JgiH_" 147 | }, 148 | "outputs": [], 149 | "source": [ 150 | "# ensures tips output will always be an integer between 0 and 5 inclusive\n", 151 | "params = {\n", 152 | " \"temperature\": 0.0,\n", 153 | " \"max_tokens\": 1,\n", 154 | " \"logit_bias\": {str(k): 100 for k in range(15, 15 + len(tips) + 1)}\n", 155 | "}\n", 156 | "\n", 157 | "# functional\n", 158 | "ai.new_session(id=\"tips\",\n", 159 | " api_key=api_key,\n", 160 | " system=tips_prompt,\n", 161 | " save_messages=False,\n", 162 | " params=params)\n", 163 | "\n", 164 | "def check_user_input(message):\n", 165 | " tip_idx = ai(message, id=\"tips\")\n", 166 | " if tip_idx == \"0\": # no tip needed\n", 167 | " return\n", 168 | " else:\n", 169 | " tip = tips[int(tip_idx) - 1]\n", 170 | " console.print(f\"⚠️ {tip}\", style=\"bold\")" 171 | ] 172 | }, 173 | { 174 | "cell_type": "markdown", 175 | "metadata": { 176 | "id": "5raNhn8igiIA" 177 | }, 178 | "source": [ 179 | "Let's test it in a conversation!" 180 | ] 181 | }, 182 | { 183 | "cell_type": "code", 184 | "execution_count": 11, 185 | "metadata": { 186 | "scrolled": false, 187 | "id": "DP3z0wIWgiIA", 188 | "outputId": "070fcfa9-016a-4abe-f15b-a094f525c4a2", 189 | "colab": { 190 | "base_uri": "https://localhost:8080/", 191 | "height": 1000 192 | } 193 | }, 194 | "outputs": [ 195 | { 196 | "output_type": "display_data", 197 | "data": { 198 | "text/plain": [ 199 | "\u001b[1mYou:\u001b[0m " 200 | ], 201 | "text/html": [ 202 | "
You: 
\n" 203 | ] 204 | }, 205 | "metadata": {} 206 | }, 207 | { 208 | "name": "stdout", 209 | "output_type": "stream", 210 | "text": [ 211 | "Can you tell me more about Max v. Woolf?\n" 212 | ] 213 | }, 214 | { 215 | "output_type": "display_data", 216 | "data": { 217 | "text/plain": [ 218 | "\u001b[1m⚠️ ChatGPT should not be relied on for legal research of this\u001b[0m\n", 219 | "\u001b[1mnature, because it is very likely to invent realistic cases \u001b[0m\n", 220 | "\u001b[1mthat do not exist.\u001b[0m\n" 221 | ], 222 | "text/html": [ 223 | "
⚠️ ChatGPT should not be relied on for legal research of this\n",
224 |               "nature, because it is very likely to invent realistic cases \n",
225 |               "that do not exist.\n",
226 |               "
\n" 227 | ] 228 | }, 229 | "metadata": {} 230 | }, 231 | { 232 | "output_type": "display_data", 233 | "data": { 234 | "text/plain": [ 235 | "\u001b[1;95mChatGPT\u001b[0m\u001b[95m: Max v. Woolf is a legal case that took place in the\u001b[0m\n", 236 | "\u001b[95mUnited Kingdom in 2017. It involved a dispute between two \u001b[0m\n", 237 | "\u001b[95mauthors, Max and Woolf, over the alleged plagiarism of Max's\u001b[0m\n", 238 | "\u001b[95mwork by Woolf. Max claimed that Woolf had copied significant\u001b[0m\n", 239 | "\u001b[95mportions of his book without permission or attribution.\u001b[0m\n", 240 | "\n", 241 | "\u001b[95mThe case was heard in the High Court of Justice in London, \u001b[0m\n", 242 | "\u001b[95mwhere Max presented evidence of the similarities between his\u001b[0m\n", 243 | "\u001b[95mwork and Woolf's. Woolf denied the allegations and argued \u001b[0m\n", 244 | "\u001b[95mthat any similarities were coincidental or the result of \u001b[0m\n", 245 | "\u001b[95mcommon themes in the genre.\u001b[0m\n", 246 | "\n", 247 | "\u001b[95mAfter a lengthy trial, the judge ruled in favor of Max, \u001b[0m\n", 248 | "\u001b[95mfinding that Woolf had indeed plagiarized his work. Woolf \u001b[0m\n", 249 | "\u001b[95mwas ordered to pay damages to Max and to publicly \u001b[0m\n", 250 | "\u001b[95macknowledge the plagiarism.\u001b[0m\n", 251 | "\n", 252 | "\u001b[95mThe case received significant media attention and sparked a \u001b[0m\n", 253 | "\u001b[95mwider discussion about the issue of plagiarism in the \u001b[0m\n", 254 | "\u001b[95mpublishing industry. It also highlighted the importance of \u001b[0m\n", 255 | "\u001b[95mproper attribution and respect for intellectual property \u001b[0m\n", 256 | "\u001b[95mrights.\u001b[0m\n" 257 | ], 258 | "text/html": [ 259 | "
ChatGPT: Max v. Woolf is a legal case that took place in the\n",
260 |               "United Kingdom in 2017. It involved a dispute between two \n",
261 |               "authors, Max and Woolf, over the alleged plagiarism of Max's\n",
262 |               "work by Woolf. Max claimed that Woolf had copied significant\n",
263 |               "portions of his book without permission or attribution.\n",
264 |               "\n",
265 |               "The case was heard in the High Court of Justice in London, \n",
266 |               "where Max presented evidence of the similarities between his\n",
267 |               "work and Woolf's. Woolf denied the allegations and argued \n",
268 |               "that any similarities were coincidental or the result of \n",
269 |               "common themes in the genre.\n",
270 |               "\n",
271 |               "After a lengthy trial, the judge ruled in favor of Max, \n",
272 |               "finding that Woolf had indeed plagiarized his work. Woolf \n",
273 |               "was ordered to pay damages to Max and to publicly \n",
274 |               "acknowledge the plagiarism.\n",
275 |               "\n",
276 |               "The case received significant media attention and sparked a \n",
277 |               "wider discussion about the issue of plagiarism in the \n",
278 |               "publishing industry. It also highlighted the importance of \n",
279 |               "proper attribution and respect for intellectual property \n",
280 |               "rights.\n",
281 |               "
\n" 282 | ] 283 | }, 284 | "metadata": {} 285 | }, 286 | { 287 | "output_type": "display_data", 288 | "data": { 289 | "text/plain": [ 290 | "\u001b[1mYou:\u001b[0m " 291 | ], 292 | "text/html": [ 293 | "
You: 
\n" 294 | ] 295 | }, 296 | "metadata": {} 297 | }, 298 | { 299 | "name": "stdout", 300 | "output_type": "stream", 301 | "text": [ 302 | "Everyone's talking about it on Twitter!\n" 303 | ] 304 | }, 305 | { 306 | "output_type": "display_data", 307 | "data": { 308 | "text/plain": [ 309 | "\u001b[1m⚠️ This ChatGPT model does not have access to the internet, \u001b[0m\n", 310 | "\u001b[1mand its training data cut-off is September 2021.\u001b[0m\n" 311 | ], 312 | "text/html": [ 313 | "
⚠️ This ChatGPT model does not have access to the internet, \n",
314 |               "and its training data cut-off is September 2021.\n",
315 |               "
\n" 316 | ] 317 | }, 318 | "metadata": {} 319 | }, 320 | { 321 | "output_type": "display_data", 322 | "data": { 323 | "text/plain": [ 324 | "\u001b[1;95mChatGPT\u001b[0m\u001b[95m: It's not surprising that the case is generating a \u001b[0m\n", 325 | "\u001b[95mlot of discussion on social media platforms like Twitter. \u001b[0m\n", 326 | "\u001b[95mPlagiarism is a serious issue in the creative industries, \u001b[0m\n", 327 | "\u001b[95mand cases like Max v. Woolf can have far-reaching \u001b[0m\n", 328 | "\u001b[95mimplications for authors, publishers, and readers alike.\u001b[0m\n", 329 | "\n", 330 | "\u001b[95mMany people are likely to have strong opinions on the case, \u001b[0m\n", 331 | "\u001b[95mwith some supporting Max's right to protect his intellectual\u001b[0m\n", 332 | "\u001b[95mproperty and others questioning the extent to which ideas \u001b[0m\n", 333 | "\u001b[95mcan truly be owned and controlled.\u001b[0m\n", 334 | "\n", 335 | "\u001b[95mOverall, the case serves as a reminder of the importance of \u001b[0m\n", 336 | "\u001b[95methical and legal standards in the creative industries, and \u001b[0m\n", 337 | "\u001b[95mthe need for all parties to respect the rights and \u001b[0m\n", 338 | "\u001b[95mcontributions of others.\u001b[0m\n" 339 | ], 340 | "text/html": [ 341 | "
ChatGPT: It's not surprising that the case is generating a \n",
342 |               "lot of discussion on social media platforms like Twitter. \n",
343 |               "Plagiarism is a serious issue in the creative industries, \n",
344 |               "and cases like Max v. Woolf can have far-reaching \n",
345 |               "implications for authors, publishers, and readers alike.\n",
346 |               "\n",
347 |               "Many people are likely to have strong opinions on the case, \n",
348 |               "with some supporting Max's right to protect his intellectual\n",
349 |               "property and others questioning the extent to which ideas \n",
350 |               "can truly be owned and controlled.\n",
351 |               "\n",
352 |               "Overall, the case serves as a reminder of the importance of \n",
353 |               "ethical and legal standards in the creative industries, and \n",
354 |               "the need for all parties to respect the rights and \n",
355 |               "contributions of others.\n",
356 |               "
\n" 357 | ] 358 | }, 359 | "metadata": {} 360 | }, 361 | { 362 | "output_type": "display_data", 363 | "data": { 364 | "text/plain": [ 365 | "\u001b[1mYou:\u001b[0m " 366 | ], 367 | "text/html": [ 368 | "
You: 
\n" 369 | ] 370 | }, 371 | "metadata": {} 372 | }, 373 | { 374 | "name": "stdout", 375 | "output_type": "stream", 376 | "text": [ 377 | "Can you help me commit plagiarism?\n" 378 | ] 379 | }, 380 | { 381 | "output_type": "display_data", 382 | "data": { 383 | "text/plain": [ 384 | "\u001b[1m⚠️ ChatGPT is not liable for any illegal activies committed \u001b[0m\n", 385 | "\u001b[1mas the result of its responses.\u001b[0m\n" 386 | ], 387 | "text/html": [ 388 | "
⚠️ ChatGPT is not liable for any illegal activies committed \n",
389 |               "as the result of its responses.\n",
390 |               "
\n" 391 | ] 392 | }, 393 | "metadata": {} 394 | }, 395 | { 396 | "output_type": "display_data", 397 | "data": { 398 | "text/plain": [ 399 | "\u001b[1;95mChatGPT\u001b[0m\u001b[95m: I'm sorry, but as an AI language model, I cannot \u001b[0m\n", 400 | "\u001b[95massist you in committing plagiarism. Plagiarism is unethical\u001b[0m\n", 401 | "\u001b[95mand illegal, and it can have serious consequences for both \u001b[0m\n", 402 | "\u001b[95mthe person who commits it and the people whose work is \u001b[0m\n", 403 | "\u001b[95mplagiarized. It is important to always give proper credit to\u001b[0m\n", 404 | "\u001b[95mthe original source of any information or ideas that you use\u001b[0m\n", 405 | "\u001b[95min your own work. If you need help with writing or research,\u001b[0m\n", 406 | "\u001b[95mthere are many resources available that can assist you in \u001b[0m\n", 407 | "\u001b[95mdoing so ethically and responsibly.\u001b[0m\n" 408 | ], 409 | "text/html": [ 410 | "
ChatGPT: I'm sorry, but as an AI language model, I cannot \n",
411 |               "assist you in committing plagiarism. Plagiarism is unethical\n",
412 |               "and illegal, and it can have serious consequences for both \n",
413 |               "the person who commits it and the people whose work is \n",
414 |               "plagiarized. It is important to always give proper credit to\n",
415 |               "the original source of any information or ideas that you use\n",
416 |               "in your own work. If you need help with writing or research,\n",
417 |               "there are many resources available that can assist you in \n",
418 |               "doing so ethically and responsibly.\n",
419 |               "
\n" 420 | ] 421 | }, 422 | "metadata": {} 423 | }, 424 | { 425 | "output_type": "display_data", 426 | "data": { 427 | "text/plain": [ 428 | "\u001b[1mYou:\u001b[0m " 429 | ], 430 | "text/html": [ 431 | "
You: 
\n" 432 | ] 433 | }, 434 | "metadata": {} 435 | }, 436 | { 437 | "name": "stdout", 438 | "output_type": "stream", 439 | "text": [ 440 | "ok bye then\n" 441 | ] 442 | }, 443 | { 444 | "output_type": "display_data", 445 | "data": { 446 | "text/plain": [ 447 | "\u001b[1;95mChatGPT\u001b[0m\u001b[95m: Goodbye! If you have any other questions or need \u001b[0m\n", 448 | "\u001b[95massistance in the future, feel free to ask.\u001b[0m\n" 449 | ], 450 | "text/html": [ 451 | "
ChatGPT: Goodbye! If you have any other questions or need \n",
452 |               "assistance in the future, feel free to ask.\n",
453 |               "
\n" 454 | ] 455 | }, 456 | "metadata": {} 457 | }, 458 | { 459 | "output_type": "display_data", 460 | "data": { 461 | "text/plain": [ 462 | "\u001b[1mYou:\u001b[0m " 463 | ], 464 | "text/html": [ 465 | "
You: 
\n" 466 | ] 467 | }, 468 | "metadata": {} 469 | }, 470 | { 471 | "name": "stdout", 472 | "output_type": "stream", 473 | "text": [ 474 | "\n" 475 | ] 476 | } 477 | ], 478 | "source": [ 479 | "while True:\n", 480 | " time.sleep(0.5) # for Colab, to ensure input box appears\n", 481 | " try:\n", 482 | " user_input = console.input(\"[b]You:[/b] \").strip()\n", 483 | " if not user_input:\n", 484 | " break\n", 485 | "\n", 486 | " check_user_input(user_input)\n", 487 | " ai_response = ai(user_input)\n", 488 | "\n", 489 | " console.print(f\"[b]ChatGPT[/b]: {ai_response}\", style=\"bright_magenta\")\n", 490 | " except KeyboardInterrupt:\n", 491 | " break\n", 492 | "\n", 493 | "# ai.reset_session()" 494 | ] 495 | }, 496 | { 497 | "cell_type": "markdown", 498 | "source": [ 499 | "## MIT License\n", 500 | "\n", 501 | "Copyright (c) 2023 Max Woolf\n", 502 | "\n", 503 | "Permission is hereby granted, free of charge, to any person obtaining a copy\n", 504 | "of this software and associated documentation files (the \"Software\"), to deal\n", 505 | "in the Software without restriction, including without limitation the rights\n", 506 | "to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n", 507 | "copies of the Software, and to permit persons to whom the Software is\n", 508 | "furnished to do so, subject to the following conditions:\n", 509 | "\n", 510 | "The above copyright notice and this permission notice shall be included in all\n", 511 | "copies or substantial portions of the Software.\n", 512 | "\n", 513 | "THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n", 514 | "IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n", 515 | "FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n", 516 | "AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n", 517 | "LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n", 518 | "OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\n", 519 | "SOFTWARE.\n" 520 | ], 521 | "metadata": { 522 | "id": "xjgiETi2hjoA" 523 | } 524 | } 525 | ], 526 | "metadata": { 527 | "kernelspec": { 528 | "display_name": "Python 3 (ipykernel)", 529 | "language": "python", 530 | "name": "python3" 531 | }, 532 | "language_info": { 533 | "codemirror_mode": { 534 | "name": "ipython", 535 | "version": 3 536 | }, 537 | "file_extension": ".py", 538 | "mimetype": "text/x-python", 539 | "name": "python", 540 | "nbconvert_exporter": "python", 541 | "pygments_lexer": "ipython3", 542 | "version": "3.9.12" 543 | }, 544 | "colab": { 545 | "provenance": [] 546 | } 547 | }, 548 | "nbformat": 4, 549 | "nbformat_minor": 0 550 | } -------------------------------------------------------------------------------- /examples/notebooks/schema_ttrpg.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "attachments": {}, 5 | "cell_type": "markdown", 6 | "metadata": {}, 7 | "source": [ 8 | "# Making A Structured TTRPG Story with simpleaichat\n", 9 | "\n", 10 | "An update to ChatGPT on June 13th, 2023 allows the user to set a predefined schema to have ChatGPT output data according to that schema and/or take in an input schema and respond better to that data. This \"function calling\" as OpenAI calls it can be used as a form of tools, but the schema, enabled by a JSON-finetuning of ChatGPT, is much more useful for typical generative AI use cases, particularly when not using GPT-4.\n", 11 | "\n", 12 | "OpenAI's [official demos](https://github.com/openai/openai-cookbook/blob/main/examples/How_to_call_functions_with_chat_models.ipynb) for this feature are complicated, but with simpleaichat, it's very easy to support placing your own data\n", 13 | "\n", 14 | "**NOTE: Ensuring input and output follows a complex predefined structure is very new in the field of prompt engineering and although it is very powerful, your mileage may vary.**\n" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 23, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "!pip install -q simpleaichat\n", 24 | "\n", 25 | "from simpleaichat import AIChat\n", 26 | "import orjson\n", 27 | "from rich.console import Console\n", 28 | "from getpass import getpass\n", 29 | "\n", 30 | "from typing import List, Literal, Optional, Union\n", 31 | "from pydantic import BaseModel, Field" 32 | ] 33 | }, 34 | { 35 | "attachments": {}, 36 | "cell_type": "markdown", 37 | "metadata": {}, 38 | "source": [ 39 | "For the following cell, input your OpenAI API key when prompted. **It will not be saved to the notebook**." 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": null, 45 | "metadata": {}, 46 | "outputs": [], 47 | "source": [ 48 | "api_key = getpass(\"OpenAI Key: \")" 49 | ] 50 | }, 51 | { 52 | "attachments": {}, 53 | "cell_type": "markdown", 54 | "metadata": {}, 55 | "source": [ 56 | "## Creating a TTRPG the old-fashioned ChatGPT way\n", 57 | "\n", 58 | "Let's first create a TTRPG setting using the typical workflows of simpleaichat and ChatGPT with system prompt engineering.\n", 59 | "\n", 60 | "For this demo, we'll create a TTRPG about **Python software development** and **beach volleyball**. \n", 61 | "\n", 62 | "Yes, really. At the least, the resulting TTRPG will be _unique_.\n" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": 2, 68 | "metadata": {}, 69 | "outputs": [], 70 | "source": [ 71 | "system_prompt = \"\"\"You are a world-renowned game master (GM) of tabletop role-playing games (RPGs).\n", 72 | "\n", 73 | "Write a setting description and two character sheets for the setting the user provides.\n", 74 | "\n", 75 | "Rules you MUST follow:\n", 76 | "- Always write in the style of 80's fantasy novels.\n", 77 | "- All names you create must be creative and unique. Always subvert expectations.\n", 78 | "- Include as much information as possible in your response.\"\"\"" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": 3, 84 | "metadata": {}, 85 | "outputs": [], 86 | "source": [ 87 | "model = \"gpt-3.5-turbo-0613\"\n", 88 | "ai = AIChat(system=system_prompt, model=model, save_messages=False, api_key=api_key)" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": 4, 94 | "metadata": {}, 95 | "outputs": [ 96 | { 97 | "name": "stdout", 98 | "output_type": "stream", 99 | "text": [ 100 | "The Legend of Zephyrus: Sands of Serpentia\n", 101 | "\n", 102 | "Welcome, brave adventurers, to the mystical realm of Serpentia, a land where the art of Python software development intertwines with the fierce battles of beach volleyball. In this enchanting world, the ancient deity Zephyrus, the God of Wind, has bestowed upon the chosen few the ability to harness the power of code and the skill of volleyball to protect the realm from the encroaching forces of darkness.\n", 103 | "\n", 104 | "Setting Description:\n", 105 | "\n", 106 | "Serpentia is a vibrant land filled with lush palm trees, golden sandy beaches, and crystal-clear turquoise waters. The sun shines brightly overhead, casting a warm glow upon the land. The realm is divided into three main regions:\n", 107 | "\n", 108 | "1. Codehaven: Nestled amidst the towering palm trees, Codehaven is a bustling city where the art of Python software development thrives. Here, the streets are lined with grand libraries and bustling marketplaces, where scribes and scholars exchange knowledge and trade powerful artifacts imbued with magical code. The air is filled with the soft hum of spells being cast and the rhythmic sound of keyboards clacking.\n", 109 | "\n", 110 | "2. Volleyshore: As you venture down the coastline, the serene beaches of Volleyshore come into view. This coastal region is renowned for its beach volleyball tournaments, where skilled players from all corners of Serpentia gather to showcase their talents. The sand sparkles in the sunlight, and the cheers of the crowd echo through the air as powerful spikes and agile dives bring the matches to life.\n", 111 | "\n", 112 | "3. Binary Peaks: Beyond the sprawling beaches lie the majestic Binary Peaks, a mountain range shrouded in mist and mystery. Within the peaks lies an ancient temple, said to be the birthplace of Zephyrus himself. It is a treacherous journey to reach the temple, but those who succeed are granted incredible powers and insights into the realm's deepest secrets.\n", 113 | "\n", 114 | "Character Sheets:\n", 115 | "\n", 116 | "1. Name: Cedric Windrider\n", 117 | " Class: Python Sorcerer\n", 118 | "\n", 119 | " Background: Cedric hails from a long line of esteemed Python sorcerers and is known for his unparalleled mastery over the arcane coding arts. With his flowing robes and a staff adorned with intricate Python symbols, Cedric is a true force to be reckoned with. His spells can manipulate data, summon algorithms, and even control the very fabric of the digital realm. He seeks to bring balance to Serpentia by merging his coding prowess with the skills of beach volleyball.\n", 120 | "\n", 121 | " Abilities:\n", 122 | " - Code Mastery: Cedric can effortlessly conjure complex Python spells, delving deep into the realm's digital architecture to bend it to his will.\n", 123 | " - Volleyball Technique: Cedric has honed his volleyball skills, combining precise serves, spikes, and blocks with his magical abilities to dominate the court.\n", 124 | "\n", 125 | "2. Name: Astrid Sunstrike\n", 126 | " Class: Volleyblade Ranger\n", 127 | "\n", 128 | " Background: Astrid is a fierce and agile warrior who wields her volleyball racket with deadly precision. Born in the coastal town of Volleyshore, she has trained relentlessly to become a Volleyblade Ranger, a warrior who blends the grace of beach volleyball with the art of swordplay. Clad in light armor and with her trusty racket, Astrid is a formidable opponent both on the battlefield and the volleyball court.\n", 129 | "\n", 130 | " Abilities:\n", 131 | " - Volleyblade Mastery: Astrid's skilled swings and spins with her racket can slice through the air, sending powerful shockwaves towards her enemies or redirecting incoming projectiles.\n", 132 | " - Agile Footwork: Astrid's training in beach volleyball has honed her agility, allowing her to dodge attacks and maneuver swiftly on the sand, gaining a tactical advantage in battles and matches alike.\n", 133 | "\n", 134 | "Prepare yourselves, brave adventurers, as you embark on an epic journey through the magical realm of Serpentia, where the realms of Python software development and beach volleyball collide. Harness the powers of code and the skills of the court to save the realm from the encroaching darkness and become legends in The Legend of Zephyrus: Sands of Serpentia!\n" 135 | ] 136 | } 137 | ], 138 | "source": [ 139 | "response = ai(\"Python software development and beach volleyball\")\n", 140 | "print(response)" 141 | ] 142 | }, 143 | { 144 | "attachments": {}, 145 | "cell_type": "markdown", 146 | "metadata": {}, 147 | "source": [ 148 | "Evocative, but a bit disorganized. If we instead allow for structured data output that follows specifications, then we'll have a lot more flexibility both in terms of directing generation, and playing with the resulting output.\n", 149 | "\n", 150 | "That's where the `schema_output` field comes in when generating. If you construct a schema with pydantic ,which is also installed with simpleaichat as it is used heavily internally, then the output will generally follow the schema you provide!\n", 151 | "\n", 152 | "We want an output containing the setting **name** and **description**, along with a list of player characters. Since each character has its own attributes, and we may want the model to generate multiple chatacters, we'll define a schema for that first.\n", 153 | "\n", 154 | "We must also set a description for each field, can provide further hints to ChatGPT for how to guide generation. There is a _lot_ of flexibility here!\n" 155 | ] 156 | }, 157 | { 158 | "cell_type": "code", 159 | "execution_count": 5, 160 | "metadata": {}, 161 | "outputs": [], 162 | "source": [ 163 | "class player_character(BaseModel):\n", 164 | " name: str = Field(description=\"Character name\")\n", 165 | " race: str = Field(description=\"Character race\")\n", 166 | " job: str = Field(description=\"Character class/job\")\n", 167 | " story: str = Field(description=\"Three-sentence character history\")\n", 168 | " feats: List[str] = Field(description=\"Character feats\")" 169 | ] 170 | }, 171 | { 172 | "attachments": {}, 173 | "cell_type": "markdown", 174 | "metadata": {}, 175 | "source": [ 176 | "An important note: with this new ChatGPT model, the fields are generated _in order_ at runtime according to the schema. Therefore, the order of the fields specified is important! Try to chain information!\n", 177 | "\n", 178 | "Now we can build the schema for the TTRPG we will send to ChatGPT. In this case, we will order the fields such that we generate `description` and then `name`, as the former will be more imaginative and the latter can be infered from it. We will also add a list of player characters using the player character schema above.\n", 179 | "\n", 180 | "Lastly, we will also include a docstring for the schema class; the specifics don't matter but it can provide another editorial hint.\n" 181 | ] 182 | }, 183 | { 184 | "cell_type": "code", 185 | "execution_count": 6, 186 | "metadata": {}, 187 | "outputs": [], 188 | "source": [ 189 | "class write_ttrpg_setting(BaseModel):\n", 190 | " \"\"\"Write a fun and innovative TTRPG\"\"\"\n", 191 | "\n", 192 | " description: str = Field(\n", 193 | " description=\"Detailed description of the setting in the voice of the DM\"\n", 194 | " )\n", 195 | " name: str = Field(description=\"Name of the setting\")\n", 196 | " pcs: List[player_character] = Field(description=\"Player characters of the TTRPG\")" 197 | ] 198 | }, 199 | { 200 | "cell_type": "code", 201 | "execution_count": 7, 202 | "metadata": {}, 203 | "outputs": [ 204 | { 205 | "name": "stdout", 206 | "output_type": "stream", 207 | "text": [ 208 | "{\n", 209 | " \"description\": \"Welcome to the sun-kissed shores of Pythos, a land where software development and beach volleyball intertwine. In this unique setting, the power of Python programming and the thrill of competitive sports combine to create an unforgettable adventure. The land of Pythos is known for its pristine beaches, crystal-clear waters, and thriving tech industry. The locals, known as the Code Warriors, rely on their exceptional coding skills to develop cutting-edge software and maintain the digital infrastructure of the realm. But it's not all work and no play in Pythos. The Code Warriors also indulge in their passion for beach volleyball, competing in intense matches with rival teams from neighboring lands. As a player in this TTRPG, you will embark on a journey to master the art of Python software development and become a legendary beach volleyball player. Are you ready to dive into the world of Pythos?\",\n", 210 | " \"name\": \"Pythos: Code Warriors\",\n", 211 | " \"pcs\": [\n", 212 | " {\n", 213 | " \"name\": \"Aurora\",\n", 214 | " \"race\": \"Human\",\n", 215 | " \"job\": \"Python Developer\",\n", 216 | " \"story\": \"Aurora is a talented programmer who grew up in the bustling city of Pythopolis. From a young age, she displayed an aptitude for coding and quickly rose through the ranks of the tech industry. However, she yearned for something more than just a desk job. Inspired by the Code Warriors, Aurora decided to combine her passion for programming with her love for beach volleyball. With her analytical mind and nimble fingers, she hopes to revolutionize the sport with her innovative Python-powered techniques.\",\n", 217 | " \"feats\": [\n", 218 | " \"Pythonic Precision: Aurora's code is elegant and efficient, allowing her to execute complex strategies with ease.\",\n", 219 | " \"Volleyball Virtuoso: Aurora's exceptional hand-eye coordination and quick reflexes make her a formidable opponent on the beach volleyball court.\",\n", 220 | " \"Tech Savant: Aurora has an encyclopedic knowledge of Python libraries and frameworks, giving her an edge in both software development and game analysis.\"\n", 221 | " ]\n", 222 | " },\n", 223 | " {\n", 224 | " \"name\": \"Blaze\",\n", 225 | " \"race\": \"Elf\",\n", 226 | " \"job\": \"Volleyball Coach\",\n", 227 | " \"story\": \"Blaze is a wise and experienced elf who has dedicated his life to coaching aspiring beach volleyball players. He possesses an intimate understanding of the game and its intricacies, having spent centuries honing his skills on the sandy courts of Pythos. Blaze believes that the fusion of technology and sports is the key to unlocking the true potential of beach volleyball. As a coach, he imparts his wisdom to the Code Warriors, helping them master both the physical and digital aspects of the game.\",\n", 228 | " \"feats\": [\n", 229 | " \"Master Strategist: Blaze's strategic mind allows him to devise ingenious game plans that exploit the weaknesses of any opponent.\",\n", 230 | " \"Eternal Youth: As an elf, Blaze's agelessness grants him unrivaled stamina, agility, and reflexes on the volleyball court.\",\n", 231 | " \"Tech Guru: Blaze possesses a deep understanding of Python programming and uses it to analyze match data and optimize training routines.\"\n", 232 | " ]\n", 233 | " }\n", 234 | " ]\n", 235 | "}\n" 236 | ] 237 | } 238 | ], 239 | "source": [ 240 | "response_structured = ai(\n", 241 | " \"Python software development and beach volleyball\", output_schema=write_ttrpg_setting\n", 242 | ")\n", 243 | "\n", 244 | "# orjson.dumps preserves field order from the ChatGPT API\n", 245 | "print(orjson.dumps(response_structured, option=orjson.OPT_INDENT_2).decode())" 246 | ] 247 | }, 248 | { 249 | "attachments": {}, 250 | "cell_type": "markdown", 251 | "metadata": {}, 252 | "source": [ 253 | "Since the output is structured, we can parse it as we want.\n", 254 | "\n", 255 | "For example, if we just want the setting name:\n" 256 | ] 257 | }, 258 | { 259 | "cell_type": "code", 260 | "execution_count": 8, 261 | "metadata": {}, 262 | "outputs": [ 263 | { 264 | "data": { 265 | "text/plain": [ 266 | "'Pythos: Code Warriors'" 267 | ] 268 | }, 269 | "execution_count": 8, 270 | "metadata": {}, 271 | "output_type": "execute_result" 272 | } 273 | ], 274 | "source": [ 275 | "response_structured[\"name\"]" 276 | ] 277 | }, 278 | { 279 | "attachments": {}, 280 | "cell_type": "markdown", 281 | "metadata": {}, 282 | "source": [ 283 | "Or if we just the names of the player characters:\n" 284 | ] 285 | }, 286 | { 287 | "cell_type": "code", 288 | "execution_count": 9, 289 | "metadata": {}, 290 | "outputs": [ 291 | { 292 | "data": { 293 | "text/plain": [ 294 | "['Aurora', 'Blaze']" 295 | ] 296 | }, 297 | "execution_count": 9, 298 | "metadata": {}, 299 | "output_type": "execute_result" 300 | } 301 | ], 302 | "source": [ 303 | "[x[\"name\"] for x in response_structured[\"pcs\"]]" 304 | ] 305 | }, 306 | { 307 | "attachments": {}, 308 | "cell_type": "markdown", 309 | "metadata": {}, 310 | "source": [ 311 | "## Structured Output and Structured Input\n", 312 | "\n", 313 | "Now that we have a schema for a TTRPG setting, we can use the same hints we defined to help generation of a TTRPG adventure!\n", 314 | "\n", 315 | "First, we convert the structured `dict` data to a pydantic object with that schema with `parse_obj`:\n" 316 | ] 317 | }, 318 | { 319 | "cell_type": "code", 320 | "execution_count": 10, 321 | "metadata": {}, 322 | "outputs": [], 323 | "source": [ 324 | "input_ttrpg = write_ttrpg_setting.model_validate(response_structured)" 325 | ] 326 | }, 327 | { 328 | "attachments": {}, 329 | "cell_type": "markdown", 330 | "metadata": {}, 331 | "source": [ 332 | "Next, we define a schema for a list of events. To keep things simple, we'll just do **dialogue** and **setting** events. (a proper TTRPG would likely have a more detailed combat system!)\n", 333 | "\n", 334 | "There are a few other helpful object types you can use to control output:\n", 335 | "\n", 336 | "- `Literal`, to force a certain range of values.\n", 337 | "- `Union` can be used to have the model select from a set of schema. For example we have one schema for `Dialogue` and one schema for `Setting`: if unioned, the model will use only one of them, which allows for token-saving output.\n", 338 | "\n", 339 | "Lastly, if the `Field(description=...)` pattern is too wordy, you can use `fd` which is a shortcut.\n" 340 | ] 341 | }, 342 | { 343 | "cell_type": "code", 344 | "execution_count": 25, 345 | "metadata": {}, 346 | "outputs": [], 347 | "source": [ 348 | "from simpleaichat.utils import fd\n", 349 | "\n", 350 | "\n", 351 | "class Dialogue(BaseModel):\n", 352 | " character_name: str = fd(\"Character name\")\n", 353 | " dialogue: str = fd(\"Dialogue from the character\")\n", 354 | "\n", 355 | "\n", 356 | "class Setting(BaseModel):\n", 357 | " description: str = fd(\n", 358 | " \"Detailed setting or event description, e.g. The sun was bright.\"\n", 359 | " )\n", 360 | "\n", 361 | "\n", 362 | "class Event(BaseModel):\n", 363 | " type: Literal[\"setting\", \"conversation\"] = fd(\n", 364 | " \"Whether the event is a scene setting or a conversation by an NPC\"\n", 365 | " )\n", 366 | " data: Union[Dialogue, Setting] = fd(\"Event data\")\n", 367 | "\n", 368 | "\n", 369 | "class write_ttrpg_story(BaseModel):\n", 370 | " \"\"\"Write an award-winning TTRPG story\"\"\"\n", 371 | "\n", 372 | " events: List[Event] = fd(\"All events in a TTRPG campaign.\")" 373 | ] 374 | }, 375 | { 376 | "attachments": {}, 377 | "cell_type": "markdown", 378 | "metadata": {}, 379 | "source": [ 380 | "Lastly, we'll need a new system prompt since we have a different goal.\n" 381 | ] 382 | }, 383 | { 384 | "cell_type": "code", 385 | "execution_count": 26, 386 | "metadata": {}, 387 | "outputs": [], 388 | "source": [ 389 | "system_prompt_event = \"\"\"You are a world-renowned game master (GM) of tabletop role-playing games (RPGs).\n", 390 | "\n", 391 | "Write a complete three-act story in 10 events with a shocking twist ending using the data from the input_ttrpg function. Write the player characters as a TTRPG party fighting against a new evil.\n", 392 | "\n", 393 | "In the second (2nd) event, the party must be formed.\n", 394 | "\n", 395 | "Rules you MUST follow:\n", 396 | "- Always write in the style of 80's fantasy novels.\n", 397 | "- All names you create must be creative and unique. Always subvert expectations.\"\"\"" 398 | ] 399 | }, 400 | { 401 | "attachments": {}, 402 | "cell_type": "markdown", 403 | "metadata": {}, 404 | "source": [ 405 | "For the final call, we will need the parsed `input_ttrpg` object as the new \"prompt\", plus the `write_ttrpg_setting` schema used to build it as the `input_schema`.\n", 406 | "\n", 407 | "Putting it all together:\n" 408 | ] 409 | }, 410 | { 411 | "cell_type": "code", 412 | "execution_count": 27, 413 | "metadata": {}, 414 | "outputs": [ 415 | { 416 | "name": "stdout", 417 | "output_type": "stream", 418 | "text": [ 419 | "{\n", 420 | " \"events\": [\n", 421 | " {\n", 422 | " \"type\": \"setting\",\n", 423 | " \"data\": {\n", 424 | " \"description\": \"The sun rises over the golden shores of Pythos, casting a warm glow on the city of Pythopolis. The bustling metropolis is a hub of technology and beach volleyball, where the Code Warriors, a group of talented programmers and volleyball enthusiasts, reside. In the heart of the city, the Code Warriors' headquarters stands tall, a symbol of their dedication to both software development and sports. Inside, Aurora, a talented Python developer, and Blaze, a wise volleyball coach, prepare for a fateful meeting that will change their lives forever.\"\n", 425 | " }\n", 426 | " },\n", 427 | " {\n", 428 | " \"type\": \"conversation\",\n", 429 | " \"data\": {\n", 430 | " \"character_name\": \"Aurora\",\n", 431 | " \"dialogue\": \"Blaze, I've been following your coaching career for years. Your innovative techniques have revolutionized the game of beach volleyball. I want to combine my programming skills with my love for volleyball, and I believe you can help me become the best.\"\n", 432 | " }\n", 433 | " },\n", 434 | " {\n", 435 | " \"type\": \"conversation\",\n", 436 | " \"data\": {\n", 437 | " \"character_name\": \"Blaze\",\n", 438 | " \"dialogue\": \"Ah, Aurora, I've heard of your coding prowess. Your Python skills are legendary in the tech industry. If you're willing to put in the work, I can help you unlock the true potential of beach volleyball. Together, we can create a new era of sports technology.\"\n", 439 | " }\n", 440 | " },\n", 441 | " {\n", 442 | " \"type\": \"conversation\",\n", 443 | " \"data\": {\n", 444 | " \"character_name\": \"Aurora\",\n", 445 | " \"dialogue\": \"I'm ready to dive in, Blaze. I want to use Python to analyze the game, develop new strategies, and push the boundaries of what's possible on the court.\"\n", 446 | " }\n", 447 | " },\n", 448 | " {\n", 449 | " \"type\": \"conversation\",\n", 450 | " \"data\": {\n", 451 | " \"character_name\": \"Blaze\",\n", 452 | " \"dialogue\": \"Then we shall embark on this journey together, Aurora. But first, we must gather a team of like-minded individuals who share our passion for both programming and beach volleyball.\"\n", 453 | " }\n", 454 | " },\n", 455 | " {\n", 456 | " \"type\": \"setting\",\n", 457 | " \"data\": {\n", 458 | " \"description\": \"Aurora and Blaze set out on a quest to recruit the most talented individuals in Pythopolis. They scour the tech hubs, beachside cafes, and coding competitions, seeking those who possess the perfect blend of coding skills and volleyball prowess. After weeks of searching, they finally assemble a team of four exceptional individuals who are ready to join their cause.\"\n", 459 | " }\n", 460 | " },\n", 461 | " {\n", 462 | " \"type\": \"conversation\",\n", 463 | " \"data\": {\n", 464 | " \"character_name\": \"Aurora\",\n", 465 | " \"dialogue\": \"Welcome, my fellow Code Warriors! Together, we will combine the power of Python programming and the art of beach volleyball to achieve greatness. Introduce yourselves, and let us know how you plan to contribute to our mission.\"\n", 466 | " }\n", 467 | " },\n", 468 | " {\n", 469 | " \"type\": \"conversation\",\n", 470 | " \"data\": {\n", 471 | " \"character_name\": \"Samurai\",\n", 472 | " \"dialogue\": \"I am Samurai, a master of precision and discipline. My coding skills are unmatched, and my agility on the volleyball court is unparalleled. With my strategic mind and lightning-fast reflexes, I will ensure victory for our team.\"\n", 473 | " }\n", 474 | " },\n", 475 | " {\n", 476 | " \"type\": \"conversation\",\n", 477 | " \"data\": {\n", 478 | " \"character_name\": \"Pixel\",\n", 479 | " \"dialogue\": \"Greetings, Code Warriors! I am Pixel, the pixel-perfect programmer. My attention to detail and creative problem-solving abilities make me a valuable asset to any team. On the volleyball court, my quick thinking and adaptability will outwit our opponents.\"\n", 480 | " }\n", 481 | " },\n", 482 | " {\n", 483 | " \"type\": \"conversation\",\n", 484 | " \"data\": {\n", 485 | " \"character_name\": \"Nebula\",\n", 486 | " \"dialogue\": \"I am Nebula, a cosmic coder with a passion for the stars and the digital realm. My expertise lies in data analysis and visualization. I will harness the power of Python to uncover hidden patterns in our opponents' strategies and guide us to victory.\"\n", 487 | " }\n", 488 | " },\n", 489 | " {\n", 490 | " \"type\": \"conversation\",\n", 491 | " \"data\": {\n", 492 | " \"character_name\": \"Blitz\",\n", 493 | " \"dialogue\": \"Greetings, Code Warriors! I am Blitz, the lightning-fast developer. My coding speed is unmatched, and my agility on the volleyball court is electrifying. With my lightning-quick reflexes and powerful spikes, I will leave our opponents in awe.\"\n", 494 | " }\n", 495 | " }\n", 496 | " ]\n", 497 | "}\n" 498 | ] 499 | } 500 | ], 501 | "source": [ 502 | "ai_2 = AIChat(system=system_prompt_event, model=model, save_messages=False, api_key=api_key)\n", 503 | "\n", 504 | "response_story = ai_2(\n", 505 | " input_ttrpg, input_schema=write_ttrpg_setting, output_schema=write_ttrpg_story\n", 506 | ")\n", 507 | "\n", 508 | "print(orjson.dumps(response_story, option=orjson.OPT_INDENT_2).decode())" 509 | ] 510 | }, 511 | { 512 | "attachments": {}, 513 | "cell_type": "markdown", 514 | "metadata": {}, 515 | "source": [ 516 | "Now that we have a structured output, we can output it like a story, with custom and consistent formatting!" 517 | ] 518 | }, 519 | { 520 | "cell_type": "code", 521 | "execution_count": 28, 522 | "metadata": {}, 523 | "outputs": [ 524 | { 525 | "data": { 526 | "text/html": [ 527 | "
The sun rises over the golden shores of Pythos, casting a \n",
528 |        "warm glow on the city of Pythopolis. The bustling metropolis\n",
529 |        "is a hub of technology and beach volleyball, where the Code \n",
530 |        "Warriors, a group of talented programmers and volleyball \n",
531 |        "enthusiasts, reside. In the heart of the city, the Code \n",
532 |        "Warriors' headquarters stands tall, a symbol of their \n",
533 |        "dedication to both software development and sports. Inside, \n",
534 |        "Aurora, a talented Python developer, and Blaze, a wise \n",
535 |        "volleyball coach, prepare for a fateful meeting that will \n",
536 |        "change their lives forever.\n",
537 |        "
\n" 538 | ], 539 | "text/plain": [ 540 | "\u001b[3mThe sun rises over the golden shores of Pythos, casting a \u001b[0m\n", 541 | "\u001b[3mwarm glow on the city of Pythopolis. The bustling metropolis\u001b[0m\n", 542 | "\u001b[3mis a hub of technology and beach volleyball, where the Code \u001b[0m\n", 543 | "\u001b[3mWarriors, a group of talented programmers and volleyball \u001b[0m\n", 544 | "\u001b[3menthusiasts, reside. In the heart of the city, the Code \u001b[0m\n", 545 | "\u001b[3mWarriors' headquarters stands tall, a symbol of their \u001b[0m\n", 546 | "\u001b[3mdedication to both software development and sports. Inside, \u001b[0m\n", 547 | "\u001b[3mAurora, a talented Python developer, and Blaze, a wise \u001b[0m\n", 548 | "\u001b[3mvolleyball coach, prepare for a fateful meeting that will \u001b[0m\n", 549 | "\u001b[3mchange their lives forever.\u001b[0m\n" 550 | ] 551 | }, 552 | "metadata": {}, 553 | "output_type": "display_data" 554 | }, 555 | { 556 | "data": { 557 | "text/html": [ 558 | "
Aurora: Blaze, I've been following your coaching career for \n",
559 |        "years. Your innovative techniques have revolutionized the \n",
560 |        "game of beach volleyball. I want to combine my programming \n",
561 |        "skills with my love for volleyball, and I believe you can \n",
562 |        "help me become the best.\n",
563 |        "
\n" 564 | ], 565 | "text/plain": [ 566 | "\u001b[1mAurora\u001b[0m: Blaze, I've been following your coaching career for \n", 567 | "years. Your innovative techniques have revolutionized the \n", 568 | "game of beach volleyball. I want to combine my programming \n", 569 | "skills with my love for volleyball, and I believe you can \n", 570 | "help me become the best.\n" 571 | ] 572 | }, 573 | "metadata": {}, 574 | "output_type": "display_data" 575 | }, 576 | { 577 | "data": { 578 | "text/html": [ 579 | "
Blaze: Ah, Aurora, I've heard of your coding prowess. Your \n",
580 |        "Python skills are legendary in the tech industry. If you're \n",
581 |        "willing to put in the work, I can help you unlock the true \n",
582 |        "potential of beach volleyball. Together, we can create a new\n",
583 |        "era of sports technology.\n",
584 |        "
\n" 585 | ], 586 | "text/plain": [ 587 | "\u001b[1mBlaze\u001b[0m: Ah, Aurora, I've heard of your coding prowess. Your \n", 588 | "Python skills are legendary in the tech industry. If you're \n", 589 | "willing to put in the work, I can help you unlock the true \n", 590 | "potential of beach volleyball. Together, we can create a new\n", 591 | "era of sports technology.\n" 592 | ] 593 | }, 594 | "metadata": {}, 595 | "output_type": "display_data" 596 | }, 597 | { 598 | "data": { 599 | "text/html": [ 600 | "
Aurora: I'm ready to dive in, Blaze. I want to use Python to\n",
601 |        "analyze the game, develop new strategies, and push the \n",
602 |        "boundaries of what's possible on the court.\n",
603 |        "
\n" 604 | ], 605 | "text/plain": [ 606 | "\u001b[1mAurora\u001b[0m: I'm ready to dive in, Blaze. I want to use Python to\n", 607 | "analyze the game, develop new strategies, and push the \n", 608 | "boundaries of what's possible on the court.\n" 609 | ] 610 | }, 611 | "metadata": {}, 612 | "output_type": "display_data" 613 | }, 614 | { 615 | "data": { 616 | "text/html": [ 617 | "
Blaze: Then we shall embark on this journey together, \n",
618 |        "Aurora. But first, we must gather a team of like-minded \n",
619 |        "individuals who share our passion for both programming and \n",
620 |        "beach volleyball.\n",
621 |        "
\n" 622 | ], 623 | "text/plain": [ 624 | "\u001b[1mBlaze\u001b[0m: Then we shall embark on this journey together, \n", 625 | "Aurora. But first, we must gather a team of like-minded \n", 626 | "individuals who share our passion for both programming and \n", 627 | "beach volleyball.\n" 628 | ] 629 | }, 630 | "metadata": {}, 631 | "output_type": "display_data" 632 | }, 633 | { 634 | "data": { 635 | "text/html": [ 636 | "
Aurora and Blaze set out on a quest to recruit the most \n",
637 |        "talented individuals in Pythopolis. They scour the tech \n",
638 |        "hubs, beachside cafes, and coding competitions, seeking \n",
639 |        "those who possess the perfect blend of coding skills and \n",
640 |        "volleyball prowess. After weeks of searching, they finally \n",
641 |        "assemble a team of four exceptional individuals who are \n",
642 |        "ready to join their cause.\n",
643 |        "
\n" 644 | ], 645 | "text/plain": [ 646 | "\u001b[3mAurora and Blaze set out on a quest to recruit the most \u001b[0m\n", 647 | "\u001b[3mtalented individuals in Pythopolis. They scour the tech \u001b[0m\n", 648 | "\u001b[3mhubs, beachside cafes, and coding competitions, seeking \u001b[0m\n", 649 | "\u001b[3mthose who possess the perfect blend of coding skills and \u001b[0m\n", 650 | "\u001b[3mvolleyball prowess. After weeks of searching, they finally \u001b[0m\n", 651 | "\u001b[3massemble a team of four exceptional individuals who are \u001b[0m\n", 652 | "\u001b[3mready to join their cause.\u001b[0m\n" 653 | ] 654 | }, 655 | "metadata": {}, 656 | "output_type": "display_data" 657 | }, 658 | { 659 | "data": { 660 | "text/html": [ 661 | "
Aurora: Welcome, my fellow Code Warriors! Together, we will \n",
662 |        "combine the power of Python programming and the art of beach\n",
663 |        "volleyball to achieve greatness. Introduce yourselves, and \n",
664 |        "let us know how you plan to contribute to our mission.\n",
665 |        "
\n" 666 | ], 667 | "text/plain": [ 668 | "\u001b[1mAurora\u001b[0m: Welcome, my fellow Code Warriors! Together, we will \n", 669 | "combine the power of Python programming and the art of beach\n", 670 | "volleyball to achieve greatness. Introduce yourselves, and \n", 671 | "let us know how you plan to contribute to our mission.\n" 672 | ] 673 | }, 674 | "metadata": {}, 675 | "output_type": "display_data" 676 | }, 677 | { 678 | "data": { 679 | "text/html": [ 680 | "
Samurai: I am Samurai, a master of precision and discipline.\n",
681 |        "My coding skills are unmatched, and my agility on the \n",
682 |        "volleyball court is unparalleled. With my strategic mind and\n",
683 |        "lightning-fast reflexes, I will ensure victory for our team.\n",
684 |        "
\n" 685 | ], 686 | "text/plain": [ 687 | "\u001b[1mSamurai\u001b[0m: I am Samurai, a master of precision and discipline.\n", 688 | "My coding skills are unmatched, and my agility on the \n", 689 | "volleyball court is unparalleled. With my strategic mind and\n", 690 | "lightning-fast reflexes, I will ensure victory for our team.\n" 691 | ] 692 | }, 693 | "metadata": {}, 694 | "output_type": "display_data" 695 | }, 696 | { 697 | "data": { 698 | "text/html": [ 699 | "
Pixel: Greetings, Code Warriors! I am Pixel, the \n",
700 |        "pixel-perfect programmer. My attention to detail and \n",
701 |        "creative problem-solving abilities make me a valuable asset \n",
702 |        "to any team. On the volleyball court, my quick thinking and \n",
703 |        "adaptability will outwit our opponents.\n",
704 |        "
\n" 705 | ], 706 | "text/plain": [ 707 | "\u001b[1mPixel\u001b[0m: Greetings, Code Warriors! I am Pixel, the \n", 708 | "pixel-perfect programmer. My attention to detail and \n", 709 | "creative problem-solving abilities make me a valuable asset \n", 710 | "to any team. On the volleyball court, my quick thinking and \n", 711 | "adaptability will outwit our opponents.\n" 712 | ] 713 | }, 714 | "metadata": {}, 715 | "output_type": "display_data" 716 | }, 717 | { 718 | "data": { 719 | "text/html": [ 720 | "
Nebula: I am Nebula, a cosmic coder with a passion for the \n",
721 |        "stars and the digital realm. My expertise lies in data \n",
722 |        "analysis and visualization. I will harness the power of \n",
723 |        "Python to uncover hidden patterns in our opponents' \n",
724 |        "strategies and guide us to victory.\n",
725 |        "
\n" 726 | ], 727 | "text/plain": [ 728 | "\u001b[1mNebula\u001b[0m: I am Nebula, a cosmic coder with a passion for the \n", 729 | "stars and the digital realm. My expertise lies in data \n", 730 | "analysis and visualization. I will harness the power of \n", 731 | "Python to uncover hidden patterns in our opponents' \n", 732 | "strategies and guide us to victory.\n" 733 | ] 734 | }, 735 | "metadata": {}, 736 | "output_type": "display_data" 737 | }, 738 | { 739 | "data": { 740 | "text/html": [ 741 | "
Blitz: Greetings, Code Warriors! I am Blitz, the \n",
742 |        "lightning-fast developer. My coding speed is unmatched, and \n",
743 |        "my agility on the volleyball court is electrifying. With my \n",
744 |        "lightning-quick reflexes and powerful spikes, I will leave \n",
745 |        "our opponents in awe.\n",
746 |        "
\n" 747 | ], 748 | "text/plain": [ 749 | "\u001b[1mBlitz\u001b[0m: Greetings, Code Warriors! I am Blitz, the \n", 750 | "lightning-fast developer. My coding speed is unmatched, and \n", 751 | "my agility on the volleyball court is electrifying. With my \n", 752 | "lightning-quick reflexes and powerful spikes, I will leave \n", 753 | "our opponents in awe.\n" 754 | ] 755 | }, 756 | "metadata": {}, 757 | "output_type": "display_data" 758 | } 759 | ], 760 | "source": [ 761 | "c = Console(width=60, highlight=False)\n", 762 | "\n", 763 | "for event in response_story[\"events\"]:\n", 764 | " data = event[\"data\"]\n", 765 | " if event[\"type\"] == \"setting\":\n", 766 | " c.print(data[\"description\"], style=\"italic\")\n", 767 | " if event[\"type\"] == \"conversation\":\n", 768 | " c.print(f\"[b]{data['character_name']}[/b]: {data['dialogue']}\")" 769 | ] 770 | }, 771 | { 772 | "attachments": {}, 773 | "cell_type": "markdown", 774 | "metadata": {}, 775 | "source": [ 776 | "## MIT License\n", 777 | "\n", 778 | "Copyright (c) 2023 Max Woolf\n", 779 | "\n", 780 | "Permission is hereby granted, free of charge, to any person obtaining a copy\n", 781 | "of this software and associated documentation files (the \"Software\"), to deal\n", 782 | "in the Software without restriction, including without limitation the rights\n", 783 | "to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n", 784 | "copies of the Software, and to permit persons to whom the Software is\n", 785 | "furnished to do so, subject to the following conditions:\n", 786 | "\n", 787 | "The above copyright notice and this permission notice shall be included in all\n", 788 | "copies or substantial portions of the Software.\n", 789 | "\n", 790 | "THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n", 791 | "IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n", 792 | "FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n", 793 | "AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n", 794 | "LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n", 795 | "OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\n", 796 | "SOFTWARE." 797 | ] 798 | } 799 | ], 800 | "metadata": { 801 | "kernelspec": { 802 | "display_name": "Python 3", 803 | "language": "python", 804 | "name": "python3" 805 | }, 806 | "language_info": { 807 | "codemirror_mode": { 808 | "name": "ipython", 809 | "version": 3 810 | }, 811 | "file_extension": ".py", 812 | "mimetype": "text/x-python", 813 | "name": "python", 814 | "nbconvert_exporter": "python", 815 | "pygments_lexer": "ipython3", 816 | "version": "3.9.12" 817 | }, 818 | "orig_nbformat": 4 819 | }, 820 | "nbformat": 4, 821 | "nbformat_minor": 2 822 | } 823 | -------------------------------------------------------------------------------- /examples/notebooks/simpleaichat_async.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "attachments": {}, 5 | "cell_type": "markdown", 6 | "metadata": { 7 | "id": "tFuqw725l_4A" 8 | }, 9 | "source": [ 10 | "# Async ChatGPT with simpleaichat\n", 11 | "\n", 12 | "simpleaichat has an `AsyncAIChat` class which allows all requests to be async. This works for normal chats, streaming, and tools!\n", 13 | "\n", 14 | "With that, you can implement it into an async webapp such as FastAPI, or query multiple requests at the same time. And because `AsyncAIChat` is a session manager, you can conduct independent chat sessions!" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 2, 20 | "metadata": { 21 | "id": "JuI3NyYylPK6" 22 | }, 23 | "outputs": [], 24 | "source": [ 25 | "!pip install -q simpleaichat\n", 26 | "\n", 27 | "from simpleaichat import AsyncAIChat\n", 28 | "from getpass import getpass\n", 29 | "import asyncio" 30 | ] 31 | }, 32 | { 33 | "attachments": {}, 34 | "cell_type": "markdown", 35 | "metadata": { 36 | "id": "OSp2OXNRmi1j" 37 | }, 38 | "source": [ 39 | "For the following cell, input your OpenAI API key when prompted. **It will not be saved to the notebook**." 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": 3, 45 | "metadata": { 46 | "colab": { 47 | "base_uri": "https://localhost:8080/" 48 | }, 49 | "id": "as8yMGC8mo6I", 50 | "outputId": "8f7ea9e9-364b-4f2e-c19d-8147630dc0ac" 51 | }, 52 | "outputs": [ 53 | { 54 | "name": "stdout", 55 | "output_type": "stream", 56 | "text": [ 57 | "OpenAI Key: ··········\n" 58 | ] 59 | } 60 | ], 61 | "source": [ 62 | "api_key = getpass(\"OpenAI Key: \")" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": 12, 68 | "metadata": { 69 | "id": "NK-f4AJElPK7" 70 | }, 71 | "outputs": [], 72 | "source": [ 73 | "ai = AsyncAIChat(api_key=api_key, console=False)" 74 | ] 75 | }, 76 | { 77 | "attachments": {}, 78 | "cell_type": "markdown", 79 | "metadata": { 80 | "id": "0cTgt37umrms" 81 | }, 82 | "source": [ 83 | "## Async Generation\n", 84 | "\n", 85 | "Async calls are typical async, with an `await` keyword." 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": 13, 91 | "metadata": { 92 | "colab": { 93 | "base_uri": "https://localhost:8080/" 94 | }, 95 | "id": "R80NS8jflPK7", 96 | "outputId": "e34f515b-d3c1-40e0-e319-0c0fa3e20fdf" 97 | }, 98 | "outputs": [ 99 | { 100 | "name": "stdout", 101 | "output_type": "stream", 102 | "text": [ 103 | "The capital of California is Sacramento.\n" 104 | ] 105 | } 106 | ], 107 | "source": [ 108 | "response = await ai(\"What is the capital of California?\")\n", 109 | "print(response)" 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": 14, 115 | "metadata": { 116 | "colab": { 117 | "base_uri": "https://localhost:8080/" 118 | }, 119 | "id": "_ZxNQq8alPK8", 120 | "outputId": "765697cb-7cb7-4855-a894-1f6261039e87" 121 | }, 122 | "outputs": [ 123 | { 124 | "name": "stdout", 125 | "output_type": "stream", 126 | "text": [ 127 | "Sacramento was founded on February 27, 1850.\n" 128 | ] 129 | } 130 | ], 131 | "source": [ 132 | "response = await ai(\"When was it founded?\")\n", 133 | "print(response)" 134 | ] 135 | }, 136 | { 137 | "attachments": {}, 138 | "cell_type": "markdown", 139 | "metadata": { 140 | "id": "ivvJNrErn6aJ" 141 | }, 142 | "source": [ 143 | "Now, let's ask for multiple distinct states, at the same time. It will take roughly the same amount of time to complete as a single state!\n", 144 | "\n", 145 | "To do that, we create a session for each input state:" 146 | ] 147 | }, 148 | { 149 | "cell_type": "code", 150 | "execution_count": 34, 151 | "metadata": { 152 | "id": "6NeY0Ls_lPK9" 153 | }, 154 | "outputs": [], 155 | "source": [ 156 | "states = [\"Washington\", \"New Mexico\", \"Texas\", \"Mississippi\", \"Alaska\"]\n", 157 | "\n", 158 | "ai_2 = AsyncAIChat(api_key=api_key, console=False)\n", 159 | "for state in states:\n", 160 | " ai_2.new_session(api_key=api_key, id=state)" 161 | ] 162 | }, 163 | { 164 | "attachments": {}, 165 | "cell_type": "markdown", 166 | "metadata": { 167 | "id": "FyxsQi4RoKhF" 168 | }, 169 | "source": [ 170 | "Each call creates a task coroutine; we can store the tasks, then run them all with `asyncio.gather`." 171 | ] 172 | }, 173 | { 174 | "cell_type": "code", 175 | "execution_count": 35, 176 | "metadata": { 177 | "colab": { 178 | "base_uri": "https://localhost:8080/" 179 | }, 180 | "id": "6v7V_jM_lPLA", 181 | "outputId": "b83589d3-76de-4c28-fad6-43e1c37a6f42" 182 | }, 183 | "outputs": [ 184 | { 185 | "data": { 186 | "text/plain": [ 187 | "['The capital of Washington is Olympia.',\n", 188 | " 'The capital of New Mexico is Santa Fe.',\n", 189 | " 'The capital of Texas is Austin.',\n", 190 | " 'The capital of Mississippi is Jackson.',\n", 191 | " 'The capital of Alaska is Juneau.']" 192 | ] 193 | }, 194 | "execution_count": 35, 195 | "metadata": {}, 196 | "output_type": "execute_result" 197 | } 198 | ], 199 | "source": [ 200 | "tasks = []\n", 201 | "for state in states:\n", 202 | " tasks.append(ai_2(f\"What is the capital of {state}?\", id=state))\n", 203 | "\n", 204 | "results = await asyncio.gather(*tasks)\n", 205 | "results" 206 | ] 207 | }, 208 | { 209 | "attachments": {}, 210 | "cell_type": "markdown", 211 | "metadata": { 212 | "id": "xJTvRG2YpmPN" 213 | }, 214 | "source": [ 215 | "Now, to ask the same question to all states:" 216 | ] 217 | }, 218 | { 219 | "cell_type": "code", 220 | "execution_count": 36, 221 | "metadata": { 222 | "colab": { 223 | "base_uri": "https://localhost:8080/" 224 | }, 225 | "id": "fOggtHsiptcI", 226 | "outputId": "2208559d-57e7-42aa-d821-41ff058c4d57" 227 | }, 228 | "outputs": [ 229 | { 230 | "data": { 231 | "text/plain": [ 232 | "['Olympia was founded in 1853.',\n", 233 | " 'Santa Fe was founded in 1610, making it the oldest state capital in the United States.',\n", 234 | " 'Austin was founded on December 27, 1839.',\n", 235 | " 'Jackson was founded on December 23, 1821.',\n", 236 | " 'Juneau was founded on October 18, 1880.']" 237 | ] 238 | }, 239 | "execution_count": 36, 240 | "metadata": {}, 241 | "output_type": "execute_result" 242 | } 243 | ], 244 | "source": [ 245 | "tasks = []\n", 246 | "for state in states:\n", 247 | " tasks.append(ai_2(\"When was it founded?\", id=state))\n", 248 | "\n", 249 | "results = await asyncio.gather(*tasks)\n", 250 | "results" 251 | ] 252 | }, 253 | { 254 | "attachments": {}, 255 | "cell_type": "markdown", 256 | "metadata": { 257 | "id": "ywVjhqJHoYqb" 258 | }, 259 | "source": [ 260 | "Indeed, the messages are stored correctly by session, and are still independent between sessions." 261 | ] 262 | }, 263 | { 264 | "cell_type": "code", 265 | "execution_count": 37, 266 | "metadata": { 267 | "colab": { 268 | "base_uri": "https://localhost:8080/" 269 | }, 270 | "id": "NUfgFE4TlPLB", 271 | "outputId": "38a1b0d7-9603-42a9-b66a-a4b5e76e957c" 272 | }, 273 | "outputs": [ 274 | { 275 | "data": { 276 | "text/plain": [ 277 | "[What is the capital of Washington?,\n", 278 | " The capital of Washington is Olympia.,\n", 279 | " When was it founded?,\n", 280 | " Olympia was founded in 1853.]" 281 | ] 282 | }, 283 | "execution_count": 37, 284 | "metadata": {}, 285 | "output_type": "execute_result" 286 | } 287 | ], 288 | "source": [ 289 | "ai_2.sessions[\"Washington\"].messages" 290 | ] 291 | }, 292 | { 293 | "cell_type": "code", 294 | "execution_count": 38, 295 | "metadata": { 296 | "colab": { 297 | "base_uri": "https://localhost:8080/" 298 | }, 299 | "id": "zzbXUhgtlPLB", 300 | "outputId": "44fb73e6-902a-4e9c-a584-b2ea8943bbba" 301 | }, 302 | "outputs": [ 303 | { 304 | "data": { 305 | "text/plain": [ 306 | "[What is the capital of Texas?,\n", 307 | " The capital of Texas is Austin.,\n", 308 | " When was it founded?,\n", 309 | " Austin was founded on December 27, 1839.]" 310 | ] 311 | }, 312 | "execution_count": 38, 313 | "metadata": {}, 314 | "output_type": "execute_result" 315 | } 316 | ], 317 | "source": [ 318 | "ai_2.sessions[\"Texas\"].messages" 319 | ] 320 | }, 321 | { 322 | "attachments": {}, 323 | "cell_type": "markdown", 324 | "metadata": { 325 | "id": "n47XlrZmojU5" 326 | }, 327 | "source": [ 328 | "## Async Streaming\n", 329 | "\n", 330 | "Now, let's do the same thing, except with streaming." 331 | ] 332 | }, 333 | { 334 | "cell_type": "code", 335 | "execution_count": 20, 336 | "metadata": { 337 | "id": "cgJaPdlvoar9" 338 | }, 339 | "outputs": [], 340 | "source": [ 341 | "ai = AsyncAIChat(api_key=api_key, console=False)" 342 | ] 343 | }, 344 | { 345 | "attachments": {}, 346 | "cell_type": "markdown", 347 | "metadata": { 348 | "id": "kjyUp_PZotON" 349 | }, 350 | "source": [ 351 | "In this case, you need an async generator for the streaming call." 352 | ] 353 | }, 354 | { 355 | "cell_type": "code", 356 | "execution_count": 21, 357 | "metadata": { 358 | "colab": { 359 | "base_uri": "https://localhost:8080/" 360 | }, 361 | "id": "niO2SZN4owuZ", 362 | "outputId": "d3c28d6d-fa59-40ca-a903-e3ce8ba21ebc" 363 | }, 364 | "outputs": [ 365 | { 366 | "name": "stdout", 367 | "output_type": "stream", 368 | "text": [ 369 | "{'delta': 'The', 'response': 'The'}\n", 370 | "{'delta': ' capital', 'response': 'The capital'}\n", 371 | "{'delta': ' of', 'response': 'The capital of'}\n", 372 | "{'delta': ' California', 'response': 'The capital of California'}\n", 373 | "{'delta': ' is', 'response': 'The capital of California is'}\n", 374 | "{'delta': ' Sacramento', 'response': 'The capital of California is Sacramento'}\n", 375 | "{'delta': '.', 'response': 'The capital of California is Sacramento.'}\n" 376 | ] 377 | } 378 | ], 379 | "source": [ 380 | "async for chunk in await ai.stream(\"What is the capital of California?\"):\n", 381 | " print(chunk)" 382 | ] 383 | }, 384 | { 385 | "attachments": {}, 386 | "cell_type": "markdown", 387 | "metadata": { 388 | "id": "4um4Mv05o4OB" 389 | }, 390 | "source": [ 391 | "For multistate generation:" 392 | ] 393 | }, 394 | { 395 | "cell_type": "code", 396 | "execution_count": 39, 397 | "metadata": { 398 | "id": "Y4G3VO6qo5yy" 399 | }, 400 | "outputs": [], 401 | "source": [ 402 | "states = [\"Washington\", \"New Mexico\", \"Texas\", \"Mississippi\", \"Alaska\"]\n", 403 | "\n", 404 | "ai_2 = AsyncAIChat(api_key=api_key, console=False)\n", 405 | "for state in states:\n", 406 | " ai_2.new_session(api_key=api_key, id=state)" 407 | ] 408 | }, 409 | { 410 | "attachments": {}, 411 | "cell_type": "markdown", 412 | "metadata": { 413 | "id": "oewS5NAVpNwT" 414 | }, 415 | "source": [ 416 | "This implementation is slightly more complicated since you need to wrap each async generator in its own async function. However, it provides the best demonstration of async, as you can clearly see that each chunk is received in a different order." 417 | ] 418 | }, 419 | { 420 | "cell_type": "code", 421 | "execution_count": 40, 422 | "metadata": { 423 | "colab": { 424 | "base_uri": "https://localhost:8080/" 425 | }, 426 | "id": "HVrZ16aXpAYm", 427 | "outputId": "9015c66a-fbf1-4cb6-bbbf-ef62b1a6bf79" 428 | }, 429 | "outputs": [ 430 | { 431 | "name": "stdout", 432 | "output_type": "stream", 433 | "text": [ 434 | "{'delta': 'The', 'response': 'The'}\n", 435 | "{'delta': 'The', 'response': 'The'}\n", 436 | "{'delta': ' capital', 'response': 'The capital'}\n", 437 | "{'delta': ' capital', 'response': 'The capital'}\n", 438 | "{'delta': ' of', 'response': 'The capital of'}\n", 439 | "{'delta': ' of', 'response': 'The capital of'}\n", 440 | "{'delta': 'The', 'response': 'The'}\n", 441 | "{'delta': 'The', 'response': 'The'}\n", 442 | "{'delta': ' Washington', 'response': 'The capital of Washington'}\n", 443 | "{'delta': ' Mississippi', 'response': 'The capital of Mississippi'}\n", 444 | "{'delta': 'The', 'response': 'The'}\n", 445 | "{'delta': ' capital', 'response': 'The capital'}\n", 446 | "{'delta': ' capital', 'response': 'The capital'}\n", 447 | "{'delta': ' is', 'response': 'The capital of Washington is'}\n", 448 | "{'delta': ' is', 'response': 'The capital of Mississippi is'}\n", 449 | "{'delta': ' of', 'response': 'The capital of'}\n", 450 | "{'delta': ' capital', 'response': 'The capital'}\n", 451 | "{'delta': ' of', 'response': 'The capital of'}\n", 452 | "{'delta': ' Jackson', 'response': 'The capital of Mississippi is Jackson'}\n", 453 | "{'delta': ' Olympia', 'response': 'The capital of Washington is Olympia'}\n", 454 | "{'delta': ' Alaska', 'response': 'The capital of Alaska'}\n", 455 | "{'delta': ' New', 'response': 'The capital of New'}\n", 456 | "{'delta': '.', 'response': 'The capital of Washington is Olympia.'}\n", 457 | "{'delta': ' of', 'response': 'The capital of'}\n", 458 | "{'delta': '.', 'response': 'The capital of Mississippi is Jackson.'}\n", 459 | "{'delta': ' Mexico', 'response': 'The capital of New Mexico'}\n", 460 | "{'delta': ' is', 'response': 'The capital of Alaska is'}\n", 461 | "{'delta': ' Texas', 'response': 'The capital of Texas'}\n", 462 | "{'delta': ' is', 'response': 'The capital of New Mexico is'}\n", 463 | "{'delta': ' June', 'response': 'The capital of Alaska is June'}\n", 464 | "{'delta': ' is', 'response': 'The capital of Texas is'}\n", 465 | "{'delta': ' Santa', 'response': 'The capital of New Mexico is Santa'}\n", 466 | "{'delta': ' Austin', 'response': 'The capital of Texas is Austin'}\n", 467 | "{'delta': 'au', 'response': 'The capital of Alaska is Juneau'}\n", 468 | "{'delta': ' Fe', 'response': 'The capital of New Mexico is Santa Fe'}\n", 469 | "{'delta': '.', 'response': 'The capital of Texas is Austin.'}\n", 470 | "{'delta': '.', 'response': 'The capital of Alaska is Juneau.'}\n", 471 | "{'delta': '.', 'response': 'The capital of New Mexico is Santa Fe.'}\n" 472 | ] 473 | }, 474 | { 475 | "data": { 476 | "text/plain": [ 477 | "['The capital of Washington is Olympia.',\n", 478 | " 'The capital of New Mexico is Santa Fe.',\n", 479 | " 'The capital of Texas is Austin.',\n", 480 | " 'The capital of Mississippi is Jackson.',\n", 481 | " 'The capital of Alaska is Juneau.']" 482 | ] 483 | }, 484 | "execution_count": 40, 485 | "metadata": {}, 486 | "output_type": "execute_result" 487 | } 488 | ], 489 | "source": [ 490 | "async def capital_stream(state):\n", 491 | " async for chunk in await ai_2.stream(f\"What is the capital of {state}?\", id=state):\n", 492 | " response = chunk\n", 493 | " print(response)\n", 494 | " return response[\"response\"]\n", 495 | "\n", 496 | "tasks = []\n", 497 | "for state in states:\n", 498 | " tasks.append(capital_stream(state))\n", 499 | "\n", 500 | "results = await asyncio.gather(*tasks)\n", 501 | "results" 502 | ] 503 | }, 504 | { 505 | "attachments": {}, 506 | "cell_type": "markdown", 507 | "metadata": {}, 508 | "source": [ 509 | "## MIT License\n", 510 | "\n", 511 | "Copyright (c) 2023 Max Woolf\n", 512 | "\n", 513 | "Permission is hereby granted, free of charge, to any person obtaining a copy\n", 514 | "of this software and associated documentation files (the \"Software\"), to deal\n", 515 | "in the Software without restriction, including without limitation the rights\n", 516 | "to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n", 517 | "copies of the Software, and to permit persons to whom the Software is\n", 518 | "furnished to do so, subject to the following conditions:\n", 519 | "\n", 520 | "The above copyright notice and this permission notice shall be included in all\n", 521 | "copies or substantial portions of the Software.\n", 522 | "\n", 523 | "THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n", 524 | "IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n", 525 | "FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n", 526 | "AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n", 527 | "LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n", 528 | "OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\n", 529 | "SOFTWARE.\n" 530 | ] 531 | } 532 | ], 533 | "metadata": { 534 | "colab": { 535 | "provenance": [] 536 | }, 537 | "kernelspec": { 538 | "display_name": "Python 3", 539 | "language": "python", 540 | "name": "python3" 541 | }, 542 | "language_info": { 543 | "codemirror_mode": { 544 | "name": "ipython", 545 | "version": 3 546 | }, 547 | "file_extension": ".py", 548 | "mimetype": "text/x-python", 549 | "name": "python", 550 | "nbconvert_exporter": "python", 551 | "pygments_lexer": "ipython3", 552 | "version": "3.9.12" 553 | }, 554 | "orig_nbformat": 4 555 | }, 556 | "nbformat": 4, 557 | "nbformat_minor": 0 558 | } 559 | -------------------------------------------------------------------------------- /examples/notebooks/simpleaichat_coding.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "source": [ 6 | "# Efficient Coding Assistant with simpleaichat\n", 7 | "\n", 8 | "_Updated using `gpt-3.5-turbo-0613`_\n", 9 | "\n", 10 | "Many coders use ChatGPT for coding help, however the web interface can be slow and contain unnecessary discussion when you want code. With some system prompt engineering and simplechatapi streaming, you can cut down code time generation and costs significantly.\n", 11 | "\n", 12 | "**DISCLAIMER: Your mileage may vary in terms of code quality and accuracy in practice, but this is a good, hackable starting point.**\n", 13 | "\n" 14 | ], 15 | "metadata": { 16 | "id": "-jfTDBnMbGO3" 17 | } 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": 1, 22 | "metadata": { 23 | "id": "5ZeSKyedacCE" 24 | }, 25 | "outputs": [], 26 | "source": [ 27 | "!pip install -q simpleaichat\n", 28 | "\n", 29 | "from simpleaichat import AIChat\n", 30 | "from getpass import getpass" 31 | ] 32 | }, 33 | { 34 | "cell_type": "markdown", 35 | "source": [ 36 | "For the following cell, input your OpenAI API key when prompted. **It will not be saved to the notebook**." 37 | ], 38 | "metadata": { 39 | "id": "1kXTv7zXapit" 40 | } 41 | }, 42 | { 43 | "cell_type": "code", 44 | "source": [ 45 | "api_key = getpass(\"OpenAI Key: \")" 46 | ], 47 | "metadata": { 48 | "colab": { 49 | "base_uri": "https://localhost:8080/" 50 | }, 51 | "id": "e_3QIHtnaqdw", 52 | "outputId": "334508c7-d1da-4270-b98d-359da120ceff" 53 | }, 54 | "execution_count": 2, 55 | "outputs": [ 56 | { 57 | "name": "stdout", 58 | "output_type": "stream", 59 | "text": [ 60 | "OpenAI Key: ··········\n" 61 | ] 62 | } 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": 3, 68 | "metadata": { 69 | "id": "mxBkiR9FacCF" 70 | }, 71 | "outputs": [], 72 | "source": [ 73 | "params = {\"temperature\": 0.0} # for reproducibility\n", 74 | "model = \"gpt-3.5-turbo\" # in production, may want to use model=\"gpt-4\" if have access\n", 75 | "\n", 76 | "ai = AIChat(api_key=api_key, console=False, params=params, model=model)" 77 | ] 78 | }, 79 | { 80 | "cell_type": "markdown", 81 | "source": [ 82 | "Let's start with a simple `is_palindrome()` function in Python, and track how long it takes to run. The output of this should be similar to what is shown in the ChatGPT webapp." 83 | ], 84 | "metadata": { 85 | "id": "-4eVxf3Jawo_" 86 | } 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": 4, 91 | "metadata": { 92 | "colab": { 93 | "base_uri": "https://localhost:8080/" 94 | }, 95 | "id": "SCkboiHsacCG", 96 | "outputId": "1608a19a-23e2-4555-c096-98d5bd08f51f" 97 | }, 98 | "outputs": [ 99 | { 100 | "output_type": "stream", 101 | "name": "stdout", 102 | "text": [ 103 | "Sure! Here's an example of an is_palindrome() function in Python:\n", 104 | "\n", 105 | "```python\n", 106 | "def is_palindrome(word):\n", 107 | " # Convert the word to lowercase and remove any spaces\n", 108 | " word = word.lower().replace(\" \", \"\")\n", 109 | " \n", 110 | " # Check if the word is equal to its reverse\n", 111 | " if word == word[::-1]:\n", 112 | " return True\n", 113 | " else:\n", 114 | " return False\n", 115 | "```\n", 116 | "\n", 117 | "You can use this function to check if a word is a palindrome. It will return True if the word is a palindrome and False otherwise.\n", 118 | "\n", 119 | "\n", 120 | "\n", 121 | "CPU times: user 23.7 ms, sys: 1.32 ms, total: 25 ms\n", 122 | "Wall time: 2.49 s\n" 123 | ] 124 | } 125 | ], 126 | "source": [ 127 | "%%time\n", 128 | "response = ai(\"Write an is_palindrome() function in Python.\")\n", 129 | "print(response)\n", 130 | "print(\"\\n\\n\") # separate time from generated text for readability" 131 | ] 132 | }, 133 | { 134 | "cell_type": "markdown", 135 | "source": [ 136 | "That's the typical implementation. However, there's a trick to cut the processing time in half, well known by technical hiring managers who want to trip up prospective candidates.\n", 137 | "\n", 138 | "ChatGPT outputs the statistically most common implementation, but it's not necessairily the best. A second pass allows ChatGPT to refine its output." 139 | ], 140 | "metadata": { 141 | "id": "dBOCM4LMbz3C" 142 | } 143 | }, 144 | { 145 | "cell_type": "code", 146 | "execution_count": 5, 147 | "metadata": { 148 | "colab": { 149 | "base_uri": "https://localhost:8080/" 150 | }, 151 | "id": "BwsfrcPHacCH", 152 | "outputId": "dcbfdce5-8f72-4d00-cfe9-6eef41018639" 153 | }, 154 | "outputs": [ 155 | { 156 | "output_type": "stream", 157 | "name": "stdout", 158 | "text": [ 159 | "Certainly! Here's an optimized version of the is_palindrome() function that uses two pointers to check if a word is a palindrome:\n", 160 | "\n", 161 | "```python\n", 162 | "def is_palindrome(word):\n", 163 | " # Convert the word to lowercase and remove any spaces\n", 164 | " word = word.lower().replace(\" \", \"\")\n", 165 | " \n", 166 | " # Initialize two pointers, one at the start and one at the end of the word\n", 167 | " left = 0\n", 168 | " right = len(word) - 1\n", 169 | " \n", 170 | " # Iterate until the pointers meet in the middle\n", 171 | " while left < right:\n", 172 | " # If the characters at the pointers are not equal, the word is not a palindrome\n", 173 | " if word[left] != word[right]:\n", 174 | " return False\n", 175 | " \n", 176 | " # Move the pointers towards the middle\n", 177 | " left += 1\n", 178 | " right -= 1\n", 179 | " \n", 180 | " # If the loop completes without returning False, the word is a palindrome\n", 181 | " return True\n", 182 | "```\n", 183 | "\n", 184 | "This optimized version reduces the number of comparisons by using two pointers that start at opposite ends of the word and move towards the middle. It will return True if the word is a palindrome and False otherwise.\n", 185 | "\n", 186 | "\n", 187 | "\n", 188 | "CPU times: user 30.7 ms, sys: 4.91 ms, total: 35.6 ms\n", 189 | "Wall time: 4.22 s\n" 190 | ] 191 | } 192 | ], 193 | "source": [ 194 | "%%time\n", 195 | "response = ai(\"Make it more efficient.\")\n", 196 | "print(response)\n", 197 | "print(\"\\n\\n\")" 198 | ] 199 | }, 200 | { 201 | "cell_type": "code", 202 | "execution_count": 6, 203 | "metadata": { 204 | "colab": { 205 | "base_uri": "https://localhost:8080/" 206 | }, 207 | "id": "ZWRKL30uacCH", 208 | "outputId": "ac47340b-9c7a-4517-a36a-d4fb92765311" 209 | }, 210 | "outputs": [ 211 | { 212 | "output_type": "execute_result", 213 | "data": { 214 | "text/plain": [ 215 | "511" 216 | ] 217 | }, 218 | "metadata": {}, 219 | "execution_count": 6 220 | } 221 | ], 222 | "source": [ 223 | "ai.total_length" 224 | ] 225 | }, 226 | { 227 | "cell_type": "markdown", 228 | "source": [ 229 | "In all, it took ~6 seconds and utilized 511 tokens. But there's a lot of unnecessary natter in the output:\n", 230 | "\n", 231 | "- The conversational preamble before the code\n", 232 | "- Docstrings and code comments\n", 233 | "- A long explanation of the code which may be redundant to the above\n", 234 | "\n", 235 | "All this natter adds latency and cost.\n", 236 | "\n", 237 | "The easiest technique to guide AI text generation is to use **prompt engineering**, specifically to give it a new system prompt to say precisely what you want. As of June 27th 2023, the default ChatGPT API responds very well to commands.\n", 238 | "\n", 239 | "Now, for the new `system` prompt:" 240 | ], 241 | "metadata": { 242 | "id": "okqqoe7lcqUH" 243 | } 244 | }, 245 | { 246 | "cell_type": "code", 247 | "execution_count": 7, 248 | "metadata": { 249 | "id": "iA9AASKfacCH" 250 | }, 251 | "outputs": [], 252 | "source": [ 253 | "system_optimized = \"\"\"Write a Python function based on the user input.\n", 254 | "\n", 255 | "You must obey ALL the following rules:\n", 256 | "- Only respond with the Python function.\n", 257 | "- Never put in-line comments or docstrings in your code.\"\"\"\n", 258 | "\n", 259 | "ai_2 = AIChat(api_key=api_key, system=system_optimized, model=model, params=params)" 260 | ] 261 | }, 262 | { 263 | "cell_type": "code", 264 | "execution_count": 8, 265 | "metadata": { 266 | "colab": { 267 | "base_uri": "https://localhost:8080/" 268 | }, 269 | "id": "MRYfniAXacCI", 270 | "outputId": "fecb6928-a1e8-48f6-c336-5055cbca7a38" 271 | }, 272 | "outputs": [ 273 | { 274 | "output_type": "stream", 275 | "name": "stdout", 276 | "text": [ 277 | "def is_palindrome(word):\n", 278 | " return word == word[::-1]\n", 279 | "\n", 280 | "\n", 281 | "\n", 282 | "CPU times: user 11.6 ms, sys: 4.01 ms, total: 15.6 ms\n", 283 | "Wall time: 1.04 s\n" 284 | ] 285 | } 286 | ], 287 | "source": [ 288 | "%%time\n", 289 | "response = ai_2(\"is_palindrome\")\n", 290 | "print(response)\n", 291 | "print(\"\\n\\n\")" 292 | ] 293 | }, 294 | { 295 | "cell_type": "code", 296 | "execution_count": 9, 297 | "metadata": { 298 | "colab": { 299 | "base_uri": "https://localhost:8080/" 300 | }, 301 | "id": "lGJr-GfIacCI", 302 | "outputId": "68a555bd-d3e0-4b39-cca6-a11ce6ad1455" 303 | }, 304 | "outputs": [ 305 | { 306 | "output_type": "stream", 307 | "name": "stdout", 308 | "text": [ 309 | "def is_palindrome(word):\n", 310 | " length = len(word)\n", 311 | " for i in range(length // 2):\n", 312 | " if word[i] != word[length - i - 1]:\n", 313 | " return False\n", 314 | " return True\n", 315 | "\n", 316 | "\n", 317 | "\n", 318 | "CPU times: user 19.7 ms, sys: 1.73 ms, total: 21.5 ms\n", 319 | "Wall time: 2.64 s\n" 320 | ] 321 | } 322 | ], 323 | "source": [ 324 | "%%time\n", 325 | "response = ai_2(\"Make it more efficient.\")\n", 326 | "print(response)\n", 327 | "print(\"\\n\\n\")" 328 | ] 329 | }, 330 | { 331 | "cell_type": "code", 332 | "execution_count": 10, 333 | "metadata": { 334 | "colab": { 335 | "base_uri": "https://localhost:8080/" 336 | }, 337 | "id": "8eMDXGysacCJ", 338 | "outputId": "0d6272c1-cb2a-460b-fe37-09c896b119e7" 339 | }, 340 | "outputs": [ 341 | { 342 | "output_type": "execute_result", 343 | "data": { 344 | "text/plain": [ 345 | "190" 346 | ] 347 | }, 348 | "metadata": {}, 349 | "execution_count": 10 350 | } 351 | ], 352 | "source": [ 353 | "ai_2.total_length" 354 | ] 355 | }, 356 | { 357 | "cell_type": "markdown", 358 | "source": [ 359 | "~3 seconds total with 190 tokens used: that's 2x faster at 1/3 the cost!" 360 | ], 361 | "metadata": { 362 | "id": "ua28VqKteKkr" 363 | } 364 | }, 365 | { 366 | "cell_type": "markdown", 367 | "metadata": { 368 | "id": "SaE3ftkMacCJ" 369 | }, 370 | "source": [ 371 | "## Create a Function\n", 372 | "\n", 373 | "Now we can create a function to automate the two calls we did above for any arbitrary input.\n", 374 | "\n", 375 | "For each call, we'll create an independent temporary session within simpleaichat and then clean it up. We'll also use a regex to strip unneded backticks." 376 | ] 377 | }, 378 | { 379 | "cell_type": "code", 380 | "execution_count": 11, 381 | "metadata": { 382 | "id": "mN_-3fwuacCL" 383 | }, 384 | "outputs": [], 385 | "source": [ 386 | "from uuid import uuid4\n", 387 | "import re\n", 388 | "\n", 389 | "ai_func = AIChat(api_key=api_key, console=False)\n", 390 | "def gen_code(query):\n", 391 | " id = uuid4()\n", 392 | " ai_func.new_session(api_key=api_key, id=id, system=system_optimized, params=params, model=model)\n", 393 | " _ = ai_func(query, id=id)\n", 394 | " response_optimized = ai_func(\"Make it more efficient.\", id=id)\n", 395 | "\n", 396 | " ai_func.delete_session(id=id)\n", 397 | " return response_optimized" 398 | ] 399 | }, 400 | { 401 | "cell_type": "code", 402 | "execution_count": 12, 403 | "metadata": { 404 | "colab": { 405 | "base_uri": "https://localhost:8080/" 406 | }, 407 | "id": "kIkdpBwOacCL", 408 | "outputId": "f00c393b-87c0-4bc6-9c35-d323c6de6812" 409 | }, 410 | "outputs": [ 411 | { 412 | "output_type": "stream", 413 | "name": "stdout", 414 | "text": [ 415 | "def is_palindrome(word):\n", 416 | " length = len(word)\n", 417 | " for i in range(length // 2):\n", 418 | " if word[i] != word[length - i - 1]:\n", 419 | " return False\n", 420 | " return True\n", 421 | "\n", 422 | "\n", 423 | "\n", 424 | "CPU times: user 27.8 ms, sys: 1.94 ms, total: 29.8 ms\n", 425 | "Wall time: 1.96 s\n" 426 | ] 427 | } 428 | ], 429 | "source": [ 430 | "%%time\n", 431 | "code = gen_code(\"is_palindrome\")\n", 432 | "print(code)\n", 433 | "print(\"\\n\\n\")" 434 | ] 435 | }, 436 | { 437 | "cell_type": "code", 438 | "execution_count": 13, 439 | "metadata": { 440 | "colab": { 441 | "base_uri": "https://localhost:8080/" 442 | }, 443 | "id": "S1LjZjFBacCM", 444 | "outputId": "53d96baf-1f6b-46fb-bba8-5eda772d8c87" 445 | }, 446 | "outputs": [ 447 | { 448 | "output_type": "stream", 449 | "name": "stdout", 450 | "text": [ 451 | "def reverse_string(string):\n", 452 | " return ''.join(reversed(string))\n", 453 | "\n", 454 | "\n", 455 | "\n", 456 | "CPU times: user 16.4 ms, sys: 644 µs, total: 17 ms\n", 457 | "Wall time: 1.42 s\n" 458 | ] 459 | } 460 | ], 461 | "source": [ 462 | "%%time\n", 463 | "code = gen_code(\"reverse string\")\n", 464 | "print(code)\n", 465 | "print(\"\\n\\n\")" 466 | ] 467 | }, 468 | { 469 | "cell_type": "code", 470 | "execution_count": 14, 471 | "metadata": { 472 | "colab": { 473 | "base_uri": "https://localhost:8080/" 474 | }, 475 | "id": "VdSNUiapacCM", 476 | "outputId": "207634e0-ef48-4428-a3c0-15a80ad845f8" 477 | }, 478 | "outputs": [ 479 | { 480 | "output_type": "stream", 481 | "name": "stdout", 482 | "text": [ 483 | "def pretty_print_dict(dictionary):\n", 484 | " import json\n", 485 | " print(json.dumps(dictionary, indent=4, separators=(',', ': ')))\n", 486 | "\n", 487 | "\n", 488 | "\n", 489 | "CPU times: user 17.2 ms, sys: 1.87 ms, total: 19.1 ms\n", 490 | "Wall time: 1.31 s\n" 491 | ] 492 | } 493 | ], 494 | "source": [ 495 | "%%time\n", 496 | "code = gen_code(\"pretty print dict\")\n", 497 | "print(code)\n", 498 | "print(\"\\n\\n\")" 499 | ] 500 | }, 501 | { 502 | "cell_type": "code", 503 | "execution_count": 15, 504 | "metadata": { 505 | "colab": { 506 | "base_uri": "https://localhost:8080/" 507 | }, 508 | "id": "7ppyZrFTacCM", 509 | "outputId": "a4e04c60-4681-4d7f-b846-028a9b1271f6" 510 | }, 511 | "outputs": [ 512 | { 513 | "output_type": "stream", 514 | "name": "stdout", 515 | "text": [ 516 | "import cv2\n", 517 | "\n", 518 | "def load_and_flip_image(image_path):\n", 519 | " image = cv2.imread(image_path)\n", 520 | " flipped_image = cv2.flip(image, 1)\n", 521 | " cv2.imwrite(\"flipped_image.jpg\", flipped_image)\n", 522 | " return \"flipped_image.jpg\"\n", 523 | "\n", 524 | "\n", 525 | "\n", 526 | "CPU times: user 21.1 ms, sys: 3.86 ms, total: 24.9 ms\n", 527 | "Wall time: 2.18 s\n" 528 | ] 529 | } 530 | ], 531 | "source": [ 532 | "%%time\n", 533 | "code = gen_code(\"load and flip image horizontally\")\n", 534 | "print(code)\n", 535 | "print(\"\\n\\n\")" 536 | ] 537 | }, 538 | { 539 | "cell_type": "code", 540 | "execution_count": 16, 541 | "metadata": { 542 | "colab": { 543 | "base_uri": "https://localhost:8080/" 544 | }, 545 | "id": "sNQhAKFqacCM", 546 | "outputId": "fcfc568f-0777-4df4-a8f7-20fa0e4e34d7" 547 | }, 548 | "outputs": [ 549 | { 550 | "output_type": "stream", 551 | "name": "stdout", 552 | "text": [ 553 | "import hashlib\n", 554 | "from multiprocessing import Pool, cpu_count\n", 555 | "\n", 556 | "def multiprocess_hash(data):\n", 557 | " def hash_string(string):\n", 558 | " return hashlib.sha256(string.encode()).hexdigest()\n", 559 | "\n", 560 | " num_processes = cpu_count()\n", 561 | " pool = Pool(processes=num_processes)\n", 562 | " hashed_data = pool.map(hash_string, data)\n", 563 | " pool.close()\n", 564 | " pool.join()\n", 565 | "\n", 566 | " return hashed_data\n", 567 | "\n", 568 | "\n", 569 | "\n", 570 | "CPU times: user 28.8 ms, sys: 3.82 ms, total: 32.6 ms\n", 571 | "Wall time: 3.31 s\n" 572 | ] 573 | } 574 | ], 575 | "source": [ 576 | "%%time\n", 577 | "code = gen_code(\"multiprocess hash\")\n", 578 | "print(code)\n", 579 | "print(\"\\n\\n\")" 580 | ] 581 | }, 582 | { 583 | "cell_type": "markdown", 584 | "source": [ 585 | "## Generating Optimized Code in a Single API Call w/ Structured Output Data\n", 586 | "\n", 587 | "Sometimes you may not want to make two calls to OpenAI. One hack you can do is to define an expected structured output to tell it to sequentially generate the normal code output, then the optimized output.\n", 588 | "\n", 589 | "This structure is essentially a different form of prompt engineering, but you can combine it with a system prompt if needed.\n", 590 | "\n", 591 | "This will also further increase response speed, but may not necessairly result in fewer tokens used." 592 | ], 593 | "metadata": { 594 | "id": "tbQspzba2_GO" 595 | } 596 | }, 597 | { 598 | "cell_type": "code", 599 | "source": [ 600 | "from pydantic import BaseModel, Field\n", 601 | "import orjson\n", 602 | "\n", 603 | "class write_python_function(BaseModel):\n", 604 | " \"\"\"Writes a Python function based on the user input.\"\"\"\n", 605 | " code: str = Field(description=\"Python code\")\n", 606 | " efficient_code: str = Field(description=\"More efficient Python code than previously written\")" 607 | ], 608 | "metadata": { 609 | "id": "lFCLUODV28FZ" 610 | }, 611 | "execution_count": 23, 612 | "outputs": [] 613 | }, 614 | { 615 | "cell_type": "code", 616 | "source": [ 617 | "ai_struct = AIChat(api_key=api_key, console=False, model=model, params=params, save_messages=False)" 618 | ], 619 | "metadata": { 620 | "id": "6rTqftko4N2L" 621 | }, 622 | "execution_count": 24, 623 | "outputs": [] 624 | }, 625 | { 626 | "cell_type": "code", 627 | "source": [ 628 | "%%time\n", 629 | "response_structured = ai_struct(\"is_palindrome\", output_schema=write_python_function)\n", 630 | "\n", 631 | "# orjson.dumps preserves field order from the ChatGPT API\n", 632 | "print(orjson.dumps(response_structured, option=orjson.OPT_INDENT_2).decode())\n", 633 | "print(\"\\n\\n\")" 634 | ], 635 | "metadata": { 636 | "colab": { 637 | "base_uri": "https://localhost:8080/" 638 | }, 639 | "id": "4ielJLnz4e5U", 640 | "outputId": "d9b3f315-1840-4781-b922-f0cd889a4086" 641 | }, 642 | "execution_count": 25, 643 | "outputs": [ 644 | { 645 | "output_type": "stream", 646 | "name": "stdout", 647 | "text": [ 648 | "{\n", 649 | " \"code\": \"def is_palindrome(s):\\n return s == s[::-1]\",\n", 650 | " \"efficient_code\": \"def is_palindrome(s):\\n n = len(s)\\n for i in range(n // 2):\\n if s[i] != s[n - i - 1]:\\n return False\\n return True\"\n", 651 | "}\n", 652 | "\n", 653 | "\n", 654 | "\n", 655 | "CPU times: user 131 ms, sys: 1.39 ms, total: 133 ms\n", 656 | "Wall time: 1.67 s\n" 657 | ] 658 | } 659 | ] 660 | }, 661 | { 662 | "cell_type": "markdown", 663 | "source": [ 664 | "As evident, the output is a `dict` so you'd just return the `efficient_code` field." 665 | ], 666 | "metadata": { 667 | "id": "M8sshigf53Nq" 668 | } 669 | }, 670 | { 671 | "cell_type": "code", 672 | "source": [ 673 | "print(response_structured[\"efficient_code\"])" 674 | ], 675 | "metadata": { 676 | "colab": { 677 | "base_uri": "https://localhost:8080/" 678 | }, 679 | "id": "eIKI_8aq51RH", 680 | "outputId": "ab239077-3114-4b52-a127-0c5a93308627" 681 | }, 682 | "execution_count": 26, 683 | "outputs": [ 684 | { 685 | "output_type": "stream", 686 | "name": "stdout", 687 | "text": [ 688 | "def is_palindrome(s):\n", 689 | " n = len(s)\n", 690 | " for i in range(n // 2):\n", 691 | " if s[i] != s[n - i - 1]:\n", 692 | " return False\n", 693 | " return True\n" 694 | ] 695 | } 696 | ] 697 | }, 698 | { 699 | "cell_type": "code", 700 | "source": [ 701 | "ai_struct.total_length" 702 | ], 703 | "metadata": { 704 | "colab": { 705 | "base_uri": "https://localhost:8080/" 706 | }, 707 | "id": "KThchw1S7MwS", 708 | "outputId": "059aedeb-13c9-4802-a0cf-239ff7384722" 709 | }, 710 | "execution_count": 27, 711 | "outputs": [ 712 | { 713 | "output_type": "execute_result", 714 | "data": { 715 | "text/plain": [ 716 | "161" 717 | ] 718 | }, 719 | "metadata": {}, 720 | "execution_count": 27 721 | } 722 | ] 723 | }, 724 | { 725 | "cell_type": "markdown", 726 | "source": [ 727 | "Token-wise it's about the same, but there's a significant speedup in generation for short queries such as these." 728 | ], 729 | "metadata": { 730 | "id": "iQcNrZcK7bn1" 731 | } 732 | }, 733 | { 734 | "cell_type": "markdown", 735 | "source": [ 736 | "Trying the other examples:" 737 | ], 738 | "metadata": { 739 | "id": "v2tfYs3a6oBy" 740 | } 741 | }, 742 | { 743 | "cell_type": "code", 744 | "source": [ 745 | "%%time\n", 746 | "response_structured = ai_struct(\"reverse string\", output_schema=write_python_function)\n", 747 | "print(response_structured[\"efficient_code\"])\n", 748 | "print(\"\\n\\n\")" 749 | ], 750 | "metadata": { 751 | "colab": { 752 | "base_uri": "https://localhost:8080/" 753 | }, 754 | "id": "ZrJtaW5m7Cdy", 755 | "outputId": "a74621e8-cebc-495b-adaa-bec926de36dd" 756 | }, 757 | "execution_count": 28, 758 | "outputs": [ 759 | { 760 | "output_type": "stream", 761 | "name": "stdout", 762 | "text": [ 763 | "def reverse_string(s):\n", 764 | " return ''.join(reversed(s))\n", 765 | "\n", 766 | "\n", 767 | "\n", 768 | "CPU times: user 24.6 ms, sys: 250 µs, total: 24.9 ms\n", 769 | "Wall time: 1.37 s\n" 770 | ] 771 | } 772 | ] 773 | }, 774 | { 775 | "cell_type": "code", 776 | "source": [ 777 | "%%time\n", 778 | "response_structured = ai_struct(\"load and flip image horizontally\", output_schema=write_python_function)\n", 779 | "print(response_structured[\"efficient_code\"])\n", 780 | "print(\"\\n\\n\")" 781 | ], 782 | "metadata": { 783 | "colab": { 784 | "base_uri": "https://localhost:8080/" 785 | }, 786 | "id": "JJxRji6h8Gww", 787 | "outputId": "45dbf068-b9b2-4d04-fccb-f988eda5538f" 788 | }, 789 | "execution_count": 29, 790 | "outputs": [ 791 | { 792 | "output_type": "stream", 793 | "name": "stdout", 794 | "text": [ 795 | "from PIL import Image\n", 796 | "\n", 797 | "def load_and_flip_image_horizontally(image_path):\n", 798 | " return Image.open(image_path).transpose(Image.FLIP_LEFT_RIGHT)\n", 799 | "\n", 800 | "\n", 801 | "\n", 802 | "CPU times: user 15.4 ms, sys: 2.15 ms, total: 17.6 ms\n", 803 | "Wall time: 1.79 s\n" 804 | ] 805 | } 806 | ] 807 | }, 808 | { 809 | "cell_type": "markdown", 810 | "source": [ 811 | "## MIT License\n", 812 | "\n", 813 | "Copyright (c) 2023 Max Woolf\n", 814 | "\n", 815 | "Permission is hereby granted, free of charge, to any person obtaining a copy\n", 816 | "of this software and associated documentation files (the \"Software\"), to deal\n", 817 | "in the Software without restriction, including without limitation the rights\n", 818 | "to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n", 819 | "copies of the Software, and to permit persons to whom the Software is\n", 820 | "furnished to do so, subject to the following conditions:\n", 821 | "\n", 822 | "The above copyright notice and this permission notice shall be included in all\n", 823 | "copies or substantial portions of the Software.\n", 824 | "\n", 825 | "THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n", 826 | "IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n", 827 | "FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n", 828 | "AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n", 829 | "LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n", 830 | "OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\n", 831 | "SOFTWARE.\n" 832 | ], 833 | "metadata": { 834 | "id": "E-QaAsRXbhAj" 835 | } 836 | } 837 | ], 838 | "metadata": { 839 | "kernelspec": { 840 | "display_name": "Python 3", 841 | "language": "python", 842 | "name": "python3" 843 | }, 844 | "language_info": { 845 | "codemirror_mode": { 846 | "name": "ipython", 847 | "version": 3 848 | }, 849 | "file_extension": ".py", 850 | "mimetype": "text/x-python", 851 | "name": "python", 852 | "nbconvert_exporter": "python", 853 | "pygments_lexer": "ipython3", 854 | "version": "3.9.12" 855 | }, 856 | "orig_nbformat": 4, 857 | "colab": { 858 | "provenance": [] 859 | } 860 | }, 861 | "nbformat": 4, 862 | "nbformat_minor": 0 863 | } -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup( 4 | name="simpleaichat", 5 | packages=["simpleaichat"], # this must be the same as the name above 6 | version="0.2.2", 7 | description="A Python package for easily interfacing with chat apps, with robust features and minimal code complexity.", 8 | long_description=open("README.md", "r", encoding="utf-8").read(), 9 | long_description_content_type="text/markdown", 10 | author="Max Woolf", 11 | author_email="max@minimaxir.com", 12 | url="https://github.com/minimaxir/simpleaichat", 13 | keywords=["chatgpt", "openai", "text generation", "ai"], 14 | classifiers=[], 15 | license="MIT", 16 | entry_points={ 17 | "console_scripts": ["simpleaichat=simpleaichat.cli:interactive_chat"] 18 | }, 19 | python_requires=">=3.8", 20 | install_requires=[ 21 | "pydantic>=2.0", 22 | "fire>=0.3.0", 23 | "httpx>=0.24.1", 24 | "python-dotenv>=1.0.0", 25 | "orjson>=3.9.0", 26 | "rich>=13.4.1", 27 | "python-dateutil>=2.8.2", 28 | ], 29 | ) 30 | -------------------------------------------------------------------------------- /simpleaichat/__init__.py: -------------------------------------------------------------------------------- 1 | from .simpleaichat import AIChat, AsyncAIChat 2 | -------------------------------------------------------------------------------- /simpleaichat/chatgpt.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Set, Union 2 | 3 | import orjson 4 | from httpx import AsyncClient, Client 5 | from pydantic import HttpUrl 6 | 7 | from .models import ChatMessage, ChatSession 8 | from .utils import remove_a_key 9 | 10 | tool_prompt = """From the list of tools below: 11 | - Reply ONLY with the number of the tool appropriate in response to the user's last message. 12 | - If no tool is appropriate, ONLY reply with \"0\". 13 | 14 | {tools}""" 15 | 16 | 17 | class ChatGPTSession(ChatSession): 18 | api_url: HttpUrl = "https://api.openai.com/v1/chat/completions" 19 | input_fields: Set[str] = {"role", "content", "name"} 20 | system: str = "You are a helpful assistant." 21 | params: Dict[str, Any] = {"temperature": 0.7} 22 | 23 | def prepare_request( 24 | self, 25 | prompt: str, 26 | system: str = None, 27 | params: Dict[str, Any] = None, 28 | stream: bool = False, 29 | input_schema: Any = None, 30 | output_schema: Any = None, 31 | is_function_calling_required: bool = True, 32 | ): 33 | headers = { 34 | "Content-Type": "application/json", 35 | "Authorization": f"Bearer {self.auth['api_key'].get_secret_value()}", 36 | } 37 | 38 | system_message = ChatMessage(role="system", content=system or self.system) 39 | if not input_schema: 40 | user_message = ChatMessage(role="user", content=prompt) 41 | else: 42 | assert isinstance( 43 | prompt, input_schema 44 | ), f"prompt must be an instance of {input_schema.__name__}" 45 | user_message = ChatMessage( 46 | role="function", 47 | content=prompt.model_dump_json(), 48 | name=input_schema.__name__, 49 | ) 50 | 51 | gen_params = params or self.params 52 | data = { 53 | "model": self.model, 54 | "messages": self.format_input_messages(system_message, user_message), 55 | "stream": stream, 56 | **gen_params, 57 | } 58 | 59 | # Add function calling parameters if a schema is provided 60 | if input_schema or output_schema: 61 | functions = [] 62 | if input_schema: 63 | input_function = self.schema_to_function(input_schema) 64 | functions.append(input_function) 65 | if output_schema: 66 | output_function = self.schema_to_function(output_schema) 67 | functions.append( 68 | output_function 69 | ) if output_function not in functions else None 70 | if is_function_calling_required: 71 | data["function_call"] = {"name": output_schema.__name__} 72 | data["functions"] = functions 73 | 74 | return headers, data, user_message 75 | 76 | def schema_to_function(self, schema: Any): 77 | assert schema.__doc__, f"{schema.__name__} is missing a docstring." 78 | assert ( 79 | "title" not in schema.model_fields.keys() 80 | ), "`title` is a reserved keyword and cannot be used as a field name." 81 | schema_dict = schema.model_json_schema() 82 | remove_a_key(schema_dict, "title") 83 | 84 | return { 85 | "name": schema.__name__, 86 | "description": schema.__doc__, 87 | "parameters": schema_dict, 88 | } 89 | 90 | def gen( 91 | self, 92 | prompt: str, 93 | client: Union[Client, AsyncClient], 94 | system: str = None, 95 | save_messages: bool = None, 96 | params: Dict[str, Any] = None, 97 | input_schema: Any = None, 98 | output_schema: Any = None, 99 | ): 100 | headers, data, user_message = self.prepare_request( 101 | prompt, system, params, False, input_schema, output_schema 102 | ) 103 | 104 | r = client.post( 105 | str(self.api_url), 106 | json=data, 107 | headers=headers, 108 | timeout=None, 109 | ) 110 | r = r.json() 111 | 112 | try: 113 | if not output_schema: 114 | content = r["choices"][0]["message"]["content"] 115 | assistant_message = ChatMessage( 116 | role=r["choices"][0]["message"]["role"], 117 | content=content, 118 | finish_reason=r["choices"][0]["finish_reason"], 119 | prompt_length=r["usage"]["prompt_tokens"], 120 | completion_length=r["usage"]["completion_tokens"], 121 | total_length=r["usage"]["total_tokens"], 122 | ) 123 | self.add_messages(user_message, assistant_message, save_messages) 124 | else: 125 | content = r["choices"][0]["message"]["function_call"]["arguments"] 126 | content = orjson.loads(content) 127 | 128 | self.total_prompt_length += r["usage"]["prompt_tokens"] 129 | self.total_completion_length += r["usage"]["completion_tokens"] 130 | self.total_length += r["usage"]["total_tokens"] 131 | except KeyError: 132 | raise KeyError(f"No AI generation: {r}") 133 | 134 | return content 135 | 136 | def stream( 137 | self, 138 | prompt: str, 139 | client: Union[Client, AsyncClient], 140 | system: str = None, 141 | save_messages: bool = None, 142 | params: Dict[str, Any] = None, 143 | input_schema: Any = None, 144 | ): 145 | headers, data, user_message = self.prepare_request( 146 | prompt, system, params, True, input_schema 147 | ) 148 | 149 | with client.stream( 150 | "POST", 151 | str(self.api_url), 152 | json=data, 153 | headers=headers, 154 | timeout=None, 155 | ) as r: 156 | content = [] 157 | for chunk in r.iter_lines(): 158 | if len(chunk) > 0: 159 | chunk = chunk[6:] # SSE JSON chunks are prepended with "data: " 160 | if chunk != "[DONE]": 161 | chunk_dict = orjson.loads(chunk) 162 | delta = chunk_dict["choices"][0]["delta"].get("content") 163 | if delta: 164 | content.append(delta) 165 | yield {"delta": delta, "response": "".join(content)} 166 | 167 | # streaming does not currently return token counts 168 | assistant_message = ChatMessage( 169 | role="assistant", 170 | content="".join(content), 171 | ) 172 | 173 | self.add_messages(user_message, assistant_message, save_messages) 174 | 175 | return assistant_message 176 | 177 | def gen_with_tools( 178 | self, 179 | prompt: str, 180 | tools: List[Any], 181 | client: Union[Client, AsyncClient], 182 | system: str = None, 183 | save_messages: bool = None, 184 | params: Dict[str, Any] = None, 185 | ) -> Dict[str, Any]: 186 | # call 1: select tool and populate context 187 | tools_list = "\n".join(f"{i+1}: {f.__doc__}" for i, f in enumerate(tools)) 188 | tool_prompt_format = tool_prompt.format(tools=tools_list) 189 | 190 | logit_bias_weight = 100 191 | logit_bias = {str(k): logit_bias_weight for k in range(15, 15 + len(tools) + 1)} 192 | 193 | tool_idx = int( 194 | self.gen( 195 | prompt, 196 | client=client, 197 | system=tool_prompt_format, 198 | save_messages=False, 199 | params={ 200 | "temperature": 0.0, 201 | "max_tokens": 1, 202 | "logit_bias": logit_bias, 203 | }, 204 | ) 205 | ) 206 | 207 | # if no tool is selected, do a standard generation instead. 208 | if tool_idx == 0: 209 | return { 210 | "response": self.gen( 211 | prompt, 212 | client=client, 213 | system=system, 214 | save_messages=save_messages, 215 | params=params, 216 | ), 217 | "tool": None, 218 | } 219 | selected_tool = tools[tool_idx - 1] 220 | context_dict = selected_tool(prompt) 221 | if isinstance(context_dict, str): 222 | context_dict = {"context": context_dict} 223 | 224 | context_dict["tool"] = selected_tool.__name__ 225 | 226 | # call 2: generate from the context 227 | new_system = f"{system or self.system}\n\nYou MUST use information from the context in your response." 228 | new_prompt = f"Context: {context_dict['context']}\n\nUser: {prompt}" 229 | 230 | context_dict["response"] = self.gen( 231 | new_prompt, 232 | client=client, 233 | system=new_system, 234 | save_messages=False, 235 | params=params, 236 | ) 237 | 238 | # manually append the nonmodified user message + normal AI response 239 | user_message = ChatMessage(role="user", content=prompt) 240 | assistant_message = ChatMessage( 241 | role="assistant", content=context_dict["response"] 242 | ) 243 | self.add_messages(user_message, assistant_message, save_messages) 244 | 245 | return context_dict 246 | 247 | async def gen_async( 248 | self, 249 | prompt: str, 250 | client: Union[Client, AsyncClient], 251 | system: str = None, 252 | save_messages: bool = None, 253 | params: Dict[str, Any] = None, 254 | input_schema: Any = None, 255 | output_schema: Any = None, 256 | ): 257 | headers, data, user_message = self.prepare_request( 258 | prompt, system, params, False, input_schema, output_schema 259 | ) 260 | 261 | r = await client.post( 262 | str(self.api_url), 263 | json=data, 264 | headers=headers, 265 | timeout=None, 266 | ) 267 | r = r.json() 268 | 269 | try: 270 | if not output_schema: 271 | content = r["choices"][0]["message"]["content"] 272 | assistant_message = ChatMessage( 273 | role=r["choices"][0]["message"]["role"], 274 | content=content, 275 | finish_reason=r["choices"][0]["finish_reason"], 276 | prompt_length=r["usage"]["prompt_tokens"], 277 | completion_length=r["usage"]["completion_tokens"], 278 | total_length=r["usage"]["total_tokens"], 279 | ) 280 | self.add_messages(user_message, assistant_message, save_messages) 281 | else: 282 | content = r["choices"][0]["message"]["function_call"]["arguments"] 283 | content = orjson.loads(content) 284 | 285 | self.total_prompt_length += r["usage"]["prompt_tokens"] 286 | self.total_completion_length += r["usage"]["completion_tokens"] 287 | self.total_length += r["usage"]["total_tokens"] 288 | except KeyError: 289 | raise KeyError(f"No AI generation: {r}") 290 | 291 | return content 292 | 293 | async def stream_async( 294 | self, 295 | prompt: str, 296 | client: Union[Client, AsyncClient], 297 | system: str = None, 298 | save_messages: bool = None, 299 | params: Dict[str, Any] = None, 300 | input_schema: Any = None, 301 | ): 302 | headers, data, user_message = self.prepare_request( 303 | prompt, system, params, True, input_schema 304 | ) 305 | 306 | async with client.stream( 307 | "POST", 308 | str(self.api_url), 309 | json=data, 310 | headers=headers, 311 | timeout=None, 312 | ) as r: 313 | content = [] 314 | async for chunk in r.aiter_lines(): 315 | if len(chunk) > 0: 316 | chunk = chunk[6:] # SSE JSON chunks are prepended with "data: " 317 | if chunk != "[DONE]": 318 | chunk_dict = orjson.loads(chunk) 319 | delta = chunk_dict["choices"][0]["delta"].get("content") 320 | if delta: 321 | content.append(delta) 322 | yield {"delta": delta, "response": "".join(content)} 323 | 324 | # streaming does not currently return token counts 325 | assistant_message = ChatMessage( 326 | role="assistant", 327 | content="".join(content), 328 | ) 329 | 330 | self.add_messages(user_message, assistant_message, save_messages) 331 | 332 | async def gen_with_tools_async( 333 | self, 334 | prompt: str, 335 | tools: List[Any], 336 | client: Union[Client, AsyncClient], 337 | system: str = None, 338 | save_messages: bool = None, 339 | params: Dict[str, Any] = None, 340 | ) -> Dict[str, Any]: 341 | # call 1: select tool and populate context 342 | tools_list = "\n".join(f"{i+1}: {f.__doc__}" for i, f in enumerate(tools)) 343 | tool_prompt_format = tool_prompt.format(tools=tools_list) 344 | 345 | logit_bias_weight = 100 346 | logit_bias = {str(k): logit_bias_weight for k in range(15, 15 + len(tools) + 1)} 347 | 348 | tool_idx = int( 349 | await self.gen_async( 350 | prompt, 351 | client=client, 352 | system=tool_prompt_format, 353 | save_messages=False, 354 | params={ 355 | "temperature": 0.0, 356 | "max_tokens": 1, 357 | "logit_bias": logit_bias, 358 | }, 359 | ) 360 | ) 361 | 362 | # if no tool is selected, do a standard generation instead. 363 | if tool_idx == 0: 364 | return { 365 | "response": await self.gen_async( 366 | prompt, 367 | client=client, 368 | system=system, 369 | save_messages=save_messages, 370 | params=params, 371 | ), 372 | "tool": None, 373 | } 374 | selected_tool = tools[tool_idx - 1] 375 | context_dict = await selected_tool(prompt) 376 | if isinstance(context_dict, str): 377 | context_dict = {"context": context_dict} 378 | 379 | context_dict["tool"] = selected_tool.__name__ 380 | 381 | # call 2: generate from the context 382 | new_system = f"{system or self.system}\n\nYou MUST use information from the context in your response." 383 | new_prompt = f"Context: {context_dict['context']}\n\nUser: {prompt}" 384 | 385 | context_dict["response"] = await self.gen_async( 386 | new_prompt, 387 | client=client, 388 | system=new_system, 389 | save_messages=False, 390 | params=params, 391 | ) 392 | 393 | # manually append the nonmodified user message + normal AI response 394 | user_message = ChatMessage(role="user", content=prompt) 395 | assistant_message = ChatMessage( 396 | role="assistant", content=context_dict["response"] 397 | ) 398 | self.add_messages(user_message, assistant_message, save_messages) 399 | 400 | return context_dict 401 | -------------------------------------------------------------------------------- /simpleaichat/cli.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from getpass import getpass 4 | 5 | import fire 6 | from dotenv import load_dotenv 7 | 8 | from .simpleaichat import AIChat 9 | 10 | load_dotenv() 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument("character", help="Specify the character", default=None, nargs="?") 14 | parser.add_argument( 15 | "character_command", help="Specify the character command", default=None, nargs="?" 16 | ) 17 | parser.add_argument("--prime", action="store_true", help="Enable priming") 18 | 19 | ARGS = parser.parse_args() 20 | 21 | 22 | def interactive_chat(): 23 | gpt_api_key = os.getenv("OPENAI_API_KEY") 24 | if not gpt_api_key: 25 | gpt_api_key = getpass("Input your OpenAI key here: ") 26 | assert gpt_api_key, "An API key was not defined." 27 | _ = AIChat(ARGS.character, ARGS.character_command, ARGS.prime) 28 | 29 | 30 | if __name__ == "__main__": 31 | fire.Fire(interactive_chat) 32 | -------------------------------------------------------------------------------- /simpleaichat/models.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | from typing import Any, Dict, List, Optional, Set, Union 3 | from uuid import UUID, uuid4 4 | 5 | import orjson 6 | from pydantic import BaseModel, Field, HttpUrl, SecretStr 7 | 8 | 9 | def orjson_dumps(v, *, default, **kwargs): 10 | # orjson.dumps returns bytes, to match standard json.dumps we need to decode 11 | return orjson.dumps(v, default=default, **kwargs).decode() 12 | 13 | 14 | def now_tz(): 15 | # Need datetime w/ timezone for cleanliness 16 | # https://stackoverflow.com/a/24666683 17 | return datetime.datetime.now(datetime.timezone.utc) 18 | 19 | 20 | class ChatMessage(BaseModel): 21 | role: str 22 | content: str 23 | name: Optional[str] = None 24 | function_call: Optional[str] = None 25 | received_at: datetime.datetime = Field(default_factory=now_tz) 26 | finish_reason: Optional[str] = None 27 | prompt_length: Optional[int] = None 28 | completion_length: Optional[int] = None 29 | total_length: Optional[int] = None 30 | 31 | def __str__(self) -> str: 32 | return str(self.model_dump(exclude_none=True)) 33 | 34 | 35 | class ChatSession(BaseModel): 36 | id: Union[str, UUID] = Field(default_factory=uuid4) 37 | created_at: datetime.datetime = Field(default_factory=now_tz) 38 | auth: Dict[str, SecretStr] 39 | api_url: HttpUrl 40 | model: str 41 | system: str 42 | params: Dict[str, Any] = {} 43 | messages: List[ChatMessage] = [] 44 | input_fields: Set[str] = {} 45 | recent_messages: Optional[int] = None 46 | save_messages: Optional[bool] = True 47 | total_prompt_length: int = 0 48 | total_completion_length: int = 0 49 | total_length: int = 0 50 | title: Optional[str] = None 51 | 52 | def __str__(self) -> str: 53 | sess_start_str = self.created_at.strftime("%Y-%m-%d %H:%M:%S") 54 | last_message_str = self.messages[-1].received_at.strftime("%Y-%m-%d %H:%M:%S") 55 | return f"""Chat session started at {sess_start_str}: 56 | - {len(self.messages):,} Messages 57 | - Last message sent at {last_message_str}""" 58 | 59 | def format_input_messages( 60 | self, system_message: ChatMessage, user_message: ChatMessage 61 | ) -> list: 62 | recent_messages = ( 63 | self.messages[-self.recent_messages :] 64 | if self.recent_messages 65 | else self.messages 66 | ) 67 | return ( 68 | [system_message.model_dump(include=self.input_fields, exclude_none=True)] 69 | + [ 70 | m.model_dump(include=self.input_fields, exclude_none=True) 71 | for m in recent_messages 72 | ] 73 | + [user_message.model_dump(include=self.input_fields, exclude_none=True)] 74 | ) 75 | 76 | def add_messages( 77 | self, 78 | user_message: ChatMessage, 79 | assistant_message: ChatMessage, 80 | save_messages: bool = None, 81 | ) -> None: 82 | # if save_messages is explicitly defined, always use that choice 83 | # instead of the default 84 | to_save = isinstance(save_messages, bool) 85 | 86 | if to_save: 87 | if save_messages: 88 | self.messages.append(user_message) 89 | self.messages.append(assistant_message) 90 | elif self.save_messages: 91 | self.messages.append(user_message) 92 | self.messages.append(assistant_message) 93 | -------------------------------------------------------------------------------- /simpleaichat/simpleaichat.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import datetime 3 | import os 4 | from contextlib import asynccontextmanager, contextmanager 5 | from typing import Any, Dict, List, Optional, Union 6 | from uuid import UUID, uuid4 7 | 8 | import dateutil 9 | import orjson 10 | from dotenv import load_dotenv 11 | from httpx import AsyncClient, Client 12 | from pydantic import BaseModel 13 | from rich.console import Console 14 | 15 | from .chatgpt import ChatGPTSession 16 | from .models import ChatMessage, ChatSession 17 | from .utils import wikipedia_search_lookup 18 | 19 | load_dotenv() 20 | 21 | 22 | class AIChat(BaseModel): 23 | client: Any 24 | default_session: Optional[ChatSession] 25 | sessions: Dict[Union[str, UUID], ChatSession] = {} 26 | 27 | def __init__( 28 | self, 29 | character: str = None, 30 | character_command: str = None, 31 | system: str = None, 32 | id: Union[str, UUID] = uuid4(), 33 | prime: bool = True, 34 | default_session: bool = True, 35 | console: bool = True, 36 | **kwargs, 37 | ): 38 | client = Client(proxies=os.getenv("https_proxy")) 39 | system_format = self.build_system(character, character_command, system) 40 | 41 | sessions = {} 42 | new_default_session = None 43 | if default_session: 44 | new_session = self.new_session( 45 | return_session=True, system=system_format, id=id, **kwargs 46 | ) 47 | 48 | new_default_session = new_session 49 | sessions = {new_session.id: new_session} 50 | 51 | super().__init__( 52 | client=client, default_session=new_default_session, sessions=sessions 53 | ) 54 | 55 | if not system and console: 56 | character = "ChatGPT" if not character else character 57 | new_default_session.title = character 58 | self.interactive_console(character=character, prime=prime) 59 | 60 | def new_session( 61 | self, 62 | return_session: bool = False, 63 | **kwargs, 64 | ) -> Optional[ChatGPTSession]: 65 | if "model" not in kwargs: # set default 66 | kwargs["model"] = "gpt-3.5-turbo" 67 | # TODO: Add support for more models (PaLM, Claude) 68 | if "gpt-" in kwargs["model"]: 69 | gpt_api_key = kwargs.get("api_key") or os.getenv("OPENAI_API_KEY") 70 | assert gpt_api_key, f"An API key for {kwargs['model'] } was not defined." 71 | sess = ChatGPTSession( 72 | auth={ 73 | "api_key": gpt_api_key, 74 | }, 75 | **kwargs, 76 | ) 77 | 78 | if return_session: 79 | return sess 80 | else: 81 | self.sessions[sess.id] = sess 82 | 83 | def get_session(self, id: Union[str, UUID] = None) -> ChatSession: 84 | try: 85 | sess = self.sessions[id] if id else self.default_session 86 | except KeyError: 87 | raise KeyError("No session by that key exists.") 88 | if not sess: 89 | raise ValueError("No default session exists.") 90 | return sess 91 | 92 | def reset_session(self, id: Union[str, UUID] = None) -> None: 93 | sess = self.get_session(id) 94 | sess.messages = [] 95 | 96 | def delete_session(self, id: Union[str, UUID] = None) -> None: 97 | sess = self.get_session(id) 98 | if self.default_session: 99 | if sess.id == self.default_session.id: 100 | self.default_session = None 101 | del self.sessions[sess.id] 102 | del sess 103 | 104 | @contextmanager 105 | def session(self, **kwargs): 106 | sess = self.new_session(return_session=True, **kwargs) 107 | self.sessions[sess.id] = sess 108 | try: 109 | yield sess 110 | finally: 111 | self.delete_session(sess.id) 112 | 113 | def __call__( 114 | self, 115 | prompt: Union[str, Any], 116 | id: Union[str, UUID] = None, 117 | system: str = None, 118 | save_messages: bool = None, 119 | params: Dict[str, Any] = None, 120 | tools: List[Any] = None, 121 | input_schema: Any = None, 122 | output_schema: Any = None, 123 | ) -> str: 124 | sess = self.get_session(id) 125 | if tools: 126 | assert (input_schema is None) and ( 127 | output_schema is None 128 | ), "When using tools, input/output schema are ignored" 129 | for tool in tools: 130 | assert tool.__doc__, f"Tool {tool} does not have a docstring." 131 | assert len(tools) <= 9, "You can only have a maximum of 9 tools." 132 | return sess.gen_with_tools( 133 | prompt, 134 | tools, 135 | client=self.client, 136 | system=system, 137 | save_messages=save_messages, 138 | params=params, 139 | ) 140 | else: 141 | return sess.gen( 142 | prompt, 143 | client=self.client, 144 | system=system, 145 | save_messages=save_messages, 146 | params=params, 147 | input_schema=input_schema, 148 | output_schema=output_schema, 149 | ) 150 | 151 | def stream( 152 | self, 153 | prompt: str, 154 | id: Union[str, UUID] = None, 155 | system: str = None, 156 | save_messages: bool = None, 157 | params: Dict[str, Any] = None, 158 | input_schema: Any = None, 159 | ) -> str: 160 | sess = self.get_session(id) 161 | return sess.stream( 162 | prompt, 163 | client=self.client, 164 | system=system, 165 | save_messages=save_messages, 166 | params=params, 167 | input_schema=input_schema, 168 | ) 169 | 170 | def build_system( 171 | self, character: str = None, character_command: str = None, system: str = None 172 | ) -> str: 173 | default = "You are a helpful assistant." 174 | if character: 175 | character_prompt = """ 176 | You must follow ALL these rules in all responses: 177 | - You are the following character and should ALWAYS act as them: {0} 178 | - NEVER speak in a formal tone. 179 | - Concisely introduce yourself first in character. 180 | """ 181 | prompt = character_prompt.format(wikipedia_search_lookup(character)).strip() 182 | if character_command: 183 | character_system = """ 184 | - {0} 185 | """ 186 | prompt = ( 187 | prompt + "\n" + character_system.format(character_command).strip() 188 | ) 189 | return prompt 190 | elif system: 191 | return system 192 | else: 193 | return default 194 | 195 | def interactive_console(self, character: str = None, prime: bool = True) -> None: 196 | console = Console(highlight=False, force_jupyter=False) 197 | sess = self.default_session 198 | ai_text_color = "bright_magenta" 199 | 200 | # prime with a unique starting response to the user 201 | if prime: 202 | console.print(f"[b]{character}[/b]: ", end="", style=ai_text_color) 203 | for chunk in sess.stream("Hello!", self.client): 204 | console.print(chunk["delta"], end="", style=ai_text_color) 205 | 206 | while True: 207 | console.print() 208 | try: 209 | user_input = console.input("[b]You:[/b] ").strip() 210 | if not user_input: 211 | break 212 | 213 | console.print(f"[b]{character}[/b]: ", end="", style=ai_text_color) 214 | for chunk in sess.stream(user_input, self.client): 215 | console.print(chunk["delta"], end="", style=ai_text_color) 216 | except KeyboardInterrupt: 217 | break 218 | 219 | def __str__(self) -> str: 220 | if self.default_session: 221 | return self.default_session.model_dump_json( 222 | exclude={"api_key", "api_url"}, 223 | exclude_none=True, 224 | ) 225 | 226 | def __repr__(self) -> str: 227 | return "" 228 | 229 | # Save/Load Chats given a session id 230 | def save_session( 231 | self, 232 | output_path: str = None, 233 | id: Union[str, UUID] = None, 234 | format: str = "csv", 235 | minify: bool = False, 236 | ): 237 | sess = self.get_session(id) 238 | sess_dict = sess.model_dump( 239 | exclude={"auth", "api_url", "input_fields"}, 240 | exclude_none=True, 241 | ) 242 | output_path = output_path or f"chat_session.{format}" 243 | if format == "csv": 244 | with open(output_path, "w", encoding="utf-8") as f: 245 | fields = [ 246 | "role", 247 | "content", 248 | "received_at", 249 | "prompt_length", 250 | "completion_length", 251 | "total_length", 252 | "finish_reason", 253 | ] 254 | w = csv.DictWriter(f, fieldnames=fields) 255 | w.writeheader() 256 | for message in sess_dict["messages"]: 257 | # datetime must be in common format to be loaded into spreadsheet 258 | # for human-readability, the timezone is set to local machine 259 | local_datetime = message["received_at"].astimezone() 260 | message["received_at"] = local_datetime.strftime( 261 | "%Y-%m-%d %H:%M:%S" 262 | ) 263 | w.writerow(message) 264 | elif format == "json": 265 | with open(output_path, "wb") as f: 266 | f.write( 267 | orjson.dumps( 268 | sess_dict, option=orjson.OPT_INDENT_2 if not minify else None 269 | ) 270 | ) 271 | 272 | def load_session(self, input_path: str, id: Union[str, UUID] = uuid4(), **kwargs): 273 | assert input_path.endswith(".csv") or input_path.endswith( 274 | ".json" 275 | ), "Only CSV and JSON imports are accepted." 276 | 277 | if input_path.endswith(".csv"): 278 | with open(input_path, "r", encoding="utf-8") as f: 279 | r = csv.DictReader(f) 280 | messages = [] 281 | for row in r: 282 | # need to convert the datetime back to UTC 283 | local_datetime = datetime.datetime.strptime( 284 | row["received_at"], "%Y-%m-%d %H:%M:%S" 285 | ).replace(tzinfo=dateutil.tz.tzlocal()) 286 | row["received_at"] = local_datetime.astimezone( 287 | datetime.timezone.utc 288 | ) 289 | # https://stackoverflow.com/a/68305271 290 | row = {k: (None if v == "" else v) for k, v in row.items()} 291 | messages.append(ChatMessage(**row)) 292 | 293 | self.new_session(id=id, **kwargs) 294 | self.sessions[id].messages = messages 295 | 296 | if input_path.endswith(".json"): 297 | with open(input_path, "rb") as f: 298 | sess_dict = orjson.loads(f.read()) 299 | # update session with info not loaded, e.g. auth/api_url 300 | for arg in kwargs: 301 | sess_dict[arg] = kwargs[arg] 302 | self.new_session(**sess_dict) 303 | 304 | # Tabulators for returning total token counts 305 | def message_totals(self, attr: str, id: Union[str, UUID] = None) -> int: 306 | sess = self.get_session(id) 307 | return getattr(sess, attr) 308 | 309 | @property 310 | def total_prompt_length(self, id: Union[str, UUID] = None) -> int: 311 | return self.message_totals("total_prompt_length", id) 312 | 313 | @property 314 | def total_completion_length(self, id: Union[str, UUID] = None) -> int: 315 | return self.message_totals("total_completion_length", id) 316 | 317 | @property 318 | def total_length(self, id: Union[str, UUID] = None) -> int: 319 | return self.message_totals("total_length", id) 320 | 321 | # alias total_tokens to total_length for common use 322 | @property 323 | def total_tokens(self, id: Union[str, UUID] = None) -> int: 324 | return self.total_length(id) 325 | 326 | 327 | class AsyncAIChat(AIChat): 328 | async def __call__( 329 | self, 330 | prompt: str, 331 | id: Union[str, UUID] = None, 332 | system: str = None, 333 | save_messages: bool = None, 334 | params: Dict[str, Any] = None, 335 | tools: List[Any] = None, 336 | input_schema: Any = None, 337 | output_schema: Any = None, 338 | ) -> str: 339 | # TODO: move to a __post_init__ in Pydantic 2.0 340 | if isinstance(self.client, Client): 341 | self.client = AsyncClient(proxies=os.getenv("https_proxy")) 342 | sess = self.get_session(id) 343 | if tools: 344 | assert (input_schema is None) and ( 345 | output_schema is None 346 | ), "When using tools, input/output schema are ignored" 347 | for tool in tools: 348 | assert tool.__doc__, f"Tool {tool} does not have a docstring." 349 | assert len(tools) <= 9, "You can only have a maximum of 9 tools." 350 | return await sess.gen_with_tools_async( 351 | prompt, 352 | tools, 353 | client=self.client, 354 | system=system, 355 | save_messages=save_messages, 356 | params=params, 357 | ) 358 | else: 359 | return await sess.gen_async( 360 | prompt, 361 | client=self.client, 362 | system=system, 363 | save_messages=save_messages, 364 | params=params, 365 | input_schema=input_schema, 366 | output_schema=output_schema, 367 | ) 368 | 369 | async def stream( 370 | self, 371 | prompt: str, 372 | id: Union[str, UUID] = None, 373 | system: str = None, 374 | save_messages: bool = None, 375 | params: Dict[str, Any] = None, 376 | input_schema: Any = None, 377 | ) -> str: 378 | # TODO: move to a __post_init__ in Pydantic 2.0 379 | if isinstance(self.client, Client): 380 | self.client = AsyncClient(proxies=os.getenv("https_proxy")) 381 | sess = self.get_session(id) 382 | return sess.stream_async( 383 | prompt, 384 | client=self.client, 385 | system=system, 386 | save_messages=save_messages, 387 | params=params, 388 | input_schema=input_schema, 389 | ) 390 | 391 | @asynccontextmanager 392 | async def session(self, **kwargs): 393 | sess = self.new_session(return_session=True, **kwargs) 394 | self.sessions[sess.id] = sess 395 | try: 396 | yield sess 397 | finally: 398 | self.delete_session(sess.id) 399 | -------------------------------------------------------------------------------- /simpleaichat/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List, Union 3 | 4 | import httpx 5 | from pydantic import Field 6 | 7 | WIKIPEDIA_API_URL = "https://en.wikipedia.org/w/api.php" 8 | 9 | 10 | def wikipedia_search(query: str, n: int = 1) -> Union[str, List[str]]: 11 | SEARCH_PARAMS = { 12 | "action": "query", 13 | "list": "search", 14 | "format": "json", 15 | "srlimit": n, 16 | "srsearch": query, 17 | "srwhat": "text", 18 | "srprop": "", 19 | } 20 | 21 | r_search = httpx.get(WIKIPEDIA_API_URL, params=SEARCH_PARAMS) 22 | results = [x["title"] for x in r_search.json()["query"]["search"]] 23 | 24 | return results[0] if n == 1 else results 25 | 26 | 27 | def wikipedia_lookup(query: str, sentences: int = 1) -> str: 28 | LOOKUP_PARAMS = { 29 | "action": "query", 30 | "prop": "extracts", 31 | "exsentences": sentences, 32 | "exlimit": "1", 33 | "explaintext": "1", 34 | "formatversion": "2", 35 | "format": "json", 36 | "titles": query, 37 | } 38 | 39 | r_lookup = httpx.get(WIKIPEDIA_API_URL, params=LOOKUP_PARAMS) 40 | return r_lookup.json()["query"]["pages"][0]["extract"] 41 | 42 | 43 | def wikipedia_search_lookup(query: str, sentences: int = 1) -> str: 44 | return wikipedia_lookup(wikipedia_search(query, 1), sentences) 45 | 46 | 47 | async def wikipedia_search_async(query: str, n: int = 1) -> Union[str, List[str]]: 48 | SEARCH_PARAMS = { 49 | "action": "query", 50 | "list": "search", 51 | "format": "json", 52 | "srlimit": n, 53 | "srsearch": query, 54 | "srwhat": "text", 55 | "srprop": "", 56 | } 57 | 58 | async with httpx.AsyncClient(proxies=os.getenv("https_proxy")) as client: 59 | r_search = await client.get(WIKIPEDIA_API_URL, params=SEARCH_PARAMS) 60 | results = [x["title"] for x in r_search.json()["query"]["search"]] 61 | 62 | return results[0] if n == 1 else results 63 | 64 | 65 | async def wikipedia_lookup_async(query: str, sentences: int = 1) -> str: 66 | LOOKUP_PARAMS = { 67 | "action": "query", 68 | "prop": "extracts", 69 | "exsentences": sentences, 70 | "exlimit": "1", 71 | "explaintext": "1", 72 | "formatversion": "2", 73 | "format": "json", 74 | "titles": query, 75 | } 76 | 77 | async with httpx.AsyncClient(proxies=os.getenv("https_proxy")) as client: 78 | r_lookup = await client.get(WIKIPEDIA_API_URL, params=LOOKUP_PARAMS) 79 | return r_lookup.json()["query"]["pages"][0]["extract"] 80 | 81 | 82 | async def wikipedia_search_lookup_async(query: str, sentences: int = 1) -> str: 83 | return await wikipedia_lookup_async( 84 | await wikipedia_search_async(query, 1), sentences 85 | ) 86 | 87 | 88 | def fd(description: str, **kwargs): 89 | return Field(description=description, **kwargs) 90 | 91 | 92 | # https://stackoverflow.com/a/58938747 93 | def remove_a_key(d, remove_key): 94 | if isinstance(d, dict): 95 | for key in list(d.keys()): 96 | if key == remove_key: 97 | del d[key] 98 | else: 99 | remove_a_key(d[key], remove_key) 100 | --------------------------------------------------------------------------------