├── LICENSE ├── README.md ├── data ├── JFVDSS_rho0p05.mat ├── JFVSSS_rho0p05.mat ├── KS_policy_N100_v1.mat ├── KS_policy_N10_v1.mat └── KS_policy_N50_v1.mat ├── environment.yml └── src ├── configs ├── JFV_DSS │ ├── game_nn_n10.json │ └── game_nn_n50.json ├── JFV_SSS │ └── game_nn_n50.json └── KS │ ├── game_nn_n10.json │ ├── game_nn_n50.json │ └── game_nn_n50_0fm1gm.json ├── dataset.py ├── param.py ├── policy.py ├── simulation_JFV.py ├── simulation_KS.py ├── slurm_scripts ├── KS_1fm.slurm └── KS_1gm.slurm ├── train_JFV.py ├── train_KS.py ├── util.py ├── validate_JFV.py ├── validate_KS.py └── value.py /LICENSE: -------------------------------------------------------------------------------- 1 | GNU LESSER GENERAL PUBLIC LICENSE 2 | Version 2.1, February 1999 3 | 4 | Copyright (C) 1991, 1999 Free Software Foundation, Inc. 5 | 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA 6 | Everyone is permitted to copy and distribute verbatim copies 7 | of this license document, but changing it is not allowed. 8 | 9 | [This is the first released version of the Lesser GPL. It also counts 10 | as the successor of the GNU Library Public License, version 2, hence 11 | the version number 2.1.] 12 | 13 | Preamble 14 | 15 | The licenses for most software are designed to take away your 16 | freedom to share and change it. By contrast, the GNU General Public 17 | Licenses are intended to guarantee your freedom to share and change 18 | free software--to make sure the software is free for all its users. 19 | 20 | This license, the Lesser General Public License, applies to some 21 | specially designated software packages--typically libraries--of the 22 | Free Software Foundation and other authors who decide to use it. You 23 | can use it too, but we suggest you first think carefully about whether 24 | this license or the ordinary General Public License is the better 25 | strategy to use in any particular case, based on the explanations below. 26 | 27 | When we speak of free software, we are referring to freedom of use, 28 | not price. Our General Public Licenses are designed to make sure that 29 | you have the freedom to distribute copies of free software (and charge 30 | for this service if you wish); that you receive source code or can get 31 | it if you want it; that you can change the software and use pieces of 32 | it in new free programs; and that you are informed that you can do 33 | these things. 34 | 35 | To protect your rights, we need to make restrictions that forbid 36 | distributors to deny you these rights or to ask you to surrender these 37 | rights. These restrictions translate to certain responsibilities for 38 | you if you distribute copies of the library or if you modify it. 39 | 40 | For example, if you distribute copies of the library, whether gratis 41 | or for a fee, you must give the recipients all the rights that we gave 42 | you. You must make sure that they, too, receive or can get the source 43 | code. If you link other code with the library, you must provide 44 | complete object files to the recipients, so that they can relink them 45 | with the library after making changes to the library and recompiling 46 | it. And you must show them these terms so they know their rights. 47 | 48 | We protect your rights with a two-step method: (1) we copyright the 49 | library, and (2) we offer you this license, which gives you legal 50 | permission to copy, distribute and/or modify the library. 51 | 52 | To protect each distributor, we want to make it very clear that 53 | there is no warranty for the free library. Also, if the library is 54 | modified by someone else and passed on, the recipients should know 55 | that what they have is not the original version, so that the original 56 | author's reputation will not be affected by problems that might be 57 | introduced by others. 58 | 59 | Finally, software patents pose a constant threat to the existence of 60 | any free program. We wish to make sure that a company cannot 61 | effectively restrict the users of a free program by obtaining a 62 | restrictive license from a patent holder. Therefore, we insist that 63 | any patent license obtained for a version of the library must be 64 | consistent with the full freedom of use specified in this license. 65 | 66 | Most GNU software, including some libraries, is covered by the 67 | ordinary GNU General Public License. This license, the GNU Lesser 68 | General Public License, applies to certain designated libraries, and 69 | is quite different from the ordinary General Public License. We use 70 | this license for certain libraries in order to permit linking those 71 | libraries into non-free programs. 72 | 73 | When a program is linked with a library, whether statically or using 74 | a shared library, the combination of the two is legally speaking a 75 | combined work, a derivative of the original library. The ordinary 76 | General Public License therefore permits such linking only if the 77 | entire combination fits its criteria of freedom. The Lesser General 78 | Public License permits more lax criteria for linking other code with 79 | the library. 80 | 81 | We call this license the "Lesser" General Public License because it 82 | does Less to protect the user's freedom than the ordinary General 83 | Public License. It also provides other free software developers Less 84 | of an advantage over competing non-free programs. These disadvantages 85 | are the reason we use the ordinary General Public License for many 86 | libraries. However, the Lesser license provides advantages in certain 87 | special circumstances. 88 | 89 | For example, on rare occasions, there may be a special need to 90 | encourage the widest possible use of a certain library, so that it becomes 91 | a de-facto standard. To achieve this, non-free programs must be 92 | allowed to use the library. A more frequent case is that a free 93 | library does the same job as widely used non-free libraries. In this 94 | case, there is little to gain by limiting the free library to free 95 | software only, so we use the Lesser General Public License. 96 | 97 | In other cases, permission to use a particular library in non-free 98 | programs enables a greater number of people to use a large body of 99 | free software. For example, permission to use the GNU C Library in 100 | non-free programs enables many more people to use the whole GNU 101 | operating system, as well as its variant, the GNU/Linux operating 102 | system. 103 | 104 | Although the Lesser General Public License is Less protective of the 105 | users' freedom, it does ensure that the user of a program that is 106 | linked with the Library has the freedom and the wherewithal to run 107 | that program using a modified version of the Library. 108 | 109 | The precise terms and conditions for copying, distribution and 110 | modification follow. Pay close attention to the difference between a 111 | "work based on the library" and a "work that uses the library". The 112 | former contains code derived from the library, whereas the latter must 113 | be combined with the library in order to run. 114 | 115 | GNU LESSER GENERAL PUBLIC LICENSE 116 | TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION 117 | 118 | 0. This License Agreement applies to any software library or other 119 | program which contains a notice placed by the copyright holder or 120 | other authorized party saying it may be distributed under the terms of 121 | this Lesser General Public License (also called "this License"). 122 | Each licensee is addressed as "you". 123 | 124 | A "library" means a collection of software functions and/or data 125 | prepared so as to be conveniently linked with application programs 126 | (which use some of those functions and data) to form executables. 127 | 128 | The "Library", below, refers to any such software library or work 129 | which has been distributed under these terms. A "work based on the 130 | Library" means either the Library or any derivative work under 131 | copyright law: that is to say, a work containing the Library or a 132 | portion of it, either verbatim or with modifications and/or translated 133 | straightforwardly into another language. (Hereinafter, translation is 134 | included without limitation in the term "modification".) 135 | 136 | "Source code" for a work means the preferred form of the work for 137 | making modifications to it. For a library, complete source code means 138 | all the source code for all modules it contains, plus any associated 139 | interface definition files, plus the scripts used to control compilation 140 | and installation of the library. 141 | 142 | Activities other than copying, distribution and modification are not 143 | covered by this License; they are outside its scope. The act of 144 | running a program using the Library is not restricted, and output from 145 | such a program is covered only if its contents constitute a work based 146 | on the Library (independent of the use of the Library in a tool for 147 | writing it). Whether that is true depends on what the Library does 148 | and what the program that uses the Library does. 149 | 150 | 1. You may copy and distribute verbatim copies of the Library's 151 | complete source code as you receive it, in any medium, provided that 152 | you conspicuously and appropriately publish on each copy an 153 | appropriate copyright notice and disclaimer of warranty; keep intact 154 | all the notices that refer to this License and to the absence of any 155 | warranty; and distribute a copy of this License along with the 156 | Library. 157 | 158 | You may charge a fee for the physical act of transferring a copy, 159 | and you may at your option offer warranty protection in exchange for a 160 | fee. 161 | 162 | 2. You may modify your copy or copies of the Library or any portion 163 | of it, thus forming a work based on the Library, and copy and 164 | distribute such modifications or work under the terms of Section 1 165 | above, provided that you also meet all of these conditions: 166 | 167 | a) The modified work must itself be a software library. 168 | 169 | b) You must cause the files modified to carry prominent notices 170 | stating that you changed the files and the date of any change. 171 | 172 | c) You must cause the whole of the work to be licensed at no 173 | charge to all third parties under the terms of this License. 174 | 175 | d) If a facility in the modified Library refers to a function or a 176 | table of data to be supplied by an application program that uses 177 | the facility, other than as an argument passed when the facility 178 | is invoked, then you must make a good faith effort to ensure that, 179 | in the event an application does not supply such function or 180 | table, the facility still operates, and performs whatever part of 181 | its purpose remains meaningful. 182 | 183 | (For example, a function in a library to compute square roots has 184 | a purpose that is entirely well-defined independent of the 185 | application. Therefore, Subsection 2d requires that any 186 | application-supplied function or table used by this function must 187 | be optional: if the application does not supply it, the square 188 | root function must still compute square roots.) 189 | 190 | These requirements apply to the modified work as a whole. If 191 | identifiable sections of that work are not derived from the Library, 192 | and can be reasonably considered independent and separate works in 193 | themselves, then this License, and its terms, do not apply to those 194 | sections when you distribute them as separate works. But when you 195 | distribute the same sections as part of a whole which is a work based 196 | on the Library, the distribution of the whole must be on the terms of 197 | this License, whose permissions for other licensees extend to the 198 | entire whole, and thus to each and every part regardless of who wrote 199 | it. 200 | 201 | Thus, it is not the intent of this section to claim rights or contest 202 | your rights to work written entirely by you; rather, the intent is to 203 | exercise the right to control the distribution of derivative or 204 | collective works based on the Library. 205 | 206 | In addition, mere aggregation of another work not based on the Library 207 | with the Library (or with a work based on the Library) on a volume of 208 | a storage or distribution medium does not bring the other work under 209 | the scope of this License. 210 | 211 | 3. You may opt to apply the terms of the ordinary GNU General Public 212 | License instead of this License to a given copy of the Library. To do 213 | this, you must alter all the notices that refer to this License, so 214 | that they refer to the ordinary GNU General Public License, version 2, 215 | instead of to this License. (If a newer version than version 2 of the 216 | ordinary GNU General Public License has appeared, then you can specify 217 | that version instead if you wish.) Do not make any other change in 218 | these notices. 219 | 220 | Once this change is made in a given copy, it is irreversible for 221 | that copy, so the ordinary GNU General Public License applies to all 222 | subsequent copies and derivative works made from that copy. 223 | 224 | This option is useful when you wish to copy part of the code of 225 | the Library into a program that is not a library. 226 | 227 | 4. You may copy and distribute the Library (or a portion or 228 | derivative of it, under Section 2) in object code or executable form 229 | under the terms of Sections 1 and 2 above provided that you accompany 230 | it with the complete corresponding machine-readable source code, which 231 | must be distributed under the terms of Sections 1 and 2 above on a 232 | medium customarily used for software interchange. 233 | 234 | If distribution of object code is made by offering access to copy 235 | from a designated place, then offering equivalent access to copy the 236 | source code from the same place satisfies the requirement to 237 | distribute the source code, even though third parties are not 238 | compelled to copy the source along with the object code. 239 | 240 | 5. A program that contains no derivative of any portion of the 241 | Library, but is designed to work with the Library by being compiled or 242 | linked with it, is called a "work that uses the Library". Such a 243 | work, in isolation, is not a derivative work of the Library, and 244 | therefore falls outside the scope of this License. 245 | 246 | However, linking a "work that uses the Library" with the Library 247 | creates an executable that is a derivative of the Library (because it 248 | contains portions of the Library), rather than a "work that uses the 249 | library". The executable is therefore covered by this License. 250 | Section 6 states terms for distribution of such executables. 251 | 252 | When a "work that uses the Library" uses material from a header file 253 | that is part of the Library, the object code for the work may be a 254 | derivative work of the Library even though the source code is not. 255 | Whether this is true is especially significant if the work can be 256 | linked without the Library, or if the work is itself a library. The 257 | threshold for this to be true is not precisely defined by law. 258 | 259 | If such an object file uses only numerical parameters, data 260 | structure layouts and accessors, and small macros and small inline 261 | functions (ten lines or less in length), then the use of the object 262 | file is unrestricted, regardless of whether it is legally a derivative 263 | work. (Executables containing this object code plus portions of the 264 | Library will still fall under Section 6.) 265 | 266 | Otherwise, if the work is a derivative of the Library, you may 267 | distribute the object code for the work under the terms of Section 6. 268 | Any executables containing that work also fall under Section 6, 269 | whether or not they are linked directly with the Library itself. 270 | 271 | 6. As an exception to the Sections above, you may also combine or 272 | link a "work that uses the Library" with the Library to produce a 273 | work containing portions of the Library, and distribute that work 274 | under terms of your choice, provided that the terms permit 275 | modification of the work for the customer's own use and reverse 276 | engineering for debugging such modifications. 277 | 278 | You must give prominent notice with each copy of the work that the 279 | Library is used in it and that the Library and its use are covered by 280 | this License. You must supply a copy of this License. If the work 281 | during execution displays copyright notices, you must include the 282 | copyright notice for the Library among them, as well as a reference 283 | directing the user to the copy of this License. Also, you must do one 284 | of these things: 285 | 286 | a) Accompany the work with the complete corresponding 287 | machine-readable source code for the Library including whatever 288 | changes were used in the work (which must be distributed under 289 | Sections 1 and 2 above); and, if the work is an executable linked 290 | with the Library, with the complete machine-readable "work that 291 | uses the Library", as object code and/or source code, so that the 292 | user can modify the Library and then relink to produce a modified 293 | executable containing the modified Library. (It is understood 294 | that the user who changes the contents of definitions files in the 295 | Library will not necessarily be able to recompile the application 296 | to use the modified definitions.) 297 | 298 | b) Use a suitable shared library mechanism for linking with the 299 | Library. A suitable mechanism is one that (1) uses at run time a 300 | copy of the library already present on the user's computer system, 301 | rather than copying library functions into the executable, and (2) 302 | will operate properly with a modified version of the library, if 303 | the user installs one, as long as the modified version is 304 | interface-compatible with the version that the work was made with. 305 | 306 | c) Accompany the work with a written offer, valid for at 307 | least three years, to give the same user the materials 308 | specified in Subsection 6a, above, for a charge no more 309 | than the cost of performing this distribution. 310 | 311 | d) If distribution of the work is made by offering access to copy 312 | from a designated place, offer equivalent access to copy the above 313 | specified materials from the same place. 314 | 315 | e) Verify that the user has already received a copy of these 316 | materials or that you have already sent this user a copy. 317 | 318 | For an executable, the required form of the "work that uses the 319 | Library" must include any data and utility programs needed for 320 | reproducing the executable from it. However, as a special exception, 321 | the materials to be distributed need not include anything that is 322 | normally distributed (in either source or binary form) with the major 323 | components (compiler, kernel, and so on) of the operating system on 324 | which the executable runs, unless that component itself accompanies 325 | the executable. 326 | 327 | It may happen that this requirement contradicts the license 328 | restrictions of other proprietary libraries that do not normally 329 | accompany the operating system. Such a contradiction means you cannot 330 | use both them and the Library together in an executable that you 331 | distribute. 332 | 333 | 7. You may place library facilities that are a work based on the 334 | Library side-by-side in a single library together with other library 335 | facilities not covered by this License, and distribute such a combined 336 | library, provided that the separate distribution of the work based on 337 | the Library and of the other library facilities is otherwise 338 | permitted, and provided that you do these two things: 339 | 340 | a) Accompany the combined library with a copy of the same work 341 | based on the Library, uncombined with any other library 342 | facilities. This must be distributed under the terms of the 343 | Sections above. 344 | 345 | b) Give prominent notice with the combined library of the fact 346 | that part of it is a work based on the Library, and explaining 347 | where to find the accompanying uncombined form of the same work. 348 | 349 | 8. You may not copy, modify, sublicense, link with, or distribute 350 | the Library except as expressly provided under this License. Any 351 | attempt otherwise to copy, modify, sublicense, link with, or 352 | distribute the Library is void, and will automatically terminate your 353 | rights under this License. However, parties who have received copies, 354 | or rights, from you under this License will not have their licenses 355 | terminated so long as such parties remain in full compliance. 356 | 357 | 9. You are not required to accept this License, since you have not 358 | signed it. However, nothing else grants you permission to modify or 359 | distribute the Library or its derivative works. These actions are 360 | prohibited by law if you do not accept this License. Therefore, by 361 | modifying or distributing the Library (or any work based on the 362 | Library), you indicate your acceptance of this License to do so, and 363 | all its terms and conditions for copying, distributing or modifying 364 | the Library or works based on it. 365 | 366 | 10. Each time you redistribute the Library (or any work based on the 367 | Library), the recipient automatically receives a license from the 368 | original licensor to copy, distribute, link with or modify the Library 369 | subject to these terms and conditions. You may not impose any further 370 | restrictions on the recipients' exercise of the rights granted herein. 371 | You are not responsible for enforcing compliance by third parties with 372 | this License. 373 | 374 | 11. If, as a consequence of a court judgment or allegation of patent 375 | infringement or for any other reason (not limited to patent issues), 376 | conditions are imposed on you (whether by court order, agreement or 377 | otherwise) that contradict the conditions of this License, they do not 378 | excuse you from the conditions of this License. If you cannot 379 | distribute so as to satisfy simultaneously your obligations under this 380 | License and any other pertinent obligations, then as a consequence you 381 | may not distribute the Library at all. For example, if a patent 382 | license would not permit royalty-free redistribution of the Library by 383 | all those who receive copies directly or indirectly through you, then 384 | the only way you could satisfy both it and this License would be to 385 | refrain entirely from distribution of the Library. 386 | 387 | If any portion of this section is held invalid or unenforceable under any 388 | particular circumstance, the balance of the section is intended to apply, 389 | and the section as a whole is intended to apply in other circumstances. 390 | 391 | It is not the purpose of this section to induce you to infringe any 392 | patents or other property right claims or to contest validity of any 393 | such claims; this section has the sole purpose of protecting the 394 | integrity of the free software distribution system which is 395 | implemented by public license practices. Many people have made 396 | generous contributions to the wide range of software distributed 397 | through that system in reliance on consistent application of that 398 | system; it is up to the author/donor to decide if he or she is willing 399 | to distribute software through any other system and a licensee cannot 400 | impose that choice. 401 | 402 | This section is intended to make thoroughly clear what is believed to 403 | be a consequence of the rest of this License. 404 | 405 | 12. If the distribution and/or use of the Library is restricted in 406 | certain countries either by patents or by copyrighted interfaces, the 407 | original copyright holder who places the Library under this License may add 408 | an explicit geographical distribution limitation excluding those countries, 409 | so that distribution is permitted only in or among countries not thus 410 | excluded. In such case, this License incorporates the limitation as if 411 | written in the body of this License. 412 | 413 | 13. The Free Software Foundation may publish revised and/or new 414 | versions of the Lesser General Public License from time to time. 415 | Such new versions will be similar in spirit to the present version, 416 | but may differ in detail to address new problems or concerns. 417 | 418 | Each version is given a distinguishing version number. If the Library 419 | specifies a version number of this License which applies to it and 420 | "any later version", you have the option of following the terms and 421 | conditions either of that version or of any later version published by 422 | the Free Software Foundation. If the Library does not specify a 423 | license version number, you may choose any version ever published by 424 | the Free Software Foundation. 425 | 426 | 14. If you wish to incorporate parts of the Library into other free 427 | programs whose distribution conditions are incompatible with these, 428 | write to the author to ask for permission. For software which is 429 | copyrighted by the Free Software Foundation, write to the Free 430 | Software Foundation; we sometimes make exceptions for this. Our 431 | decision will be guided by the two goals of preserving the free status 432 | of all derivatives of our free software and of promoting the sharing 433 | and reuse of software generally. 434 | 435 | NO WARRANTY 436 | 437 | 15. BECAUSE THE LIBRARY IS LICENSED FREE OF CHARGE, THERE IS NO 438 | WARRANTY FOR THE LIBRARY, TO THE EXTENT PERMITTED BY APPLICABLE LAW. 439 | EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT HOLDERS AND/OR 440 | OTHER PARTIES PROVIDE THE LIBRARY "AS IS" WITHOUT WARRANTY OF ANY 441 | KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, THE 442 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 443 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE 444 | LIBRARY IS WITH YOU. SHOULD THE LIBRARY PROVE DEFECTIVE, YOU ASSUME 445 | THE COST OF ALL NECESSARY SERVICING, REPAIR OR CORRECTION. 446 | 447 | 16. IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN 448 | WRITING WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MAY MODIFY 449 | AND/OR REDISTRIBUTE THE LIBRARY AS PERMITTED ABOVE, BE LIABLE TO YOU 450 | FOR DAMAGES, INCLUDING ANY GENERAL, SPECIAL, INCIDENTAL OR 451 | CONSEQUENTIAL DAMAGES ARISING OUT OF THE USE OR INABILITY TO USE THE 452 | LIBRARY (INCLUDING BUT NOT LIMITED TO LOSS OF DATA OR DATA BEING 453 | RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD PARTIES OR A 454 | FAILURE OF THE LIBRARY TO OPERATE WITH ANY OTHER SOFTWARE), EVEN IF 455 | SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH 456 | DAMAGES. 457 | 458 | END OF TERMS AND CONDITIONS 459 | 460 | How to Apply These Terms to Your New Libraries 461 | 462 | If you develop a new library, and you want it to be of the greatest 463 | possible use to the public, we recommend making it free software that 464 | everyone can redistribute and change. You can do so by permitting 465 | redistribution under these terms (or, alternatively, under the terms of the 466 | ordinary General Public License). 467 | 468 | To apply these terms, attach the following notices to the library. It is 469 | safest to attach them to the start of each source file to most effectively 470 | convey the exclusion of warranty; and each file should have at least the 471 | "copyright" line and a pointer to where the full notice is found. 472 | 473 | 474 | Copyright (C) 475 | 476 | This library is free software; you can redistribute it and/or 477 | modify it under the terms of the GNU Lesser General Public 478 | License as published by the Free Software Foundation; either 479 | version 2.1 of the License, or (at your option) any later version. 480 | 481 | This library is distributed in the hope that it will be useful, 482 | but WITHOUT ANY WARRANTY; without even the implied warranty of 483 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU 484 | Lesser General Public License for more details. 485 | 486 | You should have received a copy of the GNU Lesser General Public 487 | License along with this library; if not, write to the Free Software 488 | Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 489 | USA 490 | 491 | Also add information on how to contact you by electronic and paper mail. 492 | 493 | You should also get your employer (if you work as a programmer) or your 494 | school, if any, to sign a "copyright disclaimer" for the library, if 495 | necessary. Here is a sample; alter the names: 496 | 497 | Yoyodyne, Inc., hereby disclaims all copyright interest in the 498 | library `Frob' (a library for tweaking knobs) written by James Random 499 | Hacker. 500 | 501 | , 1 April 1990 502 | Ty Coon, President of Vice 503 | 504 | That's all there is to it! 505 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | # DeepHAM: A global solution method for heterogeneous agent models with aggregate shocks 4 | 5 | Jiequn Han, Yucheng Yang, Weinan E 6 | 7 | [![arXiv](https://img.shields.io/badge/arXiv-2112.14377-b31b1b.svg)](https://arxiv.org/abs/2112.14377) 8 | [![SSRN](https://img.shields.io/badge/SSRN-3990409-133a6f.svg)](https://papers.ssrn.com/sol3/papers.cfm?abstract_id=3990409) 9 | [![PDF](https://img.shields.io/badge/PDF-8A2BE2)](https://yangycpku.github.io/files/DeepHAM_paper.pdf) 10 | 11 | Link to code repository: https://github.com/frankhan91/DeepHAM 12 | 13 |
14 | 15 | 16 | ## Dependencies 17 | * Quick installation of conda environment for Python: ``conda env create -f environment.yml`` 18 | 19 | ## Running 20 | ### Quick start for the Krusell-Smith (KS) model under default configs: 21 | To use DeepHAM to solve the competitive equilibrium of the KS model, run 22 | ``` 23 | python train_KS.py 24 | ``` 25 | To evaluate the Bellman error of the solution of the KS model, run 26 | ``` 27 | python validate_KS.py 28 | ``` 29 | 30 | Sample scripts for solving the KS model in the Slurm system are provided in the folder ``src/slurm_scripts`` 31 | 32 | ### Solve the model in Fernandez-Villaverde, Hurtado, and Nuno (2019): 33 | ``` 34 | python train_JFV.py 35 | ``` 36 | ``` 37 | python validate_JFV.py 38 | ``` 39 | Details on the model setup and algorithm can be found in our paper. 40 | 41 | ## Citation 42 | If you find this work helpful, please consider starring this repo and citing our paper using the following Bibtex. 43 | ```bibtex 44 | @article{HanYangE2021deepham, 45 | title={Deep{HAM}: A global solution method for heterogeneous agent models with aggregate shocks}, 46 | author={Han, Jiequn and Yang, Yucheng and E, Weinan}, 47 | journal={arXiv preprint arXiv:2112.14377}, 48 | year={2021} 49 | } 50 | ``` 51 | 52 | ## Contact 53 | Please contact us at jiequnhan@gmail.com and yucheng.yang@uzh.ch if you have any questions. 54 | -------------------------------------------------------------------------------- /data/JFVDSS_rho0p05.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frankhan91/DeepHAM/887d59535288d8b95502f145842b86a5bfde5cfa/data/JFVDSS_rho0p05.mat -------------------------------------------------------------------------------- /data/JFVSSS_rho0p05.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frankhan91/DeepHAM/887d59535288d8b95502f145842b86a5bfde5cfa/data/JFVSSS_rho0p05.mat -------------------------------------------------------------------------------- /data/KS_policy_N100_v1.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frankhan91/DeepHAM/887d59535288d8b95502f145842b86a5bfde5cfa/data/KS_policy_N100_v1.mat -------------------------------------------------------------------------------- /data/KS_policy_N10_v1.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frankhan91/DeepHAM/887d59535288d8b95502f145842b86a5bfde5cfa/data/KS_policy_N10_v1.mat -------------------------------------------------------------------------------- /data/KS_policy_N50_v1.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frankhan91/DeepHAM/887d59535288d8b95502f145842b86a5bfde5cfa/data/KS_policy_N50_v1.mat -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: deepham 2 | dependencies: 3 | - python=3.8 4 | - numpy 5 | - scipy 6 | - matplotlib 7 | - jupyterlab 8 | - tqdm 9 | - quantecon 10 | - pip 11 | - pip: 12 | - tensorflow==2.13 13 | - seaborn 14 | - pandas 15 | -------------------------------------------------------------------------------- /src/configs/JFV_DSS/game_nn_n10.json: -------------------------------------------------------------------------------- 1 | { 2 | "n_agt": 10, 3 | "dt": 0.2, 4 | "mats_path": "../data/JFVDSS_rho0p05.mat", 5 | "with_ashock": false, 6 | "n_basic": 3, 7 | "_comment_basic": "k_cross, N, ishock", 8 | "n_fm": 1, 9 | "n_gm": 0, 10 | "init_with_bchmk": false, 11 | "init_const_share": 0.4, 12 | "dataset_config": { 13 | "n_path": 384, 14 | "t_burn": 4000, 15 | "value_sampling": "nn", 16 | "moving_average": 1.0 17 | }, 18 | "value_config": { 19 | "num_vnet": 2, 20 | "T": 5000, 21 | "t_count": 1300, 22 | "t_skip": 200, 23 | "_comment_v_data": "the above is about ValueDataSet", 24 | "num_epoch": 200, 25 | "lr": 1e-4, 26 | "net_width": [24, 24], 27 | "activation": "tanh", 28 | "batch_size": 64, 29 | "valid_size": 128, 30 | "_comment_v_learn": "the above is about learning" 31 | }, 32 | "policy_config": { 33 | "opt_type": "game", 34 | "update_init": true, 35 | "t_unroll": 200, 36 | "num_step": 8000, 37 | "freq_update_v": 1600, 38 | "lr_beg": 4e-4, 39 | "lr_end": 4e-4, 40 | "net_width": [24, 24], 41 | "activation": "tanh", 42 | "batch_size": 384, 43 | "valid_size": 384, 44 | "freq_valid": 200, 45 | "sgm_scale": 1, 46 | "_comment_p_learn": "the above is about learning", 47 | "T": 450, 48 | "t_sample": 300, 49 | "t_skip": 6, 50 | "epoch_resample": 3, 51 | "_comment_p_data": "the above is about PolicyDataSet" 52 | }, 53 | "gm_config": { 54 | "net_width": [12, 12], 55 | "activation": "tanh" 56 | }, 57 | "simul_config": { 58 | "n_path": 32, 59 | "T": 6000 60 | } 61 | } -------------------------------------------------------------------------------- /src/configs/JFV_DSS/game_nn_n50.json: -------------------------------------------------------------------------------- 1 | { 2 | "n_agt": 50, 3 | "dt": 0.2, 4 | "mats_path": "../data/JFVDSS_rho0p05.mat", 5 | "with_ashock": false, 6 | "n_basic": 3, 7 | "_comment_basic": "k_cross, N, ishock", 8 | "n_fm": 1, 9 | "n_gm": 0, 10 | "init_with_bchmk": false, 11 | "init_const_share": 0.4, 12 | "dataset_config": { 13 | "n_path": 384, 14 | "t_burn": 16000, 15 | "value_sampling": "nn", 16 | "moving_average": 1.0 17 | }, 18 | "value_config": { 19 | "num_vnet": 3, 20 | "T": 5000, 21 | "t_count": 2000, 22 | "t_skip": 200, 23 | "_comment_v_data": "the above is about ValueDataSet", 24 | "num_epoch": 200, 25 | "lr": 1e-4, 26 | "net_width": [24, 24], 27 | "activation": "tanh", 28 | "batch_size": 128, 29 | "valid_size": 1024, 30 | "_comment_v_learn": "the above is about learning" 31 | }, 32 | "policy_config": { 33 | "opt_type": "game", 34 | "update_init": true, 35 | "t_unroll": 200, 36 | "num_step": 8000, 37 | "freq_update_v": 1600, 38 | "lr_beg": 4e-4, 39 | "lr_end": 4e-4, 40 | "net_width": [24, 24], 41 | "activation": "tanh", 42 | "batch_size": 384, 43 | "valid_size": 384, 44 | "freq_valid": 200, 45 | "sgm_scale": 1, 46 | "_comment_p_learn": "the above is about learning", 47 | "T": 450, 48 | "t_sample": 300, 49 | "t_skip": 6, 50 | "epoch_resample": 3, 51 | "_comment_p_data": "the above is about PolicyDataSet" 52 | }, 53 | "gm_config": { 54 | "net_width": [12, 12], 55 | "activation": "tanh" 56 | }, 57 | "simul_config": { 58 | "n_path": 32, 59 | "T": 6000 60 | } 61 | } -------------------------------------------------------------------------------- /src/configs/JFV_SSS/game_nn_n50.json: -------------------------------------------------------------------------------- 1 | { 2 | "n_agt": 50, 3 | "dt": 0.2, 4 | "mats_path": "../data/JFVSSS_rho0p05.mat", 5 | "with_ashock": true, 6 | "n_basic": 3, 7 | "_comment_basic": "k_cross, N, ishock", 8 | "n_fm": 1, 9 | "n_gm": 0, 10 | "init_with_bchmk": false, 11 | "init_const_share": 0.4, 12 | "dataset_config": { 13 | "n_path": 384, 14 | "t_burn": 16000, 15 | "value_sampling": "nn", 16 | "moving_average": 1.0 17 | }, 18 | "value_config": { 19 | "num_vnet": 3, 20 | "T": 5000, 21 | "t_count": 2000, 22 | "t_skip": 200, 23 | "_comment_v_data": "the above is about ValueDataSet", 24 | "num_epoch": 200, 25 | "lr": 1e-4, 26 | "net_width": [24, 24], 27 | "activation": "tanh", 28 | "batch_size": 128, 29 | "valid_size": 1024, 30 | "_comment_v_learn": "the above is about learning" 31 | }, 32 | "policy_config": { 33 | "opt_type": "game", 34 | "update_init": true, 35 | "t_unroll": 200, 36 | "num_step": 8000, 37 | "freq_update_v": 1600, 38 | "lr_beg": 4e-4, 39 | "lr_end": 4e-4, 40 | "net_width": [24, 24], 41 | "activation": "tanh", 42 | "batch_size": 384, 43 | "valid_size": 384, 44 | "freq_valid": 200, 45 | "sgm_scale": 1, 46 | "_comment_p_learn": "the above is about learning", 47 | "T": 450, 48 | "t_sample": 300, 49 | "t_skip": 6, 50 | "epoch_resample": 3, 51 | "_comment_p_data": "the above is about PolicyDataSet" 52 | }, 53 | "gm_config": { 54 | "net_width": [12, 12], 55 | "activation": "tanh" 56 | }, 57 | "simul_config": { 58 | "n_path": 16, 59 | "T": 12000 60 | } 61 | } -------------------------------------------------------------------------------- /src/configs/KS/game_nn_n10.json: -------------------------------------------------------------------------------- 1 | { 2 | "n_agt": 10, 3 | "beta": 0.99, 4 | "mats_path": "../data/KS_policy_N10_v1.mat", 5 | "n_basic": 3, 6 | "_comment_basic": "k_cross, ishock, ashock", 7 | "n_fm": 1, 8 | "n_gm": 0, 9 | "init_with_bchmk": false, 10 | "init_const_share": 0.4, 11 | "dataset_config": { 12 | "n_path": 384, 13 | "t_burn": 200, 14 | "value_sampling": "nn", 15 | "moving_average": 1.0 16 | }, 17 | "value_config": { 18 | "num_vnet": 2, 19 | "T": 2000, 20 | "t_count": 800, 21 | "t_skip": 100, 22 | "_comment_v_data": "the above is about ValueDataSet", 23 | "num_epoch": 200, 24 | "lr": 1e-4, 25 | "net_width": [24, 24], 26 | "activation": "tanh", 27 | "batch_size": 64, 28 | "valid_size": 64, 29 | "_comment_v_learn": "the above is about learning" 30 | }, 31 | "policy_config": { 32 | "opt_type": "game", 33 | "update_init": true, 34 | "t_unroll": 150, 35 | "num_step": 10000, 36 | "freq_update_v": 2000, 37 | "lr_beg": 4e-4, 38 | "lr_end": 4e-4, 39 | "net_width": [24, 24], 40 | "activation": "tanh", 41 | "batch_size": 384, 42 | "valid_size": 384, 43 | "freq_valid": 500, 44 | "sgm_scale": 1, 45 | "_comment_p_learn": "the above is about learning", 46 | "T": 450, 47 | "t_sample": 200, 48 | "t_skip": 4, 49 | "epoch_resample": 3, 50 | "_comment_p_data": "the above is about PolicyDataSet" 51 | }, 52 | "gm_config": { 53 | "net_width": [12, 12], 54 | "activation": "tanh" 55 | }, 56 | "simul_config": { 57 | "n_path": 64, 58 | "T": 2000 59 | } 60 | } -------------------------------------------------------------------------------- /src/configs/KS/game_nn_n50.json: -------------------------------------------------------------------------------- 1 | { 2 | "n_agt": 50, 3 | "beta": 0.99, 4 | "mats_path": "../data/KS_policy_N50_v1.mat", 5 | "n_basic": 3, 6 | "_comment_basic": "k_cross, ishock, ashock", 7 | "n_fm": 1, 8 | "n_gm": 0, 9 | "init_with_bchmk": false, 10 | "init_const_share": 0.4, 11 | "dataset_config": { 12 | "n_path": 384, 13 | "t_burn": 6000, 14 | "value_sampling": "nn", 15 | "moving_average": 1.0 16 | }, 17 | "value_config": { 18 | "num_vnet": 3, 19 | "T": 2000, 20 | "t_count": 800, 21 | "t_skip": 100, 22 | "_comment_v_data": "the above is about ValueDataSet", 23 | "num_epoch": 200, 24 | "lr": 1e-4, 25 | "net_width": [24, 24], 26 | "activation": "tanh", 27 | "batch_size": 128, 28 | "valid_size": 512, 29 | "_comment_v_learn": "the above is about learning" 30 | }, 31 | "policy_config": { 32 | "opt_type": "game", 33 | "update_init": true, 34 | "t_unroll": 150, 35 | "num_step": 10000, 36 | "freq_update_v": 2000, 37 | "lr_beg": 4e-4, 38 | "lr_end": 4e-4, 39 | "net_width": [24, 24], 40 | "activation": "tanh", 41 | "batch_size": 384, 42 | "valid_size": 384, 43 | "freq_valid": 500, 44 | "sgm_scale": 1, 45 | "_comment_p_learn": "the above is about learning", 46 | "T": 450, 47 | "t_sample": 200, 48 | "t_skip": 4, 49 | "epoch_resample": 0, 50 | "_comment_p_data": "the above is about PolicyDataSet" 51 | }, 52 | "gm_config": { 53 | "net_width": [12, 12], 54 | "activation": "tanh" 55 | }, 56 | "simul_config": { 57 | "n_path": 64, 58 | "T": 2000 59 | } 60 | } 61 | -------------------------------------------------------------------------------- /src/configs/KS/game_nn_n50_0fm1gm.json: -------------------------------------------------------------------------------- 1 | { 2 | "n_agt": 50, 3 | "beta": 0.99, 4 | "mats_path": "../data/KS_policy_N50_v1.mat", 5 | "n_basic": 3, 6 | "_comment_basic": "k_cross, ishock, ashock", 7 | "n_fm": 0, 8 | "n_gm": 1, 9 | "init_with_bchmk": false, 10 | "init_const_share": 0.4, 11 | "dataset_config": { 12 | "n_path": 384, 13 | "t_burn": 6000, 14 | "value_sampling": "nn", 15 | "moving_average": 1.0 16 | }, 17 | "value_config": { 18 | "num_vnet": 3, 19 | "T": 2000, 20 | "t_count": 800, 21 | "t_skip": 100, 22 | "_comment_v_data": "the above is about ValueDataSet", 23 | "num_epoch": 200, 24 | "lr": 1e-4, 25 | "net_width": [24, 24], 26 | "activation": "tanh", 27 | "batch_size": 128, 28 | "valid_size": 512, 29 | "_comment_v_learn": "the above is about learning" 30 | }, 31 | "policy_config": { 32 | "opt_type": "game", 33 | "update_init": true, 34 | "t_unroll": 150, 35 | "num_step": 10000, 36 | "freq_update_v": 2000, 37 | "lr_beg": 4e-4, 38 | "lr_end": 4e-4, 39 | "net_width": [24, 24], 40 | "activation": "tanh", 41 | "batch_size": 384, 42 | "valid_size": 384, 43 | "freq_valid": 500, 44 | "sgm_scale": 1, 45 | "_comment_p_learn": "the above is about learning", 46 | "T": 450, 47 | "t_sample": 200, 48 | "t_skip": 4, 49 | "epoch_resample": 0, 50 | "_comment_p_data": "the above is about PolicyDataSet" 51 | }, 52 | "gm_config": { 53 | "net_width": [12, 12], 54 | "activation": "tanh" 55 | }, 56 | "simul_config": { 57 | "n_path": 64, 58 | "T": 2000 59 | } 60 | } 61 | -------------------------------------------------------------------------------- /src/dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import numpy as np 4 | import tensorflow as tf 5 | 6 | import scipy.io as sio 7 | import simulation_KS as KS 8 | import simulation_JFV as JFV 9 | 10 | EPSILON = 1e-3 11 | DTYPE = "float64" 12 | tf.keras.backend.set_floatx(DTYPE) 13 | if DTYPE == "float64": 14 | NP_DTYPE = np.float64 15 | elif DTYPE == "float32": 16 | NP_DTYPE = np.float32 17 | else: 18 | raise ValueError("Unknown dtype.") 19 | 20 | 21 | class NumpyEncoder(json.JSONEncoder): 22 | def default(self, o): # pylint: disable=E0202 23 | if isinstance(o, np.ndarray): 24 | return o.tolist() 25 | return json.JSONEncoder.default(self, o) 26 | 27 | 28 | class BasicDataSet(): 29 | def __init__(self, datadict=None): 30 | self.datadict, self.keys = None, None 31 | self.size, self.idx_in_epoch, self.epoch_used = None, None, None 32 | if datadict: 33 | self.update_datadict(datadict) 34 | 35 | def update_datadict(self, datadict): 36 | self.datadict = datadict 37 | self.keys = datadict.keys() 38 | size_list = [datadict[k].shape[0] for k in self.keys] 39 | for i in range(1, len(size_list)): 40 | assert size_list[i] == size_list[0], "The size does not match." 41 | self.size = size_list[0] 42 | self.shuffle() 43 | self.epoch_used = 0 44 | 45 | def shuffle(self): 46 | idx = np.arange(0, self.size) 47 | np.random.shuffle(idx) 48 | self.datadict = dict((k, self.datadict[k][idx]) for k in self.keys) 49 | self.idx_in_epoch = 0 50 | 51 | def next_batch(self, batch_size): 52 | if self.idx_in_epoch + batch_size > self.size: 53 | self.shuffle() 54 | self.epoch_used += 1 55 | idx = slice(self.idx_in_epoch, self.idx_in_epoch+batch_size) 56 | self.idx_in_epoch += batch_size 57 | return dict((k, self.datadict[k][idx]) for k in self.keys) 58 | 59 | 60 | class DataSetwithStats(BasicDataSet): 61 | def __init__(self, stats_keys, datadict=None): 62 | super().__init__(datadict) 63 | self.stats_keys = stats_keys 64 | self.stats_dict, self.stats_dict_tf = {}, {} 65 | for k in stats_keys: 66 | self.stats_dict[k] = None 67 | self.stats_dict_tf[k] = None 68 | 69 | def update_stats(self, data, key, ma): 70 | # data can be of shape B * d or B * n_agt * d 71 | axis_for_mean = tuple(list(range(len(data.shape)-1))) 72 | if self.stats_dict[key] is None: 73 | mean, std = data.mean(axis=axis_for_mean), data.std(axis=axis_for_mean) 74 | else: 75 | mean_new, std_new = data.mean(axis=axis_for_mean), data.std(axis=axis_for_mean) 76 | mean, std = self.stats_dict[key] 77 | mean = mean * ma + mean_new * (1-ma) 78 | std = std * ma + std_new * (1-ma) 79 | self.stats_dict[key] = (mean, std) 80 | self.stats_dict_tf[key] = (tf.constant(mean, dtype=DTYPE), tf.constant(std, dtype=DTYPE)) 81 | 82 | def normalize_data(self, data, key, withtf=False): 83 | if withtf: 84 | mean, std = self.stats_dict_tf[key] 85 | else: 86 | mean, std = self.stats_dict[key] 87 | return (data - mean) / std 88 | 89 | def unnormalize_data(self, data, key, withtf=False): 90 | if withtf: 91 | mean, std = self.stats_dict_tf[key] 92 | else: 93 | mean, std = self.stats_dict[key] 94 | return data * std + mean 95 | 96 | def save_stats(self, path): 97 | with open(os.path.join(path, "stats.json"), "w") as fp: 98 | json.dump(self.stats_dict, fp, cls=NumpyEncoder) 99 | 100 | def load_stats(self, path): 101 | with open(os.path.join(path, "stats.json"), "r") as fp: 102 | saved_stats = json.load(fp) 103 | for key in saved_stats: 104 | assert key in self.stats_dict, "The key of stats_dict does not match!" 105 | mean, std = saved_stats[key] 106 | mean, std = np.asarray(mean).astype(NP_DTYPE), np.asarray(std).astype(NP_DTYPE) 107 | self.stats_dict[key] = (mean, std) 108 | self.stats_dict_tf[key] = (tf.constant(mean, dtype=DTYPE), tf.constant(std, dtype=DTYPE)) 109 | 110 | 111 | class InitDataSet(DataSetwithStats): 112 | def __init__(self, mparam, config): 113 | super().__init__(stats_keys=["basic_s", "agt_s", "value"]) 114 | self.mparam = mparam 115 | self.config = config 116 | self.n_basic = config["n_basic"] 117 | self.n_fm = config["n_fm"] # fixed moments 118 | self.n_path = config["dataset_config"]["n_path"] 119 | self.t_burn = config["dataset_config"]["t_burn"] 120 | self.c_policy_const_share = lambda *args: config["init_const_share"] 121 | if not config["init_with_bchmk"]: 122 | assert config["policy_config"]["update_init"], \ 123 | "Must update init data during learning if bchmk policy is not used for sampling init" 124 | 125 | def update_with_burn(self, policy, policy_type, t_burn=None, state_init=None): 126 | if t_burn is None: 127 | t_burn = self.t_burn 128 | if state_init is None: 129 | state_init = self.datadict 130 | simul_data = self.simul_k_func( 131 | self.n_path, t_burn, self.mparam, 132 | policy, policy_type, state_init=state_init 133 | ) 134 | self.update_from_simul(simul_data) 135 | 136 | def update_from_simul(self, simul_data): 137 | init_datadict = dict((k, simul_data[k][..., -1].copy()) for k in self.keys) 138 | for k in self.keys: 139 | if len(init_datadict[k].shape) == 1: 140 | init_datadict[k] = init_datadict[k][:, None] # for macro init state like N in JFV 141 | notnan = ~(np.isnan(init_datadict["k_cross"]).any(axis=1)) 142 | if np.sum(~notnan) > 0: 143 | num_nan = np.sum(~notnan) 144 | num_total = notnan.shape[0] 145 | print("Warning: {} of {} init samples are nan!".format(num_nan, num_total)) 146 | idx = np.where(notnan)[0] 147 | idx = np.concatenate([idx, idx[:num_nan]]) 148 | for k in self.keys: 149 | init_datadict[k] = init_datadict[k][idx] 150 | self.update_datadict(init_datadict) 151 | 152 | def process_vdatadict(self, v_datadict): 153 | idx_nan = np.logical_or( 154 | np.isnan(v_datadict["basic_s"]).any(axis=(1, 2)), 155 | np.isnan(v_datadict["value"]).any(axis=(1, 2)) 156 | ) 157 | ma = self.config["dataset_config"]["moving_average"] 158 | for key, array in v_datadict.items(): 159 | array = array[~idx_nan].astype(NP_DTYPE) 160 | self.update_stats(array, key, ma) 161 | v_datadict[key] = self.normalize_data(array, key) 162 | print("Average of total utility %f." % (self.stats_dict["value"][0][0])) 163 | 164 | valid_size = self.config["value_config"]["valid_size"] 165 | n_sample = v_datadict["value"].shape[0] 166 | if valid_size > 0.2*n_sample: 167 | valid_size = int(0.2*n_sample) 168 | print("Valid size is reduced to %d according to small data size!" % valid_size) 169 | print("The dataset has %d samples in total." % n_sample) 170 | 171 | dataset = tf.data.Dataset.from_tensor_slices(v_datadict) 172 | dataset = dataset.shuffle(n_sample) 173 | train_size = n_sample - valid_size 174 | train_vdataset = dataset.skip(valid_size).shuffle(train_size, reshuffle_each_iteration=True) 175 | valid_vdataset = dataset.take(valid_size).batch(valid_size) 176 | return train_vdataset, valid_vdataset 177 | 178 | def get_policydataset(self, policy, policy_type, update_init=False): 179 | policy_config = self.config["policy_config"] 180 | simul_data = self.simul_k_func( 181 | self.n_path, policy_config["T"], self.mparam, policy, policy_type, 182 | state_init=self.datadict 183 | ) 184 | if update_init: 185 | self.update_from_simul(simul_data) 186 | p_datadict = {} 187 | idx_nan = False 188 | for k in self.keys: 189 | arr = simul_data[k].astype(NP_DTYPE) 190 | arr = arr[..., slice(-policy_config["t_sample"], -1, policy_config["t_skip"])] 191 | if len(arr.shape) == 3: 192 | arr = np.swapaxes(arr, 1, 2) 193 | arr = np.reshape(arr, (-1, self.mparam.n_agt)) 194 | if k != "ishock": 195 | idx_nan = np.logical_or(idx_nan, np.isnan(arr).any(axis=1)) 196 | else: 197 | arr = np.reshape(arr, (-1, 1)) 198 | if k != "ashock": 199 | idx_nan = np.logical_or(idx_nan, np.isnan(arr[:, 0])) 200 | p_datadict[k] = arr 201 | for k in self.keys: 202 | p_datadict[k] = p_datadict[k][~idx_nan] 203 | if policy_config["opt_type"] == "game": 204 | p_datadict = crazyshuffle(p_datadict) 205 | policy_ds = BasicDataSet(p_datadict) 206 | return policy_ds 207 | 208 | def simul_k_func(self, n_sample, T, mparam, c_policy, policy_type, state_init=None, shocks=None): 209 | raise NotImplementedError 210 | 211 | 212 | class KSInitDataSet(InitDataSet): 213 | def __init__(self, mparam, config): 214 | super().__init__(mparam, config) 215 | mats = sio.loadmat(mparam.mats_path) 216 | self.splines = KS.construct_bspl(mats) 217 | self.keys = ["k_cross", "ashock", "ishock"] 218 | self.k_policy_bchmk = lambda k_cross, ashock, ishock: KS.k_policy_bspl(k_cross, ashock, ishock, self.splines) 219 | # the first burn for initialization 220 | self.update_with_burn(self.k_policy_bchmk, "pde") 221 | 222 | def get_valuedataset(self, policy, policy_type, update_init=False): 223 | value_config = self.config["value_config"] 224 | t_count = value_config["t_count"] 225 | t_skip = value_config["t_skip"] 226 | simul_data = self.simul_k_func( 227 | self.n_path, value_config["T"], self.mparam, policy, policy_type, 228 | state_init=self.datadict 229 | ) 230 | if update_init: 231 | self.update_from_simul(simul_data) 232 | 233 | ashock, ishock = simul_data["ashock"], simul_data["ishock"] 234 | k_cross, csmp = simul_data["k_cross"], simul_data["csmp"] 235 | k_mean = np.mean(k_cross, axis=1, keepdims=True) 236 | # k_fm = self.compute_fm(k_cross) # n_path*n_fm*T 237 | discount = np.power(self.mparam.beta, np.arange(t_count)) 238 | util = np.log(csmp) 239 | 240 | basic_s = np.zeros(shape=[0, self.mparam.n_agt, self.n_basic+1]) 241 | agt_s = np.zeros(shape=[0, self.mparam.n_agt, 1]) 242 | value = np.zeros(shape=[0, self.mparam.n_agt, 1]) 243 | t_idx = 0 244 | while t_idx + t_count < value_config["T"]-1: 245 | k_tmp = k_cross[:, :, t_idx:t_idx+1] 246 | i_tmp = ishock[:, :, t_idx:t_idx+1] 247 | k_mean_tmp = np.repeat(k_mean[:, :, t_idx:t_idx+1], self.mparam.n_agt, axis=1) 248 | a_tmp = np.repeat(ashock[:, None, t_idx:t_idx+1], self.mparam.n_agt, axis=1) 249 | basic_s_tmp = np.concatenate([k_tmp, k_mean_tmp, a_tmp, i_tmp], axis=-1) 250 | v_tmp = np.sum(util[..., t_idx:t_idx+t_count]*discount, axis=-1, keepdims=True) 251 | 252 | basic_s = np.concatenate([basic_s, basic_s_tmp], axis=0) 253 | agt_s = np.concatenate([agt_s, k_tmp], axis=0) 254 | value = np.concatenate([value, v_tmp], axis=0) 255 | t_idx += t_skip 256 | 257 | v_datadict = {"basic_s": basic_s, "agt_s": agt_s, "value": value} 258 | train_vdataset, valid_vdataset = self.process_vdatadict(v_datadict) 259 | return train_vdataset, valid_vdataset 260 | 261 | def simul_k_func(self, n_sample, T, mparam, c_policy, policy_type, state_init=None, shocks=None): 262 | return KS.simul_k(n_sample, T, mparam, c_policy, policy_type, state_init, shocks) 263 | 264 | class JFVInitDataSet(InitDataSet): 265 | def __init__(self, mparam, config): 266 | super().__init__(mparam, config) 267 | self.with_ashock = mparam.with_ashock 268 | self.keys = ["k_cross", "N", "ishock"] 269 | mats = sio.loadmat(mparam.mats_path) 270 | if self.with_ashock: 271 | self.splines = JFV.construct_spl_SSS(mats, 'c') 272 | self.c_policy_bchmk = lambda k_cross, N, ishock: JFV.c_policy_spl_SSS(k_cross, N, ishock, self.splines) 273 | # state_init = {"k_cross": mparam.B_sss, "N": mparam.N_sss} 274 | else: 275 | self.splines = JFV.construct_spl_DSS(mats, 'c') 276 | self.c_policy_bchmk = lambda k_cross, N, ishock: JFV.c_policy_spl_DSS(k_cross, N, ishock, self.splines) 277 | # state_init = {"k_cross": mparam.k_dss, "N": mparam.N_dss} 278 | # the first burn for initialization 279 | self.update_with_burn(self.c_policy_bchmk, "pde") 280 | 281 | def get_valuedataset(self, policy, policy_type, update_init=False): 282 | value_config = self.config["value_config"] 283 | t_count = value_config["t_count"] 284 | t_skip = value_config["t_skip"] 285 | simul_data = self.simul_k_func( 286 | self.n_path, value_config["T"], self.mparam, policy, policy_type, 287 | state_init=self.datadict 288 | ) 289 | if update_init: 290 | self.update_from_simul(simul_data) 291 | 292 | ishock = simul_data["ishock"] 293 | k_cross, csmp = simul_data["k_cross"], simul_data["csmp"] 294 | k_mean = np.mean(k_cross, axis=1, keepdims=True) 295 | discount = np.power(self.mparam.beta, np.arange(t_count)) 296 | util = 1 - 1/csmp 297 | 298 | basic_s = np.zeros(shape=[0, self.mparam.n_agt, self.n_basic+1]) 299 | agt_s = np.zeros(shape=[0, self.mparam.n_agt, 1]) 300 | value = np.zeros(shape=[0, self.mparam.n_agt, 1]) 301 | t_idx = 0 302 | while t_idx + t_count < value_config["T"]-1: 303 | k_tmp = k_cross[:, :, t_idx:t_idx+1] 304 | i_tmp = ishock[:, :, t_idx:t_idx+1] 305 | k_mean_tmp = np.repeat(k_mean[:, :, t_idx:t_idx+1], self.mparam.n_agt, axis=1) 306 | N_tmp = np.repeat(simul_data["N"][:, None, t_idx:t_idx+1], self.mparam.n_agt, axis=1) 307 | basic_s_tmp = np.concatenate([k_tmp, k_mean_tmp, N_tmp, i_tmp], axis=-1) 308 | v_tmp = np.sum(util[..., t_idx:t_idx+t_count]*discount, axis=-1, keepdims=True) * self.mparam.dt 309 | 310 | basic_s = np.concatenate([basic_s, basic_s_tmp], axis=0) 311 | agt_s = np.concatenate([agt_s, k_tmp], axis=0) 312 | value = np.concatenate([value, v_tmp], axis=0) 313 | t_idx += t_skip 314 | 315 | v_datadict = {"basic_s": basic_s, "agt_s": agt_s, "value": value} 316 | train_vdataset, valid_vdataset = self.process_vdatadict(v_datadict) 317 | return train_vdataset, valid_vdataset 318 | 319 | # def get_policydataset(self, policy, policy_type, update_init=False): 320 | # # it only include k_cross and N for policy 321 | # policy_config = self.config["policy_config"] 322 | # simul_data = self.simul_k_func( 323 | # self.n_path, policy_config["T"], self.mparam, policy, policy_type, 324 | # state_init=self.datadict 325 | # ) 326 | # if update_init: 327 | # self.update_from_simul(simul_data) 328 | 329 | # k_cross = simul_data["k_cross"].astype(NP_DTYPE) 330 | # k_cross = np.swapaxes(k_cross[..., slice(-policy_config["t_sample"], -1, policy_config["t_skip"])], 1, 2) 331 | # k_cross = np.reshape(k_cross, (-1, self.mparam.n_agt)) 332 | # if policy_config["opt_type"] == "game": 333 | # k_cross = crazyshuffle(k_cross) 334 | # N = simul_data["N"][:, slice(-policy_config["t_sample"], -1, policy_config["t_skip"])].astype(NP_DTYPE) 335 | # N = np.reshape(N, (-1, 1)) 336 | # idx_nan = np.isnan(k_cross).any(axis=1) 337 | # idx_nan = np.logical_or(idx_nan, np.isnan(N[:, 0])) 338 | # p_datadict = {"k_cross": k_cross[~idx_nan], "N": N[~idx_nan]} 339 | # policy_ds = BasicDataSet(p_datadict) 340 | # return policy_ds 341 | 342 | def simul_k_func(self, n_sample, T, mparam, c_policy, policy_type, state_init=None, shocks=None): 343 | return JFV.simul_k(n_sample, T, mparam, c_policy, policy_type, state_init, shocks) 344 | 345 | 346 | def crazyshuffle(data): 347 | assert data["k_cross"].shape == data["ishock"].shape 348 | x, y = data["k_cross"].shape 349 | rows = np.indices((x, y))[0] 350 | cols = [np.random.permutation(y) for _ in range(x)] 351 | data["k_cross"] = data["k_cross"][rows, cols] 352 | data["ishock"] = data["ishock"][rows, cols] 353 | return data 354 | -------------------------------------------------------------------------------- /src/param.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class KSParam(): 5 | def __init__(self, n_agt, beta, mats_path): 6 | self.n_agt = n_agt # number of finite agents 7 | self.beta = beta # discount factor 8 | self.mats_path = mats_path # matrix from Matlab policy 9 | self.gamma = 1.0 # utility-function parameter 10 | self.alpha = 0.36 # share of capital in the production function 11 | self.delta = 0.025 # depreciation rate 12 | self.delta_a = 0.01 # (1-delta_a) is the productivity level in a bad state, 13 | # and (1+delta_a) is the productivity level in a good state 14 | self.mu = 0.15 # unemployment benefits as a share of wage 15 | self.l_bar = 1.0 / 0.9 # time endowment normalizes labor supply to 1 in a bad state 16 | 17 | self.epsilon_u = 0 # idiosyncratic shock if the agent is unemployed 18 | self.epsilon_e = 1 # idiosyncratic shock if the agent is employed 19 | 20 | self.ur_b = 0.1 # unemployment rate in a bad aggregate state 21 | self.er_b = (1 - self.ur_b) # employment rate in a bad aggregate state 22 | self.ur_g = 0.04 # unemployment rate in a good aggregate state 23 | self.er_g = (1 - self.ur_g) # employment rate in a good aggregate state 24 | 25 | # labor tax rate in bad and good aggregate states 26 | self.tau_b = self.mu * self.ur_b / (self.l_bar * self.er_b) 27 | self.tau_g = self.mu * self.ur_g / (self.l_bar * self.er_g) 28 | 29 | self.k_ss = ((1 / self.beta - (1 - self.delta)) / self.alpha) ** (1 / (self.alpha - 1)) 30 | # steady-state capital in a deterministic model with employment rate of 0.9 31 | # (i.e., l_bar*L=1, where L is aggregate labor in the paper) 32 | 33 | self.prob_trans = np.array( 34 | [ 35 | [0.525, 0.35, 0.03125, 0.09375], 36 | [0.038889, 0.836111, 0.002083, 0.122917], 37 | [0.09375, 0.03125, 0.291667, 0.583333], 38 | [0.009115, 0.115885, 0.024306, 0.850694] 39 | ] 40 | ) 41 | 42 | self.prob_ag = np.zeros([2, 2]) 43 | self.prob_ag[0, 0] = self.prob_trans[0, 0] + self.prob_trans[0, 1] 44 | self.prob_ag[1, 1] = self.prob_trans[3, 2] + self.prob_trans[3, 3] 45 | self.prob_ag[0, 1] = 1 - self.prob_ag[0, 0] 46 | self.prob_ag[1, 0] = 1 - self.prob_ag[1, 1] 47 | 48 | self.p_bb_uu = self.prob_trans[0, 0] / self.prob_ag[0, 0] 49 | self.p_bb_ue = 1 - self.p_bb_uu 50 | self.p_bb_ee = self.prob_trans[1, 1] / self.prob_ag[0, 0] 51 | self.p_bb_eu = 1 - self.p_bb_ee 52 | self.p_bg_uu = self.prob_trans[0, 2] / self.prob_ag[0, 1] 53 | self.p_bg_ue = 1 - self.p_bg_uu 54 | self.p_bg_ee = self.prob_trans[1, 3] / self.prob_ag[0, 1] 55 | self.p_bg_eu = 1 - self.p_bg_ee 56 | self.p_gb_uu = self.prob_trans[2, 0] / self.prob_ag[1, 0] 57 | self.p_gb_ue = 1 - self.p_gb_uu 58 | self.p_gb_ee = self.prob_trans[3, 1] / self.prob_ag[1, 0] 59 | self.p_gb_eu = 1 - self.p_gb_ee 60 | self.p_gg_uu = self.prob_trans[2, 2] / self.prob_ag[1, 1] 61 | self.p_gg_ue = 1 - self.p_gg_uu 62 | self.p_gg_ee = self.prob_trans[3, 3] / self.prob_ag[1, 1] 63 | self.p_gg_eu = 1 - self.p_gg_ee 64 | 65 | 66 | class DavilaParam(): 67 | def __init__(self, n_agt, beta, mats_path, ashock_type): 68 | self.n_agt = n_agt # number of finite agents 69 | self.beta = beta # discount factor 70 | self.mats_path = mats_path # matrix from Matlab policy 71 | self.ashock_type = ashock_type # None, or IAS, or CIS 72 | self.gamma = 2.0 # utility-function parameter 73 | self.alpha = 0.36 # share of capital in the production function 74 | self.delta = 0.08 # annual depreciation rate 75 | self.delta_a = 0.02 # (1-delta_a) is the productivity level in a bad state, 76 | # and (1+delta_a) is the productivity level in a good state 77 | self.k_ss = ((1 / self.beta - (1 - self.delta)) / self.alpha) ** (1 / (self.alpha - 1)) 78 | # steady-state capital in a complete market model 79 | self.amin = 0.0 # borrowing constraint 80 | 81 | self.epsilon_0 = 1.0 # idiosyncratic state 0 82 | self.epsilon_1 = 5.29 # idiosyncratic state 1 83 | self.epsilon_2 = 46.55 # idiosyncratic state 2 84 | 85 | self.ur = 0.49833222 # unemployment rate 86 | self.er1 = 0.44296197 87 | self.er2 = 1 - self.ur - self.er1 88 | 89 | if ashock_type == "CIS": 90 | self.trans_g = np.array([ 91 | [0.98, 0.02, 0.0], 92 | [0.009, 0.980, 0.011], 93 | [0.0, 0.083, 0.917] 94 | ]) 95 | self.trans_b = np.array([ 96 | [0.6512248557478917, 0.34877514425210826, 0.0], 97 | [0.978, 0.011, 0.011], 98 | [0.0, 0.083, 0.917] 99 | ]) 100 | self.ur_g = 0.28435478 # unemployment rate in good aggregate state 101 | self.er1_g = 0.63189951 102 | self.er2_g = 1 - self.ur_g - self.er1_g 103 | self.ur_b = 0.71230967 # unemployment rate in good aggregate state 104 | self.er1_b = 0.25402444 105 | self.er2_b = 1 - self.ur_g - self.er1_g 106 | self.emp_g = self.epsilon_0*self.ur_g + self.epsilon_1*self.er1_g + self.epsilon_2*self.er2_g 107 | self.emp_b = self.epsilon_0*self.ur_b + self.epsilon_1*self.er1_b + self.epsilon_2*self.er2_b 108 | else: 109 | self.prob_trans = np.array([ 110 | [0.992, 0.008, 0.0], 111 | [0.009, 0.980, 0.011], 112 | [0.0, 0.083, 0.917] 113 | ]) 114 | self.emp_g = self.epsilon_0*self.ur + self.epsilon_1*self.er1 + self.epsilon_2*self.er2 115 | self.emp_b = self.emp_g 116 | 117 | 118 | class JFVParam(): 119 | def __init__(self, n_agt, dt, mats_path, with_ashock): 120 | self.n_agt = n_agt # number of finite agents 121 | self.dt = dt 122 | self.mats_path = mats_path 123 | self.with_ashock = with_ashock 124 | self.rho = 0.05 # discount rate 125 | self.rhohat = 0.04971 # discount rate for experts 126 | self.gamma = 2.0 # utility-function parameter 127 | self.alpha = 0.35 # share of capital in the production function 128 | self.delta = 0.1 # depreciation rate 129 | self.la1 = 0.986 #transition probability from low to high 130 | self.la2 = 0.052 #transition probability from high to low 131 | self.z1 = 0.72 # low type labor productivity 132 | self.z2 = 1 + self.la2/self.la1 * (1-self.z1)# high type labor productivity 133 | if with_ashock: # SSS 134 | self.sigma = 0.0140 # sigma for aggregate capital quality shock 135 | self.sigma2 = self.sigma**2 # sigma^2 136 | else: # DSS 137 | self.sigma = 0 138 | self.sigma2 = 0 139 | self.beta = np.exp(-self.rho * self.dt) 140 | self.k_dss = 1.8718155468494229 141 | self.N_dss = 1.8214550560258203 142 | # self.B_sss = 1.8531388366299562 143 | # self.N_sss = 1.8267371165081967 144 | self.B_sss = 1.9903560465449313 145 | self.N_sss = 1.6837810785365737 146 | 147 | self.amin = 0.0 # borrowing constraint 148 | self.amax = 20.0 # max value of individual savings 149 | self.Bmin = 0.7 # relevant range for aggregate savings 150 | self.Bmax = 2.7 151 | self.Nmin = 1.2 # relevant range for aggregate equity 152 | self.Nmax = 3.2 153 | 154 | self.nval_a = 501 # number of points in amin-to-amax range (individual savings) 155 | self.nval_z = 2 # number of options for z (the idiosincratic shock) 156 | self.nval_B = 4 # number of points in Bmin-to-Bmax range (aggregate savings), on the coarse grid for HJB 157 | self.nval_N = 51 # number of points in Nmin-to-Nmax range (aggregate equity), on the coarse grid for HJB 158 | 159 | self.nval_BB = 101 # finer grid, used for training the NN, for determining visited range and for the convergence 160 | self.nval_NN = 101 # finer grid, used for training the NN, for determining visited range and for the convergence 161 | 162 | self.da = (self.amax-self.amin)/(self.nval_a-1) # size of a jump 163 | self.dB = (self.Bmax-self.Bmin)/(self.nval_B-1) # size of B jump on the coarse grid 164 | self.dN = (self.Nmax-self.Nmin)/(self.nval_N-1) # size of N jump on the coarse grid 165 | self.dBB = (self.Bmax-self.Bmin)/(self.nval_BB-1)# size of B jump on the fine grid 166 | self.dNN = (self.Nmax-self.Nmin)/(self.nval_NN-1)# size of B jump on the fine grid 167 | 168 | # prices in the DSS 169 | self.r_dss = self.rhohat # DSS interest rate 170 | self.K_dss = ((self.rhohat + self.delta)/self.alpha)**(1.0/(self.alpha-1.0)) # DSS capital 171 | self.w_dss = (1.0-self.alpha)*(self.K_dss**self.alpha) # DSS wage 172 | -------------------------------------------------------------------------------- /src/policy.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import tensorflow as tf 4 | from tqdm import tqdm 5 | import util 6 | import simulation_KS as KS 7 | import simulation_JFV as JFV 8 | 9 | EPSILON = 1e-3 10 | DTYPE = "float64" 11 | tf.keras.backend.set_floatx(DTYPE) 12 | if DTYPE == "float64": 13 | NP_DTYPE = np.float64 14 | elif DTYPE == "float32": 15 | NP_DTYPE = np.float32 16 | else: 17 | raise ValueError("Unknown dtype.") 18 | 19 | 20 | class PolicyTrainer(): 21 | def __init__(self, vtrainers, init_ds, policy_path=None): 22 | self.config = init_ds.config 23 | self.policy_config = self.config["policy_config"] 24 | self.t_unroll = self.policy_config["t_unroll"] 25 | self.vtrainers = vtrainers 26 | self.valid_size = self.policy_config["valid_size"] 27 | self.sgm_scale = self.policy_config["sgm_scale"] # scaling param in sigmoid 28 | self.init_ds = init_ds 29 | self.value_sampling = self.config["dataset_config"]["value_sampling"] 30 | self.num_vnet = len(vtrainers) 31 | self.mparam = init_ds.mparam 32 | d_in = self.config["n_basic"] + self.config["n_fm"] + self.config["n_gm"] 33 | self.model = util.FeedforwardModel(d_in, 1, self.policy_config, name="p_net") 34 | if self.config["n_gm"] > 0: 35 | # TODO generalize to multi-dimensional agt_s 36 | self.gm_model = util.GeneralizedMomModel(1, self.config["n_gm"], self.config["gm_config"], name="p_gm") 37 | self.train_vars = None 38 | if policy_path is not None: 39 | self.model.load_weights_after_init(policy_path) 40 | if self.config["n_gm"] > 0: 41 | self.gm_model.load_weights_after_init(policy_path.replace(".h5", "_gm.h5")) 42 | self.init_ds.load_stats(os.path.dirname(policy_path)) 43 | self.discount = np.power(self.mparam.beta, np.arange(self.t_unroll)) 44 | # to be generated in the child class 45 | self.policy_ds = None 46 | 47 | @tf.function 48 | def prepare_state(self, input_data): 49 | if self.config["n_fm"] == 2: 50 | k_var = tf.math.reduce_variance(input_data["agt_s"], axis=-2, keepdims=True) 51 | k_var = tf.tile(k_var, [1, input_data["agt_s"].shape[-2], 1]) 52 | state = tf.concat([input_data["basic_s"], k_var], axis=-1) 53 | elif self.config["n_fm"] == 0: 54 | state = tf.concat([input_data["basic_s"][..., 0:1], input_data["basic_s"][..., 2:]], axis=-1) 55 | elif self.config["n_fm"] == 1: # so far always add k_mean in the basic_state 56 | state = input_data["basic_s"] 57 | if self.config["n_gm"] > 0: 58 | gm = self.gm_model(input_data["agt_s"]) 59 | state = tf.concat([state, gm], axis=-1) 60 | return state 61 | 62 | @tf.function 63 | def policy_fn(self, input_data): 64 | state = self.prepare_state(input_data) 65 | policy = tf.sigmoid(self.sgm_scale*self.model(state)) 66 | return policy 67 | 68 | @tf.function 69 | def loss(self, input_data): 70 | raise NotImplementedError 71 | 72 | def grad(self, input_data): 73 | with tf.GradientTape(persistent=True) as tape: 74 | output_dict = self.loss(input_data) 75 | train_vars = self.model.trainable_variables 76 | if self.config["n_gm"] > 0: 77 | train_vars += self.gm_model.trainable_variables 78 | self.train_vars = train_vars 79 | grad = tape.gradient( 80 | output_dict["m_util"], 81 | train_vars, 82 | unconnected_gradients=tf.UnconnectedGradients.ZERO, 83 | ) 84 | del tape 85 | return grad, output_dict["k_end"] 86 | 87 | @tf.function 88 | def train_step(self, train_data): 89 | grad, k_end = self.grad(train_data) 90 | self.optimizer.apply_gradients( 91 | zip(grad, self.train_vars) 92 | ) 93 | return k_end 94 | 95 | def train(self, num_step=None, batch_size=None): 96 | lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay( 97 | self.policy_config["lr_beg"], 98 | decay_steps=num_step, 99 | decay_rate=self.policy_config["lr_end"] / self.policy_config["lr_beg"], 100 | staircase=False, 101 | ) 102 | assert batch_size <= self.valid_size, "The valid size should be no smaller than batch_size." 103 | self.optimizer = tf.keras.optimizers.Adam( # pylint: disable=W0201 104 | learning_rate=lr_schedule, epsilon=1e-8, 105 | beta_1=0.99, beta_2=0.99 106 | ) 107 | 108 | # assume valid_size = n_path in self.init_ds 109 | valid_data = dict((k, self.init_ds.datadict[k].astype(NP_DTYPE)) for k in self.init_ds.keys) 110 | ashock, ishock = self.simul_shocks( 111 | self.valid_size, self.t_unroll, self.mparam, 112 | state_init=self.init_ds.datadict 113 | ) 114 | valid_data["ashock"] = ashock.astype(NP_DTYPE) 115 | valid_data["ishock"] = ishock.astype(NP_DTYPE) 116 | 117 | freq_valid = self.policy_config["freq_valid"] 118 | n_epoch = num_step // freq_valid 119 | update_init = False 120 | for n in range(n_epoch): 121 | for step in tqdm(range(freq_valid)): 122 | train_data = self.sampler(batch_size, update_init) 123 | k_end = self.train_step(train_data) 124 | n_step = n*freq_valid + step 125 | if self.value_sampling != "bchmk" and n_step % self.policy_config["freq_update_v"] == 0 and n_step > 0: 126 | update_init = self.policy_config["update_init"] 127 | train_vds, valid_vds = self.get_valuedataset(update_init) 128 | for vtr in self.vtrainers: 129 | vtr.train( 130 | train_vds, valid_vds, 131 | self.config["value_config"]["num_epoch"], 132 | self.config["value_config"]["batch_size"] 133 | ) 134 | val_output = self.loss(valid_data) 135 | print( 136 | "Step: %d, valid util: %g, k_end: %g" % 137 | (freq_valid*(n+1), -val_output["m_util"], k_end) 138 | ) 139 | 140 | def save_model(self, path="policy_model"): 141 | self.model.save_weights(path) 142 | self.init_ds.save_stats(os.path.dirname(path)) 143 | if self.config["n_gm"] > 0: 144 | self.gm_model.save_weights(path.replace(".h5", "_gm.h5")) 145 | 146 | def simul_shocks(self, n_sample, T, mparam, state_init): 147 | raise NotImplementedError 148 | 149 | def sampler(self, batch_size, update_init=False): 150 | train_data = self.policy_ds.next_batch(batch_size) 151 | ashock, ishock = self.simul_shocks(batch_size, self.t_unroll, self.mparam, train_data) 152 | train_data["ashock"] = ashock.astype(NP_DTYPE) 153 | train_data["ishock"] = ishock.astype(NP_DTYPE) 154 | # TODO test the effect of epoch_resample 155 | if self.policy_ds.epoch_used > self.policy_config["epoch_resample"]: 156 | self.update_policydataset(update_init) 157 | return train_data 158 | 159 | def update_policydataset(self, update_init=False): 160 | raise NotImplementedError 161 | 162 | def get_valuedataset(self, update_init=False): 163 | raise NotImplementedError 164 | 165 | 166 | class KSPolicyTrainer(PolicyTrainer): 167 | def __init__(self, vtrainers, init_ds, policy_path=None): 168 | super().__init__(vtrainers, init_ds, policy_path) 169 | if self.config["init_with_bchmk"]: 170 | init_policy = self.init_ds.k_policy_bchmk 171 | policy_type = "pde" 172 | else: 173 | init_policy = self.init_ds.c_policy_const_share 174 | policy_type = "nn_share" 175 | self.policy_ds = self.init_ds.get_policydataset(init_policy, policy_type, update_init=False) 176 | 177 | @tf.function 178 | def loss(self, input_data): 179 | k_cross = input_data["k_cross"] 180 | ashock, ishock = input_data["ashock"], input_data["ishock"] 181 | util_sum = 0 182 | 183 | for t in range(self.t_unroll): 184 | k_mean = tf.reduce_mean(k_cross, axis=1, keepdims=True) 185 | k_mean_tmp = tf.tile(k_mean, [1, self.mparam.n_agt]) 186 | k_mean_tmp = tf.expand_dims(k_mean_tmp, axis=-1) 187 | i_tmp = ishock[:, :, t:t+1] # n_path*n_agt*1 188 | a_tmp = tf.tile(ashock[:, t:t+1], [1, self.mparam.n_agt]) 189 | a_tmp = tf.expand_dims(a_tmp, axis=2) # n_path*n_agt*1 190 | basic_s_tmp = tf.concat([tf.expand_dims(k_cross, axis=-1), k_mean_tmp, a_tmp, i_tmp], axis=-1) 191 | basic_s_tmp = self.init_ds.normalize_data(basic_s_tmp, key="basic_s", withtf=True) 192 | full_state_dict = { 193 | "basic_s": basic_s_tmp, 194 | "agt_s": self.init_ds.normalize_data(tf.expand_dims(k_cross, axis=-1), key="agt_s", withtf=True) 195 | } 196 | if t == self.t_unroll - 1: 197 | value = 0 198 | for vtr in self.vtrainers: 199 | value += self.init_ds.unnormalize_data( 200 | vtr.value_fn(full_state_dict)[..., 0], key="value", withtf=True) 201 | value /= self.num_vnet 202 | util_sum += self.discount[t]*value 203 | continue 204 | 205 | c_share = self.policy_fn(full_state_dict)[..., 0] 206 | if self.policy_config["opt_type"] == "game": 207 | # optimizing agent 0 only 208 | c_share = tf.concat([c_share[:, 0:1], tf.stop_gradient(c_share[:, 1:])], axis=1) 209 | # labor tax rate - depend on ashock 210 | tau = tf.where(ashock[:, t:t+1] < 1, self.mparam.tau_b, self.mparam.tau_g) 211 | # total labor supply - depend on ashock 212 | emp = tf.where( 213 | ashock[:, t:t+1] < 1, 214 | self.mparam.l_bar*self.mparam.er_b, 215 | self.mparam.l_bar*self.mparam.er_g 216 | ) 217 | tau, emp = tf.cast(tau, DTYPE), tf.cast(emp, DTYPE) 218 | R = 1 - self.mparam.delta + ashock[:, t:t+1] * self.mparam.alpha*(k_mean / emp)**(self.mparam.alpha-1) 219 | wage = ashock[:, t:t+1]*(1-self.mparam.alpha)*(k_mean / emp)**(self.mparam.alpha) 220 | wealth = R * k_cross + (1-tau)*wage*self.mparam.l_bar*ishock[:, :, t] + \ 221 | self.mparam.mu*wage*(1-ishock[:, :, t]) 222 | csmp = tf.clip_by_value(c_share * wealth, EPSILON, wealth-EPSILON) 223 | k_cross = wealth - csmp 224 | util_sum += self.discount[t] * tf.math.log(csmp) 225 | 226 | if self.policy_config["opt_type"] == "socialplanner": 227 | output_dict = {"m_util": -tf.reduce_mean(util_sum), "k_end": tf.reduce_mean(k_cross)} 228 | elif self.policy_config["opt_type"] == "game": 229 | # optimizing agent 0 only 230 | output_dict = {"m_util": -tf.reduce_mean(util_sum[:, 0]), "k_end": tf.reduce_mean(k_cross)} 231 | return output_dict 232 | 233 | def update_policydataset(self, update_init=False): 234 | self.policy_ds = self.init_ds.get_policydataset(self.current_c_policy, "nn_share", update_init) 235 | 236 | def get_valuedataset(self, update_init=False): 237 | return self.init_ds.get_valuedataset(self.current_c_policy, "nn_share", update_init) 238 | 239 | def current_c_policy(self, k_cross, ashock, ishock): 240 | k_mean = np.mean(k_cross, axis=1, keepdims=True) 241 | k_mean = np.repeat(k_mean, self.mparam.n_agt, axis=1) 242 | ashock = np.repeat(ashock, self.mparam.n_agt, axis=1) 243 | basic_s = np.stack([k_cross, k_mean, ashock, ishock], axis=-1) 244 | basic_s = self.init_ds.normalize_data(basic_s, key="basic_s") 245 | basic_s = basic_s.astype(NP_DTYPE) 246 | full_state_dict = { 247 | "basic_s": basic_s, 248 | "agt_s": self.init_ds.normalize_data(k_cross[:, :, None], key="agt_s") 249 | } 250 | c_share = self.policy_fn(full_state_dict)[..., 0] 251 | return c_share 252 | 253 | def simul_shocks(self, n_sample, T, mparam, state_init): 254 | return KS.simul_shocks(n_sample, T, mparam, state_init) 255 | 256 | 257 | class JFVPolicyTrainer(PolicyTrainer): 258 | def __init__(self, vtrainers, init_ds, policy_path=None): 259 | super().__init__(vtrainers, init_ds, policy_path) 260 | if self.config["init_with_bchmk"]: 261 | init_policy = self.init_ds.c_policy_bchmk 262 | policy_type = "pde" 263 | else: 264 | init_policy = self.init_ds.c_policy_const_share 265 | policy_type = "nn_share" 266 | self.policy_ds = self.init_ds.get_policydataset(init_policy, policy_type, update_init=False) 267 | self.with_ashock = self.mparam.with_ashock 268 | 269 | @tf.function 270 | def loss(self, input_data): 271 | k_cross, N = input_data["k_cross"], input_data["N"] 272 | ashock, ishock = input_data["ashock"], input_data["ishock"] 273 | util_sum = 0 274 | 275 | for t in range(self.t_unroll): 276 | k_mean = tf.reduce_mean(k_cross, axis=1, keepdims=True) 277 | k_mean_tmp = tf.tile(k_mean, [1, self.mparam.n_agt]) 278 | k_mean_tmp = tf.expand_dims(k_mean_tmp, axis=-1) 279 | i_tmp = ishock[:, :, t:t+1] 280 | N_tmp = tf.tile(N, [1, self.mparam.n_agt]) 281 | N_tmp = tf.expand_dims(N_tmp, axis=-1) # n_path*n_agt*1 282 | basic_s_tmp = tf.concat([tf.expand_dims(k_cross, axis=-1), k_mean_tmp, N_tmp, i_tmp], axis=-1) 283 | basic_s_tmp = self.init_ds.normalize_data(basic_s_tmp, key="basic_s", withtf=True) 284 | full_state_dict = { 285 | "basic_s": basic_s_tmp, 286 | "agt_s": self.init_ds.normalize_data(tf.expand_dims(k_cross, axis=-1), key="agt_s", withtf=True) 287 | } 288 | if t == self.t_unroll - 1: 289 | value = 0 290 | for vtr in self.vtrainers: 291 | value += self.init_ds.unnormalize_data( 292 | vtr.value_fn(full_state_dict)[..., 0], key="value", withtf=True) 293 | value /= self.num_vnet 294 | util_sum = util_sum * self.mparam.dt + self.discount[t]*value 295 | continue 296 | 297 | c_share = self.policy_fn(full_state_dict)[..., 0] 298 | if self.policy_config["opt_type"] == "game": 299 | # optimizing agent 0 only 300 | c_share = tf.concat([c_share[:, 0:1], tf.stop_gradient(c_share[:, 1:])], axis=1) 301 | 302 | K = N + k_mean 303 | wage_unit = (1 - self.mparam.alpha) * K**self.mparam.alpha 304 | r = self.mparam.alpha * K**(self.mparam.alpha-1) - self.mparam.delta - self.mparam.sigma2*K/N 305 | wage = (ishock[:, :, t] * (self.mparam.z2-self.mparam.z1) + self.mparam.z1) * wage_unit # map 0/1 to z1/z2 306 | wealth = (1 + r*self.mparam.dt) * k_cross + wage * self.mparam.dt 307 | csmp = tf.clip_by_value(c_share * wealth / self.mparam.dt, EPSILON, wealth/self.mparam.dt-EPSILON) 308 | k_cross = wealth - csmp * self.mparam.dt 309 | dN_drift = self.mparam.dt * (self.mparam.alpha * K**(self.mparam.alpha-1) - self.mparam.delta - \ 310 | self.mparam.rhohat - self.mparam.sigma2*(-k_mean/N)*(K/N))*N 311 | dN_diff = K * ashock[:, t:t+1] 312 | N = tf.maximum(N + dN_drift + dN_diff, 0.01) 313 | util_sum += self.discount[t] * (1 - 1/csmp) 314 | 315 | if self.policy_config["opt_type"] == "socialplanner": 316 | output_dict = {"m_util": -tf.reduce_mean(util_sum), "k_end": tf.reduce_mean(k_cross)} 317 | elif self.policy_config["opt_type"] == "game": 318 | # optimizing agent 0 only 319 | output_dict = {"m_util": -tf.reduce_mean(util_sum[:, 0]), "k_end": tf.reduce_mean(k_cross)} 320 | return output_dict 321 | 322 | def update_policydataset(self, update_init=False): 323 | # self.policy_ds = self.init_ds.get_policydataset(self.init_ds.c_policy_bchmk, "pde", update_init) 324 | self.policy_ds = self.init_ds.get_policydataset(self.current_c_policy, "nn_share", update_init) 325 | 326 | def get_valuedataset(self, update_init=False): 327 | return self.init_ds.get_valuedataset(self.current_c_policy, "nn_share", update_init) 328 | 329 | def current_c_policy(self, k_cross, N, ishock): 330 | k_mean = np.mean(k_cross, axis=1, keepdims=True) 331 | k_mean = np.repeat(k_mean, self.mparam.n_agt, axis=1) 332 | N_tmp = np.repeat(N, self.mparam.n_agt, axis=1) 333 | basic_s = np.stack([k_cross, k_mean, N_tmp, ishock], axis=-1) 334 | basic_s = self.init_ds.normalize_data(basic_s, key="basic_s") 335 | basic_s = basic_s.astype(NP_DTYPE) 336 | full_state_dict = { 337 | "basic_s": basic_s, 338 | "agt_s": self.init_ds.normalize_data(k_cross[:, :, None], key="agt_s") 339 | } 340 | c_share = self.policy_fn(full_state_dict)[..., 0] 341 | return c_share 342 | 343 | def simul_shocks(self, n_sample, T, mparam, state_init): 344 | return JFV.simul_shocks(n_sample, T, mparam, state_init) 345 | -------------------------------------------------------------------------------- /src/simulation_JFV.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.interpolate import RectBivariateSpline 3 | from scipy.interpolate import interp1d 4 | 5 | EPSILON = 1e-3 6 | 7 | # simulate Poisson in discrete time 8 | def simul_shocks(n_sample, T, mparam, state_init=None): 9 | n_agt = mparam.n_agt 10 | ashock = mparam.dt**0.5*np.random.normal(0, mparam.sigma, [n_sample, T]) 11 | ishock = np.ones([n_sample, n_agt, T]) 12 | if state_init: 13 | ishock[..., 0] = state_init["ishock"] 14 | else: 15 | ur_rate = mparam.la2/(mparam.la1 + mparam.la2)*np.ones([n_sample, n_agt]) 16 | rand = np.random.uniform(0, 1, size=(n_sample, n_agt)) 17 | ishock[rand < ur_rate, 0] = 0 18 | 19 | for t in range(1, T): 20 | y_agt = ishock[:, :, t - 1] 21 | ur_rate = (1 - y_agt) * (1 - mparam.la1 * mparam.dt) # unemployed now, (1-lambda1*dt) to remain unemployed 22 | ur_rate += y_agt * mparam.la2 * mparam.dt # employed now, lambda2*dt to become unemployed 23 | rand = np.random.uniform(0, 1, size=(n_sample, n_agt)) 24 | ishock[rand < ur_rate, t] = 0 25 | 26 | return ashock, ishock 27 | 28 | def simul_k(n_sample, T, mparam, c_policy, policy_type, state_init=None, shocks=None): 29 | # policy_type: "pde" or "nn_share" 30 | # return k_cross [n_sample, n_agt, T] 31 | assert policy_type in ["pde", "nn_share"], "Invalid policy type" 32 | n_agt = mparam.n_agt 33 | if shocks: 34 | ashock, ishock = shocks 35 | assert n_sample == ashock.shape[0], "n_sample is inconsistent with given shocks." 36 | assert T == ashock.shape[1], "T is inconsistent with given shocks." 37 | if state_init: 38 | assert np.array_equal(ishock[..., 0], state_init["ishock"]), \ 39 | "Shock inputs are inconsistent with state_init" 40 | else: 41 | ashock, ishock = simul_shocks(n_sample, T, mparam, state_init) 42 | k_cross = np.zeros([n_sample, n_agt, T]) 43 | B, N = np.zeros([n_sample, T]), np.zeros([n_sample, T]) 44 | csmp = np.zeros([n_sample, n_agt, T-1]) 45 | if state_init: 46 | assert n_sample == state_init["k_cross"].shape[0], "n_sample is inconsistent with state_init." 47 | k_cross[:, :, 0] = state_init["k_cross"] 48 | N[:, 0:1] = state_init["N"] 49 | B[:, 0] = np.mean(k_cross[:, :, 0], axis=-1) 50 | else: 51 | if mparam.with_ashock: 52 | k_cross[:, :, 0] = mparam.B_sss 53 | N[:, 0] = mparam.N_sss 54 | B[:, 0] = mparam.B_sss 55 | else: 56 | k_cross[:, :, 0] = mparam.k_dss 57 | N[:, 0] = mparam.N_dss 58 | B[:, 0] = mparam.k_dss 59 | 60 | for t in range(1, T): 61 | K = B[:, t-1] + N[:, t-1] 62 | wage_unit = (1 - mparam.alpha) * K[:, None]**mparam.alpha 63 | wage = (ishock[:, :, t-1] * (mparam.z2-mparam.z1) + mparam.z1) * wage_unit # map 0 to z1 and 1 to z2 64 | r = mparam.alpha * K[:, None]**(mparam.alpha-1) - mparam.delta - mparam.sigma2*K[:, None]/N[:, t-1:t] 65 | wealth = (1 + r*mparam.dt) * k_cross[:, :, t-1] + wage * mparam.dt 66 | if policy_type == "pde": 67 | # to avoid negative wealth 68 | csmp[:, :, t-1] = np.minimum( 69 | c_policy(k_cross[:, :, t-1], N[:, t-1:t], ishock[:, :, t-1]), 70 | wealth/mparam.dt-EPSILON) 71 | elif policy_type == "nn_share": 72 | csmp[:, :, t-1] = c_policy(k_cross[:, :, t-1], N[:, t-1:t], ishock[:, :, t-1]) * (wealth / mparam.dt) 73 | k_cross[:, :, t] = wealth - csmp[:, :, t-1] * mparam.dt 74 | B[:, t] = np.mean(k_cross[:, :, t], axis=1) 75 | dN_drift = mparam.dt * (mparam.alpha * K**(mparam.alpha-1) - mparam.delta - mparam.rhohat - \ 76 | mparam.sigma2*(-B[:, t-1]/N[:, t-1])*(K/N[:, t-1]))*N[:, t-1] 77 | dN_diff = K * ashock[:, t-1] 78 | N[:, t] = N[:, t-1] + dN_drift + dN_diff 79 | 80 | # print(B.max(), B.min(), N.max(), N.min(), csmp.min(), csmp.max()) 81 | # if k_cross.min() < 0 or N.min() < 0: 82 | # print(k_cross.min(), N.min()) 83 | simul_data = {"k_cross": k_cross, "csmp": csmp, "B": B, "N": N, "ishock": ishock} 84 | return simul_data 85 | 86 | 87 | def c_policy_spl_DSS(k_cross, N, ishock, splines): # pylint: disable=W0613 88 | c = np.zeros_like(k_cross) 89 | idx = (ishock == 0) 90 | c[idx] = splines["y0"](k_cross[idx]) 91 | idx = (ishock == 1) 92 | c[idx] = splines["y1"](k_cross[idx]) 93 | return c 94 | 95 | 96 | def construct_spl_DSS(mats, key): 97 | # mats is saved in Matlab through 98 | # save 'ss_for_JQ.mat' aa zz V c g_ss -mat (here z is idiosyncratic income level) 99 | splines = { 100 | 'y0': interp1d(mats['aa'][:, 0], mats[key][:, 0], kind='cubic', fill_value="extrapolate"), 101 | 'y1': interp1d(mats['aa'][:, 1], mats[key][:, 1], kind='cubic', fill_value="extrapolate"), 102 | } 103 | return splines 104 | 105 | 106 | def c_policy_spl_SSS(k_cross, N, ishock, splines): 107 | # this part is simplified than the notebook, considering that B is always <=Bmax (but possibly 1) & (ishock == 0)) 104 | k_tmp, km_tmp = k_cross[idx], k_mean[idx] 105 | k_next[idx] = splines['10'](k_tmp, km_tmp, grid=False) 106 | 107 | idx = ((ashock > 1) & (ishock == 1)) 108 | k_tmp, km_tmp = k_cross[idx], k_mean[idx] 109 | k_next[idx] = splines['11'](k_tmp, km_tmp, grid=False) 110 | 111 | return k_next 112 | 113 | 114 | def construct_bspl(mats): 115 | # mats is saved in Matlab through 116 | # "save(filename, 'kprime', 'k', 'km', 'agshock', 'idshock', 'kmts', 'kcross');" 117 | splines = { 118 | '00': RectBivariateSpline(mats['k'], mats['km'], mats['kprime'][:, :, 0, 0]), 119 | '01': RectBivariateSpline(mats['k'], mats['km'], mats['kprime'][:, :, 0, 1]), 120 | '10': RectBivariateSpline(mats['k'], mats['km'], mats['kprime'][:, :, 1, 0]), 121 | '11': RectBivariateSpline(mats['k'], mats['km'], mats['kprime'][:, :, 1, 1]), 122 | } 123 | return splines 124 | -------------------------------------------------------------------------------- /src/slurm_scripts/KS_1fm.slurm: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --nodes=1 3 | #SBATCH --ntasks-per-node=1 4 | #SBATCH --ntasks-per-socket=1 5 | #SBATCH --mem-per-cpu=60G 6 | #SBATCH --gres=gpu:1 7 | #SBATCH -t 6:00:00 # 6 hours 8 | #SBATCH -o KS50agt%j.out 9 | #SBATCH -e KS50agt%j.err 10 | 11 | export PYTHONUNBUFFERED=TRUE 12 | python train_KS.py -c ./configs/KS/game_nn_n50.json -n 1fm 13 | -------------------------------------------------------------------------------- /src/slurm_scripts/KS_1gm.slurm: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --nodes=1 3 | #SBATCH --ntasks-per-node=1 4 | #SBATCH --ntasks-per-socket=1 5 | #SBATCH --mem-per-cpu=60G 6 | #SBATCH --gres=gpu:1 7 | #SBATCH -t 6:00:00 # 6 hours 8 | #SBATCH -o KS50agt%j.out 9 | #SBATCH -e KS50agt%j.err 10 | 11 | export PYTHONUNBUFFERED=TRUE 12 | python train_KS.py -c ./configs/KS/game_nn_n50_0fm1gm.json -n 1gm 13 | -------------------------------------------------------------------------------- /src/train_JFV.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import time 4 | from absl import app 5 | from absl import flags 6 | from param import JFVParam 7 | from dataset import JFVInitDataSet 8 | from value import ValueTrainer 9 | from policy import JFVPolicyTrainer 10 | from util import print_elapsedtime 11 | 12 | flags.DEFINE_string("config_path", "./configs/JFV_DSS/game_nn_n50.json", 13 | """The path to load json file.""", 14 | short_name='c') 15 | flags.DEFINE_string("exp_name", "test", 16 | """The suffix used in model_path for save.""", 17 | short_name='n') 18 | FLAGS = flags.FLAGS 19 | 20 | def main(argv): 21 | del argv 22 | folder = "JFV_DSS" if "DSS" in FLAGS.config_path else "JFV_SSS" 23 | with open(FLAGS.config_path, 'r') as f: 24 | config = json.load(f) 25 | print("Solving the problem based on the config path {}".format(FLAGS.config_path)) 26 | mparam = JFVParam(config["n_agt"], config["dt"], config["mats_path"], config["with_ashock"]) 27 | # save config at the beginning for checking 28 | model_path = "../data/simul_results/{}/{}_{}_n{}_{}".format( 29 | folder, 30 | "game" if config["policy_config"]["opt_type"] == "game" else "sp", 31 | config["dataset_config"]["value_sampling"], 32 | config["n_agt"], 33 | FLAGS.exp_name, 34 | ) 35 | os.makedirs(model_path, exist_ok=True) 36 | with open(os.path.join(model_path, "config_beg.json"), 'w') as f: 37 | json.dump(config, f) 38 | 39 | start_time = time.monotonic() 40 | 41 | # initial value training 42 | init_ds = JFVInitDataSet(mparam, config) 43 | value_config = config["value_config"] 44 | if config["init_with_bchmk"]: 45 | init_policy = init_ds.c_policy_bchmk 46 | policy_type = "pde" 47 | # TODO: change all "pde" to "conventional" 48 | else: 49 | init_policy = init_ds.c_policy_const_share 50 | policy_type = "nn_share" 51 | train_vds, valid_vds = init_ds.get_valuedataset(init_policy, policy_type, update_init=False) 52 | vtrainers = [ValueTrainer(config) for i in range(value_config["num_vnet"])] 53 | for vtr in vtrainers: 54 | vtr.train(train_vds, valid_vds, value_config["num_epoch"], value_config["batch_size"]) 55 | 56 | # iterative policy and value training 57 | policy_config = config["policy_config"] 58 | ptrainer = JFVPolicyTrainer(vtrainers, init_ds) 59 | ptrainer.train(policy_config["num_step"], policy_config["batch_size"]) 60 | 61 | # save config and models 62 | with open(os.path.join(model_path, "config.json"), 'w') as f: 63 | json.dump(config, f) 64 | for i, vtr in enumerate(vtrainers): 65 | vtr.save_model(os.path.join(model_path, "value{}.h5".format(i))) 66 | ptrainer.save_model(os.path.join(model_path, "policy.h5")) 67 | 68 | end_time = time.monotonic() 69 | print_elapsedtime(end_time - start_time) 70 | 71 | if __name__ == '__main__': 72 | app.run(main) 73 | -------------------------------------------------------------------------------- /src/train_KS.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import time 4 | from absl import app 5 | from absl import flags 6 | from param import KSParam 7 | from dataset import KSInitDataSet 8 | from value import ValueTrainer 9 | from policy import KSPolicyTrainer 10 | from util import print_elapsedtime 11 | 12 | flags.DEFINE_string("config_path", "./configs/KS/game_nn_n50.json", 13 | """The path to load json file.""", 14 | short_name='c') 15 | flags.DEFINE_string("exp_name", "test", 16 | """The suffix used in model_path for save.""", 17 | short_name='n') 18 | FLAGS = flags.FLAGS 19 | 20 | def main(argv): 21 | del argv 22 | with open(FLAGS.config_path, 'r') as f: 23 | config = json.load(f) 24 | print("Solving the problem based on the config path {}".format(FLAGS.config_path)) 25 | mparam = KSParam(config["n_agt"], config["beta"], config["mats_path"]) 26 | # save config at the beginning for checking 27 | model_path = "../data/simul_results/KS/{}_{}_n{}_{}".format( 28 | "game" if config["policy_config"]["opt_type"] == "game" else "sp", 29 | config["dataset_config"]["value_sampling"], 30 | config["n_agt"], 31 | FLAGS.exp_name, 32 | ) 33 | os.makedirs(model_path, exist_ok=True) 34 | with open(os.path.join(model_path, "config_beg.json"), 'w') as f: 35 | json.dump(config, f) 36 | 37 | start_time = time.monotonic() 38 | 39 | # initial value training 40 | init_ds = KSInitDataSet(mparam, config) 41 | value_config = config["value_config"] 42 | if config["init_with_bchmk"]: 43 | init_policy = init_ds.k_policy_bchmk 44 | policy_type = "pde" 45 | # TODO: change all "pde" to "conventional" 46 | else: 47 | init_policy = init_ds.c_policy_const_share 48 | policy_type = "nn_share" 49 | train_vds, valid_vds = init_ds.get_valuedataset(init_policy, policy_type, update_init=False) 50 | vtrainers = [ValueTrainer(config) for i in range(value_config["num_vnet"])] 51 | for vtr in vtrainers: 52 | vtr.train(train_vds, valid_vds, value_config["num_epoch"], value_config["batch_size"]) 53 | 54 | # iterative policy and value training 55 | policy_config = config["policy_config"] 56 | ptrainer = KSPolicyTrainer(vtrainers, init_ds) 57 | ptrainer.train(policy_config["num_step"], policy_config["batch_size"]) 58 | 59 | # save config and models 60 | with open(os.path.join(model_path, "config.json"), 'w') as f: 61 | json.dump(config, f) 62 | for i, vtr in enumerate(vtrainers): 63 | vtr.save_model(os.path.join(model_path, "value{}.h5".format(i))) 64 | ptrainer.save_model(os.path.join(model_path, "policy.h5")) 65 | 66 | end_time = time.monotonic() 67 | print_elapsedtime(end_time - start_time) 68 | 69 | if __name__ == '__main__': 70 | app.run(main) 71 | -------------------------------------------------------------------------------- /src/util.py: -------------------------------------------------------------------------------- 1 | from tensorflow import keras 2 | import tensorflow as tf 3 | import numpy as np 4 | 5 | # def create_model(d_in, d_out, config): 6 | # model = keras.Sequential() 7 | # model.add(keras.layers.InputLayer([d_in])) 8 | # for w in config["net_width"]: 9 | # model.add(keras.layers.Dense(w, activation=config["activation"])) 10 | # model.add((keras.layers.Dense(1, activation=None))) 11 | # return model 12 | 13 | class FeedforwardModel(keras.Model): 14 | def __init__(self, d_in, d_out, config, name="agentmodel", **kwargs): 15 | super(FeedforwardModel, self).__init__(name=name, **kwargs) 16 | self.dense_layers = [keras.layers.Dense(w, activation=config["activation"]) for w in config["net_width"]] 17 | self.dense_layers.append(keras.layers.Dense(d_out, activation=None)) 18 | self.d_in = d_in 19 | 20 | def call(self, inputs): 21 | x = self.dense_layers[0](inputs) 22 | for l in self.dense_layers[1:]: 23 | x = l(x) 24 | return x 25 | 26 | def load_weights_after_init(self, path): 27 | # evaluate once for creating variables before loading weights 28 | zeros = tf.ones([1, 1, self.d_in]) 29 | self.__call__(zeros) 30 | self.load_weights(path) 31 | 32 | class GeneralizedMomModel(FeedforwardModel): 33 | def __init__(self, d_in, d_out, config, name="generalizedmomentmodel", **kwargs): 34 | super(GeneralizedMomModel, self).__init__(d_in, d_out, config, name=name, **kwargs) 35 | 36 | def basis_fn(self, inputs): 37 | x = self.dense_layers[0](inputs) 38 | for l in self.dense_layers[1:]: 39 | x = l(x) 40 | return x 41 | 42 | def call(self, inputs): 43 | x = self.basis_fn(inputs) 44 | gm = tf.reduce_mean(x, axis=-2, keepdims=True) 45 | gm = tf.tile(gm, [1, inputs.shape[-2], 1]) 46 | return gm 47 | 48 | def print_elapsedtime(delta): 49 | hours, rem = divmod(delta, 3600) 50 | minutes, seconds = divmod(rem, 60) 51 | print("Elapsed time: {:0>2}:{:0>2}:{:05.2f}".format(int(hours), int(minutes), seconds)) 52 | 53 | def gini(array): #https://github.com/oliviaguest/gini 54 | """Calculate the Gini of a numpy array.""" 55 | # based on bottom eq: http://www.statsdirect.com/help/content/image/stat0206_wmf.gif 56 | # from: http://www.statsdirect.com/help/default.htm#nonparametric_methods/gini.htm 57 | array = array.flatten() # all values are treated equally, arrays must be 1d 58 | if np.amin(array) < 0: 59 | array -= np.amin(array) # values cannot be negative 60 | array += 0.0000001 # values cannot be 0 61 | array = np.sort(array) # values must be sorted 62 | index = np.arange(1, array.shape[0]+1) # index per array element 63 | n = array.shape[0] # number of array elements 64 | return (np.sum((2 * index - n - 1) * array)) / (n * np.sum(array)) # Gini coefficient 65 | -------------------------------------------------------------------------------- /src/validate_JFV.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import numpy as np 4 | import scipy.io as sio 5 | from absl import app 6 | from absl import flags 7 | from param import JFVParam 8 | from dataset import JFVInitDataSet 9 | from value import ValueTrainer 10 | from policy import JFVPolicyTrainer 11 | import simulation_JFV as JFV 12 | 13 | flags.DEFINE_string("model_path", "../data/simul_results/JFV_DSS/game_nn_n50_test", 14 | """The path to load json file.""", 15 | short_name='m') 16 | flags.DEFINE_integer('n_agt', -1, "Number of agents in validation simulation") 17 | flags.DEFINE_boolean('save', True, "Save simulation results or not.") 18 | FLAGS = flags.FLAGS 19 | 20 | def bellman_err(simul_data, shocks, ptrainer, value_fn, prefix="", nnext=500, nc=10, seed=None, nt=None): 21 | if seed: 22 | np.random.seed(seed) 23 | k_cross, csmp, B, N = simul_data["k_cross"], simul_data["csmp"], simul_data["B"], simul_data["N"] 24 | if nt: 25 | nt = min(nt, csmp.shape[-1]) 26 | else: 27 | nt = csmp.shape[-1] 28 | mparam = ptrainer.mparam 29 | # compute error for n_path * n_agt * nt states 30 | t_idx = np.random.choice(csmp.shape[-1], nt) 31 | n_agt = csmp.shape[1] 32 | k_now, k_next = k_cross[:, :, t_idx], k_cross[:, :, t_idx+1] 33 | knormed_now = ptrainer.init_ds.normalize_data(k_now, key="agt_s") 34 | knormed_next = ptrainer.init_ds.normalize_data(k_next, key="agt_s") 35 | knormed_mean_now = np.repeat(np.mean(knormed_now, axis=-2, keepdims=True), n_agt, axis=1) 36 | knormed_mean_next = np.repeat(np.mean(knormed_next, axis=-2, keepdims=True), n_agt, axis=1) 37 | B_now, B_next = np.repeat(B[:, None, t_idx], n_agt, axis=1), np.repeat(B[:, None, t_idx+1], n_agt, axis=1) 38 | N_now, N_next = np.repeat(N[:, None, t_idx], n_agt, axis=1), np.repeat(N[:, None, t_idx+1], n_agt, axis=1) 39 | c_now = csmp[:, :, t_idx] 40 | K_now = B_now + N_now 41 | ashock = shocks[0][:, t_idx] 42 | ishock = shocks[1][:, :, t_idx] 43 | 44 | ashock_next = mparam.dt**0.5*np.random.normal(0, mparam.sigma, [ashock.shape[0], ashock.shape[1], nnext]) 45 | y_agt = ishock.copy() 46 | ur_rate = (1 - y_agt) * (1 - mparam.la1 * mparam.dt) 47 | ur_rate += y_agt * mparam.la2 * mparam.dt 48 | ishock0 = np.zeros_like(ishock) 49 | ishock1 = np.ones_like(ishock) 50 | 51 | def gm_fn(knormed): # k_normalized of shape B * n_agt * T 52 | knormed = knormed.transpose((0, 2, 1))[:, :, :, None] 53 | basis = [None] * len(ptrainer.vtrainers) 54 | gm = [None] * len(ptrainer.vtrainers) 55 | for i, vtr in enumerate(ptrainer.vtrainers): 56 | basis[i] = vtr.gm_model.basis_fn(knormed).numpy() 57 | basis[i] = basis[i].transpose((0, 2, 1, 3)) 58 | gm[i] = np.repeat(np.mean(basis[i], axis=1, keepdims=True), n_agt, axis=1) 59 | return basis, gm 60 | 61 | basic_s_now = np.stack([k_now, B_now, N_now, ishock], axis=-1) 62 | if ptrainer.init_ds.config["n_fm"] == 2 and "pde" not in prefix: 63 | knormed_sqr_mean_now = np.repeat(np.mean(knormed_now**2, axis=1, keepdims=True), n_agt, axis=1) 64 | knormed_sqr_mean_next = np.repeat(np.mean(knormed_next**2, axis=1, keepdims=True), n_agt, axis=1) 65 | fm_extra_now = knormed_sqr_mean_now-knormed_mean_now**2 66 | fm_extra_now = fm_extra_now[:, :, :, None] 67 | else: 68 | fm_extra_now = None 69 | if ptrainer.init_ds.config["n_gm"] > 0 and "pde" not in prefix: 70 | _, gm_now = gm_fn(knormed_now) 71 | gm_basis_next, gm_next = gm_fn(knormed_next) 72 | else: 73 | gm_now, gm_next = None, None 74 | v_now = value_fn(basic_s_now, fm_extra_now, gm_now) 75 | def next_value_fn(c_tmp): 76 | k_next_tmp = k_next + (c_now - c_tmp) * mparam.dt 77 | B_next_tmp = B_next + (k_next_tmp - k_next) / n_agt 78 | knormed_next_tmp = ptrainer.init_ds.normalize_data(k_next_tmp, key="agt_s") 79 | knormed_mean_next_tmp = knormed_mean_next + (knormed_next_tmp - knormed_next) / n_agt 80 | basic_s_next_tmp = [k_next_tmp, B_next_tmp] 81 | if ptrainer.init_ds.config["n_fm"] == 2 and "pde" not in prefix: 82 | knormed_sqr_mean_next_tmp = knormed_sqr_mean_next + (knormed_next_tmp**2 - knormed_next**2) / n_agt 83 | fm_extra_next_tmp = knormed_sqr_mean_next_tmp - knormed_mean_next_tmp**2 84 | fm_extra_next_tmp = fm_extra_next_tmp[:, :, :, None] 85 | else: 86 | fm_extra_next_tmp = None 87 | if ptrainer.init_ds.config["n_gm"] > 0 and "pde" not in prefix: 88 | gm_basis_next_tmp, _ = gm_fn(knormed_next_tmp) 89 | gm_next_tmp = [gm_next[i] + (gm_basis_next_tmp[i] - gm_basis_next[i]) / n_agt for i in range(len(gm_next))] 90 | else: 91 | gm_next_tmp = None 92 | v_tmp = np.zeros_like(v_now) 93 | for j in range(nnext): 94 | N_next_tmp = N_next - K_now * ashock[:, None, :] + K_now * ashock_next[:, None, :, j] 95 | basic_s_next0_tmp = np.stack(basic_s_next_tmp + [N_next_tmp, ishock0], axis=-1) 96 | basic_s_next1_tmp = np.stack(basic_s_next_tmp + [N_next_tmp, ishock1], axis=-1) 97 | v_next0 = value_fn(basic_s_next0_tmp, fm_extra_next_tmp, gm_next_tmp) 98 | v_next1 = value_fn(basic_s_next1_tmp, fm_extra_next_tmp, gm_next_tmp) 99 | v_tmp += mparam.beta*(v_next0 * ur_rate + v_next1 * (1-ur_rate)) + mparam.dt * (1-1/c_tmp) 100 | v_tmp /= nnext 101 | return v_tmp 102 | # Bellman expectation error 103 | v_next = next_value_fn(c_now) 104 | err_blmexpct = (v_now - v_next) / mparam.dt 105 | # Bellman expectation error 106 | v_next = np.zeros_like(v_now) 107 | c_max = c_now + np.minimum(k_next/mparam.dt-1e-6, 0.5) # sampliewise cmax 108 | c_min = c_now - np.minimum(c_now*0.95, 0.5) # sampliewise cmin 109 | dc = (c_max - c_min) / nc 110 | for i in range(nc+1): 111 | c_tmp = c_min + dc * i 112 | v_tmp = next_value_fn(c_tmp) 113 | v_next = np.maximum(v_tmp, v_next) 114 | err_blmopt = (v_now - v_next) / mparam.dt 115 | print("Bellman error of %3s: %.6f" % \ 116 | (prefix.upper(), np.abs(err_blmopt).mean())) 117 | return err_blmexpct, err_blmopt 118 | 119 | 120 | def main(argv): 121 | del argv 122 | print("Validating the model from {}".format(FLAGS.model_path)) 123 | with open(os.path.join(FLAGS.model_path, "config.json"), 'r') as f: 124 | config = json.load(f) 125 | config["dataset_config"]["n_path"] = config["simul_config"]["n_path"] 126 | config["init_with_bchmk"] = True 127 | if FLAGS.n_agt > 0: 128 | config["n_agt"] = FLAGS.n_agt 129 | mparam = JFVParam(config["n_agt"], config["dt"], config["mats_path"], config["with_ashock"]) 130 | 131 | init_ds = JFVInitDataSet(mparam, config) 132 | value_config = config["value_config"] 133 | vtrainers = [ValueTrainer(config) for i in range(value_config["num_vnet"])] 134 | for i, vtr in enumerate(vtrainers): 135 | vtr.load_model(os.path.join(FLAGS.model_path, "value{}.h5".format(i))) 136 | ptrainer = JFVPolicyTrainer(vtrainers, init_ds, os.path.join(FLAGS.model_path, "policy.h5")) 137 | 138 | # long simulation 139 | simul_config = config["simul_config"] 140 | n_path = simul_config["n_path"] 141 | T = simul_config["T"] 142 | state_init = init_ds.next_batch(n_path) 143 | shocks = JFV.simul_shocks(n_path, T, mparam, state_init) 144 | simul_data_bchmk = JFV.simul_k( 145 | n_path, T, mparam, init_ds.c_policy_bchmk, "pde", 146 | state_init=state_init, shocks=shocks 147 | ) 148 | idx = ~np.isnan(simul_data_bchmk["k_cross"]).any(axis=(1, 2)) # exclude nan path 149 | def remove_nan(simul_data, idx): 150 | new_dict = { 151 | "k_cross": simul_data["k_cross"][idx], 152 | "csmp": simul_data["csmp"][idx], 153 | "N": simul_data["N"][idx], 154 | "B": simul_data["B"][idx] 155 | } 156 | return new_dict 157 | simul_data_bchmk = remove_nan(simul_data_bchmk, idx) 158 | simul_data_nn = JFV.simul_k( 159 | n_path, T, mparam, ptrainer.current_c_policy, "nn_share", 160 | state_init=state_init, shocks=shocks 161 | ) 162 | simul_data_nn = remove_nan(simul_data_nn, idx) 163 | shocks = (shocks[0][idx], shocks[1][idx]) 164 | 165 | # calculate path stats 166 | def path_stats(simul_data, prefix=""): 167 | k_mean = np.mean(simul_data["k_cross"], axis=1) 168 | discount = np.power(mparam.beta, np.arange(simul_data["csmp"].shape[-1])) 169 | util_sum = np.sum((1-1/simul_data["csmp"])*discount, axis=-1) * mparam.dt 170 | print( 171 | "%8s: total utilily: %.5f, mean of k: %.5f, std of k: %.5f, max of k: %.5f, max of K: %.5f" % ( 172 | prefix, util_sum.mean(), simul_data["k_cross"].mean(), simul_data["k_cross"].std(), 173 | simul_data["k_cross"].max(), k_mean.max() 174 | ) 175 | ) 176 | path_stats(simul_data_bchmk, "Conventional") 177 | path_stats(simul_data_nn, "DeepHAM") 178 | 179 | # compute Bellman expectation error 180 | mats = sio.loadmat(mparam.mats_path) 181 | if mparam.with_ashock: 182 | v_spline = JFV.construct_spl_SSS(mats, "V") 183 | value_fn_pde = lambda state, fm_extra, gm: \ 184 | JFV.value_spl_SSS(state[..., 0], state[..., 1], state[..., -2], state[..., -1], v_spline) 185 | else: 186 | v_spline = JFV.construct_spl_DSS(mats, "V") 187 | value_fn_pde = lambda state, fm_extra, gm: \ 188 | JFV.value_spl_DSS(state[..., 0], state[..., 1], state[..., -2], state[..., -1], v_spline) 189 | def value_fn_nn(basic_s, fm_extra, gm): 190 | basic_s = ptrainer.init_ds.normalize_data(basic_s, key="basic_s") 191 | if ptrainer.init_ds.config["n_fm"] == 0: 192 | basic_s = np.concatenate([basic_s[..., 0:1], basic_s[..., 2:]], axis=-1) 193 | if fm_extra is not None: 194 | n_state = basic_s.shape[-1] + fm_extra.shape[-1] 195 | state_fix = np.concatenate([basic_s, fm_extra], axis=-1) 196 | else: 197 | n_state = basic_s.shape[-1] 198 | state_fix = basic_s 199 | if gm is not None: 200 | n_state += gm[0].shape[-1] 201 | state = [None] * len(vtrainers) 202 | for i in range(len(vtrainers)): 203 | state[i] = np.concatenate([state_fix, gm[i]], axis=-1) 204 | state[i] = state[i].transpose((0, 2, 1, 3)).reshape((-1, config['n_agt'], n_state)) 205 | else: 206 | state = [state_fix.transpose((0, 2, 1, 3)).reshape((-1, config['n_agt'], n_state))] * len(vtrainers) 207 | v = 0 208 | for i, vtr in enumerate(vtrainers): 209 | v += vtr.model(state[i]).numpy() 210 | v /= len(vtrainers) 211 | v = ptrainer.init_ds.unnormalize_data(v, key="value") 212 | # reshape and transpose back to path * n_agt * time 213 | v = v.reshape([basic_s.shape[0], basic_s.shape[2], basic_s.shape[1]]) 214 | v = np.transpose(v, (0, 2, 1)) 215 | return v 216 | # err_blm_bchmk = bellman_err(simul_data_bchmk, shocks, ptrainer, value_fn_pde, "pde", seed=1, nt=100) 217 | err_blm_nn = bellman_err(simul_data_nn, shocks, ptrainer, value_fn_nn, "nn", seed=1, nt=100) 218 | 219 | # save data if required 220 | if FLAGS.save: 221 | to_save = { 222 | "config": config, 223 | "k_cross_bchmk": simul_data_bchmk["k_cross"], 224 | "k_cross_nn": simul_data_nn["k_cross"], 225 | "csmp_bchmk": simul_data_bchmk["csmp"], 226 | "csmp_nn": simul_data_nn["csmp"], 227 | "N_bchmk": simul_data_bchmk["N"], 228 | "N_nn": simul_data_nn["N"], 229 | # "err_blmexpct_bchmk": err_blm_bchmk[0], 230 | # "err_blmopt_bchmk": err_blm_bchmk[1], 231 | "err_blmexpct_nn": err_blm_nn[0], 232 | "err_blmopt_nn": err_blm_nn[1], 233 | "ashock": shocks[0], 234 | "ishock": shocks[1], 235 | } 236 | np.savez(os.path.join(FLAGS.model_path, "paths.npz"), **to_save) 237 | 238 | if __name__ == '__main__': 239 | app.run(main) 240 | -------------------------------------------------------------------------------- /src/validate_KS.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import numpy as np 4 | from absl import app 5 | from absl import flags 6 | from param import KSParam 7 | from dataset import KSInitDataSet 8 | from value import ValueTrainer 9 | from policy import KSPolicyTrainer 10 | from simulation_KS import simul_shocks, simul_k 11 | 12 | flags.DEFINE_string("model_path", "../data/simul_results/KS/game_nn_n50_test", 13 | """The path to load json file.""", 14 | short_name='m') 15 | flags.DEFINE_integer('n_agt', -1, "Number of agents in validation simulation") 16 | flags.DEFINE_boolean('save', True, "Save simulation results or not.") 17 | FLAGS = flags.FLAGS 18 | 19 | def bellman_err(simul_data, shocks, ptrainer, value_fn, prefix="", nnext=100, nc=10, seed=None, nt=None): 20 | if seed: 21 | np.random.seed(seed) 22 | k_cross, csmp = simul_data["k_cross"], simul_data["csmp"] 23 | K = np.mean(k_cross, axis=1, keepdims=True) 24 | if nt: 25 | nt = min(nt, csmp.shape[-1]) 26 | else: 27 | nt = csmp.shape[-1] 28 | mparam = ptrainer.mparam 29 | # compute error for n_path * n_agt * nt states 30 | t_idx = np.random.choice(csmp.shape[-1], nt) 31 | n_agt = csmp.shape[1] 32 | k_now, k_next = k_cross[:, :, t_idx], k_cross[:, :, t_idx+1] 33 | knormed_now = ptrainer.init_ds.normalize_data(k_now, key="agt_s") 34 | knormed_next = ptrainer.init_ds.normalize_data(k_next, key="agt_s") 35 | knormed_mean_now = np.repeat(np.mean(knormed_now, axis=-2, keepdims=True), n_agt, axis=1) 36 | knormed_mean_next = np.repeat(np.mean(knormed_next, axis=-2, keepdims=True), n_agt, axis=1) 37 | K_now, K_next = np.repeat(K[:, :, t_idx], n_agt, axis=1), np.repeat(K[:, :, t_idx+1], n_agt, axis=1) 38 | c_now = csmp[:, :, t_idx] 39 | ashock = shocks[0][:, t_idx] 40 | ishock = shocks[1][:, :, t_idx] 41 | 42 | if_keep = np.random.binomial(1, 0.875, [ashock.shape[0], ashock.shape[1], nnext]) # prob for Z to stay the same 43 | # ashock = 1-delta or 1+delta 44 | ashock_next = if_keep * ashock[:, :, None] + (1 - if_keep) * (2 - ashock[:, :, None]) 45 | ashock_01 = (((ashock-1) / mparam.delta_a + 1) / 2).astype(int) 46 | ashock_next_01 = (((ashock_next-1) / mparam.delta_a + 1) / 2).astype(int) 47 | y_agt = ishock.copy() 48 | ishock0 = np.zeros_like(ishock) 49 | ishock1 = np.ones_like(ishock) 50 | 51 | def gm_fn(knormed): # k_normalized of shape B * n_agt * T 52 | knormed = knormed.transpose((0, 2, 1))[:, :, :, None] 53 | basis = [None] * len(ptrainer.vtrainers) 54 | gm = [None] * len(ptrainer.vtrainers) 55 | for i, vtr in enumerate(ptrainer.vtrainers): 56 | basis[i] = vtr.gm_model.basis_fn(knormed).numpy() 57 | basis[i] = basis[i].transpose((0, 2, 1, 3)) 58 | gm[i] = np.repeat(np.mean(basis[i], axis=1, keepdims=True), n_agt, axis=1) 59 | return basis, gm 60 | 61 | basic_s_now = np.stack([k_now, K_now, np.repeat(ashock[:, None, :], n_agt, axis=1), ishock], axis=-1) 62 | if ptrainer.init_ds.config["n_fm"] == 2 and "pde" not in prefix: 63 | knormed_sqr_mean_now = np.repeat(np.mean(knormed_now**2, axis=1, keepdims=True), n_agt, axis=1) 64 | knormed_sqr_mean_next = np.repeat(np.mean(knormed_next**2, axis=1, keepdims=True), n_agt, axis=1) 65 | fm_extra_now = knormed_sqr_mean_now-knormed_mean_now**2 66 | fm_extra_now = fm_extra_now[:, :, :, None] 67 | else: 68 | fm_extra_now = None 69 | if ptrainer.init_ds.config["n_gm"] > 0 and "pde" not in prefix: 70 | _, gm_now = gm_fn(knormed_now) 71 | gm_basis_next, gm_next = gm_fn(knormed_next) 72 | else: 73 | gm_now, gm_next = None, None 74 | v_now = value_fn(basic_s_now, fm_extra_now, gm_now) 75 | def next_value_fn(c_tmp): 76 | k_next_tmp = k_next + (c_now - c_tmp) 77 | K_next_tmp = K_next + (k_next_tmp - k_next) / n_agt 78 | knormed_next_tmp = ptrainer.init_ds.normalize_data(k_next_tmp, key="agt_s") 79 | knormed_mean_next_tmp = knormed_mean_next + (knormed_next_tmp - knormed_next) / n_agt 80 | basic_s_next_tmp = [k_next_tmp, K_next_tmp] 81 | if ptrainer.init_ds.config["n_fm"] == 2 and "pde" not in prefix: 82 | knormed_sqr_mean_next_tmp = knormed_sqr_mean_next + (knormed_next_tmp**2 - knormed_next**2) / n_agt 83 | fm_extra_next_tmp = knormed_sqr_mean_next_tmp - knormed_mean_next_tmp**2 84 | fm_extra_next_tmp = fm_extra_next_tmp[:, :, :, None] 85 | else: 86 | fm_extra_next_tmp = None 87 | if ptrainer.init_ds.config["n_gm"] > 0 and "pde" not in prefix: 88 | gm_basis_next_tmp, _ = gm_fn(knormed_next_tmp) 89 | gm_next_tmp = [gm_next[i] + (gm_basis_next_tmp[i] - gm_basis_next[i]) / n_agt for i in range(len(gm_next))] 90 | else: 91 | gm_next_tmp = None 92 | v_tmp = np.zeros_like(v_now) 93 | for j in range(nnext): 94 | ashock_next_tmp = np.repeat(ashock_next[:, None, :, j], n_agt, axis=1) 95 | basic_s_next0_tmp = np.stack(basic_s_next_tmp + [ashock_next_tmp, ishock0], axis=-1) 96 | basic_s_next1_tmp = np.stack(basic_s_next_tmp + [ashock_next_tmp, ishock1], axis=-1) 97 | v_next0 = value_fn(basic_s_next0_tmp, fm_extra_next_tmp, gm_next_tmp) 98 | v_next1 = value_fn(basic_s_next1_tmp, fm_extra_next_tmp, gm_next_tmp) 99 | # convert to 0,1 for computing ishock transition 100 | a0, a1 = ashock_01[:, None, :], ashock_next_01[:, None, :, j] 101 | ur_rate = (1 - a0) * (1 - a1) * (1 - y_agt) * mparam.p_bb_uu + (1 - a0) * (1 - a1) * y_agt * mparam.p_bb_eu 102 | ur_rate += (1 - a0) * a1 * (1 - y_agt) * mparam.p_bg_uu + (1 - a0) * a1 * y_agt * mparam.p_bg_eu 103 | ur_rate += a0 * (1 - a1) * (1 - y_agt) * mparam.p_gb_uu + a0 * (1 - a1) * y_agt * mparam.p_gb_eu 104 | ur_rate += a0 * a1 * (1 - y_agt) * mparam.p_gg_uu + a0 * a1 * y_agt * mparam.p_gg_eu 105 | v_tmp += mparam.beta*(v_next0 * ur_rate + v_next1 * (1-ur_rate)) + np.log(c_tmp) 106 | v_tmp /= nnext 107 | return v_tmp 108 | # Bellman expectation error 109 | v_next = next_value_fn(c_now) 110 | err_blmexpct = v_now - v_next 111 | # Bellman expectation error 112 | v_next = np.zeros_like(v_now) 113 | c_max = c_now + np.minimum(k_next-1e-6, 5) # sampliewise cmax 114 | c_min = c_now - np.minimum(c_now*0.95, 5) # sampliewise cmin 115 | dc = (c_max - c_min) / nc 116 | for i in range(nc+1): 117 | c_tmp = c_min + dc * i 118 | v_tmp = next_value_fn(c_tmp) 119 | v_next = np.maximum(v_tmp, v_next) 120 | err_blmopt = v_now - v_next 121 | print("Bellman error of %3s: %.6f" % \ 122 | (prefix.upper(), np.abs(err_blmopt).mean())) 123 | return err_blmexpct, err_blmopt 124 | 125 | 126 | def main(argv): 127 | del argv 128 | print("Validating the model from {}".format(FLAGS.model_path)) 129 | with open(os.path.join(FLAGS.model_path, "config.json"), 'r') as f: 130 | config = json.load(f) 131 | config["dataset_config"]["n_path"] = config["simul_config"]["n_path"] 132 | config["init_with_bchmk"] = True 133 | if FLAGS.n_agt > 0: 134 | config["n_agt"] = FLAGS.n_agt 135 | mparam = KSParam(config["n_agt"], config["beta"], config["mats_path"]) 136 | 137 | init_ds = KSInitDataSet(mparam, config) 138 | value_config = config["value_config"] 139 | vtrainers = [ValueTrainer(config) for i in range(value_config["num_vnet"])] 140 | for i, vtr in enumerate(vtrainers): 141 | vtr.load_model(os.path.join(FLAGS.model_path, "value{}.h5".format(i))) 142 | ptrainer = KSPolicyTrainer(vtrainers, init_ds, os.path.join(FLAGS.model_path, "policy.h5")) 143 | 144 | # long simulation 145 | simul_config = config["simul_config"] 146 | n_path = simul_config["n_path"] 147 | T = simul_config["T"] 148 | state_init = init_ds.next_batch(n_path) 149 | shocks = simul_shocks(n_path, T, mparam, state_init) 150 | simul_data_bchmk = simul_k( 151 | n_path, T, mparam, init_ds.k_policy_bchmk, policy_type="pde", 152 | state_init=state_init, shocks=shocks 153 | ) 154 | simul_data_nn = simul_k( 155 | n_path, T, mparam, ptrainer.current_c_policy, policy_type="nn_share", 156 | state_init=state_init, shocks=shocks 157 | ) 158 | 159 | # calculate path stats 160 | def path_stats(simul_data, prefix=""): 161 | k_mean = np.mean(simul_data["k_cross"], axis=1) 162 | discount = np.power(mparam.beta, np.arange(simul_data["csmp"].shape[-1])) 163 | util_sum = np.sum(np.log(simul_data["csmp"])*discount, axis=-1) 164 | print( 165 | "%8s: total utilily: %.5f, mean of k: %.5f, std of k: %.5f, max of k: %.5f, max of K: %.5f" % ( 166 | prefix, util_sum.mean(), simul_data["k_cross"].mean(), simul_data["k_cross"].std(), 167 | simul_data["k_cross"].max(), k_mean.max()) 168 | ) 169 | path_stats(simul_data_bchmk, "Conventional") 170 | path_stats(simul_data_nn, "DeepHAM") 171 | 172 | # compute Bellman expectation error 173 | # value_fn_pde unavailable so far 174 | def value_fn_nn(basic_s, fm_extra, gm): 175 | basic_s = ptrainer.init_ds.normalize_data(basic_s, key="basic_s") 176 | if ptrainer.init_ds.config["n_fm"] == 0: 177 | basic_s = np.concatenate([basic_s[..., 0:1], basic_s[..., 2:]], axis=-1) 178 | if fm_extra is not None: 179 | n_state = basic_s.shape[-1] + fm_extra.shape[-1] 180 | state_fix = np.concatenate([basic_s, fm_extra], axis=-1) 181 | else: 182 | n_state = basic_s.shape[-1] 183 | state_fix = basic_s 184 | if gm is not None: 185 | n_state += gm[0].shape[-1] 186 | state = [None] * len(vtrainers) 187 | for i in range(len(vtrainers)): 188 | state[i] = np.concatenate([state_fix, gm[i]], axis=-1) 189 | state[i] = state[i].transpose((0, 2, 1, 3)).reshape((-1, config['n_agt'], n_state)) 190 | else: 191 | state = [state_fix.transpose((0, 2, 1, 3)).reshape((-1, config['n_agt'], n_state))] * len(vtrainers) 192 | v = 0 193 | for i, vtr in enumerate(vtrainers): 194 | v += vtr.model(state[i]).numpy() 195 | v /= len(vtrainers) 196 | v = ptrainer.init_ds.unnormalize_data(v, key="value") 197 | # reshape and transpose back to path * n_agt * time 198 | v = v.reshape([basic_s.shape[0], basic_s.shape[2], basic_s.shape[1]]) 199 | v = np.transpose(v, (0, 2, 1)) 200 | return v 201 | err_blm_nn = bellman_err(simul_data_nn, shocks, ptrainer, value_fn_nn, "nn", seed=1, nt=100) 202 | 203 | if FLAGS.save: 204 | to_save = { 205 | "k_cross_bchmk": simul_data_bchmk["k_cross"], 206 | "k_cross_nn": simul_data_nn["k_cross"], 207 | "csmp_bchmk": simul_data_bchmk["csmp"], 208 | "csmp_nn": simul_data_nn["csmp"], 209 | "err_blmexpct_nn": err_blm_nn[0], 210 | "err_blmopt_nn": err_blm_nn[1], 211 | "ashock": shocks[0], 212 | "ishock": shocks[1], 213 | } 214 | np.savez(os.path.join(FLAGS.model_path, "paths.npz"), **to_save) 215 | 216 | if __name__ == '__main__': 217 | app.run(main) 218 | -------------------------------------------------------------------------------- /src/value.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import util 4 | 5 | DTYPE = "float64" 6 | tf.keras.backend.set_floatx(DTYPE) 7 | if DTYPE == "float64": 8 | NP_DTYPE = np.float64 9 | elif DTYPE == "float32": 10 | NP_DTYPE = np.float32 11 | else: 12 | raise ValueError("Unknown dtype.") 13 | 14 | class ValueTrainer(): 15 | def __init__(self, config): 16 | self.config = config 17 | self.value_config = config["value_config"] 18 | d_in = config["n_basic"] + config["n_fm"] + config["n_gm"] 19 | self.model = util.FeedforwardModel(d_in, 1, self.value_config, name="v_net") 20 | if config["n_gm"] > 0: 21 | # TODO generalize to multi-dimensional agt_s 22 | self.gm_model = util.GeneralizedMomModel(1, config["n_gm"], config["gm_config"], name="v_gm") 23 | self.train_vars = None 24 | self.optimizer = tf.keras.optimizers.Adam( 25 | learning_rate=self.value_config["lr"], epsilon=1e-8, 26 | beta_1=0.99, beta_2=0.99 27 | ) 28 | 29 | @tf.function 30 | def prepare_state(self, input_data): 31 | if self.config["n_fm"] == 2: 32 | k_var = tf.math.reduce_variance(input_data["agt_s"], axis=-2, keepdims=True) 33 | k_var = tf.tile(k_var, [1, input_data["agt_s"].shape[-2], 1]) 34 | state = tf.concat([input_data["basic_s"], k_var], axis=-1) 35 | elif self.config["n_fm"] == 0: 36 | state = tf.concat([input_data["basic_s"][..., 0:1], input_data["basic_s"][..., 2:]], axis=-1) 37 | elif self.config["n_fm"] == 1: # so far always add k_mean in the basic_state 38 | state = input_data["basic_s"] 39 | if self.config["n_gm"] > 0: 40 | gm = self.gm_model(input_data["agt_s"]) 41 | state = tf.concat([state, gm], axis=-1) 42 | return state 43 | 44 | @tf.function 45 | def value_fn(self, input_data): 46 | state = self.prepare_state(input_data) 47 | value = self.model(state) 48 | return value 49 | 50 | @tf.function 51 | def loss(self, input_data): 52 | y_pred = self.value_fn(input_data) 53 | y = input_data["value"] 54 | loss = tf.reduce_mean(tf.square(y_pred - y)) 55 | loss_dict = {"loss": loss} 56 | return loss_dict 57 | 58 | def grad(self, input_data): 59 | with tf.GradientTape(persistent=True) as tape: 60 | loss = self.loss(input_data)["loss"] 61 | train_vars = self.model.trainable_variables 62 | if self.config["n_gm"] > 0: 63 | train_vars += self.gm_model.trainable_variables 64 | self.train_vars = train_vars 65 | grad = tape.gradient( 66 | loss, 67 | train_vars, 68 | unconnected_gradients=tf.UnconnectedGradients.ZERO, 69 | ) 70 | del tape 71 | return grad 72 | 73 | @tf.function 74 | def train_step(self, train_data): 75 | grad = self.grad(train_data) 76 | self.optimizer.apply_gradients( 77 | zip(grad, self.train_vars) 78 | ) 79 | 80 | def train(self, train_dataset, valid_dataset, num_epoch=None, batch_size=None): 81 | train_dataset = train_dataset.batch(batch_size) 82 | 83 | for epoch in range(num_epoch+1): 84 | for train_data in train_dataset: 85 | self.train_step(train_data) 86 | if epoch % 20 == 0: 87 | for valid_data in valid_dataset: 88 | val_loss = self.loss(valid_data) 89 | print( 90 | # "Epoch: %d, validation loss: %g" % (epoch, val_loss["loss"]) 91 | "Value function learning epoch: %d" % (epoch) 92 | ) 93 | 94 | def save_model(self, path="value_model.h5"): 95 | self.model.save_weights(path) 96 | if self.config["n_gm"] > 0: 97 | self.gm_model.save_weights(path.replace(".h5", "_gm.h5")) 98 | 99 | def load_model(self, path): 100 | self.model.load_weights_after_init(path) 101 | if self.config["n_gm"] > 0: 102 | self.gm_model.load_weights_after_init(path.replace(".h5", "_gm.h5")) 103 | --------------------------------------------------------------------------------