├── .gitignore ├── .travis.yml ├── 01-Introduction.ipynb ├── 02-ProbabilisticRepresentations.ipynb ├── 03-Inference.ipynb ├── 04-ParameterLearning.ipynb ├── 05-StructureLearning.ipynb ├── 06-DecisionNetworks.ipynb ├── 07-Games.ipynb ├── 08-MarkovDecisionProcesses.ipynb ├── 09-ApproximateDynamicProgramming.ipynb ├── 10-ExplorationExploitation.ipynb ├── 11-ModelBasedReinforcementLearning.ipynb ├── 12-ModelFreeReinforcementLearning.ipynb ├── 13-StateUncertainty.ipynb ├── 14-ExactPOMDPMethods.ipynb ├── 15-OfflinePOMDPMethods.ipynb ├── 16-OnlinePOMDPMethods.ipynb ├── POMDPs-jl-demo.ipynb ├── Project.toml ├── README.md ├── alpha_plots.jl ├── baby.jl ├── bandits.jl ├── gridworld.jl ├── helpers.jl ├── install.jl ├── rl.jl └── runtests.jl /.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints/ 2 | *.DS_Store 3 | tmp* 4 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: julia 2 | dist: trusty 3 | julia: 4 | - 1.2 5 | notifications: 6 | email: false 7 | before_install: 8 | - sudo apt-get install texlive-latex-extra 9 | script: 10 | - git clone https://github.com/JuliaRegistries/General $(julia -e 'import Pkg; println(joinpath(Pkg.depots1(), "registries", "General"))') 11 | - git clone https://github.com/JuliaPOMDP/Registry $(julia -e 'import Pkg; println(joinpath(Pkg.depots1(), "registries", "JuliaPOMDP"))') 12 | - if [[ -a .git/shallow ]]; then git fetch --unshallow; fi 13 | - julia -e 'import Pkg; ENV["PYTHON"]=""; Pkg.add("PyCall"); Pkg.build("PyCall")' 14 | - julia -e 'import Pkg; Pkg.add("Conda"); using Conda; Conda.add("matplotlib")' 15 | - julia --check-bounds=yes -e 'include("install.jl"); include("runtests.jl")' 16 | -------------------------------------------------------------------------------- /01-Introduction.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Quick introduction to Julia" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "These examples are based on http://learnxinyminutes.com/docs/julia/. Assumes Julia 1.2" 15 | ] 16 | }, 17 | { 18 | "cell_type": "markdown", 19 | "metadata": {}, 20 | "source": [ 21 | "## Types" 22 | ] 23 | }, 24 | { 25 | "cell_type": "markdown", 26 | "metadata": {}, 27 | "source": [ 28 | "There are different types of numbers." 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": 1, 34 | "metadata": {}, 35 | "outputs": [ 36 | { 37 | "data": { 38 | "text/plain": [ 39 | "String" 40 | ] 41 | }, 42 | "execution_count": 1, 43 | "metadata": {}, 44 | "output_type": "execute_result" 45 | } 46 | ], 47 | "source": [ 48 | "typeof(\"mykel\")" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": 2, 54 | "metadata": {}, 55 | "outputs": [ 56 | { 57 | "data": { 58 | "text/plain": [ 59 | "Float64" 60 | ] 61 | }, 62 | "execution_count": 2, 63 | "metadata": {}, 64 | "output_type": "execute_result" 65 | } 66 | ], 67 | "source": [ 68 | "typeof(1.0)" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": 3, 74 | "metadata": {}, 75 | "outputs": [ 76 | { 77 | "data": { 78 | "text/plain": [ 79 | "Complex{Int64}" 80 | ] 81 | }, 82 | "execution_count": 3, 83 | "metadata": {}, 84 | "output_type": "execute_result" 85 | } 86 | ], 87 | "source": [ 88 | "typeof(1 + 1im)" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": 4, 94 | "metadata": {}, 95 | "outputs": [ 96 | { 97 | "data": { 98 | "text/plain": [ 99 | "AbstractFloat" 100 | ] 101 | }, 102 | "execution_count": 4, 103 | "metadata": {}, 104 | "output_type": "execute_result" 105 | } 106 | ], 107 | "source": [ 108 | "supertype(Float64)" 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": 5, 114 | "metadata": {}, 115 | "outputs": [ 116 | { 117 | "data": { 118 | "text/plain": [ 119 | "Real" 120 | ] 121 | }, 122 | "execution_count": 5, 123 | "metadata": {}, 124 | "output_type": "execute_result" 125 | } 126 | ], 127 | "source": [ 128 | "supertype(AbstractFloat)" 129 | ] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "execution_count": 6, 134 | "metadata": {}, 135 | "outputs": [ 136 | { 137 | "data": { 138 | "text/plain": [ 139 | "Number" 140 | ] 141 | }, 142 | "execution_count": 6, 143 | "metadata": {}, 144 | "output_type": "execute_result" 145 | } 146 | ], 147 | "source": [ 148 | "supertype(Real)" 149 | ] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "execution_count": 7, 154 | "metadata": {}, 155 | "outputs": [ 156 | { 157 | "data": { 158 | "text/plain": [ 159 | "Any" 160 | ] 161 | }, 162 | "execution_count": 7, 163 | "metadata": {}, 164 | "output_type": "execute_result" 165 | } 166 | ], 167 | "source": [ 168 | "supertype(Number)" 169 | ] 170 | }, 171 | { 172 | "cell_type": "code", 173 | "execution_count": 8, 174 | "metadata": {}, 175 | "outputs": [ 176 | { 177 | "data": { 178 | "text/plain": [ 179 | "Signed" 180 | ] 181 | }, 182 | "execution_count": 8, 183 | "metadata": {}, 184 | "output_type": "execute_result" 185 | } 186 | ], 187 | "source": [ 188 | "supertype(Int64)" 189 | ] 190 | }, 191 | { 192 | "cell_type": "code", 193 | "execution_count": 9, 194 | "metadata": {}, 195 | "outputs": [ 196 | { 197 | "data": { 198 | "text/plain": [ 199 | "Integer" 200 | ] 201 | }, 202 | "execution_count": 9, 203 | "metadata": {}, 204 | "output_type": "execute_result" 205 | } 206 | ], 207 | "source": [ 208 | "supertype(Signed)" 209 | ] 210 | }, 211 | { 212 | "cell_type": "code", 213 | "execution_count": 10, 214 | "metadata": {}, 215 | "outputs": [ 216 | { 217 | "data": { 218 | "text/plain": [ 219 | "Real" 220 | ] 221 | }, 222 | "execution_count": 10, 223 | "metadata": {}, 224 | "output_type": "execute_result" 225 | } 226 | ], 227 | "source": [ 228 | "supertype(Integer)" 229 | ] 230 | }, 231 | { 232 | "cell_type": "markdown", 233 | "metadata": {}, 234 | "source": [ 235 | "Boolean types" 236 | ] 237 | }, 238 | { 239 | "cell_type": "code", 240 | "execution_count": 11, 241 | "metadata": {}, 242 | "outputs": [ 243 | { 244 | "data": { 245 | "text/plain": [ 246 | "Bool" 247 | ] 248 | }, 249 | "execution_count": 11, 250 | "metadata": {}, 251 | "output_type": "execute_result" 252 | } 253 | ], 254 | "source": [ 255 | "typeof(true)" 256 | ] 257 | }, 258 | { 259 | "cell_type": "markdown", 260 | "metadata": {}, 261 | "source": [ 262 | "## Boolean Operators" 263 | ] 264 | }, 265 | { 266 | "cell_type": "markdown", 267 | "metadata": {}, 268 | "source": [ 269 | "Negation is done with !" 270 | ] 271 | }, 272 | { 273 | "cell_type": "code", 274 | "execution_count": 12, 275 | "metadata": {}, 276 | "outputs": [ 277 | { 278 | "data": { 279 | "text/plain": [ 280 | "false" 281 | ] 282 | }, 283 | "execution_count": 12, 284 | "metadata": {}, 285 | "output_type": "execute_result" 286 | } 287 | ], 288 | "source": [ 289 | "!true" 290 | ] 291 | }, 292 | { 293 | "cell_type": "code", 294 | "execution_count": 13, 295 | "metadata": {}, 296 | "outputs": [ 297 | { 298 | "data": { 299 | "text/plain": [ 300 | "true" 301 | ] 302 | }, 303 | "execution_count": 13, 304 | "metadata": {}, 305 | "output_type": "execute_result" 306 | } 307 | ], 308 | "source": [ 309 | "1 == 1" 310 | ] 311 | }, 312 | { 313 | "cell_type": "code", 314 | "execution_count": 14, 315 | "metadata": {}, 316 | "outputs": [ 317 | { 318 | "data": { 319 | "text/plain": [ 320 | "false" 321 | ] 322 | }, 323 | "execution_count": 14, 324 | "metadata": {}, 325 | "output_type": "execute_result" 326 | } 327 | ], 328 | "source": [ 329 | "1 != 1" 330 | ] 331 | }, 332 | { 333 | "cell_type": "markdown", 334 | "metadata": {}, 335 | "source": [ 336 | "Comparisons can be chained" 337 | ] 338 | }, 339 | { 340 | "cell_type": "code", 341 | "execution_count": 15, 342 | "metadata": {}, 343 | "outputs": [ 344 | { 345 | "data": { 346 | "text/plain": [ 347 | "true" 348 | ] 349 | }, 350 | "execution_count": 15, 351 | "metadata": {}, 352 | "output_type": "execute_result" 353 | } 354 | ], 355 | "source": [ 356 | "1 < 2 < 3" 357 | ] 358 | }, 359 | { 360 | "cell_type": "markdown", 361 | "metadata": {}, 362 | "source": [ 363 | "## Strings" 364 | ] 365 | }, 366 | { 367 | "cell_type": "markdown", 368 | "metadata": {}, 369 | "source": [ 370 | "Use double quotes for strings." 371 | ] 372 | }, 373 | { 374 | "cell_type": "code", 375 | "execution_count": 16, 376 | "metadata": {}, 377 | "outputs": [ 378 | { 379 | "data": { 380 | "text/plain": [ 381 | "\"This is a string\"" 382 | ] 383 | }, 384 | "execution_count": 16, 385 | "metadata": {}, 386 | "output_type": "execute_result" 387 | } 388 | ], 389 | "source": [ 390 | "\"This is a string\"" 391 | ] 392 | }, 393 | { 394 | "cell_type": "code", 395 | "execution_count": 17, 396 | "metadata": {}, 397 | "outputs": [ 398 | { 399 | "data": { 400 | "text/plain": [ 401 | "String" 402 | ] 403 | }, 404 | "execution_count": 17, 405 | "metadata": {}, 406 | "output_type": "execute_result" 407 | } 408 | ], 409 | "source": [ 410 | "typeof(\"This is a string\")" 411 | ] 412 | }, 413 | { 414 | "cell_type": "markdown", 415 | "metadata": {}, 416 | "source": [ 417 | "Use single quotes for characters." 418 | ] 419 | }, 420 | { 421 | "cell_type": "code", 422 | "execution_count": 18, 423 | "metadata": {}, 424 | "outputs": [ 425 | { 426 | "data": { 427 | "text/plain": [ 428 | "'a': ASCII/Unicode U+0061 (category Ll: Letter, lowercase)" 429 | ] 430 | }, 431 | "execution_count": 18, 432 | "metadata": {}, 433 | "output_type": "execute_result" 434 | } 435 | ], 436 | "source": [ 437 | "'a'" 438 | ] 439 | }, 440 | { 441 | "cell_type": "code", 442 | "execution_count": 19, 443 | "metadata": {}, 444 | "outputs": [ 445 | { 446 | "data": { 447 | "text/plain": [ 448 | "Char" 449 | ] 450 | }, 451 | "execution_count": 19, 452 | "metadata": {}, 453 | "output_type": "execute_result" 454 | } 455 | ], 456 | "source": [ 457 | "typeof('a')" 458 | ] 459 | }, 460 | { 461 | "cell_type": "code", 462 | "execution_count": 20, 463 | "metadata": {}, 464 | "outputs": [ 465 | { 466 | "data": { 467 | "text/plain": [ 468 | "'T': ASCII/Unicode U+0054 (category Lu: Letter, uppercase)" 469 | ] 470 | }, 471 | "execution_count": 20, 472 | "metadata": {}, 473 | "output_type": "execute_result" 474 | } 475 | ], 476 | "source": [ 477 | "\"This is a string\"[1] # note the 1-based indexing---similar to Matlab but unlike C/C++/Java" 478 | ] 479 | }, 480 | { 481 | "cell_type": "markdown", 482 | "metadata": {}, 483 | "source": [ 484 | "$ can be used for \"string interpolation\"" 485 | ] 486 | }, 487 | { 488 | "cell_type": "code", 489 | "execution_count": 21, 490 | "metadata": {}, 491 | "outputs": [ 492 | { 493 | "data": { 494 | "text/plain": [ 495 | "\"2 + 2 = 4\"" 496 | ] 497 | }, 498 | "execution_count": 21, 499 | "metadata": {}, 500 | "output_type": "execute_result" 501 | } 502 | ], 503 | "source": [ 504 | "\"2 + 2 = $(2+2)\"" 505 | ] 506 | }, 507 | { 508 | "cell_type": "code", 509 | "execution_count": 22, 510 | "metadata": {}, 511 | "outputs": [ 512 | { 513 | "name": "stdout", 514 | "output_type": "stream", 515 | "text": [ 516 | "5 is less than 5.300000" 517 | ] 518 | } 519 | ], 520 | "source": [ 521 | "using Printf\n", 522 | "Printf.@printf \"%d is less than %f\" 4.5 5.3" 523 | ] 524 | }, 525 | { 526 | "cell_type": "code", 527 | "execution_count": 23, 528 | "metadata": {}, 529 | "outputs": [ 530 | { 531 | "name": "stdout", 532 | "output_type": "stream", 533 | "text": [ 534 | "Welcome to Julia\n" 535 | ] 536 | } 537 | ], 538 | "source": [ 539 | "println(\"Welcome to Julia\")" 540 | ] 541 | }, 542 | { 543 | "cell_type": "markdown", 544 | "metadata": {}, 545 | "source": [ 546 | "## Variables" 547 | ] 548 | }, 549 | { 550 | "cell_type": "code", 551 | "execution_count": 24, 552 | "metadata": {}, 553 | "outputs": [ 554 | { 555 | "data": { 556 | "text/plain": [ 557 | "5" 558 | ] 559 | }, 560 | "execution_count": 24, 561 | "metadata": {}, 562 | "output_type": "execute_result" 563 | } 564 | ], 565 | "source": [ 566 | "x = 5" 567 | ] 568 | }, 569 | { 570 | "cell_type": "markdown", 571 | "metadata": {}, 572 | "source": [ 573 | "Variable names start with a letter, but after that you can use letters, digits, underscores, and exclamation points." 574 | ] 575 | }, 576 | { 577 | "cell_type": "code", 578 | "execution_count": 25, 579 | "metadata": {}, 580 | "outputs": [ 581 | { 582 | "data": { 583 | "text/plain": [ 584 | "1" 585 | ] 586 | }, 587 | "execution_count": 25, 588 | "metadata": {}, 589 | "output_type": "execute_result" 590 | } 591 | ], 592 | "source": [ 593 | "xMarksTheSpot2Dig! = 1" 594 | ] 595 | }, 596 | { 597 | "cell_type": "markdown", 598 | "metadata": {}, 599 | "source": [ 600 | "## Arrays" 601 | ] 602 | }, 603 | { 604 | "cell_type": "code", 605 | "execution_count": 26, 606 | "metadata": {}, 607 | "outputs": [ 608 | { 609 | "data": { 610 | "text/plain": [ 611 | "0-element Array{Int64,1}" 612 | ] 613 | }, 614 | "execution_count": 26, 615 | "metadata": {}, 616 | "output_type": "execute_result" 617 | } 618 | ], 619 | "source": [ 620 | "a = Int64[]" 621 | ] 622 | }, 623 | { 624 | "cell_type": "code", 625 | "execution_count": 27, 626 | "metadata": {}, 627 | "outputs": [ 628 | { 629 | "data": { 630 | "text/plain": [ 631 | "3-element Array{Int64,1}:\n", 632 | " 4\n", 633 | " 5\n", 634 | " 6" 635 | ] 636 | }, 637 | "execution_count": 27, 638 | "metadata": {}, 639 | "output_type": "execute_result" 640 | } 641 | ], 642 | "source": [ 643 | "b = [4, 5, 6]" 644 | ] 645 | }, 646 | { 647 | "cell_type": "code", 648 | "execution_count": 28, 649 | "metadata": {}, 650 | "outputs": [ 651 | { 652 | "data": { 653 | "text/plain": [ 654 | "4" 655 | ] 656 | }, 657 | "execution_count": 28, 658 | "metadata": {}, 659 | "output_type": "execute_result" 660 | } 661 | ], 662 | "source": [ 663 | "b[1]" 664 | ] 665 | }, 666 | { 667 | "cell_type": "code", 668 | "execution_count": 29, 669 | "metadata": {}, 670 | "outputs": [ 671 | { 672 | "data": { 673 | "text/plain": [ 674 | "5" 675 | ] 676 | }, 677 | "execution_count": 29, 678 | "metadata": {}, 679 | "output_type": "execute_result" 680 | } 681 | ], 682 | "source": [ 683 | "b[end-1]" 684 | ] 685 | }, 686 | { 687 | "cell_type": "code", 688 | "execution_count": 30, 689 | "metadata": {}, 690 | "outputs": [ 691 | { 692 | "data": { 693 | "text/plain": [ 694 | "2×2 Array{Int64,2}:\n", 695 | " 1 2\n", 696 | " 3 4" 697 | ] 698 | }, 699 | "execution_count": 30, 700 | "metadata": {}, 701 | "output_type": "execute_result" 702 | } 703 | ], 704 | "source": [ 705 | "matrix = [1 2; 3 4]" 706 | ] 707 | }, 708 | { 709 | "cell_type": "code", 710 | "execution_count": 31, 711 | "metadata": {}, 712 | "outputs": [ 713 | { 714 | "data": { 715 | "text/plain": [ 716 | "0-element Array{Int64,1}" 717 | ] 718 | }, 719 | "execution_count": 31, 720 | "metadata": {}, 721 | "output_type": "execute_result" 722 | } 723 | ], 724 | "source": [ 725 | "a" 726 | ] 727 | }, 728 | { 729 | "cell_type": "code", 730 | "execution_count": 32, 731 | "metadata": {}, 732 | "outputs": [ 733 | { 734 | "data": { 735 | "text/plain": [ 736 | "1-element Array{Int64,1}:\n", 737 | " 1" 738 | ] 739 | }, 740 | "execution_count": 32, 741 | "metadata": {}, 742 | "output_type": "execute_result" 743 | } 744 | ], 745 | "source": [ 746 | "push!(a, 1)" 747 | ] 748 | }, 749 | { 750 | "cell_type": "code", 751 | "execution_count": 33, 752 | "metadata": {}, 753 | "outputs": [ 754 | { 755 | "data": { 756 | "text/plain": [ 757 | "2-element Array{Int64,1}:\n", 758 | " 1\n", 759 | " 2" 760 | ] 761 | }, 762 | "execution_count": 33, 763 | "metadata": {}, 764 | "output_type": "execute_result" 765 | } 766 | ], 767 | "source": [ 768 | "push!(a, 2)" 769 | ] 770 | }, 771 | { 772 | "cell_type": "code", 773 | "execution_count": 34, 774 | "metadata": {}, 775 | "outputs": [ 776 | { 777 | "data": { 778 | "text/plain": [ 779 | "5-element Array{Int64,1}:\n", 780 | " 1\n", 781 | " 2\n", 782 | " 4\n", 783 | " 5\n", 784 | " 6" 785 | ] 786 | }, 787 | "execution_count": 34, 788 | "metadata": {}, 789 | "output_type": "execute_result" 790 | } 791 | ], 792 | "source": [ 793 | "append!(a, b)" 794 | ] 795 | }, 796 | { 797 | "cell_type": "code", 798 | "execution_count": 35, 799 | "metadata": {}, 800 | "outputs": [ 801 | { 802 | "data": { 803 | "text/plain": [ 804 | "5-element Array{Int64,1}:\n", 805 | " 1\n", 806 | " 2\n", 807 | " 4\n", 808 | " 5\n", 809 | " 6" 810 | ] 811 | }, 812 | "execution_count": 35, 813 | "metadata": {}, 814 | "output_type": "execute_result" 815 | } 816 | ], 817 | "source": [ 818 | "a" 819 | ] 820 | }, 821 | { 822 | "cell_type": "code", 823 | "execution_count": 36, 824 | "metadata": {}, 825 | "outputs": [ 826 | { 827 | "data": { 828 | "text/plain": [ 829 | "6" 830 | ] 831 | }, 832 | "execution_count": 36, 833 | "metadata": {}, 834 | "output_type": "execute_result" 835 | } 836 | ], 837 | "source": [ 838 | "pop!(a)" 839 | ] 840 | }, 841 | { 842 | "cell_type": "code", 843 | "execution_count": 37, 844 | "metadata": {}, 845 | "outputs": [ 846 | { 847 | "data": { 848 | "text/plain": [ 849 | "4-element Array{Int64,1}:\n", 850 | " 1\n", 851 | " 2\n", 852 | " 4\n", 853 | " 5" 854 | ] 855 | }, 856 | "execution_count": 37, 857 | "metadata": {}, 858 | "output_type": "execute_result" 859 | } 860 | ], 861 | "source": [ 862 | "a" 863 | ] 864 | }, 865 | { 866 | "cell_type": "code", 867 | "execution_count": 38, 868 | "metadata": {}, 869 | "outputs": [ 870 | { 871 | "data": { 872 | "text/plain": [ 873 | "3-element Array{Int64,1}:\n", 874 | " 2\n", 875 | " 4\n", 876 | " 5" 877 | ] 878 | }, 879 | "execution_count": 38, 880 | "metadata": {}, 881 | "output_type": "execute_result" 882 | } 883 | ], 884 | "source": [ 885 | "a[2:4]" 886 | ] 887 | }, 888 | { 889 | "cell_type": "code", 890 | "execution_count": 39, 891 | "metadata": {}, 892 | "outputs": [ 893 | { 894 | "data": { 895 | "text/plain": [ 896 | "3-element Array{Int64,1}:\n", 897 | " 2\n", 898 | " 4\n", 899 | " 5" 900 | ] 901 | }, 902 | "execution_count": 39, 903 | "metadata": {}, 904 | "output_type": "execute_result" 905 | } 906 | ], 907 | "source": [ 908 | "a[2:end]" 909 | ] 910 | }, 911 | { 912 | "cell_type": "code", 913 | "execution_count": 40, 914 | "metadata": {}, 915 | "outputs": [ 916 | { 917 | "data": { 918 | "text/plain": [ 919 | "5-element Array{Int64,1}:\n", 920 | " 1\n", 921 | " 2\n", 922 | " 4\n", 923 | " 5\n", 924 | " 1" 925 | ] 926 | }, 927 | "execution_count": 40, 928 | "metadata": {}, 929 | "output_type": "execute_result" 930 | } 931 | ], 932 | "source": [ 933 | "push!(a, round(Int64, 1.3))" 934 | ] 935 | }, 936 | { 937 | "cell_type": "code", 938 | "execution_count": 41, 939 | "metadata": {}, 940 | "outputs": [ 941 | { 942 | "data": { 943 | "text/plain": [ 944 | "true" 945 | ] 946 | }, 947 | "execution_count": 41, 948 | "metadata": {}, 949 | "output_type": "execute_result" 950 | } 951 | ], 952 | "source": [ 953 | "in(4, a)" 954 | ] 955 | }, 956 | { 957 | "cell_type": "code", 958 | "execution_count": 42, 959 | "metadata": {}, 960 | "outputs": [ 961 | { 962 | "data": { 963 | "text/plain": [ 964 | "true" 965 | ] 966 | }, 967 | "execution_count": 42, 968 | "metadata": {}, 969 | "output_type": "execute_result" 970 | } 971 | ], 972 | "source": [ 973 | "4 in a" 974 | ] 975 | }, 976 | { 977 | "cell_type": "code", 978 | "execution_count": 43, 979 | "metadata": {}, 980 | "outputs": [ 981 | { 982 | "data": { 983 | "text/plain": [ 984 | "5" 985 | ] 986 | }, 987 | "execution_count": 43, 988 | "metadata": {}, 989 | "output_type": "execute_result" 990 | } 991 | ], 992 | "source": [ 993 | "length(a)" 994 | ] 995 | }, 996 | { 997 | "cell_type": "markdown", 998 | "metadata": {}, 999 | "source": [ 1000 | "## Tuples" 1001 | ] 1002 | }, 1003 | { 1004 | "cell_type": "code", 1005 | "execution_count": 44, 1006 | "metadata": {}, 1007 | "outputs": [ 1008 | { 1009 | "data": { 1010 | "text/plain": [ 1011 | "(1, 5, 3)" 1012 | ] 1013 | }, 1014 | "execution_count": 44, 1015 | "metadata": {}, 1016 | "output_type": "execute_result" 1017 | } 1018 | ], 1019 | "source": [ 1020 | "a = (1, 5, 3)" 1021 | ] 1022 | }, 1023 | { 1024 | "cell_type": "code", 1025 | "execution_count": 45, 1026 | "metadata": {}, 1027 | "outputs": [ 1028 | { 1029 | "data": { 1030 | "text/plain": [ 1031 | "Tuple{Int64,Int64,Int64}" 1032 | ] 1033 | }, 1034 | "execution_count": 45, 1035 | "metadata": {}, 1036 | "output_type": "execute_result" 1037 | } 1038 | ], 1039 | "source": [ 1040 | "typeof(a)" 1041 | ] 1042 | }, 1043 | { 1044 | "cell_type": "code", 1045 | "execution_count": 46, 1046 | "metadata": {}, 1047 | "outputs": [ 1048 | { 1049 | "data": { 1050 | "text/plain": [ 1051 | "5" 1052 | ] 1053 | }, 1054 | "execution_count": 46, 1055 | "metadata": {}, 1056 | "output_type": "execute_result" 1057 | } 1058 | ], 1059 | "source": [ 1060 | "a[2]" 1061 | ] 1062 | }, 1063 | { 1064 | "cell_type": "code", 1065 | "execution_count": 47, 1066 | "metadata": {}, 1067 | "outputs": [], 1068 | "source": [ 1069 | "#a[2] = 3 # can't change elements in a tuple" 1070 | ] 1071 | }, 1072 | { 1073 | "cell_type": "code", 1074 | "execution_count": 48, 1075 | "metadata": {}, 1076 | "outputs": [ 1077 | { 1078 | "data": { 1079 | "text/plain": [ 1080 | "(1, 2, 3)" 1081 | ] 1082 | }, 1083 | "execution_count": 48, 1084 | "metadata": {}, 1085 | "output_type": "execute_result" 1086 | } 1087 | ], 1088 | "source": [ 1089 | "a, b, c = (1, 2, 3)" 1090 | ] 1091 | }, 1092 | { 1093 | "cell_type": "code", 1094 | "execution_count": 49, 1095 | "metadata": {}, 1096 | "outputs": [ 1097 | { 1098 | "data": { 1099 | "text/plain": [ 1100 | "1" 1101 | ] 1102 | }, 1103 | "execution_count": 49, 1104 | "metadata": {}, 1105 | "output_type": "execute_result" 1106 | } 1107 | ], 1108 | "source": [ 1109 | "a" 1110 | ] 1111 | }, 1112 | { 1113 | "cell_type": "code", 1114 | "execution_count": 50, 1115 | "metadata": {}, 1116 | "outputs": [ 1117 | { 1118 | "data": { 1119 | "text/plain": [ 1120 | "2" 1121 | ] 1122 | }, 1123 | "execution_count": 50, 1124 | "metadata": {}, 1125 | "output_type": "execute_result" 1126 | } 1127 | ], 1128 | "source": [ 1129 | "b" 1130 | ] 1131 | }, 1132 | { 1133 | "cell_type": "code", 1134 | "execution_count": 51, 1135 | "metadata": {}, 1136 | "outputs": [ 1137 | { 1138 | "data": { 1139 | "text/plain": [ 1140 | "3" 1141 | ] 1142 | }, 1143 | "execution_count": 51, 1144 | "metadata": {}, 1145 | "output_type": "execute_result" 1146 | } 1147 | ], 1148 | "source": [ 1149 | "c" 1150 | ] 1151 | }, 1152 | { 1153 | "cell_type": "code", 1154 | "execution_count": 52, 1155 | "metadata": {}, 1156 | "outputs": [ 1157 | { 1158 | "data": { 1159 | "text/plain": [ 1160 | "(1, 2, 3)" 1161 | ] 1162 | }, 1163 | "execution_count": 52, 1164 | "metadata": {}, 1165 | "output_type": "execute_result" 1166 | } 1167 | ], 1168 | "source": [ 1169 | "a, b, c = 1, 2, 3 # you can also leave off parentheses" 1170 | ] 1171 | }, 1172 | { 1173 | "cell_type": "code", 1174 | "execution_count": 53, 1175 | "metadata": {}, 1176 | "outputs": [ 1177 | { 1178 | "data": { 1179 | "text/plain": [ 1180 | "1" 1181 | ] 1182 | }, 1183 | "execution_count": 53, 1184 | "metadata": {}, 1185 | "output_type": "execute_result" 1186 | } 1187 | ], 1188 | "source": [ 1189 | "a" 1190 | ] 1191 | }, 1192 | { 1193 | "cell_type": "code", 1194 | "execution_count": 54, 1195 | "metadata": {}, 1196 | "outputs": [ 1197 | { 1198 | "data": { 1199 | "text/plain": [ 1200 | "2" 1201 | ] 1202 | }, 1203 | "execution_count": 54, 1204 | "metadata": {}, 1205 | "output_type": "execute_result" 1206 | } 1207 | ], 1208 | "source": [ 1209 | "b" 1210 | ] 1211 | }, 1212 | { 1213 | "cell_type": "code", 1214 | "execution_count": 55, 1215 | "metadata": {}, 1216 | "outputs": [ 1217 | { 1218 | "data": { 1219 | "text/plain": [ 1220 | "3" 1221 | ] 1222 | }, 1223 | "execution_count": 55, 1224 | "metadata": {}, 1225 | "output_type": "execute_result" 1226 | } 1227 | ], 1228 | "source": [ 1229 | "c" 1230 | ] 1231 | }, 1232 | { 1233 | "cell_type": "code", 1234 | "execution_count": 56, 1235 | "metadata": {}, 1236 | "outputs": [ 1237 | { 1238 | "data": { 1239 | "text/plain": [ 1240 | "(1,)" 1241 | ] 1242 | }, 1243 | "execution_count": 56, 1244 | "metadata": {}, 1245 | "output_type": "execute_result" 1246 | } 1247 | ], 1248 | "source": [ 1249 | "(1,) # to create a single element tuple, you must add the \",\" at the end" 1250 | ] 1251 | }, 1252 | { 1253 | "cell_type": "code", 1254 | "execution_count": 57, 1255 | "metadata": {}, 1256 | "outputs": [ 1257 | { 1258 | "data": { 1259 | "text/plain": [ 1260 | "Tuple{Int64}" 1261 | ] 1262 | }, 1263 | "execution_count": 57, 1264 | "metadata": {}, 1265 | "output_type": "execute_result" 1266 | } 1267 | ], 1268 | "source": [ 1269 | "typeof((1,))" 1270 | ] 1271 | }, 1272 | { 1273 | "cell_type": "code", 1274 | "execution_count": 58, 1275 | "metadata": {}, 1276 | "outputs": [ 1277 | { 1278 | "data": { 1279 | "text/plain": [ 1280 | "(x = 1, y = 2)" 1281 | ] 1282 | }, 1283 | "execution_count": 58, 1284 | "metadata": {}, 1285 | "output_type": "execute_result" 1286 | } 1287 | ], 1288 | "source": [ 1289 | "n = (x=1, y=2) # use keyword assignments in a tuple to create a NamedTuple" 1290 | ] 1291 | }, 1292 | { 1293 | "cell_type": "code", 1294 | "execution_count": 59, 1295 | "metadata": {}, 1296 | "outputs": [ 1297 | { 1298 | "data": { 1299 | "text/plain": [ 1300 | "NamedTuple{(:x, :y),Tuple{Int64,Int64}}" 1301 | ] 1302 | }, 1303 | "execution_count": 59, 1304 | "metadata": {}, 1305 | "output_type": "execute_result" 1306 | } 1307 | ], 1308 | "source": [ 1309 | "typeof(n)" 1310 | ] 1311 | }, 1312 | { 1313 | "cell_type": "code", 1314 | "execution_count": 60, 1315 | "metadata": {}, 1316 | "outputs": [ 1317 | { 1318 | "data": { 1319 | "text/plain": [ 1320 | "1" 1321 | ] 1322 | }, 1323 | "execution_count": 60, 1324 | "metadata": {}, 1325 | "output_type": "execute_result" 1326 | } 1327 | ], 1328 | "source": [ 1329 | "n.x # NamedTuple fields can be accessed using dot syntax" 1330 | ] 1331 | }, 1332 | { 1333 | "cell_type": "markdown", 1334 | "metadata": {}, 1335 | "source": [ 1336 | "## Dictionaries" 1337 | ] 1338 | }, 1339 | { 1340 | "cell_type": "code", 1341 | "execution_count": 61, 1342 | "metadata": {}, 1343 | "outputs": [ 1344 | { 1345 | "data": { 1346 | "text/plain": [ 1347 | "Dict{Any,Any} with 0 entries" 1348 | ] 1349 | }, 1350 | "execution_count": 61, 1351 | "metadata": {}, 1352 | "output_type": "execute_result" 1353 | } 1354 | ], 1355 | "source": [ 1356 | "d = Dict()" 1357 | ] 1358 | }, 1359 | { 1360 | "cell_type": "code", 1361 | "execution_count": 62, 1362 | "metadata": {}, 1363 | "outputs": [ 1364 | { 1365 | "data": { 1366 | "text/plain": [ 1367 | "Dict{String,Int64} with 3 entries:\n", 1368 | " \"two\" => 2\n", 1369 | " \"one\" => 1\n", 1370 | " \"three\" => 3" 1371 | ] 1372 | }, 1373 | "execution_count": 62, 1374 | "metadata": {}, 1375 | "output_type": "execute_result" 1376 | } 1377 | ], 1378 | "source": [ 1379 | "d = Dict(\"one\"=>1, \"two\"=>2, \"three\"=>3)" 1380 | ] 1381 | }, 1382 | { 1383 | "cell_type": "code", 1384 | "execution_count": 63, 1385 | "metadata": {}, 1386 | "outputs": [ 1387 | { 1388 | "data": { 1389 | "text/plain": [ 1390 | "1" 1391 | ] 1392 | }, 1393 | "execution_count": 63, 1394 | "metadata": {}, 1395 | "output_type": "execute_result" 1396 | } 1397 | ], 1398 | "source": [ 1399 | "d[\"one\"]" 1400 | ] 1401 | }, 1402 | { 1403 | "cell_type": "code", 1404 | "execution_count": 64, 1405 | "metadata": {}, 1406 | "outputs": [ 1407 | { 1408 | "data": { 1409 | "text/plain": [ 1410 | "Base.KeySet for a Dict{String,Int64} with 3 entries. Keys:\n", 1411 | " \"two\"\n", 1412 | " \"one\"\n", 1413 | " \"three\"" 1414 | ] 1415 | }, 1416 | "execution_count": 64, 1417 | "metadata": {}, 1418 | "output_type": "execute_result" 1419 | } 1420 | ], 1421 | "source": [ 1422 | "keys(d)" 1423 | ] 1424 | }, 1425 | { 1426 | "cell_type": "code", 1427 | "execution_count": 65, 1428 | "metadata": {}, 1429 | "outputs": [ 1430 | { 1431 | "data": { 1432 | "text/plain": [ 1433 | "3-element Array{String,1}:\n", 1434 | " \"two\" \n", 1435 | " \"one\" \n", 1436 | " \"three\"" 1437 | ] 1438 | }, 1439 | "execution_count": 65, 1440 | "metadata": {}, 1441 | "output_type": "execute_result" 1442 | } 1443 | ], 1444 | "source": [ 1445 | "collect(keys(d))" 1446 | ] 1447 | }, 1448 | { 1449 | "cell_type": "code", 1450 | "execution_count": 66, 1451 | "metadata": {}, 1452 | "outputs": [ 1453 | { 1454 | "data": { 1455 | "text/plain": [ 1456 | "Base.ValueIterator for a Dict{String,Int64} with 3 entries. Values:\n", 1457 | " 2\n", 1458 | " 1\n", 1459 | " 3" 1460 | ] 1461 | }, 1462 | "execution_count": 66, 1463 | "metadata": {}, 1464 | "output_type": "execute_result" 1465 | } 1466 | ], 1467 | "source": [ 1468 | "values(d)" 1469 | ] 1470 | }, 1471 | { 1472 | "cell_type": "code", 1473 | "execution_count": 67, 1474 | "metadata": {}, 1475 | "outputs": [ 1476 | { 1477 | "data": { 1478 | "text/plain": [ 1479 | "true" 1480 | ] 1481 | }, 1482 | "execution_count": 67, 1483 | "metadata": {}, 1484 | "output_type": "execute_result" 1485 | } 1486 | ], 1487 | "source": [ 1488 | "haskey(d, \"one\")" 1489 | ] 1490 | }, 1491 | { 1492 | "cell_type": "code", 1493 | "execution_count": 68, 1494 | "metadata": {}, 1495 | "outputs": [ 1496 | { 1497 | "data": { 1498 | "text/plain": [ 1499 | "false" 1500 | ] 1501 | }, 1502 | "execution_count": 68, 1503 | "metadata": {}, 1504 | "output_type": "execute_result" 1505 | } 1506 | ], 1507 | "source": [ 1508 | "haskey(d, 1)" 1509 | ] 1510 | }, 1511 | { 1512 | "cell_type": "markdown", 1513 | "metadata": {}, 1514 | "source": [ 1515 | "## Control Flow" 1516 | ] 1517 | }, 1518 | { 1519 | "cell_type": "code", 1520 | "execution_count": 69, 1521 | "metadata": {}, 1522 | "outputs": [ 1523 | { 1524 | "name": "stdout", 1525 | "output_type": "stream", 1526 | "text": [ 1527 | "some_var is smaller than 10.\n" 1528 | ] 1529 | } 1530 | ], 1531 | "source": [ 1532 | "# Let's make a variable\n", 1533 | "some_var = 5\n", 1534 | "\n", 1535 | "# Here is an if statement. Indentation is not meaningful in Julia.\n", 1536 | "if some_var > 10\n", 1537 | " println(\"some_var is totally bigger than 10.\")\n", 1538 | "elseif some_var < 10 # This elseif clause is optional.\n", 1539 | " println(\"some_var is smaller than 10.\")\n", 1540 | "else # The else clause is optional too.\n", 1541 | " println(\"some_var is indeed 10.\")\n", 1542 | "end" 1543 | ] 1544 | }, 1545 | { 1546 | "cell_type": "code", 1547 | "execution_count": 70, 1548 | "metadata": {}, 1549 | "outputs": [ 1550 | { 1551 | "name": "stdout", 1552 | "output_type": "stream", 1553 | "text": [ 1554 | "dog is a mammal\n", 1555 | "cat is a mammal\n", 1556 | "mouse is a mammal\n" 1557 | ] 1558 | } 1559 | ], 1560 | "source": [ 1561 | "# For loops iterate over iterables.\n", 1562 | "# Iterable types include Range, Array, Set, Dict, and String.\n", 1563 | "for animal in [\"dog\", \"cat\", \"mouse\"]\n", 1564 | " println(\"$animal is a mammal\")\n", 1565 | " # You can use $ to interpolate variables or expression into strings\n", 1566 | "end" 1567 | ] 1568 | }, 1569 | { 1570 | "cell_type": "code", 1571 | "execution_count": 71, 1572 | "metadata": {}, 1573 | "outputs": [ 1574 | { 1575 | "name": "stdout", 1576 | "output_type": "stream", 1577 | "text": [ 1578 | "mouse is a mammal\n", 1579 | "cat is a mammal\n", 1580 | "dog is a mammal\n" 1581 | ] 1582 | } 1583 | ], 1584 | "source": [ 1585 | "for key_val in Dict(\"dog\"=>\"mammal\",\"cat\"=>\"mammal\",\"mouse\"=>\"mammal\")\n", 1586 | " println(\"$(key_val[1]) is a $(key_val[2])\")\n", 1587 | "end" 1588 | ] 1589 | }, 1590 | { 1591 | "cell_type": "code", 1592 | "execution_count": 72, 1593 | "metadata": {}, 1594 | "outputs": [ 1595 | { 1596 | "name": "stdout", 1597 | "output_type": "stream", 1598 | "text": [ 1599 | "mouse is a mammal\n", 1600 | "cat is a mammal\n", 1601 | "dog is a mammal\n" 1602 | ] 1603 | } 1604 | ], 1605 | "source": [ 1606 | "for (k,v) in Dict(\"dog\"=>\"mammal\",\"cat\"=>\"mammal\",\"mouse\"=>\"mammal\")\n", 1607 | " println(\"$k is a $v\")\n", 1608 | "end" 1609 | ] 1610 | }, 1611 | { 1612 | "cell_type": "code", 1613 | "execution_count": 73, 1614 | "metadata": {}, 1615 | "outputs": [ 1616 | { 1617 | "name": "stdout", 1618 | "output_type": "stream", 1619 | "text": [ 1620 | "0\n", 1621 | "1\n", 1622 | "2\n", 1623 | "3\n" 1624 | ] 1625 | } 1626 | ], 1627 | "source": [ 1628 | "x = 0\n", 1629 | "while x < 4\n", 1630 | " global x\n", 1631 | " println(x)\n", 1632 | " x += 1 # Shorthand for x = x + 1\n", 1633 | "end" 1634 | ] 1635 | }, 1636 | { 1637 | "cell_type": "code", 1638 | "execution_count": 74, 1639 | "metadata": {}, 1640 | "outputs": [], 1641 | "source": [ 1642 | "# Handle exceptions with a try/catch block\n", 1643 | "try\n", 1644 | "# error(\"help\")\n", 1645 | "catch e\n", 1646 | " println(\"caught it $e\")\n", 1647 | "end" 1648 | ] 1649 | }, 1650 | { 1651 | "cell_type": "markdown", 1652 | "metadata": {}, 1653 | "source": [ 1654 | "## Functions" 1655 | ] 1656 | }, 1657 | { 1658 | "cell_type": "code", 1659 | "execution_count": 75, 1660 | "metadata": {}, 1661 | "outputs": [ 1662 | { 1663 | "name": "stdout", 1664 | "output_type": "stream", 1665 | "text": [ 1666 | "x is 5 and y is 6\n" 1667 | ] 1668 | }, 1669 | { 1670 | "data": { 1671 | "text/plain": [ 1672 | "11" 1673 | ] 1674 | }, 1675 | "execution_count": 75, 1676 | "metadata": {}, 1677 | "output_type": "execute_result" 1678 | } 1679 | ], 1680 | "source": [ 1681 | "function add(x, y)\n", 1682 | " println(\"x is $x and y is $y\")\n", 1683 | " # Functions return the value of their last statement (or where you specify \"return\")\n", 1684 | " x + y\n", 1685 | "end\n", 1686 | "add(5, 6) " 1687 | ] 1688 | }, 1689 | { 1690 | "cell_type": "code", 1691 | "execution_count": 76, 1692 | "metadata": {}, 1693 | "outputs": [ 1694 | { 1695 | "data": { 1696 | "text/plain": [ 1697 | "defaults (generic function with 3 methods)" 1698 | ] 1699 | }, 1700 | "execution_count": 76, 1701 | "metadata": {}, 1702 | "output_type": "execute_result" 1703 | } 1704 | ], 1705 | "source": [ 1706 | "# You can define functions with optional positional arguments\n", 1707 | "function defaults(a,b,x=5,y=6)\n", 1708 | " return \"$a $b and $x $y\"\n", 1709 | "end" 1710 | ] 1711 | }, 1712 | { 1713 | "cell_type": "code", 1714 | "execution_count": 77, 1715 | "metadata": {}, 1716 | "outputs": [ 1717 | { 1718 | "data": { 1719 | "text/plain": [ 1720 | "\"h g and 5 6\"" 1721 | ] 1722 | }, 1723 | "execution_count": 77, 1724 | "metadata": {}, 1725 | "output_type": "execute_result" 1726 | } 1727 | ], 1728 | "source": [ 1729 | "defaults('h','g')" 1730 | ] 1731 | }, 1732 | { 1733 | "cell_type": "code", 1734 | "execution_count": 78, 1735 | "metadata": {}, 1736 | "outputs": [ 1737 | { 1738 | "data": { 1739 | "text/plain": [ 1740 | "\"h g and j 6\"" 1741 | ] 1742 | }, 1743 | "execution_count": 78, 1744 | "metadata": {}, 1745 | "output_type": "execute_result" 1746 | } 1747 | ], 1748 | "source": [ 1749 | "defaults('h','g','j')" 1750 | ] 1751 | }, 1752 | { 1753 | "cell_type": "code", 1754 | "execution_count": 79, 1755 | "metadata": {}, 1756 | "outputs": [ 1757 | { 1758 | "data": { 1759 | "text/plain": [ 1760 | "\"h g and j k\"" 1761 | ] 1762 | }, 1763 | "execution_count": 79, 1764 | "metadata": {}, 1765 | "output_type": "execute_result" 1766 | } 1767 | ], 1768 | "source": [ 1769 | "defaults('h','g','j','k')" 1770 | ] 1771 | }, 1772 | { 1773 | "cell_type": "code", 1774 | "execution_count": 80, 1775 | "metadata": {}, 1776 | "outputs": [ 1777 | { 1778 | "data": { 1779 | "text/plain": [ 1780 | "keyword_args (generic function with 1 method)" 1781 | ] 1782 | }, 1783 | "execution_count": 80, 1784 | "metadata": {}, 1785 | "output_type": "execute_result" 1786 | } 1787 | ], 1788 | "source": [ 1789 | "# You can define functions that take keyword arguments\n", 1790 | "function keyword_args(;k1=4,name2=\"hello\") # note the ;\n", 1791 | " return Dict(\"k1\"=>k1,\"name2\"=>name2)\n", 1792 | "end" 1793 | ] 1794 | }, 1795 | { 1796 | "cell_type": "code", 1797 | "execution_count": 81, 1798 | "metadata": {}, 1799 | "outputs": [ 1800 | { 1801 | "data": { 1802 | "text/plain": [ 1803 | "Dict{String,Any} with 2 entries:\n", 1804 | " \"name2\" => \"ness\"\n", 1805 | " \"k1\" => 4" 1806 | ] 1807 | }, 1808 | "execution_count": 81, 1809 | "metadata": {}, 1810 | "output_type": "execute_result" 1811 | } 1812 | ], 1813 | "source": [ 1814 | "keyword_args(name2=\"ness\")" 1815 | ] 1816 | }, 1817 | { 1818 | "cell_type": "code", 1819 | "execution_count": 82, 1820 | "metadata": {}, 1821 | "outputs": [ 1822 | { 1823 | "data": { 1824 | "text/plain": [ 1825 | "Dict{String,String} with 2 entries:\n", 1826 | " \"name2\" => \"hello\"\n", 1827 | " \"k1\" => \"mine\"" 1828 | ] 1829 | }, 1830 | "execution_count": 82, 1831 | "metadata": {}, 1832 | "output_type": "execute_result" 1833 | } 1834 | ], 1835 | "source": [ 1836 | "keyword_args(k1=\"mine\")" 1837 | ] 1838 | }, 1839 | { 1840 | "cell_type": "code", 1841 | "execution_count": 83, 1842 | "metadata": {}, 1843 | "outputs": [ 1844 | { 1845 | "data": { 1846 | "text/plain": [ 1847 | "Dict{String,Any} with 2 entries:\n", 1848 | " \"name2\" => \"hello\"\n", 1849 | " \"k1\" => 4" 1850 | ] 1851 | }, 1852 | "execution_count": 83, 1853 | "metadata": {}, 1854 | "output_type": "execute_result" 1855 | } 1856 | ], 1857 | "source": [ 1858 | "keyword_args()" 1859 | ] 1860 | }, 1861 | { 1862 | "cell_type": "code", 1863 | "execution_count": 84, 1864 | "metadata": {}, 1865 | "outputs": [ 1866 | { 1867 | "data": { 1868 | "text/plain": [ 1869 | "true" 1870 | ] 1871 | }, 1872 | "execution_count": 84, 1873 | "metadata": {}, 1874 | "output_type": "execute_result" 1875 | } 1876 | ], 1877 | "source": [ 1878 | "# This is \"stabby lambda syntax\" for creating anonymous functions\n", 1879 | "(x -> x > 2)(3) # => true" 1880 | ] 1881 | }, 1882 | { 1883 | "cell_type": "code", 1884 | "execution_count": 85, 1885 | "metadata": {}, 1886 | "outputs": [ 1887 | { 1888 | "data": { 1889 | "text/plain": [ 1890 | "create_adder (generic function with 1 method)" 1891 | ] 1892 | }, 1893 | "execution_count": 85, 1894 | "metadata": {}, 1895 | "output_type": "execute_result" 1896 | } 1897 | ], 1898 | "source": [ 1899 | "# This function is identical to create_adder implementation above.\n", 1900 | "function create_adder(x)\n", 1901 | " y -> x + y\n", 1902 | "end" 1903 | ] 1904 | }, 1905 | { 1906 | "cell_type": "code", 1907 | "execution_count": 86, 1908 | "metadata": {}, 1909 | "outputs": [ 1910 | { 1911 | "data": { 1912 | "text/plain": [ 1913 | "create_adder2 (generic function with 1 method)" 1914 | ] 1915 | }, 1916 | "execution_count": 86, 1917 | "metadata": {}, 1918 | "output_type": "execute_result" 1919 | } 1920 | ], 1921 | "source": [ 1922 | "# You can also name the internal function, if you want\n", 1923 | "function create_adder2(x)\n", 1924 | " function adder(y)\n", 1925 | " x + y\n", 1926 | " end\n", 1927 | " adder\n", 1928 | "end\n" 1929 | ] 1930 | }, 1931 | { 1932 | "cell_type": "code", 1933 | "execution_count": 87, 1934 | "metadata": {}, 1935 | "outputs": [ 1936 | { 1937 | "data": { 1938 | "text/plain": [ 1939 | "13" 1940 | ] 1941 | }, 1942 | "execution_count": 87, 1943 | "metadata": {}, 1944 | "output_type": "execute_result" 1945 | } 1946 | ], 1947 | "source": [ 1948 | "add_10 = create_adder(10)\n", 1949 | "add_10(3) " 1950 | ] 1951 | }, 1952 | { 1953 | "cell_type": "code", 1954 | "execution_count": 88, 1955 | "metadata": {}, 1956 | "outputs": [ 1957 | { 1958 | "data": { 1959 | "text/plain": [ 1960 | "3-element Array{Int64,1}:\n", 1961 | " 11\n", 1962 | " 12\n", 1963 | " 13" 1964 | ] 1965 | }, 1966 | "execution_count": 88, 1967 | "metadata": {}, 1968 | "output_type": "execute_result" 1969 | } 1970 | ], 1971 | "source": [ 1972 | "map(add_10, [1,2,3])" 1973 | ] 1974 | }, 1975 | { 1976 | "cell_type": "code", 1977 | "execution_count": 89, 1978 | "metadata": {}, 1979 | "outputs": [ 1980 | { 1981 | "data": { 1982 | "text/plain": [ 1983 | "2-element Array{Int64,1}:\n", 1984 | " 6\n", 1985 | " 7" 1986 | ] 1987 | }, 1988 | "execution_count": 89, 1989 | "metadata": {}, 1990 | "output_type": "execute_result" 1991 | } 1992 | ], 1993 | "source": [ 1994 | "filter(x -> x > 5, [3, 4, 5, 6, 7])" 1995 | ] 1996 | }, 1997 | { 1998 | "cell_type": "code", 1999 | "execution_count": 90, 2000 | "metadata": {}, 2001 | "outputs": [ 2002 | { 2003 | "data": { 2004 | "text/plain": [ 2005 | "3-element Array{Int64,1}:\n", 2006 | " 11\n", 2007 | " 12\n", 2008 | " 13" 2009 | ] 2010 | }, 2011 | "execution_count": 90, 2012 | "metadata": {}, 2013 | "output_type": "execute_result" 2014 | } 2015 | ], 2016 | "source": [ 2017 | "[add_10(i) for i in [1, 2, 3]]" 2018 | ] 2019 | }, 2020 | { 2021 | "cell_type": "markdown", 2022 | "metadata": {}, 2023 | "source": [ 2024 | "## Composite Types" 2025 | ] 2026 | }, 2027 | { 2028 | "cell_type": "code", 2029 | "execution_count": 91, 2030 | "metadata": {}, 2031 | "outputs": [], 2032 | "source": [ 2033 | "struct Tiger\n", 2034 | " taillength::Float64\n", 2035 | " coatcolor # not including a type annotation is the same as `::Any`\n", 2036 | "end" 2037 | ] 2038 | }, 2039 | { 2040 | "cell_type": "code", 2041 | "execution_count": 92, 2042 | "metadata": {}, 2043 | "outputs": [ 2044 | { 2045 | "data": { 2046 | "text/plain": [ 2047 | "Tiger(3.5, \"orange\")" 2048 | ] 2049 | }, 2050 | "execution_count": 92, 2051 | "metadata": {}, 2052 | "output_type": "execute_result" 2053 | } 2054 | ], 2055 | "source": [ 2056 | "tigger = Tiger(3.5,\"orange\")" 2057 | ] 2058 | }, 2059 | { 2060 | "cell_type": "code", 2061 | "execution_count": 93, 2062 | "metadata": {}, 2063 | "outputs": [], 2064 | "source": [ 2065 | "abstract type Cat end # just a name and point in the type hierarchy" 2066 | ] 2067 | }, 2068 | { 2069 | "cell_type": "code", 2070 | "execution_count": 94, 2071 | "metadata": {}, 2072 | "outputs": [ 2073 | { 2074 | "data": { 2075 | "text/plain": [ 2076 | "2-element Array{Any,1}:\n", 2077 | " Complex\n", 2078 | " Real " 2079 | ] 2080 | }, 2081 | "execution_count": 94, 2082 | "metadata": {}, 2083 | "output_type": "execute_result" 2084 | } 2085 | ], 2086 | "source": [ 2087 | "subtypes(Number)" 2088 | ] 2089 | }, 2090 | { 2091 | "cell_type": "code", 2092 | "execution_count": 95, 2093 | "metadata": {}, 2094 | "outputs": [ 2095 | { 2096 | "data": { 2097 | "text/plain": [ 2098 | "0-element Array{Any,1}" 2099 | ] 2100 | }, 2101 | "execution_count": 95, 2102 | "metadata": {}, 2103 | "output_type": "execute_result" 2104 | } 2105 | ], 2106 | "source": [ 2107 | "subtypes(Cat)" 2108 | ] 2109 | }, 2110 | { 2111 | "cell_type": "code", 2112 | "execution_count": 96, 2113 | "metadata": {}, 2114 | "outputs": [], 2115 | "source": [ 2116 | "# <: is the subtyping operator\n", 2117 | "struct Lion <: Cat # Lion is a subtype of Cat\n", 2118 | " mane_color\n", 2119 | " roar::String\n", 2120 | "end" 2121 | ] 2122 | }, 2123 | { 2124 | "cell_type": "code", 2125 | "execution_count": 97, 2126 | "metadata": {}, 2127 | "outputs": [], 2128 | "source": [ 2129 | "# You can define more constructors for your type\n", 2130 | "# Just define a function of the same name as the type\n", 2131 | "# and call an existing constructor to get a value of the correct type\n", 2132 | "Lion(roar::String) = Lion(\"green\",roar);\n", 2133 | "# This is an outer constructor because it's outside the type definition\n", 2134 | "# Note, the semicolon suppresses the output" 2135 | ] 2136 | }, 2137 | { 2138 | "cell_type": "code", 2139 | "execution_count": 98, 2140 | "metadata": {}, 2141 | "outputs": [], 2142 | "source": [ 2143 | "struct Panther <: Cat # Panther is also a subtype of Cat\n", 2144 | " eye_color\n", 2145 | " Panther() = new(\"green\")\n", 2146 | " # Panthers will only have this constructor, and no default constructor.\n", 2147 | "end" 2148 | ] 2149 | }, 2150 | { 2151 | "cell_type": "code", 2152 | "execution_count": 99, 2153 | "metadata": {}, 2154 | "outputs": [ 2155 | { 2156 | "data": { 2157 | "text/plain": [ 2158 | "2-element Array{Any,1}:\n", 2159 | " Lion \n", 2160 | " Panther" 2161 | ] 2162 | }, 2163 | "execution_count": 99, 2164 | "metadata": {}, 2165 | "output_type": "execute_result" 2166 | } 2167 | ], 2168 | "source": [ 2169 | "subtypes(Cat)" 2170 | ] 2171 | }, 2172 | { 2173 | "cell_type": "markdown", 2174 | "metadata": {}, 2175 | "source": [ 2176 | "## Multiple Dispatch" 2177 | ] 2178 | }, 2179 | { 2180 | "cell_type": "code", 2181 | "execution_count": 100, 2182 | "metadata": {}, 2183 | "outputs": [ 2184 | { 2185 | "data": { 2186 | "text/plain": [ 2187 | "meow (generic function with 3 methods)" 2188 | ] 2189 | }, 2190 | "execution_count": 100, 2191 | "metadata": {}, 2192 | "output_type": "execute_result" 2193 | } 2194 | ], 2195 | "source": [ 2196 | "function meow(animal::Lion)\n", 2197 | " animal.roar # access type properties using dot notation\n", 2198 | "end\n", 2199 | "\n", 2200 | "function meow(animal::Panther)\n", 2201 | " \"grrr\"\n", 2202 | "end\n", 2203 | "\n", 2204 | "function meow(animal::Tiger)\n", 2205 | " \"rawwwr\"\n", 2206 | "end" 2207 | ] 2208 | }, 2209 | { 2210 | "cell_type": "code", 2211 | "execution_count": 101, 2212 | "metadata": {}, 2213 | "outputs": [ 2214 | { 2215 | "data": { 2216 | "text/plain": [ 2217 | "\"rawwwr\"" 2218 | ] 2219 | }, 2220 | "execution_count": 101, 2221 | "metadata": {}, 2222 | "output_type": "execute_result" 2223 | } 2224 | ], 2225 | "source": [ 2226 | "meow(tigger)" 2227 | ] 2228 | }, 2229 | { 2230 | "cell_type": "code", 2231 | "execution_count": 102, 2232 | "metadata": {}, 2233 | "outputs": [ 2234 | { 2235 | "data": { 2236 | "text/plain": [ 2237 | "\"ROAAR\"" 2238 | ] 2239 | }, 2240 | "execution_count": 102, 2241 | "metadata": {}, 2242 | "output_type": "execute_result" 2243 | } 2244 | ], 2245 | "source": [ 2246 | "meow(Lion(\"brown\",\"ROAAR\"))" 2247 | ] 2248 | }, 2249 | { 2250 | "cell_type": "code", 2251 | "execution_count": 103, 2252 | "metadata": {}, 2253 | "outputs": [ 2254 | { 2255 | "data": { 2256 | "text/plain": [ 2257 | "\"grrr\"" 2258 | ] 2259 | }, 2260 | "execution_count": 103, 2261 | "metadata": {}, 2262 | "output_type": "execute_result" 2263 | } 2264 | ], 2265 | "source": [ 2266 | "meow(Panther())" 2267 | ] 2268 | }, 2269 | { 2270 | "cell_type": "markdown", 2271 | "metadata": {}, 2272 | "source": [ 2273 | "## Native Code" 2274 | ] 2275 | }, 2276 | { 2277 | "cell_type": "code", 2278 | "execution_count": 104, 2279 | "metadata": {}, 2280 | "outputs": [ 2281 | { 2282 | "data": { 2283 | "text/plain": [ 2284 | "square (generic function with 1 method)" 2285 | ] 2286 | }, 2287 | "execution_count": 104, 2288 | "metadata": {}, 2289 | "output_type": "execute_result" 2290 | } 2291 | ], 2292 | "source": [ 2293 | "square(l) = l * l" 2294 | ] 2295 | }, 2296 | { 2297 | "cell_type": "code", 2298 | "execution_count": 105, 2299 | "metadata": {}, 2300 | "outputs": [ 2301 | { 2302 | "data": { 2303 | "text/plain": [ 2304 | "25" 2305 | ] 2306 | }, 2307 | "execution_count": 105, 2308 | "metadata": {}, 2309 | "output_type": "execute_result" 2310 | } 2311 | ], 2312 | "source": [ 2313 | "square(5)" 2314 | ] 2315 | }, 2316 | { 2317 | "cell_type": "code", 2318 | "execution_count": 106, 2319 | "metadata": {}, 2320 | "outputs": [ 2321 | { 2322 | "name": "stdout", 2323 | "output_type": "stream", 2324 | "text": [ 2325 | "\t.text\n", 2326 | "; ┌ @ In[104]:1 within `square'\n", 2327 | "\tpushq\t%rbp\n", 2328 | "\tmovq\t%rsp, %rbp\n", 2329 | "; │┌ @ int.jl:54 within `*'\n", 2330 | "\timull\t%ecx, %ecx\n", 2331 | "; │└\n", 2332 | "\tmovl\t%ecx, %eax\n", 2333 | "\tpopq\t%rbp\n", 2334 | "\tretq\n", 2335 | "\tnopl\t(%rax,%rax)\n", 2336 | "; └\n" 2337 | ] 2338 | } 2339 | ], 2340 | "source": [ 2341 | "code_native(square, (Int32,))" 2342 | ] 2343 | }, 2344 | { 2345 | "cell_type": "code", 2346 | "execution_count": 107, 2347 | "metadata": {}, 2348 | "outputs": [ 2349 | { 2350 | "name": "stdout", 2351 | "output_type": "stream", 2352 | "text": [ 2353 | "\t.text\n", 2354 | "; ┌ @ In[104]:1 within `square'\n", 2355 | "\tpushq\t%rbp\n", 2356 | "\tmovq\t%rsp, %rbp\n", 2357 | "; │┌ @ float.jl:399 within `*'\n", 2358 | "\tvmulsd\t%xmm0, %xmm0, %xmm0\n", 2359 | "; │└\n", 2360 | "\tpopq\t%rbp\n", 2361 | "\tretq\n", 2362 | "\tnopw\t(%rax,%rax)\n", 2363 | "; └\n" 2364 | ] 2365 | } 2366 | ], 2367 | "source": [ 2368 | "code_native(square, (Float64,))" 2369 | ] 2370 | }, 2371 | { 2372 | "cell_type": "code", 2373 | "execution_count": 108, 2374 | "metadata": {}, 2375 | "outputs": [ 2376 | { 2377 | "name": "stdout", 2378 | "output_type": "stream", 2379 | "text": [ 2380 | "\n", 2381 | "; @ In[104]:1 within `square'\n", 2382 | "; Function Attrs: uwtable\n", 2383 | "define i32 @julia_square_17289(i32) #0 {\n", 2384 | "top:\n", 2385 | "; ┌ @ int.jl:54 within `*'\n", 2386 | " %1 = mul i32 %0, %0\n", 2387 | "; └\n", 2388 | " ret i32 %1\n", 2389 | "}\n" 2390 | ] 2391 | } 2392 | ], 2393 | "source": [ 2394 | "code_llvm(square, (Int32,))" 2395 | ] 2396 | } 2397 | ], 2398 | "metadata": { 2399 | "@webio": { 2400 | "lastCommId": null, 2401 | "lastKernelId": null 2402 | }, 2403 | "kernelspec": { 2404 | "display_name": "Julia 1.2.0", 2405 | "language": "julia", 2406 | "name": "julia-1.2" 2407 | }, 2408 | "language_info": { 2409 | "file_extension": ".jl", 2410 | "mimetype": "application/julia", 2411 | "name": "julia", 2412 | "version": "1.2.0" 2413 | } 2414 | }, 2415 | "nbformat": 4, 2416 | "nbformat_minor": 1 2417 | } 2418 | -------------------------------------------------------------------------------- /03-Inference.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Inference" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "using Distributions\n", 17 | "using BayesNets" 18 | ] 19 | }, 20 | { 21 | "cell_type": "markdown", 22 | "metadata": {}, 23 | "source": [ 24 | "## Inference for Classification" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": 2, 30 | "metadata": {}, 31 | "outputs": [ 32 | { 33 | "data": { 34 | "image/svg+xml": [ 35 | "\n", 36 | "\n", 37 | "\n", 38 | "\n", 39 | "\n", 40 | "\n", 41 | "\n", 42 | "\n", 43 | "\n", 44 | "\n", 45 | "\n", 46 | "\n", 47 | "\n", 48 | "\n", 49 | "\n", 50 | "\n", 51 | "\n", 52 | "\n", 53 | "\n", 54 | "\n", 55 | "\n", 56 | "\n", 57 | "\n", 58 | "\n", 59 | "\n", 60 | "\n", 61 | "\n", 62 | "\n", 63 | "\n", 64 | "\n", 65 | "\n", 66 | "\n", 67 | "\n", 68 | "\n", 69 | "\n", 70 | "\n", 71 | "\n", 72 | "\n", 73 | "\n", 74 | "\n", 75 | "\n", 76 | "\n", 77 | "\n", 78 | "\n", 79 | "\n", 80 | "\n", 81 | "\n", 82 | "\n", 83 | "\n", 84 | "\n", 85 | "\n", 86 | "\n", 87 | "\n", 88 | "\n", 89 | "\n", 90 | "\n", 91 | "\n", 92 | "\n", 93 | "\n", 94 | "\n", 95 | "\n", 96 | "\n", 97 | "\n", 98 | " \n", 99 | " \n", 100 | " \n", 101 | " \n", 102 | " \n", 103 | "\n", 104 | "\n", 105 | " \n", 106 | " \n", 107 | " \n", 108 | " \n", 109 | " \n", 110 | " \n", 111 | " \n", 112 | " \n", 113 | "\n", 114 | "\n", 115 | " \n", 116 | " \n", 117 | " \n", 118 | " \n", 119 | " \n", 120 | " \n", 121 | " \n", 122 | " \n", 123 | " \n", 124 | " \n", 125 | " \n", 126 | "\n", 127 | "\n", 128 | "\n", 129 | "\n" 130 | ], 131 | "text/plain": [ 132 | "BayesNet{CPD}({3, 2} directed simple Int64 graph, CPD[StaticCPD{NamedCategorical{String}}(:Class, Symbol[], NamedCategorical with entries:\n", 133 | "\t 0.5000: aircraft\n", 134 | "\t 0.5000: bird\n", 135 | "), FunctionalCPD{Normal}(:Airspeed, Symbol[:Class], airspeedDistributions), FunctionalCPD{NamedCategorical}(:Fluctuation, Symbol[:Class], fluctuationDistributions)], Dict(:Class => 1,:Fluctuation => 3,:Airspeed => 2))" 136 | ] 137 | }, 138 | "execution_count": 2, 139 | "metadata": {}, 140 | "output_type": "execute_result" 141 | } 142 | ], 143 | "source": [ 144 | "b = BayesNet()\n", 145 | "\n", 146 | "# Set uniform prior over Class\n", 147 | "push!(b, StaticCPD(:Class, NamedCategorical([\"bird\", \"aircraft\"], [0.5, 0.5])))\n", 148 | "\n", 149 | "fluctuationStates = [\"low\", \"hi\"]\n", 150 | "fluctuationDistributions(a::Assignment) = a[:Class] == \"bird\" ? NamedCategorical(fluctuationStates, [0.1, 0.9]) : NamedCategorical(fluctuationStates, [0.9, 0.1])\n", 151 | "push!(b, FunctionalCPD{NamedCategorical}(:Fluctuation, [:Class], fluctuationDistributions))\n", 152 | "\n", 153 | "# if Bird, then Airspeed ~ N(45,10)\n", 154 | "# if Aircraft, then Airspeed ~ N(100,40)\n", 155 | "airspeedDistributions(a::Assignment) = a[:Class] == \"bird\" ? Normal(45,10) : Normal(100,40)\n", 156 | "push!(b, FunctionalCPD{Normal}(:Airspeed, [:Class], airspeedDistributions))" 157 | ] 158 | }, 159 | { 160 | "cell_type": "code", 161 | "execution_count": 3, 162 | "metadata": {}, 163 | "outputs": [ 164 | { 165 | "data": { 166 | "text/plain": [ 167 | "0.00026995483256594033" 168 | ] 169 | }, 170 | "execution_count": 3, 171 | "metadata": {}, 172 | "output_type": "execute_result" 173 | } 174 | ], 175 | "source": [ 176 | "pb = pdf(b, :Class=>\"bird\", :Airspeed=>65, :Fluctuation=>\"low\")" 177 | ] 178 | }, 179 | { 180 | "cell_type": "code", 181 | "execution_count": 4, 182 | "metadata": {}, 183 | "outputs": [ 184 | { 185 | "data": { 186 | "text/plain": [ 187 | "0.003060618731758615" 188 | ] 189 | }, 190 | "execution_count": 4, 191 | "metadata": {}, 192 | "output_type": "execute_result" 193 | } 194 | ], 195 | "source": [ 196 | "pa = pdf(b, :Class=>\"aircraft\", :Airspeed=>65, :Fluctuation=>\"low\")" 197 | ] 198 | }, 199 | { 200 | "cell_type": "code", 201 | "execution_count": 5, 202 | "metadata": {}, 203 | "outputs": [ 204 | { 205 | "data": { 206 | "text/plain": [ 207 | "0.9189464435022358" 208 | ] 209 | }, 210 | "execution_count": 5, 211 | "metadata": {}, 212 | "output_type": "execute_result" 213 | } 214 | ], 215 | "source": [ 216 | "# Probability of aircraft given data\n", 217 | "pa / (pa + pb)" 218 | ] 219 | }, 220 | { 221 | "cell_type": "code", 222 | "execution_count": 6, 223 | "metadata": {}, 224 | "outputs": [ 225 | { 226 | "data": { 227 | "text/plain": [ 228 | "2-element Array{Float64,1}:\n", 229 | " 0.00026995483256594033\n", 230 | " 0.003060618731758615 " 231 | ] 232 | }, 233 | "execution_count": 6, 234 | "metadata": {}, 235 | "output_type": "execute_result" 236 | } 237 | ], 238 | "source": [ 239 | "# View (unnormalized) distribution as a vector\n", 240 | "d = [pb, pa]" 241 | ] 242 | }, 243 | { 244 | "cell_type": "code", 245 | "execution_count": 7, 246 | "metadata": {}, 247 | "outputs": [ 248 | { 249 | "data": { 250 | "text/plain": [ 251 | "2-element Array{Float64,1}:\n", 252 | " 0.08105355649776423\n", 253 | " 0.9189464435022358 " 254 | ] 255 | }, 256 | "execution_count": 7, 257 | "metadata": {}, 258 | "output_type": "execute_result" 259 | } 260 | ], 261 | "source": [ 262 | "# Now normalize\n", 263 | "d / sum(d)" 264 | ] 265 | }, 266 | { 267 | "cell_type": "markdown", 268 | "metadata": {}, 269 | "source": [ 270 | "## Inference in temporal models" 271 | ] 272 | }, 273 | { 274 | "cell_type": "markdown", 275 | "metadata": {}, 276 | "source": [ 277 | "Here is a simple crying baby temporal model. Whether the baby is crying is a noisy indication of whether the baby is hungry." 278 | ] 279 | }, 280 | { 281 | "cell_type": "code", 282 | "execution_count": 8, 283 | "metadata": {}, 284 | "outputs": [], 285 | "source": [ 286 | "struct State\n", 287 | " hungry\n", 288 | "end\n", 289 | "struct Observation\n", 290 | " crying\n", 291 | "end\n", 292 | "\n", 293 | "States = [State(false), State(true)]\n", 294 | "Observations = [Observation(false), Observation(true)]\n", 295 | "\n", 296 | "# P(o|s)\n", 297 | "function P(o::Observation, s::State)\n", 298 | " if s.hungry\n", 299 | " return o.crying ? 0.8 : 0.2\n", 300 | " else\n", 301 | " return o.crying ? 0.1 : 0.9\n", 302 | " end\n", 303 | "end\n", 304 | "\n", 305 | "# P(s' | s)\n", 306 | "function P(s1::State, s0::State)\n", 307 | " if s0.hungry\n", 308 | " return s1.hungry ? 0.9 : 0.1\n", 309 | " else\n", 310 | " return s1.hungry ? 0.6 : 0.4\n", 311 | " end\n", 312 | "end\n", 313 | "\n", 314 | "# P(s)\n", 315 | "P(s::State) = 1/length(States)\n", 316 | "\n", 317 | "mutable struct Belief\n", 318 | " p::Vector{Float64}\n", 319 | "end" 320 | ] 321 | }, 322 | { 323 | "cell_type": "markdown", 324 | "metadata": {}, 325 | "source": [ 326 | "Here are some sampling functions." 327 | ] 328 | }, 329 | { 330 | "cell_type": "code", 331 | "execution_count": 9, 332 | "metadata": {}, 333 | "outputs": [], 334 | "source": [ 335 | "sampleState() = States[rand(Distributions.Categorical(Float64[P(s) for s in States]))]\n", 336 | "sampleState(s::State) = States[rand(Distributions.Categorical(Float64[P(s1, s) for s1 in States]))]\n", 337 | "sampleObservation(s::State) = Observations[rand(Distributions.Categorical(Float64[P(o, s) for o in Observations]))]\n", 338 | "function generateSequence(steps)\n", 339 | " S = State[]\n", 340 | " O = Observation[]\n", 341 | " s = sampleState()\n", 342 | " push!(S, s) \n", 343 | " o = sampleObservation(s)\n", 344 | " push!(O, o)\n", 345 | " for t = 2:steps\n", 346 | " s = sampleState(s)\n", 347 | " push!(S, s) \n", 348 | " o = sampleObservation(s)\n", 349 | " push!(O, o)\n", 350 | " end\n", 351 | " (S, O)\n", 352 | "end\n", 353 | "(S, O) = generateSequence(20);" 354 | ] 355 | }, 356 | { 357 | "cell_type": "markdown", 358 | "metadata": {}, 359 | "source": [ 360 | "Update a belief as follows \n", 361 | "\n", 362 | "$b_1(s) \\propto P(o \\mid s) \\sum_{s'} P(s \\mid s') b_0(s')$" 363 | ] 364 | }, 365 | { 366 | "cell_type": "code", 367 | "execution_count": 10, 368 | "metadata": {}, 369 | "outputs": [], 370 | "source": [ 371 | "function update(b0::Belief, o::Observation)\n", 372 | " b1 = Belief(zeros(length(States)))\n", 373 | " for i = 1:length(States)\n", 374 | " s1 = States[i]\n", 375 | " b1.p[i] = P(o, s1) * sum([P(s1,States[j]) * b0.p[j] for j = 1:length(States)])\n", 376 | " end\n", 377 | " b1.p = b1.p / sum(b1.p)\n", 378 | " b1\n", 379 | "end;" 380 | ] 381 | }, 382 | { 383 | "cell_type": "code", 384 | "execution_count": 11, 385 | "metadata": {}, 386 | "outputs": [ 387 | { 388 | "name": "stdout", 389 | "output_type": "stream", 390 | "text": [ 391 | "s\to\tP(hungry)\n", 392 | "1\t1\t0.960\n", 393 | "1\t1\t0.984\n", 394 | "0\t0\t0.655\n", 395 | "1\t1\t0.969\n", 396 | "0\t0\t0.644\n", 397 | "1\t1\t0.968\n", 398 | "1\t1\t0.985\n", 399 | "0\t0\t0.656\n", 400 | "1\t0\t0.465\n", 401 | "1\t1\t0.958\n", 402 | "1\t1\t0.984\n", 403 | "1\t0\t0.655\n", 404 | "1\t0\t0.465\n", 405 | "1\t1\t0.958\n", 406 | "1\t1\t0.984\n", 407 | "1\t1\t0.986\n", 408 | "1\t0\t0.656\n", 409 | "1\t1\t0.969\n", 410 | "1\t1\t0.985\n", 411 | "1\t1\t0.986\n" 412 | ] 413 | } 414 | ], 415 | "source": [ 416 | "using Printf\n", 417 | "function printBeliefs(S::Vector{State}, O::Vector{Observation})\n", 418 | " print(\"s\\to\\tP(hungry)\\n\")\n", 419 | " n = length(S)\n", 420 | " b = Belief([0.5, 0.5])\n", 421 | " for t = 1:n\n", 422 | " b = update(b, O[t])\n", 423 | " Printf.@printf(\"%.0f\\t%.0f\\t%.3f\\n\", float(S[t].hungry), float(O[t].crying), b.p[2])\n", 424 | " end\n", 425 | "end\n", 426 | "printBeliefs(S, O)" 427 | ] 428 | }, 429 | { 430 | "cell_type": "markdown", 431 | "metadata": {}, 432 | "source": [ 433 | "## Exact Inference" 434 | ] 435 | }, 436 | { 437 | "cell_type": "code", 438 | "execution_count": 12, 439 | "metadata": {}, 440 | "outputs": [ 441 | { 442 | "data": { 443 | "image/svg+xml": [ 444 | "\n", 445 | "\n", 446 | "\n", 447 | "\n", 448 | "\n", 449 | "\n", 450 | "\n", 451 | "\n", 452 | "\n", 453 | "\n", 454 | "\n", 455 | "\n", 456 | "\n", 457 | "\n", 458 | "\n", 459 | "\n", 460 | "\n", 461 | "\n", 462 | "\n", 463 | "\n", 464 | "\n", 465 | "\n", 466 | "\n", 467 | "\n", 468 | "\n", 469 | "\n", 470 | "\n", 471 | "\n", 472 | "\n", 473 | "\n", 474 | "\n", 475 | "\n", 476 | "\n", 477 | "\n", 478 | " \n", 479 | "\n", 480 | "\n", 481 | " \n", 482 | "\n", 483 | "\n", 484 | " \n", 485 | "\n", 486 | "\n", 487 | " \n", 488 | "\n", 489 | "\n", 490 | " \n", 491 | "\n", 492 | "\n", 493 | "\n", 494 | "\n" 495 | ], 496 | "text/plain": [ 497 | "BayesNet{CategoricalCPD{DiscreteNonParametric{Int64,Float64,Base.OneTo{Int64},Array{Float64,1}}}}({5, 4} directed simple Int64 graph, CategoricalCPD{DiscreteNonParametric{Int64,Float64,Base.OneTo{Int64},Array{Float64,1}}}[2 instantiations:\n", 498 | " B (2), 2 instantiations:\n", 499 | " S (2), 8 instantiations:\n", 500 | " E (2)\n", 501 | " B (2)\n", 502 | " S (2), 4 instantiations:\n", 503 | " C (2)\n", 504 | " E (2), 4 instantiations:\n", 505 | " D (2)\n", 506 | " E (2)], Dict(:D => 5,:B => 1,:S => 2,:E => 3,:C => 4))" 507 | ] 508 | }, 509 | "execution_count": 12, 510 | "metadata": {}, 511 | "output_type": "execute_result" 512 | } 513 | ], 514 | "source": [ 515 | "b = DiscreteBayesNet()\n", 516 | "push!(b, DiscreteCPD(:B, [0.1,0.9]))\n", 517 | "push!(b, DiscreteCPD(:S, [0.5,0.5]))\n", 518 | "push!(b, rand_cpd(b, 2, :E, [:B, :S]))\n", 519 | "push!(b, rand_cpd(b, 2, :D, [:E]))\n", 520 | "push!(b, rand_cpd(b, 2, :C, [:E]))" 521 | ] 522 | }, 523 | { 524 | "cell_type": "markdown", 525 | "metadata": {}, 526 | "source": [ 527 | "Compute \n", 528 | "\n", 529 | "$P(b^1, d^1, c^1) = \\sum_s \\sum_e P(b^1)P(s)P(e \\mid b^1, s)P(d^1 \\mid e)P(c^1 \\mid e)$" 530 | ] 531 | }, 532 | { 533 | "cell_type": "code", 534 | "execution_count": 13, 535 | "metadata": {}, 536 | "outputs": [ 537 | { 538 | "data": { 539 | "text/plain": [ 540 | "Dict{Symbol,Any} with 3 entries:\n", 541 | " :D => 2\n", 542 | " :B => 2\n", 543 | " :C => 2" 544 | ] 545 | }, 546 | "execution_count": 13, 547 | "metadata": {}, 548 | "output_type": "execute_result" 549 | } 550 | ], 551 | "source": [ 552 | "a = Assignment(:B=>2, :D=>2, :C=>2)" 553 | ] 554 | }, 555 | { 556 | "cell_type": "code", 557 | "execution_count": 14, 558 | "metadata": {}, 559 | "outputs": [ 560 | { 561 | "data": { 562 | "text/html": [ 563 | "

