├── .gitattributes ├── .github └── workflows │ └── publish-to-pypi.yml ├── .gitignore ├── .pylintrc ├── .vscode ├── extensions.json ├── launch.json └── settings.json ├── ChangeLog.md ├── LICENSE.md ├── README.md ├── data └── sample_data.p ├── pyproject.toml ├── requirements-dev.txt ├── requirements.txt └── source ├── DPAD ├── DPADModel.py ├── DPADModelDoc.md ├── RNNModel.py ├── RNNModelDoc.md ├── RegressionModel.py ├── RegressionModelDoc.md ├── __init__.py ├── example │ ├── DPAD_tutorial.ipynb │ └── __init__.py ├── tests │ ├── test_DPADModel.py │ └── test_RNNModel.py └── tools │ ├── GaussianSmoother.py │ ├── LinearMapping.py │ ├── SSM.py │ ├── __init__.py │ ├── abstract_classes.py │ ├── evaluation.py │ ├── file_tools.py │ ├── flexible.py │ ├── model_base_classes.py │ ├── parse_tools.py │ ├── plot.py │ ├── plot_model_params.py │ ├── sim_tools.py │ ├── tests │ ├── test_LinearMapping.py │ ├── test_parse_tools.py │ ├── test_tf_losses.py │ └── test_tools.py │ ├── tf_losses.py │ ├── tf_tools.py │ └── tools.py ├── __init__.py └── setup.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # Basic .gitattributes for a python repo. 2 | *.sh text eol=lf 3 | *.slurm text eol=lf 4 | 5 | # Source files 6 | # ============ 7 | *.pxd text diff=python 8 | *.py text diff=python 9 | *.py3 text diff=python 10 | *.pyc text diff=python 11 | *.pyd text diff=python 12 | *.pyo text diff=python 13 | *.pyw text diff=python 14 | *.pyx text diff=python 15 | *.pyz text diff=python 16 | 17 | # Binary files 18 | # ============ 19 | *.db binary 20 | *.p binary 21 | *.pkl binary 22 | *.pickle binary 23 | *.pyc binary 24 | *.pyd binary 25 | *.pyo binary 26 | 27 | # Jupyter notebook 28 | *.ipynb text 29 | 30 | # Note: .db, .p, and .pkl files are associated 31 | # with the python modules ``pickle``, ``dbm.*``, 32 | # ``shelve``, ``marshal``, ``anydbm``, & ``bsddb`` 33 | # (among others). -------------------------------------------------------------------------------- /.github/workflows/publish-to-pypi.yml: -------------------------------------------------------------------------------- 1 | name: Publish Python 🐍 distribution 📦 to PyPI and TestPyPI 2 | 3 | on: push 4 | 5 | jobs: 6 | build: 7 | name: Build distribution 📦 8 | runs-on: ubuntu-latest 9 | 10 | steps: 11 | - uses: actions/checkout@v4 12 | - name: Set up Python 13 | uses: actions/setup-python@v4 14 | with: 15 | python-version: "3.x" 16 | - name: Install pypa/build 17 | run: >- 18 | python3 -m 19 | pip install 20 | build 21 | --user 22 | - name: Build a binary wheel and a source tarball 23 | run: python3 -m build 24 | - name: Store the distribution packages 25 | uses: actions/upload-artifact@v4 26 | with: 27 | name: python-package-distributions 28 | path: dist/ 29 | 30 | publish-to-pypi: 31 | name: >- 32 | Publish Python 🐍 distribution 📦 to PyPI 33 | if: startsWith(github.ref, 'refs/tags/') # only publish to PyPI on tag pushes 34 | needs: 35 | - build 36 | runs-on: ubuntu-latest 37 | environment: 38 | name: pypi 39 | url: https://pypi.org/p/PSID # Replace with your PyPI project name 40 | permissions: 41 | id-token: write # IMPORTANT: mandatory for trusted publishing 42 | 43 | steps: 44 | - name: Download all the dists 45 | uses: actions/download-artifact@v4 46 | with: 47 | name: python-package-distributions 48 | path: dist/ 49 | - name: Publish distribution 📦 to PyPI 50 | uses: pypa/gh-action-pypi-publish@release/v1 51 | 52 | github-release: 53 | name: >- 54 | Sign the Python 🐍 distribution 📦 with Sigstore 55 | and upload them to GitHub Release 56 | needs: 57 | - publish-to-pypi 58 | runs-on: ubuntu-latest 59 | 60 | permissions: 61 | contents: write # IMPORTANT: mandatory for making GitHub Releases 62 | id-token: write # IMPORTANT: mandatory for sigstore 63 | 64 | steps: 65 | - name: Download all the dists 66 | uses: actions/download-artifact@v4 67 | with: 68 | name: python-package-distributions 69 | path: dist/ 70 | - name: Sign the dists with Sigstore 71 | uses: sigstore/gh-action-sigstore-python@v3.0.0 72 | with: 73 | inputs: >- 74 | ./dist/*.tar.gz 75 | ./dist/*.whl 76 | - name: Create GitHub Release 77 | env: 78 | GITHUB_TOKEN: ${{ github.token }} 79 | run: >- 80 | gh release create 81 | '${{ github.ref_name }}' 82 | --repo '${{ github.repository }}' 83 | --notes "" 84 | - name: Upload artifact signatures to GitHub Release 85 | env: 86 | GITHUB_TOKEN: ${{ github.token }} 87 | # Upload to GitHub Release using the `gh` CLI. 88 | # `dist/` contains the built packages, and the 89 | # sigstore-produced signatures and certificates. 90 | run: >- 91 | gh release upload 92 | '${{ github.ref_name }}' dist/** 93 | --repo '${{ github.repository }}' 94 | 95 | publish-to-testpypi: 96 | name: Publish Python 🐍 distribution 📦 to TestPyPI 97 | needs: 98 | - build 99 | runs-on: ubuntu-latest 100 | if: false 101 | 102 | environment: 103 | name: testpypi 104 | url: https://test.pypi.org/p/PSID 105 | 106 | permissions: 107 | id-token: write # IMPORTANT: mandatory for trusted publishing 108 | 109 | steps: 110 | - name: Download all the dists 111 | uses: actions/download-artifact@v4 112 | with: 113 | name: python-package-distributions 114 | path: dist/ 115 | - name: Publish distribution 📦 to TestPyPI 116 | uses: pypa/gh-action-pypi-publish@release/v1 117 | with: 118 | repository-url: https://test.pypi.org/legacy/ 119 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | 3 | *.pyirc 4 | *.noseids 5 | 6 | # cProfile results 7 | *.prof 8 | 9 | # Auto-generated bash files 10 | source/genFigs/*.sh 11 | source/hpc/*.sh 12 | source/hpc/jobs/*.sh 13 | source/hpc/old/*.sh 14 | source/results 15 | models 16 | results 17 | data 18 | logs 19 | 20 | # Byte-compiled / optimized / DLL files 21 | __pycache__/ 22 | *.py[cod] 23 | *$py.class 24 | 25 | # C extensions 26 | *.so 27 | 28 | # Distribution / packaging 29 | .Python 30 | build/ 31 | develop-eggs/ 32 | dist/ 33 | downloads/ 34 | eggs/ 35 | .eggs/ 36 | lib/ 37 | lib64/ 38 | parts/ 39 | sdist/ 40 | var/ 41 | wheels/ 42 | pip-wheel-metadata/ 43 | share/python-wheels/ 44 | *.egg-info/ 45 | .installed.cfg 46 | *.egg 47 | MANIFEST 48 | 49 | # PyInstaller 50 | # Usually these files are written by a python script from a template 51 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 52 | *.manifest 53 | *.spec 54 | 55 | # Installer logs 56 | pip-log.txt 57 | pip-delete-this-directory.txt 58 | 59 | # Unit test / coverage reports 60 | htmlcov/ 61 | .tox/ 62 | .nox/ 63 | .coverage 64 | .coverage.* 65 | .cache 66 | nosetests.xml 67 | coverage.xml 68 | *.cover 69 | *.py,cover 70 | .hypothesis/ 71 | .pytest_cache/ 72 | 73 | # Translations 74 | *.mo 75 | *.pot 76 | 77 | # Django stuff: 78 | *.log 79 | local_settings.py 80 | db.sqlite3 81 | db.sqlite3-journal 82 | 83 | # Flask stuff: 84 | instance/ 85 | .webassets-cache 86 | 87 | # Scrapy stuff: 88 | .scrapy 89 | 90 | # Sphinx documentation 91 | docs/_build/ 92 | 93 | # PyBuilder 94 | target/ 95 | 96 | # Jupyter Notebook 97 | .ipynb_checkpoints 98 | 99 | # IPython 100 | profile_default/ 101 | ipython_config.py 102 | 103 | # pyenv 104 | .python-version 105 | 106 | # pipenv 107 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 108 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 109 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 110 | # install all needed dependencies. 111 | #Pipfile.lock 112 | 113 | # celery beat schedule file 114 | celerybeat-schedule 115 | 116 | # SageMath parsed files 117 | *.sage.py 118 | 119 | # Environments 120 | .env 121 | .venv 122 | env/ 123 | venv/ 124 | ENV/ 125 | env.bak/ 126 | venv.bak/ 127 | 128 | # Spyder project settings 129 | .spyderproject 130 | .spyproject 131 | 132 | # Rope project settings 133 | .ropeproject 134 | 135 | # mkdocs documentation 136 | /site 137 | 138 | # mypy 139 | .mypy_cache/ 140 | .dmypy.json 141 | dmypy.json 142 | 143 | # Pyre type checker 144 | .pyre/ 145 | 146 | .DS_Store 147 | -------------------------------------------------------------------------------- /.pylintrc: -------------------------------------------------------------------------------- 1 | [MASTER] 2 | init-hook="from pylint.config import find_pylintrc; import os, sys; sys.path.append(os.path.dirname(find_pylintrc()))" -------------------------------------------------------------------------------- /.vscode/extensions.json: -------------------------------------------------------------------------------- 1 | { 2 | "recommendations": [ 3 | "ms-python.python", 4 | "ms-python.black-formatter", 5 | "ms-python.pylint", 6 | "ms-python.isort" 7 | ] 8 | } -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | // Use IntelliSense to learn about possible attributes. 3 | // Hover to view descriptions of existing attributes. 4 | // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 5 | "version": "0.2.0", 6 | "configurations": [ 7 | { 8 | "name": "Python: Current File", 9 | "type": "debugpy", 10 | "request": "launch", 11 | "program": "${file}", 12 | "console": "integratedTerminal", 13 | "purpose": ["debug-test"], 14 | "env": { 15 | "CUDA_VISIBLE_DEVICES": "0", 16 | "TF_CPP_MIN_LOG_LEVEL": "1", 17 | "PYDEVD_WARN_SLOW_RESOLVE_TIMEOUT": "5", 18 | }, 19 | "justMyCode": false 20 | } 21 | ] 22 | } -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "python.analysis.extraPaths": ["./source"], 3 | "python.terminal.activateEnvironment": true, 4 | "[python]": { 5 | "editor.defaultFormatter": "ms-python.black-formatter", 6 | "editor.formatOnSave": true, 7 | "editor.codeActionsOnSave": { 8 | "source.organizeImports": "explicit", 9 | }, 10 | "editor.tabSize": 4 11 | }, 12 | "isort.args":["--profile", "black"], 13 | "python.linting.pylintPath": "${workspaceFolder}\\.venv\\Scripts\\pylint.exe", 14 | "python.linting.pylintEnabled": true, 15 | "python.linting.enabled": true, 16 | "python.testing.pytestEnabled": true, 17 | "python.testing.unittestEnabled": false, 18 | "python.testing.unittestArgs": [ 19 | "-v", 20 | "-s", 21 | "./source", 22 | "-p", 23 | "test*.py" 24 | ], 25 | "python.testing.pytestArgs": [ 26 | "source" 27 | ], 28 | "python.analysis.typeCheckingMode": "basic", 29 | "python.analysis.diagnosticSeverityOverrides": { 30 | "reportUnboundVariable": "warning", 31 | "reportGeneralTypeIssues": "warning", 32 | "reportOptionalMemberAccess": "warning", 33 | "reportOptionalSubscript": "warning", 34 | "reportOptionalIterable": "warning" 35 | } 36 | } -------------------------------------------------------------------------------- /ChangeLog.md: -------------------------------------------------------------------------------- 1 | # Changes 2 | Versioning follows [semver](https://semver.org/). 3 | 4 | - v0.0.9: 5 | - Changes the default Early Stopping setting to be based on validation loss. 6 | - Makes the metric computation in inner cross validation immune to flat channels in the validation data (previously if a neuron was flat in validation data of one inner cross validation fold, the self-prediction for that fold would become NaN). 7 | - Adds Gaussian Smoother tool to use in notebooks. 8 | 9 | - v0.0.8: 10 | - Enables z-scoring of inputs to the model by default. Change `zscore_inputs` to `False` or add `nzs` to `methodCode` to disable. 11 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | This software is Copyright © 2024 The University of Southern California. All Rights Reserved. 2 | 3 | Permission to use, copy, modify, and distribute this software and its documentation for educational, research 4 | and non-profit purposes, without fee, and without a written agreement is hereby granted, provided that the 5 | above copyright notice, this paragraph and the following three paragraphs appear in all copies. 6 | 7 | Permission to make commercial use of this software may be obtained by contacting: 8 | USC Stevens Center for Innovation 9 | University of Southern California 10 | 1150 S. Olive Street, Suite 2300 11 | Los Angeles, CA 90115, USA 12 | 13 | This software program and documentation are copyrighted by The University of Southern California. The software 14 | program and documentation are supplied "as is", without any accompanying services from USC. USC does not warrant 15 | that the operation of the program will be uninterrupted or error-free. The end-user understands that the program 16 | was developed for research purposes and is advised not to rely exclusively on the program for any reason. 17 | 18 | IN NO EVENT SHALL THE UNIVERSITY OF SOUTHERN CALIFORNIA BE LIABLE TO ANY PARTY FOR 19 | DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, INCLUDING LOST 20 | PROFITS, ARISING OUT OF THE USE OF THIS SOFTWARE AND ITS DOCUMENTATION, EVEN IF THE 21 | UNIVERSITY OF SOUTHERN CALIFORNIA HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH 22 | DAMAGE. THE UNIVERSITY OF SOUTHERN CALIFORNIA SPECIFICALLY DISCLAIMS ANY 23 | WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF 24 | MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE SOFTWARE PROVIDED 25 | HEREUNDER IS ON AN "AS IS" BASIS, AND THE UNIVERSITY OF SOUTHERN CALIFORNIA HAS NO 26 | OBLIGATIONS TO PROVIDE MAINTENANCE, SUPPORT, UPDATES, ENHANCEMENTS, OR 27 | MODIFICATIONS. 28 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Publication: 2 | The following paper introduces and provides results of DPAD (**dissociative and prioritized analysis of dynamics**) in multiple real neural datasets. 3 | 4 | Omid G. Sani, Bijan Pesaran, Maryam M. Shanechi. *Dissociative and prioritized modeling of behaviorally relevant neural dynamics using recurrent neural networks*. ***Nature Neuroscience*** (2024). https://doi.org/10.1038/s41593-024-01731-2 5 | 6 | Original preprint: https://doi.org/10.1101/2021.09.03.458628 7 | 8 | 9 | # Usage examples 10 | The following notebook contains usage examples of DPAD for several use-cases: 11 | [source/DPAD/example/DPAD_tutorial.ipynb](https://github.dev/ShanechiLab/DPAD/blob/main/source/DPAD/example/DPAD_tutorial.ipynb). 12 | 13 | An HTML version of the notebook is also available next to it in the same directory. 14 | 15 | # Usage examples 16 | The following documents explain the formulation of the key classes that are used to implement DPAD (the code for these key classes is also available in the same directory): 17 | 18 | - [source/DPAD/DPADModelDoc.md](./source/DPAD/DPADModelDoc.md): The formulation implemented by the `DPADModel` class, which performs the overall 4-step DPAD modeling. 19 | 20 | - [source/DPAD/RNNModelDoc.md](./source/DPAD/RNNModelDoc.md): The formulation implemented by the custom `RNNModel` class, which implements the RNNs that are trained in steps 1 and 3 of DPAD. 21 | 22 | - [source/DPAD/RegressionModelDoc.md](./source/DPAD/RegressionModelDoc.md): The formulation implemented by the `RegressionModel` class, which `RNNModel` and `DPADModel` both internally use to build the general multilayer feed-forward neural networks that are used to implement each model parameter. 23 | 24 | We are working on various improvements to the DPAD codebase. Stay tuned! 25 | 26 | # Change Log 27 | You can see the change log in [ChangeLog.md](./ChangeLog.md) 28 | 29 | # License 30 | Copyright (c) 2024 University of Southern California 31 | See full notice in [LICENSE.md](./LICENSE.md) 32 | Omid G. Sani and Maryam M. Shanechi 33 | Shanechi Lab, University of Southern California 34 | -------------------------------------------------------------------------------- /data/sample_data.p: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShanechiLab/DPAD/a09e7d9c3e59f1adb2d75336705dc166c15f4039/data/sample_data.p -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "DPAD" 7 | version = "0.0.9" 8 | authors = [ 9 | {name = "Omid Sani", email = "omidsani@gmail.com"}, 10 | ] 11 | description = "Python implementation for DPAD (dissociative and prioritized analysis of dynamics)" 12 | requires-python = ">=3.11" 13 | classifiers = [ 14 | "Programming Language :: Python :: 3", 15 | "Operating System :: OS Independent", 16 | ] 17 | dependencies = [ 18 | "tensorflow==2.15.1", 19 | "numpy==1.26.4", 20 | "PSID==1.2.5", 21 | "coloredlogs==15.0.1", 22 | "tqdm==4.66.4", 23 | "xxhash==3.5.0" 24 | ] 25 | dynamic = ["readme"] 26 | 27 | [project.urls] 28 | Homepage = "https://github.com/ShanechiLab/DPAD" 29 | Issues = "https://github.com/ShanechiLab/DPAD/issues" 30 | 31 | [tool.setuptools.dynamic] 32 | readme = {file = ["README.md"], content-type = "text/markdown"} 33 | 34 | [tool.setuptools.packages.find] 35 | where = ["source"] # list of folders that contain the packages (["."] by default) -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | # Additional requirements for development or for running the notebook 2 | sympy==1.13.2 # For notebook 3 | ipython # For notebook 4 | ipykernel # For notebook -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # python version: 3.11.9 2 | tensorflow==2.15.1 3 | numpy==1.26.4 4 | PSID==1.2.5 # PSID library for comprisons 5 | coloredlogs==15.0.1 # For printing logs in color onto the terminal 6 | tqdm==4.66.4 7 | xxhash==3.5.0 8 | -------------------------------------------------------------------------------- /source/DPAD/DPADModelDoc.md: -------------------------------------------------------------------------------- 1 | # DPADModel formulation 2 | The model learning for `DPADModel` is done in 4 steps as follows ([**Methods**](https://doi.org/10.1038/s41593-024-01731-2)): 3 | 4 | 1. In the first optimization step, we learn the parameters $A'^{(1)}(\cdot)$, $K^{(1)}(\cdot)$, and $C^{(1)}_z(\cdot)$ of the following RNN: 5 | 6 | $$x^{(1)}_{k+1} = A'^{(1)}(x^{(1)}_k) + K^{(1)}( y_k )$$ 7 | 8 | $$\hat{z}^{(1)}_k = C_z^{(1)}( x^{(1)}_k )$$ 9 | 10 | and estimate its latent state $x^{(1)}_k\in\mathbb{R}^{n_1}$, while minimizing the negative log-likelihood (NLL) of predicting the behavior $z_k$ as $\hat{z}^{(1)}_k$. This RNN is implemented as an `RNNModel` object with $y_k$ as the input and $\hat{z}^{(1)}_k$ as the output and the state dimension of $n_1$ as specified by the user. `RNNModel` implements each of the RNN parameters, $A'^{(1)}(\cdot)$, $K^{(1)}(\cdot)$, and $C^{(1)}_z(\cdot)$, as a multilayer feed-forward neural network implemented via the `RegressionModel` class. 11 | 12 | 13 | 2. The second optimization step uses the extracted latent state $x^{(1)}_k$ from the RNN and fits the parameter $C_y^{(1)}$ in 14 | 15 | $$\hat{y}_k = C_y^{(1)}( x^{(1)}_k )$$ 16 | 17 | while minimizing the NLL of predicting the neural activity $y_k$ as $\hat{y}_k$. The $C_y^{(1)}$ parameter that specifies this mapping is implemented as a flexible multilayer feed-forward neural network, via the `RegressionModel` class. 18 | 19 | 20 | 3. In the third optimization step, we learn the parameters $A^{(2)}(\cdot)$, $K^{(2)}(\cdot)$, and $C^{(2)}_y(\cdot)$ of the following RNN: 21 | 22 | $$x^{(2)}_{k+1} = A'^{(2)}(x^{(2)}_k) + K^{(2)}( y_k, x^{(1)}_{k+1} )$$ 23 | 24 | $$\hat{y}_k = C_y^{(2)}( x^{(2)}_k )$$ 25 | 26 | and estimate its latent state $x^{(2)}_k$ while minimizing the aggregate neural prediction negative log-likelihood, which also takes into account the negative log-likelihood (NLL) obtained from step 2 via the $C_y^{(1)}( x^{(1)}_k )$ and computed using the previously learned parameter $C_y^{(1)}$ and the previously extracted states $x_k^{(1)}$ in steps 1-2. This RNN is also implemented as an `RNNModel` object with the concatenation of $y_k$ and $x^{(1)}_k$ as the input and the predicted neural activity as the output. The NLL for predicting neural activity from steps 1-2 is also provided as input, to allow formation of aggregate neural prediction NLL as the loss. `RNNModel` again implements each of the RNN parameters, $A'^{(2)}(\cdot)$, $K^{(2)}(\cdot)$, and $C^{(2)}_y(\cdot)$, as a multilayer feed-forward neural network implemented via the `RegressionModel` class. 27 | 28 | 29 | 4. The fourth optimization step uses the extracted latent states in optimization steps 1 and 3 (i.e., $x^{(1)}_k$ and $x^{(2)}_k$) and fits $C_z$ in: 30 | 31 | $$\hat{z}_k = C_z( x^{(1)}_k, x^{(2)}_k )$$ 32 | 33 | while minimizing the behavior prediction negative log-likelihood. This step again implements $C_z(.)$ as a flexible multilayer feed-forward neural network, via the `RegressionModel` class. 34 | 35 | For additional options and generalizations to these steps, please read **Methods** in the [DPAD paper](https://doi.org/10.1038/s41593-024-01731-2). 36 | 37 | # Objective function of each optimization step 38 | Objective function of each optimization step is the negative log-likelihood (NLL) associated with the time series predicted in that optimization step, i.e. $z_k$ for steps 1 and 4 and $y_k$ for steps 2 and 3. 39 | For Gaussian distributed signals $z_k$ with isotropic noise, the NLL is proportional to the mean squared errors (MSE). For example, for Gaussian behaviors loss of step 1 will be: 40 | 41 | $$\sum_{k}\Vert z_k-\hat{z}^{(1)}_k\Vert_2^2$$ 42 | 43 | To support non-Gaussian data modalities, e.g., categorical behavior, DPAD adjusts the objectives of the four optimization steps and the architecture of the readout parameters based on the NLL of the relevant distribution. For example, for categorical behavior $z_k$ the NLL is proportional to the cross-entropy and the readout architecture is adjusted as follows: 44 | 1) we change the behavior readout parameter $C_z$ to have an output dimension of $n_z \times n_c$ instead of $n_z$, where $n_c$ denotes the number of behavior categories or classes, and 45 | 2) we apply a Softmax normalization to the output of the behavior readout parameter $C_z$ to ensure that for each of the $n_z$ behavior dimensions, the predicted probabilities for all the $n_c$ classes add up to 1, so that they represent valid probability mass functions. 46 | 47 | For details, see [**Methods**](https://doi.org/10.1038/s41593-024-01731-2). 48 | 49 | We also extend DPAD to modeling intermittently measured behavior time series. To do so, when forming the behavior loss, we only compute the NLL loss on samples where the behavior is measured (i.e., mask the other samples) and solve the optimization with this loss. Doing so, the modeling approach becomes applicable to intermittently measured behavior signals (**ED Figs. 8-9, S Fig. 8** in the [DPAD paper](https://doi.org/10.1038/s41593-024-01731-2)). 50 | 51 | -------------------------------------------------------------------------------- /source/DPAD/RNNModelDoc.md: -------------------------------------------------------------------------------- 1 | # RNNModel formulation, 1-step ahead, no input 2 | The formulation for `RNNModel` is as follows 3 | 4 | $$ 5 | x_{k+1} = A(x_{k}) + K( y_k ) \\ 6 | z_k = C( x_k ) 7 | $$ 8 | 9 | where $y_k$ and $z_k$ are the input and output, respectively and $x_k$ is the latent state. Each parameter ($A(.)$, $K(.)$, and $C(.)$) is a multi-layer perceptron (MLP) implemented via the `RegressionModel` class. 10 | 11 | In the special case of a linear model with the same output time series as the input time series (from $y_k$ to $y_k$), this reduces to a Kalman filter doing one-step ahead prediction: 12 | 13 | $$ 14 | x_{k+1} = A x_k + K y_k \\ 15 | y_k = C x_k + e_k 16 | $$ 17 | 18 | where $x_k \triangleq x_{k|k-1}$ is the estimated latent state at time step $k$ given all inputs up to time step $k-1$. 19 | 20 | Read more on the links to the linear case in **S Note 1** in the [DPAD paper](https://doi.org/10.1038/s41593-024-01731-2). -------------------------------------------------------------------------------- /source/DPAD/RegressionModelDoc.md: -------------------------------------------------------------------------------- 1 | # RegressionModel formulation 2 | `RegressionModel` implements a multilayer feed-forward neural network as follows. Each layer applies the following computation to its input: 3 | 4 | $$ 5 | h = f(Nx+b) 6 | $$ 7 | 8 | where $x\in\mathbb{R}^{n_x}$ is the input to the layer, $h\in\mathbb{R}^{n_h}$ is the output, and $f(.)$ is a fixed scalar function (typically nonlinear) applied on each dimension of its input vector. The bias vector $b\in\mathbb{R}^{n_h}$ and the matrix $N\in\mathbb{R}^{n_h\times n_x}$ are learnable parameters of the above feed-forward neural network layer. `RegressionModel` creates multiple of these layers stacked together (each created using `tf.keras.layers.Dense`), feeding the output of each as the input of the next, until finally returning the output of the last layer as the overall output of the multilayer feed-forward neural network. For example, with 1 hidden layer, the overall formulation will be as follows: 9 | 10 | $$ 11 | z = M f(Nx + b) + b' 12 | $$ 13 | 14 | where $z\in\mathbb{R}^{n_z}$ is the overall output of the multilayer feed-forward neural network, and $M\in\mathbb{R}^{n_z\times n_h}$ is a matrix. **S Fig. 1c** in the [DPAD paper](https://doi.org/10.1038/s41593-024-01731-2) depicts the computation graph for this case (not showing the bias vectors $b$ and $b'$). We use rectified linear unit (ReLU) functions as the nonlinearity $f(.)$ for all hidden layers and include a bias term $b$ for all hidden layers in this work ([**Methods**](https://doi.org/10.1038/s41593-024-01731-2)). 15 | 16 | When no hidden layers are used and biases are set to 0, the multilayer feed-forward neural network implemented by `RegressionModel` reduces to the special case of a linear matrix multiplication: 17 | 18 | $$ 19 | y = Mx 20 | $$ 21 | 22 | for which the computation graph is shown in **S Fig. 1b** in the [DPAD paper](https://doi.org/10.1038/s41593-024-01731-2). 23 | -------------------------------------------------------------------------------- /source/DPAD/__init__.py: -------------------------------------------------------------------------------- 1 | # Import DPAD classes 2 | from .DPADModel import DPADModel 3 | from .RegressionModel import RegressionModel 4 | from .RNNModel import RNNModel 5 | -------------------------------------------------------------------------------- /source/DPAD/example/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShanechiLab/DPAD/a09e7d9c3e59f1adb2d75336705dc166c15f4039/source/DPAD/example/__init__.py -------------------------------------------------------------------------------- /source/DPAD/tests/test_RNNModel.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2024 University of Southern California 3 | See full notice in LICENSE.md 4 | Omid G. Sani and Maryam M. Shanechi 5 | Shanechi Lab, University of Southern California 6 | """ 7 | 8 | """Tests RNNModel""" 9 | 10 | # pylint: disable=C0103, C0111 11 | 12 | import copy 13 | import os 14 | import sys 15 | import unittest 16 | 17 | sys.path.insert(0, os.path.dirname(__file__)) 18 | sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) 19 | 20 | import numpy as np 21 | from DPAD.DPADModel import shift_ms_to_1s_series 22 | from DPAD.RNNModel import RNNModel 23 | from DPAD.sim import ( 24 | generateRandomLinearModel, 25 | genRandomGaussianNoise, 26 | getSysSettingsFromSysCode, 27 | ) 28 | from DPAD.tools.tf_tools import set_global_tf_eagerly_flag 29 | 30 | numTests = 10 31 | 32 | 33 | class TestRNNModel(unittest.TestCase): 34 | def testRNNModel_initLSSM(self): 35 | np.random.seed(42) 36 | 37 | sysCode = "nyR1_5_nzR1_5_nuR0_3_NxR1_5_N1R0_5" 38 | sysSettings = getSysSettingsFromSysCode(sysCode) 39 | 40 | failInds = [] 41 | failErrs = [] 42 | 43 | # numTests = 100 44 | for ci in range(numTests): 45 | sOrig, sysU, zErrSys = generateRandomLinearModel(sysSettings) 46 | s = copy.deepcopy(sOrig) 47 | # print('Testing system {}/{}'.format(1+ci, numTests)) 48 | 49 | block_samples = 64 50 | stateful = True 51 | # Tests ARE EXPECTED to FAIL if stateful = False, because at edges of block_samples, initial state of RNN resets, but that does not happen for Kalman. 52 | # This is why is is important to have stateful=True 53 | 54 | N = ( 55 | 20 if ci % 5 > 0 else 200 56 | ) # Important to also have some longer data to test cass with multiple batches 57 | if not stateful: 58 | N = ( 59 | 1 * block_samples if ci % 5 > 0 else 4 * block_samples 60 | ) # For non-stateful (trial-based RNNs), input data must be a multiple of block_samples 61 | if s.input_dim: 62 | U, XU = sysU.generateRealization(N) 63 | UT = U.T 64 | else: 65 | U, UT = None, None 66 | Y, X = s.generateRealizationWithKF(N, u=U) 67 | if s.input_dim > 0: 68 | YU = np.concatenate((Y, U), axis=1) 69 | UT, YUT = U.T, YU.T 70 | else: 71 | YU, YUT = Y, Y.T 72 | 73 | allZp, allYp, allXp = s.predict(Y, U=U) 74 | 75 | model2 = RNNModel( 76 | initLSSM=s, block_samples=64, batch_size=10, stateful=stateful 77 | ) 78 | allXp2, allYp2 = model2.predict(YUT, FT_in=UT, use_quick_method=True)[:2] 79 | outs = model2.predict_with_keras( 80 | (YUT, UT) 81 | ) # Just making sure that this can run too 82 | 83 | try: 84 | np.testing.assert_allclose(allXp, allXp2.T, rtol=1e-3, atol=1e-6) 85 | np.testing.assert_allclose(allYp, allYp2.T, rtol=1e-3, atol=1e-6) 86 | except Exception as e: 87 | failInds.append(ci) 88 | failErrs.append(e) 89 | 90 | if len(failInds) > 0: 91 | raise ( 92 | Exception( 93 | "{} => {}/{} random systems (indices: {}) failed: \n{}".format( 94 | self.id(), len(failInds), numTests, failInds, failErrs 95 | ) 96 | ) 97 | ) 98 | else: 99 | print( 100 | "{} => Ok: Tested with {} random systems, all were ok!".format( 101 | self.id(), numTests 102 | ) 103 | ) 104 | 105 | def testRNNModel_MutliStepAheadPrediction(self): 106 | np.random.seed(42) 107 | 108 | sysCode = "nyR1_5_nzR1_5_nuR0_3_NxR1_5_N1R0_5" 109 | sysSettings = getSysSettingsFromSysCode(sysCode) 110 | 111 | failInds = [] 112 | failStepAheads = [] 113 | failErrs = [] 114 | 115 | # numTests = 100 116 | for ci in range(numTests): 117 | sOrig, sysU, zErrSys = generateRandomLinearModel(sysSettings) 118 | s = copy.deepcopy(sOrig) 119 | # print('Testing system {}/{}'.format(1+ci, numTests)) 120 | 121 | N = ( 122 | 20 if ci % 5 > 0 else 200 123 | ) # Important to also have some longer data to test cases with multiple batches 124 | if s.input_dim: 125 | U, XU = sysU.generateRealization(N) 126 | UT = U.T 127 | else: 128 | U, UT = None, None 129 | Y, X = s.generateRealizationWithKF(N, u=U) 130 | if s.input_dim > 0: 131 | YU = np.concatenate((Y, U), axis=1) 132 | UT, YUT = U.T, YU.T 133 | else: 134 | YU, YUT = Y, Y.T 135 | 136 | allZp, allYp, allXp = s.predict(Y, U=U) 137 | 138 | model1 = RNNModel( 139 | initLSSM=s, block_samples=64, batch_size=10, steps_ahead=None 140 | ) 141 | outs1 = model1.predict(YUT, FT_in=UT, use_quick_method=True) 142 | 143 | steps_ahead = [1, 2, 5, 10] 144 | 145 | model2 = RNNModel( 146 | initLSSM=s, 147 | block_samples=64, 148 | batch_size=10, 149 | steps_ahead=steps_ahead, 150 | enable_forward_pred=True, 151 | multi_step_with_A_KC=True, 152 | ) 153 | outs2 = model2.predict(YUT, FT_in=UT, use_quick_method=True) 154 | 155 | model3 = RNNModel( 156 | initLSSM=s, 157 | block_samples=64, 158 | batch_size=10, 159 | steps_ahead=steps_ahead, 160 | enable_forward_pred=True, 161 | ) 162 | outs3 = model3.predict(YUT, FT_in=UT, use_quick_method=True) 163 | 164 | model3.set_use_feedthrough_in_fw(True) 165 | outs3_use_FT_in_fw = model3.predict(YUT, FT_in=UT, use_quick_method=True) 166 | 167 | model4 = RNNModel( 168 | initLSSM=s, 169 | block_samples=64, 170 | batch_size=10, 171 | steps_ahead=steps_ahead, 172 | enable_forward_pred=True, 173 | ) 174 | model4.set_steps_ahead([1]) 175 | outs4 = model4.predict(YUT, FT_in=UT, use_quick_method=True) 176 | 177 | outs_tmp = model2.predict_with_keras( 178 | (YUT, UT) 179 | ) # Just making sure that this can run too 180 | 181 | np.testing.assert_allclose(outs1[0], outs2[0], rtol=1e-3, atol=1e-6) 182 | np.testing.assert_allclose( 183 | outs1[1], outs2[len(steps_ahead)], rtol=1e-3, atol=1e-6 184 | ) 185 | np.testing.assert_allclose(outs1[0], outs3[0], rtol=1e-3, atol=1e-6) 186 | np.testing.assert_allclose( 187 | outs1[1], outs3[len(steps_ahead)], rtol=1e-3, atol=1e-6 188 | ) 189 | np.testing.assert_allclose(outs1[0], outs4[0], rtol=1e-3, atol=1e-6) 190 | np.testing.assert_allclose(outs1[1], outs4[1], rtol=1e-3, atol=1e-6) 191 | 192 | outs3Ys_shifted = shift_ms_to_1s_series( 193 | outs3[len(steps_ahead) : 2 * len(steps_ahead)], 194 | steps_ahead, 195 | time_first=False, 196 | ) 197 | 198 | for saInd, step_ahead in enumerate(steps_ahead): 199 | thisXp_A_KC = np.array(allXp) 200 | thisXp = np.array(allXp) 201 | for stepInd in range(step_ahead - 1): 202 | thisXp_A_KC = (s.A_KC @ thisXp_A_KC.T).T 203 | thisXp = (s.A @ thisXp.T).T 204 | thisYp_A_KC = (s.C @ thisXp_A_KC.T).T 205 | thisYp = (s.C @ thisXp.T).T 206 | if step_ahead == 1 and s.input_dim > 0: 207 | thisYp_A_KC += (s.D @ UT).T 208 | thisYp += (s.D @ UT).T 209 | 210 | UTShift = np.concatenate( 211 | ( 212 | UT[:, (step_ahead - 1) :], 213 | np.zeros_like(UT[:, : (step_ahead - 1)]), 214 | ), 215 | axis=1, 216 | ) 217 | thisYp_use_FT_in_fw = (s.C @ thisXp.T).T + ( 218 | s.D @ UTShift 219 | ).T # Always, even for forecasting 220 | 221 | try: 222 | if step_ahead == 1: 223 | np.testing.assert_allclose( 224 | thisYp_use_FT_in_fw, thisYp, rtol=1e-3, atol=1e-6 225 | ) 226 | np.testing.assert_allclose( 227 | thisXp_A_KC, outs2[saInd].T, rtol=1e-3, atol=1e-6 228 | ) 229 | np.testing.assert_allclose( 230 | thisYp_A_KC, 231 | outs2[saInd + len(steps_ahead)].T, 232 | rtol=1e-3, 233 | atol=1e-5, 234 | ) 235 | np.testing.assert_allclose( 236 | thisXp, outs3[saInd].T, rtol=1e-3, atol=1e-6 237 | ) 238 | np.testing.assert_allclose( 239 | thisYp, outs3[saInd + len(steps_ahead)].T, rtol=1e-3, atol=1e-5 240 | ) 241 | np.testing.assert_allclose( 242 | thisYp_use_FT_in_fw, 243 | outs3_use_FT_in_fw[saInd + len(steps_ahead)].T, 244 | rtol=1e-3, 245 | atol=1e-5, 246 | ) 247 | except Exception as e: 248 | failInds.append(ci) 249 | failStepAheads.append(step_ahead) 250 | failErrs.append(e) 251 | 252 | # Test shift_ms_to_1s_series 253 | if step_ahead == 1: 254 | thisYp_same_time_as_Y = thisYp 255 | else: 256 | thisYp_same_time_as_Y = np.nan * np.ones_like(thisYp) 257 | thisYp_same_time_as_Y[(step_ahead - 1) :, :] = thisYp[ 258 | : (-step_ahead + 1), : 259 | ] 260 | 261 | np.testing.assert_allclose( 262 | thisYp_same_time_as_Y, 263 | outs3Ys_shifted[saInd].T, 264 | rtol=1e-3, 265 | atol=1e-5, 266 | ) 267 | 268 | pass 269 | 270 | if len(failInds) > 0: 271 | raise ( 272 | Exception( 273 | "{} => {}/{} random systems (indices: {}) failed: \n{}".format( 274 | self.id(), len(failInds), numTests, failInds, failErrs 275 | ) 276 | ) 277 | ) 278 | else: 279 | print( 280 | "{} => Ok: Tested with {} random systems, all were ok!".format( 281 | self.id(), numTests 282 | ) 283 | ) 284 | 285 | def testRNNModel_Bidirectional(self): 286 | np.random.seed(42) 287 | 288 | sysCode = "nyR1_5_nzR1_5_nuR0_0_NxR1_5_N1R0_5" 289 | sysSettings = getSysSettingsFromSysCode(sysCode) 290 | 291 | failInds = [] 292 | failErrs = [] 293 | 294 | # numTests = 100 295 | for ci in range(numTests): 296 | sOrig, sysU, zErrSys = generateRandomLinearModel(sysSettings) 297 | 298 | if ci == 0: # A special simple system that is easy to track 299 | sOrig, sysU, zErrSys = generateRandomLinearModel( 300 | getSysSettingsFromSysCode("nyR1_1_nzR1_1_nuR0_0_NxR1_1_N1R1_1") 301 | ) 302 | sOrig.changeParams({"A": sOrig.A / sOrig.A * 0.999}) 303 | # Temp, setting up an inconsistent model 304 | sOrig.A_KC = sOrig.A_KC / sOrig.A_KC * 0.01 305 | sOrig.K = sOrig.K / sOrig.K * 1 306 | sOrig.C = sOrig.C / sOrig.C * 1 307 | sOrig.useA_KC_plus_KC_in_KF = True 308 | 309 | s = copy.deepcopy(sOrig) 310 | # print('Testing system {}/{}'.format(1+ci, numTests)) 311 | 312 | block_samples = 64 313 | stateful = True 314 | # Tests ARE EXPECTED to FAIL if stateful = False, because at edges of block_samples, initial state of RNN resets, but that does not happen for Kalman. 315 | # This is why is is important to have stateful=True 316 | 317 | N = ( 318 | 20 if ci % 5 > 0 else 200 319 | ) # Important to also have some longer data to test cass with multiple batches 320 | if ci == 0: 321 | N = 20 322 | if not stateful: 323 | N = ( 324 | 1 * block_samples if ci % 5 > 0 else 4 * block_samples 325 | ) # For non-stateful (trial-based RNNs), input data must be a multiple of block_samples 326 | if s.input_dim: 327 | U, XU = sysU.generateRealization(N) 328 | UT = U.T 329 | else: 330 | U, UT = None, None 331 | Y, X = s.generateRealizationWithKF(N, u=U) 332 | 333 | if ci == 0: # A special simple realization that is easy to track 334 | e = np.zeros((N, s.output_dim)) 335 | e[0, :] = 1 336 | Y, X = s.generateRealizationWithKF(N, u=U, e=e) 337 | a = s.A 338 | apow = (a ** np.arange(len(Y))).T 339 | # import matplotlib.pyplot as plt 340 | # plt.figure() 341 | # ax = plt.gca() 342 | # ax.plot(X, label='X') 343 | # ax.plot(Y, label='Y') 344 | # ax.plot(apow, label='a^t') 345 | # ax.legend() 346 | # ax2 = ax.twinx() 347 | # ax2.plot(np.log10(np.abs(X)), linestyle='--', label='log10 |X|') 348 | # ax2.plot(np.log10(np.abs(Y)), linestyle='--', label='log10 |Y|') 349 | # ax2.plot(np.log10(np.abs(apow)), linestyle='--', label='log10 |a^t|') 350 | # ax2.legend(loc='lower right') 351 | # plt.show() 352 | 353 | Y = np.arange(1, N + 1)[:, np.newaxis] 354 | 355 | if s.input_dim > 0: 356 | YU = np.concatenate((Y, U), axis=1) 357 | UT, YUT = U.T, YU.T 358 | else: 359 | YU, YUT = Y, Y.T 360 | 361 | nx = s.state_dim 362 | sBW = copy.deepcopy(s) 363 | 364 | # Append one extra sample to Y to also get the 1-step prediction given the last actual sample of Y 365 | YFW = np.concatenate((Y, np.zeros_like(Y[0:1, :])), axis=0) 366 | UFW = ( 367 | np.concatenate((U, np.zeros_like(U[0:1, :])), axis=0) 368 | if U is not None 369 | else None 370 | ) 371 | allZpFW, allYpFW, allXpFW = s.predict(YFW, U=UFW) 372 | 373 | YBW = np.concatenate((np.flipud(Y), np.zeros_like(Y[0:2, :])), axis=0) 374 | UBW = ( 375 | np.concatenate((np.flipud(U), np.zeros_like(U[0:2, :])), axis=0) 376 | if U is not None 377 | else None 378 | ) 379 | allZpBW, allYpBW, allXpBW = sBW.predict(YBW, U=UBW) 380 | allZpBWFlip, allYpBWFlip, allXpBWFlip = ( 381 | np.flipud(allZpBW), 382 | np.flipud(allYpBW), 383 | np.flipud(allXpBW), 384 | ) 385 | 386 | shift_preds = False # Must be False, True is to emulate old behavior that is like the behavior for one directional RNNs 387 | if shift_preds: 388 | allZpFW, allYpFW, allXpFW = ( 389 | allZpFW[:-1, :], 390 | allYpFW[:-1, :], 391 | allXpFW[:-1, :], 392 | ) 393 | # We want the backward pass to see up to the exact same observation as was seen in the forward pass (that's how the Bidirectional RNN model works) 394 | # so we will append by two samples and remove the first two outputs 395 | allXp = np.concatenate((allXpFW, allXpBWFlip[:(-2), :]), axis=1) 396 | allYp = allYpFW + allYpBWFlip[:(-2), :] 397 | allZp = allZpFW + allZpBWFlip[:(-2), :] 398 | else: 399 | # For bidirectional, we expect shift_preds=False, in which case state at index i has seen inputs up to and including index i 400 | allZpFW, allYpFW, allXpFW = ( 401 | allZpFW[1:, :], 402 | allYpFW[1:, :], 403 | allXpFW[1:, :], 404 | ) # Drop first sample that is prediction given no observation 405 | allXp = np.concatenate((allXpFW, allXpBWFlip[1:(-1), :]), axis=1) 406 | allYp = allYpFW + allYpBWFlip[1:(-1), :] 407 | allZp = allZpFW + allZpBWFlip[1:(-1), :] 408 | 409 | # eagerly_flag_backup = set_global_tf_eagerly_flag(True) 410 | log_dir = "./logs" 411 | model2 = RNNModel( 412 | initLSSM=s, 413 | initLSSM_backward=sBW, 414 | bidirectional=True, 415 | # linear_cell=True, 416 | block_samples=block_samples, # Making block samples the same as data length makes the results identical to LSSM, except for the edges 417 | batch_size=10, 418 | stateful=stateful, 419 | log_dir=log_dir, 420 | ) 421 | allXp2, allYp2 = model2.predict( 422 | YUT, FT_in=UT, use_quick_method=True, shift_preds=shift_preds 423 | )[:2] 424 | 425 | outs = model2.predict_with_keras( 426 | (YUT, UT) 427 | ) # Just making sure that this can run too 428 | 429 | """ 430 | import matplotlib.pyplot as plt 431 | plt.figure() 432 | plt.plot(Y[:,:1], label='Y') 433 | plt.legend() 434 | plt.figure() 435 | plt.plot(allXp2[0:1,:].T, label='FW RNN XHat') 436 | plt.plot(allXp[:, 0:1], '--', label='FW LSSM XHat') 437 | [plt.axvline(x=x, color = 'r') for x in block_samples*(1+np.arange(round(N/block_samples), dtype=float))] 438 | plt.legend() 439 | plt.show() 440 | plt.figure() 441 | plt.plot(allXp2[nx:(nx+1),:].T, label='BW RNN XHat') 442 | plt.plot(allXp[:, nx:(nx+1)], '--', label='BW LSSM XHat') 443 | [plt.axvline(x=x, color = 'r') for x in block_samples*(1+np.arange(round(N/block_samples), dtype=float))] 444 | plt.legend() 445 | plt.show() 446 | plt.figure() 447 | plt.plot(allYp2[0:1,:].T, label='BW RNN YHat') 448 | plt.plot(allYp[:, 0:1], '--', label='BW LSSM YHat') 449 | [plt.axvline(x=x, color = 'r') for x in block_samples*(1+np.arange(round(N/block_samples), dtype=float))] 450 | plt.legend() 451 | plt.show() 452 | #""" 453 | 454 | try: 455 | # TO DO: fix discrepancies along block edges for the backward pass 456 | np.testing.assert_allclose( 457 | allXpFW, allXp2[: allXpFW.shape[1], :].T, rtol=1e-3, atol=1e-6 458 | ) 459 | np.testing.assert_allclose( 460 | allXp[:, nx : (2 * nx)], 461 | allXp2[nx : (2 * nx) :, :].T, 462 | rtol=1e-3, 463 | atol=1e-6, 464 | ) 465 | np.testing.assert_allclose(allXp, allXp2.T, rtol=1e-3, atol=1e-6) 466 | np.testing.assert_allclose(allYp, allYp2.T, rtol=1e-3, atol=1e-6) 467 | except Exception as e: 468 | failInds.append(ci) 469 | failErrs.append(e) 470 | 471 | if len(failInds) > 0: 472 | raise ( 473 | Exception( 474 | "{} => {}/{} random systems (indices: {}) failed: \n{}".format( 475 | self.id(), len(failInds), numTests, failInds, failErrs 476 | ) 477 | ) 478 | ) 479 | else: 480 | print( 481 | "{} => Ok: Tested with {} random systems, all were ok!".format( 482 | self.id(), numTests 483 | ) 484 | ) 485 | 486 | def tearDown(self): 487 | pass 488 | 489 | 490 | if __name__ == "__main__": 491 | unittest.main() 492 | -------------------------------------------------------------------------------- /source/DPAD/tools/GaussianSmoother.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.ndimage import convolve1d 3 | from scipy.signal import filtfilt, lfilter 4 | from scipy.stats import norm 5 | 6 | 7 | class GaussianSmoother(): 8 | def __init__(self, std, Fs=1, axis=0, kernel_multiplier=None, causal=False): 9 | """_summary_ 10 | 11 | Args: 12 | std (float): standard deviation of the gaussian kernel to be used. In seconds. If Fs not provided, std 13 | can also be though of as being in the unit of signal samples. 14 | Fs (int, optional): sampling rate of signal. Defaults to 1. 15 | """ 16 | self.std = std # In seconds 17 | self.Fs = Fs 18 | self.kernel_sigma = std * Fs # effective std for the gaussian kernel 19 | 20 | self.axis = axis 21 | if kernel_multiplier is None: 22 | conf = 0.99 23 | kernel_multiplier = norm.ppf( 1 - 0.5 * (1-conf) ) 24 | self.kernel_multiplier = kernel_multiplier 25 | self.causal = causal 26 | 27 | if self.causal: 28 | sigma = self.kernel_sigma 29 | else: 30 | sigma = self.kernel_sigma / np.sqrt(2) # Because applying the smoothing filter twice with filtfilt multiplies its effective std by sqrt(2) 31 | 32 | m = self.kernel_multiplier 33 | filterLen = int(sigma*m + 0.5) 34 | x0 = 0 35 | xVals = np.arange(-filterLen, filterLen+1, 1) 36 | weights = np.exp( -0.5 * (xVals-x0)**2 / sigma**2 ) 37 | if self.causal: 38 | w_norm = np.sum(weights) 39 | else: 40 | w_padded = np.concatenate((np.zeros_like(weights), weights, np.zeros_like(weights))) 41 | w_norm = np.sqrt( np.sum(convolve1d(w_padded, w_padded, mode='constant', cval=0, origin=0)) ) 42 | 43 | weights = weights / w_norm 44 | 45 | self.weights = weights 46 | 47 | def apply(self, data, axis=None): 48 | if axis is None: 49 | axis = self.axis 50 | if self.causal: 51 | filtered_data = lfilter(self.weights, 1, data, axis=axis) 52 | else: 53 | filtered_data = filtfilt(self.weights, 1, data, axis=axis, method='pad', padtype='constant') 54 | return filtered_data -------------------------------------------------------------------------------- /source/DPAD/tools/LinearMapping.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2024 University of Southern California 3 | See full notice in LICENSE.md 4 | Omid G. Sani and Maryam M. Shanechi 5 | Shanechi Lab, University of Southern California 6 | """ 7 | 8 | """ 9 | An object for building and applying linear mappings 10 | """ 11 | 12 | import logging 13 | 14 | import numpy as np 15 | 16 | from .tools import applyGivenScaling, isFlat, learnScaling, undoScaling 17 | 18 | logger = logging.getLogger(__name__) 19 | 20 | 21 | class LinearMapping: 22 | """Implements a mapping in the form of f(x) = W x + b""" 23 | 24 | def __init__(self, W=None, b=None, missing_marker=None): 25 | self.set_params(W, b) 26 | self.removed_inds = None 27 | self.to_replace_vals = [] 28 | self.replacement_vals = [] 29 | self.missing_marker = None 30 | 31 | def set_params(self, W=None, b=None): 32 | self.set_weight(W) 33 | self.set_intercept(b) 34 | 35 | def set_intercept(self, b=None): 36 | self.b = b 37 | 38 | def set_weight(self, W=None): 39 | self.W = W 40 | if W is not None: 41 | self._W_pinv = np.linalg.pinv(self.W) 42 | else: 43 | self._W_pinv = None 44 | 45 | def get_overall_W(self): 46 | return self.W 47 | 48 | def set_to_dimension_remover(self, keep_vector): 49 | eye = np.eye(keep_vector.size) 50 | remove_inds = np.where(~np.array(keep_vector, dtype=bool))[0] 51 | self.set_weight(np.delete(eye, remove_inds, 0)) 52 | self.set_intercept(None) 53 | self.removed_inds = remove_inds 54 | 55 | def set_value_replacements(self, to_replace_vals, replacement_vals): 56 | self.to_replace_vals = to_replace_vals 57 | self.replacement_vals = replacement_vals 58 | 59 | def apply(self, x): 60 | """Applies mapping to a series of samples. Second dimension is the dimension of samples. 61 | 62 | Args: 63 | x (np.array): _description_ 64 | 65 | Returns: 66 | _type_: _description_ 67 | """ 68 | out = x 69 | if hasattr(self, "to_replace_vals"): 70 | for ri, rep_val in enumerate(self.to_replace_vals): 71 | if np.isnan(rep_val): 72 | rep_inds = np.isnan(out) 73 | elif np.isinf(rep_val): 74 | rep_inds = np.isinf(out) 75 | else: 76 | rep_inds = out == rep_val 77 | rep_val = ( 78 | self.replacement_vals[ri % len(self.replacement_vals)] 79 | if isinstance(self.replacement_vals, (list, tuple)) 80 | else self.replacement_vals 81 | ) 82 | out[rep_inds] = rep_val 83 | if self.W is not None: 84 | out = self.W @ out 85 | if self.b is not None: 86 | out = out + self.b 87 | return out 88 | 89 | def apply_inverse(self, x): 90 | out = x 91 | if self.b is not None: 92 | out = out - self.b 93 | if self.W is not None: 94 | out = self._W_pinv @ out 95 | return out 96 | 97 | 98 | class LinearMappingPerDim(LinearMapping): 99 | def __init__(self, axis=0, **kw_args): 100 | super().__init__(**kw_args) 101 | self.axis = axis 102 | 103 | def set_weight(self, W=None): 104 | self.W = W 105 | 106 | def get_overall_W(self): 107 | if isinstance(self.W, (np.ndarray)): 108 | WMat = self.W 109 | else: 110 | WMat = np.array([self.W]) 111 | if WMat.size < self.b.size: 112 | WMat = WMat * np.ones_like(self.b) 113 | if len(WMat.shape) == 1: 114 | WMat = np.diag(WMat) 115 | return WMat 116 | 117 | def set_to_zscorer( 118 | self, 119 | Y, 120 | axis=0, 121 | remove_mean=True, 122 | zscore=True, 123 | zscore_per_dim=True, 124 | missing_marker=None, 125 | ): 126 | if missing_marker is not None: 127 | self.missing_marker = missing_marker 128 | yMean, yStd = learnScaling( 129 | Y, 130 | remove_mean, 131 | zscore, 132 | zscore_per_dim=zscore_per_dim, 133 | missing_marker=self.missing_marker, 134 | axis=axis, 135 | ) 136 | if not zscore_per_dim: 137 | yStd = yStd[0] 138 | self.axis = axis 139 | self.b = yMean 140 | self.W = yStd 141 | 142 | def apply(self, x): 143 | """Applies mapping to a series of samples. Second dimension is the dimension of samples. 144 | 145 | Args: 146 | x (np.array): _description_ 147 | 148 | Returns: 149 | _type_: _description_ 150 | """ 151 | return applyGivenScaling( 152 | x, self.b, self.W, axis=self.axis, missing_marker=self.missing_marker 153 | ) 154 | 155 | def apply_inverse(self, x): 156 | return undoScaling( 157 | self, 158 | x, 159 | meanField="b", 160 | stdField="W", 161 | axis=self.axis, 162 | missing_marker=self.missing_marker, 163 | ) 164 | 165 | 166 | class LinearMappingSequence: 167 | def __init__(self): 168 | self.maps = [] 169 | 170 | def append(self, map): 171 | if map is not None: 172 | self.maps.append(map) 173 | 174 | def get_overall_W(self): 175 | if len(self.maps) == 0: 176 | return None 177 | for mi, map in enumerate(self.maps): 178 | thisW = map.get_overall_W() 179 | if mi == 0: 180 | W = thisW 181 | else: 182 | W = thisW @ W 183 | return W 184 | 185 | def apply(self, Y): 186 | out = Y 187 | for map in self.maps: 188 | out = map.apply(out) 189 | return out 190 | 191 | def apply_inverse(self, Y): 192 | out = Y 193 | for map in reversed(self.maps): 194 | out = map.apply_inverse(out) 195 | return out 196 | 197 | 198 | def getNaNRemoverMapping(Y, signal_name="", axis=0, verbose=False): 199 | """Returns a LinearMapping that removes NaN/Inf dimensions of the given data data 200 | 201 | Args: 202 | Y (np.array): input data 203 | signal_name (str, optional): _description_. Defaults to ''. 204 | axis (int, optional): Axis over which to check flatness. Defaults to 0. 205 | 206 | Returns: 207 | _type_: _description_ 208 | """ 209 | 210 | # Detect and remove flat data dimensions 211 | if Y is not None: 212 | isAllNans = np.all(np.isnan(Y), axis=axis) 213 | isAllInfs = np.all(np.isinf(Y), axis=axis) 214 | isBadY = np.logical_or(isAllNans, isAllInfs) 215 | if np.any(isBadY): 216 | if verbose: 217 | logger.warning( 218 | "Warning: {}/{} dimensions of signal {} (dims: {})) were just NaN/Inf values, removing them as a preprocessing".format( 219 | np.sum(isBadY), isBadY.size, signal_name, np.where(isBadY)[0] 220 | ) 221 | ) 222 | YPrepMap = LinearMapping() 223 | YPrepMap.set_to_dimension_remover(~isBadY) 224 | else: 225 | YPrepMap = None 226 | else: 227 | YPrepMap = None 228 | return YPrepMap 229 | 230 | 231 | def getFlatRemoverMapping(Y, signal_name="", axis=0, verbose=False): 232 | """Returns a LinearMapping that removes flat dimensions of the given data data 233 | 234 | Args: 235 | Y (np.array): input data 236 | signal_name (str, optional): _description_. Defaults to ''. 237 | axis (int, optional): Axis over which to check flatness. Defaults to 0. 238 | 239 | Returns: 240 | _type_: _description_ 241 | """ 242 | 243 | # Detect and remove flat data dimensions 244 | if Y is not None: 245 | isFlatY = isFlat(Y, axis=axis) 246 | isAllNans = np.all(np.isnan(Y), axis=axis) 247 | isAllInfs = np.all(np.isinf(Y), axis=axis) 248 | isAllNaNsOrInfs = np.logical_or(isAllNans, isAllInfs) 249 | isFlatY = np.logical_or(isFlatY, isAllNaNsOrInfs) 250 | if np.any(isFlatY): 251 | if verbose: 252 | logger.warning( 253 | "Warning: {}/{} dimensions of signal {} (dims: {})) were flat, removing them as a preprocessing".format( 254 | np.sum(isFlatY), isFlatY.size, signal_name, np.where(isFlatY)[0] 255 | ) 256 | ) 257 | YPrepMap = LinearMapping() 258 | YPrepMap.set_to_dimension_remover(~isFlatY) 259 | else: 260 | YPrepMap = None 261 | else: 262 | YPrepMap = None 263 | return YPrepMap 264 | 265 | 266 | def getZScoreMapping( 267 | Y, 268 | signal_name="", 269 | axis=0, 270 | verbose=False, 271 | remove_mean=True, 272 | zscore=True, 273 | zscore_per_dim=True, 274 | missing_marker=None, 275 | ): 276 | """Returns a LinearMapping that zscores the given data data 277 | 278 | Args: 279 | Y (np.array): input data 280 | signal_name (str, optional): _description_. Defaults to ''. 281 | axis (int, optional): Axis over which to check flatness. Defaults to 0. 282 | 283 | Returns: 284 | _type_: _description_ 285 | """ 286 | 287 | if Y is not None: 288 | YPrepMap = LinearMappingPerDim() 289 | YPrepMap.set_to_zscorer( 290 | Y, 291 | axis=axis, 292 | remove_mean=remove_mean, 293 | zscore=zscore, 294 | zscore_per_dim=zscore_per_dim, 295 | missing_marker=missing_marker, 296 | ) 297 | else: 298 | YPrepMap = None 299 | return YPrepMap 300 | -------------------------------------------------------------------------------- /source/DPAD/tools/SSM.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2024 University of Southern California 3 | See full notice in LICENSE.md 4 | Omid G. Sani and Maryam M. Shanechi 5 | Shanechi Lab, University of Southern California 6 | """ 7 | 8 | """An SSM object for keeping parameters, filtering, etc""" 9 | import logging 10 | import time 11 | 12 | import numpy as np 13 | from PSID.LSSM import LSSM, genRandomGaussianNoise 14 | from sympy import Poly, symbols 15 | from tqdm import tqdm 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | 20 | class SSM(LSSM): 21 | def __init__(self, lssm=None, **kwargs): 22 | if lssm is not None: 23 | if "params" not in kwargs: 24 | kwargs["params"] = {} 25 | for p in lssm.getListOfParams(): 26 | kwargs["params"][p] = getattr(lssm, p) 27 | super().__init__(**kwargs) 28 | 29 | def changeParamsIsolated(self, params={}): 30 | """ 31 | Changes given parameters but DOES NOT update any other potentially dependent parameters! 32 | Use with care! 33 | """ 34 | for k, v in params.items(): 35 | setattr(self, k, v) 36 | 37 | def get_pram_sym(self, param_name): 38 | """ 39 | Returns a symbolic var reprenting the operation of one parameter 40 | """ 41 | p = getattr(self, param_name) 42 | 43 | n_out, n_in = p.shape 44 | in_sym_names = ",".join(["x{}".format(ii) for ii in range(n_in)]) 45 | in_syms = symbols(in_sym_names) 46 | if not isinstance(in_syms, tuple): 47 | in_syms = (in_syms,) 48 | symP = [] 49 | for oi in range(n_out): 50 | dimP = 0 51 | for ii in range(n_in): 52 | dimP += float(p[oi, ii]) * in_syms[ii] 53 | symP.append(dimP.as_poly()) 54 | return symP 55 | 56 | def apply_param(self, param_name, input, **kwargs): 57 | p = getattr(self, param_name) 58 | if isinstance(p, np.ndarray) and not isinstance(p.flatten()[0], Poly): 59 | out = p @ input 60 | elif isinstance(np.array(p).flatten()[0], Poly): 61 | pFlat = np.array(p).flatten() 62 | out = np.empty((len(pFlat), input.shape[1])) 63 | for dimInd, dimP in enumerate(pFlat): 64 | dimPExpr = dimP.as_expr() 65 | var_syms = dimPExpr.free_symbols 66 | var_names = list(map(str, var_syms)) 67 | var_syms_sorted = [v2 for v1, v2 in sorted(zip(var_names, var_syms))] 68 | if isinstance(dimP, Poly): 69 | timeInds = range(input.shape[1]) 70 | if input.shape[1] > 100: 71 | timeInds = tqdm(timeInds, f"applying {param_name}") 72 | for ti in timeInds: 73 | subVarVals = input[:, ti] 74 | res = dimPExpr.subs(list(zip(var_syms_sorted, subVarVals))) 75 | out[dimInd, ti] = float(res) 76 | else: 77 | raise (Exception("Not supported")) 78 | return out 79 | 80 | def get_param_io_count(self, param_name): 81 | p = getattr(self, param_name) 82 | if isinstance(p, np.ndarray) and not isinstance(p.flatten()[0], Poly): 83 | in_dim = p.shape[-1] 84 | out_dim = p.shape[0] 85 | elif isinstance(np.array(p).flatten()[0], Poly): 86 | pFlat = np.array(p).flatten() 87 | in_dim = len((pFlat[0].as_expr()).free_symbols) 88 | out_dim = len(pFlat) 89 | else: 90 | raise (Exception("Not supported")) 91 | return in_dim, out_dim 92 | 93 | def generateObservationFromStates( 94 | self, X, u=None, param_names=["C", "D"], prep_model_param="", mapping_param="" 95 | ): 96 | Y = None 97 | if hasattr(self, param_names[0]): 98 | C = getattr(self, param_names[0]) 99 | else: 100 | C = None 101 | if len(param_names) > 1 and hasattr(self, param_names[1]): 102 | D = getattr(self, param_names[1]) 103 | else: 104 | D = None 105 | 106 | if len(param_names) > 2 and hasattr(self, param_names[2]): 107 | errSys = getattr(self, param_names[2]) 108 | else: 109 | errSys = None 110 | 111 | if C is not None or D is not None: 112 | ny = ( 113 | self.get_param_io_count(param_names[0])[1] 114 | if C is not None 115 | else self.get_param_io_count(param_names[1])[1] 116 | ) 117 | N = X.shape[0] 118 | Y = np.zeros((N, ny)) 119 | if C is not None: 120 | Y += self.apply_param(param_names[0], X.T).T 121 | if D is not None and u is not None: 122 | if hasattr(self, "UPrepModel") and self.UPrepModel is not None: 123 | u = self.UPrepModel.apply( 124 | u, time_first=True 125 | ) # Apply any mean removal/zscoring 126 | Y += self.apply_param(param_names[1], u.T).T 127 | 128 | if errSys is not None: 129 | err = errSys.generateRealization(N=N, return_z=True)[2] 130 | if err is not None: 131 | Y = Y + err if Y is not None else Y 132 | 133 | if prep_model_param is not None and hasattr(self, prep_model_param): 134 | prep_model_param_obj = getattr(self, prep_model_param) 135 | if prep_model_param_obj is not None: 136 | Y = prep_model_param_obj.apply_inverse( 137 | Y 138 | ) # Apply inverse of any mean-removal/zscoring 139 | 140 | if mapping_param is not None and hasattr(self, mapping_param): 141 | mapping_param_obj = getattr(self, mapping_param) 142 | if mapping_param_obj is not None and hasattr(mapping_param_obj, "map"): 143 | Y = mapping_param_obj.map(Y) 144 | return Y 145 | 146 | def generateRealizationWithQRS( 147 | self, 148 | N, 149 | x0=None, 150 | w0=None, 151 | u0=None, 152 | u=None, 153 | wv=None, 154 | return_z=False, 155 | return_z_err=False, 156 | return_wv=False, 157 | blowup_threshold=None, 158 | reset_x_on_blowup=None, 159 | randomize_x_on_blowup=None, 160 | ): 161 | if blowup_threshold is None: 162 | if hasattr(self, "blowup_threshold"): 163 | blowup_threshold = self.blowup_threshold 164 | else: 165 | blowup_threshold = np.inf 166 | if reset_x_on_blowup is None: 167 | if hasattr(self, "reset_x_on_blowup"): 168 | reset_x_on_blowup = self.reset_x_on_blowup 169 | else: 170 | reset_x_on_blowup = False 171 | if randomize_x_on_blowup is None: 172 | if hasattr(self, "randomize_x_on_blowup"): 173 | randomize_x_on_blowup = self.randomize_x_on_blowup 174 | else: 175 | randomize_x_on_blowup = False 176 | QRS = np.block([[self.Q, self.S], [self.S.T, self.R]]) 177 | wv, self.QRSShaping = genRandomGaussianNoise(N, QRS) 178 | w = wv[:, : self.state_dim] 179 | v = wv[:, self.state_dim :] 180 | if x0 is None: 181 | if hasattr(self, "x0"): 182 | x0 = self.x0 183 | else: 184 | x0 = np.zeros((self.state_dim, 1)) 185 | if len(x0.shape) == 1: 186 | x0 = x0[:, np.newaxis] 187 | if w0 is None: 188 | w0 = np.zeros((self.state_dim, 1)) 189 | if self.input_dim > 0 and u0 is None: 190 | u0 = np.zeros((self.input_dim, 1)) 191 | X = np.empty((N, self.state_dim)) 192 | Y = np.empty((N, self.output_dim)) 193 | for i in tqdm(range(N), "Generating realization"): 194 | if i == 0: 195 | Xt_1 = x0 196 | Wt_1 = w0 197 | if self.input_dim > 0 and u is not None: 198 | Ut_1 = u0 199 | else: 200 | Xt_1 = X[(i - 1) : i, :].T 201 | Wt_1 = w[(i - 1) : i, :].T 202 | if self.input_dim > 0 and u is not None: 203 | Ut_1 = u[(i - 1) : i, :].T 204 | X[i, :] = (self.apply_param("A", Xt_1) + Wt_1).T 205 | if u is not None: 206 | X[i, :] += np.squeeze(self.apply_param("B", Ut_1).T) 207 | # Check if X[i, :] has blown up 208 | if ( 209 | np.any(np.isnan(X[i, :])) 210 | or np.any(np.isinf(X[i, :])) 211 | or np.any(np.abs(X[i, :]) > blowup_threshold) 212 | ): 213 | msg = f"Xp blew up at sample {i} (mean Xp={np.mean(X[i, :]):.3g})" 214 | if reset_x_on_blowup: 215 | X[i, :] = x0 216 | msg += f", so it was reset to initial x0 (mean x0={np.mean(X[i, :]):.3g})" 217 | if randomize_x_on_blowup: 218 | X[i, :] = np.atleast_2d( 219 | np.random.multivariate_normal( 220 | mean=np.zeros(self.state_dim), cov=self.XCov 221 | ) 222 | ).T 223 | msg += f", so it was reset to a random Gaussian x0 with XCov (mean x0={np.mean(X[i, :]):.3g})" 224 | logger.warning(msg) 225 | Y = v 226 | CxDu = self.generateObservationFromStates( 227 | X, u=u, param_names=["C", "D"], prep_model_param="YPrepModel" 228 | ) 229 | if CxDu is not None: 230 | Y += CxDu 231 | out = Y, X 232 | if return_z: 233 | Z, ZErr = self.generateZRealizationFromStates(X=X, U=u, return_err=True) 234 | out += (Z,) 235 | if return_z_err: 236 | out += (ZErr,) 237 | if return_wv: 238 | out += (wv,) 239 | return out 240 | 241 | def generateRealizationWithKF( 242 | self, 243 | N, 244 | x0=None, 245 | u0=None, 246 | u=None, 247 | e=None, 248 | return_z=False, 249 | return_z_err=False, 250 | return_e=False, 251 | blowup_threshold=None, 252 | reset_x_on_blowup=None, 253 | randomize_x_on_blowup=None, 254 | ): 255 | if blowup_threshold is None: 256 | if hasattr(self, "blowup_threshold"): 257 | blowup_threshold = self.blowup_threshold 258 | else: 259 | blowup_threshold = np.inf 260 | if reset_x_on_blowup is None: 261 | if hasattr(self, "reset_x_on_blowup"): 262 | reset_x_on_blowup = self.reset_x_on_blowup 263 | else: 264 | reset_x_on_blowup = False 265 | if randomize_x_on_blowup is None: 266 | if hasattr(self, "randomize_x_on_blowup"): 267 | randomize_x_on_blowup = self.randomize_x_on_blowup 268 | else: 269 | randomize_x_on_blowup = False 270 | if e is None: 271 | e, innovShaping = genRandomGaussianNoise(N, self.innovCov) 272 | if x0 is None: 273 | if hasattr(self, "x0"): 274 | x0 = self.x0 275 | else: 276 | x0 = np.zeros((self.state_dim, 1)) 277 | if len(x0.shape) == 1: 278 | x0 = x0[:, np.newaxis] 279 | if self.input_dim > 0 and u0 is None: 280 | u0 = np.zeros((self.input_dim, 1)) 281 | X = np.empty((N, self.state_dim)) 282 | Y = np.empty((N, self.output_dim)) 283 | Xp = x0 284 | tic = time.perf_counter() 285 | time_passed = 0 286 | for i in tqdm(range(N), "Generating realization"): 287 | ek = e[i, :][:, np.newaxis] 288 | yk = self.apply_param("C", Xp) + ek 289 | if u is not None: 290 | yk += self.apply_param("D", u[i, :][:, np.newaxis]) 291 | X[i, :] = np.squeeze(Xp) 292 | Y[i, :] = np.squeeze(yk) 293 | # Xp = self.apply_param('A', Xp) \ 294 | # - self.apply_param('K', self.apply_param('C', Xp)) \ 295 | # + self.apply_param('K', yk) 296 | Xp = self.apply_param("A_KC", Xp) + self.apply_param("K", yk) 297 | if u is not None: 298 | Ut = u[i, :][:, np.newaxis] 299 | Xp += self.apply_param("B_KD", Ut) 300 | # Check if Xp has blown up 301 | if ( 302 | np.any(np.isnan(Xp)) 303 | or np.any(np.isinf(Xp)) 304 | or np.any(np.abs(Xp) > blowup_threshold) 305 | ): 306 | msg = f"Xp blew up at sample {i} (mean Xp={np.mean(Xp):.3g})" 307 | if reset_x_on_blowup: 308 | Xp = x0 309 | msg += ( 310 | f", so it was reset to initial x0 (mean x0={np.mean(Xp):.3g})" 311 | ) 312 | if randomize_x_on_blowup: 313 | Xp = np.atleast_2d( 314 | np.random.multivariate_normal( 315 | mean=np.zeros(self.state_dim), cov=self.XCov 316 | ) 317 | ).T 318 | msg += f", so it was reset to a random Gaussian x0 with XCov (mean x0={np.mean(Xp):.3g})" 319 | logger.warning(msg) 320 | toc = time.perf_counter() 321 | print_secs = 60 322 | if ( 323 | (toc - tic) > print_secs 324 | and np.mod(toc - tic, print_secs) < 0.5 * print_secs 325 | and np.mod(time_passed, print_secs) >= 0.5 * print_secs 326 | ): 327 | logger.info( 328 | "{:.2f}% ({}/{} samples) generated after {:.3g} min(s) and {:.3g} second(s)".format( 329 | i / N * 100, i, N, (toc - tic) // 60, (toc - tic) % 60 330 | ) 331 | ) 332 | time_passed = toc - tic 333 | 334 | out = Y, X 335 | if return_z: 336 | Z, ZErr = self.generateZRealizationFromStates(X=X, U=u, return_err=True) 337 | out += (Z,) 338 | if return_z_err: 339 | out += (ZErr,) 340 | if return_e: 341 | out += (e,) 342 | return out 343 | 344 | def find_fixedpoints(self, Y=None, X=None, U=None, N=None): 345 | if X is None: 346 | N = 1000 347 | Y, X, Z = self.generateRealization(N, return_z=True, u=U) 348 | 349 | if N is None: 350 | inds = np.arange(X.shape[0]) 351 | else: 352 | inds = np.arange(np.min((N, X.shape[0]))) 353 | 354 | oDiff = ( 355 | X[inds[1:], :].T 356 | - self.apply_param("K", Y[inds[:-1], :].T) 357 | - X[inds[:-1], :].T 358 | ).T 359 | oDiffNorm = np.sum(oDiff**2, axis=1) 360 | rootInd = np.argsort(oDiffNorm) 361 | maxNorm = 1e-2 362 | rootInd = rootInd[oDiffNorm[rootInd] < maxNorm] 363 | rootVal = X[rootInd, :] 364 | 365 | if rootVal.size > 0: 366 | # Subsample to keep at most 10k examples otherwise clustering will be too slow 367 | if rootVal.size > 1000: 368 | rootVal = np.random.choice(rootVal.flatten(), 1000, replace=False)[ 369 | :, np.newaxis 370 | ] 371 | from sklearn import cluster 372 | 373 | clustering = cluster.MeanShift().fit(rootVal) 374 | fpVals = clustering.cluster_centers_ 375 | else: 376 | fpVals = [] 377 | return fpVals 378 | 379 | def kalman( 380 | self, 381 | Y, 382 | U=None, 383 | x0=None, 384 | P0=None, 385 | steady_state=True, 386 | blowup_threshold=None, 387 | clip_on_blowup=None, 388 | reset_x_on_blowup=None, 389 | randomize_x_on_blowup=None, 390 | ): 391 | if blowup_threshold is None: 392 | if hasattr(self, "blowup_threshold"): 393 | blowup_threshold = self.blowup_threshold 394 | else: 395 | blowup_threshold = np.inf 396 | if clip_on_blowup is None: 397 | if hasattr(self, "clip_on_blowup"): 398 | clip_on_blowup = self.clip_on_blowup 399 | else: 400 | clip_on_blowup = False 401 | if reset_x_on_blowup is None: 402 | if hasattr(self, "reset_x_on_blowup"): 403 | reset_x_on_blowup = self.reset_x_on_blowup 404 | else: 405 | reset_x_on_blowup = False 406 | if randomize_x_on_blowup is None: 407 | if hasattr(self, "randomize_x_on_blowup"): 408 | randomize_x_on_blowup = self.randomize_x_on_blowup 409 | else: 410 | randomize_x_on_blowup = False 411 | if self.state_dim == 0: 412 | allXp = np.zeros((Y.shape[0], self.state_dim)) 413 | allX = allXp 414 | allYp = np.zeros((Y.shape[0], self.output_dim)) 415 | return allXp, allYp, allX 416 | if not steady_state: 417 | raise (Exception("Not supported!")) 418 | N = Y.shape[0] 419 | allXp = np.empty((N, self.state_dim)) # X(i|i-1) 420 | # allX = np.empty((N, self.state_dim)) 421 | allX = None 422 | if x0 is None: 423 | if hasattr(self, "x0"): 424 | x0 = self.x0 425 | else: 426 | x0 = np.zeros((self.state_dim, 1)) 427 | if len(x0.shape) == 1: 428 | x0 = x0[:, np.newaxis] 429 | if P0 is None: 430 | if hasattr(self, "P0"): 431 | P0 = self.P0 432 | else: 433 | P0 = np.eye(self.state_dim) 434 | Xp = x0 435 | Pp = P0 436 | for i in tqdm(range(N), "Estimating latent states"): 437 | allXp[i, :] = np.transpose(Xp) # X(i|i-1) 438 | thisY = Y[i, :][np.newaxis, :] 439 | if hasattr(self, "YPrepModel") and self.YPrepModel is not None: 440 | thisY = self.YPrepModel.apply( 441 | thisY, time_first=True 442 | ) # Apply any mean removal/zscoring 443 | 444 | if U is not None: 445 | ui = U[i, :][:, np.newaxis] 446 | if hasattr(self, "UPrepModel") and self.UPrepModel is not None: 447 | ui = self.UPrepModel.apply( 448 | ui, time_first=False 449 | ) # Apply any mean removal/zscoring 450 | 451 | if self.missing_marker is not None and np.any( 452 | Y[i, :] == self.missing_marker 453 | ): 454 | newXp = self.apply_param("A", Xp) 455 | if U is not None and self.B.size > 0: 456 | newXp += self.apply_param("B", ui) 457 | else: 458 | newXp = self.apply_param("A_KC", Xp) + self.apply_param("K", thisY.T) 459 | if U is not None: 460 | newXp += self.apply_param("B_KD", ui) 461 | # Check if Xp has blown up 462 | if ( 463 | np.any(np.isnan(newXp)) 464 | or np.any(np.isinf(newXp)) 465 | or np.any(np.abs(newXp) > blowup_threshold) 466 | ): 467 | msg = f"Xp blew up at sample {i} (mean Xp={np.mean(newXp):.3g})" 468 | if clip_on_blowup: 469 | msg += f", so it was clipped to its previous value (mean x0={np.mean(Xp):.3g})" 470 | newXp = Xp 471 | if reset_x_on_blowup: 472 | newXp = x0 473 | msg += f", so it was reset to initial x0 (mean x0={np.mean(newXp):.3g})" 474 | if randomize_x_on_blowup: 475 | newXp = np.atleast_2d( 476 | np.random.multivariate_normal( 477 | mean=np.zeros(self.state_dim), cov=self.XCov 478 | ) 479 | ).T 480 | msg += f", so it was reset to a random Gaussian x0 with XCov (mean x0={np.mean(newXp):.3g})" 481 | logger.warning(msg) 482 | Xp = newXp 483 | 484 | allYp = self.generateObservationFromStates( 485 | allXp, 486 | u=U, 487 | param_names=["C", "D"], 488 | prep_model_param="YPrepModel", 489 | mapping_param="cMapY", 490 | ) 491 | return allXp, allYp, allX 492 | 493 | def propagateStates(self, allXp, step_ahead=1): 494 | for step in range(step_ahead - 1): 495 | if ( 496 | hasattr(self, "multi_step_with_A_KC") and self.multi_step_with_A_KC 497 | ): # If true, forward predictions will be done with A-KC rather than the correct A (but will be useful for comparing with predictor form models) 498 | # allXp = self.apply_param('A', allXp.T).T - self.apply_param('K', self.apply_param('C', allXp.T)).T 499 | allXp = self.apply_param("A_KC", allXp.T).T 500 | else: 501 | allXp = self.apply_param("A", allXp.T).T 502 | return allXp 503 | -------------------------------------------------------------------------------- /source/DPAD/tools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShanechiLab/DPAD/a09e7d9c3e59f1adb2d75336705dc166c15f4039/source/DPAD/tools/__init__.py -------------------------------------------------------------------------------- /source/DPAD/tools/abstract_classes.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2024 University of Southern California 3 | See full notice in LICENSE.md 4 | Omid G. Sani and Maryam M. Shanechi 5 | Shanechi Lab, University of Southern California 6 | """ 7 | 8 | """Abstract classes used to standardize predictor models""" 9 | 10 | from abc import ABC, abstractmethod 11 | 12 | 13 | class PredictorModel(ABC): 14 | @abstractmethod 15 | def predict(self, Y, U=None): 16 | pass 17 | -------------------------------------------------------------------------------- /source/DPAD/tools/file_tools.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2024 University of Southern California 3 | See full notice in LICENSE.md 4 | Omid G. Sani and Maryam M. Shanechi 5 | Shanechi Lab, University of Southern California 6 | """ 7 | 8 | import bz2 9 | import gzip 10 | import os 11 | import pickle 12 | from collections import deque 13 | from itertools import chain 14 | from pathlib import Path 15 | from sys import stderr 16 | 17 | try: 18 | from reprlib import repr 19 | except ImportError: 20 | pass 21 | 22 | import numpy as np 23 | 24 | 25 | def pickle_load(filePath): 26 | """Loads a pickle file 27 | 28 | Args: 29 | filePath (string): file path 30 | 31 | Returns: 32 | data (Any): pickle file content 33 | """ 34 | with open(filePath, "rb") as f: 35 | return pickle.load(f) 36 | 37 | 38 | def pickle_save(filePath, data): 39 | """Saves a pickle file 40 | 41 | Args: 42 | filePath (string): file path 43 | data (Any): data to save in file 44 | """ 45 | with open(filePath, "wb") as f: 46 | pickle.dump(data, f) 47 | 48 | 49 | def pickle_load_compressed(filePath, format="bz2", auto_add_ext=False): 50 | """Loads data from a compressed pickle file 51 | 52 | Args: 53 | filePath (string): file path 54 | format (str, optional): the compression format, can be in ['bz2', 'gz']. Defaults to 'bz2'. 55 | auto_add_ext (bool, optional): if true, will automatically add the 56 | extension for the compression format to the file path. Defaults to False. 57 | 58 | Returns: 59 | data (Any): pickle file content 60 | """ 61 | if format == "bz2": 62 | if auto_add_ext: 63 | filePath += ".bz2" 64 | data = bz2.BZ2File(filePath, "rb") 65 | elif format == "gzip": 66 | if auto_add_ext: 67 | filePath += ".gz" 68 | data = gzip.open(filePath, "rb") 69 | else: 70 | raise (Exception("Unsupported format: {}".format(format))) 71 | return pickle.load(data) 72 | 73 | 74 | def pickle_save_compressed(filePath, data, format="bz2", auto_add_ext=False): 75 | """Saves data as a compressed pickle file 76 | 77 | Args: 78 | filePath (string): file path 79 | data (Any): data to save in file 80 | format (str, optional): the compression format, can be in ['bz2', 'gz']. Defaults to 'bz2'. 81 | auto_add_ext (bool, optional): if true, will automatically add the 82 | extension for the compression format to the file path. Defaults to False. 83 | """ 84 | if format == "bz2": 85 | if auto_add_ext: 86 | filePath += ".bz2" 87 | with bz2.BZ2File(filePath, "w") as f: 88 | pickle.dump(data, f) 89 | elif format == "gzip": 90 | if auto_add_ext: 91 | filePath += ".gz" 92 | with gzip.open(filePath, "w") as f: 93 | pickle.dump(data, f) 94 | else: 95 | raise (Exception("Unsupported format: {}".format(format))) 96 | 97 | 98 | def mk_parent_dir(file_path): 99 | return Path(os.path.dirname(file_path)).mkdir(parents=True, exist_ok=True) 100 | 101 | 102 | def bytes_to_string(num, suffix="B"): 103 | """Gets size in bytes and returns a human readable string 104 | 105 | Args: 106 | num (number): input size in bytes 107 | suffix (str, optional): Suffix to add to the final string. Defaults to 'B'. 108 | 109 | Returns: 110 | string: human readable string of the input size 111 | """ 112 | for unit in ["", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"]: 113 | if abs(num) < 1024.0: 114 | return "%3.1f%s%s" % (num, unit, suffix) 115 | num /= 1024.0 116 | return "%.1f%s%s" % (num, "Yi", suffix) 117 | 118 | 119 | def file_exists(filePath): 120 | """Returns true if file exists 121 | 122 | Args: 123 | filePath (string): file path 124 | 125 | Returns: 126 | output (bool): True if file exists 127 | """ 128 | return os.path.exists(filePath) 129 | 130 | 131 | def get_file_size(filePath): 132 | """Returns the file size in bytes 133 | 134 | Args: 135 | filePath (string): file path 136 | 137 | Returns: 138 | size (number): size of file in bytes 139 | """ 140 | return os.path.getsize(filePath) 141 | -------------------------------------------------------------------------------- /source/DPAD/tools/model_base_classes.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2024 University of Southern California 3 | See full notice in LICENSE.md 4 | Omid G. Sani and Maryam M. Shanechi 5 | Shanechi Lab, University of Southern California 6 | """ 7 | 8 | """Some base classes that RNNModel and RegressionModel inherit from""" 9 | 10 | import copy 11 | import io 12 | import logging 13 | import os 14 | import time 15 | import warnings 16 | from datetime import datetime 17 | 18 | import matplotlib.pyplot as plt 19 | import numpy as np 20 | import tensorflow as tf 21 | 22 | from .plot import plotPredictionScatter, plotTimeSeriesPrediction 23 | from .tf_tools import ( 24 | convertHistoryToDict, 25 | getModelFitHistoyStr, 26 | set_global_tf_eagerly_flag, 27 | ) 28 | 29 | logger = logging.getLogger(__name__) 30 | 31 | 32 | class ReconstructionInfo: 33 | """A class that can store information required to reconstruct a RegressionModel or RNNModel based on 34 | their constructor arguments and tf weights (rather than tf objects), which can be easily be pickled 35 | """ 36 | 37 | def __init__(self, weights, constructor_kwargs): 38 | self.weights = weights 39 | self.constructor_kwargs = constructor_kwargs 40 | 41 | 42 | class Reconstructable: 43 | """A class that allows a child class with tf models to be saved into pickle files and later be 44 | reconstructred 45 | """ 46 | 47 | def get_recreation_info(self): 48 | constructor_kwargs = self.constructor_kwargs 49 | # for k in self.constructor_kwargs.keys(): 50 | # constructor_kwargs[k] = getattr(self, k, self.constructor_kwargs[k]) 51 | return ReconstructionInfo( 52 | weights=self.model.get_weights(), constructor_kwargs=constructor_kwargs 53 | ) 54 | 55 | def reconstruct(self, reconstruction_info): 56 | cls = type(self) # Get the child class: either RNNModel or RegressionModel 57 | newInstance = cls(**reconstruction_info.constructor_kwargs) 58 | newInstance.model.set_weights(reconstruction_info.weights) 59 | return newInstance 60 | 61 | def save_to_file(self, file_path): 62 | """Calls the Keras save method to save the model to a file. 63 | 64 | Args: 65 | file_path (_type_): _description_ 66 | 67 | Returns: 68 | _type_: _description_ 69 | """ 70 | return self.model.save(file_path) 71 | 72 | def load_from_file(self, file_path): 73 | """Calls the Keras load method to load the model from a file.""" 74 | return self.model.load_model(file_path) # TEMP, throws error! 75 | 76 | 77 | # Inherited from https://github.com/keras-team/keras/blob/v2.8.0/keras/callbacks.py#L1744-L1891 78 | # Changing one line to allow a min_epoch to be set so that early stopping kicks in only after that many trials 79 | # start_from_epoch = 0 will reduce to the original functionality 80 | # Standard in newer tf versions: 81 | # https://github.com/keras-team/keras/commit/05d90d2a6931b5a583579cd2ef2e6932919afa63 82 | class EarlyStoppingWithMinEpochs(tf.keras.callbacks.EarlyStopping): 83 | """Modified EarlyStopping class to allow a minimum number of epochs to be specified before early stopping kicks in.""" 84 | 85 | def __init__(self, start_from_epoch=0, **kwargs): 86 | super().__init__(**kwargs) 87 | self.start_from_epoch = start_from_epoch 88 | 89 | def on_epoch_end(self, epoch, logs=None): 90 | current = self.get_monitor_value(logs) 91 | if current is None or epoch <= self.start_from_epoch: 92 | # If no monitor value exists or still in initial warm-up stage. 93 | return 94 | if self.restore_best_weights and self.best_weights is None: 95 | # Restore the weights after first epoch if no progress is ever made. 96 | self.best_weights = self.model.get_weights() 97 | 98 | self.wait += 1 99 | if self._is_improvement(current, self.best): 100 | self.best = current 101 | self.best_epoch = epoch 102 | if self.restore_best_weights: 103 | self.best_weights = self.model.get_weights() 104 | # Only restart wait if we beat both the baseline and our previous best. 105 | if self.baseline is None or self._is_improvement(current, self.baseline): 106 | self.wait = 0 107 | 108 | # Only check after the first epoch. 109 | if self.wait >= self.patience: 110 | self.stopped_epoch = epoch 111 | self.model.stop_training = True 112 | if self.restore_best_weights and self.best_weights is not None: 113 | # if self.verbose > 0: 114 | logger.info( 115 | "Restoring model weights from the end of the best epoch: " 116 | f"{self.best_epoch + 1} (stopped at {self.stopped_epoch} epochs)." 117 | ) 118 | self.model.set_weights(self.best_weights) 119 | 120 | 121 | # https://www.tensorflow.org/tensorboard/image_summaries 122 | def plot_to_image(figure): 123 | """Converts the matplotlib plot specified by 'figure' to a PNG image and 124 | returns it. The supplied figure is closed and inaccessible after this call.""" 125 | # Save the plot to a PNG in memory. 126 | buf = io.BytesIO() 127 | plt.savefig(buf, format="png") 128 | # Closing the figure prevents it from being displayed directly inside the notebook. 129 | buf.seek(0) 130 | # Convert PNG buffer to TF image 131 | image = tf.image.decode_png(buf.getvalue(), channels=4) 132 | # Add the batch dimension 133 | image = tf.expand_dims(image, 0) 134 | return image 135 | 136 | 137 | class ModelWithFitWithRetry: 138 | """A class that adds a fit_with_retry method to classes inheriting from it. 139 | Used by RNNModel and RegressionModel. 140 | """ 141 | 142 | def fit_with_retry( 143 | self, 144 | init_attempts=1, 145 | early_stopping_patience=3, 146 | early_stopping_measure="loss", 147 | early_stopping_restore_best_weights=True, 148 | start_from_epoch=0, 149 | tb_make_prediction_plots=False, 150 | tb_make_prediction_scatters=False, 151 | tb_plot_epoch_mod=20, 152 | x=None, 153 | y=None, 154 | callbacks=None, 155 | validation_data=None, 156 | keep_latest_nonnan_weights=True, 157 | **kwargs, 158 | ): 159 | """Calls keras fit for the model, with the option to redo the fitting multiple 160 | times with different initializations 161 | 162 | Args: 163 | self (RegressionModel or RNNModel): the object to fit. 164 | Must be fully ready to call obj.model.fit 165 | init_attempts (int, optional): The number of refitting attempts. Defaults to 1. 166 | If more than 1, the attempt with the smallest 'loss' will be selected in the end. 167 | early_stopping_patience (int, optional): [description]. Defaults to 3. 168 | # The rest of the arguments will be passed to keras model.fit and should include: 169 | x: input 170 | y: output 171 | See the tf keras help for more: 172 | https://www.tensorflow.org/api_docs/python/tf/keras/Model#fit 173 | Returns: 174 | history (A History object): the output of keras model fit 175 | """ 176 | if callbacks is None: 177 | callbacks = [] 178 | 179 | def compute_reg_loss(epoch, logs=None): 180 | if logs is None: 181 | return 182 | logs["learning_rate"] = float( 183 | self.model.optimizer.learning_rate 184 | ) # Save current learning rate 185 | if len(self.model.losses): # If we have some regularization loss 186 | regularization_loss = float(tf.math.add_n(self.model.losses)) 187 | logs["regularization_loss"] = regularization_loss 188 | total_loss = logs["loss"] 189 | logs["loss_minus_regularization"] = total_loss - regularization_loss 190 | if "val_loss" in logs: 191 | logs["val_loss_minus_regularization"] = ( 192 | logs["val_loss"] - regularization_loss 193 | ) 194 | 195 | latest_nonnan_weights, latest_nonnan_weights_epoch, latest_epoch = ( 196 | None, 197 | None, 198 | None, 199 | ) 200 | 201 | def keep_latest_nonnan_weights_callback(epoch, logs=None): 202 | nonlocal latest_nonnan_weights, latest_nonnan_weights_epoch, latest_epoch 203 | weights = self.model.get_weights() 204 | nan_weights = [v for v in weights if np.any(np.isnan(v))] 205 | if len(nan_weights) == 0: 206 | latest_nonnan_weights = copy.copy(weights) 207 | latest_nonnan_weights_epoch = epoch 208 | else: 209 | logger.info( 210 | f"epoch {epoch} has {len(nan_weights)} blown up nan weights!" 211 | ) 212 | latest_epoch = epoch 213 | 214 | keep_latest_nonnan_weights_callback( 215 | 0 216 | ) # To save initial weights, which should definitely be nonnan 217 | 218 | from ..DPADModel import shift_ms_to_1s_series 219 | from ..RegressionModel import RegressionModel 220 | 221 | attempt = 0 222 | modelsWeightsAll, historyAll, log_subdirs = [], [], [] 223 | while attempt < init_attempts: 224 | attempt += 1 225 | if init_attempts > 1: 226 | logger.info( 227 | "Starting fit attempt {} of {}".format(attempt, init_attempts) 228 | ) 229 | callbacks_this = copy.deepcopy(callbacks) 230 | 231 | # Early stopping: 232 | early_stopping_callback = EarlyStoppingWithMinEpochs( 233 | monitor=early_stopping_measure, 234 | patience=early_stopping_patience, 235 | restore_best_weights=early_stopping_restore_best_weights, 236 | start_from_epoch=start_from_epoch, 237 | ) 238 | callbacks_this.append(early_stopping_callback) 239 | callbacks_this.append( 240 | tf.keras.callbacks.LambdaCallback(on_epoch_end=compute_reg_loss) 241 | ) 242 | 243 | if keep_latest_nonnan_weights: 244 | callbacks_this.append( 245 | tf.keras.callbacks.LambdaCallback( 246 | on_epoch_end=keep_latest_nonnan_weights_callback 247 | ) 248 | ) 249 | 250 | if self.log_dir != "": 251 | log_subdir = datetime.now().strftime("%Y%m%d-%H%M%S") 252 | log_subdirs.append(log_subdir) 253 | log_dir = os.path.join(self.log_dir, log_subdir) 254 | logger.info("Tensorboard log_dir: {}".format(log_dir)) 255 | callbacks_this.append( 256 | tf.keras.callbacks.TensorBoard( 257 | log_dir=log_dir, histogram_freq=1, profile_batch="10,20" 258 | ) 259 | ) 260 | # # Save regularization loss 261 | # file_writer_metrics = tf.summary.create_file_writer(log_dir + '/metrics') 262 | # def tensorboard_save_reg_loss(epoch, logs): 263 | # total_loss = logs['loss'] 264 | # if len(self.model.losses): 265 | # regularization_loss = tf.math.add_n(self.model.losses) 266 | # else: 267 | # regularization_loss = 0 268 | # with file_writer_metrics.as_default(): 269 | # tf.summary.scalar('regularization loss', data=regularization_loss, step=epoch) 270 | # tf.summary.scalar('fit loss', data=total_loss-regularization_loss, step=epoch) 271 | # callbacks_this.append(tf.keras.callbacks.LambdaCallback( 272 | # on_epoch_end=tensorboard_save_reg_loss 273 | # )) 274 | # Save some data plots if requested 275 | if tb_make_prediction_plots or tb_make_prediction_scatters: 276 | file_writer_plot = tf.summary.create_file_writer(log_dir + "/plots") 277 | if isinstance(self, RegressionModel): # For RegressionModel 278 | y_in = y 279 | yAll = y_in 280 | else: # For RNNModel 281 | y_in = y[0] 282 | yAll = np.reshape( 283 | y_in, (int(y_in.size / y_in.shape[-1]), y_in.shape[-1]), "C" 284 | ) 285 | if validation_data is not None: 286 | x_val, y_val = validation_data[0], validation_data[1] 287 | if isinstance(self, RegressionModel): # For RegressionModel 288 | yAll_val = y_val 289 | else: # For RNNModel 290 | yAll_val = np.reshape( 291 | y_val[0], 292 | (int(y_val[0].size / y_in.shape[-1]), y_in.shape[-1]), 293 | "C", 294 | ) 295 | ny = y_in.shape[-1] 296 | nyIndsToPlot = np.array( 297 | np.unique(np.round(np.linspace(0, ny - 1, 2))), "int" 298 | ) 299 | fig1, fig2, fig3, fig4 = None, None, None, None 300 | 301 | def tensorboard_plot_signals(epoch, logs): 302 | nonlocal fig1, fig2, fig3, fig4 303 | if ( 304 | epoch != 0 305 | and epoch % tb_plot_epoch_mod < tb_plot_epoch_mod - 1 306 | ): 307 | return 308 | steps_ahead = ( 309 | [1] 310 | if not hasattr(self, "steps_ahead") 311 | or self.steps_ahead is None 312 | else self.steps_ahead 313 | ) 314 | # Use the model to predict the values from the validation dataset. 315 | if isinstance(self, RegressionModel): # For RegressionModel 316 | yHat = [self.model.predict(x)] 317 | predLegStrs = ["Pred"] 318 | batchCntStr = "" 319 | else: # For RNNModel 320 | yHat = self.predict_with_keras(x)[: len(steps_ahead)] 321 | yHat = [ 322 | np.reshape(yHatThis, yAll.shape) 323 | for yHatThis in list(yHat) 324 | ] 325 | yHat = list( 326 | shift_ms_to_1s_series( 327 | yHat, 328 | steps_ahead, 329 | missing_marker=np.nan, 330 | time_first=True, 331 | ) 332 | ) 333 | predLegStrs = [ 334 | f"Pred {step_ahead}-step" for step_ahead in steps_ahead 335 | ] 336 | batch_count = int(x[0].shape[0] / self.batch_size) 337 | batchCntStr = ", {} batches".format(batch_count) 338 | titleHead = "Epoch:{}{} (training data)\n".format( 339 | epoch, batchCntStr 340 | ) 341 | if tb_make_prediction_plots: 342 | if fig1 is not None: 343 | fig1.clf() 344 | plotArgs = { 345 | "missing_marker": self.missing_marker, 346 | "addNaNInTimeGaps": False, 347 | "plotDims": nyIndsToPlot, 348 | "predLegStrs": predLegStrs, 349 | "y_pred_is_list": True, 350 | "lineStyles": ["-", "-", "--", "-.", ":"], 351 | "figsize": (11, 6), 352 | "predPerfsToAdd": ["R2", "CC", "MSE"], 353 | "return_fig": True, 354 | } 355 | fig1 = plotTimeSeriesPrediction( 356 | yAll, yHat, titleHead=titleHead, fig=fig1, **plotArgs 357 | ) 358 | plot_image = plot_to_image(fig1) 359 | with file_writer_plot.as_default(): 360 | tf.summary.image( 361 | "Training prediction", plot_image, step=epoch 362 | ) 363 | if tb_make_prediction_scatters: 364 | if fig2 is not None: 365 | fig2.clf() 366 | scatterArgs = { 367 | "missing_marker": self.missing_marker, 368 | "plot45DegLine": True, 369 | "plotLSLine": True, 370 | "styles": {"size": 10, "marker": "x"}, 371 | "figsize": (11, 4), 372 | "title": ["Dim{} ".format(di) for di in nyIndsToPlot], 373 | "legNames": [ 374 | f"{step_ahead}-step" for step_ahead in steps_ahead 375 | ], 376 | "addPerfMeasuresToLegend": ["CC", "R2"], 377 | "addPerfMeasuresToTitle": ["CC", "R2", "MSE"], 378 | "return_fig": True, 379 | } 380 | fig2 = plotPredictionScatter( 381 | [yAll[..., di] for di in nyIndsToPlot], 382 | [ 383 | np.array([yHatStep[..., di] for yHatStep in yHat]) 384 | for di in nyIndsToPlot 385 | ], 386 | titleHead=[titleHead] + [""] * len(nyIndsToPlot), 387 | fig=fig2, 388 | **scatterArgs, 389 | ) 390 | plot_image = plot_to_image(fig2) 391 | with file_writer_plot.as_default(): 392 | tf.summary.image( 393 | "Training prediction (scatter)", 394 | plot_image, 395 | step=epoch, 396 | ) 397 | # The same for validation data 398 | if validation_data is not None: 399 | if isinstance(self, RegressionModel): # For RegressionModel 400 | yHat_val = [self.model.predict(x_val)] 401 | batchCntStr = "" 402 | else: # For RNNModel 403 | yHat_val = self.predict_with_keras(x_val)[ 404 | : len(steps_ahead) 405 | ] 406 | yHat_val = [ 407 | np.reshape(yHatThis, yAll_val.shape) 408 | for yHatThis in list(yHat_val) 409 | ] 410 | yHat_val = list( 411 | shift_ms_to_1s_series( 412 | yHat_val, 413 | steps_ahead, 414 | missing_marker=np.nan, 415 | time_first=True, 416 | ) 417 | ) 418 | batchCntStr = ", {} batches".format( 419 | int(x_val[0].shape[0] / self.batch_size) 420 | ) 421 | titleHead = "Epoch:{}{} (validation)\n".format( 422 | epoch, batchCntStr 423 | ) 424 | if tb_make_prediction_plots: 425 | if fig3 is not None: 426 | fig3.clf() 427 | fig3 = plotTimeSeriesPrediction( 428 | yAll_val, 429 | yHat_val, 430 | titleHead=titleHead, 431 | fig=fig3, 432 | **plotArgs, 433 | ) 434 | plot_image = plot_to_image(fig3) 435 | with file_writer_plot.as_default(): 436 | tf.summary.image( 437 | "Validation prediction", plot_image, step=epoch 438 | ) 439 | if tb_make_prediction_scatters: 440 | if fig4 is not None: 441 | fig4.clf() 442 | fig4 = plotPredictionScatter( 443 | [yAll_val[..., di] for di in nyIndsToPlot], 444 | [ 445 | np.array( 446 | [yHatStep[..., di] for yHatStep in yHat_val] 447 | ) 448 | for di in nyIndsToPlot 449 | ], 450 | titleHead=[titleHead] + [""] * len(nyIndsToPlot), 451 | fig=fig4, 452 | **scatterArgs, 453 | ) 454 | plot_image = plot_to_image(fig4) 455 | with file_writer_plot.as_default(): 456 | tf.summary.image( 457 | "Validation prediction (scatter)", 458 | plot_image, 459 | step=epoch, 460 | ) 461 | 462 | callbacks_this.append( 463 | tf.keras.callbacks.LambdaCallback( 464 | on_epoch_end=tensorboard_plot_signals 465 | ) 466 | ) 467 | eagerly_flag_backup = set_global_tf_eagerly_flag(False) 468 | if eagerly_flag_backup: 469 | logger.warning( 470 | "Tensorflow was set up globally to run eagerly. This is EXTREMELY slow so we have temporarily disabled it and will reenable it after model fitting. Consider fixing this global setting by running tf.config.run_functions_eagerly(False)." 471 | ) 472 | if self.model.run_eagerly or tf.config.functions_run_eagerly(): 473 | warnings.warn( 474 | "This Tensorflow model is set up to run eagerly. This will be EXTREMELY slow!!! Please fix." 475 | ) 476 | if len(self.model.trainable_weights) == 0: 477 | logger.info(f"No trainable weights... skipping training.") 478 | tic = time.perf_counter() 479 | history = self.model.fit( 480 | x=x, 481 | y=y, 482 | callbacks=callbacks_this, 483 | validation_data=validation_data, 484 | **kwargs, 485 | ) 486 | toc = time.perf_counter() 487 | fitTime = toc - tic 488 | set_global_tf_eagerly_flag(eagerly_flag_backup) 489 | if hasattr(early_stopping_callback, "stopped_epoch"): 490 | history.params["stopped_epoch"] = early_stopping_callback.stopped_epoch 491 | if hasattr(early_stopping_callback, "best_epoch"): 492 | history.params["best_epoch"] = early_stopping_callback.best_epoch 493 | if early_stopping_restore_best_weights: 494 | picked_epoch = history.params["best_epoch"] 495 | else: 496 | picked_epoch = history.history["epoch"][-1] 497 | if "verbose" in kwargs and kwargs["verbose"] != 2: 498 | logFields = [k for k in history.history.keys()] 499 | logger.info( 500 | "\n" 501 | + getModelFitHistoyStr(history, fields=logFields, keep_ratio=0.1) 502 | ) 503 | if "regularization_loss" in history.history: 504 | total_loss = np.array(history.history["loss"]) 505 | reg_loss = np.array(history.history["regularization_loss"]) 506 | loss_range = np.quantile(total_loss, [0.01, 0.99]) 507 | reg_loss_range = np.quantile(reg_loss, [0.01, 0.99]) 508 | reg_to_total_change_ratio = ( 509 | np.diff(reg_loss_range)[0] / np.diff(loss_range)[0] 510 | ) 511 | median_reg_to_total_ratio = np.median(reg_loss / total_loss) 512 | logger.info( 513 | "{:.2g}% of the changes in total loss ({:.2g} => {:.2g}) are due to changes in regularization loss ({:.2g} => {:.2g})".format( 514 | reg_to_total_change_ratio * 100, 515 | loss_range[1], 516 | loss_range[0], 517 | reg_loss_range[1], 518 | reg_loss_range[0], 519 | ) 520 | ) 521 | logger.info( 522 | "Median ratio of reg_loss to total_loss is {:.2g}%".format( 523 | median_reg_to_total_ratio * 100 524 | ) 525 | ) 526 | if np.any((total_loss - reg_loss) < 0): 527 | logger.info("Loss has negative values") 528 | reg_to_loss_ratio = reg_to_total_change_ratio 529 | else: 530 | reg_to_loss_ratio = median_reg_to_total_ratio 531 | if reg_to_loss_ratio > 0.5: 532 | logger.info( 533 | "Regularization lambda is too high, regularization is dominating the total loss" 534 | ) 535 | elif reg_to_loss_ratio < 0.01: 536 | logger.info( 537 | "Regularization lambda is too low, regularization is an almost negligible part (<1%) of the total loss" 538 | ) 539 | logger.info("Model fitting took {:.2f}s".format(fitTime)) 540 | weights = self.model.get_weights() 541 | nan_weights = [v for v in weights if np.any(np.isnan(v))] 542 | if ( 543 | len(nan_weights) > 0 544 | and keep_latest_nonnan_weights 545 | and latest_nonnan_weights is not None 546 | ): 547 | logger.warning( 548 | f"{len(nan_weights)} weights had nans, replacing with weights from the latest epoch with non-nan weights (epoch {latest_nonnan_weights_epoch})" 549 | ) 550 | self.model.set_weights(latest_nonnan_weights) 551 | epoch_ind = [ 552 | ep for ep in history.epoch if ep <= latest_nonnan_weights_epoch 553 | ][-1] 554 | for key in history.history: 555 | history.history[key][-1] = history.history[key][epoch_ind] 556 | picked_epoch = history.epoch[epoch_ind] 557 | history.params["picked_epoch"] = picked_epoch 558 | if init_attempts > 1: 559 | weights = self.model.get_weights() 560 | modelsWeightsAll.append(weights) 561 | historyAll.append(history) 562 | self.build() 563 | # Reset model weights 564 | if attempt == init_attempts: # Select final model 565 | lossAll = [ 566 | np.array(h.history[early_stopping_measure])[ 567 | np.where(np.array(h.epoch) == h.params["picked_epoch"])[0] 568 | ][0] 569 | for h in historyAll 570 | ] 571 | if np.all(np.isnan(lossAll)): 572 | msg = "All fit attempts ended up with a nan loss (probably blew up)!" 573 | if not keep_latest_nonnan_weights: 574 | msg += " Consider setting keep_latest_nonnan_weights=True to keep latest epoch with non-nan loss in case of blow up." 575 | # raise(Exception(msg)) 576 | logger.warning(msg) 577 | if np.all(np.isnan(lossAll)): 578 | logger.warning( 579 | "All attempts resulted in NaN loss for all epochs!! Keeping initial random params from attempt 1. " 580 | ) 581 | bestInd = 0 582 | else: 583 | bestInd = np.nanargmin(lossAll) 584 | logger.info( 585 | "Selected model from learning attempt {}/{}, which had the smallest loss ({:.8g})".format( 586 | 1 + bestInd, init_attempts, lossAll[bestInd] 587 | ) 588 | ) 589 | self.model.set_weights(modelsWeightsAll[bestInd]) 590 | history = historyAll[bestInd] 591 | history.params["history_all"] = [ 592 | convertHistoryToDict(h) for h in historyAll 593 | ] 594 | history.params["selected_ind"] = bestInd 595 | if self.log_dir != "": 596 | self.log_subdir = log_subdirs[bestInd] 597 | if self.log_dir != "" and ( 598 | tb_make_prediction_plots or tb_make_prediction_scatters 599 | ): 600 | plt.close("all") 601 | del fig1, fig2, fig3, fig4 602 | return history 603 | -------------------------------------------------------------------------------- /source/DPAD/tools/parse_tools.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2024 University of Southern California 3 | See full notice in LICENSE.md 4 | Omid G. Sani and Maryam M. Shanechi 5 | Shanechi Lab, University of Southern California 6 | """ 7 | 8 | """ Tools for parsing method description strings """ 9 | 10 | import copy 11 | import re 12 | 13 | import numpy as np 14 | 15 | 16 | def extractNumberFromRegex( 17 | saveCode, regex=r"([-+]?[\d]+\.?[\d]*)([Ee][-+]?[\d]+)?", prefix=None 18 | ): 19 | if prefix is not None: 20 | regex = re.compile(prefix + str(regex)) 21 | out = [] 22 | out_matches = [] 23 | if len(re.findall(regex, saveCode)): 24 | matches = re.finditer(regex, saveCode) 25 | for matchNum, match in enumerate(matches, start=1): 26 | num = float(match.groups()[0]) 27 | if match.groups()[1] is not None: 28 | num *= 10 ** float(match.groups()[1][1:]) 29 | out_matches.append(match) 30 | out.append(num) 31 | return out, out_matches 32 | 33 | 34 | def extractPowRangesFromRegex(regex, saveCode, base_type=int): 35 | out = [] 36 | out_matches = [] 37 | if len(re.findall(regex, saveCode)): 38 | matches = re.finditer(regex, saveCode) 39 | for matchNum, match in enumerate(matches, start=1): 40 | base, min_val, step_val, max_val = match.groups() 41 | out_matches.append(match) 42 | pows = np.arange(int(min_val), 1 + int(max_val), int(step_val)) 43 | steps = [base_type(base) ** p for p in pows] 44 | out.extend(steps) 45 | return out, out_matches 46 | 47 | 48 | def extractLinearRangesFromRegex(regex, saveCode): 49 | out = [] 50 | out_matches = [] 51 | if len(re.findall(regex, saveCode)): 52 | matches = re.finditer(regex, saveCode) 53 | for matchNum, match in enumerate(matches, start=1): 54 | min_val, step_val, max_val = match.groups() 55 | out_matches.append(match) 56 | 57 | out.extend(np.arange(float(min_val), 1 + float(max_val), float(step_val))) 58 | return out, out_matches 59 | 60 | 61 | def extractIntRangesFromRegex(regex, saveCode): 62 | out = [] 63 | out_matches = [] 64 | if len(re.findall(regex, saveCode)): 65 | matches = re.finditer(regex, saveCode) 66 | for matchNum, match in enumerate(matches, start=1): 67 | min_val, step_val, max_val = match.groups() 68 | out_matches.append(match) 69 | 70 | out.extend(np.arange(int(min_val), 1 + int(max_val), int(step_val))) 71 | return out, out_matches 72 | 73 | 74 | def extractStrsFromRegex(regex, saveCode): 75 | out = [] 76 | out_matches = [] 77 | if len(re.findall(regex, saveCode)): 78 | matches = re.finditer(regex, saveCode) 79 | for matchNum, match in enumerate(matches, start=1): 80 | this_val = match.groups()[0] 81 | out_matches.append(match) 82 | out.extend([this_val]) 83 | return out, out_matches 84 | 85 | 86 | def extractIntsFromRegex(regex, saveCode): 87 | out = [] 88 | out_matches = [] 89 | if len(re.findall(regex, saveCode)): 90 | matches = re.finditer(regex, saveCode) 91 | for matchNum, match in enumerate(matches, start=1): 92 | this_val = match.groups()[0] 93 | out_matches.append(match) 94 | out.extend([int(this_val)]) 95 | return out, out_matches 96 | 97 | 98 | def extractFloatsFromRegex(regex, saveCode): 99 | out = [] 100 | out_matches = [] 101 | if len(re.findall(regex, saveCode)): 102 | matches = re.finditer(regex, saveCode) 103 | for matchNum, match in enumerate(matches, start=1): 104 | this_val = match.groups()[0] 105 | out_matches.append(match) 106 | out.extend([float(this_val)]) 107 | return out, out_matches 108 | 109 | 110 | def parseMethodCodeArg_kpp(saveCode): 111 | if ( 112 | "kpp" in saveCode 113 | ): # Keeps this portion of the whole preprocessed data for all analyses 114 | regex = r"(tr)?kpp(\d+\.?\d*|\d*\.?\d+)_(\d+\.?\d*|\d*\.?\d+)" # _kpp0.5_1 or _trkpp0.5_1 115 | if len(re.findall(regex, saveCode)): 116 | matches = re.finditer(regex, saveCode) 117 | for matchNum, match in enumerate(matches, start=1): 118 | dataPeriod = ( 119 | "train" 120 | if len(match.groups()) > 2 or match.groups()[0] == "tr" 121 | else "all" 122 | ) 123 | keptPortion = (float(match.groups()[-2]), float(match.groups()[-1])) 124 | else: 125 | regex = r"(tr)kpp(\d+\.?\d*|\d*\.?\d+)" # _kpp0.5 or _trkpp0.5 126 | matches = re.finditer(regex, saveCode) 127 | for matchNum, match in enumerate(matches, start=1): 128 | dataPeriod = ( 129 | "train" 130 | if len(match.groups()) > 1 or match.groups()[0] == "tr" 131 | else "all" 132 | ) 133 | keptPortion = (0, float(match.groups()[-1])) 134 | return dataPeriod, keptPortion, match 135 | else: 136 | return None, None, None 137 | 138 | 139 | def parseMethodCodeArgStepsAhead(saveCode): 140 | out_matches = [] 141 | steps_ahead, matches1 = extractIntRangesFromRegex( 142 | r"sta(\d+);(\d+);(\d+)", saveCode 143 | ) # sta1;1;5 144 | out_matches.extend(matches1) 145 | 146 | steps_ahead2, matches2 = extractPowRangesFromRegex( 147 | r"sta(\d+)\^(\d+);(\d+);(\d+)", saveCode 148 | ) # sta2^1;1;5 149 | out_matches.extend(matches2) 150 | steps_ahead.extend(steps_ahead2) 151 | 152 | steps_ahead3, matches3 = extractIntsFromRegex( 153 | r"sta(\d+)(?![;\^])", saveCode 154 | ) # sta10 (but not sta10;) 155 | out_matches.extend(matches3) 156 | steps_ahead.extend(steps_ahead3) 157 | 158 | if len(steps_ahead) == 0: 159 | steps_ahead = None 160 | 161 | steps_ahead_loss_weights = None 162 | 163 | if steps_ahead is not None: 164 | zeroWeightList, matches1 = extractIntRangesFromRegex( 165 | r"staZW(\d+);(\d+);(\d+)", saveCode 166 | ) # staZW2;1;5 167 | out_matches.extend(matches1) 168 | zeroWeightList2, matches2 = extractPowRangesFromRegex( 169 | r"staZW(\d+)\^(\d+);(\d+);(\d+)", saveCode 170 | ) # staZW2^2;1;5 171 | zeroWeightList2.extend(zeroWeightList2) 172 | out_matches.extend(matches2) 173 | zeroWeightList3, matches3 = extractIntsFromRegex( 174 | r"staZW(\d+)(?!;)", saveCode 175 | ) # staZW10 (but not staZW10;) 176 | zeroWeightList.extend(zeroWeightList3) 177 | out_matches.extend(matches3) 178 | 179 | if len(zeroWeightList) > 0: 180 | steps_ahead_loss_weights = [ 181 | 0.0 if step_ahead in zeroWeightList else 1.0 182 | for step_ahead in steps_ahead 183 | ] 184 | 185 | return steps_ahead, steps_ahead_loss_weights, out_matches 186 | 187 | 188 | def parseMethodCodeArgEnsemble(saveCode): 189 | out_matches = [] 190 | ensemble_cnt, matches1 = extractIntsFromRegex(r"ensm(\d+)", saveCode) # ensm10 191 | out_matches.extend(matches1) 192 | return ensemble_cnt, out_matches 193 | 194 | 195 | def parseMethodCodeArgOptimizer(saveCode): 196 | """Parses the optimizer settings from methodCode 197 | 198 | Args: 199 | saveCode (str): the string specifying the method settings 200 | 201 | Returns: 202 | _type_: _description_ 203 | """ 204 | outs = [] 205 | out_matches = [] 206 | 207 | optimizer_args = None 208 | learning_rates, matches = extractNumberFromRegex( 209 | saveCode, prefix="LR" 210 | ) # LR1e-05 or LR0.01 211 | if len(learning_rates) > 0: 212 | if optimizer_args is None: 213 | optimizer_args = {} 214 | optimizer_args.update( 215 | {"learning_rate": learning_rates[0]} # Default for Adam 0.001 216 | ) 217 | 218 | weight_decays, matches = extractNumberFromRegex( 219 | saveCode, prefix="WD" 220 | ) # WD1e-05 or WD0.01 221 | if len(weight_decays) > 0: 222 | if optimizer_args is None: 223 | optimizer_args = {} 224 | optimizer_args.update( 225 | {"weight_decay": weight_decays[0]} # Default for AdamW is 1 226 | ) 227 | 228 | regex = r"opt(AdamW|Adam)(_sc)?(CDR|CD|ED|ITD|PCD|PD)?" 229 | if len(re.findall(regex, saveCode)) == 0: 230 | # Revert to default 231 | outs = [{"optimizer_args": optimizer_args}] 232 | else: 233 | matches = re.finditer(regex, saveCode) 234 | for matchNum, match in enumerate(matches, start=1): 235 | groups = match.groups() 236 | optimizer_name = groups[0] 237 | out = {"optimizer_name": optimizer_name, "optimizer_args": optimizer_args} 238 | if len(groups) > 0 and groups[1] == "_sc": 239 | scheduler_code = groups[2] 240 | scheduler_options = { 241 | "CD": { # https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/schedules/CosineDecay 242 | "name": "CosineDecay", 243 | "args": [ 244 | "initial_learning_rate", 245 | "decay_steps", 246 | "alpha", 247 | "warmup_target", 248 | "warmup_steps", 249 | ], 250 | }, 251 | "CDR": { # https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/schedules/CosineDecayRestarts 252 | "name": "CosineDecayRestarts", 253 | "args": [ 254 | "initial_learning_rate", 255 | "first_decay_steps", 256 | "t_mul", 257 | "m_mul", 258 | "alpha", 259 | ], 260 | }, 261 | "ED": { # https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/schedules/ExponentialDecay 262 | "name": "ExponentialDecay", 263 | "args": [ 264 | "initial_learning_rate", 265 | "decay_steps", 266 | "decay_rate", 267 | "staircase", 268 | ], 269 | }, 270 | "ITD": { # https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/schedules/InverseTimeDecay 271 | "name": "InverseTimeDecay", 272 | "args": [ 273 | "initial_learning_rate", 274 | "decay_steps", 275 | "decay_rate", 276 | "staircase", 277 | ], 278 | }, 279 | "PCD": { # https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/schedules/PiecewiseConstantDecay 280 | "name": "PiecewiseConstantDecay", 281 | "args": [ 282 | "boundaries", 283 | "values", 284 | ], # boundaries and values are each lists 285 | }, 286 | "PD": { # https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/schedules/PolynomialDecay 287 | "name": "PolynomialDecay", 288 | "args": [ 289 | "initial_learning_rate", 290 | "decay_steps", 291 | "end_learning_rate", 292 | "power", 293 | "cycle", 294 | ], 295 | }, 296 | } 297 | scheduler_name = scheduler_options[scheduler_code]["name"] 298 | out["scheduler_name"] = scheduler_name 299 | regex2 = re.compile( 300 | f"opt{optimizer_name}_sc{scheduler_code}" 301 | + r"(_?(?:_?[-+]?[\d]+\.?[\d]*)(?:[Ee][-+]?[\d]+)?)*" 302 | ) 303 | if len(re.findall(regex2, saveCode)): 304 | matches2 = re.finditer(regex2, saveCode) 305 | for matchNum2, match2 in enumerate(matches2, start=1): 306 | matchstr = match2.group().replace( 307 | f"opt{optimizer_name}_sc{scheduler_code}", "" 308 | ) 309 | numbers = matchstr.split("_") 310 | args = {} 311 | for param_ind, num_str in enumerate(numbers): 312 | out_this, out_matches_this = extractNumberFromRegex(num_str) 313 | num = out_this[0] 314 | if float(num) == int(num): 315 | num = int(num) 316 | arg_name = scheduler_options[scheduler_code]["args"][ 317 | 1 + param_ind 318 | ] 319 | args[arg_name] = num 320 | out["scheduler_args"] = args 321 | regex2 = re.compile( 322 | f"opt{optimizer_name}_sc{scheduler_code}" 323 | + r"(?:_?(?:_?[-+]?[\d]+\.?[\d]*)(?:[Ee][-+]?[\d]+)?)*(_str(T|F))" 324 | ) # Check for staircase in ExponentialDecay and InverseTimeDecay 325 | if len(re.findall(regex2, saveCode)): 326 | matches2 = re.finditer(regex2, saveCode) 327 | for matchNum2, match2 in enumerate(matches2, start=1): 328 | out["scheduler_args"]["staircase"] = ( 329 | True if match2.groups()[-1] == "T" else False 330 | ) 331 | regex2 = re.compile( 332 | f"opt{optimizer_name}_sc{scheduler_code}" 333 | + r"(?:_?(?:_?[-+]?[\d]+\.?[\d]*)(?:[Ee][-+]?[\d]+)?)*(_cyc(T|F))" 334 | ) # Check for cycle in PolynomialDecay 335 | if len(re.findall(regex2, saveCode)): 336 | matches2 = re.finditer(regex2, saveCode) 337 | for matchNum2, match2 in enumerate(matches2, start=1): 338 | out["scheduler_args"]["cycle"] = ( 339 | True if match2.groups()[-1] == "T" else False 340 | ) 341 | out_matches.append(match) 342 | outs.append(out) 343 | 344 | return outs, out_matches 345 | 346 | 347 | def parseInnerCVFoldSettings(saveCode): 348 | out = [] 349 | out_matches = [] 350 | regex = r"iCVF(\d+)o?(\d+)?" # iCVF2, or iCVF5o5 351 | if len(re.findall(regex, saveCode)): 352 | matches = re.finditer(regex, saveCode) 353 | for matchNum, match in enumerate(matches, start=1): 354 | vals = match.groups() 355 | numFolds = int(vals[0]) 356 | foldsToRun = ( 357 | [int(vals[1])] if len(vals) > 1 and vals[1] is not None else None 358 | ) 359 | out_matches.append(match) 360 | out.extend([{"folds": numFolds, "foldsToRun": foldsToRun}]) 361 | return out, out_matches 362 | 363 | 364 | def extractValueRanges(methodCode, prefix="L"): 365 | methodCodeCpy = copy.copy(methodCode) 366 | lambdaVals = [] 367 | lambdaValStrs = [] 368 | # Find lambda vals provided as linear ranges 369 | regex = f"{prefix}" + r"(\d+)e([-+])?(\d+);(\d+);(\d+)e([-+])?(\d+)" # L1e-2:-2:-8 370 | matches = re.finditer(regex, methodCodeCpy) 371 | for matchNum, match in enumerate(matches, start=1): 372 | m, sgn, power, count, m2, sgn2, power2 = match.groups() 373 | power = -float(power) if sgn is not None and sgn == "-" else float(power) 374 | power2 = -float(power2) if sgn2 is not None and sgn2 == "-" else float(power2) 375 | lVals = np.linspace(float(m) * 10**power, float(m2) * 10**power2, int(count)) 376 | lValsC = np.array([float(f"{l:.5f}") for l in lVals]) 377 | if np.max(np.abs(lVals - lValsC)) < np.min(lVals) * 1e-3: 378 | lVals = lValsC 379 | lambdaVals.extend(list(lVals)) 380 | strSpan = match.span() 381 | lambdaValStrs.extend([methodCodeCpy[strSpan[0] : strSpan[1]]] * int(count)) 382 | for ls in lambdaValStrs: 383 | methodCodeCpy = methodCodeCpy.replace(ls, "") 384 | # Find lambda vals provided as ranges of exponents 385 | regex = f"{prefix}" + r"(\d+)e([-+])?(\d+);([-+])?(\d+);([-+])?(\d+)" # L1e-2:-2:-8 386 | matches = re.finditer(regex, methodCodeCpy) 387 | for matchNum, match in enumerate(matches, start=1): 388 | m, sgn, power, step_sgn, step_val, sgn2, power2 = match.groups() 389 | power = -float(power) if sgn is not None and sgn == "-" else float(power) 390 | power2 = -float(power2) if sgn2 is not None and sgn2 == "-" else float(power2) 391 | step_val = ( 392 | -float(step_val) 393 | if step_sgn is not None and step_sgn == "-" 394 | else float(step_val) 395 | ) 396 | pow_vals = np.array(np.arange(power, power2, step_val)) 397 | lVals = float(m) * 10**pow_vals 398 | lValsC = np.array([float(f"{l:.5f}") for l in lVals]) 399 | if np.max(np.abs(lVals - lValsC)) < np.min(lVals) * 1e-3: 400 | lVals = lValsC 401 | lambdaVals.extend(list(lVals)) 402 | strSpan = match.span() 403 | lambdaValStrs.extend([methodCodeCpy[strSpan[0] : strSpan[1]]] * pow_vals.size) 404 | for ls in lambdaValStrs: 405 | methodCodeCpy = methodCodeCpy.replace(ls, "") 406 | # Find individual lambda vals 407 | regex = f"{prefix}" + r"(\d+)+e([-+])?(\d+)+" # L1e-2 408 | matches = re.finditer(regex, methodCodeCpy) 409 | for matchNum, match in enumerate(matches, start=1): 410 | m, sgn, power = match.groups() 411 | if sgn is not None and sgn == "-": 412 | power = -float(power) 413 | lVals = np.array([float(m) * 10 ** float(power)]) 414 | lValsC = np.array([float(f"{l:.5f}") for l in lVals]) 415 | if np.max(np.abs(lVals - lValsC)) < np.min(lVals) * 1e-3: 416 | lVals = lValsC 417 | lambdaVals.append(lVals[0]) 418 | strSpan = match.span() 419 | lambdaValStrs.append(methodCodeCpy[strSpan[0] : strSpan[1]]) 420 | for ls in lambdaValStrs: 421 | methodCodeCpy = methodCodeCpy.replace(ls, "") 422 | return lambdaVals, lambdaValStrs 423 | -------------------------------------------------------------------------------- /source/DPAD/tools/plot_model_params.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PSID import LSSM 3 | from PSID.PSID import projOrth 4 | 5 | from .. import DPADModel 6 | from .plot import checkIfAllExtsAlreadyExist, plotPredictionScatter 7 | from .SSM import SSM 8 | from .tools import pickColumnOp 9 | 10 | 11 | def plot_model_params( 12 | sId, 13 | Z, 14 | Y, 15 | X, 16 | title, 17 | savePath=None, 18 | saveExtensions=["png", "svg"], 19 | skip_existing=False, 20 | trueModel=None, 21 | ZTrue=None, 22 | YTrue=None, 23 | XTrue=None, 24 | params_to_plot=None, 25 | plot_orig=None, 26 | figsize=None, 27 | show_hist_x=False, 28 | show_hist_y=False, 29 | XLimByQuantile=None, 30 | YLimByQuantile=None, 31 | x_keep_quantiles=None, 32 | ): 33 | """Plots parameters of the learned model 34 | 35 | Args: 36 | sId (object): learned model 37 | Z (np.array): signal 2 (e.g. behavior) 38 | Y (np.array): signal 1 (e.g. neural activity) 39 | X (np.array): latent state 40 | title (string): title of figure 41 | savePath (string, optional): path to save figure files. Defaults to None. 42 | saveExtensions (list, optional): list of figure file extensions to generate. Defaults to ['png', 'svg']. 43 | skip_existing (bool, optional): if True will skip generating plots if the figure file exists. Defaults to False. 44 | trueModel (object, optional): true model if known in simulations. Defaults to None. 45 | ZTrue (np.array, optional): true values of signal 2 (e.g. behavior) if known in simulations. Defaults to None. 46 | YTrue (np.array, optional): true values of signal 1 (e.g. neural activity) if known in simulations. Defaults to None. 47 | XTrue (np.array, optional): true values of latent states if known in simulations. Defaults to None. 48 | params_to_plot (list of string, optional): list of parameters to plot. Defaults to None. 49 | plot_orig (bool, optional): if True, will try to plot true parameter values in simulations when possible. Defaults to None. 50 | # The following are arguments to pass to plotPredictionScatter 51 | figsize (_type_, optional): _description_. Defaults to None. 52 | show_hist_x (bool, optional): _description_. Defaults to False. 53 | show_hist_y (bool, optional): _description_. Defaults to False. 54 | XLimByQuantile (_type_, optional): _description_. Defaults to None. 55 | YLimByQuantile (_type_, optional): _description_. Defaults to None. 56 | x_keep_quantiles (_type_, optional): _description_. Defaults to None. 57 | """ 58 | if plot_orig is None: 59 | plot_orig = XTrue is None 60 | nx = X.shape[-1] 61 | ny = Y.shape[-1] 62 | if isinstance(sId, LSSM) and not isinstance(sId, SSM): 63 | sId_DPAD = DPADModel() 64 | sId_DPAD.setToLSSM(sId, model1_Cy_Full=False, model2_Cz_Full=False) 65 | sId = sId_DPAD 66 | if isinstance(sId, DPADModel): 67 | if sId.model1 is not None: 68 | X1 = X[:, : sId.n1] 69 | X2 = X[:, sId.n1 :] 70 | XRng = range(X.shape[1]) 71 | X1Rng = range(X1.shape[1]) 72 | YRng = range(Y.shape[1]) 73 | X1Cols = [(X1 @ pickColumnOp(X1.shape[1], [xi])).T for xi in X1Rng] 74 | YCols = [(Y @ pickColumnOp(Y.shape[1], [yi])).T for yi in YRng] 75 | 76 | if XTrue is not None: 77 | zDims = trueModel.zDims 78 | n1True = zDims.size 79 | X1True = XTrue[:, :n1True] 80 | X2True = XTrue[:, n1True:] 81 | X1TrueCols = [ 82 | (X1True @ pickColumnOp(X1True.shape[1], [xi])).T for xi in X1Rng 83 | ] 84 | 85 | X1TrueHat, WToX1True = projOrth(X1True.T, X1.T) 86 | X1TrueHat = X1TrueHat.T 87 | # X1TrueHat is the estimated equivalent X1True for every identified X1. 88 | # We expect true_param(X1TrueHat) to be the same as learned_param(X1) 89 | 90 | X1Hat, WToX1 = projOrth(X1.T, X1True.T) 91 | X1Hat = X1Hat.T 92 | # X1Hat is the estimated equivalent identified X1 for every X1True. 93 | # We expect learned_param(X1Hat) to be the same as true_param(X1True) 94 | 95 | X1TrueCols = [ 96 | (X1True @ pickColumnOp(X1True.shape[1], [xi])).T for xi in X1Rng 97 | ] 98 | XTrueCols = [ 99 | (XTrue @ pickColumnOp(XTrue.shape[1], [xi])).T for xi in XRng 100 | ] 101 | else: 102 | X1Hat = X1 103 | X1HatCols = [(X1Hat @ pickColumnOp(X1Hat.shape[1], [xi])).T for xi in X1Rng] 104 | if ( 105 | (params_to_plot is None or "A" in params_to_plot) 106 | and hasattr(sId.model1.rnn.cell, "A") 107 | and ( 108 | "unifiedAK" not in sId.model1.cell_args 109 | or not sId.model1.cell_args["unifiedAK"] 110 | ) 111 | ): 112 | if plot_orig and ( 113 | savePath is None 114 | or not checkIfAllExtsAlreadyExist( 115 | savePath + "_paramA_KCy", saveExtensions 116 | ) 117 | or not skip_existing 118 | ): 119 | y_out_A_list = [ 120 | sId.model1.rnn.cell.A.predict(X1Cols[xi]).T for xi in X1Rng 121 | ] 122 | plotPredictionScatter( 123 | [ 124 | np.tile(X1[:, xi].T, (y_out_A_list[0].shape[1], 1)) 125 | for xi in X1Rng 126 | ], 127 | [y_out_A_list[xi].T for xi in X1Rng], 128 | connect_sorted=True, 129 | square=False, 130 | styles={"size": 1}, 131 | title=f"{title}A" "=(A-KCy) stage 1", 132 | xLabel=[f"x{xi+1}" for xi in X1Rng], 133 | yLabel=[f"A'(x{xi+1})" for xi in X1Rng], 134 | figsize=figsize, 135 | show_hist_x=show_hist_x, 136 | show_hist_y=show_hist_y, 137 | XLimByQuantile=XLimByQuantile, 138 | YLimByQuantile=YLimByQuantile, 139 | x_keep_quantiles=x_keep_quantiles, 140 | skip_existing=skip_existing, 141 | saveFile=None if savePath is None else savePath + "_paramA_KCy", 142 | saveExtensions=saveExtensions, 143 | ) 144 | if XTrue is not None and ( 145 | savePath is None 146 | or not checkIfAllExtsAlreadyExist( 147 | savePath + "_paramA_KCy_sim", saveExtensions 148 | ) 149 | or not skip_existing 150 | ): 151 | y_out_A_list = [ 152 | sId.model1.rnn.cell.A.predict(X1HatCols[xi]).T for xi in X1Rng 153 | ] 154 | if trueModel is not None: 155 | if isinstance(trueModel, SSM): 156 | # y_out_A_true = trueModel.apply_param('A', XTrue.T).T \ 157 | # - trueModel.apply_param('K', trueModel.apply_param('C', XTrue.T)).T 158 | y_out_A_true = [ 159 | trueModel.apply_param("A_KC", XTrueCols[xi]).T 160 | for xi in X1Rng 161 | ] 162 | else: 163 | y_out_A_true = [ 164 | (trueModel.A_KC @ XTrueCols[xi]).T for xi in X1Rng 165 | ] 166 | # Convert back to the basis of the learned model 167 | y_out_A_true_sim = [ 168 | (WToX1 @ y_out_A_true[xi][:, :n1True].T).T for xi in X1Rng 169 | ] 170 | # y_out_A_true_sim, W2 = projOrth(y_out_A.T, y_out_A_true.T) # TEMP 171 | # y_out_A_true_sim = y_out_A_true_sim.T 172 | y_out_A_list = [ 173 | np.concatenate( 174 | (y_out_A_list[xi], y_out_A_true_sim[xi]), axis=1 175 | ) 176 | for xi in X1Rng 177 | ] 178 | plotPredictionScatter( 179 | [ 180 | np.tile(X1True[:, xi].T, (y_out_A_list[0].shape[1], 1)) 181 | for xi in X1Rng 182 | ], 183 | [y_out_A_list[xi].T for xi in X1Rng], 184 | connect_sorted=True, 185 | legNames=["Learned", "True (sim)"], 186 | square=False, 187 | styles=[{"linestyle": "-"}, {"linestyle": "--"}], 188 | title=f"{title}A" "=(A-KCy) stage 1", 189 | xLabel=[f"x{xi+1}" for xi in X1Rng], 190 | yLabel=[f"A" "(x{xi+1})" for xi in X1Rng], 191 | figsize=figsize, 192 | show_hist_x=show_hist_x, 193 | show_hist_y=show_hist_y, 194 | XLimByQuantile=XLimByQuantile, 195 | YLimByQuantile=YLimByQuantile, 196 | x_keep_quantiles=x_keep_quantiles, 197 | skip_existing=skip_existing, 198 | saveFile=( 199 | None if savePath is None else savePath + "_paramA_KCy_sim" 200 | ), 201 | saveExtensions=saveExtensions, 202 | ) 203 | if (params_to_plot is None or "Cz" in params_to_plot) and hasattr( 204 | sId.model1.rnn.cell, "C" 205 | ): 206 | if plot_orig and ( 207 | savePath is None 208 | or not checkIfAllExtsAlreadyExist( 209 | savePath + "_paramCz", saveExtensions 210 | ) 211 | or not skip_existing 212 | ): 213 | z_out_Cz_list = [ 214 | sId.model1.rnn.cell.C.predict(X1Cols[xi]).T for xi in X1Rng 215 | ] 216 | plotPredictionScatter( 217 | [ 218 | np.tile(X1[:, xi].T, (z_out_Cz_list[0].shape[1], 1)) 219 | for xi in X1Rng 220 | ], 221 | [z_out_Cz_list[xi].T for xi in X1Rng], 222 | connect_sorted=True, 223 | square=False, 224 | styles={"size": 1}, 225 | title=[f"{title}Cz1{xi+1}" for xi in X1Rng], 226 | xLabel=[f"x{xi+1}" for xi in X1Rng], 227 | yLabel=[f"Cz(x{xi+1})" for xi in X1Rng], 228 | figsize=figsize, 229 | show_hist_x=show_hist_x, 230 | show_hist_y=show_hist_y, 231 | XLimByQuantile=XLimByQuantile, 232 | YLimByQuantile=YLimByQuantile, 233 | x_keep_quantiles=x_keep_quantiles, 234 | skip_existing=skip_existing, 235 | saveFile=None if savePath is None else savePath + "_paramCz", 236 | saveExtensions=saveExtensions, 237 | ) 238 | if XTrue is not None and ( 239 | savePath is None 240 | or not checkIfAllExtsAlreadyExist( 241 | savePath + "_paramCz_sim", saveExtensions 242 | ) 243 | or not skip_existing 244 | ): 245 | z_out_Cz_list = [ 246 | sId.model1.rnn.cell.C.predict(X1HatCols[xi]).T for xi in X1Rng 247 | ] 248 | if trueModel is not None: # Add true model's param 249 | if isinstance(trueModel, SSM): 250 | z_out_Cz_true = [ 251 | trueModel.apply_param("Cz", XTrueCols[xi]).T 252 | for xi in X1Rng 253 | ] 254 | else: 255 | z_out_Cz_true = [ 256 | (trueModel.Cz @ XTrueCols[xi]).T for xi in X1Rng 257 | ] 258 | z_out_Cz_list = [ 259 | np.concatenate( 260 | (z_out_Cz_list[xi], z_out_Cz_true[xi]), axis=1 261 | ) 262 | for xi in X1Rng 263 | ] 264 | plotPredictionScatter( 265 | [ 266 | np.tile(X1True[:, xi].T, (z_out_Cz_list[0].shape[1], 1)) 267 | for xi in X1Rng 268 | ], 269 | [z_out_Cz_list[xi].T for xi in X1Rng], 270 | connect_sorted=True, 271 | legNames=["Learned", "True (sim)"], 272 | square=False, 273 | styles=[{"linestyle": "-"}, {"linestyle": "--"}], 274 | title=[f"{title}Cz1{xi+1}" for xi in X1Rng], 275 | xLabel=[f"x{xi+1}" for xi in X1Rng], 276 | yLabel=[f"Cz(x{xi+1})" for xi in X1Rng], 277 | figsize=figsize, 278 | show_hist_x=show_hist_x, 279 | show_hist_y=show_hist_y, 280 | XLimByQuantile=XLimByQuantile, 281 | YLimByQuantile=YLimByQuantile, 282 | x_keep_quantiles=x_keep_quantiles, 283 | skip_existing=skip_existing, 284 | saveFile=( 285 | None if savePath is None else savePath + "_paramCz_sim" 286 | ), 287 | saveExtensions=saveExtensions, 288 | ) 289 | if (params_to_plot is None or "K" in params_to_plot) and hasattr( 290 | sId.model1.rnn.cell, "K" 291 | ): 292 | if plot_orig and ( 293 | savePath is None 294 | or not checkIfAllExtsAlreadyExist( 295 | savePath + "_paramK1", saveExtensions 296 | ) 297 | or not skip_existing 298 | ): 299 | y_out_K_list = [ 300 | sId.model1.rnn.cell.K.predict(YCols[yi]).T for yi in YRng 301 | ] 302 | plotPredictionScatter( 303 | [ 304 | np.tile(Y[:, yi].T, (y_out_K_list[0].shape[1], 1)) 305 | for yi in YRng 306 | ], 307 | [y_out_K_list[yi].T for yi in YRng], 308 | connect_sorted=True, 309 | square=False, 310 | styles={"size": 1}, 311 | title=[f"{title}K1{yi+1}" for yi in YRng], 312 | xLabel=[f"y{yi+1}" for yi in YRng], 313 | yLabel=[f"K1(y{yi+1})" for yi in YRng], 314 | figsize=figsize, 315 | show_hist_x=show_hist_x, 316 | show_hist_y=show_hist_y, 317 | XLimByQuantile=XLimByQuantile, 318 | YLimByQuantile=YLimByQuantile, 319 | x_keep_quantiles=x_keep_quantiles, 320 | skip_existing=skip_existing, 321 | saveFile=None if savePath is None else savePath + "_paramK1", 322 | saveExtensions=saveExtensions, 323 | ) 324 | if XTrue is not None and ( 325 | savePath is None 326 | or not checkIfAllExtsAlreadyExist( 327 | savePath + "_paramK1_sim", saveExtensions 328 | ) 329 | or not skip_existing 330 | ): 331 | y_out_K_list = [ 332 | (WToX1True @ sId.model1.rnn.cell.K.predict(YCols[yi])).T 333 | for yi in YRng 334 | ] 335 | if trueModel is not None: # Add true model's param 336 | if isinstance(trueModel, SSM): 337 | y_out_K_true = [ 338 | trueModel.apply_param("K", YCols[yi]).T for yi in YRng 339 | ] 340 | else: 341 | y_out_K_true = [(trueModel.K @ YCols[yi]).T for yi in YRng] 342 | y_out_K_true_X1 = [y_out_K_true[yi][:, :n1True] for yi in YRng] 343 | y_out_K_list = [ 344 | np.concatenate( 345 | (y_out_K_list[yi], y_out_K_true_X1[yi]), axis=1 346 | ) 347 | for yi in YRng 348 | ] 349 | plotPredictionScatter( 350 | [ 351 | np.tile(Y[:, yi].T, (y_out_K_list[0].shape[1], 1)) 352 | for yi in YRng 353 | ], 354 | [y_out_K_list[yi].T for yi in YRng], 355 | connect_sorted=True, 356 | legNames=[ 357 | f"Learned (x{xi+1})" for xi in range(WToX1True.shape[0]) 358 | ] 359 | + [f"True (x{xi+1}) (sim)" for xi in range(WToX1True.shape[0])], 360 | square=False, 361 | styles=[{"linestyle": "-"}] * WToX1True.shape[0] 362 | + [{"linestyle": "--"}] * WToX1True.shape[0], 363 | title=[f"{title}K1{yi+1}" for yi in YRng], 364 | xLabel=[f"y{yi+1}" for yi in YRng], 365 | yLabel=[f"K1(y{yi+1})" for yi in YRng], 366 | figsize=figsize, 367 | show_hist_x=show_hist_x, 368 | show_hist_y=show_hist_y, 369 | XLimByQuantile=XLimByQuantile, 370 | YLimByQuantile=YLimByQuantile, 371 | x_keep_quantiles=x_keep_quantiles, 372 | skip_existing=skip_existing, 373 | saveFile=( 374 | None if savePath is None else savePath + "_paramK1_sim" 375 | ), 376 | saveExtensions=saveExtensions, 377 | ) 378 | if params_to_plot is None or "Cy" in params_to_plot: 379 | if plot_orig and ( 380 | savePath is None 381 | or not checkIfAllExtsAlreadyExist( 382 | savePath + "_paramCy", saveExtensions 383 | ) 384 | or not skip_existing 385 | ): 386 | y_out_Cy_list = [ 387 | sId.model1_Cy.predict(X1Cols[xi]).T for xi in X1Rng 388 | ] 389 | plotPredictionScatter( 390 | [ 391 | np.tile(X1[:, xi].T, (y_out_Cy_list[0].shape[1], 1)) 392 | for xi in X1Rng 393 | ], 394 | [y_out_Cy_list[xi].T for xi in X1Rng], 395 | connect_sorted=True, 396 | square=False, 397 | styles={"size": 1}, 398 | title=[f"{title}Cy1{xi+1}" for xi in X1Rng], 399 | xLabel=[f"x{xi+1}" for xi in X1Rng], 400 | yLabel=[f"Cy(x{xi+1})" for xi in X1Rng], 401 | figsize=figsize, 402 | show_hist_x=show_hist_x, 403 | show_hist_y=show_hist_y, 404 | XLimByQuantile=XLimByQuantile, 405 | YLimByQuantile=YLimByQuantile, 406 | x_keep_quantiles=x_keep_quantiles, 407 | skip_existing=skip_existing, 408 | saveFile=None if savePath is None else savePath + "_paramCy", 409 | saveExtensions=saveExtensions, 410 | ) 411 | if XTrue is not None and ( 412 | savePath is None 413 | or not checkIfAllExtsAlreadyExist( 414 | savePath + "_paramCy_sim", saveExtensions 415 | ) 416 | or not skip_existing 417 | ): 418 | y_out_Cy_list = [ 419 | sId.model1_Cy.predict(X1HatCols[xi]).T for xi in X1Rng 420 | ] 421 | if trueModel is not None: # Add true model's param 422 | if isinstance(trueModel, SSM): 423 | y_out_Cy_true = [ 424 | trueModel.apply_param("C", XTrueCols[xi]).T 425 | for xi in X1Rng 426 | ] 427 | else: 428 | y_out_Cy_true = [ 429 | (trueModel.C @ XTrueCols[xi]).T for xi in X1Rng 430 | ] 431 | y_out_Cy_list = [ 432 | np.concatenate( 433 | (y_out_Cy_list[xi], y_out_Cy_true[xi]), axis=1 434 | ) 435 | for xi in X1Rng 436 | ] 437 | plotPredictionScatter( 438 | [ 439 | np.tile(X1True[:, xi].T, (y_out_Cy_list[0].shape[1], 1)) 440 | for xi in X1Rng 441 | ], 442 | [y_out_Cy_list[xi].T for xi in X1Rng], 443 | connect_sorted=True, 444 | legNames=["Learned", "True (sim)"], 445 | square=False, 446 | styles=[{"linestyle": "-"}, {"linestyle": "--"}], 447 | title=[f"{title}Cy1{xi+1}" for xi in X1Rng], 448 | xLabel=[f"x{xi+1}" for xi in X1Rng], 449 | yLabel=[f"Cy(x{xi+1})" for xi in X1Rng], 450 | figsize=figsize, 451 | show_hist_x=show_hist_x, 452 | show_hist_y=show_hist_y, 453 | XLimByQuantile=XLimByQuantile, 454 | YLimByQuantile=YLimByQuantile, 455 | x_keep_quantiles=x_keep_quantiles, 456 | skip_existing=skip_existing, 457 | saveFile=( 458 | None if savePath is None else savePath + "_paramCy_sim" 459 | ), 460 | saveExtensions=saveExtensions, 461 | ) 462 | pass 463 | -------------------------------------------------------------------------------- /source/DPAD/tools/sim_tools.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2024 University of Southern California 3 | See full notice in LICENSE.md 4 | Omid G. Sani and Maryam M. Shanechi 5 | Shanechi Lab, University of Southern California 6 | """ 7 | 8 | """Tools for simulating models""" 9 | 10 | import numpy as np 11 | from scipy import stats 12 | 13 | 14 | def addConjs(vals): 15 | """Adds complex conjugate values for each complex value 16 | 17 | Args: 18 | vals (list of number): list of numbers (e.g. eigenvalues) 19 | 20 | Returns: 21 | output (list of numbers): new list of numbers where for each complex value in the original list, 22 | the conjugate is also added to the new list. For real values an np.nan is added. 23 | """ 24 | vals = np.atleast_2d(vals).T 25 | valsConj = vals.conj() 26 | valsConj[np.abs(vals - valsConj) < np.spacing(1)] = np.nan 27 | return np.concatenate((vals, valsConj), axis=1) 28 | 29 | 30 | def drawRandomPoles(N, poleDist={}): 31 | """Draws random eigenvalues from the unit dist 32 | 33 | Args: 34 | N (int): number of eigenvalues to draw 35 | poleDist (dict, optional): information about the distribution. 36 | For options see the top of the code. Defaults to dict. 37 | 38 | Returns: 39 | valsA (list of numbers): drawn random values 40 | """ 41 | nCplx = int(np.floor(N / 2)) 42 | 43 | if "magDist" not in poleDist: 44 | poleDist["magDist"] = "beta" 45 | if "magDistParams" not in poleDist and poleDist["magDist"] == "beta": 46 | poleDist["magDistParams"] = {"a": 2, "b": 1} 47 | if "angleDist" not in poleDist: 48 | poleDist["angleDist"] = "uniform" 49 | 50 | # mag = np.random.rand(nCplx) # Uniform dist 51 | if poleDist["magDist"] == "beta": 52 | a, b = 2, 1 # Use a, b = 2, 1 for uniform prob over unit circle 53 | if "a" in poleDist["magDistParams"]: 54 | a = poleDist["magDistParams"]["a"] 55 | if "b" in poleDist["magDistParams"]: 56 | b = poleDist["magDistParams"]["b"] 57 | 58 | """ 59 | import matplotlib.pyplot as plt 60 | fig, ax = plt.subplots(1, 1) 61 | x = np.linspace(0, 1, 100) 62 | ax.plot(x, stats.beta.pdf(x, a, b), lw=5, alpha=0.6, label='beta pdf (a={}, b={})'.format(a, b)) 63 | ax.legend() 64 | plt.show() 65 | """ 66 | 67 | mag = stats.beta.rvs(a=a, b=b, size=nCplx) # Beta dist 68 | else: 69 | raise Exception("Only beta distribution is supported for the magnitude") 70 | 71 | if poleDist["angleDist"] == "uniform": 72 | theta = np.random.rand(nCplx) * np.pi 73 | else: 74 | raise Exception("Only uniform distribution is supported for the angle") 75 | 76 | vals = mag * np.exp(1j * theta) 77 | 78 | valsA = addConjs(vals) 79 | valsA = valsA.reshape(valsA.size) 80 | valsA = valsA[np.logical_not(np.isnan(valsA))] 81 | 82 | # Add real mode(s) if needed 83 | nReal = N - 2 * nCplx 84 | if nReal > 0: 85 | # rVals = np.random.rand(nReal) 86 | rVals = stats.beta.rvs(a=a, b=b, size=nReal) # Beta dist 87 | rSign = 2 * (((np.random.rand(nReal) > 0.5).astype(float)) - 0.5) 88 | 89 | valsA = np.concatenate((valsA, rVals * rSign)) 90 | 91 | return valsA 92 | -------------------------------------------------------------------------------- /source/DPAD/tools/tests/test_LinearMapping.py: -------------------------------------------------------------------------------- 1 | """ Omid Sani, Shanechi Lab, University of Southern California, 2020 """ 2 | 3 | # pylint: disable=C0103, C0111 4 | 5 | "Tests LinearMapping" 6 | 7 | import copy 8 | import os 9 | import sys 10 | import unittest 11 | 12 | sys.path.insert(0, os.path.dirname(__file__)) 13 | sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) 14 | 15 | import numpy as np 16 | 17 | from ..LinearMapping import LinearMapping 18 | 19 | # print('__file__={0:<35} | __name__={1:<20} | __package__={2:<20}'.format(__file__,__name__,str(__package__))) 20 | 21 | 22 | class TestLinearMapping(unittest.TestCase): 23 | def test_LinearMapping(self): 24 | np.random.seed(42) 25 | 26 | numTests = 100 27 | for ci in range(numTests): 28 | ny = np.random.choice(np.arange(1, 10)) 29 | n_rem = np.random.choice(np.arange(0, ny)) 30 | 31 | N = 100 32 | Y = np.random.randn(N, ny) 33 | rem_inds = np.unique(np.random.random_integers(0, ny - 1, n_rem)) 34 | keep_vector = np.array([yi not in rem_inds for yi in range(ny)]) 35 | Y[:, rem_inds] = 0 36 | 37 | LM = LinearMapping() 38 | LM.set_to_dimension_remover(keep_vector) 39 | Y_rem = LM.apply(Y.T).T 40 | Y_recover = LM.apply_inverse(Y_rem.T).T 41 | 42 | np.testing.assert_array_equal(Y_rem, Y[:, keep_vector]) 43 | np.testing.assert_array_equal(Y_recover, Y) 44 | 45 | def tearDown(self): 46 | pass 47 | 48 | 49 | if __name__ == "__main__": 50 | unittest.main() 51 | -------------------------------------------------------------------------------- /source/DPAD/tools/tests/test_parse_tools.py: -------------------------------------------------------------------------------- 1 | """ Omid Sani, Shanechi Lab, University of Southern California, 2020 """ 2 | 3 | # pylint: disable=C0103, C0111 4 | 5 | "Tests the module" 6 | 7 | import copy 8 | import os 9 | import pickle 10 | import sys 11 | import unittest 12 | 13 | sys.path.insert(0, os.path.dirname(__file__)) 14 | sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) 15 | 16 | import matplotlib.pyplot as plt 17 | import numpy as np 18 | import tensorflow as tf 19 | from DPAD.tools.parse_tools import parseMethodCodeArgOptimizer 20 | 21 | 22 | class TestParseTools(unittest.TestCase): 23 | 24 | def test_optimizers(self): 25 | codes = [ 26 | "DPAD_LR1e-3", 27 | "DPAD_LR1e-3_optAdamW_ErSV128", 28 | "DPAD_LR1e-3_WD1e-3_optAdamW_scCDR2000_ErSV128", 29 | "DPAD_LR1e-3_optAdamW_scCD2000_ErSV128", 30 | "DPAD_LR1e-3_optAdamW_scED2000_0.96_ErSV128", 31 | "DPAD_LR1e-3_optAdamW_scED2000_0.96_strT_ErSV128", 32 | "DPAD_LR1e-3_optAdamW_scITD2000_0.96_ErSV128", 33 | "DPAD_LR1e-3_optAdamW_scPD2000_1e-4_ErSV128", 34 | "DPAD_LR1e-3_optAdamW_scPD2000_1e-4_cycT_ErSV128", 35 | "DPAD_LR1e-3_optAdamW_scPD2000_1e-4_2_cycT_ErSV128", 36 | ] 37 | 38 | fig = plt.figure(figsize=(10, 8)) 39 | ax = fig.add_subplot(1, 3, (1, 2)) 40 | lineStyles = ["-", "--", "-.", ":"] 41 | 42 | for mi, methodCode in enumerate(codes): 43 | lr_scheduler_name = None 44 | lr_scheduler_args = None 45 | 46 | optimizer_name = "Adam" # default 47 | optimizer_args = None 48 | optimizer_infos, matches = parseMethodCodeArgOptimizer(methodCode) 49 | if len(optimizer_infos) > 0: 50 | optimizer_info = optimizer_infos[0] 51 | if "optimizer_name" in optimizer_info: 52 | optimizer_name = optimizer_info["optimizer_name"] 53 | if "optimizer_args" in optimizer_info: 54 | optimizer_args = optimizer_info["optimizer_args"] 55 | if "scheduler_name" in optimizer_info: 56 | lr_scheduler_name = optimizer_info["scheduler_name"] 57 | if "scheduler_args" in optimizer_info: 58 | lr_scheduler_args = optimizer_info["scheduler_args"] 59 | 60 | if lr_scheduler_args is None: 61 | lr_scheduler_args = {} 62 | if optimizer_args is None: 63 | optimizer_args = {} 64 | optimizer_args_BU = copy.deepcopy(optimizer_args) 65 | 66 | if isinstance(lr_scheduler_name, str): 67 | if hasattr(tf.keras.optimizers.schedules, lr_scheduler_name): 68 | lr_scheduler_constructor = getattr( 69 | tf.keras.optimizers.schedules, lr_scheduler_name 70 | ) 71 | else: 72 | raise Exception( 73 | "Learning rate scheduler {lr_scheduler_name} not supported as string, pass actual class for the optimizer (e.g. tf.keras.optimizers.Adam)" 74 | ) 75 | else: 76 | lr_scheduler_constructor = lr_scheduler_name 77 | if isinstance(optimizer_name, str): 78 | if hasattr(tf.keras.optimizers, optimizer_name): 79 | optimizer_constructor = getattr(tf.keras.optimizers, optimizer_name) 80 | else: 81 | raise Exception( 82 | "optimizer not supported as string, pass actual class for the optimizer (e.g. tf.keras.optimizers.Adam)" 83 | ) 84 | else: 85 | optimizer_constructor = optimizer_name 86 | if lr_scheduler_constructor is not None: 87 | if ( 88 | "learning_rate" in optimizer_args 89 | and "initial_learning_rate" not in lr_scheduler_args 90 | ): 91 | lr_scheduler_args["initial_learning_rate"] = optimizer_args[ 92 | "learning_rate" 93 | ] 94 | lr_scheduler = lr_scheduler_constructor(**lr_scheduler_args) 95 | optimizer_args["learning_rate"] = lr_scheduler 96 | else: 97 | lr_scheduler = lambda steps: optimizer_args[ 98 | "learning_rate" 99 | ] * np.ones_like(steps) 100 | optimizer = optimizer_constructor(**optimizer_args) 101 | 102 | epochs = 2000 103 | batches = 20 104 | steps = np.arange(epochs * batches) 105 | lr = np.array(lr_scheduler(steps)) 106 | 107 | ax.plot( 108 | steps, 109 | lr, 110 | label=f"{methodCode}\nOptimizer: {optimizer_name}, {optimizer_args_BU}, Scheduler: {lr_scheduler_name}\n{lr_scheduler_args}", 111 | linestyle=lineStyles[mi % len(lineStyles)], 112 | ) 113 | 114 | print(f"Done with {lr_scheduler_name}") 115 | 116 | ax.set_xlabel(f"Training steps") 117 | ax.set_ylabel(f"Learning rate") 118 | ax.legend( 119 | bbox_to_anchor=(1.04, 0.5), 120 | loc="center left", 121 | borderaxespad=0, 122 | fontsize="x-small", 123 | ) 124 | plt.show() 125 | 126 | print("Test!") 127 | 128 | 129 | if __name__ == "__main__": 130 | unittest.main() 131 | -------------------------------------------------------------------------------- /source/DPAD/tools/tests/test_tf_losses.py: -------------------------------------------------------------------------------- 1 | """ Omid Sani, Shanechi Lab, University of Southern California, 2020 """ 2 | 3 | # pylint: disable=C0103, C0111 4 | 5 | "Tests tensorflow losses" 6 | 7 | import copy 8 | import os 9 | import sys 10 | import unittest 11 | 12 | sys.path.insert(0, os.path.dirname(__file__)) 13 | sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) 14 | 15 | import numpy as np 16 | from tools.evaluation import evalPrediction 17 | from tools.tf_losses import ( 18 | computeR2_masked, 19 | masked_CategoricalCrossentropy, 20 | masked_CC, 21 | masked_mse, 22 | masked_negativeCC, 23 | masked_negativeR2, 24 | masked_PoissonLL_loss, 25 | masked_R2, 26 | masked_SparseCategoricalCrossentropy, 27 | ) 28 | from tools.tools import get_one_hot 29 | 30 | # print('__file__={0:<35} | __name__={1:<20} | __package__={2:<20}'.format(__file__,__name__,str(__package__))) 31 | 32 | 33 | def prep_CCE_case(missing_marker=-1): 34 | # For 4D signals 35 | np.random.seed(42) 36 | 37 | N, nb, nx, nc = 100, 32, 2, 3 38 | ZClass = np.random.randint(0, nc, (N, nb, nx)) 39 | Z = get_one_hot(ZClass, nc) 40 | ZProb = np.random.rand(N, nb, nx, nc) 41 | ZProb = ZProb / np.sum(ZProb, axis=-1)[..., np.newaxis] 42 | ZLogProb = np.log(ZProb) 43 | missing_inds = np.random.randint(0, N, int(np.round(0.4 * N))) 44 | Z[missing_inds, 0] = missing_marker 45 | 46 | ZR = np.reshape(Z, [np.prod(Z.shape[0:2]), Z.shape[-2], Z.shape[-1]]) 47 | ZLogProbR = np.reshape( 48 | ZLogProb, [np.prod(ZLogProb.shape[0:2]), ZLogProb.shape[-2], ZLogProb.shape[-1]] 49 | ) 50 | 51 | isOk = np.all(ZR != missing_marker, axis=-1) 52 | isOk = np.all(isOk, axis=-1) 53 | 54 | ZROk = ZR[isOk, ...] 55 | ZLogProbROk = ZLogProbR[isOk, ...] 56 | 57 | CCE = -np.mean(np.sum(ZROk * ZLogProbROk, axis=-1)) 58 | 59 | return Z, ZClass, ZLogProb, CCE 60 | 61 | 62 | def assert_params_are_close(s, sBack): 63 | skipParams = [] 64 | impossibleParams = [] 65 | okParams = [] 66 | errorParams = [] 67 | errorParamsErr = [] 68 | 69 | for p, valOrig in s.getListOfParams().items(): 70 | if not hasattr(sBack, p): 71 | skipParams.append(p) 72 | continue 73 | valNew = getattr(sBack, p) 74 | if valNew is not None: 75 | try: 76 | np.testing.assert_allclose(valNew, valOrig, rtol=1e-3, atol=1e-6) 77 | okParams.append(p) 78 | except Exception as e: 79 | errorParams.append(p) 80 | errorParamsErr.append(e) 81 | continue 82 | else: 83 | impossibleParams.append(p) 84 | 85 | return skipParams, impossibleParams, okParams, errorParams, errorParamsErr 86 | 87 | 88 | class TestLosses(unittest.TestCase): 89 | 90 | def test_masked_mse_for2D(self): 91 | # For 2D signals 92 | np.random.seed(42) 93 | 94 | N, nx = 100, 3 95 | missing_marker = -1 96 | Z = np.random.rand(N, nx) 97 | ZHat = Z + 0.5 * np.random.rand(N, nx) 98 | missing_inds = np.random.randint(0, N, int(np.round(0.4 * N))) 99 | Z[missing_inds, 0] = missing_marker 100 | 101 | isOk = np.all(Z != missing_marker, axis=-1) 102 | 103 | expected = np.mean(np.power(Z[isOk, :] - ZHat[isOk, :], 2)) 104 | 105 | lossFunc = masked_mse(missing_marker) 106 | computed = float(lossFunc(Z, ZHat)) 107 | np.testing.assert_allclose(expected, computed, rtol=1e-3) 108 | 109 | def test_masked_mse_for3D(self): 110 | # For 3D signals 111 | np.random.seed(42) 112 | 113 | N, nb, nx = 100, 32, 3 114 | missing_marker = -1 115 | Z = np.random.rand(N, nb, nx) 116 | ZHat = Z + 0.5 * np.random.rand(N, nb, nx) 117 | missing_inds = np.random.randint(0, N, int(np.round(0.4 * N))) 118 | Z[missing_inds, 0] = missing_marker 119 | 120 | isOk = np.all(Z != missing_marker, axis=-1) 121 | 122 | expected = np.mean(np.power(Z[isOk, ...] - ZHat[isOk, ...], 2)) 123 | 124 | lossFunc = masked_mse(missing_marker) 125 | computed = float(lossFunc(Z, ZHat)) 126 | np.testing.assert_allclose(expected, computed, rtol=1e-3) 127 | 128 | def test_masked_CC_for2D(self): 129 | # For 2D signals 130 | np.random.seed(42) 131 | 132 | N, nx = 100, 3 133 | missing_marker = -1 134 | Z = np.random.rand(N, nx) 135 | ZHat = Z + 0.5 * np.random.rand(N, nx) 136 | missing_inds = np.random.randint(0, N, int(np.round(0.4 * N))) 137 | Z[missing_inds, 0] = missing_marker 138 | 139 | isOk = np.all(Z != missing_marker, axis=-1) 140 | 141 | expected = np.mean(evalPrediction(Z[isOk, :], ZHat[isOk, :], "CC")) 142 | 143 | lossFunc = masked_CC(missing_marker) 144 | computed = float(lossFunc(Z, ZHat)) 145 | np.testing.assert_allclose(expected, computed, rtol=1e-3) 146 | 147 | lossFuncNeg = masked_negativeCC(missing_marker) 148 | computedNeg = float(lossFuncNeg(Z, ZHat)) 149 | np.testing.assert_allclose(-expected, computedNeg, rtol=1e-3) 150 | 151 | def test_masked_R2_for2D(self): 152 | # For 2D signals 153 | np.random.seed(42) 154 | 155 | N, nx = 100, 3 156 | missing_marker = -1 157 | for test_num in range(10): 158 | Z = np.random.rand(N, nx) 159 | ZHat = Z + 0.5 * np.random.rand(N, nx) 160 | missing_inds = np.random.randint(0, N, int(np.round(0.4 * N))) 161 | Z[missing_inds, 0] = missing_marker 162 | 163 | flat_chans = int(np.round(np.random.rand())) 164 | Z[:, :flat_chans] = np.mean(Z[:, :flat_chans], axis=0) 165 | 166 | isOk = np.all(Z != missing_marker, axis=-1) 167 | 168 | allR2_expected = evalPrediction(Z[isOk, :], ZHat[isOk, :], "R2") 169 | allR2_computed = np.array( 170 | computeR2_masked(Z, ZHat, missing_marker), dtype=float 171 | ) 172 | 173 | np.testing.assert_allclose(allR2_expected, allR2_computed, rtol=1e-3) 174 | 175 | expected = np.mean(allR2_expected) 176 | 177 | lossFunc = masked_R2(missing_marker) 178 | computed = float(lossFunc(Z, ZHat)) 179 | np.testing.assert_allclose(expected, computed, rtol=1e-3) 180 | 181 | lossFuncNeg = masked_negativeR2(missing_marker) 182 | computedNeg = float(lossFuncNeg(Z, ZHat)) 183 | np.testing.assert_allclose(-expected, computedNeg, rtol=1e-3) 184 | 185 | def test_masked_PoissonLL_loss_for3D(self): 186 | # For 3D signals 187 | np.random.seed(42) 188 | 189 | N, nb, nx = 100, 32, 3 190 | missing_marker = -1 191 | logR = np.random.randn(N, nb, nx) # Log rates 192 | R = np.exp(logR) # Rates 193 | Z = np.random.poisson(R) # Counts 194 | RHat = np.exp(logR + 0.5 * np.random.randn(N, nb, nx)) 195 | missing_inds = np.random.randint(0, N, int(np.round(0.4 * N))) 196 | Z[missing_inds, 0] = missing_marker 197 | 198 | isOk = np.all(Z != missing_marker, axis=-1) 199 | 200 | ZOk = Z[isOk, ...] 201 | RHatOk = RHat[isOk, ...] 202 | 203 | # loss = y_pred - y_true * log(y_pred) 204 | expected = np.mean(RHatOk - ZOk * np.log(RHatOk)) 205 | 206 | lossFunc = masked_PoissonLL_loss(missing_marker) 207 | computed = float(lossFunc(Z, RHat)) 208 | np.testing.assert_allclose(expected, computed, rtol=1e-3) 209 | 210 | def test_masked_CategoricalCrossentropy(self): 211 | missing_marker = -1 212 | Z, ZClass, ZLogProb, expected = prep_CCE_case(missing_marker) 213 | lossFunc = masked_CategoricalCrossentropy(missing_marker) 214 | computed = float(lossFunc(Z, ZLogProb)) 215 | np.testing.assert_allclose(expected, computed, rtol=1e-3) 216 | 217 | def test_masked_SparseCategoricalCrossentropy(self): 218 | missing_marker = -1 219 | Z, ZClass, ZLogProb, expected = prep_CCE_case(missing_marker) 220 | lossFunc = masked_SparseCategoricalCrossentropy(missing_marker) 221 | computed = float(lossFunc(ZClass, ZLogProb)) 222 | np.testing.assert_allclose(expected, computed, rtol=1e-2) 223 | 224 | def tearDown(self): 225 | pass 226 | 227 | 228 | if __name__ == "__main__": 229 | unittest.main() 230 | -------------------------------------------------------------------------------- /source/DPAD/tools/tests/test_tools.py: -------------------------------------------------------------------------------- 1 | """ Omid Sani, Shanechi Lab, University of Southern California, 2020 """ 2 | 3 | # pylint: disable=C0103, C0111 4 | 5 | "Tests the module" 6 | 7 | import copy 8 | import os 9 | import sys 10 | import unittest 11 | 12 | sys.path.insert(0, os.path.dirname(__file__)) 13 | sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) 14 | 15 | import numpy as np 16 | from scipy import linalg 17 | from tools.tools import ( 18 | extractDiagonalBlocks, 19 | getBlockIndsFromBLKSArray, 20 | isClockwise, 21 | shortenGaps, 22 | standardizeStateTrajectory, 23 | ) 24 | 25 | 26 | class TestTools(unittest.TestCase): 27 | 28 | def test_extractDiagonalBlocks(self): 29 | testCases = [ 30 | [ 31 | { 32 | "A": linalg.block_diag( 33 | 1, 2 * np.ones((3, 3)), 3, 4 * np.ones((2, 2)) 34 | ), 35 | }, 36 | np.array([1, 3, 1, 2]), 37 | ], 38 | [ 39 | { 40 | "A": linalg.block_diag(5 * np.ones((5, 5))), 41 | }, 42 | np.array([5]), 43 | ], 44 | [ 45 | { 46 | "A": linalg.block_diag(1 * np.ones((5, 5)), 2 * np.ones((3, 3))), 47 | }, 48 | np.array([5, 3]), 49 | ], 50 | [{"A": np.array([[1, 2], [0, 3]]), "emptySide": "lower"}, np.array([1, 1])], 51 | [{"A": np.array([[1, 2], [0, 3]]), "emptySide": "upper"}, np.array([2])], 52 | ] 53 | for case in testCases: 54 | input_args = case[0] 55 | output_correct = case[1] 56 | 57 | BLKS = extractDiagonalBlocks(**input_args) 58 | np.testing.assert_array_equal(BLKS, output_correct) 59 | 60 | def test_getBlockIndsFromBLKSArray(self): 61 | testCases = [ 62 | [{"BLKS": [1]}, np.array([[0, 1]])], 63 | [{"BLKS": [2]}, np.array([[0, 2]])], 64 | [{"BLKS": [1, 2]}, np.array([[0, 1], [1, 3]])], 65 | [{"BLKS": [1, 3, 1]}, np.array([[0, 1], [1, 4], [4, 5]])], 66 | ] 67 | for case in testCases: 68 | input_args = case[0] 69 | output_correct = case[1] 70 | 71 | groups = getBlockIndsFromBLKSArray(**input_args) 72 | np.testing.assert_array_equal(groups, output_correct) 73 | 74 | def test_shortenGaps(self): 75 | N = 1000 76 | T = 100 77 | t = np.random.rand(N) * T 78 | t = t[np.argsort(t)][:, np.newaxis] 79 | timeCopy = copy.copy(t) 80 | dT = np.median(np.diff(t)) 81 | 82 | tNew, timeRemapper = shortenGaps(t) 83 | 84 | timeCopyRemap = timeRemapper.apply(timeCopy) 85 | np.testing.assert_allclose(tNew, timeCopyRemap) 86 | 87 | t = np.arange(N) 88 | tNew, timeRemapper = shortenGaps(np.array(t)) 89 | timeCopyRemap = timeRemapper.apply(copy.copy(t)) 90 | np.testing.assert_allclose(tNew, t) 91 | np.testing.assert_allclose(tNew, timeCopyRemap) 92 | 93 | def test_isClockwise(self): 94 | np.random.seed(42) 95 | 96 | num_tests = 1000 97 | for ti in range(num_tests): 98 | theta0, theta1 = np.sort(np.random.rand(2) * np.pi * 2) 99 | theta = np.linspace(theta0, theta1, 100) 100 | c = (np.random.rand(2) - 0.5) * 2 101 | r = np.random.rand(2) 102 | x = c[0] + r[0] * np.cos(theta) 103 | y = c[1] + 10 * r[1] * np.sin(theta) 104 | data = np.concatenate((x[:, np.newaxis], y[:, np.newaxis]), axis=1) 105 | 106 | rotTheta = (np.random.rand(1)[0] - 0.5) * np.pi 107 | R = np.array( 108 | [ 109 | [np.cos(rotTheta), -np.sin(rotTheta)], 110 | [np.sin(rotTheta), np.cos(rotTheta)], 111 | ] 112 | ) 113 | data = (R @ data.T).T 114 | 115 | obsInds = np.sort( 116 | np.random.randint(0, theta.size, np.random.randint(5, 30)) 117 | ) 118 | 119 | isCW = isClockwise(data[obsInds, :]) 120 | np.testing.assert_equal(isCW, False) 121 | 122 | isCW = isClockwise(np.flipud(data[obsInds, :])) 123 | np.testing.assert_equal(isCW, True) 124 | 125 | def test_standardizeStateTrajectory(self): 126 | np.random.seed(42) 127 | 128 | num_tests = 1000 129 | for ti in range(num_tests): 130 | theta0, theta1 = np.sort(np.random.rand(2) * np.pi * 2) 131 | theta = np.linspace(theta0, theta1, 100) 132 | elev0, elev1 = np.sort(np.random.rand(2) * np.pi * 2) 133 | elev = np.linspace(elev0, elev1, 100) 134 | c = (np.random.rand(3) - 0.5) * 2 135 | r = np.random.rand(3) 136 | x = c[0] + r[0] * np.cos(theta) 137 | y = c[1] + 10 * r[1] * np.sin(theta) 138 | z = c[2] + 5 * r[2] * np.sin(elev) 139 | data = np.concatenate( 140 | (x[:, np.newaxis], y[:, np.newaxis], z[:, np.newaxis]), axis=1 141 | ) 142 | 143 | rotTheta = (np.random.rand(1)[0] - 0.5) * np.pi 144 | R = np.array( 145 | [ 146 | [np.cos(rotTheta), -np.sin(rotTheta)], 147 | [np.sin(rotTheta), np.cos(rotTheta)], 148 | ] 149 | ) 150 | R1 = np.block([[R, np.zeros((2, 1))], [np.zeros((1, 2)), np.ones(1)]]) 151 | 152 | rotElev = (np.random.rand(1)[0] - 0.5) * np.pi 153 | R = np.array( 154 | [ 155 | [np.cos(rotElev), -np.sin(rotElev)], 156 | [np.sin(rotElev), np.cos(rotElev)], 157 | ] 158 | ) 159 | R2 = np.block([[np.ones(1), np.zeros((1, 2))], [np.zeros((2, 1)), R]]) 160 | data = (R2 @ R1 @ data.T).T 161 | 162 | xMean = np.random.rand(3) 163 | data = data + xMean 164 | 165 | obsInds = np.unique( 166 | np.random.randint(0, theta.size, np.random.randint(25, 40)) 167 | ) 168 | 169 | for nx in [1, 3]: 170 | xTest = data[obsInds, :nx] 171 | xTestN, E, X0 = standardizeStateTrajectory(xTest, generate_plot=False) 172 | 173 | # Outputs should describe the same similarity transform 174 | xTestNExpected = (E @ (xTest - X0).T).T 175 | np.testing.assert_allclose(xTestN, xTestNExpected) 176 | 177 | # Zero mean 178 | np.testing.assert_almost_equal(np.mean(xTestN, axis=0), np.zeros(nx)) 179 | 180 | # Start from the positive side of the x-axis (have no y-element in the start) 181 | np.testing.assert_array_less(0, xTestN[0, 0]) 182 | if nx > 1: 183 | np.testing.assert_almost_equal(xTestN[0, 1], 0) 184 | 185 | # Be counter clockwise on the xy-plane 186 | isCW = isClockwise(xTestN[:, :2]) 187 | np.testing.assert_equal(isCW, False) 188 | 189 | 190 | if __name__ == "__main__": 191 | unittest.main() 192 | -------------------------------------------------------------------------------- /source/DPAD/tools/tf_losses.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2024 University of Southern California 3 | See full notice in LICENSE.md 4 | Omid G. Sani and Maryam M. Shanechi 5 | Shanechi Lab, University of Southern California 6 | """ 7 | 8 | """Tensorflow losses""" 9 | 10 | import tensorflow as tf 11 | 12 | 13 | def masked_mse(mask_value=None): 14 | """Returns a tf MSE loss computation function, but with support for setting one value as a mask indicator 15 | 16 | Args: 17 | mask_value (numpy value, optional): if not None, will treat this value as mask indicator. Defaults to None. 18 | """ 19 | 20 | def f(y_true, y_pred): 21 | # mse = tf.reduce_mean(tf.math.squared_difference(y_pred, y_true), axis=-1) # Without handling NaNs 22 | # Assumes that the last dimension is the only data dimension, others are sample dimensions (batch,time,etc) 23 | sh = tf.shape(y_true) 24 | y_true_r = tf.reshape(y_true, [tf.reduce_prod(sh[:-1]), sh[-1]]) 25 | y_pred_r = tf.reshape(y_pred, [tf.reduce_prod(sh[:-1]), sh[-1]]) 26 | y_true_f = tf.cast(y_true_r, dtype=y_pred.dtype) 27 | y_pred_f = tf.cast(y_pred_r, dtype=y_pred.dtype) 28 | if mask_value is not None: 29 | mask_value_cast = tf.constant(mask_value, dtype=y_true_f.dtype) 30 | isOk = tf.not_equal(y_true_f, mask_value_cast) 31 | else: 32 | isOk = tf.ones_like(y_true_f, dtype=bool) 33 | isOk1 = tf.math.reduce_all(isOk, axis=-1) 34 | y_true_masked = tf.boolean_mask(y_true_f, isOk1, axis=0) 35 | y_pred_masked = tf.boolean_mask(y_pred_f, isOk1, axis=0) 36 | lossFunc = tf.keras.losses.MeanSquaredError() 37 | return lossFunc(y_true_masked, y_pred_masked) 38 | 39 | f.__name__ = str("MSE_maskV_{}".format(mask_value)) 40 | return f 41 | 42 | 43 | def compute_CC(x, y): # https://stackoverflow.com/a/58890795/2275605 44 | """Computes correlation coefficient (CC) in tensorflow 45 | 46 | Args: 47 | x (numpy array): input 1 48 | y (numpy array): input 2 49 | 50 | Returns: 51 | tf.Tensor: CC value 52 | """ 53 | mx = tf.math.reduce_mean(x) 54 | my = tf.math.reduce_mean(y) 55 | xm, ym = x - mx, y - my 56 | r_num = tf.math.reduce_mean(tf.multiply(xm, ym)) 57 | r_den = tf.math.reduce_std(xm) * tf.math.reduce_std(ym) 58 | return r_num / r_den 59 | 60 | 61 | def compute_R2(y_true, y_pred): # https://stackoverflow.com/a/58890795/2275605 62 | """Computes correlation of determination (R2) in tensorflow 63 | 64 | Args: 65 | x (numpy array): input 1 66 | y (numpy array): input 2 67 | 68 | Returns: 69 | tf.Tensor: CC value 70 | """ 71 | m_true = tf.math.reduce_mean(y_true, axis=0) 72 | 73 | r_num = tf.math.reduce_sum(tf.math.pow(y_true - y_pred, 2), axis=0) 74 | r_den = tf.math.reduce_sum(tf.math.pow(y_true - m_true, 2), axis=0) 75 | 76 | R2 = 1 - (r_num / r_den) 77 | 78 | isFlat = (tf.reduce_max(y_true, axis=0) - tf.reduce_min(y_pred, axis=0)) < 1e-9 79 | R2 = tf.where(isFlat, tf.zeros_like(R2), R2) 80 | return R2 81 | 82 | 83 | def computeCC_masked(y_true, y_pred, mask_value=None): 84 | """Computes correlation coefficient (CC) in tensorflow, with support for a masked value. 85 | First dimension of data is the sample/time dimension. If a sample has a mask_value in 86 | one of its dimensions, it will be discarded before the CC computation. 87 | Args: 88 | y_true (numpy array): input 1. 89 | y_pred (numpy array): input 2 90 | mask_value (numpy value, optional): if not None, will treat this value as mask indicator. Defaults to None. 91 | 92 | Returns: 93 | tf.Tensor: CC value 94 | """ 95 | # Assumes that the last dimension is the only data dimension, others are sample dimensions (batch,time,etc) 96 | sh = tf.shape(y_true) 97 | y_true_r = tf.reshape(y_true, [tf.reduce_prod(sh[:-1]), sh[-1]]) 98 | y_pred_r = tf.reshape(y_pred, [tf.reduce_prod(sh[:-1]), sh[-1]]) 99 | y_true_f = tf.cast(y_true_r, dtype=y_pred.dtype) 100 | y_pred_f = tf.cast(y_pred_r, dtype=y_pred.dtype) 101 | if mask_value is not None: 102 | mask_value_cast = tf.constant(mask_value, dtype=y_true_f.dtype) 103 | isOk = tf.not_equal(y_true_f, mask_value_cast) 104 | else: 105 | isOk = tf.ones_like(y_true_f, dtype=bool) 106 | isOk1 = tf.math.reduce_all(isOk, axis=-1) 107 | y_true_masked = tf.boolean_mask(y_true_f, isOk1, axis=0) 108 | y_pred_masked = tf.boolean_mask(y_pred_f, isOk1, axis=0) 109 | CC = compute_CC(y_true_masked, y_pred_masked) 110 | return CC 111 | 112 | 113 | def computeR2_masked(y_true, y_pred, mask_value=None): 114 | """Computes correlation of determination (R2) in tensorflow, with support for a masked value. 115 | First dimension of data is the sample/time dimension. If a sample has a mask_value in 116 | one of its dimensions, it will be discarded before the CC computation. 117 | Args: 118 | y_true (numpy array): input 1. 119 | y_pred (numpy array): input 2 120 | mask_value (numpy value, optional): if not None, will treat this value as mask indicator. Defaults to None. 121 | 122 | Returns: 123 | tf.Tensor: R2 value 124 | """ 125 | # Assumes that the last dimension is the only data dimension, others are sample dimensions (batch,time,etc) 126 | sh = tf.shape(y_true) 127 | y_true_r = tf.reshape(y_true, [tf.reduce_prod(sh[:-1]), sh[-1]]) 128 | y_pred_r = tf.reshape(y_pred, [tf.reduce_prod(sh[:-1]), sh[-1]]) 129 | y_true_f = tf.cast(y_true_r, dtype=y_pred.dtype) 130 | y_pred_f = tf.cast(y_pred_r, dtype=y_pred.dtype) 131 | if mask_value is not None: 132 | mask_value_cast = tf.constant(mask_value, dtype=y_true_f.dtype) 133 | isOk = tf.not_equal(y_true_f, mask_value_cast) 134 | else: 135 | isOk = tf.ones_like(y_true_f, dtype=bool) 136 | isOk1 = tf.math.reduce_all(isOk, axis=-1) 137 | y_true_masked = tf.boolean_mask(y_true_f, isOk1, axis=0) 138 | y_pred_masked = tf.boolean_mask(y_pred_f, isOk1, axis=0) 139 | R2 = compute_R2(y_true_masked, y_pred_masked) 140 | return R2 141 | 142 | 143 | def masked_CC(mask_value=None): 144 | """Returns a tf correlation coefficient (CC) computation function, but with support for setting one value as a mask indicator. 145 | Takes mean of CC across dimensions. See computeCC_masked for details of computing CC for each dimension. 146 | Args: 147 | mask_value (numpy value, optional): if not None, will treat this value as mask indicator. Defaults to 148 | """ 149 | 150 | def f(y_true, y_pred): 151 | meanCC = tf.math.reduce_mean( 152 | computeCC_masked(y_true, y_pred, mask_value) 153 | ) # Average across dimensions 154 | return meanCC 155 | 156 | f.__name__ = str("CC_maskV_{}".format(mask_value)) 157 | return f 158 | 159 | 160 | def masked_R2(mask_value=None): 161 | """Returns a tf R2 computation function, but with support for setting one value as a mask indicator. 162 | Takes mean of R2 across dimensions. See computeR2_masked for details of computing R2 for each dimension. 163 | Args: 164 | mask_value (numpy value, optional): if not None, will treat this value as mask indicator. Defaults to 165 | """ 166 | 167 | def f(y_true, y_pred): 168 | allR2 = computeR2_masked(y_true, y_pred, mask_value) 169 | meanR2 = tf.math.reduce_mean(allR2) # Average across dimensions 170 | return meanR2 171 | 172 | f.__name__ = str("R2_maskV_{}".format(mask_value)) 173 | return f 174 | 175 | 176 | def masked_negativeCC(mask_value=None): 177 | """Returns a tf negative correlation coefficient (CC) computation function, but with support for setting one value as a mask indicator. 178 | Takes mean of negative CC across dimensions. See computeCC_masked for details of computing CC for each dimension. 179 | Args: 180 | mask_value (numpy value, optional): if not None, will treat this value as mask indicator. Defaults to 181 | """ 182 | 183 | def f(y_true, y_pred): 184 | meanCC = tf.math.reduce_mean( 185 | computeCC_masked(y_true, y_pred, mask_value) 186 | ) # Average across dimensions 187 | return -meanCC 188 | 189 | f.__name__ = str("negCC_maskV_{}".format(mask_value)) 190 | return f 191 | 192 | 193 | def masked_negativeR2(mask_value=None): 194 | """Returns a tf negative correlation of determination (R2) computation function, but with support for setting one value as a mask indicator. 195 | Takes mean of negative R2 across dimensions. See computeR2_masked for details of computing R2 for each dimension. 196 | Args: 197 | mask_value (numpy value, optional): if not None, will treat this value as mask indicator. Defaults to 198 | """ 199 | 200 | def f(y_true, y_pred): 201 | meanR2 = tf.math.reduce_mean( 202 | computeR2_masked(y_true, y_pred, mask_value) 203 | ) # Average across dimensions 204 | return -meanR2 205 | 206 | f.__name__ = str("negR2_maskV_{}".format(mask_value)) 207 | return f 208 | 209 | 210 | def masked_PoissonLL_loss(mask_value=None): 211 | """Returns a tf function that computes the poisson negative log likelihood loss, with support for setting one value as a mask indicator. 212 | First dimension of data is the sample/time dimension. If a sample has a mask_value in 213 | one of its dimensions, it will be discarded before the loss computation. 214 | 215 | Args: 216 | mask_value (numpy value, optional): if not None, will treat this value as mask indicator. Defaults to None. 217 | """ 218 | 219 | def f(true_counts, pred_logLambda): 220 | sh = tf.shape(true_counts) 221 | true_counts_f = tf.reshape(true_counts, [tf.reduce_prod(sh[:-1]), sh[-1]]) 222 | pred_logLambda_f = tf.reshape(pred_logLambda, [tf.reduce_prod(sh[:-1]), sh[-1]]) 223 | if mask_value is not None: 224 | mask_value_cast = tf.constant(int(mask_value), dtype=true_counts_f.dtype) 225 | isOk = tf.not_equal(true_counts_f, mask_value_cast) 226 | else: 227 | isOk = tf.ones_like(true_counts_f, dtype=bool) 228 | isOk1 = tf.math.reduce_all(isOk, axis=-1) 229 | y_true_masked = tf.boolean_mask(true_counts_f, isOk1, axis=0) 230 | y_pred_masked = tf.boolean_mask(pred_logLambda_f, isOk1, axis=0) 231 | # LL = true_counts_f * pred_logLambda_f - tf.math.exp(pred_logLambda_f) - tf.math.lgamma( true_counts_f+1 ) 232 | # pLoss = - tf.reduce_mean(tf.boolean_mask(LL, isOk)) 233 | # https://www.tensorflow.org/api_docs/python/tf/keras/losses/poisson 234 | lossFunc = tf.keras.losses.Poisson() 235 | return lossFunc(y_true_masked, y_pred_masked) 236 | 237 | f.__name__ = str("PoissonLL_maskV_{}".format(mask_value)) 238 | return f 239 | 240 | 241 | def masked_CategoricalCrossentropy(mask_value=None): 242 | """Returns a tf function that computes the Categorical Crossentropy loss, but with support for setting one value as a mask indicator. 243 | First dimension of data is the sample/time dimension. If a sample has a mask_value in 244 | one of its dimensions, it will be discarded before the loss computation. 245 | 246 | Args: 247 | mask_value (numpy value, optional): if not None, will treat this value as mask indicator. Defaults to None. 248 | """ 249 | 250 | def f(y_true, y_pred): 251 | # Assumes that the last two dimensions are the only data dimensions, others are sample dimensions (batch,time,etc) 252 | sh = tf.shape(y_true) 253 | y_true = tf.reshape(y_true, [tf.reduce_prod(sh[:-2]), sh[-2], sh[-1]]) 254 | y_pred = tf.reshape(y_pred, [tf.reduce_prod(sh[:-2]), sh[-2], sh[-1]]) 255 | if mask_value is not None: 256 | mask_value_cast = tf.constant(int(mask_value), dtype=y_true.dtype) 257 | isOk = tf.not_equal(y_true, mask_value_cast) 258 | else: 259 | isOk = tf.ones_like(y_true, dtype=bool) 260 | isOk1 = tf.math.reduce_all( 261 | isOk, axis=tf.range(tf.rank(isOk) - 2, tf.rank(isOk)) 262 | ) 263 | y_true_masked = tf.boolean_mask(y_true, isOk1, axis=0) 264 | y_pred_masked = tf.boolean_mask(y_pred, isOk1, axis=0) 265 | lossFunc = tf.keras.losses.CategoricalCrossentropy( 266 | from_logits=True 267 | ) # Later will need softmax for pred_model 268 | return lossFunc(y_true_masked, y_pred_masked) 269 | 270 | f.__name__ = str("CCE_maskV_{}".format(mask_value)) 271 | return f 272 | 273 | 274 | def masked_SparseCategoricalCrossentropy(mask_value=None): 275 | """Returns a tf function that computes the Sparse Categorical Crossentropy loss, but with support for setting one value as a mask indicator. 276 | First dimension of data is the sample/time dimension. If a sample has a mask_value in 277 | one of its dimensions, it will be discarded before the loss computation. 278 | 279 | Args: 280 | mask_value (numpy value, optional): if not None, will treat this value as mask indicator. Defaults to None. 281 | """ 282 | 283 | def f(y_true, y_pred): 284 | # Assumes that the last dimension is the only data dimension, others are sample dimensions (batch,time,etc) 285 | sh = tf.shape(y_true) 286 | y_true = tf.reshape(y_true, [tf.reduce_prod(sh[:-1]), sh[-1]]) 287 | sh2 = tf.shape(y_pred) 288 | y_pred = tf.reshape(y_pred, [tf.reduce_prod(sh2[:-2]), sh2[-2], sh2[-1]]) 289 | if mask_value is not None: 290 | mask_value_cast = tf.constant(int(mask_value), dtype=y_true.dtype) 291 | isOk = tf.not_equal(y_true, mask_value_cast) 292 | else: 293 | isOk = tf.ones_like(y_true, dtype=bool) 294 | isOk1 = tf.math.reduce_all(isOk, axis=-1) 295 | y_true_masked = tf.boolean_mask(y_true, isOk1, axis=0) 296 | y_pred_masked = tf.boolean_mask(y_pred, isOk1, axis=0) 297 | lossFunc = tf.keras.losses.SparseCategoricalCrossentropy( 298 | from_logits=True 299 | ) # Later will need softmax for pred_model 300 | return lossFunc(y_true_masked, y_pred_masked) 301 | 302 | f.__name__ = str("SCCE_maskV_{}".format(mask_value)) 303 | return f 304 | -------------------------------------------------------------------------------- /source/DPAD/tools/tf_tools.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2024 University of Southern California 3 | See full notice in LICENSE.md 4 | Omid G. Sani and Maryam M. Shanechi 5 | Shanechi Lab, University of Southern California 6 | """ 7 | 8 | """Tensorflow tools""" 9 | 10 | import logging 11 | import os 12 | import time 13 | 14 | import numpy as np 15 | import tensorflow as tf 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | 20 | def set_global_tf_eagerly_flag(desired_flag=False): 21 | global_tf_eagerly_flag = ( 22 | tf.config.functions_run_eagerly() 23 | ) # Get global eager execution config 24 | if global_tf_eagerly_flag != desired_flag: 25 | logger.info( 26 | f"Changing global Tensorflow eagerly flag from {global_tf_eagerly_flag} to {desired_flag}" 27 | ) 28 | tf.config.run_functions_eagerly(desired_flag) # Disable global eager execution 29 | return global_tf_eagerly_flag 30 | 31 | 32 | def setupTensorflow(cpu=False): 33 | logger.info("Tensorflow version: {}".format(tf.__version__)) 34 | logger.info("Tensorflow path: {}".format(os.path.abspath(tf.__file__))) 35 | if cpu == False: 36 | gpus = tf.config.list_physical_devices("GPU") 37 | if gpus: 38 | try: 39 | # Currently, memory growth needs to be the same across GPUs 40 | for gpu in gpus: 41 | tf.config.experimental.set_memory_growth(gpu, True) 42 | logical_gpus = tf.config.list_logical_devices("GPU") 43 | logger.info( 44 | f"Found {len(logical_gpus)} Logical GPU(s) and {len(gpus)} Physical GPU(s): {gpus}" 45 | ) 46 | except RuntimeError as e: 47 | # Memory growth must be set before GPUs have been initialized 48 | logger.info(e) 49 | else: 50 | logger.info("No GPUs were found!") 51 | else: 52 | cpus = tf.config.list_physical_devices("CPU") 53 | logger.info("Using CPUs: {}".format(cpus)) 54 | pass 55 | 56 | 57 | def convertHistoryToDict(history, tic=None): 58 | """Converts tf model.fit history to a dictionary. 59 | 60 | Args: 61 | history (model.fit output): output of model.fit for a tf model. 62 | 63 | Returns: 64 | dict: dictionary form of the history. 65 | """ 66 | if tic is not None: 67 | toc = time.perf_counter() 68 | fit_time = toc - tic 69 | else: 70 | fit_time = None 71 | return { 72 | "epoch": history.epoch, 73 | "history": history.history, 74 | "params": history.params, 75 | "fit_time": fit_time, 76 | } 77 | 78 | 79 | def getModelFitHistoyStr( 80 | history=None, fields=["loss"], keep_ratio=1, history_dict=None, epoch=None 81 | ): 82 | """Prints a human readable summary of a tf model.fit history 83 | 84 | Args: 85 | history (model.fit output): output of model.fit for a tf model. Defaults to None. 86 | fields (list, optional): fields to print. Defaults to ['loss']. 87 | keep_ratio (int, optional): ratio of epochs to include in the log. Defaults to 1. 88 | history_dict (dict, optional): dictionary form of the history. Defaults to None. 89 | epoch (int, optional): number of epochs. Defaults to None. 90 | """ 91 | if epoch is None: 92 | epoch = history.epoch 93 | if history_dict is None: 94 | history_dict = history.history 95 | if keep_ratio < 1 and len(epoch) > 0: 96 | epochToPrint = list(range(0, len(epoch), int(np.ceil(keep_ratio * len(epoch))))) 97 | if (len(epoch) - 1) not in epochToPrint: 98 | epochToPrint.append(len(epoch) - 1) 99 | else: 100 | epochToPrint = range(len(epoch)) 101 | logStrAll = "" 102 | for ei in epochToPrint: 103 | logStr = "Epoch {}/{} - ".format(1 + epoch[ei], len(epoch)) 104 | metricNameStrs = [] 105 | metricVals = [] 106 | for f in fields: 107 | val = history_dict[f][ei] 108 | if val not in metricVals: 109 | metricVals.append(val) 110 | metricNameStrs.append(f) 111 | else: 112 | ind = metricVals.index(val) 113 | metricNameStrs[ind] += "={}".format(f) 114 | metricStrs = [ 115 | "{}={:.8g}".format(mName, mVal) 116 | for mName, mVal in zip(metricNameStrs, metricVals) 117 | ] 118 | logStr += ", ".join(metricStrs) 119 | logStrAll += logStr + "\n" 120 | return logStrAll 121 | -------------------------------------------------------------------------------- /source/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShanechiLab/DPAD/a09e7d9c3e59f1adb2d75336705dc166c15f4039/source/__init__.py -------------------------------------------------------------------------------- /source/setup.py: -------------------------------------------------------------------------------- 1 | # Release tutorial: 2 | # https://packaging.python.org/tutorials/packaging-projects/ 3 | # cd source 4 | # python setup.py sdist bdist_wheel 5 | # python -m twine upload --repository testpypi dist/* 6 | # pip install -i https://test.pypi.org/simple/ DPAD-omidsani --upgrade 7 | # python -m twine upload --repository pypi dist/* 8 | 9 | import setuptools 10 | 11 | with open("../README.md", "r", encoding="utf-8") as fh: 12 | long_description = fh.read() 13 | 14 | setuptools.setup( 15 | name="DPAD", 16 | version="1.0.0", 17 | author="Omid Sani", 18 | author_email="omidsani@gmail.com", 19 | description="Python implementation for DPAD (dissociative prioritized analysis of dynamics)", 20 | long_description=long_description, 21 | long_description_content_type="text/markdown", 22 | url="https://github.com/ShanechiLab/DPAD", 23 | packages=setuptools.find_packages(), 24 | classifiers=[ 25 | "Programming Language :: Python :: 3", 26 | "Operating System :: OS Independent", 27 | ], 28 | python_requires=">=3.6", 29 | ) 30 | --------------------------------------------------------------------------------