├── .gitignore ├── README.md ├── fp16-to-bf16 ├── bfloat16-vs-float16-study.ipynb ├── fp16-bf16-convert.ipynb └── train-fp16.ipynb ├── images ├── regularization-Mkclz.png ├── regularization-XmtF2.png ├── regularization-cmWO0.png └── regularization-jlQYp.png ├── ml.md ├── nlp.md ├── numbers ├── .ipynb_checkpoints │ └── bfloat16-vs-float16-study-checkpoint.ipynb ├── bfloat16-vs-float16-study.ipynb └── detect-model-pretrained-in-bf16-fp16-fp32.ipynb ├── pytorch.md ├── regularization.md ├── stats.md ├── symbols.md └── tools ├── html2md └── images2local.pl /.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints/ 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ml-ways 2 | 3 | ML/DL Math and Method notes 4 | 5 | 6 | ## Viewing Notes with LaTex Formulas properly rendered 7 | 8 | Since github doesn't support JS, you will need to either download these files and view them in a markdown editor that support LaTeX or alternatively you can use some kind of browser extension that will get LaTeX formulas automatically rendered. You just install 9 | the extension and when you browse any website with $\LaTeX$ in the html it'll automatically render it correctly. 10 | 11 | * Chrome extensions: 12 | 13 | - [TeX All the Things](https://chrome.google.com/webstore/detail/tex-all-the-things/cbimabofgmfdkicghcadidpemeenbffn/related) 14 | 15 | If you know of good extensions for other browsers please send a PR or send me a note via the issues. 16 | 17 | 18 | -------------------------------------------------------------------------------- /fp16-to-bf16/bfloat16-vs-float16-study.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "black-beijing", 6 | "metadata": {}, 7 | "source": [ 8 | "# float16 vs bfloat16 numerical properties comparison\n", 9 | "\n", 10 | "This a short notebook to help understand `fp16` vs `bfloat16` in particular when converting a model trained\n", 11 | "in `bfloat16` to mixed precision - it should be possible to look at the numbers to know which ranges\n", 12 | "are safe and which need to be scaled/avoided.\n", 13 | "\n", 14 | "I needed to do that in the context of trying to understand why bfloat16 t5/mt5 models that were pretrained in bfloat16 had a lot of `nan`/`inf` problems when finetuned in mixed precision." 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 1, 20 | "id": "eastern-variation", 21 | "metadata": {}, 22 | "outputs": [], 23 | "source": [ 24 | "import torch" 25 | ] 26 | }, 27 | { 28 | "cell_type": "markdown", 29 | "id": "adult-daughter", 30 | "metadata": {}, 31 | "source": [ 32 | "This is the main function, that tries to do very simply increments in `bfloat16` and then converting the result to `float16` and showing the discrepancies." 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 2, 38 | "id": "resistant-chile", 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "def find_mismatch(start, incr):\n", 43 | " bf16 = torch.tensor(start, dtype=torch.bfloat16)\n", 44 | " print(f\"\\nfp32 start={start:.2e} using increment={incr}\")\n", 45 | " print(f\"{'bfloat16':>18} {'float16':>18} {'diff':>8}\")\n", 46 | " c = 0\n", 47 | " tries = 0\n", 48 | " while c < 8:\n", 49 | " fp16 = bf16.to(torch.float16)\n", 50 | " if not (fp16 == bf16):\n", 51 | " print(f\"{bf16:.16f} {fp16:.16f} {torch.sub(fp16.to(dtype=torch.float32), bf16):+.2e}\")\n", 52 | " c += 1\n", 53 | " bf16 += incr\n", 54 | " tries += 1\n", 55 | " if tries >= 1e5:\n", 56 | " print(f\"gave up finding mismatch after {tries} steps\")\n", 57 | " return" 58 | ] 59 | }, 60 | { 61 | "cell_type": "markdown", 62 | "id": "applied-damages", 63 | "metadata": {}, 64 | "source": [ 65 | "## Underflow for fp16\n", 66 | "\n", 67 | "when numbers become 0.0" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": 3, 73 | "id": "cooperative-government", 74 | "metadata": {}, 75 | "outputs": [ 76 | { 77 | "name": "stdout", 78 | "output_type": "stream", 79 | "text": [ 80 | "\n", 81 | "fp32 start=1.00e-08 using increment=1e-09\n", 82 | " bfloat16 float16 diff\n", 83 | "0.0000000100117177 0.0000000000000000 -1.00e-08\n", 84 | "0.0000000110012479 0.0000000000000000 -1.10e-08\n", 85 | "0.0000000119907781 0.0000000000000000 -1.20e-08\n", 86 | "0.0000000129803084 0.0000000000000000 -1.30e-08\n", 87 | "0.0000000139698386 0.0000000000000000 -1.40e-08\n", 88 | "0.0000000150175765 0.0000000000000000 -1.50e-08\n", 89 | "0.0000000160653144 0.0000000000000000 -1.61e-08\n", 90 | "0.0000000171130523 0.0000000000000000 -1.71e-08\n" 91 | ] 92 | } 93 | ], 94 | "source": [ 95 | "find_mismatch(1e-08, 1e-09)" 96 | ] 97 | }, 98 | { 99 | "cell_type": "markdown", 100 | "id": "decimal-fraction", 101 | "metadata": {}, 102 | "source": [ 103 | "## Subnormal range for fp16\n", 104 | "\n", 105 | "starting from 5.96e-8 \n", 106 | "\n", 107 | "usually expensive and very low precision" 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": 4, 113 | "id": "statutory-procurement", 114 | "metadata": {}, 115 | "outputs": [ 116 | { 117 | "name": "stdout", 118 | "output_type": "stream", 119 | "text": [ 120 | "\n", 121 | "fp32 start=1.00e-07 using increment=1e-08\n", 122 | " bfloat16 float16 diff\n", 123 | "0.0000001001171768 0.0000001192092896 +1.91e-08\n", 124 | "0.0000001098960638 0.0000001192092896 +9.31e-09\n", 125 | "0.0000001201406121 0.0000001192092896 -9.31e-10\n", 126 | "0.0000001303851604 0.0000001192092896 -1.12e-08\n", 127 | "0.0000001406297088 0.0000001192092896 -2.14e-08\n", 128 | "0.0000001508742571 0.0000001788139343 +2.79e-08\n", 129 | "0.0000001611188054 0.0000001788139343 +1.77e-08\n", 130 | "0.0000001713633537 0.0000001788139343 +7.45e-09\n" 131 | ] 132 | } 133 | ], 134 | "source": [ 135 | "# very limited range for fp16\n", 136 | "find_mismatch(1e-07, 1e-08)" 137 | ] 138 | }, 139 | { 140 | "cell_type": "code", 141 | "execution_count": 5, 142 | "id": "distributed-puppy", 143 | "metadata": {}, 144 | "outputs": [ 145 | { 146 | "name": "stdout", 147 | "output_type": "stream", 148 | "text": [ 149 | "\n", 150 | "fp32 start=1.00e-06 using increment=1e-07\n", 151 | " bfloat16 float16 diff\n", 152 | "0.0000009983778000 0.0000010132789612 +1.49e-08\n", 153 | "0.0000010952353477 0.0000010728836060 -2.24e-08\n", 154 | "0.0000012889504433 0.0000013113021851 +2.24e-08\n", 155 | "0.0000013858079910 0.0000013709068298 -1.49e-08\n", 156 | "0.0000014826655388 0.0000014901161194 +7.45e-09\n", 157 | "0.0000015795230865 0.0000015497207642 -2.98e-08\n", 158 | "0.0000016763806343 0.0000016689300537 -7.45e-09\n", 159 | "0.0000017732381821 0.0000017881393433 +1.49e-08\n" 160 | ] 161 | } 162 | ], 163 | "source": [ 164 | "# things starting to improve slightly for fp16\n", 165 | "find_mismatch(1e-06, 1e-07)" 166 | ] 167 | }, 168 | { 169 | "cell_type": "markdown", 170 | "id": "expired-drinking", 171 | "metadata": {}, 172 | "source": [ 173 | "## Normal numbers\n", 174 | "\n", 175 | "Min positive normal fp16: 6.104e-05 (`np.finfo(np.float16).tiny`)\n", 176 | "\n", 177 | "These ranges match much better and thus will not easily find a mismatch if at all" 178 | ] 179 | }, 180 | { 181 | "cell_type": "code", 182 | "execution_count": 6, 183 | "id": "seven-caution", 184 | "metadata": {}, 185 | "outputs": [ 186 | { 187 | "name": "stdout", 188 | "output_type": "stream", 189 | "text": [ 190 | "\n", 191 | "fp32 start=1.00e-05 using increment=1e-06\n", 192 | " bfloat16 float16 diff\n", 193 | "gave up finding mismatch after 100000 steps\n", 194 | "\n", 195 | "fp32 start=1.00e-04 using increment=1e-06\n", 196 | " bfloat16 float16 diff\n", 197 | "gave up finding mismatch after 100000 steps\n", 198 | "\n", 199 | "fp32 start=1.00e-03 using increment=0.0001\n", 200 | " bfloat16 float16 diff\n", 201 | "gave up finding mismatch after 100000 steps\n", 202 | "\n", 203 | "fp32 start=1.00e-02 using increment=0.001\n", 204 | " bfloat16 float16 diff\n", 205 | "gave up finding mismatch after 100000 steps\n", 206 | "\n", 207 | "fp32 start=1.00e-01 using increment=0.01\n", 208 | " bfloat16 float16 diff\n", 209 | "gave up finding mismatch after 100000 steps\n", 210 | "\n", 211 | "fp32 start=1.00e+01 using increment=1e-06\n", 212 | " bfloat16 float16 diff\n", 213 | "gave up finding mismatch after 100000 steps\n", 214 | "\n", 215 | "fp32 start=1.00e+01 using increment=10.0\n", 216 | " bfloat16 float16 diff\n", 217 | "gave up finding mismatch after 100000 steps\n", 218 | "\n", 219 | "fp32 start=1.00e+04 using increment=1\n", 220 | " bfloat16 float16 diff\n", 221 | "gave up finding mismatch after 100000 steps\n" 222 | ] 223 | } 224 | ], 225 | "source": [ 226 | "find_mismatch(1e-05, 1e-06)\n", 227 | "find_mismatch(1e-04, 1e-06)\n", 228 | "find_mismatch(1e-03, 1e-04)\n", 229 | "find_mismatch(1e-02, 1e-03)\n", 230 | "find_mismatch(1e-01, 1e-02)\n", 231 | "find_mismatch(1e1, 1e-06)\n", 232 | "find_mismatch(1e1, 1e1)\n", 233 | "find_mismatch(1e4, 1)" 234 | ] 235 | }, 236 | { 237 | "cell_type": "code", 238 | "execution_count": 7, 239 | "id": "mighty-injection", 240 | "metadata": {}, 241 | "outputs": [ 242 | { 243 | "name": "stdout", 244 | "output_type": "stream", 245 | "text": [ 246 | "\n", 247 | "fp32 start=5.00e+04 using increment=1000.0\n", 248 | " bfloat16 float16 diff\n", 249 | "66048.0000000000000000 inf +inf\n", 250 | "67072.0000000000000000 inf +inf\n", 251 | "68096.0000000000000000 inf +inf\n", 252 | "69120.0000000000000000 inf +inf\n", 253 | "70144.0000000000000000 inf +inf\n", 254 | "71168.0000000000000000 inf +inf\n", 255 | "72192.0000000000000000 inf +inf\n", 256 | "73216.0000000000000000 inf +inf\n" 257 | ] 258 | } 259 | ], 260 | "source": [ 261 | "# hitting max range for fp16\n", 262 | "find_mismatch(5e4, 1e3)" 263 | ] 264 | }, 265 | { 266 | "cell_type": "code", 267 | "execution_count": 8, 268 | "id": "alleged-lobby", 269 | "metadata": {}, 270 | "outputs": [], 271 | "source": [ 272 | "# --- roundoff ---\n", 273 | "# fp16 4.88e-4\n", 274 | "# bf16 3.91e-3" 275 | ] 276 | }, 277 | { 278 | "cell_type": "markdown", 279 | "id": "cheap-natural", 280 | "metadata": {}, 281 | "source": [ 282 | "## Big numbers\n", 283 | "\n", 284 | "`bfloat16` numbers have a terrible range for numbers `> 1` but `fp16` matches those exactly\n", 285 | "e.g. one can't represent 283 in bf16\n", 286 | "\n", 287 | "```\n", 288 | "python -c \"import torch; print( torch.tensor(283, dtype=torch.bfloat16) )\"\n", 289 | "tensor(284., dtype=torch.bfloat16)\n", 290 | "```" 291 | ] 292 | }, 293 | { 294 | "cell_type": "code", 295 | "execution_count": 9, 296 | "id": "integrated-individual", 297 | "metadata": {}, 298 | "outputs": [ 299 | { 300 | "name": "stdout", 301 | "output_type": "stream", 302 | "text": [ 303 | "282.00\n", 304 | "284.00\n", 305 | "286.00\n" 306 | ] 307 | } 308 | ], 309 | "source": [ 310 | "start = 280\n", 311 | "fp32 = torch.tensor(start, dtype=torch.float32)\n", 312 | "for i in range(3):\n", 313 | " bf16 = fp32.to(torch.bfloat16)\n", 314 | " bf16d = bf16\n", 315 | " while bf16 == bf16d:\n", 316 | " fp32 += 1\n", 317 | " bf16d = fp32.to(torch.bfloat16)\n", 318 | " print(f\"{bf16d:.2f}\")\n", 319 | "# 282\n", 320 | "# 284\n", 321 | "# 286" 322 | ] 323 | }, 324 | { 325 | "cell_type": "markdown", 326 | "id": "ae153e44", 327 | "metadata": {}, 328 | "source": [ 329 | "## How many positions between 2 numbers\n", 330 | "\n", 331 | "Let's see how many `fp16` numbers can fit between `bf16` numbers - which should help to understand how converting a model trained in `fp16` to `bf16` in a way quantizes the model - since there are less `bf16` numbers in the same range." 332 | ] 333 | }, 334 | { 335 | "cell_type": "code", 336 | "execution_count": 5, 337 | "id": "278e6249", 338 | "metadata": {}, 339 | "outputs": [ 340 | { 341 | "name": "stdout", 342 | "output_type": "stream", 343 | "text": [ 344 | "9 fp16s: [0.10009765625, 0.10015869140625, 0.1002197265625, 0.10028076171875, 0.100341796875, 0.10040283203125, 0.1004638671875, 0.10052490234375, 0.1005859375]\n", 345 | "2 fp16s: [0.10009765625, 0.1005859375]\n" 346 | ] 347 | } 348 | ], 349 | "source": [ 350 | "fp16 = torch.tensor(0.1001, dtype=torch.float16)\n", 351 | "bf16 = torch.tensor(0.1001, dtype=torch.bfloat16)\n", 352 | "fp16s = [fp16]\n", 353 | "bf16s = [bf16]\n", 354 | "\n", 355 | "delta = 0.00001\n", 356 | "for i in range(100):\n", 357 | " fp16_new = fp16 + delta*i\n", 358 | " bf16_new = bf16 + delta*i\n", 359 | " if fp16s[-1] != fp16_new:\n", 360 | " fp16s.append(fp16_new)\n", 361 | " if bf16s[-1] != bf16_new:\n", 362 | " bf16s.append(bf16_new)\n", 363 | " if len(bf16s) > 1 and bf16s[-1] == fp16s[-1]:\n", 364 | " break\n", 365 | " \n", 366 | "print(f\"{len(fp16s)} fp16s: {[x.item() for x in fp16s]}\")\n", 367 | "print(f\"{len(bf16s)} fp16s: {[x.item() for x in bf16s]}\")\n" 368 | ] 369 | }, 370 | { 371 | "cell_type": "markdown", 372 | "id": "fa0ae862", 373 | "metadata": {}, 374 | "source": [ 375 | "So it can be seen that in this particular range of numbers every 8 \"positions\" in `fp16` get remapped to a single \"position\" in `bf16`. As `exponent(fp16) = 10` and `exponent(bf16) = 7` - so we have `2**3=8` different positions between 2 representations." 376 | ] 377 | }, 378 | { 379 | "cell_type": "markdown", 380 | "id": "revolutionary-force", 381 | "metadata": {}, 382 | "source": [ 383 | "# Math" 384 | ] 385 | }, 386 | { 387 | "cell_type": "markdown", 388 | "id": "systematic-latex", 389 | "metadata": {}, 390 | "source": [ 391 | "## Summation\n", 392 | "\n", 393 | "A very narrow dynamic range means that for largish numbers NN trained in `bfloat16` **expects** bad\n", 394 | "precision and when the precision is suddenly higher unexpected outcomes happen:" 395 | ] 396 | }, 397 | { 398 | "cell_type": "code", 399 | "execution_count": 10, 400 | "id": "unlike-raising", 401 | "metadata": {}, 402 | "outputs": [ 403 | { 404 | "name": "stdout", 405 | "output_type": "stream", 406 | "text": [ 407 | "tensor(284., dtype=torch.bfloat16)\n", 408 | "tensor(283., dtype=torch.float16)\n" 409 | ] 410 | } 411 | ], 412 | "source": [ 413 | "# small sum\n", 414 | "print(torch.tensor(282, dtype=torch.bfloat16)+1) # 284\n", 415 | "print(torch.tensor(282, dtype=torch.float16)+1) # 283" 416 | ] 417 | }, 418 | { 419 | "cell_type": "code", 420 | "execution_count": 11, 421 | "id": "competitive-average", 422 | "metadata": {}, 423 | "outputs": [ 424 | { 425 | "name": "stdout", 426 | "output_type": "stream", 427 | "text": [ 428 | "tensor(2848., dtype=torch.bfloat16)\n", 429 | "tensor(2830., dtype=torch.float16)\n" 430 | ] 431 | } 432 | ], 433 | "source": [ 434 | "# sum several of these\n", 435 | "print(torch.tensor(283, dtype=torch.bfloat16)*10) # 2848\n", 436 | "print(torch.tensor(283, dtype=torch.float16)*10) # 2830" 437 | ] 438 | }, 439 | { 440 | "cell_type": "markdown", 441 | "id": "choice-enemy", 442 | "metadata": {}, 443 | "source": [ 444 | "As you can see numbers start to diverge quickly!\n", 445 | "\n", 446 | "Now in practice we typically add up thousands of numbers.\n", 447 | "\n", 448 | "The solution is to always do this kind of operations in double precision of the operands and then if needed casting back to the original. i.e. the accumulate of `sum(fp16_tensor)` should be at least a `float32` tensor." 449 | ] 450 | }, 451 | { 452 | "cell_type": "code", 453 | "execution_count": 12, 454 | "id": "liquid-purple", 455 | "metadata": {}, 456 | "outputs": [ 457 | { 458 | "data": { 459 | "text/plain": [ 460 | "tensor(inf, dtype=torch.float16)" 461 | ] 462 | }, 463 | "execution_count": 12, 464 | "metadata": {}, 465 | "output_type": "execute_result" 466 | }, 467 | { 468 | "data": { 469 | "text/plain": [ 470 | "tensor(250394.1875)" 471 | ] 472 | }, 473 | "execution_count": 12, 474 | "metadata": {}, 475 | "output_type": "execute_result" 476 | } 477 | ], 478 | "source": [ 479 | "x = torch.rand((10000)).half()*50\n", 480 | "\n", 481 | "# this overflows\n", 482 | "x.sum()\n", 483 | "# this succeeds\n", 484 | "x.sum(dtype=torch.float32)" 485 | ] 486 | }, 487 | { 488 | "cell_type": "markdown", 489 | "id": "dated-scope", 490 | "metadata": {}, 491 | "source": [ 492 | "## Getting overflows\n", 493 | "\n", 494 | "Full numbers range: ``float16: ±65,504``\n", 495 | "\n", 496 | "So fp16 overflows easily in say variance calculation when you try to just square a number bigger than `256` - as it'd overflow, i.e. you get `inf`! so `256**2+1` will be `inf`\n", 497 | "\n", 498 | "You can't even do `pow(2)` for fp16 in pytorch, the following will give an error: that it doesn't suppor power for fp16.\n", 499 | "\n", 500 | "`torch.tensor(256, dtype=torch.float16).pow(2)`\n", 501 | "\n", 502 | "You have to cast to `float32` first:" 503 | ] 504 | }, 505 | { 506 | "cell_type": "code", 507 | "execution_count": 13, 508 | "id": "quick-local", 509 | "metadata": {}, 510 | "outputs": [ 511 | { 512 | "data": { 513 | "text/plain": [ 514 | "tensor(65024., dtype=torch.float16)" 515 | ] 516 | }, 517 | "execution_count": 13, 518 | "metadata": {}, 519 | "output_type": "execute_result" 520 | } 521 | ], 522 | "source": [ 523 | "x = torch.tensor(255, dtype=torch.float16)\n", 524 | "x_squared = x.float().pow(2)\n", 525 | "x_squared.to(dtype=torch.float16)" 526 | ] 527 | }, 528 | { 529 | "cell_type": "code", 530 | "execution_count": 14, 531 | "id": "nonprofit-charleston", 532 | "metadata": {}, 533 | "outputs": [ 534 | { 535 | "data": { 536 | "text/plain": [ 537 | "tensor(inf, dtype=torch.float16)" 538 | ] 539 | }, 540 | "execution_count": 14, 541 | "metadata": {}, 542 | "output_type": "execute_result" 543 | } 544 | ], 545 | "source": [ 546 | "# let's cross into the overflow\n", 547 | "x += 1\n", 548 | "x_squared = x.float().pow(2)\n", 549 | "x_squared.to(dtype=torch.float16)" 550 | ] 551 | }, 552 | { 553 | "cell_type": "markdown", 554 | "id": "dominican-retro", 555 | "metadata": {}, 556 | "source": [ 557 | "And that's how `inf` comes about.\n", 558 | "\n", 559 | "Or if you need to create one, you can just do:" 560 | ] 561 | }, 562 | { 563 | "cell_type": "code", 564 | "execution_count": 15, 565 | "id": "returning-certificate", 566 | "metadata": {}, 567 | "outputs": [], 568 | "source": [ 569 | "t_inf = torch.tensor(float('inf'))" 570 | ] 571 | }, 572 | { 573 | "cell_type": "markdown", 574 | "id": "eastern-orleans", 575 | "metadata": {}, 576 | "source": [ 577 | "If you need to compare if a tensor has `inf` elements:" 578 | ] 579 | }, 580 | { 581 | "cell_type": "code", 582 | "execution_count": 16, 583 | "id": "ethical-wayne", 584 | "metadata": {}, 585 | "outputs": [ 586 | { 587 | "data": { 588 | "text/plain": [ 589 | "tensor(True)" 590 | ] 591 | }, 592 | "execution_count": 16, 593 | "metadata": {}, 594 | "output_type": "execute_result" 595 | } 596 | ], 597 | "source": [ 598 | "torch.isinf(t_inf).any()" 599 | ] 600 | }, 601 | { 602 | "cell_type": "markdown", 603 | "id": "derived-canberra", 604 | "metadata": {}, 605 | "source": [ 606 | "## Getting NaNs \n", 607 | "\n", 608 | "While there are many ways to get `NaN` during calculations, the most common for machine learning are:" 609 | ] 610 | }, 611 | { 612 | "cell_type": "code", 613 | "execution_count": 17, 614 | "id": "formed-certificate", 615 | "metadata": {}, 616 | "outputs": [ 617 | { 618 | "data": { 619 | "text/plain": [ 620 | "tensor(nan)" 621 | ] 622 | }, 623 | "execution_count": 17, 624 | "metadata": {}, 625 | "output_type": "execute_result" 626 | } 627 | ], 628 | "source": [ 629 | "# 0/0\n", 630 | "t_zero = torch.tensor(0)\n", 631 | "t_zero/t_zero" 632 | ] 633 | }, 634 | { 635 | "cell_type": "code", 636 | "execution_count": 18, 637 | "id": "authorized-window", 638 | "metadata": {}, 639 | "outputs": [ 640 | { 641 | "data": { 642 | "text/plain": [ 643 | "tensor(nan)" 644 | ] 645 | }, 646 | "execution_count": 18, 647 | "metadata": {}, 648 | "output_type": "execute_result" 649 | } 650 | ], 651 | "source": [ 652 | "# inf/inf\n", 653 | "t_inf = torch.tensor(float('inf'))\n", 654 | "t_inf/t_inf" 655 | ] 656 | }, 657 | { 658 | "cell_type": "code", 659 | "execution_count": 19, 660 | "id": "elegant-grant", 661 | "metadata": {}, 662 | "outputs": [ 663 | { 664 | "data": { 665 | "text/plain": [ 666 | "tensor(nan)" 667 | ] 668 | }, 669 | "execution_count": 19, 670 | "metadata": {}, 671 | "output_type": "execute_result" 672 | } 673 | ], 674 | "source": [ 675 | "# 0*inf\n", 676 | "t_zero * t_inf" 677 | ] 678 | }, 679 | { 680 | "cell_type": "code", 681 | "execution_count": 20, 682 | "id": "several-council", 683 | "metadata": {}, 684 | "outputs": [ 685 | { 686 | "data": { 687 | "text/plain": [ 688 | "tensor(nan)" 689 | ] 690 | }, 691 | "execution_count": 20, 692 | "metadata": {}, 693 | "output_type": "execute_result" 694 | } 695 | ], 696 | "source": [ 697 | "# inf - inf\n", 698 | "t_inf - t_inf" 699 | ] 700 | }, 701 | { 702 | "cell_type": "code", 703 | "execution_count": 21, 704 | "id": "local-change", 705 | "metadata": {}, 706 | "outputs": [ 707 | { 708 | "data": { 709 | "text/plain": [ 710 | "tensor(nan)" 711 | ] 712 | }, 713 | "execution_count": 21, 714 | "metadata": {}, 715 | "output_type": "execute_result" 716 | } 717 | ], 718 | "source": [ 719 | "# to get one explicitly\n", 720 | "t_nan = torch.tensor(float('nan'))\n", 721 | "t_nan" 722 | ] 723 | }, 724 | { 725 | "cell_type": "code", 726 | "execution_count": 22, 727 | "id": "hawaiian-evaluation", 728 | "metadata": {}, 729 | "outputs": [ 730 | { 731 | "data": { 732 | "text/plain": [ 733 | "tensor(True)" 734 | ] 735 | }, 736 | "execution_count": 22, 737 | "metadata": {}, 738 | "output_type": "execute_result" 739 | } 740 | ], 741 | "source": [ 742 | "# comparison\n", 743 | "torch.isnan(t_nan).any()" 744 | ] 745 | }, 746 | { 747 | "cell_type": "markdown", 748 | "id": "heated-equivalent", 749 | "metadata": {}, 750 | "source": [ 751 | "# Debugging process" 752 | ] 753 | }, 754 | { 755 | "cell_type": "markdown", 756 | "id": "inside-investor", 757 | "metadata": {}, 758 | "source": [ 759 | "As you can see, since ML is mostly matrix multiplications, which is sums and multiplications, it's enough to get one `inf` or `nan`, and the whole training goes down the rails.\n", 760 | "\n", 761 | "Here is a helper that you can run after suspect functions to see if the output gets any `inf` or `nan`s and also if you want to get an indication on whether you have some large numbers that are likely to overflow - remember in fp16 65K is the biggest number one can have." 762 | ] 763 | }, 764 | { 765 | "cell_type": "code", 766 | "execution_count": 23, 767 | "id": "alpine-wrist", 768 | "metadata": {}, 769 | "outputs": [], 770 | "source": [ 771 | "def detect_overflow(var, ctx):\n", 772 | " \"\"\"\n", 773 | " Report the count of ``nan`` and ``inf`` entries in the tensor.\n", 774 | "\n", 775 | " This is useful for detecting overflows/underflows and best to call right after the function that did some math that\n", 776 | " modified the variable in question.\n", 777 | "\n", 778 | " Args:\n", 779 | " var: tensor variable to check\n", 780 | " ctx: the message to print as a context\n", 781 | " \"\"\"\n", 782 | " if torch.isnan(var).any().item():\n", 783 | " logger.warning(f\"{ctx} has nans\")\n", 784 | " if torch.isinf(var).any().item():\n", 785 | " logger.warning(f\"{ctx} has inf\")\n", 786 | "\n", 787 | " # if needed to monitor large elements can enable the following\n", 788 | " if 0:\n", 789 | " n100 = var[torch.ge(var.abs(), 100)]\n", 790 | " if n100.numel() > 0:\n", 791 | " logger.warning(f\"{ctx}: n100={n100.numel()}\")\n", 792 | " n1000 = var[torch.ge(var.abs(), 1000)]\n", 793 | " if n1000.numel() > 0:\n", 794 | " logger.warning(f\"{ctx}: n1000={n1000.numel()}\")" 795 | ] 796 | }, 797 | { 798 | "cell_type": "markdown", 799 | "id": "vocal-passing", 800 | "metadata": {}, 801 | "source": [ 802 | "So, if you training gives you say a loss of `nan`, you can go to the layers of your model and inject this function, in one or more places, e.g.:" 803 | ] 804 | }, 805 | { 806 | "cell_type": "code", 807 | "execution_count": 24, 808 | "id": "exempt-environment", 809 | "metadata": {}, 810 | "outputs": [], 811 | "source": [ 812 | "def forward(x):\n", 813 | " detect_overflow(x, \"x / enter\")\n", 814 | " y = self.ff(x)\n", 815 | " detect_overflow(x, \"y / after ff\") " 816 | ] 817 | }, 818 | { 819 | "cell_type": "markdown", 820 | "id": "final-tracy", 821 | "metadata": {}, 822 | "source": [ 823 | "or you use an advanced debugger you can assign watches that will immediately tell you if a tensor just got some `inf`s, by having a dynamically evaluated watch expression: `torch.isinf(x).any().item()` - in this example we watch the tensor `x`. So as you step through the code you can visually immediately see if it went from `False` to `True`. " 824 | ] 825 | }, 826 | { 827 | "cell_type": "markdown", 828 | "id": "caroline-group", 829 | "metadata": {}, 830 | "source": [ 831 | "\n" 832 | ] 833 | }, 834 | { 835 | "cell_type": "markdown", 836 | "id": "imported-empty", 837 | "metadata": {}, 838 | "source": [ 839 | "# Disabling subnormal numbers in pytorch\n", 840 | "\n", 841 | "In some systems subnormal number calculation can be suboptimial as it's often done in software, so if your network deals a lot with subnormal numbers you might want to disable those and scale your numbers to a normal range instead.\n", 842 | "\n", 843 | "The following demonstrates how it works in pytorch" 844 | ] 845 | }, 846 | { 847 | "cell_type": "code", 848 | "execution_count": 25, 849 | "id": "fuzzy-pound", 850 | "metadata": { 851 | "run_control": { 852 | "marked": false 853 | } 854 | }, 855 | "outputs": [ 856 | { 857 | "data": { 858 | "text/plain": [ 859 | "tensor([0.])" 860 | ] 861 | }, 862 | "execution_count": 25, 863 | "metadata": {}, 864 | "output_type": "execute_result" 865 | }, 866 | { 867 | "data": { 868 | "text/plain": [ 869 | "tensor([1.0000e-39])" 870 | ] 871 | }, 872 | "execution_count": 25, 873 | "metadata": {}, 874 | "output_type": "execute_result" 875 | } 876 | ], 877 | "source": [ 878 | "_ = torch.set_flush_denormal(True)\n", 879 | "torch.tensor([1e-39], dtype=torch.float32)\n", 880 | "_ = torch.set_flush_denormal(False)\n", 881 | "torch.tensor([1e-39], dtype=torch.float32)" 882 | ] 883 | }, 884 | { 885 | "cell_type": "code", 886 | "execution_count": 26, 887 | "id": "surrounded-upper", 888 | "metadata": {}, 889 | "outputs": [ 890 | { 891 | "data": { 892 | "text/plain": [ 893 | "tensor([1.0133e-06], dtype=torch.float16)" 894 | ] 895 | }, 896 | "execution_count": 26, 897 | "metadata": {}, 898 | "output_type": "execute_result" 899 | }, 900 | { 901 | "data": { 902 | "text/plain": [ 903 | "tensor([1.0133e-06], dtype=torch.float16)" 904 | ] 905 | }, 906 | "execution_count": 26, 907 | "metadata": {}, 908 | "output_type": "execute_result" 909 | } 910 | ], 911 | "source": [ 912 | "# broken for fp16\n", 913 | "_ = torch.set_flush_denormal(True)\n", 914 | "torch.tensor([1e-6], dtype=torch.float16)\n", 915 | "_ = torch.set_flush_denormal(False)\n", 916 | "torch.tensor([1e-6], dtype=torch.float16)" 917 | ] 918 | }, 919 | { 920 | "cell_type": "code", 921 | "execution_count": 27, 922 | "id": "attempted-deficit", 923 | "metadata": {}, 924 | "outputs": [ 925 | { 926 | "data": { 927 | "text/plain": [ 928 | "tensor([0.], dtype=torch.bfloat16)" 929 | ] 930 | }, 931 | "execution_count": 27, 932 | "metadata": {}, 933 | "output_type": "execute_result" 934 | }, 935 | { 936 | "data": { 937 | "text/plain": [ 938 | "tensor([1.0102e-39], dtype=torch.bfloat16)" 939 | ] 940 | }, 941 | "execution_count": 27, 942 | "metadata": {}, 943 | "output_type": "execute_result" 944 | } 945 | ], 946 | "source": [ 947 | "_ = torch.set_flush_denormal(True)\n", 948 | "torch.tensor([1e-39], dtype=torch.bfloat16)\n", 949 | "_ = torch.set_flush_denormal(False)\n", 950 | "torch.tensor([1e-39], dtype=torch.bfloat16)" 951 | ] 952 | }, 953 | { 954 | "cell_type": "code", 955 | "execution_count": null, 956 | "id": "hollywood-karma", 957 | "metadata": {}, 958 | "outputs": [], 959 | "source": [] 960 | }, 961 | { 962 | "cell_type": "code", 963 | "execution_count": null, 964 | "id": "indie-message", 965 | "metadata": {}, 966 | "outputs": [], 967 | "source": [ 968 | "%%javascript # prevent committing an unsaved notebook\n", 969 | "IPython.notebook.save_notebook()" 970 | ] 971 | } 972 | ], 973 | "metadata": { 974 | "hide_input": false, 975 | "kernelspec": { 976 | "display_name": "Python 3 (ipykernel)", 977 | "language": "python", 978 | "name": "python3" 979 | }, 980 | "language_info": { 981 | "codemirror_mode": { 982 | "name": "ipython", 983 | "version": 3 984 | }, 985 | "file_extension": ".py", 986 | "mimetype": "text/x-python", 987 | "name": "python", 988 | "nbconvert_exporter": "python", 989 | "pygments_lexer": "ipython3", 990 | "version": "3.8.15" 991 | }, 992 | "toc": { 993 | "base_numbering": 1, 994 | "nav_menu": {}, 995 | "number_sections": true, 996 | "sideBar": true, 997 | "skip_h1_title": false, 998 | "title_cell": "Table of Contents", 999 | "title_sidebar": "Contents", 1000 | "toc_cell": false, 1001 | "toc_position": { 1002 | "height": "calc(100% - 180px)", 1003 | "left": "10px", 1004 | "top": "150px", 1005 | "width": "256.391px" 1006 | }, 1007 | "toc_section_display": true, 1008 | "toc_window_display": true 1009 | } 1010 | }, 1011 | "nbformat": 4, 1012 | "nbformat_minor": 5 1013 | } 1014 | -------------------------------------------------------------------------------- /fp16-to-bf16/fp16-bf16-convert.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 2, 6 | "id": "7f243247", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import numpy as np\n", 11 | "import torch" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 4, 17 | "id": "5500ad8b", 18 | "metadata": {}, 19 | "outputs": [ 20 | { 21 | "ename": "AttributeError", 22 | "evalue": "'Tensor' object has no attribute 'bytes'", 23 | "output_type": "error", 24 | "traceback": [ 25 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 26 | "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", 27 | "Cell \u001b[0;32mIn [4], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtensor\u001b[49m\u001b[43m(\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m2\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdtype\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbfloat16\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbytes\u001b[49m()\n", 28 | "\u001b[0;31mAttributeError\u001b[0m: 'Tensor' object has no attribute 'bytes'" 29 | ] 30 | } 31 | ], 32 | "source": [ 33 | "torch.tensor([2], dtype=torch.bfloat16)" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": null, 39 | "id": "8ae686a5", 40 | "metadata": {}, 41 | "outputs": [], 42 | "source": [ 43 | "bfloat16 = True" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": 32, 49 | "id": "dc51a3ad", 50 | "metadata": {}, 51 | "outputs": [ 52 | { 53 | "name": "stdout", 54 | "output_type": "stream", 55 | "text": [ 56 | "0\n", 57 | "10000\n", 58 | "20000\n", 59 | "30000\n" 60 | ] 61 | } 62 | ], 63 | "source": [ 64 | "bf16 = {}\n", 65 | "fp16 = {}\n", 66 | "for index in range(2**15):\n", 67 | " binary = \"{0:b}\".format(index)\n", 68 | " binary = f\"{binary:0>16}\"\n", 69 | " \n", 70 | " sign, exp, frac = binary[:1], binary[2:9], binary[9:]\n", 71 | " bf16[index] = int(frac, 2)*2**int(exp, 2)\n", 72 | " \n", 73 | " sign, exp, frac = binary[:1], binary[2:6], binary[6:]\n", 74 | " fp16[index] = int(frac, 2)*2**int(exp, 2)\n", 75 | " \n", 76 | " \n", 77 | " if index % 10000 == 0:\n", 78 | " print(index)" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": null, 84 | "id": "ae99bb10", 85 | "metadata": {}, 86 | "outputs": [], 87 | "source": [] 88 | } 89 | ], 90 | "metadata": { 91 | "kernelspec": { 92 | "display_name": "Python 3 (ipykernel)", 93 | "language": "python", 94 | "name": "python3" 95 | }, 96 | "language_info": { 97 | "codemirror_mode": { 98 | "name": "ipython", 99 | "version": 3 100 | }, 101 | "file_extension": ".py", 102 | "mimetype": "text/x-python", 103 | "name": "python", 104 | "nbconvert_exporter": "python", 105 | "pygments_lexer": "ipython3", 106 | "version": "3.10.6" 107 | } 108 | }, 109 | "nbformat": 4, 110 | "nbformat_minor": 5 111 | } 112 | -------------------------------------------------------------------------------- /fp16-to-bf16/train-fp16.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 109, 6 | "id": "17e58755", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import random\n", 11 | "\n", 12 | "import numpy as np\n", 13 | "import torch\n", 14 | "\n", 15 | "import torch.nn as nn\n", 16 | "import torchvision \n", 17 | "import torchvision.transforms as transforms" 18 | ] 19 | }, 20 | { 21 | "cell_type": "code", 22 | "execution_count": 2, 23 | "id": "05f2ded0", 24 | "metadata": {}, 25 | "outputs": [], 26 | "source": [ 27 | "device = torch.device(\"cuda\")" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": 179, 33 | "id": "2b5ae7a6", 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "class MyModel(nn.Module):\n", 38 | " def __init__(self, input_size, output_size, dtype):\n", 39 | " super().__init__()\n", 40 | " \n", 41 | " \n", 42 | " self.linear = nn.Linear(input_size, 3, bias=False)\n", 43 | " if dtype == torch.bfloat16:\n", 44 | " self.linear.weight = nn.Parameter(self.linear.weight * 3)\n", 45 | " else:\n", 46 | " self.linear.weight = nn.Parameter(self.linear.weight / 10.)\n", 47 | " \n", 48 | " self.non_lin = nn.ReLU()\n", 49 | " self.output = nn.Linear(3, output_size, bias=False)\n", 50 | " if dtype == torch.bfloat16:\n", 51 | " self.output.weight = nn.Parameter(self.output.weight * 3)\n", 52 | " else:\n", 53 | " self.output.weight = nn.Parameter(self.output.weight / 10.)\n", 54 | " \n", 55 | " def forward(self, inputs):\n", 56 | " inputs = self.linear(inputs)\n", 57 | " inputs = self.non_lin(inputs)\n", 58 | " return self.output(inputs)" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": 180, 64 | "id": "0775f6d3", 65 | "metadata": {}, 66 | "outputs": [], 67 | "source": [ 68 | "random.seed(0)\n", 69 | "torch.manual_seed(0)\n", 70 | "torch.cuda.manual_seed(0)\n", 71 | "torch.backends.cudnn.deterministic = True\n", 72 | "dataset_input = torch.randn((1000, 4), device=device)\n", 73 | "dataset_label = (0.1 * (torch.randint(0, 10, (1000,), device=device)) + 5 ).long() * 10" 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": 189, 79 | "id": "9e3ca4f7", 80 | "metadata": { 81 | "scrolled": true 82 | }, 83 | "outputs": [ 84 | { 85 | "name": "stdout", 86 | "output_type": "stream", 87 | "text": [ 88 | "3.751953125\n", 89 | "2.7734375\n", 90 | "1.1162109375\n", 91 | "0.1075439453125\n", 92 | "0.0081787109375\n", 93 | "0.0009007453918457031\n", 94 | "0.00014412403106689453\n", 95 | "2.4437904357910156e-05\n", 96 | "5.125999450683594e-06\n", 97 | "finished\n", 98 | "4.1875\n", 99 | "4.0\n", 100 | "3.65625\n", 101 | "2.28125\n", 102 | "0.2734375\n", 103 | "0.01104736328125\n", 104 | "0.0004291534423828125\n", 105 | "1.239776611328125e-05\n", 106 | "2.384185791015625e-07\n", 107 | "finished\n" 108 | ] 109 | } 110 | ], 111 | "source": [ 112 | "data = {}\n", 113 | "for j, dtype in enumerate((torch.float16, torch.bfloat16)):\n", 114 | " \n", 115 | " random.seed(0)\n", 116 | " torch.manual_seed(0)\n", 117 | " torch.cuda.manual_seed(0)\n", 118 | " torch.backends.cudnn.deterministic = True\n", 119 | " \n", 120 | " model = MyModel(4, 60, dtype=dtype).to(device=device)\n", 121 | "\n", 122 | " optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)\n", 123 | " scaler = torch.cuda.amp.GradScaler()\n", 124 | " \n", 125 | " f_loss = nn.CrossEntropyLoss()\n", 126 | " \n", 127 | " for epoch in range(9):\n", 128 | " for i in range(1000):\n", 129 | "\n", 130 | " with torch.cuda.amp.autocast(enabled=True, dtype=dtype):\n", 131 | " output = model(dataset_input[i].unsqueeze(0))\n", 132 | " loss = f_loss(output, dataset_label[i].unsqueeze(0))\n", 133 | "\n", 134 | " # Backward and optimize\n", 135 | " optimizer.zero_grad()\n", 136 | " \n", 137 | " if dtype==torch.bfloat16:\n", 138 | " loss.backward()\n", 139 | " optimizer.step()\n", 140 | " else:\n", 141 | " scaler.scale(loss).backward()\n", 142 | "\n", 143 | " # scaler.step() first unscales the gradients of the optimizer's assigned params.\n", 144 | " # If these gradients do not contain infs or NaNs, optimizer.step() is then called,\n", 145 | " # otherwise, optimizer.step() is skipped.\n", 146 | " scaler.step(optimizer)\n", 147 | "\n", 148 | " # Updates the scale for next iteration.\n", 149 | " scaler.update()\n", 150 | "\n", 151 | " print(loss.item())\n", 152 | " \n", 153 | " \n", 154 | " for name, parameter in model.named_parameters():\n", 155 | " u, s, v = torch.svd_lowrank(parameter, q=3)\n", 156 | " \n", 157 | " data[f\"{j}_{dtype}_{name}\"] = parameter\n", 158 | " \n", 159 | " data[f\"{j}_{dtype}_{name}_u\"] = u\n", 160 | " data[f\"{j}_{dtype}_{name}_s\"] = s\n", 161 | " data[f\"{j}_{dtype}_{name}_v\"] = v\n", 162 | " \n", 163 | " print(\"finished\")" 164 | ] 165 | }, 166 | { 167 | "cell_type": "code", 168 | "execution_count": 190, 169 | "id": "ae31a9b0", 170 | "metadata": { 171 | "scrolled": true 172 | }, 173 | "outputs": [ 174 | { 175 | "name": "stdout", 176 | "output_type": "stream", 177 | "text": [ 178 | "tensor([[-0.4757, 0.6781, -0.5603],\n", 179 | " [ 0.7616, -0.0011, -0.6480],\n", 180 | " [-0.4400, -0.7350, -0.5159]], device='cuda:0',\n", 181 | " grad_fn=)\n", 182 | "tensor([3.2457, 2.2845, 0.1098], device='cuda:0', grad_fn=)\n", 183 | "tensor([[-0.3599, 0.7246, 0.5300],\n", 184 | " [-0.5523, -0.4250, 0.4669],\n", 185 | " [ 0.1861, -0.5160, 0.5169],\n", 186 | " [ 0.7286, 0.1676, 0.4837]], device='cuda:0', grad_fn=)\n" 187 | ] 188 | } 189 | ], 190 | "source": [ 191 | "for x in (\"u\", \"s\", \"v\"):\n", 192 | " print(data[f'0_torch.float16_linear.weight_{x}'])" 193 | ] 194 | }, 195 | { 196 | "cell_type": "code", 197 | "execution_count": 191, 198 | "id": "73aa9221", 199 | "metadata": {}, 200 | "outputs": [ 201 | { 202 | "name": "stdout", 203 | "output_type": "stream", 204 | "text": [ 205 | "tensor([[-0.1856, -0.9520, 0.2434],\n", 206 | " [ 0.6623, 0.0618, 0.7467],\n", 207 | " [-0.7259, 0.2998, 0.6190]], device='cuda:0',\n", 208 | " grad_fn=)\n", 209 | "tensor([3.6283, 1.8956, 0.3372], device='cuda:0', grad_fn=)\n", 210 | "tensor([[-0.1361, -0.2753, -0.9333],\n", 211 | " [-0.1401, -0.1053, 0.2421],\n", 212 | " [ 0.4078, 0.8456, -0.2649],\n", 213 | " [ 0.8920, -0.4451, 0.0167]], device='cuda:0', grad_fn=)\n" 214 | ] 215 | } 216 | ], 217 | "source": [ 218 | "for x in (\"u\", \"s\", \"v\"):\n", 219 | " print(data[f'1_torch.bfloat16_linear.weight_{x}'])" 220 | ] 221 | }, 222 | { 223 | "cell_type": "code", 224 | "execution_count": 192, 225 | "id": "16a0f8a4", 226 | "metadata": {}, 227 | "outputs": [ 228 | { 229 | "name": "stdout", 230 | "output_type": "stream", 231 | "text": [ 232 | "Parameter containing:\n", 233 | "tensor([[ 1.6456, 0.1657, -1.1185, -0.8951],\n", 234 | " [-0.9291, -1.3975, 0.4246, 1.7662],\n", 235 | " [-0.7328, 1.4759, 0.5712, -1.3493]], device='cuda:0',\n", 236 | " requires_grad=True)\n", 237 | "Parameter containing:\n", 238 | "tensor([[-3.0909, -3.2617, -2.9867],\n", 239 | " [-3.0261, -3.1987, -2.9237],\n", 240 | " [-3.0696, -3.2513, -2.9388],\n", 241 | " [-2.9730, -3.2431, -2.9065],\n", 242 | " [-3.0353, -3.2192, -2.9048],\n", 243 | " [-3.0885, -3.2604, -2.9769],\n", 244 | " [-3.0623, -3.1596, -3.0052],\n", 245 | " [-3.0633, -3.2635, -3.0158],\n", 246 | " [-3.0674, -3.1641, -2.9391],\n", 247 | " [-3.0043, -3.2185, -2.9878],\n", 248 | " [-3.0221, -3.2824, -2.9980],\n", 249 | " [-3.0613, -3.1806, -2.9290],\n", 250 | " [-3.0543, -3.2264, -2.9223],\n", 251 | " [-2.9691, -3.2000, -2.9461],\n", 252 | " [-2.9860, -3.2657, -2.9405],\n", 253 | " [-3.0804, -3.2640, -2.9919],\n", 254 | " [-3.0080, -3.1943, -2.9946],\n", 255 | " [-3.0149, -3.1858, -2.9676],\n", 256 | " [-3.0253, -3.2099, -2.9217],\n", 257 | " [-2.9709, -3.2762, -2.9719],\n", 258 | " [-3.0029, -3.1717, -2.9066],\n", 259 | " [-2.9822, -3.2084, -3.0078],\n", 260 | " [-3.0292, -3.2606, -3.0127],\n", 261 | " [-2.9842, -3.1695, -3.0182],\n", 262 | " [-3.0200, -3.2339, -2.9672],\n", 263 | " [-3.0646, -3.1923, -2.9999],\n", 264 | " [-3.0048, -3.1931, -2.9152],\n", 265 | " [-3.0087, -3.2873, -2.9918],\n", 266 | " [-3.0056, -3.2067, -3.0048],\n", 267 | " [-3.0656, -3.1597, -2.9255],\n", 268 | " [-3.0634, -3.2337, -3.0187],\n", 269 | " [-3.0347, -3.2694, -3.0029],\n", 270 | " [-3.0377, -3.2104, -2.9853],\n", 271 | " [-2.9856, -3.2701, -2.8965],\n", 272 | " [-2.9872, -3.2805, -2.9662],\n", 273 | " [-3.0272, -3.2140, -2.9450],\n", 274 | " [-3.0100, -3.2176, -2.9869],\n", 275 | " [-3.0024, -3.2859, -2.9881],\n", 276 | " [-3.0474, -3.2529, -2.9798],\n", 277 | " [-3.0800, -3.2349, -2.9498],\n", 278 | " [-3.0668, -3.2270, -2.9188],\n", 279 | " [-3.0382, -3.2199, -2.9651],\n", 280 | " [-3.0140, -3.1846, -2.9025],\n", 281 | " [-2.9949, -3.1600, -2.9640],\n", 282 | " [-3.0810, -3.2545, -2.9202],\n", 283 | " [-3.0348, -3.2525, -3.0038],\n", 284 | " [-3.0879, -3.2749, -2.9730],\n", 285 | " [-3.0047, -3.1836, -3.0174],\n", 286 | " [-2.9910, -3.2765, -2.9645],\n", 287 | " [-3.0571, -3.2333, -2.9727],\n", 288 | " [ 2.9798, 3.1726, 2.9504],\n", 289 | " [-3.0263, -3.2550, -2.9336],\n", 290 | " [-3.0812, -3.2280, -2.9100],\n", 291 | " [-3.0572, -3.1614, -2.9435],\n", 292 | " [-3.0902, -3.1759, -2.9745],\n", 293 | " [-3.0666, -3.1628, -3.0150],\n", 294 | " [-3.0189, -3.2405, -2.9135],\n", 295 | " [-3.0128, -3.1927, -2.9379],\n", 296 | " [-3.0507, -3.2320, -3.0100],\n", 297 | " [-2.9950, -3.1737, -2.9171]], device='cuda:0', requires_grad=True)\n" 298 | ] 299 | } 300 | ], 301 | "source": [ 302 | "for x in (\"linear\", \"output\"):\n", 303 | " print(data[f'0_torch.float16_{x}.weight'])" 304 | ] 305 | }, 306 | { 307 | "cell_type": "code", 308 | "execution_count": 193, 309 | "id": "2e4932ae", 310 | "metadata": {}, 311 | "outputs": [ 312 | { 313 | "name": "stdout", 314 | "output_type": "stream", 315 | "text": [ 316 | "Parameter containing:\n", 317 | "tensor([[ 0.5118, 0.3042, -1.8223, 0.2039],\n", 318 | " [-0.5942, -0.2880, 1.0123, 2.0953],\n", 319 | " [ 0.0071, 0.3596, -0.6487, -2.5987]], device='cuda:0',\n", 320 | " requires_grad=True)\n", 321 | "Parameter containing:\n", 322 | "tensor([[-4.4419, -4.9634, -4.3817],\n", 323 | " [-2.9160, -3.0629, -2.7884],\n", 324 | " [-4.1905, -4.6531, -3.2930],\n", 325 | " [-0.4725, -3.3024, -0.9470],\n", 326 | " [-3.3107, -3.7164, -2.3899],\n", 327 | " [-4.4433, -4.9202, -4.1564],\n", 328 | " [-3.5436, -1.4986, -4.4286],\n", 329 | " [-3.4613, -5.0607, -5.1958],\n", 330 | " [-4.0473, -1.5142, -2.7730],\n", 331 | " [-1.8561, -3.8900, -4.3769],\n", 332 | " [-2.3192, -5.5579, -4.8079],\n", 333 | " [-4.0529, -2.3221, -2.7766],\n", 334 | " [-3.8760, -3.9557, -2.8824],\n", 335 | " [-0.2306, -2.4180, -2.1388],\n", 336 | " [-1.2093, -4.8995, -2.8423],\n", 337 | " [-4.1038, -5.0288, -4.5613],\n", 338 | " [-2.0090, -3.0834, -4.5075],\n", 339 | " [-2.3666, -2.6742, -3.8056],\n", 340 | " [-2.8977, -3.4578, -2.7808],\n", 341 | " [-0.3357, -4.9915, -3.4214],\n", 342 | " [-1.7502, -1.6765, -1.5579],\n", 343 | " [-0.8048, -3.5790, -4.5389],\n", 344 | " [-2.4829, -5.0370, -5.1643],\n", 345 | " [-0.9278, -2.0825, -4.6137],\n", 346 | " [-2.4998, -4.2839, -4.0042],\n", 347 | " [-3.7818, -2.9778, -4.6614],\n", 348 | " [-2.0236, -2.7549, -2.2046],\n", 349 | " [-1.9007, -5.6943, -4.6004],\n", 350 | " [-1.8497, -3.5475, -4.7965],\n", 351 | " [-3.9569, -1.1794, -2.3032],\n", 352 | " [-3.5146, -4.3009, -5.2434],\n", 353 | " [-2.6969, -5.2310, -4.9397],\n", 354 | " [-3.0324, -3.5839, -4.4497],\n", 355 | " [-1.1692, -4.5088, -1.2816],\n", 356 | " [-1.1784, -5.4390, -3.6473],\n", 357 | " [-2.8891, -3.6375, -3.4287],\n", 358 | " [-2.0753, -3.8593, -4.4060],\n", 359 | " [-1.6883, -5.6669, -4.4468],\n", 360 | " [-3.2633, -4.7668, -4.3690],\n", 361 | " [-4.4553, -4.2293, -3.4943],\n", 362 | " [-4.2752, -3.9635, -2.7446],\n", 363 | " [-3.1437, -3.8463, -3.9787],\n", 364 | " [-2.3573, -2.3726, -1.8725],\n", 365 | " [-1.4033, -1.3287, -3.0406],\n", 366 | " [-4.6340, -4.7124, -2.7629],\n", 367 | " [-2.7265, -4.8092, -4.9508],\n", 368 | " [-4.4315, -5.2667, -4.0727],\n", 369 | " [-1.8000, -2.7150, -4.9971],\n", 370 | " [-1.3535, -5.3647, -3.6639],\n", 371 | " [-3.6391, -4.2246, -4.1554],\n", 372 | " [ 1.2187, 2.2229, 3.2045],\n", 373 | " [-2.8537, -4.7942, -3.1981],\n", 374 | " [-4.7412, -3.9762, -2.4346],\n", 375 | " [-3.6646, -1.3561, -2.8693],\n", 376 | " [-4.6493, -2.2637, -3.7564],\n", 377 | " [-3.6628, -1.7340, -4.7200],\n", 378 | " [-2.6589, -4.3691, -2.5699],\n", 379 | " [-2.3692, -2.8685, -3.0278],\n", 380 | " [-3.2079, -4.2578, -5.0685],\n", 381 | " [-1.4336, -1.7821, -1.7491]], device='cuda:0', requires_grad=True)\n" 382 | ] 383 | } 384 | ], 385 | "source": [ 386 | "for x in (\"linear\", \"output\"):\n", 387 | " print(data[f'1_torch.bfloat16_{x}.weight'])" 388 | ] 389 | }, 390 | { 391 | "cell_type": "code", 392 | "execution_count": 176, 393 | "id": "ee5a7f88", 394 | "metadata": {}, 395 | "outputs": [], 396 | "source": [ 397 | "def fp16_to_bf16(weight, q=6):\n", 398 | " # basic: convert each sep\n", 399 | " \n", 400 | " u, s, v = torch.svd_lowrank(weight, q=q, niter=10)\n", 401 | " return u, s, v" 402 | ] 403 | }, 404 | { 405 | "cell_type": "code", 406 | "execution_count": 177, 407 | "id": "4a337e90", 408 | "metadata": {}, 409 | "outputs": [ 410 | { 411 | "ename": "AssertionError", 412 | "evalue": "(torch.Size([3, 3]), 6)", 413 | "output_type": "error", 414 | "traceback": [ 415 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 416 | "\u001b[0;31mAssertionError\u001b[0m Traceback (most recent call last)", 417 | "Cell \u001b[0;32mIn [177], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m name, parameter \u001b[38;5;129;01min\u001b[39;00m model\u001b[38;5;241m.\u001b[39mnamed_parameters():\n\u001b[0;32m----> 2\u001b[0m u, s, v \u001b[38;5;241m=\u001b[39m \u001b[43mfp16_to_bf16\u001b[49m\u001b[43m(\u001b[49m\u001b[43mparameter\u001b[49m\u001b[43m)\u001b[49m\n", 418 | "Cell \u001b[0;32mIn [176], line 4\u001b[0m, in \u001b[0;36mfp16_to_bf16\u001b[0;34m(weight, q)\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mfp16_to_bf16\u001b[39m(weight, q\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m6\u001b[39m):\n\u001b[1;32m 2\u001b[0m \u001b[38;5;66;03m# basic: convert each sep\u001b[39;00m\n\u001b[0;32m----> 4\u001b[0m u, s, v \u001b[38;5;241m=\u001b[39m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msvd_lowrank\u001b[49m\u001b[43m(\u001b[49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mq\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mq\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mniter\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m10\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 5\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m u, s, v\n", 419 | "File \u001b[0;32m~/.env/pytorch/lib/python3.10/site-packages/torch/_lowrank.py:137\u001b[0m, in \u001b[0;36msvd_lowrank\u001b[0;34m(A, q, niter, M)\u001b[0m\n\u001b[1;32m 131\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mset\u001b[39m(\u001b[38;5;28mmap\u001b[39m(\u001b[38;5;28mtype\u001b[39m, tensor_ops))\u001b[38;5;241m.\u001b[39missubset(\n\u001b[1;32m 132\u001b[0m (torch\u001b[38;5;241m.\u001b[39mTensor, \u001b[38;5;28mtype\u001b[39m(\u001b[38;5;28;01mNone\u001b[39;00m))\n\u001b[1;32m 133\u001b[0m ) \u001b[38;5;129;01mand\u001b[39;00m has_torch_function(tensor_ops):\n\u001b[1;32m 134\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m handle_torch_function(\n\u001b[1;32m 135\u001b[0m svd_lowrank, tensor_ops, A, q\u001b[38;5;241m=\u001b[39mq, niter\u001b[38;5;241m=\u001b[39mniter, M\u001b[38;5;241m=\u001b[39mM\n\u001b[1;32m 136\u001b[0m )\n\u001b[0;32m--> 137\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_svd_lowrank\u001b[49m\u001b[43m(\u001b[49m\u001b[43mA\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mq\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mq\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mniter\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mniter\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mM\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mM\u001b[49m\u001b[43m)\u001b[49m\n", 420 | "File \u001b[0;32m~/.env/pytorch/lib/python3.10/site-packages/torch/_lowrank.py:168\u001b[0m, in \u001b[0;36m_svd_lowrank\u001b[0;34m(A, q, niter, M)\u001b[0m\n\u001b[1;32m 166\u001b[0m B_t \u001b[38;5;241m=\u001b[39m matmul(A, Q_c) \u001b[38;5;241m-\u001b[39m matmul(M, Q_c)\n\u001b[1;32m 167\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m B_t\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m2\u001b[39m] \u001b[38;5;241m==\u001b[39m m, (B_t\u001b[38;5;241m.\u001b[39mshape, m)\n\u001b[0;32m--> 168\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m B_t\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m] \u001b[38;5;241m==\u001b[39m q, (B_t\u001b[38;5;241m.\u001b[39mshape, q)\n\u001b[1;32m 169\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m B_t\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m] \u001b[38;5;241m<\u001b[39m\u001b[38;5;241m=\u001b[39m B_t\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m2\u001b[39m], B_t\u001b[38;5;241m.\u001b[39mshape\n\u001b[1;32m 170\u001b[0m U, S, Vh \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mlinalg\u001b[38;5;241m.\u001b[39msvd(B_t, full_matrices\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)\n", 421 | "\u001b[0;31mAssertionError\u001b[0m: (torch.Size([3, 3]), 6)" 422 | ] 423 | } 424 | ], 425 | "source": [ 426 | "for name, parameter in model.named_parameters():\n", 427 | " u, s, v = fp16_to_bf16(parameter) \n", 428 | " \n", 429 | " " 430 | ] 431 | }, 432 | { 433 | "cell_type": "code", 434 | "execution_count": 178, 435 | "id": "8b1cc9d9", 436 | "metadata": {}, 437 | "outputs": [ 438 | { 439 | "name": "stdout", 440 | "output_type": "stream", 441 | "text": [ 442 | "Parameter containing:\n", 443 | "tensor([[ 1.0611, 0.1330, -1.7605, 0.1466],\n", 444 | " [-0.7716, -0.2960, 1.1901, 2.2352],\n", 445 | " [-0.0403, 0.3733, -0.5693, -2.6609]], device='cuda:0',\n", 446 | " requires_grad=True)\n", 447 | "Parameter containing:\n", 448 | "tensor([[-3.9159, -4.2902, -4.0074],\n", 449 | " [-2.6747, -2.9679, -2.6374],\n", 450 | " [-3.6547, -4.0776, -3.1965],\n", 451 | " [-1.1376, -3.5658, -1.4755],\n", 452 | " [-2.9043, -3.4260, -2.3562],\n", 453 | " [-3.9046, -4.2604, -3.8511],\n", 454 | " [-3.3239, -1.9507, -4.1121],\n", 455 | " [-3.2540, -4.3546, -4.4939],\n", 456 | " [-3.6044, -1.9820, -2.9902],\n", 457 | " [-2.0055, -3.5081, -3.7839],\n", 458 | " [-2.3974, -4.7042, -4.1317],\n", 459 | " [-3.5135, -2.4737, -2.8746],\n", 460 | " [-3.3641, -3.5888, -2.8332],\n", 461 | " [-0.9957, -2.8861, -2.3710],\n", 462 | " [-1.5866, -4.2982, -2.6797],\n", 463 | " [-3.6862, -4.3365, -4.1030],\n", 464 | " [-2.1020, -2.9508, -3.8896],\n", 465 | " [-2.3492, -2.6882, -3.4017],\n", 466 | " [-2.6588, -3.2364, -2.6149],\n", 467 | " [-1.0227, -4.4652, -3.2033],\n", 468 | " [-1.9665, -2.1739, -1.8420],\n", 469 | " [-1.2726, -3.3118, -3.9496],\n", 470 | " [-2.5283, -4.3307, -4.3942],\n", 471 | " [-1.3413, -2.3135, -4.0655],\n", 472 | " [-2.4632, -3.7983, -3.5229],\n", 473 | " [-3.4194, -2.8780, -4.1572],\n", 474 | " [-2.0788, -2.7962, -2.1693],\n", 475 | " [-2.0808, -4.7921, -3.9656],\n", 476 | " [-1.9980, -3.2654, -4.0889],\n", 477 | " [-3.5653, -1.8077, -2.7016],\n", 478 | " [-3.2813, -3.8062, -4.5239],\n", 479 | " [-2.6803, -4.4737, -4.2473],\n", 480 | " [-2.8596, -3.3028, -3.9003],\n", 481 | " [-1.5199, -4.1593, -1.5349],\n", 482 | " [-1.5643, -4.6262, -3.2834],\n", 483 | " [-2.6949, -3.3527, -3.1123],\n", 484 | " [-2.1609, -3.4884, -3.8051],\n", 485 | " [-1.9231, -4.7669, -3.8524],\n", 486 | " [-3.0583, -4.1479, -3.8725],\n", 487 | " [-3.8560, -3.7697, -3.3829],\n", 488 | " [-3.6453, -3.5935, -2.8057],\n", 489 | " [-2.9173, -3.4935, -3.5576],\n", 490 | " [-2.2856, -2.5579, -1.9568],\n", 491 | " [-1.7839, -1.8945, -3.0339],\n", 492 | " [-3.9329, -4.1244, -2.8829],\n", 493 | " [-2.6949, -4.1705, -4.2525],\n", 494 | " [-3.8950, -4.5112, -3.7919],\n", 495 | " [-1.9580, -2.7017, -4.2606],\n", 496 | " [-1.6814, -4.5647, -3.2842],\n", 497 | " [-3.3084, -3.7611, -3.7534],\n", 498 | " [ 1.6819, 2.4937, 3.0662],\n", 499 | " [-2.6795, -4.1715, -2.9375],\n", 500 | " [-3.9633, -3.6030, -2.6700],\n", 501 | " [-3.3674, -1.8966, -3.0372],\n", 502 | " [-4.0097, -2.4085, -3.6645],\n", 503 | " [-3.3919, -2.0825, -4.2951],\n", 504 | " [-2.4937, -3.8754, -2.4265],\n", 505 | " [-2.3258, -2.8296, -2.7955],\n", 506 | " [-3.0465, -3.7761, -4.3718],\n", 507 | " [-1.7695, -2.2439, -1.9860]], device='cuda:0', requires_grad=True)\n" 508 | ] 509 | } 510 | ], 511 | "source": [ 512 | "for name, parameter in model.named_parameters():\n", 513 | " #bf16_weight = fp16_to_bf16(parameter) \n", 514 | " print(parameter)" 515 | ] 516 | }, 517 | { 518 | "cell_type": "code", 519 | "execution_count": 74, 520 | "id": "ed27d074", 521 | "metadata": {}, 522 | "outputs": [ 523 | { 524 | "ename": "SyntaxError", 525 | "evalue": "invalid syntax (3891572883.py, line 1)", 526 | "output_type": "error", 527 | "traceback": [ 528 | "\u001b[0;36m Cell \u001b[0;32mIn [74], line 1\u001b[0;36m\u001b[0m\n\u001b[0;31m Parameter containing:\u001b[0m\n\u001b[0m ^\u001b[0m\n\u001b[0;31mSyntaxError\u001b[0m\u001b[0;31m:\u001b[0m invalid syntax\n" 529 | ] 530 | } 531 | ], 532 | "source": [ 533 | "#bf16\n", 534 | "Parameter containing:\n", 535 | "tensor([[-0.1229, -0.6971, -0.0126, 0.3644],\n", 536 | " [ 0.4843, 0.5767, 0.0836, 0.4326],\n", 537 | " [-0.1458, -0.1250, -0.5431, -0.6094]], device='cuda:0',\n", 538 | " requires_grad=True)\n", 539 | "\n", 540 | "\n", 541 | "#fb16\n", 542 | "tensor([[-0.6800, 0.8049, 0.4577, 0.3503],\n", 543 | " [ 0.2505, -0.3497, 0.0312, 0.3005],\n", 544 | " [-0.5663, 0.8540, 0.1833, 0.6032]], device='cuda:0',\n", 545 | " requires_grad=True)" 546 | ] 547 | }, 548 | { 549 | "cell_type": "code", 550 | "execution_count": null, 551 | "id": "1b96806c", 552 | "metadata": {}, 553 | "outputs": [], 554 | "source": [ 555 | "torch.svd_lowrank(torch)" 556 | ] 557 | }, 558 | { 559 | "cell_type": "code", 560 | "execution_count": null, 561 | "id": "5a87a092", 562 | "metadata": {}, 563 | "outputs": [], 564 | "source": [] 565 | } 566 | ], 567 | "metadata": { 568 | "kernelspec": { 569 | "display_name": "Python 3 (ipykernel)", 570 | "language": "python", 571 | "name": "python3" 572 | }, 573 | "language_info": { 574 | "codemirror_mode": { 575 | "name": "ipython", 576 | "version": 3 577 | }, 578 | "file_extension": ".py", 579 | "mimetype": "text/x-python", 580 | "name": "python", 581 | "nbconvert_exporter": "python", 582 | "pygments_lexer": "ipython3", 583 | "version": "3.10.6" 584 | } 585 | }, 586 | "nbformat": 4, 587 | "nbformat_minor": 5 588 | } 589 | -------------------------------------------------------------------------------- /images/regularization-Mkclz.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stas00/ml-ways/1a9887be460aba9b7a5f8f3ea8002066cba5bf22/images/regularization-Mkclz.png -------------------------------------------------------------------------------- /images/regularization-XmtF2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stas00/ml-ways/1a9887be460aba9b7a5f8f3ea8002066cba5bf22/images/regularization-XmtF2.png -------------------------------------------------------------------------------- /images/regularization-cmWO0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stas00/ml-ways/1a9887be460aba9b7a5f8f3ea8002066cba5bf22/images/regularization-cmWO0.png -------------------------------------------------------------------------------- /images/regularization-jlQYp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stas00/ml-ways/1a9887be460aba9b7a5f8f3ea8002066cba5bf22/images/regularization-jlQYp.png -------------------------------------------------------------------------------- /ml.md: -------------------------------------------------------------------------------- 1 | # Machine Learning 2 | 3 | 4 | ## Concepts 5 | 6 | ### Inductive Bias 7 | 8 | Fundamental assumptions that the learner makes about the target function that enables it to generalize beyond the training data. These assumptions are used to choose one generalization over another. 9 | 10 | Examples: 11 | 12 | * Support Vector Machines - Distinct classes tend to be separated by wide margins. 13 | * Naive Bayes - Each input depends only on the output class or label; the inputs are independent from each other. 14 | * Linear Regression - The relationship between the attributes x and the output y is linear. 15 | 16 | More common types: 17 | 18 | * [The Inductive Biases of Various Machine Learning Algorithms](http://www.lauradhamilton.com/inductive-biases-various-machine-learning-algorithms) 19 | * [wiki](https://en.wikipedia.org/wiki/Inductive_bias#Types) 20 | 21 | 22 | ## Model Explainability 23 | 24 | 25 | * [Interpretable Machine Learning](https://christophm.github.io/interpretable-ml-book/) - A Guide for Making Black Box Models Explainable. 26 | * [Machine Learning Explainability](https://www.kaggle.com/learn/machine-learning-explainability) - Kaggle Tutorial 27 | 28 | Tools: 29 | 30 | * [SHAP](https://github.com/slundberg/shap) - A game theoretic approach to explain the output of any machine learning model. 31 | * [LIME](https://github.com/marcotcr/lime) - Local Interpretable Model-Agnostic Explanations 32 | 33 | 34 | 35 | 36 | ## Distillation 37 | 38 | Distillation == once a neural network has been trained, its full output distributions can be approximated using a smaller network. 39 | 40 | * [Awesome Knowledge Distillation](https://github.com/dkozlov/awesome-knowledge-distillation) - a great compilation of resources. 41 | 42 | 43 | ## Algorithms 44 | 45 | System order: 46 | * First-order algorithms require a first-derivative/gradient (Jacobian). 47 | * Second-order algorithms require a second-derivative/gradient (Hessian). 48 | 49 | 50 | ## Competitions 51 | 52 | * https://www.kaggle.com 53 | * https://tianchi.aliyun.com/competition/gameList/activeList 54 | * https://evalai.cloudcv.org/ 55 | * https://www.drivendata.org/competitions/ 56 | * https://www.aicrowd.com/ 57 | * https://datahack.analyticsvidhya.com/ 58 | * https://competitions.codalab.org/competitions/ 59 | * http://tunedit.org/challenges 60 | * https://www.innocentive.com/ar/challenge/browse 61 | * https://www.crowdanalytix.com/community 62 | * https://www.hackerearth.com/challenges 63 | * https://www.topcoder.com/challenges?filter[tracks][data_science]=true&bucket=ongoing 64 | * https://www.machinehack.com/course-cat/modeling/ 65 | * https://quant-quest.com/competitions/ 66 | 67 | Aggregators: 68 | 69 | * https://www.reddit.com/r/ai_competitions/ 70 | * https://mlcontests.com/ 71 | -------------------------------------------------------------------------------- /nlp.md: -------------------------------------------------------------------------------- 1 | # Natural Language Processing 2 | 3 | Dense notes on NLP. 4 | 5 | ## Abbreviations 6 | 7 | More extensive lists of abbreviations [1](https://github.com/AgaMiko/machine-learning-acronyms), [2](https://machinelearning.wtf/acronyms/) 8 | 9 | ``` 10 | ASGD Averaged Stochastic Gradient Descent 11 | AWD-LSTM ASGD Weight-Dropped LSTM 12 | BERT Bidirectional Encoder Representations from Transformers 13 | BPE Byte Pair Encoding 14 | BiLM Bidirectional Language Model 15 | CBOW Continuous Bag-Of-Words 16 | CFG Context-free Grammar 17 | CL Computational Linguistics 18 | CVT Cross-View Training 19 | CoLA Corpus of Linguistic Acceptability 20 | CoVe Contextual Word Vectors 21 | CRF Conditional Random Field 22 | DAG Directed Acyclic Graph 23 | DAE Denoising Auto-Encoder 24 | DCN Dynamic Coattention Network 25 | DCNN Dynamic Convolutional Neural Network 26 | DMN Dynamic Memory Network 27 | EDA Exploratory Data Analysis 28 | ELMo Embeddings from Language Model 29 | ESA Explicit Semantic Analysis 30 | FGN Fine-Grained NER 31 | FOL First-Order Logic 32 | GAN Generative Adversarial Network 33 | GEC Grammatical Error Correction 34 | GPT Generative Pre-training Transformer 35 | GRU Gated-Recurrent Network 36 | GloVe Global Vectors for Word Representation 37 | HAL Hyperspace Analogue to Language 38 | HDP Hierarchical Dirichlet Process 39 | IE Information Extraction 40 | IR Information Retrieval 41 | LDA Latent Dirichlet Allocation 42 | LSA Latent Semantic Analysis (Truncated SVD) 43 | LSI Latent Semantic Indexing 44 | LSTM Long Short-Term Memory 45 | MAE Mean Absolute Error 46 | MLM Mask Language Model 47 | MNLI Multi-Genre NLI 48 | MRPC MicRosoft Paraphrase Corpus 49 | MSE Mean Squared Error 50 | MaxEnt Maximum Entropy (classifier) (softmax) 51 | NER Named-Entity Recognition 52 | NLG Natural Language Generation 53 | NLI Natural Language Inference (Text Entailment) 54 | NLP Natural Language Processing 55 | NLU Natural Language Understanding 56 | NMT Neural Machine Translation 57 | NTN Neural Tensor Network 58 | NiN Network-in-network (1x1 convconnections) 59 | PCFG Probabilistic Context Free Grammar 60 | POS Parts-Of-Speech 61 | QRNN Quasi-Recurrent Neural Networks 62 | QNLI Question NLI 63 | RACE ReAding Comprehension from Examinations 64 | RMSE Root Mean Squared Error 65 | RNN Recurrent Neural Network 66 | RNN Recursive Neural Network 67 | RNTN Recursive Neural Tensor Network 68 | RP Random Projections 69 | RTE Recognizing Textual Entailment (now called NLI) 70 | SG Skip-Gram 71 | SNLI Stanford Natural Language Inference 72 | SOTA State-Of-The-Art 73 | SQuAD Stanford Question Answering Dataset 74 | SRL Semantic Role Labeling 75 | SST Stanford Sentiment Treebank 76 | STLR Slanted Triangular Learning Rates 77 | SWAG Situations With Adversarial Generations 78 | TDNN Time-Delayed Neural Network 79 | TF Term­Frequency 80 | TF­IDF Term­Frequency­Inverse­Document­Frequency 81 | TLM Translation Language Modeling 82 | ULMFiT Universal Language Model Fine-tuning 83 | USE Universal Sentence Encoder 84 | VAE Variational Autoenconder 85 | VSM Vector Space Model 86 | WSD Word Sense Disambiguation 87 | ZSL Zero-Shot Learning 88 | t-SNE t-distributed Stochastic Neighbor Embedding 89 | 90 | 91 | 92 | ``` 93 | 94 | ## Glossary and Terminology 95 | 96 | **Denotational semantics**: The concept of representing an idea as a symbol (a word or a one-hot vector). It is sparse and cannot capture similarity. This is a "localist" representation. 97 | 98 | **Distributional semantics**: The concept of representing the meaning of a word based on the context in which it usually appears. It is dense and can better capture similarity. 99 | 100 | **Distributional similarity**: similar words have similar context. 101 | 102 | **Transformer** is an architecture for transforming one sequence into another one with the help of two parts (Encoder and Decoder). 103 | 104 | **Constituency Parsing** is a way to break a piece of text (e.g. one sentence) into sub-phrases. One of the goals of constituency parsing (also known as "phrase structure parsing") is to identify the constituents in the text which would be useful when extracting information from text. By knowing the constituents after parsing the sentence, it is possible to generate similar sentences that are syntactically correct. 105 | 106 | **Lemmas** are root forms of words. 107 | 108 | **Named Entity Recognition**: which words in a sentence are a proper name, organization name, or entity? 109 | 110 | **Textual Entailment**: given two sentences, does the first sentence entail or contradict the second sentence? 111 | 112 | **Coreference Resolution**: given a pronoun like “it” in a sentence that discusses multiple objects, which object does “it” refer to? 113 | 114 | 115 | 116 | ### Word Embedding Models 117 | 118 | Popular off-the-shelf word embedding models: 119 | 120 | - Word2Vec (by Google) 121 | - GloVe (by Stanford) 122 | - fastText (by Facebook) 123 | 124 | 125 | #### Word2vec 126 | 127 | - 2 algorithms: continuous bag-of-words (CBOW) and skip-gram. CBOW aims to predict a center word from the surrounding context in terms of word vectors. Skip-gram does the opposite, and predicts the distribution (probability) of context words from a center word. 128 | 129 | - 2 training methods: negative sampling and hierarchical softmax. Negative sampling defines an objective by sampling negative examples, while hierarchical softmax defines an objective using an efficient tree structure to compute probabilities for all the vocabulary. 130 | 131 | 132 | 133 | ## Augmentation 134 | 135 | 136 | https://amitness.com/2020/05/data-augmentation-for-nlp/ 137 | 138 | 139 | ## Metrics 140 | 141 | 142 | 143 | ### Perplexity 144 | 145 | Perplexity is often used as an intrinsic evaluation metric for gauging how well a language model can capture the real word distribution conditioned on the context. 146 | 147 | A [perplexity](https://en.wikipedia.org/wiki/Perplexity) of a discrete probability distribution pp is defined as the exponentiation of the entropy: 148 | 149 | $2^{H(p)} = 2^{-\sum_x p(x) \log_2 p(x)}$ 150 | 151 | Given a sentence with $N$ words, $s = (w_1, \dots, w_N)$, the entropy looks as follows, simply assuming that each word has the same frequency, $\frac{1}{N}$: 152 | 153 | $H(s) = -\sum_{i=1}^N P(w_i) \log_2 p(w_i) = -\sum_{i=1}^N \frac{1}{N} \log_2 p(w_i)$ 154 | 155 | The perplexity for the sentence becomes: 156 | 157 | $ 158 | 2^{H(s)} = 2^{-\frac{1}{N} \sum_{i=1}^N \log_2 p(w_i)} 159 | = (2^{\sum_{i=1}^N \log_2 p(w_i)})^{-\frac{1}{N}} 160 | = (p(w_1) \dots p(w_N))^{-\frac{1}{N}} 161 | $ 162 | 163 | A good language model should predict high word probabilities. Therefore, the smaller perplexity the better. 164 | 165 | 166 | ### Compilations 167 | 168 | * cs224u: [Evaluation metrics in NLP](https://github.com/cgpotts/cs224u/blob/master/evaluation_metrics.ipynb) 169 | * scikit: [Metrics and scoring: quantifying the quality of predictions](https://scikit-learn.org/stable/modules/model_evaluation.html) 170 | 171 | 172 | ## Functions 173 | 174 | ### Softmax 175 | 176 | After applying softmax, each component will be in the interval (0, 1) and the total will add up to 1, so that they can be interpreted as probabilities. 177 | 178 | The larger input components will correspond to larger probabilities. 179 | 180 | **Temperature** is used to scale the logits before applying softmax. (logits/τ) 181 | 182 | 1. For high temperatures (τ → ∞), all components have nearly the same probability and the lower the temperature, the more expected values affect the probability. This results in more diversity and also more mistakes. 183 | 184 | 2. When the temperature is 1, the softmax is computed on unscaled logits. 185 | 186 | 3. For a low temperature (τ → 0), the probability of the action with the highest expected value tends to 1. Larger logit values makes softmax more confident, but also more conservative in its samples (it is less likely to sample from unlikely candidates). 187 | 188 | https://cs.stackexchange.com/a/79242/113823 189 | 190 | 191 | 192 | 193 | 194 | 195 | ## Linguistics 196 | 197 | - **[Hyponymy](https://en.wikipedia.org/wiki/Hyponymy_and_hypernymy)** && **Hypernymy**: a hyponym is a word or phrase whose semantic field is included within that of another word. San Francisco (hyponym) is an **instance of** a city (hypernym). A pigeon is a hyponym of bird; which, in turn, is a hyponym of animal. A bird is a hypernym of a pigeon. An animal is a hypernym of a bird. 198 | 199 | - **Antonymy**: acidic is the **opposite** of basic 200 | 201 | - **Meronymy**: an alternator is a **part of** a car 202 | 203 | - **[Polysemy](https://en.wikipedia.org/wiki/Polysemy)** is the capacity for a word or phrase to have multiple meanings, usually related by contiguity of meaning within a semantic field. e.g. crane: (n) machine, (n) bird, (v) to strain out one's neck. 204 | 205 | [Semantic change](https://en.wikipedia.org/wiki/Semantic_change) (also semantic shift, semantic progression, semantic development, or semantic drift) is a form of language change regarding the evolution of word usage—usually to the point that the modern meaning is radically different from the original usage. 206 | 207 | 208 | 209 | ### Monotonicity reasoning 210 | 211 | **Monotonicity**. A system is monotonic if it grows without shrinking. 212 | 213 | **Monotonicity reasoning** is a type of reasoning based on word replacement, requires the ability to capture the interaction between lexical and syntactic structures. Consider examples in (1) and (2). 214 | 215 | (1) a. All [ workers ↓] joined for a [French dinner ↑] 216 | b. All [ workers ] joined for a [ dinner ] 217 | c. All [new workers ] joined for a [French dinner ] 218 | 219 | (2) a. Not all [new workers ↑] joined for a dinner 220 | b. Not all [ workers ] joined for a dinner 221 | 222 | A context is **upward entailing** (shown by [... ↑]) that allows an inference from (1a) to (1b), where *French dinner* is replaced by a more general concept *dinner*. On the other hand, a **downward entailing** context (shown by [... ↓]) allows an inference from (1a) to (1c), where *workers* is replaced by a more specific concept *new workers*. Interestingly, the direction of monotonicity can be reversed again by embedding yet another downward entailing context (e.g., *not* in (2)), as witness the fact that (2a) entails (2b). To properly handle both directions of monotonicity, NLI models must detect monotonicity operators (e.g., all, not) and their arguments from the syntactic structure. 223 | (this excerpt is from [Can neural networks understand monotonicity reasoning?](https://arxiv.org/abs/1906.06448)) 224 | 225 | 226 | 227 | 228 | ## Libraries 229 | 230 | Useful libraries and modules: 231 | 232 | - [Annoy](https://github.com/spotify/annoy) (Approximate Nearest Neighbors Oh Yeah) is a C++ library with Python bindings to search for points in space that are close to a given query point. 233 | 234 | ## Transformers 235 | 236 | - Huggingface [transformers](https://github.com/huggingface/transformers). The main transformers library. 237 | 238 | - [Simple Transformers](https://github.com/ThilinaRajapakse/simpletransformers). Transformers made simple with training, evaluation, and prediction possible with one line each 239 | 240 | - [AdaptNLP](https://github.com/Novetta/adaptnlp). A high level framework and library for running, training, and deploying state-of-the-art Natural Language Processing (NLP) models for end to end tasks. Built on top of Zalando Research's Flair and Hugging Face's Transformers. 241 | 242 | - [spacy-transformers](https://github.com/explosion/spacy-transformers) provides spaCy model pipelines that wrap Hugging Face's transformers package, so you can use them in spaCy. 243 | 244 | ## Good Paper Explanations 245 | 246 | - AWD-LSTM: Average-SGD Weight-Dropped LSTM 247 | * [What makes the AWD-LSTM great](https://yashuseth.blog/2018/09/12/awd-lstm-explanation-understanding-language-model/) 248 | 249 | - ULMFiT: Universal Language Model Fine Tuning 250 | * [Understanding the Working of ULMFiT](https://yashuseth.blog/2018/06/17/understanding-universal-language-model-fine-tuning-ulmfit/) 251 | 252 | 253 | ## fastai NLP notebooks 254 | 255 | - seq2seq: 256 | * https://github.com/ohmeow/seq2seq-pytorch-fastai/blob/master/seq2seq-rnn-attn.ipynb 257 | 258 | 259 | ## Sources 260 | 261 | - [Stanford cs224n](http://web.stanford.edu/class/cs224n/) 262 | - [Generalized Language Models](https://lilianweng.github.io/lil-log/2019/01/31/generalized-language-models.html#metric-perplexity) 263 | 264 | ## Books 265 | 266 | - Chris Manning and Hinrich Schuetze - Foundations of Statistical Natural Language Processing 267 | 268 | 269 | ## Newsletters 270 | 271 | - [NLP Newsletter](https://github.com/dair-ai/nlp_newsletter) 272 | - [NLP News](http://newsletter.ruder.io/) 273 | -------------------------------------------------------------------------------- /numbers/.ipynb_checkpoints/bfloat16-vs-float16-study-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "black-beijing", 6 | "metadata": {}, 7 | "source": [ 8 | "# float16 vs bfloat16 numerical properties comparison\n", 9 | "\n", 10 | "This a short notebook to help understand `fp16` vs `bfloat16` in particular when converting a model trained\n", 11 | "in `bfloat16` to mixed precision - it should be possible to look at the numbers to know which ranges\n", 12 | "are safe and which need to be scaled/avoided.\n", 13 | "\n", 14 | "I needed to do that in the context of trying to understand why bfloat16 t5/mt5 models that were pretrained in bfloat16 had a lot of `nan`/`inf` problems when finetuned in mixed precision." 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 1, 20 | "id": "eastern-variation", 21 | "metadata": {}, 22 | "outputs": [], 23 | "source": [ 24 | "import torch" 25 | ] 26 | }, 27 | { 28 | "cell_type": "markdown", 29 | "id": "adult-daughter", 30 | "metadata": {}, 31 | "source": [ 32 | "This is the main function, that tries to do very simply increments in `bfloat16` and then converting the result to `float16` and showing the discrepancies." 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 2, 38 | "id": "resistant-chile", 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "def find_mismatch(start, incr):\n", 43 | " bf16 = torch.tensor(start, dtype=torch.bfloat16)\n", 44 | " print(f\"\\nfp32 start={start:.2e} using increment={incr}\")\n", 45 | " print(f\"{'bfloat16':>18} {'float16':>18} {'diff':>8}\")\n", 46 | " c = 0\n", 47 | " tries = 0\n", 48 | " while c < 8:\n", 49 | " fp16 = bf16.to(torch.float16)\n", 50 | " if not (fp16 == bf16):\n", 51 | " print(f\"{bf16:.16f} {fp16:.16f} {torch.sub(fp16.to(dtype=torch.float32), bf16):+.2e}\")\n", 52 | " c += 1\n", 53 | " bf16 += incr\n", 54 | " tries += 1\n", 55 | " if tries >= 1e5:\n", 56 | " print(f\"gave up finding mismatch after {tries} steps\")\n", 57 | " return" 58 | ] 59 | }, 60 | { 61 | "cell_type": "markdown", 62 | "id": "applied-damages", 63 | "metadata": {}, 64 | "source": [ 65 | "## Underflow for fp16\n", 66 | "\n", 67 | "when numbers become 0.0" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": 3, 73 | "id": "cooperative-government", 74 | "metadata": {}, 75 | "outputs": [ 76 | { 77 | "name": "stdout", 78 | "output_type": "stream", 79 | "text": [ 80 | "\n", 81 | "fp32 start=1.00e-08 using increment=1e-09\n", 82 | " bfloat16 float16 diff\n", 83 | "0.0000000100117177 0.0000000000000000 -1.00e-08\n", 84 | "0.0000000110012479 0.0000000000000000 -1.10e-08\n", 85 | "0.0000000119907781 0.0000000000000000 -1.20e-08\n", 86 | "0.0000000129803084 0.0000000000000000 -1.30e-08\n", 87 | "0.0000000139698386 0.0000000000000000 -1.40e-08\n", 88 | "0.0000000150175765 0.0000000000000000 -1.50e-08\n", 89 | "0.0000000160653144 0.0000000000000000 -1.61e-08\n", 90 | "0.0000000171130523 0.0000000000000000 -1.71e-08\n" 91 | ] 92 | } 93 | ], 94 | "source": [ 95 | "find_mismatch(1e-08, 1e-09)" 96 | ] 97 | }, 98 | { 99 | "cell_type": "markdown", 100 | "id": "decimal-fraction", 101 | "metadata": {}, 102 | "source": [ 103 | "## Subnormal range for fp16\n", 104 | "\n", 105 | "starting from 5.96e-8 \n", 106 | "\n", 107 | "usually expensive and very low precision" 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": 4, 113 | "id": "statutory-procurement", 114 | "metadata": {}, 115 | "outputs": [ 116 | { 117 | "name": "stdout", 118 | "output_type": "stream", 119 | "text": [ 120 | "\n", 121 | "fp32 start=1.00e-07 using increment=1e-08\n", 122 | " bfloat16 float16 diff\n", 123 | "0.0000001001171768 0.0000001192092896 +1.91e-08\n", 124 | "0.0000001098960638 0.0000001192092896 +9.31e-09\n", 125 | "0.0000001201406121 0.0000001192092896 -9.31e-10\n", 126 | "0.0000001303851604 0.0000001192092896 -1.12e-08\n", 127 | "0.0000001406297088 0.0000001192092896 -2.14e-08\n", 128 | "0.0000001508742571 0.0000001788139343 +2.79e-08\n", 129 | "0.0000001611188054 0.0000001788139343 +1.77e-08\n", 130 | "0.0000001713633537 0.0000001788139343 +7.45e-09\n" 131 | ] 132 | } 133 | ], 134 | "source": [ 135 | "# very limited range for fp16\n", 136 | "find_mismatch(1e-07, 1e-08)" 137 | ] 138 | }, 139 | { 140 | "cell_type": "code", 141 | "execution_count": 5, 142 | "id": "distributed-puppy", 143 | "metadata": {}, 144 | "outputs": [ 145 | { 146 | "name": "stdout", 147 | "output_type": "stream", 148 | "text": [ 149 | "\n", 150 | "fp32 start=1.00e-06 using increment=1e-07\n", 151 | " bfloat16 float16 diff\n", 152 | "0.0000009983778000 0.0000010132789612 +1.49e-08\n", 153 | "0.0000010952353477 0.0000010728836060 -2.24e-08\n", 154 | "0.0000012889504433 0.0000013113021851 +2.24e-08\n", 155 | "0.0000013858079910 0.0000013709068298 -1.49e-08\n", 156 | "0.0000014826655388 0.0000014901161194 +7.45e-09\n", 157 | "0.0000015795230865 0.0000015497207642 -2.98e-08\n", 158 | "0.0000016763806343 0.0000016689300537 -7.45e-09\n", 159 | "0.0000017732381821 0.0000017881393433 +1.49e-08\n" 160 | ] 161 | } 162 | ], 163 | "source": [ 164 | "# things starting to improve slightly for fp16\n", 165 | "find_mismatch(1e-06, 1e-07)" 166 | ] 167 | }, 168 | { 169 | "cell_type": "markdown", 170 | "id": "expired-drinking", 171 | "metadata": {}, 172 | "source": [ 173 | "## Normal numbers\n", 174 | "\n", 175 | "Min positive normal fp16: 6.104e-05 (`np.finfo(np.float16).tiny`)\n", 176 | "\n", 177 | "These ranges match much better and thus will not easily find a mismatch if at all" 178 | ] 179 | }, 180 | { 181 | "cell_type": "code", 182 | "execution_count": 6, 183 | "id": "seven-caution", 184 | "metadata": {}, 185 | "outputs": [ 186 | { 187 | "name": "stdout", 188 | "output_type": "stream", 189 | "text": [ 190 | "\n", 191 | "fp32 start=1.00e-05 using increment=1e-06\n", 192 | " bfloat16 float16 diff\n", 193 | "gave up finding mismatch after 100000 steps\n", 194 | "\n", 195 | "fp32 start=1.00e-04 using increment=1e-06\n", 196 | " bfloat16 float16 diff\n", 197 | "gave up finding mismatch after 100000 steps\n", 198 | "\n", 199 | "fp32 start=1.00e-03 using increment=0.0001\n", 200 | " bfloat16 float16 diff\n", 201 | "gave up finding mismatch after 100000 steps\n", 202 | "\n", 203 | "fp32 start=1.00e-02 using increment=0.001\n", 204 | " bfloat16 float16 diff\n", 205 | "gave up finding mismatch after 100000 steps\n", 206 | "\n", 207 | "fp32 start=1.00e-01 using increment=0.01\n", 208 | " bfloat16 float16 diff\n", 209 | "gave up finding mismatch after 100000 steps\n", 210 | "\n", 211 | "fp32 start=1.00e+01 using increment=1e-06\n", 212 | " bfloat16 float16 diff\n", 213 | "gave up finding mismatch after 100000 steps\n", 214 | "\n", 215 | "fp32 start=1.00e+01 using increment=10.0\n", 216 | " bfloat16 float16 diff\n", 217 | "gave up finding mismatch after 100000 steps\n", 218 | "\n", 219 | "fp32 start=1.00e+04 using increment=1\n", 220 | " bfloat16 float16 diff\n", 221 | "gave up finding mismatch after 100000 steps\n" 222 | ] 223 | } 224 | ], 225 | "source": [ 226 | "find_mismatch(1e-05, 1e-06)\n", 227 | "find_mismatch(1e-04, 1e-06)\n", 228 | "find_mismatch(1e-03, 1e-04)\n", 229 | "find_mismatch(1e-02, 1e-03)\n", 230 | "find_mismatch(1e-01, 1e-02)\n", 231 | "find_mismatch(1e1, 1e-06)\n", 232 | "find_mismatch(1e1, 1e1)\n", 233 | "find_mismatch(1e4, 1)" 234 | ] 235 | }, 236 | { 237 | "cell_type": "code", 238 | "execution_count": 7, 239 | "id": "mighty-injection", 240 | "metadata": {}, 241 | "outputs": [ 242 | { 243 | "name": "stdout", 244 | "output_type": "stream", 245 | "text": [ 246 | "\n", 247 | "fp32 start=5.00e+04 using increment=1000.0\n", 248 | " bfloat16 float16 diff\n", 249 | "66048.0000000000000000 inf +inf\n", 250 | "67072.0000000000000000 inf +inf\n", 251 | "68096.0000000000000000 inf +inf\n", 252 | "69120.0000000000000000 inf +inf\n", 253 | "70144.0000000000000000 inf +inf\n", 254 | "71168.0000000000000000 inf +inf\n", 255 | "72192.0000000000000000 inf +inf\n", 256 | "73216.0000000000000000 inf +inf\n" 257 | ] 258 | } 259 | ], 260 | "source": [ 261 | "# hitting max range for fp16\n", 262 | "find_mismatch(5e4, 1e3)" 263 | ] 264 | }, 265 | { 266 | "cell_type": "code", 267 | "execution_count": 8, 268 | "id": "alleged-lobby", 269 | "metadata": {}, 270 | "outputs": [], 271 | "source": [ 272 | "# --- roundoff ---\n", 273 | "# fp16 4.88e-4\n", 274 | "# bf16 3.91e-3" 275 | ] 276 | }, 277 | { 278 | "cell_type": "markdown", 279 | "id": "cheap-natural", 280 | "metadata": {}, 281 | "source": [ 282 | "## Big numbers\n", 283 | "\n", 284 | "`bfloat16` numbers have a terrible range for numbers `> 1` but `fp16` matches those exactly\n", 285 | "e.g. one can't represent 283 in bf16\n", 286 | "\n", 287 | "```\n", 288 | "python -c \"import torch; print( torch.tensor(283, dtype=torch.bfloat16) )\"\n", 289 | "tensor(284., dtype=torch.bfloat16)\n", 290 | "```" 291 | ] 292 | }, 293 | { 294 | "cell_type": "code", 295 | "execution_count": 9, 296 | "id": "integrated-individual", 297 | "metadata": {}, 298 | "outputs": [ 299 | { 300 | "name": "stdout", 301 | "output_type": "stream", 302 | "text": [ 303 | "282.00\n", 304 | "284.00\n", 305 | "286.00\n" 306 | ] 307 | } 308 | ], 309 | "source": [ 310 | "start = 280\n", 311 | "fp32 = torch.tensor(start, dtype=torch.float32)\n", 312 | "for i in range(3):\n", 313 | " bf16 = fp32.to(torch.bfloat16)\n", 314 | " bf16d = bf16\n", 315 | " while bf16 == bf16d:\n", 316 | " fp32 += 1\n", 317 | " bf16d = fp32.to(torch.bfloat16)\n", 318 | " print(f\"{bf16d:.2f}\")\n", 319 | "# 282\n", 320 | "# 284\n", 321 | "# 286" 322 | ] 323 | }, 324 | { 325 | "cell_type": "markdown", 326 | "id": "ae153e44", 327 | "metadata": {}, 328 | "source": [ 329 | "## How many positions between 2 numbers\n", 330 | "\n", 331 | "Let's see how many `fp16` numbers can fit between `bf16` numbers - which should help to understand how converting a model trained in `fp16` to `bf16` in a way quantizes the model - since there are less `bf16` numbers in the same range." 332 | ] 333 | }, 334 | { 335 | "cell_type": "code", 336 | "execution_count": 5, 337 | "id": "278e6249", 338 | "metadata": {}, 339 | "outputs": [ 340 | { 341 | "name": "stdout", 342 | "output_type": "stream", 343 | "text": [ 344 | "9 fp16s: [0.10009765625, 0.10015869140625, 0.1002197265625, 0.10028076171875, 0.100341796875, 0.10040283203125, 0.1004638671875, 0.10052490234375, 0.1005859375]\n", 345 | "2 fp16s: [0.10009765625, 0.1005859375]\n" 346 | ] 347 | } 348 | ], 349 | "source": [ 350 | "fp16 = torch.tensor(0.1001, dtype=torch.float16)\n", 351 | "bf16 = torch.tensor(0.1001, dtype=torch.bfloat16)\n", 352 | "fp16s = [fp16]\n", 353 | "bf16s = [bf16]\n", 354 | "\n", 355 | "delta = 0.00001\n", 356 | "for i in range(100):\n", 357 | " fp16_new = fp16 + delta*i\n", 358 | " bf16_new = bf16 + delta*i\n", 359 | " if fp16s[-1] != fp16_new:\n", 360 | " fp16s.append(fp16_new)\n", 361 | " if bf16s[-1] != bf16_new:\n", 362 | " bf16s.append(bf16_new)\n", 363 | " if len(bf16s) > 1 and bf16s[-1] == fp16s[-1]:\n", 364 | " break\n", 365 | " \n", 366 | "print(f\"{len(fp16s)} fp16s: {[x.item() for x in fp16s]}\")\n", 367 | "print(f\"{len(bf16s)} fp16s: {[x.item() for x in bf16s]}\")\n" 368 | ] 369 | }, 370 | { 371 | "cell_type": "markdown", 372 | "id": "fa0ae862", 373 | "metadata": {}, 374 | "source": [ 375 | "So it can be seen that in this particular range of numbers every 8 \"positions\" in `fp16` get remapped to a single \"position\" in `bf16`. As `exponent(fp16) = 10` and `exponent(bf16) = 7` - so we have `2**3=8` different positions between 2 representations." 376 | ] 377 | }, 378 | { 379 | "cell_type": "markdown", 380 | "id": "revolutionary-force", 381 | "metadata": {}, 382 | "source": [ 383 | "# Math" 384 | ] 385 | }, 386 | { 387 | "cell_type": "markdown", 388 | "id": "systematic-latex", 389 | "metadata": {}, 390 | "source": [ 391 | "## Summation\n", 392 | "\n", 393 | "A very narrow dynamic range means that for largish numbers NN trained in `bfloat16` **expects** bad\n", 394 | "precision and when the precision is suddenly higher unexpected outcomes happen:" 395 | ] 396 | }, 397 | { 398 | "cell_type": "code", 399 | "execution_count": 10, 400 | "id": "unlike-raising", 401 | "metadata": {}, 402 | "outputs": [ 403 | { 404 | "name": "stdout", 405 | "output_type": "stream", 406 | "text": [ 407 | "tensor(284., dtype=torch.bfloat16)\n", 408 | "tensor(283., dtype=torch.float16)\n" 409 | ] 410 | } 411 | ], 412 | "source": [ 413 | "# small sum\n", 414 | "print(torch.tensor(282, dtype=torch.bfloat16)+1) # 284\n", 415 | "print(torch.tensor(282, dtype=torch.float16)+1) # 283" 416 | ] 417 | }, 418 | { 419 | "cell_type": "code", 420 | "execution_count": 11, 421 | "id": "competitive-average", 422 | "metadata": {}, 423 | "outputs": [ 424 | { 425 | "name": "stdout", 426 | "output_type": "stream", 427 | "text": [ 428 | "tensor(2848., dtype=torch.bfloat16)\n", 429 | "tensor(2830., dtype=torch.float16)\n" 430 | ] 431 | } 432 | ], 433 | "source": [ 434 | "# sum several of these\n", 435 | "print(torch.tensor(283, dtype=torch.bfloat16)*10) # 2848\n", 436 | "print(torch.tensor(283, dtype=torch.float16)*10) # 2830" 437 | ] 438 | }, 439 | { 440 | "cell_type": "markdown", 441 | "id": "choice-enemy", 442 | "metadata": {}, 443 | "source": [ 444 | "As you can see numbers start to diverge quickly!\n", 445 | "\n", 446 | "Now in practice we typically add up thousands of numbers.\n", 447 | "\n", 448 | "The solution is to always do this kind of operations in double precision of the operands and then if needed casting back to the original. i.e. the accumulate of `sum(fp16_tensor)` should be at least a `float32` tensor." 449 | ] 450 | }, 451 | { 452 | "cell_type": "code", 453 | "execution_count": 12, 454 | "id": "liquid-purple", 455 | "metadata": {}, 456 | "outputs": [ 457 | { 458 | "data": { 459 | "text/plain": [ 460 | "tensor(inf, dtype=torch.float16)" 461 | ] 462 | }, 463 | "execution_count": 12, 464 | "metadata": {}, 465 | "output_type": "execute_result" 466 | }, 467 | { 468 | "data": { 469 | "text/plain": [ 470 | "tensor(250394.1875)" 471 | ] 472 | }, 473 | "execution_count": 12, 474 | "metadata": {}, 475 | "output_type": "execute_result" 476 | } 477 | ], 478 | "source": [ 479 | "x = torch.rand((10000)).half()*50\n", 480 | "\n", 481 | "# this overflows\n", 482 | "x.sum()\n", 483 | "# this succeeds\n", 484 | "x.sum(dtype=torch.float32)" 485 | ] 486 | }, 487 | { 488 | "cell_type": "markdown", 489 | "id": "dated-scope", 490 | "metadata": {}, 491 | "source": [ 492 | "## Getting overflows\n", 493 | "\n", 494 | "Full numbers range: ``float16: ±65,504``\n", 495 | "\n", 496 | "So fp16 overflows easily in say variance calculation when you try to just square a number bigger than `256` - as it'd overflow, i.e. you get `inf`! so `256**2+1` will be `inf`\n", 497 | "\n", 498 | "You can't even do `pow(2)` for fp16 in pytorch, the following will give an error: that it doesn't suppor power for fp16.\n", 499 | "\n", 500 | "`torch.tensor(256, dtype=torch.float16).pow(2)`\n", 501 | "\n", 502 | "You have to cast to `float32` first:" 503 | ] 504 | }, 505 | { 506 | "cell_type": "code", 507 | "execution_count": 13, 508 | "id": "quick-local", 509 | "metadata": {}, 510 | "outputs": [ 511 | { 512 | "data": { 513 | "text/plain": [ 514 | "tensor(65024., dtype=torch.float16)" 515 | ] 516 | }, 517 | "execution_count": 13, 518 | "metadata": {}, 519 | "output_type": "execute_result" 520 | } 521 | ], 522 | "source": [ 523 | "x = torch.tensor(255, dtype=torch.float16)\n", 524 | "x_squared = x.float().pow(2)\n", 525 | "x_squared.to(dtype=torch.float16)" 526 | ] 527 | }, 528 | { 529 | "cell_type": "code", 530 | "execution_count": 14, 531 | "id": "nonprofit-charleston", 532 | "metadata": {}, 533 | "outputs": [ 534 | { 535 | "data": { 536 | "text/plain": [ 537 | "tensor(inf, dtype=torch.float16)" 538 | ] 539 | }, 540 | "execution_count": 14, 541 | "metadata": {}, 542 | "output_type": "execute_result" 543 | } 544 | ], 545 | "source": [ 546 | "# let's cross into the overflow\n", 547 | "x += 1\n", 548 | "x_squared = x.float().pow(2)\n", 549 | "x_squared.to(dtype=torch.float16)" 550 | ] 551 | }, 552 | { 553 | "cell_type": "markdown", 554 | "id": "dominican-retro", 555 | "metadata": {}, 556 | "source": [ 557 | "And that's how `inf` comes about.\n", 558 | "\n", 559 | "Or if you need to create one, you can just do:" 560 | ] 561 | }, 562 | { 563 | "cell_type": "code", 564 | "execution_count": 15, 565 | "id": "returning-certificate", 566 | "metadata": {}, 567 | "outputs": [], 568 | "source": [ 569 | "t_inf = torch.tensor(float('inf'))" 570 | ] 571 | }, 572 | { 573 | "cell_type": "markdown", 574 | "id": "eastern-orleans", 575 | "metadata": {}, 576 | "source": [ 577 | "If you need to compare if a tensor has `inf` elements:" 578 | ] 579 | }, 580 | { 581 | "cell_type": "code", 582 | "execution_count": 16, 583 | "id": "ethical-wayne", 584 | "metadata": {}, 585 | "outputs": [ 586 | { 587 | "data": { 588 | "text/plain": [ 589 | "tensor(True)" 590 | ] 591 | }, 592 | "execution_count": 16, 593 | "metadata": {}, 594 | "output_type": "execute_result" 595 | } 596 | ], 597 | "source": [ 598 | "torch.isinf(t_inf).any()" 599 | ] 600 | }, 601 | { 602 | "cell_type": "markdown", 603 | "id": "derived-canberra", 604 | "metadata": {}, 605 | "source": [ 606 | "## Getting NaNs \n", 607 | "\n", 608 | "While there are many ways to get `NaN` during calculations, the most common for machine learning are:" 609 | ] 610 | }, 611 | { 612 | "cell_type": "code", 613 | "execution_count": 17, 614 | "id": "formed-certificate", 615 | "metadata": {}, 616 | "outputs": [ 617 | { 618 | "data": { 619 | "text/plain": [ 620 | "tensor(nan)" 621 | ] 622 | }, 623 | "execution_count": 17, 624 | "metadata": {}, 625 | "output_type": "execute_result" 626 | } 627 | ], 628 | "source": [ 629 | "# 0/0\n", 630 | "t_zero = torch.tensor(0)\n", 631 | "t_zero/t_zero" 632 | ] 633 | }, 634 | { 635 | "cell_type": "code", 636 | "execution_count": 18, 637 | "id": "authorized-window", 638 | "metadata": {}, 639 | "outputs": [ 640 | { 641 | "data": { 642 | "text/plain": [ 643 | "tensor(nan)" 644 | ] 645 | }, 646 | "execution_count": 18, 647 | "metadata": {}, 648 | "output_type": "execute_result" 649 | } 650 | ], 651 | "source": [ 652 | "# inf/inf\n", 653 | "t_inf = torch.tensor(float('inf'))\n", 654 | "t_inf/t_inf" 655 | ] 656 | }, 657 | { 658 | "cell_type": "code", 659 | "execution_count": 19, 660 | "id": "elegant-grant", 661 | "metadata": {}, 662 | "outputs": [ 663 | { 664 | "data": { 665 | "text/plain": [ 666 | "tensor(nan)" 667 | ] 668 | }, 669 | "execution_count": 19, 670 | "metadata": {}, 671 | "output_type": "execute_result" 672 | } 673 | ], 674 | "source": [ 675 | "# 0*inf\n", 676 | "t_zero * t_inf" 677 | ] 678 | }, 679 | { 680 | "cell_type": "code", 681 | "execution_count": 20, 682 | "id": "several-council", 683 | "metadata": {}, 684 | "outputs": [ 685 | { 686 | "data": { 687 | "text/plain": [ 688 | "tensor(nan)" 689 | ] 690 | }, 691 | "execution_count": 20, 692 | "metadata": {}, 693 | "output_type": "execute_result" 694 | } 695 | ], 696 | "source": [ 697 | "# inf - inf\n", 698 | "t_inf - t_inf" 699 | ] 700 | }, 701 | { 702 | "cell_type": "code", 703 | "execution_count": 21, 704 | "id": "local-change", 705 | "metadata": {}, 706 | "outputs": [ 707 | { 708 | "data": { 709 | "text/plain": [ 710 | "tensor(nan)" 711 | ] 712 | }, 713 | "execution_count": 21, 714 | "metadata": {}, 715 | "output_type": "execute_result" 716 | } 717 | ], 718 | "source": [ 719 | "# to get one explicitly\n", 720 | "t_nan = torch.tensor(float('nan'))\n", 721 | "t_nan" 722 | ] 723 | }, 724 | { 725 | "cell_type": "code", 726 | "execution_count": 22, 727 | "id": "hawaiian-evaluation", 728 | "metadata": {}, 729 | "outputs": [ 730 | { 731 | "data": { 732 | "text/plain": [ 733 | "tensor(True)" 734 | ] 735 | }, 736 | "execution_count": 22, 737 | "metadata": {}, 738 | "output_type": "execute_result" 739 | } 740 | ], 741 | "source": [ 742 | "# comparison\n", 743 | "torch.isnan(t_nan).any()" 744 | ] 745 | }, 746 | { 747 | "cell_type": "markdown", 748 | "id": "heated-equivalent", 749 | "metadata": {}, 750 | "source": [ 751 | "# Debugging process" 752 | ] 753 | }, 754 | { 755 | "cell_type": "markdown", 756 | "id": "inside-investor", 757 | "metadata": {}, 758 | "source": [ 759 | "As you can see, since ML is mostly matrix multiplications, which is sums and multiplications, it's enough to get one `inf` or `nan`, and the whole training goes down the rails.\n", 760 | "\n", 761 | "Here is a helper that you can run after suspect functions to see if the output gets any `inf` or `nan`s and also if you want to get an indication on whether you have some large numbers that are likely to overflow - remember in fp16 65K is the biggest number one can have." 762 | ] 763 | }, 764 | { 765 | "cell_type": "code", 766 | "execution_count": 23, 767 | "id": "alpine-wrist", 768 | "metadata": {}, 769 | "outputs": [], 770 | "source": [ 771 | "def detect_overflow(var, ctx):\n", 772 | " \"\"\"\n", 773 | " Report the count of ``nan`` and ``inf`` entries in the tensor.\n", 774 | "\n", 775 | " This is useful for detecting overflows/underflows and best to call right after the function that did some math that\n", 776 | " modified the variable in question.\n", 777 | "\n", 778 | " Args:\n", 779 | " var: tensor variable to check\n", 780 | " ctx: the message to print as a context\n", 781 | " \"\"\"\n", 782 | " if torch.isnan(var).any().item():\n", 783 | " logger.warning(f\"{ctx} has nans\")\n", 784 | " if torch.isinf(var).any().item():\n", 785 | " logger.warning(f\"{ctx} has inf\")\n", 786 | "\n", 787 | " # if needed to monitor large elements can enable the following\n", 788 | " if 0:\n", 789 | " n100 = var[torch.ge(var.abs(), 100)]\n", 790 | " if n100.numel() > 0:\n", 791 | " logger.warning(f\"{ctx}: n100={n100.numel()}\")\n", 792 | " n1000 = var[torch.ge(var.abs(), 1000)]\n", 793 | " if n1000.numel() > 0:\n", 794 | " logger.warning(f\"{ctx}: n1000={n1000.numel()}\")" 795 | ] 796 | }, 797 | { 798 | "cell_type": "markdown", 799 | "id": "vocal-passing", 800 | "metadata": {}, 801 | "source": [ 802 | "So, if you training gives you say a loss of `nan`, you can go to the layers of your model and inject this function, in one or more places, e.g.:" 803 | ] 804 | }, 805 | { 806 | "cell_type": "code", 807 | "execution_count": 24, 808 | "id": "exempt-environment", 809 | "metadata": {}, 810 | "outputs": [], 811 | "source": [ 812 | "def forward(x):\n", 813 | " detect_overflow(x, \"x / enter\")\n", 814 | " y = self.ff(x)\n", 815 | " detect_overflow(x, \"y / after ff\") " 816 | ] 817 | }, 818 | { 819 | "cell_type": "markdown", 820 | "id": "final-tracy", 821 | "metadata": {}, 822 | "source": [ 823 | "or you use an advanced debugger you can assign watches that will immediately tell you if a tensor just got some `inf`s, by having a dynamically evaluated watch expression: `torch.isinf(x).any().item()` - in this example we watch the tensor `x`. So as you step through the code you can visually immediately see if it went from `False` to `True`. " 824 | ] 825 | }, 826 | { 827 | "cell_type": "markdown", 828 | "id": "caroline-group", 829 | "metadata": {}, 830 | "source": [ 831 | "\n" 832 | ] 833 | }, 834 | { 835 | "cell_type": "markdown", 836 | "id": "imported-empty", 837 | "metadata": {}, 838 | "source": [ 839 | "# Disabling subnormal numbers in pytorch\n", 840 | "\n", 841 | "In some systems subnormal number calculation can be suboptimial as it's often done in software, so if your network deals a lot with subnormal numbers you might want to disable those and scale your numbers to a normal range instead.\n", 842 | "\n", 843 | "The following demonstrates how it works in pytorch" 844 | ] 845 | }, 846 | { 847 | "cell_type": "code", 848 | "execution_count": 25, 849 | "id": "fuzzy-pound", 850 | "metadata": { 851 | "run_control": { 852 | "marked": false 853 | } 854 | }, 855 | "outputs": [ 856 | { 857 | "data": { 858 | "text/plain": [ 859 | "tensor([0.])" 860 | ] 861 | }, 862 | "execution_count": 25, 863 | "metadata": {}, 864 | "output_type": "execute_result" 865 | }, 866 | { 867 | "data": { 868 | "text/plain": [ 869 | "tensor([1.0000e-39])" 870 | ] 871 | }, 872 | "execution_count": 25, 873 | "metadata": {}, 874 | "output_type": "execute_result" 875 | } 876 | ], 877 | "source": [ 878 | "_ = torch.set_flush_denormal(True)\n", 879 | "torch.tensor([1e-39], dtype=torch.float32)\n", 880 | "_ = torch.set_flush_denormal(False)\n", 881 | "torch.tensor([1e-39], dtype=torch.float32)" 882 | ] 883 | }, 884 | { 885 | "cell_type": "code", 886 | "execution_count": 26, 887 | "id": "surrounded-upper", 888 | "metadata": {}, 889 | "outputs": [ 890 | { 891 | "data": { 892 | "text/plain": [ 893 | "tensor([1.0133e-06], dtype=torch.float16)" 894 | ] 895 | }, 896 | "execution_count": 26, 897 | "metadata": {}, 898 | "output_type": "execute_result" 899 | }, 900 | { 901 | "data": { 902 | "text/plain": [ 903 | "tensor([1.0133e-06], dtype=torch.float16)" 904 | ] 905 | }, 906 | "execution_count": 26, 907 | "metadata": {}, 908 | "output_type": "execute_result" 909 | } 910 | ], 911 | "source": [ 912 | "# broken for fp16\n", 913 | "_ = torch.set_flush_denormal(True)\n", 914 | "torch.tensor([1e-6], dtype=torch.float16)\n", 915 | "_ = torch.set_flush_denormal(False)\n", 916 | "torch.tensor([1e-6], dtype=torch.float16)" 917 | ] 918 | }, 919 | { 920 | "cell_type": "code", 921 | "execution_count": 27, 922 | "id": "attempted-deficit", 923 | "metadata": {}, 924 | "outputs": [ 925 | { 926 | "data": { 927 | "text/plain": [ 928 | "tensor([0.], dtype=torch.bfloat16)" 929 | ] 930 | }, 931 | "execution_count": 27, 932 | "metadata": {}, 933 | "output_type": "execute_result" 934 | }, 935 | { 936 | "data": { 937 | "text/plain": [ 938 | "tensor([1.0102e-39], dtype=torch.bfloat16)" 939 | ] 940 | }, 941 | "execution_count": 27, 942 | "metadata": {}, 943 | "output_type": "execute_result" 944 | } 945 | ], 946 | "source": [ 947 | "_ = torch.set_flush_denormal(True)\n", 948 | "torch.tensor([1e-39], dtype=torch.bfloat16)\n", 949 | "_ = torch.set_flush_denormal(False)\n", 950 | "torch.tensor([1e-39], dtype=torch.bfloat16)" 951 | ] 952 | }, 953 | { 954 | "cell_type": "code", 955 | "execution_count": null, 956 | "id": "hollywood-karma", 957 | "metadata": {}, 958 | "outputs": [], 959 | "source": [] 960 | }, 961 | { 962 | "cell_type": "code", 963 | "execution_count": null, 964 | "id": "indie-message", 965 | "metadata": {}, 966 | "outputs": [], 967 | "source": [ 968 | "%%javascript # prevent committing an unsaved notebook\n", 969 | "IPython.notebook.save_notebook()" 970 | ] 971 | } 972 | ], 973 | "metadata": { 974 | "hide_input": false, 975 | "kernelspec": { 976 | "display_name": "Python 3 (ipykernel)", 977 | "language": "python", 978 | "name": "python3" 979 | }, 980 | "language_info": { 981 | "codemirror_mode": { 982 | "name": "ipython", 983 | "version": 3 984 | }, 985 | "file_extension": ".py", 986 | "mimetype": "text/x-python", 987 | "name": "python", 988 | "nbconvert_exporter": "python", 989 | "pygments_lexer": "ipython3", 990 | "version": "3.8.12" 991 | }, 992 | "toc": { 993 | "base_numbering": 1, 994 | "nav_menu": {}, 995 | "number_sections": true, 996 | "sideBar": true, 997 | "skip_h1_title": false, 998 | "title_cell": "Table of Contents", 999 | "title_sidebar": "Contents", 1000 | "toc_cell": false, 1001 | "toc_position": { 1002 | "height": "calc(100% - 180px)", 1003 | "left": "10px", 1004 | "top": "150px", 1005 | "width": "256.391px" 1006 | }, 1007 | "toc_section_display": true, 1008 | "toc_window_display": true 1009 | } 1010 | }, 1011 | "nbformat": 4, 1012 | "nbformat_minor": 5 1013 | } 1014 | -------------------------------------------------------------------------------- /numbers/bfloat16-vs-float16-study.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "black-beijing", 6 | "metadata": {}, 7 | "source": [ 8 | "# float16 vs bfloat16 numerical properties comparison\n", 9 | "\n", 10 | "This a short notebook to help understand `fp16` vs `bfloat16` in particular when converting a model trained\n", 11 | "in `bfloat16` to mixed precision - it should be possible to look at the numbers to know which ranges\n", 12 | "are safe and which need to be scaled/avoided.\n", 13 | "\n", 14 | "I needed to do that in the context of trying to understand why bfloat16 t5/mt5 models that were pretrained in bfloat16 had a lot of `nan`/`inf` problems when finetuned in mixed precision." 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 1, 20 | "id": "eastern-variation", 21 | "metadata": {}, 22 | "outputs": [], 23 | "source": [ 24 | "import torch" 25 | ] 26 | }, 27 | { 28 | "cell_type": "markdown", 29 | "id": "adult-daughter", 30 | "metadata": {}, 31 | "source": [ 32 | "This is the main function, that tries to do very simply increments in `bfloat16` and then converting the result to `float16` and showing the discrepancies." 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 2, 38 | "id": "resistant-chile", 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "def find_mismatch(start, incr):\n", 43 | " bf16 = torch.tensor(start, dtype=torch.bfloat16)\n", 44 | " print(f\"\\nfp32 start={start:.2e} using increment={incr}\")\n", 45 | " print(f\"{'bfloat16':>18} {'float16':>18} {'diff':>8}\")\n", 46 | " c = 0\n", 47 | " tries = 0\n", 48 | " while c < 8:\n", 49 | " fp16 = bf16.to(torch.float16)\n", 50 | " if not (fp16 == bf16):\n", 51 | " print(f\"{bf16:.16f} {fp16:.16f} {torch.sub(fp16.to(dtype=torch.float32), bf16):+.2e}\")\n", 52 | " c += 1\n", 53 | " bf16 += incr\n", 54 | " tries += 1\n", 55 | " if tries >= 1e5:\n", 56 | " print(f\"gave up finding mismatch after {tries} steps\")\n", 57 | " return" 58 | ] 59 | }, 60 | { 61 | "cell_type": "markdown", 62 | "id": "applied-damages", 63 | "metadata": {}, 64 | "source": [ 65 | "## Underflow for fp16\n", 66 | "\n", 67 | "when numbers become 0.0" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": 3, 73 | "id": "cooperative-government", 74 | "metadata": {}, 75 | "outputs": [ 76 | { 77 | "name": "stdout", 78 | "output_type": "stream", 79 | "text": [ 80 | "\n", 81 | "fp32 start=1.00e-08 using increment=1e-09\n", 82 | " bfloat16 float16 diff\n", 83 | "0.0000000100117177 0.0000000000000000 -1.00e-08\n", 84 | "0.0000000110012479 0.0000000000000000 -1.10e-08\n", 85 | "0.0000000119907781 0.0000000000000000 -1.20e-08\n", 86 | "0.0000000129803084 0.0000000000000000 -1.30e-08\n", 87 | "0.0000000139698386 0.0000000000000000 -1.40e-08\n", 88 | "0.0000000150175765 0.0000000000000000 -1.50e-08\n", 89 | "0.0000000160653144 0.0000000000000000 -1.61e-08\n", 90 | "0.0000000171130523 0.0000000000000000 -1.71e-08\n" 91 | ] 92 | } 93 | ], 94 | "source": [ 95 | "find_mismatch(1e-08, 1e-09)" 96 | ] 97 | }, 98 | { 99 | "cell_type": "markdown", 100 | "id": "decimal-fraction", 101 | "metadata": {}, 102 | "source": [ 103 | "## Subnormal range for fp16\n", 104 | "\n", 105 | "starting from 5.96e-8 \n", 106 | "\n", 107 | "usually expensive and very low precision" 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": 4, 113 | "id": "statutory-procurement", 114 | "metadata": {}, 115 | "outputs": [ 116 | { 117 | "name": "stdout", 118 | "output_type": "stream", 119 | "text": [ 120 | "\n", 121 | "fp32 start=1.00e-07 using increment=1e-08\n", 122 | " bfloat16 float16 diff\n", 123 | "0.0000001001171768 0.0000001192092896 +1.91e-08\n", 124 | "0.0000001098960638 0.0000001192092896 +9.31e-09\n", 125 | "0.0000001201406121 0.0000001192092896 -9.31e-10\n", 126 | "0.0000001303851604 0.0000001192092896 -1.12e-08\n", 127 | "0.0000001406297088 0.0000001192092896 -2.14e-08\n", 128 | "0.0000001508742571 0.0000001788139343 +2.79e-08\n", 129 | "0.0000001611188054 0.0000001788139343 +1.77e-08\n", 130 | "0.0000001713633537 0.0000001788139343 +7.45e-09\n" 131 | ] 132 | } 133 | ], 134 | "source": [ 135 | "# very limited range for fp16\n", 136 | "find_mismatch(1e-07, 1e-08)" 137 | ] 138 | }, 139 | { 140 | "cell_type": "code", 141 | "execution_count": 5, 142 | "id": "distributed-puppy", 143 | "metadata": {}, 144 | "outputs": [ 145 | { 146 | "name": "stdout", 147 | "output_type": "stream", 148 | "text": [ 149 | "\n", 150 | "fp32 start=1.00e-06 using increment=1e-07\n", 151 | " bfloat16 float16 diff\n", 152 | "0.0000009983778000 0.0000010132789612 +1.49e-08\n", 153 | "0.0000010952353477 0.0000010728836060 -2.24e-08\n", 154 | "0.0000012889504433 0.0000013113021851 +2.24e-08\n", 155 | "0.0000013858079910 0.0000013709068298 -1.49e-08\n", 156 | "0.0000014826655388 0.0000014901161194 +7.45e-09\n", 157 | "0.0000015795230865 0.0000015497207642 -2.98e-08\n", 158 | "0.0000016763806343 0.0000016689300537 -7.45e-09\n", 159 | "0.0000017732381821 0.0000017881393433 +1.49e-08\n" 160 | ] 161 | } 162 | ], 163 | "source": [ 164 | "# things starting to improve slightly for fp16\n", 165 | "find_mismatch(1e-06, 1e-07)" 166 | ] 167 | }, 168 | { 169 | "cell_type": "markdown", 170 | "id": "expired-drinking", 171 | "metadata": {}, 172 | "source": [ 173 | "## Normal numbers\n", 174 | "\n", 175 | "Min positive normal fp16: 6.104e-05 (`np.finfo(np.float16).tiny`)\n", 176 | "\n", 177 | "These ranges match much better and thus will not easily find a mismatch if at all" 178 | ] 179 | }, 180 | { 181 | "cell_type": "code", 182 | "execution_count": 6, 183 | "id": "seven-caution", 184 | "metadata": {}, 185 | "outputs": [ 186 | { 187 | "name": "stdout", 188 | "output_type": "stream", 189 | "text": [ 190 | "\n", 191 | "fp32 start=1.00e-05 using increment=1e-06\n", 192 | " bfloat16 float16 diff\n", 193 | "gave up finding mismatch after 100000 steps\n", 194 | "\n", 195 | "fp32 start=1.00e-04 using increment=1e-06\n", 196 | " bfloat16 float16 diff\n", 197 | "gave up finding mismatch after 100000 steps\n", 198 | "\n", 199 | "fp32 start=1.00e-03 using increment=0.0001\n", 200 | " bfloat16 float16 diff\n", 201 | "gave up finding mismatch after 100000 steps\n", 202 | "\n", 203 | "fp32 start=1.00e-02 using increment=0.001\n", 204 | " bfloat16 float16 diff\n", 205 | "gave up finding mismatch after 100000 steps\n", 206 | "\n", 207 | "fp32 start=1.00e-01 using increment=0.01\n", 208 | " bfloat16 float16 diff\n", 209 | "gave up finding mismatch after 100000 steps\n", 210 | "\n", 211 | "fp32 start=1.00e+01 using increment=1e-06\n", 212 | " bfloat16 float16 diff\n", 213 | "gave up finding mismatch after 100000 steps\n", 214 | "\n", 215 | "fp32 start=1.00e+01 using increment=10.0\n", 216 | " bfloat16 float16 diff\n", 217 | "gave up finding mismatch after 100000 steps\n", 218 | "\n", 219 | "fp32 start=1.00e+04 using increment=1\n", 220 | " bfloat16 float16 diff\n", 221 | "gave up finding mismatch after 100000 steps\n" 222 | ] 223 | } 224 | ], 225 | "source": [ 226 | "find_mismatch(1e-05, 1e-06)\n", 227 | "find_mismatch(1e-04, 1e-06)\n", 228 | "find_mismatch(1e-03, 1e-04)\n", 229 | "find_mismatch(1e-02, 1e-03)\n", 230 | "find_mismatch(1e-01, 1e-02)\n", 231 | "find_mismatch(1e1, 1e-06)\n", 232 | "find_mismatch(1e1, 1e1)\n", 233 | "find_mismatch(1e4, 1)" 234 | ] 235 | }, 236 | { 237 | "cell_type": "code", 238 | "execution_count": 7, 239 | "id": "mighty-injection", 240 | "metadata": {}, 241 | "outputs": [ 242 | { 243 | "name": "stdout", 244 | "output_type": "stream", 245 | "text": [ 246 | "\n", 247 | "fp32 start=5.00e+04 using increment=1000.0\n", 248 | " bfloat16 float16 diff\n", 249 | "66048.0000000000000000 inf +inf\n", 250 | "67072.0000000000000000 inf +inf\n", 251 | "68096.0000000000000000 inf +inf\n", 252 | "69120.0000000000000000 inf +inf\n", 253 | "70144.0000000000000000 inf +inf\n", 254 | "71168.0000000000000000 inf +inf\n", 255 | "72192.0000000000000000 inf +inf\n", 256 | "73216.0000000000000000 inf +inf\n" 257 | ] 258 | } 259 | ], 260 | "source": [ 261 | "# hitting max range for fp16\n", 262 | "find_mismatch(5e4, 1e3)" 263 | ] 264 | }, 265 | { 266 | "cell_type": "code", 267 | "execution_count": 8, 268 | "id": "alleged-lobby", 269 | "metadata": {}, 270 | "outputs": [], 271 | "source": [ 272 | "# --- roundoff ---\n", 273 | "# fp16 4.88e-4\n", 274 | "# bf16 3.91e-3" 275 | ] 276 | }, 277 | { 278 | "cell_type": "markdown", 279 | "id": "cheap-natural", 280 | "metadata": {}, 281 | "source": [ 282 | "## Big numbers\n", 283 | "\n", 284 | "`bfloat16` numbers have a terrible range for numbers `> 1` but `fp16` matches those exactly\n", 285 | "e.g. one can't represent 283 in bf16\n", 286 | "\n", 287 | "```\n", 288 | "python -c \"import torch; print( torch.tensor(283, dtype=torch.bfloat16) )\"\n", 289 | "tensor(284., dtype=torch.bfloat16)\n", 290 | "```" 291 | ] 292 | }, 293 | { 294 | "cell_type": "code", 295 | "execution_count": 9, 296 | "id": "integrated-individual", 297 | "metadata": {}, 298 | "outputs": [ 299 | { 300 | "name": "stdout", 301 | "output_type": "stream", 302 | "text": [ 303 | "282.00\n", 304 | "284.00\n", 305 | "286.00\n" 306 | ] 307 | } 308 | ], 309 | "source": [ 310 | "start = 280\n", 311 | "fp32 = torch.tensor(start, dtype=torch.float32)\n", 312 | "for i in range(3):\n", 313 | " bf16 = fp32.to(torch.bfloat16)\n", 314 | " bf16d = bf16\n", 315 | " while bf16 == bf16d:\n", 316 | " fp32 += 1\n", 317 | " bf16d = fp32.to(torch.bfloat16)\n", 318 | " print(f\"{bf16d:.2f}\")\n", 319 | "# 282\n", 320 | "# 284\n", 321 | "# 286" 322 | ] 323 | }, 324 | { 325 | "cell_type": "markdown", 326 | "id": "ae153e44", 327 | "metadata": {}, 328 | "source": [ 329 | "## How many positions between 2 numbers\n", 330 | "\n", 331 | "Let's see how many `fp16` numbers can fit between `bf16` numbers - which should help to understand how converting a model trained in `fp16` to `bf16` in a way quantizes the model - since there are less `bf16` numbers in the same range." 332 | ] 333 | }, 334 | { 335 | "cell_type": "code", 336 | "execution_count": 5, 337 | "id": "278e6249", 338 | "metadata": {}, 339 | "outputs": [ 340 | { 341 | "name": "stdout", 342 | "output_type": "stream", 343 | "text": [ 344 | "9 fp16s: [0.10009765625, 0.10015869140625, 0.1002197265625, 0.10028076171875, 0.100341796875, 0.10040283203125, 0.1004638671875, 0.10052490234375, 0.1005859375]\n", 345 | "2 bf16s: [0.10009765625, 0.1005859375]\n" 346 | ] 347 | } 348 | ], 349 | "source": [ 350 | "fp16 = torch.tensor(0.1001, dtype=torch.float16)\n", 351 | "bf16 = torch.tensor(0.1001, dtype=torch.bfloat16)\n", 352 | "fp16s = [fp16]\n", 353 | "bf16s = [bf16]\n", 354 | "\n", 355 | "delta = 0.00001\n", 356 | "for i in range(100):\n", 357 | " fp16_new = fp16 + delta*i\n", 358 | " bf16_new = bf16 + delta*i\n", 359 | " if fp16s[-1] != fp16_new:\n", 360 | " fp16s.append(fp16_new)\n", 361 | " if bf16s[-1] != bf16_new:\n", 362 | " bf16s.append(bf16_new)\n", 363 | " if len(bf16s) > 1 and bf16s[-1] == fp16s[-1]:\n", 364 | " break\n", 365 | " \n", 366 | "print(f\"{len(fp16s)} fp16s: {[x.item() for x in fp16s]}\")\n", 367 | "print(f\"{len(bf16s)} bf16s: {[x.item() for x in bf16s]}\")\n" 368 | ] 369 | }, 370 | { 371 | "cell_type": "markdown", 372 | "id": "fa0ae862", 373 | "metadata": {}, 374 | "source": [ 375 | "So it can be seen that in this particular range of numbers every 8 \"positions\" in `fp16` get remapped to a single \"position\" in `bf16`. As `exponent(fp16) = 10` and `exponent(bf16) = 7` - so we have `2**3=8` different positions between 2 representations." 376 | ] 377 | }, 378 | { 379 | "cell_type": "markdown", 380 | "id": "revolutionary-force", 381 | "metadata": {}, 382 | "source": [ 383 | "# Math" 384 | ] 385 | }, 386 | { 387 | "cell_type": "markdown", 388 | "id": "systematic-latex", 389 | "metadata": {}, 390 | "source": [ 391 | "## Summation\n", 392 | "\n", 393 | "A very narrow dynamic range means that for largish numbers NN trained in `bfloat16` **expects** bad\n", 394 | "precision and when the precision is suddenly higher unexpected outcomes happen:" 395 | ] 396 | }, 397 | { 398 | "cell_type": "code", 399 | "execution_count": 10, 400 | "id": "unlike-raising", 401 | "metadata": {}, 402 | "outputs": [ 403 | { 404 | "name": "stdout", 405 | "output_type": "stream", 406 | "text": [ 407 | "tensor(284., dtype=torch.bfloat16)\n", 408 | "tensor(283., dtype=torch.float16)\n" 409 | ] 410 | } 411 | ], 412 | "source": [ 413 | "# small sum\n", 414 | "print(torch.tensor(282, dtype=torch.bfloat16)+1) # 284\n", 415 | "print(torch.tensor(282, dtype=torch.float16)+1) # 283" 416 | ] 417 | }, 418 | { 419 | "cell_type": "code", 420 | "execution_count": 11, 421 | "id": "competitive-average", 422 | "metadata": {}, 423 | "outputs": [ 424 | { 425 | "name": "stdout", 426 | "output_type": "stream", 427 | "text": [ 428 | "tensor(2848., dtype=torch.bfloat16)\n", 429 | "tensor(2830., dtype=torch.float16)\n" 430 | ] 431 | } 432 | ], 433 | "source": [ 434 | "# sum several of these\n", 435 | "print(torch.tensor(283, dtype=torch.bfloat16)*10) # 2848\n", 436 | "print(torch.tensor(283, dtype=torch.float16)*10) # 2830" 437 | ] 438 | }, 439 | { 440 | "cell_type": "markdown", 441 | "id": "choice-enemy", 442 | "metadata": {}, 443 | "source": [ 444 | "As you can see numbers start to diverge quickly!\n", 445 | "\n", 446 | "Now in practice we typically add up thousands of numbers.\n", 447 | "\n", 448 | "The solution is to always do this kind of operations in double precision of the operands and then if needed casting back to the original. i.e. the accumulate of `sum(fp16_tensor)` should be at least a `float32` tensor." 449 | ] 450 | }, 451 | { 452 | "cell_type": "code", 453 | "execution_count": 12, 454 | "id": "liquid-purple", 455 | "metadata": {}, 456 | "outputs": [ 457 | { 458 | "data": { 459 | "text/plain": [ 460 | "tensor(inf, dtype=torch.float16)" 461 | ] 462 | }, 463 | "execution_count": 12, 464 | "metadata": {}, 465 | "output_type": "execute_result" 466 | }, 467 | { 468 | "data": { 469 | "text/plain": [ 470 | "tensor(250394.1875)" 471 | ] 472 | }, 473 | "execution_count": 12, 474 | "metadata": {}, 475 | "output_type": "execute_result" 476 | } 477 | ], 478 | "source": [ 479 | "x = torch.rand((10000)).half()*50\n", 480 | "\n", 481 | "# this overflows\n", 482 | "x.sum()\n", 483 | "# this succeeds\n", 484 | "x.sum(dtype=torch.float32)" 485 | ] 486 | }, 487 | { 488 | "cell_type": "markdown", 489 | "id": "dated-scope", 490 | "metadata": {}, 491 | "source": [ 492 | "## Getting overflows\n", 493 | "\n", 494 | "Full numbers range: ``float16: ±65,504``\n", 495 | "\n", 496 | "So fp16 overflows easily in say variance calculation when you try to just square a number bigger than `256` - as it'd overflow, i.e. you get `inf`! so `256**2+1` will be `inf`\n", 497 | "\n", 498 | "You can't even do `pow(2)` for fp16 in pytorch, the following will give an error: that it doesn't suppor power for fp16.\n", 499 | "\n", 500 | "`torch.tensor(256, dtype=torch.float16).pow(2)`\n", 501 | "\n", 502 | "You have to cast to `float32` first:" 503 | ] 504 | }, 505 | { 506 | "cell_type": "code", 507 | "execution_count": 13, 508 | "id": "quick-local", 509 | "metadata": {}, 510 | "outputs": [ 511 | { 512 | "data": { 513 | "text/plain": [ 514 | "tensor(65024., dtype=torch.float16)" 515 | ] 516 | }, 517 | "execution_count": 13, 518 | "metadata": {}, 519 | "output_type": "execute_result" 520 | } 521 | ], 522 | "source": [ 523 | "x = torch.tensor(255, dtype=torch.float16)\n", 524 | "x_squared = x.float().pow(2)\n", 525 | "x_squared.to(dtype=torch.float16)" 526 | ] 527 | }, 528 | { 529 | "cell_type": "code", 530 | "execution_count": 14, 531 | "id": "nonprofit-charleston", 532 | "metadata": {}, 533 | "outputs": [ 534 | { 535 | "data": { 536 | "text/plain": [ 537 | "tensor(inf, dtype=torch.float16)" 538 | ] 539 | }, 540 | "execution_count": 14, 541 | "metadata": {}, 542 | "output_type": "execute_result" 543 | } 544 | ], 545 | "source": [ 546 | "# let's cross into the overflow\n", 547 | "x += 1\n", 548 | "x_squared = x.float().pow(2)\n", 549 | "x_squared.to(dtype=torch.float16)" 550 | ] 551 | }, 552 | { 553 | "cell_type": "markdown", 554 | "id": "dominican-retro", 555 | "metadata": {}, 556 | "source": [ 557 | "And that's how `inf` comes about.\n", 558 | "\n", 559 | "Or if you need to create one, you can just do:" 560 | ] 561 | }, 562 | { 563 | "cell_type": "code", 564 | "execution_count": 15, 565 | "id": "returning-certificate", 566 | "metadata": {}, 567 | "outputs": [], 568 | "source": [ 569 | "t_inf = torch.tensor(float('inf'))" 570 | ] 571 | }, 572 | { 573 | "cell_type": "markdown", 574 | "id": "eastern-orleans", 575 | "metadata": {}, 576 | "source": [ 577 | "If you need to compare if a tensor has `inf` elements:" 578 | ] 579 | }, 580 | { 581 | "cell_type": "code", 582 | "execution_count": 16, 583 | "id": "ethical-wayne", 584 | "metadata": {}, 585 | "outputs": [ 586 | { 587 | "data": { 588 | "text/plain": [ 589 | "tensor(True)" 590 | ] 591 | }, 592 | "execution_count": 16, 593 | "metadata": {}, 594 | "output_type": "execute_result" 595 | } 596 | ], 597 | "source": [ 598 | "torch.isinf(t_inf).any()" 599 | ] 600 | }, 601 | { 602 | "cell_type": "markdown", 603 | "id": "derived-canberra", 604 | "metadata": {}, 605 | "source": [ 606 | "## Getting NaNs \n", 607 | "\n", 608 | "While there are many ways to get `NaN` during calculations, the most common for machine learning are:" 609 | ] 610 | }, 611 | { 612 | "cell_type": "code", 613 | "execution_count": 17, 614 | "id": "formed-certificate", 615 | "metadata": {}, 616 | "outputs": [ 617 | { 618 | "data": { 619 | "text/plain": [ 620 | "tensor(nan)" 621 | ] 622 | }, 623 | "execution_count": 17, 624 | "metadata": {}, 625 | "output_type": "execute_result" 626 | } 627 | ], 628 | "source": [ 629 | "# 0/0\n", 630 | "t_zero = torch.tensor(0)\n", 631 | "t_zero/t_zero" 632 | ] 633 | }, 634 | { 635 | "cell_type": "code", 636 | "execution_count": 18, 637 | "id": "authorized-window", 638 | "metadata": {}, 639 | "outputs": [ 640 | { 641 | "data": { 642 | "text/plain": [ 643 | "tensor(nan)" 644 | ] 645 | }, 646 | "execution_count": 18, 647 | "metadata": {}, 648 | "output_type": "execute_result" 649 | } 650 | ], 651 | "source": [ 652 | "# inf/inf\n", 653 | "t_inf = torch.tensor(float('inf'))\n", 654 | "t_inf/t_inf" 655 | ] 656 | }, 657 | { 658 | "cell_type": "code", 659 | "execution_count": 19, 660 | "id": "elegant-grant", 661 | "metadata": {}, 662 | "outputs": [ 663 | { 664 | "data": { 665 | "text/plain": [ 666 | "tensor(nan)" 667 | ] 668 | }, 669 | "execution_count": 19, 670 | "metadata": {}, 671 | "output_type": "execute_result" 672 | } 673 | ], 674 | "source": [ 675 | "# 0*inf\n", 676 | "t_zero * t_inf" 677 | ] 678 | }, 679 | { 680 | "cell_type": "code", 681 | "execution_count": 20, 682 | "id": "several-council", 683 | "metadata": {}, 684 | "outputs": [ 685 | { 686 | "data": { 687 | "text/plain": [ 688 | "tensor(nan)" 689 | ] 690 | }, 691 | "execution_count": 20, 692 | "metadata": {}, 693 | "output_type": "execute_result" 694 | } 695 | ], 696 | "source": [ 697 | "# inf - inf\n", 698 | "t_inf - t_inf" 699 | ] 700 | }, 701 | { 702 | "cell_type": "code", 703 | "execution_count": 21, 704 | "id": "local-change", 705 | "metadata": {}, 706 | "outputs": [ 707 | { 708 | "data": { 709 | "text/plain": [ 710 | "tensor(nan)" 711 | ] 712 | }, 713 | "execution_count": 21, 714 | "metadata": {}, 715 | "output_type": "execute_result" 716 | } 717 | ], 718 | "source": [ 719 | "# to get one explicitly\n", 720 | "t_nan = torch.tensor(float('nan'))\n", 721 | "t_nan" 722 | ] 723 | }, 724 | { 725 | "cell_type": "code", 726 | "execution_count": 22, 727 | "id": "hawaiian-evaluation", 728 | "metadata": {}, 729 | "outputs": [ 730 | { 731 | "data": { 732 | "text/plain": [ 733 | "tensor(True)" 734 | ] 735 | }, 736 | "execution_count": 22, 737 | "metadata": {}, 738 | "output_type": "execute_result" 739 | } 740 | ], 741 | "source": [ 742 | "# comparison\n", 743 | "torch.isnan(t_nan).any()" 744 | ] 745 | }, 746 | { 747 | "cell_type": "markdown", 748 | "id": "heated-equivalent", 749 | "metadata": {}, 750 | "source": [ 751 | "# Debugging process" 752 | ] 753 | }, 754 | { 755 | "cell_type": "markdown", 756 | "id": "inside-investor", 757 | "metadata": {}, 758 | "source": [ 759 | "As you can see, since ML is mostly matrix multiplications, which is sums and multiplications, it's enough to get one `inf` or `nan`, and the whole training goes down the rails.\n", 760 | "\n", 761 | "Here is a helper that you can run after suspect functions to see if the output gets any `inf` or `nan`s and also if you want to get an indication on whether you have some large numbers that are likely to overflow - remember in fp16 65K is the biggest number one can have." 762 | ] 763 | }, 764 | { 765 | "cell_type": "code", 766 | "execution_count": 23, 767 | "id": "alpine-wrist", 768 | "metadata": {}, 769 | "outputs": [], 770 | "source": [ 771 | "def detect_overflow(var, ctx):\n", 772 | " \"\"\"\n", 773 | " Report the count of ``nan`` and ``inf`` entries in the tensor.\n", 774 | "\n", 775 | " This is useful for detecting overflows/underflows and best to call right after the function that did some math that\n", 776 | " modified the variable in question.\n", 777 | "\n", 778 | " Args:\n", 779 | " var: tensor variable to check\n", 780 | " ctx: the message to print as a context\n", 781 | " \"\"\"\n", 782 | " if torch.isnan(var).any().item():\n", 783 | " logger.warning(f\"{ctx} has nans\")\n", 784 | " if torch.isinf(var).any().item():\n", 785 | " logger.warning(f\"{ctx} has inf\")\n", 786 | "\n", 787 | " # if needed to monitor large elements can enable the following\n", 788 | " if 0:\n", 789 | " n100 = var[torch.ge(var.abs(), 100)]\n", 790 | " if n100.numel() > 0:\n", 791 | " logger.warning(f\"{ctx}: n100={n100.numel()}\")\n", 792 | " n1000 = var[torch.ge(var.abs(), 1000)]\n", 793 | " if n1000.numel() > 0:\n", 794 | " logger.warning(f\"{ctx}: n1000={n1000.numel()}\")" 795 | ] 796 | }, 797 | { 798 | "cell_type": "markdown", 799 | "id": "vocal-passing", 800 | "metadata": {}, 801 | "source": [ 802 | "So, if you training gives you say a loss of `nan`, you can go to the layers of your model and inject this function, in one or more places, e.g.:" 803 | ] 804 | }, 805 | { 806 | "cell_type": "code", 807 | "execution_count": 24, 808 | "id": "exempt-environment", 809 | "metadata": {}, 810 | "outputs": [], 811 | "source": [ 812 | "def forward(x):\n", 813 | " detect_overflow(x, \"x / enter\")\n", 814 | " y = self.ff(x)\n", 815 | " detect_overflow(x, \"y / after ff\") " 816 | ] 817 | }, 818 | { 819 | "cell_type": "markdown", 820 | "id": "final-tracy", 821 | "metadata": {}, 822 | "source": [ 823 | "or you use an advanced debugger you can assign watches that will immediately tell you if a tensor just got some `inf`s, by having a dynamically evaluated watch expression: `torch.isinf(x).any().item()` - in this example we watch the tensor `x`. So as you step through the code you can visually immediately see if it went from `False` to `True`. " 824 | ] 825 | }, 826 | { 827 | "cell_type": "markdown", 828 | "id": "caroline-group", 829 | "metadata": {}, 830 | "source": [ 831 | "\n" 832 | ] 833 | }, 834 | { 835 | "cell_type": "markdown", 836 | "id": "imported-empty", 837 | "metadata": {}, 838 | "source": [ 839 | "# Disabling subnormal numbers in pytorch\n", 840 | "\n", 841 | "In some systems subnormal number calculation can be suboptimial as it's often done in software, so if your network deals a lot with subnormal numbers you might want to disable those and scale your numbers to a normal range instead.\n", 842 | "\n", 843 | "The following demonstrates how it works in pytorch" 844 | ] 845 | }, 846 | { 847 | "cell_type": "code", 848 | "execution_count": 25, 849 | "id": "fuzzy-pound", 850 | "metadata": { 851 | "run_control": { 852 | "marked": false 853 | } 854 | }, 855 | "outputs": [ 856 | { 857 | "data": { 858 | "text/plain": [ 859 | "tensor([0.])" 860 | ] 861 | }, 862 | "execution_count": 25, 863 | "metadata": {}, 864 | "output_type": "execute_result" 865 | }, 866 | { 867 | "data": { 868 | "text/plain": [ 869 | "tensor([1.0000e-39])" 870 | ] 871 | }, 872 | "execution_count": 25, 873 | "metadata": {}, 874 | "output_type": "execute_result" 875 | } 876 | ], 877 | "source": [ 878 | "_ = torch.set_flush_denormal(True)\n", 879 | "torch.tensor([1e-39], dtype=torch.float32)\n", 880 | "_ = torch.set_flush_denormal(False)\n", 881 | "torch.tensor([1e-39], dtype=torch.float32)" 882 | ] 883 | }, 884 | { 885 | "cell_type": "code", 886 | "execution_count": 26, 887 | "id": "surrounded-upper", 888 | "metadata": {}, 889 | "outputs": [ 890 | { 891 | "data": { 892 | "text/plain": [ 893 | "tensor([1.0133e-06], dtype=torch.float16)" 894 | ] 895 | }, 896 | "execution_count": 26, 897 | "metadata": {}, 898 | "output_type": "execute_result" 899 | }, 900 | { 901 | "data": { 902 | "text/plain": [ 903 | "tensor([1.0133e-06], dtype=torch.float16)" 904 | ] 905 | }, 906 | "execution_count": 26, 907 | "metadata": {}, 908 | "output_type": "execute_result" 909 | } 910 | ], 911 | "source": [ 912 | "# broken for fp16\n", 913 | "_ = torch.set_flush_denormal(True)\n", 914 | "torch.tensor([1e-6], dtype=torch.float16)\n", 915 | "_ = torch.set_flush_denormal(False)\n", 916 | "torch.tensor([1e-6], dtype=torch.float16)" 917 | ] 918 | }, 919 | { 920 | "cell_type": "code", 921 | "execution_count": 27, 922 | "id": "attempted-deficit", 923 | "metadata": {}, 924 | "outputs": [ 925 | { 926 | "data": { 927 | "text/plain": [ 928 | "tensor([0.], dtype=torch.bfloat16)" 929 | ] 930 | }, 931 | "execution_count": 27, 932 | "metadata": {}, 933 | "output_type": "execute_result" 934 | }, 935 | { 936 | "data": { 937 | "text/plain": [ 938 | "tensor([1.0102e-39], dtype=torch.bfloat16)" 939 | ] 940 | }, 941 | "execution_count": 27, 942 | "metadata": {}, 943 | "output_type": "execute_result" 944 | } 945 | ], 946 | "source": [ 947 | "_ = torch.set_flush_denormal(True)\n", 948 | "torch.tensor([1e-39], dtype=torch.bfloat16)\n", 949 | "_ = torch.set_flush_denormal(False)\n", 950 | "torch.tensor([1e-39], dtype=torch.bfloat16)" 951 | ] 952 | }, 953 | { 954 | "cell_type": "code", 955 | "execution_count": null, 956 | "id": "hollywood-karma", 957 | "metadata": {}, 958 | "outputs": [], 959 | "source": [] 960 | }, 961 | { 962 | "cell_type": "code", 963 | "execution_count": null, 964 | "id": "indie-message", 965 | "metadata": {}, 966 | "outputs": [], 967 | "source": [ 968 | "%%javascript # prevent committing an unsaved notebook\n", 969 | "IPython.notebook.save_notebook()" 970 | ] 971 | } 972 | ], 973 | "metadata": { 974 | "hide_input": false, 975 | "kernelspec": { 976 | "display_name": "Python 3 (ipykernel)", 977 | "language": "python", 978 | "name": "python3" 979 | }, 980 | "language_info": { 981 | "codemirror_mode": { 982 | "name": "ipython", 983 | "version": 3 984 | }, 985 | "file_extension": ".py", 986 | "mimetype": "text/x-python", 987 | "name": "python", 988 | "nbconvert_exporter": "python", 989 | "pygments_lexer": "ipython3", 990 | "version": "3.8.15" 991 | }, 992 | "toc": { 993 | "base_numbering": 1, 994 | "nav_menu": {}, 995 | "number_sections": true, 996 | "sideBar": true, 997 | "skip_h1_title": false, 998 | "title_cell": "Table of Contents", 999 | "title_sidebar": "Contents", 1000 | "toc_cell": false, 1001 | "toc_position": { 1002 | "height": "calc(100% - 180px)", 1003 | "left": "10px", 1004 | "top": "150px", 1005 | "width": "256.391px" 1006 | }, 1007 | "toc_section_display": true, 1008 | "toc_window_display": true 1009 | } 1010 | }, 1011 | "nbformat": 4, 1012 | "nbformat_minor": 5 1013 | } 1014 | -------------------------------------------------------------------------------- /numbers/detect-model-pretrained-in-bf16-fp16-fp32.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "7428a31f", 6 | "metadata": {}, 7 | "source": [ 8 | "# bf16, fp16 or fp32 Model Pretraining Detection\n", 9 | "\n", 10 | "The goal is to autodetect if a model has been trained in bf16, fp16 or fp32 precision. We want this since we know that bf16-pretrained models tend to overflow when consequently finetuned with fp16 (mixed).\n", 11 | "\n", 12 | "We know that fp16's max number is `2**16=65536` (`~6.5e04`), so it should be easy to look at the weights and if they are larger than `1e02` (`sqrt(1e04)`) then the model has most likely been trained in other than fp16 precision (mixed or not).\n", 13 | "\n", 14 | "Let's write a script to look at the absolute min/max values of any model's weights, apply it to a bunch of models that we have information on how they were trained and find a pattern. \n", 15 | "\n", 16 | "I thought that abs min values could give us some info about the precision, but most likely it's the abs max values that are most telling. Let's see.\n", 17 | "\n", 18 | "I also added min and max norms, which I see are quite telling as well.\n", 19 | "\n", 20 | "**I'm currently needing more public models to get the patterns right. Please help by adding more models that you know how they were trained. Thank you!**\n", 21 | "\n", 22 | "You can submit your contribution and/or read the database gathered so far [here](https://discuss.huggingface.co/t/compiling-data-on-how-models-were-pre-trained-fp16-fp32-bf16/5671).\n" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": 12, 28 | "id": "e951074c", 29 | "metadata": {}, 30 | "outputs": [], 31 | "source": [ 32 | "import torch\n", 33 | "import logging\n", 34 | "import transformers" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": 2, 40 | "id": "028d21bb", 41 | "metadata": {}, 42 | "outputs": [], 43 | "source": [ 44 | "from transformers import AutoModel" 45 | ] 46 | }, 47 | { 48 | "cell_type": "markdown", 49 | "id": "7b70ce36", 50 | "metadata": {}, 51 | "source": [ 52 | "## Module weights abs min/max analyser" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": 3, 58 | "id": "fa0b9662", 59 | "metadata": {}, 60 | "outputs": [], 61 | "source": [ 62 | "def analyze(modules, verbose=True):\n", 63 | " \"\"\"\n", 64 | " modules is a list of sub-modules to search recursively. \n", 65 | " \n", 66 | " this can be the whole model, but sometimes only some submodules want to be inspected\n", 67 | " \"\"\"\n", 68 | " if verbose:\n", 69 | " print(\"\\nSearching:\")\n", 70 | " print(\"module | params\")\n", 71 | " abs_min, abs_max = 1e10, 0\n", 72 | " norm_min, norm_max = 1e10, 0\n", 73 | " for i,m in enumerate(modules):\n", 74 | " for j,p in enumerate(m.parameters(recurse=True)):\n", 75 | " p_abs = p.abs()\n", 76 | " p_abs_max = p_abs.max().item()\n", 77 | " p_abs_min = p_abs.min().item()\n", 78 | " if p_abs_min < abs_min: abs_min = p_abs_min\n", 79 | " if p_abs_max > abs_max: abs_max = p_abs_max\n", 80 | " \n", 81 | " p_norm = torch.linalg.norm(p.data)\n", 82 | " if p_norm > 0:\n", 83 | " if p_norm < norm_min: norm_min = p_norm\n", 84 | " if p_norm > norm_max: norm_max = p_norm\n", 85 | " if verbose:\n", 86 | " print(f\"{i:>6} | {j}\")\n", 87 | " return abs_min, abs_max, norm_min, norm_max" 88 | ] 89 | }, 90 | { 91 | "cell_type": "markdown", 92 | "id": "1d3a43e2", 93 | "metadata": {}, 94 | "source": [ 95 | "the only concern I have here is that some models when trained in mixed precision may have some segment trained in fp32 and may end up with larger weights, though it is very unlikely since these then have to interact with the rest of the system. But more thought is needed." 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": 4, 101 | "id": "332ce2fe", 102 | "metadata": {}, 103 | "outputs": [], 104 | "source": [ 105 | "from transformers.utils.logging import disable_progress_bar\n", 106 | "disable_progress_bar() # disable tqdm!\n", 107 | "\n", 108 | "model = AutoModel.from_pretrained(\"t5-3b\")" 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": 5, 114 | "id": "197c886d", 115 | "metadata": { 116 | "run_control": { 117 | "marked": false 118 | } 119 | }, 120 | "outputs": [ 121 | { 122 | "name": "stdout", 123 | "output_type": "stream", 124 | "text": [ 125 | "\n", 126 | "Searching:\n", 127 | "module | params\n", 128 | " 0 | 192\n", 129 | " 1 | 312\n", 130 | "\n", 131 | "Results:\n", 132 | "abs min | abs max | norm min | norm max\n", 133 | "1.455e-11 | 6.950e+01 | 5.201e+00 | 2.535e+03\n", 134 | "\n", 135 | "Searching:\n", 136 | "module | params\n", 137 | " 0 | 508\n", 138 | "\n", 139 | "Results:\n", 140 | "abs min | abs max | norm min | norm max\n", 141 | "1.455e-11 | 2.340e+02 | 5.201e+00 | 6.349e+04\n" 142 | ] 143 | } 144 | ], 145 | "source": [ 146 | "# Let's look at t5-small in verbose mode\n", 147 | "#model = AutoModel.from_pretrained(\"t5-small\")\n", 148 | "\n", 149 | "# let's look at just transformer blocks\n", 150 | "abs_min, abs_max, norm_min, norm_max = analyze([model.encoder.block, model.decoder.block])\n", 151 | "print(\"\\nResults:\")\n", 152 | "print(\"abs min | abs max | norm min | norm max\")\n", 153 | "print(f\"{abs_min:.3e} | {abs_max:.3e} | {norm_min:.3e} | {norm_max:.3e}\")\n", 154 | "\n", 155 | "# now the whole model\n", 156 | "abs_min, abs_max, norm_min, norm_max = analyze([model])\n", 157 | "print(\"\\nResults:\")\n", 158 | "print(\"abs min | abs max | norm min | norm max\")\n", 159 | "print(f\"{abs_min:.3e} | {abs_max:.3e} | {norm_min:.3e} | {norm_max:.3e}\")\n", 160 | "\n", 161 | "del model" 162 | ] 163 | }, 164 | { 165 | "cell_type": "markdown", 166 | "id": "48185d16", 167 | "metadata": {}, 168 | "source": [ 169 | "## Multiple model weights abs min/max analyser\n", 170 | "\n", 171 | "Now let's write a nice wrapper to process many models" 172 | ] 173 | }, 174 | { 175 | "cell_type": "code", 176 | "execution_count": 6, 177 | "id": "6bcf8583", 178 | "metadata": {}, 179 | "outputs": [], 180 | "source": [ 181 | "def models_analyze(mnames):\n", 182 | " transformers.logging.set_verbosity_error() # be quiet\n", 183 | " print(f\"{'name':^40} | {'abs min':^9} | {'abs max':^9} | {'norm min':^9} | {'norm max':^9} \")\n", 184 | " print(f\"{'-'*40}-|-{'-'*9}-|-{'-'*9}-|-{'-'*9}-|-{'-'*9}-\")\n", 185 | " for mname in mnames:\n", 186 | " model = AutoModel.from_pretrained(mname)\n", 187 | " abs_min, abs_max, norm_min, norm_max = analyze([model], verbose=False)\n", 188 | " print(f\"{mname:<40} | {abs_min:.3e} | {abs_max:.3e} | {norm_min:.3e} | {norm_max:.3e}\")\n", 189 | " del model" 190 | ] 191 | }, 192 | { 193 | "cell_type": "markdown", 194 | "id": "50f92dc2", 195 | "metadata": {}, 196 | "source": [ 197 | "## fp16 models\n", 198 | "\n", 199 | "Let's look at fp16-pretrained models" 200 | ] 201 | }, 202 | { 203 | "cell_type": "code", 204 | "execution_count": 7, 205 | "id": "a74dfcdb", 206 | "metadata": {}, 207 | "outputs": [ 208 | { 209 | "name": "stdout", 210 | "output_type": "stream", 211 | "text": [ 212 | " name | abs min | abs max | norm min | norm max \n", 213 | "-----------------------------------------|-----------|-----------|-----------|-----------\n", 214 | "allenai/longformer-base-4096 | 0.000e+00 | 1.510e+00 | 2.272e-02 | 7.993e+02\n", 215 | "allenai/longformer-large-4096 | 0.000e+00 | 1.146e+00 | 9.087e-02 | 9.428e+02\n", 216 | "allenai/led-base-16384 | 0.000e+00 | 1.600e+01 | 1.611e-02 | 4.147e+02\n", 217 | "allenai/led-large-16384 | 0.000e+00 | 2.320e+01 | 4.799e-02 | 6.362e+02\n", 218 | "lvwerra/codeparrot | 1.245e-11 | 1.832e+00 | 1.185e-01 | 2.112e+02\n", 219 | "facebook/m2m100_418M | 0.000e+00 | 1.000e+00 | 4.792e-01 | 4.829e+02\n", 220 | "facebook/m2m100_1.2B | 0.000e+00 | 1.000e+00 | 4.835e-01 | 4.925e+02\n", 221 | "facebook/opt-1.3b | 0.000e+00 | 1.000e+00 | 4.852e-02 | 3.619e+02\n", 222 | "facebook/opt-13b | 0.000e+00 | 1.000e+00 | 7.830e-02 | 3.136e+02\n", 223 | "bigscience/bloom-7b1 | 0.000e+00 | 1.783e+01 | 1.645e+00 | 2.669e+02\n", 224 | "bigscience/bloom-3b | 0.000e+00 | 2.522e+01 | 1.406e+00 | 2.606e+02\n" 225 | ] 226 | } 227 | ], 228 | "source": [ 229 | "# fp16-pretrained models\n", 230 | "mnames = [\"allenai/longformer-base-4096\", \"allenai/longformer-large-4096\", \n", 231 | " \"allenai/led-base-16384\", \"allenai/led-large-16384\", \"lvwerra/codeparrot\", \n", 232 | " \"facebook/m2m100_418M\", \"facebook/m2m100_1.2B\",\n", 233 | " \"facebook/opt-1.3b\", \"facebook/opt-13b\",\n", 234 | " \"bigscience/bloom-7b1\", \"bigscience/bloom-3b\", \n", 235 | " ]\n", 236 | "models_analyze(mnames)" 237 | ] 238 | }, 239 | { 240 | "cell_type": "markdown", 241 | "id": "3500cd2b", 242 | "metadata": {}, 243 | "source": [ 244 | "So we can see the fp16 abs max weights are quite small - they are in the range of 1e0 - 1e1.\n", 245 | "\n", 246 | "The norm max is also always under 1e3 in our samples\n", 247 | "\n", 248 | "abs max for \"led\" models is oddly pretty high. They are supposed to be the same as longformer, which are fp16. But norm max matches other models." 249 | ] 250 | }, 251 | { 252 | "cell_type": "markdown", 253 | "id": "b266fb1a", 254 | "metadata": {}, 255 | "source": [ 256 | "## bf16 models\n", 257 | "\n", 258 | "Let's look at bf16-pretrained models" 259 | ] 260 | }, 261 | { 262 | "cell_type": "code", 263 | "execution_count": 8, 264 | "id": "d1479a3b", 265 | "metadata": {}, 266 | "outputs": [ 267 | { 268 | "name": "stdout", 269 | "output_type": "stream", 270 | "text": [ 271 | " name | abs min | abs max | norm min | norm max \n", 272 | "-----------------------------------------|-----------|-----------|-----------|-----------\n", 273 | "t5-small | 5.442e-09 | 7.920e+02 | 1.780e+00 | 9.403e+04\n", 274 | "t5-base | 1.273e-10 | 5.600e+02 | 1.647e+00 | 9.332e+04\n", 275 | "t5-large | 3.638e-11 | 5.200e+02 | 3.797e+00 | 8.237e+04\n", 276 | "google/mt5-small | 3.201e-09 | 1.140e+02 | 2.662e+00 | 1.610e+05\n", 277 | "google/mt5-base | 1.848e-09 | 1.135e+02 | 3.445e+00 | 1.639e+05\n", 278 | "google/mt5-large | 1.892e-10 | 1.750e+02 | 4.472e+00 | 2.029e+05\n", 279 | "google/bigbird-pegasus-large-arxiv | 0.000e+00 | 2.424e+02 | 4.955e-01 | 3.183e+03\n", 280 | "google/pegasus-cnn_dailymail | 0.000e+00 | 2.416e+02 | 4.926e-01 | 4.423e+03\n", 281 | "google/pegasus-large | 0.000e+00 | 2.417e+02 | 4.912e-01 | 4.745e+03\n", 282 | "google/pegasus-multi_news | 0.000e+00 | 2.412e+02 | 4.925e-01 | 4.377e+03\n", 283 | "google/pegasus-xsum | 0.000e+00 | 2.418e+02 | 4.914e-01 | 4.402e+03\n", 284 | "bigscience/T0_3B | 5.755e-13 | 1.680e+02 | 3.114e+00 | 7.432e+04\n", 285 | "EleutherAI/gpt-neo-1.3B | 2.456e-10 | 5.125e+00 | 1.337e+00 | 1.055e+03\n" 286 | ] 287 | } 288 | ], 289 | "source": [ 290 | "# bf16-pretrained models\n", 291 | "mnames = [\"t5-small\", \"t5-base\", \"t5-large\", \"google/mt5-small\", \"google/mt5-base\", \n", 292 | " \"google/mt5-large\",\n", 293 | " \"google/bigbird-pegasus-large-arxiv\", \"google/pegasus-cnn_dailymail\", \n", 294 | " \"google/pegasus-large\", \"google/pegasus-multi_news\", \"google/pegasus-xsum\",\n", 295 | " \"bigscience/T0_3B\", \"EleutherAI/gpt-neo-1.3B\",\n", 296 | "]\n", 297 | "# \"bigscience/T0pp\", T0 are huge!\n", 298 | "models_analyze(mnames)" 299 | ] 300 | }, 301 | { 302 | "cell_type": "markdown", 303 | "id": "31280426", 304 | "metadata": {}, 305 | "source": [ 306 | "We can see big abs max weight values - pretty consistently - so perhaps if the max weight > 1e2 it's a good candidate for bf16 group." 307 | ] 308 | }, 309 | { 310 | "cell_type": "markdown", 311 | "id": "33149dff", 312 | "metadata": {}, 313 | "source": [ 314 | "## fp32 models\n", 315 | "\n", 316 | "Let's look at fp32-pretrained models" 317 | ] 318 | }, 319 | { 320 | "cell_type": "code", 321 | "execution_count": 9, 322 | "id": "517e0226", 323 | "metadata": {}, 324 | "outputs": [ 325 | { 326 | "name": "stdout", 327 | "output_type": "stream", 328 | "text": [ 329 | " name | abs min | abs max | norm min | norm max \n", 330 | "-----------------------------------------|-----------|-----------|-----------|-----------\n", 331 | "gsarti/it5-small | 6.114e-08 | 4.693e+02 | 8.411e-02 | 6.881e+04\n", 332 | "gsarti/it5-base | 1.068e-08 | 1.598e+03 | 3.596e-01 | 8.997e+04\n", 333 | "gsarti/it5-base-oscar | 3.638e-12 | 2.092e+01 | 3.637e+00 | 5.758e+03\n", 334 | "gsarti/it5-large | 2.094e-09 | 4.388e+04 | 7.982e-02 | 1.105e+06\n", 335 | "EleutherAI/gpt-neo-2.7B | 2.319e-11 | 3.563e+00 | 1.322e+00 | 9.850e+02\n" 336 | ] 337 | } 338 | ], 339 | "source": [ 340 | "# fp32-pretrained models\n", 341 | "mnames = [\"gsarti/it5-small\", \"gsarti/it5-base\", \"gsarti/it5-base-oscar\", \n", 342 | " \"gsarti/it5-large\", \"EleutherAI/gpt-neo-2.7B\", \n", 343 | " ]\n", 344 | "models_analyze(mnames)" 345 | ] 346 | }, 347 | { 348 | "cell_type": "markdown", 349 | "id": "081c1126", 350 | "metadata": {}, 351 | "source": [ 352 | "The abs max is all over the map here.\n", 353 | "\n", 354 | "\"EleutherAI/gpt-neo-2.7B\"'s abs max is very low.\n", 355 | "\n", 356 | "XXX: need more inputs" 357 | ] 358 | }, 359 | { 360 | "cell_type": "markdown", 361 | "id": "a447e780", 362 | "metadata": {}, 363 | "source": [ 364 | "## Unknown models\n", 365 | "\n", 366 | "Let's look at some uknown models" 367 | ] 368 | }, 369 | { 370 | "cell_type": "code", 371 | "execution_count": 10, 372 | "id": "f6e8df28", 373 | "metadata": {}, 374 | "outputs": [], 375 | "source": [ 376 | "# fp32? (XXX: need to check)\n", 377 | "#mnames = [\"bigscience/T0_3B\"] \n", 378 | "# mnames = [\"bigscience/T0pp\", \"bigscience/T0_3B\"] \"bigscience/T0pp\" is huge!\n", 379 | "#models_analyze(mnames)" 380 | ] 381 | }, 382 | { 383 | "cell_type": "markdown", 384 | "id": "b7bb4a50", 385 | "metadata": {}, 386 | "source": [ 387 | "need to check how it was trained - looks like bf16 to me" 388 | ] 389 | }, 390 | { 391 | "cell_type": "code", 392 | "execution_count": 11, 393 | "id": "16f37c4a", 394 | "metadata": {}, 395 | "outputs": [], 396 | "source": [ 397 | "# fp32? (XXX: need to check)\n", 398 | "#mnames = [\"google/pegasus-pubmed\"] \n", 399 | "#mnames = [] \n", 400 | "\n", 401 | "#mnames = [\"\"] \n", 402 | "#models_analyze(mnames)" 403 | ] 404 | }, 405 | { 406 | "cell_type": "code", 407 | "execution_count": null, 408 | "id": "7af36a2a", 409 | "metadata": {}, 410 | "outputs": [], 411 | "source": [] 412 | } 413 | ], 414 | "metadata": { 415 | "hide_input": false, 416 | "kernelspec": { 417 | "display_name": "Python 3 (ipykernel)", 418 | "language": "python", 419 | "name": "python3" 420 | }, 421 | "language_info": { 422 | "codemirror_mode": { 423 | "name": "ipython", 424 | "version": 3 425 | }, 426 | "file_extension": ".py", 427 | "mimetype": "text/x-python", 428 | "name": "python", 429 | "nbconvert_exporter": "python", 430 | "pygments_lexer": "ipython3", 431 | "version": "3.8.15" 432 | }, 433 | "toc": { 434 | "base_numbering": 1, 435 | "nav_menu": {}, 436 | "number_sections": true, 437 | "sideBar": true, 438 | "skip_h1_title": false, 439 | "title_cell": "Table of Contents", 440 | "title_sidebar": "Contents", 441 | "toc_cell": false, 442 | "toc_position": { 443 | "height": "calc(100% - 180px)", 444 | "left": "10px", 445 | "top": "150px", 446 | "width": "263.391px" 447 | }, 448 | "toc_section_display": true, 449 | "toc_window_display": true 450 | } 451 | }, 452 | "nbformat": 4, 453 | "nbformat_minor": 5 454 | } 455 | -------------------------------------------------------------------------------- /pytorch.md: -------------------------------------------------------------------------------- 1 | # Pytorch ecosystem 2 | 3 | https://pytorch.org/ecosystem/ 4 | 5 | ## Pytorch frameworks 6 | 7 | high-level training APIs built on top of pytorch 8 | 9 | a comparison article: https://neptune.ai/blog/model-training-libraries-pytorch-ecosystem 10 | 11 | https://github.com/yukkyo/Compare-PyTorch-Catalyst-Ignite-Lightning 12 | 13 | 14 | 15 | - [catalyst](https://github.com/catalyst-team/catalyst) Accelerated DL R&D 16 | 17 | https://catalyst-team.github.io/catalyst 18 | 19 | 20 | 21 | - [pytorch lightning](https://github.com/PyTorchLightning/pytorch-lightning) 22 | 23 | The lightweight PyTorch wrapper for ML researchers. Scale your models. Write less boilerplate 24 | 25 | 26 | 27 | - [pytorch ignite](https://github.com/pytorch/ignite) 28 | 29 | https://pytorch.org/ignite/ 30 | 31 | 32 | 33 | - [skorch](https://github.com/skorch-dev/skorch) A scikit-learn compatible neural network library that wraps pytorch 34 | 35 | follow the sklearn API 36 | 37 | 38 | - [fastai](https://github.com/fastai/fastai) 39 | 40 | 41 | 42 | - [flambé](https://github.com/asappresearch/flambe) 43 | 44 | An ML framework to accelerate research and its path to production 45 | https://flambe.ai 46 | 47 | 48 | - [pywick](https://github.com/achaiah/pywick) 49 | High-level batteries-included neural network training library for Pytorch 50 | 51 | 52 | - [poutyne](https://github.com/GRAAL-Research/poutyne) 53 | A Keras-like framework and utilities for PyTorch 54 | -------------------------------------------------------------------------------- /regularization.md: -------------------------------------------------------------------------------- 1 | # Regularization Notes 2 | 3 | 4 | 5 | # L1 and L2 6 | 7 | https://stats.stackexchange.com/a/159379/202309 8 | 9 | With a sparse model, we think of a model where many of the weights are 0. Let us therefore reason about how L1-regularization is more likely to create 0-weights. 10 | 11 | Consider a model consisting of the weights $(w_1, w_2, \dots, w_m)$. 12 | 13 | With L1 regularization, you penalize the model by a loss function $L_1(w)$ = $\Sigma_i |w_i|$. 14 | 15 | With L2-regularization, you penalize the model by a loss function $L_2(w)$ = $\frac{1}{2} \Sigma_i w_i^2$ 16 | 17 | If using gradient descent, you will iteratively make the weights change in the opposite direction of the gradient with a step size $\eta$ multiplied with the gradient. This means that a more steep gradient will make us take a larger step, while a more flat gradient will make us take a smaller step. Let us look at the gradients (subgradient in case of L1): 18 | 19 | $\frac{dL_1(w)}{dw} = sign(w)$, where $sign(w) = (\frac{w_1}{|w_1|}, \frac{w_2}{|w_2|}, \dots, \frac{w_m}{|w_m|})$ 20 | 21 | $\frac{dL_2(w)}{dw} = w$ 22 | 23 | If we plot the loss function and it's derivative for a model consisting of just a single parameter, it looks like this for L1: 24 | 25 | ![enter image description here](images/regularization-cmWO0.png) 26 | 27 | And like this for L2: 28 | 29 | [![enter image description here](images/regularization-Mkclz.png)][2] 30 | 31 | Notice that for $L_1$, the gradient is either 1 or -1, except for when $w_1 = 0$. That means that L1-regularization will move any weight towards 0 with the same step size, regardless the weight's value. In contrast, you can see that the $L_2$ gradient is linearly decreasing towards 0 as the weight goes towards 0. Therefore, L2-regularization will also move any weight towards 0, but it will take smaller and smaller steps as a weight approaches 0. 32 | 33 | Try to imagine that you start with a model with $w_1 = 5$ and using $\eta = \frac{1}{2}$. In the following picture, you can see how gradient descent using L1-regularization makes 10 of the updates $w_1 := w_1 - \eta \cdot \frac{dL_1(w)}{dw} = w_1 - \frac{1}{2} \cdot 1$, until reaching a model with $w_1 = 0$: 34 | 35 | ![enter image description here](images/regularization-XmtF2.png) 36 | 37 | In constrast, with L2-regularization where $\eta = \frac{1}{2}$, the gradient is $w_1$, causing every step to be only halfway towards 0. That is, we make the update $w_1 := w_1 - \eta \cdot \frac{dL_2(w)}{dw} = w_1 - \frac{1}{2} \cdot w_1$ 38 | Therefore, the model never reaches a weight of 0, regardless of how many steps we take: 39 | 40 | ![enter image description here](images/regularization-jlQYp.png) 41 | 42 | Note that L2-regularization **can** make a weight reach zero if the step size $\eta$ is so high that it reaches zero in a single step. Even if L2-regularization on its own over or undershoots 0, it can still reach a weight of 0 when used together with an objective function that tries to minimize the error of the model with respect to the weights. In that case, finding the best weights of the model is a trade-off between regularizing (having small weights) and minimizing loss (fitting the training data), and the result of that trade-off can be that the best value for some weights are 0. 43 | -------------------------------------------------------------------------------- /stats.md: -------------------------------------------------------------------------------- 1 | # Statistics 2 | 3 | 4 | ## Feature selection 5 | 6 | **Mutual information** (MI) of term t and class c measures how much information the presence/absence of a term contributes to making the correct classification decision on c. 7 | 8 | **X^2** is a measure of how much expected counts E and observed counts N deviate from each other. A high value of X^2 indicates that the hypothesis of independence, which implies that expected and observed counts are similar, is incorrect. 9 | 10 | **Frequency-based feature selection** - selecting the terms that are most common in the class. 11 | 12 | 13 | ## Evaluation Metrics 14 | 15 | The **precision** is the ratio tp / (tp + fp) where tp is the number of true positives and fp the number of false positives. The precision is intuitively the ability of the classifier not to label as positive a sample that is negative. 16 | 17 | The **recall** is the ratio tp / (tp + fn) where tp is the number of true positives and fn the number of false negatives. The recall is intuitively the ability of the classifier to find all the positive samples. 18 | 19 | The **F-beta** score can be interpreted as a weighted harmonic mean of the precision and recall, where an F-beta score reaches its best value at 1 and worst score at 0. The F-beta score weights recall more than precision by a factor of beta. beta == 1.0 means recall and precision are equally important. 20 | 21 | The **support** is the number of occurrences of each class in ground truth (correct) target values. 22 | 23 | [scikit-learn.org](https://scikit-learn.org/stable/modules/model_evaluation.html) 24 | 25 | A **macro-average** treats all classes equally - it computes the metric independently for each class and then takes the average. 26 | 27 | A **micro-average** aggregates the contributions of all classes to compute the average metric. If there is a class imbalance this might be preferable over macro-average. (usually for multi-class classification) 28 | 29 | 30 | 31 | 32 | ## T-test, p-value 33 | 34 | The **p-value** is the probability that the results from your sample data occurred by chance. Typical threshold p = 0.05 => 5% probability. Small p-value indicates that your hypotheses (feature) has a statistical significance. 35 | 36 | 37 | 38 | ## Probabilities 39 | 40 | The **prior probabilities** are also called **class priors**, which describe ”the general probability of encountering a particular class.” 41 | 42 | 43 | ## Distributions 44 | 45 | | Distribution | Categories | Number of trials | Example | 46 | | ------------ | ---------- | ---------------- | ---------------- | 47 | | Bernoulli | 2 | 1 | 1 coin toss | 48 | | Binomial | 2 | many | 2 heads 3 tosses | 49 | | Multinoulli | many | 1 | 1 dice roll | 50 | | Multinomial | many | many | 2 6s in 3 rolls | 51 | 52 | 53 | ### Naive Bayes estimators 54 | 55 | **Multi-variate Bernoulli Naive Bayes**. The binomial model is useful if your feature vectors are binary (i.e., 0s and 1s). One application would be text classification with a bag of words model where the 0s 1s are "word occurs in the document" and "word does not occur in the document" 56 | 57 | **Multinomial Naive Bayes**. The multinomial naive Bayes model is typically used for discrete counts. E.g., if we have a text classification problem, we can take the idea of bernoulli trials one step further and instead of "word occurs in the document" we have "count how often word occurs in the document", you can think of it as "number of times outcome number x_i is observed over the n trials" 58 | 59 | **Gaussian Naive Bayes**. we assume that the features follow a normal distribution. Instead of discrete counts, we have continuous features (e.g., the popular Iris dataset where the features are sepal width, petal width, sepal length, petal length). 60 | 61 | 62 | 63 | ## Classification 64 | 65 | 1. **Binary classification** or **binomial classification** is classifying instances into one of 2 classes 66 | 2. **Multi-class classification** is classifying instances into one of 3+ classes 67 | 3. **Multi-label classification** assigns one or more classes to a single instance. 68 | 69 | 70 | ## Books 71 | 72 | - [Daniel Jurafsky, James H. Martin - Speech and Language Processing - An Introduction to Natural Language Processing, Computational Linguistics, and Speech Recognition](https://amzn.to/2PG3mJz) 73 | - [Ian Goodfellow, Yoshua Bengio, Aaron Courville - Deep Learning - Adaptive Computation And Machine Learning](https://amzn.to/2sPHYIK) 74 | - [Christopher D. Manning, Prabhakar Raghavan, Hinrich Schutze - Introduction to information retrieval](https://amzn.to/2Mdv5iv) 75 | - [Yoav Goldberg - Neural Network Methods for Natural Language Processing (2017)](https://amzn.to/2sSKWw5) 76 | -------------------------------------------------------------------------------- /symbols.md: -------------------------------------------------------------------------------- 1 | # Symbols 2 | 3 | ≡ is defined as equivalent (or modulo airthmetic) 4 | := is defined to be equal to (y := 2z+5) 5 | ⊂ entails (turtle ⊂ reptile) 6 | ⊃ reverse entails (reptile ⊃ turtle) 7 | -------------------------------------------------------------------------------- /tools/html2md: -------------------------------------------------------------------------------- 1 | 2 | # HTML to MD 3 | 4 | perl -pi -e 's|||g' *html 5 | perl -pi -e 's|]*>|("#" x $1)." "|e; s|||' *html 6 | 7 | # mathajax/tex 8 | 9 | perl -0777 -pi -e 's||\$$1\$|msg' *html 10 | -------------------------------------------------------------------------------- /tools/images2local.pl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/perl 2 | 3 | # this script: 4 | # 1. mirrors images and adjusts markdown to use a local copy 5 | # 2. converts reference style image markdown (e.g StackOverflow (SO) answers) to inline style 6 | 7 | # usage: 8 | # tools/images2local.pl *.md 9 | 10 | use strict; 11 | use warnings; 12 | 13 | use LWP::UserAgent; 14 | use File::Copy; 15 | use Data::Dumper; 16 | use File::Basename; 17 | 18 | my $agent = "Mozilla/5.0 (Windows; U; Windows NT 5.1; en-US; rv:58.0) Gecko/20100101 Firefox/58.0"; 19 | my $base = "images"; 20 | 21 | my @files = @ARGV; 22 | 23 | # XXX: make sure we are in the root dir 24 | for my $f (@files) { 25 | print "- Doc: $f\n"; 26 | replace($f) 27 | } 28 | 29 | sub replace { 30 | my $f = shift; 31 | 32 | my ($file, $dir, $ext) = fileparse($f, qr/\.[^.]*/); 33 | #print "$file\n"; 34 | 35 | my $prefix = "$base/$file"; 36 | 37 | my $fi = $f; 38 | my $fo = "$f.new"; 39 | 40 | open my $fh, "<$fi" or die "Can't open $fi: $!\n"; 41 | my @lines = <$fh>; 42 | close $fh; 43 | 44 | my %refs = (); 45 | # mirror images, save their mapping 46 | foreach my $l (@lines) { 47 | # mirror images # XXX: more image formats? 48 | while ($l =~ m/(https?:.*?(?:png|jpg|jpeg|gif))/ig) { 49 | my $src = $1; 50 | my $dest = mirror($src, $prefix); 51 | $l =~ s/\Q$src\E/$dest/; 52 | } 53 | 54 | # image reference style parsing 55 | while ($l =~ m/^(\s+\[([^]]+)\]:\s+(.*))$/g){ 56 | my $pat = $1; 57 | my $id = $2; 58 | my $url = $3; 59 | $refs{$id} = $url; 60 | $l =~ s/\Q$pat\E//; # remove 61 | } 62 | } 63 | #print Dumper \%refs; 64 | 65 | # replace reference style with inline style of images: 66 | # SO examples of the used reference style 67 | # ![enter image description here][1] 68 | # [![enter image description here][2]][2] 69 | # and then at the end of the file: 70 | # [1]: http://i.stack.imgur.com/cmWO0.png 71 | # [2]: https://i.stack.imgur.com/Mkclz.png 72 | foreach (@lines) { 73 | while (/(\!\[([^]]*)\]\[([^]])\])/g) { 74 | my $pat = $1; 75 | my $anchor = $2; # XXX: fixup description if none? 76 | my $id = $3; 77 | if (exists $refs{$id}) { 78 | s/\Q$pat\E/![$anchor]($refs{$id})/g; 79 | } 80 | } 81 | } 82 | 83 | # done 84 | open my $fho, ">$fo" or die "Can't open $fo: $!\n"; 85 | print $fho @lines; 86 | close $fho; 87 | 88 | # rename back 89 | File::Copy::move($fi, "$fi.bak"); # backup 90 | File::Copy::move($fo, $fi); 91 | } 92 | 93 | sub mirror { 94 | my ($url, $prefix) = @_; 95 | 96 | my $dest; 97 | if ($url =~ m|([^/]+)$|) { 98 | $dest = unique_dest("$prefix-$1"); # ensure a unique path 99 | } 100 | else { 101 | return ''; 102 | } 103 | print "Mirror [$url] => [$dest]\n"; 104 | my $lwp = LWP::UserAgent->new(agent => $agent, cookie_jar=>{}); 105 | #return ''; 106 | 107 | my $res = $lwp->mirror($url, $dest); 108 | unless ($res->is_success) { 109 | if ($res->code == 304) { 110 | print "File [$dest] hasn't changed - skipping\n"; 111 | return $dest; 112 | } 113 | else { 114 | print "Failed to get $url: ", $res->status_line; 115 | return ''; 116 | } 117 | } 118 | 119 | # don't get blocked 120 | sleep int rand 7; 121 | 122 | return $dest; 123 | } 124 | 125 | # find a variation on the proposed destination that's not taken already 126 | sub unique_dest { 127 | my $dest = shift; 128 | 129 | my $try = $dest; 130 | my $c = 0; 131 | while (-e $try) { 132 | $c++; 133 | $try =~ s/(-\d+)?\.([^\.]+)$/-$c.$2/; 134 | } 135 | 136 | return $try; 137 | } 138 | --------------------------------------------------------------------------------