4 rows × 6 columns

BSEDCp
Int64⍰Int64⍰Int64⍰Int64⍰Int64⍰Float64
1211220.00399303
2212220.0534892
3221220.149917
4222220.0145783
" 564 | ], 565 | "text/plain": [ 566 | "Table(4×6 DataFrame\n", 567 | "│ Row │ B │ S │ E │ D │ C │ p │\n", 568 | "│ │ \u001b[90mInt64⍰\u001b[39m │ \u001b[90mInt64⍰\u001b[39m │ \u001b[90mInt64⍰\u001b[39m │ \u001b[90mInt64⍰\u001b[39m │ \u001b[90mInt64⍰\u001b[39m │ \u001b[90mFloat64\u001b[39m │\n", 569 | "├─────┼────────┼────────┼────────┼────────┼────────┼────────────┤\n", 570 | "│ 1 │ 2 │ 1 │ 1 │ 2 │ 2 │ 0.00399303 │\n", 571 | "│ 2 │ 2 │ 1 │ 2 │ 2 │ 2 │ 0.0534892 │\n", 572 | "│ 3 │ 2 │ 2 │ 1 │ 2 │ 2 │ 0.149917 │\n", 573 | "│ 4 │ 2 │ 2 │ 2 │ 2 │ 2 │ 0.0145783 │)" 574 | ] 575 | }, 576 | "execution_count": 14, 577 | "metadata": {}, 578 | "output_type": "execute_result" 579 | } 580 | ], 581 | "source": [ 582 | "T = table(b,:B,a)*table(b,:S)*table(b,:E,a)*table(b,:D,a)*table(b,:C,a)" 583 | ] 584 | }, 585 | { 586 | "cell_type": "markdown", 587 | "metadata": {}, 588 | "source": [ 589 | "The character ⍰ indicates that the column can hold a `Missing` value" 590 | ] 591 | }, 592 | { 593 | "cell_type": "code", 594 | "execution_count": 15, 595 | "metadata": {}, 596 | "outputs": [ 597 | { 598 | "data": { 599 | "text/html": [ 600 | "

1 rows × 4 columns

BDCp
Int64⍰Int64⍰Int64⍰Float64
12220.221977
" 601 | ], 602 | "text/plain": [ 603 | "Table(1×4 DataFrame\n", 604 | "│ Row │ B │ D │ C │ p │\n", 605 | "│ │ \u001b[90mInt64⍰\u001b[39m │ \u001b[90mInt64⍰\u001b[39m │ \u001b[90mInt64⍰\u001b[39m │ \u001b[90mFloat64\u001b[39m │\n", 606 | "├─────┼────────┼────────┼────────┼──────────┤\n", 607 | "│ 1 │ 2 │ 2 │ 2 │ 0.221977 │)" 608 | ] 609 | }, 610 | "execution_count": 15, 611 | "metadata": {}, 612 | "output_type": "execute_result" 613 | } 614 | ], 615 | "source": [ 616 | "sumout(T, [:S, :E])" 617 | ] 618 | }, 619 | { 620 | "cell_type": "markdown", 621 | "metadata": {}, 622 | "source": [ 623 | "## Approximate Inference" 624 | ] 625 | }, 626 | { 627 | "cell_type": "code", 628 | "execution_count": 16, 629 | "metadata": {}, 630 | "outputs": [ 631 | { 632 | "data": { 633 | "image/svg+xml": [ 634 | "\n", 635 | "\n", 636 | "\n", 637 | "\n", 638 | "\n", 639 | "\n", 640 | "\n", 641 | "\n", 642 | "\n", 643 | "\n", 644 | "\n", 645 | "\n", 646 | "\n", 647 | "\n", 648 | "\n", 649 | "\n", 650 | "\n", 651 | "\n", 652 | "\n", 653 | "\n", 654 | "\n", 655 | "\n", 656 | "\n", 657 | "\n", 658 | "\n", 659 | "\n", 660 | "\n", 661 | "\n", 662 | "\n", 663 | "\n", 664 | "\n", 665 | "\n", 666 | "\n", 667 | "\n", 668 | " \n", 669 | "\n", 670 | "\n", 671 | " \n", 672 | "\n", 673 | "\n", 674 | " \n", 675 | "\n", 676 | "\n", 677 | " \n", 678 | "\n", 679 | "\n", 680 | " \n", 681 | "\n", 682 | "\n", 683 | "\n", 684 | "\n" 685 | ], 686 | "text/plain": [ 687 | "BayesNet{CategoricalCPD{DiscreteNonParametric{Int64,Float64,Base.OneTo{Int64},Array{Float64,1}}}}({5, 4} directed simple Int64 graph, CategoricalCPD{DiscreteNonParametric{Int64,Float64,Base.OneTo{Int64},Array{Float64,1}}}[2 instantiations:\n", 688 | " B (2), 2 instantiations:\n", 689 | " S (2), 8 instantiations:\n", 690 | " E (2)\n", 691 | " B (2)\n", 692 | " S (2), 4 instantiations:\n", 693 | " C (2)\n", 694 | " E (2), 4 instantiations:\n", 695 | " D (2)\n", 696 | " E (2)], Dict(:D => 5,:B => 1,:S => 2,:E => 3,:C => 4))" 697 | ] 698 | }, 699 | "execution_count": 16, 700 | "metadata": {}, 701 | "output_type": "execute_result" 702 | } 703 | ], 704 | "source": [ 705 | "b = DiscreteBayesNet()\n", 706 | "push!(b, DiscreteCPD(:B, [0.1,0.9]))\n", 707 | "push!(b, DiscreteCPD(:S, [0.5,0.5]))\n", 708 | "push!(b, rand_cpd(b, 2, :E, [:B, :S]))\n", 709 | "push!(b, rand_cpd(b, 2, :D, [:E]))\n", 710 | "push!(b, rand_cpd(b, 2, :C, [:E]))" 711 | ] 712 | }, 713 | { 714 | "cell_type": "code", 715 | "execution_count": 17, 716 | "metadata": {}, 717 | "outputs": [ 718 | { 719 | "data": { 720 | "text/plain": [ 721 | "Dict{Symbol,Any} with 5 entries:\n", 722 | " :D => 2\n", 723 | " :B => 2\n", 724 | " :S => 2\n", 725 | " :E => 2\n", 726 | " :C => 1" 727 | ] 728 | }, 729 | "execution_count": 17, 730 | "metadata": {}, 731 | "output_type": "execute_result" 732 | } 733 | ], 734 | "source": [ 735 | "rand(b)" 736 | ] 737 | }, 738 | { 739 | "cell_type": "code", 740 | "execution_count": 18, 741 | "metadata": {}, 742 | "outputs": [ 743 | { 744 | "data": { 745 | "text/html": [ 746 | "

8 rows × 5 columns

BSECD
Int64Int64Int64Int64Int64
121222
211121
322112
422122
522121
622221
722222
821222
" 747 | ], 748 | "text/latex": [ 749 | "\\begin{tabular}{r|ccccc}\n", 750 | "\t& B & S & E & C & D\\\\\n", 751 | "\t\\hline\n", 752 | "\t& Int64 & Int64 & Int64 & Int64 & Int64\\\\\n", 753 | "\t\\hline\n", 754 | "\t1 & 2 & 1 & 2 & 2 & 2 \\\\\n", 755 | "\t2 & 1 & 1 & 1 & 2 & 1 \\\\\n", 756 | "\t3 & 2 & 2 & 1 & 1 & 2 \\\\\n", 757 | "\t4 & 2 & 2 & 1 & 2 & 2 \\\\\n", 758 | "\t5 & 2 & 2 & 1 & 2 & 1 \\\\\n", 759 | "\t6 & 2 & 2 & 2 & 2 & 1 \\\\\n", 760 | "\t7 & 2 & 2 & 2 & 2 & 2 \\\\\n", 761 | "\t8 & 2 & 1 & 2 & 2 & 2 \\\\\n", 762 | "\\end{tabular}\n" 763 | ], 764 | "text/plain": [ 765 | "8×5 DataFrame\n", 766 | "│ Row │ B │ S │ E │ C │ D │\n", 767 | "│ │ \u001b[90mInt64\u001b[39m │ \u001b[90mInt64\u001b[39m │ \u001b[90mInt64\u001b[39m │ \u001b[90mInt64\u001b[39m │ \u001b[90mInt64\u001b[39m │\n", 768 | "├─────┼───────┼───────┼───────┼───────┼───────┤\n", 769 | "│ 1 │ 2 │ 1 │ 2 │ 2 │ 2 │\n", 770 | "│ 2 │ 1 │ 1 │ 1 │ 2 │ 1 │\n", 771 | "│ 3 │ 2 │ 2 │ 1 │ 1 │ 2 │\n", 772 | "│ 4 │ 2 │ 2 │ 1 │ 2 │ 2 │\n", 773 | "│ 5 │ 2 │ 2 │ 1 │ 2 │ 1 │\n", 774 | "│ 6 │ 2 │ 2 │ 2 │ 2 │ 1 │\n", 775 | "│ 7 │ 2 │ 2 │ 2 │ 2 │ 2 │\n", 776 | "│ 8 │ 2 │ 1 │ 2 │ 2 │ 2 │" 777 | ] 778 | }, 779 | "execution_count": 18, 780 | "metadata": {}, 781 | "output_type": "execute_result" 782 | } 783 | ], 784 | "source": [ 785 | "rand(b, 8)" 786 | ] 787 | }, 788 | { 789 | "cell_type": "markdown", 790 | "metadata": {}, 791 | "source": [ 792 | "### Example chemical detection network" 793 | ] 794 | }, 795 | { 796 | "cell_type": "code", 797 | "execution_count": 19, 798 | "metadata": {}, 799 | "outputs": [ 800 | { 801 | "data": { 802 | "image/svg+xml": [ 803 | "\n", 804 | "\n", 805 | "\n", 806 | "\n", 807 | "\n", 808 | "\n", 809 | "\n", 810 | "\n", 811 | "\n", 812 | "\n", 813 | "\n", 814 | "\n", 815 | "\n", 816 | "\n", 817 | "\n", 818 | "\n", 819 | "\n", 820 | "\n", 821 | "\n", 822 | "\n", 823 | "\n", 824 | "\n", 825 | "\n", 826 | "\n", 827 | "\n", 828 | "\n", 829 | "\n", 830 | "\n", 831 | "\n", 832 | "\n", 833 | "\n", 834 | "\n", 835 | "\n", 836 | "\n", 837 | "\n", 838 | "\n", 839 | "\n", 840 | "\n", 841 | "\n", 842 | "\n", 843 | " \n", 844 | " \n", 845 | " \n", 846 | " \n", 847 | " \n", 848 | " \n", 849 | " \n", 850 | "\n", 851 | "\n", 852 | " \n", 853 | " \n", 854 | " \n", 855 | " \n", 856 | " \n", 857 | " \n", 858 | " \n", 859 | " \n", 860 | "\n", 861 | "\n", 862 | "\n", 863 | "\n" 864 | ], 865 | "text/plain": [ 866 | "BayesNet{CPD}({2, 1} directed simple Int64 graph, CPD[StaticCPD{Bernoulli{Float64}}(:Present, Symbol[], Bernoulli{Float64}(p=0.001)), FunctionalCPD{Bernoulli}(:Detected, Symbol[:Present], getfield(Main, Symbol(\"##5#6\"))())], Dict(:Present => 1,:Detected => 2))" 867 | ] 868 | }, 869 | "execution_count": 19, 870 | "metadata": {}, 871 | "output_type": "execute_result" 872 | } 873 | ], 874 | "source": [ 875 | "b = BayesNet()\n", 876 | "push!(b, StaticCPD(:Present, Bernoulli(0.001)))\n", 877 | "push!(b, FunctionalCPD{Bernoulli}(:Detected, [:Present], a->Bernoulli(a[:Present] == true ? 0.999 : 0.001)))" 878 | ] 879 | }, 880 | { 881 | "cell_type": "code", 882 | "execution_count": 20, 883 | "metadata": {}, 884 | "outputs": [ 885 | { 886 | "data": { 887 | "text/html": [ 888 | "

10 rows × 2 columns

PresentDetected
BoolBool
100
200
300
400
500
600
700
800
900
1000
" 889 | ], 890 | "text/latex": [ 891 | "\\begin{tabular}{r|cc}\n", 892 | "\t& Present & Detected\\\\\n", 893 | "\t\\hline\n", 894 | "\t& Bool & Bool\\\\\n", 895 | "\t\\hline\n", 896 | "\t1 & 0 & 0 \\\\\n", 897 | "\t2 & 0 & 0 \\\\\n", 898 | "\t3 & 0 & 0 \\\\\n", 899 | "\t4 & 0 & 0 \\\\\n", 900 | "\t5 & 0 & 0 \\\\\n", 901 | "\t6 & 0 & 0 \\\\\n", 902 | "\t7 & 0 & 0 \\\\\n", 903 | "\t8 & 0 & 0 \\\\\n", 904 | "\t9 & 0 & 0 \\\\\n", 905 | "\t10 & 0 & 0 \\\\\n", 906 | "\\end{tabular}\n" 907 | ], 908 | "text/plain": [ 909 | "10×2 DataFrame\n", 910 | "│ Row │ Present │ Detected │\n", 911 | "│ │ \u001b[90mBool\u001b[39m │ \u001b[90mBool\u001b[39m │\n", 912 | "├─────┼─────────┼──────────┤\n", 913 | "│ 1 │ 0 │ 0 │\n", 914 | "│ 2 │ 0 │ 0 │\n", 915 | "│ 3 │ 0 │ 0 │\n", 916 | "│ 4 │ 0 │ 0 │\n", 917 | "│ 5 │ 0 │ 0 │\n", 918 | "│ 6 │ 0 │ 0 │\n", 919 | "│ 7 │ 0 │ 0 │\n", 920 | "│ 8 │ 0 │ 0 │\n", 921 | "│ 9 │ 0 │ 0 │\n", 922 | "│ 10 │ 0 │ 0 │" 923 | ] 924 | }, 925 | "execution_count": 20, 926 | "metadata": {}, 927 | "output_type": "execute_result" 928 | } 929 | ], 930 | "source": [ 931 | "rand(b, 10)" 932 | ] 933 | }, 934 | { 935 | "cell_type": "markdown", 936 | "metadata": {}, 937 | "source": [ 938 | "Not very interesting since all the samples are likely to be (false, false)" 939 | ] 940 | }, 941 | { 942 | "cell_type": "code", 943 | "execution_count": 21, 944 | "metadata": {}, 945 | "outputs": [ 946 | { 947 | "data": { 948 | "text/plain": [ 949 | "1" 950 | ] 951 | }, 952 | "execution_count": 21, 953 | "metadata": {}, 954 | "output_type": "execute_result" 955 | } 956 | ], 957 | "source": [ 958 | "data = rand(b, 1000)\n", 959 | "sum(data[!,:Detected] .== 1)" 960 | ] 961 | }, 962 | { 963 | "cell_type": "markdown", 964 | "metadata": {}, 965 | "source": [ 966 | "Even with 1000 samples, we are not likely to get many samples that are consistent with Detected = true. This can result in a pretty poor estimate." 967 | ] 968 | }, 969 | { 970 | "cell_type": "code", 971 | "execution_count": 22, 972 | "metadata": {}, 973 | "outputs": [ 974 | { 975 | "data": { 976 | "text/html": [ 977 | "

2 rows × 3 columns

PresentDetectedp
BoolBoolFloat64
1010.493
2110.507
" 978 | ], 979 | "text/plain": [ 980 | "Table(2×3 DataFrame\n", 981 | "│ Row │ Present │ Detected │ p │\n", 982 | "│ │ \u001b[90mBool\u001b[39m │ \u001b[90mBool\u001b[39m │ \u001b[90mFloat64\u001b[39m │\n", 983 | "├─────┼─────────┼──────────┼─────────┤\n", 984 | "│ 1 │ 0 │ 1 │ 0.493 │\n", 985 | "│ 2 │ 1 │ 1 │ 0.507 │)" 986 | ] 987 | }, 988 | "execution_count": 22, 989 | "metadata": {}, 990 | "output_type": "execute_result" 991 | } 992 | ], 993 | "source": [ 994 | "samples = rand(b, RejectionSampler(:Detected=>true, max_nsamples=100000000), 1000)\n", 995 | "fit(Table, samples)" 996 | ] 997 | }, 998 | { 999 | "cell_type": "markdown", 1000 | "metadata": {}, 1001 | "source": [ 1002 | "### Likelihood weighted sampling" 1003 | ] 1004 | }, 1005 | { 1006 | "cell_type": "code", 1007 | "execution_count": 23, 1008 | "metadata": {}, 1009 | "outputs": [ 1010 | { 1011 | "data": { 1012 | "text/html": [ 1013 | "

5 rows × 3 columns

DetectedPresentp
AnyAnyAny
1100.2
2100.2
3100.2
4100.2
5100.2
" 1014 | ], 1015 | "text/latex": [ 1016 | "\\begin{tabular}{r|ccc}\n", 1017 | "\t& Detected & Present & p\\\\\n", 1018 | "\t\\hline\n", 1019 | "\t& Any & Any & Any\\\\\n", 1020 | "\t\\hline\n", 1021 | "\t1 & 1 & 0 & 0.2 \\\\\n", 1022 | "\t2 & 1 & 0 & 0.2 \\\\\n", 1023 | "\t3 & 1 & 0 & 0.2 \\\\\n", 1024 | "\t4 & 1 & 0 & 0.2 \\\\\n", 1025 | "\t5 & 1 & 0 & 0.2 \\\\\n", 1026 | "\\end{tabular}\n" 1027 | ], 1028 | "text/plain": [ 1029 | "5×3 DataFrame\n", 1030 | "│ Row │ Detected │ Present │ p │\n", 1031 | "│ │ \u001b[90mAny\u001b[39m │ \u001b[90mAny\u001b[39m │ \u001b[90mAny\u001b[39m │\n", 1032 | "├─────┼──────────┼─────────┼─────┤\n", 1033 | "│ 1 │ 1 │ 0 │ 0.2 │\n", 1034 | "│ 2 │ 1 │ 0 │ 0.2 │\n", 1035 | "│ 3 │ 1 │ 0 │ 0.2 │\n", 1036 | "│ 4 │ 1 │ 0 │ 0.2 │\n", 1037 | "│ 5 │ 1 │ 0 │ 0.2 │" 1038 | ] 1039 | }, 1040 | "execution_count": 23, 1041 | "metadata": {}, 1042 | "output_type": "execute_result" 1043 | } 1044 | ], 1045 | "source": [ 1046 | "rand(b, LikelihoodWeightedSampler(Assignment(:Detected=>true)), 5)" 1047 | ] 1048 | }, 1049 | { 1050 | "cell_type": "code", 1051 | "execution_count": 24, 1052 | "metadata": {}, 1053 | "outputs": [ 1054 | { 1055 | "data": { 1056 | "text/html": [ 1057 | "

2 rows × 3 columns

DetectedPresentp
AnyAnyFloat64
1100.476166
2110.523834
" 1058 | ], 1059 | "text/plain": [ 1060 | "Table(2×3 DataFrame\n", 1061 | "│ Row │ Detected │ Present │ p │\n", 1062 | "│ │ \u001b[90mAny\u001b[39m │ \u001b[90mAny\u001b[39m │ \u001b[90mFloat64\u001b[39m │\n", 1063 | "├─────┼──────────┼─────────┼──────────┤\n", 1064 | "│ 1 │ 1 │ 0 │ 0.476166 │\n", 1065 | "│ 2 │ 1 │ 1 │ 0.523834 │)" 1066 | ] 1067 | }, 1068 | "execution_count": 24, 1069 | "metadata": {}, 1070 | "output_type": "execute_result" 1071 | } 1072 | ], 1073 | "source": [ 1074 | "fit(Table, rand(b, LikelihoodWeightedSampler(:Detected=>true), 10000))" 1075 | ] 1076 | } 1077 | ], 1078 | "metadata": { 1079 | "kernelspec": { 1080 | "display_name": "Julia 1.2.0", 1081 | "language": "julia", 1082 | "name": "julia-1.2" 1083 | }, 1084 | "language_info": { 1085 | "file_extension": ".jl", 1086 | "mimetype": "application/julia", 1087 | "name": "julia", 1088 | "version": "1.2.0" 1089 | } 1090 | }, 1091 | "nbformat": 4, 1092 | "nbformat_minor": 1 1093 | } 1094 | -------------------------------------------------------------------------------- /06-DecisionNetworks.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Decision Networks" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "using BayesNets\n", 17 | "using DataFrames" 18 | ] 19 | }, 20 | { 21 | "cell_type": "code", 22 | "execution_count": 2, 23 | "metadata": {}, 24 | "outputs": [ 25 | { 26 | "data": { 27 | "image/svg+xml": [ 28 | "\n", 29 | "\n", 30 | "\n", 31 | "\n", 32 | "\n", 33 | "\n", 34 | "\n", 35 | "\n", 36 | "\n", 37 | "\n", 38 | "\n", 39 | "\n", 40 | "\n", 41 | "\n", 42 | "\n", 43 | "\n", 44 | "\n", 45 | "\n", 46 | "\n", 47 | "\n", 48 | "\n", 49 | "\n", 50 | "\n", 51 | "\n", 52 | "\n", 53 | "\n", 54 | "\n", 55 | "\n", 56 | "\n", 57 | "\n", 58 | "\n", 59 | "\n", 60 | " \n", 61 | "\n", 62 | "\n", 63 | " \n", 64 | " \n", 65 | "\n", 66 | "\n", 67 | " \n", 68 | " \n", 69 | "\n", 70 | "\n", 71 | " \n", 72 | " \n", 73 | "\n", 74 | "\n", 75 | "\n", 76 | "\n" 77 | ], 78 | "text/plain": [ 79 | "BayesNet{CPD}({4, 3} directed simple Int64 graph, CPD[StaticCPD{Bernoulli{Float64}}(:D, Symbol[], Bernoulli{Float64}(p=0.01)), FunctionalCPD{Bernoulli}(:O3, Symbol[:D], getfield(Main, Symbol(\"##5#6\"))()), StaticCPD{Bernoulli{Float64}}(:O1, Symbol[:D], Bernoulli{Float64}(p=0.5)), FunctionalCPD{Bernoulli}(:O2, Symbol[:D], getfield(Main, Symbol(\"##3#4\"))())], Dict(:O2 => 4,:D => 1,:O1 => 3,:O3 => 2))" 80 | ] 81 | }, 82 | "execution_count": 2, 83 | "metadata": {}, 84 | "output_type": "execute_result" 85 | } 86 | ], 87 | "source": [ 88 | "b = BayesNet()\n", 89 | "push!(b, StaticCPD(:D, Bernoulli(0.01)))\n", 90 | "push!(b, StaticCPD(:O1, [:D], Bernoulli(0.5))) # no real signal of whether disease is present\n", 91 | "push!(b, FunctionalCPD{Bernoulli}(:O2, [:D], a->Bernoulli(a[:D] == true ? 0.9 : 0.01)))\n", 92 | "push!(b, FunctionalCPD{Bernoulli}(:O3, [:D], a->Bernoulli(a[:D] == true ? 0.6 : 0.3)))" 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": 3, 98 | "metadata": {}, 99 | "outputs": [ 100 | { 101 | "data": { 102 | "text/html": [ 103 | "

4 rows × 3 columns

TDU
BoolBoolInt64
1000
201-10
310-1
411-1
" 104 | ], 105 | "text/latex": [ 106 | "\\begin{tabular}{r|ccc}\n", 107 | "\t& T & D & U\\\\\n", 108 | "\t\\hline\n", 109 | "\t& Bool & Bool & Int64\\\\\n", 110 | "\t\\hline\n", 111 | "\t1 & 0 & 0 & 0 \\\\\n", 112 | "\t2 & 0 & 1 & -10 \\\\\n", 113 | "\t3 & 1 & 0 & -1 \\\\\n", 114 | "\t4 & 1 & 1 & -1 \\\\\n", 115 | "\\end{tabular}\n" 116 | ], 117 | "text/plain": [ 118 | "4×3 DataFrame\n", 119 | "│ Row │ T │ D │ U │\n", 120 | "│ │ \u001b[90mBool\u001b[39m │ \u001b[90mBool\u001b[39m │ \u001b[90mInt64\u001b[39m │\n", 121 | "├─────┼──────┼──────┼───────┤\n", 122 | "│ 1 │ 0 │ 0 │ 0 │\n", 123 | "│ 2 │ 0 │ 1 │ -10 │\n", 124 | "│ 3 │ 1 │ 0 │ -1 │\n", 125 | "│ 4 │ 1 │ 1 │ -1 │" 126 | ] 127 | }, 128 | "execution_count": 3, 129 | "metadata": {}, 130 | "output_type": "execute_result" 131 | } 132 | ], 133 | "source": [ 134 | "U = DataFrame()\n", 135 | "U[!,:T] = [false, false, true, true]\n", 136 | "U[!,:D] = [false, true, false, true]\n", 137 | "U[!,:U] = [0, -10, -1, -1]\n", 138 | "U" 139 | ] 140 | }, 141 | { 142 | "cell_type": "code", 143 | "execution_count": 4, 144 | "metadata": {}, 145 | "outputs": [], 146 | "source": [ 147 | "using Random\n", 148 | "function estimate_table(b::BayesNet, target::NodeName, consistent_with::Assignment; nsamples = 10000)\n", 149 | " Random.seed!(0)\n", 150 | " t = fit(Table, rand(b, LikelihoodWeightedSampler(consistent_with), nsamples))\n", 151 | " normalize(sumout(t, setdiff(names(b), [target])))\n", 152 | "end;" 153 | ] 154 | }, 155 | { 156 | "cell_type": "code", 157 | "execution_count": 5, 158 | "metadata": {}, 159 | "outputs": [ 160 | { 161 | "data": { 162 | "text/html": [ 163 | "

2 rows × 2 columns

Dp
AnyFloat64
100.9899
210.0101
" 164 | ], 165 | "text/plain": [ 166 | "Table(2×2 DataFrame\n", 167 | "│ Row │ D │ p │\n", 168 | "│ │ \u001b[90mAny\u001b[39m │ \u001b[90mFloat64\u001b[39m │\n", 169 | "├─────┼─────┼─────────┤\n", 170 | "│ 1 │ 0 │ 0.9899 │\n", 171 | "│ 2 │ 1 │ 0.0101 │)" 172 | ] 173 | }, 174 | "execution_count": 5, 175 | "metadata": {}, 176 | "output_type": "execute_result" 177 | } 178 | ], 179 | "source": [ 180 | "D = estimate_table(b, :D, Assignment(:O1=>true))" 181 | ] 182 | }, 183 | { 184 | "cell_type": "code", 185 | "execution_count": 6, 186 | "metadata": {}, 187 | "outputs": [ 188 | { 189 | "data": { 190 | "text/html": [ 191 | "

4 rows × 4 columns

TDUp
BoolBoolInt64Float64
10000.9899
201-100.0101
310-10.9899
411-10.0101
" 192 | ], 193 | "text/latex": [ 194 | "\\begin{tabular}{r|cccc}\n", 195 | "\t& T & D & U & p\\\\\n", 196 | "\t\\hline\n", 197 | "\t& Bool & Bool & Int64 & Float64\\\\\n", 198 | "\t\\hline\n", 199 | "\t1 & 0 & 0 & 0 & 0.9899 \\\\\n", 200 | "\t2 & 0 & 1 & -10 & 0.0101 \\\\\n", 201 | "\t3 & 1 & 0 & -1 & 0.9899 \\\\\n", 202 | "\t4 & 1 & 1 & -1 & 0.0101 \\\\\n", 203 | "\\end{tabular}\n" 204 | ], 205 | "text/plain": [ 206 | "4×4 DataFrame\n", 207 | "│ Row │ T │ D │ U │ p │\n", 208 | "│ │ \u001b[90mBool\u001b[39m │ \u001b[90mBool\u001b[39m │ \u001b[90mInt64\u001b[39m │ \u001b[90mFloat64\u001b[39m │\n", 209 | "├─────┼──────┼──────┼───────┼─────────┤\n", 210 | "│ 1 │ 0 │ 0 │ 0 │ 0.9899 │\n", 211 | "│ 2 │ 0 │ 1 │ -10 │ 0.0101 │\n", 212 | "│ 3 │ 1 │ 0 │ -1 │ 0.9899 │\n", 213 | "│ 4 │ 1 │ 1 │ -1 │ 0.0101 │" 214 | ] 215 | }, 216 | "execution_count": 6, 217 | "metadata": {}, 218 | "output_type": "execute_result" 219 | } 220 | ], 221 | "source": [ 222 | "EU = join(U, D.potential, on = :D)" 223 | ] 224 | }, 225 | { 226 | "cell_type": "code", 227 | "execution_count": 7, 228 | "metadata": {}, 229 | "outputs": [ 230 | { 231 | "data": { 232 | "text/html": [ 233 | "

2 rows × 2 columns

Tx1
BoolFloat64
10-0.101
21-1.0
" 234 | ], 235 | "text/latex": [ 236 | "\\begin{tabular}{r|cc}\n", 237 | "\t& T & x1\\\\\n", 238 | "\t\\hline\n", 239 | "\t& Bool & Float64\\\\\n", 240 | "\t\\hline\n", 241 | "\t1 & 0 & -0.101 \\\\\n", 242 | "\t2 & 1 & -1.0 \\\\\n", 243 | "\\end{tabular}\n" 244 | ], 245 | "text/plain": [ 246 | "2×2 DataFrame\n", 247 | "│ Row │ T │ x1 │\n", 248 | "│ │ \u001b[90mBool\u001b[39m │ \u001b[90mFloat64\u001b[39m │\n", 249 | "├─────┼──────┼─────────┤\n", 250 | "│ 1 │ 0 │ -0.101 │\n", 251 | "│ 2 │ 1 │ -1.0 │" 252 | ] 253 | }, 254 | "execution_count": 7, 255 | "metadata": {}, 256 | "output_type": "execute_result" 257 | } 258 | ], 259 | "source": [ 260 | "using LinearAlgebra\n", 261 | "by(EU, :T, df->LinearAlgebra.dot(df[!,:U], df[!,:p]))" 262 | ] 263 | }, 264 | { 265 | "cell_type": "code", 266 | "execution_count": 8, 267 | "metadata": {}, 268 | "outputs": [], 269 | "source": [ 270 | "function diseaseEU(b::BayesNet, a::Assignment, U::DataFrame)\n", 271 | " D = estimate_table(b, :D, a).potential\n", 272 | " EU = join(U, D, on = :D)\n", 273 | " t = by(EU, :T, df->LinearAlgebra.dot(df[!,:U], df[!,:p]))\n", 274 | " rename!(t, :x1=>:EU)\n", 275 | " t\n", 276 | "end;" 277 | ] 278 | }, 279 | { 280 | "cell_type": "code", 281 | "execution_count": 9, 282 | "metadata": {}, 283 | "outputs": [ 284 | { 285 | "data": { 286 | "text/html": [ 287 | "

2 rows × 2 columns

TEU
BoolFloat64
10-0.101
21-1.0
" 288 | ], 289 | "text/latex": [ 290 | "\\begin{tabular}{r|cc}\n", 291 | "\t& T & EU\\\\\n", 292 | "\t\\hline\n", 293 | "\t& Bool & Float64\\\\\n", 294 | "\t\\hline\n", 295 | "\t1 & 0 & -0.101 \\\\\n", 296 | "\t2 & 1 & -1.0 \\\\\n", 297 | "\\end{tabular}\n" 298 | ], 299 | "text/plain": [ 300 | "2×2 DataFrame\n", 301 | "│ Row │ T │ EU │\n", 302 | "│ │ \u001b[90mBool\u001b[39m │ \u001b[90mFloat64\u001b[39m │\n", 303 | "├─────┼──────┼─────────┤\n", 304 | "│ 1 │ 0 │ -0.101 │\n", 305 | "│ 2 │ 1 │ -1.0 │" 306 | ] 307 | }, 308 | "execution_count": 9, 309 | "metadata": {}, 310 | "output_type": "execute_result" 311 | } 312 | ], 313 | "source": [ 314 | "diseaseEU(b, Assignment(:O1=>true), U)" 315 | ] 316 | }, 317 | { 318 | "cell_type": "code", 319 | "execution_count": 10, 320 | "metadata": {}, 321 | "outputs": [ 322 | { 323 | "data": { 324 | "text/html": [ 325 | "

2 rows × 2 columns

TEU
BoolFloat64
10-0.101
21-1.0
" 326 | ], 327 | "text/latex": [ 328 | "\\begin{tabular}{r|cc}\n", 329 | "\t& T & EU\\\\\n", 330 | "\t\\hline\n", 331 | "\t& Bool & Float64\\\\\n", 332 | "\t\\hline\n", 333 | "\t1 & 0 & -0.101 \\\\\n", 334 | "\t2 & 1 & -1.0 \\\\\n", 335 | "\\end{tabular}\n" 336 | ], 337 | "text/plain": [ 338 | "2×2 DataFrame\n", 339 | "│ Row │ T │ EU │\n", 340 | "│ │ \u001b[90mBool\u001b[39m │ \u001b[90mFloat64\u001b[39m │\n", 341 | "├─────┼──────┼─────────┤\n", 342 | "│ 1 │ 0 │ -0.101 │\n", 343 | "│ 2 │ 1 │ -1.0 │" 344 | ] 345 | }, 346 | "execution_count": 10, 347 | "metadata": {}, 348 | "output_type": "execute_result" 349 | } 350 | ], 351 | "source": [ 352 | "diseaseEU(b, Assignment(:O1=>false), U)" 353 | ] 354 | }, 355 | { 356 | "cell_type": "code", 357 | "execution_count": 11, 358 | "metadata": {}, 359 | "outputs": [ 360 | { 361 | "data": { 362 | "text/html": [ 363 | "

2 rows × 2 columns

TEU
BoolFloat64
10-4.78698
21-1.0
" 364 | ], 365 | "text/latex": [ 366 | "\\begin{tabular}{r|cc}\n", 367 | "\t& T & EU\\\\\n", 368 | "\t\\hline\n", 369 | "\t& Bool & Float64\\\\\n", 370 | "\t\\hline\n", 371 | "\t1 & 0 & -4.78698 \\\\\n", 372 | "\t2 & 1 & -1.0 \\\\\n", 373 | "\\end{tabular}\n" 374 | ], 375 | "text/plain": [ 376 | "2×2 DataFrame\n", 377 | "│ Row │ T │ EU │\n", 378 | "│ │ \u001b[90mBool\u001b[39m │ \u001b[90mFloat64\u001b[39m │\n", 379 | "├─────┼──────┼──────────┤\n", 380 | "│ 1 │ 0 │ -4.78698 │\n", 381 | "│ 2 │ 1 │ -1.0 │" 382 | ] 383 | }, 384 | "execution_count": 11, 385 | "metadata": {}, 386 | "output_type": "execute_result" 387 | } 388 | ], 389 | "source": [ 390 | "diseaseEU(b, Assignment(:O2=>true), U)" 391 | ] 392 | }, 393 | { 394 | "cell_type": "code", 395 | "execution_count": 12, 396 | "metadata": {}, 397 | "outputs": [ 398 | { 399 | "data": { 400 | "text/html": [ 401 | "

2 rows × 2 columns

TEU
BoolFloat64
10-0.19998
21-1.0
" 402 | ], 403 | "text/latex": [ 404 | "\\begin{tabular}{r|cc}\n", 405 | "\t& T & EU\\\\\n", 406 | "\t\\hline\n", 407 | "\t& Bool & Float64\\\\\n", 408 | "\t\\hline\n", 409 | "\t1 & 0 & -0.19998 \\\\\n", 410 | "\t2 & 1 & -1.0 \\\\\n", 411 | "\\end{tabular}\n" 412 | ], 413 | "text/plain": [ 414 | "2×2 DataFrame\n", 415 | "│ Row │ T │ EU │\n", 416 | "│ │ \u001b[90mBool\u001b[39m │ \u001b[90mFloat64\u001b[39m │\n", 417 | "├─────┼──────┼──────────┤\n", 418 | "│ 1 │ 0 │ -0.19998 │\n", 419 | "│ 2 │ 1 │ -1.0 │" 420 | ] 421 | }, 422 | "execution_count": 12, 423 | "metadata": {}, 424 | "output_type": "execute_result" 425 | } 426 | ], 427 | "source": [ 428 | "diseaseEU(b, Assignment(:O3=>true), U)" 429 | ] 430 | } 431 | ], 432 | "metadata": { 433 | "kernelspec": { 434 | "display_name": "Julia 1.2.0", 435 | "language": "julia", 436 | "name": "julia-1.2" 437 | }, 438 | "language_info": { 439 | "file_extension": ".jl", 440 | "mimetype": "application/julia", 441 | "name": "julia", 442 | "version": "1.2.0" 443 | } 444 | }, 445 | "nbformat": 4, 446 | "nbformat_minor": 1 447 | } 448 | -------------------------------------------------------------------------------- /12-ModelFreeReinforcementLearning.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "include(\"gridworld.jl\")\n", 10 | "g = DMUGridWorld();" 11 | ] 12 | }, 13 | { 14 | "cell_type": "markdown", 15 | "metadata": {}, 16 | "source": [ 17 | "Let's apply Q-learning from Algorithm 5.3 in the text. We'll train over 1000 100-step runs:" 18 | ] 19 | }, 20 | { 21 | "cell_type": "code", 22 | "execution_count": 2, 23 | "metadata": {}, 24 | "outputs": [ 25 | { 26 | "data": { 27 | "text/plain": [ 28 | "Qlearn (generic function with 1 method)" 29 | ] 30 | }, 31 | "execution_count": 2, 32 | "metadata": {}, 33 | "output_type": "execute_result" 34 | } 35 | ], 36 | "source": [ 37 | "function Qlearn(g, alpha, epsilon)\n", 38 | " # initialize dictionary\n", 39 | " Q = Dict{Int, Vector{Float64}}()\n", 40 | " \n", 41 | " # initialize Q-values at initial state (s = 1)\n", 42 | " Q[1] = zeros(n_actions(g))\n", 43 | " \n", 44 | " # 1000 simulations\n", 45 | " for k = 1:1000\n", 46 | " s = 1\n", 47 | " for t = 0:100\n", 48 | " # choose a based on Q and some exploration strategy\n", 49 | " a_idx = findmax(Q[s])[2]\n", 50 | " if rand() < epsilon\n", 51 | " a_idx = rand(1:4)\n", 52 | " end\n", 53 | " a = actions(g)[a_idx]\n", 54 | "\n", 55 | " # observe new state s_{t+1} and reward rt\n", 56 | " sp, r = simulate(g, s, a)\n", 57 | "\n", 58 | " # if we've never observed this state, initialize it to zeros\n", 59 | " if !haskey(Q, sp)\n", 60 | " Q[sp] = zeros(n_actions(g))\n", 61 | " end\n", 62 | "\n", 63 | " # update Q values\n", 64 | " Q[s][a_idx] += alpha * ( r + discount(g)*maximum(Q[sp]) - Q[s][a_idx] )\n", 65 | "\n", 66 | " # update s\n", 67 | " s = sp\n", 68 | " \n", 69 | " # 73 and 88 are terminal states. Just quit if we get in them.\n", 70 | " if s == 73 || s == 88\n", 71 | " break\n", 72 | " end\n", 73 | " end\n", 74 | " end\n", 75 | " \n", 76 | " return Q\n", 77 | "end" 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": 3, 83 | "metadata": {}, 84 | "outputs": [], 85 | "source": [ 86 | "Q = Qlearn(g, 0.5, 0.5);" 87 | ] 88 | }, 89 | { 90 | "cell_type": "markdown", 91 | "metadata": {}, 92 | "source": [ 93 | "Did the Q-learning work? Let's compare it to a random policy during 10 simulations." 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": 4, 99 | "metadata": {}, 100 | "outputs": [ 101 | { 102 | "name": "stdout", 103 | "output_type": "stream", 104 | "text": [ 105 | "Q-learned policy: -209947\n", 106 | "random poilcy: -1531508\n" 107 | ] 108 | } 109 | ], 110 | "source": [ 111 | "using Random\n", 112 | "Random.seed!(1) # for reproducibility, seed random number generator\n", 113 | "\n", 114 | "r_sum = 0.0 # sum for policy from Q-learning\n", 115 | "rr_sum = 0.0 # sum for random policy\n", 116 | "\n", 117 | "# run 10 simulations\n", 118 | "for k = 1:10\n", 119 | " global r_sum, rr_sum\n", 120 | " s = 1 # initial state for policy from Q-learning\n", 121 | " sr = 1 # initial state for random policy\n", 122 | " \n", 123 | " for t = 0:100\n", 124 | " \n", 125 | " # generate actions for both policies\n", 126 | " a = actions(g)[findmax(Q[s])[2]]\n", 127 | " ar = actions(g)[rand(1:4)]\n", 128 | " \n", 129 | " # advance Q simulation if you aren't in a terminal state\n", 130 | " if s != 73 && s != 88\n", 131 | " sp, r = simulate(g, s, a)\n", 132 | " r_sum += r * discount(g) ^ (-t)\n", 133 | " s = sp\n", 134 | " end\n", 135 | " \n", 136 | " # advance random simulation if you aren't in a terminal state\n", 137 | " if sr != 73 && sr != 88\n", 138 | " spr, rr = simulate(g, sr, ar)\n", 139 | " rr_sum += rr * discount(g) ^ (-t)\n", 140 | " sr = spr\n", 141 | " end\n", 142 | " end\n", 143 | "end\n", 144 | "\n", 145 | "println(\"Q-learned policy: \", round(Int, r_sum))\n", 146 | "println(\"random poilcy: \", round(Int, rr_sum))" 147 | ] 148 | }, 149 | { 150 | "cell_type": "markdown", 151 | "metadata": {}, 152 | "source": [ 153 | "The cumulative sum from Q-learning is much better." 154 | ] 155 | } 156 | ], 157 | "metadata": { 158 | "@webio": { 159 | "lastCommId": null, 160 | "lastKernelId": null 161 | }, 162 | "kernelspec": { 163 | "display_name": "Julia 1.2.0", 164 | "language": "julia", 165 | "name": "julia-1.2" 166 | }, 167 | "language_info": { 168 | "file_extension": ".jl", 169 | "mimetype": "application/julia", 170 | "name": "julia", 171 | "version": "1.2.0" 172 | } 173 | }, 174 | "nbformat": 4, 175 | "nbformat_minor": 1 176 | } 177 | -------------------------------------------------------------------------------- /13-StateUncertainty.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "This example shows how to perform the discrete belief update discussed in section 6.2 of the course text.\n", 8 | "\n", 9 | "Read over the description of the baby problem before seeing how to express it in math below.\n", 10 | "\n", 11 | "Let's start by defining the transition function:" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 1, 17 | "metadata": {}, 18 | "outputs": [ 19 | { 20 | "data": { 21 | "text/plain": [ 22 | "T (generic function with 1 method)" 23 | ] 24 | }, 25 | "execution_count": 1, 26 | "metadata": {}, 27 | "output_type": "execute_result" 28 | } 29 | ], 30 | "source": [ 31 | "function T(s, a, sp)\n", 32 | " \n", 33 | " # if we feed the baby, probability that it becomes not hungry is 1.0\n", 34 | " if a == :feed\n", 35 | " if sp == :not_hungry\n", 36 | " return 1.0\n", 37 | " else\n", 38 | " return 0.0\n", 39 | " end\n", 40 | " \n", 41 | " # if we don't feed baby...\n", 42 | " else\n", 43 | " # baby remains hungry if unfed\n", 44 | " if s == :hungry\n", 45 | " if sp == :hungry\n", 46 | " return 1.0\n", 47 | " else\n", 48 | " return 0.0\n", 49 | " end\n", 50 | " else\n", 51 | " # 10% chance of baby becoming hungry given it is not hungry and unfed\n", 52 | " if sp == :hungry\n", 53 | " return 0.1\n", 54 | " else\n", 55 | " return 0.9\n", 56 | " end\n", 57 | " end\n", 58 | " end\n", 59 | " \n", 60 | "end" 61 | ] 62 | }, 63 | { 64 | "cell_type": "markdown", 65 | "metadata": {}, 66 | "source": [ 67 | "Let's define the observation function:" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": 2, 73 | "metadata": {}, 74 | "outputs": [ 75 | { 76 | "data": { 77 | "text/plain": [ 78 | "O (generic function with 1 method)" 79 | ] 80 | }, 81 | "execution_count": 2, 82 | "metadata": {}, 83 | "output_type": "execute_result" 84 | } 85 | ], 86 | "source": [ 87 | "function O(a, sp, o)\n", 88 | " if sp == :hungry\n", 89 | " p_cry = 0.8\n", 90 | " else\n", 91 | " p_cry = 0.1\n", 92 | " end\n", 93 | " \n", 94 | " if o == :cry\n", 95 | " return p_cry\n", 96 | " else\n", 97 | " return 1.0 - p_cry\n", 98 | " end \n", 99 | "end" 100 | ] 101 | }, 102 | { 103 | "cell_type": "markdown", 104 | "metadata": {}, 105 | "source": [ 106 | "The discrete belief update is defined in equations 6.7-6.11 of the course text:" 107 | ] 108 | }, 109 | { 110 | "cell_type": "code", 111 | "execution_count": 3, 112 | "metadata": {}, 113 | "outputs": [ 114 | { 115 | "data": { 116 | "text/plain": [ 117 | "update_belief (generic function with 1 method)" 118 | ] 119 | }, 120 | "execution_count": 3, 121 | "metadata": {}, 122 | "output_type": "execute_result" 123 | } 124 | ], 125 | "source": [ 126 | "function update_belief(b, a, o)\n", 127 | " bp = Dict()\n", 128 | " for sp in [:hungry, :not_hungry]\n", 129 | " sum_over_s = 0.0\n", 130 | " for s in [:hungry, :not_hungry]\n", 131 | " sum_over_s += T(s, a, sp) * b[s]\n", 132 | " end\n", 133 | " bp[sp] = O(a, sp, o) * sum_over_s\n", 134 | " end\n", 135 | "\n", 136 | " # normalize so that probabilities sum to 1\n", 137 | " bp_sum = bp[:hungry] + bp[:not_hungry]\n", 138 | " bp[:hungry] = bp[:hungry] / bp_sum\n", 139 | " bp[:not_hungry] = bp[:not_hungry] / bp_sum\n", 140 | "\n", 141 | " return bp\n", 142 | "end" 143 | ] 144 | }, 145 | { 146 | "cell_type": "markdown", 147 | "metadata": {}, 148 | "source": [ 149 | "Let's use our functions and follow the example in chapter 6.2.1 of the course textbook.\n", 150 | "\n", 151 | "Step 1. We start with a uniform belief:" 152 | ] 153 | }, 154 | { 155 | "cell_type": "code", 156 | "execution_count": 4, 157 | "metadata": {}, 158 | "outputs": [ 159 | { 160 | "data": { 161 | "text/plain": [ 162 | "0.5" 163 | ] 164 | }, 165 | "execution_count": 4, 166 | "metadata": {}, 167 | "output_type": "execute_result" 168 | } 169 | ], 170 | "source": [ 171 | "b1 = Dict()\n", 172 | "b1[:hungry] = 0.5\n", 173 | "b1[:not_hungry] = 0.5" 174 | ] 175 | }, 176 | { 177 | "cell_type": "markdown", 178 | "metadata": {}, 179 | "source": [ 180 | "Step 2. We do not feed the baby and the baby cries." 181 | ] 182 | }, 183 | { 184 | "cell_type": "code", 185 | "execution_count": 5, 186 | "metadata": {}, 187 | "outputs": [ 188 | { 189 | "data": { 190 | "text/plain": [ 191 | "Dict{Any,Any} with 2 entries:\n", 192 | " :not_hungry => 0.0927835\n", 193 | " :hungry => 0.907216" 194 | ] 195 | }, 196 | "execution_count": 5, 197 | "metadata": {}, 198 | "output_type": "execute_result" 199 | } 200 | ], 201 | "source": [ 202 | "b2 = update_belief(b1, :not_feed, :cry)" 203 | ] 204 | }, 205 | { 206 | "cell_type": "markdown", 207 | "metadata": {}, 208 | "source": [ 209 | "Step 3. We feed the baby and it stops crying." 210 | ] 211 | }, 212 | { 213 | "cell_type": "code", 214 | "execution_count": 6, 215 | "metadata": {}, 216 | "outputs": [ 217 | { 218 | "data": { 219 | "text/plain": [ 220 | "Dict{Any,Any} with 2 entries:\n", 221 | " :not_hungry => 1.0\n", 222 | " :hungry => 0.0" 223 | ] 224 | }, 225 | "execution_count": 6, 226 | "metadata": {}, 227 | "output_type": "execute_result" 228 | } 229 | ], 230 | "source": [ 231 | "b3 = update_belief(b2, :feed, :not_cry)" 232 | ] 233 | }, 234 | { 235 | "cell_type": "markdown", 236 | "metadata": {}, 237 | "source": [ 238 | "Step 4. We do not feed the baby and it does not cry." 239 | ] 240 | }, 241 | { 242 | "cell_type": "code", 243 | "execution_count": 7, 244 | "metadata": {}, 245 | "outputs": [ 246 | { 247 | "data": { 248 | "text/plain": [ 249 | "Dict{Any,Any} with 2 entries:\n", 250 | " :not_hungry => 0.975904\n", 251 | " :hungry => 0.0240964" 252 | ] 253 | }, 254 | "execution_count": 7, 255 | "metadata": {}, 256 | "output_type": "execute_result" 257 | } 258 | ], 259 | "source": [ 260 | "b4 = update_belief(b3, :not_feed, :not_cry)" 261 | ] 262 | }, 263 | { 264 | "cell_type": "markdown", 265 | "metadata": {}, 266 | "source": [ 267 | "Step 5. Again, we do not feed the baby and it does not cry." 268 | ] 269 | }, 270 | { 271 | "cell_type": "code", 272 | "execution_count": 8, 273 | "metadata": {}, 274 | "outputs": [ 275 | { 276 | "data": { 277 | "text/plain": [ 278 | "Dict{Any,Any} with 2 entries:\n", 279 | " :not_hungry => 0.970132\n", 280 | " :hungry => 0.0298684" 281 | ] 282 | }, 283 | "execution_count": 8, 284 | "metadata": {}, 285 | "output_type": "execute_result" 286 | } 287 | ], 288 | "source": [ 289 | "b5 = update_belief(b4, :not_feed, :not_cry)" 290 | ] 291 | }, 292 | { 293 | "cell_type": "markdown", 294 | "metadata": {}, 295 | "source": [ 296 | "Step 6. We do not feed the baby and the baby begins to cry." 297 | ] 298 | }, 299 | { 300 | "cell_type": "code", 301 | "execution_count": 9, 302 | "metadata": {}, 303 | "outputs": [ 304 | { 305 | "data": { 306 | "text/plain": [ 307 | "Dict{Any,Any} with 2 entries:\n", 308 | " :not_hungry => 0.462415\n", 309 | " :hungry => 0.537585" 310 | ] 311 | }, 312 | "execution_count": 9, 313 | "metadata": {}, 314 | "output_type": "execute_result" 315 | } 316 | ], 317 | "source": [ 318 | "b6 = update_belief(b5, :not_feed, :cry)" 319 | ] 320 | } 321 | ], 322 | "metadata": { 323 | "@webio": { 324 | "lastCommId": null, 325 | "lastKernelId": null 326 | }, 327 | "kernelspec": { 328 | "display_name": "Julia 1.2.0", 329 | "language": "julia", 330 | "name": "julia-1.2" 331 | }, 332 | "language_info": { 333 | "file_extension": ".jl", 334 | "mimetype": "application/julia", 335 | "name": "julia", 336 | "version": "1.2.0" 337 | } 338 | }, 339 | "nbformat": 4, 340 | "nbformat_minor": 1 341 | } 342 | -------------------------------------------------------------------------------- /Project.toml: -------------------------------------------------------------------------------- 1 | [deps] 2 | BasicPOMCP = "d721219e-3fc6-5570-a8ef-e5402f47c49e" 3 | BayesNets = "ba4760a4-c768-5bed-964b-cf806dc591cb" 4 | BeliefUpdaters = "8bb6e9a1-7d73-552c-a44a-e5dc5634aac4" 5 | ContinuumWorld = "5cbb95a3-277b-5373-895a-7e14bd91b3cc" 6 | D3Trees = "e3df1716-f71e-5df9-9e2d-98e193103c45" 7 | DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" 8 | DiscreteValueIteration = "4b033969-44f6-5439-a48b-c11fa3648068" 9 | Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" 10 | GridInterpolations = "bb4c363b-b914-514b-8517-4eb369bc008a" 11 | IJulia = "7073ff75-c697-5162-941a-fcdaad2a7d2a" 12 | Interact = "c601a237-2ae4-5e1e-952c-7a85b0c7eef1" 13 | LaserTag = "041f53e1-e4f8-54ec-814d-e9e995aa38d4" 14 | MCTS = "e12ccd36-dcad-5f33-8774-9175229e7b33" 15 | NBInclude = "0db19996-df87-5ea3-a455-e3a50d440464" 16 | PGFPlots = "3b7a836e-365b-5785-a47d-02c71176b4aa" 17 | POMDPModelTools = "08074719-1b2a-587c-a292-00f91cc44415" 18 | POMDPModels = "355abbd5-f08e-5560-ac9e-8b5f2592a0ca" 19 | POMDPPolicies = "182e52fb-cfd0-5e46-8c26-fd0667c990f4" 20 | POMDPSimulators = "e0d0a172-29c6-5d4e-96d0-f262df5d01fd" 21 | POMDPToolbox = "0729bffe-8e6b-52fa-a3fa-893719b744f4" 22 | POMDPs = "a93abf59-7444-517b-a68a-c42f96afdd7d" 23 | ParticleFilters = "c8b314e2-9260-5cf8-ae76-3be7461ca6d0" 24 | Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" 25 | PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0" 26 | PyPlot = "d330b81b-6aea-500a-939a-2ce795aea3ee" 27 | RDatasets = "ce6b1742-4840-55fa-b093-852dadbb1d8b" 28 | Reactive = "a223df75-4e93-5b7c-acf9-bdd599c0f4de" 29 | SARSOP = "cef570c6-3a94-5604-96b7-1a5e143043f2" 30 | StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" 31 | TikzPictures = "37f6aa50-8035-52d0-81c2-5a1d08754b2d" 32 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AA228 Notebooks 2 | 3 | [![Build Status](https://travis-ci.org/sisl/aa228-notebook.svg)](https://travis-ci.org/sisl/aa228-notebook) 4 | 5 | These notebooks are used for [AA228/CS238: Decision Making under Uncertainty](https://aa228.stanford.edu) taught by [Mykel Kochenderfer](https://mykel.kochenderfer.com) at Stanford University. 6 | -------------------------------------------------------------------------------- /alpha_plots.jl: -------------------------------------------------------------------------------- 1 | using PGFPlots 2 | 3 | alpha2vec(alpha::Dict) = [ alpha[:not_hungry], alpha[:hungry] ] 4 | 5 | function plot(alpha::Dict) 6 | Plots.Linear([0,1], alpha2vec(alpha)) 7 | end 8 | 9 | function plot(alphas::Vector{Dict{Symbol, Float64}}) 10 | plot_array = Plots.Linear[] 11 | for alpha in alphas 12 | push!(plot_array, Plots.Linear([0,1], alpha2vec(alpha), style="red,solid,thick", mark="none") ) 13 | end 14 | #return plot_array 15 | Axis(plot_array, xlabel="P(hungry=true)", xmin=0,xmax=1) 16 | end 17 | -------------------------------------------------------------------------------- /baby.jl: -------------------------------------------------------------------------------- 1 | function T(s, a, sp) 2 | 3 | # if we feed the baby, probability that it becomes not hungry is 1.0 4 | if a == :feed 5 | if sp == :not_hungry 6 | return 1.0 7 | else 8 | return 0.0 9 | end 10 | 11 | # if we don't feed baby... 12 | else 13 | # baby remains hungry if unfed 14 | if s == :hungry 15 | if sp == :hungry 16 | return 1.0 17 | else 18 | return 0.0 19 | end 20 | else 21 | # 10% chance of baby becoming hungry given it is not hungry and unfed 22 | if sp == :hungry 23 | return 0.1 24 | else 25 | return 0.9 26 | end 27 | end 28 | end 29 | 30 | end 31 | 32 | function O(a, sp, o) 33 | if sp == :hungry 34 | p_cry = 0.8 35 | else 36 | p_cry = 0.1 37 | end 38 | 39 | if o == :cry 40 | return p_cry 41 | else 42 | return 1.0 - p_cry 43 | end 44 | end 45 | 46 | function update_belief(b, a, o) 47 | bp = Dict() 48 | for sp in [:hungry, :not_hungry] 49 | sum_over_s = 0.0 50 | for s in [:hungry, :not_hungry] 51 | sum_over_s += T(s, a, sp) * b[s] 52 | end 53 | bp[sp] = O(a, sp, o) * sum_over_s 54 | end 55 | 56 | # normalize so that probabilities sum to 1 57 | bp_sum = bp[:hungry] + bp[:not_hungry] 58 | bp[:hungry] = bp[:hungry] / bp_sum 59 | bp[:not_hungry] = bp[:not_hungry] / bp_sum 60 | 61 | return bp 62 | end 63 | -------------------------------------------------------------------------------- /bandits.jl: -------------------------------------------------------------------------------- 1 | using Printf 2 | using Random 3 | using PGFPlots 4 | 5 | mutable struct Bandit 6 | θ::Vector{Float64} # true bandit probabilities 7 | end 8 | Bandit(k::Integer) = Bandit(rand(k)) 9 | pull(b::Bandit, i::Integer) = rand() < b.θ[i] 10 | numArms(b::Bandit) = length(b.θ) 11 | 12 | function _get_string_list_of_percentages(bandit_odds::Vector{R}) where {R<:Real} 13 | strings = map(θ->Printf.@sprintf("%.2f percent", 100θ), bandit_odds) 14 | retval = strings[1] 15 | for i in 2 : length(strings) 16 | retval = retval * ", " * strings[i] 17 | end 18 | retval 19 | end 20 | 21 | function banditTrial(b) 22 | 23 | for i in 1 : numArms(b) 24 | but=button("Arm $i",value=0) 25 | display(but) 26 | wins=Observable(0) 27 | Interact.@on &but>0 ? (wins[] = wins[]+pull(b,i)) : 0 28 | display(map(s -> Printf.@sprintf("%d wins out of %d tries (%d percent)", wins[], but[], 100*wins[]/but[]), but)) 29 | # NOTE: we used to use the latex() wrapper 30 | end 31 | 32 | t = togglebuttons(["Hide", "Show"], value="Hide", label="True Params") 33 | display(t) 34 | display(map(v -> v == "Show" ? _get_string_list_of_percentages(b.θ) : "", t)) 35 | end 36 | 37 | function banditEstimation(b) 38 | B = [button("Arm $i") for i = 1:numArms(b)] 39 | for i in 1 : numArms(b) 40 | but=button("Arm $i",value=0) 41 | display(but) 42 | wins=Observable(0) 43 | Interact.@on &but>0 ? (wins[] = wins[]+pull(b,i)) : 0 44 | display(map(s -> Printf.@sprintf("%d wins out of %d tries (%d percent)", wins[], but[], 100*wins[]/but[]), but)) 45 | display(map(s -> begin 46 | w = wins[] 47 | t = but[] 48 | Axis([ 49 | Plots.Linear(θ->pdf(Beta(w+1, t-w+1), θ), (0,1), legendentry="Beta($(w+1), $(t-w+1))") 50 | ], 51 | xmin=0,xmax=1,ymin=0, width="15cm", height="10cm") 52 | end, but 53 | )) 54 | end 55 | t = togglebuttons(["Hide", "Show"], value="Hide", label="True Params") 56 | display(t) 57 | display(map(v -> v == "Show" ? string(b.θ) : "", t)) 58 | end 59 | 60 | mutable struct BanditStatistics 61 | numWins::Vector{Int} 62 | numTries::Vector{Int} 63 | BanditStatistics(k::Int) = new(zeros(k), zeros(k)) 64 | end 65 | numArms(b::BanditStatistics) = length(b.numWins) 66 | function update!(b::BanditStatistics, i::Int, success::Bool) 67 | b.numTries[i] += 1 68 | if success 69 | b.numWins[i] += 1 70 | end 71 | end 72 | # win probability assuming uniform prior 73 | winProbabilities(b::BanditStatistics) = (b.numWins .+ 1)./(b.numTries .+ 2) 74 | 75 | abstract type BanditPolicy end 76 | 77 | reset!(p::BanditPolicy) = nothing 78 | 79 | function simulate(b::Bandit, policy::BanditPolicy; steps = 10) 80 | wins = zeros(Int, steps) 81 | s = BanditStatistics(numArms(b)) 82 | reset!(policy) 83 | for step = 1:steps 84 | i = arm(policy, s) 85 | win = pull(b, i) 86 | update!(s, i, win) 87 | wins[step] = win 88 | end 89 | wins 90 | end 91 | 92 | function simulateAverage(b::Bandit, policy::BanditPolicy; steps = 10, iterations = 10) 93 | ret = zeros(Int, steps) 94 | for i = 1:iterations 95 | ret .+= simulate(b, policy, steps=steps) 96 | end 97 | ret ./ iterations 98 | end 99 | 100 | function learningCurves(b::Bandit, policies; steps=10, iterations=10) 101 | lines = Plots.Linear[] 102 | for (name, policy) in policies 103 | results = simulateAverage(b, policy; steps=steps, iterations=iterations) 104 | push!(lines, Plots.Linear(results, legendentry=name, style="very thick", mark="none")) 105 | end 106 | return lines 107 | end 108 | -------------------------------------------------------------------------------- /gridworld.jl: -------------------------------------------------------------------------------- 1 | using POMDPs 2 | 3 | # Problem based on https://www.cs.ubc.ca/~poole/demos/mdp/vi.html 4 | 5 | using TikzPictures 6 | using Printf 7 | 8 | mutable struct DMUGridWorld <: MDP{Int, Symbol} 9 | S::Vector{Int} 10 | A::Vector{Symbol} 11 | T::Array{Float64,3} 12 | R::Matrix{Float64} 13 | discount::Float64 14 | actionIndex::Dict{Symbol, Int} 15 | nextStates::Dict{Tuple{Int, Symbol}, Vector{Int}} 16 | end 17 | 18 | actions(g::DMUGridWorld) = g.A 19 | states(g::DMUGridWorld) = g.S 20 | n_actions(g::DMUGridWorld) = length(g.A) 21 | n_states(g::DMUGridWorld) = length(g.S) 22 | reward(g::DMUGridWorld, s::Int, a::Symbol) = g.R[s, g.actionIndex[a]] 23 | transition_pdf(g::DMUGridWorld, s0::Int, a::Symbol, s1::Int) = g.T[s0, g.actionIndex[a], s1] 24 | discount(g::DMUGridWorld) = g.discount 25 | next_states(g::DMUGridWorld, s, a) = g.nextStates[(s, a)] 26 | state_index(g::DMUGridWorld, s) = s 27 | action_index(g::DMUGridWorld, a) = g.actionIndex[a] 28 | 29 | function locals(mdp::MDP) 30 | S = states(mdp) 31 | A = actions(mdp) 32 | T = (s0, a, s1) -> transition_pdf(mdp, s0, a, s1) 33 | R = (s, a) -> reward(mdp, s, a) 34 | gamma = discount(mdp) 35 | (S, A, T, R, gamma) 36 | end 37 | 38 | s2xy(s) = Tuple(CartesianIndices((10,10))[s]) 39 | 40 | function xy2s(x, y) 41 | x = max(x, 1) 42 | y = max(y, 1) 43 | x = min(x, 10) 44 | y = min(y, 10) 45 | LinearIndices((10, 10))[x,y] 46 | end 47 | 48 | function DMUGridWorld() 49 | A = [:left, :right, :up, :down] 50 | S = 1:100 51 | T = zeros(length(S), length(A), length(S)) 52 | R = zeros(length(S), length(A)) 53 | for s in S 54 | (x, y) = s2xy(s) 55 | if x == 3 && y == 8 56 | R[s, :] .= 3 57 | elseif x == 8 && y == 9 58 | R[s, :] .= 10 59 | else 60 | if x == 8 && y == 4 61 | R[s, :] .= -10 62 | elseif x == 5 && y == 4 63 | R[s, :] .= -5 64 | elseif x == 1 65 | if y == 1 || y == 10 66 | R[s, :] .= -0.2 67 | else 68 | R[s, :] .= -0.1 69 | end 70 | 71 | R[s, 3] = -0.7 72 | elseif x == 10 73 | if y == 1 || y == 10 74 | R[s, :] .= -0.2 75 | else 76 | R[s, :] .= -0.1 77 | end 78 | R[s, 4] = -0.7 79 | elseif y == 1 80 | if x == 1 || x == 10 81 | R[s, :] .= -0.2 82 | else 83 | R[s, :] .= -0.1 84 | end 85 | R[s, 1] = -0.7 86 | elseif y == 10 87 | if x == 1 || x == 10 88 | R[s, :] .= -0.2 89 | else 90 | R[s, :] .= -0.1 91 | end 92 | R[s, 2] = -0.7 93 | end 94 | for a in A 95 | if a == :left 96 | T[s, 1, xy2s(x, y - 1)] += 0.7 97 | T[s, 1, xy2s(x, y + 1)] += 0.1 98 | T[s, 1, xy2s(x - 1, y)] += 0.1 99 | T[s, 1, xy2s(x + 1, y)] += 0.1 100 | elseif a == :right 101 | T[s, 2, xy2s(x, y + 1)] += 0.7 102 | T[s, 2, xy2s(x, y - 1)] += 0.1 103 | T[s, 2, xy2s(x - 1, y)] += 0.1 104 | T[s, 2, xy2s(x + 1, y)] += 0.1 105 | elseif a == :up 106 | T[s, 3, xy2s(x - 1, y)] += 0.7 107 | T[s, 3, xy2s(x + 1, y)] += 0.1 108 | T[s, 3, xy2s(x, y - 1)] += 0.1 109 | T[s, 3, xy2s(x, y + 1)] += 0.1 110 | elseif a == :down 111 | T[s, 4, xy2s(x + 1, y)] += 0.7 112 | T[s, 4, xy2s(x - 1, y)] += 0.1 113 | T[s, 4, xy2s(x, y - 1)] += 0.1 114 | T[s, 4, xy2s(x, y + 1)] += 0.1 115 | end 116 | end 117 | end 118 | end 119 | R[1,1] = -0.8 120 | R[10,1] = -0.8 121 | R[91,2] = -0.8 122 | R[100,2] = -0.8 123 | R[1,3] = -0.8 124 | R[91,3] = -0.8 125 | R[10,4] = -0.8 126 | R[100,4] = -0.8 127 | discount = 0.9 128 | nextStates = Dict([(S[si], A[ai])=>findall(x->x!=0, T[si, ai, :]) for si=1:length(S), ai=1:length(A)]) 129 | DMUGridWorld(S, A, T, R, discount, Dict([A[i]=>i for i=1:length(A)]), nextStates) 130 | end 131 | 132 | function colorval(val, brightness::Real = 1.0) 133 | val = convert(Vector{Float64}, val) 134 | x = 255 .- min.(255, 255 * (abs.(val) ./ 10.0) .^ brightness) 135 | r = 255 * ones(size(val)) 136 | g = 255 * ones(size(val)) 137 | b = 255 * ones(size(val)) 138 | r[val .>= 0] .= x[val .>= 0] 139 | b[val .>= 0] .= x[val .>= 0] 140 | g[val .< 0] .= x[val .< 0] 141 | b[val .< 0] .= x[val .< 0] 142 | (r, g, b) 143 | end 144 | 145 | function plot(g::DMUGridWorld, f::Function) 146 | V = map(f, g.S) 147 | plot(g, V) 148 | end 149 | 150 | function plot(obj::DMUGridWorld, V::Vector; curState=0) 151 | o = IOBuffer() 152 | sqsize = 1.0 153 | twid = 0.05 154 | (r, g, b) = colorval(V) 155 | for s = obj.S 156 | (yval, xval) = s2xy(s) 157 | yval = 10 - yval 158 | println(o, "\\definecolor{currentcolor}{RGB}{$(r[s]),$(g[s]),$(b[s])}") 159 | println(o, "\\fill[currentcolor] ($((xval-1) * sqsize),$((yval) * sqsize)) rectangle +($sqsize,$sqsize);") 160 | if s == curState 161 | println(o, "\\fill[orange] ($((xval-1) * sqsize),$((yval) * sqsize)) rectangle +($sqsize,$sqsize);") 162 | end 163 | vs = Printf.@sprintf("%0.2f", V[s]) 164 | println(o, "\\node[above right] at ($((xval-1) * sqsize), $((yval) * sqsize)) {\$$(vs)\$};") 165 | end 166 | println(o, "\\draw[black] grid(10,10);") 167 | tikzDeleteIntermediate(false) 168 | TikzPicture(String(take!(o)), options="scale=1.25") 169 | end 170 | 171 | function plot(g::DMUGridWorld, f::Function, policy::Function; curState=0) 172 | V = map(f, g.S) 173 | plot(g, V, policy, curState=curState) 174 | end 175 | 176 | function plot(obj::DMUGridWorld, V::Vector, policy::Function; curState=0) 177 | P = map(policy, obj.S) 178 | plot(obj, V, P, curState=curState) 179 | end 180 | 181 | function plot(obj::DMUGridWorld, V::Vector, policy::Vector; curState=0) 182 | o = IOBuffer() 183 | sqsize = 1.0 184 | twid = 0.05 185 | (r, g, b) = colorval(V) 186 | for s in obj.S 187 | (yval, xval) = s2xy(s) 188 | yval = 10 - yval 189 | println(o, "\\definecolor{currentcolor}{RGB}{$(r[s]),$(g[s]),$(b[s])}") 190 | println(o, "\\fill[currentcolor] ($((xval-1) * sqsize),$((yval) * sqsize)) rectangle +($sqsize,$sqsize);") 191 | if s == curState 192 | println(o, "\\fill[orange] ($((xval-1) * sqsize),$((yval) * sqsize)) rectangle +($sqsize,$sqsize);") 193 | end 194 | end 195 | println(o, "\\begin{scope}[fill=gray]") 196 | for s in obj.S 197 | (yval, xval) = s2xy(s) 198 | yval = 10 - yval + 1 199 | c = [xval, yval] * sqsize .- sqsize / 2 200 | C = [c'; c'; c']' 201 | RightArrow = [0 0 sqsize/2; twid -twid 0] 202 | if policy[s] == :left 203 | A = [-1 0; 0 -1] * RightArrow + C 204 | println(o, "\\fill ($(A[1]), $(A[2])) -- ($(A[3]), $(A[4])) -- ($(A[5]), $(A[6])) -- cycle;") 205 | end 206 | if policy[s] == :right 207 | A = RightArrow + C 208 | println(o, "\\fill ($(A[1]), $(A[2])) -- ($(A[3]), $(A[4])) -- ($(A[5]), $(A[6])) -- cycle;") 209 | end 210 | if policy[s] == :up 211 | A = [0 -1; 1 0] * RightArrow + C 212 | println(o, "\\fill ($(A[1]), $(A[2])) -- ($(A[3]), $(A[4])) -- ($(A[5]), $(A[6])) -- cycle;") 213 | end 214 | if policy[s] == :down 215 | A = [0 1; -1 0] * RightArrow + C 216 | println(o, "\\fill ($(A[1]), $(A[2])) -- ($(A[3]), $(A[4])) -- ($(A[5]), $(A[6])) -- cycle;") 217 | end 218 | 219 | vs = Printf.@sprintf("%0.2f", V[s]) 220 | println(o, "\\node[above right] at ($((xval-1) * sqsize), $((yval-1) * sqsize)) {\$$(vs)\$};") 221 | end 222 | println(o, "\\end{scope}"); 223 | println(o, "\\draw[black] grid(10,10);"); 224 | TikzPicture(String(take!(o)), options="scale=1.25") 225 | end 226 | 227 | # simulates taking action a from s 228 | function simulate(g::DMUGridWorld, s::Int, a::Symbol) 229 | probs = Float64[] 230 | if length(next_states(g,s,a)) == 0 231 | println("s = ", s) 232 | println("a = ", a) 233 | end 234 | for sp in next_states(g, s, a) 235 | push!(probs, transition_pdf(g, s, a, sp) ) 236 | end 237 | 238 | # make sure these sum to 1. They should, but let's be safe. 239 | probs = probs / sum(probs) 240 | 241 | # sample a random value from next states 242 | rand_val = rand() 243 | sampled_idx = 1 244 | prob_sum = 0.0 245 | i = 1 246 | while true 247 | prob_sum += probs[i] 248 | if rand_val < prob_sum 249 | sampled_idx = i 250 | break 251 | end 252 | i += 1 253 | end 254 | sp = next_states(g,s,a)[sampled_idx] 255 | 256 | return sp, reward(g,s,a) 257 | end 258 | -------------------------------------------------------------------------------- /helpers.jl: -------------------------------------------------------------------------------- 1 | using Printf 2 | using LinearAlgebra 3 | macro max(range, ex) 4 | :(maximum($(Expr(:typed_comprehension, :Float64, ex, range)))) 5 | end 6 | macro sum(range, ex) 7 | :(sum($(Expr(:typed_comprehension, :Float64, ex, range)))) 8 | end 9 | macro min(range, ex) 10 | :(minimum($(Expr(:typed_comprehension, :Float64, ex, range)))) 11 | end 12 | macro prod(range, ex) 13 | :(prod($(Expr(:typed_comprehension, :Float64, ex, range)))) 14 | end 15 | macro argmax(range, ex) 16 | @assert(range.head == :in) 17 | @assert(length(range.args) == 2) 18 | :($(range.args[2])[indmax($(Expr(:typed_comprehension, :Float64, ex, range)))]) 19 | end 20 | macro argmin(range, ex) 21 | @assert(range.head == :in) 22 | @assert(length(range.args) == 2) 23 | :($(range.args[2])[indmin($(Expr(:typed_comprehension, :Float64, ex, range)))]) 24 | end 25 | macro array(range, ex) 26 | :($(Expr(:typed_comprehension, :Float64, ex, range))) 27 | end 28 | 29 | function polyfit(x, y, n) 30 | A = [float(xi)^p for xi in x, p = 0:n] 31 | (q, r) = LinearAlgebra.qr(A) 32 | r \ (q[:,1:n+1]' * y) 33 | end 34 | 35 | function prettyPolynomial(λ) 36 | o = IOBuffer() 37 | Printf.@printf(o, "\$") 38 | for i = 1:length(λ) 39 | if i == 1 40 | Printf.@printf(o, "%0.2f", λ[i]) 41 | elseif i == 2 42 | if λ[i] < 0 43 | Printf.@printf(o, "%0.2f x", λ[i]) 44 | else 45 | Printf.@printf(o, "+%0.2f x", λ[i]) 46 | end 47 | else 48 | if λ[i] < 0 49 | Printf.@printf(o, "%0.2fx^{%d}", λ[i], i-1) 50 | else 51 | Printf.@printf(o, "+%0.2fx^{%d}", λ[i], i-1) 52 | end 53 | end 54 | end 55 | Printf.@printf(o, "\$") 56 | String(take!(o)) 57 | end 58 | 59 | using TikzPictures 60 | 61 | function plot_chain(len; fill::Dict{Int,String}=Dict{Int64,String}()) 62 | str = "\\draw " 63 | for i in 1:len 64 | fl = get(fill, i, "white") 65 | str = string(str, "($(i)cm, 0cm) node[draw=black,circle,fill=$fl]{$i}") 66 | if i == len 67 | str = string(str, ";") 68 | else 69 | str = string(str, " -- ") 70 | end 71 | end 72 | return TikzPicture(str) 73 | end 74 | -------------------------------------------------------------------------------- /install.jl: -------------------------------------------------------------------------------- 1 | import Pkg 2 | 3 | @info("Adding JuliaPOMDP Package Registry to your global list of registries.") 4 | Pkg.add("POMDPs") 5 | using POMDPs 6 | POMDPs.add_registry() 7 | 8 | ENV["PYTHON"]="" 9 | 10 | projdir = dirname(@__FILE__()) 11 | toml = open(joinpath(projdir, "Project.toml")) do f 12 | Pkg.TOML.parse(f) 13 | end 14 | pkgs = collect(keys(toml["deps"])) 15 | pkgstring = string([pkg*"\n " for pkg in pkgs]...) 16 | @info(""" 17 | Installing the following packages to the current environment: 18 | 19 | $pkgstring 20 | """) 21 | 22 | Pkg.add(pkgs) 23 | 24 | @info("Dependency install complete! (check for errors)") 25 | -------------------------------------------------------------------------------- /rl.jl: -------------------------------------------------------------------------------- 1 | using Distributions 2 | using StatsBase 3 | using Random 4 | include("gridworld.jl") 5 | include("helpers.jl") 6 | 7 | mutable struct MappedDiscreteMDP{SType,AType} <: MDP{SType,AType} 8 | S::Vector{SType} 9 | A::Vector{AType} 10 | T::Array{Float64,3} 11 | R::Matrix{Float64} 12 | discount::Float64 13 | stateIndex::Dict 14 | actionIndex::Dict 15 | nextStates 16 | end 17 | 18 | function MappedDiscreteMDP(S::Vector, A::Vector, T, R; discount=0.9) 19 | stateIndex = Dict([S[i]=>i for i in 1:length(S)]) 20 | actionIndex = Dict([A[i]=>i for i in 1:length(A)]) 21 | nextStates = Dict([(S[si], A[ai])=>S[findall(x->x!=0, T[si, ai, :])] for si=1:length(S), ai=1:length(A)]) 22 | MappedDiscreteMDP(S, A, T, R, discount, stateIndex, actionIndex, nextStates) 23 | end 24 | 25 | MappedDiscreteMDP(S::Vector, A::Vector; discount=0.9) = 26 | MappedDiscreteMDP(S, A, 27 | zeros(length(S), length(A), length(S)), 28 | zeros(length(S), length(A)), 29 | discount=discount) 30 | 31 | actions(mdp::MappedDiscreteMDP) = mdp.A 32 | states(mdp::MappedDiscreteMDP) = mdp.S 33 | n_states(mdp::MappedDiscreteMDP) = length(mdp.S) 34 | n_actions(mdp::MappedDiscreteMDP) = length(mdp.A) 35 | reward(mdp::MappedDiscreteMDP, s, a) = mdp.R[mdp.stateIndex[s], mdp.actionIndex[a]] 36 | transition_pdf(mdp::MappedDiscreteMDP, s0, a, s1) = mdp.T[mdp.stateIndex[s0], mdp.actionIndex[a], mdp.stateIndex[s1]] 37 | discount(mdp::MappedDiscreteMDP) = mdp.discount 38 | state_index(mdp::MappedDiscreteMDP, s) = mdp.stateIndex[s] 39 | action_index(mdp::MappedDiscreteMDP, a) = mdp.actionIndex[s] 40 | next_states(mdp::MappedDiscreteMDP, s, a) = mdp.nextStates[(s, a)] 41 | 42 | 43 | rand_state(mdp::MDP) = states(mdp)[rand(DiscreteUniform(1,n_states(mdp)))] 44 | 45 | function value_iteration(mdp::MDP, iterations::Integer) 46 | V = zeros(n_states(mdp)) 47 | Q = zeros(n_states(mdp), n_actions(mdp)) 48 | value_iteration!(V, Q, mdp, iterations) 49 | (V, Q) 50 | end 51 | 52 | function value_iteration!(V::Vector, Q::Matrix, mdp::MDP, iterations::Integer) 53 | (S, A, T, R, discount) = locals(mdp) 54 | V_old = copy(V) 55 | for i = 1:iterations 56 | for s0i in 1:n_states(mdp) 57 | s0 = S[s0i] 58 | for ai = 1:n_actions(mdp) 59 | a = A[ai] 60 | Q[s0i,ai] = R(s0, a) + discount * sum([0.0; [T(s0, a, s1)*V_old[state_index(mdp, s1)] for s1 in next_states(mdp, s0, a)]]) 61 | end 62 | V[s0i] = maximum(Q[s0i,:]) 63 | end 64 | copyto!(V_old, V) 65 | end 66 | end 67 | 68 | function update_parameters!(mdp::MappedDiscreteMDP, N, Nsa, ρ, s, a) 69 | si = mdp.stateIndex[s] 70 | ai = mdp.actionIndex[a] 71 | denom = Nsa[si, ai] 72 | mdp.T[si, ai, :] = N[si, ai, :] ./ denom 73 | mdp.R[si, ai] = ρ[si, ai] / denom 74 | mdp.nextStates[(s, a)]= mdp.S[findall(x->x!=0, mdp.T[si, ai, :])] 75 | end 76 | 77 | function isterminal(mdp::MDP, s0, a) 78 | S1 = next_states(mdp, s0, a) 79 | length(S1) == 0 || 0 == sum(s1 -> transition_pdf(mdp, s0, a, s1), S1) 80 | end 81 | 82 | function generate_s(mdp::MDP, s0, a, rng::AbstractRNG=Random.GLOBAL_RNG) 83 | p = [transition_pdf(mdp, s0, a, s1) for s1 in states(mdp)] 84 | s1i = sample(rng, Weights(p)) 85 | states(mdp)[s1i] 86 | end 87 | 88 | mutable struct MLRL <: Policy 89 | N::Array{Float64,3} # transition counts 90 | Nsa::Matrix{Float64} # state-action counts 91 | ρ::Matrix{Float64} # sum of rewards 92 | lastState 93 | lastAction 94 | lastReward 95 | newEpisode 96 | mdp::MappedDiscreteMDP 97 | Q::Matrix{Float64} 98 | V::Vector{Float64} 99 | iterations::Int 100 | epsilon::Float64 # probability of exploration 101 | function MLRL(S, A; discount=0.9, iterations=20, epsilon=0.2) 102 | N = zeros(length(S), length(A), length(S)) 103 | Nsa = zeros(length(S), length(A)) 104 | ρ = zeros(length(S), length(A)) 105 | lastState = nothing 106 | lastAction = nothing 107 | lastReward = 0. 108 | mdp = MappedDiscreteMDP(S, A, discount=discount) 109 | Q = zeros(length(S), length(A)) 110 | V = zeros(length(S)) 111 | newEpisode = true 112 | new(N, Nsa, ρ, lastState, lastAction, lastReward, newEpisode, mdp, Q, V, iterations, epsilon) 113 | end 114 | end 115 | 116 | function reset(policy::MLRL) 117 | if !policy.newEpisode 118 | s0i = policy.mdp.stateIndex[policy.lastState] 119 | ai = policy.mdp.actionIndex[policy.lastAction] 120 | policy.Nsa[s0i, ai] += 1 121 | policy.ρ[s0i, ai] = policy.lastReward 122 | # update Q and V 123 | update_parameters!(policy.mdp, policy.N, policy.Nsa, policy.ρ, policy.lastState, policy.lastAction) 124 | value_iteration!(policy.V, policy.Q, policy.mdp, policy.iterations) 125 | policy.newEpisode = true 126 | end 127 | end 128 | 129 | function update(policy::MLRL, s, a, r) 130 | if policy.newEpisode 131 | policy.newEpisode = false 132 | else 133 | s0i = policy.mdp.stateIndex[policy.lastState] 134 | ai = policy.mdp.actionIndex[policy.lastAction] 135 | s1i = policy.mdp.stateIndex[s] 136 | policy.N[s0i, ai, s1i] += 1 137 | policy.Nsa[s0i, ai] += 1 138 | policy.ρ[s0i, ai] += policy.lastReward 139 | # update Q and V 140 | update_parameters!(policy.mdp, policy.N, policy.Nsa, policy.ρ, policy.lastState, policy.lastAction) 141 | value_iteration!(policy.V, policy.Q, policy.mdp, policy.iterations) 142 | end 143 | policy.lastState = s 144 | policy.lastAction = a 145 | policy.lastReward = r 146 | nothing 147 | end 148 | 149 | function action(policy::MLRL, s) 150 | si = policy.mdp.stateIndex[s] 151 | Qs = policy.Q[si, :] 152 | ais = findall((in)(maximum(Qs)), Qs) 153 | ai = rand(ais) 154 | policy.mdp.A[ai] 155 | end 156 | 157 | function action(policy::MLRL) 158 | if rand() < policy.epsilon 159 | policy.mdp.A[rand(DiscreteUniform(1,numActions(policy.mdp)))] 160 | else 161 | action(policy, policy.lastState) 162 | end 163 | end 164 | 165 | function simulate(mdp::MDP, steps::Integer, policy::Policy; script=[]) 166 | S = Any[] 167 | V = Any[] 168 | R = Float64[] 169 | if length(script) == 0 170 | s = rand_state(mdp) 171 | else 172 | s = script[1] 173 | end 174 | for i = 1:steps 175 | push!(S, s) 176 | a = action(policy, s) 177 | r = reward(mdp, s, a) 178 | push!(R, r) 179 | update(policy, s, a, r) 180 | push!(V, copy(policy.V)) 181 | if i < length(script) 182 | s = script[i + 1] 183 | else 184 | if isterminal(mdp, s, a) 185 | s = rand_state(mdp) 186 | reset(policy) 187 | else 188 | s = generate_s(mdp, s, a) 189 | end 190 | end 191 | end 192 | (S, R, V) 193 | end 194 | -------------------------------------------------------------------------------- /runtests.jl: -------------------------------------------------------------------------------- 1 | using NBInclude 2 | using Test 3 | 4 | @testset "notebooks" begin 5 | for d in readdir(".") 6 | # if endswith(d, ".ipynb") && !startswith(d, "08-Markov") && !startswith(d, "09-") && !startswith(d, "11-") && !startswith(d, "16-") && !startswith(d, "POM") # ignore MDP notebook because it fails for some reason 7 | if endswith(d, ".ipynb") 8 | @info("Running "*d) 9 | stuff = "using InteractiveUtils; using NBInclude; ENV[\"PYTHON\"]=\"\"; @nbinclude(\"" * d * "\")" 10 | projdir = dirname(@__FILE__()) 11 | cmd = `julia --project=$projdir -e $stuff` 12 | proc = run(pipeline(cmd, stderr=stderr), wait=false) 13 | @test success(proc) 14 | end 15 | end 16 | end 17 | --------------------------------------------------------------------------------