├── 5_friday └── .foo ├── 3_wednesday └── .foo ├── 4_thursday ├── .foo └── dp_slides.pdf ├── source_files ├── monday │ ├── mathfoo.py │ ├── intro_slides │ │ ├── matinv.jl │ │ ├── gdi.png │ │ ├── gdi2.png │ │ ├── loss.pdf │ │ ├── loss2.jpg │ │ ├── loss3.png │ │ ├── loss4.jpg │ │ ├── numpy.pdf │ │ ├── pvr.png │ │ ├── solow │ │ │ ├── out │ │ │ ├── f_out │ │ │ ├── solow.py │ │ │ ├── solow.c │ │ │ └── solow.f90 │ │ ├── admired.pdf │ │ ├── matlab.pdf │ │ ├── matlab.png │ │ ├── nvidia.png │ │ ├── ppf_plus.pdf │ │ ├── python_vs_rest.png │ │ ├── admired.py │ │ └── loss.py │ ├── ai_revolution │ │ ├── dgx.png │ │ ├── gdi.png │ │ ├── gdi2.png │ │ ├── jax.png │ │ ├── loss.pdf │ │ ├── loss2.jpg │ │ ├── loss3.png │ │ ├── loss4.jpg │ │ ├── python.png │ │ ├── solow │ │ │ ├── out │ │ │ └── solow.c │ │ ├── ppf_plus.pdf │ │ ├── ai_revolution.pdf │ │ ├── loss.py │ │ └── ai_revolution.tex │ ├── us_cities.txt │ └── fun_with_jax.md ├── friday │ └── bianchi.pdf ├── thursday │ ├── dp_slides │ │ ├── main.pdf │ │ ├── figures │ │ │ └── howard_newton_1.pdf │ │ └── tikz │ │ │ ├── triangle2.tex │ │ │ └── flint.tex │ ├── opt_savings_2.md │ └── egm.md ├── tuesday │ ├── temp │ │ ├── foo.f │ │ ├── solow.f90 │ │ ├── test.f90 │ │ └── timed_solow.f90 │ ├── regression.md │ ├── exercises │ │ ├── markov_homework.md │ │ └── simulation_exercises.md │ ├── numpy.md │ ├── inventory_dynamics.md │ ├── lorenz_gini.md │ └── numba.md ├── build_notebooks.py └── wednesday │ ├── ez_preferences.md │ ├── job_search.md │ ├── autodiff.md │ └── inventory_dynamics_jax.md ├── qe-logo-large.png ├── 1_monday ├── solow │ ├── out │ ├── solow.py │ ├── solow.c │ └── solow.f90 ├── ai_revolution.pdf ├── intro_slides.pdf └── sci_comp_intro.pdf ├── README.md ├── .gitignore ├── LICENSE └── 2_tuesday └── regression.ipynb /5_friday/.foo: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /3_wednesday/.foo: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /4_thursday/.foo: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /source_files/monday/mathfoo.py: -------------------------------------------------------------------------------- 1 | pi = 'foobar' 2 | -------------------------------------------------------------------------------- /qe-logo-large.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QuantEcon/cbc_2024/main/qe-logo-large.png -------------------------------------------------------------------------------- /1_monday/solow/out: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QuantEcon/cbc_2024/main/1_monday/solow/out -------------------------------------------------------------------------------- /4_thursday/dp_slides.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QuantEcon/cbc_2024/main/4_thursday/dp_slides.pdf -------------------------------------------------------------------------------- /1_monday/ai_revolution.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QuantEcon/cbc_2024/main/1_monday/ai_revolution.pdf -------------------------------------------------------------------------------- /1_monday/intro_slides.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QuantEcon/cbc_2024/main/1_monday/intro_slides.pdf -------------------------------------------------------------------------------- /1_monday/sci_comp_intro.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QuantEcon/cbc_2024/main/1_monday/sci_comp_intro.pdf -------------------------------------------------------------------------------- /source_files/friday/bianchi.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QuantEcon/cbc_2024/main/source_files/friday/bianchi.pdf -------------------------------------------------------------------------------- /source_files/monday/intro_slides/matinv.jl: -------------------------------------------------------------------------------- 1 | A = [2.0 -1.0 2 | 5.0 -0.5] 3 | 4 | b = [0.5 1.0]' 5 | 6 | x = inv(A) * b 7 | -------------------------------------------------------------------------------- /source_files/monday/ai_revolution/dgx.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QuantEcon/cbc_2024/main/source_files/monday/ai_revolution/dgx.png -------------------------------------------------------------------------------- /source_files/monday/ai_revolution/gdi.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QuantEcon/cbc_2024/main/source_files/monday/ai_revolution/gdi.png -------------------------------------------------------------------------------- /source_files/monday/ai_revolution/gdi2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QuantEcon/cbc_2024/main/source_files/monday/ai_revolution/gdi2.png -------------------------------------------------------------------------------- /source_files/monday/ai_revolution/jax.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QuantEcon/cbc_2024/main/source_files/monday/ai_revolution/jax.png -------------------------------------------------------------------------------- /source_files/monday/ai_revolution/loss.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QuantEcon/cbc_2024/main/source_files/monday/ai_revolution/loss.pdf -------------------------------------------------------------------------------- /source_files/monday/intro_slides/gdi.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QuantEcon/cbc_2024/main/source_files/monday/intro_slides/gdi.png -------------------------------------------------------------------------------- /source_files/monday/intro_slides/gdi2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QuantEcon/cbc_2024/main/source_files/monday/intro_slides/gdi2.png -------------------------------------------------------------------------------- /source_files/monday/intro_slides/loss.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QuantEcon/cbc_2024/main/source_files/monday/intro_slides/loss.pdf -------------------------------------------------------------------------------- /source_files/monday/intro_slides/loss2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QuantEcon/cbc_2024/main/source_files/monday/intro_slides/loss2.jpg -------------------------------------------------------------------------------- /source_files/monday/intro_slides/loss3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QuantEcon/cbc_2024/main/source_files/monday/intro_slides/loss3.png -------------------------------------------------------------------------------- /source_files/monday/intro_slides/loss4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QuantEcon/cbc_2024/main/source_files/monday/intro_slides/loss4.jpg -------------------------------------------------------------------------------- /source_files/monday/intro_slides/numpy.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QuantEcon/cbc_2024/main/source_files/monday/intro_slides/numpy.pdf -------------------------------------------------------------------------------- /source_files/monday/intro_slides/pvr.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QuantEcon/cbc_2024/main/source_files/monday/intro_slides/pvr.png -------------------------------------------------------------------------------- /source_files/monday/intro_slides/solow/out: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QuantEcon/cbc_2024/main/source_files/monday/intro_slides/solow/out -------------------------------------------------------------------------------- /source_files/thursday/dp_slides/main.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QuantEcon/cbc_2024/main/source_files/thursday/dp_slides/main.pdf -------------------------------------------------------------------------------- /source_files/monday/ai_revolution/loss2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QuantEcon/cbc_2024/main/source_files/monday/ai_revolution/loss2.jpg -------------------------------------------------------------------------------- /source_files/monday/ai_revolution/loss3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QuantEcon/cbc_2024/main/source_files/monday/ai_revolution/loss3.png -------------------------------------------------------------------------------- /source_files/monday/ai_revolution/loss4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QuantEcon/cbc_2024/main/source_files/monday/ai_revolution/loss4.jpg -------------------------------------------------------------------------------- /source_files/monday/ai_revolution/python.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QuantEcon/cbc_2024/main/source_files/monday/ai_revolution/python.png -------------------------------------------------------------------------------- /source_files/monday/ai_revolution/solow/out: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QuantEcon/cbc_2024/main/source_files/monday/ai_revolution/solow/out -------------------------------------------------------------------------------- /source_files/monday/intro_slides/admired.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QuantEcon/cbc_2024/main/source_files/monday/intro_slides/admired.pdf -------------------------------------------------------------------------------- /source_files/monday/intro_slides/matlab.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QuantEcon/cbc_2024/main/source_files/monday/intro_slides/matlab.pdf -------------------------------------------------------------------------------- /source_files/monday/intro_slides/matlab.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QuantEcon/cbc_2024/main/source_files/monday/intro_slides/matlab.png -------------------------------------------------------------------------------- /source_files/monday/intro_slides/nvidia.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QuantEcon/cbc_2024/main/source_files/monday/intro_slides/nvidia.png -------------------------------------------------------------------------------- /source_files/monday/intro_slides/solow/f_out: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QuantEcon/cbc_2024/main/source_files/monday/intro_slides/solow/f_out -------------------------------------------------------------------------------- /source_files/monday/ai_revolution/ppf_plus.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QuantEcon/cbc_2024/main/source_files/monday/ai_revolution/ppf_plus.pdf -------------------------------------------------------------------------------- /source_files/monday/intro_slides/ppf_plus.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QuantEcon/cbc_2024/main/source_files/monday/intro_slides/ppf_plus.pdf -------------------------------------------------------------------------------- /source_files/monday/ai_revolution/ai_revolution.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QuantEcon/cbc_2024/main/source_files/monday/ai_revolution/ai_revolution.pdf -------------------------------------------------------------------------------- /source_files/monday/intro_slides/python_vs_rest.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QuantEcon/cbc_2024/main/source_files/monday/intro_slides/python_vs_rest.png -------------------------------------------------------------------------------- /source_files/thursday/dp_slides/figures/howard_newton_1.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QuantEcon/cbc_2024/main/source_files/thursday/dp_slides/figures/howard_newton_1.pdf -------------------------------------------------------------------------------- /1_monday/solow/solow.py: -------------------------------------------------------------------------------- 1 | 2 | α = 0.4 3 | s = 0.3 4 | δ = 0.1 5 | n = 1_000 6 | k = 0.2 7 | 8 | for i in range(n): 9 | k = s * k**α + (1 - δ) * k 10 | 11 | print(k) 12 | -------------------------------------------------------------------------------- /source_files/monday/intro_slides/solow/solow.py: -------------------------------------------------------------------------------- 1 | 2 | α = 0.4 3 | s = 0.3 4 | δ = 0.1 5 | n = 1_000 6 | k = 0.2 7 | 8 | for i in range(n): 9 | k = s * k**α + (1 - δ) * k 10 | 11 | print(k) 12 | -------------------------------------------------------------------------------- /source_files/monday/us_cities.txt: -------------------------------------------------------------------------------- 1 | new york: 8244910 2 | los angeles: 3819702 3 | chicago: 2707120 4 | houston: 2145146 5 | philadelphia: 1536471 6 | phoenix: 1469471 7 | san antonio: 1359758 8 | san diego: 1326179 9 | dallas: 1223229 10 | -------------------------------------------------------------------------------- /source_files/monday/intro_slides/admired.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | fig, ax = plt.subplots() 3 | langs = ['Rust', 'Elixir', 'Python','Julia', 'C', 'R', 'Fortran', 'Matlab'] 4 | vals = [88.66, 73, 65.5, 62.7, 43.2, 39.0, 24.4, 18.3] 5 | ax.bar(langs, vals) 6 | ax.set_title("Language I want to work with next year") 7 | plt.show() 8 | -------------------------------------------------------------------------------- /1_monday/solow/solow.c: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | int main() { 5 | double k = 0.2; 6 | double alpha = 0.4; 7 | double s = 0.3; 8 | double delta = 0.1; 9 | int i; 10 | int n = 1000; 11 | for (i = 0; i < n; i++) { 12 | k = s * pow(k, alpha) + (1 - delta) * k; 13 | } 14 | printf("k = %f\n", k); 15 | } 16 | 17 | 18 | -------------------------------------------------------------------------------- /source_files/monday/ai_revolution/solow/solow.c: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | int main() { 5 | double k = 0.2; 6 | double alpha = 0.4; 7 | double s = 0.3; 8 | double delta = 0.1; 9 | int i; 10 | int n = 1000; 11 | for (i = 0; i < n; i++) { 12 | k = s * pow(k, alpha) + (1 - delta) * k; 13 | } 14 | printf("k = %f\n", k); 15 | } 16 | 17 | 18 | -------------------------------------------------------------------------------- /source_files/monday/intro_slides/solow/solow.c: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | int main() { 5 | double k = 0.2; 6 | double alpha = 0.4; 7 | double s = 0.3; 8 | double delta = 0.1; 9 | int i; 10 | int n = 1000; 11 | for (i = 0; i < n; i++) { 12 | k = s * pow(k, alpha) + (1 - delta) * k; 13 | } 14 | printf("k = %f\n", k); 15 | } 16 | 17 | 18 | -------------------------------------------------------------------------------- /1_monday/solow/solow.f90: -------------------------------------------------------------------------------- 1 | program main 2 | implicit none 3 | integer, parameter :: dp=kind(0.d0) 4 | integer :: n=1000 5 | real(dp) :: s=0.3_dp 6 | real(dp) :: a=1.0_dp 7 | real(dp) :: delta=0.1_dp 8 | real(dp) :: alpha=0.4_dp 9 | real(dp) :: k=0.2_dp 10 | integer :: i 11 | do i = 1, n - 1 12 | k = a * s * k**alpha + (1 - delta) * k 13 | end do 14 | print *,'k = ', k 15 | end program main 16 | -------------------------------------------------------------------------------- /source_files/tuesday/temp/foo.f: -------------------------------------------------------------------------------- 1 | subroutine solow_fortran(k0, s, a, delta, alpha, n, kt) 2 | implicit none 3 | integer, parameter :: dp=kind(0.d0) 4 | integer, intent(in) :: n 5 | real(dp), intent(in) :: k0, s, a, delta, alpha 6 | real(dp), intent(out) :: kt 7 | real(dp) :: k 8 | integer :: i 9 | k = k0 10 | do i = 1, n - 1 11 | k = a * s * k**alpha + (1 - delta) * k 12 | end do 13 | end subroutine solow_fortran 14 | 15 | -------------------------------------------------------------------------------- /source_files/monday/intro_slides/solow/solow.f90: -------------------------------------------------------------------------------- 1 | program main 2 | implicit none 3 | integer, parameter :: dp=kind(0.d0) 4 | integer :: n=1000 5 | real(dp) :: s=0.3_dp 6 | real(dp) :: a=1.0_dp 7 | real(dp) :: delta=0.1_dp 8 | real(dp) :: alpha=0.4_dp 9 | real(dp) :: k=0.2_dp 10 | integer :: i 11 | do i = 1, n - 1 12 | k = a * s * k**alpha + (1 - delta) * k 13 | end do 14 | print *,'k = ', k 15 | end program main 16 | -------------------------------------------------------------------------------- /source_files/tuesday/temp/solow.f90: -------------------------------------------------------------------------------- 1 | program main 2 | implicit none 3 | integer, parameter :: dp=kind(0.d0) 4 | integer :: n=1000 5 | real(dp) :: s=0.3_dp 6 | real(dp) :: a=1.0_dp 7 | real(dp) :: delta=0.1_dp 8 | real(dp) :: alpha=0.4_dp 9 | real(dp) :: k=0.2_dp 10 | integer :: i 11 | do i = 1, n - 1 12 | k = a * s * k**alpha + (1 - delta) * k 13 | end do 14 | print *,'Capital stock = ', k 15 | end program main 16 | -------------------------------------------------------------------------------- /source_files/build_notebooks.py: -------------------------------------------------------------------------------- 1 | #! python 2 | import os 3 | import glob 4 | dirs = 'monday', 'tuesday', 'wednesday', 'thursday', 'friday' 5 | 6 | print('Converting Myst files to notebooks.') 7 | for dir in dirs: 8 | files = glob.glob(f'{dir}/*.md') 9 | for file in files: 10 | name, ext = file.split('.') 11 | print(f'Executing command jupytext --update --output {name}.ipynb {name}.md') 12 | os.system(f'jupytext --update --output {name}.ipynb {name}.md') 13 | print('\nDone.') 14 | 15 | -------------------------------------------------------------------------------- /source_files/monday/ai_revolution/loss.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | from mpl_toolkits.mplot3d.axes3d import Axes3D 3 | from matplotlib import cm 4 | import numpy as np 5 | 6 | def f(x, y): 7 | return -np.cos(x**2 + y**2) / (1 + x**2 + y**2) 8 | 9 | xgrid = np.linspace(-3, 3, 50) 10 | ygrid = xgrid 11 | x, y = np.meshgrid(xgrid, ygrid) 12 | 13 | fig = plt.figure(figsize=(10, 6)) 14 | ax = fig.add_subplot(111, projection='3d') 15 | ax.plot_surface(x, 16 | y, 17 | f(x, y), 18 | rstride=2, cstride=2, 19 | cmap=cm.jet, 20 | alpha=0.7, 21 | linewidth=0.25) 22 | ax.set_zlim(-1.0, 0.5) 23 | plt.show() 24 | -------------------------------------------------------------------------------- /source_files/monday/intro_slides/loss.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | from mpl_toolkits.mplot3d.axes3d import Axes3D 3 | from matplotlib import cm 4 | import numpy as np 5 | 6 | def f(x, y): 7 | return -np.cos(x**2 + y**2) / (1 + x**2 + y**2) 8 | 9 | xgrid = np.linspace(-3, 3, 50) 10 | ygrid = xgrid 11 | x, y = np.meshgrid(xgrid, ygrid) 12 | 13 | fig = plt.figure(figsize=(10, 6)) 14 | ax = fig.add_subplot(111, projection='3d') 15 | ax.plot_surface(x, 16 | y, 17 | f(x, y), 18 | rstride=2, cstride=2, 19 | cmap=cm.jet, 20 | alpha=0.7, 21 | linewidth=0.25) 22 | ax.set_zlim(-1.0, 0.5) 23 | plt.show() 24 | -------------------------------------------------------------------------------- /source_files/thursday/dp_slides/tikz/triangle2.tex: -------------------------------------------------------------------------------- 1 | 2 | 3 | %\tikzset{font={\fontsize{13pt}{12}\selectfont}} 4 | 5 | 6 | \begin{tikzpicture}[thick, 7 | scale=0.8, 8 | shape=circle, 9 | shorten <= 5pt, 10 | shorten >= 5pt, 11 | minimum width=60pt] 12 | 13 | \node[draw] (v) at (0, 0) {$\RR^\Xsf$ (value functions)}; 14 | \node[draw] (g) at (6, -4) {$\RR^\Gsf$ (EV functions)}; 15 | \node[draw] (h) at (-5, -4) {$\RR^\Gsf$ ($Q$-factors)}; 16 | 17 | \draw[->] (v) to node[shift={(5pt, 10pt)}] {$E$} (g); 18 | \draw[->] (g) to node[yshift=-10pt] {$D$} (h); 19 | \draw[->] (h) to node[shift={(-5pt, 10pt)}] {$M$} (v); 20 | 21 | \end{tikzpicture} 22 | 23 | -------------------------------------------------------------------------------- /source_files/thursday/dp_slides/tikz/flint.tex: -------------------------------------------------------------------------------- 1 | 2 | \begin{tikzpicture}[auto, thick, node distance=30pt] 3 | 4 | \node (s) {initial $\sigma$}; 5 | \node (s1) [below of=s] {$v_{\sigma}$-greedy $\sigma'$}; 6 | \node (el1) [below of=s1] {$\vdots$}; 7 | \node (s2) [below of=el1] {$\sigma^*$}; 8 | 9 | \node (v) [right=70pt of s] {$v_{\sigma} = (I - \beta P_\sigma)^{-1} r_\sigma$}; 10 | \node (v1) [below of=v] {$v_{\sigma'} = (I - \beta P_{\sigma'})^{-1} r_{\sigma'}$}; 11 | \node (el2) [below of=v1] {$\vdots$}; 12 | \node (v2) [below of=el2] {$v_{\sigma^*} = (I - \beta P_{\sigma^*})^{-1} r_{\sigma^*}$}; 13 | 14 | \node[draw, ellipse, minimum width=50pt, fit=(s) (s1) (s2)](f1) {}; 15 | \node[draw, ellipse, minimum width=50pt, fit=(v) (v1) (v2)](f2) {}; 16 | 17 | \draw[->, bend left=0] (s) to (v); 18 | \draw[->] (v) to (s1); 19 | \draw[->] (s1) to (v1); 20 | \draw[->] (el2) to (s2); 21 | 22 | \end{tikzpicture} 23 | -------------------------------------------------------------------------------- /source_files/tuesday/temp/test.f90: -------------------------------------------------------------------------------- 1 | 2 | pure function capital(k0, s, a, delta, alpha, n) 3 | implicit none 4 | integer, parameter :: dp=kind(0.d0) 5 | real(dp), intent(in) :: k0, s, a, delta, alpha 6 | real(dp) :: capital, r 7 | integer :: i 8 | integer, intent(in) :: n 9 | capital = k0 10 | do i = 1, n - 1 11 | capital = a * s * capital**alpha + (1 - delta) * capital 12 | end do 13 | return 14 | end function capital 15 | 16 | program main 17 | implicit none 18 | integer, parameter :: dp=kind(0.d0) 19 | real(dp) :: start, finish, x, capital 20 | integer :: n 21 | real(dp) :: s=3.0_dp 22 | real(dp) :: a=1.0_dp 23 | real(dp) :: delta=0.1_dp 24 | real(dp) :: alpha=0.4_dp 25 | n = 1000000 26 | call cpu_time(start) 27 | x = capital(0.2_dp, s, a, delta, alpha, n) 28 | call cpu_time(finish) 29 | print *,'Last val = ', x 30 | print *,'Elapsed time = ', finish - start 31 | end program main 32 | -------------------------------------------------------------------------------- /source_files/tuesday/temp/timed_solow.f90: -------------------------------------------------------------------------------- 1 | 2 | pure function capital(k0, s, a, delta, alpha, n) 3 | implicit none 4 | integer, parameter :: dp=kind(0.d0) 5 | real(dp), intent(in) :: k0, s, a, delta, alpha 6 | real(dp) :: capital, r 7 | integer :: i 8 | integer, intent(in) :: n 9 | capital = k0 10 | do i = 1, n - 1 11 | capital = a * s * capital**alpha + (1 - delta) * capital 12 | end do 13 | return 14 | end function capital 15 | 16 | program main 17 | implicit none 18 | integer, parameter :: dp=kind(0.d0) 19 | real(dp) :: start, finish, x, capital 20 | integer :: n 21 | real(dp) :: s=0.3_dp 22 | real(dp) :: a=1.0_dp 23 | real(dp) :: delta=0.1_dp 24 | real(dp) :: alpha=0.4_dp 25 | n = 100000 26 | call cpu_time(start) 27 | x = capital(0.2_dp, s, a, delta, alpha, n) 28 | call cpu_time(finish) 29 | print *,'Steady state capital stock = ', x 30 | print *,'Elapsed time = ', finish - start 31 | end program main 32 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Central Bank of Chile 2024 Scientific Computing Workshop 2 | 3 | ![](qe-logo-large.png) 4 | 5 | This is the homepage for the [QuantEcon](https://quantecon.org/) scientific 6 | and high performance computing workshop to be held at the Central Bank of 7 | Chile in May 2024. 8 | 9 | 10 | 11 | ## Instructor 12 | 13 | * [John Stachurski](https://johnstachurski.net/) (Australian National University) 14 | 15 | Bio: John Stachurski is a mathematical and computational economist who works on 16 | algorithms at the intersection of dynamic programming, Markov dynamics, 17 | economics, and finance. His work is published in journals such as the Journal 18 | of Finance, the Journal of Economic Theory, Automatica, Econometrica, and 19 | Operations Research. In 2016 he co-founded QuantEcon with Thomas J. Sargent. 20 | 21 | 22 | ## Abstract 23 | 24 | Open source scientific computing environments built around the Python 25 | programming language have expanded rapidly in recent years. They now form the 26 | dominant paradigm in artificial intelligence and many fields within the natural 27 | sciences. Economists can greatly enhance their modeling and data processing 28 | capabilities by exploiting Python's scientific ecosystem. This course will 29 | cover the foundations of Python programming and Python scientific libraries, as 30 | well as showing how they can be used in economic applications for rapid 31 | development and high performance computing. 32 | 33 | ## Topics 34 | 35 | ### Monday: overview and Python intro 36 | 37 | * An overview of modern scientific computing 38 | * AI and its impact on economic modeling 39 | * Quick introduction to Python 40 | 41 | ### Tuesday: scientific Python 42 | 43 | * Linear regression with Python 44 | * Accelerating Python using Numba and Fortran 45 | * Inventory dynamics 46 | * Gini coefficients and Lorenz curves 47 | * Wealth dynamics (simple model) 48 | * Markov chains 49 | 50 | ### Wednesday: JAX, GPUs and autodiff 51 | 52 | * Introduction to JAX and GPU computing 53 | * Automatic differentiation 54 | * Autodiff application: Epstein-Zin preferences 55 | * Wealth dynamics revisited 56 | * Inventory dynamics revisited 57 | * Job search 58 | 59 | ### Thursday: dynamic programming with JAX 60 | 61 | * Dynamic programming: theory and algorithms 62 | * Optimal savings problems (JAX) 63 | * Endogenous grid method (JAX) 64 | 65 | ### Friday: equilibrium models with JAX 66 | 67 | * Aiyagari model 68 | * Arellano sovereign default model 69 | * Bianchi overborrowing model 70 | * Hopenhayn industry model 71 | 72 | 73 | ## Dates 74 | 75 | * May 13th - 17th 76 | 77 | ## Prerequisites 78 | 79 | All participants should bring laptop computers. If possible, participants 80 | should bring laptops with the ability to install open source software. For those 81 | without such permissions, a cloud computing option will be provided. The courses 82 | assume knowledge of the fundamentals of linear algebra, analysis, dynamic optimization 83 | and probability. 84 | 85 | Suitable background can be found in the first few chapters of [Dynamic Programming](https://dp.quantecon.org). 86 | 87 | 88 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # 2 | .ipynb_checkpoints/ 3 | ipynb_checkpoints/ 4 | */.ipynb_checkpoints/ 5 | __pycache__/ 6 | Untitled* 7 | 8 | 9 | ## Core latex/pdflatex auxiliary files: 10 | *.aux 11 | *.lof 12 | *.log 13 | *.lot 14 | *.fls 15 | *.out 16 | *.toc 17 | *.fmt 18 | *.fot 19 | *.cb 20 | *.cb2 21 | .*.lb 22 | 23 | ## Intermediate documents: 24 | *.dvi 25 | *.xdv 26 | *-converted-to.* 27 | # these rules might exclude image files for figures etc. 28 | # *.ps 29 | # *.eps 30 | # *.pdf 31 | 32 | ## Bibliography auxiliary files (bibtex/biblatex/biber): 33 | *.bbl 34 | *.bcf 35 | *.blg 36 | *-blx.aux 37 | *-blx.bib 38 | *.run.xml 39 | 40 | ## Build tool auxiliary files: 41 | *.fdb_latexmk 42 | *.synctex 43 | *.synctex(busy) 44 | *.synctex.gz 45 | *.synctex.gz(busy) 46 | *.pdfsync 47 | 48 | ## Build tool directories for auxiliary files 49 | # latexrun 50 | latex.out/ 51 | 52 | ## Auxiliary and intermediate files from other packages: 53 | # algorithms 54 | *.alg 55 | *.loa 56 | 57 | # achemso 58 | acs-*.bib 59 | 60 | # amsthm 61 | *.thm 62 | 63 | # beamer 64 | *.nav 65 | *.pre 66 | *.snm 67 | *.vrb 68 | 69 | # changes 70 | *.soc 71 | 72 | # comment 73 | *.cut 74 | 75 | # cprotect 76 | *.cpt 77 | 78 | # elsarticle (documentclass of Elsevier journals) 79 | *.spl 80 | 81 | # endnotes 82 | *.ent 83 | 84 | # fixme 85 | *.lox 86 | 87 | # feynmf/feynmp 88 | *.mf 89 | *.mp 90 | *.t[1-9] 91 | *.t[1-9][0-9] 92 | *.tfm 93 | 94 | #(r)(e)ledmac/(r)(e)ledpar 95 | *.end 96 | *.?end 97 | *.[1-9] 98 | *.[1-9][0-9] 99 | *.[1-9][0-9][0-9] 100 | *.[1-9]R 101 | *.[1-9][0-9]R 102 | *.[1-9][0-9][0-9]R 103 | *.eledsec[1-9] 104 | *.eledsec[1-9]R 105 | *.eledsec[1-9][0-9] 106 | *.eledsec[1-9][0-9]R 107 | *.eledsec[1-9][0-9][0-9] 108 | *.eledsec[1-9][0-9][0-9]R 109 | 110 | # glossaries 111 | *.acn 112 | *.acr 113 | *.glg 114 | *.glo 115 | *.gls 116 | *.glsdefs 117 | *.lzo 118 | *.lzs 119 | 120 | # uncomment this for glossaries-extra (will ignore makeindex's style files!) 121 | # *.ist 122 | 123 | # gnuplottex 124 | *-gnuplottex-* 125 | 126 | # gregoriotex 127 | *.gaux 128 | *.gtex 129 | 130 | # htlatex 131 | *.4ct 132 | *.4tc 133 | *.idv 134 | *.lg 135 | *.trc 136 | *.xref 137 | 138 | # hyperref 139 | *.brf 140 | 141 | # knitr 142 | *-concordance.tex 143 | # TODO Comment the next line if you want to keep your tikz graphics files 144 | *.tikz 145 | *-tikzDictionary 146 | 147 | # listings 148 | *.lol 149 | 150 | # luatexja-ruby 151 | *.ltjruby 152 | 153 | # makeidx 154 | *.idx 155 | *.ilg 156 | *.ind 157 | 158 | # minitoc 159 | *.maf 160 | *.mlf 161 | *.mlt 162 | *.mtc[0-9]* 163 | *.slf[0-9]* 164 | *.slt[0-9]* 165 | *.stc[0-9]* 166 | 167 | # minted 168 | _minted* 169 | *.pyg 170 | 171 | # morewrites 172 | *.mw 173 | 174 | # nomencl 175 | *.nlg 176 | *.nlo 177 | *.nls 178 | 179 | # pax 180 | *.pax 181 | 182 | # pdfpcnotes 183 | *.pdfpc 184 | 185 | # sagetex 186 | *.sagetex.sage 187 | *.sagetex.py 188 | *.sagetex.scmd 189 | 190 | # scrwfile 191 | *.wrt 192 | 193 | # sympy 194 | *.sout 195 | *.sympy 196 | sympy-plots-for-*.tex/ 197 | 198 | # pdfcomment 199 | *.upa 200 | *.upb 201 | 202 | # pythontex 203 | *.pytxcode 204 | pythontex-files-*/ 205 | 206 | # tcolorbox 207 | *.listing 208 | 209 | # thmtools 210 | *.loe 211 | 212 | # TikZ & PGF 213 | *.dpth 214 | *.md5 215 | *.auxlock 216 | 217 | # todonotes 218 | *.tdo 219 | 220 | # vhistory 221 | *.hst 222 | *.ver 223 | 224 | # easy-todo 225 | *.lod 226 | 227 | # xcolor 228 | *.xcp 229 | 230 | # xmpincl 231 | *.xmpi 232 | 233 | # xindy 234 | *.xdy 235 | 236 | # xypic precompiled matrices and outlines 237 | *.xyc 238 | *.xyd 239 | 240 | # endfloat 241 | *.ttt 242 | *.fff 243 | 244 | # Latexian 245 | TSWLatexianTemp* 246 | 247 | ## Editors: 248 | # WinEdt 249 | *.bak 250 | *.sav 251 | 252 | # Texpad 253 | .texpadtmp 254 | 255 | # LyX 256 | *.lyx~ 257 | 258 | # Kile 259 | *.backup 260 | 261 | # gummi 262 | .*.swp 263 | 264 | # KBibTeX 265 | *~[0-9]* 266 | 267 | # TeXnicCenter 268 | *.tps 269 | 270 | # auto folder when using emacs and auctex 271 | ./auto/* 272 | *.el 273 | 274 | # expex forward references with \gathertags 275 | *-tags.tex 276 | 277 | # standalone packages 278 | *.sta 279 | 280 | # Makeindex log files 281 | *.lpz 282 | 283 | 284 | *reveal.js* 285 | node_modules 286 | -------------------------------------------------------------------------------- /source_files/tuesday/regression.md: -------------------------------------------------------------------------------- 1 | --- 2 | jupytext: 3 | text_representation: 4 | extension: .md 5 | format_name: myst 6 | format_version: 0.13 7 | jupytext_version: 1.16.1 8 | kernelspec: 9 | display_name: Python 3 (ipykernel) 10 | language: python 11 | name: python3 12 | --- 13 | 14 | # Linear Regression with Python 15 | 16 | ---- 17 | 18 | #### John Stachurski 19 | #### Prepared for the CBC Computational Workshop (May 2024) 20 | 21 | ---- 22 | 23 | +++ 24 | 25 | Let's have a very quick look at linear regression in Python. 26 | 27 | We'll also show how to download some data from [FRED](https://fred.stlouisfed.org/). 28 | 29 | +++ 30 | 31 | Uncomment the next line if you don't have this library installed: 32 | 33 | ```{code-cell} ipython3 34 | #!pip install pandas_datareader 35 | ``` 36 | 37 | Let's do some imports. 38 | 39 | ```{code-cell} ipython3 40 | import pandas_datareader.data as web 41 | import datetime 42 | import matplotlib.pyplot as plt 43 | import plotly.express as px 44 | import statsmodels.api as sm 45 | import statsmodels.formula.api as smf 46 | ``` 47 | 48 | We use the `datetime` module from the standard library to pick start dates and end dates. 49 | 50 | ```{code-cell} ipython3 51 | start = datetime.datetime(1947, 1, 1) 52 | end = datetime.datetime(2019, 12, 1) 53 | ``` 54 | 55 | Now let's read in data on GDP and unemployment from FRED. 56 | 57 | ```{code-cell} ipython3 58 | series = 'GDP', 'UNRATE' 59 | source = 'fred' 60 | data = web.DataReader(series, source, start, end) 61 | ``` 62 | 63 | Data is read in as a `pandas` dataframe. 64 | 65 | ```{code-cell} ipython3 66 | type(data) 67 | ``` 68 | 69 | ```{code-cell} ipython3 70 | data = data.dropna() 71 | data 72 | ``` 73 | 74 | We'll convert both series into rates of change. 75 | 76 | ```{code-cell} ipython3 77 | data = data.pct_change() * 100 78 | data = data.reset_index().dropna() 79 | data 80 | ``` 81 | 82 | Let's have a look at our data. 83 | 84 | Notice in the code below that Matplotlib plays well with pandas. 85 | 86 | ```{code-cell} ipython3 87 | fig, ax = plt.subplots() 88 | ax.scatter(x='UNRATE', y='GDP', data=data, color='k', alpha=0.5) 89 | ax.set_xlabel('% change in unemployment rate') 90 | ax.set_ylabel('% change in GDP') 91 | plt.show() 92 | ``` 93 | 94 | If you want an interactive graph, you can use Plotly instead: 95 | 96 | ```{code-cell} ipython3 97 | fig = px.scatter(data, x='UNRATE', y='GDP') 98 | fig.update_layout( 99 | showlegend=False, 100 | autosize=False, 101 | width=600, 102 | height=400, 103 | ) 104 | 105 | fig.show() 106 | ``` 107 | 108 | Let's fit a regression line, which can be used to measure [Okun's law](https://en.wikipedia.org/wiki/Okun%27s_law). 109 | 110 | To do so we'll use [Statsmodels](https://www.statsmodels.org/stable/index.html). 111 | 112 | ```{code-cell} ipython3 113 | model = smf.ols(formula='GDP ~ UNRATE', data=data) 114 | ols = model.fit() 115 | ``` 116 | 117 | ```{code-cell} ipython3 118 | ols.summary() 119 | ``` 120 | 121 | ```{code-cell} ipython3 122 | ols.params 123 | ``` 124 | 125 | ```{code-cell} ipython3 126 | X = data['UNRATE'] 127 | fig, ax = plt.subplots() 128 | plt.scatter(x='UNRATE', y='GDP', data=data, color='k', alpha=0.5) 129 | plt.plot(X, ols.fittedvalues, label='OLS') 130 | ax.set_xlabel('% change in unemployment rate') 131 | ax.set_ylabel('% change in GDP') 132 | plt.legend() 133 | plt.show() 134 | ``` 135 | 136 | Next let's try using least absolute deviations, which means that we minimize 137 | 138 | $$ 139 | \ell(\alpha, \beta) = \sum_{i=1}^n |y_i - (\alpha x_i + \beta_i)| 140 | $$ 141 | 142 | over parameters $\alpha, \beta$. 143 | 144 | This is a special case of quantile regression when the quantile is the median (0.5). 145 | 146 | ```{code-cell} ipython3 147 | mod = smf.quantreg(formula="GDP ~ UNRATE", data=data) 148 | lad = mod.fit(q=0.5) # LAD model is a special case of quantile regression when q = 0.5 149 | ``` 150 | 151 | ```{code-cell} ipython3 152 | lad.summary() 153 | ``` 154 | 155 | ```{code-cell} ipython3 156 | lad.params 157 | ``` 158 | 159 | Let's compare the LAD regression line to the least squares regression line. 160 | 161 | ```{code-cell} ipython3 162 | fig, ax = plt.subplots() 163 | plt.scatter(x='UNRATE', y='GDP', data=data, color='k', alpha=0.5) 164 | plt.plot(X, ols.fittedvalues, label='OLS') 165 | plt.plot(X, lad.fittedvalues, label='LAD') 166 | ax.set_xlabel('% change in unemployment rate') 167 | ax.set_ylabel('% change in GDP') 168 | plt.legend() 169 | plt.show() 170 | ``` 171 | 172 | ```{code-cell} ipython3 173 | 174 | ``` 175 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Creative Commons Legal Code 2 | 3 | CC0 1.0 Universal 4 | 5 | CREATIVE COMMONS CORPORATION IS NOT A LAW FIRM AND DOES NOT PROVIDE 6 | LEGAL SERVICES. DISTRIBUTION OF THIS DOCUMENT DOES NOT CREATE AN 7 | ATTORNEY-CLIENT RELATIONSHIP. CREATIVE COMMONS PROVIDES THIS 8 | INFORMATION ON AN "AS-IS" BASIS. CREATIVE COMMONS MAKES NO WARRANTIES 9 | REGARDING THE USE OF THIS DOCUMENT OR THE INFORMATION OR WORKS 10 | PROVIDED HEREUNDER, AND DISCLAIMS LIABILITY FOR DAMAGES RESULTING FROM 11 | THE USE OF THIS DOCUMENT OR THE INFORMATION OR WORKS PROVIDED 12 | HEREUNDER. 13 | 14 | Statement of Purpose 15 | 16 | The laws of most jurisdictions throughout the world automatically confer 17 | exclusive Copyright and Related Rights (defined below) upon the creator 18 | and subsequent owner(s) (each and all, an "owner") of an original work of 19 | authorship and/or a database (each, a "Work"). 20 | 21 | Certain owners wish to permanently relinquish those rights to a Work for 22 | the purpose of contributing to a commons of creative, cultural and 23 | scientific works ("Commons") that the public can reliably and without fear 24 | of later claims of infringement build upon, modify, incorporate in other 25 | works, reuse and redistribute as freely as possible in any form whatsoever 26 | and for any purposes, including without limitation commercial purposes. 27 | These owners may contribute to the Commons to promote the ideal of a free 28 | culture and the further production of creative, cultural and scientific 29 | works, or to gain reputation or greater distribution for their Work in 30 | part through the use and efforts of others. 31 | 32 | For these and/or other purposes and motivations, and without any 33 | expectation of additional consideration or compensation, the person 34 | associating CC0 with a Work (the "Affirmer"), to the extent that he or she 35 | is an owner of Copyright and Related Rights in the Work, voluntarily 36 | elects to apply CC0 to the Work and publicly distribute the Work under its 37 | terms, with knowledge of his or her Copyright and Related Rights in the 38 | Work and the meaning and intended legal effect of CC0 on those rights. 39 | 40 | 1. Copyright and Related Rights. A Work made available under CC0 may be 41 | protected by copyright and related or neighboring rights ("Copyright and 42 | Related Rights"). Copyright and Related Rights include, but are not 43 | limited to, the following: 44 | 45 | i. the right to reproduce, adapt, distribute, perform, display, 46 | communicate, and translate a Work; 47 | ii. moral rights retained by the original author(s) and/or performer(s); 48 | iii. publicity and privacy rights pertaining to a person's image or 49 | likeness depicted in a Work; 50 | iv. rights protecting against unfair competition in regards to a Work, 51 | subject to the limitations in paragraph 4(a), below; 52 | v. rights protecting the extraction, dissemination, use and reuse of data 53 | in a Work; 54 | vi. database rights (such as those arising under Directive 96/9/EC of the 55 | European Parliament and of the Council of 11 March 1996 on the legal 56 | protection of databases, and under any national implementation 57 | thereof, including any amended or successor version of such 58 | directive); and 59 | vii. other similar, equivalent or corresponding rights throughout the 60 | world based on applicable law or treaty, and any national 61 | implementations thereof. 62 | 63 | 2. Waiver. To the greatest extent permitted by, but not in contravention 64 | of, applicable law, Affirmer hereby overtly, fully, permanently, 65 | irrevocably and unconditionally waives, abandons, and surrenders all of 66 | Affirmer's Copyright and Related Rights and associated claims and causes 67 | of action, whether now known or unknown (including existing as well as 68 | future claims and causes of action), in the Work (i) in all territories 69 | worldwide, (ii) for the maximum duration provided by applicable law or 70 | treaty (including future time extensions), (iii) in any current or future 71 | medium and for any number of copies, and (iv) for any purpose whatsoever, 72 | including without limitation commercial, advertising or promotional 73 | purposes (the "Waiver"). Affirmer makes the Waiver for the benefit of each 74 | member of the public at large and to the detriment of Affirmer's heirs and 75 | successors, fully intending that such Waiver shall not be subject to 76 | revocation, rescission, cancellation, termination, or any other legal or 77 | equitable action to disrupt the quiet enjoyment of the Work by the public 78 | as contemplated by Affirmer's express Statement of Purpose. 79 | 80 | 3. Public License Fallback. Should any part of the Waiver for any reason 81 | be judged legally invalid or ineffective under applicable law, then the 82 | Waiver shall be preserved to the maximum extent permitted taking into 83 | account Affirmer's express Statement of Purpose. In addition, to the 84 | extent the Waiver is so judged Affirmer hereby grants to each affected 85 | person a royalty-free, non transferable, non sublicensable, non exclusive, 86 | irrevocable and unconditional license to exercise Affirmer's Copyright and 87 | Related Rights in the Work (i) in all territories worldwide, (ii) for the 88 | maximum duration provided by applicable law or treaty (including future 89 | time extensions), (iii) in any current or future medium and for any number 90 | of copies, and (iv) for any purpose whatsoever, including without 91 | limitation commercial, advertising or promotional purposes (the 92 | "License"). The License shall be deemed effective as of the date CC0 was 93 | applied by Affirmer to the Work. Should any part of the License for any 94 | reason be judged legally invalid or ineffective under applicable law, such 95 | partial invalidity or ineffectiveness shall not invalidate the remainder 96 | of the License, and in such case Affirmer hereby affirms that he or she 97 | will not (i) exercise any of his or her remaining Copyright and Related 98 | Rights in the Work or (ii) assert any associated claims and causes of 99 | action with respect to the Work, in either case contrary to Affirmer's 100 | express Statement of Purpose. 101 | 102 | 4. Limitations and Disclaimers. 103 | 104 | a. No trademark or patent rights held by Affirmer are waived, abandoned, 105 | surrendered, licensed or otherwise affected by this document. 106 | b. Affirmer offers the Work as-is and makes no representations or 107 | warranties of any kind concerning the Work, express, implied, 108 | statutory or otherwise, including without limitation warranties of 109 | title, merchantability, fitness for a particular purpose, non 110 | infringement, or the absence of latent or other defects, accuracy, or 111 | the present or absence of errors, whether or not discoverable, all to 112 | the greatest extent permissible under applicable law. 113 | c. Affirmer disclaims responsibility for clearing rights of other persons 114 | that may apply to the Work or any use thereof, including without 115 | limitation any person's Copyright and Related Rights in the Work. 116 | Further, Affirmer disclaims responsibility for obtaining any necessary 117 | consents, permissions or other rights required for any use of the 118 | Work. 119 | d. Affirmer understands and acknowledges that Creative Commons is not a 120 | party to this document and has no duty or obligation with respect to 121 | this CC0 or use of the Work. 122 | -------------------------------------------------------------------------------- /source_files/tuesday/exercises/markov_homework.md: -------------------------------------------------------------------------------- 1 | --- 2 | jupytext: 3 | text_representation: 4 | extension: .md 5 | format_name: myst 6 | format_version: 0.13 7 | jupytext_version: 1.16.1 8 | kernelspec: 9 | display_name: Python 3 (ipykernel) 10 | language: python 11 | name: python3 12 | --- 13 | 14 | # Markov Chains Homework 15 | 16 | 17 | ## Ergodicity 18 | 19 | Let $\{X_t\} = \{X_0, X_1, \ldots\}$ be a Markov chain generated by stochastic 20 | matrix $P$ and any initial condition. 21 | 22 | Recall that $P$ has a unique stationary distribution whenever $P$ is 23 | irreducible. 24 | 25 | Let $\psi^*$ denote this stationary distribution. 26 | 27 | Under irreducibility, the following result holds: for all $ x \in S $, 28 | 29 | $$ 30 | \frac{1}{m} \sum_{t = 1}^m \mathbf{1}\{X_t = x\} \to \psi^*(x) 31 | \quad \text{as } m \to \infty 32 | $$ 33 | 34 | Here 35 | 36 | - $ \mathbf{1}\{X_t = x\} = 1 $ if $ X_t = x $ and zero otherwise 37 | - convergence is with probability one 38 | - the result does not depend on the marginal distribution of $ X_0 $ 39 | 40 | The convergence asserted above is a special case of a law of large numbers 41 | result for Markov chains -- see, for example, [EDTC](http://johnstachurski.net/edtc.html), 42 | section 4.3.4. 43 | 44 | The result tells us that the fraction of time the chain spends at state $ x $ converges to $ \psi^*(x) $ as time goes to infinity. 45 | 46 | This gives us another way to interpret the stationary distribution. 47 | 48 | +++ 49 | 50 | ### Example 51 | 52 | Recall our cross-sectional interpretation of the employment/unemployment model discussed in the finite Markov chain lecture. 53 | 54 | Assume that $ \alpha \in (0,1) $ and $ \beta \in (0,1) $, so that irreducibility and aperiodicity both hold. 55 | 56 | We saw that the stationary distribution is $ (p, 1-p) $, where 57 | 58 | $$ 59 | p = \frac{\beta}{\alpha + \beta} 60 | $$ 61 | 62 | In the cross-sectional interpretation, this is the fraction of people unemployed. 63 | 64 | In view of our latest (ergodicity) result, it is also the fraction of time that a single worker can expect to spend unemployed. 65 | 66 | Thus, in the long-run, cross-sectional averages for a population and time-series averages for a given person coincide. 67 | 68 | This is one aspect of the concept of ergodicity. 69 | 70 | +++ 71 | 72 | **Exercise** 73 | 74 | According to the discussion above, if a worker’s employment dynamics obey the stochastic matrix 75 | 76 | $$ 77 | P 78 | = \left( 79 | \begin{array}{cc} 80 | 1 - \alpha & \alpha \\ 81 | \beta & 1 - \beta 82 | \end{array} 83 | \right) 84 | $$ 85 | 86 | with $ \alpha \in (0,1) $ and $ \beta \in (0,1) $, then, in the long-run, the fraction 87 | of time spent unemployed will be 88 | 89 | $$ 90 | p := \frac{\beta}{\alpha + \beta} 91 | $$ 92 | 93 | In other words, if $ \{X_t\} $ represents the Markov chain for 94 | employment, then $ \bar X_m \to p $ as $ m \to \infty $, where 95 | 96 | $$ 97 | \bar X_m := \frac{1}{m} \sum_{t = 1}^m \mathbf{1}\{X_t = 0\} 98 | $$ 99 | 100 | This exercise asks you to illustrate convergence by computing 101 | $ \bar X_m $ for large $ m $ and checking that 102 | it is close to $ p $. 103 | 104 | You will see that this statement is true regardless of the choice of initial 105 | condition or the values of $ \alpha, \beta $, provided both lie in 106 | $ (0, 1) $. 107 | 108 | Here's some code to start you off. 109 | 110 | ```{code-cell} ipython3 111 | α = β = 0.1 112 | p = β / (α + β) 113 | 114 | P = ((1 - α, α), # Careful: P and p are distinct 115 | ( β, 1 - β)) 116 | mc = MarkovChain(P) 117 | ``` 118 | 119 | **Solution** 120 | 121 | The plots below show the time series of $ \bar X_m - p $ for two initial 122 | conditions. 123 | 124 | As $ m $ gets large, both series converge to zero. 125 | 126 | ```{code-cell} ipython3 127 | N = 20_000 128 | fig, ax = plt.subplots() 129 | ax.set_ylim(-0.25, 0.25) 130 | ax.grid() 131 | ax.hlines(0, 0, N, lw=2, alpha=0.6) # Horizonal line at zero 132 | 133 | for x0 in (0, 1): 134 | # Generate time series for worker that starts at x0 135 | X = mc.simulate(N, init=x0) 136 | # Compute fraction of time spent unemployed, for each n 137 | X_bar = (X == 0).cumsum() / (1 + np.arange(N, dtype=float)) 138 | # Plot 139 | ax.fill_between(range(N), np.zeros(N), X_bar - p, alpha=0.1) 140 | ax.plot(X_bar - p, label=f'$X_0 = \, {x0} $') 141 | # Overlay in black--make lines clearer 142 | ax.plot(X_bar - p, 'k-', alpha=0.6) 143 | 144 | ax.legend(loc='upper right') 145 | plt.show() 146 | ``` 147 | 148 | ## Computing Expectations 149 | 150 | 151 | We are interested in computing expressions like 152 | 153 | $$ 154 | \mathbb E [ h(X_t) ] 155 | $$ 156 | 157 | and conditional expectations such as 158 | 159 | $$ 160 | \mathbb E [ h(X_{t + k}) \mid X_t = x] 161 | $$ 162 | 163 | where 164 | 165 | - $ \{X_t\} $ is a Markov chain generated by $ n \times n $ stochastic matrix $ P $ 166 | - $ \psi $ is the distribution of $ X_0 $ 167 | - $ h $ is a given function, which, in terms of matrix 168 | algebra, we’ll think of as the column vector 169 | 170 | 171 | $$ 172 | h 173 | = \left( 174 | \begin{array}{c} 175 | h(x_1) \\ 176 | \vdots \\ 177 | h(x_n) 178 | \end{array} 179 | \right) 180 | $$ 181 | 182 | Computing the unconditional expectation is easy. 183 | 184 | We just sum over the marginal distribution of $ X_t $ to get 185 | 186 | $$ 187 | \mathbb E [ h(X_t) ] 188 | = \sum_{x \in S} (\psi P^t)(x) h(x) 189 | $$ 190 | 191 | 192 | 193 | Since $ \psi $ and hence $ \psi P^t $ are row vectors, we can also 194 | write this as 195 | 196 | $$ 197 | \mathbb E [ h(X_t) ] 198 | = \psi P^t h 199 | $$ 200 | 201 | For the conditional expectation we need to sum over the conditional distribution 202 | of $ X_{t + k} $ given $ X_t = x $. 203 | 204 | We already know that this is $ P^k(x, \cdot) $, so 205 | 206 | 207 | $$ 208 | \mathbb E [ h(X_{t + k}) \mid X_t = x] 209 | = (P^k h)(x) 210 | $$ 211 | 212 | The vector $ P^k h $ stores the conditional expectation $ \mathbb E [ h(X_{t + k}) \mid X_t = x] $ over all $ x $. 213 | 214 | +++ 215 | 216 | ### Expectations of Geometric Sums 217 | 218 | To compute present values we often need to calculate expectation of a geometric sum, such as 219 | $ \sum_t \beta^t h(X_t) $. 220 | 221 | In view of the preceding discussion, this is 222 | 223 | $$ 224 | \mathbb{E} \left[ 225 | \sum_{t=0}^\infty \beta^t h(X_t) \mid X_0 = x 226 | \right] 227 | = \sum_{t=0}^\infty \mathbb{E} \left[ \beta^t h(X_t) \mid X_t = x 228 | \right] 229 | = \sum_{t=0}^\infty ((\beta P)^t h)(x) 230 | $$ 231 | 232 | +++ 233 | 234 | **Exercise** Suppose that the state of the economy is given by Hamilton's Markov chain, so that 235 | 236 | ```{code-cell} ipython3 237 | P = ((0.971, 0.029, 0.0), 238 | (0.145, 0.778, 0.077), 239 | (0.0, 0.508, 0.492)) 240 | ``` 241 | 242 | Suppose that current profits $\pi(X_t)$ of a firm are given by the vector 243 | 244 | ```{code-cell} ipython3 245 | π = (10, 5, -25) 246 | ``` 247 | 248 | Let the discount factor be 249 | 250 | ```{code-cell} ipython3 251 | β = 0.99 252 | ``` 253 | 254 | Using the Neumann series lemma, which tells us that 255 | 256 | $$ 257 | (I - \beta P)^{-1} = I + \beta P + \beta^2 P^2 + \cdots 258 | $$ 259 | 260 | compute the expected present value of the firm. 261 | 262 | ```{code-cell} ipython3 263 | 264 | h, P = np.asarray(h), np.asarray(P) 265 | I = np.identity(len(h)) 266 | v = np.linalg.solve(I - β * P, h) 267 | v 268 | ``` 269 | 270 | ```{code-cell} ipython3 271 | 272 | ``` 273 | -------------------------------------------------------------------------------- /source_files/tuesday/numpy.md: -------------------------------------------------------------------------------- 1 | --- 2 | jupytext: 3 | text_representation: 4 | extension: .md 5 | format_name: myst 6 | format_version: 0.13 7 | jupytext_version: 1.16.1 8 | kernelspec: 9 | display_name: Python 3 (ipykernel) 10 | language: python 11 | name: python3 12 | --- 13 | 14 | # Vectorization and Array Operations with NumPy 15 | 16 | ### Written for the CBC Workshop (May 2024) 17 | 18 | #### John Stachurski 19 | 20 | +++ 21 | 22 | [NumPy](https://numpy.org/) is the standard library for numerical array operations in Python. 23 | 24 | +++ 25 | 26 | This notebook contains a very quick introduction to NumPy. 27 | 28 | (Although the syntax and some reference concepts differ, the basic framework is similar to Matlab.) 29 | 30 | We use the following imports 31 | 32 | ```{code-cell} ipython3 33 | import numpy as np 34 | import matplotlib.pyplot as plt 35 | ``` 36 | 37 | ## NumPy arrays 38 | 39 | Let's review the basics of NumPy arrays. 40 | 41 | ### Creating arrays 42 | 43 | Here are a few ways to create arrays: 44 | 45 | ```{code-cell} ipython3 46 | a = (10.0, 20.0) 47 | print(type(a)) 48 | a = np.array(a) # Create array from Python tuple 49 | type(a) 50 | ``` 51 | 52 | ```{code-cell} ipython3 53 | a = np.array((10, 20), dtype='float64') # Specify data type -- must be homogeneous 54 | a 55 | ``` 56 | 57 | ```{code-cell} ipython3 58 | a = np.linspace(0, 10, 5) 59 | a 60 | ``` 61 | 62 | ```{code-cell} ipython3 63 | a = np.ones(3) 64 | a 65 | ``` 66 | 67 | ```{code-cell} ipython3 68 | a = np.zeros(3) 69 | a 70 | ``` 71 | 72 | ```{code-cell} ipython3 73 | a = np.random.randn(4) 74 | a 75 | ``` 76 | 77 | ```{code-cell} ipython3 78 | a = np.random.randn(2, 2) 79 | a 80 | ``` 81 | 82 | ```{code-cell} ipython3 83 | b = np.zeros_like(a) 84 | b 85 | ``` 86 | 87 | ### Reshaping 88 | 89 | ```{code-cell} ipython3 90 | a = np.random.randn(2, 2) 91 | a 92 | ``` 93 | 94 | ```{code-cell} ipython3 95 | a.shape 96 | ``` 97 | 98 | ```{code-cell} ipython3 99 | np.reshape(a, (1, 4)) 100 | ``` 101 | 102 | ```{code-cell} ipython3 103 | np.reshape(a, (4, 1)) 104 | ``` 105 | 106 | ### Array operations 107 | 108 | +++ 109 | 110 | Standard arithmetic operators are pointwise: 111 | 112 | ```{code-cell} ipython3 113 | a 114 | ``` 115 | 116 | ```{code-cell} ipython3 117 | b 118 | ``` 119 | 120 | ```{code-cell} ipython3 121 | a + b 122 | ``` 123 | 124 | ```{code-cell} ipython3 125 | a * b # pointwise multiplication 126 | ``` 127 | 128 | ### Matrix multiplication 129 | 130 | ```{code-cell} ipython3 131 | a @ b 132 | ``` 133 | 134 | ```{code-cell} ipython3 135 | np.ones(3) @ np.zeros(3) # inner product 136 | ``` 137 | 138 | ### Reductions 139 | 140 | +++ 141 | 142 | There are various functions for acting on arrays, such as 143 | 144 | ```{code-cell} ipython3 145 | np.mean(a) 146 | ``` 147 | 148 | ```{code-cell} ipython3 149 | np.sum(a) 150 | ``` 151 | 152 | These operations have an equivalent OOP syntax, as in 153 | 154 | ```{code-cell} ipython3 155 | a.mean() 156 | ``` 157 | 158 | ```{code-cell} ipython3 159 | a.sum() 160 | ``` 161 | 162 | These operations also work on higher-dimensional arrays: 163 | 164 | ```{code-cell} ipython3 165 | a = np.linspace(0, 3, 4).reshape(2, 2) 166 | a 167 | ``` 168 | 169 | ```{code-cell} ipython3 170 | a.sum(axis=0) # sum columns 171 | ``` 172 | 173 | ```{code-cell} ipython3 174 | a.sum(axis=1) # sum rows 175 | ``` 176 | 177 | ### Broadcasting 178 | 179 | +++ 180 | 181 | When possible, arrays are "streched" across missing dimensions to perform array operations. 182 | 183 | For example, 184 | 185 | ```{code-cell} ipython3 186 | a = np.zeros((3, 3)) 187 | a 188 | ``` 189 | 190 | ```{code-cell} ipython3 191 | b = np.array((1.0, 2.0, 3.0)) 192 | b = np.reshape(b, (1, 3)) 193 | b 194 | ``` 195 | 196 | ```{code-cell} ipython3 197 | a + b 198 | ``` 199 | 200 | ```{code-cell} ipython3 201 | b = np.reshape(b, (3, 1)) 202 | b 203 | ``` 204 | 205 | ```{code-cell} ipython3 206 | a + b 207 | ``` 208 | 209 | For more on broadcasting see [this tutorial](https://jakevdp.github.io/PythonDataScienceHandbook/02.05-computation-on-arrays-broadcasting.html). 210 | 211 | +++ 212 | 213 | ### Ufuncs 214 | 215 | +++ 216 | 217 | Many NumPy functions can act on either scalars or arrays. 218 | 219 | When they act on arrays, they act pointwise (element-by-element). 220 | 221 | These kinds of functions are called **universal functions** or **ufuncs**. 222 | 223 | ```{code-cell} ipython3 224 | np.cos(np.pi) 225 | ``` 226 | 227 | ```{code-cell} ipython3 228 | a = np.random.choice((0, np.pi), 6).reshape(2, 3) 229 | a 230 | ``` 231 | 232 | ```{code-cell} ipython3 233 | np.cos(a) 234 | ``` 235 | 236 | Some user-defined functions will be ufuncs, such as 237 | 238 | ```{code-cell} ipython3 239 | def f(x): 240 | return np.cos(np.sin(x)) 241 | ``` 242 | 243 | ```{code-cell} ipython3 244 | f(a) 245 | ``` 246 | 247 | But some are not: 248 | 249 | ```{code-cell} ipython3 250 | def f(x): 251 | if x < 0: 252 | return np.cos(x) 253 | else: 254 | return np.sin(x) 255 | ``` 256 | 257 | ```{code-cell} ipython3 258 | f(a) 259 | ``` 260 | 261 | If we want to turn this into a vectorized function we can use `np.vectorize` 262 | 263 | ```{code-cell} ipython3 264 | f_vec = np.vectorize(f) 265 | ``` 266 | 267 | Let's test it, and also time it. 268 | 269 | ```{code-cell} ipython3 270 | a = np.linspace(0, 1, 10_000_000) 271 | %time f_vec(a) 272 | ``` 273 | 274 | This is pretty slow. 275 | 276 | Here's a version of `f` that uses NumPy functions to create a more efficient ufunc. 277 | 278 | ```{code-cell} ipython3 279 | def f(x): 280 | return np.where(x < 0, np.cos(x), np.sin(x)) 281 | ``` 282 | 283 | ```{code-cell} ipython3 284 | %time f(a) 285 | ``` 286 | 287 | Moral of the story: Don't use `np.vectorize` unless you have to. 288 | 289 | (There are good alternatives, which we will discuss soon.) 290 | 291 | +++ 292 | 293 | ### Mutability 294 | 295 | +++ 296 | 297 | NumPy arrays are mutable (can be altered in memory). 298 | 299 | ```{code-cell} ipython3 300 | a = np.array((10.0, 20.0)) 301 | a 302 | ``` 303 | 304 | ```{code-cell} ipython3 305 | a[0] = 1 306 | ``` 307 | 308 | ```{code-cell} ipython3 309 | a 310 | ``` 311 | 312 | ```{code-cell} ipython3 313 | a[:] = 42 314 | ``` 315 | 316 | ```{code-cell} ipython3 317 | a 318 | ``` 319 | 320 | **All names** bound to an array have equal rights. 321 | 322 | ```{code-cell} ipython3 323 | a 324 | ``` 325 | 326 | ```{code-cell} ipython3 327 | b = a # bind the name b to the same array object 328 | ``` 329 | 330 | ```{code-cell} ipython3 331 | id(a) 332 | ``` 333 | 334 | ```{code-cell} ipython3 335 | id(b) 336 | ``` 337 | 338 | ```{code-cell} ipython3 339 | b[0] = 1_000 340 | ``` 341 | 342 | ```{code-cell} ipython3 343 | b 344 | ``` 345 | 346 | ```{code-cell} ipython3 347 | a 348 | ``` 349 | 350 | ## Vectorizing loops 351 | 352 | +++ 353 | 354 | ### Accelerating slow loops 355 | 356 | In scripting languages, native loops are slow: 357 | 358 | ```{code-cell} ipython3 359 | n = 10_000_000 360 | x_vec = np.linspace(0.1, 1.1, n) 361 | ``` 362 | 363 | Let's say we want to compute the sum of of $\cos(2\pi / x)$ over $x$ in 364 | 365 | ```{code-cell} ipython3 366 | %%time 367 | current_sum = 0.0 368 | for x in x_vec: 369 | current_sum += np.cos(2 * np.pi / x) 370 | ``` 371 | 372 | The reason is that Python, like most high level languages is dynamically typed. 373 | 374 | This means that the type of a variable can freely change. 375 | 376 | Moreover, the interpreter doesn't compile the whole program at once, so it doesn't know when types will change. 377 | 378 | So the interpreter has to check the type of variables before any operation like addition, comparison, etc. 379 | 380 | Hence there's a lot of fixed cost for each such operation 381 | 382 | +++ 383 | 384 | The code runs much faster if we use **vectorized** expressions to avoid explicit loops. 385 | 386 | ```{code-cell} ipython3 387 | %%time 388 | np.sum(np.cos(2 * np.pi / x_vec)) 389 | ``` 390 | 391 | Now high level overheads are paid *per array rather than per float*. 392 | 393 | +++ 394 | 395 | ### Implict Multithreading 396 | 397 | 398 | Recent versions of Anaconda are compiled with Intel MKL support, which accelerates NumPy operations. 399 | 400 | Watch system resources when you run this code. 401 | 402 | (For example, install `htop` (Linux / Mac), `perfmon` (Windows) or another system load monitor and set it running in another window.) 403 | 404 | ```{code-cell} ipython3 405 | n = 20 406 | m = 1000 407 | for i in range(n): 408 | X = np.random.randn(m, m) 409 | λ = np.linalg.eigvals(X) 410 | ``` 411 | 412 | You should see all your cores light up. With MKL, many matrix operations are automatically parallelized. 413 | -------------------------------------------------------------------------------- /source_files/wednesday/ez_preferences.md: -------------------------------------------------------------------------------- 1 | --- 2 | jupytext: 3 | text_representation: 4 | extension: .md 5 | format_name: myst 6 | format_version: 0.13 7 | jupytext_version: 1.16.1 8 | kernelspec: 9 | display_name: Python 3 (ipykernel) 10 | language: python 11 | name: python3 12 | --- 13 | 14 | # Recursive Utility: Solution Methods 15 | 16 | ------ 17 | 18 | #### Prepared for the CBC Workshop May 2024 19 | #### John Stachurski 20 | 21 | ------ 22 | 23 | +++ 24 | 25 | ## Outline 26 | 27 | In some dynamic models, lifetime utility is nonlinear and defined recursively. 28 | 29 | * models with Epstein-Zin preferences 30 | * risk-sensitive preferences 31 | * models with ambiguity aversion 32 | * robust control models 33 | * adversarial agents 34 | 35 | In this lecture we explore how to compute lifetime utility in some of these settings. 36 | 37 | Our main focus will be the Epstein-Zin setting. 38 | 39 | Uncomment if necessary: 40 | 41 | ```{code-cell} ipython3 42 | #!pip install quantecon 43 | ``` 44 | 45 | We will use the following imports. 46 | 47 | ```{code-cell} ipython3 48 | import jax 49 | import jax.numpy as jnp 50 | import matplotlib.pyplot as plt 51 | import numpy as np 52 | from collections import namedtuple 53 | import time 54 | import quantecon as qe 55 | 56 | jax.config.update("jax_enable_x64", True) 57 | ``` 58 | 59 | Looking for a GPU: 60 | 61 | ```{code-cell} ipython3 62 | !nvidia-smi 63 | ``` 64 | 65 | ## Model 66 | 67 | 68 | ### The EZ recursion 69 | 70 | For Epstein--Zin preferences, lifetime utility from time $t$ onwards is given by 71 | 72 | $$ 73 | V_t = \left\{ 74 | C_t^\rho + \beta (\mathbb E_t V_{t+1}^\gamma)^{\rho/\gamma} 75 | \right\}^{1/\rho} 76 | $$ 77 | 78 | Here 79 | 80 | * $(C_t)$ is a consumption path that is being valued by the agent 81 | * $\beta \in (0,1)$ and $\gamma, \rho$ are nonzero 82 | * $\mathbb E_t$ is time $t$ expectation 83 | * $V_t$ is lifetime utility generated by $C_t, C_{t+1}, C_{t+2}, \ldots$ 84 | 85 | One way to understand the recursion above is to write it as 86 | 87 | $$ 88 | V_t = \left\{ 89 | C_t^\rho + \beta (\mathcal E_t V_{t+1})^\rho 90 | \right\}^{1/\rho} 91 | $$ 92 | 93 | where $\mathcal E$ computes the ``risk-adjusted expectation'' 94 | 95 | $$ 96 | \mathcal E Y = (\mathbb E Y^\gamma)^{1/\gamma} 97 | $$ 98 | 99 | * $\gamma$ governs risk aversion 100 | * $\rho$ governs elasticity of intertemporal substitution 101 | 102 | +++ 103 | 104 | ### A Markov Formulation 105 | 106 | We suppose $C_t = c(X_t)$ where $X_t$ is a Markov process taking values in state 107 | space $S$. 108 | 109 | We guess the solution has the form $V_t = v(X_t)$ for all $t$, where $v$ is some 110 | function over $S$. 111 | 112 | In this case we can write the above equation as 113 | 114 | $$ 115 | v(X_t) = \left\{ 116 | c(X_t)^\rho + \beta (\mathbb E_t v(X_{t+1})^\gamma)^{\rho/\gamma} 117 | \right\}^{1/\rho} 118 | $$ 119 | 120 | Let's suppose that $(X_t)$ is a Markov chain with transition matrix $P$. 121 | 122 | It suffices to find a $v \colon S \to \mathbb R$ such that 123 | 124 | $$ 125 | v(x) 126 | = \left\{ 127 | c(x)^\rho + \beta \left[\sum_{x'} v(x')^\gamma P(x, x')\right]^{\rho/\gamma} 128 | \right\}^{1/\rho} 129 | \qquad (x \in S) 130 | $$ 131 | 132 | 133 | We define the operator $K$ sending $v$ into $Kv$ by 134 | 135 | $$ 136 | (Kv)(x) 137 | = \left\{ 138 | c(x)^\rho + \beta \left[\sum_{x'} v(x')^\gamma P(x, x')\right]^{\rho/\gamma} 139 | \right\}^{1/\rho} 140 | $$ 141 | 142 | We seek a fixed point of $v$ (i.e., a $v$ with $Kv=v$). 143 | 144 | +++ 145 | 146 | ## Solvers 147 | 148 | To solve for a fixed point of $K$ we use two methods. 149 | 150 | The first is successive approximation (also called fixed point iteration): pick any $v$ and then iterate with $K$. 151 | 152 | The second is Newton's method, which is used to find the $v \in \mathbb R^n$ such that $F(v) = 0$. 153 | 154 | We can use Newton's method to find a fixed point of $K$ by setting $F(v) = Kv - v$. 155 | 156 | Newton's method for finding a zero of $F$ is to iterate on 157 | 158 | $$ 159 | v_{k+1} = v_k - J(v_k)^{-1} F(v_k) 160 | $$ 161 | 162 | where $J(v)$ is the Jacobian of $F(v)$. 163 | 164 | In general 165 | 166 | * Newton's method has a faster rate of convergence (quadratic vs linear) 167 | * Successive approximation is more robust and can be quicker due to smaller constant terms in the rate of convergence. 168 | 169 | Here are our two solvers. 170 | 171 | ```{code-cell} ipython3 172 | def successive_approx(K, v_init, max_iter=50_000, tol=1e-8): 173 | "Compute the fixed point of K by iterating from guess v_init." 174 | 175 | i, error = 0, tol + 1 176 | v = v_init 177 | while error > tol and i < max_iter: 178 | v_new = K(v) 179 | error = np.max(np.abs(v_new - v)) 180 | i += 1 181 | v = v_new 182 | return v, i 183 | 184 | 185 | def newton_solver(K, v_init, max_iter=10_000, tol=1e-8): 186 | """ 187 | Apply Newton's algorithm to find a fixed point of K. 188 | 189 | We use a root-finding operation on F(v) = K(v) - v, which requires 190 | iterating with the map 191 | 192 | Q(v) = v - J(v)^{-1} F(v). 193 | 194 | Here J(v) is the Jacobian of F evaluated at v. 195 | 196 | """ 197 | F = lambda v: K(v) - v 198 | @jax.jit 199 | def Q(v): 200 | J = jax.jacobian(F) 201 | return v - jnp.linalg.solve(J(v), F(v)) 202 | return successive_approx(Q, v_init, tol=tol, max_iter=max_iter) 203 | ``` 204 | 205 | ### Solution 206 | 207 | In solving the model, we need to specify the function $c$ in $C_t = c(X_t)$ and the state process. 208 | 209 | For the state process we use a Tauchen discretization of 210 | 211 | $$ 212 | X_{t+1} = \alpha X_t + \sigma Z_{t+1}, 213 | \qquad \{Z_t\} \text{ is IID and } N(0, 1) 214 | $$ 215 | 216 | We assume that $c(x) = \exp(x)$. 217 | 218 | ```{code-cell} ipython3 219 | c = jnp.exp 220 | ``` 221 | 222 | Here's the model. 223 | 224 | ```{code-cell} ipython3 225 | Model = namedtuple('Model', ('ρ', 'γ', 'β', 'α', 'σ', 'x_vals', 'c_vals', 'P')) 226 | 227 | def create_ez_model(ρ=1.6, 228 | γ=-2.0, 229 | α=0.9, 230 | β=0.998, 231 | σ=0.1, 232 | n=1_000): 233 | mc = qe.tauchen(n, α, σ) 234 | x_vals, P = jnp.exp(mc.state_values), mc.P 235 | c_vals = c(x_vals) 236 | P = jnp.array(P) 237 | return Model(ρ, γ, β, α, σ, x_vals, c_vals, P) 238 | ``` 239 | 240 | Here's the operator $K$. 241 | 242 | ```{code-cell} ipython3 243 | @jax.jit 244 | def K(v, model): 245 | ρ, γ, β, α, σ, x_vals, c_vals, P = model 246 | return (c_vals**ρ + β * (P @ v**γ)**(ρ / γ))**(1 / ρ) 247 | ``` 248 | 249 | Let's solve it using the two different solvers and see which is faster. 250 | 251 | ```{code-cell} ipython3 252 | model = create_ez_model() 253 | ρ, γ, β, α, σ, x_vals, c_vals, P = model 254 | v_init = c_vals 255 | ``` 256 | 257 | ```{code-cell} ipython3 258 | start_time = time.time() 259 | v_sa, num_iter = successive_approx(lambda v: K(v, model), v_init) 260 | sa_time = time.time() - start_time 261 | print(f"Successive approximation converged in {num_iter} iterations.") 262 | print(f"Execution time = {sa_time:.5} seconds.") 263 | ``` 264 | 265 | ```{code-cell} ipython3 266 | start_time = time.time() 267 | v_newton, num_iter = newton_solver(lambda v: K(v, model), v_init) 268 | newton_time = time.time() - start_time 269 | print(f"Successive approximation converged in {num_iter} iterations.") 270 | print(f"Execution time = {newton_time:.5} seconds.") 271 | ``` 272 | 273 | ```{code-cell} ipython3 274 | print(f"Successive approx time / Newton time = {sa_time / newton_time}") 275 | ``` 276 | 277 | ```{code-cell} ipython3 278 | fig, ax = plt.subplots() 279 | ax.plot(x_vals, v_sa, label='successive approx', ls='--') 280 | ax.plot(x_vals, v_newton, lw=1, label='newton') 281 | ax.legend() 282 | plt.show() 283 | ``` 284 | 285 | ## Exercise 286 | 287 | Step $\sigma$ through `0.05, 0.075, 0.1, 0.125`, in each case computing the solution and plotting lifetime utility (all on the same figure). 288 | 289 | How does increasing volatility affect lifetime utility? 290 | 291 | (You might find that lifetime utility goes up with $\sigma$, even though the agent is risk-averse ($\gamma < 0$). Can you explain this?) 292 | 293 | ```{code-cell} ipython3 294 | # Put your code here. 295 | ``` 296 | 297 | ```{code-cell} ipython3 298 | for i in range(18): 299 | print("Solution below! 🐇") 300 | ``` 301 | 302 | ```{code-cell} ipython3 303 | v_init = c_vals 304 | sig_vals = 0.05, 0.075, 0.1, 0.125 305 | 306 | fig, ax = plt.subplots() 307 | for σ in sig_vals: 308 | model= create_ez_model(σ=σ, γ=0.5) 309 | v, _ = newton_solver(lambda v: K(v, model), v_init) 310 | v_init = v 311 | ax.plot(x_vals, v, label=f"$\sigma = {σ:.4}$") 312 | ax.legend(frameon=False) 313 | plt.show() 314 | 315 | ``` 316 | 317 | ```{code-cell} ipython3 318 | 319 | ``` 320 | -------------------------------------------------------------------------------- /source_files/tuesday/exercises/simulation_exercises.md: -------------------------------------------------------------------------------- 1 | --- 2 | jupytext: 3 | text_representation: 4 | extension: .md 5 | format_name: myst 6 | format_version: 0.13 7 | jupytext_version: 1.15.0 8 | kernelspec: 9 | display_name: Python 3 (ipykernel) 10 | language: python 11 | name: python3 12 | --- 13 | 14 | # Simulation Exercises 15 | 16 | #### Prepared for the CBC Computational Workshop (May 2024) 17 | 18 | #### John Stachurski 19 | 20 | +++ 21 | 22 | This notebook contains some exercises related to simulation. 23 | 24 | ```{code-cell} ipython3 25 | import numpy as np 26 | import matplotlib.pyplot as plt 27 | import numba 28 | from numba import jit, prange 29 | ``` 30 | 31 | 32 | ## Exercise 33 | 34 | Compute an approximation to $ \pi $ using [Monte Carlo](https://en.wikipedia.org/wiki/Monte_Carlo_method). 35 | 36 | Your hints are as follows: 37 | 38 | - If $ U $ is a bivariate uniform random variable on the unit square $ (0, 1)^2 $, then the probability that $ U $ lies in a subset $ B $ of $ (0,1)^2 $ is equal to the area of $ B $. 39 | - If $ U_1,\ldots,U_n $ are IID copies of $ U $, then, as $ n $ gets large, the fraction that falls in $ B $, converges to the probability of landing in $ B $. 40 | - For a circle, $ area = \pi * radius^2 $. 41 | 42 | ```{code-cell} ipython3 43 | # Put your code here 44 | ``` 45 | 46 | ```{code-cell} ipython3 47 | for _ in range(12): 48 | print('solution below') 49 | ``` 50 | 51 | Consider the circle of diameter 1 embedded in the unit square. 52 | 53 | Let $ A $ be its area and let $ r=1/2 $ be its radius, so that $A = \pi r^2 $. 54 | 55 | If we can estimate $A$ then we can estimate $ \pi $ via $ \pi = A / r^2 = 4A$. 56 | 57 | We estimate $A$ by sampling bivariate uniforms and looking at the fraction that falls into the circle. 58 | 59 | ```{code-cell} ipython3 60 | n = 1_000_000 # sample size for Monte Carlo simulation 61 | 62 | def in_circle(u, v): 63 | """ 64 | Test whether (u, v) falls within the unit circle centred at (0.5,0.5) 65 | """ 66 | d = np.sqrt((u - 0.5)**2 + (v - 0.5)**2) 67 | return d < 0.5 68 | 69 | count = 0 70 | for i in range(n): 71 | 72 | # drawing random positions on the square 73 | u, v = np.random.uniform(0, 1), np.random.uniform(0, 1) 74 | 75 | # if it falls within the circle, add it to the count 76 | if in_circle(u, v): 77 | count += 1 78 | 79 | area_estimate = count / n 80 | 81 | print(area_estimate * 4) # dividing by radius**2 82 | ``` 83 | 84 | 85 | ## Exercise 86 | 87 | Accelerate the code from the previous exercise using Numba. Time the difference. 88 | 89 | ```{code-cell} ipython3 90 | for _ in range(12): 91 | print('solution below') 92 | ``` 93 | 94 | ```{code-cell} ipython3 95 | def calculate_pi(n=1_000_000): 96 | count = 0 97 | for i in range(n): 98 | u, v = np.random.uniform(0, 1), np.random.uniform(0, 1) 99 | d = np.sqrt((u - 0.5)**2 + (v - 0.5)**2) 100 | if d < 0.5: 101 | count += 1 102 | area_estimate = count / n 103 | return area_estimate * 4 # dividing by radius**2 104 | ``` 105 | 106 | ```{code-cell} ipython3 107 | %time calculate_pi() 108 | ``` 109 | 110 | ```{code-cell} ipython3 111 | fast_calc_pi = jit(calculate_pi) 112 | ``` 113 | 114 | ```{code-cell} ipython3 115 | %time fast_calc_pi() 116 | ``` 117 | 118 | And again to omit compile time: 119 | 120 | ```{code-cell} ipython3 121 | %time fast_calc_pi() 122 | ``` 123 | 124 | ## Exercise 125 | 126 | Suppose that the volatility of returns on an asset can be in one of two regimes — high or low. 127 | 128 | The transition probabilities across states are as follows 129 | 130 | ![https://python-programming.quantecon.org/_static/lecture_specific/sci_libs/nfs_ex1.png](https://python-programming.quantecon.org/_static/lecture_specific/sci_libs/nfs_ex1.png) 131 | 132 | 133 | For example, let the period length be one day, and suppose the current state is high. 134 | 135 | We see from the graph that the state tomorrow will be 136 | 137 | - high with probability 0.8 138 | - low with probability 0.2 139 | 140 | 141 | Your task is to simulate a sequence of daily volatility states according to this rule. 142 | 143 | Set the length of the sequence to `n = 1_000_000` and start in the high state. 144 | 145 | Implement a pure Python version and a Numba version, and compare speeds. 146 | 147 | To test your code, evaluate the fraction of time that the chain spends in the low state. 148 | 149 | If your code is correct, it should be about 2/3. 150 | 151 | Hints: 152 | 153 | - Represent the low state as 0 and the high state as 1. 154 | - If you want to store integers in a NumPy array and then apply JIT compilation, use `x = np.empty(n, dtype=numba.int64)` or similar. 155 | 156 | ```{code-cell} ipython3 157 | # Put your code here 158 | ``` 159 | 160 | ```{code-cell} ipython3 161 | for _ in range(12): 162 | print('solution below') 163 | ``` 164 | 165 | We let 166 | 167 | - 0 represent “low” 168 | - 1 represent “high” 169 | 170 | ```{code-cell} ipython3 171 | p, q = 0.1, 0.2 # Prob of leaving low and high state respectively 172 | ``` 173 | 174 | Here’s a pure Python version of the function 175 | 176 | ```{code-cell} ipython3 177 | def compute_series(n): 178 | x = np.empty(n, dtype=int) 179 | x[0] = 1 # Start in state 1 180 | U = np.random.uniform(0, 1, size=n) 181 | for t in range(1, n): 182 | current_x = x[t-1] 183 | if current_x == 0: 184 | x[t] = U[t] < p 185 | else: 186 | x[t] = U[t] > q 187 | return x 188 | ``` 189 | 190 | ```{code-cell} ipython3 191 | n = 1_000_000 192 | ``` 193 | 194 | ```{code-cell} ipython3 195 | %time x = compute_series(n) 196 | ``` 197 | 198 | ```{code-cell} ipython3 199 | print(np.mean(x == 0)) # Fraction of time x is in state 0 200 | ``` 201 | 202 | Now let's speed it up: 203 | 204 | 205 | ```{code-cell} ipython3 206 | @jit 207 | def fast_compute_series(n): 208 | x = np.empty(n, dtype=numba.int8) 209 | x[0] = 1 # Start in state 1 210 | U = np.random.uniform(0, 1, size=n) 211 | for t in range(1, n): 212 | current_x = x[t-1] 213 | if current_x == 0: 214 | x[t] = U[t] < p 215 | else: 216 | x[t] = U[t] > q 217 | return x 218 | ``` 219 | 220 | 221 | Run once to compile: 222 | 223 | ```{code-cell} ipython3 224 | %time fast_compute_series(n) 225 | ``` 226 | 227 | Now let's check the speed: 228 | 229 | ```{code-cell} ipython3 230 | %time fast_compute_series(n) 231 | ``` 232 | 233 | 234 | **Exercise** 235 | 236 | 237 | We consider using Monte Carlo to price a European call option. 238 | 239 | The price of the option obeys 240 | 241 | $$ 242 | P = \beta^n \mathbb E \max\{ S_n - K, 0 \} 243 | $$ 244 | 245 | where 246 | 247 | 1. $\beta$ is a discount factor, 248 | 2. $n$ is the expiry date, 249 | 2. $K$ is the strike price and 250 | 3. $\{S_t\}$ is the price of the underlying asset at each time $t$. 251 | 252 | Suppose that `n, β, K = 20, 0.99, 100`. 253 | 254 | Assume that the stock price obeys 255 | 256 | $$ 257 | \ln \frac{S_{t+1}}{S_t} = \mu + \sigma_t \xi_{t+1} 258 | $$ 259 | 260 | where 261 | 262 | $$ 263 | \sigma_t = \exp(h_t), 264 | \quad 265 | h_{t+1} = \rho h_t + \nu \eta_{t+1} 266 | $$ 267 | 268 | Here $\{\xi_t\}$ and $\{\eta_t\}$ are IID and standard normal. 269 | 270 | (This is a stochastic volatility model, where the volatility $\sigma_t$ varies over time.) 271 | 272 | Use the defaults `μ, ρ, ν, S0, h0 = 0.0001, 0.1, 0.001, 10, 0`. 273 | 274 | (Here `S0` is $S_0$ and `h0` is $h_0$.) 275 | 276 | By generating $M$ paths $s_0, \ldots, s_n$, compute the Monte Carlo estimate 277 | 278 | $$ 279 | \hat P_M 280 | := \beta^n \mathbb E \max\{ S_n - K, 0 \} 281 | \approx 282 | \beta^n \frac{1}{M} \sum_{m=1}^M \max \{S_n^m - K, 0 \} 283 | $$ 284 | 285 | 286 | If you can, use Numba to speed up loops. 287 | 288 | If possible, use Numba-based multithreading (`parallel=True`) to speed it even 289 | further. 290 | 291 | 292 | 293 | 294 | ```{code-cell} ipython3 295 | for _ in range(12): 296 | print('solution below') 297 | ``` 298 | 299 | 300 | **Solution** 301 | 302 | 303 | With $s_t := \ln S_t$, the price dynamics become 304 | 305 | $$ 306 | s_{t+1} = s_t + \mu + \exp(h_t) \xi_{t+1} 307 | $$ 308 | 309 | Using this fact, the solution can be written as follows. 310 | 311 | 312 | ```{code-cell} ipython3 313 | from numpy.random import randn 314 | M = 10_000_000 315 | 316 | n, β, K = 20, 0.99, 100 317 | μ, ρ, ν, S0, h0 = 0.0001, 0.1, 0.001, 10, 0 318 | 319 | @jit(parallel=True) 320 | def compute_call_price_parallel(β=β, 321 | μ=μ, 322 | S0=S0, 323 | h0=h0, 324 | K=K, 325 | n=n, 326 | ρ=ρ, 327 | ν=ν, 328 | M=M): 329 | current_sum = 0.0 330 | # For each sample path 331 | for m in prange(M): 332 | s = np.log(S0) 333 | h = h0 334 | # Simulate forward in time 335 | for t in range(n): 336 | s = s + μ + np.exp(h) * randn() 337 | h = ρ * h + ν * randn() 338 | # And add the value max{S_n - K, 0} to current_sum 339 | current_sum += np.maximum(np.exp(s) - K, 0) 340 | 341 | return β**n * current_sum / M 342 | ``` 343 | 344 | Try swapping between `parallel=True` and `parallel=False` and noting the run time. 345 | 346 | If you are on a machine with many CPUs, the difference should be significant. 347 | 348 | 349 | 350 | 351 | 352 | -------------------------------------------------------------------------------- /2_tuesday/regression.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "fc8d66d4", 6 | "metadata": {}, 7 | "source": [ 8 | "# Linear Regression with Python\n", 9 | "\n", 10 | "----\n", 11 | "\n", 12 | "#### John Stachurski\n", 13 | "#### Prepared for the CBC Computational Workshop (May 2024)\n", 14 | "\n", 15 | "----" 16 | ] 17 | }, 18 | { 19 | "cell_type": "markdown", 20 | "id": "6fc0b912", 21 | "metadata": {}, 22 | "source": [ 23 | "Let's have a very quick look at linear regression in Python.\n", 24 | "\n", 25 | "We'll also show how to download some data from [FRED](https://fred.stlouisfed.org/)." 26 | ] 27 | }, 28 | { 29 | "cell_type": "markdown", 30 | "id": "cff63b7c", 31 | "metadata": {}, 32 | "source": [ 33 | "Uncomment the next line if you don't have this library installed:" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 1, 39 | "id": "347c4645", 40 | "metadata": {}, 41 | "outputs": [], 42 | "source": [ 43 | "#!pip install pandas_datareader" 44 | ] 45 | }, 46 | { 47 | "cell_type": "markdown", 48 | "id": "c728b0ec", 49 | "metadata": {}, 50 | "source": [ 51 | "Let's do some imports." 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": 2, 57 | "id": "1fa7d6b3", 58 | "metadata": {}, 59 | "outputs": [], 60 | "source": [ 61 | "import pandas_datareader.data as web\n", 62 | "import datetime\n", 63 | "import matplotlib.pyplot as plt\n", 64 | "import plotly.express as px\n", 65 | "import statsmodels.api as sm\n", 66 | "import statsmodels.formula.api as smf" 67 | ] 68 | }, 69 | { 70 | "cell_type": "markdown", 71 | "id": "ad82983b", 72 | "metadata": {}, 73 | "source": [ 74 | "We use the `datetime` module from the standard library to pick start dates and end dates." 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": null, 80 | "id": "76700b1c", 81 | "metadata": {}, 82 | "outputs": [], 83 | "source": [ 84 | "start = datetime.datetime(1947, 1, 1)\n", 85 | "end = datetime.datetime(2019, 12, 1)" 86 | ] 87 | }, 88 | { 89 | "cell_type": "markdown", 90 | "id": "e95b566c", 91 | "metadata": {}, 92 | "source": [ 93 | "Now let's read in data on GDP and unemployment from FRED." 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": null, 99 | "id": "f115ccd6", 100 | "metadata": {}, 101 | "outputs": [], 102 | "source": [ 103 | "series = 'GDP', 'UNRATE' \n", 104 | "source = 'fred'\n", 105 | "data = web.DataReader(series, source, start, end)" 106 | ] 107 | }, 108 | { 109 | "cell_type": "markdown", 110 | "id": "bd5e104d", 111 | "metadata": {}, 112 | "source": [ 113 | "Data is read in as a `pandas` dataframe." 114 | ] 115 | }, 116 | { 117 | "cell_type": "code", 118 | "execution_count": null, 119 | "id": "ebb85714", 120 | "metadata": {}, 121 | "outputs": [], 122 | "source": [ 123 | "type(data)" 124 | ] 125 | }, 126 | { 127 | "cell_type": "code", 128 | "execution_count": null, 129 | "id": "bd2f9d8f", 130 | "metadata": {}, 131 | "outputs": [], 132 | "source": [ 133 | "data = data.dropna()\n", 134 | "data" 135 | ] 136 | }, 137 | { 138 | "cell_type": "markdown", 139 | "id": "db887181", 140 | "metadata": {}, 141 | "source": [ 142 | "We'll convert both series into rates of change." 143 | ] 144 | }, 145 | { 146 | "cell_type": "code", 147 | "execution_count": null, 148 | "id": "cb097768", 149 | "metadata": {}, 150 | "outputs": [], 151 | "source": [ 152 | "data = data.pct_change() * 100\n", 153 | "data = data.reset_index().dropna()\n", 154 | "data" 155 | ] 156 | }, 157 | { 158 | "cell_type": "markdown", 159 | "id": "6064ae79", 160 | "metadata": {}, 161 | "source": [ 162 | "Let's have a look at our data.\n", 163 | "\n", 164 | "Notice in the code below that Matplotlib plays well with pandas." 165 | ] 166 | }, 167 | { 168 | "cell_type": "code", 169 | "execution_count": null, 170 | "id": "85eb362e", 171 | "metadata": {}, 172 | "outputs": [], 173 | "source": [ 174 | "fig, ax = plt.subplots()\n", 175 | "ax.scatter(x='UNRATE', y='GDP', data=data, color='k', alpha=0.5)\n", 176 | "ax.set_xlabel('% change in unemployment rate')\n", 177 | "ax.set_ylabel('% change in GDP')\n", 178 | "plt.show()" 179 | ] 180 | }, 181 | { 182 | "cell_type": "markdown", 183 | "id": "636c40e6", 184 | "metadata": {}, 185 | "source": [ 186 | "If you want an interactive graph, you can use Plotly instead:" 187 | ] 188 | }, 189 | { 190 | "cell_type": "code", 191 | "execution_count": null, 192 | "id": "96aaa58e", 193 | "metadata": {}, 194 | "outputs": [], 195 | "source": [ 196 | "fig = px.scatter(data, x='UNRATE', y='GDP')\n", 197 | "fig.update_layout(\n", 198 | " showlegend=False,\n", 199 | " autosize=False,\n", 200 | " width=600,\n", 201 | " height=400,\n", 202 | ")\n", 203 | "\n", 204 | "fig.show()" 205 | ] 206 | }, 207 | { 208 | "cell_type": "markdown", 209 | "id": "32765158", 210 | "metadata": {}, 211 | "source": [ 212 | "Let's fit a regression line, which can be used to measure [Okun's law](https://en.wikipedia.org/wiki/Okun%27s_law).\n", 213 | "\n", 214 | "To do so we'll use [Statsmodels](https://www.statsmodels.org/stable/index.html)." 215 | ] 216 | }, 217 | { 218 | "cell_type": "code", 219 | "execution_count": null, 220 | "id": "ef56a9ea", 221 | "metadata": {}, 222 | "outputs": [], 223 | "source": [ 224 | "model = smf.ols(formula='GDP ~ UNRATE', data=data)\n", 225 | "ols = model.fit()" 226 | ] 227 | }, 228 | { 229 | "cell_type": "code", 230 | "execution_count": null, 231 | "id": "2dde7d7c", 232 | "metadata": {}, 233 | "outputs": [], 234 | "source": [ 235 | "ols.summary()" 236 | ] 237 | }, 238 | { 239 | "cell_type": "code", 240 | "execution_count": null, 241 | "id": "dfa195a4", 242 | "metadata": {}, 243 | "outputs": [], 244 | "source": [ 245 | "ols.params" 246 | ] 247 | }, 248 | { 249 | "cell_type": "code", 250 | "execution_count": null, 251 | "id": "1a9b70f0", 252 | "metadata": {}, 253 | "outputs": [], 254 | "source": [ 255 | "X = data['UNRATE']\n", 256 | "fig, ax = plt.subplots()\n", 257 | "plt.scatter(x='UNRATE', y='GDP', data=data, color='k', alpha=0.5)\n", 258 | "plt.plot(X, ols.fittedvalues, label='OLS')\n", 259 | "ax.set_xlabel('% change in unemployment rate')\n", 260 | "ax.set_ylabel('% change in GDP')\n", 261 | "plt.legend()\n", 262 | "plt.show()" 263 | ] 264 | }, 265 | { 266 | "cell_type": "markdown", 267 | "id": "e2dd974e", 268 | "metadata": {}, 269 | "source": [ 270 | "Next let's try using least absolute deviations, which means that we minimize\n", 271 | "\n", 272 | "$$\n", 273 | "\\ell(\\alpha, \\beta) = \\sum_{i=1}^n |y_i - (\\alpha x_i + \\beta_i)|\n", 274 | "$$\n", 275 | "\n", 276 | "over parameters $\\alpha, \\beta$.\n", 277 | "\n", 278 | "This is a special case of quantile regression when the quantile is the median (0.5)." 279 | ] 280 | }, 281 | { 282 | "cell_type": "code", 283 | "execution_count": null, 284 | "id": "f733ea4c", 285 | "metadata": {}, 286 | "outputs": [], 287 | "source": [ 288 | "mod = smf.quantreg(formula=\"GDP ~ UNRATE\", data=data)\n", 289 | "lad = mod.fit(q=0.5) # LAD model is a special case of quantile regression when q = 0.5" 290 | ] 291 | }, 292 | { 293 | "cell_type": "code", 294 | "execution_count": null, 295 | "id": "8ed98143", 296 | "metadata": {}, 297 | "outputs": [], 298 | "source": [ 299 | "lad.summary()" 300 | ] 301 | }, 302 | { 303 | "cell_type": "code", 304 | "execution_count": null, 305 | "id": "89d6e607", 306 | "metadata": {}, 307 | "outputs": [], 308 | "source": [ 309 | "lad.params" 310 | ] 311 | }, 312 | { 313 | "cell_type": "markdown", 314 | "id": "3459c349", 315 | "metadata": {}, 316 | "source": [ 317 | "Let's compare the LAD regression line to the least squares regression line." 318 | ] 319 | }, 320 | { 321 | "cell_type": "code", 322 | "execution_count": null, 323 | "id": "2a22deaa", 324 | "metadata": {}, 325 | "outputs": [], 326 | "source": [ 327 | "fig, ax = plt.subplots()\n", 328 | "plt.scatter(x='UNRATE', y='GDP', data=data, color='k', alpha=0.5)\n", 329 | "plt.plot(X, ols.fittedvalues, label='OLS')\n", 330 | "plt.plot(X, lad.fittedvalues, label='LAD')\n", 331 | "ax.set_xlabel('% change in unemployment rate')\n", 332 | "ax.set_ylabel('% change in GDP')\n", 333 | "plt.legend()\n", 334 | "plt.show()" 335 | ] 336 | }, 337 | { 338 | "cell_type": "code", 339 | "execution_count": null, 340 | "id": "1acce101", 341 | "metadata": {}, 342 | "outputs": [], 343 | "source": [] 344 | } 345 | ], 346 | "metadata": { 347 | "kernelspec": { 348 | "display_name": "Python 3 (ipykernel)", 349 | "language": "python", 350 | "name": "python3" 351 | }, 352 | "language_info": { 353 | "codemirror_mode": { 354 | "name": "ipython", 355 | "version": 3 356 | }, 357 | "file_extension": ".py", 358 | "mimetype": "text/x-python", 359 | "name": "python", 360 | "nbconvert_exporter": "python", 361 | "pygments_lexer": "ipython3", 362 | "version": "3.11.7" 363 | } 364 | }, 365 | "nbformat": 4, 366 | "nbformat_minor": 5 367 | } 368 | -------------------------------------------------------------------------------- /source_files/tuesday/inventory_dynamics.md: -------------------------------------------------------------------------------- 1 | --- 2 | jupytext: 3 | text_representation: 4 | extension: .md 5 | format_name: myst 6 | format_version: 0.13 7 | jupytext_version: 1.16.1 8 | kernelspec: 9 | display_name: Python 3 (ipykernel) 10 | language: python 11 | name: python3 12 | --- 13 | 14 | # Inventory Dynamics 15 | 16 | ------ 17 | 18 | #### John Stachurski 19 | #### Prepared for the CBC Computational Workshop May 2024 20 | 21 | ----- 22 | 23 | ## Overview 24 | 25 | This lecture explores the inventory dynamics of a firm using so-called s-S inventory control. 26 | 27 | Loosely speaking, this means that the firm 28 | 29 | - waits until inventory falls below some value $ s $ 30 | - and then restocks with a bulk order of $ S $ units (or, in some models, restocks up to level $ S $). 31 | 32 | This lecture will help use become familiar with NumPy. 33 | 34 | (Later we will try similar operations with JAX) 35 | 36 | We will use the following imports: 37 | 38 | ```{code-cell} ipython3 39 | import matplotlib.pyplot as plt 40 | import numpy as np 41 | from collections import namedtuple 42 | ``` 43 | 44 | ## Sample paths 45 | 46 | Consider a firm with inventory $ X_t $. 47 | 48 | The firm waits until $ X_t \leq s $ and then restocks up to $ S $ units. 49 | 50 | It faces stochastic demand $ \{ D_t \} $, which we assume is IID across time and 51 | firms. 52 | 53 | With notation $ a^+ := \max\{a, 0\} $, inventory dynamics can be written 54 | as 55 | 56 | $$ 57 | X_{t+1} = 58 | \begin{cases} 59 | ( S - D_{t+1})^+ & \quad \text{if } X_t \leq s \\ 60 | ( X_t - D_{t+1} )^+ & \quad \text{if } X_t > s 61 | \end{cases} 62 | $$ 63 | 64 | In what follows, we will assume that each $ D_t $ is lognormal, so that 65 | 66 | $$ 67 | D_t = \exp(\mu + \sigma Z_t) 68 | $$ 69 | 70 | where $ \mu $ and $ \sigma $ are parameters and $ \{Z_t\} $ is IID 71 | and standard normal. 72 | 73 | Here’s a `namedtuple` that stores parameters. 74 | 75 | ```{code-cell} ipython3 76 | Parameters = namedtuple('Parameters', ['s', 'S', 'μ', 'σ']) 77 | ``` 78 | 79 | Here's a function that updates from $X_t = x$ to $X_{t+1}$ 80 | 81 | ```{code-cell} ipython3 82 | def update(params, x): 83 | """ 84 | Update the state from t to t+1 given current state x. 85 | 86 | """ 87 | s, S, μ, σ = params 88 | Z = np.random.randn() 89 | D = np.exp(μ + σ * Z) 90 | return max(S - D, 0) if x <= s else max(x - D, 0) 91 | ``` 92 | 93 | Here's a function that generates a time series. 94 | 95 | ```{code-cell} ipython3 96 | def sim_inventory_path(x_init, params, sim_length): 97 | """ 98 | Simulate a time series (X_t) with X_0 = x_init. 99 | 100 | """ 101 | X = np.empty(sim_length) 102 | X[0] = x_init 103 | for t in range(sim_length-1): 104 | X[t+1] = update(params, X[t]) 105 | return X 106 | ``` 107 | 108 | Let's test it. 109 | 110 | ```{code-cell} ipython3 111 | params = Parameters(s=10, S=100, μ=1.0, σ=0.5) 112 | s, S = params.s, params.S 113 | sim_length = 100 114 | x_init = 50 115 | 116 | X = sim_inventory_path(x_init, params, sim_length) 117 | 118 | fig, ax = plt.subplots() 119 | bbox = (0., 1.02, 1., .102) 120 | legend_args = {'ncol': 3, 121 | 'bbox_to_anchor': bbox, 122 | 'loc': 3, 123 | 'mode': 'expand'} 124 | 125 | ax.plot(X, label="inventory") 126 | ax.plot(np.full(sim_length, s), 'k--', label="$s$") 127 | ax.plot(np.full(sim_length, S), 'k-', label="$S$") 128 | ax.set_ylim(0, S+10) 129 | ax.set_xlabel("time") 130 | ax.legend(**legend_args) 131 | 132 | plt.show() 133 | ``` 134 | 135 | ## Cross-sectional distributions 136 | 137 | Now let’s look at the marginal distribution $ \psi_T $ of $ X_T $ for some fixed $ T $. 138 | 139 | The probability distribution $ \psi_T $ is the time $ T $ distribution of firm 140 | inventory levels implied by the model. 141 | 142 | We will approximate this distribution by 143 | 144 | 1. fixing $ n $ to be some large number, indicating the number of firms in the 145 | simulation, 146 | 1. fixing $ T $, the time period we are interested in, 147 | 1. generating $ n $ independent draws from some fixed distribution $ \psi_0 $ that gives the 148 | initial cross-section of inventories for the $ n $ firms, and 149 | 1. shifting this distribution forward in time $ T $ periods, updating each firm 150 | $ T $ times via the dynamics described above (independent of other firms). 151 | 152 | 153 | We will then visualize $ \psi_T $ by histogramming the cross-section. 154 | 155 | We will use the following code to update the cross-section of firms by one period. 156 | 157 | ```{code-cell} ipython3 158 | def update_cross_section(params, X): 159 | """ 160 | Update by one period a cross-section of firms with inventory levels 161 | given by array X. (Thus, X[i] is the inventory of the i-th firm.) 162 | 163 | """ 164 | s, S, μ, σ = params 165 | num_firms = len(X) 166 | Z = np.random.randn(num_firms) 167 | D = np.exp(μ + σ * Z) 168 | X_new = np.where(X <= s, np.maximum(S - D, 0), np.maximum(X - D, 0)) 169 | return X_new 170 | ``` 171 | 172 | ### Shifting the cross-section 173 | 174 | Now we provide code to compute the cross-sectional distribution $ \psi_T $ given some 175 | initial distribution $ \psi_0 $ and a positive integer $ T $. 176 | 177 | In the code below, the initial distribution $ \psi_0 $ takes all firms to have 178 | initial inventory `x_init`. 179 | 180 | ```{code-cell} ipython3 181 | def shift_cross_section(params, X, T): 182 | """ 183 | Shift the cross-sectional distribution X = X[i] forward by T periods. 184 | 185 | """ 186 | for i in range(T): 187 | X = update_cross_section(params, X) 188 | return X 189 | ``` 190 | 191 | We’ll use the following specification 192 | 193 | ```{code-cell} ipython3 194 | x_init = 50 # All firms start at this value 195 | num_firms = 100_000 196 | X = np.full(num_firms, x_init) 197 | T = 500 198 | ``` 199 | 200 | ```{code-cell} ipython3 201 | X = shift_cross_section(params, X, T) 202 | ``` 203 | 204 | Here’s a histogram of inventory levels at time $ T $. 205 | 206 | ```{code-cell} ipython3 207 | fig, ax = plt.subplots() 208 | ax.hist(X, bins=50, 209 | density=True, 210 | histtype='step', 211 | label=f'cross-section when $t = {T}$') 212 | ax.set_xlabel('inventory') 213 | ax.set_ylabel('probability') 214 | ax.legend() 215 | plt.show() 216 | ``` 217 | 218 | ## Distribution dynamics 219 | 220 | Next let’s take a look at how the distribution sequence evolves over time. 221 | 222 | Here is code that repeatedly shifts the cross-section forward while 223 | recording the cross-section at the dates in `sample_dates`. 224 | 225 | All firms start at the same level `x_init`. 226 | 227 | ```{code-cell} ipython3 228 | def shift_forward_and_sample(x_init, params, sample_dates, 229 | num_firms=50_000): 230 | 231 | X = np.full(num_firms, x_init) 232 | X_samples = [] 233 | sim_length = sample_dates[-1] + 1 234 | # Use for loop to update X and collect samples 235 | for i in range(sim_length): 236 | if i in sample_dates: 237 | X_samples.append(X) 238 | X = update_cross_section(params, X) 239 | 240 | return X_samples 241 | ``` 242 | 243 | Let’s test it 244 | 245 | ```{code-cell} ipython3 246 | x_init = 50 247 | num_firms = 10_000 248 | sample_dates = 10, 50, 250, 500 249 | 250 | X_samples = shift_forward_and_sample(x_init, params, sample_dates) 251 | ``` 252 | 253 | Let’s plot the output. 254 | 255 | ```{code-cell} ipython3 256 | fig, ax = plt.subplots() 257 | 258 | for i, date in enumerate(sample_dates): 259 | ax.hist(X_samples, bins=50, 260 | density=True, 261 | histtype='step', 262 | label=f'cross-section when $t = {date}$') 263 | 264 | ax.set_xlabel('inventory') 265 | ax.set_ylabel('probability') 266 | ax.legend() 267 | plt.show() 268 | ``` 269 | 270 | This model for inventory dynamics is asymptotically stationary, with a unique 271 | stationary distribution. 272 | 273 | In particular, the sequence of marginal distributions $ \{\psi_t\} $ 274 | converges to a unique limiting distribution that does not depend on 275 | initial conditions. 276 | 277 | That's why, by $ t=500 $, the distributions are barely changing. 278 | 279 | If you test a few different initial conditions, you will see that they do not affect long-run outcomes. 280 | 281 | +++ 282 | 283 | ## Exercise: Restock frequency 284 | 285 | Let’s study the probability that firms need to restock at least twice over periods $1, \ldots, 50$ when $ X_0 = 70 $. 286 | 287 | We will do this by Monte Carlo: 288 | 289 | * Set the number of firms to `1_000_000`. 290 | * Calculate the fraction of firms that need to order twice or more in the first 50 periods. 291 | 292 | This proportion approximates the probability of the event when the sample size 293 | is large. 294 | 295 | ```{code-cell} ipython3 296 | # Put your code here 297 | ``` 298 | 299 | ```{code-cell} ipython3 300 | for i in range(18): 301 | print("Solution below!") 302 | ``` 303 | 304 | ```{code-cell} ipython3 305 | def compute_freq(params, 306 | x_init=70, 307 | sim_length=50, 308 | num_firms=1_000_000): 309 | s = params.s 310 | # Prepare initial arrays 311 | X = np.full(num_firms, x_init) 312 | # Restock counter starts at zero 313 | counter = np.zeros(num_firms) 314 | 315 | for i in range(sim_length): 316 | X = update_cross_section(params, X) 317 | counter = np.where(X <= s, counter + 1, counter) 318 | return np.mean(counter > 1, axis=0) 319 | ``` 320 | 321 | ```{code-cell} ipython3 322 | freq = compute_freq(params) 323 | print(f"Frequency of at least two stock outs = {freq}") 324 | ``` 325 | 326 | ```{code-cell} ipython3 327 | 328 | ``` 329 | -------------------------------------------------------------------------------- /source_files/wednesday/job_search.md: -------------------------------------------------------------------------------- 1 | --- 2 | jupytext: 3 | text_representation: 4 | extension: .md 5 | format_name: myst 6 | format_version: 0.13 7 | jupytext_version: 1.16.1 8 | kernelspec: 9 | display_name: Python 3 (ipykernel) 10 | language: python 11 | name: python3 12 | --- 13 | 14 | # Job Search 15 | 16 | ---- 17 | #### John Stachurski 18 | #### Prepared for the CBC Computational Workshop (2024) 19 | 20 | ---- 21 | 22 | Uncomment if necessary 23 | 24 | ```{code-cell} ipython3 25 | #!pip install quantecon 26 | ``` 27 | 28 | In this lecture we study a basic infinite-horizon job search problem with Markov wage 29 | draws 30 | 31 | The exercise at the end asks you to add recursive preferences and compare 32 | the result. 33 | 34 | We use the following imports. 35 | 36 | ```{code-cell} ipython3 37 | import matplotlib.pyplot as plt 38 | import quantecon as qe 39 | import jax 40 | import jax.numpy as jnp 41 | from collections import namedtuple 42 | 43 | jax.config.update("jax_enable_x64", True) 44 | ``` 45 | 46 | Let's check our GPU status: 47 | 48 | ```{code-cell} ipython3 49 | !nvidia-smi 50 | ``` 51 | 52 | ## Model 53 | 54 | We study an elementary model where 55 | 56 | * jobs are permanent 57 | * unemployed workers receive current compensation $c$ 58 | * the wage offer distribution $\{W_t\}$ is Markovian 59 | * the horizon is infinite 60 | * an unemployment agent discounts the future via discount factor $\beta \in (0,1)$ 61 | 62 | ### Set up 63 | 64 | The wage offer process obeys 65 | 66 | $$ 67 | W_{t+1} = \rho W_t + \nu Z_{t+1} 68 | $$ 69 | 70 | where $(Z_t)_{t \geq 0}$ is IID and standard normal. 71 | 72 | We discretize this wage process using Tauchen's method to produce a stochastic matrix $P$ 73 | 74 | ### Rewards 75 | 76 | Since jobs are permanent, the return to accepting wage offer $w$ today is 77 | 78 | $$ 79 | w + \beta w + \beta^2 w + \frac{w}{1-\beta} 80 | $$ 81 | 82 | The Bellman equation is 83 | 84 | $$ 85 | v(w) = \max 86 | \left\{ 87 | \frac{w}{1-\beta}, c + \beta \sum_{w'} v(w') P(w, w') 88 | \right\} 89 | $$ 90 | 91 | We solve this model using value function iteration. 92 | 93 | +++ 94 | 95 | ## Code 96 | 97 | Let's set up a namedtuple to store information needed to solve the model. 98 | 99 | ```{code-cell} ipython3 100 | Model = namedtuple('Model', ('n', 'w_vals', 'P', 'β', 'c')) 101 | ``` 102 | 103 | The function below holds default values and populates the namedtuple. 104 | 105 | ```{code-cell} ipython3 106 | def create_js_model( 107 | n=500, # wage grid size 108 | ρ=0.9, # wage persistence 109 | ν=0.2, # wage volatility 110 | β=0.99, # discount factor 111 | c=1.0, # unemployment compensation 112 | ): 113 | "Creates an instance of the job search model with Markov wages." 114 | mc = qe.tauchen(n, ρ, ν) 115 | w_vals, P = jnp.exp(mc.state_values), jnp.array(mc.P) 116 | return Model(n, w_vals, P, β, c) 117 | ``` 118 | 119 | Let's test it: 120 | 121 | ```{code-cell} ipython3 122 | model = create_js_model(β=0.98) 123 | ``` 124 | 125 | ```{code-cell} ipython3 126 | model.c 127 | ``` 128 | 129 | ```{code-cell} ipython3 130 | model.β 131 | ``` 132 | 133 | ```{code-cell} ipython3 134 | model.w_vals.mean() 135 | ``` 136 | 137 | Here's the Bellman operator. 138 | 139 | ```{code-cell} ipython3 140 | @jax.jit 141 | def T(v, model): 142 | """ 143 | The Bellman operator Tv = max{e, c + β E v} with 144 | 145 | e(w) = w / (1-β) and (Ev)(w) = E_w[ v(W')] 146 | 147 | """ 148 | n, w_vals, P, β, c = model 149 | h = c + β * P @ v 150 | e = w_vals / (1 - β) 151 | 152 | return jnp.maximum(e, h) 153 | ``` 154 | 155 | Question: Is this a pure function? 156 | 157 | +++ 158 | 159 | The next function computes the optimal policy under the assumption that $v$ is 160 | the value function. 161 | 162 | The policy takes the form 163 | 164 | $$ 165 | \sigma(w) = \mathbf 1 166 | \left\{ 167 | \frac{w}{1-\beta} \geq c + \beta \sum_{w'} v(w') P(w, w') 168 | \right\} 169 | $$ 170 | 171 | Here $\mathbf 1$ is an indicator function. 172 | 173 | * $\sigma(w) = 1$ means stop 174 | * $\sigma(w) = 0$ means continue. 175 | 176 | ```{code-cell} ipython3 177 | @jax.jit 178 | def get_greedy(v, model): 179 | "Get a v-greedy policy." 180 | n, w_vals, P, β, c = model 181 | e = w_vals / (1 - β) 182 | h = c + β * P @ v 183 | σ = jnp.where(e >= h, 1, 0) 184 | return σ 185 | ``` 186 | 187 | Here's a routine for value function iteration. 188 | 189 | ```{code-cell} ipython3 190 | def vfi(model, max_iter=10_000, tol=1e-4): 191 | "Solve the infinite-horizon Markov job search model by VFI." 192 | print("Starting VFI iteration.") 193 | v = jnp.zeros_like(model.w_vals) # Initial guess 194 | i = 0 195 | error = tol + 1 196 | 197 | while error > tol and i < max_iter: 198 | new_v = T(v, model) 199 | error = jnp.max(jnp.abs(new_v - v)) 200 | i += 1 201 | v = new_v 202 | 203 | v_star = v 204 | σ_star = get_greedy(v_star, model) 205 | return v_star, σ_star 206 | ``` 207 | 208 | Question: Is this a pure function? 209 | 210 | +++ 211 | 212 | ## Computing the solution 213 | 214 | Let's set up and solve the model. 215 | 216 | ```{code-cell} ipython3 217 | model = create_js_model() 218 | n, w_vals, P, β, c = model 219 | 220 | v_star, σ_star = vfi(model) 221 | ``` 222 | 223 | Here's the optimal policy: 224 | 225 | ```{code-cell} ipython3 226 | fig, ax = plt.subplots() 227 | ax.plot(σ_star) 228 | ax.set_xlabel("wage values") 229 | ax.set_ylabel("optimal choice (stop=1)") 230 | plt.show() 231 | ``` 232 | 233 | We compute the reservation wage as the first $w$ such that $\sigma(w)=1$. 234 | 235 | ```{code-cell} ipython3 236 | stop_indices = jnp.where(σ_star == 1) 237 | stop_indices 238 | ``` 239 | 240 | ```{code-cell} ipython3 241 | res_wage_index = min(stop_indices[0]) 242 | ``` 243 | 244 | ```{code-cell} ipython3 245 | res_wage = w_vals[res_wage_index] 246 | ``` 247 | 248 | ```{code-cell} ipython3 249 | fig, ax = plt.subplots() 250 | ax.plot(w_vals, v_star, alpha=0.8, label="value function") 251 | ax.vlines((res_wage,), 150, 400, 'k', ls='--', label="reservation wage") 252 | ax.legend(frameon=False, fontsize=12, loc="lower right") 253 | ax.set_xlabel("$w$", fontsize=12) 254 | plt.show() 255 | ``` 256 | 257 | ## Exercise 258 | 259 | In the setting above, the agent is risk-neutral vis-a-vis future utility risk. 260 | 261 | Now solve the same problem but this time assuming that the agent has risk-sensitive 262 | preferences, which are a type of nonlinear recursive preferences. 263 | 264 | The Bellman equation becomes 265 | 266 | $$ 267 | v(w) = \max 268 | \left\{ 269 | \frac{w}{1-\beta}, 270 | c + \frac{\beta}{\theta} 271 | \ln \left[ 272 | \sum_{w'} \exp(\theta v(w')) P(w, w') 273 | \right] 274 | \right\} 275 | $$ 276 | 277 | 278 | When $\theta < 0$ the agent is risk averse. 279 | 280 | Solve the model when $\theta = -0.1$ and compare your result to the risk neutral 281 | case. 282 | 283 | Try to interpret your result. 284 | 285 | You can start with the following code: 286 | 287 | ```{code-cell} ipython3 288 | Model = namedtuple('Model', ('n', 'w_vals', 'P', 'β', 'c', 'θ')) 289 | ``` 290 | 291 | ```{code-cell} ipython3 292 | def create_risk_sensitive_js_model( 293 | n=500, # wage grid size 294 | ρ=0.9, # wage persistence 295 | ν=0.2, # wage volatility 296 | β=0.99, # discount factor 297 | c=1.0, # unemployment compensation 298 | θ=-0.1 # risk parameter 299 | ): 300 | "Creates an instance of the job search model with Markov wages." 301 | mc = qe.tauchen(n, ρ, ν) 302 | w_vals, P = jnp.exp(mc.state_values), mc.P 303 | P = jnp.array(P) 304 | return Model(n, w_vals, P, β, c, θ) 305 | ``` 306 | 307 | Now you need to modify `T` and `get_greedy` and then run value function iteration again. 308 | 309 | ```{code-cell} ipython3 310 | # Put your code here 311 | ``` 312 | 313 | ```{code-cell} ipython3 314 | for i in range(20): 315 | print("Solution below!") 316 | ``` 317 | 318 | ```{code-cell} ipython3 319 | @jax.jit 320 | def T_rs(v, model): 321 | """ 322 | The Bellman operator Tv = max{e, c + β R v} with 323 | 324 | e(w) = w / (1-β) and 325 | 326 | (Rv)(w) = (1/θ) ln{E_w[ exp(θ v(W'))]} 327 | 328 | """ 329 | n, w_vals, P, β, c, θ = model 330 | h = c + (β / θ) * jnp.log(P @ (jnp.exp(θ * v))) 331 | e = w_vals / (1 - β) 332 | 333 | return jnp.maximum(e, h) 334 | 335 | 336 | @jax.jit 337 | def get_greedy_rs(v, model): 338 | " Get a v-greedy policy." 339 | n, w_vals, P, β, c, θ = model 340 | e = w_vals / (1 - β) 341 | h = c + (β / θ) * jnp.log(P @ (jnp.exp(θ * v))) 342 | σ = jnp.where(e >= h, 1, 0) 343 | return σ 344 | 345 | 346 | 347 | def vfi(model, max_iter=10_000, tol=1e-4): 348 | "Solve the infinite-horizon Markov job search model by VFI." 349 | print("Starting VFI iteration.") 350 | v = jnp.zeros_like(model.w_vals) # Initial guess 351 | i = 0 352 | error = tol + 1 353 | 354 | while error > tol and i < max_iter: 355 | new_v = T_rs(v, model) 356 | error = jnp.max(jnp.abs(new_v - v)) 357 | i += 1 358 | v = new_v 359 | 360 | v_star = v 361 | σ_star = get_greedy_rs(v_star, model) 362 | return v_star, σ_star 363 | 364 | 365 | 366 | model_rs = create_risk_sensitive_js_model() 367 | n, w_vals, P, β, c, θ = model_rs 368 | 369 | v_star_rs, σ_star_rs = vfi(model_rs) 370 | ``` 371 | 372 | Let's plot the results together with the original risk neutral case and see what we get. 373 | 374 | ```{code-cell} ipython3 375 | stop_indices = jnp.where(σ_star_rs == 1) 376 | res_wage_index = min(stop_indices[0]) 377 | res_wage_rs = w_vals[res_wage_index] 378 | ``` 379 | 380 | ```{code-cell} ipython3 381 | fig, ax = plt.subplots() 382 | ax.plot(w_vals, v_star, alpha=0.8, label="risk neutral $v$") 383 | ax.plot(w_vals, v_star_rs, alpha=0.8, label="risk sensitive $v$") 384 | ax.vlines((res_wage,), 100, 400, ls='--', color='darkblue', 385 | alpha=0.5, label=r"risk neutral $\bar w$") 386 | ax.vlines((res_wage_rs,), 100, 400, ls='--', color='orange', 387 | alpha=0.5, label=r"risk sensitive $\bar w$") 388 | ax.legend(frameon=False, fontsize=12, loc="lower right") 389 | ax.set_xlabel("$w$", fontsize=12) 390 | plt.show() 391 | ``` 392 | 393 | ```{code-cell} ipython3 394 | 395 | ``` 396 | -------------------------------------------------------------------------------- /source_files/tuesday/lorenz_gini.md: -------------------------------------------------------------------------------- 1 | --- 2 | jupytext: 3 | text_representation: 4 | extension: .md 5 | format_name: myst 6 | format_version: 0.13 7 | jupytext_version: 1.16.1 8 | kernelspec: 9 | display_name: Python 3 (ipykernel) 10 | language: python 11 | name: python3 12 | --- 13 | 14 | # Lorenz Curves and Gini Coefficients 15 | 16 | ----- 17 | 18 | #### John Stachurski 19 | 20 | #### Prepared for the CBC Computational Workshop (May 2024) 21 | 22 | ----- 23 | 24 | This notebook contains some exercises related to the Lorenz curve and the Gini 25 | coefficient, which are often used to study inequality. 26 | 27 | Our task will be to compute and examine these curves and values. 28 | 29 | Uncomment the following if necessary 30 | 31 | ```{code-cell} ipython3 32 | #!pip install quantecon 33 | ``` 34 | 35 | We use the following imports. 36 | 37 | ```{code-cell} ipython3 38 | import numba 39 | import numpy as np 40 | import matplotlib.pyplot as plt 41 | import quantecon as qe 42 | ``` 43 | 44 | ## The Lorenz curve 45 | 46 | Let's start by examining the Lorenz curve. 47 | 48 | ### Definition 49 | 50 | Let $w_1, \ldots, w_n$ be a sample of observations of wealth (or income, or consumption, or firm sizes, etc.) in a population. 51 | 52 | Suppose the sample has been sorted from smallest to largest. 53 | 54 | The Lorenz curve takes this sample and produces a curve $L$. 55 | 56 | To create it we first generate data points $(x_i, y_i)_{i=0}^n$ according to 57 | 58 | \begin{equation*} 59 | x_0 = y_0 = 0 60 | \qquad \text{and, for $i \geq 1$,} \quad 61 | x_i = \frac{i}{n}, 62 | \qquad 63 | y_i = 64 | \frac{\sum_{j \leq i} w_j}{\sum_{j \leq n} w_j} 65 | \end{equation*} 66 | 67 | Now the Lorenz curve $L$ is formed from these data points using interpolation. 68 | 69 | The meaning of the statement $y = L(x)$ is that the lowest $(100 \times x)$\% of 70 | people have $(100 \times y)$\% of all wealth. 71 | 72 | * if $x=0.5$ and $y=0.1$, then the bottom 50% of the population 73 | owns 10% of the wealth. 74 | 75 | +++ 76 | 77 | ### Using QuantEcon's routine 78 | 79 | Let's look at an example. 80 | 81 | First we generate $n=2000$ draws from a lognormal distribution and treat these draws as our population. 82 | 83 | ```{code-cell} ipython3 84 | n = 2000 85 | sample = np.exp(np.random.randn(n)) # Lognormal sample 86 | ``` 87 | 88 | We then generate the Lorenz curve using a routine from `quantecon`. 89 | 90 | ```{code-cell} ipython3 91 | x, y = qe.lorenz_curve(sample) # QuantEcon routine (no need to sort) 92 | ``` 93 | 94 | Now let's plot. 95 | 96 | The straight line ($x=L(x)$ for all $x$) corresponds to perfect equality. 97 | 98 | The lognormal draws produce a less equal distribution. 99 | 100 | ```{code-cell} ipython3 101 | fig, ax = plt.subplots() 102 | ax.plot(x, y, label=f'lognormal sample', lw=2) 103 | ax.plot(x, x, label='equality', lw=2) 104 | ax.legend(fontsize=12) 105 | ax.set_ylim((0, 1)) 106 | ax.set_xlim((0, 1)) 107 | j = 1600 # dashed lines for j-th element 108 | ax.vlines(x[j], [0.0], y[j], alpha=0.5, colors='k', ls='--') 109 | ax.hlines(y[j], [0], x[j], alpha=0.5, colors='k', ls='--') 110 | plt.show() 111 | ``` 112 | 113 | In this sample, the dashed lines show that the bottom 80\% of 114 | households own just over 40\% of total wealth. 115 | 116 | +++ 117 | 118 | ### Exercise: write a NumPy version 119 | 120 | Using the definition of the Lorenz curve given above and NumPy, try to write 121 | your own version of `qe.lorenz_curve`. 122 | 123 | See if you can write a version without any explicity loops. 124 | 125 | Try to replicate the figure above, using the same lognormal data set. 126 | 127 | ```{code-cell} ipython3 128 | # Put your code here 129 | ``` 130 | 131 | ```{code-cell} ipython3 132 | for i in range(16): 133 | print("Solution below!") 134 | ``` 135 | 136 | Here's one solution: 137 | 138 | ```{code-cell} ipython3 139 | def lorenz_curve(w): 140 | n = len(w) 141 | w = np.sort(w) 142 | x = np.arange(n + 1) / n 143 | s = np.concatenate((np.zeros(1), np.cumsum(w))) 144 | y = s / s[n] 145 | return x, y 146 | ``` 147 | 148 | Let's test it: 149 | 150 | ```{code-cell} ipython3 151 | x, y = lorenz_curve(sample) # Our routine 152 | 153 | fig, ax = plt.subplots() 154 | ax.plot(x, y, label=f'lognormal sample', lw=2) 155 | ax.plot(x, x, label='equality', lw=2) 156 | ax.legend(fontsize=12) 157 | ax.set_ylim((0, 1)) 158 | ax.set_xlim((0, 1)) 159 | j = 1600 # dashed lines for j-th element 160 | ax.vlines(x[j], [0.0], y[j], alpha=0.5, colors='k', ls='--') 161 | ax.hlines(y[j], [0], x[j], alpha=0.5, colors='k', ls='--') 162 | plt.show() 163 | ``` 164 | 165 | ### A Numba version 166 | 167 | If you prefer, you can use a for loop accelerated by Numba to compute the 168 | Lorenz curve: 169 | 170 | ```{code-cell} ipython3 171 | @numba.jit 172 | def lorenz_curve(w): 173 | n = len(w) 174 | w = np.sort(w) 175 | s = np.zeros(n + 1) 176 | s[1:] = np.cumsum(w) # s[i] = sum_{j <= i} w_j 177 | x = np.zeros(n + 1) 178 | y = np.zeros(n + 1) 179 | for i in range(1, n + 1): 180 | x[i] = i / n 181 | y[i] = s[i] / s[n] 182 | return x, y 183 | 184 | 185 | x, y = lorenz_curve(sample) # Our routine 186 | 187 | fig, ax = plt.subplots() 188 | ax.plot(x, y, label=f'lognormal sample', lw=2) 189 | ax.plot(x, x, label='equality', lw=2) 190 | ax.legend(fontsize=12) 191 | ax.set_ylim((0, 1)) 192 | ax.set_xlim((0, 1)) 193 | j = 1600 # dashed lines for j-th element 194 | ax.vlines(x[j], [0.0], y[j], alpha=0.5, colors='k', ls='--') 195 | ax.hlines(y[j], [0], x[j], alpha=0.5, colors='k', ls='--') 196 | plt.show() 197 | ``` 198 | 199 | ## The Gini coefficient 200 | 201 | Now let's examine the Gini coefficient. 202 | 203 | 204 | ### Definition 205 | 206 | 207 | Continuing to assume that $w_1, \ldots, w_n$ has been sorted from smallest to largest, 208 | the Gini coefficient of the sample is defined by 209 | 210 | \begin{equation} 211 | \label{eq:gini} 212 | G := 213 | \frac 214 | {\sum_{i=1}^n \sum_{j = 1}^n |w_j - w_i|} 215 | {2n\sum_{i=1}^n w_i}. 216 | \end{equation} 217 | 218 | 219 | 220 | ### Using QuantEcon's routine 221 | 222 | Let's study the Gini coefficient in some simulations using `gini_coefficient` 223 | from `quantecon`. 224 | 225 | The following code computes the Gini coefficients for five different populations. 226 | 227 | Each of these populations is generated by drawing from a lognormal distribution with parameters $\mu$ (mean) and $\sigma$ (standard deviation). 228 | 229 | To create the five populations, we vary $\sigma$ over a grid of length $5$ 230 | between $0.2$ and $4$. 231 | 232 | In each case we set $\mu = - \sigma^2 / 2$, so that the mean of the distribution does not change with $\sigma$. 233 | 234 | ```{code-cell} ipython3 235 | k = 5 236 | σ_vals = np.linspace(0.2, 4, k) 237 | n = 2_000 238 | ginis = [] 239 | for σ in σ_vals: 240 | # Generate the data 241 | μ = -σ**2 / 2 242 | y = np.exp(μ + σ * np.random.randn(n)) 243 | ginis.append(qe.gini_coefficient(y)) # Uses quantecon routine 244 | 245 | fig, ax = plt.subplots() 246 | ax.plot(σ_vals, ginis, marker='o') 247 | ax.set_xlabel('$\sigma$', fontsize=12) 248 | ax.set_ylabel('Gini coefficient', fontsize=12) 249 | plt.show() 250 | ``` 251 | 252 | The plots show that inequality rises with $\sigma$ (as measured by the Gini coefficient). 253 | 254 | +++ 255 | 256 | ### A NumPy version 257 | 258 | Let's write our own function to compute the Gini coefficient. 259 | 260 | We'll start with a NumPy version that uses vectorized code to avoid loops. 261 | 262 | ```{code-cell} ipython3 263 | def gini(w): 264 | n = len(w) 265 | w_1 = np.reshape(w, (n, 1)) 266 | w_2 = np.reshape(w, (1, n)) 267 | g_sum = np.sum(np.abs(w_1 - w_2)) 268 | return g_sum / (2 * n * np.sum(w)) 269 | ``` 270 | 271 | ```{code-cell} ipython3 272 | ginis = [] 273 | for σ in σ_vals: 274 | # Generate the data 275 | μ = -σ**2 / 2 276 | y = np.exp(μ + σ * np.random.randn(n)) 277 | ginis.append(gini(y)) # Use our NumPy version 278 | 279 | fig, ax = plt.subplots() 280 | ax.plot(σ_vals, ginis, marker='o') 281 | ax.set_xlabel('$\sigma$', fontsize=12) 282 | ax.set_ylabel('Gini coefficient', fontsize=12) 283 | plt.show() 284 | ``` 285 | 286 | Notice, however, that the NumPy version of the Gini function is very memory intensive, since we create large intermediate arrays. 287 | 288 | For example, consider the following 289 | 290 | ```{code-cell} ipython3 291 | w = np.exp(np.random.randn(1_000_000)) 292 | w.sort() 293 | ``` 294 | 295 | ```{code-cell} ipython3 296 | gini(w) 297 | ``` 298 | 299 | Unless you have massive memory, the code above gives an out-of-memory error. 300 | 301 | The next exercise asks you to write a more memory efficient version. 302 | 303 | +++ 304 | 305 | ### Exercise: A Numba version 306 | 307 | Try to write your own function that computes the Gini coefficient, this time 308 | using Numba and loops to produce effient code 309 | 310 | * Try to replicate the Gini figure above. 311 | * If possible, parallelize one of the loops 312 | * See if your code runs on `w = np.exp(np.random.randn(1_000_000))` 313 | 314 | ```{code-cell} ipython3 315 | # Put your code here 316 | ``` 317 | 318 | ```{code-cell} ipython3 319 | for i in range(18): 320 | print("Solution below!") 321 | ``` 322 | 323 | Here's one solution. 324 | 325 | Notice how easy it is to parallelize the loop --- even though `s` is common across the outer loops, which violates independence, this loop is still efficiently parallelized. 326 | 327 | ```{code-cell} ipython3 328 | @numba.jit(parallel=True) 329 | def gini_numba(w): 330 | n = len(w) 331 | s = 0.0 332 | for i in numba.prange(n): 333 | for j in range(n): 334 | s += abs(w[i] - w[j]) 335 | return s / (2 * n * np.sum(w)) 336 | ``` 337 | 338 | Let's recreate the figure. 339 | 340 | ```{code-cell} ipython3 341 | ginis = [] 342 | 343 | for σ in σ_vals: 344 | μ = -σ**2 / 2 345 | y = np.exp(μ + σ * np.random.randn(n)) 346 | ginis.append(gini_numba(y)) # Use Numba version 347 | 348 | 349 | fig, ax = plt.subplots() 350 | ax.plot(σ_vals, ginis, marker='o') 351 | ax.set_xlabel('$\sigma$', fontsize=12) 352 | ax.set_ylabel('Gini coefficient', fontsize=12) 353 | plt.show() 354 | ``` 355 | 356 | And let's see if it works on the large data set we considered above. 357 | 358 | (Note that it will take a couple of minutes to run!) 359 | 360 | ```{code-cell} ipython3 361 | w = np.exp(np.random.randn(1_000_000)) 362 | w.sort() 363 | gini_numba(w) 364 | ``` 365 | 366 | ```{code-cell} ipython3 367 | 368 | ``` 369 | -------------------------------------------------------------------------------- /source_files/tuesday/numba.md: -------------------------------------------------------------------------------- 1 | --- 2 | jupytext: 3 | text_representation: 4 | extension: .md 5 | format_name: myst 6 | format_version: 0.13 7 | jupytext_version: 1.16.2 8 | kernelspec: 9 | display_name: Python 3 (ipykernel) 10 | language: python 11 | name: python3 12 | --- 13 | 14 | # Accelerating Python Code with Numba 15 | 16 | ---- 17 | 18 | ### Written for the CBC Workshop (May 2024) 19 | 20 | #### John Stachurski 21 | 22 | 23 | ----- 24 | 25 | +++ 26 | 27 | In the lecture on NumPy we saw that vectorization via NumPy can help accelerate our code. 28 | 29 | But 30 | 31 | * Vectorization can be very memory intensive. 32 | * NumPy-style vectorization cannot fully exploit parallel hardware (precompiled NumPy binaries cannot optimize on array size or over hardware accelerators). 33 | * Some problems cannot be vectorized --- they need to be written in loops. 34 | 35 | This notebook contains a very quick introduction to an alternative method for accelerating code, via 36 | [Numba](https://numba.pydata.org/). 37 | 38 | We use the following imports: 39 | 40 | ```{code-cell} ipython3 41 | import numpy as np 42 | from numba import vectorize, jit, float64 43 | import matplotlib.pyplot as plt 44 | ``` 45 | 46 | ## Example: Solow-Swan dynamics 47 | 48 | 49 | Here we look at one example of the last case and discuss how to accelerate it. 50 | 51 | ### A loop in Python 52 | 53 | Let's suppose we are interested in the long-run behavior of capital stock in the stochasic 54 | Solow-Swan model 55 | 56 | $$ 57 | k_{t+1} = a_{t+1} s k_t^\alpha + (1-\delta)k_t 58 | $$ 59 | 60 | where $(a_t)$ is IID and uniform. 61 | 62 | Here's some code to generate a short time series, plus a plot. 63 | 64 | ```{code-cell} ipython3 65 | α, s, δ, a = 0.4, 0.3, 0.1, 1.0 66 | n = 120 67 | k = np.empty(n) 68 | k[0] = 0.2 69 | for t in range(n-1): 70 | k[t+1] = np.random.rand() * s * k[t]**α + (1 - δ) * k[t] 71 | 72 | fig, ax = plt.subplots() 73 | ax.plot(k, 'o-', ms=2, lw=1) 74 | ax.set_xlabel('time') 75 | ax.set_ylabel('capital') 76 | plt.show() 77 | ``` 78 | 79 | Let's say that we want to compute long-run average capital stock. 80 | 81 | ```{code-cell} ipython3 82 | def solow(n=10_000_000, α=0.4, s=0.3, δ=0.1, k0=0.2): 83 | k = k0 84 | k_sum = k0 85 | for t in range(n): 86 | a = np.random.rand() 87 | k = a * s * k**α + (1 - δ) * k 88 | k_sum += k 89 | return k_sum / n 90 | 91 | %time k = solow() 92 | print(f"Steady-state capital = {k}.") 93 | ``` 94 | 95 | Steady-state capital: 96 | 97 | +++ 98 | 99 | Notice that the run-time is pretty slow. 100 | 101 | Also, we can't use NumPy to accelerate it because there's no way to vectorize the loop. 102 | 103 | Let's look at some alternatives. 104 | 105 | +++ 106 | 107 | ### A Fortran version 108 | 109 | We can make it fast if we rewrite it in Fortran. 110 | 111 | To execute the following Fortran code on your machine, you need 112 | 113 | * a Fortran compiler (such as the open source `gfortran` compiler) and also 114 | * the `fortranmagic` Jupyter extension, which can be installed by uncommenting 115 | 116 | ```{code-cell} ipython3 117 | #!pip install fortran-magic 118 | ``` 119 | 120 | Now we load the extension (skip executing this and the rest of the section if you don't have a Fortran compiler). 121 | 122 | ```{code-cell} ipython3 123 | %load_ext fortranmagic 124 | ``` 125 | 126 | In the following code, all parameters are the same as for the Python code above. 127 | 128 | +++ 129 | 130 | Now we add the cell magic ``%%fortran`` to a cell that contains a Fortran subroutine for the Solow-Swan computation: 131 | 132 | ```{code-cell} ipython3 133 | %%fortran 134 | 135 | subroutine solow_fortran(k0, s, delta, alpha, n, kt) 136 | implicit none 137 | integer, parameter :: dp=kind(0.d0) 138 | integer, intent(in) :: n 139 | real(dp), intent(in) :: k0, s, delta, alpha 140 | real(dp), intent(out) :: kt 141 | real(dp) :: k, k_sum, a 142 | integer :: i 143 | k = k0 144 | k_sum = k0 145 | call random_seed 146 | do i = 1, n - 1 147 | call random_number(a) 148 | k = a * s * k**alpha + (1 - delta) * k 149 | k_sum = k_sum + k 150 | end do 151 | kt = k_sum / real(n) 152 | end subroutine solow_fortran 153 | ``` 154 | 155 | Now we can call the function `solow_fortran` from Python. 156 | 157 | (`fortranmagic` uses a program called `F2Py` to create a Python "wrapper" for the Fortran subroutine so we can access it from within Python.) 158 | 159 | Let's make sure it gives a reasonable answer. 160 | 161 | ```{code-cell} ipython3 162 | n = 10_000_000 163 | solow_fortran(0.2, s, δ, α, n) 164 | ``` 165 | 166 | Now let's time it: 167 | 168 | ```{code-cell} ipython3 169 | %time solow_fortran(0.2, s, δ, α, n) 170 | ``` 171 | 172 | Let's time it more carefully, over multiple runs: 173 | 174 | ```{code-cell} ipython3 175 | %timeit solow_fortran(0.2, s, δ, α, n) 176 | ``` 177 | 178 | The speed gain is about 1 order of magnitude. 179 | 180 | +++ 181 | 182 | ## Numba 183 | 184 | Now let's try the same thing in Python using Numba's JIT compilation. 185 | 186 | We recall the Python function from above. 187 | 188 | ```{code-cell} ipython3 189 | def solow(n=10_000_000, α=0.4, s=0.3, δ=0.1, k0=0.2): 190 | k = k_sum = k0 191 | for t in range(n-1): 192 | a = np.random.rand() 193 | k = a * s * k**α + (1 - δ) * k 194 | k_sum += k 195 | return k_sum / n 196 | ``` 197 | 198 | Now let's flag it for JIT-compilation: 199 | 200 | ```{code-cell} ipython3 201 | solow_jitted = jit(solow) 202 | ``` 203 | 204 | And then run it: 205 | 206 | ```{code-cell} ipython3 207 | %time k = solow_jitted() 208 | ``` 209 | 210 | ```{code-cell} ipython3 211 | %time k = solow_jitted() 212 | ``` 213 | 214 | ```{code-cell} ipython3 215 | %timeit k = solow_jitted() 216 | ``` 217 | 218 | Hopefully we get a similar value (about 1.95): 219 | 220 | ```{code-cell} ipython3 221 | k 222 | ``` 223 | 224 | Here's the same thing using decorator notation. 225 | 226 | ```{code-cell} ipython3 227 | @jit 228 | def solow(n=10_000_000, α=0.4, s=0.3, δ=0.1, k0=0.2): 229 | k = k_sum = k0 230 | for t in range(n-1): 231 | a = np.random.rand() 232 | k = a * s * k**α + (1 - δ) * k 233 | k_sum += k 234 | return k_sum / n 235 | ``` 236 | 237 | ```{code-cell} ipython3 238 | %time k = solow() 239 | ``` 240 | 241 | ```{code-cell} ipython3 242 | %time k = solow() 243 | ``` 244 | 245 | After JIT compilation, function execution speed is about the same as Fortran. 246 | 247 | ```{code-cell} ipython3 248 | k 249 | ``` 250 | 251 | #### How does it work? 252 | 253 | +++ 254 | 255 | The secret sauce is type inference inside the function body. 256 | 257 | When we call `solow_jitted` with particular arguments, Numba's compiler works 258 | through the function body and infers the types of the variables inside the 259 | function. 260 | 261 | It then produces compiled code *specialized to that type signature* 262 | 263 | For example, we called `solow_jitted` with a `float, int` pair in the cell above and the compiler produced code specialized to those types. 264 | 265 | That code runs fast because the compiler can fully specialize all operations inside the function based on that information and hence write very efficient machine code. 266 | 267 | +++ 268 | 269 | #### Limitations of Numba 270 | 271 | Numba is great when it works but it can't compile functions that aren't 272 | themselves JIT compiled. 273 | 274 | +++ 275 | 276 | In practice, this means we can't use most third party libraries: 277 | 278 | ```{code-cell} ipython3 279 | from scipy.integrate import quad 280 | 281 | def compute_integral(n): 282 | return quad(lambda x: x**(1/n), 0, 1) 283 | ``` 284 | 285 | This works fine if we don't jit the function. 286 | 287 | ```{code-cell} ipython3 288 | compute_integral(4) 289 | ``` 290 | 291 | But if we do... 292 | 293 | ```{code-cell} ipython3 294 | @jit 295 | def compute_integral(n): 296 | return quad(lambda x: x**(1/n), 0, 1) 297 | 298 | ``` 299 | 300 | ```{code-cell} ipython3 301 | compute_integral(4) 302 | ``` 303 | 304 | The reason is that the `quad` function is from SciPy and Numba doesn't know how 305 | to handle it. 306 | 307 | Key message: even though it might not be possible to JIT-compile your whole 308 | program, you might well be able to compile the hot loops that are eating up 99% 309 | of your computation time. 310 | 311 | If you can do this, you open up large speed gains, as the following sections 312 | make clear. 313 | 314 | +++ 315 | 316 | ### Vectorization vs Numba 317 | 318 | +++ 319 | 320 | We made the point above that some problems are hard or impossible to vectorize and, in these situations, that we can use Numba instead of NumPy to accelerate our code. 321 | 322 | However, there are also many situations where we *can* vectorize our code but Numba is still the better option. 323 | 324 | Let's look at an example. 325 | 326 | +++ 327 | 328 | The problem is to maximize the function 329 | 330 | $$ f(x, y) = \frac{\cos \left(x^2 + y^2 \right)}{1 + x^2 + y^2} + 1$$ 331 | 332 | using brute force --- searching over a grid of $(x, y)$ pairs. 333 | 334 | ```{code-cell} ipython3 335 | def f(x, y): 336 | return np.cos(x**2 + y**2) / (1 + x**2 + y**2) + 1 337 | ``` 338 | 339 | ```{code-cell} ipython3 340 | from mpl_toolkits.mplot3d.axes3d import Axes3D 341 | from matplotlib import cm 342 | 343 | gridsize = 50 344 | gmin, gmax = -3, 3 345 | xgrid = np.linspace(gmin, gmax, gridsize) 346 | ygrid = xgrid 347 | x, y = np.meshgrid(xgrid, ygrid) 348 | 349 | # === plot value function === # 350 | fig = plt.figure(figsize=(10, 8)) 351 | ax = fig.add_subplot(111, projection='3d') 352 | ax.plot_surface(x, 353 | y, 354 | f(x, y), 355 | rstride=2, cstride=2, 356 | cmap=cm.jet, 357 | alpha=0.4, 358 | linewidth=0.05) 359 | 360 | 361 | ax.scatter(x, y, c='k', s=0.6) 362 | 363 | ax.scatter(x, y, f(x, y), c='k', s=0.6) 364 | 365 | ax.view_init(25, -57) 366 | ax.set_zlim(-0, 2.0) 367 | ax.set_xlim(gmin, gmax) 368 | ax.set_ylim(gmin, gmax) 369 | 370 | plt.show() 371 | ``` 372 | 373 | #### Vectorized code 374 | 375 | ```{code-cell} ipython3 376 | n = 10_000 377 | grid = np.linspace(-3, 3, n) 378 | ``` 379 | 380 | ```{code-cell} ipython3 381 | x, y = np.meshgrid(grid, grid) 382 | ``` 383 | 384 | ```{code-cell} ipython3 385 | --- 386 | nbpresent: 387 | id: 1ba9f9f9-f737-4ee1-86e6-0a33c4752188 388 | --- 389 | %%time 390 | np.max(f(x, y)) 391 | ``` 392 | 393 | #### JITTed code 394 | 395 | +++ 396 | 397 | Let's try a jitted version with loops. 398 | 399 | ```{code-cell} ipython3 400 | @jit 401 | def compute_max(): 402 | m = -np.inf 403 | for x in grid: 404 | for y in grid: 405 | z = np.cos(x**2 + y**2) / (1 + x**2 + y**2) + 1 406 | if z > m: 407 | m = z 408 | return m 409 | ``` 410 | 411 | ```{code-cell} ipython3 412 | %%time 413 | compute_max() 414 | ``` 415 | 416 | ```{code-cell} ipython3 417 | %%time 418 | compute_max() 419 | ``` 420 | 421 | Why the speed gain? 422 | 423 | +++ 424 | 425 | #### JITTed, parallelized code 426 | 427 | We can parallelize on the CPU through Numba via `@jit(parallel=True)` 428 | 429 | Our strategy is 430 | 431 | - Compute the max value along each row, parallelizing this task across rows 432 | - Take the max of these row maxes 433 | 434 | The memory footprint is still relatively light, because the size of the rows is only `n x 1`. 435 | 436 | ```{code-cell} ipython3 437 | @jit 438 | def f(x, y): 439 | return np.cos(x**2 + y**2) / (1 + x**2 + y**2) + 1 440 | ``` 441 | 442 | ```{code-cell} ipython3 443 | from numba import prange 444 | ``` 445 | 446 | ```{code-cell} ipython3 447 | @jit(parallel=True) 448 | def compute_max(): 449 | row_maxes = np.empty(n) 450 | y_grid = grid 451 | for i in prange(n): 452 | x = grid[i] 453 | row_maxes[i] = np.max(f(x, y_grid)) 454 | return np.max(row_maxes) 455 | ``` 456 | 457 | ```{code-cell} ipython3 458 | %%time 459 | compute_max() 460 | ``` 461 | 462 | ```{code-cell} ipython3 463 | %%time 464 | compute_max() 465 | ``` 466 | 467 | ```{code-cell} ipython3 468 | 469 | ``` 470 | -------------------------------------------------------------------------------- /source_files/wednesday/autodiff.md: -------------------------------------------------------------------------------- 1 | --- 2 | jupytext: 3 | text_representation: 4 | extension: .md 5 | format_name: myst 6 | format_version: 0.13 7 | jupytext_version: 1.16.2 8 | kernelspec: 9 | display_name: Python 3 (ipykernel) 10 | language: python 11 | name: python3 12 | --- 13 | 14 | # Adventures with Autodiff 15 | 16 | ------ 17 | 18 | #### Prepared for the CBC Workshop May 2024 19 | #### John Stachurski 20 | 21 | ------ 22 | 23 | ```{code-cell} ipython3 24 | import jax 25 | import jax.numpy as jnp 26 | import matplotlib.pyplot as plt 27 | import numpy as np 28 | ``` 29 | 30 | Checking for a GPU: 31 | 32 | ```{code-cell} ipython3 33 | !nvidia-smi 34 | ``` 35 | 36 | ## What is automatic differentiation? 37 | 38 | Autodiff is a technique for calculating derivatives on a computer. 39 | 40 | ### Autodiff is not finite differences 41 | 42 | The derivative of $f(x) = \exp(2x)$ is 43 | 44 | $$ 45 | f'(x) = 2 \exp(2x) 46 | $$ 47 | 48 | 49 | 50 | A computer that doesn't know how to take derivatives might approximate this with the finite difference ratio 51 | 52 | $$ 53 | (Df)(x) := \frac{f(x+h) - f(x)}{h} 54 | $$ 55 | 56 | where $h$ is a small positive number. 57 | 58 | ```{code-cell} ipython3 59 | def f(x): 60 | "Original function." 61 | return np.exp(2 * x) 62 | 63 | def f_prime(x): 64 | "True derivative." 65 | return 2 * np.exp(2 * x) 66 | 67 | def Df(x, h=0.1): 68 | "Approximate derivative (finite difference)." 69 | return (f(x + h) - f(x))/h 70 | 71 | x_grid = np.linspace(-2, 1, 200) 72 | fig, ax = plt.subplots() 73 | ax.plot(x_grid, f_prime(x_grid), label="$f'$") 74 | ax.plot(x_grid, Df(x_grid), label="$Df$") 75 | ax.legend() 76 | plt.show() 77 | ``` 78 | 79 | This kind of numerical derivative is often inaccurate and unstable. 80 | 81 | One reason is that 82 | 83 | $$ 84 | \frac{f(x+h) - f(x)}{h} \approx \frac{0}{0} 85 | $$ 86 | 87 | Small numbers in the numerator and denominator causes rounding errors. 88 | 89 | The situation is exponentially worse in high dimensions / with higher order derivatives 90 | 91 | +++ 92 | 93 | ### Autodiff is not symbolic calculus 94 | 95 | +++ 96 | 97 | Symbolic calculus tries to use rules for differentiation to produce a single 98 | closed-form expression representing a derivative. 99 | 100 | ```{code-cell} ipython3 101 | from sympy import symbols, diff 102 | 103 | m, a, b, x = symbols('m a b x') 104 | f_x = (a*x + b)**m 105 | f_x.diff((x, 6)) # 6-th order derivative 106 | ``` 107 | 108 | Symbolic calculus is not well suited to high performance 109 | computing. 110 | 111 | One disadvantage is that symbolic calculus cannot differentiate through control flow. 112 | 113 | Also, using symbolic calculus might involve redundant calculations. 114 | 115 | For example, consider 116 | 117 | $$ 118 | (f g h)' 119 | = (f' g + g' f) h + (f g) h' 120 | $$ 121 | 122 | If we evaluate at $x$, then we evalute $f(x)$ and $g(x)$ twice each. 123 | 124 | Also, computing $f'(x)$ and $f(x)$ might involve similar terms (e.g., $(f(x) = \exp(2x)' \implies f'(x) = 2f(x)$) but this is not exploited in symbolic algebra. 125 | 126 | +++ 127 | 128 | ### Autodiff 129 | 130 | Autodiff produces functions that evaluates derivatives at numerical values 131 | passed in by the calling code, rather than producing a single symbolic 132 | expression representing the entire derivative. 133 | 134 | Derivatives are constructed by breaking calculations into component parts via the chain rule. 135 | 136 | The chain rule is applied until the point where the terms reduce to primitive functions that the program knows how to differentiate exactly (addition, subtraction, exponentiation, sine and cosine, etc.) 137 | 138 | +++ 139 | 140 | ## Some experiments 141 | 142 | +++ 143 | 144 | Let's start with some real-valued functions on $\mathbb R$. 145 | 146 | +++ 147 | 148 | ### A differentiable function 149 | 150 | +++ 151 | 152 | Let's test JAX's auto diff with a relatively simple function. 153 | 154 | ```{code-cell} ipython3 155 | def f(x): 156 | return jnp.sin(x) - 2 * jnp.cos(3 * x) * jnp.exp(- x**2) 157 | ``` 158 | 159 | We use `grad` to compute the gradient of a real-valued function: 160 | 161 | ```{code-cell} ipython3 162 | f_prime = jax.grad(f) 163 | ``` 164 | 165 | Let's plot the result: 166 | 167 | ```{code-cell} ipython3 168 | x_grid = jnp.linspace(-5, 5, 100) 169 | ``` 170 | 171 | ```{code-cell} ipython3 172 | fig, ax = plt.subplots() 173 | ax.plot(x_grid, [f(x) for x in x_grid], label="$f$") 174 | ax.plot(x_grid, [f_prime(x) for x in x_grid], label="$f'$") 175 | ax.legend() 176 | plt.show() 177 | ``` 178 | 179 | ### Absolute value function 180 | 181 | +++ 182 | 183 | What happens if the function is not differentiable? 184 | 185 | ```{code-cell} ipython3 186 | def f(x): 187 | return jnp.abs(x) 188 | ``` 189 | 190 | ```{code-cell} ipython3 191 | f_prime = jax.grad(f) 192 | ``` 193 | 194 | ```{code-cell} ipython3 195 | fig, ax = plt.subplots() 196 | ax.plot(x_grid, [f(x) for x in x_grid], label="$f$") 197 | ax.plot(x_grid, [f_prime(x) for x in x_grid], label="$f'$") 198 | ax.legend() 199 | plt.show() 200 | ``` 201 | 202 | At the nondifferentiable point $0$, `jax.grad` returns the right derivative: 203 | 204 | ```{code-cell} ipython3 205 | f_prime(0.0) 206 | ``` 207 | 208 | ### Differentiating through control flow 209 | 210 | +++ 211 | 212 | Let's try differentiating through some loops and conditions. 213 | 214 | ```{code-cell} ipython3 215 | def f(x): 216 | def f1(x): 217 | for i in range(2): 218 | x *= 0.2 * x 219 | return x 220 | def f2(x): 221 | x = sum((x**i + i) for i in range(3)) 222 | return x 223 | y = f1(x) if x < 0 else f2(x) 224 | return y 225 | ``` 226 | 227 | ```{code-cell} ipython3 228 | f_prime = jax.grad(f) 229 | ``` 230 | 231 | ```{code-cell} ipython3 232 | x_grid = jnp.linspace(-5, 5, 100) 233 | ``` 234 | 235 | ```{code-cell} ipython3 236 | fig, ax = plt.subplots() 237 | ax.plot(x_grid, [f(x) for x in x_grid], label="$f$") 238 | ax.plot(x_grid, [f_prime(x) for x in x_grid], label="$f'$") 239 | ax.legend() 240 | plt.show() 241 | ``` 242 | 243 | ### Differentiating through a linear interpolation 244 | 245 | +++ 246 | 247 | We can differentiate through linear interpolation, even though the function is not smooth: 248 | 249 | ```{code-cell} ipython3 250 | n = 20 251 | xp = jnp.linspace(-5, 5, n) 252 | yp = jnp.cos(2 * xp) 253 | 254 | fig, ax = plt.subplots() 255 | ax.plot(x_grid, jnp.interp(x_grid, xp, yp)) 256 | plt.show() 257 | ``` 258 | 259 | ```{code-cell} ipython3 260 | f_prime = jax.grad(jnp.interp) 261 | ``` 262 | 263 | ```{code-cell} ipython3 264 | f_prime_vec = jax.vmap(f_prime, in_axes=(0, None, None)) 265 | ``` 266 | 267 | ```{code-cell} ipython3 268 | fig, ax = plt.subplots() 269 | ax.plot(x_grid, f_prime_vec(x_grid, xp, yp)) 270 | plt.show() 271 | ``` 272 | 273 | ## Gradient Descent 274 | 275 | +++ 276 | 277 | Let's try implementing gradient descent. 278 | 279 | As a simple application, we'll use gradient descent to solve for the OLS parameter estimates in simple linear regression. 280 | 281 | +++ 282 | 283 | ### A function for gradient descent 284 | 285 | +++ 286 | 287 | Here's an implementation of gradient descent. 288 | 289 | ```{code-cell} ipython3 290 | def grad_descent(f, # Function to be minimized 291 | args, # Extra arguments to the function 292 | x0, # Initial condition 293 | λ=0.1, # Initial learning rate 294 | tol=1e-5, 295 | max_iter=1_000): 296 | """ 297 | Minimize the function f via gradient descent, starting from guess x0. 298 | 299 | The learning rate is computed according to the Barzilai-Borwein method. 300 | 301 | """ 302 | 303 | f_grad = jax.grad(f) 304 | x = jnp.array(x0) 305 | df = f_grad(x, args) 306 | ϵ = tol + 1 307 | i = 0 308 | while ϵ > tol and i < max_iter: 309 | new_x = x - λ * df 310 | new_df = f_grad(new_x, args) 311 | Δx = new_x - x 312 | Δdf = new_df - df 313 | λ = jnp.abs(Δx @ Δdf) / (Δdf @ Δdf) 314 | ϵ = jnp.max(jnp.abs(Δx)) 315 | x, df = new_x, new_df 316 | i += 1 317 | 318 | return x 319 | 320 | ``` 321 | 322 | ### Simulated data 323 | 324 | We're going to test our gradient descent function my minimizing a sum of least squares in a regression problem. 325 | 326 | Let's generate some simulated data: 327 | 328 | ```{code-cell} ipython3 329 | n = 100 330 | key = jax.random.PRNGKey(1234) 331 | x = jax.random.uniform(key, (n,)) 332 | 333 | α, β, σ = 0.5, 1.0, 0.1 # Set the true intercept and slope. 334 | key, subkey = jax.random.split(key) 335 | ϵ = jax.random.normal(subkey, (n,)) 336 | 337 | y = α * x + β + σ * ϵ 338 | ``` 339 | 340 | ```{code-cell} ipython3 341 | fig, ax = plt.subplots() 342 | ax.scatter(x, y) 343 | plt.show() 344 | ``` 345 | 346 | Let's start by calculating the estimated slope and intercept using closed form solutions. 347 | 348 | ```{code-cell} ipython3 349 | mx = x.mean() 350 | my = y.mean() 351 | α_hat = jnp.sum((x - mx) * (y - my)) / jnp.sum((x - mx)**2) 352 | β_hat = my - α_hat * mx 353 | ``` 354 | 355 | ```{code-cell} ipython3 356 | α_hat, β_hat 357 | ``` 358 | 359 | ```{code-cell} ipython3 360 | fig, ax = plt.subplots() 361 | ax.scatter(x, y) 362 | ax.plot(x, α_hat * x + β_hat, 'k-') 363 | ax.text(0.1, 1.55, rf'$\hat \alpha = {α_hat:.3}$') 364 | ax.text(0.1, 1.50, rf'$\hat \beta = {β_hat:.3}$') 365 | plt.show() 366 | ``` 367 | 368 | ### Minimizing squared loss by gradient descent 369 | 370 | +++ 371 | 372 | Let's see if we can get the same values with our gradient descent function. 373 | 374 | First we set up the least squares loss function. 375 | 376 | ```{code-cell} ipython3 377 | @jax.jit 378 | def loss(params, data): 379 | a, b = params 380 | x, y = data 381 | return jnp.sum((y - a * x - b)**2) 382 | ``` 383 | 384 | Now we minimize it: 385 | 386 | ```{code-cell} ipython3 387 | p0 = jnp.zeros(2) # Initial guess for α, β 388 | data = x, y 389 | α_hat, β_hat = grad_descent(loss, data, p0) 390 | ``` 391 | 392 | Let's plot the results. 393 | 394 | ```{code-cell} ipython3 395 | fig, ax = plt.subplots() 396 | x_grid = jnp.linspace(0, 1, 100) 397 | ax.scatter(x, y) 398 | ax.plot(x_grid, α_hat * x_grid + β_hat, 'k-', alpha=0.6) 399 | ax.text(0.1, 1.55, rf'$\hat \alpha = {α_hat:.3}$') 400 | ax.text(0.1, 1.50, rf'$\hat \beta = {β_hat:.3}$') 401 | plt.show() 402 | ``` 403 | 404 | Notice that we get the same estimates as we did from the closed form solutions. 405 | 406 | +++ 407 | 408 | ### Adding a squared term 409 | 410 | +++ 411 | 412 | Now let's try fitting a second order polynomial. 413 | 414 | Here's our new loss function. 415 | 416 | ```{code-cell} ipython3 417 | @jax.jit 418 | def loss(params, data): 419 | a, b, c = params 420 | x, y = data 421 | return jnp.sum((y - a * x**2 - b * x - c)**2) 422 | ``` 423 | 424 | Now we're minimizing in three dimensions. 425 | 426 | Let's try it. 427 | 428 | ```{code-cell} ipython3 429 | p0 = jnp.zeros(3) 430 | α_hat, β_hat, γ_hat = grad_descent(loss, data, p0) 431 | 432 | fig, ax = plt.subplots() 433 | ax.scatter(x, y) 434 | ax.plot(x_grid, α_hat * x_grid**2 + β_hat * x_grid + γ_hat, 'k-', alpha=0.6) 435 | ax.text(0.1, 1.55, rf'$\hat \alpha = {α_hat:.3}$') 436 | ax.text(0.1, 1.50, rf'$\hat \beta = {β_hat:.3}$') 437 | plt.show() 438 | ``` 439 | 440 | ## Exercise 441 | 442 | The function `jnp.polyval` evaluates polynomials. 443 | 444 | For example, if `len(p)` is 3, then `jnp.polyval(p, x)` returns 445 | 446 | $$ 447 | f(p, x) := p_0 x^2 + p_1 x + p_2 448 | $$ 449 | 450 | Use this function for polynomial regression. 451 | 452 | The (empirical) loss becomes 453 | 454 | $$ 455 | \ell(p, x, y) 456 | = \sum_{i=1}^n (y_i - f(p, x_i))^2 457 | $$ 458 | 459 | Set $k=4$ and set the initial guess of `params` to `jnp.zeros(k)`. 460 | 461 | Use gradient descent to find the array `params` that minimizes the loss 462 | function and plot the result (following the examples above). 463 | 464 | ```{code-cell} ipython3 465 | for i in range(18): 466 | print("Solution below 🦀") 467 | ``` 468 | 469 | ```{code-cell} ipython3 470 | def loss(params, data): 471 | x, y = data 472 | return jnp.sum((y - jnp.polyval(params, x))**2) 473 | ``` 474 | 475 | ```{code-cell} ipython3 476 | k = 4 477 | p0 = jnp.zeros(k) 478 | p_hat = grad_descent(loss, data, p0) 479 | print('Estimated parameter vector:') 480 | print(p_hat) 481 | print('\n\n') 482 | 483 | fig, ax = plt.subplots() 484 | ax.scatter(x, y) 485 | ax.plot(x_grid, jnp.polyval(p_hat, x_grid), 'k-', alpha=0.6) 486 | plt.show() 487 | ``` 488 | 489 | ```{code-cell} ipython3 490 | 491 | ``` 492 | -------------------------------------------------------------------------------- /source_files/wednesday/inventory_dynamics_jax.md: -------------------------------------------------------------------------------- 1 | --- 2 | jupytext: 3 | text_representation: 4 | extension: .md 5 | format_name: myst 6 | format_version: 0.13 7 | jupytext_version: 1.16.1 8 | kernelspec: 9 | display_name: Python 3 (ipykernel) 10 | language: python 11 | name: python3 12 | --- 13 | 14 | # Inventory Dynamics 15 | 16 | +++ 17 | 18 | ---- 19 | 20 | #### John Stachurski 21 | #### Prepared for the CBC Computational Workshop (May 2024) 22 | 23 | ---- 24 | 25 | +++ 26 | 27 | ## Overview 28 | 29 | This lecture explores s-S inventory dynamics. 30 | 31 | We also studied this model in an earlier notebook using NumPy. 32 | 33 | Here we study the same problem using JAX. 34 | 35 | One issue we will consider is whether or not we can improve execution speed by using JAX's `fori_loop` function (as a replacement for looping in Python). 36 | 37 | We will use the following imports: 38 | 39 | ```{code-cell} ipython3 40 | :hide-output: false 41 | 42 | import matplotlib.pyplot as plt 43 | import numpy as np 44 | import jax 45 | import jax.numpy as jnp 46 | from jax import random, lax 47 | from collections import namedtuple 48 | ``` 49 | 50 | Here’s a description of our GPU: 51 | 52 | ```{code-cell} ipython3 53 | :hide-output: false 54 | 55 | !nvidia-smi 56 | ``` 57 | 58 | ## Model 59 | 60 | We briefly recall the dynamics. 61 | 62 | Let $a \vee b = \max\{a, b\}$ 63 | 64 | Consider a firm with inventory $ X_t $. 65 | 66 | The firm waits until $ X_t \leq s $ and then restocks up to $ S $ units. 67 | 68 | It faces stochastic demand $ \{ D_t \} $, which we assume is IID across time and 69 | firms. 70 | 71 | $$ 72 | X_{t+1} = 73 | \begin{cases} 74 | (S - D_{t+1}) \vee 0 & \quad \text{if } X_t \leq s \\ 75 | (X_t - D_{t+1}) \vee 0 & \quad \text{if } X_t > s 76 | \end{cases} 77 | $$ 78 | 79 | In what follows, we will assume that each $ D_t $ is lognormal, so that 80 | 81 | $$ 82 | D_t = \exp(\mu + \sigma Z_t) 83 | $$ 84 | 85 | where $ \mu $ and $ \sigma $ are parameters and $ \{Z_t\} $ is IID 86 | and standard normal. 87 | 88 | Here’s a `namedtuple` that stores parameters. 89 | 90 | ```{code-cell} ipython3 91 | :hide-output: false 92 | 93 | Parameters = namedtuple('Parameters', ('s', 'S', 'μ', 'σ')) 94 | 95 | # Create a default instance 96 | params = Parameters(s=10, S=100, μ=1.0, σ=0.5) 97 | ``` 98 | 99 | ## Cross-sectional distributions 100 | 101 | Let's simulate the inventories of a cross-section of firms. 102 | 103 | We will use the following code to update the cross-section by one period. 104 | 105 | ```{code-cell} ipython3 106 | :hide-output: false 107 | 108 | @jax.jit 109 | def update_cross_section(params, X_vec, D): 110 | """ 111 | Update by one period a cross-section of firms with inventory levels 112 | X_vec, given the vector of demand shocks in D. 113 | 114 | * X_vec[i] is the inventory of firm i 115 | * D[i] is the demand shock for firm i 116 | 117 | """ 118 | # Unpack 119 | s, S = params.s, params.S 120 | # Restock if the inventory is below the threshold 121 | X_new = jnp.where(X_vec <= s, 122 | jnp.maximum(S - D, 0), jnp.maximum(X_vec - D, 0)) 123 | return X_new 124 | ``` 125 | 126 | ### For loop version 127 | 128 | Now we provide code to compute the cross-sectional distribution $ \psi_T $ given some 129 | initial distribution $ \psi_0 $ and a positive integer $ T $. 130 | 131 | In this code we use an ordinary Python `for` loop to step forward through time 132 | 133 | While Python loops are slow, this approach is reasonable here because 134 | efficiency of outer loops has far less influence on runtime than efficiency of inner loops. 135 | 136 | (Below we will squeeze out more speed by compiling the outer loop as well as the 137 | update rule.) 138 | 139 | In the code below, the initial distribution $ \psi_0 $ takes all firms to have 140 | initial inventory `x_init`. 141 | 142 | ```{code-cell} ipython3 143 | :hide-output: false 144 | 145 | def compute_cross_section(params, x_init, T, key, num_firms=50_000): 146 | # Unpack 147 | μ, σ = params.μ, params.σ 148 | # Set up initial distribution 149 | X = jnp.full((num_firms, ), x_init) 150 | # Loop 151 | for i in range(T): 152 | Z = random.normal(key, shape=(num_firms, )) 153 | D = jnp.exp(params.μ + params.σ * Z) 154 | X = update_cross_section(params, X, D) 155 | key = random.fold_in(key, i) 156 | 157 | return X 158 | ``` 159 | 160 | We’ll use the following specification 161 | 162 | ```{code-cell} ipython3 163 | :hide-output: false 164 | 165 | x_init = 50 166 | T = 500 167 | # Initialize random number generator 168 | key = random.PRNGKey(10) 169 | ``` 170 | 171 | Let’s look at the timing. 172 | 173 | ```{code-cell} ipython3 174 | :hide-output: false 175 | 176 | %time X_vec = compute_cross_section(params, x_init, T, key).block_until_ready() 177 | ``` 178 | 179 | ```{code-cell} ipython3 180 | :hide-output: false 181 | 182 | %time X_vec = compute_cross_section(params, x_init, T, key).block_until_ready() 183 | ``` 184 | 185 | Here’s a histogram of inventory levels at time $ T $. 186 | 187 | ```{code-cell} ipython3 188 | :hide-output: false 189 | 190 | fig, ax = plt.subplots() 191 | ax.hist(X_vec, bins=50, 192 | density=True, 193 | histtype='step', 194 | label=f'cross-section when $T = {T}$') 195 | ax.set_xlabel('inventory') 196 | ax.set_ylabel('probability') 197 | ax.legend() 198 | plt.show() 199 | ``` 200 | 201 | ### Compiling the outer loop 202 | 203 | Now let’s see if we can gain some speed by compiling the outer loop, which steps 204 | through the time dimension. 205 | 206 | We will do this using `jax.jit` and a `fori_loop`, which is a compiler-ready version of a `for` loop provided by JAX. 207 | 208 | ```{code-cell} ipython3 209 | :hide-output: false 210 | 211 | def compute_cross_section_fori(params, x_init, T, key, num_firms=50_000): 212 | 213 | s, S, μ, σ = params.s, params.S, params.μ, params.σ 214 | 215 | # Define the function for each update 216 | def fori_update(t, state): 217 | # Unpack 218 | X, key = state 219 | # Draw shocks using key 220 | Z = random.normal(key, shape=(num_firms,)) 221 | D = jnp.exp(μ + σ * Z) 222 | # Update X 223 | X = jnp.where(X <= s, 224 | jnp.maximum(S - D, 0), 225 | jnp.maximum(X - D, 0)) 226 | # Refresh the key 227 | key = random.fold_in(key, t) 228 | new_state = X, key 229 | return new_state 230 | 231 | # Loop t from 0 to T, applying fori_update each time. 232 | X = jnp.full((num_firms, ), x_init) 233 | initial_state = X, key 234 | X, key = lax.fori_loop(0, T, fori_update, initial_state) 235 | 236 | return X 237 | 238 | # Compile taking T and num_firms as static (changes trigger recompile) 239 | compute_cross_section_fori = jax.jit( 240 | compute_cross_section_fori, static_argnums=(2, 4)) 241 | ``` 242 | 243 | Let’s see how fast this runs with compile time. 244 | 245 | ```{code-cell} ipython3 246 | :hide-output: false 247 | 248 | %time X_vec = compute_cross_section_fori(params, x_init, T, key).block_until_ready() 249 | ``` 250 | 251 | And let’s see how fast it runs without compile time. 252 | 253 | ```{code-cell} ipython3 254 | :hide-output: false 255 | 256 | %time X_vec = compute_cross_section_fori(params, x_init, T, key).block_until_ready() 257 | ``` 258 | 259 | Compared to the original version with a pure Python outer loop, we have 260 | produced a nontrivial speed gain. 261 | 262 | This is due to the fact that we have compiled the whole operation. 263 | 264 | Let's check that we get a similar cross-section. 265 | 266 | ```{code-cell} ipython3 267 | :hide-output: false 268 | 269 | fig, ax = plt.subplots() 270 | ax.hist(X_vec, bins=50, 271 | density=True, 272 | histtype='step', 273 | label=f'cross-section when $T = {T}$') 274 | ax.set_xlabel('inventory') 275 | ax.set_ylabel('probability') 276 | ax.legend() 277 | plt.show() 278 | ``` 279 | 280 | ### Further vectorization 281 | 282 | For relatively small problems, we can make this code run even faster by generating 283 | all random variables at once. 284 | 285 | This improves efficiency because we are taking more operations out of the loop. 286 | 287 | ```{code-cell} ipython3 288 | :hide-output: false 289 | 290 | def compute_cross_section_fori(params, x_init, T, key, num_firms=50_000): 291 | 292 | s, S, μ, σ = params.s, params.S, params.μ, params.σ 293 | X = jnp.full((num_firms, ), x_init) 294 | Z = random.normal(key, shape=(T, num_firms)) 295 | D = jnp.exp(μ + σ * Z) 296 | 297 | def update_cross_section(i, X): 298 | X = jnp.where(X <= s, 299 | jnp.maximum(S - D[i, :], 0), 300 | jnp.maximum(X - D[i, :], 0)) 301 | return X 302 | 303 | X = lax.fori_loop(0, T, update_cross_section, X) 304 | 305 | return X 306 | 307 | # Compile taking T and num_firms as static (changes trigger recompile) 308 | compute_cross_section_fori = jax.jit( 309 | compute_cross_section_fori, static_argnums=(2, 4)) 310 | ``` 311 | 312 | Let’s test it with compile time included. 313 | 314 | ```{code-cell} ipython3 315 | :hide-output: false 316 | 317 | %time X_vec = compute_cross_section_fori(params, x_init, T, key).block_until_ready() 318 | ``` 319 | 320 | Let’s run again to eliminate compile time. 321 | 322 | ```{code-cell} ipython3 323 | :hide-output: false 324 | 325 | %time X_vec = compute_cross_section_fori(params, x_init, T, key).block_until_ready() 326 | ``` 327 | 328 | On one hand, this version is faster than the previous one, where random variables were 329 | generated inside the loop. 330 | 331 | On the other hand, this implementation consumes far more memory, as we need to 332 | store large arrays of random draws. 333 | 334 | The high memory consumption becomes problematic for large problems. 335 | 336 | +++ 337 | 338 | ## Restock frequency 339 | 340 | As an exercise, let’s study the probability that firms need to restock over a given time period. 341 | 342 | In the exercise, we will 343 | 344 | - set the starting stock level to $ X_0 = 70 $ and 345 | - calculate the proportion of firms that need to order twice or more in the first 50 periods. 346 | 347 | 348 | This proportion approximates the probability of the event when the sample size 349 | is large. 350 | 351 | +++ 352 | 353 | ### For loop version 354 | 355 | We start with an easier `for` loop implementation 356 | 357 | ```{code-cell} ipython3 358 | :hide-output: false 359 | 360 | # Define a jitted function for each update 361 | @jax.jit 362 | def update_stock(params, counter, X, D): 363 | s, S = params.s, params.S 364 | X = jnp.where(X <= s, 365 | jnp.maximum(S - D, 0), 366 | jnp.maximum(X - D, 0)) 367 | counter = jnp.where(X <= s, counter + 1, counter) 368 | return counter, X 369 | 370 | def compute_freq(params, key, 371 | x_init=70, 372 | sim_length=50, 373 | num_firms=1_000_000): 374 | 375 | # Prepare initial arrays 376 | X = jnp.full((num_firms, ), x_init) 377 | counter = jnp.zeros((num_firms, )) 378 | 379 | # Use a for loop to perform the calculations on all states 380 | for i in range(sim_length): 381 | Z = random.normal(key, (num_firms, )) 382 | D = jnp.exp(params.μ + params.σ * Z) 383 | counter, X = update_stock(params, counter, X, D) 384 | key = random.fold_in(key, i) 385 | 386 | return jnp.mean(counter > 1, axis=0) 387 | ``` 388 | 389 | ```{code-cell} ipython3 390 | :hide-output: false 391 | 392 | key = random.PRNGKey(42) 393 | ``` 394 | 395 | ```{code-cell} ipython3 396 | :hide-output: false 397 | 398 | %time freq = compute_freq(params, key).block_until_ready() 399 | print(f"Frequency of at least two stock outs = {freq}") 400 | ``` 401 | 402 | ```{code-cell} ipython3 403 | :hide-output: false 404 | 405 | %time freq = compute_freq(params, key).block_until_ready() 406 | print(f"Frequency of at least two stock outs = {freq}") 407 | ``` 408 | 409 | ### Exercise 4.1 410 | 411 | Write a `fori_loop` version of the last function. See if you can increase the 412 | speed while generating a similar answer. 413 | 414 | +++ 415 | 416 | ### Solution to[ Exercise 4.1](https://jax.quantecon.org/#inventory_dynamics_ex1) 417 | 418 | Here is a `fori_loop` version that JIT compiles the whole function 419 | 420 | ```{code-cell} ipython3 421 | :hide-output: false 422 | 423 | def compute_freq_fori_loop(params, 424 | key, 425 | x_init=70, 426 | sim_length=50, 427 | num_firms=1_000_000): 428 | 429 | s, S, μ, σ = params 430 | Z = random.normal(key, shape=(sim_length, num_firms)) 431 | D = jnp.exp(μ + σ * Z) 432 | 433 | # Define the function for each update 434 | def update(t, state): 435 | # Separate the inventory and restock counter 436 | X, counter = state 437 | X = jnp.where(X <= s, 438 | jnp.maximum(S - D[t, :], 0), 439 | jnp.maximum(X - D[t, :], 0)) 440 | counter = jnp.where(X <= s, counter + 1, counter) 441 | return X, counter 442 | 443 | X = jnp.full((num_firms, ), x_init) 444 | counter = jnp.zeros(num_firms) 445 | initial_state = X, counter 446 | X, counter = lax.fori_loop(0, sim_length, update, initial_state) 447 | 448 | return jnp.mean(counter > 1) 449 | 450 | compute_freq_fori_loop = jax.jit(compute_freq_fori_loop, static_argnums=((3, 4))) 451 | ``` 452 | 453 | Note the time the routine takes to run, as well as the output 454 | 455 | ```{code-cell} ipython3 456 | :hide-output: false 457 | 458 | %time freq = compute_freq_fori_loop(params, key).block_until_ready() 459 | ``` 460 | 461 | ```{code-cell} ipython3 462 | :hide-output: false 463 | 464 | %time freq = compute_freq_fori_loop(params, key).block_until_ready() 465 | ``` 466 | 467 | ```{code-cell} ipython3 468 | :hide-output: false 469 | 470 | print(f"Frequency of at least two stock outs = {freq}") 471 | ``` 472 | 473 | ```{code-cell} ipython3 474 | 475 | ``` 476 | -------------------------------------------------------------------------------- /source_files/monday/fun_with_jax.md: -------------------------------------------------------------------------------- 1 | --- 2 | jupytext: 3 | text_representation: 4 | extension: .md 5 | format_name: myst 6 | format_version: 0.13 7 | jupytext_version: 1.16.2 8 | kernelspec: 9 | display_name: Python 3 (ipykernel) 10 | language: python 11 | name: python3 12 | --- 13 | 14 | # Fun with JAX 15 | 16 | ----- 17 | 18 | #### [John Stachurski](https://johnstachurski.net/) 19 | 20 | ### Prepared for the CBC Computational Workshop (May 2024) 21 | 22 | ----- 23 | 24 | This notebook illustrates the power of [JAX](https://github.com/google/jax), a Python library built by Google Research. 25 | 26 | It should be run on a machine with a GPU --- for example, try Google Colab with the runtime environment set to include a GPU. 27 | 28 | The aim is just to give a small taste of high performance computing in Python -- details will be covered later in the course. 29 | 30 | +++ 31 | 32 | We start with some imports 33 | 34 | ```{code-cell} ipython3 35 | import numpy as np 36 | import scipy 37 | import jax 38 | import jax.numpy as jnp 39 | import matplotlib.pyplot as plt 40 | ``` 41 | 42 | Let's check our hardware: 43 | 44 | ```{code-cell} ipython3 45 | !nvidia-smi 46 | ``` 47 | 48 | ```{code-cell} ipython3 49 | !lscpu -e 50 | ``` 51 | 52 | ## Transforming Data 53 | 54 | +++ 55 | 56 | A very common numerical task is to apply a transformation to a set of data points. 57 | 58 | Our transformation will be the cosine function. 59 | 60 | +++ 61 | 62 | Here we evaluate the cosine function at 50 points. 63 | 64 | ```{code-cell} ipython3 65 | x = np.linspace(0, 10, 50) 66 | y = np.cos(x) 67 | ``` 68 | 69 | Let's plot. 70 | 71 | ```{code-cell} ipython3 72 | fig, ax = plt.subplots() 73 | ax.scatter(x, y) 74 | plt.show() 75 | ``` 76 | 77 | Our aim is to evaluate the cosine function at many points. 78 | 79 | ```{code-cell} ipython3 80 | n = 50_000_000 81 | x = np.linspace(0, 10, n) 82 | ``` 83 | 84 | ### With NumPy 85 | 86 | ```{code-cell} ipython3 87 | %time np.cos(x) 88 | ``` 89 | 90 | ```{code-cell} ipython3 91 | %time np.cos(x) 92 | ``` 93 | 94 | The next line of code frees some memory -- can you explain why? 95 | 96 | ```{code-cell} ipython3 97 | x = None 98 | ``` 99 | 100 | ### With JAX 101 | 102 | ```{code-cell} ipython3 103 | x_jax = jnp.linspace(0, 10, n) 104 | ``` 105 | 106 | Let's run the same operation on JAX 107 | 108 | (The `block_until_ready()` method is explained a bit later.) 109 | 110 | ```{code-cell} ipython3 111 | %time jnp.cos(x_jax).block_until_ready() 112 | ``` 113 | 114 | ```{code-cell} ipython3 115 | %time jnp.cos(x_jax).block_until_ready() 116 | ``` 117 | 118 | ```{code-cell} ipython3 119 | x_jax = None # Free memory 120 | ``` 121 | 122 | ## Evaluating a more complicated function 123 | 124 | ```{code-cell} ipython3 125 | def f(x): 126 | y = np.cos(2 * x**2) + np.sqrt(np.abs(x)) + 2 * np.sin(x**4) - 0.1 * x**2 127 | return y 128 | ``` 129 | 130 | ```{code-cell} ipython3 131 | fig, ax = plt.subplots() 132 | x = np.linspace(0, 10, 100) 133 | ax.plot(x, f(x)) 134 | ax.scatter(x, f(x)) 135 | plt.show() 136 | ``` 137 | 138 | Now let's try with a large array. 139 | 140 | +++ 141 | 142 | ### With NumPy 143 | 144 | ```{code-cell} ipython3 145 | n = 50_000_000 146 | x = np.linspace(0, 10, n) 147 | ``` 148 | 149 | ```{code-cell} ipython3 150 | %time f(x) 151 | ``` 152 | 153 | ```{code-cell} ipython3 154 | %time f(x) 155 | ``` 156 | 157 | ### With JAX 158 | 159 | ```{code-cell} ipython3 160 | def f(x): 161 | y = jnp.cos(2 * x**2) + jnp.sqrt(jnp.abs(x)) + 2 * jnp.sin(x**4) - x**2 162 | return y 163 | ``` 164 | 165 | ```{code-cell} ipython3 166 | x_jax = jnp.linspace(0, 10, n) 167 | ``` 168 | 169 | ```{code-cell} ipython3 170 | %time f(x_jax).block_until_ready() 171 | ``` 172 | 173 | ```{code-cell} ipython3 174 | %time f(x_jax).block_until_ready() 175 | ``` 176 | 177 | ### Compiling the Whole Function 178 | 179 | ```{code-cell} ipython3 180 | f_jax = jax.jit(f) 181 | ``` 182 | 183 | ```{code-cell} ipython3 184 | %time f_jax(x_jax).block_until_ready() 185 | ``` 186 | 187 | ```{code-cell} ipython3 188 | %time f_jax(x_jax).block_until_ready() 189 | ``` 190 | 191 | ## Solving Linear Systems 192 | 193 | ```{code-cell} ipython3 194 | np.random.seed(1234) 195 | n = 5_000 196 | A = np.random.randn(n, n) 197 | b = np.ones(n) 198 | ``` 199 | 200 | ```{code-cell} ipython3 201 | %time np.linalg.solve(A, b) 202 | ``` 203 | 204 | ```{code-cell} ipython3 205 | A, b = [jax.device_put(v) for v in (A, b)] 206 | ``` 207 | 208 | ```{code-cell} ipython3 209 | %time jnp.linalg.solve(A, b).block_until_ready() 210 | ``` 211 | 212 | ```{code-cell} ipython3 213 | %time jnp.linalg.solve(A, b).block_until_ready() 214 | ``` 215 | 216 | ## Nonlinear Equations 217 | 218 | +++ 219 | 220 | In many cases we want to solve a system of nonlinear equations. 221 | 222 | This section gives an example --- solving for an equilibrium price vector when supply and demand are nonlinear. 223 | 224 | We start with a simple two good market. 225 | 226 | Then we shift up to high dimensions. 227 | 228 | We will see that, in high dimensions, automatic differentiation and the GPU are very helpful. 229 | 230 | +++ 231 | 232 | ### A Two Goods Market Equilibrium 233 | 234 | Let’s start by computing the market equilibrium of a two-good problem. 235 | 236 | Here's the excess demand function 237 | 238 | $$ 239 | e(p) = 240 | \begin{pmatrix} 241 | e_0(p_0, p_1) \\ 242 | e_1(p_0, p_1) 243 | \end{pmatrix} 244 | $$ 245 | 246 | An equilibrium price vector is a $p=(p_0, p_1)$ such that 247 | 248 | $$ 249 | e(p) = 0 250 | $$ 251 | 252 | The function below calculates the excess demand for given parameters 253 | 254 | ```{code-cell} ipython3 255 | :hide-output: false 256 | 257 | def e(p, A, b, c): 258 | "Excess demand is demand - supply at price vector p" 259 | return np.exp(- A @ p) + c - b * np.sqrt(p) 260 | ``` 261 | 262 | Our default parameter values will be 263 | 264 | $$ 265 | A = \begin{pmatrix} 266 | 0.5 & 0.4 \\ 267 | 0.8 & 0.2 268 | \end{pmatrix}, 269 | \qquad 270 | b = \begin{pmatrix} 271 | 1 \\ 272 | 1 273 | \end{pmatrix} 274 | \qquad \text{and} \qquad 275 | c = \begin{pmatrix} 276 | 1 \\ 277 | 1 278 | \end{pmatrix} 279 | $$ 280 | 281 | ```{code-cell} ipython3 282 | :hide-output: false 283 | 284 | A = np.array(((0.5, 0.4), 285 | (0.8, 0.2))) 286 | b = np.ones(2) 287 | c = np.ones(2) 288 | ``` 289 | 290 | Next we plot the two functions $ e_0 $ and $ e_1 $ on a grid of $ (p_0, p_1) $ values, using contour surfaces and lines. 291 | 292 | We will use the following function to build the contour plots 293 | 294 | ```{code-cell} ipython3 295 | :hide-output: false 296 | 297 | def plot_excess_demand(ax, good=0, grid_size=100, grid_max=4, surface=True): 298 | p_grid = np.linspace(0, grid_max, grid_size) 299 | z = np.empty((100, 100)) 300 | 301 | for i, p_1 in enumerate(p_grid): 302 | for j, p_2 in enumerate(p_grid): 303 | z[i, j] = e((p_1, p_2), A, b, c)[good] 304 | 305 | if surface: 306 | cs1 = ax.contourf(p_grid, p_grid, z.T, alpha=0.5) 307 | plt.colorbar(cs1, ax=ax, format="%.6f") 308 | 309 | ctr1 = ax.contour(p_grid, p_grid, z.T, levels=[0.0]) 310 | ax.set_xlabel("$p_0$") 311 | ax.set_ylabel("$p_1$") 312 | ax.set_title(f'Excess Demand for Good {good}') 313 | plt.clabel(ctr1, inline=1, fontsize=13) 314 | ``` 315 | 316 | Here’s our plot of $ e_0 $: 317 | 318 | ```{code-cell} ipython3 319 | :hide-output: false 320 | 321 | fig, ax = plt.subplots() 322 | plot_excess_demand(ax, good=0) 323 | plt.show() 324 | ``` 325 | 326 | Here’s our plot of $ e_1 $: 327 | 328 | ```{code-cell} ipython3 329 | :hide-output: false 330 | 331 | fig, ax = plt.subplots() 332 | plot_excess_demand(ax, good=1) 333 | plt.show() 334 | ``` 335 | 336 | We see the black contour line of zero, which tells us when $ e_i(p)=0 $. 337 | 338 | For a price vector $ p $ such that $ e_i(p)=0 $ we know that good $ i $ is in equilibrium (demand equals supply). 339 | 340 | If these two contour lines cross at some price vector $ p^* $, then $ p^* $ is an equilibrium price vector. 341 | 342 | ```{code-cell} ipython3 343 | :hide-output: false 344 | 345 | fig, ax = plt.subplots() 346 | for good in (0, 1): 347 | plot_excess_demand(ax, good=good, surface=False) 348 | plt.show() 349 | ``` 350 | 351 | It seems there is an equilibrium close to $ p = (1.6, 1.5) $. 352 | 353 | +++ 354 | 355 | #### Using a Multidimensional Root Finder 356 | 357 | To solve for $ p^* $ more precisely, we use a zero-finding algorithm from `scipy.optimize`. 358 | 359 | We supply $ p = (1, 1) $ as our initial guess. 360 | 361 | ```{code-cell} ipython3 362 | :hide-output: false 363 | 364 | init_p = np.ones(2) 365 | ``` 366 | 367 | Now we use a standard hybrid algorithm to find the zero 368 | 369 | ```{code-cell} ipython3 370 | :hide-output: false 371 | 372 | solution = scipy.optimize.root(lambda p: e(p, A, b, c), init_p, method='hybr') 373 | ``` 374 | 375 | Here’s the resulting value: 376 | 377 | ```{code-cell} ipython3 378 | :hide-output: false 379 | 380 | p = solution.x 381 | p 382 | ``` 383 | 384 | This looks close to our guess from observing the figure. We can plug it back into $ e $ to test that $ e(p) \approx 0 $: 385 | 386 | ```{code-cell} ipython3 387 | :hide-output: false 388 | 389 | np.max(np.abs(e(p, A, b, c))) 390 | ``` 391 | 392 | This is indeed a very small error. 393 | 394 | +++ 395 | 396 | #### Adding Gradient Information 397 | 398 | In many cases, for zero-finding algorithms applied to smooth functions, supplying the [Jacobian](https://en.wikipedia.org/wiki/Jacobian_matrix_and_determinant) of the function leads to better convergence properties. 399 | 400 | Here we manually calculate the elements of the Jacobian 401 | 402 | $$ 403 | J(p) = 404 | \begin{pmatrix} 405 | \frac{\partial e_0}{\partial p_0}(p) & \frac{\partial e_0}{\partial p_1}(p) \\ 406 | \frac{\partial e_1}{\partial p_0}(p) & \frac{\partial e_1}{\partial p_1}(p) 407 | \end{pmatrix} 408 | $$ 409 | 410 | ```{code-cell} ipython3 411 | :hide-output: false 412 | 413 | def jacobian_e(p, A, b, c): 414 | p_0, p_1 = p 415 | a_00, a_01 = A[0, :] 416 | a_10, a_11 = A[1, :] 417 | j_00 = -a_00 * np.exp(-a_00 * p_0) - (b[0]/2) * p_0**(-1/2) 418 | j_01 = -a_01 * np.exp(-a_01 * p_1) 419 | j_10 = -a_10 * np.exp(-a_10 * p_0) 420 | j_11 = -a_11 * np.exp(-a_11 * p_1) - (b[1]/2) * p_1**(-1/2) 421 | J = [[j_00, j_01], 422 | [j_10, j_11]] 423 | return np.array(J) 424 | ``` 425 | 426 | ```{code-cell} ipython3 427 | :hide-output: false 428 | 429 | solution = scipy.optimize.root(lambda p: e(p, A, b, c), 430 | init_p, 431 | jac=lambda p: jacobian_e(p, A, b, c), 432 | method='hybr') 433 | ``` 434 | 435 | Now the solution is even more accurate (although, in this low-dimensional problem, the difference is quite small): 436 | 437 | ```{code-cell} ipython3 438 | :hide-output: false 439 | 440 | p = solution.x 441 | np.max(np.abs(e(p, A, b, c))) 442 | ``` 443 | 444 | #### Newton’s Method via JAX 445 | 446 | We use a multivariate version of Newton’s method to compute the equilibrium price. 447 | 448 | The rule for updating a guess $ p_n $ of the equilibrium price vector is 449 | 450 | 451 | 452 | $$ 453 | p_{n+1} = p_n - J_e(p_n)^{-1} e(p_n) \tag{3.1} 454 | $$ 455 | 456 | Here $ J_e(p_n) $ is the Jacobian of $ e $ evaluated at $ p_n $. 457 | 458 | Iteration starts from initial guess $ p_0 $. 459 | 460 | Instead of coding the Jacobian by hand, we use automatic differentiation via `jax.jacobian()`. 461 | 462 | ```{code-cell} ipython3 463 | :hide-output: false 464 | 465 | def newton(f, x_0, tol=1e-5, max_iter=15): 466 | """ 467 | A multivariate Newton root-finding routine. 468 | 469 | """ 470 | x = x_0 471 | f_jac = jax.jacobian(f) 472 | @jax.jit 473 | def q(x): 474 | " Updates the current guess. " 475 | return x - jnp.linalg.solve(f_jac(x), f(x)) 476 | error = tol + 1 477 | n = 0 478 | while error > tol: 479 | n += 1 480 | if(n > max_iter): 481 | raise Exception('Max iteration reached without convergence') 482 | y = q(x) 483 | error = jnp.linalg.norm(x - y) 484 | x = y 485 | print(f'iteration {n}, error = {error}') 486 | return x 487 | ``` 488 | 489 | ```{code-cell} ipython3 490 | :hide-output: false 491 | 492 | @jax.jit 493 | def e(p, A, b, c): 494 | return jnp.exp(- A @ p) + c - b * jnp.sqrt(p) 495 | ``` 496 | 497 | ```{code-cell} ipython3 498 | :hide-output: false 499 | 500 | p = newton(lambda p: e(p, A, b, c), init_p) 501 | p 502 | ``` 503 | 504 | ```{code-cell} ipython3 505 | :hide-output: false 506 | 507 | jnp.max(jnp.abs(e(p, A, b, c))) 508 | ``` 509 | 510 | ### A High-Dimensional Problem 511 | 512 | Let’s now apply the method just described to investigate a large market with 5,000 goods. 513 | 514 | We randomly generate the matrix $ A $ and set the parameter vectors $ b, c $ to $ 1 $. 515 | 516 | ```{code-cell} ipython3 517 | :hide-output: false 518 | 519 | dim = 5_000 520 | seed = 32 521 | 522 | # Create a random matrix A and normalize the rows to sum to one 523 | key = jax.random.PRNGKey(seed) 524 | A = jax.random.uniform(key, (dim, dim)) 525 | s = jnp.sum(A, axis=0) 526 | A = A / s 527 | 528 | # Set up b and c 529 | b = jnp.ones(dim) 530 | c = jnp.ones(dim) 531 | ``` 532 | 533 | Here’s our initial condition $ p_0 $ 534 | 535 | ```{code-cell} ipython3 536 | :hide-output: false 537 | 538 | init_p = jnp.ones(dim) 539 | ``` 540 | 541 | By combining the power of Newton’s method, JAX accelerated linear algebra, 542 | automatic differentiation, and a GPU, we obtain a relatively small error for 543 | this high-dimensional problem in just a few seconds: 544 | 545 | ```{code-cell} ipython3 546 | :hide-output: false 547 | 548 | %time p = newton(lambda p: e(p, A, b, c), init_p).block_until_ready() 549 | ``` 550 | 551 | Here’s the size of the error: 552 | 553 | ```{code-cell} ipython3 554 | :hide-output: false 555 | 556 | jnp.max(jnp.abs(e(p, A, b, c))) 557 | ``` 558 | 559 | With the same tolerance, SciPy’s `root` function takes much longer to run. 560 | 561 | ```{code-cell} ipython3 562 | :hide-output: false 563 | 564 | %%time 565 | 566 | solution = scipy.optimize.root(lambda p: e(p, A, b, c), 567 | init_p, 568 | method='hybr', 569 | tol=1e-5) 570 | ``` 571 | 572 | ```{code-cell} ipython3 573 | :hide-output: false 574 | 575 | p = solution.x 576 | jnp.max(jnp.abs(e(p, A, b, c))) 577 | ``` 578 | -------------------------------------------------------------------------------- /source_files/monday/ai_revolution/ai_revolution.tex: -------------------------------------------------------------------------------- 1 | \documentclass[ 2 | xcolor={svgnames,dvipsnames}, 3 | hyperref={colorlinks, citecolor=DeepPink4, linkcolor=DarkRed, urlcolor=DarkBlue} 4 | ]{beamer} % for hardcopy add 'trans' 5 | 6 | 7 | \mode 8 | { 9 | \usetheme{Singapore} 10 | % or ... 11 | \setbeamercovered{transparent} 12 | % or whatever (possibly just delete it) 13 | } 14 | 15 | %\usefonttheme{professionalfonts} 16 | %\usepackage[english]{babel} 17 | % or whatever 18 | %\usepackage[latin1]{inputenc} 19 | % or whatever 20 | %\usepackage{times} 21 | %\usepackage[T1]{fontenc} 22 | % Or whatever. Note that the encoding and the font should match. If T1 23 | % does not look nice, try deleting the line with the fontenc. 24 | 25 | %\usepackage{fontspec} 26 | %\setmonofont{CMU Typewriter Text} 27 | %\setmonofont{Consolas} 28 | 29 | \usepackage{fontspec} 30 | %\usepackage[xcharter]{newtxmath} 31 | %\setmainfont{XCharter} 32 | \usepackage{unicode-math} 33 | %\setmathfont{XCharter-Math.otf} 34 | \setmonofont{DejaVu Sans Mono}[Scale=MatchLowercase] % provides unicode characters 35 | 36 | 37 | 38 | %%%%%%%%%%%%%%%%%%%%%% start my preamble %%%%%%%%%%%%%%%%%%%%%% 39 | 40 | \addtobeamertemplate{navigation symbols}{}{% 41 | \usebeamerfont{footline}% 42 | \usebeamercolor[fg]{footline}% 43 | \hspace{1em}% 44 | \insertframenumber/\inserttotalframenumber 45 | } 46 | 47 | 48 | \usepackage{graphicx} 49 | \usepackage{amsmath, amssymb, amsthm} 50 | \usepackage{bbm} 51 | \usepackage{mathrsfs} 52 | \usepackage{xcolor} 53 | \usepackage{fancyvrb} 54 | 55 | 56 | % Quotes at start of chapters / sections 57 | \usepackage{epigraph} 58 | %\renewcommand{\epigraphflush}{flushleft} 59 | %\renewcommand{\sourceflush}{flushleft} 60 | \renewcommand{\epigraphwidth}{6in} 61 | 62 | %% Fonts 63 | 64 | %\usepackage[T1]{fontenc} 65 | \usepackage{mathpazo} 66 | %\usepackage{fontspec} 67 | %\defaultfontfeatures{Ligatures=TeX} 68 | %\setsansfont[Scale=MatchLowercase]{DejaVu Sans} 69 | %\setmonofont[Scale=MatchLowercase]{DejaVu Sans Mono} 70 | %\setmathfont{Asana Math} 71 | %\setmainfont{Optima} 72 | %\setmathrm{Optima} 73 | %\setboldmathrm[BoldFont={Optima ExtraBlack}]{Optima Bold} 74 | 75 | % Some colors 76 | 77 | \definecolor{aquamarine}{RGB}{69,139,116} 78 | \definecolor{midnightblue}{RGB}{25,25,112} 79 | \definecolor{darkslategrey}{RGB}{47,79,79} 80 | \definecolor{darkorange4}{RGB}{139,90,0} 81 | \definecolor{dogerblue}{RGB}{24,116,205} 82 | \definecolor{blue2}{RGB}{0,0,238} 83 | \definecolor{bg}{rgb}{0.95,0.95,0.95} 84 | \definecolor{DarkOrange1}{RGB}{255,127,0} 85 | \definecolor{ForestGreen}{RGB}{34,139,34} 86 | \definecolor{DarkRed}{RGB}{139, 0, 0} 87 | \definecolor{DarkBlue}{RGB}{0, 0, 139} 88 | \definecolor{Blue}{RGB}{0, 0, 255} 89 | \definecolor{Brown}{RGB}{165,42,42} 90 | 91 | 92 | \setlength{\parskip}{1.5ex plus0.5ex minus0.5ex} 93 | 94 | %\renewcommand{\baselinestretch}{1.05} 95 | %\setlength{\parskip}{1.5ex plus0.5ex minus0.5ex} 96 | %\setlength{\parindent}{0pt} 97 | 98 | % Typesetting code 99 | \definecolor{bg}{rgb}{0.95,0.95,0.95} 100 | \usepackage{minted} 101 | \setminted{mathescape, frame=lines, framesep=3mm} 102 | \usemintedstyle{friendly} 103 | %\newminted{python}{} 104 | %\newminted{c}{mathescape,frame=lines,framesep=4mm,bgcolor=bg} 105 | %\newminted{java}{mathescape,frame=lines,framesep=4mm,bgcolor=bg} 106 | %\newminted{julia}{mathescape,frame=lines,framesep=4mm,bgcolor=bg} 107 | %\newminted{ipython}{mathescape,frame=lines,framesep=4mm,bgcolor=bg} 108 | 109 | 110 | \newcommand{\Fact}{\textcolor{Brown}{\bf Fact. }} 111 | \newcommand{\Facts}{\textcolor{Brown}{\bf Facts }} 112 | \newcommand{\keya}{\textcolor{turquois4}{\bf Key Idea. }} 113 | \newcommand{\Factnodot}{\textcolor{Brown}{\bf Fact }} 114 | \newcommand{\Eg}{\textcolor{ForestGreen}{Example. }} 115 | \newcommand{\Egs}{\textcolor{ForestGreen}{Examples. }} 116 | \newcommand{\Ex}{{\bf Ex. }} 117 | 118 | 119 | 120 | \renewcommand{\theFancyVerbLine}{\sffamily 121 | \textcolor[rgb]{0.5,0.5,1.0}{\scriptsize {\arabic{FancyVerbLine}}}} 122 | 123 | \newcommand{\navy}[1]{\textcolor{Blue}{\bf #1}} 124 | \newcommand{\brown}[1]{\textcolor{Brown}{\sf #1}} 125 | \newcommand{\green}[1]{\textcolor{ForestGreen}{\sf #1}} 126 | \newcommand{\blue}[1]{\textcolor{Blue}{\sf #1}} 127 | \newcommand{\navymth}[1]{\textcolor{Blue}{#1}} 128 | \newcommand{\emp}[1]{\textcolor{DarkOrange1}{\bf #1}} 129 | \newcommand{\red}[1]{\textcolor{Red}{\bf #1}} 130 | 131 | % Symbols, redefines, etc. 132 | 133 | \newcommand{\code}[1]{\texttt{#1}} 134 | 135 | \newcommand{\argmax}{\operatornamewithlimits{argmax}} 136 | \newcommand{\argmin}{\operatornamewithlimits{argmin}} 137 | 138 | \DeclareMathOperator{\cl}{cl} 139 | \DeclareMathOperator{\interior}{int} 140 | \DeclareMathOperator{\Prob}{Prob} 141 | \DeclareMathOperator{\determinant}{det} 142 | \DeclareMathOperator{\trace}{trace} 143 | \DeclareMathOperator{\Span}{span} 144 | \DeclareMathOperator{\rank}{rank} 145 | \DeclareMathOperator{\cov}{cov} 146 | \DeclareMathOperator{\corr}{corr} 147 | \DeclareMathOperator{\var}{var} 148 | \DeclareMathOperator{\mse}{mse} 149 | \DeclareMathOperator{\se}{se} 150 | \DeclareMathOperator{\row}{row} 151 | \DeclareMathOperator{\col}{col} 152 | \DeclareMathOperator{\range}{rng} 153 | \DeclareMathOperator{\dimension}{dim} 154 | \DeclareMathOperator{\bias}{bias} 155 | 156 | 157 | % mics short cuts and symbols 158 | \newcommand{\st}{\ensuremath{\ \mathrm{s.t.}\ }} 159 | \newcommand{\setntn}[2]{ \{ #1 : #2 \} } 160 | \newcommand{\cf}[1]{ \lstinline|#1| } 161 | \newcommand{\fore}{\therefore \quad} 162 | \newcommand{\tod}{\stackrel { d } {\to} } 163 | \newcommand{\toprob}{\stackrel { p } {\to} } 164 | \newcommand{\toms}{\stackrel { ms } {\to} } 165 | \newcommand{\eqdist}{\stackrel {\textrm{ \scriptsize{d} }} {=} } 166 | \newcommand{\iidsim}{\stackrel {\textrm{ {\sc iid }}} {\sim} } 167 | \newcommand{\1}{\mathbbm 1} 168 | \newcommand{\dee}{\,{\rm d}} 169 | \newcommand{\given}{\, | \,} 170 | \newcommand{\la}{\langle} 171 | \newcommand{\ra}{\rangle} 172 | 173 | \newcommand{\boldA}{\mathbf A} 174 | \newcommand{\boldB}{\mathbf B} 175 | \newcommand{\boldC}{\mathbf C} 176 | \newcommand{\boldD}{\mathbf D} 177 | \newcommand{\boldM}{\mathbf M} 178 | \newcommand{\boldP}{\mathbf P} 179 | \newcommand{\boldQ}{\mathbf Q} 180 | \newcommand{\boldI}{\mathbf I} 181 | \newcommand{\boldX}{\mathbf X} 182 | \newcommand{\boldY}{\mathbf Y} 183 | \newcommand{\boldZ}{\mathbf Z} 184 | 185 | \newcommand{\bSigmaX}{ {\boldsymbol \Sigma_{\hboldbeta}} } 186 | \newcommand{\hbSigmaX}{ \mathbf{\hat \Sigma_{\hboldbeta}} } 187 | 188 | \newcommand{\RR}{\mathbbm R} 189 | \newcommand{\NN}{\mathbbm N} 190 | \newcommand{\PP}{\mathbbm P} 191 | \newcommand{\EE}{\mathbbm E \,} 192 | \newcommand{\XX}{\mathbbm X} 193 | \newcommand{\ZZ}{\mathbbm Z} 194 | \newcommand{\QQ}{\mathbbm Q} 195 | 196 | \newcommand{\fF}{\mathcal F} 197 | \newcommand{\dD}{\mathcal D} 198 | \newcommand{\lL}{\mathcal L} 199 | \newcommand{\gG}{\mathcal G} 200 | \newcommand{\hH}{\mathcal H} 201 | \newcommand{\nN}{\mathcal N} 202 | \newcommand{\pP}{\mathcal P} 203 | 204 | 205 | 206 | 207 | \title{Python and the AI Revolution} 208 | 209 | 210 | \author{John Stachurski} 211 | 212 | 213 | \date{May 2024} 214 | 215 | 216 | \begin{document} 217 | 218 | \begin{frame} 219 | \titlepage 220 | \end{frame} 221 | 222 | 223 | 224 | 225 | 226 | 227 | 228 | \begin{frame} 229 | \frametitle{Topics} 230 | 231 | \begin{itemize} 232 | \item Deep learning and AI 233 | \vspace{0.5em} 234 | \item AI-driven scientific computing 235 | \vspace{0.5em} 236 | \item Where are we heading? 237 | \vspace{0.5em} 238 | \item How will that impact economic modeling for policy work? 239 | \end{itemize} 240 | 241 | 242 | \end{frame} 243 | 244 | 245 | 246 | \begin{frame} 247 | \frametitle{AI-driven scientific computing} 248 | 249 | AI is changing the world 250 | 251 | \begin{itemize} 252 | \item image processing / computer vision 253 | \vspace{0.5em} 254 | \item speech recognition, translation 255 | \vspace{0.5em} 256 | \item scientific knowledge discovery 257 | \vspace{0.5em} 258 | \item forecasting and prediction 259 | \vspace{0.5em} 260 | \item generative AI 261 | \end{itemize} 262 | 263 | \pause 264 | 265 | \vspace{0.5em} 266 | \vspace{0.5em} 267 | \vspace{0.5em} 268 | Plus killer drones, skynet, etc.\ldots 269 | 270 | 271 | \end{frame} 272 | 273 | \begin{frame} 274 | 275 | Projected spending on AI in 2024: 276 | 277 | \begin{itemize} 278 | \item Google: \$48 billion 279 | \vspace{0.5em} 280 | \item Microsoft: \$60 billion 281 | \vspace{0.5em} 282 | \item Meta: \$40 billion 283 | \vspace{0.5em} 284 | \item etc. 285 | \end{itemize} 286 | 287 | \pause 288 | \vspace{0.5em} 289 | \vspace{0.5em} 290 | \vspace{0.5em} 291 | \vspace{0.5em} 292 | \navy{Key point:} vast investment in AI is changing the production 293 | possibility frontier for \emp{all} scientific coders 294 | 295 | \end{frame} 296 | 297 | 298 | \begin{frame} 299 | 300 | 301 | Platforms / libraries 302 | 303 | \begin{itemize} 304 | \item PyTorch (ChatGPT, LLaMA 3, Github Copilot) 305 | \vspace{0.5em} 306 | \item Google JAX (Gemini) 307 | \vspace{0.5em} 308 | \item Tensorflow? 309 | \vspace{0.5em} 310 | \item Mojo (by Modular)? 311 | \end{itemize} 312 | 313 | \end{frame} 314 | 315 | 316 | \begin{frame} 317 | \frametitle{Deep learning in two slides} 318 | 319 | Supervised deep learning: find a good approximation to an unknown functional 320 | relationship 321 | % 322 | \begin{equation*} 323 | y = f(x) 324 | \qquad (x \in \RR^d, \; y \in \RR) 325 | \end{equation*} 326 | 327 | \Egs 328 | % 329 | \begin{itemize} 330 | \item $x = $ sequence of words, $y = $ next word 331 | \vspace{0.5em} 332 | \item $x = $ weather sensor data, $y = $ max temp tomorrow 333 | \end{itemize} 334 | \vspace{0.5em} 335 | \vspace{0.5em} 336 | 337 | Problem: 338 | 339 | \begin{itemize} 340 | \item observe $(x_i, y_i)_{i=1}^n$ and seek $f$ such that $y_{n+1} 341 | \approx f(x_{n+1})$ 342 | \end{itemize} 343 | 344 | 345 | \end{frame} 346 | 347 | 348 | \begin{frame} 349 | 350 | Nonlinear regression: minimize the empirical loss 351 | % 352 | \begin{equation*} 353 | \ell(\theta) := \sum_{i=1}^n (y_i - f_\theta(x_i))^2 354 | \quad \st \quad \theta \in \Theta 355 | \end{equation*} 356 | 357 | 358 | \pause 359 | But what is $\{f_\theta\}_{\theta \in \Theta}$? 360 | 361 | \pause 362 | \vspace{0.5em} 363 | In the case of ANNs, we consider all $f_\theta$ having the form 364 | % 365 | \begin{equation*} 366 | f_\theta 367 | = \sigma \circ A_{1} 368 | \circ \cdots \circ \sigma \circ A_{k-1} \circ \sigma \circ A_{k} 369 | \end{equation*} 370 | % 371 | where 372 | % 373 | \begin{itemize} 374 | \item $A_{i} x = W_i x + b_i $ is an affine map 375 | \vspace{0.5em} 376 | \item $\sigma$ is a nonlinear ``activation'' function 377 | \end{itemize} 378 | 379 | \end{frame} 380 | 381 | 382 | \begin{frame} 383 | 384 | 385 | Minimizing a smooth loss functions -- what algorithm? 386 | 387 | \begin{figure} 388 | \begin{center} 389 | \scalebox{0.15}{\includegraphics[trim={0cm 0cm 0cm 0cm},clip]{gdi.png}} 390 | \end{center} 391 | \end{figure} 392 | 393 | Source: \url{https://danielkhv.com/} 394 | 395 | \end{frame} 396 | 397 | 398 | \begin{frame} 399 | 400 | Deep learning: $\theta \in \RR^d$ where $d = ?$ 401 | 402 | \begin{figure} 403 | \begin{center} 404 | \scalebox{0.14}{\includegraphics[trim={0cm 0cm 0cm 0cm},clip]{loss2.jpg}} 405 | \end{center} 406 | \end{figure} 407 | 408 | Source: \url{https://losslandscape.com/gallery/} 409 | 410 | \end{frame} 411 | 412 | \begin{frame} 413 | 414 | But what about the curse of dimensionality!??? 415 | 416 | \end{frame} 417 | 418 | \begin{frame} 419 | \frametitle{Software} 420 | 421 | \begin{figure} 422 | \begin{center} 423 | \scalebox{0.032}{\includegraphics[trim={0cm 0cm 0cm 0cm},clip]{python.png}} 424 | \end{center} 425 | \end{figure} 426 | 427 | 428 | \end{frame} 429 | 430 | \begin{frame} 431 | 432 | 433 | Core elements 434 | % 435 | \begin{itemize} 436 | \item automatic differentiation (for \underline{gradient} descent) 437 | \vspace{0.5em} 438 | \item parallelization (GPUs! --- how many?) 439 | \vspace{0.5em} 440 | \item Compilers / JIT-compilers 441 | \end{itemize} 442 | 443 | \vspace{0.5em} 444 | \vspace{0.5em} 445 | \pause 446 | Crucially, these components are all \underline{integrated} 447 | 448 | \begin{itemize} 449 | \item autodiff is JIT compiled 450 | \vspace{0.5em} 451 | \item JIT compiled functions are automatically parallelized 452 | \vspace{0.5em} 453 | \item etc. 454 | \end{itemize} 455 | 456 | \end{frame} 457 | 458 | 459 | \begin{frame} 460 | 461 | 462 | \begin{figure} 463 | \begin{center} 464 | \scalebox{2.0}{\includegraphics[trim={0cm 0cm 0cm 0cm},clip]{jax.png}} 465 | \end{center} 466 | \end{figure} 467 | 468 | 469 | \end{frame} 470 | 471 | 472 | \begin{frame}[fragile] 473 | 474 | \vspace{-1em} 475 | \begin{minted}{python} 476 | import jax.numpy as jnp 477 | from jax import grad, jit 478 | 479 | def f(θ, x): 480 | for W, b in θ: 481 | w = W @ x + b 482 | x = jnp.tanh(w) 483 | return x 484 | 485 | def loss(θ, x, y): 486 | return jnp.sum((y - f(θ, x))**2) 487 | 488 | grad_loss = jit(grad(loss)) # Now use gradient descent 489 | \end{minted} 490 | 491 | {\footnotesize Source: JAX readthedocs} 492 | 493 | \end{frame} 494 | 495 | \begin{frame} 496 | \frametitle{Hardware} 497 | 498 | \begin{figure} 499 | \begin{center} 500 | \scalebox{0.18}{\includegraphics[trim={0cm 0cm 0cm 0cm},clip]{dgx.png}} 501 | \end{center} 502 | \end{figure} 503 | 504 | 505 | \end{frame} 506 | 507 | 508 | \begin{frame} 509 | 510 | ``NVIDIA today announced its next-generation AI supercomputer — the NVIDIA 511 | DGX SuperPOD powered by GB200 Grace Blackwell Superchips — for 512 | processing trillion-parameter models for superscale 513 | generative AI training and inference workloads...'' 514 | \vspace{0.5em} 515 | \vspace{0.5em} 516 | 517 | \vspace{0.5em} 518 | \vspace{0.5em} 519 | ``NVIDIA supercomputers are the factories of the AI industrial 520 | revolution.'' -- Jensen Huang 521 | 522 | \end{frame} 523 | 524 | 525 | 526 | \begin{frame} 527 | \frametitle{Example: Weather forecasting} 528 | 529 | 530 | ``ECMWF's model is considered the gold standard for 531 | medium-term weather forecasting\ldots '' 532 | 533 | \pause 534 | 535 | \vspace{0.5em} 536 | Google DeepMind claims to now beat it 90\% of the time\ldots 537 | \vspace{0.5em} 538 | 539 | ``Traditional forecasting models are big, complex computer algorithms based 540 | on atmospheric physics and take hours to run. AI models can create forecasts 541 | in just seconds.'' 542 | \vspace{0.5em} 543 | \vspace{0.5em} 544 | 545 | $\quad \qquad$$\quad \qquad$ Source: MIT Technology Review 546 | 547 | \end{frame} 548 | 549 | 550 | \begin{frame} 551 | 552 | 553 | Relevant to economics? 554 | 555 | \vspace{0.5em} 556 | \pause 557 | Deep learning provides massively powerful pattern recognition 558 | 559 | \pause 560 | \vspace{0.5em} 561 | \vspace{0.5em} 562 | \vspace{0.5em} 563 | \navy{But} macroeconomic data is 564 | 565 | \begin{itemize} 566 | \item extremely limited 567 | \vspace{0.5em} 568 | \item generally nonstationary 569 | \vspace{0.5em} 570 | \item sensitive to policy changes (Lucas critique) 571 | \end{itemize} 572 | 573 | 574 | 575 | \end{frame} 576 | 577 | 578 | \begin{frame} 579 | 580 | My view 581 | 582 | \begin{itemize} 583 | \item Policy-centric macroeconomic modeling will survive much longer 584 | than traditional weather forecasting 585 | \vspace{0.5em} 586 | \item Deep learning is yet to prove itself as a ``better'' approach to 587 | numerical methods 588 | \vspace{0.5em} 589 | \end{itemize} 590 | 591 | \pause 592 | And yet, 593 | 594 | \begin{itemize} 595 | \item \brown{the AI computing revolution is 596 | generating tools that are enormously beneficial for macroeconomic 597 | modeling} 598 | \vspace{0.5em} 599 | \begin{itemize} 600 | \item autodiff, JIT compilers, parallelization, GPUs, etc. 601 | \end{itemize} 602 | \vspace{0.5em} 603 | \item We can take full advantage of them right now! 604 | \end{itemize} 605 | 606 | 607 | \end{frame} 608 | 609 | \end{document} 610 | 611 | 612 | -------------------------------------------------------------------------------- /source_files/thursday/opt_savings_2.md: -------------------------------------------------------------------------------- 1 | --- 2 | jupytext: 3 | text_representation: 4 | extension: .md 5 | format_name: myst 6 | format_version: 0.13 7 | jupytext_version: 1.16.1 8 | kernelspec: 9 | display_name: Python 3 (ipykernel) 10 | language: python 11 | name: python3 12 | --- 13 | 14 | # Optimal Savings II: Alternative Algorithms 15 | 16 | ----- 17 | 18 | #### John Stachurski 19 | 20 | #### Prepared for the CBC Workshop (May 2024) 21 | 22 | ----- 23 | 24 | In `opt_savings_1.ipynb` we solved a simple version of the household optimal 25 | savings problem via value function iteration (VFI) using JAX. 26 | 27 | In this lecture we tackle exactly the same problem while adding in two 28 | alternative algorithms: 29 | 30 | * optimistic policy iteration (OPI) and 31 | * Howard policy iteration (HPI). 32 | 33 | We will see that both of these algorithms outperform traditional VFI. 34 | 35 | One reason for this is that the algorithms have good convergence properties. 36 | 37 | Another is that one of them, HPI, is particularly well suited to pairing with 38 | JAX. 39 | 40 | The reason is that HPI uses a relatively small number of computationally expensive steps, 41 | whereas VFI uses a longer sequence of small steps. 42 | 43 | In other words, VFI is inherently more sequential than HPI, and sequential 44 | routines are hard to parallelize. 45 | 46 | By comparison, HPI is less sequential -- the small number of computationally 47 | intensive steps can be effectively parallelized by JAX. 48 | 49 | This is particularly valuable when the underlying hardware includes a GPU. 50 | 51 | Details on VFI, HPI and OPI can be found in [this book](https://dp.quantecon.org), for which a PDF is freely available. 52 | 53 | Here we assume readers have some knowledge of the algorithms and focus on 54 | computation. 55 | 56 | 57 | ---- 58 | 59 | Uncomment if necessary: 60 | 61 | ```{code-cell} 62 | #!pip install quantecon 63 | ``` 64 | 65 | We will use the following imports: 66 | 67 | ```{code-cell} 68 | import quantecon as qe 69 | import jax 70 | import jax.numpy as jnp 71 | from collections import namedtuple 72 | import matplotlib.pyplot as plt 73 | import time 74 | ``` 75 | 76 | Let's check the GPU we are running. 77 | 78 | ```{code-cell} 79 | !nvidia-smi 80 | ``` 81 | 82 | We'll use 64 bit floats to gain extra precision. 83 | 84 | ```{code-cell} 85 | jax.config.update("jax_enable_x64", True) 86 | ``` 87 | 88 | ## Model primitives 89 | 90 | We start with a namedtuple to store parameters and arrays 91 | 92 | ```{code-cell} 93 | Model = namedtuple('Model', ('β', 'R', 'γ', 'w_grid', 'y_grid', 'Q')) 94 | ``` 95 | 96 | The following code is repeated from `opt_savings_1`. 97 | 98 | ```{code-cell} 99 | def create_consumption_model(R=1.01, # Gross interest rate 100 | β=0.98, # Discount factor 101 | γ=2, # CRRA parameter 102 | w_min=0.01, # Min wealth 103 | w_max=5.0, # Max wealth 104 | w_size=150, # Grid size 105 | ρ=0.9, ν=0.1, y_size=100): # Income parameters 106 | """ 107 | A function that takes in parameters and returns parameters and grids 108 | for the optimal savings problem. 109 | """ 110 | w_grid = jnp.linspace(w_min, w_max, w_size) 111 | mc = qe.tauchen(n=y_size, rho=ρ, sigma=ν) 112 | y_grid, Q = jnp.exp(mc.state_values), jax.device_put(mc.P) 113 | return Model(β, R, γ, w_grid, y_grid, Q) 114 | ``` 115 | 116 | Here's the right hand side of the Bellman equation: 117 | 118 | ```{code-cell} 119 | def _B(v, model, i, j, ip): 120 | """ 121 | The right-hand side of the Bellman equation before maximization, which takes 122 | the form 123 | 124 | B(w, y, w′) = u(Rw + y - w′) + β Σ_y′ v(w′, y′) Q(y, y′) 125 | 126 | The indices are (i, j, ip) -> (w, y, w′). 127 | """ 128 | β, R, γ, w_grid, y_grid, Q = model 129 | w, y, wp = w_grid[i], y_grid[j], w_grid[ip] 130 | c = R * w + y - wp 131 | EV = jnp.sum(v[ip, :] * Q[j, :]) 132 | return jnp.where(c > 0, c**(1-γ)/(1-γ) + β * EV, -jnp.inf) 133 | ``` 134 | 135 | Now we successively apply `vmap` to vectorize $B$ by simulating nested loops. 136 | 137 | ```{code-cell} 138 | B_vmap = jax.vmap(_B, in_axes=(None, None, None, None, 0)) 139 | B_vmap = jax.vmap(B_vmap, in_axes=(None, None, None, 0, None)) 140 | B_vmap = jax.vmap(B_vmap, in_axes=(None, None, 0, None, None)) 141 | ``` 142 | 143 | Here's a fully vectorized version of $B$. 144 | 145 | ```{code-cell} 146 | @jax.jit 147 | def B(v, model): 148 | β, R, γ, w_grid, y_grid, Q = model 149 | w_size, y_size = len(w_grid), len(y_grid) 150 | w_indices, y_indices = jnp.arange(w_size), jnp.arange(y_size) 151 | return B_vmap(v, model, w_indices, y_indices, w_indices) 152 | ``` 153 | 154 | ## Operators 155 | 156 | 157 | Here's the Bellman operator $T$ 158 | 159 | ```{code-cell} 160 | @jax.jit 161 | def T(v, model): 162 | "The Bellman operator." 163 | return jnp.max(B(v, model), axis=-1) 164 | ``` 165 | 166 | The next function computes a $v$-greedy policy given $v$ 167 | 168 | ```{code-cell} 169 | @jax.jit 170 | def get_greedy(v, modeld): 171 | "Computes a v-greedy policy, returned as a set of indices." 172 | return jnp.argmax(B(v, model), axis=-1) 173 | ``` 174 | 175 | We define a function to compute the current rewards $r_\sigma$ given policy $\sigma$, 176 | which is defined as the vector 177 | 178 | $$ 179 | r_\sigma(w, y) := r(w, y, \sigma(w, y)) 180 | $$ 181 | 182 | ```{code-cell} 183 | def _compute_r_σ(σ, model, i, j): 184 | """ 185 | With indices (i, j) -> (w, y) and wp = σ[i, j], compute 186 | 187 | r_σ[i, j] = u(Rw + y - wp) 188 | 189 | which gives current rewards under policy σ. 190 | """ 191 | 192 | # Unpack model 193 | β, R, γ, w_grid, y_grid, Q = model 194 | # Compute r_σ[i, j] 195 | w, y, wp = w_grid[i], y_grid[j], w_grid[σ[i, j]] 196 | c = R * w + y - wp 197 | r_σ = c**(1-γ)/(1-γ) 198 | 199 | return r_σ 200 | ``` 201 | 202 | Now we successively apply `vmap` to simulate nested loops. 203 | 204 | ```{code-cell} 205 | compute_r_σ_vmap = jax.vmap(_compute_r_σ, in_axes=(None, None, None, 0)) 206 | compute_r_σ_vmap = jax.vmap(compute_r_σ_vmap, in_axes=(None, None, 0, None)) 207 | ``` 208 | 209 | Here's a fully vectorized version of $r_\sigma$. 210 | 211 | ```{code-cell} 212 | @jax.jit 213 | def compute_r_σ(σ, model): 214 | β, R, γ, w_grid, y_grid, Q = model 215 | w_size, y_size = len(w_grid), len(y_grid) 216 | w_indices, y_indices = jnp.arange(w_size), jnp.arange(y_size) 217 | return compute_r_σ_vmap(σ, model, w_indices, y_indices) 218 | ``` 219 | 220 | Now we define the policy operator $T_\sigma$ going through similar steps 221 | 222 | ```{code-cell} 223 | def _T_σ(v, σ, model, i, j): 224 | "The σ-policy operator." 225 | 226 | # Unpack model 227 | β, R, γ, w_grid, y_grid, Q = model 228 | 229 | r_σ = _compute_r_σ(σ, model, i, j) 230 | # Calculate the expected sum Σ_jp v[σ[i, j], jp] * Q[i, j, jp] 231 | EV = jnp.sum(v[σ[i, j], :] * Q[j, :]) 232 | 233 | return r_σ + β * EV 234 | 235 | 236 | T_σ_vmap = jax.vmap(_T_σ, in_axes=(None, None, None, None, 0)) 237 | T_σ_vmap = jax.vmap(T_σ_vmap, in_axes=(None, None, None, 0, None)) 238 | 239 | 240 | @jax.jit 241 | def T_σ(v, σ, model): 242 | β, R, γ, w_grid, y_grid, Q = model 243 | w_size, y_size = len(w_grid), len(y_grid) 244 | w_indices, y_indices = jnp.arange(w_size), jnp.arange(y_size) 245 | return T_σ_vmap(v, σ, model, w_indices, y_indices) 246 | ``` 247 | 248 | The function below computes the value $v_\sigma$ of following policy $\sigma$. 249 | 250 | This lifetime value is a function $v_\sigma$ that satisfies 251 | 252 | $$ 253 | v_\sigma(w, y) = r_\sigma(w, y) + \beta \sum_{y'} v_\sigma(\sigma(w, y), y') Q(y, y') 254 | $$ 255 | 256 | We wish to solve this equation for $v_\sigma$. 257 | 258 | Suppose we define the linear operator $L_\sigma$ by 259 | 260 | $$ 261 | (L_\sigma v)(w, y) = v(w, y) - \beta \sum_{y'} v(\sigma(w, y), y') Q(y, y') 262 | $$ 263 | 264 | With this notation, the problem is to solve for $v$ via 265 | 266 | $$ 267 | (L_{\sigma} v)(w, y) = r_\sigma(w, y) 268 | $$ 269 | 270 | In vector for this is $L_\sigma v = r_\sigma$, which tells us that the function 271 | we seek is 272 | 273 | $$ 274 | v_\sigma = L_\sigma^{-1} r_\sigma 275 | $$ 276 | 277 | JAX allows us to solve linear systems defined in terms of operators; the first 278 | step is to define the function $L_{\sigma}$. 279 | 280 | ```{code-cell} 281 | def _L_σ(v, σ, model, i, j): 282 | """ 283 | Here we set up the linear map v -> L_σ v, where 284 | 285 | (L_σ v)(w, y) = v(w, y) - β Σ_y′ v(σ(w, y), y′) Q(y, y′) 286 | 287 | """ 288 | # Unpack 289 | β, R, γ, w_grid, y_grid, Q = model 290 | # Compute and return v[i, j] - β Σ_jp v[σ[i, j], jp] * Q[j, jp] 291 | return v[i, j] - β * jnp.sum(v[σ[i, j], :] * Q[j, :]) 292 | 293 | L_σ_vmap = jax.vmap(_L_σ, in_axes=(None, None, None, None, 0)) 294 | L_σ_vmap = jax.vmap(L_σ_vmap, in_axes=(None, None, None, 0, None)) 295 | 296 | @jax.jit 297 | def L_σ(v, σ, model): 298 | β, R, γ, w_grid, y_grid, Q = model 299 | w_size, y_size = len(w_grid), len(y_grid) 300 | w_indices, y_indices = jnp.arange(w_size), jnp.arange(y_size) 301 | return L_σ_vmap(v, σ, model, w_indices, y_indices) 302 | ``` 303 | 304 | Now we can define a function to compute $v_{\sigma}$ 305 | 306 | ```{code-cell} 307 | @jax.jit 308 | def get_value(σ, model): 309 | "Get the value v_σ of policy σ by inverting the linear map L_σ." 310 | 311 | r_σ = compute_r_σ(σ, model) 312 | partial_L_σ = lambda v: L_σ(v, σ, model) 313 | return jax.scipy.sparse.linalg.bicgstab(partial_L_σ, r_σ)[0] 314 | ``` 315 | 316 | ## Iteration 317 | 318 | 319 | We use successive approximation for VFI. 320 | 321 | ```{code-cell} 322 | def successive_approx_jax(T, # Operator (callable) 323 | x_0, # Initial condition 324 | tol=1e-6, # Error tolerance 325 | max_iter=10_000): # Max iteration bound 326 | def update(inputs): 327 | k, x, error = inputs 328 | x_new = T(x) 329 | error = jnp.max(jnp.abs(x_new - x)) 330 | return k + 1, x_new, error 331 | 332 | def condition_function(inputs): 333 | k, x, error = inputs 334 | return jnp.logical_and(error > tol, k < max_iter) 335 | 336 | k, x, error = jax.lax.while_loop(condition_function, 337 | update, 338 | (1, x_0, tol + 1)) 339 | return x 340 | 341 | successive_approx_jax = jax.jit(successive_approx_jax, static_argnums=(0,)) 342 | ``` 343 | 344 | For OPI we'll add a compiled routine that computes $T_σ^m v$. 345 | 346 | ```{code-cell} 347 | @jax.jit 348 | def iterate_policy_operator(σ, v, m, model): 349 | 350 | def update(i, v): 351 | v = T_σ(v, σ, model) 352 | return v 353 | 354 | v = jax.lax.fori_loop(0, m, update, v) 355 | return v 356 | 357 | 358 | ``` 359 | 360 | ## Solvers 361 | 362 | Now we define the solvers, which implement VFI, HPI and OPI. 363 | 364 | Here's VFI. 365 | 366 | ```{code-cell} 367 | def value_function_iteration(model, tol=1e-4): 368 | """ 369 | Implements value function iteration. 370 | """ 371 | β, R, γ, w_grid, y_grid, Q = model 372 | sizes = len(w_grid), len(y_grid) 373 | vz = jnp.zeros(sizes) 374 | _T = lambda v: T(v, model) 375 | v_star = successive_approx_jax(_T, vz, tol=tol) 376 | return get_greedy(v_star, model) 377 | ``` 378 | 379 | For OPI we will use a compiled JAX `lax.while_loop` operation to speed execution. 380 | 381 | ```{code-cell} 382 | def opi_loop(model, m, tol, max_iter): 383 | """ 384 | Implements optimistic policy iteration (see dp.quantecon.org) with 385 | step size m. 386 | 387 | """ 388 | β, R, γ, w_grid, y_grid, Q = model 389 | sizes = len(w_grid), len(y_grid) 390 | v_init = jnp.zeros(sizes) 391 | 392 | def condition_function(inputs): 393 | i, v, error = inputs 394 | return jnp.logical_and(error > tol, i < max_iter) 395 | 396 | def update(inputs): 397 | i, v, error = inputs 398 | last_v = v 399 | σ = get_greedy(v, model) 400 | v = iterate_policy_operator(σ, v, m, model) 401 | error = jnp.max(jnp.abs(v - last_v)) 402 | i += 1 403 | return i, v, error 404 | 405 | num_iter, v, error = jax.lax.while_loop(condition_function, 406 | update, 407 | (0, v_init, tol + 1)) 408 | 409 | return get_greedy(v, model) 410 | 411 | opi_loop = jax.jit(opi_loop, static_argnums=(1,)) 412 | ``` 413 | 414 | Here's a friendly interface to OPI 415 | 416 | ```{code-cell} 417 | def optimistic_policy_iteration(model, m=10, tol=1e-4, max_iter=10_000): 418 | σ_star = opi_loop(model, m, tol, max_iter) 419 | return σ_star 420 | ``` 421 | 422 | Here's HPI. 423 | 424 | ```{code-cell} 425 | def howard_policy_iteration(model, tol=1e-4, maxiter=250): 426 | """ 427 | Implements Howard policy iteration (see dp.quantecon.org) 428 | """ 429 | β, R, γ, w_grid, y_grid, Q = model 430 | sizes = len(w_grid), len(y_grid) 431 | v_σ = jnp.zeros(sizes) 432 | i, error = 0, 1.0 433 | while error > tol and i < maxiter: 434 | σ = get_greedy(v_σ, model) 435 | v_σ_new = get_value(σ, model) 436 | error = jnp.max(jnp.abs(v_σ_new - v_σ)) 437 | v_σ = v_σ_new 438 | i = i + 1 439 | print(f"Concluded loop {i} with error {error}.") 440 | return σ 441 | ``` 442 | 443 | ## Plots 444 | 445 | Create a model for consumption, perform policy iteration, and plot the resulting optimal policy function. 446 | 447 | ```{code-cell} 448 | model = create_consumption_model() 449 | β, R, γ, w_grid, y_grid, Q = model 450 | ``` 451 | 452 | ```{code-cell} 453 | σ_star = howard_policy_iteration(model) 454 | 455 | fig, ax = plt.subplots() 456 | ax.plot(w_grid, w_grid, "k--", label="45") 457 | ax.plot(w_grid, w_grid[σ_star[:, 1]], label="$\\sigma^*(\cdot, y_1)$") 458 | ax.plot(w_grid, w_grid[σ_star[:, -1]], label="$\\sigma^*(\cdot, y_N)$") 459 | ax.legend() 460 | plt.show() 461 | ``` 462 | 463 | ## Tests 464 | 465 | Let's create an instance of the model. 466 | 467 | ```{code-cell} 468 | model = create_consumption_model() 469 | ``` 470 | 471 | Here's a function that runs any one of the algorithms and returns the result and 472 | elapsed time. 473 | 474 | ```{code-cell} 475 | def run_algorithm(algorithm, model, **kwargs): 476 | start_time = time.time() 477 | result = algorithm(model, **kwargs) 478 | end_time = time.time() 479 | elapsed_time = end_time - start_time 480 | print(f"{algorithm.__name__} completed in {elapsed_time:.2f} seconds.") 481 | return result, elapsed_time 482 | ``` 483 | 484 | Here's a quick test of each model. 485 | 486 | HPI first run: 487 | 488 | ```{code-cell} 489 | σ_pi, pi_time = run_algorithm(howard_policy_iteration, 490 | model) 491 | ``` 492 | 493 | HPI second run: 494 | 495 | ```{code-cell} 496 | σ_pi, pi_time = run_algorithm(howard_policy_iteration, 497 | model) 498 | ``` 499 | 500 | VFI first run: 501 | 502 | ```{code-cell} 503 | print("Starting VFI.") 504 | σ_vfi, vfi_time = run_algorithm(value_function_iteration, 505 | model, tol=1e-4) 506 | ``` 507 | 508 | VFI second run: 509 | 510 | ```{code-cell} 511 | print("Starting VFI.") 512 | σ_vfi, vfi_time = run_algorithm(value_function_iteration, 513 | model, tol=1e-4) 514 | ``` 515 | 516 | OPI first run: 517 | 518 | ```{code-cell} 519 | m = 100 520 | print(f"Starting OPI with $m = {m}$.") 521 | σ_opi, opi_time = run_algorithm(optimistic_policy_iteration, 522 | model, m=m, tol=1e-4) 523 | ``` 524 | 525 | OPI second run: 526 | 527 | ```{code-cell} 528 | m = 100 529 | print(f"Starting OPI with $m = {m}$.") 530 | σ_opi, opi_time = run_algorithm(optimistic_policy_iteration, 531 | model, m=m, tol=1e-4) 532 | ``` 533 | 534 | Now let's run OPI at a range of $m$ values and plot the execution time along 535 | side the execution time for VFI and HPI. 536 | 537 | ```{code-cell} 538 | σ_pi, pi_time = run_algorithm(howard_policy_iteration, model) 539 | σ_vfi, vfi_time = run_algorithm(value_function_iteration, model, tol=1e-4) 540 | m_vals = range(5, 600, 40) 541 | opi_times = [] 542 | for m in m_vals: 543 | σ_opi, opi_time = run_algorithm(optimistic_policy_iteration, 544 | model, m=m, tol=1e-4) 545 | opi_times.append(opi_time) 546 | ``` 547 | 548 | Here's the plot. 549 | 550 | ```{code-cell} 551 | fig, ax = plt.subplots() 552 | ax.plot(m_vals, 553 | jnp.full(len(m_vals), pi_time), 554 | lw=2, label="Howard policy iteration") 555 | ax.plot(m_vals, 556 | jnp.full(len(m_vals), vfi_time), 557 | lw=2, label="value function iteration") 558 | ax.plot(m_vals, opi_times, 559 | lw=2, label="optimistic policy iteration") 560 | ax.legend(frameon=False) 561 | ax.set_xlabel("$m$") 562 | ax.set_ylabel("time") 563 | plt.show() 564 | ``` 565 | -------------------------------------------------------------------------------- /source_files/thursday/egm.md: -------------------------------------------------------------------------------- 1 | --- 2 | jupytext: 3 | text_representation: 4 | extension: .md 5 | format_name: myst 6 | format_version: 0.13 7 | jupytext_version: 1.16.1 8 | kernelspec: 9 | display_name: Python 3 (ipykernel) 10 | language: python 11 | name: python3 12 | --- 13 | 14 | # Endogenous Grid Method 15 | 16 | ---- 17 | 18 | #### John Stachurski 19 | #### Prepared for the CBC Computational Workshop (May 2024) 20 | 21 | ---- 22 | 23 | ## Overview 24 | 25 | In this lecture we use the endogenous grid method (EGM) to solve a basic optimal savings problem. 26 | 27 | Our main implementation is in JAX, although we also run a Numba version for comparison. 28 | 29 | Our treatment of EGM is quite brief -- we just outline the algorithm. 30 | 31 | Readers who want more background can see [this lecture](https://python.quantecon.org/egm_policy_iter.html). 32 | 33 | 34 | Uncomment if necessary: 35 | 36 | ```{code-cell} 37 | #!pip install --upgrade quantecon 38 | ``` 39 | 40 | Let's run some imports 41 | 42 | ```{code-cell} 43 | import quantecon as qe 44 | from collections import namedtuple 45 | import matplotlib.pyplot as plt 46 | import numpy as np 47 | import jax 48 | import jax.numpy as jnp 49 | import numba 50 | ``` 51 | 52 | What GPU are we running? 53 | 54 | ```{code-cell} 55 | !nvidia-smi 56 | ``` 57 | 58 | We use 64 bit floating point numbers for extra precision. 59 | 60 | ```{code-cell} 61 | jax.config.update("jax_enable_x64", True) 62 | ``` 63 | 64 | ## Setup 65 | 66 | Consider a household that chooses $\{c_t\}_{t \geq 0}$ to maximize 67 | 68 | $$ 69 | \mathbb{E} \, \sum_{t=0}^{\infty} \beta^t u(c_t) 70 | $$ 71 | 72 | subject to 73 | 74 | $$ 75 | a_{t+1} = R(a_t - c_t) + Y_{t+1} 76 | \quad \text{and} 77 | \quad 0 \leq c_t \leq a_t \quad \text{for } t \geq 0 78 | $$ 79 | 80 | Here 81 | 82 | * $\beta \in (0,1)$ is a discount factor 83 | * $R = 1 + r$ where $r$ is the interest rate 84 | * the income process $\{Y_t\}$ is a generated by stochastic matrix $P$ 85 | 86 | The matrix $P$ and the grid of values taken by $Y_t$ are obtained by discretizing the AR(1) process 87 | 88 | $$ 89 | Y_{t+1} = \rho Y_t + \nu \epsilon_{t+1} 90 | $$ 91 | 92 | where $\{\epsilon_t\}$ is IID and standard normal. 93 | 94 | +++ 95 | 96 | ### Euler iteration (time iteration) 97 | 98 | Let $S = \mathbb R_+ \times \mathsf Y$ be the set of possible values for the state $(a_t, Y_t)$. 99 | 100 | We aim to compute an optimal consumption policy $\sigma \colon S \to \mathbb R$, under which dynamics are given by 101 | 102 | $$ 103 | c_t = \sigma(a_t, Y_t) 104 | \quad \text{and} \quad 105 | a_{t+1} = R (a_t - c_t) + Y_{t+1} 106 | $$ 107 | 108 | 109 | We solve for this policy via the endogenous grid method (EGM). 110 | 111 | EGM is a special case of an algorithm called either "time iteration" (a bad name) or "Euler iteration" (a better one). 112 | 113 | Euler iteration involves guessing a consumption policy $\sigma$ and then updating it using the Euler equation. 114 | 115 | We call the update rule $K$ and think of $K$ as an operator. 116 | 117 | EGM is a technique for computing the update (computing $K\sigma$ from $\sigma$) via approximation that is very fast in simple settings (but often fails or does not exist in more complex ones). 118 | 119 | +++ 120 | 121 | ## Specification 122 | 123 | Utility has the CRRA specification 124 | 125 | $$ 126 | u(c) = \frac{c^{1 - \gamma}} {1 - \gamma} 127 | $$ 128 | 129 | We start with a namedtuple to store parameters and arrays 130 | 131 | ```{code-cell} 132 | Model = namedtuple('Model', ('β', 'R', 'γ', 's_grid', 'y_grid', 'P')) 133 | ``` 134 | 135 | The following function stores default parameter values for the income fluctuation problem and creates suitable arrays. 136 | 137 | ```{code-cell} 138 | def ifp(R=1.01, # Gross interest rate 139 | β=0.99, # Discount factor 140 | γ=1.5, # CRRA preference parameter 141 | s_max=16, # Savings grid max 142 | s_size=200, # Savings grid size 143 | ρ=0.99, # Income persistence 144 | ν=0.02, # Income volatility 145 | y_size=25): # Income grid size 146 | 147 | # require R β < 1 for convergence 148 | assert R * β < 1, "Stability condition failed." 149 | 150 | # Create arrays 151 | mc = qe.tauchen(y_size, ρ, ν) 152 | y_grid, P = jnp.exp(mc.state_values), jnp.array(mc.P) 153 | s_grid = jnp.linspace(0, s_max, s_size) 154 | 155 | # Pack and return 156 | return Model(β, R, γ, s_grid, y_grid, P) 157 | ``` 158 | 159 | ### The EGM algorithm 160 | 161 | 162 | We take as given a fixed and *exogenous* grid of saving values $s = s[i]$ 163 | 164 | We begin with a pair of arrays $a[i, j]$ and $\sigma[i, j]$, where $i$ is in `range(len(s_grid))` and $j$ is in `range(len(y_grid))`. 165 | 166 | Fixing $j$, we linearly interpolate $a[:, j]$ and $\sigma[:, j]$ to great a consumption policy when income is in state $j$ --- let's also denote this consumption policy by $\sigma$. 167 | 168 | 169 | Next we set $a_0 = c_0 = 0$, since zero consumption is the only choice when $a=0$. 170 | 171 | Then, for $i > 0$, we compute 172 | 173 | $$ 174 | c[i, j] 175 | = (u')^{-1} 176 | \left\{ 177 | \beta R \, \mathbb \sum_{j'} u' 178 | \left [ 179 | \sigma(R s[i] + y[j'], \, j') 180 | \right] P[j, j'] 181 | \right\} 182 | $$ 183 | 184 | and we set 185 | 186 | $$ 187 | a[i, j] = s[i] + c[i, j] 188 | $$ 189 | 190 | Now we increment $j$ and then repeat over all $i$. 191 | 192 | This gives us a new pair of arrays $a = a[i, j]$ and $\sigma = \sigma[i, j]$. 193 | 194 | Now we repeat the whole procedure, each time updating $a$ and $\sigma$. 195 | 196 | Once iteration finishes, $\sigma[i, j]$ contains (approximately) optimal consumption when assets are $a[i,j]$ and income is in state $j$. 197 | 198 | +++ 199 | 200 | ## JAX version 201 | 202 | First we define an operator $K$ for updating consumption policies based on the EGM. 203 | 204 | We'll try a `vmap` version and then we'll try a vectorized version using reshapes. 205 | 206 | Here's a non-vectorized version of $K$ that we can apply `vmap` to: 207 | 208 | ```{code-cell} 209 | def K_jax_generator(a_vec, σ_vec, model, i, j): 210 | """ 211 | Computes Kσ evaluated at (a_i, y_j). 212 | 213 | """ 214 | # Unpack 215 | β, R, γ, s_grid, y_grid, P = model 216 | y_grid_idx = jnp.arange(len(y_grid)) 217 | s_i = s_grid[i] 218 | 219 | def u_prime(c): 220 | return c**(-γ) 221 | 222 | def u_prime_inv(u): 223 | return u**(-1/γ) 224 | 225 | # Evaluate σ(R s_i + y_k, y_k) over all future income states k 226 | def f(k): 227 | return jnp.interp(R * s_i + y_grid[k], a_vec[:, k], σ_vec[:, k]) 228 | consumption_vals = jax.vmap(f)(y_grid_idx) 229 | 230 | # Evaluate consumption choice 231 | E = u_prime(consumption_vals) @ P[j, :] 232 | c_i = u_prime_inv(β * R * E) 233 | c_i = c_i * (i > 0) # When s_i = 0, set c_i = 0 234 | a_i = s_i + c_i 235 | 236 | return a_i, c_i 237 | ``` 238 | 239 | ```{code-cell} 240 | K_jax_generator = jax.vmap(K_jax_generator, 241 | in_axes=(None, None, None, None, 0)) 242 | K_jax_generator = jax.vmap(K_jax_generator, 243 | in_axes=(None, None, None, 0, None)) 244 | ``` 245 | 246 | ```{code-cell} 247 | @jax.jit 248 | def K_jax(a_vec, σ_vec, model): 249 | # Unpack 250 | β, R, γ, s_grid, y_grid, P = model 251 | s_size, y_size = len(s_grid), len(y_grid) 252 | s_indices, y_indices = jnp.arange(s_size), jnp.arange(y_size) 253 | return K_jax_generator(a_vec, σ_vec, model, s_indices, y_indices) 254 | ``` 255 | 256 | Here's the vectorized version using reshapes. 257 | 258 | ```{code-cell} 259 | @jax.jit 260 | def K_jax_vectorized(a_vec, σ, model): 261 | "The vectorized operator K using EGM." 262 | 263 | # Unpack 264 | β, R, γ, s_grid, y_grid, P = model 265 | s_size, y_size = len(s_grid), len(y_grid) 266 | 267 | def u_prime(c): 268 | return c**(-γ) 269 | 270 | def u_prime_inv(u): 271 | return u**(-1/γ) 272 | 273 | # Linearly interpolate σ(a, y) 274 | def σ_f(a, y): 275 | return jnp.interp(a, a_vec[:, y], σ[:, y]) 276 | σ_vec = jnp.vectorize(σ_f) 277 | 278 | # Broadcast and vectorize 279 | y_hat = jnp.reshape(y_grid, (1, 1, y_size)) 280 | y_hat_idx = jnp.reshape(jnp.arange(y_size), (1, 1, y_size)) 281 | s = jnp.reshape(s_grid, (s_size, 1, 1)) 282 | P = jnp.reshape(P, (1, y_size, y_size)) 283 | 284 | # Evaluate consumption choice 285 | a_next = R * s + y_hat 286 | σ_next = σ_vec(a_next, y_hat_idx) 287 | up = u_prime(σ_next) 288 | E = jnp.sum(up * P, axis=-1) 289 | c = u_prime_inv(β * R * E) 290 | 291 | # Set up a column vector with zero in the first row and ones elsewhere 292 | e_0 = jnp.ones(s_size) - jnp.identity(s_size)[:, 0] 293 | e_0 = jnp.reshape(e_0, (s_size, 1)) 294 | 295 | # The policy is computed consumption with the first row set to zero 296 | σ_out = c * e_0 297 | 298 | # Compute a_out by a = s + c 299 | a_out = np.reshape(s_grid, (s_size, 1)) + σ_out 300 | 301 | return a_out, σ_out 302 | ``` 303 | 304 | Let's check that they compute the same thing. 305 | 306 | ```{code-cell} 307 | # Unpack 308 | model = ifp() 309 | β, R, γ, s_grid, y_grid, P = model 310 | s_size, y_size = len(s_grid), len(y_grid) 311 | ``` 312 | 313 | ```{code-cell} 314 | # Initial condition is to consume all in every state 315 | σ_vec = jnp.repeat(s_grid, y_size) 316 | σ_vec = jnp.reshape(σ_vec, (s_size, y_size)) 317 | a_vec = jnp.copy(σ_vec) 318 | ``` 319 | 320 | ```{code-cell} 321 | a_vmap, σ_vmap = K_jax(a_vec, σ_vec, model) 322 | ``` 323 | 324 | ```{code-cell} 325 | a_vectorized, σ_vectorized = K_jax_vectorized(a_vec, σ_vec, model) 326 | ``` 327 | 328 | ```{code-cell} 329 | jnp.allclose(a_vmap, a_vectorized) 330 | ``` 331 | 332 | ```{code-cell} 333 | jnp.allclose(σ_vmap, σ_vectorized) 334 | ``` 335 | 336 | OK, so they compute the same thing. Now let's test timing over multiple runs via `timeit`: 337 | 338 | ```{code-cell} 339 | %timeit _, _ = K_jax(a_vec, σ_vec, model) 340 | ``` 341 | 342 | ```{code-cell} 343 | %timeit _, _ = K_jax_vectorized(a_vec, σ_vec, model) 344 | ``` 345 | 346 | The two versions run in about the same time. 347 | 348 | We'll use the `vmap` version in what follows. 349 | 350 | +++ 351 | 352 | Next we define a successive approximator that repeatedly applies $K$. 353 | 354 | ```{code-cell} 355 | def successive_approx_jax(model, 356 | tol=1e-5, 357 | max_iter=100_000, 358 | verbose=True, 359 | print_skip=25): 360 | 361 | # Unpack 362 | β, R, γ, s_grid, y_grid, P = model 363 | s_size, y_size = len(s_grid), len(y_grid) 364 | 365 | # Initial condition is to consume all in every state 366 | σ_init = jnp.repeat(s_grid, y_size) 367 | σ_init = jnp.reshape(σ_init, (s_size, y_size)) 368 | a_init = jnp.copy(σ_init) 369 | a_vec, σ_vec = a_init, σ_init 370 | 371 | i = 0 372 | error = tol + 1 373 | 374 | while i < max_iter and error > tol: 375 | a_new, σ_new = K_jax(a_vec, σ_vec, model) 376 | error = jnp.max(jnp.abs(σ_vec - σ_new)) 377 | i += 1 378 | if verbose and i % print_skip == 0: 379 | print(f"Error at iteration {i} is {error}.") 380 | a_vec, σ_vec = a_new, σ_new 381 | 382 | if error > tol: 383 | print("Failed to converge!") 384 | else: 385 | print(f"\nConverged in {i} iterations.") 386 | 387 | return a_new, σ_new 388 | ``` 389 | 390 | ### Numba version 391 | 392 | Below we provide a second set of code, which solves the same model with Numba. 393 | 394 | The purpose of this code is to cross-check our results from the JAX version, as 395 | well as to do a runtime comparison. 396 | 397 | Most readers will want to skip ahead to the next section, where we solve the 398 | model and run the cross-check. 399 | 400 | ```{code-cell} 401 | 402 | def ifp_numba( 403 | R=1.01, # Gross interest rate 404 | β=0.99, # Discount factor 405 | γ=1.5, # CRRA preference parameter 406 | s_max=16, # Savings grid max 407 | s_size=200, # Savings grid size 408 | ρ=0.99, # Income persistence 409 | ν=0.02, # Income volatility 410 | y_size=25): # Income grid size 411 | 412 | # require R β < 1 for convergence 413 | assert R * β < 1, "Stability condition failed." 414 | 415 | # Create arrays 416 | mc = qe.tauchen(y_size, ρ, ν) 417 | y_grid, P = np.exp(mc.state_values), mc.P 418 | s_grid = np.linspace(0, s_max, s_size) 419 | 420 | # Pack and return 421 | return Model(β, R, γ, s_grid, y_grid, P) 422 | 423 | 424 | @numba.jit 425 | def K_nb(a_vec, σ, model): 426 | "The operator K using Numba." 427 | 428 | # Unpack 429 | β, R, γ, s_grid, y_grid, P = model 430 | s_size, y_size = len(s_grid), len(y_grid) 431 | 432 | def u_prime(c): 433 | return c**(-γ) 434 | 435 | def u_prime_inv(u): 436 | return u**(-1/γ) 437 | 438 | # Linear interpolation of policy using endogenous grid 439 | def σ_f(a, z): 440 | return np.interp(a, a_vec[:, z], σ[:, z]) 441 | 442 | # Allocate memory for new consumption array 443 | σ_out = np.zeros_like(σ) 444 | a_out = np.zeros_like(σ_out) 445 | 446 | for i, s in enumerate(s_grid[1:]): 447 | i += 1 448 | for z in range(y_size): 449 | expect = 0.0 450 | for z_hat in range(y_size): 451 | expect += u_prime(σ_f(R * s + y_grid[z_hat], z_hat)) * \ 452 | P[z, z_hat] 453 | c = u_prime_inv(β * R * expect) 454 | σ_out[i, z] = c 455 | a_out[i, z] = s + c 456 | 457 | return a_out, σ_out 458 | ``` 459 | 460 | ```{code-cell} 461 | def successive_approx_numba(model, # Class with model information 462 | tol=1e-5, 463 | max_iter=100_000, 464 | verbose=True, 465 | print_skip=25): 466 | 467 | # Unpack 468 | β, R, γ, s_grid, y_grid, P = model 469 | s_size, y_size = len(s_grid), len(y_grid) 470 | 471 | # make NumPy versions of arrays 472 | s_grid, y_grid, P = [np.array(x) for x in (s_grid, y_grid, P)] 473 | 474 | σ_init = np.repeat(s_grid, y_size) 475 | σ_init = np.reshape(σ_init, (s_size, y_size)) 476 | a_init = np.copy(σ_init) 477 | a_vec, σ_vec = a_init, σ_init 478 | 479 | # Set up loop 480 | i = 0 481 | error = tol + 1 482 | 483 | while i < max_iter and error > tol: 484 | a_new, σ_new = K_nb(a_vec, σ_vec, model) 485 | error = np.max(np.abs(σ_vec - σ_new)) 486 | i += 1 487 | if verbose and i % print_skip == 0: 488 | print(f"Error at iteration {i} is {error}.") 489 | a_vec, σ_vec = a_new, σ_new 490 | 491 | if error > tol: 492 | print("Failed to converge!") 493 | else: 494 | print(f"\nConverged in {i} iterations.") 495 | 496 | return a_new, σ_new 497 | ``` 498 | 499 | ## Solutions 500 | 501 | Here we solve the IFP with JAX and Numba. 502 | 503 | We will compare both the outputs and the execution time. 504 | 505 | ### Outputs 506 | 507 | 508 | Here's a first run of the JAX code. 509 | 510 | ```{code-cell} 511 | model = ifp() 512 | a_star_jax, σ_star_jax = successive_approx_jax(model, 513 | print_skip=100) 514 | ``` 515 | 516 | Here's a first run of the Numba code. 517 | 518 | ```{code-cell} 519 | model = ifp_numba() 520 | a_star_nb, σ_star_nb = successive_approx_numba(model, 521 | print_skip=100) 522 | ``` 523 | 524 | Now let's check the outputs in a plot to make sure they are the same. 525 | 526 | ```{code-cell} 527 | fig, ax = plt.subplots() 528 | β, R, γ, s_grid, y_grid, P = model 529 | s_size, y_size = len(s_grid), len(y_grid) 530 | 531 | for z in (0, y_size-1): 532 | ax.plot(a_star_nb[:, z], 533 | σ_star_nb[:, z], 534 | '--', lw=2, 535 | label=f"Numba EGM: consumption when $z={z}$") 536 | ax.plot(a_star_jax[:, z], 537 | σ_star_jax[:, z], 538 | label=f"JAX EGM: consumption when $z={z}$") 539 | 540 | ax.set_xlabel('asset') 541 | plt.legend() 542 | plt.show() 543 | ``` 544 | 545 | ### Timing 546 | 547 | Now let's compare execution time of the two methods 548 | 549 | ```{code-cell} 550 | model = ifp() 551 | qe.tic() 552 | a_star_jax, σ_star_jax = successive_approx_jax(model, 553 | print_skip=1000) 554 | jax_time = qe.toc() 555 | ``` 556 | 557 | ```{code-cell} 558 | model = ifp_numba() 559 | qe.tic() 560 | a_star_nb, σ_star_nb = successive_approx_numba(model, 561 | print_skip=1000) 562 | numba_time = qe.toc() 563 | ``` 564 | 565 | How much faster is JAX? 566 | 567 | ```{code-cell} 568 | numba_time / jax_time 569 | ``` 570 | 571 | The JAX code is significantly faster, as expected. 572 | 573 | This difference will increase when more features (and state variables) are added 574 | to the model. 575 | 576 | +++ 577 | 578 | ## Exercise 579 | 580 | Try replacing `successive_approx_jax` with a jitted version (`@jax.jit` at the top) that uses `jax.lax.while_loop`. 581 | 582 | Measure the execution time (after running once to compile) and compare it with the timings above. 583 | 584 | Also plot the resulting functions using the plotting code above to make sure that you're still getting the same outputs. 585 | 586 | ```{code-cell} 587 | #Put your code here 588 | ``` 589 | 590 | ```{code-cell} 591 | for i in range(18): 592 | print("Solution below! 🐘") 593 | ``` 594 | 595 | ```{code-cell} 596 | @jax.jit 597 | def successive_approx_jax_jitted( 598 | model, 599 | tol=1e-5, 600 | max_iter=100_000, 601 | verbose=True, 602 | print_skip=25): 603 | 604 | # Unpack 605 | β, R, γ, s_grid, y_grid, P = model 606 | s_size, y_size = len(s_grid), len(y_grid) 607 | 608 | # Initial condition is to consume all in every state 609 | σ_init = jnp.repeat(s_grid, y_size) 610 | σ_init = jnp.reshape(σ_init, (s_size, y_size)) 611 | a_init = jnp.copy(σ_init) 612 | 613 | def update(state): 614 | i, a_vec, σ_vec, error = state 615 | a_new, σ_new = K_jax(a_vec, σ_vec, model) 616 | error = jnp.max(jnp.abs(σ_vec - σ_new)) 617 | i += 1 618 | return i, a_new, σ_new, error 619 | 620 | def condition(state): 621 | i, a_vec, σ_vec, error = state 622 | return jnp.logical_and(i < max_iter, error > tol) 623 | 624 | init_state = (0, a_init, σ_init, tol + 1) 625 | state = jax.lax.while_loop(condition, update, init_state) 626 | 627 | return state 628 | ``` 629 | 630 | Here's a first run. 631 | 632 | ```{code-cell} 633 | model = ifp() 634 | i, a_star_jax_jit, σ_star_jax_jit, error = successive_approx_jax_jitted(model, 635 | print_skip=1000) 636 | ``` 637 | 638 | ```{code-cell} 639 | print(f"Run completed in {i} iterations with error {error:.5}.") 640 | ``` 641 | 642 | Now let's time it. 643 | 644 | ```{code-cell} 645 | qe.tic() 646 | i, a_star_jax_jit, σ_star_jax_jit, error = successive_approx_jax_jitted(model, 647 | print_skip=1000) 648 | jax_jit_time = qe.toc() 649 | ``` 650 | 651 | ```{code-cell} 652 | jax_time / jax_jit_time 653 | ``` 654 | 655 | ```{code-cell} 656 | numba_time / jax_jit_time 657 | ``` 658 | 659 | ```{code-cell} 660 | β, R, γ, s_grid, y_grid, P = model 661 | s_size, y_size = len(s_grid), len(y_grid) 662 | 663 | fig, ax = plt.subplots() 664 | 665 | for z in (0, y_size-1): 666 | ax.plot(a_star_jax_jit[:, z], 667 | σ_star_jax_jit[:, z], 668 | label=f"JAX EGM: consumption when $z={z}$") 669 | 670 | ax.set_xlabel('asset') 671 | plt.legend() 672 | plt.show() 673 | ``` 674 | --------------------------------------------------------------------------------