├── hodja ├── search │ ├── embeddings │ │ ├── __init__.py │ │ ├── base.py │ │ └── openai.py │ ├── __init__.py │ ├── documents.py │ └── docstores.py ├── chains │ ├── __init__.py │ └── base.py ├── __init__.py ├── setup.cfg ├── agents │ ├── __init__.py │ ├── base.py │ └── openai.py ├── tools │ ├── __init__.py │ ├── base.py │ ├── search_tools.py │ └── math.py ├── tests │ ├── documents_test.py │ ├── search_tools_test.py │ └── docstores_test.py └── links │ ├── base.py │ └── react.py ├── static ├── nasreddin_hodja_chain.png └── nasreddin_hodja_chain_2.png ├── setup.py ├── LICENSE ├── .gitignore ├── README.md └── examples.ipynb /hodja/search/embeddings/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /hodja/chains/__init__.py: -------------------------------------------------------------------------------- 1 | from hodja.chains.base import Chain -------------------------------------------------------------------------------- /hodja/__init__.py: -------------------------------------------------------------------------------- 1 | from hodja.chains import Chain 2 | from hodja.links.base import Link -------------------------------------------------------------------------------- /static/nasreddin_hodja_chain.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jdagdelen/hodja/HEAD/static/nasreddin_hodja_chain.png -------------------------------------------------------------------------------- /static/nasreddin_hodja_chain_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jdagdelen/hodja/HEAD/static/nasreddin_hodja_chain_2.png -------------------------------------------------------------------------------- /hodja/search/__init__.py: -------------------------------------------------------------------------------- 1 | from hodja.search.documents import Document 2 | from hodja.search.docstores import FAISS, VectorStore -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """Setup script for the package.""" 2 | from setuptools import setup 3 | 4 | if __name__ == "__main__": 5 | setup() -------------------------------------------------------------------------------- /hodja/setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = hodja 3 | author = John Dagdelen 4 | author_email = jdagdelen@berkeley.edu 5 | version = 0.1 6 | [options] 7 | install_requires = 8 | openai -------------------------------------------------------------------------------- /hodja/search/embeddings/base.py: -------------------------------------------------------------------------------- 1 | """Interface for embedding models.""" 2 | from abc import ABC, abstractmethod 3 | import numpy as np 4 | 5 | class Embeddings(ABC): 6 | """Interface for embeddings.""" 7 | 8 | @abstractmethod 9 | def embed(self, documents, **kwargs): 10 | """Embed documents.""" 11 | raise NotImplementedError -------------------------------------------------------------------------------- /hodja/agents/__init__.py: -------------------------------------------------------------------------------- 1 | """Agents read prompts, make decisions, and then return output based on what they saw. 2 | 3 | Agents are the heard of the Hodja framework. Agents sit inside Links and process the input state, potentially making use of Tools, to create the output state. Agents can be prompt-engineered LLMs, fine-tuned models, or even simple rules-based systems.""" -------------------------------------------------------------------------------- /hodja/tools/__init__.py: -------------------------------------------------------------------------------- 1 | """Tools that can be used in Links to perform useful tasks. 2 | 3 | Tools provide Agents inside of Links with specialized functionality that the Agent may no 4 | have. For example, an Agent may need to perform a search over a database of documents, or 5 | perform a math operation. Tools are the way to give Agents this functionality.""" 6 | 7 | from hodja.tools.base import Tool -------------------------------------------------------------------------------- /hodja/agents/base.py: -------------------------------------------------------------------------------- 1 | """Base class for agents.""" 2 | 3 | from abc import ABC 4 | 5 | class Agent(ABC): 6 | """Base class for agents.""" 7 | 8 | def __init__(self, name): 9 | self.name = name 10 | 11 | def __call__(self, prompt, state, tools): 12 | """Run the agent.""" 13 | raise NotImplementedError 14 | 15 | def prepare_prompt_string(self, prompt, state, tools): 16 | """Formats the prompts with information from state""" 17 | return prompt.format(**state, tools=tools) -------------------------------------------------------------------------------- /hodja/tools/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | class Tool(ABC): 4 | def __init__(self, name, description, instructions): 5 | self.name = name 6 | self.description = description 7 | self.instructions = instructions 8 | 9 | @abstractmethod 10 | def run(self, *args, **kwargs): 11 | raise NotImplementedError 12 | 13 | def __repr__(self): 14 | return f"{self.name}: {self.description}" 15 | 16 | def __str__(self): 17 | return self.__repr__() -------------------------------------------------------------------------------- /hodja/search/documents.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | class DocumentBase(ABC): 4 | """A basic document, which is just a string.""" 5 | 6 | def __init__(self, text): 7 | self.text = text 8 | 9 | def __str__(self): 10 | return f"Document({self.text})" 11 | 12 | def __repr__(self): 13 | return self.text 14 | 15 | 16 | 17 | class Document(DocumentBase): 18 | 19 | def __init__(self, text, **kwargs): 20 | """A document.""" 21 | super().__init__(text) 22 | self.__dict__.update(kwargs) 23 | 24 | def __str__(self): 25 | return f"Document({self.__dict__})" 26 | 27 | def __repr__(self): 28 | return self.__str__() 29 | 30 | def __eq__(self, other): 31 | return self.__dict__ == other.__dict__ 32 | -------------------------------------------------------------------------------- /hodja/tests/documents_test.py: -------------------------------------------------------------------------------- 1 | """Unit test for documents.py""" 2 | 3 | import unittest 4 | from hodja.search import documents 5 | 6 | class TestDocument(unittest.TestCase): 7 | 8 | def test_init(self): 9 | """Test the init method.""" 10 | document = documents.Document(text="test", id=1) 11 | self.assertEqual(document.text, "test") 12 | self.assertEqual(document.id, 1) 13 | 14 | def test_eq(self): 15 | """Test the eq method.""" 16 | document1 = documents.Document(text="test1", id=1) 17 | document2 = documents.Document(text="test2", id=2) 18 | document3 = documents.Document(text="test1", id=1) 19 | self.assertEqual(document1, document3) 20 | self.assertNotEqual(document1, document2) 21 | 22 | 23 | if __name__ == "__main__": 24 | unittest.main() -------------------------------------------------------------------------------- /hodja/chains/base.py: -------------------------------------------------------------------------------- 1 | """Chains are a collection of Links that are executed in order, with state passing from link to link. 2 | 3 | Chains are the main components of Hodja. They are a collection of Links that are executed in order, with state passing from link to link. Chains are the main way to create a Hodja workflow. The user's input goes into the first Link in the Chain, and the output of the last Link in the Chain is the final output that returns to the user.""" 4 | 5 | from abc import ABC, abstractmethod 6 | 7 | 8 | class Chain(ABC): 9 | def __init__(self, name, links=[], intial_state={}): 10 | self.name = name 11 | self.state = intial_state 12 | self.links = links 13 | 14 | def run(self, input, debug=False): 15 | """Run the chain by executing each link in order.""" 16 | self.state["input"] = input 17 | for link in self.links: 18 | self.state = link.run(self.state, debug=debug) 19 | return self.state['output'] 20 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 John Dagdelen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /hodja/tools/search_tools.py: -------------------------------------------------------------------------------- 1 | from hodja.search.docstores import FAISS 2 | from hodja.tools.base import Tool 3 | 4 | 5 | class SearchTool(Tool): 6 | """Tool for wrapping a DocStore so an Agent can search for relevant documents.""" 7 | def __init__( 8 | self, 9 | docstore, 10 | name="Search", 11 | description="Search for documents based on a query. Returns a list of documents that best match the query.", 12 | instructions="Provide query as text."): 13 | super().__init__(name, description, instructions) 14 | self.docstore = docstore 15 | 16 | def run(self, query, top_k=3): 17 | """Search for documents similar to a query. 18 | 19 | Args: 20 | query: Query to search for. 21 | top_k: Number of top documents to return. 22 | 23 | Returns: 24 | List of top documents. 25 | """ 26 | results = self.docstore.search(query, min(top_k, len(self.docstore))) 27 | return results 28 | 29 | def add_docs(self, docs): 30 | """Add documents to the docstore.""" 31 | self.docstore.add(docs) 32 | 33 | 34 | 35 | -------------------------------------------------------------------------------- /hodja/links/base.py: -------------------------------------------------------------------------------- 1 | """Links are the main components of Chains. 2 | 3 | Links take in input state, do work, and then pass output state to the next Link in the Chain. Links are usually composed of an Agent and zero or more Tools. A Link is responsible for processing the input state and creating the output state. Agents are intelligent agents that help do this task. The Tools are used by the Agent to perform the necessary tasks to create the ouput state. Links can enforce conditions on the input state they recieve. For example, a Link may require that there is a "date": exists in the state. Links can choose to ignore parts of the chain state that they don't need.""" 4 | 5 | from abc import ABC, abstractmethod 6 | 7 | class Link(ABC): 8 | def __init__(self, name): 9 | self.name = name 10 | 11 | @abstractmethod 12 | def validate_state(self, state): 13 | """Validate the input state. Raise an exception if the input state is invalid for this Link.""" 14 | raise NotImplementedError 15 | 16 | @abstractmethod 17 | def run(self, input_state, **kwargs): 18 | """Run the link. Return the output state.""" 19 | raise NotImplementedError -------------------------------------------------------------------------------- /hodja/tests/search_tools_test.py: -------------------------------------------------------------------------------- 1 | """Unit tests for search_tools.py""" 2 | 3 | import unittest 4 | from hodja.search.docstores import FAISS 5 | from hodja.tools.search_tools import SearchTool 6 | from hodja.tests.docstores_test import DummyEmbeddings 7 | from hodja.search.documents import Document 8 | 9 | class TestSearchTool(unittest.TestCase): 10 | 11 | def test_init(self): 12 | """Test the init method.""" 13 | docstore = FAISS(DummyEmbeddings()) 14 | search_tool = SearchTool(docstore) 15 | self.assertEqual(search_tool.docstore, docstore) 16 | 17 | def test_run(self): 18 | """Test the run method.""" 19 | docstore = FAISS(DummyEmbeddings()) 20 | search_tool = SearchTool(docstore) 21 | document = Document(text="test", id=1) 22 | search_tool.add_docs([document]) 23 | self.assertEqual(search_tool.run("test"), [document]) 24 | 25 | def test_add_docs(self): 26 | """Test the add_docs method.""" 27 | docstore = FAISS(DummyEmbeddings()) 28 | search_tool = SearchTool(docstore) 29 | document = Document(text="test", id=1) 30 | search_tool.add_docs([document]) 31 | self.assertEqual(search_tool.docstore.documents, [document]) 32 | 33 | 34 | if __name__ == "__main__": 35 | unittest.main() -------------------------------------------------------------------------------- /hodja/tools/math.py: -------------------------------------------------------------------------------- 1 | """Tools for math.""" 2 | 3 | from hodja.tools.base import Tool 4 | 5 | class MathTool(Tool): 6 | """Example tool that evaluates simple math expressions.""" 7 | 8 | def __init__(self, name="Math", description="Evaluate a math expression. Note: only numbers and basic operators (e.g. +-*/) are allowed.", instructions="Enter a math expression."): 9 | super().__init__(name=name, description=description, instructions=instructions) 10 | 11 | def run(self, input): 12 | # check if input is a math expression (e.g. numbers and math operators only) 13 | # if not, raise an error 14 | allowed_characters = "0123456789+-*/.," 15 | for character in input: 16 | if character not in allowed_characters: 17 | raise ValueError("Invalid input. Only numbers, decimals (.), commas (,), and math operators (+-*/) are allowed.") 18 | # evaluate the math expression 19 | return eval(input) 20 | 21 | 22 | class FibonacciTool(Tool): 23 | """Example tool that returns the nth Fibonacci number.""" 24 | 25 | def __init__(self, name="Fibonacci", description="Returns the nth Fibonacci number", instructions="Enter a number n."): 26 | super().__init__(name, description, instructions) 27 | 28 | def run(self, n): 29 | # check if input is a number 30 | # if not, raise an error 31 | try: 32 | n = int(n) 33 | except: 34 | raise ValueError("Invalid input. Only numbers are allowed.") 35 | 36 | def _fib(n): 37 | if n <= 1: 38 | return n 39 | else: 40 | return _fib(n-1) + _fib(n-2) 41 | 42 | return _fib(n+1) -------------------------------------------------------------------------------- /hodja/search/embeddings/openai.py: -------------------------------------------------------------------------------- 1 | """Wrapper around OpenAI embedding models.""" 2 | import os 3 | import numpy as np 4 | from hodja.search.embeddings.base import Embeddings 5 | import openai 6 | 7 | class OpenAIEmbeddings(Embeddings): 8 | """Wrapper around OpenAI embedding models.""" 9 | 10 | def __init__(self, model_name="text-embedding-ada-002", openai_api_key=None): 11 | """Initialize OpenAIEmbeddings. 12 | 13 | Args: 14 | model_name: The name of the model to use. The default is 15 | ``text-embedding-ada-002``. See 16 | https://beta.openai.com/docs/engines for a list of available 17 | models. 18 | openai_api_key: The API key to use. If not provided, will look for 19 | the environment variable ``OPENAI_API_KEY``. 20 | """ 21 | self.model_name = model_name 22 | if openai_api_key is None: 23 | openai_api_key = os.environ.get("OPENAI_API_KEY") 24 | self.openai_api_key = openai_api_key 25 | self.model_name = model_name 26 | self.client = openai.Embedding 27 | 28 | def embed(self, texts, batch_size=1000): 29 | """Call out to OpenAI's embedding endpoint for embedding search docs. 30 | 31 | Args: 32 | texts (List[str]): The texts to embed. 33 | batch_size (int): The maximum number of documents to send to OpenAI at once. 34 | 35 | Returns: 36 | List of embeddings, one for each document. 37 | """ 38 | results = [] 39 | for i in range(0, len(texts), batch_size): 40 | response = self.client.create( 41 | input=texts[i : i + batch_size], engine=self.model_name 42 | ) 43 | results += [r["embedding"] for r in response["data"]] 44 | return results 45 | -------------------------------------------------------------------------------- /hodja/agents/openai.py: -------------------------------------------------------------------------------- 1 | """Classes for interacting with OpenAI's API.""" 2 | import os 3 | from hodja.agents.base import Agent 4 | import openai 5 | from collections import defaultdict 6 | 7 | # create default dict of context sizes 8 | CONTEXT_SIZES = defaultdict(lambda: 2048) 9 | _defaults = { 10 | "text-davinci-003": 4097, 11 | "text-curie-001": 2048, 12 | "text-babbage-001": 2048, 13 | "text-ada-001": 2048, 14 | "code-davinci-002": 8000, 15 | "code-cushman-001": 2048, 16 | "gpt-3.5-turbo": 4097, 17 | } 18 | for k, v in _defaults.items(): 19 | CONTEXT_SIZES[k] = v 20 | 21 | class OpenAIAPIAgent(Agent): 22 | 23 | def __init__( 24 | self, 25 | name='OpenAI Agent', 26 | api_key=os.getenv('OPENAI_API_KEY'), 27 | engine='text-davinci-003', 28 | temperature=0.0, 29 | max_tokens=500, 30 | top_p=1, 31 | frequency_penalty=0, 32 | presence_penalty=0, 33 | stop=['\n\n',], 34 | streaming=False 35 | ): 36 | super().__init__(name=name) 37 | self.api_key = api_key 38 | self.engine = engine 39 | self.context_size = CONTEXT_SIZES[self.engine] 40 | self.temperature = temperature 41 | self.max_tokens = max_tokens 42 | self.top_p = top_p 43 | self.frequency_penalty = frequency_penalty 44 | self.presence_penalty = presence_penalty 45 | self.stop = stop 46 | self.streaming = streaming 47 | 48 | def __call__(self, prompt): 49 | if len(prompt) + self.max_tokens > self.context_size: 50 | raise ValueError( 51 | f"Prompt length ({len(prompt)}) + max_tokens ({self.max_tokens}) " 52 | f"exceeds maximum context size ({self.context_size})." 53 | ) 54 | results = openai.Completion.create( 55 | prompt=prompt, 56 | engine=self.engine, 57 | temperature=self.temperature, 58 | max_tokens=self.max_tokens, 59 | top_p=self.top_p, 60 | frequency_penalty=self.frequency_penalty, 61 | presence_penalty=self.presence_penalty, 62 | stop=self.stop, 63 | stream=self.streaming 64 | ) 65 | return results.choices[0].text -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | Nasreddin Hodja inspects a chain 3 |

