├── .gitignore ├── README.md ├── db_api.py ├── function_calling_demo.ipynb ├── requirements.txt └── streamlit ├── app.py └── utils ├── callback.py └── funcs ├── db_interactions.py └── rag_pipeline.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Service account 132 | service-accounts 133 | 134 | # Spyder project settings 135 | .spyderproject 136 | .spyproject 137 | 138 | # Rope project settings 139 | .ropeproject 140 | 141 | # mkdocs documentation 142 | /site 143 | 144 | # mypy 145 | .mypy_cache/ 146 | .dmypy.json 147 | dmypy.json 148 | 149 | # Pyre type checker 150 | .pyre/ 151 | 152 | # pytype static type analyzer 153 | .pytype/ 154 | 155 | # Cython debug symbols 156 | cython_debug/ 157 | 158 | # PyCharm 159 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 160 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 161 | # and can be added to the global gitignore or merged into this file. For a more nuclear 162 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 163 | #.idea/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Function Calling demo 2 | ## What does this application wants to demonstrate 3 | This application is built as an extension to [this](https://haystack.deepset.ai/tutorials/40_building_chat_application_with_function_calling) 4 | 1. **Data retrieval**: With both RAG and DB search (via API created from Flask) 5 | 2. **Routing**: Use Function Call for autonomous tool choice & invocation 6 | 3. **UI** Via Streamlit 7 | 8 | ## Tech stack 9 | - **Embedding model**: [all-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2) 10 | - **Vector Database**: [Haystack's InMemoryDocumentStore](https://docs.haystack.deepset.ai/docs/inmemorydocumentstore) 11 | - **LLM**: [GPT-4 Turbo accessed via OpenRouter](https://openrouter.ai/models/openai/gpt-4-1106-preview). But the flow can be adapted into using other LLMs 12 | - **LLM Framework**: [Haystack](https://haystack.deepset.ai/) for their great documentations, and transparency in pipeline construction. This tutorial is actually an extension to their [fantastic tutorial](https://haystack.deepset.ai/tutorials/40_building_chat_application_with_function_calling) for the same topic 13 | 14 | ## Running this tool 15 | 1. Create and activate a virtual environment, then `pip install -r requirements.txt` to install the required packages 16 | 1. Spin up the API server with `python db_api.py` 17 | 2. If you are seeking for an initial tutorial for the concept behind, run `rag_plus_db_search.ipynb`. Or proceed to #3 directly 18 | 3. Run the streamlit application with the below 19 | ``` 20 | export OPENROUTER_API_KEY = '@REPLACE WITH YOUR API KEY' 21 | cd streamlit 22 | streamlit run app.py 23 | ``` -------------------------------------------------------------------------------- /db_api.py: -------------------------------------------------------------------------------- 1 | # Create a flask application to interact with the database with CRUD operations 2 | # This file contains the API for the database operations 3 | 4 | from flask import Flask, request, jsonify 5 | from flask_sqlalchemy import SQLAlchemy 6 | from flask_marshmallow import Marshmallow 7 | import os 8 | 9 | # Initialize the flask application 10 | app = Flask(__name__) 11 | basedir = os.path.abspath(os.path.dirname(__file__)) 12 | # Database configuration 13 | app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///:memory:' 14 | app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False 15 | # Initialize the database 16 | db = SQLAlchemy(app) 17 | # Initialize the marshmallow 18 | ma = Marshmallow(app) 19 | 20 | # Create a class for the database model 21 | 22 | class Item(db.Model): 23 | id = db.Column(db.Integer, primary_key=True) 24 | name_en = db.Column(db.String(255), nullable=False) 25 | name_cn = db.Column(db.String(255), nullable=False) 26 | category = db.Column(db.String(255), nullable=False) 27 | price = db.Column(db.Float, nullable=False) 28 | quantity = db.Column(db.Integer, nullable=False) 29 | 30 | def __init__(self, name_en, name_cn, category, price, quantity): 31 | self.name_en = name_en 32 | self.name_cn = name_cn 33 | self.category = category 34 | self.price = price 35 | self.quantity = quantity 36 | 37 | # Create a schema for the database model 38 | class ItemSchema(ma.Schema): 39 | class Meta: 40 | fields = ('id', 'name_en', 'name_cn', 'category', 'price', 'quantity') 41 | 42 | # Initialize the schema 43 | item_schema = ItemSchema() 44 | items_schema = ItemSchema(many=True) 45 | 46 | # Add sample data to the database 47 | app.app_context().push() 48 | db.create_all() 49 | initial_items = [ 50 | Item('Water', '水', 'Food and beverages', 1.0, 100), 51 | Item('Coca-cola', '可樂', 'Food and beverages', 2.0, 200), 52 | Item('Hamburger', '漢堡包', 'Food and beverages', 13.0, 0), 53 | Item('Fried rice', '炒飯', 'Food and beverages', 9.0, 300), 54 | Item('Newspaper', '報紙', 'Miscellaneous', 2.0, 100), 55 | Item('Cigarettes', '煙', 'Miscellaneous', 5.0, 100), 56 | ] 57 | for item in initial_items: 58 | db.session.add(item) 59 | db.session.commit() 60 | 61 | # Create a route to add a item to the database 62 | # @app.route('/item', methods=['POST']) 63 | # def add_item(): 64 | # name_en = request.json['name_en'] 65 | # name_cn = request.json['name_cn'] 66 | # price = request.json['price'] 67 | # quantity = request.json['quantity'] 68 | 69 | # new_item = Item(name_en, name_cn, price, quantity) 70 | 71 | # db.session.add(new_item) 72 | # db.session.commit() 73 | 74 | # return item_schema.jsonify(new_item) 75 | 76 | # Create a route to get all the items from the database 77 | # @app.route('/item', methods=['GET']) 78 | # def get_items(): 79 | # all_items = Item.query.all() 80 | # result = items_schema.dump(all_items) 81 | # return jsonify(result) 82 | 83 | # Create a route to get the unique categories 84 | @app.route('/category', methods=['GET']) 85 | def get_categories(): 86 | categories = Item.query.with_entities(Item.category).distinct().all() 87 | result = [category[0] for category in categories] 88 | return jsonify(result) 89 | 90 | # Create a route to get a item by ids, multiple ids are separated by comma 91 | @app.route('/item', methods=['GET']) 92 | def get_items(): 93 | ids = request.args.get('id') 94 | categories = request.args.get('category') 95 | if ids: 96 | ids = ids.split(',') 97 | items = Item.query.filter(Item.id.in_(ids)).all() 98 | result = items_schema.dump(items) 99 | elif categories: 100 | categories = categories.split(',') 101 | items = Item.query.filter(Item.category.in_(categories)).all() 102 | result = items_schema.dump(items) 103 | else: 104 | items = Item.query.all() 105 | result = items_schema.dump(items) 106 | return jsonify(result) 107 | 108 | # # Create a route to update a item by id 109 | # @app.route('/item/', methods=['PUT']) 110 | # def update_item(id): 111 | # item = Item.query.get(id) 112 | # name_en = request.json['name_en'] 113 | # name_cn = request.json['name_cn'] 114 | # price = request.json['price'] 115 | # quantity = request.json['quantity'] 116 | 117 | # item.name_en = name_en 118 | # item.name_cn = name_cn 119 | # item.price = price 120 | # item.quantity = quantity 121 | 122 | # db.session.commit() 123 | 124 | # return item_schema.jsonify(item) 125 | 126 | # # Create a route to delete a item by id 127 | # @app.route('/item/', methods=['DELETE']) 128 | # def delete_item(id): 129 | # item = Item.query.get(id) 130 | # db.session.delete(item) 131 | # db.session.commit() 132 | 133 | # return item_schema.jsonify(item) 134 | 135 | # Create a route to deduct the quantity of a item by id 136 | @app.route('/item/purchase', methods=['POST']) 137 | def purchase_item(): 138 | 139 | id = request.json['id'] 140 | quantity = request.json['quantity'] 141 | item = Item.query.get(id) 142 | 143 | item.quantity = item.quantity - quantity 144 | 145 | db.session.commit() 146 | 147 | return item_schema.jsonify(item) 148 | 149 | # Run the application 150 | if __name__ == '__main__': 151 | app.run(debug=True) 152 | 153 | # End of file -------------------------------------------------------------------------------- /function_calling_demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Initialization" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "In this demo we are calling LLMs from OpenRouter, because with it you can access different LLM APIs from Hong Kong. But using the original OpenAIChatGenerator without overwritting the `api_base_url` would also work" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": null, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "import os\n", 24 | "from dotenv import load_dotenv\n", 25 | "from haystack.components.generators.chat import OpenAIChatGenerator\n", 26 | "from haystack.utils import Secret\n", 27 | "from haystack.dataclasses import ChatMessage\n", 28 | "from haystack.components.generators.utils import print_streaming_chunk\n", 29 | "\n", 30 | "# Set your API key as environment variable before executing this\n", 31 | "load_dotenv()\n", 32 | "OPENROUTER_API_KEY = os.environ.get('OPENROUTER_API_KEY')\n", 33 | "\n", 34 | "chat_generator = OpenAIChatGenerator(api_key=Secret.from_env_var(\"OPENROUTER_API_KEY\"),\n", 35 | "\t\tapi_base_url=\"https://openrouter.ai/api/v1\",\n", 36 | "\t\tmodel=\"openai/gpt-4-turbo-preview\",\n", 37 | " streaming_callback=print_streaming_chunk)\n", 38 | "\n", 39 | "chat_generator.run(messages=[ChatMessage.from_user(\"Return this text: 'test'\")])" 40 | ] 41 | }, 42 | { 43 | "cell_type": "markdown", 44 | "metadata": {}, 45 | "source": [ 46 | "# Step 1 - Establish data store" 47 | ] 48 | }, 49 | { 50 | "cell_type": "markdown", 51 | "metadata": {}, 52 | "source": [ 53 | "## Index Documents with a Pipeline\n", 54 | "Here we provide sample texts for the model to perform Retrival Augmented Generation (RAG). The texts are turned into embeddings and stored in an in-memory document store" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": 2, 60 | "metadata": {}, 61 | "outputs": [ 62 | { 63 | "name": "stderr", 64 | "output_type": "stream", 65 | "text": [ 66 | "Batches: 100%|██████████| 1/1 [00:00<00:00, 1.17it/s]\n" 67 | ] 68 | }, 69 | { 70 | "data": { 71 | "text/plain": [ 72 | "{'doc_writer': {'documents_written': 2}}" 73 | ] 74 | }, 75 | "execution_count": 2, 76 | "metadata": {}, 77 | "output_type": "execute_result" 78 | } 79 | ], 80 | "source": [ 81 | "from haystack import Pipeline, Document\n", 82 | "from haystack.document_stores.in_memory import InMemoryDocumentStore\n", 83 | "from haystack.components.writers import DocumentWriter\n", 84 | "from haystack.components.embedders import SentenceTransformersDocumentEmbedder\n", 85 | "\n", 86 | "# Sample documents\n", 87 | "documents = [\n", 88 | " Document(content=\"Coffee shop opens at 9am and closes at 5pm.\"),\n", 89 | " Document(content=\"Gym room opens at 6am and closes at 10pm.\")\n", 90 | "]\n", 91 | "\n", 92 | "# Create the document store\n", 93 | "document_store = InMemoryDocumentStore()\n", 94 | "\n", 95 | "# Create a pipeline to turn the texts into embeddings and store them in the document store\n", 96 | "indexing_pipeline = Pipeline()\n", 97 | "indexing_pipeline.add_component(\n", 98 | " \"doc_embedder\", SentenceTransformersDocumentEmbedder(model=\"sentence-transformers/all-MiniLM-L6-v2\")\n", 99 | ")\n", 100 | "indexing_pipeline.add_component(\"doc_writer\", DocumentWriter(document_store=document_store))\n", 101 | "\n", 102 | "indexing_pipeline.connect(\"doc_embedder.documents\", \"doc_writer.documents\")\n", 103 | "\n", 104 | "indexing_pipeline.run({\"doc_embedder\": {\"documents\": documents}})" 105 | ] 106 | }, 107 | { 108 | "cell_type": "markdown", 109 | "metadata": {}, 110 | "source": [ 111 | "## Spin up API server" 112 | ] 113 | }, 114 | { 115 | "cell_type": "markdown", 116 | "metadata": {}, 117 | "source": [ 118 | "An API server made with Flask is created under `db_api.py` to connect to SQLite. Please spin it up by running `python db_api.py` in your terminal " 119 | ] 120 | }, 121 | { 122 | "cell_type": "markdown", 123 | "metadata": {}, 124 | "source": [ 125 | "# Step 2 - Define the functions\n", 126 | "Here we prepare the actual functions for the model to invoke AFTER Function Calling. Function Calling provides ONLY the arguments for you to invoke these functions, it does not invoke the functions themselves" 127 | ] 128 | }, 129 | { 130 | "cell_type": "markdown", 131 | "metadata": {}, 132 | "source": [ 133 | "## RAG function\n", 134 | "Namely the `rag_pipeline_func`. This is for the model to provide an answer by searching through the texts stored in the Document Store. We first define the RAG retrieval as a Haystack pipeline" 135 | ] 136 | }, 137 | { 138 | "cell_type": "code", 139 | "execution_count": 3, 140 | "metadata": {}, 141 | "outputs": [ 142 | { 143 | "data": { 144 | "image/png": "", 145 | "text/plain": [ 146 | "" 147 | ] 148 | }, 149 | "metadata": {}, 150 | "output_type": "display_data" 151 | }, 152 | { 153 | "data": { 154 | "text/plain": [] 155 | }, 156 | "execution_count": 3, 157 | "metadata": {}, 158 | "output_type": "execute_result" 159 | } 160 | ], 161 | "source": [ 162 | "from haystack.components.embedders import SentenceTransformersTextEmbedder\n", 163 | "from haystack.components.retrievers.in_memory import InMemoryEmbeddingRetriever\n", 164 | "from haystack.components.builders import PromptBuilder\n", 165 | "from haystack.components.generators import OpenAIGenerator\n", 166 | "\n", 167 | "template = \"\"\"\n", 168 | "Answer the questions based on the given context.\n", 169 | "\n", 170 | "Context:\n", 171 | "{% for document in documents %}\n", 172 | " {{ document.content }}\n", 173 | "{% endfor %}\n", 174 | "Question: {{ question }}\n", 175 | "Answer:\n", 176 | "\"\"\"\n", 177 | "rag_pipe = Pipeline()\n", 178 | "rag_pipe.add_component(\"embedder\", SentenceTransformersTextEmbedder(model=\"sentence-transformers/all-MiniLM-L6-v2\"))\n", 179 | "rag_pipe.add_component(\"retriever\", InMemoryEmbeddingRetriever(document_store=document_store))\n", 180 | "rag_pipe.add_component(\"prompt_builder\", PromptBuilder(template=template))\n", 181 | "# Note to llm: We are using OpenAIGenerator, not the OpenAIChatGenerator, because the latter only accepts List[str] as input and cannot accept prompt_builder's str output\n", 182 | "rag_pipe.add_component(\"llm\", OpenAIGenerator(api_key=Secret.from_env_var(\"OPENROUTER_API_KEY\"),\n", 183 | "\t\tapi_base_url=\"https://openrouter.ai/api/v1\",\n", 184 | "\t\tmodel=\"openai/gpt-4-turbo-preview\"))\n", 185 | "\n", 186 | "rag_pipe.connect(\"embedder.embedding\", \"retriever.query_embedding\")\n", 187 | "rag_pipe.connect(\"retriever\", \"prompt_builder.documents\")\n", 188 | "rag_pipe.connect(\"prompt_builder\", \"llm\")" 189 | ] 190 | }, 191 | { 192 | "cell_type": "markdown", 193 | "metadata": {}, 194 | "source": [ 195 | "Test if the pipeline works" 196 | ] 197 | }, 198 | { 199 | "cell_type": "code", 200 | "execution_count": 4, 201 | "metadata": {}, 202 | "outputs": [ 203 | { 204 | "name": "stderr", 205 | "output_type": "stream", 206 | "text": [ 207 | "Batches: 100%|██████████| 1/1 [00:00<00:00, 63.95it/s]\n" 208 | ] 209 | }, 210 | { 211 | "data": { 212 | "text/plain": [ 213 | "{'llm': {'replies': ['The coffee shop opens at 9am.'],\n", 214 | " 'meta': [{'model': 'openai/gpt-4-turbo-preview',\n", 215 | " 'index': 0,\n", 216 | " 'finish_reason': 'stop',\n", 217 | " 'usage': {'completion_tokens': 9,\n", 218 | " 'prompt_tokens': 60,\n", 219 | " 'total_tokens': 69,\n", 220 | " 'total_cost': 0.00087}}]}}" 221 | ] 222 | }, 223 | "execution_count": 4, 224 | "metadata": {}, 225 | "output_type": "execute_result" 226 | } 227 | ], 228 | "source": [ 229 | "query = \"When does the coffee shop open?\"\n", 230 | "rag_pipe.run({\"embedder\": {\"text\": query}, \"prompt_builder\": {\"question\": query}})" 231 | ] 232 | }, 233 | { 234 | "cell_type": "markdown", 235 | "metadata": {}, 236 | "source": [ 237 | "The pipeline is turned into a function" 238 | ] 239 | }, 240 | { 241 | "cell_type": "code", 242 | "execution_count": 5, 243 | "metadata": {}, 244 | "outputs": [], 245 | "source": [ 246 | "def rag_pipeline_func(query: str):\n", 247 | " result = rag_pipe.run({\"embedder\": {\"text\": query}, \"prompt_builder\": {\"question\": query}})\n", 248 | "\n", 249 | " return {\"reply\": result[\"llm\"][\"replies\"][0]}" 250 | ] 251 | }, 252 | { 253 | "cell_type": "markdown", 254 | "metadata": {}, 255 | "source": [ 256 | "## API calls\n", 257 | "For interacting with the API server, which in turns interact with our database" 258 | ] 259 | }, 260 | { 261 | "cell_type": "code", 262 | "execution_count": 6, 263 | "metadata": {}, 264 | "outputs": [], 265 | "source": [ 266 | "# Flask's default local URL, change it if necessary\n", 267 | "db_base_url = 'http://127.0.0.1:5000'\n", 268 | "\n", 269 | "# Use requests to get the data from the database\n", 270 | "import requests\n", 271 | "import json\n", 272 | "\n", 273 | "def get_categories():\n", 274 | " response = requests.get(f'{db_base_url}/category')\n", 275 | " data = response.json()\n", 276 | " return data\n", 277 | "\n", 278 | "def get_items(ids=None,categories=None):\n", 279 | " params = {\n", 280 | " 'id': ids,\n", 281 | " 'category': categories,\n", 282 | " }\n", 283 | " response = requests.get(f'{db_base_url}/item', params=params)\n", 284 | " data = response.json()\n", 285 | " return data\n", 286 | "\n", 287 | "def purchase_item(id,quantity):\n", 288 | "\n", 289 | " headers = {\n", 290 | " 'Content-type':'application/json', \n", 291 | " 'Accept':'application/json'\n", 292 | " }\n", 293 | "\n", 294 | " data = {\n", 295 | " 'id': id,\n", 296 | " 'quantity': quantity,\n", 297 | " }\n", 298 | " response = requests.post(f'{db_base_url}/item/purchase', json=data, headers=headers)\n", 299 | " return response.json()" 300 | ] 301 | }, 302 | { 303 | "cell_type": "markdown", 304 | "metadata": {}, 305 | "source": [ 306 | "## Define the tool list\n", 307 | "Now that we have defined the fuctions, we need to let the model to recognize those functions, and to instruct them how they are used, by providing descriptions for them." 308 | ] 309 | }, 310 | { 311 | "cell_type": "code", 312 | "execution_count": 7, 313 | "metadata": {}, 314 | "outputs": [], 315 | "source": [ 316 | "tools = [\n", 317 | " {\n", 318 | " \"type\": \"function\",\n", 319 | " \"function\": {\n", 320 | " \"name\": \"get_items\",\n", 321 | " \"description\": \"Get a list of items from the database\",\n", 322 | " \"parameters\": {\n", 323 | " \"type\": \"object\",\n", 324 | " \"properties\": {\n", 325 | " \"ids\": {\n", 326 | " \"type\": \"string\",\n", 327 | " \"description\": \"Comma separated list of item ids to fetch\",\n", 328 | " },\n", 329 | " \"categories\": {\n", 330 | " \"type\": \"string\",\n", 331 | " \"description\": \"Comma separated list of item categories to fetch\",\n", 332 | " },\n", 333 | " },\n", 334 | " \"required\": [],\n", 335 | " },\n", 336 | " }\n", 337 | " },\n", 338 | " {\n", 339 | " \"type\": \"function\",\n", 340 | " \"function\": {\n", 341 | " \"name\": \"purchase_item\",\n", 342 | " \"description\": \"Purchase a particular item\",\n", 343 | " \"parameters\": {\n", 344 | " \"type\": \"object\",\n", 345 | " \"properties\": {\n", 346 | " \"id\": {\n", 347 | " \"type\": \"string\",\n", 348 | " \"description\": \"The given product ID, product name is not accepted here. Please obtain the product ID from the database first.\",\n", 349 | " },\n", 350 | " \"quantity\": {\n", 351 | " \"type\": \"integer\",\n", 352 | " \"description\": \"Number of items to purchase\",\n", 353 | " },\n", 354 | " },\n", 355 | " \"required\": [],\n", 356 | " },\n", 357 | " }\n", 358 | " },\n", 359 | " {\n", 360 | " \"type\": \"function\",\n", 361 | " \"function\": {\n", 362 | " \"name\": \"rag_pipeline_func\",\n", 363 | " \"description\": \"Get information from hotel brochure\",\n", 364 | " \"parameters\": {\n", 365 | " \"type\": \"object\",\n", 366 | " \"properties\": {\n", 367 | " \"query\": {\n", 368 | " \"type\": \"string\",\n", 369 | " \"description\": \"The query to use in the search. Infer this from the user's message. It should be a question or a statement\",\n", 370 | " }\n", 371 | " },\n", 372 | " \"required\": [\"query\"],\n", 373 | " },\n", 374 | " },\n", 375 | " }\n", 376 | "]" 377 | ] 378 | }, 379 | { 380 | "cell_type": "markdown", 381 | "metadata": {}, 382 | "source": [ 383 | "# Step 3: Putting it all together\n", 384 | "We now have the necessary inputs to test Function Calling! Here we do a few things:\n", 385 | "1. Provide the initial prompt to the model, to give it some context\n", 386 | "2. Provide a sample user-generated message\n", 387 | "3. Most importantly, we pass the tool list defined above to the chat generator in `tools`" 388 | ] 389 | }, 390 | { 391 | "cell_type": "code", 392 | "execution_count": 15, 393 | "metadata": {}, 394 | "outputs": [ 395 | { 396 | "data": { 397 | "text/plain": [ 398 | "{'replies': [ChatMessage(content='[{\"index\": 0, \"id\": \"call_AkTWoiJzx5uJSgKW0WAI1yBB\", \"function\": {\"arguments\": \"{\\\\\"categories\\\\\":\\\\\"Food and beverages\\\\\"}\", \"name\": \"get_items\"}, \"type\": \"function\"}]', role=, name=None, meta={'model': 'openai/gpt-4-turbo-preview', 'index': 0, 'finish_reason': 'tool_calls', 'usage': {}})]}" 399 | ] 400 | }, 401 | "execution_count": 15, 402 | "metadata": {}, 403 | "output_type": "execute_result" 404 | } 405 | ], 406 | "source": [ 407 | "# 1. Initial prompt\n", 408 | "context = f\"\"\"You are an assistant to tourists visiting a hotel.\n", 409 | "You have access to a database of items (which includes {get_categories()}) that tourists can buy, you also have access to the hotel's brochure.\n", 410 | "If the tourist's question cannot be answered from the database, you can refer to the brochure.\n", 411 | "If the tourist's question cannot be answered from the brochure, you can ask the tourist to ask the hotel staff.\n", 412 | "\"\"\"\n", 413 | "messages = [\n", 414 | " ChatMessage.from_system(context),\n", 415 | " # 2. Sample message from user\n", 416 | " ChatMessage.from_user(\"Can I buy a coffee?\"),\n", 417 | " ]\n", 418 | "\n", 419 | "# 3. Passing the tools list and invoke the chat generator\n", 420 | "response = chat_generator.run(messages=messages, generation_kwargs= {\"tools\": tools})\n", 421 | "response" 422 | ] 423 | }, 424 | { 425 | "cell_type": "markdown", 426 | "metadata": {}, 427 | "source": [ 428 | "Now let's inspect the response. Notice how the Function Calling returns both the function chosen by the model, and the arguments for invoking the chosen function." 429 | ] 430 | }, 431 | { 432 | "cell_type": "code", 433 | "execution_count": 9, 434 | "metadata": {}, 435 | "outputs": [ 436 | { 437 | "name": "stdout", 438 | "output_type": "stream", 439 | "text": [ 440 | "Function Name: get_items\n", 441 | "Function Arguments: {'categories': 'Food and beverages'}\n" 442 | ] 443 | } 444 | ], 445 | "source": [ 446 | "function_call = json.loads(response[\"replies\"][0].content)[0]\n", 447 | "function_name = function_call[\"function\"][\"name\"]\n", 448 | "function_args = json.loads(function_call[\"function\"][\"arguments\"])\n", 449 | "print(\"Function Name:\", function_name)\n", 450 | "print(\"Function Arguments:\", function_args)" 451 | ] 452 | }, 453 | { 454 | "cell_type": "markdown", 455 | "metadata": {}, 456 | "source": [ 457 | "When presented with another question, the model will use another tool that is more relevant" 458 | ] 459 | }, 460 | { 461 | "cell_type": "code", 462 | "execution_count": 10, 463 | "metadata": {}, 464 | "outputs": [ 465 | { 466 | "name": "stdout", 467 | "output_type": "stream", 468 | "text": [ 469 | "Function Name: rag_pipeline_func\n", 470 | "Function Arguments: {'query': \"Where's the coffee shop?\"}\n" 471 | ] 472 | } 473 | ], 474 | "source": [ 475 | "# Another question\n", 476 | "messages.append(ChatMessage.from_user(\"Where's the coffee shop?\"))\n", 477 | "\n", 478 | "# Invoke the chat generator, and passing the tools list\n", 479 | "response = chat_generator.run(messages=messages, generation_kwargs= {\"tools\": tools})\n", 480 | "function_call = json.loads(response[\"replies\"][0].content)[0]\n", 481 | "function_name = function_call[\"function\"][\"name\"]\n", 482 | "function_args = json.loads(function_call[\"function\"][\"arguments\"])\n", 483 | "print(\"Function Name:\", function_name)\n", 484 | "print(\"Function Arguments:\", function_args)" 485 | ] 486 | }, 487 | { 488 | "cell_type": "markdown", 489 | "metadata": {}, 490 | "source": [ 491 | "Notice that no actual function is invoked here, this is what we will do next" 492 | ] 493 | }, 494 | { 495 | "cell_type": "markdown", 496 | "metadata": {}, 497 | "source": [ 498 | "## Calling the function\n", 499 | "We can then feed the arguments into the chosen function" 500 | ] 501 | }, 502 | { 503 | "cell_type": "code", 504 | "execution_count": 11, 505 | "metadata": {}, 506 | "outputs": [ 507 | { 508 | "name": "stderr", 509 | "output_type": "stream", 510 | "text": [ 511 | "Batches: 100%|██████████| 1/1 [00:00<00:00, 63.99it/s]\n" 512 | ] 513 | }, 514 | { 515 | "name": "stdout", 516 | "output_type": "stream", 517 | "text": [ 518 | "Function Response: {'reply': 'The provided context does not specify a physical location for the coffee shop, only its operating hours. Therefore, I cannot determine where the coffee shop is located based on the given information.'}\n" 519 | ] 520 | } 521 | ], 522 | "source": [ 523 | "## Find the correspoding function and call it with the given arguments\n", 524 | "available_functions = {\"get_items\": get_items, \"purchase_item\": purchase_item,\"rag_pipeline_func\": rag_pipeline_func}\n", 525 | "function_to_call = available_functions[function_name]\n", 526 | "function_response = function_to_call(**function_args)\n", 527 | "print(\"Function Response:\", function_response)" 528 | ] 529 | }, 530 | { 531 | "cell_type": "markdown", 532 | "metadata": {}, 533 | "source": [ 534 | "The response can then passed as a context to the chat, under the `messages` argument." 535 | ] 536 | }, 537 | { 538 | "cell_type": "code", 539 | "execution_count": 12, 540 | "metadata": {}, 541 | "outputs": [ 542 | { 543 | "name": "stdout", 544 | "output_type": "stream", 545 | "text": [ 546 | "For the location of the coffee shop within the hotel, I recommend asking the hotel staff directly. They will be able to guide you to it accurately.For the location of the coffee shop within the hotel, I recommend asking the hotel staff directly. They will be able to guide you to it accurately.\n" 547 | ] 548 | } 549 | ], 550 | "source": [ 551 | "messages.append(ChatMessage.from_function(content=json.dumps(function_response), name=function_name))\n", 552 | "response = chat_generator.run(messages=messages)\n", 553 | "response_msg = response[\"replies\"][0]\n", 554 | "\n", 555 | "print(response_msg.content)" 556 | ] 557 | }, 558 | { 559 | "cell_type": "markdown", 560 | "metadata": {}, 561 | "source": [ 562 | "We now have completed the chat cycle!" 563 | ] 564 | }, 565 | { 566 | "cell_type": "markdown", 567 | "metadata": {}, 568 | "source": [ 569 | "## Turn into interactive chat\n", 570 | "The below code is copied from [Haystack's tutorial](https://haystack.deepset.ai/tutorials/40_building_chat_application_with_function_calling). However, for interactive chat we better hook the model to Streamlit to provide a neat ChatGPT-like UI" 571 | ] 572 | }, 573 | { 574 | "cell_type": "code", 575 | "execution_count": null, 576 | "metadata": {}, 577 | "outputs": [], 578 | "source": [ 579 | "import json\n", 580 | "from haystack.dataclasses import ChatMessage, ChatRole\n", 581 | "\n", 582 | "response = None\n", 583 | "messages = [\n", 584 | " ChatMessage.from_system(context)\n", 585 | "]\n", 586 | "\n", 587 | "while True:\n", 588 | " # if OpenAI response is a tool call\n", 589 | " if response and response[\"replies\"][0].meta[\"finish_reason\"] == \"tool_calls\":\n", 590 | " function_calls = json.loads(response[\"replies\"][0].content)\n", 591 | "\n", 592 | " for function_call in function_calls:\n", 593 | " ## Parse function calling information\n", 594 | " function_name = function_call[\"function\"][\"name\"]\n", 595 | " function_args = json.loads(function_call[\"function\"][\"arguments\"])\n", 596 | "\n", 597 | " ## Find the correspoding function and call it with the given arguments\n", 598 | " function_to_call = available_functions[function_name]\n", 599 | " function_response = function_to_call(**function_args)\n", 600 | "\n", 601 | " ## Append function response to the messages list using `ChatMessage.from_function`\n", 602 | " messages.append(ChatMessage.from_function(content=json.dumps(function_response), name=function_name))\n", 603 | "\n", 604 | " # Regular Conversation\n", 605 | " else:\n", 606 | " # Append assistant messages to the messages list\n", 607 | " if not messages[-1].is_from(ChatRole.SYSTEM):\n", 608 | " messages.append(response[\"replies\"][0])\n", 609 | "\n", 610 | " user_input = input(\"ENTER YOUR MESSAGE 👇 INFO: Type 'exit' or 'quit' to stop\\n\")\n", 611 | " if user_input.lower() == \"exit\" or user_input.lower() == \"quit\":\n", 612 | " break\n", 613 | " else:\n", 614 | " messages.append(ChatMessage.from_user(user_input))\n", 615 | "\n", 616 | " response = chat_generator.run(messages=messages, generation_kwargs={\"tools\": tools})" 617 | ] 618 | } 619 | ], 620 | "metadata": { 621 | "kernelspec": { 622 | "display_name": "venv", 623 | "language": "python", 624 | "name": "python3" 625 | }, 626 | "language_info": { 627 | "codemirror_mode": { 628 | "name": "ipython", 629 | "version": 3 630 | }, 631 | "file_extension": ".py", 632 | "mimetype": "text/x-python", 633 | "name": "python", 634 | "nbconvert_exporter": "python", 635 | "pygments_lexer": "ipython3", 636 | "version": "3.10.4" 637 | } 638 | }, 639 | "nbformat": 4, 640 | "nbformat_minor": 2 641 | } 642 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | python-dotenv==1.0.1 2 | Flask==3.0.2 3 | Flask_sqlalchemy==3.1.1 4 | flask_marshmallow==1.2.1 5 | marshmallow-sqlalchemy==1.0.0 6 | openai==1.14.3 7 | haystack-ai==2.0.0 8 | sentence-transformers==2.6.1 9 | streamlit==1.32.2 10 | ipykernel==6.25.2 -------------------------------------------------------------------------------- /streamlit/app.py: -------------------------------------------------------------------------------- 1 | import os 2 | from haystack.components.generators.chat import OpenAIChatGenerator 3 | from haystack.utils import Secret 4 | from haystack.dataclasses import ChatMessage 5 | from utils.funcs.db_interactions import get_categories, get_items, purchase_item 6 | from utils.funcs.rag_pipeline import rag_pipeline_func 7 | from utils.callback import StreamlitCallbackHandler 8 | import json 9 | from haystack.dataclasses import ChatMessage, ChatRole 10 | from haystack.components.generators.chat import OpenAIChatGenerator 11 | import streamlit as st 12 | 13 | # Load the API key from the .env file, alternatively declare it in the terminal 14 | OPENROUTER_API_KEY = os.environ.get('OPENROUTER_API_KEY') 15 | 16 | # Prepare the OpenAIChatGenerator for Streamlit 17 | tools = [ 18 | { 19 | "type": "function", 20 | "function": { 21 | "name": "get_items", 22 | "description": "Get a list of items from the database", 23 | "parameters": { 24 | "type": "object", 25 | "properties": { 26 | "ids": { 27 | "type": "string", 28 | "description": "Comma separated list of item ids to fetch", 29 | }, 30 | "categories": { 31 | "type": "string", 32 | "description": "Comma separated list of item categories to fetch", 33 | }, 34 | }, 35 | "required": [], 36 | }, 37 | } 38 | }, 39 | { 40 | "type": "function", 41 | "function": { 42 | "name": "purchase_item", 43 | "description": "Purchase a particular item", 44 | "parameters": { 45 | "type": "object", 46 | "properties": { 47 | "id": { 48 | "type": "string", 49 | "description": "The given product ID, product name is not accepted here. Please obtain the product ID from the database first.", 50 | }, 51 | "quantity": { 52 | "type": "integer", 53 | "description": "Number of items to purchase", 54 | }, 55 | }, 56 | "required": [], 57 | }, 58 | } 59 | }, 60 | { 61 | "type": "function", 62 | "function": { 63 | "name": "rag_pipeline_func", 64 | "description": "Get information from hotel brochure", 65 | "parameters": { 66 | "type": "object", 67 | "properties": { 68 | "query": { 69 | "type": "string", 70 | "description": "The query to use in the search. Infer this from the user's message. It should be a question or a statement", 71 | } 72 | }, 73 | "required": ["query"], 74 | }, 75 | }, 76 | } 77 | ] 78 | 79 | context = f"""You are an assistant to tourists visiting a hotel. 80 | You have access to a database of items (which includes {get_categories()}) that tourists can buy, you also have access to the hotel's brochure. 81 | If the tourist's question cannot be answered by the database, you can refer to the brochure. 82 | If the tourist's question cannot be answered by the brochure, you can ask the tourist to ask the hotel staff. 83 | """ 84 | 85 | available_functions = {"get_items": get_items, "purchase_item": purchase_item,"rag_pipeline_func": rag_pipeline_func} 86 | 87 | # Streamlit chat interface 88 | if "messages" not in st.session_state: 89 | st.session_state["messages"] = [ChatMessage.from_system(context)] 90 | 91 | # Only show chat messages from the user and the assistant. Initial system prompt and function calls are hidden. 92 | for message in st.session_state.messages: 93 | if message.is_from(ChatRole.USER) | message.is_from(ChatRole.ASSISTANT): 94 | with st.chat_message(message.role.name): 95 | st.markdown(message.content) 96 | 97 | if prompt := st.chat_input("ENTER YOUR MESSAGE 👇"): 98 | st.session_state.messages.append(ChatMessage.from_user(prompt)) 99 | with st.chat_message("USER"): 100 | st.markdown(prompt) 101 | 102 | with st.chat_message("ASSISTANT"): 103 | # Initialize the callback handler, which creates an empty container for the responses to be streamed into 104 | st_callback = StreamlitCallbackHandler(st.empty()) 105 | # Initialize the chat generator 106 | chat_generator = OpenAIChatGenerator( 107 | api_key=Secret.from_env_var("OPENROUTER_API_KEY"), 108 | api_base_url="https://openrouter.ai/api/v1", 109 | model="openai/gpt-4-turbo-preview", 110 | streaming_callback=st_callback.on_llm_new_token) 111 | while True: 112 | # Run the chat generator, tool calls will be looped through and executed, until an assistant reply is generated 113 | response = chat_generator.run(messages=st.session_state.messages, generation_kwargs={"tools": tools}) 114 | 115 | if response and response["replies"][0].meta["finish_reason"] == "tool_calls": 116 | function_calls = json.loads(response["replies"][0].content) 117 | 118 | for function_call in function_calls: 119 | ## Parse function calling information 120 | function_name = function_call["function"]["name"] 121 | function_args = json.loads(function_call["function"]["arguments"]) 122 | 123 | ## Find the correspoding function and call it with the given arguments 124 | function_to_call = available_functions[function_name] 125 | function_response = function_to_call(**function_args) 126 | 127 | ## Append function response to the messages list using `ChatMessage.from_function` 128 | st.session_state.messages.append(ChatMessage.from_function(content=json.dumps(function_response), name=function_name)) 129 | # Regular conversation 130 | else: 131 | # Append assistant messages to the messages list 132 | if not st.session_state.messages[-1].is_from(ChatRole.SYSTEM): 133 | st.session_state.messages.append(response["replies"][0]) 134 | break -------------------------------------------------------------------------------- /streamlit/utils/callback.py: -------------------------------------------------------------------------------- 1 | # The write_stream method newly supported in Streamlit does not seem to be usable with Haystack's generator 2 | # So we will use the StreamlitCallbackHandler to stream the responses from the assistant to the UI 3 | 4 | # This is adapted from Langchain's StreamlitCallbackHandler 5 | # https://api.python.langchain.com/en/latest/_modules/langchain_community/callbacks/streamlit/streamlit_callback_handler.html#StreamlitCallbackHandler 6 | 7 | from haystack.dataclasses import StreamingChunk 8 | 9 | class StreamlitCallbackHandler(): 10 | def __init__(self, response_container): 11 | self.response_container = response_container 12 | self.current_text = '' 13 | 14 | # Stream the messages from the assistant to the UI 15 | def on_llm_new_token(self, chunk: StreamingChunk): 16 | # Only chat messages from the assistant are shown, because chunks from function/tool calls do not have the content attribute 17 | self.current_text += chunk.content 18 | self.response_container.markdown(self.current_text) -------------------------------------------------------------------------------- /streamlit/utils/funcs/db_interactions.py: -------------------------------------------------------------------------------- 1 | # Use requests to get the data from the database 2 | import requests 3 | 4 | db_base_url = 'http://127.0.0.1:5000' 5 | 6 | def get_categories(): 7 | response = requests.get(f'{db_base_url}/category') 8 | data = response.json() 9 | return data 10 | 11 | def get_items(ids=None,categories=None): 12 | params = { 13 | 'id': ids, 14 | 'category': categories, 15 | } 16 | response = requests.get(f'{db_base_url}/item', params=params) 17 | data = response.json() 18 | return data 19 | 20 | def purchase_item(id,quantity): 21 | 22 | headers = { 23 | 'Content-type':'application/json', 24 | 'Accept':'application/json' 25 | } 26 | 27 | data = { 28 | 'id': id, 29 | 'quantity': quantity, 30 | } 31 | response = requests.post(f'{db_base_url}/item/purchase', json=data, headers=headers) 32 | print(response) 33 | print(response.text) 34 | response_json = response.json() 35 | return response_json -------------------------------------------------------------------------------- /streamlit/utils/funcs/rag_pipeline.py: -------------------------------------------------------------------------------- 1 | from haystack import Pipeline, Document 2 | from haystack.document_stores.in_memory import InMemoryDocumentStore 3 | from haystack.components.writers import DocumentWriter 4 | from haystack.components.embedders import (SentenceTransformersDocumentEmbedder, SentenceTransformersTextEmbedder) 5 | from haystack.components.retrievers.in_memory import InMemoryEmbeddingRetriever 6 | from haystack.components.builders import PromptBuilder 7 | from haystack.components.generators import OpenAIGenerator 8 | from haystack.utils import Secret 9 | 10 | # Embed documents 11 | def embed_documents(documents): 12 | document_store = InMemoryDocumentStore() 13 | 14 | indexing_pipeline = Pipeline() 15 | indexing_pipeline.add_component( 16 | "doc_embedder", SentenceTransformersDocumentEmbedder(model="sentence-transformers/all-MiniLM-L6-v2") 17 | ) 18 | indexing_pipeline.add_component("doc_writer", DocumentWriter(document_store=document_store)) 19 | 20 | indexing_pipeline.connect("doc_embedder.documents", "doc_writer.documents") 21 | 22 | indexing_pipeline.run({"doc_embedder": {"documents": documents}}) 23 | 24 | return document_store 25 | 26 | documents = [ 27 | Document(content="Coffee shop opens at 9am and closes at 5pm."), 28 | Document(content="Gym room opens at 6am and closes at 10pm.") 29 | ] 30 | 31 | document_store = embed_documents(documents) 32 | 33 | # Create RAG pipeline 34 | template = """ 35 | Answer the questions based on the given context. 36 | 37 | Context: 38 | {% for document in documents %} 39 | {{ document.content }} 40 | {% endfor %} 41 | Question: {{ question }} 42 | Answer: 43 | """ 44 | rag_pipe = Pipeline() 45 | rag_pipe.add_component("embedder", SentenceTransformersTextEmbedder(model="sentence-transformers/all-MiniLM-L6-v2")) 46 | rag_pipe.add_component("retriever", InMemoryEmbeddingRetriever(document_store=document_store)) 47 | rag_pipe.add_component("prompt_builder", PromptBuilder(template=template)) 48 | # Note to llm: We are using OpenAIGenerator, not the OpenAIChatGenerator, because the latter only accepts List[str] as input and cannot accept prompt_builder's str output 49 | rag_pipe.add_component("llm", OpenAIGenerator(api_key=Secret.from_env_var("OPENROUTER_API_KEY"), 50 | api_base_url="https://openrouter.ai/api/v1", 51 | model="openai/gpt-4-turbo-preview")) 52 | 53 | rag_pipe.connect("embedder.embedding", "retriever.query_embedding") 54 | rag_pipe.connect("retriever", "prompt_builder.documents") 55 | rag_pipe.connect("prompt_builder", "llm") 56 | 57 | def rag_pipeline_func(query: str): 58 | result = rag_pipe.run({"embedder": {"text": query}, "prompt_builder": {"question": query}}) 59 | 60 | return {"reply": result["llm"]["replies"][0]} 61 | --------------------------------------------------------------------------------