├── GeminiFunctionCallingWithRAGChromaDb.ipynb ├── GeminiFunctionCalling_Base.ipynb └── LICENSE /GeminiFunctionCallingWithRAGChromaDb.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "view-in-github", 7 | "colab_type": "text" 8 | }, 9 | "source": [ 10 | "\"Open" 11 | ] 12 | }, 13 | { 14 | "cell_type": "markdown", 15 | "source": [ 16 | "# Putting together Function Calling and RAGs with Gemini\n", 17 | "\n", 18 | "In this notebook, we explore using Gemini APIs with Function Calling and Retrieval Augmented Generation. It covers a method to use ChromaDB powered vector store to act as a cache for an external API.\n", 19 | "\n", 20 | "We use the [Frankfurter API](https://www.frankfurter.app/) for fetching exchange rates between currencies." 21 | ], 22 | "metadata": { 23 | "id": "mQJc0rOrd9YP" 24 | } 25 | }, 26 | { 27 | "cell_type": "markdown", 28 | "metadata": { 29 | "id": "uHK1b5sH1Tu8" 30 | }, 31 | "source": [ 32 | "## Setup\n", 33 | "\n", 34 | "1. Install ChromaDB\n", 35 | "2. Install Google's SDK for Gen AI\n", 36 | "\n", 37 | "This notebook uses Google AI Studio for calling the APIs. This method is slightly different from using Vertex AI Studio based Gemini Model APIs." 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": 1, 43 | "metadata": { 44 | "id": "I73Zu7da0gpb" 45 | }, 46 | "outputs": [], 47 | "source": [ 48 | "%%capture\n", 49 | "!pip install chromadb\n", 50 | "!pip install -U -q google-generativeai" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "source": [ 56 | "import requests\n", 57 | "import google.generativeai as genai\n", 58 | "import chromadb\n", 59 | "from datetime import datetime, timedelta" 60 | ], 61 | "metadata": { 62 | "id": "rWS3LVtIYf8V" 63 | }, 64 | "execution_count": 2, 65 | "outputs": [] 66 | }, 67 | { 68 | "cell_type": "markdown", 69 | "source": [ 70 | "Instantiating a Gemini API client." 71 | ], 72 | "metadata": { 73 | "id": "LqejWRLrfdZN" 74 | } 75 | }, 76 | { 77 | "cell_type": "code", 78 | "source": [ 79 | "try:\n", 80 | " # Used to securely store your API key\n", 81 | " from google.colab import userdata\n", 82 | "\n", 83 | " # Or use `os.getenv('API_KEY')` to fetch an environment variable.\n", 84 | " GOOGLE_API_KEY=userdata.get('GOOGLE_API_KEY')\n", 85 | "except ImportError:\n", 86 | " import os\n", 87 | " GOOGLE_API_KEY = os.environ['GOOGLE_API_KEY']\n", 88 | "\n", 89 | "genai.configure(api_key=GOOGLE_API_KEY)" 90 | ], 91 | "metadata": { 92 | "id": "NBhrIrYo5Ric" 93 | }, 94 | "execution_count": 3, 95 | "outputs": [] 96 | }, 97 | { 98 | "cell_type": "markdown", 99 | "source": [ 100 | "Instantiating a ChromaDB client" 101 | ], 102 | "metadata": { 103 | "id": "5GdFghY7fgl_" 104 | } 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": 4, 109 | "metadata": { 110 | "id": "cFkvhQdE1znl" 111 | }, 112 | "outputs": [], 113 | "source": [ 114 | "chroma_client = chromadb.Client()" 115 | ] 116 | }, 117 | { 118 | "cell_type": "markdown", 119 | "source": [ 120 | "Creating a collection, you can have multiple collections per client. These are all stored in memory, so make sure your RAM is high enough to hold these documents." 121 | ], 122 | "metadata": { 123 | "id": "92kO_9tufpmU" 124 | } 125 | }, 126 | { 127 | "cell_type": "code", 128 | "execution_count": 5, 129 | "metadata": { 130 | "id": "ZMabyYgw08Kp" 131 | }, 132 | "outputs": [], 133 | "source": [ 134 | "collection = chroma_client.create_collection(name=\"exchange-rates\")" 135 | ] 136 | }, 137 | { 138 | "cell_type": "markdown", 139 | "source": [ 140 | "## [Optional] Initialize the collection with latest 5 days of data on exchange rates.\n", 141 | "\n", 142 | "Let's get the latest 5 dates for which we'll fetch the initial chunk of data into the DB." 143 | ], 144 | "metadata": { 145 | "id": "GWoUNAE-f0sP" 146 | } 147 | }, 148 | { 149 | "cell_type": "code", 150 | "execution_count": 6, 151 | "metadata": { 152 | "id": "m2JhWLSH097J" 153 | }, 154 | "outputs": [], 155 | "source": [ 156 | "current_date = datetime.now()\n", 157 | "dates = [(current_date - timedelta(days=i)).strftime('%Y-%m-%d') for i in range(1, 6)]" 158 | ] 159 | }, 160 | { 161 | "cell_type": "code", 162 | "execution_count": 7, 163 | "metadata": { 164 | "colab": { 165 | "base_uri": "https://localhost:8080/" 166 | }, 167 | "id": "rZVkHP5p2bS_", 168 | "outputId": "bd6c360c-829f-4c08-ec8a-4cf2f59f3668" 169 | }, 170 | "outputs": [ 171 | { 172 | "output_type": "execute_result", 173 | "data": { 174 | "text/plain": [ 175 | "['2024-05-10', '2024-05-09', '2024-05-08', '2024-05-07', '2024-05-06']" 176 | ] 177 | }, 178 | "metadata": {}, 179 | "execution_count": 7 180 | } 181 | ], 182 | "source": [ 183 | "dates" 184 | ] 185 | }, 186 | { 187 | "cell_type": "code", 188 | "execution_count": 8, 189 | "metadata": { 190 | "id": "1Eg_H3ol3Fgt" 191 | }, 192 | "outputs": [], 193 | "source": [ 194 | "url = \"https://api.frankfurter.app/latest\"\n", 195 | "response = requests.get(url)\n", 196 | "api_data = response.json()\n", 197 | "\n", 198 | "currencies = list(api_data['rates'].keys())" 199 | ] 200 | }, 201 | { 202 | "cell_type": "code", 203 | "execution_count": 9, 204 | "metadata": { 205 | "id": "8AhPIXIR2cBK" 206 | }, 207 | "outputs": [], 208 | "source": [ 209 | "def initial_fetch_conversion_rates(date, currencies):\n", 210 | " rates = {}\n", 211 | " try:\n", 212 | " response = requests.get(f'https://api.frankfurter.app/{date}')\n", 213 | " data = response.json()\n", 214 | " rates = {curr: data['rates'].get(curr, 'N/A') for curr in currencies}\n", 215 | " except Exception as e:\n", 216 | " print(f\"Error fetching conversion rates for {date}: {e}\")\n", 217 | " return rates" 218 | ] 219 | }, 220 | { 221 | "cell_type": "code", 222 | "execution_count": 10, 223 | "metadata": { 224 | "id": "A9hgdjD627HU" 225 | }, 226 | "outputs": [], 227 | "source": [ 228 | "documents = []\n", 229 | "metadatas = []\n", 230 | "ids = []\n", 231 | "for date in dates:\n", 232 | " rates = initial_fetch_conversion_rates(date, currencies)\n", 233 | " for currency, rate in rates.items():\n", 234 | " document = f\"EUR to {currency} conversion rate is {rate} on {date}\"\n", 235 | " documents.append(document)\n", 236 | " metadata = {\"base\": \"EUR\", \"target\": currency, \"date\": date, \"rate\": rate}\n", 237 | " id_ = f\"EUR-{currency}-{date.replace('-', '')}\"\n", 238 | " metadatas.append(metadata)\n", 239 | " ids.append(id_)" 240 | ] 241 | }, 242 | { 243 | "cell_type": "code", 244 | "execution_count": 11, 245 | "metadata": { 246 | "id": "KI5cUBjs3MIc" 247 | }, 248 | "outputs": [], 249 | "source": [ 250 | "collection.add(\n", 251 | " metadatas=metadatas,\n", 252 | " documents=documents,\n", 253 | " ids=ids,\n", 254 | ")" 255 | ] 256 | }, 257 | { 258 | "cell_type": "markdown", 259 | "source": [ 260 | "Quick check of whether our initial fetch of data works!" 261 | ], 262 | "metadata": { 263 | "id": "G70L-HBPgL7i" 264 | } 265 | }, 266 | { 267 | "cell_type": "code", 268 | "source": [ 269 | "search_query = \"EUR to INR conversion rate on 09-05-2024\"" 270 | ], 271 | "metadata": { 272 | "id": "zoOmlJ-n3pBo" 273 | }, 274 | "execution_count": 12, 275 | "outputs": [] 276 | }, 277 | { 278 | "cell_type": "code", 279 | "execution_count": 13, 280 | "metadata": { 281 | "id": "d7qikJLg3TP1", 282 | "colab": { 283 | "base_uri": "https://localhost:8080/" 284 | }, 285 | "outputId": "65280a65-d8c8-4af2-972c-d6ab617c6f8f" 286 | }, 287 | "outputs": [ 288 | { 289 | "output_type": "execute_result", 290 | "data": { 291 | "text/plain": [ 292 | "{'ids': [['EUR-INR-20240509']],\n", 293 | " 'distances': [[0.08850838989019394]],\n", 294 | " 'metadatas': [[{'base': 'EUR',\n", 295 | " 'date': '2024-05-09',\n", 296 | " 'rate': 89.61,\n", 297 | " 'target': 'INR'}]],\n", 298 | " 'embeddings': None,\n", 299 | " 'documents': [['EUR to INR conversion rate is 89.61 on 2024-05-09']],\n", 300 | " 'uris': None,\n", 301 | " 'data': None}" 302 | ] 303 | }, 304 | "metadata": {}, 305 | "execution_count": 13 306 | } 307 | ], 308 | "source": [ 309 | "results = collection.query(\n", 310 | " query_texts=[search_query], # Chroma will embed this for you\n", 311 | " n_results=1 # how many results to return\n", 312 | ")\n", 313 | "results" 314 | ] 315 | }, 316 | { 317 | "cell_type": "markdown", 318 | "source": [ 319 | "## Function for fetching data by date, from and target currency directly from API\n", 320 | "\n", 321 | "We'll call this API anytime we need to fetch new data from the API." 322 | ], 323 | "metadata": { 324 | "id": "YKZTwv24gP4r" 325 | } 326 | }, 327 | { 328 | "cell_type": "code", 329 | "source": [ 330 | "def get_exchange_rate_from_api(currency_from, currency_to, currency_date=\"latest\"):\n", 331 | " try:\n", 332 | " response = requests.get(f'https://api.frankfurter.app/{currency_date}', params={\n", 333 | " \"from\": currency_from,\n", 334 | " \"to\": currency_to\n", 335 | " })\n", 336 | " metadata, document, id_ = None, None, None\n", 337 | " data = response.json()\n", 338 | "\n", 339 | " rate = data['rates'][currency_to]\n", 340 | " document = f\"{currency_from} to {currency_to} conversion rate is {rate} on {currency_date}\"\n", 341 | " metadata = {\"base\": currency_from, \"target\": currency_to, \"date\": currency_date, \"rate\": rate}\n", 342 | " id_ = f\"{currency_from}-{currency_to}-{currency_date.replace('-', '')}\"\n", 343 | "\n", 344 | " collection.upsert(\n", 345 | " metadatas=[metadata],\n", 346 | " documents=[document],\n", 347 | " ids=[id_],\n", 348 | " )\n", 349 | "\n", 350 | " print(\"Added document\")\n", 351 | " print(document)\n", 352 | " except Exception as e:\n", 353 | " print(f\"Error fetching conversion rates for {currency_date}: {e}\")\n", 354 | " raise(e)\n", 355 | " return document" 356 | ], 357 | "metadata": { 358 | "id": "1Asd4prK6lht" 359 | }, 360 | "execution_count": 14, 361 | "outputs": [] 362 | }, 363 | { 364 | "cell_type": "markdown", 365 | "source": [ 366 | "Quick check if our function above is working!" 367 | ], 368 | "metadata": { 369 | "id": "Zo5RUNp5gZrB" 370 | } 371 | }, 372 | { 373 | "cell_type": "code", 374 | "source": [ 375 | "get_exchange_rate_from_api(\"USD\", \"INR\", \"2024-05-10\")" 376 | ], 377 | "metadata": { 378 | "colab": { 379 | "base_uri": "https://localhost:8080/", 380 | "height": 71 381 | }, 382 | "id": "9S-UJXDM-70j", 383 | "outputId": "896b2875-8674-4e3a-8f2e-a9b347205713" 384 | }, 385 | "execution_count": 15, 386 | "outputs": [ 387 | { 388 | "output_type": "stream", 389 | "name": "stdout", 390 | "text": [ 391 | "Added document\n", 392 | "USD to INR conversion rate is 83.51 on 2024-05-10\n" 393 | ] 394 | }, 395 | { 396 | "output_type": "execute_result", 397 | "data": { 398 | "text/plain": [ 399 | "'USD to INR conversion rate is 83.51 on 2024-05-10'" 400 | ], 401 | "application/vnd.google.colaboratory.intrinsic+json": { 402 | "type": "string" 403 | } 404 | }, 405 | "metadata": {}, 406 | "execution_count": 15 407 | } 408 | ] 409 | }, 410 | { 411 | "cell_type": "markdown", 412 | "source": [ 413 | "## Function that is ready to work with Gemini API\n", 414 | "\n", 415 | "This function first checks if we have the data for those dates and currrency pair already in the database. If not, fetch from API and return.\n", 416 | "\n", 417 | "Check out all those verbose descriptions, these help Gemini understand this function better while deciding when and how to call it!" 418 | ], 419 | "metadata": { 420 | "id": "WlIbIQfZgcZD" 421 | } 422 | }, 423 | { 424 | "cell_type": "code", 425 | "source": [ 426 | "def get_exchange_rate(currency_from: str, currency_to: str, currency_date: str =\"latest\"):\n", 427 | " \"\"\"\n", 428 | " This function retrieves the exchange rate between two currencies on a specific date.\n", 429 | "\n", 430 | " Args:\n", 431 | " currency_from (str): The currency to convert from (ISO 4217 format). (Required)\n", 432 | " currency_to (str): The currency to convert to (ISO 4217 format). (Required)\n", 433 | " currency_date (str, optional): The date for the exchange rate in YYYY-MM-DD format\n", 434 | " or \"latest\" for the most recent rate. Defaults to \"latest\".\n", 435 | "\n", 436 | " Returns:\n", 437 | " float: The exchange rate (currency_to per unit of currency_from).\n", 438 | " If the rate cannot be retrieved, returns None.\n", 439 | "\n", 440 | " Raises:\n", 441 | " ValueError: If either currency code is invalid or the date format is incorrect.\n", 442 | " \"\"\"\n", 443 | "\n", 444 | " ## First, check if the Vector Store has the data\n", 445 | " search_query = f\"{currency_from} to {currency_to} conversion on {currency_date}\"\n", 446 | "\n", 447 | " results = collection.query(\n", 448 | " query_texts=[search_query], # Chroma will embed this for you\n", 449 | " n_results=1 # how many results to return\n", 450 | " )\n", 451 | "\n", 452 | " if (results[\"ids\"][0][0] == f\"{currency_from}-{currency_to}-{currency_date.replace('-', '')}\"):\n", 453 | " print(\"Vector Store hit, let's return this.\")\n", 454 | " print(\"DB return: \", results[\"documents\"][0])\n", 455 | " return results[\"documents\"][0]\n", 456 | "\n", 457 | " api_data = get_exchange_rate_from_api(currency_from, currency_to, currency_date)\n", 458 | " print(\"Vector store miss. Fetching data from API and returning!\")\n", 459 | " print(\"API return: \",api_data )\n", 460 | " return api_data" 461 | ], 462 | "metadata": { 463 | "id": "o53em65cXWCd" 464 | }, 465 | "execution_count": 47, 466 | "outputs": [] 467 | }, 468 | { 469 | "cell_type": "markdown", 470 | "source": [ 471 | "Quick check of above function running!" 472 | ], 473 | "metadata": { 474 | "id": "n0lP-Hn_g5sC" 475 | } 476 | }, 477 | { 478 | "cell_type": "code", 479 | "source": [ 480 | "get_exchange_rate(\"USD\", \"INR\", \"2024-05-01\")" 481 | ], 482 | "metadata": { 483 | "colab": { 484 | "base_uri": "https://localhost:8080/" 485 | }, 486 | "id": "q0x6AGJAbvy3", 487 | "outputId": "c7afba06-4465-47bb-d7db-f3265ea6b59c" 488 | }, 489 | "execution_count": 48, 490 | "outputs": [ 491 | { 492 | "output_type": "stream", 493 | "name": "stdout", 494 | "text": [ 495 | "Vector Store hit, let's return this.\n", 496 | "DB return: ['USD to INR conversion rate is 83.43 on 2024-05-01']\n" 497 | ] 498 | }, 499 | { 500 | "output_type": "execute_result", 501 | "data": { 502 | "text/plain": [ 503 | "['USD to INR conversion rate is 83.43 on 2024-05-01']" 504 | ] 505 | }, 506 | "metadata": {}, 507 | "execution_count": 48 508 | } 509 | ] 510 | }, 511 | { 512 | "cell_type": "markdown", 513 | "source": [ 514 | "## Declare the model with the function declaration\n", 515 | "\n", 516 | "The SDK converts your function to a declaration for usage, if does not send the function anywhere!" 517 | ], 518 | "metadata": { 519 | "id": "5adcuF4Ug9Jq" 520 | } 521 | }, 522 | { 523 | "cell_type": "code", 524 | "source": [ 525 | "model = genai.GenerativeModel(model_name='gemini-1.0-pro',\n", 526 | " tools=[get_exchange_rate])" 527 | ], 528 | "metadata": { 529 | "id": "U7R63JidWvS1" 530 | }, 531 | "execution_count": 49, 532 | "outputs": [] 533 | }, 534 | { 535 | "cell_type": "markdown", 536 | "source": [ 537 | "Turns out chat usage has automatic function calling enabled, let's use it for the oomph factor!" 538 | ], 539 | "metadata": { 540 | "id": "4DhLbe5phGQx" 541 | } 542 | }, 543 | { 544 | "cell_type": "code", 545 | "source": [ 546 | "chat = model.start_chat(enable_automatic_function_calling=True)" 547 | ], 548 | "metadata": { 549 | "id": "3fDAELNMXUU0" 550 | }, 551 | "execution_count": 50, 552 | "outputs": [] 553 | }, 554 | { 555 | "cell_type": "markdown", 556 | "source": [ 557 | "Wrap the chat invocation in a nice neat function!" 558 | ], 559 | "metadata": { 560 | "id": "fKjed8tqhNU3" 561 | } 562 | }, 563 | { 564 | "cell_type": "code", 565 | "source": [ 566 | "def ask_gemini(query):\n", 567 | " response = chat.send_message(query)\n", 568 | " return response.text" 569 | ], 570 | "metadata": { 571 | "id": "getq18lSdlI1" 572 | }, 573 | "execution_count": 51, 574 | "outputs": [] 575 | }, 576 | { 577 | "cell_type": "markdown", 578 | "source": [ 579 | "## Testing" 580 | ], 581 | "metadata": { 582 | "id": "Ml0yMBIqhRzD" 583 | } 584 | }, 585 | { 586 | "cell_type": "markdown", 587 | "source": [ 588 | "Check: 20 USD to INR on 2024-05-10" 589 | ], 590 | "metadata": { 591 | "id": "aEKJrk_bhTEm" 592 | } 593 | }, 594 | { 595 | "cell_type": "code", 596 | "source": [ 597 | "ask_gemini('I have 20 US dollars with me. Calculate how much I have in INR as of 10th May, 2024.')" 598 | ], 599 | "metadata": { 600 | "colab": { 601 | "base_uri": "https://localhost:8080/", 602 | "height": 71 603 | }, 604 | "id": "MdXHEUOEWxgd", 605 | "outputId": "10a11994-934d-45b4-9f22-4098d80b701f" 606 | }, 607 | "execution_count": 52, 608 | "outputs": [ 609 | { 610 | "output_type": "stream", 611 | "name": "stdout", 612 | "text": [ 613 | "Vector Store hit, let's return this.\n", 614 | "DB return: ['USD to INR conversion rate is 83.51 on 2024-05-10']\n" 615 | ] 616 | }, 617 | { 618 | "output_type": "execute_result", 619 | "data": { 620 | "text/plain": [ 621 | "'20 US dollars would be equal to 1,670.2 INR as of May 10th, 2024.'" 622 | ], 623 | "application/vnd.google.colaboratory.intrinsic+json": { 624 | "type": "string" 625 | } 626 | }, 627 | "metadata": {}, 628 | "execution_count": 52 629 | } 630 | ] 631 | }, 632 | { 633 | "cell_type": "markdown", 634 | "source": [ 635 | "Check: 20 USD to INR on 2024-05-01" 636 | ], 637 | "metadata": { 638 | "id": "zR6X5uq_hbTK" 639 | } 640 | }, 641 | { 642 | "cell_type": "code", 643 | "source": [ 644 | "ask_gemini('I have 20 US dollars with me. Calculate how much I have in INR as of 1st May, 2024?')" 645 | ], 646 | "metadata": { 647 | "colab": { 648 | "base_uri": "https://localhost:8080/", 649 | "height": 71 650 | }, 651 | "id": "om1caMytaZz_", 652 | "outputId": "03efd0c3-1bfd-4087-97c6-05ae4c1e0c05" 653 | }, 654 | "execution_count": 53, 655 | "outputs": [ 656 | { 657 | "output_type": "stream", 658 | "name": "stdout", 659 | "text": [ 660 | "Vector Store hit, let's return this.\n", 661 | "DB return: ['USD to INR conversion rate is 83.43 on 2024-05-01']\n" 662 | ] 663 | }, 664 | { 665 | "output_type": "execute_result", 666 | "data": { 667 | "text/plain": [ 668 | "'20 US dollars would be equal to 1,668.6 INR as of May 1st, 2024.'" 669 | ], 670 | "application/vnd.google.colaboratory.intrinsic+json": { 671 | "type": "string" 672 | } 673 | }, 674 | "metadata": {}, 675 | "execution_count": 53 676 | } 677 | ] 678 | }, 679 | { 680 | "cell_type": "markdown", 681 | "source": [ 682 | "Check: 20 USD to INR on 2024-01-05" 683 | ], 684 | "metadata": { 685 | "id": "DQK791C4hczQ" 686 | } 687 | }, 688 | { 689 | "cell_type": "code", 690 | "source": [ 691 | "ask_gemini('I have 20 US dollars with me. Calculate how much I have in INR as of 5th January, 2024?')" 692 | ], 693 | "metadata": { 694 | "colab": { 695 | "base_uri": "https://localhost:8080/", 696 | "height": 107 697 | }, 698 | "id": "9kNwEl-uciMD", 699 | "outputId": "d5c756b8-6305-46e4-90e6-3769f2c968d2" 700 | }, 701 | "execution_count": 54, 702 | "outputs": [ 703 | { 704 | "output_type": "stream", 705 | "name": "stdout", 706 | "text": [ 707 | "Added document\n", 708 | "USD to INR conversion rate is 83.15 on 2024-01-05\n", 709 | "Vector store miss. Fetching data from API and returning!\n", 710 | "API return: USD to INR conversion rate is 83.15 on 2024-01-05\n" 711 | ] 712 | }, 713 | { 714 | "output_type": "execute_result", 715 | "data": { 716 | "text/plain": [ 717 | "'20 US dollars would be equal to 1,663 INR as of January 5th, 2024.'" 718 | ], 719 | "application/vnd.google.colaboratory.intrinsic+json": { 720 | "type": "string" 721 | } 722 | }, 723 | "metadata": {}, 724 | "execution_count": 54 725 | } 726 | ] 727 | }, 728 | { 729 | "cell_type": "code", 730 | "source": [], 731 | "metadata": { 732 | "id": "e7Bb6eyWhosK" 733 | }, 734 | "execution_count": 24, 735 | "outputs": [] 736 | } 737 | ], 738 | "metadata": { 739 | "colab": { 740 | "provenance": [], 741 | "toc_visible": true, 742 | "authorship_tag": "ABX9TyPdOF9X71AfpufvKjC6Qq+c", 743 | "include_colab_link": true 744 | }, 745 | "kernelspec": { 746 | "display_name": "Python 3", 747 | "name": "python3" 748 | }, 749 | "language_info": { 750 | "name": "python" 751 | } 752 | }, 753 | "nbformat": 4, 754 | "nbformat_minor": 0 755 | } -------------------------------------------------------------------------------- /GeminiFunctionCalling_Base.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "provenance": [], 7 | "authorship_tag": "ABX9TyM49MTHCavJQ7lmQC0fZEAt", 8 | "include_colab_link": true 9 | }, 10 | "kernelspec": { 11 | "name": "python3", 12 | "display_name": "Python 3" 13 | }, 14 | "language_info": { 15 | "name": "python" 16 | } 17 | }, 18 | "cells": [ 19 | { 20 | "cell_type": "markdown", 21 | "metadata": { 22 | "id": "view-in-github", 23 | "colab_type": "text" 24 | }, 25 | "source": [ 26 | "\"Open" 27 | ] 28 | }, 29 | { 30 | "cell_type": "markdown", 31 | "source": [ 32 | "# Basic Gemini Function Calling\n", 33 | "\n", 34 | "This notebook follows along the required code to complete the lab on [How to Interact with APIs Using Function Calling in Gemini\n", 35 | "](https://codelabs.developers.google.com/codelabs/gemini-function-calling) by [@koverholt](https://github.com/koverholt)." 36 | ], 37 | "metadata": { 38 | "id": "EJOQZY29makV" 39 | } 40 | }, 41 | { 42 | "cell_type": "markdown", 43 | "source": [ 44 | "## Step 1. Overview\n", 45 | "This section talks about the importance of Function Calling and how it enables Gemini to access data that may be real-time, protected or otherwise unavailable in the Gemini training datasets." 46 | ], 47 | "metadata": { 48 | "id": "74lqwXX-m-wf" 49 | } 50 | }, 51 | { 52 | "cell_type": "markdown", 53 | "source": [ 54 | "## Step 2. Setup and requirements" 55 | ], 56 | "metadata": { 57 | "id": "Q0FxkQ9Nm4DE" 58 | } 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": 2, 63 | "metadata": { 64 | "id": "YMqhtV0LT2Ff" 65 | }, 66 | "outputs": [], 67 | "source": [ 68 | "%%capture\n", 69 | "!pip install --upgrade google-cloud-aiplatform" 70 | ] 71 | }, 72 | { 73 | "cell_type": "markdown", 74 | "source": [ 75 | "## Step 3. Understand the problem" 76 | ], 77 | "metadata": { 78 | "id": "QLqRh1GenSAy" 79 | } 80 | }, 81 | { 82 | "cell_type": "code", 83 | "source": [ 84 | "import vertexai\n", 85 | "from vertexai.preview.generative_models import GenerativeModel" 86 | ], 87 | "metadata": { 88 | "id": "afMdh8d1T8Ub" 89 | }, 90 | "execution_count": 25, 91 | "outputs": [] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "source": [ 96 | "from google.oauth2 import service_account" 97 | ], 98 | "metadata": { 99 | "id": "Rc4EfQ5jZG6U" 100 | }, 101 | "execution_count": 26, 102 | "outputs": [] 103 | }, 104 | { 105 | "cell_type": "code", 106 | "source": [ 107 | "credentials = service_account.Credentials.from_service_account_file('gcp-adventure-x-a3fb7a36e1e6.json')" 108 | ], 109 | "metadata": { 110 | "id": "qxQxLxscZJwG" 111 | }, 112 | "execution_count": 27, 113 | "outputs": [] 114 | }, 115 | { 116 | "cell_type": "code", 117 | "source": [ 118 | "vertexai.init(project=\"gcp-adventure-x\", location=\"us-central1\", credentials=credentials)" 119 | ], 120 | "metadata": { 121 | "id": "xI9lR5lkUos4" 122 | }, 123 | "execution_count": 28, 124 | "outputs": [] 125 | }, 126 | { 127 | "cell_type": "code", 128 | "source": [ 129 | "model = GenerativeModel(\"gemini-1.0-pro-001\")" 130 | ], 131 | "metadata": { 132 | "id": "tC_UCY5aUOxx" 133 | }, 134 | "execution_count": 29, 135 | "outputs": [] 136 | }, 137 | { 138 | "cell_type": "code", 139 | "source": [ 140 | "response = model.generate_content(\n", 141 | " \"What's the exchange rate for euros to dollars today?\"\n", 142 | ")\n", 143 | "print(response.text)" 144 | ], 145 | "metadata": { 146 | "colab": { 147 | "base_uri": "https://localhost:8080/" 148 | }, 149 | "id": "XQvmgIhgUR9h", 150 | "outputId": "ca335332-2420-4606-8488-89e5ef3dd656" 151 | }, 152 | "execution_count": 30, 153 | "outputs": [ 154 | { 155 | "output_type": "stream", 156 | "name": "stdout", 157 | "text": [ 158 | "I do not have access to real-time information and cannot provide the current exchange rate. Please check a currency converter or financial news source for the most up-to-date information.\n" 159 | ] 160 | } 161 | ] 162 | }, 163 | { 164 | "cell_type": "markdown", 165 | "source": [ 166 | "## Step 4: Try common workarounds\n", 167 | "Skipping this section in interest of getting ideas from the workshop participants." 168 | ], 169 | "metadata": { 170 | "id": "83gN8MgHnWkb" 171 | } 172 | }, 173 | { 174 | "cell_type": "markdown", 175 | "source": [ 176 | "## Step 5. How function calling works\n", 177 | "This section describes the execution flow of Function Calling." 178 | ], 179 | "metadata": { 180 | "id": "_C4PTcngnkM0" 181 | } 182 | }, 183 | { 184 | "cell_type": "markdown", 185 | "source": [ 186 | "## Step 6. Choose your API" 187 | ], 188 | "metadata": { 189 | "id": "T0Khd0PGnt1m" 190 | } 191 | }, 192 | { 193 | "cell_type": "code", 194 | "source": [ 195 | "import requests\n", 196 | "url = \"https://api.frankfurter.app/latest\"\n", 197 | "response = requests.get(url)\n", 198 | "data = response.json()\n", 199 | "data" 200 | ], 201 | "metadata": { 202 | "colab": { 203 | "base_uri": "https://localhost:8080/" 204 | }, 205 | "id": "-iHue7JDXJzo", 206 | "outputId": "efab08fc-386e-40dc-cdcd-6c6fc181e743" 207 | }, 208 | "execution_count": 32, 209 | "outputs": [ 210 | { 211 | "output_type": "execute_result", 212 | "data": { 213 | "text/plain": [ 214 | "{'amount': 1.0,\n", 215 | " 'base': 'EUR',\n", 216 | " 'date': '2024-03-15',\n", 217 | " 'rates': {'AUD': 1.6579,\n", 218 | " 'BGN': 1.9558,\n", 219 | " 'BRL': 5.4461,\n", 220 | " 'CAD': 1.4731,\n", 221 | " 'CHF': 0.9613,\n", 222 | " 'CNY': 7.838,\n", 223 | " 'CZK': 25.166,\n", 224 | " 'DKK': 7.4571,\n", 225 | " 'GBP': 0.8541,\n", 226 | " 'HKD': 8.5199,\n", 227 | " 'HUF': 393.2,\n", 228 | " 'IDR': 17011,\n", 229 | " 'ILS': 3.9811,\n", 230 | " 'INR': 90.26,\n", 231 | " 'ISK': 148.9,\n", 232 | " 'JPY': 162.03,\n", 233 | " 'KRW': 1448.71,\n", 234 | " 'MXN': 18.1915,\n", 235 | " 'MYR': 5.1241,\n", 236 | " 'NOK': 11.5205,\n", 237 | " 'NZD': 1.786,\n", 238 | " 'PHP': 60.494,\n", 239 | " 'PLN': 4.2953,\n", 240 | " 'RON': 4.9711,\n", 241 | " 'SEK': 11.2674,\n", 242 | " 'SGD': 1.4562,\n", 243 | " 'THB': 39.053,\n", 244 | " 'TRY': 35.092,\n", 245 | " 'USD': 1.0892,\n", 246 | " 'ZAR': 20.352}}" 247 | ] 248 | }, 249 | "metadata": {}, 250 | "execution_count": 32 251 | } 252 | ] 253 | }, 254 | { 255 | "cell_type": "markdown", 256 | "source": [ 257 | "## Step 7. Define a function and tool" 258 | ], 259 | "metadata": { 260 | "id": "Ntis3zdqnx-s" 261 | } 262 | }, 263 | { 264 | "cell_type": "code", 265 | "source": [ 266 | "from vertexai.preview.generative_models import (\n", 267 | " Content,\n", 268 | " FunctionDeclaration,\n", 269 | " GenerativeModel,\n", 270 | " Part,\n", 271 | " Tool,\n", 272 | ")\n", 273 | "\n", 274 | "model = GenerativeModel(\"gemini-1.0-pro-001\")" 275 | ], 276 | "metadata": { 277 | "id": "OECV1_BFbbAY" 278 | }, 279 | "execution_count": 33, 280 | "outputs": [] 281 | }, 282 | { 283 | "cell_type": "code", 284 | "source": [ 285 | "get_exchange_rate_func = FunctionDeclaration(\n", 286 | " name=\"get_exchange_rate\",\n", 287 | " description=\"Get the exchange rate for currencies between countries\",\n", 288 | " parameters={\n", 289 | " \"type\": \"object\",\n", 290 | " \"properties\": {\n", 291 | " \"currency_date\": {\n", 292 | " \"type\": \"string\",\n", 293 | " \"description\": \"A date that must always be in YYYY-MM-DD format or the value 'latest' if a time period is not specified\"\n", 294 | " },\n", 295 | " \"currency_from\": {\n", 296 | " \"type\": \"string\",\n", 297 | " \"description\": \"The currency to convert from in ISO 4217 format\"\n", 298 | " },\n", 299 | " \"currency_to\": {\n", 300 | " \"type\": \"string\",\n", 301 | " \"description\": \"The currency to convert to in ISO 4217 format\"\n", 302 | " }\n", 303 | " },\n", 304 | " \"required\": [\n", 305 | " \"currency_from\",\n", 306 | " \"currency_date\",\n", 307 | " ]\n", 308 | " },\n", 309 | ")" 310 | ], 311 | "metadata": { 312 | "id": "dW3THZYfb3xJ" 313 | }, 314 | "execution_count": 34, 315 | "outputs": [] 316 | }, 317 | { 318 | "cell_type": "code", 319 | "source": [ 320 | "exchange_rate_tool = Tool(\n", 321 | " function_declarations=[get_exchange_rate_func],\n", 322 | ")" 323 | ], 324 | "metadata": { 325 | "id": "mRCGq9UxcNjY" 326 | }, 327 | "execution_count": 35, 328 | "outputs": [] 329 | }, 330 | { 331 | "cell_type": "markdown", 332 | "source": [ 333 | "## Step 8. Generate a function call" 334 | ], 335 | "metadata": { 336 | "id": "vL07Qym0n1dV" 337 | } 338 | }, 339 | { 340 | "cell_type": "code", 341 | "source": [ 342 | "prompt = \"\"\"What is the latest exchange rate from Australian dollars to Swedish krona?\n", 343 | "How much is 500 Australian dollars worth in Swedish krona?\"\"\"\n", 344 | "\n", 345 | "response = model.generate_content(\n", 346 | " prompt,\n", 347 | " tools=[exchange_rate_tool],\n", 348 | ")" 349 | ], 350 | "metadata": { 351 | "id": "qUKr6ryDcTvl" 352 | }, 353 | "execution_count": 57, 354 | "outputs": [] 355 | }, 356 | { 357 | "cell_type": "code", 358 | "source": [ 359 | "print(response.candidates[0].content)" 360 | ], 361 | "metadata": { 362 | "colab": { 363 | "base_uri": "https://localhost:8080/" 364 | }, 365 | "id": "D-xZ3rsbcXp6", 366 | "outputId": "d3fff990-c388-43ea-e40d-e393ba9b1bdf" 367 | }, 368 | "execution_count": 59, 369 | "outputs": [ 370 | { 371 | "output_type": "stream", 372 | "name": "stdout", 373 | "text": [ 374 | "role: \"model\"\n", 375 | "parts {\n", 376 | " function_call {\n", 377 | " name: \"get_exchange_rate\"\n", 378 | " args {\n", 379 | " fields {\n", 380 | " key: \"currency_date\"\n", 381 | " value {\n", 382 | " string_value: \"latest\"\n", 383 | " }\n", 384 | " }\n", 385 | " fields {\n", 386 | " key: \"currency_from\"\n", 387 | " value {\n", 388 | " string_value: \"AUD\"\n", 389 | " }\n", 390 | " }\n", 391 | " fields {\n", 392 | " key: \"currency_to\"\n", 393 | " value {\n", 394 | " string_value: \"SEK\"\n", 395 | " }\n", 396 | " }\n", 397 | " }\n", 398 | " }\n", 399 | "}\n", 400 | "\n" 401 | ] 402 | } 403 | ] 404 | }, 405 | { 406 | "cell_type": "markdown", 407 | "source": [ 408 | "## Step 9. Make an API request" 409 | ], 410 | "metadata": { 411 | "id": "BIzZ906Bn6L8" 412 | } 413 | }, 414 | { 415 | "cell_type": "code", 416 | "source": [ 417 | "params = {}\n", 418 | "for key, value in response.candidates[0].content.parts[0].function_call.args.items():\n", 419 | " params[key[9:]] = value\n", 420 | "params" 421 | ], 422 | "metadata": { 423 | "colab": { 424 | "base_uri": "https://localhost:8080/" 425 | }, 426 | "id": "YuflEdH6caCY", 427 | "outputId": "02b4fa48-e0d8-456e-e9a1-16a1e1de8ab3" 428 | }, 429 | "execution_count": 60, 430 | "outputs": [ 431 | { 432 | "output_type": "execute_result", 433 | "data": { 434 | "text/plain": [ 435 | "{'from': 'AUD', 'date': 'latest', 'to': 'SEK'}" 436 | ] 437 | }, 438 | "metadata": {}, 439 | "execution_count": 60 440 | } 441 | ] 442 | }, 443 | { 444 | "cell_type": "code", 445 | "source": [ 446 | "import requests\n", 447 | "url = f\"https://api.frankfurter.app/{params['date']}\"\n", 448 | "api_response = requests.get(url, params=params)\n", 449 | "api_response.json()" 450 | ], 451 | "metadata": { 452 | "colab": { 453 | "base_uri": "https://localhost:8080/" 454 | }, 455 | "id": "B_M0seAGgZ7u", 456 | "outputId": "dc1b47df-9765-4277-9d79-1f34799949e8" 457 | }, 458 | "execution_count": 62, 459 | "outputs": [ 460 | { 461 | "output_type": "execute_result", 462 | "data": { 463 | "text/plain": [ 464 | "{'amount': 1.0, 'base': 'AUD', 'date': '2024-03-15', 'rates': {'SEK': 6.7962}}" 465 | ] 466 | }, 467 | "metadata": {}, 468 | "execution_count": 62 469 | } 470 | ] 471 | }, 472 | { 473 | "cell_type": "markdown", 474 | "source": [ 475 | "## Step 10. Generate a response" 476 | ], 477 | "metadata": { 478 | "id": "X0_w2t1Mn-F9" 479 | } 480 | }, 481 | { 482 | "cell_type": "code", 483 | "source": [ 484 | "response = model.generate_content(\n", 485 | " [\n", 486 | " Content(role=\"user\", parts=[\n", 487 | " Part.from_text(prompt + \"\"\"Give your answer in steps with lots of detail\n", 488 | " and context, including the exchange rate and date.\"\"\"),\n", 489 | " ]),\n", 490 | " Content(role=\"function\", parts=[\n", 491 | " Part.from_dict({\n", 492 | " \"function_call\": {\n", 493 | " \"name\": \"get_exchange_rate\",\n", 494 | " }\n", 495 | " })\n", 496 | " ]),\n", 497 | " Content(role=\"function\", parts=[\n", 498 | " Part.from_function_response(\n", 499 | " name=\"get_exchange_rate\",\n", 500 | " response={\n", 501 | " \"content\": api_response.text,\n", 502 | " }\n", 503 | " )\n", 504 | " ]),\n", 505 | " ],\n", 506 | " tools=[exchange_rate_tool],\n", 507 | ")\n", 508 | "\n", 509 | "\n", 510 | "response.candidates[0].content.parts[0].text" 511 | ], 512 | "metadata": { 513 | "colab": { 514 | "base_uri": "https://localhost:8080/", 515 | "height": 35 516 | }, 517 | "id": "a2Fgz7mkgcgY", 518 | "outputId": "0bb0686a-f69f-49af-cc97-b3b8a2b36c19" 519 | }, 520 | "execution_count": 63, 521 | "outputs": [ 522 | { 523 | "output_type": "execute_result", 524 | "data": { 525 | "text/plain": [ 526 | "'500 Australian dollars is worth 3398.1 Swedish krona as of 2024-03-15. The exchange rate is 1 AUD = 6.7962 SEK.'" 527 | ], 528 | "application/vnd.google.colaboratory.intrinsic+json": { 529 | "type": "string" 530 | } 531 | }, 532 | "metadata": {}, 533 | "execution_count": 63 534 | } 535 | ] 536 | }, 537 | { 538 | "cell_type": "markdown", 539 | "source": [ 540 | "## Conclusion\n", 541 | "\n", 542 | "Explore further:\n", 543 | "\n", 544 | "\n", 545 | "1. [Gemini Pro API Reference](https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/gemini#gemini-pro)\n", 546 | "2. [Sample Function Calling Notebook](https://github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/function-calling/intro_function_calling.ipynb)\n", 547 | "\n" 548 | ], 549 | "metadata": { 550 | "id": "ZleqMsPxoJpx" 551 | } 552 | }, 553 | { 554 | "cell_type": "code", 555 | "source": [], 556 | "metadata": { 557 | "id": "ADJFJ6acog0u" 558 | }, 559 | "execution_count": null, 560 | "outputs": [] 561 | } 562 | ] 563 | } -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Anubhav Singh 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | --------------------------------------------------------------------------------