├── .DS_Store ├── In_Context_Learning_Demo_PatternCompletion.ipynb ├── In_Context_Learning_Demo_ImageClassification.ipynb ├── In_Context_Learning_Demo_LinearRegression.ipynb ├── In_Context_Learning_Demo_BinaryClassification.ipynb ├── In_Context_Learning_Demo_TimeSeriesPrediction.ipynb ├── LICENSE └── In_Context_Learning_Demo_LanguageConditionedTimeSeriesPrediction.ipynb /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hsiang-fu/icl_demos/HEAD/.DS_Store -------------------------------------------------------------------------------- /In_Context_Learning_Demo_PatternCompletion.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "source": [ 6 | "# **Introduction to In-Context Learning Demo: Pattern Completion using the ARC Corpus**\n", 7 | "\n", 8 | "In this notebook we present a complete step-by-step walkthrough of how Large Language Models (LLMs) can **learn and apply visual completion and transformation patterns purely from few-shot examples**. Inspired by *[Large Language Models as General Pattern Machines](https://arxiv.org/pdf/2307.04721)* (Mirchandani et al., 2023), this demo uses the Abstract and Reasoning Corpus (ARC) introduced in *[On the Measure of Intelligence](https://arxiv.org/pdf/1911.01547)* (Chollet, 2019) as the source of structured input–output patterns and evaluates Gemini’s ability to recognize and generalize these patterns through in-context learning (ICL). The goal is to illustrate how an LLM without updates to any weights can observe a handful of small grid transformations and then correctly generate the missing output for a novel test example. The result is a learning paradigm that is purely based on the forward pass of the LLM without requiring any gradient-based updates.\n", 9 | "\n", 10 | "## **Overview**\n", 11 | "\n", 12 | "This demonstration guides you through:\n", 13 | "1. Loading the ARC tasks\n", 14 | "2. Constructing an ICL prompt composed of example input–output pairs\n", 15 | "3. Prompting Gemini using the constructed prompt for it to infer the underlying transformation\n", 16 | "4. Extract and render the model’s predicted test output\n", 17 | "\n", 18 | "The result is a clear visualization of how LLMs can act as pattern machines, discovering abstract rules from minimal supervision.\n", 19 | "\n", 20 | "## **Background**\n", 21 | "The Abstract and Reasoning Corpus (ARC) is a benchmark introduced by François Chollet to evaluate generalization, abstraction, and concept learning in AI systems. Unlike standard vision datasets, ARC focuses on Few-shot learning and minimal prior knowledge. The underlying hypothesis is that the learning systems would discover abstract rules that allow for human-like pattern induction.\n", 22 | "\n", 23 | "## **Let's Take a Look at an Example**\n", 24 | "\n", 25 | "The figure below shows an example from the ARC dataset and demonstrates how a model identifies and applies a transformation rule. The training examples provide few-shot demonstrations of the underlying pattern. Each training example has an input pattern and an output pattern. In the example below, the output can be regarded as a de-noised version of the input. The test query is a new input from the corpus that follows the same rule. In the case below, the test query is an input image that includes a number of noisy or corrupted pixels. The task of the LLM is to generate the corrected output image which adjusts for these corrupted pixels. The underlying assumption of ICL is that the large language model can learn what sort of transformation to perform, e.g., de-noising, from the few-shot examples alone.\n", 26 | "\n", 27 | "\n", 28 | "\n", 29 | "**In this specific example, the learned transformation is a de-noising operation that fills in and removes any corrupted pixels or segments.**\n", 30 | "\n", 31 | "## **LLM as the Optimizer**\n", 32 | "In our demo, we leverage Gemini 2.0/2.5 Flash to perform in-context learning (ICL) on Abstraction and Reasoning Corpus (ARC)–style grid transformation tasks. Unlike traditional supervised learning approaches, the LLM is never trained on the task. Instead, it infers the underlying transformation rule purely from the examples we provide in the prompt.\n", 33 | "\n", 34 | "The process can be understood as a sequence of structured steps: preparing the training examples, constructing a carefully engineered prompt, invoking the LLM to infer the rule, and finally converting its output back into a grid. The LLM behaves as a reasoning engine, reading the examples like a story and deducing the pattern that connects them.\n", 35 | "\n", 36 | "## **Code Overview**\n", 37 | "The implementation is organized into a modular structure, with each component responsible for a different stage of the ARC task-solving pipeline. This design separates data loading, visualization, prompt construction, LLM inference, and result interpretation, making the system easy to understand, modify, and extend.\n", 38 | "\n", 39 | "## **Before Running the Demo, Follow the Instructions Below**\n", 40 | "To run the full experiment:\n", 41 | "1. Ensure all dependencies are imported and installed.\n", 42 | "2.\tVisit Google AI Studio (https://aistudio.google.com/) to obtain your Gemini API key.\n", 43 | "3.\tOnce API key is generated, copy and paste it into the demo when prompted.\n", 44 | "\n", 45 | "#### **If having trouble creating an API Key, follow the link below for instructions:**\n", 46 | "* #### **[Instructions on How to Obtain a Gemini API Key](https://docs.google.com/document/d/17pgikIpvZCBgcRh4NNrcee-zp3Id0pU0vq2SG52KIGE/edit?usp=sharing)**\n", 47 | "\n", 48 | "\n", 49 | "### ***Note: Run all the cells below to initialize the environment.**\n", 50 | "\n", 51 | "\n", 52 | "\n" 53 | ], 54 | "metadata": { 55 | "id": "JWBl0iMOvbE_" 56 | } 57 | }, 58 | { 59 | "cell_type": "code", 60 | "source": [ 61 | "#@title **Import Necessary Libraries**\n", 62 | "import numpy as np\n", 63 | "import matplotlib.pyplot as plt\n", 64 | "from google import genai\n", 65 | "from google.genai import types\n", 66 | "import re\n", 67 | "import itertools\n", 68 | "import math\n", 69 | "import ast\n", 70 | "import json\n", 71 | "import ipywidgets as widgets\n", 72 | "from IPython.display import display\n", 73 | "import os\n", 74 | "import getpass" 75 | ], 76 | "metadata": { 77 | "id": "qDV03CJIve4R", 78 | "cellView": "form" 79 | }, 80 | "execution_count": null, 81 | "outputs": [] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "execution_count": null, 86 | "metadata": { 87 | "cellView": "form", 88 | "id": "lh6d1fWygOfe" 89 | }, 90 | "outputs": [], 91 | "source": [ 92 | "#@title **Setting Up Gemini Client**\n", 93 | "apikey = getpass.getpass(\"Enter your Gemini API Key: \")\n", 94 | "client = genai.Client(api_key=apikey)" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "source": [ 100 | "#@title **Choose a Model**\n", 101 | "model_dropdown = widgets.Dropdown(\n", 102 | " options=[\n", 103 | " (\"Gemini 2.5 Flash\", \"gemini-2.5-flash\"),\n", 104 | " (\"Gemini 2.0 Flash\", \"gemini-2.0-flash\")\n", 105 | " ],\n", 106 | " description=\"Model:\",\n", 107 | " value=\"gemini-2.5-flash\",\n", 108 | " style={'description_width': 'initial'}\n", 109 | ")\n", 110 | "\n", 111 | "confirm_button = widgets.Button(\n", 112 | " description=\"Confirm Selection\"\n", 113 | ")\n", 114 | "\n", 115 | "output = widgets.Output()\n", 116 | "\n", 117 | "model_name = None\n", 118 | "\n", 119 | "def on_confirm_click(b):\n", 120 | " global model_name, batch_size\n", 121 | "\n", 122 | " model_name = model_dropdown.value\n", 123 | "\n", 124 | " with output:\n", 125 | " output.clear_output()\n", 126 | " print(f\"\\nSelected model: {model_name}\")\n", 127 | "\n", 128 | "confirm_button.on_click(on_confirm_click)\n", 129 | "\n", 130 | "display(model_dropdown, confirm_button, output)" 131 | ], 132 | "metadata": { 133 | "id": "OdA80Nrvdo2N", 134 | "cellView": "form" 135 | }, 136 | "execution_count": null, 137 | "outputs": [] 138 | }, 139 | { 140 | "cell_type": "code", 141 | "source": [ 142 | "#@title **Download Data from GitHub**\n", 143 | "if not os.path.exists(\"intro_to_icl_data\"):\n", 144 | " !git clone https://github.com/hsiang-fu/intro_to_icl_data.git\n", 145 | "\n" 146 | ], 147 | "metadata": { 148 | "id": "6UMMqb9zX20A", 149 | "cellView": "form" 150 | }, 151 | "execution_count": null, 152 | "outputs": [] 153 | }, 154 | { 155 | "cell_type": "markdown", 156 | "source": [ 157 | "The below cell `Grid Display Function` is responsible for the entire grid display and visual interpretation layer of the ARC demo. Its purpose is to transform the raw numerical grids—integers ranging from 0 to 9—into the colored, human-readable visualizations used throughout the notebook.\n", 158 | "\n", 159 | "ARC tasks define grids where each integer corresponds to a color category. The `arc_colors` dictionary encodes this mapping using RGB triplets.\n", 160 | "For example:\n", 161 | "- 0 → black,\n", 162 | "- 1 → blue,\n", 163 | "- 2 → green,\n", 164 | "- 3 → red, etc.\n", 165 | "\n", 166 | "**These colors reflect the official ARC palette.**" 167 | ], 168 | "metadata": { 169 | "id": "xhByfpVOhj6c" 170 | } 171 | }, 172 | { 173 | "cell_type": "code", 174 | "execution_count": null, 175 | "metadata": { 176 | "id": "3TuKafdRgOfe", 177 | "cellView": "form" 178 | }, 179 | "outputs": [], 180 | "source": [ 181 | "#@title **Grid Display Function**\n", 182 | "arc_colors = {\n", 183 | " 0: (0, 0, 0),\n", 184 | " 1: (0, 0, 255),\n", 185 | " 2: (0, 255, 0),\n", 186 | " 3: (255, 0, 0),\n", 187 | " 4: (255, 255, 0),\n", 188 | " 5: (255, 165, 0),\n", 189 | " 6: (128, 0, 128),\n", 190 | " 7: (150, 75, 0),\n", 191 | " 8: (128, 128, 128),\n", 192 | " 9: (255, 192, 203)\n", 193 | "}\n", 194 | "\n", 195 | "def show_grid(grid, title=\"\"):\n", 196 | " arr = np.array(grid)\n", 197 | " rgb = np.zeros((arr.shape[0], arr.shape[1], 3), dtype=np.uint8)\n", 198 | " for k, v in arc_colors.items():\n", 199 | " rgb[arr == k] = v\n", 200 | " plt.imshow(rgb)\n", 201 | " plt.axis(\"off\")\n", 202 | " if title:\n", 203 | " plt.title(title)" 204 | ] 205 | }, 206 | { 207 | "cell_type": "markdown", 208 | "source": [ 209 | "The below cell `ARC ICL Demo` is responsible for running the full interactive ARC demo, including user task selection, prompt construction, LLM inference, visualization of training/test grids, parsing model outputs, and saving the results. It effectively acts as the end-to-end execution and UI layer of the entire notebook.\n", 210 | "\n", 211 | "The cell creates a simple user interface, consisting of :\n", 212 | "- A dropdown menu that lists ARC tasks (Task 1 through Task 10),\n", 213 | "- A button to confirm the selection.\n", 214 | "- **This makes the demo interactive, users can dynamically choose which ARC task they want to solve without modifying any code.**\n", 215 | "\n", 216 | "The model will then generate and running the visualization cell would show the following:\n", 217 | "\n", 218 | "- Training examples\n", 219 | "- Test example\n", 220 | "- Expected output\n", 221 | "- Model-generated output\n", 222 | "\n", 223 | "### **Let's Take a Look at an Example of a Prompt:**\n", 224 | "```\n", 225 | "Train:\n", 226 | "Example 1 Input: [[0, 0, 0, 0, ..., 0, 0, 0, 0, 0, 0, 0]] -> Example 1 Output: [[0, 0, 0, 0, 0, ..., 0, 0, 0, 0, 0]]\n", 227 | "Example 2 Input: [[0, 3, 0, 0, 0, 3, 3, 0, ..., 0, 0, 0]] -> Example 2 Output: [[0, 0, 0, 0, 0, ..., 0, 0, 0, 0, 0]]\n", 228 | "Example 3 Input: [[0, 0, 0, 0, 0, 0, ..., 0, 0, 0, 0, 0]] -> Example 3 Output: [[0, 0, 0, 0, 0, ..., 0, 0, 0, 0, 0]]\n", 229 | "\n", 230 | "Test Input: [[0, 0, 3, 3, 0, 0, 0, 0, ..., 3, 0, 0, 0, 0]] ->\n", 231 | "```\n", 232 | "\n", 233 | "This prompt shows the model three training examples and one test query from an ARC task (same task as the image above), where each example provides an input grid. The grid consists of rows of colored cells represented by a list of numbers (each number corresponds to a color) and a corresponding output grid that reflects how the grid changes after applying a hidden transformation rule. In ARC, the model’s objective is to infer this underlying pattern solely from the provided examples and then apply it to the test input to generate the appropriate output.\n", 234 | "\n", 235 | "This is considered ICL because the model is not being updated or fine-tuned. All “learning” occurs within the prompt itself: the examples become the model’s temporary training data. By observing the input–output pairs embedded in the prompt, the model identifies the pattern, generalizes it, and produces the appropriate output for the test query—all during a single forward pass of the LLM." 236 | ], 237 | "metadata": { 238 | "id": "6qsh3-ijiSq7" 239 | } 240 | }, 241 | { 242 | "cell_type": "code", 243 | "source": [ 244 | "#@title **ARC ICL Demo — Model Call & Inference**\n", 245 | "\n", 246 | "def count_pixel_differences(a, b):\n", 247 | " diff = 0\n", 248 | " for r1, r2 in zip(a, b):\n", 249 | " for p1, p2 in zip(r1, r2):\n", 250 | " if p1 != p2:\n", 251 | " diff += 1\n", 252 | " return diff\n", 253 | "\n", 254 | "dropdown = widgets.Dropdown(\n", 255 | " options=[f\"Task {i}\" for i in range(1, 11)],\n", 256 | " value=\"Task 1\",\n", 257 | " description=\"Select Task:\",\n", 258 | " style={'description_width': 'initial'}\n", 259 | ")\n", 260 | "\n", 261 | "button = widgets.Button(description=\"Run Inference\")\n", 262 | "\n", 263 | "file = None\n", 264 | "\n", 265 | "def run_llm_inference(file):\n", 266 | " file_path = \"intro_to_icl_data/\" + file + \".json\"\n", 267 | "\n", 268 | " if not os.path.exists(file_path):\n", 269 | " print(f\"File not found: {file_path}\")\n", 270 | " return None\n", 271 | "\n", 272 | " with open(file_path, \"r\") as f:\n", 273 | " task = json.load(f)\n", 274 | "\n", 275 | " # Build prompting string\n", 276 | " prompt = \"Train:\\n\"\n", 277 | " for i, example in enumerate(task[\"train\"]):\n", 278 | " prompt += (\n", 279 | " f\"Example {i+1} Input: {example['input']} -> \"\n", 280 | " f\"Example {i+1} Output: {example['output']}\\n\"\n", 281 | " )\n", 282 | " prompt += f\"\\nTest Input: {task['test'][0]['input']} -> \"\n", 283 | "\n", 284 | " # Run LLM\n", 285 | " response = client.models.generate_content(\n", 286 | " model=model_name,\n", 287 | " contents=[prompt],\n", 288 | " config=types.GenerateContentConfig(\n", 289 | " thinking_config=types.ThinkingConfig(thinking_budget=0)\n", 290 | " ),\n", 291 | " )\n", 292 | "\n", 293 | " # Extract output tokens\n", 294 | " match = re.search(r\"\\[.*\\]\", response.text, re.DOTALL)\n", 295 | " if not match:\n", 296 | " print(\"No valid output found in model response.\")\n", 297 | " return None\n", 298 | "\n", 299 | " # Save original ground truth\n", 300 | " task[\"test\"][0][\"ground_truth\"] = task[\"test\"][0][\"output\"]\n", 301 | " try:\n", 302 | " task[\"test\"][1][\"ground_truth\"] = task[\"test\"][1][\"output\"]\n", 303 | " except:\n", 304 | " pass\n", 305 | " # Save model output\n", 306 | " task[\"test\"][0][\"output\"] = ast.literal_eval(match.group(0))\n", 307 | "\n", 308 | " # Write output JSON\n", 309 | " directory, filename = os.path.split(file_path)\n", 310 | " output_filename = \"output_\" + filename\n", 311 | " output_file_path = os.path.join(directory, output_filename)\n", 312 | "\n", 313 | " with open(output_file_path, \"w\") as f:\n", 314 | " json.dump(task, f, indent=2)\n", 315 | "\n", 316 | " print(f\"\\nLLM output saved to: {output_file_path}\")\n", 317 | " return output_file_path\n", 318 | "\n", 319 | "def on_button_click(b):\n", 320 | " global file\n", 321 | " file = dropdown.value\n", 322 | " print(f\"\\nRunning LLM inference for: {file}…\")\n", 323 | " run_llm_inference(file)\n", 324 | "\n", 325 | "button.on_click(on_button_click)\n", 326 | "\n", 327 | "display(widgets.VBox([dropdown, button]))" 328 | ], 329 | "metadata": { 330 | "id": "MZ7Yx5sKw1sg", 331 | "cellView": "form" 332 | }, 333 | "execution_count": null, 334 | "outputs": [] 335 | }, 336 | { 337 | "cell_type": "code", 338 | "execution_count": null, 339 | "metadata": { 340 | "id": "R5_8IkaFgOff", 341 | "cellView": "form" 342 | }, 343 | "outputs": [], 344 | "source": [ 345 | "#@title **ARC ICL Demo — Visualization**\n", 346 | "\n", 347 | "def visualize_results(file):\n", 348 | " file_path = \"intro_to_icl_data/output_\" + file + \".json\"\n", 349 | "\n", 350 | " if not os.path.exists(file_path):\n", 351 | " print(\"Run inference first — output file not found.\")\n", 352 | " return\n", 353 | "\n", 354 | " with open(file_path, \"r\") as f:\n", 355 | " output = json.load(f)\n", 356 | "\n", 357 | " # Visualize train examples\n", 358 | " for i, pair in enumerate(output[\"train\"]):\n", 359 | " plt.figure(figsize=(4, 2))\n", 360 | " plt.subplot(1, 2, 1)\n", 361 | " show_grid(pair[\"input\"], \"\")\n", 362 | " plt.subplot(1, 2, 2)\n", 363 | " show_grid(pair[\"output\"], \"\")\n", 364 | " plt.suptitle(f\"Train Example {i+1}\")\n", 365 | " plt.show()\n", 366 | "\n", 367 | " for i, pair in enumerate(output[\"test\"]):\n", 368 | " plt.figure(figsize=(4, 2))\n", 369 | " plt.subplot(1, 2, 1)\n", 370 | " show_grid(pair[\"input\"], f\"Test Example {i+1}\")\n", 371 | " plt.show()\n", 372 | "\n", 373 | " # Visualize test results\n", 374 | " for i, pair in enumerate(output[\"test\"]):\n", 375 | " generated = pair[\"output\"]\n", 376 | " ground_truth = pair.get(\"ground_truth\", None)\n", 377 | "\n", 378 | " plt.figure(figsize=(4, 2))\n", 379 | "\n", 380 | " if ground_truth:\n", 381 | " plt.subplot(1, 2, 1)\n", 382 | " show_grid(ground_truth, \"Ground Truth\")\n", 383 | "\n", 384 | " plt.subplot(1, 2, 2)\n", 385 | " show_grid(generated, \"Generated Output\")\n", 386 | " else:\n", 387 | " show_grid(generated, \"Generated Output\")\n", 388 | "\n", 389 | " plt.show()\n", 390 | "\n", 391 | " if ground_truth:\n", 392 | " try:\n", 393 | " diff = count_pixel_differences(ground_truth, generated)\n", 394 | " print(f\"\\nThe Number of Pixel Differences Between the Predicted Output and Ground Truth is: {diff}\\n\")\n", 395 | " except:\n", 396 | " print(\"\\nError comparing outputs.\\n\")\n", 397 | "\n", 398 | "# Run it manually after inference:\n", 399 | "visualize_results(file)" 400 | ] 401 | }, 402 | { 403 | "cell_type": "markdown", 404 | "source": [ 405 | "## **Summary**\n", 406 | "This demo showcases how Large Language Models can perform in-context learning by recognizing visual transformation patterns from a small number of examples. Using tasks from the Abstract and Reasoning Corpus (ARC), we construct prompts that present input–output grid pairs and then ask the model to generate the missing transformation for a new test input. Through this setup, we observe how the model identifies the hidden rule by examining the structure of the grids rather than relying on built-in symbolic reasoning.\n", 407 | "\n", 408 | "Beyond the specific task illustrated here, the same methodology can be applied to a wide spectrum of ARC-style transformations. By crafting carefully chosen few-shot examples, we can guide the model toward discovering diverse rules such as color mappings, object counting, symmetry detection, region extraction, movement operations, and shape reconstruction. These extensions demonstrate how flexible the in-context learning framework is when applied to structured, symbolic data.\n", 409 | "\n", 410 | "## **Conclusion**\n", 411 | "This demonstration highlights that LLMs can generalize abstract rules similar to human inductive reasoning, even when reasoning over symbolic grid data. By providing only a few examples, the model can infer the underlying transformation and apply it to new cases—capturing the essence of in-context learning. Although the model does not perform explicit algorithmic reasoning, its ability to internalize and reproduce patterns suggests strong potential for tasks requiring rapid generalization from limited data. ARC serves as a valuable testing ground for understanding these capabilities, revealing both the promise and the limitations of treating LLMs as general pattern machines.\n", 412 | "\n", 413 | "## **References**\n", 414 | "Chollet, F. (2019). On the Measure of Intelligence. *arXiv preprint arXiv:1911.01547.*\n", 415 | "\n", 416 | "Mirchandani, S., Xia, F., Florence, P., Ichter, B., Driess, D., Arenas, M.G., Rao, K., Sadigh, D. and Zeng, A. (2023, December). Large Language Models as General Pattern Machines. *In Conference on Robot Learning* (pp. 2498-2518). PMLR." 417 | ], 418 | "metadata": { 419 | "id": "5tkOJMOJX7ix" 420 | } 421 | }, 422 | { 423 | "cell_type": "code", 424 | "source": [], 425 | "metadata": { 426 | "id": "wSMbLRACZy3L" 427 | }, 428 | "execution_count": null, 429 | "outputs": [] 430 | } 431 | ], 432 | "metadata": { 433 | "colab": { 434 | "provenance": [] 435 | }, 436 | "kernelspec": { 437 | "display_name": "base", 438 | "language": "python", 439 | "name": "python3" 440 | }, 441 | "language_info": { 442 | "codemirror_mode": { 443 | "name": "ipython", 444 | "version": 3 445 | }, 446 | "file_extension": ".py", 447 | "mimetype": "text/x-python", 448 | "name": "python", 449 | "nbconvert_exporter": "python", 450 | "pygments_lexer": "ipython3", 451 | "version": "3.13.5" 452 | } 453 | }, 454 | "nbformat": 4, 455 | "nbformat_minor": 0 456 | } -------------------------------------------------------------------------------- /In_Context_Learning_Demo_ImageClassification.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "source": [ 6 | "# **Introduction to In-Context Learning Demo: Image Classification**\n", 7 | "\n", 8 | "## **Overview**\n", 9 | "This notebook presents a step-by-step walkthrough of an interactive demo that explores how Large Language Models (LLMs) can perform image-based image classification purely through in-context learning. Inspired by the idea that LLMs can function as general pattern recognizers, this experiment uses small image snippets of weld defects, labeled with their respective defect—to examine how effectively Gemini can internalize visual defect patterns and reproduce consistent judgments on new samples.\n", 10 | "\n", 11 | "The goal is to illustrate how an LLM—without explicit training, fine-tuning, or feature engineering—can infer quality cues (defects) from a few examples and generalize to other images.\n", 12 | "\n", 13 | "## **Background**\n", 14 | "This demo focuses on real weld defect images categorized into 6 classes:\n", 15 | "- Weld Cracks\n", 16 | "- Burn Through\n", 17 | "- Lack of Fusion\n", 18 | "- Slag Inclusion\n", 19 | "- Weld Splatter\n", 20 | "- Surface Porosity\n", 21 | "\n", 22 | "Each example image serves as an example pairing:\n", 23 | "- A compact visual representation\n", 24 | "- A category label from one of the 6 classes\n", 25 | "\n", 26 | "What makes this setup particularly compelling:\n", 27 | "- The data was checked before utilization to make sure it couldn't classify the industrial defect correctly before ICL.\n", 28 | "- The model receives visual examples only through prompt context\n", 29 | "- No training or gradient updates occur—classification arises from pattern matching\n", 30 | "- Prompts can include descriptions, annotations, or multi-step chains of examples\n", 31 | "- The test images require generalization, not memorization\n", 32 | "- Predictions are generated sample-by-sample, mimicking standard evaluation flows\n", 33 | "\n", 34 | "This setting provides a clear benchmark for understanding how well LLMs can perform visual classification tasks when guided only through carefully constructed prompts.\n", 35 | "\n", 36 | "## **Let's Take a Look at an Example**\n", 37 | "\n", 38 | "The illustration below shows an example of weld defects. Given a set of labeled samples in context, the LLM must detect the defect for new, unseen defect images during evaluation.\n", 39 | "\n", 40 | "\n", 41 | "\n", 42 | "## **LLM as the Classifier**\n", 43 | "\n", 44 | "In this demo, we use Gemini 2.0/2.5 Flash in non-reasoning mode, prompting the model to rely on direct pattern recognition rather than symbolic explanation or analytic reasoning.\n", 45 | "\n", 46 | "The workflow proceeds as follows:\n", 47 | "- Provide 9 total images of weld defects (1-2 for each defect) as in-context examples\n", 48 | "- Send a new unlabeled image through the prompt to classify\n", 49 | "- Ask the model to output the category of the new unlabeled image\n", 50 | "\n", 51 | "The LLM functions as a lightweight, prompt-driven classifier—absorbing visual differences, structural patterns, and defect signatures from the in-context examples.\n", 52 | "\n", 53 | "## **Evaluation**\n", 54 | "Finally, we compare the model’s predicted labels against ground-truth labels and compute accuracy, which provides insight into how effectively a LLM can approximate visual quality-control decisions through in-context learning alone—without any dedicated training pipeline.\n", 55 | "\n", 56 | "## **Code Overview**\n", 57 | "The implementation is structured modularly, with each component handling a distinct stage of the ICL classification pipeline. This separation makes the system easy to modify, extend, and reuse:\n", 58 | "- Data loading & preprocessing: Read images, convert to model-compatible format\n", 59 | "- Visualization: Display sets of good/bad examples\n", 60 | "- Prompt construction: Insert labeled samples into few-shot prompts\n", 61 | "- LLM inference: Retrieve predictions one image at a time\n", 62 | "\n", 63 | "## **Before Running the Demo, Follow the Instructions Below**\n", 64 | "To run the full experiment:\n", 65 | "1. Ensure all dependencies are imported and installed.\n", 66 | "2.\tVisit Google AI Studio (https://aistudio.google.com/) to obtain your Gemini API key.\n", 67 | "3.\tOnce API key is generated, copy and paste it into the demo when prompted.\n", 68 | "\n", 69 | "#### **If having trouble creating an API Key, follow the link below for instructions:**\n", 70 | "* #### **[Instructions on How to Obtain a Gemini API Key](https://docs.google.com/document/d/17pgikIpvZCBgcRh4NNrcee-zp3Id0pU0vq2SG52KIGE/edit?usp=sharing)**\n", 71 | "\n" 72 | ], 73 | "metadata": { 74 | "id": "JWBl0iMOvbE_" 75 | } 76 | }, 77 | { 78 | "cell_type": "code", 79 | "source": [ 80 | "#@title **Import Necessary Libraries**\n", 81 | "import numpy as np\n", 82 | "import matplotlib.pyplot as plt\n", 83 | "from google import genai\n", 84 | "from google.genai import types\n", 85 | "import re\n", 86 | "import itertools\n", 87 | "import math\n", 88 | "import ast\n", 89 | "import json\n", 90 | "import ipywidgets as widgets\n", 91 | "from IPython.display import display\n", 92 | "import os\n", 93 | "import getpass\n", 94 | "from IPython.display import Image\n", 95 | "import time" 96 | ], 97 | "metadata": { 98 | "cellView": "form", 99 | "id": "qDV03CJIve4R" 100 | }, 101 | "execution_count": null, 102 | "outputs": [] 103 | }, 104 | { 105 | "cell_type": "code", 106 | "execution_count": null, 107 | "metadata": { 108 | "cellView": "form", 109 | "id": "lh6d1fWygOfe" 110 | }, 111 | "outputs": [], 112 | "source": [ 113 | "#@title **Setting Up Gemini Client**\n", 114 | "apikey = getpass.getpass(\"Enter your Gemini API Key: \")\n", 115 | "client = genai.Client(api_key=apikey)" 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "source": [ 121 | "#@title **Choose a Model**\n", 122 | "model_dropdown = widgets.Dropdown(\n", 123 | " options=[\n", 124 | " (\"Gemini 2.5 Flash\", \"gemini-2.5-flash\"),\n", 125 | " (\"Gemini 2.0 Flash\", \"gemini-2.0-flash\")\n", 126 | " ],\n", 127 | " description=\"Model:\",\n", 128 | " value=\"gemini-2.5-flash\",\n", 129 | " style={'description_width': 'initial'}\n", 130 | ")\n", 131 | "\n", 132 | "confirm_button = widgets.Button(\n", 133 | " description=\"Confirm Selection\"\n", 134 | ")\n", 135 | "\n", 136 | "output = widgets.Output()\n", 137 | "\n", 138 | "model_name = None\n", 139 | "\n", 140 | "def on_confirm_click(b):\n", 141 | " global model_name, batch_size\n", 142 | "\n", 143 | " model_name = model_dropdown.value\n", 144 | "\n", 145 | " with output:\n", 146 | " output.clear_output()\n", 147 | " print(f\"\\nSelected model: {model_name}\")\n", 148 | "\n", 149 | "confirm_button.on_click(on_confirm_click)\n", 150 | "\n", 151 | "display(model_dropdown, confirm_button, output)" 152 | ], 153 | "metadata": { 154 | "cellView": "form", 155 | "id": "FyQ7yQQV7qvt" 156 | }, 157 | "execution_count": null, 158 | "outputs": [] 159 | }, 160 | { 161 | "cell_type": "code", 162 | "source": [ 163 | "#@title **Download Data from GitHub**\n", 164 | "if not os.path.exists(\"intro_to_icl_data\"):\n", 165 | " !git clone https://github.com/hsiang-fu/intro_to_icl_data.git" 166 | ], 167 | "metadata": { 168 | "id": "6UMMqb9zX20A", 169 | "cellView": "form" 170 | }, 171 | "execution_count": null, 172 | "outputs": [] 173 | }, 174 | { 175 | "cell_type": "markdown", 176 | "source": [ 177 | "The cell `ICL Image Classification` below is responsible for running ICL weld defect image classification. It loads labeled training examples, constructs the few-shot prompt, sends the prompt to the LLM for each test image, parses the prediction, evaluates correctness, and finally reports overall accuracy.\n", 178 | "\n", 179 | "Each test image is displayed, and the model’s prediction and the ground truth label are printed. Running each cell will perform inference across all test images.\n", 180 | "\n", 181 | "After the cell runs, it performs several steps. First, it builds the ICL training examples by loading labled training images. Each example is converted into a Gemini Part so it can be embedded directly into the prompt. These form the annotated few-shot demonstrations the model uses to learn the classification pattern. Next, it constructs the full ICL prompt for each test item by including the instruction, all labeled example images, the unlabeled test image, and a rule specifying that the model should respond only with the labels of the weld defect. Then it loads the unseen test images and their ground-truth labels. For each test image, the cell displays the image, sends the entire ICL prompt to Gemini, reads the model’s label prediction, compares it to the ground truth, and stores the results. After all images are processed, the cell computes summary metrics such as accuracy, total number of correct predictions, and incorrect predictions.\n", 182 | "\n", 183 | "Each evaluation cycle outputs the test hazelnut image, the model’s predicted label, the true label, and whether the prediction was correct. At the end, the code prints a performance summary showing the model’s accuracy across all ten test images." 184 | ], 185 | "metadata": { 186 | "id": "NkxcjLfcvzzw" 187 | } 188 | }, 189 | { 190 | "cell_type": "code", 191 | "source": [ 192 | "#@title **ICL Image Classification**\n", 193 | "test_labels = [\n", 194 | " \"Burn Through\",\n", 195 | " \"Cracks\",\n", 196 | " \"Lack of Fusion\",\n", 197 | " \"Slag Inclusion\",\n", 198 | " \"Splatter\",\n", 199 | " \"Surface Porosity\",\n", 200 | " \"Cracks\",\n", 201 | " \"Cracks\",\n", 202 | " \"Lack of Fusion\",\n", 203 | " \"Splatter\",\n", 204 | " \"Surface Porosity\"\n", 205 | " ]\n", 206 | "\n", 207 | "def load_part(path):\n", 208 | " with open(path, \"rb\") as f:\n", 209 | " return types.Part.from_bytes(\n", 210 | " data=f.read(),\n", 211 | " mime_type=\"image/jpeg\"\n", 212 | " )\n", 213 | "\n", 214 | "train_paths = [f\"intro_to_icl_data/industrial_defects/train{i}.jpg\" for i in range(1, 10)]\n", 215 | "train_labels = [\"Cracks\", \"Surface Porosity\", \"Slag Inclusion\", \"Splatter\", \"Burn Through\", \"Lack of Fusion\", \"Splatter\", \"Slag Inclusion\", \"Surface Porosity\", \"Burn Through\", \"Burn Through\"]\n", 216 | "\n", 217 | "train_parts = [load_part(p) for p in train_paths]\n", 218 | "\n", 219 | "def classify_image(label, index):\n", 220 | " test_path = f\"intro_to_icl_data/industrial_defects/test{index}.jpg\"\n", 221 | " test_part = load_part(test_path)\n", 222 | "\n", 223 | " contents = [\"You are an expert in detecting industrial defects. \"\n", 224 | " \"By only using the provided examples, classify the defect.\\n\"]\n", 225 | "\n", 226 | " for i, (tlabel, tpart) in enumerate(zip(train_labels, train_parts)):\n", 227 | " contents.append(f\"Example {i+1}: {tlabel}\")\n", 228 | " contents.append(tpart)\n", 229 | "\n", 230 | " contents.extend([\n", 231 | " \"What is the defect in the test image? Only return the label.\",\n", 232 | " test_part,\n", 233 | " ])\n", 234 | "\n", 235 | "\n", 236 | " try:\n", 237 | " response = client.models.generate_content(\n", 238 | " model=model_name,\n", 239 | " contents=contents,\n", 240 | " )\n", 241 | " except Exception as e:\n", 242 | " print(\"Waiting 60 seconds for quota limit reset\")\n", 243 | " time.sleep(60)\n", 244 | " response = client.models.generate_content(\n", 245 | " model=\"gemini-2.5-flash\",\n", 246 | " contents=contents,\n", 247 | " config=types.GenerateContentConfig(\n", 248 | " thinking_config=types.ThinkingConfig(thinking_budget=0)\n", 249 | " )\n", 250 | " )\n", 251 | "\n", 252 | " return response.text.strip(), test_path\n", 253 | "\n", 254 | "correct = 0\n", 255 | "results = []\n", 256 | "print(\"Starting Image Classification for Welding Defects\\n\")\n", 257 | "for i, label in enumerate(test_labels, start=1):\n", 258 | " pred, path = classify_image(label, i)\n", 259 | "\n", 260 | " display(Image(filename=path, height = 300))\n", 261 | "\n", 262 | " print(f\"\\nGround Truth: {label}\")\n", 263 | " print(f\"Model Output: {pred}\\n\")\n", 264 | "\n", 265 | " is_correct = (pred.lower() == label.lower())\n", 266 | "\n", 267 | " results.append((label, pred, is_correct))\n", 268 | " correct += int(is_correct)\n", 269 | "\n", 270 | "accuracy = correct / len(test_labels)\n", 271 | "print(f\"Overall Accuracy: {accuracy:.2f}\\n\")" 272 | ], 273 | "metadata": { 274 | "cellView": "form", 275 | "id": "VE8BH-DCqzUM" 276 | }, 277 | "execution_count": null, 278 | "outputs": [] 279 | }, 280 | { 281 | "cell_type": "markdown", 282 | "source": [ 283 | "This cell `Baseline SVM Model` implements the traditional machine-learning baseline used to compare against the in-context learning approaches. Rather than learning patterns directly from image pixels, this baseline relies on hand-crafted visual descriptors—specifically Histogram of Oriented Gradients (HOG)—which are then classified using a linear Support Vector Machine (SVM). The code loads the training images, converts them to grayscale, resizes each to 256×256, and extracts their HOG feature vectors. The same preprocessing steps are applied to the test set, ensuring a consistent feature representation. Once the features are assembled, a LinearSVC classifier is trained and evaluated on the same set of test images used by the ICL methods. The resulting accuracy provides a structured, feature-engineered benchmark to compare against the LLM’s prompt-based classification performance." 284 | ], 285 | "metadata": { 286 | "id": "OpteMWm3xbAy" 287 | } 288 | }, 289 | { 290 | "cell_type": "code", 291 | "source": [ 292 | "#@title **Baseline SVM Model**\n", 293 | "from skimage.feature import hog\n", 294 | "from skimage.io import imread\n", 295 | "from skimage.color import rgb2gray\n", 296 | "from skimage.transform import resize\n", 297 | "from sklearn.svm import LinearSVC\n", 298 | "from sklearn.metrics import accuracy_score\n", 299 | "\n", 300 | "# Load train data\n", 301 | "X_train = []\n", 302 | "y_train = []\n", 303 | "\n", 304 | "for path, label in zip(train_paths, train_labels):\n", 305 | " img = imread(path)\n", 306 | " img = resize(img, (256, 256), anti_aliasing=True)\n", 307 | " img = rgb2gray(img)\n", 308 | "\n", 309 | " feats = hog(img, pixels_per_cell=(16,16), cells_per_block=(2,2))\n", 310 | " X_train.append(feats)\n", 311 | " y_train.append(label)\n", 312 | "\n", 313 | "# Load test data\n", 314 | "X_test = []\n", 315 | "y_test = []\n", 316 | "\n", 317 | "for i, label in enumerate(test_labels, start=1):\n", 318 | " img = imread(f\"intro_to_icl_data/industrial_defects/test{i}.jpg\")\n", 319 | " img = resize(img, (256, 256), anti_aliasing=True)\n", 320 | " img = rgb2gray(img)\n", 321 | "\n", 322 | " feats = hog(img, pixels_per_cell=(16,16), cells_per_block=(2,2))\n", 323 | " X_test.append(feats)\n", 324 | " y_test.append(label)\n", 325 | "\n", 326 | "# Train classifier\n", 327 | "clf = LinearSVC()\n", 328 | "clf.fit(X_train, y_train)\n", 329 | "\n", 330 | "# Predict\n", 331 | "preds = clf.predict(X_test)\n", 332 | "print(\"Baseline HOG+SVM Accuracy:\", accuracy_score(y_test, preds))" 333 | ], 334 | "metadata": { 335 | "cellView": "form", 336 | "id": "heJjTaJ1RpPX" 337 | }, 338 | "execution_count": null, 339 | "outputs": [] 340 | }, 341 | { 342 | "cell_type": "code", 343 | "source": [ 344 | "#@title **ICL and Baseline SVM Comparison**\n", 345 | "from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, classification_report\n", 346 | "\n", 347 | "# ----------------------------------------------------\n", 348 | "# Extract predictions from your existing ICL results\n", 349 | "# results = [(label, pred, is_correct), ...]\n", 350 | "icl_true = [r[0] for r in results]\n", 351 | "icl_pred = [r[1] for r in results]\n", 352 | "\n", 353 | "# ----------------------------------------------------\n", 354 | "# Extract baseline predictions (already in 'preds')\n", 355 | "baseline_pred = list(preds)\n", 356 | "baseline_true = list(test_labels) # Same ordering as test inputs\n", 357 | "\n", 358 | "# ----------------------------------------------------\n", 359 | "# Compute accuracies\n", 360 | "icl_accuracy = sum(np.array(icl_true) == np.array(icl_pred)) / len(icl_true)\n", 361 | "baseline_accuracy = sum(np.array(baseline_true) == np.array(baseline_pred)) / len(baseline_true)\n", 362 | "\n", 363 | "print(\"===========================================\")\n", 364 | "print(\" Model Comparison Summary\")\n", 365 | "print(\"===========================================\\n\")\n", 366 | "\n", 367 | "print(f\"ICL Accuracy: {icl_accuracy:.3f}\")\n", 368 | "print(f\"Baseline Accuracy: {baseline_accuracy:.3f}\\n\")\n", 369 | "\n", 370 | "print(\"ICL Classification Report:\")\n", 371 | "print(classification_report(icl_true, icl_pred, zero_division=0))\n", 372 | "\n", 373 | "print(\"Baseline Classification Report:\")\n", 374 | "print(classification_report(baseline_true, baseline_pred, zero_division=0))\n", 375 | "\n", 376 | "# ----------------------------------------------------\n", 377 | "# Confusion Matrices (ICL and Baseline)\n", 378 | "labels_sorted = sorted(list(set(test_labels)))\n", 379 | "\n", 380 | "fig, axes = plt.subplots(1, 2, figsize=(14, 5))\n", 381 | "\n", 382 | "# ICL Confusion Matrix\n", 383 | "cm_icl = confusion_matrix(icl_true, icl_pred, labels=labels_sorted)\n", 384 | "disp_icl = ConfusionMatrixDisplay(confusion_matrix=cm_icl, display_labels=labels_sorted)\n", 385 | "disp_icl.plot(ax=axes[0], xticks_rotation=45, cmap=\"Blues\", colorbar=False)\n", 386 | "axes[0].set_title(\"ICL Confusion Matrix\")\n", 387 | "\n", 388 | "# Baseline Confusion Matrix\n", 389 | "cm_baseline = confusion_matrix(baseline_true, baseline_pred, labels=labels_sorted)\n", 390 | "disp_base = ConfusionMatrixDisplay(confusion_matrix=cm_baseline, display_labels=labels_sorted)\n", 391 | "disp_base.plot(ax=axes[1], xticks_rotation=45, cmap=\"Greens\", colorbar=False)\n", 392 | "axes[1].set_title(\"Baseline Confusion Matrix\")\n", 393 | "\n", 394 | "plt.tight_layout()\n", 395 | "plt.show()" 396 | ], 397 | "metadata": { 398 | "cellView": "form", 399 | "id": "W3eSReNEq4AZ" 400 | }, 401 | "execution_count": null, 402 | "outputs": [] 403 | }, 404 | { 405 | "cell_type": "markdown", 406 | "source": [ 407 | "## **Summary**\n", 408 | "\n", 409 | "This demo illustrates how LLMs can perform in-context learning (ICL) for multi-class image classification, using real weld defect imagery as the target domain. By supplying the model with a small set of annotated image–label pairs—covering six defect categories such as weld cracks, burn through, slag inclusion, weld splatter, lack of fusion, and surface porosity—we show that the model can infer subtle structural cues that distinguish one defect type from another. These cues include local texture disruptions, cavity patterns, shape irregularities, and characteristic weld-surface anomalies. Crucially, the model learns entirely from the examples embedded in the prompt: no fine-tuning, no gradient updates, and no specialized vision training occurs.\n", 410 | "\n", 411 | "To contextualize performance, the ICL approach is compared against a traditional machine-learning baseline that must learn directly from pixel-level information. While the baseline relies on supervised training and engineered features, the LLM derives its classification behavior purely from pattern recognition using images and the prompt. When evaluated on unseen weld defect images, ICL consistently performs better than the baseline—demonstrating stronger generalization from just a few examples and outperforming the benchmark in overall accuracy. These results highlight the efficiency and adaptability of ICL for inspection-style tasks, especially when training data is scarce or rapid deployment is required.\n", 412 | "\n", 413 | "## **Conclusion**\n", 414 | "\n", 415 | "This demonstration shows that LLMs can successfully classify complex weld defects using only in-context visual examples, effectively acting as prompt-driven inspectors capable of recognizing defect signatures from minimal supervision. The ICL paradigm proves especially powerful in this setting: with only nine example images, the model generalizes to new, unseen weld defects more reliably than the supervised baseline, reflecting the model’s flexibility and its ability to internalize visual patterns without any training pipeline.\n", 416 | "\n", 417 | "Compared to traditional approaches—which typically require substantial datasets, model tuning, and iterative optimization—the ICL method provides a fast, low-overhead alternative that can be adapted to new defect categories simply by revising the prompt. Together, the comparison between ICL and the baseline model demonstrates why LLMs are well-suited for rapid, lightweight visual classification tasks such as weld inspection, quality assurance, and defect triage, offering accurate and consistent performance with minimal setup and significantly reduced computational cost." 418 | ], 419 | "metadata": { 420 | "id": "fGbNdcqOx4u_" 421 | } 422 | }, 423 | { 424 | "cell_type": "code", 425 | "source": [], 426 | "metadata": { 427 | "id": "DD3usKIxx7db" 428 | }, 429 | "execution_count": null, 430 | "outputs": [] 431 | } 432 | ], 433 | "metadata": { 434 | "colab": { 435 | "provenance": [] 436 | }, 437 | "kernelspec": { 438 | "display_name": "base", 439 | "language": "python", 440 | "name": "python3" 441 | }, 442 | "language_info": { 443 | "codemirror_mode": { 444 | "name": "ipython", 445 | "version": 3 446 | }, 447 | "file_extension": ".py", 448 | "mimetype": "text/x-python", 449 | "name": "python", 450 | "nbconvert_exporter": "python", 451 | "pygments_lexer": "ipython3", 452 | "version": "3.13.5" 453 | } 454 | }, 455 | "nbformat": 4, 456 | "nbformat_minor": 0 457 | } -------------------------------------------------------------------------------- /In_Context_Learning_Demo_LinearRegression.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "6fhNlpM88MMV" 7 | }, 8 | "source": [ 9 | "# **Introduction to In-Context Learning Demo: Linear Regression**\n", 10 | "This notebook provides a step-by-step walkthrough of an interactive demonstration showing how **Large Language Models (LLMs) can act as optimizers for optimal parameter estimation**. Inspired by recent work positioning LLMs as flexible meta-learners, this experiment uses a simple linear regression setting to investigate whether an LLM can iteratively propose improved parameter values (w, b) by observing model performance, previous attempts, and its own update history. Rather than computing gradients analytically, we allow the LLM—operating purely in non-reasoning mode—to navigate the optimization landscape using only in-context cues.\n", 11 | "\n", 12 | "The goal is to illustrate how an **LLM can behave like a general optimization engine: adjusting parameters, reducing loss,** and **converging towards a solution** after multiple rounds of proposal and feedback.\n", 13 | "\n", 14 | "## **Overview**\n", 15 | "This demonstration guides you through:\n", 16 | "1. Generating synthetic training data from a ground-truth linear function\n", 17 | "2. Running an iterative loop where the LLM proposes updated parameters\n", 18 | "3. Visualizing the optimization path across iterations\n", 19 | "4. Rendering an animation showing the model’s gradual convergence\n", 20 | "\n", 21 | "The result is a clear visualization of how an LLM can perform optimization through pattern-based updates—even in a numerical, regression-style task.\n", 22 | "\n", 23 | "## **Background**\n", 24 | "\n", 25 | "This demo centers on the classic linear regression form:\n", 26 | "`y = w⋅x + b`\n", 27 | "\n", 28 | "We generate a small set of 3 points from this ground-truth linear regression line and then give the LLM only the examples of points and a record of previous attempts. The model is not given gradients, derivative formulas, or closed-form solutions. Instead, it receives:\n", 29 | "* The training pairs\n", 30 | "* Its previous guesses\n", 31 | "* The resulting loss values\n", 32 | "\n", 33 | "From this, the model must infer a direction of improvement and produce new (w, b) proposals.\n", 34 | "\n", 35 | "What makes this setup interesting:\n", 36 | "* The LLM acts like an optimizer rather than a solver\n", 37 | "* The updates depend solely on in-context patterns\n", 38 | "* The model relies on recognizing relationships between parameters and loss\n", 39 | "\n", 40 | "This makes it an intuitive testbed for evaluating whether LLMs can learn to optimize through context alone.\n", 41 | "\n", 42 | "## **Let’s Take a Look at an Example**\n", 43 | "\n", 44 | "The visualization below shows:\n", 45 | "* The training points used for estimation\n", 46 | "* The sequences of candidate regression lines proposed across iterations\n", 47 | "* A colorbar indicating iteration step\n", 48 | "\n", 49 | "Each line other than the initial line represents the LLM’s current guess for the underlying linear function.\n", 50 | "\n", 51 | "\n", 52 | "\n", 53 | "As the iterations progress, the proposed regression line typically shifts toward the true relationship, demonstrating how the model improves purely based on the examples and the loss given.\n", 54 | "\n", 55 | "## **LLM as the Optimizer**\n", 56 | "In this demo, we use Gemini 2.0/2.5 Flash in non-reasoning mode. This encourages the model to act not through symbolic derivation, but through pattern recognition and iterative refinement.\n", 57 | "\n", 58 | "The workflow is simple:\n", 59 | "- Provide the three (x,y) points\n", 60 | "- Supply the model’s current (w, b) guess and its loss\n", 61 | "- Ask the model to propose new values that reduce the loss\n", 62 | "- Repeat for 30–100 iterations\n", 63 | "\n", 64 | "The LLM becomes an optimization engine—adjusting parameters using only contextual signals. This further highlights the emerging ability of LLMs to perform optimization without explicit algorithmic implementations.\n", 65 | "\n", 66 | "## **Code Overview**\n", 67 | "The implementation is organized into a modular structure, with each component responsible for a different stage of the linear regression optimization pipeline. This design separates data loading, visualization, prompt construction, and LLM inference, making the system easy to understand, modify, and extend.\n", 68 | "\n", 69 | "## **Before Running the Demo, Follow the Instructions Below**\n", 70 | "To run the full experiment:\n", 71 | "1. Ensure all dependencies are imported and installed.\n", 72 | "2.\tVisit Google AI Studio (https://aistudio.google.com/) to obtain your Gemini API key.\n", 73 | "3.\tOnce API key is generated, copy and paste it into the demo when prompted.\n", 74 | "\n", 75 | "#### **If having trouble creating an API Key, follow the link below for instructions:**\n", 76 | "* #### **[Instructions on How to Obtain a Gemini API Key](https://docs.google.com/document/d/17pgikIpvZCBgcRh4NNrcee-zp3Id0pU0vq2SG52KIGE/edit?usp=sharing)**\n", 77 | "\n" 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": null, 83 | "metadata": { 84 | "id": "LhzfLjFw50v5" 85 | }, 86 | "outputs": [], 87 | "source": [ 88 | "#@title **Import Necessary Libraries**\n", 89 | "import numpy as np\n", 90 | "import matplotlib.pyplot as plt\n", 91 | "from google import genai\n", 92 | "from google.genai import types\n", 93 | "import re\n", 94 | "import itertools\n", 95 | "import math\n", 96 | "import ast\n", 97 | "import getpass\n", 98 | "from tqdm import tqdm\n", 99 | "import time\n", 100 | "import ipywidgets as widgets\n", 101 | "import matplotlib.pyplot as plt\n", 102 | "import matplotlib.animation as animation\n", 103 | "from IPython.display import Image\n", 104 | "from matplotlib.animation import FuncAnimation\n", 105 | "from IPython.display import HTML" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": null, 111 | "metadata": { 112 | "cellView": "form", 113 | "id": "d2ROw87H5j0O" 114 | }, 115 | "outputs": [], 116 | "source": [ 117 | "#@title **Setting Up Gemini Client**\n", 118 | "apikey = getpass.getpass(\"Enter your Gemini API Key: \")\n", 119 | "client = genai.Client(api_key=apikey)" 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "source": [ 125 | "#@title **Choose a Model**\n", 126 | "model_dropdown = widgets.Dropdown(\n", 127 | " options=[\n", 128 | " (\"Gemini 2.5 Flash\", \"gemini-2.5-flash\"),\n", 129 | " (\"Gemini 2.0 Flash\", \"gemini-2.0-flash\")\n", 130 | " ],\n", 131 | " description=\"Model:\",\n", 132 | " value=\"gemini-2.5-flash\",\n", 133 | " style={'description_width': 'initial'}\n", 134 | ")\n", 135 | "\n", 136 | "confirm_button = widgets.Button(\n", 137 | " description=\"Confirm Selection\"\n", 138 | ")\n", 139 | "\n", 140 | "output = widgets.Output()\n", 141 | "\n", 142 | "model_name = None\n", 143 | "\n", 144 | "def on_confirm_click(b):\n", 145 | " global model_name, batch_size\n", 146 | "\n", 147 | " model_name = model_dropdown.value\n", 148 | "\n", 149 | " with output:\n", 150 | " output.clear_output()\n", 151 | " print(f\"\\nSelected model: {model_name}\")\n", 152 | "\n", 153 | "confirm_button.on_click(on_confirm_click)\n", 154 | "\n", 155 | "display(model_dropdown, confirm_button, output)" 156 | ], 157 | "metadata": { 158 | "cellView": "form", 159 | "id": "ZsHMlYEy7kJe" 160 | }, 161 | "execution_count": null, 162 | "outputs": [] 163 | }, 164 | { 165 | "cell_type": "markdown", 166 | "metadata": { 167 | "id": "vkEeqSfR5j0P" 168 | }, 169 | "source": [ 170 | "## **Utlity Functions**\n", 171 | "\n", 172 | "Below are the cells that act as the backbone to the entire LLM optimization pipeline.\n", 173 | "\n", 174 | "The cell `Linear Task Randomized Generation` is in charge of generating the linear task the LLM is supposed to predict. It uses the line generated by random values of the slope and intercept to obtain 35 random noisy data points using the line and returns the true slope and intercept, along with the 35 (x, y) pairs\n", 175 | "\n", 176 | "The next cell `Prompt Generation` builds the text prompt given to the LLM during each optimization step. It formats the training point pairs and the history of previous (w, b) guesses with their losses, then inserts them into a structured instruction block. The function returns a prompt that tells the model to propose a better (w, b) and output it strictly as a JSON object.\n", 177 | "\n", 178 | "The following cell `Model Inference` contains a function that sends the generated prompt to the LLM and retrieves its proposed (w, b) values. It calls the model, retries after a delay if a quota error occurs, and then cleans and parses the model’s JSON-only response. Finally, it extracts the \"w\" and \"b\" fields from the parsed output and returns them." 179 | ] 180 | }, 181 | { 182 | "cell_type": "code", 183 | "execution_count": null, 184 | "metadata": { 185 | "id": "9i30_Zzj1IOs" 186 | }, 187 | "outputs": [], 188 | "source": [ 189 | "#@title **Linear Task Randomized Generation**\n", 190 | "def generate_linear_task(noise_ratio=0.2):\n", 191 | " w = np.random.uniform(-5, 5) # slope\n", 192 | " b = np.random.uniform(-5, 5) # intercept\n", 193 | " n_train = 35 # 50 points\n", 194 | "\n", 195 | " x_train = np.random.uniform(-5, 5, n_train)\n", 196 | "\n", 197 | " # Clean line\n", 198 | " y_clean = w * x_train + b\n", 199 | "\n", 200 | " # Scale noise to the signal range (10% of typical variation)\n", 201 | " y_range = np.max(y_clean) - np.min(y_clean)\n", 202 | " noise_std = noise_ratio * y_range\n", 203 | "\n", 204 | " # Add Gaussian noise\n", 205 | " noise = np.random.normal(0, noise_std, n_train)\n", 206 | " y_train = y_clean + noise\n", 207 | "\n", 208 | " # Round for readability\n", 209 | " train_pairs = [(round(x, 2), round(y, 2)) for x, y in zip(x_train, y_train)]\n", 210 | " w = round(w, 2)\n", 211 | " b = round(b, 2)\n", 212 | " return train_pairs, w, b" 213 | ] 214 | }, 215 | { 216 | "cell_type": "code", 217 | "execution_count": null, 218 | "metadata": { 219 | "id": "Bh5rOyr1_BjH" 220 | }, 221 | "outputs": [], 222 | "source": [ 223 | "#@title **Prompt Generation**\n", 224 | "def generate_optimizer_prompt(train_pairs, history, it):\n", 225 | "\n", 226 | " history_str = \"\\n\".join([f\"(w={w}, b={b}): loss={round(loss,2)}\" for w, b, loss in history])\n", 227 | "\n", 228 | " examples_str = \"\\n\".join([f\"({x},{y})\" for x, y in train_pairs])\n", 229 | "\n", 230 | " prompt = f\"\"\"\n", 231 | "You are helping to discover the underlying linear relationship y=w*x+b.\n", 232 | "\n", 233 | "Here are the (x, y) numerical examples:\n", 234 | "{examples_str}\n", 235 | "\n", 236 | "Here are the previously tried (w, b) values and their losses:\n", 237 | "{history_str}\n", 238 | "\n", 239 | "You are currently at iteration {it} of 50.\n", 240 | "\n", 241 | "Your task:\n", 242 | "- Infer a better (w, b) that reduces the loss.\n", 243 | "- Propose exactly one (w, b) pair.\n", 244 | "- Return ONLY a JSON object of the form:\n", 245 | " {{\"w\": number, \"b\": number}}\n", 246 | "\n", 247 | "Do not include any text, explanation, or code.\n", 248 | " \"\"\"\n", 249 | " return prompt\n" 250 | ] 251 | }, 252 | { 253 | "cell_type": "code", 254 | "execution_count": null, 255 | "metadata": { 256 | "id": "i7BTEkiE_EVA" 257 | }, 258 | "outputs": [], 259 | "source": [ 260 | "#@title **Model Inference**\n", 261 | "def propose_wb(train_pairs, history,it):\n", 262 | " prompt = generate_optimizer_prompt(train_pairs, history,it)\n", 263 | " try:\n", 264 | " response = client.models.generate_content(\n", 265 | " model=model_name,\n", 266 | " contents=[prompt],\n", 267 | " config=types.GenerateContentConfig(\n", 268 | " thinking_config=types.ThinkingConfig(thinking_budget=0)\n", 269 | " ),\n", 270 | " )\n", 271 | " except:\n", 272 | " print(\"\\nWaiting 60 seconds for quota limit reset\")\n", 273 | " time.sleep(60)\n", 274 | " return propose_wb(train_pairs, history,it)\n", 275 | "\n", 276 | " cleaned = response.text.replace(\"```\",\"\").replace(\"json\",\"\").replace(\"python\",\"\").strip()\n", 277 | "\n", 278 | " parsed = ast.literal_eval(cleaned)\n", 279 | " w = parsed[\"w\"]\n", 280 | " b = parsed[\"b\"]\n", 281 | " return w, b\n" 282 | ] 283 | }, 284 | { 285 | "cell_type": "markdown", 286 | "metadata": { 287 | "id": "bfSgciiSXPGj" 288 | }, 289 | "source": [ 290 | "## **Let's look at an example of the prompt**\n", 291 | "```\n", 292 | "You are trying to discover the underlying linear line from the numerical examples.\n", 293 | "\n", 294 | "Here are the example (x,y) points:\n", 295 | "{xy_points}\n", 296 | "\n", 297 | "Previously tried w, b and the loss:\n", 298 | "{optimization_history}\n", 299 | "\n", 300 | "You are currently at iteration {it} out of 50\n", 301 | "\n", 302 | "Your task:\n", 303 | "- Infer a better (w, b) that reduces the loss.\n", 304 | "- Propose exactly one (w, b) pair.\n", 305 | "- Return ONLY a JSON object of the form:\n", 306 | " {{\"w\": number, \"b\": number}}\n", 307 | "\n", 308 | "Do not include any text, explanation, or code.\n", 309 | "```\n", 310 | "\n", 311 | "This prompt shows the model several numerical point pairs along with a running history of previously attempted linear parameters, each associated with its loss. The examples illustrate the points from the true line the model is trying to find, while the optimization history demonstrates how well past guesses have performed. The model’s objective is to infer the underlying linear relationship and propose a better (w,b) that should reduce the loss in the next step.\n", 312 | "\n", 313 | "This is considered ICL because the model is never retrained or fine-tuned; it performs all reasoning directly inside the prompt. The example points and the optimization history act as temporary, in-context supervision that guides the model toward better parameter choices. By interpreting these examples and losses, the model infers the pattern, adjusts its guess, and outputs improved parameters—all within a single forward pass—demonstrating learning from context rather than parameter updates." 314 | ] 315 | }, 316 | { 317 | "cell_type": "markdown", 318 | "metadata": { 319 | "id": "xETg1ERbZZK5" 320 | }, 321 | "source": [ 322 | "The cell below `Main Optimization Loop` runs the full optimization loop where the LLM repeatedly proposes improved values of w and b. It begins by generating a synthetic linear task, computing an initial random guess, and recording its loss; then for each iteration, it calls the LLM to propose a new (w,b), evaluates the loss of that proposed values, and appends the result to the optimization history. After completing all iterations, it returns the training data, the full sequence of guesses, and the true underlying parameters, and finally prints the loss of the last LLM-proposed solution." 323 | ] 324 | }, 325 | { 326 | "cell_type": "code", 327 | "execution_count": null, 328 | "metadata": { 329 | "id": "Fonu9u_D5j0P" 330 | }, 331 | "outputs": [], 332 | "source": [ 333 | "#@title **Main Optimization Loop**\n", 334 | "from tqdm import tqdm\n", 335 | "\n", 336 | "def compute_loss(train_pairs, w, b):\n", 337 | " xs = np.array([p[0] for p in train_pairs])\n", 338 | " ys = np.array([p[1] for p in train_pairs])\n", 339 | " pred = w * xs + b\n", 340 | " return float(np.mean((pred - ys) ** 2))\n", 341 | "\n", 342 | "train_pairs, w_true, b_true = generate_linear_task()\n", 343 | "\n", 344 | "def run_iterations(n_iterations=50):\n", 345 | "\n", 346 | " history = []\n", 347 | " w_initial = round(np.random.uniform(-5, 5), 2)\n", 348 | " b_initial = round(np.random.uniform(-5, 5), 2)\n", 349 | " loss_initial = compute_loss(train_pairs, w_initial, b_initial)\n", 350 | " history.append((float(w_initial), float(b_initial), loss_initial))\n", 351 | "\n", 352 | "\n", 353 | " # Add tqdm progress bar\n", 354 | " for it in tqdm(range(n_iterations), desc=\"Optimizing (LLM proposing w,b)\"):\n", 355 | " w, b = propose_wb(train_pairs, history, it)\n", 356 | " if w is None or b is None:\n", 357 | " history.append((None, None, 1e9))\n", 358 | " continue\n", 359 | "\n", 360 | " loss = compute_loss(train_pairs, w, b)\n", 361 | " history.append((float(w), float(b), loss))\n", 362 | "\n", 363 | " return train_pairs, history, w_true, b_true\n", 364 | "\n", 365 | "n_iterations = 50\n", 366 | "train_pairs, history, w_true, b_true = run_iterations(n_iterations)\n", 367 | "\n", 368 | "final_w, final_b, final_loss = history[-1]\n", 369 | "print(\"\\nFinal loss:\", final_loss)" 370 | ] 371 | }, 372 | { 373 | "cell_type": "markdown", 374 | "metadata": { 375 | "id": "VdNBCw9NZut2" 376 | }, 377 | "source": [ 378 | "The cell below dispalys a cisualization for how each LLM-proposed (w,b) evolves over the course of the ICL optimization process. It constructs every model-predicted line from the history, assigns each iteration a color on a gradient, and plots the example points provided. Then, using an animation, it draws one fitted line per frame—adding the newest line and shows how the model's guesses gradually move toward the correct linear fit." 379 | ] 380 | }, 381 | { 382 | "cell_type": "code", 383 | "execution_count": null, 384 | "metadata": { 385 | "id": "x9yTQGc9DN2N" 386 | }, 387 | "outputs": [], 388 | "source": [ 389 | "#@title **Linear Regression Optimization Process**\n", 390 | "import numpy as np\n", 391 | "import matplotlib\n", 392 | "import matplotlib.pyplot as plt\n", 393 | "from matplotlib.animation import FuncAnimation\n", 394 | "from IPython.display import HTML\n", 395 | "from matplotlib.lines import Line2D\n", 396 | "import matplotlib.cm as cm\n", 397 | "from matplotlib.colors import Normalize\n", 398 | "from mpl_toolkits.axes_grid1 import make_axes_locatable\n", 399 | "\n", 400 | "x = np.linspace(-10, 10, 200)\n", 401 | "y_true = w_true * x + b_true\n", 402 | "model_lines = [(w, b, w * x + b) for (w, b, _) in history]\n", 403 | "num_iters = len(model_lines)\n", 404 | "\n", 405 | "cmap = matplotlib.colormaps.get_cmap(\"plasma\")\n", 406 | "norm = Normalize(vmin=0, vmax=num_iters - 1)\n", 407 | "colors = [cmap(norm(i)) for i in range(num_iters)]\n", 408 | "\n", 409 | "xs = np.array([p[0] for p in train_pairs], dtype=float)\n", 410 | "ys = np.array([p[1] for p in train_pairs], dtype=float)\n", 411 | "\n", 412 | "fig, ax = plt.subplots(figsize=(10, 6))\n", 413 | "\n", 414 | "# ax.plot(x, y_true, \"k--\", label=\"Ground Truth\", linewidth=2)\n", 415 | "ax.scatter(xs, ys, facecolor='white', edgecolor='black', s=100, linewidths=2)\n", 416 | "ax.set_xlim(x.min(), x.max())\n", 417 | "ax.set_ylim(-25, 25)\n", 418 | "ax.set_xlabel(\"x\")\n", 419 | "ax.set_ylabel(\"y\")\n", 420 | "\n", 421 | "divider = make_axes_locatable(ax)\n", 422 | "cax = divider.append_axes(\"right\", size=\"5%\", pad=0.3)\n", 423 | "\n", 424 | "cb = plt.colorbar(\n", 425 | " cm.ScalarMappable(norm=norm, cmap=cmap),\n", 426 | " cax=cax\n", 427 | ")\n", 428 | "cb.set_label(\"Step\", rotation=90)\n", 429 | "\n", 430 | "plotted_lines = []\n", 431 | "\n", 432 | "def update(frame):\n", 433 | " w, b, y = model_lines[frame]\n", 434 | "\n", 435 | " # color for this iteration\n", 436 | " color = colors[frame]\n", 437 | "\n", 438 | " # newest line (opaque)\n", 439 | " new_line, = ax.plot(x, y, color=color, alpha=0.5, linewidth=2)\n", 440 | " plotted_lines.append(new_line)\n", 441 | "\n", 442 | " # fade old lines\n", 443 | " for i, line in enumerate(plotted_lines):\n", 444 | " age = len(plotted_lines) - 1 - i\n", 445 | " # line.set_alpha(max(0.1, 1 - age * 0.07))\n", 446 | "\n", 447 | " ax.set_title(\n", 448 | "\n", 449 | " f\"Linear Regression Optimization | Iteration {frame+1}/{num_iters}\"\n", 450 | " )\n", 451 | "\n", 452 | " return plotted_lines\n", 453 | "\n", 454 | "ani = FuncAnimation(fig, update, frames=num_iters, interval=400, blit=False)\n", 455 | "ani.save(\"linear_regression_optimization_path.gif\", writer=\"pillow\", fps=5)\n", 456 | "\n", 457 | "plt.close()\n", 458 | "gif_path = \"linear_regression_optimization_path.gif\"\n", 459 | "# Display interactive animation only\n", 460 | "display(Image(filename=gif_path))\n" 461 | ] 462 | }, 463 | { 464 | "cell_type": "code", 465 | "source": [ 466 | "#@title **Comparison between true slope/intercept values & predicted slope/interecept values**\n", 467 | "print(f\"Ground Truth - Slope: {w_true}, Intercept: {b_true}\")\n", 468 | "print(f\"Predicted - Slope: {final_w}, Intercept: {final_b}\")" 469 | ], 470 | "metadata": { 471 | "cellView": "form", 472 | "id": "zCT9UEMoTREd" 473 | }, 474 | "execution_count": null, 475 | "outputs": [] 476 | }, 477 | { 478 | "cell_type": "markdown", 479 | "metadata": { 480 | "id": "GVaMXybUZ_lC" 481 | }, 482 | "source": [ 483 | "The cell below then extracts the loss from each optimization step and plots how it changes over time. It creates a simple line plot where the x-axis represents the iteration number and the y-axis shows the corresponding loss value. The resulting figure visualizes whether the LLM’s proposed (w,b) values are improving by showing the trend of the loss across iterations." 484 | ] 485 | }, 486 | { 487 | "cell_type": "code", 488 | "execution_count": null, 489 | "metadata": { 490 | "cellView": "form", 491 | "id": "N3fHxaYa5j0Q" 492 | }, 493 | "outputs": [], 494 | "source": [ 495 | "#@title **Visualize Distribution of Loss Over Iterations**\n", 496 | "\n", 497 | "loss_values = [entry[2] for entry in history]\n", 498 | "\n", 499 | "plt.figure(figsize=(10,6))\n", 500 | "plt.plot(loss_values)\n", 501 | "plt.xlabel(\"Iteration\")\n", 502 | "plt.ylabel(\"Loss\")\n", 503 | "plt.title(\"Loss over Iterations\")\n", 504 | "plt.grid(True)\n", 505 | "plt.show()\n" 506 | ] 507 | }, 508 | { 509 | "cell_type": "markdown", 510 | "metadata": { 511 | "id": "Cbur2GHraYQk" 512 | }, 513 | "source": [ 514 | "## **Summary**\n", 515 | "\n", 516 | "This demo illustrates how LLMs can perform ICL by iteratively improving a linear line by using only numerical examples and loss feedback contained within the prompt. By presenting a small set of (x,y) pairs along with the history of previously attempted (w,b) values and their losses, the model is encouraged to infer the underlying linear relationship and propose better parameters without any explicit gradient formulas or training loops. Each iteration provides new context—updated guesses and their associated errors—allowing the LLM to adjust its predictions and gradually approach the true line through pattern recognition rather than parameter updates. Through this setup, we observe how the model implicitly learns to reduce error over time, effectively behaving as an optimizer driven purely by context.\n", 517 | "\n", 518 | "Beyond this specific linear regression example, the same methodology can be extended to a broad range of optimization and model-fitting problems. By embedding intermediate states, errors, or constraints into the prompt, LLMs can be guided to improve polynomial fits, tune hyperparameters, refine curve approximations, or even propose solutions to custom objective functions defined by the user. These variations highlight the versatility of the in-context learning paradigm for continuous and numerical reasoning tasks, demonstrating how LLMs can perform iterative improvement when framed through well-structured, example-driven prompts.\n", 519 | "\n", 520 | "## **Conclusion**\n", 521 | "\n", 522 | "This demonstration shows that LLMs can engage in iterative numerical reasoning using only the information supplied in the prompt—capturing the core idea of in-context learning in the setting of linear regression. Without receiving gradients, explicit optimization rules, or model updates, the LLM learns from the pattern of losses and examples to produce increasingly accurate (w,b) values over time. Although the model is not performing true mathematical optimization, its ability to approximate the underlying line through repeated contextual feedback reveals a promising direction for using LLMs as lightweight, prompt-driven optimizers. Linear regression therefore serves as a clear and intuitive example of how LLMs can adapt, refine, and improve numerical predictions purely from context." 523 | ] 524 | }, 525 | { 526 | "cell_type": "code", 527 | "source": [], 528 | "metadata": { 529 | "id": "q8pVXS6QVg7F" 530 | }, 531 | "execution_count": null, 532 | "outputs": [] 533 | } 534 | ], 535 | "metadata": { 536 | "colab": { 537 | "provenance": [] 538 | }, 539 | "kernelspec": { 540 | "display_name": "Python 3 (ipykernel)", 541 | "language": "python", 542 | "name": "python3" 543 | }, 544 | "language_info": { 545 | "codemirror_mode": { 546 | "name": "ipython", 547 | "version": 3 548 | }, 549 | "file_extension": ".py", 550 | "mimetype": "text/x-python", 551 | "name": "python", 552 | "nbconvert_exporter": "python", 553 | "pygments_lexer": "ipython3", 554 | "version": "3.10.9" 555 | } 556 | }, 557 | "nbformat": 4, 558 | "nbformat_minor": 0 559 | } -------------------------------------------------------------------------------- /In_Context_Learning_Demo_BinaryClassification.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "source": [ 6 | "# **Introduction to In-Context Learning Demo: Binary Classification**\n", 7 | "\n", 8 | "## **Overview**\n", 9 | "This notebook presents a step-by-step walkthrough of an interactive demo that explores how Large Language Models (LLMs) can perform **image-based binary classification purely through in-context learning**. Inspired by the idea that LLMs can function as general pattern recognizers, this experiment uses small image snippets of hazelnuts—labeled as good or bad—to examine how effectively Gemini can internalize visual defect patterns and reproduce consistent judgments on new samples. The goal is to illustrate how an LLM—without explicit training, fine-tuning, or feature engineering—can infer quality cues such as surface texture, cracks, mold, or deformation from a few examples and generalize to unseen hazelnuts. The resulting approach **allows for online and incremental learning of perception** without any gradient-based updates. \n", 10 | "\n", 11 | "## **Background**\n", 12 | "This demo focuses on real hazelnut images categorized into two classes:\n", 13 | "- Good hazelnuts: Typically clean, symmetric, and free of surface defects.\n", 14 | "- Bad hazelnuts: Containing surface markings, cracks, discoloration, mold, dents, or structural abnormalities.\n", 15 | "\n", 16 | "Each example image serves as an example pairing:\n", 17 | "- A compact visual representation\n", 18 | "- A binary label: \"Good\" or \"Bad\"\n", 19 | "\n", 20 | "What makes this setup particularly compelling:\n", 21 | "- The data was checked before utilization to make sure the LLM cannot classify the quality of hazelnuts before ICL.\n", 22 | "- The model receives visual examples only through the prompt context\n", 23 | "- No training or gradient updates occur—classification arises from pattern matching\n", 24 | "- Prompts can include descriptions, annotations, or multi-step chains of examples\n", 25 | "- The test images require generalization, not memorization\n", 26 | "- Predictions are generated sample-by-sample, mimicking standard evaluation flows\n", 27 | "\n", 28 | "This setting provides a clear benchmark for understanding how well LLMs can perform visual classification tasks when guided only through carefully constructed prompts.\n", 29 | "\n", 30 | "## **Let's Take a Look at an Example**\n", 31 | "\n", 32 | "The illustration below shows good and bad hazelnut examples. Given a set of labeled samples in context, the LLM must determine the correct label for each new, unseen hazelnut image during evaluation.\n", 33 | "\n", 34 | "\n", 35 | "\n", 36 | "## **LLM as the Classifier**\n", 37 | "\n", 38 | "In this demo, we use Gemini 2.0/2.5 Flash in non-reasoning mode, prompting the model to rely on direct pattern recognition rather than symbolic explanation or analytic reasoning.\n", 39 | "\n", 40 | "The workflow proceeds as follows:\n", 41 | "- Provide 10 total labeled hazelnut images (5 good, 5 bad) as in-context examples\n", 42 | "- Send a new unlabeled image through the prompt to classify\n", 43 | "- Ask the model to output either \"Good\" or \"Bad\"\n", 44 | "\n", 45 | "The LLM functions as a lightweight, prompt-driven classifier—absorbing visual differences, structural patterns, and defect signatures from the in-context examples.\n", 46 | "\n", 47 | "## **Evaluation**\n", 48 | "Finally, we compare the model’s predicted labels against ground-truth quality labels and accuracy. This provides an insight into how effectively a LLM can approximate visual quality-control decisions through in-context learning alone—without any dedicated training pipeline.\n", 49 | "\n", 50 | "## **Code Overview**\n", 51 | "The implementation is structured modularly, with each component handling a distinct stage of the ICL classification pipeline. This separation makes the system easy to modify, extend, and reuse:\n", 52 | "- Data loading & preprocessing: Read images, convert to model-compatible format\n", 53 | "- Visualization: Display sets of good/bad examples\n", 54 | "- Prompt construction: Insert labeled samples into few-shot prompts\n", 55 | "- LLM inference: Retrieve predictions one image at a time\n", 56 | "- Evaluation: Compute metrics and compare with a baseline (CNN)\n", 57 | "\n", 58 | "## **Before Running the Demo, Follow the Instructions Below**\n", 59 | "To run the full experiment:\n", 60 | "1. Ensure all dependencies are imported and installed.\n", 61 | "2.\tVisit Google AI Studio (https://aistudio.google.com/) to obtain your Gemini API key.\n", 62 | "3.\tOnce API key is generated, copy and paste it into the demo when prompted.\n", 63 | "\n", 64 | "#### **If having trouble creating an API Key, follow the link below for instructions:**\n", 65 | "* #### **[Instructions on How to Obtain a Gemini API Key](https://docs.google.com/document/d/17pgikIpvZCBgcRh4NNrcee-zp3Id0pU0vq2SG52KIGE/edit?usp=sharing)**\n", 66 | "\n" 67 | ], 68 | "metadata": { 69 | "id": "JWBl0iMOvbE_" 70 | } 71 | }, 72 | { 73 | "cell_type": "code", 74 | "source": [ 75 | "#@title **Import Necessary Libraries**\n", 76 | "import numpy as np\n", 77 | "import matplotlib.pyplot as plt\n", 78 | "from google import genai\n", 79 | "from google.genai import types\n", 80 | "import re\n", 81 | "import itertools\n", 82 | "import math\n", 83 | "import ast\n", 84 | "import json\n", 85 | "import ipywidgets as widgets\n", 86 | "from IPython.display import display\n", 87 | "from IPython.display import Image\n", 88 | "import os\n", 89 | "import getpass\n", 90 | "import time" 91 | ], 92 | "metadata": { 93 | "cellView": "form", 94 | "id": "qDV03CJIve4R" 95 | }, 96 | "execution_count": null, 97 | "outputs": [] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": null, 102 | "metadata": { 103 | "cellView": "form", 104 | "id": "lh6d1fWygOfe" 105 | }, 106 | "outputs": [], 107 | "source": [ 108 | "#@title **Setting Up Gemini Client**\n", 109 | "apikey = getpass.getpass(\"Enter your Gemini API Key: \")\n", 110 | "client = genai.Client(api_key=apikey)" 111 | ] 112 | }, 113 | { 114 | "cell_type": "code", 115 | "source": [ 116 | "#@title **Choose a Model**\n", 117 | "model_dropdown = widgets.Dropdown(\n", 118 | " options=[\n", 119 | " (\"Gemini 2.5 Flash\", \"gemini-2.5-flash\"),\n", 120 | " (\"Gemini 2.0 Flash\", \"gemini-2.0-flash\"),\n", 121 | " ],\n", 122 | " description=\"Model:\",\n", 123 | " value=\"gemini-2.5-flash\",\n", 124 | " style={'description_width': 'initial'}\n", 125 | ")\n", 126 | "\n", 127 | "confirm_button = widgets.Button(\n", 128 | " description=\"Confirm Selection\"\n", 129 | ")\n", 130 | "\n", 131 | "output = widgets.Output()\n", 132 | "\n", 133 | "model_name = None\n", 134 | "\n", 135 | "def on_confirm_click(b):\n", 136 | " global model_name, batch_size\n", 137 | "\n", 138 | " model_name = model_dropdown.value\n", 139 | "\n", 140 | " with output:\n", 141 | " output.clear_output()\n", 142 | " print(f\"\\nSelected model: {model_name}\")\n", 143 | "\n", 144 | "confirm_button.on_click(on_confirm_click)\n", 145 | "\n", 146 | "display(model_dropdown, confirm_button, output)" 147 | ], 148 | "metadata": { 149 | "cellView": "form", 150 | "id": "3uUXIAcs2KxW" 151 | }, 152 | "execution_count": null, 153 | "outputs": [] 154 | }, 155 | { 156 | "cell_type": "code", 157 | "source": [ 158 | "#@title **Download Data from GitHub**\n", 159 | "if not os.path.exists(\"intro_to_icl_data\"):\n", 160 | " !git clone https://github.com/hsiang-fu/intro_to_icl_data.git\n", 161 | "\n" 162 | ], 163 | "metadata": { 164 | "id": "6UMMqb9zX20A", 165 | "cellView": "form" 166 | }, 167 | "execution_count": null, 168 | "outputs": [] 169 | }, 170 | { 171 | "cell_type": "markdown", 172 | "source": [ 173 | "The cell `ICL Binary Classification` below is responsible for running ICL binary hazelnut classification. It loads labeled training examples, constructs the few-shot prompt, sends the prompt to the LLM for each test image, parses the prediction, evaluates correctness, and finally reports overall accuracy.\n", 174 | "\n", 175 | "Each test image is displayed, and the model’s prediction and the ground truth label are printed. Running each cell will perform inference across all test images.\n", 176 | "\n", 177 | "After the cell runs, it performs several steps. First, it builds the ICL training examples by loading ten labeled hazelnut images (five Good and five Bad). Each example is converted into a Gemini Part so it can be embedded directly into the prompt. These form the annotated few-shot demonstrations the model uses to learn the classification pattern. Next, it constructs the full ICL prompt for each test item by including the instruction, all labeled example images, the unlabeled test image, and a rule specifying that the model should respond only with “Good” or “Bad.” Then it loads the unseen test images and their ground-truth labels. For each test image, the cell displays the image, sends the entire ICL prompt to Gemini, reads the model’s label prediction, compares it to the ground truth, and stores the results. After all images are processed, the cell computes summary metrics such as accuracy, total number of correct predictions, and incorrect predictions.\n", 178 | "\n", 179 | "Each evaluation cycle outputs the test hazelnut image, the model’s predicted label, the true label, and whether the prediction was correct. At the end, the code prints a performance summary showing the model’s accuracy across all ten test images." 180 | ], 181 | "metadata": { 182 | "id": "NkxcjLfcvzzw" 183 | } 184 | }, 185 | { 186 | "cell_type": "markdown", 187 | "source": [ 188 | "## **Let's take a look at the prompt**\n", 189 | "\n", 190 | "```\n", 191 | "You are an expert in visual inspection.\n", 192 | "Based on the examples provided, classify whether the test hazelnut is 'Good' or 'Bad'.\n", 193 | "\n", 194 | "{Example Image: Label Pairs}\n", 195 | "\n", 196 | "What is the condition of the hazelnut? Only return 'Good' or 'Bad'.\n", 197 | "```\n", 198 | "\n", 199 | "This prompt shows the model several example image–label pairs of hazelnuts, where each example presents an image of a hazelnut alongside its correct classification label (“Good” or “Bad”). These examples act as demonstrations of the visual criteria that define each category. After providing these labeled examples, the prompt then presents a new unseen and unlabeled test hazelnut and asks the model to classify it based on the pattern it inferred from the examples.\n", 200 | "\n", 201 | "This is considered ICL because the model is not being updated or fine-tuned. All “learning” occurs within the prompt itself: the example image–label pairs serve as the model’s temporary training data. By examining these demonstrations, the model infers the visual characteristics associated with “Good” versus “Bad” hazelnuts and applies that inferred pattern to the test hazelnut during the same forward pass. In other words, the model learns from context, not from parameter updates, which is exactly what defines ICL." 202 | ], 203 | "metadata": { 204 | "id": "1dzsNIanp27e" 205 | } 206 | }, 207 | { 208 | "cell_type": "code", 209 | "execution_count": null, 210 | "metadata": { 211 | "id": "R5_8IkaFgOff" 212 | }, 213 | "outputs": [], 214 | "source": [ 215 | "#@title **ICL Binary Classification**\n", 216 | "\n", 217 | "print(\"\\nRunning Binary Image Classification\\n\")\n", 218 | "\n", 219 | "test_ground_truth = {\n", 220 | " \"test1\": \"Bad\",\n", 221 | " \"test2\": \"Good\",\n", 222 | " \"test3\": \"Good\",\n", 223 | " \"test4\": \"Bad\",\n", 224 | " \"test5\": \"Bad\",\n", 225 | " \"test6\": \"Bad\",\n", 226 | " \"test7\": \"Good\",\n", 227 | " \"test8\": \"Bad\",\n", 228 | " \"test9\": \"Good\",\n", 229 | " \"test10\": \"Good\",\n", 230 | "}\n", 231 | "\n", 232 | "test_images = list(test_ground_truth.keys())\n", 233 | "\n", 234 | "model_predictions = []\n", 235 | "icl_results = []\n", 236 | "\n", 237 | "train_filenames = [f\"good{i}\" for i in range(1, 6)] + [f\"bad{i}\" for i in range(1, 6)]\n", 238 | "train_labels = [\"Good\"] * 5 + [\"Bad\"] * 5\n", 239 | "\n", 240 | "def load_part(path):\n", 241 | " \"\"\"Load an image file as a Gemini Part.\"\"\"\n", 242 | " with open(path, \"rb\") as f:\n", 243 | " return types.Part.from_bytes(\n", 244 | " data=f.read(),\n", 245 | " mime_type=\"image/png\"\n", 246 | " )\n", 247 | "\n", 248 | "train_parts = [\n", 249 | " load_part(f\"intro_to_icl_data/hazelnuts/{name}.png\")\n", 250 | " for name in train_filenames\n", 251 | "]\n", 252 | "\n", 253 | "def classify_image(selected_label):\n", 254 | "\n", 255 | " test_path = f\"intro_to_icl_data/hazelnuts/{selected_label}.png\"\n", 256 | " test_part = load_part(test_path)\n", 257 | " with open(test_path, \"rb\") as f:\n", 258 | " display(Image(data=f.read()))\n", 259 | "\n", 260 | " contents = [\n", 261 | " \"You are an expert in visual inspection. \"\n", 262 | " \"Based on the examples provided, classify whether the test hazelnut is 'Good' or 'Bad'.\\n\"\n", 263 | " ]\n", 264 | "\n", 265 | " # Add example images as few-shot ICL\n", 266 | " for i, (label, part) in enumerate(zip(train_labels, train_parts)):\n", 267 | " contents.append(f\"Example {i+1}: {label}\")\n", 268 | " contents.append(part)\n", 269 | "\n", 270 | " # Add the test image\n", 271 | " contents.extend([\n", 272 | " \"What is the condition of the hazelnut? Only return 'Good' or 'Bad'.\",\n", 273 | " test_part,\n", 274 | " ])\n", 275 | "\n", 276 | " print(f\"\\nClassifying {selected_label}...\")\n", 277 | " try:\n", 278 | " response = client.models.generate_content(\n", 279 | " model=model_name,\n", 280 | " contents=contents,\n", 281 | " config=types.GenerateContentConfig(\n", 282 | " thinking_config=types.ThinkingConfig(thinking_budget=0)\n", 283 | " )\n", 284 | " )\n", 285 | " except Exception as e:\n", 286 | " print(\"Waiting 60 seconds for quota limit reset.\")\n", 287 | " time.sleep(60)\n", 288 | " response = client.models.generate_content(\n", 289 | " model=model_name,\n", 290 | " contents=contents,\n", 291 | " config=types.GenerateContentConfig(\n", 292 | " thinking_config=types.ThinkingConfig(thinking_budget=0)\n", 293 | " )\n", 294 | " )\n", 295 | "\n", 296 | " pred = response.text.strip()\n", 297 | " print(f\"LLM Prediction: {pred}\")\n", 298 | " return pred\n", 299 | "\n", 300 | "def evaluate_prediction(test_name, pred):\n", 301 | " true_label = test_ground_truth[test_name]\n", 302 | " correct = (pred == true_label)\n", 303 | "\n", 304 | " model_predictions.append({\n", 305 | " \"image\": test_name,\n", 306 | " \"prediction\": pred,\n", 307 | " \"label\": true_label,\n", 308 | " \"correct\": correct\n", 309 | " })\n", 310 | "\n", 311 | " icl_results.append({\n", 312 | " \"image\": test_name,\n", 313 | " \"pred\": pred,\n", 314 | " \"true\": true_label,\n", 315 | " \"correct\": correct\n", 316 | " })\n", 317 | "\n", 318 | " print(f\"Ground Truth: {true_label}\\n\")\n", 319 | "\n", 320 | "for test_name in test_images:\n", 321 | " pred = classify_image(test_name)\n", 322 | " evaluate_prediction(test_name, pred)\n", 323 | "\n", 324 | "def compute_final_metrics():\n", 325 | " if not model_predictions:\n", 326 | " print(\"No predictions collected.\")\n", 327 | " return\n", 328 | "\n", 329 | " acc = np.mean([p[\"correct\"] for p in model_predictions])\n", 330 | "\n", 331 | " print(\"\\n=== Model Performance Summary ===\")\n", 332 | " print(f\"Accuracy: {acc:.3f}\")\n", 333 | " print(f\"Total Samples: {len(model_predictions)}\")\n", 334 | " print(f\"Correct: {sum(p['correct'] for p in model_predictions)}\")\n", 335 | " print(f\"Incorrect: {len(model_predictions) - sum(p['correct'] for p in model_predictions)}\")\n", 336 | "\n", 337 | "compute_final_metrics()\n" 338 | ] 339 | }, 340 | { 341 | "cell_type": "markdown", 342 | "source": [ 343 | "This cell `Baseline CNN Binary Classification` implements the baseline model used to compare against the two in-context learning binary classification methods. This baseline relies on a standard convolutional neural network (CNN) that learns directly from pixel data to distinguish Good vs. Bad hazelnut images.\n", 344 | "\n", 345 | "After defining a small image-loading function, the code loads ten training images (five Good, five Bad) and assigns binary labels. It then loads a separate test set with known ground-truth labels. All images are resized to 128×128 and normalized. With the data prepared, the script builds a compact CNN consisting of three convolution-pooling blocks followed by dense layers, ending in a sigmoid output for binary classification.\n", 346 | "\n", 347 | "The model is trained for 100 epochs, evaluated on the same test images as the ones used in ICL, and its overall accuracy is reported. The final section generates predictions for each test image and prints whether each classification matches the true label. This produces a clear, image-based baseline to compare against the LLM’s in-context learning binary classification approaches." 348 | ], 349 | "metadata": { 350 | "id": "6eg0QRmsim4M" 351 | } 352 | }, 353 | { 354 | "cell_type": "code", 355 | "source": [ 356 | "\n", 357 | "#@title **Baseline CNN Binary Classification**\n", 358 | "import os\n", 359 | "import tensorflow as tf\n", 360 | "from tensorflow.keras import layers, models\n", 361 | "from tensorflow.keras.preprocessing.image import load_img, img_to_array\n", 362 | "\n", 363 | "print(\"\\nRunning Baseline CNN Model\\n\")\n", 364 | "\n", 365 | "def load_image(path, img_size=(128, 128)):\n", 366 | " img = load_img(path, target_size=img_size)\n", 367 | " arr = img_to_array(img) / 255.0\n", 368 | " return arr\n", 369 | "\n", 370 | "train_filenames = [f\"good{i}\" for i in range(1, 6)] + [f\"bad{i}\" for i in range(1, 6)]\n", 371 | "train_labels = [\"Good\"] * 5 + [\"Bad\"] * 5\n", 372 | "\n", 373 | "X_train = []\n", 374 | "y_train = []\n", 375 | "\n", 376 | "for file, label in zip(train_filenames, train_labels):\n", 377 | " path = f\"intro_to_icl_data/hazelnuts/{file}.png\"\n", 378 | " X_train.append(load_image(path))\n", 379 | " y_train.append(1 if label == \"Good\" else 0)\n", 380 | "\n", 381 | "X_train = np.array(X_train, dtype=\"float32\")\n", 382 | "y_train = np.array(y_train, dtype=\"int32\")\n", 383 | "\n", 384 | "print(f\"Loaded training images: {X_train.shape}\")\n", 385 | "\n", 386 | "test_ground_truth = {\n", 387 | " \"test1\": \"Bad\",\n", 388 | " \"test2\": \"Good\",\n", 389 | " \"test3\": \"Good\",\n", 390 | " \"test4\": \"Bad\",\n", 391 | " \"test5\": \"Bad\",\n", 392 | " \"test6\": \"Bad\",\n", 393 | " \"test7\": \"Good\",\n", 394 | " \"test8\": \"Bad\",\n", 395 | " \"test9\": \"Good\",\n", 396 | " \"test10\": \"Good\",\n", 397 | "}\n", 398 | "\n", 399 | "test_images = list(test_ground_truth.keys())\n", 400 | "\n", 401 | "X_test = []\n", 402 | "y_test = []\n", 403 | "\n", 404 | "for name in test_images:\n", 405 | " path = f\"intro_to_icl_data/hazelnuts/{name}.png\"\n", 406 | " X_test.append(load_image(path))\n", 407 | " y_test.append(1 if test_ground_truth[name] == \"Good\" else 0)\n", 408 | "\n", 409 | "X_test = np.array(X_test, dtype=\"float32\")\n", 410 | "y_test = np.array(y_test, dtype=\"int32\")\n", 411 | "\n", 412 | "print(f\"Loaded test images: {X_test.shape}\")\n", 413 | "\n", 414 | "baseline = models.Sequential([\n", 415 | " layers.Conv2D(16, (3, 3), activation=\"relu\", input_shape=(128, 128, 3)),\n", 416 | " layers.MaxPooling2D(2, 2),\n", 417 | "\n", 418 | " layers.Conv2D(32, (3, 3), activation=\"relu\"),\n", 419 | " layers.MaxPooling2D(2, 2),\n", 420 | "\n", 421 | " layers.Conv2D(64, (3, 3), activation=\"relu\"),\n", 422 | " layers.MaxPooling2D(2, 2),\n", 423 | "\n", 424 | " layers.Flatten(),\n", 425 | " layers.Dense(64, activation=\"relu\"),\n", 426 | " layers.Dense(1, activation=\"sigmoid\") # binary output\n", 427 | "])\n", 428 | "\n", 429 | "baseline.compile(\n", 430 | " optimizer=\"adam\",\n", 431 | " loss=\"binary_crossentropy\",\n", 432 | " metrics=[\"accuracy\"]\n", 433 | ")\n", 434 | "\n", 435 | "print(\"\\nTraining Baseline Model...\\n\")\n", 436 | "\n", 437 | "history = baseline.fit(\n", 438 | " X_train, y_train,\n", 439 | " epochs=100,\n", 440 | " batch_size=4,\n", 441 | " verbose=1\n", 442 | ")\n", 443 | "\n", 444 | "print(\"\\nEvaluating Baseline Model...\\n\")\n", 445 | "test_loss, test_acc = baseline.evaluate(X_test, y_test, verbose=0)\n", 446 | "print(f\"Baseline Test Accuracy: {test_acc:.3f}\")\n", 447 | "baseline_results = []\n", 448 | "preds = (baseline.predict(X_test) > 0.5).astype(int).flatten()\n", 449 | "\n", 450 | "print(\"\\n=== Baseline Model Predictions ===\")\n", 451 | "for i, name in enumerate(test_images):\n", 452 | " pred_label = \"Good\" if preds[i] == 1 else \"Bad\"\n", 453 | " true_label = test_ground_truth[name]\n", 454 | " correct = pred_label == true_label\n", 455 | " baseline_results.append({\n", 456 | " \"image\": name,\n", 457 | " \"pred\": pred_label,\n", 458 | " \"true\": true_label,\n", 459 | " \"correct\": pred_label == true_label\n", 460 | " })\n", 461 | " print(f\"{name}: Pred={pred_label} | True={true_label} | {'Correct' if correct else 'Wrong'}\")" 462 | ], 463 | "metadata": { 464 | "id": "SuuV-bgu6MYV" 465 | }, 466 | "execution_count": null, 467 | "outputs": [] 468 | }, 469 | { 470 | "cell_type": "code", 471 | "source": [ 472 | "#@title **Comparing ICL(annotation aware and agnostic) vs Baseline Model**\n", 473 | "\n", 474 | "import pandas as pd\n", 475 | "\n", 476 | "# Convert to DataFrames\n", 477 | "df_icl = pd.DataFrame(icl_results)\n", 478 | "df_baseline = pd.DataFrame(baseline_results)\n", 479 | "\n", 480 | "# Merge by image name\n", 481 | "comparison = df_icl.merge(df_baseline, on=\"image\", suffixes=(\"_icl\", \"_baseline\"))\n", 482 | "\n", 483 | "display(comparison)\n", 484 | "\n", 485 | "# Compute accuracy\n", 486 | "acc_icl = comparison[\"correct_icl\"].mean()\n", 487 | "acc_base = comparison[\"correct_baseline\"].mean()\n", 488 | "\n", 489 | "print(f\"\\nICL Accuracy: {acc_icl:.3f}\")\n", 490 | "print(f\"Baseline Accuracy: {acc_base:.3f}\")\n", 491 | "\n", 492 | "# Bar Chart Comparison\n", 493 | "plt.figure(figsize=(6,4))\n", 494 | "plt.bar([\"ICL Aware\", \"Baseline CNN\"], [acc_icl, acc_base])\n", 495 | "plt.ylim(0,1)\n", 496 | "plt.title(\"Accuracy Comparison\")\n", 497 | "plt.ylabel(\"Accuracy\")\n", 498 | "plt.show()\n", 499 | "\n", 500 | "# Confusion Matrix Option\n", 501 | "from sklearn.metrics import confusion_matrix\n", 502 | "\n", 503 | "y_true = comparison[\"true_baseline\"].map({\"Good\":1,\"Bad\":0})\n", 504 | "\n", 505 | "def print_confusion_stats(cm, title=\"Confusion Matrix\"):\n", 506 | " TN, FP, FN, TP = cm.ravel()\n", 507 | "\n", 508 | " print(f\"\\n=== {title} ===\")\n", 509 | " print(f\"True Positive : {TP}\")\n", 510 | " print(f\"False Positive: {FP}\")\n", 511 | " print(f\"False Negative: {FN}\")\n", 512 | " print(f\"True Negative : {TN}\")\n", 513 | "\n", 514 | "cm_icl = confusion_matrix(y_true, comparison[\"pred_icl\"].map({\"Good\":1,\"Bad\":0}))\n", 515 | "cm_base = confusion_matrix(y_true, comparison[\"pred_baseline\"].map({\"Good\":1,\"Bad\":0}))\n", 516 | "\n", 517 | "print_confusion_stats(cm_icl, \"ICL\")\n", 518 | "print_confusion_stats(cm_base, \"Baseline CNN\")" 519 | ], 520 | "metadata": { 521 | "id": "OvcxVDPk-vMU" 522 | }, 523 | "execution_count": null, 524 | "outputs": [] 525 | }, 526 | { 527 | "cell_type": "markdown", 528 | "source": [ 529 | "# **Summary**\n", 530 | "\n", 531 | "This demo illustrates how LLMs can perform ICL for binary image classification, using hazelnuts as the target task. By providing the model with a small set of annotated image–label pairs, we show that it can infer the visual distinctions between “Good” and “Bad” hazelnuts—such as cracks, surface irregularities, or discoloration—without ever being trained as a model for this specific purpose.\n", 532 | "\n", 533 | "To ground the evaluation, we compare the LLM’s classification behavior against a traditional Convolutional Neural Network (CNN) baseline trained directly on the pixel data. The CNN provides a fully supervised reference point, learning visual features from scratch and offering a stable benchmark for accuracy. When evaluated across a set of unseen test images, ICL performs better than the baseline CNN and achieves a higher accuracy, by classifying all the test images correct. Even if we increase the number of epochs the CNN was trained on, the amount of training data was still not enough for the baseline to classify the hazelnuts correctly\n", 534 | "\n", 535 | "The techniques demonstrated here can be extended beyond hazelnut inspection to a broad range of lightweight image-based classification tasks. With a small set of curated examples, LLMs can be prompted to generalize visual patterns related to binary categories: quality control, medical triage, material defects, etc. These results highlight how ICL can provide quick, adaptable classification capabilities without the need to train or fine-tune a neural network, significantly reducing computational costs and setup time.\n", 536 | "\n", 537 | "## **Conclusion**\n", 538 | "\n", 539 | "This demonstration shows that LLMs can perform meaningful visual classification through pattern recognition alone—interpreting defects in hazelnut images using only the examples supplied in the prompt. The ICL approach allows for stronger generalization from very small datasets (10 images). Compared to the CNN baseline, ICL performs better and more consistent, proving that the LLM’s flexibility and ability to mimic a human-like pattern learning skill. Together, the comparison between ICL and the CNN baseline offers a clear perspective on why LLMs are capable of excelling in tasks that normally would take traditional approaches much more time to retrain and larger datasets that would need to be used." 540 | ], 541 | "metadata": { 542 | "id": "fSvKO3MLk-ZN" 543 | } 544 | }, 545 | { 546 | "cell_type": "code", 547 | "source": [], 548 | "metadata": { 549 | "id": "9n-DrhP6pzmp" 550 | }, 551 | "execution_count": null, 552 | "outputs": [] 553 | } 554 | ], 555 | "metadata": { 556 | "colab": { 557 | "provenance": [] 558 | }, 559 | "kernelspec": { 560 | "display_name": "base", 561 | "language": "python", 562 | "name": "python3" 563 | }, 564 | "language_info": { 565 | "codemirror_mode": { 566 | "name": "ipython", 567 | "version": 3 568 | }, 569 | "file_extension": ".py", 570 | "mimetype": "text/x-python", 571 | "name": "python", 572 | "nbconvert_exporter": "python", 573 | "pygments_lexer": "ipython3", 574 | "version": "3.13.5" 575 | } 576 | }, 577 | "nbformat": 4, 578 | "nbformat_minor": 0 579 | } -------------------------------------------------------------------------------- /In_Context_Learning_Demo_TimeSeriesPrediction.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "source": [ 6 | "# **Introduction to In-Context Learning Demo: Time-Series Prediction**\n", 7 | "\n", 8 | "In this notebook we show how Large Language Models (LLMs) can learn **to predict and complete temporal, numerical sequences from a handful of examples**. Inspired by *[Large Language Models as General Pattern Machines](https://arxiv.org/pdf/2307.04721)* (Mirchandani et al., 2023), this experiment uses training data derived from sinusoidal functions and evaluates Gemini’s ability to recognize, internalize, and extrapolate the underlying behavior. The goal is to show how an LLM —**without being explicitly told the analytical form of the function**— can observe short temporal sequences and, in turn, complete a new test sequence beyond its original domain.\n", 9 | "\n", 10 | "## **Overview**\n", 11 | "\n", 12 | "This demonstration guides you through:\n", 13 | "1. Preparing training examples for in-context learning\n", 14 | "2. Constructing prompts containing these examples\n", 15 | "3. Querying the Gemini model to predict a time series\n", 16 | "4. Generating extended values for a custom test sequence\n", 17 | "\n", 18 | "The result is a clear visualization of how LLMs behave as pattern machines—even when dealing with mathematical patterns instead of grid transformations.\n", 19 | "\n", 20 | "## **Background**\n", 21 | "This demo focuses on sequences generated from two families of functions:\n", 22 | "- a⋅x⋅sin(bx)\n", 23 | "- a⋅sin(bx)\n", 24 | "\n", 25 | "For each function type, we create 10 sample sequences, using x-values uniformly sampled from the interval [0,π]. These examples are given to the LLM as demonstrations of the kind of pattern we expect it to continue. What makes this setup interesting is that the model is not given the underlying analytical equations. In addition, the test sequence extends beyond the [0,π] range, challenging the model to generalize rather than memorize, i.e., the interpolation of values is not enough to generate good performance. This setting provides an intuitive benchmark of whether an LLM can identify the temporal behavior of a numerical time-series provided.\n", 26 | "\n", 27 | "## **Let's Take a Look at an Example**\n", 28 | "\n", 29 | "The illustration below shows a sample training sequence generated (red curve) from a sine-based function and a test sequence that slightly exceeds the seen interval. The model observes several patterns like this in the prompt. The test input challenges the LLM to extend the curve naturally, maintaining the oscillatory structure without explicit formulas or reasoning tools.\n", 30 | "\n", 31 | "\n", 32 | "\n", 33 | "## **LLM as the Optimizer**\n", 34 | "In this demo, we use Gemini 2.0/2.5 Flash in non-reasoning mode. This ensures that the model focuses on pattern continuation rather than symbolic mathematics or deliberate reasoning. The workflow is simple:\n", 35 | "\n", 36 | "- Provide several sine-based training sequences\n", 37 | "- Append a new test sequence that extends beyond the sampled domain\n", 38 | "- Ask the model to continue the sequence for 50, 100, or 200 additional time steps\n", 39 | "\n", 40 | "The LLM acts as a pattern completion machine, detecting the structure from the examples and applying it to the test sequence. This approach highlights how LLMs can approximate functional behavior purely through in-context pattern recognition.\n", 41 | "\n", 42 | "## **Code Overview**\n", 43 | "The implementation is organized into a modular structure, with each component responsible for a different stage of the sequence prediction pipeline. This design separates data loading, visualization, prompt construction, LLM inference, and result interpretation, making the system easy to understand, modify, and extend.\n", 44 | "\n", 45 | "## **Before Running the Demo, Follow the Instructions Below**\n", 46 | "To run the full experiment:\n", 47 | "1. Ensure all dependencies are imported and installed.\n", 48 | "2.\tVisit Google AI Studio (https://aistudio.google.com/) to obtain your Gemini API key.\n", 49 | "3.\tOnce API key is generated, copy and paste it into the demo when prompted.\n", 50 | "\n", 51 | "#### **If having trouble creating an API Key, follow the link below for instructions:**\n", 52 | "* #### **[Instructions on How to Obtain a Gemini API Key](https://docs.google.com/document/d/17pgikIpvZCBgcRh4NNrcee-zp3Id0pU0vq2SG52KIGE/edit?usp=sharing)**\n", 53 | "\n" 54 | ], 55 | "metadata": { 56 | "id": "6fhNlpM88MMV" 57 | } 58 | }, 59 | { 60 | "cell_type": "code", 61 | "source": [ 62 | "#@title **Import Necessary Libraries**\n", 63 | "import numpy as np\n", 64 | "import matplotlib.pyplot as plt\n", 65 | "from google import genai\n", 66 | "from google.genai import types\n", 67 | "import re\n", 68 | "import itertools\n", 69 | "import math\n", 70 | "import ast\n", 71 | "import getpass\n", 72 | "import matplotlib.pyplot as plt\n", 73 | "import matplotlib.animation as animation\n", 74 | "from IPython.display import Image\n", 75 | "import ipywidgets as widgets" 76 | ], 77 | "metadata": { 78 | "id": "LhzfLjFw50v5", 79 | "cellView": "form" 80 | }, 81 | "execution_count": null, 82 | "outputs": [] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": null, 87 | "metadata": { 88 | "id": "d2ROw87H5j0O", 89 | "cellView": "form" 90 | }, 91 | "outputs": [], 92 | "source": [ 93 | "#@title **Setting Up Gemini Client**\n", 94 | "apikey = getpass.getpass(\"Enter your Gemini API Key: \")\n", 95 | "client = genai.Client(api_key=apikey)" 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "source": [ 101 | "#@title **Choose a Model**\n", 102 | "model_dropdown = widgets.Dropdown(\n", 103 | " options=[\n", 104 | " (\"Gemini 2.5 Flash\", \"gemini-2.5-flash\"),\n", 105 | " (\"Gemini 2.0 Flash\", \"gemini-2.0-flash\")\n", 106 | " ],\n", 107 | " description=\"Model:\",\n", 108 | " value=\"gemini-2.5-flash\",\n", 109 | " style={'description_width': 'initial'}\n", 110 | ")\n", 111 | "\n", 112 | "confirm_button = widgets.Button(\n", 113 | " description=\"Confirm Selection\"\n", 114 | ")\n", 115 | "\n", 116 | "output = widgets.Output()\n", 117 | "\n", 118 | "model_name = None\n", 119 | "\n", 120 | "def on_confirm_click(b):\n", 121 | " global model_name, batch_size\n", 122 | "\n", 123 | " model_name = model_dropdown.value\n", 124 | "\n", 125 | " with output:\n", 126 | " output.clear_output()\n", 127 | " print(f\"\\nSelected model: {model_name}\")\n", 128 | "\n", 129 | "confirm_button.on_click(on_confirm_click)\n", 130 | "\n", 131 | "display(model_dropdown, confirm_button, output)" 132 | ], 133 | "metadata": { 134 | "id": "THDQ-1H2cUqf", 135 | "cellView": "form" 136 | }, 137 | "execution_count": null, 138 | "outputs": [] 139 | }, 140 | { 141 | "cell_type": "markdown", 142 | "source": [ 143 | "### **Let's Take a Look at an Example of the Prompt (applies to both examples!):**\n", 144 | "```\n", 145 | "You are given several example sequences of (x, y) pairs generated by different mathematical patterns.\n", 146 | "\n", 147 | "Example 1: [(0.0, 0.0), (0.025, 0.199), (0.05, 0.389), ...]\n", 148 | "Example 2: [(0.0, 0.0), (0.025, 0.397), (0.05, 0.779), ...]\n", 149 | "Example 3: [(0.0, 0.0), (0.025, 0.596), (0.05, 1.168), ...]\n", 150 | "Example 4: [(0.0, 0.0), (0.025, 0.795), (0.05, 1.558), ...]\n", 151 | "Example 5: [(0.0, 0.0), (0.025, 0.993), (0.05, 1.947), ...]\n", 152 | "Example 6: [(0.0, 0.0), (0.025, 1.192), (0.05, 2.337), ...]\n", 153 | "...\n", 154 | "\n", 155 | "The following sequence represents a partial test input: {test}\n", 156 | "\n", 157 | "Now generate the next 200 new (x, y) pairs that follow the same underlying mathematical pattern, continuing\n", 158 | "naturally from where the test sequence ends.\n", 159 | "\n", 160 | "Output a Python list of (x, y) pairs in this format, remember to close all brackets correctly:\n", 161 | "[(x1, y1), (x2, y2), (x3, y3), (x4, y4), (x5, y5), ...]\n", 162 | "No explanations, no code, no comments — only the list.\n", 163 | "```\n", 164 | "\n", 165 | "The above prompt shows several training examples and one test query from a time series prediction task, where each example provides a short list of (x,y) pairs generated by a particular mathematical pattern (in this case, sinusoidal curves with different amplitudes or phases). Each example demonstrates how the values evolve as x increases, and the corresponding y-values reflect the underlying function that produced the sequence. The model’s objective is to infer this hidden generative rule solely from the provided examples and then apply it to the partial test sequence in order to produce the next 200 (x,y) pairs.\n", 166 | "\n", 167 | "This is considered ICL because the weights of the LLM model are not updated or fine-tuned to the training data. All “learning” occurs within the prompt itself: the example sequences become the model’s temporary training data. By observing the input–output behavior embedded in the prompt, the LLM identifies the shared mathematical pattern, generalizes it, and continues the test sequence beyond its given domain—all during a single forward pass of the LLM. In other words, the model learns from context, not from parameter updates." 168 | ], 169 | "metadata": { 170 | "id": "lPcPZlbaL62V" 171 | } 172 | }, 173 | { 174 | "cell_type": "markdown", 175 | "source": [ 176 | "# **First Example: a⋅x⋅sin(bx)**" 177 | ], 178 | "metadata": { 179 | "id": "O5NkOsR28l8U" 180 | } 181 | }, 182 | { 183 | "cell_type": "markdown", 184 | "metadata": { 185 | "id": "vkEeqSfR5j0P" 186 | }, 187 | "source": [ 188 | "The below cell **`Example Generation`** is responsible for the generating the example sequences for the sequence prediction demo. It's purpose is to output sequences of transformations of the orignal a⋅x⋅sin(bx) function" 189 | ] 190 | }, 191 | { 192 | "cell_type": "code", 193 | "source": [ 194 | "#@title **Example Generation**\n", 195 | "def generate_example(amplitude):\n", 196 | " x = np.round(np.arange(0, 2*np.pi, 0.025), 3).tolist()\n", 197 | " y = np.round(amplitude * np.array(x) * np.sin(frequency * np.array(x)), 3).tolist()\n", 198 | " return list(zip(x, y))" 199 | ], 200 | "metadata": { 201 | "id": "9i30_Zzj1IOs", 202 | "cellView": "form" 203 | }, 204 | "execution_count": null, 205 | "outputs": [] 206 | }, 207 | { 208 | "cell_type": "markdown", 209 | "source": [ 210 | "The two cells below: **`Sequence Prediction of a⋅x⋅sin(bx)`** and **`Demo`** are responsible for running the full sequence-completion demo, including example generation, prompt construction, LLM inference, output parsing, and returning the predicted values. It effectively acts as the end-to-end execution layer for the entire notebook. The cell generates a set of example sequences based on variations of the ax·sin(bx) function and formats them into a prompt shown to the LLM. It then takes a user-supplied partial test sequence and asks the model to extend it by producing the next 50/100/200 (x, y) pairs that follow the same underlying mathematical pattern.\n", 211 | "\n", 212 | "The function performs the following steps:\n", 213 | "1. Takes input of frequency, amplitude, x_start, and x_end from the user\n", 214 | "2. Creates ten example sequences using those parameter values,\n", 215 | "3. Formats all examples and the test input into a structured LLM prompt,\n", 216 | "4. Sends the prompt to Gemini 2.5-Flash for sequence completion,\n", 217 | "5. Cleans and parses the raw model output into valid Python data,\n", 218 | "6. Extracts the resulting x-values and y-values for downstream visualization.\n", 219 | "\n", 220 | "The function will then visualize the output as an animated plot. The plot contains the original example sequences (solid gray curve), the model-generated continuation (red curve), as well as the ground-truth or expected values (dashed gray curve)." 221 | ], 222 | "metadata": { 223 | "id": "utEkjo--7KdY" 224 | } 225 | }, 226 | { 227 | "cell_type": "code", 228 | "execution_count": null, 229 | "metadata": { 230 | "id": "Fonu9u_D5j0P", 231 | "cellView": "form" 232 | }, 233 | "outputs": [], 234 | "source": [ 235 | "#@title **Sequence Prediction of ax⋅sin(bx)**\n", 236 | "def sequence_completion_demo(frequency, test):\n", 237 | " examples = [generate_example(a) for a in range(1, 11)]\n", 238 | "\n", 239 | " examples_text = \"\\n\".join([f\"Example {i+1}: {ex}\" for i, ex in enumerate(examples)])\n", 240 | "\n", 241 | " prompt = f\"\"\"\n", 242 | " You are given several example sequences of (x, y) pairs generated by different mathematical pattern.\n", 243 | "\n", 244 | " {examples_text}\n", 245 | "\n", 246 | " The following sequence represents a partial test input: {test}\n", 247 | "\n", 248 | " Now generate the next 200 new (x, y) pairs that follow the same underlying mathematical pattern, continuing naturally from where the test sequence ends.\n", 249 | "\n", 250 | " Output a Python list of [x, y] pairs in this format, remember to close all brackets correctly:\n", 251 | " [(x1, y1), (x2, y2), (x3, y3), (x4, y4), (x5, y5), ..., ..., ...]\n", 252 | " No explanations, no code, no comments — only the list.\n", 253 | " \"\"\"\n", 254 | "\n", 255 | " response = client.models.generate_content(\n", 256 | " model=model_name,\n", 257 | " contents=[prompt],\n", 258 | " config=types.GenerateContentConfig(\n", 259 | " thinking_config=types.ThinkingConfig(thinking_budget=0)\n", 260 | " ),\n", 261 | " )\n", 262 | "\n", 263 | " cleaned_response = response.text\n", 264 | " cleaned_response = cleaned_response.replace(\"python\",\"\").replace(\"```\",\"\").replace(\"json\",\"\")\n", 265 | "\n", 266 | " points = ast.literal_eval(cleaned_response.strip())\n", 267 | " x_values = [x for x, y in points]\n", 268 | " y_values = [y for x, y in points]\n", 269 | " return x_values, y_values" 270 | ] 271 | }, 272 | { 273 | "cell_type": "code", 274 | "execution_count": null, 275 | "metadata": { 276 | "id": "N3fHxaYa5j0Q", 277 | "cellView": "form" 278 | }, 279 | "outputs": [], 280 | "source": [ 281 | "#@title **Demo**\n", 282 | "amplitude = float(input(\"Enter sample amplitude (e.g. 2): \"))\n", 283 | "frequency = float(input(\"Enter sample frequency (e.g. 8): \"))\n", 284 | "x_start = float(input(\"Enter start x value (e.g. 2): \"))\n", 285 | "x_end = float(input(\"Enter end x value (e.g. 3): \"))\n", 286 | "\n", 287 | "x_test = np.round(np.arange(x_start, x_end, 0.025), 3)\n", 288 | "y_test = np.round(amplitude * x_test * np.sin(frequency * x_test), 3)\n", 289 | "\n", 290 | "test = [(float(x), float(y)) for x, y in zip(x_test, y_test)]\n", 291 | "\n", 292 | "x_pred, y_pred = sequence_completion_demo(frequency, test)\n", 293 | "\n", 294 | "fig, ax = plt.subplots(figsize=(10, 6))\n", 295 | "ax.set_xlim(0, max(x_end + 2, 2 * np.pi))\n", 296 | "ax.set_ylim(min(y_pred[:100]), max(y_pred[:100]))\n", 297 | "ax.set_xlabel(\"x\", fontsize=12)\n", 298 | "ax.set_ylabel(\"y\", fontsize=12)\n", 299 | "ax.set_title(f\"\\nSequence Completion: a·x·sin(bx)\\nAmplitude={amplitude}, Frequency={frequency}\\n\", fontsize=14)\n", 300 | "ax.grid(True, linestyle=\"--\", alpha=0.5)\n", 301 | "\n", 302 | "# Ground truth curve\n", 303 | "x_truth = np.arange(0, x_end, 0.025)\n", 304 | "y_truth = amplitude * x_truth * np.sin(frequency * x_truth)\n", 305 | "ax.plot(x_truth, y_truth, color=\"black\", alpha=0.3, linewidth=5, label=\"Ground Truth\")\n", 306 | "\n", 307 | "x_truth_2 = np.arange(x_end, x_end+2, 0.025)\n", 308 | "y_truth_2 = amplitude * x_truth_2 * np.sin(frequency * x_truth_2)\n", 309 | "ax.plot(x_truth_2, y_truth_2, color=\"gray\", linestyle = \"dashed\", alpha=0.3, linewidth=5, label=\"Ground Truth\")\n", 310 | "\n", 311 | "# Initialize animated line\n", 312 | "(pred_line,) = ax.plot([], [], color=\"red\", linewidth=5, label=\"Predicted Sequence\")\n", 313 | "\n", 314 | "# --- Update function for animation ---\n", 315 | "def update(frame):\n", 316 | " pred_line.set_data(x_pred[:frame], y_pred[:frame])\n", 317 | " return pred_line,\n", 318 | "\n", 319 | "# --- Create animation ---\n", 320 | "ani = animation.FuncAnimation(\n", 321 | " fig,\n", 322 | " update,\n", 323 | " frames=len(x_pred),\n", 324 | " interval=30, # milliseconds between frames\n", 325 | " blit=True\n", 326 | ")\n", 327 | "\n", 328 | "# --- Save as GIF ---\n", 329 | "gif_path = \"sequence_completion.gif\"\n", 330 | "ani.save(gif_path, writer=\"pillow\", fps=30)\n", 331 | "\n", 332 | "plt.close(fig) # close static plot to avoid double output\n", 333 | "\n", 334 | "# --- Display inline (works in notebooks) ---\n", 335 | "display(Image(filename=gif_path))\n", 336 | "\n" 337 | ] 338 | }, 339 | { 340 | "cell_type": "markdown", 341 | "source": [ 342 | "# **Second Example: a⋅sin(bx)**" 343 | ], 344 | "metadata": { 345 | "id": "OcIx-pTW8doX" 346 | } 347 | }, 348 | { 349 | "cell_type": "markdown", 350 | "metadata": { 351 | "id": "OPiklRgx5j0R" 352 | }, 353 | "source": [ 354 | "This second cell `Example Generation` is the same as the first and is also responsible for the generating the example sequences for the sequence prediction demo. However, it outputs the sequences of transformations of the orignal `a⋅sin(bx)` function" 355 | ] 356 | }, 357 | { 358 | "cell_type": "code", 359 | "source": [ 360 | "#@title **Example Generation**\n", 361 | "def generate_example_2(amplitude):\n", 362 | " x = np.round(np.arange(0, 2*np.pi, 0.025), 3).tolist()\n", 363 | " y = np.round(amplitude * np.sin(frequency * np.array(x)), 3).tolist()\n", 364 | " return list(zip(x, y))\n", 365 | "# Reset Gemini Client\n", 366 | "client = genai.Client(api_key=apikey)" 367 | ], 368 | "metadata": { 369 | "id": "CoWk4EG062Ja", 370 | "cellView": "form" 371 | }, 372 | "execution_count": null, 373 | "outputs": [] 374 | }, 375 | { 376 | "cell_type": "markdown", 377 | "source": [ 378 | "The two cells below: **`Sequence Prediction of a⋅sin(bx)`** and **`Demo`** is responsible for running the full sequence-completion demo, including example generation, prompt construction, LLM inference, output parsing, and returning the predicted continuation. It effectively acts as the end-to-end execution layer for the entire notebook.\n", 379 | "\n", 380 | "The cell generates a set of example sequences based on variations of the ax·sin(bx) function and formats them into a prompt shown to the model. It then takes a user-supplied partial test sequence and asks the model to extend it by producing the next 50/100/200 (x, y) pairs that follow the same underlying mathematical pattern.\n", 381 | "\n", 382 | "The function performs the following steps:\n", 383 | "1. Takes input of frequency, amplitude, x_start, and x_end from the user\n", 384 | "2. Creates ten example sequences using those parameter values,\n", 385 | "3. Formats all examples and the test input into a structured LLM prompt,\n", 386 | "4. Sends the prompt to Gemini 2.5-Flash for sequence completion,\n", 387 | "5. Cleans and parses the raw model output into valid Python data,\n", 388 | "6. Extracts the resulting x-values and y-values for downstream visualization.\n", 389 | "\n", 390 | "The function will then visualize the output as an animated plot and contain:\n", 391 | "- The original example sequences (solid gray curve)\n", 392 | "- The model-generated continuation (red curve)\n", 393 | "- The ground-truth or expected values (dashed gray curve)" 394 | ], 395 | "metadata": { 396 | "id": "Noceiuw-8qpF" 397 | } 398 | }, 399 | { 400 | "cell_type": "code", 401 | "execution_count": null, 402 | "metadata": { 403 | "id": "mrbDAPOe5j0R", 404 | "cellView": "form" 405 | }, 406 | "outputs": [], 407 | "source": [ 408 | "#@title **Sequence Prediction of a⋅sin(bx) using Gemini 2.5-Flash**\n", 409 | "def sequence_completion_2(frequency, test):\n", 410 | " examples = [generate_example_2(a) for a in range(1, 11)]\n", 411 | "\n", 412 | " examples_text = \"\\n\".join([f\"Example {i+1}: {ex}\" for i, ex in enumerate(examples)])\n", 413 | "\n", 414 | " prompt = f\"\"\"\n", 415 | " You are given several example sequences of (x, y) pairs generated by different mathematical pattern.\n", 416 | "\n", 417 | " {examples_text}\n", 418 | "\n", 419 | " The following sequence represents a partial test input: {test}\n", 420 | "\n", 421 | " Now generate the next 100 new (x), y) pairs that follow the same underlying mathematical pattern, continuing naturally from where the test sequence ends.\n", 422 | "\n", 423 | " Output a Python list of [x, y] pairs in this format, remember to close all brackets correctly:\n", 424 | " [(x1, y1), (x2, y2), (x3, y3), (x4, y4), (x5, y5), ..., ..., ...]\n", 425 | " No explanations, no code, no comments — only the list.\n", 426 | " \"\"\"\n", 427 | "\n", 428 | " # --- Gemini Query ---\n", 429 | " response = client.models.generate_content(\n", 430 | " model=model_name,\n", 431 | " contents=[prompt],\n", 432 | " config=types.GenerateContentConfig(\n", 433 | " thinking_config=types.ThinkingConfig(thinking_budget=0)\n", 434 | " ),\n", 435 | " )\n", 436 | "\n", 437 | " cleaned_response = response.text\n", 438 | " cleaned_response = cleaned_response.replace(\"python\",\"\").replace(\"```\",\"\").replace(\"json\",\"\")\n", 439 | " # --- Parse model response safely ---\n", 440 | " points = ast.literal_eval(cleaned_response.strip())\n", 441 | " x_values = [x for x, y in points]\n", 442 | " y_values = [y for x, y in points]\n", 443 | " return x_values, y_values" 444 | ] 445 | }, 446 | { 447 | "cell_type": "code", 448 | "execution_count": null, 449 | "metadata": { 450 | "scrolled": false, 451 | "id": "WWfzKlro5j0R", 452 | "cellView": "form" 453 | }, 454 | "outputs": [], 455 | "source": [ 456 | "#@title **Demo**\n", 457 | "amplitude = float(input(\"Enter sample amplitude (e.g. 2): \"))\n", 458 | "frequency = float(input(\"Enter sample frequency (e.g. 8): \"))\n", 459 | "x_start = float(input(\"Enter start x value (e.g. 2): \"))\n", 460 | "x_end = float(input(\"Enter end x value (e.g. 3): \"))\n", 461 | "\n", 462 | "x_test = np.round(np.arange(x_start, x_end, 0.025), 3)\n", 463 | "y_test = np.round(amplitude * np.sin(frequency * x_test), 3)\n", 464 | "\n", 465 | "test = [(float(x), float(y)) for x, y in zip(x_test, y_test)]\n", 466 | "\n", 467 | "x_pred, y_pred = sequence_completion_2(amplitude, test)\n", 468 | "\n", 469 | "fig, ax = plt.subplots(figsize=(10, 6))\n", 470 | "ax.set_xlim(0, max(x_end + 2, 2 * np.pi))\n", 471 | "ax.set_ylim(min(y_pred[:100]), max(y_pred[:100]))\n", 472 | "ax.set_xlabel(\"x\", fontsize=12)\n", 473 | "ax.set_ylabel(\"y\", fontsize=12)\n", 474 | "ax.set_title(f\"\\nSequence Completion: a·sin(bx)\\nAmplitude={amplitude}, Frequency={frequency}\\n\", fontsize=14)\n", 475 | "ax.grid(True, linestyle=\"--\", alpha=0.5)\n", 476 | "\n", 477 | "# Ground truth curve\n", 478 | "x_truth = np.arange(0, x_end, 0.025)\n", 479 | "y_truth = amplitude * np.sin(frequency * x_truth)\n", 480 | "ax.plot(x_truth, y_truth, color=\"black\", alpha=0.3, linewidth=5, label=\"Ground Truth\")\n", 481 | "\n", 482 | "x_truth_2 = np.arange(x_end, x_end+2, 0.025)\n", 483 | "y_truth_2 = amplitude * np.sin(frequency * x_truth_2)\n", 484 | "ax.plot(x_truth_2, y_truth_2, color=\"black\", linestyle = \"dashed\", alpha=0.3, linewidth=5)\n", 485 | "\n", 486 | "# Initialize animated line\n", 487 | "(pred_line,) = ax.plot([], [], color=\"red\", linewidth=5, label=\"Predicted Sequence\")\n", 488 | "\n", 489 | "# --- Update function for animation ---\n", 490 | "def update(frame):\n", 491 | " pred_line.set_data(x_pred[:frame], y_pred[:frame])\n", 492 | " return pred_line,\n", 493 | "\n", 494 | "# --- Create animation ---\n", 495 | "ani = animation.FuncAnimation(\n", 496 | " fig,\n", 497 | " update,\n", 498 | " frames=len(x_pred),\n", 499 | " interval=30, # milliseconds between frames\n", 500 | " blit=True\n", 501 | ")\n", 502 | "\n", 503 | "# --- Save as GIF ---\n", 504 | "gif_path = \"sequence_completion_2.gif\"\n", 505 | "ani.save(gif_path, writer=\"pillow\", fps=30)\n", 506 | "\n", 507 | "plt.close(fig) # close static plot to avoid double output\n", 508 | "\n", 509 | "# --- Display inline (works in notebooks) ---\n", 510 | "display(Image(filename=gif_path))" 511 | ] 512 | }, 513 | { 514 | "cell_type": "markdown", 515 | "source": [ 516 | "## **Summary**\n", 517 | "\n", 518 | "This demo illustrates how LLMs can perform in-context learning by recognizing and extending mathematical patterns from a limited set of numerical examples. By generating sample sequences from sine-based functions and presenting them as few-shot demonstrations, we query the model to infer the underlying oscillatory structure without any explicit description of the equations involved. When prompted with a new test sequence (one that extends slightly beyond the domain of the training examples) the model is asked to continue it for dozens or even hundreds of additional steps. Although predicted sequences don't always yield 100% accuracy on all results, the trends/pattern of the predicted values always remain similar. Through this setup, we observe how the LLM captures the sinusoidal behavior embedded in the examples, relying on pattern induction rather than formal symbolic reasoning.\n", 519 | "\n", 520 | "Beyond the specific sine functions used here, the same methodology can be applied to a wide variety of numerical or functional patterns **(! spoiler: there's one demo similar to this)**. By crafting appropriate few-shot demonstrations, LLMs can be guided to model polynomial trends, exponential growth, periodic signals, noisy measurements, mixed-frequency curves, or even synthetic data constructed from custom rules. These extensions demonstrate the flexibility of the in-context learning framework when applied to continuous, mathematical, or time-series data, revealing how LLMs can generalize from limited examples to produce coherent and often surprisingly accurate continuations.\n", 521 | "\n", 522 | "## **Conclusion**\n", 523 | "\n", 524 | "This demonstration highlights that LLMs can generalize structured numerical behavior using only a handful of examples thereby capturing the essence of in-context learning in the domain of sequence prediction. Even without access to symbolic expressions, internal formulas, or explicit reasoning tools, the model can infer the qualitative shape of sine-like functions and extend them beyond the observed interval. While the LLM does not perform rigorous mathematical computation, its ability to internalize and reproduce oscillatory patterns shows strong potential for tasks requiring rapid approximation, pattern continuation, or short-range forecasting. Sequence prediction thus serves as a compelling testbed for understanding both the strengths and limitations of treating LLMs as general pattern machines.\n", 525 | "\n", 526 | "## **References**\n", 527 | "Mirchandani, S., Xia, F., Florence, P., Ichter, B., Driess, D., Arenas, M.G., Rao, K., Sadigh, D. and Zeng, A. (2023, December). Large Language Models as General Pattern Machines. *In Conference on Robot Learning* (pp. 2498-2518). PMLR." 528 | ], 529 | "metadata": { 530 | "id": "VTK2bac29wyi" 531 | } 532 | }, 533 | { 534 | "cell_type": "code", 535 | "source": [], 536 | "metadata": { 537 | "id": "0_ghCHQzg-3K" 538 | }, 539 | "execution_count": null, 540 | "outputs": [] 541 | } 542 | ], 543 | "metadata": { 544 | "colab": { 545 | "provenance": [] 546 | }, 547 | "kernelspec": { 548 | "display_name": "Python 3 (ipykernel)", 549 | "language": "python", 550 | "name": "python3" 551 | }, 552 | "language_info": { 553 | "codemirror_mode": { 554 | "name": "ipython", 555 | "version": 3 556 | }, 557 | "file_extension": ".py", 558 | "mimetype": "text/x-python", 559 | "name": "python", 560 | "nbconvert_exporter": "python", 561 | "pygments_lexer": "ipython3", 562 | "version": "3.10.9" 563 | } 564 | }, 565 | "nbformat": 4, 566 | "nbformat_minor": 0 567 | } -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | GNU GENERAL PUBLIC LICENSE 2 | Version 3, 29 June 2007 3 | 4 | Copyright (C) 2007 Free Software Foundation, Inc. 5 | Everyone is permitted to copy and distribute verbatim copies 6 | of this license document, but changing it is not allowed. 7 | 8 | Preamble 9 | 10 | The GNU General Public License is a free, copyleft license for 11 | software and other kinds of works. 12 | 13 | The licenses for most software and other practical works are designed 14 | to take away your freedom to share and change the works. By contrast, 15 | the GNU General Public License is intended to guarantee your freedom to 16 | share and change all versions of a program--to make sure it remains free 17 | software for all its users. We, the Free Software Foundation, use the 18 | GNU General Public License for most of our software; it applies also to 19 | any other work released this way by its authors. You can apply it to 20 | your programs, too. 21 | 22 | When we speak of free software, we are referring to freedom, not 23 | price. Our General Public Licenses are designed to make sure that you 24 | have the freedom to distribute copies of free software (and charge for 25 | them if you wish), that you receive source code or can get it if you 26 | want it, that you can change the software or use pieces of it in new 27 | free programs, and that you know you can do these things. 28 | 29 | To protect your rights, we need to prevent others from denying you 30 | these rights or asking you to surrender the rights. Therefore, you have 31 | certain responsibilities if you distribute copies of the software, or if 32 | you modify it: responsibilities to respect the freedom of others. 33 | 34 | For example, if you distribute copies of such a program, whether 35 | gratis or for a fee, you must pass on to the recipients the same 36 | freedoms that you received. You must make sure that they, too, receive 37 | or can get the source code. And you must show them these terms so they 38 | know their rights. 39 | 40 | Developers that use the GNU GPL protect your rights with two steps: 41 | (1) assert copyright on the software, and (2) offer you this License 42 | giving you legal permission to copy, distribute and/or modify it. 43 | 44 | For the developers' and authors' protection, the GPL clearly explains 45 | that there is no warranty for this free software. For both users' and 46 | authors' sake, the GPL requires that modified versions be marked as 47 | changed, so that their problems will not be attributed erroneously to 48 | authors of previous versions. 49 | 50 | Some devices are designed to deny users access to install or run 51 | modified versions of the software inside them, although the manufacturer 52 | can do so. This is fundamentally incompatible with the aim of 53 | protecting users' freedom to change the software. The systematic 54 | pattern of such abuse occurs in the area of products for individuals to 55 | use, which is precisely where it is most unacceptable. Therefore, we 56 | have designed this version of the GPL to prohibit the practice for those 57 | products. If such problems arise substantially in other domains, we 58 | stand ready to extend this provision to those domains in future versions 59 | of the GPL, as needed to protect the freedom of users. 60 | 61 | Finally, every program is threatened constantly by software patents. 62 | States should not allow patents to restrict development and use of 63 | software on general-purpose computers, but in those that do, we wish to 64 | avoid the special danger that patents applied to a free program could 65 | make it effectively proprietary. To prevent this, the GPL assures that 66 | patents cannot be used to render the program non-free. 67 | 68 | The precise terms and conditions for copying, distribution and 69 | modification follow. 70 | 71 | TERMS AND CONDITIONS 72 | 73 | 0. Definitions. 74 | 75 | "This License" refers to version 3 of the GNU General Public License. 76 | 77 | "Copyright" also means copyright-like laws that apply to other kinds of 78 | works, such as semiconductor masks. 79 | 80 | "The Program" refers to any copyrightable work licensed under this 81 | License. Each licensee is addressed as "you". "Licensees" and 82 | "recipients" may be individuals or organizations. 83 | 84 | To "modify" a work means to copy from or adapt all or part of the work 85 | in a fashion requiring copyright permission, other than the making of an 86 | exact copy. The resulting work is called a "modified version" of the 87 | earlier work or a work "based on" the earlier work. 88 | 89 | A "covered work" means either the unmodified Program or a work based 90 | on the Program. 91 | 92 | To "propagate" a work means to do anything with it that, without 93 | permission, would make you directly or secondarily liable for 94 | infringement under applicable copyright law, except executing it on a 95 | computer or modifying a private copy. Propagation includes copying, 96 | distribution (with or without modification), making available to the 97 | public, and in some countries other activities as well. 98 | 99 | To "convey" a work means any kind of propagation that enables other 100 | parties to make or receive copies. Mere interaction with a user through 101 | a computer network, with no transfer of a copy, is not conveying. 102 | 103 | An interactive user interface displays "Appropriate Legal Notices" 104 | to the extent that it includes a convenient and prominently visible 105 | feature that (1) displays an appropriate copyright notice, and (2) 106 | tells the user that there is no warranty for the work (except to the 107 | extent that warranties are provided), that licensees may convey the 108 | work under this License, and how to view a copy of this License. If 109 | the interface presents a list of user commands or options, such as a 110 | menu, a prominent item in the list meets this criterion. 111 | 112 | 1. Source Code. 113 | 114 | The "source code" for a work means the preferred form of the work 115 | for making modifications to it. "Object code" means any non-source 116 | form of a work. 117 | 118 | A "Standard Interface" means an interface that either is an official 119 | standard defined by a recognized standards body, or, in the case of 120 | interfaces specified for a particular programming language, one that 121 | is widely used among developers working in that language. 122 | 123 | The "System Libraries" of an executable work include anything, other 124 | than the work as a whole, that (a) is included in the normal form of 125 | packaging a Major Component, but which is not part of that Major 126 | Component, and (b) serves only to enable use of the work with that 127 | Major Component, or to implement a Standard Interface for which an 128 | implementation is available to the public in source code form. A 129 | "Major Component", in this context, means a major essential component 130 | (kernel, window system, and so on) of the specific operating system 131 | (if any) on which the executable work runs, or a compiler used to 132 | produce the work, or an object code interpreter used to run it. 133 | 134 | The "Corresponding Source" for a work in object code form means all 135 | the source code needed to generate, install, and (for an executable 136 | work) run the object code and to modify the work, including scripts to 137 | control those activities. However, it does not include the work's 138 | System Libraries, or general-purpose tools or generally available free 139 | programs which are used unmodified in performing those activities but 140 | which are not part of the work. For example, Corresponding Source 141 | includes interface definition files associated with source files for 142 | the work, and the source code for shared libraries and dynamically 143 | linked subprograms that the work is specifically designed to require, 144 | such as by intimate data communication or control flow between those 145 | subprograms and other parts of the work. 146 | 147 | The Corresponding Source need not include anything that users 148 | can regenerate automatically from other parts of the Corresponding 149 | Source. 150 | 151 | The Corresponding Source for a work in source code form is that 152 | same work. 153 | 154 | 2. Basic Permissions. 155 | 156 | All rights granted under this License are granted for the term of 157 | copyright on the Program, and are irrevocable provided the stated 158 | conditions are met. This License explicitly affirms your unlimited 159 | permission to run the unmodified Program. The output from running a 160 | covered work is covered by this License only if the output, given its 161 | content, constitutes a covered work. This License acknowledges your 162 | rights of fair use or other equivalent, as provided by copyright law. 163 | 164 | You may make, run and propagate covered works that you do not 165 | convey, without conditions so long as your license otherwise remains 166 | in force. You may convey covered works to others for the sole purpose 167 | of having them make modifications exclusively for you, or provide you 168 | with facilities for running those works, provided that you comply with 169 | the terms of this License in conveying all material for which you do 170 | not control copyright. Those thus making or running the covered works 171 | for you must do so exclusively on your behalf, under your direction 172 | and control, on terms that prohibit them from making any copies of 173 | your copyrighted material outside their relationship with you. 174 | 175 | Conveying under any other circumstances is permitted solely under 176 | the conditions stated below. Sublicensing is not allowed; section 10 177 | makes it unnecessary. 178 | 179 | 3. Protecting Users' Legal Rights From Anti-Circumvention Law. 180 | 181 | No covered work shall be deemed part of an effective technological 182 | measure under any applicable law fulfilling obligations under article 183 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or 184 | similar laws prohibiting or restricting circumvention of such 185 | measures. 186 | 187 | When you convey a covered work, you waive any legal power to forbid 188 | circumvention of technological measures to the extent such circumvention 189 | is effected by exercising rights under this License with respect to 190 | the covered work, and you disclaim any intention to limit operation or 191 | modification of the work as a means of enforcing, against the work's 192 | users, your or third parties' legal rights to forbid circumvention of 193 | technological measures. 194 | 195 | 4. Conveying Verbatim Copies. 196 | 197 | You may convey verbatim copies of the Program's source code as you 198 | receive it, in any medium, provided that you conspicuously and 199 | appropriately publish on each copy an appropriate copyright notice; 200 | keep intact all notices stating that this License and any 201 | non-permissive terms added in accord with section 7 apply to the code; 202 | keep intact all notices of the absence of any warranty; and give all 203 | recipients a copy of this License along with the Program. 204 | 205 | You may charge any price or no price for each copy that you convey, 206 | and you may offer support or warranty protection for a fee. 207 | 208 | 5. Conveying Modified Source Versions. 209 | 210 | You may convey a work based on the Program, or the modifications to 211 | produce it from the Program, in the form of source code under the 212 | terms of section 4, provided that you also meet all of these conditions: 213 | 214 | a) The work must carry prominent notices stating that you modified 215 | it, and giving a relevant date. 216 | 217 | b) The work must carry prominent notices stating that it is 218 | released under this License and any conditions added under section 219 | 7. This requirement modifies the requirement in section 4 to 220 | "keep intact all notices". 221 | 222 | c) You must license the entire work, as a whole, under this 223 | License to anyone who comes into possession of a copy. This 224 | License will therefore apply, along with any applicable section 7 225 | additional terms, to the whole of the work, and all its parts, 226 | regardless of how they are packaged. This License gives no 227 | permission to license the work in any other way, but it does not 228 | invalidate such permission if you have separately received it. 229 | 230 | d) If the work has interactive user interfaces, each must display 231 | Appropriate Legal Notices; however, if the Program has interactive 232 | interfaces that do not display Appropriate Legal Notices, your 233 | work need not make them do so. 234 | 235 | A compilation of a covered work with other separate and independent 236 | works, which are not by their nature extensions of the covered work, 237 | and which are not combined with it such as to form a larger program, 238 | in or on a volume of a storage or distribution medium, is called an 239 | "aggregate" if the compilation and its resulting copyright are not 240 | used to limit the access or legal rights of the compilation's users 241 | beyond what the individual works permit. Inclusion of a covered work 242 | in an aggregate does not cause this License to apply to the other 243 | parts of the aggregate. 244 | 245 | 6. Conveying Non-Source Forms. 246 | 247 | You may convey a covered work in object code form under the terms 248 | of sections 4 and 5, provided that you also convey the 249 | machine-readable Corresponding Source under the terms of this License, 250 | in one of these ways: 251 | 252 | a) Convey the object code in, or embodied in, a physical product 253 | (including a physical distribution medium), accompanied by the 254 | Corresponding Source fixed on a durable physical medium 255 | customarily used for software interchange. 256 | 257 | b) Convey the object code in, or embodied in, a physical product 258 | (including a physical distribution medium), accompanied by a 259 | written offer, valid for at least three years and valid for as 260 | long as you offer spare parts or customer support for that product 261 | model, to give anyone who possesses the object code either (1) a 262 | copy of the Corresponding Source for all the software in the 263 | product that is covered by this License, on a durable physical 264 | medium customarily used for software interchange, for a price no 265 | more than your reasonable cost of physically performing this 266 | conveying of source, or (2) access to copy the 267 | Corresponding Source from a network server at no charge. 268 | 269 | c) Convey individual copies of the object code with a copy of the 270 | written offer to provide the Corresponding Source. This 271 | alternative is allowed only occasionally and noncommercially, and 272 | only if you received the object code with such an offer, in accord 273 | with subsection 6b. 274 | 275 | d) Convey the object code by offering access from a designated 276 | place (gratis or for a charge), and offer equivalent access to the 277 | Corresponding Source in the same way through the same place at no 278 | further charge. You need not require recipients to copy the 279 | Corresponding Source along with the object code. If the place to 280 | copy the object code is a network server, the Corresponding Source 281 | may be on a different server (operated by you or a third party) 282 | that supports equivalent copying facilities, provided you maintain 283 | clear directions next to the object code saying where to find the 284 | Corresponding Source. Regardless of what server hosts the 285 | Corresponding Source, you remain obligated to ensure that it is 286 | available for as long as needed to satisfy these requirements. 287 | 288 | e) Convey the object code using peer-to-peer transmission, provided 289 | you inform other peers where the object code and Corresponding 290 | Source of the work are being offered to the general public at no 291 | charge under subsection 6d. 292 | 293 | A separable portion of the object code, whose source code is excluded 294 | from the Corresponding Source as a System Library, need not be 295 | included in conveying the object code work. 296 | 297 | A "User Product" is either (1) a "consumer product", which means any 298 | tangible personal property which is normally used for personal, family, 299 | or household purposes, or (2) anything designed or sold for incorporation 300 | into a dwelling. In determining whether a product is a consumer product, 301 | doubtful cases shall be resolved in favor of coverage. For a particular 302 | product received by a particular user, "normally used" refers to a 303 | typical or common use of that class of product, regardless of the status 304 | of the particular user or of the way in which the particular user 305 | actually uses, or expects or is expected to use, the product. A product 306 | is a consumer product regardless of whether the product has substantial 307 | commercial, industrial or non-consumer uses, unless such uses represent 308 | the only significant mode of use of the product. 309 | 310 | "Installation Information" for a User Product means any methods, 311 | procedures, authorization keys, or other information required to install 312 | and execute modified versions of a covered work in that User Product from 313 | a modified version of its Corresponding Source. The information must 314 | suffice to ensure that the continued functioning of the modified object 315 | code is in no case prevented or interfered with solely because 316 | modification has been made. 317 | 318 | If you convey an object code work under this section in, or with, or 319 | specifically for use in, a User Product, and the conveying occurs as 320 | part of a transaction in which the right of possession and use of the 321 | User Product is transferred to the recipient in perpetuity or for a 322 | fixed term (regardless of how the transaction is characterized), the 323 | Corresponding Source conveyed under this section must be accompanied 324 | by the Installation Information. But this requirement does not apply 325 | if neither you nor any third party retains the ability to install 326 | modified object code on the User Product (for example, the work has 327 | been installed in ROM). 328 | 329 | The requirement to provide Installation Information does not include a 330 | requirement to continue to provide support service, warranty, or updates 331 | for a work that has been modified or installed by the recipient, or for 332 | the User Product in which it has been modified or installed. Access to a 333 | network may be denied when the modification itself materially and 334 | adversely affects the operation of the network or violates the rules and 335 | protocols for communication across the network. 336 | 337 | Corresponding Source conveyed, and Installation Information provided, 338 | in accord with this section must be in a format that is publicly 339 | documented (and with an implementation available to the public in 340 | source code form), and must require no special password or key for 341 | unpacking, reading or copying. 342 | 343 | 7. Additional Terms. 344 | 345 | "Additional permissions" are terms that supplement the terms of this 346 | License by making exceptions from one or more of its conditions. 347 | Additional permissions that are applicable to the entire Program shall 348 | be treated as though they were included in this License, to the extent 349 | that they are valid under applicable law. If additional permissions 350 | apply only to part of the Program, that part may be used separately 351 | under those permissions, but the entire Program remains governed by 352 | this License without regard to the additional permissions. 353 | 354 | When you convey a copy of a covered work, you may at your option 355 | remove any additional permissions from that copy, or from any part of 356 | it. (Additional permissions may be written to require their own 357 | removal in certain cases when you modify the work.) You may place 358 | additional permissions on material, added by you to a covered work, 359 | for which you have or can give appropriate copyright permission. 360 | 361 | Notwithstanding any other provision of this License, for material you 362 | add to a covered work, you may (if authorized by the copyright holders of 363 | that material) supplement the terms of this License with terms: 364 | 365 | a) Disclaiming warranty or limiting liability differently from the 366 | terms of sections 15 and 16 of this License; or 367 | 368 | b) Requiring preservation of specified reasonable legal notices or 369 | author attributions in that material or in the Appropriate Legal 370 | Notices displayed by works containing it; or 371 | 372 | c) Prohibiting misrepresentation of the origin of that material, or 373 | requiring that modified versions of such material be marked in 374 | reasonable ways as different from the original version; or 375 | 376 | d) Limiting the use for publicity purposes of names of licensors or 377 | authors of the material; or 378 | 379 | e) Declining to grant rights under trademark law for use of some 380 | trade names, trademarks, or service marks; or 381 | 382 | f) Requiring indemnification of licensors and authors of that 383 | material by anyone who conveys the material (or modified versions of 384 | it) with contractual assumptions of liability to the recipient, for 385 | any liability that these contractual assumptions directly impose on 386 | those licensors and authors. 387 | 388 | All other non-permissive additional terms are considered "further 389 | restrictions" within the meaning of section 10. If the Program as you 390 | received it, or any part of it, contains a notice stating that it is 391 | governed by this License along with a term that is a further 392 | restriction, you may remove that term. If a license document contains 393 | a further restriction but permits relicensing or conveying under this 394 | License, you may add to a covered work material governed by the terms 395 | of that license document, provided that the further restriction does 396 | not survive such relicensing or conveying. 397 | 398 | If you add terms to a covered work in accord with this section, you 399 | must place, in the relevant source files, a statement of the 400 | additional terms that apply to those files, or a notice indicating 401 | where to find the applicable terms. 402 | 403 | Additional terms, permissive or non-permissive, may be stated in the 404 | form of a separately written license, or stated as exceptions; 405 | the above requirements apply either way. 406 | 407 | 8. Termination. 408 | 409 | You may not propagate or modify a covered work except as expressly 410 | provided under this License. Any attempt otherwise to propagate or 411 | modify it is void, and will automatically terminate your rights under 412 | this License (including any patent licenses granted under the third 413 | paragraph of section 11). 414 | 415 | However, if you cease all violation of this License, then your 416 | license from a particular copyright holder is reinstated (a) 417 | provisionally, unless and until the copyright holder explicitly and 418 | finally terminates your license, and (b) permanently, if the copyright 419 | holder fails to notify you of the violation by some reasonable means 420 | prior to 60 days after the cessation. 421 | 422 | Moreover, your license from a particular copyright holder is 423 | reinstated permanently if the copyright holder notifies you of the 424 | violation by some reasonable means, this is the first time you have 425 | received notice of violation of this License (for any work) from that 426 | copyright holder, and you cure the violation prior to 30 days after 427 | your receipt of the notice. 428 | 429 | Termination of your rights under this section does not terminate the 430 | licenses of parties who have received copies or rights from you under 431 | this License. If your rights have been terminated and not permanently 432 | reinstated, you do not qualify to receive new licenses for the same 433 | material under section 10. 434 | 435 | 9. Acceptance Not Required for Having Copies. 436 | 437 | You are not required to accept this License in order to receive or 438 | run a copy of the Program. Ancillary propagation of a covered work 439 | occurring solely as a consequence of using peer-to-peer transmission 440 | to receive a copy likewise does not require acceptance. However, 441 | nothing other than this License grants you permission to propagate or 442 | modify any covered work. These actions infringe copyright if you do 443 | not accept this License. Therefore, by modifying or propagating a 444 | covered work, you indicate your acceptance of this License to do so. 445 | 446 | 10. Automatic Licensing of Downstream Recipients. 447 | 448 | Each time you convey a covered work, the recipient automatically 449 | receives a license from the original licensors, to run, modify and 450 | propagate that work, subject to this License. You are not responsible 451 | for enforcing compliance by third parties with this License. 452 | 453 | An "entity transaction" is a transaction transferring control of an 454 | organization, or substantially all assets of one, or subdividing an 455 | organization, or merging organizations. If propagation of a covered 456 | work results from an entity transaction, each party to that 457 | transaction who receives a copy of the work also receives whatever 458 | licenses to the work the party's predecessor in interest had or could 459 | give under the previous paragraph, plus a right to possession of the 460 | Corresponding Source of the work from the predecessor in interest, if 461 | the predecessor has it or can get it with reasonable efforts. 462 | 463 | You may not impose any further restrictions on the exercise of the 464 | rights granted or affirmed under this License. For example, you may 465 | not impose a license fee, royalty, or other charge for exercise of 466 | rights granted under this License, and you may not initiate litigation 467 | (including a cross-claim or counterclaim in a lawsuit) alleging that 468 | any patent claim is infringed by making, using, selling, offering for 469 | sale, or importing the Program or any portion of it. 470 | 471 | 11. Patents. 472 | 473 | A "contributor" is a copyright holder who authorizes use under this 474 | License of the Program or a work on which the Program is based. The 475 | work thus licensed is called the contributor's "contributor version". 476 | 477 | A contributor's "essential patent claims" are all patent claims 478 | owned or controlled by the contributor, whether already acquired or 479 | hereafter acquired, that would be infringed by some manner, permitted 480 | by this License, of making, using, or selling its contributor version, 481 | but do not include claims that would be infringed only as a 482 | consequence of further modification of the contributor version. For 483 | purposes of this definition, "control" includes the right to grant 484 | patent sublicenses in a manner consistent with the requirements of 485 | this License. 486 | 487 | Each contributor grants you a non-exclusive, worldwide, royalty-free 488 | patent license under the contributor's essential patent claims, to 489 | make, use, sell, offer for sale, import and otherwise run, modify and 490 | propagate the contents of its contributor version. 491 | 492 | In the following three paragraphs, a "patent license" is any express 493 | agreement or commitment, however denominated, not to enforce a patent 494 | (such as an express permission to practice a patent or covenant not to 495 | sue for patent infringement). To "grant" such a patent license to a 496 | party means to make such an agreement or commitment not to enforce a 497 | patent against the party. 498 | 499 | If you convey a covered work, knowingly relying on a patent license, 500 | and the Corresponding Source of the work is not available for anyone 501 | to copy, free of charge and under the terms of this License, through a 502 | publicly available network server or other readily accessible means, 503 | then you must either (1) cause the Corresponding Source to be so 504 | available, or (2) arrange to deprive yourself of the benefit of the 505 | patent license for this particular work, or (3) arrange, in a manner 506 | consistent with the requirements of this License, to extend the patent 507 | license to downstream recipients. "Knowingly relying" means you have 508 | actual knowledge that, but for the patent license, your conveying the 509 | covered work in a country, or your recipient's use of the covered work 510 | in a country, would infringe one or more identifiable patents in that 511 | country that you have reason to believe are valid. 512 | 513 | If, pursuant to or in connection with a single transaction or 514 | arrangement, you convey, or propagate by procuring conveyance of, a 515 | covered work, and grant a patent license to some of the parties 516 | receiving the covered work authorizing them to use, propagate, modify 517 | or convey a specific copy of the covered work, then the patent license 518 | you grant is automatically extended to all recipients of the covered 519 | work and works based on it. 520 | 521 | A patent license is "discriminatory" if it does not include within 522 | the scope of its coverage, prohibits the exercise of, or is 523 | conditioned on the non-exercise of one or more of the rights that are 524 | specifically granted under this License. You may not convey a covered 525 | work if you are a party to an arrangement with a third party that is 526 | in the business of distributing software, under which you make payment 527 | to the third party based on the extent of your activity of conveying 528 | the work, and under which the third party grants, to any of the 529 | parties who would receive the covered work from you, a discriminatory 530 | patent license (a) in connection with copies of the covered work 531 | conveyed by you (or copies made from those copies), or (b) primarily 532 | for and in connection with specific products or compilations that 533 | contain the covered work, unless you entered into that arrangement, 534 | or that patent license was granted, prior to 28 March 2007. 535 | 536 | Nothing in this License shall be construed as excluding or limiting 537 | any implied license or other defenses to infringement that may 538 | otherwise be available to you under applicable patent law. 539 | 540 | 12. No Surrender of Others' Freedom. 541 | 542 | If conditions are imposed on you (whether by court order, agreement or 543 | otherwise) that contradict the conditions of this License, they do not 544 | excuse you from the conditions of this License. If you cannot convey a 545 | covered work so as to satisfy simultaneously your obligations under this 546 | License and any other pertinent obligations, then as a consequence you may 547 | not convey it at all. For example, if you agree to terms that obligate you 548 | to collect a royalty for further conveying from those to whom you convey 549 | the Program, the only way you could satisfy both those terms and this 550 | License would be to refrain entirely from conveying the Program. 551 | 552 | 13. Use with the GNU Affero General Public License. 553 | 554 | Notwithstanding any other provision of this License, you have 555 | permission to link or combine any covered work with a work licensed 556 | under version 3 of the GNU Affero General Public License into a single 557 | combined work, and to convey the resulting work. The terms of this 558 | License will continue to apply to the part which is the covered work, 559 | but the special requirements of the GNU Affero General Public License, 560 | section 13, concerning interaction through a network will apply to the 561 | combination as such. 562 | 563 | 14. Revised Versions of this License. 564 | 565 | The Free Software Foundation may publish revised and/or new versions of 566 | the GNU General Public License from time to time. Such new versions will 567 | be similar in spirit to the present version, but may differ in detail to 568 | address new problems or concerns. 569 | 570 | Each version is given a distinguishing version number. If the 571 | Program specifies that a certain numbered version of the GNU General 572 | Public License "or any later version" applies to it, you have the 573 | option of following the terms and conditions either of that numbered 574 | version or of any later version published by the Free Software 575 | Foundation. If the Program does not specify a version number of the 576 | GNU General Public License, you may choose any version ever published 577 | by the Free Software Foundation. 578 | 579 | If the Program specifies that a proxy can decide which future 580 | versions of the GNU General Public License can be used, that proxy's 581 | public statement of acceptance of a version permanently authorizes you 582 | to choose that version for the Program. 583 | 584 | Later license versions may give you additional or different 585 | permissions. However, no additional obligations are imposed on any 586 | author or copyright holder as a result of your choosing to follow a 587 | later version. 588 | 589 | 15. Disclaimer of Warranty. 590 | 591 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY 592 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT 593 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY 594 | OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, 595 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 596 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM 597 | IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF 598 | ALL NECESSARY SERVICING, REPAIR OR CORRECTION. 599 | 600 | 16. Limitation of Liability. 601 | 602 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING 603 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS 604 | THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY 605 | GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE 606 | USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF 607 | DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD 608 | PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), 609 | EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF 610 | SUCH DAMAGES. 611 | 612 | 17. Interpretation of Sections 15 and 16. 613 | 614 | If the disclaimer of warranty and limitation of liability provided 615 | above cannot be given local legal effect according to their terms, 616 | reviewing courts shall apply local law that most closely approximates 617 | an absolute waiver of all civil liability in connection with the 618 | Program, unless a warranty or assumption of liability accompanies a 619 | copy of the Program in return for a fee. 620 | 621 | END OF TERMS AND CONDITIONS 622 | 623 | How to Apply These Terms to Your New Programs 624 | 625 | If you develop a new program, and you want it to be of the greatest 626 | possible use to the public, the best way to achieve this is to make it 627 | free software which everyone can redistribute and change under these terms. 628 | 629 | To do so, attach the following notices to the program. It is safest 630 | to attach them to the start of each source file to most effectively 631 | state the exclusion of warranty; and each file should have at least 632 | the "copyright" line and a pointer to where the full notice is found. 633 | 634 | 635 | Copyright (C) 636 | 637 | This program is free software: you can redistribute it and/or modify 638 | it under the terms of the GNU General Public License as published by 639 | the Free Software Foundation, either version 3 of the License, or 640 | (at your option) any later version. 641 | 642 | This program is distributed in the hope that it will be useful, 643 | but WITHOUT ANY WARRANTY; without even the implied warranty of 644 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 645 | GNU General Public License for more details. 646 | 647 | You should have received a copy of the GNU General Public License 648 | along with this program. If not, see . 649 | 650 | Also add information on how to contact you by electronic and paper mail. 651 | 652 | If the program does terminal interaction, make it output a short 653 | notice like this when it starts in an interactive mode: 654 | 655 | Copyright (C) 656 | This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'. 657 | This is free software, and you are welcome to redistribute it 658 | under certain conditions; type `show c' for details. 659 | 660 | The hypothetical commands `show w' and `show c' should show the appropriate 661 | parts of the General Public License. Of course, your program's commands 662 | might be different; for a GUI interface, you would use an "about box". 663 | 664 | You should also get your employer (if you work as a programmer) or school, 665 | if any, to sign a "copyright disclaimer" for the program, if necessary. 666 | For more information on this, and how to apply and follow the GNU GPL, see 667 | . 668 | 669 | The GNU General Public License does not permit incorporating your program 670 | into proprietary programs. If your program is a subroutine library, you 671 | may consider it more useful to permit linking proprietary applications with 672 | the library. If this is what you want to do, use the GNU Lesser General 673 | Public License instead of this License. But first, please read 674 | . 675 | -------------------------------------------------------------------------------- /In_Context_Learning_Demo_LanguageConditionedTimeSeriesPrediction.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "source": [ 6 | "# **Introduction to In-Context Learning Demo: Language Conditioned Time Series Prediction**\n", 7 | "\n", 8 | "## **Overview**\n", 9 | "In this notebook we show how Large Language Models (LLMs) can **learn and forecast physiological patterns from real-world time-series data** while also taking into account information provided in language form. This experiment uses full-day glucose measurements paired with meal logs to examine how effectively Gemini can internalize meal–glucose dynamics and predict future glucose trajectories. Traditional approaches to blood glucose prediction only take into account numerical data. However, knowing what foods the user take is critical in order to make realistic predictions of the blood glucose levels. **We leverage the fact that ICL uses LLMs in order to incorporate linguistic information about the user's diet into the prediction process**. We also highlight that LLMs already have substantial medical knowledge about glucose levels, Glycemic indeces and the role of diet on healthy bloodsugar levels. The goal is to highlight how an LLM can observe natural daily glucose patterns, understand how meals impact glycemic response, and extend that behavior into an unseen test day.\n", 10 | "\n", 11 | "## **Background**\n", 12 | "This demo focuses on real or simulated full-day sequences that include:\n", 13 | "- Glucose measurements recorded over the course of a day\n", 14 | "- Meal logs that act as the driving input events\n", 15 | "- Mentioning the Glycemic Index (GI) explicitly, so that the LLM infers extra nutritional context\n", 16 | "\n", 17 | "Each daily sequence captures natural dependencies such as meal timing, carbohydrate response, and gradual glucose decay. These sequences form the in-context “training set” that the LLM observes before generating predictions.\n", 18 | "\n", 19 | "What makes this setup particularly interesting:\n", 20 | "- The model sees realistic temporal meal–glucose dynamics\n", 21 | "- Prompts can include or exclude mention of GI without changing data\n", 22 | "- The LLM is never given explicit physiological equations\n", 23 | "- The test day requires generalization, not memorization\n", 24 | "- Predictions are generated step-by-step (using time points), mimicking real forecasting behavior\n", 25 | "\n", 26 | "This setting offers an intuitive benchmark for understanding whether an LLM can learn biological response patterns simply through context alone.\n", 27 | "\n", 28 | "## **Let's Take a Look at an Example**\n", 29 | "The illustration below shows an example of a test day containing glucose readings and meal times. Starting with the beginning of a new day the LLM must forecast. The model receives several complete daily sequences of previous days as examples, each showing how meals shape the glucose curve.\n", 30 | "\n", 31 | "\n", 32 | "\n", 33 | "## **LLM as the Forecaster**\n", 34 | "\n", 35 | "In this demo, we use Gemini 2.0/2.5 Flash in non-reasoning mode, ensuring the model focuses on direct pattern continuation rather than analytical or symbolic interpretation.\n", 36 | "\n", 37 | "The workflow is straightforward:\n", 38 | "- Provide several full-day glucose + meal sequences\n", 39 | "- Append a future day with partial context\n", 40 | "- Ask the model to predict one glucose value at a time\n", 41 | "- Iterate until the full sequence is completed\n", 42 | "\n", 43 | "Two prompting strategies are tested:\n", 44 | "- Mention of Glycemic Index: The prompt instructs the model to factor in Glycemic Index\n", 45 | "- No mention of Glycemic Index: No explicit nutritional guidance is included\n", 46 | "\n", 47 | "The LLM acts as a sequence forecaster—absorbing temporal structure, meal-driven spikes, and recovery patterns from the examples and applying them to the new day.\n", 48 | "\n", 49 | "## **Evaluation**\n", 50 | "Finally, we compare these predictions against a Gaussian Process Regression baseline and evaluate performance using RMSE and MSE, providing insight into how well LLMs can approximate physiological patterns through in-context learning alone.\n", 51 | "\n", 52 | "## **Code Overview**\n", 53 | "The implementation is organized into a modular structure, with each component responsible for a different stage of the glucose forecasting pipeline. This design separates data preparation, visualization, prompt construction, LLM inference, and performance evaluation, making the system easy to understand, modify, and extend.\n", 54 | "\n", 55 | "## **Before Running the Demo, Follow the Instructions Below**\n", 56 | "To run the full experiment:\n", 57 | "1. Ensure all dependencies are imported and installed.\n", 58 | "2.\tVisit Google AI Studio (https://aistudio.google.com/) to obtain your Gemini API key.\n", 59 | "3.\tOnce API key is generated, copy and paste it into the demo when prompted.\n", 60 | "\n", 61 | "#### **If having trouble creating an API Key, follow the link below for instructions:**\n", 62 | "* #### **[Instructions on How to Obtain a Gemini API Key](https://docs.google.com/document/d/17pgikIpvZCBgcRh4NNrcee-zp3Id0pU0vq2SG52KIGE/edit?usp=sharing)**\n", 63 | "\n", 64 | "### ***Note: Run all three demos to get comparison results**" 65 | ], 66 | "metadata": { 67 | "id": "iZkM9dGZ5OOD" 68 | }, 69 | "id": "iZkM9dGZ5OOD" 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": null, 74 | "id": "6e55c7eb", 75 | "metadata": { 76 | "cellView": "form", 77 | "id": "6e55c7eb" 78 | }, 79 | "outputs": [], 80 | "source": [ 81 | "#@title **Import Necessary Libraries**\n", 82 | "import numpy as np\n", 83 | "import pandas as pd\n", 84 | "import matplotlib.pyplot as plt\n", 85 | "from google import genai\n", 86 | "from google.genai import types\n", 87 | "import re\n", 88 | "import itertools\n", 89 | "import math\n", 90 | "import ast\n", 91 | "import os\n", 92 | "import getpass\n", 93 | "import ipywidgets as widgets\n", 94 | "import time" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "source": [ 100 | "#@title **Setting Up Gemini Client**\n", 101 | "apikey = getpass.getpass(\"Enter your Gemini API Key: \")\n", 102 | "client = genai.Client(api_key=apikey)\n", 103 | "comparison_results = {}\n", 104 | "config = types.GenerateContentConfig(\n", 105 | " thinking_config=types.ThinkingConfig(\n", 106 | " thinking_budget=0\n", 107 | " )\n", 108 | ")" 109 | ], 110 | "metadata": { 111 | "cellView": "form", 112 | "id": "uDeyfej15J9O" 113 | }, 114 | "id": "uDeyfej15J9O", 115 | "execution_count": null, 116 | "outputs": [] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "source": [ 121 | "#@title **Choose a Model**\n", 122 | "model_dropdown = widgets.Dropdown(\n", 123 | " options=[\n", 124 | " (\"Gemini 2.5 Flash\", \"gemini-2.5-flash\"),\n", 125 | " (\"Gemini 2.0 Flash\", \"gemini-2.0-flash\"),\n", 126 | " ],\n", 127 | " description=\"Model:\",\n", 128 | " value=\"gemini-2.5-flash\",\n", 129 | " style={'description_width': 'initial'}\n", 130 | ")\n", 131 | "\n", 132 | "confirm_button = widgets.Button(\n", 133 | " description=\"Confirm Selection\"\n", 134 | ")\n", 135 | "\n", 136 | "output = widgets.Output()\n", 137 | "\n", 138 | "model_name = None\n", 139 | "\n", 140 | "def on_confirm_click(b):\n", 141 | " global model_name, batch_size\n", 142 | "\n", 143 | " model_name = model_dropdown.value\n", 144 | "\n", 145 | " with output:\n", 146 | " output.clear_output()\n", 147 | " print(f\"\\nSelected model: {model_name}\")\n", 148 | "\n", 149 | "confirm_button.on_click(on_confirm_click)\n", 150 | "\n", 151 | "display(model_dropdown, confirm_button, output)" 152 | ], 153 | "metadata": { 154 | "cellView": "form", 155 | "id": "EGRpvXm5ojqQ" 156 | }, 157 | "id": "EGRpvXm5ojqQ", 158 | "execution_count": null, 159 | "outputs": [] 160 | }, 161 | { 162 | "cell_type": "code", 163 | "source": [ 164 | "#@title **Download Data from GitHub**\n", 165 | "if not os.path.exists(\"intro_to_icl_data\"):\n", 166 | " !git clone https://github.com/hsiang-fu/intro_to_icl_data.git\n", 167 | "\n", 168 | "glucose_levels = pd.read_csv(\"intro_to_icl_data/blood_sugar_data.csv\")\n", 169 | "food_data = pd.read_csv(\"intro_to_icl_data/food_data.csv\")" 170 | ], 171 | "metadata": { 172 | "cellView": "form", 173 | "id": "Qe0-gZkR38oQ" 174 | }, 175 | "id": "Qe0-gZkR38oQ", 176 | "execution_count": null, 177 | "outputs": [] 178 | }, 179 | { 180 | "cell_type": "markdown", 181 | "source": [ 182 | "This cell `Model Inference Function`, is an essential part of this whole demo, it acts as the interface between our program and the language model by sending prompts to the LLM and returning the model’s generated output. Every prompt used throughout the workflow is passed through this function, making it the central mechanism for all model communication." 183 | ], 184 | "metadata": { 185 | "id": "2yC-jn4PpbDi" 186 | }, 187 | "id": "2yC-jn4PpbDi" 188 | }, 189 | { 190 | "cell_type": "code", 191 | "execution_count": null, 192 | "id": "d6846bc9", 193 | "metadata": { 194 | "id": "d6846bc9" 195 | }, 196 | "outputs": [], 197 | "source": [ 198 | "#@title **Model Inference Function**\n", 199 | "def predict_diabetes_level(prompt):\n", 200 | " response = chat.send_message(prompt)\n", 201 | " return response.text" 202 | ] 203 | }, 204 | { 205 | "cell_type": "markdown", 206 | "source": [ 207 | "The two cells below is responsible for running the full interactive glucose-forecasting demo, including user example selection, prompt construction, LLM inference, timeline processing, visualization of predicted vs. expected values, and parsing model outputs. It essentially acts as the end-to-end execution and UI layer of the notebook.\n", 208 | "\n", 209 | "**User Interface**\n", 210 | "- The cell provides a simple interface: a user prompt asking how many days of history to include as examples. This makes the demo interactive, allowing users to choose the size of the in-context learning window without modifying any code.\n", 211 | "\n", 212 | "**What the Cell Does**\n", 213 | "After the user selects the number of example days, the cell:\n", 214 | "1. Builds ICL Training Examples: Formats glucose + food logs from the chosen days into few-shot examples.\n", 215 | "2. Constructs the Initial Warm-Up Prompt: Creates a persistent LLM chat session and sends the examples as context.\n", 216 | "3. Prepares the Test Data: Loads food and glucose data for the next unseen day.\n", 217 | "4. Builds a Timeline: Merges meal events and glucose events into a single chronological sequence.\n", 218 | "5. Runs the Prediction Loop: dor each glucose timestamp, the cell determines the relevant meal context, constructs a prediction prompt that incorporates this context along with the glycemic index, sends the prompt to the LLM, and finally parses and stores the model’s predicted glucose value.\n", 219 | "\n", 220 | "**What Each Example Outputs**\n", 221 | "A visualization contaning:\n", 222 | "- The expected output (Black Curve)\n", 223 | "- The LLM-generated output (Red Curve)\n", 224 | "- The MSE and RMSE of predicted values\n", 225 | "\n", 226 | "**How Does Example 1 and 2 Differ?**\n", 227 | "\n", 228 | "As mentioned above, we tested two prediction cases for ICL, so:\n", 229 | "- **Example 1**: GI-aware, the prompt instructs the model to factor in Glycemic Index\n", 230 | "- **Example 2**: GI-agnostic, no explicit nutritional guidance is included" 231 | ], 232 | "metadata": { 233 | "id": "afSzg_Qpq86u" 234 | }, 235 | "id": "afSzg_Qpq86u" 236 | }, 237 | { 238 | "cell_type": "markdown", 239 | "source": [ 240 | "### **Let's Take a Look at an Example of a Prompt:**\n", 241 | "\n", 242 | "**Initial Prompt**\n", 243 | "```\n", 244 | "You are the world's best glucose level predictor for adults.\n", 245 | "Here are some examples of my glucose data for the previous days:\n", 246 | "{examples}\n", 247 | "\n", 248 | "*Note: I am not diabetic.\n", 249 | "```\n", 250 | "\n", 251 | "This prompt establishes the role of the LLM and provides it with the context needed for in-context learning. By telling the model that it is the “world’s best glucose level predictor,” the prompt sets an expectation that it should focus on accurate forecasting. The following line introduces several example days of glucose data—supplied in place of {examples}—which act as the training demonstrations from which the model learns the underlying patterns of glucose behavior. The note provides the context of the individual.\n", 252 | "\n", 253 | "Then, to predict the individual's forecasted glucose levels, we pass in one of the three versions of these prompt all combined.\n", 254 | "```\n", 255 | "My last glucose level was {glucose_level} at {previous_time_point}.\n", 256 | "```\n", 257 | "**plus one of the following:**\n", 258 | "\n", 259 | "Option 1:\n", 260 | "```\n", 261 | "I just ate ..., ..., ... at {time} on {date}, using the **Glycemic Index** of the meal, predict my glucose level at {next_time point}\n", 262 | "```\n", 263 | "Option 2:\n", 264 | "```\n", 265 | "I previously ate at {time} on {date}, infer the **Glycemic Index** values of the meal, predict my glucose level at {next_time point}\n", 266 | "```\n", 267 | "Option 3:\n", 268 | "```\n", 269 | "I haven't eaten anything yet, using the **Glycemic Index** of the meal, predict my glucose level at {next_time point}\n", 270 | "```\n", 271 | "**plus this last section**\n", 272 | "```\n", 273 | "Output only a Python list with a single dictionary in this format:\n", 274 | "[{{'Date': '{target_date}', 'Time': '{t}', 'Blood Sugar Level': value}}]\n", 275 | "No extra text, no explanations, no code, only the list.\n", 276 | "```\n", 277 | "This combined prompt gives the LLM clear situational context (recent glucose level, meal events, and timing) along with strict output formatting instructions, enabling it to generate accurate and consistent predictions for the next time point.\n" 278 | ], 279 | "metadata": { 280 | "id": "TnJpvB3BuiXB" 281 | }, 282 | "id": "TnJpvB3BuiXB" 283 | }, 284 | { 285 | "cell_type": "code", 286 | "source": [ 287 | "#@title ## **Example 1: LLM Forecasting w/ mention of the Glycemic Index**\n", 288 | "output = []\n", 289 | "days_in_example = int(input(\"Enter how many days you want to use as examples (e.g. 1-30): \"))\n", 290 | "\n", 291 | "days = pd.unique(glucose_levels[\"Date\"])[:days_in_example]\n", 292 | "examples = [\n", 293 | " f\"Here is an example of my glucose level data for {day}:\\n\"\n", 294 | " f\"{glucose_levels.loc[glucose_levels['Date'] == day]}\\n\"\n", 295 | " f\"and examples of the food I consumed at the time:\\n\"\n", 296 | " f\"{food_data.loc[food_data['Date'] == day]}\"\n", 297 | " for day in days\n", 298 | "]\n", 299 | "examples_text = \"\\n\\n\".join(examples)\n", 300 | "\n", 301 | "initial_prompt = f\"\"\"\n", 302 | " You are the world's best glucose level predictor for adults.\n", 303 | " Here are some examples of my glucose data for the previous days:\n", 304 | " {examples_text}\n", 305 | "\n", 306 | " *Note: I am not diabetic.\n", 307 | " \"\"\"\n", 308 | "\n", 309 | "chat = client.chats.create(model=model_name, config=config)\n", 310 | "\n", 311 | "chat.send_message(initial_prompt)\n", 312 | "\n", 313 | "target_date = pd.unique(food_data[\"Date\"])[days_in_example]\n", 314 | "\n", 315 | "target_idx = list(pd.unique(food_data[\"Date\"])).index(target_date)\n", 316 | "dates_to_use = pd.unique(food_data[\"Date\"])[target_idx:target_idx+2]\n", 317 | "food_date_data = food_data[food_data[\"Date\"].isin(dates_to_use)].copy()\n", 318 | "\n", 319 | "glucose_date_data = glucose_levels[glucose_levels[\"Date\"] == target_date].copy()\n", 320 | "\n", 321 | "meals = food_date_data[[\"Date\", \"Time\", \"Meal Type\", \"Food Items\"]].copy()\n", 322 | "meals[\"Event\"] = \"Meal\"\n", 323 | "\n", 324 | "glucose = glucose_date_data[[\"Date\", \"Time\", \"Blood Sugar Level\"]].copy()\n", 325 | "glucose[\"Event\"] = \"Glucose\"\n", 326 | "\n", 327 | "timeline = pd.concat([meals, glucose], ignore_index=True)\n", 328 | "\n", 329 | "timeline[\"DateTime\"] = pd.to_datetime(\n", 330 | " timeline[\"Date\"] + \" \" + timeline[\"Time\"],\n", 331 | " format=\"%d-%m-%Y %I:%M %p\",\n", 332 | " errors=\"coerce\"\n", 333 | ")\n", 334 | "\n", 335 | "timeline = timeline.sort_values(\"DateTime\").reset_index(drop=True)\n", 336 | "\n", 337 | "last_meal = None\n", 338 | "previous_context_meal = None\n", 339 | "last_context_meal_time = None\n", 340 | "current_glucose = glucose_date_data[\"Blood Sugar Level\"].iloc[0]\n", 341 | "i = glucose_date_data[\"Time\"].iloc[0]\n", 342 | "\n", 343 | "print(f\"Running Prediction Pipeline\")\n", 344 | "for iteration, row in timeline.iterrows():\n", 345 | " if row[\"Event\"] == \"Meal\":\n", 346 | " last_meal = row\n", 347 | " continue\n", 348 | "\n", 349 | " if row[\"Event\"] == \"Glucose\":\n", 350 | " t = row[\"Time\"]\n", 351 | "\n", 352 | " # Determine if the last_meal is actually a *new* meal\n", 353 | " if last_meal is not None and last_meal[\"Time\"] != last_context_meal_time:\n", 354 | " # New meal detected\n", 355 | " meal_food = last_meal[\"Food Items\"]\n", 356 | " meal_type = last_meal[\"Meal Type\"]\n", 357 | " meal_time = last_meal[\"Time\"]\n", 358 | "\n", 359 | " context_meal = f\"I just ate {meal_food} at {meal_time}\"\n", 360 | "\n", 361 | " last_context_meal_time = meal_time\n", 362 | "\n", 363 | " else:\n", 364 | " # No new meal — use \"previously ate\" logic\n", 365 | " if previous_context_meal and previous_context_meal.startswith(\"I just ate\"):\n", 366 | " context_meal = f\"I previously ate a meal at {meal_time}\"\n", 367 | " elif previous_context_meal and previous_context_meal.startswith(\"I previously ate\"):\n", 368 | " context_meal = f\"I previously ate a meal at {meal_time}\"\n", 369 | " else:\n", 370 | " context_meal = \"I haven't eaten anything yet.\"\n", 371 | "\n", 372 | " previous_context_meal = context_meal\n", 373 | "\n", 374 | "\n", 375 | " prompt = f\"\"\"\n", 376 | " My last glucose level was {current_glucose} at {i}.\n", 377 | " {context_meal} on {target_date}, infer the **Glycemic Index** values of the meal, predict my glucose level at {t}\n", 378 | "\n", 379 | " Output only a Python list with a single dictionary in this format:\n", 380 | " [{{'Date': '{target_date}', 'Time': '{t}', 'Blood Sugar Level': value}}]\n", 381 | " No extra text, no explanations, no code, only the list.\n", 382 | " \"\"\"\n", 383 | "\n", 384 | " try:\n", 385 | " print(f\"Running Iteration {iteration}\")\n", 386 | " result = predict_diabetes_level(prompt).replace(\"`\", \"\").replace(\"python\", \"\").replace(\"json\", \"\")\n", 387 | " except Exception as e:\n", 388 | " print(\"Waiting 60 seconds due to quota limits\")\n", 389 | " time.sleep(60)\n", 390 | " print(f\"Rerunning Iteration {iteration}\")\n", 391 | " result = predict_diabetes_level(prompt).replace(\"`\", \"\").replace(\"python\", \"\").replace(\"json\", \"\")\n", 392 | "\n", 393 | " try:\n", 394 | " result_list = ast.literal_eval(result.strip())\n", 395 | " output.extend(result_list)\n", 396 | " current_glucose = result_list[-1][\"Blood Sugar Level\"]\n", 397 | " except Exception as e:\n", 398 | " print(\"Parsing error:\", result, e)\n", 399 | "\n", 400 | " i = t\n", 401 | "\n", 402 | "df2 = pd.DataFrame(output)\n", 403 | "\n", 404 | "df1 = glucose_levels.loc[glucose_levels['Date'] == target_date]\n", 405 | "\n", 406 | "merged = pd.merge(df1, df2, on=\"Time\", suffixes=(\"_expected\", \"_predicted\"))\n", 407 | "\n", 408 | "merged[\"Blood Sugar Level_predicted\"] = merged[\"Blood Sugar Level_predicted\"].astype(int)\n", 409 | "\n", 410 | "merged[\"Error\"] = merged[\"Blood Sugar Level_expected\"] - merged[\"Blood Sugar Level_predicted\"]\n", 411 | "merged[\"Absolute Error\"] = merged[\"Error\"].abs()\n", 412 | "merged[\"Squared Error\"] = merged[\"Error\"] ** 2\n", 413 | "\n", 414 | "mae = merged[\"Absolute Error\"].mean()\n", 415 | "mse = merged[\"Squared Error\"].mean()\n", 416 | "rmse = np.sqrt(mse)\n", 417 | "\n", 418 | "comparison_results[\"MSE ICL (w/ mention)\"] = mse\n", 419 | "comparison_results[\"RMSE ICL (w/ mention)\"] = rmse\n", 420 | "\n", 421 | "print(\"Mean Squared Error (MSE):\", mse)\n", 422 | "print(\"Root Mean Squared Error (RMSE):\", rmse)\n", 423 | "\n", 424 | "fig, ax1 = plt.subplots(figsize=(10,7))\n", 425 | "\n", 426 | "ax1.plot(merged[\"Time\"], merged[\"Blood Sugar Level_expected\"], label=\"Expected Values\", color = \"black\", marker=\"o\")\n", 427 | "ax1.plot(merged[\"Time\"], merged[\"Blood Sugar Level_predicted\"], label=\"Predicted Values\", color = \"red\", marker=\"o\")\n", 428 | "ax1.set_xlabel(\"Time\")\n", 429 | "ax1.set_ylabel(\"Blood Sugar Level\")\n", 430 | "ax1.tick_params(axis=\"x\", rotation=45)\n", 431 | "\n", 432 | "lines, labels = ax1.get_legend_handles_labels()\n", 433 | "\n", 434 | "plt.title(\"Blood Sugar Levels (w/ mention of glycemic index)\")\n", 435 | "plt.legend()\n", 436 | "plt.grid(True)\n", 437 | "plt.show()" 438 | ], 439 | "metadata": { 440 | "id": "DbytXWjuLVal", 441 | "cellView": "form" 442 | }, 443 | "id": "DbytXWjuLVal", 444 | "execution_count": null, 445 | "outputs": [] 446 | }, 447 | { 448 | "cell_type": "code", 449 | "execution_count": null, 450 | "id": "7edeeb4c", 451 | "metadata": { 452 | "id": "7edeeb4c", 453 | "cellView": "form" 454 | }, 455 | "outputs": [], 456 | "source": [ 457 | "#@title ## **Example 2: LLM Forecasting w/o mention of the Glycemic Index**\n", 458 | "output = []\n", 459 | "days_in_example = int(input(\"Enter how many days you want to use as examples (use the same as above): \"))\n", 460 | "\n", 461 | "days = pd.unique(glucose_levels[\"Date\"])[:days_in_example]\n", 462 | "examples = [\n", 463 | " f\"Here is an example of my glucose level data for {day}:\\n\"\n", 464 | " f\"{glucose_levels.loc[glucose_levels['Date'] == day]}\\n\"\n", 465 | " f\"and examples of the food I consumed at the time:\\n\"\n", 466 | " f\"{food_data.loc[food_data['Date'] == day]}\"\n", 467 | " for day in days\n", 468 | "]\n", 469 | "examples_text = \"\\n\\n\".join(examples)\n", 470 | "\n", 471 | "initial_prompt = f\"\"\"\n", 472 | " You are the world's best glucose level predictor for adults.\n", 473 | " Here are some examples of my glucose data for the previous days:\n", 474 | " {examples_text}\n", 475 | "\n", 476 | " *Note: I am not diabetic.\n", 477 | " \"\"\"\n", 478 | "\n", 479 | "chat = client.chats.create(model=model_name,config=config)\n", 480 | "\n", 481 | "chat.send_message(initial_prompt)\n", 482 | "\n", 483 | "target_date = pd.unique(food_data[\"Date\"])[days_in_example]\n", 484 | "\n", 485 | "target_idx = list(pd.unique(food_data[\"Date\"])).index(target_date)\n", 486 | "dates_to_use = pd.unique(food_data[\"Date\"])[target_idx:target_idx+2]\n", 487 | "food_date_data = food_data[food_data[\"Date\"].isin(dates_to_use)].copy()\n", 488 | "\n", 489 | "glucose_date_data = glucose_levels[glucose_levels[\"Date\"] == target_date].copy()\n", 490 | "\n", 491 | "meals = food_date_data[[\"Date\", \"Time\", \"Meal Type\", \"Food Items\"]].copy()\n", 492 | "meals[\"Event\"] = \"Meal\"\n", 493 | "\n", 494 | "glucose = glucose_date_data[[\"Date\", \"Time\", \"Blood Sugar Level\"]].copy()\n", 495 | "glucose[\"Event\"] = \"Glucose\"\n", 496 | "\n", 497 | "timeline = pd.concat([meals, glucose], ignore_index=True)\n", 498 | "\n", 499 | "timeline[\"DateTime\"] = pd.to_datetime(\n", 500 | " timeline[\"Date\"] + \" \" + timeline[\"Time\"],\n", 501 | " format=\"%d-%m-%Y %I:%M %p\",\n", 502 | " errors=\"coerce\"\n", 503 | ")\n", 504 | "\n", 505 | "timeline = timeline.sort_values(\"DateTime\").reset_index(drop=True)\n", 506 | "\n", 507 | "last_meal = None\n", 508 | "previous_context_meal = None\n", 509 | "last_context_meal_time = None\n", 510 | "current_glucose = glucose_date_data[\"Blood Sugar Level\"].iloc[0]\n", 511 | "i = glucose_date_data[\"Time\"].iloc[0]\n", 512 | "\n", 513 | "print(f\"Running Prediction Pipeline\")\n", 514 | "for iteration, row in timeline.iterrows():\n", 515 | " if row[\"Event\"] == \"Meal\":\n", 516 | " last_meal = row\n", 517 | " continue\n", 518 | "\n", 519 | " if row[\"Event\"] == \"Glucose\":\n", 520 | " t = row[\"Time\"]\n", 521 | "\n", 522 | " # Determine if the last_meal is actually a *new* meal\n", 523 | " if last_meal is not None and last_meal[\"Time\"] != last_context_meal_time:\n", 524 | " # New meal detected\n", 525 | " meal_food = last_meal[\"Food Items\"]\n", 526 | " meal_type = last_meal[\"Meal Type\"]\n", 527 | " meal_time = last_meal[\"Time\"]\n", 528 | "\n", 529 | " context_meal = f\"I just ate {meal_food} at {meal_time}\"\n", 530 | "\n", 531 | " last_context_meal_time = meal_time\n", 532 | "\n", 533 | " else:\n", 534 | " # No new meal — use \"previously ate\" logic\n", 535 | " if previous_context_meal and previous_context_meal.startswith(\"I just ate\"):\n", 536 | " context_meal = f\"I previously ate a meal at {meal_time}\"\n", 537 | " elif previous_context_meal and previous_context_meal.startswith(\"I previously ate\"):\n", 538 | " context_meal = f\"I previously ate a meal at {meal_time}\"\n", 539 | " else:\n", 540 | " context_meal = \"I haven't eaten anything yet.\"\n", 541 | "\n", 542 | " previous_context_meal = context_meal\n", 543 | "\n", 544 | "\n", 545 | " prompt = f\"\"\"\n", 546 | " Here are the examples:\n", 547 | " {examples_text}\n", 548 | "\n", 549 | " My last glucose level was {current_glucose} at {i}.\n", 550 | " {context_meal} on {target_date}, using the meal, predict my glucose level at {t}\n", 551 | "\n", 552 | " Output only a Python list with a single dictionary in this format:\n", 553 | " [{{'Date': '{target_date}', 'Time': '{t}', 'Blood Sugar Level': value}}]\n", 554 | " No extra text, no explanations, no code, only the list.\n", 555 | " \"\"\"\n", 556 | "\n", 557 | " try:\n", 558 | " print(f\"Running Iteration {iteration}\")\n", 559 | " result = predict_diabetes_level(prompt).replace(\"`\", \"\").replace(\"python\", \"\").replace(\"json\", \"\")\n", 560 | " except Exception as e:\n", 561 | " print(\"Waiting 60 seconds due to quota limits\")\n", 562 | " time.sleep(60)\n", 563 | " print(f\"Rerunning Iteration {iteration}\")\n", 564 | " result = predict_diabetes_level(prompt).replace(\"`\", \"\").replace(\"python\", \"\").replace(\"json\", \"\")\n", 565 | "\n", 566 | " try:\n", 567 | " result_list = ast.literal_eval(result.strip())\n", 568 | " output.extend(result_list)\n", 569 | " current_glucose = result_list[-1][\"Blood Sugar Level\"]\n", 570 | " except Exception as e:\n", 571 | " print(\"Parsing error:\", result, e)\n", 572 | "\n", 573 | " i = t\n", 574 | "\n", 575 | "df2 = pd.DataFrame(output)\n", 576 | "\n", 577 | "df1 = glucose_levels.loc[glucose_levels['Date'] == target_date]\n", 578 | "\n", 579 | "merged = pd.merge(df1, df2, on=\"Time\", suffixes=(\"_expected\", \"_predicted\"))\n", 580 | "\n", 581 | "merged[\"Blood Sugar Level_predicted\"] = merged[\"Blood Sugar Level_predicted\"].astype(int)\n", 582 | "\n", 583 | "merged[\"Error\"] = merged[\"Blood Sugar Level_expected\"] - merged[\"Blood Sugar Level_predicted\"]\n", 584 | "merged[\"Absolute Error\"] = merged[\"Error\"].abs()\n", 585 | "merged[\"Squared Error\"] = merged[\"Error\"] ** 2\n", 586 | "\n", 587 | "mae = merged[\"Absolute Error\"].mean()\n", 588 | "mse = merged[\"Squared Error\"].mean()\n", 589 | "rmse = np.sqrt(mse)\n", 590 | "\n", 591 | "comparison_results[\"MSE ICL (w/o mention)\"] = mse\n", 592 | "comparison_results[\"RMSE ICL (w/o mention)\"] = rmse\n", 593 | "\n", 594 | "print(\"Mean Squared Error (MSE):\", mse)\n", 595 | "print(\"Root Mean Squared Error (RMSE):\", rmse)\n", 596 | "\n", 597 | "fig, ax1 = plt.subplots(figsize=(10,7))\n", 598 | "\n", 599 | "ax1.plot(merged[\"Time\"], merged[\"Blood Sugar Level_expected\"], label=\"Expected Values\", color = \"black\", marker=\"o\")\n", 600 | "ax1.plot(merged[\"Time\"], merged[\"Blood Sugar Level_predicted\"], label=\"Predicted Values\", color = \"red\", marker=\"o\")\n", 601 | "ax1.set_xlabel(\"Time\")\n", 602 | "ax1.set_ylabel(\"Blood Sugar Level\")\n", 603 | "ax1.tick_params(axis=\"x\", rotation=45)\n", 604 | "\n", 605 | "lines, labels = ax1.get_legend_handles_labels()\n", 606 | "\n", 607 | "plt.title(\"Blood Sugar Levels (w/o mention of glycemic index)\")\n", 608 | "plt.legend()\n", 609 | "plt.grid(True)\n", 610 | "plt.show()" 611 | ] 612 | }, 613 | { 614 | "cell_type": "markdown", 615 | "source": [ 616 | "This cell `Baseline: Gaussian Process Regression`, implements the baseline model used to benchmark the two in-context learning forecasting approaches. We use a Gaussian Process Regression (GPR) model with a simple DotProduct + WhiteKernel combination to predict glucose levels based on the previous two time steps (lag-1 and lag-2).\n", 617 | "\n", 618 | "After loading and preprocessing the glucose and meal datasets, the code constructs lag features, splits the data into training and test sets, and fits the GPR model.\n", 619 | "\n", 620 | "It then computes overall MSE and RMSE across the full dataset to establish a general performance reference. Next, it isolates the specific day selected for comparison, generates GPR predictions for that day, and evaluates them with day-level MSE and RMSE. Finally, the code visualizes the actual glucose values alongside model predictions and their 95% confidence interval, providing a clear baseline curve to compare against the LLM’s GI-aware and GI-agnostic forecasts." 621 | ], 622 | "metadata": { 623 | "id": "qiYjO0M5L9Li" 624 | }, 625 | "id": "qiYjO0M5L9Li" 626 | }, 627 | { 628 | "cell_type": "code", 629 | "source": [ 630 | "#@title **Baseline: Gaussian Process Regression**\n", 631 | "\n", 632 | "from sklearn.gaussian_process import GaussianProcessRegressor\n", 633 | "from sklearn.gaussian_process.kernels import DotProduct, WhiteKernel\n", 634 | "from sklearn.metrics import mean_squared_error\n", 635 | "from sklearn.model_selection import train_test_split\n", 636 | "\n", 637 | "days_before_test = int(input(\"Enter the same amount of days you entered prior: \"))\n", 638 | "\n", 639 | "glucose = pd.read_csv(\"intro_to_icl_data/blood_sugar_data.csv\")\n", 640 | "food_data = pd.read_csv(\"intro_to_icl_data/food_data.csv\")\n", 641 | "\n", 642 | "glucose[\"Date\"] = pd.to_datetime(glucose[\"Date\"], format=\"%d-%m-%Y\", errors=\"coerce\")\n", 643 | "food_data[\"Date\"] = pd.to_datetime(food_data[\"Date\"], format=\"%d-%m-%Y\", errors=\"coerce\")\n", 644 | "\n", 645 | "glucose[\"glucose_lag1\"] = glucose[\"Blood Sugar Level\"].shift(1)\n", 646 | "glucose[\"glucose_lag2\"] = glucose[\"Blood Sugar Level\"].shift(2)\n", 647 | "glucose = glucose.dropna()\n", 648 | "\n", 649 | "X = glucose[[\"glucose_lag1\", \"glucose_lag2\"]]\n", 650 | "y = glucose[\"Blood Sugar Level\"]\n", 651 | "\n", 652 | "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)\n", 653 | "\n", 654 | "kernel = DotProduct() + WhiteKernel()\n", 655 | "gpr = GaussianProcessRegressor(kernel=kernel, random_state=0)\n", 656 | "gpr.fit(X_train, y_train)\n", 657 | "\n", 658 | "preds, stds = gpr.predict(X_test, return_std=True)\n", 659 | "\n", 660 | "mse = mean_squared_error(y_test, preds)\n", 661 | "rmse = np.sqrt(mse)\n", 662 | "\n", 663 | "print(\"\\nMSE & RMSE for All Test Cases in the Dataset\")\n", 664 | "print(\"Mean Squared Error (MSE):\", mse)\n", 665 | "print(\"Root Mean Squared Error (RMSE):\", rmse)\n", 666 | "\n", 667 | "one_day = pd.unique(food_data[\"Date\"])[days_before_test]\n", 668 | "day_data = glucose[glucose[\"Date\"] == one_day].copy()\n", 669 | "\n", 670 | "if not day_data.empty:\n", 671 | " date_str = pd.Timestamp(one_day).strftime('%Y-%m-%d')\n", 672 | " day_data[\"Predicted\"], day_data[\"Std\"] = gpr.predict(day_data[[\"glucose_lag1\", \"glucose_lag2\"]], return_std=True)\n", 673 | "\n", 674 | " mse = mean_squared_error(day_data[\"Blood Sugar Level\"], day_data[\"Predicted\"])\n", 675 | " rmse = np.sqrt(mse)\n", 676 | "\n", 677 | " comparison_results[\"MSE Baseline\"] = mse\n", 678 | " comparison_results[\"RMSE Baseline\"] = rmse\n", 679 | "\n", 680 | " print(f\"\\nMSE & RMSE for {date_str}\")\n", 681 | " print(\"Mean Squared Error (MSE):\", mse)\n", 682 | " print(\"Root Mean Squared Error (RMSE):\", rmse)\n", 683 | "\n", 684 | " day_data[\"Lower\"] = day_data[\"Predicted\"] - 1.96 * day_data[\"Std\"]\n", 685 | " day_data[\"Upper\"] = day_data[\"Predicted\"] + 1.96 * day_data[\"Std\"]\n", 686 | "\n", 687 | " plt.figure(figsize=(10,7))\n", 688 | " x_axis = day_data[\"Time\"] if \"Time\" in day_data.columns else day_data[\"Date\"].dt.strftime(\"%H:%M\")\n", 689 | "\n", 690 | " plt.fill_between(x_axis, day_data[\"Lower\"], day_data[\"Upper\"], alpha=0.2, color='gray', label='95% Confidence Interval')\n", 691 | " plt.plot(x_axis, day_data[\"Blood Sugar Level\"], color=\"black\", label=\"Actual\", marker='o')\n", 692 | " plt.plot(x_axis, day_data[\"Predicted\"], color=\"red\", label=\"Predicted\", marker='o')\n", 693 | "\n", 694 | "\n", 695 | " date_str = pd.Timestamp(one_day).strftime('%Y-%m-%d')\n", 696 | " plt.title(f\"Blood Sugar Levels (Gaussian Process Regression) for {date_str}\")\n", 697 | " plt.xlabel(\"Time\")\n", 698 | " plt.ylabel(\"Blood Sugar Level\")\n", 699 | " plt.legend()\n", 700 | " plt.xticks(rotation=45)\n", 701 | " plt.tight_layout()\n", 702 | " plt.show()" 703 | ], 704 | "metadata": { 705 | "id": "umsxyQCk0tRK", 706 | "cellView": "form" 707 | }, 708 | "id": "umsxyQCk0tRK", 709 | "execution_count": null, 710 | "outputs": [] 711 | }, 712 | { 713 | "cell_type": "markdown", 714 | "source": [ 715 | "### **Before Running This Section, make sure you have already ran all three demos:**\n", 716 | "\t1.\tICL w/ mention of glycemic index (Example 1)\n", 717 | "\t2.\tICL w/o mention of glycemic index (Example 2)\n", 718 | "\t3.\tGaussian Process Regression (GPR) (Baseline)\n", 719 | "\n", 720 | "This cell `Comparison`, brings together the results from all three forecasting approaches: the GI-aware ICL version, the GI-agnostic ICL version, and the GPR baseline to provide a direct performance comparison.\n", 721 | "\n", 722 | "This plot highlights how each approach performs relative to the others, making it easy to see the impact of including GI information and how both LLM-based methods compare to the traditional GPR baseline." 723 | ], 724 | "metadata": { 725 | "id": "CeV_VZY9xb7u" 726 | }, 727 | "id": "CeV_VZY9xb7u" 728 | }, 729 | { 730 | "cell_type": "code", 731 | "source": [ 732 | "#@title **Comparison Results**\n", 733 | "keys = list(comparison_results.keys())\n", 734 | "values = [float(v) for v in comparison_results.values()]\n", 735 | "\n", 736 | "colors = [\"#FF7777\" if \"Baseline\" in k else \"#88FF88\" for k in keys]\n", 737 | "\n", 738 | "plt.figure(figsize=(10, 5))\n", 739 | "bars = plt.bar(keys, values, color=colors)\n", 740 | "\n", 741 | "plt.xticks(rotation=45, ha='right')\n", 742 | "plt.ylabel(\"Error Value\")\n", 743 | "plt.title(\"MSE and RMSE Comparison Across Models\")\n", 744 | "\n", 745 | "plt.tight_layout()\n", 746 | "plt.show()" 747 | ], 748 | "metadata": { 749 | "id": "vfs58ikSoAht", 750 | "cellView": "form" 751 | }, 752 | "id": "vfs58ikSoAht", 753 | "execution_count": null, 754 | "outputs": [] 755 | }, 756 | { 757 | "cell_type": "markdown", 758 | "source": [ 759 | "## **Summary**\n", 760 | "\n", 761 | "This demo illustrates how LLMs can perform ICL to forecast physiological time-series data, using daily glucose measurements paired with meal events as the foundation. By providing the model with several full-day examples, we show that it can internalize natural glycemic patterns—such as post-meal spikes and gradual return to baseline—and extend them to generate predictions for an unseen future day.\n", 762 | "\n", 763 | "We explore two prompting strategies for predicting glucose levels. In the first strategy, the prompt explicitly instructs the model to infer GI values when generating its outputs. In the second strategy, the prompt provides only the input line, without mentioning GI at all, leaving the model to independently determine what information to extract and how to contextualize it. Comparing these two approaches allows us to analyze how the LLM’s outputs differ, how the presence or absence of explicit guidance shapes its reasoning process\n", 764 | "\n", 765 | "Alongside the two LLM-based approaches, we also construct a Gaussian Process Regression (GPR) baseline using traditional time-series lag features, providing a grounded benchmark for assessing the strengths and limitations of in-context learning in this setting. When compared against this baseline, both the ICL approach that explicitly incorporates the guidance for using the Glycemic Index and withut the GI guidance performs the best—capturing meaningful glucose trends and achieving the lowest MSE and RMSE values. On the other hand, while the baseline model achieves relatively average/similar error, it fails to reproduce realistic temporal patterns shown in the graph, instead relying strictly on short-term lag structure.\n", 766 | "\n", 767 | "The same methodology demonstrated here can be extended beyond glucose forecasting to a wide range of physiological and behavioral time-series tasks. By curating a small number of representative examples, LLMs can be prompted to learn patterns related to circadian rhythms, activity cycles, medication effects, or other temporal health signals. These extensions highlight the flexibility of in-context learning when applied to structured real-world sequences, opening the door to rapid pattern generalization without the need for explicit mechanistic models.\n", 768 | "\n", 769 | "## **Conclusion**\n", 770 | "\n", 771 | "This demonstration shows that LLMs can generalize temporal patterns in a way that resembles human intuition—capturing meal-driven glucose dynamics from only a handful of daily examples. Even without explicit physiological equations, the model can infer relationships between meals and glucose responses and apply them to forecast an entire future day. While the LLM does not perform mechanistic or medical reasoning, its ability to approximate nonlinear biological behavior through pattern recognition underscores the potential of in-context learning for health-related time-series tasks. Comparing the LLM’s predictions to a GPR baseline further reveals where LLMs excel and where traditional models remain competitive, offering a balanced perspective on the opportunities and limitations of treating LLMs as general sequence pattern learners." 772 | ], 773 | "metadata": { 774 | "id": "zxbHRie7uSK8" 775 | }, 776 | "id": "zxbHRie7uSK8" 777 | }, 778 | { 779 | "cell_type": "code", 780 | "source": [], 781 | "metadata": { 782 | "id": "K53XsVo34ict" 783 | }, 784 | "id": "K53XsVo34ict", 785 | "execution_count": null, 786 | "outputs": [] 787 | } 788 | ], 789 | "metadata": { 790 | "kernelspec": { 791 | "display_name": "base", 792 | "language": "python", 793 | "name": "python3" 794 | }, 795 | "language_info": { 796 | "codemirror_mode": { 797 | "name": "ipython", 798 | "version": 3 799 | }, 800 | "file_extension": ".py", 801 | "mimetype": "text/x-python", 802 | "name": "python", 803 | "nbconvert_exporter": "python", 804 | "pygments_lexer": "ipython3", 805 | "version": "3.13.5" 806 | }, 807 | "colab": { 808 | "provenance": [] 809 | } 810 | }, 811 | "nbformat": 4, 812 | "nbformat_minor": 5 813 | } --------------------------------------------------------------------------------