├── .github └── workflows │ └── main.yml ├── .gitignore ├── .pre-commit-config.yaml ├── .readthedocs.yml ├── .yamllint.yml ├── LICENSE ├── README.md ├── codecov.yml ├── docs ├── Makefile ├── make.bat └── source │ ├── _static │ ├── css │ │ └── custom.css │ └── images │ │ ├── book.svg │ │ ├── books.svg │ │ ├── coding.svg │ │ ├── light-bulb.svg │ │ └── logo.svg │ ├── chs_test_params.csv │ ├── conf.py │ ├── explanations │ ├── index.rst │ ├── names_and_concepts.rst │ └── notes_on_factor_scales.rst │ ├── getting_started │ ├── index.rst │ └── tutorial.ipynb │ ├── how_to_guides │ ├── how_to_simulate_dataset.ipynb │ ├── how_to_visualize_correlations.ipynb │ ├── how_to_visualize_pairwise_factor_distribution.ipynb │ ├── how_to_visualize_transition_equations.ipynb │ ├── index.rst │ ├── model_specs.rst │ └── utilities.rst │ ├── index.rst │ ├── reference_guides │ ├── endogeneity_corrections.rst │ ├── estimation.rst │ ├── index.rst │ ├── pre_processing.rst │ ├── simulation.rst │ └── transition_functions.rst │ ├── rtd_environment.yml │ ├── start_params.csv │ └── start_params_template.csv ├── environment.yml ├── pixi.lock ├── pyproject.toml ├── src └── skillmodels │ ├── __init__.py │ ├── check_model.py │ ├── clipping.py │ ├── config.py │ ├── constraints.py │ ├── correlation_heatmap.py │ ├── decorators.py │ ├── filtered_states.py │ ├── kalman_filters.py │ ├── kalman_filters_debug.py │ ├── likelihood_function.py │ ├── likelihood_function_debug.py │ ├── maximization_inputs.py │ ├── params_index.py │ ├── parse_params.py │ ├── process_data.py │ ├── process_debug_data.py │ ├── process_model.py │ ├── simulate_data.py │ ├── transition_functions.py │ ├── utilities.py │ ├── utils_plotting.py │ ├── visualize_factor_distributions.py │ └── visualize_transition_equations.py └── tests ├── model2.yaml ├── model2_correct_params_index.csv ├── model2_correct_update_info.csv ├── model2_simulated_data.dta ├── regression_vault ├── chs_results.csv ├── no_stages_anchoring.csv ├── no_stages_anchoring_result.json ├── one_stage.csv ├── one_stage_anchoring.csv ├── one_stage_anchoring_custom_functions.csv ├── one_stage_anchoring_custom_functions_result.json ├── one_stage_anchoring_result.json ├── one_stage_old.csv ├── one_stage_result.json ├── two_stages_anchoring.csv └── two_stages_anchoring_result.json ├── test_clipping.py ├── test_constraints.py ├── test_correlation_heatmap.py ├── test_decorators.py ├── test_filtered_states.py ├── test_kalman_filters.py ├── test_likelihood_regression.py ├── test_maximization_inputs.py ├── test_params_index.py ├── test_parse_params.py ├── test_process_data.py ├── test_process_model.py ├── test_simulate_data.py ├── test_transition_functions.py ├── test_utilities.py ├── test_visualize_factor_distributions.py └── test_visualize_transition_equations.py /.github/workflows/main.yml: -------------------------------------------------------------------------------- 1 | --- 2 | name: main 3 | # Automatically cancel a previous run. 4 | concurrency: 5 | group: ${{ github.head_ref || github.run_id }} 6 | cancel-in-progress: true 7 | on: 8 | push: 9 | branches: 10 | - main 11 | pull_request: 12 | branches: 13 | - '*' 14 | jobs: 15 | run-tests: 16 | name: Run tests for ${{ matrix.os }} on ${{ matrix.python-version }} 17 | runs-on: ${{ matrix.os }} 18 | strategy: 19 | fail-fast: false 20 | matrix: 21 | os: 22 | - ubuntu-latest 23 | - macos-latest 24 | - windows-latest 25 | python-version: 26 | - '3.12' 27 | steps: 28 | - uses: actions/checkout@v4 29 | - uses: prefix-dev/setup-pixi@v0.8.0 30 | with: 31 | pixi-version: v0.29.0 32 | cache: true 33 | cache-write: ${{ github.event_name == 'push' && github.ref_name == 'main' }} 34 | environments: test-cpu 35 | activate-environment: true 36 | - name: Run pytest 37 | shell: bash -l {0} 38 | run: pixi run -e test-cpu tests-with-cov 39 | - name: Upload coverage report 40 | if: runner.os == 'Linux' && matrix.python-version == '3.12' 41 | uses: codecov/codecov-action@v4 42 | env: 43 | CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} 44 | # run-mypy: 45 | # name: Run mypy on Python 3.12 46 | # runs-on: ubuntu-latest 47 | # strategy: 48 | # fail-fast: false 49 | # steps: 50 | # - uses: actions/checkout@v4 51 | # - uses: prefix-dev/setup-pixi@v0.8.0 52 | # with: 53 | # pixi-version: v0.28.2 54 | # cache: true 55 | # cache-write: ${{ github.event_name == 'push' && github.ref_name == 'main' }} 56 | # environments: mypy 57 | # - name: Run mypy 58 | # shell: bash -l {0} 59 | # run: pixi run mypy 60 | # run-explanation-notebooks: 61 | # name: Run explanation notebooks on Python 3.12 62 | # runs-on: ubuntu-latest 63 | # steps: 64 | # - uses: actions/checkout@v4 65 | # - uses: prefix-dev/setup-pixi@v0.8.0 66 | # with: 67 | # pixi-version: v0.28.2 68 | # cache: true 69 | # cache-write: ${{ github.event_name == 'push' && github.ref_name == 'main' }} 70 | # environments: test 71 | # - name: Run explanation notebooks 72 | # shell: bash -l {0} 73 | # run: pixi run -e test explanation-notebooks 74 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | *build/ 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | *.sublime-workspace 35 | *.sublime-project 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | db.sqlite3 61 | 62 | # Flask stuff: 63 | instance/ 64 | .webassets-cache 65 | 66 | # Scrapy stuff: 67 | .scrapy 68 | 69 | # Sphinx documentation 70 | docs/_build/ 71 | 72 | # PyBuilder 73 | target/ 74 | 75 | # Jupyter Notebook 76 | .ipynb_checkpoints 77 | 78 | # pyenv 79 | .python-version 80 | 81 | # celery beat schedule file 82 | celerybeat-schedule 83 | 84 | # SageMath parsed files 85 | *.sage.py 86 | 87 | # Environments 88 | .env 89 | .venv 90 | env/ 91 | venv/ 92 | ENV/ 93 | env.bak/ 94 | venv.bak/ 95 | 96 | # Spyder project settings 97 | .spyderproject 98 | .spyproject 99 | 100 | # Rope project settings 101 | .ropeproject 102 | 103 | # mkdocs documentation 104 | /site 105 | 106 | # mypy 107 | .mypy_cache/ 108 | 109 | *notes/ 110 | 111 | .idea/ 112 | 113 | *.bak 114 | 115 | 116 | *.db 117 | 118 | 119 | mixed_documents/ 120 | src/skillmodels/_version.py 121 | .pixi 122 | .vscode 123 | *.py.*.bin 124 | *.py.*.html 125 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | repos: 3 | - repo: meta 4 | hooks: 5 | - id: check-hooks-apply 6 | - id: check-useless-excludes 7 | # - id: identity # Prints all files passed to pre-commits. Debugging. 8 | - repo: https://github.com/lyz-code/yamlfix 9 | rev: 1.17.0 10 | hooks: 11 | - id: yamlfix 12 | - repo: https://github.com/pre-commit/pre-commit-hooks 13 | rev: v4.6.0 14 | hooks: 15 | - id: check-added-large-files 16 | args: 17 | - --maxkb=50000 18 | - id: check-case-conflict 19 | - id: check-merge-conflict 20 | - id: check-vcs-permalinks 21 | - id: check-yaml 22 | - id: check-toml 23 | - id: debug-statements 24 | - id: end-of-file-fixer 25 | - id: fix-byte-order-marker 26 | types: 27 | - text 28 | - id: forbid-submodules 29 | - id: mixed-line-ending 30 | args: 31 | - --fix=lf 32 | description: Forces to replace line ending by the UNIX 'lf' character. 33 | - id: name-tests-test 34 | args: 35 | - --pytest-test-first 36 | - id: trailing-whitespace 37 | - id: check-ast 38 | - id: check-docstring-first 39 | - repo: https://github.com/adrienverge/yamllint.git 40 | rev: v1.35.1 41 | hooks: 42 | - id: yamllint 43 | - repo: https://github.com/astral-sh/ruff-pre-commit 44 | rev: v0.6.4 45 | hooks: 46 | # Run the linter. 47 | - id: ruff 48 | types_or: 49 | - python 50 | - pyi 51 | - jupyter 52 | args: 53 | - --fix 54 | # - --unsafe-fixes 55 | # Run the formatter. 56 | - id: ruff-format 57 | types_or: 58 | - python 59 | - pyi 60 | - jupyter 61 | - repo: https://github.com/kynan/nbstripout 62 | rev: 0.7.1 63 | hooks: 64 | - id: nbstripout 65 | args: 66 | - --extra-keys 67 | - metadata.kernelspec metadata.language_info.version metadata.vscode 68 | - repo: https://github.com/executablebooks/mdformat 69 | rev: 0.7.17 70 | hooks: 71 | - id: mdformat 72 | additional_dependencies: 73 | - mdformat-myst 74 | - mdformat-ruff 75 | args: 76 | - --wrap 77 | - '88' 78 | ci: 79 | autoupdate_schedule: monthly 80 | -------------------------------------------------------------------------------- /.readthedocs.yml: -------------------------------------------------------------------------------- 1 | --- 2 | version: 2 3 | sphinx: 4 | configuration: docs/source/conf.py 5 | python: 6 | version: 3.12 7 | conda: 8 | environment: docs/source/rtd_environment.yml 9 | -------------------------------------------------------------------------------- /.yamllint.yml: -------------------------------------------------------------------------------- 1 | --- 2 | yaml-files: 3 | - '*.yaml' 4 | - '*.yml' 5 | - .yamllint 6 | rules: 7 | braces: enable 8 | brackets: enable 9 | colons: enable 10 | commas: enable 11 | comments: 12 | level: warning 13 | comments-indentation: 14 | level: warning 15 | document-end: disable 16 | document-start: 17 | level: warning 18 | empty-lines: enable 19 | empty-values: disable 20 | float-values: disable 21 | hyphens: enable 22 | indentation: {spaces: 2} 23 | key-duplicates: enable 24 | key-ordering: disable 25 | line-length: 26 | max: 88 27 | allow-non-breakable-words: true 28 | allow-non-breakable-inline-mappings: false 29 | new-line-at-end-of-file: enable 30 | new-lines: 31 | type: unix 32 | octal-values: disable 33 | quoted-strings: disable 34 | trailing-spaces: enable 35 | truthy: 36 | level: warning 37 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2016- Janoś Gabler 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # skillmodels 2 | 3 | ## Introduction 4 | 5 | Welcome to skillmodels, a Python implementation of estimators for skill formation 6 | models. The econometrics of skill formation models is a very active field and several 7 | estimators were proposed. None of them is implemented in standard econometrics packages. 8 | 9 | ## Installation 10 | 11 | > **Warning:** To run skillmodels you need to install jax and jaxlib. At the time of 12 | > writing, in most use cases, it is faster on a CPU than on a GPU, so it should be 13 | > sufficient to install the CPU version, which is available on all platforms. In any 14 | > case, for installation of jax and jaxlib, please consult the jax 15 | > [docs](https://jax.readthedocs.io/en/latest/installation.html#supported-platforms). 16 | 17 | Skillmodels can be installed via PyPI or via GitHub. To do so, type the following in a 18 | terminal: 19 | 20 | ```console 21 | $ pip install skillmodels 22 | ``` 23 | 24 | or, for the latest development version, type: 25 | 26 | ```console 27 | $ pip install git+https://github.com/OpenSourceEconomics/skillmodels.git 28 | ``` 29 | 30 | ## Documentation 31 | 32 | [The documentation is hosted at readthedocs](https://skillmodels.readthedocs.io/en/latest/) 33 | 34 | ## Developing 35 | 36 | We use [pixi](https://pixi.sh/latest/) for our local development environment. If you 37 | want to work with or extend the skillmodels code base you can run the tests using 38 | 39 | ```console 40 | $ git clone https://github.com/OpenSourceEconomics/skillmodels.git 41 | $ pixi run tests 42 | ``` 43 | 44 | This will install the development environment and run the tests. You can run 45 | [mypy](https://mypy-lang.org/) using 46 | 47 | ```console 48 | $ pixi run mypy 49 | ``` 50 | 51 | Before committing, install the pre-commit hooks using 52 | 53 | ```console 54 | $ pre-commit install 55 | ``` 56 | 57 | #### Documentation 58 | 59 | You can build the documentation locally. After cloning the repository you can cd to the 60 | docs directory and type: 61 | 62 | ```console 63 | $ make html 64 | ``` 65 | 66 | ## Citation 67 | 68 | It took countless hours to write skillmodels. I make it available under a very 69 | permissive license in the hope that it helps other people to do great research that 70 | advances our knowledge about the formation of cognitive and noncognitive siklls. If you 71 | find skillmodels helpful, please don't forget to cite it. Below you can find the bibtex 72 | entry for a suggested citation. The suggested citation will be updated once the code 73 | becomes part of a published paper. 74 | 75 | ``` 76 | @Unpublished{Gabler2024, 77 | Title = {A Python Library to Estimate Nonlinear Dynamic Latent Factor Models}, 78 | Author = {Janos Gabler}, 79 | Year = {2024}, 80 | Url = {https://github.com/OpenSourceEconomics/skillmodels} 81 | } 82 | ``` 83 | 84 | ## Feedback 85 | 86 | If you find skillmodels helpful for research or teaching, please let me know. If you 87 | encounter any problems with the installation or while using skillmodels, please complain 88 | or open an issue at [GitHub](https://github.com/OpenSourceEconomics/skillmodels) 89 | -------------------------------------------------------------------------------- /codecov.yml: -------------------------------------------------------------------------------- 1 | --- 2 | codecov: 3 | notify: 4 | require_ci_to_pass: true 5 | coverage: 6 | precision: 2 7 | round: down 8 | range: 50...100 9 | status: 10 | patch: 11 | default: 12 | target: 80% 13 | project: 14 | default: 15 | target: 87.5% 16 | ignore: 17 | - .tox/**/* 18 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | PAPER = 8 | BUILDDIR = build 9 | 10 | # User-friendly check for sphinx-build 11 | ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $$?), 1) 12 | $(error The '$(SPHINXBUILD)' command was not found. Make sure you have Sphinx installed, then set the SPHINXBUILD environment variable to point to the full path of the '$(SPHINXBUILD)' executable. Alternatively you can add the directory with the executable to your PATH. If you don't have Sphinx installed, grab it from http://sphinx-doc.org/) 13 | endif 14 | 15 | # Internal variables. 16 | PAPEROPT_a4 = -D latex_paper_size=a4 17 | PAPEROPT_letter = -D latex_paper_size=letter 18 | ALLSPHINXOPTS = -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) source 19 | # the i18n builder cannot share the environment and doctrees with the others 20 | I18NSPHINXOPTS = $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) source 21 | 22 | .PHONY: help 23 | help: 24 | @echo "Please use \`make ' where is one of" 25 | @echo " html to make standalone HTML files" 26 | @echo " dirhtml to make HTML files named index.html in directories" 27 | @echo " singlehtml to make a single large HTML file" 28 | @echo " pickle to make pickle files" 29 | @echo " json to make JSON files" 30 | @echo " htmlhelp to make HTML files and a HTML help project" 31 | @echo " qthelp to make HTML files and a qthelp project" 32 | @echo " applehelp to make an Apple Help Book" 33 | @echo " devhelp to make HTML files and a Devhelp project" 34 | @echo " epub to make an epub" 35 | @echo " latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter" 36 | @echo " latexpdf to make LaTeX files and run them through pdflatex" 37 | @echo " latexpdfja to make LaTeX files and run them through platex/dvipdfmx" 38 | @echo " text to make text files" 39 | @echo " man to make manual pages" 40 | @echo " texinfo to make Texinfo files" 41 | @echo " info to make Texinfo files and run them through makeinfo" 42 | @echo " gettext to make PO message catalogs" 43 | @echo " changes to make an overview of all changed/added/deprecated items" 44 | @echo " xml to make Docutils-native XML files" 45 | @echo " pseudoxml to make pseudoxml-XML files for display purposes" 46 | @echo " linkcheck to check all external links for integrity" 47 | @echo " doctest to run all doctests embedded in the documentation (if enabled)" 48 | @echo " coverage to run coverage check of the documentation (if enabled)" 49 | 50 | .PHONY: clean 51 | clean: 52 | rm -rf $(BUILDDIR)/* 53 | 54 | .PHONY: html 55 | html: 56 | $(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html 57 | @echo 58 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/html." 59 | 60 | .PHONY: dirhtml 61 | dirhtml: 62 | $(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) $(BUILDDIR)/dirhtml 63 | @echo 64 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/dirhtml." 65 | 66 | .PHONY: singlehtml 67 | singlehtml: 68 | $(SPHINXBUILD) -b singlehtml $(ALLSPHINXOPTS) $(BUILDDIR)/singlehtml 69 | @echo 70 | @echo "Build finished. The HTML page is in $(BUILDDIR)/singlehtml." 71 | 72 | .PHONY: pickle 73 | pickle: 74 | $(SPHINXBUILD) -b pickle $(ALLSPHINXOPTS) $(BUILDDIR)/pickle 75 | @echo 76 | @echo "Build finished; now you can process the pickle files." 77 | 78 | .PHONY: json 79 | json: 80 | $(SPHINXBUILD) -b json $(ALLSPHINXOPTS) $(BUILDDIR)/json 81 | @echo 82 | @echo "Build finished; now you can process the JSON files." 83 | 84 | .PHONY: htmlhelp 85 | htmlhelp: 86 | $(SPHINXBUILD) -b htmlhelp $(ALLSPHINXOPTS) $(BUILDDIR)/htmlhelp 87 | @echo 88 | @echo "Build finished; now you can run HTML Help Workshop with the" \ 89 | ".hhp project file in $(BUILDDIR)/htmlhelp." 90 | 91 | .PHONY: qthelp 92 | qthelp: 93 | $(SPHINXBUILD) -b qthelp $(ALLSPHINXOPTS) $(BUILDDIR)/qthelp 94 | @echo 95 | @echo "Build finished; now you can run "qcollectiongenerator" with the" \ 96 | ".qhcp project file in $(BUILDDIR)/qthelp, like this:" 97 | @echo "# qcollectiongenerator $(BUILDDIR)/qthelp/skillmodels.qhcp" 98 | @echo "To view the help file:" 99 | @echo "# assistant -collectionFile $(BUILDDIR)/qthelp/skillmodels.qhc" 100 | 101 | .PHONY: applehelp 102 | applehelp: 103 | $(SPHINXBUILD) -b applehelp $(ALLSPHINXOPTS) $(BUILDDIR)/applehelp 104 | @echo 105 | @echo "Build finished. The help book is in $(BUILDDIR)/applehelp." 106 | @echo "N.B. You won't be able to view it unless you put it in" \ 107 | "~/Library/Documentation/Help or install it in your application" \ 108 | "bundle." 109 | 110 | .PHONY: devhelp 111 | devhelp: 112 | $(SPHINXBUILD) -b devhelp $(ALLSPHINXOPTS) $(BUILDDIR)/devhelp 113 | @echo 114 | @echo "Build finished." 115 | @echo "To view the help file:" 116 | @echo "# mkdir -p $$HOME/.local/share/devhelp/skillmodels" 117 | @echo "# ln -s $(BUILDDIR)/devhelp $$HOME/.local/share/devhelp/skillmodels" 118 | @echo "# devhelp" 119 | 120 | .PHONY: epub 121 | epub: 122 | $(SPHINXBUILD) -b epub $(ALLSPHINXOPTS) $(BUILDDIR)/epub 123 | @echo 124 | @echo "Build finished. The epub file is in $(BUILDDIR)/epub." 125 | 126 | .PHONY: latex 127 | latex: 128 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 129 | @echo 130 | @echo "Build finished; the LaTeX files are in $(BUILDDIR)/latex." 131 | @echo "Run \`make' in that directory to run these through (pdf)latex" \ 132 | "(use \`make latexpdf' here to do that automatically)." 133 | 134 | .PHONY: latexpdf 135 | latexpdf: 136 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 137 | @echo "Running LaTeX files through pdflatex..." 138 | $(MAKE) -C $(BUILDDIR)/latex all-pdf 139 | @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." 140 | 141 | .PHONY: latexpdfja 142 | latexpdfja: 143 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 144 | @echo "Running LaTeX files through platex and dvipdfmx..." 145 | $(MAKE) -C $(BUILDDIR)/latex all-pdf-ja 146 | @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." 147 | 148 | .PHONY: text 149 | text: 150 | $(SPHINXBUILD) -b text $(ALLSPHINXOPTS) $(BUILDDIR)/text 151 | @echo 152 | @echo "Build finished. The text files are in $(BUILDDIR)/text." 153 | 154 | .PHONY: man 155 | man: 156 | $(SPHINXBUILD) -b man $(ALLSPHINXOPTS) $(BUILDDIR)/man 157 | @echo 158 | @echo "Build finished. The manual pages are in $(BUILDDIR)/man." 159 | 160 | .PHONY: texinfo 161 | texinfo: 162 | $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo 163 | @echo 164 | @echo "Build finished. The Texinfo files are in $(BUILDDIR)/texinfo." 165 | @echo "Run \`make' in that directory to run these through makeinfo" \ 166 | "(use \`make info' here to do that automatically)." 167 | 168 | .PHONY: info 169 | info: 170 | $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo 171 | @echo "Running Texinfo files through makeinfo..." 172 | make -C $(BUILDDIR)/texinfo info 173 | @echo "makeinfo finished; the Info files are in $(BUILDDIR)/texinfo." 174 | 175 | .PHONY: gettext 176 | gettext: 177 | $(SPHINXBUILD) -b gettext $(I18NSPHINXOPTS) $(BUILDDIR)/locale 178 | @echo 179 | @echo "Build finished. The message catalogs are in $(BUILDDIR)/locale." 180 | 181 | .PHONY: changes 182 | changes: 183 | $(SPHINXBUILD) -b changes $(ALLSPHINXOPTS) $(BUILDDIR)/changes 184 | @echo 185 | @echo "The overview file is in $(BUILDDIR)/changes." 186 | 187 | .PHONY: linkcheck 188 | linkcheck: 189 | $(SPHINXBUILD) -b linkcheck $(ALLSPHINXOPTS) $(BUILDDIR)/linkcheck 190 | @echo 191 | @echo "Link check complete; look for any errors in the above output " \ 192 | "or in $(BUILDDIR)/linkcheck/output.txt." 193 | 194 | .PHONY: doctest 195 | doctest: 196 | $(SPHINXBUILD) -b doctest $(ALLSPHINXOPTS) $(BUILDDIR)/doctest 197 | @echo "Testing of doctests in the sources finished, look at the " \ 198 | "results in $(BUILDDIR)/doctest/output.txt." 199 | 200 | .PHONY: coverage 201 | coverage: 202 | $(SPHINXBUILD) -b coverage $(ALLSPHINXOPTS) $(BUILDDIR)/coverage 203 | @echo "Testing of coverage in the sources finished, look at the " \ 204 | "results in $(BUILDDIR)/coverage/python.txt." 205 | 206 | .PHONY: xml 207 | xml: 208 | $(SPHINXBUILD) -b xml $(ALLSPHINXOPTS) $(BUILDDIR)/xml 209 | @echo 210 | @echo "Build finished. The XML files are in $(BUILDDIR)/xml." 211 | 212 | .PHONY: pseudoxml 213 | pseudoxml: 214 | $(SPHINXBUILD) -b pseudoxml $(ALLSPHINXOPTS) $(BUILDDIR)/pseudoxml 215 | @echo 216 | @echo "Build finished. The pseudo-XML files are in $(BUILDDIR)/pseudoxml." 217 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | REM Command file for Sphinx documentation 4 | 5 | if "%SPHINXBUILD%" == "" ( 6 | set SPHINXBUILD=sphinx-build 7 | ) 8 | set BUILDDIR=build 9 | set ALLSPHINXOPTS=-d %BUILDDIR%/doctrees %SPHINXOPTS% source 10 | set I18NSPHINXOPTS=%SPHINXOPTS% source 11 | if NOT "%PAPER%" == "" ( 12 | set ALLSPHINXOPTS=-D latex_paper_size=%PAPER% %ALLSPHINXOPTS% 13 | set I18NSPHINXOPTS=-D latex_paper_size=%PAPER% %I18NSPHINXOPTS% 14 | ) 15 | 16 | if "%1" == "" goto help 17 | 18 | if "%1" == "help" ( 19 | :help 20 | echo.Please use `make ^` where ^ is one of 21 | echo. html to make standalone HTML files 22 | echo. dirhtml to make HTML files named index.html in directories 23 | echo. singlehtml to make a single large HTML file 24 | echo. pickle to make pickle files 25 | echo. json to make JSON files 26 | echo. htmlhelp to make HTML files and a HTML help project 27 | echo. qthelp to make HTML files and a qthelp project 28 | echo. devhelp to make HTML files and a Devhelp project 29 | echo. epub to make an epub 30 | echo. latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter 31 | echo. text to make text files 32 | echo. man to make manual pages 33 | echo. texinfo to make Texinfo files 34 | echo. gettext to make PO message catalogs 35 | echo. changes to make an overview over all changed/added/deprecated items 36 | echo. xml to make Docutils-native XML files 37 | echo. pseudoxml to make pseudoxml-XML files for display purposes 38 | echo. linkcheck to check all external links for integrity 39 | echo. doctest to run all doctests embedded in the documentation if enabled 40 | echo. coverage to run coverage check of the documentation if enabled 41 | goto end 42 | ) 43 | 44 | if "%1" == "clean" ( 45 | for /d %%i in (%BUILDDIR%\*) do rmdir /q /s %%i 46 | del /q /s %BUILDDIR%\* 47 | goto end 48 | ) 49 | 50 | 51 | REM Check if sphinx-build is available and fallback to Python version if any 52 | %SPHINXBUILD% 1>NUL 2>NUL 53 | if errorlevel 9009 goto sphinx_python 54 | goto sphinx_ok 55 | 56 | :sphinx_python 57 | 58 | set SPHINXBUILD=python -m sphinx.__init__ 59 | %SPHINXBUILD% 2> nul 60 | if errorlevel 9009 ( 61 | echo. 62 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 63 | echo.installed, then set the SPHINXBUILD environment variable to point 64 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 65 | echo.may add the Sphinx directory to PATH. 66 | echo. 67 | echo.If you don't have Sphinx installed, grab it from 68 | echo.http://sphinx-doc.org/ 69 | exit /b 1 70 | ) 71 | 72 | :sphinx_ok 73 | 74 | 75 | if "%1" == "html" ( 76 | %SPHINXBUILD% -b html %ALLSPHINXOPTS% %BUILDDIR%/html 77 | if errorlevel 1 exit /b 1 78 | echo. 79 | echo.Build finished. The HTML pages are in %BUILDDIR%/html. 80 | goto end 81 | ) 82 | 83 | if "%1" == "dirhtml" ( 84 | %SPHINXBUILD% -b dirhtml %ALLSPHINXOPTS% %BUILDDIR%/dirhtml 85 | if errorlevel 1 exit /b 1 86 | echo. 87 | echo.Build finished. The HTML pages are in %BUILDDIR%/dirhtml. 88 | goto end 89 | ) 90 | 91 | if "%1" == "singlehtml" ( 92 | %SPHINXBUILD% -b singlehtml %ALLSPHINXOPTS% %BUILDDIR%/singlehtml 93 | if errorlevel 1 exit /b 1 94 | echo. 95 | echo.Build finished. The HTML pages are in %BUILDDIR%/singlehtml. 96 | goto end 97 | ) 98 | 99 | if "%1" == "pickle" ( 100 | %SPHINXBUILD% -b pickle %ALLSPHINXOPTS% %BUILDDIR%/pickle 101 | if errorlevel 1 exit /b 1 102 | echo. 103 | echo.Build finished; now you can process the pickle files. 104 | goto end 105 | ) 106 | 107 | if "%1" == "json" ( 108 | %SPHINXBUILD% -b json %ALLSPHINXOPTS% %BUILDDIR%/json 109 | if errorlevel 1 exit /b 1 110 | echo. 111 | echo.Build finished; now you can process the JSON files. 112 | goto end 113 | ) 114 | 115 | if "%1" == "htmlhelp" ( 116 | %SPHINXBUILD% -b htmlhelp %ALLSPHINXOPTS% %BUILDDIR%/htmlhelp 117 | if errorlevel 1 exit /b 1 118 | echo. 119 | echo.Build finished; now you can run HTML Help Workshop with the ^ 120 | .hhp project file in %BUILDDIR%/htmlhelp. 121 | goto end 122 | ) 123 | 124 | if "%1" == "qthelp" ( 125 | %SPHINXBUILD% -b qthelp %ALLSPHINXOPTS% %BUILDDIR%/qthelp 126 | if errorlevel 1 exit /b 1 127 | echo. 128 | echo.Build finished; now you can run "qcollectiongenerator" with the ^ 129 | .qhcp project file in %BUILDDIR%/qthelp, like this: 130 | echo.^> qcollectiongenerator %BUILDDIR%\qthelp\skillmodels.qhcp 131 | echo.To view the help file: 132 | echo.^> assistant -collectionFile %BUILDDIR%\qthelp\skillmodels.ghc 133 | goto end 134 | ) 135 | 136 | if "%1" == "devhelp" ( 137 | %SPHINXBUILD% -b devhelp %ALLSPHINXOPTS% %BUILDDIR%/devhelp 138 | if errorlevel 1 exit /b 1 139 | echo. 140 | echo.Build finished. 141 | goto end 142 | ) 143 | 144 | if "%1" == "epub" ( 145 | %SPHINXBUILD% -b epub %ALLSPHINXOPTS% %BUILDDIR%/epub 146 | if errorlevel 1 exit /b 1 147 | echo. 148 | echo.Build finished. The epub file is in %BUILDDIR%/epub. 149 | goto end 150 | ) 151 | 152 | if "%1" == "latex" ( 153 | %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex 154 | if errorlevel 1 exit /b 1 155 | echo. 156 | echo.Build finished; the LaTeX files are in %BUILDDIR%/latex. 157 | goto end 158 | ) 159 | 160 | if "%1" == "latexpdf" ( 161 | %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex 162 | cd %BUILDDIR%/latex 163 | make all-pdf 164 | cd %~dp0 165 | echo. 166 | echo.Build finished; the PDF files are in %BUILDDIR%/latex. 167 | goto end 168 | ) 169 | 170 | if "%1" == "latexpdfja" ( 171 | %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex 172 | cd %BUILDDIR%/latex 173 | make all-pdf-ja 174 | cd %~dp0 175 | echo. 176 | echo.Build finished; the PDF files are in %BUILDDIR%/latex. 177 | goto end 178 | ) 179 | 180 | if "%1" == "text" ( 181 | %SPHINXBUILD% -b text %ALLSPHINXOPTS% %BUILDDIR%/text 182 | if errorlevel 1 exit /b 1 183 | echo. 184 | echo.Build finished. The text files are in %BUILDDIR%/text. 185 | goto end 186 | ) 187 | 188 | if "%1" == "man" ( 189 | %SPHINXBUILD% -b man %ALLSPHINXOPTS% %BUILDDIR%/man 190 | if errorlevel 1 exit /b 1 191 | echo. 192 | echo.Build finished. The manual pages are in %BUILDDIR%/man. 193 | goto end 194 | ) 195 | 196 | if "%1" == "texinfo" ( 197 | %SPHINXBUILD% -b texinfo %ALLSPHINXOPTS% %BUILDDIR%/texinfo 198 | if errorlevel 1 exit /b 1 199 | echo. 200 | echo.Build finished. The Texinfo files are in %BUILDDIR%/texinfo. 201 | goto end 202 | ) 203 | 204 | if "%1" == "gettext" ( 205 | %SPHINXBUILD% -b gettext %I18NSPHINXOPTS% %BUILDDIR%/locale 206 | if errorlevel 1 exit /b 1 207 | echo. 208 | echo.Build finished. The message catalogs are in %BUILDDIR%/locale. 209 | goto end 210 | ) 211 | 212 | if "%1" == "changes" ( 213 | %SPHINXBUILD% -b changes %ALLSPHINXOPTS% %BUILDDIR%/changes 214 | if errorlevel 1 exit /b 1 215 | echo. 216 | echo.The overview file is in %BUILDDIR%/changes. 217 | goto end 218 | ) 219 | 220 | if "%1" == "linkcheck" ( 221 | %SPHINXBUILD% -b linkcheck %ALLSPHINXOPTS% %BUILDDIR%/linkcheck 222 | if errorlevel 1 exit /b 1 223 | echo. 224 | echo.Link check complete; look for any errors in the above output ^ 225 | or in %BUILDDIR%/linkcheck/output.txt. 226 | goto end 227 | ) 228 | 229 | if "%1" == "doctest" ( 230 | %SPHINXBUILD% -b doctest %ALLSPHINXOPTS% %BUILDDIR%/doctest 231 | if errorlevel 1 exit /b 1 232 | echo. 233 | echo.Testing of doctests in the sources finished, look at the ^ 234 | results in %BUILDDIR%/doctest/output.txt. 235 | goto end 236 | ) 237 | 238 | if "%1" == "coverage" ( 239 | %SPHINXBUILD% -b coverage %ALLSPHINXOPTS% %BUILDDIR%/coverage 240 | if errorlevel 1 exit /b 1 241 | echo. 242 | echo.Testing of coverage in the sources finished, look at the ^ 243 | results in %BUILDDIR%/coverage/python.txt. 244 | goto end 245 | ) 246 | 247 | if "%1" == "xml" ( 248 | %SPHINXBUILD% -b xml %ALLSPHINXOPTS% %BUILDDIR%/xml 249 | if errorlevel 1 exit /b 1 250 | echo. 251 | echo.Build finished. The XML files are in %BUILDDIR%/xml. 252 | goto end 253 | ) 254 | 255 | if "%1" == "pseudoxml" ( 256 | %SPHINXBUILD% -b pseudoxml %ALLSPHINXOPTS% %BUILDDIR%/pseudoxml 257 | if errorlevel 1 exit /b 1 258 | echo. 259 | echo.Build finished. The pseudo-XML files are in %BUILDDIR%/pseudoxml. 260 | goto end 261 | ) 262 | 263 | :end 264 | -------------------------------------------------------------------------------- /docs/source/_static/css/custom.css: -------------------------------------------------------------------------------- 1 | /* Remove execution count for notebook cells. */ 2 | div.prompt { 3 | display: none; 4 | } 5 | 6 | /* Getting started index page */ 7 | 8 | .intro-card { 9 | background: #fff; 10 | border-radius: 0; 11 | padding: 30px 10px 10px 10px; 12 | margin: 10px 0px; 13 | max-height: 85%; 14 | } 15 | 16 | .intro-card .card-text { 17 | margin: 20px 0px; 18 | } 19 | 20 | div#index-container { 21 | padding-bottom: 20px; 22 | } 23 | 24 | a#index-link { 25 | color: #333; 26 | text-decoration: none; 27 | } 28 | 29 | /* reference to user guide */ 30 | .gs-torefguide { 31 | align-items: center; 32 | font-size: 0.9rem; 33 | } 34 | 35 | .gs-torefguide .badge { 36 | background-color: #130654; 37 | margin: 10px 10px 10px 0px; 38 | padding: 5px; 39 | } 40 | 41 | .gs-torefguide a { 42 | margin-left: 5px; 43 | color: #130654; 44 | border-bottom: 1px solid #FFCA00f3; 45 | box-shadow: 0px -10px 0px #FFCA00f3 inset; 46 | } 47 | 48 | .gs-torefguide p { 49 | margin-top: 1rem; 50 | } 51 | 52 | .gs-torefguide a:hover { 53 | margin-left: 5px; 54 | color: grey; 55 | text-decoration: none; 56 | border-bottom: 1px solid #b2ff80f3; 57 | box-shadow: 0px -10px 0px #b2ff80f3 inset; 58 | } 59 | -------------------------------------------------------------------------------- /docs/source/_static/images/book.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 5 | 6 | 7 | 16 | 17 | 18 | 19 | -------------------------------------------------------------------------------- /docs/source/_static/images/books.svg: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /docs/source/_static/images/coding.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 5 | 6 | 7 | 9 | 10 | 11 | 12 | 13 | 15 | 16 | 17 | 18 | 19 | 21 | 22 | 23 | 24 | -------------------------------------------------------------------------------- /docs/source/_static/images/light-bulb.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 5 | 6 | 11 | 13 | 15 | 17 | 19 | 21 | 23 | 25 | 27 | 29 | 30 | 32 | 33 | 34 | -------------------------------------------------------------------------------- /docs/source/_static/images/logo.svg: -------------------------------------------------------------------------------- 1 | 2 | 17 | 19 | 25 | 26 | 45 | 47 | 48 | 50 | image/svg+xml 51 | 53 | 54 | 55 | 56 | 57 | 62 | skillmodels 69 | 70 | 71 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # 2 | # Documentation build configuration file, created by sphinx-quickstart 3 | # 4 | # This file is execfile()d with the current directory set to its containing dir. 5 | # 6 | # Note that not all possible configuration values are present in this 7 | # autogenerated file. 8 | # 9 | # All configuration values have a default; values that are commented out 10 | # serve to show the default. 11 | import os 12 | import sys 13 | 14 | # If extensions (or modules to document with autodoc) are in another directory, 15 | # add these directories to sys.path here. If the directory is relative to the 16 | # documentation root, use os.path.abspath to make it absolute, like shown here. 17 | sys.path.insert(0, os.path.abspath("../..")) 18 | 19 | 20 | # -- General configuration ---------------------------------------------------- 21 | 22 | # If your documentation needs a minimal Sphinx, state it here. 23 | needs_sphinx = "1.6" 24 | 25 | # Add any Sphinx extension module names here, as strings. 26 | # They can be extensions coming with Sphinx (named "sphinx.ext.*") 27 | # or your custom ones. 28 | extensions = [ 29 | "sphinx.ext.autodoc", 30 | "sphinx.ext.viewcode", 31 | "sphinx.ext.mathjax", 32 | "sphinx.ext.napoleon", 33 | "sphinx.ext.todo", 34 | "nbsphinx", 35 | ] 36 | 37 | # Mock imports. 38 | autodoc_mock_imports = [ 39 | "optimagic", 40 | "matplotlib", 41 | "jax", 42 | "numpy", 43 | "pandas", 44 | "scipy", 45 | "filterpy", 46 | "dags", 47 | "plotly", 48 | ] 49 | 50 | # Add any paths that contain templates here, relative to this directory. 51 | templates_path = ["_templates"] 52 | 53 | # The suffix of source filenames. 54 | source_suffix = ".rst" 55 | 56 | # The encoding of source files. 57 | source_encoding = "utf-8" 58 | 59 | # The master toctree document. 60 | master_doc = "index" 61 | 62 | # General information about the project. 63 | project = "skillmodels" 64 | copyright = "2016-2021, Janos Gabler" 65 | 66 | # The version info for the project you"re documenting, acts as replacement for 67 | # |version| and |release|, also used in various other places throughout the 68 | # built documents. 69 | # 70 | # The short X.Y version. 71 | version = "0.2" 72 | # The full version, including alpha/beta/rc tags. 73 | release = "0.2.2" 74 | 75 | # The language for content autogenerated by Sphinx. Refer to documentation 76 | # for a list of supported languages. 77 | # language = None 78 | 79 | # There are two options for replacing |today|: either, you set today to some 80 | # non-false value, then it is used: 81 | # today = "" 82 | # Else, today_fmt is used as the format for a strftime call. 83 | today_fmt = "%d %B %Y" 84 | 85 | # List of patterns, relative to source directory, that match files and 86 | # directories to ignore when looking for source files. 87 | exclude_patterns = [] 88 | 89 | # The reST default role (used for this markup: `text`) to use for all documents. 90 | # default_role = None 91 | 92 | # If true, "()" will be appended to :func: etc. cross-reference text. 93 | add_function_parentheses = True 94 | 95 | # If true, the current module name will be prepended to all description 96 | # unit titles (such as .. function::). 97 | add_module_names = False 98 | 99 | # If true, sectionauthor and moduleauthor directives will be shown in the 100 | # output. They are ignored by default. 101 | show_authors = False 102 | 103 | # The name of the Pygments (syntax highlighting) style to use. 104 | pygments_style = "sphinx" 105 | 106 | # A list of ignored prefixes for module index sorting. 107 | modindex_common_prefix = ["src."] 108 | 109 | 110 | # -- Options for HTML output -------------------------------------------------- 111 | 112 | # The theme to use for HTML and HTML Help pages. See the documentation for 113 | # a list of builtin themes. 114 | html_theme = "pydata_sphinx_theme" 115 | 116 | html_logo = "_static/images/logo.svg" 117 | 118 | html_theme_options = { 119 | "github_url": "https://github.com/OpenSourceEconomics/skillmodels" 120 | } 121 | 122 | html_css_files = ["css/custom.css"] 123 | 124 | html_sidebars = { 125 | "**": [ 126 | "relations.html", # needs 'show_related': True theme option to display 127 | "searchbox.html", 128 | ], 129 | } 130 | 131 | templates_path = ["_templates"] 132 | html_static_path = ["_static"] 133 | 134 | 135 | html_show_copyright = True 136 | 137 | # If true, an OpenSearch description file will be output, and all pages will 138 | # contain a tag referring to it. The value of this option must be the 139 | # base URL from which the finished HTML is served. 140 | # html_use_opensearch = "" 141 | 142 | # This is the file name suffix for HTML files (e.g. ".xhtml"). 143 | html_file_suffix = ".html" 144 | 145 | # Output file base name for HTML help builder. 146 | htmlhelp_basename = "somedoc" 147 | 148 | # Other settings 149 | 150 | autodoc_member_order = "bysource" 151 | napoleon_use_rtype = False 152 | napoleon_include_private_with_doc = False 153 | todo_include_todos = True 154 | -------------------------------------------------------------------------------- /docs/source/explanations/index.rst: -------------------------------------------------------------------------------- 1 | Explanations 2 | ============ 3 | 4 | 5 | .. toctree:: 6 | :maxdepth: 1 7 | 8 | names_and_concepts 9 | notes_on_factor_scales 10 | -------------------------------------------------------------------------------- /docs/source/explanations/names_and_concepts.rst: -------------------------------------------------------------------------------- 1 | .. _names_and_concepts: 2 | 3 | 4 | ================== 5 | Names and concepts 6 | ================== 7 | 8 | This section contains an overview of frequently used variable names and 9 | concepts. It's not necessary to read this section if you are only interested in 10 | using the code, but you might want to skim it if you are interested in what the 11 | code actually does or plan to adapt it to your use case. 12 | 13 | Most of those quantities are generated once during the :ref:`model_processing` 14 | and appear as arguments of many other functions. 15 | 16 | .. _dimensions: 17 | 18 | ``dimensions`` 19 | ============== 20 | 21 | Dimensions of the model quantities. All of them are integers. 22 | 23 | - n_states: Number of latent factors or states in the model. Note that the terms 24 | state and factor are used interchangeably throughout the documentation. 25 | - n_periods: Number of periods of the model. There is one more period than 26 | transition equations of the model. 27 | - n_mixtures: Number of elements in the finite mixture of normals distribution. 28 | - n_controls: Number of control variables in the measurement equations. This 29 | includes the intercept of the measurement equation. Thus n_controls is always 30 | 1 or larger. 31 | 32 | 33 | .. _labels: 34 | 35 | ``labels`` 36 | ========== 37 | 38 | Labels for the model quantities. All of them are lists. 39 | 40 | 41 | - factors: Names of the latent factors. 42 | - controls: Names of the control variables. The first entry is always "constant". 43 | - periods: List of integers, starting at zero. The indices of the periods. 44 | - stagemap: Maps periods to stages. Has one entry less than the number of periods. 45 | - stages: The indices of the stages of the model. 46 | 47 | 48 | .. _stages_vs_periods: 49 | 50 | 51 | Development-Stages vs Periods 52 | ============================= 53 | 54 | A development is a group of consecutive periods for which the technology of skill 55 | formation remains the same. Thus the number of stages is always <= the number of 56 | periods of a model. 57 | 58 | Thus development stages are just equality constraints on the estimated parameter 59 | vector. Because they are very frequently used, skillmodels can generate the 60 | constraints automatically if you specify a stagemap in your model dictionary. 61 | 62 | 63 | Example: If you have a model with 5 periods you can estimate at most 4 different 64 | production functions (one for each transition between periods). If you want to 65 | keep the parameters of the technology of skill formation constant between two 66 | consecutive periods, you would specify the following stagemap: ``[0, 0, 1, 1]`` 67 | 68 | 69 | .. _anchoring: 70 | 71 | ``anchoring`` 72 | ============= 73 | 74 | 75 | 76 | 77 | .. _update_info: 78 | 79 | 80 | ``update_info`` 81 | =============== 82 | 83 | 84 | 85 | .. _normalizations: 86 | 87 | ``normalizations`` 88 | ================== 89 | 90 | 91 | .. _estimation_options: 92 | 93 | 94 | ``estimation_options`` 95 | ====================== 96 | -------------------------------------------------------------------------------- /docs/source/explanations/notes_on_factor_scales.rst: -------------------------------------------------------------------------------- 1 | ********************************** 2 | Notes on Scales and Normalizations 3 | ********************************** 4 | 5 | Here I collect Notes on different aspects of the discussion about factor 6 | scales and re-normalization. This discussion originates in the `critique`_ by 7 | Wiswall and Agostinelli but I argue below, that this critique is not yet 8 | complete. 9 | 10 | Wiswall and Agostinelli define a class of transition functions with Known 11 | Location and Scale (KLS) that require less normalizations. You should read 12 | this definition in their paper. 13 | 14 | The critique by Wiswall and Agostinelli potentially invalidates the empirical 15 | estimates of CHS, but not their general estimation routine. To get estimates 16 | that don't suffer from renormalization you can either use less normalizations 17 | or non-KLS transition functions. As there is no natural scale of skills, none 18 | of the approaches is better or worse. Nevertheless, I prefer using flexible 19 | Non-KLS transition functions with one normalization per period and factor. 20 | Firstly, because they are more compatible with using development stages that 21 | span more than one period. Secondly, because picking suitable normalizations 22 | might help to give the latent factors a more meaningful scale. 23 | 24 | 25 | .. _KLS_not_constant: 26 | 27 | Why KLS functions don't keep the scales constant 28 | ************************************************ 29 | 30 | Skills have no natural scale, but after reading the critique paper by Wiswall 31 | and Agostinelli one could easily get the impression that using KLS transition 32 | functions and less normalizations is better, because it identifies some sort 33 | of natural scale. Moreover in their `estimation`_ paper (p. 7), they write: 34 | "We argue that our limited normalization is appropriate for the dynamic 35 | setting of child development we analyze. With our normalization for the 36 | initial period only, latent skills in all periods share a common location 37 | and scale with respect to the one chosen normalizing measure." 38 | 39 | The following example intuitively shows firstly that the scale identified with 40 | KLS functions is as arbitrary as a scale identified through normalizations and 41 | secondly that this scale is not constant over time in general. 42 | 43 | The example completely abstracts from measurement and estimation problems and 44 | thereby allows to focus essential on the aspects of the problem. 45 | 46 | Consider a simple model of financial investments with two latent factors: a 47 | stock variable wealth (w) and a flow variable investment (i). Suppose periods 48 | last one year and annual interest rate on wealth is 10 percent. New 49 | investments are deposited at the end of the year (get interests only in the 50 | next year). 51 | 52 | The most intuitive scales to describe the system would be to measure all 53 | latent factors in all periods in the same currency, say Dollars. In this case 54 | the transition equation of wealth is given by: 55 | 56 | .. math:: 57 | 58 | w_{t + 1} = 1.1 w_t + i_t 59 | 60 | However, it would also be possible to measure w in period t in Dollars, i in 61 | period t in 1000 Dollars and w in period t + 1 in Dollar cents. The transition 62 | equation -- that still describes the exactly same system -- is then: 63 | 64 | .. math:: 65 | 66 | w_{t + 1} = 110 w_t + 100000 i_t 67 | 68 | The parameters now reflect the actual technology and scale changes between 69 | periods. They are much harder to interpret than before. In fact any linear 70 | function 71 | 72 | .. math:: 73 | 74 | f: \mathbb{R}^2 \rightarrow \mathbb{R} 75 | 76 | could describe the example system -- just in different combinations of scales. 77 | 78 | When latent factor models are estimated, the scales of each factor are usually 79 | set through normalizations in each period. The main point of the first paper 80 | by Wiswall and Agostinelli is that a KLS transition function prohibits to make 81 | such normalizations except for the initial period. One could say that after 82 | that, the transition function chooses the scale. 83 | 84 | The CES function has KLS and contains the subset of all linear functions 85 | without intercept whose parameters sum to 1 as special cases. It can therefore 86 | be used to describe the example system. After setting the scale of both 87 | factors to Dollars in the initial period, the CES function would then choose 88 | the scales for all other periods. 89 | 90 | The linear function that is a CES function and describes the system is: 91 | 92 | .. math:: 93 | w_{t + 1} = \frac{1}{2.1} (1.1 w_t + i_t) \approx 0.524 w_t + 0.476 i_t 94 | 95 | The scale of w in period t + 1 chosen by this function is thus 1 / 2.1 or 96 | approximately 0.476 Dollars which means that wealth in period t + 1 is 97 | approximately measured in 100 Philippine Pesos. 98 | 99 | 100 | .. _log_ces_problem: 101 | 102 | Why the CES and log_CES functions are problematic 103 | ************************************************* 104 | 105 | The definition of Known Location and Scale refers only to the scale of the 106 | (always one-dimensional) output of a transition function. After reading the 107 | Wiswall and Agostinelli critique, I wondered if the CES and log_CES functions 108 | also pose restrictions on the scales of their inputs, i.e. can describe a system 109 | only at a certain location or scale of inputs. 110 | 111 | According to Wiswall and Agostinelli, when using a log_CES function (which 112 | belongs to the KLS class), one needs initial normalizations of location and 113 | scale for all factors in the model. I made some pen-and-paper-calculations and 114 | estimated models with simulated data and the results suggest that less 115 | normalizations are needed with the log_CES function. 116 | 117 | While one does need to make initial normalizations for the location of all 118 | factors, it is sufficient to normalize the scale of only one factor in the 119 | initial period and the model is still identified. However, these are only 120 | simulations and I do not have a formal result that shows that the restrictions 121 | the log_CES function poses on the scale of its inputs are always enough for 122 | identification. 123 | 124 | I would therefore currently advise not to use the CES or log_CES function 125 | without thinking deeply about the normalizations you need. The automatic 126 | generation of normalizations treats the log_ces function simply as a KLS 127 | function. 128 | 129 | 130 | .. _normalization_and_stages: 131 | 132 | Normalizations and Development stages 133 | ************************************* 134 | 135 | CHS use development stages, i.e. several periods of childhood in which the 136 | parameters of the technology of skill formation remain the same. Wiswall and 137 | Agostinelli do not use or analyze this case, but development stages do change 138 | the normalization requirements. 139 | 140 | I always had the intuition that with development stages it is possible to 141 | identify a scale from the first period of the stage, such that no later 142 | normalizations are necessary until the next stage. When extending the WA 143 | estimator to be compatible with development stages, I could confirm this 144 | intuition as one nice feature of this estimator is that its identification 145 | strategy has to be very explicit. 146 | 147 | If development stages are used, one only has to make normalizations in the first 148 | period of each stage, except for the initial stage where the first two periods 149 | have to be normalized. My recommendation is to use automatic normalizations if 150 | you use development stages because it is very easy to get confused. 151 | 152 | This shows another type of over-normalization in the original CHS paper. 153 | 154 | .. _critique: 155 | https://tinyurl.com/y3wl43kz 156 | 157 | .. _estimation: 158 | https://tinyurl.com/y5ezloh2 159 | -------------------------------------------------------------------------------- /docs/source/getting_started/index.rst: -------------------------------------------------------------------------------- 1 | Getting Started 2 | =============== 3 | 4 | .. toctree:: 5 | :maxdepth: 1 6 | 7 | tutorial.ipynb 8 | -------------------------------------------------------------------------------- /docs/source/how_to_guides/how_to_simulate_dataset.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import pandas as pd\n", 10 | "import yaml\n", 11 | "\n", 12 | "from skillmodels.config import TEST_DIR\n", 13 | "from skillmodels.simulate_data import simulate_dataset" 14 | ] 15 | }, 16 | { 17 | "cell_type": "markdown", 18 | "metadata": {}, 19 | "source": [ 20 | "# How to simulate dataset\n", 21 | "\n", 22 | "\n", 23 | "\n", 24 | "Below we show how to simulate dataset for a test model. \n", 25 | "\n", 26 | "## Getting inputs\n", 27 | "\n", 28 | "For more details on this check out the introductory tutorial. " 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": null, 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "with open(TEST_DIR / \"model2.yaml\") as y:\n", 38 | " model_dict = yaml.load(y, Loader=yaml.FullLoader)\n", 39 | "\n", 40 | "data = pd.read_stata(TEST_DIR / \"model2_simulated_data.dta\")\n", 41 | "data = data.set_index([\"caseid\", \"period\"])\n", 42 | "\n", 43 | "params = pd.read_csv(TEST_DIR / \"regression_vault\" / \"one_stage_anchoring.csv\")\n", 44 | "params = params.set_index([\"category\", \"period\", \"name1\", \"name2\"])" 45 | ] 46 | }, 47 | { 48 | "cell_type": "markdown", 49 | "metadata": {}, 50 | "source": [ 51 | "## Simulated data without policy" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": null, 57 | "metadata": {}, 58 | "outputs": [], 59 | "source": [ 60 | "initial_data = simulate_dataset(\n", 61 | " model_dict=model_dict,\n", 62 | " params=params,\n", 63 | " data=data,\n", 64 | ")\n", 65 | "initial_data[\"anchored_states\"][\"states\"]" 66 | ] 67 | }, 68 | { 69 | "cell_type": "markdown", 70 | "metadata": {}, 71 | "source": [ 72 | "## Why do I need data to simulate data?\n", 73 | "\n", 74 | "The data you pass to simulate_data contains information on observed factors and control variables. Those are not part of the latent factor model and a standard model specification does not have enough information to generate them. \n", 75 | "\n", 76 | "If you have a model without control variables and observed factors, you can simply pass `n_obs` instead of `data`." 77 | ] 78 | }, 79 | { 80 | "cell_type": "markdown", 81 | "metadata": {}, 82 | "source": [ 83 | "## Simulated data with policy" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": null, 89 | "metadata": {}, 90 | "outputs": [], 91 | "source": [ 92 | "policies = [\n", 93 | " {\"period\": 0, \"factor\": \"fac1\", \"effect_size\": 0.2, \"standard_deviation\": 0.0},\n", 94 | " {\"period\": 1, \"factor\": \"fac2\", \"effect_size\": 0.1, \"standard_deviation\": 0.0},\n", 95 | "]" 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": null, 101 | "metadata": {}, 102 | "outputs": [], 103 | "source": [ 104 | "data_after_policies = simulate_dataset(\n", 105 | " model_dict=model_dict,\n", 106 | " params=params,\n", 107 | " data=data,\n", 108 | ")\n", 109 | "data_after_policies[\"anchored_states\"][\"states\"]" 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": null, 115 | "metadata": {}, 116 | "outputs": [], 117 | "source": [] 118 | } 119 | ], 120 | "metadata": { 121 | "language_info": { 122 | "codemirror_mode": { 123 | "name": "ipython", 124 | "version": 3 125 | }, 126 | "file_extension": ".py", 127 | "mimetype": "text/x-python", 128 | "name": "python", 129 | "nbconvert_exporter": "python", 130 | "pygments_lexer": "ipython3" 131 | } 132 | }, 133 | "nbformat": 4, 134 | "nbformat_minor": 4 135 | } 136 | -------------------------------------------------------------------------------- /docs/source/how_to_guides/how_to_visualize_correlations.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Visualizing correlations in a skill formation model" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import pandas as pd\n", 17 | "import yaml\n", 18 | "\n", 19 | "from skillmodels.config import TEST_DIR\n", 20 | "from skillmodels.correlation_heatmap import (\n", 21 | " get_measurements_corr,\n", 22 | " get_quasi_scores_corr,\n", 23 | " get_scores_corr,\n", 24 | " plot_correlation_heatmap,\n", 25 | ")\n", 26 | "\n", 27 | "%load_ext nb_black" 28 | ] 29 | }, 30 | { 31 | "cell_type": "markdown", 32 | "metadata": {}, 33 | "source": [ 34 | "## Loading inputs" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": null, 40 | "metadata": {}, 41 | "outputs": [], 42 | "source": [ 43 | "with open(TEST_DIR / \"model2.yaml\") as y:\n", 44 | " model_dict = yaml.load(y, Loader=yaml.FullLoader)" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": null, 50 | "metadata": {}, 51 | "outputs": [], 52 | "source": [ 53 | "params = pd.read_csv(TEST_DIR / \"regression_vault\" / \"one_stage_anchoring.csv\")\n", 54 | "params = params.set_index([\"category\", \"period\", \"name1\", \"name2\"])\n", 55 | "\n", 56 | "data = pd.read_stata(TEST_DIR / \"model2_simulated_data.dta\")\n", 57 | "data = data.set_index([\"caseid\", \"period\"])" 58 | ] 59 | }, 60 | { 61 | "cell_type": "markdown", 62 | "metadata": {}, 63 | "source": [ 64 | "## Plotting correlations of measurements" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": null, 70 | "metadata": {}, 71 | "outputs": [], 72 | "source": [ 73 | "corr_meas = get_measurements_corr(\n", 74 | " periods=0, data=data, model_dict=model_dict, factors=[\"fac1\", \"fac2\"]\n", 75 | ")" 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": null, 81 | "metadata": {}, 82 | "outputs": [], 83 | "source": [ 84 | "fig_meas = plot_correlation_heatmap(\n", 85 | " corr_meas,\n", 86 | ")" 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": null, 92 | "metadata": {}, 93 | "outputs": [], 94 | "source": [ 95 | "fig_meas.update_layout(title=\"Measurement correlations of fac1 and fac2 in period 0\")" 96 | ] 97 | }, 98 | { 99 | "cell_type": "markdown", 100 | "metadata": {}, 101 | "source": [ 102 | "## Plotting correlations of factor scores" 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "execution_count": null, 108 | "metadata": {}, 109 | "outputs": [], 110 | "source": [ 111 | "corr_score = get_scores_corr(\n", 112 | " periods=None, params=params, data=data, model_dict=model_dict, factors=\"fac1\"\n", 113 | ")" 114 | ] 115 | }, 116 | { 117 | "cell_type": "code", 118 | "execution_count": null, 119 | "metadata": {}, 120 | "outputs": [], 121 | "source": [ 122 | "fig_scores = plot_correlation_heatmap(\n", 123 | " corr_score,\n", 124 | ")" 125 | ] 126 | }, 127 | { 128 | "cell_type": "code", 129 | "execution_count": null, 130 | "metadata": {}, 131 | "outputs": [], 132 | "source": [ 133 | "fig_scores.update_layout(title=\"Stability of fac1 over time\")" 134 | ] 135 | }, 136 | { 137 | "cell_type": "code", 138 | "execution_count": null, 139 | "metadata": {}, 140 | "outputs": [], 141 | "source": [ 142 | "quasi_corr_score = get_quasi_scores_corr(\n", 143 | " periods=None, data=data, model_dict=model_dict, factors=\"fac1\"\n", 144 | ")" 145 | ] 146 | }, 147 | { 148 | "cell_type": "code", 149 | "execution_count": null, 150 | "metadata": {}, 151 | "outputs": [], 152 | "source": [ 153 | "fig_quasi_scores = plot_correlation_heatmap(\n", 154 | " quasi_corr_score,\n", 155 | ")" 156 | ] 157 | }, 158 | { 159 | "cell_type": "code", 160 | "execution_count": null, 161 | "metadata": {}, 162 | "outputs": [], 163 | "source": [ 164 | "fig_quasi_scores.update_layout(title=\"Stability of fac1 over time\")" 165 | ] 166 | }, 167 | { 168 | "cell_type": "code", 169 | "execution_count": null, 170 | "metadata": {}, 171 | "outputs": [], 172 | "source": [ 173 | "from skillmodels.visualize_transition_equations import _get_pardict, _set_index_params" 174 | ] 175 | }, 176 | { 177 | "cell_type": "code", 178 | "execution_count": null, 179 | "metadata": {}, 180 | "outputs": [], 181 | "source": [ 182 | "from skillmodels.process_model import process_model" 183 | ] 184 | }, 185 | { 186 | "cell_type": "code", 187 | "execution_count": null, 188 | "metadata": {}, 189 | "outputs": [], 190 | "source": [ 191 | "_get_pardict(\n", 192 | " params=_set_index_params(process_model(model_dict), params),\n", 193 | " model=process_model(model_dict),\n", 194 | ")[\"loadings\"]" 195 | ] 196 | }, 197 | { 198 | "cell_type": "code", 199 | "execution_count": null, 200 | "metadata": {}, 201 | "outputs": [], 202 | "source": [ 203 | "params" 204 | ] 205 | }, 206 | { 207 | "cell_type": "code", 208 | "execution_count": null, 209 | "metadata": {}, 210 | "outputs": [], 211 | "source": [] 212 | } 213 | ], 214 | "metadata": { 215 | "language_info": { 216 | "codemirror_mode": { 217 | "name": "ipython", 218 | "version": 3 219 | }, 220 | "file_extension": ".py", 221 | "mimetype": "text/x-python", 222 | "name": "python", 223 | "nbconvert_exporter": "python", 224 | "pygments_lexer": "ipython3" 225 | } 226 | }, 227 | "nbformat": 4, 228 | "nbformat_minor": 4 229 | } 230 | -------------------------------------------------------------------------------- /docs/source/how_to_guides/index.rst: -------------------------------------------------------------------------------- 1 | How-To Guides 2 | ============= 3 | 4 | 5 | .. toctree:: 6 | :maxdepth: 1 7 | 8 | model_specs 9 | utilities 10 | how_to_visualize_transition_equations.ipynb 11 | how_to_simulate_dataset.ipynb 12 | how_to_visualize_pairwise_factor_distribution.ipynb 13 | how_to_visualize_correlations.ipynb 14 | -------------------------------------------------------------------------------- /docs/source/how_to_guides/utilities.rst: -------------------------------------------------------------------------------- 1 | How to modify model specifications 2 | ================================== 3 | 4 | 5 | ``skillmodels.utilities`` contains functions to construct a model dictionary by varying 6 | an existing one and to update the parameters of a larger model from estimated parameters 7 | from smaller models. 8 | 9 | All functions that modify model dictionaries can can also modify a params DataFrame 10 | that was constructed for the original model accordingly. 11 | 12 | 13 | .. automodule:: skillmodels.utilities 14 | :members: 15 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | Welcome to the documentation of skillmodels! 2 | ============================================ 3 | 4 | 5 | 6 | Structure of the Documentation 7 | ============================== 8 | 9 | 10 | .. raw:: html 11 | 12 | 78 | 79 | 80 | Welcome to skillmodels, a Python implementation of estimators for skill 81 | formation models. The econometrics of skill formation models is a very active 82 | field and several estimators were proposed. None of them is implemented in 83 | standard econometrics packages. 84 | 85 | 86 | Skillmodels implements the Kalman filter based maximum likelihood estimator 87 | proposed by Cunha, Heckman and Schennach (CHS), (`Econometrica 2010`_) 88 | 89 | 90 | Skillmodels was developed for skill formation models but is by no means 91 | limited to this particular application. It can be applied to any dynamic 92 | nonlinear latent factor model. 93 | 94 | The CHS estimator implemented here differs in two points from the one 95 | implemented in their `replication files`_: 1) It uses different normalizations 96 | that take into account the `critique`_ of Wiswall and Agostinelli. 2) It can 97 | optionally use more robust square-root implementations of the Kalman filters. 98 | 99 | 100 | Most of the code is unit tested. Furthermore, the results have been compared 101 | to the Fortran code by CHS for two basic models with hypothetical data from 102 | their `replication files`_. 103 | 104 | 105 | **Citation** 106 | 107 | It took countless hours to write skillmodels. I make it available under a very 108 | permissive license in the hope that it helps other people to do great research 109 | that advances our knowledge about the formation of cognitive and noncognitive 110 | siklls. If you find skillmodels helpful, please don't forget to cite it. You 111 | can find a suggested citation in the README file on `GitHub`_. 112 | 113 | 114 | **Feedback** 115 | 116 | If you find skillmodels helpful for research or teaching, please let me know. 117 | If you encounter any problems with the installation or while using 118 | skillmodels, please complain or open an issue at `GitHub`_. 119 | 120 | 121 | 122 | .. _critique: 123 | https://tinyurl.com/y3wl43kz 124 | 125 | .. _replication files: 126 | https://tinyurl.com/yyuq2sa4 127 | 128 | .. _GitHub: 129 | https://github.com/OpenSourceEconomics/skillmodels 130 | 131 | 132 | .. _Econometrica 2010: 133 | http://onlinelibrary.wiley.com/doi/10.3982/ECTA6551/abstract 134 | 135 | 136 | .. toctree:: 137 | :maxdepth: 1 138 | 139 | getting_started/index 140 | how_to_guides/index 141 | explanations/index 142 | reference_guides/index 143 | -------------------------------------------------------------------------------- /docs/source/reference_guides/endogeneity_corrections.rst: -------------------------------------------------------------------------------- 1 | A note on endogeneity correction methods: 2 | ***************************************** 3 | 4 | In the empirical part of their paper, CHS use two methods for endogeneity 5 | correction. Both require very strong assumptions on the scale of factors. 6 | Below I give an overview of the proposed endogeneity correction methods that 7 | can serve as a starting point for someone who wants to extend skillmodels in 8 | that direction: 9 | 10 | In secton 4.2.4 CHS extend their basic model with a time invariant individual 11 | specific heterogeneity component, i.e. a fixed effect. The time invariance 12 | assumption can only be valid if the scale of all factors remains the same 13 | throughout the model. This is highly unlikely, unless age invariant 14 | measurements (as defined by Wiswall and Agostinelli) are available and used 15 | for normalization in all periods for all factors. With KLS transition 16 | functions the assumption of the factor scales remaining constant in all 17 | periods is highly unlikely (see: :ref:`KLS_not_constant`). Moreover, this 18 | approach requires 3 adult outcomes. If you have a dataset with enough time 19 | invariant measurements and enough adult outcomes, this method is suitable for 20 | you and you could use the Fortran code by CHS as a starting point. 21 | 22 | In 4.2.5 they make a endogeneity correction with time varying heterogeneity. 23 | However, this heterogeneity follows the same AR1 process in each period and 24 | relies on an estimated time invariant investment equation, so it also requires 25 | the factor scales to be constant. This might not be a good assumption in many 26 | applications. Moreover, this correction method relies on a exclusion 27 | restriction (Income is an argument of the investment function but not of the 28 | transition functions of other latent factors) or suitable functional form 29 | assumptions for identification. 30 | 31 | To use this correction method in models where not enough age invariant 32 | measurements are available to ensure constant factor scales, one would have to 33 | replace the AR1 process by a linear transition function with different 34 | estimated parameters in each period and also estimate a different investment 35 | function in each period. I don't know if this model is identified. 36 | 37 | I don't know if these methods could be used in the WA estimator. 38 | 39 | Wiswall and Agostinelli use a simpler model of endegeneity of investments that 40 | could be used with both estimators. See section 6.1.2 of their `paper`_. 41 | 42 | .. _paper: 43 | https://tinyurl.com/y5ezloh2 44 | 45 | 46 | .. _replication files: 47 | https://tinyurl.com/yyuq2sa4 48 | -------------------------------------------------------------------------------- /docs/source/reference_guides/estimation.rst: -------------------------------------------------------------------------------- 1 | ============================= 2 | Modules Related to Estimation 3 | ============================= 4 | 5 | .. _likelihood_function: 6 | 7 | The Likelihood Function 8 | ======================= 9 | 10 | .. automodule:: skillmodels.likelihood_function 11 | :members: 12 | 13 | .. _kalman_filters: 14 | 15 | The Kalman Filters 16 | ================== 17 | 18 | 19 | .. automodule:: skillmodels.kalman_filters 20 | :members: 21 | 22 | 23 | The Index of the Parameter DataFrame 24 | ==================================== 25 | 26 | 27 | .. _params_index: 28 | 29 | 30 | .. automodule:: skillmodels.params_index 31 | :members: 32 | 33 | 34 | 35 | .. _parse_params: 36 | 37 | Parsing the Parameter Vector 38 | ============================ 39 | 40 | 41 | .. automodule:: skillmodels.parse_params 42 | :members: 43 | -------------------------------------------------------------------------------- /docs/source/reference_guides/index.rst: -------------------------------------------------------------------------------- 1 | Reference Guides 2 | ================ 3 | 4 | 5 | .. toctree:: 6 | :maxdepth: 1 7 | 8 | pre_processing 9 | estimation 10 | simulation 11 | transition_functions 12 | endogeneity_corrections 13 | -------------------------------------------------------------------------------- /docs/source/reference_guides/pre_processing.rst: -------------------------------------------------------------------------------- 1 | ================================= 2 | How the User Inputs are Processed 3 | ================================= 4 | 5 | 6 | 7 | 8 | .. _model_processing: 9 | 10 | Model Processing 11 | ================ 12 | 13 | 14 | .. automodule:: skillmodels.process_model 15 | :members: 16 | 17 | 18 | 19 | .. _data_processing: 20 | 21 | Data Processing 22 | =============== 23 | 24 | 25 | .. automodule:: skillmodels.process_data 26 | :members: 27 | 28 | 29 | 30 | .. _model_checking: 31 | 32 | Model Checking 33 | ============== 34 | 35 | 36 | .. automodule:: skillmodels.check_model 37 | :members: 38 | -------------------------------------------------------------------------------- /docs/source/reference_guides/simulation.rst: -------------------------------------------------------------------------------- 1 | ============================= 2 | Modules Related to Simulation 3 | ============================= 4 | 5 | .. _simulate_data: 6 | 7 | 8 | Simulating a Dataset 9 | ==================== 10 | 11 | .. automodule:: skillmodels.simulate_data 12 | :members: 13 | -------------------------------------------------------------------------------- /docs/source/reference_guides/transition_functions.rst: -------------------------------------------------------------------------------- 1 | .. _transition_functions: 2 | 3 | Transition Equations 4 | ==================== 5 | 6 | 7 | .. automodule:: skillmodels.transition_functions 8 | :members: 9 | -------------------------------------------------------------------------------- /docs/source/rtd_environment.yml: -------------------------------------------------------------------------------- 1 | --- 2 | name: skillmodels_rtd 3 | channels: 4 | - conda-forge 5 | dependencies: 6 | - python=3.12 7 | - skillmodels 8 | - sphinxcontrib-bibtex 9 | - pydata-sphinx-theme>=0.3 10 | - sphinx 11 | - nbsphinx 12 | -------------------------------------------------------------------------------- /docs/source/start_params_template.csv: -------------------------------------------------------------------------------- 1 | category,period,name1,name2,value,lower_bound,upper_bound 2 | delta,0,y1,constant,,-inf,inf 3 | delta,0,y1,x1,,-inf,inf 4 | delta,0,y2,constant,,-inf,inf 5 | delta,0,y2,x1,,-inf,inf 6 | delta,0,y3,constant,,-inf,inf 7 | delta,0,y3,x1,,-inf,inf 8 | delta,0,y4,constant,,-inf,inf 9 | delta,0,y4,x1,,-inf,inf 10 | delta,0,y5,constant,,-inf,inf 11 | delta,0,y5,x1,,-inf,inf 12 | delta,0,y6,constant,,-inf,inf 13 | delta,0,y6,x1,,-inf,inf 14 | delta,0,y7,constant,,-inf,inf 15 | delta,0,y7,x1,,-inf,inf 16 | delta,0,y8,constant,,-inf,inf 17 | delta,0,y8,x1,,-inf,inf 18 | delta,0,y9,constant,,-inf,inf 19 | delta,0,y9,x1,,-inf,inf 20 | delta,0,Q1_fac1,constant,,-inf,inf 21 | delta,0,Q1_fac1,x1,,-inf,inf 22 | delta,1,y1,constant,,-inf,inf 23 | delta,1,y1,x1,,-inf,inf 24 | delta,1,y2,constant,,-inf,inf 25 | delta,1,y2,x1,,-inf,inf 26 | delta,1,y3,constant,,-inf,inf 27 | delta,1,y3,x1,,-inf,inf 28 | delta,1,y4,constant,,-inf,inf 29 | delta,1,y4,x1,,-inf,inf 30 | delta,1,y5,constant,,-inf,inf 31 | delta,1,y5,x1,,-inf,inf 32 | delta,1,y6,constant,,-inf,inf 33 | delta,1,y6,x1,,-inf,inf 34 | delta,1,Q1_fac1,constant,,-inf,inf 35 | delta,1,Q1_fac1,x1,,-inf,inf 36 | delta,2,y1,constant,,-inf,inf 37 | delta,2,y1,x1,,-inf,inf 38 | delta,2,y2,constant,,-inf,inf 39 | delta,2,y2,x1,,-inf,inf 40 | delta,2,y3,constant,,-inf,inf 41 | delta,2,y3,x1,,-inf,inf 42 | delta,2,y4,constant,,-inf,inf 43 | delta,2,y4,x1,,-inf,inf 44 | delta,2,y5,constant,,-inf,inf 45 | delta,2,y5,x1,,-inf,inf 46 | delta,2,y6,constant,,-inf,inf 47 | delta,2,y6,x1,,-inf,inf 48 | delta,2,Q1_fac1,constant,,-inf,inf 49 | delta,2,Q1_fac1,x1,,-inf,inf 50 | delta,3,y1,constant,,-inf,inf 51 | delta,3,y1,x1,,-inf,inf 52 | delta,3,y2,constant,,-inf,inf 53 | delta,3,y2,x1,,-inf,inf 54 | delta,3,y3,constant,,-inf,inf 55 | delta,3,y3,x1,,-inf,inf 56 | delta,3,y4,constant,,-inf,inf 57 | delta,3,y4,x1,,-inf,inf 58 | delta,3,y5,constant,,-inf,inf 59 | delta,3,y5,x1,,-inf,inf 60 | delta,3,y6,constant,,-inf,inf 61 | delta,3,y6,x1,,-inf,inf 62 | delta,3,Q1_fac1,constant,,-inf,inf 63 | delta,3,Q1_fac1,x1,,-inf,inf 64 | delta,4,y1,constant,,-inf,inf 65 | delta,4,y1,x1,,-inf,inf 66 | delta,4,y2,constant,,-inf,inf 67 | delta,4,y2,x1,,-inf,inf 68 | delta,4,y3,constant,,-inf,inf 69 | delta,4,y3,x1,,-inf,inf 70 | delta,4,y4,constant,,-inf,inf 71 | delta,4,y4,x1,,-inf,inf 72 | delta,4,y5,constant,,-inf,inf 73 | delta,4,y5,x1,,-inf,inf 74 | delta,4,y6,constant,,-inf,inf 75 | delta,4,y6,x1,,-inf,inf 76 | delta,4,Q1_fac1,constant,,-inf,inf 77 | delta,4,Q1_fac1,x1,,-inf,inf 78 | delta,5,y1,constant,,-inf,inf 79 | delta,5,y1,x1,,-inf,inf 80 | delta,5,y2,constant,,-inf,inf 81 | delta,5,y2,x1,,-inf,inf 82 | delta,5,y3,constant,,-inf,inf 83 | delta,5,y3,x1,,-inf,inf 84 | delta,5,y4,constant,,-inf,inf 85 | delta,5,y4,x1,,-inf,inf 86 | delta,5,y5,constant,,-inf,inf 87 | delta,5,y5,x1,,-inf,inf 88 | delta,5,y6,constant,,-inf,inf 89 | delta,5,y6,x1,,-inf,inf 90 | delta,5,Q1_fac1,constant,,-inf,inf 91 | delta,5,Q1_fac1,x1,,-inf,inf 92 | delta,6,y1,constant,,-inf,inf 93 | delta,6,y1,x1,,-inf,inf 94 | delta,6,y2,constant,,-inf,inf 95 | delta,6,y2,x1,,-inf,inf 96 | delta,6,y3,constant,,-inf,inf 97 | delta,6,y3,x1,,-inf,inf 98 | delta,6,y4,constant,,-inf,inf 99 | delta,6,y4,x1,,-inf,inf 100 | delta,6,y5,constant,,-inf,inf 101 | delta,6,y5,x1,,-inf,inf 102 | delta,6,y6,constant,,-inf,inf 103 | delta,6,y6,x1,,-inf,inf 104 | delta,6,Q1_fac1,constant,,-inf,inf 105 | delta,6,Q1_fac1,x1,,-inf,inf 106 | delta,7,y1,constant,,-inf,inf 107 | delta,7,y1,x1,,-inf,inf 108 | delta,7,y2,constant,,-inf,inf 109 | delta,7,y2,x1,,-inf,inf 110 | delta,7,y3,constant,,-inf,inf 111 | delta,7,y3,x1,,-inf,inf 112 | delta,7,y4,constant,,-inf,inf 113 | delta,7,y4,x1,,-inf,inf 114 | delta,7,y5,constant,,-inf,inf 115 | delta,7,y5,x1,,-inf,inf 116 | delta,7,y6,constant,,-inf,inf 117 | delta,7,y6,x1,,-inf,inf 118 | delta,7,Q1_fac1,constant,,-inf,inf 119 | delta,7,Q1_fac1,x1,,-inf,inf 120 | loading,0,y2,fac1,,-inf,inf 121 | loading,0,y3,fac1,,-inf,inf 122 | loading,0,y5,fac2,,-inf,inf 123 | loading,0,y6,fac2,,-inf,inf 124 | loading,0,y8,fac3,,-inf,inf 125 | loading,0,y9,fac3,,-inf,inf 126 | loading,0,Q1_fac1,fac1,,-inf,inf 127 | loading,1,y2,fac1,,-inf,inf 128 | loading,1,y3,fac1,,-inf,inf 129 | loading,1,y5,fac2,,-inf,inf 130 | loading,1,y6,fac2,,-inf,inf 131 | loading,1,Q1_fac1,fac1,,-inf,inf 132 | loading,2,y2,fac1,,-inf,inf 133 | loading,2,y3,fac1,,-inf,inf 134 | loading,2,y5,fac2,,-inf,inf 135 | loading,2,y6,fac2,,-inf,inf 136 | loading,2,Q1_fac1,fac1,,-inf,inf 137 | loading,3,y2,fac1,,-inf,inf 138 | loading,3,y3,fac1,,-inf,inf 139 | loading,3,y5,fac2,,-inf,inf 140 | loading,3,y6,fac2,,-inf,inf 141 | loading,3,Q1_fac1,fac1,,-inf,inf 142 | loading,4,y2,fac1,,-inf,inf 143 | loading,4,y3,fac1,,-inf,inf 144 | loading,4,y5,fac2,,-inf,inf 145 | loading,4,y6,fac2,,-inf,inf 146 | loading,4,Q1_fac1,fac1,,-inf,inf 147 | loading,5,y2,fac1,,-inf,inf 148 | loading,5,y3,fac1,,-inf,inf 149 | loading,5,y5,fac2,,-inf,inf 150 | loading,5,y6,fac2,,-inf,inf 151 | loading,5,Q1_fac1,fac1,,-inf,inf 152 | loading,6,y2,fac1,,-inf,inf 153 | loading,6,y3,fac1,,-inf,inf 154 | loading,6,y5,fac2,,-inf,inf 155 | loading,6,y6,fac2,,-inf,inf 156 | loading,6,Q1_fac1,fac1,,-inf,inf 157 | loading,7,y2,fac1,,-inf,inf 158 | loading,7,y3,fac1,,-inf,inf 159 | loading,7,y5,fac2,,-inf,inf 160 | loading,7,y6,fac2,,-inf,inf 161 | loading,7,Q1_fac1,fac1,,-inf,inf 162 | meas_sd,0,y1,-,,-inf,inf 163 | meas_sd,0,y2,-,,-inf,inf 164 | meas_sd,0,y3,-,,-inf,inf 165 | meas_sd,0,y4,-,,-inf,inf 166 | meas_sd,0,y5,-,,-inf,inf 167 | meas_sd,0,y6,-,,-inf,inf 168 | meas_sd,0,y7,-,,-inf,inf 169 | meas_sd,0,y8,-,,-inf,inf 170 | meas_sd,0,y9,-,,-inf,inf 171 | meas_sd,0,Q1_fac1,-,,-inf,inf 172 | meas_sd,1,y1,-,,-inf,inf 173 | meas_sd,1,y2,-,,-inf,inf 174 | meas_sd,1,y3,-,,-inf,inf 175 | meas_sd,1,y4,-,,-inf,inf 176 | meas_sd,1,y5,-,,-inf,inf 177 | meas_sd,1,y6,-,,-inf,inf 178 | meas_sd,1,Q1_fac1,-,,-inf,inf 179 | meas_sd,2,y1,-,,-inf,inf 180 | meas_sd,2,y2,-,,-inf,inf 181 | meas_sd,2,y3,-,,-inf,inf 182 | meas_sd,2,y4,-,,-inf,inf 183 | meas_sd,2,y5,-,,-inf,inf 184 | meas_sd,2,y6,-,,-inf,inf 185 | meas_sd,2,Q1_fac1,-,,-inf,inf 186 | meas_sd,3,y1,-,,-inf,inf 187 | meas_sd,3,y2,-,,-inf,inf 188 | meas_sd,3,y3,-,,-inf,inf 189 | meas_sd,3,y4,-,,-inf,inf 190 | meas_sd,3,y5,-,,-inf,inf 191 | meas_sd,3,y6,-,,-inf,inf 192 | meas_sd,3,Q1_fac1,-,,-inf,inf 193 | meas_sd,4,y1,-,,-inf,inf 194 | meas_sd,4,y2,-,,-inf,inf 195 | meas_sd,4,y3,-,,-inf,inf 196 | meas_sd,4,y4,-,,-inf,inf 197 | meas_sd,4,y5,-,,-inf,inf 198 | meas_sd,4,y6,-,,-inf,inf 199 | meas_sd,4,Q1_fac1,-,,-inf,inf 200 | meas_sd,5,y1,-,,-inf,inf 201 | meas_sd,5,y2,-,,-inf,inf 202 | meas_sd,5,y3,-,,-inf,inf 203 | meas_sd,5,y4,-,,-inf,inf 204 | meas_sd,5,y5,-,,-inf,inf 205 | meas_sd,5,y6,-,,-inf,inf 206 | meas_sd,5,Q1_fac1,-,,-inf,inf 207 | meas_sd,6,y1,-,,-inf,inf 208 | meas_sd,6,y2,-,,-inf,inf 209 | meas_sd,6,y3,-,,-inf,inf 210 | meas_sd,6,y4,-,,-inf,inf 211 | meas_sd,6,y5,-,,-inf,inf 212 | meas_sd,6,y6,-,,-inf,inf 213 | meas_sd,6,Q1_fac1,-,,-inf,inf 214 | meas_sd,7,y1,-,,-inf,inf 215 | meas_sd,7,y2,-,,-inf,inf 216 | meas_sd,7,y3,-,,-inf,inf 217 | meas_sd,7,y4,-,,-inf,inf 218 | meas_sd,7,y5,-,,-inf,inf 219 | meas_sd,7,y6,-,,-inf,inf 220 | meas_sd,7,Q1_fac1,-,,-inf,inf 221 | shock_variance,0,fac1,-,,-inf,inf 222 | shock_variance,0,fac2,-,,-inf,inf 223 | initial_mean,0,mixture_0,fac1,,-inf,inf 224 | initial_mean,0,mixture_0,fac2,,-inf,inf 225 | initial_mean,0,mixture_0,fac3,,-inf,inf 226 | initial_cov,0,mixture_0,fac1-fac1,,-inf,inf 227 | initial_cov,0,mixture_0,fac2-fac1,,-inf,inf 228 | initial_cov,0,mixture_0,fac2-fac2,,-inf,inf 229 | initial_cov,0,mixture_0,fac3-fac1,,-inf,inf 230 | initial_cov,0,mixture_0,fac3-fac2,,-inf,inf 231 | initial_cov,0,mixture_0,fac3-fac3,,-inf,inf 232 | trans,0,fac1,fac1,,-inf,inf 233 | trans,0,fac1,fac2,,-inf,inf 234 | trans,0,fac1,fac3,,-inf,inf 235 | trans,0,fac1,phi,,-inf,inf 236 | trans,0,fac2,fac2,,-inf,inf 237 | trans,0,fac2,constant,,-inf,inf 238 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | --- 2 | name: skillmodels 3 | channels: 4 | - conda-forge 5 | - opensourceeconomics 6 | dependencies: 7 | - python=3.9 8 | - jax>=0.1.7 9 | - jaxlib>=0.1.51 10 | - pip 11 | - conda-build 12 | - doc8 13 | - jupyter 14 | - numpy 15 | - pandas 16 | - pytest 17 | - pytest-xdist 18 | - restructuredtext_lint 19 | - tox-conda 20 | - scipy 21 | - sphinxcontrib-bibtex 22 | - pydata-sphinx-theme>=0.3 23 | - anaconda-client 24 | - nbsphinx 25 | - pdbpp 26 | - filterpy 27 | - dags 28 | - plotly 29 | - pip: 30 | - pre-commit 31 | - black 32 | -------------------------------------------------------------------------------- /src/skillmodels/__init__.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | 3 | try: 4 | import pdbp # noqa: F401 5 | except ImportError: 6 | contextlib.suppress(Exception) 7 | 8 | from skillmodels.filtered_states import get_filtered_states 9 | from skillmodels.maximization_inputs import get_maximization_inputs 10 | from skillmodels.simulate_data import simulate_dataset 11 | 12 | __all__ = ["get_maximization_inputs", "simulate_dataset", "get_filtered_states"] 13 | -------------------------------------------------------------------------------- /src/skillmodels/check_model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def check_model(model_dict, labels, dimensions, anchoring): 5 | """Check consistency and validity of the model specification. 6 | 7 | labels, dimensions and anchoring information are done before the model checking 8 | because processing them will not raise any errors except for easy to understand 9 | KeyErrors. 10 | 11 | Other specifications are checked in the model dict before processing to make sure 12 | that the assumptions we make during the processing are fulfilled. 13 | 14 | Args: 15 | model_dict (dict): The model specification. See: :ref:`model_specs` 16 | dimensions (dict): Dimensional information like n_states, n_periods, n_controls, 17 | n_mixtures. See :ref:`dimensions`. 18 | 19 | labels (dict): Dict of lists with labels for the model quantities like 20 | factors, periods, controls, stagemap and stages. See :ref:`labels` 21 | 22 | anchoring (dict): Dictionary with information about anchoring. 23 | See :ref:`anchoring` 24 | 25 | Raises: 26 | ValueError 27 | 28 | """ 29 | report = _check_stagemap( 30 | labels["stagemap"], 31 | labels["stages"], 32 | dimensions["n_periods"], 33 | ) 34 | report += _check_anchoring(anchoring) 35 | report += _check_measurements(model_dict, labels["latent_factors"]) 36 | report += _check_normalizations(model_dict, labels["latent_factors"]) 37 | 38 | report = "\n".join(report) 39 | if report != "": 40 | raise ValueError(f"Invalid model specification:\n{report}") 41 | 42 | 43 | def _check_stagemap(stagemap, stages, n_periods): 44 | report = [] 45 | if len(stagemap) != n_periods - 1: 46 | report.append( 47 | "The stagemap needs to be of length n_periods - 1. n_periods is " 48 | f"{n_periods}, the stagemap has length {len(stagemap)}.", 49 | ) 50 | 51 | if stages != list(range(len(stages))): 52 | report.append("Stages need to be integers, start at zero and increase by 1.") 53 | 54 | if not np.isin(np.array(stagemap[1:]) - np.array(stagemap[:-1]), (0, 1)).all(): 55 | report.append("Consecutive entries in stagemap must be equal or increase by 1.") 56 | return report 57 | 58 | 59 | def _check_anchoring(anchoring): 60 | report = [] 61 | if not isinstance(anchoring["anchoring"], bool): 62 | report.append("anchoring['anchoring'] must be a bool.") 63 | if not isinstance(anchoring["outcomes"], dict): 64 | report.append("anchoring['outcomes'] must be a dict") 65 | else: 66 | variables = list(anchoring["outcomes"].values()) 67 | for var in variables: 68 | if not isinstance(var, str | int | tuple): 69 | report.append("Outcomes variables have to be valid variable names.") 70 | 71 | if not isinstance(anchoring["free_controls"], bool): 72 | report.append("anchoring['use_controls'] must be a bool") 73 | if not isinstance(anchoring["free_constant"], bool): 74 | report.append("anchoring['use_constant'] must be a bool.") 75 | if not isinstance(anchoring["free_loadings"], bool): 76 | report.append("anchoring['free_loadings'] must be a bool.") 77 | return report 78 | 79 | 80 | def _check_measurements(model_dict, factors): 81 | report = [] 82 | for factor in factors: 83 | candidate = model_dict["factors"][factor]["measurements"] 84 | if not _is_list_of(candidate, list): 85 | report.append( 86 | f"measurements must lists of lists. Check measurements of {factor}.", 87 | ) 88 | else: 89 | for period, meas_list in enumerate(candidate): 90 | for meas in meas_list: 91 | if not isinstance(meas, int | str | tuple): 92 | report.append( 93 | "Measurements need to be valid pandas column names. Check " 94 | f"{meas} for {factor} in period {period}.", 95 | ) 96 | return report 97 | 98 | 99 | def _check_normalizations(model_dict, factors): 100 | report = [] 101 | for factor in factors: 102 | norminfo = model_dict["factors"][factor].get("normalizations", {}) 103 | for norm_type in ["loadings", "intercepts"]: 104 | candidate = norminfo.get(norm_type, []) 105 | if not _is_list_of(candidate, dict): 106 | report.append( 107 | f"normalizations must be lists of dicts. Check {norm_type} " 108 | f"normalizations for {factor}.", 109 | ) 110 | else: 111 | report += _check_normalized_variables_are_present( 112 | candidate, 113 | model_dict, 114 | factor, 115 | ) 116 | 117 | if norm_type == "loadings": 118 | report += _check_loadings_are_not_normalized_to_zero( 119 | candidate, 120 | factor, 121 | ) 122 | return report 123 | 124 | 125 | def _check_normalized_variables_are_present(list_of_normdicts, model_dict, factor): 126 | report = [] 127 | for period, norm_dict in enumerate(list_of_normdicts): 128 | for var in norm_dict: 129 | if var not in model_dict["factors"][factor]["measurements"][period]: 130 | report.append( 131 | "You can only normalize variables that are specified as " 132 | f"measurements. Check {var} for {factor} in period " 133 | f"{period}.", 134 | ) 135 | 136 | return report 137 | 138 | 139 | def _check_loadings_are_not_normalized_to_zero(list_of_normdicts, factor): 140 | report = [] 141 | for period, norm_dict in enumerate(list_of_normdicts): 142 | for var, val in norm_dict.items(): 143 | if val == 0: 144 | report.append( 145 | f"loadings cannot be normalized to 0. Check measurement {var} " 146 | f"of {factor} in period {period}.", 147 | ) 148 | return report 149 | 150 | 151 | def _is_list_of(candidate, type_): 152 | """Check if candidate is a list that only contains elements of type. 153 | 154 | Note that this is always falls if candidate is not a list and always true if 155 | it is an empty list. 156 | 157 | Examples: 158 | >>> _is_list_of([["a"], ["b"]], list) 159 | True 160 | >>> _is_list_of([{}], list) 161 | False 162 | >>> _is_list_of([], dict) 163 | True 164 | 165 | """ 166 | return isinstance(candidate, list) and all(isinstance(i, type_) for i in candidate) 167 | -------------------------------------------------------------------------------- /src/skillmodels/clipping.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | 4 | 5 | def soft_clipping(arr, lower=None, upper=None, lower_hardness=1, upper_hardness=1): 6 | """Clip values in an array elementwise using a soft maximum to avoid kinks. 7 | 8 | Clipping from below is taking a maximum between two values. Clipping 9 | from above is taking a minimum, but it can be rewritten as taking a maximum after 10 | switching the signs. 11 | 12 | To smooth out the kinks introduced by normal clipping, we first rewrite all clipping 13 | operations to taking maxima. Then we replace the normal maximum by the soft maximum. 14 | 15 | For background on the soft maximum check out this 16 | `article by John Cook: `_ 17 | 18 | Note that contrary to the name, the soft maximum can be calculated using 19 | ``scipy.special.logsumexp``. ``scipy.special.softmax`` is the gradient of 20 | ``scipy.special.logsumexp``. 21 | 22 | 23 | Args: 24 | arr (jax.numpy.array): Array that is clipped elementwise. 25 | lower (float): The value at which the array is clipped from below. 26 | upper (float): The value at which the array is clipped from above. 27 | lower_hardness (float): Scaling factor that is applied inside the soft maximum. 28 | High values imply a closer approximation of the real maximum. 29 | upper_hardness (float): Scaling factor that is applied inside the soft maximum. 30 | High values imply a closer approximation of the real maximum. 31 | 32 | """ 33 | shape = arr.shape 34 | flat = arr.flatten() 35 | dim = len(flat) 36 | if lower is not None: 37 | helper = jnp.column_stack([flat, jnp.full(dim, lower)]) 38 | flat = ( 39 | jax.scipy.special.logsumexp(lower_hardness * helper, axis=1) 40 | / lower_hardness 41 | ) 42 | if upper is not None: 43 | helper = jnp.column_stack([-flat, jnp.full(dim, -upper)]) 44 | flat = ( 45 | -jax.scipy.special.logsumexp(upper_hardness * helper, axis=1) 46 | / upper_hardness 47 | ) 48 | return flat.reshape(shape) 49 | -------------------------------------------------------------------------------- /src/skillmodels/config.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | TEST_DIR = Path(__file__).resolve().parent / "tests" 4 | -------------------------------------------------------------------------------- /src/skillmodels/decorators.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | import jax.numpy as jnp 4 | 5 | 6 | def extract_params(func=None, *, key=None, names=None): 7 | """Process params before passing them to func. 8 | 9 | Note: The resulting function is keyword only! 10 | 11 | Args: 12 | key (str or None): If key is not None, we assume params is a dictionary of which 13 | only the params[key] should be passed into func. 14 | names (list or None): If names is provided, we assume that params 15 | (or params[key]) should be converted to a dictionary with names as keys 16 | before passing them to func. 17 | 18 | """ 19 | 20 | def decorator_extract_params(func): 21 | if key is not None and names is None: 22 | 23 | @functools.wraps(func) 24 | def wrapper_extract_params(**kwargs): 25 | internal_kwargs = kwargs.copy() 26 | internal_kwargs["params"] = kwargs["params"][key] 27 | return func(**internal_kwargs) 28 | 29 | elif key is None and names is not None: 30 | 31 | @functools.wraps(func) 32 | def wrapper_extract_params(**kwargs): 33 | internal_kwargs = kwargs.copy() 34 | internal_kwargs["params"] = dict( 35 | zip(names, kwargs["params"], strict=False) 36 | ) 37 | return func(**internal_kwargs) 38 | 39 | elif key is not None and names is not None: 40 | 41 | @functools.wraps(func) 42 | def wrapper_extract_params(**kwargs): 43 | internal_kwargs = kwargs.copy() 44 | internal_kwargs["params"] = dict( 45 | zip(names, kwargs["params"][key], strict=False) 46 | ) 47 | return func(**internal_kwargs) 48 | 49 | else: 50 | raise ValueError("key and names cannot both be None.") 51 | 52 | return wrapper_extract_params 53 | 54 | if callable(func): 55 | return decorator_extract_params(func) 56 | return decorator_extract_params 57 | 58 | 59 | def jax_array_output(func): 60 | """Convert tuple output to list output.""" 61 | 62 | @functools.wraps(func) 63 | def wrapper_jax_array_output(*args, **kwargs): 64 | raw = func(*args, **kwargs) 65 | out = jnp.array(raw) 66 | return out 67 | 68 | return wrapper_jax_array_output 69 | 70 | 71 | def register_params(func=None, *, params=None): 72 | def decorator_register_params(func): 73 | func.__registered_params__ = params 74 | return func 75 | 76 | if callable(func): 77 | return decorator_register_params(func) 78 | return decorator_register_params 79 | -------------------------------------------------------------------------------- /src/skillmodels/filtered_states.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | import numpy as np 3 | 4 | from skillmodels.maximization_inputs import get_maximization_inputs 5 | from skillmodels.params_index import get_params_index 6 | from skillmodels.parse_params import create_parsing_info, parse_params 7 | from skillmodels.process_debug_data import create_state_ranges 8 | from skillmodels.process_model import process_model 9 | 10 | 11 | def get_filtered_states(model_dict, data, params): 12 | max_inputs = get_maximization_inputs(model_dict=model_dict, data=data) 13 | params = params.loc[max_inputs["params_template"].index] 14 | debug_loglike = max_inputs["debug_loglike"] 15 | debug_data = debug_loglike(params) 16 | unanchored_states_df = debug_data["filtered_states"] 17 | unanchored_ranges = debug_data["state_ranges"] 18 | model = process_model(model_dict) 19 | 20 | anchored_states_df = anchor_states_df( 21 | states_df=unanchored_states_df, 22 | model_dict=model_dict, 23 | params=params, 24 | ) 25 | 26 | anchored_ranges = create_state_ranges( 27 | filtered_states=anchored_states_df, 28 | factors=model["labels"]["latent_factors"], 29 | ) 30 | 31 | out = { 32 | "anchored_states": { 33 | "states": anchored_states_df, 34 | "state_ranges": anchored_ranges, 35 | }, 36 | "unanchored_states": { 37 | "states": unanchored_states_df, 38 | "state_ranges": unanchored_ranges, 39 | }, 40 | } 41 | 42 | return out 43 | 44 | 45 | def anchor_states_df(states_df, model_dict, params): 46 | """Anchor states in a DataFrame. 47 | 48 | The DataFrame is expected to have a column called "period" as well as one column 49 | for each latent factor. 50 | 51 | All other columns are not affected. 52 | 53 | This is a bit difficult because we need to re-use `parse_params` (which was meant 54 | as an internal function that only works with jax objects). 55 | 56 | """ 57 | model = process_model(model_dict) 58 | 59 | p_index = get_params_index( 60 | model["update_info"], 61 | model["labels"], 62 | model["dimensions"], 63 | model["transition_info"], 64 | ) 65 | 66 | params = params.loc[p_index] 67 | 68 | parsing_info = create_parsing_info( 69 | p_index, 70 | model["update_info"], 71 | model["labels"], 72 | model["anchoring"], 73 | ) 74 | 75 | *_, pardict = parse_params( 76 | params=jnp.array(params["value"].to_numpy()), 77 | parsing_info=parsing_info, 78 | dimensions=model["dimensions"], 79 | labels=model["labels"], 80 | n_obs=1, 81 | ) 82 | 83 | n_latent = model["dimensions"]["n_latent_factors"] 84 | 85 | scaling_factors = np.array(pardict["anchoring_scaling_factors"][:, :n_latent]) 86 | constants = np.array(pardict["anchoring_constants"][:, :n_latent]) 87 | 88 | period_arr = states_df["period"].to_numpy() 89 | scaling_arr = scaling_factors[period_arr] 90 | constants_arr = constants[period_arr] 91 | 92 | out = states_df.copy(deep=True) 93 | for pos, factor in enumerate(model["labels"]["latent_factors"]): 94 | out[factor] = constants_arr[:, pos] + states_df[factor] * scaling_arr[:, pos] 95 | 96 | out = out[states_df.columns] 97 | 98 | return out 99 | -------------------------------------------------------------------------------- /src/skillmodels/kalman_filters_debug.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | 4 | array_qr_jax = jax.vmap(jax.vmap(jnp.linalg.qr)) 5 | 6 | 7 | # ====================================================================================== 8 | # Update Step 9 | # ====================================================================================== 10 | 11 | 12 | def kalman_update( 13 | states, 14 | upper_chols, 15 | loadings, 16 | control_params, 17 | meas_sd, 18 | measurements, 19 | controls, 20 | log_mixture_weights, 21 | ): 22 | """Perform a Kalman update with likelihood evaluation, returning debug info on top. 23 | 24 | Args: 25 | states (jax.numpy.array): Array of shape (n_obs, n_mixtures, n_states) with 26 | pre-update states estimates. 27 | upper_chols (jax.numpy.array): Array of shape (n_obs, n_mixtures, n_states, 28 | n_states) with the transpose of the lower triangular cholesky factor 29 | of the pre-update covariance matrix of the state estimates. 30 | loadings (jax.numpy.array): 1d array of length n_states with factor loadings. 31 | control_params (jax.numpy.array): 1d array of length n_controls. 32 | meas_sd (float): Standard deviation of the measurement error. 33 | measurements (jax.numpy.array): 1d array of length n_obs with measurements. 34 | May contain NaNs if no measurement was observed. 35 | controls (jax.numpy.array): Array of shape (n_obs, n_controls) with data on the 36 | control variables. 37 | log_mixture_weights (jax.numpy.array): Array of shape (n_obs, n_mixtures) with 38 | the natural logarithm of the weights of each element of the mixture of 39 | normals distribution. 40 | 41 | Returns: 42 | states (jax.numpy.array): Same format as states. 43 | new_states (jax.numpy.array): Same format as states. 44 | new_upper_chols (jax.numpy.array): Same format as upper_chols 45 | new_log_mixture_weights: (jax.numpy.array): Same format as log_mixture_weights 46 | new_loglikes: (jax.numpy.array): 1d array of length n_obs 47 | debug_info (dict): Empty or containing residuals and residual_sds 48 | 49 | """ 50 | n_obs, n_mixtures, n_states = states.shape 51 | 52 | not_missing = jnp.isfinite(measurements) 53 | 54 | # replace missing measurements and controls by reasonable fill values to avoid NaNs 55 | # in the gradient calculation. All values that are influenced by this, are 56 | # replaced by other values later. Choosing the average expected 57 | # expected measurements without controls as fill value ensures that all numbers 58 | # are well defined because the fill values have a reasonable order of magnitude. 59 | # See https://github.com/tensorflow/probability/blob/main/discussion/where-nan.pdf 60 | # and https://jax.readthedocs.io/en/latest/faq.html 61 | # for more details on the issue of NaNs in gradient calculations. 62 | _safe_controls = jnp.where(not_missing.reshape(n_obs, 1), controls, 0) 63 | 64 | _safe_expected_measurements = jnp.dot(states, loadings) + jnp.dot( 65 | _safe_controls, 66 | control_params, 67 | ).reshape(n_obs, 1) 68 | 69 | _safe_measurements = jnp.where( 70 | not_missing, 71 | measurements, 72 | _safe_expected_measurements.mean(axis=1), 73 | ) 74 | 75 | _residuals = _safe_measurements.reshape(n_obs, 1) - _safe_expected_measurements 76 | _f_stars = jnp.dot(upper_chols, loadings.reshape(n_states, 1)) 77 | 78 | _m = jnp.zeros((n_obs, n_mixtures, n_states + 1, n_states + 1)) 79 | _m = _m.at[..., 0, 0].set(meas_sd) 80 | _m = _m.at[..., 1:, :1].set(_f_stars) 81 | _m = _m.at[..., 1:, 1:].set(upper_chols) 82 | 83 | _r = array_qr_jax(_m)[1] 84 | 85 | _new_upper_chols = _r[..., 1:, 1:] 86 | _root_sigmas = _r[..., 0, 0] 87 | _abs_root_sigmas = jnp.abs(_root_sigmas) 88 | # it is important not to divide by the absolute value of _root_sigmas in order 89 | # to recover the sign of the Kalman gain. 90 | _kalman_gains = _r[..., 0, 1:] / _root_sigmas.reshape(n_obs, n_mixtures, 1) 91 | _new_states = states + _kalman_gains * _residuals.reshape(n_obs, n_mixtures, 1) 92 | 93 | # calculate log likelihood per individual and update mixture weights 94 | _loglikes_per_dist = jax.scipy.stats.norm.logpdf(_residuals, 0, _abs_root_sigmas) 95 | if n_mixtures >= 2: 96 | _weighted_loglikes_per_dist = _loglikes_per_dist + log_mixture_weights 97 | _loglikes = jax.scipy.special.logsumexp(_weighted_loglikes_per_dist, axis=1) 98 | _new_log_mixture_weights = _weighted_loglikes_per_dist - _loglikes.reshape( 99 | -1, 100 | 1, 101 | ) 102 | 103 | else: 104 | _loglikes = _loglikes_per_dist.flatten() 105 | _new_log_mixture_weights = log_mixture_weights 106 | 107 | # combine pre-update quantities for missing observations with updated quantities 108 | new_states = jnp.where(not_missing.reshape(n_obs, 1, 1), _new_states, states) 109 | new_upper_chols = jnp.where( 110 | not_missing.reshape(n_obs, 1, 1, 1), 111 | _new_upper_chols, 112 | upper_chols, 113 | ) 114 | new_loglikes = jnp.where(not_missing, _loglikes, 0) 115 | new_log_mixture_weights = jnp.where( 116 | not_missing.reshape(n_obs, 1), 117 | _new_log_mixture_weights, 118 | log_mixture_weights, 119 | ) 120 | 121 | debug_info = {} 122 | residuals = jnp.where(not_missing.reshape(n_obs, 1), _residuals, jnp.nan) 123 | debug_info["residuals"] = residuals 124 | residual_sds = jnp.where( 125 | not_missing.reshape(n_obs, 1), 126 | _abs_root_sigmas, 127 | jnp.nan, 128 | ) 129 | debug_info["residual_sds"] = residual_sds 130 | debug_info["log_mixture_weights"] = new_log_mixture_weights 131 | 132 | return ( 133 | new_states, 134 | new_upper_chols, 135 | new_log_mixture_weights, 136 | new_loglikes, 137 | debug_info, 138 | ) 139 | -------------------------------------------------------------------------------- /src/skillmodels/likelihood_function.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | 6 | from skillmodels.clipping import soft_clipping 7 | from skillmodels.kalman_filters import ( 8 | kalman_predict, 9 | kalman_update, 10 | ) 11 | from skillmodels.parse_params import parse_params 12 | 13 | 14 | def log_likelihood( 15 | params, 16 | parsing_info, 17 | measurements, 18 | controls, 19 | transition_func, 20 | sigma_scaling_factor, 21 | sigma_weights, 22 | dimensions, 23 | labels, 24 | estimation_options, 25 | is_measurement_iteration, 26 | is_predict_iteration, 27 | iteration_to_period, 28 | observed_factors, 29 | ): 30 | return log_likelihood_obs( 31 | params=params, 32 | parsing_info=parsing_info, 33 | measurements=measurements, 34 | controls=controls, 35 | transition_func=transition_func, 36 | sigma_scaling_factor=sigma_scaling_factor, 37 | sigma_weights=sigma_weights, 38 | dimensions=dimensions, 39 | labels=labels, 40 | estimation_options=estimation_options, 41 | is_measurement_iteration=is_measurement_iteration, 42 | is_predict_iteration=is_predict_iteration, 43 | iteration_to_period=iteration_to_period, 44 | observed_factors=observed_factors, 45 | ).sum() 46 | 47 | 48 | def log_likelihood_obs( 49 | params, 50 | parsing_info, 51 | measurements, 52 | controls, 53 | transition_func, 54 | sigma_scaling_factor, 55 | sigma_weights, 56 | dimensions, 57 | labels, 58 | estimation_options, 59 | is_measurement_iteration, 60 | is_predict_iteration, 61 | iteration_to_period, 62 | observed_factors, 63 | ): 64 | """Log likelihood of a skill formation model. 65 | 66 | This function is jax-differentiable and jax-jittable as long as all but the first 67 | argument are marked as static. 68 | 69 | The function returns both a tuple (float, dict). The first entry is the aggregated 70 | log likelihood value. The second additional information like the log likelihood 71 | contribution of each individual. Note that the dict also contains the aggregated 72 | value. Returning that value separately is only needed to calculate a gradient 73 | with Jax. 74 | 75 | Args: 76 | params (jax.numpy.array): 1d array with model parameters. 77 | parsing_info (dict): Contains information how to parse parameter vector. 78 | update_info (pandas.DataFrame): Contains information about number of updates in 79 | each period and purpose of each update. 80 | measurements (jax.numpy.array): Array of shape (n_updates, n_obs) with data on 81 | observed measurements. NaN if the measurement was not observed. 82 | controls (jax.numpy.array): Array of shape (n_periods, n_obs, n_controls) 83 | with observed control variables for the measurement equations. 84 | transition_func (Callable): The transition function. 85 | sigma_scaling_factor (float): A scaling factor that controls the spread of the 86 | sigma points. Bigger means that sigma points are further apart. Depends on 87 | the sigma_point algorithm chosen. 88 | sigma_weights (jax.numpy.array): 1d array of length n_sigma with non-negative 89 | sigma weights. 90 | dimensions (dict): Dimensional information like n_states, n_periods, n_controls, 91 | n_mixtures. See :ref:`dimensions`. 92 | labels (dict): Dict of lists with labels for the model quantities like 93 | factors, periods, controls, stagemap and stages. See :ref:`labels` 94 | observed_factors (jax.numpy.array): Array of shape (n_periods, n_obs, 95 | n_observed_factors) with data on the observed factors. 96 | 97 | Returns: 98 | jnp.array: 1d array of length N, the aggregated log likelihood. 99 | 100 | """ 101 | n_obs = measurements.shape[1] 102 | states, upper_chols, log_mixture_weights, pardict = parse_params( 103 | params, 104 | parsing_info, 105 | dimensions, 106 | labels, 107 | n_obs, 108 | ) 109 | 110 | carry = { 111 | "states": states, 112 | "upper_chols": upper_chols, 113 | "log_mixture_weights": log_mixture_weights, 114 | } 115 | 116 | loop_args = { 117 | "period": iteration_to_period, 118 | "loadings": pardict["loadings"], 119 | "control_params": pardict["controls"], 120 | "meas_sds": pardict["meas_sds"], 121 | "measurements": measurements, 122 | "is_measurement_iteration": is_measurement_iteration, 123 | "is_predict_iteration": is_predict_iteration, 124 | } 125 | 126 | _body = functools.partial( 127 | _scan_body, 128 | controls=controls, 129 | pardict=pardict, 130 | sigma_scaling_factor=sigma_scaling_factor, 131 | sigma_weights=sigma_weights, 132 | transition_func=transition_func, 133 | observed_factors=observed_factors, 134 | ) 135 | 136 | static_out = jax.lax.scan(_body, carry, loop_args)[1] 137 | 138 | # clip contributions before aggregation to preserve as much information as 139 | # possible. 140 | return soft_clipping( 141 | arr=static_out["loglikes"], 142 | lower=estimation_options["clipping_lower_bound"], 143 | upper=estimation_options["clipping_upper_bound"], 144 | lower_hardness=estimation_options["clipping_lower_hardness"], 145 | upper_hardness=estimation_options["clipping_upper_hardness"], 146 | ).sum(axis=0) 147 | 148 | 149 | def _scan_body( 150 | carry, 151 | loop_args, 152 | controls, 153 | pardict, 154 | sigma_scaling_factor, 155 | sigma_weights, 156 | transition_func, 157 | observed_factors, 158 | ): 159 | # ================================================================================== 160 | # create arguments needed for update 161 | # ================================================================================== 162 | t = loop_args["period"] 163 | states = carry["states"] 164 | upper_chols = carry["upper_chols"] 165 | log_mixture_weights = carry["log_mixture_weights"] 166 | 167 | update_kwargs = { 168 | "states": states, 169 | "upper_chols": upper_chols, 170 | "loadings": loop_args["loadings"], 171 | "control_params": loop_args["control_params"], 172 | "meas_sd": loop_args["meas_sds"], 173 | "measurements": loop_args["measurements"], 174 | "controls": controls[t], 175 | "log_mixture_weights": log_mixture_weights, 176 | } 177 | 178 | # ================================================================================== 179 | # do a measurement or anchoring update 180 | # ================================================================================== 181 | states, upper_chols, log_mixture_weights, loglikes = jax.lax.cond( 182 | loop_args["is_measurement_iteration"], 183 | functools.partial(_one_arg_measurement_update), 184 | functools.partial(_one_arg_anchoring_update), 185 | update_kwargs, 186 | ) 187 | 188 | # ================================================================================== 189 | # create arguments needed for predict step 190 | # ================================================================================== 191 | predict_kwargs = { 192 | "states": states, 193 | "upper_chols": upper_chols, 194 | "sigma_scaling_factor": sigma_scaling_factor, 195 | "sigma_weights": sigma_weights, 196 | "trans_coeffs": {k: arr[t] for k, arr in pardict["transition"].items()}, 197 | "shock_sds": pardict["shock_sds"][t], 198 | "anchoring_scaling_factors": pardict["anchoring_scaling_factors"][ 199 | jnp.array([t, t + 1]) 200 | ], 201 | "anchoring_constants": pardict["anchoring_constants"][jnp.array([t, t + 1])], 202 | "observed_factors": observed_factors[t], 203 | } 204 | 205 | fixed_kwargs = {"transition_func": transition_func} 206 | 207 | # ================================================================================== 208 | # Do a predict step or a do-nothing fake predict step 209 | # ================================================================================== 210 | states, upper_chols, filtered_states = jax.lax.cond( 211 | loop_args["is_predict_iteration"], 212 | functools.partial(_one_arg_predict, **fixed_kwargs), 213 | functools.partial(_one_arg_no_predict, **fixed_kwargs), 214 | predict_kwargs, 215 | ) 216 | 217 | new_state = { 218 | "states": states, 219 | "upper_chols": upper_chols, 220 | "log_mixture_weights": log_mixture_weights, 221 | } 222 | 223 | static_out = {"loglikes": loglikes, "states": filtered_states} 224 | return new_state, static_out 225 | 226 | 227 | def _one_arg_measurement_update(kwargs): 228 | out = kalman_update(**kwargs) 229 | return out 230 | 231 | 232 | def _one_arg_anchoring_update(kwargs): 233 | _, _, new_log_mixture_weights, new_loglikes = kalman_update(**kwargs) 234 | out = ( 235 | kwargs["states"], 236 | kwargs["upper_chols"], 237 | new_log_mixture_weights, 238 | new_loglikes, 239 | ) 240 | return out 241 | 242 | 243 | def _one_arg_no_predict(kwargs, transition_func): # noqa: ARG001 244 | """Just return the states cond chols without any changes.""" 245 | return kwargs["states"], kwargs["upper_chols"], kwargs["states"] 246 | 247 | 248 | def _one_arg_predict(kwargs, transition_func): 249 | """Do a predict step but also return the input states as filtered states.""" 250 | new_states, new_upper_chols = kalman_predict( 251 | transition_func, 252 | **kwargs, 253 | ) 254 | return new_states, new_upper_chols, kwargs["states"] 255 | -------------------------------------------------------------------------------- /src/skillmodels/maximization_inputs.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | import numpy as np 6 | import pandas as pd 7 | 8 | import skillmodels.likelihood_function as lf 9 | import skillmodels.likelihood_function_debug as lfd 10 | from skillmodels.constraints import add_bounds, get_constraints 11 | from skillmodels.kalman_filters import calculate_sigma_scaling_factor_and_weights 12 | from skillmodels.params_index import get_params_index 13 | from skillmodels.parse_params import create_parsing_info 14 | from skillmodels.process_data import process_data 15 | from skillmodels.process_debug_data import process_debug_data 16 | from skillmodels.process_model import process_model 17 | 18 | jax.config.update("jax_enable_x64", True) # noqa: FBT003 19 | 20 | 21 | def get_maximization_inputs(model_dict, data): 22 | """Create inputs for optimagic's maximize function. 23 | 24 | Args: 25 | model_dict (dict): The model specification. See: :ref:`model_specs` 26 | data (DataFrame): dataset in long format. 27 | 28 | Returns a dictionary with keys: 29 | loglike (function): A jax jitted function that takes an optimagic-style 30 | params dataframe as only input and returns a dict with entries: 31 | - "value": The scalar log likelihood 32 | - "contributions": An array with the log likelihood per observation 33 | debug_loglike (function): Similar to loglike, with the following differences: 34 | - It is not jitted and thus faster on the first call and debuggable 35 | - It will add intermediate results as additional entries in the returned 36 | dictionary. Those can be used for debugging and plotting. 37 | gradient (function): The gradient of the scalar log likelihood 38 | function with respect to the parameters. 39 | loglike_and_gradient (function): Combination of loglike and 40 | loglike_gradient that is faster than calling the two functions separately. 41 | constraints (list): List of optimagic constraints that are implied by the 42 | model specification. 43 | params_template (pd.DataFrame): Parameter DataFrame with correct index and 44 | bounds but with empty value column. 45 | 46 | """ 47 | model = process_model(model_dict) 48 | p_index = get_params_index( 49 | model["update_info"], 50 | model["labels"], 51 | model["dimensions"], 52 | model["transition_info"], 53 | ) 54 | 55 | parsing_info = create_parsing_info( 56 | p_index, 57 | model["update_info"], 58 | model["labels"], 59 | model["anchoring"], 60 | ) 61 | measurements, controls, observed_factors = process_data( 62 | data, 63 | model["labels"], 64 | model["update_info"], 65 | model["anchoring"], 66 | ) 67 | 68 | sigma_scaling_factor, sigma_weights = calculate_sigma_scaling_factor_and_weights( 69 | model["dimensions"]["n_latent_factors"], 70 | model["estimation_options"]["sigma_points_scale"], 71 | ) 72 | 73 | partialed_get_jnp_params_vec = functools.partial( 74 | _get_jnp_params_vec, 75 | target_index=p_index, 76 | ) 77 | 78 | partialed_loglikes = {} 79 | for n, fun in { 80 | "ll": lf.log_likelihood, 81 | "llo": lf.log_likelihood_obs, 82 | "debug_ll": lfd.log_likelihood, 83 | }.items(): 84 | partialed_loglikes[n] = _partial_some_log_likelihood( 85 | fun=fun, 86 | parsing_info=parsing_info, 87 | measurements=measurements, 88 | controls=controls, 89 | observed_factors=observed_factors, 90 | model=model, 91 | sigma_weights=sigma_weights, 92 | sigma_scaling_factor=sigma_scaling_factor, 93 | ) 94 | 95 | _jitted_loglike = jax.jit(partialed_loglikes["ll"]) 96 | _jitted_loglikeobs = jax.jit(partialed_loglikes["llo"]) 97 | _gradient = jax.jit(jax.grad(partialed_loglikes["ll"])) 98 | 99 | def loglike(params): 100 | params_vec = partialed_get_jnp_params_vec(params) 101 | return float(_jitted_loglike(params_vec)) 102 | 103 | def loglikeobs(params): 104 | params_vec = partialed_get_jnp_params_vec(params) 105 | return _to_numpy(_jitted_loglikeobs(params_vec)) 106 | 107 | def loglike_and_gradient(params): 108 | params_vec = partialed_get_jnp_params_vec(params) 109 | crit = float(_jitted_loglike(params_vec)) 110 | grad = _to_numpy(_gradient(params_vec)) 111 | return crit, grad 112 | 113 | def debug_loglike(params): 114 | params_vec = partialed_get_jnp_params_vec(params) 115 | jax_output = partialed_loglikes["debug_ll"](params_vec) 116 | tmp = _to_numpy(jax_output) 117 | tmp["value"] = float(tmp["value"]) 118 | return process_debug_data(debug_data=tmp, model=model) 119 | 120 | constr = get_constraints( 121 | dimensions=model["dimensions"], 122 | labels=model["labels"], 123 | anchoring_info=model["anchoring"], 124 | update_info=model["update_info"], 125 | normalizations=model["normalizations"], 126 | ) 127 | 128 | params_template = pd.DataFrame(columns=["value"], index=p_index) 129 | params_template = add_bounds( 130 | params_template, 131 | model["estimation_options"]["bounds_distance"], 132 | ) 133 | 134 | out = { 135 | "loglike": loglike, 136 | "loglikeobs": loglikeobs, 137 | "debug_loglike": debug_loglike, 138 | "loglike_and_gradient": loglike_and_gradient, 139 | "constraints": constr, 140 | "params_template": params_template, 141 | } 142 | 143 | return out 144 | 145 | 146 | def _partial_some_log_likelihood( 147 | fun, 148 | parsing_info, 149 | measurements, 150 | controls, 151 | observed_factors, 152 | model, 153 | sigma_weights, 154 | sigma_scaling_factor, 155 | ): 156 | update_info = model["update_info"] 157 | is_measurement_iteration = (update_info["purpose"] == "measurement").to_numpy() 158 | _periods = pd.Series(update_info.index.get_level_values("period").to_numpy()) 159 | is_predict_iteration = ((_periods - _periods.shift(-1)) == -1).to_numpy() 160 | last_period = model["labels"]["periods"][-1] 161 | # iteration_to_period is used as an indexer to loop over arrays of different lengths 162 | # in a jax.lax.scan. It needs to work for arrays of length n_periods and not raise 163 | # IndexErrors on tracer arrays of length n_periods - 1 (i.e. n_transitions). 164 | # To achieve that, we replace the last period by -1. 165 | iteration_to_period = _periods.replace(last_period, -1).to_numpy() 166 | 167 | return functools.partial( 168 | fun, 169 | parsing_info=parsing_info, 170 | measurements=measurements, 171 | controls=controls, 172 | transition_func=model["transition_info"]["func"], 173 | sigma_scaling_factor=sigma_scaling_factor, 174 | sigma_weights=sigma_weights, 175 | dimensions=model["dimensions"], 176 | labels=model["labels"], 177 | estimation_options=model["estimation_options"], 178 | is_measurement_iteration=is_measurement_iteration, 179 | is_predict_iteration=is_predict_iteration, 180 | iteration_to_period=iteration_to_period, 181 | observed_factors=observed_factors, 182 | ) 183 | 184 | 185 | def _to_numpy(obj): 186 | if isinstance(obj, dict): 187 | res = {} 188 | for key, value in obj.items(): 189 | if np.isscalar(value): 190 | res[key] = value 191 | else: 192 | res[key] = np.array(value) 193 | 194 | elif np.isscalar(obj): 195 | res = obj 196 | else: 197 | res = np.array(obj) 198 | 199 | return res 200 | 201 | 202 | def _get_jnp_params_vec(params, target_index): 203 | if set(params.index) != set(target_index): 204 | additional_entries = params.index.difference(target_index).tolist() 205 | missing_entries = target_index.difference(params.index).tolist() 206 | msg = "Invalid params DataFrame. " 207 | if additional_entries: 208 | msg += f"Your params have additional entries: {additional_entries}. " 209 | if missing_entries: 210 | msg += f"Your params have missing entries: {missing_entries}. " 211 | raise ValueError(msg) 212 | 213 | vec = jnp.array(params.reindex(target_index)["value"].to_numpy()) 214 | return vec 215 | -------------------------------------------------------------------------------- /src/skillmodels/params_index.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | 3 | 4 | def get_params_index(update_info, labels, dimensions, transition_info): 5 | """Generate index for the params_df for optimagic. 6 | 7 | The index has four levels. The first is the parameter category. The second is the 8 | period in which the parameters are used. The third and fourth are additional 9 | descriptors that depend on the category. If the fourth level is not really needed, 10 | it contains an empty string. 11 | 12 | Args: 13 | update_info (pandas.DataFrame): DataFrame with one row per Kalman update needed 14 | in the likelihood function. See :ref:`update_info`. 15 | labels (dict): Dict of lists with labels for the model quantities like 16 | factors, periods, controls, stagemap and stages. See :ref:`labels` 17 | options (dict): Tuning parameters for the estimation. 18 | See :ref:`estimation_options`. 19 | 20 | Returns: 21 | params_index (pd.MultiIndex) 22 | 23 | """ 24 | ind_tups = get_control_params_index_tuples(labels["controls"], update_info) 25 | ind_tups += get_loadings_index_tuples(labels["latent_factors"], update_info) 26 | ind_tups += get_meas_sds_index_tuples(update_info) 27 | ind_tups += get_shock_sds_index_tuples(labels["periods"], labels["latent_factors"]) 28 | ind_tups += initial_mean_index_tuples( 29 | dimensions["n_mixtures"], 30 | labels["latent_factors"], 31 | ) 32 | ind_tups += get_mixture_weights_index_tuples(dimensions["n_mixtures"]) 33 | ind_tups += get_initial_cholcovs_index_tuples( 34 | dimensions["n_mixtures"], 35 | labels["latent_factors"], 36 | ) 37 | ind_tups += get_transition_index_tuples(transition_info, labels["periods"]) 38 | 39 | index = pd.MultiIndex.from_tuples( 40 | ind_tups, 41 | names=["category", "period", "name1", "name2"], 42 | ) 43 | return index 44 | 45 | 46 | def get_control_params_index_tuples(controls, update_info): 47 | """Index tuples for control coeffs. 48 | 49 | Args: 50 | controls (list): List of lists. There is one sublist per period which contains 51 | the names of the control variables in that period. Constant not included. 52 | update_info (pandas.DataFrame): DataFrame with one row per Kalman update needed 53 | in the likelihood function. See :ref:`update_info`. 54 | 55 | """ 56 | ind_tups = [] 57 | for period, meas in update_info.index: 58 | for cont in controls: 59 | ind_tups.append(("controls", period, meas, cont)) 60 | return ind_tups 61 | 62 | 63 | def get_loadings_index_tuples(factors, update_info): 64 | """Index tuples for loading. 65 | 66 | Args: 67 | factors (list): The latent factors of the model 68 | update_info (pandas.DataFrame): DataFrame with one row per Kalman update needed 69 | in the likelihood function. See :ref:`update_info`. 70 | 71 | Returns: 72 | ind_tups (list) 73 | 74 | """ 75 | mask = update_info[factors].to_numpy() 76 | ind_tups = [] 77 | for i, (period, meas) in enumerate(update_info.index): 78 | for f, factor in enumerate(factors): 79 | if mask[i, f]: 80 | ind_tups.append(("loadings", period, meas, factor)) 81 | return ind_tups 82 | 83 | 84 | def get_meas_sds_index_tuples(update_info): 85 | """Index tuples for meas_sd. 86 | 87 | Args: 88 | update_info (pandas.DataFrame): DataFrame with one row per Kalman update needed 89 | in the likelihood function. See :ref:`update_info`. 90 | 91 | Returns: 92 | ind_tups (list) 93 | 94 | """ 95 | ind_tups = [] 96 | for period, meas in update_info.index: 97 | ind_tups.append(("meas_sds", period, meas, "-")) 98 | return ind_tups 99 | 100 | 101 | def get_shock_sds_index_tuples(periods, factors): 102 | """Index tuples for shock_sd. 103 | 104 | Args: 105 | periods (list): The periods of the model. 106 | factors (list): The latent factors of the model. 107 | 108 | Returns: 109 | ind_tups (list) 110 | 111 | """ 112 | ind_tups = [] 113 | for period in periods[:-1]: 114 | for factor in factors: 115 | ind_tups.append(("shock_sds", period, factor, "-")) 116 | return ind_tups 117 | 118 | 119 | def initial_mean_index_tuples(n_mixtures, factors): 120 | """Index tuples for initial_mean. 121 | 122 | Args: 123 | n_mixtures (int): Number of elements in the mixture distribution of the factors. 124 | factors (list): The latent factors of the model 125 | 126 | Returns: 127 | ind_tups (list) 128 | 129 | """ 130 | ind_tups = [] 131 | for emf in range(n_mixtures): 132 | for factor in factors: 133 | ind_tups.append(("initial_states", 0, f"mixture_{emf}", factor)) 134 | return ind_tups 135 | 136 | 137 | def get_mixture_weights_index_tuples(n_mixtures): 138 | """Index tuples for mixture_weight. 139 | 140 | Args: 141 | n_mixtures (int): Number of elements in the mixture distribution of the factors. 142 | 143 | Returns: 144 | ind_tups (list) 145 | 146 | """ 147 | ind_tups = [] 148 | for emf in range(n_mixtures): 149 | ind_tups.append(("mixture_weights", 0, f"mixture_{emf}", "-")) 150 | return ind_tups 151 | 152 | 153 | def get_initial_cholcovs_index_tuples(n_mixtures, factors): 154 | """Index tuples for initial_cov. 155 | 156 | Args: 157 | n_mixtures (int): Number of elements in the mixture distribution of the factors. 158 | factors (list): The latent factors of the model 159 | 160 | Returns: 161 | ind_tups (list) 162 | 163 | """ 164 | ind_tups = [] 165 | for emf in range(n_mixtures): 166 | for row, factor1 in enumerate(factors): 167 | for col, factor2 in enumerate(factors): 168 | if col <= row: 169 | ind_tups.append( 170 | ( 171 | "initial_cholcovs", 172 | 0, 173 | f"mixture_{emf}", 174 | f"{factor1}-{factor2}", 175 | ), 176 | ) 177 | return ind_tups 178 | 179 | 180 | def get_transition_index_tuples(transition_info, periods): 181 | """Index tuples for transition equation coefficients. 182 | 183 | Args: 184 | latent_factors (list): The latent factors of the model 185 | all_factors (list): The latent and observed factors of the model. 186 | periods (list): The periods of the model 187 | transition_names (list): name of the transition equation of each factor 188 | 189 | Returns: 190 | ind_tups (list) 191 | 192 | """ 193 | ind_tups = [] 194 | for factor, names in transition_info["param_names"].items(): 195 | for period in periods[:-1]: 196 | for name in names: 197 | ind_tups.append(("transition", period, factor, name)) 198 | return ind_tups 199 | -------------------------------------------------------------------------------- /src/skillmodels/process_data.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import jax.numpy as jnp 4 | import numpy as np 5 | import pandas as pd 6 | 7 | from skillmodels.process_model import get_period_measurements 8 | 9 | 10 | def process_data(df, labels, update_info, anchoring_info, purpose="estimation"): 11 | """Process the data for estimation. 12 | 13 | Args: 14 | df (DataFrame): panel dataset in long format. It has a MultiIndex 15 | where the first level indicates the period and the second the individual. 16 | labels (dict): Dict of lists with labels for the model quantities like 17 | factors, periods, controls, stagemap and stages. See :ref:`labels` 18 | update_info (pandas.DataFrame): DataFrame with one row per Kalman update needed 19 | in the likelihood function. See :ref:`update_info`. 20 | anchoring_qinfo (dict): Information about anchoring. See :ref:`anchoring` 21 | purpose (Literal["estimation", "anything"]): Whether the data is used for 22 | estimation (default, includes measurement data) or not. 23 | 24 | Returns: 25 | meas_data (jax.numpy.array): Array of shape (n_updates, n_obs) with data on 26 | observed measurements. NaN if the measurement was not observed. 27 | control_data (jax.numpy.array): Array of shape (n_periods, n_obs, n_controls) 28 | with observed control variables for the measurement equations. 29 | observed_factors (jax.numpy.array): Array of shape (n_periods, n_obs, 30 | n_observed_factors) with data on the observed factors. 31 | 32 | """ 33 | df = pre_process_data(df, labels["periods"]) 34 | df["constant"] = 1 35 | df = _add_copies_of_anchoring_outcome(df, anchoring_info) 36 | _check_data(df, update_info, labels, purpose=purpose) 37 | n_obs = int(len(df) / len(labels["periods"])) 38 | df = _handle_controls_with_missings(df, labels["controls"], update_info) 39 | if purpose == "estimation": 40 | meas_data = _generate_measurements_array(df, update_info, n_obs) 41 | control_data = _generate_controls_array(df, labels, n_obs) 42 | observed_data = _generate_observed_factor_array(df, labels, n_obs) 43 | 44 | if purpose == "estimation": 45 | out = (meas_data, control_data, observed_data) 46 | else: 47 | out = (control_data, observed_data) 48 | return out 49 | 50 | 51 | def pre_process_data(df, periods): 52 | """Balance panel data in long format, drop unnecessary periods and set index. 53 | 54 | Args: 55 | df (DataFrame): panel dataset in long format. It has a MultiIndex 56 | where the first level indicates the period and the second 57 | the individual. 58 | 59 | Returns: 60 | balanced (DataFrame): balanced panel. It has a MultiIndex. The first 61 | enumerates individuals. The second level counts periods, starting at 0. 62 | 63 | """ 64 | df = df.sort_index() 65 | df["__old_id__"] = df.index.get_level_values(0) 66 | df["__old_period__"] = df.index.get_level_values(1) 67 | 68 | # replace existing codes for periods and 69 | df.index.names = ["id", "period"] 70 | for level in [0, 1]: 71 | df.index = df.index.set_levels(range(len(df.index.levels[level])), level=level) 72 | 73 | # create new index 74 | ids = sorted(df.index.get_level_values("id").unique()) 75 | new_index = pd.MultiIndex.from_product([ids, periods], names=["id", "period"]) 76 | 77 | # set new index 78 | df = df.reindex(new_index) 79 | 80 | return df 81 | 82 | 83 | def _add_copies_of_anchoring_outcome(df, anchoring_info): 84 | df = df.copy() 85 | for factor in anchoring_info["factors"]: 86 | outcome = anchoring_info["outcomes"][factor] 87 | df[f"{outcome}_{factor}"] = df[outcome] 88 | return df 89 | 90 | 91 | def _check_data(df, update_info, labels, purpose): # noqa: C901 92 | var_report = pd.DataFrame(index=update_info.index[:0], columns=["problem"]) 93 | for period in labels["periods"]: 94 | period_data = df.query(f"period == {period}") 95 | for cont in labels["controls"]: 96 | if cont not in period_data.columns or period_data[cont].isna().all(): 97 | var_report.loc[(period, cont), "problem"] = "Variable is missing" 98 | 99 | if purpose == "estimation": 100 | for meas in get_period_measurements(update_info, period): 101 | if meas not in period_data.columns: 102 | var_report.loc[(period, meas), "problem"] = "Variable is missing" 103 | elif len(period_data[meas].dropna().unique()) == 1: 104 | var_report.loc[(period, meas), "problem"] = ( 105 | "Variable has no variance" 106 | ) 107 | 108 | for factor in labels["observed_factors"]: 109 | if factor not in period_data.columns: 110 | var_report.loc[(period, factor), "problem"] = "Variable is missing" 111 | elif period_data[factor].isna().any(): 112 | var_report.loc[(period, factor), "problem"] = "Variable has missings" 113 | 114 | var_report = var_report.to_string() if len(var_report) > 0 else "" 115 | 116 | if var_report: 117 | raise ValueError(var_report) 118 | 119 | 120 | def _handle_controls_with_missings(df, controls, update_info): 121 | periods = update_info.index.get_level_values(0).unique().tolist() 122 | problematic_index = df.index[:0] 123 | for period in periods: 124 | period_data = df.query(f"period == {period}") 125 | control_data = period_data[controls] 126 | meas_data = period_data[get_period_measurements(update_info, period)] 127 | problem = control_data.isna().any(axis=1) & meas_data.notna().any(axis=1) 128 | problematic_index = problematic_index.union(period_data[problem].index) 129 | 130 | if len(problematic_index) > 0: 131 | old_names = df.loc[problematic_index][["__old_id__", "__old_period__"]] 132 | msg = "Set measurements to NaN because there are NaNs in the controls for:\n{}" 133 | msg = msg.format(list(map(tuple, old_names.to_numpy().tolist()))) 134 | warnings.warn(msg) 135 | df.loc[problematic_index] = np.nan 136 | return df 137 | 138 | 139 | def _generate_measurements_array(df, update_info, n_obs): 140 | arr = np.zeros((len(update_info), n_obs)) 141 | for k, (period, var) in enumerate(update_info.index): 142 | arr[k] = df.query(f"period == {period}")[var].to_numpy() 143 | return jnp.array(arr, dtype="float32") 144 | 145 | 146 | def _generate_controls_array(df, labels, n_obs): 147 | arr = np.zeros((len(labels["periods"]), n_obs, len(labels["controls"]))) 148 | for period in labels["periods"]: 149 | arr[period] = df.query(f"period == {period}")[labels["controls"]].to_numpy() 150 | return jnp.array(arr, dtype="float32") 151 | 152 | 153 | def _generate_observed_factor_array(df, labels, n_obs): 154 | arr = np.zeros((len(labels["periods"]), n_obs, len(labels["observed_factors"]))) 155 | for period in labels["periods"]: 156 | arr[period] = df.query(f"period == {period}")[ 157 | labels["observed_factors"] 158 | ].to_numpy() 159 | return jnp.array(arr, dtype="float32") 160 | -------------------------------------------------------------------------------- /src/skillmodels/process_debug_data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | 4 | 5 | def process_debug_data(debug_data, model): 6 | """Process the raw debug data into pandas objects that make visualization easy. 7 | 8 | Args: 9 | debug_data (dict): Dictionary containing the following entries ( 10 | and potentially others which are not modified): 11 | - filtered_states (jax.numpy.array): Array of shape (n_updates, n_obs, 12 | n_mixtures, n_states) containing the filtered states after each Kalman 13 | update. 14 | - initial_states (jax.numpy.array): Array of shape (n_obs, n_mixtures, n_states) 15 | with the state estimates before the first Kalman update. 16 | - residuals (jax.numpy.array): Array of shape (n_updates, n_obs, n_mixtures) 17 | containing the residuals of a Kalman update. 18 | - residual_sds (jax.numpy.ndarray): Array of shape (n_updates, n_obs, 19 | n_mixtures) containing the theoretical standard deviation of the residuals. 20 | - all_contributions (jax.numpy.array): Array of shape (n_updates, n_obs) with 21 | the likelihood contributions per update and individual. 22 | - log_mixture_weights (jax.numpy.array): Array of shape (n_updates, n_obs, 23 | n_mixtures) containing the log mixture weights after each update. 24 | - initial_log_mixture_weights (jax.numpy.array): Array of shape (n_obs, 25 | n_mixtures) containing the log mixture weights before the first 26 | kalman update. 27 | 28 | model (dict): Processed model dictionary. 29 | 30 | Returns: 31 | dict: Dictionary with processed debug data. It has the following entries: 32 | 33 | - post_update_states (pd.DataFrame). As pre_update_states but "period" and 34 | "measurement" identify the last measurement that was incorporated. 35 | - filtered_states (pd.DataFrame). Tidy DataFrame with filtered states 36 | after the last update of each period. The columns are the factor names, 37 | "period" and "id". The filtered states are already aggregated over 38 | mixture distributions. 39 | - state_ranges (dict): The keys are the names of the latent factors. 40 | The values are DataFrames with the columns "period", "minimum", "maximum". 41 | Note that this aggregates over mixture distributions. 42 | - residuals (pd.DataFrame): Tidy DataFrame with residuals of each Kalman update. 43 | Columns are "residual", "mixture", "period", "measurement" and "id". 44 | "period" and "measurement" identify the Kalman update to which the residual 45 | belongs. 46 | - residual_sds (pd.DataFrame): As residuals but containing the theoretical 47 | standard deviation of the corresponding residual. 48 | - all_contributions (pd.DataFrame): Tidy DataFrame with log likelihood 49 | contribution per individual and Kalman Update. The columns are 50 | "contribution", "period", "measurement" and "id". "period" and "measurement" 51 | identify the Kalman Update to which the likelihood contribution corresponds. 52 | 53 | """ 54 | update_info = model["update_info"] 55 | factors = model["labels"]["latent_factors"] 56 | 57 | post_update_states = _create_post_update_states( 58 | debug_data["filtered_states"], 59 | factors, 60 | update_info, 61 | ) 62 | 63 | filtered_states = _create_filtered_states( 64 | filtered_states=debug_data["filtered_states"], 65 | log_mixture_weights=debug_data["log_mixture_weights"], 66 | update_info=update_info, 67 | factors=factors, 68 | ) 69 | 70 | state_ranges = create_state_ranges(filtered_states, factors) 71 | 72 | residuals = _process_residuals(debug_data["residuals"], update_info) 73 | residual_sds = _process_residual_sds(debug_data["residual_sds"], update_info) 74 | 75 | all_contributions = _process_all_contributions( 76 | debug_data["all_contributions"], 77 | update_info, 78 | ) 79 | 80 | res = { 81 | "post_update_states": post_update_states, 82 | "filtered_states": filtered_states, 83 | "state_ranges": state_ranges, 84 | "residuals": residuals, 85 | "residual_sds": residual_sds, 86 | "all_contributions": all_contributions, 87 | } 88 | 89 | for key in ["value", "contributions"]: 90 | if key in debug_data: 91 | res[key] = debug_data[key] 92 | 93 | return res 94 | 95 | 96 | def _create_post_update_states(filtered_states, factors, update_info): 97 | to_concat = [] 98 | for (period, meas), data in zip(update_info.index, filtered_states, strict=False): 99 | df = _convert_state_array_to_df(data, factors) 100 | df["period"] = period 101 | df["id"] = np.arange(len(df)) 102 | df["measurement"] = meas 103 | to_concat.append(df) 104 | 105 | post_states = pd.concat(to_concat) 106 | 107 | return post_states 108 | 109 | 110 | def _convert_state_array_to_df(arr, factor_names): 111 | """Convert a 3d state array into a 2d DataFrame. 112 | 113 | Args: 114 | arr (np.ndarray): Array of shape (n_obs, n_mixtures, n_states) 115 | factor_names (list): Names of the latent factors. 116 | """ 117 | n_obs, n_mixtures, n_states = arr.shape 118 | df = pd.DataFrame(data=arr.reshape(-1, n_states), columns=factor_names) 119 | df["mixture"] = np.full((n_obs, n_mixtures), np.arange(n_mixtures)).flatten() 120 | return df 121 | 122 | 123 | def _create_filtered_states(filtered_states, log_mixture_weights, update_info, factors): 124 | filtered_states = np.array(filtered_states) 125 | log_mixture_weights = np.array(log_mixture_weights) 126 | weights = np.exp(log_mixture_weights) 127 | 128 | agg_states = (filtered_states * weights.reshape(*weights.shape, 1)).sum(axis=-2) 129 | 130 | keep = [] 131 | for i, (period, measurement) in enumerate(update_info.index): 132 | last_measurement = update_info.query( 133 | f"purpose == 'measurement' & period == {period}", 134 | ).index[-1][1] 135 | 136 | if measurement == last_measurement: 137 | keep.append(i) 138 | 139 | to_concat = [] 140 | for period, i in enumerate(keep): 141 | df = pd.DataFrame(data=agg_states[i], columns=factors) 142 | df["period"] = period 143 | df["id"] = np.arange(len(df)) 144 | to_concat.append(df) 145 | 146 | filtered_states = pd.concat(to_concat) 147 | 148 | return filtered_states 149 | 150 | 151 | def create_state_ranges(filtered_states, factors): 152 | ranges = {} 153 | minima = filtered_states.groupby("period").min() 154 | maxima = filtered_states.groupby("period").max() 155 | for factor in factors: 156 | df = pd.concat([minima[factor], maxima[factor]], axis=1) 157 | df.columns = ["minimum", "maximum"] 158 | ranges[factor] = df 159 | return ranges 160 | 161 | 162 | def _process_residuals(residuals, update_info): 163 | to_concat = [] 164 | n_obs, n_mixtures = residuals[0].shape 165 | for (period, meas), data in zip(update_info.index, residuals, strict=False): 166 | df = pd.DataFrame(data.reshape(-1, 1), columns=["residual"]) 167 | df["mixture"] = np.full((n_obs, n_mixtures), np.arange(n_mixtures)).flatten() 168 | df["period"] = period 169 | df["id"] = np.arange(len(df)) 170 | df["measurement"] = meas 171 | to_concat.append(df) 172 | return pd.concat(to_concat) 173 | 174 | 175 | def _process_residual_sds(residual_sds, update_info): 176 | return _process_residuals(residual_sds, update_info) 177 | 178 | 179 | def _process_all_contributions(all_contributions, update_info): 180 | to_concat = [] 181 | for (period, meas), contribs in zip( 182 | update_info.index, all_contributions, strict=False 183 | ): 184 | df = pd.DataFrame(data=contribs.reshape(-1, 1), columns=["contribution"]) 185 | df["measurement"] = meas 186 | df["period"] = period 187 | df["id"] = np.arange(len(df)) 188 | to_concat.append(df) 189 | return pd.concat(to_concat) 190 | -------------------------------------------------------------------------------- /src/skillmodels/transition_functions.py: -------------------------------------------------------------------------------- 1 | """Contains transition functions and corresponding helper functions. 2 | 3 | Below the signature and purpose of a transition function and its helper 4 | functions is explained with a transition function called example_func: 5 | > 6 | 7 | **example_func(** *states, params**)**: 8 | 9 | The actual transition function. 10 | 11 | Args: 12 | * states: 1d numpy array of length n_all_factors 13 | * params: 1d numpy array with coefficients specific to this transition function 14 | 15 | Returns: 16 | * float 17 | 18 | 19 | **names_example_func(** *factors* **)**: 20 | 21 | Generate a list of names for the params of the transition function. 22 | 23 | The names will be used to construct index tuples in the following way: 24 | 25 | ('transition', period, factor, NAME) 26 | 27 | The transition functions have to be JAX jittable and differentiable. However, they 28 | should not be jitted yet. 29 | 30 | """ 31 | 32 | from itertools import combinations 33 | 34 | import jax 35 | import jax.numpy as jnp 36 | 37 | 38 | def linear(states, params): 39 | """Linear production function where the constant is the last parameter.""" 40 | constant = params[-1] 41 | betas = params[:-1] 42 | return jnp.dot(states, betas) + constant 43 | 44 | 45 | def params_linear(factors): 46 | """Index tuples for linear transition function.""" 47 | return [*factors, "constant"] 48 | 49 | 50 | def translog(states, params): 51 | """Translog transition function. 52 | 53 | The name is a convention in the skill formation literature even though the function 54 | is better described as a linear in parameters transition function with squares and 55 | interaction terms of the states. 56 | 57 | """ 58 | nfac = len(states) 59 | constant = params[-1] 60 | lin_beta = params[:nfac] 61 | square_beta = params[nfac : 2 * nfac] 62 | inter_beta = params[2 * nfac : -1] 63 | 64 | res = jnp.dot(states, lin_beta) 65 | res += jnp.dot(states**2, square_beta) 66 | for p, (a, b) in zip(inter_beta, combinations(range(nfac), 2), strict=False): 67 | res += p * states[a] * states[b] 68 | res += constant 69 | return res 70 | 71 | 72 | def params_translog(factors): 73 | """Index tuples for the translog production function.""" 74 | names = ( 75 | factors 76 | + [f"{factor} ** 2" for factor in factors] 77 | + [f"{a} * {b}" for a, b in combinations(factors, 2)] 78 | + ["constant"] 79 | ) 80 | return names 81 | 82 | 83 | def log_ces(states, params): 84 | """Log CES production function (KLS version).""" 85 | phi = params[-1] 86 | gammas = params[:-1] 87 | scaling_factor = 1 / phi 88 | 89 | # note: once the b argument is supported in jax.scipy.special.logsumexp, we can set 90 | # b = gammas instead of adding the log of gammas to sigma_points * phi 91 | 92 | # the log step for gammas underflows for gamma = 0, but this is handled correctly 93 | # by logsumexp and does not raise a warning. 94 | unscaled = jax.scipy.special.logsumexp(jnp.log(gammas) + states * phi) 95 | result = unscaled * scaling_factor 96 | return result 97 | 98 | 99 | def params_log_ces(factors): 100 | """Index tuples for the log_ces production function.""" 101 | return [*factors, "phi"] 102 | 103 | 104 | def constraints_log_ces(factor, factors, period): 105 | names = params_log_ces(factors) 106 | loc = [("transition", period, factor, name) for name in names[:-1]] 107 | return {"loc": loc, "type": "probability"} 108 | 109 | 110 | def constant(state, params): # noqa: ARG001 111 | """Constant production function.""" 112 | return state 113 | 114 | 115 | def params_constant(factors): # noqa: ARG001 116 | """Index tuples for the constant production function.""" 117 | return [] 118 | 119 | 120 | def robust_translog(states, params): 121 | """Numerically robust version of the translog transition function. 122 | 123 | This function does a clipping of the state vector at +- 1e12 before calling 124 | the standard translog function. It has a no effect on the results if the 125 | states do not get close to the clipping values and prevents overflows otherwise. 126 | 127 | The name is a convention in the skill formation literature even though the function 128 | is better described as a linear in parameters transition function with squares and 129 | interaction terms of the states. 130 | 131 | """ 132 | clipped_states = jnp.clip(states, -1e12, 1e12) 133 | return translog(clipped_states, params) 134 | 135 | 136 | def params_robust_translog(factors): 137 | return params_translog(factors) 138 | 139 | 140 | def linear_and_squares(states, params): 141 | """linear_and_squares transition function.""" 142 | nfac = len(states) 143 | constant = params[-1] 144 | lin_beta = params[:nfac] 145 | square_beta = params[nfac : 2 * nfac] 146 | 147 | res = jnp.dot(states, lin_beta) 148 | res += jnp.dot(states**2, square_beta) 149 | res += constant 150 | return res 151 | 152 | 153 | def params_linear_and_squares(factors): 154 | """Index tuples for the linear_and_squares production function.""" 155 | names = factors + [f"{factor} ** 2" for factor in factors] + ["constant"] 156 | return names 157 | 158 | 159 | def log_ces_general(states, params): 160 | """Generalized log_ces production function without known location and scale.""" 161 | n = states.shape[-1] 162 | tfp = params[-1] 163 | gammas = params[:n] 164 | sigmas = params[n : 2 * n] 165 | 166 | # note: once the b argument is supported in jax.scipy.special.logsumexp, we can set 167 | # b = gammas instead of adding the log of gammas to sigma_points * phi 168 | 169 | # the log step for gammas underflows for gamma = 0, but this is handled correctly 170 | # by logsumexp and does not raise a warning. 171 | unscaled = jax.scipy.special.logsumexp(jnp.log(gammas) + states * sigmas) 172 | result = unscaled * tfp 173 | return result 174 | 175 | 176 | def params_log_ces_general(factors): 177 | """Index tuples for the generalized log_ces production function.""" 178 | return factors + [f"sigma_{fac}" for fac in factors] + ["tfp"] 179 | -------------------------------------------------------------------------------- /src/skillmodels/utils_plotting.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def get_layout_kwargs( 5 | layout_kwargs=None, 6 | legend_kwargs=None, 7 | title_kwargs=None, 8 | showlegend=False, 9 | columns=None, 10 | rows=None, 11 | ): 12 | """Define and update default kwargs for update_layout. 13 | 14 | Defines some default keyword arguments to update figure layout, such as 15 | title and legend. 16 | 17 | """ 18 | default_kwargs = { 19 | "template": "simple_white", 20 | "xaxis_showgrid": False, 21 | "yaxis_showgrid": False, 22 | "legend": {}, 23 | "title": {}, 24 | "showlegend": showlegend, 25 | } 26 | if rows is not None: 27 | default_kwargs["height"] = 300 * len(rows) 28 | if columns is not None: 29 | default_kwargs["width"] = 300 * len(columns) 30 | if title_kwargs: 31 | default_kwargs["title"] = title_kwargs 32 | if legend_kwargs: 33 | default_kwargs["legend"].update(legend_kwargs) 34 | if layout_kwargs: 35 | default_kwargs.update(layout_kwargs) 36 | return default_kwargs 37 | 38 | 39 | def get_make_subplot_kwargs( 40 | sharex, 41 | sharey, 42 | column_order, 43 | row_order, 44 | make_subplot_kwargs, 45 | add_scenes=False, 46 | ): 47 | """Define and update keywargs for instantiating figure with subplots.""" 48 | nrows = len(row_order) 49 | ncols = len(column_order) 50 | default_kwargs = { 51 | "rows": nrows, 52 | "cols": ncols, 53 | "start_cell": "top-left", 54 | "print_grid": False, 55 | "shared_yaxes": sharey, 56 | "shared_xaxes": sharex, 57 | "horizontal_spacing": 1 / (ncols * 6), 58 | } 59 | if nrows > 1: 60 | default_kwargs["vertical_spacing"] = (1 / (nrows - 1)) / 4 61 | if not sharey: 62 | default_kwargs["horizontal_spacing"] = 2 * default_kwargs["horizontal_spacing"] 63 | if add_scenes: 64 | specs = np.array([[{}] * ncols] * nrows) 65 | for i in range(nrows): 66 | for j in range(ncols): 67 | if j > i: 68 | specs[i, j] = {"type": "scene"} 69 | default_kwargs["specs"] = specs.tolist() 70 | if make_subplot_kwargs is not None: 71 | default_kwargs.update(make_subplot_kwargs) 72 | return default_kwargs 73 | -------------------------------------------------------------------------------- /tests/model2.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | factors: 3 | fac1: 4 | measurements: 5 | - [y1, y2, y3] 6 | - [y1, y2, y3] 7 | - [y1, y2, y3] 8 | - [y1, y2, y3] 9 | - [y1, y2, y3] 10 | - [y1, y2, y3] 11 | - [y1, y2, y3] 12 | - [y1, y2, y3] 13 | transition_function: log_ces 14 | normalizations: 15 | loadings: 16 | - {y1: 1} 17 | - {y1: 1} 18 | - {y1: 1} 19 | - {y1: 1} 20 | - {y1: 1} 21 | - {y1: 1} 22 | - {y1: 1} 23 | - {y1: 1} 24 | fac2: 25 | measurements: 26 | - [y4, y5, y6] 27 | - [y4, y5, y6] 28 | - [y4, y5, y6] 29 | - [y4, y5, y6] 30 | - [y4, y5, y6] 31 | - [y4, y5, y6] 32 | - [y4, y5, y6] 33 | - [y4, y5, y6] 34 | transition_function: linear 35 | normalizations: 36 | loadings: 37 | - {y4: 1} 38 | - {y4: 1} 39 | - {y4: 1} 40 | - {y4: 1} 41 | - {y4: 1} 42 | - {y4: 1} 43 | - {y4: 1} 44 | - {y4: 1} 45 | fac3: 46 | measurements: 47 | - [y7, y8, y9] 48 | transition_function: constant 49 | normalizations: 50 | loadings: 51 | - {y7: 1} 52 | anchoring: 53 | outcomes: {fac1: Q1} 54 | free_controls: true 55 | free_constant: true 56 | free_loadings: true 57 | ignore_constant_when_anchoring: true 58 | controls: 59 | - x1 60 | stagemap: 61 | - 0 62 | - 0 63 | - 0 64 | - 0 65 | - 0 66 | - 0 67 | - 0 68 | estimation_options: 69 | robust_bounds: true 70 | bounds_distance: 0.001 71 | n_mixtures: 1 72 | -------------------------------------------------------------------------------- /tests/model2_correct_params_index.csv: -------------------------------------------------------------------------------- 1 | category,period,name1,name2 2 | controls,0,y1,constant 3 | controls,0,y1,x1 4 | controls,0,y2,constant 5 | controls,0,y2,x1 6 | controls,0,y3,constant 7 | controls,0,y3,x1 8 | controls,0,y4,constant 9 | controls,0,y4,x1 10 | controls,0,y5,constant 11 | controls,0,y5,x1 12 | controls,0,y6,constant 13 | controls,0,y6,x1 14 | controls,0,y7,constant 15 | controls,0,y7,x1 16 | controls,0,y8,constant 17 | controls,0,y8,x1 18 | controls,0,y9,constant 19 | controls,0,y9,x1 20 | controls,0,Q1_fac1,constant 21 | controls,0,Q1_fac1,x1 22 | controls,1,y1,constant 23 | controls,1,y1,x1 24 | controls,1,y2,constant 25 | controls,1,y2,x1 26 | controls,1,y3,constant 27 | controls,1,y3,x1 28 | controls,1,y4,constant 29 | controls,1,y4,x1 30 | controls,1,y5,constant 31 | controls,1,y5,x1 32 | controls,1,y6,constant 33 | controls,1,y6,x1 34 | controls,1,Q1_fac1,constant 35 | controls,1,Q1_fac1,x1 36 | controls,2,y1,constant 37 | controls,2,y1,x1 38 | controls,2,y2,constant 39 | controls,2,y2,x1 40 | controls,2,y3,constant 41 | controls,2,y3,x1 42 | controls,2,y4,constant 43 | controls,2,y4,x1 44 | controls,2,y5,constant 45 | controls,2,y5,x1 46 | controls,2,y6,constant 47 | controls,2,y6,x1 48 | controls,2,Q1_fac1,constant 49 | controls,2,Q1_fac1,x1 50 | controls,3,y1,constant 51 | controls,3,y1,x1 52 | controls,3,y2,constant 53 | controls,3,y2,x1 54 | controls,3,y3,constant 55 | controls,3,y3,x1 56 | controls,3,y4,constant 57 | controls,3,y4,x1 58 | controls,3,y5,constant 59 | controls,3,y5,x1 60 | controls,3,y6,constant 61 | controls,3,y6,x1 62 | controls,3,Q1_fac1,constant 63 | controls,3,Q1_fac1,x1 64 | controls,4,y1,constant 65 | controls,4,y1,x1 66 | controls,4,y2,constant 67 | controls,4,y2,x1 68 | controls,4,y3,constant 69 | controls,4,y3,x1 70 | controls,4,y4,constant 71 | controls,4,y4,x1 72 | controls,4,y5,constant 73 | controls,4,y5,x1 74 | controls,4,y6,constant 75 | controls,4,y6,x1 76 | controls,4,Q1_fac1,constant 77 | controls,4,Q1_fac1,x1 78 | controls,5,y1,constant 79 | controls,5,y1,x1 80 | controls,5,y2,constant 81 | controls,5,y2,x1 82 | controls,5,y3,constant 83 | controls,5,y3,x1 84 | controls,5,y4,constant 85 | controls,5,y4,x1 86 | controls,5,y5,constant 87 | controls,5,y5,x1 88 | controls,5,y6,constant 89 | controls,5,y6,x1 90 | controls,5,Q1_fac1,constant 91 | controls,5,Q1_fac1,x1 92 | controls,6,y1,constant 93 | controls,6,y1,x1 94 | controls,6,y2,constant 95 | controls,6,y2,x1 96 | controls,6,y3,constant 97 | controls,6,y3,x1 98 | controls,6,y4,constant 99 | controls,6,y4,x1 100 | controls,6,y5,constant 101 | controls,6,y5,x1 102 | controls,6,y6,constant 103 | controls,6,y6,x1 104 | controls,6,Q1_fac1,constant 105 | controls,6,Q1_fac1,x1 106 | controls,7,y1,constant 107 | controls,7,y1,x1 108 | controls,7,y2,constant 109 | controls,7,y2,x1 110 | controls,7,y3,constant 111 | controls,7,y3,x1 112 | controls,7,y4,constant 113 | controls,7,y4,x1 114 | controls,7,y5,constant 115 | controls,7,y5,x1 116 | controls,7,y6,constant 117 | controls,7,y6,x1 118 | controls,7,Q1_fac1,constant 119 | controls,7,Q1_fac1,x1 120 | loadings,0,y1,fac1 121 | loadings,0,y2,fac1 122 | loadings,0,y3,fac1 123 | loadings,0,y4,fac2 124 | loadings,0,y5,fac2 125 | loadings,0,y6,fac2 126 | loadings,0,y7,fac3 127 | loadings,0,y8,fac3 128 | loadings,0,y9,fac3 129 | loadings,0,Q1_fac1,fac1 130 | loadings,1,y1,fac1 131 | loadings,1,y2,fac1 132 | loadings,1,y3,fac1 133 | loadings,1,y4,fac2 134 | loadings,1,y5,fac2 135 | loadings,1,y6,fac2 136 | loadings,1,Q1_fac1,fac1 137 | loadings,2,y1,fac1 138 | loadings,2,y2,fac1 139 | loadings,2,y3,fac1 140 | loadings,2,y4,fac2 141 | loadings,2,y5,fac2 142 | loadings,2,y6,fac2 143 | loadings,2,Q1_fac1,fac1 144 | loadings,3,y1,fac1 145 | loadings,3,y2,fac1 146 | loadings,3,y3,fac1 147 | loadings,3,y4,fac2 148 | loadings,3,y5,fac2 149 | loadings,3,y6,fac2 150 | loadings,3,Q1_fac1,fac1 151 | loadings,4,y1,fac1 152 | loadings,4,y2,fac1 153 | loadings,4,y3,fac1 154 | loadings,4,y4,fac2 155 | loadings,4,y5,fac2 156 | loadings,4,y6,fac2 157 | loadings,4,Q1_fac1,fac1 158 | loadings,5,y1,fac1 159 | loadings,5,y2,fac1 160 | loadings,5,y3,fac1 161 | loadings,5,y4,fac2 162 | loadings,5,y5,fac2 163 | loadings,5,y6,fac2 164 | loadings,5,Q1_fac1,fac1 165 | loadings,6,y1,fac1 166 | loadings,6,y2,fac1 167 | loadings,6,y3,fac1 168 | loadings,6,y4,fac2 169 | loadings,6,y5,fac2 170 | loadings,6,y6,fac2 171 | loadings,6,Q1_fac1,fac1 172 | loadings,7,y1,fac1 173 | loadings,7,y2,fac1 174 | loadings,7,y3,fac1 175 | loadings,7,y4,fac2 176 | loadings,7,y5,fac2 177 | loadings,7,y6,fac2 178 | loadings,7,Q1_fac1,fac1 179 | meas_sds,0,y1,- 180 | meas_sds,0,y2,- 181 | meas_sds,0,y3,- 182 | meas_sds,0,y4,- 183 | meas_sds,0,y5,- 184 | meas_sds,0,y6,- 185 | meas_sds,0,y7,- 186 | meas_sds,0,y8,- 187 | meas_sds,0,y9,- 188 | meas_sds,0,Q1_fac1,- 189 | meas_sds,1,y1,- 190 | meas_sds,1,y2,- 191 | meas_sds,1,y3,- 192 | meas_sds,1,y4,- 193 | meas_sds,1,y5,- 194 | meas_sds,1,y6,- 195 | meas_sds,1,Q1_fac1,- 196 | meas_sds,2,y1,- 197 | meas_sds,2,y2,- 198 | meas_sds,2,y3,- 199 | meas_sds,2,y4,- 200 | meas_sds,2,y5,- 201 | meas_sds,2,y6,- 202 | meas_sds,2,Q1_fac1,- 203 | meas_sds,3,y1,- 204 | meas_sds,3,y2,- 205 | meas_sds,3,y3,- 206 | meas_sds,3,y4,- 207 | meas_sds,3,y5,- 208 | meas_sds,3,y6,- 209 | meas_sds,3,Q1_fac1,- 210 | meas_sds,4,y1,- 211 | meas_sds,4,y2,- 212 | meas_sds,4,y3,- 213 | meas_sds,4,y4,- 214 | meas_sds,4,y5,- 215 | meas_sds,4,y6,- 216 | meas_sds,4,Q1_fac1,- 217 | meas_sds,5,y1,- 218 | meas_sds,5,y2,- 219 | meas_sds,5,y3,- 220 | meas_sds,5,y4,- 221 | meas_sds,5,y5,- 222 | meas_sds,5,y6,- 223 | meas_sds,5,Q1_fac1,- 224 | meas_sds,6,y1,- 225 | meas_sds,6,y2,- 226 | meas_sds,6,y3,- 227 | meas_sds,6,y4,- 228 | meas_sds,6,y5,- 229 | meas_sds,6,y6,- 230 | meas_sds,6,Q1_fac1,- 231 | meas_sds,7,y1,- 232 | meas_sds,7,y2,- 233 | meas_sds,7,y3,- 234 | meas_sds,7,y4,- 235 | meas_sds,7,y5,- 236 | meas_sds,7,y6,- 237 | meas_sds,7,Q1_fac1,- 238 | shock_sds,0,fac1,- 239 | shock_sds,0,fac2,- 240 | shock_sds,0,fac3,- 241 | shock_sds,1,fac1,- 242 | shock_sds,1,fac2,- 243 | shock_sds,1,fac3,- 244 | shock_sds,2,fac1,- 245 | shock_sds,2,fac2,- 246 | shock_sds,2,fac3,- 247 | shock_sds,3,fac1,- 248 | shock_sds,3,fac2,- 249 | shock_sds,3,fac3,- 250 | shock_sds,4,fac1,- 251 | shock_sds,4,fac2,- 252 | shock_sds,4,fac3,- 253 | shock_sds,5,fac1,- 254 | shock_sds,5,fac2,- 255 | shock_sds,5,fac3,- 256 | shock_sds,6,fac1,- 257 | shock_sds,6,fac2,- 258 | shock_sds,6,fac3,- 259 | initial_states,0,mixture_0,fac1 260 | initial_states,0,mixture_0,fac2 261 | initial_states,0,mixture_0,fac3 262 | mixture_weights,0,mixture_0,- 263 | initial_cholcovs,0,mixture_0,fac1-fac1 264 | initial_cholcovs,0,mixture_0,fac2-fac1 265 | initial_cholcovs,0,mixture_0,fac2-fac2 266 | initial_cholcovs,0,mixture_0,fac3-fac1 267 | initial_cholcovs,0,mixture_0,fac3-fac2 268 | initial_cholcovs,0,mixture_0,fac3-fac3 269 | transition,0,fac1,fac1 270 | transition,0,fac1,fac2 271 | transition,0,fac1,fac3 272 | transition,0,fac1,phi 273 | transition,1,fac1,fac1 274 | transition,1,fac1,fac2 275 | transition,1,fac1,fac3 276 | transition,1,fac1,phi 277 | transition,2,fac1,fac1 278 | transition,2,fac1,fac2 279 | transition,2,fac1,fac3 280 | transition,2,fac1,phi 281 | transition,3,fac1,fac1 282 | transition,3,fac1,fac2 283 | transition,3,fac1,fac3 284 | transition,3,fac1,phi 285 | transition,4,fac1,fac1 286 | transition,4,fac1,fac2 287 | transition,4,fac1,fac3 288 | transition,4,fac1,phi 289 | transition,5,fac1,fac1 290 | transition,5,fac1,fac2 291 | transition,5,fac1,fac3 292 | transition,5,fac1,phi 293 | transition,6,fac1,fac1 294 | transition,6,fac1,fac2 295 | transition,6,fac1,fac3 296 | transition,6,fac1,phi 297 | transition,0,fac2,fac1 298 | transition,0,fac2,fac2 299 | transition,0,fac2,fac3 300 | transition,0,fac2,constant 301 | transition,1,fac2,fac1 302 | transition,1,fac2,fac2 303 | transition,1,fac2,fac3 304 | transition,1,fac2,constant 305 | transition,2,fac2,fac1 306 | transition,2,fac2,fac2 307 | transition,2,fac2,fac3 308 | transition,2,fac2,constant 309 | transition,3,fac2,fac1 310 | transition,3,fac2,fac2 311 | transition,3,fac2,fac3 312 | transition,3,fac2,constant 313 | transition,4,fac2,fac1 314 | transition,4,fac2,fac2 315 | transition,4,fac2,fac3 316 | transition,4,fac2,constant 317 | transition,5,fac2,fac1 318 | transition,5,fac2,fac2 319 | transition,5,fac2,fac3 320 | transition,5,fac2,constant 321 | transition,6,fac2,fac1 322 | transition,6,fac2,fac2 323 | transition,6,fac2,fac3 324 | transition,6,fac2,constant 325 | -------------------------------------------------------------------------------- /tests/model2_correct_update_info.csv: -------------------------------------------------------------------------------- 1 | period,variable,fac1,fac2,fac3,purpose 2 | 0,y1,True,False,False,measurement 3 | 0,y2,True,False,False,measurement 4 | 0,y3,True,False,False,measurement 5 | 0,y4,False,True,False,measurement 6 | 0,y5,False,True,False,measurement 7 | 0,y6,False,True,False,measurement 8 | 0,y7,False,False,True,measurement 9 | 0,y8,False,False,True,measurement 10 | 0,y9,False,False,True,measurement 11 | 0,Q1_fac1,True,False,False,anchoring 12 | 1,y1,True,False,False,measurement 13 | 1,y2,True,False,False,measurement 14 | 1,y3,True,False,False,measurement 15 | 1,y4,False,True,False,measurement 16 | 1,y5,False,True,False,measurement 17 | 1,y6,False,True,False,measurement 18 | 1,Q1_fac1,True,False,False,anchoring 19 | 2,y1,True,False,False,measurement 20 | 2,y2,True,False,False,measurement 21 | 2,y3,True,False,False,measurement 22 | 2,y4,False,True,False,measurement 23 | 2,y5,False,True,False,measurement 24 | 2,y6,False,True,False,measurement 25 | 2,Q1_fac1,True,False,False,anchoring 26 | 3,y1,True,False,False,measurement 27 | 3,y2,True,False,False,measurement 28 | 3,y3,True,False,False,measurement 29 | 3,y4,False,True,False,measurement 30 | 3,y5,False,True,False,measurement 31 | 3,y6,False,True,False,measurement 32 | 3,Q1_fac1,True,False,False,anchoring 33 | 4,y1,True,False,False,measurement 34 | 4,y2,True,False,False,measurement 35 | 4,y3,True,False,False,measurement 36 | 4,y4,False,True,False,measurement 37 | 4,y5,False,True,False,measurement 38 | 4,y6,False,True,False,measurement 39 | 4,Q1_fac1,True,False,False,anchoring 40 | 5,y1,True,False,False,measurement 41 | 5,y2,True,False,False,measurement 42 | 5,y3,True,False,False,measurement 43 | 5,y4,False,True,False,measurement 44 | 5,y5,False,True,False,measurement 45 | 5,y6,False,True,False,measurement 46 | 5,Q1_fac1,True,False,False,anchoring 47 | 6,y1,True,False,False,measurement 48 | 6,y2,True,False,False,measurement 49 | 6,y3,True,False,False,measurement 50 | 6,y4,False,True,False,measurement 51 | 6,y5,False,True,False,measurement 52 | 6,y6,False,True,False,measurement 53 | 6,Q1_fac1,True,False,False,anchoring 54 | 7,y1,True,False,False,measurement 55 | 7,y2,True,False,False,measurement 56 | 7,y3,True,False,False,measurement 57 | 7,y4,False,True,False,measurement 58 | 7,y5,False,True,False,measurement 59 | 7,y6,False,True,False,measurement 60 | 7,Q1_fac1,True,False,False,anchoring 61 | -------------------------------------------------------------------------------- /tests/model2_simulated_data.dta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenSourceEconomics/skillmodels/8417b4b9afe33ca9cf249136bae5608a1ca43bc5/tests/model2_simulated_data.dta -------------------------------------------------------------------------------- /tests/test_clipping.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | import numpy as np 3 | 4 | from skillmodels.clipping import soft_clipping 5 | 6 | 7 | def test_one_sided_soft_maximum(): 8 | arr = jnp.array([-10.0, -5, -1, 1, 5, 10]) 9 | lower_bound = -8 10 | lower_hardness = 3 11 | expected = [] 12 | for x in arr: 13 | exp_part = jnp.exp(lower_hardness * x) + jnp.exp(lower_hardness * lower_bound) 14 | entry = jnp.log(exp_part) / lower_hardness 15 | expected.append(entry) 16 | 17 | res = soft_clipping(arr=arr, lower=lower_bound, lower_hardness=lower_hardness) 18 | # compare to calculation "by hand" 19 | np.testing.assert_allclose(res, np.array(expected)) 20 | # compare that upper part is very close to true values 21 | np.testing.assert_allclose(res[1:], arr[1:], rtol=1e-05) 22 | 23 | 24 | def test_one_sided_soft_minimum(): 25 | arr = jnp.array([-10.0, -5, -1, 1, 5, 10]) 26 | upper_bound = 8 27 | upper_hardness = 3 28 | 29 | expected = [] 30 | for x in arr: 31 | # min(x, y) = -max(-x, -y) 32 | # min(3, 5) = 3 = -max(-3, -5) = -(-3) 33 | exp_part = jnp.exp(-upper_hardness * x) + jnp.exp(-upper_hardness * upper_bound) 34 | entry = -jnp.log(exp_part) / upper_hardness 35 | expected.append(entry) 36 | 37 | res = soft_clipping(arr=arr, upper=upper_bound, upper_hardness=upper_hardness) 38 | 39 | # compare to calculation "by hand" 40 | np.testing.assert_allclose(res, np.array(expected)) 41 | 42 | # compare that the lower part is very close to true values 43 | np.testing.assert_allclose(res[:-1], arr[:-1], rtol=1e-05) 44 | -------------------------------------------------------------------------------- /tests/test_decorators.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | 3 | from skillmodels.decorators import extract_params, jax_array_output, register_params 4 | 5 | 6 | def test_extract_params_decorator_only_key(): 7 | @extract_params(key="a") 8 | def f(x, params): 9 | return x * params 10 | 11 | assert f(x=3, params={"a": 4, "b": 5}) == 12 12 | 13 | 14 | def test_extract_params_direct_call_only_key(): 15 | def f(x, params): 16 | return x * params 17 | 18 | g = extract_params(f, key="a") 19 | 20 | assert g(x=3, params={"a": 4, "b": 5}) == 12 21 | 22 | 23 | def test_extract_params_decorator_only_names(): 24 | @extract_params(names=["c", "d"]) 25 | def f(x, params): 26 | return x * params["c"] 27 | 28 | assert f(x=3, params=[4, 5]) == 12 29 | 30 | 31 | def test_extract_params_direct_call_only_names(): 32 | def f(x, params): 33 | return x * params["c"] 34 | 35 | g = extract_params(f, names=["c", "d"]) 36 | assert g(x=3, params=[4, 5]) == 12 37 | 38 | 39 | def test_extract_params_decorator_key_and_names(): 40 | @extract_params(key="a", names=["c", "d"]) 41 | def f(x, params): 42 | return x * params["c"] 43 | 44 | assert f(x=3, params={"a": [4, 5], "b": [5, 6]}) == 12 45 | 46 | 47 | def test_extract_params_direct_call_key_and_names(): 48 | def f(x, params): 49 | return x * params["c"] 50 | 51 | g = extract_params(f, key="a", names=["c", "d"]) 52 | assert g(x=3, params={"a": [4, 5], "b": [5, 6]}) == 12 53 | 54 | 55 | def test_jax_array_output_decorator(): 56 | @jax_array_output 57 | def f(): 58 | return (1, 2, 3) 59 | 60 | assert isinstance(f(), jnp.ndarray) 61 | 62 | 63 | def test_jax_array_output_direct_call(): 64 | def f(): 65 | return (1, 2, 3) 66 | 67 | g = jax_array_output(f) 68 | 69 | assert isinstance(g(), jnp.ndarray) 70 | 71 | 72 | def test_register_params_decorator(): 73 | @register_params(params=["a", "b", "c"]) 74 | def f(): 75 | return "bla" 76 | 77 | assert f.__registered_params__ == ["a", "b", "c"] 78 | assert f() == "bla" 79 | 80 | 81 | def test_register_params_direct_call(): 82 | def f(): 83 | return "bla" 84 | 85 | g = register_params(f, params=["a", "b", "c"]) 86 | assert g.__registered_params__ == ["a", "b", "c"] 87 | assert g() == "bla" 88 | -------------------------------------------------------------------------------- /tests/test_filtered_states.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import numpy as np 4 | import pandas as pd 5 | import pytest 6 | import yaml 7 | 8 | from skillmodels.filtered_states import get_filtered_states 9 | from skillmodels.maximization_inputs import get_maximization_inputs 10 | 11 | # importing the TEST_DIR from config does not work for test run in conda build 12 | TEST_DIR = Path(__file__).parent.resolve() 13 | 14 | 15 | @pytest.fixture 16 | def model2(): 17 | with open(TEST_DIR / "model2.yaml") as y: 18 | model_dict = yaml.load(y, Loader=yaml.FullLoader) 19 | return model_dict 20 | 21 | 22 | @pytest.fixture 23 | def model2_data(): 24 | data = pd.read_stata(TEST_DIR / "model2_simulated_data.dta") 25 | data = data.set_index(["caseid", "period"]) 26 | return data 27 | 28 | 29 | def test_get_filtered_states(model2, model2_data): 30 | params = pd.read_csv(TEST_DIR / "regression_vault" / "one_stage_anchoring.csv") 31 | params = params.set_index(["category", "period", "name1", "name2"]) 32 | 33 | max_inputs = get_maximization_inputs(model2, model2_data) 34 | params = params.loc[max_inputs["params_template"].index] 35 | 36 | calculated = get_filtered_states(model_dict=model2, data=model2_data, params=params) 37 | 38 | factors = ["fac1", "fac2", "fac3"] 39 | expected_ratios = [1.187757, 1, 1] 40 | for factor, expected_ratio in zip(factors, expected_ratios, strict=False): 41 | anch_ranges = calculated["anchored_states"]["state_ranges"][factor] 42 | unanch_ranges = calculated["unanchored_states"]["state_ranges"][factor] 43 | ratio = (anch_ranges / unanch_ranges).to_numpy() 44 | assert np.allclose(ratio, expected_ratio) 45 | -------------------------------------------------------------------------------- /tests/test_likelihood_regression.py: -------------------------------------------------------------------------------- 1 | import json 2 | from itertools import product 3 | from pathlib import Path 4 | 5 | import jax 6 | import numpy as np 7 | import pandas as pd 8 | import pytest 9 | import yaml 10 | from numpy.testing import assert_array_almost_equal as aaae 11 | 12 | from skillmodels.decorators import register_params 13 | from skillmodels.maximization_inputs import get_maximization_inputs 14 | from skillmodels.utilities import reduce_n_periods 15 | 16 | jax.config.update("jax_enable_x64", True) 17 | 18 | MODEL_NAMES = [ 19 | "no_stages_anchoring", 20 | "one_stage", 21 | "one_stage_anchoring", 22 | "two_stages_anchoring", 23 | "one_stage_anchoring_custom_functions", 24 | ] 25 | 26 | # importing the TEST_DIR from config does not work for test run in conda build 27 | TEST_DIR = Path(__file__).parent.resolve() 28 | 29 | 30 | @pytest.fixture 31 | def model2(): 32 | with open(TEST_DIR / "model2.yaml") as y: 33 | model_dict = yaml.load(y, Loader=yaml.FullLoader) 34 | return model_dict 35 | 36 | 37 | @pytest.fixture 38 | def model2_data(): 39 | data = pd.read_stata(TEST_DIR / "model2_simulated_data.dta") 40 | data = data.set_index(["caseid", "period"]) 41 | return data 42 | 43 | 44 | def _convert_model(base_model, model_name): 45 | model = base_model.copy() 46 | if model_name == "no_stages_anchoring": 47 | model.pop("stagemap") 48 | elif model_name == "one_stage": 49 | model.pop("anchoring") 50 | elif model_name == "one_stage_anchoring": 51 | pass 52 | elif model_name == "two_stages_anchoring": 53 | model["stagemap"] = [0, 0, 0, 0, 1, 1, 1] 54 | elif model_name == "one_stage_anchoring_custom_functions": 55 | 56 | @register_params(params=[]) 57 | def constant(fac3, params): 58 | return fac3 59 | 60 | @register_params(params=["fac1", "fac2", "fac3", "constant"]) 61 | def linear(fac1, fac2, fac3, params): 62 | p = params 63 | out = p["constant"] + fac1 * p["fac1"] + fac2 * p["fac2"] + fac3 * p["fac3"] 64 | return out 65 | 66 | model["factors"]["fac2"]["transition_function"] = linear 67 | model["factors"]["fac3"]["transition_function"] = constant 68 | else: 69 | raise ValueError("Invalid model name.") 70 | return model 71 | 72 | 73 | @pytest.mark.parametrize( 74 | ("model_name", "fun_key"), product(MODEL_NAMES, ["loglike", "debug_loglike"]) 75 | ) 76 | def test_likelihood_values_have_not_changed(model2, model2_data, model_name, fun_key): 77 | regvault = TEST_DIR / "regression_vault" 78 | model = _convert_model(model2, model_name) 79 | params = pd.read_csv(regvault / f"{model_name}.csv").set_index( 80 | ["category", "period", "name1", "name2"], 81 | ) 82 | 83 | inputs = get_maximization_inputs(model, model2_data) 84 | 85 | params = params.loc[inputs["params_template"].index] 86 | 87 | fun = inputs[fun_key] 88 | new_loglike = fun(params)["value"] if "debug" in fun_key else fun(params) 89 | 90 | with open(regvault / f"{model_name}_result.json") as j: 91 | old_loglike = np.array(json.load(j)).sum() 92 | aaae(new_loglike, old_loglike) 93 | 94 | 95 | @pytest.mark.parametrize( 96 | ("model_name", "fun_key"), product(MODEL_NAMES, ["loglikeobs"]) 97 | ) 98 | def test_likelihood_contributions_have_not_changed( 99 | model2, model2_data, model_name, fun_key 100 | ): 101 | regvault = TEST_DIR / "regression_vault" 102 | model = _convert_model(model2, model_name) 103 | params = pd.read_csv(regvault / f"{model_name}.csv").set_index( 104 | ["category", "period", "name1", "name2"], 105 | ) 106 | 107 | inputs = get_maximization_inputs(model, model2_data) 108 | 109 | params = params.loc[inputs["params_template"].index] 110 | 111 | fun = inputs[fun_key] 112 | new_loglikes = fun(params)["contributions"] if "debug" in fun_key else fun(params) 113 | 114 | with open(regvault / f"{model_name}_result.json") as j: 115 | old_loglikes = np.array(json.load(j)) 116 | aaae(new_loglikes, old_loglikes) 117 | 118 | 119 | @pytest.mark.parametrize( 120 | ("model_type", "fun_key"), 121 | product(["no_stages_anchoring", "with_missings"], ["loglike_and_gradient"]), 122 | ) 123 | def test_likelihood_contributions_large_nobs(model2, model2_data, model_type, fun_key): 124 | regvault = TEST_DIR / "regression_vault" 125 | model = _convert_model(model2, "no_stages_anchoring") 126 | params = pd.read_csv(regvault / "no_stages_anchoring.csv").set_index( 127 | ["category", "period", "name1", "name2"], 128 | ) 129 | 130 | to_concat = [model2_data] 131 | idx_names = model2_data.index.names 132 | n_repetitions = 5 133 | n_ids = model2_data.index.get_level_values("caseid").max() 134 | for i in range(1, 1 + n_repetitions): 135 | increment = i * n_ids 136 | this_round = model2_data.copy().reset_index() 137 | for col in ("caseid", "id"): 138 | this_round[col] += increment 139 | this_round = this_round.set_index(idx_names) 140 | cols = [ 141 | "y1", 142 | "y2", 143 | "y3", 144 | "y4", 145 | "y5", 146 | "y6", 147 | "y7", 148 | "y8", 149 | "y9", 150 | "Q1", 151 | "dy7", 152 | "dy8", 153 | "dy9", 154 | "x1", 155 | ] 156 | if model_type == "no_stages_anchoring": 157 | for col in cols: 158 | this_round[col] += np.random.normal(0, 0.1, (len(model2_data),)) 159 | elif model_type == "with_missings": 160 | fraction_to_set_missing = 0.9 161 | n_rows = len(this_round) 162 | n_missing = int(n_rows * fraction_to_set_missing) 163 | rows_to_set_missing = this_round.sample(n=n_missing).index 164 | this_round.loc[rows_to_set_missing, cols] = np.nan 165 | else: 166 | raise ValueError(f"Invalid model type: {model_type}") 167 | to_concat.append(this_round) 168 | 169 | stacked_data = pd.concat(to_concat) 170 | 171 | inputs = get_maximization_inputs(model, stacked_data) 172 | 173 | params = params.loc[inputs["params_template"].index] 174 | 175 | loglike = inputs[fun_key](params) 176 | 177 | assert np.isfinite(loglike[0]) 178 | assert np.isfinite(loglike[1]).all() 179 | 180 | 181 | def test_likelihood_runs_with_empty_periods(model2, model2_data): 182 | del model2["anchoring"] 183 | for factor in ["fac1", "fac2"]: 184 | model2["factors"][factor]["measurements"][-1] = [] 185 | model2["factors"][factor]["normalizations"]["loadings"][-1] = {} 186 | 187 | func_dict = get_maximization_inputs(model2, model2_data) 188 | 189 | params = func_dict["params_template"] 190 | params["value"] = 0.1 191 | 192 | debug_loglike = func_dict["debug_loglike"] 193 | debug_loglike(params) 194 | 195 | 196 | def test_likelihood_runs_with_too_long_data(model2, model2_data): 197 | model = reduce_n_periods(model2, 2) 198 | func_dict = get_maximization_inputs(model, model2_data) 199 | 200 | params = func_dict["params_template"] 201 | params["value"] = 0.1 202 | 203 | debug_loglike = func_dict["debug_loglike"] 204 | debug_loglike(params) 205 | 206 | 207 | def test_likelihood_runs_with_observed_factors(model2, model2_data): 208 | model2["observed_factors"] = ["ob1", "ob2"] 209 | model2_data["ob1"] = np.arange(len(model2_data)) 210 | model2_data["ob2"] = np.ones(len(model2_data)) 211 | func_dict = get_maximization_inputs(model2, model2_data) 212 | 213 | params = func_dict["params_template"] 214 | params["value"] = 0.1 215 | 216 | debug_loglike = func_dict["debug_loglike"] 217 | debug_loglike(params) 218 | -------------------------------------------------------------------------------- /tests/test_maximization_inputs.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | import numpy as np 3 | 4 | from skillmodels.maximization_inputs import _to_numpy 5 | 6 | 7 | def test_to_numpy_with_dict(): 8 | dict_ = {"a": jnp.ones(3), "b": 4.5} 9 | calculated = _to_numpy(dict_) 10 | assert isinstance(calculated["a"], np.ndarray) 11 | assert isinstance(calculated["b"], float) 12 | 13 | 14 | def test_to_numpy_one_array(): 15 | calculated = _to_numpy(jnp.ones(3)) 16 | assert isinstance(calculated, np.ndarray) 17 | 18 | 19 | def test_to_numpy_one_float(): 20 | calculated = _to_numpy(3.5) 21 | assert isinstance(calculated, float) 22 | -------------------------------------------------------------------------------- /tests/test_params_index.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import pandas as pd 4 | import pytest 5 | import yaml 6 | 7 | from skillmodels.params_index import ( 8 | get_control_params_index_tuples, 9 | get_initial_cholcovs_index_tuples, 10 | get_loadings_index_tuples, 11 | get_meas_sds_index_tuples, 12 | get_mixture_weights_index_tuples, 13 | get_params_index, 14 | get_shock_sds_index_tuples, 15 | get_transition_index_tuples, 16 | initial_mean_index_tuples, 17 | ) 18 | from skillmodels.process_model import process_model 19 | 20 | 21 | @pytest.fixture 22 | def model2_inputs(): 23 | test_dir = Path(__file__).parent.resolve() 24 | with open(test_dir / "model2.yaml") as y: 25 | model_dict = yaml.load(y, Loader=yaml.FullLoader) 26 | processed = process_model(model_dict) 27 | 28 | out = { 29 | "update_info": processed["update_info"], 30 | "labels": processed["labels"], 31 | "dimensions": processed["dimensions"], 32 | "transition_info": processed["transition_info"], 33 | } 34 | return out 35 | 36 | 37 | def test_params_index_with_model2(model2_inputs): 38 | test_dir = Path(__file__).parent.resolve() 39 | calculated = get_params_index(**model2_inputs) 40 | expected = pd.read_csv( 41 | test_dir / "model2_correct_params_index.csv", 42 | index_col=["category", "period", "name1", "name2"], 43 | ).index 44 | 45 | assert calculated.equals(expected) 46 | 47 | 48 | def test_control_coeffs_index_tuples(): 49 | uinfo_tups = [(0, "m1"), (0, "m2"), (0, "bla"), (1, "m1"), (1, "m2")] 50 | uinfo = pd.DataFrame(index=pd.MultiIndex.from_tuples(uinfo_tups)) 51 | controls = ["constant", "c1"] 52 | 53 | expected = [ 54 | ("controls", 0, "m1", "constant"), 55 | ("controls", 0, "m1", "c1"), 56 | ("controls", 0, "m2", "constant"), 57 | ("controls", 0, "m2", "c1"), 58 | ("controls", 0, "bla", "constant"), 59 | ("controls", 0, "bla", "c1"), 60 | ("controls", 1, "m1", "constant"), 61 | ("controls", 1, "m1", "c1"), 62 | ("controls", 1, "m2", "constant"), 63 | ("controls", 1, "m2", "c1"), 64 | ] 65 | 66 | calculated = get_control_params_index_tuples(controls, uinfo) 67 | assert calculated == expected 68 | 69 | 70 | def test_loading_index_tuples(): 71 | uinfo_tups = [(0, "m1"), (0, "m2"), (0, "bla"), (1, "m1"), (1, "m2")] 72 | uinfo = pd.DataFrame( 73 | True, 74 | index=pd.MultiIndex.from_tuples(uinfo_tups), 75 | columns=["fac1", "fac2"], 76 | ) 77 | factors = ["fac1", "fac2"] 78 | expected = [ 79 | ("loadings", 0, "m1", "fac1"), 80 | ("loadings", 0, "m1", "fac2"), 81 | ("loadings", 0, "m2", "fac1"), 82 | ("loadings", 0, "m2", "fac2"), 83 | ("loadings", 0, "bla", "fac1"), 84 | ("loadings", 0, "bla", "fac2"), 85 | ("loadings", 1, "m1", "fac1"), 86 | ("loadings", 1, "m1", "fac2"), 87 | ("loadings", 1, "m2", "fac1"), 88 | ("loadings", 1, "m2", "fac2"), 89 | ] 90 | 91 | calculated = get_loadings_index_tuples(factors, uinfo) 92 | assert calculated == expected 93 | 94 | 95 | def test_meas_sd_index_tuples(): 96 | uinfo_tups = [(0, "m1"), (0, "m2"), (0, "bla"), (1, "m1"), (1, "m2")] 97 | uinfo = pd.DataFrame(index=pd.MultiIndex.from_tuples(uinfo_tups)) 98 | 99 | expected = [ 100 | ("meas_sds", 0, "m1", "-"), 101 | ("meas_sds", 0, "m2", "-"), 102 | ("meas_sds", 0, "bla", "-"), 103 | ("meas_sds", 1, "m1", "-"), 104 | ("meas_sds", 1, "m2", "-"), 105 | ] 106 | 107 | calculated = get_meas_sds_index_tuples(uinfo) 108 | assert calculated == expected 109 | 110 | 111 | def test_shock_sd_index_tuples(): 112 | periods = [0, 1, 2] 113 | factors = ["fac1", "fac2"] 114 | 115 | expected = [ 116 | ("shock_sds", 0, "fac1", "-"), 117 | ("shock_sds", 0, "fac2", "-"), 118 | ("shock_sds", 1, "fac1", "-"), 119 | ("shock_sds", 1, "fac2", "-"), 120 | ] 121 | 122 | calculated = get_shock_sds_index_tuples(periods, factors) 123 | assert calculated == expected 124 | 125 | 126 | def test_initial_mean_index_tuples(): 127 | nmixtures = 3 128 | factors = ["fac1", "fac2"] 129 | 130 | expected = [ 131 | ("initial_states", 0, "mixture_0", "fac1"), 132 | ("initial_states", 0, "mixture_0", "fac2"), 133 | ("initial_states", 0, "mixture_1", "fac1"), 134 | ("initial_states", 0, "mixture_1", "fac2"), 135 | ("initial_states", 0, "mixture_2", "fac1"), 136 | ("initial_states", 0, "mixture_2", "fac2"), 137 | ] 138 | 139 | calculated = initial_mean_index_tuples(nmixtures, factors) 140 | assert calculated == expected 141 | 142 | 143 | def test_mixture_weight_index_tuples(): 144 | nmixtures = 3 145 | expected = [ 146 | ("mixture_weights", 0, "mixture_0", "-"), 147 | ("mixture_weights", 0, "mixture_1", "-"), 148 | ("mixture_weights", 0, "mixture_2", "-"), 149 | ] 150 | calculated = get_mixture_weights_index_tuples(nmixtures) 151 | assert calculated == expected 152 | 153 | 154 | def test_initial_cov_index_tuples(): 155 | nmixtures = 2 156 | factors = ["fac1", "fac2", "fac3"] 157 | expected = [ 158 | ("initial_cholcovs", 0, "mixture_0", "fac1-fac1"), 159 | ("initial_cholcovs", 0, "mixture_0", "fac2-fac1"), 160 | ("initial_cholcovs", 0, "mixture_0", "fac2-fac2"), 161 | ("initial_cholcovs", 0, "mixture_0", "fac3-fac1"), 162 | ("initial_cholcovs", 0, "mixture_0", "fac3-fac2"), 163 | ("initial_cholcovs", 0, "mixture_0", "fac3-fac3"), 164 | ("initial_cholcovs", 0, "mixture_1", "fac1-fac1"), 165 | ("initial_cholcovs", 0, "mixture_1", "fac2-fac1"), 166 | ("initial_cholcovs", 0, "mixture_1", "fac2-fac2"), 167 | ("initial_cholcovs", 0, "mixture_1", "fac3-fac1"), 168 | ("initial_cholcovs", 0, "mixture_1", "fac3-fac2"), 169 | ("initial_cholcovs", 0, "mixture_1", "fac3-fac3"), 170 | ] 171 | 172 | calculated = get_initial_cholcovs_index_tuples(nmixtures, factors) 173 | assert calculated == expected 174 | 175 | 176 | def test_trans_coeffs_index_tuples(): 177 | periods = [0, 1, 2] 178 | 179 | param_names = { 180 | "fac1": ["fac1", "fac2", "fac3", "constant"], 181 | "fac2": [], 182 | "fac3": ["fac1", "fac2", "fac3", "phi"], 183 | } 184 | trans_info = {"param_names": param_names} 185 | 186 | expected = [ 187 | ("transition", 0, "fac1", "fac1"), 188 | ("transition", 0, "fac1", "fac2"), 189 | ("transition", 0, "fac1", "fac3"), 190 | ("transition", 0, "fac1", "constant"), 191 | ("transition", 1, "fac1", "fac1"), 192 | ("transition", 1, "fac1", "fac2"), 193 | ("transition", 1, "fac1", "fac3"), 194 | ("transition", 1, "fac1", "constant"), 195 | ("transition", 0, "fac3", "fac1"), 196 | ("transition", 0, "fac3", "fac2"), 197 | ("transition", 0, "fac3", "fac3"), 198 | ("transition", 0, "fac3", "phi"), 199 | ("transition", 1, "fac3", "fac1"), 200 | ("transition", 1, "fac3", "fac2"), 201 | ("transition", 1, "fac3", "fac3"), 202 | ("transition", 1, "fac3", "phi"), 203 | ] 204 | 205 | calculated = get_transition_index_tuples(trans_info, periods) 206 | 207 | assert calculated == expected 208 | -------------------------------------------------------------------------------- /tests/test_parse_params.py: -------------------------------------------------------------------------------- 1 | """Test parameter parsing with example model 2 from CHS2010. 2 | 3 | Only test the create_parsing_info and parse_params jointly, to abstract from 4 | implementation details. 5 | 6 | """ 7 | 8 | from pathlib import Path 9 | 10 | import jax.numpy as jnp 11 | import numpy as np 12 | import pandas as pd 13 | import pytest 14 | import yaml 15 | from numpy.testing import assert_array_equal as aae 16 | 17 | from skillmodels.parse_params import create_parsing_info, parse_params 18 | from skillmodels.process_model import process_model 19 | 20 | 21 | @pytest.fixture 22 | def parsed_parameters(): 23 | test_dir = Path(__file__).parent.resolve() 24 | p_index = pd.read_csv( 25 | test_dir / "model2_correct_params_index.csv", 26 | index_col=["category", "period", "name1", "name2"], 27 | ).index 28 | 29 | with open(test_dir / "model2.yaml") as y: 30 | model_dict = yaml.load(y, Loader=yaml.FullLoader) 31 | 32 | processed = process_model(model_dict) 33 | 34 | update_info = processed["update_info"] 35 | labels = processed["labels"] 36 | dimensions = processed["dimensions"] 37 | # this overwrites the anchoring setting from the model specification to get a 38 | # more meaningful test 39 | anchoring = {"ignore_constant_when_anchoring": False} 40 | 41 | parsing_info = create_parsing_info(p_index, update_info, labels, anchoring) 42 | 43 | params_vec = jnp.arange(len(p_index)) 44 | n_obs = 5 45 | 46 | parsed = parse_params(params_vec, parsing_info, dimensions, labels, n_obs) 47 | 48 | return dict( 49 | zip(["states", "upper_chols", "log_weights", "pardict"], parsed, strict=False) 50 | ) 51 | 52 | 53 | def test_controls(parsed_parameters): 54 | expected = jnp.arange(118).reshape(59, 2) 55 | aae(parsed_parameters["pardict"]["controls"], expected) 56 | 57 | 58 | def test_loadings(parsed_parameters): 59 | expected_values = jnp.arange(118, 177) 60 | calculated = parsed_parameters["pardict"]["loadings"] 61 | calculated_values = calculated[calculated != 0] 62 | aae(expected_values, calculated_values) 63 | 64 | 65 | def test_meas_sds(parsed_parameters): 66 | expected = jnp.arange(177, 236) 67 | aae(parsed_parameters["pardict"]["meas_sds"], expected) 68 | 69 | 70 | def test_shock_sds(parsed_parameters): 71 | expected = jnp.arange(236, 257).reshape(7, 3) 72 | aae(parsed_parameters["pardict"]["shock_sds"], expected) 73 | 74 | 75 | def test_initial_states(parsed_parameters): 76 | expected = jnp.arange(257, 260).reshape(1, 3).repeat(5, axis=0).reshape(5, 1, 3) 77 | aae(parsed_parameters["states"], expected) 78 | 79 | 80 | def test_initial_upper_chols(parsed_parameters): 81 | expected = ( 82 | jnp.array([[[261, 262, 264], [0, 263, 265], [0, 0, 266]]]) 83 | .repeat(5, axis=0) 84 | .reshape(5, 1, 3, 3) 85 | ) 86 | aae(parsed_parameters["upper_chols"], expected) 87 | 88 | 89 | def test_transition_parameters(parsed_parameters): 90 | calculated = parsed_parameters["pardict"]["transition"] 91 | 92 | aae(calculated["fac1"], jnp.arange(385, 413).reshape(7, 4) - 118) 93 | aae(calculated["fac2"], jnp.arange(413, 441).reshape(7, 4) - 118) 94 | aae(calculated["fac3"], jnp.zeros((7, 0))) 95 | 96 | assert isinstance(calculated, dict) 97 | 98 | 99 | def test_anchoring_scaling_factors(parsed_parameters): 100 | calculated = parsed_parameters["pardict"]["anchoring_scaling_factors"] 101 | expected = np.ones((8, 3)) 102 | expected[:, 0] = jnp.array([127 + 7 * i for i in range(8)]) 103 | aae(calculated, expected) 104 | 105 | 106 | def test_anchoring_constants(parsed_parameters): 107 | calculated = parsed_parameters["pardict"]["anchoring_constants"] 108 | expected = np.zeros((8, 3)) 109 | expected[:, 0] = jnp.array([18 + i * 14 for i in range(8)]) 110 | aae(calculated, expected) 111 | -------------------------------------------------------------------------------- /tests/test_process_data.py: -------------------------------------------------------------------------------- 1 | import io 2 | import textwrap 3 | 4 | import jax.numpy as jnp 5 | import numpy as np 6 | import pandas as pd 7 | import pytest 8 | from numpy.testing import assert_array_equal as aae 9 | 10 | from skillmodels.process_data import ( 11 | _generate_controls_array, 12 | _generate_measurements_array, 13 | _generate_observed_factor_array, 14 | _handle_controls_with_missings, 15 | pre_process_data, 16 | ) 17 | 18 | 19 | def test_pre_process_data(): 20 | df = pd.DataFrame(data=np.arange(10).reshape(10, 1), columns=["var"]) 21 | df["period"] = [1, 2, 3, 2, 3, 4, 2, 4, 3, 1] 22 | df["id"] = [1, 1, 1, 3, 3, 3, 4, 4, 5, 5] 23 | df.set_index(["id", "period"], inplace=True) 24 | 25 | exp = pd.DataFrame() 26 | period = [0, 1, 2, 3] * 4 27 | id_ = np.arange(4).repeat(4) 28 | nan = np.nan 29 | data = [0, 1, 2, nan, nan, 3, 4, 5, nan, 6, nan, 7, 9, nan, 8, nan] 30 | data = np.column_stack([period, id_, data]) 31 | exp = pd.DataFrame(data=data, columns=["__period__", "__id__", "var"]) 32 | exp.set_index(["__id__", "__period__"], inplace=True) 33 | 34 | res = pre_process_data(df, [0, 1, 2, 3]) 35 | 36 | assert res["var"].equals(exp["var"]) 37 | 38 | 39 | def test_handle_controls_with_missings(): 40 | controls = ["c1"] 41 | uinfo_ind_tups = [(0, "m1"), (0, "m2")] 42 | update_info = pd.DataFrame(index=pd.MultiIndex.from_tuples(uinfo_ind_tups)) 43 | data = [[1, 1, 1], [np.nan, 1, 1], [np.nan, 1, np.nan], [np.nan, np.nan, np.nan]] 44 | df = pd.DataFrame(data=data, columns=["m1", "m2", "c1"]) 45 | df["period"] = 0 46 | df["id"] = np.arange(4) 47 | df["__old_id__"] = df["id"] 48 | df["__old_period__"] = df["period"] + 1 49 | df.set_index(["id", "period"], inplace=True) 50 | 51 | with pytest.warns(UserWarning): 52 | calculated = _handle_controls_with_missings(df, controls, update_info) 53 | assert calculated.loc[(2, 0)].isna().all() 54 | 55 | 56 | def test_generate_measurements_array(): 57 | uinfo_ind_tups = [(0, "m1"), (0, "m2"), (1, "m1"), (1, "m3")] 58 | update_info = pd.DataFrame(index=pd.MultiIndex.from_tuples(uinfo_ind_tups)) 59 | 60 | csv = """ 61 | id,period,m1,m2,m3 62 | 0,0,1,2,3 63 | 0,1,4,5,6 64 | 1,0,7,8,9 65 | 1,1,10,11,12 66 | """ 67 | data = _read_csv_string(csv, ["id", "period"]) 68 | 69 | expected = jnp.array([[1, 7], [2, 8], [4, 10], [6, 12.0]]) 70 | 71 | calculated = _generate_measurements_array(data, update_info, 2) 72 | aae(calculated, expected) 73 | 74 | 75 | def test_generate_controls_array(): 76 | csv = """ 77 | id,period,c1,c2 78 | 0, 0, 1, 2 79 | 0, 1, 3, 4 80 | 1, 0, 5, 8 81 | 1, 1, 7, 8 82 | """ 83 | data = _read_csv_string(csv, ["id", "period"]) 84 | 85 | labels = {"controls": ["c1", "c2"], "periods": [0, 1]} 86 | 87 | calculated = _generate_controls_array(data, labels, 2) 88 | expected = jnp.array([[[1, 2], [5, 8]], [[3, 4], [7, 8]]]) 89 | aae(calculated, expected) 90 | 91 | 92 | def test_generate_observed_factor_array(): 93 | csv = """ 94 | id,period,v1,v2 95 | 0, 0, 1, 2 96 | 0, 1, 3, 4 97 | 1, 0, 5, 8 98 | 1, 1, 7, 8 99 | """ 100 | data = _read_csv_string(csv, ["id", "period"]) 101 | 102 | labels = {"observed_factors": ["v1", "v2"], "periods": [0, 1]} 103 | 104 | calculated = _generate_observed_factor_array(data, labels, 2) 105 | expected = jnp.array([[[1, 2], [5, 8]], [[3, 4], [7, 8]]]) 106 | aae(calculated, expected) 107 | 108 | 109 | def _read_csv_string(string, index_cols): 110 | string = textwrap.dedent(string) 111 | return pd.read_csv(io.StringIO(string), index_col=index_cols) 112 | -------------------------------------------------------------------------------- /tests/test_process_model.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from pathlib import Path 3 | 4 | import pandas as pd 5 | import pytest 6 | import yaml 7 | from pandas.testing import assert_frame_equal 8 | 9 | from skillmodels.process_model import process_model 10 | 11 | # ====================================================================================== 12 | # Integration test with model2 from the replication files of CHS2010 13 | # ====================================================================================== 14 | 15 | # importing the TEST_DIR from config does not work for test run in conda build 16 | TEST_DIR = Path(__file__).parent.resolve() 17 | 18 | 19 | @pytest.fixture 20 | def model2(): 21 | with open(TEST_DIR / "model2.yaml") as y: 22 | model_dict = yaml.load(y, Loader=yaml.FullLoader) 23 | return model_dict 24 | 25 | 26 | def test_dimensions(model2): 27 | res = process_model(model2)["dimensions"] 28 | assert res["n_latent_factors"] == 3 29 | assert res["n_observed_factors"] == 0 30 | assert res["n_all_factors"] == 3 31 | assert res["n_periods"] == 8 32 | assert res["n_controls"] == 2 33 | assert res["n_mixtures"] == 1 34 | 35 | 36 | def test_labels(model2): 37 | res = process_model(model2)["labels"] 38 | assert res["latent_factors"] == ["fac1", "fac2", "fac3"] 39 | assert res["observed_factors"] == [] 40 | assert res["all_factors"] == ["fac1", "fac2", "fac3"] 41 | assert res["controls"] == ["constant", "x1"] 42 | assert res["periods"] == [0, 1, 2, 3, 4, 5, 6, 7] 43 | assert res["stagemap"] == [0, 0, 0, 0, 0, 0, 0] 44 | assert res["stages"] == [0] 45 | 46 | 47 | def test_estimation_options(model2): 48 | res = process_model(model2)["estimation_options"] 49 | assert res["sigma_points_scale"] == 2 50 | assert res["robust_bounds"] 51 | assert res["bounds_distance"] == 0.001 52 | 53 | 54 | def test_anchoring(model2): 55 | res = process_model(model2)["anchoring"] 56 | assert res["outcomes"] == {"fac1": "Q1"} 57 | assert res["factors"] == ["fac1"] 58 | assert res["free_controls"] 59 | assert res["free_constant"] 60 | assert res["free_loadings"] 61 | 62 | 63 | def test_transition_info(model2): 64 | res = process_model(model2)["transition_info"] 65 | 66 | assert isinstance(res, dict) 67 | assert callable(res["func"]) 68 | 69 | assert list(inspect.signature(res["func"]).parameters) == ["params", "states"] 70 | 71 | 72 | def test_update_info(model2): 73 | res = process_model(model2)["update_info"] 74 | test_dir = Path(__file__).parent.resolve() 75 | expected = pd.read_csv( 76 | test_dir / "model2_correct_update_info.csv", 77 | index_col=["period", "variable"], 78 | ) 79 | assert_frame_equal(res, expected) 80 | 81 | 82 | def test_normalizations(model2): 83 | expected = { 84 | "fac1": { 85 | "loadings": [ 86 | {"y1": 1}, 87 | {"y1": 1}, 88 | {"y1": 1}, 89 | {"y1": 1}, 90 | {"y1": 1}, 91 | {"y1": 1}, 92 | {"y1": 1}, 93 | {"y1": 1}, 94 | ], 95 | "intercepts": [{}, {}, {}, {}, {}, {}, {}, {}], 96 | }, 97 | "fac2": { 98 | "loadings": [ 99 | {"y4": 1}, 100 | {"y4": 1}, 101 | {"y4": 1}, 102 | {"y4": 1}, 103 | {"y4": 1}, 104 | {"y4": 1}, 105 | {"y4": 1}, 106 | {"y4": 1}, 107 | ], 108 | "intercepts": [{}, {}, {}, {}, {}, {}, {}, {}], 109 | }, 110 | "fac3": { 111 | "loadings": [{"y7": 1}, {}, {}, {}, {}, {}, {}, {}], 112 | "intercepts": [{}, {}, {}, {}, {}, {}, {}, {}], 113 | }, 114 | } 115 | res = process_model(model2)["normalizations"] 116 | 117 | assert res == expected 118 | -------------------------------------------------------------------------------- /tests/test_simulate_data.py: -------------------------------------------------------------------------------- 1 | """Tests for functions in simulate_data module.""" 2 | 3 | from pathlib import Path 4 | 5 | import numpy as np 6 | import pandas as pd 7 | import pytest 8 | import yaml 9 | from numpy.testing import assert_array_almost_equal as aaae 10 | 11 | from skillmodels.simulate_data import measurements_from_states, simulate_dataset 12 | 13 | # importing the TEST_DIR from config does not work for test run in conda build 14 | TEST_DIR = Path(__file__).parent.resolve() 15 | 16 | 17 | @pytest.fixture 18 | def model2(): 19 | with open(TEST_DIR / "model2.yaml") as y: 20 | model_dict = yaml.load(y, Loader=yaml.FullLoader) 21 | return model_dict 22 | 23 | 24 | @pytest.fixture 25 | def model2_data(): 26 | data = pd.read_stata(TEST_DIR / "model2_simulated_data.dta") 27 | data = data.set_index(["caseid", "period"]) 28 | return data 29 | 30 | 31 | def test_simulate_dataset(model2, model2_data): 32 | model_dict = model2 33 | params = pd.read_csv(TEST_DIR / "regression_vault" / "one_stage_anchoring.csv") 34 | params = params.set_index(["category", "period", "name1", "name2"]) 35 | 36 | calculated = simulate_dataset( 37 | model_dict=model_dict, 38 | params=params, 39 | data=model2_data, 40 | ) 41 | 42 | factors = ["fac1", "fac2", "fac3"] 43 | expected_ratios = [1.187757, 1, 1] 44 | for factor, expected_ratio in zip(factors, expected_ratios, strict=False): 45 | anch_ranges = calculated["anchored_states"]["state_ranges"][factor] 46 | unanch_ranges = calculated["unanchored_states"]["state_ranges"][factor] 47 | ratio = (anch_ranges / unanch_ranges).to_numpy() 48 | assert np.allclose(ratio, expected_ratio) 49 | 50 | 51 | def test_measurements_from_factors(): 52 | inputs = { 53 | "states": np.array([[0, 0, 0], [1, 1, 1]]), 54 | "controls": np.array([[1, 1], [1, 1]]), 55 | "loadings": np.array([[0.3, 0.3, 0.3], [0.3, 0.3, 0.3], [0.3, 0.3, 0.3]]), 56 | "control_params": np.array([[0.5, 0.5], [0.5, 0.5], [0.5, 0.5]]), 57 | "sds": np.zeros(3), 58 | } 59 | expected = np.array([[1, 1, 1], [1.9, 1.9, 1.9]]) 60 | aaae(measurements_from_states(**inputs), expected) 61 | -------------------------------------------------------------------------------- /tests/test_transition_functions.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import numpy as np 4 | from numpy.testing import assert_array_almost_equal as aaae 5 | 6 | from skillmodels.transition_functions import ( 7 | constant, 8 | linear, 9 | log_ces, 10 | log_ces_general, 11 | params_log_ces_general, 12 | robust_translog, 13 | translog, 14 | ) 15 | 16 | jax.config.update("jax_enable_x64", True) 17 | 18 | 19 | def test_linear(): 20 | states = np.arange(3) 21 | params = np.array([0.1, 0.2, 0.3, 0.4]) 22 | expected = 1.2 23 | aaae(linear(states, params), expected) 24 | 25 | 26 | def test_translog(): 27 | all_states = np.array( 28 | [ 29 | [2, 0, 0], 30 | [0, 3, 0], 31 | [0, 0, 4], 32 | [0, 0, 0], 33 | [1, 1, 1], 34 | [0, -3, 0], 35 | [-1, -1, -1], 36 | [1.5, -2, 1.8], 37 | [12, -34, 48], 38 | ], 39 | ) 40 | 41 | params = np.array( 42 | [ 43 | # linear terms 44 | 0.2, 45 | 0.1, 46 | 0.12, 47 | # square terms 48 | 0.08, 49 | 0.04, 50 | 0.05, 51 | # interactions: The order is 0-1, 0-2, 1-2 52 | 0.05, 53 | 0.03, 54 | 0.06, 55 | # constant 56 | 0.04, 57 | ], 58 | ) 59 | 60 | expected_translog = [0.76, 0.7, 1.32, 0.04, 0.77, 0.1, -0.07, 0.573, 76.72] 61 | 62 | for states, expected in zip(all_states, expected_translog, strict=False): 63 | calculated = translog(states, params) 64 | aaae(calculated, expected) 65 | 66 | 67 | def test_log_ces(): 68 | states = np.array([3, 7.5]) 69 | params = jnp.array([0.4, 0.6, 2]) 70 | expected = 7.244628323025 71 | calculated = log_ces(states, params) 72 | aaae(calculated, expected) 73 | 74 | 75 | def test_where_all_but_one_gammas_are_zero(): 76 | """This has to be tested, becaus it leads to an underflow in the log step.""" 77 | states = jnp.ones(3) 78 | params = jnp.array([0, 0, 1, -0.5]) 79 | calculated = log_ces(states, params) 80 | expected = 1.0 81 | aaae(calculated, expected) 82 | 83 | 84 | def test_constant(): 85 | assert constant("bla", "blubb") == "bla" 86 | 87 | 88 | def test_robust_translog(): 89 | all_states = np.array( 90 | [ 91 | [2, 0, 0], 92 | [0, 3, 0], 93 | [0, 0, 4], 94 | [0, 0, 0], 95 | [1, 1, 1], 96 | [0, -3, 0], 97 | [-1, -1, -1], 98 | [1.5, -2, 1.8], 99 | [12, -34, 48], 100 | ], 101 | ) 102 | 103 | params = np.array( 104 | [ 105 | # linear terms 106 | 0.2, 107 | 0.1, 108 | 0.12, 109 | # square terms 110 | 0.08, 111 | 0.04, 112 | 0.05, 113 | # interactions: The order is 0-1, 0-2, 1-2 114 | 0.05, 115 | 0.03, 116 | 0.06, 117 | # constant 118 | 0.04, 119 | ], 120 | ) 121 | 122 | expected_translog = [0.76, 0.7, 1.32, 0.04, 0.77, 0.1, -0.07, 0.573, 76.72] 123 | 124 | for states, expected in zip(all_states, expected_translog, strict=False): 125 | calculated = robust_translog(states, params) 126 | aaae(calculated, expected) 127 | 128 | 129 | def test_log_ces_general(): 130 | states = np.array([3, 7.5]) 131 | params = jnp.array([0.4, 0.6, 2, 2, 0.5]) 132 | expected = 7.244628323025 133 | calculated = log_ces_general(states, params) 134 | aaae(calculated, expected) 135 | 136 | 137 | def test_log_ces_general_where_all_but_one_gammas_are_zero(): 138 | """This has to be tested, becaus it leads to an underflow in the log step.""" 139 | states = jnp.ones(3) 140 | params = jnp.array([0, 0, 1, -0.5, -0.5, -0.5, -2]) 141 | calculated = log_ces_general(states, params) 142 | expected = 1.0 143 | aaae(calculated, expected) 144 | 145 | 146 | def test_param_names_log_ces_general(): 147 | factors = ["a", "b"] 148 | expected = ["a", "b", "sigma_a", "sigma_b", "tfp"] 149 | calculated = params_log_ces_general(factors) 150 | assert calculated == expected 151 | -------------------------------------------------------------------------------- /tests/test_utilities.py: -------------------------------------------------------------------------------- 1 | """Test utility functions. 2 | 3 | All tests should not only assert that modified model specifications are correct but 4 | also that there are no side effects on the inputs. 5 | 6 | """ 7 | 8 | from pathlib import Path 9 | 10 | import numpy as np 11 | import pandas as pd 12 | import pytest 13 | import yaml 14 | from pandas.testing import assert_frame_equal, assert_index_equal 15 | 16 | from skillmodels.process_model import process_model 17 | from skillmodels.utilities import ( 18 | _get_params_index_from_model_dict, 19 | _remove_from_dict, 20 | _remove_from_list, 21 | _shorten_if_necessary, 22 | extract_factors, 23 | reduce_n_periods, 24 | remove_controls, 25 | remove_factors, 26 | remove_measurements, 27 | switch_linear_to_translog, 28 | switch_translog_to_linear, 29 | update_parameter_values, 30 | ) 31 | 32 | # importing the TEST_DIR from config does not work for test run in conda build 33 | TEST_DIR = Path(__file__).parent.resolve() 34 | 35 | 36 | @pytest.fixture 37 | def model2(): 38 | with open(TEST_DIR / "model2.yaml") as y: 39 | model_dict = yaml.load(y, Loader=yaml.FullLoader) 40 | return model_dict 41 | 42 | 43 | @pytest.mark.parametrize("factors", ["fac2", ["fac2"]]) 44 | def test_extract_factors_single(model2, factors): 45 | reduced = extract_factors(factors, model2) 46 | assert list(reduced["factors"]) == ["fac2"] 47 | assert list(model2["factors"]) == ["fac1", "fac2", "fac3"] 48 | assert "anchoring" not in reduced 49 | assert model2["anchoring"]["outcomes"] == {"fac1": "Q1"} 50 | process_model(reduced) 51 | 52 | 53 | def test_update_parameter_values(): 54 | params = pd.DataFrame() 55 | params["value"] = np.arange(5, dtype=np.int64) 56 | 57 | others = [ 58 | pd.DataFrame([[7], [8]], columns=["value"], index=[1, 4]), 59 | pd.DataFrame([[9]], columns=["value"], index=[2]), 60 | ] 61 | 62 | expected = pd.DataFrame() 63 | expected["value"] = [0, 7, 9, 3, 8] 64 | 65 | calculated = update_parameter_values(params, others) 66 | assert_frame_equal(calculated, expected) 67 | 68 | 69 | @pytest.mark.parametrize("factors", ["fac2", ["fac2"]]) 70 | def test_remove_factors(model2, factors): 71 | reduced = remove_factors(factors, model2) 72 | assert list(reduced["factors"]) == ["fac1", "fac3"] 73 | assert list(model2["factors"]) == ["fac1", "fac2", "fac3"] 74 | assert "anchoring" in reduced 75 | process_model(reduced) 76 | 77 | 78 | @pytest.mark.parametrize("measurements", ["y5", ["y5"]]) 79 | def test_remove_measurements(model2, measurements): 80 | reduced = remove_measurements(measurements, model2) 81 | assert reduced["factors"]["fac2"]["measurements"] == [["y4", "y6"]] * 8 82 | assert "y5" in model2["factors"]["fac2"]["measurements"][0] 83 | process_model(reduced) 84 | 85 | 86 | @pytest.mark.parametrize("controls", ["x1", ["x1"]]) 87 | def test_remove_controls(model2, controls): 88 | reduced = remove_controls(controls, model2) 89 | assert "controls" not in reduced 90 | assert "controls" in model2 91 | process_model(reduced) 92 | 93 | 94 | def test_reduce_n_periods(model2): 95 | reduced = reduce_n_periods(model2, 1) 96 | assert reduced["factors"]["fac1"]["measurements"] == [["y1", "y2", "y3"]] 97 | assert reduced["factors"]["fac2"]["normalizations"]["loadings"] == [{"y4": 1}] 98 | process_model(reduced) 99 | 100 | 101 | def test_switch_linear_to_translog(model2): 102 | switched = switch_linear_to_translog(model2) 103 | assert switched["factors"]["fac2"]["transition_function"] == "translog" 104 | 105 | 106 | def test_switch_linear_and_translog_back_and_forth(model2): 107 | with_translog = switch_linear_to_translog(model2) 108 | with_linear = switch_translog_to_linear(with_translog) 109 | assert model2 == with_linear 110 | 111 | 112 | @pytest.mark.parametrize("to_remove", ["a", ["a"]]) 113 | def test_remove_from_list(to_remove): 114 | list_ = ["a", "b", "c"] 115 | calculated = _remove_from_list(list_, to_remove) 116 | assert calculated == ["b", "c"] 117 | assert list_ == ["a", "b", "c"] 118 | 119 | 120 | @pytest.mark.parametrize("to_remove", ["a", ["a"]]) 121 | def test_remove_from_dict(to_remove): 122 | dict_ = {"a": 1, "b": 2, "c": 3} 123 | calculated = _remove_from_dict(dict_, to_remove) 124 | assert calculated == {"b": 2, "c": 3} 125 | assert dict_ == {"a": 1, "b": 2, "c": 3} 126 | 127 | 128 | def test_reduce_params_via_extract_factors(model2): 129 | model_dict = reduce_n_periods(model2, 2) 130 | 131 | full_index = _get_params_index_from_model_dict(model_dict) 132 | params = pd.DataFrame(columns=["value"], index=full_index) 133 | 134 | _, reduced_params = extract_factors("fac3", model_dict, params) 135 | 136 | expected_index = pd.MultiIndex.from_tuples( 137 | [ 138 | ("controls", 0, "y7", "constant"), 139 | ("controls", 0, "y7", "x1"), 140 | ("controls", 0, "y8", "constant"), 141 | ("controls", 0, "y8", "x1"), 142 | ("controls", 0, "y9", "constant"), 143 | ("controls", 0, "y9", "x1"), 144 | ("loadings", 0, "y7", "fac3"), 145 | ("loadings", 0, "y8", "fac3"), 146 | ("loadings", 0, "y9", "fac3"), 147 | ("meas_sds", 0, "y7", "-"), 148 | ("meas_sds", 0, "y8", "-"), 149 | ("meas_sds", 0, "y9", "-"), 150 | ("initial_states", 0, "mixture_0", "fac3"), 151 | ("mixture_weights", 0, "mixture_0", "-"), 152 | ("initial_cholcovs", 0, "mixture_0", "fac3-fac3"), 153 | ], 154 | names=["category", "period", "name1", "name2"], 155 | ) 156 | 157 | assert_index_equal(reduced_params.index, expected_index) 158 | 159 | 160 | def test_extend_params_via_switch_to_translog(model2): 161 | model_dict = reduce_n_periods(model2, 2) 162 | normal_index = _get_params_index_from_model_dict(model_dict) 163 | params = pd.DataFrame(columns=["value"], index=normal_index) 164 | 165 | _, extended_params = switch_linear_to_translog(model_dict, params) 166 | 167 | added_index = extended_params.index.difference(normal_index) 168 | 169 | expected_added_index = pd.MultiIndex.from_tuples( 170 | [ 171 | ("transition", 0, "fac2", "fac1 * fac2"), 172 | ("transition", 0, "fac2", "fac1 * fac3"), 173 | ("transition", 0, "fac2", "fac1 ** 2"), 174 | ("transition", 0, "fac2", "fac2 * fac3"), 175 | ("transition", 0, "fac2", "fac2 ** 2"), 176 | ("transition", 0, "fac2", "fac3 ** 2"), 177 | ], 178 | names=["category", "period", "name1", "name2"], 179 | ) 180 | 181 | assert_index_equal(added_index, expected_added_index) 182 | 183 | assert extended_params.loc[added_index, "value"].unique()[0] == 0.05 184 | 185 | 186 | def test_shorten_if_necessary(): 187 | list_ = list(range(3)) 188 | not_necessary = _shorten_if_necessary(list_, 5) 189 | assert not_necessary == list_ 190 | 191 | necessary = _shorten_if_necessary(list_, 2) 192 | assert necessary == [0, 1] 193 | -------------------------------------------------------------------------------- /tests/test_visualize_factor_distributions.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import pandas as pd 4 | import yaml 5 | 6 | from skillmodels.maximization_inputs import get_maximization_inputs 7 | from skillmodels.simulate_data import simulate_dataset 8 | from skillmodels.visualize_factor_distributions import ( 9 | bivariate_density_contours, 10 | bivariate_density_surfaces, 11 | combine_distribution_plots, 12 | univariate_densities, 13 | ) 14 | 15 | # importing the TEST_DIR from config does not work for test run in conda build 16 | TEST_DIR = Path(__file__).parent.resolve() 17 | 18 | 19 | def test_visualize_factor_distributions_runs_with_filtered_states(): 20 | with open(TEST_DIR / "model2.yaml") as y: 21 | model_dict = yaml.load(y, Loader=yaml.FullLoader) 22 | 23 | params = pd.read_csv(TEST_DIR / "regression_vault" / "one_stage_anchoring.csv") 24 | params = params.set_index(["category", "period", "name1", "name2"]) 25 | 26 | data = pd.read_stata(TEST_DIR / "model2_simulated_data.dta") 27 | data.set_index(["caseid", "period"], inplace=True) 28 | 29 | max_inputs = get_maximization_inputs(model_dict, data) 30 | params = params.loc[max_inputs["params_template"].index] 31 | kde = univariate_densities( 32 | data=data, 33 | model_dict=model_dict, 34 | params=params, 35 | period=1, 36 | ) 37 | contours = bivariate_density_contours( 38 | data=data, 39 | model_dict=model_dict, 40 | params=params, 41 | period=1, 42 | ) 43 | surfaces = bivariate_density_surfaces( 44 | data=data, 45 | model_dict=model_dict, 46 | params=params, 47 | period=1, 48 | ) 49 | combine_distribution_plots( 50 | kde_plots=kde, 51 | contour_plots=contours, 52 | surface_plots=surfaces, 53 | ) 54 | 55 | 56 | def test_visualize_factor_distributions_runs_with_simulated_states(): 57 | with open(TEST_DIR / "model2.yaml") as y: 58 | model_dict = yaml.load(y, Loader=yaml.FullLoader) 59 | 60 | data = pd.read_stata(TEST_DIR / "model2_simulated_data.dta") 61 | data.set_index(["caseid", "period"], inplace=True) 62 | 63 | params = pd.read_csv(TEST_DIR / "regression_vault" / "one_stage_anchoring.csv") 64 | params = params.set_index(["category", "period", "name1", "name2"]) 65 | 66 | max_inputs = get_maximization_inputs(model_dict, data) 67 | params = params.loc[max_inputs["params_template"].index] 68 | 69 | latent_data = simulate_dataset(model_dict, params, data=data, policies=None)[ 70 | "unanchored_states" 71 | ]["states"] 72 | 73 | kde = univariate_densities( 74 | data=data, 75 | states=latent_data, 76 | model_dict=model_dict, 77 | params=params, 78 | period=1, 79 | ) 80 | contours = bivariate_density_contours( 81 | data=data, 82 | states=latent_data, 83 | model_dict=model_dict, 84 | params=params, 85 | period=1, 86 | ) 87 | combine_distribution_plots( 88 | kde_plots=kde, 89 | contour_plots=contours, 90 | surface_plots=None, 91 | ) 92 | -------------------------------------------------------------------------------- /tests/test_visualize_transition_equations.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import pandas as pd 4 | import yaml 5 | 6 | from skillmodels.maximization_inputs import get_maximization_inputs 7 | from skillmodels.visualize_transition_equations import ( 8 | combine_transition_plots, 9 | get_transition_plots, 10 | ) 11 | 12 | TEST_DIR = Path(__file__).parent.resolve() 13 | 14 | 15 | def test_visualize_transition_equations_runs(): 16 | with open(TEST_DIR / "model2.yaml") as y: 17 | model_dict = yaml.load(y, Loader=yaml.FullLoader) 18 | 19 | model_dict["observed_factors"] = ["ob1"] 20 | 21 | params = pd.read_csv(TEST_DIR / "regression_vault" / "one_stage_anchoring.csv") 22 | params = params.set_index(["category", "period", "name1", "name2"]) 23 | 24 | data = pd.read_stata(TEST_DIR / "model2_simulated_data.dta") 25 | data.set_index(["caseid", "period"], inplace=True) 26 | data["ob1"] = 0 27 | 28 | max_inputs = get_maximization_inputs(model_dict, data) 29 | full_index = max_inputs["params_template"].index 30 | params = params.reindex(full_index) 31 | params["value"] = params["value"].fillna(0) 32 | subplots = get_transition_plots( 33 | model_dict=model_dict, 34 | params=params, 35 | period=0, 36 | quantiles_of_other_factors=[0.1, 0.25, 0.5, 0.75, 0.9], 37 | data=data, 38 | ) 39 | combine_transition_plots(subplots) 40 | subplots = get_transition_plots( 41 | model_dict=model_dict, 42 | params=params, 43 | period=0, 44 | quantiles_of_other_factors=None, 45 | data=data, 46 | ) 47 | combine_transition_plots(subplots) 48 | --------------------------------------------------------------------------------