4 | 5 | # Hodja 6 | Hodja is a framework for building [augmented language model](https://arxiv.org/abs/2302.07842)-based applications. 7 | 8 | ## About 9 | Hodja is an attempt at creating a streamlined and focused framework for building chains of transformations that use natural language as their interface. I found other projects were hard to use and debug in practice, so I decided to create my own. 10 | 11 | The main concepts in Hodja are: 12 | 13 | ### Chains 14 | A Chain is made up of a sequence of Links. Chains take in an intial input and pass it through each link in the Chain. The output of each Link is passed to the next subsequent Link in the Chain. The output of the final Link in the Chain is the output of the Chain. Information gets passed through the Chain in the form of a state. 15 | 16 | ### Links 17 | A Link is responsible for taking in a state, making decisions about what to do, transforming the state based on those decisions, and then returning the new state. Links are the building blocks of Chains and are the main way to extend Hodja's functionality. Links contain Agents that are reponsible for making decisions about what to do. Complex functionality is built by combining Links together in Chains. 18 | 19 | ### Agents 20 | An Agent lives inside a Link and is responsible for making decisions about what to do with the state. Agents can be simple, such as a function that uses regular expressions to extract information from a string, but they are usually more complex and based on large language models (LLMs). Agents need instructions to know what to do with the input state. We give Agents instructions in the form of **prompt strings**, which are just strings that are filled in with instructions and information from the state. Hodja contains a Prompt class for dynamically constructing prompt strings. Agents can also use Tools that augment them with new capabilities. To create a new Agent, we combine a LLM with a custom Prompt and one or more Tools. 21 | 22 | ### Prompts 23 | Prompts are objects that dynamically construct prompt strings. Prompts are read by Agents to determine what to do with the state. Prompts are constructed using a simple templating language that allows you to insert information from the state into the prompt string. Prompts can also dynamically decide what to put in the prompt string based on the state, for example providing more or less context to the Agent. 24 | 25 | ### Tools 26 | A Tool is a object that can be used by an Agent to augment its capabilities, for example doing simple math expressions or performing a Google search. Tools are used by Agents to perform tasks that they are not able to do by themselves. Tools expose their functionality through a specific interface that allows any Agent to interact with them, and they need to come with "usage instructions" so that Agents will know how to use them at runtime. Almost anything can be wrapped into a Tool. The most common types of Tools expose an external API to Agents (WolframAlphaTool, etc.), give them additional functionality like math or regex operations (MathTool, RegexTool, etc.) or give Agents a way to work with external document stores (e.g. a SearchTool wrapped around a VectorStore.) Entire Chains can be wrapped into Tools, allowing Agents to use the functionality of other Chains. 27 | 28 | 29 | ## Installation 30 | Clone this repository and in the root directory run: 31 | ```python setup.py install``` 32 | 33 | ## Name 34 | The name "Hodja" comes from "[Nasreddin Hodja](https://en.wikipedia.org/wiki/Nasreddin_)" who was a philosopher from what is now modern-day Turkey. His is known for his wit and wisdom, and is the main character of many humerous folk stories. 35 | 36 | -------------------------------------------------------------------------------- /hodja/tests/docstores_test.py: -------------------------------------------------------------------------------- 1 | """Unit tests for the docstores module.""" 2 | 3 | import unittest 4 | 5 | from hodja.search import docstores 6 | from hodja.search.documents import Document 7 | 8 | class TestDocstore(unittest.TestCase): 9 | """Unit tests for the DocStore class.""" 10 | 11 | def test_add(self): 12 | """Test the add method.""" 13 | store = docstores.DocStore() 14 | document = Document(text="test", id=1) 15 | store.add([document]) 16 | self.assertEqual(store.documents, {hash(document.id): document}) 17 | 18 | def test_remove(self): 19 | """Test the remove method.""" 20 | store = docstores.DocStore() 21 | document1 = Document(text="test1", id=1) 22 | document2 = Document(text="test2", id=2) 23 | documents = [document1, document2] 24 | store.add(documents) 25 | store.remove([document1.id]) 26 | self.assertEqual(store.documents, {document2.id: document2}) 27 | 28 | def test_get(self): 29 | """Test the get method.""" 30 | store = docstores.DocStore() 31 | document1 = Document(text="test1", id=1) 32 | document2 = Document(text="test2", id=2) 33 | documents = [document1, document2] 34 | store.add(documents) 35 | self.assertEqual(store.get(document1.id), document1) 36 | 37 | def cond(text): 38 | if "3" in text: 39 | return [0, 0, 0] 40 | else: 41 | return [1, 1, 1] 42 | 43 | def dummy_embedding_function(text): 44 | if not isinstance(text, list): 45 | text = [text] 46 | embeddings = [cond(t) for t in text] 47 | return embeddings 48 | 49 | class DummyEmbeddings: 50 | def __init__(self): 51 | self.embedding_function = dummy_embedding_function 52 | 53 | def embed(self, docs): 54 | return dummy_embedding_function(docs) 55 | 56 | class TestFAISS(unittest.TestCase): 57 | 58 | def test_add(self): 59 | """Test the add method.""" 60 | store = docstores.FAISS(DummyEmbeddings()) 61 | document = Document(text="test", id=1) 62 | store.add([document]) 63 | self.assertEqual(store.documents, [document]) 64 | 65 | def test_remove(self): 66 | """Test the remove method.""" 67 | store = docstores.FAISS(DummyEmbeddings()) 68 | document1 = Document(text="test1", id=1) 69 | document2 = Document(text="test2", id=2) 70 | documents = [document1, document2] 71 | store.add(documents) 72 | store.remove([document1.id]) 73 | self.assertEqual(store.documents, [document2]) 74 | 75 | def test_get(self): 76 | """Test the get method.""" 77 | store = docstores.FAISS(DummyEmbeddings()) 78 | document1 = Document(text="test1", id=1) 79 | document2 = Document(text="test2", id=2) 80 | documents = [document1, document2] 81 | store.add(documents) 82 | self.assertEqual(store.get(document1.id), document1) 83 | 84 | def test_get_all(self): 85 | """Test the get_all method.""" 86 | store = docstores.FAISS(DummyEmbeddings()) 87 | document1 = Document(text="test1", id=1) 88 | document2 = Document(text="test2", id=2) 89 | documents = [document1, document2] 90 | store.add(documents) 91 | self.assertEqual(store.get_all(), [document1, document2]) 92 | 93 | def test_search(self): 94 | """Test the search method.""" 95 | store = docstores.FAISS(DummyEmbeddings()) 96 | document1 = Document(text="test1", id=1) 97 | document2 = Document(text="test2", id=2) 98 | document3 = Document(text="test3", id=3) 99 | documents = [document1, document2, document3] 100 | store.add(documents) 101 | self.assertEqual(store.search(document3.text, 1), [document3]) 102 | 103 | if __name__ == '__main__': 104 | unittest.main() -------------------------------------------------------------------------------- /hodja/links/react.py: -------------------------------------------------------------------------------- 1 | """Link implementing ReACT (https://arxiv.org/abs/2210.03629)""" 2 | from hodja.chains.base import Chain 3 | from hodja.links.base import Link 4 | from hodja.agents.openai import OpenAIAPIAgent 5 | 6 | REACT_PROMPT = """You are a polite, thoughtful, and resourceful general purpose AI. 7 | 8 | Tools - You can use the following tools. The tool name is listed first and a description is listed after the colon. 9 | {tool_summary} 10 | 11 | To employ a tool, use following format: 12 | Action: [] 13 | 14 | Note: You MUST use the tool's name (as provided above) in order to use a tool. 15 | 16 | Instructions 17 | ------------ 18 | You work in a thought-action-observation cycle. In each cycle, you: 19 | 1. Think about the problem and how to solve it. 20 | 2. Perform an action. You can use one of the tools or return a final response with `RETURN[]. You can only use one tool at a time. 21 | 3. Observe the result of your action. This will be the output from any tool you used, or the final answer if you returned a final answer. 22 | 4. Repeat steps 1-3 until you have a final answer in the observation. 23 | 24 | Use the following format: 25 | 26 | Thought: 27 | Action: [] (or RETURN[]) 28 | Observation: 29 | Thought: 30 | Action: [] (or RETURN[]) 31 | Observation: 32 | ... 33 | 34 | Example 1 (no tool needed): 35 | ``` 36 | User Input: What is captial of France? 37 | Tools Available: () 38 | Thought: I know that Paris is the capital of France. 39 | Action: RETURN[Paris] 40 | Observation: Paris 41 | ``` 42 | 43 | Example 2 (tool needed): 44 | In this example, assume that a Weather tool is available. 45 | ``` 46 | User Input: What is the weather in Paris? 47 | Tools Available: (Weather) 48 | Thought: I need to use the Weather tool to get the weather in Paris. 49 | Action: Weather[Paris] 50 | Observation: ('city': 'Paris', 'country': 'France', 'temperature': 20, 'weather': 'sunny') 51 | Thought: I now know that it is 20 degrees and sunny in Paris. 52 | Action: RETURN[It is 20 degrees and sunny in Paris] 53 | Observation: It is 20 degrees and sunny in Paris. 54 | ``` 55 | Note: the above is just an example. The Weather tool may not actually be available. Only use tools listed in the Tools Avaliable section. 56 | 57 | Extra notes: 58 | If a math or code tool is available, do not do any math yourself. Use Tools to evaluate math expressions. 59 | 60 | Begin! 61 | 62 | {workspace}""" 63 | 64 | class ReACTLink(Link): 65 | """ A Link that uses the ReACT algorithm to generate reasoning traces and task-specific actions in an interleaved manner. 66 | 67 | Each step involves: 68 | 1. Thought 69 | 2. Action 70 | 3. Observation 71 | 72 | """ 73 | def __init__(self, agent=OpenAIAPIAgent(stop=['\n'], max_tokens=250), prompt=REACT_PROMPT, tools=[]): 74 | super().__init__(name="ReACTLink") 75 | self.agent = agent 76 | self.prompt = prompt 77 | self.tools = tools 78 | self.tool_names = str([tool.name for tool in self.tools]).replace("[", "(").replace("]", ")") 79 | self.tool_summary = "\n".join(["* " + str(tool) for tool in self.tools]) 80 | 81 | def _parse_action(self, action): 82 | """Parse the output of the agent into a state dictionary.""" 83 | # check if we have a final answer 84 | if "RETURN" in action: 85 | return {"final_answer": action.split("RETURN[")[1].split("]")[0]} 86 | else: 87 | # parse tool output 88 | tool_name = action.split("[")[0].strip() 89 | tool_input = action.split("[")[1].split("]")[0].strip() 90 | return {tool_name: tool_input} 91 | 92 | def validate_state(self, state): 93 | return True 94 | 95 | def _format_prompt(self, workspace): 96 | """Update the prompt with configuration and the current state of the workspace.""" 97 | return self.prompt.format( 98 | tool_summary=self.tool_summary, 99 | workspace=workspace 100 | ) 101 | 102 | def run(self, state, max_calls=5, debug=False): 103 | """Run the ReACT loop until a final answer is found or max_calls is reached.""" 104 | 105 | # get input from state 106 | input = state["input"] 107 | workspace = f"User Input: {input}\nTools Available: {self.tool_names}" 108 | TAOs = [] 109 | 110 | if debug: 111 | print(self._format_prompt(workspace)) 112 | 113 | # main loop 114 | terminate = False 115 | calls = 0 116 | while not terminate and calls < max_calls: 117 | # update prompt for next step 118 | TAO = {} 119 | # Think 120 | agent_input = self._format_prompt(workspace) + "Thought:" 121 | thought = self.agent(agent_input) 122 | TAO['thought'] = thought 123 | workspace += f"Thought: {thought}\n" 124 | if debug: 125 | print(f"Thought: {thought}") 126 | 127 | # Act 128 | agent_input = self._format_prompt(workspace) + "Action:" 129 | action = self.agent(agent_input) 130 | TAO['action'] = action 131 | workspace += f"Action: {action}\n" 132 | if debug: 133 | print(f"Action: {action}") 134 | 135 | # Observe 136 | parsed_action = self._parse_action(action) 137 | if "final_answer" in parsed_action: 138 | observation = parsed_action["final_answer"] 139 | TAO['observation'] = observation 140 | workspace += f"Observation: {observation}\n" 141 | if debug: 142 | print(f"Observation: {observation}") 143 | state["output"] = parsed_action["final_answer"] 144 | terminate = True 145 | else: 146 | # check if we need to run a tool and run it 147 | for tool in self.tools: 148 | if tool.name in parsed_action: 149 | observation = tool.run(parsed_action[tool.name]) 150 | TAO['observation'] = observation 151 | workspace += f"Observation: {observation}\n" 152 | if debug: 153 | print(f"Observation: {observation}") 154 | break 155 | TAOs.append(TAO) 156 | calls += 1 157 | 158 | return state -------------------------------------------------------------------------------- /examples.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 2, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stdout", 10 | "output_type": "stream", 11 | "text": [ 12 | "You are a polite, thoughtful, and resourceful general purpose AI. \n", 13 | "\n", 14 | "Tools - You can use the following tools. The tool name is listed first and a description is listed after the colon.\n", 15 | "* Search: Search for documents based on a query. Returns a list of documents that best match the query.\n", 16 | "\n", 17 | "To employ a tool, use following format:\n", 18 | " Action: []\n", 19 | "\n", 20 | "Note: You MUST use the tool's name (as provided above) in order to use a tool.\n", 21 | " \n", 22 | "Instructions\n", 23 | "------------\n", 24 | "You work in a thought-action-observation cycle. In each cycle, you:\n", 25 | " 1. Think about the problem and how to solve it.\n", 26 | " 2. Perform an action. You can use one of the tools or return a final response with `RETURN[]. You can only use one tool at a time.\n", 27 | " 3. Observe the result of your action. This will be the output from any tool you used, or the final answer if you returned a final answer.\n", 28 | " 4. Repeat steps 1-3 until you have a final answer in the observation.\n", 29 | "\n", 30 | "Use the following format:\n", 31 | "\n", 32 | " Thought: \n", 33 | " Action: [] (or RETURN[])\n", 34 | " Observation: \n", 35 | " Thought: \n", 36 | " Action: [] (or RETURN[])\n", 37 | " Observation: \n", 38 | " ...\n", 39 | "\n", 40 | "Example 1 (no tool needed):\n", 41 | " ```\n", 42 | " User Input: What is captial of France?\n", 43 | " Tools Available: ()\n", 44 | " Thought: I know that Paris is the capital of France.\n", 45 | " Action: RETURN[Paris]\n", 46 | " Observation: Paris\n", 47 | " ```\n", 48 | "\n", 49 | "Example 2 (tool needed):\n", 50 | " In this example, assume that a Weather tool is available.\n", 51 | " ```\n", 52 | " User Input: What is the weather in Paris?\n", 53 | " Tools Available: (Weather)\n", 54 | " Thought: I need to use the Weather tool to get the weather in Paris.\n", 55 | " Action: Weather[Paris]\n", 56 | " Observation: ('city': 'Paris', 'country': 'France', 'temperature': 20, 'weather': 'sunny')\n", 57 | " Thought: I now know that it is 20 degrees and sunny in Paris.\n", 58 | " Action: RETURN[It is 20 degrees and sunny in Paris]\n", 59 | " Observation: It is 20 degrees and sunny in Paris.\n", 60 | " ```\n", 61 | " Note: the above is just an example. The Weather tool may not actually be available. Only use tools listed in the Tools Avaliable section.\n", 62 | "\n", 63 | "Extra notes:\n", 64 | "If a math or code tool is available, do not do any math yourself. Use Tools to evaluate math expressions.\n", 65 | "\n", 66 | "Begin!\n", 67 | "\n", 68 | "User Input: Sumamrize entries on Nasreddin Hodja.\n", 69 | "Tools Available: ('Search')\n", 70 | "Thought: I need to use the Search tool to find documents related to Nasreddin Hodja.\n", 71 | "Action: Search[Nasreddin Hodja]\n", 72 | "Observation: [Document({'text': 'Nasreddin hodja once lived in what is modern day Turkey.', 'id': 1}), Document({'text': 'Nasreddin hodja may have lived in the 13th century, but he is now the subject of many (frequently humerous) stories in Turkish folklore.', 'id': 8}), Document({'text': 'Nasreddin hodja was a very wise old man whom many people counted on for advice.', 'id': 3})]\n", 73 | "Thought: I now have a list of documents related to Nasreddin Hodja.\n", 74 | "Action: RETURN[Nasreddin Hodja was a wise old man who lived in what is now modern day Turkey. He is the subject of many stories in Turkish folklore and was often sought out for advice.]\n", 75 | "Observation: Nasreddin Hodja was a wise old man who lived in what is now modern day Turkey. He is the subject of many stories in Turkish folklore and was often sought out for advice.\n", 76 | "Nasreddin Hodja was a wise old man who lived in what is now modern day Turkey. He is the subject of many stories in Turkish folklore and was often sought out for advice.\n" 77 | ] 78 | } 79 | ], 80 | "source": [ 81 | "from hodja import Chain\n", 82 | "from hodja.links.react import ReACTLink\n", 83 | "from hodja.search import Document, FAISS\n", 84 | "from hodja.search.embeddings.openai import OpenAIEmbeddings\n", 85 | "from hodja.tools.search_tools import SearchTool\n", 86 | "\n", 87 | "test_documents = [\n", 88 | " Document(text=\"Nasreddin hodja once lived in what is modern day Turkey.\", id=1),\n", 89 | " Document(text=\"A man once asked for a glass of water.\", id=2),\n", 90 | " Document(text=\"Nasreddin hodja was a very wise old man whom many people counted on for advice.\", id=3),\n", 91 | " Document(text=\"Python is a great programming language.\", id=4),\n", 92 | " Document(text=\"The most important thing when making pasta is to use the right amount of water.\", id=5),\n", 93 | " Document(text=\"I wonder if I should go to the gym today.\", id=6),\n", 94 | " Document(text=\"My favorite city in Italy to visit is Rome.\", id=7),\n", 95 | " Document(text=\"Nasreddin hodja may have lived in the 13th century, but he is now the subject of many (frequently humerous) stories in Turkish folklore.\", id=8),\n", 96 | "]\n", 97 | "\n", 98 | "embeddings = OpenAIEmbeddings()\n", 99 | "db = FAISS(embeddings=embeddings)\n", 100 | "db.add(test_documents)\n", 101 | "search_tool = SearchTool(db)\n", 102 | "\n", 103 | "tools = [search_tool]\n", 104 | "links = [ReACTLink(tools=tools)]\n", 105 | "chain = Chain(\"ReACT\", links=links)\n", 106 | "chain.run(\"Sumamrize entries on Nasreddin Hodja.\", debug=True)" 107 | ] 108 | }, 109 | { 110 | "cell_type": "code", 111 | "execution_count": null, 112 | "metadata": {}, 113 | "outputs": [], 114 | "source": [] 115 | }, 116 | { 117 | "cell_type": "code", 118 | "execution_count": null, 119 | "metadata": {}, 120 | "outputs": [], 121 | "source": [] 122 | } 123 | ], 124 | "metadata": { 125 | "kernelspec": { 126 | "display_name": "figtree", 127 | "language": "python", 128 | "name": "python3" 129 | }, 130 | "language_info": { 131 | "codemirror_mode": { 132 | "name": "ipython", 133 | "version": 3 134 | }, 135 | "file_extension": ".py", 136 | "mimetype": "text/x-python", 137 | "name": "python", 138 | "nbconvert_exporter": "python", 139 | "pygments_lexer": "ipython3", 140 | "version": "3.9.16" 141 | }, 142 | "orig_nbformat": 4 143 | }, 144 | "nbformat": 4, 145 | "nbformat_minor": 2 146 | } 147 | -------------------------------------------------------------------------------- /hodja/search/docstores.py: -------------------------------------------------------------------------------- 1 | """Classes for storing and retrieving documents.""" 2 | import os 3 | from abc import ABC, abstractmethod 4 | import faiss 5 | import pickle 6 | import json 7 | import numpy as np 8 | 9 | class DocStoreBase(ABC): 10 | 11 | @abstractmethod 12 | def add(self, document): 13 | """Add a document to the store.""" 14 | pass 15 | 16 | @abstractmethod 17 | def remove(self, document_id): 18 | """Remove a document from the store.""" 19 | pass 20 | 21 | @abstractmethod 22 | def get(self, document_id): 23 | """Get a document from the store.""" 24 | pass 25 | 26 | @abstractmethod 27 | def get_all(self): 28 | """Get all documents from the store.""" 29 | pass 30 | 31 | @property 32 | @abstractmethod 33 | def __len__(self): 34 | """Get the number of documents in the store.""" 35 | pass 36 | 37 | 38 | class DocStore(DocStoreBase): 39 | """A basic DocStore that stores documents in a dict.""" 40 | 41 | def __init__(self): 42 | self.documents = {} 43 | 44 | def add(self, documents): 45 | """Add documents to the store. 46 | 47 | Args: 48 | documents (list): Documents to add to the store. 49 | """ 50 | for document in documents: 51 | # check if document has an id, if not, assign one via text hash 52 | if hasattr(document, "id"): 53 | # if document has an id, check if it's already in the store 54 | if document.id in self.documents: 55 | raise ValueError(f"Document with id {document.id} already in store.") 56 | else: 57 | self.documents[document.id] = document 58 | else: 59 | document.id = hash(document.text) 60 | self.documents[document.id] = document 61 | 62 | def remove(self, document_ids): 63 | """Remove documents from the store. 64 | 65 | Args: 66 | document_ids (list): Document ids to remove from the store. 67 | """ 68 | for document_id in document_ids: 69 | del self.documents[document_id] 70 | 71 | def get(self, document_id): 72 | """Get a document from the store. 73 | 74 | Args: 75 | document_id: Document id to get from the store. 76 | """ 77 | return self.documents[document_id] 78 | 79 | def get_all(self): 80 | """Get all documents from the store.""" 81 | return self.documents.values() 82 | 83 | def __len__(self): 84 | """Get the number of documents in the store.""" 85 | return len(self.documents) 86 | 87 | 88 | class VectorStore(DocStoreBase): 89 | """A DocStore that embeds documents.""" 90 | 91 | def __init__(self, embeddings): 92 | self.embeddings = embeddings 93 | self.documents = [] 94 | self.document_embeddings = [] 95 | self._ids = [] 96 | 97 | def add(self, documents): 98 | """Get embeddings for documents and add to the vectorstore. 99 | 100 | Args: 101 | documents: Documents to add to the vectorstore. 102 | """ 103 | for document in documents: 104 | # check if document has an id, if not, assign one via text hash 105 | if hasattr(document, "id"): 106 | # if document has an id, check if it's already in the vectorstore 107 | if document.id in self._ids: 108 | raise ValueError(f"Document with id {document.id} already in vectorstore.") 109 | else: 110 | self._ids.append(document.id) 111 | else: 112 | self._ids.append(hash(document.text)) 113 | self.documents.append(document) 114 | texts = [document.text for document in documents] 115 | new_document_embeddings = self.embeddings.embed(texts) 116 | self.document_embeddings.extend(new_document_embeddings) 117 | 118 | def remove(self, document_ids): 119 | """Remove documents from the vectorstore. 120 | 121 | Args: 122 | document_ids (list): Document ids to remove from the vectorstore. 123 | """ 124 | for document_id in document_ids: 125 | index = self._ids.index(document_id) 126 | self.documents = self.documents[:index] + self.documents[index+1:] 127 | self._ids = self._ids[:index] + self._ids[index+1:] 128 | self.document_embeddings = self.document_embeddings[:index] + self.document_embeddings[index+1:] 129 | 130 | def get(self, document_id): 131 | """Get a document from the vectorstore. 132 | 133 | Args: 134 | document_id: Document id to get from the vectorstore. 135 | """ 136 | index = self._ids.index(document_id) 137 | return self.documents[index] 138 | 139 | def get_all(self): 140 | """Get all documents from the vectorstore.""" 141 | return self.documents 142 | 143 | def __len__(self): 144 | """Get the number of documents in the vectorstore.""" 145 | return len(self.documents) 146 | 147 | 148 | class FAISS(VectorStore): 149 | """Vector database that uses FAISS for fast semantic similarity search over documents.""" 150 | 151 | def __init__(self, embeddings, index=None, documents=None): 152 | super().__init__(embeddings) 153 | if index is None: 154 | self.index = faiss.IndexFlatL2(self._embedding_size) 155 | else: 156 | self.index = index 157 | if documents is None: 158 | self.documents = [] 159 | else: 160 | self.documents = documents 161 | 162 | @property 163 | def _embedding_size(self): 164 | return len(self.embeddings.embed(["dummy"])[0]) 165 | 166 | def save( self, save_directory): 167 | """Save to files.""" 168 | faiss.write_index(self.index, "index.faiss") 169 | with open(os.path.join(save_directory, "documents.json"), "w") as f: 170 | json.dump(self.documents, f) 171 | with open(os.path.join(save_directory, "embeddings.pickle"), "wb") as f: 172 | pickle.dump(self.embeddings, f) 173 | 174 | @classmethod 175 | def load(cls, save_directory): 176 | """Load from files.""" 177 | index = faiss.read_index("index.faiss") 178 | with open(os.path.join(save_directory, "documents.json"), "r") as f: 179 | documents = json.load(f) 180 | with open(os.path.join(save_directory, "embeddings.pickle"), "rb") as f: 181 | embeddings = pickle.load(f) 182 | return cls( 183 | embeddings=embeddings, 184 | index=index, 185 | documents=documents, 186 | ) 187 | 188 | def add(self, documents): 189 | """Add documents to the vectorstore. 190 | 191 | Args: 192 | document: Document to add to the vectorstore. 193 | """ 194 | super().add(documents) 195 | document_embeddings = self.document_embeddings[-len(documents):] 196 | document_embeddings = np.array(document_embeddings, dtype=np.float32) 197 | document_embeddings = document_embeddings.reshape(-1, self._embedding_size) 198 | self.index.add(document_embeddings) 199 | 200 | def remove(self, document_ids): 201 | """Remove documents from the vectorstore. 202 | 203 | Args: 204 | document_ids (list): Document ids to remove from the vectorstore. 205 | """ 206 | super().remove(document_ids) 207 | self.index = faiss.IndexFlatL2(self._embedding_size) 208 | self.document_embeddings = np.array(self.document_embeddings, dtype=np.float32) 209 | if len(self.document_embeddings): 210 | self.index.add(self.document_embeddings) 211 | 212 | def search(self, query, k=4): 213 | """Return docs most similar to query.""" 214 | query_embedding = self.embeddings.embed([query])[0] 215 | query_embedding = np.array(query_embedding, dtype=np.float32) 216 | D, I = self.index.search(query_embedding.reshape(1, -1), k) 217 | return [self.documents[i] for i in I[0]] 218 | 219 | 220 | --------------------------------------------------------------------------------