├── README.md
├── zreferences.ipynb
├── troubleshooting.ipynb
├── intro.ipynb
├── status.ipynb
├── about.ipynb
├── keras.ipynb
├── job_search.ipynb
├── kesten_processes.ipynb
├── inventory_ssd.ipynb
├── short_path.ipynb
├── cake_eating_numerical.ipynb
├── mle.ipynb
├── opt_savings.ipynb
└── autodiff.ipynb
/README.md:
--------------------------------------------------------------------------------
1 | # lecture-jax.notebooks
2 | Notebooks for the lecture-jax lecture series
3 |
--------------------------------------------------------------------------------
/zreferences.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "fb395af3",
6 | "metadata": {},
7 | "source": [
8 | "$$\n",
9 | "\\newcommand{\\argmax}{arg\\,max}\n",
10 | "\\newcommand{\\argmin}{arg\\,min}\n",
11 | "$$"
12 | ]
13 | },
14 | {
15 | "cell_type": "markdown",
16 | "id": "917ade05",
17 | "metadata": {},
18 | "source": [
19 | "\n",
20 | ""
21 | ]
22 | },
23 | {
24 | "cell_type": "markdown",
25 | "id": "cd18ce36",
26 | "metadata": {},
27 | "source": [
28 | "# References\n",
29 | "\n",
30 | "\n",
31 | "\\[Are08\\] Cristina Arellano. Default risk and income fluctuations in emerging economies. *The American Economic Review*, pages 690–712, 2008.\n",
32 | "\n",
33 | "\n",
34 | "\\[Bia11\\] Javier Bianchi. Overborrowing and systemic externalities in the business cycle. *American Economic Review*, 101(7):3400–3426, December 2011. URL: [https://www.aeaweb.org/articles?id=10.1257/aer.101.7.3400](https://www.aeaweb.org/articles?id=10.1257/aer.101.7.3400), [doi:10.1257/aer.101.7.3400](https://doi.org/10.1257/aer.101.7.3400).\n",
35 | "\n",
36 | "\n",
37 | "\\[Luc78\\] Robert E Lucas, Jr. Asset prices in an exchange economy. *Econometrica: Journal of the Econometric Society*, 46(6):1429–1445, 1978."
38 | ]
39 | }
40 | ],
41 | "metadata": {
42 | "date": 1765244755.815626,
43 | "filename": "zreferences.md",
44 | "kernelspec": {
45 | "display_name": "Python",
46 | "language": "python3",
47 | "name": "python3"
48 | },
49 | "title": "References"
50 | },
51 | "nbformat": 4,
52 | "nbformat_minor": 5
53 | }
--------------------------------------------------------------------------------
/troubleshooting.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "cef471a4",
6 | "metadata": {},
7 | "source": [
8 | "$$\n",
9 | "\\newcommand{\\argmax}{arg\\,max}\n",
10 | "\\newcommand{\\argmin}{arg\\,min}\n",
11 | "$$"
12 | ]
13 | },
14 | {
15 | "cell_type": "markdown",
16 | "id": "41adcbd2",
17 | "metadata": {},
18 | "source": [
19 | "\n",
20 | ""
21 | ]
22 | },
23 | {
24 | "cell_type": "markdown",
25 | "id": "d2bd14a0",
26 | "metadata": {},
27 | "source": [
28 | "# Troubleshooting\n",
29 | "\n",
30 | "This page is for readers experiencing errors when running the code from the lectures."
31 | ]
32 | },
33 | {
34 | "cell_type": "markdown",
35 | "id": "a403eb71",
36 | "metadata": {},
37 | "source": [
38 | "## Fixing Your Local Environment\n",
39 | "\n",
40 | "The basic assumption of the lectures is that code in a lecture should execute whenever\n",
41 | "\n",
42 | "1. it is executed in a Jupyter notebook and \n",
43 | "1. the notebook is running on a machine with the latest version of Anaconda Python. \n",
44 | "\n",
45 | "\n",
46 | "You have installed Anaconda, haven’t you, following the instructions in [this lecture](https://python-programming.quantecon.org/getting_started.html)?\n",
47 | "\n",
48 | "Assuming that you have, the most common source of problems for our readers is that their Anaconda distribution is not up to date.\n",
49 | "\n",
50 | "[Here’s a useful article](https://www.anaconda.com/blog/keeping-anaconda-date)\n",
51 | "on how to update Anaconda.\n",
52 | "\n",
53 | "Another option is to simply remove Anaconda and reinstall.\n",
54 | "\n",
55 | "You also need to keep the external code libraries, such as [QuantEcon.py](https://quantecon.org/quantecon-py) up to date.\n",
56 | "\n",
57 | "For this task you can either\n",
58 | "\n",
59 | "- use conda install -y quantecon on the command line, or \n",
60 | "- execute !conda install -y quantecon within a Jupyter notebook. \n",
61 | "\n",
62 | "\n",
63 | "If your local environment is still not working you can do two things.\n",
64 | "\n",
65 | "First, you can use a remote machine instead, by clicking on the Launch Notebook icon available for each lecture\n",
66 | "\n",
67 | "\n",
68 | "\n",
69 | "Second, you can report an issue, so we can try to fix your local set up.\n",
70 | "\n",
71 | "We like getting feedback on the lectures so please don’t hesitate to get in\n",
72 | "touch."
73 | ]
74 | },
75 | {
76 | "cell_type": "markdown",
77 | "id": "dba61004",
78 | "metadata": {},
79 | "source": [
80 | "## Reporting an Issue\n",
81 | "\n",
82 | "One way to give feedback is to raise an issue through our [issue tracker](https://github.com/QuantEcon/lecture-python/issues).\n",
83 | "\n",
84 | "Please be as specific as possible. Tell us where the problem is and as much\n",
85 | "detail about your local set up as you can provide.\n",
86 | "\n",
87 | "Another feedback option is to use our [discourse forum](https://discourse.quantecon.org/).\n",
88 | "\n",
89 | "Finally, you can provide direct feedback to [contact@quantecon.org](mailto:contact@quantecon.org)"
90 | ]
91 | }
92 | ],
93 | "metadata": {
94 | "date": 1765244755.7753887,
95 | "filename": "troubleshooting.md",
96 | "kernelspec": {
97 | "display_name": "Python",
98 | "language": "python3",
99 | "name": "python3"
100 | },
101 | "title": "Troubleshooting"
102 | },
103 | "nbformat": 4,
104 | "nbformat_minor": 5
105 | }
--------------------------------------------------------------------------------
/intro.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "079656f1",
6 | "metadata": {},
7 | "source": [
8 | "$$\n",
9 | "\\newcommand{\\argmax}{arg\\,max}\n",
10 | "\\newcommand{\\argmin}{arg\\,min}\n",
11 | "$$"
12 | ]
13 | },
14 | {
15 | "cell_type": "markdown",
16 | "id": "4220111f",
17 | "metadata": {},
18 | "source": [
19 | "# Quantitative Economics with JAX\n",
20 | "\n",
21 | "This website presents a set of lectures on quantitative economic modeling\n",
22 | "using GPUs and [Google JAX](https://jax.readthedocs.io)."
23 | ]
24 | },
25 | {
26 | "cell_type": "markdown",
27 | "id": "fe0de9c3",
28 | "metadata": {},
29 | "source": [
30 | "# Introduction\n",
31 | "\n",
32 | "- [About](https://jax.quantecon.org/about.html)\n",
33 | "- [An Introduction to JAX](https://jax.quantecon.org/jax_intro.html)\n",
34 | "- [Adventures with Autodiff](https://jax.quantecon.org/autodiff.html)\n",
35 | "- [Newton’s Method via JAX](https://jax.quantecon.org/newtons_method.html)"
36 | ]
37 | },
38 | {
39 | "cell_type": "markdown",
40 | "id": "31b70446",
41 | "metadata": {},
42 | "source": [
43 | "# Simulation\n",
44 | "\n",
45 | "- [Inventory Dynamics](https://jax.quantecon.org/inventory_dynamics.html)\n",
46 | "- [Kesten Processes and Firm Dynamics](https://jax.quantecon.org/kesten_processes.html)\n",
47 | "- [Wealth Distribution Dynamics](https://jax.quantecon.org/wealth_dynamics.html)"
48 | ]
49 | },
50 | {
51 | "cell_type": "markdown",
52 | "id": "12d226ef",
53 | "metadata": {},
54 | "source": [
55 | "# Asset Pricing\n",
56 | "\n",
57 | "- [Asset Pricing: The Lucas Asset Pricing Model](https://jax.quantecon.org/lucas_model.html)\n",
58 | "- [An Asset Pricing Problem](https://jax.quantecon.org/markov_asset.html)"
59 | ]
60 | },
61 | {
62 | "cell_type": "markdown",
63 | "id": "a5f010e1",
64 | "metadata": {},
65 | "source": [
66 | "# Dynamic Programming\n",
67 | "\n",
68 | "- [Job Search](https://jax.quantecon.org/job_search.html)\n",
69 | "- [Optimal Savings I: Value Function Iteration](https://jax.quantecon.org/opt_savings_1.html)\n",
70 | "- [Optimal Savings II: Alternative Algorithms](https://jax.quantecon.org/opt_savings_2.html)\n",
71 | "- [Shortest Paths](https://jax.quantecon.org/short_path.html)\n",
72 | "- [Optimal Investment](https://jax.quantecon.org/opt_invest.html)\n",
73 | "- [Inventory Management Model](https://jax.quantecon.org/inventory_ssd.html)\n",
74 | "- [Endogenous Grid Method](https://jax.quantecon.org/ifp_egm.html)"
75 | ]
76 | },
77 | {
78 | "cell_type": "markdown",
79 | "id": "3ff19208",
80 | "metadata": {},
81 | "source": [
82 | "# Macroeconomic Models\n",
83 | "\n",
84 | "- [Default Risk and Income Fluctuations](https://jax.quantecon.org/arellano.html)\n",
85 | "- [The Aiyagari Model](https://jax.quantecon.org/aiyagari_jax.html)\n",
86 | "- [The Hopenhayn Entry-Exit Model](https://jax.quantecon.org/hopenhayn.html)\n",
87 | "- [Bianchi Overborrowing Model](https://jax.quantecon.org/overborrowing.html)"
88 | ]
89 | },
90 | {
91 | "cell_type": "markdown",
92 | "id": "9d8f0a40",
93 | "metadata": {},
94 | "source": [
95 | "# Data and Empirics\n",
96 | "\n",
97 | "- [Maximum Likelihood Estimation](https://jax.quantecon.org/mle.html)\n",
98 | "- [Simple Neural Network Regression with Keras and JAX](https://jax.quantecon.org/keras.html)\n",
99 | "- [Neural Network Regression with JAX](https://jax.quantecon.org/jax_nn.html)\n",
100 | "- [Policy Gradient-Based Optimal Savings](https://jax.quantecon.org/ifp_dl.html)"
101 | ]
102 | },
103 | {
104 | "cell_type": "markdown",
105 | "id": "0575e09a",
106 | "metadata": {},
107 | "source": [
108 | "# Other\n",
109 | "\n",
110 | "- [Troubleshooting](https://jax.quantecon.org/troubleshooting.html)\n",
111 | "- [References](https://jax.quantecon.org/zreferences.html)\n",
112 | "- [Execution Statistics](https://jax.quantecon.org/status.html)"
113 | ]
114 | }
115 | ],
116 | "metadata": {
117 | "date": 1765244755.2207153,
118 | "filename": "intro.md",
119 | "kernelspec": {
120 | "display_name": "Python",
121 | "language": "python3",
122 | "name": "python3"
123 | },
124 | "title": "Quantitative Economics with JAX"
125 | },
126 | "nbformat": 4,
127 | "nbformat_minor": 5
128 | }
--------------------------------------------------------------------------------
/status.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "3a0b62d4",
6 | "metadata": {},
7 | "source": [
8 | "$$\n",
9 | "\\newcommand{\\argmax}{arg\\,max}\n",
10 | "\\newcommand{\\argmin}{arg\\,min}\n",
11 | "$$"
12 | ]
13 | },
14 | {
15 | "cell_type": "markdown",
16 | "id": "f56814b0",
17 | "metadata": {},
18 | "source": [
19 | "# Execution Statistics\n",
20 | "\n",
21 | "This table contains the latest execution statistics.\n",
22 | "\n",
23 | "[](https://jax.quantecon.org/aiyagari_jax.html)[](https://jax.quantecon.org/arellano.html)[](https://jax.quantecon.org/autodiff.html)[](https://jax.quantecon.org/hopenhayn.html)[](https://jax.quantecon.org/ifp_dl.html)[](https://jax.quantecon.org/ifp_egm.html)[](https://jax.quantecon.org/intro.html)[](https://jax.quantecon.org/inventory_dynamics.html)[](https://jax.quantecon.org/inventory_ssd.html)[](https://jax.quantecon.org/jax_intro.html)[](https://jax.quantecon.org/jax_nn.html)[](https://jax.quantecon.org/job_search.html)[](https://jax.quantecon.org/keras.html)[](https://jax.quantecon.org/kesten_processes.html)[](https://jax.quantecon.org/lucas_model.html)[](https://jax.quantecon.org/markov_asset.html)[](https://jax.quantecon.org/mle.html)[](https://jax.quantecon.org/newtons_method.html)[](https://jax.quantecon.org/opt_invest.html)[](https://jax.quantecon.org/opt_savings_1.html)[](https://jax.quantecon.org/opt_savings_2.html)[](https://jax.quantecon.org/overborrowing.html)[](https://jax.quantecon.org/short_path.html)[](https://jax.quantecon.org/.html)[](https://jax.quantecon.org/troubleshooting.html)[](https://jax.quantecon.org/wealth_dynamics.html)[](https://jax.quantecon.org/zreferences.html)|Document|Modified|Method|Run Time (s)|Status|\n",
24 | "|:------------------:|:------------------:|:------------------:|:------------------:|:------------------:|\n",
25 | "|aiyagari_jax|2025-12-08 03:38|cache|67.79|✅|\n",
26 | "|arellano|2025-12-08 03:38|cache|22.96|✅|\n",
27 | "|autodiff|2025-12-08 03:38|cache|13.56|✅|\n",
28 | "|hopenhayn|2025-12-08 03:39|cache|24.93|✅|\n",
29 | "|ifp_dl|2025-12-08 03:39|cache|35.94|✅|\n",
30 | "|ifp_egm|2025-12-08 03:42|cache|144.16|✅|\n",
31 | "|intro|2025-12-08 03:42|cache|0.9|✅|\n",
32 | "|inventory_dynamics|2025-12-08 03:42|cache|8.71|✅|\n",
33 | "|inventory_ssd|2025-12-08 03:42|cache|10.03|✅|\n",
34 | "|jax_intro|2025-12-08 03:42|cache|19.67|✅|\n",
35 | "|jax_nn|2025-12-09 01:44|cache|82.48|✅|\n",
36 | "|job_search|2025-12-08 03:43|cache|10.42|✅|\n",
37 | "|keras|2025-12-08 03:44|cache|26.9|✅|\n",
38 | "|kesten_processes|2025-12-08 03:44|cache|11.27|✅|\n",
39 | "|lucas_model|2025-12-08 03:44|cache|18.51|✅|\n",
40 | "|markov_asset|2025-12-08 03:44|cache|12.2|✅|\n",
41 | "|mle|2025-12-08 03:45|cache|16.87|✅|\n",
42 | "|newtons_method|2025-12-08 03:48|cache|190.25|✅|\n",
43 | "|opt_invest|2025-12-08 03:48|cache|22.8|✅|\n",
44 | "|opt_savings_1|2025-12-08 03:49|cache|43.51|✅|\n",
45 | "|opt_savings_2|2025-12-08 03:49|cache|21.13|✅|\n",
46 | "|overborrowing|2025-12-08 03:50|cache|24.55|✅|\n",
47 | "|short_path|2025-12-08 03:50|cache|4.31|✅|\n",
48 | "|status|2025-12-08 03:50|cache|7.24|✅|\n",
49 | "|troubleshooting|2025-12-08 03:42|cache|0.9|✅|\n",
50 | "|wealth_dynamics|2025-12-08 03:53|cache|153.26|✅|\n",
51 | "|zreferences|2025-12-08 03:42|cache|0.9|✅|\n",
52 | "\n",
53 | "\n",
54 | "These lectures are built on `linux` instances through `github actions` that has\n",
55 | "access to a `gpu`. These lectures make use of the nvidia `T4` card.\n",
56 | "\n",
57 | "These lectures are using the following python version"
58 | ]
59 | },
60 | {
61 | "cell_type": "code",
62 | "execution_count": null,
63 | "id": "7ce3e17d",
64 | "metadata": {
65 | "hide-output": false
66 | },
67 | "outputs": [],
68 | "source": [
69 | "!python --version"
70 | ]
71 | },
72 | {
73 | "cell_type": "markdown",
74 | "id": "6c50bce6",
75 | "metadata": {},
76 | "source": [
77 | "and the following package versions"
78 | ]
79 | },
80 | {
81 | "cell_type": "code",
82 | "execution_count": null,
83 | "id": "3a2e3d21",
84 | "metadata": {
85 | "hide-output": false
86 | },
87 | "outputs": [],
88 | "source": [
89 | "!conda list"
90 | ]
91 | },
92 | {
93 | "cell_type": "markdown",
94 | "id": "a53e8185",
95 | "metadata": {},
96 | "source": [
97 | "You can check the backend used by JAX using:"
98 | ]
99 | },
100 | {
101 | "cell_type": "code",
102 | "execution_count": null,
103 | "id": "a6c9a78e",
104 | "metadata": {
105 | "hide-output": false
106 | },
107 | "outputs": [],
108 | "source": [
109 | "import jax\n",
110 | "# Check if JAX is using GPU\n",
111 | "print(f\"JAX backend: {jax.devices()[0].platform}\")"
112 | ]
113 | },
114 | {
115 | "cell_type": "markdown",
116 | "id": "ae4bcca3",
117 | "metadata": {},
118 | "source": [
119 | "and the hardware we are running on:"
120 | ]
121 | },
122 | {
123 | "cell_type": "code",
124 | "execution_count": null,
125 | "id": "5f456c05",
126 | "metadata": {
127 | "hide-output": false
128 | },
129 | "outputs": [],
130 | "source": [
131 | "!nvidia-smi"
132 | ]
133 | }
134 | ],
135 | "metadata": {
136 | "date": 1765244755.7709463,
137 | "filename": "status.md",
138 | "kernelspec": {
139 | "display_name": "Python",
140 | "language": "python3",
141 | "name": "python3"
142 | },
143 | "title": "Execution Statistics"
144 | },
145 | "nbformat": 4,
146 | "nbformat_minor": 5
147 | }
--------------------------------------------------------------------------------
/about.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "630c2ecf",
6 | "metadata": {},
7 | "source": [
8 | "$$\n",
9 | "\\newcommand{\\argmax}{arg\\,max}\n",
10 | "\\newcommand{\\argmin}{arg\\,min}\n",
11 | "$$"
12 | ]
13 | },
14 | {
15 | "cell_type": "markdown",
16 | "id": "1bedc334",
17 | "metadata": {},
18 | "source": [
19 | "# About\n",
20 | "\n",
21 | "Perhaps the single most notable feature of scientific computing in the past\n",
22 | "two decades is the rise and rise of parallel computation.\n",
23 | "\n",
24 | "For example, the advanced artificial intelligence applications now shaking the\n",
25 | "worlds of business and academia require massive computer power to train, and the\n",
26 | "great majority of that computer power is supplied by GPUs.\n",
27 | "\n",
28 | "For us economists, with our ever-growing need for more compute cycles,\n",
29 | "parallel computing provides both opportunities and new difficulties.\n",
30 | "\n",
31 | "The main difficulty we face vis-a-vis parallel computation is accessibility.\n",
32 | "\n",
33 | "Even for those with time to invest in careful parallelization of their programs,\n",
34 | "exploiting the full power of parallel hardware is challenging for non-experts.\n",
35 | "\n",
36 | "Moreover, that hardware changes from year to year, so any human capital\n",
37 | "associated with mastering intricacies of a particular GPU has a very high\n",
38 | "depreciation rate.\n",
39 | "\n",
40 | "For these reasons, we find [Google JAX](https://github.com/google/jax) compelling.\n",
41 | "\n",
42 | "In short, JAX makes high performance and parallel computing accessible (and fun!).\n",
43 | "\n",
44 | "It provides a familiar array programming interface based on NumPy, and, as long as\n",
45 | "some simple conventions are adhered to, this code compiles to extremely\n",
46 | "efficient and well-parallelized machine code.\n",
47 | "\n",
48 | "One of the most agreeable features of JAX is that the same code set and be run on\n",
49 | "either CPUs or GPUs, which allows users to test and develop locally, before\n",
50 | "deploying to a more powerful machine for heavier computations.\n",
51 | "\n",
52 | "JAX is relatively easy to learn and highly portable, allowing us programmers to\n",
53 | "focus on the algorithms we want to implement, rather than particular features of\n",
54 | "our hardware.\n",
55 | "\n",
56 | "This lecture series provides an introduction to using Google JAX for\n",
57 | "quantitative economics.\n",
58 | "\n",
59 | "The rest of this page provides some background information on JAX, notes on\n",
60 | "how to run the lectures, and credits for our colleagues and RAs."
61 | ]
62 | },
63 | {
64 | "cell_type": "markdown",
65 | "id": "661f3734",
66 | "metadata": {},
67 | "source": [
68 | "## What is JAX?\n",
69 | "\n",
70 | "JAX is an open source Python library developed by Google Research to support\n",
71 | "in-house artificial intelligence and machine learning.\n",
72 | "\n",
73 | "JAX provides data types, functions and a compiler for fast linear\n",
74 | "algebra operations and automatic differentiation.\n",
75 | "\n",
76 | "Loosely speaking, JAX is like [NumPy](https://numpy.org/) with the addition of\n",
77 | "\n",
78 | "- automatic differentiation \n",
79 | "- automated GPU/TPU support \n",
80 | "- a just-in-time compiler \n",
81 | "\n",
82 | "\n",
83 | "In short, JAX delivers\n",
84 | "\n",
85 | "1. high execution speeds on CPUs due to efficient parallelization and JIT\n",
86 | " compilation, \n",
87 | "1. a powerful and convenient environment for GPU programming, and \n",
88 | "1. the ability to efficiently differentiate smooth functions for optimization\n",
89 | " and estimation. \n",
90 | "\n",
91 | "\n",
92 | "These features make JAX ideal for almost all quantitative economic modeling\n",
93 | "problems that require heavy-duty computing."
94 | ]
95 | },
96 | {
97 | "cell_type": "markdown",
98 | "id": "ad470294",
99 | "metadata": {},
100 | "source": [
101 | "## How to run these lectures\n",
102 | "\n",
103 | "The easiest way to run these lectures is via [Google Colab](https://colab.research.google.com/).\n",
104 | "\n",
105 | "JAX is pre-installed with GPU support on Colab and Colab provides GPU access\n",
106 | "even on the free tier.\n",
107 | "\n",
108 | "Each lecture has a “play” button on the top right that you can use to launch the\n",
109 | "lecture on Colab.\n",
110 | "\n",
111 | "You might also like to try using JAX locally.\n",
112 | "\n",
113 | "If you do not own a GPU, you can still install JAX for the CPU by following the relevant [install instructions](https://github.com/google/jax).\n",
114 | "\n",
115 | "(We recommend that you install [Anaconda\n",
116 | "Python](https://www.anaconda.com/download) first.)\n",
117 | "\n",
118 | "If you do have a GPU, you can try installing JAX for the GPU by following the\n",
119 | "install instructions for GPUs.\n",
120 | "\n",
121 | "(This is not always trivial but is starting to get easier.)"
122 | ]
123 | },
124 | {
125 | "cell_type": "markdown",
126 | "id": "cb0922fd",
127 | "metadata": {},
128 | "source": [
129 | "## Credits\n",
130 | "\n",
131 | "In building this lecture series, we had invaluable assistance from research\n",
132 | "assistants at QuantEcon and our QuantEcon colleagues.\n",
133 | "\n",
134 | "In particular, we thank and credit\n",
135 | "\n",
136 | "- [Shu Hu](https://github.com/shlff) \n",
137 | "- [Smit Lunagariya](https://github.com/Smit-create) \n",
138 | "- [Matthew McKay](https://github.com/mmcky) \n",
139 | "- [Humphrey Yang](https://github.com/HumphreyYang) \n",
140 | "- [Hengcheng Zhang](https://github.com/HengchengZhang) \n",
141 | "- [Frank Wu](https://github.com/chappiewuzefan) "
142 | ]
143 | },
144 | {
145 | "cell_type": "markdown",
146 | "id": "66f95e11",
147 | "metadata": {},
148 | "source": [
149 | "## Prerequisites\n",
150 | "\n",
151 | "We assume that readers have covered most of the QuantEcon lecture\n",
152 | "series [on Python programming](https://python-programming.quantecon.org/intro.html)."
153 | ]
154 | }
155 | ],
156 | "metadata": {
157 | "date": 1765244754.897235,
158 | "filename": "about.md",
159 | "kernelspec": {
160 | "display_name": "Python",
161 | "language": "python3",
162 | "name": "python3"
163 | },
164 | "title": "About"
165 | },
166 | "nbformat": 4,
167 | "nbformat_minor": 5
168 | }
--------------------------------------------------------------------------------
/keras.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "9e08cab2",
6 | "metadata": {},
7 | "source": [
8 | "$$\n",
9 | "\\newcommand{\\argmax}{arg\\,max}\n",
10 | "\\newcommand{\\argmin}{arg\\,min}\n",
11 | "$$"
12 | ]
13 | },
14 | {
15 | "cell_type": "markdown",
16 | "id": "20950cb4",
17 | "metadata": {},
18 | "source": [
19 | "# Simple Neural Network Regression with Keras and JAX"
20 | ]
21 | },
22 | {
23 | "cell_type": "markdown",
24 | "id": "4541aa5a",
25 | "metadata": {},
26 | "source": [
27 | "# GPU\n",
28 | "\n",
29 | "This lecture was built using a machine with access to a GPU.\n",
30 | "\n",
31 | "[Google Colab](https://colab.research.google.com/) has a free tier with GPUs\n",
32 | "that you can access as follows:\n",
33 | "\n",
34 | "1. Click on the “play” icon top right \n",
35 | "1. Select Colab \n",
36 | "1. Set the runtime environment to include a GPU \n",
37 | "\n",
38 | "\n",
39 | "In this lecture we show how to implement one-dimensional nonlinear regression\n",
40 | "using a neural network.\n",
41 | "\n",
42 | "We will use the popular deep learning library [Keras](https://keras.io/), which\n",
43 | "provides a simple interface to deep learning.\n",
44 | "\n",
45 | "The emphasis in Keras is on providing an intuitive API, while the heavy lifting is\n",
46 | "done by one of several possible backends.\n",
47 | "\n",
48 | "Currently the backend library options are Tensorflow, PyTorch, and JAX.\n",
49 | "\n",
50 | "In this lecture we will use JAX.\n",
51 | "\n",
52 | "The objective of this lecture is to provide a very simple introduction to deep\n",
53 | "learning in a regression setting.\n",
54 | "\n",
55 | "Later, in [a separate lecture](https://jax.quantecon.org/jax_nn.html), we will investigate how to do the same learning task using pure JAX, rather than relying on Keras.\n",
56 | "\n",
57 | "We begin this lecture with some standard imports."
58 | ]
59 | },
60 | {
61 | "cell_type": "code",
62 | "execution_count": null,
63 | "id": "9f9a25e7",
64 | "metadata": {
65 | "hide-output": false
66 | },
67 | "outputs": [],
68 | "source": [
69 | "import numpy as np\n",
70 | "import matplotlib.pyplot as plt"
71 | ]
72 | },
73 | {
74 | "cell_type": "markdown",
75 | "id": "b882d8c1",
76 | "metadata": {},
77 | "source": [
78 | "Let’s install Keras."
79 | ]
80 | },
81 | {
82 | "cell_type": "code",
83 | "execution_count": null,
84 | "id": "a2f2307f",
85 | "metadata": {
86 | "hide-output": false
87 | },
88 | "outputs": [],
89 | "source": [
90 | "!pip install keras"
91 | ]
92 | },
93 | {
94 | "cell_type": "markdown",
95 | "id": "70cd4f78",
96 | "metadata": {},
97 | "source": [
98 | "Now we specify that the desired backend is JAX."
99 | ]
100 | },
101 | {
102 | "cell_type": "code",
103 | "execution_count": null,
104 | "id": "45d8c44b",
105 | "metadata": {
106 | "hide-output": false
107 | },
108 | "outputs": [],
109 | "source": [
110 | "import os\n",
111 | "os.environ['KERAS_BACKEND'] = 'jax'"
112 | ]
113 | },
114 | {
115 | "cell_type": "markdown",
116 | "id": "4af49655",
117 | "metadata": {},
118 | "source": [
119 | "Now we should be able to import some tools from Keras.\n",
120 | "\n",
121 | "(Without setting the backend to JAX, these imports might fail – unless you have PyTorch or Tensorflow set up.)"
122 | ]
123 | },
124 | {
125 | "cell_type": "code",
126 | "execution_count": null,
127 | "id": "cea8ca1c",
128 | "metadata": {
129 | "hide-output": false
130 | },
131 | "outputs": [],
132 | "source": [
133 | "import keras\n",
134 | "from keras import Sequential\n",
135 | "from keras.layers import Dense"
136 | ]
137 | },
138 | {
139 | "cell_type": "markdown",
140 | "id": "679d6c73",
141 | "metadata": {},
142 | "source": [
143 | "## Data\n",
144 | "\n",
145 | "First let’s write a function to generate some data.\n",
146 | "\n",
147 | "The data has the form\n",
148 | "\n",
149 | "$$\n",
150 | "y_i = f(x_i) + \\epsilon_i,\n",
151 | " \\qquad i=1, \\ldots, n,\n",
152 | "$$\n",
153 | "\n",
154 | "where\n",
155 | "\n",
156 | "- the input sequence $ (x_i) $ is an evenly-spaced grid, \n",
157 | "- $ f $ is a nonlinear transformation, and \n",
158 | "- each $ \\epsilon_i $ is independent white noise. \n",
159 | "\n",
160 | "\n",
161 | "Here’s the function that creates vectors `x` and `y` according to the rule\n",
162 | "above."
163 | ]
164 | },
165 | {
166 | "cell_type": "code",
167 | "execution_count": null,
168 | "id": "ce4068d7",
169 | "metadata": {
170 | "hide-output": false
171 | },
172 | "outputs": [],
173 | "source": [
174 | "def generate_data(x_min=0, # Minimum x value\n",
175 | " x_max=5, # Max x value\n",
176 | " data_size=400, # Default size for dataset\n",
177 | " seed=1234):\n",
178 | " np.random.seed(seed)\n",
179 | " x = np.linspace(x_min, x_max, num=data_size)\n",
180 | " \n",
181 | " ϵ = 0.2 * np.random.randn(data_size)\n",
182 | " y = x**0.5 + np.sin(x) + ϵ\n",
183 | " # Keras expects two dimensions, not flat arrays\n",
184 | " x, y = [np.reshape(z, (data_size, 1)) for z in (x, y)]\n",
185 | " return x, y"
186 | ]
187 | },
188 | {
189 | "cell_type": "markdown",
190 | "id": "74c096a7",
191 | "metadata": {},
192 | "source": [
193 | "Now we generate some data to train the model."
194 | ]
195 | },
196 | {
197 | "cell_type": "code",
198 | "execution_count": null,
199 | "id": "2b2c8b21",
200 | "metadata": {
201 | "hide-output": false
202 | },
203 | "outputs": [],
204 | "source": [
205 | "x, y = generate_data()"
206 | ]
207 | },
208 | {
209 | "cell_type": "markdown",
210 | "id": "8e828fe3",
211 | "metadata": {},
212 | "source": [
213 | "Here’s a plot of the training data."
214 | ]
215 | },
216 | {
217 | "cell_type": "code",
218 | "execution_count": null,
219 | "id": "1e608434",
220 | "metadata": {
221 | "hide-output": false
222 | },
223 | "outputs": [],
224 | "source": [
225 | "fig, ax = plt.subplots()\n",
226 | "ax.scatter(x, y)\n",
227 | "ax.set_xlabel('x')\n",
228 | "ax.set_ylabel('y')\n",
229 | "plt.show()"
230 | ]
231 | },
232 | {
233 | "cell_type": "markdown",
234 | "id": "d8f8558b",
235 | "metadata": {},
236 | "source": [
237 | "We’ll also use data from the same process for cross-validation."
238 | ]
239 | },
240 | {
241 | "cell_type": "code",
242 | "execution_count": null,
243 | "id": "bdcde9da",
244 | "metadata": {
245 | "hide-output": false
246 | },
247 | "outputs": [],
248 | "source": [
249 | "x_validate, y_validate = generate_data()"
250 | ]
251 | },
252 | {
253 | "cell_type": "markdown",
254 | "id": "f8b78914",
255 | "metadata": {},
256 | "source": [
257 | "## Models\n",
258 | "\n",
259 | "We supply functions to build two types of models."
260 | ]
261 | },
262 | {
263 | "cell_type": "markdown",
264 | "id": "1840963d",
265 | "metadata": {},
266 | "source": [
267 | "## Regression model\n",
268 | "\n",
269 | "The first implements linear regression.\n",
270 | "\n",
271 | "This is achieved by constructing a neural network with just one layer, that maps\n",
272 | "to a single dimension (since the prediction is real-valued).\n",
273 | "\n",
274 | "The object `model` will be an instance of `keras.Sequential`, which is used to\n",
275 | "group a stack of layers into a single prediction model."
276 | ]
277 | },
278 | {
279 | "cell_type": "code",
280 | "execution_count": null,
281 | "id": "0f1ba3e5",
282 | "metadata": {
283 | "hide-output": false
284 | },
285 | "outputs": [],
286 | "source": [
287 | "def build_regression_model():\n",
288 | " # Generate an instance of Sequential, to store layers and training attributes\n",
289 | " model = Sequential()\n",
290 | " # Add a single layer with scalar output\n",
291 | " model.add(Dense(units=1)) \n",
292 | " # Configure the model for training\n",
293 | " model.compile(optimizer=keras.optimizers.SGD(), \n",
294 | " loss='mean_squared_error')\n",
295 | " return model"
296 | ]
297 | },
298 | {
299 | "cell_type": "markdown",
300 | "id": "2bef39a1",
301 | "metadata": {},
302 | "source": [
303 | "In the function above you can see that\n",
304 | "\n",
305 | "- we use stochastic gradient descent to train the model, and \n",
306 | "- the loss is mean squared error (MSE). \n",
307 | "\n",
308 | "\n",
309 | "The call `model.add` adds a single layer the activation function equal to the identity map.\n",
310 | "\n",
311 | "MSE is the standard loss function for ordinary least squares regression."
312 | ]
313 | },
314 | {
315 | "cell_type": "markdown",
316 | "id": "23382ccf",
317 | "metadata": {},
318 | "source": [
319 | "### Deep Network\n",
320 | "\n",
321 | "The second function creates a dense (i.e., fully connected) neural network with\n",
322 | "3 hidden layers, where each hidden layer maps to a k-dimensional output space."
323 | ]
324 | },
325 | {
326 | "cell_type": "code",
327 | "execution_count": null,
328 | "id": "052ff4e5",
329 | "metadata": {
330 | "hide-output": false
331 | },
332 | "outputs": [],
333 | "source": [
334 | "def build_nn_model(output_dim=10, num_layers=3, activation_function='tanh'):\n",
335 | " # Create a Keras Model instance using Sequential()\n",
336 | " model = Sequential()\n",
337 | " # Add layers to the network sequentially, from inputs towards outputs\n",
338 | " for i in range(num_layers):\n",
339 | " model.add(Dense(units=output_dim, activation=activation_function))\n",
340 | " # Add a final layer that maps to a scalar value, for regression.\n",
341 | " model.add(Dense(units=1))\n",
342 | " # Embed training configurations\n",
343 | " model.compile(optimizer=keras.optimizers.SGD(), \n",
344 | " loss='mean_squared_error')\n",
345 | " return model"
346 | ]
347 | },
348 | {
349 | "cell_type": "markdown",
350 | "id": "c58d1b6d",
351 | "metadata": {},
352 | "source": [
353 | "### Tracking errors\n",
354 | "\n",
355 | "The following function will be used to plot the MSE of the model during the\n",
356 | "training process.\n",
357 | "\n",
358 | "Initially the MSE will be relatively high, but it should fall at each iteration,\n",
359 | "as the parameters are adjusted to better fit the data."
360 | ]
361 | },
362 | {
363 | "cell_type": "code",
364 | "execution_count": null,
365 | "id": "affb57ec",
366 | "metadata": {
367 | "hide-output": false
368 | },
369 | "outputs": [],
370 | "source": [
371 | "def plot_loss_history(training_history, ax):\n",
372 | " # Plot MSE of training data against epoch\n",
373 | " epochs = training_history.epoch\n",
374 | " ax.plot(epochs, \n",
375 | " np.array(training_history.history['loss']), \n",
376 | " label='training loss')\n",
377 | " # Plot MSE of validation data against epoch\n",
378 | " ax.plot(epochs, \n",
379 | " np.array(training_history.history['val_loss']),\n",
380 | " label='validation loss')\n",
381 | " # Add labels\n",
382 | " ax.set_xlabel('Epoch')\n",
383 | " ax.set_ylabel('Loss (Mean squared error)')\n",
384 | " ax.legend()"
385 | ]
386 | },
387 | {
388 | "cell_type": "markdown",
389 | "id": "2ccc4fdf",
390 | "metadata": {},
391 | "source": [
392 | "## Training\n",
393 | "\n",
394 | "Now let’s go ahead and train our models."
395 | ]
396 | },
397 | {
398 | "cell_type": "markdown",
399 | "id": "51c299f9",
400 | "metadata": {},
401 | "source": [
402 | "### Linear regression\n",
403 | "\n",
404 | "We’ll start with linear regression."
405 | ]
406 | },
407 | {
408 | "cell_type": "code",
409 | "execution_count": null,
410 | "id": "1a3fe5e3",
411 | "metadata": {
412 | "hide-output": false
413 | },
414 | "outputs": [],
415 | "source": [
416 | "regression_model = build_regression_model()"
417 | ]
418 | },
419 | {
420 | "cell_type": "markdown",
421 | "id": "386b7ac3",
422 | "metadata": {},
423 | "source": [
424 | "Now we train the model using the training data."
425 | ]
426 | },
427 | {
428 | "cell_type": "code",
429 | "execution_count": null,
430 | "id": "1297d274",
431 | "metadata": {
432 | "hide-output": false
433 | },
434 | "outputs": [],
435 | "source": [
436 | "training_history = regression_model.fit(\n",
437 | " x, y, batch_size=x.shape[0], verbose=0,\n",
438 | " epochs=2000, validation_data=(x_validate, y_validate))"
439 | ]
440 | },
441 | {
442 | "cell_type": "markdown",
443 | "id": "e732a40d",
444 | "metadata": {},
445 | "source": [
446 | "Let’s have a look at the evolution of MSE as the model is trained."
447 | ]
448 | },
449 | {
450 | "cell_type": "code",
451 | "execution_count": null,
452 | "id": "a024eebe",
453 | "metadata": {
454 | "hide-output": false
455 | },
456 | "outputs": [],
457 | "source": [
458 | "fig, ax = plt.subplots()\n",
459 | "plot_loss_history(training_history, ax)\n",
460 | "plt.show()"
461 | ]
462 | },
463 | {
464 | "cell_type": "markdown",
465 | "id": "18024b5f",
466 | "metadata": {},
467 | "source": [
468 | "Let’s print the final MSE on the cross-validation data."
469 | ]
470 | },
471 | {
472 | "cell_type": "code",
473 | "execution_count": null,
474 | "id": "17053ff3",
475 | "metadata": {
476 | "hide-output": false
477 | },
478 | "outputs": [],
479 | "source": [
480 | "print(\"Testing loss on the validation set.\")\n",
481 | "regression_model.evaluate(x_validate, y_validate, verbose=2)"
482 | ]
483 | },
484 | {
485 | "cell_type": "markdown",
486 | "id": "b2bf21a0",
487 | "metadata": {},
488 | "source": [
489 | "Here’s our output predictions on the cross-validation data."
490 | ]
491 | },
492 | {
493 | "cell_type": "code",
494 | "execution_count": null,
495 | "id": "1efc6ace",
496 | "metadata": {
497 | "hide-output": false
498 | },
499 | "outputs": [],
500 | "source": [
501 | "y_predict = regression_model.predict(x_validate, verbose=2)"
502 | ]
503 | },
504 | {
505 | "cell_type": "markdown",
506 | "id": "52bdf066",
507 | "metadata": {},
508 | "source": [
509 | "We use the following function to plot our predictions along with the data."
510 | ]
511 | },
512 | {
513 | "cell_type": "code",
514 | "execution_count": null,
515 | "id": "4e2358c4",
516 | "metadata": {
517 | "hide-output": false
518 | },
519 | "outputs": [],
520 | "source": [
521 | "def plot_results(x, y, y_predict, ax):\n",
522 | " ax.scatter(x, y)\n",
523 | " ax.plot(x, y_predict, label=\"fitted model\", color='black')\n",
524 | " ax.set_xlabel('x')\n",
525 | " ax.set_ylabel('y')"
526 | ]
527 | },
528 | {
529 | "cell_type": "markdown",
530 | "id": "308c497b",
531 | "metadata": {},
532 | "source": [
533 | "Let’s now call the function on the cross-validation data."
534 | ]
535 | },
536 | {
537 | "cell_type": "code",
538 | "execution_count": null,
539 | "id": "5a030972",
540 | "metadata": {
541 | "hide-output": false
542 | },
543 | "outputs": [],
544 | "source": [
545 | "fig, ax = plt.subplots()\n",
546 | "plot_results(x_validate, y_validate, y_predict, ax)\n",
547 | "plt.show()"
548 | ]
549 | },
550 | {
551 | "cell_type": "markdown",
552 | "id": "fb193972",
553 | "metadata": {},
554 | "source": [
555 | "### Deep learning\n",
556 | "\n",
557 | "Now let’s switch to a neural network with multiple layers.\n",
558 | "\n",
559 | "We implement the same steps as before."
560 | ]
561 | },
562 | {
563 | "cell_type": "code",
564 | "execution_count": null,
565 | "id": "7b662ebb",
566 | "metadata": {
567 | "hide-output": false
568 | },
569 | "outputs": [],
570 | "source": [
571 | "nn_model = build_nn_model()"
572 | ]
573 | },
574 | {
575 | "cell_type": "code",
576 | "execution_count": null,
577 | "id": "313b8bc3",
578 | "metadata": {
579 | "hide-output": false
580 | },
581 | "outputs": [],
582 | "source": [
583 | "training_history = nn_model.fit(\n",
584 | " x, y, batch_size=x.shape[0], verbose=0,\n",
585 | " epochs=2000, validation_data=(x_validate, y_validate))"
586 | ]
587 | },
588 | {
589 | "cell_type": "code",
590 | "execution_count": null,
591 | "id": "9f42d477",
592 | "metadata": {
593 | "hide-output": false
594 | },
595 | "outputs": [],
596 | "source": [
597 | "fig, ax = plt.subplots()\n",
598 | "plot_loss_history(training_history, ax)\n",
599 | "plt.show()"
600 | ]
601 | },
602 | {
603 | "cell_type": "markdown",
604 | "id": "b5892fec",
605 | "metadata": {},
606 | "source": [
607 | "Here’s the final MSE for the deep learning model."
608 | ]
609 | },
610 | {
611 | "cell_type": "code",
612 | "execution_count": null,
613 | "id": "b330351c",
614 | "metadata": {
615 | "hide-output": false
616 | },
617 | "outputs": [],
618 | "source": [
619 | "print(\"Testing loss on the validation set.\")\n",
620 | "nn_model.evaluate(x_validate, y_validate, verbose=2)"
621 | ]
622 | },
623 | {
624 | "cell_type": "markdown",
625 | "id": "e1a6512b",
626 | "metadata": {},
627 | "source": [
628 | "You will notice that this loss is much lower than the one we achieved with\n",
629 | "linear regression, suggesting a better fit.\n",
630 | "\n",
631 | "To confirm this, let’s look at the fitted function."
632 | ]
633 | },
634 | {
635 | "cell_type": "code",
636 | "execution_count": null,
637 | "id": "3fdabaff",
638 | "metadata": {
639 | "hide-output": false
640 | },
641 | "outputs": [],
642 | "source": [
643 | "y_predict = nn_model.predict(x_validate, verbose=2)"
644 | ]
645 | },
646 | {
647 | "cell_type": "code",
648 | "execution_count": null,
649 | "id": "ead12d31",
650 | "metadata": {
651 | "hide-output": false
652 | },
653 | "outputs": [],
654 | "source": [
655 | "def plot_results(x, y, y_predict, ax):\n",
656 | " ax.scatter(x, y)\n",
657 | " ax.plot(x, y_predict, label=\"fitted model\", color='black')\n",
658 | " ax.set_xlabel('x')\n",
659 | " ax.set_ylabel('y')"
660 | ]
661 | },
662 | {
663 | "cell_type": "code",
664 | "execution_count": null,
665 | "id": "febff227",
666 | "metadata": {
667 | "hide-output": false
668 | },
669 | "outputs": [],
670 | "source": [
671 | "fig, ax = plt.subplots()\n",
672 | "plot_results(x_validate, y_validate, y_predict, ax)\n",
673 | "plt.show()"
674 | ]
675 | },
676 | {
677 | "cell_type": "markdown",
678 | "id": "64d31a5c",
679 | "metadata": {},
680 | "source": [
681 | "Not surprisingly, the multilayer neural network does a much better job of fitting the data.\n",
682 | "\n",
683 | "In a [a follow-up lecture](https://jax.quantecon.org/jax_nn.html), we will try to achieve the same fit using pure JAX, rather than relying on the Keras front-end."
684 | ]
685 | }
686 | ],
687 | "metadata": {
688 | "date": 1765244755.4090016,
689 | "filename": "keras.md",
690 | "kernelspec": {
691 | "display_name": "Python",
692 | "language": "python3",
693 | "name": "python3"
694 | },
695 | "title": "Simple Neural Network Regression with Keras and JAX"
696 | },
697 | "nbformat": 4,
698 | "nbformat_minor": 5
699 | }
--------------------------------------------------------------------------------
/job_search.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "e4cb52e5",
6 | "metadata": {},
7 | "source": [
8 | "$$\n",
9 | "\\newcommand{\\argmax}{arg\\,max}\n",
10 | "\\newcommand{\\argmin}{arg\\,min}\n",
11 | "$$"
12 | ]
13 | },
14 | {
15 | "cell_type": "markdown",
16 | "id": "a408abde",
17 | "metadata": {},
18 | "source": [
19 | "# Job Search"
20 | ]
21 | },
22 | {
23 | "cell_type": "markdown",
24 | "id": "720c7ba3",
25 | "metadata": {},
26 | "source": [
27 | "# GPU\n",
28 | "\n",
29 | "This lecture was built using a machine with access to a GPU.\n",
30 | "\n",
31 | "[Google Colab](https://colab.research.google.com/) has a free tier with GPUs\n",
32 | "that you can access as follows:\n",
33 | "\n",
34 | "1. Click on the “play” icon top right \n",
35 | "1. Select Colab \n",
36 | "1. Set the runtime environment to include a GPU \n",
37 | "\n",
38 | "\n",
39 | "In this lecture we study a basic infinite-horizon job search problem with Markov wage\n",
40 | "draws\n",
41 | "\n",
42 | ">**Note**\n",
43 | ">\n",
44 | ">For background on infinite horizon job search see, e.g., [DP1](https://dp.quantecon.org/).\n",
45 | "\n",
46 | "The exercise at the end asks you to add risk-sensitive preferences and see how\n",
47 | "the main results change.\n",
48 | "\n",
49 | "In addition to what’s in Anaconda, this lecture will need the following libraries:"
50 | ]
51 | },
52 | {
53 | "cell_type": "code",
54 | "execution_count": null,
55 | "id": "f3bd6c7c",
56 | "metadata": {
57 | "hide-output": false
58 | },
59 | "outputs": [],
60 | "source": [
61 | "!pip install quantecon"
62 | ]
63 | },
64 | {
65 | "cell_type": "markdown",
66 | "id": "f76d9ee5",
67 | "metadata": {},
68 | "source": [
69 | "We use the following imports."
70 | ]
71 | },
72 | {
73 | "cell_type": "code",
74 | "execution_count": null,
75 | "id": "5d8adf8d",
76 | "metadata": {
77 | "hide-output": false
78 | },
79 | "outputs": [],
80 | "source": [
81 | "import matplotlib.pyplot as plt\n",
82 | "import quantecon as qe\n",
83 | "import jax\n",
84 | "import jax.numpy as jnp\n",
85 | "from collections import namedtuple\n",
86 | "\n",
87 | "jax.config.update(\"jax_enable_x64\", True)"
88 | ]
89 | },
90 | {
91 | "cell_type": "markdown",
92 | "id": "33107907",
93 | "metadata": {},
94 | "source": [
95 | "## Model\n",
96 | "\n",
97 | "We study an elementary model where\n",
98 | "\n",
99 | "- jobs are permanent \n",
100 | "- unemployed workers receive current compensation $ c $ \n",
101 | "- the horizon is infinite \n",
102 | "- an unemployment agent discounts the future via discount factor $ \\beta \\in (0,1) $ "
103 | ]
104 | },
105 | {
106 | "cell_type": "markdown",
107 | "id": "53ace009",
108 | "metadata": {},
109 | "source": [
110 | "### Set up\n",
111 | "\n",
112 | "At the start of each period, an unemployed worker receives wage offer $ W_t $.\n",
113 | "\n",
114 | "To build a wage offer process we consider the dynamics\n",
115 | "\n",
116 | "$$\n",
117 | "W_{t+1} = \\rho W_t + \\nu Z_{t+1}\n",
118 | "$$\n",
119 | "\n",
120 | "where $ (Z_t)_{t \\geq 0} $ is IID and standard normal.\n",
121 | "\n",
122 | "We then discretize this wage process using Tauchen’s method to produce a stochastic matrix $ P $.\n",
123 | "\n",
124 | "Successive wage offers are drawn from $ P $."
125 | ]
126 | },
127 | {
128 | "cell_type": "markdown",
129 | "id": "2352ba3a",
130 | "metadata": {},
131 | "source": [
132 | "### Rewards\n",
133 | "\n",
134 | "Since jobs are permanent, the return to accepting wage offer $ w $ today is\n",
135 | "\n",
136 | "$$\n",
137 | "w + \\beta w + \\beta^2 w + \n",
138 | " \\cdots = \\frac{w}{1-\\beta}\n",
139 | "$$\n",
140 | "\n",
141 | "The Bellman equation is\n",
142 | "\n",
143 | "$$\n",
144 | "v(w) = \\max\n",
145 | " \\left\\{\n",
146 | " \\frac{w}{1-\\beta}, c + \\beta \\sum_{w'} v(w') P(w, w')\n",
147 | " \\right\\}\n",
148 | "$$\n",
149 | "\n",
150 | "We solve this model using value function iteration."
151 | ]
152 | },
153 | {
154 | "cell_type": "markdown",
155 | "id": "f2415736",
156 | "metadata": {},
157 | "source": [
158 | "## Code\n",
159 | "\n",
160 | "Let’s set up a `namedtuple` to store information needed to solve the model."
161 | ]
162 | },
163 | {
164 | "cell_type": "code",
165 | "execution_count": null,
166 | "id": "32dcf609",
167 | "metadata": {
168 | "hide-output": false
169 | },
170 | "outputs": [],
171 | "source": [
172 | "Model = namedtuple('Model', ('n', 'w_vals', 'P', 'β', 'c'))"
173 | ]
174 | },
175 | {
176 | "cell_type": "markdown",
177 | "id": "a3f20539",
178 | "metadata": {},
179 | "source": [
180 | "The function below holds default values and populates the `namedtuple`."
181 | ]
182 | },
183 | {
184 | "cell_type": "code",
185 | "execution_count": null,
186 | "id": "40f12e37",
187 | "metadata": {
188 | "hide-output": false
189 | },
190 | "outputs": [],
191 | "source": [
192 | "def create_js_model(\n",
193 | " n=500, # wage grid size\n",
194 | " ρ=0.9, # wage persistence\n",
195 | " ν=0.2, # wage volatility\n",
196 | " β=0.99, # discount factor\n",
197 | " c=1.0, # unemployment compensation\n",
198 | " ):\n",
199 | " \"Creates an instance of the job search model with Markov wages.\"\n",
200 | " mc = qe.tauchen(n, ρ, ν)\n",
201 | " w_vals, P = jnp.exp(mc.state_values), jnp.array(mc.P)\n",
202 | " return Model(n, w_vals, P, β, c)"
203 | ]
204 | },
205 | {
206 | "cell_type": "markdown",
207 | "id": "cec2bb8a",
208 | "metadata": {},
209 | "source": [
210 | "Let’s test it:"
211 | ]
212 | },
213 | {
214 | "cell_type": "code",
215 | "execution_count": null,
216 | "id": "44705c2c",
217 | "metadata": {
218 | "hide-output": false
219 | },
220 | "outputs": [],
221 | "source": [
222 | "model = create_js_model(β=0.98)"
223 | ]
224 | },
225 | {
226 | "cell_type": "code",
227 | "execution_count": null,
228 | "id": "77e12fb3",
229 | "metadata": {
230 | "hide-output": false
231 | },
232 | "outputs": [],
233 | "source": [
234 | "model.c"
235 | ]
236 | },
237 | {
238 | "cell_type": "code",
239 | "execution_count": null,
240 | "id": "45d394f5",
241 | "metadata": {
242 | "hide-output": false
243 | },
244 | "outputs": [],
245 | "source": [
246 | "model.β"
247 | ]
248 | },
249 | {
250 | "cell_type": "code",
251 | "execution_count": null,
252 | "id": "a3e1375e",
253 | "metadata": {
254 | "hide-output": false
255 | },
256 | "outputs": [],
257 | "source": [
258 | "model.w_vals.mean() "
259 | ]
260 | },
261 | {
262 | "cell_type": "markdown",
263 | "id": "6effac50",
264 | "metadata": {},
265 | "source": [
266 | "Here’s the Bellman operator."
267 | ]
268 | },
269 | {
270 | "cell_type": "code",
271 | "execution_count": null,
272 | "id": "2e1b87de",
273 | "metadata": {
274 | "hide-output": false
275 | },
276 | "outputs": [],
277 | "source": [
278 | "@jax.jit\n",
279 | "def T(v, model):\n",
280 | " \"\"\"\n",
281 | " The Bellman operator Tv = max{e, c + β E v} with \n",
282 | "\n",
283 | " e(w) = w / (1-β) and (Ev)(w) = E_w[ v(W')]\n",
284 | "\n",
285 | " \"\"\"\n",
286 | " n, w_vals, P, β, c = model\n",
287 | " h = c + β * P @ v\n",
288 | " e = w_vals / (1 - β)\n",
289 | "\n",
290 | " return jnp.maximum(e, h)"
291 | ]
292 | },
293 | {
294 | "cell_type": "markdown",
295 | "id": "9c51c489",
296 | "metadata": {},
297 | "source": [
298 | "The next function computes the optimal policy under the assumption that $ v $ is\n",
299 | "the value function.\n",
300 | "\n",
301 | "The policy takes the form\n",
302 | "\n",
303 | "$$\n",
304 | "\\sigma(w) = \\mathbf 1 \n",
305 | " \\left\\{\n",
306 | " \\frac{w}{1-\\beta} \\geq c + \\beta \\sum_{w'} v(w') P(w, w')\n",
307 | " \\right\\}\n",
308 | "$$\n",
309 | "\n",
310 | "Here $ \\mathbf 1 $ is an indicator function.\n",
311 | "\n",
312 | "- $ \\sigma(w) = 1 $ means stop \n",
313 | "- $ \\sigma(w) = 0 $ means continue. "
314 | ]
315 | },
316 | {
317 | "cell_type": "code",
318 | "execution_count": null,
319 | "id": "a92c62f9",
320 | "metadata": {
321 | "hide-output": false
322 | },
323 | "outputs": [],
324 | "source": [
325 | "@jax.jit\n",
326 | "def get_greedy(v, model):\n",
327 | " \"Get a v-greedy policy.\"\n",
328 | " n, w_vals, P, β, c = model\n",
329 | " e = w_vals / (1 - β)\n",
330 | " h = c + β * P @ v\n",
331 | " σ = jnp.where(e >= h, 1, 0)\n",
332 | " return σ"
333 | ]
334 | },
335 | {
336 | "cell_type": "markdown",
337 | "id": "8b05a07f",
338 | "metadata": {},
339 | "source": [
340 | "Here’s a routine for value function iteration."
341 | ]
342 | },
343 | {
344 | "cell_type": "code",
345 | "execution_count": null,
346 | "id": "ed121cfe",
347 | "metadata": {
348 | "hide-output": false
349 | },
350 | "outputs": [],
351 | "source": [
352 | "def vfi(model, max_iter=10_000, tol=1e-4):\n",
353 | " \"Solve the infinite-horizon Markov job search model by VFI.\"\n",
354 | " print(\"Starting VFI iteration.\")\n",
355 | " v = jnp.zeros_like(model.w_vals) # Initial guess\n",
356 | " i = 0\n",
357 | " error = tol + 1\n",
358 | "\n",
359 | " while error > tol and i < max_iter:\n",
360 | " new_v = T(v, model)\n",
361 | " error = jnp.max(jnp.abs(new_v - v))\n",
362 | " i += 1\n",
363 | " v = new_v\n",
364 | "\n",
365 | " v_star = v\n",
366 | " σ_star = get_greedy(v_star, model)\n",
367 | " return v_star, σ_star"
368 | ]
369 | },
370 | {
371 | "cell_type": "markdown",
372 | "id": "31486c61",
373 | "metadata": {},
374 | "source": [
375 | "## Computing the solution\n",
376 | "\n",
377 | "Let’s set up and solve the model."
378 | ]
379 | },
380 | {
381 | "cell_type": "code",
382 | "execution_count": null,
383 | "id": "9a1569bc",
384 | "metadata": {
385 | "hide-output": false
386 | },
387 | "outputs": [],
388 | "source": [
389 | "model = create_js_model()\n",
390 | "n, w_vals, P, β, c = model\n",
391 | "\n",
392 | "v_star, σ_star = vfi(model)"
393 | ]
394 | },
395 | {
396 | "cell_type": "markdown",
397 | "id": "b6d31d9f",
398 | "metadata": {},
399 | "source": [
400 | "Here’s the optimal policy:"
401 | ]
402 | },
403 | {
404 | "cell_type": "code",
405 | "execution_count": null,
406 | "id": "2aa17c35",
407 | "metadata": {
408 | "hide-output": false
409 | },
410 | "outputs": [],
411 | "source": [
412 | "fig, ax = plt.subplots()\n",
413 | "ax.plot(σ_star)\n",
414 | "ax.set_xlabel(\"wage values\")\n",
415 | "ax.set_ylabel(\"optimal choice (stop=1)\")\n",
416 | "plt.show()"
417 | ]
418 | },
419 | {
420 | "cell_type": "markdown",
421 | "id": "73839e6b",
422 | "metadata": {},
423 | "source": [
424 | "We compute the reservation wage as the first $ w $ such that $ \\sigma(w)=1 $."
425 | ]
426 | },
427 | {
428 | "cell_type": "code",
429 | "execution_count": null,
430 | "id": "72647e40",
431 | "metadata": {
432 | "hide-output": false
433 | },
434 | "outputs": [],
435 | "source": [
436 | "stop_indices = jnp.where(σ_star == 1)\n",
437 | "stop_indices"
438 | ]
439 | },
440 | {
441 | "cell_type": "code",
442 | "execution_count": null,
443 | "id": "d81ccf2c",
444 | "metadata": {
445 | "hide-output": false
446 | },
447 | "outputs": [],
448 | "source": [
449 | "res_wage_index = min(stop_indices[0])"
450 | ]
451 | },
452 | {
453 | "cell_type": "code",
454 | "execution_count": null,
455 | "id": "497a530d",
456 | "metadata": {
457 | "hide-output": false
458 | },
459 | "outputs": [],
460 | "source": [
461 | "res_wage = w_vals[res_wage_index]"
462 | ]
463 | },
464 | {
465 | "cell_type": "markdown",
466 | "id": "26366628",
467 | "metadata": {},
468 | "source": [
469 | "Here’s a joint plot of the value function and the reservation wage."
470 | ]
471 | },
472 | {
473 | "cell_type": "code",
474 | "execution_count": null,
475 | "id": "cd9a2b1b",
476 | "metadata": {
477 | "hide-output": false
478 | },
479 | "outputs": [],
480 | "source": [
481 | "fig, ax = plt.subplots()\n",
482 | "ax.plot(w_vals, v_star, alpha=0.8, label=\"value function\")\n",
483 | "ax.vlines((res_wage,), 150, 400, 'k', ls='--', label=\"reservation wage\")\n",
484 | "ax.legend(frameon=False, fontsize=12, loc=\"lower right\")\n",
485 | "ax.set_xlabel(\"$w$\", fontsize=12)\n",
486 | "plt.show()"
487 | ]
488 | },
489 | {
490 | "cell_type": "markdown",
491 | "id": "1909356b",
492 | "metadata": {},
493 | "source": [
494 | "## Exercise"
495 | ]
496 | },
497 | {
498 | "cell_type": "markdown",
499 | "id": "a8dc3054",
500 | "metadata": {},
501 | "source": [
502 | "## Exercise 10.1\n",
503 | "\n",
504 | "In the setting above, the agent is risk-neutral vis-a-vis future utility risk.\n",
505 | "\n",
506 | "Now solve the same problem but this time assuming that the agent has risk-sensitive\n",
507 | "preferences, which are a type of nonlinear recursive preferences.\n",
508 | "\n",
509 | "The Bellman equation becomes\n",
510 | "\n",
511 | "$$\n",
512 | "v(w) = \\max\n",
513 | " \\left\\{\n",
514 | " \\frac{w}{1-\\beta}, \n",
515 | " c + \\frac{\\beta}{\\theta}\n",
516 | " \\ln \\left[ \n",
517 | " \\sum_{w'} \\exp(\\theta v(w')) P(w, w')\n",
518 | " \\right]\n",
519 | " \\right\\}\n",
520 | "$$\n",
521 | "\n",
522 | "When $ \\theta < 0 $ the agent is risk averse.\n",
523 | "\n",
524 | "Solve the model when $ \\theta = -0.1 $ and compare your result to the risk neutral\n",
525 | "case.\n",
526 | "\n",
527 | "Try to interpret your result.\n",
528 | "\n",
529 | "You can start with the following code:"
530 | ]
531 | },
532 | {
533 | "cell_type": "code",
534 | "execution_count": null,
535 | "id": "17ee5327",
536 | "metadata": {
537 | "hide-output": false
538 | },
539 | "outputs": [],
540 | "source": [
541 | "RiskModel = namedtuple('Model', ('n', 'w_vals', 'P', 'β', 'c', 'θ'))\n",
542 | "\n",
543 | "def create_risk_sensitive_js_model(\n",
544 | " n=500, # wage grid size\n",
545 | " ρ=0.9, # wage persistence\n",
546 | " ν=0.2, # wage volatility\n",
547 | " β=0.99, # discount factor\n",
548 | " c=1.0, # unemployment compensation\n",
549 | " θ=-0.1 # risk parameter\n",
550 | " ):\n",
551 | " \"Creates an instance of the job search model with Markov wages.\"\n",
552 | " mc = qe.tauchen(n, ρ, ν)\n",
553 | " w_vals, P = jnp.exp(mc.state_values), mc.P\n",
554 | " P = jnp.array(P)\n",
555 | " return RiskModel(n, w_vals, P, β, c, θ)"
556 | ]
557 | },
558 | {
559 | "cell_type": "markdown",
560 | "id": "055a1e01",
561 | "metadata": {},
562 | "source": [
563 | "Now you need to modify `T` and `get_greedy` and then run value function iteration again."
564 | ]
565 | },
566 | {
567 | "cell_type": "markdown",
568 | "id": "fc67c4ed",
569 | "metadata": {},
570 | "source": [
571 | "## Solution"
572 | ]
573 | },
574 | {
575 | "cell_type": "code",
576 | "execution_count": null,
577 | "id": "c3a0e61f",
578 | "metadata": {
579 | "hide-output": false
580 | },
581 | "outputs": [],
582 | "source": [
583 | "RiskModel = namedtuple('Model', ('n', 'w_vals', 'P', 'β', 'c', 'θ'))\n",
584 | "\n",
585 | "def create_risk_sensitive_js_model(\n",
586 | " n=500, # wage grid size\n",
587 | " ρ=0.9, # wage persistence\n",
588 | " ν=0.2, # wage volatility\n",
589 | " β=0.99, # discount factor\n",
590 | " c=1.0, # unemployment compensation\n",
591 | " θ=-0.1 # risk parameter\n",
592 | " ):\n",
593 | " \"Creates an instance of the job search model with Markov wages.\"\n",
594 | " mc = qe.tauchen(n, ρ, ν)\n",
595 | " w_vals, P = jnp.exp(mc.state_values), mc.P\n",
596 | " P = jnp.array(P)\n",
597 | " return RiskModel(n, w_vals, P, β, c, θ)\n",
598 | "\n",
599 | "\n",
600 | "@jax.jit\n",
601 | "def T_rs(v, model):\n",
602 | " \"\"\"\n",
603 | " The Bellman operator Tv = max{e, c + β R v} with \n",
604 | "\n",
605 | " e(w) = w / (1-β) and\n",
606 | "\n",
607 | " (Rv)(w) = (1/θ) ln{E_w[ exp(θ v(W'))]}\n",
608 | "\n",
609 | " \"\"\"\n",
610 | " n, w_vals, P, β, c, θ = model\n",
611 | " h = c + (β / θ) * jnp.log(P @ (jnp.exp(θ * v)))\n",
612 | " e = w_vals / (1 - β)\n",
613 | "\n",
614 | " return jnp.maximum(e, h)\n",
615 | "\n",
616 | "\n",
617 | "@jax.jit\n",
618 | "def get_greedy_rs(v, model):\n",
619 | " \" Get a v-greedy policy.\"\n",
620 | " n, w_vals, P, β, c, θ = model\n",
621 | " e = w_vals / (1 - β)\n",
622 | " h = c + (β / θ) * jnp.log(P @ (jnp.exp(θ * v)))\n",
623 | " σ = jnp.where(e >= h, 1, 0)\n",
624 | " return σ\n",
625 | "\n",
626 | "\n",
627 | "\n",
628 | "def vfi(model, max_iter=10_000, tol=1e-4):\n",
629 | " \"Solve the infinite-horizon Markov job search model by VFI.\"\n",
630 | " print(\"Starting VFI iteration.\")\n",
631 | " v = jnp.zeros_like(model.w_vals) # Initial guess\n",
632 | " i = 0\n",
633 | " error = tol + 1\n",
634 | "\n",
635 | " while error > tol and i < max_iter:\n",
636 | " new_v = T_rs(v, model)\n",
637 | " error = jnp.max(jnp.abs(new_v - v))\n",
638 | " i += 1\n",
639 | " v = new_v\n",
640 | "\n",
641 | " v_star = v\n",
642 | " σ_star = get_greedy_rs(v_star, model)\n",
643 | " return v_star, σ_star\n",
644 | "\n",
645 | "\n",
646 | "\n",
647 | "model_rs = create_risk_sensitive_js_model()\n",
648 | "\n",
649 | "n, w_vals, P, β, c, θ = model_rs\n",
650 | "\n",
651 | "v_star_rs, σ_star_rs = vfi(model_rs)"
652 | ]
653 | },
654 | {
655 | "cell_type": "markdown",
656 | "id": "7bae9631",
657 | "metadata": {},
658 | "source": [
659 | "Let’s plot the results together with the original risk neutral case and see what we get."
660 | ]
661 | },
662 | {
663 | "cell_type": "code",
664 | "execution_count": null,
665 | "id": "3d7985bb",
666 | "metadata": {
667 | "hide-output": false
668 | },
669 | "outputs": [],
670 | "source": [
671 | "stop_indices = jnp.where(σ_star_rs == 1)\n",
672 | "res_wage_index = min(stop_indices[0])\n",
673 | "res_wage_rs = w_vals[res_wage_index]"
674 | ]
675 | },
676 | {
677 | "cell_type": "code",
678 | "execution_count": null,
679 | "id": "929ef432",
680 | "metadata": {
681 | "hide-output": false
682 | },
683 | "outputs": [],
684 | "source": [
685 | "fig, ax = plt.subplots()\n",
686 | "ax.plot(w_vals, v_star, alpha=0.8, label=\"risk neutral $v$\")\n",
687 | "ax.plot(w_vals, v_star_rs, alpha=0.8, label=\"risk sensitive $v$\")\n",
688 | "ax.vlines((res_wage,), 100, 400, ls='--', color='darkblue', \n",
689 | " alpha=0.5, label=r\"risk neutral $\\bar w$\")\n",
690 | "ax.vlines((res_wage_rs,), 100, 400, ls='--', color='orange', \n",
691 | " alpha=0.5, label=r\"risk sensitive $\\bar w$\")\n",
692 | "ax.legend(frameon=False, fontsize=12, loc=\"lower right\")\n",
693 | "ax.set_xlabel(\"$w$\", fontsize=12)\n",
694 | "plt.show()"
695 | ]
696 | },
697 | {
698 | "cell_type": "markdown",
699 | "id": "7cca1d2a",
700 | "metadata": {},
701 | "source": [
702 | "The figure shows that the reservation wage under risk sensitive preferences (RS $ \\bar w $) shifts down.\n",
703 | "\n",
704 | "This makes sense – the agent does not like risk and hence is more inclined to\n",
705 | "accept the current offer, even when it’s lower."
706 | ]
707 | }
708 | ],
709 | "metadata": {
710 | "date": 1765244755.386771,
711 | "filename": "job_search.md",
712 | "kernelspec": {
713 | "display_name": "Python",
714 | "language": "python3",
715 | "name": "python3"
716 | },
717 | "title": "Job Search"
718 | },
719 | "nbformat": 4,
720 | "nbformat_minor": 5
721 | }
--------------------------------------------------------------------------------
/kesten_processes.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "7a07b487",
6 | "metadata": {},
7 | "source": [
8 | "$$\n",
9 | "\\newcommand{\\argmax}{arg\\,max}\n",
10 | "\\newcommand{\\argmin}{arg\\,min}\n",
11 | "$$"
12 | ]
13 | },
14 | {
15 | "cell_type": "markdown",
16 | "id": "e8cb22f4",
17 | "metadata": {},
18 | "source": [
19 | "# Kesten Processes and Firm Dynamics\n",
20 | "\n",
21 | "\n",
22 | ""
23 | ]
24 | },
25 | {
26 | "cell_type": "markdown",
27 | "id": "db370006",
28 | "metadata": {},
29 | "source": [
30 | "# GPU\n",
31 | "\n",
32 | "This lecture was built using a machine with access to a GPU.\n",
33 | "\n",
34 | "[Google Colab](https://colab.research.google.com/) has a free tier with GPUs\n",
35 | "that you can access as follows:\n",
36 | "\n",
37 | "1. Click on the “play” icon top right \n",
38 | "1. Select Colab \n",
39 | "1. Set the runtime environment to include a GPU \n",
40 | "\n",
41 | "\n",
42 | "In addition to JAX and Anaconda, this lecture will need the following libraries:"
43 | ]
44 | },
45 | {
46 | "cell_type": "code",
47 | "execution_count": null,
48 | "id": "35806a7e",
49 | "metadata": {
50 | "hide-output": false
51 | },
52 | "outputs": [],
53 | "source": [
54 | "!pip install quantecon"
55 | ]
56 | },
57 | {
58 | "cell_type": "markdown",
59 | "id": "8c670a2a",
60 | "metadata": {},
61 | "source": [
62 | "## Overview\n",
63 | "\n",
64 | "This lecture describes Kesten processes, which are an important class of\n",
65 | "stochastic processes, and an application of firm dynamics.\n",
66 | "\n",
67 | "The lecture draws on [an earlier QuantEcon lecture](https://python.quantecon.org/kesten_processes.html),\n",
68 | "which uses Numba to accelerate the computations.\n",
69 | "\n",
70 | "In that earlier lecture you can find a more detailed discussion of the concepts involved.\n",
71 | "\n",
72 | "This lecture focuses on implementing the same computations in JAX.\n",
73 | "\n",
74 | "Let’s start with some imports:"
75 | ]
76 | },
77 | {
78 | "cell_type": "code",
79 | "execution_count": null,
80 | "id": "0c295302",
81 | "metadata": {
82 | "hide-output": false
83 | },
84 | "outputs": [],
85 | "source": [
86 | "import matplotlib.pyplot as plt\n",
87 | "import quantecon as qe\n",
88 | "import jax\n",
89 | "import jax.numpy as jnp\n",
90 | "from jax import random\n",
91 | "from jax import lax\n",
92 | "from quantecon import tic, toc\n",
93 | "from typing import NamedTuple\n",
94 | "from functools import partial"
95 | ]
96 | },
97 | {
98 | "cell_type": "markdown",
99 | "id": "91a6ba4a",
100 | "metadata": {},
101 | "source": [
102 | "Let’s check the GPU we are running"
103 | ]
104 | },
105 | {
106 | "cell_type": "code",
107 | "execution_count": null,
108 | "id": "6464bb03",
109 | "metadata": {
110 | "hide-output": false
111 | },
112 | "outputs": [],
113 | "source": [
114 | "!nvidia-smi"
115 | ]
116 | },
117 | {
118 | "cell_type": "markdown",
119 | "id": "ecf858e4",
120 | "metadata": {},
121 | "source": [
122 | "## Kesten processes\n",
123 | "\n",
124 | "\n",
125 | "\n",
126 | "A **Kesten process** is a stochastic process of the form\n",
127 | "\n",
128 | "\n",
129 | "\n",
130 | "$$\n",
131 | "X_{t+1} = a_{t+1} X_t + \\eta_{t+1} \\tag{6.1}\n",
132 | "$$\n",
133 | "\n",
134 | "where $ \\{a_t\\}_{t \\geq 1} $ and $ \\{\\eta_t\\}_{t \\geq 1} $ are IID\n",
135 | "sequences.\n",
136 | "\n",
137 | "We are interested in the dynamics of $ \\{X_t\\}_{t \\geq 0} $ when $ X_0 $ is given.\n",
138 | "\n",
139 | "We will focus on the nonnegative scalar case, where $ X_t $ takes values in $ \\mathbb R_+ $.\n",
140 | "\n",
141 | "In particular, we will assume that\n",
142 | "\n",
143 | "- the initial condition $ X_0 $ is nonnegative, \n",
144 | "- $ \\{a_t\\}_{t \\geq 1} $ is a nonnegative IID stochastic process and \n",
145 | "- $ \\{\\eta_t\\}_{t \\geq 1} $ is another nonnegative IID stochastic process, independent of the first. "
146 | ]
147 | },
148 | {
149 | "cell_type": "markdown",
150 | "id": "8b666033",
151 | "metadata": {},
152 | "source": [
153 | "### Application: firm dynamics\n",
154 | "\n",
155 | "In this section we apply Kesten process theory to the study of firm dynamics."
156 | ]
157 | },
158 | {
159 | "cell_type": "markdown",
160 | "id": "4acb3b07",
161 | "metadata": {},
162 | "source": [
163 | "#### Gibrat’s law\n",
164 | "\n",
165 | "It was postulated many years ago by Robert Gibrat that firm size evolves\n",
166 | "according to a simple rule whereby size next period is proportional to current\n",
167 | "size.\n",
168 | "\n",
169 | "This is now known as [Gibrat’s law of proportional growth](https://en.wikipedia.org/wiki/Gibrat%27s_law).\n",
170 | "\n",
171 | "We can express this idea by stating that a suitably defined measure\n",
172 | "$ s_t $ of firm size obeys\n",
173 | "\n",
174 | "\n",
175 | "\n",
176 | "$$\n",
177 | "\\frac{s_{t+1}}{s_t} = a_{t+1} \\tag{6.2}\n",
178 | "$$\n",
179 | "\n",
180 | "for some positive IID sequence $ \\{a_t\\} $.\n",
181 | "\n",
182 | "Subsequent empirical research has shown that this specification is not accurate,\n",
183 | "particularly for small firms.\n",
184 | "\n",
185 | "However, we can get close to the data by modifying [(6.2)](#equation-firm-dynam-gb) to\n",
186 | "\n",
187 | "\n",
188 | "\n",
189 | "$$\n",
190 | "s_{t+1} = a_{t+1} s_t + b_{t+1} \\tag{6.3}\n",
191 | "$$\n",
192 | "\n",
193 | "where $ \\{a_t\\} $ and $ \\{b_t\\} $ are both IID and independent of each\n",
194 | "other.\n",
195 | "\n",
196 | "We now study the implications of this specification."
197 | ]
198 | },
199 | {
200 | "cell_type": "markdown",
201 | "id": "6cebea72",
202 | "metadata": {},
203 | "source": [
204 | "#### Heavy tails\n",
205 | "\n",
206 | "If the conditions of the [Kesten–Goldie Theorem](https://python.quantecon.org/kesten_processes.html#the-kestengoldie-theorem)\n",
207 | "are satisfied, then [(6.3)](#equation-firm-dynam) implies that the firm size distribution will have Pareto tails.\n",
208 | "\n",
209 | "This matches empirical findings across many data sets.\n",
210 | "\n",
211 | "But there is another unrealistic aspect of the firm dynamics specified in [(6.3)](#equation-firm-dynam) that we need to address: it ignores entry and exit.\n",
212 | "\n",
213 | "In any given period and in any given market, we observe significant numbers of\n",
214 | "firms entering and exiting the market.\n",
215 | "\n",
216 | "In this setting, firm dynamics can be expressed as\n",
217 | "\n",
218 | "\n",
219 | "\n",
220 | "$$\n",
221 | "s_{t+1} = e_{t+1} \\mathbb{1}\\{s_t < \\bar s\\} +\n",
222 | " (a_{t+1} s_t + b_{t+1}) \\mathbb{1}\\{s_t \\geq \\bar s\\} \\tag{6.4}\n",
223 | "$$\n",
224 | "\n",
225 | "The motivation behind and interpretation of [(6.4)](#equation-firm-dynam-ee) can be found in\n",
226 | "[our earlier Kesten process lecture](https://python.quantecon.org/kesten_processes.html).\n",
227 | "\n",
228 | "What can we say about dynamics?\n",
229 | "\n",
230 | "Although [(6.4)](#equation-firm-dynam-ee) is not a Kesten process, it does update in the\n",
231 | "same way as a Kesten process when $ s_t $ is large.\n",
232 | "\n",
233 | "So perhaps its stationary distribution still has Pareto tails?\n",
234 | "\n",
235 | "We can investigate this question via simulation and rank-size plots.\n",
236 | "\n",
237 | "The approach will be to\n",
238 | "\n",
239 | "1. generate $ M $ draws of $ s_T $ when $ M $ and $ T $ are large and \n",
240 | "1. plot the largest 1,000 of the resulting draws in a rank-size plot. \n",
241 | "\n",
242 | "\n",
243 | "(The distribution of $ s_T $ will be close to the stationary distribution\n",
244 | "when $ T $ is large.)\n",
245 | "\n",
246 | "In the simulation, we assume that each of $ a_t, b_t $ and $ e_t $ is lognormal.\n",
247 | "\n",
248 | "Here’s a class to store parameters:"
249 | ]
250 | },
251 | {
252 | "cell_type": "code",
253 | "execution_count": null,
254 | "id": "3e0db88b",
255 | "metadata": {
256 | "hide-output": false
257 | },
258 | "outputs": [],
259 | "source": [
260 | "class Firm(NamedTuple):\n",
261 | " μ_a: float = -0.5\n",
262 | " σ_a: float = 0.1\n",
263 | " μ_b: float = 0.0\n",
264 | " σ_b: float = 0.5\n",
265 | " μ_e: float = 0.0\n",
266 | " σ_e: float = 0.5\n",
267 | " s_bar: float = 1.0"
268 | ]
269 | },
270 | {
271 | "cell_type": "markdown",
272 | "id": "5f6185f9",
273 | "metadata": {},
274 | "source": [
275 | "Here’s code to update a cross-section of firms according to the dynamics in\n",
276 | "[(6.4)](#equation-firm-dynam-ee)."
277 | ]
278 | },
279 | {
280 | "cell_type": "code",
281 | "execution_count": null,
282 | "id": "2bf8fe0a",
283 | "metadata": {
284 | "hide-output": false
285 | },
286 | "outputs": [],
287 | "source": [
288 | "@jax.jit\n",
289 | "def update_cross_section(s, a, b, e, firm):\n",
290 | " μ_a, σ_a, μ_b, σ_b, μ_e, σ_e, s_bar = firm\n",
291 | " s = jnp.where(s < s_bar, e, a * s + b)\n",
292 | " return s"
293 | ]
294 | },
295 | {
296 | "cell_type": "markdown",
297 | "id": "988a8179",
298 | "metadata": {},
299 | "source": [
300 | "Now we write a for loop that repeatedly calls this function, to push a\n",
301 | "cross-section of firms forward in time.\n",
302 | "\n",
303 | "For sufficiently large `T`, the cross-section it returns (the cross-section at\n",
304 | "time `T`) corresponds to firm size distribution in (approximate) equilibrium."
305 | ]
306 | },
307 | {
308 | "cell_type": "code",
309 | "execution_count": null,
310 | "id": "b87a0cd9",
311 | "metadata": {
312 | "hide-output": false
313 | },
314 | "outputs": [],
315 | "source": [
316 | "def generate_cross_section(\n",
317 | " firm, M=500_000, T=500, s_init=1.0, seed=123\n",
318 | " ):\n",
319 | "\n",
320 | " μ_a, σ_a, μ_b, σ_b, μ_e, σ_e, s_bar = firm\n",
321 | " key = random.PRNGKey(seed)\n",
322 | "\n",
323 | " # Initialize the cross-section to a common value\n",
324 | " s = jnp.full((M, ), s_init)\n",
325 | "\n",
326 | " # Perform updates on s for time t\n",
327 | " for t in range(T):\n",
328 | " key, *subkeys = random.split(key, 4)\n",
329 | " a = μ_a + σ_a * random.normal(subkeys[0], (M,))\n",
330 | " b = μ_b + σ_b * random.normal(subkeys[1], (M,))\n",
331 | " e = μ_e + σ_e * random.normal(subkeys[2], (M,))\n",
332 | " # Exponentiate shocks\n",
333 | " a, b, e = jax.tree.map(jnp.exp, (a, b, e))\n",
334 | " # Update the cross-section of firms\n",
335 | " s = update_cross_section(s, a, b, e, firm)\n",
336 | "\n",
337 | " return s"
338 | ]
339 | },
340 | {
341 | "cell_type": "markdown",
342 | "id": "80a8f0ff",
343 | "metadata": {},
344 | "source": [
345 | "Let’s try running the code and generating a cross-section."
346 | ]
347 | },
348 | {
349 | "cell_type": "code",
350 | "execution_count": null,
351 | "id": "aa982236",
352 | "metadata": {
353 | "hide-output": false
354 | },
355 | "outputs": [],
356 | "source": [
357 | "firm = Firm()\n",
358 | "tic()\n",
359 | "data = generate_cross_section(firm).block_until_ready()\n",
360 | "toc()"
361 | ]
362 | },
363 | {
364 | "cell_type": "markdown",
365 | "id": "92989c4f",
366 | "metadata": {},
367 | "source": [
368 | "We run the function again so we can see the speed without compile time."
369 | ]
370 | },
371 | {
372 | "cell_type": "code",
373 | "execution_count": null,
374 | "id": "885e3c94",
375 | "metadata": {
376 | "hide-output": false
377 | },
378 | "outputs": [],
379 | "source": [
380 | "tic()\n",
381 | "data = generate_cross_section(firm).block_until_ready()\n",
382 | "toc()"
383 | ]
384 | },
385 | {
386 | "cell_type": "markdown",
387 | "id": "bdee97b1",
388 | "metadata": {},
389 | "source": [
390 | "Let’s produce the rank-size plot and check the distribution:"
391 | ]
392 | },
393 | {
394 | "cell_type": "code",
395 | "execution_count": null,
396 | "id": "d467d198",
397 | "metadata": {
398 | "hide-output": false
399 | },
400 | "outputs": [],
401 | "source": [
402 | "fig, ax = plt.subplots()\n",
403 | "\n",
404 | "rank_data, size_data = qe.rank_size(data, c=0.01)\n",
405 | "ax.loglog(rank_data, size_data, 'o', markersize=3.0, alpha=0.5)\n",
406 | "ax.set_xlabel(\"log rank\")\n",
407 | "ax.set_ylabel(\"log size\")\n",
408 | "\n",
409 | "plt.show()"
410 | ]
411 | },
412 | {
413 | "cell_type": "markdown",
414 | "id": "4227bae9",
415 | "metadata": {},
416 | "source": [
417 | "The plot produces a straight line, consistent with a Pareto tail."
418 | ]
419 | },
420 | {
421 | "cell_type": "markdown",
422 | "id": "4bbadac5",
423 | "metadata": {},
424 | "source": [
425 | "#### Alternative implementation with `lax.fori_loop`\n",
426 | "\n",
427 | "Although we JIT-compiled some of the code above,\n",
428 | "we did not JIT-compile the `for` loop.\n",
429 | "\n",
430 | "Let’s try squeezing out a bit more speed\n",
431 | "by\n",
432 | "\n",
433 | "- replacing the `for` loop with [`lax.fori_loop`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.fori_loop.html) and \n",
434 | "- JIT-compiling the whole function. \n",
435 | "\n",
436 | "\n",
437 | "Here a the `lax.fori_loop` version:"
438 | ]
439 | },
440 | {
441 | "cell_type": "code",
442 | "execution_count": null,
443 | "id": "b17ff2ec",
444 | "metadata": {
445 | "hide-output": false
446 | },
447 | "outputs": [],
448 | "source": [
449 | "@jax.jit\n",
450 | "def generate_cross_section_lax(\n",
451 | " firm, T=500, M=500_000, s_init=1.0, seed=123\n",
452 | " ):\n",
453 | "\n",
454 | " μ_a, σ_a, μ_b, σ_b, μ_e, σ_e, s_bar = firm\n",
455 | " key = random.PRNGKey(seed)\n",
456 | " \n",
457 | " # Initial cross section\n",
458 | " s = jnp.full((M, ), s_init)\n",
459 | "\n",
460 | " def update_cross_section(t, state):\n",
461 | " s, key = state\n",
462 | " key, *subkeys = jax.random.split(key, 4)\n",
463 | " # Generate current random draws \n",
464 | " a = μ_a + σ_a * random.normal(subkeys[0], (M,))\n",
465 | " b = μ_b + σ_b * random.normal(subkeys[1], (M,))\n",
466 | " e = μ_e + σ_e * random.normal(subkeys[2], (M,))\n",
467 | " # Exponentiate them\n",
468 | " a, b, e = jax.tree.map(jnp.exp, (a, b, e))\n",
469 | " # Pull out the t-th cross-section of shocks\n",
470 | " s = jnp.where(s < s_bar, e, a * s + b)\n",
471 | " new_state = s, key\n",
472 | " return new_state\n",
473 | "\n",
474 | " # Use fori_loop \n",
475 | " initial_state = s, key\n",
476 | " final_s, final_key = lax.fori_loop(\n",
477 | " 0, T, update_cross_section, initial_state\n",
478 | " )\n",
479 | " return final_s"
480 | ]
481 | },
482 | {
483 | "cell_type": "markdown",
484 | "id": "0039f701",
485 | "metadata": {},
486 | "source": [
487 | "Let’s see if we get any speed gain"
488 | ]
489 | },
490 | {
491 | "cell_type": "code",
492 | "execution_count": null,
493 | "id": "056c0a8f",
494 | "metadata": {
495 | "hide-output": false
496 | },
497 | "outputs": [],
498 | "source": [
499 | "tic()\n",
500 | "data = generate_cross_section_lax(firm).block_until_ready()\n",
501 | "toc()"
502 | ]
503 | },
504 | {
505 | "cell_type": "code",
506 | "execution_count": null,
507 | "id": "87633363",
508 | "metadata": {
509 | "hide-output": false
510 | },
511 | "outputs": [],
512 | "source": [
513 | "tic()\n",
514 | "data = generate_cross_section_lax(firm).block_until_ready()\n",
515 | "toc()"
516 | ]
517 | },
518 | {
519 | "cell_type": "markdown",
520 | "id": "27a4953d",
521 | "metadata": {},
522 | "source": [
523 | "Here we produce the same rank-size plot:"
524 | ]
525 | },
526 | {
527 | "cell_type": "code",
528 | "execution_count": null,
529 | "id": "65d1bf97",
530 | "metadata": {
531 | "hide-output": false
532 | },
533 | "outputs": [],
534 | "source": [
535 | "fig, ax = plt.subplots()\n",
536 | "\n",
537 | "rank_data, size_data = qe.rank_size(data, c=0.01)\n",
538 | "ax.loglog(rank_data, size_data, 'o', markersize=3.0, alpha=0.5)\n",
539 | "ax.set_xlabel(\"log rank\")\n",
540 | "ax.set_ylabel(\"log size\")\n",
541 | "\n",
542 | "plt.show()"
543 | ]
544 | },
545 | {
546 | "cell_type": "markdown",
547 | "id": "53c80c22",
548 | "metadata": {},
549 | "source": [
550 | "## Exercises"
551 | ]
552 | },
553 | {
554 | "cell_type": "markdown",
555 | "id": "10efbecf",
556 | "metadata": {},
557 | "source": [
558 | "## Exercise 6.1\n",
559 | "\n",
560 | "Try writing an alternative version of `generate_cross_section_lax()` where the entire sequence of random draws is generated at once, so that all of `a`, `b`, and `e` are of shape `(T, M)`.\n",
561 | "\n",
562 | "(The `update_cross_section()` function should not generate any random numbers.)\n",
563 | "\n",
564 | "Does it improve the runtime?\n",
565 | "\n",
566 | "What are the pros and cons of this approach?"
567 | ]
568 | },
569 | {
570 | "cell_type": "markdown",
571 | "id": "7c65d570",
572 | "metadata": {},
573 | "source": [
574 | "## Solution"
575 | ]
576 | },
577 | {
578 | "cell_type": "code",
579 | "execution_count": null,
580 | "id": "2b22aca8",
581 | "metadata": {
582 | "hide-output": false
583 | },
584 | "outputs": [],
585 | "source": [
586 | "@jax.jit\n",
587 | "def generate_cross_section_lax(\n",
588 | " firm, T=500, M=500_000, s_init=1.0, seed=123\n",
589 | " ):\n",
590 | "\n",
591 | " μ_a, σ_a, μ_b, σ_b, μ_e, σ_e, s_bar = firm\n",
592 | " key = random.PRNGKey(seed)\n",
593 | " subkey_1, subkey_2, subkey_3 = random.split(key, 3)\n",
594 | " \n",
595 | " # Generate entire sequence of random draws \n",
596 | " a = μ_a + σ_a * random.normal(subkey_1, (T, M))\n",
597 | " b = μ_b + σ_b * random.normal(subkey_2, (T, M))\n",
598 | " e = μ_e + σ_e * random.normal(subkey_3, (T, M))\n",
599 | " # Exponentiate them\n",
600 | " a, b, e = jax.tree.map(jnp.exp, (a, b, e))\n",
601 | " # Initial cross section\n",
602 | " s = jnp.full((M, ), s_init)\n",
603 | "\n",
604 | " def update_cross_section(t, s):\n",
605 | " # Pull out the t-th cross-section of shocks\n",
606 | " a_t, b_t, e_t = a[t], b[t], e[t]\n",
607 | " s = jnp.where(s < s_bar, e_t, a_t * s + b_t)\n",
608 | " return s\n",
609 | "\n",
610 | " # Use lax.scan to perform the calculations on all states\n",
611 | " s_final = lax.fori_loop(0, T, update_cross_section, s)\n",
612 | " return s_final"
613 | ]
614 | },
615 | {
616 | "cell_type": "markdown",
617 | "id": "454d1308",
618 | "metadata": {},
619 | "source": [
620 | "Here are the run times."
621 | ]
622 | },
623 | {
624 | "cell_type": "code",
625 | "execution_count": null,
626 | "id": "c213fd00",
627 | "metadata": {
628 | "hide-output": false
629 | },
630 | "outputs": [],
631 | "source": [
632 | "tic()\n",
633 | "data = generate_cross_section_lax(firm).block_until_ready()\n",
634 | "toc()"
635 | ]
636 | },
637 | {
638 | "cell_type": "code",
639 | "execution_count": null,
640 | "id": "66300fab",
641 | "metadata": {
642 | "hide-output": false
643 | },
644 | "outputs": [],
645 | "source": [
646 | "tic()\n",
647 | "data = generate_cross_section_lax(firm).block_until_ready()\n",
648 | "toc()"
649 | ]
650 | },
651 | {
652 | "cell_type": "markdown",
653 | "id": "7a870d0d",
654 | "metadata": {},
655 | "source": [
656 | "This method might or might not be faster.\n",
657 | "\n",
658 | "In general, the relative speed will depend on the size of the cross-section and the length of\n",
659 | "the simulation paths.\n",
660 | "\n",
661 | "However, this method is far more memory intensive.\n",
662 | "\n",
663 | "It will fail when $ T $ and $ M $ become sufficiently large."
664 | ]
665 | }
666 | ],
667 | "metadata": {
668 | "date": 1765244755.4315395,
669 | "filename": "kesten_processes.md",
670 | "kernelspec": {
671 | "display_name": "Python",
672 | "language": "python3",
673 | "name": "python3"
674 | },
675 | "title": "Kesten Processes and Firm Dynamics"
676 | },
677 | "nbformat": 4,
678 | "nbformat_minor": 5
679 | }
--------------------------------------------------------------------------------
/inventory_ssd.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "315153b7",
6 | "metadata": {},
7 | "source": [
8 | "$$\n",
9 | "\\newcommand{\\argmax}{arg\\,max}\n",
10 | "\\newcommand{\\argmin}{arg\\,min}\n",
11 | "$$"
12 | ]
13 | },
14 | {
15 | "cell_type": "markdown",
16 | "id": "d5577058",
17 | "metadata": {},
18 | "source": [
19 | "# Inventory Management Model"
20 | ]
21 | },
22 | {
23 | "cell_type": "markdown",
24 | "id": "610a1a38",
25 | "metadata": {},
26 | "source": [
27 | "# GPU\n",
28 | "\n",
29 | "This lecture was built using a machine with access to a GPU.\n",
30 | "\n",
31 | "[Google Colab](https://colab.research.google.com/) has a free tier with GPUs\n",
32 | "that you can access as follows:\n",
33 | "\n",
34 | "1. Click on the “play” icon top right \n",
35 | "1. Select Colab \n",
36 | "1. Set the runtime environment to include a GPU \n",
37 | "\n",
38 | "\n",
39 | "This lecture provides a JAX implementation of a model in [Dynamic Programming](https://dp.quantecon.org/).\n",
40 | "\n",
41 | "In addition to JAX and Anaconda, this lecture will need the following libraries:"
42 | ]
43 | },
44 | {
45 | "cell_type": "code",
46 | "execution_count": null,
47 | "id": "855a67dd",
48 | "metadata": {
49 | "hide-output": false
50 | },
51 | "outputs": [],
52 | "source": [
53 | "!pip install --upgrade quantecon"
54 | ]
55 | },
56 | {
57 | "cell_type": "markdown",
58 | "id": "89375451",
59 | "metadata": {},
60 | "source": [
61 | "## A model with constant discounting\n",
62 | "\n",
63 | "We study a firm where a manager tries to maximize shareholder value.\n",
64 | "\n",
65 | "To simplify the problem, we assume that the firm only sells one product.\n",
66 | "\n",
67 | "Letting $ \\pi_t $ be profits at time $ t $ and $ r > 0 $ be the interest rate, the value of the firm is\n",
68 | "\n",
69 | "$$\n",
70 | "V_0 = \\sum_{t \\geq 0} \\beta^t \\pi_t\n",
71 | " \\qquad\n",
72 | " \\text{ where }\n",
73 | " \\quad \\beta := \\frac{1}{1+r}.\n",
74 | "$$\n",
75 | "\n",
76 | "Suppose the firm faces exogenous demand process $ (D_t)_{t \\geq 0} $.\n",
77 | "\n",
78 | "We assume $ (D_t)_{t \\geq 0} $ is IID with common distribution $ \\phi \\in (Z_+) $.\n",
79 | "\n",
80 | "Inventory $ (X_t)_{t \\geq 0} $ of the product obeys\n",
81 | "\n",
82 | "$$\n",
83 | "X_{t+1} = f(X_t, D_{t+1}, A_t)\n",
84 | " \\qquad\n",
85 | " \\text{where}\n",
86 | " \\quad\n",
87 | " f(x,a,d) := (x - d)\\vee 0 + a.\n",
88 | "$$\n",
89 | "\n",
90 | "The term $ A_t $ is units of stock ordered this period, which take one period to\n",
91 | "arrive.\n",
92 | "\n",
93 | "We assume that the firm can store at most $ K $ items at one time.\n",
94 | "\n",
95 | "Profits are given by\n",
96 | "\n",
97 | "$$\n",
98 | "\\pi_t := X_t \\wedge D_{t+1} - c A_t - \\kappa 1\\{A_t > 0\\}.\n",
99 | "$$\n",
100 | "\n",
101 | "We take the minimum of current stock and demand because orders in excess of\n",
102 | "inventory are assumed to be lost rather than back-filled.\n",
103 | "\n",
104 | "Here $ c $ is unit product cost and $ \\kappa $ is a fixed cost of ordering inventory.\n",
105 | "\n",
106 | "We can map our inventory problem into a dynamic program with state space\n",
107 | "$ X := \\{0, \\ldots, K\\} $ and action space $ A := X $.\n",
108 | "\n",
109 | "The feasible correspondence $ \\Gamma $ is\n",
110 | "\n",
111 | "$$\n",
112 | "\\Gamma(x) := \\{0, \\ldots, K - x\\},\n",
113 | "$$\n",
114 | "\n",
115 | "which represents the set of feasible orders when the current inventory\n",
116 | "state is $ x $.\n",
117 | "\n",
118 | "The reward function is expected current profits, or\n",
119 | "\n",
120 | "$$\n",
121 | "r(x, a) := \\sum_{d \\geq 0} (x \\wedge d) \\phi(d)\n",
122 | " - c a - \\kappa 1\\{a > 0\\}.\n",
123 | "$$\n",
124 | "\n",
125 | "The stochastic kernel (i.e., state-transition probabilities) from the set of feasible state-action pairs is\n",
126 | "\n",
127 | "$$\n",
128 | "P(x, a, x') := P\\{ f(x, a, D) = x' \\}\n",
129 | " \\qquad \\text{when} \\quad\n",
130 | " D \\sim \\phi.\n",
131 | "$$\n",
132 | "\n",
133 | "When discounting is constant, the Bellman equation takes the form\n",
134 | "\n",
135 | "\n",
136 | "\n",
137 | "$$\n",
138 | "v(x)\n",
139 | " = \\max_{a \\in \\Gamma(x)} \\left\\{\n",
140 | " r(x, a)\n",
141 | " + \\beta\n",
142 | " \\sum_{d \\geq 0} v(f(x, a, d)) \\phi(d)\n",
143 | " \\right\\} \\tag{15.1}\n",
144 | "$$"
145 | ]
146 | },
147 | {
148 | "cell_type": "markdown",
149 | "id": "d4490853",
150 | "metadata": {},
151 | "source": [
152 | "## Time varing discount rates\n",
153 | "\n",
154 | "We wish to consider a more sophisticated model with time-varying discounting.\n",
155 | "\n",
156 | "This time variation accommodates non-constant interest rates.\n",
157 | "\n",
158 | "To this end, we replace the constant $ \\beta $ in\n",
159 | "[(15.1)](#equation-inventory-ssd-v1) with a stochastic process $ (\\beta_t) $ where\n",
160 | "\n",
161 | "- $ \\beta_t = 1/(1+r_t) $ and \n",
162 | "- $ r_t $ is the interest rate at time $ t $ \n",
163 | "\n",
164 | "\n",
165 | "We suppose that the dynamics can be expressed as $ \\beta_t = \\beta(Z_t) $, where the exogenous process $ (Z_t)_{t \\geq 0} $ is a Markov chain\n",
166 | "on $ Z $ with Markov matrix $ Q $.\n",
167 | "\n",
168 | "After relabeling inventory $ X_t $ as $ Y_t $ and $ x $ as $ y $, the Bellman equation becomes\n",
169 | "\n",
170 | "$$\n",
171 | "v(y, z) = \\max_{a \\in \\Gamma(x)} B((y, z), a, v)\n",
172 | "$$\n",
173 | "\n",
174 | "where\n",
175 | "\n",
176 | "\n",
177 | "\n",
178 | "$$\n",
179 | "B((y, z), a, v)\n",
180 | " =\n",
181 | " r(y, a)\n",
182 | " + \\beta(z)\n",
183 | " \\sum_{d, \\, z'} v(f(y, a, d), z') \\phi(d) Q(z, z'). \\tag{15.2}\n",
184 | "$$\n",
185 | "\n",
186 | "We set $ \\beta(z) := z $ and\n",
187 | "\n",
188 | "$$\n",
189 | "R(y, a, y')\n",
190 | " := P\\{f(y, a, d) = y'\\} \\quad \\text{when} \\quad D \\sim \\phi,\n",
191 | "$$\n",
192 | "\n",
193 | "Now $ R(y, a, y') $ is the probability of realizing next period inventory level\n",
194 | "$ y' $ when the current level is $ y $ and the action is $ a $.\n",
195 | "\n",
196 | "Hence we can rewrite [(15.2)](#equation-inventory-ssd-b1) as\n",
197 | "\n",
198 | "$$\n",
199 | "B((y, z), a, v)\n",
200 | " = r(y, a)\n",
201 | " + \\beta(z)\n",
202 | " \\sum_{y', z'} v(y', z') Q(z, z') R(y, a, y') .\n",
203 | "$$\n",
204 | "\n",
205 | "Let’s begin with the following imports"
206 | ]
207 | },
208 | {
209 | "cell_type": "code",
210 | "execution_count": null,
211 | "id": "f6aaa001",
212 | "metadata": {
213 | "hide-output": false
214 | },
215 | "outputs": [],
216 | "source": [
217 | "import quantecon as qe\n",
218 | "import jax\n",
219 | "import jax.numpy as jnp\n",
220 | "import numpy as np\n",
221 | "import matplotlib.pyplot as plt\n",
222 | "from time import time\n",
223 | "from functools import partial\n",
224 | "from typing import NamedTuple"
225 | ]
226 | },
227 | {
228 | "cell_type": "markdown",
229 | "id": "c90a02b2",
230 | "metadata": {},
231 | "source": [
232 | "Let’s check the GPU we are running"
233 | ]
234 | },
235 | {
236 | "cell_type": "code",
237 | "execution_count": null,
238 | "id": "be0c6d79",
239 | "metadata": {
240 | "hide-output": false
241 | },
242 | "outputs": [],
243 | "source": [
244 | "!nvidia-smi"
245 | ]
246 | },
247 | {
248 | "cell_type": "markdown",
249 | "id": "47f256f3",
250 | "metadata": {},
251 | "source": [
252 | "We will use 64 bit floats with JAX in order to increase the precision."
253 | ]
254 | },
255 | {
256 | "cell_type": "code",
257 | "execution_count": null,
258 | "id": "0e789cfd",
259 | "metadata": {
260 | "hide-output": false
261 | },
262 | "outputs": [],
263 | "source": [
264 | "jax.config.update(\"jax_enable_x64\", True)"
265 | ]
266 | },
267 | {
268 | "cell_type": "markdown",
269 | "id": "5c9ad3f9",
270 | "metadata": {},
271 | "source": [
272 | "Let’s define a model to represent the inventory management."
273 | ]
274 | },
275 | {
276 | "cell_type": "code",
277 | "execution_count": null,
278 | "id": "341feeac",
279 | "metadata": {
280 | "hide-output": false
281 | },
282 | "outputs": [],
283 | "source": [
284 | "# NamedTuple Model\n",
285 | "class Model(NamedTuple):\n",
286 | " z_values: jnp.ndarray # Exogenous shock values\n",
287 | " Q: jnp.ndarray # Exogenous shock probabilities\n",
288 | " x_values: jnp.ndarray # Inventory values\n",
289 | " d_values: jnp.ndarray # Demand values for summation\n",
290 | " ϕ_values: jnp.ndarray # Demand probabilities\n",
291 | " p: float # Demand parameter\n",
292 | " c: float = 0.2 # Unit cost\n",
293 | " κ: float = 0.8 # Fixed cost"
294 | ]
295 | },
296 | {
297 | "cell_type": "code",
298 | "execution_count": null,
299 | "id": "a377819d",
300 | "metadata": {
301 | "hide-output": false
302 | },
303 | "outputs": [],
304 | "source": [
305 | "def create_sdd_inventory_model(\n",
306 | " ρ: float = 0.98, # Exogenous state autocorrelation parameter\n",
307 | " ν: float = 0.002, # Exogenous state volatility parameter\n",
308 | " n_z: int = 10, # Exogenous state discretization size\n",
309 | " b: float = 0.97, # Exogenous state offset\n",
310 | " K: int = 100, # Max inventory\n",
311 | " D_MAX: int = 101, # Demand upper bound for summation\n",
312 | " p: float = 0.6 \n",
313 | " ) -> Model:\n",
314 | " \n",
315 | " # Demand\n",
316 | " def demand_pdf(p, d):\n",
317 | " return (1 - p)**d * p\n",
318 | " \n",
319 | " d_values = jnp.arange(D_MAX)\n",
320 | " ϕ_values = demand_pdf(p, d_values)\n",
321 | " \n",
322 | " # Exogenous state process\n",
323 | " mc = qe.tauchen(n_z, ρ, ν)\n",
324 | " z_values, Q = map(jnp.array, (mc.state_values + b, mc.P))\n",
325 | " \n",
326 | " # Endogenous state\n",
327 | " x_values = jnp.arange(K + 1) # 0, 1, ..., K\n",
328 | " \n",
329 | " return Model(\n",
330 | " z_values=z_values, Q=Q, \n",
331 | " x_values=x_values, d_values=d_values, ϕ_values=ϕ_values,\n",
332 | " p=p\n",
333 | " )"
334 | ]
335 | },
336 | {
337 | "cell_type": "markdown",
338 | "id": "3c8f3ba6",
339 | "metadata": {},
340 | "source": [
341 | "Here’s the function `B` on the right-hand side of the Bellman equation."
342 | ]
343 | },
344 | {
345 | "cell_type": "code",
346 | "execution_count": null,
347 | "id": "790e79f2",
348 | "metadata": {
349 | "hide-output": false
350 | },
351 | "outputs": [],
352 | "source": [
353 | "@jax.jit\n",
354 | "def B(x, z_idx, v, model):\n",
355 | " \"\"\"\n",
356 | " Take z_idx and convert it to z. Then compute\n",
357 | "\n",
358 | " B(x, z, a, v) = r(x, a) + β(z) Σ_x′ v(x′) P(x, a, x′)\n",
359 | "\n",
360 | " for all possible choices of a.\n",
361 | " \"\"\"\n",
362 | " \n",
363 | " z_values, Q, x_values, d_values, ϕ_values, p, c, κ = model\n",
364 | " z = z_values[z_idx]\n",
365 | "\n",
366 | " def _B(a):\n",
367 | " \"\"\"\n",
368 | " Returns r(x, a) + β(z) Σ_x′ v(x′) P(x, a, x′) for each a.\n",
369 | " \"\"\"\n",
370 | " revenue = jnp.sum(jnp.minimum(x, d_values) * ϕ_values)\n",
371 | " profit = revenue - c * a - κ * (a > 0)\n",
372 | " v_R = jnp.sum(v[jnp.maximum(x - d_values, 0) + a].T * ϕ_values, axis=1)\n",
373 | " cv = jnp.sum(v_R * Q[z_idx])\n",
374 | " return profit + z * cv\n",
375 | "\n",
376 | " a_values = x_values # Set of possible order sizes\n",
377 | " B_values = jax.vmap(_B)(a_values)\n",
378 | " max_x = len(x_values) - 1\n",
379 | " \n",
380 | " return jnp.where(a_values <= max_x - x, B_values, -jnp.inf)"
381 | ]
382 | },
383 | {
384 | "cell_type": "markdown",
385 | "id": "fe96c993",
386 | "metadata": {},
387 | "source": [
388 | "We need to vectorize this function so that we can use it efficiently in JAX.\n",
389 | "\n",
390 | "We apply a sequence of `vmap` operations to vectorize appropriately in each\n",
391 | "argument."
392 | ]
393 | },
394 | {
395 | "cell_type": "code",
396 | "execution_count": null,
397 | "id": "6b44b26a",
398 | "metadata": {
399 | "hide-output": false
400 | },
401 | "outputs": [],
402 | "source": [
403 | "B = jax.vmap(B, in_axes=(None, 0, None, None))\n",
404 | "B = jax.vmap(B, in_axes=(0, None, None, None))"
405 | ]
406 | },
407 | {
408 | "cell_type": "markdown",
409 | "id": "51b2a6b9",
410 | "metadata": {},
411 | "source": [
412 | "Next we define the Bellman operator."
413 | ]
414 | },
415 | {
416 | "cell_type": "code",
417 | "execution_count": null,
418 | "id": "63628588",
419 | "metadata": {
420 | "hide-output": false
421 | },
422 | "outputs": [],
423 | "source": [
424 | "@jax.jit\n",
425 | "def T(v, model):\n",
426 | " \"\"\"The Bellman operator.\"\"\"\n",
427 | " z_values, Q, x_values, d_values, ϕ_values, p, c, κ = model\n",
428 | " z_indices = jnp.arange(len(z_values))\n",
429 | " res = B(x_values, z_indices, v, model)\n",
430 | " return jnp.max(res, axis=2)"
431 | ]
432 | },
433 | {
434 | "cell_type": "markdown",
435 | "id": "2940a43b",
436 | "metadata": {},
437 | "source": [
438 | "The following function computes a v-greedy policy."
439 | ]
440 | },
441 | {
442 | "cell_type": "code",
443 | "execution_count": null,
444 | "id": "237f295c",
445 | "metadata": {
446 | "hide-output": false
447 | },
448 | "outputs": [],
449 | "source": [
450 | "@jax.jit\n",
451 | "def get_greedy(v, model):\n",
452 | " \"\"\"Get a v-greedy policy. Returns a zero-based array.\"\"\"\n",
453 | " z_values, Q, x_values, d_values, ϕ_values, p, c, κ = model\n",
454 | " z_indices = jnp.arange(len(z_values))\n",
455 | " res = B(x_values, z_indices, v, model)\n",
456 | " return jnp.argmax(res, axis=2)"
457 | ]
458 | },
459 | {
460 | "cell_type": "markdown",
461 | "id": "67d75d23",
462 | "metadata": {},
463 | "source": [
464 | "Here’s code to solve the model using value function iteration."
465 | ]
466 | },
467 | {
468 | "cell_type": "code",
469 | "execution_count": null,
470 | "id": "886c4c72",
471 | "metadata": {
472 | "hide-output": false
473 | },
474 | "outputs": [],
475 | "source": [
476 | "@jax.jit\n",
477 | "def solve_inventory_model(v_init, model, max_iter=10_000, tol=1e-6):\n",
478 | " \"\"\"Use successive_approx to get v_star and then compute greedy.\"\"\"\n",
479 | "\n",
480 | " def update(state):\n",
481 | " error, i, v = state\n",
482 | " new_v = T(v, model)\n",
483 | " new_error = jnp.max(jnp.abs(new_v - v))\n",
484 | " new_i = i + 1\n",
485 | " return new_error, new_i, new_v\n",
486 | "\n",
487 | " def test(state):\n",
488 | " error, i, v = state\n",
489 | " return (i < max_iter) & (error > tol)\n",
490 | "\n",
491 | " i, error = 0, tol + 1\n",
492 | " initial_state = error, i, v_init\n",
493 | " final_state = jax.lax.while_loop(test, update, initial_state)\n",
494 | " error, i, v_star = final_state\n",
495 | " σ_star = get_greedy(v_star, model)\n",
496 | " return v_star, σ_star"
497 | ]
498 | },
499 | {
500 | "cell_type": "markdown",
501 | "id": "726c6612",
502 | "metadata": {},
503 | "source": [
504 | "Now let’s create an instance and solve it."
505 | ]
506 | },
507 | {
508 | "cell_type": "code",
509 | "execution_count": null,
510 | "id": "88026937",
511 | "metadata": {
512 | "hide-output": false
513 | },
514 | "outputs": [],
515 | "source": [
516 | "model = create_sdd_inventory_model()\n",
517 | "z_values, Q, x_values, d_values, ϕ_values, p, c, κ = model\n",
518 | "n_z = len(z_values)\n",
519 | "n_x = len(x_values)\n",
520 | "v_init = jnp.zeros((n_x, n_z), dtype=float)"
521 | ]
522 | },
523 | {
524 | "cell_type": "code",
525 | "execution_count": null,
526 | "id": "b9d9236b",
527 | "metadata": {
528 | "hide-output": false
529 | },
530 | "outputs": [],
531 | "source": [
532 | "start = time()\n",
533 | "v_star, σ_star = solve_inventory_model(v_init, model)\n",
534 | "\n",
535 | "# Pause until execution finishes\n",
536 | "jax.tree_util.tree_map(lambda x: x.block_until_ready(), (v_star, σ_star))\n",
537 | "\n",
538 | "jax_time_with_compile = time() - start\n",
539 | "print(f\"compile plus execution time = {jax_time_with_compile * 1000:.6f} ms\")"
540 | ]
541 | },
542 | {
543 | "cell_type": "markdown",
544 | "id": "eaf220cd",
545 | "metadata": {},
546 | "source": [
547 | "Let’s run again to get rid of the compile time."
548 | ]
549 | },
550 | {
551 | "cell_type": "code",
552 | "execution_count": null,
553 | "id": "f66ffaf3",
554 | "metadata": {
555 | "hide-output": false
556 | },
557 | "outputs": [],
558 | "source": [
559 | "start = time()\n",
560 | "v_star, σ_star = solve_inventory_model(v_init, model)\n",
561 | "\n",
562 | "# Pause until execution finishes\n",
563 | "jax.tree_util.tree_map(lambda x: x.block_until_ready(), (v_star, σ_star))\n",
564 | "\n",
565 | "jax_time_without_compile = time() - start\n",
566 | "print(f\"execution time = {jax_time_without_compile * 1000:.6f} ms\")"
567 | ]
568 | },
569 | {
570 | "cell_type": "markdown",
571 | "id": "48baebd3",
572 | "metadata": {},
573 | "source": [
574 | "Now let’s do a simulation.\n",
575 | "\n",
576 | "We’ll begin by converting back to NumPy arrays for convenience"
577 | ]
578 | },
579 | {
580 | "cell_type": "code",
581 | "execution_count": null,
582 | "id": "2a4a59a0",
583 | "metadata": {
584 | "hide-output": false
585 | },
586 | "outputs": [],
587 | "source": [
588 | "Q = np.array(Q)\n",
589 | "z_values = np.array(z_values)\n",
590 | "z_mc = qe.MarkovChain(Q, z_values)"
591 | ]
592 | },
593 | {
594 | "cell_type": "markdown",
595 | "id": "a386caee",
596 | "metadata": {},
597 | "source": [
598 | "Here’s code to simulate inventories"
599 | ]
600 | },
601 | {
602 | "cell_type": "code",
603 | "execution_count": null,
604 | "id": "fd7e64ac",
605 | "metadata": {
606 | "hide-output": false
607 | },
608 | "outputs": [],
609 | "source": [
610 | "def sim_inventories(ts_length, X_init=0):\n",
611 | " \"\"\"Simulate given the optimal policy.\"\"\"\n",
612 | " global p, z_mc\n",
613 | " \n",
614 | " z_idx = z_mc.simulate_indices(ts_length, init=1)\n",
615 | " X = np.zeros(ts_length, dtype=np.int32)\n",
616 | " X[0] = X_init\n",
617 | " rand = np.random.default_rng().geometric(p=p, size=ts_length-1) - 1\n",
618 | " \n",
619 | " for t in range(ts_length-1):\n",
620 | " X[t+1] = np.maximum(X[t] - rand[t], 0) + σ_star[X[t], z_idx[t]]\n",
621 | " \n",
622 | " return X, z_values[z_idx]"
623 | ]
624 | },
625 | {
626 | "cell_type": "markdown",
627 | "id": "baeb9cf8",
628 | "metadata": {},
629 | "source": [
630 | "Here’s code to generate a plot."
631 | ]
632 | },
633 | {
634 | "cell_type": "code",
635 | "execution_count": null,
636 | "id": "457f28ee",
637 | "metadata": {
638 | "hide-output": false
639 | },
640 | "outputs": [],
641 | "source": [
642 | "def plot_ts(ts_length=400, fontsize=10):\n",
643 | " X, Z = sim_inventories(ts_length)\n",
644 | " fig, axes = plt.subplots(2, 1, figsize=(9, 5.5))\n",
645 | "\n",
646 | " ax = axes[0]\n",
647 | " ax.plot(X, label=r\"$X_t$\", alpha=0.7)\n",
648 | " ax.set_xlabel(r\"$t$\", fontsize=fontsize)\n",
649 | " ax.set_ylabel(\"inventory\", fontsize=fontsize)\n",
650 | " ax.legend(fontsize=fontsize, frameon=False)\n",
651 | " ax.set_ylim(0, np.max(X)+3)\n",
652 | "\n",
653 | " # calculate interest rate from discount factors\n",
654 | " r = (1 / Z) - 1\n",
655 | "\n",
656 | " ax = axes[1]\n",
657 | " ax.plot(r, label=r\"$r_t$\", alpha=0.7)\n",
658 | " ax.set_xlabel(r\"$t$\", fontsize=fontsize)\n",
659 | " ax.set_ylabel(\"interest rate\", fontsize=fontsize)\n",
660 | " ax.legend(fontsize=fontsize, frameon=False)\n",
661 | "\n",
662 | " plt.tight_layout()\n",
663 | " plt.show()"
664 | ]
665 | },
666 | {
667 | "cell_type": "markdown",
668 | "id": "5e6d097d",
669 | "metadata": {},
670 | "source": [
671 | "Let’s take a look."
672 | ]
673 | },
674 | {
675 | "cell_type": "code",
676 | "execution_count": null,
677 | "id": "d39b8cfd",
678 | "metadata": {
679 | "hide-output": false
680 | },
681 | "outputs": [],
682 | "source": [
683 | "plot_ts()"
684 | ]
685 | }
686 | ],
687 | "metadata": {
688 | "date": 1765244755.2634375,
689 | "filename": "inventory_ssd.md",
690 | "kernelspec": {
691 | "display_name": "Python",
692 | "language": "python3",
693 | "name": "python3"
694 | },
695 | "title": "Inventory Management Model"
696 | },
697 | "nbformat": 4,
698 | "nbformat_minor": 5
699 | }
--------------------------------------------------------------------------------
/short_path.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "79d5381e",
6 | "metadata": {},
7 | "source": [
8 | "$$\n",
9 | "\\newcommand{\\argmax}{arg\\,max}\n",
10 | "\\newcommand{\\argmin}{arg\\,min}\n",
11 | "$$"
12 | ]
13 | },
14 | {
15 | "cell_type": "markdown",
16 | "id": "4b58f30e",
17 | "metadata": {},
18 | "source": [
19 | "# Shortest Paths"
20 | ]
21 | },
22 | {
23 | "cell_type": "markdown",
24 | "id": "2f5cffec",
25 | "metadata": {},
26 | "source": [
27 | "# GPU\n",
28 | "\n",
29 | "This lecture was built using a machine with access to a GPU.\n",
30 | "\n",
31 | "[Google Colab](https://colab.research.google.com/) has a free tier with GPUs\n",
32 | "that you can access as follows:\n",
33 | "\n",
34 | "1. Click on the “play” icon top right \n",
35 | "1. Select Colab \n",
36 | "1. Set the runtime environment to include a GPU "
37 | ]
38 | },
39 | {
40 | "cell_type": "markdown",
41 | "id": "b938687a",
42 | "metadata": {},
43 | "source": [
44 | "## Overview\n",
45 | "\n",
46 | "This lecture is the extended version of the [shortest path lecture](https://python.quantecon.org/short_path.html) using JAX. Please see that lecture for all background and notation.\n",
47 | "\n",
48 | "Let’s start by importing the libraries."
49 | ]
50 | },
51 | {
52 | "cell_type": "code",
53 | "execution_count": null,
54 | "id": "111fbcb3",
55 | "metadata": {
56 | "hide-output": false
57 | },
58 | "outputs": [],
59 | "source": [
60 | "import numpy as np\n",
61 | "import jax.numpy as jnp\n",
62 | "import jax"
63 | ]
64 | },
65 | {
66 | "cell_type": "markdown",
67 | "id": "3d963fbe",
68 | "metadata": {},
69 | "source": [
70 | "Let’s check the GPU we are running"
71 | ]
72 | },
73 | {
74 | "cell_type": "code",
75 | "execution_count": null,
76 | "id": "187d3db2",
77 | "metadata": {
78 | "hide-output": false
79 | },
80 | "outputs": [],
81 | "source": [
82 | "!nvidia-smi"
83 | ]
84 | },
85 | {
86 | "cell_type": "markdown",
87 | "id": "e339b99c",
88 | "metadata": {},
89 | "source": [
90 | "## Solving for Minimum Cost-to-Go\n",
91 | "\n",
92 | "Let $ J(v) $ denote the minimum cost-to-go from node $ v $,\n",
93 | "understood as the total cost from $ v $ if we take the best route.\n",
94 | "\n",
95 | "Let’s look at an algorithm for computing $ J $ and then think about how to\n",
96 | "implement it."
97 | ]
98 | },
99 | {
100 | "cell_type": "markdown",
101 | "id": "d4020f4f",
102 | "metadata": {},
103 | "source": [
104 | "### The Algorithm\n",
105 | "\n",
106 | "The standard algorithm for finding $ J $ is to start an initial guess and then iterate.\n",
107 | "\n",
108 | "This is a standard approach to solving nonlinear equations, often called\n",
109 | "the method of **successive approximations**.\n",
110 | "\n",
111 | "Our initial guess will be\n",
112 | "\n",
113 | "\n",
114 | "\n",
115 | "$$\n",
116 | "J_0(v) = 0 \\text{ for all } v \\tag{13.1}\n",
117 | "$$\n",
118 | "\n",
119 | "Now\n",
120 | "\n",
121 | "1. Set $ n = 0 $ \n",
122 | "1. Set $ J_{n+1} (v) = \\min_{w \\in F_v} \\{ c(v, w) + J_n(w) \\} $ for all $ v $ \n",
123 | "1. If $ J_{n+1} $ and $ J_n $ are not equal then increment $ n $, go to 2 \n",
124 | "\n",
125 | "\n",
126 | "This sequence converges to $ J $.\n",
127 | "\n",
128 | "Let’s start by defining the **distance matrix** $ Q $."
129 | ]
130 | },
131 | {
132 | "cell_type": "code",
133 | "execution_count": null,
134 | "id": "14576fec",
135 | "metadata": {
136 | "hide-output": false
137 | },
138 | "outputs": [],
139 | "source": [
140 | "inf = jnp.inf\n",
141 | "Q = jnp.array([[inf, 1, 5, 3, inf, inf, inf],\n",
142 | " [inf, inf, inf, 9, 6, inf, inf],\n",
143 | " [inf, inf, inf, inf, inf, 2, inf],\n",
144 | " [inf, inf, inf, inf, inf, 4, 8],\n",
145 | " [inf, inf, inf, inf, inf, inf, 4],\n",
146 | " [inf, inf, inf, inf, inf, inf, 1],\n",
147 | " [inf, inf, inf, inf, inf, inf, 0]])"
148 | ]
149 | },
150 | {
151 | "cell_type": "markdown",
152 | "id": "9f8be79d",
153 | "metadata": {},
154 | "source": [
155 | "Notice that the cost of staying still (on the principle diagonal) is set to\n",
156 | "\n",
157 | "- `jnp.inf` for non-destination nodes — moving on is required. \n",
158 | "- `0` for the destination node — here is where we stop. \n",
159 | "\n",
160 | "\n",
161 | "Let’s try with this example using python `while` loop and some `jax` vectorized code:"
162 | ]
163 | },
164 | {
165 | "cell_type": "code",
166 | "execution_count": null,
167 | "id": "cf0f454b",
168 | "metadata": {
169 | "hide-output": false
170 | },
171 | "outputs": [],
172 | "source": [
173 | "%%time\n",
174 | "\n",
175 | "num_nodes = Q.shape[0]\n",
176 | "J = jnp.zeros(num_nodes)\n",
177 | "\n",
178 | "max_iter = 500\n",
179 | "i = 0\n",
180 | "\n",
181 | "while i < max_iter:\n",
182 | " next_J = jnp.min(Q + J, axis=1)\n",
183 | " if jnp.allclose(next_J, J):\n",
184 | " break\n",
185 | " else:\n",
186 | " J = next_J.copy()\n",
187 | " i += 1\n",
188 | "\n",
189 | "print(\"The cost-to-go function is\", J)"
190 | ]
191 | },
192 | {
193 | "cell_type": "markdown",
194 | "id": "ffdf10ca",
195 | "metadata": {},
196 | "source": [
197 | "We can further optimize the above code by using [jax.lax.while_loop](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.while_loop.html). The extra acceleration is due to the fact that the entire operation can be optimized by the JAX compiler and launched as a single kernel on the GPU."
198 | ]
199 | },
200 | {
201 | "cell_type": "code",
202 | "execution_count": null,
203 | "id": "c5ce6cc9",
204 | "metadata": {
205 | "hide-output": false
206 | },
207 | "outputs": [],
208 | "source": [
209 | "max_iter = 500\n",
210 | "num_nodes = Q.shape[0]\n",
211 | "J = jnp.zeros(num_nodes)"
212 | ]
213 | },
214 | {
215 | "cell_type": "code",
216 | "execution_count": null,
217 | "id": "bffa66dc",
218 | "metadata": {
219 | "hide-output": false
220 | },
221 | "outputs": [],
222 | "source": [
223 | "def body_fun(values):\n",
224 | " # Define the body function of while loop\n",
225 | " i, J, break_cond = values\n",
226 | "\n",
227 | " # Update J and break condition\n",
228 | " next_J = jnp.min(Q + J, axis=1)\n",
229 | " break_condition = jnp.allclose(next_J, J)\n",
230 | "\n",
231 | " # Return next iteration values\n",
232 | " return i + 1, next_J, break_condition"
233 | ]
234 | },
235 | {
236 | "cell_type": "code",
237 | "execution_count": null,
238 | "id": "c6d641be",
239 | "metadata": {
240 | "hide-output": false
241 | },
242 | "outputs": [],
243 | "source": [
244 | "def cond_fun(values):\n",
245 | " i, J, break_condition = values\n",
246 | " return ~break_condition & (i < max_iter)"
247 | ]
248 | },
249 | {
250 | "cell_type": "markdown",
251 | "id": "535ce0bb",
252 | "metadata": {},
253 | "source": [
254 | "Let’s see the timing for JIT compilation of the functions and runtime results."
255 | ]
256 | },
257 | {
258 | "cell_type": "code",
259 | "execution_count": null,
260 | "id": "3aba86ea",
261 | "metadata": {
262 | "hide-output": false
263 | },
264 | "outputs": [],
265 | "source": [
266 | "%%time\n",
267 | "jax.lax.while_loop(cond_fun, body_fun, init_val=(0, J, False))[1].block_until_ready()"
268 | ]
269 | },
270 | {
271 | "cell_type": "markdown",
272 | "id": "7631c516",
273 | "metadata": {},
274 | "source": [
275 | "Now, this runs faster once we have the JIT compiled JAX version of the functions."
276 | ]
277 | },
278 | {
279 | "cell_type": "code",
280 | "execution_count": null,
281 | "id": "df6dd19d",
282 | "metadata": {
283 | "hide-output": false
284 | },
285 | "outputs": [],
286 | "source": [
287 | "%%time\n",
288 | "jax.lax.while_loop(cond_fun, body_fun, init_val=(0, J, False))[1].block_until_ready()"
289 | ]
290 | },
291 | {
292 | "cell_type": "markdown",
293 | "id": "d4a23d7e",
294 | "metadata": {},
295 | "source": [
296 | ">**Note**\n",
297 | ">\n",
298 | ">Large speed gains while using `jax.lax.while_loop` won’t be realized unless the shortest path problem is relatively large."
299 | ]
300 | },
301 | {
302 | "cell_type": "markdown",
303 | "id": "83aa59a4",
304 | "metadata": {},
305 | "source": [
306 | "## Exercises"
307 | ]
308 | },
309 | {
310 | "cell_type": "markdown",
311 | "id": "38a014b2",
312 | "metadata": {},
313 | "source": [
314 | "## Exercise 13.1\n",
315 | "\n",
316 | "The text below describes a weighted directed graph.\n",
317 | "\n",
318 | "The line `node0, node1 0.04, node8 11.11, node14 72.21` means that from node0 we can go to\n",
319 | "\n",
320 | "- node1 at cost 0.04 \n",
321 | "- node8 at cost 11.11 \n",
322 | "- node14 at cost 72.21 \n",
323 | "\n",
324 | "\n",
325 | "No other nodes can be reached directly from node0.\n",
326 | "\n",
327 | "Other lines have a similar interpretation.\n",
328 | "\n",
329 | "Your task is to use the algorithm given above to find the optimal path and its cost."
330 | ]
331 | },
332 | {
333 | "cell_type": "code",
334 | "execution_count": null,
335 | "id": "8fbab404",
336 | "metadata": {
337 | "hide-output": false
338 | },
339 | "outputs": [],
340 | "source": [
341 | "%%file graph.txt\n",
342 | "node0, node1 0.04, node8 11.11, node14 72.21\n",
343 | "node1, node46 1247.25, node6 20.59, node13 64.94\n",
344 | "node2, node66 54.18, node31 166.80, node45 1561.45\n",
345 | "node3, node20 133.65, node6 2.06, node11 42.43\n",
346 | "node4, node75 3706.67, node5 0.73, node7 1.02\n",
347 | "node5, node45 1382.97, node7 3.33, node11 34.54\n",
348 | "node6, node31 63.17, node9 0.72, node10 13.10\n",
349 | "node7, node50 478.14, node9 3.15, node10 5.85\n",
350 | "node8, node69 577.91, node11 7.45, node12 3.18\n",
351 | "node9, node70 2454.28, node13 4.42, node20 16.53\n",
352 | "node10, node89 5352.79, node12 1.87, node16 25.16\n",
353 | "node11, node94 4961.32, node18 37.55, node20 65.08\n",
354 | "node12, node84 3914.62, node24 34.32, node28 170.04\n",
355 | "node13, node60 2135.95, node38 236.33, node40 475.33\n",
356 | "node14, node67 1878.96, node16 2.70, node24 38.65\n",
357 | "node15, node91 3597.11, node17 1.01, node18 2.57\n",
358 | "node16, node36 392.92, node19 3.49, node38 278.71\n",
359 | "node17, node76 783.29, node22 24.78, node23 26.45\n",
360 | "node18, node91 3363.17, node23 16.23, node28 55.84\n",
361 | "node19, node26 20.09, node20 0.24, node28 70.54\n",
362 | "node20, node98 3523.33, node24 9.81, node33 145.80\n",
363 | "node21, node56 626.04, node28 36.65, node31 27.06\n",
364 | "node22, node72 1447.22, node39 136.32, node40 124.22\n",
365 | "node23, node52 336.73, node26 2.66, node33 22.37\n",
366 | "node24, node66 875.19, node26 1.80, node28 14.25\n",
367 | "node25, node70 1343.63, node32 36.58, node35 45.55\n",
368 | "node26, node47 135.78, node27 0.01, node42 122.00\n",
369 | "node27, node65 480.55, node35 48.10, node43 246.24\n",
370 | "node28, node82 2538.18, node34 21.79, node36 15.52\n",
371 | "node29, node64 635.52, node32 4.22, node33 12.61\n",
372 | "node30, node98 2616.03, node33 5.61, node35 13.95\n",
373 | "node31, node98 3350.98, node36 20.44, node44 125.88\n",
374 | "node32, node97 2613.92, node34 3.33, node35 1.46\n",
375 | "node33, node81 1854.73, node41 3.23, node47 111.54\n",
376 | "node34, node73 1075.38, node42 51.52, node48 129.45\n",
377 | "node35, node52 17.57, node41 2.09, node50 78.81\n",
378 | "node36, node71 1171.60, node54 101.08, node57 260.46\n",
379 | "node37, node75 269.97, node38 0.36, node46 80.49\n",
380 | "node38, node93 2767.85, node40 1.79, node42 8.78\n",
381 | "node39, node50 39.88, node40 0.95, node41 1.34\n",
382 | "node40, node75 548.68, node47 28.57, node54 53.46\n",
383 | "node41, node53 18.23, node46 0.28, node54 162.24\n",
384 | "node42, node59 141.86, node47 10.08, node72 437.49\n",
385 | "node43, node98 2984.83, node54 95.06, node60 116.23\n",
386 | "node44, node91 807.39, node46 1.56, node47 2.14\n",
387 | "node45, node58 79.93, node47 3.68, node49 15.51\n",
388 | "node46, node52 22.68, node57 27.50, node67 65.48\n",
389 | "node47, node50 2.82, node56 49.31, node61 172.64\n",
390 | "node48, node99 2564.12, node59 34.52, node60 66.44\n",
391 | "node49, node78 53.79, node50 0.51, node56 10.89\n",
392 | "node50, node85 251.76, node53 1.38, node55 20.10\n",
393 | "node51, node98 2110.67, node59 23.67, node60 73.79\n",
394 | "node52, node94 1471.80, node64 102.41, node66 123.03\n",
395 | "node53, node72 22.85, node56 4.33, node67 88.35\n",
396 | "node54, node88 967.59, node59 24.30, node73 238.61\n",
397 | "node55, node84 86.09, node57 2.13, node64 60.80\n",
398 | "node56, node76 197.03, node57 0.02, node61 11.06\n",
399 | "node57, node86 701.09, node58 0.46, node60 7.01\n",
400 | "node58, node83 556.70, node64 29.85, node65 34.32\n",
401 | "node59, node90 820.66, node60 0.72, node71 0.67\n",
402 | "node60, node76 48.03, node65 4.76, node67 1.63\n",
403 | "node61, node98 1057.59, node63 0.95, node64 4.88\n",
404 | "node62, node91 132.23, node64 2.94, node76 38.43\n",
405 | "node63, node66 4.43, node72 70.08, node75 56.34\n",
406 | "node64, node80 47.73, node65 0.30, node76 11.98\n",
407 | "node65, node94 594.93, node66 0.64, node73 33.23\n",
408 | "node66, node98 395.63, node68 2.66, node73 37.53\n",
409 | "node67, node82 153.53, node68 0.09, node70 0.98\n",
410 | "node68, node94 232.10, node70 3.35, node71 1.66\n",
411 | "node69, node99 247.80, node70 0.06, node73 8.99\n",
412 | "node70, node76 27.18, node72 1.50, node73 8.37\n",
413 | "node71, node89 104.50, node74 8.86, node91 284.64\n",
414 | "node72, node76 15.32, node84 102.77, node92 133.06\n",
415 | "node73, node83 52.22, node76 1.40, node90 243.00\n",
416 | "node74, node81 1.07, node76 0.52, node78 8.08\n",
417 | "node75, node92 68.53, node76 0.81, node77 1.19\n",
418 | "node76, node85 13.18, node77 0.45, node78 2.36\n",
419 | "node77, node80 8.94, node78 0.98, node86 64.32\n",
420 | "node78, node98 355.90, node81 2.59\n",
421 | "node79, node81 0.09, node85 1.45, node91 22.35\n",
422 | "node80, node92 121.87, node88 28.78, node98 264.34\n",
423 | "node81, node94 99.78, node89 39.52, node92 99.89\n",
424 | "node82, node91 47.44, node88 28.05, node93 11.99\n",
425 | "node83, node94 114.95, node86 8.75, node88 5.78\n",
426 | "node84, node89 19.14, node94 30.41, node98 121.05\n",
427 | "node85, node97 94.51, node87 2.66, node89 4.90\n",
428 | "node86, node97 85.09\n",
429 | "node87, node88 0.21, node91 11.14, node92 21.23\n",
430 | "node88, node93 1.31, node91 6.83, node98 6.12\n",
431 | "node89, node97 36.97, node99 82.12\n",
432 | "node90, node96 23.53, node94 10.47, node99 50.99\n",
433 | "node91, node97 22.17\n",
434 | "node92, node96 10.83, node97 11.24, node99 34.68\n",
435 | "node93, node94 0.19, node97 6.71, node99 32.77\n",
436 | "node94, node98 5.91, node96 2.03\n",
437 | "node95, node98 6.17, node99 0.27\n",
438 | "node96, node98 3.32, node97 0.43, node99 5.87\n",
439 | "node97, node98 0.30\n",
440 | "node98, node99 0.33\n",
441 | "node99,"
442 | ]
443 | },
444 | {
445 | "cell_type": "markdown",
446 | "id": "ac653b66",
447 | "metadata": {},
448 | "source": [
449 | "## Solution\n",
450 | "\n",
451 | "First let’s write a function that reads in the graph data above and builds a distance matrix."
452 | ]
453 | },
454 | {
455 | "cell_type": "code",
456 | "execution_count": null,
457 | "id": "19136c0a",
458 | "metadata": {
459 | "hide-output": false
460 | },
461 | "outputs": [],
462 | "source": [
463 | "num_nodes = 100\n",
464 | "destination_node = 99\n",
465 | "def map_graph_to_distance_matrix(in_file):\n",
466 | "\n",
467 | " # First let's set of the distance matrix Q with inf everywhere\n",
468 | " Q = np.full((num_nodes, num_nodes), np.inf)\n",
469 | "\n",
470 | " # Now we read in the data and modify Q\n",
471 | " with open(in_file) as infile:\n",
472 | " for line in infile:\n",
473 | " elements = line.split(',')\n",
474 | " node = elements.pop(0)\n",
475 | " node = int(node[4:]) # convert node description to integer\n",
476 | " if node != destination_node:\n",
477 | " for element in elements:\n",
478 | " destination, cost = element.split()\n",
479 | " destination = int(destination[4:])\n",
480 | " Q[node, destination] = float(cost)\n",
481 | " Q[destination_node, destination_node] = 0\n",
482 | " return jnp.array(Q)"
483 | ]
484 | },
485 | {
486 | "cell_type": "markdown",
487 | "id": "d9938b61",
488 | "metadata": {},
489 | "source": [
490 | "Let’s write a function `compute_cost_to_go` that returns $ J $ given any valid $ Q $."
491 | ]
492 | },
493 | {
494 | "cell_type": "code",
495 | "execution_count": null,
496 | "id": "98065ae6",
497 | "metadata": {
498 | "hide-output": false
499 | },
500 | "outputs": [],
501 | "source": [
502 | "@jax.jit\n",
503 | "def compute_cost_to_go(Q):\n",
504 | " num_nodes = Q.shape[0]\n",
505 | " J = jnp.zeros(num_nodes) # Initial guess\n",
506 | " max_iter = 500\n",
507 | " i = 0\n",
508 | "\n",
509 | " def body_fun(values):\n",
510 | " # Define the body function of while loop\n",
511 | " i, J, break_cond = values\n",
512 | "\n",
513 | " # Update J and break condition\n",
514 | " next_J = jnp.min(Q + J, axis=1)\n",
515 | " break_condition = jnp.allclose(next_J, J)\n",
516 | "\n",
517 | " # Return next iteration values\n",
518 | " return i + 1, next_J, break_condition\n",
519 | "\n",
520 | " def cond_fun(values):\n",
521 | " i, J, break_condition = values\n",
522 | " return ~break_condition & (i < max_iter)\n",
523 | "\n",
524 | " return jax.lax.while_loop(cond_fun, body_fun,\n",
525 | " init_val=(0, J, False))[1]"
526 | ]
527 | },
528 | {
529 | "cell_type": "markdown",
530 | "id": "b77c52ec",
531 | "metadata": {},
532 | "source": [
533 | "Finally, here’s a function that uses the `cost-to-go` function to obtain the\n",
534 | "optimal path (and its cost)."
535 | ]
536 | },
537 | {
538 | "cell_type": "code",
539 | "execution_count": null,
540 | "id": "103ec6f4",
541 | "metadata": {
542 | "hide-output": false
543 | },
544 | "outputs": [],
545 | "source": [
546 | "def print_best_path(J, Q):\n",
547 | " sum_costs = 0\n",
548 | " current_node = 0\n",
549 | " while current_node != destination_node:\n",
550 | " print(current_node)\n",
551 | " # Move to the next node and increment costs\n",
552 | " next_node = jnp.argmin(Q[current_node, :] + J)\n",
553 | " sum_costs += Q[current_node, next_node]\n",
554 | " current_node = next_node\n",
555 | " print(destination_node)\n",
556 | " print('Cost: ', sum_costs)"
557 | ]
558 | },
559 | {
560 | "cell_type": "markdown",
561 | "id": "3eeaad1b",
562 | "metadata": {},
563 | "source": [
564 | "Okay, now we have the necessary functions, let’s call them to do the job we were assigned."
565 | ]
566 | },
567 | {
568 | "cell_type": "code",
569 | "execution_count": null,
570 | "id": "1189e88d",
571 | "metadata": {
572 | "hide-output": false
573 | },
574 | "outputs": [],
575 | "source": [
576 | "Q = map_graph_to_distance_matrix('graph.txt')"
577 | ]
578 | },
579 | {
580 | "cell_type": "markdown",
581 | "id": "ccc0445c",
582 | "metadata": {},
583 | "source": [
584 | "Let’s see the timings for jitting the function and runtime results."
585 | ]
586 | },
587 | {
588 | "cell_type": "code",
589 | "execution_count": null,
590 | "id": "403b58dd",
591 | "metadata": {
592 | "hide-output": false
593 | },
594 | "outputs": [],
595 | "source": [
596 | "%%time\n",
597 | "J = compute_cost_to_go(Q).block_until_ready()"
598 | ]
599 | },
600 | {
601 | "cell_type": "markdown",
602 | "id": "c037b83a",
603 | "metadata": {},
604 | "source": [
605 | "Let’s run again to eliminate compile time."
606 | ]
607 | },
608 | {
609 | "cell_type": "code",
610 | "execution_count": null,
611 | "id": "aba6d2cc",
612 | "metadata": {
613 | "hide-output": false
614 | },
615 | "outputs": [],
616 | "source": [
617 | "%%time\n",
618 | "J = compute_cost_to_go(Q).block_until_ready()"
619 | ]
620 | },
621 | {
622 | "cell_type": "code",
623 | "execution_count": null,
624 | "id": "b35e8c9c",
625 | "metadata": {
626 | "hide-output": false
627 | },
628 | "outputs": [],
629 | "source": [
630 | "print_best_path(J, Q)"
631 | ]
632 | },
633 | {
634 | "cell_type": "markdown",
635 | "id": "28a00eae",
636 | "metadata": {},
637 | "source": [
638 | "The total cost of the path should agree with $ J[0] $ so let’s check this."
639 | ]
640 | },
641 | {
642 | "cell_type": "code",
643 | "execution_count": null,
644 | "id": "1a2922b9",
645 | "metadata": {
646 | "hide-output": false
647 | },
648 | "outputs": [],
649 | "source": [
650 | "J[0].item()"
651 | ]
652 | }
653 | ],
654 | "metadata": {
655 | "date": 1765244755.749365,
656 | "filename": "short_path.md",
657 | "kernelspec": {
658 | "display_name": "Python",
659 | "language": "python3",
660 | "name": "python3"
661 | },
662 | "title": "Shortest Paths"
663 | },
664 | "nbformat": 4,
665 | "nbformat_minor": 5
666 | }
--------------------------------------------------------------------------------
/cake_eating_numerical.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "0355c95e",
6 | "metadata": {},
7 | "source": [
8 | "# Cake Eating: Numerical Methods"
9 | ]
10 | },
11 | {
12 | "cell_type": "markdown",
13 | "id": "518c125c",
14 | "metadata": {},
15 | "source": [
16 | "# GPU\n",
17 | "\n",
18 | "This lecture was built using a machine with JAX installed and access to a GPU.\n",
19 | "\n",
20 | "To run this lecture on [Google Colab](https://colab.research.google.com/), click on the “play” icon top right, select Colab, and set the runtime environment to include a GPU.\n",
21 | "\n",
22 | "To run this lecture on your own machine, you need to install [Google JAX](https://github.com/google/jax).\n",
23 | "\n",
24 | "This lecture is the extended JAX implementation of [this lecture](https://python.quantecon.org/cake_eating_numerical.html).\n",
25 | "\n",
26 | "Please refer that lecture for all background and notation.\n",
27 | "\n",
28 | "In addition to JAX and Anaconda, this lecture will need the following libraries:"
29 | ]
30 | },
31 | {
32 | "cell_type": "code",
33 | "execution_count": null,
34 | "id": "82c19e76",
35 | "metadata": {
36 | "hide-output": false
37 | },
38 | "outputs": [],
39 | "source": [
40 | "!pip install quantecon"
41 | ]
42 | },
43 | {
44 | "cell_type": "markdown",
45 | "id": "73461742",
46 | "metadata": {},
47 | "source": [
48 | "We will use the following imports."
49 | ]
50 | },
51 | {
52 | "cell_type": "code",
53 | "execution_count": null,
54 | "id": "2772d468",
55 | "metadata": {
56 | "hide-output": false
57 | },
58 | "outputs": [],
59 | "source": [
60 | "import jax\n",
61 | "import jax.numpy as jnp\n",
62 | "import matplotlib.pyplot as plt\n",
63 | "from collections import namedtuple\n",
64 | "import time"
65 | ]
66 | },
67 | {
68 | "cell_type": "markdown",
69 | "id": "9d4caca9",
70 | "metadata": {},
71 | "source": [
72 | "Let’s check the GPU we are running"
73 | ]
74 | },
75 | {
76 | "cell_type": "code",
77 | "execution_count": null,
78 | "id": "5ebff432",
79 | "metadata": {
80 | "hide-output": false
81 | },
82 | "outputs": [],
83 | "source": [
84 | "!nvidia-smi"
85 | ]
86 | },
87 | {
88 | "cell_type": "markdown",
89 | "id": "484b052f",
90 | "metadata": {},
91 | "source": [
92 | "## Reviewing the Model\n",
93 | "\n",
94 | "Recall in particular that the Bellman equation is\n",
95 | "\n",
96 | "\n",
97 | "\n",
98 | "$$\n",
99 | "v(x) = \\max_{0\\leq c \\leq x} \\{u(c) + \\beta v(x-c)\\}\n",
100 | "\\quad \\text{for all } x \\geq 0. \\tag{16.1}\n",
101 | "$$\n",
102 | "\n",
103 | "where $ u $ is the CRRA utility function."
104 | ]
105 | },
106 | {
107 | "cell_type": "markdown",
108 | "id": "cff7a4e5",
109 | "metadata": {},
110 | "source": [
111 | "## Implementation using JAX\n",
112 | "\n",
113 | "The analytical solutions for the value function and optimal policy were found\n",
114 | "to be as follows."
115 | ]
116 | },
117 | {
118 | "cell_type": "code",
119 | "execution_count": null,
120 | "id": "2a182152",
121 | "metadata": {
122 | "hide-output": false
123 | },
124 | "outputs": [],
125 | "source": [
126 | "@jax.jit\n",
127 | "def c_star(x, β, γ):\n",
128 | " return (1 - β ** (1/γ)) * x\n",
129 | "\n",
130 | "@jax.jit\n",
131 | "def v_star(x, β, γ):\n",
132 | " return (1 - β**(1 / γ))**(-γ) * (x**(1-γ) / (1-γ))"
133 | ]
134 | },
135 | {
136 | "cell_type": "markdown",
137 | "id": "e79f2516",
138 | "metadata": {},
139 | "source": [
140 | "Let’s define a model to represent the Cake Eating Problem."
141 | ]
142 | },
143 | {
144 | "cell_type": "code",
145 | "execution_count": null,
146 | "id": "829658ad",
147 | "metadata": {
148 | "hide-output": false
149 | },
150 | "outputs": [],
151 | "source": [
152 | "CEM = namedtuple('CakeEatingModel',\n",
153 | " ('β', 'γ', 'x_grid', 'c_grid'))"
154 | ]
155 | },
156 | {
157 | "cell_type": "code",
158 | "execution_count": null,
159 | "id": "befb17aa",
160 | "metadata": {
161 | "hide-output": false
162 | },
163 | "outputs": [],
164 | "source": [
165 | "def create_cake_eating_model(β=0.96, # discount factor\n",
166 | " γ=1.5, # degree of relative risk aversion\n",
167 | " x_grid_min=1e-3, # exclude zero for numerical stability\n",
168 | " x_grid_max=2.5, # size of cake\n",
169 | " x_grid_size=200):\n",
170 | " x_grid = jnp.linspace(x_grid_min, x_grid_max, x_grid_size)\n",
171 | "\n",
172 | " # c_grid used for finding maximize function values using brute force\n",
173 | " c_grid = jnp.linspace(x_grid_min, x_grid_max, 100*x_grid_size)\n",
174 | " return CEM(β=β, γ=γ, x_grid=x_grid, c_grid=c_grid)"
175 | ]
176 | },
177 | {
178 | "cell_type": "markdown",
179 | "id": "d33a5d72",
180 | "metadata": {},
181 | "source": [
182 | "Now let’s define the CRRA utility function."
183 | ]
184 | },
185 | {
186 | "cell_type": "code",
187 | "execution_count": null,
188 | "id": "810800a0",
189 | "metadata": {
190 | "hide-output": false
191 | },
192 | "outputs": [],
193 | "source": [
194 | "# Utility function\n",
195 | "@jax.jit\n",
196 | "def u(c, cem):\n",
197 | " return (c ** (1 - cem.γ)) / (1 - cem.γ)"
198 | ]
199 | },
200 | {
201 | "cell_type": "markdown",
202 | "id": "51c9fd98",
203 | "metadata": {},
204 | "source": [
205 | "### The Bellman Operator\n",
206 | "\n",
207 | "We introduce the **Bellman operator** $ T $ that takes a function v as an\n",
208 | "argument and returns a new function $ Tv $ defined by\n",
209 | "\n",
210 | "$$\n",
211 | "Tv(x) = \\max_{0 \\leq c \\leq x} \\{u(c) + \\beta v(x - c)\\}\n",
212 | "$$\n",
213 | "\n",
214 | "From $ v $ we get $ Tv $, and applying $ T $ to this yields\n",
215 | "$ T^2 v := T (Tv) $ and so on.\n",
216 | "\n",
217 | "This is called **iterating with the Bellman operator** from initial guess\n",
218 | "$ v $."
219 | ]
220 | },
221 | {
222 | "cell_type": "code",
223 | "execution_count": null,
224 | "id": "e501164c",
225 | "metadata": {
226 | "hide-output": false
227 | },
228 | "outputs": [],
229 | "source": [
230 | "@jax.jit\n",
231 | "def state_action_value(x, c, v_array, ce):\n",
232 | " \"\"\"\n",
233 | " Right hand side of the Bellman equation given x and c.\n",
234 | " * x: scalar element `x`\n",
235 | " * c: c_grid, 1-D array\n",
236 | " * v_array: value function array guess, 1-D array\n",
237 | " * ce: Cake Eating Model instance\n",
238 | " \"\"\"\n",
239 | "\n",
240 | " return jnp.where(c <= x,\n",
241 | " u(c, ce) + ce.β * jnp.interp(x - c, ce.x_grid, v_array),\n",
242 | " -jnp.inf)"
243 | ]
244 | },
245 | {
246 | "cell_type": "markdown",
247 | "id": "485be622",
248 | "metadata": {},
249 | "source": [
250 | "In order to create a vectorized function using `state_action_value`, we use [jax.vmap](https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html).\n",
251 | "This function returns a new vectorized version of the above function which is vectorized on the argument `x`."
252 | ]
253 | },
254 | {
255 | "cell_type": "code",
256 | "execution_count": null,
257 | "id": "5c2396c2",
258 | "metadata": {
259 | "hide-output": false
260 | },
261 | "outputs": [],
262 | "source": [
263 | "state_action_value_vec = jax.vmap(state_action_value, (0, None, None, None))"
264 | ]
265 | },
266 | {
267 | "cell_type": "code",
268 | "execution_count": null,
269 | "id": "ef15c222",
270 | "metadata": {
271 | "hide-output": false
272 | },
273 | "outputs": [],
274 | "source": [
275 | "@jax.jit\n",
276 | "def T(v, ce):\n",
277 | " \"\"\"\n",
278 | " The Bellman operator. Updates the guess of the value function.\n",
279 | "\n",
280 | " * ce: Cake Eating Model instance\n",
281 | " * v: value function array guess, 1-D array\n",
282 | "\n",
283 | " \"\"\"\n",
284 | " return jnp.max(state_action_value_vec(ce.x_grid, ce.c_grid, v, ce), axis=1)"
285 | ]
286 | },
287 | {
288 | "cell_type": "markdown",
289 | "id": "929b1623",
290 | "metadata": {},
291 | "source": [
292 | "Let’s start by creating a Cake Eating Model instance using the default parameterization."
293 | ]
294 | },
295 | {
296 | "cell_type": "code",
297 | "execution_count": null,
298 | "id": "f0229093",
299 | "metadata": {
300 | "hide-output": false
301 | },
302 | "outputs": [],
303 | "source": [
304 | "ce = create_cake_eating_model()"
305 | ]
306 | },
307 | {
308 | "cell_type": "markdown",
309 | "id": "c40d7b39",
310 | "metadata": {},
311 | "source": [
312 | "Now let’s see the iteration of the value function in action.\n",
313 | "\n",
314 | "We start from guess $ v $ given by $ v(x) = u(x) $ for every\n",
315 | "$ x $ grid point."
316 | ]
317 | },
318 | {
319 | "cell_type": "code",
320 | "execution_count": null,
321 | "id": "5aecbae0",
322 | "metadata": {
323 | "hide-output": false
324 | },
325 | "outputs": [],
326 | "source": [
327 | "x_grid = ce.x_grid\n",
328 | "v = u(x_grid, ce) # Initial guess\n",
329 | "n = 12 # Number of iterations\n",
330 | "\n",
331 | "fig, ax = plt.subplots()\n",
332 | "\n",
333 | "ax.plot(x_grid, v, color=plt.cm.jet(0),\n",
334 | " lw=2, alpha=0.6, label='Initial guess')\n",
335 | "\n",
336 | "for i in range(n):\n",
337 | " v = T(v, ce) # Apply the Bellman operator\n",
338 | " ax.plot(x_grid, v, color=plt.cm.jet(i / n), lw=2, alpha=0.6)\n",
339 | "\n",
340 | "ax.legend()\n",
341 | "ax.set_ylabel('value', fontsize=12)\n",
342 | "ax.set_xlabel('cake size $x$', fontsize=12)\n",
343 | "ax.set_title('Value function iterations')\n",
344 | "\n",
345 | "plt.show()"
346 | ]
347 | },
348 | {
349 | "cell_type": "markdown",
350 | "id": "75d7f9da",
351 | "metadata": {},
352 | "source": [
353 | "Let’s introduce a wrapper function called `compute_value_function`\n",
354 | "that iterates until some convergence conditions are satisfied."
355 | ]
356 | },
357 | {
358 | "cell_type": "code",
359 | "execution_count": null,
360 | "id": "acfbb8a7",
361 | "metadata": {
362 | "hide-output": false
363 | },
364 | "outputs": [],
365 | "source": [
366 | "def compute_value_function(ce,\n",
367 | " tol=1e-4,\n",
368 | " max_iter=1000,\n",
369 | " verbose=True,\n",
370 | " print_skip=25):\n",
371 | "\n",
372 | " # Set up loop\n",
373 | " v = jnp.zeros(len(ce.x_grid)) # Initial guess\n",
374 | " i = 0\n",
375 | " error = tol + 1\n",
376 | "\n",
377 | " while i < max_iter and error > tol:\n",
378 | " v_new = T(v, ce)\n",
379 | "\n",
380 | " error = jnp.max(jnp.abs(v - v_new))\n",
381 | " i += 1\n",
382 | "\n",
383 | " if verbose and i % print_skip == 0:\n",
384 | " print(f\"Error at iteration {i} is {error}.\")\n",
385 | "\n",
386 | " v = v_new\n",
387 | "\n",
388 | " if error > tol:\n",
389 | " print(\"Failed to converge!\")\n",
390 | " elif verbose:\n",
391 | " print(f\"\\nConverged in {i} iterations.\")\n",
392 | "\n",
393 | " return v_new"
394 | ]
395 | },
396 | {
397 | "cell_type": "code",
398 | "execution_count": null,
399 | "id": "8e13e06b",
400 | "metadata": {
401 | "hide-output": false
402 | },
403 | "outputs": [],
404 | "source": [
405 | "in_time = time.time()\n",
406 | "v_jax = compute_value_function(ce)\n",
407 | "jax_time = time.time() - in_time"
408 | ]
409 | },
410 | {
411 | "cell_type": "code",
412 | "execution_count": null,
413 | "id": "5851cbba",
414 | "metadata": {
415 | "hide-output": false
416 | },
417 | "outputs": [],
418 | "source": [
419 | "fig, ax = plt.subplots()\n",
420 | "\n",
421 | "ax.plot(x_grid, v_jax, label='Approximate value function')\n",
422 | "ax.set_ylabel('$V(x)$', fontsize=12)\n",
423 | "ax.set_xlabel('$x$', fontsize=12)\n",
424 | "ax.set_title('Value function')\n",
425 | "ax.legend()\n",
426 | "plt.show()"
427 | ]
428 | },
429 | {
430 | "cell_type": "markdown",
431 | "id": "ad382e7f",
432 | "metadata": {},
433 | "source": [
434 | "Next let’s compare it to the analytical solution."
435 | ]
436 | },
437 | {
438 | "cell_type": "code",
439 | "execution_count": null,
440 | "id": "b7030243",
441 | "metadata": {
442 | "hide-output": false
443 | },
444 | "outputs": [],
445 | "source": [
446 | "v_analytical = v_star(ce.x_grid, ce.β, ce.γ)"
447 | ]
448 | },
449 | {
450 | "cell_type": "code",
451 | "execution_count": null,
452 | "id": "3d3f60c8",
453 | "metadata": {
454 | "hide-output": false
455 | },
456 | "outputs": [],
457 | "source": [
458 | "fig, ax = plt.subplots()\n",
459 | "\n",
460 | "ax.plot(x_grid, v_analytical, label='analytical solution')\n",
461 | "ax.plot(x_grid, v_jax, label='numerical solution')\n",
462 | "ax.set_ylabel('$V(x)$', fontsize=12)\n",
463 | "ax.set_xlabel('$x$', fontsize=12)\n",
464 | "ax.legend()\n",
465 | "ax.set_title('Comparison between analytical and numerical value functions')\n",
466 | "plt.show()"
467 | ]
468 | },
469 | {
470 | "cell_type": "markdown",
471 | "id": "d73a7a7c",
472 | "metadata": {},
473 | "source": [
474 | "### Policy Function\n",
475 | "\n",
476 | "Recall that the optimal consumption policy was shown to be\n",
477 | "\n",
478 | "$$\n",
479 | "\\sigma^*(x) = \\left(1-\\beta^{1/\\gamma} \\right) x\n",
480 | "$$\n",
481 | "\n",
482 | "Let’s see if our numerical results lead to something similar.\n",
483 | "\n",
484 | "Our numerical strategy will be to compute\n",
485 | "\n",
486 | "$$\n",
487 | "\\sigma(x) = \\arg \\max_{0 \\leq c \\leq x} \\{u(c) + \\beta v(x - c)\\}\n",
488 | "$$\n",
489 | "\n",
490 | "on a grid of $ x $ points and then interpolate.\n",
491 | "\n",
492 | "For $ v $ we will use the approximation of the value function we obtained\n",
493 | "above.\n",
494 | "\n",
495 | "Here’s the function:"
496 | ]
497 | },
498 | {
499 | "cell_type": "code",
500 | "execution_count": null,
501 | "id": "57c0d4fd",
502 | "metadata": {
503 | "hide-output": false
504 | },
505 | "outputs": [],
506 | "source": [
507 | "@jax.jit\n",
508 | "def σ(ce, v):\n",
509 | " \"\"\"\n",
510 | " The optimal policy function. Given the value function,\n",
511 | " it finds optimal consumption in each state.\n",
512 | "\n",
513 | " * ce: Cake Eating Model instance\n",
514 | " * v: value function array guess, 1-D array\n",
515 | "\n",
516 | " \"\"\"\n",
517 | " i_cs = jnp.argmax(state_action_value_vec(ce.x_grid, ce.c_grid, v, ce), axis=1)\n",
518 | " return ce.c_grid[i_cs]"
519 | ]
520 | },
521 | {
522 | "cell_type": "markdown",
523 | "id": "7bbd3db5",
524 | "metadata": {},
525 | "source": [
526 | "Now let’s pass the approximate value function and compute optimal consumption:"
527 | ]
528 | },
529 | {
530 | "cell_type": "code",
531 | "execution_count": null,
532 | "id": "5930f970",
533 | "metadata": {
534 | "hide-output": false
535 | },
536 | "outputs": [],
537 | "source": [
538 | "c = σ(ce, v_jax)"
539 | ]
540 | },
541 | {
542 | "cell_type": "markdown",
543 | "id": "508f4c0a",
544 | "metadata": {},
545 | "source": [
546 | "Let’s plot this next to the true analytical solution"
547 | ]
548 | },
549 | {
550 | "cell_type": "code",
551 | "execution_count": null,
552 | "id": "0a2aadef",
553 | "metadata": {
554 | "hide-output": false
555 | },
556 | "outputs": [],
557 | "source": [
558 | "c_analytical = c_star(ce.x_grid, ce.β, ce.γ)\n",
559 | "\n",
560 | "fig, ax = plt.subplots()\n",
561 | "\n",
562 | "ax.plot(ce.x_grid, c_analytical, label='analytical')\n",
563 | "ax.plot(ce.x_grid, c, label='numerical')\n",
564 | "ax.set_ylabel(r'$\\sigma(x)$')\n",
565 | "ax.set_xlabel('$x$')\n",
566 | "ax.legend()\n",
567 | "\n",
568 | "plt.show()"
569 | ]
570 | },
571 | {
572 | "cell_type": "markdown",
573 | "id": "4c502de0",
574 | "metadata": {},
575 | "source": [
576 | "## Numba implementation\n",
577 | "\n",
578 | "This section of the lecture is directly adapted from [this lecture](https://python.quantecon.org/cake_eating_numerical.html)\n",
579 | "for the purpose of comparing the results of JAX implementation."
580 | ]
581 | },
582 | {
583 | "cell_type": "code",
584 | "execution_count": null,
585 | "id": "6a9482c1",
586 | "metadata": {
587 | "hide-output": false
588 | },
589 | "outputs": [],
590 | "source": [
591 | "import numpy as np\n",
592 | "from numba import prange, njit\n",
593 | "from quantecon.optimize import brent_max"
594 | ]
595 | },
596 | {
597 | "cell_type": "code",
598 | "execution_count": null,
599 | "id": "9a80ca6c",
600 | "metadata": {
601 | "hide-output": false
602 | },
603 | "outputs": [],
604 | "source": [
605 | "CEMN = namedtuple('CakeEatingModelNumba',\n",
606 | " ('β', 'γ', 'x_grid'))"
607 | ]
608 | },
609 | {
610 | "cell_type": "code",
611 | "execution_count": null,
612 | "id": "f535dabb",
613 | "metadata": {
614 | "hide-output": false
615 | },
616 | "outputs": [],
617 | "source": [
618 | "def create_cake_eating_model_numba(β=0.96, # discount factor\n",
619 | " γ=1.5, # degree of relative risk aversion\n",
620 | " x_grid_min=1e-3, # exclude zero for numerical stability\n",
621 | " x_grid_max=2.5, # size of cake\n",
622 | " x_grid_size=200):\n",
623 | " x_grid = np.linspace(x_grid_min, x_grid_max, x_grid_size)\n",
624 | " return CEMN(β=β, γ=γ, x_grid=x_grid)"
625 | ]
626 | },
627 | {
628 | "cell_type": "code",
629 | "execution_count": null,
630 | "id": "3687d69c",
631 | "metadata": {
632 | "hide-output": false
633 | },
634 | "outputs": [],
635 | "source": [
636 | "# Utility function\n",
637 | "@njit\n",
638 | "def u_numba(c, cem):\n",
639 | " return (c ** (1 - cem.γ)) / (1 - cem.γ)"
640 | ]
641 | },
642 | {
643 | "cell_type": "code",
644 | "execution_count": null,
645 | "id": "636bc343",
646 | "metadata": {
647 | "hide-output": false
648 | },
649 | "outputs": [],
650 | "source": [
651 | "@njit\n",
652 | "def state_action_value_numba(c, x, v_array, cem):\n",
653 | " \"\"\"\n",
654 | " Right hand side of the Bellman equation given x and c.\n",
655 | " * x: scalar element `x`\n",
656 | " * c: consumption\n",
657 | " * v_array: value function array guess, 1-D array\n",
658 | " * cem: Cake Eating Numba Model instance\n",
659 | " \"\"\"\n",
660 | " return u_numba(c, cem) + cem.β * np.interp(x - c, cem.x_grid, v_array)"
661 | ]
662 | },
663 | {
664 | "cell_type": "code",
665 | "execution_count": null,
666 | "id": "4c7dac6b",
667 | "metadata": {
668 | "hide-output": false
669 | },
670 | "outputs": [],
671 | "source": [
672 | "@njit\n",
673 | "def T_numba(v, ce):\n",
674 | " \"\"\"\n",
675 | " The Bellman operator. Updates the guess of the value function.\n",
676 | "\n",
677 | " * ce is an instance of CakeEatingNumba Model\n",
678 | " * v is an array representing a guess of the value function\n",
679 | "\n",
680 | " \"\"\"\n",
681 | " v_new = np.empty_like(v)\n",
682 | "\n",
683 | " for i in prange(len(ce.x_grid)):\n",
684 | " # Maximize RHS of Bellman equation at state x\n",
685 | " v_new[i] = brent_max(state_action_value_numba, 1e-10, ce.x_grid[i],\n",
686 | " args=(ce.x_grid[i], v, ce))[1]\n",
687 | " return v_new"
688 | ]
689 | },
690 | {
691 | "cell_type": "code",
692 | "execution_count": null,
693 | "id": "c9a98455",
694 | "metadata": {
695 | "hide-output": false
696 | },
697 | "outputs": [],
698 | "source": [
699 | "def compute_value_function_numba(ce,\n",
700 | " tol=1e-4,\n",
701 | " max_iter=1000,\n",
702 | " verbose=True,\n",
703 | " print_skip=25):\n",
704 | "\n",
705 | " # Set up loop\n",
706 | " v = np.zeros(len(ce.x_grid)) # Initial guess\n",
707 | " i = 0\n",
708 | " error = tol + 1\n",
709 | "\n",
710 | " while i < max_iter and error > tol:\n",
711 | " v_new = T_numba(v, ce)\n",
712 | "\n",
713 | " error = np.max(np.abs(v - v_new))\n",
714 | " i += 1\n",
715 | "\n",
716 | " if verbose and i % print_skip == 0:\n",
717 | " print(f\"Error at iteration {i} is {error}.\")\n",
718 | "\n",
719 | " v = v_new\n",
720 | "\n",
721 | " if error > tol:\n",
722 | " print(\"Failed to converge!\")\n",
723 | " elif verbose:\n",
724 | " print(f\"\\nConverged in {i} iterations.\")\n",
725 | "\n",
726 | " return v_new"
727 | ]
728 | },
729 | {
730 | "cell_type": "code",
731 | "execution_count": null,
732 | "id": "c6512d1a",
733 | "metadata": {
734 | "hide-output": false
735 | },
736 | "outputs": [],
737 | "source": [
738 | "cen = create_cake_eating_model_numba()"
739 | ]
740 | },
741 | {
742 | "cell_type": "code",
743 | "execution_count": null,
744 | "id": "d6ead6a1",
745 | "metadata": {
746 | "hide-output": false
747 | },
748 | "outputs": [],
749 | "source": [
750 | "in_time = time.time()\n",
751 | "v_np = compute_value_function_numba(cen)\n",
752 | "numba_time = time.time() - in_time"
753 | ]
754 | },
755 | {
756 | "cell_type": "code",
757 | "execution_count": null,
758 | "id": "f29046d8",
759 | "metadata": {
760 | "hide-output": false
761 | },
762 | "outputs": [],
763 | "source": [
764 | "ratio = numba_time/jax_time\n",
765 | "print(f\"JAX implementation is {ratio} times faster than NumPy.\")\n",
766 | "print(f\"JAX time: {jax_time}\")\n",
767 | "print(f\"Numba time: {numba_time}\")"
768 | ]
769 | }
770 | ],
771 | "metadata": {
772 | "date": 1715241420.1465895,
773 | "filename": "cake_eating_numerical.md",
774 | "kernelspec": {
775 | "display_name": "Python",
776 | "language": "python3",
777 | "name": "python3"
778 | },
779 | "title": "Cake Eating: Numerical Methods"
780 | },
781 | "nbformat": 4,
782 | "nbformat_minor": 5
783 | }
--------------------------------------------------------------------------------
/mle.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "b8bcb944",
6 | "metadata": {},
7 | "source": [
8 | "$$\n",
9 | "\\newcommand{\\argmax}{arg\\,max}\n",
10 | "\\newcommand{\\argmin}{arg\\,min}\n",
11 | "$$"
12 | ]
13 | },
14 | {
15 | "cell_type": "markdown",
16 | "id": "9169cefc",
17 | "metadata": {},
18 | "source": [
19 | "# Maximum Likelihood Estimation"
20 | ]
21 | },
22 | {
23 | "cell_type": "markdown",
24 | "id": "0f45b177",
25 | "metadata": {},
26 | "source": [
27 | "# GPU\n",
28 | "\n",
29 | "This lecture was built using a machine with access to a GPU.\n",
30 | "\n",
31 | "[Google Colab](https://colab.research.google.com/) has a free tier with GPUs\n",
32 | "that you can access as follows:\n",
33 | "\n",
34 | "1. Click on the “play” icon top right \n",
35 | "1. Select Colab \n",
36 | "1. Set the runtime environment to include a GPU "
37 | ]
38 | },
39 | {
40 | "cell_type": "markdown",
41 | "id": "a0f7ae40",
42 | "metadata": {},
43 | "source": [
44 | "## Overview\n",
45 | "\n",
46 | "This lecture is the extended JAX implementation of [this section](https://python.quantecon.org/mle.html#mle-with-numerical-methods) of [this lecture](https://python.quantecon.org/mle.html).\n",
47 | "\n",
48 | "Please refer that lecture for all background and notation.\n",
49 | "\n",
50 | "Here we will exploit the automatic differentiation capabilities of JAX rather than calculating derivatives by hand.\n",
51 | "\n",
52 | "We’ll require the following imports:"
53 | ]
54 | },
55 | {
56 | "cell_type": "code",
57 | "execution_count": null,
58 | "id": "01722300",
59 | "metadata": {
60 | "hide-output": false
61 | },
62 | "outputs": [],
63 | "source": [
64 | "import matplotlib.pyplot as plt\n",
65 | "from collections import namedtuple\n",
66 | "import jax.numpy as jnp\n",
67 | "import jax\n",
68 | "from statsmodels.api import Poisson"
69 | ]
70 | },
71 | {
72 | "cell_type": "markdown",
73 | "id": "20f36900",
74 | "metadata": {},
75 | "source": [
76 | "Let’s check the GPU we are running"
77 | ]
78 | },
79 | {
80 | "cell_type": "code",
81 | "execution_count": null,
82 | "id": "d99eaec5",
83 | "metadata": {
84 | "hide-output": false
85 | },
86 | "outputs": [],
87 | "source": [
88 | "!nvidia-smi"
89 | ]
90 | },
91 | {
92 | "cell_type": "markdown",
93 | "id": "c1bb02cf",
94 | "metadata": {},
95 | "source": [
96 | "We will use 64 bit floats with JAX in order to increase the precision."
97 | ]
98 | },
99 | {
100 | "cell_type": "code",
101 | "execution_count": null,
102 | "id": "66583da0",
103 | "metadata": {
104 | "hide-output": false
105 | },
106 | "outputs": [],
107 | "source": [
108 | "jax.config.update(\"jax_enable_x64\", True)"
109 | ]
110 | },
111 | {
112 | "cell_type": "markdown",
113 | "id": "2f2f785b",
114 | "metadata": {},
115 | "source": [
116 | "## MLE with numerical methods (JAX)\n",
117 | "\n",
118 | "Many distributions do not have nice, analytical solutions and therefore require\n",
119 | "numerical methods to solve for parameter estimates.\n",
120 | "\n",
121 | "One such numerical method is the Newton-Raphson algorithm.\n",
122 | "\n",
123 | "Let’s start with a simple example to illustrate the algorithm."
124 | ]
125 | },
126 | {
127 | "cell_type": "markdown",
128 | "id": "55084c33",
129 | "metadata": {},
130 | "source": [
131 | "### A toy model\n",
132 | "\n",
133 | "Our goal is to find the maximum likelihood estimate $ \\hat{\\boldsymbol{\\beta}} $.\n",
134 | "\n",
135 | "At $ \\hat{\\boldsymbol{\\beta}} $, the first derivative of the log-likelihood\n",
136 | "function will be equal to 0.\n",
137 | "\n",
138 | "Let’s illustrate this by supposing\n",
139 | "\n",
140 | "$$\n",
141 | "\\log \\mathcal{L(\\beta)} = - (\\beta - 10) ^2 - 10\n",
142 | "$$\n",
143 | "\n",
144 | "Define the function `logL`."
145 | ]
146 | },
147 | {
148 | "cell_type": "code",
149 | "execution_count": null,
150 | "id": "43bde8ea",
151 | "metadata": {
152 | "hide-output": false
153 | },
154 | "outputs": [],
155 | "source": [
156 | "@jax.jit\n",
157 | "def logL(β):\n",
158 | " return -(β - 10) ** 2 - 10"
159 | ]
160 | },
161 | {
162 | "cell_type": "markdown",
163 | "id": "66bc32e4",
164 | "metadata": {},
165 | "source": [
166 | "To find the value of $ \\frac{d \\log \\mathcal{L(\\boldsymbol{\\beta})}}{d \\boldsymbol{\\beta}} $, we can use [jax.grad](https://jax.readthedocs.io/en/latest/_autosummary/jax.grad.html) which auto-differentiates the given function.\n",
167 | "\n",
168 | "We further use [jax.vmap](https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html) which vectorizes the given function i.e. the function acting upon scalar inputs can now be used with vector inputs."
169 | ]
170 | },
171 | {
172 | "cell_type": "code",
173 | "execution_count": null,
174 | "id": "2ce8e4e8",
175 | "metadata": {
176 | "hide-output": false
177 | },
178 | "outputs": [],
179 | "source": [
180 | "dlogL = jax.vmap(jax.grad(logL))"
181 | ]
182 | },
183 | {
184 | "cell_type": "code",
185 | "execution_count": null,
186 | "id": "0573fb6f",
187 | "metadata": {
188 | "hide-output": false
189 | },
190 | "outputs": [],
191 | "source": [
192 | "β = jnp.linspace(1, 20)\n",
193 | "\n",
194 | "fig, (ax1, ax2) = plt.subplots(2, sharex=True, figsize=(12, 8))\n",
195 | "\n",
196 | "ax1.plot(β, logL(β), lw=2)\n",
197 | "ax2.plot(β, dlogL(β), lw=2)\n",
198 | "\n",
199 | "ax1.set_ylabel(r'$log \\mathcal{L(\\beta)}$',\n",
200 | " rotation=0,\n",
201 | " labelpad=35,\n",
202 | " fontsize=15)\n",
203 | "ax2.set_ylabel(r'$\\frac{dlog \\mathcal{L(\\beta)}}{d \\beta}$ ',\n",
204 | " rotation=0,\n",
205 | " labelpad=35,\n",
206 | " fontsize=19)\n",
207 | "\n",
208 | "ax2.set_xlabel(r'$\\beta$', fontsize=15)\n",
209 | "ax1.grid(), ax2.grid()\n",
210 | "plt.axhline(c='black')\n",
211 | "plt.show()"
212 | ]
213 | },
214 | {
215 | "cell_type": "markdown",
216 | "id": "ce9de378",
217 | "metadata": {},
218 | "source": [
219 | "The plot shows that the maximum likelihood value (the top plot) occurs\n",
220 | "when $ \\frac{d \\log \\mathcal{L(\\boldsymbol{\\beta})}}{d \\boldsymbol{\\beta}} = 0 $ (the bottom\n",
221 | "plot).\n",
222 | "\n",
223 | "Therefore, the likelihood is maximized when $ \\beta = 10 $.\n",
224 | "\n",
225 | "We can also ensure that this value is a *maximum* (as opposed to a\n",
226 | "minimum) by checking that the second derivative (slope of the bottom\n",
227 | "plot) is negative.\n",
228 | "\n",
229 | "The Newton-Raphson algorithm finds a point where the first derivative is\n",
230 | "0.\n",
231 | "\n",
232 | "To use the algorithm, we take an initial guess at the maximum value,\n",
233 | "$ \\beta_0 $ (the OLS parameter estimates might be a reasonable\n",
234 | "guess).\n",
235 | "\n",
236 | "Then we use the updating rule involving gradient information to iterate the algorithm until the error is sufficiently small or the algorithm reaches the maximum number of iterations.\n",
237 | "\n",
238 | "Please refer to [this section](https://python.quantecon.org/mle.html#mle-with-numerical-methods) for the detailed algorithm."
239 | ]
240 | },
241 | {
242 | "cell_type": "markdown",
243 | "id": "1f2df3c2",
244 | "metadata": {},
245 | "source": [
246 | "### A Poisson model\n",
247 | "\n",
248 | "Let’s have a go at implementing the Newton-Raphson algorithm to calculate the maximum likelihood estimations of a Poisson regression.\n",
249 | "\n",
250 | "The Poisson regression has a joint pmf:\n",
251 | "\n",
252 | "$$\n",
253 | "f(y_1, y_2, \\ldots, y_n \\mid \\mathbf{x}_1, \\mathbf{x}_2, \\ldots, \\mathbf{x}_n; \\boldsymbol{\\beta})\n",
254 | " = \\prod_{i=1}^{n} \\frac{\\mu_i^{y_i}}{y_i!} e^{-\\mu_i}\n",
255 | "$$\n",
256 | "\n",
257 | "$$\n",
258 | "\\text{where}\\ \\mu_i\n",
259 | " = \\exp(\\mathbf{x}_i' \\boldsymbol{\\beta})\n",
260 | " = \\exp(\\beta_0 + \\beta_1 x_{i1} + \\ldots + \\beta_k x_{ik})\n",
261 | "$$\n",
262 | "\n",
263 | "We create a `namedtuple` to store the observed values"
264 | ]
265 | },
266 | {
267 | "cell_type": "code",
268 | "execution_count": null,
269 | "id": "75d72b4e",
270 | "metadata": {
271 | "hide-output": false
272 | },
273 | "outputs": [],
274 | "source": [
275 | "RegressionModel = namedtuple('RegressionModel', ['X', 'y'])\n",
276 | "\n",
277 | "def create_regression_model(X, y):\n",
278 | " n, k = X.shape\n",
279 | " # Reshape y as a n_by_1 column vector\n",
280 | " y = y.reshape(n, 1)\n",
281 | " X, y = jax.device_put((X, y))\n",
282 | " return RegressionModel(X=X, y=y)"
283 | ]
284 | },
285 | {
286 | "cell_type": "markdown",
287 | "id": "b335e441",
288 | "metadata": {},
289 | "source": [
290 | "The log likelihood function of the Poisson regression is\n",
291 | "\n",
292 | "$$\n",
293 | "\\underset{\\beta}{\\max} \\Big(\n",
294 | "\\sum_{i=1}^{n} y_i \\log{\\mu_i} -\n",
295 | "\\sum_{i=1}^{n} \\mu_i -\n",
296 | "\\sum_{i=1}^{n} \\log y! \\Big)\n",
297 | "$$\n",
298 | "\n",
299 | "The full derivation can be found [here](https://python.quantecon.org/mle.html#id2).\n",
300 | "\n",
301 | "The log likelihood function involves factorial, but JAX doesn’t have a readily available implementation to compute factorial directly.\n",
302 | "\n",
303 | "In order to compute the factorial efficiently such that we can JIT it, we use\n",
304 | "\n",
305 | "$$\n",
306 | "n! = e^{\\log(\\Gamma(n+1))}\n",
307 | "$$\n",
308 | "\n",
309 | "since [jax.lax.lgamma](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.lgamma.html) and [jax.lax.exp](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.exp.html) are available.\n",
310 | "\n",
311 | "The following function `jax_factorial` computes the factorial using this idea.\n",
312 | "\n",
313 | "Let’s define this function in Python"
314 | ]
315 | },
316 | {
317 | "cell_type": "code",
318 | "execution_count": null,
319 | "id": "3d4f0ebc",
320 | "metadata": {
321 | "hide-output": false
322 | },
323 | "outputs": [],
324 | "source": [
325 | "@jax.jit\n",
326 | "def _factorial(n):\n",
327 | " return jax.lax.exp(jax.lax.lgamma(n + 1.0)).astype(int)\n",
328 | "\n",
329 | "jax_factorial = jax.vmap(_factorial)"
330 | ]
331 | },
332 | {
333 | "cell_type": "markdown",
334 | "id": "82faad1c",
335 | "metadata": {},
336 | "source": [
337 | "Now we can define the log likelihood function in Python"
338 | ]
339 | },
340 | {
341 | "cell_type": "code",
342 | "execution_count": null,
343 | "id": "75d8f728",
344 | "metadata": {
345 | "hide-output": false
346 | },
347 | "outputs": [],
348 | "source": [
349 | "@jax.jit\n",
350 | "def poisson_logL(β, model):\n",
351 | " y = model.y\n",
352 | " μ = jnp.exp(model.X @ β)\n",
353 | " return jnp.sum(model.y * jnp.log(μ) - μ - jnp.log(jax_factorial(y)))"
354 | ]
355 | },
356 | {
357 | "cell_type": "markdown",
358 | "id": "b716b54f",
359 | "metadata": {},
360 | "source": [
361 | "To find the gradient of the `poisson_logL`, we again use [jax.grad](https://jax.readthedocs.io/en/latest/_autosummary/jax.grad.html).\n",
362 | "\n",
363 | "According to [the documentation](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html#jacobians-and-hessians-using-jacfwd-and-jacrev),\n",
364 | "\n",
365 | "- `jax.jacfwd` uses forward-mode automatic differentiation, which is more efficient for “tall” Jacobian matrices, while \n",
366 | "- `jax.jacrev` uses reverse-mode, which is more efficient for “wide” Jacobian matrices. \n",
367 | "\n",
368 | "\n",
369 | "(The documentation also states that when matrices that are near-square, `jax.jacfwd` probably has an edge over `jax.jacrev`.)\n",
370 | "\n",
371 | "Therefore, to find the Hessian, we can directly use `jax.jacfwd`."
372 | ]
373 | },
374 | {
375 | "cell_type": "code",
376 | "execution_count": null,
377 | "id": "cda4dac1",
378 | "metadata": {
379 | "hide-output": false
380 | },
381 | "outputs": [],
382 | "source": [
383 | "G_poisson_logL = jax.grad(poisson_logL)\n",
384 | "H_poisson_logL = jax.jacfwd(G_poisson_logL)"
385 | ]
386 | },
387 | {
388 | "cell_type": "markdown",
389 | "id": "0a6dc97e",
390 | "metadata": {},
391 | "source": [
392 | "Our function `newton_raphson` will take a `RegressionModel` object\n",
393 | "that has an initial guess of the parameter vector $ \\boldsymbol{\\beta}_0 $.\n",
394 | "\n",
395 | "The algorithm will update the parameter vector according to the updating\n",
396 | "rule, and recalculate the gradient and Hessian matrices at the new\n",
397 | "parameter estimates."
398 | ]
399 | },
400 | {
401 | "cell_type": "code",
402 | "execution_count": null,
403 | "id": "2bff392c",
404 | "metadata": {
405 | "hide-output": false
406 | },
407 | "outputs": [],
408 | "source": [
409 | "def newton_raphson(model, β, tol=1e-3, max_iter=100, display=True):\n",
410 | "\n",
411 | " i = 0\n",
412 | " error = 100 # Initial error value\n",
413 | "\n",
414 | " # Print header of output\n",
415 | " if display:\n",
416 | " header = f'{\"Iteration_k\":<13}{\"Log-likelihood\":<16}{\"θ\":<60}'\n",
417 | " print(header)\n",
418 | " print(\"-\" * len(header))\n",
419 | "\n",
420 | " # While loop runs while any value in error is greater\n",
421 | " # than the tolerance until max iterations are reached\n",
422 | " while jnp.any(error > tol) and i < max_iter:\n",
423 | " H, G = jnp.squeeze(H_poisson_logL(β, model)), G_poisson_logL(β, model)\n",
424 | " β_new = β - (jnp.dot(jnp.linalg.inv(H), G))\n",
425 | " error = jnp.abs(β_new - β)\n",
426 | " β = β_new\n",
427 | "\n",
428 | " if display:\n",
429 | " β_list = [f'{t:.3}' for t in list(β.flatten())]\n",
430 | " update = f'{i:<13}{poisson_logL(β, model):<16.8}{β_list}'\n",
431 | " print(update)\n",
432 | "\n",
433 | " i += 1\n",
434 | "\n",
435 | " print(f'Number of iterations: {i}')\n",
436 | " print(f'β_hat = {β.flatten()}')\n",
437 | "\n",
438 | " return β"
439 | ]
440 | },
441 | {
442 | "cell_type": "markdown",
443 | "id": "776f9c51",
444 | "metadata": {},
445 | "source": [
446 | "Let’s try out our algorithm with a small dataset of 5 observations and 3\n",
447 | "variables in $ \\mathbf{X} $."
448 | ]
449 | },
450 | {
451 | "cell_type": "code",
452 | "execution_count": null,
453 | "id": "ef3011a5",
454 | "metadata": {
455 | "hide-output": false
456 | },
457 | "outputs": [],
458 | "source": [
459 | "X = jnp.array([[1, 2, 5],\n",
460 | " [1, 1, 3],\n",
461 | " [1, 4, 2],\n",
462 | " [1, 5, 2],\n",
463 | " [1, 3, 1]])\n",
464 | "\n",
465 | "y = jnp.array([1, 0, 1, 1, 0])\n",
466 | "\n",
467 | "# Take a guess at initial βs\n",
468 | "init_β = jnp.array([0.1, 0.1, 0.1]).reshape(X.shape[1], 1)\n",
469 | "\n",
470 | "# Create an object with Poisson model values\n",
471 | "poi = create_regression_model(X, y)\n",
472 | "\n",
473 | "# Use newton_raphson to find the MLE\n",
474 | "β_hat = newton_raphson(poi, init_β, display=True)"
475 | ]
476 | },
477 | {
478 | "cell_type": "markdown",
479 | "id": "8b548f3a",
480 | "metadata": {},
481 | "source": [
482 | "As this was a simple model with few observations, the algorithm achieved\n",
483 | "convergence in only 7 iterations.\n",
484 | "\n",
485 | "The gradient vector should be close to 0 at $ \\hat{\\boldsymbol{\\beta}} $"
486 | ]
487 | },
488 | {
489 | "cell_type": "code",
490 | "execution_count": null,
491 | "id": "f8b213f6",
492 | "metadata": {
493 | "hide-output": false
494 | },
495 | "outputs": [],
496 | "source": [
497 | "G_poisson_logL(β_hat, poi)"
498 | ]
499 | },
500 | {
501 | "cell_type": "markdown",
502 | "id": "c142fd08",
503 | "metadata": {},
504 | "source": [
505 | "## MLE with `statsmodels`\n",
506 | "\n",
507 | "We’ll use the Poisson regression model in `statsmodels` to verify the results\n",
508 | "obtained using JAX.\n",
509 | "\n",
510 | "`statsmodels` uses the same algorithm as above to find the maximum\n",
511 | "likelihood estimates.\n",
512 | "\n",
513 | "Now, as `statsmodels` accepts only NumPy arrays, we can use the `__array__` method\n",
514 | "of JAX arrays to convert it to NumPy arrays."
515 | ]
516 | },
517 | {
518 | "cell_type": "code",
519 | "execution_count": null,
520 | "id": "95c4e2f3",
521 | "metadata": {
522 | "hide-output": false
523 | },
524 | "outputs": [],
525 | "source": [
526 | "X_numpy = X.__array__()\n",
527 | "y_numpy = y.__array__()"
528 | ]
529 | },
530 | {
531 | "cell_type": "code",
532 | "execution_count": null,
533 | "id": "329591cc",
534 | "metadata": {
535 | "hide-output": false
536 | },
537 | "outputs": [],
538 | "source": [
539 | "stats_poisson = Poisson(y_numpy, X_numpy).fit()\n",
540 | "print(stats_poisson.summary())"
541 | ]
542 | },
543 | {
544 | "cell_type": "markdown",
545 | "id": "3465be9e",
546 | "metadata": {},
547 | "source": [
548 | "The benefit of writing our own procedure, relative to statsmodels is that\n",
549 | "\n",
550 | "- we can exploit the power of the GPU and \n",
551 | "- we learn the underlying methodology, which can be extended to complex situations where no existing routines are available. "
552 | ]
553 | },
554 | {
555 | "cell_type": "markdown",
556 | "id": "67fb7de7",
557 | "metadata": {},
558 | "source": [
559 | "## Exercise 21.1\n",
560 | "\n",
561 | "We define a quadratic model for a single explanatory variable by\n",
562 | "\n",
563 | "$$\n",
564 | "\\log(\\lambda_t) = \\beta_0 + \\beta_1 x_t + \\beta_2 x_{t}^2\n",
565 | "$$\n",
566 | "\n",
567 | "We calculate the mean on the original scale instead of the log scale by exponentiating both sides of the above equation, which gives\n",
568 | "\n",
569 | "\n",
570 | "\n",
571 | "$$\n",
572 | "\\lambda_t = \\exp(\\beta_0 + \\beta_1 x_t + \\beta_2 x_{t}^2) \\tag{21.1}\n",
573 | "$$\n",
574 | "\n",
575 | "Simulate the values of $ x_t $ by sampling from a normal distribution and $ \\lambda_t $ by using [(21.1)](#equation-lambda-mle) and the following constants:\n",
576 | "\n",
577 | "$$\n",
578 | "\\beta_0 = -2.5,\n",
579 | " \\quad\n",
580 | " \\beta_1 = 0.25,\n",
581 | " \\quad\n",
582 | " \\beta_2 = 0.5\n",
583 | "$$\n",
584 | "\n",
585 | "Try to obtain the approximate values of $ \\beta_0,\\beta_1,\\beta_2 $, by simulating a Poisson Regression Model such that\n",
586 | "\n",
587 | "$$\n",
588 | "y_t \\sim {\\rm Poisson}(\\lambda_t)\n",
589 | " \\quad \\text{for all } t.\n",
590 | "$$\n",
591 | "\n",
592 | "Using our `newton_raphson` function on the data set $ X = [1, x_t, x_t^{2}] $ and\n",
593 | "$ y $, obtain the maximum likelihood estimates of $ \\beta_0,\\beta_1,\\beta_2 $.\n",
594 | "\n",
595 | "With a sufficient large sample size, you should approximately\n",
596 | "recover the true values of of these parameters."
597 | ]
598 | },
599 | {
600 | "cell_type": "markdown",
601 | "id": "4c847896",
602 | "metadata": {},
603 | "source": [
604 | "## Solution\n",
605 | "\n",
606 | "Let’s start by defining “true” parameter values."
607 | ]
608 | },
609 | {
610 | "cell_type": "code",
611 | "execution_count": null,
612 | "id": "c1d1de5a",
613 | "metadata": {
614 | "hide-output": false
615 | },
616 | "outputs": [],
617 | "source": [
618 | "β_0 = -2.5\n",
619 | "β_1 = 0.25\n",
620 | "β_2 = 0.5"
621 | ]
622 | },
623 | {
624 | "cell_type": "markdown",
625 | "id": "ef634b5b",
626 | "metadata": {},
627 | "source": [
628 | "To simulate the model, we sample 500,000 values of $ x_t $ from the standard normal distribution."
629 | ]
630 | },
631 | {
632 | "cell_type": "code",
633 | "execution_count": null,
634 | "id": "ef1e2ef4",
635 | "metadata": {
636 | "hide-output": false
637 | },
638 | "outputs": [],
639 | "source": [
640 | "seed = 32\n",
641 | "shape = (500_000, 1)\n",
642 | "key = jax.random.PRNGKey(seed)\n",
643 | "x = jax.random.normal(key, shape)"
644 | ]
645 | },
646 | {
647 | "cell_type": "markdown",
648 | "id": "e3ba8ba8",
649 | "metadata": {},
650 | "source": [
651 | "We compute $ \\lambda $ using [(21.1)](#equation-lambda-mle)"
652 | ]
653 | },
654 | {
655 | "cell_type": "code",
656 | "execution_count": null,
657 | "id": "27b50f21",
658 | "metadata": {
659 | "hide-output": false
660 | },
661 | "outputs": [],
662 | "source": [
663 | "λ = jnp.exp(β_0 + β_1 * x + β_2 * x**2)"
664 | ]
665 | },
666 | {
667 | "cell_type": "markdown",
668 | "id": "4a030085",
669 | "metadata": {},
670 | "source": [
671 | "Let’s define $ y_t $ by sampling from a Poisson distribution with mean as $ \\lambda_t $."
672 | ]
673 | },
674 | {
675 | "cell_type": "code",
676 | "execution_count": null,
677 | "id": "f1a1b0be",
678 | "metadata": {
679 | "hide-output": false
680 | },
681 | "outputs": [],
682 | "source": [
683 | "y = jax.random.poisson(key, λ, shape)"
684 | ]
685 | },
686 | {
687 | "cell_type": "markdown",
688 | "id": "b9662db5",
689 | "metadata": {},
690 | "source": [
691 | "Now let’s try to recover the true parameter values using the Newton-Raphson\n",
692 | "method described above."
693 | ]
694 | },
695 | {
696 | "cell_type": "code",
697 | "execution_count": null,
698 | "id": "ab7fa8be",
699 | "metadata": {
700 | "hide-output": false
701 | },
702 | "outputs": [],
703 | "source": [
704 | "X = jnp.hstack((jnp.ones(shape), x, x**2))\n",
705 | "\n",
706 | "# Take a guess at initial βs\n",
707 | "init_β = jnp.array([0.1, 0.1, 0.1]).reshape(X.shape[1], 1)\n",
708 | "\n",
709 | "# Create an object with Poisson model values\n",
710 | "poi = create_regression_model(X, y)\n",
711 | "\n",
712 | "# Use newton_raphson to find the MLE\n",
713 | "β_hat = newton_raphson(poi, init_β, tol=1e-5, display=True)"
714 | ]
715 | },
716 | {
717 | "cell_type": "markdown",
718 | "id": "22da66ba",
719 | "metadata": {},
720 | "source": [
721 | "The maximum likelihood estimates are similar to the true parameter values."
722 | ]
723 | }
724 | ],
725 | "metadata": {
726 | "date": 1765244755.559107,
727 | "filename": "mle.md",
728 | "kernelspec": {
729 | "display_name": "Python",
730 | "language": "python3",
731 | "name": "python3"
732 | },
733 | "title": "Maximum Likelihood Estimation"
734 | },
735 | "nbformat": 4,
736 | "nbformat_minor": 5
737 | }
--------------------------------------------------------------------------------
/opt_savings.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "1301ffe6",
6 | "metadata": {},
7 | "source": [
8 | "# Optimal Savings"
9 | ]
10 | },
11 | {
12 | "cell_type": "markdown",
13 | "id": "40dd4c48",
14 | "metadata": {},
15 | "source": [
16 | "# GPU\n",
17 | "\n",
18 | "This lecture was built using [hardware](https://jax.quantecon.org/status.html#status-machine-details) that has access to a GPU.\n",
19 | "\n",
20 | "To run this lecture on [Google Colab](https://colab.research.google.com/), click on the “play” icon top right, select Colab, and set the runtime environment to include a GPU.\n",
21 | "\n",
22 | "To run this lecture on your own machine, you need to install [Google JAX](https://github.com/google/jax).\n",
23 | "\n",
24 | "In addition to what’s in Anaconda, this lecture will need the following libraries:"
25 | ]
26 | },
27 | {
28 | "cell_type": "code",
29 | "execution_count": null,
30 | "id": "273d7940",
31 | "metadata": {
32 | "hide-output": false
33 | },
34 | "outputs": [],
35 | "source": [
36 | "!pip install quantecon"
37 | ]
38 | },
39 | {
40 | "cell_type": "markdown",
41 | "id": "4481b2aa",
42 | "metadata": {},
43 | "source": [
44 | "We will use the following imports:"
45 | ]
46 | },
47 | {
48 | "cell_type": "code",
49 | "execution_count": null,
50 | "id": "6aa3677a",
51 | "metadata": {
52 | "hide-output": false
53 | },
54 | "outputs": [],
55 | "source": [
56 | "import quantecon as qe\n",
57 | "import jax\n",
58 | "import jax.numpy as jnp\n",
59 | "from collections import namedtuple\n",
60 | "import matplotlib.pyplot as plt\n",
61 | "import time"
62 | ]
63 | },
64 | {
65 | "cell_type": "markdown",
66 | "id": "41b21f55",
67 | "metadata": {},
68 | "source": [
69 | "Let’s check the GPU we are running"
70 | ]
71 | },
72 | {
73 | "cell_type": "code",
74 | "execution_count": null,
75 | "id": "0f1b1199",
76 | "metadata": {
77 | "hide-output": false
78 | },
79 | "outputs": [],
80 | "source": [
81 | "!nvidia-smi"
82 | ]
83 | },
84 | {
85 | "cell_type": "markdown",
86 | "id": "ebc1c96f",
87 | "metadata": {},
88 | "source": [
89 | "Use 64 bit floats with JAX in order to match NumPy code\n",
90 | "\n",
91 | "- By default, JAX uses 32-bit datatypes. \n",
92 | "- By default, NumPy uses 64-bit datatypes. "
93 | ]
94 | },
95 | {
96 | "cell_type": "code",
97 | "execution_count": null,
98 | "id": "e85c00da",
99 | "metadata": {
100 | "hide-output": false
101 | },
102 | "outputs": [],
103 | "source": [
104 | "jax.config.update(\"jax_enable_x64\", True)"
105 | ]
106 | },
107 | {
108 | "cell_type": "markdown",
109 | "id": "6a09799f",
110 | "metadata": {},
111 | "source": [
112 | "## Overview\n",
113 | "\n",
114 | "We consider an optimal savings problem with CRRA utility and budget constraint\n",
115 | "\n",
116 | "$$\n",
117 | "W_{t+1} + C_t \\leq R W_t + Y_t\n",
118 | "$$\n",
119 | "\n",
120 | "We assume that labor income $ (Y_t) $ is a discretized AR(1) process.\n",
121 | "\n",
122 | "The right-hand side of the Bellman equation is\n",
123 | "\n",
124 | "$$\n",
125 | "B((w, y), w', v) = u(Rw + y - w') + β \\sum_{y'} v(w', y') Q(y, y').\n",
126 | "$$\n",
127 | "\n",
128 | "where\n",
129 | "\n",
130 | "$$\n",
131 | "u(c) = \\frac{c^{1-\\gamma}}{1-\\gamma}\n",
132 | "$$"
133 | ]
134 | },
135 | {
136 | "cell_type": "markdown",
137 | "id": "dd59a12d",
138 | "metadata": {},
139 | "source": [
140 | "## Model primitives\n",
141 | "\n",
142 | "First we define a model that stores parameters and grids"
143 | ]
144 | },
145 | {
146 | "cell_type": "code",
147 | "execution_count": null,
148 | "id": "fc4a0026",
149 | "metadata": {
150 | "hide-output": false
151 | },
152 | "outputs": [],
153 | "source": [
154 | "def create_consumption_model(R=1.01, # Gross interest rate\n",
155 | " β=0.98, # Discount factor\n",
156 | " γ=2, # CRRA parameter\n",
157 | " w_min=0.01, # Min wealth\n",
158 | " w_max=5.0, # Max wealth\n",
159 | " w_size=150, # Grid side\n",
160 | " ρ=0.9, ν=0.1, y_size=100): # Income parameters\n",
161 | " \"\"\"\n",
162 | " A function that takes in parameters and returns parameters and grids \n",
163 | " for the optimal savings problem.\n",
164 | " \"\"\"\n",
165 | " w_grid = jnp.linspace(w_min, w_max, w_size)\n",
166 | " mc = qe.tauchen(n=y_size, rho=ρ, sigma=ν)\n",
167 | " y_grid, Q = jnp.exp(mc.state_values), mc.P\n",
168 | " β, R, γ = jax.device_put([β, R, γ])\n",
169 | " w_grid, y_grid, Q = tuple(map(jax.device_put, [w_grid, y_grid, Q]))\n",
170 | " sizes = w_size, y_size\n",
171 | " return (β, R, γ), sizes, (w_grid, y_grid, Q)"
172 | ]
173 | },
174 | {
175 | "cell_type": "markdown",
176 | "id": "d0627bb2",
177 | "metadata": {},
178 | "source": [
179 | "Here’s the right hand side of the Bellman equation:"
180 | ]
181 | },
182 | {
183 | "cell_type": "code",
184 | "execution_count": null,
185 | "id": "79caa560",
186 | "metadata": {
187 | "hide-output": false
188 | },
189 | "outputs": [],
190 | "source": [
191 | "def B(v, constants, sizes, arrays):\n",
192 | " \"\"\"\n",
193 | " A vectorized version of the right-hand side of the Bellman equation\n",
194 | " (before maximization), which is a 3D array representing\n",
195 | "\n",
196 | " B(w, y, w′) = u(Rw + y - w′) + β Σ_y′ v(w′, y′) Q(y, y′)\n",
197 | "\n",
198 | " for all (w, y, w′).\n",
199 | " \"\"\"\n",
200 | "\n",
201 | " # Unpack\n",
202 | " β, R, γ = constants\n",
203 | " w_size, y_size = sizes\n",
204 | " w_grid, y_grid, Q = arrays\n",
205 | "\n",
206 | " # Compute current rewards r(w, y, wp) as array r[i, j, ip]\n",
207 | " w = jnp.reshape(w_grid, (w_size, 1, 1)) # w[i] -> w[i, j, ip]\n",
208 | " y = jnp.reshape(y_grid, (1, y_size, 1)) # z[j] -> z[i, j, ip]\n",
209 | " wp = jnp.reshape(w_grid, (1, 1, w_size)) # wp[ip] -> wp[i, j, ip]\n",
210 | " c = R * w + y - wp\n",
211 | "\n",
212 | " # Calculate continuation rewards at all combinations of (w, y, wp)\n",
213 | " v = jnp.reshape(v, (1, 1, w_size, y_size)) # v[ip, jp] -> v[i, j, ip, jp]\n",
214 | " Q = jnp.reshape(Q, (1, y_size, 1, y_size)) # Q[j, jp] -> Q[i, j, ip, jp]\n",
215 | " EV = jnp.sum(v * Q, axis=3) # sum over last index jp\n",
216 | "\n",
217 | " # Compute the right-hand side of the Bellman equation\n",
218 | " return jnp.where(c > 0, c**(1-γ)/(1-γ) + β * EV, -jnp.inf)"
219 | ]
220 | },
221 | {
222 | "cell_type": "markdown",
223 | "id": "8925d0b3",
224 | "metadata": {},
225 | "source": [
226 | "## Operators\n",
227 | "\n",
228 | "We define a function to compute the current rewards $ r_\\sigma $ given policy $ \\sigma $,\n",
229 | "which is defined as the vector\n",
230 | "\n",
231 | "$$\n",
232 | "r_\\sigma(w, y) := r(w, y, \\sigma(w, y))\n",
233 | "$$"
234 | ]
235 | },
236 | {
237 | "cell_type": "code",
238 | "execution_count": null,
239 | "id": "7296a174",
240 | "metadata": {
241 | "hide-output": false
242 | },
243 | "outputs": [],
244 | "source": [
245 | "def compute_r_σ(σ, constants, sizes, arrays):\n",
246 | " \"\"\"\n",
247 | " Compute the array r_σ[i, j] = r[i, j, σ[i, j]], which gives current\n",
248 | " rewards given policy σ.\n",
249 | " \"\"\"\n",
250 | "\n",
251 | " # Unpack model\n",
252 | " β, R, γ = constants\n",
253 | " w_size, y_size = sizes\n",
254 | " w_grid, y_grid, Q = arrays\n",
255 | "\n",
256 | " # Compute r_σ[i, j]\n",
257 | " w = jnp.reshape(w_grid, (w_size, 1))\n",
258 | " y = jnp.reshape(y_grid, (1, y_size))\n",
259 | " wp = w_grid[σ]\n",
260 | " c = R * w + y - wp\n",
261 | " r_σ = c**(1-γ)/(1-γ)\n",
262 | "\n",
263 | " return r_σ"
264 | ]
265 | },
266 | {
267 | "cell_type": "markdown",
268 | "id": "dbaa052f",
269 | "metadata": {},
270 | "source": [
271 | "Now we define the policy operator $ T_\\sigma $"
272 | ]
273 | },
274 | {
275 | "cell_type": "code",
276 | "execution_count": null,
277 | "id": "e4bc60a1",
278 | "metadata": {
279 | "hide-output": false
280 | },
281 | "outputs": [],
282 | "source": [
283 | "def T_σ(v, σ, constants, sizes, arrays):\n",
284 | " \"The σ-policy operator.\"\n",
285 | "\n",
286 | " # Unpack model\n",
287 | " β, R, γ = constants\n",
288 | " w_size, y_size = sizes\n",
289 | " w_grid, y_grid, Q = arrays\n",
290 | "\n",
291 | " r_σ = compute_r_σ(σ, constants, sizes, arrays)\n",
292 | "\n",
293 | " # Compute the array v[σ[i, j], jp]\n",
294 | " yp_idx = jnp.arange(y_size)\n",
295 | " yp_idx = jnp.reshape(yp_idx, (1, 1, y_size))\n",
296 | " σ = jnp.reshape(σ, (w_size, y_size, 1))\n",
297 | " V = v[σ, yp_idx]\n",
298 | "\n",
299 | " # Convert Q[j, jp] to Q[i, j, jp]\n",
300 | " Q = jnp.reshape(Q, (1, y_size, y_size))\n",
301 | "\n",
302 | " # Calculate the expected sum Σ_jp v[σ[i, j], jp] * Q[i, j, jp]\n",
303 | " EV = jnp.sum(V * Q, axis=2)\n",
304 | "\n",
305 | " return r_σ + β * EV"
306 | ]
307 | },
308 | {
309 | "cell_type": "markdown",
310 | "id": "e5f4deb0",
311 | "metadata": {},
312 | "source": [
313 | "and the Bellman operator $ T $"
314 | ]
315 | },
316 | {
317 | "cell_type": "code",
318 | "execution_count": null,
319 | "id": "04eddd8f",
320 | "metadata": {
321 | "hide-output": false
322 | },
323 | "outputs": [],
324 | "source": [
325 | "def T(v, constants, sizes, arrays):\n",
326 | " \"The Bellman operator.\"\n",
327 | " return jnp.max(B(v, constants, sizes, arrays), axis=2)"
328 | ]
329 | },
330 | {
331 | "cell_type": "markdown",
332 | "id": "915332d7",
333 | "metadata": {},
334 | "source": [
335 | "The next function computes a $ v $-greedy policy given $ v $"
336 | ]
337 | },
338 | {
339 | "cell_type": "code",
340 | "execution_count": null,
341 | "id": "6a013b75",
342 | "metadata": {
343 | "hide-output": false
344 | },
345 | "outputs": [],
346 | "source": [
347 | "def get_greedy(v, constants, sizes, arrays):\n",
348 | " \"Computes a v-greedy policy, returned as a set of indices.\"\n",
349 | " return jnp.argmax(B(v, constants, sizes, arrays), axis=2)"
350 | ]
351 | },
352 | {
353 | "cell_type": "markdown",
354 | "id": "83f6f249",
355 | "metadata": {},
356 | "source": [
357 | "The function below computes the value $ v_\\sigma $ of following policy $ \\sigma $.\n",
358 | "\n",
359 | "This lifetime value is a function $ v_\\sigma $ that satisfies\n",
360 | "\n",
361 | "$$\n",
362 | "v_\\sigma(w, y) = r_\\sigma(w, y) + \\beta \\sum_{y'} v_\\sigma(\\sigma(w, y), y') Q(y, y')\n",
363 | "$$\n",
364 | "\n",
365 | "We wish to solve this equation for $ v_\\sigma $.\n",
366 | "\n",
367 | "Suppose we define the linear operator $ L_\\sigma $ by\n",
368 | "\n",
369 | "$$\n",
370 | "(L_\\sigma v)(w, y) = v(w, y) - \\beta \\sum_{y'} v(\\sigma(w, y), y') Q(y, y')\n",
371 | "$$\n",
372 | "\n",
373 | "With this notation, the problem is to solve for $ v $ via\n",
374 | "\n",
375 | "$$\n",
376 | "(L_{\\sigma} v)(w, y) = r_\\sigma(w, y)\n",
377 | "$$\n",
378 | "\n",
379 | "In vector for this is $ L_\\sigma v = r_\\sigma $, which tells us that the function\n",
380 | "we seek is\n",
381 | "\n",
382 | "$$\n",
383 | "v_\\sigma = L_\\sigma^{-1} r_\\sigma\n",
384 | "$$\n",
385 | "\n",
386 | "JAX allows us to solve linear systems defined in terms of operators; the first\n",
387 | "step is to define the function $ L_{\\sigma} $."
388 | ]
389 | },
390 | {
391 | "cell_type": "code",
392 | "execution_count": null,
393 | "id": "1a5b6b90",
394 | "metadata": {
395 | "hide-output": false
396 | },
397 | "outputs": [],
398 | "source": [
399 | "def L_σ(v, σ, constants, sizes, arrays):\n",
400 | " \"\"\"\n",
401 | " Here we set up the linear map v -> L_σ v, where \n",
402 | "\n",
403 | " (L_σ v)(w, y) = v(w, y) - β Σ_y′ v(σ(w, y), y′) Q(y, y′)\n",
404 | "\n",
405 | " \"\"\"\n",
406 | "\n",
407 | " β, R, γ = constants\n",
408 | " w_size, y_size = sizes\n",
409 | " w_grid, y_grid, Q = arrays\n",
410 | "\n",
411 | " # Set up the array v[σ[i, j], jp]\n",
412 | " zp_idx = jnp.arange(y_size)\n",
413 | " zp_idx = jnp.reshape(zp_idx, (1, 1, y_size))\n",
414 | " σ = jnp.reshape(σ, (w_size, y_size, 1))\n",
415 | " V = v[σ, zp_idx]\n",
416 | "\n",
417 | " # Expand Q[j, jp] to Q[i, j, jp]\n",
418 | " Q = jnp.reshape(Q, (1, y_size, y_size))\n",
419 | "\n",
420 | " # Compute and return v[i, j] - β Σ_jp v[σ[i, j], jp] * Q[j, jp]\n",
421 | " return v - β * jnp.sum(V * Q, axis=2)"
422 | ]
423 | },
424 | {
425 | "cell_type": "markdown",
426 | "id": "236e0462",
427 | "metadata": {},
428 | "source": [
429 | "Now we can define a function to compute $ v_{\\sigma} $"
430 | ]
431 | },
432 | {
433 | "cell_type": "code",
434 | "execution_count": null,
435 | "id": "ba95b4e9",
436 | "metadata": {
437 | "hide-output": false
438 | },
439 | "outputs": [],
440 | "source": [
441 | "def get_value(σ, constants, sizes, arrays):\n",
442 | " \"Get the value v_σ of policy σ by inverting the linear map L_σ.\"\n",
443 | "\n",
444 | " # Unpack\n",
445 | " β, R, γ = constants\n",
446 | " w_size, y_size = sizes\n",
447 | " w_grid, y_grid, Q = arrays\n",
448 | "\n",
449 | " r_σ = compute_r_σ(σ, constants, sizes, arrays)\n",
450 | "\n",
451 | " # Reduce L_σ to a function in v\n",
452 | " partial_L_σ = lambda v: L_σ(v, σ, constants, sizes, arrays)\n",
453 | "\n",
454 | " return jax.scipy.sparse.linalg.bicgstab(partial_L_σ, r_σ)[0]"
455 | ]
456 | },
457 | {
458 | "cell_type": "markdown",
459 | "id": "03eb7d54",
460 | "metadata": {},
461 | "source": [
462 | "## JIT compiled versions"
463 | ]
464 | },
465 | {
466 | "cell_type": "code",
467 | "execution_count": null,
468 | "id": "3e719d81",
469 | "metadata": {
470 | "hide-output": false
471 | },
472 | "outputs": [],
473 | "source": [
474 | "B = jax.jit(B, static_argnums=(2,))\n",
475 | "compute_r_σ = jax.jit(compute_r_σ, static_argnums=(2,))\n",
476 | "T = jax.jit(T, static_argnums=(2,))\n",
477 | "get_greedy = jax.jit(get_greedy, static_argnums=(2,))\n",
478 | "get_value = jax.jit(get_value, static_argnums=(2,))\n",
479 | "T_σ = jax.jit(T_σ, static_argnums=(3,))\n",
480 | "L_σ = jax.jit(L_σ, static_argnums=(3,))"
481 | ]
482 | },
483 | {
484 | "cell_type": "markdown",
485 | "id": "990192bb",
486 | "metadata": {},
487 | "source": [
488 | "We use successive approximation for VFI."
489 | ]
490 | },
491 | {
492 | "cell_type": "code",
493 | "execution_count": null,
494 | "id": "390db1df",
495 | "metadata": {
496 | "hide-output": false
497 | },
498 | "outputs": [],
499 | "source": [
500 | "def successive_approx_jax(x_0, # Initial condition\n",
501 | " constants,\n",
502 | " sizes,\n",
503 | " arrays, \n",
504 | " tolerance=1e-6, # Error tolerance\n",
505 | " max_iter=10_000): # Max iteration bound\n",
506 | "\n",
507 | " def body_fun(k_x_err):\n",
508 | " k, x, error = k_x_err\n",
509 | " x_new = T(x, constants, sizes, arrays)\n",
510 | " error = jnp.max(jnp.abs(x_new - x))\n",
511 | " return k + 1, x_new, error\n",
512 | "\n",
513 | " def cond_fun(k_x_err):\n",
514 | " k, x, error = k_x_err\n",
515 | " return jnp.logical_and(error > tolerance, k < max_iter)\n",
516 | "\n",
517 | " k, x, error = jax.lax.while_loop(cond_fun, body_fun, (1, x_0, tolerance + 1))\n",
518 | " return x\n",
519 | "\n",
520 | "successive_approx_jax = jax.jit(successive_approx_jax, static_argnums=(2,))"
521 | ]
522 | },
523 | {
524 | "cell_type": "markdown",
525 | "id": "282c4611",
526 | "metadata": {},
527 | "source": [
528 | "## Solvers\n",
529 | "\n",
530 | "Now we define the solvers, which implement VFI, HPI and OPI."
531 | ]
532 | },
533 | {
534 | "cell_type": "code",
535 | "execution_count": null,
536 | "id": "06b01a4a",
537 | "metadata": {
538 | "hide-output": false
539 | },
540 | "outputs": [],
541 | "source": [
542 | "# Implements VFI-Value Function iteration\n",
543 | "\n",
544 | "def value_iteration(model, tol=1e-5):\n",
545 | " constants, sizes, arrays = model\n",
546 | " vz = jnp.zeros(sizes)\n",
547 | "\n",
548 | " v_star = successive_approx_jax(vz, constants, sizes, arrays, tolerance=tol)\n",
549 | " return get_greedy(v_star, constants, sizes, arrays)"
550 | ]
551 | },
552 | {
553 | "cell_type": "code",
554 | "execution_count": null,
555 | "id": "44d50bee",
556 | "metadata": {
557 | "hide-output": false
558 | },
559 | "outputs": [],
560 | "source": [
561 | "# Implements HPI-Howard policy iteration routine\n",
562 | "\n",
563 | "def policy_iteration(model, maxiter=250):\n",
564 | " constants, sizes, arrays = model\n",
565 | " σ = jnp.zeros(sizes, dtype=int)\n",
566 | " i, error = 0, 1.0\n",
567 | " while error > 0 and i < maxiter:\n",
568 | " v_σ = get_value(σ, constants, sizes, arrays)\n",
569 | " σ_new = get_greedy(v_σ, constants, sizes, arrays)\n",
570 | " error = jnp.max(jnp.abs(σ_new - σ))\n",
571 | " σ = σ_new\n",
572 | " i = i + 1\n",
573 | " print(f\"Concluded loop {i} with error {error}.\")\n",
574 | " return σ"
575 | ]
576 | },
577 | {
578 | "cell_type": "code",
579 | "execution_count": null,
580 | "id": "edfc6a7f",
581 | "metadata": {
582 | "hide-output": false
583 | },
584 | "outputs": [],
585 | "source": [
586 | "# Implements the OPI-Optimal policy Iteration routine\n",
587 | "\n",
588 | "def optimistic_policy_iteration(model, tol=1e-5, m=10):\n",
589 | " constants, sizes, arrays = model\n",
590 | " v = jnp.zeros(sizes)\n",
591 | " error = tol + 1\n",
592 | " while error > tol:\n",
593 | " last_v = v\n",
594 | " σ = get_greedy(v, constants, sizes, arrays)\n",
595 | " for _ in range(m):\n",
596 | " v = T_σ(v, σ, constants, sizes, arrays)\n",
597 | " error = jnp.max(jnp.abs(v - last_v))\n",
598 | " return get_greedy(v, constants, sizes, arrays)"
599 | ]
600 | },
601 | {
602 | "cell_type": "markdown",
603 | "id": "9470e3a4",
604 | "metadata": {},
605 | "source": [
606 | "## Plots\n",
607 | "\n",
608 | "Create a model for consumption, perform policy iteration, and plot the resulting optimal policy function."
609 | ]
610 | },
611 | {
612 | "cell_type": "code",
613 | "execution_count": null,
614 | "id": "78468634",
615 | "metadata": {
616 | "hide-output": false
617 | },
618 | "outputs": [],
619 | "source": [
620 | "fontsize = 12\n",
621 | "model = create_consumption_model()\n",
622 | "# Unpack\n",
623 | "constants, sizes, arrays = model\n",
624 | "β, R, γ = constants\n",
625 | "w_size, y_size = sizes\n",
626 | "w_grid, y_grid, Q = arrays"
627 | ]
628 | },
629 | {
630 | "cell_type": "code",
631 | "execution_count": null,
632 | "id": "8fa58eb9",
633 | "metadata": {
634 | "hide-output": false
635 | },
636 | "outputs": [],
637 | "source": [
638 | "σ_star = policy_iteration(model)\n",
639 | "\n",
640 | "fig, ax = plt.subplots(figsize=(9, 5.2))\n",
641 | "ax.plot(w_grid, w_grid, \"k--\", label=\"45\")\n",
642 | "ax.plot(w_grid, w_grid[σ_star[:, 1]], label=\"$\\\\sigma^*(\\cdot, y_1)$\")\n",
643 | "ax.plot(w_grid, w_grid[σ_star[:, -1]], label=\"$\\\\sigma^*(\\cdot, y_N)$\")\n",
644 | "ax.legend(fontsize=fontsize)\n",
645 | "plt.show()"
646 | ]
647 | },
648 | {
649 | "cell_type": "markdown",
650 | "id": "1d9c9eee",
651 | "metadata": {},
652 | "source": [
653 | "## Tests\n",
654 | "\n",
655 | "Here’s a quick test of the timing of each solver."
656 | ]
657 | },
658 | {
659 | "cell_type": "code",
660 | "execution_count": null,
661 | "id": "6ef2af4a",
662 | "metadata": {
663 | "hide-output": false
664 | },
665 | "outputs": [],
666 | "source": [
667 | "model = create_consumption_model()"
668 | ]
669 | },
670 | {
671 | "cell_type": "code",
672 | "execution_count": null,
673 | "id": "26c4fa88",
674 | "metadata": {
675 | "hide-output": false
676 | },
677 | "outputs": [],
678 | "source": [
679 | "print(\"Starting HPI.\")\n",
680 | "start_time = time.time()\n",
681 | "out = policy_iteration(model)\n",
682 | "elapsed = time.time() - start_time\n",
683 | "print(f\"HPI completed in {elapsed} seconds.\")"
684 | ]
685 | },
686 | {
687 | "cell_type": "code",
688 | "execution_count": null,
689 | "id": "b6d920e0",
690 | "metadata": {
691 | "hide-output": false
692 | },
693 | "outputs": [],
694 | "source": [
695 | "print(\"Starting VFI.\")\n",
696 | "start_time = time.time()\n",
697 | "out = value_iteration(model)\n",
698 | "elapsed = time.time() - start_time\n",
699 | "print(f\"VFI completed in {elapsed} seconds.\")"
700 | ]
701 | },
702 | {
703 | "cell_type": "code",
704 | "execution_count": null,
705 | "id": "a9d48feb",
706 | "metadata": {
707 | "hide-output": false
708 | },
709 | "outputs": [],
710 | "source": [
711 | "print(\"Starting OPI.\")\n",
712 | "start_time = time.time()\n",
713 | "out = optimistic_policy_iteration(model, m=100)\n",
714 | "elapsed = time.time() - start_time\n",
715 | "print(f\"OPI completed in {elapsed} seconds.\")"
716 | ]
717 | },
718 | {
719 | "cell_type": "code",
720 | "execution_count": null,
721 | "id": "26273b26",
722 | "metadata": {
723 | "hide-output": false
724 | },
725 | "outputs": [],
726 | "source": [
727 | "def run_algorithm(algorithm, model, **kwargs):\n",
728 | " start_time = time.time()\n",
729 | " result = algorithm(model, **kwargs)\n",
730 | " end_time = time.time()\n",
731 | " elapsed_time = end_time - start_time\n",
732 | " print(f\"{algorithm.__name__} completed in {elapsed_time:.2f} seconds.\")\n",
733 | " return result, elapsed_time"
734 | ]
735 | },
736 | {
737 | "cell_type": "code",
738 | "execution_count": null,
739 | "id": "f9b2ea24",
740 | "metadata": {
741 | "hide-output": false
742 | },
743 | "outputs": [],
744 | "source": [
745 | "model = create_consumption_model()\n",
746 | "σ_pi, pi_time = run_algorithm(policy_iteration, model)\n",
747 | "σ_vfi, vfi_time = run_algorithm(value_iteration, model, tol=1e-5)\n",
748 | "\n",
749 | "m_vals = range(5, 600, 40)\n",
750 | "opi_times = []\n",
751 | "for m in m_vals:\n",
752 | " σ_opi, opi_time = run_algorithm(optimistic_policy_iteration, model, m=m, tol=1e-5)\n",
753 | " opi_times.append(opi_time)"
754 | ]
755 | },
756 | {
757 | "cell_type": "code",
758 | "execution_count": null,
759 | "id": "cdb08e56",
760 | "metadata": {
761 | "hide-output": false
762 | },
763 | "outputs": [],
764 | "source": [
765 | "fig, ax = plt.subplots(figsize=(9, 5.2))\n",
766 | "ax.plot(m_vals, jnp.full(len(m_vals), pi_time), lw=2, label=\"Howard policy iteration\")\n",
767 | "ax.plot(m_vals, jnp.full(len(m_vals), vfi_time), lw=2, label=\"value function iteration\")\n",
768 | "ax.plot(m_vals, opi_times, lw=2, label=\"optimistic policy iteration\")\n",
769 | "ax.legend(fontsize=fontsize, frameon=False)\n",
770 | "ax.set_xlabel(\"$m$\", fontsize=fontsize)\n",
771 | "ax.set_ylabel(\"time\", fontsize=fontsize)\n",
772 | "plt.show()"
773 | ]
774 | }
775 | ],
776 | "metadata": {
777 | "date": 1709777244.168797,
778 | "filename": "opt_savings.md",
779 | "kernelspec": {
780 | "display_name": "Python",
781 | "language": "python3",
782 | "name": "python3"
783 | },
784 | "title": "Optimal Savings"
785 | },
786 | "nbformat": 4,
787 | "nbformat_minor": 5
788 | }
--------------------------------------------------------------------------------
/autodiff.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "aec1bc28",
6 | "metadata": {},
7 | "source": [
8 | "$$\n",
9 | "\\newcommand{\\argmax}{arg\\,max}\n",
10 | "\\newcommand{\\argmin}{arg\\,min}\n",
11 | "$$"
12 | ]
13 | },
14 | {
15 | "cell_type": "markdown",
16 | "id": "74b5ce9b",
17 | "metadata": {},
18 | "source": [
19 | "# Adventures with Autodiff"
20 | ]
21 | },
22 | {
23 | "cell_type": "markdown",
24 | "id": "c6eb2d83",
25 | "metadata": {},
26 | "source": [
27 | "# GPU\n",
28 | "\n",
29 | "This lecture was built using a machine with access to a GPU.\n",
30 | "\n",
31 | "[Google Colab](https://colab.research.google.com/) has a free tier with GPUs\n",
32 | "that you can access as follows:\n",
33 | "\n",
34 | "1. Click on the “play” icon top right \n",
35 | "1. Select Colab \n",
36 | "1. Set the runtime environment to include a GPU "
37 | ]
38 | },
39 | {
40 | "cell_type": "markdown",
41 | "id": "c4b3fb66",
42 | "metadata": {},
43 | "source": [
44 | "## Overview\n",
45 | "\n",
46 | "This lecture gives a brief introduction to automatic differentiation using\n",
47 | "Google JAX.\n",
48 | "\n",
49 | "Automatic differentiation is one of the key elements of modern machine learning\n",
50 | "and artificial intelligence.\n",
51 | "\n",
52 | "As such it has attracted a great deal of investment and there are several\n",
53 | "powerful implementations available.\n",
54 | "\n",
55 | "One of the best of these is the automatic differentiation routines contained\n",
56 | "in JAX.\n",
57 | "\n",
58 | "While other software packages also offer this feature, the JAX version is\n",
59 | "particularly powerful because it integrates so well with other core\n",
60 | "components of JAX (e.g., JIT compilation and parallelization).\n",
61 | "\n",
62 | "As we will see in later lectures, automatic differentiation can be used not only\n",
63 | "for AI but also for many problems faced in mathematical modeling, such as\n",
64 | "multi-dimensional nonlinear optimization and root-finding problems.\n",
65 | "\n",
66 | "We need the following imports"
67 | ]
68 | },
69 | {
70 | "cell_type": "code",
71 | "execution_count": null,
72 | "id": "e96596b7",
73 | "metadata": {
74 | "hide-output": false
75 | },
76 | "outputs": [],
77 | "source": [
78 | "import jax\n",
79 | "import jax.numpy as jnp\n",
80 | "import matplotlib.pyplot as plt\n",
81 | "import numpy as np"
82 | ]
83 | },
84 | {
85 | "cell_type": "markdown",
86 | "id": "636ed1a2",
87 | "metadata": {},
88 | "source": [
89 | "Checking for a GPU:"
90 | ]
91 | },
92 | {
93 | "cell_type": "code",
94 | "execution_count": null,
95 | "id": "a884d274",
96 | "metadata": {
97 | "hide-output": false
98 | },
99 | "outputs": [],
100 | "source": [
101 | "!nvidia-smi"
102 | ]
103 | },
104 | {
105 | "cell_type": "markdown",
106 | "id": "57771d89",
107 | "metadata": {},
108 | "source": [
109 | "## What is automatic differentiation?\n",
110 | "\n",
111 | "Autodiff is a technique for calculating derivatives on a computer."
112 | ]
113 | },
114 | {
115 | "cell_type": "markdown",
116 | "id": "f0bfd211",
117 | "metadata": {},
118 | "source": [
119 | "### Autodiff is not finite differences\n",
120 | "\n",
121 | "The derivative of $ f(x) = \\exp(2x) $ is\n",
122 | "\n",
123 | "$$\n",
124 | "f'(x) = 2 \\exp(2x)\n",
125 | "$$\n",
126 | "\n",
127 | "A computer that doesn’t know how to take derivatives might approximate this with the finite difference ratio\n",
128 | "\n",
129 | "$$\n",
130 | "(Df)(x) := \\frac{f(x+h) - f(x)}{h}\n",
131 | "$$\n",
132 | "\n",
133 | "where $ h $ is a small positive number."
134 | ]
135 | },
136 | {
137 | "cell_type": "code",
138 | "execution_count": null,
139 | "id": "6bdd0a76",
140 | "metadata": {
141 | "hide-output": false
142 | },
143 | "outputs": [],
144 | "source": [
145 | "def f(x):\n",
146 | " \"Original function.\"\n",
147 | " return np.exp(2 * x)\n",
148 | "\n",
149 | "def f_prime(x):\n",
150 | " \"True derivative.\"\n",
151 | " return 2 * np.exp(2 * x)\n",
152 | "\n",
153 | "def Df(x, h=0.1):\n",
154 | " \"Approximate derivative (finite difference).\"\n",
155 | " return (f(x + h) - f(x))/h\n",
156 | "\n",
157 | "x_grid = np.linspace(-2, 1, 200)\n",
158 | "fig, ax = plt.subplots()\n",
159 | "ax.plot(x_grid, f_prime(x_grid), label=\"$f'$\")\n",
160 | "ax.plot(x_grid, Df(x_grid), label=\"$Df$\")\n",
161 | "ax.legend()\n",
162 | "plt.show()"
163 | ]
164 | },
165 | {
166 | "cell_type": "markdown",
167 | "id": "fbfbaf53",
168 | "metadata": {},
169 | "source": [
170 | "This kind of numerical derivative is often inaccurate and unstable.\n",
171 | "\n",
172 | "One reason is that\n",
173 | "\n",
174 | "$$\n",
175 | "\\frac{f(x+h) - f(x)}{h} \\approx \\frac{0}{0}\n",
176 | "$$\n",
177 | "\n",
178 | "Small numbers in the numerator and denominator causes rounding errors.\n",
179 | "\n",
180 | "The situation is exponentially worse in high dimensions / with higher order derivatives"
181 | ]
182 | },
183 | {
184 | "cell_type": "markdown",
185 | "id": "f5b5184b",
186 | "metadata": {},
187 | "source": [
188 | "### Autodiff is not symbolic calculus\n",
189 | "\n",
190 | "Symbolic calculus tries to use rules for differentiation to produce a single\n",
191 | "closed-form expression representing a derivative."
192 | ]
193 | },
194 | {
195 | "cell_type": "code",
196 | "execution_count": null,
197 | "id": "b363aac1",
198 | "metadata": {
199 | "hide-output": false
200 | },
201 | "outputs": [],
202 | "source": [
203 | "from sympy import symbols, diff\n",
204 | "\n",
205 | "m, a, b, x = symbols('m a b x')\n",
206 | "f_x = (a*x + b)**m\n",
207 | "f_x.diff((x, 6)) # 6-th order derivative"
208 | ]
209 | },
210 | {
211 | "cell_type": "markdown",
212 | "id": "5087cf5f",
213 | "metadata": {},
214 | "source": [
215 | "Symbolic calculus is not well suited to high performance\n",
216 | "computing.\n",
217 | "\n",
218 | "One disadvantage is that symbolic calculus cannot differentiate through control flow.\n",
219 | "\n",
220 | "Also, using symbolic calculus might involve redundant calculations.\n",
221 | "\n",
222 | "For example, consider\n",
223 | "\n",
224 | "$$\n",
225 | "(f g h)'\n",
226 | " = (f' g + g' f) h + (f g) h'\n",
227 | "$$\n",
228 | "\n",
229 | "If we evaluate at $ x $, then we evalute $ f(x) $ and $ g(x) $ twice each.\n",
230 | "\n",
231 | "Also, computing $ f'(x) $ and $ f(x) $ might involve similar terms (e.g., $ (f(x) = \\exp(2x)' \\implies f'(x) = 2f(x) $) but this is not exploited in symbolic algebra."
232 | ]
233 | },
234 | {
235 | "cell_type": "markdown",
236 | "id": "09cf9012",
237 | "metadata": {},
238 | "source": [
239 | "### Autodiff\n",
240 | "\n",
241 | "Autodiff produces functions that evaluates derivatives at numerical values\n",
242 | "passed in by the calling code, rather than producing a single symbolic\n",
243 | "expression representing the entire derivative.\n",
244 | "\n",
245 | "Derivatives are constructed by breaking calculations into component parts via the chain rule.\n",
246 | "\n",
247 | "The chain rule is applied until the point where the terms reduce to primitive functions that the program knows how to differentiate exactly (addition, subtraction, exponentiation, sine and cosine, etc.)"
248 | ]
249 | },
250 | {
251 | "cell_type": "markdown",
252 | "id": "46b0f00d",
253 | "metadata": {},
254 | "source": [
255 | "## Some experiments\n",
256 | "\n",
257 | "Let’s start with some real-valued functions on $ \\mathbb R $."
258 | ]
259 | },
260 | {
261 | "cell_type": "markdown",
262 | "id": "27427a8c",
263 | "metadata": {},
264 | "source": [
265 | "### A differentiable function\n",
266 | "\n",
267 | "Let’s test JAX’s auto diff with a relatively simple function."
268 | ]
269 | },
270 | {
271 | "cell_type": "code",
272 | "execution_count": null,
273 | "id": "c5cddeba",
274 | "metadata": {
275 | "hide-output": false
276 | },
277 | "outputs": [],
278 | "source": [
279 | "def f(x):\n",
280 | " return jnp.sin(x) - 2 * jnp.cos(3 * x) * jnp.exp(- x**2)"
281 | ]
282 | },
283 | {
284 | "cell_type": "markdown",
285 | "id": "a69a7ced",
286 | "metadata": {},
287 | "source": [
288 | "We use `grad` to compute the gradient of a real-valued function:"
289 | ]
290 | },
291 | {
292 | "cell_type": "code",
293 | "execution_count": null,
294 | "id": "1cd5d97e",
295 | "metadata": {
296 | "hide-output": false
297 | },
298 | "outputs": [],
299 | "source": [
300 | "f_prime = jax.grad(f)"
301 | ]
302 | },
303 | {
304 | "cell_type": "markdown",
305 | "id": "228b8f98",
306 | "metadata": {},
307 | "source": [
308 | "Let’s plot the result:"
309 | ]
310 | },
311 | {
312 | "cell_type": "code",
313 | "execution_count": null,
314 | "id": "e4ec4923",
315 | "metadata": {
316 | "hide-output": false
317 | },
318 | "outputs": [],
319 | "source": [
320 | "x_grid = jnp.linspace(-5, 5, 100)"
321 | ]
322 | },
323 | {
324 | "cell_type": "code",
325 | "execution_count": null,
326 | "id": "356faf27",
327 | "metadata": {
328 | "hide-output": false
329 | },
330 | "outputs": [],
331 | "source": [
332 | "fig, ax = plt.subplots()\n",
333 | "ax.plot(x_grid, [f(x) for x in x_grid], label=\"$f$\")\n",
334 | "ax.plot(x_grid, [f_prime(x) for x in x_grid], label=\"$f'$\")\n",
335 | "ax.legend()\n",
336 | "plt.show()"
337 | ]
338 | },
339 | {
340 | "cell_type": "markdown",
341 | "id": "0ffbe3f4",
342 | "metadata": {},
343 | "source": [
344 | "### Absolute value function\n",
345 | "\n",
346 | "What happens if the function is not differentiable?"
347 | ]
348 | },
349 | {
350 | "cell_type": "code",
351 | "execution_count": null,
352 | "id": "d9010eac",
353 | "metadata": {
354 | "hide-output": false
355 | },
356 | "outputs": [],
357 | "source": [
358 | "def f(x):\n",
359 | " return jnp.abs(x)"
360 | ]
361 | },
362 | {
363 | "cell_type": "code",
364 | "execution_count": null,
365 | "id": "a0bfb19c",
366 | "metadata": {
367 | "hide-output": false
368 | },
369 | "outputs": [],
370 | "source": [
371 | "f_prime = jax.grad(f)"
372 | ]
373 | },
374 | {
375 | "cell_type": "code",
376 | "execution_count": null,
377 | "id": "08281fc5",
378 | "metadata": {
379 | "hide-output": false
380 | },
381 | "outputs": [],
382 | "source": [
383 | "fig, ax = plt.subplots()\n",
384 | "ax.plot(x_grid, [f(x) for x in x_grid], label=\"$f$\")\n",
385 | "ax.plot(x_grid, [f_prime(x) for x in x_grid], label=\"$f'$\")\n",
386 | "ax.legend()\n",
387 | "plt.show()"
388 | ]
389 | },
390 | {
391 | "cell_type": "markdown",
392 | "id": "86723cb2",
393 | "metadata": {},
394 | "source": [
395 | "At the nondifferentiable point $ 0 $, `jax.grad` returns the right derivative:"
396 | ]
397 | },
398 | {
399 | "cell_type": "code",
400 | "execution_count": null,
401 | "id": "e93b475b",
402 | "metadata": {
403 | "hide-output": false
404 | },
405 | "outputs": [],
406 | "source": [
407 | "f_prime(0.0)"
408 | ]
409 | },
410 | {
411 | "cell_type": "markdown",
412 | "id": "74ea684a",
413 | "metadata": {},
414 | "source": [
415 | "### Differentiating through control flow\n",
416 | "\n",
417 | "Let’s try differentiating through some loops and conditions."
418 | ]
419 | },
420 | {
421 | "cell_type": "code",
422 | "execution_count": null,
423 | "id": "74d552ec",
424 | "metadata": {
425 | "hide-output": false
426 | },
427 | "outputs": [],
428 | "source": [
429 | "def f(x):\n",
430 | " def f1(x):\n",
431 | " for i in range(2):\n",
432 | " x *= 0.2 * x\n",
433 | " return x\n",
434 | " def f2(x):\n",
435 | " x = sum((x**i + i) for i in range(3))\n",
436 | " return x\n",
437 | " y = f1(x) if x < 0 else f2(x)\n",
438 | " return y"
439 | ]
440 | },
441 | {
442 | "cell_type": "code",
443 | "execution_count": null,
444 | "id": "47ed86fd",
445 | "metadata": {
446 | "hide-output": false
447 | },
448 | "outputs": [],
449 | "source": [
450 | "f_prime = jax.grad(f)"
451 | ]
452 | },
453 | {
454 | "cell_type": "code",
455 | "execution_count": null,
456 | "id": "d5bffb0c",
457 | "metadata": {
458 | "hide-output": false
459 | },
460 | "outputs": [],
461 | "source": [
462 | "x_grid = jnp.linspace(-5, 5, 100)"
463 | ]
464 | },
465 | {
466 | "cell_type": "code",
467 | "execution_count": null,
468 | "id": "fadeee11",
469 | "metadata": {
470 | "hide-output": false
471 | },
472 | "outputs": [],
473 | "source": [
474 | "fig, ax = plt.subplots()\n",
475 | "ax.plot(x_grid, [f(x) for x in x_grid], label=\"$f$\")\n",
476 | "ax.plot(x_grid, [f_prime(x) for x in x_grid], label=\"$f'$\")\n",
477 | "ax.legend()\n",
478 | "plt.show()"
479 | ]
480 | },
481 | {
482 | "cell_type": "markdown",
483 | "id": "79567851",
484 | "metadata": {},
485 | "source": [
486 | "### Differentiating through a linear interpolation\n",
487 | "\n",
488 | "We can differentiate through linear interpolation, even though the function is not smooth:"
489 | ]
490 | },
491 | {
492 | "cell_type": "code",
493 | "execution_count": null,
494 | "id": "2324c34e",
495 | "metadata": {
496 | "hide-output": false
497 | },
498 | "outputs": [],
499 | "source": [
500 | "n = 20\n",
501 | "xp = jnp.linspace(-5, 5, n)\n",
502 | "yp = jnp.cos(2 * xp)\n",
503 | "\n",
504 | "fig, ax = plt.subplots()\n",
505 | "ax.plot(x_grid, jnp.interp(x_grid, xp, yp))\n",
506 | "plt.show()"
507 | ]
508 | },
509 | {
510 | "cell_type": "code",
511 | "execution_count": null,
512 | "id": "ae67f453",
513 | "metadata": {
514 | "hide-output": false
515 | },
516 | "outputs": [],
517 | "source": [
518 | "f_prime = jax.grad(jnp.interp)"
519 | ]
520 | },
521 | {
522 | "cell_type": "code",
523 | "execution_count": null,
524 | "id": "00f88ae6",
525 | "metadata": {
526 | "hide-output": false
527 | },
528 | "outputs": [],
529 | "source": [
530 | "f_prime_vec = jax.vmap(f_prime, in_axes=(0, None, None))"
531 | ]
532 | },
533 | {
534 | "cell_type": "code",
535 | "execution_count": null,
536 | "id": "3720f0ef",
537 | "metadata": {
538 | "hide-output": false
539 | },
540 | "outputs": [],
541 | "source": [
542 | "fig, ax = plt.subplots()\n",
543 | "ax.plot(x_grid, f_prime_vec(x_grid, xp, yp))\n",
544 | "plt.show()"
545 | ]
546 | },
547 | {
548 | "cell_type": "markdown",
549 | "id": "57b3cf7b",
550 | "metadata": {},
551 | "source": [
552 | "## Gradient Descent\n",
553 | "\n",
554 | "Let’s try implementing gradient descent.\n",
555 | "\n",
556 | "As a simple application, we’ll use gradient descent to solve for the OLS parameter estimates in simple linear regression."
557 | ]
558 | },
559 | {
560 | "cell_type": "markdown",
561 | "id": "43f8136b",
562 | "metadata": {},
563 | "source": [
564 | "### A function for gradient descent\n",
565 | "\n",
566 | "Here’s an implementation of gradient descent."
567 | ]
568 | },
569 | {
570 | "cell_type": "code",
571 | "execution_count": null,
572 | "id": "6f278b7f",
573 | "metadata": {
574 | "hide-output": false
575 | },
576 | "outputs": [],
577 | "source": [
578 | "def grad_descent(f, # Function to be minimized\n",
579 | " args, # Extra arguments to the function\n",
580 | " x0, # Initial condition\n",
581 | " λ=0.1, # Initial learning rate\n",
582 | " tol=1e-5, \n",
583 | " max_iter=1_000):\n",
584 | " \"\"\"\n",
585 | " Minimize the function f via gradient descent, starting from guess x0.\n",
586 | "\n",
587 | " The learning rate is computed according to the Barzilai-Borwein method.\n",
588 | " \n",
589 | " \"\"\"\n",
590 | " \n",
591 | " f_grad = jax.grad(f)\n",
592 | " x = jnp.array(x0)\n",
593 | " df = f_grad(x, args)\n",
594 | " ϵ = tol + 1\n",
595 | " i = 0\n",
596 | " while ϵ > tol and i < max_iter:\n",
597 | " new_x = x - λ * df\n",
598 | " new_df = f_grad(new_x, args)\n",
599 | " Δx = new_x - x\n",
600 | " Δdf = new_df - df\n",
601 | " λ = jnp.abs(Δx @ Δdf) / (Δdf @ Δdf)\n",
602 | " ϵ = jnp.max(jnp.abs(Δx))\n",
603 | " x, df = new_x, new_df\n",
604 | " i += 1\n",
605 | " \n",
606 | " return x"
607 | ]
608 | },
609 | {
610 | "cell_type": "markdown",
611 | "id": "c2a7ffc6",
612 | "metadata": {},
613 | "source": [
614 | "### Simulated data\n",
615 | "\n",
616 | "We’re going to test our gradient descent function my minimizing a sum of least squares in a regression problem.\n",
617 | "\n",
618 | "Let’s generate some simulated data:"
619 | ]
620 | },
621 | {
622 | "cell_type": "code",
623 | "execution_count": null,
624 | "id": "826f2c73",
625 | "metadata": {
626 | "hide-output": false
627 | },
628 | "outputs": [],
629 | "source": [
630 | "n = 100\n",
631 | "key = jax.random.PRNGKey(1234)\n",
632 | "x = jax.random.uniform(key, (n,))\n",
633 | "\n",
634 | "α, β, σ = 0.5, 1.0, 0.1 # Set the true intercept and slope.\n",
635 | "key, subkey = jax.random.split(key)\n",
636 | "ϵ = jax.random.normal(subkey, (n,))\n",
637 | "\n",
638 | "y = α * x + β + σ * ϵ"
639 | ]
640 | },
641 | {
642 | "cell_type": "code",
643 | "execution_count": null,
644 | "id": "4c444d34",
645 | "metadata": {
646 | "hide-output": false
647 | },
648 | "outputs": [],
649 | "source": [
650 | "fig, ax = plt.subplots()\n",
651 | "ax.scatter(x, y)\n",
652 | "plt.show()"
653 | ]
654 | },
655 | {
656 | "cell_type": "markdown",
657 | "id": "b1708a68",
658 | "metadata": {},
659 | "source": [
660 | "Let’s start by calculating the estimated slope and intercept using closed form solutions."
661 | ]
662 | },
663 | {
664 | "cell_type": "code",
665 | "execution_count": null,
666 | "id": "31071eea",
667 | "metadata": {
668 | "hide-output": false
669 | },
670 | "outputs": [],
671 | "source": [
672 | "mx = x.mean()\n",
673 | "my = y.mean()\n",
674 | "α_hat = jnp.sum((x - mx) * (y - my)) / jnp.sum((x - mx)**2)\n",
675 | "β_hat = my - α_hat * mx"
676 | ]
677 | },
678 | {
679 | "cell_type": "code",
680 | "execution_count": null,
681 | "id": "d916e891",
682 | "metadata": {
683 | "hide-output": false
684 | },
685 | "outputs": [],
686 | "source": [
687 | "α_hat, β_hat"
688 | ]
689 | },
690 | {
691 | "cell_type": "code",
692 | "execution_count": null,
693 | "id": "0c0823f8",
694 | "metadata": {
695 | "hide-output": false
696 | },
697 | "outputs": [],
698 | "source": [
699 | "fig, ax = plt.subplots()\n",
700 | "ax.scatter(x, y)\n",
701 | "ax.plot(x, α_hat * x + β_hat, 'k-')\n",
702 | "ax.text(0.1, 1.55, rf'$\\hat \\alpha = {α_hat:.3}$')\n",
703 | "ax.text(0.1, 1.50, rf'$\\hat \\beta = {β_hat:.3}$')\n",
704 | "plt.show()"
705 | ]
706 | },
707 | {
708 | "cell_type": "markdown",
709 | "id": "77c1b99a",
710 | "metadata": {},
711 | "source": [
712 | "### Minimizing squared loss by gradient descent\n",
713 | "\n",
714 | "Let’s see if we can get the same values with our gradient descent function.\n",
715 | "\n",
716 | "First we set up the least squares loss function."
717 | ]
718 | },
719 | {
720 | "cell_type": "code",
721 | "execution_count": null,
722 | "id": "662adb3f",
723 | "metadata": {
724 | "hide-output": false
725 | },
726 | "outputs": [],
727 | "source": [
728 | "@jax.jit\n",
729 | "def loss(params, data):\n",
730 | " a, b = params\n",
731 | " x, y = data\n",
732 | " return jnp.sum((y - a * x - b)**2)"
733 | ]
734 | },
735 | {
736 | "cell_type": "markdown",
737 | "id": "f784db32",
738 | "metadata": {},
739 | "source": [
740 | "Now we minimize it:"
741 | ]
742 | },
743 | {
744 | "cell_type": "code",
745 | "execution_count": null,
746 | "id": "1e7ac9c1",
747 | "metadata": {
748 | "hide-output": false
749 | },
750 | "outputs": [],
751 | "source": [
752 | "p0 = jnp.zeros(2) # Initial guess for α, β\n",
753 | "data = x, y\n",
754 | "α_hat, β_hat = grad_descent(loss, data, p0)"
755 | ]
756 | },
757 | {
758 | "cell_type": "markdown",
759 | "id": "d159123c",
760 | "metadata": {},
761 | "source": [
762 | "Let’s plot the results."
763 | ]
764 | },
765 | {
766 | "cell_type": "code",
767 | "execution_count": null,
768 | "id": "a47ed873",
769 | "metadata": {
770 | "hide-output": false
771 | },
772 | "outputs": [],
773 | "source": [
774 | "fig, ax = plt.subplots()\n",
775 | "x_grid = jnp.linspace(0, 1, 100)\n",
776 | "ax.scatter(x, y)\n",
777 | "ax.plot(x_grid, α_hat * x_grid + β_hat, 'k-', alpha=0.6)\n",
778 | "ax.text(0.1, 1.55, rf'$\\hat \\alpha = {α_hat:.3}$')\n",
779 | "ax.text(0.1, 1.50, rf'$\\hat \\beta = {β_hat:.3}$')\n",
780 | "plt.show()"
781 | ]
782 | },
783 | {
784 | "cell_type": "markdown",
785 | "id": "f13c4b2d",
786 | "metadata": {},
787 | "source": [
788 | "Notice that we get the same estimates as we did from the closed form solutions."
789 | ]
790 | },
791 | {
792 | "cell_type": "markdown",
793 | "id": "f6c7af8f",
794 | "metadata": {},
795 | "source": [
796 | "### Adding a squared term\n",
797 | "\n",
798 | "Now let’s try fitting a second order polynomial.\n",
799 | "\n",
800 | "Here’s our new loss function."
801 | ]
802 | },
803 | {
804 | "cell_type": "code",
805 | "execution_count": null,
806 | "id": "afa165c9",
807 | "metadata": {
808 | "hide-output": false
809 | },
810 | "outputs": [],
811 | "source": [
812 | "@jax.jit\n",
813 | "def loss(params, data):\n",
814 | " a, b, c = params\n",
815 | " x, y = data\n",
816 | " return jnp.sum((y - a * x**2 - b * x - c)**2)"
817 | ]
818 | },
819 | {
820 | "cell_type": "markdown",
821 | "id": "5172996a",
822 | "metadata": {},
823 | "source": [
824 | "Now we’re minimizing in three dimensions.\n",
825 | "\n",
826 | "Let’s try it."
827 | ]
828 | },
829 | {
830 | "cell_type": "code",
831 | "execution_count": null,
832 | "id": "fe9b6d98",
833 | "metadata": {
834 | "hide-output": false
835 | },
836 | "outputs": [],
837 | "source": [
838 | "p0 = jnp.zeros(3)\n",
839 | "α_hat, β_hat, γ_hat = grad_descent(loss, data, p0)\n",
840 | "\n",
841 | "fig, ax = plt.subplots()\n",
842 | "ax.scatter(x, y)\n",
843 | "ax.plot(x_grid, α_hat * x_grid**2 + β_hat * x_grid + γ_hat, 'k-', alpha=0.6)\n",
844 | "ax.text(0.1, 1.55, rf'$\\hat \\alpha = {α_hat:.3}$')\n",
845 | "ax.text(0.1, 1.50, rf'$\\hat \\beta = {β_hat:.3}$')\n",
846 | "plt.show()"
847 | ]
848 | },
849 | {
850 | "cell_type": "markdown",
851 | "id": "8eb16699",
852 | "metadata": {},
853 | "source": [
854 | "## Exercises"
855 | ]
856 | },
857 | {
858 | "cell_type": "markdown",
859 | "id": "fe37f6ae",
860 | "metadata": {},
861 | "source": [
862 | "## Exercise 3.1\n",
863 | "\n",
864 | "The function `jnp.polyval` evaluates polynomials.\n",
865 | "\n",
866 | "For example, if `len(p)` is 3, then `jnp.polyval(p, x)` returns\n",
867 | "\n",
868 | "$$\n",
869 | "f(p, x) := p_0 x^2 + p_1 x + p_2\n",
870 | "$$\n",
871 | "\n",
872 | "Use this function for polynomial regression.\n",
873 | "\n",
874 | "The (empirical) loss becomes\n",
875 | "\n",
876 | "$$\n",
877 | "\\ell(p, x, y) \n",
878 | " = \\sum_{i=1}^n (y_i - f(p, x_i))^2\n",
879 | "$$\n",
880 | "\n",
881 | "Set $ k=4 $ and set the initial guess of `params` to `jnp.zeros(k)`.\n",
882 | "\n",
883 | "Use gradient descent to find the array `params` that minimizes the loss\n",
884 | "function and plot the result (following the examples above)."
885 | ]
886 | },
887 | {
888 | "cell_type": "markdown",
889 | "id": "92c471df",
890 | "metadata": {},
891 | "source": [
892 | "## Solution\n",
893 | "\n",
894 | "Here’s one solution."
895 | ]
896 | },
897 | {
898 | "cell_type": "code",
899 | "execution_count": null,
900 | "id": "e278fb9b",
901 | "metadata": {
902 | "hide-output": false
903 | },
904 | "outputs": [],
905 | "source": [
906 | "def loss(params, data):\n",
907 | " x, y = data\n",
908 | " return jnp.sum((y - jnp.polyval(params, x))**2)"
909 | ]
910 | },
911 | {
912 | "cell_type": "code",
913 | "execution_count": null,
914 | "id": "ccc336c8",
915 | "metadata": {
916 | "hide-output": false
917 | },
918 | "outputs": [],
919 | "source": [
920 | "k = 4\n",
921 | "p0 = jnp.zeros(k)\n",
922 | "p_hat = grad_descent(loss, data, p0)\n",
923 | "print('Estimated parameter vector:')\n",
924 | "print(p_hat)\n",
925 | "print('\\n\\n')\n",
926 | "\n",
927 | "fig, ax = plt.subplots()\n",
928 | "ax.scatter(x, y)\n",
929 | "ax.plot(x_grid, jnp.polyval(p_hat, x_grid), 'k-', alpha=0.6)\n",
930 | "plt.show()"
931 | ]
932 | }
933 | ],
934 | "metadata": {
935 | "date": 1765244755.101336,
936 | "filename": "autodiff.md",
937 | "kernelspec": {
938 | "display_name": "Python",
939 | "language": "python3",
940 | "name": "python3"
941 | },
942 | "title": "Adventures with Autodiff"
943 | },
944 | "nbformat": 4,
945 | "nbformat_minor": 5
946 | }
--------------------------------------------------------------------------------