├── GeminiFunctionCallingWithRAGChromaDb.ipynb
├── GeminiFunctionCalling_Base.ipynb
└── LICENSE
/GeminiFunctionCallingWithRAGChromaDb.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {
6 | "id": "view-in-github",
7 | "colab_type": "text"
8 | },
9 | "source": [
10 | "
"
11 | ]
12 | },
13 | {
14 | "cell_type": "markdown",
15 | "source": [
16 | "# Putting together Function Calling and RAGs with Gemini\n",
17 | "\n",
18 | "In this notebook, we explore using Gemini APIs with Function Calling and Retrieval Augmented Generation. It covers a method to use ChromaDB powered vector store to act as a cache for an external API.\n",
19 | "\n",
20 | "We use the [Frankfurter API](https://www.frankfurter.app/) for fetching exchange rates between currencies."
21 | ],
22 | "metadata": {
23 | "id": "mQJc0rOrd9YP"
24 | }
25 | },
26 | {
27 | "cell_type": "markdown",
28 | "metadata": {
29 | "id": "uHK1b5sH1Tu8"
30 | },
31 | "source": [
32 | "## Setup\n",
33 | "\n",
34 | "1. Install ChromaDB\n",
35 | "2. Install Google's SDK for Gen AI\n",
36 | "\n",
37 | "This notebook uses Google AI Studio for calling the APIs. This method is slightly different from using Vertex AI Studio based Gemini Model APIs."
38 | ]
39 | },
40 | {
41 | "cell_type": "code",
42 | "execution_count": 1,
43 | "metadata": {
44 | "id": "I73Zu7da0gpb"
45 | },
46 | "outputs": [],
47 | "source": [
48 | "%%capture\n",
49 | "!pip install chromadb\n",
50 | "!pip install -U -q google-generativeai"
51 | ]
52 | },
53 | {
54 | "cell_type": "code",
55 | "source": [
56 | "import requests\n",
57 | "import google.generativeai as genai\n",
58 | "import chromadb\n",
59 | "from datetime import datetime, timedelta"
60 | ],
61 | "metadata": {
62 | "id": "rWS3LVtIYf8V"
63 | },
64 | "execution_count": 2,
65 | "outputs": []
66 | },
67 | {
68 | "cell_type": "markdown",
69 | "source": [
70 | "Instantiating a Gemini API client."
71 | ],
72 | "metadata": {
73 | "id": "LqejWRLrfdZN"
74 | }
75 | },
76 | {
77 | "cell_type": "code",
78 | "source": [
79 | "try:\n",
80 | " # Used to securely store your API key\n",
81 | " from google.colab import userdata\n",
82 | "\n",
83 | " # Or use `os.getenv('API_KEY')` to fetch an environment variable.\n",
84 | " GOOGLE_API_KEY=userdata.get('GOOGLE_API_KEY')\n",
85 | "except ImportError:\n",
86 | " import os\n",
87 | " GOOGLE_API_KEY = os.environ['GOOGLE_API_KEY']\n",
88 | "\n",
89 | "genai.configure(api_key=GOOGLE_API_KEY)"
90 | ],
91 | "metadata": {
92 | "id": "NBhrIrYo5Ric"
93 | },
94 | "execution_count": 3,
95 | "outputs": []
96 | },
97 | {
98 | "cell_type": "markdown",
99 | "source": [
100 | "Instantiating a ChromaDB client"
101 | ],
102 | "metadata": {
103 | "id": "5GdFghY7fgl_"
104 | }
105 | },
106 | {
107 | "cell_type": "code",
108 | "execution_count": 4,
109 | "metadata": {
110 | "id": "cFkvhQdE1znl"
111 | },
112 | "outputs": [],
113 | "source": [
114 | "chroma_client = chromadb.Client()"
115 | ]
116 | },
117 | {
118 | "cell_type": "markdown",
119 | "source": [
120 | "Creating a collection, you can have multiple collections per client. These are all stored in memory, so make sure your RAM is high enough to hold these documents."
121 | ],
122 | "metadata": {
123 | "id": "92kO_9tufpmU"
124 | }
125 | },
126 | {
127 | "cell_type": "code",
128 | "execution_count": 5,
129 | "metadata": {
130 | "id": "ZMabyYgw08Kp"
131 | },
132 | "outputs": [],
133 | "source": [
134 | "collection = chroma_client.create_collection(name=\"exchange-rates\")"
135 | ]
136 | },
137 | {
138 | "cell_type": "markdown",
139 | "source": [
140 | "## [Optional] Initialize the collection with latest 5 days of data on exchange rates.\n",
141 | "\n",
142 | "Let's get the latest 5 dates for which we'll fetch the initial chunk of data into the DB."
143 | ],
144 | "metadata": {
145 | "id": "GWoUNAE-f0sP"
146 | }
147 | },
148 | {
149 | "cell_type": "code",
150 | "execution_count": 6,
151 | "metadata": {
152 | "id": "m2JhWLSH097J"
153 | },
154 | "outputs": [],
155 | "source": [
156 | "current_date = datetime.now()\n",
157 | "dates = [(current_date - timedelta(days=i)).strftime('%Y-%m-%d') for i in range(1, 6)]"
158 | ]
159 | },
160 | {
161 | "cell_type": "code",
162 | "execution_count": 7,
163 | "metadata": {
164 | "colab": {
165 | "base_uri": "https://localhost:8080/"
166 | },
167 | "id": "rZVkHP5p2bS_",
168 | "outputId": "bd6c360c-829f-4c08-ec8a-4cf2f59f3668"
169 | },
170 | "outputs": [
171 | {
172 | "output_type": "execute_result",
173 | "data": {
174 | "text/plain": [
175 | "['2024-05-10', '2024-05-09', '2024-05-08', '2024-05-07', '2024-05-06']"
176 | ]
177 | },
178 | "metadata": {},
179 | "execution_count": 7
180 | }
181 | ],
182 | "source": [
183 | "dates"
184 | ]
185 | },
186 | {
187 | "cell_type": "code",
188 | "execution_count": 8,
189 | "metadata": {
190 | "id": "1Eg_H3ol3Fgt"
191 | },
192 | "outputs": [],
193 | "source": [
194 | "url = \"https://api.frankfurter.app/latest\"\n",
195 | "response = requests.get(url)\n",
196 | "api_data = response.json()\n",
197 | "\n",
198 | "currencies = list(api_data['rates'].keys())"
199 | ]
200 | },
201 | {
202 | "cell_type": "code",
203 | "execution_count": 9,
204 | "metadata": {
205 | "id": "8AhPIXIR2cBK"
206 | },
207 | "outputs": [],
208 | "source": [
209 | "def initial_fetch_conversion_rates(date, currencies):\n",
210 | " rates = {}\n",
211 | " try:\n",
212 | " response = requests.get(f'https://api.frankfurter.app/{date}')\n",
213 | " data = response.json()\n",
214 | " rates = {curr: data['rates'].get(curr, 'N/A') for curr in currencies}\n",
215 | " except Exception as e:\n",
216 | " print(f\"Error fetching conversion rates for {date}: {e}\")\n",
217 | " return rates"
218 | ]
219 | },
220 | {
221 | "cell_type": "code",
222 | "execution_count": 10,
223 | "metadata": {
224 | "id": "A9hgdjD627HU"
225 | },
226 | "outputs": [],
227 | "source": [
228 | "documents = []\n",
229 | "metadatas = []\n",
230 | "ids = []\n",
231 | "for date in dates:\n",
232 | " rates = initial_fetch_conversion_rates(date, currencies)\n",
233 | " for currency, rate in rates.items():\n",
234 | " document = f\"EUR to {currency} conversion rate is {rate} on {date}\"\n",
235 | " documents.append(document)\n",
236 | " metadata = {\"base\": \"EUR\", \"target\": currency, \"date\": date, \"rate\": rate}\n",
237 | " id_ = f\"EUR-{currency}-{date.replace('-', '')}\"\n",
238 | " metadatas.append(metadata)\n",
239 | " ids.append(id_)"
240 | ]
241 | },
242 | {
243 | "cell_type": "code",
244 | "execution_count": 11,
245 | "metadata": {
246 | "id": "KI5cUBjs3MIc"
247 | },
248 | "outputs": [],
249 | "source": [
250 | "collection.add(\n",
251 | " metadatas=metadatas,\n",
252 | " documents=documents,\n",
253 | " ids=ids,\n",
254 | ")"
255 | ]
256 | },
257 | {
258 | "cell_type": "markdown",
259 | "source": [
260 | "Quick check of whether our initial fetch of data works!"
261 | ],
262 | "metadata": {
263 | "id": "G70L-HBPgL7i"
264 | }
265 | },
266 | {
267 | "cell_type": "code",
268 | "source": [
269 | "search_query = \"EUR to INR conversion rate on 09-05-2024\""
270 | ],
271 | "metadata": {
272 | "id": "zoOmlJ-n3pBo"
273 | },
274 | "execution_count": 12,
275 | "outputs": []
276 | },
277 | {
278 | "cell_type": "code",
279 | "execution_count": 13,
280 | "metadata": {
281 | "id": "d7qikJLg3TP1",
282 | "colab": {
283 | "base_uri": "https://localhost:8080/"
284 | },
285 | "outputId": "65280a65-d8c8-4af2-972c-d6ab617c6f8f"
286 | },
287 | "outputs": [
288 | {
289 | "output_type": "execute_result",
290 | "data": {
291 | "text/plain": [
292 | "{'ids': [['EUR-INR-20240509']],\n",
293 | " 'distances': [[0.08850838989019394]],\n",
294 | " 'metadatas': [[{'base': 'EUR',\n",
295 | " 'date': '2024-05-09',\n",
296 | " 'rate': 89.61,\n",
297 | " 'target': 'INR'}]],\n",
298 | " 'embeddings': None,\n",
299 | " 'documents': [['EUR to INR conversion rate is 89.61 on 2024-05-09']],\n",
300 | " 'uris': None,\n",
301 | " 'data': None}"
302 | ]
303 | },
304 | "metadata": {},
305 | "execution_count": 13
306 | }
307 | ],
308 | "source": [
309 | "results = collection.query(\n",
310 | " query_texts=[search_query], # Chroma will embed this for you\n",
311 | " n_results=1 # how many results to return\n",
312 | ")\n",
313 | "results"
314 | ]
315 | },
316 | {
317 | "cell_type": "markdown",
318 | "source": [
319 | "## Function for fetching data by date, from and target currency directly from API\n",
320 | "\n",
321 | "We'll call this API anytime we need to fetch new data from the API."
322 | ],
323 | "metadata": {
324 | "id": "YKZTwv24gP4r"
325 | }
326 | },
327 | {
328 | "cell_type": "code",
329 | "source": [
330 | "def get_exchange_rate_from_api(currency_from, currency_to, currency_date=\"latest\"):\n",
331 | " try:\n",
332 | " response = requests.get(f'https://api.frankfurter.app/{currency_date}', params={\n",
333 | " \"from\": currency_from,\n",
334 | " \"to\": currency_to\n",
335 | " })\n",
336 | " metadata, document, id_ = None, None, None\n",
337 | " data = response.json()\n",
338 | "\n",
339 | " rate = data['rates'][currency_to]\n",
340 | " document = f\"{currency_from} to {currency_to} conversion rate is {rate} on {currency_date}\"\n",
341 | " metadata = {\"base\": currency_from, \"target\": currency_to, \"date\": currency_date, \"rate\": rate}\n",
342 | " id_ = f\"{currency_from}-{currency_to}-{currency_date.replace('-', '')}\"\n",
343 | "\n",
344 | " collection.upsert(\n",
345 | " metadatas=[metadata],\n",
346 | " documents=[document],\n",
347 | " ids=[id_],\n",
348 | " )\n",
349 | "\n",
350 | " print(\"Added document\")\n",
351 | " print(document)\n",
352 | " except Exception as e:\n",
353 | " print(f\"Error fetching conversion rates for {currency_date}: {e}\")\n",
354 | " raise(e)\n",
355 | " return document"
356 | ],
357 | "metadata": {
358 | "id": "1Asd4prK6lht"
359 | },
360 | "execution_count": 14,
361 | "outputs": []
362 | },
363 | {
364 | "cell_type": "markdown",
365 | "source": [
366 | "Quick check if our function above is working!"
367 | ],
368 | "metadata": {
369 | "id": "Zo5RUNp5gZrB"
370 | }
371 | },
372 | {
373 | "cell_type": "code",
374 | "source": [
375 | "get_exchange_rate_from_api(\"USD\", \"INR\", \"2024-05-10\")"
376 | ],
377 | "metadata": {
378 | "colab": {
379 | "base_uri": "https://localhost:8080/",
380 | "height": 71
381 | },
382 | "id": "9S-UJXDM-70j",
383 | "outputId": "896b2875-8674-4e3a-8f2e-a9b347205713"
384 | },
385 | "execution_count": 15,
386 | "outputs": [
387 | {
388 | "output_type": "stream",
389 | "name": "stdout",
390 | "text": [
391 | "Added document\n",
392 | "USD to INR conversion rate is 83.51 on 2024-05-10\n"
393 | ]
394 | },
395 | {
396 | "output_type": "execute_result",
397 | "data": {
398 | "text/plain": [
399 | "'USD to INR conversion rate is 83.51 on 2024-05-10'"
400 | ],
401 | "application/vnd.google.colaboratory.intrinsic+json": {
402 | "type": "string"
403 | }
404 | },
405 | "metadata": {},
406 | "execution_count": 15
407 | }
408 | ]
409 | },
410 | {
411 | "cell_type": "markdown",
412 | "source": [
413 | "## Function that is ready to work with Gemini API\n",
414 | "\n",
415 | "This function first checks if we have the data for those dates and currrency pair already in the database. If not, fetch from API and return.\n",
416 | "\n",
417 | "Check out all those verbose descriptions, these help Gemini understand this function better while deciding when and how to call it!"
418 | ],
419 | "metadata": {
420 | "id": "WlIbIQfZgcZD"
421 | }
422 | },
423 | {
424 | "cell_type": "code",
425 | "source": [
426 | "def get_exchange_rate(currency_from: str, currency_to: str, currency_date: str =\"latest\"):\n",
427 | " \"\"\"\n",
428 | " This function retrieves the exchange rate between two currencies on a specific date.\n",
429 | "\n",
430 | " Args:\n",
431 | " currency_from (str): The currency to convert from (ISO 4217 format). (Required)\n",
432 | " currency_to (str): The currency to convert to (ISO 4217 format). (Required)\n",
433 | " currency_date (str, optional): The date for the exchange rate in YYYY-MM-DD format\n",
434 | " or \"latest\" for the most recent rate. Defaults to \"latest\".\n",
435 | "\n",
436 | " Returns:\n",
437 | " float: The exchange rate (currency_to per unit of currency_from).\n",
438 | " If the rate cannot be retrieved, returns None.\n",
439 | "\n",
440 | " Raises:\n",
441 | " ValueError: If either currency code is invalid or the date format is incorrect.\n",
442 | " \"\"\"\n",
443 | "\n",
444 | " ## First, check if the Vector Store has the data\n",
445 | " search_query = f\"{currency_from} to {currency_to} conversion on {currency_date}\"\n",
446 | "\n",
447 | " results = collection.query(\n",
448 | " query_texts=[search_query], # Chroma will embed this for you\n",
449 | " n_results=1 # how many results to return\n",
450 | " )\n",
451 | "\n",
452 | " if (results[\"ids\"][0][0] == f\"{currency_from}-{currency_to}-{currency_date.replace('-', '')}\"):\n",
453 | " print(\"Vector Store hit, let's return this.\")\n",
454 | " print(\"DB return: \", results[\"documents\"][0])\n",
455 | " return results[\"documents\"][0]\n",
456 | "\n",
457 | " api_data = get_exchange_rate_from_api(currency_from, currency_to, currency_date)\n",
458 | " print(\"Vector store miss. Fetching data from API and returning!\")\n",
459 | " print(\"API return: \",api_data )\n",
460 | " return api_data"
461 | ],
462 | "metadata": {
463 | "id": "o53em65cXWCd"
464 | },
465 | "execution_count": 47,
466 | "outputs": []
467 | },
468 | {
469 | "cell_type": "markdown",
470 | "source": [
471 | "Quick check of above function running!"
472 | ],
473 | "metadata": {
474 | "id": "n0lP-Hn_g5sC"
475 | }
476 | },
477 | {
478 | "cell_type": "code",
479 | "source": [
480 | "get_exchange_rate(\"USD\", \"INR\", \"2024-05-01\")"
481 | ],
482 | "metadata": {
483 | "colab": {
484 | "base_uri": "https://localhost:8080/"
485 | },
486 | "id": "q0x6AGJAbvy3",
487 | "outputId": "c7afba06-4465-47bb-d7db-f3265ea6b59c"
488 | },
489 | "execution_count": 48,
490 | "outputs": [
491 | {
492 | "output_type": "stream",
493 | "name": "stdout",
494 | "text": [
495 | "Vector Store hit, let's return this.\n",
496 | "DB return: ['USD to INR conversion rate is 83.43 on 2024-05-01']\n"
497 | ]
498 | },
499 | {
500 | "output_type": "execute_result",
501 | "data": {
502 | "text/plain": [
503 | "['USD to INR conversion rate is 83.43 on 2024-05-01']"
504 | ]
505 | },
506 | "metadata": {},
507 | "execution_count": 48
508 | }
509 | ]
510 | },
511 | {
512 | "cell_type": "markdown",
513 | "source": [
514 | "## Declare the model with the function declaration\n",
515 | "\n",
516 | "The SDK converts your function to a declaration for usage, if does not send the function anywhere!"
517 | ],
518 | "metadata": {
519 | "id": "5adcuF4Ug9Jq"
520 | }
521 | },
522 | {
523 | "cell_type": "code",
524 | "source": [
525 | "model = genai.GenerativeModel(model_name='gemini-1.0-pro',\n",
526 | " tools=[get_exchange_rate])"
527 | ],
528 | "metadata": {
529 | "id": "U7R63JidWvS1"
530 | },
531 | "execution_count": 49,
532 | "outputs": []
533 | },
534 | {
535 | "cell_type": "markdown",
536 | "source": [
537 | "Turns out chat usage has automatic function calling enabled, let's use it for the oomph factor!"
538 | ],
539 | "metadata": {
540 | "id": "4DhLbe5phGQx"
541 | }
542 | },
543 | {
544 | "cell_type": "code",
545 | "source": [
546 | "chat = model.start_chat(enable_automatic_function_calling=True)"
547 | ],
548 | "metadata": {
549 | "id": "3fDAELNMXUU0"
550 | },
551 | "execution_count": 50,
552 | "outputs": []
553 | },
554 | {
555 | "cell_type": "markdown",
556 | "source": [
557 | "Wrap the chat invocation in a nice neat function!"
558 | ],
559 | "metadata": {
560 | "id": "fKjed8tqhNU3"
561 | }
562 | },
563 | {
564 | "cell_type": "code",
565 | "source": [
566 | "def ask_gemini(query):\n",
567 | " response = chat.send_message(query)\n",
568 | " return response.text"
569 | ],
570 | "metadata": {
571 | "id": "getq18lSdlI1"
572 | },
573 | "execution_count": 51,
574 | "outputs": []
575 | },
576 | {
577 | "cell_type": "markdown",
578 | "source": [
579 | "## Testing"
580 | ],
581 | "metadata": {
582 | "id": "Ml0yMBIqhRzD"
583 | }
584 | },
585 | {
586 | "cell_type": "markdown",
587 | "source": [
588 | "Check: 20 USD to INR on 2024-05-10"
589 | ],
590 | "metadata": {
591 | "id": "aEKJrk_bhTEm"
592 | }
593 | },
594 | {
595 | "cell_type": "code",
596 | "source": [
597 | "ask_gemini('I have 20 US dollars with me. Calculate how much I have in INR as of 10th May, 2024.')"
598 | ],
599 | "metadata": {
600 | "colab": {
601 | "base_uri": "https://localhost:8080/",
602 | "height": 71
603 | },
604 | "id": "MdXHEUOEWxgd",
605 | "outputId": "10a11994-934d-45b4-9f22-4098d80b701f"
606 | },
607 | "execution_count": 52,
608 | "outputs": [
609 | {
610 | "output_type": "stream",
611 | "name": "stdout",
612 | "text": [
613 | "Vector Store hit, let's return this.\n",
614 | "DB return: ['USD to INR conversion rate is 83.51 on 2024-05-10']\n"
615 | ]
616 | },
617 | {
618 | "output_type": "execute_result",
619 | "data": {
620 | "text/plain": [
621 | "'20 US dollars would be equal to 1,670.2 INR as of May 10th, 2024.'"
622 | ],
623 | "application/vnd.google.colaboratory.intrinsic+json": {
624 | "type": "string"
625 | }
626 | },
627 | "metadata": {},
628 | "execution_count": 52
629 | }
630 | ]
631 | },
632 | {
633 | "cell_type": "markdown",
634 | "source": [
635 | "Check: 20 USD to INR on 2024-05-01"
636 | ],
637 | "metadata": {
638 | "id": "zR6X5uq_hbTK"
639 | }
640 | },
641 | {
642 | "cell_type": "code",
643 | "source": [
644 | "ask_gemini('I have 20 US dollars with me. Calculate how much I have in INR as of 1st May, 2024?')"
645 | ],
646 | "metadata": {
647 | "colab": {
648 | "base_uri": "https://localhost:8080/",
649 | "height": 71
650 | },
651 | "id": "om1caMytaZz_",
652 | "outputId": "03efd0c3-1bfd-4087-97c6-05ae4c1e0c05"
653 | },
654 | "execution_count": 53,
655 | "outputs": [
656 | {
657 | "output_type": "stream",
658 | "name": "stdout",
659 | "text": [
660 | "Vector Store hit, let's return this.\n",
661 | "DB return: ['USD to INR conversion rate is 83.43 on 2024-05-01']\n"
662 | ]
663 | },
664 | {
665 | "output_type": "execute_result",
666 | "data": {
667 | "text/plain": [
668 | "'20 US dollars would be equal to 1,668.6 INR as of May 1st, 2024.'"
669 | ],
670 | "application/vnd.google.colaboratory.intrinsic+json": {
671 | "type": "string"
672 | }
673 | },
674 | "metadata": {},
675 | "execution_count": 53
676 | }
677 | ]
678 | },
679 | {
680 | "cell_type": "markdown",
681 | "source": [
682 | "Check: 20 USD to INR on 2024-01-05"
683 | ],
684 | "metadata": {
685 | "id": "DQK791C4hczQ"
686 | }
687 | },
688 | {
689 | "cell_type": "code",
690 | "source": [
691 | "ask_gemini('I have 20 US dollars with me. Calculate how much I have in INR as of 5th January, 2024?')"
692 | ],
693 | "metadata": {
694 | "colab": {
695 | "base_uri": "https://localhost:8080/",
696 | "height": 107
697 | },
698 | "id": "9kNwEl-uciMD",
699 | "outputId": "d5c756b8-6305-46e4-90e6-3769f2c968d2"
700 | },
701 | "execution_count": 54,
702 | "outputs": [
703 | {
704 | "output_type": "stream",
705 | "name": "stdout",
706 | "text": [
707 | "Added document\n",
708 | "USD to INR conversion rate is 83.15 on 2024-01-05\n",
709 | "Vector store miss. Fetching data from API and returning!\n",
710 | "API return: USD to INR conversion rate is 83.15 on 2024-01-05\n"
711 | ]
712 | },
713 | {
714 | "output_type": "execute_result",
715 | "data": {
716 | "text/plain": [
717 | "'20 US dollars would be equal to 1,663 INR as of January 5th, 2024.'"
718 | ],
719 | "application/vnd.google.colaboratory.intrinsic+json": {
720 | "type": "string"
721 | }
722 | },
723 | "metadata": {},
724 | "execution_count": 54
725 | }
726 | ]
727 | },
728 | {
729 | "cell_type": "code",
730 | "source": [],
731 | "metadata": {
732 | "id": "e7Bb6eyWhosK"
733 | },
734 | "execution_count": 24,
735 | "outputs": []
736 | }
737 | ],
738 | "metadata": {
739 | "colab": {
740 | "provenance": [],
741 | "toc_visible": true,
742 | "authorship_tag": "ABX9TyPdOF9X71AfpufvKjC6Qq+c",
743 | "include_colab_link": true
744 | },
745 | "kernelspec": {
746 | "display_name": "Python 3",
747 | "name": "python3"
748 | },
749 | "language_info": {
750 | "name": "python"
751 | }
752 | },
753 | "nbformat": 4,
754 | "nbformat_minor": 0
755 | }
--------------------------------------------------------------------------------
/GeminiFunctionCalling_Base.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "provenance": [],
7 | "authorship_tag": "ABX9TyM49MTHCavJQ7lmQC0fZEAt",
8 | "include_colab_link": true
9 | },
10 | "kernelspec": {
11 | "name": "python3",
12 | "display_name": "Python 3"
13 | },
14 | "language_info": {
15 | "name": "python"
16 | }
17 | },
18 | "cells": [
19 | {
20 | "cell_type": "markdown",
21 | "metadata": {
22 | "id": "view-in-github",
23 | "colab_type": "text"
24 | },
25 | "source": [
26 | "
"
27 | ]
28 | },
29 | {
30 | "cell_type": "markdown",
31 | "source": [
32 | "# Basic Gemini Function Calling\n",
33 | "\n",
34 | "This notebook follows along the required code to complete the lab on [How to Interact with APIs Using Function Calling in Gemini\n",
35 | "](https://codelabs.developers.google.com/codelabs/gemini-function-calling) by [@koverholt](https://github.com/koverholt)."
36 | ],
37 | "metadata": {
38 | "id": "EJOQZY29makV"
39 | }
40 | },
41 | {
42 | "cell_type": "markdown",
43 | "source": [
44 | "## Step 1. Overview\n",
45 | "This section talks about the importance of Function Calling and how it enables Gemini to access data that may be real-time, protected or otherwise unavailable in the Gemini training datasets."
46 | ],
47 | "metadata": {
48 | "id": "74lqwXX-m-wf"
49 | }
50 | },
51 | {
52 | "cell_type": "markdown",
53 | "source": [
54 | "## Step 2. Setup and requirements"
55 | ],
56 | "metadata": {
57 | "id": "Q0FxkQ9Nm4DE"
58 | }
59 | },
60 | {
61 | "cell_type": "code",
62 | "execution_count": 2,
63 | "metadata": {
64 | "id": "YMqhtV0LT2Ff"
65 | },
66 | "outputs": [],
67 | "source": [
68 | "%%capture\n",
69 | "!pip install --upgrade google-cloud-aiplatform"
70 | ]
71 | },
72 | {
73 | "cell_type": "markdown",
74 | "source": [
75 | "## Step 3. Understand the problem"
76 | ],
77 | "metadata": {
78 | "id": "QLqRh1GenSAy"
79 | }
80 | },
81 | {
82 | "cell_type": "code",
83 | "source": [
84 | "import vertexai\n",
85 | "from vertexai.preview.generative_models import GenerativeModel"
86 | ],
87 | "metadata": {
88 | "id": "afMdh8d1T8Ub"
89 | },
90 | "execution_count": 25,
91 | "outputs": []
92 | },
93 | {
94 | "cell_type": "code",
95 | "source": [
96 | "from google.oauth2 import service_account"
97 | ],
98 | "metadata": {
99 | "id": "Rc4EfQ5jZG6U"
100 | },
101 | "execution_count": 26,
102 | "outputs": []
103 | },
104 | {
105 | "cell_type": "code",
106 | "source": [
107 | "credentials = service_account.Credentials.from_service_account_file('gcp-adventure-x-a3fb7a36e1e6.json')"
108 | ],
109 | "metadata": {
110 | "id": "qxQxLxscZJwG"
111 | },
112 | "execution_count": 27,
113 | "outputs": []
114 | },
115 | {
116 | "cell_type": "code",
117 | "source": [
118 | "vertexai.init(project=\"gcp-adventure-x\", location=\"us-central1\", credentials=credentials)"
119 | ],
120 | "metadata": {
121 | "id": "xI9lR5lkUos4"
122 | },
123 | "execution_count": 28,
124 | "outputs": []
125 | },
126 | {
127 | "cell_type": "code",
128 | "source": [
129 | "model = GenerativeModel(\"gemini-1.0-pro-001\")"
130 | ],
131 | "metadata": {
132 | "id": "tC_UCY5aUOxx"
133 | },
134 | "execution_count": 29,
135 | "outputs": []
136 | },
137 | {
138 | "cell_type": "code",
139 | "source": [
140 | "response = model.generate_content(\n",
141 | " \"What's the exchange rate for euros to dollars today?\"\n",
142 | ")\n",
143 | "print(response.text)"
144 | ],
145 | "metadata": {
146 | "colab": {
147 | "base_uri": "https://localhost:8080/"
148 | },
149 | "id": "XQvmgIhgUR9h",
150 | "outputId": "ca335332-2420-4606-8488-89e5ef3dd656"
151 | },
152 | "execution_count": 30,
153 | "outputs": [
154 | {
155 | "output_type": "stream",
156 | "name": "stdout",
157 | "text": [
158 | "I do not have access to real-time information and cannot provide the current exchange rate. Please check a currency converter or financial news source for the most up-to-date information.\n"
159 | ]
160 | }
161 | ]
162 | },
163 | {
164 | "cell_type": "markdown",
165 | "source": [
166 | "## Step 4: Try common workarounds\n",
167 | "Skipping this section in interest of getting ideas from the workshop participants."
168 | ],
169 | "metadata": {
170 | "id": "83gN8MgHnWkb"
171 | }
172 | },
173 | {
174 | "cell_type": "markdown",
175 | "source": [
176 | "## Step 5. How function calling works\n",
177 | "This section describes the execution flow of Function Calling."
178 | ],
179 | "metadata": {
180 | "id": "_C4PTcngnkM0"
181 | }
182 | },
183 | {
184 | "cell_type": "markdown",
185 | "source": [
186 | "## Step 6. Choose your API"
187 | ],
188 | "metadata": {
189 | "id": "T0Khd0PGnt1m"
190 | }
191 | },
192 | {
193 | "cell_type": "code",
194 | "source": [
195 | "import requests\n",
196 | "url = \"https://api.frankfurter.app/latest\"\n",
197 | "response = requests.get(url)\n",
198 | "data = response.json()\n",
199 | "data"
200 | ],
201 | "metadata": {
202 | "colab": {
203 | "base_uri": "https://localhost:8080/"
204 | },
205 | "id": "-iHue7JDXJzo",
206 | "outputId": "efab08fc-386e-40dc-cdcd-6c6fc181e743"
207 | },
208 | "execution_count": 32,
209 | "outputs": [
210 | {
211 | "output_type": "execute_result",
212 | "data": {
213 | "text/plain": [
214 | "{'amount': 1.0,\n",
215 | " 'base': 'EUR',\n",
216 | " 'date': '2024-03-15',\n",
217 | " 'rates': {'AUD': 1.6579,\n",
218 | " 'BGN': 1.9558,\n",
219 | " 'BRL': 5.4461,\n",
220 | " 'CAD': 1.4731,\n",
221 | " 'CHF': 0.9613,\n",
222 | " 'CNY': 7.838,\n",
223 | " 'CZK': 25.166,\n",
224 | " 'DKK': 7.4571,\n",
225 | " 'GBP': 0.8541,\n",
226 | " 'HKD': 8.5199,\n",
227 | " 'HUF': 393.2,\n",
228 | " 'IDR': 17011,\n",
229 | " 'ILS': 3.9811,\n",
230 | " 'INR': 90.26,\n",
231 | " 'ISK': 148.9,\n",
232 | " 'JPY': 162.03,\n",
233 | " 'KRW': 1448.71,\n",
234 | " 'MXN': 18.1915,\n",
235 | " 'MYR': 5.1241,\n",
236 | " 'NOK': 11.5205,\n",
237 | " 'NZD': 1.786,\n",
238 | " 'PHP': 60.494,\n",
239 | " 'PLN': 4.2953,\n",
240 | " 'RON': 4.9711,\n",
241 | " 'SEK': 11.2674,\n",
242 | " 'SGD': 1.4562,\n",
243 | " 'THB': 39.053,\n",
244 | " 'TRY': 35.092,\n",
245 | " 'USD': 1.0892,\n",
246 | " 'ZAR': 20.352}}"
247 | ]
248 | },
249 | "metadata": {},
250 | "execution_count": 32
251 | }
252 | ]
253 | },
254 | {
255 | "cell_type": "markdown",
256 | "source": [
257 | "## Step 7. Define a function and tool"
258 | ],
259 | "metadata": {
260 | "id": "Ntis3zdqnx-s"
261 | }
262 | },
263 | {
264 | "cell_type": "code",
265 | "source": [
266 | "from vertexai.preview.generative_models import (\n",
267 | " Content,\n",
268 | " FunctionDeclaration,\n",
269 | " GenerativeModel,\n",
270 | " Part,\n",
271 | " Tool,\n",
272 | ")\n",
273 | "\n",
274 | "model = GenerativeModel(\"gemini-1.0-pro-001\")"
275 | ],
276 | "metadata": {
277 | "id": "OECV1_BFbbAY"
278 | },
279 | "execution_count": 33,
280 | "outputs": []
281 | },
282 | {
283 | "cell_type": "code",
284 | "source": [
285 | "get_exchange_rate_func = FunctionDeclaration(\n",
286 | " name=\"get_exchange_rate\",\n",
287 | " description=\"Get the exchange rate for currencies between countries\",\n",
288 | " parameters={\n",
289 | " \"type\": \"object\",\n",
290 | " \"properties\": {\n",
291 | " \"currency_date\": {\n",
292 | " \"type\": \"string\",\n",
293 | " \"description\": \"A date that must always be in YYYY-MM-DD format or the value 'latest' if a time period is not specified\"\n",
294 | " },\n",
295 | " \"currency_from\": {\n",
296 | " \"type\": \"string\",\n",
297 | " \"description\": \"The currency to convert from in ISO 4217 format\"\n",
298 | " },\n",
299 | " \"currency_to\": {\n",
300 | " \"type\": \"string\",\n",
301 | " \"description\": \"The currency to convert to in ISO 4217 format\"\n",
302 | " }\n",
303 | " },\n",
304 | " \"required\": [\n",
305 | " \"currency_from\",\n",
306 | " \"currency_date\",\n",
307 | " ]\n",
308 | " },\n",
309 | ")"
310 | ],
311 | "metadata": {
312 | "id": "dW3THZYfb3xJ"
313 | },
314 | "execution_count": 34,
315 | "outputs": []
316 | },
317 | {
318 | "cell_type": "code",
319 | "source": [
320 | "exchange_rate_tool = Tool(\n",
321 | " function_declarations=[get_exchange_rate_func],\n",
322 | ")"
323 | ],
324 | "metadata": {
325 | "id": "mRCGq9UxcNjY"
326 | },
327 | "execution_count": 35,
328 | "outputs": []
329 | },
330 | {
331 | "cell_type": "markdown",
332 | "source": [
333 | "## Step 8. Generate a function call"
334 | ],
335 | "metadata": {
336 | "id": "vL07Qym0n1dV"
337 | }
338 | },
339 | {
340 | "cell_type": "code",
341 | "source": [
342 | "prompt = \"\"\"What is the latest exchange rate from Australian dollars to Swedish krona?\n",
343 | "How much is 500 Australian dollars worth in Swedish krona?\"\"\"\n",
344 | "\n",
345 | "response = model.generate_content(\n",
346 | " prompt,\n",
347 | " tools=[exchange_rate_tool],\n",
348 | ")"
349 | ],
350 | "metadata": {
351 | "id": "qUKr6ryDcTvl"
352 | },
353 | "execution_count": 57,
354 | "outputs": []
355 | },
356 | {
357 | "cell_type": "code",
358 | "source": [
359 | "print(response.candidates[0].content)"
360 | ],
361 | "metadata": {
362 | "colab": {
363 | "base_uri": "https://localhost:8080/"
364 | },
365 | "id": "D-xZ3rsbcXp6",
366 | "outputId": "d3fff990-c388-43ea-e40d-e393ba9b1bdf"
367 | },
368 | "execution_count": 59,
369 | "outputs": [
370 | {
371 | "output_type": "stream",
372 | "name": "stdout",
373 | "text": [
374 | "role: \"model\"\n",
375 | "parts {\n",
376 | " function_call {\n",
377 | " name: \"get_exchange_rate\"\n",
378 | " args {\n",
379 | " fields {\n",
380 | " key: \"currency_date\"\n",
381 | " value {\n",
382 | " string_value: \"latest\"\n",
383 | " }\n",
384 | " }\n",
385 | " fields {\n",
386 | " key: \"currency_from\"\n",
387 | " value {\n",
388 | " string_value: \"AUD\"\n",
389 | " }\n",
390 | " }\n",
391 | " fields {\n",
392 | " key: \"currency_to\"\n",
393 | " value {\n",
394 | " string_value: \"SEK\"\n",
395 | " }\n",
396 | " }\n",
397 | " }\n",
398 | " }\n",
399 | "}\n",
400 | "\n"
401 | ]
402 | }
403 | ]
404 | },
405 | {
406 | "cell_type": "markdown",
407 | "source": [
408 | "## Step 9. Make an API request"
409 | ],
410 | "metadata": {
411 | "id": "BIzZ906Bn6L8"
412 | }
413 | },
414 | {
415 | "cell_type": "code",
416 | "source": [
417 | "params = {}\n",
418 | "for key, value in response.candidates[0].content.parts[0].function_call.args.items():\n",
419 | " params[key[9:]] = value\n",
420 | "params"
421 | ],
422 | "metadata": {
423 | "colab": {
424 | "base_uri": "https://localhost:8080/"
425 | },
426 | "id": "YuflEdH6caCY",
427 | "outputId": "02b4fa48-e0d8-456e-e9a1-16a1e1de8ab3"
428 | },
429 | "execution_count": 60,
430 | "outputs": [
431 | {
432 | "output_type": "execute_result",
433 | "data": {
434 | "text/plain": [
435 | "{'from': 'AUD', 'date': 'latest', 'to': 'SEK'}"
436 | ]
437 | },
438 | "metadata": {},
439 | "execution_count": 60
440 | }
441 | ]
442 | },
443 | {
444 | "cell_type": "code",
445 | "source": [
446 | "import requests\n",
447 | "url = f\"https://api.frankfurter.app/{params['date']}\"\n",
448 | "api_response = requests.get(url, params=params)\n",
449 | "api_response.json()"
450 | ],
451 | "metadata": {
452 | "colab": {
453 | "base_uri": "https://localhost:8080/"
454 | },
455 | "id": "B_M0seAGgZ7u",
456 | "outputId": "dc1b47df-9765-4277-9d79-1f34799949e8"
457 | },
458 | "execution_count": 62,
459 | "outputs": [
460 | {
461 | "output_type": "execute_result",
462 | "data": {
463 | "text/plain": [
464 | "{'amount': 1.0, 'base': 'AUD', 'date': '2024-03-15', 'rates': {'SEK': 6.7962}}"
465 | ]
466 | },
467 | "metadata": {},
468 | "execution_count": 62
469 | }
470 | ]
471 | },
472 | {
473 | "cell_type": "markdown",
474 | "source": [
475 | "## Step 10. Generate a response"
476 | ],
477 | "metadata": {
478 | "id": "X0_w2t1Mn-F9"
479 | }
480 | },
481 | {
482 | "cell_type": "code",
483 | "source": [
484 | "response = model.generate_content(\n",
485 | " [\n",
486 | " Content(role=\"user\", parts=[\n",
487 | " Part.from_text(prompt + \"\"\"Give your answer in steps with lots of detail\n",
488 | " and context, including the exchange rate and date.\"\"\"),\n",
489 | " ]),\n",
490 | " Content(role=\"function\", parts=[\n",
491 | " Part.from_dict({\n",
492 | " \"function_call\": {\n",
493 | " \"name\": \"get_exchange_rate\",\n",
494 | " }\n",
495 | " })\n",
496 | " ]),\n",
497 | " Content(role=\"function\", parts=[\n",
498 | " Part.from_function_response(\n",
499 | " name=\"get_exchange_rate\",\n",
500 | " response={\n",
501 | " \"content\": api_response.text,\n",
502 | " }\n",
503 | " )\n",
504 | " ]),\n",
505 | " ],\n",
506 | " tools=[exchange_rate_tool],\n",
507 | ")\n",
508 | "\n",
509 | "\n",
510 | "response.candidates[0].content.parts[0].text"
511 | ],
512 | "metadata": {
513 | "colab": {
514 | "base_uri": "https://localhost:8080/",
515 | "height": 35
516 | },
517 | "id": "a2Fgz7mkgcgY",
518 | "outputId": "0bb0686a-f69f-49af-cc97-b3b8a2b36c19"
519 | },
520 | "execution_count": 63,
521 | "outputs": [
522 | {
523 | "output_type": "execute_result",
524 | "data": {
525 | "text/plain": [
526 | "'500 Australian dollars is worth 3398.1 Swedish krona as of 2024-03-15. The exchange rate is 1 AUD = 6.7962 SEK.'"
527 | ],
528 | "application/vnd.google.colaboratory.intrinsic+json": {
529 | "type": "string"
530 | }
531 | },
532 | "metadata": {},
533 | "execution_count": 63
534 | }
535 | ]
536 | },
537 | {
538 | "cell_type": "markdown",
539 | "source": [
540 | "## Conclusion\n",
541 | "\n",
542 | "Explore further:\n",
543 | "\n",
544 | "\n",
545 | "1. [Gemini Pro API Reference](https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/gemini#gemini-pro)\n",
546 | "2. [Sample Function Calling Notebook](https://github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/function-calling/intro_function_calling.ipynb)\n",
547 | "\n"
548 | ],
549 | "metadata": {
550 | "id": "ZleqMsPxoJpx"
551 | }
552 | },
553 | {
554 | "cell_type": "code",
555 | "source": [],
556 | "metadata": {
557 | "id": "ADJFJ6acog0u"
558 | },
559 | "execution_count": null,
560 | "outputs": []
561 | }
562 | ]
563 | }
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Anubhav Singh
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------