├── .gitignore
├── .travis.yml
├── 01-Introduction.ipynb
├── 02-ProbabilisticRepresentations.ipynb
├── 03-Inference.ipynb
├── 04-ParameterLearning.ipynb
├── 05-StructureLearning.ipynb
├── 06-DecisionNetworks.ipynb
├── 07-Games.ipynb
├── 08-MarkovDecisionProcesses.ipynb
├── 09-ApproximateDynamicProgramming.ipynb
├── 10-ExplorationExploitation.ipynb
├── 11-ModelBasedReinforcementLearning.ipynb
├── 12-ModelFreeReinforcementLearning.ipynb
├── 13-StateUncertainty.ipynb
├── 14-ExactPOMDPMethods.ipynb
├── 15-OfflinePOMDPMethods.ipynb
├── 16-OnlinePOMDPMethods.ipynb
├── POMDPs-jl-demo.ipynb
├── Project.toml
├── README.md
├── alpha_plots.jl
├── baby.jl
├── bandits.jl
├── gridworld.jl
├── helpers.jl
├── install.jl
├── rl.jl
└── runtests.jl
/.gitignore:
--------------------------------------------------------------------------------
1 | .ipynb_checkpoints/
2 | *.DS_Store
3 | tmp*
4 |
--------------------------------------------------------------------------------
/.travis.yml:
--------------------------------------------------------------------------------
1 | language: julia
2 | dist: trusty
3 | julia:
4 | - 1.2
5 | notifications:
6 | email: false
7 | before_install:
8 | - sudo apt-get install texlive-latex-extra
9 | script:
10 | - git clone https://github.com/JuliaRegistries/General $(julia -e 'import Pkg; println(joinpath(Pkg.depots1(), "registries", "General"))')
11 | - git clone https://github.com/JuliaPOMDP/Registry $(julia -e 'import Pkg; println(joinpath(Pkg.depots1(), "registries", "JuliaPOMDP"))')
12 | - if [[ -a .git/shallow ]]; then git fetch --unshallow; fi
13 | - julia -e 'import Pkg; ENV["PYTHON"]=""; Pkg.add("PyCall"); Pkg.build("PyCall")'
14 | - julia -e 'import Pkg; Pkg.add("Conda"); using Conda; Conda.add("matplotlib")'
15 | - julia --check-bounds=yes -e 'include("install.jl"); include("runtests.jl")'
16 |
--------------------------------------------------------------------------------
/01-Introduction.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Quick introduction to Julia"
8 | ]
9 | },
10 | {
11 | "cell_type": "markdown",
12 | "metadata": {},
13 | "source": [
14 | "These examples are based on http://learnxinyminutes.com/docs/julia/. Assumes Julia 1.2"
15 | ]
16 | },
17 | {
18 | "cell_type": "markdown",
19 | "metadata": {},
20 | "source": [
21 | "## Types"
22 | ]
23 | },
24 | {
25 | "cell_type": "markdown",
26 | "metadata": {},
27 | "source": [
28 | "There are different types of numbers."
29 | ]
30 | },
31 | {
32 | "cell_type": "code",
33 | "execution_count": 1,
34 | "metadata": {},
35 | "outputs": [
36 | {
37 | "data": {
38 | "text/plain": [
39 | "String"
40 | ]
41 | },
42 | "execution_count": 1,
43 | "metadata": {},
44 | "output_type": "execute_result"
45 | }
46 | ],
47 | "source": [
48 | "typeof(\"mykel\")"
49 | ]
50 | },
51 | {
52 | "cell_type": "code",
53 | "execution_count": 2,
54 | "metadata": {},
55 | "outputs": [
56 | {
57 | "data": {
58 | "text/plain": [
59 | "Float64"
60 | ]
61 | },
62 | "execution_count": 2,
63 | "metadata": {},
64 | "output_type": "execute_result"
65 | }
66 | ],
67 | "source": [
68 | "typeof(1.0)"
69 | ]
70 | },
71 | {
72 | "cell_type": "code",
73 | "execution_count": 3,
74 | "metadata": {},
75 | "outputs": [
76 | {
77 | "data": {
78 | "text/plain": [
79 | "Complex{Int64}"
80 | ]
81 | },
82 | "execution_count": 3,
83 | "metadata": {},
84 | "output_type": "execute_result"
85 | }
86 | ],
87 | "source": [
88 | "typeof(1 + 1im)"
89 | ]
90 | },
91 | {
92 | "cell_type": "code",
93 | "execution_count": 4,
94 | "metadata": {},
95 | "outputs": [
96 | {
97 | "data": {
98 | "text/plain": [
99 | "AbstractFloat"
100 | ]
101 | },
102 | "execution_count": 4,
103 | "metadata": {},
104 | "output_type": "execute_result"
105 | }
106 | ],
107 | "source": [
108 | "supertype(Float64)"
109 | ]
110 | },
111 | {
112 | "cell_type": "code",
113 | "execution_count": 5,
114 | "metadata": {},
115 | "outputs": [
116 | {
117 | "data": {
118 | "text/plain": [
119 | "Real"
120 | ]
121 | },
122 | "execution_count": 5,
123 | "metadata": {},
124 | "output_type": "execute_result"
125 | }
126 | ],
127 | "source": [
128 | "supertype(AbstractFloat)"
129 | ]
130 | },
131 | {
132 | "cell_type": "code",
133 | "execution_count": 6,
134 | "metadata": {},
135 | "outputs": [
136 | {
137 | "data": {
138 | "text/plain": [
139 | "Number"
140 | ]
141 | },
142 | "execution_count": 6,
143 | "metadata": {},
144 | "output_type": "execute_result"
145 | }
146 | ],
147 | "source": [
148 | "supertype(Real)"
149 | ]
150 | },
151 | {
152 | "cell_type": "code",
153 | "execution_count": 7,
154 | "metadata": {},
155 | "outputs": [
156 | {
157 | "data": {
158 | "text/plain": [
159 | "Any"
160 | ]
161 | },
162 | "execution_count": 7,
163 | "metadata": {},
164 | "output_type": "execute_result"
165 | }
166 | ],
167 | "source": [
168 | "supertype(Number)"
169 | ]
170 | },
171 | {
172 | "cell_type": "code",
173 | "execution_count": 8,
174 | "metadata": {},
175 | "outputs": [
176 | {
177 | "data": {
178 | "text/plain": [
179 | "Signed"
180 | ]
181 | },
182 | "execution_count": 8,
183 | "metadata": {},
184 | "output_type": "execute_result"
185 | }
186 | ],
187 | "source": [
188 | "supertype(Int64)"
189 | ]
190 | },
191 | {
192 | "cell_type": "code",
193 | "execution_count": 9,
194 | "metadata": {},
195 | "outputs": [
196 | {
197 | "data": {
198 | "text/plain": [
199 | "Integer"
200 | ]
201 | },
202 | "execution_count": 9,
203 | "metadata": {},
204 | "output_type": "execute_result"
205 | }
206 | ],
207 | "source": [
208 | "supertype(Signed)"
209 | ]
210 | },
211 | {
212 | "cell_type": "code",
213 | "execution_count": 10,
214 | "metadata": {},
215 | "outputs": [
216 | {
217 | "data": {
218 | "text/plain": [
219 | "Real"
220 | ]
221 | },
222 | "execution_count": 10,
223 | "metadata": {},
224 | "output_type": "execute_result"
225 | }
226 | ],
227 | "source": [
228 | "supertype(Integer)"
229 | ]
230 | },
231 | {
232 | "cell_type": "markdown",
233 | "metadata": {},
234 | "source": [
235 | "Boolean types"
236 | ]
237 | },
238 | {
239 | "cell_type": "code",
240 | "execution_count": 11,
241 | "metadata": {},
242 | "outputs": [
243 | {
244 | "data": {
245 | "text/plain": [
246 | "Bool"
247 | ]
248 | },
249 | "execution_count": 11,
250 | "metadata": {},
251 | "output_type": "execute_result"
252 | }
253 | ],
254 | "source": [
255 | "typeof(true)"
256 | ]
257 | },
258 | {
259 | "cell_type": "markdown",
260 | "metadata": {},
261 | "source": [
262 | "## Boolean Operators"
263 | ]
264 | },
265 | {
266 | "cell_type": "markdown",
267 | "metadata": {},
268 | "source": [
269 | "Negation is done with !"
270 | ]
271 | },
272 | {
273 | "cell_type": "code",
274 | "execution_count": 12,
275 | "metadata": {},
276 | "outputs": [
277 | {
278 | "data": {
279 | "text/plain": [
280 | "false"
281 | ]
282 | },
283 | "execution_count": 12,
284 | "metadata": {},
285 | "output_type": "execute_result"
286 | }
287 | ],
288 | "source": [
289 | "!true"
290 | ]
291 | },
292 | {
293 | "cell_type": "code",
294 | "execution_count": 13,
295 | "metadata": {},
296 | "outputs": [
297 | {
298 | "data": {
299 | "text/plain": [
300 | "true"
301 | ]
302 | },
303 | "execution_count": 13,
304 | "metadata": {},
305 | "output_type": "execute_result"
306 | }
307 | ],
308 | "source": [
309 | "1 == 1"
310 | ]
311 | },
312 | {
313 | "cell_type": "code",
314 | "execution_count": 14,
315 | "metadata": {},
316 | "outputs": [
317 | {
318 | "data": {
319 | "text/plain": [
320 | "false"
321 | ]
322 | },
323 | "execution_count": 14,
324 | "metadata": {},
325 | "output_type": "execute_result"
326 | }
327 | ],
328 | "source": [
329 | "1 != 1"
330 | ]
331 | },
332 | {
333 | "cell_type": "markdown",
334 | "metadata": {},
335 | "source": [
336 | "Comparisons can be chained"
337 | ]
338 | },
339 | {
340 | "cell_type": "code",
341 | "execution_count": 15,
342 | "metadata": {},
343 | "outputs": [
344 | {
345 | "data": {
346 | "text/plain": [
347 | "true"
348 | ]
349 | },
350 | "execution_count": 15,
351 | "metadata": {},
352 | "output_type": "execute_result"
353 | }
354 | ],
355 | "source": [
356 | "1 < 2 < 3"
357 | ]
358 | },
359 | {
360 | "cell_type": "markdown",
361 | "metadata": {},
362 | "source": [
363 | "## Strings"
364 | ]
365 | },
366 | {
367 | "cell_type": "markdown",
368 | "metadata": {},
369 | "source": [
370 | "Use double quotes for strings."
371 | ]
372 | },
373 | {
374 | "cell_type": "code",
375 | "execution_count": 16,
376 | "metadata": {},
377 | "outputs": [
378 | {
379 | "data": {
380 | "text/plain": [
381 | "\"This is a string\""
382 | ]
383 | },
384 | "execution_count": 16,
385 | "metadata": {},
386 | "output_type": "execute_result"
387 | }
388 | ],
389 | "source": [
390 | "\"This is a string\""
391 | ]
392 | },
393 | {
394 | "cell_type": "code",
395 | "execution_count": 17,
396 | "metadata": {},
397 | "outputs": [
398 | {
399 | "data": {
400 | "text/plain": [
401 | "String"
402 | ]
403 | },
404 | "execution_count": 17,
405 | "metadata": {},
406 | "output_type": "execute_result"
407 | }
408 | ],
409 | "source": [
410 | "typeof(\"This is a string\")"
411 | ]
412 | },
413 | {
414 | "cell_type": "markdown",
415 | "metadata": {},
416 | "source": [
417 | "Use single quotes for characters."
418 | ]
419 | },
420 | {
421 | "cell_type": "code",
422 | "execution_count": 18,
423 | "metadata": {},
424 | "outputs": [
425 | {
426 | "data": {
427 | "text/plain": [
428 | "'a': ASCII/Unicode U+0061 (category Ll: Letter, lowercase)"
429 | ]
430 | },
431 | "execution_count": 18,
432 | "metadata": {},
433 | "output_type": "execute_result"
434 | }
435 | ],
436 | "source": [
437 | "'a'"
438 | ]
439 | },
440 | {
441 | "cell_type": "code",
442 | "execution_count": 19,
443 | "metadata": {},
444 | "outputs": [
445 | {
446 | "data": {
447 | "text/plain": [
448 | "Char"
449 | ]
450 | },
451 | "execution_count": 19,
452 | "metadata": {},
453 | "output_type": "execute_result"
454 | }
455 | ],
456 | "source": [
457 | "typeof('a')"
458 | ]
459 | },
460 | {
461 | "cell_type": "code",
462 | "execution_count": 20,
463 | "metadata": {},
464 | "outputs": [
465 | {
466 | "data": {
467 | "text/plain": [
468 | "'T': ASCII/Unicode U+0054 (category Lu: Letter, uppercase)"
469 | ]
470 | },
471 | "execution_count": 20,
472 | "metadata": {},
473 | "output_type": "execute_result"
474 | }
475 | ],
476 | "source": [
477 | "\"This is a string\"[1] # note the 1-based indexing---similar to Matlab but unlike C/C++/Java"
478 | ]
479 | },
480 | {
481 | "cell_type": "markdown",
482 | "metadata": {},
483 | "source": [
484 | "$ can be used for \"string interpolation\""
485 | ]
486 | },
487 | {
488 | "cell_type": "code",
489 | "execution_count": 21,
490 | "metadata": {},
491 | "outputs": [
492 | {
493 | "data": {
494 | "text/plain": [
495 | "\"2 + 2 = 4\""
496 | ]
497 | },
498 | "execution_count": 21,
499 | "metadata": {},
500 | "output_type": "execute_result"
501 | }
502 | ],
503 | "source": [
504 | "\"2 + 2 = $(2+2)\""
505 | ]
506 | },
507 | {
508 | "cell_type": "code",
509 | "execution_count": 22,
510 | "metadata": {},
511 | "outputs": [
512 | {
513 | "name": "stdout",
514 | "output_type": "stream",
515 | "text": [
516 | "5 is less than 5.300000"
517 | ]
518 | }
519 | ],
520 | "source": [
521 | "using Printf\n",
522 | "Printf.@printf \"%d is less than %f\" 4.5 5.3"
523 | ]
524 | },
525 | {
526 | "cell_type": "code",
527 | "execution_count": 23,
528 | "metadata": {},
529 | "outputs": [
530 | {
531 | "name": "stdout",
532 | "output_type": "stream",
533 | "text": [
534 | "Welcome to Julia\n"
535 | ]
536 | }
537 | ],
538 | "source": [
539 | "println(\"Welcome to Julia\")"
540 | ]
541 | },
542 | {
543 | "cell_type": "markdown",
544 | "metadata": {},
545 | "source": [
546 | "## Variables"
547 | ]
548 | },
549 | {
550 | "cell_type": "code",
551 | "execution_count": 24,
552 | "metadata": {},
553 | "outputs": [
554 | {
555 | "data": {
556 | "text/plain": [
557 | "5"
558 | ]
559 | },
560 | "execution_count": 24,
561 | "metadata": {},
562 | "output_type": "execute_result"
563 | }
564 | ],
565 | "source": [
566 | "x = 5"
567 | ]
568 | },
569 | {
570 | "cell_type": "markdown",
571 | "metadata": {},
572 | "source": [
573 | "Variable names start with a letter, but after that you can use letters, digits, underscores, and exclamation points."
574 | ]
575 | },
576 | {
577 | "cell_type": "code",
578 | "execution_count": 25,
579 | "metadata": {},
580 | "outputs": [
581 | {
582 | "data": {
583 | "text/plain": [
584 | "1"
585 | ]
586 | },
587 | "execution_count": 25,
588 | "metadata": {},
589 | "output_type": "execute_result"
590 | }
591 | ],
592 | "source": [
593 | "xMarksTheSpot2Dig! = 1"
594 | ]
595 | },
596 | {
597 | "cell_type": "markdown",
598 | "metadata": {},
599 | "source": [
600 | "## Arrays"
601 | ]
602 | },
603 | {
604 | "cell_type": "code",
605 | "execution_count": 26,
606 | "metadata": {},
607 | "outputs": [
608 | {
609 | "data": {
610 | "text/plain": [
611 | "0-element Array{Int64,1}"
612 | ]
613 | },
614 | "execution_count": 26,
615 | "metadata": {},
616 | "output_type": "execute_result"
617 | }
618 | ],
619 | "source": [
620 | "a = Int64[]"
621 | ]
622 | },
623 | {
624 | "cell_type": "code",
625 | "execution_count": 27,
626 | "metadata": {},
627 | "outputs": [
628 | {
629 | "data": {
630 | "text/plain": [
631 | "3-element Array{Int64,1}:\n",
632 | " 4\n",
633 | " 5\n",
634 | " 6"
635 | ]
636 | },
637 | "execution_count": 27,
638 | "metadata": {},
639 | "output_type": "execute_result"
640 | }
641 | ],
642 | "source": [
643 | "b = [4, 5, 6]"
644 | ]
645 | },
646 | {
647 | "cell_type": "code",
648 | "execution_count": 28,
649 | "metadata": {},
650 | "outputs": [
651 | {
652 | "data": {
653 | "text/plain": [
654 | "4"
655 | ]
656 | },
657 | "execution_count": 28,
658 | "metadata": {},
659 | "output_type": "execute_result"
660 | }
661 | ],
662 | "source": [
663 | "b[1]"
664 | ]
665 | },
666 | {
667 | "cell_type": "code",
668 | "execution_count": 29,
669 | "metadata": {},
670 | "outputs": [
671 | {
672 | "data": {
673 | "text/plain": [
674 | "5"
675 | ]
676 | },
677 | "execution_count": 29,
678 | "metadata": {},
679 | "output_type": "execute_result"
680 | }
681 | ],
682 | "source": [
683 | "b[end-1]"
684 | ]
685 | },
686 | {
687 | "cell_type": "code",
688 | "execution_count": 30,
689 | "metadata": {},
690 | "outputs": [
691 | {
692 | "data": {
693 | "text/plain": [
694 | "2×2 Array{Int64,2}:\n",
695 | " 1 2\n",
696 | " 3 4"
697 | ]
698 | },
699 | "execution_count": 30,
700 | "metadata": {},
701 | "output_type": "execute_result"
702 | }
703 | ],
704 | "source": [
705 | "matrix = [1 2; 3 4]"
706 | ]
707 | },
708 | {
709 | "cell_type": "code",
710 | "execution_count": 31,
711 | "metadata": {},
712 | "outputs": [
713 | {
714 | "data": {
715 | "text/plain": [
716 | "0-element Array{Int64,1}"
717 | ]
718 | },
719 | "execution_count": 31,
720 | "metadata": {},
721 | "output_type": "execute_result"
722 | }
723 | ],
724 | "source": [
725 | "a"
726 | ]
727 | },
728 | {
729 | "cell_type": "code",
730 | "execution_count": 32,
731 | "metadata": {},
732 | "outputs": [
733 | {
734 | "data": {
735 | "text/plain": [
736 | "1-element Array{Int64,1}:\n",
737 | " 1"
738 | ]
739 | },
740 | "execution_count": 32,
741 | "metadata": {},
742 | "output_type": "execute_result"
743 | }
744 | ],
745 | "source": [
746 | "push!(a, 1)"
747 | ]
748 | },
749 | {
750 | "cell_type": "code",
751 | "execution_count": 33,
752 | "metadata": {},
753 | "outputs": [
754 | {
755 | "data": {
756 | "text/plain": [
757 | "2-element Array{Int64,1}:\n",
758 | " 1\n",
759 | " 2"
760 | ]
761 | },
762 | "execution_count": 33,
763 | "metadata": {},
764 | "output_type": "execute_result"
765 | }
766 | ],
767 | "source": [
768 | "push!(a, 2)"
769 | ]
770 | },
771 | {
772 | "cell_type": "code",
773 | "execution_count": 34,
774 | "metadata": {},
775 | "outputs": [
776 | {
777 | "data": {
778 | "text/plain": [
779 | "5-element Array{Int64,1}:\n",
780 | " 1\n",
781 | " 2\n",
782 | " 4\n",
783 | " 5\n",
784 | " 6"
785 | ]
786 | },
787 | "execution_count": 34,
788 | "metadata": {},
789 | "output_type": "execute_result"
790 | }
791 | ],
792 | "source": [
793 | "append!(a, b)"
794 | ]
795 | },
796 | {
797 | "cell_type": "code",
798 | "execution_count": 35,
799 | "metadata": {},
800 | "outputs": [
801 | {
802 | "data": {
803 | "text/plain": [
804 | "5-element Array{Int64,1}:\n",
805 | " 1\n",
806 | " 2\n",
807 | " 4\n",
808 | " 5\n",
809 | " 6"
810 | ]
811 | },
812 | "execution_count": 35,
813 | "metadata": {},
814 | "output_type": "execute_result"
815 | }
816 | ],
817 | "source": [
818 | "a"
819 | ]
820 | },
821 | {
822 | "cell_type": "code",
823 | "execution_count": 36,
824 | "metadata": {},
825 | "outputs": [
826 | {
827 | "data": {
828 | "text/plain": [
829 | "6"
830 | ]
831 | },
832 | "execution_count": 36,
833 | "metadata": {},
834 | "output_type": "execute_result"
835 | }
836 | ],
837 | "source": [
838 | "pop!(a)"
839 | ]
840 | },
841 | {
842 | "cell_type": "code",
843 | "execution_count": 37,
844 | "metadata": {},
845 | "outputs": [
846 | {
847 | "data": {
848 | "text/plain": [
849 | "4-element Array{Int64,1}:\n",
850 | " 1\n",
851 | " 2\n",
852 | " 4\n",
853 | " 5"
854 | ]
855 | },
856 | "execution_count": 37,
857 | "metadata": {},
858 | "output_type": "execute_result"
859 | }
860 | ],
861 | "source": [
862 | "a"
863 | ]
864 | },
865 | {
866 | "cell_type": "code",
867 | "execution_count": 38,
868 | "metadata": {},
869 | "outputs": [
870 | {
871 | "data": {
872 | "text/plain": [
873 | "3-element Array{Int64,1}:\n",
874 | " 2\n",
875 | " 4\n",
876 | " 5"
877 | ]
878 | },
879 | "execution_count": 38,
880 | "metadata": {},
881 | "output_type": "execute_result"
882 | }
883 | ],
884 | "source": [
885 | "a[2:4]"
886 | ]
887 | },
888 | {
889 | "cell_type": "code",
890 | "execution_count": 39,
891 | "metadata": {},
892 | "outputs": [
893 | {
894 | "data": {
895 | "text/plain": [
896 | "3-element Array{Int64,1}:\n",
897 | " 2\n",
898 | " 4\n",
899 | " 5"
900 | ]
901 | },
902 | "execution_count": 39,
903 | "metadata": {},
904 | "output_type": "execute_result"
905 | }
906 | ],
907 | "source": [
908 | "a[2:end]"
909 | ]
910 | },
911 | {
912 | "cell_type": "code",
913 | "execution_count": 40,
914 | "metadata": {},
915 | "outputs": [
916 | {
917 | "data": {
918 | "text/plain": [
919 | "5-element Array{Int64,1}:\n",
920 | " 1\n",
921 | " 2\n",
922 | " 4\n",
923 | " 5\n",
924 | " 1"
925 | ]
926 | },
927 | "execution_count": 40,
928 | "metadata": {},
929 | "output_type": "execute_result"
930 | }
931 | ],
932 | "source": [
933 | "push!(a, round(Int64, 1.3))"
934 | ]
935 | },
936 | {
937 | "cell_type": "code",
938 | "execution_count": 41,
939 | "metadata": {},
940 | "outputs": [
941 | {
942 | "data": {
943 | "text/plain": [
944 | "true"
945 | ]
946 | },
947 | "execution_count": 41,
948 | "metadata": {},
949 | "output_type": "execute_result"
950 | }
951 | ],
952 | "source": [
953 | "in(4, a)"
954 | ]
955 | },
956 | {
957 | "cell_type": "code",
958 | "execution_count": 42,
959 | "metadata": {},
960 | "outputs": [
961 | {
962 | "data": {
963 | "text/plain": [
964 | "true"
965 | ]
966 | },
967 | "execution_count": 42,
968 | "metadata": {},
969 | "output_type": "execute_result"
970 | }
971 | ],
972 | "source": [
973 | "4 in a"
974 | ]
975 | },
976 | {
977 | "cell_type": "code",
978 | "execution_count": 43,
979 | "metadata": {},
980 | "outputs": [
981 | {
982 | "data": {
983 | "text/plain": [
984 | "5"
985 | ]
986 | },
987 | "execution_count": 43,
988 | "metadata": {},
989 | "output_type": "execute_result"
990 | }
991 | ],
992 | "source": [
993 | "length(a)"
994 | ]
995 | },
996 | {
997 | "cell_type": "markdown",
998 | "metadata": {},
999 | "source": [
1000 | "## Tuples"
1001 | ]
1002 | },
1003 | {
1004 | "cell_type": "code",
1005 | "execution_count": 44,
1006 | "metadata": {},
1007 | "outputs": [
1008 | {
1009 | "data": {
1010 | "text/plain": [
1011 | "(1, 5, 3)"
1012 | ]
1013 | },
1014 | "execution_count": 44,
1015 | "metadata": {},
1016 | "output_type": "execute_result"
1017 | }
1018 | ],
1019 | "source": [
1020 | "a = (1, 5, 3)"
1021 | ]
1022 | },
1023 | {
1024 | "cell_type": "code",
1025 | "execution_count": 45,
1026 | "metadata": {},
1027 | "outputs": [
1028 | {
1029 | "data": {
1030 | "text/plain": [
1031 | "Tuple{Int64,Int64,Int64}"
1032 | ]
1033 | },
1034 | "execution_count": 45,
1035 | "metadata": {},
1036 | "output_type": "execute_result"
1037 | }
1038 | ],
1039 | "source": [
1040 | "typeof(a)"
1041 | ]
1042 | },
1043 | {
1044 | "cell_type": "code",
1045 | "execution_count": 46,
1046 | "metadata": {},
1047 | "outputs": [
1048 | {
1049 | "data": {
1050 | "text/plain": [
1051 | "5"
1052 | ]
1053 | },
1054 | "execution_count": 46,
1055 | "metadata": {},
1056 | "output_type": "execute_result"
1057 | }
1058 | ],
1059 | "source": [
1060 | "a[2]"
1061 | ]
1062 | },
1063 | {
1064 | "cell_type": "code",
1065 | "execution_count": 47,
1066 | "metadata": {},
1067 | "outputs": [],
1068 | "source": [
1069 | "#a[2] = 3 # can't change elements in a tuple"
1070 | ]
1071 | },
1072 | {
1073 | "cell_type": "code",
1074 | "execution_count": 48,
1075 | "metadata": {},
1076 | "outputs": [
1077 | {
1078 | "data": {
1079 | "text/plain": [
1080 | "(1, 2, 3)"
1081 | ]
1082 | },
1083 | "execution_count": 48,
1084 | "metadata": {},
1085 | "output_type": "execute_result"
1086 | }
1087 | ],
1088 | "source": [
1089 | "a, b, c = (1, 2, 3)"
1090 | ]
1091 | },
1092 | {
1093 | "cell_type": "code",
1094 | "execution_count": 49,
1095 | "metadata": {},
1096 | "outputs": [
1097 | {
1098 | "data": {
1099 | "text/plain": [
1100 | "1"
1101 | ]
1102 | },
1103 | "execution_count": 49,
1104 | "metadata": {},
1105 | "output_type": "execute_result"
1106 | }
1107 | ],
1108 | "source": [
1109 | "a"
1110 | ]
1111 | },
1112 | {
1113 | "cell_type": "code",
1114 | "execution_count": 50,
1115 | "metadata": {},
1116 | "outputs": [
1117 | {
1118 | "data": {
1119 | "text/plain": [
1120 | "2"
1121 | ]
1122 | },
1123 | "execution_count": 50,
1124 | "metadata": {},
1125 | "output_type": "execute_result"
1126 | }
1127 | ],
1128 | "source": [
1129 | "b"
1130 | ]
1131 | },
1132 | {
1133 | "cell_type": "code",
1134 | "execution_count": 51,
1135 | "metadata": {},
1136 | "outputs": [
1137 | {
1138 | "data": {
1139 | "text/plain": [
1140 | "3"
1141 | ]
1142 | },
1143 | "execution_count": 51,
1144 | "metadata": {},
1145 | "output_type": "execute_result"
1146 | }
1147 | ],
1148 | "source": [
1149 | "c"
1150 | ]
1151 | },
1152 | {
1153 | "cell_type": "code",
1154 | "execution_count": 52,
1155 | "metadata": {},
1156 | "outputs": [
1157 | {
1158 | "data": {
1159 | "text/plain": [
1160 | "(1, 2, 3)"
1161 | ]
1162 | },
1163 | "execution_count": 52,
1164 | "metadata": {},
1165 | "output_type": "execute_result"
1166 | }
1167 | ],
1168 | "source": [
1169 | "a, b, c = 1, 2, 3 # you can also leave off parentheses"
1170 | ]
1171 | },
1172 | {
1173 | "cell_type": "code",
1174 | "execution_count": 53,
1175 | "metadata": {},
1176 | "outputs": [
1177 | {
1178 | "data": {
1179 | "text/plain": [
1180 | "1"
1181 | ]
1182 | },
1183 | "execution_count": 53,
1184 | "metadata": {},
1185 | "output_type": "execute_result"
1186 | }
1187 | ],
1188 | "source": [
1189 | "a"
1190 | ]
1191 | },
1192 | {
1193 | "cell_type": "code",
1194 | "execution_count": 54,
1195 | "metadata": {},
1196 | "outputs": [
1197 | {
1198 | "data": {
1199 | "text/plain": [
1200 | "2"
1201 | ]
1202 | },
1203 | "execution_count": 54,
1204 | "metadata": {},
1205 | "output_type": "execute_result"
1206 | }
1207 | ],
1208 | "source": [
1209 | "b"
1210 | ]
1211 | },
1212 | {
1213 | "cell_type": "code",
1214 | "execution_count": 55,
1215 | "metadata": {},
1216 | "outputs": [
1217 | {
1218 | "data": {
1219 | "text/plain": [
1220 | "3"
1221 | ]
1222 | },
1223 | "execution_count": 55,
1224 | "metadata": {},
1225 | "output_type": "execute_result"
1226 | }
1227 | ],
1228 | "source": [
1229 | "c"
1230 | ]
1231 | },
1232 | {
1233 | "cell_type": "code",
1234 | "execution_count": 56,
1235 | "metadata": {},
1236 | "outputs": [
1237 | {
1238 | "data": {
1239 | "text/plain": [
1240 | "(1,)"
1241 | ]
1242 | },
1243 | "execution_count": 56,
1244 | "metadata": {},
1245 | "output_type": "execute_result"
1246 | }
1247 | ],
1248 | "source": [
1249 | "(1,) # to create a single element tuple, you must add the \",\" at the end"
1250 | ]
1251 | },
1252 | {
1253 | "cell_type": "code",
1254 | "execution_count": 57,
1255 | "metadata": {},
1256 | "outputs": [
1257 | {
1258 | "data": {
1259 | "text/plain": [
1260 | "Tuple{Int64}"
1261 | ]
1262 | },
1263 | "execution_count": 57,
1264 | "metadata": {},
1265 | "output_type": "execute_result"
1266 | }
1267 | ],
1268 | "source": [
1269 | "typeof((1,))"
1270 | ]
1271 | },
1272 | {
1273 | "cell_type": "code",
1274 | "execution_count": 58,
1275 | "metadata": {},
1276 | "outputs": [
1277 | {
1278 | "data": {
1279 | "text/plain": [
1280 | "(x = 1, y = 2)"
1281 | ]
1282 | },
1283 | "execution_count": 58,
1284 | "metadata": {},
1285 | "output_type": "execute_result"
1286 | }
1287 | ],
1288 | "source": [
1289 | "n = (x=1, y=2) # use keyword assignments in a tuple to create a NamedTuple"
1290 | ]
1291 | },
1292 | {
1293 | "cell_type": "code",
1294 | "execution_count": 59,
1295 | "metadata": {},
1296 | "outputs": [
1297 | {
1298 | "data": {
1299 | "text/plain": [
1300 | "NamedTuple{(:x, :y),Tuple{Int64,Int64}}"
1301 | ]
1302 | },
1303 | "execution_count": 59,
1304 | "metadata": {},
1305 | "output_type": "execute_result"
1306 | }
1307 | ],
1308 | "source": [
1309 | "typeof(n)"
1310 | ]
1311 | },
1312 | {
1313 | "cell_type": "code",
1314 | "execution_count": 60,
1315 | "metadata": {},
1316 | "outputs": [
1317 | {
1318 | "data": {
1319 | "text/plain": [
1320 | "1"
1321 | ]
1322 | },
1323 | "execution_count": 60,
1324 | "metadata": {},
1325 | "output_type": "execute_result"
1326 | }
1327 | ],
1328 | "source": [
1329 | "n.x # NamedTuple fields can be accessed using dot syntax"
1330 | ]
1331 | },
1332 | {
1333 | "cell_type": "markdown",
1334 | "metadata": {},
1335 | "source": [
1336 | "## Dictionaries"
1337 | ]
1338 | },
1339 | {
1340 | "cell_type": "code",
1341 | "execution_count": 61,
1342 | "metadata": {},
1343 | "outputs": [
1344 | {
1345 | "data": {
1346 | "text/plain": [
1347 | "Dict{Any,Any} with 0 entries"
1348 | ]
1349 | },
1350 | "execution_count": 61,
1351 | "metadata": {},
1352 | "output_type": "execute_result"
1353 | }
1354 | ],
1355 | "source": [
1356 | "d = Dict()"
1357 | ]
1358 | },
1359 | {
1360 | "cell_type": "code",
1361 | "execution_count": 62,
1362 | "metadata": {},
1363 | "outputs": [
1364 | {
1365 | "data": {
1366 | "text/plain": [
1367 | "Dict{String,Int64} with 3 entries:\n",
1368 | " \"two\" => 2\n",
1369 | " \"one\" => 1\n",
1370 | " \"three\" => 3"
1371 | ]
1372 | },
1373 | "execution_count": 62,
1374 | "metadata": {},
1375 | "output_type": "execute_result"
1376 | }
1377 | ],
1378 | "source": [
1379 | "d = Dict(\"one\"=>1, \"two\"=>2, \"three\"=>3)"
1380 | ]
1381 | },
1382 | {
1383 | "cell_type": "code",
1384 | "execution_count": 63,
1385 | "metadata": {},
1386 | "outputs": [
1387 | {
1388 | "data": {
1389 | "text/plain": [
1390 | "1"
1391 | ]
1392 | },
1393 | "execution_count": 63,
1394 | "metadata": {},
1395 | "output_type": "execute_result"
1396 | }
1397 | ],
1398 | "source": [
1399 | "d[\"one\"]"
1400 | ]
1401 | },
1402 | {
1403 | "cell_type": "code",
1404 | "execution_count": 64,
1405 | "metadata": {},
1406 | "outputs": [
1407 | {
1408 | "data": {
1409 | "text/plain": [
1410 | "Base.KeySet for a Dict{String,Int64} with 3 entries. Keys:\n",
1411 | " \"two\"\n",
1412 | " \"one\"\n",
1413 | " \"three\""
1414 | ]
1415 | },
1416 | "execution_count": 64,
1417 | "metadata": {},
1418 | "output_type": "execute_result"
1419 | }
1420 | ],
1421 | "source": [
1422 | "keys(d)"
1423 | ]
1424 | },
1425 | {
1426 | "cell_type": "code",
1427 | "execution_count": 65,
1428 | "metadata": {},
1429 | "outputs": [
1430 | {
1431 | "data": {
1432 | "text/plain": [
1433 | "3-element Array{String,1}:\n",
1434 | " \"two\" \n",
1435 | " \"one\" \n",
1436 | " \"three\""
1437 | ]
1438 | },
1439 | "execution_count": 65,
1440 | "metadata": {},
1441 | "output_type": "execute_result"
1442 | }
1443 | ],
1444 | "source": [
1445 | "collect(keys(d))"
1446 | ]
1447 | },
1448 | {
1449 | "cell_type": "code",
1450 | "execution_count": 66,
1451 | "metadata": {},
1452 | "outputs": [
1453 | {
1454 | "data": {
1455 | "text/plain": [
1456 | "Base.ValueIterator for a Dict{String,Int64} with 3 entries. Values:\n",
1457 | " 2\n",
1458 | " 1\n",
1459 | " 3"
1460 | ]
1461 | },
1462 | "execution_count": 66,
1463 | "metadata": {},
1464 | "output_type": "execute_result"
1465 | }
1466 | ],
1467 | "source": [
1468 | "values(d)"
1469 | ]
1470 | },
1471 | {
1472 | "cell_type": "code",
1473 | "execution_count": 67,
1474 | "metadata": {},
1475 | "outputs": [
1476 | {
1477 | "data": {
1478 | "text/plain": [
1479 | "true"
1480 | ]
1481 | },
1482 | "execution_count": 67,
1483 | "metadata": {},
1484 | "output_type": "execute_result"
1485 | }
1486 | ],
1487 | "source": [
1488 | "haskey(d, \"one\")"
1489 | ]
1490 | },
1491 | {
1492 | "cell_type": "code",
1493 | "execution_count": 68,
1494 | "metadata": {},
1495 | "outputs": [
1496 | {
1497 | "data": {
1498 | "text/plain": [
1499 | "false"
1500 | ]
1501 | },
1502 | "execution_count": 68,
1503 | "metadata": {},
1504 | "output_type": "execute_result"
1505 | }
1506 | ],
1507 | "source": [
1508 | "haskey(d, 1)"
1509 | ]
1510 | },
1511 | {
1512 | "cell_type": "markdown",
1513 | "metadata": {},
1514 | "source": [
1515 | "## Control Flow"
1516 | ]
1517 | },
1518 | {
1519 | "cell_type": "code",
1520 | "execution_count": 69,
1521 | "metadata": {},
1522 | "outputs": [
1523 | {
1524 | "name": "stdout",
1525 | "output_type": "stream",
1526 | "text": [
1527 | "some_var is smaller than 10.\n"
1528 | ]
1529 | }
1530 | ],
1531 | "source": [
1532 | "# Let's make a variable\n",
1533 | "some_var = 5\n",
1534 | "\n",
1535 | "# Here is an if statement. Indentation is not meaningful in Julia.\n",
1536 | "if some_var > 10\n",
1537 | " println(\"some_var is totally bigger than 10.\")\n",
1538 | "elseif some_var < 10 # This elseif clause is optional.\n",
1539 | " println(\"some_var is smaller than 10.\")\n",
1540 | "else # The else clause is optional too.\n",
1541 | " println(\"some_var is indeed 10.\")\n",
1542 | "end"
1543 | ]
1544 | },
1545 | {
1546 | "cell_type": "code",
1547 | "execution_count": 70,
1548 | "metadata": {},
1549 | "outputs": [
1550 | {
1551 | "name": "stdout",
1552 | "output_type": "stream",
1553 | "text": [
1554 | "dog is a mammal\n",
1555 | "cat is a mammal\n",
1556 | "mouse is a mammal\n"
1557 | ]
1558 | }
1559 | ],
1560 | "source": [
1561 | "# For loops iterate over iterables.\n",
1562 | "# Iterable types include Range, Array, Set, Dict, and String.\n",
1563 | "for animal in [\"dog\", \"cat\", \"mouse\"]\n",
1564 | " println(\"$animal is a mammal\")\n",
1565 | " # You can use $ to interpolate variables or expression into strings\n",
1566 | "end"
1567 | ]
1568 | },
1569 | {
1570 | "cell_type": "code",
1571 | "execution_count": 71,
1572 | "metadata": {},
1573 | "outputs": [
1574 | {
1575 | "name": "stdout",
1576 | "output_type": "stream",
1577 | "text": [
1578 | "mouse is a mammal\n",
1579 | "cat is a mammal\n",
1580 | "dog is a mammal\n"
1581 | ]
1582 | }
1583 | ],
1584 | "source": [
1585 | "for key_val in Dict(\"dog\"=>\"mammal\",\"cat\"=>\"mammal\",\"mouse\"=>\"mammal\")\n",
1586 | " println(\"$(key_val[1]) is a $(key_val[2])\")\n",
1587 | "end"
1588 | ]
1589 | },
1590 | {
1591 | "cell_type": "code",
1592 | "execution_count": 72,
1593 | "metadata": {},
1594 | "outputs": [
1595 | {
1596 | "name": "stdout",
1597 | "output_type": "stream",
1598 | "text": [
1599 | "mouse is a mammal\n",
1600 | "cat is a mammal\n",
1601 | "dog is a mammal\n"
1602 | ]
1603 | }
1604 | ],
1605 | "source": [
1606 | "for (k,v) in Dict(\"dog\"=>\"mammal\",\"cat\"=>\"mammal\",\"mouse\"=>\"mammal\")\n",
1607 | " println(\"$k is a $v\")\n",
1608 | "end"
1609 | ]
1610 | },
1611 | {
1612 | "cell_type": "code",
1613 | "execution_count": 73,
1614 | "metadata": {},
1615 | "outputs": [
1616 | {
1617 | "name": "stdout",
1618 | "output_type": "stream",
1619 | "text": [
1620 | "0\n",
1621 | "1\n",
1622 | "2\n",
1623 | "3\n"
1624 | ]
1625 | }
1626 | ],
1627 | "source": [
1628 | "x = 0\n",
1629 | "while x < 4\n",
1630 | " global x\n",
1631 | " println(x)\n",
1632 | " x += 1 # Shorthand for x = x + 1\n",
1633 | "end"
1634 | ]
1635 | },
1636 | {
1637 | "cell_type": "code",
1638 | "execution_count": 74,
1639 | "metadata": {},
1640 | "outputs": [],
1641 | "source": [
1642 | "# Handle exceptions with a try/catch block\n",
1643 | "try\n",
1644 | "# error(\"help\")\n",
1645 | "catch e\n",
1646 | " println(\"caught it $e\")\n",
1647 | "end"
1648 | ]
1649 | },
1650 | {
1651 | "cell_type": "markdown",
1652 | "metadata": {},
1653 | "source": [
1654 | "## Functions"
1655 | ]
1656 | },
1657 | {
1658 | "cell_type": "code",
1659 | "execution_count": 75,
1660 | "metadata": {},
1661 | "outputs": [
1662 | {
1663 | "name": "stdout",
1664 | "output_type": "stream",
1665 | "text": [
1666 | "x is 5 and y is 6\n"
1667 | ]
1668 | },
1669 | {
1670 | "data": {
1671 | "text/plain": [
1672 | "11"
1673 | ]
1674 | },
1675 | "execution_count": 75,
1676 | "metadata": {},
1677 | "output_type": "execute_result"
1678 | }
1679 | ],
1680 | "source": [
1681 | "function add(x, y)\n",
1682 | " println(\"x is $x and y is $y\")\n",
1683 | " # Functions return the value of their last statement (or where you specify \"return\")\n",
1684 | " x + y\n",
1685 | "end\n",
1686 | "add(5, 6) "
1687 | ]
1688 | },
1689 | {
1690 | "cell_type": "code",
1691 | "execution_count": 76,
1692 | "metadata": {},
1693 | "outputs": [
1694 | {
1695 | "data": {
1696 | "text/plain": [
1697 | "defaults (generic function with 3 methods)"
1698 | ]
1699 | },
1700 | "execution_count": 76,
1701 | "metadata": {},
1702 | "output_type": "execute_result"
1703 | }
1704 | ],
1705 | "source": [
1706 | "# You can define functions with optional positional arguments\n",
1707 | "function defaults(a,b,x=5,y=6)\n",
1708 | " return \"$a $b and $x $y\"\n",
1709 | "end"
1710 | ]
1711 | },
1712 | {
1713 | "cell_type": "code",
1714 | "execution_count": 77,
1715 | "metadata": {},
1716 | "outputs": [
1717 | {
1718 | "data": {
1719 | "text/plain": [
1720 | "\"h g and 5 6\""
1721 | ]
1722 | },
1723 | "execution_count": 77,
1724 | "metadata": {},
1725 | "output_type": "execute_result"
1726 | }
1727 | ],
1728 | "source": [
1729 | "defaults('h','g')"
1730 | ]
1731 | },
1732 | {
1733 | "cell_type": "code",
1734 | "execution_count": 78,
1735 | "metadata": {},
1736 | "outputs": [
1737 | {
1738 | "data": {
1739 | "text/plain": [
1740 | "\"h g and j 6\""
1741 | ]
1742 | },
1743 | "execution_count": 78,
1744 | "metadata": {},
1745 | "output_type": "execute_result"
1746 | }
1747 | ],
1748 | "source": [
1749 | "defaults('h','g','j')"
1750 | ]
1751 | },
1752 | {
1753 | "cell_type": "code",
1754 | "execution_count": 79,
1755 | "metadata": {},
1756 | "outputs": [
1757 | {
1758 | "data": {
1759 | "text/plain": [
1760 | "\"h g and j k\""
1761 | ]
1762 | },
1763 | "execution_count": 79,
1764 | "metadata": {},
1765 | "output_type": "execute_result"
1766 | }
1767 | ],
1768 | "source": [
1769 | "defaults('h','g','j','k')"
1770 | ]
1771 | },
1772 | {
1773 | "cell_type": "code",
1774 | "execution_count": 80,
1775 | "metadata": {},
1776 | "outputs": [
1777 | {
1778 | "data": {
1779 | "text/plain": [
1780 | "keyword_args (generic function with 1 method)"
1781 | ]
1782 | },
1783 | "execution_count": 80,
1784 | "metadata": {},
1785 | "output_type": "execute_result"
1786 | }
1787 | ],
1788 | "source": [
1789 | "# You can define functions that take keyword arguments\n",
1790 | "function keyword_args(;k1=4,name2=\"hello\") # note the ;\n",
1791 | " return Dict(\"k1\"=>k1,\"name2\"=>name2)\n",
1792 | "end"
1793 | ]
1794 | },
1795 | {
1796 | "cell_type": "code",
1797 | "execution_count": 81,
1798 | "metadata": {},
1799 | "outputs": [
1800 | {
1801 | "data": {
1802 | "text/plain": [
1803 | "Dict{String,Any} with 2 entries:\n",
1804 | " \"name2\" => \"ness\"\n",
1805 | " \"k1\" => 4"
1806 | ]
1807 | },
1808 | "execution_count": 81,
1809 | "metadata": {},
1810 | "output_type": "execute_result"
1811 | }
1812 | ],
1813 | "source": [
1814 | "keyword_args(name2=\"ness\")"
1815 | ]
1816 | },
1817 | {
1818 | "cell_type": "code",
1819 | "execution_count": 82,
1820 | "metadata": {},
1821 | "outputs": [
1822 | {
1823 | "data": {
1824 | "text/plain": [
1825 | "Dict{String,String} with 2 entries:\n",
1826 | " \"name2\" => \"hello\"\n",
1827 | " \"k1\" => \"mine\""
1828 | ]
1829 | },
1830 | "execution_count": 82,
1831 | "metadata": {},
1832 | "output_type": "execute_result"
1833 | }
1834 | ],
1835 | "source": [
1836 | "keyword_args(k1=\"mine\")"
1837 | ]
1838 | },
1839 | {
1840 | "cell_type": "code",
1841 | "execution_count": 83,
1842 | "metadata": {},
1843 | "outputs": [
1844 | {
1845 | "data": {
1846 | "text/plain": [
1847 | "Dict{String,Any} with 2 entries:\n",
1848 | " \"name2\" => \"hello\"\n",
1849 | " \"k1\" => 4"
1850 | ]
1851 | },
1852 | "execution_count": 83,
1853 | "metadata": {},
1854 | "output_type": "execute_result"
1855 | }
1856 | ],
1857 | "source": [
1858 | "keyword_args()"
1859 | ]
1860 | },
1861 | {
1862 | "cell_type": "code",
1863 | "execution_count": 84,
1864 | "metadata": {},
1865 | "outputs": [
1866 | {
1867 | "data": {
1868 | "text/plain": [
1869 | "true"
1870 | ]
1871 | },
1872 | "execution_count": 84,
1873 | "metadata": {},
1874 | "output_type": "execute_result"
1875 | }
1876 | ],
1877 | "source": [
1878 | "# This is \"stabby lambda syntax\" for creating anonymous functions\n",
1879 | "(x -> x > 2)(3) # => true"
1880 | ]
1881 | },
1882 | {
1883 | "cell_type": "code",
1884 | "execution_count": 85,
1885 | "metadata": {},
1886 | "outputs": [
1887 | {
1888 | "data": {
1889 | "text/plain": [
1890 | "create_adder (generic function with 1 method)"
1891 | ]
1892 | },
1893 | "execution_count": 85,
1894 | "metadata": {},
1895 | "output_type": "execute_result"
1896 | }
1897 | ],
1898 | "source": [
1899 | "# This function is identical to create_adder implementation above.\n",
1900 | "function create_adder(x)\n",
1901 | " y -> x + y\n",
1902 | "end"
1903 | ]
1904 | },
1905 | {
1906 | "cell_type": "code",
1907 | "execution_count": 86,
1908 | "metadata": {},
1909 | "outputs": [
1910 | {
1911 | "data": {
1912 | "text/plain": [
1913 | "create_adder2 (generic function with 1 method)"
1914 | ]
1915 | },
1916 | "execution_count": 86,
1917 | "metadata": {},
1918 | "output_type": "execute_result"
1919 | }
1920 | ],
1921 | "source": [
1922 | "# You can also name the internal function, if you want\n",
1923 | "function create_adder2(x)\n",
1924 | " function adder(y)\n",
1925 | " x + y\n",
1926 | " end\n",
1927 | " adder\n",
1928 | "end\n"
1929 | ]
1930 | },
1931 | {
1932 | "cell_type": "code",
1933 | "execution_count": 87,
1934 | "metadata": {},
1935 | "outputs": [
1936 | {
1937 | "data": {
1938 | "text/plain": [
1939 | "13"
1940 | ]
1941 | },
1942 | "execution_count": 87,
1943 | "metadata": {},
1944 | "output_type": "execute_result"
1945 | }
1946 | ],
1947 | "source": [
1948 | "add_10 = create_adder(10)\n",
1949 | "add_10(3) "
1950 | ]
1951 | },
1952 | {
1953 | "cell_type": "code",
1954 | "execution_count": 88,
1955 | "metadata": {},
1956 | "outputs": [
1957 | {
1958 | "data": {
1959 | "text/plain": [
1960 | "3-element Array{Int64,1}:\n",
1961 | " 11\n",
1962 | " 12\n",
1963 | " 13"
1964 | ]
1965 | },
1966 | "execution_count": 88,
1967 | "metadata": {},
1968 | "output_type": "execute_result"
1969 | }
1970 | ],
1971 | "source": [
1972 | "map(add_10, [1,2,3])"
1973 | ]
1974 | },
1975 | {
1976 | "cell_type": "code",
1977 | "execution_count": 89,
1978 | "metadata": {},
1979 | "outputs": [
1980 | {
1981 | "data": {
1982 | "text/plain": [
1983 | "2-element Array{Int64,1}:\n",
1984 | " 6\n",
1985 | " 7"
1986 | ]
1987 | },
1988 | "execution_count": 89,
1989 | "metadata": {},
1990 | "output_type": "execute_result"
1991 | }
1992 | ],
1993 | "source": [
1994 | "filter(x -> x > 5, [3, 4, 5, 6, 7])"
1995 | ]
1996 | },
1997 | {
1998 | "cell_type": "code",
1999 | "execution_count": 90,
2000 | "metadata": {},
2001 | "outputs": [
2002 | {
2003 | "data": {
2004 | "text/plain": [
2005 | "3-element Array{Int64,1}:\n",
2006 | " 11\n",
2007 | " 12\n",
2008 | " 13"
2009 | ]
2010 | },
2011 | "execution_count": 90,
2012 | "metadata": {},
2013 | "output_type": "execute_result"
2014 | }
2015 | ],
2016 | "source": [
2017 | "[add_10(i) for i in [1, 2, 3]]"
2018 | ]
2019 | },
2020 | {
2021 | "cell_type": "markdown",
2022 | "metadata": {},
2023 | "source": [
2024 | "## Composite Types"
2025 | ]
2026 | },
2027 | {
2028 | "cell_type": "code",
2029 | "execution_count": 91,
2030 | "metadata": {},
2031 | "outputs": [],
2032 | "source": [
2033 | "struct Tiger\n",
2034 | " taillength::Float64\n",
2035 | " coatcolor # not including a type annotation is the same as `::Any`\n",
2036 | "end"
2037 | ]
2038 | },
2039 | {
2040 | "cell_type": "code",
2041 | "execution_count": 92,
2042 | "metadata": {},
2043 | "outputs": [
2044 | {
2045 | "data": {
2046 | "text/plain": [
2047 | "Tiger(3.5, \"orange\")"
2048 | ]
2049 | },
2050 | "execution_count": 92,
2051 | "metadata": {},
2052 | "output_type": "execute_result"
2053 | }
2054 | ],
2055 | "source": [
2056 | "tigger = Tiger(3.5,\"orange\")"
2057 | ]
2058 | },
2059 | {
2060 | "cell_type": "code",
2061 | "execution_count": 93,
2062 | "metadata": {},
2063 | "outputs": [],
2064 | "source": [
2065 | "abstract type Cat end # just a name and point in the type hierarchy"
2066 | ]
2067 | },
2068 | {
2069 | "cell_type": "code",
2070 | "execution_count": 94,
2071 | "metadata": {},
2072 | "outputs": [
2073 | {
2074 | "data": {
2075 | "text/plain": [
2076 | "2-element Array{Any,1}:\n",
2077 | " Complex\n",
2078 | " Real "
2079 | ]
2080 | },
2081 | "execution_count": 94,
2082 | "metadata": {},
2083 | "output_type": "execute_result"
2084 | }
2085 | ],
2086 | "source": [
2087 | "subtypes(Number)"
2088 | ]
2089 | },
2090 | {
2091 | "cell_type": "code",
2092 | "execution_count": 95,
2093 | "metadata": {},
2094 | "outputs": [
2095 | {
2096 | "data": {
2097 | "text/plain": [
2098 | "0-element Array{Any,1}"
2099 | ]
2100 | },
2101 | "execution_count": 95,
2102 | "metadata": {},
2103 | "output_type": "execute_result"
2104 | }
2105 | ],
2106 | "source": [
2107 | "subtypes(Cat)"
2108 | ]
2109 | },
2110 | {
2111 | "cell_type": "code",
2112 | "execution_count": 96,
2113 | "metadata": {},
2114 | "outputs": [],
2115 | "source": [
2116 | "# <: is the subtyping operator\n",
2117 | "struct Lion <: Cat # Lion is a subtype of Cat\n",
2118 | " mane_color\n",
2119 | " roar::String\n",
2120 | "end"
2121 | ]
2122 | },
2123 | {
2124 | "cell_type": "code",
2125 | "execution_count": 97,
2126 | "metadata": {},
2127 | "outputs": [],
2128 | "source": [
2129 | "# You can define more constructors for your type\n",
2130 | "# Just define a function of the same name as the type\n",
2131 | "# and call an existing constructor to get a value of the correct type\n",
2132 | "Lion(roar::String) = Lion(\"green\",roar);\n",
2133 | "# This is an outer constructor because it's outside the type definition\n",
2134 | "# Note, the semicolon suppresses the output"
2135 | ]
2136 | },
2137 | {
2138 | "cell_type": "code",
2139 | "execution_count": 98,
2140 | "metadata": {},
2141 | "outputs": [],
2142 | "source": [
2143 | "struct Panther <: Cat # Panther is also a subtype of Cat\n",
2144 | " eye_color\n",
2145 | " Panther() = new(\"green\")\n",
2146 | " # Panthers will only have this constructor, and no default constructor.\n",
2147 | "end"
2148 | ]
2149 | },
2150 | {
2151 | "cell_type": "code",
2152 | "execution_count": 99,
2153 | "metadata": {},
2154 | "outputs": [
2155 | {
2156 | "data": {
2157 | "text/plain": [
2158 | "2-element Array{Any,1}:\n",
2159 | " Lion \n",
2160 | " Panther"
2161 | ]
2162 | },
2163 | "execution_count": 99,
2164 | "metadata": {},
2165 | "output_type": "execute_result"
2166 | }
2167 | ],
2168 | "source": [
2169 | "subtypes(Cat)"
2170 | ]
2171 | },
2172 | {
2173 | "cell_type": "markdown",
2174 | "metadata": {},
2175 | "source": [
2176 | "## Multiple Dispatch"
2177 | ]
2178 | },
2179 | {
2180 | "cell_type": "code",
2181 | "execution_count": 100,
2182 | "metadata": {},
2183 | "outputs": [
2184 | {
2185 | "data": {
2186 | "text/plain": [
2187 | "meow (generic function with 3 methods)"
2188 | ]
2189 | },
2190 | "execution_count": 100,
2191 | "metadata": {},
2192 | "output_type": "execute_result"
2193 | }
2194 | ],
2195 | "source": [
2196 | "function meow(animal::Lion)\n",
2197 | " animal.roar # access type properties using dot notation\n",
2198 | "end\n",
2199 | "\n",
2200 | "function meow(animal::Panther)\n",
2201 | " \"grrr\"\n",
2202 | "end\n",
2203 | "\n",
2204 | "function meow(animal::Tiger)\n",
2205 | " \"rawwwr\"\n",
2206 | "end"
2207 | ]
2208 | },
2209 | {
2210 | "cell_type": "code",
2211 | "execution_count": 101,
2212 | "metadata": {},
2213 | "outputs": [
2214 | {
2215 | "data": {
2216 | "text/plain": [
2217 | "\"rawwwr\""
2218 | ]
2219 | },
2220 | "execution_count": 101,
2221 | "metadata": {},
2222 | "output_type": "execute_result"
2223 | }
2224 | ],
2225 | "source": [
2226 | "meow(tigger)"
2227 | ]
2228 | },
2229 | {
2230 | "cell_type": "code",
2231 | "execution_count": 102,
2232 | "metadata": {},
2233 | "outputs": [
2234 | {
2235 | "data": {
2236 | "text/plain": [
2237 | "\"ROAAR\""
2238 | ]
2239 | },
2240 | "execution_count": 102,
2241 | "metadata": {},
2242 | "output_type": "execute_result"
2243 | }
2244 | ],
2245 | "source": [
2246 | "meow(Lion(\"brown\",\"ROAAR\"))"
2247 | ]
2248 | },
2249 | {
2250 | "cell_type": "code",
2251 | "execution_count": 103,
2252 | "metadata": {},
2253 | "outputs": [
2254 | {
2255 | "data": {
2256 | "text/plain": [
2257 | "\"grrr\""
2258 | ]
2259 | },
2260 | "execution_count": 103,
2261 | "metadata": {},
2262 | "output_type": "execute_result"
2263 | }
2264 | ],
2265 | "source": [
2266 | "meow(Panther())"
2267 | ]
2268 | },
2269 | {
2270 | "cell_type": "markdown",
2271 | "metadata": {},
2272 | "source": [
2273 | "## Native Code"
2274 | ]
2275 | },
2276 | {
2277 | "cell_type": "code",
2278 | "execution_count": 104,
2279 | "metadata": {},
2280 | "outputs": [
2281 | {
2282 | "data": {
2283 | "text/plain": [
2284 | "square (generic function with 1 method)"
2285 | ]
2286 | },
2287 | "execution_count": 104,
2288 | "metadata": {},
2289 | "output_type": "execute_result"
2290 | }
2291 | ],
2292 | "source": [
2293 | "square(l) = l * l"
2294 | ]
2295 | },
2296 | {
2297 | "cell_type": "code",
2298 | "execution_count": 105,
2299 | "metadata": {},
2300 | "outputs": [
2301 | {
2302 | "data": {
2303 | "text/plain": [
2304 | "25"
2305 | ]
2306 | },
2307 | "execution_count": 105,
2308 | "metadata": {},
2309 | "output_type": "execute_result"
2310 | }
2311 | ],
2312 | "source": [
2313 | "square(5)"
2314 | ]
2315 | },
2316 | {
2317 | "cell_type": "code",
2318 | "execution_count": 106,
2319 | "metadata": {},
2320 | "outputs": [
2321 | {
2322 | "name": "stdout",
2323 | "output_type": "stream",
2324 | "text": [
2325 | "\t.text\n",
2326 | "; ┌ @ In[104]:1 within `square'\n",
2327 | "\tpushq\t%rbp\n",
2328 | "\tmovq\t%rsp, %rbp\n",
2329 | "; │┌ @ int.jl:54 within `*'\n",
2330 | "\timull\t%ecx, %ecx\n",
2331 | "; │└\n",
2332 | "\tmovl\t%ecx, %eax\n",
2333 | "\tpopq\t%rbp\n",
2334 | "\tretq\n",
2335 | "\tnopl\t(%rax,%rax)\n",
2336 | "; └\n"
2337 | ]
2338 | }
2339 | ],
2340 | "source": [
2341 | "code_native(square, (Int32,))"
2342 | ]
2343 | },
2344 | {
2345 | "cell_type": "code",
2346 | "execution_count": 107,
2347 | "metadata": {},
2348 | "outputs": [
2349 | {
2350 | "name": "stdout",
2351 | "output_type": "stream",
2352 | "text": [
2353 | "\t.text\n",
2354 | "; ┌ @ In[104]:1 within `square'\n",
2355 | "\tpushq\t%rbp\n",
2356 | "\tmovq\t%rsp, %rbp\n",
2357 | "; │┌ @ float.jl:399 within `*'\n",
2358 | "\tvmulsd\t%xmm0, %xmm0, %xmm0\n",
2359 | "; │└\n",
2360 | "\tpopq\t%rbp\n",
2361 | "\tretq\n",
2362 | "\tnopw\t(%rax,%rax)\n",
2363 | "; └\n"
2364 | ]
2365 | }
2366 | ],
2367 | "source": [
2368 | "code_native(square, (Float64,))"
2369 | ]
2370 | },
2371 | {
2372 | "cell_type": "code",
2373 | "execution_count": 108,
2374 | "metadata": {},
2375 | "outputs": [
2376 | {
2377 | "name": "stdout",
2378 | "output_type": "stream",
2379 | "text": [
2380 | "\n",
2381 | "; @ In[104]:1 within `square'\n",
2382 | "; Function Attrs: uwtable\n",
2383 | "define i32 @julia_square_17289(i32) #0 {\n",
2384 | "top:\n",
2385 | "; ┌ @ int.jl:54 within `*'\n",
2386 | " %1 = mul i32 %0, %0\n",
2387 | "; └\n",
2388 | " ret i32 %1\n",
2389 | "}\n"
2390 | ]
2391 | }
2392 | ],
2393 | "source": [
2394 | "code_llvm(square, (Int32,))"
2395 | ]
2396 | }
2397 | ],
2398 | "metadata": {
2399 | "@webio": {
2400 | "lastCommId": null,
2401 | "lastKernelId": null
2402 | },
2403 | "kernelspec": {
2404 | "display_name": "Julia 1.2.0",
2405 | "language": "julia",
2406 | "name": "julia-1.2"
2407 | },
2408 | "language_info": {
2409 | "file_extension": ".jl",
2410 | "mimetype": "application/julia",
2411 | "name": "julia",
2412 | "version": "1.2.0"
2413 | }
2414 | },
2415 | "nbformat": 4,
2416 | "nbformat_minor": 1
2417 | }
2418 |
--------------------------------------------------------------------------------
/03-Inference.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Inference"
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": 1,
13 | "metadata": {},
14 | "outputs": [],
15 | "source": [
16 | "using Distributions\n",
17 | "using BayesNets"
18 | ]
19 | },
20 | {
21 | "cell_type": "markdown",
22 | "metadata": {},
23 | "source": [
24 | "## Inference for Classification"
25 | ]
26 | },
27 | {
28 | "cell_type": "code",
29 | "execution_count": 2,
30 | "metadata": {},
31 | "outputs": [
32 | {
33 | "data": {
34 | "image/svg+xml": [
35 | "\n",
36 | "\n",
129 | "\n"
130 | ],
131 | "text/plain": [
132 | "BayesNet{CPD}({3, 2} directed simple Int64 graph, CPD[StaticCPD{NamedCategorical{String}}(:Class, Symbol[], NamedCategorical with entries:\n",
133 | "\t 0.5000: aircraft\n",
134 | "\t 0.5000: bird\n",
135 | "), FunctionalCPD{Normal}(:Airspeed, Symbol[:Class], airspeedDistributions), FunctionalCPD{NamedCategorical}(:Fluctuation, Symbol[:Class], fluctuationDistributions)], Dict(:Class => 1,:Fluctuation => 3,:Airspeed => 2))"
136 | ]
137 | },
138 | "execution_count": 2,
139 | "metadata": {},
140 | "output_type": "execute_result"
141 | }
142 | ],
143 | "source": [
144 | "b = BayesNet()\n",
145 | "\n",
146 | "# Set uniform prior over Class\n",
147 | "push!(b, StaticCPD(:Class, NamedCategorical([\"bird\", \"aircraft\"], [0.5, 0.5])))\n",
148 | "\n",
149 | "fluctuationStates = [\"low\", \"hi\"]\n",
150 | "fluctuationDistributions(a::Assignment) = a[:Class] == \"bird\" ? NamedCategorical(fluctuationStates, [0.1, 0.9]) : NamedCategorical(fluctuationStates, [0.9, 0.1])\n",
151 | "push!(b, FunctionalCPD{NamedCategorical}(:Fluctuation, [:Class], fluctuationDistributions))\n",
152 | "\n",
153 | "# if Bird, then Airspeed ~ N(45,10)\n",
154 | "# if Aircraft, then Airspeed ~ N(100,40)\n",
155 | "airspeedDistributions(a::Assignment) = a[:Class] == \"bird\" ? Normal(45,10) : Normal(100,40)\n",
156 | "push!(b, FunctionalCPD{Normal}(:Airspeed, [:Class], airspeedDistributions))"
157 | ]
158 | },
159 | {
160 | "cell_type": "code",
161 | "execution_count": 3,
162 | "metadata": {},
163 | "outputs": [
164 | {
165 | "data": {
166 | "text/plain": [
167 | "0.00026995483256594033"
168 | ]
169 | },
170 | "execution_count": 3,
171 | "metadata": {},
172 | "output_type": "execute_result"
173 | }
174 | ],
175 | "source": [
176 | "pb = pdf(b, :Class=>\"bird\", :Airspeed=>65, :Fluctuation=>\"low\")"
177 | ]
178 | },
179 | {
180 | "cell_type": "code",
181 | "execution_count": 4,
182 | "metadata": {},
183 | "outputs": [
184 | {
185 | "data": {
186 | "text/plain": [
187 | "0.003060618731758615"
188 | ]
189 | },
190 | "execution_count": 4,
191 | "metadata": {},
192 | "output_type": "execute_result"
193 | }
194 | ],
195 | "source": [
196 | "pa = pdf(b, :Class=>\"aircraft\", :Airspeed=>65, :Fluctuation=>\"low\")"
197 | ]
198 | },
199 | {
200 | "cell_type": "code",
201 | "execution_count": 5,
202 | "metadata": {},
203 | "outputs": [
204 | {
205 | "data": {
206 | "text/plain": [
207 | "0.9189464435022358"
208 | ]
209 | },
210 | "execution_count": 5,
211 | "metadata": {},
212 | "output_type": "execute_result"
213 | }
214 | ],
215 | "source": [
216 | "# Probability of aircraft given data\n",
217 | "pa / (pa + pb)"
218 | ]
219 | },
220 | {
221 | "cell_type": "code",
222 | "execution_count": 6,
223 | "metadata": {},
224 | "outputs": [
225 | {
226 | "data": {
227 | "text/plain": [
228 | "2-element Array{Float64,1}:\n",
229 | " 0.00026995483256594033\n",
230 | " 0.003060618731758615 "
231 | ]
232 | },
233 | "execution_count": 6,
234 | "metadata": {},
235 | "output_type": "execute_result"
236 | }
237 | ],
238 | "source": [
239 | "# View (unnormalized) distribution as a vector\n",
240 | "d = [pb, pa]"
241 | ]
242 | },
243 | {
244 | "cell_type": "code",
245 | "execution_count": 7,
246 | "metadata": {},
247 | "outputs": [
248 | {
249 | "data": {
250 | "text/plain": [
251 | "2-element Array{Float64,1}:\n",
252 | " 0.08105355649776423\n",
253 | " 0.9189464435022358 "
254 | ]
255 | },
256 | "execution_count": 7,
257 | "metadata": {},
258 | "output_type": "execute_result"
259 | }
260 | ],
261 | "source": [
262 | "# Now normalize\n",
263 | "d / sum(d)"
264 | ]
265 | },
266 | {
267 | "cell_type": "markdown",
268 | "metadata": {},
269 | "source": [
270 | "## Inference in temporal models"
271 | ]
272 | },
273 | {
274 | "cell_type": "markdown",
275 | "metadata": {},
276 | "source": [
277 | "Here is a simple crying baby temporal model. Whether the baby is crying is a noisy indication of whether the baby is hungry."
278 | ]
279 | },
280 | {
281 | "cell_type": "code",
282 | "execution_count": 8,
283 | "metadata": {},
284 | "outputs": [],
285 | "source": [
286 | "struct State\n",
287 | " hungry\n",
288 | "end\n",
289 | "struct Observation\n",
290 | " crying\n",
291 | "end\n",
292 | "\n",
293 | "States = [State(false), State(true)]\n",
294 | "Observations = [Observation(false), Observation(true)]\n",
295 | "\n",
296 | "# P(o|s)\n",
297 | "function P(o::Observation, s::State)\n",
298 | " if s.hungry\n",
299 | " return o.crying ? 0.8 : 0.2\n",
300 | " else\n",
301 | " return o.crying ? 0.1 : 0.9\n",
302 | " end\n",
303 | "end\n",
304 | "\n",
305 | "# P(s' | s)\n",
306 | "function P(s1::State, s0::State)\n",
307 | " if s0.hungry\n",
308 | " return s1.hungry ? 0.9 : 0.1\n",
309 | " else\n",
310 | " return s1.hungry ? 0.6 : 0.4\n",
311 | " end\n",
312 | "end\n",
313 | "\n",
314 | "# P(s)\n",
315 | "P(s::State) = 1/length(States)\n",
316 | "\n",
317 | "mutable struct Belief\n",
318 | " p::Vector{Float64}\n",
319 | "end"
320 | ]
321 | },
322 | {
323 | "cell_type": "markdown",
324 | "metadata": {},
325 | "source": [
326 | "Here are some sampling functions."
327 | ]
328 | },
329 | {
330 | "cell_type": "code",
331 | "execution_count": 9,
332 | "metadata": {},
333 | "outputs": [],
334 | "source": [
335 | "sampleState() = States[rand(Distributions.Categorical(Float64[P(s) for s in States]))]\n",
336 | "sampleState(s::State) = States[rand(Distributions.Categorical(Float64[P(s1, s) for s1 in States]))]\n",
337 | "sampleObservation(s::State) = Observations[rand(Distributions.Categorical(Float64[P(o, s) for o in Observations]))]\n",
338 | "function generateSequence(steps)\n",
339 | " S = State[]\n",
340 | " O = Observation[]\n",
341 | " s = sampleState()\n",
342 | " push!(S, s) \n",
343 | " o = sampleObservation(s)\n",
344 | " push!(O, o)\n",
345 | " for t = 2:steps\n",
346 | " s = sampleState(s)\n",
347 | " push!(S, s) \n",
348 | " o = sampleObservation(s)\n",
349 | " push!(O, o)\n",
350 | " end\n",
351 | " (S, O)\n",
352 | "end\n",
353 | "(S, O) = generateSequence(20);"
354 | ]
355 | },
356 | {
357 | "cell_type": "markdown",
358 | "metadata": {},
359 | "source": [
360 | "Update a belief as follows \n",
361 | "\n",
362 | "$b_1(s) \\propto P(o \\mid s) \\sum_{s'} P(s \\mid s') b_0(s')$"
363 | ]
364 | },
365 | {
366 | "cell_type": "code",
367 | "execution_count": 10,
368 | "metadata": {},
369 | "outputs": [],
370 | "source": [
371 | "function update(b0::Belief, o::Observation)\n",
372 | " b1 = Belief(zeros(length(States)))\n",
373 | " for i = 1:length(States)\n",
374 | " s1 = States[i]\n",
375 | " b1.p[i] = P(o, s1) * sum([P(s1,States[j]) * b0.p[j] for j = 1:length(States)])\n",
376 | " end\n",
377 | " b1.p = b1.p / sum(b1.p)\n",
378 | " b1\n",
379 | "end;"
380 | ]
381 | },
382 | {
383 | "cell_type": "code",
384 | "execution_count": 11,
385 | "metadata": {},
386 | "outputs": [
387 | {
388 | "name": "stdout",
389 | "output_type": "stream",
390 | "text": [
391 | "s\to\tP(hungry)\n",
392 | "1\t1\t0.960\n",
393 | "1\t1\t0.984\n",
394 | "0\t0\t0.655\n",
395 | "1\t1\t0.969\n",
396 | "0\t0\t0.644\n",
397 | "1\t1\t0.968\n",
398 | "1\t1\t0.985\n",
399 | "0\t0\t0.656\n",
400 | "1\t0\t0.465\n",
401 | "1\t1\t0.958\n",
402 | "1\t1\t0.984\n",
403 | "1\t0\t0.655\n",
404 | "1\t0\t0.465\n",
405 | "1\t1\t0.958\n",
406 | "1\t1\t0.984\n",
407 | "1\t1\t0.986\n",
408 | "1\t0\t0.656\n",
409 | "1\t1\t0.969\n",
410 | "1\t1\t0.985\n",
411 | "1\t1\t0.986\n"
412 | ]
413 | }
414 | ],
415 | "source": [
416 | "using Printf\n",
417 | "function printBeliefs(S::Vector{State}, O::Vector{Observation})\n",
418 | " print(\"s\\to\\tP(hungry)\\n\")\n",
419 | " n = length(S)\n",
420 | " b = Belief([0.5, 0.5])\n",
421 | " for t = 1:n\n",
422 | " b = update(b, O[t])\n",
423 | " Printf.@printf(\"%.0f\\t%.0f\\t%.3f\\n\", float(S[t].hungry), float(O[t].crying), b.p[2])\n",
424 | " end\n",
425 | "end\n",
426 | "printBeliefs(S, O)"
427 | ]
428 | },
429 | {
430 | "cell_type": "markdown",
431 | "metadata": {},
432 | "source": [
433 | "## Exact Inference"
434 | ]
435 | },
436 | {
437 | "cell_type": "code",
438 | "execution_count": 12,
439 | "metadata": {},
440 | "outputs": [
441 | {
442 | "data": {
443 | "image/svg+xml": [
444 | "\n",
445 | "\n",
494 | "\n"
495 | ],
496 | "text/plain": [
497 | "BayesNet{CategoricalCPD{DiscreteNonParametric{Int64,Float64,Base.OneTo{Int64},Array{Float64,1}}}}({5, 4} directed simple Int64 graph, CategoricalCPD{DiscreteNonParametric{Int64,Float64,Base.OneTo{Int64},Array{Float64,1}}}[2 instantiations:\n",
498 | " B (2), 2 instantiations:\n",
499 | " S (2), 8 instantiations:\n",
500 | " E (2)\n",
501 | " B (2)\n",
502 | " S (2), 4 instantiations:\n",
503 | " C (2)\n",
504 | " E (2), 4 instantiations:\n",
505 | " D (2)\n",
506 | " E (2)], Dict(:D => 5,:B => 1,:S => 2,:E => 3,:C => 4))"
507 | ]
508 | },
509 | "execution_count": 12,
510 | "metadata": {},
511 | "output_type": "execute_result"
512 | }
513 | ],
514 | "source": [
515 | "b = DiscreteBayesNet()\n",
516 | "push!(b, DiscreteCPD(:B, [0.1,0.9]))\n",
517 | "push!(b, DiscreteCPD(:S, [0.5,0.5]))\n",
518 | "push!(b, rand_cpd(b, 2, :E, [:B, :S]))\n",
519 | "push!(b, rand_cpd(b, 2, :D, [:E]))\n",
520 | "push!(b, rand_cpd(b, 2, :C, [:E]))"
521 | ]
522 | },
523 | {
524 | "cell_type": "markdown",
525 | "metadata": {},
526 | "source": [
527 | "Compute \n",
528 | "\n",
529 | "$P(b^1, d^1, c^1) = \\sum_s \\sum_e P(b^1)P(s)P(e \\mid b^1, s)P(d^1 \\mid e)P(c^1 \\mid e)$"
530 | ]
531 | },
532 | {
533 | "cell_type": "code",
534 | "execution_count": 13,
535 | "metadata": {},
536 | "outputs": [
537 | {
538 | "data": {
539 | "text/plain": [
540 | "Dict{Symbol,Any} with 3 entries:\n",
541 | " :D => 2\n",
542 | " :B => 2\n",
543 | " :C => 2"
544 | ]
545 | },
546 | "execution_count": 13,
547 | "metadata": {},
548 | "output_type": "execute_result"
549 | }
550 | ],
551 | "source": [
552 | "a = Assignment(:B=>2, :D=>2, :C=>2)"
553 | ]
554 | },
555 | {
556 | "cell_type": "code",
557 | "execution_count": 14,
558 | "metadata": {},
559 | "outputs": [
560 | {
561 | "data": {
562 | "text/html": [
563 | "
| B | S | E | D | C | p |
---|
| Int64⍰ | Int64⍰ | Int64⍰ | Int64⍰ | Int64⍰ | Float64 |
---|
4 rows × 6 columns
1 | 2 | 1 | 1 | 2 | 2 | 0.00399303 |
---|
2 | 2 | 1 | 2 | 2 | 2 | 0.0534892 |
---|
3 | 2 | 2 | 1 | 2 | 2 | 0.149917 |
---|
4 | 2 | 2 | 2 | 2 | 2 | 0.0145783 |
---|
"
564 | ],
565 | "text/plain": [
566 | "Table(4×6 DataFrame\n",
567 | "│ Row │ B │ S │ E │ D │ C │ p │\n",
568 | "│ │ \u001b[90mInt64⍰\u001b[39m │ \u001b[90mInt64⍰\u001b[39m │ \u001b[90mInt64⍰\u001b[39m │ \u001b[90mInt64⍰\u001b[39m │ \u001b[90mInt64⍰\u001b[39m │ \u001b[90mFloat64\u001b[39m │\n",
569 | "├─────┼────────┼────────┼────────┼────────┼────────┼────────────┤\n",
570 | "│ 1 │ 2 │ 1 │ 1 │ 2 │ 2 │ 0.00399303 │\n",
571 | "│ 2 │ 2 │ 1 │ 2 │ 2 │ 2 │ 0.0534892 │\n",
572 | "│ 3 │ 2 │ 2 │ 1 │ 2 │ 2 │ 0.149917 │\n",
573 | "│ 4 │ 2 │ 2 │ 2 │ 2 │ 2 │ 0.0145783 │)"
574 | ]
575 | },
576 | "execution_count": 14,
577 | "metadata": {},
578 | "output_type": "execute_result"
579 | }
580 | ],
581 | "source": [
582 | "T = table(b,:B,a)*table(b,:S)*table(b,:E,a)*table(b,:D,a)*table(b,:C,a)"
583 | ]
584 | },
585 | {
586 | "cell_type": "markdown",
587 | "metadata": {},
588 | "source": [
589 | "The character ⍰ indicates that the column can hold a `Missing` value"
590 | ]
591 | },
592 | {
593 | "cell_type": "code",
594 | "execution_count": 15,
595 | "metadata": {},
596 | "outputs": [
597 | {
598 | "data": {
599 | "text/html": [
600 | " | B | D | C | p |
---|
| Int64⍰ | Int64⍰ | Int64⍰ | Float64 |
---|
1 rows × 4 columns
1 | 2 | 2 | 2 | 0.221977 |
---|
"
601 | ],
602 | "text/plain": [
603 | "Table(1×4 DataFrame\n",
604 | "│ Row │ B │ D │ C │ p │\n",
605 | "│ │ \u001b[90mInt64⍰\u001b[39m │ \u001b[90mInt64⍰\u001b[39m │ \u001b[90mInt64⍰\u001b[39m │ \u001b[90mFloat64\u001b[39m │\n",
606 | "├─────┼────────┼────────┼────────┼──────────┤\n",
607 | "│ 1 │ 2 │ 2 │ 2 │ 0.221977 │)"
608 | ]
609 | },
610 | "execution_count": 15,
611 | "metadata": {},
612 | "output_type": "execute_result"
613 | }
614 | ],
615 | "source": [
616 | "sumout(T, [:S, :E])"
617 | ]
618 | },
619 | {
620 | "cell_type": "markdown",
621 | "metadata": {},
622 | "source": [
623 | "## Approximate Inference"
624 | ]
625 | },
626 | {
627 | "cell_type": "code",
628 | "execution_count": 16,
629 | "metadata": {},
630 | "outputs": [
631 | {
632 | "data": {
633 | "image/svg+xml": [
634 | "\n",
635 | "\n",
684 | "\n"
685 | ],
686 | "text/plain": [
687 | "BayesNet{CategoricalCPD{DiscreteNonParametric{Int64,Float64,Base.OneTo{Int64},Array{Float64,1}}}}({5, 4} directed simple Int64 graph, CategoricalCPD{DiscreteNonParametric{Int64,Float64,Base.OneTo{Int64},Array{Float64,1}}}[2 instantiations:\n",
688 | " B (2), 2 instantiations:\n",
689 | " S (2), 8 instantiations:\n",
690 | " E (2)\n",
691 | " B (2)\n",
692 | " S (2), 4 instantiations:\n",
693 | " C (2)\n",
694 | " E (2), 4 instantiations:\n",
695 | " D (2)\n",
696 | " E (2)], Dict(:D => 5,:B => 1,:S => 2,:E => 3,:C => 4))"
697 | ]
698 | },
699 | "execution_count": 16,
700 | "metadata": {},
701 | "output_type": "execute_result"
702 | }
703 | ],
704 | "source": [
705 | "b = DiscreteBayesNet()\n",
706 | "push!(b, DiscreteCPD(:B, [0.1,0.9]))\n",
707 | "push!(b, DiscreteCPD(:S, [0.5,0.5]))\n",
708 | "push!(b, rand_cpd(b, 2, :E, [:B, :S]))\n",
709 | "push!(b, rand_cpd(b, 2, :D, [:E]))\n",
710 | "push!(b, rand_cpd(b, 2, :C, [:E]))"
711 | ]
712 | },
713 | {
714 | "cell_type": "code",
715 | "execution_count": 17,
716 | "metadata": {},
717 | "outputs": [
718 | {
719 | "data": {
720 | "text/plain": [
721 | "Dict{Symbol,Any} with 5 entries:\n",
722 | " :D => 2\n",
723 | " :B => 2\n",
724 | " :S => 2\n",
725 | " :E => 2\n",
726 | " :C => 1"
727 | ]
728 | },
729 | "execution_count": 17,
730 | "metadata": {},
731 | "output_type": "execute_result"
732 | }
733 | ],
734 | "source": [
735 | "rand(b)"
736 | ]
737 | },
738 | {
739 | "cell_type": "code",
740 | "execution_count": 18,
741 | "metadata": {},
742 | "outputs": [
743 | {
744 | "data": {
745 | "text/html": [
746 | " | B | S | E | C | D |
---|
| Int64 | Int64 | Int64 | Int64 | Int64 |
---|
8 rows × 5 columns
1 | 2 | 1 | 2 | 2 | 2 |
---|
2 | 1 | 1 | 1 | 2 | 1 |
---|
3 | 2 | 2 | 1 | 1 | 2 |
---|
4 | 2 | 2 | 1 | 2 | 2 |
---|
5 | 2 | 2 | 1 | 2 | 1 |
---|
6 | 2 | 2 | 2 | 2 | 1 |
---|
7 | 2 | 2 | 2 | 2 | 2 |
---|
8 | 2 | 1 | 2 | 2 | 2 |
---|
"
747 | ],
748 | "text/latex": [
749 | "\\begin{tabular}{r|ccccc}\n",
750 | "\t& B & S & E & C & D\\\\\n",
751 | "\t\\hline\n",
752 | "\t& Int64 & Int64 & Int64 & Int64 & Int64\\\\\n",
753 | "\t\\hline\n",
754 | "\t1 & 2 & 1 & 2 & 2 & 2 \\\\\n",
755 | "\t2 & 1 & 1 & 1 & 2 & 1 \\\\\n",
756 | "\t3 & 2 & 2 & 1 & 1 & 2 \\\\\n",
757 | "\t4 & 2 & 2 & 1 & 2 & 2 \\\\\n",
758 | "\t5 & 2 & 2 & 1 & 2 & 1 \\\\\n",
759 | "\t6 & 2 & 2 & 2 & 2 & 1 \\\\\n",
760 | "\t7 & 2 & 2 & 2 & 2 & 2 \\\\\n",
761 | "\t8 & 2 & 1 & 2 & 2 & 2 \\\\\n",
762 | "\\end{tabular}\n"
763 | ],
764 | "text/plain": [
765 | "8×5 DataFrame\n",
766 | "│ Row │ B │ S │ E │ C │ D │\n",
767 | "│ │ \u001b[90mInt64\u001b[39m │ \u001b[90mInt64\u001b[39m │ \u001b[90mInt64\u001b[39m │ \u001b[90mInt64\u001b[39m │ \u001b[90mInt64\u001b[39m │\n",
768 | "├─────┼───────┼───────┼───────┼───────┼───────┤\n",
769 | "│ 1 │ 2 │ 1 │ 2 │ 2 │ 2 │\n",
770 | "│ 2 │ 1 │ 1 │ 1 │ 2 │ 1 │\n",
771 | "│ 3 │ 2 │ 2 │ 1 │ 1 │ 2 │\n",
772 | "│ 4 │ 2 │ 2 │ 1 │ 2 │ 2 │\n",
773 | "│ 5 │ 2 │ 2 │ 1 │ 2 │ 1 │\n",
774 | "│ 6 │ 2 │ 2 │ 2 │ 2 │ 1 │\n",
775 | "│ 7 │ 2 │ 2 │ 2 │ 2 │ 2 │\n",
776 | "│ 8 │ 2 │ 1 │ 2 │ 2 │ 2 │"
777 | ]
778 | },
779 | "execution_count": 18,
780 | "metadata": {},
781 | "output_type": "execute_result"
782 | }
783 | ],
784 | "source": [
785 | "rand(b, 8)"
786 | ]
787 | },
788 | {
789 | "cell_type": "markdown",
790 | "metadata": {},
791 | "source": [
792 | "### Example chemical detection network"
793 | ]
794 | },
795 | {
796 | "cell_type": "code",
797 | "execution_count": 19,
798 | "metadata": {},
799 | "outputs": [
800 | {
801 | "data": {
802 | "image/svg+xml": [
803 | "\n",
804 | "\n",
863 | "\n"
864 | ],
865 | "text/plain": [
866 | "BayesNet{CPD}({2, 1} directed simple Int64 graph, CPD[StaticCPD{Bernoulli{Float64}}(:Present, Symbol[], Bernoulli{Float64}(p=0.001)), FunctionalCPD{Bernoulli}(:Detected, Symbol[:Present], getfield(Main, Symbol(\"##5#6\"))())], Dict(:Present => 1,:Detected => 2))"
867 | ]
868 | },
869 | "execution_count": 19,
870 | "metadata": {},
871 | "output_type": "execute_result"
872 | }
873 | ],
874 | "source": [
875 | "b = BayesNet()\n",
876 | "push!(b, StaticCPD(:Present, Bernoulli(0.001)))\n",
877 | "push!(b, FunctionalCPD{Bernoulli}(:Detected, [:Present], a->Bernoulli(a[:Present] == true ? 0.999 : 0.001)))"
878 | ]
879 | },
880 | {
881 | "cell_type": "code",
882 | "execution_count": 20,
883 | "metadata": {},
884 | "outputs": [
885 | {
886 | "data": {
887 | "text/html": [
888 | " | Present | Detected |
---|
| Bool | Bool |
---|
10 rows × 2 columns
1 | 0 | 0 |
---|
2 | 0 | 0 |
---|
3 | 0 | 0 |
---|
4 | 0 | 0 |
---|
5 | 0 | 0 |
---|
6 | 0 | 0 |
---|
7 | 0 | 0 |
---|
8 | 0 | 0 |
---|
9 | 0 | 0 |
---|
10 | 0 | 0 |
---|
"
889 | ],
890 | "text/latex": [
891 | "\\begin{tabular}{r|cc}\n",
892 | "\t& Present & Detected\\\\\n",
893 | "\t\\hline\n",
894 | "\t& Bool & Bool\\\\\n",
895 | "\t\\hline\n",
896 | "\t1 & 0 & 0 \\\\\n",
897 | "\t2 & 0 & 0 \\\\\n",
898 | "\t3 & 0 & 0 \\\\\n",
899 | "\t4 & 0 & 0 \\\\\n",
900 | "\t5 & 0 & 0 \\\\\n",
901 | "\t6 & 0 & 0 \\\\\n",
902 | "\t7 & 0 & 0 \\\\\n",
903 | "\t8 & 0 & 0 \\\\\n",
904 | "\t9 & 0 & 0 \\\\\n",
905 | "\t10 & 0 & 0 \\\\\n",
906 | "\\end{tabular}\n"
907 | ],
908 | "text/plain": [
909 | "10×2 DataFrame\n",
910 | "│ Row │ Present │ Detected │\n",
911 | "│ │ \u001b[90mBool\u001b[39m │ \u001b[90mBool\u001b[39m │\n",
912 | "├─────┼─────────┼──────────┤\n",
913 | "│ 1 │ 0 │ 0 │\n",
914 | "│ 2 │ 0 │ 0 │\n",
915 | "│ 3 │ 0 │ 0 │\n",
916 | "│ 4 │ 0 │ 0 │\n",
917 | "│ 5 │ 0 │ 0 │\n",
918 | "│ 6 │ 0 │ 0 │\n",
919 | "│ 7 │ 0 │ 0 │\n",
920 | "│ 8 │ 0 │ 0 │\n",
921 | "│ 9 │ 0 │ 0 │\n",
922 | "│ 10 │ 0 │ 0 │"
923 | ]
924 | },
925 | "execution_count": 20,
926 | "metadata": {},
927 | "output_type": "execute_result"
928 | }
929 | ],
930 | "source": [
931 | "rand(b, 10)"
932 | ]
933 | },
934 | {
935 | "cell_type": "markdown",
936 | "metadata": {},
937 | "source": [
938 | "Not very interesting since all the samples are likely to be (false, false)"
939 | ]
940 | },
941 | {
942 | "cell_type": "code",
943 | "execution_count": 21,
944 | "metadata": {},
945 | "outputs": [
946 | {
947 | "data": {
948 | "text/plain": [
949 | "1"
950 | ]
951 | },
952 | "execution_count": 21,
953 | "metadata": {},
954 | "output_type": "execute_result"
955 | }
956 | ],
957 | "source": [
958 | "data = rand(b, 1000)\n",
959 | "sum(data[!,:Detected] .== 1)"
960 | ]
961 | },
962 | {
963 | "cell_type": "markdown",
964 | "metadata": {},
965 | "source": [
966 | "Even with 1000 samples, we are not likely to get many samples that are consistent with Detected = true. This can result in a pretty poor estimate."
967 | ]
968 | },
969 | {
970 | "cell_type": "code",
971 | "execution_count": 22,
972 | "metadata": {},
973 | "outputs": [
974 | {
975 | "data": {
976 | "text/html": [
977 | " | Present | Detected | p |
---|
| Bool | Bool | Float64 |
---|
2 rows × 3 columns
1 | 0 | 1 | 0.493 |
---|
2 | 1 | 1 | 0.507 |
---|
"
978 | ],
979 | "text/plain": [
980 | "Table(2×3 DataFrame\n",
981 | "│ Row │ Present │ Detected │ p │\n",
982 | "│ │ \u001b[90mBool\u001b[39m │ \u001b[90mBool\u001b[39m │ \u001b[90mFloat64\u001b[39m │\n",
983 | "├─────┼─────────┼──────────┼─────────┤\n",
984 | "│ 1 │ 0 │ 1 │ 0.493 │\n",
985 | "│ 2 │ 1 │ 1 │ 0.507 │)"
986 | ]
987 | },
988 | "execution_count": 22,
989 | "metadata": {},
990 | "output_type": "execute_result"
991 | }
992 | ],
993 | "source": [
994 | "samples = rand(b, RejectionSampler(:Detected=>true, max_nsamples=100000000), 1000)\n",
995 | "fit(Table, samples)"
996 | ]
997 | },
998 | {
999 | "cell_type": "markdown",
1000 | "metadata": {},
1001 | "source": [
1002 | "### Likelihood weighted sampling"
1003 | ]
1004 | },
1005 | {
1006 | "cell_type": "code",
1007 | "execution_count": 23,
1008 | "metadata": {},
1009 | "outputs": [
1010 | {
1011 | "data": {
1012 | "text/html": [
1013 | " | Detected | Present | p |
---|
| Any | Any | Any |
---|
5 rows × 3 columns
1 | 1 | 0 | 0.2 |
---|
2 | 1 | 0 | 0.2 |
---|
3 | 1 | 0 | 0.2 |
---|
4 | 1 | 0 | 0.2 |
---|
5 | 1 | 0 | 0.2 |
---|
"
1014 | ],
1015 | "text/latex": [
1016 | "\\begin{tabular}{r|ccc}\n",
1017 | "\t& Detected & Present & p\\\\\n",
1018 | "\t\\hline\n",
1019 | "\t& Any & Any & Any\\\\\n",
1020 | "\t\\hline\n",
1021 | "\t1 & 1 & 0 & 0.2 \\\\\n",
1022 | "\t2 & 1 & 0 & 0.2 \\\\\n",
1023 | "\t3 & 1 & 0 & 0.2 \\\\\n",
1024 | "\t4 & 1 & 0 & 0.2 \\\\\n",
1025 | "\t5 & 1 & 0 & 0.2 \\\\\n",
1026 | "\\end{tabular}\n"
1027 | ],
1028 | "text/plain": [
1029 | "5×3 DataFrame\n",
1030 | "│ Row │ Detected │ Present │ p │\n",
1031 | "│ │ \u001b[90mAny\u001b[39m │ \u001b[90mAny\u001b[39m │ \u001b[90mAny\u001b[39m │\n",
1032 | "├─────┼──────────┼─────────┼─────┤\n",
1033 | "│ 1 │ 1 │ 0 │ 0.2 │\n",
1034 | "│ 2 │ 1 │ 0 │ 0.2 │\n",
1035 | "│ 3 │ 1 │ 0 │ 0.2 │\n",
1036 | "│ 4 │ 1 │ 0 │ 0.2 │\n",
1037 | "│ 5 │ 1 │ 0 │ 0.2 │"
1038 | ]
1039 | },
1040 | "execution_count": 23,
1041 | "metadata": {},
1042 | "output_type": "execute_result"
1043 | }
1044 | ],
1045 | "source": [
1046 | "rand(b, LikelihoodWeightedSampler(Assignment(:Detected=>true)), 5)"
1047 | ]
1048 | },
1049 | {
1050 | "cell_type": "code",
1051 | "execution_count": 24,
1052 | "metadata": {},
1053 | "outputs": [
1054 | {
1055 | "data": {
1056 | "text/html": [
1057 | " | Detected | Present | p |
---|
| Any | Any | Float64 |
---|
2 rows × 3 columns
1 | 1 | 0 | 0.476166 |
---|
2 | 1 | 1 | 0.523834 |
---|
"
1058 | ],
1059 | "text/plain": [
1060 | "Table(2×3 DataFrame\n",
1061 | "│ Row │ Detected │ Present │ p │\n",
1062 | "│ │ \u001b[90mAny\u001b[39m │ \u001b[90mAny\u001b[39m │ \u001b[90mFloat64\u001b[39m │\n",
1063 | "├─────┼──────────┼─────────┼──────────┤\n",
1064 | "│ 1 │ 1 │ 0 │ 0.476166 │\n",
1065 | "│ 2 │ 1 │ 1 │ 0.523834 │)"
1066 | ]
1067 | },
1068 | "execution_count": 24,
1069 | "metadata": {},
1070 | "output_type": "execute_result"
1071 | }
1072 | ],
1073 | "source": [
1074 | "fit(Table, rand(b, LikelihoodWeightedSampler(:Detected=>true), 10000))"
1075 | ]
1076 | }
1077 | ],
1078 | "metadata": {
1079 | "kernelspec": {
1080 | "display_name": "Julia 1.2.0",
1081 | "language": "julia",
1082 | "name": "julia-1.2"
1083 | },
1084 | "language_info": {
1085 | "file_extension": ".jl",
1086 | "mimetype": "application/julia",
1087 | "name": "julia",
1088 | "version": "1.2.0"
1089 | }
1090 | },
1091 | "nbformat": 4,
1092 | "nbformat_minor": 1
1093 | }
1094 |
--------------------------------------------------------------------------------
/06-DecisionNetworks.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Decision Networks"
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": 1,
13 | "metadata": {},
14 | "outputs": [],
15 | "source": [
16 | "using BayesNets\n",
17 | "using DataFrames"
18 | ]
19 | },
20 | {
21 | "cell_type": "code",
22 | "execution_count": 2,
23 | "metadata": {},
24 | "outputs": [
25 | {
26 | "data": {
27 | "image/svg+xml": [
28 | "\n",
29 | "\n",
76 | "\n"
77 | ],
78 | "text/plain": [
79 | "BayesNet{CPD}({4, 3} directed simple Int64 graph, CPD[StaticCPD{Bernoulli{Float64}}(:D, Symbol[], Bernoulli{Float64}(p=0.01)), FunctionalCPD{Bernoulli}(:O3, Symbol[:D], getfield(Main, Symbol(\"##5#6\"))()), StaticCPD{Bernoulli{Float64}}(:O1, Symbol[:D], Bernoulli{Float64}(p=0.5)), FunctionalCPD{Bernoulli}(:O2, Symbol[:D], getfield(Main, Symbol(\"##3#4\"))())], Dict(:O2 => 4,:D => 1,:O1 => 3,:O3 => 2))"
80 | ]
81 | },
82 | "execution_count": 2,
83 | "metadata": {},
84 | "output_type": "execute_result"
85 | }
86 | ],
87 | "source": [
88 | "b = BayesNet()\n",
89 | "push!(b, StaticCPD(:D, Bernoulli(0.01)))\n",
90 | "push!(b, StaticCPD(:O1, [:D], Bernoulli(0.5))) # no real signal of whether disease is present\n",
91 | "push!(b, FunctionalCPD{Bernoulli}(:O2, [:D], a->Bernoulli(a[:D] == true ? 0.9 : 0.01)))\n",
92 | "push!(b, FunctionalCPD{Bernoulli}(:O3, [:D], a->Bernoulli(a[:D] == true ? 0.6 : 0.3)))"
93 | ]
94 | },
95 | {
96 | "cell_type": "code",
97 | "execution_count": 3,
98 | "metadata": {},
99 | "outputs": [
100 | {
101 | "data": {
102 | "text/html": [
103 | " | T | D | U |
---|
| Bool | Bool | Int64 |
---|
4 rows × 3 columns
1 | 0 | 0 | 0 |
---|
2 | 0 | 1 | -10 |
---|
3 | 1 | 0 | -1 |
---|
4 | 1 | 1 | -1 |
---|
"
104 | ],
105 | "text/latex": [
106 | "\\begin{tabular}{r|ccc}\n",
107 | "\t& T & D & U\\\\\n",
108 | "\t\\hline\n",
109 | "\t& Bool & Bool & Int64\\\\\n",
110 | "\t\\hline\n",
111 | "\t1 & 0 & 0 & 0 \\\\\n",
112 | "\t2 & 0 & 1 & -10 \\\\\n",
113 | "\t3 & 1 & 0 & -1 \\\\\n",
114 | "\t4 & 1 & 1 & -1 \\\\\n",
115 | "\\end{tabular}\n"
116 | ],
117 | "text/plain": [
118 | "4×3 DataFrame\n",
119 | "│ Row │ T │ D │ U │\n",
120 | "│ │ \u001b[90mBool\u001b[39m │ \u001b[90mBool\u001b[39m │ \u001b[90mInt64\u001b[39m │\n",
121 | "├─────┼──────┼──────┼───────┤\n",
122 | "│ 1 │ 0 │ 0 │ 0 │\n",
123 | "│ 2 │ 0 │ 1 │ -10 │\n",
124 | "│ 3 │ 1 │ 0 │ -1 │\n",
125 | "│ 4 │ 1 │ 1 │ -1 │"
126 | ]
127 | },
128 | "execution_count": 3,
129 | "metadata": {},
130 | "output_type": "execute_result"
131 | }
132 | ],
133 | "source": [
134 | "U = DataFrame()\n",
135 | "U[!,:T] = [false, false, true, true]\n",
136 | "U[!,:D] = [false, true, false, true]\n",
137 | "U[!,:U] = [0, -10, -1, -1]\n",
138 | "U"
139 | ]
140 | },
141 | {
142 | "cell_type": "code",
143 | "execution_count": 4,
144 | "metadata": {},
145 | "outputs": [],
146 | "source": [
147 | "using Random\n",
148 | "function estimate_table(b::BayesNet, target::NodeName, consistent_with::Assignment; nsamples = 10000)\n",
149 | " Random.seed!(0)\n",
150 | " t = fit(Table, rand(b, LikelihoodWeightedSampler(consistent_with), nsamples))\n",
151 | " normalize(sumout(t, setdiff(names(b), [target])))\n",
152 | "end;"
153 | ]
154 | },
155 | {
156 | "cell_type": "code",
157 | "execution_count": 5,
158 | "metadata": {},
159 | "outputs": [
160 | {
161 | "data": {
162 | "text/html": [
163 | " | D | p |
---|
| Any | Float64 |
---|
2 rows × 2 columns
1 | 0 | 0.9899 |
---|
2 | 1 | 0.0101 |
---|
"
164 | ],
165 | "text/plain": [
166 | "Table(2×2 DataFrame\n",
167 | "│ Row │ D │ p │\n",
168 | "│ │ \u001b[90mAny\u001b[39m │ \u001b[90mFloat64\u001b[39m │\n",
169 | "├─────┼─────┼─────────┤\n",
170 | "│ 1 │ 0 │ 0.9899 │\n",
171 | "│ 2 │ 1 │ 0.0101 │)"
172 | ]
173 | },
174 | "execution_count": 5,
175 | "metadata": {},
176 | "output_type": "execute_result"
177 | }
178 | ],
179 | "source": [
180 | "D = estimate_table(b, :D, Assignment(:O1=>true))"
181 | ]
182 | },
183 | {
184 | "cell_type": "code",
185 | "execution_count": 6,
186 | "metadata": {},
187 | "outputs": [
188 | {
189 | "data": {
190 | "text/html": [
191 | " | T | D | U | p |
---|
| Bool | Bool | Int64 | Float64 |
---|
4 rows × 4 columns
1 | 0 | 0 | 0 | 0.9899 |
---|
2 | 0 | 1 | -10 | 0.0101 |
---|
3 | 1 | 0 | -1 | 0.9899 |
---|
4 | 1 | 1 | -1 | 0.0101 |
---|
"
192 | ],
193 | "text/latex": [
194 | "\\begin{tabular}{r|cccc}\n",
195 | "\t& T & D & U & p\\\\\n",
196 | "\t\\hline\n",
197 | "\t& Bool & Bool & Int64 & Float64\\\\\n",
198 | "\t\\hline\n",
199 | "\t1 & 0 & 0 & 0 & 0.9899 \\\\\n",
200 | "\t2 & 0 & 1 & -10 & 0.0101 \\\\\n",
201 | "\t3 & 1 & 0 & -1 & 0.9899 \\\\\n",
202 | "\t4 & 1 & 1 & -1 & 0.0101 \\\\\n",
203 | "\\end{tabular}\n"
204 | ],
205 | "text/plain": [
206 | "4×4 DataFrame\n",
207 | "│ Row │ T │ D │ U │ p │\n",
208 | "│ │ \u001b[90mBool\u001b[39m │ \u001b[90mBool\u001b[39m │ \u001b[90mInt64\u001b[39m │ \u001b[90mFloat64\u001b[39m │\n",
209 | "├─────┼──────┼──────┼───────┼─────────┤\n",
210 | "│ 1 │ 0 │ 0 │ 0 │ 0.9899 │\n",
211 | "│ 2 │ 0 │ 1 │ -10 │ 0.0101 │\n",
212 | "│ 3 │ 1 │ 0 │ -1 │ 0.9899 │\n",
213 | "│ 4 │ 1 │ 1 │ -1 │ 0.0101 │"
214 | ]
215 | },
216 | "execution_count": 6,
217 | "metadata": {},
218 | "output_type": "execute_result"
219 | }
220 | ],
221 | "source": [
222 | "EU = join(U, D.potential, on = :D)"
223 | ]
224 | },
225 | {
226 | "cell_type": "code",
227 | "execution_count": 7,
228 | "metadata": {},
229 | "outputs": [
230 | {
231 | "data": {
232 | "text/html": [
233 | " | T | x1 |
---|
| Bool | Float64 |
---|
2 rows × 2 columns
1 | 0 | -0.101 |
---|
2 | 1 | -1.0 |
---|
"
234 | ],
235 | "text/latex": [
236 | "\\begin{tabular}{r|cc}\n",
237 | "\t& T & x1\\\\\n",
238 | "\t\\hline\n",
239 | "\t& Bool & Float64\\\\\n",
240 | "\t\\hline\n",
241 | "\t1 & 0 & -0.101 \\\\\n",
242 | "\t2 & 1 & -1.0 \\\\\n",
243 | "\\end{tabular}\n"
244 | ],
245 | "text/plain": [
246 | "2×2 DataFrame\n",
247 | "│ Row │ T │ x1 │\n",
248 | "│ │ \u001b[90mBool\u001b[39m │ \u001b[90mFloat64\u001b[39m │\n",
249 | "├─────┼──────┼─────────┤\n",
250 | "│ 1 │ 0 │ -0.101 │\n",
251 | "│ 2 │ 1 │ -1.0 │"
252 | ]
253 | },
254 | "execution_count": 7,
255 | "metadata": {},
256 | "output_type": "execute_result"
257 | }
258 | ],
259 | "source": [
260 | "using LinearAlgebra\n",
261 | "by(EU, :T, df->LinearAlgebra.dot(df[!,:U], df[!,:p]))"
262 | ]
263 | },
264 | {
265 | "cell_type": "code",
266 | "execution_count": 8,
267 | "metadata": {},
268 | "outputs": [],
269 | "source": [
270 | "function diseaseEU(b::BayesNet, a::Assignment, U::DataFrame)\n",
271 | " D = estimate_table(b, :D, a).potential\n",
272 | " EU = join(U, D, on = :D)\n",
273 | " t = by(EU, :T, df->LinearAlgebra.dot(df[!,:U], df[!,:p]))\n",
274 | " rename!(t, :x1=>:EU)\n",
275 | " t\n",
276 | "end;"
277 | ]
278 | },
279 | {
280 | "cell_type": "code",
281 | "execution_count": 9,
282 | "metadata": {},
283 | "outputs": [
284 | {
285 | "data": {
286 | "text/html": [
287 | " | T | EU |
---|
| Bool | Float64 |
---|
2 rows × 2 columns
1 | 0 | -0.101 |
---|
2 | 1 | -1.0 |
---|
"
288 | ],
289 | "text/latex": [
290 | "\\begin{tabular}{r|cc}\n",
291 | "\t& T & EU\\\\\n",
292 | "\t\\hline\n",
293 | "\t& Bool & Float64\\\\\n",
294 | "\t\\hline\n",
295 | "\t1 & 0 & -0.101 \\\\\n",
296 | "\t2 & 1 & -1.0 \\\\\n",
297 | "\\end{tabular}\n"
298 | ],
299 | "text/plain": [
300 | "2×2 DataFrame\n",
301 | "│ Row │ T │ EU │\n",
302 | "│ │ \u001b[90mBool\u001b[39m │ \u001b[90mFloat64\u001b[39m │\n",
303 | "├─────┼──────┼─────────┤\n",
304 | "│ 1 │ 0 │ -0.101 │\n",
305 | "│ 2 │ 1 │ -1.0 │"
306 | ]
307 | },
308 | "execution_count": 9,
309 | "metadata": {},
310 | "output_type": "execute_result"
311 | }
312 | ],
313 | "source": [
314 | "diseaseEU(b, Assignment(:O1=>true), U)"
315 | ]
316 | },
317 | {
318 | "cell_type": "code",
319 | "execution_count": 10,
320 | "metadata": {},
321 | "outputs": [
322 | {
323 | "data": {
324 | "text/html": [
325 | " | T | EU |
---|
| Bool | Float64 |
---|
2 rows × 2 columns
1 | 0 | -0.101 |
---|
2 | 1 | -1.0 |
---|
"
326 | ],
327 | "text/latex": [
328 | "\\begin{tabular}{r|cc}\n",
329 | "\t& T & EU\\\\\n",
330 | "\t\\hline\n",
331 | "\t& Bool & Float64\\\\\n",
332 | "\t\\hline\n",
333 | "\t1 & 0 & -0.101 \\\\\n",
334 | "\t2 & 1 & -1.0 \\\\\n",
335 | "\\end{tabular}\n"
336 | ],
337 | "text/plain": [
338 | "2×2 DataFrame\n",
339 | "│ Row │ T │ EU │\n",
340 | "│ │ \u001b[90mBool\u001b[39m │ \u001b[90mFloat64\u001b[39m │\n",
341 | "├─────┼──────┼─────────┤\n",
342 | "│ 1 │ 0 │ -0.101 │\n",
343 | "│ 2 │ 1 │ -1.0 │"
344 | ]
345 | },
346 | "execution_count": 10,
347 | "metadata": {},
348 | "output_type": "execute_result"
349 | }
350 | ],
351 | "source": [
352 | "diseaseEU(b, Assignment(:O1=>false), U)"
353 | ]
354 | },
355 | {
356 | "cell_type": "code",
357 | "execution_count": 11,
358 | "metadata": {},
359 | "outputs": [
360 | {
361 | "data": {
362 | "text/html": [
363 | " | T | EU |
---|
| Bool | Float64 |
---|
2 rows × 2 columns
1 | 0 | -4.78698 |
---|
2 | 1 | -1.0 |
---|
"
364 | ],
365 | "text/latex": [
366 | "\\begin{tabular}{r|cc}\n",
367 | "\t& T & EU\\\\\n",
368 | "\t\\hline\n",
369 | "\t& Bool & Float64\\\\\n",
370 | "\t\\hline\n",
371 | "\t1 & 0 & -4.78698 \\\\\n",
372 | "\t2 & 1 & -1.0 \\\\\n",
373 | "\\end{tabular}\n"
374 | ],
375 | "text/plain": [
376 | "2×2 DataFrame\n",
377 | "│ Row │ T │ EU │\n",
378 | "│ │ \u001b[90mBool\u001b[39m │ \u001b[90mFloat64\u001b[39m │\n",
379 | "├─────┼──────┼──────────┤\n",
380 | "│ 1 │ 0 │ -4.78698 │\n",
381 | "│ 2 │ 1 │ -1.0 │"
382 | ]
383 | },
384 | "execution_count": 11,
385 | "metadata": {},
386 | "output_type": "execute_result"
387 | }
388 | ],
389 | "source": [
390 | "diseaseEU(b, Assignment(:O2=>true), U)"
391 | ]
392 | },
393 | {
394 | "cell_type": "code",
395 | "execution_count": 12,
396 | "metadata": {},
397 | "outputs": [
398 | {
399 | "data": {
400 | "text/html": [
401 | " | T | EU |
---|
| Bool | Float64 |
---|
2 rows × 2 columns
1 | 0 | -0.19998 |
---|
2 | 1 | -1.0 |
---|
"
402 | ],
403 | "text/latex": [
404 | "\\begin{tabular}{r|cc}\n",
405 | "\t& T & EU\\\\\n",
406 | "\t\\hline\n",
407 | "\t& Bool & Float64\\\\\n",
408 | "\t\\hline\n",
409 | "\t1 & 0 & -0.19998 \\\\\n",
410 | "\t2 & 1 & -1.0 \\\\\n",
411 | "\\end{tabular}\n"
412 | ],
413 | "text/plain": [
414 | "2×2 DataFrame\n",
415 | "│ Row │ T │ EU │\n",
416 | "│ │ \u001b[90mBool\u001b[39m │ \u001b[90mFloat64\u001b[39m │\n",
417 | "├─────┼──────┼──────────┤\n",
418 | "│ 1 │ 0 │ -0.19998 │\n",
419 | "│ 2 │ 1 │ -1.0 │"
420 | ]
421 | },
422 | "execution_count": 12,
423 | "metadata": {},
424 | "output_type": "execute_result"
425 | }
426 | ],
427 | "source": [
428 | "diseaseEU(b, Assignment(:O3=>true), U)"
429 | ]
430 | }
431 | ],
432 | "metadata": {
433 | "kernelspec": {
434 | "display_name": "Julia 1.2.0",
435 | "language": "julia",
436 | "name": "julia-1.2"
437 | },
438 | "language_info": {
439 | "file_extension": ".jl",
440 | "mimetype": "application/julia",
441 | "name": "julia",
442 | "version": "1.2.0"
443 | }
444 | },
445 | "nbformat": 4,
446 | "nbformat_minor": 1
447 | }
448 |
--------------------------------------------------------------------------------
/12-ModelFreeReinforcementLearning.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "include(\"gridworld.jl\")\n",
10 | "g = DMUGridWorld();"
11 | ]
12 | },
13 | {
14 | "cell_type": "markdown",
15 | "metadata": {},
16 | "source": [
17 | "Let's apply Q-learning from Algorithm 5.3 in the text. We'll train over 1000 100-step runs:"
18 | ]
19 | },
20 | {
21 | "cell_type": "code",
22 | "execution_count": 2,
23 | "metadata": {},
24 | "outputs": [
25 | {
26 | "data": {
27 | "text/plain": [
28 | "Qlearn (generic function with 1 method)"
29 | ]
30 | },
31 | "execution_count": 2,
32 | "metadata": {},
33 | "output_type": "execute_result"
34 | }
35 | ],
36 | "source": [
37 | "function Qlearn(g, alpha, epsilon)\n",
38 | " # initialize dictionary\n",
39 | " Q = Dict{Int, Vector{Float64}}()\n",
40 | " \n",
41 | " # initialize Q-values at initial state (s = 1)\n",
42 | " Q[1] = zeros(n_actions(g))\n",
43 | " \n",
44 | " # 1000 simulations\n",
45 | " for k = 1:1000\n",
46 | " s = 1\n",
47 | " for t = 0:100\n",
48 | " # choose a based on Q and some exploration strategy\n",
49 | " a_idx = findmax(Q[s])[2]\n",
50 | " if rand() < epsilon\n",
51 | " a_idx = rand(1:4)\n",
52 | " end\n",
53 | " a = actions(g)[a_idx]\n",
54 | "\n",
55 | " # observe new state s_{t+1} and reward rt\n",
56 | " sp, r = simulate(g, s, a)\n",
57 | "\n",
58 | " # if we've never observed this state, initialize it to zeros\n",
59 | " if !haskey(Q, sp)\n",
60 | " Q[sp] = zeros(n_actions(g))\n",
61 | " end\n",
62 | "\n",
63 | " # update Q values\n",
64 | " Q[s][a_idx] += alpha * ( r + discount(g)*maximum(Q[sp]) - Q[s][a_idx] )\n",
65 | "\n",
66 | " # update s\n",
67 | " s = sp\n",
68 | " \n",
69 | " # 73 and 88 are terminal states. Just quit if we get in them.\n",
70 | " if s == 73 || s == 88\n",
71 | " break\n",
72 | " end\n",
73 | " end\n",
74 | " end\n",
75 | " \n",
76 | " return Q\n",
77 | "end"
78 | ]
79 | },
80 | {
81 | "cell_type": "code",
82 | "execution_count": 3,
83 | "metadata": {},
84 | "outputs": [],
85 | "source": [
86 | "Q = Qlearn(g, 0.5, 0.5);"
87 | ]
88 | },
89 | {
90 | "cell_type": "markdown",
91 | "metadata": {},
92 | "source": [
93 | "Did the Q-learning work? Let's compare it to a random policy during 10 simulations."
94 | ]
95 | },
96 | {
97 | "cell_type": "code",
98 | "execution_count": 4,
99 | "metadata": {},
100 | "outputs": [
101 | {
102 | "name": "stdout",
103 | "output_type": "stream",
104 | "text": [
105 | "Q-learned policy: -209947\n",
106 | "random poilcy: -1531508\n"
107 | ]
108 | }
109 | ],
110 | "source": [
111 | "using Random\n",
112 | "Random.seed!(1) # for reproducibility, seed random number generator\n",
113 | "\n",
114 | "r_sum = 0.0 # sum for policy from Q-learning\n",
115 | "rr_sum = 0.0 # sum for random policy\n",
116 | "\n",
117 | "# run 10 simulations\n",
118 | "for k = 1:10\n",
119 | " global r_sum, rr_sum\n",
120 | " s = 1 # initial state for policy from Q-learning\n",
121 | " sr = 1 # initial state for random policy\n",
122 | " \n",
123 | " for t = 0:100\n",
124 | " \n",
125 | " # generate actions for both policies\n",
126 | " a = actions(g)[findmax(Q[s])[2]]\n",
127 | " ar = actions(g)[rand(1:4)]\n",
128 | " \n",
129 | " # advance Q simulation if you aren't in a terminal state\n",
130 | " if s != 73 && s != 88\n",
131 | " sp, r = simulate(g, s, a)\n",
132 | " r_sum += r * discount(g) ^ (-t)\n",
133 | " s = sp\n",
134 | " end\n",
135 | " \n",
136 | " # advance random simulation if you aren't in a terminal state\n",
137 | " if sr != 73 && sr != 88\n",
138 | " spr, rr = simulate(g, sr, ar)\n",
139 | " rr_sum += rr * discount(g) ^ (-t)\n",
140 | " sr = spr\n",
141 | " end\n",
142 | " end\n",
143 | "end\n",
144 | "\n",
145 | "println(\"Q-learned policy: \", round(Int, r_sum))\n",
146 | "println(\"random poilcy: \", round(Int, rr_sum))"
147 | ]
148 | },
149 | {
150 | "cell_type": "markdown",
151 | "metadata": {},
152 | "source": [
153 | "The cumulative sum from Q-learning is much better."
154 | ]
155 | }
156 | ],
157 | "metadata": {
158 | "@webio": {
159 | "lastCommId": null,
160 | "lastKernelId": null
161 | },
162 | "kernelspec": {
163 | "display_name": "Julia 1.2.0",
164 | "language": "julia",
165 | "name": "julia-1.2"
166 | },
167 | "language_info": {
168 | "file_extension": ".jl",
169 | "mimetype": "application/julia",
170 | "name": "julia",
171 | "version": "1.2.0"
172 | }
173 | },
174 | "nbformat": 4,
175 | "nbformat_minor": 1
176 | }
177 |
--------------------------------------------------------------------------------
/13-StateUncertainty.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "This example shows how to perform the discrete belief update discussed in section 6.2 of the course text.\n",
8 | "\n",
9 | "Read over the description of the baby problem before seeing how to express it in math below.\n",
10 | "\n",
11 | "Let's start by defining the transition function:"
12 | ]
13 | },
14 | {
15 | "cell_type": "code",
16 | "execution_count": 1,
17 | "metadata": {},
18 | "outputs": [
19 | {
20 | "data": {
21 | "text/plain": [
22 | "T (generic function with 1 method)"
23 | ]
24 | },
25 | "execution_count": 1,
26 | "metadata": {},
27 | "output_type": "execute_result"
28 | }
29 | ],
30 | "source": [
31 | "function T(s, a, sp)\n",
32 | " \n",
33 | " # if we feed the baby, probability that it becomes not hungry is 1.0\n",
34 | " if a == :feed\n",
35 | " if sp == :not_hungry\n",
36 | " return 1.0\n",
37 | " else\n",
38 | " return 0.0\n",
39 | " end\n",
40 | " \n",
41 | " # if we don't feed baby...\n",
42 | " else\n",
43 | " # baby remains hungry if unfed\n",
44 | " if s == :hungry\n",
45 | " if sp == :hungry\n",
46 | " return 1.0\n",
47 | " else\n",
48 | " return 0.0\n",
49 | " end\n",
50 | " else\n",
51 | " # 10% chance of baby becoming hungry given it is not hungry and unfed\n",
52 | " if sp == :hungry\n",
53 | " return 0.1\n",
54 | " else\n",
55 | " return 0.9\n",
56 | " end\n",
57 | " end\n",
58 | " end\n",
59 | " \n",
60 | "end"
61 | ]
62 | },
63 | {
64 | "cell_type": "markdown",
65 | "metadata": {},
66 | "source": [
67 | "Let's define the observation function:"
68 | ]
69 | },
70 | {
71 | "cell_type": "code",
72 | "execution_count": 2,
73 | "metadata": {},
74 | "outputs": [
75 | {
76 | "data": {
77 | "text/plain": [
78 | "O (generic function with 1 method)"
79 | ]
80 | },
81 | "execution_count": 2,
82 | "metadata": {},
83 | "output_type": "execute_result"
84 | }
85 | ],
86 | "source": [
87 | "function O(a, sp, o)\n",
88 | " if sp == :hungry\n",
89 | " p_cry = 0.8\n",
90 | " else\n",
91 | " p_cry = 0.1\n",
92 | " end\n",
93 | " \n",
94 | " if o == :cry\n",
95 | " return p_cry\n",
96 | " else\n",
97 | " return 1.0 - p_cry\n",
98 | " end \n",
99 | "end"
100 | ]
101 | },
102 | {
103 | "cell_type": "markdown",
104 | "metadata": {},
105 | "source": [
106 | "The discrete belief update is defined in equations 6.7-6.11 of the course text:"
107 | ]
108 | },
109 | {
110 | "cell_type": "code",
111 | "execution_count": 3,
112 | "metadata": {},
113 | "outputs": [
114 | {
115 | "data": {
116 | "text/plain": [
117 | "update_belief (generic function with 1 method)"
118 | ]
119 | },
120 | "execution_count": 3,
121 | "metadata": {},
122 | "output_type": "execute_result"
123 | }
124 | ],
125 | "source": [
126 | "function update_belief(b, a, o)\n",
127 | " bp = Dict()\n",
128 | " for sp in [:hungry, :not_hungry]\n",
129 | " sum_over_s = 0.0\n",
130 | " for s in [:hungry, :not_hungry]\n",
131 | " sum_over_s += T(s, a, sp) * b[s]\n",
132 | " end\n",
133 | " bp[sp] = O(a, sp, o) * sum_over_s\n",
134 | " end\n",
135 | "\n",
136 | " # normalize so that probabilities sum to 1\n",
137 | " bp_sum = bp[:hungry] + bp[:not_hungry]\n",
138 | " bp[:hungry] = bp[:hungry] / bp_sum\n",
139 | " bp[:not_hungry] = bp[:not_hungry] / bp_sum\n",
140 | "\n",
141 | " return bp\n",
142 | "end"
143 | ]
144 | },
145 | {
146 | "cell_type": "markdown",
147 | "metadata": {},
148 | "source": [
149 | "Let's use our functions and follow the example in chapter 6.2.1 of the course textbook.\n",
150 | "\n",
151 | "Step 1. We start with a uniform belief:"
152 | ]
153 | },
154 | {
155 | "cell_type": "code",
156 | "execution_count": 4,
157 | "metadata": {},
158 | "outputs": [
159 | {
160 | "data": {
161 | "text/plain": [
162 | "0.5"
163 | ]
164 | },
165 | "execution_count": 4,
166 | "metadata": {},
167 | "output_type": "execute_result"
168 | }
169 | ],
170 | "source": [
171 | "b1 = Dict()\n",
172 | "b1[:hungry] = 0.5\n",
173 | "b1[:not_hungry] = 0.5"
174 | ]
175 | },
176 | {
177 | "cell_type": "markdown",
178 | "metadata": {},
179 | "source": [
180 | "Step 2. We do not feed the baby and the baby cries."
181 | ]
182 | },
183 | {
184 | "cell_type": "code",
185 | "execution_count": 5,
186 | "metadata": {},
187 | "outputs": [
188 | {
189 | "data": {
190 | "text/plain": [
191 | "Dict{Any,Any} with 2 entries:\n",
192 | " :not_hungry => 0.0927835\n",
193 | " :hungry => 0.907216"
194 | ]
195 | },
196 | "execution_count": 5,
197 | "metadata": {},
198 | "output_type": "execute_result"
199 | }
200 | ],
201 | "source": [
202 | "b2 = update_belief(b1, :not_feed, :cry)"
203 | ]
204 | },
205 | {
206 | "cell_type": "markdown",
207 | "metadata": {},
208 | "source": [
209 | "Step 3. We feed the baby and it stops crying."
210 | ]
211 | },
212 | {
213 | "cell_type": "code",
214 | "execution_count": 6,
215 | "metadata": {},
216 | "outputs": [
217 | {
218 | "data": {
219 | "text/plain": [
220 | "Dict{Any,Any} with 2 entries:\n",
221 | " :not_hungry => 1.0\n",
222 | " :hungry => 0.0"
223 | ]
224 | },
225 | "execution_count": 6,
226 | "metadata": {},
227 | "output_type": "execute_result"
228 | }
229 | ],
230 | "source": [
231 | "b3 = update_belief(b2, :feed, :not_cry)"
232 | ]
233 | },
234 | {
235 | "cell_type": "markdown",
236 | "metadata": {},
237 | "source": [
238 | "Step 4. We do not feed the baby and it does not cry."
239 | ]
240 | },
241 | {
242 | "cell_type": "code",
243 | "execution_count": 7,
244 | "metadata": {},
245 | "outputs": [
246 | {
247 | "data": {
248 | "text/plain": [
249 | "Dict{Any,Any} with 2 entries:\n",
250 | " :not_hungry => 0.975904\n",
251 | " :hungry => 0.0240964"
252 | ]
253 | },
254 | "execution_count": 7,
255 | "metadata": {},
256 | "output_type": "execute_result"
257 | }
258 | ],
259 | "source": [
260 | "b4 = update_belief(b3, :not_feed, :not_cry)"
261 | ]
262 | },
263 | {
264 | "cell_type": "markdown",
265 | "metadata": {},
266 | "source": [
267 | "Step 5. Again, we do not feed the baby and it does not cry."
268 | ]
269 | },
270 | {
271 | "cell_type": "code",
272 | "execution_count": 8,
273 | "metadata": {},
274 | "outputs": [
275 | {
276 | "data": {
277 | "text/plain": [
278 | "Dict{Any,Any} with 2 entries:\n",
279 | " :not_hungry => 0.970132\n",
280 | " :hungry => 0.0298684"
281 | ]
282 | },
283 | "execution_count": 8,
284 | "metadata": {},
285 | "output_type": "execute_result"
286 | }
287 | ],
288 | "source": [
289 | "b5 = update_belief(b4, :not_feed, :not_cry)"
290 | ]
291 | },
292 | {
293 | "cell_type": "markdown",
294 | "metadata": {},
295 | "source": [
296 | "Step 6. We do not feed the baby and the baby begins to cry."
297 | ]
298 | },
299 | {
300 | "cell_type": "code",
301 | "execution_count": 9,
302 | "metadata": {},
303 | "outputs": [
304 | {
305 | "data": {
306 | "text/plain": [
307 | "Dict{Any,Any} with 2 entries:\n",
308 | " :not_hungry => 0.462415\n",
309 | " :hungry => 0.537585"
310 | ]
311 | },
312 | "execution_count": 9,
313 | "metadata": {},
314 | "output_type": "execute_result"
315 | }
316 | ],
317 | "source": [
318 | "b6 = update_belief(b5, :not_feed, :cry)"
319 | ]
320 | }
321 | ],
322 | "metadata": {
323 | "@webio": {
324 | "lastCommId": null,
325 | "lastKernelId": null
326 | },
327 | "kernelspec": {
328 | "display_name": "Julia 1.2.0",
329 | "language": "julia",
330 | "name": "julia-1.2"
331 | },
332 | "language_info": {
333 | "file_extension": ".jl",
334 | "mimetype": "application/julia",
335 | "name": "julia",
336 | "version": "1.2.0"
337 | }
338 | },
339 | "nbformat": 4,
340 | "nbformat_minor": 1
341 | }
342 |
--------------------------------------------------------------------------------
/Project.toml:
--------------------------------------------------------------------------------
1 | [deps]
2 | BasicPOMCP = "d721219e-3fc6-5570-a8ef-e5402f47c49e"
3 | BayesNets = "ba4760a4-c768-5bed-964b-cf806dc591cb"
4 | BeliefUpdaters = "8bb6e9a1-7d73-552c-a44a-e5dc5634aac4"
5 | ContinuumWorld = "5cbb95a3-277b-5373-895a-7e14bd91b3cc"
6 | D3Trees = "e3df1716-f71e-5df9-9e2d-98e193103c45"
7 | DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
8 | DiscreteValueIteration = "4b033969-44f6-5439-a48b-c11fa3648068"
9 | Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
10 | GridInterpolations = "bb4c363b-b914-514b-8517-4eb369bc008a"
11 | IJulia = "7073ff75-c697-5162-941a-fcdaad2a7d2a"
12 | Interact = "c601a237-2ae4-5e1e-952c-7a85b0c7eef1"
13 | LaserTag = "041f53e1-e4f8-54ec-814d-e9e995aa38d4"
14 | MCTS = "e12ccd36-dcad-5f33-8774-9175229e7b33"
15 | NBInclude = "0db19996-df87-5ea3-a455-e3a50d440464"
16 | PGFPlots = "3b7a836e-365b-5785-a47d-02c71176b4aa"
17 | POMDPModelTools = "08074719-1b2a-587c-a292-00f91cc44415"
18 | POMDPModels = "355abbd5-f08e-5560-ac9e-8b5f2592a0ca"
19 | POMDPPolicies = "182e52fb-cfd0-5e46-8c26-fd0667c990f4"
20 | POMDPSimulators = "e0d0a172-29c6-5d4e-96d0-f262df5d01fd"
21 | POMDPToolbox = "0729bffe-8e6b-52fa-a3fa-893719b744f4"
22 | POMDPs = "a93abf59-7444-517b-a68a-c42f96afdd7d"
23 | ParticleFilters = "c8b314e2-9260-5cf8-ae76-3be7461ca6d0"
24 | Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
25 | PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0"
26 | PyPlot = "d330b81b-6aea-500a-939a-2ce795aea3ee"
27 | RDatasets = "ce6b1742-4840-55fa-b093-852dadbb1d8b"
28 | Reactive = "a223df75-4e93-5b7c-acf9-bdd599c0f4de"
29 | SARSOP = "cef570c6-3a94-5604-96b7-1a5e143043f2"
30 | StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
31 | TikzPictures = "37f6aa50-8035-52d0-81c2-5a1d08754b2d"
32 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # AA228 Notebooks
2 |
3 | [](https://travis-ci.org/sisl/aa228-notebook)
4 |
5 | These notebooks are used for [AA228/CS238: Decision Making under Uncertainty](https://aa228.stanford.edu) taught by [Mykel Kochenderfer](https://mykel.kochenderfer.com) at Stanford University.
6 |
--------------------------------------------------------------------------------
/alpha_plots.jl:
--------------------------------------------------------------------------------
1 | using PGFPlots
2 |
3 | alpha2vec(alpha::Dict) = [ alpha[:not_hungry], alpha[:hungry] ]
4 |
5 | function plot(alpha::Dict)
6 | Plots.Linear([0,1], alpha2vec(alpha))
7 | end
8 |
9 | function plot(alphas::Vector{Dict{Symbol, Float64}})
10 | plot_array = Plots.Linear[]
11 | for alpha in alphas
12 | push!(plot_array, Plots.Linear([0,1], alpha2vec(alpha), style="red,solid,thick", mark="none") )
13 | end
14 | #return plot_array
15 | Axis(plot_array, xlabel="P(hungry=true)", xmin=0,xmax=1)
16 | end
17 |
--------------------------------------------------------------------------------
/baby.jl:
--------------------------------------------------------------------------------
1 | function T(s, a, sp)
2 |
3 | # if we feed the baby, probability that it becomes not hungry is 1.0
4 | if a == :feed
5 | if sp == :not_hungry
6 | return 1.0
7 | else
8 | return 0.0
9 | end
10 |
11 | # if we don't feed baby...
12 | else
13 | # baby remains hungry if unfed
14 | if s == :hungry
15 | if sp == :hungry
16 | return 1.0
17 | else
18 | return 0.0
19 | end
20 | else
21 | # 10% chance of baby becoming hungry given it is not hungry and unfed
22 | if sp == :hungry
23 | return 0.1
24 | else
25 | return 0.9
26 | end
27 | end
28 | end
29 |
30 | end
31 |
32 | function O(a, sp, o)
33 | if sp == :hungry
34 | p_cry = 0.8
35 | else
36 | p_cry = 0.1
37 | end
38 |
39 | if o == :cry
40 | return p_cry
41 | else
42 | return 1.0 - p_cry
43 | end
44 | end
45 |
46 | function update_belief(b, a, o)
47 | bp = Dict()
48 | for sp in [:hungry, :not_hungry]
49 | sum_over_s = 0.0
50 | for s in [:hungry, :not_hungry]
51 | sum_over_s += T(s, a, sp) * b[s]
52 | end
53 | bp[sp] = O(a, sp, o) * sum_over_s
54 | end
55 |
56 | # normalize so that probabilities sum to 1
57 | bp_sum = bp[:hungry] + bp[:not_hungry]
58 | bp[:hungry] = bp[:hungry] / bp_sum
59 | bp[:not_hungry] = bp[:not_hungry] / bp_sum
60 |
61 | return bp
62 | end
63 |
--------------------------------------------------------------------------------
/bandits.jl:
--------------------------------------------------------------------------------
1 | using Printf
2 | using Random
3 | using PGFPlots
4 |
5 | mutable struct Bandit
6 | θ::Vector{Float64} # true bandit probabilities
7 | end
8 | Bandit(k::Integer) = Bandit(rand(k))
9 | pull(b::Bandit, i::Integer) = rand() < b.θ[i]
10 | numArms(b::Bandit) = length(b.θ)
11 |
12 | function _get_string_list_of_percentages(bandit_odds::Vector{R}) where {R<:Real}
13 | strings = map(θ->Printf.@sprintf("%.2f percent", 100θ), bandit_odds)
14 | retval = strings[1]
15 | for i in 2 : length(strings)
16 | retval = retval * ", " * strings[i]
17 | end
18 | retval
19 | end
20 |
21 | function banditTrial(b)
22 |
23 | for i in 1 : numArms(b)
24 | but=button("Arm $i",value=0)
25 | display(but)
26 | wins=Observable(0)
27 | Interact.@on &but>0 ? (wins[] = wins[]+pull(b,i)) : 0
28 | display(map(s -> Printf.@sprintf("%d wins out of %d tries (%d percent)", wins[], but[], 100*wins[]/but[]), but))
29 | # NOTE: we used to use the latex() wrapper
30 | end
31 |
32 | t = togglebuttons(["Hide", "Show"], value="Hide", label="True Params")
33 | display(t)
34 | display(map(v -> v == "Show" ? _get_string_list_of_percentages(b.θ) : "", t))
35 | end
36 |
37 | function banditEstimation(b)
38 | B = [button("Arm $i") for i = 1:numArms(b)]
39 | for i in 1 : numArms(b)
40 | but=button("Arm $i",value=0)
41 | display(but)
42 | wins=Observable(0)
43 | Interact.@on &but>0 ? (wins[] = wins[]+pull(b,i)) : 0
44 | display(map(s -> Printf.@sprintf("%d wins out of %d tries (%d percent)", wins[], but[], 100*wins[]/but[]), but))
45 | display(map(s -> begin
46 | w = wins[]
47 | t = but[]
48 | Axis([
49 | Plots.Linear(θ->pdf(Beta(w+1, t-w+1), θ), (0,1), legendentry="Beta($(w+1), $(t-w+1))")
50 | ],
51 | xmin=0,xmax=1,ymin=0, width="15cm", height="10cm")
52 | end, but
53 | ))
54 | end
55 | t = togglebuttons(["Hide", "Show"], value="Hide", label="True Params")
56 | display(t)
57 | display(map(v -> v == "Show" ? string(b.θ) : "", t))
58 | end
59 |
60 | mutable struct BanditStatistics
61 | numWins::Vector{Int}
62 | numTries::Vector{Int}
63 | BanditStatistics(k::Int) = new(zeros(k), zeros(k))
64 | end
65 | numArms(b::BanditStatistics) = length(b.numWins)
66 | function update!(b::BanditStatistics, i::Int, success::Bool)
67 | b.numTries[i] += 1
68 | if success
69 | b.numWins[i] += 1
70 | end
71 | end
72 | # win probability assuming uniform prior
73 | winProbabilities(b::BanditStatistics) = (b.numWins .+ 1)./(b.numTries .+ 2)
74 |
75 | abstract type BanditPolicy end
76 |
77 | reset!(p::BanditPolicy) = nothing
78 |
79 | function simulate(b::Bandit, policy::BanditPolicy; steps = 10)
80 | wins = zeros(Int, steps)
81 | s = BanditStatistics(numArms(b))
82 | reset!(policy)
83 | for step = 1:steps
84 | i = arm(policy, s)
85 | win = pull(b, i)
86 | update!(s, i, win)
87 | wins[step] = win
88 | end
89 | wins
90 | end
91 |
92 | function simulateAverage(b::Bandit, policy::BanditPolicy; steps = 10, iterations = 10)
93 | ret = zeros(Int, steps)
94 | for i = 1:iterations
95 | ret .+= simulate(b, policy, steps=steps)
96 | end
97 | ret ./ iterations
98 | end
99 |
100 | function learningCurves(b::Bandit, policies; steps=10, iterations=10)
101 | lines = Plots.Linear[]
102 | for (name, policy) in policies
103 | results = simulateAverage(b, policy; steps=steps, iterations=iterations)
104 | push!(lines, Plots.Linear(results, legendentry=name, style="very thick", mark="none"))
105 | end
106 | return lines
107 | end
108 |
--------------------------------------------------------------------------------
/gridworld.jl:
--------------------------------------------------------------------------------
1 | using POMDPs
2 |
3 | # Problem based on https://www.cs.ubc.ca/~poole/demos/mdp/vi.html
4 |
5 | using TikzPictures
6 | using Printf
7 |
8 | mutable struct DMUGridWorld <: MDP{Int, Symbol}
9 | S::Vector{Int}
10 | A::Vector{Symbol}
11 | T::Array{Float64,3}
12 | R::Matrix{Float64}
13 | discount::Float64
14 | actionIndex::Dict{Symbol, Int}
15 | nextStates::Dict{Tuple{Int, Symbol}, Vector{Int}}
16 | end
17 |
18 | actions(g::DMUGridWorld) = g.A
19 | states(g::DMUGridWorld) = g.S
20 | n_actions(g::DMUGridWorld) = length(g.A)
21 | n_states(g::DMUGridWorld) = length(g.S)
22 | reward(g::DMUGridWorld, s::Int, a::Symbol) = g.R[s, g.actionIndex[a]]
23 | transition_pdf(g::DMUGridWorld, s0::Int, a::Symbol, s1::Int) = g.T[s0, g.actionIndex[a], s1]
24 | discount(g::DMUGridWorld) = g.discount
25 | next_states(g::DMUGridWorld, s, a) = g.nextStates[(s, a)]
26 | state_index(g::DMUGridWorld, s) = s
27 | action_index(g::DMUGridWorld, a) = g.actionIndex[a]
28 |
29 | function locals(mdp::MDP)
30 | S = states(mdp)
31 | A = actions(mdp)
32 | T = (s0, a, s1) -> transition_pdf(mdp, s0, a, s1)
33 | R = (s, a) -> reward(mdp, s, a)
34 | gamma = discount(mdp)
35 | (S, A, T, R, gamma)
36 | end
37 |
38 | s2xy(s) = Tuple(CartesianIndices((10,10))[s])
39 |
40 | function xy2s(x, y)
41 | x = max(x, 1)
42 | y = max(y, 1)
43 | x = min(x, 10)
44 | y = min(y, 10)
45 | LinearIndices((10, 10))[x,y]
46 | end
47 |
48 | function DMUGridWorld()
49 | A = [:left, :right, :up, :down]
50 | S = 1:100
51 | T = zeros(length(S), length(A), length(S))
52 | R = zeros(length(S), length(A))
53 | for s in S
54 | (x, y) = s2xy(s)
55 | if x == 3 && y == 8
56 | R[s, :] .= 3
57 | elseif x == 8 && y == 9
58 | R[s, :] .= 10
59 | else
60 | if x == 8 && y == 4
61 | R[s, :] .= -10
62 | elseif x == 5 && y == 4
63 | R[s, :] .= -5
64 | elseif x == 1
65 | if y == 1 || y == 10
66 | R[s, :] .= -0.2
67 | else
68 | R[s, :] .= -0.1
69 | end
70 |
71 | R[s, 3] = -0.7
72 | elseif x == 10
73 | if y == 1 || y == 10
74 | R[s, :] .= -0.2
75 | else
76 | R[s, :] .= -0.1
77 | end
78 | R[s, 4] = -0.7
79 | elseif y == 1
80 | if x == 1 || x == 10
81 | R[s, :] .= -0.2
82 | else
83 | R[s, :] .= -0.1
84 | end
85 | R[s, 1] = -0.7
86 | elseif y == 10
87 | if x == 1 || x == 10
88 | R[s, :] .= -0.2
89 | else
90 | R[s, :] .= -0.1
91 | end
92 | R[s, 2] = -0.7
93 | end
94 | for a in A
95 | if a == :left
96 | T[s, 1, xy2s(x, y - 1)] += 0.7
97 | T[s, 1, xy2s(x, y + 1)] += 0.1
98 | T[s, 1, xy2s(x - 1, y)] += 0.1
99 | T[s, 1, xy2s(x + 1, y)] += 0.1
100 | elseif a == :right
101 | T[s, 2, xy2s(x, y + 1)] += 0.7
102 | T[s, 2, xy2s(x, y - 1)] += 0.1
103 | T[s, 2, xy2s(x - 1, y)] += 0.1
104 | T[s, 2, xy2s(x + 1, y)] += 0.1
105 | elseif a == :up
106 | T[s, 3, xy2s(x - 1, y)] += 0.7
107 | T[s, 3, xy2s(x + 1, y)] += 0.1
108 | T[s, 3, xy2s(x, y - 1)] += 0.1
109 | T[s, 3, xy2s(x, y + 1)] += 0.1
110 | elseif a == :down
111 | T[s, 4, xy2s(x + 1, y)] += 0.7
112 | T[s, 4, xy2s(x - 1, y)] += 0.1
113 | T[s, 4, xy2s(x, y - 1)] += 0.1
114 | T[s, 4, xy2s(x, y + 1)] += 0.1
115 | end
116 | end
117 | end
118 | end
119 | R[1,1] = -0.8
120 | R[10,1] = -0.8
121 | R[91,2] = -0.8
122 | R[100,2] = -0.8
123 | R[1,3] = -0.8
124 | R[91,3] = -0.8
125 | R[10,4] = -0.8
126 | R[100,4] = -0.8
127 | discount = 0.9
128 | nextStates = Dict([(S[si], A[ai])=>findall(x->x!=0, T[si, ai, :]) for si=1:length(S), ai=1:length(A)])
129 | DMUGridWorld(S, A, T, R, discount, Dict([A[i]=>i for i=1:length(A)]), nextStates)
130 | end
131 |
132 | function colorval(val, brightness::Real = 1.0)
133 | val = convert(Vector{Float64}, val)
134 | x = 255 .- min.(255, 255 * (abs.(val) ./ 10.0) .^ brightness)
135 | r = 255 * ones(size(val))
136 | g = 255 * ones(size(val))
137 | b = 255 * ones(size(val))
138 | r[val .>= 0] .= x[val .>= 0]
139 | b[val .>= 0] .= x[val .>= 0]
140 | g[val .< 0] .= x[val .< 0]
141 | b[val .< 0] .= x[val .< 0]
142 | (r, g, b)
143 | end
144 |
145 | function plot(g::DMUGridWorld, f::Function)
146 | V = map(f, g.S)
147 | plot(g, V)
148 | end
149 |
150 | function plot(obj::DMUGridWorld, V::Vector; curState=0)
151 | o = IOBuffer()
152 | sqsize = 1.0
153 | twid = 0.05
154 | (r, g, b) = colorval(V)
155 | for s = obj.S
156 | (yval, xval) = s2xy(s)
157 | yval = 10 - yval
158 | println(o, "\\definecolor{currentcolor}{RGB}{$(r[s]),$(g[s]),$(b[s])}")
159 | println(o, "\\fill[currentcolor] ($((xval-1) * sqsize),$((yval) * sqsize)) rectangle +($sqsize,$sqsize);")
160 | if s == curState
161 | println(o, "\\fill[orange] ($((xval-1) * sqsize),$((yval) * sqsize)) rectangle +($sqsize,$sqsize);")
162 | end
163 | vs = Printf.@sprintf("%0.2f", V[s])
164 | println(o, "\\node[above right] at ($((xval-1) * sqsize), $((yval) * sqsize)) {\$$(vs)\$};")
165 | end
166 | println(o, "\\draw[black] grid(10,10);")
167 | tikzDeleteIntermediate(false)
168 | TikzPicture(String(take!(o)), options="scale=1.25")
169 | end
170 |
171 | function plot(g::DMUGridWorld, f::Function, policy::Function; curState=0)
172 | V = map(f, g.S)
173 | plot(g, V, policy, curState=curState)
174 | end
175 |
176 | function plot(obj::DMUGridWorld, V::Vector, policy::Function; curState=0)
177 | P = map(policy, obj.S)
178 | plot(obj, V, P, curState=curState)
179 | end
180 |
181 | function plot(obj::DMUGridWorld, V::Vector, policy::Vector; curState=0)
182 | o = IOBuffer()
183 | sqsize = 1.0
184 | twid = 0.05
185 | (r, g, b) = colorval(V)
186 | for s in obj.S
187 | (yval, xval) = s2xy(s)
188 | yval = 10 - yval
189 | println(o, "\\definecolor{currentcolor}{RGB}{$(r[s]),$(g[s]),$(b[s])}")
190 | println(o, "\\fill[currentcolor] ($((xval-1) * sqsize),$((yval) * sqsize)) rectangle +($sqsize,$sqsize);")
191 | if s == curState
192 | println(o, "\\fill[orange] ($((xval-1) * sqsize),$((yval) * sqsize)) rectangle +($sqsize,$sqsize);")
193 | end
194 | end
195 | println(o, "\\begin{scope}[fill=gray]")
196 | for s in obj.S
197 | (yval, xval) = s2xy(s)
198 | yval = 10 - yval + 1
199 | c = [xval, yval] * sqsize .- sqsize / 2
200 | C = [c'; c'; c']'
201 | RightArrow = [0 0 sqsize/2; twid -twid 0]
202 | if policy[s] == :left
203 | A = [-1 0; 0 -1] * RightArrow + C
204 | println(o, "\\fill ($(A[1]), $(A[2])) -- ($(A[3]), $(A[4])) -- ($(A[5]), $(A[6])) -- cycle;")
205 | end
206 | if policy[s] == :right
207 | A = RightArrow + C
208 | println(o, "\\fill ($(A[1]), $(A[2])) -- ($(A[3]), $(A[4])) -- ($(A[5]), $(A[6])) -- cycle;")
209 | end
210 | if policy[s] == :up
211 | A = [0 -1; 1 0] * RightArrow + C
212 | println(o, "\\fill ($(A[1]), $(A[2])) -- ($(A[3]), $(A[4])) -- ($(A[5]), $(A[6])) -- cycle;")
213 | end
214 | if policy[s] == :down
215 | A = [0 1; -1 0] * RightArrow + C
216 | println(o, "\\fill ($(A[1]), $(A[2])) -- ($(A[3]), $(A[4])) -- ($(A[5]), $(A[6])) -- cycle;")
217 | end
218 |
219 | vs = Printf.@sprintf("%0.2f", V[s])
220 | println(o, "\\node[above right] at ($((xval-1) * sqsize), $((yval-1) * sqsize)) {\$$(vs)\$};")
221 | end
222 | println(o, "\\end{scope}");
223 | println(o, "\\draw[black] grid(10,10);");
224 | TikzPicture(String(take!(o)), options="scale=1.25")
225 | end
226 |
227 | # simulates taking action a from s
228 | function simulate(g::DMUGridWorld, s::Int, a::Symbol)
229 | probs = Float64[]
230 | if length(next_states(g,s,a)) == 0
231 | println("s = ", s)
232 | println("a = ", a)
233 | end
234 | for sp in next_states(g, s, a)
235 | push!(probs, transition_pdf(g, s, a, sp) )
236 | end
237 |
238 | # make sure these sum to 1. They should, but let's be safe.
239 | probs = probs / sum(probs)
240 |
241 | # sample a random value from next states
242 | rand_val = rand()
243 | sampled_idx = 1
244 | prob_sum = 0.0
245 | i = 1
246 | while true
247 | prob_sum += probs[i]
248 | if rand_val < prob_sum
249 | sampled_idx = i
250 | break
251 | end
252 | i += 1
253 | end
254 | sp = next_states(g,s,a)[sampled_idx]
255 |
256 | return sp, reward(g,s,a)
257 | end
258 |
--------------------------------------------------------------------------------
/helpers.jl:
--------------------------------------------------------------------------------
1 | using Printf
2 | using LinearAlgebra
3 | macro max(range, ex)
4 | :(maximum($(Expr(:typed_comprehension, :Float64, ex, range))))
5 | end
6 | macro sum(range, ex)
7 | :(sum($(Expr(:typed_comprehension, :Float64, ex, range))))
8 | end
9 | macro min(range, ex)
10 | :(minimum($(Expr(:typed_comprehension, :Float64, ex, range))))
11 | end
12 | macro prod(range, ex)
13 | :(prod($(Expr(:typed_comprehension, :Float64, ex, range))))
14 | end
15 | macro argmax(range, ex)
16 | @assert(range.head == :in)
17 | @assert(length(range.args) == 2)
18 | :($(range.args[2])[indmax($(Expr(:typed_comprehension, :Float64, ex, range)))])
19 | end
20 | macro argmin(range, ex)
21 | @assert(range.head == :in)
22 | @assert(length(range.args) == 2)
23 | :($(range.args[2])[indmin($(Expr(:typed_comprehension, :Float64, ex, range)))])
24 | end
25 | macro array(range, ex)
26 | :($(Expr(:typed_comprehension, :Float64, ex, range)))
27 | end
28 |
29 | function polyfit(x, y, n)
30 | A = [float(xi)^p for xi in x, p = 0:n]
31 | (q, r) = LinearAlgebra.qr(A)
32 | r \ (q[:,1:n+1]' * y)
33 | end
34 |
35 | function prettyPolynomial(λ)
36 | o = IOBuffer()
37 | Printf.@printf(o, "\$")
38 | for i = 1:length(λ)
39 | if i == 1
40 | Printf.@printf(o, "%0.2f", λ[i])
41 | elseif i == 2
42 | if λ[i] < 0
43 | Printf.@printf(o, "%0.2f x", λ[i])
44 | else
45 | Printf.@printf(o, "+%0.2f x", λ[i])
46 | end
47 | else
48 | if λ[i] < 0
49 | Printf.@printf(o, "%0.2fx^{%d}", λ[i], i-1)
50 | else
51 | Printf.@printf(o, "+%0.2fx^{%d}", λ[i], i-1)
52 | end
53 | end
54 | end
55 | Printf.@printf(o, "\$")
56 | String(take!(o))
57 | end
58 |
59 | using TikzPictures
60 |
61 | function plot_chain(len; fill::Dict{Int,String}=Dict{Int64,String}())
62 | str = "\\draw "
63 | for i in 1:len
64 | fl = get(fill, i, "white")
65 | str = string(str, "($(i)cm, 0cm) node[draw=black,circle,fill=$fl]{$i}")
66 | if i == len
67 | str = string(str, ";")
68 | else
69 | str = string(str, " -- ")
70 | end
71 | end
72 | return TikzPicture(str)
73 | end
74 |
--------------------------------------------------------------------------------
/install.jl:
--------------------------------------------------------------------------------
1 | import Pkg
2 |
3 | @info("Adding JuliaPOMDP Package Registry to your global list of registries.")
4 | Pkg.add("POMDPs")
5 | using POMDPs
6 | POMDPs.add_registry()
7 |
8 | ENV["PYTHON"]=""
9 |
10 | projdir = dirname(@__FILE__())
11 | toml = open(joinpath(projdir, "Project.toml")) do f
12 | Pkg.TOML.parse(f)
13 | end
14 | pkgs = collect(keys(toml["deps"]))
15 | pkgstring = string([pkg*"\n " for pkg in pkgs]...)
16 | @info("""
17 | Installing the following packages to the current environment:
18 |
19 | $pkgstring
20 | """)
21 |
22 | Pkg.add(pkgs)
23 |
24 | @info("Dependency install complete! (check for errors)")
25 |
--------------------------------------------------------------------------------
/rl.jl:
--------------------------------------------------------------------------------
1 | using Distributions
2 | using StatsBase
3 | using Random
4 | include("gridworld.jl")
5 | include("helpers.jl")
6 |
7 | mutable struct MappedDiscreteMDP{SType,AType} <: MDP{SType,AType}
8 | S::Vector{SType}
9 | A::Vector{AType}
10 | T::Array{Float64,3}
11 | R::Matrix{Float64}
12 | discount::Float64
13 | stateIndex::Dict
14 | actionIndex::Dict
15 | nextStates
16 | end
17 |
18 | function MappedDiscreteMDP(S::Vector, A::Vector, T, R; discount=0.9)
19 | stateIndex = Dict([S[i]=>i for i in 1:length(S)])
20 | actionIndex = Dict([A[i]=>i for i in 1:length(A)])
21 | nextStates = Dict([(S[si], A[ai])=>S[findall(x->x!=0, T[si, ai, :])] for si=1:length(S), ai=1:length(A)])
22 | MappedDiscreteMDP(S, A, T, R, discount, stateIndex, actionIndex, nextStates)
23 | end
24 |
25 | MappedDiscreteMDP(S::Vector, A::Vector; discount=0.9) =
26 | MappedDiscreteMDP(S, A,
27 | zeros(length(S), length(A), length(S)),
28 | zeros(length(S), length(A)),
29 | discount=discount)
30 |
31 | actions(mdp::MappedDiscreteMDP) = mdp.A
32 | states(mdp::MappedDiscreteMDP) = mdp.S
33 | n_states(mdp::MappedDiscreteMDP) = length(mdp.S)
34 | n_actions(mdp::MappedDiscreteMDP) = length(mdp.A)
35 | reward(mdp::MappedDiscreteMDP, s, a) = mdp.R[mdp.stateIndex[s], mdp.actionIndex[a]]
36 | transition_pdf(mdp::MappedDiscreteMDP, s0, a, s1) = mdp.T[mdp.stateIndex[s0], mdp.actionIndex[a], mdp.stateIndex[s1]]
37 | discount(mdp::MappedDiscreteMDP) = mdp.discount
38 | state_index(mdp::MappedDiscreteMDP, s) = mdp.stateIndex[s]
39 | action_index(mdp::MappedDiscreteMDP, a) = mdp.actionIndex[s]
40 | next_states(mdp::MappedDiscreteMDP, s, a) = mdp.nextStates[(s, a)]
41 |
42 |
43 | rand_state(mdp::MDP) = states(mdp)[rand(DiscreteUniform(1,n_states(mdp)))]
44 |
45 | function value_iteration(mdp::MDP, iterations::Integer)
46 | V = zeros(n_states(mdp))
47 | Q = zeros(n_states(mdp), n_actions(mdp))
48 | value_iteration!(V, Q, mdp, iterations)
49 | (V, Q)
50 | end
51 |
52 | function value_iteration!(V::Vector, Q::Matrix, mdp::MDP, iterations::Integer)
53 | (S, A, T, R, discount) = locals(mdp)
54 | V_old = copy(V)
55 | for i = 1:iterations
56 | for s0i in 1:n_states(mdp)
57 | s0 = S[s0i]
58 | for ai = 1:n_actions(mdp)
59 | a = A[ai]
60 | Q[s0i,ai] = R(s0, a) + discount * sum([0.0; [T(s0, a, s1)*V_old[state_index(mdp, s1)] for s1 in next_states(mdp, s0, a)]])
61 | end
62 | V[s0i] = maximum(Q[s0i,:])
63 | end
64 | copyto!(V_old, V)
65 | end
66 | end
67 |
68 | function update_parameters!(mdp::MappedDiscreteMDP, N, Nsa, ρ, s, a)
69 | si = mdp.stateIndex[s]
70 | ai = mdp.actionIndex[a]
71 | denom = Nsa[si, ai]
72 | mdp.T[si, ai, :] = N[si, ai, :] ./ denom
73 | mdp.R[si, ai] = ρ[si, ai] / denom
74 | mdp.nextStates[(s, a)]= mdp.S[findall(x->x!=0, mdp.T[si, ai, :])]
75 | end
76 |
77 | function isterminal(mdp::MDP, s0, a)
78 | S1 = next_states(mdp, s0, a)
79 | length(S1) == 0 || 0 == sum(s1 -> transition_pdf(mdp, s0, a, s1), S1)
80 | end
81 |
82 | function generate_s(mdp::MDP, s0, a, rng::AbstractRNG=Random.GLOBAL_RNG)
83 | p = [transition_pdf(mdp, s0, a, s1) for s1 in states(mdp)]
84 | s1i = sample(rng, Weights(p))
85 | states(mdp)[s1i]
86 | end
87 |
88 | mutable struct MLRL <: Policy
89 | N::Array{Float64,3} # transition counts
90 | Nsa::Matrix{Float64} # state-action counts
91 | ρ::Matrix{Float64} # sum of rewards
92 | lastState
93 | lastAction
94 | lastReward
95 | newEpisode
96 | mdp::MappedDiscreteMDP
97 | Q::Matrix{Float64}
98 | V::Vector{Float64}
99 | iterations::Int
100 | epsilon::Float64 # probability of exploration
101 | function MLRL(S, A; discount=0.9, iterations=20, epsilon=0.2)
102 | N = zeros(length(S), length(A), length(S))
103 | Nsa = zeros(length(S), length(A))
104 | ρ = zeros(length(S), length(A))
105 | lastState = nothing
106 | lastAction = nothing
107 | lastReward = 0.
108 | mdp = MappedDiscreteMDP(S, A, discount=discount)
109 | Q = zeros(length(S), length(A))
110 | V = zeros(length(S))
111 | newEpisode = true
112 | new(N, Nsa, ρ, lastState, lastAction, lastReward, newEpisode, mdp, Q, V, iterations, epsilon)
113 | end
114 | end
115 |
116 | function reset(policy::MLRL)
117 | if !policy.newEpisode
118 | s0i = policy.mdp.stateIndex[policy.lastState]
119 | ai = policy.mdp.actionIndex[policy.lastAction]
120 | policy.Nsa[s0i, ai] += 1
121 | policy.ρ[s0i, ai] = policy.lastReward
122 | # update Q and V
123 | update_parameters!(policy.mdp, policy.N, policy.Nsa, policy.ρ, policy.lastState, policy.lastAction)
124 | value_iteration!(policy.V, policy.Q, policy.mdp, policy.iterations)
125 | policy.newEpisode = true
126 | end
127 | end
128 |
129 | function update(policy::MLRL, s, a, r)
130 | if policy.newEpisode
131 | policy.newEpisode = false
132 | else
133 | s0i = policy.mdp.stateIndex[policy.lastState]
134 | ai = policy.mdp.actionIndex[policy.lastAction]
135 | s1i = policy.mdp.stateIndex[s]
136 | policy.N[s0i, ai, s1i] += 1
137 | policy.Nsa[s0i, ai] += 1
138 | policy.ρ[s0i, ai] += policy.lastReward
139 | # update Q and V
140 | update_parameters!(policy.mdp, policy.N, policy.Nsa, policy.ρ, policy.lastState, policy.lastAction)
141 | value_iteration!(policy.V, policy.Q, policy.mdp, policy.iterations)
142 | end
143 | policy.lastState = s
144 | policy.lastAction = a
145 | policy.lastReward = r
146 | nothing
147 | end
148 |
149 | function action(policy::MLRL, s)
150 | si = policy.mdp.stateIndex[s]
151 | Qs = policy.Q[si, :]
152 | ais = findall((in)(maximum(Qs)), Qs)
153 | ai = rand(ais)
154 | policy.mdp.A[ai]
155 | end
156 |
157 | function action(policy::MLRL)
158 | if rand() < policy.epsilon
159 | policy.mdp.A[rand(DiscreteUniform(1,numActions(policy.mdp)))]
160 | else
161 | action(policy, policy.lastState)
162 | end
163 | end
164 |
165 | function simulate(mdp::MDP, steps::Integer, policy::Policy; script=[])
166 | S = Any[]
167 | V = Any[]
168 | R = Float64[]
169 | if length(script) == 0
170 | s = rand_state(mdp)
171 | else
172 | s = script[1]
173 | end
174 | for i = 1:steps
175 | push!(S, s)
176 | a = action(policy, s)
177 | r = reward(mdp, s, a)
178 | push!(R, r)
179 | update(policy, s, a, r)
180 | push!(V, copy(policy.V))
181 | if i < length(script)
182 | s = script[i + 1]
183 | else
184 | if isterminal(mdp, s, a)
185 | s = rand_state(mdp)
186 | reset(policy)
187 | else
188 | s = generate_s(mdp, s, a)
189 | end
190 | end
191 | end
192 | (S, R, V)
193 | end
194 |
--------------------------------------------------------------------------------
/runtests.jl:
--------------------------------------------------------------------------------
1 | using NBInclude
2 | using Test
3 |
4 | @testset "notebooks" begin
5 | for d in readdir(".")
6 | # if endswith(d, ".ipynb") && !startswith(d, "08-Markov") && !startswith(d, "09-") && !startswith(d, "11-") && !startswith(d, "16-") && !startswith(d, "POM") # ignore MDP notebook because it fails for some reason
7 | if endswith(d, ".ipynb")
8 | @info("Running "*d)
9 | stuff = "using InteractiveUtils; using NBInclude; ENV[\"PYTHON\"]=\"\"; @nbinclude(\"" * d * "\")"
10 | projdir = dirname(@__FILE__())
11 | cmd = `julia --project=$projdir -e $stuff`
12 | proc = run(pipeline(cmd, stderr=stderr), wait=false)
13 | @test success(proc)
14 | end
15 | end
16 | end
17 |
--------------------------------------------------------------------------------