├── .gitignore ├── LICENSE ├── README.md ├── pyproject.toml ├── src └── promptrix │ ├── AssistantMessage.py │ ├── ConversationHistory.py │ ├── FunctionRegistry.py │ ├── GPT3Tokenizer.py │ ├── GroupSection.py │ ├── LayoutEngine.py │ ├── Prompt.py │ ├── PromptSectionBase.py │ ├── SystemMessage.py │ ├── TemplateSection.py │ ├── TextSection.py │ ├── UserMessage.py │ ├── Utilities.py │ ├── VolatileMemory.py │ ├── __init__.py │ └── promptrixTypes.py └── tests ├── ConversationHistoryTest.py ├── FunctionRegistryTest.py ├── PromptSectionBaseTest.py ├── TemplateSectionTest.py └── VolatileMemoryTest.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Steven Ickman 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # promptrix-py 2 | Promptrix is a prompt layout engine for Large Language Models. 3 | 4 | # Here is a first trivial example: 5 | 6 | from promptrix import promptrixTypes, VolatileMemory, FunctionRegistry, GPT3Tokenizer 7 | from promptrix.Prompt import Prompt 8 | from promptrix.SystemMessage import SystemMessage 9 | from promptrix.UserMessage import UserMessage 10 | from promptrix.AssistantMessage import AssistantMessage 11 | from promptrix.ConversationHistory import ConversationHistory 12 | 13 | functions = FunctionRegistry() 14 | tokenizer = GPT3Tokenizer() 15 | memory = VolatileMemory({'input':'', 'history':[]}) 16 | max_tokens = 2000 17 | 18 | prompt_text = 'You are helpful, creative, clever, and very friendly. ' 19 | PROMPT = Prompt([ 20 | UserMessage(prompt_text), 21 | ConversationHistory('history', .5), # allow history to use up 1/2 the remaining token budget left after the prompt and input 22 | UserMessage('{{$input}}') 23 | ]) 24 | 25 | async def render_messages_completion(): 26 | as_msgs = await PROMPT.renderAsMessages(memory, functions, tokenizer, max_tokens) 27 | msgs = [] 28 | if not as_msgs.tooLong: 29 | msgs = as_msgs.output 30 | return msgs 31 | 32 | ### basic chat loop 33 | while True: 34 | memory.set('input', query) 35 | msgs = asyncio.run(render_messages_completion()) 36 | response = ... your favorite llm api (model, msgs, ...) 37 | print(response) 38 | history = memory.get('history') 39 | history.append({'role':USER_PREFIX, 'content': query}) 40 | history.append({'role': ASSISTANT_PREFIX, 'content': response}) 41 | memory.set('history', history) 42 | 43 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "promptrix" 7 | version = "0.3.1" 8 | authors = [ 9 | { name="Steven Ickman", email="author@example.com" }, 10 | { name="Bruce DAmbrosio", email="bruce.dambrosio@gmail.com" }, 11 | ] 12 | 13 | description = "Promptrix. A prompt layout manager for LLMs" 14 | #documentation = "https://github.com/Stevenic/promptrix-py/README.md" 15 | 16 | readme = "README.md" 17 | 18 | requires-python = ">=3.8" 19 | 20 | classifiers = [ 21 | "Programming Language :: Python :: 3", 22 | "License :: OSI Approved :: MIT License", 23 | "Operating System :: OS Independent", 24 | ] 25 | 26 | dependencies = [ 27 | "requests", 28 | "tiktoken", 29 | "pyyaml", 30 | 'importlib-metadata; python_version<"3.9"', 31 | ] 32 | 33 | 34 | [project.urls] 35 | "Homepage" = "https://tuuyi.io/promptrix" 36 | "Bug Tracker" = "https://github.com/Stevenic/promptrix-py/issues" 37 | -------------------------------------------------------------------------------- /src/promptrix/AssistantMessage.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | import promptrix.TemplateSection as TemplateSection 3 | 4 | class AssistantMessage(TemplateSection.TemplateSection): 5 | """ 6 | A message sent by the assistant. 7 | """ 8 | def __init__(self, template: str, tokens: Optional[int] = -1, assistant_prefix: Optional[str] = 'assistant'): 9 | """ 10 | Creates a new 'AssistantMessage' instance. 11 | :param template: Template to use for this section. 12 | :param tokens: Optional. Sizing strategy for this section. Defaults to `auto`. 13 | :param assistant_prefix: Optional. Prefix to use for assistant messages when rendering as text. Defaults to `assistant`. 14 | """ 15 | super().__init__(template, assistant_prefix, tokens, True, '\n', text_prefix=assistant_prefix) 16 | 17 | -------------------------------------------------------------------------------- /src/promptrix/ConversationHistory.py: -------------------------------------------------------------------------------- 1 | from promptrix.promptrixTypes import Message, PromptFunctions, PromptMemory, RenderedPromptSection, Tokenizer 2 | from promptrix.PromptSectionBase import PromptSectionBase 3 | from promptrix.Utilities import Utilities 4 | 5 | class ConversationHistory(PromptSectionBase): 6 | def __init__(self, variable, tokens=1.0, required=False, userPrefix='user', assistantPrefix='assistant', separator='\n'): 7 | super().__init__(tokens, required, separator) 8 | self.variable = variable 9 | self.userPrefix = userPrefix 10 | self.assistantPrefix = assistantPrefix 11 | 12 | def renderAsText(self, memory, functions, tokenizer, maxTokens): 13 | history = memory.get(self.variable) 14 | if history is None: history=[] 15 | tokens = 0 16 | budget = min(self.tokens, maxTokens) if self.tokens > 1.0 else maxTokens 17 | separatorLength = len(tokenizer.encode(self.separator)) 18 | lines = [] 19 | for i in range(len(history)-1, -1, -1): 20 | msg = history[i] 21 | message = Utilities.to_string(tokenizer, msg['content']) 22 | prefix = self.userPrefix if msg['role'] == 'user' else self.assistantPrefix 23 | line = prefix + message.content 24 | length = len(tokenizer.encode(line)) + (separatorLength if len(lines) > 0 else 0) 25 | if len(lines) == 0 and self.required: 26 | tokens += length 27 | lines.insert(0, line) 28 | continue 29 | if tokens + length > budget: 30 | break 31 | tokens += length 32 | lines.insert(0, line) 33 | return RenderedPromptSection(output=self.separator.join(lines), length=tokens, tooLong=tokens > maxTokens) 34 | 35 | def renderAsMessages(self, memory, functions, tokenizer, maxTokens): 36 | history = memory.get(self.variable) 37 | if history is None: history = [] 38 | tokens = 0 39 | budget = min(self.tokens, maxTokens) if self.tokens > 1.0 else maxTokens 40 | messages = [] 41 | for i in range(len(history)-1, -1, -1): 42 | msg = history[i] 43 | message = {'role':msg['role'], 'content':Utilities.to_string(tokenizer, msg['content'])} 44 | length = len(tokenizer.encode(message['content'])) 45 | if len(messages) == 0 and self.required: 46 | tokens += length 47 | messages.insert(0, message) 48 | continue 49 | if tokens + length > budget: 50 | break 51 | tokens += length 52 | messages.insert(0, message) 53 | 54 | return RenderedPromptSection(output=messages, length=tokens, tooLong=tokens > maxTokens) 55 | -------------------------------------------------------------------------------- /src/promptrix/FunctionRegistry.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Dict, List, Any, Optional 2 | 3 | class FunctionRegistry: 4 | """ 5 | Registry of functions that can be invoked from a prompt template. 6 | """ 7 | def __init__(self, functions: Optional[Dict[str, Callable]] = None): 8 | """ 9 | Creates a new 'FunctionRegistry' instance. 10 | :param functions: Optional. Functions to add to this registry. 11 | """ 12 | self._functions = {} 13 | if functions: 14 | for key, value in functions.items(): 15 | self._functions[key] = value 16 | 17 | def has(self, name: str) -> bool: 18 | return name in self._functions 19 | 20 | def get(self, name: str) -> Callable: 21 | fn = self._functions.get(name) 22 | if not fn: 23 | raise Exception(f"Function {name} not found.") 24 | return fn 25 | 26 | def addFunction(self, name: str, value: Callable) -> None: 27 | if self.has(name): 28 | raise Exception(f"Function '{name}' already exists.") 29 | self._functions[name] = value 30 | 31 | def invoke(self, key: str, memory: Any, functions: Any, tokenizer: Any, args: List[str]) -> Any: 32 | fn = self.get(key) 33 | return fn(memory, functions, tokenizer, args) 34 | -------------------------------------------------------------------------------- /src/promptrix/GPT3Tokenizer.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | #from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer 3 | import tiktoken 4 | enc = tiktoken.get_encoding("cl100k_base") 5 | assert enc.decode(enc.encode("hello world")) == "hello world" 6 | 7 | class GPT3Tokenizer: 8 | def __init__(self): 9 | self.ttk = tiktoken.get_encoding("cl100k_base") 10 | #self.ttk = tiktoken.encoding_for_model("gpt4") 11 | 12 | def decode(self, tokens) -> str: 13 | return self.ttk.decode(tokens) 14 | 15 | def encode(self, text) -> List[int]: 16 | return self.ttk.encode(text) 17 | -------------------------------------------------------------------------------- /src/promptrix/GroupSection.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from promptrix.promptrixTypes import Message, PromptFunctions, PromptMemory, PromptSection, RenderedPromptSection, Tokenizer 3 | from promptrix.PromptSectionBase import PromptSectionBase 4 | from promptrix.LayoutEngine import LayoutEngine 5 | 6 | class GroupSection(PromptSectionBase): 7 | def __init__(self, sections: List[PromptSection], role: str = 'system', tokens: int = -1, required: bool = True, separator: str = '\n\n', textPrefix: str = 'system'): 8 | super().__init__(tokens, required, separator, textPrefix) 9 | self._layoutEngine = LayoutEngine(sections, tokens, required, separator) 10 | self.sections = sections 11 | self.role = role 12 | 13 | def renderAsMessages(self, memory: PromptMemory, functions: PromptFunctions, tokenizer: Tokenizer, maxTokens: int): 14 | # Render sections to text 15 | renderedPromptSection = self._layoutEngine.renderAsText(memory, functions, tokenizer, maxTokens) 16 | output = renderedPromptSection.output 17 | length = renderedPromptSection.length 18 | # Return output as a single message 19 | return self.return_messages([{'role': self.role, 'content': output}], length, tokenizer, maxTokens) 20 | -------------------------------------------------------------------------------- /src/promptrix/LayoutEngine.py: -------------------------------------------------------------------------------- 1 | from typing import List, TypeVar, Optional, Callable, Union 2 | from types import FunctionType 3 | import asyncio 4 | 5 | T = TypeVar('T') 6 | 7 | class RenderedPromptSection: 8 | def __init__(self, output: T, length: int, tooLong: bool): 9 | self.output = output 10 | self.length = length 11 | self.tooLong = tooLong 12 | 13 | class PromptSectionLayout: 14 | def __init__(self, section: 'PromptSection', layout = None): 15 | self.section = section 16 | self.layout = layout 17 | 18 | class PromptSection: 19 | def __init__(self, sections, tokens: int, required: bool, separator: str): 20 | self.sections = sections 21 | self.required = required 22 | self.tokens = tokens 23 | self.separator = separator 24 | 25 | class LayoutEngine(PromptSection): 26 | def __init__(self, sections: List[PromptSection], tokens: int, required: bool, separator: str): 27 | super().__init__(sections, tokens, required, separator) 28 | 29 | def renderAsText(self, memory, functions, tokenizer, maxTokens): 30 | layout = [] 31 | self.addSectionsToLayout(self.sections, layout) 32 | 33 | remaining = self.layoutSections( 34 | layout, 35 | maxTokens, 36 | lambda section: section.renderAsText(memory, functions, tokenizer, maxTokens), 37 | lambda section, remaining: section.renderAsText(memory, functions, tokenizer, remaining), 38 | True, 39 | tokenizer 40 | ) 41 | 42 | output = [section.layout.output for section in layout if section.layout] 43 | text = self.separator.join(output) 44 | return RenderedPromptSection(text, len(tokenizer.encode(text)), remaining < 0) 45 | 46 | def renderAsMessages(self, memory: 'PromptMemory', functions: 'PromptFunctions', tokenizer: 'Tokenizer', maxTokens: int) -> RenderedPromptSection: 47 | 48 | layout = [] 49 | self.addSectionsToLayout(self.sections, layout) 50 | 51 | remaining = self.layoutSections( 52 | layout, 53 | maxTokens, 54 | lambda section: section.renderAsMessages(memory, functions, tokenizer, maxTokens), 55 | lambda section, remaining: section.renderAsMessages(memory, functions, tokenizer, remaining) 56 | ) 57 | 58 | output = [message for section in layout if section.layout for message in section.layout.output] 59 | return RenderedPromptSection(output, self.getLayoutLength(layout), remaining < 0) 60 | 61 | def addSectionsToLayout(self, sections: List[PromptSection], layout: List): 62 | for section in sections: 63 | if isinstance(section, LayoutEngine): 64 | self.addSectionsToLayout(section.sections, layout) 65 | else: 66 | layout.append(PromptSectionLayout(section)) 67 | 68 | def layoutSections(self, layout, maxTokens, cbFixed, cbProportional, textLayout=False, tokenizer=None): 69 | self.layoutFixedSections(layout, cbFixed) 70 | 71 | remaining = maxTokens - self.getLayoutLength(layout, textLayout, tokenizer) 72 | while remaining < 0 and self.dropLastOptionalSection(layout): 73 | remaining = maxTokens - self.getLayoutLength(layout, textLayout, tokenizer) 74 | 75 | if self.needsMoreLayout(layout) and remaining > 0: 76 | self.layoutProportionalSections(layout, lambda section: cbProportional(section, remaining)) 77 | 78 | remaining = maxTokens - self.getLayoutLength(layout, textLayout, tokenizer) 79 | while remaining < 0 and self.dropLastOptionalSection(layout): 80 | remaining = maxTokens - self.getLayoutLength(layout, textLayout, tokenizer) 81 | 82 | return remaining 83 | 84 | def layoutFixedSections(self, layout, callback): 85 | 86 | def process_section(section): 87 | output = callback(section.section) 88 | setattr(section, 'layout', output) 89 | 90 | tasks = [process_section(section) for section in layout if section.section.tokens < 0 or section.section.tokens > 1.0] 91 | #promises = [callback(section.section).then(lambda output: setattr(section, 'layout', output)) for section in layout if section.section.tokens < 0 or section.section.tokens > 1.0] 92 | 93 | 94 | def layoutProportionalSections(self, layout, callback): 95 | def process_section(section): 96 | output = callback(section.section) 97 | setattr(section, 'layout', output) 98 | 99 | tasks = [process_section(section) for section in layout if 0.0 <= section.section.tokens <= 1.0] 100 | 101 | def getLayoutLength(self, layout, textLayout=False, tokenizer=None) -> int: 102 | if textLayout and tokenizer: 103 | output = [section.layout.output for section in layout if section.layout] 104 | return len(tokenizer.encode(self.separator.join(output))) 105 | else: 106 | return sum(section.layout.length for section in layout if section.layout) 107 | 108 | def dropLastOptionalSection(self, layout) -> bool: 109 | for i in range(len(layout) - 1, -1, -1): 110 | if not layout[i].section.required: 111 | layout.pop(i) 112 | return True 113 | return False 114 | 115 | def needsMoreLayout(self, layout) -> bool: 116 | return any(not section.layout for section in layout) 117 | -------------------------------------------------------------------------------- /src/promptrix/Prompt.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | from promptrix.promptrixTypes import Message, PromptFunctions, PromptMemory, PromptSection, RenderedPromptSection, Tokenizer 3 | from promptrix.LayoutEngine import LayoutEngine 4 | 5 | class Prompt(LayoutEngine): 6 | def __init__(self, sections: List[PromptSection], tokens: int = -1, required: bool = True, separator: str = '\n\n'): 7 | super().__init__(sections, tokens, required, separator) 8 | -------------------------------------------------------------------------------- /src/promptrix/PromptSectionBase.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import List, Tuple, Any 3 | #from promptrixTypes import Message, PromptFunctions, PromptMemory, PromptSection, RenderedPromptSection 4 | from promptrix.promptrixTypes import RenderedPromptSection, Message 5 | import promptrix.GPT3Tokenizer as Tokenizer 6 | import traceback 7 | 8 | class PromptSectionBase(): 9 | def __init__(self, tokens = -1, required = True, separator = '\n', text_prefix = ''): 10 | self.required = required 11 | self.tokens = tokens 12 | self.separator = separator 13 | self.text_prefix = text_prefix 14 | if text_prefix is None: 15 | raise Exception 16 | 17 | @abstractmethod 18 | def renderAsMessages(self, memory, functions, tokenizer, max_tokens): 19 | pass 20 | 21 | def renderAsText(self, memory, functions, tokenizer, max_tokens): 22 | as_messages = self.renderAsMessages(memory, functions, tokenizer, max_tokens) 23 | messages = as_messages.output 24 | text = '' 25 | for message in messages: 26 | text += message['content']+'\n' 27 | #text = self.separator.join([message['content'] for message in messages]) 28 | prefix_length = len(tokenizer.encode(self.text_prefix)) 29 | separator_length = len(tokenizer.encode(self.separator)) 30 | length = prefix_length + as_messages.length + ((len(as_messages.output) - 1) * separator_length) 31 | text = self.text_prefix + text 32 | if self.tokens > 1.0 and length > self.tokens: 33 | encoded = tokenizer.encode(text) 34 | text = tokenizer.decode(encoded[:self.tokens]) 35 | length = self.tokens 36 | if text.endswith('\n'): 37 | text = text[:-1] 38 | return RenderedPromptSection(output=text, length=length, tooLong=length > max_tokens) 39 | 40 | def return_messages(self, output, length, tokenizer, max_tokens): 41 | if self.tokens > 1.0: 42 | while length > self.tokens: 43 | msg = output.pop() 44 | encoded = tokenizer.encode(msg['content']) 45 | length -= len(encoded) 46 | if length < self.tokens: 47 | delta = self.tokens - length 48 | truncated = tokenizer.decode(encoded[:delta]) 49 | role = msg['role'] if type(msg) == dict else msg.role 50 | output.append({'role':role, 'content':truncated}) 51 | length += delta 52 | #print(f'PromptSectionBase return_messages {output}') 53 | return RenderedPromptSection(output=output, length=length, tooLong=length > max_tokens) 54 | -------------------------------------------------------------------------------- /src/promptrix/SystemMessage.py: -------------------------------------------------------------------------------- 1 | from promptrix.TemplateSection import TemplateSection 2 | 3 | class SystemMessage(TemplateSection): 4 | """ 5 | A system message. 6 | """ 7 | def __init__(self, template: str, tokens: int = -1): 8 | """ 9 | Creates a new 'SystemMessage' instance. 10 | :param template: Template to use for this section. 11 | :param tokens: Optional. Sizing strategy for this section. Defaults to `auto`. 12 | """ 13 | super().__init__(template, 'system', tokens, True, '\n', '') 14 | -------------------------------------------------------------------------------- /src/promptrix/TemplateSection.py: -------------------------------------------------------------------------------- 1 | #from promptrixTypes import * 2 | from promptrix.PromptSectionBase import PromptSectionBase 3 | from promptrix.Utilities import Utilities 4 | from typing import List, Callable, Any 5 | from enum import Enum 6 | import asyncio 7 | 8 | def get_mem_str(memory, value): 9 | #print (f'***** TemplateSection create_variable_renderer memory {memory}, value {value}') 10 | return value 11 | 12 | class ParseState(Enum): 13 | IN_TEXT = 1 14 | IN_PARAMETER = 2 15 | IN_STRING = 3 16 | 17 | class TemplateSection(PromptSectionBase): 18 | def __init__(self, template, role, tokens = -1, required = True, separator='\n', text_prefix = ''): 19 | super().__init__(tokens, required, separator, text_prefix) 20 | self.template = template 21 | self.role = role 22 | self._parts = [] 23 | self.parse_template() 24 | #print(f'***** TemplateSection init template {self._parts}') 25 | 26 | def renderAsMessages(self, memory: 'PromptMemory', functions: 'PromptFunctions', tokenizer: 'Tokenizer', max_tokens: int) -> 'RenderedPromptSection[List[Message]]': 27 | #print(f'***** TemplateSection entry {self._parts}') 28 | rendered_parts = [part(memory, functions, tokenizer, max_tokens) for part in self._parts] 29 | text = ''.join(rendered_parts) 30 | #print(f'***** TemplateSection rendered parts {rendered_parts}') 31 | length = len(tokenizer.encode(text)) 32 | #print(f'***** TemplateSection rendered parts {text}') 33 | return self.return_messages([{'role': self.role, 'content': text}], length, tokenizer, max_tokens) 34 | 35 | def parse_template(self): 36 | part = '' 37 | state = ParseState.IN_TEXT 38 | string_delim = '' 39 | skip_next = False 40 | for i in range(len(self.template)): 41 | if skip_next: 42 | skip_next = False 43 | continue 44 | char = self.template[i] 45 | if state == ParseState.IN_TEXT: 46 | if char == '{' and self.template[i + 1] == '{': 47 | if len(part) > 0: 48 | self._parts.append(self.create_text_renderer(part)) 49 | part = '' 50 | state = ParseState.IN_PARAMETER 51 | skip_next = True 52 | else: 53 | part += char 54 | elif state == ParseState.IN_PARAMETER: 55 | if char == '}' and self.template[i + 1] == '}': 56 | if len(part) > 0: 57 | if part[0] == '$': 58 | self._parts.append(self.create_variable_renderer(part[1:])) 59 | else: 60 | self._parts.append(self.create_function_renderer(part)) 61 | part = '' 62 | state = ParseState.IN_TEXT 63 | skip_next = True 64 | elif char in ["'", '"', '`']: 65 | string_delim = char 66 | state = ParseState.IN_STRING 67 | part += char 68 | else: 69 | part += char 70 | elif state == ParseState.IN_STRING: 71 | part += char 72 | if char == string_delim: 73 | state = ParseState.IN_PARAMETER 74 | if state != ParseState.IN_TEXT: 75 | raise ValueError(f"Invalid template: {self.template}") 76 | if len(part) > 0: 77 | self._parts.append(self.create_text_renderer(part)) 78 | 79 | 80 | def create_text_renderer(self, text: str) -> Callable[['PromptMemory', 'PromptFunctions', 'Tokenizer', int], 'Promise[str]']: 81 | return lambda memory, functions, tokenizer, max_tokens: text 82 | 83 | def create_variable_renderer(self, name: str) -> Callable[['PromptMemory', 'PromptFunctions', 'Tokenizer', int], 'Promise[str]']: 84 | #print (f'***** TemplateSection create_variable_renderer name {name}') 85 | return lambda memory, functions, tokenizer, max_tokens: get_mem_str(memory, Utilities.to_string(tokenizer, memory.get(name))) 86 | 87 | def create_function_renderer(self, param: str) -> Callable[['PromptMemory', 'PromptFunctions', 'Tokenizer', int], 'Promise[str]']: 88 | name = None 89 | args = [] 90 | part = '' 91 | def save_part(): 92 | nonlocal part, name, args 93 | if len(part) > 0: 94 | if not name: 95 | name = part 96 | else: 97 | args.append(part) 98 | part = '' 99 | 100 | state = ParseState.IN_TEXT 101 | string_delim = '' 102 | for i in range(len(param)): 103 | char = param[i] 104 | if state == ParseState.IN_TEXT: 105 | if char in ["'", '"', '`']: 106 | save_part() 107 | string_delim = char 108 | state = ParseState.IN_STRING 109 | elif char == ' ': 110 | save_part() 111 | else: 112 | part += char 113 | elif state == ParseState.IN_STRING: 114 | if char == string_delim: 115 | save_part() 116 | state = ParseState.IN_TEXT 117 | else: 118 | part += char 119 | save_part() 120 | 121 | return lambda memory, functions, tokenizer, max_tokens: Utilities.to_string(tokenizer, functions.invoke(name, memory, functions, tokenizer, args)) 122 | 123 | -------------------------------------------------------------------------------- /src/promptrix/TextSection.py: -------------------------------------------------------------------------------- 1 | from promptrix.promptrixTypes import PromptMemory, PromptFunctions, Tokenizer, RenderedPromptSection, Message 2 | from promptrix.PromptSectionBase import PromptSectionBase 3 | 4 | class TextSection(PromptSectionBase): 5 | def __init__(self, text: str, role: str, tokens: int = -1, required: bool = True, separator: str = '\n', text_prefix: str = None): 6 | super().__init__(tokens, required, separator, text_prefix) 7 | self.text = text 8 | self.role = role 9 | self._length = -1 10 | 11 | def renderAsMessages(self, memory: PromptMemory, functions: PromptFunctions, tokenizer: Tokenizer, max_tokens: int): 12 | if self._length < 0: 13 | self._length = len(tokenizer.encode(self.text)) 14 | 15 | return self.return_messages([{'role': self.role, 'content': self.text}], self._length, tokenizer, max_tokens) 16 | -------------------------------------------------------------------------------- /src/promptrix/UserMessage.py: -------------------------------------------------------------------------------- 1 | from promptrix.TemplateSection import TemplateSection 2 | 3 | class UserMessage(TemplateSection): 4 | """ 5 | A user message. 6 | """ 7 | def __init__(self, template: str, tokens: int = -1, user_prefix: str = 'user'): 8 | """ 9 | Creates a new 'UserMessage' instance. 10 | :param template: Template to use for this section. 11 | :param tokens: Optional. Sizing strategy for this section. Defaults to `auto`. 12 | :param user_prefix: Optional. Prefix to use for user messages when rendering as text. Defaults to `user`. 13 | """ 14 | super().__init__(template, user_prefix, tokens, True, '\n', text_prefix = user_prefix) 15 | -------------------------------------------------------------------------------- /src/promptrix/Utilities.py: -------------------------------------------------------------------------------- 1 | import json 2 | import yaml 3 | 4 | class Utilities: 5 | """ 6 | Utility functions. 7 | """ 8 | @staticmethod 9 | def to_string(tokenizer, value): 10 | """ 11 | Converts a value to a string. 12 | Dates are converted to ISO strings and Objects are converted to JSON or YAML, whichever is shorter. 13 | :param tokenizer: Tokenizer to use for encoding. 14 | :param value: Value to convert. 15 | :returns: Converted value. 16 | """ 17 | if value is None: 18 | return '' 19 | elif isinstance(value, dict): 20 | if hasattr(value, 'isoformat'): 21 | return value.isoformat() 22 | else: 23 | as_json = json.dumps(value) 24 | return as_json 25 | else: 26 | return str(value) 27 | -------------------------------------------------------------------------------- /src/promptrix/VolatileMemory.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import Any, Dict 3 | 4 | class VolatileMemory: 5 | def __init__(self, memory: Dict[str, Any] = None): 6 | self._memory = {} 7 | if memory: 8 | self._memory = {key: memory[key] for key in memory} 9 | 10 | def has(self, key: str) -> bool: 11 | return key in self._memory 12 | 13 | def get(self, key: str) -> Any: 14 | value = self._memory.get(key) 15 | if value is not None and isinstance(value, dict): 16 | return json.loads(json.dumps(value)) 17 | else: 18 | return value 19 | 20 | def set(self, key: str, value: Any) -> None: 21 | if value is not None and isinstance(value, dict): 22 | clone = json.loads(json.dumps(value)) 23 | self._memory[key] = clone 24 | else: 25 | self._memory[key] = value 26 | 27 | def delete(self, key: str) -> None: 28 | if key in self._memory: 29 | del self._memory[key] 30 | 31 | def clear(self) -> None: 32 | self._memory.clear() 33 | -------------------------------------------------------------------------------- /src/promptrix/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Stevenic/promptrix-py/6ed7059cb35dcf3bf4d383a27cab9f88b358f007/src/promptrix/__init__.py -------------------------------------------------------------------------------- /src/promptrix/promptrixTypes.py: -------------------------------------------------------------------------------- 1 | from typing import Any, List, TypeVar, Callable 2 | from abc import ABC, abstractmethod 3 | from dataclasses import dataclass 4 | 5 | T = TypeVar('T') 6 | 7 | @dataclass 8 | class RenderedPromptSection: 9 | output: T 10 | length: int 11 | tooLong: bool 12 | 13 | @dataclass 14 | class Message: 15 | role: str 16 | content: T 17 | 18 | class PromptMemory(ABC): 19 | @abstractmethod 20 | def has(self, key: str) -> bool: 21 | pass 22 | 23 | @abstractmethod 24 | def get(self, key: str) -> Any: 25 | pass 26 | 27 | @abstractmethod 28 | def set(self, key: str, value: Any) -> None: 29 | pass 30 | 31 | @abstractmethod 32 | def delete(self, key: str) -> None: 33 | pass 34 | 35 | @abstractmethod 36 | def clear(self) -> None: 37 | pass 38 | 39 | class PromptFunctions(ABC): 40 | @abstractmethod 41 | def has(self, name: str) -> bool: 42 | pass 43 | 44 | @abstractmethod 45 | def get(self, name: str) -> Callable: 46 | pass 47 | 48 | @abstractmethod 49 | def invoke(self, name: str, memory, functions, tokenizer, args) -> Any: 50 | pass 51 | 52 | class Tokenizer(ABC): 53 | @abstractmethod 54 | def decode(self, tokens: List[int]) -> str: 55 | pass 56 | 57 | @abstractmethod 58 | def encode(self, text: str) -> List[int]: 59 | pass 60 | 61 | PromptFunction = Callable[['PromptMemory', 'PromptFunctions', 'Tokenizer', T], Any] 62 | 63 | class PromptSection(ABC): 64 | required: bool 65 | tokens: int 66 | 67 | @abstractmethod 68 | def renderAsText(self, memory, functions, tokenizer, maxTokens): 69 | pass 70 | 71 | @abstractmethod 72 | def renderAsMessages(self, memory, functions, tokenizer, maxTokens): 73 | pass 74 | -------------------------------------------------------------------------------- /tests/ConversationHistoryTest.py: -------------------------------------------------------------------------------- 1 | import aiounittest, unittest 2 | from promptrix.ConversationHistory import ConversationHistory 3 | from promptrix.VolatileMemory import VolatileMemory 4 | from promptrix.FunctionRegistry import FunctionRegistry 5 | from promptrix.GPT3Tokenizer import GPT3Tokenizer 6 | import asyncio 7 | 8 | class TestConversationHistory(aiounittest.AsyncTestCase): 9 | def setUp(self): 10 | self.memory = VolatileMemory({ 11 | "history": [ 12 | { "role": "user", "content": "Hello" }, 13 | { "role": "assistant", "content": "Hi" }, 14 | ], 15 | "longHistory": [ 16 | { "role": "user", "content": "Hello" }, 17 | { "role": "assistant", "content": "Hi! How can I help you?" }, 18 | { "role": "user", "content": "I'd like to book a flight" }, 19 | { "role": "assistant", "content": "Sure, where would you like to go?" }, 20 | ] 21 | }) 22 | self.functions = FunctionRegistry() 23 | self.tokenizer = GPT3Tokenizer() 24 | 25 | def test_constructor(self): 26 | section = ConversationHistory('history') 27 | self.assertEqual(section.variable, 'history') 28 | self.assertEqual(section.tokens, 1.0) 29 | self.assertEqual(section.required, False) 30 | self.assertEqual(section.separator, "\n") 31 | self.assertEqual(section.userPrefix, "user") 32 | self.assertEqual(section.assistantPrefix, "assistant") 33 | self.assertEqual(section.text_prefix, "") 34 | 35 | async def test_renderAsMessages(self): 36 | section = ConversationHistory('history', 100) 37 | rendered = await section.renderAsMessages(self.memory, self.functions, self.tokenizer, 100) 38 | self.assertEqual(rendered.output, [ 39 | { "role": "user", "content": "Hello" }, 40 | { "role": "assistant", "content": "Hi" }, 41 | ]) 42 | self.assertEqual(rendered.length, 2) 43 | self.assertEqual(rendered.tooLong, False) 44 | 45 | # Add other test cases... 46 | 47 | if __name__ == '__main__': 48 | unittest.main() 49 | -------------------------------------------------------------------------------- /tests/FunctionRegistryTest.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from FunctionRegistry import FunctionRegistry 3 | from VolatileMemory import VolatileMemory 4 | from GPT3Tokenizer import GPT3Tokenizer 5 | 6 | class TestFunctionRegistry(unittest.TestCase): 7 | def test_constructor(self): 8 | registry = FunctionRegistry() 9 | self.assertIsNotNone(registry) 10 | self.assertFalse(registry.has("test")) 11 | 12 | registry = FunctionRegistry({ 13 | "test": lambda memory, functions, tokenizer, args: None 14 | }) 15 | self.assertIsNotNone(registry) 16 | self.assertTrue(registry.has("test")) 17 | 18 | def test_addFunction(self): 19 | registry = FunctionRegistry() 20 | registry.addFunction("test", lambda memory, functions, tokenizer, args: None) 21 | self.assertTrue(registry.has("test")) 22 | 23 | with self.assertRaises(Exception): 24 | registry = FunctionRegistry({ 25 | "test": lambda memory, functions, tokenizer, args: None 26 | }) 27 | registry.addFunction("test", lambda memory, functions, tokenizer, args: None) 28 | 29 | def test_get(self): 30 | registry = FunctionRegistry({ 31 | "test": lambda memory, functions, tokenizer, args: None 32 | }) 33 | fn = registry.get("test") 34 | self.assertIsNotNone(fn) 35 | 36 | with self.assertRaises(Exception): 37 | registry = FunctionRegistry() 38 | registry.get("test") 39 | 40 | def test_has(self): 41 | registry = FunctionRegistry() 42 | self.assertFalse(registry.has("test")) 43 | 44 | registry = FunctionRegistry({ 45 | "test": lambda memory, functions, tokenizer, args: None 46 | }) 47 | self.assertTrue(registry.has("test")) 48 | 49 | def test_invoke(self): 50 | memory = VolatileMemory() 51 | tokenizer = GPT3Tokenizer() 52 | 53 | called = False 54 | def test_func(memory, functions, tokenizer, args): 55 | nonlocal called 56 | self.assertEqual(len(args), 1) 57 | self.assertEqual(args[0], "Hello World") 58 | called = True 59 | 60 | registry = FunctionRegistry({ 61 | "test": test_func 62 | }) 63 | registry.invoke("test", memory, registry, tokenizer, ["Hello World"]) 64 | self.assertTrue(called) 65 | 66 | with self.assertRaises(Exception): 67 | registry = FunctionRegistry() 68 | registry.invoke("test", memory, registry, tokenizer, ["Hello World"]) 69 | 70 | if __name__ == '__main__': 71 | unittest.main() 72 | -------------------------------------------------------------------------------- /tests/PromptSectionBaseTest.py: -------------------------------------------------------------------------------- 1 | import aiounittest, unittest 2 | from promptrix.promptrixTypes import * 3 | from promptrix.PromptSectionBase import PromptSectionBase 4 | from promptrix.VolatileMemory import VolatileMemory 5 | from promptrix.FunctionRegistry import FunctionRegistry 6 | from promptrix.GPT3Tokenizer import GPT3Tokenizer 7 | 8 | class TestSection(PromptSectionBase): 9 | async def renderAsMessages(self, memory: PromptMemory, functions: PromptFunctions, tokenizer: Tokenizer, max_tokens: int): 10 | return self.return_messages([{'role': 'test', 'content': 'Hello Big World'}], 3, tokenizer, max_tokens) 11 | 12 | 13 | class MultiTestSection(PromptSectionBase): 14 | async def renderAsMessages(self, memory: PromptMemory, functions: PromptFunctions, tokenizer: Tokenizer, max_tokens: int): 15 | return self.return_messages([{'role': 'test', 'content': 'Hello Big'}, {'role': 'test', 'content': 'World'}], 3, tokenizer, max_tokens) 16 | 17 | class TestPromptSectionBase(aiounittest.AsyncTestCase): 18 | def setUp(self): 19 | self.memory = VolatileMemory() 20 | self.functions = FunctionRegistry() 21 | self.tokenizer = GPT3Tokenizer() 22 | 23 | def test_constructor(self): 24 | section = TestSection() 25 | self.assertEqual(section.tokens, -1) 26 | self.assertEqual(section.required, True) 27 | self.assertEqual(section.separator, "\n") 28 | self.assertEqual(section.text_prefix, "") 29 | 30 | async def test_renderAsMessages(self): 31 | section = TestSection() 32 | rendered = await section.renderAsMessages(self.memory, self.functions, self.tokenizer, 100) 33 | self.assertEqual(rendered.output, [{'role': 'test', 'content': 'Hello Big World'}]) 34 | self.assertEqual(rendered.length, 3) 35 | self.assertEqual(rendered.tooLong, False) 36 | 37 | section = TestSection(2) 38 | rendered = await section.renderAsMessages(self.memory, self.functions, self.tokenizer, 100) 39 | self.assertEqual(rendered.output, [{'role': 'test', 'content': 'Hello Big'}]) 40 | self.assertEqual(rendered.length, 2) 41 | self.assertEqual(rendered.tooLong, False) 42 | 43 | section = TestSection(2) 44 | rendered = await section.renderAsMessages(self.memory, self.functions, self.tokenizer, 1) 45 | self.assertEqual(rendered.output, [{'role': 'test', 'content': 'Hello Big'}]) 46 | self.assertEqual(rendered.length, 2) 47 | self.assertEqual(rendered.tooLong, True) 48 | 49 | section = MultiTestSection(2) 50 | rendered = await section.renderAsMessages(self.memory, self.functions, self.tokenizer, 100) 51 | self.assertEqual(rendered.output, [{'role': 'test', 'content': 'Hello Big'}]) 52 | self.assertEqual(rendered.length, 2) 53 | self.assertEqual(rendered.tooLong, False) 54 | 55 | async def test_renderAsText(self): 56 | section = TestSection() 57 | rendered = await section.renderAsText(self.memory, self.functions, self.tokenizer, 100) 58 | self.assertEqual(rendered.output, "Hello Big World") 59 | self.assertEqual(rendered.length, 3) 60 | self.assertEqual(rendered.tooLong, False) 61 | 62 | section = TestSection(4, True, "\n", "user: ") 63 | rendered = await section.renderAsText(self.memory, self.functions, self.tokenizer, 100) 64 | self.assertEqual(rendered.output, "user: Hello Big") 65 | self.assertEqual(rendered.length, 4) 66 | self.assertEqual(rendered.tooLong, False) 67 | 68 | section = TestSection(4, True, "\n", "user: ") 69 | rendered = await section.renderAsText(self.memory, self.functions, self.tokenizer, 1) 70 | self.assertEqual(rendered.output, "user: Hello Big") 71 | self.assertEqual(rendered.length, 4) 72 | self.assertEqual(rendered.tooLong, True) 73 | 74 | 75 | if __name__ == '__main__': 76 | unittest.main() 77 | -------------------------------------------------------------------------------- /tests/TemplateSectionTest.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from promptrix.TemplateSection import TemplateSection 3 | from promptrix.VolatileMemory import VolatileMemory 4 | from promptrix.FunctionRegistry import FunctionRegistry 5 | from promptrix.GPT3Tokenizer import GPT3Tokenizer 6 | import asyncio 7 | 8 | class TestTemplateSection(unittest.TestCase): 9 | def setUp(self): 10 | self.memory = VolatileMemory({ 11 | 'foo': 'bar' 12 | }) 13 | self.functions = FunctionRegistry({ 14 | 'test': lambda memory, functions, tokenizer, args: 'Hello World', 15 | 'test2': lambda memory, functions, tokenizer, args: args[0], 16 | 'test3': lambda memory, functions, tokenizer, args: ' '.join(args), 17 | }) 18 | self.tokenizer = GPT3Tokenizer() 19 | 20 | def test_constructor(self): 21 | section = TemplateSection("Hello World", "user") 22 | self.assertEqual(section.template, "Hello World") 23 | self.assertEqual(section.role, "user") 24 | self.assertEqual(section.tokens, -1) 25 | self.assertEqual(section.required, True) 26 | self.assertEqual(section.separator, "\n") 27 | 28 | section = TemplateSection("Hello World", "system", 2.0, False) 29 | self.assertEqual(section.template, "Hello World") 30 | self.assertEqual(section.role, "system") 31 | self.assertEqual(section.tokens, 2.0) 32 | self.assertEqual(section.required, False) 33 | self.assertEqual(section.separator, "\n") 34 | 35 | async def test_renderAsMessages(self): 36 | section = TemplateSection("Hello World", "user") 37 | rendered = await section.renderAsMessages(self.memory, self.functions, self.tokenizer, 100) 38 | self.assertEqual(rendered.output, [{'role': 'user', 'content': 'Hello World'}]) 39 | self.assertEqual(rendered.length, 2) 40 | self.assertEqual(rendered.tooLong, False) 41 | 42 | section = TemplateSection("Hello World", "user") 43 | rendered = await section.renderAsMessages(self.memory, self.functions, self.tokenizer, 1) 44 | self.assertEqual(rendered.output, [{'role': 'user', 'content': 'Hello World'}]) 45 | self.assertEqual(rendered.length, 2) 46 | self.assertEqual(rendered.tooLong, True) 47 | 48 | async def test_renderAsText(self): 49 | section = TemplateSection("Hello World", "user") 50 | rendered = await section.renderAsText(self.memory, self.functions, self.tokenizer, 100) 51 | self.assertEqual(rendered.output, "Hello World") 52 | self.assertEqual(rendered.length, 2) 53 | self.assertEqual(rendered.tooLong, False) 54 | 55 | section = TemplateSection("Hello World", "user") 56 | rendered = await section.renderAsText(self.memory, self.functions, self.tokenizer, 1) 57 | self.assertEqual(rendered.output, "Hello World") 58 | self.assertEqual(rendered.length, 2) 59 | self.assertEqual(rendered.tooLong, True) 60 | 61 | async def test_template_syntax(self): 62 | section = TemplateSection("Hello {{$foo}}", "user") 63 | rendered = await section.renderAsText(self.memory, self.functions, self.tokenizer, 100) 64 | self.assertEqual(rendered.output, "Hello bar") 65 | self.assertEqual(rendered.length, 2) 66 | self.assertEqual(rendered.tooLong, False) 67 | 68 | section = TemplateSection("Hello {{$foo}} {{test}}", "user") 69 | rendered = await section.renderAsText(self.memory, self.functions, self.tokenizer, 100) 70 | self.assertEqual(rendered.output, "Hello bar Hello World") 71 | self.assertEqual(rendered.length, 4 ) 72 | self.assertEqual(rendered.tooLong, False) 73 | 74 | section = TemplateSection("Hello {{test2 World}}", "user") 75 | rendered = await section.renderAsText(self.memory, self.functions, self.tokenizer, 100) 76 | self.assertEqual(rendered.output, "Hello World") 77 | self.assertEqual(rendered.length, 2) 78 | self.assertEqual(rendered.tooLong, False) 79 | 80 | section = TemplateSection("Hello {{test2 'Big World'}}", "user") 81 | rendered = await section.renderAsText(self.memory, self.functions, self.tokenizer, 100) 82 | self.assertEqual(rendered.output, "Hello Big World") 83 | self.assertEqual(rendered.length, 3) 84 | self.assertEqual(rendered.tooLong, False) 85 | 86 | section = TemplateSection("Hello {{test2 `Big World`}}", "user") 87 | rendered = await section.renderAsText(self.memory, self.functions, self.tokenizer, 100) 88 | self.assertEqual(rendered.output, "Hello Big World") 89 | self.assertEqual(rendered.length, 3) 90 | self.assertEqual(rendered.tooLong, False) 91 | 92 | section = TemplateSection("Hello {{test3 'Big' World}}", "user") 93 | rendered = await section.renderAsText(self.memory, self.functions, self.tokenizer, 100) 94 | self.assertEqual(rendered.output, "Hello Big World") 95 | self.assertEqual(rendered.length, 3) 96 | self.assertEqual(rendered.tooLong, False) 97 | 98 | section = TemplateSection("{{}}", "user") 99 | rendered = await section.renderAsText(self.memory, self.functions, self.tokenizer, 100) 100 | self.assertEqual(rendered.output, "") 101 | self.assertEqual(rendered.length, 0) 102 | self.assertEqual(rendered.tooLong, False) 103 | 104 | with self.assertRaises(Exception) as context: 105 | section = TemplateSection("Hello {{test3 'Big' World}", "user") 106 | self.assertTrue('Invalid template: Hello {{test3 \'Big\' World}' in str(context.exception)) 107 | 108 | with self.assertRaises(Exception) as context: 109 | section = TemplateSection("Hello {{test3 'Big}}", "user") 110 | self.assertTrue('Invalid template: Hello {{test3 \'Big}}' in str(context.exception)) 111 | 112 | ts = TestTemplateSection() 113 | ts.setUp() 114 | ts.test_constructor() 115 | 116 | if __name__ == '__main__': 117 | asyncio.run(ts.test_renderAsMessages()) 118 | asyncio.run(ts.test_renderAsText()) 119 | asyncio.run(ts.test_template_syntax()) 120 | 121 | -------------------------------------------------------------------------------- /tests/VolatileMemoryTest.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from promptrix.VolatileMemory import VolatileMemory 3 | 4 | class TestVolatileMemory(unittest.TestCase): 5 | def setUp(self): 6 | self.memory = VolatileMemory() 7 | self.obj = {'foo': 'bar'} 8 | 9 | def test_constructor(self): 10 | self.assertIsNotNone(self.memory) 11 | 12 | def test_constructor_with_initial_values(self): 13 | memory = VolatileMemory({"test": 123}) 14 | self.assertIsNotNone(memory) 15 | self.assertTrue(memory.has("test")) 16 | 17 | def test_set_primitive_value(self): 18 | self.memory.set("test", 123) 19 | self.assertTrue(self.memory.has("test")) 20 | 21 | def test_set_object(self): 22 | self.memory.set("test2", self.obj) 23 | self.assertTrue(self.memory.has("test2")) 24 | 25 | def test_get_primitive_value(self): 26 | self.memory.set("test", 123) 27 | value = self.memory.get("test") 28 | self.assertEqual(value, 123) 29 | 30 | def test_get_object_clone(self): 31 | self.memory.set("test2", self.obj) 32 | value = self.memory.get("test2") 33 | self.assertEqual(value, {'foo': 'bar'}) 34 | self.assertIsNot(value, self.obj) 35 | 36 | def test_get_undefined(self): 37 | value = self.memory.get("test3") 38 | self.assertIsNone(value) 39 | 40 | def test_has_value(self): 41 | self.memory.set("test", 123) 42 | self.assertTrue(self.memory.has("test")) 43 | 44 | def test_has_no_value(self): 45 | self.assertFalse(self.memory.has("test3")) 46 | 47 | def test_delete_value(self): 48 | self.memory.set("test", 123) 49 | self.memory.set("test2", 123) 50 | self.memory.delete("test") 51 | self.assertFalse(self.memory.has("test")) 52 | self.assertTrue(self.memory.has("test2")) 53 | 54 | def test_clear_values(self): 55 | self.memory.set("test", 123) 56 | self.memory.clear() 57 | self.assertFalse(self.memory.has("test")) 58 | self.assertFalse(self.memory.has("test2")) 59 | 60 | if __name__ == '__main__': 61 | unittest.main() 62 | --------------------------------------------------------------------------------