├── .github └── workflows │ └── run_tests.yml ├── .gitignore ├── .pre-commit-config.yaml ├── .pylintrc ├── INPUT.md ├── README.md ├── REQUIREMENTS.txt ├── codecov.yml ├── conda_envs └── pymc_statespace.yml ├── data └── nile.csv ├── examples ├── ARMA Example.ipynb ├── Custom SSM - Daily Seasonality.ipynb ├── Nile Local Level Model.ipynb └── VARMAX Example.ipynb ├── pymc_statespace ├── __init__.py ├── core │ ├── __init__.py │ ├── representation.py │ └── statespace.py ├── filters │ ├── __init__.py │ ├── kalman_filter.py │ ├── kalman_smoother.py │ └── utilities.py ├── models │ ├── SARIMAX.py │ ├── VARMAX.py │ ├── __init__.py │ ├── local_level.py │ └── utilities.py └── utils │ ├── __init__.py │ ├── numba_linalg.py │ ├── pytensor_scipy.py │ └── simulation.py ├── pyproject.toml ├── setup.cfg ├── setup.py ├── svgs ├── 17a925cbd243c9dd3ce20fd8558993f1.svg ├── 17b59c002f249204f24e31507dc4957d.svg ├── 2d6502498c2ef42e278a774086e73a26.svg ├── 3b5e41543d7fc8cedf98ec609b343134.svg ├── 4881244fda86ce9792cccafb3bb7eb0c.svg ├── 523b266d36c270dbbb5daf2c9092ce0f.svg ├── 54221efbfb5e69569dfe8ddea785093a.svg ├── 5dfc2ae9e19de8995e3edb059f7abd19.svg ├── 92b8c1194757fb3131cda468a34be85f.svg ├── a06c0e58d4d162b0e87d32927c9812db.svg ├── a13d89295e999545a129b2d412e99f6d.svg ├── c1133ef3a57a33193b33c998a34246cc.svg ├── cac7e81ebde5e530e639eae5389f149e.svg ├── d523a14b8179ebe46f0ed16895ee46f0.svg ├── dfbe60bd49a89dc2de3950ee0ffab3f3.svg ├── edcff444fd5240add1c47d2de50ebd7e.svg ├── f566e90ed17c5292db4600846e0ace27.svg └── ff25a8f22c7430ca572d33206c0a9176.svg └── tests ├── __init__.py ├── test_VARMAX.py ├── test_data ├── nile.csv └── statsmodels_macrodata_processed.csv ├── test_kalman_filter.py ├── test_local_level.py ├── test_numba_linalg.py ├── test_pytensor_scipy.py ├── test_representation.py ├── test_simulations.py ├── test_statespace.py └── utilities ├── __init__.py ├── statsmodel_local_level.py └── test_helpers.py /.github/workflows/run_tests.yml: -------------------------------------------------------------------------------- 1 | name: run_tests 2 | 3 | on: 4 | pull_request: 5 | push: 6 | branches: [main] 7 | 8 | 9 | jobs: 10 | unittest: 11 | strategy: 12 | fail-fast: false 13 | matrix: 14 | os: [ubuntu-latest, windows-latest] 15 | python-version: ["3.9"] 16 | test-subset: 17 | - | 18 | tests/ 19 | 20 | runs-on: ${{ matrix.os }} 21 | 22 | env: 23 | TEST_SUBSET: ${{ matrix.test-subset }} 24 | 25 | defaults: 26 | run: 27 | shell: bash -l {0} 28 | 29 | steps: 30 | - uses: actions/checkout@v3 31 | - uses: actions/cache@v3 32 | env: 33 | # Increase this value to reset cache if pymc_statespace.yml has not changed 34 | CACHE_NUMBER: 0 35 | with: 36 | path: ~/conda_pkgs_dir 37 | key: ${{ runner.os }}-py${{matrix.python-version}}-conda-${{ env.CACHE_NUMBER }}-${{ 38 | hashFiles('conda_envs/pymc_statespace.yml') }} 39 | - name: Cache multiple paths 40 | uses: actions/cache@v3 41 | env: 42 | # Increase this value to reset cache if requirements.txt has not changed 43 | CACHE_NUMBER: 0 44 | with: 45 | path: | 46 | ~/.cache/pip 47 | $RUNNER_TOOL_CACHE/Python/* 48 | ~\AppData\Local\pip\Cache 49 | key: ${{ runner.os }}-build-${{ matrix.python-version }}-${{ env.CACHE_NUMBER }}-${{ 50 | hashFiles('requirements.txt') }} 51 | - uses: conda-incubator/setup-miniconda@v2 52 | with: 53 | miniforge-variant: Mambaforge 54 | miniforge-version: latest 55 | mamba-version: "*" 56 | activate-environment: pymc-statespace 57 | channel-priority: flexible 58 | environment-file: conda_envs/pymc_statespace.yml 59 | python-version: 3.9 60 | use-mamba: true 61 | use-only-tar-bz2: false # IMPORTANT: This may break caching of conda packages! See https://github.com/conda-incubator/setup-miniconda/issues/267 62 | 63 | - name: Install current branch 64 | run: | 65 | conda activate pymc-statespace 66 | pip install -e . 67 | python --version 68 | 69 | - name: Run tests 70 | run: | 71 | python -m pytest -vv --cov=pymc_statespace --cov-report=xml --no-cov-on-fail --cov-report term $TEST_SUBSET 72 | - name: Upload coverage to Codecov 73 | uses: codecov/codecov-action@v3 74 | with: 75 | token: ${{ secrets.CODECOV_TOKEN }} # use token for more robust uploads 76 | env_vars: TEST_SUBSET 77 | name: ${{ matrix.os }} 78 | fail_ci_if_error: false 79 | verbose: true 80 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by https://www.toptal.com/developers/gitignore/api/macos,windows,pycharm,jupyternotebooks 2 | # Edit at https://www.toptal.com/developers/gitignore?templates=macos,windows,pycharm,jupyternotebooks 3 | 4 | **/__pycache__/ 5 | 6 | ### JupyterNotebooks ### 7 | # gitignore template for Jupyter Notebooks 8 | # website: http://jupyter.org/ 9 | 10 | .ipynb_checkpoints 11 | */.ipynb_checkpoints/* 12 | 13 | # IPython 14 | profile_default/ 15 | ipython_config.py 16 | 17 | # Remove previous ipynb_checkpoints 18 | # git rm -r .ipynb_checkpoints/ 19 | 20 | ### macOS ### 21 | # General 22 | .DS_Store 23 | .AppleDouble 24 | .LSOverride 25 | 26 | # Icon must end with two \r 27 | Icon 28 | 29 | 30 | # Thumbnails 31 | ._* 32 | 33 | # Files that might appear in the root of a volume 34 | .DocumentRevisions-V100 35 | .fseventsd 36 | .Spotlight-V100 37 | .TemporaryItems 38 | .Trashes 39 | .VolumeIcon.icns 40 | .com.apple.timemachine.donotpresent 41 | 42 | # Directories potentially created on remote AFP share 43 | .AppleDB 44 | .AppleDesktop 45 | Network Trash Folder 46 | Temporary Items 47 | .apdisk 48 | 49 | ### macOS Patch ### 50 | # iCloud generated files 51 | *.icloud 52 | 53 | ### PyCharm ### 54 | .idea/ 55 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider 56 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 57 | 58 | # User-specific stuff 59 | .idea/**/workspace.xml 60 | .idea/**/tasks.xml 61 | .idea/**/usage.statistics.xml 62 | .idea/**/dictionaries 63 | .idea/**/shelf 64 | 65 | # AWS User-specific 66 | .idea/**/aws.xml 67 | 68 | # Generated files 69 | .idea/**/contentModel.xml 70 | 71 | # Sensitive or high-churn files 72 | .idea/**/dataSources/ 73 | .idea/**/dataSources.ids 74 | .idea/**/dataSources.local.xml 75 | .idea/**/sqlDataSources.xml 76 | .idea/**/dynamic.xml 77 | .idea/**/uiDesigner.xml 78 | .idea/**/dbnavigator.xml 79 | 80 | # Gradle 81 | .idea/**/gradle.xml 82 | .idea/**/libraries 83 | 84 | # Gradle and Maven with auto-import 85 | # When using Gradle or Maven with auto-import, you should exclude module files, 86 | # since they will be recreated, and may cause churn. Uncomment if using 87 | # auto-import. 88 | # .idea/artifacts 89 | # .idea/compiler.xml 90 | # .idea/jarRepositories.xml 91 | # .idea/modules.xml 92 | # .idea/*.iml 93 | # .idea/modules 94 | # *.iml 95 | # *.ipr 96 | 97 | # CMake 98 | cmake-build-*/ 99 | 100 | # Mongo Explorer plugin 101 | .idea/**/mongoSettings.xml 102 | 103 | # File-based project format 104 | *.iws 105 | 106 | # IntelliJ 107 | out/ 108 | 109 | # mpeltonen/sbt-idea plugin 110 | .idea_modules/ 111 | 112 | # JIRA plugin 113 | atlassian-ide-plugin.xml 114 | 115 | # Cursive Clojure plugin 116 | .idea/replstate.xml 117 | 118 | # SonarLint plugin 119 | .idea/sonarlint/ 120 | 121 | # Crashlytics plugin (for Android Studio and IntelliJ) 122 | com_crashlytics_export_strings.xml 123 | crashlytics.properties 124 | crashlytics-build.properties 125 | fabric.properties 126 | 127 | # Editor-based Rest Client 128 | .idea/httpRequests 129 | 130 | # Android studio 3.1+ serialized cache file 131 | .idea/caches/build_file_checksums.ser 132 | 133 | ### PyCharm Patch ### 134 | # Comment Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-215987721 135 | 136 | # *.iml 137 | # modules.xml 138 | # .idea/misc.xml 139 | # *.ipr 140 | 141 | # Sonarlint plugin 142 | # https://plugins.jetbrains.com/plugin/7973-sonarlint 143 | .idea/**/sonarlint/ 144 | 145 | # SonarQube Plugin 146 | # https://plugins.jetbrains.com/plugin/7238-sonarqube-community-plugin 147 | .idea/**/sonarIssues.xml 148 | 149 | # Markdown Navigator plugin 150 | # https://plugins.jetbrains.com/plugin/7896-markdown-navigator-enhanced 151 | .idea/**/markdown-navigator.xml 152 | .idea/**/markdown-navigator-enh.xml 153 | .idea/**/markdown-navigator/ 154 | 155 | # Cache file creation bug 156 | # See https://youtrack.jetbrains.com/issue/JBR-2257 157 | .idea/$CACHE_FILE$ 158 | 159 | # CodeStream plugin 160 | # https://plugins.jetbrains.com/plugin/12206-codestream 161 | .idea/codestream.xml 162 | 163 | # Azure Toolkit for IntelliJ plugin 164 | # https://plugins.jetbrains.com/plugin/8053-azure-toolkit-for-intellij 165 | .idea/**/azureSettings.xml 166 | 167 | ### Windows ### 168 | # Windows thumbnail cache files 169 | Thumbs.db 170 | Thumbs.db:encryptable 171 | ehthumbs.db 172 | ehthumbs_vista.db 173 | 174 | # Dump file 175 | *.stackdump 176 | 177 | # Folder config file 178 | [Dd]esktop.ini 179 | 180 | # Recycle Bin used on file shares 181 | $RECYCLE.BIN/ 182 | 183 | # Windows Installer files 184 | *.cab 185 | *.msi 186 | *.msix 187 | *.msm 188 | *.msp 189 | 190 | # Windows shortcuts 191 | *.lnk 192 | 193 | # End of https://www.toptal.com/developers/gitignore/api/macos,windows,pycharm,jupyternotebooks 194 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v4.3.0 4 | hooks: 5 | - id: check-merge-conflict 6 | - id: check-toml 7 | - id: check-yaml 8 | - id: debug-statements 9 | - id: end-of-file-fixer 10 | exclude: .txt$ 11 | - id: trailing-whitespace 12 | - id: requirements-txt-fixer 13 | - repo: https://github.com/PyCQA/isort 14 | rev: 5.12.0 15 | hooks: 16 | - id: isort 17 | name: isort 18 | - repo: https://github.com/asottile/pyupgrade 19 | rev: v3.3.1 20 | hooks: 21 | - id: pyupgrade 22 | args: [--py37-plus] 23 | - repo: https://github.com/psf/black 24 | rev: 22.12.0 25 | hooks: 26 | - id: black 27 | - id: black-jupyter 28 | - repo: https://github.com/PyCQA/pylint 29 | rev: v2.16.0b1 30 | hooks: 31 | - id: pylint 32 | args: [--rcfile=.pylintrc] 33 | files: ^gEconpy/ 34 | - repo: local 35 | hooks: 36 | - id: no-relative-imports 37 | name: No relative imports 38 | entry: from \.[\.\w]* import 39 | types: [python] 40 | language: pygrep 41 | -------------------------------------------------------------------------------- /.pylintrc: -------------------------------------------------------------------------------- 1 | [MASTER] 2 | # Use multiple processes to speed up Pylint. 3 | jobs=1 4 | 5 | # Allow loading of arbitrary C extensions. Extensions are imported into the 6 | # active Python interpreter and may run arbitrary code. 7 | unsafe-load-any-extension=no 8 | 9 | # Allow optimization of some AST trees. This will activate a peephole AST 10 | # optimizer, which will apply various small optimizations. For instance, it can 11 | # be used to obtain the result of joining multiple strings with the addition 12 | # operator. Joining a lot of strings can lead to a maximum recursion error in 13 | # Pylint and this flag can prevent that. It has one side effect, the resulting 14 | # AST will be different than the one from reality. 15 | optimize-ast=no 16 | 17 | [MESSAGES CONTROL] 18 | 19 | # Only show warnings with the listed confidence levels. Leave empty to show 20 | # all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED 21 | confidence= 22 | 23 | # Disable the message, report, category or checker with the given id(s). You 24 | # can either give multiple identifiers separated by comma (,) or put this 25 | # option multiple times (only on the command line, not in the configuration 26 | # file where it should appear only once).You can also use "--disable=all" to 27 | # disable everything first and then reenable specific checks. For example, if 28 | # you want to run only the similarities checker, you can use "--disable=all 29 | # --enable=similarities". If you want to run only the classes checker, but have 30 | # no Warning level messages displayed, use"--disable=all --enable=classes 31 | # --disable=W" 32 | disable=all 33 | 34 | # Enable the message, report, category or checker with the given id(s). You can 35 | # either give multiple identifier separated by comma (,) or put this option 36 | # multiple time. See also the "--disable" option for examples. 37 | enable=import-self, 38 | reimported, 39 | wildcard-import, 40 | misplaced-future, 41 | relative-import, 42 | deprecated-module, 43 | unpacking-non-sequence, 44 | invalid-all-object, 45 | undefined-all-variable, 46 | used-before-assignment, 47 | cell-var-from-loop, 48 | global-variable-undefined, 49 | dangerous-default-value, 50 | # redefined-builtin, 51 | redefine-in-handler, 52 | unused-import, 53 | unused-wildcard-import, 54 | global-variable-not-assigned, 55 | undefined-loop-variable, 56 | global-statement, 57 | global-at-module-level, 58 | bad-open-mode, 59 | redundant-unittest-assert, 60 | boolean-datetime, 61 | # unused-variable 62 | 63 | 64 | [REPORTS] 65 | 66 | # Set the output format. Available formats are text, parseable, colorized, msvs 67 | # (visual studio) and html. You can also give a reporter class, eg 68 | # mypackage.mymodule.MyReporterClass. 69 | output-format=parseable 70 | 71 | # Put messages in a separate file for each module / package specified on the 72 | # command line instead of printing them on stdout. Reports (if any) will be 73 | # written in a file name "pylint_global.[txt|html]". 74 | files-output=no 75 | 76 | # Tells whether to display a full report or only the messages 77 | reports=no 78 | 79 | # Python expression which should return a note less than 10 (10 is the highest 80 | # note). You have access to the variables errors warning, statement which 81 | # respectively contain the number of errors / warnings messages and the total 82 | # number of statements analyzed. This is used by the global evaluation report 83 | # (RP0004). 84 | evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) 85 | 86 | [BASIC] 87 | 88 | # List of builtins function names that should not be used, separated by a comma 89 | bad-functions=map,filter,input 90 | 91 | # Good variable names which should always be accepted, separated by a comma 92 | good-names=i,j,k,ex,Run,_ 93 | 94 | # Bad variable names which should always be refused, separated by a comma 95 | bad-names=foo,bar,baz,toto,tutu,tata 96 | 97 | # Colon-delimited sets of names that determine each other's naming style when 98 | # the name regexes allow several styles. 99 | name-group= 100 | 101 | # Include a hint for the correct naming format with invalid-name 102 | include-naming-hint=yes 103 | 104 | # Regular expression matching correct method names 105 | method-rgx=[a-z_][a-z0-9_]{2,30}$ 106 | 107 | # Naming hint for method names 108 | method-name-hint=[a-z_][a-z0-9_]{2,30}$ 109 | 110 | # Regular expression matching correct function names 111 | function-rgx=[a-z_][a-z0-9_]{2,30}$ 112 | 113 | # Naming hint for function names 114 | function-name-hint=[a-z_][a-z0-9_]{2,30}$ 115 | 116 | # Regular expression matching correct module names 117 | module-rgx=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$ 118 | 119 | # Naming hint for module names 120 | module-name-hint=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$ 121 | 122 | # Regular expression matching correct attribute names 123 | attr-rgx=[a-z_][a-z0-9_]{2,30}$ 124 | 125 | # Naming hint for attribute names 126 | attr-name-hint=[a-z_][a-z0-9_]{2,30}$ 127 | 128 | # Regular expression matching correct class attribute names 129 | class-attribute-rgx=([A-Za-z_][A-Za-z0-9_]{2,30}|(__.*__))$ 130 | 131 | # Naming hint for class attribute names 132 | class-attribute-name-hint=([A-Za-z_][A-Za-z0-9_]{2,30}|(__.*__))$ 133 | 134 | # Regular expression matching correct constant names 135 | const-rgx=(([A-Z_][A-Z0-9_]*)|(__.*__))$ 136 | 137 | # Naming hint for constant names 138 | const-name-hint=(([A-Z_][A-Z0-9_]*)|(__.*__))$ 139 | 140 | # Regular expression matching correct class names 141 | class-rgx=[A-Z_][a-zA-Z0-9]+$ 142 | 143 | # Naming hint for class names 144 | class-name-hint=[A-Z_][a-zA-Z0-9]+$ 145 | 146 | # Regular expression matching correct argument names 147 | argument-rgx=[a-z_][a-z0-9_]{2,30}$ 148 | 149 | # Naming hint for argument names 150 | argument-name-hint=[a-z_][a-z0-9_]{2,30}$ 151 | 152 | # Regular expression matching correct inline iteration names 153 | inlinevar-rgx=[A-Za-z_][A-Za-z0-9_]*$ 154 | 155 | # Naming hint for inline iteration names 156 | inlinevar-name-hint=[A-Za-z_][A-Za-z0-9_]*$ 157 | 158 | # Regular expression matching correct variable names 159 | variable-rgx=[a-z_][a-z0-9_]{2,30}$ 160 | 161 | # Naming hint for variable names 162 | variable-name-hint=[a-z_][a-z0-9_]{2,30}$ 163 | 164 | # Regular expression which should only match function or class names that do 165 | # not require a docstring. 166 | no-docstring-rgx=^_ 167 | 168 | # Minimum line length for functions/classes that require docstrings, shorter 169 | # ones are exempt. 170 | docstring-min-length=-1 171 | 172 | 173 | [ELIF] 174 | 175 | # Maximum number of nested blocks for function / method body 176 | max-nested-blocks=5 177 | 178 | 179 | [FORMAT] 180 | 181 | # Maximum number of characters on a single line. 182 | max-line-length=100 183 | 184 | # Regexp for a line that is allowed to be longer than the limit. 185 | ignore-long-lines=^\s*(# )??$ 186 | 187 | # Allow the body of an if to be on the same line as the test if there is no 188 | # else. 189 | single-line-if-stmt=no 190 | 191 | # List of optional constructs for which whitespace checking is disabled. `dict- 192 | # separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}. 193 | # `trailing-comma` allows a space between comma and closing bracket: (a, ). 194 | # `empty-line` allows space-only lines. 195 | no-space-check=trailing-comma,dict-separator 196 | 197 | # Maximum number of lines in a module 198 | max-module-lines=1000 199 | 200 | # String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 201 | # tab). 202 | indent-string=' ' 203 | 204 | # Number of spaces of indent required inside a hanging or continued line. 205 | indent-after-paren=4 206 | 207 | # Expected format of line ending, e.g. empty (any line ending), LF or CRLF. 208 | expected-line-ending-format= 209 | 210 | 211 | [LOGGING] 212 | 213 | # Logging modules to check that the string format arguments are in logging 214 | # function parameter format 215 | logging-modules=logging 216 | 217 | 218 | [MISCELLANEOUS] 219 | 220 | # List of note tags to take in consideration, separated by a comma. 221 | notes=FIXME,XXX,TODO 222 | 223 | 224 | [SIMILARITIES] 225 | 226 | # Minimum lines number of a similarity. 227 | min-similarity-lines=4 228 | 229 | # Ignore comments when computing similarities. 230 | ignore-comments=yes 231 | 232 | # Ignore docstrings when computing similarities. 233 | ignore-docstrings=yes 234 | 235 | # Ignore imports when computing similarities. 236 | ignore-imports=no 237 | 238 | 239 | [SPELLING] 240 | 241 | # Spelling dictionary name. Available dictionaries: none. To make it working 242 | # install python-enchant package. 243 | spelling-dict= 244 | 245 | # List of comma separated words that should not be checked. 246 | spelling-ignore-words= 247 | 248 | # A path to a file that contains private dictionary; one word per line. 249 | spelling-private-dict-file= 250 | 251 | # Tells whether to store unknown words to indicated private dictionary in 252 | # --spelling-private-dict-file option instead of raising a message. 253 | spelling-store-unknown-words=no 254 | 255 | 256 | [TYPECHECK] 257 | 258 | # Tells whether missing members accessed in mixin class should be ignored. A 259 | # mixin class is detected if its name ends with "mixin" (case insensitive). 260 | ignore-mixin-members=yes 261 | 262 | # List of module names for which member attributes should not be checked 263 | # (useful for modules/projects where namespaces are manipulated during runtime 264 | # and thus existing member attributes cannot be deduced by static analysis. It 265 | # supports qualified module names, as well as Unix pattern matching. 266 | ignored-modules= 267 | 268 | # List of classes names for which member attributes should not be checked 269 | # (useful for classes with attributes dynamically set). This supports can work 270 | # with qualified names. 271 | ignored-classes= 272 | 273 | # List of members which are set dynamically and missed by pylint inference 274 | # system, and so shouldn't trigger E1101 when accessed. Python regular 275 | # expressions are accepted. 276 | generated-members= 277 | 278 | 279 | [VARIABLES] 280 | 281 | # Tells whether we should check for unused import in __init__ files. 282 | init-import=no 283 | 284 | # A regular expression matching the name of dummy variables (i.e. expectedly 285 | # not used). 286 | dummy-variables-rgx=_$|dummy 287 | 288 | # List of additional names supposed to be defined in builtins. Remember that 289 | # you should avoid to define new builtins when possible. 290 | additional-builtins= 291 | 292 | # List of strings which can identify a callback function by name. A callback 293 | # name must start or end with one of those strings. 294 | callbacks=cb_,_cb 295 | 296 | 297 | [CLASSES] 298 | 299 | # List of method names used to declare (i.e. assign) instance attributes. 300 | defining-attr-methods=__init__,__new__,setUp 301 | 302 | # List of valid names for the first argument in a class method. 303 | valid-classmethod-first-arg=cls 304 | 305 | # List of valid names for the first argument in a metaclass class method. 306 | valid-metaclass-classmethod-first-arg=mcs 307 | 308 | # List of member names, which should be excluded from the protected access 309 | # warning. 310 | exclude-protected=_asdict,_fields,_replace,_source,_make 311 | 312 | 313 | [DESIGN] 314 | 315 | # Maximum number of arguments for function / method 316 | max-args=5 317 | 318 | # Argument names that match this expression will be ignored. Default to name 319 | # with leading underscore 320 | ignored-argument-names=_.* 321 | 322 | # Maximum number of locals for function / method body 323 | max-locals=15 324 | 325 | # Maximum number of return / yield for function / method body 326 | max-returns=6 327 | 328 | # Maximum number of branch for function / method body 329 | max-branches=12 330 | 331 | # Maximum number of statements in function / method body 332 | max-statements=50 333 | 334 | # Maximum number of parents for a class (see R0901). 335 | max-parents=7 336 | 337 | # Maximum number of attributes for a class (see R0902). 338 | max-attributes=7 339 | 340 | # Minimum number of public methods for a class (see R0903). 341 | min-public-methods=2 342 | 343 | # Maximum number of public methods for a class (see R0904). 344 | max-public-methods=20 345 | 346 | # Maximum number of boolean expressions in a if statement 347 | max-bool-expr=5 348 | 349 | 350 | [IMPORTS] 351 | 352 | # Deprecated modules which should not be used, separated by a comma 353 | deprecated-modules=optparse 354 | 355 | # Create a graph of every (i.e. internal and external) dependencies in the 356 | # given file (report RP0402 must not be disabled) 357 | import-graph= 358 | 359 | # Create a graph of external dependencies in the given file (report RP0402 must 360 | # not be disabled) 361 | ext-import-graph= 362 | 363 | # Create a graph of internal dependencies in the given file (report RP0402 must 364 | # not be disabled) 365 | int-import-graph= 366 | 367 | 368 | [EXCEPTIONS] 369 | 370 | # Exceptions that will emit a warning when being caught. Defaults to 371 | # "Exception" 372 | overgeneral-exceptions=Exception 373 | -------------------------------------------------------------------------------- /INPUT.md: -------------------------------------------------------------------------------- 1 | # PyMC StateSpace 2 | A system for Bayesian estimation of state space models using PyMC 4.0. This package is designed to mirror the functionality of the Statsmodels.api `tsa.statespace` module, except within a Bayesian estimation framework. To accomplish this, PyMC Statespace has a Kalman filter written in Pytensor, allowing the gradients of the iterative Kalman filter likelihood to be computed and provided to the PyMC NUTS sampler. 3 | 4 | ## State Space Models 5 | This package follows Statsmodels in using the Durbin and Koopman (2012) nomenclature for a linear state space model. Under this nomenclature, the model is written as: 6 | 7 | $y_t = Z_t \alpha_t + d_t + \varepsilon_t, \quad \varepsilon_t \sim N(0, H_t)$ 8 | $\alpha_{t+1} = T_t \alpha_t + c_t + R_t \eta_t, \quad \eta_t \sim N(0, Q_t)$ 9 | $\alpha_1 \sim N(a_1, P_1)$ 10 | 11 | The objects in the above equation have the following shapes and meanings: 12 | 13 | - $y_t, p \times 1$, observation vector 14 | - $\alpha_t, m \times 1$, state vector 15 | - $\varepsilon_t, p \times 1$, observation noise vector 16 | - $\eta_t, r \times 1$, state innovation vector 17 | 18 | 19 | - $Z_t, p \times m$, the design matrix 20 | - $H_t, p \times p$, observation noise covariance matrix 21 | - $T_t, m \times m$, the transition matrix 22 | - $R_t, m \times r$, the selection matrix 23 | - $Q_t, r \times r$, the state innovation covariance matrix 24 | 25 | 26 | - $c_t, m \times 1$, the state intercept vector 27 | - $d_t, p \times 1$, the observation intercept vector 28 | 29 | The linear state space model is a workhorse in many disciplines, and is flexible enough to represent a wide range of models, including Box-Jenkins SARIMAX class models, time series decompositions, and model of multiple time series (VARMAX) models. Use of a Kalman filter allows for estimation of unobserved and missing variables. Taken together, these are a powerful class of models that can be further augmented by Bayesian estimation. This allows the researcher to integrate uncertainty about the true model when estimating model parameteres. 30 | 31 | 32 | ## Example Usage 33 | 34 | Currently, local level and ARMA models are implemented. To initialize a model, simply create a model object, provide data, and any additional arguments unique to that model. 35 | ```python 36 | import pymc_statespace as pmss 37 | #make sure data is a 2d numpy array! 38 | arma_model = pmss.BayesianARMA(data = data[:, None], order=(1, 1)) 39 | ``` 40 | You will see a message letting you know the model was created, as well as telling you the expected parameter names you will need to create inside a PyMC model block. 41 | ```commandline 42 | Model successfully initialized! The following parameters should be assigned priors inside a PyMC model block: x0, sigma_state, rho, theta 43 | ``` 44 | 45 | Next, a PyMC model is declared as usual, and the parameters can be passed into the state space model. This is done as follows: 46 | ```python 47 | with pm.Model() as arma_model: 48 | x0 = pm.Normal('x0', mu=0, sigma=1.0, shape=mod.k_states) 49 | sigma_state = pm.HalfNormal('sigma_state', sigma=1) 50 | 51 | rho = pm.TruncatedNormal('rho', mu=0.0, sigma=0.5, lower=-1.0, upper=1.0, shape=p) 52 | theta = pm.Normal('theta', mu=0.0, sigma=0.5, shape=q) 53 | 54 | arma_model.build_statespace_graph() 55 | trace_1 = pm.sample(init='jitter+adapt_diag_grad', target_accept=0.9) 56 | ``` 57 | 58 | After you place priors over the requested parameters, call `arma_model.build_statespace_graph()` to automatically put everything together into a state space system. If you are missing any parameters, you will get a error letting you know what else needs to be defined. 59 | 60 | And that's it! After this, you can sample the PyMC model as normal. 61 | 62 | 63 | ## Creating your own state space model 64 | 65 | Creating a custom state space model isn't complicated. Once again, the API follows the Statsmodels implementation. All models need to subclass the `PyMCStateSpace` class, and pass three values into the class super construction: `data` (from which p is inferred), `k_states` (this is "m" in the shapes above), and `k_posdef` (this is "r" above). The user also needs to declare any required state space matrices. Here is an example of a simple local linear trend model: 66 | 67 | ```python 68 | def __init__(self, data): 69 | # Model shapes 70 | k_states = k_posdef = 2 71 | 72 | super().__init__(data, k_states, k_posdef) 73 | 74 | # Initialize the matrices 75 | self.ssm['design'] = np.array([[1.0, 0.0]]) 76 | self.ssm['transition'] = np.array([[1.0, 1.0], 77 | [0.0, 1.0]]) 78 | self.ssm['selection'] = np.eye(k_states) 79 | 80 | self.ssm['initial_state'] = np.array([[0.0], 81 | [0.0]]) 82 | self.ssm['initial_state_cov'] = np.array([[1.0, 0.0], 83 | [0.0, 1.0]]) 84 | ``` 85 | 86 | You will also need to set up the `param_names` class method, which returns a list of parameter names. This lets users know what priors they need to define, and lets the model gather the required parameters from the PyMC model context. 87 | 88 | ```python 89 | @property 90 | def param_names(self): 91 | return ['x0', 'P0', 'sigma_obs', 'sigma_state'] 92 | ``` 93 | 94 | `self.ssm` is an `PytensorRepresentation` class object that is created by the super constructor. Every model has a `self.ssm` and a `self.kalman_filter` created after the super constructor is called. All the matrices stored in `self.ssm` are Pytensor tensor variables, but numpy arrays can be passed to them for convenience. Behind the scenes, they will be converted to Pytensor tensors. 95 | 96 | Note that the names of the matrices correspond to the names listed above. They are (in the same order): 97 | 98 | - Z = design 99 | - H = obs_cov 100 | - T = transition 101 | - R = selection 102 | - Q = state_cov 103 | - c = state_intercept 104 | - d = obs_intercept 105 | - a1 = initial_state 106 | - P1 = initial_state_cov 107 | 108 | Indexing by name only will expose the entire matrix. A name can also be followed by the usual numpy slice notation to get a specific element, row, or column. 109 | 110 | The user also needs to implement an `update` method, which takes in a single Pytensor tensor as an argument. This method routes the parameters estimated by PyMC into the right spots in the state space matrices. The local level has at least two parameters to estimate: the variance of the level state innovations, and the variance of the trend state innovations. Here is the corresponding update method: 111 | 112 | ```python 113 | def update(self, theta: at.TensorVariable) -> None: 114 | # Observation covariance 115 | self.ssm['obs_cov', 0, 0] = theta[0] 116 | 117 | # State covariance 118 | self.ssm['state_cov', np.diag_indices(2)] = theta[1:] 119 | ``` 120 | 121 | And that's it! By making a model you also gain access to simulation helper functions, including prior and posterior simulation from the predicted, filtered, and smoothed states, as well as stability analysis and impulse response functions. 122 | 123 | This package is very much a work in progress and really needs help! If you're interested in time series analysis and want to contribute, please reach out! 124 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # UPDATE 2 | 3 | This repo is no longer maintained! Statespace models are now a part of [pymc-experimental](https://github.com/pymc-devs/pymc-experimental/), and are maintained by the PyMC development team. Please look over there for the most up-to-date Bayesian State Space models! 4 | 5 | # PyMC StateSpace 6 | A system for Bayesian estimation of state space models using PyMC 5.0. This package is designed to mirror the functionality of the Statsmodels.api `tsa.statespace` module, except within a Bayesian estimation framework. To accomplish this, PyMC Statespace has a Kalman filter written in Pytensor, allowing the gradients of the iterative Kalman filter likelihood to be computed and provided to the PyMC NUTS sampler. 7 | 8 | ## State Space Models 9 | This package follows Statsmodels in using the Durbin and Koopman (2012) nomenclature for a linear state space model. Under this nomenclature, the model is written as: 10 | 11 | 12 | 13 | 14 | 15 | The objects in the above equation have the following shapes and meanings: 16 | 17 | - , observation vector 18 | - , state vector 19 | - , observation noise vector 20 | - , state innovation vector 21 | 22 | 23 | - , the design matrix 24 | - , observation noise covariance matrix 25 | - , the transition matrix 26 | - , the selection matrix 27 | - , the state innovation covariance matrix 28 | 29 | 30 | - , the state intercept vector 31 | - , the observation intercept vector 32 | 33 | The linear state space model is a workhorse in many disciplines, and is flexible enough to represent a wide range of models, including Box-Jenkins SARIMAX class models, time series decompositions, and model of multiple time series (VARMAX) models. Use of a Kalman filter allows for estimation of unobserved and missing variables. Taken together, these are a powerful class of models that can be further augmented by Bayesian estimation. This allows the researcher to integrate uncertainty about the true model when estimating model parameteres. 34 | 35 | 36 | ## Example Usage 37 | 38 | Currently, local level and ARMA models are implemented. To initialize a model, simply create a model object, provide data, and any additional arguments unique to that model. In addition, you can choose from several Kalman Filter implementations. Currently these are `standard`, `steady_state`, `univariate`, `single`, and `cholesky`. For more information, see the docs (they're on the to-do list) 39 | ```python 40 | import pymc_statespace as pmss 41 | #make sure data is a 2d numpy array! 42 | arma_model = pmss.BayesianARMA(data = data[:, None], order=(1, 1), filter='standard', stationary_initalization=True) 43 | ``` 44 | You will see a message letting you know the model was created, as well as telling you the expected parameter names you will need to create inside a PyMC model block. 45 | ```commandline 46 | Model successfully initialized! The following parameters should be assigned priors inside a PyMC model block: x0, sigma_state, rho, theta 47 | ``` 48 | 49 | Next, a PyMC model is declared as usual, and the parameters can be passed into the state space model. This is done as follows: 50 | ```python 51 | with pm.Model() as arma_model: 52 | x0 = pm.Normal('x0', mu=0, sigma=1.0, shape=mod.k_states) 53 | sigma_state = pm.HalfNormal('sigma_state', sigma=1) 54 | 55 | rho = pm.TruncatedNormal('rho', mu=0.0, sigma=0.5, lower=-1.0, upper=1.0, shape=p) 56 | theta = pm.Normal('theta', mu=0.0, sigma=0.5, shape=q) 57 | 58 | arma_model.build_statespace_graph() 59 | trace_1 = pm.sample(init='jitter+adapt_diag_grad', target_accept=0.9) 60 | ``` 61 | 62 | After you place priors over the requested parameters, call `arma_model.build_statespace_graph()` to automatically put everything together into a state space system. If you are missing any parameters, you will get a error letting you know what else needs to be defined. 63 | 64 | And that's it! After this, you can sample the PyMC model as normal. 65 | 66 | 67 | ## Creating your own state space model 68 | 69 | Creating a custom state space model isn't complicated. Once again, the API follows the Statsmodels implementation. All models need to subclass the `PyMCStateSpace` class, and pass three values into the class super construction: `data` (from which p is inferred), `k_states` (this is "m" in the shapes above), and `k_posdef` (this is "r" above). The user also needs to declare any required state space matrices. Here is an example of a simple local linear trend model: 70 | 71 | ```python 72 | def __init__(self, data): 73 | # Model shapes 74 | k_states = k_posdef = 2 75 | 76 | super().__init__(data, k_states, k_posdef) 77 | 78 | # Initialize the matrices 79 | self.ssm['design'] = np.array([[1.0, 0.0]]) 80 | self.ssm['transition'] = np.array([[1.0, 1.0], 81 | [0.0, 1.0]]) 82 | self.ssm['selection'] = np.eye(k_states) 83 | 84 | self.ssm['initial_state'] = np.array([[0.0], 85 | [0.0]]) 86 | self.ssm['initial_state_cov'] = np.array([[1.0, 0.0], 87 | [0.0, 1.0]]) 88 | ``` 89 | You will also need to set up the `param_names` class method, which returns a list of parameter names. This lets users know what priors they need to define, and lets the model gather the required parameters from the PyMC model context. 90 | 91 | ```python 92 | @property 93 | def param_names(self): 94 | return ['x0', 'P0', 'sigma_obs', 'sigma_state'] 95 | ``` 96 | 97 | `self.ssm` is an `PytensorRepresentation` class object that is created by the super constructor. Every model has a `self.ssm` and a `self.kalman_filter` created after the super constructor is called. All the matrices stored in `self.ssm` are Pytensor tensor variables, but numpy arrays can be passed to them for convenience. Behind the scenes, they will be converted to Pytensor tensors. 98 | 99 | Note that the names of the matrices correspond to the names listed above. They are (in the same order): 100 | 101 | - Z = design 102 | - H = obs_cov 103 | - T = transition 104 | - R = selection 105 | - Q = state_cov 106 | - c = state_intercept 107 | - d = obs_intercept 108 | - a1 = initial_state 109 | - P1 = initial_state_cov 110 | 111 | Indexing by name only will expose the entire matrix. A name can also be followed by the usual numpy slice notation to get a specific element, row, or column. 112 | 113 | The user also needs to implement an `update` method, which takes in a single Pytensor tensor as an argument. This method routes the parameters estimated by PyMC into the right spots in the state space matrices. The local level has at least two parameters to estimate: the variance of the level state innovations, and the variance of the trend state innovations. Here is the corresponding update method: 114 | 115 | ```python 116 | def update(self, theta: at.TensorVariable) -> None: 117 | # Observation covariance 118 | self.ssm['obs_cov', 0, 0] = theta[0] 119 | 120 | # State covariance 121 | self.ssm['state_cov', np.diag_indices(2)] = theta[1:] 122 | ``` 123 | 124 | This function is why the order matters when flattening and concatenating the random variables inside the PyMC model. In this case, we must first pass `sigma_obs`, followed by `sigma_level`, then `sigma_trend`. 125 | 126 | But that's it! Obviously this API isn't great, and will be subject to change as the package evolves, but it should be enough to get a motivated research going. Happy estimation, and let me know all the bugs you find by opening an issue. 127 | -------------------------------------------------------------------------------- /REQUIREMENTS.txt: -------------------------------------------------------------------------------- 1 | # Base dependencies 2 | pymc>=5.2 3 | pytensor 4 | numba>=0.57 5 | numpy 6 | pandas 7 | xarray 8 | matplotlib 9 | arviz 10 | statsmodels 11 | 12 | # Testing dependencies 13 | pre-commit 14 | pytest-cov>=2.5 15 | pytest>=3.0 16 | -------------------------------------------------------------------------------- /codecov.yml: -------------------------------------------------------------------------------- 1 | codecov: 2 | require_ci_to_pass: yes 3 | 4 | coverage: 5 | precision: 2 6 | round: down 7 | range: "70...100" 8 | status: 9 | project: 10 | default: 11 | # basic 12 | target: auto 13 | threshold: 1% 14 | base: auto 15 | patch: 16 | default: 17 | # basic 18 | target: 50% 19 | threshold: 1% 20 | base: auto 21 | 22 | ignore: 23 | - "pymc_statespace/tests/*" 24 | - "pymc_statespace/examples/*" 25 | - "pymc_statespace/data/*" 26 | 27 | comment: 28 | layout: "reach, diff, flags, files" 29 | behavior: default 30 | require_changes: false # if true: only post the comment if coverage changes 31 | require_base: no # [yes :: must have a base report to post] 32 | require_head: yes # [yes :: must have a head report to post] 33 | branches: null # branch names that can post comment 34 | -------------------------------------------------------------------------------- /conda_envs/pymc_statespace.yml: -------------------------------------------------------------------------------- 1 | name: pymc-statespace 2 | channels: 3 | - conda-forge 4 | dependencies: 5 | # Base dependencies 6 | - pymc 7 | - pytensor 8 | - numba::numba>=0.57 9 | - numpy 10 | - pandas 11 | - xarray 12 | - matplotlib 13 | - arviz 14 | - statsmodels 15 | # Testing dependencies 16 | - pre-commit 17 | - pytest-cov>=2.5 18 | - pytest>=3.0 19 | - pytest-env 20 | -------------------------------------------------------------------------------- /data/nile.csv: -------------------------------------------------------------------------------- 1 | "x" 2 | 1120 3 | 1160 4 | 963 5 | 1210 6 | 1160 7 | 1160 8 | 813 9 | 1230 10 | 1370 11 | 1140 12 | 995 13 | 935 14 | 1110 15 | 994 16 | 1020 17 | 960 18 | 1180 19 | 799 20 | 958 21 | 1140 22 | 1100 23 | 1210 24 | 1150 25 | 1250 26 | 1260 27 | 1220 28 | 1030 29 | 1100 30 | 774 31 | 840 32 | 874 33 | 694 34 | 940 35 | 833 36 | 701 37 | 916 38 | 692 39 | 1020 40 | 1050 41 | 969 42 | 831 43 | 726 44 | 456 45 | 824 46 | 702 47 | 1120 48 | 1100 49 | 832 50 | 764 51 | 821 52 | 768 53 | 845 54 | 864 55 | 862 56 | 698 57 | 845 58 | 744 59 | 796 60 | 1040 61 | 759 62 | 781 63 | 865 64 | 845 65 | 944 66 | 984 67 | 897 68 | 822 69 | 1010 70 | 771 71 | 676 72 | 649 73 | 846 74 | 812 75 | 742 76 | 801 77 | 1040 78 | 860 79 | 874 80 | 848 81 | 890 82 | 744 83 | 749 84 | 838 85 | 1050 86 | 918 87 | 986 88 | 797 89 | 923 90 | 975 91 | 815 92 | 1020 93 | 906 94 | 901 95 | 1170 96 | 912 97 | 746 98 | 919 99 | 718 100 | 714 101 | 740 102 | -------------------------------------------------------------------------------- /pymc_statespace/__init__.py: -------------------------------------------------------------------------------- 1 | from pymc_statespace.models.local_level import BayesianLocalLevel 2 | from pymc_statespace.models.SARIMAX import BayesianARMA 3 | from pymc_statespace.models.VARMAX import BayesianVARMAX 4 | 5 | __all__ = ["BayesianLocalLevel", "BayesianARMA", "BayesianVARMAX"] 6 | 7 | __version__ = "0.0.1" 8 | -------------------------------------------------------------------------------- /pymc_statespace/core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessegrabowski/pymc_statespace/0659a3bfd9186f128f238c69d5193d3c461f3948/pymc_statespace/core/__init__.py -------------------------------------------------------------------------------- /pymc_statespace/core/representation.py: -------------------------------------------------------------------------------- 1 | from functools import reduce 2 | from typing import List, Optional, Tuple, Type, Union 3 | 4 | import numpy as np 5 | import pandas.core.tools.datetimes 6 | import pytensor.tensor as pt 7 | from pandas import DataFrame 8 | from pytensor.tensor import TensorVariable 9 | 10 | KeyLike = Union[Tuple[Union[str, int]], str] 11 | 12 | 13 | def _preprocess_data(data: Union[DataFrame, np.ndarray], expected_dims=3): 14 | if isinstance(data, pandas.DataFrame): 15 | data = data.values 16 | elif not isinstance(data, np.ndarray): 17 | raise ValueError("Expected pandas Dataframe or numpy array as data") 18 | 19 | if data.ndim < expected_dims: 20 | n_dims = data.ndim 21 | n_to_add = expected_dims - n_dims + 1 22 | data = reduce(lambda a, b: np.expand_dims(a, -1), [data] * n_to_add) 23 | 24 | return data 25 | 26 | 27 | class PytensorRepresentation: 28 | def __init__( 29 | self, 30 | data: Union[DataFrame, np.ndarray], 31 | k_states: int, 32 | k_posdef: int, 33 | design: Optional[np.ndarray] = None, 34 | obs_intercept: Optional[np.ndarray] = None, 35 | obs_cov=None, 36 | transition=None, 37 | state_intercept=None, 38 | selection=None, 39 | state_cov=None, 40 | initial_state=None, 41 | initial_state_cov=None, 42 | ) -> None: 43 | """ 44 | A representation of a State Space model, in Pytensor. Shamelessly copied from the Statsmodels.api implementation 45 | found here: 46 | 47 | https://github.com/statsmodels/statsmodels/blob/main/statsmodels/tsa/statespace/representation.py 48 | 49 | Parameters 50 | ---------- 51 | data: ArrayLike 52 | Array of observed data (called exog in statsmodels) 53 | k_states: int 54 | Number of hidden states 55 | k_posdef: int 56 | Number of states that have exogenous shocks; also the rank of the selection matrix R. 57 | design: ArrayLike, optional 58 | Design matrix, denoted 'Z' in [1]. 59 | obs_intercept: ArrayLike, optional 60 | Constant vector in the observation equation, denoted 'd' in [1]. Currently 61 | not used. 62 | obs_cov: ArrayLike, optional 63 | Covariance matrix for multivariate-normal errors in the observation equation. Denoted 'H' in 64 | [1]. 65 | transition: ArrayLike, optional 66 | Transition equation that updates the hidden state between time-steps. Denoted 'T' in [1]. 67 | state_intercept: ArrayLike, optional 68 | Constant vector for the observation equation, denoted 'c' in [1]. Currently not used. 69 | selection: ArrayLike, optional 70 | Selection matrix that matches shocks to hidden states, denoted 'R' in [1]. This is the identity 71 | matrix when k_posdef = k_states. 72 | state_cov: ArrayLike, optional 73 | Covariance matrix for state equations, denoted 'Q' in [1]. Null matrix when there is no observation 74 | noise. 75 | initial_state: ArrayLike, optional 76 | Experimental setting to allow for Bayesian estimation of the initial state, denoted `alpha_0` in [1]. Default 77 | It should potentially be removed in favor of the closed-form diffuse initialization. 78 | initial_state_cov: ArrayLike, optional 79 | Experimental setting to allow for Bayesian estimation of the initial state, denoted `P_0` in [1]. Default 80 | It should potentially be removed in favor of the closed-form diffuse initialization. 81 | 82 | References 83 | ---------- 84 | .. [1] Durbin, James, and Siem Jan Koopman. 2012. 85 | Time Series Analysis by State Space Methods: Second Edition. 86 | Oxford University Press. 87 | """ 88 | 89 | # self.data = pt.tensor3(name="Data") 90 | # self.transition = pt.tensor3(name="transition") 91 | # self.selection = pt.tensor3(name="selection") 92 | # self.design = pt.tensor3(name="design") 93 | # self.obs_cov = pt.tensor3(name="obs_cov") 94 | # self.state_cov = pt.tensor3(name="state_cov") 95 | # self.state_intercept = pt.tensor3(name="state_cov") 96 | # self.obs_intercept = pt.tensor3(name="state_cov") 97 | # self.initial_state = pt.tensor3(name="state_cov") 98 | # self.initial_state_cov = pt.tensor3(name="state_cov") 99 | 100 | self.data = _preprocess_data(data) 101 | self.k_states = k_states 102 | self.k_posdef = k_posdef if k_posdef is not None else k_states 103 | 104 | self.n_obs, self.k_endog, *_ = data.shape 105 | 106 | # The last dimension is for time varying matrices; it could be n_obs. Not thinking about that now. 107 | self.shapes = { 108 | "data": (self.k_endog, self.n_obs, 1), 109 | "design": (self.k_endog, self.k_states, 1), 110 | "obs_intercept": (self.k_endog, 1, 1), 111 | "obs_cov": (self.k_endog, self.k_endog, 1), 112 | "transition": (self.k_states, self.k_states, 1), 113 | "state_intercept": (self.k_states, 1, 1), 114 | "selection": (self.k_states, self.k_posdef, 1), 115 | "state_cov": (self.k_posdef, self.k_posdef, 1), 116 | "initial_state": (self.k_states, 1, 1), 117 | "initial_state_cov": (self.k_states, self.k_states, 1), 118 | } 119 | 120 | # Initialize the representation matrices 121 | scope = locals() 122 | for name, shape in self.shapes.items(): 123 | if name == "data": 124 | continue 125 | 126 | elif scope[name] is not None: 127 | matrix = self._numpy_to_pytensor(name, scope[name]) 128 | setattr(self, name, matrix) 129 | 130 | else: 131 | setattr(self, name, pt.zeros(shape)) 132 | 133 | def _validate_key(self, key: KeyLike) -> None: 134 | if key not in self.shapes: 135 | raise IndexError(f"{key} is an invalid state space matrix name") 136 | 137 | def update_shape(self, key: KeyLike, value: Union[np.ndarray, pt.TensorType]) -> None: 138 | # TODO: Get rid of these evals 139 | 140 | if isinstance(value, (pt.TensorConstant, pt.TensorVariable)): 141 | shape = value.shape.eval() 142 | else: 143 | shape = value.shape 144 | 145 | old_shape = self.shapes[key] 146 | if not all([a == b for a, b in zip(shape[:2], old_shape[:2])]): 147 | raise ValueError( 148 | f"The first two dimensions of {key} must be {old_shape[:2]}, found {shape[:2]}" 149 | ) 150 | 151 | # Add time dimension dummy if none present 152 | if len(shape) == 2: 153 | self.shapes[key] = shape + (1,) 154 | 155 | self.shapes[key] = shape 156 | 157 | def _add_time_dim_to_slice( 158 | self, name: str, slice_: Union[List[int], Tuple[int]], n_dim: int 159 | ) -> Tuple[int]: 160 | no_time_dim = self.shapes[name][-1] == 1 161 | 162 | # Case 1: All dimensions are sliced 163 | if len(slice_) == n_dim: 164 | return slice_ 165 | 166 | # Case 2a: There is a time dim. Just return. 167 | if not no_time_dim: 168 | return slice_ 169 | 170 | # Case 2b: There's no time dim. Slice away the dummy dim. 171 | if len(slice_) < n_dim: 172 | empty_slice = (slice(None, None, None),) 173 | n_omitted = n_dim - len(slice_) - 1 174 | return tuple(slice_) + empty_slice * n_omitted + (0,) 175 | 176 | @staticmethod 177 | def _validate_key_and_get_type(key: KeyLike) -> Type[str]: 178 | if isinstance(key, tuple) and not isinstance(key[0], str): 179 | raise IndexError("First index must the name of a valid state space matrix.") 180 | 181 | return type(key) 182 | 183 | def _validate_matrix_shape(self, name: str, X: np.ndarray) -> None: 184 | *expected_shape, time_dim = self.shapes[name] 185 | expected_shape = tuple(expected_shape) 186 | 187 | if X.ndim > 3 or X.ndim < 2: 188 | raise ValueError( 189 | f"Array provided for {name} has {X.ndim} dimensions, " 190 | f"expecting 2 (static) or 3 (time-varying)" 191 | ) 192 | 193 | if X.ndim == 2: 194 | if expected_shape != X.shape: 195 | raise ValueError( 196 | f"Array provided for {name} has shape {X.shape}, expected {expected_shape}" 197 | ) 198 | if X.ndim == 3: 199 | if X.shape[:2] != expected_shape: 200 | raise ValueError( 201 | f"First two dimensions of array provided for {name} has shape {X.shape[:2]}, " 202 | f"expected {expected_shape}" 203 | ) 204 | if X.shape[-1] != self.data.shape[0]: 205 | raise ValueError( 206 | f"Last dimension (time dimension) of array provided for {name} has shape " 207 | f"{X.shape[-1]}, expected {self.data.shape[0]} (equal to the first dimension of the " 208 | f"provided data)" 209 | ) 210 | 211 | def _numpy_to_pytensor(self, name: str, X: np.ndarray) -> pt.TensorVariable: 212 | X = X.copy() 213 | self._validate_matrix_shape(name, X) 214 | # Add a time dimension if one isn't provided 215 | if X.ndim == 2: 216 | X = X[..., None] 217 | return pt.as_tensor(X, name=name) 218 | 219 | def __getitem__(self, key: KeyLike) -> pt.TensorVariable: 220 | _type = self._validate_key_and_get_type(key) 221 | 222 | # Case 1: user asked for an entire matrix by name 223 | if _type is str: 224 | self._validate_key(key) 225 | matrix = getattr(self, key) 226 | 227 | # Slice away the time dimension if it's a dummy 228 | if self.shapes[key][-1] == 1: 229 | return matrix[(slice(None),) * (matrix.ndim - 1) + (0,)] 230 | 231 | # If it's time varying, return everything 232 | else: 233 | return matrix 234 | 235 | # Case 2: user asked for a particular matrix and some slices of it 236 | elif _type is tuple: 237 | name, *slice_ = key 238 | self._validate_key(name) 239 | 240 | matrix = getattr(self, name) 241 | slice_ = self._add_time_dim_to_slice(name, slice_, matrix.ndim) 242 | 243 | return matrix[slice_] 244 | 245 | # Case 3: There is only one slice index, but it's not a string 246 | else: 247 | raise IndexError("First index must the name of a valid state space matrix.") 248 | 249 | def __setitem__(self, key: KeyLike, value: Union[float, int, np.ndarray]) -> None: 250 | _type = type(key) 251 | # Case 1: key is a string: we are setting an entire matrix. 252 | if _type is str: 253 | self._validate_key(key) 254 | if isinstance(value, np.ndarray): 255 | value = self._numpy_to_pytensor(key, value) 256 | setattr(self, key, value) 257 | self.update_shape(key, value) 258 | 259 | # Case 2: key is a string plus a slice: we are setting a subset of a matrix 260 | elif _type is tuple: 261 | name, *slice_ = key 262 | self._validate_key(name) 263 | 264 | matrix = getattr(self, name) 265 | 266 | slice_ = self._add_time_dim_to_slice(name, slice_, matrix.ndim) 267 | 268 | matrix = pt.set_subtensor(matrix[slice_], value) 269 | setattr(self, name, matrix) 270 | -------------------------------------------------------------------------------- /pymc_statespace/filters/__init__.py: -------------------------------------------------------------------------------- 1 | from pymc_statespace.filters.kalman_filter import ( 2 | CholeskyFilter, 3 | SingleTimeseriesFilter, 4 | StandardFilter, 5 | SteadyStateFilter, 6 | UnivariateFilter, 7 | ) 8 | from pymc_statespace.filters.kalman_smoother import KalmanSmoother 9 | 10 | __all__ = [ 11 | "StandardFilter", 12 | "UnivariateFilter", 13 | "SteadyStateFilter", 14 | "KalmanSmoother", 15 | "SingleTimeseriesFilter", 16 | "CholeskyFilter", 17 | ] 18 | -------------------------------------------------------------------------------- /pymc_statespace/filters/kalman_smoother.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import pytensor 4 | import pytensor.tensor as pt 5 | from pytensor.compile import get_mode 6 | from pytensor.tensor.nlinalg import matrix_dot 7 | 8 | from pymc_statespace.filters.utilities import split_vars_into_seq_and_nonseq 9 | 10 | 11 | class KalmanSmoother: 12 | def __init__(self, mode: Optional[str] = None): 13 | self.mode = mode 14 | self.seq_names = [] 15 | self.non_seq_names = [] 16 | 17 | def unpack_args(self, args): 18 | """ 19 | The order of inputs to the inner scan function is not known, since some, all, or none of the input matrices 20 | can be time varying. The order arguments are fed to the inner function is sequences, outputs_info, 21 | non-sequences. This function works out which matrices are where, and returns a standardized order expected 22 | by the kalman_step function. 23 | 24 | The standard order is: a, P, a_smooth, P_smooth, T, R, Q 25 | """ 26 | # If there are no sequence parameters (all params are static), 27 | # no changes are needed, params will be in order. 28 | args = list(args) 29 | n_seq = len(self.seq_names) 30 | if n_seq == 0: 31 | return args 32 | 33 | # The first two args are always a and P 34 | a = args.pop(0) 35 | P = args.pop(0) 36 | 37 | # There are always two outputs_info wedged between the seqs and non_seqs 38 | seqs, (a_smooth, P_smooth), non_seqs = ( 39 | args[:n_seq], 40 | args[n_seq : n_seq + 2], 41 | args[n_seq + 2 :], 42 | ) 43 | return_ordered = [] 44 | for name in ["T", "R", "Q"]: 45 | if name in self.seq_names: 46 | idx = self.seq_names.index(name) 47 | return_ordered.append(seqs[idx]) 48 | else: 49 | idx = self.non_seq_names.index(name) 50 | return_ordered.append(non_seqs[idx]) 51 | 52 | T, R, Q = return_ordered 53 | 54 | return a, P, a_smooth, P_smooth, T, R, Q 55 | 56 | def build_graph(self, T, R, Q, filtered_states, filtered_covariances, mode=None): 57 | self.mode = mode 58 | 59 | a_last = filtered_states[-1] 60 | P_last = filtered_covariances[-1] 61 | 62 | sequences, non_sequences, seq_names, non_seq_names = split_vars_into_seq_and_nonseq( 63 | [T, R, Q], ["T", "R", "Q"] 64 | ) 65 | 66 | self.seq_names = seq_names 67 | self.non_seq_names = non_seq_names 68 | 69 | smoother_result, updates = pytensor.scan( 70 | self.smoother_step, 71 | sequences=[filtered_states[:-1], filtered_covariances[:-1]] + sequences, 72 | outputs_info=[a_last, P_last], 73 | non_sequences=non_sequences, 74 | go_backwards=True, 75 | name="kalman_smoother", 76 | mode=get_mode(self.mode), 77 | ) 78 | 79 | smoothed_states, smoothed_covariances = smoother_result 80 | smoothed_states = pt.concatenate([smoothed_states[::-1], pt.atleast_3d(a_last)], axis=0) 81 | smoothed_covariances = pt.concatenate( 82 | [smoothed_covariances[::-1], pt.atleast_3d(P_last)], axis=0 83 | ) 84 | 85 | return smoothed_states, smoothed_covariances 86 | 87 | def smoother_step(self, *args): 88 | a, P, a_smooth, P_smooth, T, R, Q = self.unpack_args(args) 89 | a_hat, P_hat = self.predict(a, P, T, R, Q) 90 | 91 | # Use pinv, otherwise P_hat is singular when there is missing data 92 | smoother_gain = matrix_dot(pt.linalg.pinv(P_hat), T, P).T 93 | a_smooth_next = a + smoother_gain @ (a_smooth - a_hat) 94 | 95 | P_smooth_next = P + matrix_dot(smoother_gain, P_smooth - P_hat, smoother_gain.T) 96 | 97 | return a_smooth_next, P_smooth_next 98 | 99 | @staticmethod 100 | def predict(a, P, T, R, Q): 101 | a_hat = T.dot(a) 102 | P_hat = matrix_dot(T, P, T.T) + matrix_dot(R, Q, R.T) 103 | 104 | return a_hat, P_hat 105 | -------------------------------------------------------------------------------- /pymc_statespace/filters/utilities.py: -------------------------------------------------------------------------------- 1 | def split_vars_into_seq_and_nonseq(params, param_names): 2 | """ 3 | Split inputs into those that are time varying and those that are not. This division is required by scan. 4 | """ 5 | sequences, non_sequences = [], [] 6 | seq_names, non_seq_names = [], [] 7 | 8 | for param, name in zip(params, param_names): 9 | if param.ndim == 2: 10 | non_sequences.append(param) 11 | non_seq_names.append(name) 12 | elif param.ndim == 3: 13 | sequences.append(param) 14 | seq_names.append(name) 15 | else: 16 | raise ValueError( 17 | f"Matrix {name} has {param.ndim}, it should either 2 (static) or 3 (time varying)." 18 | ) 19 | 20 | return sequences, non_sequences, seq_names, non_seq_names 21 | -------------------------------------------------------------------------------- /pymc_statespace/models/SARIMAX.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import numpy as np 4 | import pytensor.tensor as at 5 | 6 | from pymc_statespace.core.statespace import PyMCStateSpace 7 | from pymc_statespace.utils.pytensor_scipy import solve_discrete_lyapunov 8 | 9 | 10 | class BayesianARMA(PyMCStateSpace): 11 | def __init__( 12 | self, 13 | data, 14 | order: Tuple[int, int], 15 | stationary_initialization: bool = True, 16 | filter_type: str = "standard", 17 | ): 18 | 19 | # Model order 20 | self.p, self.q = order 21 | 22 | self.stationary_initialization = stationary_initialization 23 | 24 | k_states = max(self.p, self.q + 1) 25 | k_posdef = 1 26 | 27 | super().__init__(data, k_states, k_posdef, filter_type) 28 | 29 | # Initialize the matrices 30 | self.ssm["design"] = np.r_[[1.0], np.zeros(k_states - 1)][None] 31 | 32 | self.ssm["transition"] = np.eye(k_states, k=1) 33 | 34 | self.ssm["selection"] = np.r_[[[1.0]], np.zeros(k_states - 1)[:, None]] 35 | 36 | self.ssm["initial_state"] = np.zeros(k_states)[:, None] 37 | 38 | self.ssm["initial_state_cov"] = np.eye(k_states) 39 | 40 | # Cache some indices 41 | self._state_cov_idx = ("state_cov",) + np.diag_indices(k_posdef) 42 | self._ar_param_idx = ("transition",) + ( 43 | np.arange(self.p, dtype=int), 44 | np.zeros(self.p, dtype=int), 45 | ) 46 | self._ma_param_idx = ("selection",) + ( 47 | np.arange(1, self.q + 1, dtype=int), 48 | np.zeros(self.q, dtype=int), 49 | ) 50 | 51 | @property 52 | def param_names(self): 53 | names = ["x0", "P0", "sigma_state", "rho", "theta"] 54 | if self.stationary_initialization: 55 | names.remove("P0") 56 | 57 | return names 58 | 59 | def update(self, theta: at.TensorVariable) -> None: 60 | """ 61 | Put parameter values from vector theta into the correct positions in the state space matrices. 62 | TODO: Can this be done using variable names to avoid the need to ravel and concatenate all RVs in the 63 | PyMC model? 64 | 65 | Parameters 66 | ---------- 67 | theta: TensorVariable 68 | Vector of all variables in the state space model 69 | """ 70 | cursor = 0 71 | 72 | # initial states 73 | param_slice = slice(cursor, cursor + self.k_states) 74 | cursor += self.k_states 75 | self.ssm["initial_state", :, 0] = theta[param_slice] 76 | 77 | if not self.stationary_initialization: 78 | # initial covariance 79 | param_slice = slice(cursor, cursor + self.k_states**2) 80 | cursor += self.k_states**2 81 | self.ssm["initial_state_cov", :, :] = theta[param_slice].reshape( 82 | (self.k_states, self.k_states) 83 | ) 84 | 85 | # State covariance 86 | param_slice = slice(cursor, cursor + 1) 87 | cursor += 1 88 | self.ssm[self._state_cov_idx] = theta[param_slice] 89 | 90 | # AR parameteres 91 | param_slice = slice(cursor, cursor + self.p) 92 | cursor += self.p 93 | self.ssm[self._ar_param_idx] = theta[param_slice] 94 | 95 | # MA parameters 96 | param_slice = slice(cursor, cursor + self.q) 97 | cursor += self.q 98 | self.ssm[self._ma_param_idx] = theta[param_slice] 99 | 100 | if self.stationary_initialization: 101 | # Solve for matrix quadratic for P0 102 | T = self.ssm["transition"] 103 | R = self.ssm["selection"] 104 | Q = self.ssm["state_cov"] 105 | 106 | P0 = solve_discrete_lyapunov(T, at.linalg.matrix_dot(R, Q, R.T), method="bilinear") 107 | self.ssm["initial_state_cov", :, :] = P0 108 | -------------------------------------------------------------------------------- /pymc_statespace/models/VARMAX.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import numpy as np 4 | import pytensor.tensor as at 5 | 6 | from pymc_statespace.core.statespace import PyMCStateSpace 7 | from pymc_statespace.models.utilities import get_slice_and_move_cursor 8 | from pymc_statespace.utils.pytensor_scipy import solve_discrete_lyapunov 9 | 10 | 11 | class BayesianVARMAX(PyMCStateSpace): 12 | def __init__( 13 | self, 14 | data, 15 | order: Tuple[int, int], 16 | stationary_initialization: bool = True, 17 | filter_type: str = "standard", 18 | measurement_error: bool = True, 19 | verbose=True, 20 | ): 21 | 22 | self.p, self.q = order 23 | self.stationary_initialization = stationary_initialization 24 | self.measurement_error = measurement_error 25 | 26 | k_order = max(self.p, 1) + self.q 27 | k_obs = data.shape[1] 28 | k_states = k_obs * k_order 29 | k_posdef = data.shape[1] 30 | 31 | super().__init__(data, k_states, k_posdef, filter_type, verbose=verbose) 32 | 33 | # Save counts of the number of parameters in each category 34 | self.param_counts = { 35 | "x0": k_states, 36 | "P0": k_states**2 * (1 - self.stationary_initialization), 37 | "AR": k_obs**2 * self.p, 38 | "MA": k_obs**2 * self.q, 39 | "state_cov": k_posdef**2, 40 | "obs_cov": k_obs * self.measurement_error, 41 | } 42 | 43 | # Initialize the matrices 44 | # Design matrix is a truncated identity (first k_obs states observed) 45 | self.ssm[("design",) + np.diag_indices(k_obs)] = 1 46 | 47 | # Transition matrix has 4 blocks: 48 | self.ssm["transition"] = np.zeros((k_states, k_states)) 49 | 50 | # UL: AR coefs (k_obs, k_obs * min(p, 1)) 51 | # UR: MA coefs (k_obs, k_obs * q) 52 | # LL: Truncated identity (k_obs * min(p, 1), k_obs * min(p, 1)) 53 | # LR: Shifted identity (k_obs * p, k_obs * q) 54 | if self.p > 1: 55 | idx = (slice(k_obs, k_obs * self.p), slice(0, k_obs * (self.p - 1))) 56 | self.ssm[("transition",) + idx] = np.eye(k_obs * (self.p - 1)) 57 | 58 | if self.q > 1: 59 | idx = (slice(-k_obs * (self.q - 1), None), slice(-k_obs * self.q, -k_obs)) 60 | self.ssm[("transition",) + idx] = np.eye(k_obs * (self.q - 1)) 61 | 62 | # The selection matrix is (k_states, k_obs), with two (k_obs, k_obs) identity 63 | # matrix blocks inside. One is always on top, the other starts after (k_obs * p) rows 64 | self.ssm["selection"] = np.zeros((k_states, k_obs)) 65 | self.ssm["selection", slice(0, k_obs), :] = np.eye(k_obs) 66 | if self.q > 0: 67 | end = -k_obs * (self.q - 1) if self.q > 1 else None 68 | self.ssm["selection", slice(k_obs * -self.q, end), :] = np.eye(k_obs) 69 | 70 | # self.ssm["initial_state"] = np.zeros(k_states)[:, None] 71 | # self.ssm["initial_state_cov"] = np.eye(k_states) 72 | # self.ssm["state_cov"] = np.eye(k_posdef) 73 | # 74 | # if self.measurement_error: 75 | # self.ssm['obs_cov'] = np.eye(k_obs) 76 | 77 | # Cache some indices 78 | self._ar_param_idx = ("transition", slice(0, k_obs), slice(0, k_obs * self.p)) 79 | self._ma_param_idx = ("transition", slice(0, k_obs), slice(k_obs * max(1, self.p), None)) 80 | self._obs_cov_idx = ("obs_cov",) + np.diag_indices(k_obs) 81 | 82 | @property 83 | def param_names(self): 84 | names = ["x0", "P0", "ar_params", "ma_params", "state_cov", "obs_cov"] 85 | if self.stationary_initialization: 86 | names.remove("P0") 87 | if not self.measurement_error: 88 | names.remove("obs_cov") 89 | if self.p == 0: 90 | names.remove("ar_params") 91 | if self.q == 0: 92 | names.remove("ma_params") 93 | return names 94 | 95 | def update(self, theta: at.TensorVariable) -> None: 96 | """ 97 | Put parameter values from vector theta into the correct positions in the state space matrices. 98 | 99 | Parameters 100 | ---------- 101 | theta: TensorVariable 102 | Vector of all variables in the state space model 103 | """ 104 | 105 | cursor = 0 106 | # initial states 107 | param_slice, cursor = get_slice_and_move_cursor(cursor, self.param_counts["x0"]) 108 | self.ssm["initial_state", :, 0] = theta[param_slice] 109 | 110 | if not self.stationary_initialization: 111 | # initial covariance 112 | param_slice, cursor = get_slice_and_move_cursor(cursor, self.param_counts["P0"]) 113 | self.ssm["initial_state_cov", :, :] = theta[param_slice].reshape( 114 | (self.k_states, self.k_states) 115 | ) 116 | 117 | # AR parameters 118 | if self.p > 0: 119 | ar_shape = (self.k_endog, self.k_endog * self.p) 120 | param_slice, cursor = get_slice_and_move_cursor(cursor, self.param_counts["AR"]) 121 | self.ssm[self._ar_param_idx] = theta[param_slice].reshape(ar_shape) 122 | 123 | # MA parameters 124 | if self.q > 0: 125 | ma_shape = (self.k_endog, self.k_endog * self.q) 126 | param_slice, cursor = get_slice_and_move_cursor(cursor, self.param_counts["MA"]) 127 | self.ssm[self._ma_param_idx] = theta[param_slice].reshape(ma_shape) 128 | 129 | # State covariance 130 | param_slice, cursor = get_slice_and_move_cursor( 131 | cursor, self.param_counts["state_cov"], last_slice=not self.measurement_error 132 | ) 133 | 134 | self.ssm["state_cov", :, :] = theta[param_slice].reshape((self.k_posdef, self.k_posdef)) 135 | 136 | # Measurement error 137 | if self.measurement_error: 138 | param_slice, cursor = get_slice_and_move_cursor( 139 | cursor, self.param_counts["obs_cov"], last_slice=True 140 | ) 141 | self.ssm[self._obs_cov_idx] = theta[param_slice] 142 | 143 | if self.stationary_initialization: 144 | # Solve for matrix quadratic for P0 145 | T = self.ssm["transition"] 146 | R = self.ssm["selection"] 147 | Q = self.ssm["state_cov"] 148 | 149 | P0 = solve_discrete_lyapunov(T, at.linalg.matrix_dot(R, Q, R.T), method="bilinear") 150 | self.ssm["initial_state_cov", :, :] = P0 151 | -------------------------------------------------------------------------------- /pymc_statespace/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessegrabowski/pymc_statespace/0659a3bfd9186f128f238c69d5193d3c461f3948/pymc_statespace/models/__init__.py -------------------------------------------------------------------------------- /pymc_statespace/models/local_level.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytensor.tensor as at 3 | 4 | from pymc_statespace.core.statespace import PyMCStateSpace 5 | 6 | 7 | class BayesianLocalLevel(PyMCStateSpace): 8 | def __init__(self, data): 9 | k_states = k_posdef = 2 10 | 11 | super().__init__(data, k_states, k_posdef) 12 | 13 | # Initialize the matrices 14 | self.ssm["design"] = np.array([[1.0, 0.0]]) 15 | self.ssm["transition"] = np.array([[1.0, 1.0], [0.0, 1.0]]) 16 | self.ssm["selection"] = np.eye(k_states) 17 | 18 | self.ssm["initial_state"] = np.array([[0.0], [0.0]]) 19 | self.ssm["initial_state_cov"] = np.array([[1.0, 0.0], [0.0, 1.0]]) 20 | 21 | # Cache some indices 22 | self._state_cov_idx = ("state_cov",) + np.diag_indices(k_posdef) 23 | 24 | @property 25 | def param_names(self): 26 | return ["x0", "P0", "sigma_obs", "sigma_state"] 27 | 28 | def update(self, theta: at.TensorVariable) -> None: 29 | """ 30 | Put parameter values from vector theta into the correct positions in the state space matrices. 31 | TODO: Can this be done using variable names to avoid the need to ravel and concatenate all RVs in the 32 | PyMC model? 33 | 34 | Parameters 35 | ---------- 36 | theta: TensorVariable 37 | Vector of all variables in the state space model 38 | """ 39 | # initial states 40 | self.ssm["initial_state", :, 0] = theta[:2] 41 | 42 | # initial covariance 43 | self.ssm["initial_state_cov", :, :] = theta[2:6].reshape((2, 2)) 44 | 45 | # Observation covariance 46 | self.ssm["obs_cov", 0, 0] = theta[6] 47 | 48 | # State covariance 49 | self.ssm[self._state_cov_idx] = theta[7:] 50 | -------------------------------------------------------------------------------- /pymc_statespace/models/utilities.py: -------------------------------------------------------------------------------- 1 | def get_slice_and_move_cursor(cursor, param_count, last_slice=False): 2 | param_slice = slice(cursor, None if last_slice else cursor + param_count) 3 | cursor += param_count 4 | 5 | return param_slice, cursor 6 | -------------------------------------------------------------------------------- /pymc_statespace/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessegrabowski/pymc_statespace/0659a3bfd9186f128f238c69d5193d3c461f3948/pymc_statespace/utils/__init__.py -------------------------------------------------------------------------------- /pymc_statespace/utils/numba_linalg.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from numba import njit 3 | 4 | 5 | @njit 6 | def numba_block_diagonal(nd_array): 7 | n, rows, cols = nd_array.shape 8 | 9 | out = np.zeros((n * rows, n * cols)) 10 | 11 | r, c = 0, 0 12 | for i, (rr, cc) in enumerate([(rows, cols)] * n): 13 | out[r : r + rr, c : c + cc] = nd_array[i] 14 | r += rr 15 | c += cc 16 | return out 17 | -------------------------------------------------------------------------------- /pymc_statespace/utils/pytensor_scipy.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import pytensor 4 | import pytensor.tensor as at 5 | import scipy 6 | from pytensor.tensor import TensorVariable, as_tensor_variable 7 | from pytensor.tensor.nlinalg import matrix_dot 8 | from pytensor.tensor.slinalg import solve_discrete_lyapunov 9 | 10 | 11 | class SolveDiscreteARE(at.Op): 12 | __props__ = ("enforce_Q_symmetric",) 13 | 14 | def __init__(self, enforce_Q_symmetric=False): 15 | self.enforce_Q_symmetric = enforce_Q_symmetric 16 | 17 | def make_node(self, A, B, Q, R): 18 | A = as_tensor_variable(A) 19 | B = as_tensor_variable(B) 20 | Q = as_tensor_variable(Q) 21 | R = as_tensor_variable(R) 22 | 23 | out_dtype = pytensor.scalar.upcast(A.dtype, B.dtype, Q.dtype, R.dtype) 24 | X = pytensor.tensor.matrix(dtype=out_dtype) 25 | 26 | return pytensor.graph.basic.Apply(self, [A, B, Q, R], [X]) 27 | 28 | def perform(self, node, inputs, output_storage): 29 | A, B, Q, R = inputs 30 | X = output_storage[0] 31 | 32 | if self.enforce_Q_symmetric: 33 | Q = 0.5 * (Q + Q.T) 34 | X[0] = scipy.linalg.solve_discrete_are(A, B, Q, R) 35 | 36 | def infer_shape(self, fgraph, node, shapes): 37 | return [shapes[0]] 38 | 39 | def grad(self, inputs, output_grads): 40 | # Gradient computations come from Kao and Hennequin (2020), https://arxiv.org/pdf/2011.11430.pdf 41 | A, B, Q, R = inputs 42 | 43 | (dX,) = output_grads 44 | X = self(A, B, Q, R) 45 | 46 | K_inner = R + at.linalg.matrix_dot(B.T, X, B) 47 | K_inner_inv = at.linalg.solve(K_inner, at.eye(R.shape[0])) 48 | K = matrix_dot(K_inner_inv, B.T, X, A) 49 | 50 | A_tilde = A - B.dot(K) 51 | 52 | dX_symm = 0.5 * (dX + dX.T) 53 | S = solve_discrete_lyapunov(A_tilde, dX_symm) 54 | 55 | A_bar = 2 * matrix_dot(X, A_tilde, S) 56 | B_bar = -2 * matrix_dot(X, A_tilde, S, K.T) 57 | Q_bar = S 58 | R_bar = matrix_dot(K, S, K.T) 59 | 60 | return [A_bar, B_bar, Q_bar, R_bar] 61 | 62 | 63 | def solve_discrete_are(A, B, Q, R) -> TensorVariable: 64 | """ 65 | Solve the discrete Algebraic Riccati equation :math:`A^TXA - X - (A^TXB)(R + B^TXB)^{-1}(B^TXA) + Q = 0`. 66 | Parameters 67 | ---------- 68 | A: ArrayLike 69 | Square matrix of shape M x M 70 | B: ArrayLike 71 | Square matrix of shape M x M 72 | Q: ArrayLike 73 | Square matrix of shape M x M 74 | R: ArrayLike 75 | Square matrix of shape N x N 76 | 77 | Returns 78 | ------- 79 | X: at.matrix 80 | Square matrix of shape M x M, representing the solution to the DARE 81 | """ 82 | 83 | return SolveDiscreteARE()(A, B, Q, R) 84 | -------------------------------------------------------------------------------- /pymc_statespace/utils/simulation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from numba import njit 3 | 4 | from pymc_statespace.utils.numba_linalg import numba_block_diagonal 5 | 6 | 7 | @njit 8 | def numba_mvn_draws(mu, cov): 9 | samples = np.random.randn(*mu.shape) 10 | k = cov.shape[0] 11 | jitter = np.eye(k) * np.random.uniform(1e-12, 1e-8) 12 | 13 | L = np.linalg.cholesky(cov + jitter) 14 | return mu + L @ samples 15 | 16 | 17 | @njit 18 | def conditional_simulation(mus, covs, n, k, n_simulations=100): 19 | n_samples = mus.shape[0] 20 | simulations = np.empty((n_samples * n_simulations, n, k)) 21 | 22 | for i in range(n_samples): 23 | for j in range(n_simulations): 24 | sim = numba_mvn_draws(mus[i], numba_block_diagonal(covs[i])) 25 | simulations[(i * n_simulations + j), :, :] = sim.reshape(n, k) 26 | return simulations 27 | 28 | 29 | @njit 30 | def simulate_statespace(T, Z, R, H, Q, n_steps, x0=None): 31 | n_obs, n_states = Z.shape 32 | k_posdef = R.shape[1] 33 | k_obs_noise = H.shape[0] * (1 - int(np.all(H == 0))) 34 | 35 | state_noise = np.random.randn(n_steps, k_posdef) 36 | state_chol = np.linalg.cholesky(Q) 37 | state_innovations = state_noise @ state_chol 38 | 39 | if k_obs_noise != 0: 40 | obs_noise = np.random.randn(n_steps, k_obs_noise) 41 | obs_chol = np.linalg.cholesky(H) 42 | obs_innovations = obs_noise @ obs_chol 43 | 44 | simulated_states = np.zeros((n_steps, n_states)) 45 | simulated_obs = np.zeros((n_steps, n_obs)) 46 | 47 | if x0 is not None: 48 | simulated_states[0] = x0 49 | simulated_obs[0] = Z @ x0 50 | 51 | if k_obs_noise != 0: 52 | for t in range(1, n_steps): 53 | simulated_states[t] = T @ simulated_states[t - 1] + R @ state_innovations[t] 54 | simulated_obs[t] = Z @ simulated_states[t - 1] + obs_innovations[t] 55 | else: 56 | for t in range(1, n_steps): 57 | simulated_states[t] = T @ simulated_states[t - 1] + R @ state_innovations[t] 58 | simulated_obs[t] = Z @ simulated_states[t - 1] 59 | 60 | return simulated_states, simulated_obs 61 | 62 | 63 | def unconditional_simulations(thetas, update_funcs, n_steps=100, n_simulations=100): 64 | samples, *_ = thetas[0].shape 65 | _, _, T, Z, R, H, Q = (f(*[theta[0] for theta in thetas])[0] for f in update_funcs) 66 | n_obs, n_states = Z.shape 67 | 68 | states = np.empty((samples * n_simulations, n_steps, n_states)) 69 | observed = np.empty((samples * n_simulations, n_steps, n_obs)) 70 | 71 | for i in range(samples): 72 | theta = [x[i] for x in thetas] 73 | _, _, T, Z, R, H, Q = (f(*theta)[0] for f in update_funcs) 74 | for j in range(n_simulations): 75 | sim_state, sim_obs = simulate_statespace(T, Z, R, H, Q, n_steps=n_steps) 76 | states[i * n_simulations + j, :, :] = sim_state 77 | observed[i * n_simulations + j, :, :] = sim_obs 78 | 79 | return states, observed 80 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.pytest.ini_options] 2 | minversion = "6.0" 3 | xfail_strict=true 4 | filterwarnings = [ 5 | "error", 6 | "ignore::DeprecationWarning"] 7 | 8 | env = ["NUMBA_DISABLE_JIT = 1"] 9 | 10 | [tool.isort] 11 | profile = 'black' 12 | 13 | [tool.black] 14 | line-length = 100 15 | 16 | [tool.nbqa.mutate] 17 | isort = 1 18 | black = 1 19 | pyupgrade = 1 20 | 21 | [tool.bumpver] 22 | current_version = "0.0.1" 23 | version_pattern = "MAJOR.MINOR.PATCH" 24 | commit_message = "Bump version {old_version} -> {new_version}" 25 | commit = true 26 | tag = true 27 | push = false 28 | 29 | [tool.bumpver.file_patterns] 30 | "pyproject.toml" = ['current_version = "{version}"'] 31 | "setup.cfg" = ['version = {version}'] 32 | "pymc_statespace/__init__.py" = ["{version}"] 33 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = pymc-statespace 3 | version = 0.0.1 4 | description = A system for Bayesian estimation of state space models 5 | long_description = file: README.md 6 | long_description_content_type = text/markdown 7 | url = https://github.com/jessegrabowski/pymc_statespace 8 | author = Jesse Grabowski 9 | author_email = jessegrabowski@gmail.com 10 | 11 | [options] 12 | packages = find: 13 | install_requires = 14 | pymc>=5.2 15 | pytensor 16 | numba>=0.57 17 | numpy 18 | pandas 19 | xarray 20 | matplotlib 21 | arviz 22 | statsmodels 23 | 24 | [options.packages.find] 25 | exclude = tests* 26 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup() 4 | -------------------------------------------------------------------------------- /svgs/17b59c002f249204f24e31507dc4957d.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | -------------------------------------------------------------------------------- /svgs/3b5e41543d7fc8cedf98ec609b343134.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | -------------------------------------------------------------------------------- /svgs/523b266d36c270dbbb5daf2c9092ce0f.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | -------------------------------------------------------------------------------- /svgs/54221efbfb5e69569dfe8ddea785093a.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | -------------------------------------------------------------------------------- /svgs/92b8c1194757fb3131cda468a34be85f.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | -------------------------------------------------------------------------------- /svgs/a06c0e58d4d162b0e87d32927c9812db.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | -------------------------------------------------------------------------------- /svgs/a13d89295e999545a129b2d412e99f6d.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | -------------------------------------------------------------------------------- /svgs/cac7e81ebde5e530e639eae5389f149e.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | -------------------------------------------------------------------------------- /svgs/d523a14b8179ebe46f0ed16895ee46f0.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | -------------------------------------------------------------------------------- /svgs/edcff444fd5240add1c47d2de50ebd7e.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | -------------------------------------------------------------------------------- /svgs/f566e90ed17c5292db4600846e0ace27.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /svgs/ff25a8f22c7430ca572d33206c0a9176.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessegrabowski/pymc_statespace/0659a3bfd9186f128f238c69d5193d3c461f3948/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_VARMAX.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import warnings 4 | from itertools import product 5 | from pathlib import Path 6 | 7 | import numpy as np 8 | import pandas as pd 9 | import pymc as pm 10 | import pytensor.tensor as pt 11 | import pytest 12 | import statsmodels.api as sm 13 | from numpy.testing import assert_allclose 14 | 15 | from pymc_statespace import BayesianVARMAX 16 | 17 | ROOT = Path(__file__).parent.absolute() 18 | sys.path.append(ROOT) 19 | 20 | 21 | @pytest.fixture 22 | def data(): 23 | return pd.read_csv( 24 | os.path.join(ROOT, "test_data/statsmodels_macrodata_processed.csv"), index_col=0 25 | ) 26 | 27 | 28 | ps = [0, 1, 2, 3] 29 | qs = [0, 1, 2, 3] 30 | orders = list(product(ps, qs))[1:] 31 | ids = [f"p={x[0]}, q={x[1]}" for x in orders] 32 | 33 | 34 | @pytest.mark.parametrize("order", orders, ids=ids) 35 | @pytest.mark.parametrize("matrix", ["transition", "selection", "state_cov", "obs_cov", "design"]) 36 | def test_VARMAX_init_matches_statsmodels(data, order, matrix): 37 | p, q = order 38 | 39 | mod = BayesianVARMAX(data, order=(p, q), verbose=False) 40 | with warnings.catch_warnings(): 41 | warnings.simplefilter("ignore") 42 | sm_var = sm.tsa.VARMAX(data, order=(p, q)) 43 | 44 | assert_allclose(mod.ssm[matrix].eval(), sm_var.ssm[matrix]) 45 | 46 | 47 | @pytest.mark.parametrize("order", orders, ids=ids) 48 | @pytest.mark.parametrize("var", ["AR", "MA", "state_cov"]) 49 | def test_VARMAX_param_counts_match_statsmodels(data, order, var): 50 | p, q = order 51 | 52 | mod = BayesianVARMAX(data, order=(p, q), verbose=False) 53 | with warnings.catch_warnings(): 54 | warnings.simplefilter("ignore") 55 | sm_var = sm.tsa.VARMAX(data, order=(p, q)) 56 | 57 | count = mod.param_counts[var] 58 | if var == "state_cov": 59 | # Statsmodels only counts the lower triangle 60 | count = mod.k_posdef * (mod.k_posdef - 1) 61 | assert count == sm_var.parameters[var.lower()] 62 | 63 | 64 | @pytest.mark.parametrize("order", orders, ids=ids) 65 | @pytest.mark.parametrize("matrix", ["transition", "selection", "state_cov", "obs_cov", "design"]) 66 | def test_VARMAX_update_matches_statsmodels(data, order, matrix): 67 | p, q = order 68 | 69 | with warnings.catch_warnings(): 70 | warnings.simplefilter("ignore") 71 | sm_var = sm.tsa.VARMAX(data, order=(p, q)) 72 | 73 | param_counts = [None] + np.cumsum(list(sm_var.parameters.values())).tolist() 74 | param_slices = [slice(a, b) for a, b in zip(param_counts[:-1], param_counts[1:])] 75 | param_lists = [trend, ar, ma, reg, state_cov, obs_cov] = [ 76 | sm_var.param_names[idx] for idx in param_slices 77 | ] 78 | param_d = { 79 | k: np.random.normal(scale=0.1) ** 2 for param_list in param_lists for k in param_list 80 | } 81 | 82 | res = sm_var.fit_constrained(param_d) 83 | 84 | mod = BayesianVARMAX(data, order=(p, q), verbose=False, measurement_error=False) 85 | 86 | with pm.Model() as pm_mod: 87 | x0 = pm.Deterministic("x0", pt.zeros(mod.k_states)) 88 | ma_params = pm.Deterministic( 89 | "ma_params", pt.as_tensor_variable(np.array([param_d[var] for var in ma])) 90 | ) 91 | ar_params = pm.Deterministic( 92 | "ar_params", pt.as_tensor_variable(np.array([param_d[var] for var in ar])) 93 | ) 94 | state_chol = np.zeros((mod.k_posdef, mod.k_posdef)) 95 | state_chol[np.tril_indices(mod.k_posdef)] = np.array([param_d[var] for var in state_cov]) 96 | state_cov = pm.Deterministic("state_cov", pt.as_tensor_variable(state_chol @ state_chol.T)) 97 | mod.build_statespace_graph() 98 | 99 | assert_allclose(mod.ssm[matrix].eval(), sm_var.ssm[matrix]) 100 | -------------------------------------------------------------------------------- /tests/test_data/nile.csv: -------------------------------------------------------------------------------- 1 | "x" 2 | 1120 3 | 1160 4 | 963 5 | 1210 6 | 1160 7 | 1160 8 | 813 9 | 1230 10 | 1370 11 | 1140 12 | 995 13 | 935 14 | 1110 15 | 994 16 | 1020 17 | 960 18 | 1180 19 | 799 20 | 958 21 | 1140 22 | 1100 23 | 1210 24 | 1150 25 | 1250 26 | 1260 27 | 1220 28 | 1030 29 | 1100 30 | 774 31 | 840 32 | 874 33 | 694 34 | 940 35 | 833 36 | 701 37 | 916 38 | 692 39 | 1020 40 | 1050 41 | 969 42 | 831 43 | 726 44 | 456 45 | 824 46 | 702 47 | 1120 48 | 1100 49 | 832 50 | 764 51 | 821 52 | 768 53 | 845 54 | 864 55 | 862 56 | 698 57 | 845 58 | 744 59 | 796 60 | 1040 61 | 759 62 | 781 63 | 865 64 | 845 65 | 944 66 | 984 67 | 897 68 | 822 69 | 1010 70 | 771 71 | 676 72 | 649 73 | 846 74 | 812 75 | 742 76 | 801 77 | 1040 78 | 860 79 | 874 80 | 848 81 | 890 82 | 744 83 | 749 84 | 838 85 | 1050 86 | 918 87 | 986 88 | 797 89 | 923 90 | 975 91 | 815 92 | 1020 93 | 906 94 | 901 95 | 1170 96 | 912 97 | 746 98 | 919 99 | 718 100 | 714 101 | 740 102 | -------------------------------------------------------------------------------- /tests/test_data/statsmodels_macrodata_processed.csv: -------------------------------------------------------------------------------- 1 | ,realgdp,realcons,realinv 2 | 1959-06-30,0.024942130816387298,0.015286107415635186,0.08021268127441772 3 | 1959-09-30,-0.0011929521106681662,0.010385977737146668,-0.0721310437426288 4 | 1959-12-31,0.003494532654372051,0.0010840109462586511,0.03442511117316993 5 | 1960-03-31,0.022190179514293362,0.0095341508767115,0.10266376814814393 6 | 1960-06-30,-0.004684553282733539,0.012572428049423046,-0.10669384516657932 7 | 1960-09-30,0.0016328801889198274,-0.003967926518265941,-0.0059778791939155695 8 | 1960-12-31,-0.01290635994607836,0.001343033218102363,-0.1318520178798508 9 | 1961-03-31,0.005922591255446363,-0.00027964988017536996,0.025244180753152712 10 | 1961-06-30,0.01853453415389339,0.01476984095507472,0.07183387371281658 11 | 1961-09-30,0.016031639161811384,0.004838630433301461,0.08045270660189097 12 | 1961-12-31,0.020152816046898003,0.019823062007010783,0.016737113362548683 13 | 1962-03-31,0.017777259286836156,0.01059116613253508,0.05791064029750803 14 | 1962-06-30,0.010980515349976017,0.01221623378687653,-0.009715848023952312 15 | 1962-09-30,0.009204067213396172,0.008062026704950043,0.017733971140915017 16 | 1962-12-31,0.002427018714243445,0.014082552171855056,-0.03414697935818811 17 | 1963-03-31,0.01298521045351464,0.006712294307324562,0.05400709679319515 18 | 1963-06-30,0.012452834593128514,0.009504277281920714,0.014467702035278585 19 | 1963-09-30,0.01865404073909538,0.013515416623270937,0.032089340831295665 20 | 1963-12-31,0.0075738617893286175,0.008349119168469699,0.012232500910016597 21 | 1964-03-31,0.022195863868844867,0.01955417478523991,0.04029537549762097 22 | 1964-06-30,0.011419916678280018,0.01741600837232049,-0.004608479556662992 23 | 1964-09-30,0.01349678440620039,0.018195638939121572,0.02348211049511484 24 | 1964-12-31,0.0027684319766052568,0.0028061004289119396,0.008127111263337206 25 | 1965-03-31,0.02426471280607778,0.02198702890900517,0.09587891389684966 26 | 1965-06-30,0.013476920985249663,0.010995612622818562,-6.05874183419175e-05 27 | 1965-09-30,0.02009026988071838,0.01702550085568255,0.035089793025194105 28 | 1965-12-31,0.0238395627245076,0.027732704078884396,0.004599659936751266 29 | 1966-03-31,0.024249417896358594,0.014669560481433308,0.0811651892140155 30 | 1966-06-30,0.003323329257922225,0.002551564216559221,-0.018415529197198133 31 | 1966-09-30,0.00655531757692529,0.011402150557305646,-0.009958807818388316 32 | 1966-12-31,0.008069240525410137,0.0041484272445648784,0.004789900368053601 33 | 1967-03-31,0.008770749498840047,0.005795667399719484,-0.027762768509880686 34 | 1967-06-30,0.00020820851954361785,0.013544411542069312,-0.043574196677693244 35 | 1967-09-30,0.007946288894562059,0.00511384508807744,0.028297460430921184 36 | 1967-12-31,0.0076008372189466655,0.006142850307218062,0.021403487969130275 37 | 1968-03-31,0.020399308411914063,0.02360689149619688,0.02153029334303813 38 | 1968-06-30,0.016836250623496696,0.015212856527097252,0.03963280880042941 39 | 1968-09-30,0.006818187675689202,0.0186292780304953,-0.033002368053595355 40 | 1968-12-31,0.004323535018018632,0.004584535250413246,0.010333919522066637 41 | 1969-03-31,0.015626993226984354,0.011144077123032226,0.06380045990686156 42 | 1969-06-30,0.002908045754733024,0.006350181971833457,-0.007999752832058782 43 | 1969-09-30,0.00630412172871786,0.0048201868198987086,0.02285706056559711 44 | 1969-12-31,-0.004707590235586423,0.00794538968601799,-0.05536353177486397 45 | 1970-03-31,-0.0015699839630070045,0.006120060394082749,-0.031798102136913364 46 | 1970-06-30,0.0018110848665440216,0.004583883784991194,0.0031276717452559666 47 | 1970-09-30,0.008864772287074274,0.008706318768779475,0.016943182250337863 48 | 1970-12-31,-0.010660821696957257,-0.002723956139778494,-0.05967484549706992 49 | 1971-03-31,0.027202168334472532,0.018949376401506512,0.12209449744029222 50 | 1971-06-30,0.0056567889131589055,0.00912956377520402,0.030519869627655183 51 | 1971-09-30,0.007950886761737053,0.007924948579291602,0.013068141575951486 52 | 1971-12-31,0.002774937094654817,0.016492493094995453,-0.03178387805772953 53 | 1972-03-31,0.01772331446789366,0.013266567575952237,0.06832901837757266 54 | 1972-06-30,0.023438898862099933,0.018924178786744683,0.059410089170633285 55 | 1972-09-30,0.009538014207757683,0.015320125879748403,0.01413153537262346 56 | 1972-12-31,0.016336792461098426,0.023192179368109578,0.0051755829879578386 57 | 1973-03-31,0.02525804234288387,0.018129887541462608,0.061563645765462915 58 | 1973-06-30,0.011501097796264403,-0.0005053376394563713,0.04568816612695592 59 | 1973-09-30,-0.005350042942591671,0.0035634922053358054,-0.03988458461383537 60 | 1973-12-31,0.009493238255322112,-0.0029318600468926093,0.037538073304505204 61 | 1974-03-31,-0.008807613756498966,-0.008783808120972125,-0.06593511447881983 62 | 1974-06-30,0.002557212404131093,0.0034656569352140565,-0.004967723257728096 63 | 1974-09-30,-0.009936678481238914,0.004117708773394568,-0.058061447421211554 64 | 1974-12-31,-0.003943318438729193,-0.014743377118552559,0.009461202519150724 65 | 1975-03-31,-0.012237921312634015,0.00833777411551928,-0.19316322686510023 66 | 1975-06-30,0.007613228387052473,0.016532431063058795,-0.035342692176429935 67 | 1975-09-30,0.01670305536159411,0.014167887518730993,0.08128912221366846 68 | 1975-12-31,0.01297845428771538,0.010526251464110459,0.027115371108036967 69 | 1976-03-31,0.02247811058416005,0.01979843719362151,0.09853530316550518 70 | 1976-06-30,0.007492297472733611,0.009116702499950335,0.04176338114011724 71 | 1976-09-30,0.004886706909426053,0.010532158188986784,0.0018540964384721192 72 | 1976-12-31,0.007235398273953919,0.012916949526614374,0.006927681623511539 73 | 1977-03-31,0.011541159204879747,0.011378325577613424,0.048024592961263046 74 | 1977-06-30,0.019678246737301563,0.005512803347329509,0.07444626765491602 75 | 1977-09-30,0.017726137552946497,0.009497286008931738,0.053063063975045566 76 | 1977-12-31,-0.0002069209563817509,0.014954910730345716,-0.02863469545466213 77 | 1978-03-31,0.003408732702791184,0.0058204155446599515,0.01963890276843383 78 | 1978-06-30,0.038585475403756675,0.02116452181230777,0.0664049810036742 79 | 1978-09-30,0.009756162989988937,0.004185066791748682,0.030400604229731343 80 | 1978-12-31,0.013139436292622264,0.008023980684571441,0.022499917479481546 81 | 1979-03-31,0.0016709945702260143,0.005039002863874487,-0.00016226982579325977 82 | 1979-06-30,0.0009382908509696364,-0.0005852464852740269,-0.002316127923967848 83 | 1979-09-30,0.007162420120430113,0.009771159048007405,-0.019659741924709984 84 | 1979-12-31,0.0027476398395105406,0.002657931274367087,-0.0187911675618766 85 | 1980-03-31,0.0032161514595951957,-0.0017360658804363993,-0.007274586891929502 86 | 1980-06-30,-0.02070793158034334,-0.02295523265035193,-0.09455513343548372 87 | 1980-09-30,-0.001860258111918256,0.010664338361117132,-0.07927831360496729 88 | 1980-12-31,0.01832680700930389,0.013238212922260573,0.0968258047063566 89 | 1981-03-31,0.020566824679381313,0.005455894343649348,0.0947431099496372 90 | 1981-06-30,-0.00801140272402101,0.0,-0.048776329370762816 91 | 1981-09-30,0.012077078189896895,0.004046421832148539,0.06021997477221852 92 | 1981-12-31,-0.012535910197692957,-0.007584115017426285,-0.039552608640689435 93 | 1982-03-31,-0.016547233635337832,0.006437116638242202,-0.11001915580959931 94 | 1982-06-30,0.00540438914742758,0.0035930053234061177,-0.0008870179588731375 95 | 1982-09-30,-0.0038627257685650562,0.00763767645005764,-0.011739399570102726 96 | 1982-12-31,0.0007891034952027809,0.018070548333479763,-0.0932678840105785 97 | 1983-03-31,0.012360524736578782,0.00975494533673249,0.034986691966604866 98 | 1983-06-30,0.02222733528868126,0.019647192590248608,0.0921358495310205 99 | 1983-09-30,0.019527814105522623,0.017529979698624132,0.06507760763708603 100 | 1983-12-31,0.020459959942552786,0.01573023035373744,0.10011095968525563 101 | 1984-03-31,0.019210165810124025,0.008528432247278062,0.09954287887801616 102 | 1984-06-30,0.01711776350258809,0.01421719441872682,0.033159907559543456 103 | 1984-09-30,0.009671516584397466,0.00766837614262883,0.02297814070131121 104 | 1984-12-31,0.008108095427875384,0.013091863113150026,-0.016542413296196656 105 | 1985-03-31,0.00939240246972517,0.016827618778211928,-0.03352576107703609 106 | 1985-06-30,0.008431222368500357,0.009052660542270274,0.017114336250190654 107 | 1985-09-30,0.015499913364575235,0.018822015632972366,-0.0111110951439235 108 | 1985-12-31,0.007560947026535203,0.0021976357071959995,0.03835114042144738 109 | 1986-03-31,0.009563067941808612,0.00831155625679969,-0.0020569212508529944 110 | 1986-06-30,0.004009178207507347,0.010591865745748663,-0.02244250448137919 111 | 1986-09-30,0.009595188450980174,0.017337686123433116,-0.03185277828355737 112 | 1986-12-31,0.004821793960115173,0.005997865443140071,0.001549575539982584 113 | 1987-03-31,0.005528935739866014,-0.0015221813917598581,0.030095986982594525 114 | 1987-06-30,0.010577916660121645,0.013308102382179499,0.0013989296878635926 115 | 1987-09-30,0.008635542829132703,0.011078573792220325,0.0010078332176899352 116 | 1987-12-31,0.01696579977682866,0.0024001801486601693,0.0750846836507959 117 | 1988-03-31,0.005159352817525331,0.01656194042911352,-0.057993128623962775 118 | 1988-06-30,0.01276151335619602,0.007247274184601693,0.02403489024554961 119 | 1988-09-30,0.005149580544561161,0.007864565206437746,0.006403454265695885 120 | 1988-12-31,0.013264351168576383,0.011701532598701547,0.013156828932998188 121 | 1989-03-31,0.009344884553845745,0.0036700139253920128,0.037605497844809044 122 | 1989-06-30,0.007454656791988867,0.004467358058686699,-0.011753595376318593 123 | 1989-09-30,0.007899659182339036,0.010308676978862508,-0.011821928251866787 124 | 1989-12-31,0.0021804320484175577,0.004877222036274276,-0.010316302801358646 125 | 1990-03-31,0.010392526978691308,0.007875099856645917,0.009793458713496683 126 | 1990-06-30,0.003966490297244718,0.003294399681363913,0.0002839754625112434 127 | 1990-09-30,-1.5137345979354677e-05,0.003789233369559497,-0.023819672846864037 128 | 1990-12-31,-0.008799970050372252,-0.007800424872460354,-0.06532875105654501 129 | 1991-03-31,-0.004856014563534572,-0.002853392463110893,-0.04156716709528485 130 | 1991-06-30,0.00672662020935455,0.007597287958210686,-0.005040521569013912 131 | 1991-09-30,0.004203639797180969,0.0038051724967349543,0.02459129417067274 132 | 1991-12-31,0.003912442289637497,-0.00044911861228591476,0.03756840115321758 133 | 1992-03-31,0.010916709746682685,0.017055094156383177,-0.02248205323390895 134 | 1992-06-30,0.010569355265042546,0.0059076236833242035,0.06379346786306783 135 | 1992-09-30,0.010265417859880444,0.01098812198103083,0.010284555756322256 136 | 1992-12-31,0.010468628590384554,0.012138581118792402,0.031159828241962728 137 | 1993-03-31,0.0018361379777385167,0.004031472307636008,0.02322640760306438 138 | 1993-06-30,0.006377497246457864,0.009549783338858475,0.00782162022374333 139 | 1993-09-30,0.005250241191419036,0.010803521482550593,-0.0007046835401780527 140 | 1993-12-31,0.013119466990481499,0.00885699235809767,0.051424907521263385 141 | 1994-03-31,0.009688253462533325,0.011073205873081804,0.042240819567175514 142 | 1994-06-30,0.01358569880162186,0.007393812006782241,0.05665002899143978 143 | 1994-09-30,0.006420403956367338,0.00797962003214181,-0.018141383569249214 144 | 1994-12-31,0.01104477749296251,0.009819113744269359,0.04556673100613651 145 | 1995-03-31,0.0024502400068371344,0.0011665112641257025,0.010112657125509017 146 | 1995-06-30,0.0021473245426548715,0.00816073795386174,-0.02726183773355295 147 | 1995-09-30,0.00836938846235924,0.008897280161562549,-0.009672229723568293 148 | 1995-12-31,0.006947981234826983,0.007015740708053997,0.02776802287631952 149 | 1996-03-31,0.0068267227982588,0.009111938423863819,0.013087865506328455 150 | 1996-06-30,0.01714011546601668,0.011246513926669977,0.050498380009807775 151 | 1996-09-30,0.008660860530211423,0.005956322159514471,0.049194061600450034 152 | 1996-12-31,0.010856683405549461,0.008122624642348697,-0.0027222953368513103 153 | 1997-03-31,0.0076617204401348005,0.010018604777197737,0.023078936598619038 154 | 1997-06-30,0.0147176177580981,0.004035267457265235,0.06189844480028306 155 | 1997-09-30,0.01247302253870508,0.016863909657265808,0.01762257993775762 156 | 1997-12-31,0.00764257404176405,0.011372793339003096,0.01584465220509479 157 | 1998-03-31,0.009402375786226713,0.009903963846072728,0.04663007831635557 158 | 1998-06-30,0.008952009121774296,0.017059054612118985,-0.012038222653899311 159 | 1998-09-30,0.013108366865449028,0.013177531481691318,0.028250019650299052 160 | 1998-12-31,0.01716157436950816,0.015216842332248959,0.03165239255225849 161 | 1999-03-31,0.008868795717511091,0.009810532772998926,0.03100057041875015 162 | 1999-06-30,0.007786660186548389,0.01562159957313547,-0.003497282090812348 163 | 1999-09-30,0.01263644360829197,0.011942459850613929,0.024791673771441758 164 | 1999-12-31,0.01780192698963745,0.014009877815061245,0.034874117857821574 165 | 2000-03-31,0.002610475349184682,0.015056864152766636,-0.014060287861793697 166 | 2000-06-30,0.0193185856258129,0.009354481042235463,0.06693782876834042 167 | 2000-09-30,0.0008357335001925037,0.009738736557368455,-0.015765208923020246 168 | 2000-12-31,0.005900007164653331,0.008802484015141943,0.00044731799897590463 169 | 2001-03-31,-0.003302713380032074,0.003985048423478688,-0.05434900101941054 170 | 2001-06-30,0.0065359877030051194,0.003763424514321656,-0.0032138639191021667 171 | 2001-09-30,-0.0027454160608613165,0.004389912833836718,-0.021314183957510835 172 | 2001-12-31,0.0035257644037436364,0.015542609405262198,-0.05936563281615381 173 | 2002-03-31,0.008551982920579349,0.003436710689667777,0.03303111582599705 174 | 2002-06-30,0.005292010252050616,0.0050767575239785145,0.011917567926544415 175 | 2002-09-30,0.004984622513504178,0.006754003917142981,0.0020698926930968753 176 | 2002-12-31,0.0002064215385182422,0.003545619424254909,-0.000723313380029289 177 | 2003-03-31,0.004043517814475095,0.005147238095256412,-4.301834346964739e-05 178 | 2003-06-30,0.00794435538282201,0.009252460293302178,0.005805607370590771 179 | 2003-09-30,0.016622298075029462,0.013846458933041816,0.035648614625020336 180 | 2003-12-31,0.008954497678535844,0.005506879502322093,0.036318830433965665 181 | 2004-03-31,0.007017360710909415,0.009478743954543845,0.005207404106674751 182 | 2004-06-30,0.00708219043205105,0.005389829132301571,0.042516889635136224 183 | 2004-09-30,0.007318523149864475,0.008521678425312373,0.01288186847757089 184 | 2004-12-31,0.008638865661898976,0.01143533669801755,0.020403414183223667 185 | 2005-03-31,0.009928644671413522,0.007459800209780099,0.0210216168488655 186 | 2005-06-30,0.004253071337316783,0.009576660204851706,-0.018054001870183356 187 | 2005-09-30,0.007567538827411013,0.007097406480417234,0.010956113173677728 188 | 2005-12-31,0.00515464977792135,0.0025796872496552936,0.03521745293244205 189 | 2006-03-31,0.013032825454358132,0.010976272570252021,0.01446706222507732 190 | 2006-06-30,0.0035955893813284234,0.005371345093337254,-0.001535141513263838 191 | 2006-09-30,0.0002664262315530408,0.006145988881025133,-0.014078087577432008 192 | 2006-12-31,0.007282045058815356,0.009949568459104441,-0.02897189213120477 193 | 2007-03-31,0.0029985596181916208,0.00905317160290231,-0.015520338524286359 194 | 2007-06-30,0.007913399166394441,0.0028453507443444437,0.01378658394853094 195 | 2007-09-30,0.008831847811689997,0.004735045433360341,0.001976111281232207 196 | 2007-12-31,0.0052515140142759265,0.0029947827636505053,-0.020077985987605906 197 | 2008-03-31,-0.0018225504794315839,-0.0014962702917689086,-0.01927639001132686 198 | 2008-06-30,0.0036144287915647055,0.00014972781649902345,-0.02743538266838197 199 | 2008-09-30,-0.0067813613915230775,-0.00894805285032696,-0.01783623003564294 200 | 2008-12-31,-0.01380482973598518,-0.007842752651521678,-0.06916464982757464 201 | 2009-03-31,-0.016611979742227945,0.0015105004366180452,-0.17559819913074826 202 | 2009-06-30,-0.001851247642470355,-0.00219586786933057,-0.06756146963272425 203 | 2009-09-30,0.006862187581308632,0.007264873372626823,0.02019724281421187 204 | -------------------------------------------------------------------------------- /tests/test_kalman_filter.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | import pytensor 5 | import pytensor.tensor as pt 6 | import pytest 7 | from numpy.testing import assert_allclose 8 | 9 | from pymc_statespace.filters import ( 10 | CholeskyFilter, 11 | KalmanSmoother, 12 | SingleTimeseriesFilter, 13 | StandardFilter, 14 | SteadyStateFilter, 15 | UnivariateFilter, 16 | ) 17 | from pymc_statespace.filters.kalman_filter import BaseFilter 18 | from tests.utilities.test_helpers import ( 19 | get_expected_shape, 20 | get_sm_state_from_output_name, 21 | initialize_filter, 22 | make_test_inputs, 23 | nile_test_test_helper, 24 | ) 25 | 26 | standard_inout = initialize_filter(StandardFilter()) 27 | cholesky_inout = initialize_filter(CholeskyFilter()) 28 | univariate_inout = initialize_filter(UnivariateFilter()) 29 | single_inout = initialize_filter(SingleTimeseriesFilter()) 30 | steadystate_inout = initialize_filter(SteadyStateFilter()) 31 | 32 | f_standard = pytensor.function(*standard_inout) 33 | f_cholesky = pytensor.function(*cholesky_inout) 34 | f_univariate = pytensor.function(*univariate_inout) 35 | f_single_ts = pytensor.function(*single_inout) 36 | f_steady = pytensor.function(*steadystate_inout) 37 | 38 | filter_funcs = [f_standard, f_cholesky, f_univariate, f_single_ts, f_steady] 39 | 40 | filter_names = [ 41 | "StandardFilter", 42 | "CholeskyFilter", 43 | "UnivariateFilter", 44 | "SingleTimeSeriesFilter", 45 | "SteadyStateFilter", 46 | ] 47 | output_names = [ 48 | "filtered_states", 49 | "predicted_states", 50 | "smoothed_states", 51 | "filtered_covs", 52 | "predicted_covs", 53 | "smoothed_covs", 54 | "log_likelihood", 55 | "ll_obs", 56 | ] 57 | 58 | 59 | def test_base_class_update_raises(): 60 | filter = BaseFilter() 61 | inputs = [None] * 8 62 | with pytest.raises(NotImplementedError): 63 | filter.update(*inputs) 64 | 65 | 66 | @pytest.mark.parametrize("filter_func", filter_funcs, ids=filter_names) 67 | @pytest.mark.parametrize(("output_idx", "name"), list(enumerate(output_names)), ids=output_names) 68 | def test_output_shapes_one_state_one_observed(filter_func, output_idx, name): 69 | p, m, r, n = 1, 1, 1, 10 70 | inputs = make_test_inputs(p, m, r, n) 71 | 72 | outputs = filter_func(*inputs) 73 | expected_output = get_expected_shape(name, p, m, r, n) 74 | 75 | assert outputs[output_idx].shape == expected_output 76 | 77 | 78 | @pytest.mark.parametrize("filter_func", filter_funcs, ids=filter_names) 79 | @pytest.mark.parametrize(("output_idx", "name"), list(enumerate(output_names)), ids=output_names) 80 | def test_output_shapes_when_all_states_are_stochastic(filter_func, output_idx, name): 81 | p, m, r, n = 1, 2, 2, 10 82 | inputs = make_test_inputs(p, m, r, n) 83 | 84 | outputs = filter_func(*inputs) 85 | expected_output = get_expected_shape(name, p, m, r, n) 86 | 87 | assert outputs[output_idx].shape == expected_output 88 | 89 | 90 | @pytest.mark.parametrize("filter_func", filter_funcs, ids=filter_names) 91 | @pytest.mark.parametrize(("output_idx", "name"), list(enumerate(output_names)), ids=output_names) 92 | def test_output_shapes_when_some_states_are_deterministic(filter_func, output_idx, name): 93 | p, m, r, n = 1, 5, 2, 10 94 | inputs = make_test_inputs(p, m, r, n) 95 | 96 | outputs = filter_func(*inputs) 97 | expected_output = get_expected_shape(name, p, m, r, n) 98 | 99 | assert outputs[output_idx].shape == expected_output 100 | 101 | 102 | @pytest.fixture 103 | def f_standard_nd(): 104 | ksmoother = KalmanSmoother() 105 | data = pt.dtensor3(name="data") 106 | a0 = pt.matrix(name="a0") 107 | P0 = pt.matrix(name="P0") 108 | Q = pt.dtensor3(name="Q") 109 | H = pt.dtensor3(name="H") 110 | T = pt.dtensor3(name="T") 111 | R = pt.dtensor3(name="R") 112 | Z = pt.dtensor3(name="Z") 113 | 114 | inputs = [data, a0, P0, T, Z, R, H, Q] 115 | 116 | ( 117 | filtered_states, 118 | predicted_states, 119 | filtered_covs, 120 | predicted_covs, 121 | log_likelihood, 122 | ll_obs, 123 | ) = StandardFilter().build_graph(*inputs) 124 | 125 | smoothed_states, smoothed_covs = ksmoother.build_graph(T, R, Q, filtered_states, filtered_covs) 126 | 127 | outputs = [ 128 | filtered_states, 129 | predicted_states, 130 | smoothed_states, 131 | filtered_covs, 132 | predicted_covs, 133 | smoothed_covs, 134 | log_likelihood, 135 | ll_obs, 136 | ] 137 | 138 | f_standard = pytensor.function(inputs, outputs) 139 | 140 | return f_standard 141 | 142 | 143 | @pytest.mark.parametrize(("output_idx", "name"), list(enumerate(output_names)), ids=output_names) 144 | def test_output_shapes_with_time_varying_matrices(f_standard_nd, output_idx, name): 145 | p, m, r, n = 1, 5, 2, 10 146 | data, a0, P0, T, Z, R, H, Q = make_test_inputs(p, m, r, n) 147 | T = np.concatenate([np.expand_dims(T, 0)] * n, axis=0) 148 | Z = np.concatenate([np.expand_dims(Z, 0)] * n, axis=0) 149 | R = np.concatenate([np.expand_dims(R, 0)] * n, axis=0) 150 | H = np.concatenate([np.expand_dims(H, 0)] * n, axis=0) 151 | Q = np.concatenate([np.expand_dims(Q, 0)] * n, axis=0) 152 | 153 | outputs = f_standard_nd(data, a0, P0, T, Z, R, H, Q) 154 | expected_output = get_expected_shape(name, p, m, r, n) 155 | 156 | assert outputs[output_idx].shape == expected_output 157 | 158 | 159 | @pytest.mark.parametrize("filter_func", filter_funcs, ids=filter_names) 160 | @pytest.mark.parametrize(("output_idx", "name"), list(enumerate(output_names)), ids=output_names) 161 | def test_output_with_deterministic_observation_equation(filter_func, output_idx, name): 162 | p, m, r, n = 1, 5, 1, 10 163 | inputs = make_test_inputs(p, m, r, n) 164 | 165 | outputs = filter_func(*inputs) 166 | expected_output = get_expected_shape(name, p, m, r, n) 167 | 168 | assert outputs[output_idx].shape == expected_output 169 | 170 | 171 | @pytest.mark.parametrize( 172 | ("filter_func", "filter_name"), zip(filter_funcs, filter_names), ids=filter_names 173 | ) 174 | @pytest.mark.parametrize(("output_idx", "name"), list(enumerate(output_names)), ids=output_names) 175 | def test_output_with_multiple_observed(filter_func, filter_name, output_idx, name): 176 | p, m, r, n = 5, 5, 1, 10 177 | inputs = make_test_inputs(p, m, r, n) 178 | expected_output = get_expected_shape(name, p, m, r, n) 179 | 180 | if filter_name == "SingleTimeSeriesFilter": 181 | with pytest.raises( 182 | AssertionError, 183 | match="UnivariateTimeSeries filter requires data be at most 1-dimensional", 184 | ): 185 | filter_func(*inputs) 186 | 187 | else: 188 | outputs = filter_func(*inputs) 189 | assert outputs[output_idx].shape == expected_output 190 | 191 | 192 | @pytest.mark.parametrize( 193 | ("filter_func", "filter_name"), zip(filter_funcs, filter_names), ids=filter_names 194 | ) 195 | @pytest.mark.parametrize(("output_idx", "name"), list(enumerate(output_names)), ids=output_names) 196 | @pytest.mark.parametrize("p", [1, 5], ids=["univariate (p=1)", "multivariate (p=5)"]) 197 | def test_missing_data(filter_func, filter_name, output_idx, name, p): 198 | m, r, n = 5, 1, 10 199 | inputs = make_test_inputs(p, m, r, n, missing_data=1) 200 | if p > 1 and filter_name == "SingleTimeSeriesFilter": 201 | with pytest.raises( 202 | AssertionError, 203 | match="UnivariateTimeSeries filter requires data be at most 1-dimensional", 204 | ): 205 | filter_func(*inputs) 206 | 207 | else: 208 | outputs = filter_func(*inputs) 209 | assert not np.any(np.isnan(outputs[output_idx])) 210 | 211 | 212 | @pytest.mark.parametrize("filter_func", filter_funcs, ids=filter_names) 213 | @pytest.mark.parametrize("output_idx", [(0, 2), (3, 5)], ids=["smoothed_states", "smoothed_covs"]) 214 | def test_last_smoother_is_last_filtered(filter_func, output_idx): 215 | p, m, r, n = 1, 5, 1, 10 216 | inputs = make_test_inputs(p, m, r, n) 217 | outputs = filter_func(*inputs) 218 | 219 | filtered = outputs[output_idx[0]] 220 | smoothed = outputs[output_idx[1]] 221 | 222 | assert_allclose(filtered[-1], smoothed[-1]) 223 | 224 | 225 | # TODO: These tests omit the SteadyStateFilter, because it gives different results to StatsModels (reason to dump it?) 226 | @pytest.mark.parametrize("filter_func", filter_funcs[:-1], ids=filter_names[:-1]) 227 | @pytest.mark.parametrize(("output_idx", "name"), list(enumerate(output_names)), ids=output_names) 228 | @pytest.mark.parametrize("n_missing", [0, 5], ids=["n_missing=0", "n_missing=5"]) 229 | def test_filters_match_statsmodel_output(filter_func, output_idx, name, n_missing): 230 | fit_sm_mod, inputs = nile_test_test_helper(n_missing) 231 | outputs = filter_func(*inputs) 232 | 233 | val_to_test = outputs[output_idx].squeeze() 234 | ref_val = get_sm_state_from_output_name(fit_sm_mod, name) 235 | 236 | if name == "smoothed_covs": 237 | # TODO: The smoothed covariance matrices have large errors (1e-2) ONLY in the first few states -- no idea why. 238 | assert_allclose(val_to_test[5:], ref_val[5:]) 239 | else: 240 | # Need atol = 1e-7 for smoother tests to pass 241 | assert_allclose(val_to_test, ref_val, atol=1e-7) 242 | 243 | 244 | if __name__ == "__main__": 245 | unittest.main() 246 | -------------------------------------------------------------------------------- /tests/test_local_level.py: -------------------------------------------------------------------------------- 1 | import os 2 | import unittest 3 | from pathlib import Path 4 | 5 | import pandas as pd 6 | 7 | from pymc_statespace import BayesianLocalLevel 8 | 9 | ROOT = Path(__file__).parent.absolute() 10 | nile = pd.read_csv(os.path.join(ROOT, "test_data/nile.csv")) 11 | nile.index = pd.date_range(start="1871-01-01", end="1970-01-01", freq="AS-Jan") 12 | nile.rename(columns={"x": "height"}, inplace=True) 13 | nile = (nile - nile.mean()) / nile.std() 14 | 15 | 16 | def test_local_level_model(): 17 | mod = BayesianLocalLevel(data=nile.values) 18 | 19 | 20 | if __name__ == "__main__": 21 | unittest.main() 22 | -------------------------------------------------------------------------------- /tests/test_numba_linalg.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | from numpy.testing import assert_allclose 4 | 5 | from pymc_statespace.utils.numba_linalg import numba_block_diagonal 6 | 7 | 8 | def test_numba_block_diagonal(): 9 | stack = np.concatenate([np.eye(3)[None]] * 5, axis=0) 10 | block_stack = numba_block_diagonal(stack) 11 | assert_allclose(block_stack, np.eye(15)) 12 | -------------------------------------------------------------------------------- /tests/test_pytensor_scipy.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import unittest 3 | 4 | import numpy as np 5 | from numpy.testing import assert_allclose 6 | from pytensor.configdefaults import config 7 | from pytensor.gradient import verify_grad as orig_verify_grad 8 | 9 | from pymc_statespace.utils.pytensor_scipy import SolveDiscreteARE, solve_discrete_are 10 | 11 | solve_discrete_are_enforce = SolveDiscreteARE(enforce_Q_symmetric=True) 12 | 13 | 14 | def fetch_seed(pseed=None): 15 | """ 16 | Copied from pytensor.test.unittest_tools 17 | """ 18 | 19 | seed = pseed or config.unittests__rseed 20 | if seed == "random": 21 | seed = None 22 | 23 | try: 24 | if seed: 25 | seed = int(seed) 26 | else: 27 | seed = None 28 | except ValueError: 29 | print( 30 | ("Error: config.unittests__rseed contains " "invalid seed, using None instead"), 31 | file=sys.stderr, 32 | ) 33 | seed = None 34 | 35 | return seed 36 | 37 | 38 | def verify_grad(op, pt, n_tests=2, rng=None, *args, **kwargs): 39 | """ 40 | Copied from pytensor.test.unittest_tools 41 | """ 42 | if rng is None: 43 | rng = np.random.default_rng(fetch_seed()) 44 | 45 | # TODO: Needed to increase tolerance for certain tests when migrating to 46 | # Generators from RandomStates. Caused flaky test failures. Needs further investigation 47 | if "rel_tol" not in kwargs: 48 | kwargs["rel_tol"] = 0.05 49 | if "abs_tol" not in kwargs: 50 | kwargs["abs_tol"] = 0.05 51 | orig_verify_grad(op, pt, n_tests, rng, *args, **kwargs) 52 | 53 | 54 | class TestSolveDiscreteARE(unittest.TestCase): 55 | def test_forward(self): 56 | # TEST CASE 4 : darex #1 -- taken from Scipy tests 57 | a, b, q, r = ( 58 | np.array([[4, 3], [-4.5, -3.5]]), 59 | np.array([[1], [-1]]), 60 | np.array([[9, 6], [6, 4]]), 61 | np.array([[1]]), 62 | ) 63 | a, b, q, r = (x.astype("float64") for x in [a, b, q, r]) 64 | 65 | x = solve_discrete_are(a, b, q, r).eval() 66 | res = a.T.dot(x.dot(a)) - x + q 67 | res -= ( 68 | a.conj() 69 | .T.dot(x.dot(b)) 70 | .dot(np.linalg.solve(r + b.conj().T.dot(x.dot(b)), b.T).dot(x.dot(a))) 71 | ) 72 | 73 | assert_allclose(res, np.zeros_like(res), atol=1e-12) 74 | 75 | def test_backward(self): 76 | 77 | a, b, q, r = ( 78 | np.array([[4, 3], [-4.5, -3.5]]), 79 | np.array([[1], [-1]]), 80 | np.array([[9, 6], [6, 4]]), 81 | np.array([[1]]), 82 | ) 83 | a, b, q, r = (x.astype("float64") for x in [a, b, q, r]) 84 | 85 | rng = np.random.default_rng(fetch_seed()) 86 | verify_grad(solve_discrete_are_enforce, pt=[a, b, q, r], rng=rng) 87 | 88 | 89 | if __name__ == "__main__": 90 | unittest.main() 91 | -------------------------------------------------------------------------------- /tests/test_representation.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | import pytensor 5 | import pytensor.tensor as pt 6 | from numpy.testing import assert_allclose 7 | 8 | from pymc_statespace.core.representation import PytensorRepresentation 9 | from tests.utilities.test_helpers import make_test_inputs 10 | 11 | 12 | class BasicFunctionality(unittest.TestCase): 13 | def setUp(self): 14 | self.data = np.arange(10)[:, None] 15 | 16 | def test_numpy_to_pytensor(self): 17 | ssm = PytensorRepresentation(data=np.zeros((4, 1)), k_states=5, k_posdef=1) 18 | X = np.eye(5) 19 | X_pt = ssm._numpy_to_pytensor("transition", X) 20 | self.assertTrue(isinstance(X_pt, pt.TensorVariable)) 21 | 22 | def test_default_shapes_full_rank(self): 23 | ssm = PytensorRepresentation(data=self.data, k_states=5, k_posdef=5) 24 | p = ssm.data.shape[1] 25 | m = ssm.k_states 26 | r = ssm.k_posdef 27 | 28 | self.assertTrue(ssm.data.shape == (10, 1, 1)) 29 | self.assertTrue(ssm["design"].eval().shape == (p, m)) 30 | self.assertTrue(ssm["transition"].eval().shape == (m, m)) 31 | self.assertTrue(ssm["selection"].eval().shape == (m, r)) 32 | self.assertTrue(ssm["state_cov"].eval().shape == (r, r)) 33 | self.assertTrue(ssm["obs_cov"].eval().shape == (p, p)) 34 | 35 | def test_default_shapes_low_rank(self): 36 | ssm = PytensorRepresentation(data=self.data, k_states=5, k_posdef=2) 37 | p = ssm.data.shape[1] 38 | m = ssm.k_states 39 | r = ssm.k_posdef 40 | 41 | self.assertTrue(ssm.data.shape == (10, 1, 1)) 42 | self.assertTrue(ssm["design"].eval().shape == (p, m)) 43 | self.assertTrue(ssm["transition"].eval().shape == (m, m)) 44 | self.assertTrue(ssm["selection"].eval().shape == (m, r)) 45 | self.assertTrue(ssm["state_cov"].eval().shape == (r, r)) 46 | self.assertTrue(ssm["obs_cov"].eval().shape == (p, p)) 47 | 48 | def test_matrix_assignment(self): 49 | ssm = PytensorRepresentation(data=self.data, k_states=5, k_posdef=2) 50 | 51 | ssm["design", 0, 0] = 3.0 52 | ssm["transition", 0, :] = 2.7 53 | ssm["selection", -1, -1] = 9.9 54 | 55 | self.assertTrue(ssm["design"].eval()[0, 0] == 3.0) 56 | self.assertTrue(np.all(ssm["transition"].eval()[0, :] == 2.7)) 57 | self.assertTrue(ssm["selection"].eval()[-1, -1] == 9.9) 58 | 59 | def test_build_representation_from_data(self): 60 | p, m, r, n = 3, 6, 1, 10 61 | inputs = [data, a0, P0, T, Z, R, H, Q] = make_test_inputs(p, m, r, n, missing_data=0) 62 | ssm = PytensorRepresentation( 63 | data=data, 64 | k_states=m, 65 | k_posdef=r, 66 | design=Z, 67 | transition=T, 68 | selection=R, 69 | state_cov=Q, 70 | obs_cov=H, 71 | initial_state=a0, 72 | initial_state_cov=P0, 73 | ) 74 | names = [ 75 | "initial_state", 76 | "initial_state_cov", 77 | "transition", 78 | "design", 79 | "selection", 80 | "obs_cov", 81 | "state_cov", 82 | ] 83 | for name, X in zip(names, inputs[1:]): 84 | self.assertTrue(np.allclose(X, ssm[name].eval())) 85 | 86 | def test_assign_time_varying_matrices(self): 87 | ssm = PytensorRepresentation(data=self.data, k_states=5, k_posdef=2) 88 | n = self.data.shape[0] 89 | 90 | ssm["design", 0, 0] = 3.0 91 | ssm["transition", 0, :] = 2.7 92 | ssm["selection", -1, -1] = 9.9 93 | 94 | ssm["state_intercept"] = np.zeros((5, 1, self.data.shape[0])) 95 | ssm["state_intercept", 0, 0, :] = np.arange(n) 96 | 97 | self.assertTrue(ssm["design"].eval()[0, 0] == 3.0) 98 | self.assertTrue(np.all(ssm["transition"].eval()[0, :] == 2.7)) 99 | self.assertTrue(ssm["selection"].eval()[-1, -1] == 9.9) 100 | self.assertTrue(np.allclose(ssm["state_intercept"][0, 0, :].eval(), np.arange(n))) 101 | 102 | def test_invalid_key_name_raises(self): 103 | ssm = PytensorRepresentation(data=self.data, k_states=5, k_posdef=1) 104 | with self.assertRaises(IndexError) as e: 105 | X = ssm["invalid_key"] 106 | msg = str(e.exception) 107 | self.assertEqual(msg, "invalid_key is an invalid state space matrix name") 108 | 109 | def test_non_string_key_raises(self): 110 | ssm = PytensorRepresentation(data=self.data, k_states=5, k_posdef=1) 111 | with self.assertRaises(IndexError) as e: 112 | X = ssm[0] 113 | msg = str(e.exception) 114 | self.assertEqual(msg, "First index must the name of a valid state space matrix.") 115 | 116 | def test_invalid_key_tuple_raises(self): 117 | ssm = PytensorRepresentation(data=self.data, k_states=5, k_posdef=1) 118 | with self.assertRaises(IndexError) as e: 119 | X = ssm[0, 1, 1] 120 | msg = str(e.exception) 121 | self.assertEqual(msg, "First index must the name of a valid state space matrix.") 122 | 123 | def test_slice_statespace_matrix(self): 124 | T = np.eye(5) 125 | ssm = PytensorRepresentation(data=self.data, k_states=5, k_posdef=1, transition=T) 126 | T_out = ssm["transition", :3, :] 127 | assert_allclose(T[:3], T_out.eval()) 128 | 129 | def test_update_matrix_via_key(self): 130 | T = np.eye(5) 131 | ssm = PytensorRepresentation(data=self.data, k_states=5, k_posdef=1) 132 | ssm["transition"] = T 133 | 134 | assert_allclose(T, ssm["transition"].eval()) 135 | 136 | def test_update_matrix_with_invalid_shape_raises(self): 137 | T = np.eye(10) 138 | ssm = PytensorRepresentation(data=self.data, k_states=5, k_posdef=1) 139 | with self.assertRaises(ValueError) as e: 140 | ssm["transition"] = T 141 | msg = str(e.exception) 142 | self.assertEqual(msg, "Array provided for transition has shape (10, 10), expected (5, 5)") 143 | 144 | 145 | if __name__ == "__main__": 146 | unittest.main() 147 | -------------------------------------------------------------------------------- /tests/test_simulations.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | from numpy.testing import assert_allclose 4 | 5 | from pymc_statespace.utils.simulation import ( 6 | conditional_simulation, 7 | numba_mvn_draws, 8 | simulate_statespace, 9 | unconditional_simulations, 10 | ) 11 | from tests.utilities.test_helpers import make_test_inputs 12 | 13 | 14 | def test_numba_mvn_draws(): 15 | cov = np.random.normal(size=(2, 2)) 16 | cov = cov @ cov.T 17 | draws = numba_mvn_draws(mu=np.zeros((2, 1_000_000)), cov=cov) 18 | 19 | assert_allclose(np.cov(draws), cov, atol=0.01, rtol=0.01) 20 | 21 | 22 | def test_simulate_statespace(): 23 | data, a0, P0, T, Z, R, H, Q = make_test_inputs(3, 5, 1, 100) 24 | simulated_states, simulated_obs = simulate_statespace(T, Z, R, H, Q, n_steps=100) 25 | 26 | assert simulated_states.shape == (100, 5) 27 | assert simulated_obs.shape == (100, 3) 28 | 29 | 30 | def test_simulate_statespace_with_x0(): 31 | data, a0, P0, T, Z, R, H, Q = make_test_inputs(3, 5, 1, 100) 32 | simulated_states, simulated_obs = simulate_statespace( 33 | T, Z, R, H, Q, n_steps=100, x0=a0.squeeze() 34 | ) 35 | 36 | assert simulated_states.shape == (100, 5) 37 | assert simulated_obs.shape == (100, 3) 38 | assert np.all(simulated_states[0] == a0) 39 | assert np.all(simulated_obs[0] == Z @ a0) 40 | 41 | 42 | def test_simulate_statespace_no_obs_noise(): 43 | data, a0, P0, T, Z, R, H, Q = make_test_inputs(3, 5, 1, 100) 44 | H = np.zeros_like(H) 45 | simulated_states, simulated_obs = simulate_statespace( 46 | T, Z, R, H, Q, n_steps=100, x0=a0.squeeze() 47 | ) 48 | 49 | assert simulated_states.shape == (100, 5) 50 | assert simulated_obs.shape == (100, 3) 51 | 52 | 53 | def test_conditional_simulation(): 54 | pass 55 | 56 | 57 | def test_unconditional_simulation(): 58 | pass 59 | -------------------------------------------------------------------------------- /tests/test_statespace.py: -------------------------------------------------------------------------------- 1 | import os 2 | import unittest 3 | from pathlib import Path 4 | 5 | import numpy as np 6 | import pandas as pd 7 | import pymc as pm 8 | import pytest 9 | from numpy.testing import assert_allclose 10 | 11 | from pymc_statespace.core.statespace import FILTER_FACTORY, PyMCStateSpace 12 | from tests.utilities.test_helpers import make_test_inputs 13 | 14 | ROOT = Path(__file__).parent.absolute() 15 | nile = pd.read_csv(os.path.join(ROOT, "test_data/nile.csv")) 16 | nile.index = pd.date_range(start="1871-01-01", end="1970-01-01", freq="AS-Jan") 17 | nile.rename(columns={"x": "height"}, inplace=True) 18 | nile = (nile - nile.mean()) / nile.std() 19 | 20 | 21 | @pytest.fixture() 22 | def ss_mod(): 23 | class StateSpace(PyMCStateSpace): 24 | @property 25 | def param_names(self): 26 | return ["rho", "zeta"] 27 | 28 | def update(self, theta): 29 | self.ssm["transition", 0, :] = theta 30 | 31 | T = np.zeros((2, 2)).astype("float64") 32 | T[1, 0] = 1.0 33 | Z = np.array([[1.0, 0.0]]) 34 | R = np.array([[1.0], [0.0]]) 35 | H = np.array([[0.1]]) 36 | Q = np.array([[0.8]]) 37 | 38 | ss_mod = StateSpace(data=nile, k_states=2, k_posdef=1, filter_type="standard") 39 | for X, name in zip( 40 | [T, Z, R, H, Q], ["transition", "design", "selection", "obs_cov", "state_cov"] 41 | ): 42 | ss_mod.ssm[name] = X 43 | 44 | return ss_mod 45 | 46 | 47 | @pytest.fixture 48 | def pymc_mod(ss_mod): 49 | with pm.Model() as pymc_mod: 50 | rho = pm.Normal("rho") 51 | zeta = pm.Deterministic("zeta", 1 - rho) 52 | ss_mod.build_statespace_graph() 53 | ss_mod.build_smoother_graph() 54 | 55 | return pymc_mod 56 | 57 | 58 | @pytest.fixture 59 | def idata(pymc_mod): 60 | with pymc_mod: 61 | idata = pm.sample(draws=100, tune=0, chains=1) 62 | 63 | return idata 64 | 65 | 66 | def test_invalid_filter_name_raises(): 67 | msg = "The following are valid filter types: " + ", ".join(list(FILTER_FACTORY.keys())) 68 | with pytest.raises(NotImplementedError, match=msg): 69 | mod = PyMCStateSpace(data=nile.values, k_states=5, k_posdef=1, filter_type="invalid_filter") 70 | 71 | 72 | def test_singleseriesfilter_raises_if_data_is_nd(): 73 | data = np.random.normal(size=(4, 10)) 74 | msg = 'Cannot use filter_type = "single" with multiple observed time series' 75 | with pytest.raises(ValueError, match=msg): 76 | mod = PyMCStateSpace(data=data, k_states=5, k_posdef=1, filter_type="single") 77 | 78 | 79 | def test_unpack_matrices(): 80 | p, m, r, n = 2, 5, 1, 10 81 | data, *inputs = make_test_inputs(p, m, r, n, missing_data=0) 82 | mod = PyMCStateSpace( 83 | data=data[..., 0], k_states=m, k_posdef=r, filter_type="standard", verbose=False 84 | ) 85 | 86 | outputs = mod.unpack_statespace() 87 | for x, y in zip(inputs, outputs): 88 | assert_allclose(np.zeros_like(x), y.eval()) 89 | 90 | 91 | def test_param_names_raises_on_base_class(): 92 | mod = PyMCStateSpace(data=nile, k_states=5, k_posdef=1, filter_type="standard", verbose=False) 93 | with pytest.raises(NotImplementedError): 94 | x = mod.param_names 95 | 96 | 97 | def test_update_raises_on_base_class(): 98 | mod = PyMCStateSpace(data=nile, k_states=5, k_posdef=1, filter_type="standard", verbose=False) 99 | theta = np.zeros(4) 100 | with pytest.raises(NotImplementedError): 101 | mod.update(theta) 102 | 103 | 104 | def test_gather_pymc_variables(ss_mod): 105 | with pm.Model() as mod: 106 | rho = pm.Normal("rho") 107 | zeta = pm.Deterministic("zeta", 1 - rho) 108 | theta = ss_mod.gather_required_random_variables() 109 | 110 | assert_allclose(pm.math.stack([rho, zeta]).eval(), theta.eval()) 111 | 112 | 113 | def test_gather_raises_if_variable_missing(ss_mod): 114 | with pm.Model() as mod: 115 | rho = pm.Normal("rho") 116 | msg = "The following required model parameters were not found in the PyMC model: zeta" 117 | with pytest.raises(ValueError, match=msg): 118 | theta = ss_mod.gather_required_random_variables() 119 | 120 | 121 | def test_build_smoother_fails_if_statespace_not_built_first(ss_mod): 122 | msg = ( 123 | "Couldn't find Kalman filtered time series among model deterministics. Have you run" 124 | ".build_statespace_graph() ?" 125 | ) 126 | 127 | with pm.Model() as mod: 128 | rho = pm.Normal("rho") 129 | zeta = pm.Normal("zeta") 130 | with pytest.raises(ValueError, match=msg): 131 | ss_mod.build_smoother_graph() 132 | 133 | 134 | def test_build_statespace_graph(pymc_mod): 135 | for name in [ 136 | "filtered_states", 137 | "predicted_states", 138 | "predicted_covariances", 139 | "filtered_covariances", 140 | ]: 141 | assert name in [x.name for x in pymc_mod.deterministics] 142 | 143 | 144 | def test_build_smoother_graph(ss_mod, pymc_mod): 145 | names = ["smoothed_states", "smoothed_covariances"] 146 | for name in names: 147 | assert name in [x.name for x in pymc_mod.deterministics] 148 | 149 | 150 | @pytest.mark.parametrize( 151 | "filter_output", 152 | ["filtered", "predicted", "smoothed", "invalid"], 153 | ids=["filtered", "predicted", "smoothed", "invalid"], 154 | ) 155 | def test_sample_conditional_prior(ss_mod, pymc_mod, filter_output): 156 | if filter_output == "invalid": 157 | msg = "filter_output should be one of filtered, predicted, or smoothed, recieved invalid" 158 | with pytest.raises(ValueError, match=msg), pymc_mod: 159 | ss_mod.sample_conditional_prior(filter_output=filter_output) 160 | else: 161 | with pymc_mod: 162 | conditional_prior = ss_mod.sample_conditional_prior( 163 | filter_output=filter_output, n_simulations=1, prior_samples=100 164 | ) 165 | 166 | 167 | @pytest.mark.parametrize( 168 | "filter_output", 169 | ["filtered", "predicted", "smoothed", "invalid"], 170 | ids=["filtered", "predicted", "smoothed", "invalid"], 171 | ) 172 | def test_sample_conditional_posterior(ss_mod, pymc_mod, idata, filter_output): 173 | if filter_output == "invalid": 174 | msg = "filter_output should be one of filtered, predicted, or smoothed, recieved invalid" 175 | with pytest.raises(ValueError, match=msg), pymc_mod: 176 | ss_mod.sample_conditional_prior(filter_output=filter_output) 177 | else: 178 | with pymc_mod: 179 | conditional_prior = ss_mod.sample_conditional_posterior( 180 | idata, filter_output=filter_output, n_simulations=1, posterior_samples=0.5 181 | ) 182 | 183 | 184 | def test_sample_conditional_posterior_raises_on_invalid_samples(ss_mod, pymc_mod, idata): 185 | msg = ( 186 | "If posterior_samples is a float, it should be between 0 and 1, representing the " 187 | "fraction of total posterior samples to re-sample." 188 | ) 189 | 190 | with pymc_mod: 191 | with pytest.raises(ValueError, match=msg): 192 | conditional_prior = ss_mod.sample_conditional_posterior( 193 | idata, filter_output="predicted", n_simulations=1, posterior_samples=-0.3 194 | ) 195 | 196 | 197 | def test_sample_conditional_posterior_default_samples(ss_mod, pymc_mod, idata): 198 | with pymc_mod: 199 | conditional_prior = ss_mod.sample_conditional_posterior( 200 | idata, filter_output="predicted", n_simulations=1 201 | ) 202 | 203 | 204 | def test_sample_unconditional_prior(ss_mod, pymc_mod): 205 | with pymc_mod: 206 | unconditional_prior = ss_mod.sample_unconditional_prior(n_simulations=1, prior_samples=100) 207 | 208 | 209 | def test_sample_unconditional_posterior(ss_mod, pymc_mod, idata): 210 | with pymc_mod: 211 | unconditional_posterior = ss_mod.sample_unconditional_posterior( 212 | idata, n_steps=100, n_simulations=1, posterior_samples=10 213 | ) 214 | 215 | 216 | if __name__ == "__main__": 217 | unittest.main() 218 | -------------------------------------------------------------------------------- /tests/utilities/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jessegrabowski/pymc_statespace/0659a3bfd9186f128f238c69d5193d3c461f3948/tests/utilities/__init__.py -------------------------------------------------------------------------------- /tests/utilities/statsmodel_local_level.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import statsmodels.api as sm 3 | 4 | 5 | class LocalLinearTrend(sm.tsa.statespace.MLEModel): 6 | def __init__(self, endog, **kwargs): 7 | # Model order 8 | k_states = k_posdef = 2 9 | 10 | # Initialize the statespace 11 | super().__init__(endog, k_states=k_states, k_posdef=k_posdef, **kwargs) 12 | 13 | # Initialize the matrices 14 | self.ssm["design"] = np.array([1, 0]) 15 | self.ssm["transition"] = np.array([[1, 1], [0, 1]]) 16 | self.ssm["selection"] = np.eye(k_states) 17 | 18 | # Cache some indices 19 | self._state_cov_idx = ("state_cov",) + np.diag_indices(k_posdef) 20 | 21 | @property 22 | def param_names(self): 23 | return ["sigma2.measurement", "sigma2.level", "sigma2.trend"] 24 | 25 | @property 26 | def start_params(self): 27 | return [np.std(self.endog)] * 3 28 | 29 | def transform_params(self, unconstrained): 30 | return unconstrained**2 31 | 32 | def untransform_params(self, constrained): 33 | return constrained**0.5 34 | 35 | def update(self, params, *args, **kwargs): 36 | params = super().update(params, *args, **kwargs) 37 | 38 | # Observation covariance 39 | self.ssm["obs_cov", 0, 0] = params[0] 40 | 41 | # State covariance 42 | self.ssm[self._state_cov_idx] = params[1:] 43 | -------------------------------------------------------------------------------- /tests/utilities/test_helpers.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | import numpy as np 5 | import pandas as pd 6 | import pytensor.tensor as pt 7 | 8 | from pymc_statespace.filters.kalman_smoother import KalmanSmoother 9 | from tests.utilities.statsmodel_local_level import LocalLinearTrend 10 | 11 | ROOT = Path(__file__).parent.parent.absolute() 12 | nile_data = pd.read_csv(os.path.join(ROOT, "test_data/nile.csv")) 13 | nile_data["x"] = nile_data["x"].astype(float) 14 | 15 | 16 | def initialize_filter(kfilter): 17 | ksmoother = KalmanSmoother() 18 | data = pt.dtensor3(name="data") 19 | a0 = pt.matrix(name="a0") 20 | P0 = pt.matrix(name="P0") 21 | Q = pt.matrix(name="Q") 22 | H = pt.matrix(name="H") 23 | T = pt.matrix(name="T") 24 | R = pt.matrix(name="R") 25 | Z = pt.matrix(name="Z") 26 | 27 | inputs = [data, a0, P0, T, Z, R, H, Q] 28 | 29 | ( 30 | filtered_states, 31 | predicted_states, 32 | filtered_covs, 33 | predicted_covs, 34 | log_likelihood, 35 | ll_obs, 36 | ) = kfilter.build_graph(*inputs) 37 | 38 | smoothed_states, smoothed_covs = ksmoother.build_graph(T, R, Q, filtered_states, filtered_covs) 39 | 40 | outputs = [ 41 | filtered_states, 42 | predicted_states, 43 | smoothed_states, 44 | filtered_covs, 45 | predicted_covs, 46 | smoothed_covs, 47 | log_likelihood, 48 | ll_obs, 49 | ] 50 | 51 | return inputs, outputs 52 | 53 | 54 | def add_missing_data(data, n_missing): 55 | n = data.shape[0] 56 | missing_idx = np.random.choice(n, n_missing, replace=False) 57 | data[missing_idx] = np.nan 58 | 59 | return data 60 | 61 | 62 | def make_test_inputs(p, m, r, n, missing_data=None, H_is_zero=False): 63 | data = np.arange(n * p, dtype="float").reshape(-1, p, 1) 64 | if missing_data is not None: 65 | data = add_missing_data(data, missing_data) 66 | 67 | a0 = np.zeros((m, 1)) 68 | P0 = np.eye(m) 69 | Q = np.eye(r) 70 | H = np.zeros((p, p)) if H_is_zero else np.eye(p) 71 | T = np.eye(m, k=-1) 72 | T[0, :] = 1 / m 73 | R = np.eye(m)[:, :r] 74 | Z = np.eye(m)[:p, :] 75 | 76 | return data, a0, P0, T, Z, R, H, Q 77 | 78 | 79 | def get_expected_shape(name, p, m, r, n): 80 | if name == "log_likelihood": 81 | return () 82 | elif name == "ll_obs": 83 | return (n,) 84 | filter_type, variable = name.split("_") 85 | if filter_type == "predicted": 86 | n += 1 87 | if variable == "states": 88 | return n, m, 1 89 | if variable == "covs": 90 | return n, m, m 91 | 92 | 93 | def get_sm_state_from_output_name(res, name): 94 | if name == "log_likelihood": 95 | return res.llf 96 | elif name == "ll_obs": 97 | return res.llf_obs 98 | 99 | filter_type, variable = name.split("_") 100 | sm_states = getattr(res, "states") 101 | 102 | if variable == "states": 103 | return getattr(sm_states, filter_type) 104 | if variable == "covs": 105 | m = res.filter_results.k_states 106 | # remove the "s" from "covs" 107 | return getattr(sm_states, name[:-1]).reshape(-1, m, m) 108 | 109 | 110 | def nile_test_test_helper(n_missing=0): 111 | a0 = np.zeros((2, 1)) 112 | P0 = np.eye(2) * 1e6 113 | Q = np.eye(2) * np.array([0.5, 0.01]) 114 | H = np.eye(1) * 0.8 115 | T = np.array([[1.0, 1.0], [0.0, 1.0]]) 116 | R = np.eye(2) 117 | Z = np.array([[1.0, 0.0]]) 118 | 119 | data = nile_data.values.copy() 120 | if n_missing > 0: 121 | data = add_missing_data(data, n_missing) 122 | 123 | sm_model = LocalLinearTrend( 124 | endog=data, 125 | initialization="known", 126 | initial_state_cov=P0, 127 | initial_state=a0.ravel(), 128 | ) 129 | 130 | res = sm_model.fit_constrained( 131 | constraints={ 132 | "sigma2.measurement": 0.8, 133 | "sigma2.level": 0.5, 134 | "sigma2.trend": 0.01, 135 | } 136 | ) 137 | 138 | inputs = [data[..., None], a0, P0, T, Z, R, H, Q] 139 | 140 | return res, inputs 141 | --------------------------------------------------------------------------------