├── LICENSE ├── Mixtral8x7B_AI_Chat_with_Tools.ipynb └── README.md /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Will Spagnoli 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Mixtral8x7B_AI_Chat_with_Tools.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "source": [ 6 | "# Enhancing AI Chat with Mistral AI's Mixtral8x7B Large Language Model\n", 7 | "\n", 8 | "## Overview\n", 9 | "This notebook is your guide to leveraging Mistral AI's advanced Mixtral8x7B [Mixtral of Experts](https://mistral.ai/news/mixtral-of-experts/) Large Language Model. It utilizes the [Huggingface transformers](https://huggingface.co/docs/transformers/index) package to provide an interactive and powerful AI assistant.\n", 10 | "\n", 11 | "Inspired by the comprehensive [Pinecone Tutorial](https://www.pinecone.io/learn/mixtral-8x7b/), this setup includes advanced features like 4-bit/8-bit quantization and Flash Attention 2 for improved efficiency and multi-step chat logic for a seamless interaction experience.\n", 12 | "\n", 13 | "### Tools\n", 14 | "\n", 15 | "- **Calculator**: A feature for the Mixtral AI Assistant to execute mathematical calculations. **Note**: This uses the `exec()` function, allowing execution of arbitrary Python code, which could pose security risks in environments like Google Colab.\n", 16 | "- **Search**: This enables the AI Assistant to perform real-time web searches using DuckDuckGo.\n", 17 | "\n" 18 | ], 19 | "metadata": { 20 | "id": "ZFaOHFn4lQCO" 21 | } 22 | }, 23 | { 24 | "cell_type": "markdown", 25 | "source": [ 26 | "# Setup Instructions\n", 27 | "\n", 28 | "## 1) GPU Runtime Selection\n", 29 | "In Google Colab, go to Runtime > Change runtime type > Select A100 GPU (if available). For the free version of Google Colab, where A100 might not be available, opt for a 'High-RAM' Runtime with 4-bit quantization (`load_in_4_bit=True`). Note: This has been tested on A100 GPU using 4-bit quantization and Flash Attention 2 (`use_flash_attention_2=True`). For more powerful hardware, use the 'Connect to a custom GCE VM' option.\n", 30 | "\n", 31 | "## 2) Installing Dependencies\n", 32 | "Execute the following cell to install necessary packages. After installation, the runtime will restart automatically. Do not rerun the installation cell; proceed to step 3 for downloading and setting up the Mixtral model.\n", 33 | "\n" 34 | ], 35 | "metadata": { 36 | "id": "TPVUrcqKkvuG" 37 | } 38 | }, 39 | { 40 | "cell_type": "code", 41 | "source": [ 42 | "# @title Install Huggingface & dependencies\n", 43 | "\n", 44 | "from google.colab import output, files\n", 45 | "import os\n", 46 | "\n", 47 | "# Installing tooling from Pinecone tutorial and additional dependencies\n", 48 | "!pip install -qU transformers==4.36.1 accelerate==0.25.0 duckduckgo_search==4.1.0\n", 49 | "# For 4-bit & 8-bit quantization\n", 50 | "!pip install -U bitsandbytes\n", 51 | "# For Flash Attention 2\n", 52 | "!pip install flash-attn --no-build-isolation\n", 53 | "output.clear()\n", 54 | "\n", 55 | "print(\"Dependencies installed successfully. Restarting Runtime...\")\n", 56 | "os.kill(os.getpid(), 9)\n" 57 | ], 58 | "metadata": { 59 | "cellView": "form", 60 | "id": "1y1cdhhiky5-" 61 | }, 62 | "execution_count": null, 63 | "outputs": [] 64 | }, 65 | { 66 | "cell_type": "markdown", 67 | "source": [ 68 | "# 3) Download and Initialize the Mixtral Model\n", 69 | "\n", 70 | "Set up the Mixtral model by selecting your preferences below, then run the '**Download & Initialize Model**' cell. Remember to only run this cell once to avoid time-consuming reloads. For initiating chat sessions, use the 'Run AI Chat' cell.\n", 71 | "\n", 72 | "### Model Settings\n", 73 | "- `use_instruct_model: bool = True`: Enables the instruct-finetuned Mixtral-8x7B-Instruct-v0.1 model. Set to `False` for the base model, but note this notebook is optimized for the Instruct model. For base model usage, modify `sys_message: str` in the 'Run Mixtral8x7B AI Chat with Tools' cell of Step 4 with your text-completion prompt.\n", 74 | "\n", 75 | "- `load_in_4_bit: bool = True`: Activates 4-bit quantization for reduced memory and faster inference.\n", 76 | "\n", 77 | "- `load_in_8_bit: bool = False`: Enables 8-bit quantization. If both 4-bit and 8-bit are true, 4-bit takes precedence.\n", 78 | "\n", 79 | "- `use_flash_attention_2: bool = True`: Utilizes Flash Attention 2 for faster inference.\n", 80 | "\n", 81 | "---\n", 82 | "### Text-Generation Arguments\n", 83 | "Default settings for Huggingface `transformers.pipeline()` are recommended unless you have specific requirements. For detailed information, refer to the [Huggingface Pipelines Documentation](https://huggingface.co/docs/transformers/main_classes/pipelines).\n", 84 | "- `temperature: float = 0.1`: Controls output randomness. Range: 0.0 (min) to 1.0 (max).\n", 85 | "- `top_p: float = 0.15`: Chooses from top tokens cumulatively adding up to `top_p`.\n", 86 | "- `top_k: int = 0`: Selects from top `top_k` tokens. Zero value means reliance on `top_p`.\n", 87 | "- `do_sample: bool = True`: Necessary for `top_k` usage, although its exact function is unclear.\n", 88 | "- `max_new_tokens: int = 512`: Limits the number of generated tokens per response.\n", 89 | "- `repetition_penalty: float = 1.1`: Discourages repetitive text. Increase if repetition occurs.\n", 90 | "\n" 91 | ], 92 | "metadata": { 93 | "id": "AQ5nkCe8h35L" 94 | } 95 | }, 96 | { 97 | "cell_type": "code", 98 | "source": [ 99 | "# @title Download & Initialize Mixtral8x7B Model\n", 100 | "\n", 101 | "# Model Settings\n", 102 | "use_instruct_model = True # @param {type:\"boolean\"}\n", 103 | "load_in_4_bit = True # @param {type:\"boolean\"}\n", 104 | "load_in_8_bit = False # @param {type:\"boolean\"}\n", 105 | "use_flash_attention_2 = True # @param {type:\"boolean\"}\n", 106 | "# Huggingface transformers.pipeline() Args\n", 107 | "temperature = 0.1 # @param {type:\"number\"}\n", 108 | "top_p = 0.15 # @param {type:\"number\"}\n", 109 | "top_k = 0 # @param {type:\"integer\"}\n", 110 | "do_sample = True # @param {type:\"boolean\"}\n", 111 | "max_new_tokens = 512 # @param {type:\"integer\"}\n", 112 | "repetition_penalty = 1.1 # @param {type:\"number\"}\n", 113 | "\n", 114 | "\n", 115 | "import torch\n", 116 | "import transformers\n", 117 | "from transformers import AutoModelForCausalLM, AutoTokenizer\n", 118 | "import json\n", 119 | "import datetime\n", 120 | "from duckduckgo_search import DDGS\n", 121 | "import os\n", 122 | "from google.colab import output\n", 123 | "import logging\n", 124 | "import io\n", 125 | "\n", 126 | "# Set the logging level to suppress warnings\n", 127 | "logging.getLogger('transformers').setLevel(logging.ERROR)\n", 128 | "logging.getLogger('bitsandbytes').setLevel(logging.ERROR)\n", 129 | "\n", 130 | "\n", 131 | "# Select Model Version\n", 132 | "if use_instruct_model:\n", 133 | " model_id = \"mistralai/Mixtral-8x7B-Instruct-v0.1\"\n", 134 | "else:\n", 135 | " model_id = \"mistralai/Mixtral-8x7B-v0.1\"\n", 136 | "\n", 137 | "# Load Tokenizer\n", 138 | "tokenizer = AutoTokenizer.from_pretrained(model_id)\n", 139 | "if load_in_4_bit:\n", 140 | " if use_flash_attention_2:\n", 141 | " # Load model in 4-bit precision with Flash Attention 2\n", 142 | " model = AutoModelForCausalLM.from_pretrained(\n", 143 | " pretrained_model_name_or_path = model_id,\n", 144 | " load_in_4bit=True,\n", 145 | " attn_implementation=\"flash_attention_2\"\n", 146 | " )\n", 147 | " else:\n", 148 | " # Load model in 8-bit precision\n", 149 | " model = AutoModelForCausalLM.from_pretrained(\n", 150 | " pretrained_model_name_or_path = model_id,\n", 151 | " load_in_4bit=True\n", 152 | " )\n", 153 | "\n", 154 | "elif load_in_8_bit:\n", 155 | " if use_flash_attention_2:\n", 156 | " # Load model in 8-bit precision with Flash Attention 2\n", 157 | " model = AutoModelForCausalLM.from_pretrained(\n", 158 | " pretrained_model_name_or_path = model_id,\n", 159 | " load_in_8bit=True,\n", 160 | " attn_implementation=\"flash_attention_2\"\n", 161 | " )\n", 162 | " else:\n", 163 | " # Load model in 8-bit precision\n", 164 | " model = AutoModelForCausalLM.from_pretrained(\n", 165 | " pretrained_model_name_or_path = model_id,\n", 166 | " load_in_8bit=True\n", 167 | " )\n", 168 | "\n", 169 | "else:\n", 170 | " if use_flash_attention_2:\n", 171 | " # Load model in full precision with Flash Attention 2\n", 172 | " model = AutoModelForCausalLM.from_pretrained(\n", 173 | " pretrained_model_name_or_path = model_id,\n", 174 | " attn_implementation=\"flash_attention_2\"\n", 175 | " )\n", 176 | " else:\n", 177 | " # Load model in full precision\n", 178 | " model = AutoModelForCausalLM.from_pretrained(\n", 179 | " pretrained_model_name_or_path = model_id\n", 180 | " )\n", 181 | "\n", 182 | "\n", 183 | "\n", 184 | "# Create Huggingface Text Generation Pipeline\n", 185 | "generate_text = transformers.pipeline(\n", 186 | " model=model,\n", 187 | " tokenizer=tokenizer,\n", 188 | " return_full_text=False, # if using langchain set True\n", 189 | " task=\"text-generation\",\n", 190 | " # we pass model parameters here too\n", 191 | " temperature=temperature, # 'randomness' of outputs, 0.0 is the min and 1.0 the max\n", 192 | " top_p=top_p, # select from top tokens whose probability add up to 15%\n", 193 | " top_k=top_k, # select from top 0 tokens (because zero, relies on top_p)\n", 194 | " do_sample=do_sample, # Transformers warning says I need to set this to True since top_k is set\n", 195 | " max_new_tokens=max_new_tokens, # max number of tokens to generate in the output\n", 196 | " repetition_penalty=repetition_penalty # if output begins repeating increase\n", 197 | ")\n", 198 | "\n", 199 | "\n", 200 | "\n", 201 | "# Prints debug_message only if env variable \"DEBUG\".upper() == True\n", 202 | "def print_debug_message(debug_message: str):\n", 203 | " if os.environ.get(\"DEBUG\") and os.environ.get(\"DEBUG\").upper() == \"TRUE\":\n", 204 | " print(f\"\\n\\n\\n##### Debug Message ######\\n {debug_message}\\n##### End Debug Message ######\\n\\n\\n\")\n", 205 | "\n", 206 | "# Sanitize generated_text string for when the model directly\n", 207 | "# returned its response message rather than a valid action dict json\n", 208 | "# string for its response as intended so we have to format it manually\n", 209 | "# the action dict json string manually and must remove any potential newlines\n", 210 | "# and stuff from its message to ensure json parsability\n", 211 | "def sanitize_text_for_json(text: str):\n", 212 | " # Strip whitespace and control characters from both ends of the text\n", 213 | " sanitized_text = text.strip()\n", 214 | "\n", 215 | " # Replace internal control characters like newlines with spaces\n", 216 | " sanitized_text = sanitized_text.replace(\"\\n\", \" \").replace(\"\\r\", \" \").replace(\"\\t\", \" \")\n", 217 | "\n", 218 | " return sanitized_text\n", 219 | "\n", 220 | "# Set up first user query formatted with initial system prompt (tool\n", 221 | "# use instructions included in default system prompt)\n", 222 | "def first_prompt_instruction_format(query: str, sys_message: str = None):\n", 223 | "\n", 224 | " if sys_message is None:\n", 225 | " sys_message = \"\"\"You are a helpful AI assistant, you are an agent capable of using a variety of tools to answer a question. Here are a few of the tools available to you:\n", 226 | "\n", 227 | " - Calculator: the calculator should be used whenever you need to perform a calculation, no matter how simple. It uses Python so make sure to write complete Python code required to perform the calculation required and make sure the Python returns your answer to the `output` variable.\n", 228 | " - Search: the search tool should be used whenever you need to find information. It can be used to find information about everything\n", 229 | " - Final Answer: the final answer tool must be used to respond to the user. You must use this when you have decided on an answer.\n", 230 | "\n", 231 | " To use these tools you must always respond in JSON format containing `\"tool_name\"` and `\"input\"` key-value pairs. For example, to answer the question, \"what is the square root of 51?\" you must use the calculator tool like so:\n", 232 | "\n", 233 | " ```json\n", 234 | " {\n", 235 | " \"tool_name\": \"Calculator\",\n", 236 | " \"input\": \"from math import sqrt; output = sqrt(51)\"\n", 237 | " }\n", 238 | " ```\n", 239 | "\n", 240 | " Or to answer the question \"who is the current president of the USA?\" you must respond:\n", 241 | "\n", 242 | " ```json\n", 243 | " {\n", 244 | " \"tool_name\": \"Search\",\n", 245 | " \"input\": \"current president of USA\"\n", 246 | " }\n", 247 | " ```\n", 248 | "\n", 249 | " Remember, even when answering to the user, you must still use this JSON format! If you'd like to ask how the user is doing you must write:\n", 250 | "\n", 251 | " ```json\n", 252 | " {\n", 253 | " \"tool_name\": \"Final Answer\",\n", 254 | " \"input\": \"How are you today?\"\n", 255 | " }\n", 256 | " ```\n", 257 | "\n", 258 | " Let's get started. The users query is as follows.\n", 259 | " \"\"\"\n", 260 | " # note, don't \"\" to the end\n", 261 | "\n", 262 | " return f' [INST] {sys_message} [/INST]\\nUser: {query}\\nAssistant: '#```json\\n{{\\n\"tool_name\": '\n", 263 | "\n", 264 | "\n", 265 | "def extract_first_json(text: str, debug: bool = False):\n", 266 | " # Find the first and last braces to extract the JSON object\n", 267 | " # This will be more robust against irregular formatting\n", 268 | " start_index = text.find('{')\n", 269 | " end_index = text.rfind('}') # Get the last closing brace\n", 270 | "\n", 271 | " if start_index == -1 or end_index == -1 or end_index < start_index:\n", 272 | " return None # Return None if valid JSON braces are not found\n", 273 | "\n", 274 | " # Extract the substring that forms the JSON object\n", 275 | " json_str = text[start_index:end_index + 1]\n", 276 | "\n", 277 | " # Replace newline characters and other potential issues\n", 278 | " #json_str = json_str.replace('\\n', '\\\\n').replace('\\r', '\\\\r').replace('\\t', '\\\\t')\n", 279 | "\n", 280 | " return json_str\n", 281 | "\n", 282 | "def format_output(text: str):\n", 283 | " print_debug_message(f\"format_output(): initial text: {text}\\n\")\n", 284 | " full_json_str = extract_first_json(text)\n", 285 | "\n", 286 | " if full_json_str is None:\n", 287 | " print_debug_message(f\"format_output(): No valid JSON found calling extract_first_json() with text - {text}\")\n", 288 | " return None\n", 289 | "\n", 290 | " print_debug_message(f\"format_output(): full_json_str after extract_first_json()call - {full_json_str}\\n\")\n", 291 | "\n", 292 | " try:\n", 293 | " return json.loads(full_json_str)\n", 294 | " except json.JSONDecodeError as e:\n", 295 | " print_debug_message(f\"format_output(): Error decoding JSON from json.loads(full_json_str) with text - {text}\\nand full_json_str - {full_json_str}\\nError Message - {e}\")\n", 296 | " return None # Handle the error as needed\n", 297 | "\n", 298 | "\n", 299 | "\n", 300 | "\n", 301 | "# Processes the action dict created by the format_output() function to execute\n", 302 | "# the selected tool or provide the AI's final response based on the value of\n", 303 | "# action[\"tool_name\"] (if tool name isn't recognized, it will assume it's a\n", 304 | "# \"Final Answer\" action)\n", 305 | "def use_tool(action: dict):\n", 306 | " tool_name = action[\"tool_name\"]\n", 307 | " if tool_name == \"Final Answer\":\n", 308 | " is_tool_response = False\n", 309 | " return f\"\\n\\nAI Assistant: {action['input']}\", is_tool_response\n", 310 | " elif tool_name == \"Calculator\":\n", 311 | " print(\"\\nUsing Calculator...\\n\")\n", 312 | " # Create a dictionary to serve as a local namespace for exec\n", 313 | " local_namespace = {}\n", 314 | "\n", 315 | " # Execute the code within the local namespace\n", 316 | " exec(action[\"input\"], {}, local_namespace)\n", 317 | "\n", 318 | " # Access the value of 'result' from the local namespace\n", 319 | " exec_result = local_namespace.get('output', None)\n", 320 | " is_tool_response = True\n", 321 | " return f\"\\n\\nTool Output: {exec_result}\", is_tool_response\n", 322 | " elif tool_name == \"Search\":\n", 323 | " print(f\"\\nSearching the Web for {action['input']}...\\n\")\n", 324 | " contexts = []\n", 325 | " with DDGS() as ddgs:\n", 326 | " results = ddgs.text(\n", 327 | " action[\"input\"],\n", 328 | " region=\"wt-wt\", safesearch=\"on\",\n", 329 | " max_results=3\n", 330 | " )\n", 331 | " for r in results:\n", 332 | " contexts.append(r['body'])\n", 333 | " info = \"\\n---\\n\".join(contexts)\n", 334 | " is_tool_response = True\n", 335 | " return f\"\\n\\nTool Output: {info}\", is_tool_response\n", 336 | " else:\n", 337 | " # otherwise just assume final answer\n", 338 | " is_tool_response = False\n", 339 | " return f\"\\n\\nAI Assistant: {action['input']}\", is_tool_response\n", 340 | "\n", 341 | "# Takes a full input_prompt (with prior conversation history included),\n", 342 | "# queries the Mixtral model, processes the text-generation response, and\n", 343 | "# itteratively executes Tool Usage until a \"Final Answer\" answer is recieved.\n", 344 | "# Then, returns both the response_message to be displayed and the\n", 345 | "# new input_prompt with all new messages appended\n", 346 | "def handle_message(input_prompt):\n", 347 | "\n", 348 | " response = generate_text(input_prompt)\n", 349 | " generated_text = response[0]['generated_text']\n", 350 | " print_debug_message(f\"handle_message(): initial generated_text: {generated_text}\\n\")\n", 351 | "\n", 352 | " # If it fails to load json, it probably just responded\n", 353 | " # directly without the dict, so build a dict out of it\n", 354 | " # in a try/except clause\n", 355 | " try:\n", 356 | " action = format_output(generated_text)\n", 357 | " if action is None:\n", 358 | " # if action is None, the formatting was wrong, meaning the model probably didn't\n", 359 | " # wrap its response message in an action dict, so raise an error to let except\n", 360 | " # block reformat it into a \"Final Answer\" action dict json string\n", 361 | " raise Exception(\"Failed to parse json from generated_text to creare action dict.\")\n", 362 | " response_message, is_tool_response = use_tool(action)\n", 363 | " except Exception as e:\n", 364 | "\n", 365 | " # If formatting the output and using the tool fails, then the model\n", 366 | " # probably didn't use action dict json string formatting, so just\n", 367 | " # assume its message was intended to be a direct message to the user\n", 368 | " # and reformat it as a \"Final Answer\" action\n", 369 | " print_debug_message(f\"handle_message(): Exception Block triggered for first text-generation's format_output()/use_tool() call! Exception - {e}\")\n", 370 | "\n", 371 | " # Since the model probably sent direct response message rather than formatting it as\n", 372 | " # an action dict json string, we need to remove any newlines, whitespace, etc\n", 373 | " # to ensure the new action dict input value is json parsable in the new_generated_text\n", 374 | " # json string\n", 375 | " generated_text = sanitize_text_for_json(text = generated_text)\n", 376 | " # Now format the response into a valid action dict json string\n", 377 | " new_generated_text = \"\"\"\n", 378 | " ```json\n", 379 | " {\n", 380 | " \"tool_name\": \"Final Answer\",\n", 381 | " \"input\": \"\"\" + f\"\\\"{generated_text}\\\"\"+ \"\"\"\n", 382 | " }\n", 383 | " ```\n", 384 | " \"\"\"\n", 385 | " generated_text = new_generated_text\n", 386 | " action = format_output(generated_text)\n", 387 | " response_message, is_tool_response = use_tool(action)\n", 388 | " # Add Initial Assistant Response to the prompt\n", 389 | " input_prompt += generated_text\n", 390 | "\n", 391 | " # If response is tool response, add the response message\n", 392 | " # to the prompt and query again, return the full new\n", 393 | " # input_prompt along with the response_message to be displayed\n", 394 | " # to the user\n", 395 | " while is_tool_response:\n", 396 | " input_prompt += response_message + \"\\n\\nAI Assistant: \"\n", 397 | " response = generate_text(input_prompt)\n", 398 | " generated_text = response[0]['generated_text']\n", 399 | " print_debug_message(f\"handle_message(): while is_tool_response generated_text : {generated_text}\")\n", 400 | "\n", 401 | " # If it fails to load json, it probably just responded\n", 402 | " # directly without the dict, so build a dict out of it\n", 403 | " # in a try/except clause\n", 404 | " try:\n", 405 | " action = format_output(generated_text)\n", 406 | " if action is None:\n", 407 | " # if action is None, the formatting was wrong, meaning the model probably didn't\n", 408 | " # wrap its response message in an action dict, so raise an error to let except\n", 409 | " # block reformat it into a \"Final Answer\" action dict json string\n", 410 | " raise Exception(\"Failed to parse json from generated_text to creare action dict.\")\n", 411 | " response_message, is_tool_response = use_tool(action)\n", 412 | " except Exception as e:\n", 413 | "\n", 414 | " # If formatting the output and using the tool fails, then the model\n", 415 | " # probably didn't use action dict json string formatting, so just\n", 416 | " # assume its message was intended to be a direct message to the user\n", 417 | " # and reformat it as a \"Final Answer\" action\n", 418 | " print_debug_message(f\"handle_message(): Exception Block triggered for 'while is_tool_response' loop format_output()/use_tool() call! Exception - {e}\")\n", 419 | "\n", 420 | " # Since the model probably sent direct response message rather than formatting it as\n", 421 | " # an action dict json string, we need to remove any newlines, whitespace, etc\n", 422 | " # to ensure the new action dict input value is json parsable in the new_generated_text\n", 423 | " # json string\n", 424 | " generated_text = sanitize_text_for_json(text = generated_text)\n", 425 | " # Now format the response into a valid action dict json string\n", 426 | " new_generated_text = \"\"\"\n", 427 | " ```json\n", 428 | " {\n", 429 | " \"tool_name\": \"Final Answer\",\n", 430 | " \"input\": \"\"\" + f\"\\\"{generated_text}\\\"\"+ \"\"\"\n", 431 | " }\n", 432 | " ```\n", 433 | " \"\"\"\n", 434 | " generated_text = new_generated_text\n", 435 | " action = format_output(generated_text)\n", 436 | " response_message, is_tool_response = use_tool(action)\n", 437 | " # Add the new is_tool_response_loop generated_text to the prompt\n", 438 | " input_prompt += generated_text\n", 439 | "\n", 440 | "\n", 441 | "\n", 442 | " return response_message, input_prompt\n", 443 | "\n", 444 | "# Initializes AI Chat orchestrating user inputs and multi-step chat logic\n", 445 | "def run_ai_chat(sys_message: str = None, max_user_messages: int = 10, chat_log_dir: str = \"logs/\"):\n", 446 | "\n", 447 | " # Create Chat Log Filepath to save convo history\n", 448 | " now = datetime.datetime.now()\n", 449 | " formatted_datetime = now.strftime(\"%Y-%m-%d_%H-%M-%S\")\n", 450 | "\n", 451 | " if not os.path.exists(chat_log_dir):\n", 452 | " os.makedirs(chat_log_dir)\n", 453 | " chat_log_file = f\"{chat_log_dir}User_Conversation_{formatted_datetime}.txt\"\n", 454 | "\n", 455 | " # Begin AI Assistant Chat\n", 456 | " output.clear()\n", 457 | " print(\"\\n\\nBeginning AI Assistant Chat. Type 'exit' at any time to end the chat\\n\\n\\n\")\n", 458 | " print(\"\\n\\n\")\n", 459 | " user_message = input(\"User: \")\n", 460 | " print(\"\\n\\n\")\n", 461 | " input_prompt = first_prompt_instruction_format(user_message, sys_message = sys_message)\n", 462 | " response_message, input_prompt = handle_message(\n", 463 | " input_prompt\n", 464 | " )\n", 465 | "\n", 466 | " print(response_message)\n", 467 | " print(\"\\n\\n\")\n", 468 | " # Save Chat Log\n", 469 | " with open(chat_log_file, 'w', encoding='utf-8') as file:\n", 470 | " file.write(input_prompt)\n", 471 | " # Start at 1, not 0, cuz they already sent the first message\n", 472 | " for i in range(1, max_user_messages):\n", 473 | " print(\"\\n\\n\")\n", 474 | " user_message = input(\"User: \")\n", 475 | " print(\"\\n\\n\")\n", 476 | " input_prompt += \"\\nUser: \" + user_message\n", 477 | "\n", 478 | " if user_message.upper() == 'EXIT':\n", 479 | " input_prompt += \"\\n\\n\\nChat Ended. Have a great day! (:\"\n", 480 | " print(\"\\n\\n\\nChat Ended. Have a great day! (:\\n\\n\")\n", 481 | " # Save Chat Log\n", 482 | " with open(chat_log_file, 'w', encoding='utf-8') as file:\n", 483 | " file.write(input_prompt)\n", 484 | " return input_prompt\n", 485 | " input_prompt += f\"\\nUser: {user_message}\"\n", 486 | " response_message, input_prompt = handle_message(\n", 487 | " input_prompt\n", 488 | " )\n", 489 | "\n", 490 | " print(response_message)\n", 491 | " print(\"\\n\\n\")\n", 492 | " # Save Chat Log\n", 493 | " with open(chat_log_file, 'w', encoding='utf-8') as file:\n", 494 | " file.write(input_prompt)\n", 495 | "\n", 496 | " print(\"\\n\\n\\nChat Ended. Have a great day! (:\")\n", 497 | " # Save Chat Log\n", 498 | " with open(chat_log_file, 'w', encoding='utf-8') as file:\n", 499 | " file.write(input_prompt)\n", 500 | " return input_prompt\n", 501 | "\n", 502 | "\n", 503 | "\n", 504 | "print(\"\\n\\n\\nMixtral-8x7B-Instruct Model Loaded Successfully!\")\n", 505 | "\n", 506 | "\n", 507 | "\n" 508 | ], 509 | "metadata": { 510 | "cellView": "form", 511 | "id": "tcFSb5rejKzj" 512 | }, 513 | "execution_count": null, 514 | "outputs": [] 515 | }, 516 | { 517 | "cell_type": "markdown", 518 | "source": [ 519 | "# 4) Run Mixtral AI Chat\n", 520 | "\n", 521 | "Once the Mixtral model is downloaded and initialized, execute the cell below to begin your chat session.\n", 522 | "\n", 523 | "### Chat Session Settings\n", 524 | "- `max_user_messages: int = 20`: Sets the limit for user messages in a single AI Chat session. This helps to maintain conversations within the model's maximum sequence length.\n", 525 | "- `debug_mode: bool = False`: Toggle this to `True` to enable the display of debug messages.\n" 526 | ], 527 | "metadata": { 528 | "id": "rfVLoMBRxlpp" 529 | } 530 | }, 531 | { 532 | "cell_type": "code", 533 | "source": [ 534 | "# @title Run Mixtral8x7B AI Chat with Tools\n", 535 | "\n", 536 | "# How many total messages the user can send in a given AI Chat session\n", 537 | "# before it automatically ends the conversation. This prevents\n", 538 | "# conversations from becoming too long to fit within the model's maximum\n", 539 | "# sequence length\n", 540 | "max_user_messages = 20 # @param {type:\"integer\"}\n", 541 | "debug_mode = False # @param {type:\"boolean\"}\n", 542 | "\n", 543 | "# Set sys_message to None to use default tool-enabled system prompt (default\n", 544 | "# system prompt is located in the first_prompt_instruction_format() function\n", 545 | "# of the \"Download & Initialize Mixtral8x7B Model\" cell above)\n", 546 | "sys_message = None, # Setting \"None\"\n", 547 | "\n", 548 | "# Set \"DEBUG\" env variable to toggle print_debug_message() function printing\n", 549 | "if debug_mode:\n", 550 | " os.environ['DEBUG'] = 'TRUE'\n", 551 | "else:\n", 552 | " os.environ['DEBUG'] = 'FALSE'\n", 553 | "\n", 554 | "# Creates a chat_log_dir directory to save a .txt file of the full\n", 555 | "# chat conversation for each AI Chat session (including system prompt,\n", 556 | "# Tool Calls, Tool Outputs, etc) to help with debugging\n", 557 | "chat_log_dir = \"logs/\"\n", 558 | "\n", 559 | "# Set the logging level to suppress warnings\n", 560 | "logging.getLogger('transformers').setLevel(logging.ERROR)\n", 561 | "logging.getLogger('bitsandbytes').setLevel(logging.ERROR)\n", 562 | "\n", 563 | "input_prompt = run_ai_chat(\n", 564 | " sys_message = sys_message,\n", 565 | " max_user_messages = max_user_messages,\n", 566 | " chat_log_dir = chat_log_dir\n", 567 | " )\n" 568 | ], 569 | "metadata": { 570 | "cellView": "form", 571 | "id": "SBUFYBDuxcgw" 572 | }, 573 | "execution_count": null, 574 | "outputs": [] 575 | } 576 | ], 577 | "metadata": { 578 | "accelerator": "GPU", 579 | "colab": { 580 | "provenance": [], 581 | "private_outputs": true, 582 | "gpuType": "A100", 583 | "machine_shape": "hm", 584 | "collapsed_sections": [ 585 | "TPVUrcqKkvuG", 586 | "AQ5nkCe8h35L", 587 | "rfVLoMBRxlpp" 588 | ] 589 | }, 590 | "kernelspec": { 591 | "display_name": "Python 3", 592 | "name": "python3" 593 | }, 594 | "language_info": { 595 | "name": "python" 596 | } 597 | }, 598 | "nbformat": 4, 599 | "nbformat_minor": 0 600 | } 601 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Mixtral8x7B-AI-Chat-Colab 2 | Download, Initialization, Tooling, and Chat Session Logic of Mistal AI's Mixtral8x7B "Mixtral of Experts" model in Google Colab 3 | 4 | [![Run in Google Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/willspag/Mixtral8x7B-AI-Chat-Colab/blob/main/Mixtral8x7B_AI_Chat_with_Tools.ipynb) 5 | 6 | ## This notebook downloads and initializes Mistal AI's Mixtral8x7B [Mixtral of Experts](https://mistral.ai/news/mixtral-of-experts/) using the [Huggingface transformers](https://huggingface.co/docs/transformers/index) package along with facilitating Tool Usage and Chat Session Logic. 7 | 8 | The code is based on this [Pinecone Tutorial](https://www.pinecone.io/learn/mixtral-8x7b/) and extends it to include additional inference/memory improvements (4-bit/8-bit quantization and Flash Attention 2) along with full multi-step chat logic. Please check out their tutorial for a step-by-step walkthrough on the implementation. 9 | 10 | ### Tools 11 | 12 | - "Calculator" - Enables Mixtral AI Assistant to perform math calculations by providing python code which is then run executed using exec() - **WARNING* Using exec() to execute arbitrary Python code can be dangerous, as it can execute any Python command. This might lead to security vulnerabilities, especially in a web-based environment like Google Colab. This notebook does not currently apply any sanitzation, filtering, or other safety measures to the code generated by the Mixtral model before code execution, so please be cautious and use at your own risk.** 13 | - "Search" - Enables Mixtral AI Assistant to Search the web for real-time information using DuckDuckGo 14 | --------------------------------------------------------------------------------