├── .gitattributes ├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.md │ └── enhancement.md ├── PULL_REQUEST_TEMPLATE │ └── pull_request_template.md ├── dependabot.yml └── workflows │ ├── main.yml │ └── publish-to-pypi.yml ├── .gitignore ├── .pre-commit-config.yaml ├── .yamllint.yml ├── CHANGES.md ├── LICENSE ├── README.md ├── codecov.yml ├── examples ├── README.md └── long_running.py ├── explanations ├── README.md └── function_representation.ipynb ├── pixi.lock ├── pyproject.toml ├── src └── lcm │ ├── Q_and_F.py │ ├── __init__.py │ ├── _config.py │ ├── argmax.py │ ├── dispatchers.py │ ├── entry_point.py │ ├── exceptions.py │ ├── function_representation.py │ ├── functools.py │ ├── grid_helpers.py │ ├── grids.py │ ├── input_processing │ ├── __init__.py │ ├── create_params_template.py │ ├── process_model.py │ └── util.py │ ├── interfaces.py │ ├── logging.py │ ├── mark.py │ ├── max_Q_over_c.py │ ├── max_Qc_over_d.py │ ├── ndimage.py │ ├── next_state.py │ ├── py.typed │ ├── random.py │ ├── simulation │ ├── __init__.py │ ├── processing.py │ └── simulate.py │ ├── solution │ ├── __init__.py │ └── solve_brute.py │ ├── state_action_space.py │ ├── typing.py │ ├── user_model.py │ └── utils.py └── tests ├── __init__.py ├── conftest.py ├── data ├── analytical_solution │ ├── iskhakov_2017_five_periods__consumption.csv │ ├── iskhakov_2017_five_periods__values_retired.csv │ ├── iskhakov_2017_five_periods__values_worker.csv │ ├── iskhakov_2017_five_periods__work_decision.csv │ ├── iskhakov_2017_low_delta__consumption.csv │ ├── iskhakov_2017_low_delta__values_retired.csv │ ├── iskhakov_2017_low_delta__values_worker.csv │ └── iskhakov_2017_low_delta__work_decision.csv └── regression_tests │ ├── simulation.pkl │ └── solution.pkl ├── input_processing ├── __init__.py ├── test_create_params_template.py └── test_process_model.py ├── simulation ├── __init__.py ├── test_processing.py └── test_simulate.py ├── solution ├── __init__.py └── test_solve_brute.py ├── test_Q_and_F.py ├── test_analytical_solution.py ├── test_argmax.py ├── test_dispatchers.py ├── test_entry_point.py ├── test_function_representation.py ├── test_functools.py ├── test_grid_helpers.py ├── test_grids.py ├── test_max_Qc_over_d.py ├── test_model.py ├── test_models ├── __init__.py ├── deterministic.py ├── discrete_deterministic.py ├── get_model.py └── stochastic.py ├── test_ndimage.py ├── test_ndimage_unit.py ├── test_next_state.py ├── test_random.py ├── test_regression_test.py ├── test_solution_on_toy_model.py ├── test_state_action_space.py ├── test_stochastic.py └── test_utils.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # SCM syntax highlighting & preventing 3-way merges 2 | pixi.lock merge=binary linguist-language=YAML linguist-generated=true 3 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug Report 3 | about: Create a bug report to help us improve PyLCM 4 | title: "BUG:" 5 | labels: "bug" 6 | --- 7 | 8 | - [ ] I have checked that this issue has not already been reported. 9 | 10 | - [ ] I have confirmed this bug exists on the latest version of PyLCM. 11 | 12 | - [ ] (optional) I have confirmed this bug exists on the `main` branch of PyLCM. 13 | 14 | --- 15 | 16 | **Note**: Please read [this 17 | guide](https://matthewrocklin.com/blog/work/2018/02/28/minimal-bug-reports) detailing 18 | how to provide the necessary information for us to reproduce your bug. 19 | 20 | #### Code Sample, a copy-pastable example 21 | 22 | ```python 23 | # Your code here 24 | ``` 25 | 26 | #### Problem description 27 | 28 | Explain **why** the current behaviour is a problem and why the expected output is a 29 | better solution. 30 | 31 | #### Expected Output 32 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/enhancement.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Enhancement 3 | about: Suggest an idea for PyLCM 4 | title: "ENH:" 5 | labels: "enhancement" 6 | --- 7 | 8 | #### Is your feature request related to a problem? 9 | 10 | Provide a description of what the problem is, e.g. "I wish I could use PyLCM to do 11 | [...]". 12 | 13 | #### Describe the solution you'd like 14 | 15 | Provide a description of the feature request and how it might be implemented. 16 | 17 | #### API breaking implications 18 | 19 | Provide a description of how this feature will affect the API. 20 | 21 | #### Describe alternatives you've considered 22 | 23 | Provide a description of any alternative solutions or features you've considered. 24 | 25 | #### Additional context 26 | 27 | Add any other context, code examples, or references to existing implementations about 28 | the feature request here. 29 | 30 | ```python 31 | # Your code here, if applicable 32 | ``` 33 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE/pull_request_template.md: -------------------------------------------------------------------------------- 1 | ### What problem do you want to solve? 2 | 3 | Reference the issue or discussion, if there is any. Provide a description of your 4 | proposed solution. 5 | 6 | ### Todo 7 | 8 | - [ ] Target the right branch and pick an appropriate title. 9 | - [ ] Put `Closes #XXXX` in the first PR comment to auto-close the relevant issue once 10 | the PR is accepted. This is not applicable if there is no corresponding issue. 11 | - [ ] Any steps that still need to be done. 12 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | --- 2 | version: 2 3 | updates: 4 | - package-ecosystem: github-actions 5 | directory: / 6 | schedule: 7 | interval: weekly 8 | groups: 9 | github-actions: 10 | patterns: 11 | - '*' 12 | -------------------------------------------------------------------------------- /.github/workflows/main.yml: -------------------------------------------------------------------------------- 1 | --- 2 | name: main 3 | # Automatically cancel a previous run. 4 | concurrency: 5 | group: ${{ github.head_ref || github.run_id }} 6 | cancel-in-progress: true 7 | on: 8 | push: 9 | branches: 10 | - main 11 | pull_request: 12 | branches: 13 | - '*' 14 | jobs: 15 | run-tests: 16 | name: Run tests for ${{ matrix.os }} on ${{ matrix.python-version }} 17 | runs-on: ${{ matrix.os }} 18 | strategy: 19 | fail-fast: false 20 | matrix: 21 | os: 22 | - ubuntu-latest 23 | - macos-latest 24 | - windows-latest 25 | python-version: 26 | - '3.12' 27 | steps: 28 | - uses: actions/checkout@v4 29 | - uses: prefix-dev/setup-pixi@v0.8.8 30 | with: 31 | pixi-version: v0.41.4 32 | cache: true 33 | cache-write: ${{ github.event_name == 'push' && github.ref_name == 'main' }} 34 | environments: test-cpu 35 | activate-environment: true 36 | frozen: true 37 | - name: Run pytest 38 | shell: bash {0} 39 | run: pixi run -e test-cpu tests 40 | if: runner.os != 'Linux' || matrix.python-version != '3.12' 41 | - name: Run pytest and collect coverage 42 | shell: bash {0} 43 | run: pixi run -e test-cpu tests-with-cov 44 | if: runner.os == 'Linux' && matrix.python-version == '3.12' 45 | - name: Upload coverage report 46 | if: runner.os == 'Linux' && matrix.python-version == '3.12' 47 | uses: codecov/codecov-action@v5 48 | run-mypy: 49 | name: Run mypy on Python 3.12 50 | runs-on: ubuntu-latest 51 | strategy: 52 | fail-fast: false 53 | steps: 54 | - uses: actions/checkout@v4 55 | - uses: prefix-dev/setup-pixi@v0.8.8 56 | with: 57 | pixi-version: v0.41.4 58 | cache: true 59 | cache-write: ${{ github.event_name == 'push' && github.ref_name == 'main' }} 60 | environments: test-cpu 61 | frozen: true 62 | - name: Run mypy 63 | shell: bash {0} 64 | run: pixi run -e test-cpu mypy 65 | # run-explanation-notebooks: 66 | # name: Run explanation notebooks on Python 3.12 67 | # runs-on: ubuntu-latest 68 | # steps: 69 | # - uses: actions/checkout@v4 70 | # - uses: prefix-dev/setup-pixi@v0.8.8 71 | # with: 72 | # pixi-version: v0.41.4 73 | # cache: true 74 | # cache-write: ${{ github.event_name == 'push' && github.ref_name == 'main' }} 75 | # environments: test-cpu 76 | # frozen: true 77 | # - name: Run explanation notebooks 78 | # shell: bash {0} 79 | # run: pixi run -e test-cpu explanation-notebooks 80 | -------------------------------------------------------------------------------- /.github/workflows/publish-to-pypi.yml: -------------------------------------------------------------------------------- 1 | --- 2 | name: PyPI 3 | on: push 4 | jobs: 5 | build-n-publish: 6 | name: Build and publish Python 🐍 distributions 📦 to PyPI 7 | runs-on: ubuntu-latest 8 | steps: 9 | - uses: actions/checkout@v4 10 | - name: Set up Python 3.13 11 | uses: actions/setup-python@v5 12 | with: 13 | python-version: '3.13' 14 | - name: Install pypa/build 15 | run: >- 16 | python -m 17 | pip install 18 | build 19 | --user 20 | - name: Build a binary wheel and a source tarball 21 | run: >- 22 | python -m 23 | build 24 | --sdist 25 | --wheel 26 | --outdir dist/ 27 | - name: Publish distribution 📦 to PyPI 28 | if: startsWith(github.ref, 'refs/tags') 29 | uses: pypa/gh-action-pypi-publish@release/v1 30 | with: 31 | password: ${{ secrets.PYLCM_PYPI_TOKEN }} 32 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # MacOS specific service store 7 | .DS_Store 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | *build/ 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | *.sublime-workspace 38 | *.sublime-project 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | docs/build/ 75 | docs/source/_build/ 76 | docs/source/refs.bib.bak 77 | 78 | # PyBuilder 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # celery beat schedule file 88 | celerybeat-schedule 89 | 90 | # SageMath parsed files 91 | *.sage.py 92 | 93 | # Environments 94 | .env 95 | .venv 96 | env/ 97 | venv/ 98 | ENV/ 99 | env.bak/ 100 | venv.bak/ 101 | 102 | # Spyder project settings 103 | .spyderproject 104 | .spyproject 105 | 106 | # VSCode project settings 107 | .vscode 108 | 109 | # Rope project settings 110 | .ropeproject 111 | 112 | # mkdocs documentation 113 | /site 114 | 115 | # mypy 116 | .mypy_cache/ 117 | 118 | *notes/ 119 | 120 | .idea/ 121 | 122 | *.bak 123 | 124 | 125 | *.db 126 | 127 | 128 | .pytask.sqlite3 129 | src/lcm/_version.py 130 | .pixi 131 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | repos: 3 | - repo: meta 4 | hooks: 5 | - id: check-hooks-apply 6 | - id: check-useless-excludes 7 | # - id: identity # Prints all files passed to pre-commits. Debugging. 8 | - repo: https://github.com/lyz-code/yamlfix 9 | rev: 1.17.0 10 | hooks: 11 | - id: yamlfix 12 | - repo: https://github.com/pre-commit/pre-commit-hooks 13 | rev: v5.0.0 14 | hooks: 15 | - id: check-added-large-files 16 | args: 17 | - --maxkb=1300 18 | - id: check-case-conflict 19 | - id: check-merge-conflict 20 | - id: check-vcs-permalinks 21 | - id: check-yaml 22 | - id: check-toml 23 | - id: debug-statements 24 | - id: end-of-file-fixer 25 | - id: fix-byte-order-marker 26 | types: 27 | - text 28 | - id: forbid-submodules 29 | - id: mixed-line-ending 30 | args: 31 | - --fix=lf 32 | description: Forces to replace line ending by the UNIX 'lf' character. 33 | - id: name-tests-test 34 | args: 35 | - --pytest-test-first 36 | exclude: ^tests/test_models/ 37 | - id: no-commit-to-branch 38 | args: 39 | - --branch 40 | - main 41 | - id: trailing-whitespace 42 | - id: check-ast 43 | - id: check-docstring-first 44 | - repo: https://github.com/adrienverge/yamllint.git 45 | rev: v1.37.1 46 | hooks: 47 | - id: yamllint 48 | - repo: https://github.com/astral-sh/ruff-pre-commit 49 | rev: v0.11.8 50 | hooks: 51 | # Run the linter. 52 | - id: ruff 53 | types_or: 54 | - python 55 | - pyi 56 | - jupyter 57 | args: 58 | - --fix 59 | # Run the formatter. 60 | - id: ruff-format 61 | types_or: 62 | - python 63 | - pyi 64 | - jupyter 65 | - repo: https://github.com/executablebooks/mdformat 66 | rev: 0.7.22 67 | hooks: 68 | - id: mdformat 69 | additional_dependencies: 70 | - mdformat-gfm 71 | - mdformat-gfm-alerts 72 | - mdformat-ruff 73 | args: 74 | - --wrap 75 | - '88' 76 | files: (README\.md) 77 | - repo: https://github.com/kynan/nbstripout 78 | rev: 0.8.1 79 | hooks: 80 | - id: nbstripout 81 | args: 82 | - --drop-empty-cells 83 | - --keep-output 84 | - repo: https://github.com/pre-commit/mirrors-mypy 85 | rev: v1.15.0 86 | hooks: 87 | - id: mypy 88 | files: src|tests 89 | additional_dependencies: 90 | - dags>=0.3.0 91 | - jax>=0.5.1 92 | - numpy 93 | - packaging 94 | - pandas-stubs 95 | - pytest 96 | - scipy-stubs 97 | args: 98 | - --config=pyproject.toml 99 | ci: 100 | autoupdate_schedule: monthly 101 | skip: 102 | - mypy # installing jax is not possible on pre-commit.ci due to size limits. 103 | -------------------------------------------------------------------------------- /.yamllint.yml: -------------------------------------------------------------------------------- 1 | --- 2 | yaml-files: 3 | - '*.yaml' 4 | - '*.yml' 5 | - .yamllint 6 | rules: 7 | braces: enable 8 | brackets: enable 9 | colons: enable 10 | commas: enable 11 | comments: 12 | level: warning 13 | comments-indentation: 14 | level: warning 15 | document-end: disable 16 | document-start: 17 | level: warning 18 | empty-lines: enable 19 | empty-values: disable 20 | float-values: disable 21 | hyphens: enable 22 | indentation: {spaces: 2} 23 | key-duplicates: enable 24 | key-ordering: disable 25 | line-length: 26 | max: 88 27 | allow-non-breakable-words: true 28 | allow-non-breakable-inline-mappings: false 29 | new-line-at-end-of-file: enable 30 | new-lines: 31 | type: unix 32 | octal-values: disable 33 | quoted-strings: disable 34 | trailing-spaces: enable 35 | truthy: 36 | level: warning 37 | -------------------------------------------------------------------------------- /CHANGES.md: -------------------------------------------------------------------------------- 1 | # Changes 2 | 3 | 4 | This is a record of all past PyLCM releases and what went into them in reverse 5 | chronological order. We follow [semantic versioning](https://semver.org/). 6 | 7 | 8 | ## 0.0.1 9 | 10 | ### Initial Release 11 | 12 | - First public release of PyLCM. 13 | 14 | - Includes core functionality: 15 | 16 | - Specification of finite-horizon discrete-continuous choice models with an 17 | arbitrary number of discrete and continuous states and actions. 18 | 19 | - Linearly and Log-linearly spaced grids that approximate continuous states and 20 | actions. 21 | 22 | - Linear interpolation and extrapolation of the value function for continuous 23 | states. 24 | 25 | - Grid search (brute-force) for finding the optimal continuous policy. 26 | 27 | - Stochastic state transitions for discrete states which may depend on other 28 | discrete states and actions. 29 | 30 | - Built with contributions from the PyLCM team. 31 | 32 | 33 | ### Contributions 34 | 35 | Thanks to everyone who contributed to this release: 36 | 37 | - {ghuser}`hmgaudecker` 38 | 39 | Initiated and drove the development agenda for PyLCM, ensuring strategic direction 40 | and alignment. He actively steered the project, facilitated collaboration, and secured 41 | funding to support core development. Additionally, he reviewed pull requests and 42 | provided feedback on the internal and external code structure and design. 43 | 44 | - {ghuser}`janosg` 45 | 46 | Designed and implemented the initial prototype of PyLCM, laying the foundation for its 47 | development. He onboarded {ghuser}`timmens` and played a key role in shaping the 48 | project's direction. After stepping back from active development, he contributed to 49 | implementation discussions and later provided guidance on architectural decisions. 50 | 51 | - {ghuser}`timmens` 52 | 53 | Took over development of PyLCM, expanding its functionality with key features like 54 | the simulation function, extrapolation capabilities, and special arguments. He led 55 | extensive refactoring to improve code clarity, maintainability, and testability, 56 | making the package easier to develop and extend. His contributions also include 57 | improved documentation, type annotations, static type checking, and the introduction 58 | of example and explanation notebooks. 59 | 60 | - {ghuser}`mj023` 61 | 62 | Analyzed and optimized PyLCM's performance on the GPU, profiling execution and 63 | examining the computational graph of JAX-compiled functions. He fine-tuned the `solve` 64 | function's just-in-time compilation to reduce runtime and improve efficiency. 65 | Additionally, he compared PyLCM's performance against similar libraries, providing 66 | insights into its computational efficiency. 67 | 68 | - {ghuser}`mo2561057` 69 | 70 | Added tests for the model processing and fully discrete models. 71 | 72 | - {ghuser}`MImmesberger` 73 | 74 | Added checks to test PyLCM's results against analytical solutions. 75 | 76 | #### Early contributors 77 | 78 | - {ghuser}`segsell` 79 | 80 | - {ghuser}`ChristianZimpelmann` 81 | 82 | - {ghuser}`tobiasraabe` 83 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Life Cycle Models 2 | 3 | [![Ruff](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/v2.json)](https://github.com/astral-sh/ruff) 4 | [![pre-commit.ci status](https://results.pre-commit.ci/badge/github/opensourceeconomics/pylcm/main.svg)](https://results.pre-commit.ci/latest/github/opensourceeconomics/pylcm/main) 5 | [![image](https://codecov.io/gh/opensourceeconomics/pylcm/branch/main/graph/badge.svg)](https://codecov.io/gh/opensourceeconomics/pylcm) 6 | 7 | This package aims to generalize and facilitate the specification, solution, and 8 | simulation of finite-horizon discrete-continuous dynamic choice models. 9 | 10 | ## Installation 11 | 12 | PyLCM can be installed via PyPI or via GitHub. To do so, type the following in a 13 | terminal: 14 | 15 | ```console 16 | $ pip install pylcm 17 | ``` 18 | 19 | or, for the latest development version, type: 20 | 21 | ```console 22 | $ pip install git+https://github.com/OpenSourceEconomics/pylcm.git 23 | ``` 24 | 25 | ### GPU Support 26 | 27 | By default, the installation of PyLCM comes with the CPU version of `jax`. If you aim to 28 | run PyLCM on a GPU, you need to install a `jaxlib` version with GPU support. For the 29 | installation of `jaxlib`, please consult the `jax` 30 | [docs](https://jax.readthedocs.io/en/latest/installation.html#supported-platforms). 31 | 32 | > [!NOTE] 33 | > GPU support is currently only tested on Linux with CUDA 12. 34 | 35 | ## Developing 36 | 37 | We use [pixi](https://pixi.sh/latest/) for our local development environment. If you 38 | want to work with or extend the PyLCM code base you can run the tests using 39 | 40 | ```console 41 | $ git clone https://github.com/OpenSourceEconomics/pylcm.git 42 | $ pixi run tests 43 | ``` 44 | 45 | This will install the development environment and run the tests. You can run 46 | [mypy](https://mypy-lang.org/) using 47 | 48 | ```console 49 | $ pixi run mypy 50 | ``` 51 | 52 | Before committing, install the pre-commit hooks using 53 | 54 | ```console 55 | $ pixi global install pre-commit 56 | $ pre-commit install 57 | ``` 58 | 59 | ## Questions 60 | 61 | If you have any questions, feel free to ask them on the PyLCM 62 | [Zulip chat](https://ose.zulipchat.com/#narrow/channel/491562-PyLCM). 63 | 64 | ## License 65 | 66 | This project is licensed under the Apache License, Version 2.0. See the 67 | [LICENSE](LICENSE) file for details. 68 | 69 | Copyright (c) 2023- The PyLCM Authors 70 | -------------------------------------------------------------------------------- /codecov.yml: -------------------------------------------------------------------------------- 1 | --- 2 | codecov: 3 | notify: 4 | require_ci_to_pass: true 5 | coverage: 6 | precision: 2 7 | round: down 8 | range: 50...100 9 | status: 10 | patch: 11 | default: 12 | target: 70% 13 | project: 14 | default: 15 | target: 90% 16 | ignore: 17 | - tests/**/* 18 | -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | # Example model specifications 2 | 3 | ## Choosing an example 4 | 5 | | Example name | Description | Runtime | 6 | | ----------------------------------- | ------------------------------------------------- | ------------- | 7 | | [`long_running`](./long_running.py) | Consumption-savings model with health and leisure | a few minutes | 8 | 9 | ## Running an example 10 | 11 | Say you want to solve the `long_running` example locally. First, clone this repository, 12 | [install pixi if required](https://pixi.sh/latest/#installation), move into the examples 13 | folder, and open the interactive Python shell. In a console, type: 14 | 15 | ```console 16 | $ git clone https://github.com/opensourceeconomics/pylcm.git 17 | $ cd lcm/examples 18 | $ pixi run ipython 19 | ``` 20 | 21 | In that shell, run the following code: 22 | 23 | ```python 24 | from lcm.entry_point import get_lcm_function 25 | 26 | from long_running import MODEL_CONFIG, PARAMS 27 | 28 | 29 | solve_model, _ = get_lcm_function(model=MODEL_CONFIG, targets="solve") 30 | V_arr_list = solve_model(PARAMS) 31 | ``` 32 | -------------------------------------------------------------------------------- /examples/long_running.py: -------------------------------------------------------------------------------- 1 | """Example specification for a consumption-savings model with health and exercise.""" 2 | 3 | from dataclasses import dataclass 4 | 5 | import jax.numpy as jnp 6 | 7 | from lcm import DiscreteGrid, LinspaceGrid, Model 8 | 9 | # ====================================================================================== 10 | # Model functions 11 | # ====================================================================================== 12 | 13 | 14 | # -------------------------------------------------------------------------------------- 15 | # Categorical variables 16 | # -------------------------------------------------------------------------------------- 17 | @dataclass 18 | class WorkingStatus: 19 | retired: int = 0 20 | working: int = 1 21 | 22 | 23 | # -------------------------------------------------------------------------------------- 24 | # Utility function 25 | # -------------------------------------------------------------------------------------- 26 | def utility(consumption, working, health, exercise, disutility_of_work): 27 | return jnp.log(consumption) - (disutility_of_work - health) * working - exercise 28 | 29 | 30 | # -------------------------------------------------------------------------------------- 31 | # Auxiliary variables 32 | # -------------------------------------------------------------------------------------- 33 | def labor_income(wage, working): 34 | return wage * working 35 | 36 | 37 | def wage(age): 38 | return 1 + 0.1 * age 39 | 40 | 41 | def age(_period): 42 | return _period + 18 43 | 44 | 45 | # -------------------------------------------------------------------------------------- 46 | # State transitions 47 | # -------------------------------------------------------------------------------------- 48 | def next_wealth(wealth, consumption, labor_income, interest_rate): 49 | return (1 + interest_rate) * (wealth + labor_income - consumption) 50 | 51 | 52 | def next_health(health, exercise, working): 53 | return health * (1 + exercise - working / 2) 54 | 55 | 56 | # -------------------------------------------------------------------------------------- 57 | # Constraints 58 | # -------------------------------------------------------------------------------------- 59 | def consumption_constraint(consumption, wealth, labor_income): 60 | return consumption <= wealth + labor_income 61 | 62 | 63 | # ====================================================================================== 64 | # Model specification and parameters 65 | # ====================================================================================== 66 | RETIREMENT_AGE = 65 67 | 68 | 69 | MODEL_CONFIG = Model( 70 | n_periods=RETIREMENT_AGE - 18, 71 | functions={ 72 | "utility": utility, 73 | "next_wealth": next_wealth, 74 | "next_health": next_health, 75 | "consumption_constraint": consumption_constraint, 76 | "labor_income": labor_income, 77 | "wage": wage, 78 | "age": age, 79 | }, 80 | actions={ 81 | "working": DiscreteGrid(WorkingStatus), 82 | "consumption": LinspaceGrid( 83 | start=1, 84 | stop=100, 85 | n_points=100, 86 | ), 87 | "exercise": LinspaceGrid( 88 | start=0, 89 | stop=1, 90 | n_points=200, 91 | ), 92 | }, 93 | states={ 94 | "wealth": LinspaceGrid( 95 | start=1, 96 | stop=100, 97 | n_points=100, 98 | ), 99 | "health": LinspaceGrid( 100 | start=0, 101 | stop=1, 102 | n_points=100, 103 | ), 104 | }, 105 | ) 106 | 107 | PARAMS = { 108 | "beta": 0.95, 109 | "utility": {"disutility_of_work": 0.05}, 110 | "next_wealth": {"interest_rate": 0.05}, 111 | } 112 | -------------------------------------------------------------------------------- /explanations/README.md: -------------------------------------------------------------------------------- 1 | # Explanations of internal PyLCM concepts 2 | 3 | > [!NOTE] 4 | > 1. The following explanations are designed for PyLCM developers and not users. 5 | > 1. Figures are only rendered correctly on nbviewer, not on GitHub. Please use the 6 | > links below to view the correctly rendered notebooks. 7 | 8 | | Module name | Description | 9 | | ------------------------------------------------------------------------------------------------------------------------------------------ | ---------------------------------------------------------------------- | 10 | | [`function_representation.py`](https://nbviewer.org/github/opensourceeconomics/pylcm/blob/main/explanations/function_representation.ipynb) | Explanations of what the function representation does and how it works | 11 | -------------------------------------------------------------------------------- /src/lcm/__init__.py: -------------------------------------------------------------------------------- 1 | try: 2 | import pdbp # noqa: F401 3 | except ImportError: 4 | pass 5 | 6 | from lcm import mark 7 | from lcm.grids import DiscreteGrid, LinspaceGrid, LogspaceGrid 8 | from lcm.user_model import Model 9 | 10 | __all__ = ["DiscreteGrid", "LinspaceGrid", "LogspaceGrid", "Model", "mark"] 11 | -------------------------------------------------------------------------------- /src/lcm/_config.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | TEST_DATA = Path(__file__).parent.parent.parent.resolve().joinpath("tests", "data") 4 | -------------------------------------------------------------------------------- /src/lcm/argmax.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | from jax import Array 3 | 4 | # ====================================================================================== 5 | # argmax 6 | # ====================================================================================== 7 | 8 | 9 | def argmax_and_max( 10 | a: Array, 11 | axis: int | tuple[int, ...] | None = None, 12 | initial: float | None = None, 13 | where: Array | None = None, 14 | ) -> tuple[Array, Array]: 15 | """Compute the argmax of an n-dim array along axis. 16 | 17 | If multiple maxima exist, the first index will be selected. 18 | 19 | Args: 20 | a: Multidimensional array. 21 | axis: Axis along which to compute the argmax. If None, the argmax is computed 22 | over all axes. 23 | initial: The minimum value of an output element. Must be present to 24 | allow computation on empty slice. See ~numpy.ufunc.reduce for details. 25 | where: Elements to compare for the maximum. See ~numpy.ufunc.reduce 26 | for details. 27 | 28 | Returns: 29 | - The argmax indices. Array with the same shape as a, except for the dimensions 30 | specified in axis, which are dropped. The value corresponds to an index that 31 | can be translated into a tuple of indices using jnp.unravel_index. 32 | - The corresponding maximum values. 33 | 34 | """ 35 | # Preparation 36 | # ================================================================================== 37 | if axis is None: 38 | axis = tuple(range(a.ndim)) 39 | elif isinstance(axis, int): 40 | axis = (axis,) 41 | 42 | # Move axis over which to compute the argmax to the back and flatten last dims 43 | # ================================================================================== 44 | a = _move_axes_to_back(a, axes=axis) 45 | a = _flatten_last_n_axes(a, n=len(axis)) 46 | 47 | # Do same transformation for where 48 | # ================================================================================== 49 | if where is not None: 50 | where = _move_axes_to_back(where, axes=axis) 51 | where = _flatten_last_n_axes(where, n=len(axis)) 52 | 53 | # Compute argmax over last dimension 54 | # ---------------------------------------------------------------------------------- 55 | # Note: If multiple maxima exist, this approach will select the first index. 56 | # ================================================================================== 57 | _max = jnp.max(a, axis=-1, keepdims=True, initial=initial, where=where) 58 | max_value_mask = a == _max 59 | if where is not None: 60 | max_value_mask = jnp.logical_and(max_value_mask, where) 61 | _argmax = jnp.argmax(max_value_mask, axis=-1) 62 | 63 | return _argmax, _max.reshape(_argmax.shape) 64 | 65 | 66 | def _move_axes_to_back(a: Array, axes: tuple[int, ...]) -> Array: 67 | """Move specified axes to the back of the array. 68 | 69 | Args: 70 | a: Multidimensional jax array. 71 | axes: Axes to move to the back. 72 | 73 | Returns: 74 | Array a with shifted axes. 75 | 76 | """ 77 | front_axes = sorted(set(range(a.ndim)) - set(axes)) 78 | return a.transpose((*front_axes, *axes)) 79 | 80 | 81 | def _flatten_last_n_axes(a: Array, n: int) -> Array: 82 | """Flatten the last n axes of a to 1 dimension. 83 | 84 | Args: 85 | a: Multidimensional jax array. 86 | n: Number of axes to flatten. 87 | 88 | Returns: 89 | Array a with flattened last n axes. 90 | 91 | """ 92 | return a.reshape(*a.shape[:-n], -1) 93 | -------------------------------------------------------------------------------- /src/lcm/entry_point.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Callable 2 | from functools import partial 3 | from typing import Literal 4 | 5 | import jax 6 | import pandas as pd 7 | from jax import Array 8 | 9 | from lcm.input_processing import process_model 10 | from lcm.interfaces import StateActionSpace, StateSpaceInfo 11 | from lcm.logging import get_logger 12 | from lcm.max_Q_over_c import ( 13 | get_argmax_and_max_Q_over_c, 14 | get_max_Q_over_c, 15 | ) 16 | from lcm.max_Qc_over_d import get_max_Qc_over_d 17 | from lcm.next_state import get_next_state_function 18 | from lcm.Q_and_F import ( 19 | get_Q_and_F, 20 | ) 21 | from lcm.simulation.simulate import simulate, solve_and_simulate 22 | from lcm.solution.solve_brute import solve 23 | from lcm.state_action_space import ( 24 | create_state_action_space, 25 | create_state_space_info, 26 | ) 27 | from lcm.typing import ( 28 | ArgmaxQOverCFunction, 29 | MaxQcOverDFunction, 30 | MaxQOverCFunction, 31 | ParamsDict, 32 | Target, 33 | ) 34 | from lcm.user_model import Model 35 | 36 | 37 | def get_lcm_function( 38 | model: Model, 39 | *, 40 | targets: Literal["solve", "simulate", "solve_and_simulate"], 41 | debug_mode: bool = True, 42 | jit: bool = True, 43 | ) -> tuple[Callable[..., dict[int, Array] | pd.DataFrame], ParamsDict]: 44 | """Entry point for users to get high level functions generated by lcm. 45 | 46 | Return the function to solve and/or simulate a model along with a template for the 47 | parameters. 48 | 49 | Advanced users might want to use lower level functions instead, but can read the 50 | source code of this function to see how the lower level components are meant to be 51 | used. 52 | 53 | Args: 54 | model: User model specification. 55 | targets: The requested function types. Currently only "solve", "simulate" and 56 | "solve_and_simulate" are supported. 57 | debug_mode: Whether to log debug messages. 58 | jit: Whether to jit the returned function. 59 | 60 | Returns: 61 | - A function that takes params (and possibly other arguments, such as initial 62 | states in the simulate case) and returns the requested targets. 63 | - A parameter dictionary where all parameter values are initialized to NaN. 64 | 65 | """ 66 | # ================================================================================== 67 | # preparations 68 | # ================================================================================== 69 | if targets not in {"solve", "simulate", "solve_and_simulate"}: 70 | raise NotImplementedError 71 | 72 | internal_model = process_model(model) 73 | last_period = internal_model.n_periods - 1 74 | 75 | logger = get_logger(debug_mode=debug_mode) 76 | 77 | # ================================================================================== 78 | # Create model functions and state-action-spaces 79 | # ================================================================================== 80 | state_action_spaces: dict[int, StateActionSpace] = {} 81 | state_space_infos: dict[int, StateSpaceInfo] = {} 82 | max_Q_over_c_functions: dict[int, MaxQOverCFunction] = {} 83 | argmax_and_max_Q_over_c_functions: dict[int, ArgmaxQOverCFunction] = {} 84 | max_Qc_over_d_functions: dict[int, MaxQcOverDFunction] = {} 85 | 86 | for period in reversed(range(internal_model.n_periods)): 87 | is_last_period = period == last_period 88 | 89 | state_action_space = create_state_action_space( 90 | model=internal_model, 91 | is_last_period=is_last_period, 92 | ) 93 | 94 | state_space_info = create_state_space_info( 95 | model=internal_model, 96 | is_last_period=is_last_period, 97 | ) 98 | 99 | if is_last_period: 100 | next_state_space_info = LastPeriodsNextStateSpaceInfo 101 | else: 102 | next_state_space_info = state_space_infos[period + 1] 103 | 104 | Q_and_F = get_Q_and_F( 105 | model=internal_model, 106 | next_state_space_info=next_state_space_info, 107 | period=period, 108 | ) 109 | 110 | max_Q_over_c = get_max_Q_over_c( 111 | Q_and_F=Q_and_F, 112 | continuous_actions_names=tuple(state_action_space.continuous_actions), 113 | states_and_discrete_actions_names=state_action_space.states_and_discrete_actions_names, 114 | ) 115 | 116 | argmax_and_max_Q_over_c = get_argmax_and_max_Q_over_c( 117 | Q_and_F=Q_and_F, 118 | continuous_actions_names=tuple(state_action_space.continuous_actions), 119 | ) 120 | 121 | max_Qc_over_d = get_max_Qc_over_d( 122 | random_utility_shock_type=internal_model.random_utility_shocks, 123 | variable_info=internal_model.variable_info, 124 | is_last_period=is_last_period, 125 | ) 126 | 127 | state_action_spaces[period] = state_action_space 128 | state_space_infos[period] = state_space_info 129 | max_Q_over_c_functions[period] = max_Q_over_c 130 | argmax_and_max_Q_over_c_functions[period] = argmax_and_max_Q_over_c 131 | max_Qc_over_d_functions[period] = max_Qc_over_d 132 | 133 | # ================================================================================== 134 | # select requested solver and partial arguments into it 135 | # ================================================================================== 136 | _solve_model = partial( 137 | solve, 138 | state_action_spaces=state_action_spaces, 139 | max_Q_over_c_functions=max_Q_over_c_functions, 140 | max_Qc_over_d_functions=max_Qc_over_d_functions, 141 | logger=logger, 142 | ) 143 | solve_model = jax.jit(_solve_model) if jit else _solve_model 144 | 145 | _next_state_simulate = get_next_state_function( 146 | model=internal_model, target=Target.SIMULATE 147 | ) 148 | next_state_simulate = jax.jit(_next_state_simulate) if jit else _next_state_simulate 149 | simulate_model = partial( 150 | simulate, 151 | argmax_and_max_Q_over_c_functions=argmax_and_max_Q_over_c_functions, 152 | model=internal_model, 153 | next_state=next_state_simulate, # type: ignore[arg-type] 154 | logger=logger, 155 | ) 156 | 157 | solve_and_simulate_model = partial( 158 | solve_and_simulate, 159 | argmax_and_max_Q_over_c_functions=argmax_and_max_Q_over_c_functions, 160 | model=internal_model, 161 | next_state=next_state_simulate, # type: ignore[arg-type] 162 | logger=logger, 163 | solve_model=solve_model, 164 | ) 165 | 166 | target_func: Callable[..., dict[int, Array] | pd.DataFrame] 167 | 168 | if targets == "solve": 169 | target_func = solve_model 170 | elif targets == "simulate": 171 | target_func = simulate_model 172 | elif targets == "solve_and_simulate": 173 | target_func = solve_and_simulate_model 174 | 175 | return target_func, internal_model.params 176 | 177 | 178 | LastPeriodsNextStateSpaceInfo = StateSpaceInfo( 179 | states_names=(), 180 | discrete_states={}, 181 | continuous_states={}, 182 | ) 183 | -------------------------------------------------------------------------------- /src/lcm/exceptions.py: -------------------------------------------------------------------------------- 1 | class ModelInitilizationError(Exception): 2 | """Raised when there is an error in the model initialization.""" 3 | 4 | 5 | class GridInitializationError(Exception): 6 | """Raised when there is an error in the grid initialization.""" 7 | 8 | 9 | def format_messages(errors: str | list[str]) -> str: 10 | """Convert message or list of messages into a single string.""" 11 | if isinstance(errors, str): 12 | formatted = errors 13 | elif len(errors) == 1: 14 | formatted = errors[0] 15 | else: 16 | enumerated = "\n\n".join([f"{i}. {error}" for i, error in enumerate(errors, 1)]) 17 | formatted = f"The following errors occurred:\n\n{enumerated}" 18 | return formatted 19 | -------------------------------------------------------------------------------- /src/lcm/grid_helpers.py: -------------------------------------------------------------------------------- 1 | """Functions to generate and work with different kinds of grids. 2 | 3 | Grid generation functions must have the following signature: 4 | 5 | Signature (start: Scalar, stop: Scalar, n_points: int) -> jax.Array 6 | 7 | They take start and end points and create a grid of points between them. 8 | 9 | 10 | Interpolation info functions must have the following signature: 11 | 12 | Signature ( 13 | value: Scalar, 14 | start: Scalar, 15 | stop: Scalar, 16 | n_points: int 17 | ) -> Scalar 18 | 19 | They take the information required to generate a grid, and return an index corresponding 20 | to the value, which is a point in the space but not necessarily a grid point. 21 | 22 | Some of the arguments will not be used by all functions but the aligned interface makes 23 | it easy to call functions interchangeably. 24 | 25 | """ 26 | 27 | import jax.numpy as jnp 28 | from jax import Array 29 | 30 | from lcm.typing import Scalar 31 | 32 | 33 | def linspace(start: Scalar, stop: Scalar, n_points: int) -> Array: 34 | """Wrapper around jnp.linspace. 35 | 36 | Returns a linearly spaced grid between start and stop with n_points, including both 37 | endpoints. 38 | 39 | """ 40 | return jnp.linspace(start, stop, n_points) 41 | 42 | 43 | def get_linspace_coordinate( 44 | value: Scalar, 45 | start: Scalar, 46 | stop: Scalar, 47 | n_points: int, 48 | ) -> Scalar: 49 | """Map a value into the input needed for jax.scipy.ndimage.map_coordinates.""" 50 | step_length = (stop - start) / (n_points - 1) 51 | return (value - start) / step_length 52 | 53 | 54 | def logspace(start: Scalar, stop: Scalar, n_points: int) -> Array: 55 | """Wrapper around jnp.logspace. 56 | 57 | Returns a logarithmically spaced grid between start and stop with n_points, 58 | including both endpoints. 59 | 60 | From the JAX documentation: 61 | 62 | In linear space, the sequence starts at base ** start (base to the power of 63 | start) and ends with base ** stop [...]. 64 | 65 | """ 66 | start_linear = jnp.log(start) 67 | stop_linear = jnp.log(stop) 68 | return jnp.logspace(start_linear, stop_linear, n_points, base=jnp.e) 69 | 70 | 71 | def get_logspace_coordinate( 72 | value: Scalar, 73 | start: Scalar, 74 | stop: Scalar, 75 | n_points: int, 76 | ) -> Scalar: 77 | """Map a value into the input needed for jax.scipy.ndimage.map_coordinates.""" 78 | # Transform start, stop, and value to linear scale 79 | start_linear = jnp.log(start) 80 | stop_linear = jnp.log(stop) 81 | value_linear = jnp.log(value) 82 | 83 | # Calculate coordinate in linear space 84 | coordinate_in_linear_space = get_linspace_coordinate( 85 | value_linear, 86 | start_linear, 87 | stop_linear, 88 | n_points, 89 | ) 90 | 91 | # Calculate rank of lower and upper point in logarithmic space 92 | rank_lower_gridpoint = jnp.floor(coordinate_in_linear_space) 93 | rank_upper_gridpoint = rank_lower_gridpoint + 1 94 | 95 | # Calculate lower and upper point in logarithmic space 96 | step_length_linear = (stop_linear - start_linear) / (n_points - 1) 97 | lower_gridpoint = jnp.exp(start_linear + step_length_linear * rank_lower_gridpoint) 98 | upper_gridpoint = jnp.exp(start_linear + step_length_linear * rank_upper_gridpoint) 99 | 100 | # Calculate the decimal part of coordinate 101 | logarithmic_step_size_at_coordinate = upper_gridpoint - lower_gridpoint 102 | distance_from_lower_gridpoint = value - lower_gridpoint 103 | 104 | # If the distance from the lower gridpoint is zero, the coordinate corresponds to 105 | # the rank of the lower gridpoint. The other extreme is when the distance is equal 106 | # to the logarithmic step size at the coordinate, in which case the coordinate 107 | # corresponds to the rank of the upper gridpoint. For values in between, the 108 | # coordinate lies on a linear scale between the ranks of the lower and upper 109 | # gridpoints. 110 | decimal_part = distance_from_lower_gridpoint / logarithmic_step_size_at_coordinate 111 | return rank_lower_gridpoint + decimal_part 112 | -------------------------------------------------------------------------------- /src/lcm/input_processing/__init__.py: -------------------------------------------------------------------------------- 1 | from .process_model import process_model 2 | 3 | __all__ = ["process_model"] 4 | -------------------------------------------------------------------------------- /src/lcm/input_processing/create_params_template.py: -------------------------------------------------------------------------------- 1 | """Create a parameter template for a model specification.""" 2 | 3 | import inspect 4 | 5 | import jax.numpy as jnp 6 | import pandas as pd 7 | from jax import Array 8 | 9 | from lcm.input_processing.util import get_grids, get_variable_info 10 | from lcm.typing import ParamsDict 11 | from lcm.user_model import Model 12 | 13 | 14 | def create_params_template( 15 | model: Model, 16 | default_params: dict[str, float] = {"beta": jnp.nan}, # noqa: B006 17 | ) -> ParamsDict: 18 | """Create parameter template from a model specification. 19 | 20 | Args: 21 | model: The model as provided by the user. 22 | default_params: A dictionary of default parameters. Default is None. If None, 23 | the default {"beta": np.nan} is used. For other lifetime reward objectives, 24 | additional parameters may be required, for example {"beta": np.nan, "delta": 25 | np.nan} for beta-delta discounting. 26 | 27 | Returns: 28 | A nested dictionary of model parameters. 29 | 30 | """ 31 | variable_info = get_variable_info(model) 32 | grids = get_grids(model) 33 | 34 | if variable_info["is_stochastic"].any(): 35 | stochastic_transitions = _create_stochastic_transition_params( 36 | model=model, 37 | variable_info=variable_info, 38 | grids=grids, 39 | ) 40 | stochastic_transition_params = {"shocks": stochastic_transitions} 41 | else: 42 | stochastic_transition_params = {} 43 | 44 | function_params = _create_function_params(model) 45 | 46 | return default_params | function_params | stochastic_transition_params 47 | 48 | 49 | def _create_function_params(model: Model) -> dict[str, dict[str, float]]: 50 | """Get function parameters from a model specification. 51 | 52 | Explanation: We consider the arguments of all model functions, from which we exclude 53 | all variables that are states, actions or the period argument. Everything else is 54 | considered a parameter of the respective model function that is provided by the 55 | user. 56 | 57 | Args: 58 | model: The model as provided by the user. 59 | 60 | Returns: 61 | A dictionary for each model function, containing a parameters required in the 62 | model functions, initialized with jnp.nan. 63 | 64 | """ 65 | # Collect all model variables, that includes actions, states, the period, and 66 | # auxiliary variables (model function names). 67 | variables = { 68 | *model.functions, 69 | *model.actions, 70 | *model.states, 71 | "_period", 72 | } 73 | 74 | if hasattr(model, "shocks"): 75 | variables = variables | set(model.shocks) 76 | 77 | function_params = {} 78 | # For each model function, capture the arguments of the function that are not in the 79 | # set of model variables, and initialize them. 80 | for name, func in model.functions.items(): 81 | arguments = set(inspect.signature(func).parameters) 82 | params = sorted(arguments.difference(variables)) 83 | function_params[name] = dict.fromkeys(params, jnp.nan) 84 | 85 | return function_params 86 | 87 | 88 | def _create_stochastic_transition_params( 89 | model: Model, 90 | variable_info: pd.DataFrame, 91 | grids: dict[str, Array], 92 | ) -> dict[str, Array]: 93 | """Create parameters for stochastic transitions. 94 | 95 | Args: 96 | model: The model as provided by the user. 97 | variable_info: A dataframe with information about the variables. 98 | grids: A dictionary of grids consistent with model. 99 | 100 | Returns: 101 | A dictionary of parameters required for stochastic transitions, initialized with 102 | jnp.nan matrices of the correct dimensions. 103 | 104 | """ 105 | stochastic_variables = variable_info.query("is_stochastic").index.tolist() 106 | 107 | # Assert that all stochastic variables are discrete state variables 108 | # ================================================================================== 109 | discrete_state_vars = set(variable_info.query("is_state & is_discrete").index) 110 | 111 | invalid = set(stochastic_variables) - discrete_state_vars 112 | if invalid: 113 | raise ValueError( 114 | f"The following variables are stochastic, but are not discrete state " 115 | f"variables: {invalid}. This is currently not supported.", 116 | ) 117 | 118 | # Create template matrices for stochastic transitions 119 | # ================================================================================== 120 | 121 | # Stochastic transition functions can only depend on discrete vars or '_period'. 122 | valid_vars = set(variable_info.query("is_discrete").index) | {"_period"} 123 | 124 | stochastic_transition_params = {} 125 | invalid_dependencies = {} 126 | 127 | for var in stochastic_variables: 128 | # Retrieve corresponding next function and its arguments 129 | next_var = model.functions[f"next_{var}"] 130 | dependencies = list(inspect.signature(next_var).parameters) 131 | 132 | # If there are invalid dependencies, store them in a dictionary and continue 133 | # with the next variable to collect as many invalid arguments as possible. 134 | invalid = set(dependencies) - valid_vars 135 | if invalid: 136 | invalid_dependencies[var] = invalid 137 | else: 138 | # Get the dimensions of variables that influence the stochastic variable 139 | dimensions_of_deps = [ 140 | len(grids[arg]) if arg != "_period" else model.n_periods 141 | for arg in dependencies 142 | ] 143 | # Add the dimension of the stochastic variable itself at the end 144 | dimensions = (*dimensions_of_deps, len(grids[var])) 145 | 146 | stochastic_transition_params[var] = jnp.full(dimensions, jnp.nan) 147 | 148 | # Raise an error if there are invalid arguments 149 | # ================================================================================== 150 | if invalid_dependencies: 151 | raise ValueError( 152 | f"Stochastic transition functions can only depend on discrete variables or " 153 | "'_period'. The following variables have invalid arguments: " 154 | f"{invalid_dependencies}.", 155 | ) 156 | 157 | return stochastic_transition_params 158 | -------------------------------------------------------------------------------- /src/lcm/input_processing/util.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from dags import get_ancestors 3 | from jax import Array 4 | 5 | from lcm.grids import ContinuousGrid, Grid 6 | from lcm.typing import UserFunction 7 | from lcm.user_model import Model 8 | 9 | 10 | def get_function_info(model: Model) -> pd.DataFrame: 11 | """Derive information about functions in the model. 12 | 13 | Args: 14 | model: The model as provided by the user. 15 | 16 | Returns: 17 | A table with information about all functions in the model. The index contains 18 | the name of a model function. The columns are booleans that are True if the 19 | function has the corresponding property. The columns are: is_next, 20 | is_stochastic_next, is_constraint. 21 | 22 | """ 23 | info = pd.DataFrame(index=list(model.functions)) 24 | # Convert both filter and constraint to constraints, until we forbid filters. 25 | info["is_constraint"] = info.index.str.endswith(("_constraint", "_filter")) 26 | info["is_next"] = info.index.str.startswith("next_") & ~info["is_constraint"] 27 | info["is_stochastic_next"] = [ 28 | hasattr(func, "_stochastic_info") and info.loc[func_name]["is_next"] 29 | for func_name, func in model.functions.items() 30 | ] 31 | return info 32 | 33 | 34 | def get_variable_info(model: Model) -> pd.DataFrame: 35 | """Derive information about all variables in the model. 36 | 37 | Args: 38 | model: The model as provided by the user. 39 | 40 | Returns: 41 | A table with information about all variables in the model. The index contains 42 | the name of a model variable. The columns are booleans that are True if the 43 | variable has the corresponding property. The columns are: is_state, is_action, 44 | is_continuous, is_discrete. 45 | 46 | """ 47 | function_info = get_function_info(model) 48 | 49 | variables = model.states | model.actions 50 | 51 | info = pd.DataFrame(index=list(variables)) 52 | 53 | info["is_state"] = info.index.isin(model.states) 54 | info["is_action"] = ~info["is_state"] 55 | 56 | info["is_continuous"] = [ 57 | isinstance(spec, ContinuousGrid) for spec in variables.values() 58 | ] 59 | info["is_discrete"] = ~info["is_continuous"] 60 | 61 | info["is_stochastic"] = [ 62 | (var in model.states and function_info.loc[f"next_{var}", "is_stochastic_next"]) 63 | for var in variables 64 | ] 65 | 66 | auxiliary_variables = _get_auxiliary_variables( 67 | state_variables=info.query("is_state").index.tolist(), 68 | function_info=function_info, 69 | user_functions=model.functions, 70 | ) 71 | info["is_auxiliary"] = [var in auxiliary_variables for var in variables] 72 | 73 | order = info.query("is_discrete & is_state").index.tolist() 74 | order += info.query("is_discrete & is_action").index.tolist() 75 | order += info.query("is_continuous & is_state").index.tolist() 76 | order += info.query("is_continuous & is_action").index.tolist() 77 | 78 | if set(order) != set(info.index): 79 | raise ValueError("Order and index do not match.") 80 | 81 | return info.loc[order] 82 | 83 | 84 | def _get_auxiliary_variables( 85 | state_variables: list[str], 86 | function_info: pd.DataFrame, 87 | user_functions: dict[str, UserFunction], 88 | ) -> list[str]: 89 | """Get state variables that only occur in next functions. 90 | 91 | Args: 92 | state_variables: List of state variable names. 93 | function_info: A table with information about all 94 | functions in the model. The index contains the name of a function. The 95 | columns are booleans that are True if the function has the corresponding 96 | property. The columns are: is_filter, is_constraint, is_next. 97 | user_functions: Dictionary that maps names of functions to functions. 98 | 99 | Returns: 100 | List of state variable names that are only used in next functions. 101 | 102 | """ 103 | non_next_functions = function_info.query("~is_next").index.tolist() 104 | user_functions = {name: user_functions[name] for name in non_next_functions} 105 | ancestors = get_ancestors( 106 | user_functions, 107 | targets=list(user_functions), 108 | include_targets=True, 109 | ) 110 | return list(set(state_variables).difference(set(ancestors))) 111 | 112 | 113 | def get_gridspecs( 114 | model: Model, 115 | ) -> dict[str, Grid]: 116 | """Create a dictionary of grid specifications for each variable in the model. 117 | 118 | Args: 119 | model (dict): The model as provided by the user. 120 | 121 | Returns: 122 | Dictionary containing all variables of the model. The keys are the names of the 123 | variables. The values describe which values the variable can take. For discrete 124 | variables these are the codes. For continuous variables this is information 125 | about how to build the grids. 126 | 127 | """ 128 | variable_info = get_variable_info(model) 129 | 130 | raw_variables = model.states | model.actions 131 | order = variable_info.index.tolist() 132 | return {k: raw_variables[k] for k in order} 133 | 134 | 135 | def get_grids( 136 | model: Model, 137 | ) -> dict[str, Array]: 138 | """Create a dictionary of array grids for each variable in the model. 139 | 140 | Args: 141 | model: The model as provided by the user. 142 | 143 | Returns: 144 | Dictionary containing all variables of the model. The keys are the names of the 145 | variables. The values are the grids. 146 | 147 | """ 148 | variable_info = get_variable_info(model) 149 | gridspecs = get_gridspecs(model) 150 | 151 | grids = {name: spec.to_jax() for name, spec in gridspecs.items()} 152 | order = variable_info.index.tolist() 153 | return {k: grids[k] for k in order} 154 | -------------------------------------------------------------------------------- /src/lcm/interfaces.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | from collections.abc import Mapping 3 | 4 | import pandas as pd 5 | from jax import Array 6 | 7 | from lcm.grids import ContinuousGrid, DiscreteGrid, Grid 8 | from lcm.typing import InternalUserFunction, ParamsDict, ShockType 9 | from lcm.utils import first_non_none 10 | 11 | 12 | @dataclasses.dataclass(frozen=True) 13 | class StateActionSpace: 14 | """The state-action space. 15 | 16 | When used for the model solution: 17 | --------------------------------- 18 | 19 | The state-action space becomes the full Cartesian product of the state variables and 20 | the action variables. 21 | 22 | When used for the simulation: 23 | ---------------------------- 24 | 25 | The state-action space becomes the product of state-combinations with the full 26 | Cartesian product of the action variables. 27 | 28 | In both cases, infeasible state-action combinations will be masked. 29 | 30 | Note: 31 | ----- 32 | We store discrete and continuous actions separately since these are handled during 33 | different stages of the solution and simulation processes. 34 | 35 | Attributes: 36 | states: Dictionary containing the values of the state variables. 37 | discrete_actions: Dictionary containing the values of the discrete action 38 | variables. 39 | continuous_actions: Dictionary containing the values of the continuous action 40 | variables. 41 | states_and_discrete_actions_names: Tuple with names of states and discrete 42 | action variables in the order they appear in the variable info table. 43 | 44 | """ 45 | 46 | states: dict[str, Array] 47 | discrete_actions: dict[str, Array] 48 | continuous_actions: dict[str, Array] 49 | states_and_discrete_actions_names: tuple[str, ...] 50 | 51 | def replace( 52 | self, 53 | states: dict[str, Array] | None = None, 54 | discrete_actions: dict[str, Array] | None = None, 55 | continuous_actions: dict[str, Array] | None = None, 56 | ) -> "StateActionSpace": 57 | """Replace the states or actions in the state-action space. 58 | 59 | Args: 60 | states: Dictionary with new states. If None, the existing states are used. 61 | discrete_actions: Dictionary with new discrete actions. If None, the 62 | existing discrete actions are used. 63 | continuous_actions: Dictionary with new continuous actions. If None, the 64 | existing continuous actions are used. 65 | 66 | Returns: 67 | New state-action space with the replaced states or actions. 68 | 69 | """ 70 | states = first_non_none(states, self.states) 71 | discrete_actions = first_non_none(discrete_actions, self.discrete_actions) 72 | continuous_actions = first_non_none(continuous_actions, self.continuous_actions) 73 | return dataclasses.replace( 74 | self, 75 | states=states, 76 | discrete_actions=discrete_actions, 77 | continuous_actions=continuous_actions, 78 | ) 79 | 80 | 81 | @dataclasses.dataclass(frozen=True) 82 | class StateSpaceInfo: 83 | """Information to work with the output of a function evaluated on a state space. 84 | 85 | An example is the value function array, which is the output of the value function 86 | evaluated on the state space. 87 | 88 | Attributes: 89 | var_names: Tuple with names of state variables. 90 | discrete_vars: Dictionary with grids of discrete state variables. 91 | continuous_vars: Dictionary with grids of continuous state variables. 92 | 93 | """ 94 | 95 | states_names: tuple[str, ...] 96 | discrete_states: Mapping[str, DiscreteGrid] 97 | continuous_states: Mapping[str, ContinuousGrid] 98 | 99 | 100 | @dataclasses.dataclass(frozen=True) 101 | class InternalModel: 102 | """Internal representation of a user model. 103 | 104 | Attributes: 105 | grids: Dictionary that maps names of model variables to grids of feasible values 106 | for that variable. 107 | gridspecs: Dictionary that maps names of model variables to specifications from 108 | which grids of feasible values can be built. 109 | variable_info: A table with information about all variables in the model. The 110 | index contains the name of a model variable. The columns are booleans that 111 | are True if the variable has the corresponding property. The columns are: 112 | is_state, is_action, is_continuous, is_discrete. 113 | functions: Dictionary that maps names of functions to functions. The functions 114 | differ from the user functions in that they take `params` as a keyword 115 | argument. Two cases: 116 | - If the original function depended on model parameters, those are 117 | automatically extracted from `params` and passed to the original 118 | function. 119 | - Otherwise, the `params` argument is simply ignored. 120 | function_info: A table with information about all functions in the model. The 121 | index contains the name of a function. The columns are booleans that are 122 | True if the function has the corresponding property. The columns are: 123 | is_constraint, is_next. 124 | params: Dict of model parameters. 125 | n_periods: Number of periods. 126 | random_utility_shocks: Type of random utility shocks. 127 | 128 | """ 129 | 130 | grids: dict[str, Array] 131 | gridspecs: dict[str, Grid] 132 | variable_info: pd.DataFrame 133 | functions: dict[str, InternalUserFunction] 134 | function_info: pd.DataFrame 135 | params: ParamsDict 136 | n_periods: int 137 | # Not properly processed yet 138 | random_utility_shocks: ShockType 139 | 140 | 141 | @dataclasses.dataclass(frozen=True) 142 | class InternalSimulationPeriodResults: 143 | """The results of a simulation for one period.""" 144 | 145 | value: Array 146 | actions: dict[str, Array] 147 | states: dict[str, Array] 148 | -------------------------------------------------------------------------------- /src/lcm/logging.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | 4 | def get_logger(*, debug_mode: bool) -> logging.Logger: 5 | """Get a logger that logs to stdout. 6 | 7 | Args: 8 | debug_mode: Whether to log debug messages. 9 | 10 | Returns: 11 | Logger that logs to stdout. 12 | 13 | """ 14 | logging.basicConfig(level=logging.WARNING) 15 | logger = logging.getLogger("lcm") 16 | 17 | if debug_mode: 18 | logger.setLevel(logging.DEBUG) 19 | 20 | return logger 21 | -------------------------------------------------------------------------------- /src/lcm/mark.py: -------------------------------------------------------------------------------- 1 | """Collection of LCM marking decorators.""" 2 | 3 | import functools 4 | from collections.abc import Callable 5 | from dataclasses import dataclass 6 | from typing import Any, ParamSpec, TypeVar 7 | 8 | P = ParamSpec("P") 9 | R = TypeVar("R") 10 | 11 | 12 | @dataclass(frozen=True) 13 | class StochasticInfo: 14 | """Information on the stochastic nature of user provided functions.""" 15 | 16 | 17 | def stochastic( 18 | func: Callable[..., Any], 19 | *args: tuple[Any, ...], 20 | **kwargs: dict[str, Any], 21 | ) -> Callable[..., Any]: 22 | """Decorator to mark a function as stochastic and add information. 23 | 24 | Args: 25 | func (callable): The function to be decorated. 26 | *args (list): Positional arguments to be passed to the StochasticInfo. 27 | **kwargs (dict): Keyword arguments to be passed to the StochasticInfo. 28 | 29 | Returns: 30 | The decorated function 31 | 32 | """ 33 | stochastic_info = StochasticInfo(*args, **kwargs) 34 | 35 | def decorator_stochastic(func: Callable[P, R]) -> Callable[P, R]: 36 | @functools.wraps(func) 37 | def wrapper_mark_stochastic(*args: P.args, **kwargs: P.kwargs) -> R: 38 | return func(*args, **kwargs) 39 | 40 | wrapper_mark_stochastic._stochastic_info = stochastic_info # type: ignore[attr-defined] 41 | return wrapper_mark_stochastic 42 | 43 | return decorator_stochastic(func) if callable(func) else decorator_stochastic 44 | -------------------------------------------------------------------------------- /src/lcm/max_Q_over_c.py: -------------------------------------------------------------------------------- 1 | import functools 2 | from collections.abc import Callable 3 | 4 | import jax.numpy as jnp 5 | from jax import Array 6 | 7 | from lcm.argmax import argmax_and_max 8 | from lcm.dispatchers import productmap 9 | from lcm.typing import ArgmaxQOverCFunction, MaxQOverCFunction, ParamsDict, Scalar 10 | 11 | 12 | def get_max_Q_over_c( 13 | Q_and_F: Callable[..., tuple[Array, Array]], 14 | continuous_actions_names: tuple[str, ...], 15 | states_and_discrete_actions_names: tuple[str, ...], 16 | ) -> MaxQOverCFunction: 17 | r"""Get the function returning the maximum of Q over continuous actions. 18 | 19 | The state-action value function $Q$ is defined as: 20 | 21 | ```{math} 22 | Q(x, a) = H(U(x, a), \mathbb{E}[V(x', a') | x, a]), 23 | ``` 24 | with $H(U, v) = u + \beta \cdot v$ as the leading case (which is the only one that 25 | is pre-implemented in LCM). 26 | 27 | Fixing a state and discrete action, maximizing over the feasible continuous actions, 28 | we get the $Q^c$ function: 29 | 30 | ```{math} 31 | Q^{c}(x, a^d) = \max_{a^c} Q(x, a^d, a^c). 32 | ``` 33 | 34 | This last step is handled by the function returned here. 35 | 36 | Args: 37 | Q_and_F: A function that takes a state-action combination and returns the action 38 | value of that combination and whether the state-action combination is 39 | feasible. 40 | continuous_actions_names: Tuple of action variable names that are continuous. 41 | states_and_discrete_actions_names: Tuple of state and discrete action variable 42 | names. 43 | 44 | Returns: 45 | Qc, i.e., the function that calculates the maximum of the Q-function over the 46 | feasible continuous actions. 47 | 48 | """ 49 | if continuous_actions_names: 50 | Q_and_F = productmap( 51 | func=Q_and_F, 52 | variables=continuous_actions_names, 53 | ) 54 | 55 | @functools.wraps(Q_and_F) 56 | def max_Q_over_c(next_V_arr: Array, params: ParamsDict, **kwargs: Scalar) -> Array: 57 | Q_arr, F_arr = Q_and_F(params=params, next_V_arr=next_V_arr, **kwargs) 58 | return Q_arr.max(where=F_arr, initial=-jnp.inf) 59 | 60 | return productmap(max_Q_over_c, variables=states_and_discrete_actions_names) 61 | 62 | 63 | def get_argmax_and_max_Q_over_c( 64 | Q_and_F: Callable[..., tuple[Array, Array]], 65 | continuous_actions_names: tuple[str, ...], 66 | ) -> ArgmaxQOverCFunction: 67 | r"""Get the function returning the arguments maximizing Q over continuous actions. 68 | 69 | The state-action value function $Q$ is defined as: 70 | 71 | ```{math} 72 | Q(x, a) = H(U(x, a), \mathbb{E}[V(x', a') | x, a]), 73 | ``` 74 | with $H(U, v) = u + \beta \cdot v$ as the leading case (which is the only one that 75 | is pre-implemented in LCM). 76 | 77 | Fixing a state and discrete action but choosing the feasible continuous actions that 78 | maximizes Q, we get 79 | 80 | ```{math} 81 | \pi^{c}(x, a^d) = \argmax_{a^c} Q(x, a^d, a^c). 82 | ``` 83 | 84 | This last step is handled by the function returned here. 85 | 86 | Args: 87 | Q_and_F: A function that takes a state-action combination and returns the action 88 | value of that combination and whether the state-action combination is 89 | feasible. 90 | continuous_actions_names: Tuple of action variable names that are continuous. 91 | 92 | Returns: 93 | Function that calculates the argument maximizing Q over the feasible continuous 94 | actions and the maximum iteself. The argument maximizing Q is the policy 95 | function of the continuous actions, conditional on the states and discrete 96 | actions. The maximum corresponds to the Qc-function. 97 | 98 | """ 99 | if continuous_actions_names: 100 | Q_and_F = productmap( 101 | func=Q_and_F, 102 | variables=continuous_actions_names, 103 | ) 104 | 105 | @functools.wraps(Q_and_F) 106 | def argmax_and_max_Q_over_c( 107 | next_V_arr: Array, params: ParamsDict, **kwargs: Scalar 108 | ) -> tuple[Array, Array]: 109 | Q_arr, F_arr = Q_and_F(params=params, next_V_arr=next_V_arr, **kwargs) 110 | return argmax_and_max(Q_arr, where=F_arr, initial=-jnp.inf) 111 | 112 | return argmax_and_max_Q_over_c 113 | -------------------------------------------------------------------------------- /src/lcm/ndimage.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The JAX Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # Modifications made by Tim Mensinger, 2024 16 | 17 | import functools 18 | import itertools 19 | import operator 20 | from collections.abc import Sequence 21 | 22 | import jax.numpy as jnp 23 | from jax import Array, jit, lax, util 24 | 25 | 26 | @jit 27 | def map_coordinates( 28 | input: Array, 29 | coordinates: Sequence[Array], 30 | ) -> Array: 31 | """Map the input array to new coordinates using linear interpolation. 32 | 33 | Modified from JAX implementation of :func:`scipy.ndimage.map_coordinates`. 34 | 35 | Given an input array and a set of coordinates, this function returns the 36 | interpolated values of the input array at those coordinates. For coordinates outside 37 | the input array, linear extrapolation is used. 38 | 39 | Args: 40 | input: N-dimensional input array from which values are interpolated. 41 | coordinates: length-N sequence of arrays specifying the coordinates 42 | at which to evaluate the interpolated values 43 | 44 | Returns: 45 | The interpolated (extrapolated) values at the specified coordinates. 46 | 47 | """ 48 | if len(coordinates) != input.ndim: 49 | raise ValueError( 50 | "coordinates must be a sequence of length input.ndim, but " 51 | f"{len(coordinates)} != {input.ndim}" 52 | ) 53 | 54 | interpolation_data = [ 55 | _compute_indices_and_weights(coordinate, size) 56 | for coordinate, size in util.safe_zip(coordinates, input.shape) 57 | ] 58 | 59 | interpolation_values = [] 60 | for indices_and_weights in itertools.product(*interpolation_data): 61 | indices, weights = util.unzip2(indices_and_weights) 62 | contribution = input[indices] 63 | weighted_value = _multiply_all(weights) * contribution 64 | interpolation_values.append(weighted_value) 65 | 66 | result = _sum_all(interpolation_values) 67 | 68 | if jnp.issubdtype(input.dtype, jnp.integer): 69 | result = _round_half_away_from_zero(result) 70 | 71 | return result.astype(input.dtype) 72 | 73 | 74 | def _compute_indices_and_weights( 75 | coordinate: Array, input_size: int 76 | ) -> list[tuple[Array, Array]]: 77 | """Compute indices and weights for linear interpolation.""" 78 | lower_index = jnp.clip(jnp.floor(coordinate), 0, input_size - 2).astype(jnp.int32) 79 | upper_weight = coordinate - lower_index 80 | lower_weight = 1 - upper_weight 81 | return [(lower_index, lower_weight), (lower_index + 1, upper_weight)] 82 | 83 | 84 | def _multiply_all(arrs: Sequence[Array]) -> Array: 85 | """Multiply all arrays in the sequence.""" 86 | return functools.reduce(operator.mul, arrs) 87 | 88 | 89 | def _sum_all(arrs: Sequence[Array]) -> Array: 90 | """Sum all arrays in the sequence.""" 91 | return functools.reduce(operator.add, arrs) 92 | 93 | 94 | def _round_half_away_from_zero(a: Array) -> Array: 95 | return a if jnp.issubdtype(a.dtype, jnp.integer) else lax.round(a) 96 | -------------------------------------------------------------------------------- /src/lcm/next_state.py: -------------------------------------------------------------------------------- 1 | """Generate function that compute the next states for solution and simulation.""" 2 | 3 | from collections.abc import Callable 4 | 5 | from dags import concatenate_functions 6 | from dags.signature import with_signature 7 | from jax import Array 8 | 9 | from lcm.interfaces import InternalModel 10 | from lcm.random import random_choice 11 | from lcm.typing import Scalar, StochasticNextFunction, Target 12 | 13 | 14 | def get_next_state_function( 15 | model: InternalModel, 16 | target: Target, 17 | ) -> Callable[..., dict[str, Scalar]]: 18 | """Get function that computes the next states during the solution. 19 | 20 | Args: 21 | model: Internal model instance. 22 | target: Whether to generate the function for the solve or simulate target. 23 | 24 | Returns: 25 | Function that computes the next states. Depends on states and actions of the 26 | current period, and the model parameters ("params"). If target is "simulate", 27 | the function also depends on the dictionary of random keys ("keys"), which 28 | corresponds to the names of stochastic next functions. 29 | 30 | """ 31 | targets = model.function_info.query("is_next").index.tolist() 32 | 33 | if target == Target.SOLVE: 34 | functions_dict = model.functions 35 | elif target == Target.SIMULATE: 36 | # For the simulation target, we need to extend the functions dictionary with 37 | # stochastic next states functions and their weights. 38 | functions_dict = _extend_functions_dict_for_simulation(model) 39 | else: 40 | raise ValueError(f"Invalid target: {target}") 41 | 42 | return concatenate_functions( 43 | functions=functions_dict, 44 | targets=targets, 45 | return_type="dict", 46 | enforce_signature=False, 47 | ) 48 | 49 | 50 | def get_next_stochastic_weights_function( 51 | model: InternalModel, 52 | ) -> Callable[..., dict[str, Array]]: 53 | """Get function that computes the weights for the next stochastic states. 54 | 55 | Args: 56 | model: Internal model instance. 57 | 58 | Returns: 59 | Function that computes the weights for the next stochastic states. 60 | 61 | """ 62 | targets = [ 63 | f"weight_{name}" 64 | for name in model.function_info.query("is_stochastic_next").index.tolist() 65 | ] 66 | 67 | return concatenate_functions( 68 | functions=model.functions, 69 | targets=targets, 70 | return_type="dict", 71 | enforce_signature=False, 72 | ) 73 | 74 | 75 | def _extend_functions_dict_for_simulation( 76 | model: InternalModel, 77 | ) -> dict[str, Callable[..., Scalar]]: 78 | """Extend the functions dictionary for the simulation target. 79 | 80 | Args: 81 | model: Internal model instance. 82 | 83 | Returns: 84 | Extended functions dictionary. 85 | 86 | """ 87 | stochastic_targets = model.function_info.query("is_stochastic_next").index 88 | 89 | # Handle stochastic next states functions 90 | # ---------------------------------------------------------------------------------- 91 | # We generate stochastic next states functions that simulate the next state given 92 | # a random key (think of a seed) and the weights corresponding to the labels of the 93 | # stochastic variable. The weights are computed using the stochastic weight 94 | # functions, which we add the to functions dict. `dags.concatenate_functions` then 95 | # generates a function that computes the weights and simulates the next state in 96 | # one go. 97 | # ---------------------------------------------------------------------------------- 98 | stochastic_next = { 99 | name: _create_stochastic_next_func( 100 | name, labels=model.grids[name.removeprefix("next_")] 101 | ) 102 | for name in stochastic_targets 103 | } 104 | 105 | stochastic_weights = { 106 | f"weight_{name}": model.functions[f"weight_{name}"] 107 | for name in stochastic_targets 108 | } 109 | 110 | # Overwrite model.functions with generated stochastic next states functions 111 | # ---------------------------------------------------------------------------------- 112 | return model.functions | stochastic_next | stochastic_weights 113 | 114 | 115 | def _create_stochastic_next_func(name: str, labels: Array) -> StochasticNextFunction: 116 | """Get function that simulates the next state of a stochastic variable. 117 | 118 | Args: 119 | name: Name of the stochastic variable. 120 | labels: 1d array of labels. 121 | 122 | Returns: 123 | A function that simulates the next state of the stochastic variable. The 124 | function must be called with keyword arguments: 125 | - weight_{name}: 2d array of weights. The first dimension corresponds to the 126 | number of simulation units. The second dimension corresponds to the number of 127 | grid points (labels). 128 | - keys: Dictionary with random key arrays. Dictionary keys correspond to the 129 | names of stochastic next functions, e.g. 'next_health'. 130 | 131 | """ 132 | 133 | @with_signature(args=[f"weight_{name}", "keys"]) 134 | def next_stochastic_state(keys: dict[str, Array], **kwargs: Array) -> Array: 135 | return random_choice( 136 | labels=labels, 137 | probs=kwargs[f"weight_{name}"], 138 | key=keys[name], 139 | ) 140 | 141 | return next_stochastic_state 142 | -------------------------------------------------------------------------------- /src/lcm/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenSourceEconomics/pylcm/5692daa838cc944ef5f8d438aba5507b736eaf6c/src/lcm/py.typed -------------------------------------------------------------------------------- /src/lcm/random.py: -------------------------------------------------------------------------------- 1 | import os 2 | from functools import partial 3 | 4 | import jax 5 | from jax import Array 6 | 7 | 8 | def random_choice( 9 | labels: jax.Array, 10 | probs: jax.Array, 11 | key: jax.Array, 12 | ) -> jax.Array: 13 | """Draw multiple random choices. 14 | 15 | Args: 16 | labels: 1d array of labels. 17 | probs: 2d array of probabilities. Second dimension must be the same length as 18 | the dimension of labels. 19 | key: Random key. 20 | 21 | Returns: 22 | Selected labels. 1d array of length len(probs). 23 | 24 | """ 25 | keys = jax.random.split(key, probs.shape[0]) 26 | return _vmapped_choice(keys, probs, labels) 27 | 28 | 29 | @partial(jax.vmap, in_axes=(0, 0, None)) 30 | def _vmapped_choice(key: jax.Array, probs: jax.Array, labels: jax.Array) -> jax.Array: 31 | return jax.random.choice(key, a=labels, p=probs) 32 | 33 | 34 | def generate_simulation_keys( 35 | key: Array, ids: list[str] 36 | ) -> tuple[Array, dict[str, Array]]: 37 | """Generate pseudo-random number generator keys (PRNG keys) for simulation. 38 | 39 | PRNG keys in JAX are immutable objects used to control random number generation. 40 | A key can be used to generate a stream of random numbers, e.g., given a key, one can 41 | call jax.random.normal(key) to generate a stream of normal random numbers. In order 42 | to ensure that each simulation is based on a different stream of random numbers, we 43 | split the key into one key per stochastic variable, and one key that will be passed 44 | to the next iteration in order to generate new keys. 45 | 46 | See the JAX documentation for more details: 47 | https://docs.jax.dev/en/latest/random-numbers.html#random-numbers-in-jax 48 | 49 | Args: 50 | key: Random key. 51 | ids: List of names for which a key is to be generated. 52 | 53 | Returns: 54 | - Updated random key. 55 | - Dict with random keys for each id in ids. 56 | 57 | """ 58 | keys = jax.random.split(key, num=len(ids) + 1) 59 | 60 | next_key = keys[0] 61 | simulation_keys = dict(zip(ids, keys[1:], strict=True)) 62 | 63 | return next_key, simulation_keys 64 | 65 | 66 | def draw_random_seed() -> int: 67 | """Generate a random seed using the operating system's secure entropy pool. 68 | 69 | Returns: 70 | Random seed. 71 | 72 | """ 73 | return int.from_bytes(os.urandom(4), "little") 74 | -------------------------------------------------------------------------------- /src/lcm/simulation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenSourceEconomics/pylcm/5692daa838cc944ef5f8d438aba5507b736eaf6c/src/lcm/simulation/__init__.py -------------------------------------------------------------------------------- /src/lcm/simulation/processing.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | 3 | import jax.numpy as jnp 4 | import pandas as pd 5 | from dags import concatenate_functions 6 | from jax import Array 7 | 8 | from lcm.dispatchers import vmap_1d 9 | from lcm.interfaces import InternalModel, InternalSimulationPeriodResults 10 | from lcm.typing import InternalUserFunction, ParamsDict 11 | 12 | 13 | def process_simulated_data( 14 | results: dict[int, InternalSimulationPeriodResults], 15 | model: InternalModel, 16 | params: ParamsDict, 17 | additional_targets: list[str] | None = None, 18 | ) -> dict[str, Array]: 19 | """Process and flatten the simulation results. 20 | 21 | This function produces a dict of arrays for each var with dimension (n_periods * 22 | n_initial_states,). The arrays are flattened, so that the resulting dictionary has a 23 | one-dimensional array for each variable. The length of this array is the number of 24 | periods times the number of initial states. The order of array elements is given by 25 | an outer level of periods and an inner level of initial states ids. 26 | 27 | Args: 28 | results: Dict with simulation results. Each dict contains the value, 29 | actions, and states for one period. Actions and states are stored in a 30 | nested dictionary. 31 | model: Model. 32 | params: Parameters. 33 | additional_targets: List of additional targets to compute. 34 | 35 | Returns: 36 | Dict with processed simulation results. The keys are the variable names and the 37 | values are the flattened arrays, with dimension (n_periods * n_initial_states,). 38 | Additionally, the _period variable is added. 39 | 40 | """ 41 | n_periods = len(results) 42 | n_initial_states = len(results[0].value) 43 | 44 | list_of_dicts = [ 45 | {"value": d.value, **d.actions, **d.states} for d in results.values() 46 | ] 47 | dict_of_lists = { 48 | key: [d[key] for d in list_of_dicts] for key in list(list_of_dicts[0]) 49 | } 50 | out = {key: jnp.concatenate(values) for key, values in dict_of_lists.items()} 51 | out["_period"] = jnp.repeat(jnp.arange(n_periods), n_initial_states) 52 | 53 | if additional_targets is not None: 54 | calculated_targets = _compute_targets( 55 | out, 56 | targets=additional_targets, 57 | model_functions=model.functions, 58 | params=params, 59 | ) 60 | out = {**out, **calculated_targets} 61 | 62 | return out 63 | 64 | 65 | def as_panel(processed: dict[str, Array], n_periods: int) -> pd.DataFrame: 66 | """Convert processed simulation results to panel. 67 | 68 | Args: 69 | processed: Dict with processed simulation results. 70 | n_periods: Number of periods. 71 | 72 | Returns: 73 | Panel with the simulation results. The index is a multi-index with the first 74 | level corresponding to the initial state id and the second level corresponding 75 | to the period. The columns correspond to the value, and the action and state 76 | variables, and potentially auxiliary variables. 77 | 78 | """ 79 | n_initial_states = len(processed["value"]) // n_periods 80 | index = pd.MultiIndex.from_product( 81 | [range(n_periods), range(n_initial_states)], 82 | names=["period", "initial_state_id"], 83 | ) 84 | return pd.DataFrame(processed, index=index) 85 | 86 | 87 | def _compute_targets( 88 | processed_results: dict[str, Array], 89 | targets: list[str], 90 | model_functions: dict[str, InternalUserFunction], 91 | params: ParamsDict, 92 | ) -> dict[str, Array]: 93 | """Compute targets. 94 | 95 | Args: 96 | processed_results: Dict with processed simulation results. Values must be 97 | one-dimensional arrays. 98 | targets: List of targets to compute. 99 | model_functions: Dict with model functions. 100 | params: Dict with model parameters. 101 | 102 | Returns: 103 | Dict with computed targets. 104 | 105 | """ 106 | target_func = concatenate_functions( 107 | functions=model_functions, 108 | targets=targets, 109 | return_type="dict", 110 | ) 111 | 112 | # get list of variables over which we want to vectorize the target function 113 | variables = tuple( 114 | p for p in list(inspect.signature(target_func).parameters) if p != "params" 115 | ) 116 | 117 | target_func = vmap_1d(target_func, variables=variables) 118 | 119 | kwargs = {k: v for k, v in processed_results.items() if k in variables} 120 | return target_func(params=params, **kwargs) 121 | -------------------------------------------------------------------------------- /src/lcm/solution/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenSourceEconomics/pylcm/5692daa838cc944ef5f8d438aba5507b736eaf6c/src/lcm/solution/__init__.py -------------------------------------------------------------------------------- /src/lcm/solution/solve_brute.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import jax.numpy as jnp 4 | from jax import Array 5 | 6 | from lcm.interfaces import StateActionSpace 7 | from lcm.typing import MaxQcOverDFunction, MaxQOverCFunction, ParamsDict 8 | 9 | 10 | def solve( 11 | params: ParamsDict, 12 | state_action_spaces: dict[int, StateActionSpace], 13 | max_Q_over_c_functions: dict[int, MaxQOverCFunction], 14 | max_Qc_over_d_functions: dict[int, MaxQcOverDFunction], 15 | logger: logging.Logger, 16 | ) -> dict[int, Array]: 17 | """Solve a model using grid search. 18 | 19 | Args: 20 | params: Dict of model parameters. 21 | state_action_spaces: Dict with one state_action_space per period. 22 | max_Q_over_c_functions: Dict with one function per period. The functions 23 | calculate the maximum of the Q-function over the continuous actions. The 24 | result corresponds to the Qc-function of that period. 25 | max_Qc_over_d_functions: Dict with one function per period. The functions 26 | calculate the the (expected) maximum of the Qc-function over the discrete 27 | actions. The result corresponds to the value function array of that period. 28 | logger: Logger that logs to stdout. 29 | 30 | Returns: 31 | Dict with one value function array per period. 32 | 33 | """ 34 | n_periods = len(state_action_spaces) 35 | solution = {} 36 | next_V_arr = jnp.empty(0) 37 | 38 | logger.info("Starting solution") 39 | 40 | # backwards induction loop 41 | for period in reversed(range(n_periods)): 42 | state_action_space = state_action_spaces[period] 43 | 44 | max_Qc_over_d = max_Qc_over_d_functions[period] 45 | max_Q_over_c = max_Q_over_c_functions[period] 46 | 47 | # evaluate Q-function on states and actions, and maximize over continuous 48 | # actions 49 | Qc_arr = max_Q_over_c( 50 | **state_action_space.states, 51 | **state_action_space.discrete_actions, 52 | **state_action_space.continuous_actions, 53 | next_V_arr=next_V_arr, 54 | params=params, 55 | ) 56 | 57 | # maximize Qc-function evaluations over discrete actions 58 | V_arr = max_Qc_over_d(Qc_arr, params=params) 59 | 60 | solution[period] = V_arr 61 | next_V_arr = V_arr 62 | logger.info("Period: %s", period) 63 | 64 | return solution 65 | -------------------------------------------------------------------------------- /src/lcm/state_action_space.py: -------------------------------------------------------------------------------- 1 | """Create a state space for a given model.""" 2 | 3 | import pandas as pd 4 | from jax import Array 5 | 6 | from lcm.grids import ContinuousGrid, DiscreteGrid 7 | from lcm.interfaces import InternalModel, StateActionSpace, StateSpaceInfo 8 | 9 | 10 | def create_state_action_space( 11 | model: InternalModel, 12 | *, 13 | initial_states: dict[str, Array] | None = None, 14 | is_last_period: bool = False, 15 | ) -> StateActionSpace: 16 | """Create a state-action-space. 17 | 18 | Creates the state-action-space for the solution and simulation of a model. In the 19 | simulation, initial states must be provided. 20 | 21 | Args: 22 | model: A processed model. 23 | initial_states: A dictionary with the initial values of the state variables. 24 | If None, the initial values are the minimum values of the state variables. 25 | is_last_period: Whether the state-action-space is created for the last period, 26 | in which case auxiliary variables are not included. 27 | 28 | Returns: 29 | A state-action-space. Contains the grids of the discrete and continuous actions, 30 | the grids of the state variables, or the initial values of the state variables, 31 | and the names of the state and action variables in the order they appear in the 32 | variable info table. 33 | 34 | """ 35 | vi = model.variable_info 36 | if is_last_period: 37 | vi = vi.query("~is_auxiliary") 38 | 39 | if initial_states is None: 40 | states = {sn: model.grids[sn] for sn in vi.query("is_state").index} 41 | else: 42 | _validate_initial_states_names(initial_states, variable_info=vi) 43 | states = initial_states 44 | 45 | discrete_actions = { 46 | name: model.grids[name] for name in vi.query("is_action & is_discrete").index 47 | } 48 | continuous_actions = { 49 | name: model.grids[name] for name in vi.query("is_action & is_continuous").index 50 | } 51 | ordered_var_names = tuple(vi.query("is_state | is_discrete").index) 52 | 53 | return StateActionSpace( 54 | states=states, 55 | discrete_actions=discrete_actions, 56 | continuous_actions=continuous_actions, 57 | states_and_discrete_actions_names=ordered_var_names, 58 | ) 59 | 60 | 61 | def create_state_space_info( 62 | model: InternalModel, 63 | *, 64 | is_last_period: bool, 65 | ) -> StateSpaceInfo: 66 | """Collect information on the state space for the model solution. 67 | 68 | A state-space information is a compressed representation of all feasible states. 69 | 70 | Args: 71 | model: A processed model. 72 | is_last_period: Whether the function is created for the last period. 73 | 74 | Returns: 75 | The state-space information. 76 | 77 | """ 78 | vi = model.variable_info 79 | if is_last_period: 80 | vi = vi.query("~is_auxiliary") 81 | 82 | state_names = vi.query("is_state").index.tolist() 83 | 84 | discrete_states = { 85 | name: grid_spec 86 | for name, grid_spec in model.gridspecs.items() 87 | if name in state_names and isinstance(grid_spec, DiscreteGrid) 88 | } 89 | 90 | continuous_states = { 91 | name: grid_spec 92 | for name, grid_spec in model.gridspecs.items() 93 | if name in state_names and isinstance(grid_spec, ContinuousGrid) 94 | } 95 | 96 | return StateSpaceInfo( 97 | states_names=tuple(state_names), 98 | discrete_states=discrete_states, 99 | continuous_states=continuous_states, 100 | ) 101 | 102 | 103 | def _validate_initial_states_names( 104 | initial_states: dict[str, Array], variable_info: pd.DataFrame 105 | ) -> None: 106 | """Checks if each model-state has an initial value.""" 107 | states_names_in_model = set(variable_info.query("is_state").index) 108 | provided_states_names = set(initial_states) 109 | 110 | if states_names_in_model != provided_states_names: 111 | missing = states_names_in_model - provided_states_names 112 | too_many = provided_states_names - states_names_in_model 113 | raise ValueError( 114 | "You need to provide an initial array for each state variable in the model." 115 | f"\n\nMissing initial states: {missing}\n", 116 | f"Provided variables that are not states: {too_many}", 117 | ) 118 | -------------------------------------------------------------------------------- /src/lcm/typing.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | from typing import Any, Protocol 3 | 4 | from jax import Array 5 | 6 | # Many JAX functions are designed to work with scalar numerical values. This also 7 | # includes zero dimensional jax arrays. 8 | Scalar = int | float | Array 9 | 10 | 11 | ParamsDict = dict[str, Any] 12 | 13 | 14 | class UserFunction(Protocol): 15 | """A function provided by the user. 16 | 17 | Only used for type checking. 18 | 19 | """ 20 | 21 | def __call__(self, *args: Any, **kwargs: Any) -> Any: ... # noqa: ANN401, D102 22 | 23 | 24 | class InternalUserFunction(Protocol): 25 | """The internal representation of a function provided by the user. 26 | 27 | Only used for type checking. 28 | 29 | """ 30 | 31 | def __call__( # noqa: D102 32 | self, *args: Scalar, params: ParamsDict, **kwargs: Scalar 33 | ) -> Scalar: ... 34 | 35 | 36 | class MaxQOverCFunction(Protocol): 37 | """The function that maximizes Q over the continuous actions. 38 | 39 | Q is the state-action value function. The MaxQOverCFunction returns the maximum of Q 40 | over the continuous actions. 41 | 42 | Only used for type checking. 43 | 44 | """ 45 | 46 | def __call__( # noqa: D102 47 | self, next_V_arr: Array, params: ParamsDict, **kwargs: Scalar 48 | ) -> Array: ... 49 | 50 | 51 | class ArgmaxQOverCFunction(Protocol): 52 | """The function that finds the argmax of Q over the continuous actions. 53 | 54 | Q is the state-action value function. The ArgmaxQOverCFunction returns the argmax 55 | and the maximum of Q over the continuous actions. 56 | 57 | Only used for type checking. 58 | 59 | """ 60 | 61 | def __call__( # noqa: D102 62 | self, next_V_arr: Array, params: ParamsDict, **kwargs: Scalar 63 | ) -> tuple[Array, Array]: ... 64 | 65 | 66 | class MaxQcOverDFunction(Protocol): 67 | """The function that maximizes Qc over the discrete actions. 68 | 69 | Qc is the maximum of the state-action value function (Q) over the continuous 70 | actions, conditional on the discrete action. It depends on a state and the discrete 71 | actions. The MaxQcFunction returns the maximum of Qc over the discrete actions. 72 | 73 | Only used for type checking. 74 | 75 | """ 76 | 77 | def __call__(self, Qc_arr: Array, params: ParamsDict) -> Array: ... # noqa: D102 78 | 79 | 80 | class ArgmaxQcOverDFunction(Protocol): 81 | """The function that finds the argmax of Qc over the discrete actions. 82 | 83 | Qc is the maximum of the state-action value function (Q) over the continuous 84 | actions, conditional on the discrete action. It depends on a state and the discrete 85 | actions. The ArgmaxQcFunction returns the argmax of Qc over the discrete actions. 86 | 87 | Only used for type checking. 88 | 89 | """ 90 | 91 | def __call__(self, Qc_arr: Array, params: ParamsDict) -> tuple[Array, Array]: ... # noqa: D102 92 | 93 | 94 | class StochasticNextFunction(Protocol): 95 | """The function that simulates the next state of a stochastic variable. 96 | 97 | Only used for type checking. 98 | 99 | """ 100 | 101 | def __call__(self, keys: dict[str, Array], **kwargs: Array) -> Array: ... # noqa: D102 102 | 103 | 104 | class ShockType(Enum): 105 | """Type of shocks.""" 106 | 107 | EXTREME_VALUE = "extreme_value" 108 | NONE = None 109 | 110 | 111 | class Target(Enum): 112 | """Target of the function.""" 113 | 114 | SOLVE = "solve" 115 | SIMULATE = "simulate" 116 | -------------------------------------------------------------------------------- /src/lcm/user_model.py: -------------------------------------------------------------------------------- 1 | """Collection of classes that are used by the user to define the model and grids.""" 2 | 3 | import dataclasses as dc 4 | from dataclasses import KW_ONLY, dataclass, field 5 | from typing import Any 6 | 7 | from lcm.exceptions import ModelInitilizationError, format_messages 8 | from lcm.grids import Grid 9 | from lcm.typing import UserFunction 10 | 11 | 12 | @dataclass(frozen=True) 13 | class Model: 14 | """A user model which can be processed into an internal model. 15 | 16 | Attributes: 17 | description: Description of the model. 18 | n_periods: Number of periods in the model. 19 | functions: Dictionary of user provided functions that define the functional 20 | relationships between model variables. It must include at least a function 21 | called 'utility'. 22 | actions: Dictionary of user provided actions. 23 | states: Dictionary of user provided states. 24 | 25 | """ 26 | 27 | description: str | None = None 28 | _: KW_ONLY 29 | n_periods: int 30 | functions: dict[str, UserFunction] = field(default_factory=dict) 31 | actions: dict[str, Grid] = field(default_factory=dict) 32 | states: dict[str, Grid] = field(default_factory=dict) 33 | 34 | def __post_init__(self) -> None: 35 | _validate_attribute_types(self) 36 | _validate_logical_consistency(self) 37 | 38 | def replace(self, **kwargs: Any) -> "Model": # noqa: ANN401 39 | """Replace the attributes of the model. 40 | 41 | Args: 42 | **kwargs: Keyword arguments to replace the attributes of the model. 43 | 44 | Returns: 45 | A new model with the replaced attributes. 46 | 47 | """ 48 | try: 49 | return dc.replace(self, **kwargs) 50 | except TypeError as e: 51 | raise ModelInitilizationError( 52 | f"Failed to replace attributes of the model. The error was: {e}" 53 | ) from e 54 | 55 | 56 | def _validate_attribute_types(model: Model) -> None: # noqa: C901 57 | """Validate the types of the model attributes.""" 58 | error_messages = [] 59 | 60 | # Validate types of states and actions 61 | # ---------------------------------------------------------------------------------- 62 | for attr_name in ("actions", "states"): 63 | attr = getattr(model, attr_name) 64 | if isinstance(attr, dict): 65 | for k, v in attr.items(): 66 | if not isinstance(k, str): 67 | error_messages.append(f"{attr_name} key {k} must be a string.") 68 | if not isinstance(v, Grid): 69 | error_messages.append(f"{attr_name} value {v} must be an LCM grid.") 70 | else: 71 | error_messages.append(f"{attr_name} must be a dictionary.") 72 | 73 | # Validate types of functions 74 | # ---------------------------------------------------------------------------------- 75 | if isinstance(model.functions, dict): 76 | for k, v in model.functions.items(): 77 | if not isinstance(k, str): 78 | error_messages.append(f"function keys must be a strings, but is {k}.") 79 | if not callable(v): 80 | error_messages.append( 81 | f"function values must be a callable, but is {v}." 82 | ) 83 | else: 84 | error_messages.append("functions must be a dictionary.") 85 | 86 | if error_messages: 87 | msg = format_messages(error_messages) 88 | raise ModelInitilizationError(msg) 89 | 90 | 91 | def _validate_logical_consistency(model: Model) -> None: 92 | """Validate the logical consistency of the model.""" 93 | error_messages = [] 94 | 95 | if model.n_periods < 1: 96 | error_messages.append("Number of periods must be a positive integer.") 97 | 98 | if "utility" not in model.functions: 99 | error_messages.append( 100 | "Utility function is not defined. LCM expects a function called 'utility' " 101 | "in the functions dictionary.", 102 | ) 103 | 104 | states_without_next_func = [ 105 | state for state in model.states if f"next_{state}" not in model.functions 106 | ] 107 | if states_without_next_func: 108 | error_messages.append( 109 | "Each state must have a corresponding next state function. For the " 110 | "following states, no next state function was found: " 111 | f"{states_without_next_func}.", 112 | ) 113 | 114 | states_and_actions_overlap = set(model.states) & set(model.actions) 115 | if states_and_actions_overlap: 116 | error_messages.append( 117 | "States and actions cannot have overlapping names. The following names " 118 | f"are used in both states and actions: {states_and_actions_overlap}.", 119 | ) 120 | 121 | if error_messages: 122 | msg = format_messages(error_messages) 123 | raise ModelInitilizationError(msg) 124 | -------------------------------------------------------------------------------- /src/lcm/utils.py: -------------------------------------------------------------------------------- 1 | from collections import Counter 2 | from collections.abc import Iterable 3 | from itertools import chain 4 | from typing import TypeVar 5 | 6 | T = TypeVar("T") 7 | 8 | 9 | def find_duplicates(*containers: Iterable[T]) -> set[T]: 10 | combined = chain.from_iterable(containers) 11 | counts = Counter(combined) 12 | return {v for v, count in counts.items() if count > 1} 13 | 14 | 15 | def first_non_none(*args: T | None) -> T: 16 | """Return the first non-None argument. 17 | 18 | Args: 19 | *args: Arguments to check. 20 | 21 | Returns: 22 | The first non-None argument. 23 | 24 | Raises: 25 | ValueError: If all arguments are None. 26 | 27 | """ 28 | for arg in args: 29 | if arg is not None: 30 | return arg 31 | raise ValueError("All arguments are None") 32 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenSourceEconomics/pylcm/5692daa838cc944ef5f8d438aba5507b736eaf6c/tests/__init__.py -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | from dataclasses import make_dataclass 2 | 3 | import pytest 4 | from jax import config 5 | 6 | 7 | def pytest_sessionstart(session): # noqa: ARG001 8 | config.update("jax_enable_x64", val=True) 9 | 10 | 11 | @pytest.fixture(scope="session") 12 | def binary_category_class(): 13 | return make_dataclass("BinaryCategoryClass", [("cat0", int, 0), ("cat1", int, 1)]) 14 | -------------------------------------------------------------------------------- /tests/data/analytical_solution/iskhakov_2017_five_periods__consumption.csv: -------------------------------------------------------------------------------- 1 | 1.000000000000000000e+00,6.210526315789473450e+00,1.142105263157894690e+01,1.663157894736842124e+01,2.103186820267919188e+01,2.228424633291450263e+01,2.336887779159618361e+01,2.029027738867141650e+01,2.137490884735309393e+01,2.245954030603477491e+01,1.938093990311000780e+01,2.046557136179168523e+01,2.155020282047336622e+01,1.847160241754859911e+01,1.955623387623028009e+01,2.064086533491195752e+01,2.172549679359363850e+01,1.864689639066887139e+01,1.973152784935054882e+01,2.081615930803222980e+01 2 | 2.000000000000000000e+01,2.000000000000000000e+01,2.000000000000000000e+01,2.000000000000000000e+01,2.061123083862560534e+01,2.183856140625621123e+01,2.290150023576426008e+01,1.988447184089798725e+01,2.094741067040603255e+01,2.201034949991408141e+01,1.899332110504780502e+01,2.005625993455585032e+01,2.111919876406390273e+01,1.810217036919762634e+01,1.916510919870567164e+01,2.022804802821371695e+01,2.129098685772176580e+01,1.827395846285549297e+01,1.933689729236353827e+01,2.039983612187158712e+01 3 | 2.000000000000000000e+01,2.000000000000000000e+01,2.000000000000000000e+01,2.000000000000000000e+01,2.019900622185309658e+01,2.140179017813109041e+01,2.244347023104897332e+01,1.948678240408002793e+01,2.052846245699791439e+01,2.157014250991579729e+01,1.861345468294685190e+01,1.965513473586473481e+01,2.069681478878262126e+01,1.774012696181367588e+01,1.878180701473155878e+01,1.982348706764943813e+01,2.086516712056733169e+01,1.790847929359838631e+01,1.895015934651626566e+01,1.999183939943415567e+01 4 | 2.000000000000000000e+01,2.000000000000000000e+01,2.000000000000000000e+01,2.000000000000000000e+01,2.000000000000000000e+01,2.097375437456846825e+01,2.199460082642799108e+01,1.909704675599842716e+01,2.011789320785795709e+01,2.113873965971747992e+01,1.824118558928791245e+01,1.926203204114743883e+01,2.028287849300696877e+01,1.738532442257740485e+01,1.840617087443692768e+01,1.942701732629645051e+01,2.044786377815598044e+01,1.755030970772641652e+01,1.857115615958593935e+01,1.959200261144546928e+01 5 | 2.000000000000000000e+01,2.000000000000000000e+01,2.000000000000000000e+01,2.000000000000000000e+01,2.000000000000000000e+01,2.055427928707709384e+01,2.155470880989943439e+01,1.871510582087845620e+01,1.971553534370079674e+01,2.071596486652313018e+01,1.787636187750215200e+01,1.887679140032448899e+01,1.987722092314682953e+01,1.703761793412585490e+01,1.803804745694818834e+01,1.903847697977052178e+01,2.003890650259286232e+01,1.719930351357188769e+01,1.819973303639422113e+01,1.920016255921656168e+01 6 | -------------------------------------------------------------------------------- /tests/data/analytical_solution/iskhakov_2017_five_periods__work_decision.csv: -------------------------------------------------------------------------------- 1 | 1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00 2 | 1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00 3 | 1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00 4 | 1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00 5 | 0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00 6 | -------------------------------------------------------------------------------- /tests/data/analytical_solution/iskhakov_2017_low_delta__consumption.csv: -------------------------------------------------------------------------------- 1 | 1.000000000000000000e+00,6.210526315789473450e+00,1.142105263157894690e+01,1.663157894736842124e+01,2.103186820267919188e+01,2.280391497039429183e+01,2.457596173810939177e+01,2.634800850582448462e+01,2.812005527353958101e+01,2.989210204125468096e+01,3.166414880896978090e+01,3.343619557668487374e+01,3.520824234439997724e+01,3.698028911211507364e+01,3.875233587983017003e+01,4.052438264754526642e+01,4.229642941526036282e+01,4.406847618297545921e+01,4.584052295069056271e+01,4.761256971840565910e+01 2 | 2.000000000000000000e+01,2.000000000000000000e+01,2.000000000000000000e+01,2.000000000000000000e+01,2.061123083862560534e+01,2.234783667098639981e+01,2.408444250334719783e+01,2.582104833570799585e+01,2.755765416806879031e+01,2.929426000042958478e+01,3.103086583279038280e+01,3.276747166515117726e+01,3.450407749751197173e+01,3.624068332987276619e+01,3.797728916223356777e+01,3.971389499459435513e+01,4.145050082695515670e+01,4.318710665931595827e+01,4.492371249167674563e+01,4.666031832403754009e+01 3 | 2.000000000000000000e+01,2.000000000000000000e+01,2.000000000000000000e+01,2.000000000000000000e+01,2.019900622185309658e+01,2.190087993756667473e+01,2.360275365328025288e+01,2.530462736899383458e+01,2.700650108470741273e+01,2.870837480042099443e+01,3.041024851613457258e+01,3.211212223184815429e+01,3.381399594756173599e+01,3.551586966327531059e+01,3.721774337898889229e+01,3.891961709470247399e+01,4.062149081041605569e+01,4.232336452612963740e+01,4.402523824184320489e+01,4.572711195755679370e+01 4 | -------------------------------------------------------------------------------- /tests/data/analytical_solution/iskhakov_2017_low_delta__values_retired.csv: -------------------------------------------------------------------------------- 1 | -3.229959573599586520e+00,1.520198485547954448e+00,3.250493875862349391e+00,4.332508192337356512e+00,5.121711972527370094e+00,5.743293042426854456e+00,6.256130241457334762e+00,6.692666435337223163e+00,7.072693039657663050e+00,7.409177678641123421e+00,7.711080259224859290e+00,7.984851278818989151e+00,8.235289410597829374e+00,8.466061405685582741e+00,8.680032287809968494e+00,8.879483287060818242e+00,9.066260411337221825e+00,9.241878568790038884e+00,9.407596301982874110e+00,9.564470557851770138e+00,9.713397569967600731e+00,9.855143875628421313e+00,9.990370193381949093e+00,1.011965004656775946e+01,1.024348446189189232e+01,1.036231369565363991e+01,1.047652668095602380e+01,1.058646870759007896e+01,1.069244771707007402e+01,1.079473950207974653e+01,1.089359203146553057e+01,1.098922907152946848e+01,1.108185323669781042e+01,1.117164857417607138e+01,1.125878276548878532e+01,1.134340901109201383e+01,1.142566765127866546e+01,1.150568756645335000e+01,1.158358739186123287e+01,1.165947657551418004e+01,1.173345630299307629e+01,1.180562030873576873e+01,1.187605559013068302e+01,1.194484303806273218e+01,1.201205799537372343e+01,1.207777075290582403e+01,1.214204699131666132e+01,1.220494817562812884e+01,1.226653190844999131e+01,1.232685224696595938e+01,1.238595998805392639e+01,1.244390292530868436e+01,1.250072608122582452e+01,1.255647191737267931e+01,1.261118052500429343e+01,1.266488979826785766e+01,1.271763559186999260e+01,1.276945186484993400e+01,1.282037081190241778e+01,1.287042298352217529e+01,1.291963739609261452e+01,1.296804163291223233e+01,1.301566193703931873e+01,1.306252329673752044e+01,1.310864952421868779e+01,1.315406332830430713e+01,1.319878638156050599e+01,1.324283938240348135e+01,1.328624211262086874e+01,1.332901349070916908e+01,1.337117162138718385e+01,1.341273384160979276e+01,1.345371676337485312e+01,1.349413631358759602e+01,1.353400777122208964e+01,1.357334580199666973e+01,1.361216449076028923e+01,1.365047737176879039e+01,1.368829745701393819e+01,1.372563726275355123e+01,1.376250883437819894e+01,1.379892376973792878e+01,1.383489324104220586e+01,1.387042801543642057e+01,1.390553847434985002e+01,1.394023463170201893e+01,1.397452615104741369e+01,1.400842236173187771e+01,1.404193227412837786e+01,1.407506459401433574e+01,1.410782773614790031e+01,1.414022983709618941e+01,1.417227876736436798e+01,1.420398214287082084e+01,1.423534733581024980e+01,1.426638148494351910e+01,1.429709150535007822e+01,1.432748409767640041e+01,1.435756575691126002e+01,1.438734278071672179e+01 2 | -1.372330405689927924e+00,1.826320443552396755e+00,2.991462897920000863e+00,3.720067287329779759e+00,4.251499570276393136e+00,4.670058446075938896e+00,5.015391616420228793e+00,5.309345385935320571e+00,5.565246853951397732e+00,5.791828130372038252e+00,5.995122617365574946e+00,6.179473936504597198e+00,6.348113407638490635e+00,6.503510139468768436e+00,6.647593375289162942e+00,6.781899244734366228e+00,6.907670808490075309e+00,7.025928172031286500e+00,7.137518809945112785e+00,7.243154446804184410e+00,7.343438586373406629e+00,7.438887397830495907e+00,7.529945794358854805e+00,7.616999973826920289e+00,7.700387316481587696e+00,7.780404281128663335e+00,7.857312766674412252e+00,7.931345283588922435e+00,8.002709192842900165e+00,8.071590207099175629e+00,8.138155303067021507e+00,8.202555160000294165e+00,8.264926213949864930e+00,8.325392398212876799e+00,8.384066625800675965e+00,8.441052058495001731e+00,8.496443198329572510e+00,8.550326830504065612e+00,8.602782841355516297e+00,8.653884930742277604e+00,8.703701234785336638e+00,8.752294872171587414e+00,8.799724425008593087e+00,8.846044363420189072e+00,8.891305421601309433e+00,8.935554931842622395e+00,8.978837122039010410e+00,9.021193381369954523e+00,9.062662498152462831e+00,9.103280873292449726e+00,9.143082712278360447e+00,9.182100198254563139e+00,9.220363648368810061e+00,9.257901655296702614e+00,9.294741215598223860e+00,9.330907846349774104e+00,9.366425691313802204e+00,9.401317617752477318e+00,9.435605304857608289e+00,9.469309324653252702e+00,9.502449216127020648e+00,9.535043553258990556e+00,9.567110007541277383e+00,9.598665405515170335e+00,9.629725781794830297e+00,9.660306427995866585e+00,9.690421937942566899e+00,9.720086249488304020e+00,9.749312683249163669e+00,9.778113978520158511e+00,9.806502326616493903e+00,9.834489401858219182e+00,9.862086390395431224e+00,9.889304017052086948e+00,9.916152570349700213e+00,9.942641925857000018e+00,9.968781567998194149e+00,9.994580610440323198e+00,1.002004781516940390e+01,1.004519161035524100e+01,1.007002010709609330e+01,1.009454111512638264e+01,1.011876215756362640e+01,1.014269048476419322e+01,1.016633308735179497e+01,1.018969670847725340e+01,1.021278785536335043e+01,1.023561281018420388e+01,1.025817764032468560e+01,1.028048820806179364e+01,1.030255017970662479e+01,1.032436903424261132e+01,1.034595007149298418e+01,1.036729841984789147e+01,1.038841904357938084e+01,1.040931674977033161e+01,1.042999619488153762e+01,1.045046189097938694e+01,1.047071821164495020e+01,1.049076939758387539e+01 3 | 0.000000000000000000e+00,1.615480226890062987e+00,2.203936011924206184e+00,2.571918026777630040e+00,2.840318169679960647e+00,3.051711541295892172e+00,3.226122233388967864e+00,3.374583733144065345e+00,3.503826898808750201e+00,3.618261886899983182e+00,3.720935870230052345e+00,3.814042597067942353e+00,3.899214047135564964e+00,3.977697245029644613e+00,4.050466556050046130e+00,4.118297803244593602e+00,4.181818795040405767e+00,4.241544736222835787e+00,4.297903644260121503e+00,4.351254976007127517e+00,4.401903531345118914e+00,4.450110001777992430e+00,4.496099090933729059e+00,4.540065848240832480e+00,4.582180667763392279e+00,4.622593276171006238e+00,4.661435945638555900e+00,4.698826105696389810e+00,4.734868484107488840e+00,4.769656875146011998e+00,4.803275610483307645e+00,4.835800790752637468e+00,4.867301323050400086e+00,4.897839799950911122e+00,4.927473248227578040e+00,4.956253769790368580e+00,4.984229092939141736e+00,5.011443048582824744e+00,5.037935983356284453e+00,5.063745119410204332e+00,5.088904868926901415e+00,5.113447110031068554e+00,5.137401429645717776e+00,5.160795337934402482e+00,5.183654458227897166e+00,5.206002695723510598e+00,5.227862387741888384e+00,5.249254437909031346e+00,5.270198436284035282e+00,5.290712767162816910e+00,5.310814706044589606e+00,5.330520507042671774e+00,5.349845481847847140e+00,5.368804071205368622e+00,5.387409909741490921e+00,5.405675884868536230e+00,5.423614190405924518e+00,5.441236375475962284e+00,5.458553389165421876e+00,5.475575621385444514e+00,5.492312940311590452e+00,5.508774726741878069e+00,5.524969905672326043e+00,5.540906975356110209e+00,5.556594034083211398e+00,5.572038804891815467e+00,5.587248658400249646e+00,5.602230633928400039e+00,5.616991459060146852e+00,5.631537567782872067e+00,5.645875117326475845e+00,5.660010003812195656e+00,5.673947876810787605e+00,5.687694152900007794e+00,5.701254028302842336e+00,5.714632490680267018e+00,5.727834330145515906e+00,5.740864149560732876e+00,5.753726374171379554e+00,5.766425260628873062e+00,5.778964905447485201e+00,5.791349252937530778e+00,5.803582102653310315e+00,5.815667116390970115e+00,5.827607824768547218e+00,5.839407633417768473e+00,5.851069828814787321e+00,5.862597583774814325e+00,5.873993962633643484e+00,5.885261926137233424e+00,5.896404336058864715e+00,5.907423959561888260e+00,5.918323473324702988e+00,5.929105467443343436e+00,5.939772449125913845e+00,5.950326846192050212e+00,5.960771010389628977e+00,5.971107220540058158e+00,5.981337685522666092e+00,5.991464547107981709e+00 4 | -------------------------------------------------------------------------------- /tests/data/analytical_solution/iskhakov_2017_low_delta__values_worker.csv: -------------------------------------------------------------------------------- 1 | 5.614918903604164235e+00,7.230399130494227222e+00,7.818854915528369531e+00,8.186836930381794275e+00,8.455237073284123994e+00,8.666947148202604367e+00,8.854622167963208668e+00,9.031032764173833272e+00,9.197455904884895617e+00,9.354962128287901635e+00,9.504458569255881173e+00,9.646720610548408459e+00,9.782416223144812406e+00,9.912124932958107593e+00,1.003635277753389943e+01,1.015554422896855336e+01,1.027009179278513074e+01,1.038034380604022644e+01,1.048661082543733336e+01,1.058917090073234135e+01,1.068827395899721644e+01,1.078414547378900501e+01,1.087698955477638307e+01,1.096699156431429856e+01,1.105432034530506336e+01,1.113913012764695765e+01,1.122156216735744927e+01,1.130174616212835659e+01,1.137980147893569338e+01,1.145583822287566811e+01,1.152995817124900491e+01,1.160225559277945528e+01,1.167281796851068876e+01,1.174172662821081481e+01,1.180905731389631796e+01,1.187488068026702592e+01,1.193926274034254043e+01,1.200226526334670751e+01,1.206394613085175571e+01,1.212435965632881540e+01,1.218355687252614317e+01,1.224158579048530093e+01,1.229849163348940877e+01,1.235431704879949422e+01,1.240910229966259948e+01,1.246288543975711072e+01,1.251570247196842089e+01,1.256758749315427792e+01,1.261857282635759248e+01,1.266868914175060112e+01,1.271796556744363649e+01,1.276642979116095233e+01,1.281410815367223677e+01,1.286102573476928868e+01,1.290720643249037813e+01,1.295267303621881183e+01,1.299744729421553302e+01,1.304154997608660871e+01,1.308500093063483760e+01,1.312781913949875801e+01,1.317002276694189788e+01,1.321162920611916824e+01,1.325265512211531771e+01,1.329311649202197110e+01,1.333302864229450257e+01,1.337240628360730099e+01,1.341126354340580917e+01,1.344961399633558408e+01,1.348747069271239951e+01,1.352484618518280968e+01,1.356175255371142896e+01,1.359820142901944351e+01,1.363420401458806452e+01,1.366977110733118828e+01,1.370491311703258219e+01,1.373964008463524067e+01,1.377396169946327120e+01,1.380788731545019132e+01,1.384142596644169032e+01,1.387458638063544214e+01,1.390737699421572948e+01,1.393980596423615559e+01,1.397188118079969854e+01,1.400361027858152596e+01,1.403500064773676925e+01,1.406605944423214716e+01,1.409679359963763190e+01,1.412720983041165645e+01,1.415731464671097228e+01,1.418711436075406063e+01,1.421661509476498431e+01,1.424582278852268580e+01,1.427474320653900719e+01,1.430338194488713555e+01,1.433174443770070638e+01,1.435983596336244972e+01,1.438766165040000367e+01,1.441522648310533938e+01,1.444253530689326936e+01,1.446959283341336189e+01 2 | 2.835817628082911046e+00,4.451297854972974477e+00,5.039753640007116786e+00,5.407735654860541530e+00,5.676135797762871249e+00,5.887845872681351622e+00,6.072845205840623706e+00,6.242026713635254964e+00,6.397883537216893401e+00,6.542362196552742759e+00,6.677011560198576667e+00,6.803084288023638848e+00,6.921607856112522938e+00,7.033435492350255558e+00,7.139283477829295066e+00,7.239758971859660441e+00,7.335381110390393644e+00,7.426597239527264804e+00,7.513795571120716943e+00,7.597315166912579798e+00,7.677453900609573445e+00,7.754474870244774110e+00,7.828611609262651427e+00,7.900072356656149708e+00,7.969043582953901250e+00,8.035692922448703257e+00,8.100171627751457493e+00,8.162616637110367890e+00,8.223152325567564702e+00,8.281891996257535027e+00,8.338939156789853868e+00,8.394388616843983186e+00,8.448327436211176433e+00,8.500835747089057648e+00,8.551987470127489388e+00,8.601850940285672209e+00,8.650489455797751859e+00,8.697961761311660212e+00,8.744322474451768912e+00,8.789622463573934397e+00,8.833909183264786691e+00,8.877226973133494070e+00,8.919617324612378439e+00,8.961119119790739163e+00,9.001768845727561086e+00,9.041600787203568856e+00,9.080647200464252222e+00,9.118938470160010468e+00,9.156503251396493681e+00,9.193368598558850024e+00,9.229560082360631412e+00,9.265101896385839808e+00,9.300016954235957201e+00,9.334326978258875585e+00,9.368052580720233280e+00,9.401213338176653167e+00,9.433827859722917708e+00,9.465913849708726957e+00,9.497488165454329945e+00,9.528566870436019798e+00,9.559165283361664933e+00,9.589298023511524605e+00,9.618979052680391106e+00,9.648221714022156448e+00,9.677038768067454555e+00,9.705442426157588898e+00,9.733444381514122057e+00,9.761055838141906804e+00,9.788287537744379208e+00,9.815149784812945200e+00,9.841652470037100642e+00,9.867805092168371317e+00,9.893616778459065841e+00,9.919096303785817526e+00,9.944252108558254477e+00,9.969092315504212820e+00,9.993624745415019106e+00,1.001785693192723059e+01,1.004179613541071703e+01,1.006544935602715007e+01,1.008882334601767994e+01,1.011192462127372949e+01,1.013475947224055673e+01,1.015733397419917772e+01,1.017965399696877071e+01,1.020172521406825794e+01,1.022355311137286193e+01,1.024514299529870698e+01,1.026650000054597278e+01,1.028762909742886755e+01,1.030853509881864483e+01,1.032922266672387046e+01,1.034969631853045691e+01,1.036996043292239378e+01,1.039001925550254235e+01,1.040987690413155775e+01,1.042999619488153940e+01,1.045046189097938694e+01,1.047071821164495020e+01,1.049076939758387539e+01 3 | 0.000000000000000000e+00,1.615480226890062987e+00,2.203936011924206184e+00,2.571918026777630040e+00,2.840318169679960647e+00,3.051711541295892172e+00,3.226122233388967864e+00,3.374583733144065345e+00,3.503826898808750201e+00,3.618261886899983182e+00,3.720935870230052345e+00,3.814042597067942353e+00,3.899214047135564964e+00,3.977697245029644613e+00,4.050466556050046130e+00,4.118297803244593602e+00,4.181818795040405767e+00,4.241544736222835787e+00,4.297903644260121503e+00,4.351254976007127517e+00,4.401903531345118914e+00,4.450110001777992430e+00,4.496099090933729059e+00,4.540065848240832480e+00,4.582180667763392279e+00,4.622593276171006238e+00,4.661435945638555900e+00,4.698826105696389810e+00,4.734868484107488840e+00,4.769656875146011998e+00,4.803275610483307645e+00,4.835800790752637468e+00,4.867301323050400086e+00,4.897839799950911122e+00,4.927473248227578040e+00,4.956253769790368580e+00,4.984229092939141736e+00,5.011443048582824744e+00,5.037935983356284453e+00,5.063745119410204332e+00,5.088904868926901415e+00,5.113447110031068554e+00,5.137401429645717776e+00,5.160795337934402482e+00,5.183654458227897166e+00,5.206002695723510598e+00,5.227862387741888384e+00,5.249254437909031346e+00,5.270198436284035282e+00,5.290712767162816910e+00,5.310814706044589606e+00,5.330520507042671774e+00,5.349845481847847140e+00,5.368804071205368622e+00,5.387409909741490921e+00,5.405675884868536230e+00,5.423614190405924518e+00,5.441236375475962284e+00,5.458553389165421876e+00,5.475575621385444514e+00,5.492312940311590452e+00,5.508774726741878069e+00,5.524969905672326043e+00,5.540906975356110209e+00,5.556594034083211398e+00,5.572038804891815467e+00,5.587248658400249646e+00,5.602230633928400039e+00,5.616991459060146852e+00,5.631537567782872067e+00,5.645875117326475845e+00,5.660010003812195656e+00,5.673947876810787605e+00,5.687694152900007794e+00,5.701254028302842336e+00,5.714632490680267018e+00,5.727834330145515906e+00,5.740864149560732876e+00,5.753726374171379554e+00,5.766425260628873062e+00,5.778964905447485201e+00,5.791349252937530778e+00,5.803582102653310315e+00,5.815667116390970115e+00,5.827607824768547218e+00,5.839407633417768473e+00,5.851069828814787321e+00,5.862597583774814325e+00,5.873993962633643484e+00,5.885261926137233424e+00,5.896404336058864715e+00,5.907423959561888260e+00,5.918323473324702988e+00,5.929105467443343436e+00,5.939772449125913845e+00,5.950326846192050212e+00,5.960771010389628977e+00,5.971107220540058158e+00,5.981337685522666092e+00,5.991464547107981709e+00 4 | -------------------------------------------------------------------------------- /tests/data/analytical_solution/iskhakov_2017_low_delta__work_decision.csv: -------------------------------------------------------------------------------- 1 | 1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00 2 | 1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00 3 | 0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00 4 | -------------------------------------------------------------------------------- /tests/data/regression_tests/simulation.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenSourceEconomics/pylcm/5692daa838cc944ef5f8d438aba5507b736eaf6c/tests/data/regression_tests/simulation.pkl -------------------------------------------------------------------------------- /tests/data/regression_tests/solution.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenSourceEconomics/pylcm/5692daa838cc944ef5f8d438aba5507b736eaf6c/tests/data/regression_tests/solution.pkl -------------------------------------------------------------------------------- /tests/input_processing/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenSourceEconomics/pylcm/5692daa838cc944ef5f8d438aba5507b736eaf6c/tests/input_processing/__init__.py -------------------------------------------------------------------------------- /tests/input_processing/test_create_params_template.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Any 3 | 4 | import jax.numpy as jnp 5 | import pandas as pd 6 | import pytest 7 | 8 | from lcm.grids import DiscreteGrid 9 | from lcm.input_processing.create_params_template import ( 10 | _create_function_params, 11 | _create_stochastic_transition_params, 12 | create_params_template, 13 | ) 14 | 15 | 16 | @dataclass 17 | class ModelMock: 18 | """A model mock for testing the params creation functions. 19 | 20 | This dataclass has the same attributes as the Model dataclass, but does not perform 21 | any checks, which helps us to test the params creation functions in isolation. 22 | 23 | """ 24 | 25 | n_periods: int | None = None 26 | functions: dict[str, Any] | None = None 27 | actions: dict[str, Any] | None = None 28 | states: dict[str, Any] | None = None 29 | 30 | 31 | def test_create_params_without_shocks(binary_category_class): 32 | model = ModelMock( 33 | functions={ 34 | "f": lambda a, b, c: None, # noqa: ARG005 35 | "next_b": lambda b: b, 36 | }, 37 | actions={ 38 | "a": DiscreteGrid(binary_category_class), 39 | }, 40 | states={ 41 | "b": DiscreteGrid(binary_category_class), 42 | }, 43 | n_periods=None, 44 | ) 45 | got = create_params_template(model) # type: ignore[arg-type] 46 | assert got == {"beta": jnp.nan, "f": {"c": jnp.nan}, "next_b": {}} 47 | 48 | 49 | def test_create_function_params(): 50 | model = ModelMock( 51 | functions={ 52 | "f": lambda a, b, c: None, # noqa: ARG005 53 | }, 54 | actions={ 55 | "a": None, 56 | }, 57 | states={ 58 | "b": None, 59 | }, 60 | ) 61 | got = _create_function_params(model) # type: ignore[arg-type] 62 | assert got == {"f": {"c": jnp.nan}} 63 | 64 | 65 | def test_create_shock_params(): 66 | def next_a(a, _period): 67 | pass 68 | 69 | variable_info = pd.DataFrame( 70 | {"is_stochastic": True, "is_state": True, "is_discrete": True}, 71 | index=["a"], 72 | ) 73 | 74 | model = ModelMock( 75 | n_periods=3, 76 | functions={"next_a": next_a}, 77 | ) 78 | 79 | got = _create_stochastic_transition_params( 80 | model=model, # type: ignore[arg-type] 81 | variable_info=variable_info, 82 | grids={"a": jnp.array([1, 2])}, 83 | ) 84 | jnp.array_equal(got["a"], jnp.full((2, 3, 2), jnp.nan), equal_nan=True) 85 | 86 | 87 | def test_create_shock_params_invalid_variable(): 88 | def next_a(a): 89 | pass 90 | 91 | variable_info = pd.DataFrame( 92 | {"is_stochastic": True, "is_state": True, "is_discrete": False}, 93 | index=["a"], 94 | ) 95 | 96 | model = ModelMock( 97 | functions={"next_a": next_a}, 98 | ) 99 | 100 | with pytest.raises(ValueError, match="The following variables are stochastic, but"): 101 | _create_stochastic_transition_params( 102 | model=model, # type: ignore[arg-type] 103 | variable_info=variable_info, 104 | grids={"a": jnp.array([1, 2])}, 105 | ) 106 | 107 | 108 | def test_create_shock_params_invalid_dependency(): 109 | def next_a(a, b, _period): 110 | pass 111 | 112 | variable_info = pd.DataFrame( 113 | { 114 | "is_stochastic": [True, False], 115 | "is_state": [True, False], 116 | "is_discrete": [True, False], 117 | }, 118 | index=["a", "b"], 119 | ) 120 | 121 | model = ModelMock( 122 | functions={"next_a": next_a}, 123 | ) 124 | 125 | with pytest.raises(ValueError, match="Stochastic transition functions can only"): 126 | _create_stochastic_transition_params( 127 | model=model, # type: ignore[arg-type] 128 | variable_info=variable_info, 129 | grids={"a": jnp.array([1, 2])}, 130 | ) 131 | -------------------------------------------------------------------------------- /tests/simulation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenSourceEconomics/pylcm/5692daa838cc944ef5f8d438aba5507b736eaf6c/tests/simulation/__init__.py -------------------------------------------------------------------------------- /tests/simulation/test_processing.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | import pandas as pd 3 | from pybaum import tree_equal 4 | 5 | from lcm.interfaces import InternalSimulationPeriodResults 6 | from lcm.simulation.processing import ( 7 | _compute_targets, 8 | as_panel, 9 | process_simulated_data, 10 | ) 11 | 12 | 13 | def test_compute_targets(): 14 | processed_results = { 15 | "a": jnp.arange(3), 16 | "b": 1 + jnp.arange(3), 17 | "c": 2 + jnp.arange(3), 18 | } 19 | 20 | def f_a(a, params): 21 | return a + params["disutility_of_work"] 22 | 23 | def f_b(b, params): # noqa: ARG001 24 | return b 25 | 26 | def f_c(params): # noqa: ARG001 27 | return None 28 | 29 | model_functions = {"fa": f_a, "fb": f_b, "fc": f_c} 30 | 31 | got = _compute_targets( 32 | processed_results=processed_results, 33 | targets=["fa", "fb"], 34 | model_functions=model_functions, # type: ignore[arg-type] 35 | params={"disutility_of_work": -1.0}, 36 | ) 37 | expected = { 38 | "fa": jnp.arange(3) - 1.0, 39 | "fb": 1 + jnp.arange(3), 40 | } 41 | assert tree_equal(expected, got) 42 | 43 | 44 | def test_as_panel(): 45 | processed = { 46 | "value": -6 + jnp.arange(6), 47 | "a": jnp.arange(6), 48 | "b": 6 + jnp.arange(6), 49 | } 50 | got = as_panel(processed, n_periods=2) 51 | expected = pd.DataFrame( 52 | { 53 | "period": [0, 0, 0, 1, 1, 1], 54 | "initial_state_id": [0, 1, 2, 0, 1, 2], 55 | **processed, 56 | }, 57 | ).set_index(["period", "initial_state_id"]) 58 | pd.testing.assert_frame_equal(got, expected) 59 | 60 | 61 | def test_process_simulated_data(): 62 | simulated = { 63 | 0: InternalSimulationPeriodResults( 64 | value=jnp.array([0.1, 0.2]), 65 | states={"a": jnp.array([1, 2]), "b": jnp.array([-1, -2])}, 66 | actions={"c": jnp.array([5, 6]), "d": jnp.array([-5, -6])}, 67 | ), 68 | 1: InternalSimulationPeriodResults( 69 | value=jnp.array([0.3, 0.4]), 70 | states={ 71 | "b": jnp.array([-3, -4]), 72 | "a": jnp.array([3, 4]), 73 | }, 74 | actions={ 75 | "d": jnp.array([-7, -8]), 76 | "c": jnp.array([7, 8]), 77 | }, 78 | ), 79 | } 80 | expected = { 81 | "value": jnp.array([0.1, 0.2, 0.3, 0.4]), 82 | "c": jnp.array([5, 6, 7, 8]), 83 | "d": jnp.array([-5, -6, -7, -8]), 84 | "a": jnp.array([1, 2, 3, 4]), 85 | "b": jnp.array([-1, -2, -3, -4]), 86 | } 87 | 88 | got = process_simulated_data( 89 | simulated, 90 | # Rest is none, since we are not computing any additional targets 91 | model=None, # type: ignore[arg-type] 92 | params=None, # type: ignore[arg-type] 93 | additional_targets=None, 94 | ) 95 | assert tree_equal(expected, got) 96 | -------------------------------------------------------------------------------- /tests/solution/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenSourceEconomics/pylcm/5692daa838cc944ef5f8d438aba5507b736eaf6c/tests/solution/__init__.py -------------------------------------------------------------------------------- /tests/solution/test_solve_brute.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | import numpy as np 3 | from numpy.testing import assert_array_almost_equal as aaae 4 | 5 | from lcm.interfaces import StateActionSpace 6 | from lcm.logging import get_logger 7 | from lcm.max_Q_over_c import get_max_Q_over_c 8 | from lcm.ndimage import map_coordinates 9 | from lcm.solution.solve_brute import solve 10 | 11 | 12 | def test_solve_brute(): 13 | """Test solve brute with hand written inputs. 14 | 15 | Normally, these inputs would be created from a model specification. For now this can 16 | be seen as reference of what the functions that process a model specification need 17 | to produce. 18 | 19 | """ 20 | # ================================================================================== 21 | # create the params 22 | # ================================================================================== 23 | params = {"beta": 0.9} 24 | 25 | # ================================================================================== 26 | # create the list of state_action_spaces 27 | # ================================================================================== 28 | _scs = StateActionSpace( 29 | discrete_actions={ 30 | # pick [0, 1] such that no label translation is needed 31 | # lazy is like a type, it influences utility but is not affected by actions 32 | "lazy": jnp.array([0, 1]), 33 | "working": jnp.array([0, 1]), 34 | }, 35 | continuous_actions={ 36 | "consumption": jnp.array([0, 1, 2, 3]), 37 | }, 38 | states={ 39 | # pick [0, 1, 2] such that no coordinate mapping is needed 40 | "wealth": jnp.array([0.0, 1.0, 2.0]), 41 | }, 42 | states_and_discrete_actions_names=("lazy", "working", "wealth"), 43 | ) 44 | state_action_spaces = {0: _scs, 1: _scs} 45 | 46 | # ================================================================================== 47 | # create the Q_and_F functions 48 | # ================================================================================== 49 | 50 | def _Q_and_F(consumption, lazy, wealth, working, next_V_arr, params): 51 | next_wealth = wealth + working - consumption 52 | next_lazy = lazy 53 | 54 | if next_V_arr.size == 0: 55 | # this is the last period, when next_V_arr = jnp.empty(0) 56 | expected_V = 0 57 | else: 58 | expected_V = map_coordinates( 59 | input=next_V_arr[next_lazy], 60 | coordinates=jnp.array([next_wealth]), 61 | ) 62 | 63 | U_arr = consumption - 0.2 * lazy * working 64 | F_arr = next_wealth >= 0 65 | 66 | Q_arr = U_arr + params["beta"] * expected_V 67 | 68 | return Q_arr, F_arr 69 | 70 | max_Q_over_c = get_max_Q_over_c( 71 | Q_and_F=_Q_and_F, 72 | continuous_actions_names=("consumption",), 73 | states_and_discrete_actions_names=("lazy", "working", "wealth"), 74 | ) 75 | 76 | max_Q_over_c_functions = {0: max_Q_over_c, 1: max_Q_over_c} 77 | 78 | # ================================================================================== 79 | # create max_Qc_over_d functions 80 | # ================================================================================== 81 | 82 | def max_Qc_over_d(Qc_arr, params): # noqa: ARG001 83 | """Take max over axis that corresponds to working.""" 84 | return Qc_arr.max(axis=1) 85 | 86 | max_Qc_over_d_functions = {0: max_Qc_over_d, 1: max_Qc_over_d} 87 | 88 | # ================================================================================== 89 | # call solve function 90 | # ================================================================================== 91 | 92 | solution = solve( 93 | params=params, 94 | state_action_spaces=state_action_spaces, 95 | max_Q_over_c_functions=max_Q_over_c_functions, 96 | max_Qc_over_d_functions=max_Qc_over_d_functions, 97 | logger=get_logger(debug_mode=False), 98 | ) 99 | 100 | assert isinstance(solution, dict) 101 | 102 | 103 | def test_solve_brute_single_period_Qc_arr(): 104 | state_action_space = StateActionSpace( 105 | discrete_actions={ 106 | "a": jnp.array([0, 1.0]), 107 | "b": jnp.array([2, 3.0]), 108 | "c": jnp.array([4, 5, 6]), 109 | }, 110 | continuous_actions={ 111 | "d": jnp.arange(12.0), 112 | }, 113 | states={}, 114 | states_and_discrete_actions_names=("a", "b", "c"), 115 | ) 116 | 117 | def _Q_and_F(a, c, b, d, next_V_arr, params): # noqa: ARG001 118 | util = d 119 | feasib = d <= a + b + c 120 | return util, feasib 121 | 122 | max_Q_over_c = get_max_Q_over_c( 123 | Q_and_F=_Q_and_F, 124 | continuous_actions_names=("d",), 125 | states_and_discrete_actions_names=("a", "b", "c"), 126 | ) 127 | 128 | expected = np.array([[[6.0, 7, 8], [7, 8, 9]], [[7, 8, 9], [8, 9, 10]]]) 129 | 130 | # by setting max_Qc_over_d to identity, we can test that the max_Q_over_c function 131 | # is correctly applied to the state_action_space 132 | got = solve( 133 | params={}, 134 | state_action_spaces={0: state_action_space}, 135 | max_Q_over_c_functions={0: max_Q_over_c}, 136 | max_Qc_over_d_functions={0: lambda x, params: x}, # noqa: ARG005 137 | logger=get_logger(debug_mode=False), 138 | ) 139 | 140 | aaae(got[0], expected) 141 | -------------------------------------------------------------------------------- /tests/test_Q_and_F.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | import pandas as pd 3 | import pytest 4 | from jax import Array 5 | from numpy.testing import assert_array_equal 6 | 7 | from lcm.input_processing import process_model 8 | from lcm.interfaces import InternalModel 9 | from lcm.Q_and_F import ( 10 | _get_feasibility, 11 | _get_joint_weights_function, 12 | get_Q_and_F, 13 | ) 14 | from lcm.state_action_space import create_state_space_info 15 | from lcm.typing import ShockType 16 | from tests.test_models import get_model_config 17 | from tests.test_models.deterministic import utility 18 | 19 | 20 | @pytest.mark.illustrative 21 | def test_get_Q_and_F_function(): 22 | model = process_model( 23 | get_model_config("iskhakov_et_al_2017_stripped_down", n_periods=3), 24 | ) 25 | 26 | params = { 27 | "beta": 1.0, 28 | "utility": {"disutility_of_work": 1.0}, 29 | "next_wealth": { 30 | "interest_rate": 0.05, 31 | "wage": 1.0, 32 | }, 33 | } 34 | 35 | state_space_info = create_state_space_info( 36 | model=model, 37 | is_last_period=False, 38 | ) 39 | 40 | Q_and_F = get_Q_and_F( 41 | model=model, 42 | next_state_space_info=state_space_info, 43 | period=model.n_periods - 1, 44 | ) 45 | 46 | consumption = jnp.array([10, 20, 30]) 47 | retirement = jnp.array([0, 1, 0]) 48 | wealth = jnp.array([20, 20, 20]) 49 | 50 | Q_arr, F_arr = Q_and_F( 51 | consumption=consumption, 52 | retirement=retirement, 53 | wealth=wealth, 54 | params=params, 55 | next_V_arr=None, 56 | ) 57 | 58 | assert_array_equal( 59 | Q_arr, 60 | utility( 61 | consumption=consumption, 62 | working=1 - retirement, 63 | disutility_of_work=1.0, 64 | ), 65 | ) 66 | assert_array_equal(F_arr, jnp.array([True, True, False])) 67 | 68 | 69 | @pytest.fixture 70 | def internal_model_illustrative(): 71 | def age(period): 72 | return period + 18 73 | 74 | def mandatory_retirement_constraint(retirement, age, params): # noqa: ARG001 75 | # Individuals must be retired from age 65 onwards 76 | return jnp.logical_or(retirement == 1, age < 65) 77 | 78 | def mandatory_lagged_retirement_constraint(lagged_retirement, age, params): # noqa: ARG001 79 | # Individuals must have been retired last year from age 66 onwards 80 | return jnp.logical_or(lagged_retirement == 1, age < 66) 81 | 82 | def absorbing_retirement_constraint(retirement, lagged_retirement, params): # noqa: ARG001 83 | # If an individual was retired last year, it must be retired this year 84 | return jnp.logical_or(retirement == 1, lagged_retirement == 0) 85 | 86 | grids = { 87 | "lagged_retirement": jnp.array([0, 1]), 88 | "retirement": jnp.array([0, 1]), 89 | } 90 | 91 | functions = { 92 | "mandatory_retirement_constraint": mandatory_retirement_constraint, 93 | "mandatory_lagged_retirement_constraint": ( 94 | mandatory_lagged_retirement_constraint 95 | ), 96 | "absorbing_retirement_constraint": absorbing_retirement_constraint, 97 | "age": age, 98 | } 99 | 100 | function_info = pd.DataFrame( 101 | {"is_constraint": [True, True, True, False]}, 102 | index=list(functions), 103 | ) 104 | 105 | # create a model instance where some attributes are set to None because they 106 | # are not needed to create the feasibilty mask 107 | return InternalModel( 108 | grids=grids, 109 | gridspecs={}, 110 | variable_info=pd.DataFrame(), 111 | functions=functions, # type: ignore[arg-type] 112 | function_info=function_info, 113 | params={}, 114 | random_utility_shocks=ShockType.NONE, 115 | n_periods=0, 116 | ) 117 | 118 | 119 | @pytest.mark.illustrative 120 | def test_get_combined_constraint_illustrative(internal_model_illustrative): 121 | combined_constraint = _get_feasibility(internal_model_illustrative) 122 | 123 | age, retirement, lagged_retirement = jnp.array( 124 | [ 125 | # feasible cases 126 | [60, 0, 0], # Young, never retired 127 | [64, 1, 0], # Near retirement, newly retired 128 | [70, 1, 1], # Properly retired with lagged retirement 129 | # infeasible cases 130 | [65, 0, 0], # Must be retired at 65 131 | [66, 0, 1], # Must have lagged retirement at 66 132 | [60, 0, 1], # Can't be not retired if was retired before 133 | ] 134 | ).T 135 | 136 | # combined constraint expects period not age 137 | period = age - 18 138 | 139 | exp = jnp.array(3 * [True] + 3 * [False]) 140 | got = combined_constraint( 141 | period=period, 142 | retirement=retirement, 143 | lagged_retirement=lagged_retirement, 144 | params={}, 145 | ) 146 | assert_array_equal(got, exp) 147 | 148 | 149 | def test_get_multiply_weights(): 150 | multiply_weights = _get_joint_weights_function( 151 | stochastic_variables=["a", "b"], 152 | ) 153 | 154 | a = jnp.array([1, 2]) 155 | b = jnp.array([3, 4]) 156 | 157 | got = multiply_weights(weight_next_a=a, weight_next_b=b) 158 | expected = jnp.array([[3, 4], [6, 8]]) 159 | assert_array_equal(got, expected) 160 | 161 | 162 | def test_get_combined_constraint(): 163 | def f(params): # noqa: ARG001 164 | return True 165 | 166 | def g(params): # noqa: ARG001 167 | return False 168 | 169 | def h(params): # noqa: ARG001 170 | return None 171 | 172 | function_info = pd.DataFrame( 173 | {"is_constraint": [True, True, False]}, 174 | index=["f", "g", "h"], 175 | ) 176 | model = InternalModel( 177 | grids={}, 178 | gridspecs={}, 179 | variable_info=pd.DataFrame(), 180 | functions={"f": f, "g": g, "h": h}, # type: ignore[dict-item] 181 | function_info=function_info, 182 | params={}, 183 | random_utility_shocks=ShockType.NONE, 184 | n_periods=0, 185 | ) 186 | combined_constraint = _get_feasibility(model) 187 | feasibility: Array = combined_constraint(params={}) # type: ignore[assignment] 188 | assert feasibility.item() is False 189 | -------------------------------------------------------------------------------- /tests/test_analytical_solution.py: -------------------------------------------------------------------------------- 1 | """Testing against the analytical solution of Iskhakov et al. (2017). 2 | 3 | The benchmark is taken from the paper "The endogenous grid method for 4 | discrete-continuous dynamic action models with (or without) taste shocks" by Fedor 5 | Iskhakov, Thomas H. Jørgensen, John Rust and Bertel Schjerning (2017, 6 | https://doi.org/10.3982/QE643). 7 | 8 | """ 9 | 10 | from typing import TYPE_CHECKING 11 | 12 | import numpy as np 13 | import pytest 14 | from numpy.testing import assert_array_almost_equal as aaae 15 | 16 | from lcm._config import TEST_DATA 17 | from lcm.entry_point import get_lcm_function 18 | from tests.test_models import get_model_config, get_params 19 | 20 | if TYPE_CHECKING: 21 | from jax import Array 22 | 23 | # ====================================================================================== 24 | # Model specifications 25 | # ====================================================================================== 26 | 27 | 28 | TEST_CASES = { 29 | "iskhakov_2017_five_periods": { 30 | "model": get_model_config("iskhakov_et_al_2017", n_periods=5), 31 | "params": get_params( 32 | beta=0.98, 33 | disutility_of_work=1.0, 34 | interest_rate=0.0, 35 | wage=20.0, 36 | ), 37 | }, 38 | "iskhakov_2017_low_delta": { 39 | "model": get_model_config("iskhakov_et_al_2017", n_periods=3), 40 | "params": get_params( 41 | beta=0.98, 42 | disutility_of_work=0.1, 43 | interest_rate=0.0, 44 | wage=20.0, 45 | ), 46 | }, 47 | } 48 | 49 | 50 | def mean_square_error(x, y, axis=None): 51 | return np.mean((x - y) ** 2, axis=axis) 52 | 53 | 54 | # ====================================================================================== 55 | # Test 56 | # ====================================================================================== 57 | 58 | 59 | @pytest.mark.parametrize(("model_name", "model_and_params"), TEST_CASES.items()) 60 | def test_analytical_solution(model_name, model_and_params): 61 | """Test that the numerical solution matches the analytical solution. 62 | 63 | The analytical solution is from Iskhakov et al (2017) and is generated 64 | in the development repository: github.com/opensourceeconomics/pylcm-dev. 65 | 66 | """ 67 | # Compute LCM solution 68 | # ================================================================================== 69 | solve_model, _ = get_lcm_function(model=model_and_params["model"], targets="solve") 70 | 71 | V_arr_dict: dict[int, Array] = solve_model(params=model_and_params["params"]) # type: ignore[assignment] 72 | V_arr_list = list(dict(sorted(V_arr_dict.items(), key=lambda x: x[0])).values()) 73 | 74 | _numerical = np.stack(V_arr_list) 75 | numerical = { 76 | "worker": _numerical[:, 0, :], 77 | "retired": _numerical[:, 1, :], 78 | } 79 | 80 | # Load analytical solution 81 | # ================================================================================== 82 | analytical = { 83 | _type: np.genfromtxt( 84 | TEST_DATA.joinpath( 85 | "analytical_solution", 86 | f"{model_name}__values_{_type}.csv", 87 | ), 88 | delimiter=",", 89 | ) 90 | for _type in ["worker", "retired"] 91 | } 92 | 93 | # Compare 94 | # ================================================================================== 95 | for _type in ["worker", "retired"]: 96 | _analytical = np.array(analytical[_type]) 97 | _numerical = numerical[_type] 98 | 99 | # Compare the whole trajectory over time 100 | mse = mean_square_error(_analytical, _numerical, axis=0) 101 | # Exclude the first two initial wealth levels from the comparison, because the 102 | # numerical solution is unstable for very low wealth levels. 103 | aaae(mse[2:], 0, decimal=1) 104 | -------------------------------------------------------------------------------- /tests/test_argmax.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | from jax import jit 3 | from numpy.testing import assert_array_equal 4 | 5 | from lcm.argmax import _flatten_last_n_axes, _move_axes_to_back, argmax_and_max 6 | 7 | # Test jitted functions 8 | # ====================================================================================== 9 | jitted_argmax = jit(argmax_and_max, static_argnums=[1, 2]) 10 | 11 | 12 | # ====================================================================================== 13 | # argmax 14 | # ====================================================================================== 15 | 16 | 17 | def test_argmax_1d_with_mask(): 18 | a = jnp.arange(10) 19 | mask = jnp.array([1, 0, 0, 1, 1, 0, 0, 0, 0, 0], dtype=bool) 20 | _argmax, _max = jitted_argmax(a, where=mask, initial=-1) 21 | assert _argmax == 4 22 | assert _max == 4 23 | 24 | 25 | def test_argmax_2d_with_mask(): 26 | a = jnp.arange(10).reshape(2, 5) 27 | mask = jnp.array([1, 0, 0, 1, 1, 0, 0, 0, 0, 0], dtype=bool).reshape(a.shape) 28 | 29 | _argmax, _max = jitted_argmax(a, axis=None, where=mask, initial=-1) 30 | assert _argmax == 4 31 | assert _max == 4 32 | 33 | _argmax, _max = jitted_argmax(a, axis=0, where=mask, initial=-1) 34 | assert_array_equal(_argmax, jnp.array([0, 0, 0, 0, 0])) 35 | assert_array_equal(_max, jnp.array([0, -1, -1, 3, 4])) 36 | 37 | _argmax, _max = jitted_argmax(a, axis=1, where=mask, initial=-1) 38 | assert_array_equal(_argmax, jnp.array([4, 0])) 39 | assert_array_equal(_max, jnp.array([4, -1])) 40 | 41 | 42 | def test_argmax_1d_no_mask(): 43 | a = jnp.arange(10) 44 | _argmax, _max = jitted_argmax(a) 45 | assert _argmax == 9 46 | assert _max == 9 47 | 48 | 49 | def test_argmax_2d_no_mask(): 50 | a = jnp.arange(10).reshape(2, 5) 51 | 52 | _argmax, _max = jitted_argmax(a, axis=None) 53 | assert _argmax == 9 54 | assert _max == 9 55 | 56 | _argmax, _max = jitted_argmax(a, axis=0) 57 | assert_array_equal(_argmax, jnp.array([1, 1, 1, 1, 1])) 58 | assert_array_equal(_max, jnp.array([5, 6, 7, 8, 9])) 59 | 60 | _argmax, _max = jitted_argmax(a, axis=1) 61 | assert_array_equal(_argmax, jnp.array([4, 4])) 62 | assert_array_equal(_max, jnp.array([4, 9])) 63 | 64 | _argmax, _max = jitted_argmax(a, axis=(0, 1)) 65 | assert _argmax == 9 66 | assert _max == 9 67 | 68 | 69 | def test_argmax_3d_no_mask(): 70 | a = jnp.arange(24).reshape(2, 3, 4) 71 | 72 | _argmax, _max = jitted_argmax(a, axis=None) 73 | assert _argmax == 23 74 | assert _max == 23 75 | 76 | _argmax, _max = jitted_argmax(a, axis=0) 77 | assert_array_equal(_argmax, jnp.array([[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]])) 78 | assert_array_equal( 79 | _max, 80 | jnp.array([[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]), 81 | ) 82 | 83 | _argmax, _max = jitted_argmax(a, axis=1) 84 | assert_array_equal(_argmax, jnp.array([[2, 2, 2, 2], [2, 2, 2, 2]])) 85 | assert_array_equal(_max, jnp.array([[8, 9, 10, 11], [20, 21, 22, 23]])) 86 | 87 | _argmax, _max = jitted_argmax(a, axis=2) 88 | assert_array_equal(_argmax, jnp.array([[3, 3, 3], [3, 3, 3]])) 89 | assert_array_equal(_max, jnp.array([[3, 7, 11], [15, 19, 23]])) 90 | 91 | _argmax, _max = jitted_argmax(a, axis=(0, 1)) 92 | assert_array_equal(_argmax, jnp.array([5, 5, 5, 5])) 93 | assert_array_equal(_max, jnp.array([20, 21, 22, 23])) 94 | 95 | _argmax, _max = jitted_argmax(a, axis=(0, 2)) 96 | assert_array_equal(_argmax, jnp.array([7, 7, 7])) 97 | assert_array_equal(_max, jnp.array([15, 19, 23])) 98 | 99 | _argmax, _max = jitted_argmax(a, axis=(1, 2)) 100 | assert_array_equal(_argmax, jnp.array([11, 11])) 101 | assert_array_equal(_max, jnp.array([11, 23])) 102 | 103 | 104 | def test_argmax_with_ties(): 105 | # If multiple maxima exist, argmax will select the first index. 106 | a = jnp.zeros((2, 2, 2)) 107 | _argmax, _ = jitted_argmax(a, axis=(1, 2)) 108 | assert_array_equal(_argmax, jnp.array([0, 0])) 109 | 110 | 111 | # ====================================================================================== 112 | # Move axes to back 113 | # ====================================================================================== 114 | 115 | 116 | def test_move_axes_to_back_1d(): 117 | a = jnp.arange(4) 118 | got = _move_axes_to_back(a, axes=(0,)) 119 | assert_array_equal(got, a) 120 | 121 | 122 | def test_move_axes_to_back_2d(): 123 | a = jnp.arange(4).reshape(2, 2) 124 | got = _move_axes_to_back(a, axes=(0,)) 125 | assert_array_equal(got, a.transpose(1, 0)) 126 | 127 | 128 | def test_move_axes_to_back_3d(): 129 | # 2 dimensions in back 130 | a = jnp.arange(8).reshape(2, 2, 2) 131 | got = _move_axes_to_back(a, axes=(0, 1)) 132 | assert_array_equal(got, a.transpose(2, 0, 1)) 133 | 134 | # 2 dimensions in front 135 | a = jnp.arange(8).reshape(2, 2, 2) 136 | got = _move_axes_to_back(a, axes=(1,)) 137 | assert_array_equal(got, a.transpose(0, 2, 1)) 138 | 139 | 140 | # ====================================================================================== 141 | # Flatten last n axes 142 | # ====================================================================================== 143 | 144 | 145 | def test_flatten_last_n_axes_1d(): 146 | a = jnp.arange(4) 147 | got = _flatten_last_n_axes(a, n=1) 148 | assert_array_equal(got, a) 149 | 150 | 151 | def test_flatten_last_n_axes_2d(): 152 | a = jnp.arange(4).reshape(2, 2) 153 | 154 | got = _flatten_last_n_axes(a, n=1) 155 | assert_array_equal(got, a) 156 | 157 | got = _flatten_last_n_axes(a, n=2) 158 | assert_array_equal(got, a.reshape(4)) 159 | 160 | 161 | def test_flatten_last_n_axes_3d(): 162 | a = jnp.arange(8).reshape(2, 2, 2) 163 | 164 | got = _flatten_last_n_axes(a, n=1) 165 | assert_array_equal(got, a) 166 | 167 | got = _flatten_last_n_axes(a, n=2) 168 | assert_array_equal(got, a.reshape(2, 4)) 169 | 170 | got = _flatten_last_n_axes(a, n=3) 171 | assert_array_equal(got, a.reshape(8)) 172 | -------------------------------------------------------------------------------- /tests/test_functools.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import re 3 | 4 | import jax.numpy as jnp 5 | import pytest 6 | from jax import vmap 7 | from numpy.testing import assert_array_almost_equal as aaae 8 | 9 | from lcm.functools import ( 10 | all_as_args, 11 | all_as_kwargs, 12 | allow_args, 13 | allow_only_kwargs, 14 | convert_kwargs_to_args, 15 | get_union_of_arguments, 16 | ) 17 | 18 | # ====================================================================================== 19 | # get_union_of_arguments 20 | # ====================================================================================== 21 | 22 | 23 | def test_get_union_of_arguments(): 24 | def f(a, b): 25 | pass 26 | 27 | def g(b, c): 28 | pass 29 | 30 | got = get_union_of_arguments([f, g]) 31 | assert got == {"a", "b", "c"} 32 | 33 | 34 | def test_get_union_of_arguments_no_args(): 35 | def f(): 36 | pass 37 | 38 | got = get_union_of_arguments([f]) 39 | assert got == set() 40 | 41 | 42 | # ====================================================================================== 43 | # all_as_kwargs 44 | # ====================================================================================== 45 | 46 | 47 | def test_all_as_kwargs(): 48 | got = all_as_kwargs( 49 | args=(1, 2), 50 | kwargs={"c": 3}, 51 | arg_names=["a", "b", "c"], 52 | ) 53 | assert got == {"a": 1, "b": 2, "c": 3} 54 | 55 | 56 | def test_all_as_kwargs_empty_args(): 57 | got = all_as_kwargs( 58 | args=(), 59 | kwargs={"a": 1, "b": 2, "c": 3}, 60 | arg_names=["a", "b", "c"], 61 | ) 62 | assert got == {"a": 1, "b": 2, "c": 3} 63 | 64 | 65 | def test_all_as_kwargs_empty_kwargs(): 66 | got = all_as_kwargs( 67 | args=(1, 2, 3), 68 | kwargs={}, 69 | arg_names=["a", "b", "c"], 70 | ) 71 | assert got == {"a": 1, "b": 2, "c": 3} 72 | 73 | 74 | # ====================================================================================== 75 | # all_as_args 76 | # ====================================================================================== 77 | 78 | 79 | def test_all_as_args(): 80 | got = all_as_args( 81 | args=(1, 2), 82 | kwargs={"c": 3}, 83 | arg_names=["a", "b", "c"], 84 | ) 85 | assert got == (1, 2, 3) 86 | 87 | 88 | def test_all_as_args_empty_args(): 89 | got = all_as_args( 90 | args=(), 91 | kwargs={"a": 1, "b": 2, "c": 3}, 92 | arg_names=["a", "b", "c"], 93 | ) 94 | assert got == (1, 2, 3) 95 | 96 | 97 | def test_all_as_args_empty_kwargs(): 98 | got = all_as_args( 99 | args=(1, 2, 3), 100 | kwargs={}, 101 | arg_names=["a", "b", "c"], 102 | ) 103 | assert got == (1, 2, 3) 104 | 105 | 106 | # ====================================================================================== 107 | # convert kwargs to args 108 | # ====================================================================================== 109 | 110 | 111 | def test_convert_kwargs_to_args(): 112 | kwargs = {"a": 1, "b": 2, "c": 3} 113 | parameters = ["c", "a", "b"] 114 | exp = [3, 1, 2] 115 | got = convert_kwargs_to_args(kwargs, parameters) 116 | assert got == exp 117 | 118 | 119 | # ====================================================================================== 120 | # allow kwargs 121 | # ====================================================================================== 122 | 123 | 124 | def test_allow_only_kwargs(): 125 | def f(a, /, b): 126 | # a is positional-only 127 | return a + b 128 | 129 | with pytest.raises(TypeError): 130 | f(a=1, b=2) # type: ignore[call-arg] 131 | 132 | assert allow_only_kwargs(f)(a=1, b=2) == 3 133 | 134 | 135 | def test_allow_only_kwargs_with_keyword_only_args(): 136 | def f(a, /, *, b): 137 | return a + b 138 | 139 | with pytest.raises(TypeError): 140 | f(a=1, b=2) # type: ignore[call-arg] 141 | 142 | assert allow_only_kwargs(f)(a=1, b=2) == 3 143 | 144 | 145 | def test_allow_only_kwargs_too_many_args(): 146 | def f(a, /, b): 147 | return a + b 148 | 149 | too_many_match = re.escape("Expected arguments: ['a', 'b'], got extra: {'c'}") 150 | with pytest.raises(ValueError, match=too_many_match): 151 | allow_only_kwargs(f)(a=1, b=2, c=3) 152 | 153 | 154 | def test_allow_only_kwargs_too_few_args(): 155 | def f(a, /, b): 156 | return a + b 157 | 158 | too_few_match = re.escape("Expected arguments: ['a', 'b'], missing: {'b'}") 159 | with pytest.raises(ValueError, match=too_few_match): 160 | allow_only_kwargs(f)(a=1) 161 | 162 | 163 | def test_allow_only_kwargs_signature_change(): 164 | def f(a, /, b, *, c): 165 | pass 166 | 167 | decorated = allow_only_kwargs(f) 168 | parameters = inspect.signature(decorated).parameters 169 | 170 | assert parameters["a"].kind == inspect.Parameter.KEYWORD_ONLY 171 | assert parameters["b"].kind == inspect.Parameter.KEYWORD_ONLY 172 | assert parameters["c"].kind == inspect.Parameter.KEYWORD_ONLY 173 | 174 | 175 | # ====================================================================================== 176 | # allow args 177 | # ====================================================================================== 178 | 179 | 180 | def test_allow_args(): 181 | def f(a, *, b): 182 | # b is keyword-only 183 | return a + b 184 | 185 | with pytest.raises(TypeError): 186 | f(1, 2) # type: ignore[misc] 187 | 188 | assert allow_args(f)(1, 2) == 3 189 | assert allow_args(f)(1, b=2) == 3 190 | assert allow_args(f)(b=2, a=1) == 3 191 | 192 | 193 | def test_allow_args_different_kwargs_order(): 194 | def f(a, b, c, *, d): 195 | return a + b + c + d 196 | 197 | with pytest.raises(TypeError): 198 | f(1, 2, 3, 4) # type: ignore[misc] 199 | 200 | assert allow_args(f)(1, 2, 3, 4) == 10 201 | assert allow_args(f)(1, 2, d=4, c=3) == 10 202 | 203 | 204 | def test_allow_args_too_many_args(): 205 | def f(a, *, b): 206 | return a + b 207 | 208 | with pytest.raises(ValueError, match="Too many arguments provided."): 209 | allow_args(f)(1, 2, b=3) 210 | 211 | 212 | def test_allow_args_too_few_args(): 213 | def f(a, *, b): 214 | return a + b 215 | 216 | with pytest.raises(ValueError, match="Not all arguments provided."): 217 | allow_args(f)(1) 218 | 219 | 220 | def test_allow_args_with_vmap(): 221 | def f(a, *, b): 222 | # b is keyword-only 223 | return a + b 224 | 225 | f_vmapped = vmap(f, in_axes=(0, 0)) 226 | f_allow_args_vmapped = vmap(allow_args(f), in_axes=(0, 0)) 227 | 228 | a = jnp.arange(2) 229 | b = jnp.arange(2) 230 | 231 | with pytest.raises(TypeError): 232 | # TypeError since b is keyword-only 233 | f_vmapped(a, b) # type: ignore[misc] 234 | 235 | with pytest.raises(ValueError, match="vmap in_axes must be an int"): 236 | # ValueError since vmap doesn't support keyword arguments 237 | f_vmapped(a, b=b) 238 | 239 | aaae(f_allow_args_vmapped(a, b), jnp.array([0, 2])) 240 | 241 | 242 | def test_allow_args_signature_change(): 243 | def f(a, /, b, *, c): 244 | pass 245 | 246 | decorated = allow_args(f) 247 | parameters = inspect.signature(decorated).parameters 248 | 249 | assert parameters["a"].kind == inspect.Parameter.POSITIONAL_ONLY 250 | assert parameters["b"].kind == inspect.Parameter.POSITIONAL_OR_KEYWORD 251 | assert parameters["c"].kind == inspect.Parameter.POSITIONAL_OR_KEYWORD 252 | -------------------------------------------------------------------------------- /tests/test_grid_helpers.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | import numpy as np 3 | import pytest 4 | from numpy.testing import assert_array_almost_equal as aaae 5 | 6 | from lcm.grid_helpers import ( 7 | get_linspace_coordinate, 8 | get_logspace_coordinate, 9 | linspace, 10 | logspace, 11 | ) 12 | from lcm.ndimage import map_coordinates 13 | 14 | 15 | def test_linspace(): 16 | calculated = linspace(start=1, stop=2, n_points=6) 17 | expected = np.array([1, 1.2, 1.4, 1.6, 1.8, 2]) 18 | aaae(calculated, expected) 19 | 20 | 21 | def test_linspace_mapped_value(): 22 | """For reference of the grid values, see expected grid in `test_linspace`.""" 23 | # Get position corresponding to a value in the grid 24 | calculated = get_linspace_coordinate( 25 | value=1.2, 26 | start=1, 27 | stop=2, 28 | n_points=6, 29 | ) 30 | assert np.allclose(calculated, 1.0) 31 | 32 | # Get position corresponding to a value that is between two grid points 33 | # ---------------------------------------------------------------------------------- 34 | # Here, the value is 1.3, that is in the middle of 1.2 and 1.4, which have the 35 | # positions 1 and 2, respectively. Therefore, we want the position to be 1.5. 36 | calculated = get_linspace_coordinate( 37 | value=1.3, 38 | start=1, 39 | stop=2, 40 | n_points=6, 41 | ) 42 | assert np.allclose(calculated, 1.5) 43 | 44 | # Get position corresponding to a value that is outside the grid 45 | calculated = get_linspace_coordinate( 46 | value=0.6, 47 | start=1, 48 | stop=2, 49 | n_points=6, 50 | ) 51 | assert np.allclose(calculated, -2.0) 52 | 53 | 54 | def test_logspace(): 55 | calculated = logspace(start=1, stop=100, n_points=7) 56 | expected = np.array( 57 | [1.0, 2.15443469, 4.64158883, 10.0, 21.5443469, 46.41588834, 100.0], 58 | ) 59 | aaae(calculated, expected) 60 | 61 | 62 | def test_logspace_mapped_value(): 63 | """For reference of the grid values, see expected grid in `test_logspace`.""" 64 | calculated = get_logspace_coordinate( 65 | value=(2.15443469 + 4.64158883) / 2, 66 | start=1, 67 | stop=100, 68 | n_points=7, 69 | ) 70 | assert np.allclose(calculated, 1.5) 71 | 72 | 73 | @pytest.mark.illustrative 74 | def test_map_coordinates_linear(): 75 | """Illustrative test on how the output of get_linspace_coordinate can be used.""" 76 | grid_info = { 77 | "start": 0, 78 | "stop": 1, 79 | "n_points": 3, 80 | } 81 | 82 | grid = linspace(**grid_info) # [0, 0.5, 1] 83 | 84 | values = 2 * grid # [0, 1.0, 2.0] 85 | 86 | # We choose a coordinate that is exactly in the middle between the first and second 87 | # entry of the grid. 88 | coordinate = get_linspace_coordinate( 89 | value=0.25, 90 | **grid_info, 91 | ) 92 | 93 | # Perform the linear interpolation 94 | interpolated_value = map_coordinates(values, [coordinate]) 95 | assert np.allclose(interpolated_value, 0.5) 96 | 97 | 98 | @pytest.mark.illustrative 99 | def test_map_coordinates_logarithmic(): 100 | """Illustrative test on how the output of get_logspace_coordinate can be used.""" 101 | grid_info = { 102 | "start": 1, 103 | "stop": 2, 104 | "n_points": 3, 105 | } 106 | 107 | grid = logspace(**grid_info) # [1.0, 1.414213562373095, 2.0] 108 | 109 | values = 2 * grid # [2.0, 2.82842712474619, 4.0] 110 | 111 | # We choose a coordinate that is exactly in the middle between the first and second 112 | # entry of the grid. 113 | coordinate = get_logspace_coordinate( 114 | value=(1.0 + 1.414213562373095) / 2, 115 | **grid_info, 116 | ) 117 | 118 | # Perform the linear interpolation 119 | interpolated_value = map_coordinates(values, [coordinate]) 120 | assert np.allclose(interpolated_value, (2.0 + 2.82842712474619) / 2) 121 | 122 | 123 | @pytest.mark.illustrative 124 | def test_map_coordinates_linear_outside_grid(): 125 | """Illustrative test on what happens to values outside the grid.""" 126 | grid_info = { 127 | "start": 1, 128 | "stop": 2, 129 | "n_points": 2, 130 | } 131 | 132 | grid = linspace(**grid_info) # [1, 2] 133 | 134 | values = 2 * grid # [2, 4] 135 | 136 | # Get coordinates corresponding to values outside the grid [1, 2] 137 | coordinates = jnp.array( 138 | [get_linspace_coordinate(grid_val, **grid_info) for grid_val in [-1, 0, 3]] 139 | ) 140 | 141 | interpolated_value = map_coordinates(values, [coordinates]) 142 | 143 | aaae(interpolated_value, [-2, 0, 6]) 144 | -------------------------------------------------------------------------------- /tests/test_grids.py: -------------------------------------------------------------------------------- 1 | import re 2 | from dataclasses import make_dataclass 3 | 4 | import numpy as np 5 | import pytest 6 | 7 | from lcm.exceptions import GridInitializationError 8 | from lcm.grids import ( 9 | DiscreteGrid, 10 | LinspaceGrid, 11 | LogspaceGrid, 12 | _get_field_names_and_values, 13 | _validate_continuous_grid, 14 | _validate_discrete_grid, 15 | ) 16 | 17 | 18 | def test_validate_discrete_grid_empty(): 19 | category_class = make_dataclass("Category", []) 20 | error_msg = "category_class passed to DiscreteGrid must have at least one field" 21 | with pytest.raises(GridInitializationError, match=error_msg): 22 | _validate_discrete_grid(category_class) 23 | 24 | 25 | def test_validate_discrete_grid_non_scalar_input(): 26 | category_class = make_dataclass("Category", [("a", int, 1), ("b", str, "s")]) 27 | error_msg = ( 28 | "Field values of the category_class passed to DiscreteGrid can only be " 29 | "scalar int or float values. The values to the following fields are not: ['b']" 30 | ) 31 | with pytest.raises(GridInitializationError, match=re.escape(error_msg)): 32 | _validate_discrete_grid(category_class) 33 | 34 | 35 | def test_validate_discrete_grid_none_input(): 36 | category_class = make_dataclass("Category", [("a", int), ("b", int, 1)]) 37 | error_msg = ( 38 | "Field values of the category_class passed to DiscreteGrid can only be " 39 | "scalar int or float values. The values to the following fields are not: ['a']" 40 | ) 41 | with pytest.raises(GridInitializationError, match=re.escape(error_msg)): 42 | _validate_discrete_grid(category_class) 43 | 44 | 45 | def test_validate_discrete_grid_non_unique(): 46 | category_class = make_dataclass("Category", [("a", int, 1), ("b", int, 1)]) 47 | error_msg = ( 48 | "Field values of the category_class passed to DiscreteGrid must be unique. " 49 | "The following values are duplicated: {1}" 50 | ) 51 | with pytest.raises(GridInitializationError, match=error_msg): 52 | _validate_discrete_grid(category_class) 53 | 54 | 55 | def test_validate_discrete_grid_non_consecutive_unordered(): 56 | category_class = make_dataclass("Category", [("a", int, 1), ("b", int, 0)]) 57 | error_msg = "Field values of the category_class passed to DiscreteGrid must be " 58 | with pytest.raises(GridInitializationError, match=error_msg): 59 | _validate_discrete_grid(category_class) 60 | 61 | 62 | def test_validate_discrete_grid_non_consecutive_jumps(): 63 | category_class = make_dataclass("Category", [("a", int, 0), ("b", int, 2)]) 64 | error_msg = "Field values of the category_class passed to DiscreteGrid must be " 65 | with pytest.raises(GridInitializationError, match=error_msg): 66 | _validate_discrete_grid(category_class) 67 | 68 | 69 | def test_get_fields_with_defaults(): 70 | category_class = make_dataclass("Category", [("a", int, 1), ("b", int, 2)]) 71 | assert _get_field_names_and_values(category_class) == {"a": 1, "b": 2} 72 | 73 | 74 | def test_get_fields_no_defaults(): 75 | category_class = make_dataclass("Category", [("a", int), ("b", int)]) 76 | assert _get_field_names_and_values(category_class) == {"a": None, "b": None} 77 | 78 | 79 | def test_get_fields_instance(): 80 | category_class = make_dataclass("Category", [("a", int), ("b", int)]) 81 | assert _get_field_names_and_values(category_class(a=1, b=2)) == {"a": 1, "b": 2} 82 | 83 | 84 | def test_get_fields_empty(): 85 | category_class = make_dataclass("Category", []) 86 | assert _get_field_names_and_values(category_class) == {} 87 | 88 | 89 | def test_validate_continuous_grid_invalid_start(): 90 | error_msg = "start must be a scalar int or float value" 91 | with pytest.raises(GridInitializationError, match=error_msg): 92 | _validate_continuous_grid("a", 1, 10) # type: ignore[arg-type] 93 | 94 | 95 | def test_validate_continuous_grid_invalid_stop(): 96 | error_msg = "stop must be a scalar int or float value" 97 | with pytest.raises(GridInitializationError, match=error_msg): 98 | _validate_continuous_grid(1, "a", 10) # type: ignore[arg-type] 99 | 100 | 101 | def test_validate_continuous_grid_invalid_n_points(): 102 | error_msg = "n_points must be an int greater than 0 but is a" 103 | with pytest.raises(GridInitializationError, match=error_msg): 104 | _validate_continuous_grid(1, 2, "a") # type: ignore[arg-type] 105 | 106 | 107 | def test_validate_continuous_grid_negative_n_points(): 108 | error_msg = "n_points must be an int greater than 0 but is -1" 109 | with pytest.raises(GridInitializationError, match=error_msg): 110 | _validate_continuous_grid(1, 2, -1) 111 | 112 | 113 | def test_validate_continuous_grid_start_greater_than_stop(): 114 | error_msg = "start must be less than stop" 115 | with pytest.raises(GridInitializationError, match=error_msg): 116 | _validate_continuous_grid(2, 1, 10) 117 | 118 | 119 | def test_linspace_grid_creation(): 120 | grid = LinspaceGrid(start=1, stop=5, n_points=5) 121 | assert np.allclose(grid.to_jax(), np.linspace(1, 5, 5)) 122 | 123 | 124 | def test_logspace_grid_creation(): 125 | grid = LogspaceGrid(start=1, stop=10, n_points=3) 126 | assert np.allclose(grid.to_jax(), np.logspace(np.log10(1), np.log10(10), 3)) 127 | 128 | 129 | def test_discrete_grid_creation(): 130 | category_class = make_dataclass( 131 | "Category", [("a", int, 0), ("b", int, 1), ("c", int, 2)] 132 | ) 133 | grid = DiscreteGrid(category_class) 134 | assert np.allclose(grid.to_jax(), np.arange(3)) 135 | 136 | 137 | def test_linspace_grid_invalid_start(): 138 | with pytest.raises(GridInitializationError, match="start must be less than stop"): 139 | LinspaceGrid(start=1, stop=0, n_points=10) 140 | 141 | 142 | def test_logspace_grid_invalid_start(): 143 | with pytest.raises(GridInitializationError, match="start must be less than stop"): 144 | LogspaceGrid(start=1, stop=0, n_points=10) 145 | 146 | 147 | def test_discrete_grid_invalid_category_class(): 148 | category_class = make_dataclass( 149 | "Category", [("a", int, 0), ("b", str, "wrong_type")] 150 | ) 151 | with pytest.raises( 152 | GridInitializationError, 153 | match="Field values of the category_class passed to DiscreteGrid can only be", 154 | ): 155 | DiscreteGrid(category_class) 156 | 157 | 158 | def test_replace_mixin(): 159 | grid = LinspaceGrid(start=1, stop=5, n_points=5) 160 | new_grid = grid.replace(start=0) 161 | assert new_grid.start == 0 162 | assert new_grid.stop == 5 163 | assert new_grid.n_points == 5 164 | -------------------------------------------------------------------------------- /tests/test_max_Qc_over_d.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | import pandas as pd 3 | import pytest 4 | from numpy.testing import assert_array_almost_equal as aaae 5 | 6 | from lcm.max_Qc_over_d import ( 7 | _determine_discrete_action_axes_simulation, 8 | _determine_discrete_action_axes_solution, 9 | _max_Qc_over_d_extreme_value_shocks, 10 | _max_Qc_over_d_no_shocks, 11 | get_max_Qc_over_d, 12 | ) 13 | from lcm.typing import ShockType 14 | 15 | # ====================================================================================== 16 | # Illustrative 17 | # ====================================================================================== 18 | 19 | 20 | @pytest.mark.illustrative 21 | def test_get_solve_discrete_problem_illustrative(): 22 | variable_info = pd.DataFrame( 23 | { 24 | "is_action": [False, True], 25 | "is_state": [True, False], 26 | "is_discrete": [True, True], 27 | "is_continuous": [False, False], 28 | }, 29 | ) # leads to discrete_action_axes = [1] 30 | 31 | max_Qc_over_d = get_max_Qc_over_d( 32 | random_utility_shock_type=ShockType.NONE, 33 | variable_info=variable_info, 34 | is_last_period=False, 35 | ) 36 | 37 | Qc_arr = jnp.array( 38 | [ 39 | [0, 1], 40 | [2, 3], 41 | [4, 5], 42 | ], 43 | ) 44 | 45 | got = max_Qc_over_d(Qc_arr, params={}) 46 | aaae(got, jnp.array([1, 3, 5])) 47 | 48 | 49 | @pytest.mark.illustrative 50 | def test_solve_discrete_problem_no_shocks_illustrative_single_action_axis(): 51 | Qc_arr = jnp.array( 52 | [ 53 | [0, 1], 54 | [2, 3], 55 | [4, 5], 56 | ], 57 | ) 58 | got = _max_Qc_over_d_no_shocks( 59 | Qc_arr, 60 | discrete_action_axes=(0,), 61 | params={}, 62 | ) 63 | aaae(got, jnp.array([4, 5])) 64 | 65 | 66 | @pytest.mark.illustrative 67 | def test_solve_discrete_problem_no_shocks_illustrative_multiple_action_axes(): 68 | Qc_arr = jnp.array( 69 | [ 70 | [0, 1], 71 | [2, 3], 72 | [4, 5], 73 | ], 74 | ) 75 | got = _max_Qc_over_d_no_shocks( 76 | Qc_arr, 77 | discrete_action_axes=(0, 1), 78 | params={}, 79 | ) 80 | aaae(got, 5) 81 | 82 | 83 | @pytest.mark.illustrative 84 | def test_max_Qc_over_d_extreme_value_shocks_illustrative_single_action_axis(): 85 | Qc_arr = jnp.array( 86 | [ 87 | [0, 1], 88 | [2, 3], 89 | [4, 5], 90 | ], 91 | ) 92 | 93 | got = _max_Qc_over_d_extreme_value_shocks( 94 | Qc_arr, 95 | discrete_action_axes=(0,), 96 | params={"additive_utility_shock": {"scale": 0.1}}, 97 | ) 98 | aaae(got, jnp.array([4, 5]), decimal=5) 99 | 100 | 101 | @pytest.mark.illustrative 102 | def test_max_Qc_over_d_extreme_value_shocks_illustrative_multiple_action_axes(): 103 | Qc_arr = jnp.array( 104 | [ 105 | [0, 1], 106 | [2, 3], 107 | [4, 5], 108 | ], 109 | ) 110 | got = _max_Qc_over_d_extreme_value_shocks( 111 | Qc_arr, 112 | discrete_action_axes=(0, 1), 113 | params={"additive_utility_shock": {"scale": 0.1}}, 114 | ) 115 | aaae(got, 5, decimal=5) 116 | 117 | 118 | # ====================================================================================== 119 | # Determine discrete action axes 120 | # ====================================================================================== 121 | 122 | 123 | @pytest.mark.illustrative 124 | def test_determine_discrete_action_axes_illustrative_one_var(): 125 | variable_info = pd.DataFrame( 126 | { 127 | "is_action": [False, True], 128 | "is_state": [True, False], 129 | "is_discrete": [True, True], 130 | "is_continuous": [False, False], 131 | }, 132 | ) 133 | 134 | assert _determine_discrete_action_axes_solution(variable_info) == (1,) 135 | 136 | 137 | @pytest.mark.illustrative 138 | def test_determine_discrete_action_axes_illustrative_three_var(): 139 | variable_info = pd.DataFrame( 140 | { 141 | "is_action": [False, True, True, True], 142 | "is_state": [True, False, False, False], 143 | "is_discrete": [True, True, True, True], 144 | "is_continuous": [False, False, False, False], 145 | }, 146 | ) 147 | 148 | assert _determine_discrete_action_axes_solution(variable_info) == (1, 2, 3) 149 | 150 | 151 | def test_determine_discrete_action_axes(): 152 | variable_info = pd.DataFrame( 153 | { 154 | "is_state": [True, True, False, True, False, False], 155 | "is_action": [False, False, True, True, True, True], 156 | "is_discrete": [True, True, True, True, True, False], 157 | "is_continuous": [False, True, False, False, False, True], 158 | }, 159 | ) 160 | got = _determine_discrete_action_axes_simulation(variable_info) 161 | assert got == (1, 2, 3) 162 | -------------------------------------------------------------------------------- /tests/test_model.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from lcm.exceptions import ModelInitilizationError 4 | from lcm.grids import DiscreteGrid 5 | from lcm.user_model import Model 6 | 7 | 8 | def test_model_invalid_states(): 9 | with pytest.raises(ModelInitilizationError, match="states must be a dictionary"): 10 | Model( 11 | n_periods=2, 12 | states="health", # type: ignore[arg-type] 13 | actions={}, 14 | functions={"utility": lambda: 0}, 15 | ) 16 | 17 | 18 | def test_model_invalid_actions(): 19 | with pytest.raises(ModelInitilizationError, match="actions must be a dictionary"): 20 | Model( 21 | n_periods=2, 22 | states={}, 23 | actions="exercise", # type: ignore[arg-type] 24 | functions={"utility": lambda: 0}, 25 | ) 26 | 27 | 28 | def test_model_invalid_functions(): 29 | with pytest.raises(ModelInitilizationError, match="functions must be a dictionary"): 30 | Model( 31 | n_periods=2, 32 | states={}, 33 | actions={}, 34 | functions="utility", # type: ignore[arg-type] 35 | ) 36 | 37 | 38 | def test_model_invalid_functions_values(): 39 | with pytest.raises( 40 | ModelInitilizationError, match="function values must be a callable, but is 0." 41 | ): 42 | Model( 43 | n_periods=2, 44 | states={}, 45 | actions={}, 46 | functions={"utility": 0}, # type: ignore[dict-item] 47 | ) 48 | 49 | 50 | def test_model_invalid_functions_keys(): 51 | with pytest.raises( 52 | ModelInitilizationError, match="function keys must be a strings, but is 0." 53 | ): 54 | Model( 55 | n_periods=2, 56 | states={}, 57 | actions={}, 58 | functions={0: lambda: 0}, # type: ignore[dict-item] 59 | ) 60 | 61 | 62 | def test_model_invalid_actions_values(): 63 | with pytest.raises( 64 | ModelInitilizationError, match="actions value 0 must be an LCM grid." 65 | ): 66 | Model( 67 | n_periods=2, 68 | states={}, 69 | actions={"exercise": 0}, # type: ignore[dict-item] 70 | functions={"utility": lambda: 0}, 71 | ) 72 | 73 | 74 | def test_model_invalid_states_values(): 75 | with pytest.raises( 76 | ModelInitilizationError, match="states value 0 must be an LCM grid." 77 | ): 78 | Model( 79 | n_periods=2, 80 | states={"health": 0}, # type: ignore[dict-item] 81 | actions={}, 82 | functions={"utility": lambda: 0}, 83 | ) 84 | 85 | 86 | def test_model_invalid_n_periods(): 87 | with pytest.raises( 88 | ModelInitilizationError, match="Number of periods must be a positive integer." 89 | ): 90 | Model( 91 | n_periods=0, 92 | states={}, 93 | actions={}, 94 | functions={"utility": lambda: 0}, 95 | ) 96 | 97 | 98 | def test_model_missing_next_func(binary_category_class): 99 | with pytest.raises( 100 | ModelInitilizationError, 101 | match="Each state must have a corresponding next state function.", 102 | ): 103 | Model( 104 | n_periods=2, 105 | states={"health": DiscreteGrid(binary_category_class)}, 106 | actions={"exercise": DiscreteGrid(binary_category_class)}, 107 | functions={"utility": lambda: 0}, 108 | ) 109 | 110 | 111 | def test_model_missing_utility(): 112 | with pytest.raises( 113 | ModelInitilizationError, 114 | match="Utility function is not defined. LCM expects a function called 'utility", 115 | ): 116 | Model( 117 | n_periods=2, 118 | states={}, 119 | actions={}, 120 | functions={}, 121 | ) 122 | 123 | 124 | def test_model_overlapping_states_actions(binary_category_class): 125 | with pytest.raises( 126 | ModelInitilizationError, 127 | match="States and actions cannot have overlapping names.", 128 | ): 129 | Model( 130 | n_periods=2, 131 | states={"health": DiscreteGrid(binary_category_class)}, 132 | actions={"health": DiscreteGrid(binary_category_class)}, 133 | functions={"utility": lambda: 0}, 134 | ) 135 | -------------------------------------------------------------------------------- /tests/test_models/__init__.py: -------------------------------------------------------------------------------- 1 | from .get_model import get_model_config, get_params 2 | 3 | __all__ = ["get_model_config", "get_params"] 4 | -------------------------------------------------------------------------------- /tests/test_models/deterministic.py: -------------------------------------------------------------------------------- 1 | """Example specifications of a deterministic consumption-saving model. 2 | 3 | The specification builds on the example model presented in the paper: "The endogenous 4 | grid method for discrete-continuous dynamic action models with (or without) taste 5 | shocks" by Fedor Iskhakov, Thomas H. Jørgensen, John Rust and Bertel Schjerning (2017, 6 | https://doi.org/10.3982/QE643). 7 | 8 | """ 9 | 10 | from dataclasses import dataclass 11 | 12 | import jax.numpy as jnp 13 | 14 | from lcm import DiscreteGrid, LinspaceGrid, Model 15 | 16 | # ====================================================================================== 17 | # Model functions 18 | # ====================================================================================== 19 | 20 | 21 | # -------------------------------------------------------------------------------------- 22 | # Categorical variables 23 | # -------------------------------------------------------------------------------------- 24 | @dataclass 25 | class RetirementStatus: 26 | working: int = 0 27 | retired: int = 1 28 | 29 | 30 | # -------------------------------------------------------------------------------------- 31 | # Utility functions 32 | # -------------------------------------------------------------------------------------- 33 | def utility(consumption, working, disutility_of_work): 34 | return jnp.log(consumption) - disutility_of_work * working 35 | 36 | 37 | def utility_with_constraint( 38 | consumption, 39 | working, 40 | disutility_of_work, 41 | # Temporary workaround for bug described in issue #30, which requires us to pass 42 | # all state variables to the utility function. 43 | # TODO(@timmens): Remove function once #30 is fixed (re-use "utility"). 44 | # https://github.com/opensourceeconomics/pylcm/issues/30 45 | lagged_retirement, # noqa: ARG001 46 | ): 47 | return utility(consumption, working, disutility_of_work) 48 | 49 | 50 | # -------------------------------------------------------------------------------------- 51 | # Auxiliary variables 52 | # -------------------------------------------------------------------------------------- 53 | def labor_income(working, wage): 54 | return working * wage 55 | 56 | 57 | def working(retirement): 58 | return 1 - retirement 59 | 60 | 61 | def wage(age): 62 | return 1 + 0.1 * age 63 | 64 | 65 | def age(_period): 66 | return _period + 18 67 | 68 | 69 | # -------------------------------------------------------------------------------------- 70 | # State transitions 71 | # -------------------------------------------------------------------------------------- 72 | def next_wealth(wealth, consumption, labor_income, interest_rate): 73 | return (1 + interest_rate) * (wealth - consumption) + labor_income 74 | 75 | 76 | # -------------------------------------------------------------------------------------- 77 | # Constraints 78 | # -------------------------------------------------------------------------------------- 79 | def consumption_constraint(consumption, wealth): 80 | return consumption <= wealth 81 | 82 | 83 | def absorbing_retirement_constraint(retirement, lagged_retirement): 84 | return jnp.logical_or( 85 | retirement == RetirementStatus.retired, 86 | lagged_retirement == RetirementStatus.working, 87 | ) 88 | 89 | 90 | # ====================================================================================== 91 | # Model specifications 92 | # ====================================================================================== 93 | 94 | ISKHAKOV_ET_AL_2017 = Model( 95 | description=( 96 | "Corresponds to the example model in Iskhakov et al. (2017). In comparison to " 97 | "the extensions below, wage is treated as a constant parameter and therefore " 98 | "there is no need for the wage and age functions." 99 | ), 100 | n_periods=3, 101 | functions={ 102 | "utility": utility_with_constraint, 103 | "next_wealth": next_wealth, 104 | "next_lagged_retirement": lambda retirement: retirement, 105 | "consumption_constraint": consumption_constraint, 106 | "absorbing_retirement_constraint": absorbing_retirement_constraint, 107 | "labor_income": labor_income, 108 | "working": working, 109 | }, 110 | actions={ 111 | "retirement": DiscreteGrid(RetirementStatus), 112 | "consumption": LinspaceGrid( 113 | start=1, 114 | stop=400, 115 | n_points=500, 116 | ), 117 | }, 118 | states={ 119 | "wealth": LinspaceGrid( 120 | start=1, 121 | stop=400, 122 | n_points=100, 123 | ), 124 | "lagged_retirement": DiscreteGrid(RetirementStatus), 125 | }, 126 | ) 127 | 128 | 129 | ISKHAKOV_ET_AL_2017_STRIPPED_DOWN = Model( 130 | description=( 131 | "Starts from Iskhakov et al. (2017), removes absorbing retirement constraint " 132 | "and the lagged_retirement state, and adds wage function that depends on age." 133 | ), 134 | n_periods=3, 135 | functions={ 136 | "utility": utility, 137 | "next_wealth": next_wealth, 138 | "consumption_constraint": consumption_constraint, 139 | "labor_income": labor_income, 140 | "working": working, 141 | "wage": wage, 142 | "age": age, 143 | }, 144 | actions={ 145 | "retirement": DiscreteGrid(RetirementStatus), 146 | "consumption": LinspaceGrid( 147 | start=1, 148 | stop=400, 149 | n_points=500, 150 | ), 151 | }, 152 | states={ 153 | "wealth": LinspaceGrid( 154 | start=1, 155 | stop=400, 156 | n_points=100, 157 | ), 158 | }, 159 | ) 160 | -------------------------------------------------------------------------------- /tests/test_models/discrete_deterministic.py: -------------------------------------------------------------------------------- 1 | """Example specifications of fully discrete deterministic consumption-saving model. 2 | 3 | The specification builds on the example model presented in the paper: "The endogenous 4 | grid method for discrete-continuous dynamic action models with (or without) taste 5 | shocks" by Fedor Iskhakov, Thomas H. Jørgensen, John Rust and Bertel Schjerning (2017, 6 | https://doi.org/10.3982/QE643). See module `tests.test_models.deterministic` for the 7 | continuous version. 8 | 9 | """ 10 | 11 | from dataclasses import dataclass 12 | 13 | import jax.numpy as jnp 14 | 15 | from lcm import DiscreteGrid, Model 16 | from tests.test_models.deterministic import ( 17 | RetirementStatus, 18 | labor_income, 19 | next_wealth, 20 | utility, 21 | working, 22 | ) 23 | 24 | # ====================================================================================== 25 | # Model functions 26 | # ====================================================================================== 27 | 28 | 29 | # -------------------------------------------------------------------------------------- 30 | # Categorical variables 31 | # -------------------------------------------------------------------------------------- 32 | @dataclass 33 | class ConsumptionChoice: 34 | low: int = 0 35 | high: int = 1 36 | 37 | 38 | @dataclass 39 | class WealthStatus: 40 | low: int = 0 41 | medium: int = 1 42 | high: int = 2 43 | 44 | 45 | # -------------------------------------------------------------------------------------- 46 | # Utility functions 47 | # -------------------------------------------------------------------------------------- 48 | def utility_discrete(consumption, working, disutility_of_work): 49 | # In the discrete model, consumption is defined as "low" or "high". This can be 50 | # translated to the levels 1 and 2. 51 | consumption_level = 1 + (consumption == ConsumptionChoice.high) 52 | return utility(consumption_level, working, disutility_of_work) 53 | 54 | 55 | # -------------------------------------------------------------------------------------- 56 | # State transitions 57 | # -------------------------------------------------------------------------------------- 58 | def next_wealth_discrete(wealth, consumption, labor_income, interest_rate): 59 | # For discrete state variables, we need to assure that the next state is also a 60 | # valid state, i.e., it is a member of the discrete grid. 61 | continuous = next_wealth(wealth, consumption, labor_income, interest_rate) 62 | return jnp.clip(jnp.rint(continuous), WealthStatus.low, WealthStatus.high).astype( 63 | jnp.int32 64 | ) 65 | 66 | 67 | # -------------------------------------------------------------------------------------- 68 | # Constraints 69 | # -------------------------------------------------------------------------------------- 70 | def consumption_constraint(consumption, wealth): 71 | return consumption <= wealth 72 | 73 | 74 | # ====================================================================================== 75 | # Model specifications 76 | # ====================================================================================== 77 | ISKHAKOV_ET_AL_2017_DISCRETE = Model( 78 | description=( 79 | "Starts from Iskhakov et al. (2017), removes absorbing retirement constraint " 80 | "and the lagged_retirement state, and makes the consumption decision discrete." 81 | ), 82 | n_periods=3, 83 | functions={ 84 | "utility": utility_discrete, 85 | "next_wealth": next_wealth_discrete, 86 | "consumption_constraint": consumption_constraint, 87 | "labor_income": labor_income, 88 | "working": working, 89 | }, 90 | actions={ 91 | "retirement": DiscreteGrid(RetirementStatus), 92 | "consumption": DiscreteGrid(ConsumptionChoice), 93 | }, 94 | states={ 95 | "wealth": DiscreteGrid(WealthStatus), 96 | }, 97 | ) 98 | -------------------------------------------------------------------------------- /tests/test_models/get_model.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | import jax.numpy as jnp 4 | 5 | from lcm.user_model import Model 6 | from tests.test_models.deterministic import ( 7 | ISKHAKOV_ET_AL_2017, 8 | ISKHAKOV_ET_AL_2017_STRIPPED_DOWN, 9 | ) 10 | from tests.test_models.discrete_deterministic import ISKHAKOV_ET_AL_2017_DISCRETE 11 | from tests.test_models.stochastic import ISKHAKOV_ET_AL_2017_STOCHASTIC 12 | 13 | TEST_MODELS = { 14 | "iskhakov_et_al_2017": ISKHAKOV_ET_AL_2017, 15 | "iskhakov_et_al_2017_stripped_down": ISKHAKOV_ET_AL_2017_STRIPPED_DOWN, 16 | "iskhakov_et_al_2017_discrete": ISKHAKOV_ET_AL_2017_DISCRETE, 17 | "iskhakov_et_al_2017_stochastic": ISKHAKOV_ET_AL_2017_STOCHASTIC, 18 | } 19 | 20 | 21 | def get_model_config(model_name: str, n_periods: int) -> Model: 22 | model_config = deepcopy(TEST_MODELS[model_name]) 23 | return model_config.replace(n_periods=n_periods) 24 | 25 | 26 | def get_params( 27 | beta=0.95, 28 | disutility_of_work=0.5, 29 | interest_rate=0.05, 30 | wage=10.0, 31 | health_transition=None, 32 | partner_transition=None, 33 | ): 34 | # ---------------------------------------------------------------------------------- 35 | # Transition matrices 36 | # ---------------------------------------------------------------------------------- 37 | 38 | # Health shock transition: 39 | # ------------------------------------------------------------------------------ 40 | # 1st dimension: Current health state 41 | # 2nd dimension: Current Partner state 42 | # 3rd dimension: Probability distribution over next period's health state 43 | default_health_transition = jnp.array( 44 | [ 45 | # Current health state 0 46 | [ 47 | # Current Partner state 0 48 | [0.9, 0.1], 49 | # Current Partner state 1 50 | [0.5, 0.5], 51 | ], 52 | # Current health state 1 53 | [ 54 | # Current Partner state 0 55 | [0.5, 0.5], 56 | # Current Partner state 1 57 | [0.1, 0.9], 58 | ], 59 | ], 60 | ) 61 | health_transition = ( 62 | default_health_transition if health_transition is None else health_transition 63 | ) 64 | 65 | # Partner shock transition: 66 | # ------------------------------------------------------------------------------ 67 | # 1st dimension: The period 68 | # 2nd dimension: Current working decision 69 | # 3rd dimension: Current partner state 70 | # 4th dimension: Probability distribution over next period's partner state 71 | default_partner_transition = jnp.array( 72 | [ 73 | # Transition from period 0 to period 1 74 | [ 75 | # Current working decision 0 76 | [ 77 | # Current partner state 0 78 | [0, 1.0], 79 | # Current partner state 1 80 | [1.0, 0], 81 | ], 82 | # Current working decision 1 83 | [ 84 | # Current partner state 0 85 | [0, 1.0], 86 | # Current partner state 1 87 | [0.0, 1.0], 88 | ], 89 | ], 90 | # Transition from period 1 to period 2 91 | [ 92 | # Current working decision 0 93 | [ 94 | # Current partner state 0 95 | [0, 1.0], 96 | # Current partner state 1 97 | [1.0, 0], 98 | ], 99 | # Current working decision 1 100 | [ 101 | # Current partner state 0 102 | [0, 1.0], 103 | # Current partner state 1 104 | [0.0, 1.0], 105 | ], 106 | ], 107 | ], 108 | ) 109 | partner_transition = ( 110 | default_partner_transition if partner_transition is None else partner_transition 111 | ) 112 | 113 | # ---------------------------------------------------------------------------------- 114 | # Model parameters 115 | # ---------------------------------------------------------------------------------- 116 | return { 117 | "beta": beta, 118 | "utility": {"disutility_of_work": disutility_of_work}, 119 | "next_wealth": {"interest_rate": interest_rate}, 120 | "next_health": {}, 121 | "consumption_constraint": {}, 122 | "labor_income": {"wage": wage}, 123 | "shocks": { 124 | "health": health_transition, 125 | "partner": partner_transition, 126 | }, 127 | } 128 | -------------------------------------------------------------------------------- /tests/test_models/stochastic.py: -------------------------------------------------------------------------------- 1 | """Example specification of a stochastic consumption-saving model. 2 | 3 | This specification is motivated by the example model presented in the paper: "The 4 | endogenous grid method for discrete-continuous dynamic action models with (or without) 5 | taste shocks" by Fedor Iskhakov, Thomas H. Jørgensen, John Rust and Bertel Schjerning 6 | (2017, https://doi.org/10.3982/QE643). 7 | 8 | See also the specifications in tests/test_models/deterministic.py. 9 | 10 | """ 11 | 12 | from dataclasses import dataclass 13 | 14 | import jax.numpy as jnp 15 | 16 | import lcm 17 | from lcm import DiscreteGrid, LinspaceGrid, Model 18 | 19 | # ====================================================================================== 20 | # Model functions 21 | # ====================================================================================== 22 | 23 | 24 | # -------------------------------------------------------------------------------------- 25 | # Categorical variables 26 | # -------------------------------------------------------------------------------------- 27 | @dataclass 28 | class HealthStatus: 29 | bad: int = 0 30 | good: int = 1 31 | 32 | 33 | @dataclass 34 | class PartnerStatus: 35 | single: int = 0 36 | partnered: int = 1 37 | 38 | 39 | @dataclass 40 | class WorkingStatus: 41 | retired: int = 0 42 | working: int = 1 43 | 44 | 45 | # -------------------------------------------------------------------------------------- 46 | # Utility function 47 | # -------------------------------------------------------------------------------------- 48 | def utility( 49 | consumption, 50 | working, 51 | health, 52 | # Temporary workaround for bug described in issue #30, which requires us to pass 53 | # all state variables to the utility function. 54 | # TODO(@timmens): Remove function arguments once #30 is fixed. 55 | # https://github.com/opensourceeconomics/pylcm/issues/30 56 | partner, # noqa: ARG001 57 | disutility_of_work, 58 | ): 59 | return jnp.log(consumption) - (1 - health / 2) * disutility_of_work * working 60 | 61 | 62 | # -------------------------------------------------------------------------------------- 63 | # Auxiliary variables 64 | # -------------------------------------------------------------------------------------- 65 | def labor_income(working, wage): 66 | return working * wage 67 | 68 | 69 | # -------------------------------------------------------------------------------------- 70 | # Deterministic state transitions 71 | # -------------------------------------------------------------------------------------- 72 | def next_wealth(wealth, consumption, labor_income, interest_rate): 73 | return (1 + interest_rate) * (wealth - consumption) + labor_income 74 | 75 | 76 | # -------------------------------------------------------------------------------------- 77 | # Stochastic state transitions 78 | # -------------------------------------------------------------------------------------- 79 | @lcm.mark.stochastic 80 | def next_health(health, partner): 81 | pass 82 | 83 | 84 | @lcm.mark.stochastic 85 | def next_partner(_period, working, partner): 86 | pass 87 | 88 | 89 | # -------------------------------------------------------------------------------------- 90 | # Constraints 91 | # -------------------------------------------------------------------------------------- 92 | def consumption_constraint(consumption, wealth): 93 | return consumption <= wealth 94 | 95 | 96 | # ====================================================================================== 97 | # Model specification 98 | # ====================================================================================== 99 | 100 | ISKHAKOV_ET_AL_2017_STOCHASTIC = Model( 101 | description=( 102 | "Starts from Iskhakov et al. (2017), removes absorbing retirement constraint " 103 | "and the lagged_retirement state, and adds discrete stochastic state variables " 104 | "health and partner." 105 | ), 106 | n_periods=3, 107 | functions={ 108 | "utility": utility, 109 | "next_wealth": next_wealth, 110 | "next_health": next_health, 111 | "next_partner": next_partner, 112 | "consumption_constraint": consumption_constraint, 113 | "labor_income": labor_income, 114 | }, 115 | actions={ 116 | "working": DiscreteGrid(WorkingStatus), 117 | "consumption": LinspaceGrid( 118 | start=1, 119 | stop=100, 120 | n_points=200, 121 | ), 122 | }, 123 | states={ 124 | "health": DiscreteGrid(HealthStatus), 125 | "partner": DiscreteGrid(PartnerStatus), 126 | "wealth": LinspaceGrid( 127 | start=1, 128 | stop=100, 129 | n_points=100, 130 | ), 131 | }, 132 | ) 133 | -------------------------------------------------------------------------------- /tests/test_ndimage.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The JAX Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # Modifications made by Tim Mensinger, 2024. 16 | 17 | from functools import partial 18 | 19 | import jax.numpy as jnp 20 | import jax.scipy.ndimage 21 | import numpy as np 22 | import pytest 23 | import scipy.ndimage 24 | from numpy.testing import assert_allclose, assert_array_almost_equal, assert_array_equal 25 | 26 | import lcm.ndimage 27 | 28 | jax_map_coordinates = partial(jax.scipy.ndimage.map_coordinates, order=1, cval=0) 29 | scipy_map_coordinates = partial(scipy.ndimage.map_coordinates, order=1, cval=0) 30 | lcm_map_coordinates = lcm.ndimage.map_coordinates 31 | 32 | JAX_BASED_IMPLEMENTATIONS = [jax_map_coordinates, lcm_map_coordinates] 33 | 34 | 35 | TEST_SHAPES = [ 36 | (5,), 37 | (3, 4), 38 | (3, 4, 5), 39 | ] 40 | 41 | TEST_COORDINATES_SHAPES = [ 42 | (7,), 43 | (2, 3, 4), 44 | ] 45 | 46 | 47 | def _make_test_data(shape, coordinates_shape, dtype): 48 | rng = np.random.default_rng() 49 | x = np.arange(np.prod(shape), dtype=dtype).reshape(shape) 50 | c = [(size - 1) * rng.random(coordinates_shape).astype(dtype) for size in shape] 51 | return x, c 52 | 53 | 54 | @pytest.mark.parametrize("map_coordinates", JAX_BASED_IMPLEMENTATIONS) 55 | @pytest.mark.parametrize("shape", TEST_SHAPES) 56 | @pytest.mark.parametrize("coordinates_shape", TEST_COORDINATES_SHAPES) 57 | @pytest.mark.parametrize("dtype", [np.int64, np.float64]) 58 | def test_map_coordinates_against_scipy( 59 | map_coordinates, shape, coordinates_shape, dtype 60 | ): 61 | """Test that JAX and LCM implementations behave as scipy.""" 62 | x, c = _make_test_data(shape, coordinates_shape, dtype=dtype) 63 | 64 | x_jax = jnp.asarray(x) 65 | c_jax = [jnp.asarray(c_i) for c_i in c] 66 | 67 | expected = scipy_map_coordinates(x, c) 68 | got = map_coordinates(x_jax, c_jax) 69 | 70 | assert_array_almost_equal(got, expected, decimal=14) 71 | 72 | 73 | @pytest.mark.parametrize("map_coordinates", JAX_BASED_IMPLEMENTATIONS) 74 | @pytest.mark.parametrize("dtype", [np.int64, np.float64]) 75 | def test_map_coordinates_round_half_against_scipy(map_coordinates, dtype): 76 | """Test that JAX and LCM implementations round as scipy.""" 77 | x = np.arange(-5, 5, dtype=dtype) 78 | c = np.array([[0.5, 1.5, 2.5, 6.5, 8.5]]) 79 | 80 | x_jax = jnp.asarray(x) 81 | c_jax = [jnp.asarray(c_i) for c_i in c] 82 | 83 | expected = scipy_map_coordinates(x, c) 84 | got = map_coordinates(x_jax, c_jax) 85 | 86 | assert_array_equal(got, expected) 87 | 88 | 89 | @pytest.mark.parametrize("map_coordinates", JAX_BASED_IMPLEMENTATIONS) 90 | def test_gradients(map_coordinates): 91 | """Test that JAX and LCM implementations exhibit same gradient behavior.""" 92 | x = jnp.arange(9.0) 93 | border = 3 # square root of 9, as we are considering a parabola on x. 94 | 95 | def f(step): 96 | coordinates = x + step 97 | shifted = map_coordinates(x, [coordinates]) 98 | return ((x - shifted) ** 2)[border:-border].mean() 99 | 100 | # Gradient of f(step) is 2 * step 101 | assert_allclose(jax.grad(f)(0.5), 1.0) 102 | assert_allclose(jax.grad(f)(1.0), 2.0) 103 | -------------------------------------------------------------------------------- /tests/test_ndimage_unit.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | import pytest 3 | from numpy.testing import assert_array_equal 4 | 5 | from lcm.ndimage import ( 6 | _compute_indices_and_weights, 7 | _multiply_all, 8 | _round_half_away_from_zero, 9 | _sum_all, 10 | map_coordinates, 11 | ) 12 | 13 | 14 | def test_map_coordinates_wrong_input_dimensions(): 15 | values = jnp.arange(2) # ndim = 1 16 | coordinates = [jnp.array([0]), jnp.array([1])] # len = 2 17 | with pytest.raises(ValueError, match="coordinates must be a sequence of length"): 18 | map_coordinates(values, coordinates) 19 | 20 | 21 | def test_map_coordinates_extrapolation(): 22 | x = jnp.arange(3.0) 23 | c = [jnp.array([-2.0, -1.0, 5.0, 10.0])] 24 | 25 | got = map_coordinates(x, c) 26 | expected = c[0] 27 | 28 | assert_array_equal(got, expected) 29 | 30 | 31 | def test_nonempty_sum(): 32 | a = jnp.arange(3) 33 | 34 | expected = a + a + a 35 | got = _sum_all([a, a, a]) 36 | 37 | assert_array_equal(got, expected) 38 | 39 | 40 | def test_nonempty_prod(): 41 | a = jnp.arange(3) 42 | 43 | expected = a * a * a 44 | got = _multiply_all([a, a, a]) 45 | 46 | assert_array_equal(got, expected) 47 | 48 | 49 | def test_round_half_away_from_zero_integer(): 50 | a = jnp.array([1, 2], dtype=jnp.int32) 51 | assert_array_equal(_round_half_away_from_zero(a), a) 52 | 53 | 54 | def test_round_half_away_from_zero_float(): 55 | a = jnp.array([0.5, 1.5], dtype=jnp.float32) 56 | 57 | expected = jnp.array([1, 2], dtype=jnp.int32) 58 | got = _round_half_away_from_zero(a) 59 | 60 | assert_array_equal(got, expected) 61 | 62 | 63 | def test_linear_indices_and_weights_inside_domain(): 64 | """Test that the indices and weights are correct for a points inside the domain.""" 65 | coordinates = jnp.array([0, 0.5, 1]) 66 | 67 | (idx_low, weight_low), (idx_high, weight_high) = _compute_indices_and_weights( 68 | coordinates, input_size=2 69 | ) 70 | 71 | assert_array_equal(idx_low, jnp.array([0, 0, 0], dtype=jnp.int32)) 72 | assert_array_equal(weight_low, jnp.array([1, 0.5, 0], dtype=jnp.float32)) 73 | assert_array_equal(idx_high, jnp.array([1, 1, 1], dtype=jnp.int32)) 74 | assert_array_equal(weight_high, jnp.array([0, 0.5, 1], dtype=jnp.float32)) 75 | 76 | 77 | def test_linear_indices_and_weights_outside_domain(): 78 | coordinates = jnp.array([-1, 2]) 79 | 80 | (idx_low, weight_low), (idx_high, weight_high) = _compute_indices_and_weights( 81 | coordinates, input_size=2 82 | ) 83 | 84 | assert_array_equal(idx_low, jnp.array([0, 0], dtype=jnp.int32)) 85 | assert_array_equal(weight_low, jnp.array([2, -1], dtype=jnp.float32)) 86 | assert_array_equal(idx_high, jnp.array([1, 1], dtype=jnp.int32)) 87 | assert_array_equal(weight_high, jnp.array([-1, 2], dtype=jnp.float32)) 88 | -------------------------------------------------------------------------------- /tests/test_next_state.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | import pandas as pd 3 | from jax import Array 4 | from pybaum import tree_equal 5 | 6 | from lcm.input_processing import process_model 7 | from lcm.interfaces import InternalModel 8 | from lcm.next_state import _create_stochastic_next_func, get_next_state_function 9 | from lcm.typing import ParamsDict, Scalar, ShockType, Target 10 | from tests.test_models import get_model_config 11 | 12 | 13 | def test_get_next_state_function_with_solve_target(): 14 | model = process_model( 15 | get_model_config("iskhakov_et_al_2017_stripped_down", n_periods=3), 16 | ) 17 | got_func = get_next_state_function(model, target=Target.SOLVE) 18 | 19 | params = { 20 | "beta": 1.0, 21 | "utility": {"disutility_of_work": 1.0}, 22 | "next_wealth": { 23 | "interest_rate": 0.05, 24 | }, 25 | } 26 | 27 | action = {"retirement": 1, "consumption": 10} 28 | state = {"wealth": 20} 29 | 30 | got = got_func(**action, **state, _period=1, params=params) 31 | assert got == {"next_wealth": 1.05 * (20 - 10)} 32 | 33 | 34 | def test_get_next_state_function_with_simulate_target(): 35 | def f_a(state: Array, params: ParamsDict) -> Scalar: # noqa: ARG001 36 | return state[0] 37 | 38 | def f_b(state: Scalar, params: ParamsDict) -> Scalar: # noqa: ARG001 39 | return None # type: ignore[return-value] 40 | 41 | def f_weight_b(state: Scalar, params: ParamsDict) -> Array: # noqa: ARG001 42 | return jnp.array([[0.0, 1.0]]) 43 | 44 | functions = { 45 | "a": f_a, 46 | "b": f_b, 47 | "weight_b": f_weight_b, 48 | } 49 | 50 | grids = {"b": jnp.arange(2)} 51 | 52 | function_info = pd.DataFrame( 53 | { 54 | "is_next": [True, True], 55 | "is_stochastic_next": [False, True], 56 | }, 57 | index=["a", "b"], 58 | ) 59 | 60 | model = InternalModel( 61 | functions=functions, # type: ignore[arg-type] 62 | grids=grids, 63 | function_info=function_info, 64 | gridspecs={}, 65 | variable_info=pd.DataFrame(), 66 | params={}, 67 | random_utility_shocks=ShockType.NONE, 68 | n_periods=1, 69 | ) 70 | 71 | got_func = get_next_state_function(model, target=Target.SIMULATE) 72 | 73 | keys = {"b": jnp.arange(2, dtype="uint32")} 74 | got = got_func(state=jnp.arange(2), keys=keys, params={}) 75 | 76 | expected = {"a": jnp.array([0]), "b": jnp.array([1])} 77 | assert tree_equal(expected, got) 78 | 79 | 80 | def test_create_stochastic_next_func(): 81 | labels = jnp.arange(2) 82 | got_func = _create_stochastic_next_func(name="a", labels=labels) 83 | 84 | keys = {"a": jnp.arange(2, dtype="uint32")} # PRNG dtype 85 | weights = jnp.array([[0.0, 1], [1, 0]]) 86 | got = got_func(keys=keys, weight_a=weights) 87 | 88 | assert jnp.array_equal(got, jnp.array([1, 0])) 89 | -------------------------------------------------------------------------------- /tests/test_random.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | 4 | from lcm.random import generate_simulation_keys, random_choice 5 | 6 | 7 | def test_random_choice(): 8 | key = jax.random.key(0) 9 | probs = jnp.array([[0.0, 0, 1], [1, 0, 0], [0, 1, 0]]) 10 | labels = jnp.array([1, 2, 3]) 11 | got = random_choice(labels=labels, probs=probs, key=key) 12 | assert jnp.array_equal(got, jnp.array([3, 1, 2])) 13 | 14 | 15 | def test_generate_simulation_keys(): 16 | key = jnp.arange(2, dtype="uint32") # PRNG dtype 17 | stochastic_next_functions = ["a", "b"] 18 | got = generate_simulation_keys(key, stochastic_next_functions) 19 | # assert that all generated keys are different from each other 20 | matrix = jnp.array([key, got[0], got[1]["a"], got[1]["b"]]) 21 | assert jnp.linalg.matrix_rank(matrix) == 2 22 | -------------------------------------------------------------------------------- /tests/test_regression_test.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | import pandas as pd 3 | from jax import Array 4 | from numpy.testing import assert_array_almost_equal as aaae 5 | from pandas.testing import assert_frame_equal 6 | 7 | from lcm._config import TEST_DATA 8 | from lcm.entry_point import get_lcm_function 9 | from tests.test_models import get_model_config, get_params 10 | 11 | 12 | def test_regression_test(): 13 | """Test that the output of lcm does not change.""" 14 | # Load generated output 15 | # ================================================================================== 16 | expected_simulate = pd.read_pickle( 17 | TEST_DATA.joinpath("regression_tests", "simulation.pkl"), 18 | ) 19 | 20 | expected_solve = pd.read_pickle( 21 | TEST_DATA.joinpath("regression_tests", "solution.pkl"), 22 | ) 23 | 24 | # Generate current lcm ouput 25 | # ================================================================================== 26 | model_config = get_model_config("iskhakov_et_al_2017_stripped_down", n_periods=3) 27 | 28 | solve, _ = get_lcm_function(model=model_config, targets="solve") 29 | 30 | params = get_params( 31 | beta=0.95, 32 | disutility_of_work=1.0, 33 | interest_rate=0.05, 34 | ) 35 | got_solve: dict[int, Array] = solve(params) # type: ignore[assignment] 36 | 37 | solve_and_simulate, _ = get_lcm_function( 38 | model=model_config, 39 | targets="solve_and_simulate", 40 | ) 41 | 42 | got_simulate = solve_and_simulate( 43 | params=params, 44 | initial_states={ 45 | "wealth": jnp.array([5.0, 20, 40, 70]), 46 | }, 47 | ) 48 | 49 | # Compare 50 | # ================================================================================== 51 | aaae(expected_solve, list(got_solve.values()), decimal=5) 52 | assert_frame_equal(expected_simulate, got_simulate) # type: ignore[arg-type] 53 | -------------------------------------------------------------------------------- /tests/test_state_action_space.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | from numpy.testing import assert_array_equal 3 | 4 | from lcm.input_processing import process_model 5 | from lcm.interfaces import StateActionSpace, StateSpaceInfo 6 | from lcm.state_action_space import ( 7 | create_state_action_space, 8 | create_state_space_info, 9 | ) 10 | from tests.test_models import get_model_config 11 | 12 | 13 | def test_create_state_action_space_solution(): 14 | model = get_model_config("iskhakov_et_al_2017_stripped_down", n_periods=3) 15 | internal_model = process_model(model) 16 | 17 | state_action_space = create_state_action_space( 18 | model=internal_model, 19 | is_last_period=False, 20 | ) 21 | 22 | assert isinstance(state_action_space, StateActionSpace) 23 | assert jnp.array_equal( 24 | state_action_space.discrete_actions["retirement"], 25 | model.actions["retirement"].to_jax(), 26 | ) 27 | assert jnp.array_equal( 28 | state_action_space.states["wealth"], model.states["wealth"].to_jax() 29 | ) 30 | 31 | 32 | def test_create_state_action_space_simulation(): 33 | model_config = get_model_config("iskhakov_et_al_2017", n_periods=3) 34 | model = process_model(model_config) 35 | got_space = create_state_action_space( 36 | model=model, 37 | initial_states={ 38 | "wealth": jnp.array([10.0, 20.0]), 39 | "lagged_retirement": jnp.array([0, 1]), 40 | }, 41 | ) 42 | assert_array_equal(got_space.discrete_actions["retirement"], jnp.array([0, 1])) 43 | assert_array_equal(got_space.states["wealth"], jnp.array([10.0, 20.0])) 44 | assert_array_equal(got_space.states["lagged_retirement"], jnp.array([0, 1])) 45 | 46 | 47 | def test_create_state_space_info(): 48 | model = get_model_config("iskhakov_et_al_2017_stripped_down", n_periods=3) 49 | internal_model = process_model(model) 50 | 51 | state_space_info = create_state_space_info( 52 | model=internal_model, 53 | is_last_period=False, 54 | ) 55 | 56 | assert isinstance(state_space_info, StateSpaceInfo) 57 | assert state_space_info.states_names == ("wealth",) 58 | assert state_space_info.discrete_states == {} 59 | assert state_space_info.continuous_states == model.states 60 | 61 | 62 | def test_create_state_action_space_replace(): 63 | model_config = get_model_config("iskhakov_et_al_2017", n_periods=3) 64 | model = process_model(model_config) 65 | space = create_state_action_space( 66 | model=model, 67 | initial_states={ 68 | "wealth": jnp.array([10.0, 20.0]), 69 | "lagged_retirement": jnp.array([0, 1]), 70 | }, 71 | ) 72 | new_space = space.replace( 73 | states={"wealth": jnp.array([10.0, 30.0])}, 74 | ) 75 | assert_array_equal(new_space.states["wealth"], jnp.array([10.0, 30.0])) 76 | -------------------------------------------------------------------------------- /tests/test_stochastic.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | import pandas as pd 3 | import pytest 4 | from jax import Array 5 | 6 | import lcm 7 | from lcm.entry_point import ( 8 | get_lcm_function, 9 | ) 10 | from tests.test_models import get_model_config, get_params 11 | 12 | # ====================================================================================== 13 | # Simulate 14 | # ====================================================================================== 15 | 16 | 17 | def test_get_lcm_function_with_simulate_target(): 18 | simulate_model, _ = get_lcm_function( 19 | model=get_model_config("iskhakov_et_al_2017_stochastic", n_periods=3), 20 | targets="solve_and_simulate", 21 | ) 22 | 23 | res: pd.DataFrame = simulate_model( # type: ignore[assignment] 24 | params=get_params(), 25 | initial_states={ 26 | "health": jnp.array([1, 1, 0, 0]), 27 | "partner": jnp.array([0, 0, 1, 0]), 28 | "wealth": jnp.array([10.0, 50.0, 30, 80.0]), 29 | }, 30 | ) 31 | 32 | # This is derived from the partner transition in get_params. 33 | expected_next_partner = ( 34 | (res.working.astype(bool) | ~res.partner.astype(bool)).astype(int).loc[:1] 35 | ) 36 | 37 | pd.testing.assert_series_equal( 38 | res["partner"].loc[1:], 39 | expected_next_partner, 40 | check_index=False, 41 | check_names=False, 42 | ) 43 | 44 | 45 | # ====================================================================================== 46 | # Solve 47 | # ====================================================================================== 48 | 49 | 50 | def test_get_lcm_function_with_solve_target(): 51 | solve_model, _ = get_lcm_function( 52 | model=get_model_config("iskhakov_et_al_2017_stochastic", n_periods=3), 53 | targets="solve", 54 | ) 55 | solve_model(params=get_params()) 56 | 57 | 58 | # ====================================================================================== 59 | # Comparison with deterministic results 60 | # ====================================================================================== 61 | 62 | 63 | @pytest.fixture 64 | def model_and_params(): 65 | """Return a simple deterministic and stochastic model with parameters. 66 | 67 | TODO(@timmens): Add this to tests/test_models/stochastic.py. 68 | 69 | """ 70 | model_deterministic = get_model_config( 71 | "iskhakov_et_al_2017_stochastic", n_periods=3 72 | ) 73 | model_stochastic = get_model_config("iskhakov_et_al_2017_stochastic", n_periods=3) 74 | 75 | # Overwrite health transition with simple stochastic version and deterministic one 76 | # ================================================================================== 77 | @lcm.mark.stochastic 78 | def next_health_stochastic(health): 79 | pass 80 | 81 | def next_health_deterministic(health): 82 | return health 83 | 84 | model_deterministic.functions["next_health"] = next_health_deterministic 85 | model_stochastic.functions["next_health"] = next_health_stochastic 86 | 87 | params = get_params( 88 | beta=0.95, 89 | disutility_of_work=1.0, 90 | interest_rate=0.05, 91 | wage=10.0, 92 | health_transition=jnp.identity(2), 93 | ) 94 | 95 | return model_deterministic, model_stochastic, params 96 | 97 | 98 | def test_compare_deterministic_and_stochastic_results_value_function(model_and_params): 99 | """Test that the deterministic and stochastic models produce the same results.""" 100 | model_deterministic, model_stochastic, params = model_and_params 101 | 102 | # ================================================================================== 103 | # Compare value function arrays 104 | # ================================================================================== 105 | solve_model_deterministic, _ = get_lcm_function( 106 | model=model_deterministic, 107 | targets="solve", 108 | ) 109 | solve_model_stochastic, _ = get_lcm_function( 110 | model=model_stochastic, 111 | targets="solve", 112 | ) 113 | 114 | solution_deterministic: dict[int, Array] = solve_model_deterministic(params) # type: ignore[assignment] 115 | solution_stochastic: dict[int, Array] = solve_model_stochastic(params) # type: ignore[assignment] 116 | 117 | assert jnp.array_equal( 118 | jnp.array(list(solution_deterministic.values())), 119 | jnp.array(list(solution_stochastic.values())), 120 | equal_nan=True, 121 | ) 122 | 123 | # ================================================================================== 124 | # Compare simulation results 125 | # ================================================================================== 126 | simulate_model_deterministic, _ = get_lcm_function( 127 | model=model_deterministic, 128 | targets="simulate", 129 | ) 130 | simulate_model_stochastic, _ = get_lcm_function( 131 | model=model_stochastic, 132 | targets="simulate", 133 | ) 134 | 135 | initial_states = { 136 | "health": jnp.array([1, 1, 0, 0]), 137 | "partner": jnp.array([0, 0, 0, 0]), 138 | "wealth": jnp.array([10.0, 50.0, 30, 80.0]), 139 | } 140 | 141 | simulation_deterministic = simulate_model_deterministic( 142 | params, 143 | V_arr_dict=solution_deterministic, 144 | initial_states=initial_states, 145 | ) 146 | simulation_stochastic = simulate_model_stochastic( 147 | params, 148 | V_arr_dict=solution_stochastic, 149 | initial_states=initial_states, 150 | ) 151 | pd.testing.assert_frame_equal(simulation_deterministic, simulation_stochastic) # type: ignore[arg-type] 152 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | from lcm.utils import find_duplicates 2 | 3 | 4 | def test_find_duplicates_singe_container_no_duplicates(): 5 | assert find_duplicates([1, 2, 3, 4, 5]) == set() 6 | 7 | 8 | def test_find_duplicates_single_container_with_duplicates(): 9 | assert find_duplicates([1, 2, 3, 4, 5, 5]) == {5} 10 | 11 | 12 | def test_find_duplicates_multiple_containers_no_duplicates(): 13 | assert find_duplicates([1, 2, 3, 4, 5], [6, 7, 8, 9, 10]) == set() 14 | 15 | 16 | def test_find_duplicates_multiple_containers_with_duplicates(): 17 | assert find_duplicates([1, 2, 3, 4, 5, 5], [6, 7, 8, 9, 10, 5]) == {5} 18 | --------------------------------------------------------------------------------