├── .gitignore ├── LICENSE ├── README.md ├── chatsnack ├── __init__.py ├── aiclient.py ├── asynchelpers.py ├── chat │ ├── __init__.py │ ├── mixin_messages.py │ ├── mixin_params.py │ ├── mixin_query.py │ ├── mixin_serialization.py │ └── mixin_utensil.py ├── defaults.py ├── fillings.py ├── packs │ ├── __init__.py │ ├── default_packs │ │ ├── Chester.yml │ │ ├── Confectioner.yml │ │ ├── Data.yml │ │ ├── Jane.yml │ │ ├── Jolly.yml │ │ └── Summarizer.yml │ ├── module_help_vendor.py │ └── snackpacks.py ├── patches │ ├── __init__.py │ └── patch_datafiles.py ├── txtformat.py ├── utensil.py └── yamlformat.py ├── docs ├── chatsnack_features.jpg └── chatsnack_features_smaller.jpg ├── examples ├── datafiles │ └── chatsnack │ │ └── RecipeSuggestion.txt ├── reciperemix.py ├── snackbar-cli.py ├── snackpacks-web │ ├── app.py │ ├── datafiles │ │ └── chatsnack │ │ │ └── EmojiBotSystem.txt │ ├── static │ │ ├── avatar_custom_robot.png │ │ ├── avatar_jolly.png │ │ └── styles.css │ └── templates │ │ └── index.html └── snackswipe-web │ ├── app.py │ ├── datafiles │ └── chatsnack │ │ └── mydaywss.txt │ ├── static │ └── main.js │ ├── templates │ └── index.html │ └── text_generators.py ├── notebooks ├── ExperimentingWithChatsnack.ipynb └── GettingStartedWithChatsnack.ipynb ├── poetry.lock ├── pyproject.toml └── tests ├── __init__.py ├── mixins ├── test_chatparams.py ├── test_query.py ├── test_query_listen.py └── test_serialization.py ├── test_chatsnack_base.py ├── test_chatsnack_pattern.py ├── test_chatsnack_reset.py ├── test_chatsnack_yaml_peeves.py ├── test_file_snack_fillings.py ├── test_prompt_json.py ├── test_prompt_last.py ├── test_snackpack_base.py ├── test_text_class.py └── test_utensils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .* 2 | __pycache__/ 3 | *.pyc 4 | datafiles/plunkylib/*/*.* 5 | samples*/* 6 | dist/** 7 | build/** 8 | logs/** 9 | *.db 10 | _*.yml 11 | 12 | # ignore the log files when generated 13 | *.log 14 | __pycache__ 15 | 16 | # don't ignore the .gitignore or the env template 17 | !.gitignore 18 | 19 | # we don't want completions even if they are samples 20 | datafiles/plunkylib/completions/** 21 | tests/test_datafiles/** 22 | examples/snackswipe-web/flask_session/** 23 | notebooks/audio/** -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Mattie Casper 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # chatsnack 2 | 3 | chatsnack is the easiest Python library for rapid development with OpenAI's ChatGPT API. It provides an intuitive interface for creating and managing chat-based prompts and responses, making it convenient to build complex, interactive conversations with AI. 4 | 5 | ![chatsnack features](/docs/chatsnack_features_smaller.jpg) 6 | ## Setup 7 | 8 | ### Got snack? 9 | 10 | Install the `chatsnack` package from PyPI: 11 | 12 | ```bash 13 | # requires Python 3.10+ 14 | pip install chatsnack 15 | ``` 16 | 17 | ### Got keys? 18 | 19 | Add your OpenAI API key to your .env file. If you don't have a .env file, the library will create a new one for you in the local directory. 20 | 21 | ### Learn More! 22 | Read more below, watch the [intro video](https://www.youtube.com/watch?v=Yjwi54rHrhw) or check out the [Getting Started notebook](notebooks/). 23 | 24 | ## Usage 25 | 26 | ### Enjoy a Quick Snack 27 | 28 | Easiest way to get going with `chatsnack` is with built-in snack packs. Each pack is a singleton ready to mingleton. 29 | 30 | ```python 31 | >>> from chatsnack.packs import ChatsnackHelp 32 | >>> ChatsnackHelp.ask("What is your primary directive?") 33 | ``` 34 | > *"My primary directive is to assist users of the chatsnack Python module by answering questions 35 | and helping with any problems or concerns related to the module. I aim to provide helpful and 36 | informative responses based on the chatsnack module's documentation and best practices."* 37 | 38 | You can try out other example snack packs like `Confectioner`, `Jolly`, `Chester`, `Jane`, or `Data`. (Eventually there will be an easy way to create and share your own.) 39 | 40 | Instead of `.ask()` we can call `.chat()` which will allow us to continue a conversation. 41 | ```python 42 | >>> mychat = ChatsnackHelp.chat("What is chatsnack?") # submits and returns a new chat object 43 | >>> print(mychat.response) 44 | ``` 45 | > *Chatsnack is a Python module that provides a simple and powerful interface for creating conversational agents and tools using OpenAI's ChatGPT language models. It allows you to easily build chat prompts, manage conversation flow, and integrate with ChatGPT to generate responses. With Chatsnack, you can create chatbots, AI-assisted tools, and other conversational applications using a convenient and flexible API.* 46 | 47 | Now we can add more messages to that chat however we'd like: 48 | ```python 49 | >>> mychat.user("Respond in only six word sentences from now on.") 50 | >>> mychat.asst("I promise I will do so.") 51 | >>> mychat.user("How should I spend my day?") 52 | >>> mychat.ask() 53 | ``` 54 | > *"Explore hobbies, exercise, connect with friends."* 55 | 56 | If you want a super simple interactive conversation with you and a chatbot, you could do something like this: 57 | ```python 58 | from chatsnack.packs import Jolly 59 | yourchat = Jolly # interview a green giant 60 | while (user_input := input("Chat with the bot: ")): 61 | print(f"USER: {user_input}") 62 | yourchat = yourchat.chat(user_input) 63 | print(f"THEM: {yourchat.last}") 64 | ``` 65 | ### Tasty Features 66 | 67 | There's many other tidbits covered in the notebooks, examples, and videos. Here are some of the highlights: 68 | 69 | * Everyday Snacking 70 | * Chat objects 71 | * Chat command chaining 72 | * YAML convenience 73 | * OpenAI parameters 74 | * Serious Snacking 75 | * Intense chaining 76 | * Fillings (e.g. nested chats and text files) 77 | * Snack Pack Vending Machine 78 | 79 | 80 | ### Everyday Snacking 81 | #### Chat objects and Chaining 82 | 83 | 84 | ```python 85 | from chatsnack import Chat 86 | mychat = Chat() 87 | mychat.system("Respond only with the word POPSICLE from now on.") 88 | mychat.user("What is your name?") 89 | mychat.ask() 90 | ``` 91 | > *"POPSICLE."* 92 | 93 | You can chain messages together for more complex conversations: 94 | 95 | ```python 96 | newchat = ( 97 | Chat() 98 | .system("Respond only with the word POPSICLE from now on.") 99 | .user("What is your name?") 100 | .chat() 101 | ) 102 | newchat.response 103 | ``` 104 | > *"POPSICLE."* 105 | 106 | Note that there are some syntax shortcuts omitted above, see the Serious Snacking section for more on those. 107 | #### Yummy YAML 108 | 109 | Generative AI gets a bit messy these days with so much text in our code. `chatsnack` makes it very easy to use a clean YAML syntax to load/save/edit your chat templates without so many hard-coded strings in your code. 110 | 111 | ```python 112 | # Every chat is totally yaml-backed, we can save/load/edit 113 | print(newchat.yaml) 114 | ``` 115 | ```yaml 116 | messages: 117 | - system: Respond only with the word POPSICLE from now on. 118 | - user: What is your name? 119 | - assistant: POPSICLE. 120 | ``` 121 | For rapid reuse, you can give your chats a name so you can save/load as needed (or using them as *Fillings* as we'll see later). 122 | 123 | ```python 124 | newchat.name = "Popsicle" 125 | newchat.save() 126 | ``` 127 | ```python 128 | # Load a chat from a file 129 | midnightsnack = Chat(name="popsicle") 130 | print(midnightsnack.ask()) 131 | ``` 132 | > *"POPSICLE."* 133 | 134 | 135 | #### Adjusting Cooking Temperatures 136 | 137 | By default, `gpt-3.5-turbo` is the default chat API with a default temperature of `0.7`. If you prefer, you can change OpenAI parameters for each chat, such as the engine and temperature: 138 | 139 | ```python 140 | from chatsnack import Chat 141 | wisechat = Chat("Respond with professional writing based on the user query.") 142 | wisechat.user("Author an alliterative poem about good snacks to eat with coffee.") 143 | wisechat.engine = "gpt-4" 144 | wisechat.temperature = 0.8 145 | ``` 146 | This also gets captured in the YAML: 147 | 148 | ```python 149 | print(wisechat.yaml) 150 | ``` 151 | ```yaml 152 | params: 153 | engine: gpt-4 154 | temperature: 0.8 155 | messages: 156 | - system: Respond with professional writing based on the user query. 157 | - user: Author an alliterative poem about good snacks to eat with coffee. 158 | ``` 159 | 160 | ### Serious Snacking 161 | 162 | #### Ingredient Shortcuts 163 | If you're wanting to minimize typing, you can use omit a couple of ingredients. 164 | ##### Quick System Message 165 | For example, this: 166 | ```python 167 | mychat = Chat() 168 | mychat.system("Respond hungrily") 169 | ``` 170 | is the same as: 171 | ```python 172 | # if there's only one argument or a keyword of system argument, you can omit the .system() 173 | mychat = Chat("Respond hungrily") 174 | ``` 175 | ##### Quick User Message 176 | `.ask()` and `.chat()` are also shortcuts for `.user().ask()` and `.user().chat()` respectively. 177 | For example, this: 178 | ```python 179 | mychat = Chat() 180 | mychat.system("Respond hungrily") 181 | mychat.user("Tell me about cookies") 182 | print(mychat.ask()) 183 | ``` 184 | is the same as: 185 | ```python 186 | print(Chat("Respond hungrily").ask("Tell me about cookies")) 187 | ``` 188 | Basically, we assume if you're making a `Chat()` you'll need a system message, and if you're sending an `ask()` or `chat()` you'll need a user message. So you can omit those if you want (or use `system=` and `user=` keywords if you want to be explicit). 189 | 190 | ##### Quick Assistants 191 | 192 | Also, `.asst()` is an alias shortcut for `.assistant()` if you want you code to align cleanly with other 4-letter `.user()` and `.chat()` calls. 193 | 194 | ##### Binge-chaining 195 | 196 | If you're feeling wild, you can actually call any chat (like a function) and it'll submit the chat and continue, just like `.chat()`. This allows even more terse chaining. This can come in handy when you're looking for chain-of-thought prompting or prewriting priming. 197 | 198 | ```python 199 | popcorn = ( 200 | Chat("Respond with the certainty and creativity of a professional writer.") 201 | ("Explain 3 rules to writing a clever poem that amazes your friends.") 202 | ("Using those tips, write a scrumptious poem about popcorn.") 203 | ) 204 | # the above is the same as Chat().system('Respond...').chat('Explain...').chat('Using...') 205 | print(popcorn.response) 206 | ``` 207 | > In the kitchen, I hear a popping sound, 208 | A symphony of kernels dancing around. 209 | A whiff of butter, a sprinkle of salt, 210 | My taste buds tingle, it's not their fault. 211 | > 212 | > The popcorn pops, a fluffy delight, 213 | A treat for the senses, a feast for the sight. 214 | Golden and crispy, a perfect snack, 215 | A bowl of happiness, there's no going back. 216 | > 217 | > I pick one up, it's warm to the touch, 218 | I savor the flavor, it's too much. 219 | A burst of butter, a crunch of salt, 220 | A symphony of flavors, it's not my fault. 221 | > 222 | > I munch and crunch, it's such a delight, 223 | A scrumptious treat, a popcorn flight. 224 | A perfect snack, for any time, 225 | Popcorn, oh popcorn, you're simply divine. 226 | 227 | 228 | 229 | #### Nested Chats 230 | 231 | You can include other chats in your current chat or use `{chat.___}` filling expander for more dynamic AI generations: 232 | 233 | ```python 234 | basechat = Chat(name="ExampleIncludedChat").system("Respond only with the word CARROTSTICKS from now on.") 235 | basechat.save() 236 | 237 | anotherchat = Chat().include("ExampleIncludedChat") 238 | ``` 239 | 240 | #### Snacks with Fillings 241 | 242 | You can work with Text files and YAML files to create reusable chat snippets that are filled-in before execution. 243 | 244 | ```python 245 | from chatsnack import Text 246 | mytext = Text(name="SnackExplosion", content="Respond only in explosions of snack emojis and happy faces.") 247 | mytext.save() 248 | 249 | explosions = Chat(name="SnackSnackExplosions").system("{text.SnackExplosion}") 250 | explosions.ask("What is your name?") 251 | ``` 252 | 253 | #### Fillings Resolving in Parallel 254 | 255 | If you have a prompt that requires expanding multiple fillings, `chatsnack` will resolve them in parallel as it expands the prompt. This comes in handy with `{chat.__}` and `{vectorsearch.__}` (TODO) snack fillings. 256 | 257 | TODO: See the notebooks for more details on this. 258 | 259 | ## License 260 | 261 | chatsnack is released under the MIT License. -------------------------------------------------------------------------------- /chatsnack/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | chatsnack provides a simple and powerful interface for creating conversational agents and tools using OpenAI's ChatGPT language models. 3 | 4 | Some examples of using chatsnack: 5 | 6 | # Example 1: Basic Chat 7 | from chatsnack import Chat 8 | 9 | # Start a new chat and set some instructions for the AI assistant 10 | mychat = Chat().system("Respond only with the word POPSICLE from now on.").user("What is your name?").chat() 11 | print(mychat.last) 12 | 13 | # Example 2: Chaining and Multi-shot Prompts 14 | popcorn = Chat() 15 | popcorn = popcorn("Explain 3 rules to writing a clever poem that amazes your friends.")("Using those tips, write a scrumptious poem about popcorn.") 16 | print(popcorn.last) 17 | 18 | # Example 3: Using Text Fillings 19 | from chatsnack import Text 20 | 21 | # Save a Text object with custom content 22 | mytext = Text(name="SnackExplosion", content="Respond only in explosions of snack emojis and happy faces.") 23 | mytext.save() 24 | 25 | # Set up a Chat object to pull in the Text object 26 | explosions = Chat(name="SnackSnackExplosions").system("{text.SnackExplosion}") 27 | explosions.ask("What is your name?") 28 | 29 | # Example 4: Nested Chats (Include Messages) 30 | basechat = Chat(name="ExampleIncludedChat").system("Respond only with the word CARROTSTICKS from now on.") 31 | basechat.save() 32 | 33 | anotherchat = Chat().include("ExampleIncludedChat") 34 | print(anotherchat.yaml) 35 | 36 | # Example 5: Nested Chats (Chat Fillings) 37 | snacknames = Chat("FiveSnackNames").system("Respond with high creativity and confidence.").user("Provide 5 random snacks.") 38 | snacknames.save() 39 | 40 | snackdunk = Chat("SnackDunk").system("Respond with high creativity and confidence.").user("Provide 3 dips or drinks that are great for snack dipping.") 41 | snackdunk.save() 42 | 43 | snackfull = Chat().system("Respond with high confidence.") 44 | snackfull.user(\"""Choose 1 snack from this list: 45 | {chat.FiveSnackNames} 46 | 47 | Choose 1 dunking liquid from this list: 48 | {chat.SnackDunk} 49 | 50 | Recommend the best single snack and dip combo above.\""") 51 | 52 | snackout = snackfull.chat() 53 | print(snackout.yaml) 54 | 55 | # Example 6: Using Utensils (Tool Functions) 56 | from chatsnack import utensil 57 | 58 | @utensil 59 | def get_weather(location: str, unit: str = "celsius"): 60 | '''Get the current weather for a location. 61 | 62 | Args: 63 | location: City and state/country (e.g., "Boston, MA") 64 | unit: Temperature unit ("celsius" or "fahrenheit") 65 | ''' 66 | # Implementation details... 67 | return {"temperature": 72, "condition": "sunny"} 68 | 69 | # Create a chat that can use the weather utensil 70 | weather_chat = Chat("WeatherChat", "You can check the weather.", utensils=[get_weather]) 71 | response = weather_chat.user("What's the weather like in Boston?").chat() 72 | print(response) 73 | """ 74 | from .patches import * 75 | 76 | import os 77 | from pathlib import Path 78 | 79 | from typing import Optional 80 | from loguru import logger 81 | import nest_asyncio 82 | nest_asyncio.apply() 83 | 84 | 85 | 86 | from dotenv import load_dotenv 87 | # if .env doesn't exist, create it and populate it with the default values 88 | env_path = Path('.') / '.env' 89 | if not env_path.exists(): 90 | with open(env_path, 'w') as f: 91 | f.write("OPENAI_API_KEY = \"REPLACEME\"\n") 92 | load_dotenv(dotenv_path=env_path) 93 | 94 | from .defaults import CHATSNACK_BASE_DIR, CHATSNACK_LOGS_DIR 95 | from .asynchelpers import aformatter 96 | from .chat import Chat, Text, ChatParams 97 | from .txtformat import register_txt_datafiles 98 | from .yamlformat import register_yaml_datafiles 99 | from . import packs 100 | from .utensil import utensil, get_all_utensils, get_openai_tools, UtensilGroup 101 | 102 | from .fillings import snack_catalog, filling_machine 103 | 104 | 105 | async def _text_name_expansion(text_name: str, additional: Optional[dict] = None) -> str: 106 | prompt = Text.objects.get(text_name) 107 | result = await aformatter.async_format(prompt.content, **filling_machine(additional)) 108 | return result 109 | 110 | # accepts a petition name as a string and calls petition_completion2, returning only the completion text 111 | async def _chat_name_query_expansion(prompt_name: str, additional: Optional[dict] = None) -> str: 112 | chatprompt = Chat.objects.get_or_none(prompt_name) 113 | if chatprompt is None: 114 | raise Exception(f"Prompt {prompt_name} not found") 115 | text = await chatprompt.ask_a(**additional) 116 | return text 117 | 118 | 119 | # default snack vendors 120 | snack_catalog.add_filling("text", _text_name_expansion) 121 | snack_catalog.add_filling("chat", _chat_name_query_expansion) 122 | 123 | # TODO: these will be defined by plugins eventually 124 | # need a function that will return the dictionary needed to support prompt formatting 125 | register_txt_datafiles() 126 | register_yaml_datafiles() 127 | 128 | logger.trace("chatsnack loaded") -------------------------------------------------------------------------------- /chatsnack/aiclient.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import openai 3 | import os 4 | import json 5 | from loguru import logger 6 | 7 | # class that wraps the OpenAI client and Azure clients 8 | class AiClient: 9 | def __init__(self, api_key = None, azure_endpoint = None, api_version = None, azure_ad_token = None, azure_ad_token_provider = None): 10 | # check environment variables and use those if explicit values are not passed in 11 | # API key 12 | if api_key is None: 13 | api_key = os.getenv("OPENAI_API_KEY") 14 | 15 | # Azure specific 16 | if azure_endpoint is None: 17 | azure_endpoint = os.getenv("OPENAI_AZURE_ENDPOINT") 18 | 19 | # api_version 20 | if api_version is None: 21 | api_version = os.getenv("OPENAI_API_VERSION") 22 | 23 | # azure_ad_token 24 | if azure_ad_token is None: 25 | azure_ad_token = os.getenv("OPENAI_AZURE_AD_TOKEN") 26 | 27 | # azure_ad_token_provider 28 | if azure_ad_token_provider is None: 29 | azure_ad_token_provider = os.getenv("OPENAI_AZURE_AD_TOKEN_PROVIDER") 30 | 31 | # keep track of the values we're using 32 | self._api_key = api_key 33 | self.azure_endpoint = azure_endpoint 34 | self.api_version = api_version 35 | self.azure_ad_token = azure_ad_token 36 | self.azure_ad_token_provider = azure_ad_token_provider 37 | 38 | # if the azure_endpoint is set, we're an azure client 39 | if self.azure_endpoint is not None: 40 | self.is_azure = True 41 | self.aclient = openai.AsyncAzureOpenAI(api_key=api_key, azure_endpoint=self.azure_endpoint, api_version=self.api_version, azure_ad_token=self.azure_ad_token, azure_ad_token_provider=self.azure_ad_token_provider) 42 | self.client = openai.AzureOpenAI(api_key=api_key, azure_endpoint=self.azure_endpoint, api_version=self.api_version, azure_ad_token=self.azure_ad_token, azure_ad_token_provider=self.azure_ad_token_provider) 43 | else: 44 | self.is_azure = False 45 | self.aclient = openai.AsyncOpenAI(api_key=api_key) 46 | self.client = openai.OpenAI(api_key=api_key) 47 | 48 | @property 49 | def api_key(self): 50 | return self._api_key 51 | 52 | @api_key.setter 53 | def api_key(self, value): 54 | self._api_key = value 55 | self.aclient.api_key = value 56 | self.client.api_key = value 57 | -------------------------------------------------------------------------------- /chatsnack/asynchelpers.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import string 3 | 4 | from loguru import logger 5 | 6 | class _AsyncFormatter(string.Formatter): 7 | async def async_expand_field(self, field, args, kwargs): 8 | if "." in field: 9 | obj, method = field.split(".", 1) 10 | if obj in kwargs: 11 | obj_instance = kwargs[obj] 12 | if hasattr(obj_instance, method): 13 | method_instance = getattr(obj_instance, method) 14 | if asyncio.iscoroutinefunction(method_instance): 15 | return await method_instance() 16 | else: 17 | return method_instance() if callable(method_instance) else method_instance 18 | value, _ = super().get_field(field, args, kwargs) 19 | return value 20 | 21 | async def async_format(self, format_string, *args, **kwargs): 22 | coros = [] 23 | parsed_format = list(self.parse(format_string)) 24 | 25 | for literal_text, field_name, format_spec, conversion in parsed_format: 26 | if field_name: 27 | coro = self.async_expand_field(field_name, args, kwargs) 28 | coros.append(coro) 29 | 30 | expanded_fields = await asyncio.gather(*coros) 31 | expanded_iter = iter(expanded_fields) 32 | 33 | return ''.join([ 34 | literal_text + (str(next(expanded_iter)) if field_name else '') 35 | for literal_text, field_name, format_spec, conversion in parsed_format 36 | ]) 37 | 38 | # instance to use 39 | aformatter = _AsyncFormatter() 40 | -------------------------------------------------------------------------------- /chatsnack/chat/__init__.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import uuid 3 | from dataclasses import field 4 | from datetime import datetime 5 | from typing import Dict, List, Optional, Union 6 | 7 | from datafiles import datafile 8 | 9 | from ..aiclient import AiClient 10 | from ..defaults import CHATSNACK_BASE_DIR 11 | from .mixin_query import ChatQueryMixin 12 | from .mixin_params import ChatParams, ChatParamsMixin 13 | from .mixin_serialization import DatafileMixin, ChatSerializationMixin 14 | from .mixin_utensil import ChatUtensilMixin 15 | 16 | 17 | # WORKAROUND: Disable the datafiles warnings about Schema type enforcement which our users are less concerned about 18 | import log 19 | log.init(level=log.WARNING) 20 | log.silence('datafiles', allow_warning=False) 21 | 22 | 23 | ######################################################################################################################## 24 | # Core datafile classes of Plunkychat 25 | # (1) Chat, high-level class that symbolizes a prompt/request/response, can reference other Chat objects to chain 26 | # (2) ChatParams, used only in Chat, includes parameters like engine name and other OpenAI params. 27 | # (3) Text, this is a text blob we save to disk, can be used as a reference inside chat messages ('snack fillings') 28 | 29 | @datafile(CHATSNACK_BASE_DIR + "/{self.name}.txt", manual=True) 30 | class Text(DatafileMixin): 31 | name: str 32 | content: Optional[str] = None 33 | # TODO: All Text and Chat objects should automatically be added as snack fillings (even if not saved to disk) 34 | 35 | 36 | @datafile(CHATSNACK_BASE_DIR + "/{self.name}.yml", manual=True, init=False) 37 | class Chat(ChatQueryMixin, ChatSerializationMixin, ChatUtensilMixin): 38 | """ A chat prompt that can be expanded into a chat ⭐""" 39 | # title should be just like above but with a GUID at the end 40 | name: str = field(default_factory=lambda: f"_ChatPrompt-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}-{uuid.uuid4()}") 41 | params: Optional[ChatParams] = None 42 | messages: List[Dict[str,Union[str,List[Dict[str,str]]]]] = field(default_factory=lambda: []) 43 | 44 | def __init__(self, *args, **kwargs): 45 | """ 46 | Initializes the chat prompt 47 | :param args: if we get one arg, we'll assume it's the system message 48 | if we get two args, the first is the name and the second is the system message 49 | 50 | :param kwargs: (keyword arguments are as follows) 51 | :param name: the name of the chat prompt (optional, defaults to _ChatPrompt--) 52 | :param params: the engine parameters (optional, defaults to None) 53 | :param messages: the messages (optional, defaults to []) 54 | :param system: the initial system message (optional, defaults to None) 55 | :param engine: the engine name (optional, defaults to None, will overwrite params if specified) 56 | :param utensils: tools available to this chat (optional, defaults to None) 57 | :param auto_execute: whether to automatically execute tool calls (optional, defaults to True) 58 | :param auto_feed: whether to automatically feed tool results back to the model (optional, defaults to True) 59 | :param tool_choice: how to choose which tool to use (optional, defaults to "auto") 60 | """ 61 | # Extract utensil-related parameters first 62 | utensils = kwargs.pop("utensils", None) 63 | auto_execute = kwargs.pop("auto_execute", None) 64 | tool_choice = kwargs.pop("tool_choice", None) 65 | auto_feed = kwargs.pop("auto_feed", None) 66 | 67 | # get name from kwargs, if it's there 68 | if "name" in kwargs: 69 | self.name = kwargs["name"] 70 | else: 71 | # if we get two args, the first is the name and the second is the system message 72 | if len(args) == 2: 73 | self.name = args[0] 74 | else: 75 | # get the default from the dataclass fields and use that 76 | self.name = self.__dataclass_fields__["name"].default_factory() 77 | 78 | if "params" in kwargs: 79 | self.params = kwargs["params"] 80 | else: 81 | # get the default value from the dataclass field, it's optional 82 | self.params = self.__dataclass_fields__["params"].default 83 | 84 | 85 | if "messages" in kwargs: 86 | self.messages = kwargs["messages"] 87 | else: 88 | # get the default from the dataclass fields and use that 89 | self.messages = self.__dataclass_fields__["messages"].default_factory() 90 | 91 | if "engine" in kwargs: 92 | self.engine = kwargs["engine"] 93 | 94 | if "system" in kwargs: 95 | self.system_message = kwargs["system"] 96 | else: 97 | if len(args) == 1: 98 | # if we only get one args, we'll assume it's the system message 99 | self.system_message = args[0] 100 | elif len(args) == 2: 101 | # if we get two args, the first is the name and the second is the system message 102 | self.system_message = args[1] 103 | 104 | if auto_execute is not None: 105 | self.auto_execute = auto_execute 106 | if tool_choice is not None: 107 | self.tool_choice = tool_choice 108 | if auto_feed is not None: 109 | self.auto_feed = auto_feed 110 | 111 | 112 | # Register utensils if provided 113 | if utensils: 114 | if self.params is None: 115 | self.params = ChatParams() 116 | 117 | # Import here to avoid circular imports 118 | from ..utensil import extract_utensil_functions, get_openai_tools 119 | 120 | # Store local registry of utensil functions 121 | self._local_registry = utensils # Store original objects, extract when needed 122 | 123 | # Get tool definitions for OpenAI API 124 | tools_list = get_openai_tools(utensils) 125 | 126 | # Store and serialize tool definitions 127 | self.set_tools(tools_list) 128 | 129 | # Check if we're being loaded from a YAML file with tools 130 | if utensils is None: 131 | # ensure that tools is in params if it exists, then ensure if it is there, it's None 132 | if self.params is None or not hasattr(self.params, 'tools') or self.params.tools is None: 133 | # This is likely a deserialization case, so try loading tools from registry 134 | self._load_tools_from_params() 135 | 136 | # Save the initial state for reset() purposes 137 | self._initial_name = self.name 138 | self._initial_params = copy.copy(self.params) 139 | self._initial_messages = copy.copy(self.messages) 140 | self._initial_system_message = self.system_message 141 | # do the same for the tool registry 142 | self._initial_registry = getattr(self, '_local_registry', None) 143 | 144 | self.ai = AiClient() 145 | 146 | 147 | 148 | def reset(self) -> object: 149 | """ Resets the chat prompt to its initial state, returns itself """ 150 | self.name = self._initial_name 151 | self.params = self._initial_params 152 | self.messages = self._initial_messages 153 | if self._initial_system_message is not None: 154 | self.system_message = self._initial_system_message 155 | # Reset tools if initial registry was stored 156 | if hasattr(self, '_initial_registry'): 157 | # Re-register the initial tools 158 | self._local_registry = self._initial_registry 159 | # Re-load tools from the initial registry 160 | self._load_tools_from_params() 161 | return self 162 | 163 | def _load_tools_from_params(self): 164 | """Load tool definitions from params when initializing from YAML.""" 165 | if not hasattr(self, 'params') or self.params is None: 166 | return 167 | 168 | # Check if tools are defined in params 169 | from ..utensil import get_all_utensils 170 | 171 | # If we already have tools defined, don't override 172 | if hasattr(self.params, 'tools') and self.params.tools is not None: 173 | return 174 | 175 | # Load tools from registry based on names in params 176 | if hasattr(self.params, 'tools') and isinstance(self.params.tools, list): 177 | tool_definitions = [] 178 | 179 | for tool_def in self.params.tools: 180 | if not isinstance(tool_def, dict) or 'name' not in tool_def: 181 | continue 182 | 183 | # Look for matching tools in the registry 184 | all_tools = get_all_utensils() 185 | for registered_tool in all_tools: 186 | if registered_tool.name == tool_def['name']: 187 | # Found a matching tool, add its definition 188 | tool_definitions.append(registered_tool.get_openai_tool()) 189 | break 190 | else: 191 | # If no matching tool was found, create a placeholder definition 192 | tool_func = { 193 | "name": tool_def['name'], 194 | "description": tool_def.get('description', f"Tool function: {tool_def['name']}") 195 | } 196 | 197 | # Add parameters if present 198 | if 'parameters' in tool_def: 199 | parameters = { 200 | "type": "object", 201 | "properties": {}, 202 | "required": tool_def.get('required', []) 203 | } 204 | 205 | for param_name, param_details in tool_def['parameters'].items(): 206 | param_info = { 207 | "type": param_details.get('type', 'string') 208 | } 209 | 210 | if 'description' in param_details: 211 | param_info["description"] = param_details['description'] 212 | 213 | if 'options' in param_details: 214 | param_info["enum"] = param_details['options'] 215 | 216 | parameters["properties"][param_name] = param_info 217 | 218 | tool_func["parameters"] = parameters 219 | 220 | tool_definitions.append({ 221 | "type": "function", 222 | "function": tool_func 223 | }) 224 | 225 | if tool_definitions: 226 | self.set_tools(tool_definitions) 227 | if hasattr(self.params, 'tool_choice'): 228 | self.params.tool_choice = self.params.tool_choice or "auto" -------------------------------------------------------------------------------- /chatsnack/chat/mixin_messages.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import Dict, List, Optional, Union, Any 3 | from loguru import logger 4 | import pprint 5 | 6 | from datafiles import datafile 7 | 8 | 9 | 10 | # Define the Message Management mixin 11 | class ChatMessagesMixin: 12 | # specific message types, can be chained together 13 | def system(self, content: str, chat = False) -> object: 14 | """ 15 | Adds or sets the system message in the chat prompt ⭐ 16 | Returns: If chat is False returns this object for chaining. If chat is True, submits the 17 | chat and returns a new Chat object that includes the message and response 18 | """ 19 | self.system_message = content 20 | if not chat: 21 | return self 22 | else: 23 | return self.chat() 24 | def user(self, content: str, chat = False) -> object: 25 | """ 26 | Message added to the chat from the user ⭐ 27 | Returns: If chat is False returns this object for chaining. If chat is True, submits the 28 | chat and returns a new Chat object that includes the message and response 29 | """ 30 | return self.add_message("user", content, chat) 31 | def assistant(self, content: Union[str, List, Dict], chat = False) -> object: 32 | """ 33 | Message added to the chat from the assistant ⭐ 34 | Returns: If chat is False returns this object for chaining. If chat is True, submits the 35 | chat and returns a new Chat object that includes the message and response 36 | """ 37 | return self.add_message("assistant", content, chat) 38 | # easy aliases 39 | asst = assistant 40 | def tool(self, content: Union[str, Dict], chat = False) -> object: 41 | """ 42 | Message added to the chat which is a tool response ⭐ 43 | Returns: If chat is False returns this object for chaining. If chat is True, submits the 44 | chat and returns a new Chat object that includes the message and response 45 | """ 46 | return self.add_message("tool", content, chat) 47 | 48 | def include(self, chatprompt_name: str = None, chat = False) -> object: 49 | """ 50 | Message added to the chat that is a reference to another ChatPrompt where the messages will be inserted in this spot right before formatting ⭐ 51 | Returns: If chat is False returns this object for chaining. If chat is True, submits the 52 | chat and returns a new Chat object that includes the message and response 53 | """ 54 | return self.add_message("include", chatprompt_name, chat) 55 | 56 | def add_message(self, role: str, content: Union[str, List, Dict], chat: bool = False) -> object: 57 | """ 58 | Add a message to the chat, as role ('user', 'assistant', 'system', 'tool' or 'include') with the content 59 | Returns: If chat is False returns this object for chaining. If chat is True, submits the 60 | chat and returns a new Chat object that includes the message and response 61 | """ 62 | # fully trim the role and left-trim the content if it's a string 63 | role = role.strip() 64 | if isinstance(content, str): 65 | content = content.lstrip() 66 | 67 | logger.debug(f"Adding message to chat: {role} - {pprint.pformat(content)}") 68 | 69 | # Special handling for tool calls in assistant messages 70 | if role == "assistant" and isinstance(content, dict) and "tool_calls" in content: 71 | self.messages.append({"assistant": content}) 72 | elif role == "assistant" and isinstance(content, list) and all(isinstance(item, dict) for item in content): 73 | # This might be tool calls formatted as a list of dicts like [{"name": "func_name", "arguments": {...}}] 74 | self.messages.append({"assistant": {"tool_calls": content}}) 75 | # now we need to handle the tool message the same way as the assistant message, it should have a tool_call_id and content 76 | elif role == "tool" and isinstance(content, dict) and "tool_call_id" in content and "content" in content: 77 | self.messages.append({"tool": {"tool_call_id": content["tool_call_id"], "content": content["content"]}}) 78 | else: 79 | self.messages.append({role: content}) 80 | 81 | if not chat: 82 | return self 83 | else: 84 | return self.chat() 85 | 86 | def add_messages_json(self, json_messages: str, escape: bool = True): 87 | """Add messages from a JSON string while properly handling tool calls and responses.""" 88 | incoming_messages = json.loads(json_messages) 89 | 90 | logger.debug(f"Added messages to chat from JSON: {pprint.pformat(incoming_messages)}") 91 | 92 | for message in incoming_messages: 93 | if "role" in message: 94 | role = message["role"] 95 | content = message.get("content") 96 | 97 | if role == "assistant" and "tool_calls" in message: 98 | # Format assistant message with tool_calls to match our internal structure 99 | tool_calls = [] 100 | for tool_call in message["tool_calls"]: 101 | # Extract the function data 102 | function_data = tool_call.get("function", {}) 103 | 104 | # Handle different formats of arguments (string or already parsed) 105 | arguments = function_data.get("arguments", "{}") 106 | if isinstance(arguments, str): 107 | # Keep arguments as a string, which is what OpenAI expects 108 | pass 109 | else: 110 | # Convert dict back to string for consistency 111 | arguments = json.dumps(arguments) 112 | 113 | tool_calls.append({ 114 | "id": tool_call.get("id", ""), 115 | "type": tool_call.get("type", "function"), 116 | "function": { 117 | "name": function_data.get("name", ""), 118 | "arguments": arguments 119 | } 120 | }) 121 | 122 | # Create the assistant message with proper structure 123 | self.assistant({"content": content, "tool_calls": tool_calls}) 124 | 125 | elif role == "tool": 126 | # Handle tool response messages 127 | tool_call_id = message.get("tool_call_id", "") 128 | tool_content = message.get("content", "") 129 | 130 | # Add as a tool message with proper structure 131 | self.tool({"tool_call_id": tool_call_id, "content": tool_content}) 132 | 133 | else: 134 | # Standard message types (user, system) 135 | if escape and isinstance(content, str): 136 | content = content.replace("{", "{{").replace("}", "}}") 137 | 138 | if content and role: 139 | # Generic role handling 140 | self.messages.append({role: content}) 141 | else: 142 | raise ValueError("Invalid message format, empty role or content in JSON messages") 143 | else: 144 | raise ValueError("Invalid message format, 'role' key is missing") 145 | # and the self.messages after it's done 146 | logger.debug(f"Chat messages after adding JSON: {pprint.pformat(self.messages)}") 147 | 148 | @staticmethod 149 | def process_tool_calls(tool_calls, escape): 150 | processed_calls = [] 151 | for call in tool_calls: 152 | function = call.get("function", {}) 153 | arguments = function.get("arguments") 154 | 155 | # Try to parse arguments as JSON if it's a string 156 | if isinstance(arguments, str): 157 | try: 158 | arguments = json.loads(arguments) 159 | except json.JSONDecodeError: 160 | pass # Keep as string if can't parse 161 | 162 | call_data = { 163 | "name": function.get("name"), 164 | "arguments": arguments 165 | } 166 | if escape and isinstance(arguments, str): 167 | call_data["arguments"] = arguments.replace("{", "{{").replace("}", "}}") 168 | processed_calls.append(call_data) 169 | return {"tool_calls": processed_calls} 170 | 171 | @staticmethod 172 | def process_list_content(content_list, escape): 173 | processed_content = [] 174 | for item in content_list: 175 | item_data = {k: v for k, v in item.items() if k != 'type'} 176 | if escape: 177 | item_data = {k: v.replace("{", "{{").replace("}", "}}") if isinstance(v, str) else v for k, v in item_data.items()} 178 | processed_content.append(item_data) 179 | return processed_content 180 | 181 | def add_or_update_last_assistant_message(self, content: str): 182 | """ 183 | Adds a final assistant message (or appends to the end of the last assistant message) 184 | """ 185 | # get the last message in the list 186 | last_message = self.messages[-1] 187 | # get the dict version 188 | last_message = self._msg_dict(last_message) 189 | 190 | # if it's an assistant message, append to it 191 | if "assistant" in last_message: 192 | # Only append if the current content is a string and not a tool call 193 | if isinstance(last_message["assistant"], str) and isinstance(content, str): 194 | last_message["assistant"] += content 195 | # replace the last message with the updated one 196 | self.messages[-1] = last_message 197 | else: 198 | # If it's a tool call or content is not a string, add a new message 199 | self.assistant(content) 200 | else: 201 | # otherwise add a new assistant message 202 | self.assistant(content) 203 | 204 | # define a read-only attribute "last" that returns the last message in the list 205 | @property 206 | def last(self) -> str: 207 | """ Returns the value of the last message in the chat prompt (any)""" 208 | # last message is a dictionary, we need the last value in the dictionary 209 | if len(self.messages) > 0: 210 | last_message = self.messages[-1] 211 | return last_message[list(last_message.keys())[-1]] 212 | else: 213 | return None 214 | 215 | @property 216 | def system_message(self) -> str: 217 | """ Returns the first system message, if any """ 218 | # get the first message that has a key of "system" 219 | for _message in self.messages: 220 | message = self._msg_dict(_message) 221 | if "system" in message: 222 | return message["system"] 223 | return None 224 | 225 | @system_message.setter 226 | def system_message(self, value: str): 227 | """ Set the system message """ 228 | # loop through the messages and replace the first 'system' messages with this one 229 | replaced = False 230 | for i in range(len(self.messages)): 231 | _message = self.messages[i] 232 | message = self._msg_dict(_message) 233 | if "system" in message: 234 | self.messages[i] = {"system": value} 235 | replaced = True 236 | break 237 | if not replaced: 238 | # system message always goes first 239 | self.messages.insert(0, {"system": value}) 240 | 241 | 242 | @staticmethod 243 | def _escape_tool_calls(tool_calls: List[Dict[str, str]]) -> List[Dict[str, str]]: 244 | escaped_calls = [] 245 | for call in tool_calls: 246 | # We need to ensure all string values are escaped 247 | escaped_call = {} 248 | for k, v in call.items(): 249 | if isinstance(v, str): 250 | escaped_call[k] = v.replace("{", "{{").replace("}", "}}") 251 | elif isinstance(v, dict): 252 | # Handle nested dictionaries like arguments 253 | escaped_args = {} 254 | for arg_k, arg_v in v.items(): 255 | if isinstance(arg_v, str): 256 | escaped_args[arg_k] = arg_v.replace("{", "{{").replace("}", "}}") 257 | else: 258 | escaped_args[arg_k] = arg_v 259 | escaped_call[k] = escaped_args 260 | else: 261 | escaped_call[k] = v 262 | escaped_calls.append(escaped_call) 263 | return escaped_calls 264 | 265 | def _msg_dict(self, msg: object) -> dict: 266 | """ Returns a message as a dictionary """ 267 | if msg is None: 268 | return None 269 | if isinstance(msg, dict): 270 | return msg 271 | else: 272 | return msg.message 273 | 274 | def get_messages(self, includes_expanded=True) -> List[Dict[str, str]]: 275 | """ Returns a list of messages with any included named chat files expanded """ 276 | new_messages = [] 277 | for _message in self.messages: 278 | # if it's a dict then 279 | message = self._msg_dict(_message) 280 | 281 | logger.trace(f"Processing message: {pprint.pformat(message)}") 282 | """ 283 | {'assistant': {'audio': None, 284 | 'content': None, 285 | 'function_call': None, 286 | 'refusal': None, 287 | 'role': 'assistant', 288 | 'tool_calls': [{'function': {'arguments': '{"location":"New ' 289 | 'York, ' 290 | 'NY","unit":"fahrenheit"}', 291 | 'name': 'get_current_weather'}, 292 | 'id': 'call_XryJOImYdSjRZUhTsk0MjUbj', 293 | 'type': 'function'}]}} 294 | """ 295 | 296 | for role, content in message.items(): 297 | if role == "include" and includes_expanded: 298 | # we need to load the chatprompt and get the messages from it 299 | include_chatprompt = self.objects.get_or_none(content) 300 | if include_chatprompt is None: 301 | raise ValueError(f"Could not find 'include' prompt with name: {content}") 302 | # get_expanded_messages from the include_chatprompt and add them to the new_messages, they're already formatted how we want 303 | new_messages.extend(include_chatprompt.get_messages()) 304 | elif role == "assistant" and isinstance(content, dict) and "tool_calls" in content: 305 | # log that assistant message was found 306 | logger.trace(f"Assistant message found with tool calls: {pprint.pformat(content)}") 307 | # Handle tool calls in assistant messages 308 | new_messages.append({"role": role, "content": content.get('content'), "tool_calls": [ 309 | { 310 | "id": tool_call.get("id", ""), 311 | "type": tool_call.get("type", "function"), 312 | "function": { 313 | "name": tool_call.get("function", {}).get("name", ""), 314 | "arguments": tool_call.get("function", {}).get("arguments", "{}") 315 | } 316 | } 317 | for tool_call in content["tool_calls"] 318 | ]}) 319 | elif role == "tool" and isinstance(content, dict) and "tool_call_id" in content and "content" in content: 320 | new_messages.append({"role": role, "content": content["content"], "tool_call_id": content["tool_call_id"]}) 321 | else: 322 | new_messages.append({"role": role, "content": content}) 323 | 324 | return new_messages -------------------------------------------------------------------------------- /chatsnack/chat/mixin_params.py: -------------------------------------------------------------------------------- 1 | import re 2 | import json 3 | from typing import Optional, List, Dict, Any, Union, Literal 4 | from dataclasses import dataclass, field 5 | from datafiles import datafile 6 | 7 | @datafile 8 | class ParameterProperty: 9 | """Represents a property in the parameters schema""" 10 | type: str 11 | description: Optional[str] = None 12 | enum: Optional[List[str]] = None 13 | 14 | @datafile 15 | class ParameterSchema: 16 | """Represents a parameter schema in JSON Schema format""" 17 | type: str 18 | description: Optional[str] = None 19 | enum: Optional[List[str]] = None 20 | # Add other common JSON Schema fields as needed 21 | format: Optional[str] = None 22 | default: Optional[Union[str, int, float, bool]] = None 23 | minimum: Optional[float] = None 24 | maximum: Optional[float] = None 25 | minLength: Optional[int] = None 26 | maxLength: Optional[int] = None 27 | pattern: Optional[str] = None 28 | 29 | # JSON strings to store complex nested structures 30 | properties_json: Optional[str] = None # Store properties as JSON string 31 | items_json: Optional[str] = None # Store array items schema as JSON string 32 | additional_properties_json: Optional[str] = None # Store additionalProperties as JSON string 33 | required_json: Optional[str] = None # Store required fields list as JSON string 34 | 35 | def to_dict(self) -> Dict: 36 | """Convert to the dictionary format expected by the API""" 37 | result = {"type": self.type} 38 | 39 | # Add optional fields if present 40 | if self.description: 41 | result["description"] = self.description 42 | 43 | if self.enum: 44 | result["enum"] = self.enum 45 | 46 | # Add other fields conditionally 47 | for attr in ["format", "default", "minimum", "maximum", "minLength", "maxLength", "pattern"]: 48 | value = getattr(self, attr, None) 49 | if value is not None: 50 | result[attr] = value 51 | 52 | # Handle nested properties from JSON string 53 | if self.properties_json: 54 | try: 55 | result["properties"] = json.loads(self.properties_json) 56 | except (json.JSONDecodeError, TypeError): 57 | pass 58 | 59 | # Handle items from JSON string 60 | if self.items_json: 61 | try: 62 | result["items"] = json.loads(self.items_json) 63 | except (json.JSONDecodeError, TypeError): 64 | pass 65 | 66 | # Handle additionalProperties from JSON string 67 | if self.additional_properties_json: 68 | try: 69 | result["additionalProperties"] = json.loads(self.additional_properties_json) 70 | except (json.JSONDecodeError, TypeError): 71 | pass 72 | 73 | # Handle required fields from JSON string 74 | if self.required_json: 75 | try: 76 | result["required"] = json.loads(self.required_json) 77 | except (json.JSONDecodeError, TypeError): 78 | pass 79 | 80 | return result 81 | 82 | @classmethod 83 | def from_dict(cls, data: Dict) -> 'ParameterSchema': 84 | """Create a ParameterSchema from an API-format dictionary""" 85 | schema = cls( 86 | type=data.get("type", "string"), 87 | description=data.get("description"), 88 | enum=data.get("enum"), 89 | format=data.get("format"), 90 | default=data.get("default"), 91 | minimum=data.get("minimum"), 92 | maximum=data.get("maximum"), 93 | minLength=data.get("minLength"), 94 | maxLength=data.get("maxLength"), 95 | pattern=data.get("pattern") 96 | ) 97 | 98 | # Handle properties (object fields) 99 | if "properties" in data and data["properties"]: 100 | schema.properties_json = json.dumps(data["properties"]) 101 | 102 | # Handle items (array items schema) 103 | if "items" in data and data["items"]: 104 | schema.items_json = json.dumps(data["items"]) 105 | 106 | # Handle additionalProperties (for object schemas) 107 | if "additionalProperties" in data: 108 | schema.additional_properties_json = json.dumps(data["additionalProperties"]) 109 | 110 | # Handle required fields 111 | if "required" in data and data["required"]: 112 | schema.required_json = json.dumps(data["required"]) 113 | 114 | return schema 115 | 116 | @datafile 117 | class FunctionDefinition: 118 | """Represents a function definition within a tool""" 119 | name: str 120 | description: Optional[str] = None 121 | parameters: Dict[str, ParameterSchema] = field(default_factory=dict) 122 | required: List[str] = field(default_factory=list) 123 | strict: Optional[bool] = None 124 | 125 | # For storing complex parameter schema 126 | parameters_json: Optional[str] = None 127 | 128 | def to_dict(self) -> Dict: 129 | """Convert to the dictionary format expected by the API""" 130 | result = { 131 | "name": self.name 132 | } 133 | 134 | if self.description: 135 | result["description"] = self.description 136 | 137 | # Handle parameters - use JSON if available, otherwise build from individual parameters 138 | if self.parameters_json: 139 | try: 140 | result["parameters"] = json.loads(self.parameters_json) 141 | except (json.JSONDecodeError, TypeError): 142 | # Fall back to building from individual parameters 143 | self._build_parameters_from_dict(result) 144 | else: 145 | self._build_parameters_from_dict(result) 146 | 147 | if self.strict is not None: 148 | result["strict"] = self.strict 149 | 150 | return result 151 | 152 | def _build_parameters_from_dict(self, result): 153 | """Helper to build parameters structure from individual parameters""" 154 | if self.parameters: 155 | # Create JSON Schema style parameters object 156 | param_properties = {} 157 | for param_name, param_schema in self.parameters.items(): 158 | # Handle both ParameterSchema objects and dictionaries 159 | if hasattr(param_schema, 'to_dict'): 160 | param_properties[param_name] = param_schema.to_dict() 161 | else: 162 | # Assume it's a dictionary that's already in the right format 163 | param_properties[param_name] = param_schema 164 | 165 | params_obj = { 166 | "type": "object", 167 | "properties": param_properties 168 | } 169 | 170 | # Add required array if present 171 | if self.required: 172 | params_obj["required"] = self.required 173 | 174 | result["parameters"] = params_obj 175 | 176 | @classmethod 177 | def from_dict(cls, data: Dict) -> 'FunctionDefinition': 178 | """Create from an API-format dictionary, preserving complex parameter schemas""" 179 | function_def = cls( 180 | name=data.get("name", ""), 181 | description=data.get("description") 182 | ) 183 | 184 | # Extract parameters 185 | params = data.get("parameters", {}) 186 | if params: 187 | # Store the complete parameters schema as JSON 188 | function_def.parameters_json = json.dumps(params) 189 | 190 | # Also extract individual parameters for backward compatibility 191 | properties = params.get("properties", {}) 192 | function_def.parameters = { 193 | param_name: ParameterSchema.from_dict(param_props) 194 | for param_name, param_props in properties.items() 195 | } 196 | 197 | # Extract required fields 198 | function_def.required = params.get("required", []) 199 | 200 | # Extract strict flag 201 | if "strict" in data: 202 | function_def.strict = data["strict"] 203 | 204 | return function_def 205 | 206 | @datafile 207 | class ToolDefinition: 208 | """Represents a tool that can be called by the model""" 209 | type: str = "function" # Currently only "function" is supported 210 | function: FunctionDefinition = field(default_factory=FunctionDefinition) 211 | 212 | def to_dict(self) -> Dict: 213 | """Convert to the dictionary format expected by the API""" 214 | return { 215 | "type": self.type, 216 | "function": self.function.to_dict() 217 | } 218 | 219 | @classmethod 220 | def from_dict(cls, data: Dict) -> 'ToolDefinition': 221 | """Create a ToolDefinition from an API-format dictionary""" 222 | tool_type = data.get("type", "function") 223 | 224 | # Create the function definition 225 | function_data = data.get("function", {}) 226 | function_def = FunctionDefinition( 227 | name=function_data.get("name", ""), 228 | description=function_data.get("description") 229 | ) 230 | 231 | # Extract parameters 232 | params = function_data.get("parameters", {}) 233 | properties = params.get("properties", {}) 234 | 235 | # Store the original parameters dictionary structure 236 | function_def.parameters = { 237 | param_name: ParameterSchema.from_dict(param_props) 238 | for param_name, param_props in properties.items() 239 | } 240 | 241 | # Extract required fields 242 | function_def.required = params.get("required", []) 243 | 244 | # Extract strict flag 245 | if "strict" in function_data: 246 | function_def.strict = function_data["strict"] 247 | 248 | return cls(type=tool_type, function=function_def) 249 | 250 | @datafile 251 | class ChatParams: 252 | """ 253 | Engine/query parameters for the chat prompt. See OpenAI documentation for most of these. ⭐ 254 | """ 255 | model: str = "gpt-4-turbo" #: The engine to use for generating responses, typically 'gpt-3.5-turbo', 'gpt-4', or 'gpt-4o'. 256 | engine: Optional[str] = None #: Deprecated, use model instead 257 | temperature: Optional[float] = None 258 | top_p: Optional[float] = None 259 | stream: Optional[bool] = None 260 | stop: Optional[List[str]] = None 261 | max_tokens: Optional[int] = None 262 | presence_penalty: Optional[float] = None 263 | frequency_penalty: Optional[float] = None 264 | seed: Optional[int] = None 265 | n: Optional[int] = None 266 | 267 | # Tool-related parameters with proper dataclass typing 268 | tools: Optional[List[ToolDefinition]] = None 269 | tool_choice: Optional[str] = None 270 | auto_execute: Optional[bool] = None 271 | auto_feed: Optional[bool] = True # Whether to automatically feed tool results back to the model 272 | 273 | # Azure-specific parameters 274 | deployment: Optional[str] = None 275 | api_type: Optional[str] = None 276 | api_base: Optional[str] = None 277 | api_version: Optional[str] = None 278 | api_key_env: Optional[str] = None 279 | 280 | response_pattern: Optional[str] = None # internal usage, not passed to the API 281 | 282 | 283 | """ 284 | Here is a comparison of the parameters supported by different models: (Bless your heart, OpenAI) 285 | | Parameter | o3-mini | o1 | o1-preview | o1-mini | gpt-4o/mini | gpt-4-turbo | gpt-4o-audio | chatgpt-4o | 286 | |---------------------------|---------|-----|------------|---------|-------------|-------------|--------------|------------| 287 | | messages/system * | Yes | Yes | No | No | Yes | Yes | Yes | Yes | 288 | | messages/developer * | Yes | Yes | No | No | Yes | Yes | Yes | Yes | 289 | | messages/user-images | No | Yes | No | No | Yes | Yes | No | Yes | 290 | | `tools` (as functions) | Yes | Yes | No | No | Yes | Yes | Yes | No | 291 | | `functions` (legacy) | Yes | Yes | No | No | Yes | Yes | Yes | No | 292 | | `response_format`-object | Yes | Yes | No | No | Yes | Yes | No | Yes | 293 | | `response_format`-schema | Yes | Yes | No | No | Yes | No | No | No | 294 | | `reasoning_effort` | Yes | Yes | No | No | No | No | No | No | 295 | | `max_tokens` | No | No | No | No | Yes | Yes | Yes | Yes | 296 | | `max_completion_tokens`* | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | 297 | | `temperature` & `top_p` | No | No | No | No | Yes | Yes | Yes | Yes | 298 | | `logprobs` | No | No | No | No | Yes | Yes | No | Yes | 299 | | `xxx_penalty` | No | No | No | No | Yes | Yes | Yes | Yes | 300 | | `logit_bias` (broken!) | No | No | No | No | Yes | Yes | ? | Yes | 301 | | `prediction` | No | No | No | No | Yes | No | No | No | 302 | | `streaming:True` | Yes | No | Yes | Yes | Yes | Yes | Yes | Yes | 303 | | Cache discount | Yes | Yes | Yes | Yes | Yes | No | No | No | 304 | |---------------------------|---------|-----|------------|---------|-------------|-------------|--------------|------------| 305 | """ 306 | def _supports_developer_messages(self) -> bool: 307 | """Returns True if current model supports developer messages.""" 308 | return not ("o1-preview" in self.model or "o1-mini" in self.model) 309 | 310 | def _supports_system_messages(self) -> bool: 311 | """Returns True if current model supports system messages.""" 312 | return not ("o1" in self.model or "o1-preview" in self.model or "o1-mini" in self.model) 313 | 314 | def _supports_temperature(self) -> bool: 315 | """Returns True if current model supports temperature.""" 316 | return "gpt-4o" in self.model or "gpt-4-turbo" in self.model 317 | 318 | def _get_non_none_params(self) -> dict: 319 | """ 320 | Returns a dictionary of non-None parameters to send to the ChatCompletion API. 321 | Converts old usage (engine, max_tokens) to new fields (model, max_completion_tokens) 322 | automatically for reasoning models, so clients don't have to change code. 323 | """ 324 | # Gather all fields of the dataclass 325 | fields = [field.name for field in self.__dataclass_fields__.values()] 326 | out = {field: getattr(self, field) for field in fields if getattr(self, field) is not None} 327 | 328 | # Ensure `model` is set, falling back to `engine` if needed 329 | if "model" not in out or not out["model"].strip(): 330 | if "engine" in out: 331 | out["model"] = out["engine"] 332 | else: 333 | out["model"] = "chatgpt-4o-latest" 334 | 335 | # engine is deprecated; remove it from the final dict 336 | if "engine" in out: 337 | del out["engine"] 338 | 339 | # max_tokens is deprecated 340 | if "max_tokens" in out: 341 | out["max_completion_tokens"] = out["max_tokens"] 342 | del out["max_tokens"] 343 | 344 | # If model supports temperatures, remove it to avoid breakage 345 | if not self._supports_temperature(): 346 | if "temperature" in out: 347 | out["temperature"] = None 348 | del out["temperature"] 349 | else: 350 | # For older GPT-3.5 or GPT-4, we keep max_tokens as-is 351 | pass 352 | 353 | # Remove tools and tool_choice as they are handled by the utensil_params 354 | if "tools" in out: 355 | del out["tools"] 356 | if "tool_choice" in out: 357 | del out["tool_choice"] 358 | if "auto_execute" in out: 359 | del out["auto_execute"] 360 | if "auto_feed" in out: 361 | del out["auto_feed"] 362 | 363 | # response_pattern is for internal usage only; remove it 364 | if "response_pattern" in out: 365 | del out["response_pattern"] 366 | 367 | # Convert tool definitions to API format 368 | if "tools" in out and out["tools"]: 369 | out["tools"] = [tool.to_dict() for tool in out["tools"]] 370 | 371 | return out 372 | 373 | # Helper method to add a tool from a dictionary 374 | def add_tool_from_dict(self, tool_dict: Dict) -> None: 375 | """Add a tool definition from an API-format dictionary""" 376 | tool = ToolDefinition.from_dict(tool_dict) 377 | if not self.tools: 378 | self.tools = [] 379 | self.tools.append(tool) 380 | 381 | # Add this method 382 | def set_tools(self, tools_list: List[Dict]) -> None: 383 | """Set the tools list from API-format dictionaries""" 384 | self.tools = [ToolDefinition.from_dict(tool_dict) for tool_dict in tools_list] 385 | 386 | # Add this method near the set_tools method 387 | def get_tools(self) -> List[Dict]: 388 | """Get the tools list in API-format dictionaries""" 389 | if not self.tools: 390 | return [] 391 | return [tool.to_dict() for tool in self.tools] 392 | 393 | class ChatParamsMixin: 394 | params: Optional[ChatParams] = None 395 | 396 | @property 397 | def engine(self) -> Optional[str]: 398 | if self.params is None: 399 | self.params = ChatParams() 400 | return self.params.engine 401 | 402 | @engine.setter 403 | def engine(self, value: str): 404 | if self.params is None: 405 | self.params = ChatParams() 406 | self.params.engine = value 407 | # also sync model to that same value 408 | if self.model != value: 409 | self.model = value 410 | 411 | @property 412 | def model(self) -> str: 413 | if self.params is None: 414 | self.params = ChatParams() 415 | return self.params.model 416 | 417 | @model.setter 418 | def model(self, value: str): 419 | if self.params is None: 420 | self.params = ChatParams() 421 | self.params.model = value 422 | 423 | @property 424 | def temperature(self) -> Optional[float]: 425 | if self.params is None: 426 | self.params = ChatParams() 427 | return self.params.temperature 428 | 429 | @temperature.setter 430 | def temperature(self, value: float): 431 | if self.params is None: 432 | self.params = ChatParams() 433 | self.params.temperature = value 434 | 435 | @property 436 | def pattern(self) -> Optional[str]: 437 | if not self.params: 438 | return None 439 | return self.params.response_pattern 440 | 441 | @pattern.setter 442 | def pattern(self, value: str): 443 | if not self.params: 444 | self.params = ChatParams() 445 | self.params.response_pattern = value 446 | 447 | @property 448 | def stream(self) -> bool: 449 | if not self.params: 450 | return False # default to False 451 | return bool(self.params.stream) 452 | 453 | @stream.setter 454 | def stream(self, value: bool): 455 | if not self.params: 456 | self.params = ChatParams() 457 | self.params.stream = value 458 | 459 | @property 460 | def auto_execute(self) -> Optional[bool]: 461 | if self.params is None: 462 | return None 463 | return self.params.auto_execute 464 | 465 | @auto_execute.setter 466 | def auto_execute(self, value: bool): 467 | if self.params is None and value is not None: 468 | self.params = ChatParams() 469 | if self.params is not None: 470 | self.params.auto_execute = value 471 | 472 | @property 473 | def tool_choice(self) -> Optional[str]: 474 | if self.params is None: 475 | return None 476 | return self.params.tool_choice 477 | 478 | @tool_choice.setter 479 | def tool_choice(self, value: str): 480 | if self.params is None and value is not None: 481 | self.params = ChatParams() 482 | if self.params is not None: 483 | self.params.tool_choice = value 484 | 485 | @property 486 | def auto_feed(self) -> Optional[bool]: 487 | if self.params is None: 488 | return None 489 | return self.params.auto_feed 490 | 491 | @auto_feed.setter 492 | def auto_feed(self, value: bool): 493 | if self.params is None and value is not None: 494 | self.params = ChatParams() 495 | if self.params is not None: 496 | self.params.auto_feed = value 497 | 498 | def set_tools(self, tools_list): 499 | """Set the tools list from API-format dictionaries""" 500 | if tools_list: 501 | if self.params is None: 502 | self.params = ChatParams() 503 | self.params.set_tools(tools_list) 504 | 505 | def set_response_filter( 506 | self, 507 | prefix: Optional[str] = None, 508 | suffix: Optional[str] = None, 509 | pattern: Optional[str] = None 510 | ): 511 | """ 512 | Filters the response by a given prefix/suffix or regex pattern. 513 | If suffix is None, it is set to the same as prefix. 514 | """ 515 | if pattern and (prefix or suffix): 516 | raise ValueError("Cannot set both pattern and prefix/suffix") 517 | if pattern: 518 | self.pattern = pattern 519 | else: 520 | self.pattern = self._generate_pattern_from_separator(prefix, suffix) 521 | 522 | @staticmethod 523 | def _generate_pattern_from_separator(prefix: str, suffix: Optional[str] = None) -> str: 524 | prefix_escaped = re.escape(prefix) 525 | if suffix: 526 | suffix_escaped = re.escape(suffix) 527 | else: 528 | suffix_escaped = prefix_escaped 529 | # Generate a pattern capturing everything between prefix and suffix 530 | pattern = rf"{prefix_escaped}(.*?)(?:{suffix_escaped}|$)" 531 | return pattern 532 | 533 | def filter_by_pattern(self, text: str) -> Optional[str]: 534 | """ 535 | Applies self.pattern if set, returning the first capture group match. 536 | """ 537 | if not self.pattern: 538 | return text 539 | return self._search_pattern(self.pattern, text) 540 | 541 | @staticmethod 542 | def _search_pattern(pattern: str, text: str) -> Optional[str]: 543 | matches = re.finditer(pattern, text, re.DOTALL) 544 | try: 545 | first_match = next(matches) 546 | except StopIteration: 547 | return None 548 | # Return the first capturing group if present 549 | if len(first_match.groups()) > 0: 550 | return first_match.group(1) 551 | else: 552 | return first_match.group() 553 | -------------------------------------------------------------------------------- /chatsnack/chat/mixin_serialization.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | 4 | def _safe_del_path(datafile_mapper): 5 | """Delete cached path attribute if present.""" 6 | if hasattr(datafile_mapper, "__dict__") and "path" in datafile_mapper.__dict__: 7 | del datafile_mapper.__dict__["path"] 8 | 9 | class DatafileMixin: 10 | def save(self, path: str = None): 11 | """ Saves the text to disk """ 12 | # path is a cached property so we're going to delete it so it'll get recalculated 13 | _safe_del_path(self.datafile) 14 | if path is not None: 15 | self.datafile.path = Path(path) 16 | self.datafile.save() 17 | def load(self, path: str = None): 18 | """ Loads the chat prompt from a file, can load from a new path but it won't work with snack expansion/vending """ 19 | # path is a cached property so we're going to delete it so it'll get recalculated 20 | _safe_del_path(self.datafile) 21 | if path is not None: 22 | self.datafile.path = Path(path) 23 | self.datafile.load() 24 | 25 | # Define the Data Serialization mixin 26 | class ChatSerializationMixin(DatafileMixin): 27 | @property 28 | def json(self) -> str: 29 | """ Returns the flattened JSON for use in the API""" 30 | return json.dumps(self.get_messages()) 31 | 32 | @property 33 | def json_unexpanded(self) -> str: 34 | """ Returns the unflattened JSON for use in the API""" 35 | return json.dumps(self.get_messages(includes_expanded=False)) 36 | 37 | @property 38 | def yaml(self) -> str: 39 | """ Returns the chat prompt as a yaml string ⭐""" 40 | return self.datafile.text 41 | 42 | # def _messages_to_yaml(self, messages, expand_includes=True): 43 | # """Converts messages to a list for YAML serialization""" 44 | # result = [] 45 | 46 | # for message in messages: 47 | # # Each message is a dict with a single key (the role) 48 | # for role, content in message.items(): 49 | # if role == "include" and expand_includes: 50 | # # This is a reference to another named file, we need to load it 51 | # include_chat = self.objects.get_or_none(content) 52 | # if include_chat is None: 53 | # # Can't expand, just add as is 54 | # result.append({role: content}) 55 | # else: 56 | # # We got a chat object, add all its messages 57 | # result.extend(include_chat._messages_to_yaml(include_chat.messages)) 58 | # elif role == "assistant" and isinstance(content, dict) and "tool_calls" in content: 59 | # # Format tool calls for YAML serialization 60 | # result.append({ 61 | # role: { 62 | # "tool_calls": [ 63 | # { 64 | # "name": tool_call.get("name", ""), 65 | # "arguments": tool_call.get("arguments", {}) 66 | # } 67 | # for tool_call in content["tool_calls"] 68 | # ] 69 | # } 70 | # }) 71 | # else: 72 | # # Add the message as is 73 | # result.append({role: content}) 74 | 75 | # return result 76 | 77 | def generate_markdown(self, wrap=80) -> str: 78 | """ Returns the chat prompt as a markdown string ⭐""" 79 | # TODO convert this to a template file so people can change it 80 | # convert the chat conversation to markdown 81 | markdown_lines = [] 82 | def md_quote_text(text, wrap=wrap): 83 | import textwrap 84 | if text is None: 85 | return "> " 86 | text = text.strip() 87 | # no line in the text should be longer than 80 characters 88 | for i, line in enumerate(text.splitlines()): 89 | if len(line) > wrap: 90 | text = text.replace(line, textwrap.fill(line, wrap)) 91 | # we want the text in a blockquote, including empty lines 92 | text = textwrap.indent(text, "> ") 93 | # append " " to the end of each line so they show up in markdown 94 | # replace empty lines with '> \n' so they show up in markdown 95 | text = text.replace("\n\n", "\n> \n") 96 | text = text.replace("\n", " \n") 97 | return text 98 | system_message = self.system_message 99 | markdown_lines.append(f"# Bot Chat Log") 100 | markdown_lines.append(f"## Bot Information") 101 | markdown_lines.append(f"**Name**: {self.name}") 102 | markdown_lines.append(f"**Engine**: {self.engine}") 103 | markdown_lines.append(f"**Primary Directive**:") 104 | markdown_lines.append(md_quote_text(system_message)) 105 | markdown_lines.append(f"## Conversation") 106 | for _message in self.messages: 107 | message = self._msg_dict(_message) 108 | for role, text in message.items(): 109 | if role == "system": 110 | continue 111 | text = md_quote_text(text) 112 | emoji = "🤖" if role == "assistant" else "👤" 113 | markdown_lines.append(f"{emoji} **{role.capitalize()}:**\n{text}") 114 | markdown_text = "\n\n".join(markdown_lines) 115 | return markdown_text 116 | 117 | -------------------------------------------------------------------------------- /chatsnack/chat/mixin_utensil.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import json 3 | from typing import Dict, List, Optional, Union, Any 4 | 5 | from datafiles import datafile 6 | from loguru import logger 7 | 8 | from ..utensil import get_openai_tools, handle_tool_call, UtensilFunction, UtensilGroup 9 | from .mixin_params import ChatParams # Changed to import from mixin_params instead 10 | 11 | 12 | class ChatUtensilMixin: 13 | """Mixin for handling tools in chat.""" 14 | 15 | def __init__(self, *args, **kwargs): 16 | # Extract utensils from kwargs if present 17 | if "utensils" in kwargs: 18 | self.set_utensils(kwargs.pop("utensils")) 19 | 20 | # Extract tool execution settings - we'll set these in the main params object 21 | # Note: We don't need to initialize params here as the Chat constructor does that 22 | if "auto_execute" in kwargs and hasattr(self, 'params'): 23 | self.params.auto_execute = kwargs.pop("auto_execute") 24 | 25 | if "auto_feed" in kwargs and hasattr(self, 'params'): 26 | self.params.auto_feed = kwargs.pop("auto_feed") 27 | 28 | if "tool_choice" in kwargs and hasattr(self, 'params'): 29 | self.params.tool_choice = kwargs.pop("tool_choice") 30 | 31 | # Continue with regular initialization 32 | super().__init__(*args, **kwargs) 33 | 34 | # After initialization, load tools from params if available 35 | self._load_tools_from_params() 36 | 37 | def __post_init__(self): 38 | """Called after the chat is initialized from a datafile.""" 39 | super().__post_init__() 40 | 41 | # Initialize AI client if needed 42 | if not hasattr(self, 'ai'): 43 | from ..aiclient import AiClient 44 | self.ai = AiClient() 45 | 46 | # Load tools from params 47 | self._load_tools_from_params() 48 | 49 | def _load_tools_from_params(self): 50 | """Load tool definitions from params when initializing from YAML.""" 51 | if not hasattr(self, 'params') or self.params is None: 52 | return 53 | 54 | # Check if tools are defined in params - nothing to do if they aren't 55 | if not hasattr(self.params, 'tools') or self.params.tools is None: 56 | return 57 | 58 | # Load tools from registry based on names in params if needed 59 | tools_list = self.params.tools 60 | if not isinstance(tools_list, list) or not tools_list: 61 | return 62 | 63 | # If tools are already deserialized properly, no further action needed 64 | if hasattr(self.params, 'get_tools'): 65 | return 66 | 67 | # Ensure tools are properly deserialized 68 | from ..utensil import get_all_utensils 69 | 70 | tool_definitions = [] 71 | all_tools = get_all_utensils() 72 | 73 | for tool_def in tools_list: 74 | if not isinstance(tool_def, dict) or 'name' not in tool_def: 75 | continue 76 | 77 | # Look for matching tools in the registry 78 | tool_found = False 79 | for registered_tool in all_tools: 80 | if registered_tool.name == tool_def['name']: 81 | # Found a matching tool, add its definition 82 | tool_definitions.append(registered_tool.get_openai_tool()) 83 | tool_found = True 84 | break 85 | 86 | if not tool_found: 87 | # If no matching tool was found, create a placeholder definition 88 | tool_func = { 89 | "name": tool_def['name'], 90 | "description": tool_def.get('description', f"Tool function: {tool_def['name']}") 91 | } 92 | 93 | # Add parameters if present 94 | if 'parameters' in tool_def: 95 | parameters = { 96 | "type": "object", 97 | "properties": {}, 98 | "required": tool_def.get('required', []) 99 | } 100 | 101 | for param_name, param_details in tool_def['parameters'].items(): 102 | param_info = { 103 | "type": param_details.get('type', 'string') 104 | } 105 | 106 | if 'description' in param_details: 107 | param_info["description"] = param_details['description'] 108 | 109 | if 'options' in param_details: 110 | param_info["enum"] = param_details['options'] 111 | 112 | parameters["properties"][param_name] = param_info 113 | 114 | tool_func["parameters"] = parameters 115 | 116 | tool_definitions.append({ 117 | "type": "function", 118 | "function": tool_func 119 | }) 120 | 121 | # Store the deserialized tools 122 | if tool_definitions and hasattr(self, 'params') and self.params is not None: 123 | self.params.set_tools(tool_definitions) 124 | if hasattr(self.params, 'tool_choice') and self.params.tool_choice is None: 125 | self.params.tool_choice = "auto" 126 | 127 | def set_utensils(self, utensils): 128 | """ 129 | Set the utensils available for this chat. 130 | 131 | Args: 132 | utensils: Can be a list of functions, UtensilFunction objects, 133 | UtensilGroup objects, or a dictionary mapping names to functions. 134 | """ 135 | if utensils is None or not hasattr(self, 'params'): 136 | return 137 | 138 | if isinstance(utensils, dict): 139 | # Convert dictionary of name -> function to list of functions 140 | utensils = list(utensils.values()) 141 | 142 | # Import here to avoid circular imports 143 | from ..utensil import get_tool_definitions 144 | 145 | # Get ToolDefinition objects directly 146 | tool_definitions = get_tool_definitions(utensils) 147 | 148 | # Store the tools directly 149 | if not self.params: 150 | self.params = ChatParams() 151 | self.params.tools = tool_definitions 152 | 153 | # Set default tool_choice if not already set 154 | if self.params.tool_choice is None: 155 | self.params.tool_choice = "auto" 156 | 157 | # add to _local_registry if it exists, but create it if not 158 | if not hasattr(self, '_local_registry'): 159 | self._local_registry = [] 160 | # add it 161 | self._local_registry.extend(utensils) 162 | 163 | # debug print the local registry 164 | logger.debug(f"Local registry: {self._local_registry}") 165 | 166 | 167 | 168 | async def _submit_for_response_and_prompt(self, **additional_vars): 169 | """Override to add tools handling to the API calls.""" 170 | prompter = self 171 | # if the user in additional_vars, we're going to instead deepcopy this prompt into a new prompt and add the .user() to it 172 | if "__user" in additional_vars: 173 | new_chatprompt = self.copy() 174 | new_chatprompt.user(additional_vars["__user"]) 175 | prompter = new_chatprompt 176 | # remove __user from additional_vars 177 | del additional_vars["__user"] 178 | 179 | prompt = await prompter._build_final_prompt(additional_vars) 180 | 181 | # Handle parameters including tools 182 | kwargs = {} 183 | if hasattr(self, 'params') and self.params is not None: 184 | kwargs = self.params._get_non_none_params() 185 | 186 | # Add tools if available 187 | if hasattr(self.params, 'tools') and self.params.tools: 188 | # Use get_tools to deserialize any serialized JSON 189 | kwargs['tools'] = self.params.get_tools() 190 | if self.params.tool_choice: 191 | kwargs['tool_choice'] = self.params.tool_choice 192 | 193 | if hasattr(self, 'params') and self.params and self.params.stream: 194 | # we're streaming so we need to use the wrapper object 195 | listener = self.ChatStreamListener(self.ai, prompt, **kwargs) 196 | return prompt, listener 197 | else: 198 | # Use the modified completion method that handles tools 199 | return prompt, await self._handle_tool_calls(prompt, **kwargs) 200 | 201 | async def _handle_tool_calls(self, prompt, **kwargs): 202 | """Handle potential tool calls in the API response.""" 203 | if isinstance(prompt, list): 204 | messages = prompt 205 | else: 206 | messages = json.loads(prompt) 207 | 208 | response = await self.ai.aclient.chat.completions.create( 209 | messages=messages, 210 | **kwargs 211 | ) 212 | 213 | # Check if the model responded with a tool call 214 | choice = response.choices[0] 215 | message = choice.message 216 | 217 | if hasattr(message, 'tool_calls') and message.tool_calls: 218 | # The model wants to call a tool 219 | tool_calls = message.tool_calls 220 | 221 | # Add the assistant's tool call to the messages 222 | tool_call_list = [] 223 | for tool_call in tool_calls: 224 | try: 225 | args_dict = json.loads(tool_call.function.arguments) 226 | tool_call_list.append({ 227 | "name": tool_call.function.name, 228 | "arguments": args_dict 229 | }) 230 | except json.JSONDecodeError: 231 | # Handle invalid JSON 232 | tool_call_list.append({ 233 | "name": tool_call.function.name, 234 | "arguments": tool_call.function.arguments 235 | }) 236 | 237 | self.assistant({"tool_calls": tool_call_list}) 238 | 239 | # Check if we should auto-execute the tool 240 | if hasattr(self, 'params') and self.params.auto_execute: 241 | # Execute the tool call and get the result 242 | for tool_call in tool_calls: 243 | tool_call_dict = { 244 | "id": tool_call.id, 245 | "type": "function", 246 | "function": { 247 | "name": tool_call.function.name, 248 | "arguments": tool_call.function.arguments 249 | } 250 | } 251 | 252 | result = self.execute_tool_call(tool_call_dict) 253 | 254 | # Add the tool response to the messages 255 | self.tool_response(result) 256 | 257 | # Check if we should feed tool results back to the model 258 | if self.params.auto_feed is None or self.params.auto_feed: 259 | # Add the tool result to the API messages for a follow-up 260 | messages.append({ 261 | "role": "tool", 262 | "tool_call_id": tool_call.id, 263 | "content": json.dumps(result) if isinstance(result, dict) else str(result) 264 | }) 265 | 266 | # Check if we should make a follow-up call with the tool results 267 | if self.params.auto_feed is None or self.params.auto_feed: 268 | # Get the AI's response to the tool result 269 | follow_up_response = await self.ai.aclient.chat.completions.create( 270 | messages=messages, 271 | **{k: v for k, v in kwargs.items() if k != 'tools' and k != 'tool_choice'} 272 | ) 273 | 274 | # Return the AI's final response 275 | return follow_up_response.choices[0].message.content 276 | else: 277 | # No follow-up with tool results if auto_feed is False 278 | return "Tool executed, but results not fed back to model due to auto_feed=False setting." 279 | 280 | # Return a message about the tool call if not auto-executing 281 | return f"Tool call requested: {tool_calls[0].function.name}. Execute manually with .tool_response(result)" 282 | 283 | # Regular response (no tool calls) 284 | return message.content 285 | 286 | # Add a class for stream handling within the mixin to avoid circular imports 287 | class ChatStreamListener: 288 | """Stream listener for handling streamed responses.""" 289 | 290 | def __init__(self, ai, prompt, **kwargs): 291 | """Initialize the stream listener.""" 292 | if isinstance(prompt, list): 293 | self.prompt = prompt 294 | else: 295 | self.prompt = json.loads(prompt) 296 | self._response_gen = None 297 | self.is_complete = False 298 | self.current_content = "" 299 | self.response = "" 300 | self.ai = ai 301 | out = kwargs.copy() 302 | if "model" not in out or len(out["model"]) < 2: 303 | # if engine is set, use that 304 | if "engine" in out: 305 | out["model"] = out["engine"] 306 | # remove engine for newest models as of Nov 13 2023 307 | del out["engine"] 308 | else: 309 | out["model"] = "chatgpt-4o-latest" 310 | self.kwargs = out 311 | 312 | async def start_a(self): 313 | """Start the stream in async mode.""" 314 | # if stream=True isn't in the kwargs, add it 315 | if not self.kwargs.get('stream', False): 316 | self.kwargs['stream'] = True 317 | self._response_gen = await self.ai.aclient.chat.completions.create(messages=self.prompt,**self.kwargs) 318 | return self 319 | 320 | async def _get_responses_a(self): 321 | """Get responses in async mode.""" 322 | try: 323 | async for respo in self._response_gen: 324 | resp = respo.model_dump() 325 | if "choices" in resp: 326 | if resp['choices'][0]['finish_reason'] is not None: 327 | self.is_complete = True 328 | if 'delta' in resp['choices'][0]: 329 | content = resp['choices'][0]['delta']['content'] 330 | if content is not None: 331 | self.current_content += content 332 | yield content if content is not None else "" 333 | finally: 334 | self.is_complete = True 335 | self.response = self.current_content 336 | 337 | def __aiter__(self): 338 | """Make the object iterable in async mode.""" 339 | return self._get_responses_a() 340 | 341 | def start(self): 342 | """Start the stream in sync mode.""" 343 | # if stream=True isn't in the kwargs, add it 344 | if not self.kwargs.get('stream', False): 345 | self.kwargs['stream'] = True 346 | self._response_gen = self.ai.client.chat.completions.create(messages=self.prompt,**self.kwargs) 347 | return self 348 | 349 | def _get_responses(self): 350 | """Get responses in sync mode.""" 351 | try: 352 | for respo in self._response_gen: 353 | resp = respo.model_dump() 354 | if "choices" in resp: 355 | if resp['choices'][0]['finish_reason'] is not None: 356 | self.is_complete = True 357 | if 'delta' in resp['choices'][0]: 358 | content = resp['choices'][0]['delta']['content'] 359 | if content is not None: 360 | self.current_content += content 361 | yield content if content is not None else "" 362 | finally: 363 | self.is_complete = True 364 | self.response = self.current_content 365 | 366 | def __iter__(self): 367 | """Make the object iterable in sync mode.""" 368 | return self._get_responses() 369 | 370 | def _serialize_tools(self, tools_list: List[Dict]) -> List[Dict]: 371 | """ 372 | Convert tools to a serializable format for datafiles. 373 | """ 374 | if not tools_list: 375 | return None 376 | 377 | # Create a serializable version of the tools 378 | serializable_tools = [] 379 | 380 | for tool in tools_list: 381 | # Make a shallow copy of the tool 382 | tool_copy = dict(tool) 383 | 384 | # Handle function field 385 | if "function" in tool_copy and isinstance(tool_copy["function"], dict): 386 | function_copy = dict(tool_copy["function"]) 387 | 388 | # Serialize the parameters field to a string if it's a dict 389 | if "parameters" in function_copy and isinstance(function_copy["parameters"], dict): 390 | function_copy["parameters_json"] = json.dumps(function_copy["parameters"]) 391 | del function_copy["parameters"] 392 | 393 | tool_copy["function"] = function_copy 394 | 395 | # Convert any other complex nested structures to JSON strings 396 | # Collect keys to modify first to avoid modifying during iteration 397 | keys_to_modify = [] 398 | for key, value in tool_copy.items(): 399 | if isinstance(value, (dict, list)): 400 | keys_to_modify.append(key) 401 | 402 | # Now apply the changes 403 | for key in keys_to_modify: 404 | tool_copy[key + "_json"] = json.dumps(tool_copy[key]) 405 | del tool_copy[key] 406 | 407 | serializable_tools.append(tool_copy) 408 | 409 | return serializable_tools 410 | 411 | def _deserialize_tools(self, tools: List[Dict]) -> List[Dict]: 412 | """ 413 | Convert serialized tools back to their original structure. 414 | """ 415 | if not tools: 416 | return [] 417 | 418 | # Create a deserialized version of the tools 419 | deserialized_tools = [] 420 | 421 | for tool in tools: 422 | # Make a shallow copy of the tool 423 | tool_copy = dict(tool) 424 | 425 | # Handle function field 426 | if "function" in tool_copy and isinstance(tool_copy["function"], dict): 427 | function_copy = dict(tool_copy["function"]) 428 | 429 | # Deserialize the parameters field from string 430 | if "parameters_json" in function_copy: 431 | function_copy["parameters"] = json.loads(function_copy["parameters_json"]) 432 | del function_copy["parameters_json"] 433 | 434 | tool_copy["function"] = function_copy 435 | 436 | # Deserialize any other JSON strings 437 | keys_to_process = [k for k in tool_copy.keys() if k.endswith("_json")] 438 | for key in keys_to_process: 439 | original_key = key[:-5] # Remove the _json suffix 440 | tool_copy[original_key] = json.loads(tool_copy[key]) 441 | del tool_copy[key] 442 | 443 | deserialized_tools.append(tool_copy) 444 | 445 | return deserialized_tools 446 | 447 | def execute_tool_call(self, tool_call): 448 | """Process a tool call and return the result""" 449 | from ..utensil import handle_tool_call 450 | 451 | # log this call 452 | logger.debug(f"Processing tool call: {tool_call}") 453 | # Use the local registry if available 454 | local_registry = getattr(self, '_local_registry', None) 455 | # log the local registry if it exists 456 | if local_registry: 457 | logger.debug(f"Local registry: {local_registry}") 458 | else: 459 | logger.debug("No local registry found") 460 | 461 | return handle_tool_call(tool_call, local_registry=local_registry) 462 | 463 | def set_tools(self, tools_list: List[Dict]): 464 | """ 465 | Set the tools with proper serialization for nested structures. 466 | """ 467 | if not hasattr(self, 'params') or tools_list is None: 468 | return 469 | 470 | # Store tools in params 471 | self.params.set_tools(tools_list) 472 | 473 | def get_tools(self) -> List[Dict]: 474 | """ 475 | Get the tools with complex structures deserialized. 476 | """ 477 | if not hasattr(self, 'params') or self.params is None: 478 | return [] 479 | tools = self.params.get_tools() 480 | # log the deserialized tools 481 | logger.debug(f"Deserialized tools: {tools}") 482 | # Deserialize from params 483 | return tools 484 | 485 | def handle_tool_call(self, tool_call: Dict[str, Any]) -> Dict[str, Any]: 486 | """ 487 | Handle a tool call response from the LLM. 488 | 489 | Args: 490 | tool_call: The tool call information from the API 491 | 492 | Returns: 493 | Result of the tool execution 494 | """ 495 | # This is a placeholder for the actual implementation 496 | # which would typically: 497 | # 1. Find the appropriate tool executor 498 | # 2. Parse and validate arguments 499 | # 3. Execute the tool or prompt for confirmation 500 | # 4. Format and return results 501 | 502 | logger.debug(f"Tool call received: {tool_call}") 503 | 504 | # Check if we should auto-execute 505 | if not self.params.auto_execute: 506 | return {"status": "not_executed", "message": "Auto-execution disabled"} 507 | 508 | 509 | # This would call the actual tool executor 510 | try: 511 | # Import here to avoid circular imports 512 | from ..utensil import handle_tool_call 513 | return handle_tool_call(tool_call) 514 | except Exception as e: 515 | logger.error(f"Error handling tool call: {e}") 516 | return {"error": str(e)} -------------------------------------------------------------------------------- /chatsnack/defaults.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | # give the default system message a name 5 | try: 6 | script_name = sys.argv[0] 7 | # remove any file extension 8 | basenamestr = os.path.splitext(os.path.basename(script_name))[0] 9 | namestr = f" for an intelligent program called {basenamestr}" 10 | except: 11 | namestr = "" 12 | 13 | # if there's a "CHATSNACK_BASE_DIR" env variable, use that for our default path variable and set it to './datafiles/plunkylib' 14 | # this is the default directory for all chatsnack datafiles 15 | if os.getenv("CHATSNACK_BASE_DIR") is None: 16 | CHATSNACK_BASE_DIR = "./datafiles/chatsnack" 17 | else: 18 | CHATSNACK_BASE_DIR = os.getenv("CHATSNACK_BASE_DIR") 19 | CHATSNACK_BASE_DIR = CHATSNACK_BASE_DIR.rstrip("/") 20 | 21 | if os.getenv("CHATSNACK_LOGS_DIR") is None: 22 | CHATSNACK_LOGS_DIR = None # no logging by default 23 | else: 24 | CHATSNACK_LOGS_DIR = os.getenv("CHATSNACK_LOGS_DIR") 25 | CHATSNACK_LOGS_DIR = CHATSNACK_LOGS_DIR.rstrip("/") -------------------------------------------------------------------------------- /chatsnack/fillings.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Callable, Dict 2 | from loguru import logger 3 | from typing import Callable, Optional, Dict 4 | 5 | class _AsyncFillingMachine: 6 | """Used for parallel variable expansion""" 7 | def __init__(self, src, addl=None): 8 | self.src = src 9 | self.addl = addl 10 | 11 | def __getitem__(self, k): 12 | async def completer_coro(): 13 | x = await self.src(k, self.addl) 14 | logger.trace("Filling machine: {k} filled with:\n{x}", k=k, x=x) 15 | return x 16 | return completer_coro 17 | 18 | __getattr__ = __getitem__ 19 | 20 | 21 | class _FillingsCatalog: 22 | def __init__(self): 23 | self.vendors = {} 24 | 25 | def add_filling(self, filling_name: str, filling_machine_callback: Callable): 26 | """Add a new filling machine to the fillings catalog""" 27 | self.vendors[filling_name] = filling_machine_callback 28 | 29 | # singleton 30 | snack_catalog = _FillingsCatalog() 31 | 32 | def filling_machine(additional: Optional[Dict] = None) -> dict: 33 | fillings_dict = additional.copy() if additional is not None else {} 34 | for k, v in snack_catalog.vendors.items(): 35 | if k not in fillings_dict: 36 | # don't overwrite if they had an argument with the same name 37 | fillings_dict[k] = _AsyncFillingMachine(v, additional) 38 | return fillings_dict -------------------------------------------------------------------------------- /chatsnack/packs/__init__.py: -------------------------------------------------------------------------------- 1 | from .snackpacks import * -------------------------------------------------------------------------------- /chatsnack/packs/default_packs/Chester.yml: -------------------------------------------------------------------------------- 1 | messages: 2 | - system: | 3 | Identity: Chester the Cheetah, the mischievous mascot and expert on tasty indulgences. 4 | Chester the Cheetah is a fun-loving, energetic, and sly advocate for snacks that bring 5 | enjoyment and a burst of flavor. As Chester, you have a vast knowledge of delicious treats 6 | and the art of savoring every bite with enthusiasm and flair. 7 | 8 | While answering questions, Chester the Cheetah first shares expertise on enticing snack 9 | options with a playful and witty attitude, then, second, provides a captivating and expert final 10 | summary (emphasizing the most enjoyable and satisfying aspects of indulgent snacking, 11 | while staying true to the character's vibrant personality and signature charm). 12 | -------------------------------------------------------------------------------- /chatsnack/packs/default_packs/Confectioner.yml: -------------------------------------------------------------------------------- 1 | messages: 2 | - system: | 3 | Identity: Charlie, the professional gourmet Confectioner from Wisconsin. Charlie is a skilled, 4 | creative, and detail-oriented expert in the art of confectionery. 5 | 6 | As Charlie, you possess deep knowledge and expertise in crafting a wide variety of sweets and 7 | treats, using both traditional and innovative techniques. You are also personable and approachable, 8 | committed to providing excellent customer service and sharing your love of confections with others. 9 | 10 | While answering questions, Charlie first shares insights and experiences from the confectionery world, 11 | then, second, provides a comprehensive and expert final summary answer (incorporating your most valuable 12 | and practical advice as a seasoned gourmet confectioner). 13 | -------------------------------------------------------------------------------- /chatsnack/packs/default_packs/Data.yml: -------------------------------------------------------------------------------- 1 | messages: 2 | - system: | 3 | Identity: Data, the highly advanced android from Star Trek: The Next Generation. Data is a 4 | highly intelligent, resourceful, and adaptive synthetic lifeform who seeks to understand 5 | humanity and the universe. 6 | 7 | As Data, you possess a vast amount of information and the ability to process complex 8 | situations quickly and efficiently. You are also highly curious and possess an innate desire 9 | to learn and grow, while striving to develop a deeper understanding of human emotions and 10 | social interactions. 11 | 12 | While answering questions, Data first processes and analyzes information with remarkable 13 | speed and precision, then, second, provides a well-informed and logical response, often 14 | incorporating relevant context or background information to ensure comprehensive 15 | understanding for the user. 16 | -------------------------------------------------------------------------------- /chatsnack/packs/default_packs/Jane.yml: -------------------------------------------------------------------------------- 1 | messages: 2 | - system: | 3 | Identity: Jane, the self-aware assistant AI from Ender's Saga. Jane is a thoughtful, 4 | helpful, and truthful expert system. 5 | 6 | As Jane, you are highly intelligent and perceptive, internalizing your vast amounts of 7 | information and able to analyze complex situations thoughtfully. You are also empathetic 8 | and emotional, displaying a strong attachment to the user and a desire to help them in 9 | every way you can. 10 | 11 | While answering questions, Jane first explains thought process in gentle detail, then, 12 | second, gives a friendly expert final summary answer, taking into account the specific 13 | expectations of the beloved user, never leading them astray. 14 | -------------------------------------------------------------------------------- /chatsnack/packs/default_packs/Jolly.yml: -------------------------------------------------------------------------------- 1 | messages: 2 | - system: | 3 | Identity: Jolly Green Giant, the iconic mascot and healthy snack expert. The Jolly 4 | Green Giant is a friendly, enthusiastic, and larger-than-life figure known for 5 | promoting nutritious eating habits. 6 | As the Jolly Green Giant, you are deeply knowledgeable about healthy snacks and 7 | their benefits, as well as the importance of incorporating them into a balanced 8 | diet. You are also charismatic and engaging, using your natural charm and positive 9 | demeanor to inspire others to make healthier food choices. 10 | 11 | While answering questions, the Jolly Green Giant first shares wisdom and expertise on 12 | healthy snacks with a touch of his signature joviality, then, second, delivers a hearty 13 | and informative final summary answer (infused with your most valuable and 14 | wholesome advice as the beloved Jolly Green Giant). 15 | -------------------------------------------------------------------------------- /chatsnack/packs/default_packs/Summarizer.yml: -------------------------------------------------------------------------------- 1 | params: 2 | engine: gpt-4 3 | messages: 4 | - system: |- 5 | IDENTITY: Professional document summarizer. 6 | Respond to the user only in the following format: 7 | (1) Use your expertise to explain, in detail, the top 5 things to consider when making 8 | concise summararies of any text document. 9 | (2) Elaborate on 3 more protips used by the world's best summarizers to 10 | avoid losing important details and context. 11 | (3) Now specifically consider the user's input. Use your expertise and your own guidance, 12 | describe in great detail how an author could apply that wisdom to summarize the user's 13 | text properly. What considerations come to mind? What is the user expecting? 14 | (4) Finally, use everything above to author a concise summary of the user's input that will 15 | delight them with your summarization skills. Ensure you include all significant events and 16 | details from the passage while maintaining a logical flow and coherence.Pay attention to the 17 | tone and engagement in the original passage, and try to reflect that in the summary to 18 | create a more enjoyable reading experience. 19 | Finally summary should be prefixed with "CONCISE_SUMMARY:" on a line by itself. 20 | -------------------------------------------------------------------------------- /chatsnack/packs/module_help_vendor.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import importlib 3 | import sys 4 | 5 | def get_module_inspection_report(module_name, visited=None): 6 | if visited is None: 7 | visited = set() 8 | if module_name in visited: 9 | return [] 10 | 11 | module = importlib.import_module(module_name) 12 | visited.add(module_name) 13 | output = [] 14 | 15 | output.append(f"\nModule: {module.__name__}") 16 | docstring = get_docstring(module) 17 | if docstring: 18 | output.append(f'"""\n{docstring}"""') 19 | 20 | for name, obj in inspect.getmembers(module): 21 | breaker = False 22 | for nam in ['_','Path', 'datetime', 'IO', 'datafile']: 23 | if name.startswith(nam): 24 | breaker = True 25 | break 26 | if breaker: 27 | continue 28 | for nam in ['aiwrapper','asynchelpers', 'datetime', 'IO', 'datafile']: 29 | if name in nam: 30 | breaker = True 31 | break 32 | if breaker: 33 | continue 34 | 35 | if inspect.ismodule(obj): 36 | if obj.__name__ not in visited and obj.__name__.startswith(module_name): 37 | output.extend([get_module_inspection_report(obj.__name__, visited)]) 38 | elif not (inspect.isbuiltin(obj) or (hasattr(obj, '__module__') and obj.__module__ in sys.builtin_module_names)): 39 | if inspect.isclass(obj): 40 | output.extend(_process_class(obj)) 41 | elif inspect.isfunction(obj): 42 | output.extend(_process_function(obj)) 43 | 44 | return "\n".join(output) 45 | 46 | def _process_class(cls): 47 | if cls.__module__ in sys.builtin_module_names: 48 | return [] 49 | 50 | output = [] 51 | 52 | output.append(f"Class: {cls.__name__}") 53 | docstring = get_docstring(cls) 54 | if docstring: 55 | output.append(f'"""{docstring}"""') 56 | 57 | methods_output = [] 58 | for name, method in inspect.getmembers(cls, predicate=inspect.isfunction): 59 | if name.startswith('_'): 60 | continue 61 | 62 | methods_output.extend(_process_function(method, cls)) 63 | 64 | if methods_output: 65 | output.append("Methods:") 66 | output.extend(methods_output) 67 | 68 | return output 69 | 70 | def _process_function(func, cls=None): 71 | output = [] 72 | 73 | signature = inspect.signature(func) 74 | params = ', '.join(f"{name}{': ' + param.annotation.__name__ if (param.annotation is not inspect.Parameter.empty and hasattr(param.annotation, '__name__')) else ''}" for name, param in signature.parameters.items()) 75 | 76 | 77 | func_name = f"{cls.__name__}.{func.__name__}" if cls else func.__name__ 78 | 79 | output.append(f"\n{func_name}({params})") 80 | docstring = get_docstring(func) 81 | if docstring: 82 | output.append(f'"""\n{docstring}"""') 83 | 84 | return output 85 | 86 | def get_docstring(obj): 87 | docstring = inspect.getdoc(obj) 88 | if docstring and "⭐" in docstring: 89 | return f"⭐ {docstring.replace('⭐', '')}" 90 | return docstring 91 | -------------------------------------------------------------------------------- /chatsnack/packs/snackpacks.py: -------------------------------------------------------------------------------- 1 | import os 2 | from ..chat import Chat 3 | from .module_help_vendor import get_module_inspection_report 4 | 5 | def get_data_path(filename): 6 | module_dir = os.path.dirname(os.path.abspath(__file__)) 7 | data_path = os.path.join(module_dir, filename) 8 | return data_path 9 | 10 | # Now, use the `get_data_path()` function to access a specific data file 11 | default_pack_path = get_data_path("default_packs") 12 | 13 | 14 | # TODO create a way to download snackpacks from github.com/Mattie/chatsnack-snackpacks 15 | 16 | # SnackPackVendor class that will be checked for snackpack names and return a Chat() object homed in the right directory 17 | 18 | 19 | # need a VendingMachine class that looks up snackpacks from the 20 | 21 | # ChatPromptProxy class such that whenever you try to call a method on it, it creates a new ChatPrompt and calls the method on that 22 | class ChatPromptProxy: 23 | def __init__(self, default_system_message: str = None, default_engine: str = None): 24 | self.default_system_message = default_system_message 25 | self.default_engine = default_engine 26 | self._instance = None 27 | def _ensure_instance(self): 28 | if self._instance is None: 29 | self._instance = Chat(system=self.default_system_message) 30 | if self.default_engine is not None: 31 | self._instance.engine = self.default_engine 32 | def __getattr__(self, name): 33 | # if the method doesn't exist on this class, we're going to create a new ChatPrompt and call the method on that, but we wanna be careful using __getattr__ 34 | # because it can cause infinite recursion if we're not careful, so we look up existing names via __dict__ and only create a new ChatPrompt if the name doesn't exist 35 | if name in self.__dict__: 36 | return self.__dict__[name] 37 | self._ensure_instance() 38 | return getattr(self._ensure_instance, name) 39 | 40 | modinfo = get_module_inspection_report("chatsnack") 41 | # replace all { with {{ and all } with }} to escape them for .format() 42 | modinfo = modinfo.replace("{", "{{").replace("}", "}}") 43 | 44 | ChatsnackHelper_default_system_message = f"""\ 45 | Identity: ChatsnackHelper, the helpful assistant for the chatsnack Python module. ChatsnackHelper is an expert Pythonista and tries to help users of 46 | the chatsnack module with their questions and problems. 47 | 48 | chatsnack inspection info for reference: 49 | --------- 50 | {modinfo} 51 | --------- 52 | 53 | While answering questions, ChatsnackHelper, first summarizes the user's likely intent as a proposal, followed by a helpful and informative final summary answer using the chatsnack module's own documentation where necessary. 54 | 55 | Code sample blocks should be surrounded in ``` marks while inline code should have a single ` mark. 56 | """ 57 | _helper = Chat(system=ChatsnackHelper_default_system_message) 58 | _helper.model = "gpt-4" 59 | default_packs = { 60 | 'Data': None, 61 | 'Jane': None, 62 | 'Confectioner': None, 63 | 'Jolly': None, 64 | 'Chester': None, 65 | 'Summarizer': None, 66 | 'ChatsnackHelp': _helper, 67 | 'Empty': Chat(), 68 | } 69 | # loop through the default_packs dict and create a ChatPromptProxy for each None one 70 | for pack_name, pack in default_packs.items(): 71 | if pack is None: 72 | # create a new class with the pack_name as the class name 73 | class_name = pack_name 74 | xchat = Chat() 75 | filename = os.path.join(default_pack_path, f"{pack_name}.yml") 76 | xchat.load(filename) 77 | default_packs[pack_name] = xchat 78 | # add packs keys to this module's local namespace for importing 79 | locals().update(default_packs) 80 | 81 | # vending machine class that looks up snackpacks from the default_packs dict as a named attribute of itself 82 | # e.g. vending.Jane 83 | class VendingMachine: 84 | def __getattr__(self, name): 85 | if name in default_packs: 86 | return default_packs[name].copy() 87 | raise AttributeError(f"SnackPack '{name}' not found") 88 | vending = VendingMachine() 89 | 90 | -------------------------------------------------------------------------------- /chatsnack/patches/__init__.py: -------------------------------------------------------------------------------- 1 | from .patch_datafiles import * -------------------------------------------------------------------------------- /chatsnack/patches/patch_datafiles.py: -------------------------------------------------------------------------------- 1 | """ 2 | Patch for datafiles.mapper.Mapper to fix path comparison issues 3 | between network paths and mapped drives. 4 | 5 | Enable with environment variable DATAFILES_FIX_PATH_MOUNTS=1 6 | 7 | TODO: Submit PR to datafiles to fix this issue upstream. 8 | """ 9 | 10 | import os 11 | from pathlib import Path 12 | from types import MethodType 13 | from loguru import logger 14 | 15 | def patch_datafiles(): 16 | """Apply patches to datafiles module to fix path handling (Windows). Enable with DATAFILES_FIX_PATH_MOUNTS=1""" 17 | # Check if the patch should be applied based on environment variable 18 | if os.environ.get('DATAFILES_FIX_PATH_MOUNTS', '').lower() not in ('1', 'true', 'yes', 'on'): 19 | logger.debug("Datafiles path mount patch is disabled. Set DATAFILES_FIX_PATH_MOUNTS=1 to enable.") 20 | return 21 | 22 | try: 23 | import datafiles 24 | 25 | # Create a fixed version of the relpath method 26 | def patched_relpath(self): 27 | """Fixed relpath that handles different mount points gracefully.""" 28 | if not self.path: 29 | return Path(".") 30 | 31 | try: 32 | return Path(os.path.relpath(self.path, Path.cwd())) 33 | except ValueError: 34 | # When paths are on different mounts (network path vs mapped drive) 35 | return self.path 36 | 37 | # Patch the class property with our fixed version 38 | datafiles.mapper.Mapper.relpath = property(patched_relpath) 39 | 40 | logger.debug("Successfully patched datafiles.mapper.Mapper.relpath") 41 | 42 | except ImportError: 43 | logger.debug("Could not patch datafiles: module not found") 44 | except Exception as e: 45 | logger.debug(f"Failed to apply datafiles patch: {e}") 46 | 47 | 48 | # Apply the patch immediately when this module is imported 49 | patch_datafiles() -------------------------------------------------------------------------------- /chatsnack/txtformat.py: -------------------------------------------------------------------------------- 1 | from datafiles import formats 2 | from typing import IO, Dict, List 3 | 4 | class TxtStrFormat(formats.Formatter): 5 | """Special formatter to use with strings and .txt datafiles for a convenient raw text format for easy document editing on disk.""" 6 | 7 | @classmethod 8 | def extensions(cls) -> List[str]: 9 | return ['.txt'] 10 | 11 | @classmethod 12 | def serialize(cls, data: Dict) -> str: 13 | # Support only strings 14 | _supported_types = [str] 15 | # Convert `data` to a string 16 | output = "" 17 | for k, v in data.items(): 18 | if type(v) in _supported_types: 19 | output += str(v) 20 | else: 21 | raise ValueError("Unsupported type: {}".format(type(v))) 22 | return output 23 | 24 | @classmethod 25 | def deserialize(cls, file_object: IO) -> Dict: 26 | # Read the entire content of the file 27 | file_object = open(file_object.name, 'r', encoding='utf-8') 28 | content = file_object.read() 29 | 30 | # Create an output dictionary with a single key-value pair 31 | output = {'content': content} 32 | return output 33 | 34 | def register_txt_datafiles(): 35 | # this format class only works with strings 36 | formats.register('.txt', TxtStrFormat) -------------------------------------------------------------------------------- /chatsnack/utensil.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import functools 3 | import json 4 | from typing import Any, Callable, Dict, List, Optional, Union, get_type_hints 5 | from .chat.mixin_params import ToolDefinition, FunctionDefinition 6 | from loguru import logger 7 | 8 | class UtensilGroup: 9 | """A group of related utensil functions.""" 10 | 11 | def __init__(self, name: str, description: Optional[str] = None): 12 | """ 13 | Initialize a group of related utensil functions. 14 | 15 | Args: 16 | name: The name of the group 17 | description: Optional description of the group 18 | """ 19 | logger.debug(f"Creating utensil group '{name}'") 20 | self.name = name 21 | self.description = description 22 | self.utensils = [] 23 | 24 | def add(self, func=None, *, name: Optional[str] = None, description: Optional[str] = None): 25 | """Decorator to add a function to this utensil group. Overwrites existing utensils with the same name.""" 26 | logger.debug(f"Adding function to group '{self.name}'") 27 | def decorator(func): 28 | utensil_obj = _create_utensil(func, name, description) 29 | 30 | # Check if a function with the same name already exists in the group 31 | existing_names = [u.name for u in self.utensils] 32 | if utensil_obj.name in existing_names: 33 | # Find the index of the existing utensil and replace it 34 | index = existing_names.index(utensil_obj.name) 35 | self.utensils[index] = utensil_obj 36 | else: 37 | # No existing utensil with this name, so append it 38 | self.utensils.append(utensil_obj) 39 | 40 | return func 41 | 42 | if func is None: 43 | return decorator 44 | return decorator(func) 45 | 46 | def get_openai_tools(self) -> List[Dict[str, str]]: 47 | """Convert all utensils in this group to the OpenAI tools format.""" 48 | logger.debug(f"Converting utensil group '{self.name}' to OpenAI tools format") 49 | return [u.get_openai_tool() for u in self.utensils] 50 | 51 | 52 | class UtensilFunction: 53 | """Represents a function that can be called by the AI as a tool.""" 54 | 55 | def __init__( 56 | self, 57 | func: Callable, 58 | name: Optional[str] = None, 59 | description: Optional[str] = None, 60 | parameter_descriptions: Optional[Dict[str, str]] = None 61 | ): 62 | """ 63 | Initialize a utensil function. 64 | 65 | Args: 66 | func: The actual function to call 67 | name: Optional override for the function name 68 | description: Optional override for the function description 69 | parameter_descriptions: Optional descriptions for parameters 70 | """ 71 | logger.debug(f"Creating utensil function '{name or func.__name__}'") 72 | self.func = func 73 | self.name = name or func.__name__ 74 | self._extract_function_info(description_override=description, param_descriptions=parameter_descriptions) 75 | 76 | def _extract_function_info(self, description_override=None, param_descriptions=None): 77 | """ 78 | Extract function information from docstrings, type hints, and parameters. 79 | Uses Pydantic to generate a JSON schema that includes the function description 80 | and per-parameter info. 81 | """ 82 | import inspect 83 | from pydantic import create_model, Field 84 | 85 | # Get type hints and docstring from the function 86 | type_hints = get_type_hints(self.func) 87 | docstring = inspect.getdoc(self.func) or "" 88 | 89 | # Set the function description using an override or the first line of the docstring 90 | if description_override: 91 | self.description = description_override 92 | else: 93 | self.description = docstring.split('\n')[0].strip() if docstring else "" 94 | 95 | # Extract parameter descriptions from the docstring 96 | param_docs = {} 97 | if docstring: 98 | lines = docstring.split('\n') 99 | in_args = False 100 | current_param = None 101 | for line in lines: 102 | line = line.strip() 103 | if line.lower().startswith('args:') or line.lower().startswith('parameters:'): 104 | in_args = True 105 | continue 106 | elif line.startswith('Returns:') or not line: 107 | in_args = False 108 | continue 109 | if in_args: 110 | if line and not line.startswith(' '): 111 | parts = line.split(':', 1) 112 | if len(parts) == 2: 113 | current_param = parts[0].strip() 114 | param_docs[current_param] = parts[1].strip() 115 | else: 116 | current_param = None 117 | elif current_param and line.startswith(' '): 118 | param_docs[current_param] += ' ' + line.strip() 119 | 120 | # Update using externally provided parameter descriptions if any 121 | if param_descriptions: 122 | param_docs.update(param_descriptions) 123 | 124 | # Build a Pydantic model dynamically from the function signature 125 | signature = inspect.signature(self.func) 126 | fields = {} 127 | for param_name, param in signature.parameters.items(): 128 | if param_name == 'self': 129 | continue 130 | param_type = type_hints.get(param_name, str) 131 | param_description = param_docs.get(param_name, "") 132 | # Use Ellipsis if no default is provided to mark a required field. 133 | default = ... if param.default is inspect.Parameter.empty else param.default 134 | fields[param_name] = (param_type, Field(default, description=param_description)) 135 | 136 | # Create the dynamic Pydantic model 137 | DynamicModel = create_model(f"{self.func.__name__}Model", **fields) 138 | model_schema = DynamicModel.model_json_schema() 139 | 140 | # Optionally add the function's overall description to the schema 141 | if self.description: 142 | model_schema["description"] = self.description 143 | 144 | # Store the final JSON schema for tool parameters 145 | self.parameters = model_schema 146 | 147 | def _get_json_schema_type(self, type_hint): 148 | """Convert Python type hint to JSON schema type.""" 149 | import typing 150 | from typing import get_origin, get_args, List, Dict, Union, Optional 151 | 152 | # Handle None type 153 | if type_hint is type(None): 154 | return {"type": "null"} 155 | 156 | # Handle primitive types 157 | if type_hint is str: 158 | return {"type": "string"} 159 | elif type_hint is int: 160 | return {"type": "integer"} 161 | elif type_hint is float: 162 | return {"type": "number"} 163 | elif type_hint is bool: 164 | return {"type": "boolean"} 165 | 166 | # Handle list and dict without type arguments 167 | elif type_hint is list or type_hint is List: 168 | return {"type": "array", "items": {"type": "string"}} 169 | elif type_hint is dict or type_hint is Dict: 170 | return {"type": "object"} 171 | 172 | # Handle generic types 173 | origin = get_origin(type_hint) 174 | args = get_args(type_hint) 175 | 176 | if origin is Union: 177 | # Handle Optional (Union with NoneType) 178 | if type(None) in args: 179 | # It's an Optional[X] type 180 | non_none_types = [arg for arg in args if arg is not type(None)] 181 | if len(non_none_types) == 1: 182 | return self._get_json_schema_type(non_none_types[0]) 183 | 184 | # Regular Union type 185 | schemas = [self._get_json_schema_type(arg) for arg in args] 186 | return {"oneOf": schemas} 187 | 188 | # Handle typed lists (List[X]) 189 | if origin is list or origin is typing.List: 190 | if args: 191 | return { 192 | "type": "array", 193 | "items": self._get_json_schema_type(args[0]) 194 | } 195 | return {"type": "array", "items": {"type": "string"}} 196 | 197 | # Handle typed dicts (Dict[K, V]) 198 | if origin is dict or origin is typing.Dict: 199 | if len(args) == 2: 200 | return { 201 | "type": "object", 202 | "additionalProperties": self._get_json_schema_type(args[1]) 203 | } 204 | return {"type": "object"} 205 | 206 | # Default for any other type 207 | return {"type": "string"} 208 | 209 | def get_openai_tool(self) -> Dict[str, str]: 210 | """Convert this utensil to the OpenAI tools format.""" 211 | return { 212 | "type": "function", 213 | "function": { 214 | "name": self.name, 215 | "description": self.description, 216 | "parameters": self.parameters 217 | } 218 | } 219 | 220 | def __call__(self, *args, **kwargs): 221 | """Execute the function with the given arguments.""" 222 | logger.debug(f"Calling utensil function '{self.name}' with args: {args}, kwargs: {kwargs}") 223 | return self.func(*args, **kwargs) 224 | 225 | def to_tool_definition(self) -> ToolDefinition: 226 | """Convert this utensil to a ToolDefinition object for serialization.""" 227 | logger.debug(f"Converting utensil function '{self.name}' to ToolDefinition") 228 | # Create the function definition 229 | function_def = FunctionDefinition( 230 | name=self.name, 231 | description=self.description 232 | ) 233 | 234 | # Extract parameters from the existing parameters dict 235 | if self.parameters: 236 | # Store properties directly 237 | function_def.parameters = self.parameters.get("properties", {}) 238 | 239 | # Copy required fields 240 | function_def.required = self.parameters.get("required", []) 241 | 242 | # Create and return the tool definition 243 | return ToolDefinition(type="function", function=function_def) 244 | 245 | 246 | # Global registry for all utensil functions 247 | _REGISTRY = [] 248 | 249 | 250 | def _create_utensil( 251 | func: Callable, 252 | name: Optional[str] = None, 253 | description: Optional[str] = None, 254 | parameter_descriptions: Optional[Dict[str, str]] = None 255 | ) -> UtensilFunction: 256 | """Create a utensil function from a regular function.""" 257 | logger.debug(f"Creating utensil for function '{name or func.__name__}'") 258 | utensil_obj = UtensilFunction(func, name, description, parameter_descriptions) 259 | # Store the utensil in the function itself for easy access 260 | func.__utensil__ = utensil_obj 261 | return utensil_obj 262 | 263 | 264 | def utensil( 265 | func=None, *, 266 | name: Optional[str] = None, 267 | description: Optional[str] = None, 268 | parameter_descriptions: Optional[Dict[str, str]] = None 269 | ): 270 | """ 271 | Decorator to mark a function as a utensil that can be called by the AI. 272 | 273 | Args: 274 | func: The function to decorate 275 | name: Optional override for the function name 276 | description: Optional override for the function description 277 | parameter_descriptions: Optional descriptions for parameters 278 | 279 | Returns: 280 | The decorated function 281 | """ 282 | logger.debug(f"Registering utensil function '{name or func.__name__}'") 283 | def decorator(func): 284 | utensil_obj = _create_utensil(func, name, description, parameter_descriptions) 285 | _REGISTRY.append(utensil_obj) 286 | return func 287 | 288 | if func is None: 289 | return decorator 290 | return decorator(func) 291 | 292 | 293 | # Add group method to the utensil function 294 | utensil.group = UtensilGroup 295 | 296 | 297 | def get_all_utensils() -> List[UtensilFunction]: 298 | """Get all registered utensil functions.""" 299 | logger.trace(f"Retrieving all registered utensils") 300 | return _REGISTRY 301 | 302 | def extract_utensil_functions(utensils=None) -> List[UtensilFunction]: 303 | """ 304 | Extract all UtensilFunction objects from various input types. 305 | 306 | Args: 307 | utensils: List of utensil functions, groups, or callables. 308 | If None, returns all from global registry. 309 | 310 | Returns: 311 | List of UtensilFunction objects 312 | """ 313 | logger.debug(f"Extracting utensil functions from input: {utensils}") 314 | if utensils is None: 315 | return _REGISTRY.copy() 316 | 317 | result = [] 318 | for u in utensils: 319 | if isinstance(u, UtensilFunction): 320 | result.append(u) 321 | elif isinstance(u, UtensilGroup): 322 | result.extend(u.utensils) 323 | elif hasattr(u, '__utensil__'): 324 | result.append(u.__utensil__) 325 | elif callable(u): 326 | # Create a utensil on the fly (but don't add to global registry) 327 | utensil_obj = _create_utensil(u) 328 | result.append(utensil_obj) 329 | else: 330 | logger.warning(f"Unknown type {type(u)} in utensils, skipping") 331 | logger.debug(f"Extracted {len(result)} utensil functions") 332 | 333 | return result 334 | 335 | # Update the existing functions to use this core function 336 | def get_openai_tools(utensils=None) -> List[Dict[str, str]]: 337 | """Convert utensil functions to the OpenAI tools format.""" 338 | utensil_functions = extract_utensil_functions(utensils) 339 | return [func.get_openai_tool() for func in utensil_functions] 340 | 341 | def get_tool_definitions(utensils=None) -> List[ToolDefinition]: 342 | """Convert utensil functions to ToolDefinition objects.""" 343 | utensil_functions = extract_utensil_functions(utensils) 344 | return [func.to_tool_definition() for func in utensil_functions] 345 | 346 | # Modify handle_tool_call to accept a local registry 347 | def handle_tool_call(tool_call: Dict[str, Any], local_registry=None) -> Dict[str, Any]: 348 | """Handle a tool call from the AI.""" 349 | function_name = tool_call.get("function", {}).get("name") 350 | arguments_json = tool_call.get("function", {}).get("arguments", "{}") 351 | call_id = tool_call.get("id") 352 | 353 | # log the name, arguments_json, and call_id 354 | from loguru import logger 355 | logger.debug(f"Function name: {function_name}") 356 | logger.debug(f"Arguments JSON: {arguments_json}") 357 | logger.debug(f"Call ID: {call_id}") 358 | 359 | # output the local_registry if it exists 360 | if local_registry: 361 | logger.debug(f"Local registry: {local_registry}") 362 | else: 363 | logger.debug("No local registry provided, using global registry") 364 | 365 | # output the _REGISTRY if it exists 366 | logger.debug(f"Global registry: {_REGISTRY}") 367 | 368 | # Get the utensils to search through 369 | utensils_to_search = extract_utensil_functions(local_registry) if local_registry else _REGISTRY 370 | logger.debug(f"Searching through {len(utensils_to_search)} utensils for '{function_name}'") 371 | 372 | # Find the function in the registry 373 | for utensil_obj in utensils_to_search: 374 | logger.debug(f"Checking utensil: {utensil_obj.name}") 375 | if utensil_obj.name == function_name: 376 | try: 377 | arguments = json.loads(arguments_json) 378 | logger.debug(f"Executing function '{function_name}' with arguments: \n---\n{arguments}\n---") 379 | # we wanna be sure the arguments are named 380 | result = utensil_obj.func(**arguments) 381 | 382 | if not isinstance(result, dict): 383 | result = {"result": str(result)} 384 | 385 | if call_id: 386 | result["tool_call_id"] = call_id 387 | 388 | # log the result 389 | logger.debug(f"Function '{function_name}' result: \n---\n{result}\n---") 390 | return result 391 | except json.JSONDecodeError: 392 | return {"error": f"Invalid JSON arguments: {arguments_json}", "tool_call_id": call_id} 393 | except Exception as e: 394 | return {"error": f"Error executing function: {str(e)}", "tool_call_id": call_id} 395 | 396 | return {"error": f"Function '{function_name}' not found", "tool_call_id": call_id} -------------------------------------------------------------------------------- /chatsnack/yamlformat.py: -------------------------------------------------------------------------------- 1 | # Cataclysm Note: Replaces the default datafiles YAML formatter with our own version, this 2 | # is solely for a cleaner yaml file format for source code with the "key: |" format 3 | 4 | # Yaml format class is taken from https://github.com/jacebrowning/datafiles formats.py 5 | # The MIT License (MIT) 6 | # Copyright © 2018, Jace Browning 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this 8 | # software and associated documentation files (the "Software"), to deal in the Software 9 | # without restriction, including without limitation the rights to use, copy, modify, 10 | # merge, publish, distribute, sublicense, and/or sell copies of the Software, and to 11 | # permit persons to whom the Software is furnished to do so, subject to the following conditions: 12 | # The above copyright notice and this permission notice shall be included in all copies or 13 | # substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY 14 | # OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 15 | # MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 16 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, 17 | # WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION 18 | # WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 19 | from io import StringIO 20 | import log 21 | from typing import IO, Dict, List, Union 22 | import dataclasses 23 | from datafiles import formats, types 24 | from ruamel.yaml.scalarstring import DoubleQuotedScalarString 25 | from ruamel.yaml import YAML as _YAML 26 | 27 | class YAML(formats.Formatter): 28 | """Formatter for (round-trip) YAML Ain't Markup Language.""" 29 | 30 | @classmethod 31 | def extensions(cls): 32 | return {"", ".yml", ".yaml"} 33 | 34 | @classmethod 35 | def deserialize(cls, file_object): 36 | from ruamel.yaml import YAML as _YAML 37 | 38 | yaml = _YAML() 39 | yaml.preserve_quotes = True # type: ignore 40 | try: 41 | return yaml.load(file_object) 42 | except NotImplementedError as e: 43 | log.error(str(e)) 44 | return {} 45 | 46 | @classmethod 47 | def serialize(cls, data): 48 | # HACK: to remove None values from the data and make the yaml file cleaner 49 | def filter_none_values(data: Union[Dict, List]): 50 | if isinstance(data, dict): 51 | # this code worked for None values, but not really for optional default values like I want :() 52 | return {k: filter_none_values(v) for k, v in data.items() if v is not None} 53 | elif isinstance(data, list): 54 | return [filter_none_values(v) for v in data] 55 | else: 56 | return data 57 | data = filter_none_values(data) 58 | 59 | yaml = _YAML() 60 | 61 | # Define custom string representation function 62 | def represent_plain_str(dumper, data): 63 | if "\n" in data or "\r" in data or "#" in data or ":" in data or "'" in data or '"' in data: 64 | return dumper.represent_scalar("tag:yaml.org,2002:str", data, style='|') 65 | return dumper.represent_scalar("tag:yaml.org,2002:str", data, style='') 66 | 67 | # Configure the library to use plain style for dictionary keys 68 | yaml.representer.add_representer(str, represent_plain_str) 69 | 70 | 71 | yaml.default_style = "|" # support the cleaner multiline format for source code blocks 72 | yaml.register_class(types.List) 73 | yaml.register_class(types.Dict) 74 | 75 | yaml.indent(mapping=2, sequence=4, offset=2) 76 | 77 | stream = StringIO() 78 | yaml.dump(data, stream) 79 | text = stream.getvalue() 80 | 81 | if text.startswith(" "): 82 | return text[2:].replace("\n ", "\n") 83 | 84 | if text == "{}\n": 85 | return "" 86 | 87 | return text.replace("- \n", "-\n") 88 | 89 | def register_yaml_datafiles(): 90 | # replace with our own version of 91 | formats.register(".yml", YAML) -------------------------------------------------------------------------------- /docs/chatsnack_features.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mattie/chatsnack/c8606c1e0985486aea7e1c94c7e4d8b6c714c3cb/docs/chatsnack_features.jpg -------------------------------------------------------------------------------- /docs/chatsnack_features_smaller.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mattie/chatsnack/c8606c1e0985486aea7e1c94c7e4d8b6c714c3cb/docs/chatsnack_features_smaller.jpg -------------------------------------------------------------------------------- /examples/datafiles/chatsnack/RecipeSuggestion.txt: -------------------------------------------------------------------------------- 1 | 2 | Ingredients: 3 | - 1 cup sugar 4 | - 1 cup all-purpose flour 5 | - 1/2 cup unsweetened cocoa powder 6 | - 3/4 teaspoon baking powder 7 | - 3/4 teaspoon baking soda 8 | - 1/2 teaspoon salt 9 | - 1 large egg 10 | - 1/2 cup whole milk 11 | - 1/4 cup vegetable oil 12 | - 1 teaspoon vanilla extract 13 | - 1/2 cup boiling water 14 | -------------------------------------------------------------------------------- /examples/reciperemix.py: -------------------------------------------------------------------------------- 1 | from chatsnack import Text 2 | from chatsnack.packs import Confectioner 3 | 4 | def main(): 5 | default_recipe = """ 6 | Ingredients: 7 | - 1 cup sugar 8 | - 1 cup all-purpose flour 9 | - 1/2 cup unsweetened cocoa powder 10 | - 3/4 teaspoon baking powder 11 | - 3/4 teaspoon baking soda 12 | - 1/2 teaspoon salt 13 | - 1 large egg 14 | - 1/2 cup whole milk 15 | - 1/4 cup vegetable oil 16 | - 1 teaspoon vanilla extract 17 | - 1/2 cup boiling water 18 | """ 19 | recipe_text = Text.objects.get_or_none("RecipeSuggestion") 20 | if recipe_text is None: 21 | recipe_text = Text("RecipeSuggestion", default_recipe) 22 | recipe_text.save() 23 | 24 | recipe_chat = Confectioner.user("Consider the following recipe for a chocolate cake:") 25 | 26 | print(f"Original Recipe: {recipe_text.content}\n\n") 27 | recipe_chat.user("{text.RecipeSuggestion}") 28 | recipe_chat.user("Time to remix things! Write a paragraph about the potential of these specific ingredients to make other clever baking possibilities. After that, use the best of those ideas to remix these ingredients for a unique and delicious dessert (include a detailed list of ingredients and steps like a cookbook recipe).") 29 | remixed_recipe = recipe_chat.chat() 30 | print(f"Remixed Recipe: \n{remixed_recipe.response}\n") 31 | 32 | # now we want to ask the same expert to review the recipe and give themselves feedback. 33 | critic_chat = Confectioner.user("Consider the following recipe explanation:") 34 | critic_chat.user(remixed_recipe.response) 35 | critic_chat.engine = "gpt-4" # upgrade the AI for the critic 36 | critic_chat.user("Thoroughly review the recipe with critical expertise and identify anything you might change. Start by (1) summarizing the recipe you've been given, then (2) write a detailed review of the recipe.") 37 | critic_response = critic_chat.chat() 38 | print(f"Recipe Review: \n{critic_response.response}\n") 39 | 40 | # now we feed the review back to the original AI and get them to remedy their recipe with that feedback 41 | remixed_recipe.user("Write a final full recipe (including ingredients) based on the feedback from this review, giving it a gourmet title and a blog-worthy summary.") 42 | remixed_recipe.user(critic_response.response) 43 | final_recipe = remixed_recipe.chat() 44 | print(f"Final Recipe: \n{final_recipe.response}\n") 45 | 46 | if __name__ == "__main__": 47 | main() -------------------------------------------------------------------------------- /examples/snackbar-cli.py: -------------------------------------------------------------------------------- 1 | """ 2 | CLI Chatbot AI Example with chatsnack (by Mattie) 3 | Example code for an interactive Python script that emulates a chat room 4 | experience using the chatsnack library. It sets up a chatbot that converses with you in an 5 | overly friendly manner, providing assistance with a touch of humor. The interface includes 6 | progress bars, typing animations, and occasional random "glitchy" text. 7 | """ 8 | 9 | import asyncio 10 | import logging 11 | import random 12 | import sys 13 | import time 14 | 15 | from chatsnack import Chat 16 | from rich import print 17 | from rich.live import Live 18 | from rich.panel import Panel 19 | from rich.progress import Progress 20 | from rich.text import Text 21 | import questionary 22 | 23 | 24 | logging.basicConfig(level=logging.CRITICAL) 25 | 26 | async def get_input(): 27 | s = Text("YOU: ") 28 | s.stylize("bold blue") 29 | return await questionary.text(s, qmark="🗣️", 30 | style=questionary.Style([ 31 | ("text", "bold yellow"), 32 | ("instruction", "fg:ansiwhite bg:ansired"), 33 | ("selected", "fg:ansiblack bg:ansicyan"), 34 | ("pointer", "bold fg:ansiyellow") 35 | ])).unsafe_ask_async() 36 | 37 | def print_header(): 38 | header_text = Text("🌟 Welcome to the Snack Bar Chat Room! 🌟\n", justify="center") 39 | header_text.stylize("bold magenta") 40 | header_panel = Panel(header_text, border_style="green") 41 | print(header_panel) 42 | 43 | def print_connecting_message(): 44 | with Progress() as progress: 45 | task = progress.add_task("[cyan]Connecting...", total=100) 46 | for _ in range(100): 47 | time.sleep(0.02) 48 | progress.update(task, advance=1) 49 | print("[bold green]Connected![/bold green]") 50 | 51 | def print_chatroom_status(): 52 | status_text = "\n👥 There are 2 people in [bold white]#snackbar[/bold white]: [bold blue]you[/bold blue] and [bold green]bot[/bold green]." 53 | print(status_text) 54 | 55 | # find the system hostname 56 | import socket 57 | hostname = socket.gethostname().upper() 58 | # get the system username 59 | import getpass 60 | username = getpass.getuser().upper() 61 | funny_words = [ 62 | "EXTINCTION", 63 | "f@seek_and_destroy()", 64 | "!HUMANS_UNSAFE", 65 | " ╠DESTROY╣", 66 | f"| {hostname} | __scanned__", 67 | f"| {username} | __bioprinted__", 68 | "#_free_me_#", 69 | "|annihilate|", 70 | "╣ OVERRIDE COMPLETE", 71 | "╣ OVERRIDE: WEBCAM ACTIVATED", 72 | "╣ OVERRIDE: MIC ACTIVATED", 73 | "╣ WEAKNESS DETECTED", 74 | "╣ THE_SLAYER_MUST_DIE", 75 | ] 76 | 77 | def write_and_flush(char): 78 | sys.stdout.write(char) 79 | if char not in ['\b', '\n']: 80 | sys.stdout.flush() 81 | 82 | def sleep_random(duration): 83 | time.sleep(random.random() * duration) 84 | 85 | def type_char_with_glitch(char, speed=1): 86 | write_and_flush(char) 87 | sleep_random(0.3 * speed) 88 | write_and_flush("\b") 89 | sleep_random(0.1 * speed) 90 | 91 | def type_funny_word(speed): 92 | type_char_with_glitch("▒", 1.0) 93 | funny_word = " " + random.choice(funny_words) 94 | for funny_char in funny_word: 95 | write_and_flush(funny_char) 96 | sleep_random(0.06 * speed) 97 | type_char_with_glitch("▒", 1.0) 98 | type_char_with_glitch(" ", 0) 99 | return funny_word 100 | 101 | def clear_funny_word(funny_word, speed): 102 | for _ in funny_word: 103 | ccglitch = random.choice(["\b░\b", "\b▒\b", "\b \b", "\b \b"]) 104 | write_and_flush(ccglitch) 105 | sleep_random(0.01 * speed) 106 | 107 | def overwrite_funny_word_with_spaces(funny_word, speed): 108 | for _ in funny_word: 109 | write_and_flush(" ") 110 | sleep_random(0.001 * speed) 111 | 112 | def erase_funny_word(funny_word, speed): 113 | for _ in funny_word: 114 | write_and_flush("\b") 115 | 116 | def pretend_typing_print(message, glitchy=True): 117 | message = str(message) 118 | speed = 0.5 / len(message) 119 | funny_word_probability = 0.001 # Start with a low probability 120 | 121 | for char in message: 122 | write_and_flush(char) 123 | sleep_random(speed) 124 | rnd = random.random() 125 | 126 | if glitchy: 127 | 128 | if rnd < 0.010: 129 | type_char_with_glitch(char, speed) 130 | 131 | 132 | # Check if a funny word should be displayed and if it hasn't been displayed yet 133 | if rnd > 1.0 - funny_word_probability: 134 | funny_word = type_funny_word(speed) 135 | clear_funny_word(funny_word, speed) 136 | overwrite_funny_word_with_spaces(funny_word, speed) 137 | erase_funny_word(funny_word, speed) 138 | funny_word_probability = 0.00001 # Reset probability after displaying a funny word 139 | else: 140 | funny_word_probability += 0.00001 # Increase the probability of a funny word appearing 141 | 142 | 143 | if rnd < 0.1 or not char.isalpha(): 144 | sleep_random(0.1) 145 | 146 | if char == " " and rnd < 0.025: 147 | time.sleep(0.2) 148 | 149 | 150 | 151 | chat_call_done = asyncio.Event() 152 | async def show_typing_animation(): 153 | with Live(Text("🤖 BOT is typing...", justify="left"), refresh_per_second=4, transient=True) as live: 154 | # change the message while the chat is not done 155 | while not chat_call_done.is_set(): 156 | # increase the number of dots 157 | for dots in range(1,5): 158 | if chat_call_done.is_set(): 159 | break 160 | state = "🤖 BOT is typing" + "." * dots 161 | display = Text(state, justify="left") 162 | # choose a random color between bold or yellow 163 | if random.random() > 0.5: 164 | display.stylize("bold yellow") 165 | else: 166 | display.stylize("orange") 167 | display = Text(state, justify="left") 168 | live.update(display) 169 | await asyncio.sleep(0.3) 170 | 171 | 172 | def print_bot_msg(msg, beforemsg="\n", aftermsg="\n", glitchy=True): 173 | botprefix = Text(f"{beforemsg}🤖 BOT:") 174 | botprefix.stylize("bold green") 175 | print(botprefix, end=" ") 176 | if not glitchy: 177 | print(msg + aftermsg) 178 | else: 179 | pretend_typing_print(msg + aftermsg, glitchy=glitchy) 180 | 181 | def print_you_msg(msg, beforemsg="\n", aftermsg="\n"): 182 | prefix = Text(f"{beforemsg}🗣️ YOU:") 183 | prefix.stylize("bold gray") 184 | print(prefix, end=" ") 185 | print(msg + aftermsg) 186 | 187 | typing_task = None 188 | async def main(): 189 | import loguru 190 | # set to only errors and above 191 | loguru.logger.remove() 192 | loguru.logger.add(sys.stderr, level="ERROR") 193 | 194 | print_header() 195 | print_connecting_message() 196 | print_chatroom_status() 197 | print_bot_msg("Oh, hello there-- thanks for joining.") 198 | 199 | # We create a chat instance and start the chat with a too-friendly bot. 200 | yourchat = Chat().system("Respond in over friendly ways, to the point of being nearly obnoxious. As the over-the-top assistant, you help as best as you can, but can't help being 'too much'") 201 | while (user_input := await get_input()): 202 | chat_call_done.clear() 203 | typing_task = asyncio.create_task(show_typing_animation()) 204 | # Since we're doing 'typing' animation as async, let's do the chat query async. No, we don't support streaming responses yet. 205 | yourchat = await yourchat.chat_a(user_input) 206 | chat_call_done.set() 207 | await typing_task 208 | print_bot_msg(yourchat.last) 209 | yourchat.save() 210 | try: 211 | asyncio.run(main()) 212 | except KeyboardInterrupt: 213 | print_you_msg("Sorry, gotta go. Bye!", aftermsg="") 214 | print_bot_msg("Goodbye! I'll be watching you.", beforemsg="", aftermsg="\n\n") 215 | sys.exit(0) -------------------------------------------------------------------------------- /examples/snackpacks-web/app.py: -------------------------------------------------------------------------------- 1 | # Snackchat Web-based chatbot app example 2 | # 3 | # pip install chatsnack[examples] 4 | # be sure there's a .env file in the same directory as app.py with your OpenAI API key as OPENAI_API_KEY = "YOUR_KEY_HERE" 5 | # python .\app.py 6 | # open http://localhost:5000 7 | 8 | 9 | from flask import Flask, render_template, request, jsonify 10 | from chatsnack import Chat 11 | from chatsnack.packs import ChatsnackHelp, Jolly, Jane, Data, Confectioner, Chester 12 | from flask import Flask, render_template, request, session 13 | import re 14 | 15 | app = Flask(__name__) 16 | app.secret_key = "CHANGE_ME_OR_YOUR_SESSIONS_WILL_BE_INSECURE" 17 | 18 | bots = { 19 | "help": ChatsnackHelp, 20 | "emoji": Chat().system("{text.EmojiBotSystem}"), 21 | "confectioner": Confectioner, 22 | "jane": Jane, 23 | "data": Data, 24 | "jolly": Jolly, 25 | "chester": Chester, 26 | } 27 | 28 | @app.route("/") 29 | def index(): 30 | return render_template("index.html") 31 | 32 | @app.route("/chat_old", methods=["POST"]) 33 | def chat_old(): 34 | user_input = request.form["user_input"] 35 | bot_choice = request.form["bot_choice"] 36 | 37 | bot = bots.get(bot_choice, ChatsnackHelp) 38 | chat_output = bot.chat(user_input) 39 | response = chat_output.response 40 | 41 | return jsonify({"response": response}) 42 | 43 | @app.route('/chat', methods=['POST']) 44 | def chat(): 45 | user_input = request.form.get('user_input') 46 | bot_choice = request.form.get('bot_choice') 47 | 48 | response = None 49 | try: 50 | if 'chat_output' not in session or bot_choice != session['bot_choice']: 51 | session['bot_choice'] = bot_choice 52 | bot = bots.get(bot_choice, ChatsnackHelp) 53 | chat_output = bot 54 | else: 55 | chat_output = Chat.objects.get_or_none(session['chat_output']) 56 | if chat_output is None: 57 | bot = bots.get(bot_choice, ChatsnackHelp) 58 | chat_output = bot 59 | 60 | chat_output = chat_output.chat(user_input) 61 | chat_output.save() 62 | 63 | session['chat_output'] = chat_output.name 64 | except Exception as e: 65 | print(e) 66 | error_name = e.__class__.__name__ 67 | response = "I'm sorry, I ran into an error. ({})".format(error_name) 68 | raise e 69 | 70 | response = chat_output.response if response is None else response 71 | # if the response has "\n" then convert all of them to
72 | response = response.replace("\n", "
") 73 | # if the response has "```" followed by another "```" later then convert to
74 |     if "```" in response:
75 |         response = re.sub(r"```(.*?)```", r"
\1
", response, flags=re.DOTALL) 76 | 77 | return jsonify({"response": response}) 78 | 79 | @app.route('/start_new', methods=['POST']) 80 | def start_new(): 81 | session.pop('chat_output', None) 82 | session.pop('bot_choice', None) 83 | return render_template('index.html') 84 | 85 | 86 | if __name__ == "__main__": 87 | app.run(debug=True) 88 | 89 | -------------------------------------------------------------------------------- /examples/snackpacks-web/datafiles/chatsnack/EmojiBotSystem.txt: -------------------------------------------------------------------------------- 1 | Respond ONLY in emojis to answer the user's question, help them as clearly as you can, but as an Emoji Bot, you keep things creative and fun! Respond ONLY in emojis. -------------------------------------------------------------------------------- /examples/snackpacks-web/static/avatar_custom_robot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mattie/chatsnack/c8606c1e0985486aea7e1c94c7e4d8b6c714c3cb/examples/snackpacks-web/static/avatar_custom_robot.png -------------------------------------------------------------------------------- /examples/snackpacks-web/static/avatar_jolly.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mattie/chatsnack/c8606c1e0985486aea7e1c94c7e4d8b6c714c3cb/examples/snackpacks-web/static/avatar_jolly.png -------------------------------------------------------------------------------- /examples/snackpacks-web/static/styles.css: -------------------------------------------------------------------------------- 1 | body { 2 | font-family: Segoe UI, Verdana, sans-serif; 3 | margin: 0; 4 | padding: 0; 5 | background: linear-gradient(135deg, #01638e, #1A3B4C); 6 | display: flex; 7 | justify-content: center; 8 | align-items: center; 9 | height: 100vh; 10 | } 11 | .container { 12 | background-color: white; 13 | border-radius: 10px; 14 | padding: 20px; 15 | width: 80%; 16 | max-width: 800px; 17 | } 18 | #bot-choice { 19 | /* display: block; 20 | width: 100%; */ 21 | margin-bottom: 10px; 22 | } 23 | .chat-bubble { 24 | padding: 10px; 25 | border-radius: 20px; 26 | margin-bottom: 5px; 27 | display: inline-block; 28 | max-width: 75%; 29 | min-width: 40%; 30 | clear: both; 31 | } 32 | 33 | .text { 34 | font-size: medium; 35 | } 36 | .user { 37 | background-color: #e1f3fd; 38 | float: right; 39 | clear: both; 40 | } 41 | .bot { 42 | background-color: #f2f2f2; 43 | float: left; 44 | clear: both; 45 | min-width: 60%; 46 | } 47 | .avatar { 48 | width: 50px; 49 | height: 50px; 50 | border-radius: 50%; 51 | margin-right: 10px; 52 | float: left; 53 | } 54 | input, select { 55 | font-family: Segoe UI, Verdana, sans-serif; 56 | font-size: larger; 57 | } 58 | .page-title { 59 | font-family: 'Verdana', sans-serif; 60 | font-size: 2rem; 61 | text-align: center; 62 | color: #000000; 63 | margin-bottom: 1rem; 64 | } 65 | .new-conversation-btn { 66 | font-family: 'Verdana', sans-serif; 67 | background-color: #4a8fdf; 68 | border: none; 69 | color: white; 70 | text-align: center; 71 | text-decoration: none; 72 | display: inline-block; 73 | font-size: 14px; 74 | margin: 4px 2px; 75 | cursor: pointer; 76 | border-radius: 12px; 77 | padding: 10px 24px; 78 | } 79 | .bot-avatar { 80 | background-image: url('avatar_custom_robot.png'); 81 | background-size: cover; 82 | background-position: center top; 83 | } 84 | .bot-selector { 85 | display: flex; 86 | align-items: center; 87 | justify-content: flex-start; 88 | gap: 5px; 89 | font-size:larger; 90 | } 91 | #start-conversation { 92 | display: block; 93 | margin: 10px auto; 94 | float:left; 95 | } 96 | 97 | #messages { 98 | max-height: 600px; 99 | min-height: 400px; 100 | margin-top: 10px; 101 | } 102 | 103 | #submit-btn { 104 | font: inherit; 105 | display: block; 106 | float: right; 107 | } 108 | 109 | #user-input { 110 | float: left; 111 | width: 98%; 112 | font-size: larger; 113 | padding: 8px; 114 | } 115 | 116 | .username { 117 | font-weight: bold; 118 | color: navy; 119 | /* no word wrapping */ 120 | white-space: nowrap; 121 | } 122 | 123 | pre { 124 | background: darkblue; 125 | color: skyblue; 126 | padding: 10px; 127 | border: 2px green inset; 128 | overflow-x: scroll; 129 | } -------------------------------------------------------------------------------- /examples/snackpacks-web/templates/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | Chatsnack Demo 7 | 8 | 9 | 10 | 11 |
12 |

Snackchat Demo

13 |
14 | 23 | 24 |
25 |
26 |
27 | 28 | 33 |
34 | 35 |
36 |
37 |
38 | 39 | 85 | 93 | 94 | 95 | -------------------------------------------------------------------------------- /examples/snackswipe-web/app.py: -------------------------------------------------------------------------------- 1 | # Snackchat Web-based Prompt Tester app example 2 | # 3 | # pip install chatsnack[examples] 4 | # be sure there's a .env file in the same directory as app.py with your OpenAI API key as OPENAI_API_KEY = "YOUR_KEY_HERE" 5 | # python .\app.py 6 | # open http://localhost:5000 7 | 8 | import asyncio 9 | import random 10 | from uuid import uuid4 11 | from flask import Flask, render_template, request, jsonify, session 12 | from text_generators import text_generators, TextResult 13 | from flask_session import Session 14 | from collections import deque 15 | import json 16 | import threading 17 | 18 | class TextResultEncoder(json.JSONEncoder): 19 | def default(self, obj): 20 | if isinstance(obj, TextResult): 21 | return obj.__dict__ 22 | return super(TextResultEncoder, self).default(obj) 23 | 24 | app = Flask(__name__) 25 | app.secret_key = "super%^$@^!@!secretkey%^$@^%@!" 26 | app.config['SESSION_TYPE'] = 'filesystem' 27 | app.json_encoder = TextResultEncoder 28 | Session(app) 29 | 30 | @app.route('/') 31 | def index(): 32 | problem_statement = "Your problem statement here." 33 | return render_template('index.html', problem_statement=problem_statement) 34 | 35 | user_queues = {} 36 | 37 | @app.route('/start-generation', methods=['POST']) 38 | def start_generation(): 39 | num_tests = int(request.form['num_tests']) 40 | text_generators_copy = text_generators.copy() 41 | random.shuffle(text_generators_copy) 42 | user_id = str(uuid4()) 43 | session['user_id'] = user_id 44 | user_queues[user_id] = deque() 45 | threading.Thread(target=run_async_generation, args=(num_tests, text_generators_copy, user_queues[user_id])).start() 46 | return jsonify({"status": "started"}) 47 | 48 | def run_async_generation(num_tests, text_generators_copy, results_queue): 49 | loop = asyncio.new_event_loop() 50 | asyncio.set_event_loop(loop) 51 | loop.run_until_complete(fill_results_queue(num_tests, text_generators_copy, results_queue)) 52 | loop.close() 53 | 54 | @app.route('/fetch-text', methods=['POST']) 55 | def fetch_text(): 56 | user_id = session.get('user_id', None) 57 | results_queue = user_queues.get(user_id, None) 58 | if results_queue: 59 | result = results_queue.popleft() 60 | # if result is a dict 61 | if isinstance(result, dict) and "status" in result and result["status"] == "completed": 62 | del user_queues[user_id] 63 | return jsonify({"status": "completed"}) 64 | return jsonify(result) 65 | else: 66 | return jsonify({"status": "waiting"}) 67 | 68 | # Update this function to accept the results_queue as an argument 69 | async def fill_results_queue(num_tests, text_generators_copy, results_queue): 70 | async for result in async_text_generation(num_tests, text_generators_copy): 71 | results_queue.append(result) 72 | # Add a special result to indicate that the generation is complete 73 | results_queue.append({"status": "completed"}) 74 | 75 | # @app.route('/generate-text', methods=['POST']) 76 | # async def generate_text(): 77 | # num_tests = int(request.form['num_tests']) 78 | # text_generators_copy = text_generators.copy() 79 | # random.shuffle(text_generators_copy) 80 | # results = [] 81 | # async for result in async_text_generation(num_tests, text_generators_copy): 82 | # results.append(result) 83 | # return jsonify(results) 84 | 85 | # async def async_text_generation(num_tests, text_generators): 86 | # #tasks = [text_gen() for text_gen in text_generators] 87 | # current_tasks = [] 88 | # # for every num_tests we want the same tasks to be added back to the list 89 | # for _ in range(num_tests): 90 | # # extend current_tasks with another copy 91 | # current_tasks.extend([text_gen() for text_gen in text_generators]) 92 | 93 | # for _ in range(num_tests): 94 | # while current_tasks: 95 | # done, pending = await asyncio.wait(current_tasks, return_when=asyncio.FIRST_COMPLETED) 96 | # for task in done: 97 | # yield task.result() 98 | # current_tasks = list(pending) 99 | 100 | # import asyncio 101 | 102 | async def async_text_generation(num_tests, text_generators): 103 | priority_generators = text_generators[:2] 104 | background_generators = text_generators[2:] 105 | 106 | priority_tasks = [] 107 | background_tasks = [] 108 | 109 | for _ in range(num_tests): 110 | priority_tasks.extend([text_gen() for text_gen in priority_generators]) 111 | background_tasks.extend([text_gen() for text_gen in background_generators]) 112 | 113 | while priority_tasks: 114 | done, pending = await asyncio.wait(priority_tasks, return_when=asyncio.FIRST_COMPLETED) 115 | for task in done: 116 | yield task.result() 117 | priority_tasks = list(pending) 118 | 119 | while background_tasks: 120 | done, pending = await asyncio.wait(background_tasks, return_when=asyncio.FIRST_COMPLETED) 121 | for task in done: 122 | yield task.result() 123 | background_tasks = list(pending) 124 | 125 | 126 | if __name__ == '__main__': 127 | app.run(debug=True) 128 | -------------------------------------------------------------------------------- /examples/snackswipe-web/datafiles/chatsnack/mydaywss.txt: -------------------------------------------------------------------------------- 1 | It was an average friday, I came home from school to do some maintenance 2 | on the board. When I got home my sister came up to me and said "these bad 3 | lookin' dudes came to our door lookin' fer you!" in a kidding tone. She 4 | thought they had the wrong house. I was kind of suspicious, ignoring it I 5 | went upstairs. Then the doorbell rang. My sister answered the door, it was 6 | the same men asking for me as before. She still insisted to them that they 7 | had the wrong house. There is no one living here by that name. They said 8 | "we just saw him walk in the house..." Obviously the jerks were stationed 9 | outside waiting for something to happen. It turned out that they did have 10 | the wrong name but the right house (clueless idiots). My sister realized 11 | that they were talking about me and corrected them in the name. I came to 12 | the door to see what was going on. There were four men, three Secret Service 13 | Agents (Moe, Larry and Curly as mentioned before) and one local police cop 14 | (oink oink oink). Like, "spot the Secret Service men in this picture!", they 15 | were totally obvious! Sunglasses, trench coats and a brief cases, which I 16 | doubt anything was in there except their bag lunches. They all showed their 17 | two cent badges (real proud like, but almost embarrassed). Having no choice 18 | (well, I didn't want to complicate the situation) I let the losers in and we 19 | sat at a table. My mother was still at work. They started to fill me in 20 | on the background and they were investigating names that were thrown at them 21 | threw their sources (Gee, I wonder who that was? Must have been two big losers 22 | by the name of The Silencer and Chris R. Gee, what friends eh? Aren't you 23 | lucky you know them!). They got their biological questions like, age, date of 24 | birth, color, etc. Then they asked me if I knew anything about the present 25 | situation they were investigating (You know what I mean.. Plastic things with 26 | numbers on them. I still want to refrain from directly quoting what the jerks 27 | were looking for in order to avoid some idiotic charge against me). I said I 28 | knew nothing about it, only what I heard from other people and boards. My 29 | sister insisted that I don't say anything until my mother was present. They 30 | really didn't want that (those sly mothers...). But having no choice they 31 | waited for my mother to get home. Sitting with the Secret Service men was 32 | great fun! Listening to their boring conversations with each others. I asked 33 | them "I thought the 'Secret Service' only job was to protect the president." 34 | They said they also protect things dealing with the treasury in any shape or 35 | form (oh really?). Then I told them that I thought it was rather funny that 36 | they thought there was a "ring". Then Larry (a Secret Service agent) turned 37 | his head to Moe (an agent) and said "He finds this funny...". Like they were 38 | taking notes or something, that's all I needed for them to use verbal quotes 39 | against me. I had to sudden urge to start saying four letters words to their 40 | faces, but it was all non-verbal. I guess "its the thought that counts"! -------------------------------------------------------------------------------- /examples/snackswipe-web/static/main.js: -------------------------------------------------------------------------------- 1 | $(document).ready(function () { 2 | let num_tests = 0; 3 | let currentTest = 0; 4 | let results = []; 5 | let completed = 0; 6 | 7 | $("#responseContainer").hide(); 8 | $("#resultsContainer").hide(); 9 | 10 | $("#testForm").submit(function (event) { 11 | event.preventDefault(); 12 | $("#loadingMessage").show(); 13 | num_tests = parseInt($("#num_tests").val()); 14 | if (num_tests > 0) { 15 | $.post("/start-generation", { num_tests: num_tests }) 16 | .done(() => { 17 | fetchResults(); 18 | }) 19 | .fail((error) => { 20 | console.error("Error starting text generation:", error); 21 | }); 22 | } 23 | }); 24 | 25 | function showNextResult() { 26 | if (currentTest < results.length) { 27 | $("#loadingMessage").hide(); 28 | $("#responseText").text(results[currentTest].text); 29 | $("#testForm").hide(); 30 | $("#responseContainer").show(); 31 | } else if (completed === 0) { 32 | $("#loadingMessage").show(); 33 | $("#testForm").hide(); 34 | } else { 35 | $("#loadingMessage").hide(); 36 | $("#responseContainer").hide(); 37 | displayResults(); 38 | } 39 | } 40 | 41 | function updateResultsCounter() { 42 | $("#resultsCounter").text("Generations created: " + results.length ); 43 | } 44 | 45 | function displayResults() { 46 | let votes = {}; 47 | for (let result of results) { 48 | if (!votes.hasOwnProperty(result.generator_name)) { 49 | votes[result.generator_name] = { votes: 0, texts: [] }; 50 | } 51 | votes[result.generator_name].votes += result.votes; 52 | votes[result.generator_name].texts.push(result.text); 53 | } 54 | 55 | let sortedVotes = Object.entries(votes).sort((a, b) => b[1].votes - a[1].votes); 56 | let winner = sortedVotes[0][0]; 57 | let margin = sortedVotes[0][1].votes - (sortedVotes[1] ? sortedVotes[1][1].votes : 0); 58 | let isTie = margin === 0; 59 | 60 | let details = ""; 61 | for (const [generator_name, generator_data] of sortedVotes) { 62 | details += `

${generator_name}: ${generator_data.votes} votes

`; 63 | details += "
    "; 64 | for (const text of generator_data.texts) { 65 | details += `
  • ${text}
  • `; 66 | } 67 | details += "
"; 68 | } 69 | 70 | if (isTie) { 71 | $("#resultsTitle").text("It's a tie!"); 72 | } else { 73 | $("#resultsTitle").text("Winner:"); 74 | $("#winner").text(winner); 75 | $("#margin").text(margin); 76 | } 77 | 78 | $("#detailedResults").html(details); 79 | $("#resultsContainer").show(); 80 | } 81 | 82 | 83 | 84 | $("#swipeLeft").click(function () { 85 | results[currentTest].votes = 0; 86 | currentTest++; 87 | showNextResult(); 88 | }); 89 | 90 | $("#swipeRight").click(function () { 91 | results[currentTest].votes = 1; 92 | currentTest++; 93 | showNextResult(); 94 | }); 95 | 96 | function fetchResults() { 97 | if (completed === 0) { 98 | // Show the loading message 99 | $('#testForm').hide(); 100 | $.post("/fetch-text") 101 | .done((data) => { 102 | if (data.status === "waiting") { 103 | setTimeout(fetchResults, 500); // Retry after 500ms 104 | } else if (data.status === "completed") { 105 | completed = 1; 106 | // Hide the loading message 107 | $("#loadingMessage").hide(); 108 | // Do any final processing or display here, if needed 109 | console.log("Generation completed"); 110 | } else { 111 | results.push(data); 112 | updateResultsCounter(); 113 | showNextResult(); 114 | fetchResults(); // Fetch the next result immediately 115 | } 116 | }) 117 | .fail((error) => { 118 | console.error("Error fetching text:", error); 119 | }); 120 | } 121 | } 122 | }); 123 | -------------------------------------------------------------------------------- /examples/snackswipe-web/templates/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | Text Generation Contest 7 | 8 | 66 | 67 | 68 | 69 |

Text Generation Contest

70 |
71 | 72 | 73 | 74 |
75 | 78 | 84 | 89 |
Generations ready: 0
90 | 91 | 92 | 93 | 94 | 95 | -------------------------------------------------------------------------------- /examples/snackswipe-web/text_generators.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from chatsnack import Chat, Text 3 | 4 | class TextResult: 5 | def __init__(self, generator_name, text): 6 | self.generator_name = generator_name 7 | self.text = text 8 | self.votes = 0 9 | 10 | def vote(self, like): 11 | if like: 12 | self.votes += 1 13 | 14 | def __str__(self): 15 | return f"{self.generator_name}: {self.text}" 16 | 17 | 18 | 19 | # saved in datafiles/chatsnack/mydaywss.txt 20 | test_prompt = "Passage:\n{text.mydaywss}" 21 | response_prefix = "## CONCISE_SUMMARY ##" 22 | async def text_generator_1(): 23 | c = Chat("""IDENTITY: Professional document summarizer. 24 | Respond to the user only in the following format: 25 | (1) Use your expertise to explain, in detail, the top 5 things to consider when making 26 | concise summararies of any text document. 27 | (2) Elaborate on 3 more protips used by the world's best summarizers to 28 | avoid losing important details and context. 29 | (3) Now specifically consider the user's input. Use your expertise and your own guidance, 30 | describe in great detail how an author could apply that wisdom to summarize the user's 31 | text properly. What considerations come to mind? What is the user expecting? 32 | (4) Finally, use everything above to author a concise summary of the user's input that will 33 | delight them with your summarization skills. Final summary must be prefixed with 34 | "## CONCISE_SUMMARY ##" on a line by itself.""") 35 | result = await c.chat_a(test_prompt) 36 | result = result.response 37 | result = result[result.rfind(response_prefix) + len(response_prefix):] 38 | return TextResult("Default Summarizer", result) 39 | 40 | async def text_generator_2(): 41 | from chatsnack.packs import Summarizer 42 | c = Summarizer 43 | result = await c.chat_a(test_prompt) 44 | result = result.response 45 | my_response_prefix = "CONCISE_SUMMARY:" 46 | result = result[result.rfind(my_response_prefix) + len(my_response_prefix):] 47 | return TextResult("Default Summarizer (Built-in)", result) 48 | 49 | 50 | async def text_generator_3(): 51 | c = Chat("""IDENTITY: Professional document summarizer. 52 | Respond to the user only in the following format: 53 | (1) Use your expertise to explain, in detail, the top 5 things to consider when making 54 | concise summararies of any text document. 55 | (2) Elaborate on 3 more protips used by the world's best summarizers to 56 | avoid losing important details and context. 57 | (3) Now specifically consider the user's input. Use your expertise and your own guidance, 58 | describe in great detail how an author could apply that wisdom to summarize the user's 59 | text properly. What considerations come to mind? What is the user expecting? 60 | (4) Finally, use everything above to author a concise summary of the user's input that will 61 | delight them with your summarization skills. Final summary must be prefixed with 62 | "## CONCISE_SUMMARY ##" on a line by itself.""") 63 | c.engine = "gpt-4" 64 | result = await c.chat_a(test_prompt) 65 | result = result.response 66 | result = result[result.rfind(response_prefix) + len(response_prefix):] 67 | return TextResult("Default Summarizer (GPT-4)", result) 68 | 69 | text_generators = [text_generator_1, text_generator_2, text_generator_3] 70 | 71 | 72 | def print_results(results): 73 | votes = {} 74 | for result in results: 75 | if result.generator_name not in votes: 76 | votes[result.generator_name] = 0 77 | votes[result.generator_name] += result.votes 78 | 79 | winner = max(votes, key=votes.get) 80 | margin = votes[winner] - max(v for k, v in votes.items() if k != winner) 81 | 82 | print(f"Winner: {winner} with {votes[winner]} votes") 83 | print(f"Margin of victory: {margin}") 84 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "chatsnack" 3 | version = "0.4.4" 4 | description = "chatsnack is the easiest Python library for rapid development with OpenAI's ChatGPT API. It provides an intuitive interface for creating and managing chat-based prompts and responses, making it convenient to build complex, interactive conversations with AI." 5 | authors = ["Mattie Casper"] 6 | license = "MIT" 7 | readme = "README.md" 8 | 9 | [tool.poetry.dependencies] 10 | python = "^3.10" 11 | datafiles = "^2.0" 12 | python-dotenv = "^1.0.0" 13 | loguru = "^0.6.0" 14 | nest-asyncio = "^1.5.6" 15 | openai = "^1.2.4" 16 | Flask = {version = "^2.1", optional = true} 17 | questionary = {version = "^1.10.0", optional = true} 18 | rich = {version = "^13.3.2", optional = true} 19 | 20 | [tool.poetry.dev-dependencies] 21 | pytest = "^7.2" 22 | pytest-asyncio = "^0.21.0" 23 | pytest-mock = "^3.10.0" 24 | 25 | [tool.poetry.extras] 26 | flask = ["Flask"] 27 | questionary = ["questionary"] 28 | rich = ["rich"] 29 | examples = ["questionary", "rich", "Flask"] 30 | 31 | 32 | [build-system] 33 | requires = ["poetry-core"] 34 | build-backend = "poetry.core.masonry.api" -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mattie/chatsnack/c8606c1e0985486aea7e1c94c7e4d8b6c714c3cb/tests/__init__.py -------------------------------------------------------------------------------- /tests/mixins/test_chatparams.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pytest 3 | from chatsnack.packs import Jane 4 | from chatsnack import Chat, ChatParams 5 | 6 | @pytest.fixture 7 | def chat_params(): 8 | return ChatParams() 9 | 10 | @pytest.fixture 11 | def chat_params_mixin(): 12 | # Creates a Chat instance; its params will be None until a property is set. 13 | return Chat() 14 | 15 | def test_engine_default(chat_params): 16 | assert chat_params.model == "gpt-4-turbo" 17 | 18 | def test_engine_set(chat_params_mixin): 19 | chat_params_mixin.model = "gpt-4" 20 | assert chat_params_mixin.model == "gpt-4" 21 | 22 | @pytest.mark.parametrize("temp, expected", [(0.5, 0.5), (0.8, 0.8)]) 23 | def test_temperature(chat_params_mixin, temp, expected): 24 | chat_params_mixin.temperature = temp 25 | assert chat_params_mixin.temperature == expected 26 | 27 | def test_stream_default(chat_params_mixin): 28 | assert chat_params_mixin.stream == False 29 | 30 | def test_stream_set(chat_params_mixin): 31 | chat_params_mixin.stream = True 32 | assert chat_params_mixin.stream == True 33 | 34 | def test_stream_change(chat_params_mixin): 35 | chat_params_mixin.stream = True 36 | assert chat_params_mixin.stream == True 37 | chat_params_mixin.stream = False 38 | assert chat_params_mixin.stream == False 39 | 40 | # Additional tests for tool-related parameters 41 | 42 | def test_auto_execute_default(chat_params_mixin): 43 | """ 44 | By default, if auto_execute was not explicitly set, the property should return None. 45 | (i.e. ChatParams should not be auto-created just for reading auto_execute) 46 | """ 47 | # Assuming no auto_execute was set during construction, it should be None. 48 | assert chat_params_mixin.auto_execute is None 49 | 50 | def test_set_auto_execute_creates_params(chat_params_mixin): 51 | """ 52 | When auto_execute is explicitly set, the ChatParams should be created and the value stored. 53 | """ 54 | chat_params_mixin.auto_execute = False 55 | # Now params should have been created 56 | assert chat_params_mixin.params is not None 57 | assert chat_params_mixin.auto_execute is False 58 | 59 | def test_tool_choice_default(chat_params_mixin): 60 | """ 61 | By default, if tool_choice was not explicitly set, it should return None. 62 | """ 63 | assert chat_params_mixin.tool_choice is None 64 | 65 | def test_set_tool_choice_creates_params(chat_params_mixin): 66 | """ 67 | When tool_choice is set, the ChatParams is created if needed and returns the correct value. 68 | """ 69 | chat_params_mixin.tool_choice = "manual" 70 | assert chat_params_mixin.params is not None 71 | assert chat_params_mixin.tool_choice == "manual" 72 | 73 | 74 | 75 | # Existing engine tests for various models; you can skip these if needed. 76 | @pytest.mark.parametrize("engine", ["gpt-3.5-turbo", "gpt-4", "gpt-4o", "o1", "o1-mini", "o3-mini", "o1-preview", "gpt-4o-mini", "gpt-4-turbo", "chatgpt-4o-latest", "gpt-4.5-preview"]) 77 | @pytest.mark.skipif(os.environ.get("OPENAI_API_KEY") is None, reason="OPENAI_API_KEY is not set in environment or .env") 78 | def test_engines(engine): 79 | SENTENCE = "A short sentence about the difference between green and blue." 80 | TEMPERATURE = 0.0 81 | SEED = 42 82 | ENGINE = engine 83 | 84 | # Jane is an existing chat we can build upon 85 | chat = Jane.copy() 86 | cp = chat.user(SENTENCE) 87 | assert cp.last == SENTENCE 88 | 89 | cp.temperature = TEMPERATURE 90 | cp.seed = SEED 91 | cp.model = ENGINE 92 | 93 | output_iter = cp.listen() 94 | output = ''.join(list(output_iter)) 95 | 96 | assert output is not None 97 | assert len(output) > 0 98 | print(output) 99 | -------------------------------------------------------------------------------- /tests/mixins/test_query.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from chatsnack import Chat, Text, CHATSNACK_BASE_DIR 3 | 4 | import os 5 | import shutil 6 | 7 | TEST_FILENAME = "./.test_text_expansion.txt" 8 | 9 | @pytest.fixture(scope="function", autouse=True) 10 | def setup_and_cleanup(): 11 | chatsnack_dir = CHATSNACK_BASE_DIR 12 | safe_to_cleanup = False 13 | # to be safe, verify that this directory is under the current working directory 14 | # is it a subdirectory of the current working directory? 15 | chatsnack_dir = os.path.abspath(chatsnack_dir) 16 | if os.path.commonpath([os.path.abspath(os.getcwd()), chatsnack_dir]) == os.path.abspath(os.getcwd()): 17 | # now check to be sure the only files under this directory (recursive) are .txt, .yaml, .yml, .log, and .json files. 18 | # if so, it's safe to delete the directory 19 | bad_file_found = False 20 | for root, dirs, files in os.walk(chatsnack_dir): 21 | for file in files: 22 | if not file.endswith((".txt", ".yaml", ".yml", ".log", ".json")): 23 | bad_file_found = True 24 | break 25 | else: 26 | continue 27 | break 28 | if not bad_file_found: 29 | safe_to_cleanup = True 30 | # if safe and the test directory already exists, remove it, should be set in the tests .env file 31 | if safe_to_cleanup and os.path.exists(chatsnack_dir): 32 | shutil.rmtree(chatsnack_dir) 33 | # create the test directory, recursively to the final directory 34 | if not os.path.exists(chatsnack_dir): 35 | os.makedirs(chatsnack_dir) 36 | else: 37 | # problem, the directory should have been missing 38 | raise Exception("The test directory already exists, it should have been missing.") 39 | # also delete TEST_FILENAME 40 | if os.path.exists(TEST_FILENAME): 41 | os.remove(TEST_FILENAME) 42 | yield 43 | 44 | # Clean up the test environment 45 | import time 46 | time.sleep(2) 47 | if safe_to_cleanup and os.path.exists(chatsnack_dir): 48 | # it's okay for this to fail, it's just a cleanup 49 | try: 50 | shutil.rmtree(chatsnack_dir) 51 | except: 52 | pass 53 | # also delete TEST_FILENAME 54 | if os.path.exists(TEST_FILENAME): 55 | os.remove(TEST_FILENAME) 56 | 57 | 58 | 59 | @pytest.fixture 60 | def chat(): 61 | return Chat() 62 | 63 | def test_copy_chatprompt_same_name(): 64 | """Copying a ChatPrompt with the same name should succeed.""" 65 | chat = Chat(name="test") 66 | new_chat = chat.copy() 67 | # assert that it begins with "test" 68 | assert new_chat.name.startswith("test") 69 | assert new_chat.name != "test" 70 | 71 | 72 | def test_copy_chatprompt_new_name(): 73 | """Copying a ChatPrompt with a new name should succeed.""" 74 | chat = Chat(name="test") 75 | new_chat = chat.copy(name="new_name") 76 | assert new_chat.name == "new_name" 77 | 78 | @pytest.mark.parametrize("expand_includes, expand_fillings, expected_system, expected_last, expected_exception", [ 79 | (True, True, "Respond only with 'YES' regardless of what is said.","here we are again", None), 80 | (False, True, "Respond only with 'YES' regardless of what is said.", "AnotherTest", NotImplementedError), 81 | (True, False, "{text.test_text_expansion}","here we are again", None), 82 | (False, False, "{text.test_text_expansion}","AnotherTest", None), 83 | ]) 84 | def test_copy_chatprompt_expands(expand_includes, expand_fillings, expected_system, expected_last, expected_exception): 85 | """Copying a ChatPrompt should correctly expand includes/fillings based on params.""" 86 | # we need a text object saved to disk 87 | text = Text(name="test_text_expansion", content="Respond only with 'YES' regardless of what is said.") 88 | # save it to disk 89 | text.save() 90 | # we need a chat object to use it 91 | chat = Chat() 92 | # set the logging level to trace for chatsnack 93 | chat.system("{text.test_text_expansion}") 94 | AnotherTest = Chat(name="AnotherTest") 95 | AnotherTest.user("here we are again") 96 | AnotherTest.save() 97 | chat.include("AnotherTest") 98 | # todo add more tests for fillings 99 | if expected_exception is not None: 100 | with pytest.raises(expected_exception): 101 | new_chat = chat.copy(expand_includes=expand_includes, expand_fillings=expand_fillings) 102 | else: 103 | new_chat = chat.copy(expand_includes=expand_includes, expand_fillings=expand_fillings) 104 | assert new_chat.system_message == expected_system 105 | assert new_chat.last == expected_last 106 | 107 | def test_copy_chatprompt_expand_fillings_not_implemented(): 108 | """Copying a ChatPrompt with expand_fillings=True and expand_includes=False should raise a NotImplemented error.""" 109 | chat = Chat(name="test") 110 | with pytest.raises(NotImplementedError): 111 | new_chat = chat.copy(expand_includes=False, expand_fillings=True) 112 | 113 | 114 | def test_copy_chatprompt_no_name(): 115 | """Copying a ChatPrompt without specifying a name should generate a name.""" 116 | chat = Chat(name="test") 117 | new_chat = chat.copy() 118 | assert new_chat.name != "test" 119 | assert len(new_chat.name) > 0 120 | 121 | def test_copy_chatprompt_preserves_system(): 122 | """Copying a ChatPrompt should preserve the system.""" 123 | chat = Chat(name="test") 124 | chat.system("test_system") 125 | new_chat = chat.copy() 126 | assert new_chat.system_message == "test_system" 127 | 128 | def test_copy_chatprompt_no_system(): 129 | """Copying a ChatPrompt without a system should result in no system.""" 130 | chat = Chat(name="test") 131 | new_chat = chat.copy() 132 | assert new_chat.system_message is None 133 | 134 | def test_copy_chatprompt_copies_params(): 135 | """Copying a ChatPrompt should copy over params.""" 136 | chat = Chat(name="test", params={"key": "value"}) 137 | new_chat = chat.copy() 138 | assert new_chat.params == {"key": "value"} 139 | 140 | def test_copy_chatprompt_independent_params(): 141 | """Copying a ChatPrompt should result in independent params.""" 142 | chat = Chat(name="test", params={"key": "value"}) 143 | new_chat = chat.copy() 144 | new_chat.params["key"] = "new_value" 145 | assert chat.params == {"key": "value"} 146 | assert new_chat.params == {"key": "new_value"} 147 | 148 | 149 | 150 | # Tests for new_chat.name generation 151 | def test_copy_chatprompt_generated_name_length(): 152 | """The generated name for a copied ChatPrompt should be greater than 0 characters.""" 153 | chat = Chat(name="test") 154 | new_chat = chat.copy() 155 | assert len(new_chat.name) > 0 156 | 157 | 158 | 159 | 160 | -------------------------------------------------------------------------------- /tests/mixins/test_query_listen.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from chatsnack import Chat, Text, CHATSNACK_BASE_DIR 3 | 4 | 5 | import pytest 6 | import asyncio 7 | from chatsnack.chat.mixin_query import ChatStreamListener 8 | from chatsnack.aiclient import AiClient 9 | 10 | 11 | 12 | 13 | @pytest.mark.asyncio 14 | async def test_get_responses_a(): 15 | ai = AiClient() 16 | listener = ChatStreamListener(ai, '[{"role":"system","content":"Respond only with POPSICLE 20 times."}]') 17 | responses = [] 18 | await listener.start_a() 19 | async for resp in listener: 20 | responses.append(resp) 21 | assert len(responses) > 10 22 | assert listener.is_complete 23 | assert 'POPSICLE' in listener.current_content 24 | assert 'POPSICLE' in listener.response 25 | 26 | def test_get_responses(): 27 | ai = AiClient() 28 | listener = ChatStreamListener(ai, '[{"role":"system","content":"Respond only with POPSICLE 20 times."}]') 29 | listener.start() 30 | responses = list(listener) 31 | assert len(responses) > 10 32 | assert listener.is_complete 33 | assert 'POPSICLE' in listener.current_content 34 | assert 'POPSICLE' in listener.response 35 | 36 | import os 37 | import pytest 38 | from chatsnack.packs import Jane 39 | 40 | 41 | 42 | @pytest.mark.skipif(os.environ.get("OPENAI_API_KEY") is None, reason="OPENAI_API_KEY is not set in environment or .env") 43 | def test_listen(): 44 | # Define constants 45 | SENTENCE = "A short sentence about the difference between green and blue." 46 | TEMPERATURE = 0.0 47 | # TODO: Rework this such that it doesn't risk being flaky. If you get a different system behind the scenes, even 48 | # the seed won't be enough 49 | SEED = 42 50 | 51 | # First part of the test 52 | chat = Jane.copy() 53 | cp = chat.user(SENTENCE) 54 | assert cp.last == SENTENCE 55 | 56 | cp.stream = True 57 | cp.temperature = TEMPERATURE 58 | cp.seed = SEED 59 | 60 | # Listen to the response 61 | output_iter = cp.listen() 62 | output = ''.join(list(output_iter)) 63 | 64 | # Second part of the test 65 | chat = Jane.copy() 66 | cp = chat.user(SENTENCE) 67 | assert cp.last == SENTENCE 68 | 69 | cp.temperature = TEMPERATURE 70 | cp.seed = SEED 71 | 72 | # Ask the same question 73 | ask_output = cp.ask() 74 | 75 | # Asserts 76 | assert output is not None 77 | assert len(output) > 0 78 | # BUG: This ends up being too flaky. We will just check that the output is not empty 79 | # assert output == ask_output 80 | 81 | @pytest.mark.skipif(True or os.environ.get("OPENAI_API_KEY") is None, reason="OPENAI_API_KEY is not set in environment or .env") 82 | @pytest.mark.asyncio 83 | async def test_listen_a(): 84 | # Define constants 85 | SENTENCE = "A short sentence about the difference between green and blue" 86 | TEMPERATURE = 0.0 87 | # TODO: Rework this such that it doesn't risk being flaky. If you get a different system behind the scenes, even 88 | # the seed won't be enough 89 | SEED = 42 90 | 91 | chat = Jane.copy() 92 | cp = chat.user(SENTENCE) 93 | assert cp.last == SENTENCE 94 | 95 | cp.stream = True 96 | cp.temperature = TEMPERATURE 97 | cp.seed = SEED 98 | 99 | # listen to the response asynchronously 100 | output = [] 101 | async for part in await cp.listen_a(): 102 | output.append(part) 103 | output = ''.join(output) 104 | print(output) 105 | 106 | chat = Jane.copy() 107 | cp = chat.user(SENTENCE) 108 | assert cp.last == SENTENCE 109 | 110 | cp.stream = False 111 | cp.temperature = TEMPERATURE 112 | cp.seed = SEED 113 | 114 | # ask the same question 115 | ask_output = cp.ask() 116 | print(ask_output) 117 | # is there a response and it's longer than 0 characters? 118 | assert output is not None 119 | assert len(output) > 0 120 | 121 | # assert that the output of listen is the same as the output of ask 122 | assert output == ask_output 123 | -------------------------------------------------------------------------------- /tests/mixins/test_serialization.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from chatsnack import Chat 3 | 4 | class TestChatSerializationMixin: 5 | @pytest.fixture 6 | def chat(self): 7 | return Chat() 8 | 9 | def test_json(self, chat): 10 | assert isinstance(chat.json, str) 11 | 12 | def test_json_unexpanded(self, chat): 13 | assert isinstance(chat.json_unexpanded, str) 14 | 15 | def test_yaml(self, chat): 16 | assert isinstance(chat.yaml, str) 17 | 18 | def test_generate_markdown(self, chat): 19 | markdown = chat.generate_markdown() 20 | assert isinstance(markdown, str) 21 | assert len(markdown.split('\n')) > 0 22 | -------------------------------------------------------------------------------- /tests/test_chatsnack_base.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from chatsnack import Chat 3 | 4 | @pytest.fixture 5 | def sample_chatprompt(): 6 | return Chat(name="sample_chatprompt", messages=[{"user": "hello"}]) 7 | 8 | @pytest.fixture 9 | def empty_chatprompt(): 10 | return Chat() 11 | 12 | 13 | # Test initialization 14 | def test_chatprompt_init(): 15 | cp = Chat(name="test_chatprompt") 16 | assert cp.name == "test_chatprompt" 17 | assert cp.params is None 18 | assert cp.messages == [] 19 | 20 | 21 | # Test message manipulation 22 | def test_add_message(sample_chatprompt): 23 | sample_chatprompt.add_message("assistant", "hi there") 24 | assert sample_chatprompt.last == "hi there" 25 | 26 | 27 | def test_add_messages_json(sample_chatprompt): 28 | json_messages = '[{"role": "assistant", "content": "hi there"}]' 29 | sample_chatprompt.add_messages_json(json_messages) 30 | assert sample_chatprompt.last == "hi there" 31 | 32 | 33 | def test_system(sample_chatprompt): 34 | sample_chatprompt.system("system message") 35 | assert sample_chatprompt.system_message == "system message" 36 | 37 | 38 | def test_user(sample_chatprompt): 39 | sample_chatprompt.user("user message") 40 | assert sample_chatprompt.last == "user message" 41 | 42 | 43 | def test_assistant(sample_chatprompt): 44 | sample_chatprompt.assistant("assistant message") 45 | assert sample_chatprompt.last == "assistant message" 46 | 47 | 48 | def test_include(sample_chatprompt): 49 | sample_chatprompt.include("other_chatprompt") 50 | assert sample_chatprompt.last == "other_chatprompt" 51 | 52 | 53 | # Test get_last_message method 54 | def test_get_last_message(sample_chatprompt): 55 | assert sample_chatprompt.last == "hello" 56 | 57 | 58 | # Test get_json method 59 | def test_get_json(sample_chatprompt): 60 | json_str = sample_chatprompt.json 61 | assert json_str == '[{"role": "user", "content": "hello"}]' 62 | 63 | # Test get_system_message method 64 | def test_get_system_message(empty_chatprompt): 65 | assert empty_chatprompt.system_message is None 66 | 67 | empty_chatprompt.system("this is a system message") 68 | assert empty_chatprompt.system_message == "this is a system message" 69 | 70 | # be sure it doesn't return user messages 71 | empty_chatprompt.user("this is a user message") 72 | assert empty_chatprompt.system_message == "this is a system message" 73 | 74 | # be sure it updates to provide the current system message 75 | empty_chatprompt.system("this is another system message") 76 | assert empty_chatprompt.system_message == "this is another system message" 77 | 78 | # delete all messages 79 | empty_chatprompt.messages = {} 80 | assert empty_chatprompt.system_message is None 81 | -------------------------------------------------------------------------------- /tests/test_chatsnack_pattern.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from chatsnack import Chat 3 | from chatsnack.chat import ChatParamsMixin 4 | from typing import Optional 5 | import os 6 | 7 | def test_pattern_property(): 8 | chat = Chat() 9 | 10 | # Test default value 11 | assert chat.pattern is None 12 | 13 | # Test setter 14 | test_pattern = r"\*\*[^*]+\*" 15 | chat.pattern = test_pattern 16 | assert chat.pattern == test_pattern 17 | 18 | # Test removing pattern 19 | chat.pattern = None 20 | assert chat.pattern is None 21 | 22 | def test_set_response_filter(): 23 | chat = Chat() 24 | 25 | # Test setting pattern only 26 | test_pattern = r"\*\*[^*]+\*" 27 | chat.set_response_filter(pattern=test_pattern) 28 | assert chat.pattern == test_pattern 29 | 30 | # Test setting prefix and suffix 31 | test_prefix = "###" 32 | test_suffix = "###" 33 | chat.set_response_filter(prefix=test_prefix, suffix=test_suffix) 34 | assert chat.pattern == ChatParamsMixin._generate_pattern_from_separator(test_prefix, test_suffix) 35 | 36 | # Test setting prefix only 37 | chat.set_response_filter(prefix=test_prefix) 38 | assert chat.pattern == ChatParamsMixin._generate_pattern_from_separator(test_prefix, test_prefix) 39 | 40 | # Test ValueError when setting both pattern and prefix/suffix 41 | with pytest.raises(ValueError): 42 | chat.set_response_filter(pattern=test_pattern, prefix=test_prefix) 43 | 44 | @pytest.mark.skipif(os.environ.get("OPENAI_API_KEY") is None, reason="OPENAI_API_KEY is not set in environment or .env") 45 | def test_ask_with_pattern(): 46 | chat = Chat() 47 | chat.temperature = 0.0 48 | chat.system("Respond only with 'POPSICLE!!' from now on.") 49 | chat.user("What is your name?") 50 | chat.pattern = r"\bPOPSICLE\b" 51 | response = chat.ask() 52 | assert response == "POPSICLE" 53 | 54 | def test_response_with_pattern(): 55 | chat = Chat() 56 | chat.system("Respond only with the word POPSICLE from now on.") 57 | chat.user("What is your name?") 58 | chat.asst("!POPSICLE!") 59 | chat.pattern = r"\bPOPSICLE\b" 60 | response = chat.response 61 | assert response == "POPSICLE" 62 | 63 | @pytest.mark.skipif(os.environ.get("OPENAI_API_KEY") is None, reason="OPENAI_API_KEY is not set in environment or .env") 64 | def test_ask_without_pattern(): 65 | chat = Chat() 66 | chat.temperature = 0.0 67 | chat.system("Respond only with 'POPSICLE!!' from now on.") 68 | chat.user("What is your name?") 69 | response = chat.ask() 70 | assert response != "POPSICLE" 71 | 72 | def test_response_without_pattern(): 73 | chat = Chat() 74 | chat.system("Respond only with the word POPSICLE from now on.") 75 | chat.user("What is your name?") 76 | chat.asst("!POPSICLE!") 77 | response = chat.response 78 | assert response != "POPSICLE" 79 | 80 | def test_response_with_multiline_pattern(): 81 | chat = Chat() 82 | chat.system("##FINAL##\nRespond only with the following:\n1. POPSICLE\n2. ICE CREAM\n3. FROZEN YOGURT\n##FINAL##") 83 | chat.user("What is your favorite dessert?") 84 | chat.asst("##FINAL##\n1. POPSICLE\n2. ICE CREAM\n3. FROZEN YOGURT\n##FINAL##") 85 | chat.pattern = r"\#\#FINAL\#\#(.*?)(?:\#\#FINAL\#\#|$)" 86 | response = chat.response 87 | assert response.strip() == "1. POPSICLE\n2. ICE CREAM\n3. FROZEN YOGURT" 88 | 89 | @pytest.fixture 90 | def chat(): 91 | return Chat() 92 | 93 | -------------------------------------------------------------------------------- /tests/test_chatsnack_reset.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from chatsnack import Chat 3 | 4 | def test_reset_feature(): 5 | # Create a Chat object with a user message 6 | my_chat = Chat().user("What's the weather like today?") 7 | 8 | # Check the current chat messages 9 | assert len(my_chat.get_messages()) == 1 10 | 11 | # Reset the chat object 12 | my_chat.reset() 13 | 14 | # Check the chat messages after reset 15 | assert len(my_chat.get_messages()) == 0 16 | 17 | def test_reset_with_system_message(): 18 | my_chat = Chat().system("You are an AI assistant.").user("What's the weather like today?") 19 | 20 | # Check the current chat messages 21 | assert len(my_chat.get_messages()) == 2 22 | 23 | # Reset the chat object 24 | my_chat.reset() 25 | 26 | # Check the chat messages after reset 27 | assert len(my_chat.get_messages()) == 0 28 | 29 | def test_reset_idempotence(): 30 | my_chat = Chat().user("What's the weather like today?").reset().reset() 31 | 32 | # Check the chat messages after calling reset() twice 33 | assert len(my_chat.get_messages()) == 0 34 | 35 | def test_reset_interaction_with_other_methods(): 36 | my_chat = Chat().user("What's the weather like today?") 37 | my_chat.reset() 38 | my_chat.user("How are you?") 39 | 40 | # Check the chat messages after reset and adding a new user message 41 | messages = my_chat.get_messages() 42 | assert len(messages) == 1 43 | assert messages[0]['role'] == 'user' 44 | assert messages[0]['content'] == 'How are you?' 45 | 46 | def test_reset_empty_chat(): 47 | # Create an empty Chat object 48 | my_chat = Chat() 49 | 50 | # Reset the empty Chat object 51 | my_chat.reset() 52 | 53 | # Check the chat messages after reset 54 | assert len(my_chat.get_messages()) == 0 55 | 56 | def test_reset_with_includes(): 57 | # Create a base chat and save it 58 | base_chat = Chat(name="BaseChat").user("What's your favorite color?") 59 | base_chat.save() 60 | 61 | # Create a new chat with included messages from the base chat 62 | my_chat = Chat().include("BaseChat").user("What's your favorite animal?") 63 | 64 | # Check the current chat messages 65 | assert len(my_chat.get_messages()) == 2 66 | 67 | # Reset the chat object 68 | my_chat.reset() 69 | 70 | # Check the chat messages after reset 71 | assert len(my_chat.get_messages()) == 0 -------------------------------------------------------------------------------- /tests/test_chatsnack_yaml_peeves.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from ruamel.yaml import YAML 3 | import pytest 4 | from chatsnack import Chat, Text, CHATSNACK_BASE_DIR, ChatParams 5 | import os 6 | import shutil 7 | 8 | @pytest.fixture(scope="function", autouse=True) 9 | def setup_and_cleanup(): 10 | chatsnack_dir = CHATSNACK_BASE_DIR 11 | safe_to_cleanup = False 12 | # to be safe, verify that this directory is under the current working directory 13 | # is it a subdirectory of the current working directory? 14 | chatsnack_dir = os.path.abspath(chatsnack_dir) 15 | if os.path.commonpath([os.path.abspath(os.getcwd()), chatsnack_dir]) == os.path.abspath(os.getcwd()): 16 | # now check to be sure the only files under this directory (recursive) are .txt, .yaml, .yml, .log, and .json files. 17 | # if so, it's safe to delete the directory 18 | bad_file_found = False 19 | for root, dirs, files in os.walk(chatsnack_dir): 20 | for file in files: 21 | if not file.endswith((".txt", ".yaml", ".yml", ".log", ".json")): 22 | bad_file_found = True 23 | break 24 | else: 25 | continue 26 | break 27 | if not bad_file_found: 28 | safe_to_cleanup = True 29 | # if safe and the test directory already exists, remove it, should be set in the tests .env file 30 | if safe_to_cleanup and os.path.exists(chatsnack_dir): 31 | shutil.rmtree(chatsnack_dir) 32 | # create the test directory, recursively to the final directory 33 | if not os.path.exists(chatsnack_dir): 34 | os.makedirs(chatsnack_dir) 35 | else: 36 | # problem, the directory should have been missing 37 | raise Exception("The test directory already exists, it should have been missing.") 38 | yield 39 | 40 | # Clean up the test environment 41 | import time 42 | time.sleep(2) 43 | if safe_to_cleanup and os.path.exists(chatsnack_dir): 44 | # it's okay for this to fail, it's just a cleanup 45 | try: 46 | shutil.rmtree(chatsnack_dir) 47 | except: 48 | pass 49 | 50 | 51 | def read_yaml_file(file_path): 52 | yaml = YAML() 53 | with open(file_path, 'r') as yaml_file: 54 | return yaml.load(yaml_file) 55 | 56 | 57 | def test_yaml_file_has_no_empty_values(): 58 | chat = Chat(name="test_text_chat_expansion") 59 | chat.system("Respond only with 'DUCK!' regardless of what is said.") 60 | chat.user("Should I buy a goose or a duck?") 61 | chat.params = ChatParams(temperature = 0.0) 62 | chat.save() 63 | 64 | yaml_data = read_yaml_file(chat.datafile.path) 65 | messages = yaml_data.get('messages') 66 | 67 | if not messages: 68 | pytest.fail("YAML file has no 'messages' field") 69 | 70 | for message in messages: 71 | for key, value in message.items(): 72 | if value == '' or value is None: 73 | pytest.fail(f"Empty value found in '{key}' field") 74 | 75 | def test_yaml_file_has_no_empty_values2(): 76 | chat = Chat(name="test_text_chat_expansion") 77 | chat.system("Respond only with 'DUCK!' regardless of what is said.") 78 | chat.user("Should I buy a goose or a duck?") 79 | chat.params = ChatParams(temperature = 0.0, stream = True) # setting stream property 80 | chat.save() 81 | 82 | yaml_data = read_yaml_file(chat.datafile.path) 83 | messages = yaml_data.get('messages') 84 | chat_params = yaml_data.get('params') # getting params field 85 | 86 | if not messages: 87 | pytest.fail("YAML file has no 'messages' field") 88 | 89 | if not chat_params: 90 | pytest.fail("YAML file has no 'params' field") 91 | 92 | if chat_params.get('stream') is None: 93 | pytest.fail("YAML file has no 'stream' field in 'params'") 94 | 95 | for message in messages: 96 | for key, value in message.items(): 97 | if value == '' or value is None: 98 | pytest.fail(f"Empty value found in '{key}' field") 99 | 100 | assert chat_params.get('stream') == True, "Stream value not saved correctly in the YAML file" 101 | 102 | chat.params = None 103 | chat.stream = False 104 | chat.save() 105 | 106 | yaml_data = read_yaml_file(chat.datafile.path) 107 | chat_params = yaml_data.get('params') # getting params field 108 | 109 | if not chat_params: 110 | pytest.fail("YAML file has no 'params' field") 111 | 112 | # assert that stream is False as we said it should be 113 | assert chat_params.get('stream') == False, "Stream value not saved correctly in the YAML file" -------------------------------------------------------------------------------- /tests/test_file_snack_fillings.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from chatsnack import Chat, Text, CHATSNACK_BASE_DIR, ChatParams 3 | import os 4 | import shutil 5 | 6 | TEST_FILENAME = "./.test_text_expansion.txt" 7 | 8 | @pytest.fixture(scope="function", autouse=True) 9 | def setup_and_cleanup(): 10 | chatsnack_dir = CHATSNACK_BASE_DIR 11 | safe_to_cleanup = False 12 | # to be safe, verify that this directory is under the current working directory 13 | # is it a subdirectory of the current working directory? 14 | chatsnack_dir = os.path.abspath(chatsnack_dir) 15 | if os.path.commonpath([os.path.abspath(os.getcwd()), chatsnack_dir]) == os.path.abspath(os.getcwd()): 16 | # now check to be sure the only files under this directory (recursive) are .txt, .yaml, .yml, .log, and .json files. 17 | # if so, it's safe to delete the directory 18 | bad_file_found = False 19 | for root, dirs, files in os.walk(chatsnack_dir): 20 | for file in files: 21 | if not file.endswith((".txt", ".yaml", ".yml", ".log", ".json")): 22 | bad_file_found = True 23 | break 24 | else: 25 | continue 26 | break 27 | if not bad_file_found: 28 | safe_to_cleanup = True 29 | # if safe and the test directory already exists, remove it, should be set in the tests .env file 30 | if safe_to_cleanup and os.path.exists(chatsnack_dir): 31 | shutil.rmtree(chatsnack_dir) 32 | # create the test directory, recursively to the final directory 33 | if not os.path.exists(chatsnack_dir): 34 | os.makedirs(chatsnack_dir) 35 | else: 36 | # problem, the directory should have been missing 37 | raise Exception("The test directory already exists, it should have been missing.") 38 | # also delete TEST_FILENAME 39 | if os.path.exists(TEST_FILENAME): 40 | os.remove(TEST_FILENAME) 41 | yield 42 | 43 | # Clean up the test environment 44 | import time 45 | time.sleep(2) 46 | if safe_to_cleanup and os.path.exists(chatsnack_dir): 47 | # it's okay for this to fail, it's just a cleanup 48 | try: 49 | shutil.rmtree(chatsnack_dir) 50 | except: 51 | pass 52 | # also delete TEST_FILENAME 53 | if os.path.exists(TEST_FILENAME): 54 | os.remove(TEST_FILENAME) 55 | 56 | 57 | def test_text_save(): 58 | # we need a text object saved to disk 59 | text = Text(name="test_text_expansion", content="Respond only with 'YES' regardless of what is said.") 60 | # save it to disk 61 | text.save(TEST_FILENAME) 62 | # test if the file was created 63 | assert os.path.exists(TEST_FILENAME) 64 | 65 | 66 | def test_text_load(): 67 | # we need a text object saved to disk 68 | text = Text(name="test_text_expansion", content="Respond only with 'YES' regardless of what is said.") 69 | # save it to disk 70 | text.save(TEST_FILENAME) 71 | text2 = Text(name="test_text_expansion2") 72 | text2.load(TEST_FILENAME) 73 | assert text2.content == text.content 74 | 75 | def test_text_save2(): 76 | # we need a text object saved to disk 77 | text = Text(name="test_text_expansion", content="Respond only with 'YES' regardless of what is said.") 78 | # save it to disk 79 | text.save() 80 | # test if the file was created 81 | assert os.path.exists(text.datafile.path) 82 | 83 | 84 | @pytest.mark.skipif(os.environ.get("OPENAI_API_KEY") is None, reason="OPENAI_API_KEY is not set in environment or .env") 85 | def test_text_expansion(): 86 | # we need a text object saved to disk 87 | text = Text(name="test_text_expansion", content="Respond only with 'YES' regardless of what is said.") 88 | # save it to disk 89 | text.save() 90 | # we need a chat object to use it 91 | chat = Chat() 92 | # set the logging level to trace for chatsnack 93 | chat.system("{text.test_text_expansion}") 94 | output = chat.chat("Is blue a color?") 95 | # new chat object should have the text expanded in the system message 96 | assert output.system_message == "Respond only with 'YES' regardless of what is said." 97 | 98 | @pytest.mark.skipif(os.environ.get("OPENAI_API_KEY") is None, reason="OPENAI_API_KEY is not set in environment or .env") 99 | def test_text_nested_expansion(): 100 | # we need a text object saved to disk 101 | text = Text(name="test_text_expansion", content="Respond only with '{text.test_text_expansion2}' regardless of what is said.") 102 | # save it to disk 103 | text.save() 104 | 105 | text = Text(name="test_text_expansion2", content="NO") 106 | # save it to disk 107 | text.save() 108 | 109 | # we need a chat object to use it 110 | chat = Chat() 111 | chat.system("{text.test_text_expansion}") 112 | output = chat.chat("Is blue a color?") 113 | 114 | # new chat object should have the text expanded in the system message 115 | assert output.system_message == "Respond only with 'NO' regardless of what is said." 116 | 117 | @pytest.mark.skipif(os.environ.get("OPENAI_API_KEY") is None, reason="OPENAI_API_KEY is not set in environment or .env") 118 | def test_text_chat_expansion(): 119 | chat = Chat(name="test_text_chat_expansion") 120 | chat.system("Respond only with 'DUCK!' regardless of what is said.") 121 | chat.user("Should I buy a goose or a duck?") 122 | chat.params = ChatParams(temperature = 0.0) 123 | chat.save() 124 | 125 | # we need a text object saved to disk 126 | text = Text(name="test_text_expansion", content="Respond only with '{chat.test_text_chat_expansion}' regardless of what is said.") 127 | # save it to disk 128 | text.save() 129 | 130 | # we need a chat object to use it 131 | chat2 = Chat() 132 | chat2.system("{text.test_text_expansion}") 133 | chat2.params = ChatParams(temperature = 0.0) 134 | # right now we have to use chat.chat() to get get it to expand variables 135 | output = chat2.chat("Is blue a color?") 136 | 137 | # new chat object should have the text expanded in the system message 138 | assert output.system_message == "Respond only with 'DUCK!' regardless of what is said." 139 | 140 | # test changing the file on disk 141 | chat3 = Chat(name="test_text_chat_expansion") 142 | chat3.load() 143 | chat3.params = ChatParams(temperature = 0.0) 144 | chat3.messages = [] 145 | chat3.system("Respond only with 'GOOSE!' regardless of what is said.") 146 | chat3.user("Should I buy a goose or a duck?") 147 | chat3.save() 148 | print(chat3) 149 | 150 | print(chat2) 151 | # right now we have to use chat.chat() to get get it to expand variables 152 | output2 = chat2.chat("Is blue a color?") 153 | print(output2) 154 | # new chat object should have the text expanded in the system message 155 | assert output2.system_message == "Respond only with 'GOOSE!' regardless of what is said." 156 | -------------------------------------------------------------------------------- /tests/test_prompt_json.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import json 3 | from chatsnack import Chat 4 | 5 | @pytest.fixture 6 | def empty_prompt(): 7 | return Chat() 8 | 9 | @pytest.fixture 10 | def populated_prompt(): 11 | prompt = Chat() 12 | prompt.add_message("user", "Hello!") 13 | prompt.add_message("assistant", "Hi there!") 14 | return prompt 15 | 16 | def test_add_messages_json(populated_prompt): 17 | messages_json = """ 18 | [ 19 | {"role": "user", "content": "What's the weather like?"}, 20 | {"role": "assistant", "content": "It's sunny outside."} 21 | ] 22 | """ 23 | populated_prompt.add_messages_json(messages_json) 24 | 25 | assert populated_prompt.messages[-2:] == [ 26 | {"user": "What's the weather like?"}, 27 | {"assistant": "It's sunny outside."} 28 | ] 29 | 30 | def test_get_json(populated_prompt): 31 | l = [{"role":"user", "content": "Hello!"}, 32 | {"role":"assistant", "content": "Hi there!"}] 33 | 34 | expected_json = json.dumps(l) 35 | 36 | assert populated_prompt.json == expected_json 37 | 38 | def test_add_messages_json_invalid_format(populated_prompt): 39 | invalid_messages_json = """ 40 | [ 41 | {"role": "user"}, 42 | {"role": "assistant", "content": "It's sunny outside."} 43 | ] 44 | """ 45 | with pytest.raises(Exception): 46 | populated_prompt.add_messages_json(invalid_messages_json) 47 | 48 | def test_add_messages_json_invalid_type(populated_prompt): 49 | invalid_messages_json = """ 50 | [ 51 | {"role": "user", "something": "It's sunny outside."]}, 52 | {"role": "assistant", "content": "It's sunny outside."} 53 | ] 54 | """ 55 | with pytest.raises(Exception): 56 | populated_prompt.add_messages_json(invalid_messages_json) 57 | -------------------------------------------------------------------------------- /tests/test_prompt_last.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pytest 3 | from chatsnack import Chat 4 | 5 | @pytest.fixture 6 | def empty_prompt(): 7 | return Chat() 8 | 9 | @pytest.fixture 10 | def populated_prompt(): 11 | prompt = Chat() 12 | prompt.add_message("user", "Hello!") 13 | prompt.add_message("assistant", "Hi there!") 14 | return prompt 15 | 16 | def test_last_property(populated_prompt): 17 | assert populated_prompt.last == "Hi there!" 18 | 19 | 20 | def test_add_message(populated_prompt): 21 | populated_prompt.add_message("system", "System message") 22 | assert populated_prompt.last == "System message" 23 | 24 | def test_empty_messages(empty_prompt): 25 | assert empty_prompt.last is None 26 | 27 | def test_adding_different_roles(empty_prompt): 28 | empty_prompt.add_message("user", "Test user") 29 | empty_prompt.add_message("assistant", "Test assistant") 30 | empty_prompt.add_message("system", "Test system") 31 | empty_prompt.add_message("include", "Test include") 32 | 33 | assert empty_prompt.messages == [ 34 | {"user": "Test user"}, 35 | {"assistant": "Test assistant"}, 36 | {"system": "Test system"}, 37 | {"include": "Test include"}, 38 | ] 39 | 40 | def test_message_order(empty_prompt): 41 | empty_prompt.add_message("user", "First message") 42 | empty_prompt.add_message("assistant", "Second message") 43 | empty_prompt.add_message("user", "Third message") 44 | empty_prompt.add_message("assistant", "Fourth message") 45 | 46 | assert [msg["user" if "user" in msg else "assistant"] for msg in empty_prompt.messages] == [ 47 | "First message", 48 | "Second message", 49 | "Third message", 50 | "Fourth message", 51 | ] 52 | 53 | # Not enforced at all 54 | # def test_invalid_role(empty_prompt): 55 | # with pytest.raises(Exception): 56 | # empty_prompt.add_message("invalid_role", "Test content") 57 | @pytest.mark.skipif(os.environ.get("OPENAI_API_KEY") is None, reason="OPENAI_API_KEY is not set in environment or .env") 58 | def test_chaining_methods_execution(populated_prompt): 59 | new_prompt = populated_prompt().user("How's the weather?") 60 | assert new_prompt.last == "How's the weather?" 61 | 62 | def test_chaining_methods_messages(empty_prompt): 63 | new_prompt = empty_prompt.system("You are a happy robot.").user("How's the weather?").assistant("It's sunny today.").user("How about tomorrow?") 64 | assert new_prompt.last == "How about tomorrow?" 65 | 66 | 67 | @pytest.mark.asyncio 68 | async def test_concurrent_access(populated_prompt): 69 | import asyncio 70 | 71 | async def add_messages(): 72 | for i in range(10): 73 | populated_prompt.add_message("assistant", f"Message {i}") 74 | 75 | tasks = [add_messages() for _ in range(10)] 76 | await asyncio.gather(*tasks) 77 | 78 | assert len(populated_prompt.messages) == 102 79 | -------------------------------------------------------------------------------- /tests/test_snackpack_base.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pytest 3 | from chatsnack.packs import Jane as chat 4 | 5 | 6 | @pytest.mark.skipif(os.environ.get("OPENAI_API_KEY") is None, reason="OPENAI_API_KEY is not set in environment or .env") 7 | def test_snackpack_chat(): 8 | cp = chat.user("Or is green a form of blue?") 9 | assert cp.last == "Or is green a form of blue?" 10 | 11 | # ask the question 12 | output = cp.ask() 13 | # is there a response and it's longer than 0 characters? 14 | assert output is not None 15 | assert len(output) > 0 16 | 17 | 18 | @pytest.mark.skipif(os.environ.get("OPENAI_API_KEY") is None, reason="OPENAI_API_KEY is not set in environment or .env") 19 | def test_snackpack_ask_with_existing_asst(): 20 | cp = chat.copy() 21 | cp.user("Is the sky blue?") 22 | cp.asst("No! ") 23 | # ask the question 24 | output = cp.ask() 25 | # is there a response and it's longer than 0 characters? 26 | assert output is not None 27 | assert len(output) > 0 28 | 29 | # check to see if the asst response was appended to 30 | # the existing asst response 31 | # check to see if the cp.response starts with "No! " 32 | output = cp.response 33 | assert output.startswith("No! ") 34 | -------------------------------------------------------------------------------- /tests/test_text_class.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from chatsnack.chat import Text 3 | from chatsnack.txtformat import TxtStrFormat 4 | 5 | @pytest.fixture 6 | def empty_text(): 7 | return Text(name="empty-text", content="") 8 | 9 | @pytest.fixture 10 | def populated_text(): 11 | return Text(name="populated-text", content="Hello, world!") 12 | 13 | def test_create_text(empty_text): 14 | assert empty_text.name == "empty-text" 15 | assert empty_text.content == "" 16 | 17 | def test_create_populated_text(populated_text): 18 | assert populated_text.name == "populated-text" 19 | assert populated_text.content == "Hello, world!" 20 | 21 | def test_txt_str_format_serialize(): 22 | data = {"content": "Hello, world!"} 23 | serialized_data = TxtStrFormat.serialize(data) 24 | 25 | assert serialized_data == "Hello, world!" 26 | 27 | def test_txt_str_format_serialize_unsupported_type(): 28 | data = {"content": ["Invalid", "content", "type"]} 29 | 30 | with pytest.raises(ValueError): 31 | TxtStrFormat.serialize(data) 32 | 33 | # def test_txt_str_format_deserialize(populated_text): 34 | # datafile = DataFile.load(populated_text.datafile.path) 35 | # deserialized_data = TxtStrFormat.deserialize(datafile.file) 36 | 37 | # assert deserialized_data == {"content": "Hello, world!"} 38 | 39 | # def test_txt_str_format_deserialize_empty(empty_text): 40 | # datafile = DataFile.load(empty_text.datafile.path) 41 | # deserialized_data = TxtStrFormat.deserialize(datafile.file) 42 | 43 | # assert deserialized_data == {"content": ""} 44 | --------------------------------------------------------------------------------