├── lifelike ├── __init__.py ├── StateManager │ ├── __init__.py │ ├── knowledge_tree.py │ ├── sequence_tree.py │ └── base_game_tree.py ├── state.py └── brain.py ├── setup.cfg ├── README.md ├── setup.py ├── LICENSE.txt └── .gitignore /lifelike/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /lifelike/StateManager/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | description-file = README.md 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # lifelike 2 | A toolkit that allows for the creation of "lifelike" characters that you can interact with and change how they behave towards you 3 | 4 | # Documentation & Guide 5 | https://lifelike-toolkit.github.io/documentation/ 6 | 7 | # Example & Demo 8 | - Interrogation Demo: https://github.com/lifelike-toolkit/interrogation-demo 9 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='lifelike', 5 | version='1.0.7', 6 | description='A A toolkit that allows for the creation of "lifelike" characters that you can interact with and change how they behave towards you', 7 | author='Mustafa Tariq and Khoa Nguyen', 8 | license='MIT', 9 | packages=find_packages(), 10 | install_requires=[ 11 | 'langchain' 12 | ] 13 | ) 14 | -------------------------------------------------------------------------------- /lifelike/state.py: -------------------------------------------------------------------------------- 1 | """Defines the current game state, including player state, character states as well as world state""" 2 | class Context: 3 | """ 4 | Interface class for contextual embeddings used to interface with VectorDB. 5 | Part of Sequence Tree, sent to Brain for processing there, or bypass Brain altogether. 6 | """ 7 | def __init__(self) -> None: 8 | pass 9 | 10 | class State: 11 | """Interface class to define states. Keeps it consistent among different components.""" 12 | def __init__(self) -> None: 13 | pass -------------------------------------------------------------------------------- /lifelike/StateManager/knowledge_tree.py: -------------------------------------------------------------------------------- 1 | """ 2 | Inherits from BaseGameTree. Demonstrates a GameTree for knowledge-based games 3 | """ 4 | import uuid 5 | 6 | from langchain.vectorstores import Chroma 7 | from langchain.embeddings.base import Embeddings 8 | 9 | from lifelike.StateManager.base_game_tree import BaseGameTree 10 | 11 | class KnowledgeTree(BaseGameTree): 12 | """Allows for loading the tree from a pre-defined list of contextual texts""" 13 | @classmethod 14 | def from_texts(cls: BaseGameTree, name:str, texts: list[str], embedding_function: Embeddings, metadatas: list[dict]=None, ids: list[str]=None) -> 'KnowledgeTree': 15 | """For now, must know the embedding dimension.""" 16 | if ids is None: # If id not provided, generate it 17 | ids = [str(uuid.uuid1()) for _ in texts] 18 | 19 | tree = cls(name, embedding_function) 20 | # TODO: Currently bypasses building tree. Tree is essentially non-functional 21 | tree.vectorstore = Chroma.from_texts(texts, embedding_function, metadatas, ids) 22 | return tree -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 lifelike-toolkit 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /lifelike/StateManager/sequence_tree.py: -------------------------------------------------------------------------------- 1 | """ 2 | Inherits from GameNode. Meant for linear games with branching story paths. 3 | Good example of simple GameTree object. 4 | """ 5 | from lifelike.StateManager.base_game_tree import GameNodeRetriever, BaseGameTree, GameNode, EdgeEmbedding 6 | 7 | PathEmbedding = EdgeEmbedding 8 | 9 | class SequenceEvent(GameNode): 10 | """Wrapper for an event document in Database. The game can only see 1 at a time.""" 11 | def __init__(self, id: str, context: str, metadata: dict={}, reachable: list[str]=[]) -> None: 12 | """ 13 | Constructor. To build the event from dict, use .from_dict() 14 | Params: 15 | - id: self-explanatory 16 | - context: also self-explanatory 17 | - metadata: the text prompts given as response to player speech. 18 | - reachable: list of ids for sequence event that this specific event can reach. This behaviour can be customized via SequenceEventRetriever 19 | """ 20 | super().__init__(id, context, metadata) 21 | self.metadata["reachable"] = reachable # Internal metadata 22 | 23 | 24 | class SequenceTree(BaseGameTree): 25 | """ 26 | Technically a graph. Only used during Database setup. 27 | Provides methods that supports building out a sequence tree and acts as an interface for Database. 28 | Only Constructor can exit to support retries. 29 | """ 30 | def validate_edge(self, start_id: str, end_id: str, embedding_name: str, embedding_template: str = "default") -> bool: 31 | """Validates new edge_dict entry. All node in an edge must exist.""" 32 | if not super().validate_edge(start_id, end_id, embedding_name, embedding_template): 33 | return False 34 | elif start_id not in self.node_dict or end_id not in self.node_dict: 35 | print("Either start_id or end_id does not exist in event_dict. Must be 2 of {}.".format(self.node_dict.keys())) 36 | return False 37 | else: 38 | return True 39 | 40 | def add_edge(self, start_id: str, end_id: str, embedding_name: str, embedding_template: str = "default") -> bool: 41 | if super().add_edge(start_id, end_id, embedding_name, embedding_template): 42 | self.node_dict[start_id].metadata["reachable"].append(end_id) -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ -------------------------------------------------------------------------------- /lifelike/brain.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file contains the interface to manage characters and conversations 3 | """ 4 | import json 5 | import os 6 | import random 7 | from typing import List, Dict, Set 8 | from langchain import PromptTemplate, LLMChain 9 | 10 | 11 | class Characters: 12 | """ 13 | This class is an interface to manage characters. 14 | """ 15 | def __init__(self, path: str) -> None: 16 | """ 17 | @param path: path to the json file 18 | @return: None, initializes Characters 19 | """ 20 | self.path = path 21 | self.characters = {} 22 | if os.path.exists(path): 23 | self.characters = json.load(open(path, 'r', encoding='utf-8')) 24 | 25 | def is_out(self, name: str) -> ValueError: 26 | """ 27 | @param name: unique name of the character 28 | @return: ValueError if character does not exist 29 | """ 30 | if name not in self.characters: 31 | raise ValueError(f"Character {name} does not exist.") 32 | 33 | def is_in(self, name: str) -> ValueError: 34 | """ 35 | @param name: unique name of the character 36 | @return: ValueError if character exists 37 | """ 38 | if name in self.characters: 39 | raise ValueError(f"Character {name} already exists.") 40 | 41 | def get(self, name: str) -> str: 42 | """ 43 | @param name: unique name of the character 44 | @return: background of the character 45 | """ 46 | self.is_out(name) 47 | return self.characters[name] 48 | 49 | def add(self, name: str, background: str) -> None: 50 | """ 51 | @param name: unique name of the character 52 | @param background: background of the character 53 | @return: None, adds character to Characters 54 | """ 55 | self.is_in(name) 56 | self.characters[name] = background 57 | 58 | def update(self, name: str, background: str) -> None: 59 | """ 60 | @param new_name: name of the character 61 | @param background: new background of the character 62 | @return: None, updates character in Characters 63 | """ 64 | self.is_out(name) 65 | self.characters[name] = background 66 | 67 | def delete(self, name) -> None: 68 | """ 69 | @param name: unique name of the character 70 | @return: None, deletes character from Characters 71 | """ 72 | self.is_out(name) 73 | self.characters.pop(name) 74 | 75 | def __str__(self) -> str: 76 | """ 77 | @return: string representation of Characters 78 | """ 79 | return str(self.characters) 80 | 81 | def save(self) -> None: 82 | """ 83 | @return: None, saves Characters to json file 84 | """ 85 | json.dump(self.characters, open(self.path, 'w', encoding='utf-8')) 86 | 87 | 88 | class Conversations: 89 | """ 90 | This class is an interface to manage conversations 91 | """ 92 | def __init__(self, path: str, characters: Characters, llm) -> None: 93 | """ 94 | @param path: path to the json file 95 | @param characters: Characters object 96 | @param llm: langchain llm object 97 | @return: None, initializes Conversations 98 | """ 99 | # TODO: Allow for custom prompt template 100 | self.path = path 101 | self.valid = characters 102 | self.llm = llm 103 | self.conversations = {} 104 | if os.path.exists(path): 105 | self.conversations = json.load(open(path, 'r', encoding='utf-8')) 106 | 107 | def context_out(self, context: str) -> ValueError: 108 | """ 109 | @param context: unique context of the conversation 110 | @return: ValueError if context does not exist 111 | """ 112 | if context not in self.conversations: 113 | raise ValueError(f"Conversation {context} does not exist.") 114 | 115 | def context_in(self, context: str) -> ValueError: 116 | """ 117 | @param context: unique context of the conversation 118 | @return: ValueError if context exists 119 | """ 120 | if context in self.conversations: 121 | raise ValueError(f"Conversation {context} already exists.") 122 | 123 | def valid_participants(self, participants: Set[str]) -> ValueError: 124 | """ 125 | @param context: unique context of the conversation 126 | @param participants: list of character names 127 | @return: ValueError if participants are invalid 128 | """ 129 | invalid = participants - set(self.valid.characters) 130 | if invalid: 131 | raise ValueError(f"The participants in {invalid} do not exist.") 132 | 133 | def get(self, context: str) -> Dict[str, any]: 134 | """ 135 | @param context: unique context of the conversation 136 | @return: participants and log of the conversation 137 | """ 138 | self.context_out(context) 139 | return self.conversations[context] 140 | 141 | def new(self, context: str, participants: Set[str]) -> None: 142 | """ 143 | @param context: unique context of the conversation 144 | @param participants: list of character names 145 | @return: None, creates new conversation 146 | """ 147 | self.valid_participants(participants) 148 | self.context_in(context) 149 | self.conversations[context] = {"participants": participants, "log": []} 150 | 151 | def update(self, context: str, participants: Set[str], log: List[List[str]]) -> None: 152 | """ 153 | @param context: unique context of the conversation 154 | @param participants: list of character names 155 | @param log: list of [speaker, utterance] 156 | @return: None, updates conversation 157 | """ 158 | self.valid_participants(participants) 159 | self.context_out(context) 160 | self.conversations[context] = {"participants": participants, "log": log} 161 | 162 | def delete(self, context: str) -> None: 163 | """ 164 | @param context: unique context of the conversation 165 | @return: None, deletes conversation 166 | """ 167 | self.context_out(context) 168 | self.conversations.pop(context) 169 | 170 | def append(self, context: str, speaker: str, utterance: str) -> None: 171 | """ 172 | @param context: unique context of the conversation 173 | @param speaker: name of the speaker 174 | @param utterance: utterance of the speaker 175 | @return: None, appends utterance to the conversation 176 | """ 177 | self.valid_participants({speaker}) 178 | self.context_out(context) 179 | self.conversations[context]["log"].append([speaker, utterance]) 180 | 181 | def generate(self, context: str, history: str, muted: Set[str]) -> List[str]: 182 | """ 183 | @param context: unique context of the conversation 184 | @param muted: list of muted characters 185 | @return: speaker and generated utterance 186 | """ 187 | #TODO: find a smarter way to choose next character 188 | #TODO: Find a smarter way to get a single response. 189 | #TODO: memory for conversation log overflow 190 | #TODO: memory for the context 191 | #TODO: memory for the character background 192 | self.valid_participants(muted) 193 | 194 | convo = self.get(context) 195 | 196 | convo_speakers = convo["participants"] 197 | unmuted = convo_speakers.difference(muted) 198 | next_speaker = random.sample(unmuted, 1)[0] 199 | bg = self.valid.get(next_speaker) 200 | 201 | convo_log = convo["log"] 202 | log_str = '\n'.join([f"{speaker}: {utterance}" for speaker, utterance in convo_log]) 203 | # get last 3 lines of log 204 | try: 205 | log_str = '\n'.join(log_str.split('\n')[-3:]) 206 | except IndexError: 207 | pass 208 | 209 | template = "Context:\n"\ 210 | "{context}\n"\ 211 | "\n"\ 212 | "Background:\n"\ 213 | "{background}\n"\ 214 | "\n"\ 215 | "Relevant pieces of information:\n"\ 216 | "{history}\n"\ 217 | "(Only use if relevant to the conversation)\n"\ 218 | "\n"\ 219 | "Conversation:\n"\ 220 | "{log}\n"\ 221 | "{speaker}:" 222 | 223 | prompt = PromptTemplate(template=template, 224 | input_variables=["context", "background", 225 | "history", "log", "speaker"]) 226 | chain = LLMChain(prompt=prompt, llm=self.llm) 227 | 228 | output = chain.run({"context": context, "background": bg, "history": history, "log": log_str, "speaker": next_speaker}) 229 | if output != "": 230 | output = output.split('\n')[0].lstrip() 231 | self.append(context, next_speaker, output) 232 | return [next_speaker, output] 233 | 234 | def __str__(self) -> str: 235 | """ 236 | @return: string representation of Conversations 237 | """ 238 | return str(self.conversations) 239 | 240 | def save(self) -> None: 241 | """ 242 | @return: None, saves Conversations to json file 243 | """ 244 | json.dump(self.conversations, open(self.path, 'w', encoding='utf-8')) 245 | -------------------------------------------------------------------------------- /lifelike/StateManager/base_game_tree.py: -------------------------------------------------------------------------------- 1 | """ 2 | Meant to be used alongside chromadb, but any vectordb that uses KNN will work too. 3 | Tunes sequence embeddings to allow the game to more accurately predict player's intentions. 4 | For now, requires chromadb 5 | """ 6 | import numpy 7 | import json 8 | import uuid 9 | 10 | from langchain.schema import BaseRetriever 11 | from langchain.vectorstores import Chroma 12 | from langchain.embeddings.base import Embeddings # TODO: Make all embedding_function Embeddings interface 13 | 14 | class EdgeEmbedding: 15 | """Edge embedding dictionary that stores, calculate and allows for retrieval of different preset embeddings""" 16 | def __init__(self, name: str, embedding_function: Embeddings=None, current_embedding: list=None, current_weight: int=0) -> None: 17 | """ 18 | Constructor. If loading from dict, use from_dict() instead. 19 | Params: 20 | - name: identifier 21 | - embedding_function: the function that takes a batch of responses and returns the corresponding batch of embeddings. If not provided, the embedding is marked as Final (no tuning allowed). 22 | - current_embedding: the current embedding loaded from json, or pre-determined to bypass tuning. Defaulted to None. 23 | - current_weight: the current weight loaded from json (use 1 to bypass tuning). If current_weight is 0, the current_embedding will be ignored. Defaulted to 0. 24 | """ 25 | if not name: 26 | raise Exception("edge embedding name cannot be None") 27 | 28 | self.name = name 29 | self._final = False # Whether tuning is disabled on this embedding 30 | 31 | if embedding_function is None: 32 | if current_embedding is None: 33 | raise Exception("EdgeEmbedding {} was initialized with no embedding".format(name)) 34 | self._final = True # Activate Final flag 35 | 36 | self.embed = embedding_function 37 | 38 | self.embedding = current_embedding 39 | 40 | self.weight = current_weight # The number of responses this embedding represents 41 | 42 | @classmethod 43 | def from_dict(cls:'EdgeEmbedding', embedding_dict: dict, embedding_function: Embeddings=None) -> 'EdgeEmbedding': 44 | """ 45 | Generate up a edge Embedding instance. DO NOT USE TO MAKE DEEP COPY, use copy() instead 46 | Params: 47 | - embedding_dict: the result of edgeEmbedding.to_dict() instance method, or a dict with the same format 48 | - embedding_function: the function that takes a batch of responses and returns the corresponding batch of embeddings. Set to None if no tuning is required 49 | """ 50 | name = embedding_dict["name"] 51 | embedding = embedding_dict["embedding"] 52 | weight = embedding_dict["weight"] 53 | return cls(name, embedding_function, embedding, weight) 54 | 55 | def get_embedding(self) -> list: 56 | """ 57 | Getter function for embedding attribute. Defaulted to all 0s if no responses have been added. 58 | """ 59 | return self.embedding 60 | 61 | def tune_prompts(self, prompts: list) -> list: 62 | """ 63 | Add prompts to the embedding and calculates new embedding using weighted average. Tunes embedding of SequenceEvent. 64 | Potentially suffers from density bias. Good response choices will depend on choice of embedding function. 65 | Params: 66 | - prompts: the list of prompts that should be matched to this edge_embedding instance 67 | 68 | Returns: The new embedding 69 | """ 70 | if self._final: 71 | print("The Embedding is marked as Final. Cannot be tuned.") 72 | return None 73 | 74 | new_embeddings = self.embed.embed_documents(prompts) 75 | 76 | # Update weighted average 77 | self.embedding = numpy.average([self.embedding] + new_embeddings, 0, [self.weight]+[1]*len(prompts)).tolist() 78 | 79 | # Update weights 80 | self.weight += len(prompts) 81 | return self.embedding 82 | 83 | def to_dict(self) -> dict: 84 | """ 85 | Used to save to json as part of configuration, as well as make a copy 86 | Returns: a dict of the form {'name': ..., 'embedding': ..., 'weight': ...} 87 | """ 88 | return { 89 | 'name': self.name, 90 | 'embedding': self.embedding, 91 | 'weight': self.weight 92 | } 93 | 94 | def copy(self, new_name: str) -> 'EdgeEmbedding': 95 | """ 96 | Return a deep copy of this edgeEmbedding instance. 97 | Params: 98 | - new_name: Give new name. Ensure it is unique in the game to avoid unexpected behaviours. 99 | """ 100 | embedding_dict = self.to_dict() 101 | embedding_dict["name"] = new_name 102 | return EdgeEmbedding.from_dict(embedding_dict, self.embed) 103 | 104 | 105 | class GameNode: 106 | """Wrapper for a document in Database""" 107 | def __init__(self, id: str, context: str, metadata: str) -> None: 108 | """ 109 | Constructor. To build the event from dict, use .from_dict() 110 | Params: 111 | - id: self-explanatory 112 | - context: also self-explanatory 113 | - metadata: the text prompts given as response to player speech. 114 | """ 115 | self.id = id 116 | self.context = context 117 | self.metadata = metadata # Future proofing 118 | 119 | def to_dict(self) -> dict: 120 | """Saves tree to a dictionary that can be serialized with json""" 121 | return { 122 | 'id': self.id, 123 | 'context': self.context, 124 | 'metadata': self.metadata 125 | } 126 | 127 | @staticmethod 128 | def from_dict(node_dict: dict) -> 'GameNode': 129 | """Rebuild tree using a dictionary to resume progress (expects json deserialization higher up in the process)""" 130 | return GameNode(node_dict["id"], node_dict["context"], node_dict["metadata"]) 131 | 132 | 133 | class BaseGameTree: 134 | """ 135 | Provides methods that supports building out a Game Tree and acts as an interface for Database. 136 | Only Constructor can exit to support retries. 137 | Technically a graph, not a tree. 138 | """ 139 | def __init__(self, name: str, embedding_function: Embeddings=None) -> None: 140 | """ 141 | Constructor. 142 | Params: 143 | - name: Name of the story (Must be chromadb friendly) 144 | - embedding_function: Takes input and returns embedding. If not provided, the Tree is marked as Final (cannot be changed) 145 | """ 146 | self.name = name 147 | self._final = False # Final flag, determines if any change can be made to the tree 148 | # if flag is False, embedding_function will be used to determine embedding using document text 149 | 150 | if embedding_function is None: 151 | print('Tree {} was initialized in final mode'.format(name)) # TODO Final Mode might be excessive in constructor 152 | self._final = True # Activate Final flag 153 | 154 | self.embed = embedding_function # TODO Rename this, Embeddings is not a function 155 | 156 | # TODO add persistent option and metadata preset 157 | self.vectorstore = Chroma(name, self.embed) # Preset to Chroma TODO make this work with all vectorstore 158 | 159 | # Currently, if there are 2 ways to reach an event, a copy with a new unique id must be made 160 | self.node_dict = {} # id - SequenceEvent. Mostly for lookup 161 | 162 | self.embedding_template_dict = {} # All Embedding template must be copied using .copy() to be used or risk unexpected behaviour 163 | 164 | # Provides edge look up for custom edge embedding. Format: {(SequenceEvent left, SequenceEvent right): edgeEmbedding embedding} 165 | # Where left is start event, right is end event, and embedding is the required embedding to go from left to right 166 | self.edge_dict = {} 167 | 168 | def add_texts(self, texts, metadatas = None, ids = None, custom_embeddings = None): 169 | """ 170 | Create GameNode and add to GameTree using add_node(). 171 | Must be overriden if a child of GameNode is used. 172 | """ 173 | if ids is None: # If id not provided, generate it 174 | ids = [str(uuid.uuid1()) for _ in texts] 175 | 176 | if custom_embeddings is None: 177 | embeddings = self.embed.embed_documents(texts) 178 | 179 | # Text to tree. Does not allow for custom edges 180 | for i in range(len(texts)): 181 | node = GameNode(ids[i], texts[i], metadatas[i]) 182 | edge = EdgeEmbedding(ids[i], self.embed, embeddings[i], 20) # Tunable embedding with default weight of 20 183 | self.add_node(node) 184 | self.add_edge('_', node.id, ids[i], edge) 185 | 186 | return ids 187 | 188 | def add_node(self, node: GameNode) -> bool: 189 | """ 190 | Add a GameNode to the GameTree. 191 | Params: 192 | - node: a GameNode instance 193 | """ 194 | if self._final: 195 | print("Game Tree was marked as Final. No change can be made to it") 196 | return False 197 | 198 | event_id = node.id 199 | 200 | if event_id in self.node_dict: 201 | print("Sequence Event id {} already exists. If there is a second way to reach this event, create a copy with a unique id and retry") 202 | return False 203 | else: 204 | self.node_dict[event_id] = node 205 | return True 206 | 207 | def add_embedding_template(self, edge_embedding: EdgeEmbedding) -> bool: 208 | """ 209 | Add embedding template. 210 | Params: 211 | - name: the name for the template 212 | - prompts: the initial prompts to tune the template. If None, consider using the "default" template instead. 213 | """ 214 | if self._final: 215 | print("Game Tree was marked as Final. No change can be made to it") 216 | return False 217 | 218 | name = edge_embedding.name 219 | 220 | if name in self.embedding_template_dict: 221 | print("Embedding template {} already exists. Use a new unique name") 222 | return False 223 | else: 224 | self.embedding_template_dict[name] = edge_embedding 225 | return True 226 | 227 | def add_edge(self, start_id: str, end_id: str, embedding_name: str, embedding_template: EdgeEmbedding) -> bool: 228 | """ 229 | Add edge to the tree and assign a premade embedding template. 230 | Most of the time, you only have to override validate_edge() to modify behaviour. 231 | Params: 232 | - start_id: The id of the start event. Must exists in event_dict. 233 | - end_id: The id of the end event. Must exists in event_dict. 234 | - embedding_name: Rename the embedding class. Ensure it is unique to avoid unexpected behaviours. 235 | - embedding_template: The EdgeEmbedding object to use as a template. For saved templates, use get_template(). 236 | """ 237 | if self.validate_edge(start_id, end_id, embedding_template.name): 238 | # Assigns to edge_dict 239 | self.edge_dict[(start_id, end_id)] = embedding_template.copy(embedding_name) 240 | return True 241 | 242 | def validate_edge(self, start_id: str, end_id: str, embedding_template_name: str="default") -> bool: 243 | """Validates new edge_dict entry. Only contains basic validation, should be overidden to prevent undesireable behaviour""" 244 | if self._final: 245 | print("Game Tree was marked as Final. No change can be made to it") 246 | return False 247 | elif end_id not in self.node_dict: # Allows for a setup like (_, end_id) where any node can reach end_id 248 | print("Either start_id or end_id does not exist in event_dict. Must be 2 of {}.".format(self.node_dict.keys())) 249 | return False 250 | # elif embedding_template_name not in self.embedding_template_dict: 251 | # print("Invalid template: Embedding Template must be one of {}. Create one with add_embedding_template() or use the default template".format(self.embedding_template_dict.keys())) 252 | # return False 253 | elif (start_id, end_id) in self.edge_dict: 254 | print("Invalid edge: edge {} already exists".format(start_id + '-' + end_id)) 255 | return False 256 | else: 257 | return True 258 | 259 | def tune_edge(self, edge_id: tuple, prompts: list) -> bool: 260 | """ 261 | Tune the chosen embedding with extra prompts. 262 | Params: 263 | - edge_id: Identifier for edge, is a tuple of form (start_event, end_event) 264 | - prompts: List of prompts for tuning 265 | """ 266 | if self._final: 267 | print("Game Tree was marked as Final. No change can be made to it") 268 | return False 269 | 270 | if edge_id not in self.edge_dict: 271 | print("edge {} does not exists".format(edge_id)) 272 | return False 273 | self.edge_dict[edge_id].tune_prompts(prompts) 274 | return True 275 | 276 | def get_template_options(self) -> list: 277 | """Get the name of all stored templates""" 278 | return self.embedding_template_dict.keys() 279 | 280 | def get_template(self, name) -> EdgeEmbedding: 281 | if name not in self.get_template_options(): 282 | print("No embedding with name {} found".format(name)) 283 | return None 284 | 285 | return self.embedding_template_dict[name] 286 | 287 | @classmethod 288 | def build_from_json(cls:'BaseGameTree', edge_to_json: str, embedding_function: Embeddings=None) -> 'BaseGameTree': 289 | """ 290 | Rebuild tree from JSON file. May cause unexpected behaviour if the JSON file was built using derived GameNode and EdgeEmbedding classes. 291 | Params: 292 | - edge_to_json: The string that signifies the edge to the jsonified tree 293 | - embedding_function: Takes input and returns embedding. If not provided, the Tree is marked as Final (cannot be changed) 294 | """ 295 | tree_dict = {} 296 | with open(edge_to_json, "r") as f: 297 | tree_dict = json.load(f) 298 | 299 | tree = cls(tree_dict["name"], embedding_function) 300 | for (template_name, template_dict) in tree_dict["embedding_template_dict"].items(): 301 | tree.embedding_template_dict[template_name] = EdgeEmbedding.from_dict(template_dict, embedding_function) 302 | 303 | for (event_id, event_dict) in tree_dict["event_dict"].items(): 304 | tree.node_dict[event_id] = GameNode.from_dict(event_dict) 305 | 306 | for (edge_string, embedding_dict) in tree_dict["edge_dict"].items(): 307 | edge = edge_string.split(" ") # Split by " " 308 | tree.edge_dict[edge] = EdgeEmbedding.from_dict(embedding_dict, embedding_function) 309 | 310 | return tree 311 | 312 | def to_dict(self) -> dict: 313 | """Return the instance in dictionary form to be saved to json""" 314 | return { 315 | "name": self.name, 316 | "embedding_template_dict": {template_id: template.to_dict() for (template_id, template) in self.embedding_template_dict.items()}, # Kinda optional here 317 | "event_dict": {event_id: event.to_dict() for (event_id, event) in self.node_dict.items()}, 318 | "edge_dict": {edge[0]+" "+edge[1]: embedding.to_dict() for (edge, embedding) in self.edge_dict.items()} 319 | } 320 | 321 | def to_json(self, edge_to_json: str) -> None: 322 | """Saves current Tree configuration to json""" 323 | with open(edge_to_json, "w") as f: 324 | json.dump(self.to_dict(), f, indent=4) 325 | 326 | def write_db(self): 327 | """Write current tree to a ChromaDB collection.""" 328 | # Only need to add the following nodes to chroma, as the starting state does not need to be defined 329 | embeddings = [] 330 | documents = [] 331 | metadatas=[] 332 | ids=[] 333 | 334 | for (edge, embedding) in self.edge_dict.items(): 335 | target_node_id = edge[1] 336 | target_node = self.node_dict[target_node_id] 337 | 338 | embeddings.append(embedding.get_embedding()) 339 | documents.append(target_node.context) 340 | ids.append(target_node_id) 341 | metadatas.append(target_node.metadata) 342 | 343 | self.vectorstore._collection.add(ids=ids, embeddings=embeddings, metadatas=metadatas, documents=documents) 344 | 345 | def get_retriever(self, **kwargs) -> BaseRetriever: 346 | """ 347 | Some possible arguments: 348 | - search_type 349 | - search_kwargs: {"k": Number of returned results} 350 | """ 351 | return self.vectorstore.as_retriever(**kwargs) # Adding some options 352 | --------------------------------------------------------------------------------