├── .gitignore ├── .gitmodules ├── CMakeLists.txt ├── LICENSE ├── MIMXRT1176xxxxx_cm7_ram.ld ├── README.md ├── llama2.h ├── llama4micro.gif ├── main.cc ├── models ├── README.md ├── export_coco_labels.py ├── llama2 │ ├── stories15M_q80.bin │ └── tokenizer.bin └── yolov5 │ ├── coco_labels.txt │ └── yolov5n-int8_edgetpu.tflite └── yolov5.h /.gitignore: -------------------------------------------------------------------------------- 1 | # VS Code 2 | .vscode 3 | 4 | # CMake 5 | build 6 | 7 | # Python 8 | venv 9 | __pycache__ 10 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "llama2.c"] 2 | path = llama2.c 3 | url = git@github.com:karpathy/llama2.c.git 4 | [submodule "coralmicro"] 5 | path = coralmicro 6 | url = git@github.com:google-coral/coralmicro.git 7 | [submodule "yolov5"] 8 | path = yolov5 9 | url = git@github.com:ultralytics/yolov5.git 10 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.13) 2 | 3 | set(CMAKE_TOOLCHAIN_FILE 4 | ${CMAKE_CURRENT_LIST_DIR}/coralmicro/cmake/toolchain-arm-none-eabi-gcc.cmake 5 | ) 6 | 7 | project(llama4micro) 8 | 9 | set(CMAKE_CXX_STANDARD 17) 10 | set(CMAKE_CXX_STANDARD_REQUIRED True) 11 | 12 | include_directories(coralmicro) 13 | add_subdirectory(coralmicro) 14 | 15 | include_directories(llama2.c) 16 | 17 | add_executable_m7(llama4micro 18 | main.cc 19 | LINKER_SCRIPT 20 | ${PROJECT_SOURCE_DIR}/MIMXRT1176xxxxx_cm7_ram.ld 21 | DATA 22 | ${PROJECT_SOURCE_DIR}/models/llama2/stories15M_q80.bin 23 | ${PROJECT_SOURCE_DIR}/models/llama2/tokenizer.bin 24 | ${PROJECT_SOURCE_DIR}/models/yolov5/yolov5n-int8_edgetpu.tflite 25 | ${PROJECT_SOURCE_DIR}/models/yolov5/coco_labels.txt 26 | ) 27 | 28 | target_link_libraries(llama4micro 29 | libs_base-m7_freertos 30 | ) 31 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Max Braun 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MIMXRT1176xxxxx_cm7_ram.ld: -------------------------------------------------------------------------------- 1 | /* 2 | ** ################################################################### 3 | ** Processors: MIMXRT1176AVM8A_cm7 4 | ** MIMXRT1176CVM8A_cm7 5 | ** MIMXRT1176DVMAA_cm7 6 | ** 7 | ** Compiler: GNU C Compiler 8 | ** Reference manual: IMXRT1170RM, Rev E, 12/2019 9 | ** Version: rev. 0.1, 2018-03-05 10 | ** Build: b200828 11 | ** 12 | ** Abstract: 13 | ** Linker file for the GNU C Compiler 14 | ** 15 | ** Copyright 2016 Freescale Semiconductor, Inc. 16 | ** Copyright 2016-2020 NXP 17 | ** All rights reserved. 18 | ** 19 | ** SPDX-License-Identifier: BSD-3-Clause 20 | ** 21 | ** http: www.nxp.com 22 | ** mail: support@nxp.com 23 | ** 24 | ** ################################################################### 25 | */ 26 | 27 | /* Entry Point */ 28 | ENTRY(Reset_Handler) 29 | 30 | HEAP_SIZE = DEFINED(__heap_size__) ? __heap_size__ : 0x03F09800; 31 | STACK_SIZE = DEFINED(__stack_size__) ? __stack_size__ : 0x0400; 32 | RPMSG_SHMEM_SIZE = DEFINED(__use_shmem__) ? 0x2000 : 0; 33 | NCACHE_SIZE = 0x8000; 34 | TEXT_SIZE = 0x5B400; 35 | 36 | /* 37 | * Valid memory regions: 38 | * 0x00000000 - 0x0003FFFF (256KB) 39 | * 0x20000000 - 0x2003FFFF (256KB) 40 | * 0x20240000 - 0x202BFFFF (512KB) 41 | * 0x202C0000 - 0x2033FFFF (512KB) 42 | */ 43 | 44 | /* Specify the memory areas */ 45 | MEMORY 46 | { 47 | m_interrupts (RX) : ORIGIN = 0x00000800, LENGTH = 0x00000400 48 | m_ncache (RW) : ORIGIN = 0x20000000, LENGTH = NCACHE_SIZE 49 | m_data (RW) : ORIGIN = 0x20000000 + NCACHE_SIZE, LENGTH = 0x00040000 - NCACHE_SIZE 50 | m_text (RX) : ORIGIN = 0x20240000, LENGTH = TEXT_SIZE 51 | m_ocram (RX) : ORIGIN = 0x20240000 + TEXT_SIZE, LENGTH = 0x00080000 - TEXT_SIZE 52 | rpmsg_sh_mem (RW) : ORIGIN = 0x202C0000, LENGTH = RPMSG_SHMEM_SIZE 53 | m_heap (RW) : ORIGIN = 0x80000000, LENGTH = HEAP_SIZE 54 | m_sdram (RX) : ORIGIN = 0x80000000 + HEAP_SIZE, LENGTH = 0x04000000 - HEAP_SIZE 55 | } 56 | 57 | /* Define output sections */ 58 | SECTIONS 59 | { 60 | __NCACHE_REGION_START = ORIGIN(m_ncache); 61 | __NCACHE_REGION_SIZE = LENGTH(m_ncache); 62 | 63 | __RPMSG_SH_MEM_START = ORIGIN(rpmsg_sh_mem); 64 | __RPMSG_SH_MEM_SIZE = LENGTH(rpmsg_sh_mem); 65 | 66 | /* NOINIT section for rpmsg_sh_mem */ 67 | .noinit_rpmsg_sh_mem (NOLOAD) : ALIGN(4) 68 | { 69 | __RPMSG_SH_MEM_START__ = .; 70 | *(.noinit.$rpmsg_sh_mem*) 71 | . = ALIGN(4) ; 72 | __RPMSG_SH_MEM_END__ = .; 73 | } > rpmsg_sh_mem 74 | 75 | /* section for storing the secondary core image */ 76 | .core1_code : 77 | { 78 | . = ALIGN(4) ; 79 | KEEP (*(.core1_code)) 80 | *(.core1_code*) 81 | . = ALIGN(4) ; 82 | } > m_sdram 83 | 84 | /* The startup code goes first into internal RAM */ 85 | .interrupts : 86 | { 87 | __VECTOR_TABLE = .; 88 | __Vectors = .; 89 | . = ALIGN(4); 90 | KEEP(*(.isr_vector)) /* Startup code */ 91 | . = ALIGN(4); 92 | } > m_interrupts 93 | 94 | .a71ch : 95 | { 96 | . = ALIGN(4); 97 | __a71ch_start = .; 98 | *liblibs_a71ch.a:*(.text .text*) 99 | *liblibs_a71ch.a:*(.data .data*) 100 | *liblibs_a71ch.a:*(.rodata .rodata*) 101 | __a71ch_end = .; 102 | . = ALIGN(4); 103 | } > m_ocram 104 | 105 | .curl : 106 | { 107 | . = ALIGN(4); 108 | __curl_start = .; 109 | *liblibs_curl.a:*(.text .text*) 110 | *liblibs_curl.a:*(.data .data*) 111 | *liblibs_curl.a:*(.rodata .rodata*) 112 | __curl_end = .; 113 | . = ALIGN(4); 114 | } > m_ocram 115 | 116 | .mbedtls : 117 | { 118 | . = ALIGN(4); 119 | __mbedtls_start = .; 120 | *liblibs_nxp_rt1176-sdk-mbedtls.a:*(.text .text*) 121 | *liblibs_nxp_rt1176-sdk-mbedtls.a:*(.data .data*) 122 | *liblibs_nxp_rt1176-sdk-mbedtls.a:*(.rodata .rodata*) 123 | __mbedtls_end = .; 124 | . = ALIGN(4); 125 | } > m_ocram 126 | 127 | .lwip : 128 | { 129 | . = ALIGN(4); 130 | __lwip_start = .; 131 | *liblibs_nxp_rt1176-sdk_lwip.a:*(.text .text*) 132 | *liblibs_nxp_rt1176-sdk_lwip.a:*(.data .data*) 133 | *liblibs_nxp_rt1176-sdk_lwip.a:*(.rodata .rodata*) 134 | *liblibs_nxp_rt1176-sdk_lwip_httpd.a:*(.text .text*) 135 | *liblibs_nxp_rt1176-sdk_lwip_httpd.a:*(.data .data*) 136 | *liblibs_nxp_rt1176-sdk_lwip_httpd.a:*(.rodata .rodata*) 137 | *liblibs_nxp_rt1176-sdk_lwip_mdns.a:*(.text .text*) 138 | *liblibs_nxp_rt1176-sdk_lwip_mdns.a:*(.data .data*) 139 | *liblibs_nxp_rt1176-sdk_lwip_mdns.a:*(.rodata .rodata*) 140 | __lwip_end = .; 141 | . = ALIGN(4); 142 | } > m_ocram 143 | 144 | .wiced : 145 | { 146 | . = ALIGN(4); 147 | __wiced_start = .; 148 | *liblibs_nxp_rt1176-sdk_wiced.a:*(.text .text*) 149 | *liblibs_nxp_rt1176-sdk_wiced.a:*(.data .data*) 150 | *liblibs_nxp_rt1176-sdk_wiced.a:*(.rodata .rodata*) 151 | __wiced_end = .; 152 | . = ALIGN(4); 153 | } > m_ocram 154 | 155 | .tensorflow : 156 | { 157 | . = ALIGN(4); 158 | __tensorflow_start = .; 159 | *liblibs_tensorflow-m7.a:*(.text .text*) 160 | *liblibs_tensorflow-m7.a:*(.data .data*) 161 | *liblibs_tensorflow-m7.a:*(.rodata .rodata*) 162 | *liblibs_kissfft-m7.a*:*(.text .text*) 163 | *liblibs_kissfft-m7.a*:*(.data .data*) 164 | *liblibs_kissfft-m7.a*:*(.rodata .rodata*) 165 | __tensorflow_end = .; 166 | . = ALIGN(4); 167 | } > m_ocram 168 | 169 | .arduino : 170 | { 171 | . = ALIGN(4); 172 | __arduino_start = .; 173 | *liblibs_arduino_coral_micro_bundled.a:*(.text .text*) 174 | *liblibs_arduino_coral_micro_bundled.a:*(.data .data*) 175 | *liblibs_arduino_coral_micro_bundled.a:*(.rodata .rodata*) 176 | *liblibs_arduino_coral_micro_poe_bundled.a:*(.text .text*) 177 | *liblibs_arduino_coral_micro_poe_bundled.a:*(.data .data*) 178 | *liblibs_arduino_coral_micro_poe_bundled.a:*(.rodata .rodata*) 179 | *liblibs_arduino_coral_micro_wifi_bundled.a:*(.text .text*) 180 | *liblibs_arduino_coral_micro_wifi_bundled.a:*(.data .data*) 181 | *liblibs_arduino_coral_micro_wifi_bundled.a:*(.rodata .rodata*) 182 | __arduino_end = .; 183 | . = ALIGN(4); 184 | } > m_ocram 185 | 186 | .libjpeg : 187 | { 188 | . = ALIGN(4); 189 | __libjpeg_start = .; 190 | *liblibs_libjpeg.a:*(.text .text*) 191 | *liblibs_libjpeg.a:*(.data .data*) 192 | *liblibs_libjpeg.a:*(.rodata .rodata*) 193 | __libjpeg_end = .; 194 | . = ALIGN(4); 195 | } > m_ocram 196 | 197 | .edgefast_bluetooth_text : 198 | { 199 | . = ALIGN(4); 200 | __edgefast_bluetooth_text_start__ = .; 201 | *libedgefast_bluetooth_internal.a:*(.text .text*) 202 | *lib_crypto_m7.a:*(.text .text*) 203 | *libethermind_ble_protocol.a:*(.text .text*) 204 | *libethermind_ble_gatt.a:*(.text .text*) 205 | *libethermind_ble_core.a:*(.text .text*) 206 | *libethermind_ble_util.a:*(.text .text*) 207 | __edgefast_bluetooth_text_end__ = .; 208 | . = ALIGN(4); 209 | } > m_sdram 210 | 211 | .edgefast_bluetooth_data : 212 | { 213 | . = ALIGN(4); 214 | __edgefast_bluetooth_data_start__ = .; 215 | *libedgefast_bluetooth_internal.a:*(.data .data*) 216 | *lib_crypto_m7.a:*(.data .data*) 217 | *libethermind_ble_protocol.a:*(.data .data*) 218 | *libethermind_ble_gatt.a:*(.data .data*) 219 | *libethermind_ble_core.a:*(.data .data*) 220 | *libethermind_ble_util.a:*(.data .data*) 221 | __edgefast_bluetooth_data_end__ = .; 222 | . = ALIGN(4); 223 | } > m_sdram 224 | 225 | .edgefast_bluetooth_rodata : 226 | { 227 | . = ALIGN(4); 228 | __edgefast_bluetooth_rodata_start__ = .; 229 | *libedgefast_bluetooth_internal.a:*(.rodata .rodata*) 230 | *lib_crypto_m7.a:*(.rodata .rodata*) 231 | *libethermind_ble_protocol.a:*(.rodata .rodata*) 232 | *libethermind_ble_gatt.a:*(.rodata .rodata*) 233 | *libethermind_ble_core.a:*(.rodata .rodata*) 234 | *libethermind_ble_util.a:*(.rodata .rodata*) 235 | __edgefast_bluetooth_rodata_end__ = .; 236 | . = ALIGN(4); 237 | } > m_sdram 238 | 239 | /* The program code and other data goes into internal RAM */ 240 | .text : 241 | { 242 | . = ALIGN(4); 243 | *(.text) /* .text sections (code) */ 244 | *(.text*) /* .text* sections (code) */ 245 | *(.rodata) /* .rodata sections (constants, strings, etc.) */ 246 | *(.rodata*) /* .rodata* sections (constants, strings, etc.) */ 247 | KEEP(*(.rodata.debug)) 248 | *(.glue_7) /* glue arm to thumb code */ 249 | *(.glue_7t) /* glue thumb to arm code */ 250 | *(.eh_frame) 251 | KEEP (*(.init)) 252 | KEEP (*(.fini)) 253 | . = ALIGN(4); 254 | } > m_text 255 | 256 | ._settings_handler_static : 257 | { 258 | . = ALIGN(4); 259 | _settings_handler_static_list_start = .; 260 | KEEP(*(SORT(._settings_handler_static.static.*))) 261 | _settings_handler_static_list_end = .; 262 | . = ALIGN(4); 263 | } > m_text 264 | 265 | ._bt_gatt_service_static : 266 | { 267 | . = ALIGN(4); 268 | _bt_gatt_service_static_list_start = .; 269 | KEEP(*(SORT(._bt_gatt_service_static.static.*))) 270 | _bt_gatt_service_static_list_end = .; 271 | . = ALIGN(4); 272 | } > m_text 273 | 274 | ._bt_l2cap_fixed_chan : 275 | { 276 | . = ALIGN(4); 277 | _bt_l2cap_fixed_chan_list_start = .; 278 | KEEP(*(SORT(._bt_l2cap_fixed_chan.static.*))) 279 | _bt_l2cap_fixed_chan_list_end = .; 280 | . = ALIGN(4); 281 | } > m_text 282 | 283 | .ARM.extab : 284 | { 285 | *(.ARM.extab* .gnu.linkonce.armextab.*) 286 | } > m_text 287 | 288 | .ARM : 289 | { 290 | __exidx_start = .; 291 | *(.ARM.exidx*) 292 | __exidx_end = .; 293 | } > m_text 294 | 295 | .ctors : 296 | { 297 | __CTOR_LIST__ = .; 298 | /* gcc uses crtbegin.o to find the start of 299 | the constructors, so we make sure it is 300 | first. Because this is a wildcard, it 301 | doesn't matter if the user does not 302 | actually link against crtbegin.o; the 303 | linker won't look for a file to match a 304 | wildcard. The wildcard also means that it 305 | doesn't matter which directory crtbegin.o 306 | is in. */ 307 | KEEP (*crtbegin.o(.ctors)) 308 | KEEP (*crtbegin?.o(.ctors)) 309 | /* We don't want to include the .ctor section from 310 | from the crtend.o file until after the sorted ctors. 311 | The .ctor section from the crtend file contains the 312 | end of ctors marker and it must be last */ 313 | KEEP (*(EXCLUDE_FILE(*crtend?.o *crtend.o) .ctors)) 314 | KEEP (*(SORT(.ctors.*))) 315 | KEEP (*(.ctors)) 316 | __CTOR_END__ = .; 317 | } > m_text 318 | 319 | .dtors : 320 | { 321 | __DTOR_LIST__ = .; 322 | KEEP (*crtbegin.o(.dtors)) 323 | KEEP (*crtbegin?.o(.dtors)) 324 | KEEP (*(EXCLUDE_FILE(*crtend?.o *crtend.o) .dtors)) 325 | KEEP (*(SORT(.dtors.*))) 326 | KEEP (*(.dtors)) 327 | __DTOR_END__ = .; 328 | } > m_text 329 | 330 | .preinit_array : 331 | { 332 | PROVIDE_HIDDEN (__preinit_array_start = .); 333 | KEEP (*(.preinit_array*)) 334 | PROVIDE_HIDDEN (__preinit_array_end = .); 335 | } > m_text 336 | 337 | .init_array : 338 | { 339 | PROVIDE_HIDDEN (__init_array_start = .); 340 | KEEP (*(SORT(.init_array.*))) 341 | KEEP (*(.init_array*)) 342 | PROVIDE_HIDDEN (__init_array_end = .); 343 | } > m_text 344 | 345 | .fini_array : 346 | { 347 | PROVIDE_HIDDEN (__fini_array_start = .); 348 | KEEP (*(SORT(.fini_array.*))) 349 | KEEP (*(.fini_array*)) 350 | PROVIDE_HIDDEN (__fini_array_end = .); 351 | } > m_text 352 | 353 | __etext = .; /* define a global symbol at end of code */ 354 | __DATA_ROM = .; /* Symbol is used by startup for data initialization */ 355 | 356 | __VECTOR_RAM = ORIGIN(m_interrupts); 357 | __RAM_VECTOR_TABLE_SIZE_BYTES = 0x0; 358 | 359 | .data : AT(__DATA_ROM) 360 | { 361 | . = ALIGN(4); 362 | __DATA_RAM = .; 363 | __data_start__ = .; /* create a global symbol at data start */ 364 | *(m_usb_dma_init_data) 365 | *(.data) /* .data sections */ 366 | *(.data*) /* .data* sections */ 367 | *(.got) 368 | *(.got*) 369 | *(.got.plt) 370 | *(.got.plt*) 371 | KEEP(*(.jcr*)) 372 | . = ALIGN(4); 373 | _net_buf_pool_list = .; 374 | KEEP(*(SORT(._net_buf_pool*))) 375 | . = ALIGN(4); 376 | __data_end__ = .; /* define a global symbol at data end */ 377 | } > m_data 378 | __NDATA_ROM = __DATA_ROM + (__data_end__ - __data_start__); 379 | .ncache.init : AT(__NDATA_ROM) 380 | { 381 | __noncachedata_start__ = .; /* create a global symbol at ncache data start */ 382 | *(NonCacheable.init) 383 | . = ALIGN(4); 384 | __noncachedata_init_end__ = .; /* create a global symbol at initialized ncache data end */ 385 | } > m_ncache 386 | . = __noncachedata_init_end__; 387 | .ncache : 388 | { 389 | *(NonCacheable) 390 | . = ALIGN(4); 391 | __noncachedata_end__ = .; /* define a global symbol at ncache data end */ 392 | } > m_ncache 393 | __SDRAM_ROM = __NDATA_ROM + (__noncachedata_init_end__ - __noncachedata_start__); 394 | .sdram_data : AT(__SDRAM_ROM) 395 | { 396 | . = ALIGN(4); 397 | __sdram_data_start__ = .; 398 | KEEP(*(.sdram_data*)) 399 | . = ALIGN(4); 400 | __sdram_data_end__ = .; 401 | } > m_sdram 402 | 403 | /* Uninitialized data section */ 404 | .bss : 405 | { 406 | /* This is used by the startup in order to initialize the .bss section */ 407 | . = ALIGN(4); 408 | __START_BSS = .; 409 | __bss_start__ = .; 410 | *(m_usb_dma_noninit_data) 411 | *(.bss) 412 | *(.bss*) 413 | *(COMMON) 414 | . = ALIGN(4); 415 | __bss_end__ = .; 416 | __END_BSS = .; 417 | } > m_data 418 | __DATA_END = __SDRAM_ROM; 419 | text_end = ORIGIN(m_text) + LENGTH(m_text); 420 | ASSERT(__DATA_END <= text_end, "region m_text overflowed with text and data") 421 | 422 | .sdram_bss (NOLOAD) : 423 | { 424 | . = ALIGN(4); 425 | __sdram_bss_start__ = .; 426 | *(.sdram_bss*) 427 | /* edgefast_bluetooth_bss */ 428 | *lib_crypto_m7.a:*(.bss .bss*) 429 | *libedgefast_bluetooth_internal.a:*(.bss .bss*) 430 | *libethermind_ble_protocol.a:*(.bss .bss*) 431 | *libethermind_ble_gatt.a:*(.bss .bss*) 432 | *libethermind_ble_core.a:*(.bss .bss*) 433 | *libethermind_ble_util.a:*(.bss .bss*) 434 | . = ALIGN(4); 435 | __sdram_bss_end__ = .; 436 | } > m_sdram 437 | 438 | .heap : 439 | { 440 | . = ALIGN(8); 441 | __end__ = .; 442 | PROVIDE(end = .); 443 | __HeapBase = .; 444 | . += HEAP_SIZE; 445 | __HeapLimit = .; 446 | __heap_limit = .; /* Add for _sbrk */ 447 | } > m_heap 448 | 449 | .stack : 450 | { 451 | . = ALIGN(8); 452 | . += STACK_SIZE; 453 | } > m_data 454 | 455 | /* Initializes stack on the end of block */ 456 | __StackTop = ORIGIN(m_data) + LENGTH(m_data); 457 | __StackLimit = __StackTop - STACK_SIZE; 458 | PROVIDE(__stack = __StackTop); 459 | 460 | .ARM.attributes 0 : { *(.ARM.attributes) } 461 | } 462 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # llama4micro 🦙🔬 2 | 3 | A "large" language model running on a microcontroller. 4 | 5 | ![Example run](llama4micro.gif) 6 | 7 | ## Background 8 | 9 | I was wondering if it's possible to fit a non-trivial language model on a microcontroller. Turns out the answer is some version of yes! (Later, things got a bit out of hand and now the prompt is based on objects detected by the camera.) 10 | 11 | This project is using the [Coral Dev Board Micro](https://coral.ai/products/dev-board-micro) with its [FreeRTOS toolchain](https://coral.ai/docs/dev-board-micro/freertos/). The board has a number of neat [hardware features](https://coral.ai/docs/dev-board-micro/get-started/#the-hardware), but – most importantly for our purposes – it has 64MB of RAM. That's tiny for LLMs, which are typically measured in the GBs, but comparatively huge for a microcontroller. 12 | 13 | The LLM implementation itself is an adaptation of [llama2.c](https://github.com/karpathy/llama2.c) and the [tinyllamas](https://huggingface.co/karpathy/tinyllamas/tree/main) checkpoints trained on the [TinyStories](https://huggingface.co/datasets/roneneldan/TinyStories) dataset. The quality of the smaller model versions isn't ideal, but good enough to generate somewhat coherent (and occasionally weird) stories. 14 | 15 | > [!NOTE] 16 | > Language model inference runs on the 800 MHz [Arm Cortex-M7](https://developer.arm.com/Processors/Cortex-M7) CPU core. Camera image classification uses the [Edge TPU](https://coral.ai/technology/) and a [compiled](https://coral.ai/docs/edgetpu/compiler/) [YOLOv5 model](https://github.com/ultralytics/yolov5). The board also has a second 400 MHz [Arm Cortex-M4](https://developer.arm.com/Processors/Cortex-M4) CPU core, which is currently unused. 17 | 18 | ## Setup 19 | 20 | Clone this repo with its submodules [`karpathy/llama2.c`](https://github.com/karpathy/llama2.c), [`google-coral/coralmicro`](https://github.com/google-coral/coralmicro), and [`ultralytics/yolov5`](https://github.com/ultralytics/yolov5). 21 | 22 | ```bash 23 | git clone --recurse-submodules https://github.com/maxbbraun/llama4micro.git 24 | 25 | cd llama4micro 26 | ``` 27 | 28 | The pre-trained models are in the [`models/`](models/) directory. Refer to the [instructions](models/README.md) on how to download and convert them. 29 | 30 | Build the image: 31 | 32 | ```bash 33 | mkdir build 34 | cd build 35 | 36 | cmake .. 37 | make -j 38 | ``` 39 | 40 | Flash the image: 41 | 42 | ```bash 43 | python3 -m venv venv 44 | . venv/bin/activate 45 | 46 | pip install -r ../coralmicro/scripts/requirements.txt 47 | 48 | python ../coralmicro/scripts/flashtool.py \ 49 | --build_dir . \ 50 | --elf_path llama4micro 51 | ``` 52 | 53 | ## Usage 54 | 55 | 1. The models load automatically when the board powers up. 56 | - This takes ~7 seconds. 57 | - The green light will turn on when ready. 58 | 2. Point the camera at an object and press the button. 59 | - The green light will turn off. 60 | - The camera will take a picture and detect an object. 61 | 3. The model now generates tokens starting with a prompt based on the object. 62 | - The results are streamed to the serial port. 63 | - This happens at a rate of ~2.5 tokens per second. 64 | 4. Generation stops after the end token or maximum steps. 65 | - The green light will turn on again. 66 | - Goto 2. 67 | -------------------------------------------------------------------------------- /llama2.h: -------------------------------------------------------------------------------- 1 | /* Forked from llama2.c/runq.c (https://github.com/karpathy/llama2.c) */ 2 | 3 | /* Inference for Llama-2 Transformer model in pure C, int8 quantized forward pass. */ 4 | 5 | #include 6 | #include 7 | 8 | #include "libs/base/filesystem.h" 9 | #include "libs/base/timer.h" 10 | 11 | // Similar API to fread() for reading from a byte array instead. 12 | inline void sread(void *buffer, size_t size, uint8_t* *stream) { 13 | memcpy(buffer, *stream, size); 14 | *stream += size; 15 | } 16 | 17 | // ---------------------------------------------------------------------------- 18 | // Transformer model 19 | 20 | typedef struct { 21 | int dim; // transformer dimension 22 | int hidden_dim; // for ffn layers 23 | int n_layers; // number of layers 24 | int n_heads; // number of query heads 25 | int n_kv_heads; // number of key/value heads (can be < query heads because of multiquery) 26 | int vocab_size; // vocabulary size, usually 256 (byte-level) 27 | int seq_len; // max sequence length 28 | } Config; 29 | 30 | typedef struct { 31 | int8_t* q; // quantized values 32 | float* s; // scaling factors 33 | } QuantizedTensor; 34 | 35 | typedef struct { 36 | // token embedding table 37 | QuantizedTensor *q_tokens; // (vocab_size, dim) 38 | float* token_embedding_table; // same, but dequantized 39 | 40 | // weights for rmsnorms 41 | float* rms_att_weight; // (layer, dim) rmsnorm weights 42 | float* rms_ffn_weight; // (layer, dim) 43 | // weights for matmuls. note dim == n_heads * head_size 44 | QuantizedTensor *wq; // (layer, dim, n_heads * head_size) 45 | QuantizedTensor *wk; // (layer, dim, n_kv_heads * head_size) 46 | QuantizedTensor *wv; // (layer, dim, n_kv_heads * head_size) 47 | QuantizedTensor *wo; // (layer, n_heads * head_size, dim) 48 | // weights for ffn 49 | QuantizedTensor *w1; // (layer, hidden_dim, dim) 50 | QuantizedTensor *w2; // (layer, dim, hidden_dim) 51 | QuantizedTensor *w3; // (layer, hidden_dim, dim) 52 | // final rmsnorm 53 | float* rms_final_weight; // (dim,) 54 | // (optional) classifier weights for the logits, on the last layer 55 | QuantizedTensor *wcls; 56 | } TransformerWeights; 57 | 58 | typedef struct { 59 | // current wave of activations 60 | float *x; // activation at current time stamp (dim,) 61 | float *xb; // same, but inside a residual branch (dim,) 62 | float *xb2; // an additional buffer just for convenience (dim,) 63 | float *hb; // buffer for hidden dimension in the ffn (hidden_dim,) 64 | float *hb2; // buffer for hidden dimension in the ffn (hidden_dim,) 65 | QuantizedTensor xq; // quantized x (dim,) 66 | QuantizedTensor hq; // quantized hb (hidden_dim,) 67 | float *q; // query (dim,) 68 | float *k; // key (dim,) 69 | float *v; // value (dim,) 70 | float *att; // buffer for scores/attention values (n_heads, seq_len) 71 | float *logits; // output logits 72 | // kv cache 73 | float* key_cache; // (layer, seq_len, dim) 74 | float* value_cache; // (layer, seq_len, dim) 75 | } RunState; 76 | 77 | typedef struct { 78 | Config config; // the hyperparameters of the architecture (the blueprint) 79 | TransformerWeights weights; // the weights of the model 80 | RunState state; // buffers for the "wave" of activations in the forward pass 81 | // some more state needed to properly clean up the memory mapping (sigh) 82 | float* data; // memory mapped data pointer 83 | } Transformer; 84 | 85 | void malloc_run_state(RunState* s, Config* p) { 86 | // we calloc instead of malloc to keep valgrind happy 87 | int kv_dim = (p->dim * p->n_kv_heads) / p->n_heads; 88 | s->x = (float*)calloc(p->dim, sizeof(float)); 89 | s->xb = (float*)calloc(p->dim, sizeof(float)); 90 | s->xb2 = (float*)calloc(p->dim, sizeof(float)); 91 | s->hb = (float*)calloc(p->hidden_dim, sizeof(float)); 92 | s->hb2 = (float*)calloc(p->hidden_dim, sizeof(float)); 93 | s->xq = (QuantizedTensor) { .q = (int8_t*)calloc(p->dim, sizeof(int8_t)), .s = (float*)calloc(p->dim, sizeof(float)) }; 94 | s->hq = (QuantizedTensor) { .q = (int8_t*)calloc(p->hidden_dim, sizeof(int8_t)), .s = (float*)calloc(p->hidden_dim, sizeof(float)) }; 95 | s->q = (float*)calloc(p->dim, sizeof(float)); 96 | s->k = (float*)calloc(kv_dim, sizeof(float)); 97 | s->v = (float*)calloc(kv_dim, sizeof(float)); 98 | s->att = (float*)calloc(p->n_heads * p->seq_len, sizeof(float)); 99 | s->logits = (float*)calloc(p->vocab_size, sizeof(float)); 100 | s->key_cache = (float*)calloc(p->n_layers * p->seq_len * kv_dim, sizeof(float)); 101 | s->value_cache = (float*)calloc(p->n_layers * p->seq_len * kv_dim, sizeof(float)); 102 | // ensure all mallocs went fine 103 | if (!s->x || !s->xb || !s->xb2 || !s->hb || !s->hb2 || !s->q 104 | || !s->k || !s->v || !s->att || !s->logits || !s->key_cache 105 | || !s->value_cache) { 106 | fprintf(stderr, "malloc failed!\n"); 107 | exit(EXIT_FAILURE); 108 | } 109 | } 110 | 111 | void free_run_state(RunState* s) { 112 | free(s->x); 113 | free(s->xb); 114 | free(s->xb2); 115 | free(s->hb); 116 | free(s->hb2); 117 | free(s->xq.q); 118 | free(s->xq.s); 119 | free(s->hq.q); 120 | free(s->hq.s); 121 | free(s->q); 122 | free(s->k); 123 | free(s->v); 124 | free(s->att); 125 | free(s->logits); 126 | free(s->key_cache); 127 | free(s->value_cache); 128 | } 129 | 130 | // ---------------------------------------------------------------------------- 131 | // Quantization functions 132 | 133 | void dequantize(QuantizedTensor *qx, float* x, int n, int gs) { 134 | for (int i = 0; i < n; i++) { 135 | x[i] = qx->q[i] * qx->s[i / gs]; 136 | } 137 | } 138 | 139 | void quantize(QuantizedTensor *qx, float* x, int n, int gs) { 140 | int num_groups = n / gs; 141 | float Q_MAX = 127.0f; 142 | 143 | for (int group = 0; group < num_groups; group++) { 144 | 145 | // find the max absolute value in the current group 146 | float wmax = 0.0; 147 | for (int i = 0; i < gs; i++) { 148 | float val = fabs(x[group * gs + i]); 149 | if (val > wmax) { 150 | wmax = val; 151 | } 152 | } 153 | 154 | // calculate and write the scaling factor 155 | float scale = wmax / Q_MAX; 156 | qx->s[group] = scale; 157 | 158 | // calculate and write the quantized values 159 | for (int i = 0; i < gs; i++) { 160 | float quant_value = x[group * gs + i] / scale; // scale 161 | int8_t quantized = (int8_t) round(quant_value); // round and clamp 162 | qx->q[group * gs + i] = quantized; 163 | } 164 | } 165 | } 166 | 167 | /* initialize `n` x quantized tensor (with `size_each` elements), starting from memory pointed at *ptr */ 168 | QuantizedTensor *init_quantized_tensors(void **ptr, int n, int size_each, int gs) { 169 | void *p = *ptr; 170 | QuantizedTensor *res = (QuantizedTensor*)malloc(n * sizeof(QuantizedTensor)); 171 | for(int i=0; idim / p->n_heads; 185 | // first are the parameters that are kept in fp32 (the rmsnorm (1D) weights) 186 | float* fptr = (float*) ptr; // cast our pointer to float* 187 | w->rms_att_weight = fptr; 188 | fptr += p->n_layers * p->dim; 189 | w->rms_ffn_weight = fptr; 190 | fptr += p->n_layers * p->dim; 191 | w->rms_final_weight = fptr; 192 | fptr += p->dim; 193 | 194 | // now read all the quantized weights 195 | ptr = (void*)fptr; // now cast the pointer back to void* 196 | w->q_tokens = init_quantized_tensors(&ptr, 1, p->vocab_size * p->dim, gs); 197 | // dequantize token embedding table 198 | w->token_embedding_table = (float*)malloc(p->vocab_size * p->dim * sizeof(float)); 199 | dequantize(w->q_tokens, w->token_embedding_table, p->vocab_size * p->dim, gs); 200 | 201 | w->wq = init_quantized_tensors(&ptr, p->n_layers, p->dim * (p->n_heads * head_size), gs); 202 | w->wk = init_quantized_tensors(&ptr, p->n_layers, p->dim * (p->n_kv_heads * head_size), gs); 203 | w->wv = init_quantized_tensors(&ptr, p->n_layers, p->dim * (p->n_kv_heads * head_size), gs); 204 | w->wo = init_quantized_tensors(&ptr, p->n_layers, (p->n_heads * head_size) * p->dim, gs); 205 | 206 | w->w1 = init_quantized_tensors(&ptr, p->n_layers, p->dim * p->hidden_dim, gs); 207 | w->w2 = init_quantized_tensors(&ptr, p->n_layers, p->hidden_dim * p->dim, gs); 208 | w->w3 = init_quantized_tensors(&ptr, p->n_layers, p->dim * p->hidden_dim, gs); 209 | 210 | w->wcls = shared_classifier ? w->q_tokens : init_quantized_tensors(&ptr, 1, p->dim * p->vocab_size, gs); 211 | } 212 | 213 | void read_checkpoint(const char* checkpoint, std::vector* checkpoint_buffer, 214 | Config* config, TransformerWeights* weights, float** data, int* group_size) { 215 | if (!coralmicro::LfsReadFile(checkpoint, checkpoint_buffer)) { 216 | fprintf(stderr, "Couldn't open file %s\n", checkpoint); 217 | exit(EXIT_FAILURE); 218 | } 219 | uint8_t* checkpoint_ptr = checkpoint_buffer->data(); 220 | // read in magic number (uint32), has to be 0x616b3432, i.e. "ak42" in ASCII 221 | uint32_t magic_number; 222 | sread(&magic_number, sizeof(uint32_t), &checkpoint_ptr); 223 | if (magic_number != 0x616b3432) { fprintf(stderr, "Bad magic number\n"); exit(EXIT_FAILURE); } 224 | // read in the version number (uint32), has to be 2 225 | int version; 226 | sread(&version, sizeof(int), &checkpoint_ptr); 227 | if (version != 2) { fprintf(stderr, "Bad version %d, need version 2\n", version); exit(EXIT_FAILURE); } 228 | int header_size = 256; // the header size for version 2 in bytes 229 | // read in the Config 230 | sread(config, sizeof(Config), &checkpoint_ptr); 231 | // read in flags 232 | uint8_t shared_classifier; // a byte to indicate if the classifier is shared 233 | sread(&shared_classifier, sizeof(uint8_t), &checkpoint_ptr); 234 | // the group size used in quantization 235 | sread(group_size, sizeof(int), &checkpoint_ptr); 236 | // point the Transformer weights at the data pointer 237 | *data = (float*)checkpoint_buffer->data(); 238 | void* weights_ptr = ((char*)*data) + header_size; // skip header bytes. char is 1 byte 239 | memory_map_weights(weights, config, weights_ptr, shared_classifier, *group_size); 240 | } 241 | 242 | void build_transformer(Transformer *t, const char* checkpoint_path, 243 | std::vector* checkpoint_buffer, int* group_size) { 244 | // read in the Config and the Weights from the checkpoint 245 | read_checkpoint(checkpoint_path, checkpoint_buffer, &t->config, &t->weights, &t->data, 246 | group_size); 247 | // allocate the RunState buffers 248 | malloc_run_state(&t->state, &t->config); 249 | } 250 | 251 | void free_transformer(Transformer* t) { 252 | // free QuantizedTensors 253 | free(t->weights.q_tokens); 254 | free(t->weights.token_embedding_table); 255 | free(t->weights.wq); 256 | free(t->weights.wk); 257 | free(t->weights.wv); 258 | free(t->weights.wo); 259 | free(t->weights.w1); 260 | free(t->weights.w2); 261 | free(t->weights.w3); 262 | if(t->weights.wcls != t->weights.q_tokens) { free(t->weights.wcls); } 263 | // free the RunState buffers 264 | free_run_state(&t->state); 265 | } 266 | 267 | // ---------------------------------------------------------------------------- 268 | // neural net blocks; the dynamics of the Transformer 269 | 270 | void rmsnorm(float* o, float* x, float* weight, int size) { 271 | // calculate sum of squares 272 | float ss = 0.0f; 273 | for (int j = 0; j < size; j++) { 274 | ss += x[j] * x[j]; 275 | } 276 | ss /= size; 277 | ss += 1e-5f; 278 | ss = 1.0f / sqrtf(ss); 279 | // normalize and scale 280 | for (int j = 0; j < size; j++) { 281 | o[j] = weight[j] * (ss * x[j]); 282 | } 283 | } 284 | 285 | void softmax(float* x, int size) { 286 | // find max value (for numerical stability) 287 | float max_val = x[0]; 288 | for (int i = 1; i < size; i++) { 289 | if (x[i] > max_val) { 290 | max_val = x[i]; 291 | } 292 | } 293 | // exp and sum 294 | float sum = 0.0f; 295 | for (int i = 0; i < size; i++) { 296 | x[i] = expf(x[i] - max_val); 297 | sum += x[i]; 298 | } 299 | // normalize 300 | for (int i = 0; i < size; i++) { 301 | x[i] /= sum; 302 | } 303 | } 304 | 305 | void matmul(float* xout, QuantizedTensor *x, QuantizedTensor *w, int n, int d, int gs) { 306 | // W (d,n) @ x (n,) -> xout (d,) 307 | // by far the most amount of time is spent inside this little function 308 | // inputs to this function are both quantized 309 | 310 | int i; 311 | for (i = 0; i < d; i++) { 312 | 313 | float val = 0.0f; 314 | int32_t ival = 0; 315 | int in = i * n; 316 | 317 | // do the matmul in groups of GS 318 | int j; 319 | for (j = 0; j <= n - gs; j += gs) { 320 | for (int k = 0; k < gs; k++) { 321 | ival += ((int32_t) x->q[j + k]) * ((int32_t) w->q[in + j + k]); 322 | } 323 | val += ((float) ival) * w->s[(in + j) / gs] * x->s[j / gs]; 324 | ival = 0; 325 | } 326 | 327 | xout[i] = val; 328 | } 329 | } 330 | 331 | float* forward(Transformer* transformer, int token, int pos, int gs) { 332 | 333 | // a few convenience variables 334 | Config* p = &transformer->config; 335 | TransformerWeights* w = &transformer->weights; 336 | RunState* s = &transformer->state; 337 | float *x = s->x; 338 | int dim = p->dim; 339 | int kv_dim = (p->dim * p->n_kv_heads) / p->n_heads; 340 | int kv_mul = p->n_heads / p->n_kv_heads; // integer multiplier of the kv sharing in multiquery 341 | int hidden_dim = p->hidden_dim; 342 | int head_size = dim / p->n_heads; 343 | 344 | // copy the token embedding into x 345 | memcpy(x, w->token_embedding_table + token*dim, dim * sizeof(float)); 346 | 347 | // forward all the layers 348 | for(int l = 0; l < p->n_layers; l++) { 349 | 350 | // attention rmsnorm 351 | rmsnorm(s->xb, x, w->rms_att_weight + l*dim, dim); 352 | 353 | // qkv matmuls for this position 354 | quantize(&s->xq, s->xb, dim, gs); 355 | matmul(s->q, &s->xq, w->wq + l, dim, dim, gs); 356 | matmul(s->k, &s->xq, w->wk + l, dim, kv_dim, gs); 357 | matmul(s->v, &s->xq, w->wv + l, dim, kv_dim, gs); 358 | 359 | // RoPE relative positional encoding: complex-valued rotate q and k in each head 360 | for (int i = 0; i < dim; i+=2) { 361 | int head_dim = i % head_size; 362 | float freq = 1.0f / powf(10000.0f, head_dim / (float)head_size); 363 | float val = pos * freq; 364 | float fcr = cosf(val); 365 | float fci = sinf(val); 366 | int rotn = i < kv_dim ? 2 : 1; // how many vectors? 2 = q & k, 1 = q only 367 | for (int v = 0; v < rotn; v++) { 368 | float* vec = v == 0 ? s->q : s->k; // the vector to rotate (query or key) 369 | float v0 = vec[i]; 370 | float v1 = vec[i+1]; 371 | vec[i] = v0 * fcr - v1 * fci; 372 | vec[i+1] = v0 * fci + v1 * fcr; 373 | } 374 | } 375 | 376 | // save key,value at this time step (pos) to our kv cache 377 | int loff = l * p->seq_len * kv_dim; // kv cache layer offset for convenience 378 | float* key_cache_row = s->key_cache + loff + pos * kv_dim; 379 | float* value_cache_row = s->value_cache + loff + pos * kv_dim; 380 | memcpy(key_cache_row, s->k, kv_dim * sizeof(*key_cache_row)); 381 | memcpy(value_cache_row, s->v, kv_dim * sizeof(*value_cache_row)); 382 | 383 | // multihead attention. iterate over all heads 384 | int h; 385 | for (h = 0; h < p->n_heads; h++) { 386 | // get the query vector for this head 387 | float* q = s->q + h * head_size; 388 | // attention scores for this head 389 | float* att = s->att + h * p->seq_len; 390 | // iterate over all timesteps, including the current one 391 | for (int t = 0; t <= pos; t++) { 392 | // get the key vector for this head and at this timestep 393 | float* k = s->key_cache + loff + t * kv_dim + (h / kv_mul) * head_size; 394 | // calculate the attention score as the dot product of q and k 395 | float score = 0.0f; 396 | for (int i = 0; i < head_size; i++) { 397 | score += q[i] * k[i]; 398 | } 399 | score /= sqrtf(head_size); 400 | // save the score to the attention buffer 401 | att[t] = score; 402 | } 403 | 404 | // softmax the scores to get attention weights, from 0..pos inclusively 405 | softmax(att, pos + 1); 406 | 407 | // weighted sum of the values, store back into xb 408 | float* xb = s->xb + h * head_size; 409 | memset(xb, 0, head_size * sizeof(float)); 410 | for (int t = 0; t <= pos; t++) { 411 | // get the value vector for this head and at this timestep 412 | float* v = s->value_cache + loff + t * kv_dim + (h / kv_mul) * head_size; 413 | // get the attention weight for this timestep 414 | float a = att[t]; 415 | // accumulate the weighted value into xb 416 | for (int i = 0; i < head_size; i++) { 417 | xb[i] += a * v[i]; 418 | } 419 | } 420 | } 421 | 422 | // final matmul to get the output of the attention 423 | quantize(&s->xq, s->xb, dim, gs); 424 | matmul(s->xb2, &s->xq, w->wo + l, dim, dim, gs); 425 | 426 | // residual connection back into x 427 | for (int i = 0; i < dim; i++) { 428 | x[i] += s->xb2[i]; 429 | } 430 | 431 | // ffn rmsnorm 432 | rmsnorm(s->xb, x, w->rms_ffn_weight + l*dim, dim); 433 | 434 | // Now for FFN in PyTorch we have: self.w2(F.silu(self.w1(x)) * self.w3(x)) 435 | // first calculate self.w1(x) and self.w3(x) 436 | quantize(&s->xq, s->xb, dim, gs); 437 | matmul(s->hb, &s->xq, w->w1 + l, dim, hidden_dim, gs); 438 | matmul(s->hb2, &s->xq, w->w3 + l, dim, hidden_dim, gs); 439 | 440 | // SwiGLU non-linearity 441 | for (int i = 0; i < hidden_dim; i++) { 442 | float val = s->hb[i]; 443 | // silu(x)=x*σ(x), where σ(x) is the logistic sigmoid 444 | val *= (1.0f / (1.0f + expf(-val))); 445 | // elementwise multiply with w3(x) 446 | val *= s->hb2[i]; 447 | s->hb[i] = val; 448 | } 449 | 450 | // final matmul to get the output of the ffn 451 | quantize(&s->hq, s->hb, hidden_dim, gs); 452 | matmul(s->xb, &s->hq, w->w2 + l, hidden_dim, dim, gs); 453 | 454 | // residual connection 455 | for (int i = 0; i < dim; i++) { 456 | x[i] += s->xb[i]; 457 | } 458 | } 459 | 460 | // final rmsnorm 461 | rmsnorm(x, x, w->rms_final_weight, dim); 462 | 463 | // classifier into logits 464 | quantize(&s->xq, x, dim, gs); 465 | matmul(s->logits, &s->xq, w->wcls, dim, p->vocab_size, gs); 466 | return s->logits; 467 | } 468 | 469 | // ---------------------------------------------------------------------------- 470 | // The Byte Pair Encoding (BPE) Tokenizer that translates strings <-> tokens 471 | 472 | typedef struct { 473 | const char *str; 474 | int id; 475 | } TokenIndex; 476 | 477 | typedef struct { 478 | char** vocab; 479 | float* vocab_scores; 480 | TokenIndex *sorted_vocab; 481 | int vocab_size; 482 | unsigned int max_token_length; 483 | unsigned char byte_pieces[512]; // stores all single-byte strings 484 | } Tokenizer; 485 | 486 | int compare_tokens(const void *a, const void *b) { 487 | return strcmp(((TokenIndex*)a)->str, ((TokenIndex*)b)->str); 488 | } 489 | 490 | void build_tokenizer(Tokenizer* t, const char* tokenizer_path, 491 | std::vector* tokenizer_buffer, int vocab_size) { 492 | // i should have written the vocab_size into the tokenizer file... sigh 493 | t->vocab_size = vocab_size; 494 | // malloc space to hold the scores and the strings 495 | t->vocab = (char**)malloc(vocab_size * sizeof(char*)); 496 | t->vocab_scores = (float*)malloc(vocab_size * sizeof(float)); 497 | t->sorted_vocab = NULL; // initialized lazily 498 | for (int i = 0; i < 256; i++) { 499 | t->byte_pieces[i * 2] = (unsigned char)i; 500 | t->byte_pieces[i * 2 + 1] = '\0'; 501 | } 502 | // read in the file 503 | if (!coralmicro::LfsReadFile(tokenizer_path, tokenizer_buffer)) { 504 | fprintf(stderr, "couldn't load %s\n", tokenizer_path); 505 | exit(EXIT_FAILURE); 506 | } 507 | uint8_t* tokenizer_ptr = tokenizer_buffer->data(); 508 | sread(&t->max_token_length, sizeof(int), &tokenizer_ptr); 509 | int len; 510 | for (int i = 0; i < vocab_size; i++) { 511 | sread(t->vocab_scores + i, sizeof(float), &tokenizer_ptr); 512 | sread(&len, sizeof(int), &tokenizer_ptr); 513 | t->vocab[i] = (char *)malloc(len + 1); 514 | sread(t->vocab[i], len, &tokenizer_ptr); 515 | t->vocab[i][len] = '\0'; // add the string terminating token 516 | } 517 | } 518 | 519 | void free_tokenizer(Tokenizer* t) { 520 | for (int i = 0; i < t->vocab_size; i++) { free(t->vocab[i]); } 521 | free(t->vocab); 522 | free(t->vocab_scores); 523 | free(t->sorted_vocab); 524 | } 525 | 526 | char* decode(Tokenizer* t, int prev_token, int token) { 527 | char *piece = t->vocab[token]; 528 | // following BOS (1) token, sentencepiece decoder strips any leading whitespace (see PR #89) 529 | if (prev_token == 1 && piece[0] == ' ') { piece++; } 530 | // careful, some tokens designate raw bytes, and look like e.g. '<0x01>' 531 | // parse this and convert and return the actual byte 532 | unsigned char byte_val; 533 | if (sscanf(piece, "<0x%02hhX>", &byte_val) == 1) { 534 | piece = (char*)t->byte_pieces + byte_val * 2; 535 | } 536 | return piece; 537 | } 538 | 539 | void safe_printf(char *piece) { 540 | // piece might be a raw byte token, and we only want to print printable chars or whitespace 541 | // because some of the other bytes can be various control codes, backspace, etc. 542 | if (piece == NULL) { return; } 543 | if (piece[0] == '\0') { return; } 544 | if (piece[1] == '\0') { 545 | unsigned char byte_val = piece[0]; 546 | if (!(isprint(byte_val) || isspace(byte_val))) { 547 | return; // bad byte, don't print it 548 | } 549 | } 550 | printf("%s", piece); 551 | } 552 | 553 | int str_lookup(const char *str, TokenIndex *sorted_vocab, int vocab_size) { 554 | // efficiently find the perfect match for str in vocab, return its index or -1 if not found 555 | TokenIndex tok = { .str = str }; // acts as the key to search for 556 | TokenIndex *res = (TokenIndex*)bsearch(&tok, sorted_vocab, vocab_size, sizeof(TokenIndex), compare_tokens); 557 | return res != NULL ? res->id : -1; 558 | } 559 | 560 | void encode(Tokenizer* t, const char *text, int8_t bos, int8_t eos, int *tokens, int *n_tokens) { 561 | // encode the string text (input) into an upper-bound preallocated tokens[] array 562 | // bos != 0 means prepend the BOS token (=1), eos != 0 means append the EOS token (=2) 563 | if (text == NULL) { fprintf(stderr, "cannot encode NULL text\n"); exit(EXIT_FAILURE); } 564 | 565 | if (t->sorted_vocab == NULL) { 566 | // lazily malloc and sort the vocabulary 567 | t->sorted_vocab = (TokenIndex*)malloc(t->vocab_size * sizeof(TokenIndex)); 568 | for (int i = 0; i < t->vocab_size; i++) { 569 | t->sorted_vocab[i].str = t->vocab[i]; 570 | t->sorted_vocab[i].id = i; 571 | } 572 | qsort(t->sorted_vocab, t->vocab_size, sizeof(TokenIndex), compare_tokens); 573 | } 574 | 575 | // create a temporary buffer that will store merge candidates of always two consecutive tokens 576 | // *2 for concat, +1 for null terminator +2 for UTF8 (in case max_token_length is 1) 577 | char* str_buffer = (char*)malloc((t->max_token_length*2 +1 +2) * sizeof(char)); 578 | size_t str_len = 0; 579 | 580 | // start at 0 tokens 581 | *n_tokens = 0; 582 | 583 | // add optional BOS (=1) token, if desired 584 | if (bos) tokens[(*n_tokens)++] = 1; 585 | 586 | // add_dummy_prefix is true by default 587 | // so prepend a dummy prefix token to the input string, but only if text != "" 588 | // TODO: pretty sure this isn't correct in the general case but I don't have the 589 | // energy to read more of the sentencepiece code to figure out what it's doing 590 | if (text[0] != '\0') { 591 | int dummy_prefix = str_lookup(" ", t->sorted_vocab, t->vocab_size); 592 | tokens[(*n_tokens)++] = dummy_prefix; 593 | } 594 | 595 | // Okay UTF-8 time. This will get messy. Here is the reference from Wikipedia: 596 | // Code point ↔ UTF-8 conversion 597 | // First code point Last code point Byte 1 Byte 2 Byte 3 Byte 4 598 | // U+0000 U+007F 0xxxxxxx 599 | // U+0080 U+07FF 110xxxxx 10xxxxxx 600 | // U+0800 U+FFFF 1110xxxx 10xxxxxx 10xxxxxx 601 | // U+10000 U+10FFFF 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx 602 | 603 | // process the raw (UTF-8) byte sequence of the input string 604 | for (const char *c = text; *c != '\0'; c++) { 605 | 606 | // reset buffer if the current byte is ASCII or a leading byte 607 | // 0xC0 is 11000000, so (*c & 0xC0) keeps the first 2 bits and zeros the rest 608 | // 0x80 is 10000000 609 | // in UTF-8, all continuation bytes start with "10" in first two bits 610 | // so in English this is: "if this byte is not a continuation byte" 611 | if ((*c & 0xC0) != 0x80) { 612 | // this byte must be either a leading byte (11...) or an ASCII char (0x...) 613 | // => reset our location, as we're starting a new UTF-8 codepoint 614 | str_len = 0; 615 | } 616 | 617 | // append the current byte to the buffer 618 | str_buffer[str_len++] = *c; // ++ is post-increment, incremented after this line 619 | str_buffer[str_len] = '\0'; 620 | 621 | // while the next character is a continuation byte, continue appending 622 | // but if there are too many of them, just stop to avoid overruning str_buffer size. 623 | if ((*(c+1) & 0xC0) == 0x80 && str_len < 4) { 624 | continue; 625 | } 626 | 627 | // ok c+1 is not a continuation byte, so we've read in a full codepoint 628 | int id = str_lookup(str_buffer, t->sorted_vocab, t->vocab_size); 629 | 630 | if (id != -1) { 631 | // we found this codepoint in vocab, add it as a token 632 | tokens[(*n_tokens)++] = id; 633 | } else { 634 | // byte_fallback encoding: just encode each byte as a token 635 | // +3 is here because the first 3 vocab elements are , , 636 | // so the individual bytes only start at index 3 637 | for (size_t i=0; i < str_len; i++) { 638 | tokens[(*n_tokens)++] = (unsigned char)str_buffer[i] + 3; 639 | } 640 | } 641 | str_len = 0; // protect against a sequence of stray UTF8 continuation bytes 642 | } 643 | 644 | // merge the best consecutive pair each iteration, according the scores in vocab_scores 645 | while (1) { 646 | float best_score = -1e10; 647 | int best_id = -1; 648 | int best_idx = -1; 649 | 650 | for (int i=0; i < (*n_tokens-1); i++) { 651 | // check if we can merge the pair (tokens[i], tokens[i+1]) 652 | sprintf(str_buffer, "%s%s", t->vocab[tokens[i]], t->vocab[tokens[i+1]]); 653 | int id = str_lookup(str_buffer, t->sorted_vocab, t->vocab_size); 654 | if (id != -1 && t->vocab_scores[id] > best_score) { 655 | // this merge pair exists in vocab! record its score and position 656 | best_score = t->vocab_scores[id]; 657 | best_id = id; 658 | best_idx = i; 659 | } 660 | } 661 | 662 | if (best_idx == -1) { 663 | break; // we couldn't find any more pairs to merge, so we're done 664 | } 665 | 666 | // merge the consecutive pair (best_idx, best_idx+1) into new token best_id 667 | tokens[best_idx] = best_id; 668 | // delete token at position best_idx+1, shift the entire sequence back 1 669 | for (int i = best_idx+1; i < (*n_tokens-1); i++) { 670 | tokens[i] = tokens[i+1]; 671 | } 672 | (*n_tokens)--; // token length decreased 673 | } 674 | 675 | // add optional EOS (=2) token, if desired 676 | if (eos) tokens[(*n_tokens)++] = 2; 677 | 678 | free(str_buffer); 679 | } 680 | 681 | // ---------------------------------------------------------------------------- 682 | // The Sampler, which takes logits and returns a sampled token 683 | // sampling can be done in a few ways: greedy argmax, sampling, top-p sampling 684 | 685 | typedef struct { 686 | float prob; 687 | int index; 688 | } ProbIndex; // struct used when sorting probabilities during top-p sampling 689 | 690 | typedef struct { 691 | int vocab_size; 692 | ProbIndex* probindex; // buffer used in top-p sampling 693 | float temperature; 694 | float topp; 695 | unsigned long long rng_state; 696 | } Sampler; 697 | 698 | int sample_argmax(float* probabilities, int n) { 699 | // return the index that has the highest probability 700 | int max_i = 0; 701 | float max_p = probabilities[0]; 702 | for (int i = 1; i < n; i++) { 703 | if (probabilities[i] > max_p) { 704 | max_i = i; 705 | max_p = probabilities[i]; 706 | } 707 | } 708 | return max_i; 709 | } 710 | 711 | int sample_mult(float* probabilities, int n, float coin) { 712 | // sample index from probabilities (they must sum to 1!) 713 | // coin is a random number in [0, 1), usually from random_f32() 714 | float cdf = 0.0f; 715 | for (int i = 0; i < n; i++) { 716 | cdf += probabilities[i]; 717 | if (coin < cdf) { 718 | return i; 719 | } 720 | } 721 | return n - 1; // in case of rounding errors 722 | } 723 | 724 | int compare(const void* a, const void* b) { 725 | ProbIndex* a_ = (ProbIndex*) a; 726 | ProbIndex* b_ = (ProbIndex*) b; 727 | if (a_->prob > b_->prob) return -1; 728 | if (a_->prob < b_->prob) return 1; 729 | return 0; 730 | } 731 | 732 | int sample_topp(float* probabilities, int n, float topp, ProbIndex* probindex, float coin) { 733 | // top-p sampling (or "nucleus sampling") samples from the smallest set of 734 | // tokens that exceed probability topp. This way we never sample tokens that 735 | // have very low probabilities and are less likely to go "off the rails". 736 | // coin is a random number in [0, 1), usually from random_f32() 737 | 738 | int n0 = 0; 739 | // quicksort indices in descending order of probabilities 740 | // values smaller than (1 - topp) / (n - 1) cannot be part of the result 741 | // so for efficiency we crop these out as candidates before sorting 742 | const float cutoff = (1.0f - topp) / (n - 1); 743 | for (int i = 0; i < n; i++) { 744 | if (probabilities[i] >= cutoff) { 745 | probindex[n0].index = i; 746 | probindex[n0].prob = probabilities[i]; 747 | n0++; 748 | } 749 | } 750 | qsort(probindex, n0, sizeof(ProbIndex), compare); 751 | 752 | // truncate the list where cumulative probability exceeds topp 753 | float cumulative_prob = 0.0f; 754 | int last_idx = n0 - 1; // in case of rounding errors consider all elements 755 | for (int i = 0; i < n0; i++) { 756 | cumulative_prob += probindex[i].prob; 757 | if (cumulative_prob > topp) { 758 | last_idx = i; 759 | break; // we've exceeded topp by including last_idx 760 | } 761 | } 762 | 763 | // sample from the truncated list 764 | float r = coin * cumulative_prob; 765 | float cdf = 0.0f; 766 | for (int i = 0; i <= last_idx; i++) { 767 | cdf += probindex[i].prob; 768 | if (r < cdf) { 769 | return probindex[i].index; 770 | } 771 | } 772 | return probindex[last_idx].index; // in case of rounding errors 773 | } 774 | 775 | void build_sampler(Sampler* sampler, int vocab_size, float temperature, float topp, unsigned long long rng_seed) { 776 | sampler->vocab_size = vocab_size; 777 | sampler->temperature = temperature; 778 | sampler->topp = topp; 779 | sampler->rng_state = rng_seed; 780 | // buffer only used with nucleus sampling; may not need but it's ~small 781 | sampler->probindex = (ProbIndex*)malloc(sampler->vocab_size * sizeof(ProbIndex)); 782 | } 783 | 784 | void free_sampler(Sampler* sampler) { 785 | free(sampler->probindex); 786 | } 787 | 788 | unsigned int random_u32(unsigned long long *state) { 789 | // xorshift rng: https://en.wikipedia.org/wiki/Xorshift#xorshift.2A 790 | *state ^= *state >> 12; 791 | *state ^= *state << 25; 792 | *state ^= *state >> 27; 793 | return (*state * 0x2545F4914F6CDD1Dull) >> 32; 794 | } 795 | 796 | float random_f32(unsigned long long *state) { // random float32 in [0,1) 797 | return (random_u32(state) >> 8) / 16777216.0f; 798 | } 799 | 800 | int sample(Sampler* sampler, float* logits) { 801 | // sample the token given the logits and some hyperparameters 802 | int next; 803 | if (sampler->temperature == 0.0f) { 804 | // greedy argmax sampling: take the token with the highest probability 805 | next = sample_argmax(logits, sampler->vocab_size); 806 | } else { 807 | // apply the temperature to the logits 808 | for (int q=0; qvocab_size; q++) { logits[q] /= sampler->temperature; } 809 | // apply softmax to the logits to get the probabilities for next token 810 | softmax(logits, sampler->vocab_size); 811 | // flip a (float) coin (this is our source of entropy for sampling) 812 | float coin = random_f32(&sampler->rng_state); 813 | // we sample from this distribution to get the next token 814 | if (sampler->topp <= 0 || sampler->topp >= 1) { 815 | // simply sample from the predicted probability distribution 816 | next = sample_mult(logits, sampler->vocab_size, coin); 817 | } else { 818 | // top-p (nucleus) sampling, clamping the least likely tokens to zero 819 | next = sample_topp(logits, sampler->vocab_size, sampler->topp, sampler->probindex, coin); 820 | } 821 | } 822 | return next; 823 | } 824 | 825 | // ---------------------------------------------------------------------------- 826 | // generation loop 827 | 828 | void generate(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler, const char *prompt, 829 | int steps, int GA, float* tokens_s) { 830 | const char *empty_prompt = ""; 831 | if (prompt == NULL) { prompt = empty_prompt; } 832 | 833 | // encode the (string) prompt into tokens sequence 834 | int num_prompt_tokens = 0; 835 | int* prompt_tokens = (int*)malloc((strlen(prompt)+3) * sizeof(int)); // +3 for '\0', ?BOS, ?EOS 836 | encode(tokenizer, prompt, 1, 0, prompt_tokens, &num_prompt_tokens); 837 | if (num_prompt_tokens < 1) { 838 | fprintf(stderr, "something is wrong, expected at least 1 prompt token\n"); 839 | exit(EXIT_FAILURE); 840 | } 841 | 842 | // start the main loop 843 | long start = 0; // used to time our code, only initialized after first iteration 844 | int next; // will store the next token in the sequence 845 | int token = prompt_tokens[0]; // kick off with the first token in the prompt 846 | int pos = 0; // position in the sequence 847 | while (pos < steps) { 848 | 849 | // forward the transformer to get logits for the next token 850 | float* logits = forward(transformer, token, pos, GA); 851 | 852 | // advance the state state machine 853 | if (pos < num_prompt_tokens - 1) { 854 | // if we are still processing the input prompt, force the next prompt token 855 | next = prompt_tokens[pos + 1]; 856 | } else { 857 | // otherwise sample the next token from the logits 858 | next = sample(sampler, logits); 859 | } 860 | pos++; 861 | 862 | // data-dependent terminating condition: the BOS (=1) token delimits sequences 863 | if (next == 1) { break; } 864 | 865 | // print the token as string, decode it with the Tokenizer object 866 | char* piece = decode(tokenizer, token, next); 867 | safe_printf(piece); // same as printf("%s", piece), but skips "unsafe" bytes 868 | fflush(stdout); 869 | token = next; 870 | 871 | // init the timer here because the first iteration can be slower 872 | if (tokens_s != nullptr && start == 0) { 873 | start = coralmicro::TimerMillis(); 874 | } 875 | } 876 | printf("\n"); 877 | 878 | // return achieved tok/s (pos-1 because the timer starts after first iteration) 879 | if (tokens_s != nullptr && pos > 1) { 880 | long end = coralmicro::TimerMillis(); 881 | *tokens_s = (pos - 1) / (float)(end - start) * 1000; 882 | } 883 | 884 | free(prompt_tokens); 885 | } 886 | -------------------------------------------------------------------------------- /llama4micro.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxbbraun/llama4micro/822e56baed3d3f919cea8d6b9fc752142b74a042/llama4micro.gif -------------------------------------------------------------------------------- /main.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include "libs/base/filesystem.h" 6 | #include "libs/base/gpio.h" 7 | #include "libs/base/led.h" 8 | #include "libs/base/timer.h" 9 | #include "libs/camera/camera.h" 10 | #include "libs/tensorflow/utils.h" 11 | #include "libs/tpu/edgetpu_manager.h" 12 | #include "libs/tpu/edgetpu_op.h" 13 | #include "third_party/freertos_kernel/include/FreeRTOS.h" 14 | #include "third_party/freertos_kernel/include/task.h" 15 | #include "third_party/tflite-micro/tensorflow/lite/micro/micro_error_reporter.h" 16 | #include "third_party/tflite-micro/tensorflow/lite/micro/micro_interpreter.h" 17 | #include "third_party/tflite-micro/tensorflow/lite/micro/micro_mutable_op_resolver.h" 18 | 19 | #include "llama2.h" 20 | #include "yolov5.h" 21 | 22 | using namespace coralmicro; 23 | using namespace tflite; 24 | 25 | // Llama model data paths. 26 | const char* kLlamaModelPath = "/models/llama2/stories15M_q80.bin"; 27 | const char* kLlamaTokenizerPath = "/models/llama2/tokenizer.bin"; 28 | 29 | // Llama model inference parameters. 30 | const float kTemperature = 1.0f; 31 | const float kTopP = 0.9f; 32 | const int kSteps = 256; 33 | const char* kPromptPattern = "Once upon a time, there was a "; 34 | 35 | // Llama model data structures. 36 | Transformer transformer; 37 | std::vector* llama_model_buffer; 38 | int group_size; 39 | int steps = kSteps; 40 | Tokenizer tokenizer; 41 | std::vector* llama_tokenizer_buffer; 42 | Sampler sampler; 43 | 44 | // Vision model data path. 45 | const char* kVisionModelPath = "/models/yolov5/yolov5n-int8_edgetpu.tflite"; 46 | const char* kVisionLabelsPath = "/models/yolov5/coco_labels.txt"; 47 | 48 | // Vision model data structures. 49 | std::vector* vision_model_buffer; 50 | std::vector* vision_labels; 51 | const size_t kTensorArenaSize = 575 * 1024; 52 | STATIC_TENSOR_ARENA_IN_SDRAM(tensor_arena, kTensorArenaSize); 53 | PerformanceMode kTpuPerformanceMode = PerformanceMode::kLow; // Fast enough. 54 | 55 | // Camera and object detection configuration. 56 | CameraFrameFormat frame_format; 57 | const int kDiscardFrames = 30; 58 | const float kLabelConfidenceThreshold = 0.4f; 59 | const float kBboxScoreThreshold = 0.2f; 60 | const float kMinBboxSize = 0.1f; 61 | 62 | // Debounce interval for the button interrupt. 63 | const uint64_t kButtonDebounceUs = 50000; 64 | 65 | // Loads the Llama model and tokenizer into memory and sets up data structures. 66 | void LoadLlamaModel() { 67 | int64_t timer_start = TimerMillis(); 68 | 69 | printf(">>> Loading Llama model %s...\n", kLlamaModelPath); 70 | llama_model_buffer = new std::vector(); 71 | build_transformer(&transformer, kLlamaModelPath, llama_model_buffer, 72 | &group_size); 73 | if (steps == 0 || steps > transformer.config.seq_len) { 74 | steps = transformer.config.seq_len; 75 | } 76 | 77 | printf(">>> Loading Llama tokenizer %s...\n", kLlamaTokenizerPath); 78 | llama_tokenizer_buffer = new std::vector(); 79 | build_tokenizer(&tokenizer, kLlamaTokenizerPath, llama_tokenizer_buffer, 80 | transformer.config.vocab_size); 81 | 82 | unsigned long long rng_seed = xTaskGetTickCount(); 83 | build_sampler(&sampler, transformer.config.vocab_size, kTemperature, kTopP, 84 | rng_seed); 85 | 86 | int64_t timer_stop = TimerMillis(); 87 | float timer_s = (timer_stop - timer_start) / 1000.0f; 88 | printf(">>> Llama model loading took %.2f s\n", timer_s); 89 | } 90 | 91 | // Frees the memory associated with the Llama model. 92 | void UnloadLlamaModel() { 93 | printf(">>> Unloading Llama model...\n"); 94 | 95 | free_sampler(&sampler); 96 | free_tokenizer(&tokenizer); 97 | free_transformer(&transformer); 98 | 99 | delete llama_tokenizer_buffer; 100 | delete llama_model_buffer; 101 | } 102 | 103 | // Loads the vision model and labels into memory. 104 | void LoadVisionModel() { 105 | int64_t timer_start = TimerMillis(); 106 | 107 | // Load the model weights. 108 | printf(">>> Loading vision model %s...\n", kVisionModelPath); 109 | vision_model_buffer = new std::vector(); 110 | if (!LfsReadFile(kVisionModelPath, vision_model_buffer)) { 111 | printf("ERROR: Failed to load vision model weights: %s\n", 112 | kVisionModelPath); 113 | return; 114 | } 115 | 116 | // Load the model labels. 117 | printf(">>> Loading vision labels %s...\n", kVisionLabelsPath); 118 | std::string vision_labels_buffer; 119 | vision_labels = new std::vector(); 120 | if (!LfsReadFile(kVisionLabelsPath, &vision_labels_buffer)) { 121 | printf("ERROR: Failed to load vision labels: %s\n", kVisionLabelsPath); 122 | return; 123 | } 124 | std::istringstream labels_stream(vision_labels_buffer); 125 | std::string label; 126 | while (std::getline(labels_stream, label)) { 127 | vision_labels->push_back(label); 128 | } 129 | 130 | int64_t timer_stop = TimerMillis(); 131 | float timer_s = (timer_stop - timer_start) / 1000.0f; 132 | printf(">>> Vision model loading took %.2f s\n", timer_s); 133 | } 134 | 135 | // Frees the memory associated with the vision model. 136 | void UnloadVisionModel() { 137 | printf(">>> Unloading vision model...\n"); 138 | 139 | delete vision_model_buffer; 140 | delete vision_labels; 141 | } 142 | 143 | // Takes a picture and returns the label of the main detected object. 144 | std::string TakePicture() { 145 | printf(">>> Taking picture...\n"); 146 | int64_t timer_start = TimerMillis(); 147 | 148 | // Turn on the camera. 149 | if (!CameraTask::GetSingleton()->SetPower(true)) { 150 | printf("ERROR: Failed to power on camera\n"); 151 | return ""; 152 | } 153 | if (!CameraTask::GetSingleton()->Enable(CameraMode::kStreaming)) { 154 | printf("ERROR: Failed to enable camera\n"); 155 | return ""; 156 | } 157 | 158 | // Initialize the TPU. 159 | auto tpu_context = 160 | EdgeTpuManager::GetSingleton()->OpenDevice(kTpuPerformanceMode); 161 | if (!tpu_context) { 162 | printf("ERROR: Failed to get TPU context\n"); 163 | return ""; 164 | } 165 | 166 | // Initialize the TF Lite interpreter. 167 | MicroMutableOpResolver<4> tf_resolver; 168 | tf_resolver.AddCustom(kCustomOp, RegisterCustomOp()); 169 | tf_resolver.AddQuantize(); 170 | tf_resolver.AddConcatenation(); 171 | tf_resolver.AddReshape(); 172 | MicroErrorReporter tf_error_reporter; 173 | MicroInterpreter tf_interpreter(GetModel(vision_model_buffer->data()), 174 | tf_resolver, tensor_arena, kTensorArenaSize, 175 | &tf_error_reporter); 176 | if (tf_interpreter.AllocateTensors() != kTfLiteOk) { 177 | printf("ERROR: Failed to allocate tensors\n"); 178 | return ""; 179 | } 180 | 181 | // Initialize the camera capture. 182 | auto* input_tensor = tf_interpreter.input_tensor(0); 183 | frame_format.fmt = CameraFormat::kRgb; 184 | frame_format.filter = CameraFilterMethod::kBilinear; 185 | frame_format.rotation = CameraRotation::k270; 186 | frame_format.width = input_tensor->dims->data[1]; 187 | frame_format.height = input_tensor->dims->data[2]; 188 | frame_format.preserve_ratio = false; 189 | frame_format.buffer = GetTensorData(input_tensor); 190 | frame_format.white_balance = true; 191 | 192 | // Discard some frames to get a recent one and give auto exposure more time. 193 | CameraTask::GetSingleton()->DiscardFrames(10); 194 | if (!CameraTask::GetSingleton()->GetFrame({frame_format})) { 195 | printf("ERROR: Failed to take picture\n"); 196 | return ""; 197 | } 198 | 199 | // Turn off the camera. 200 | CameraTask::GetSingleton()->Disable(); 201 | CameraTask::GetSingleton()->SetPower(false); 202 | 203 | // Run the object vision model on the image. 204 | if (tf_interpreter.Invoke() != kTfLiteOk) { 205 | printf("ERROR: Failed to detect objects\n"); 206 | return ""; 207 | } 208 | 209 | // Process the results. 210 | auto results = yolo::GetDetectionResults( 211 | &tf_interpreter, kLabelConfidenceThreshold, kBboxScoreThreshold, 212 | kMinBboxSize, vision_labels); 213 | if (results.empty()) { 214 | printf(">>> Found no objects\n"); 215 | return ""; 216 | } 217 | for (auto result : results) { 218 | printf(">>> Found %s (%.2f @ %.2f|%.2f %.2fx%.2f)\n", result.label.c_str(), 219 | result.confidence, result.x, result.y, result.width, result.height); 220 | } 221 | 222 | int64_t timer_stop = TimerMillis(); 223 | float timer_s = (timer_stop - timer_start) / 1000.0f; 224 | printf(">>> Picture taking took %.2f s\n", timer_s); 225 | 226 | // Use the top result's label. 227 | return results[0].label; 228 | } 229 | 230 | // Generates a story beginning with the specified prompt. 231 | void TellStory(std::string prompt) { 232 | printf(">>> Generating tokens...\n"); 233 | 234 | float tokens_s; 235 | generate(&transformer, &tokenizer, &sampler, prompt.c_str(), steps, 236 | group_size, &tokens_s); 237 | 238 | printf(">>> Averaged %.2f tokens/s\n", tokens_s); 239 | } 240 | 241 | extern "C" [[noreturn]] void app_main(void* param) { 242 | (void)param; 243 | 244 | // Set up the button interrupt. 245 | GpioConfigureInterrupt( 246 | Gpio::kUserButton, GpioInterruptMode::kIntModeFalling, 247 | [handle = xTaskGetCurrentTaskHandle()]() { xTaskResumeFromISR(handle); }, 248 | kButtonDebounceUs); 249 | 250 | // Load the models while showing the status LED. 251 | LedSet(Led::kStatus, true); 252 | LedSet(Led::kUser, false); 253 | LoadLlamaModel(); 254 | LoadVisionModel(); 255 | 256 | while (true) { 257 | // Wait for a button press while showing the user LED. 258 | LedSet(Led::kStatus, false); 259 | LedSet(Led::kUser, true); 260 | vTaskSuspend(nullptr); 261 | // Continuing here after the button interrupt. 262 | 263 | // Take a picture while (automatically) showing the camera LED. The result 264 | // is the label of the main detected object. 265 | LedSet(Led::kStatus, false); 266 | LedSet(Led::kUser, false); 267 | std::string label = TakePicture(); 268 | 269 | // The label might have multiple comma-separated parts. Pick the first one 270 | // and combine it with the prompt. 271 | std::istringstream label_stream(label); 272 | std::getline(label_stream, label, ','); 273 | std::string prompt = kPromptPattern; 274 | prompt += label; 275 | 276 | // Tell a story while showing the status LED. 277 | LedSet(Led::kStatus, true); 278 | LedSet(Led::kUser, false); 279 | TellStory(prompt); 280 | } 281 | 282 | // Unreachable in regular operation. The models stay in memory. 283 | UnloadLlamaModel(); 284 | UnloadVisionModel(); 285 | } 286 | -------------------------------------------------------------------------------- /models/README.md: -------------------------------------------------------------------------------- 1 | ## Models 2 | 3 | This directory contains the pre-trained model weights and metadata. See instructions below about their origins. 4 | 5 | Some of the tools use Python. Install their dependencies: 6 | 7 | ```bash 8 | python3 -m venv venv 9 | . venv/bin/activate 10 | 11 | pip install -r llama2.c/requirements.txt 12 | pip install -r yolov5/requirements.txt 13 | 14 | ``` 15 | 16 | ### Llama 17 | 18 | The model used by [llama2.c](https://github.com/karpathy/llama2.c) is based on [Llama 2](https://ai.meta.com/llama/) and the [TinyStories](https://huggingface.co/datasets/roneneldan/TinyStories) dataset. The model weights are from the [tinyllamas](https://huggingface.co/karpathy/tinyllamas/tree/main) repository. We are using the [OG version](https://github.com/karpathy/llama2.c#models) (with 15M parameters) and quantize it. This model runs on the [Arm Cortex-M7 CPU](https://developer.arm.com/Processors/Cortex-M7). 19 | 20 | 21 | ```bash 22 | LLAMA_MODEL_NAME=stories15M 23 | LLAMA_MODEL_DIR=llama2 24 | 25 | wget -P models/${LLAMA_MODEL_DIR} \ 26 | https://huggingface.co/karpathy/tinyllamas/resolve/main/${LLAMA_MODEL_NAME}.pt 27 | 28 | python llama2.c/export.py \ 29 | models/${LLAMA_MODEL_DIR}/${LLAMA_MODEL_NAME}_q80.bin \ 30 | --version 2 \ 31 | --checkpoint models/${LLAMA_MODEL_DIR}/${LLAMA_MODEL_NAME}.pt 32 | ``` 33 | 34 | The tokenizer comes from the [llama2.c](https://github.com/karpathy/llama2.c) repository. 35 | 36 | ```bash 37 | cp llama2.c/tokenizer.bin models/${LLAMA_MODEL_DIR}/ 38 | ``` 39 | 40 | ### Vision 41 | 42 | Object detection (with labels used for prompting Llama) is based on [YOLOv5](https://github.com/ultralytics/yolov5), specifially the smallest version "n" at a 224x224 resolution. This model runs on the [Coral Edge TPU](https://coral.ai/technology/), which requires an additional compilation step (handled by the exporter). 43 | 44 | ```bash 45 | git clone https://github.com/ultralytics/yolov5.git 46 | 47 | YOLO_RESOLUTION=224 48 | YOLO_VERSION=n 49 | VISION_MODEL_NAME=yolov5${YOLO_VERSION}-int8_edgetpu 50 | YOLO_MODEL_DIR=yolov5 51 | 52 | python yolov5/export.py \ 53 | --weights yolov5${YOLO_VERSION}.pt \ 54 | --include edgetpu \ 55 | --int8 \ 56 | --img ${YOLO_RESOLUTION} \ 57 | --data yolov5/data/coco128.yaml 58 | 59 | mkdir models/${YOLO_MODEL_DIR}/ 60 | cp yolov5/${VISION_MODEL_NAME}.tflite models/${YOLO_MODEL_DIR}/ 61 | ``` 62 | 63 | The labels are from the [COCO dataset](https://cocodataset.org/). Convert them to an easily readable format. 64 | 65 | ```bash 66 | python models/export_coco_labels.py 67 | ``` 68 | -------------------------------------------------------------------------------- /models/export_coco_labels.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | 3 | YAML_FILE = 'yolov5/data/coco.yaml' 4 | TXT_FILE = 'models/yolov5/coco_labels.txt' 5 | 6 | # Open the YAML file containing the COCO labels. 7 | with open(YAML_FILE) as f: 8 | coco_yaml = yaml.safe_load(f) 9 | 10 | # Extract the labels from the data structure. 11 | labels = coco_yaml['names'] 12 | 13 | print(f'Opened {YAML_FILE} containing {len(labels)} labels.') 14 | 15 | # Save one label per line to a text file. 16 | with open(TXT_FILE, 'w') as f: 17 | for i in range(len(labels)): 18 | f.write(f'{labels[i]}\n') 19 | 20 | print(f'Saved {len(labels)} labels to {TXT_FILE}.') 21 | -------------------------------------------------------------------------------- /models/llama2/stories15M_q80.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxbbraun/llama4micro/822e56baed3d3f919cea8d6b9fc752142b74a042/models/llama2/stories15M_q80.bin -------------------------------------------------------------------------------- /models/llama2/tokenizer.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxbbraun/llama4micro/822e56baed3d3f919cea8d6b9fc752142b74a042/models/llama2/tokenizer.bin -------------------------------------------------------------------------------- /models/yolov5/coco_labels.txt: -------------------------------------------------------------------------------- 1 | person 2 | bicycle 3 | car 4 | motorcycle 5 | airplane 6 | bus 7 | train 8 | truck 9 | boat 10 | traffic light 11 | fire hydrant 12 | stop sign 13 | parking meter 14 | bench 15 | bird 16 | cat 17 | dog 18 | horse 19 | sheep 20 | cow 21 | elephant 22 | bear 23 | zebra 24 | giraffe 25 | backpack 26 | umbrella 27 | handbag 28 | tie 29 | suitcase 30 | frisbee 31 | skis 32 | snowboard 33 | sports ball 34 | kite 35 | baseball bat 36 | baseball glove 37 | skateboard 38 | surfboard 39 | tennis racket 40 | bottle 41 | wine glass 42 | cup 43 | fork 44 | knife 45 | spoon 46 | bowl 47 | banana 48 | apple 49 | sandwich 50 | orange 51 | broccoli 52 | carrot 53 | hot dog 54 | pizza 55 | donut 56 | cake 57 | chair 58 | couch 59 | potted plant 60 | bed 61 | dining table 62 | toilet 63 | tv 64 | laptop 65 | mouse 66 | remote 67 | keyboard 68 | cell phone 69 | microwave 70 | oven 71 | toaster 72 | sink 73 | refrigerator 74 | book 75 | clock 76 | vase 77 | scissors 78 | teddy bear 79 | hair drier 80 | toothbrush 81 | -------------------------------------------------------------------------------- /models/yolov5/yolov5n-int8_edgetpu.tflite: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxbbraun/llama4micro/822e56baed3d3f919cea8d6b9fc752142b74a042/models/yolov5/yolov5n-int8_edgetpu.tflite -------------------------------------------------------------------------------- /yolov5.h: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "third_party/tflite-micro/tensorflow/lite/micro/micro_interpreter.h" 5 | 6 | namespace yolo { 7 | 8 | // The intersection over union threshold used in non-maximum suppression. 9 | const float kNmsIouThreshold = 0.1f; // Aggressive 10 | 11 | // An object detection result. 12 | struct Object { 13 | std::string label; 14 | float confidence; 15 | float x; 16 | float y; 17 | float width; 18 | float height; 19 | }; 20 | 21 | // Calculates the intersection over union of two objects' bounding boxes. 22 | inline float IntersectionOverUnion(Object& a, Object& b) { 23 | float intersection_width = std::max( 24 | 0.0f, std::min(a.x + a.width, b.x + b.width) - std::max(a.x, b.x)); 25 | float intersection_height = std::max( 26 | 0.0f, std::min(a.y + a.height, b.y + b.height) - std::max(a.y, b.y)); 27 | float intersection_area = intersection_width * intersection_height; 28 | float union_area = 29 | a.width * a.height + b.width * b.height - intersection_area; 30 | return intersection_area / union_area; 31 | } 32 | 33 | // Performs non-maximum suppression on a list of objects. 34 | std::vector NonMaximumSuppression(std::vector& objects) { 35 | std::vector final_objects; 36 | 37 | for (size_t index_a = 0; index_a < objects.size(); ++index_a) { 38 | Object object_a = objects[index_a]; 39 | 40 | // Compare each object to all others to determine whether to keep it. 41 | bool discard_a = false; 42 | for (size_t index_b = 0; index_b < objects.size(); ++index_b) { 43 | Object object_b = objects[index_b]; 44 | 45 | // Don't compare the object to itself. 46 | if (index_a == index_b) { 47 | continue; 48 | } 49 | 50 | // Only compare objects with the same label. 51 | if (object_a.label != object_b.label) { 52 | continue; 53 | } 54 | 55 | // Scrutinize object pairs with overlapping bounding boxes. 56 | if (IntersectionOverUnion(object_a, object_b) > kNmsIouThreshold) { 57 | // Keep the object if it has the highest confidence. 58 | if (object_a.confidence > object_b.confidence) { 59 | continue; 60 | } 61 | 62 | // Keep the object if confidences are tied and it has a larger area. 63 | if (object_a.confidence == object_b.confidence && 64 | object_a.width * object_a.height > 65 | object_b.width * object_b.height) { 66 | continue; 67 | } 68 | 69 | // Otherwise, discard the object. 70 | discard_a = true; 71 | break; 72 | } 73 | } 74 | 75 | // Only keep non-discarded objects. 76 | if (!discard_a) { 77 | final_objects.push_back(object_a); 78 | } 79 | } 80 | 81 | return final_objects; 82 | } 83 | 84 | // Dequantizes a quantized value based on the quantization parameters. 85 | inline float Dequantize(uint8_t quantized_value, 86 | TfLiteQuantizationParams& quantization_params) { 87 | return (static_cast(quantized_value) - quantization_params.zero_point) * 88 | quantization_params.scale; 89 | } 90 | 91 | // Processes the output tensor and returns a list of detected objects. 92 | std::vector GetDetectionResults(tflite::MicroInterpreter* interpreter, 93 | float label_confidence_threshold, 94 | float class_score_threshold, 95 | float min_bbox_size, 96 | std::vector* labels) { 97 | // Extract the data from the output tensor. 98 | auto output_tensor = interpreter->output_tensor(0); 99 | const int num_rows = output_tensor->dims->data[1]; 100 | int row_dims = output_tensor->dims->data[2]; 101 | int header_size = 5; // x, y, width, height, confidence 102 | uint8_t* data = output_tensor->data.uint8; 103 | TfLiteQuantizationParams quantization_params = output_tensor->params; 104 | 105 | // Rows come in groups of header_size + num_labels. 106 | std::vector raw_results; 107 | for (int row = 0; row < num_rows; ++row) { 108 | // The first number is the confidence. 109 | float confidence = Dequantize(data[row * row_dims], quantization_params); 110 | 111 | // Discard low confidence rows. 112 | if (confidence < label_confidence_threshold) { 113 | continue; 114 | } 115 | 116 | // The next four numbers are the bounding box. 117 | float x = Dequantize(data[row * row_dims + 1], quantization_params); 118 | float y = Dequantize(data[row * row_dims + 2], quantization_params); 119 | float width = Dequantize(data[row * row_dims + 3], quantization_params); 120 | float height = Dequantize(data[row * row_dims + 4], quantization_params); 121 | 122 | // Clip the bounding box to the image. 123 | x = std::max(0.0f, std::min(x, 1.0f)); 124 | y = std::max(0.0f, std::min(y, 1.0f)); 125 | width = std::max(0.0f, std::min(width, 1.0f - x)); 126 | height = std::max(0.0f, std::min(height, 1.0f - y)); 127 | 128 | // The remaining numbers are the label scores. Pick the highest one. 129 | float max_score = 0.0f; 130 | int max_score_label = 0; 131 | int num_labels = row_dims - header_size; 132 | for (int label = 0; label < num_labels; ++label) { 133 | float score = Dequantize(data[row * row_dims + header_size + label], 134 | quantization_params); 135 | if (score > max_score) { 136 | max_score = score; 137 | max_score_label = label; 138 | } 139 | } 140 | 141 | // Discard low score classes. 142 | if (max_score < class_score_threshold) { 143 | continue; 144 | } 145 | 146 | // Discard small bounding boxes. Both sides have to be large enough. 147 | if (width < min_bbox_size || height < min_bbox_size) { 148 | continue; 149 | } 150 | 151 | // Assemble the result. 152 | Object object; 153 | object.label = labels->at(max_score_label); 154 | object.confidence = confidence; 155 | object.x = x; 156 | object.y = y; 157 | object.width = width; 158 | object.height = height; 159 | raw_results.push_back(object); 160 | } 161 | 162 | // Perform naive non-maximum suppression. 163 | auto filtered_results = NonMaximumSuppression(raw_results); 164 | 165 | // Sort the results by closeness to the center of the image. 166 | std::sort(filtered_results.begin(), filtered_results.end(), 167 | [](auto& a, auto& b) { 168 | float a_horizontal_distance = a.x + a.width / 2 - 0.5f; 169 | float a_vertical_distance = a.y + a.height / 2 - 0.5f; 170 | float a_distance_squared = 171 | a_horizontal_distance * a_horizontal_distance + 172 | a_vertical_distance * a_vertical_distance; 173 | float b_horizontal_distance = b.x + b.width / 2 - 0.5f; 174 | float b_vertical_distance = b.y + b.height / 2 - 0.5f; 175 | float b_distance_squared = 176 | b_horizontal_distance * b_horizontal_distance + 177 | b_vertical_distance * b_vertical_distance; 178 | return a_distance_squared < b_distance_squared; 179 | }); 180 | 181 | return filtered_results; 182 | } 183 | 184 | } // namespace yolo 185 | --------------------------------------------------------------------------------