├── requirements.txt ├── text_preprocessor.py ├── keyword_tally.py ├── turn_templates.py ├── metadata.py ├── embeddings.py ├── .gitignore ├── README.md ├── helpers.py ├── annoy_manager.py └── script.py /requirements.txt: -------------------------------------------------------------------------------- 1 | annoy 2 | -------------------------------------------------------------------------------- /text_preprocessor.py: -------------------------------------------------------------------------------- 1 | # ./text_preprocessor.py 2 | import spacy 3 | from extensions.annoy_ltm.helpers import * 4 | 5 | class TextPreprocessor: 6 | def __init__(self) -> None: 7 | self.nlp = spacy.load("en_core_web_sm", disable=["parser"]) 8 | 9 | def preprocess_and_extract_named_entities(self, text): 10 | # Named Entity Recognition 11 | doc = self.nlp(text) 12 | named_entities = [ent.text for ent in doc.ents] 13 | 14 | return named_entities 15 | 16 | def preprocess_and_extract_keywords(self, text): 17 | # Tokenization, lowercasing, and stopword removal 18 | tokens = [token.text.lower() for token in self.nlp(text) if not token.is_stop] 19 | 20 | # Lemmatization 21 | lemmatized_tokens = [token.lemma_ for token in self.nlp(" ".join(tokens))] 22 | 23 | keywords = lemmatized_tokens 24 | 25 | return keywords 26 | 27 | def trim_and_preprocess_text(self, text, state): 28 | text_to_process = remove_username_and_timestamp(text, state) 29 | keywords = self.preprocess_and_extract_keywords(text_to_process) 30 | named_entities = self.preprocess_and_extract_named_entities(text_to_process) 31 | 32 | return keywords, named_entities 33 | -------------------------------------------------------------------------------- /keyword_tally.py: -------------------------------------------------------------------------------- 1 | # ./keyword_tally.py 2 | 3 | class KeywordTally: 4 | def __init__(self): 5 | self.keyword_tally_count = {} 6 | self.total_keywords = 0 7 | self.most_common_count = 0 8 | 9 | def tally(self, keywords): 10 | for keyword in keywords: 11 | self.total_keywords += 1 12 | if keyword in self.keyword_tally_count: 13 | self.keyword_tally_count[keyword] += 1 14 | else: 15 | self.keyword_tally_count[keyword] = 1 16 | 17 | if self.keyword_tally_count[keyword] > self.most_common_count: 18 | self.most_common_count = self.keyword_tally_count[keyword] 19 | 20 | def get_significance(self, keywords): 21 | significance = 0 22 | keywords_len = len(keywords) 23 | if keywords_len == 0: 24 | return 0 25 | 26 | for keyword in keywords: 27 | if keyword in self.keyword_tally_count: 28 | ratio = self.keyword_tally_count[keyword] / self.most_common_count 29 | significance += 1 - ratio 30 | return significance / len(keywords) 31 | 32 | def exportKeywordTally(self): 33 | return self.keyword_tally_count 34 | 35 | def importKeywordTally(self, keyword_tally_data): 36 | self.keyword_tally_count = keyword_tally_data 37 | self.total_keywords = sum(keyword_tally_data.values()) 38 | self.most_common_count = max(keyword_tally_data.values()) -------------------------------------------------------------------------------- /turn_templates.py: -------------------------------------------------------------------------------- 1 | # ./turn_templates.py 2 | 3 | from extensions.annoy_ltm.helpers import replace_all 4 | 5 | def get_turn_templates(state, is_instruct, logger): 6 | 7 | logger(f"state['turn_template']: {state['turn_template']}", 5) 8 | 9 | # Building the turn templates 10 | if 'turn_template' not in state or state['turn_template'] == '': 11 | if is_instruct: 12 | template = '<|user|>\n<|user-message|>\n<|bot|>\n<|bot-message|>\n' 13 | else: 14 | template = '<|user|>: <|user-message|>\n<|bot|>: <|bot-message|>\n' 15 | else: 16 | template = state['turn_template'].replace(r'\n', '\n') 17 | 18 | replacements = { 19 | '<|user|>': state['name1'].strip(), 20 | '<|bot|>': state['name2'].strip(), 21 | } 22 | logger(f"turn_template replacements: {replacements}", 5) 23 | 24 | user_turn = replace_all(template.split('<|bot|>')[0], replacements) 25 | bot_turn = replace_all('<|bot|>' + template.split('<|bot|>')[1], replacements) 26 | user_turn_stripped = replace_all(user_turn.split('<|user-message|>')[0], replacements) 27 | bot_turn_stripped = replace_all(bot_turn.split('<|bot-message|>')[0], replacements) 28 | 29 | logger(f"turn_templates:\nuser_turn:{user_turn}\nbot_turn:{bot_turn}\nuser_turn_stripped:{user_turn_stripped}\nbot_turn_stripped:{bot_turn_stripped}", 5) 30 | 31 | return user_turn, bot_turn, user_turn_stripped, bot_turn_stripped 32 | 33 | def apply_turn_templates_to_rows(rows, state, logger): 34 | is_instruct = state['mode'] == 'instruct' 35 | user_turn, bot_turn, user_turn_stripped, bot_turn_stripped = get_turn_templates(state, is_instruct, logger=logger) 36 | output_rows = [] 37 | for i, row in enumerate(rows): 38 | if row[0] not in ['', '<|BEGIN-VISIBLE-CHAT|>']: 39 | user_row = replace_all(user_turn, {'<|user-message|>': row[0].strip(), '<|round|>': str(i)}) 40 | else: 41 | user_row = row[0] 42 | bot_row = bot_turn.replace('<|bot-message|>', row[1].strip()) 43 | output_rows.append((user_row, bot_row)) 44 | 45 | return output_rows -------------------------------------------------------------------------------- /metadata.py: -------------------------------------------------------------------------------- 1 | # ./metadata.py 2 | 3 | from modules import shared 4 | 5 | import hashlib 6 | import json 7 | import os 8 | import glob 9 | 10 | def compute_file_hash(filepath): 11 | hasher = hashlib.md5() 12 | with open(filepath, 'rb') as f: 13 | buf = f.read() 14 | hasher.update(buf) 15 | return hasher.hexdigest() 16 | 17 | def save_metadata(metadata, filepath): 18 | dir_path = os.path.dirname(filepath) 19 | 20 | # Check if the directory exists, and create it if necessary 21 | if not os.path.exists(dir_path): 22 | os.makedirs(dir_path) 23 | 24 | with open(filepath, 'w') as f: 25 | json.dump(metadata, f) 26 | 27 | def load_metadata(filepath): 28 | if os.path.exists(filepath): 29 | with open(filepath, 'r') as f: 30 | return json.load(f) 31 | return None 32 | 33 | def compute_hashes(history, remove_last_user_message=False): 34 | # Compute hash for each Python file in the same directory 35 | python_files = glob.glob(os.path.join(os.path.dirname(__file__), '*.py')) 36 | code_hash = ''.join(sorted([compute_file_hash(file) for file in python_files])) 37 | 38 | if remove_last_user_message: 39 | messages_hash = [hashlib.md5(str(msg).encode()).hexdigest() for msg in history['internal'][:-1]] 40 | else: 41 | messages_hash = [hashlib.md5(str(msg).encode()).hexdigest() for msg in history['internal']] 42 | 43 | return code_hash, messages_hash 44 | 45 | 46 | def check_hashes(metadata, history, logger): 47 | if metadata is None: 48 | return False 49 | 50 | code_hash, messages_hash = compute_hashes(history, remove_last_user_message=True) 51 | 52 | logger(f"Metadata code hash: {metadata['code_hash']}", 5) 53 | logger(f"Computed code hash: {code_hash}", 5) 54 | logger(f"Metadata messages hash: {metadata['messages_hash']}", 5) 55 | logger(f"Computed messages hash: {messages_hash}", 5) 56 | 57 | if metadata['code_hash'] != code_hash: 58 | return False 59 | 60 | if metadata['messages_hash'] != messages_hash: 61 | return False 62 | 63 | if metadata['model_name'] != shared.model_name: 64 | return False 65 | 66 | return True -------------------------------------------------------------------------------- /embeddings.py: -------------------------------------------------------------------------------- 1 | # ./embeddings.py 2 | 3 | from modules import shared 4 | from extensions.annoy_ltm.helpers import get_device 5 | 6 | import torch 7 | 8 | def generate_embeddings(text, logger): 9 | """ 10 | Generates embeddings for a given text. 11 | 12 | Parameters: 13 | text (str): The input text to generate embeddings for. 14 | logger (logging.Logger): A logger to log the process. 15 | 16 | Returns: 17 | np.ndarray: The generated embeddings. 18 | """ 19 | 20 | input_ids = shared.tokenizer.encode(text, return_tensors="pt", add_special_tokens=False) 21 | input_ids = input_ids.long() # ensure the values are not floats 22 | 23 | with torch.no_grad(): 24 | if hasattr(shared.model, 'ex_model'): 25 | input_ids = input_ids.to(torch.device("cpu")) # Move input_ids to the model's device 26 | input_embeds = shared.model.ex_model.embed_tokens(input_ids) 27 | elif hasattr(shared.model.model, 'embed_tokens'): 28 | input_ids = input_ids.to(get_device()) # Move input_ids to the model's device 29 | input_embeds = shared.model.model.embed_tokens(input_ids) 30 | elif hasattr(shared.model.model, 'get_input_embeddings'): 31 | input_ids = input_ids.to(get_device()) # Move input_ids to the model's device 32 | input_embeds = shared.model.model.get_input_embeddings()(input_ids) 33 | elif hasattr(shared.model.model, 'model'): # Reported in issue #17 34 | input_ids = input_ids.to(get_device()) # Move input_ids to the model's device 35 | if hasattr(shared.model.model.model, 'embed_tokens'): 36 | input_embeds = shared.model.model.model.embed_tokens(input_ids) 37 | else: 38 | raise AttributeError("The model doesn't have an 'embed_tokens' or 'get_input_embeddings' method.") 39 | 40 | input_embeds = input_embeds.mean(dim=1).squeeze(0) # Remove the extra dimension 41 | result = input_embeds.cpu().numpy().flatten() # Convert to NumPy array and flatten 42 | logger(f"generating embeddings for text: {text}\n{result}", 5) 43 | return result 44 | 45 | class Embeddings: 46 | def __init__(self) -> None: 47 | self.embeddings = {} 48 | 49 | def add_embedding(self, unique_index, embeddings): 50 | self.embeddings[unique_index] = embeddings 51 | 52 | def export_embeddings(self): 53 | return self.embeddings 54 | 55 | def import_embeddings(self, embeddings): 56 | self.embeddings = embeddings -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | outputs 163 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # annoy_ltm 2 | 3 | This repository contains an extension for the oobabooga-text-generation-webui application, introducing long-term memory to chat bots using the Annoy (Approximate Nearest Neighbors Oh Yeah) nearest neighbor vector database. 4 | 5 | ## Features 6 | 7 | The `annoy_ltm` extension provides chat bots with a form of long-term memory. It leverages the efficient search algorithm of Annoy to retrieve similar vector representations from the history, allowing the bot to reference past interactions. 8 | 9 | ## Installation 10 | 11 | This extension can be installed like any other extension to the oobabooga-text-generation-webui, with an additional requirement for the Spacy language model. Follow the instructions below: 12 | 13 | 1. Download and install the Spacy en_core_web_sm model. You can do this by running the `cmd_windows.bat` and then executing the following commands in the resulting cmd shell: 14 | 15 | Windows WSL: 16 | 17 | ```bash 18 | pip install -U pip setuptools wheel 19 | pip install -U spacy 20 | python -m spacy download en_core_web_sm 21 | ``` 22 | Linux: 23 | In the environment you are using for Oobabooga-text-generation-webui, run the folowing command: 24 | 25 | ```bash 26 | python -m spacy download en_core_web_sm 27 | ``` 28 | 2. Follow the regular installation process for extensions to the oobabooga-text-generation-webui application. 29 | 30 | 3. Navigate to the annoy_ltm extension folder and run the following command to install the dependencies: 31 | ```bash 32 | pip install -r requirements.txt 33 | ``` 34 | 35 | 36 | ## Usage 37 | 38 | Once the extension is enabled, it works automatically with no additional steps needed. You can configure its behavior by modifying the following parameters in the `settings.json` of the webui: 39 | 40 | | Parameter | Description | Default Value | 41 | | --------------------------- | --------------- | ------------- | 42 | | `annoy_output_dir` | Directory where outputs are stored. | `"extensions/annoy_ltm/outputs/"` | 43 | | `logger_level` | Logging level, higher number results in more verbose logging. Maximum reasonable value for normal debugging is 3. | `1` | 44 | | `memory_retention_threshold`| Retention threshold for memories. Lower values cause memories to retain longer, potentially at the cost of stack overflow and irrelevant memory retention. Ranges from 0-1. | `0.68` | 45 | | `full_memory_additional_weight`| Additional weight for the full memory. Smaller values result in higher weight. Ranges from 0-1. | `0.5` | 46 | | `num_memories_to_retrieve` | Number of related memories to retrieve for the full message and every keyword group generated from the message. Higher values can cause significant slowdowns. | `5` | 47 | | `keyword_grouping` | Number to group keywords into. Higher values make it harder to find an exact match, potentially improving context relevance at the cost of memory retrieval. | `4` | 48 | | `maximum_memory_stack_size` | Maximum size for the memory stack, preventing overflow. | `50` | 49 | | `prompt_memory_ratio` | The ratio of the prompt after character context is applied that will be dedicated for memories. | `0.4` | 50 | | `vector_dim_override` | Override value for the hidden layer dimension of your loaded model, Use if you encounter issues with the generated embeddings not matching the dimensionality of the annoy index. `-1` is disabled. | `-1` | 51 | 52 | These parameters allow you to tune the operation of `annoy_ltm` to best suit your specific use-case. 53 | 54 | ## Support 55 | 56 | For any issues or queries, please use the Issues tab on this repository. 57 | 58 | ## Docker 59 | Hey you! Yeah you about to install some random project extension code into your non-dockerized oobabooga instance! Don't you know that's dangerous? I highly recommend you check out the docker setup for oobabooga-text-generation-webui before randomly installing anything and do your due dilligance by reading through the extension code! You got that kind of time. 60 | https://github.com/oobabooga/text-generation-webui#alternative-docker 61 | -------------------------------------------------------------------------------- /helpers.py: -------------------------------------------------------------------------------- 1 | # ./helpers.py 2 | 3 | import re 4 | import time 5 | from typing import List 6 | import numpy as np 7 | import torch 8 | 9 | def remove_username(message: str, state) -> str: 10 | """ 11 | Removes the username prefix from a message string. Returns the message without the username. 12 | """ 13 | return re.sub(rf'^({state["name1"].strip()}|{state["name2"].strip()})[:,\s]*', '', message) 14 | 15 | def remove_timestamp(message: str) -> str: 16 | """ 17 | Removes timestamp from a message string. Assumes timestamp in format 'YYYY-MM-DD HH:MM:SS'. Returns the message without the timestamp. 18 | """ 19 | return re.sub(r'\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}', '', message) 20 | 21 | def remove_username_and_timestamp(message: str, state) -> str: 22 | """ 23 | Removes both the username and timestamp from a message string. Returns the message without the username and timestamp. 24 | """ 25 | return remove_username(remove_timestamp(message), state) 26 | 27 | def filter_keywords(keywords, min_length=3): 28 | """ 29 | Filters a list of keywords, removing any keywords shorter than min_length. Returns the filtered list of keywords. 30 | """ 31 | filtered_keywords = [keyword for keyword in keywords if len(keyword) >= min_length] 32 | return filtered_keywords 33 | 34 | def generate_keyword_groups(keywords: List[str], n: int = 2) -> List[str]: 35 | """ 36 | Generates groups of n consecutive keywords from a list. Returns a list of keyword groups. 37 | """ 38 | return [" ".join(keywords[i:i + n]) for i in range(len(keywords) - n + 1)] 39 | 40 | def merge_memory_lists_by_distance(list1, list2, max_new_list_length=500): 41 | """ 42 | Merges two memory lists by their distance values. Returns a merged list up to a maximum length of max_new_list_length. 43 | """ 44 | merged_list = [] 45 | i = j = 0 46 | while i < len(list1) and j < len(list2) and len(merged_list) < max_new_list_length: 47 | if list1[i][2] < list2[j][2]: 48 | merged_list.append(list1[i]) 49 | i += 1 50 | else: 51 | merged_list.append(list2[j]) 52 | j += 1 53 | while i < len(list1) and len(merged_list) < max_new_list_length: 54 | merged_list.append(list1[i]) 55 | i += 1 56 | while j < len(list2) and len(merged_list) < max_new_list_length: 57 | merged_list.append(list2[j]) 58 | j += 1 59 | return merged_list 60 | 61 | def remove_duplicates(memory_stack): 62 | """ 63 | Removes duplicate memories from a memory stack, keeping the one with the shortest distance. Returns a list of unique memories sorted by distance. 64 | """ 65 | memory_dict = {} 66 | for memory in memory_stack: 67 | index, _, distance = memory 68 | if index not in memory_dict or distance < memory_dict[index][2]: 69 | memory_dict[index] = memory 70 | return sorted(memory_dict.values(), key=lambda x: x[2]) 71 | 72 | def cosine_similarity(a, b): 73 | """ 74 | Computes the cosine similarity between two vectors a and b 75 | """ 76 | norm_a = np.linalg.norm(a) 77 | norm_b = np.linalg.norm(b) 78 | 79 | if norm_a == 0 or norm_b == 0: 80 | # Handle the case where a or b is a zero vector. 81 | # Here we return None, but you could also return 0 or another value. 82 | return None 83 | 84 | dot_product = np.dot(a, b) 85 | return dot_product / (norm_a * norm_b) 86 | 87 | # Replace multiple string pairs in a string 88 | def replace_all(text, dic): 89 | """ 90 | Replaces all instances of certain substrings in a text string. 91 | """ 92 | for i, j in dic.items(): 93 | text = text.replace(i, j) 94 | 95 | return text 96 | 97 | 98 | #--------------- Annoy helpers --------------- 99 | def copy_items(src_index, dest_index, num_items, logger): 100 | """ 101 | Copies all items from one annoy index to another 102 | """ 103 | start_time = time.time() 104 | for i in range(num_items): 105 | item = src_index.get_item_vector(i) 106 | dest_index.add_item(i, item) 107 | 108 | end_time = time.time() 109 | logger(f"copying annoy index took {end_time-start_time} seconds...", 1) 110 | 111 | #--------------- PyTorch helpers --------------- 112 | 113 | def get_device(): 114 | return torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 115 | -------------------------------------------------------------------------------- /annoy_manager.py: -------------------------------------------------------------------------------- 1 | # ./annoy_manager.py 2 | from math import floor 3 | import time 4 | 5 | from modules import shared 6 | from annoy import AnnoyIndex 7 | import queue 8 | import copy 9 | import threading 10 | 11 | from extensions.annoy_ltm.helpers import * 12 | from extensions.annoy_ltm.metadata import check_hashes, compute_hashes, load_metadata, save_metadata 13 | from extensions.annoy_ltm.embeddings import generate_embeddings 14 | from extensions.annoy_ltm.keyword_tally import KeywordTally 15 | from extensions.annoy_ltm.turn_templates import apply_turn_templates_to_rows 16 | 17 | class AnnoyManager: 18 | def __init__(self, text_preprocessor) -> None: 19 | self.results_queue = queue.Queue() 20 | self.text_preprocessor = text_preprocessor 21 | self.metadata = None 22 | self.annoy_index = None 23 | self.metadata_file = None 24 | self.annoy_index_file = None 25 | self.hidden_size = None 26 | self.loaded_history_last_index = 0 27 | # Create dictionary for annoy indices 28 | self.index_to_history_position = {} 29 | self.lock = threading.Lock() 30 | 31 | def _get_hidden_size(self, params, logger): 32 | if params['vector_dim_override'] != -1: 33 | return params['vector_dim_override'] 34 | try: 35 | if hasattr(shared.model, 'ex_config'): 36 | return shared.model.ex_config.hidden_size 37 | if hasattr(shared.model, 'config'): 38 | return shared.model.config.hidden_size 39 | 40 | return shared.model.model.config.hidden_size 41 | except AttributeError: 42 | return len(generate_embeddings('generate a set of embeddings to determin size of result list', logger=logger)) 43 | 44 | def save_files_to_disk(self, logger): 45 | with self.lock: 46 | try: 47 | logger(f"Cloning data before saveing...", 3) 48 | metadata_to_save = copy.deepcopy(self.metadata) 49 | annoy_index_to_save = AnnoyIndex(self.hidden_size, 'angular') 50 | copy_items(self.annoy_index, annoy_index_to_save, self.annoy_index.get_n_items(), logger) 51 | 52 | logger(f"Saving metadata...", 3) 53 | save_metadata(metadata_to_save, self.metadata_file) 54 | logger(f"Metadata saved.", 3) 55 | logger(f"Saving annoy_index...", 3) 56 | annoy_index_to_save.build(10) 57 | annoy_index_to_save.save(self.annoy_index_file) 58 | logger(f"annoy_index saved.", 3) 59 | except Exception as e: 60 | logger(f"An error occurred while saving files to disk:\n{e}", level=1) 61 | 62 | 63 | def generate_annoy_db(self, params, state, history, keyword_tally, logger): 64 | with self.lock: 65 | try: 66 | # Generate annoy database for LTM 67 | start_time = time.time() 68 | 69 | self.metadata_file = f"{params['annoy_output_dir']}{state['name2']}-annoy-metadata.json" 70 | self.annoy_index_file = f"{params['annoy_output_dir']}{state['name2']}-annoy_index.ann" 71 | 72 | if self.metadata == None: 73 | logger(f"Loading metadata file...", 5) 74 | self.metadata = load_metadata(self.metadata_file) 75 | logger(f"Loaded metadata.", 5) 76 | if self.metadata == None: 77 | logger(f"failed to load character annoy metadata, generating from scratch...", 1) 78 | else: 79 | logger(f"loaded metadata file ({len(self.metadata['messages_hash'])})", 2) 80 | 81 | 82 | hidden_size = self._get_hidden_size(params, logger) 83 | if self.annoy_index == None or self.hidden_size != hidden_size: 84 | self.hidden_size = hidden_size 85 | loaded_annoy_index = AnnoyIndex(self.hidden_size, 'angular') 86 | self.annoy_index = AnnoyIndex(self.hidden_size, 'angular') 87 | 88 | if check_hashes(self.metadata, history, logger): 89 | logger(f"Loading annoy database...", 5) 90 | loaded_annoy_index.load(self.annoy_index_file) 91 | logger(f"Loaded database.", 5) 92 | loaded_history_items = loaded_annoy_index.get_n_items() 93 | if loaded_history_items < 1: 94 | logger(f"hashes check passed but no items found in annoy db. rebuilding annoy db...", 2) 95 | else: 96 | logger(f"hashes check passed, proceeding to load existing memory db...", 2) 97 | keyword_tally.importKeywordTally(self.metadata['keyword_tally']) 98 | self.index_to_history_position = {int(k): v for k, v in self.metadata['index_to_history_position'].items()} 99 | self.loaded_history_last_index = self.index_to_history_position[loaded_history_items-1] 100 | logger(f"loaded {self.loaded_history_last_index} items from existing memory db", 3) 101 | copy_items(loaded_annoy_index, self.annoy_index, loaded_history_items, logger) 102 | loaded_annoy_index.unload() 103 | else: 104 | logger(f"hashes check failed, either an existing message changed unexpectdly or the extension code has changed. Rebuilding annoy db...", 2) 105 | keyword_tally = KeywordTally() 106 | self.loaded_history_last_index = 0 107 | 108 | formated_history_rows = apply_turn_templates_to_rows(history['internal'][self.loaded_history_last_index:], state, logger=logger) 109 | logger(f"found {len(formated_history_rows)} rows of chat history to be added to memory db. adding items...", 3) 110 | unique_index = len(self.index_to_history_position) 111 | for i, row in enumerate(formated_history_rows): 112 | for msg in row: 113 | trimmed_msg = remove_username_and_timestamp(msg, state) 114 | if trimmed_msg and len(trimmed_msg) > 0: 115 | # Add the full message 116 | logger(f"HISTORY_{i+1}_MSG: {msg}", 4) 117 | embeddings = generate_embeddings(trimmed_msg, logger=logger) 118 | self.annoy_index.add_item(unique_index, embeddings) 119 | self.index_to_history_position[unique_index] = i+self.loaded_history_last_index 120 | unique_index += 1 121 | 122 | # Add keywords and named entities 123 | keywords, named_entities = self.text_preprocessor.trim_and_preprocess_text(msg, state) 124 | keyword_tally.tally(keywords + named_entities) # Keep a tally of all keywords and named_entities 125 | filtered_keywords = filter_keywords(keywords) 126 | keyword_groups = generate_keyword_groups(filtered_keywords, params['keyword_grouping']) 127 | logger(f"HISTORY_{i+1}_KEYWORDS: {','.join(filtered_keywords)}", 4) 128 | for keyword in keyword_groups: 129 | embeddings = generate_embeddings(keyword, logger=logger) 130 | logger(f"storing keyword \"{keyword}\" with embeddings {embeddings}", 5) 131 | self.annoy_index.add_item(unique_index, embeddings) 132 | self.index_to_history_position[unique_index] = i+self.loaded_history_last_index 133 | unique_index += 1 134 | 135 | if len(named_entities) > 0: 136 | named_entities = " ".join(named_entities) 137 | embeddings = generate_embeddings(named_entities, logger=logger) 138 | logger(f"storing named_entities \"{named_entities}\" with embeddings {embeddings}", 5) 139 | self.annoy_index.add_item(unique_index, embeddings) 140 | self.index_to_history_position[unique_index] = i+self.loaded_history_last_index 141 | unique_index += 1 142 | 143 | self.loaded_history_last_index += len(formated_history_rows) 144 | 145 | # Save the annoy index and metadata 146 | code_hash, messages_hash = compute_hashes(history) 147 | self.metadata = { 148 | 'code_hash': code_hash, 149 | 'messages_hash': messages_hash, 150 | 'model_name': shared.model_name, 151 | 'index_to_history_position': self.index_to_history_position, 152 | 'keyword_tally': keyword_tally.exportKeywordTally() 153 | } 154 | 155 | 156 | # Put the result in the queue. 157 | return_index = AnnoyIndex(self.hidden_size, 'angular') 158 | copy_items(self.annoy_index, return_index, self.annoy_index.get_n_items(), logger=logger) 159 | return_index_to_history_position = copy.copy(self.index_to_history_position) 160 | 161 | return_index.build(10) 162 | 163 | end_time = time.time() 164 | logger(f"building annoy index took {end_time-start_time} seconds...", 1) 165 | 166 | self.results_queue.put((return_index_to_history_position, return_index, keyword_tally)) 167 | return return_index_to_history_position, return_index, keyword_tally 168 | 169 | except Exception as e: 170 | logger(f"An error occurred while generating annoy database:\n{e}", level=1) 171 | 172 | def generate_and_save(self, params, state, history, keyword_tally, logger): 173 | self.generate_annoy_db(params, state, history, keyword_tally, logger) 174 | self.save_files_to_disk(logger) 175 | -------------------------------------------------------------------------------- /script.py: -------------------------------------------------------------------------------- 1 | from math import floor 2 | import time 3 | 4 | from modules import shared 5 | from modules.extensions import apply_extensions 6 | from modules.text_generation import encode, get_max_prompt_length 7 | from annoy import AnnoyIndex 8 | from collections import deque 9 | import queue 10 | import concurrent.futures 11 | 12 | from extensions.annoy_ltm.helpers import * 13 | from extensions.annoy_ltm.text_preprocessor import TextPreprocessor 14 | from extensions.annoy_ltm.embeddings import generate_embeddings 15 | from extensions.annoy_ltm.keyword_tally import KeywordTally 16 | from extensions.annoy_ltm.annoy_manager import AnnoyManager 17 | from extensions.annoy_ltm.turn_templates import get_turn_templates, apply_turn_templates_to_rows 18 | 19 | # parameters which can be customized in settings.json of webui 20 | params = { 21 | 'annoy_output_dir': "extensions/annoy_ltm/outputs/", 22 | 'logger_level': 1, # higher number is more verbose logging. 3 is really as high as any reasonable person should go for normal debugging 23 | 'vector_dim_override': -1, # magic number determined by your loaded model. This parameter is here so that should some style of model in the future not include the hidden_size in the config, this can be used as a workaround. 24 | 'memory_retention_threshold': 0.7, # 0-1, lower value will make memories retain longer but can cause stack to overflow and irrelevant memories to be held onto 25 | 'full_memory_additional_weight': 0.3, # 0-1, smaller value is more weight here. 26 | 'keyword_match_weight': 0.6, # 0-1, smaller value is more weight here. 27 | 'ner_character_len_weight': 100.0, #number of characters to max out weight of named entities at. 28 | 'named_entity_match_clamp_min_dist': 0.6, # 0-1, clamp weight to this value, Prevents exact NER match from overriding all other memories. 29 | 'num_memories_to_retrieve': 5, # the number of related memories to retrieve for the full message and every keyword group and named entity generated from the message. Can cause significant slowdowns. 30 | 'keyword_grouping': 4, # the number to group keywords into. Higher means harder to find an exact match, which makes matches more useful to context but too high and no memories will be returned. 31 | 'keyword_rarity_weight': 1, # Throttles the weight applied to memories favoring unique phrases and vocabularly. 32 | 'maximum_memory_stack_size': 50, # just a cap on the stack so it doesn't blow. 33 | 'prompt_memory_ratio': 0.4 # the ratio of prompt after the character context is applied that will be dedicated for memories. 34 | } 35 | 36 | #--------------- Logger --------------- 37 | def logger(message: str, level=5): 38 | if params['logger_level'] >= level: 39 | print(message) 40 | 41 | #--------------- Custom Prompt Generator --------------- 42 | 43 | class ChatGenerator: 44 | def __init__(self): 45 | self.memory_stack = deque() 46 | self.loaded_character = None 47 | self.loaded_model = None 48 | self.keyword_tally = KeywordTally() 49 | self.text_preprocessor = TextPreprocessor() 50 | self.annoy_manager = AnnoyManager(self.text_preprocessor) 51 | self.annoy_index = None 52 | 53 | # Create dictionary for annoy indices 54 | self.index_to_history_position = {} 55 | 56 | #--------------- Memory --------------- 57 | def compare_text_embeddings(self, text1, text2): 58 | if len(text1) == 0 or len(text2) == 0: 59 | return 1 60 | 61 | logger(f"comparing text {text1}\nagainst {text2}", 5) 62 | text1_embeddings = generate_embeddings(text1, logger=logger) 63 | text2_embeddings = generate_embeddings(text2, logger=logger) 64 | logger(f"text1_embeddings: {text1_embeddings}", 6) 65 | logger(f"text2_embeddings: {text2_embeddings}", 6) 66 | logger(f"len text1_embeddings: {len(text1_embeddings)}", 6) 67 | logger(f"len text2_embeddings: {len(text2_embeddings)}", 6) 68 | cosine_similarity_value = cosine_similarity(text1_embeddings, text2_embeddings) 69 | logger(f"manually computed cosine similarity: {cosine_similarity_value}", 5) 70 | if cosine_similarity_value == None: 71 | return 1 72 | 73 | return cosine_similarity_value 74 | 75 | 76 | def evaluate_memory_relevance(self, state, memory, conversation, min_relevance_threshold=0.2) -> bool: 77 | memory_text = ''.join([user_mem + '\n' + bot_mem for user_mem, bot_mem in memory]) 78 | conversation_text = ''.join(conversation) 79 | logger(f"\nevaluating memory relevance for memory: {memory}", 4) 80 | memory_keywords, memory_named_entities = self.text_preprocessor.trim_and_preprocess_text(memory_text, state) 81 | conversation_keywords, conversation_named_entities = self.text_preprocessor.trim_and_preprocess_text(conversation_text, state) 82 | 83 | memory_keywords = " ".join(filter_keywords(memory_keywords)) 84 | conversation_keywords = " ".join(filter_keywords(conversation_keywords)) 85 | 86 | memory_named_entities = " ".join(memory_named_entities) 87 | conversation_named_entities = " ".join(conversation_named_entities) 88 | 89 | logger(f"comparing memory_keywords against conversation_keywords", 5) 90 | keyword_similarity_value = self.compare_text_embeddings(memory_keywords, conversation_keywords) 91 | logger(f"keyword_similarity_value: {keyword_similarity_value}", 6) 92 | logger(f"comparing memory_named_entities against conversation_named_entities", 5) 93 | named_entitiy_similarity_value = self.compare_text_embeddings(memory_named_entities, conversation_named_entities) 94 | logger(f"named_entity_similarity_value: {named_entitiy_similarity_value}", 6) 95 | named_entity_weight_calc = (named_entitiy_similarity_value + (1.0 - float((min(len(memory_named_entities), params['ner_character_len_weight']) / params['ner_character_len_weight'])))) / 2 #This is a bit of a hack to give more NEs a higher weight 96 | logger(f"named_entity_weight_calc: {named_entity_weight_calc}", 6) 97 | value_sum = keyword_similarity_value + named_entity_weight_calc 98 | logger(f"value_sum: {value_sum}", 6) 99 | similarity_value = 0.0 100 | if value_sum > 0.0: 101 | similarity_value = value_sum / 2.0 102 | logger(f"similarity_value: {similarity_value}", 5) 103 | relevance_value = 1.0 - similarity_value 104 | logger(f"calculated relevance: {relevance_value}", 3) 105 | return relevance_value >= min_relevance_threshold 106 | 107 | 108 | def retrieve_related_memories(self, state, history, annoy_index, input_messages, history_rows, index_to_history_position, keyword_tally, num_related_memories=3, weight=0.5): 109 | return_memories = set() 110 | for input_str in input_messages: 111 | logger(f"retrieving memories for {input_str} ", 3) 112 | if num_related_memories == 0: 113 | num_related_memories = annoy_index.get_n_items() 114 | input_embedding = generate_embeddings(remove_username_and_timestamp(input_str, state), logger=logger) 115 | results_indices = [] 116 | results_distances = [] 117 | 118 | # Query for the original input_embedding 119 | indices, distances = annoy_index.get_nns_by_vector(input_embedding, num_related_memories, include_distances=True) 120 | results_indices.extend(indices) 121 | results_distances.extend(map(lambda x: x * weight, distances)) 122 | original_input_results_count = len(results_distances) 123 | 124 | # Get keywords and named entities 125 | keywords, named_entities = self.text_preprocessor.trim_and_preprocess_text(input_str, state) 126 | filtered_keywords = filter_keywords(keywords) 127 | keyword_groups = generate_keyword_groups(filtered_keywords, params['keyword_grouping']) 128 | logger(f"INPUT_KEYWORDS: {','.join(filtered_keywords)}", 4) 129 | 130 | # Query for each keyword_embedding 131 | for keyword in keyword_groups: 132 | keyword_embedding = generate_embeddings(keyword, logger=logger) 133 | logger(f"looking up keyword \"{keyword}\" embeddings {keyword_embedding}", 3) 134 | indices, distances = annoy_index.get_nns_by_vector(keyword_embedding, num_related_memories, include_distances=True) 135 | logger(f"keyword matches: {keyword}\n{indices}\n{distances}", 3) 136 | results_indices.extend(indices) 137 | results_distances.extend(map(lambda x: x*params['keyword_match_weight'], distances)) 138 | 139 | # Query for each named entity 140 | if len(named_entities) > 0: 141 | named_entities = " ".join(named_entities) 142 | named_entity_embedding = generate_embeddings(named_entities, logger=logger) 143 | logger(f"looking up named entity \"{named_entities}\" embeddings {named_entity_embedding}", 3) 144 | indices, distances = annoy_index.get_nns_by_vector(named_entity_embedding, num_related_memories, include_distances=True) 145 | logger(f"named_entities matches: {named_entities}\n{indices}\n{distances}", 3) 146 | results_indices.extend(indices) 147 | results_distances.extend(map(lambda x: x * (1-params['named_entity_match_clamp_min_dist']) + params['named_entity_match_clamp_min_dist'] , distances)) 148 | 149 | if len(results_indices) == 0: 150 | return [] # If we don't have any results, not much point in progressing. 151 | 152 | # 1. Combine the results 153 | indices_distances = list(zip(results_indices, results_distances)) 154 | 155 | # 3. Create a new list of unique history positions tupled with their distance while applying weights for duplicates 156 | history_positions_distances = {} 157 | for index, distance in indices_distances: 158 | history_position = index_to_history_position[index] 159 | if history_position in history_positions_distances: 160 | history_positions_distances[history_position].append(distance) 161 | else: 162 | history_positions_distances[history_position] = [distance] 163 | 164 | 165 | weighted_history_positions = [(pos, min(distances) / len(distances)) for pos, distances in history_positions_distances.items()] 166 | 167 | return_memories.update(set(weighted_history_positions)) 168 | # return_memories.extend(weighted_history_positions) 169 | 170 | # 4. Get the related memories using the new sorted list 171 | related_memories = [(pos, history['internal'][max(0, pos - 1):pos + 1], distance) for pos, distance in list(return_memories)] 172 | 173 | # Get keywords for each memory and calculate their significance 174 | for i in range(len(related_memories)): 175 | index, memory, distance = related_memories[i] 176 | memory_keywords = [] 177 | for user_msg, bot_reply in memory: 178 | usr_keywords, usr_ne = self.text_preprocessor.trim_and_preprocess_text(user_msg, state) 179 | bot_keywords, bot_ne = self.text_preprocessor.trim_and_preprocess_text(bot_reply, state) 180 | memory_keywords.extend(filter_keywords(usr_keywords + usr_ne)) 181 | memory_keywords.extend(filter_keywords(bot_keywords + bot_ne)) 182 | 183 | significance = params['keyword_rarity_weight'] * keyword_tally.get_significance(memory_keywords) 184 | logger(f"keywords [{','.join(memory_keywords)}] significance calculated at {significance}", 4) 185 | 186 | # Apply the significance ratio to the memory's distance value 187 | related_memories[i] = (index, memory, distance * significance) 188 | 189 | # 5. Sort the new list 190 | sorted_weighted_related_memories = sorted(related_memories, key=lambda x: (x[2], x[0])) 191 | logger(f"RESULTS: {sorted_weighted_related_memories}", 4) 192 | 193 | # 6. Filter out memories that are already present in the history added to the prompt 194 | non_duplicate_memories = [ 195 | (index, memory, distance) for index, memory, distance in sorted_weighted_related_memories 196 | if all(msg not in history_rows for msg in memory) 197 | ] 198 | 199 | return non_duplicate_memories 200 | 201 | 202 | 203 | def build_memory_rows(self, state, history_rows, user_input, max_memory_length, turn_templates, relevance_threshold=0.2): 204 | user_turn, bot_turn = turn_templates 205 | 206 | # Filter out irrelevant memories 207 | logger(f"HISTORY_ROWS:{history_rows}", 5) 208 | conversation = [remove_username_and_timestamp(row, state) for row in history_rows] + [remove_timestamp(user_input)] 209 | logger(f"CONVERSATION:{conversation}", 5) 210 | 211 | def log_and_check_relevance(memory_tuple, conversation, relevance_threshold): 212 | relevance_check = self.evaluate_memory_relevance(state, memory_tuple[1], conversation, relevance_threshold) 213 | logger(f"\nrelevance_check: {relevance_check}\nmemory_tuple: {memory_tuple}", 4) 214 | return relevance_check 215 | 216 | # Use the log_and_check_relevance function in the list comprehension 217 | new_memory_stack = [memory_tuple for memory_tuple in self.memory_stack if log_and_check_relevance(memory_tuple, conversation, relevance_threshold)] 218 | new_memory_stack = new_memory_stack[:params['maximum_memory_stack_size']] 219 | 220 | logger(f"MEMORY_STACK:{new_memory_stack}", 5) 221 | logger(f"MEMORY_STACK SIZE: {len(new_memory_stack)}", 3) 222 | 223 | # Create memory_rows 224 | 225 | memory_len = 0 226 | memory_index = 0 227 | returned_memories = 0 228 | memory_rows = [] 229 | last_index = 0 230 | last_memory_rows_count = 0 231 | 232 | while memory_index < len(new_memory_stack) and memory_len < max_memory_length: 233 | index, memory, _ = new_memory_stack[memory_index] 234 | new_memory_len = memory_len 235 | 236 | i = len(memory) - 1 237 | stop_i = 0 238 | new_memory_rows_count = 0 239 | if last_index == index-1: 240 | i -= 1 # should this happen to be a continuation, then we will skip adding the duplicate memory 241 | if last_index == index+1: 242 | stop_i = 1 243 | while i >= stop_i and memory_len < max_memory_length: 244 | 245 | turn = memory[i] 246 | proposed_user_turn = '' 247 | proposed_bot_turn = '' 248 | 249 | if len(turn) > 0: 250 | user_memory, ai_memory = turn 251 | logger(f"user_memory:{user_memory}\nai_memory:{ai_memory}", 5) 252 | proposed_user_turn = replace_all(user_turn, {'<|user-message|>': user_memory.strip(), '<|round|>': str(index)}) 253 | proposed_bot_turn = bot_turn.replace('<|bot-message|>', ai_memory.strip()) 254 | 255 | new_memory_len = new_memory_len + len(encode(proposed_user_turn)[0]) + len(encode(proposed_bot_turn)[0]) 256 | 257 | if new_memory_len <= max_memory_length: 258 | if last_index == index+1: 259 | memory_rows.insert(last_memory_rows_count, proposed_bot_turn) 260 | memory_rows.insert(last_memory_rows_count, proposed_user_turn) 261 | else: 262 | memory_rows.insert(0, proposed_bot_turn) 263 | memory_rows.insert(0, proposed_user_turn) 264 | 265 | logger(f"adding memory rows from stack for index {index}...\n{proposed_user_turn}\n{proposed_bot_turn}\n", 3) 266 | new_memory_rows_count += 2 267 | 268 | else: 269 | break 270 | 271 | i -= 1 272 | 273 | memory_len = new_memory_len 274 | returned_memories += 1 275 | memory_index += 1 276 | last_index = index 277 | last_memory_rows_count = new_memory_rows_count 278 | 279 | non_relavent_memories = [(index, memory) for index, memory, _ in self.memory_stack if index not in [i for i, _, _ in new_memory_stack]] 280 | memory_index = 0 281 | while memory_index < len(non_relavent_memories) and memory_len < max_memory_length : 282 | index, memory = non_relavent_memories[memory_index] 283 | new_memory_len = memory_len 284 | 285 | i = len(memory) - 1 286 | stop_i = 0 287 | new_memory_rows_count = 0 288 | if last_index == index-1: 289 | i -= 1 # should this happen to be a continuation, then we will skip adding the duplicate memory 290 | if last_index == index+1: 291 | stop_i = 1 292 | while i >= stop_i and memory_len < max_memory_length: 293 | turn = memory[i] 294 | proposed_user_turn = '' 295 | proposed_bot_turn = '' 296 | 297 | if len(turn) > 0: 298 | user_memory, ai_memory = turn 299 | logger(f"user_memory:{user_memory}\nai_memory:{ai_memory}", 5) 300 | proposed_user_turn = replace_all(user_turn, {'<|user-message|>': user_memory.strip(), '<|round|>': str(index)}) 301 | proposed_bot_turn = bot_turn.replace('<|bot-message|>', ai_memory.strip()) 302 | 303 | new_memory_len = new_memory_len + len(encode(proposed_user_turn)[0]) + len(encode(proposed_bot_turn)[0]) 304 | 305 | if new_memory_len <= max_memory_length: 306 | if last_index == index+1: 307 | memory_rows.insert(last_memory_rows_count, proposed_bot_turn) 308 | memory_rows.insert(last_memory_rows_count, proposed_user_turn) 309 | else: 310 | memory_rows.insert(0, proposed_bot_turn) 311 | memory_rows.insert(0, proposed_user_turn) 312 | 313 | logger(f"adding memory rows from non_relavant for index {index}...\n{proposed_user_turn}\n{proposed_bot_turn}\n", 3) 314 | new_memory_rows_count += 2 315 | 316 | else: 317 | break 318 | 319 | 320 | i -= 1 321 | 322 | memory_len = new_memory_len 323 | returned_memories += 1 324 | memory_index += 1 325 | last_index = index 326 | last_memory_rows_count = new_memory_rows_count 327 | 328 | self.memory_stack = new_memory_stack 329 | 330 | return memory_rows, returned_memories 331 | 332 | def custom_generate_chat_prompt(self, user_input, state, **kwargs): 333 | impersonate = kwargs['impersonate'] if 'impersonate' in kwargs else False 334 | _continue = kwargs['_continue'] if '_continue' in kwargs else False 335 | also_return_rows = kwargs['also_return_rows'] if 'also_return_rows' in kwargs else False 336 | is_instruct = state['mode'] == 'instruct' 337 | rows = [state['context'] if is_instruct else f"{state['context'].strip()}\n"] 338 | min_rows = 3 339 | 340 | generate_annoy_db_executor = None 341 | save_files_to_disk_executor = None 342 | 343 | 344 | # Check if memory_stack needs refreshed. 345 | if state['name2'] != self.loaded_character or shared.model_name != self.loaded_model: 346 | self.memory_stack = deque() 347 | self.annoy_index = None 348 | self.loaded_character = state['name2'] 349 | 350 | # Generate annoy database for LTM 351 | if self.annoy_index == None: 352 | self.index_to_history_position, self.annoy_index, self.keyword_tally = self.annoy_manager.generate_annoy_db(params, state, kwargs['history'], self.keyword_tally, logger) 353 | self.annoy_manager.save_files_to_disk(logger) 354 | else: 355 | generate_annoy_db_executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) 356 | generate_annoy_db_executor.submit( 357 | self.annoy_manager.generate_annoy_db, 358 | params, 359 | state, 360 | kwargs['history'], 361 | self.keyword_tally, 362 | logger 363 | ) 364 | result = None 365 | while not self.annoy_manager.results_queue.empty(): 366 | try: 367 | result = self.annoy_manager.results_queue.get_nowait() 368 | except queue.Empty: 369 | continue # in case the queue was emptied between the check and get_nowait() 370 | 371 | # Check if a result was actually fetched before trying to unpack it 372 | if result is not None: 373 | self.index_to_history_position, self.annoy_index, self.keyword_tally = result 374 | # Save files to disk 375 | save_files_to_disk_executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) 376 | save_files_to_disk_executor.submit( 377 | self.annoy_manager.save_files_to_disk, 378 | logger 379 | ) 380 | save_files_to_disk_executor.shutdown(wait=False) 381 | generate_annoy_db_executor.shutdown(wait=False) 382 | 383 | logger(f"Annoy database has length {self.annoy_index.get_n_items()}", 3) 384 | 385 | # Finding the maximum prompt size 386 | max_length = get_max_prompt_length(state) 387 | # Calc the max length for the memory block 388 | max_memory_length = floor(max_length * params['prompt_memory_ratio']) - len(encode("Memories:\n\n\nChat:\n")[0]) 389 | 390 | # Get turn templates 391 | user_turn, bot_turn, user_turn_stripped, bot_turn_stripped = get_turn_templates(state, is_instruct, logger=logger) 392 | 393 | # Building the prompt 394 | memories_header = "Memories:\n" 395 | chat_header = "\nChat:\n" 396 | mem_head_len = len(encode(memories_header)[0]) 397 | chat_head_len = len(encode(chat_header)[0]) 398 | history_partial = [] 399 | history_rows = [] 400 | i = len(kwargs['history']['internal']) - 1 401 | max_history_length = max_length - len(encode(''.join(rows))[0]) - max_memory_length - mem_head_len - chat_head_len 402 | while i >= 0 and len(encode(''.join(history_rows))[0]) < max_history_length: 403 | if _continue and i == len(kwargs['history']['internal']) - 1: 404 | history_rows.insert(0, bot_turn_stripped + kwargs['history']['internal'][i][1].strip()) 405 | else: 406 | history_rows.insert(0, bot_turn.replace('<|bot-message|>', kwargs['history']['internal'][i][1].strip())) 407 | 408 | string = kwargs['history']['internal'][i][0] 409 | if string not in ['', '<|BEGIN-VISIBLE-CHAT|>']: 410 | history_rows.insert(0, replace_all(user_turn, {'<|user-message|>': string.strip(), '<|round|>': str(i)})) 411 | 412 | history_partial.append(kwargs['history']['internal'][i]) 413 | i -= 1 414 | 415 | # Adding related memories to the prompt 416 | rows.insert(0, memories_header) 417 | memory_trigger = [] 418 | if len(kwargs['history']['internal']) > 0 and len(kwargs['history']['internal'][-1]) > 1: 419 | memory_trigger.append(kwargs['history']['internal'][-1][1]) 420 | memory_trigger.append(user_input) 421 | related_memories = self.retrieve_related_memories( 422 | state, 423 | kwargs['history'], 424 | self.annoy_index, 425 | memory_trigger, 426 | history_partial, 427 | self.index_to_history_position, 428 | self.keyword_tally, 429 | num_related_memories=params['num_memories_to_retrieve'], 430 | weight=params['full_memory_additional_weight'] 431 | ) 432 | 433 | # self.annoy_index.unload() # Unload the index so the next one can save properly 434 | 435 | # Merge new memories into memory stack by distance. 436 | self.memory_stack = remove_duplicates(merge_memory_lists_by_distance(self.memory_stack, related_memories, max_new_list_length=params['maximum_memory_stack_size']*params['num_memories_to_retrieve'])) 437 | logger(f"merged {len(related_memories)} memories into stack. Stack size:{len(self.memory_stack)}", 3) 438 | logger(f"MEMORY_STACK:\n{self.memory_stack}", 5) 439 | 440 | memory_rows, num_memories = self.build_memory_rows(state, history_rows[-2:], user_input, max_memory_length, (user_turn, bot_turn), relevance_threshold=params['memory_retention_threshold']) 441 | logger(f"memory_rows:\n{memory_rows}", 5) 442 | # Remove the least relevant memory row from the memory stack so that the stack will be worked through one memory at a time with each prompt. 443 | if num_memories > 0 and len(self.memory_stack) > 0: 444 | self.memory_stack.pop(min(len(self.memory_stack)-1, num_memories-1)) 445 | 446 | 447 | # Insert memory_rows to the prompt 448 | rows.insert(1, chat_header) 449 | rows[1:1] = memory_rows 450 | 451 | 452 | # Insert the history_rows 453 | rows.extend(history_rows) 454 | 455 | if impersonate: 456 | min_rows = 2 457 | rows.append(user_turn_stripped.rstrip(' ')) 458 | elif not _continue: 459 | # Adding the user message 460 | if len(user_input) > 0: 461 | rows.append(replace_all(user_turn, {'<|user-message|>': user_input.strip(), '<|round|>': str(len(kwargs['history']["internal"]))})) 462 | 463 | # Adding the Character prefix 464 | rows.append(apply_extensions("bot_prefix", bot_turn_stripped.rstrip(' '), state=state)) 465 | 466 | while len(rows) > min_rows and len(encode(''.join(rows))[0]) >= max_length: 467 | if len(rows) > 3 + len(memory_rows) + min_rows: 468 | rows.pop(3 + len(memory_rows)) 469 | elif len(rows) > 3 + min_rows: 470 | rows.pop(2) 471 | else: 472 | rows.pop(1) 473 | 474 | prompt = ''.join(rows) 475 | logger(f"custom_generated_prompt:\n\n{prompt}\n\n", 2) 476 | logger(f"prompt_len:{len(encode(prompt)[0])}\nmax_length:{max_length}\nmax_memory_length:{max_memory_length}\nmax_history_length:{max_history_length}\nmax_content_length:{max_history_length+max_memory_length}\ntotal_content_length:{len(encode(rows[0])[0]) + max_history_length + max_memory_length}", 2) 477 | 478 | if also_return_rows: 479 | return prompt, rows 480 | else: 481 | return prompt 482 | 483 | generator = ChatGenerator() 484 | 485 | def custom_generate_chat_prompt(user_input, state, **kwargs): 486 | return generator.custom_generate_chat_prompt(user_input, state, **kwargs) 487 | --------------------------------------------------------------------------------