├── images └── llm-agent.png ├── .env.example ├── pyproject.toml ├── src ├── utils.py ├── chat.py ├── db.py ├── embedder.py ├── search_engine.py ├── summarizer.py └── llm_agent.py ├── README.md ├── .gitignore └── requirements.txt /images/llm-agent.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AkiRusProd/llm-agent/HEAD/images/llm-agent.png -------------------------------------------------------------------------------- /.env.example: -------------------------------------------------------------------------------- 1 | LLM_PATH='your llm path' 2 | HUGGINGFACE_HUB_CACHE='your huggingface hub cache path to store summarizer models' 3 | DB_PATH='your vector db path' 4 | GOOGLE_CLOUD_API_KEY='your google cloud api key' 5 | GOOGLE_SEARCH_ENGINE_ID='your google search engine id' -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "llm_agent" 3 | version = "0.1.1" 4 | description = "LLM using long-term memory through vector database" 5 | authors = ["Rustam Akimov"] 6 | readme = "README.md" 7 | 8 | [tool.poetry.dependencies] 9 | python = "^3.10" 10 | chromadb = "^0.4.24" 11 | gpt4all = "^2.2.1.post1" 12 | llama-cpp-python = "^0.2.53" 13 | bs4 = "^0.0.2" 14 | torch = "^2.2.1" 15 | transformers = "^4.38.1" 16 | tokenizers = "^0.15.2" 17 | termcolor = "^2.4.0" 18 | guidance = "^0.1.13" 19 | 20 | 21 | [build-system] 22 | requires = ["poetry-core"] 23 | build-backend = "poetry.core.masonry.api" 24 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from functools import wraps 3 | from termcolor import colored 4 | 5 | 6 | def cosine_similarity(a: np.ndarray, b: np.ndarray): 7 | return np.dot(a, b.T) / (np.linalg.norm(a) * np.linalg.norm(b)) 8 | 9 | 10 | 11 | 12 | def logging(enabled = True, message = "", color = "yellow"): 13 | def decorator(func): 14 | @wraps(func) 15 | def wrapper(*args, **kwargs): 16 | if enabled: 17 | print(f"LOG: {colored(message, color = color)}") 18 | return func(*args, **kwargs) 19 | return wrapper 20 | return decorator -------------------------------------------------------------------------------- /src/chat.py: -------------------------------------------------------------------------------- 1 | from llm_agent import LLMAgent 2 | from embedder import HFEmbedder 3 | from search_engine import SearchEngine 4 | from summarizer import Summarizer 5 | from db import DBInstance 6 | from dotenv import dotenv_values 7 | 8 | env = dotenv_values(".env") 9 | 10 | 11 | def chat(): 12 | while True: 13 | user_text_request = input("You > ") 14 | 15 | bot_text_response = llm_agent.generate(user_text_request) 16 | print(f"Bot < {bot_text_response}") 17 | 18 | 19 | if __name__ == "__main__": 20 | embedder = HFEmbedder() 21 | search_engine = SearchEngine() 22 | summarizer = Summarizer() 23 | 24 | db_instance = DBInstance("long-term-memory", embedder=embedder) 25 | 26 | llm_agent = LLMAgent( 27 | env["LLM_PATH"], db_instance, summarizer, search_engine, use_summarizer=False 28 | ) 29 | 30 | chat() 31 | -------------------------------------------------------------------------------- /src/db.py: -------------------------------------------------------------------------------- 1 | import chromadb 2 | import uuid 3 | import datetime 4 | from embedder import BaseEmbedder, HFEmbedder 5 | from dotenv import dotenv_values 6 | 7 | env = dotenv_values(".env") 8 | DB_PATH = env["DB_PATH"] 9 | 10 | 11 | class DBInstance(): 12 | def __init__(self, collection_name, db_path = DB_PATH, embedder: BaseEmbedder = None): 13 | self.embedder = embedder 14 | self.client = chromadb.PersistentClient(path = db_path) 15 | self.collection = self.client.get_or_create_collection(name = collection_name, embedding_function = self.embedder) 16 | 17 | def add(self, text, metadata = {}): 18 | metadata['timestamp'] = str(datetime.datetime.now()) 19 | 20 | self.collection.add( 21 | documents = [text], 22 | metadatas = [metadata], 23 | ids = [str(uuid.uuid4())] 24 | ) 25 | 26 | def delete(self, id): 27 | self.collection.delete(id) 28 | 29 | def query(self, query, n_results, return_text = True): 30 | query = self.collection.query( 31 | query_texts = query, 32 | n_results = n_results, 33 | ) 34 | 35 | if return_text: 36 | return query['documents'][0] 37 | else: 38 | return query 39 | -------------------------------------------------------------------------------- /src/embedder.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import os 4 | import numpy as np 5 | from transformers import AutoModel, AutoTokenizer 6 | from chromadb import EmbeddingFunction 7 | from gpt4all import Embed4All 8 | from dotenv import dotenv_values 9 | 10 | env = dotenv_values(".env") 11 | os.environ['HUGGINGFACE_HUB_CACHE'] = env['HUGGINGFACE_HUB_CACHE'] 12 | 13 | 14 | 15 | class BaseEmbedder(EmbeddingFunction): 16 | def __init__(self): 17 | pass 18 | 19 | def get_embeddings(self, texts): 20 | raise NotImplementedError("Subclasses should implement this!") 21 | 22 | def __call__(self, input): 23 | return self.get_embeddings(input) 24 | 25 | 26 | class GPT4AllEmbedder(BaseEmbedder): 27 | def __init__(self): 28 | self.embedder = Embed4All() # default: all-MiniLM-L6-v2 29 | 30 | def get_embeddings(self, texts): 31 | if type(texts) == str: 32 | texts = [texts] 33 | 34 | embeddings = [] 35 | for text in texts: 36 | embeddings.append(self.embedder.embed(text)) 37 | 38 | return embeddings 39 | 40 | 41 | # https://huggingface.co/princeton-nlp/sup-simcse-roberta-large 42 | class HFEmbedder(BaseEmbedder): 43 | def __init__(self, model = 'princeton-nlp/sup-simcse-roberta-large'): #sentence-transformers/all-MiniLM-L6-v2 44 | self.device = 'cuda' if torch.cuda.is_available() else 'cpu' 45 | self.model = AutoModel.from_pretrained(model).to(self.device) 46 | self.tokenizer = AutoTokenizer.from_pretrained(model) 47 | 48 | def get_embeddings(self, texts): 49 | if type(texts) == str: 50 | texts = [texts] 51 | 52 | inputs = self.tokenizer(texts, padding=True, truncation=True, return_tensors="pt").to(self.device) 53 | 54 | with torch.no_grad(): 55 | embeddings = self.model(**inputs, output_hidden_states=True, return_dict=True).pooler_output.detach().cpu().numpy() 56 | 57 | norms = np.linalg.norm(embeddings, axis=1, keepdims=True) 58 | normalized_embeddings = embeddings / norms 59 | 60 | return normalized_embeddings.tolist() 61 | 62 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # long-term-memory-llm 2 | RAG-based LLM using long-term memory through vector database 3 | 4 | ## Description 5 | This repository enables the large language model to use long-term memory through a vector database (This method is called RAG (Retrieval Augmented Generation) — this is a technique that allows LLM to retrieve facts from an external database). The application is built with [mistral-7b-instruct-v0.2.Q4_K_M.gguf](https://huggingface.co/TheBloke/Mistral-7B-Instruct-v0.2-GGUF) (using [LLAMA_cpp_python](https://github.com/abetlen/llama-cpp-python) binding) and [chromadb](https://github.com/chroma-core/chroma). User can ask in natural language to add information to db, find information from db or the Internet using [guidance](https://github.com/guidance-ai/guidance). 6 | 7 | 8 | ### Current features: 9 | - add new memory: add information (in quotes) in natural language to the database 10 | - query memory: request information from a database in natural language 11 | - web search (experimental): find information from the Internet in natural language 12 | 13 | ### Diagram: 14 | ![Diagram](images/llm-agent.png) 15 | 16 | ### Example: 17 | ``` 18 | You > Hi 19 | LOG: [Response] 20 | Bot < Hello! How can I assist you today? 21 | You > Please add information to db "The user name is Rustam Akimov" 22 | LOG: [Adding to memory] 23 | Bot < Done! 24 | You > Can you find on the Internet who is Pavel Durov 25 | LOG: [Extracting question] 26 | LOG: [Searching] 27 | LOG: [Summarizing] 28 | Bot < According to the search results provided, Pavel Durov is a Russian entrepreneur who co-founded Telegram Messenger Inc. 29 | You > Please find information in db who is Rustam Akimov 30 | LOG: [Extracting question] 31 | LOG: [Querying memory] 32 | Bot < According to the input memories, your name is Rustam Akimov. 33 | ``` 34 | 35 | ### Usage: 36 | - Install requirements.txt 37 | - Download [mistral-7b-instruct-v0.2.Q4_K_M.gguf](https://huggingface.co/TheBloke/Mistral-7B-Instruct-v0.2-GGUF) (Note: you can use other models) 38 | - Get [Google API key](https://developers.google.com/webmaster-tools/search-console-api/v1/configure) and [Search Engine ID](https://programmablesearchengine.google.com/controlpanel/create) 39 | - Specify variables in .env 40 | - Run [chat.py](src/chat.py) 41 | -------------------------------------------------------------------------------- /src/search_engine.py: -------------------------------------------------------------------------------- 1 | import requests 2 | from bs4 import BeautifulSoup 3 | from dotenv import dotenv_values 4 | env = dotenv_values(".env") 5 | # from googlesearch import search 6 | 7 | 8 | 9 | API_KEY = env['GOOGLE_CLOUD_API_KEY'] 10 | SEARCH_ENGINE_ID = env['GOOGLE_SEARCH_ENGINE_ID'] 11 | 12 | 13 | 14 | 15 | class SearchEngine(): 16 | def __init__(self) -> None: 17 | # self.url = lambda query: f'https://www.googleapis.com/customsearch/v1?key={API_KEY}&cx={SEARCH_ENGINE_ID}&q={query}' 18 | self.url = 'https://www.googleapis.com/customsearch/v1' 19 | 20 | def payload(self, query, start = 1, n_results = 3, date_restrict='m1', **params): 21 | payload = { 22 | 'key': API_KEY, 23 | 'q': query, # Query string 24 | 'start': start, # Index of the first search result 25 | 'cx': SEARCH_ENGINE_ID, 26 | 'num': n_results, # Number of search results to return 27 | # 'dateRestrict': date_restrict, # Date restriction for search results 28 | 29 | } 30 | payload.update(params) 31 | 32 | return payload 33 | 34 | def scrape(self, url): 35 | try: 36 | response = requests.get(url) 37 | except: 38 | return None 39 | 40 | html = response.text 41 | soup = BeautifulSoup(html, 'html.parser') 42 | 43 | for script in soup(['script', 'style']): 44 | script.extract() 45 | 46 | text = soup.get_text() 47 | 48 | clean_text = ' '.join(text.split()) 49 | return clean_text 50 | 51 | def search(self, query:str, n_results: int = 3): 52 | # response = requests.get(self.url(query)) 53 | response = requests.get(self.url, params = self.payload(query, n_results = n_results)) 54 | 55 | request = [] 56 | 57 | if response.status_code == 200: 58 | data = response.json() 59 | 60 | items = data.get('items', []) 61 | for item in items: 62 | 63 | title = item.get('title') 64 | link = item.get('link') 65 | content = self.scrape(link) 66 | 67 | request.append({'title': title, 'link': link, 'content': content}) if content is not None else None 68 | else: 69 | print('Error occurred during the search.') 70 | 71 | return request 72 | 73 | 74 | 75 | # search_engine = SearchEngine() 76 | # response = search_engine.search('python') 77 | 78 | -------------------------------------------------------------------------------- /src/summarizer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from transformers import AutoTokenizer, AutoModelForSeq2SeqLM 4 | from transformers import pipeline 5 | from dotenv import dotenv_values 6 | 7 | env = dotenv_values(".env") 8 | os.environ['HUGGINGFACE_HUB_CACHE'] = env['HUGGINGFACE_HUB_CACHE'] 9 | 10 | # checkpoint = "t5-small" 11 | # checkpoint = "google/mt5-small" 12 | # checkpoint = "facebook/bart-large-cnn" 13 | checkpoint = "sshleifer/distilbart-cnn-12-6" 14 | 15 | # class Summarizer(): 16 | # def __init__(self, model = checkpoint) -> None: 17 | # self.summarizer = pipeline("summarization", model = model)#, min_length = 30, max_length = 300 18 | 19 | # def summarize(self, text: str, min_length_ratio = 0.3, max_length_ratio = 1.): 20 | # if len(text) < 5: 21 | # return text 22 | 23 | # prompt = f"summarize: {text}" 24 | 25 | # return self.summarizer(prompt, min_length = int(min_length_ratio * len(prompt.split(" "))), max_length = int(max_length_ratio * len(prompt.split(" "))))[0]['summary_text'] 26 | 27 | # def __call__(self, text, min_length_ratio = 0.3, max_length_ratio = 1.): 28 | # return self.summarize(text, min_length_ratio, max_length_ratio) 29 | 30 | 31 | 32 | # https://discuss.huggingface.co/t/summarization-on-long-documents/920/23 33 | # https://www.width.ai/post/4-long-text-summarization-methods 34 | 35 | class Summarizer(): 36 | def __init__(self, model = checkpoint) -> None: 37 | self.device = 'cuda' if torch.cuda.is_available() else 'cpu' 38 | self.model = AutoModelForSeq2SeqLM.from_pretrained(model).to(self.device) 39 | self.tokenizer = AutoTokenizer.from_pretrained(model) 40 | 41 | # self.model = BartForConditionalGeneration.from_pretrained(model_name)#.to('cuda') 42 | # self.tokenizer = BartTokenizer.from_pretrained(model_name) 43 | 44 | def summarize(self, text: str, min_length = 30, max_length = 100): 45 | """Fixed-size chunking""" 46 | inputs_no_trunc = self.tokenizer(text, max_length=None, return_tensors='pt', truncation=False) 47 | if len(inputs_no_trunc['input_ids'][0]) < 30: 48 | return text 49 | 50 | # min_length = min_length_ratio * len(inputs) 51 | # max_length = max_length_ratio * len(inputs) 52 | 53 | inputs_batch_lst = [] 54 | chunk_start = 0 55 | chunk_end = self.tokenizer.model_max_length # == 1024 for Bart 56 | while chunk_start <= len(inputs_no_trunc['input_ids'][0]): 57 | inputs_batch = inputs_no_trunc['input_ids'][0][chunk_start:chunk_end] # get batch of n tokens 58 | inputs_batch = torch.unsqueeze(inputs_batch, 0) 59 | inputs_batch_lst.append(inputs_batch) 60 | chunk_start += self.tokenizer.model_max_length # == 1024 for Bart 61 | chunk_end += self.tokenizer.model_max_length # == 1024 for Bart 62 | summary_ids_lst = [self.model.generate(inputs.to(self.device), num_beams=4, min_length=min_length, max_length=max_length, early_stopping=True) for inputs in inputs_batch_lst] 63 | 64 | summary_batch_lst = [] 65 | for summary_id in summary_ids_lst: 66 | summary_batch = [self.tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summary_id] 67 | summary_batch_lst.append(summary_batch[0]) 68 | summary_all = '\n'.join(summary_batch_lst) 69 | 70 | return summary_all 71 | 72 | def __call__(self, text, min_length = 30, max_length = 100): 73 | return self.summarize(text, min_length, max_length) -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | .venv/ 163 | db/ 164 | models/ 165 | hf-models/ 166 | test.py 167 | .google-cookie -------------------------------------------------------------------------------- /src/llm_agent.py: -------------------------------------------------------------------------------- 1 | import re 2 | from search_engine import SearchEngine 3 | from summarizer import Summarizer 4 | from db import DBInstance 5 | from guidance.models import LlamaCpp 6 | from guidance import gen, select 7 | from utils import logging 8 | 9 | 10 | enable_logging = True 11 | 12 | 13 | 14 | 15 | 16 | class LLMAgent: 17 | def __init__( 18 | self, 19 | model_name: str = None, 20 | db_instance: DBInstance = None, 21 | summarizer: Summarizer = None, 22 | search_engine: SearchEngine = None, 23 | use_summarizer=True, 24 | ) -> None: 25 | 26 | self.llm = LlamaCpp(model=model_name, n_ctx=8192, verbose=True) 27 | self.db_instance = db_instance 28 | self.memory_access_threshold = 1.5 29 | # self.similarity_threshold = 0.5 # [0; 1] 30 | self.db_n_results = 3 31 | self.se_n_results = 1 32 | self.use_summarizer = use_summarizer 33 | 34 | self.summarizer = summarizer 35 | self.search_engine = search_engine 36 | self.chat_prompt_template = "[INST] {prompt} [/INST]" 37 | 38 | @logging(enable_logging, message="[Adding to memory]") 39 | def add(self, request): 40 | self.db_instance.add(request) if request != "" else None 41 | 42 | @logging(enable_logging, message="[Querying memory]") 43 | def memory_response(self, request): 44 | memory_queries_data = self.db_instance.query( 45 | request, n_results=self.db_n_results, return_text=False 46 | ) 47 | memory_queries = memory_queries_data["documents"][0] 48 | memory_queries_distances = memory_queries_data["distances"][0] 49 | 50 | acceptable_memory_queries = [] 51 | 52 | for query, distance in list(zip(memory_queries, memory_queries_distances)): 53 | if distance < self.memory_access_threshold: 54 | # if (1 - distance) >= self.similarity_threshold: 55 | acceptable_memory_queries.append(query) 56 | 57 | if len(acceptable_memory_queries) == 0: 58 | # return self.llm.response(request) 59 | return None 60 | 61 | prompt_template = """\ 62 | By considering below input memories from me, answer the question if its provided in memory, else just answer without memory: 63 | QUESTION: 64 | `{text}` 65 | MEMORY CHUNKS: 66 | {context} 67 | """ 68 | 69 | context = "" 70 | for i, query in enumerate(memory_queries): 71 | context += f"MEMORY CHUNK {i}: {query}\n" 72 | 73 | queries = prompt_template.format(text=request, context=context) 74 | 75 | out = ( 76 | self.llm 77 | + self.chat_prompt_template.format(prompt=queries) 78 | + " " 79 | + gen(name="response", temperature=1) 80 | ) 81 | return out["response"] 82 | 83 | @logging(enable_logging, message="[Searching]") 84 | def search(self, request): 85 | search_response = self.search_engine.search( 86 | request, n_results=self.se_n_results 87 | ) 88 | 89 | for response in search_response: 90 | response["content"] = self._summarize(response["content"]) 91 | 92 | prompt_template = """\ 93 | You have been given access to the Internet. 94 | By considering below search results, summarize the information if its provided in search result, else just answer without search results: 95 | QUESTION: 96 | `{text}` 97 | SEARCH RESULTS: 98 | {context} 99 | """ 100 | 101 | context = "" 102 | for i, query in enumerate(search_response): 103 | context += f"SEARCH TITLE: {query['title']}\nSEARCH LINK: {query['link']}\nSEARCH CONTENT: {query['content']}\n" 104 | 105 | queries = prompt_template.format(text=request, context=context) 106 | out = ( 107 | self.llm 108 | + self.chat_prompt_template.format(prompt=queries) 109 | + " " 110 | + gen(name="response", temperature=1) 111 | ) 112 | return out["response"] 113 | 114 | @logging(enable_logging, message="[Summarizing]", color="green") 115 | def _summarize(self, text, min_length=30, max_length=100): 116 | return self.summarizer(text, min_length, max_length) 117 | 118 | @logging(enable_logging, message="[Extracting question]", color="green") 119 | def _extract_query(self, request): 120 | prompt_template = """\ 121 | Extract the question from the following text: 122 | "{request}" 123 | """ 124 | 125 | prompt = prompt_template.format(request=request) 126 | 127 | out = ( 128 | self.llm 129 | + self.chat_prompt_template.format(prompt=prompt) 130 | + " " 131 | + f"""Extracted question: "{gen(name='question', temperature=1, stop='"')}""" 132 | ) 133 | 134 | return out["question"] 135 | 136 | @logging(enable_logging, message="[Response]") 137 | def response(self, request): 138 | out = ( 139 | self.llm 140 | + self.chat_prompt_template.format(prompt=request) 141 | + " " 142 | + gen(name="response", temperature=1) 143 | ) 144 | return out["response"] 145 | 146 | def generate(self, request: str): 147 | choises = ["ANSWER", "WEB_SEARCH", "DB_SEARCH", "ADD_MEMORY"] 148 | 149 | prompt_template = """\ 150 | {request} 151 | Please choose an option from below: 152 | ANSWER - Answer the question 153 | WEB_SEARCH - Search on the web 154 | DB_SEARCH - Find information in the database 155 | ADD_MEMORY - Add information to the database 156 | Default option: ANSWER 157 | """ 158 | 159 | prompt = prompt_template.format(request=request) 160 | 161 | out = ( 162 | self.llm 163 | + self.chat_prompt_template.format(prompt=prompt) 164 | + " " 165 | + f"Choice: {select(choises, name='choice')}" 166 | ) 167 | 168 | if out["choice"] == "ANSWER": 169 | return self.response(request) 170 | 171 | elif out["choice"] == "WEB_SEARCH": 172 | return self.search(self._extract_query(request)) 173 | 174 | elif out["choice"] == "DB_SEARCH": 175 | return self.memory_response(self._extract_query(request)) 176 | 177 | elif out["choice"] == "ADD_MEMORY": 178 | matches = re.findall(r'"(.*?)"', request) 179 | 180 | for item in matches: 181 | self.add(item) 182 | 183 | return "Done!" 184 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | annotated-types==0.6.0 ; python_version >= "3.10" and python_version < "4.0" 2 | anyio==4.3.0 ; python_version >= "3.10" and python_version < "4.0" 3 | asgiref==3.7.2 ; python_version >= "3.10" and python_version < "4.0" 4 | backoff==2.2.1 ; python_version >= "3.10" and python_version < "4.0" 5 | bcrypt==4.1.2 ; python_version >= "3.10" and python_version < "4.0" 6 | beautifulsoup4==4.12.3 ; python_version >= "3.10" and python_version < "4.0" 7 | bs4==0.0.2 ; python_version >= "3.10" and python_version < "4.0" 8 | build==1.1.1 ; python_version >= "3.10" and python_version < "4.0" 9 | cachetools==5.3.3 ; python_version >= "3.10" and python_version < "4.0" 10 | certifi==2024.2.2 ; python_version >= "3.10" and python_version < "4.0" 11 | charset-normalizer==3.3.2 ; python_version >= "3.10" and python_version < "4.0" 12 | chroma-hnswlib==0.7.3 ; python_version >= "3.10" and python_version < "4.0" 13 | chromadb==0.4.24 ; python_version >= "3.10" and python_version < "4.0" 14 | click==8.1.7 ; python_version >= "3.10" and python_version < "4.0" 15 | colorama==0.4.6 ; python_version >= "3.10" and python_version < "4.0" and (os_name == "nt" or platform_system == "Windows" or sys_platform == "win32") 16 | coloredlogs==15.0.1 ; python_version >= "3.10" and python_version < "4.0" 17 | deprecated==1.2.14 ; python_version >= "3.10" and python_version < "4.0" 18 | diskcache==5.6.3 ; python_version >= "3.10" and python_version < "4.0" 19 | distro==1.9.0 ; python_version >= "3.10" and python_version < "4.0" 20 | exceptiongroup==1.2.0 ; python_version >= "3.10" and python_version < "3.11" 21 | fastapi==0.110.0 ; python_version >= "3.10" and python_version < "4.0" 22 | filelock==3.13.1 ; python_version >= "3.10" and python_version < "4.0" 23 | flatbuffers==23.5.26 ; python_version >= "3.10" and python_version < "4.0" 24 | fsspec==2024.2.0 ; python_version >= "3.10" and python_version < "4.0" 25 | google-auth==2.28.1 ; python_version >= "3.10" and python_version < "4.0" 26 | googleapis-common-protos==1.62.0 ; python_version >= "3.10" and python_version < "4.0" 27 | gpt4all==2.2.1.post1 ; python_version >= "3.10" and python_version < "4.0" 28 | grpcio==1.62.0 ; python_version >= "3.10" and python_version < "4.0" 29 | guidance==0.1.13 ; python_version >= "3.10" and python_version < "4.0" 30 | h11==0.14.0 ; python_version >= "3.10" and python_version < "4.0" 31 | httpcore==1.0.5 ; python_version >= "3.10" and python_version < "4.0" 32 | httptools==0.6.1 ; python_version >= "3.10" and python_version < "4.0" 33 | httpx==0.27.0 ; python_version >= "3.10" and python_version < "4.0" 34 | huggingface-hub==0.21.3 ; python_version >= "3.10" and python_version < "4.0" 35 | humanfriendly==10.0 ; python_version >= "3.10" and python_version < "4.0" 36 | idna==3.6 ; python_version >= "3.10" and python_version < "4.0" 37 | importlib-metadata==6.11.0 ; python_version >= "3.10" and python_version < "4.0" 38 | importlib-resources==6.1.2 ; python_version >= "3.10" and python_version < "4.0" 39 | jinja2==3.1.3 ; python_version >= "3.10" and python_version < "4.0" 40 | kubernetes==29.0.0 ; python_version >= "3.10" and python_version < "4.0" 41 | llama-cpp-python==0.2.53 ; python_version >= "3.10" and python_version < "4.0" 42 | markupsafe==2.1.5 ; python_version >= "3.10" and python_version < "4.0" 43 | mmh3==4.1.0 ; python_version >= "3.10" and python_version < "4.0" 44 | monotonic==1.6 ; python_version >= "3.10" and python_version < "4.0" 45 | mpmath==1.3.0 ; python_version >= "3.10" and python_version < "4.0" 46 | networkx==3.2.1 ; python_version >= "3.10" and python_version < "4.0" 47 | numpy==1.26.4 ; python_version >= "3.10" and python_version < "4.0" 48 | nvidia-cublas-cu12==12.1.3.1 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.10" and python_version < "4.0" 49 | nvidia-cuda-cupti-cu12==12.1.105 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.10" and python_version < "4.0" 50 | nvidia-cuda-nvrtc-cu12==12.1.105 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.10" and python_version < "4.0" 51 | nvidia-cuda-runtime-cu12==12.1.105 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.10" and python_version < "4.0" 52 | nvidia-cudnn-cu12==8.9.2.26 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.10" and python_version < "4.0" 53 | nvidia-cufft-cu12==11.0.2.54 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.10" and python_version < "4.0" 54 | nvidia-curand-cu12==10.3.2.106 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.10" and python_version < "4.0" 55 | nvidia-cusolver-cu12==11.4.5.107 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.10" and python_version < "4.0" 56 | nvidia-cusparse-cu12==12.1.0.106 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.10" and python_version < "4.0" 57 | nvidia-nccl-cu12==2.19.3 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.10" and python_version < "4.0" 58 | nvidia-nvjitlink-cu12==12.3.101 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.10" and python_version < "4.0" 59 | nvidia-nvtx-cu12==12.1.105 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.10" and python_version < "4.0" 60 | oauthlib==3.2.2 ; python_version >= "3.10" and python_version < "4.0" 61 | onnxruntime==1.17.1 ; python_version >= "3.10" and python_version < "4.0" 62 | openai==1.16.1 ; python_version >= "3.10" and python_version < "4.0" 63 | opentelemetry-api==1.23.0 ; python_version >= "3.10" and python_version < "4.0" 64 | opentelemetry-exporter-otlp-proto-common==1.23.0 ; python_version >= "3.10" and python_version < "4.0" 65 | opentelemetry-exporter-otlp-proto-grpc==1.23.0 ; python_version >= "3.10" and python_version < "4.0" 66 | opentelemetry-instrumentation-asgi==0.44b0 ; python_version >= "3.10" and python_version < "4.0" 67 | opentelemetry-instrumentation-fastapi==0.44b0 ; python_version >= "3.10" and python_version < "4.0" 68 | opentelemetry-instrumentation==0.44b0 ; python_version >= "3.10" and python_version < "4.0" 69 | opentelemetry-proto==1.23.0 ; python_version >= "3.10" and python_version < "4.0" 70 | opentelemetry-sdk==1.23.0 ; python_version >= "3.10" and python_version < "4.0" 71 | opentelemetry-semantic-conventions==0.44b0 ; python_version >= "3.10" and python_version < "4.0" 72 | opentelemetry-util-http==0.44b0 ; python_version >= "3.10" and python_version < "4.0" 73 | ordered-set==4.1.0 ; python_version >= "3.10" and python_version < "4.0" 74 | orjson==3.9.15 ; python_version >= "3.10" and python_version < "4.0" 75 | overrides==7.7.0 ; python_version >= "3.10" and python_version < "4.0" 76 | packaging==23.2 ; python_version >= "3.10" and python_version < "4.0" 77 | platformdirs==4.2.0 ; python_version >= "3.10" and python_version < "4.0" 78 | posthog==3.4.2 ; python_version >= "3.10" and python_version < "4.0" 79 | protobuf==4.25.3 ; python_version >= "3.10" and python_version < "4.0" 80 | pulsar-client==3.4.0 ; python_version >= "3.10" and python_version < "4.0" 81 | pyasn1-modules==0.3.0 ; python_version >= "3.10" and python_version < "4.0" 82 | pyasn1==0.5.1 ; python_version >= "3.10" and python_version < "4.0" 83 | pydantic-core==2.16.3 ; python_version >= "3.10" and python_version < "4.0" 84 | pydantic==2.6.3 ; python_version >= "3.10" and python_version < "4.0" 85 | pydot==2.0.0 ; python_version >= "3.10" and python_version < "4.0" 86 | pyformlang==1.0.9 ; python_version >= "3.10" and python_version < "4.0" 87 | pyparsing==3.1.2 ; python_version >= "3.10" and python_version < "4.0" 88 | pypika==0.48.9 ; python_version >= "3.10" and python_version < "4.0" 89 | pyproject-hooks==1.0.0 ; python_version >= "3.10" and python_version < "4.0" 90 | pyreadline3==3.4.1 ; sys_platform == "win32" and python_version >= "3.10" and python_version < "4.0" 91 | python-dateutil==2.8.2 ; python_version >= "3.10" and python_version < "4.0" 92 | python-dotenv==1.0.1 ; python_version >= "3.10" and python_version < "4.0" 93 | pyyaml==6.0.1 ; python_version >= "3.10" and python_version < "4.0" 94 | regex==2023.12.25 ; python_version >= "3.10" and python_version < "4.0" 95 | requests-oauthlib==1.3.1 ; python_version >= "3.10" and python_version < "4.0" 96 | requests==2.31.0 ; python_version >= "3.10" and python_version < "4.0" 97 | rsa==4.9 ; python_version >= "3.10" and python_version < "4" 98 | safetensors==0.4.2 ; python_version >= "3.10" and python_version < "4.0" 99 | setuptools==69.1.1 ; python_version >= "3.10" and python_version < "4.0" 100 | six==1.16.0 ; python_version >= "3.10" and python_version < "4.0" 101 | sniffio==1.3.1 ; python_version >= "3.10" and python_version < "4.0" 102 | soupsieve==2.5 ; python_version >= "3.10" and python_version < "4.0" 103 | starlette==0.36.3 ; python_version >= "3.10" and python_version < "4.0" 104 | sympy==1.12 ; python_version >= "3.10" and python_version < "4.0" 105 | tenacity==8.2.3 ; python_version >= "3.10" and python_version < "4.0" 106 | termcolor==2.4.0 ; python_version >= "3.10" and python_version < "4.0" 107 | tiktoken==0.6.0 ; python_version >= "3.10" and python_version < "4.0" 108 | tokenizers==0.15.2 ; python_version >= "3.10" and python_version < "4.0" 109 | tomli==2.0.1 ; python_version >= "3.10" and python_version < "3.11" 110 | torch==2.2.1 ; python_version >= "3.10" and python_version < "4.0" 111 | tqdm==4.66.2 ; python_version >= "3.10" and python_version < "4.0" 112 | transformers==4.38.1 ; python_version >= "3.10" and python_version < "4.0" 113 | triton==2.2.0 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version < "3.12" and python_version >= "3.10" 114 | typer==0.9.0 ; python_version >= "3.10" and python_version < "4.0" 115 | typing-extensions==4.10.0 ; python_version >= "3.10" and python_version < "4.0" 116 | urllib3==2.2.1 ; python_version >= "3.10" and python_version < "4.0" 117 | uvicorn==0.27.1 ; python_version >= "3.10" and python_version < "4.0" 118 | uvicorn[standard]==0.27.1 ; python_version >= "3.10" and python_version < "4.0" 119 | uvloop==0.19.0 ; (sys_platform != "win32" and sys_platform != "cygwin") and platform_python_implementation != "PyPy" and python_version >= "3.10" and python_version < "4.0" 120 | watchfiles==0.21.0 ; python_version >= "3.10" and python_version < "4.0" 121 | websocket-client==1.7.0 ; python_version >= "3.10" and python_version < "4.0" 122 | websockets==12.0 ; python_version >= "3.10" and python_version < "4.0" 123 | wrapt==1.16.0 ; python_version >= "3.10" and python_version < "4.0" 124 | zipp==3.17.0 ; python_version >= "3.10" and python_version < "4.0" 125 | --------------------------------------------------------------------------------