├── .gitignore ├── LICENSE ├── README.md ├── assets ├── img │ ├── inverse.png │ ├── overview_gh.png │ └── predictives.png └── matplotlibrc ├── automind ├── __init__.py ├── analysis │ ├── __init__.py │ ├── analysis_runners.py │ └── spikes_summary.py ├── inference │ ├── __init__.py │ ├── algorithms │ │ ├── GBI.py │ │ ├── Regression.py │ │ └── mcmc_posterior.py │ ├── inferer.py │ └── trainers.py ├── sim │ ├── __init__.py │ ├── b2_interface.py │ ├── b2_models.py │ ├── default_configs.py │ └── runners.py └── utils │ ├── __init__.py │ ├── analysis_utils.py │ ├── data_utils.py │ ├── dist_utils.py │ ├── organoid_utils.py │ └── plot_utils.py ├── datasets └── README.md ├── environment.yml ├── notebooks ├── demo-1_automind_inference_workflow.ipynb └── demo-2_automind_inference_from_spikes.ipynb └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | 164 | ### Exclude dataset files, tmp, and cache produced from simulations 165 | /datasets/** 166 | !datasets/README.md 167 | .cache/ 168 | tmp/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AutoMIND 2 | 3 | Automated Model Inference from Neural Dynamics **(AutoMIND)** is an inverse modeling framework for investigating neural circuit mechanisms underlying population dynamics. 4 | 5 | AutoMIND helps with efficient discovery of **many** parameter configurations that are consistent with target observations of neural population dynamics. To do so, it combines a flexible, highly parameterized spiking neural network as the mechanistic model (simulated in `brian2`), with powerful deep generative models (Normalizing Flows in `pytorch`) in the framework of simulation-based inference (powered by `sbi`). 6 | 7 | For a sneak peak of the workflow and what's possible, check out the [**overview demo**](./notebooks/demo-1_automind_inference_workflow.ipynb) and our **preprint**, [Deep inverse modeling reveals dynamic-dependent invariances in neural circuit mechanisms](https://www.biorxiv.org/content/10.1101/2024.08.21.608969v1). 8 | 9 | This repository contains the package `automind`, demo notebooks, links to generated simulation datasets and trained deep generative models ([figshare link](https://figshare.com/s/3f1467f8fb0f328aed16)), as well as code to reproduce figures and results from the manuscript. 10 | 11 | ![](./assets/img/overview_gh.png) 12 | 13 | --- 14 | 15 | ### Running the code 16 | After cloning this repo, we recommend creating a conda environment using the included `environment.yml` file, which installs the necessary `conda` and `pip` dependencies, as well as the package `automind` itself in editable mode: 17 | 18 | ``` 19 | git clone https://github.com/mackelab/automind.git 20 | cd automind 21 | conda env create -f environment.yml 22 | conda activate automind 23 | ``` 24 | 25 | The codebase will be updated over the next few weeks to enable successive capabilities: 26 | - [x] **Inference**: sampling from included trained DGMs conditioning on the same summary statistics of example or new target observations: 27 | - [**Demo-1**](./notebooks/demo-1_automind_inference_workflow.ipynb) on a synthetic target using PSD features. 28 | - [**Demo-2**](./notebooks/demo-2_automind_inference_from_spikes.ipynb) on a real recording using network burst features, starting from spike timestamps. 29 | - [ ] **Training**: training new DGMs on a different set of summary statistics or simulations. 30 | - [ ] **Parallel simulations**: running and saving many simulations to disk, e.g., on compute cluster. 31 | - [ ] **Analysis**: Analyzing and visualizing discovered parameter configurations. 32 | - [ ] **...and more!** 33 | 34 | 35 | 36 | 37 | --- 38 | # Dataset 39 | ### Model parameter configurations, simulations, and trained deep generative models 40 | Model configurations and simulations used to train DGMs, target observations (including experimental data from organoid and mouse), hundreds of discovered model configurations and corresponding simulations consistent with those targets, and trained posterior density estimators can be found [on figshare](https://figshare.com/s/3f1467f8fb0f328aed16). 41 | 42 | See [here](./datasets/README.md) for details and download instructions. 43 | 44 | ![](./assets/img/predictives.png) -------------------------------------------------------------------------------- /assets/img/inverse.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mackelab/automind/e5e3cd999b39bf2b9bd365e43c6d379d245d3e77/assets/img/inverse.png -------------------------------------------------------------------------------- /assets/img/overview_gh.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mackelab/automind/e5e3cd999b39bf2b9bd365e43c6d379d245d3e77/assets/img/overview_gh.png -------------------------------------------------------------------------------- /assets/img/predictives.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mackelab/automind/e5e3cd999b39bf2b9bd365e43c6d379d245d3e77/assets/img/predictives.png -------------------------------------------------------------------------------- /assets/matplotlibrc: -------------------------------------------------------------------------------- 1 | # http://matplotlib.org/users/customizing.html 2 | 3 | # Note: Units are in pt not in px 4 | # 5 | # How to convert px to pt in Inkscape 6 | # > Inkscape pixel is 1/90 of an inch, other software usually uses 1/72. 7 | # > This means if you need 10pt - use 12.5 in Inkscape (multiply with 1.25). 8 | # > http://www.inkscapeforum.com/viewtopic.php?f=6&t=5964 9 | 10 | text.usetex : False 11 | mathtext.default : regular 12 | 13 | font.family : sans-serif 14 | font.serif : Arial, sans-serif 15 | font.sans-serif : Arial, sans-serif 16 | font.cursive : Arial, sans-serif 17 | font.size : 6 18 | figure.titlesize : 6 19 | figure.dpi : 300 20 | image.cmap : RdBu_r 21 | 22 | figure.figsize : 1.5, 1.5 23 | figure.constrained_layout.use: True 24 | figure.constrained_layout.h_pad: 0.#0.04167 25 | figure.constrained_layout.hspace: 0.#0.02 26 | figure.constrained_layout.w_pad: 0.#0.04167 27 | figure.constrained_layout.wspace: 0.#0.02 28 | 29 | legend.fontsize : 5 30 | axes.titlesize : 5 31 | axes.labelsize : 5 32 | xtick.labelsize : 5 33 | ytick.labelsize : 5 34 | 35 | image.interpolation : nearest 36 | image.resample : False 37 | image.composite_image : True 38 | 39 | axes.spines.left : True 40 | axes.spines.bottom : True 41 | axes.spines.top : False 42 | axes.spines.right : False 43 | 44 | # all width default 1.0s 45 | axes.linewidth : 1 46 | xtick.major.width : 1 47 | xtick.major.size : 2 48 | ytick.major.width : 1 49 | ytick.major.size : 2 50 | 51 | xtick.minor.width : 0.5 52 | xtick.minor.size : 1. 53 | ytick.minor.width : 0.5 54 | ytick.minor.size : 1. 55 | 56 | lines.linewidth : 1 57 | lines.markersize : 3 58 | 59 | savefig.dpi : 300 60 | savefig.format : svg 61 | savefig.bbox : tight 62 | savefig.pad_inches : 0. #0.1 63 | 64 | svg.image_inline : True 65 | svg.fonttype : none 66 | 67 | legend.frameon : False 68 | axes.prop_cycle : cycler('color', ['k', '1f77b4', 'ff7f0e', '2ca02c', 'd62728', '9467bd', '8c564b', 'e377c2', '7f7f7f', 'bcbd22', '17becf']) -------------------------------------------------------------------------------- /automind/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mackelab/automind/e5e3cd999b39bf2b9bd365e43c6d379d245d3e77/automind/__init__.py -------------------------------------------------------------------------------- /automind/analysis/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mackelab/automind/e5e3cd999b39bf2b9bd365e43c6d379d245d3e77/automind/analysis/__init__.py -------------------------------------------------------------------------------- /automind/analysis/analysis_runners.py: -------------------------------------------------------------------------------- 1 | ### Script and helpers for running analysis (i.e., summary stats) on simulation data. 2 | 3 | import numpy as np 4 | import pandas as pd 5 | import h5py as h5 6 | from time import time 7 | import os 8 | from multiprocessing import Pool 9 | from functools import partial 10 | from ..utils import data_utils, plot_utils, analysis_utils 11 | 12 | 13 | def set_up_analysis(run_folder, analysis_settings_update, ref_params_dict_path=None): 14 | """Set up analysis by loading all necessary files and updating the params dict. 15 | 16 | Args: 17 | run_folder (str): path to the simulation folder. 18 | analysis_settings_update (dict): dictionary of analysis settings to update. 19 | ref_params_dict_path (str, optional): Path to reference configurations. Defaults to None. 20 | 21 | Returns: 22 | tuple: batch_seed, data_file_dict, prior, params_dict_default, df_prior_samples_ran 23 | """ 24 | # extract relevant file names 25 | file_type_list = [ 26 | "prior.pickle", 27 | "raw_data.hdf5", 28 | "prior_samples_ran.csv", 29 | "summary_data_merged.csv", 30 | "params_dict_default.pickle", 31 | ] 32 | data_folder = run_folder + "data/" 33 | data_file_dict = data_utils.extract_data_files(data_folder, file_type_list) 34 | batch_seed = data_file_dict["prior_samples_ran"].split("_")[0] 35 | 36 | # load everything 37 | prior = data_utils.load_pickled(data_folder + data_file_dict["prior"]) 38 | params_dict_default = data_utils.load_pickled( 39 | data_folder + data_file_dict["params_dict_default"] 40 | ) 41 | df_prior_samples_ran = pd.read_csv( 42 | data_folder + data_file_dict["prior_samples_ran"], index_col=0 43 | ) 44 | 45 | if ref_params_dict_path: 46 | # load reference parameter set 47 | params_dict_updated = data_utils.load_pickled(ref_params_dict_path) 48 | 49 | # swap in the new analysis sub-dict and path dicts 50 | params_dict_default["params_analysis"] = params_dict_updated["params_analysis"] 51 | path_dict = data_utils.make_subfolders(run_folder, ["figures"]) 52 | params_dict_default["path_dict"] = path_dict 53 | 54 | # update with new setting 55 | params_dict_default = data_utils.update_params_dict( 56 | params_dict_default, analysis_settings_update 57 | ) 58 | 59 | return batch_seed, data_file_dict, prior, params_dict_default, df_prior_samples_ran 60 | 61 | 62 | def make_summary_parallel( 63 | run_folder, 64 | analysis_settings_update, 65 | reference_params_path=None, 66 | num_cores=64, 67 | n_per_batch=200, 68 | overwrite_existing=False, 69 | ): 70 | """Script to run (parallelized) analysis on a simulation folder, streams while analyzing. 71 | 72 | Args: 73 | run_folder (str): path to the simulation folder, should be one level above '/data/' folder that contains '{batchseed}_raw.hd5f' file. 74 | analysis_settings_update (dict): dictionary of analysis settings to update. 75 | reference_params_path (str, optional): Path to reference configurations. Defaults to None. 76 | num_cores (int, optional): Number of cores to use. Defaults to 64. 77 | n_per_batch (int, optional): Number of simulations to process per batch. Defaults to 200. 78 | overwrite_existing (bool, optional): Overwrite existing summary file. Defaults to False. 79 | """ 80 | 81 | if not os.path.isdir(run_folder + "/data/"): 82 | print("Invalid simulation folder, skip.") 83 | return 84 | 85 | ( 86 | batch_seed, 87 | data_file_dict, 88 | prior, 89 | params_dict_default, 90 | df_prior_samples_ran, 91 | ) = set_up_analysis(run_folder, analysis_settings_update, reference_params_path) 92 | analysis_name = params_dict_default["params_analysis"]["summary_set_name"] 93 | 94 | # get all random seeds 95 | h5_file_path = data_file_dict["root_path"] + data_file_dict["raw_data"] 96 | with h5.File(h5_file_path, "r") as h5_file: 97 | random_seeds = list(h5_file[batch_seed].keys()) 98 | random_seeds = list( 99 | np.array(random_seeds)[np.argsort([int(s) for s in random_seeds])] 100 | ) 101 | 102 | print(len(random_seeds), "random seeds found in h5 file") 103 | print(len(df_prior_samples_ran), "random seeds found in prior samples file") 104 | 105 | # check if random seeds are unique 106 | n_samples = len(random_seeds) 107 | print( 108 | f"Random seeds of runs in h5 file are unique: {n_samples==len(np.unique(random_seeds))}" 109 | ) 110 | 111 | # check if stream file already exists 112 | stream_file_path = ( 113 | data_file_dict["root_path"] + f"{batch_seed}_{analysis_name}_summary_stream.csv" 114 | ) 115 | if os.path.isfile(stream_file_path): 116 | print(f"{stream_file_path} exists, deleting...") 117 | os.remove(stream_file_path) 118 | 119 | # check if save file already exists 120 | save_file_path = ( 121 | data_file_dict["root_path"] 122 | + f"{batch_seed}_{analysis_name}_summary_data_merged.csv" 123 | ) 124 | if os.path.isfile(save_file_path): 125 | print(f"{save_file_path} exists.") 126 | if not overwrite_existing: 127 | df_summary_stream_previous = pd.read_csv(save_file_path, index_col=0) 128 | if np.all( 129 | [ 130 | int(rs) 131 | in df_summary_stream_previous["params_settings.random_seed"].values 132 | for rs in random_seeds 133 | ] 134 | ): 135 | print("Skipping analysis.") 136 | return 137 | else: 138 | print(f"File exists but incomplete, overwrite.") 139 | else: 140 | print(f"File exists but overwriting.") 141 | 142 | ### RUN ANALYSIS ### 143 | print("---- Running analysis... ----") 144 | n_batches = int(np.ceil(n_samples / n_per_batch)) 145 | print(f"{n_samples} files / {n_batches} batches to process.") 146 | 147 | for i_b in range(n_batches): 148 | start_time = time() 149 | 150 | i_start, i_end = i_b * n_per_batch, (i_b + 1) * n_per_batch 151 | random_seeds_batch = random_seeds[i_start:i_end] 152 | 153 | # grab spikes and process 154 | with h5.File(h5_file_path, "r") as h5_file: 155 | spikes_collected = [] 156 | # iterate through batches 157 | for random_seed in random_seeds_batch: 158 | matched_prior = df_prior_samples_ran[ 159 | df_prior_samples_ran["params_settings.random_seed"] 160 | == int(random_seed) 161 | ] 162 | if len(matched_prior) == 1: 163 | # only one match 164 | # check if early stopped 165 | 166 | if matched_prior["params_analysis.early_stopped"].values: 167 | print( 168 | f"{batch_seed}-{random_seed} early stopped run escaped, remove." 169 | ) 170 | # remove the seed from this batch 171 | random_seeds_batch.remove(random_seed) 172 | else: 173 | # add to processing queue 174 | spikes_dict = data_utils.get_spikes_h5( 175 | h5_file, f"{batch_seed}/{random_seed}/" 176 | ) 177 | spikes_dict["t_end"] = matched_prior[ 178 | "params_settings.sim_time" 179 | ].item() 180 | spikes_collected.append(spikes_dict) 181 | else: 182 | print(f"{random_seed} with {len(matched_prior)} matches, skip.") 183 | 184 | # process in parallel 185 | # find the appropriate analysis function 186 | if analysis_name == "prescreen": 187 | # just single unit spiketrain analysis 188 | if num_cores == 1: 189 | summary_dfs = [ 190 | analysis_utils.compute_spike_features_only(s, params_dict_default) 191 | for s in spikes_collected 192 | ] 193 | else: 194 | func_analysis = partial( 195 | analysis_utils.compute_spike_features_only, 196 | params_dict=params_dict_default, 197 | ) 198 | with Pool(num_cores) as pool: 199 | # NOTE: POSSIBLE FAILURE HERE WHEN RETURN IS EMPTY 200 | summary_dfs = pool.map(func_analysis, spikes_collected) 201 | 202 | elif analysis_name == "spikes_bursts": 203 | # full burst analysis 204 | func_analysis = partial( 205 | analysis_utils.compute_spike_burst_features, 206 | params_dict=params_dict_default, 207 | ) 208 | with Pool(num_cores) as pool: 209 | # POSSIBLE FAILURE HERE WHEN RETURN IS EMPTY 210 | summary_dfs, pop_rates, burst_stats = list( 211 | zip(*pool.map(func_analysis, spikes_collected)) 212 | ) 213 | 214 | elif analysis_name == "MK1": 215 | # full features with spikes, bursts, PSDs, PCA 216 | func_analysis = partial( 217 | analysis_utils.compute_summary_features, params_dict=params_dict_default 218 | ) 219 | 220 | with Pool(num_cores) as pool: 221 | summary_stats = list(pool.map(func_analysis, spikes_collected)) 222 | # squish them into one big df row 223 | summary_dfs = [ 224 | pd.concat( 225 | [ 226 | s["summary_spikes"], 227 | s["summary_bursts"], 228 | s["summary_pca"], 229 | s["summary_psd"].loc["exc_rate"].to_frame(name=0).T, 230 | ], 231 | axis=1, 232 | ) 233 | for s in summary_stats 234 | ] 235 | 236 | # collect dfs and stream out 237 | # here it's assumed that the seeds and output df are aligned, but is not guaranteed 238 | # NOTE: it fails silently in a way such that, if the lengths of random_seeds_batch 239 | # and summary_dfs are not the same, it only collects up to the shorter length 240 | df_summaries = pd.concat( 241 | dict(zip(random_seeds_batch, summary_dfs)) 242 | ).reset_index(level=1, drop=True) 243 | df_summaries.to_csv(stream_file_path, mode="a", header=(i_b == 0)) 244 | 245 | print(f"[batch {i_b+1} of {n_batches}]: analysis time", time() - start_time) 246 | 247 | # plot 248 | if analysis_settings_update["params_analysis.do_plot"]: 249 | if analysis_name == "spikes_bursts": 250 | func_plot = partial( 251 | plot_utils.plot_wrapper, params_dict=params_dict_default 252 | ) 253 | with Pool(num_cores) as pool: 254 | pool.starmap( 255 | func_plot, list(zip(pop_rates, burst_stats, random_seeds_batch)) 256 | ) 257 | 258 | elif analysis_name == "MK1": 259 | func_plot = partial( 260 | plot_utils.plot_wrapper_MK1, params_dict=params_dict_default 261 | ) 262 | with Pool(num_cores) as pool: 263 | pool.starmap( 264 | func_plot, list(zip(summary_stats, random_seeds_batch)) 265 | ) 266 | 267 | print(f"[batch {i_b+1}]: plot time", time() - start_time) 268 | 269 | print("\n-----") 270 | 271 | # save out params dict 272 | data_utils.pickle_file( 273 | data_file_dict["root_path"] 274 | + f"{batch_seed}_{analysis_name}_params_dict_analysis_updated.pickle", 275 | params_dict_default, 276 | ) 277 | 278 | # reload and merge 279 | df_summary_stream = pd.read_csv(stream_file_path, index_col=0) 280 | df_summary = data_utils.merge_theta_and_x( 281 | df_prior_samples_ran, df_summary_stream.index.values, df_summary_stream 282 | ) 283 | df_summary.to_csv(save_file_path) 284 | 285 | # remove streaming file 286 | os.remove(stream_file_path) 287 | -------------------------------------------------------------------------------- /automind/inference/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mackelab/automind/e5e3cd999b39bf2b9bd365e43c6d379d245d3e77/automind/inference/__init__.py -------------------------------------------------------------------------------- /automind/inference/algorithms/GBI.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | from typing import Any, Callable, Dict, Optional, Union 3 | from copy import deepcopy 4 | import numpy as np 5 | from math import ceil 6 | import torch 7 | from torch import Tensor, nn, optim 8 | import torch.nn.functional as F 9 | from torch.utils.data import TensorDataset, DataLoader 10 | from torch.distributions import Distribution 11 | 12 | from sbi.utils.torchutils import atleast_2d 13 | from sbi.inference.potentials.base_potential import BasePotential 14 | from sbi.neural_nets.embedding_nets import PermutationInvariantEmbedding, FCEmbedding 15 | from time import time 16 | from pyknos.nflows.nn import nets 17 | from sbi.utils.sbiutils import standardizing_net 18 | 19 | 20 | class GBInference: 21 | def __init__( 22 | self, 23 | prior: Distribution, 24 | distance_func: Callable, 25 | do_precompute_distances: bool = True, 26 | include_bad_sims: Optional[bool] = False, 27 | nan_dists_replacement: Optional[bool] = 5.0, 28 | ): 29 | self.prior = prior 30 | self.distance_func = distance_func 31 | self.do_precompute_distances = do_precompute_distances 32 | self.include_bad_sims = include_bad_sims 33 | self.nan_dists_replacement = nan_dists_replacement 34 | self.zscore_distance_precomputed = [] 35 | 36 | def append_simulations( 37 | self, 38 | theta: Tensor, 39 | x: Tensor, 40 | x_target: Tensor, 41 | n_dists_precompute: Union[int, float] = 1.0, 42 | ): 43 | """Append simulation data: theta, x, and target x.""" 44 | self.theta = theta 45 | self.x = x 46 | self.x_target = x_target 47 | 48 | if self.do_precompute_distances: 49 | # Pre compute the distance function between all x and x_targets. 50 | self._precompute_distance() 51 | # self._compute_index_pairs() 52 | 53 | # Precompute a subset of distances for training. 54 | self._precompute_subset_of_dists(n_dists_precompute) 55 | if self.include_bad_sims: 56 | # Get replacement distance by multiplying the max distance by a scaling factor. 57 | self.nan_dists_replacement = ( 58 | self.nan_dists_replacement * self.zscore_distance_precomputed.max() 59 | ) 60 | # Get rid of any x_targets with nans. 61 | self.x_target = self.x_target[~self.x_target.isnan().any(1)] 62 | 63 | return self 64 | 65 | def initialize_distance_estimator( 66 | self, 67 | num_layers: int, 68 | num_hidden: int, 69 | net_type: str = "resnet", 70 | positive_constraint_fn: str = None, 71 | net_kwargs: Optional[Dict] = {}, 72 | ): 73 | """Initialize neural network for distance regression.""" 74 | self.distance_net = DistanceEstimator( 75 | self.theta, 76 | self.x, 77 | self.zscore_distance_precomputed.flatten(), 78 | num_layers, 79 | num_hidden, 80 | net_type, 81 | positive_constraint_fn, 82 | **net_kwargs, 83 | ) 84 | 85 | def train( 86 | self, 87 | distance_net: Optional[nn.Module] = None, 88 | training_batch_size: int = 500, 89 | max_n_epochs: int = 1000, 90 | stop_after_counter_reaches: int = 50, 91 | validation_fraction: float = 0.1, 92 | n_train_per_theta: int = 1, 93 | n_val_per_theta: int = 1, 94 | print_every_n: int = 20, 95 | plot_losses: bool = True, 96 | ) -> nn.Module: 97 | # Can use custom distance net, otherwise take existing in class. 98 | if distance_net != None: 99 | if self.distance_net != None: 100 | print("Warning: Overwriting existing distance net.") 101 | self.distance_net = distance_net 102 | 103 | # Define loss and optimizer. 104 | nn_loss = nn.MSELoss() 105 | optimizer = optim.Adam(self.distance_net.parameters()) 106 | 107 | # Hold out entire rows of theta/x for validation, but leave all x_targets intact. 108 | # The other option is to hold out all thetas, as well as the corresponding xs in x_target. 109 | dataset = TensorDataset(torch.arange(len(self.theta))) 110 | train_set, val_set = torch.utils.data.random_split( 111 | dataset, [1 - validation_fraction, validation_fraction] 112 | ) 113 | dataloader = DataLoader(train_set, batch_size=training_batch_size, shuffle=True) 114 | 115 | # Training loop. 116 | train_losses, val_losses = [], [] 117 | epoch = 0 118 | self._val_loss = torch.inf 119 | while epoch <= max_n_epochs and not self._check_convergence( 120 | epoch, stop_after_counter_reaches 121 | ): 122 | time_start = time() 123 | 124 | # If use all validation data, then pre-compute and store all possible indices 125 | # Otherwise, sample new ones per epoch 126 | if (epoch == 0) or (n_val_per_theta != -1): 127 | # Only sample when first epoch, or when not using all validation data 128 | # i.e., don't resample if using all validation data and epoch > 0 129 | idx_val = self.make_index_pairs( 130 | torch.Tensor(val_set.indices).to(int), 131 | torch.arange(len(self.x_target), dtype=int), 132 | n_val_per_theta, 133 | ) 134 | theta_val, _, xt_val, dist_val = self.get_theta_x_distances(idx_val) 135 | 136 | for i_b, idx_theta_batch in enumerate(dataloader): 137 | optimizer.zero_grad() 138 | 139 | # Randomly sample n x_target for each theta, and get the data. 140 | idx_batch = self.make_index_pairs( 141 | idx_theta_batch[0], 142 | torch.arange(len(self.x_target), dtype=int), 143 | n_train_per_theta, 144 | ) 145 | theta_batch, _, xt_batch, dist_batch = self.get_theta_x_distances( 146 | idx_batch 147 | ) 148 | # Forward pass for distances. 149 | dist_pred = self.distance_net(theta_batch, xt_batch).squeeze() 150 | 151 | # Training loss. 152 | l = nn_loss(dist_batch, dist_pred) 153 | l.backward() 154 | optimizer.step() 155 | train_losses.append(l.detach().item()) 156 | 157 | # Compute validation loss each epoch. 158 | with torch.no_grad(): 159 | dist_pred = self.distance_net(theta_val, xt_val).squeeze() 160 | self._val_loss = nn_loss(dist_val, dist_pred).item() 161 | val_losses.append([(i_b + 1) * epoch, self._val_loss]) 162 | 163 | # Print validation loss 164 | if epoch % print_every_n == 0: 165 | print( 166 | f"{epoch}: train loss: {train_losses[-1]:.6f}, val loss: {self._val_loss:.6f}, best val loss: {self._best_val_loss:.6f}, {(time()-time_start):.4f} seconds per epoch." 167 | ) 168 | 169 | epoch += 1 170 | 171 | print(f"Network converged after {epoch-1} of {max_n_epochs} epochs.") 172 | 173 | # Plot loss curves for convenience. 174 | self.train_losses = torch.Tensor(train_losses) 175 | self.val_losses = torch.Tensor(val_losses) 176 | if plot_losses: 177 | self._plot_losses(self.train_losses, self.val_losses) 178 | 179 | # Avoid keeping the gradients in the resulting network, which can 180 | # cause memory leakage when benchmarking. 181 | self.distance_net.zero_grad(set_to_none=True) 182 | return deepcopy(self.distance_net) 183 | 184 | def predict_distance(self, theta, x, require_grad=False): 185 | # Convenience function that does fixes the shape of x. 186 | # Expands to have the same batch size as theta, in case x is [1, n_dim]. 187 | if theta.shape[0] != x.shape[0]: 188 | if len(x.shape) == 2: 189 | x = x.repeat(theta.shape[0], 1) 190 | elif len(x.shape) == 3: 191 | # Has multiple independent observations, i.e., gaussian mixture task. 192 | x = x.repeat(theta.shape[0], 1, 1) 193 | 194 | dist = self.distance_net(theta, x).squeeze(1) 195 | return dist if require_grad else dist.detach() 196 | 197 | def _check_convergence(self, counter: int, stop_after_counter_reaches: int) -> bool: 198 | """Return whether the training converged yet and save best model state so far. 199 | Checks for improvement in validation performance over previous batches or epochs. 200 | """ 201 | converged = False 202 | 203 | assert self.distance_net is not None 204 | distance_net = self.distance_net 205 | 206 | # (Re)-start the epoch count with the first epoch or any improvement. 207 | if counter == 0 or self._val_loss < self._best_val_loss: 208 | self._best_val_loss = self._val_loss 209 | self._counts_since_last_improvement = 0 210 | self._best_model_state_dict = deepcopy(distance_net.state_dict()) 211 | else: 212 | self._counts_since_last_improvement += 1 213 | 214 | # If no validation improvement over many epochs, stop training. 215 | if self._counts_since_last_improvement > stop_after_counter_reaches - 1: 216 | distance_net.load_state_dict(self._best_model_state_dict) 217 | converged = True 218 | 219 | return converged 220 | 221 | def build_amortized_GLL(self, distance_net: nn.Module = None): 222 | """Build generalized likelihood function from distance predictor.""" 223 | 224 | # Can use custom distance net, otherwise take existing in class. 225 | if distance_net != None: 226 | if not hasattr(self, "distance_net"): 227 | self.distance_net = distance_net 228 | else: 229 | if self.distance_net != None: 230 | print("Warning: Overwriting existing distance net.") 231 | self.distance_net = distance_net 232 | 233 | if not hasattr(self, "distance_net"): 234 | raise ValueError( 235 | "Must initialize distance net before building amortized GLL." 236 | ) 237 | 238 | # Return function 239 | return self._generalized_loglikelihood 240 | 241 | # Build likelihood function, moved outside so MCMC posterior can be pickled 242 | def _generalized_loglikelihood(self, theta: Tensor, x_o: Tensor): 243 | theta = atleast_2d(theta) 244 | dist_pred = self.predict_distance(theta, x_o, require_grad=True) 245 | assert dist_pred.shape == (theta.shape[0],) 246 | return dist_pred 247 | 248 | def get_potential(self, x_o: Tensor = None, beta: float = 1.0): 249 | """Make the potential function. Pass through call to GBIPotenial object.""" 250 | return GBIPotential(self.prior, self.build_amortized_GLL(), x_o, beta) 251 | 252 | def build_posterior( 253 | self, posterior_func: Callable, x_o: Tensor = None, beta: float = 1.0 254 | ): 255 | """Create posterior object using the defined potential function.""" 256 | potential_func = self.get_potential(x_o, beta) 257 | posterior = posterior_func(potential_func, self.prior) 258 | return posterior 259 | 260 | def _precompute_distance(self): 261 | """Pre-compute the distances of all pairs of x and x_target.""" 262 | self.distance_precomputed = [] 263 | t_start = time() 264 | print("Pre-computing distances...", end=" ") 265 | for x_t in self.x_target: 266 | self.distance_precomputed.append( 267 | self.compute_distance(self.x, x_t).unsqueeze(1) 268 | ) 269 | # self.distance_precomputed.append( 270 | # self.distance_func(self.x.unsqueeze(1), x_t).unsqueeze(1) 271 | # ) 272 | self.distance_precomputed = torch.hstack(self.distance_precomputed) 273 | print(f"finished in {time()-t_start} seconds.") 274 | 275 | # def _precompute_subset_of_dists(self, num_x_and_xt): 276 | # """Pre-compute the distances of some x and x_target pairs for z-scoring.""" 277 | # self.zscore_distance_precomputed = [] 278 | # random_inds_x = torch.randint(0, len(self.x), (num_x_and_xt,)) 279 | # random_inds_xtarget = torch.randint(0, len(self.x_target), (num_x_and_xt,)) 280 | 281 | # # directly compute distance between N random pairs, not N x N 282 | # self.zscore_distance_precomputed = self.compute_distance( 283 | # self.x[random_inds_x], self.x_target[random_inds_xtarget] 284 | # ) 285 | 286 | def _precompute_subset_of_dists(self, n_dists): 287 | """Pre-compute the distances of some x and x_target pairs for z-scoring.""" 288 | self.zscore_distance_precomputed = [] 289 | if type(n_dists) == int: 290 | n_dists = n_dists 291 | elif type(n_dists) == float: 292 | n_dists = int(self.theta.shape[0] * n_dists) 293 | 294 | """Compute valid distances, i.e. distances that are not nan.""" 295 | x_ = self.x[~self.x.isnan().any(1)] 296 | xt_ = self.x_target[~self.x_target.isnan().any(1)] 297 | random_inds_x = torch.randint(0, len(x_), (n_dists,)) 298 | random_inds_xtarget = torch.randint(0, len(xt_), (n_dists,)) 299 | self.zscore_distance_precomputed = self.compute_distance( 300 | x_[random_inds_x], xt_[random_inds_xtarget] 301 | ) 302 | 303 | def compute_distance(self, x: Tensor, x_target: Tensor): 304 | """Compute distance between x and x_target.""" 305 | # x_target should have leading dim of 1 or same as x. 306 | assert ( 307 | x_target.shape[0] == 1 308 | or x_target.shape[0] == x.shape[0] 309 | or len(x_target.shape) == len(x.shape) - 1 310 | ), f"x_target should have: same leading dim as x, leading dim of 1, or have 1 less dim than x, but have shapes x: {x.shape}, x_target: {x_target.shape}." 311 | return self.distance_func(x.unsqueeze(1), x_target.unsqueeze(1)) 312 | 313 | def subsample_indices(self, n_total, n_samples): 314 | return torch.randperm(n_total)[:n_samples] 315 | 316 | def make_index_pairs(self, theta_indices, x_target_indices, n_draws_per_theta=1): 317 | if n_draws_per_theta == -1: 318 | # Return all possible index pairs. 319 | return torch.cartesian_prod(theta_indices, x_target_indices) 320 | else: 321 | # Return a random subset of index pairs, n per theta. 322 | index_pairs = [] 323 | for i in range(n_draws_per_theta): 324 | xt_idx = self.subsample_indices( 325 | len(x_target_indices), len(theta_indices) 326 | ) 327 | index_pairs.append( 328 | torch.vstack((theta_indices, x_target_indices[xt_idx])).T 329 | ) 330 | return torch.concat(index_pairs, dim=0) 331 | 332 | def get_theta_x_distances(self, index_pairs): 333 | """Return theta, x, x_target, and distance for each index pair.""" 334 | theta = self.theta[index_pairs[:, 0]] 335 | x = self.x[index_pairs[:, 0]] 336 | x_target = self.x_target[index_pairs[:, 1]] 337 | if self.do_precompute_distances: 338 | dist = self.distance_precomputed[index_pairs[:, 0], index_pairs[:, 1]] 339 | else: 340 | dist = self.compute_distance(x, x_target) 341 | 342 | if self.include_bad_sims: 343 | # Need to replace nan-dists with replacement. 344 | # Maybe do this stochastically? 345 | dist[dist.isnan()] = self.nan_dists_replacement 346 | return theta, x, x_target, dist 347 | 348 | def _plot_losses(self, train_losses, val_losses): 349 | plt.figure(figsize=(8, 3)) 350 | plt.plot(train_losses, "k", alpha=0.8) 351 | plt.plot(val_losses[:, 0], val_losses[:, 1], "r.-", alpha=0.8) 352 | plt.savefig("losses.png") 353 | 354 | 355 | class DistanceEstimator(nn.Module): 356 | def __init__( 357 | self, 358 | theta, 359 | x, 360 | dists, 361 | num_layers, 362 | hidden_features, 363 | net_type, 364 | positive_constraint_fn=None, 365 | dropout_prob=0.0, 366 | use_batch_norm=False, 367 | activation=F.relu, 368 | activate_output=False, 369 | trial_net_input_dim=None, 370 | trial_net_output_dim=None, 371 | z_score_theta: bool = True, 372 | z_score_x: bool = True, 373 | z_score_dists: bool = True, 374 | ): 375 | ## TO DO: probably should put all those kwargs in kwargs 376 | super().__init__() 377 | 378 | theta_dim = theta.shape[1] 379 | x_dim = x.shape[1] 380 | 381 | if trial_net_input_dim is not None and trial_net_output_dim is not None: 382 | output_dim_e_net = 20 383 | trial_net = FCEmbedding( 384 | input_dim=trial_net_input_dim, output_dim=trial_net_output_dim 385 | ) 386 | self.embedding_net_x = PermutationInvariantEmbedding( 387 | trial_net=trial_net, 388 | trial_net_output_dim=trial_net_output_dim, 389 | output_dim=output_dim_e_net, 390 | ) 391 | input_dim = theta_dim + output_dim_e_net 392 | else: 393 | self.embedding_net_x = nn.Identity() 394 | input_dim = theta_dim + x_dim 395 | 396 | if z_score_theta: 397 | self.z_score_theta_net = standardizing_net(theta, False) 398 | else: 399 | self.z_score_theta_net = nn.Identity() 400 | 401 | if z_score_x: 402 | self.z_score_x_net = standardizing_net(x, False) 403 | else: 404 | self.z_score_x_net = nn.Identity() 405 | 406 | if z_score_dists: 407 | mean_distance = torch.mean(dists) 408 | std_distance = torch.std(dists) 409 | self.z_score_dist_net = MultiplyByMean(mean_distance, std_distance) 410 | else: 411 | self.z_score_dist_net = nn.Identity() 412 | 413 | output_dim = 1 414 | if net_type == "MLP": 415 | net = nets.MLP( 416 | in_shape=[input_dim], 417 | out_shape=[output_dim], 418 | hidden_sizes=[hidden_features] * num_layers, 419 | activation=activation, 420 | activate_output=activate_output, 421 | ) 422 | 423 | elif net_type == "resnet": 424 | net = nets.ResidualNet( 425 | in_features=input_dim, 426 | out_features=output_dim, 427 | hidden_features=hidden_features, 428 | num_blocks=num_layers, 429 | activation=activation, 430 | dropout_probability=dropout_prob, 431 | use_batch_norm=use_batch_norm, 432 | ) 433 | else: 434 | raise NotImplementedError 435 | 436 | # ### TO DO: add activation at the end to force positive distance 437 | if positive_constraint_fn == None: 438 | self.positive_constraint_fn = lambda x: x 439 | elif positive_constraint_fn == "relu": 440 | self.positive_constraint_fn = F.relu 441 | elif positive_constraint_fn == "exponential": 442 | self.positive_constraint_fn = torch.exp 443 | elif positive_constraint_fn == "softplus": 444 | self.positive_constraint_fn = F.softplus 445 | else: 446 | raise NotImplementedError 447 | 448 | self.net = net 449 | 450 | def forward(self, theta, x): 451 | """ 452 | Predicts distance between theta and x. 453 | """ 454 | # Check for z-score and embedding nets at run time just in case 455 | if not hasattr(self, "embedding_net_x"): 456 | self.embedding_net_x = nn.Identity() 457 | if not hasattr(self, "z_score_theta_net"): 458 | self.z_score_theta_net = nn.Identity() 459 | if not hasattr(self, "z_score_x_net"): 460 | self.z_score_x_net = nn.Identity() 461 | if not hasattr(self, "z_score_dist_net"): 462 | self.z_score_dist_net = nn.Identity() 463 | 464 | theta = self.z_score_theta_net(theta) 465 | x = self.z_score_x_net(x) 466 | x_embedded = self.embedding_net_x(x) 467 | 468 | rectified_distance = self.positive_constraint_fn( 469 | self.net(torch.concat((theta, x_embedded), dim=-1)) 470 | ) 471 | return self.z_score_dist_net(rectified_distance) 472 | 473 | 474 | class GBIPotential(BasePotential): 475 | # Need to set this to True for gaussian mixture. 476 | allow_iid_x = True 477 | 478 | def __init__(self, prior, gen_llh_fn, x_o=None, beta=1.0): 479 | super().__init__(prior, x_o) 480 | self.gen_llh_fn = gen_llh_fn 481 | self.beta = beta 482 | 483 | def set_beta(self, beta): 484 | self.beta = beta 485 | 486 | def __call__(self, theta, track_gradients=True): 487 | with torch.set_grad_enabled(track_gradients): 488 | return -self.beta * self.gen_llh_fn(theta, self.x_o) + self.prior.log_prob( 489 | theta 490 | ) 491 | 492 | 493 | def mse_dist(xs: Tensor, x_o: Tensor) -> Tensor: 494 | # Shape of xs should be [num_thetas, num_xs, num_x_dims]. 495 | mse = ((xs - x_o) ** 2).mean(dim=2) # Average over data dimensions. 496 | return mse.mean(dim=1) # Monte-Carlo average 497 | 498 | 499 | def get_distance_function(dist_name): 500 | if dist_name == "mse": 501 | return mse_dist 502 | 503 | 504 | class MultiplyByMean(nn.Module): 505 | def __init__(self, mean: Union[Tensor, float], std: Union[Tensor, float]): 506 | super(MultiplyByMean, self).__init__() 507 | mean, std = map(torch.as_tensor, (mean, std)) 508 | self.mean = mean 509 | self.std = std 510 | self.register_buffer("_mean", mean) 511 | self.register_buffer("_std", std) 512 | 513 | def forward(self, tensor): 514 | return tensor * self._mean 515 | -------------------------------------------------------------------------------- /automind/inference/algorithms/Regression.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import torch 3 | from torch import nn, Tensor 4 | from torch.nn import functional as F 5 | from torch.utils.data import TensorDataset, DataLoader 6 | from pyknos.nflows.nn import nets 7 | 8 | from typing import Any, Callable, Dict, Optional, Union 9 | from copy import deepcopy 10 | import numpy as np 11 | 12 | from torch import Tensor, nn, optim 13 | from torch.utils.data import TensorDataset, DataLoader 14 | 15 | 16 | from sbi.utils.torchutils import atleast_2d 17 | 18 | 19 | from time import time 20 | 21 | 22 | class RegressionInference: 23 | """ 24 | A bit unnecessary but copying the sbi/GBI API to be consistent. 25 | """ 26 | 27 | def __init__( 28 | self, 29 | theta: Tensor, 30 | x: Tensor, 31 | predict_theta: bool, 32 | num_layers: int, 33 | num_hidden: int, 34 | net_type: str = "resnet", 35 | positive_constraint_fn: str = None, 36 | net_kwargs: Optional[Dict] = {}, 37 | ): 38 | if predict_theta: 39 | self.predict_theta = True 40 | self.X, self.Y = x, theta 41 | else: 42 | self.predict_theta = False 43 | self.X, self.Y = theta, x 44 | 45 | self.num_layers = num_layers 46 | self.num_hidden = num_hidden 47 | self.net_type = net_type 48 | self.net_kwargs = net_kwargs 49 | 50 | """Initialize neural network for regression.""" 51 | self.init_network() 52 | 53 | def init_network(self, seed=0): 54 | torch.manual_seed(seed) 55 | self.network = RegressionNet( 56 | self.X.shape[1], 57 | self.Y.shape[1], 58 | self.num_layers, 59 | self.num_hidden, 60 | self.net_type, 61 | **self.net_kwargs, 62 | ) 63 | 64 | def train( 65 | self, 66 | network: Optional[nn.Module] = None, 67 | training_batch_size: int = 500, 68 | max_n_epochs: int = 1000, 69 | stop_after_counter_reaches: int = 50, 70 | validation_fraction: float = 0.1, 71 | print_every_n: int = 20, 72 | plot_losses: bool = True, 73 | ) -> nn.Module: 74 | # Can use custom distance net, otherwise take existing in class. 75 | if network != None: 76 | if self.network != None: 77 | print("Warning: Overwriting existing network.") 78 | self.network = network 79 | 80 | # Define loss and optimizer. 81 | nn_loss = nn.MSELoss() 82 | optimizer = optim.Adam(self.network.parameters()) 83 | 84 | # Train/val split 85 | dataset = TensorDataset(self.X, self.Y) 86 | train_set, val_set = torch.utils.data.random_split( 87 | dataset, [1 - validation_fraction, validation_fraction] 88 | ) 89 | dataloader = DataLoader(train_set, batch_size=training_batch_size, shuffle=True) 90 | X_val, Y_val = val_set[:] 91 | 92 | # Training loop. 93 | train_losses, val_losses = [], [] 94 | epoch = 0 95 | self._val_loss = torch.inf 96 | while epoch <= max_n_epochs and not self._check_convergence( 97 | epoch, stop_after_counter_reaches 98 | ): 99 | time_start = time() 100 | for i_b, batch in enumerate(dataloader): 101 | optimizer.zero_grad() 102 | # Forward pass for distances. 103 | Y_pred = self.network(batch[0]) 104 | 105 | # Training loss. 106 | l = nn_loss(batch[1], Y_pred) 107 | l.backward() 108 | optimizer.step() 109 | train_losses.append(l.detach().item()) 110 | 111 | # Compute validation loss each epoch. 112 | with torch.no_grad(): 113 | self._val_loss = ( 114 | nn_loss(Y_val, self.network(X_val).squeeze()).detach().item() 115 | ) 116 | val_losses.append([(i_b + 1) * epoch, self._val_loss]) 117 | 118 | # Print validation loss 119 | if epoch % print_every_n == 0: 120 | print( 121 | f"{epoch}: train loss: {train_losses[-1]:.6f}, val loss: {self._val_loss:.6f}, best val loss: {self._best_val_loss:.6f}, {(time()-time_start):.4f} seconds per epoch." 122 | ) 123 | 124 | epoch += 1 125 | 126 | print(f"Network converged after {epoch-1} of {max_n_epochs} epochs.") 127 | 128 | # Plot loss curves for convenience. 129 | self.train_losses = torch.Tensor(train_losses) 130 | self.val_losses = torch.Tensor(val_losses) 131 | if plot_losses: 132 | self._plot_losses(self.train_losses, self.val_losses) 133 | 134 | # Avoid keeping the gradients in the resulting network, which can 135 | # cause memory leakage when benchmarking. 136 | self.network.zero_grad(set_to_none=True) 137 | return deepcopy(self.network) 138 | 139 | def _check_convergence(self, counter: int, stop_after_counter_reaches: int) -> bool: 140 | """Return whether the training converged yet and save best model state so far. 141 | Checks for improvement in validation performance over previous batches or epochs. 142 | """ 143 | converged = False 144 | 145 | assert self.network is not None 146 | network = self.network 147 | 148 | # (Re)-start the epoch count with the first epoch or any improvement. 149 | if counter == 0 or self._val_loss < self._best_val_loss: 150 | self._best_val_loss = self._val_loss 151 | self._counts_since_last_improvement = 0 152 | self._best_model_state_dict = deepcopy(network.state_dict()) 153 | else: 154 | self._counts_since_last_improvement += 1 155 | 156 | # If no validation improvement over many epochs, stop training. 157 | if self._counts_since_last_improvement > stop_after_counter_reaches - 1: 158 | network.load_state_dict(self._best_model_state_dict) 159 | converged = True 160 | 161 | return converged 162 | 163 | def _plot_losses(self, train_losses, val_losses): 164 | plt.figure(figsize=(8, 3)) 165 | plt.plot(train_losses, "k", alpha=0.8) 166 | plt.plot(val_losses[:, 0], val_losses[:, 1], "r.-", alpha=0.8) 167 | plt.savefig("losses.png") 168 | 169 | 170 | class RegressionNet(nn.Module): 171 | def __init__( 172 | self, 173 | input_dim, 174 | output_dim, 175 | num_layers, 176 | hidden_features, 177 | net_type="resnet", 178 | activation=F.relu, 179 | dropout_prob=0.0, 180 | use_batch_norm=False, 181 | activate_output=False, 182 | in_transform=None, 183 | ): 184 | super().__init__() 185 | if net_type == "MLP": 186 | net = nets.MLP( 187 | in_shape=[input_dim], 188 | out_shape=[output_dim], 189 | hidden_sizes=[hidden_features] * num_layers, 190 | activation=activation, 191 | activate_output=activate_output, 192 | ) 193 | 194 | elif net_type == "resnet": 195 | net = nets.ResidualNet( 196 | in_features=input_dim, 197 | out_features=output_dim, 198 | hidden_features=hidden_features, 199 | num_blocks=num_layers, 200 | activation=activation, 201 | dropout_probability=dropout_prob, 202 | use_batch_norm=use_batch_norm, 203 | ) 204 | elif net_type == "linear": 205 | net = LinearRegression(input_dim, output_dim) 206 | else: 207 | raise NotImplementedError 208 | self.network = net 209 | 210 | if in_transform is not None: 211 | self.in_transform = in_transform 212 | else: 213 | self.in_transform = nn.Identity() 214 | 215 | def forward(self, x): 216 | x = self.in_transform(x) 217 | x = self.network(x) 218 | return x 219 | 220 | 221 | class RegressionEnsemble: 222 | def __init__(self, neural_nets): 223 | if type(neural_nets) is not list: 224 | neural_nets = [neural_nets] 225 | self.n_ensemble = len(neural_nets) 226 | self.neural_nets = neural_nets 227 | 228 | def sample(self, x, sample_shape=None): 229 | # sample_shape is a dummy argument that makes it consistent with the other inference algorithms 230 | with torch.no_grad(): 231 | return torch.stack([net(x) for net in self.neural_nets]) 232 | 233 | 234 | class LinearRegression(torch.nn.Module): 235 | def __init__(self, input_dim, output_dim): 236 | super().__init__() 237 | self.linear = torch.nn.Linear(input_dim, output_dim) 238 | 239 | def forward(self, x): 240 | return self.linear(x) 241 | -------------------------------------------------------------------------------- /automind/inference/inferer.py: -------------------------------------------------------------------------------- 1 | # Defining the various inference algorithms, for both posterior and likelihood estimation, such that it can be called in the training script. 2 | # Includes inference instantiation, posterior building, and sampling functions. 3 | 4 | import pandas as pd 5 | import torch 6 | from sbi.inference import SNPE, SNLE, SNRE, MCMCPosterior 7 | from sbi.utils import mcmc_transform 8 | from automind.inference.algorithms.Regression import RegressionEnsemble 9 | from automind.inference.algorithms.GBI import ( 10 | GBInference, 11 | GBIPotential, 12 | mse_dist, 13 | ) 14 | from automind.utils import dist_utils 15 | 16 | 17 | def get_posterior_from_nn( 18 | neural_net, prior, algorithm, build_posterior_params=None, use_unity_prior=True 19 | ): 20 | """Build approximate posterior from neural net and prior using specified algorithm. 21 | 22 | Args: 23 | neural_net (nn.Module): Flow or ACE network. 24 | prior (Distribution): Prior distribution. 25 | algorithm (str): Inference algorithm to use. 26 | build_posterior_params (dict, optional): Additional parameters for posterior building. Defaults to None. 27 | use_unity_prior (bool, optional): Whether to use a dummy Box[0,1] prior. Defaults to True. 28 | 29 | Returns: 30 | Distribution: Approximate posterior distribution. 31 | """ 32 | # Need to build this dummy prior if algorithms were trained with min-max standardized theta 33 | from sbi.utils import BoxUniform 34 | 35 | prior_unity = BoxUniform( 36 | torch.zeros( 37 | len( 38 | prior.names, 39 | ) 40 | ), 41 | torch.ones( 42 | len( 43 | prior.names, 44 | ) 45 | ), 46 | ) 47 | 48 | prior_use = prior_unity if use_unity_prior else prior 49 | 50 | if algorithm == "NPE": 51 | posterior = SNPE().build_posterior(neural_net, prior_use) 52 | 53 | elif algorithm == "NLE": 54 | posterior = SNLE().build_posterior( 55 | neural_net, 56 | prior_use, 57 | mcmc_method=build_posterior_params["mcmc_method"], 58 | mcmc_parameters=build_posterior_params["mcmc_parameters"], 59 | ) 60 | 61 | elif algorithm == "NRE": 62 | posterior = SNRE().build_posterior( 63 | neural_net, 64 | prior_use, 65 | mcmc_method=build_posterior_params["mcmc_method"], 66 | mcmc_parameters=build_posterior_params["mcmc_parameters"], 67 | ) 68 | 69 | elif (algorithm == "REGR-R") or (algorithm == "REGR-F"): 70 | posterior = RegressionEnsemble(neural_net) 71 | 72 | elif algorithm == "ACE": 73 | # prior log-prob is -55 for original, 0 for dummy 74 | inference = GBInference(prior_use, mse_dist) 75 | 76 | # Get generalized log-likelihood function 77 | genlike_func = inference.build_amortized_GLL(neural_net) 78 | potential_fn = GBIPotential( 79 | prior_use, genlike_func, beta=build_posterior_params["ace_beta"] 80 | ) 81 | 82 | # Make posterior 83 | theta_transform = mcmc_transform(prior_use) 84 | posterior = MCMCPosterior( 85 | potential_fn, 86 | theta_transform=theta_transform, 87 | proposal=prior_use, 88 | method=build_posterior_params["mcmc_method"], 89 | **build_posterior_params["mcmc_parameters"], 90 | ) 91 | # Save beta in object 92 | posterior.beta = build_posterior_params["ace_beta"] 93 | 94 | # Inherit stuff from prior 95 | posterior = dist_utils.pass_info_from_prior( 96 | prior, 97 | posterior, 98 | [ 99 | "names", 100 | "marginals", 101 | "b2_units", 102 | "as_dict", 103 | "x_bounds_and_transforms", 104 | "x_standardizing_func", 105 | ], 106 | ) 107 | if not hasattr(posterior, "prior"): 108 | posterior.prior = prior_use 109 | 110 | posterior.prior_original = prior 111 | return posterior 112 | 113 | 114 | def append_samples(df_to_append_to, new_samples, sample_type, param_names): 115 | """Append samples to existing dataframe. 116 | 117 | Args: 118 | df_to_append_to (DataFrame): Existing dataframe. 119 | new_samples (DataFrame): New samples to append. 120 | sample_type (str): Type of samples, e.g., 'NPE', 'NPE_oversample'. 121 | param_names (str): Parameter names. 122 | 123 | Returns: 124 | DataFrame: Updated dataframe. 125 | """ 126 | df_cur = pd.DataFrame(columns=df_to_append_to.columns) 127 | df_cur[param_names] = new_samples 128 | df_cur["inference.type"] = sample_type 129 | return pd.concat((df_to_append_to, df_cur), axis=0) 130 | 131 | 132 | def sample_from_posterior( 133 | posterior, prior, num_samples, x_o, cfg_algorithm, theta_o=None, batch_run=None 134 | ): 135 | """Sample from posterior distribution. 136 | 137 | Args: 138 | posterior (Distribution): Posterior distribution. 139 | prior (Distribution): Prior distribution. 140 | num_samples (int): Number of samples to draw. 141 | x_o (tensor): Observed data to condition on. 142 | cfg_algorithm (dict): Sampling algorithm configurations. 143 | theta_o (tensor, optional): True parameter set. Defaults to None. 144 | batch_run (list, optional): Batch/random seed ID. Defaults to None. 145 | 146 | Returns: 147 | DataFrame: Posterior samples. 148 | """ 149 | df_posterior_samples = pd.DataFrame(columns=["inference.type"] + posterior.names) 150 | samples_dict = {} 151 | # Append GT if there is 152 | if theta_o is not None: 153 | df_posterior_samples = append_samples( 154 | df_posterior_samples, 155 | theta_o[None, :].numpy().astype(float), 156 | "gt_resim", 157 | posterior.names, 158 | ) 159 | if "Regression" in str(posterior): 160 | # Custom sampling for RegressionEnsemble and pass through 161 | posterior_samples = posterior.sample(x=x_o) 162 | sample_type_samples = f'{cfg_algorithm["name"]}_samples' 163 | theta, _ = dist_utils.standardize_theta( 164 | posterior_samples, prior, destandardize=True 165 | ) 166 | samples_dict[sample_type_samples] = theta 167 | # Just return all ones for log_probs 168 | log_prob_fn = lambda theta, x: ( 169 | torch.zeros((1,)) if len(theta.shape) == 1 else torch.zeros(theta.shape[0]) 170 | ) 171 | else: 172 | # Get log_prob / potential function for MCMC or DirectPosterior 173 | log_prob_fn = ( 174 | posterior.potential if "MCMC" in str(posterior) else posterior.log_prob 175 | ) 176 | 177 | # Posterior samples, with option to oversample high-prob 178 | if cfg_algorithm["oversample_factor"] > 1: 179 | samples_ = posterior.sample( 180 | x=x_o, sample_shape=(num_samples * cfg_algorithm["oversample_factor"],) 181 | ) 182 | posterior_samples = samples_[ 183 | torch.sort(log_prob_fn(samples_, x_o), descending=True)[1] 184 | ][:num_samples] 185 | sample_type_samples = f'{cfg_algorithm["name"]}_samples_prune_{cfg_algorithm["oversample_factor"]}' 186 | else: 187 | posterior_samples = posterior.sample(x=x_o, sample_shape=(num_samples,)) 188 | sample_type_samples = f'{cfg_algorithm["name"]}_samples' 189 | 190 | theta, _ = dist_utils.standardize_theta( 191 | posterior_samples, prior, destandardize=True 192 | ) 193 | samples_dict[sample_type_samples] = theta 194 | 195 | # MAP sample 196 | if cfg_algorithm["do_sample_map"]: 197 | posterior.set_default_x(x_o) 198 | posterior_map = posterior.map() 199 | theta_map = ( 200 | dist_utils.standardize_theta(posterior_map, prior, destandardize=True)[ 201 | 0 202 | ][None, :] 203 | .numpy() 204 | .astype(float) 205 | ) 206 | sample_type = f'{cfg_algorithm["name"]}_map' 207 | df_posterior_samples = append_samples( 208 | df_posterior_samples, 209 | theta_map, 210 | sample_type, 211 | posterior.names, 212 | ) 213 | samples_dict[sample_type] = theta_map 214 | 215 | # posterior sample mean 216 | if cfg_algorithm["do_sample_pmean"]: 217 | posterior_mean = posterior_samples.mean(0) 218 | theta_pmean = ( 219 | dist_utils.standardize_theta(posterior_mean, prior, destandardize=True)[0][ 220 | None, : 221 | ] 222 | .numpy() 223 | .astype(float) 224 | ) 225 | sample_type = f'{cfg_algorithm["name"]}_mean' 226 | df_posterior_samples = append_samples( 227 | df_posterior_samples, 228 | theta_pmean, 229 | sample_type, 230 | posterior.names, 231 | ) 232 | samples_dict[sample_type] = theta_pmean 233 | 234 | # Append posterior samples last 235 | df_posterior_samples = append_samples( 236 | df_posterior_samples, 237 | theta.numpy().astype(float), 238 | sample_type_samples, 239 | posterior.names, 240 | ) 241 | 242 | # Append log_prob 243 | theta_scaled = dist_utils.standardize_theta( 244 | torch.Tensor(df_posterior_samples[posterior.names].values.astype(float)), prior 245 | )[0] 246 | df_posterior_samples.insert(1, "infer.log_prob", log_prob_fn(theta_scaled, x=x_o)) 247 | 248 | # Append xo info 249 | if batch_run is not None: 250 | df_posterior_samples.insert(0, "x_o", f"{batch_run[0]}_{batch_run[1]}") 251 | 252 | return df_posterior_samples, samples_dict 253 | -------------------------------------------------------------------------------- /automind/inference/trainers.py: -------------------------------------------------------------------------------- 1 | # Defining the various inference algorithms, for both posterior and likelihood estimation, such that it can be called in the training script. 2 | # Includes inference instantiation, posterior building, and sampling functions. 3 | 4 | import torch 5 | from torch import nn 6 | from sbi.inference import SNPE, SNLE, SNRE 7 | from sbi.utils import get_nn_models, RestrictionEstimator, process_prior 8 | from sbi.neural_nets.embedding_nets import FCEmbedding 9 | 10 | from automind.inference.algorithms.Regression import RegressionInference 11 | from automind.inference.algorithms.GBI import GBInference, get_distance_function 12 | from automind.utils import data_utils 13 | 14 | 15 | ## NPE 16 | def train_NPE(theta, x, prior, cfg): 17 | # some settings: 18 | # cfg.density_estimator: the type of density estimator to use, MAF, NSF, MoG, etc. 19 | # cfg.sigmoid_theta: apply sigmoid on theta to keep into prior range. 20 | # cfg.z_score_x: 'none', 'independent' or 'structured'. 21 | # cfg.z_score_theta: 'none', 'independent' or 'structured'. 22 | # cfg.use_embedding_net: whether to use embedding net or not, and its dimensions. 23 | embedding_net_x = _get_embedding_net(cfg, input_dim=x.shape[1]) 24 | net = get_nn_models.posterior_nn( 25 | model=cfg.density_estimator, 26 | sigmoid_theta=cfg.sigmoid_theta, 27 | prior=prior, 28 | z_score_x=cfg.z_score_x, 29 | z_score_theta=cfg.z_score_theta, 30 | embedding_net=embedding_net_x, 31 | hidden_features=cfg.hidden_features, 32 | num_transforms=cfg.num_transforms, 33 | num_bins=cfg.num_bins, 34 | ) 35 | inference = SNPE(prior=prior, density_estimator=net) 36 | density_estimator = inference.append_simulations(theta, x).train() 37 | return inference, density_estimator 38 | 39 | 40 | def train_Restrictor(theta, x, prior, cfg): 41 | # Train restriction estimator classifier first 42 | restriction_estimator = RestrictionEstimator( 43 | prior=prior, 44 | hidden_features=cfg.hidden_features_restrictor, 45 | num_blocks=cfg.num_blocks_restrictor, 46 | z_score=cfg.z_score_theta, 47 | ) 48 | 49 | restriction_estimator.append_simulations(theta, x) 50 | classifier = restriction_estimator.train() 51 | proposal = restriction_estimator.restrict_prior() 52 | return proposal, classifier 53 | 54 | 55 | def train_RestrictorNPE(theta, x, prior, cfg): 56 | # Train restriction estimator classifier first 57 | restriction_estimator = RestrictionEstimator( 58 | prior=prior, 59 | hidden_features=cfg.hidden_features_restrictor, 60 | num_blocks=cfg.num_blocks_restrictor, 61 | z_score=cfg.z_score_theta, 62 | ) 63 | 64 | restriction_estimator.append_simulations(theta, x) 65 | classifier = restriction_estimator.train() 66 | # proposal = restriction_estimator.restrict_prior() 67 | proposal = process_prior(restriction_estimator.restrict_prior())[0] 68 | 69 | inference, density_estimator = train_NPE(theta, x, proposal, cfg) 70 | return inference, density_estimator 71 | 72 | 73 | ## NLE 74 | def train_NLE(theta, x, prior, cfg): 75 | net = get_nn_models.likelihood_nn( 76 | model=cfg.density_estimator, 77 | z_score_x=cfg.z_score_x, 78 | z_score_theta=cfg.z_score_theta, 79 | num_transforms=cfg.num_transforms, 80 | num_bins=cfg.num_bins, 81 | ) 82 | inference = SNLE(prior=prior, density_estimator=cfg.density_estimator) 83 | density_estimator = inference.append_simulations(theta, x).train() 84 | return inference, density_estimator 85 | 86 | 87 | ## NRE 88 | def train_NRE(theta, x, prior, cfg): 89 | embedding_net_x = _get_embedding_net(cfg, input_dim=x.shape[1]) 90 | net = get_nn_models.classifier_nn( 91 | model=cfg.classifier, 92 | z_score_theta=cfg.z_score_theta, 93 | z_score_x=cfg.z_score_x, 94 | embedding_net_x=embedding_net_x, 95 | ) 96 | inference = SNRE(prior=prior, classifier=net) 97 | density_estimator = inference.append_simulations(theta, x).train() 98 | return inference, density_estimator 99 | 100 | 101 | # Regression 102 | def train_Regression(theta, x, prior, cfg): 103 | inference = RegressionInference( 104 | theta, 105 | x, 106 | predict_theta=(cfg.name == "REGR-R"), 107 | num_layers=cfg.num_layers, 108 | num_hidden=cfg.num_hidden, 109 | net_type=cfg.net_type, 110 | ) 111 | ensemble = [] 112 | for i in range(cfg.n_ensemble): 113 | print("--------------------") 114 | print(f"Training {i+1} of {cfg.n_ensemble} networks...") 115 | inference.init_network(seed=i) 116 | neural_net = inference.train( 117 | training_batch_size=cfg.training_batch_size, 118 | max_n_epochs=cfg.max_n_epochs, 119 | stop_after_counter_reaches=cfg.stop_after_counter_reaches, 120 | validation_fraction=cfg.validation_fraction, 121 | print_every_n=cfg.print_every_n, 122 | plot_losses=cfg.plot_losses, 123 | ) 124 | ensemble.append(neural_net) 125 | if cfg.n_ensemble == 1: 126 | ensemble = ensemble[0] 127 | return inference, ensemble 128 | 129 | 130 | def _concatenate_xs(x1, x2): 131 | """Concatenate two tensors along the first dimension.""" 132 | return torch.concat([x1, x2], dim=0) 133 | 134 | 135 | # GBI ACE 136 | def train_ACE(theta, x, prior, cfg): 137 | # Augment data with noise. 138 | if type(cfg.n_augmented_x) is int: 139 | n_augmented_x = cfg.n_augmented_x 140 | elif type(cfg.n_augmented_x) is float: 141 | n_augmented_x = int(cfg.n_augmented_x * x.shape[0]) 142 | else: 143 | n_augmented_x = 0 144 | 145 | x_aug = x[torch.randint(x.shape[0], size=(n_augmented_x,))] 146 | x_aug = x_aug + torch.randn(x_aug.shape) * x.std(dim=0) * cfg.noise_level 147 | x_target = _concatenate_xs(x, x_aug) 148 | 149 | if cfg.train_with_obs: 150 | # Append observations. 151 | x_obs = data_utils.load_pickled(cfg.obs_data_path) 152 | # Put all together. 153 | x_target = _concatenate_xs(x_target, x_obs) 154 | 155 | distance_func = get_distance_function(cfg.dist_func) 156 | inference = GBInference( 157 | prior=prior, 158 | distance_func=distance_func, 159 | do_precompute_distances=cfg.do_precompute_distances, 160 | include_bad_sims=cfg.include_bad_sims, 161 | nan_dists_replacement=cfg.nan_dists_replacement, 162 | ) 163 | inference = inference.append_simulations(theta, x, x_target, n_dists_precompute=5.0) 164 | inference.initialize_distance_estimator( 165 | num_layers=cfg.num_layers, 166 | num_hidden=cfg.num_hidden, 167 | net_type=cfg.net_type, 168 | positive_constraint_fn=cfg.positive_constraint_fn, 169 | ) 170 | distance_net = inference.train( 171 | training_batch_size=cfg.training_batch_size, 172 | max_n_epochs=cfg.max_epochs, 173 | validation_fraction=cfg.validation_fraction, 174 | n_train_per_theta=cfg.n_train_per_theta, 175 | n_val_per_theta=cfg.n_val_per_theta, 176 | stop_after_counter_reaches=cfg.stop_after_counter_reaches, 177 | print_every_n=cfg.print_every_n, 178 | plot_losses=cfg.plot_losses, 179 | ) 180 | return inference, distance_net 181 | 182 | 183 | ## embedding nets 184 | def _get_embedding_net(cfg, input_dim): 185 | if cfg.use_embedding_net: 186 | embedding_net = FCEmbedding( 187 | input_dim=input_dim, 188 | output_dim=cfg.embedding_net_outdim, 189 | num_layers=cfg.embedding_net_layers, 190 | num_hiddens=cfg.embedding_net_num_hiddens, 191 | ) 192 | else: 193 | embedding_net = nn.Identity() 194 | return embedding_net 195 | -------------------------------------------------------------------------------- /automind/sim/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mackelab/automind/e5e3cd999b39bf2b9bd365e43c6d379d245d3e77/automind/sim/__init__.py -------------------------------------------------------------------------------- /automind/sim/b2_interface.py: -------------------------------------------------------------------------------- 1 | ### Brian2 interface functions, e.g., remove / add units, etc. 2 | import brian2 as b2 3 | 4 | 5 | def _deunitize_spiketimes(spiketrains_dict): 6 | """ 7 | Takes a spiketrain dictionary and return a copy without b2 time units. 8 | Output spiketrains_dict are assumed to be in units of seconds. 9 | 10 | Args: 11 | spiketrains_dict (dict): Dictionary of spiketimes in the form {cell_index: array(spiketimes)} 12 | 13 | Returns: 14 | dict: Dictionary of spiketimes without units. 15 | 16 | """ 17 | return {k: v / b2.second for k, v in spiketrains_dict.items()} 18 | 19 | 20 | def _deunitize_rates(t, rate): 21 | """Deunitize timestamps and rate.""" 22 | return t / b2.second, rate / b2.Hz 23 | 24 | 25 | def set_adaptive_vcut(param_dict): 26 | param_dict["v_cut"] = param_dict["v_thresh"] + 5 * param_dict["delta_T"] 27 | 28 | 29 | def parse_timeseries(net_collect, record_defs): 30 | """Parses recordings of continuous variables 31 | 32 | Args: 33 | net_collect (brian2 net): collector that has all the monitors. 34 | record_defs (dict): nested dict that defines recorded populations & variables. 35 | 36 | Returns: 37 | dict: time series 38 | """ 39 | timeseries = {} 40 | for pop_name in record_defs.keys(): 41 | if record_defs[pop_name]["rate"]: 42 | # deunitize rate 43 | t_rate = net_collect[pop_name + "_rate"].t / b2.second 44 | timeseries[pop_name + "_rate"] = ( 45 | net_collect[pop_name + "_rate"].rate / b2.Hz 46 | ) 47 | timeseries["t_rate"] = t_rate 48 | 49 | if record_defs[pop_name]["trace"]: 50 | var_names = record_defs[pop_name]["trace"][0] 51 | for vn in var_names: 52 | # taking the average of the recorded state variables 53 | timeseries[pop_name + "_" + vn] = getattr( 54 | net_collect[pop_name + "_trace"], vn + "_" 55 | ).mean(0) 56 | t = getattr(net_collect[pop_name + "_trace"], "t") / b2.second 57 | timeseries["t"] = t 58 | 59 | return timeseries 60 | 61 | 62 | def strip_b2_units(theta_samples, theta_priors): 63 | """Strips brian2 units from theta samples. 64 | 65 | Args: 66 | theta_samples (dict): Dictionary of parameter samples. 67 | theta_priors (dict): Prior, in the form of a dictionary. 68 | 69 | Returns: 70 | dict: Dictionary of parameter samples without units. 71 | """ 72 | # this is pretty much only for converting to and saving as dataframe 73 | # WILL break if data not in array but I don't really care at this point 74 | theta_samples_unitless = {} 75 | for k, v in theta_samples.items(): 76 | if k in theta_priors.keys(): 77 | theta_samples_unitless[k] = (v / theta_priors[k]["unit"]).astype(v.dtype) 78 | else: 79 | theta_samples_unitless[k] = v 80 | 81 | return theta_samples_unitless 82 | 83 | 84 | def clear_b2_cache(cache_path=None): 85 | """Clears brian2 cache. 86 | 87 | Args: 88 | cache_path (str, optional): location of cache. 89 | """ 90 | if cache_path is None: 91 | try: 92 | b2.clear_cache("cython") 93 | print("cache cleared.") 94 | except: 95 | print("cache non-existent.") 96 | else: 97 | import shutil 98 | 99 | try: 100 | shutil.rmtree(cache_path) 101 | print(f"cache cleared: {cache_path}.") 102 | except: 103 | print("cache non-existent.") 104 | 105 | 106 | def set_b2_cache(cache_dir, file_lock=False): 107 | """Set brian2 cache, optionally apply file lock.""" 108 | b2.prefs.codegen.runtime.cython.cache_dir = cache_dir 109 | b2.prefs.codegen.runtime.cython.multiprocess_safe = file_lock 110 | -------------------------------------------------------------------------------- /automind/sim/b2_models.py: -------------------------------------------------------------------------------- 1 | ### Brian2 models for network construction. 2 | import brian2 as b2 3 | import numpy as np 4 | 5 | 6 | def adaptive_exp_net(all_param_dict): 7 | """Adaptive exponential integrate-and-fire network.""" 8 | # separate parameter dictionaries 9 | param_dict_net = all_param_dict["params_net"] 10 | param_dict_settings = all_param_dict["params_settings"] 11 | 12 | # set random seeds 13 | b2.seed(param_dict_settings["random_seed"]) 14 | np.random.seed(param_dict_settings["random_seed"]) 15 | b2.defaultclock.dt = param_dict_settings["dt"] 16 | 17 | param_dict_neuron_E = all_param_dict["params_Epop"] 18 | # check if there is inhibition 19 | has_inh = False if param_dict_net["exc_prop"] == 1 else True 20 | if has_inh: 21 | param_dict_neuron_I = all_param_dict["params_Ipop"] 22 | 23 | #### NETWORK CONSTRUCTION ############ 24 | ###################################### 25 | ### define neuron equation 26 | adex_coba_eq = """dv/dt = (-g_L * (v - v_rest) + g_L * delta_T * exp((v - v_thresh)/delta_T) - w + I)/C : volt (unless refractory)""" 27 | adlif_coba_eq = ( 28 | """dv/dt = (-g_L * (v - v_rest) - w + I)/C : volt (unless refractory)""" 29 | ) 30 | 31 | network_eqs = """ 32 | dw/dt = (-w + a * (v - v_rest))/tau_w : amp 33 | dge/dt = -ge / tau_ge : siemens 34 | dgi/dt = -gi / tau_gi : siemens 35 | Ie = ge * (E_ge - v): amp 36 | Ii = gi * (E_gi - v): amp 37 | I = I_bias + Ie + Ii : amp 38 | """ 39 | 40 | ### get cell counts 41 | N_pop, exc_prop = param_dict_net["N_pop"], param_dict_net["exc_prop"] 42 | N_exc, N_inh = int(N_pop * exc_prop), int(N_pop * (1 - exc_prop)) 43 | 44 | ### make neuron populations, set initial values and connect poisson inputs ### 45 | # make adlif if delta_t is 0, otherwise adex 46 | neuron_eq = ( 47 | adlif_coba_eq + network_eqs 48 | if param_dict_neuron_E["delta_T"] == 0 49 | else adex_coba_eq + network_eqs 50 | ) 51 | E_pop = b2.NeuronGroup( 52 | N_exc, 53 | model=neuron_eq, 54 | threshold="v > v_cut", 55 | reset="v = v_reset; w+=b", 56 | refractory=param_dict_neuron_E["t_refrac"], 57 | method=param_dict_settings["integ_method"], 58 | namespace=param_dict_neuron_E, 59 | ) 60 | E_pop.v = ( 61 | param_dict_neuron_E["v_rest"] 62 | + np.random.randn(N_exc) * param_dict_neuron_E["v_0_offset"][1] 63 | + param_dict_neuron_E["v_0_offset"][0] 64 | ) 65 | 66 | ### TO DO: also randomly initialize w to either randint(?)*b or randn*(v-v_rest)*a 67 | 68 | poisson_input_E = b2.PoissonInput( 69 | target=E_pop, 70 | target_var="ge", 71 | N=param_dict_neuron_E["N_poisson"], 72 | rate=param_dict_neuron_E["poisson_rate"], 73 | weight=param_dict_neuron_E["Q_poisson"], 74 | ) 75 | 76 | if has_inh: 77 | # make adlif if delta_t is 0, otherwise adex 78 | neuron_eq = ( 79 | adlif_coba_eq + network_eqs 80 | if param_dict_neuron_E["delta_T"] == 0 81 | else adex_coba_eq + network_eqs 82 | ) 83 | I_pop = b2.NeuronGroup( 84 | N_inh, 85 | model=neuron_eq, 86 | threshold="v > v_cut", 87 | reset="v = v_reset; w+=b", 88 | refractory=param_dict_neuron_I["t_refrac"], 89 | method=param_dict_settings["integ_method"], 90 | namespace=param_dict_neuron_I, 91 | ) 92 | I_pop.v = ( 93 | param_dict_neuron_I["v_rest"] 94 | + np.random.randn(N_inh) * param_dict_neuron_I["v_0_offset"][1] 95 | + param_dict_neuron_I["v_0_offset"][0] 96 | ) 97 | # check for poisson input 98 | if ( 99 | param_dict_neuron_I["N_poisson"] == 0 100 | or param_dict_neuron_I["poisson_rate"] == 0 101 | ): 102 | poisson_input_I = [] 103 | else: 104 | poisson_input_I = b2.PoissonInput( 105 | target=I_pop, 106 | target_var="ge", 107 | N=param_dict_neuron_I["N_poisson"], 108 | rate=param_dict_neuron_I["poisson_rate"], 109 | weight=param_dict_neuron_I["Q_poisson"], 110 | ) 111 | 112 | ### make and connect recurrent synapses ### 113 | syn_e2e = b2.Synapses( 114 | source=E_pop, 115 | target=E_pop, 116 | model="q: siemens", 117 | on_pre="ge_post += Q_ge", 118 | delay=param_dict_neuron_E["tdelay2e"], 119 | namespace=param_dict_neuron_E, 120 | ) 121 | syn_e2e.connect("i!=j", p=param_dict_net["p_e2e"]) 122 | if has_inh: 123 | syn_e2i = b2.Synapses( 124 | source=E_pop, 125 | target=I_pop, 126 | model="q: siemens", 127 | on_pre="ge_post += Q_ge", 128 | delay=param_dict_neuron_E["tdelay2i"], 129 | namespace=param_dict_neuron_I, 130 | ) 131 | syn_e2i.connect("i!=j", p=param_dict_net["p_e2i"]) 132 | syn_i2e = b2.Synapses( 133 | source=I_pop, 134 | target=E_pop, 135 | model="q: siemens", 136 | on_pre="gi_post += Q_gi", 137 | delay=param_dict_neuron_I["tdelay2e"], 138 | namespace=param_dict_neuron_E, 139 | ) 140 | syn_i2e.connect("i!=j", p=param_dict_net["p_i2e"]) 141 | syn_i2i = b2.Synapses( 142 | source=I_pop, 143 | target=I_pop, 144 | model="q: siemens", 145 | on_pre="gi_post += Q_gi", 146 | delay=param_dict_neuron_I["tdelay2i"], 147 | namespace=param_dict_neuron_I, 148 | ) 149 | syn_i2i.connect("i!=j", p=param_dict_net["p_i2i"]) 150 | 151 | ### define monitors ### 152 | rate_monitors, spike_monitors, trace_monitors = [], [], [] 153 | rec_defs = param_dict_settings["record_defs"] 154 | dt_ts = param_dict_settings["dt_ts"] # recording interval for continuous variables 155 | zipped_pops = ( 156 | zip(["exc", "inh"], [E_pop, I_pop]) if has_inh else zip(["exc"], [E_pop]) 157 | ) 158 | for pop_name, pop in zipped_pops: 159 | if pop_name in rec_defs.keys(): 160 | if rec_defs[pop_name]["rate"] is not False: 161 | rate_monitors.append( 162 | b2.PopulationRateMonitor(pop, name=pop_name + "_rate") 163 | ) 164 | if rec_defs[pop_name]["spikes"] is not False: 165 | rec_idx = ( 166 | np.arange(rec_defs[pop_name]["spikes"]) 167 | if type(rec_defs[pop_name]["spikes"]) is int 168 | else rec_defs[pop_name]["spikes"] 169 | ) 170 | spike_monitors.append( 171 | b2.SpikeMonitor(pop[rec_idx], name=pop_name + "_spikes") 172 | ) 173 | if rec_defs[pop_name]["trace"] is not False: 174 | rec_idx = ( 175 | np.arange(rec_defs[pop_name]["trace"][1]) 176 | if type(rec_defs[pop_name]["trace"][1]) is int 177 | else rec_defs[pop_name]["trace"][1] 178 | ) 179 | trace_monitors.append( 180 | b2.StateMonitor( 181 | pop, 182 | rec_defs[pop_name]["trace"][0], 183 | record=rec_idx, 184 | name=pop_name + "_trace", 185 | dt=dt_ts, 186 | ) 187 | ) 188 | 189 | ### collect into network object and return ### 190 | net_collect = b2.Network(b2.collect()) # magic collect all groups 191 | monitors = [rate_monitors, spike_monitors, trace_monitors] 192 | net_collect.add(monitors) 193 | return net_collect 194 | 195 | 196 | def make_clustered_edges(N_neurons, n_clusters, clusters_per_neuron, sort_by_cluster): 197 | """Make clustered connection graph. 198 | 199 | Args: 200 | N_neurons (int): Number of neurons. 201 | n_clusters (int): Number of clusters. 202 | clusters_per_neuron (int): Number of clusters per neuron. 203 | sort_by_cluster (bool): Sort by cluster. 204 | 205 | Returns: 206 | tuple: Membership, shared membership, connections in, and connections out. 207 | """ 208 | # assign cluster membership randomly 209 | if clusters_per_neuron > n_clusters: 210 | print("More clusters per neuron than exists.") 211 | return 212 | membership = np.random.randint( 213 | int(n_clusters), size=(N_neurons, int(clusters_per_neuron)) 214 | ) 215 | 216 | if sort_by_cluster: 217 | # sort to show block-wise membership 218 | membership = membership[np.argsort(membership[:, 0]), :] 219 | 220 | # find pairs that share membership 221 | shared_membership = ( 222 | np.array([np.any(m_i == membership, axis=1) for m_i in membership]) 223 | ).astype(int) 224 | 225 | # connections not in the same clusters 226 | conn_out = np.array((1 - shared_membership).nonzero()) 227 | # connections in the same clusters 228 | conn_in = np.array((shared_membership).nonzero()) 229 | # remove all i==j to get the correct # of edges as percentage of possible 230 | conn_in = conn_in[:, conn_in[0, :] != conn_in[1, :]] 231 | 232 | return membership, shared_membership, conn_in, conn_out 233 | 234 | 235 | def draw_prob_connections(conn, p): 236 | """Draw probabilistic connections given connectivity matrix shape.""" 237 | # randomly choose p * C connections 238 | # where C is the number of possible connections 239 | ij = np.sort( 240 | np.random.choice( 241 | conn.shape[1], np.ceil(conn.shape[1] * p).astype(int), replace=False 242 | ) 243 | ) 244 | return conn[:, ij] 245 | 246 | 247 | def make_clustered_network( 248 | N_neurons, n_clusters, clusters_per_neuron, p_in, p_out, sort_by_cluster=False 249 | ): 250 | """Make clustered network. 251 | 252 | Args: 253 | N_neurons (int): Number of neurons. 254 | n_clusters (int): Number of clusters. 255 | clusters_per_neuron (int): Number of clusters per neuron. 256 | p_in (float): probability of incoming connections. 257 | p_out (float): probability of outgoing connections. 258 | sort_by_cluster (bool, optional): Sort by cluster. Defaults to False. 259 | 260 | Returns: 261 | tuple: Membership, shared membership, connections in, and connections out. 262 | """ 263 | membership, shared_membership, conn_in, conn_out = make_clustered_edges( 264 | N_neurons, n_clusters, clusters_per_neuron, sort_by_cluster 265 | ) 266 | conn_in = draw_prob_connections(conn_in, p_in) 267 | conn_out = draw_prob_connections(conn_out, p_out) 268 | return membership, shared_membership, conn_in, conn_out 269 | 270 | 271 | def adaptive_exp_net_clustered(all_param_dict): 272 | """Adaptive exponential integrate-and-fire network with clustered connections.""" 273 | # separate parameter dictionaries 274 | param_dict_net = all_param_dict["params_net"] 275 | param_dict_settings = all_param_dict["params_settings"] 276 | 277 | # set random seeds 278 | b2.seed(param_dict_settings["random_seed"]) 279 | np.random.seed(param_dict_settings["random_seed"]) 280 | b2.defaultclock.dt = param_dict_settings["dt"] 281 | 282 | param_dict_neuron_E = all_param_dict["params_Epop"] 283 | # check if there is inhibition 284 | has_inh = False if param_dict_net["exc_prop"] == 1 else True 285 | if has_inh: 286 | param_dict_neuron_I = all_param_dict["params_Ipop"] 287 | 288 | #### NETWORK CONSTRUCTION ############ 289 | ###################################### 290 | ### define neuron equation 291 | adex_coba_eq = """dv/dt = (-g_L * (v - v_rest) + g_L * delta_T * exp((v - v_thresh)/delta_T) - w + I)/C : volt (unless refractory)""" 292 | adlif_coba_eq = ( 293 | """dv/dt = (-g_L * (v - v_rest) - w + I)/C : volt (unless refractory)""" 294 | ) 295 | 296 | network_eqs = """ 297 | dw/dt = (-w + a * (v - v_rest))/tau_w : amp 298 | dge/dt = -ge / tau_ge : siemens 299 | dgi/dt = -gi / tau_gi : siemens 300 | Ie = ge * (E_ge - v): amp 301 | Ii = gi * (E_gi - v): amp 302 | I = I_bias + Ie + Ii : amp 303 | """ 304 | 305 | ### get cell counts 306 | N_pop, exc_prop = param_dict_net["N_pop"], param_dict_net["exc_prop"] 307 | N_exc, N_inh = int(N_pop * exc_prop), int(N_pop * (1 - exc_prop)) 308 | 309 | ### make neuron populations, set initial values and connect poisson inputs ### 310 | # make adlif if delta_t is 0, otherwise adex 311 | neuron_eq = ( 312 | adlif_coba_eq + network_eqs 313 | if param_dict_neuron_E["delta_T"] == 0 314 | else adex_coba_eq + network_eqs 315 | ) 316 | E_pop = b2.NeuronGroup( 317 | N_exc, 318 | model=neuron_eq, 319 | threshold="v > v_cut", 320 | reset="v = v_reset; w+=b", 321 | refractory=param_dict_neuron_E["t_refrac"], 322 | method=param_dict_settings["integ_method"], 323 | namespace=param_dict_neuron_E, 324 | name="Epop", 325 | ) 326 | E_pop.v = ( 327 | param_dict_neuron_E["v_rest"] 328 | + np.random.randn(N_exc) * param_dict_neuron_E["v_0_offset"][1] 329 | + param_dict_neuron_E["v_0_offset"][0] 330 | ) 331 | 332 | ### SUBSET INPUT 333 | # define subset of E cells that receive input, to model 334 | # spontaneously active cells (which are a small proportion) 335 | N_igniters = int(param_dict_neuron_E["p_igniters"] * N_exc) 336 | # has to be a contiguous chunk 337 | E_igniters = E_pop[:N_igniters] 338 | 339 | # connect poisson input only to igniter neurons 340 | poisson_input_E = b2.PoissonInput( 341 | target=E_igniters, 342 | target_var="ge", 343 | N=param_dict_neuron_E["N_poisson"], 344 | rate=param_dict_neuron_E["poisson_rate"], 345 | weight=param_dict_neuron_E["Q_poisson"], 346 | ) 347 | 348 | #### resolve inhibitory population 349 | if has_inh: 350 | # make adlif if delta_t is 0, otherwise adex 351 | neuron_eq = ( 352 | adlif_coba_eq + network_eqs 353 | if param_dict_neuron_E["delta_T"] == 0 354 | else adex_coba_eq + network_eqs 355 | ) 356 | I_pop = b2.NeuronGroup( 357 | N_inh, 358 | model=neuron_eq, 359 | threshold="v > v_cut", 360 | reset="v = v_reset; w+=b", 361 | refractory=param_dict_neuron_I["t_refrac"], 362 | method=param_dict_settings["integ_method"], 363 | namespace=param_dict_neuron_I, 364 | name="Ipop", 365 | ) 366 | I_pop.v = ( 367 | param_dict_neuron_I["v_rest"] 368 | + np.random.randn(N_inh) * param_dict_neuron_I["v_0_offset"][1] 369 | + param_dict_neuron_I["v_0_offset"][0] 370 | ) 371 | # check for poisson input 372 | if (param_dict_neuron_I["N_poisson"] == 0) or ( 373 | param_dict_neuron_I["poisson_rate"] == 0 374 | ): 375 | poisson_input_I = [] 376 | else: 377 | poisson_input_I = b2.PoissonInput( 378 | target=I_pop, 379 | target_var="ge", 380 | N=param_dict_neuron_I["N_poisson"], 381 | rate=param_dict_neuron_I["poisson_rate"], 382 | weight=param_dict_neuron_I["Q_poisson"], 383 | ) 384 | 385 | ############## SYNAPSES #################### 386 | ### make and connect recurrent synapses ### 387 | if ( 388 | ("n_clusters" not in param_dict_net.keys()) 389 | or (param_dict_net["n_clusters"] < 2) 390 | or (param_dict_net["R_pe2e"] == 1) 391 | ): 392 | # make homogeneous connection if 0,1 (or unspecified) cluster, or R_pe2e (ratio between in:out connection prob) is 1 393 | # print('homogeneous') 394 | syn_e2e = b2.Synapses( 395 | source=E_pop, 396 | target=E_pop, 397 | model="q: siemens", 398 | on_pre="ge_post += Q_ge", 399 | delay=param_dict_neuron_E["tdelay2e"], 400 | namespace=param_dict_neuron_E, 401 | name="syn_e2e", 402 | ) 403 | syn_e2e.connect("i!=j", p=param_dict_net["p_e2e"]) 404 | 405 | else: 406 | # print(f'clustered: n_clusters={param_dict_net["n_clusters"]}') 407 | # scale connectivity probability and make clustered connections 408 | p_out = param_dict_net["p_e2e"] 409 | p_in = p_out * param_dict_net["R_pe2e"] 410 | 411 | # NOTE: if cluster membership is ordered (last arg), then it explicitly makes 412 | # block diagonal wrt neuron id, which coincides with input (first n) 413 | # otherwise, clusters id are randomly assigned 414 | # print('ordered cluster id' if param_dict_net['order_clusters'] else 'random cluster id') 415 | # 416 | # Also, given non-integer n_clusters, int() is applied, which floors it. 417 | membership, shared_membership, conn_in, conn_out = make_clustered_network( 418 | N_exc, 419 | param_dict_net["n_clusters"], 420 | param_dict_net["clusters_per_neuron"], 421 | p_in, 422 | p_out, 423 | param_dict_net["order_clusters"], 424 | ) 425 | 426 | # scale synaptic weight 427 | Q_ge_out = param_dict_neuron_E["Q_ge"] 428 | param_dict_neuron_E["Q_ge_out"] = Q_ge_out 429 | param_dict_neuron_E["Q_ge_in"] = Q_ge_out * param_dict_net["R_Qe2e"] 430 | 431 | # make synapses and connect 432 | # in-cluster synapses 433 | syn_e2e_in = b2.Synapses( 434 | source=E_pop, 435 | target=E_pop, 436 | model="q: siemens", 437 | on_pre="ge_post += Q_ge_in", 438 | delay=param_dict_neuron_E["tdelay2e"], 439 | namespace=param_dict_neuron_E, 440 | name="syn_e2e_in", 441 | ) 442 | syn_e2e_in.connect(i=conn_in[0, :], j=conn_in[1, :]) 443 | 444 | # across-cluster synapses 445 | syn_e2e_out = b2.Synapses( 446 | source=E_pop, 447 | target=E_pop, 448 | model="q: siemens", 449 | on_pre="ge_post += Q_ge_out", 450 | delay=param_dict_neuron_E["tdelay2e"], 451 | namespace=param_dict_neuron_E, 452 | name="syn_e2e_out", 453 | ) 454 | syn_e2e_out.connect(i=conn_out[0, :], j=conn_out[1, :]) 455 | 456 | if has_inh: 457 | # no clustered connections to or from inhibitory populations 458 | # e to i 459 | syn_e2i = b2.Synapses( 460 | source=E_pop, 461 | target=I_pop, 462 | model="q: siemens", 463 | on_pre="ge_post += Q_ge", 464 | delay=param_dict_neuron_E["tdelay2i"], 465 | namespace=param_dict_neuron_I, 466 | name="syn_e2i", 467 | ) 468 | syn_e2i.connect("i!=j", p=param_dict_net["p_e2i"]) 469 | 470 | # i to e 471 | syn_i2e = b2.Synapses( 472 | source=I_pop, 473 | target=E_pop, 474 | model="q: siemens", 475 | on_pre="gi_post += Q_gi", 476 | delay=param_dict_neuron_I["tdelay2e"], 477 | namespace=param_dict_neuron_E, 478 | name="syn_i2e", 479 | ) 480 | syn_i2e.connect("i!=j", p=param_dict_net["p_i2e"]) 481 | 482 | # i to i 483 | syn_i2i = b2.Synapses( 484 | source=I_pop, 485 | target=I_pop, 486 | model="q: siemens", 487 | on_pre="gi_post += Q_gi", 488 | delay=param_dict_neuron_I["tdelay2i"], 489 | namespace=param_dict_neuron_I, 490 | name="syn_i2i", 491 | ) 492 | syn_i2i.connect("i!=j", p=param_dict_net["p_i2i"]) 493 | 494 | ### define monitors ### 495 | rate_monitors, spike_monitors, trace_monitors = [], [], [] 496 | rec_defs = param_dict_settings["record_defs"] 497 | dt_ts = param_dict_settings["dt_ts"] # recording interval for continuous variables 498 | zipped_pops = ( 499 | zip(["exc", "inh"], [E_pop, I_pop]) if has_inh else zip(["exc"], [E_pop]) 500 | ) 501 | for pop_name, pop in zipped_pops: 502 | if pop_name in rec_defs.keys(): 503 | if rec_defs[pop_name]["rate"] is not False: 504 | rate_monitors.append( 505 | b2.PopulationRateMonitor(pop, name=pop_name + "_rate") 506 | ) 507 | if rec_defs[pop_name]["spikes"] is not False: 508 | if pop_name == "exc": 509 | # special override for excitatory to record all first 510 | # and later drop randomly before saving, otherwise 511 | # recording only from first n neurons, which heavily overlap 512 | # with those stimulated, and the first few clusters 513 | rec_idx = np.arange(N_exc) 514 | else: 515 | rec_idx = ( 516 | np.arange(rec_defs[pop_name]["spikes"]) 517 | if type(rec_defs[pop_name]["spikes"]) is int 518 | else rec_defs[pop_name]["spikes"] 519 | ) 520 | spike_monitors.append( 521 | b2.SpikeMonitor(pop[rec_idx], name=pop_name + "_spikes") 522 | ) 523 | 524 | if rec_defs[pop_name]["trace"] is not False: 525 | rec_idx = ( 526 | np.arange(rec_defs[pop_name]["trace"][1]) 527 | if type(rec_defs[pop_name]["trace"][1]) is int 528 | else rec_defs[pop_name]["trace"][1] 529 | ) 530 | trace_monitors.append( 531 | b2.StateMonitor( 532 | pop, 533 | rec_defs[pop_name]["trace"][0], 534 | record=rec_idx, 535 | name=pop_name + "_trace", 536 | dt=dt_ts, 537 | ) 538 | ) 539 | 540 | ### collect into network object and return ### 541 | net_collect = b2.Network(b2.collect()) # magic collect all groups 542 | monitors = [rate_monitors, spike_monitors, trace_monitors] 543 | net_collect.add(monitors) 544 | return net_collect 545 | -------------------------------------------------------------------------------- /automind/sim/default_configs.py: -------------------------------------------------------------------------------- 1 | ### Various default configurations for the simulation and analysis. 2 | import brian2 as b2 3 | import numpy as np 4 | 5 | 6 | ADEX_NEURON_DEFAULTS_ZERLAUT = { 7 | ### parameters taken from Zerlaut et al. 2018, JCompNeurosci 8 | # these are free for each different type of population (i.e., E and I) 9 | #### NEURON PARAMS #### 10 | "C": 200 * b2.pF, # capacitance 11 | "g_L": 10 * b2.nS, # leak conductance 12 | "v_rest": -65 * b2.mV, # resting (leak) voltage 13 | "v_thresh": -50 * b2.mV, 14 | # note v_thresh for adex is not technically the spike cutoff 15 | # but where exp nonlinearity kicks in, i.e., spike initiation 16 | # spike threshold can be defined as 5*risetime from v_thresh 17 | "delta_T": 2 * b2.mV, # exponential nonlinearity scaling 18 | "a": 4 * b2.nS, 19 | "tau_w": 500 * b2.ms, 20 | "b": 20 * b2.pA, 21 | "v_reset": -65 * b2.mV, # reset voltage 22 | "t_refrac": 5 * b2.ms, 23 | "v_cut": 0 * b2.mV, # threshold voltage 24 | "v_0_offset": [ 25 | 0 * b2.mV, 26 | 4 * b2.mV, 27 | ], # starting voltage mean & variance from resting voltage 28 | #### SYNAPSE PARAMS #### 29 | "E_ge": 0 * b2.mV, 30 | "E_gi": -80 * b2.mV, 31 | "Q_ge": 1 * b2.nS, 32 | "Q_gi": 5 * b2.nS, 33 | "tau_ge": 5 * b2.ms, 34 | "tau_gi": 5 * b2.ms, 35 | "tdelay2e": 0 * b2.ms, 36 | "tdelay2i": 0 * b2.ms, 37 | #### external input #### 38 | "N_poisson": 500, 39 | "Q_poisson": 1 * b2.nS, 40 | "poisson_rate": 1 * b2.Hz, 41 | "I_bias": 0 * b2.pA, 42 | "p_igniters": 1.0, 43 | } 44 | 45 | ADEX_NET_DEFAULTS = { 46 | # cell counts 47 | "N_pop": 2000, 48 | "exc_prop": 0.8, 49 | # 2 population weight matrix 50 | "p_e2e": 0.05, 51 | "p_e2i": 0.05, 52 | "p_i2e": 0.05, 53 | "p_i2i": 0.05, 54 | # clustered connectivity scaling 55 | "n_clusters": 1, 56 | "clusters_per_neuron": 2, 57 | "R_pe2e": 1, 58 | "R_Qe2e": 1, 59 | "order_clusters": False, 60 | } 61 | 62 | SIM_SETTINGS_DEFAULTS = { 63 | ### simulation default settings 64 | "experiment": None, 65 | "network_type": "adex", 66 | "sim_time": 1.1 * b2.second, 67 | "dt": 0.2 * b2.ms, 68 | "dt_ts": 1.0 * b2.ms, 69 | "batch_seed": 0, 70 | "random_seed": 42, 71 | "record_defs": { 72 | "exc": {"rate": True, "spikes": False, "trace": (["I"], np.arange(10))} 73 | }, 74 | "save_sigdigs": 6, 75 | "t_sigdigs": 4, 76 | "integ_method": "euler", 77 | "real_run_time": 0.0, 78 | } 79 | 80 | ANALYSIS_DEFAULTS = { 81 | ### early stop settings 82 | "t_early_stop": 1.1 * b2.second, 83 | "early_stop_window": np.array([0.1, 1.1]), 84 | "early_stopped": False, 85 | "stop_fr_norm": (0.0001, 0.99), 86 | ### spike and rate summary settings 87 | "do_spikes": True, 88 | "pop_sampler": {"exc": None}, 89 | "analysis_window": np.array([0.1, None]), 90 | "smooth_std": 0.0005, 91 | "dt_poprate": 1.0 * b2.ms, 92 | "min_num_spikes": 3, 93 | ### bursts 94 | "do_bursts": True, 95 | "use_burst_prominence": True, 96 | "min_burst_height": 5, 97 | "min_burst_height_ratio": 0.5, 98 | "min_burst_distance": 1.0, 99 | "burst_win": [-0.5, 2.5], 100 | "burst_wlen": 10, 101 | "burst_rel_height": 0.95, 102 | ### PSD settings 103 | "do_psd": False, 104 | "nperseg_ratio": 0.5, 105 | "noverlap_ratio": 0.75, 106 | "f_lim": 500, 107 | ### PCA settings 108 | "do_pca": False, 109 | "n_pcs": 100, 110 | "pca_bin_width": 10.0 * b2.ms, 111 | "pca_smooth_std": 50.0 * b2.ms, 112 | } 113 | 114 | MKI_pretty_param_names = [ 115 | "$\% E:I$", 116 | r"$p_{E\rightarrow E}$", 117 | r"$p_{E\rightarrow I}$", 118 | r"$p_{I\rightarrow E}$", 119 | r"$p_{I\rightarrow I}$", 120 | r"$R_{p_{E-clus}}$", 121 | r"$R_{Q_{E-clus}}$", 122 | r"$\% input_E$", 123 | r"$\# clus$", 124 | r"C", 125 | r"$g_L$", 126 | r"$V_{rest}$", 127 | r"$V_{thresh}$", 128 | r"$V_{reset}$", 129 | r"$t_{refrac}$", 130 | r"$\Delta T$", 131 | r"$g_{adpt}$ (a)", 132 | r"$Q_{adpt}$ (b)", 133 | r"$\tau_{adpt}$", 134 | r"$E_{I\rightarrow E}$", 135 | r"$Q_{E\rightarrow E}$", 136 | r"$Q_{I\rightarrow E}$", 137 | r"$\tau_{E\rightarrow E}$", 138 | r"$\tau_{I\rightarrow E}$", 139 | r"$\nu_{input\rightarrow E}$", 140 | r"$Q_{E\rightarrow I}$", 141 | r"$Q_{I\rightarrow I}$", 142 | r"$\nu_{input\rightarrow I}$", 143 | ] 144 | 145 | MKI_3col_plot_order = np.array( 146 | [ 147 | 0, 148 | 1, 149 | 2, 150 | 3, 151 | 4, 152 | 5, 153 | 6, 154 | 7, 155 | 8, 156 | -1, 157 | 9, 158 | 10, 159 | 11, 160 | 12, 161 | 13, 162 | 14, 163 | 15, 164 | 16, 165 | 17, 166 | 18, 167 | 20, 168 | 21, 169 | 25, 170 | 26, 171 | 19, 172 | 22, 173 | 23, 174 | 24, 175 | 27, 176 | -1, 177 | ] 178 | ) 179 | -------------------------------------------------------------------------------- /automind/sim/runners.py: -------------------------------------------------------------------------------- 1 | ### This file contains the main simulation runner scripts and helpers. 2 | import numpy as np 3 | import brian2 as b2 4 | import pandas as pd 5 | from time import time 6 | import sys 7 | 8 | from . import b2_models, b2_interface 9 | from ..analysis import spikes_summary 10 | from ..utils import data_utils, dist_utils 11 | from ..sim import b2_interface, default_configs 12 | 13 | 14 | def adex_check_early_stop(net_collect, params_dict, verbose=False): 15 | """ 16 | Check if simulation should be stopped based on firing rate. 17 | """ 18 | params_analysis = params_dict["params_analysis"] 19 | params_settings = params_dict["params_settings"] 20 | t_refrac = np.array( 21 | params_dict["params_Epop"]["t_refrac"] 22 | ) # get refractory period for norming FR 23 | 24 | # compute FR stats 25 | query_spikes = { 26 | "exc_spikes": b2_interface._deunitize_spiketimes( 27 | net_collect["exc_spikes"].spike_trains() 28 | ) 29 | } 30 | 31 | df_summary = spikes_summary.return_df_summary( 32 | query_spikes, 33 | params_analysis["analysis_window"], 34 | params_analysis["min_num_spikes"], 35 | ) 36 | 37 | fr_norm = ( 38 | t_refrac 39 | * df_summary["isi_numspks"].mean(skipna=True) 40 | / np.diff(params_analysis["early_stop_window"]) 41 | ) 42 | 43 | # compare 44 | if (fr_norm <= params_analysis["stop_fr_norm"][0]) or ( 45 | fr_norm >= params_analysis["stop_fr_norm"][1] 46 | ): 47 | if verbose: 48 | print("early stop: normed FR = %.4f%%" % (fr_norm * 100)) 49 | params_analysis["early_stopped"] = True 50 | params_settings["sim_time"] = params_analysis["t_early_stop"] 51 | else: 52 | if verbose: 53 | print("continuing...") 54 | 55 | return params_dict 56 | 57 | 58 | def run_net_early_stop(net_collect, params_dict): 59 | """ 60 | Run simulation while checking for early stoppage. 61 | """ 62 | # ---- early stoppage ---- 63 | start_time = time() 64 | if params_dict["params_analysis"]["t_early_stop"]: 65 | # run for short time and check for min/max firing 66 | net_collect.run(params_dict["params_analysis"]["t_early_stop"]) 67 | params_dict = adex_check_early_stop(net_collect, params_dict, verbose=False) 68 | # did not stop, continue with sim the rest of the way 69 | if not params_dict["params_analysis"]["early_stopped"]: 70 | net_collect.run( 71 | params_dict["params_settings"]["sim_time"] 72 | - params_dict["params_analysis"]["t_early_stop"] 73 | ) 74 | else: 75 | # no early stopping just go 76 | net_collect.run(params_dict["params_settings"]["sim_time"]) 77 | 78 | # update params 79 | params_dict["params_settings"]["real_run_time"] = time() - start_time 80 | return params_dict, net_collect 81 | 82 | 83 | def adex_simulator(params_dict): 84 | """ 85 | Simulator wrapper for AdEx net, run simulations and collect data. 86 | """ 87 | 88 | print( 89 | f"{params_dict['params_settings']['batch_seed']}-{params_dict['params_settings']['random_seed']}", 90 | end="|", 91 | ) 92 | 93 | try: 94 | # set up and run model with early stopping 95 | network_type = params_dict["params_settings"]["network_type"] 96 | if (not network_type) or (network_type == "adex"): 97 | net_collect = b2_models.adaptive_exp_net(params_dict) 98 | elif network_type == "adex_clustered": 99 | net_collect = b2_models.adaptive_exp_net_clustered(params_dict) 100 | 101 | # run the model 102 | params_dict, net_collect = run_net_early_stop(net_collect, params_dict) 103 | 104 | # return pickleable outputs for pool 105 | spikes, timeseries = data_utils.collect_raw_data(net_collect, params_dict) 106 | 107 | return params_dict, spikes, timeseries 108 | 109 | except Exception as e: 110 | print("-----") 111 | print(e) 112 | print( 113 | f"{params_dict['params_settings']['batch_seed']}-{params_dict['params_settings']['random_seed']}: FAILED" 114 | ) 115 | print("-----") 116 | return params_dict, {}, {} 117 | 118 | 119 | ######## 120 | ## experiment setup stuff 121 | ######## 122 | def construct_experiment_settings_adex(update_dict=None): 123 | """Construct and fill experiment settings for AdEx model.""" 124 | #### grab default configs and fill in 125 | params_Epop = default_configs.ADEX_NEURON_DEFAULTS_ZERLAUT.copy() 126 | params_Ipop = default_configs.ADEX_NEURON_DEFAULTS_ZERLAUT.copy() 127 | params_net = default_configs.ADEX_NET_DEFAULTS.copy() 128 | params_analysis = default_configs.ANALYSIS_DEFAULTS.copy() 129 | params_settings = default_configs.SIM_SETTINGS_DEFAULTS.copy() 130 | 131 | # network configurations & set up simulation time and clock 132 | params_settings["t_sigs"] = int( 133 | np.ceil(-np.log10(params_settings["dt"] / b2.second)) 134 | ) 135 | 136 | # collect 137 | params_dict = { 138 | "params_net": params_net, 139 | "params_Epop": params_Epop, 140 | "params_Ipop": params_Ipop, 141 | "params_analysis": params_analysis, 142 | "params_settings": params_settings, 143 | } 144 | 145 | # update non-default values 146 | if update_dict: 147 | params_dict = data_utils.update_params_dict(params_dict, update_dict) 148 | 149 | # set some final configurations 150 | b2_interface.set_adaptive_vcut(params_dict["params_Epop"]) 151 | b2_interface.set_adaptive_vcut(params_dict["params_Ipop"]) 152 | 153 | return params_dict 154 | 155 | 156 | def set_up_for_presampled( 157 | cfg, hydra_path, exp_config, construct_experiment_settings_fn 158 | ): 159 | """Set up for experiments where parameter samples are pre-generated and saved in csv. 160 | This is the new workflow where prior or posterior samples are generated in advance and saved in csv files. 161 | """ 162 | # make path dict 163 | path_dict = data_utils.make_subfolders(hydra_path) 164 | 165 | # Extract relevant file names 166 | file_type_list = ["prior.pickle", "_samples.csv"] 167 | data_file_dict = data_utils.extract_data_files(path_dict["data"], file_type_list) 168 | 169 | # Load samples 170 | df_prior_samples = pd.read_csv( 171 | path_dict["data"] + data_file_dict["_samples"], index_col=0 172 | ) 173 | batch_seed = int(df_prior_samples.iloc[0]["params_settings.batch_seed"]) 174 | 175 | # Check if there is already a ran file, if so, we continue from the end 176 | samples_ran_file_dict = data_utils.extract_data_files( 177 | path_dict["data"], ["samples_ran.csv"] 178 | ) 179 | if samples_ran_file_dict["samples_ran"]: 180 | # Pick up from where we left off, load the stuff 181 | print( 182 | f"Found {samples_ran_file_dict['samples_ran']}, continuing from where we left off..." 183 | ) 184 | 185 | # Extract params_dict_default 186 | params_file_dict = data_utils.extract_data_files( 187 | path_dict["data"], ["params_dict_default.pickle"] 188 | ) 189 | 190 | # Load params_dict_default 191 | params_dict_default = data_utils.load_pickled( 192 | path_dict["data"] + params_file_dict["params_dict_default"] 193 | ) 194 | 195 | else: 196 | # Set up params_dict from scratch 197 | exp_settings_update = exp_config.exp_settings_update 198 | if cfg["infrastructure"]["flag_test"]: 199 | exp_settings_update["params_settings.sim_time"] = 20 * b2.second 200 | 201 | if cfg["experiment"]["sim_time"]: 202 | print( 203 | f'Simulation duration updated: {cfg["experiment"]["sim_time"]} seconds.' 204 | ) 205 | exp_settings_update["params_settings.sim_time"] = ( 206 | cfg["experiment"]["sim_time"] * b2.second 207 | ) 208 | 209 | # construct params and priors 210 | params_dict_default = construct_experiment_settings_fn(exp_settings_update) 211 | params_dict_default["params_settings"]["experiment"] = cfg["experiment"][ 212 | "network_name" 213 | ] 214 | _, params_dict_default = data_utils.set_seed_by_time(params_dict_default) 215 | 216 | # Put path dict in params_dict 217 | params_dict_default["path_dict"] = path_dict 218 | 219 | # Set batch seed 220 | params_dict_default["params_settings"]["batch_seed"] = batch_seed 221 | 222 | # Save params_dict_default 223 | data_utils.pickle_file( 224 | params_dict_default["path_dict"]["data"] 225 | + f"{batch_seed}_params_dict_default.pickle", 226 | params_dict_default, 227 | ) 228 | 229 | # Chop off the ones that have already been ran 230 | if samples_ran_file_dict["samples_ran"]: 231 | df_prior_samples_ran = pd.read_csv( 232 | path_dict["data"] + samples_ran_file_dict["samples_ran"], index_col=0 233 | ) 234 | # Check that the samples that ran are consistent 235 | if df_prior_samples.iloc[: len(df_prior_samples_ran)][ 236 | df_prior_samples.columns[:4] 237 | ].equals(df_prior_samples_ran[df_prior_samples.columns[:4]]): 238 | print( 239 | "Samples ran are consistent with the samples in the csv file. Continuing from where we left off..." 240 | ) 241 | print(f"Samples ran: {len(df_prior_samples_ran)}") 242 | print(f"Samples total: {len(df_prior_samples)}") 243 | df_prior_samples = df_prior_samples.iloc[len(df_prior_samples_ran) :] 244 | print( 245 | f"Samples left to run: {len(df_prior_samples)}, from {df_prior_samples.index[0]} to {df_prior_samples.index[-1]}" 246 | ) 247 | else: 248 | print( 249 | "Samples ran are inconsistent with the samples in the csv file. Exiting..." 250 | ) 251 | sys.exit() 252 | 253 | prior = data_utils.load_pickled(path_dict["data"] + data_file_dict["prior"]) 254 | n_samples = len(df_prior_samples) 255 | 256 | # Set seeds 257 | data_utils.set_all_seeds(batch_seed) 258 | 259 | # Plug samples into list of param dictionaries for simulation 260 | params_dict_list = data_utils.fill_params_dict( 261 | params_dict_default, df_prior_samples, prior.as_dict, n_samples 262 | ) 263 | return prior, df_prior_samples, params_dict_default, params_dict_list 264 | 265 | 266 | def set_up_from_hydra(cfg, hydra_path, exp_config, construct_experiment_settings_fn): 267 | """Set up for experiments where parameter samples are generated on the fly.""" 268 | exp_settings_update = exp_config.exp_settings_update 269 | if cfg["infrastructure"]["flag_test"]: 270 | exp_settings_update["params_settings.sim_time"] = 20 * b2.second 271 | 272 | if cfg["experiment"]["sim_time"]: 273 | print(f'Simulation duration updated: {cfg["experiment"]["sim_time"]} seconds.') 274 | exp_settings_update["params_settings.sim_time"] = ( 275 | cfg["experiment"]["sim_time"] * b2.second 276 | ) 277 | 278 | # construct params and priors 279 | params_dict_default = construct_experiment_settings_fn(exp_settings_update) 280 | params_dict_default["params_settings"]["experiment"] = cfg["experiment"]["name"] 281 | 282 | # option to manually set seed 283 | if cfg["experiment"]["batch_seed"]: 284 | batch_seed = cfg["experiment"]["batch_seed"] 285 | params_dict_default["params_settings"]["batch_seed"] = batch_seed 286 | else: 287 | batch_seed, params_dict_default = data_utils.set_seed_by_time( 288 | params_dict_default 289 | ) 290 | 291 | # make path dict 292 | path_dict = data_utils.make_subfolders(hydra_path) 293 | params_dict_default["path_dict"] = path_dict 294 | 295 | proposal_path = cfg["experiment"]["proposal_path"] 296 | # set seeds 297 | data_utils.set_all_seeds(batch_seed) 298 | 299 | ####--------------- 300 | if "multiround" in cfg["experiment"]["name"]: 301 | if proposal_path == "none": 302 | print("Multi-round requires density estimator / proposal. Exiting...") 303 | sys.exit() 304 | 305 | # load proposal from previously pickled file 306 | print(proposal_path) 307 | prior = data_utils.load_pickled(proposal_path) 308 | 309 | # load xo queries and data 310 | n_samples = int(cfg["experiment"]["n_samples"]) 311 | df_prior_samples = dist_utils.sample_proposal(prior, n_samples) 312 | batch, run = (cfg["experiment"]["xo_batch"], cfg["experiment"]["xo_run"]) 313 | round_number = int(proposal_path.split("posterior_R")[1][0]) 314 | df_prior_samples.insert(0, "x_o", f"{batch}_{run}") 315 | df_prior_samples.insert(2, "round", round_number) 316 | 317 | # specific setup depending on round1 or round2 318 | elif "round2" in cfg["experiment"]["name"]: 319 | # round 2, get density estimator and draw from posterior 320 | if proposal_path == "none": 321 | print("Round 2 requires density estimator / proposal. Exiting...") 322 | sys.exit() 323 | 324 | # load proposal from previously pickled file 325 | # technically, this is the unconditioned density estimator, but since 326 | # in round2 it's likely that we're running multiple observations in parallel 327 | # just save the density estimator as the "proposal" downstream 328 | print(proposal_path) 329 | prior = data_utils.load_pickled(proposal_path) 330 | 331 | # load xo queries and data 332 | n_samples_per = int(cfg["experiment"]["n_samples_per"]) 333 | 334 | xo_queries_database = exp_config.xo_queries_database 335 | xo_queries = xo_queries_database[cfg["experiment"]["xo_type"]] 336 | n_samples = len(xo_queries) * n_samples_per 337 | df_xos = pd.read_csv(cfg["experiment"]["xo_path"], index_col=0) 338 | 339 | # fill all the burst nans when there is valid isi stats 340 | # df_xos = data_utils.fill_nans_in_xos(df_xos, "isi_", 0) 341 | 342 | # condition and sample from posterior 343 | df_prior_samples = dist_utils.sample_different_xos_from_posterior( 344 | prior, df_xos, xo_queries, n_samples_per 345 | ) 346 | 347 | else: 348 | # assume its round1, restricted, or default 349 | # make or load prior 350 | if proposal_path == "none": 351 | # no proposal distribution object, make new one 352 | variable_params = exp_config.variable_params 353 | prior = dist_utils.CustomIndependentJoint(variable_params) 354 | else: 355 | # load proposal from previously pickled file 356 | print(proposal_path) 357 | prior = data_utils.load_pickled(proposal_path) 358 | 359 | # draw from prior 360 | # get sample size 361 | n_samples = cfg["infrastructure"]["n_samples"] 362 | 363 | # sample prior 364 | samples = prior.sample((n_samples,)).numpy().astype(float) 365 | df_prior_samples = pd.DataFrame(samples, columns=prior.names) 366 | # 367 | ####------------------- 368 | 369 | # sort out seed types and make into dataframe 370 | df_prior_samples.insert( 371 | loc=0, column="params_settings.batch_seed", value=batch_seed 372 | ) 373 | 374 | random_seeds = np.random.choice( 375 | a=int(n_samples * 100), size=n_samples, replace=False 376 | ) 377 | if "round2" in cfg["experiment"]["name"]: 378 | # sort seeds so it doesn't scramble the round 2 simulations 379 | random_seeds = np.sort(random_seeds) 380 | 381 | df_prior_samples.insert( 382 | loc=1, 383 | column="params_settings.random_seed", 384 | value=random_seeds, 385 | ) 386 | df_prior_samples = df_prior_samples.sort_values( 387 | "params_settings.random_seed", ignore_index=True 388 | ) 389 | 390 | # save prior, samples, and default configs 391 | df_prior_samples.to_csv( 392 | params_dict_default["path_dict"]["data"] + f"{batch_seed}_prior_samples.csv" 393 | ) 394 | data_utils.save_params_priors( 395 | params_dict_default["path_dict"]["data"] + f"{batch_seed}_", 396 | params_dict_default, 397 | prior, 398 | ) 399 | 400 | # plug samples into list of param dictionaries for simulation 401 | params_dict_list = data_utils.fill_params_dict( 402 | params_dict_default, df_prior_samples, prior.as_dict, n_samples 403 | ) 404 | 405 | return prior, df_prior_samples, params_dict_default, params_dict_list 406 | -------------------------------------------------------------------------------- /automind/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mackelab/automind/e5e3cd999b39bf2b9bd365e43c6d379d245d3e77/automind/utils/__init__.py -------------------------------------------------------------------------------- /automind/utils/analysis_utils.py: -------------------------------------------------------------------------------- 1 | ### Utility functions for computing summary features from simulations data. 2 | import numpy as np 3 | import pandas as pd 4 | from scipy.stats import pearsonr, spearmanr 5 | from brian2 import second 6 | from ..analysis import spikes_summary 7 | from . import data_utils 8 | 9 | 10 | def compute_spike_features_only(spikes, params_dict): 11 | """Compute single unit ISI features. 12 | 13 | Args: 14 | spikes (dict): Dictionary of spike times. 15 | params_dict (dict): Dictionary containing analysis configurations. 16 | 17 | Returns: 18 | DataFrame: Dataframe of single unit ISI features. 19 | """ 20 | # summary function for just SU features, usually for training restricted prior 21 | if "t_end" in spikes.keys(): 22 | # this has to be popped because downstream analysis looks at all keys as population definitions 23 | t_end = spikes.pop("t_end") 24 | # compute single unit isi features 25 | df_spikes_SU = spikes_summary.return_df_summary( 26 | spikes, 27 | params_dict["params_analysis"]["analysis_window"], 28 | params_dict["params_analysis"]["min_num_spikes"], 29 | ) 30 | df_spikes = spikes_summary.get_population_average( 31 | df_spikes_SU, params_dict["params_analysis"]["pop_sampler"] 32 | ) 33 | return df_spikes 34 | 35 | 36 | def compute_burst_features(rate, params_dict): 37 | """Compute burst features from population rate. 38 | 39 | Args: 40 | rate (array): Population rate. 41 | params_dict (dict): Dictionary containing analysis configurations. 42 | 43 | Returns: 44 | tuple: (DataFrame, dict) Dataframe of burst features and burst statistics. 45 | """ 46 | # get burst event statistics 47 | burst_analysis_params = { 48 | k: v for k, v in params_dict["params_analysis"].items() if "burst_" in k 49 | } 50 | burst_stats = spikes_summary.get_popburst_peaks( 51 | rate, 52 | fs=second / params_dict["params_analysis"]["dt_poprate"], 53 | **burst_analysis_params, 54 | ) 55 | # get subpeak features if required 56 | if "min_subpeak_prom_ratio" in params_dict["params_analysis"]: 57 | burst_stats = spikes_summary.get_burst_subpeaks( 58 | rate, 59 | second / params_dict["params_analysis"]["dt_poprate"], 60 | burst_stats, 61 | params_dict["params_analysis"]["min_subpeak_prom_ratio"], 62 | params_dict["params_analysis"]["min_subpeak_distance"], 63 | ) 64 | # get stats in dataframe 65 | if len(burst_stats["burst_kernels"]) > 1: 66 | # compute burst features 67 | df_bursts = spikes_summary.compute_burst_summaries(burst_stats) 68 | else: 69 | # return empty df row otherwise it screws up the final dataframe 70 | col_names = [ 71 | "burst_num", 72 | "burst_interval_mean", 73 | "burst_interval_std", 74 | "burst_interval_cv", 75 | "burst_peak_fr_mean", 76 | "burst_peak_fr_std", 77 | "burst_width_mean", 78 | "burst_width_std", 79 | "burst_onset_time_mean", 80 | "burst_onset_time_std", 81 | "burst_offset_time_mean", 82 | "burst_offset_time_std", 83 | "burst_corr_mean", 84 | "burst_corr_std", 85 | "burst_corr_interval2nextpeak", 86 | "burst_corr_interval2prevpeak", 87 | "burst_numsubpeaks_mean", 88 | "burst_numsubpeaks_std", 89 | "burst_mean_fr_mean", 90 | "burst_mean_fr_std", 91 | ] 92 | df_bursts = pd.DataFrame(columns=col_names) 93 | 94 | return df_bursts, burst_stats 95 | 96 | 97 | def compute_spike_burst_features(spikes, params_dict): 98 | """Compute single unit ISI and burst features.""" 99 | if "t_end" in spikes.keys(): 100 | # this has to be popped because downstream analysis looks at all keys as population definitions 101 | t_end = spikes.pop("t_end") 102 | else: 103 | print("No end time in spikes_dict, defaulting to 200.1s.") 104 | t_end = 200.1 105 | 106 | # compute single unit isi features 107 | df_spikes_SU = spikes_summary.return_df_summary( 108 | spikes, 109 | params_dict["params_analysis"]["analysis_window"], 110 | params_dict["params_analysis"]["min_num_spikes"], 111 | ) 112 | df_spikes = spikes_summary.get_population_average( 113 | df_spikes_SU, params_dict["params_analysis"]["pop_sampler"] 114 | ) 115 | 116 | # compute population rate 117 | pop_rates = spikes_summary.compute_poprate_from_spikes( 118 | spikes, 119 | t_collect=(0, t_end), 120 | dt=params_dict["params_analysis"]["dt_poprate"] / second, 121 | dt_bin=params_dict["params_settings"]["dt"] / second, 122 | pop_sampler=params_dict["params_analysis"]["pop_sampler"], 123 | smooth_std=params_dict["params_analysis"]["smooth_std"], 124 | ) 125 | pop_rates["avgpop_rate"] = spikes_summary.compute_average_pop_rates( 126 | pop_rates, params_dict["params_analysis"]["pop_sampler"] 127 | ) 128 | # get burst event statistics 129 | burst_analysis_params = { 130 | k: v for k, v in params_dict["params_analysis"].items() if "burst_" in k 131 | } 132 | burst_stats = spikes_summary.get_popburst_peaks( 133 | pop_rates["avgpop_rate"], 134 | fs=second / params_dict["params_analysis"]["dt_poprate"], 135 | **burst_analysis_params, 136 | ) 137 | 138 | if "min_subpeak_prom_ratio" in params_dict["params_analysis"]: 139 | burst_stats = spikes_summary.get_burst_subpeaks( 140 | pop_rates["avgpop_rate"], 141 | second / params_dict["params_analysis"]["dt_poprate"], 142 | burst_stats, 143 | params_dict["params_analysis"]["min_subpeak_prom_ratio"], 144 | params_dict["params_analysis"]["min_subpeak_distance"], 145 | ) 146 | 147 | if len(burst_stats["burst_kernels"]) > 1: 148 | # compute burst features 149 | df_bursts = spikes_summary.compute_burst_summaries(burst_stats) 150 | 151 | # merge spike and network features 152 | return pd.concat((df_spikes, df_bursts), axis=1), pop_rates, burst_stats 153 | else: 154 | return df_spikes, pop_rates, burst_stats 155 | 156 | 157 | def compute_summary_features(spikes, params_dict): 158 | """Compute all summary features from simulation data.""" 159 | if "t_end" in spikes.keys(): 160 | t_end = spikes["t_end"] 161 | else: 162 | t_end = 200.1 163 | 164 | result_collector = {} 165 | 166 | #### compute single unit isi features 167 | if params_dict["params_analysis"]["do_spikes"]: 168 | df_spikes_SU = spikes_summary.return_df_summary( 169 | # get just spikes in popsampler 170 | { 171 | k: v 172 | for k, v in spikes.items() 173 | if k.split("_")[0] 174 | in params_dict["params_analysis"]["pop_sampler"].keys() 175 | }, 176 | params_dict["params_analysis"]["analysis_window"], 177 | params_dict["params_analysis"]["min_num_spikes"], 178 | ) 179 | df_spikes = spikes_summary.get_population_average( 180 | df_spikes_SU, params_dict["params_analysis"]["pop_sampler"] 181 | ) 182 | result_collector["summary_spikes"] = df_spikes 183 | 184 | # compute unsmoothed total population rate 185 | dt = params_dict["params_analysis"]["dt_poprate"] / second 186 | pop_rates_raw = spikes_summary.compute_poprate_from_spikes( 187 | spikes, 188 | t_collect=(0, t_end), 189 | dt=dt, 190 | dt_bin=dt, 191 | # pop_sampler=params_dict["params_analysis"]["pop_sampler"], 192 | pop_sampler={k.split("_")[0]: None for k in spikes.keys() if k != "t_end"}, 193 | ) 194 | pop_rates_raw["avgpop_rate"] = spikes_summary.compute_average_pop_rates( 195 | pop_rates_raw, params_dict["params_analysis"]["pop_sampler"] 196 | ) 197 | 198 | ##### get PSD 199 | if params_dict["params_analysis"]["do_psd"]: 200 | df_psd = spikes_summary.compute_psd( 201 | pop_rates_raw, params_dict["params_analysis"] 202 | ) 203 | result_collector["summary_psd"] = df_psd 204 | 205 | ###### get burst features 206 | if params_dict["params_analysis"]["do_bursts"]: 207 | # smooth population rates 208 | pop_rates_smo = pop_rates_raw.copy() 209 | for pop in params_dict["params_analysis"]["pop_sampler"].keys(): 210 | pop_rates_smo[pop + "_rate"] = spikes_summary.smooth_with_gaussian( 211 | pop_rates_raw[pop + "_rate"], 212 | dt, 213 | params_dict["params_analysis"]["smooth_std"], 214 | ) 215 | 216 | pop_rates_smo["avgpop_rate"] = spikes_summary.compute_average_pop_rates( 217 | pop_rates_smo, params_dict["params_analysis"]["pop_sampler"] 218 | ) 219 | df_burst, burst_stats = compute_burst_features( 220 | pop_rates_smo["avgpop_rate"], params_dict 221 | ) 222 | 223 | result_collector["summary_bursts"] = df_burst 224 | result_collector["summary_burst_stats"] = burst_stats 225 | result_collector["pop_rates"] = pop_rates_smo 226 | 227 | ##### get PCA 228 | if params_dict["params_analysis"]["do_pca"]: 229 | df_pca = compute_pca_features(spikes, (0, t_end), params_dict) 230 | result_collector["summary_pca"] = df_pca 231 | return result_collector 232 | 233 | 234 | def compute_correlations(np_array, method="pearson"): 235 | """Compute correlation coefficients and p-values for all pairs of variables in a numpy array.""" 236 | # Get the number of variables (columns) 237 | n_vars = np_array.shape[1] 238 | 239 | # Initialize matrices to store correlation coefficients and p-values 240 | corr_matrix = np.zeros((n_vars, n_vars)) 241 | p_values = np.ones((n_vars, n_vars)) 242 | 243 | # Compute correlation coefficient and p-value for each pair of variables 244 | for i in range(n_vars): 245 | for j in range(i + 1, n_vars): 246 | if method == "pearson": 247 | corr, pval = pearsonr(np_array[:, i], np_array[:, j]) 248 | elif method == "spearman": 249 | corr, pval = spearmanr(np_array[:, i], np_array[:, j]) 250 | 251 | corr_matrix[i, j] = corr_matrix[j, i] = corr 252 | p_values[i, j] = p_values[j, i] = pval 253 | 254 | return corr_matrix, p_values 255 | 256 | 257 | def compute_participation(df_pca, up_to_npcs=10, norm_by=None): 258 | """Compute participation ratio for a given number of PCs.""" 259 | # sum squared divided by sum of squares 260 | if norm_by is None: 261 | norm_by = up_to_npcs 262 | eigvals = df_pca.values[:, :up_to_npcs] 263 | pr = (eigvals.sum(1)) ** 2 / (eigvals**2).sum(1) / norm_by 264 | return pr 265 | 266 | 267 | def check_ei_param_order(param_names): 268 | assert param_names.index("params_Epop.Q_ge") == 20 269 | assert param_names.index("params_Epop.tau_ge") == 22 270 | assert param_names.index("params_Epop.v_rest") == 11 271 | assert param_names.index("params_Epop.poisson_rate") == 24 272 | assert param_names.index("params_Epop.p_igniters") == 7 273 | 274 | assert param_names.index("params_Epop.Q_gi") == 21 275 | assert param_names.index("params_Epop.tau_gi") == 23 276 | assert param_names.index("params_Epop.E_gi") == 19 277 | assert param_names.index("params_Ipop.poisson_rate") == 27 278 | 279 | assert param_names.index("params_net.exc_prop") == 0 280 | assert param_names.index("params_net.p_e2e") == 1 281 | assert param_names.index("params_net.p_i2e") == 3 282 | assert param_names.index("params_net.R_pe2e") == 5 283 | assert param_names.index("params_net.R_Qe2e") == 6 284 | assert param_names.index("params_net.n_clusters") == 8 285 | 286 | 287 | def check_tau_param_order(param_names): 288 | assert param_names.index("params_Epop.C") == 9 289 | assert param_names.index("params_Epop.g_L") == 10 290 | 291 | 292 | def check_v_param_order(param_names): 293 | assert param_names.index("params_Epop.v_rest") == 11 294 | assert param_names.index("params_Epop.v_thresh") == 12 295 | assert param_names.index("params_Epop.v_reset") == 13 296 | assert param_names.index("params_Epop.E_gi") == 19 297 | 298 | 299 | def compute_composite_tau(samples, param_names=None): 300 | """Computes single neuron membrane time constant in ms, i.e., tau = C/g_L""" 301 | if param_names is not None: 302 | check_tau_param_order(param_names) 303 | C, g_L = np.hsplit(samples[:, [9, 10]], 2) 304 | return C / g_L 305 | 306 | 307 | def compute_composite_v_diff(samples, param_names=None): 308 | """Compute voltage difference between various and resting. 309 | v_thresh - v_rest, v_reset - v_rest, E_gi - v_rest 310 | """ 311 | if param_names is not None: 312 | check_v_param_order(param_names) 313 | v_rest, v_thresh, v_reset, E_gi = np.hsplit(samples[:, [11, 12, 13, 19]], 4) 314 | return np.hstack([v_thresh - v_rest, v_reset - v_rest, E_gi - v_rest]) 315 | 316 | 317 | def compute_composite_ei(samples, param_names=None, make_clus_adjust=False): 318 | """quant, count, total rec, input""" 319 | # Two current issues with this: 320 | # 1. what to do when E_ie is above the resting potential, i.e., no inhibition 321 | # another problem: cannot use v_rest for Vavg 322 | # maybe check here? https://neuronaldynamics.epfl.ch/online/Ch13.S6.html 323 | 324 | # 2. DONE how to account for the cluster amplification 325 | # (plus the variable checking) 326 | 327 | if param_names is not None: 328 | check_ei_param_order(param_names) 329 | N_neurons = 2000 # constant number of neurons per network 330 | N_poissons = 500 # hardcoded number of poisson inputs 331 | Q_poisson = 1.0 # hardcoded poisson synaptic conductance 332 | ### E parameters 333 | Q_ge, tau_ge, V_rest, poissonE_rate, p_igniters = np.hsplit( 334 | samples[:, [20, 22, 11, 24, 7]], 5 335 | ) 336 | E_ge = 0 # hardcoded reversal potential 337 | 338 | ### I parameters 339 | Q_gi, tau_gi, E_gi, poissonI_rate = np.hsplit(samples[:, [21, 23, 19, 27]], 4) 340 | 341 | ### network parameters 342 | exc_prop, p_e2e, p_i2e, R_pe, R_Qe, n_clusters = np.hsplit( 343 | samples[:, [0, 1, 3, 5, 6, 8]], 6 344 | ) 345 | clus_per_neuron = 2 # hardcoded, each neuron belongs to 2 clusters 346 | 347 | ### NOT SURE IF V_rest is the right thing to use here, in general NOT Vavg 348 | 349 | # E synapse quantile charge 350 | exc_quant = Q_ge * tau_ge # * (E_ge - V_rest) 351 | # E recurrent synapse count per neuron 352 | exc_count = p_e2e * exc_prop * N_neurons 353 | # cluster amplification of conductance 354 | exc_quant_clus = exc_quant * R_Qe 355 | # cluster amplification of intra-cluster connectivity 356 | exc_count_clus = ( 357 | (exc_count / np.floor(np.maximum(n_clusters, 1.0))) 358 | * clus_per_neuron 359 | * (R_pe - 1.0) 360 | ) 361 | 362 | # total recurrent E charge: baseline + intra-cluster 363 | exc_rec_total = exc_quant * exc_count 364 | exc_rec_clus = exc_quant_clus * exc_count_clus 365 | if make_clus_adjust: 366 | # Adjust for the clustering 367 | exc_count += exc_count_clus 368 | exc_rec_total += exc_rec_clus 369 | exc_quant = exc_rec_total / exc_count 370 | 371 | # external input to E charge 372 | exc_ext = ( 373 | poissonE_rate * p_igniters * N_poissons * (Q_poisson * tau_ge * (E_ge - V_rest)) 374 | ) 375 | 376 | # I synapse quantile charge 377 | inh_quant = Q_gi * tau_gi # * (-(E_gi - V_rest)) 378 | 379 | # print((inh_quant<0).sum(), len(inh_quant)) 380 | 381 | # I recurrent synapse count 382 | inh_count = p_i2e * (1.0 - exc_prop) * N_neurons 383 | # total recurrent I charge 384 | inh_rec_total = inh_quant * inh_count 385 | # external input to I charge 386 | inh_ext = poissonI_rate * N_poissons * (Q_poisson * tau_ge * (E_ge - V_rest)) 387 | 388 | return np.hstack([exc_quant, exc_count, exc_rec_total, exc_ext]), np.hstack( 389 | [inh_quant, inh_count, inh_rec_total, inh_ext] 390 | ) 391 | 392 | 393 | def discard_by_avg_power( 394 | df_summary, f_band=(2, 10), power_thresh=[1e-10, None], return_idx=False 395 | ): 396 | """Discard simulations based on average power in a given frequency range.""" 397 | # discard by average power in a given frequency range 398 | cols_psd = data_utils.subselect_features(df_summary, ["psd"]) 399 | f_axis = data_utils.decode_df_float_axis(cols_psd) 400 | f_sel_idx = (f_axis >= f_band[0]) & (f_axis <= f_band[1]) 401 | avg_power = df_summary[cols_psd].iloc[:, f_sel_idx].mean(1) 402 | if (power_thresh[0] == "None") or (power_thresh[0] is None): 403 | power_thresh[0] = -np.inf 404 | if (power_thresh[1] == "None") or (power_thresh[1] is None): 405 | power_thresh[1] = np.inf 406 | idx_keep = (avg_power > power_thresh[0]) & (avg_power < power_thresh[1]) 407 | # Return good indices if requested, otherwise the dataframe 408 | if return_idx: 409 | return idx_keep 410 | else: 411 | return df_summary[idx_keep] 412 | 413 | 414 | def manual_filter_logPSD( 415 | log_psd, 416 | f_axis, 417 | f_bounds=[1, 490], 418 | bounds_hyperactive=(6.5, 13), 419 | bounds_silent=(7e-4, 0.14), 420 | ): 421 | """ 422 | Filter out bad simulations based on per-sample PSD statistics. 423 | Bounds are hard-coded, not ideal but gets the job done. 424 | Returns indices of good ones. 425 | """ 426 | f_idx = np.logical_and(f_axis >= f_bounds[0], f_axis <= f_bounds[1]) 427 | f_axis = f_axis[f_idx] 428 | log_psd = log_psd[:, f_idx] 429 | 430 | # manual conditions for bad simulations 431 | logpsd_range = log_psd.max(1) - log_psd.min(1) 432 | 433 | bad_idx_hyperactive = (log_psd.var(1) > bounds_hyperactive[0]) & ( 434 | (logpsd_range) > bounds_hyperactive[1] 435 | ) # crazy active 436 | bad_idx_silent = (log_psd.var(1) < bounds_silent[0]) & ( 437 | (logpsd_range) < bounds_silent[1] 438 | ) # basically no activity 439 | bad_idx_infnans = np.isinf(log_psd).any(axis=1) | np.isnan(log_psd).any( 440 | axis=1 441 | ) # nans or infs 442 | 443 | good_idx = ~( 444 | bad_idx_infnans | bad_idx_silent | bad_idx_hyperactive 445 | ) # toss any bad ones 446 | return good_idx 447 | 448 | 449 | def compute_pca_features(spikes, t_collect, params_dict): 450 | """Compute PCA features from spikes.""" 451 | # aggregate all spikes in the sampled populations 452 | spikes_list = sum( 453 | [ 454 | list(spikes[f"{pop}_spikes"].values()) 455 | for pop in params_dict["params_analysis"]["pop_sampler"].keys() 456 | ], 457 | [], 458 | ) 459 | # bin and smooth 460 | t_bins, SU_rate = spikes_summary.compute_smoothed_SU_rate( 461 | spikes_list, 462 | t_collect, 463 | params_dict["params_analysis"]["pca_bin_width"] / second, 464 | params_dict["params_analysis"]["pca_smooth_std"] / second, 465 | ) 466 | # do PCA 467 | pca = spikes_summary.compute_PCA(SU_rate.T, params_dict["params_analysis"]["n_pcs"]) 468 | # record mean variance 469 | var_total = (SU_rate).var(1).mean() 470 | # record variance explained ratios 471 | df_var_exp = pd.DataFrame( 472 | np.hstack([var_total, pca.explained_variance_ratio_]).T, 473 | index=["pca_total"] 474 | + [f"pca_{i+1}" for i in np.arange(params_dict["params_analysis"]["n_pcs"])], 475 | ) 476 | 477 | return df_var_exp.T 478 | -------------------------------------------------------------------------------- /automind/utils/data_utils.py: -------------------------------------------------------------------------------- 1 | # Utility functions for accessing and saving data, directories, etc. 2 | import numpy as np 3 | import copy 4 | import pandas as pd 5 | import matplotlib.pyplot as plt 6 | from brian2 import second 7 | import pickle 8 | from os import listdir 9 | import h5py as h5 10 | from time import strftime 11 | from pathlib import Path 12 | from glob import glob 13 | from ..analysis import spikes_summary 14 | from ..sim import b2_interface 15 | from ..utils import dist_utils 16 | 17 | 18 | def _filter_spikes_random(spike_trains, n_to_save): 19 | """Filter a subset of spike trains randomly for saving.""" 20 | record_subset = np.sort( 21 | np.random.choice(len(spike_trains), n_to_save, replace=False) 22 | ) 23 | return {k: spike_trains[k] for k in spike_trains.keys() if k in record_subset} 24 | 25 | 26 | def collect_spikes(net_collect, params_dict): 27 | """Collect all spikes from all spike monitors, filter a subset if necessary.""" 28 | spike_dict = {} 29 | for sm in net_collect.objects: 30 | if "spikes" in sm.name: 31 | spike_trains = sm.spike_trains() 32 | pop_save_def = params_dict["params_settings"]["record_defs"][ 33 | sm.name.split("_")[0] 34 | ]["spikes"] 35 | n_to_save = pop_save_def if type(pop_save_def) == int else len(pop_save_def) 36 | if n_to_save == len(spike_trains): 37 | # recorded and to-be saved is the same length, go on a per usual 38 | spike_dict[sm.name] = b2_interface._deunitize_spiketimes(spike_trains) 39 | else: 40 | # recorded more than necessary, subselect for saving 41 | spike_dict[sm.name] = b2_interface._deunitize_spiketimes( 42 | _filter_spikes_random(spike_trains, n_to_save) 43 | ) 44 | return spike_dict 45 | 46 | 47 | def collect_timeseries(net_collect, params_dict): 48 | """Collect all timeseries data from the network collector.""" 49 | timeseries = b2_interface.parse_timeseries( 50 | net_collect, params_dict["params_settings"]["record_defs"] 51 | ) 52 | return timeseries 53 | 54 | 55 | def collect_raw_data(net_collect, params_dict): 56 | """Collect all raw data from the network collector.""" 57 | # get spikes 58 | spikes = collect_spikes(net_collect, params_dict) 59 | timeseries = collect_timeseries(net_collect, params_dict) 60 | 61 | all_pop_sampler = {pop.split("_")[0]: None for pop in spikes.keys()} 62 | # compute rates 63 | rates = spikes_summary.compute_poprate_from_spikes( 64 | spikes, 65 | t_collect=(0, net_collect.t / second), 66 | dt=params_dict["params_settings"]["dt_ts"] / second, 67 | dt_bin=params_dict["params_settings"]["dt"] / second, 68 | pop_sampler=all_pop_sampler, 69 | smooth_std=params_dict["params_analysis"]["smooth_std"], 70 | ) 71 | timeseries = {**timeseries, **rates} 72 | return spikes, timeseries 73 | 74 | 75 | def check_before_save_h5(h5_file, dataset_name, data): 76 | """Check if dataset exists before saving to h5 file.""" 77 | if dataset_name not in h5_file: 78 | h5_file.create_dataset(dataset_name, data=data) 79 | 80 | 81 | def save_spikes_h5(h5_file, params_dict, spikes): 82 | """Save spikes to h5 file.""" 83 | run_id = f"{params_dict['params_settings']['batch_seed']}/{params_dict['params_settings']['random_seed']}/" 84 | for pop_name, pop_spikes in spikes.items(): 85 | for cell, spks in pop_spikes.items(): 86 | dataset_name = run_id + "spikes/" + pop_name + f"/{cell}" 87 | check_before_save_h5( 88 | h5_file, 89 | dataset_name, 90 | np.around(spks, params_dict["params_settings"]["t_sigdigs"]), 91 | ) 92 | 93 | 94 | def get_spikes_h5(h5_file, run_id): 95 | """Get spikes of a specific run ID (random_seed) from h5 file.""" 96 | spikes = {} 97 | if run_id in h5_file: 98 | # found file 99 | for pop in h5_file[run_id + f"spikes/"].keys(): 100 | spikes[pop] = { 101 | cell: np.array( 102 | h5_file[run_id + f"spikes/{pop}/"][cell][()], dtype=float 103 | ) 104 | for cell in h5_file[run_id + f"spikes/{pop}/"].keys() 105 | } 106 | 107 | return spikes 108 | 109 | 110 | def save_timeseries_h5(h5_file, params_dict, timeseries): 111 | """Save timeseries data to h5 file.""" 112 | run_id = f"{params_dict['params_settings']['batch_seed']}/{params_dict['params_settings']['random_seed']}/" 113 | for k, v in timeseries.items(): 114 | if (k == "t") or (k == "t_ds"): 115 | # save dt and t_end 116 | check_before_save_h5(h5_file, run_id + "timeseries/dt", v[1] - v[0]) 117 | check_before_save_h5(h5_file, run_id + "timeseries/t_end", v[-1] + v[1]) 118 | else: 119 | check_before_save_h5(h5_file, run_id + "timeseries/" + k, v) 120 | 121 | 122 | def save_h5_and_plot_raw(sim_collector, do_plot_ts=True): 123 | """Save raw data to h5 file and plot network rate.""" 124 | # grab batch_seed from first run bc they're all identical 125 | batch_seed = sim_collector[0][0]["params_settings"]["batch_seed"] 126 | # create or access existing h5 file 127 | h5_file_path = ( 128 | sim_collector[0][0]["path_dict"]["data"] + f"{batch_seed}_raw_data.hdf5" 129 | ) 130 | 131 | with h5.File(h5_file_path, "a") as h5_raw_file: 132 | # loop over collector 133 | for params_dict, spikes, timeseries in sim_collector: 134 | # check if there's anything going on 135 | # basically don't bother plotting or saving anything if there is no spikes at all 136 | if (timeseries) and ((timeseries["exc_rate"] > 0).any()): 137 | if do_plot_ts: 138 | # plot network rate 139 | run_id = ( 140 | f'{batch_seed}_{params_dict["params_settings"]["random_seed"]}_' 141 | ) 142 | fig_name = ( 143 | params_dict["path_dict"]["figures"] 144 | + run_id 145 | + "exc_rate_%s.pdf" 146 | % ( 147 | "short" 148 | if params_dict["params_analysis"]["early_stopped"] 149 | else "full" 150 | ) 151 | ) 152 | 153 | if not Path(fig_name).exists(): 154 | # doesn't exist, plot and save 155 | fig = plt.figure(figsize=(12, 1.5)) 156 | t_plot = timeseries["t_ds"] <= 60 157 | plt.plot( 158 | timeseries["t_ds"][t_plot], 159 | timeseries["exc_rate"][t_plot], 160 | lw=0.2, 161 | ) 162 | plt.title( 163 | f'{batch_seed}_{params_dict["params_settings"]["random_seed"]}' 164 | ) 165 | plt.autoscale(tight=True) 166 | plt.savefig(fig_name) 167 | plt.close(fig) 168 | 169 | if not params_dict["params_analysis"]["early_stopped"]: 170 | # save only if the full simulation ran, but always plot 171 | # save spikes 172 | save_spikes_h5(h5_raw_file, params_dict, spikes) 173 | 174 | ### QUICK HACK TO NOT SAVE RATES FOR NOW 175 | # should almost surely change this hack but... 176 | if "exc_rate" in timeseries.keys(): 177 | timeseries.pop("exc_rate") 178 | if "inh_rate" in timeseries.keys(): 179 | timeseries.pop("inh_rate") 180 | 181 | # save time series 182 | save_timeseries_h5(h5_raw_file, params_dict, timeseries) 183 | 184 | return h5_file_path 185 | 186 | 187 | def fill_params_dict( 188 | params_dict_orig, theta_samples, theta_priors, return_n_dicts=False 189 | ): 190 | """Take a template params_dict and fill in generated samples. 191 | 192 | Args: 193 | params_dict_orig (dict): template parameter dictionary (in nested format). 194 | theta_samples (dict or pd.Dataframe): samples, can be organized as a dictionary or dataframe / series. 195 | theta_priors (dict): prior hyperparameters, really just needed for b2 units. 196 | return_n_dicts (bool or int, optional): return a list of n dictionaries if int, otherwise return a single 197 | dictionary where some fields are arrays. Defaults to False. 198 | 199 | Returns: 200 | dict or list: dictionary (or list of dictionaries) with sampled values filled into the template parameter dict. 201 | """ 202 | # decide if samples is in dict, which assumes it has b2 units attached, otherwise in dataframe, 203 | # so need to convert to array and add unit because brian2 units work completely mysteriously 204 | samples_in_dict = type(theta_samples) is dict 205 | 206 | if return_n_dicts: 207 | # return list of dictionaries for multiple samples 208 | # deep copy so we don't modify original 209 | params_dict = [copy.deepcopy(params_dict_orig) for n in range(return_n_dicts)] 210 | else: 211 | # just make a single copy where fields with multiple values are stored as arrays 212 | params_dict = copy.deepcopy(params_dict_orig) 213 | 214 | for param, val in theta_samples.items(): 215 | if "params" in param: 216 | # check if it's a network parameter 217 | param_name = param.split(".") 218 | if param in theta_priors.keys(): 219 | # copy unit 220 | unit = theta_priors[param]["unit"] 221 | else: 222 | # param not in prior dictionary, copy directly but warn 223 | unit = 1 224 | print(param + " has no prior. Copied as bare value without unit.") 225 | 226 | if return_n_dicts: 227 | vals = val if samples_in_dict else np.array(val) * unit 228 | for i_n in range(return_n_dicts): 229 | # loop over dictionaries to fill in 230 | params_dict[i_n][param_name[0]][param_name[1]] = vals[i_n] 231 | else: 232 | # put directly into dictionary, useful for saving so non-repeating stuff doesn't get saved 233 | params_dict[param_name[0]][param_name[1]] = ( 234 | val if samples_in_dict else np.array(val) * unit 235 | ) 236 | 237 | return params_dict 238 | 239 | 240 | def update_params_dict(params_dict, update_dict): 241 | """Update a configuration dictionary with new values.""" 242 | # update non-default values 243 | for k, v in update_dict.items(): 244 | p, p_sub = k.split(".") 245 | params_dict[p][p_sub] = v 246 | return params_dict 247 | 248 | 249 | def pickle_file(full_path, to_pickled): 250 | """Save a pickled file.""" 251 | with open(full_path, "wb") as handle: 252 | pickle.dump(to_pickled, handle) 253 | 254 | 255 | def load_pickled(filename): 256 | """Load a pickled file.""" 257 | with open(filename, "rb") as f: 258 | loaded = pickle.load(f) 259 | return loaded 260 | 261 | 262 | def save_params_priors(path, params_dict, priors): 263 | """Save parameters and priors to pickle files.""" 264 | with open(path + "prior.pickle", "wb") as handle: 265 | pickle.dump(priors, handle) 266 | with open(path + "params_dict_default.pickle", "wb") as handle: 267 | pickle.dump(params_dict, handle) 268 | 269 | 270 | def process_simulated_samples(sim_collector, df_prior_samples_batch): 271 | """Process successfully-run simulated samples and append to dataframe.""" 272 | # make a copy of the df 273 | df_prior_samples_ran = df_prior_samples_batch.copy() 274 | 275 | # sort results by random_seed just in case pool returned in mixed order 276 | sort_idx = np.argsort( 277 | np.array([s[0]["params_settings"]["random_seed"] for s in sim_collector]) 278 | ) 279 | sim_collector = [sim_collector[i] for i in sort_idx] 280 | 281 | # check all seeds are equal and lined up 282 | if np.all( 283 | df_prior_samples_ran["params_settings.random_seed"].values 284 | == np.array([s[0]["params_settings"]["random_seed"] for s in sim_collector]) 285 | ): 286 | ran_metainfo = [ 287 | [ 288 | s[0]["params_settings"]["sim_time"] / second, # get rid of b2 units 289 | s[0]["params_settings"]["real_run_time"], 290 | s[0]["params_analysis"]["early_stopped"], 291 | ] 292 | for s in sim_collector 293 | ] 294 | df_prior_samples_ran = pd.concat( 295 | ( 296 | df_prior_samples_ran, 297 | pd.DataFrame( 298 | ran_metainfo, 299 | index=df_prior_samples_ran.index, 300 | columns=[ 301 | "params_settings.sim_time", 302 | "params_settings.real_run_time", 303 | "params_analysis.early_stopped", 304 | ], 305 | ), 306 | ), 307 | axis=1, 308 | ) 309 | 310 | else: 311 | print("misaligned; failed to append") 312 | 313 | return df_prior_samples_ran 314 | 315 | 316 | def convert_spike_array_to_dict(spikes, fs, pop_name="exc"): 317 | """Convert tuple of spike times to dict with cell id. 318 | 319 | Args: 320 | spikes (tuple): tuple of list of spike time. 321 | fs (float): sampling frequency. 322 | pop_name (str, optional): Name of population. Defaults to "exc". 323 | 324 | Returns: 325 | _type_: _description_ 326 | """ 327 | return { 328 | pop_name: { 329 | "%i" % i_s: np.array(s, ndmin=1) / fs 330 | for i_s, s in enumerate(spikes) 331 | if s is not None 332 | } 333 | } 334 | 335 | 336 | def decode_df_float_axis(indices, out_type=float, splitter="_", idx_pos=1): 337 | """Decode dataframe indices as array of values. For freqs, etc.""" 338 | return np.array([i.split(splitter)[idx_pos] for i in indices], dtype=out_type) 339 | 340 | 341 | def subselect_features(df, feature_str): 342 | """ 343 | Select features from dataframe based on string. 344 | """ 345 | if type(feature_str) == str: 346 | feature_str = [feature_str] 347 | 348 | all_features = [] 349 | for f_str in feature_str: 350 | all_features += list(df.columns[df.columns.str.contains(f_str)]) 351 | 352 | return all_features 353 | 354 | 355 | def separate_feature_columns( 356 | df, 357 | col_sets=[ 358 | "params", 359 | "isi", 360 | "burst", 361 | "pca", 362 | "psd", 363 | ], 364 | ): 365 | """Separate columns of dataframe into different sets.""" 366 | return [subselect_features(df, c) for c in col_sets] 367 | 368 | 369 | def make_subfolders(parent_path, subfolders=["data", "figures"]): 370 | """Make directories for simulations, parameters, and figures.""" 371 | paths = {} 372 | for subfolder in subfolders: 373 | full_path_sub = parent_path + "/" + subfolder + "/" 374 | # MAKE ABSOLUTE PATH HERE 375 | Path(full_path_sub).mkdir(parents=True, exist_ok=True) 376 | paths[subfolder] = full_path_sub 377 | return paths 378 | 379 | 380 | def get_subfolders(parent_path, subpath_reqs=None): 381 | """Get subfolders in a parent directory, optionally those that satisfy the subpath requirements.""" 382 | folders = [parent_path + f + "/" for f in listdir(parent_path) if "." not in f] 383 | if subpath_reqs: 384 | # include only those that satisfy the subpath requirement 385 | # usually looking for a folder 386 | folders = [f for f in folders if Path(f + "/" + subpath_reqs).exists()] 387 | 388 | return folders 389 | 390 | 391 | def collect_csvs(run_folders, filename, subfolder="", merge=True): 392 | """Collect csv files across analysis of multiple runs.""" 393 | collector = [] 394 | for rf in run_folders: 395 | data_path = rf + subfolder 396 | path_dict = extract_data_files(data_path, [filename]) 397 | if path_dict[filename]: 398 | print(data_path + path_dict[filename]) 399 | collector.append(pd.read_csv(data_path + path_dict[filename], index_col=0)) 400 | 401 | if merge: 402 | return pd.concat(collector, ignore_index=True, axis=0) 403 | else: 404 | return collector 405 | 406 | 407 | def set_seed_by_time(params_dict): 408 | """Set batch seed based on current time.""" 409 | from time import time 410 | 411 | batch_seed = int((time() % 1) * 1e7) 412 | params_dict["params_settings"]["batch_seed"] = batch_seed 413 | return batch_seed, params_dict 414 | 415 | 416 | def set_all_seeds(seed): 417 | """Set all seeds for reproducibility.""" 418 | from numpy.random import seed as np_seed 419 | from torch import manual_seed as torch_seed 420 | from brian2 import seed as b2_seed 421 | 422 | for seed_fn in [np_seed, torch_seed, b2_seed]: 423 | seed_fn(seed) 424 | 425 | 426 | def merge_theta_and_x(df_prior_samples_ran, random_seeds, data_rows): 427 | """Align and merge theta and x dataframes based on random seeds.""" 428 | if type(data_rows) != type(df_prior_samples_ran): 429 | # make it into dataframe 430 | data_rows = pd.concat(data_rows) 431 | 432 | matching_idx = [] 433 | for rs in random_seeds: 434 | matched_prior = df_prior_samples_ran.index[ 435 | df_prior_samples_ran["params_settings.random_seed"] == int(rs) 436 | ] 437 | if len(matched_prior) == 1: 438 | # found exactly 1 prior sample that matches summary row 439 | matching_idx.append(matched_prior) 440 | else: 441 | print(f"{rs} has {len(matched_prior)} matched priors, drop.") 442 | data_rows.drop(index=rs, inplace=True) 443 | 444 | df_merged_data = pd.concat( 445 | (df_prior_samples_ran, data_rows.set_index(np.array(matching_idx).squeeze())), 446 | axis=1, 447 | ) 448 | return df_merged_data 449 | 450 | 451 | def extract_data_files(data_folder, file_type_list, verbose=True): 452 | """Extract data files from a folder based on file types.""" 453 | run_files = listdir(data_folder) 454 | data_file_dict = {} 455 | 456 | data_file_dict["root_path"] = data_folder 457 | for file_type in file_type_list: 458 | file_candidates = [f for f in run_files if file_type in f] 459 | if verbose: 460 | print(file_type, file_candidates) 461 | if file_candidates != []: 462 | data_file_dict[file_type.split(".")[0]] = file_candidates[0] 463 | else: 464 | data_file_dict[file_type.split(".")[0]] = None 465 | 466 | return data_file_dict 467 | 468 | 469 | def filter_df(df, filter_col, filter_match, return_cols=None, return_array=False): 470 | """Filter dataframe based on column and match.""" 471 | df_filtered = df[df[filter_col] == filter_match] 472 | if return_cols is not None: 473 | df_filtered = df_filtered[return_cols] 474 | return df_filtered.values if return_array else df_filtered 475 | 476 | 477 | # Helper function for sorting through stuff 478 | def grab_xo_and_posterior_preds( 479 | df_xos, 480 | query_xo, 481 | df_posterior_sims, 482 | query_pp, 483 | cols_summary, 484 | log_samples=True, 485 | stdz_func=None, 486 | include_mapnmean=False, 487 | ): 488 | """Grab xo and posterior predictives based on queries. Used in final analysis notebooks.""" 489 | # Grab original xo and posterior predictives 490 | df_xo = dist_utils.find_matching_xo(df_xos, query_xo) # xo 491 | df_per_xo = filter_df(df_posterior_sims, "x_o", query_pp) # predictives 492 | 493 | dfs_collect = [ 494 | filter_df(df_per_xo, "inference.type", i_t) 495 | for i_t in df_per_xo["inference.type"].unique() 496 | ] # split by type 497 | assert "samples" in str( 498 | dfs_collect[-1].head(1)["inference.type"].values 499 | ) # check that last entry are samples 500 | df_samples = ( 501 | pd.concat(dfs_collect[1:]) if include_mapnmean else dfs_collect[-1] 502 | ) # exclude gt_resim, and optionally map/mean samples 503 | 504 | # Preprocessing and whatnot 505 | xo = dist_utils.log_n_stdz(df_xo[cols_summary].values, log_samples, stdz_func) 506 | samples = dist_utils.log_n_stdz( 507 | df_samples[cols_summary].values, log_samples, stdz_func 508 | ) 509 | return df_xo, dfs_collect, df_samples, xo, samples 510 | 511 | 512 | def collect_all_xo_and_posterior_preds( 513 | df_xos, 514 | df_posterior_sims, 515 | xo_queries, 516 | cols_features, 517 | cols_params, 518 | log_samples=True, 519 | stdz_func=None, 520 | include_mapnmean=False, 521 | sort_samples=True, 522 | sort_weights=None, 523 | ): 524 | """Collect all xo and posterior predictives based on queries. Used in final analysis notebooks.""" 525 | df_collect, samples_x_collect, samples_theta_collect = [], [], [] 526 | 527 | for i_q, xo_query in enumerate(xo_queries): 528 | df_xo = dist_utils.find_matching_xo(df_xos, xo_query) 529 | df_samples = filter_df(df_posterior_sims, "x_o", "%s_%s" % xo_query) 530 | 531 | # log and standardize before sorting or not 532 | xo = dist_utils.log_n_stdz(df_xo[cols_features].values, log_samples, stdz_func) 533 | 534 | dfs, xs, thetas = {}, {}, {} 535 | for i_t, s_type in enumerate(df_samples["inference.type"].unique()): 536 | if ("samples" in s_type) and (include_mapnmean): 537 | df_type = pd.concat( 538 | [ 539 | filter_df(df_samples, "inference.type", tt) 540 | for tt in df_samples["inference.type"].unique() 541 | if tt != "gt_resim" 542 | ] 543 | ) 544 | else: 545 | df_type = filter_df(df_samples, "inference.type", s_type) 546 | 547 | samples = dist_utils.log_n_stdz( 548 | df_type[cols_features].values, log_samples, stdz_func 549 | ) 550 | if (sort_samples) and (samples.shape[0] > 1): 551 | samples_sorted, dists, idx_sorted = dist_utils.sort_closest_to_xo( 552 | xo, samples, top_n=None, weights=sort_weights 553 | ) 554 | samples = samples_sorted 555 | else: 556 | idx_sorted = np.arange(samples.shape[0]) 557 | 558 | for key_type in ["resim", "map", "mean", "samples"]: 559 | if key_type in s_type: 560 | dfs[key_type] = df_type.iloc[idx_sorted] 561 | xs[key_type] = samples 562 | thetas[key_type] = df_type[cols_params].iloc[idx_sorted].values 563 | 564 | dfs["xo"] = df_xo 565 | xs["xo"] = xo 566 | thetas["xo"] = None 567 | 568 | df_collect.append(dfs) 569 | samples_x_collect.append(xs) 570 | samples_theta_collect.append(thetas) 571 | 572 | return df_collect, samples_x_collect, samples_theta_collect 573 | 574 | 575 | ### Functions for loading stuff upfront for posterior predictive analysis 576 | def load_for_posterior_predictives_analyses( 577 | xo_type, 578 | feature_set, 579 | algorithm="NPE", 580 | sample_datetime=None, 581 | idx_inf=-1, 582 | using_copied=True, 583 | ): 584 | """Load all necessary files for posterior predictive analyses. Used in final analysis notebooks.""" 585 | # load posterior files 586 | head_dir = "/slurm_r2/" if using_copied else "/" 587 | data_dir_str = ( 588 | f"../data/{head_dir}/inference_r2/{xo_type}/{feature_set}/{algorithm}/" 589 | + (f"{sample_datetime}/" if sample_datetime else "") 590 | ) 591 | data_dirs = sorted(glob(data_dir_str + "*-*/*/data/")) 592 | 593 | print("All relevant inference directories:") 594 | [print(f"+++ {d}") for d in data_dirs] 595 | print(" ----- ") 596 | print(f"Loading...{data_dirs[idx_inf]}...") 597 | path_dict = extract_data_files( 598 | data_dirs[idx_inf], 599 | [ 600 | "posterior.pickle", 601 | "params_dict_analysis_updated.pickle", 602 | "summary_data_merged.csv", 603 | "raw_data.hdf5", 604 | ], 605 | ) 606 | 607 | # also load cfg but only when not on local 608 | if not using_copied: 609 | import yaml 610 | 611 | with open(path_dict["root_path"] + "../.hydra/overrides.yaml", "r") as f: 612 | cfg = yaml.safe_load(f) 613 | print(cfg) 614 | else: 615 | cfg = {} 616 | 617 | # load original xo 618 | df_xos = load_df_xos(xo_type) 619 | 620 | # load posterior pred and densities 621 | df_posterior_sims, posterior, params_dict_default = load_df_posteriors(path_dict) 622 | _, theta_minmax = dist_utils.standardize_theta( 623 | posterior.prior.sample((1,)), posterior 624 | ) 625 | _, cols_isi, cols_burst, cols_pca, cols_psd = separate_feature_columns( 626 | df_posterior_sims 627 | ) 628 | cols_params = posterior.names 629 | return ( 630 | df_xos, 631 | df_posterior_sims, 632 | posterior, 633 | theta_minmax, 634 | (cols_params, cols_isi, cols_burst, cols_pca, cols_psd), 635 | params_dict_default, 636 | path_dict, 637 | ) 638 | 639 | 640 | def load_df_xos(xo_type, xo_path=None): 641 | """Load target observation data based on xo type. 642 | NEED TO CHANGE DATA PATHS. 643 | """ 644 | # TO DO: CHANGE DATA PATHS 645 | if xo_path is None: 646 | if "simulation" in xo_type: 647 | load_path = "../data/adex_MKI-round1-testset/analysis_summary/MK1_summary_merged.csv" 648 | # '/analysis_summary/MK1_summary_merged.csv' 649 | elif xo_type == "organoid": 650 | load_path = "../data/adex_MKI/analysis_organoid/organoid_summary.csv" 651 | df_xos = pd.read_csv(load_path, index_col=0) 652 | elif xo_type in ["allen-hc", "allen-vis", "allenvc"]: 653 | load_path = "../data/adex_MKI/analysis_allenvc/allenvc_summary.csv" 654 | else: 655 | load_path = xo_path 656 | 657 | # load 658 | df_xos = pd.read_csv(load_path, index_col=0) 659 | if xo_type == "organoid": 660 | from .organoid_utils import convert_date_to_int 661 | 662 | df_xos.insert(0, ["day"], convert_date_to_int(df_xos), allow_duplicates=True) 663 | return df_xos 664 | 665 | 666 | def load_df_posteriors(path_dict): 667 | """Load posterior samples, network, and configurations.""" 668 | df_posterior_sims = pd.read_csv( 669 | path_dict["root_path"] + path_dict["summary_data_merged"], index_col=0 670 | ) 671 | print(f"{df_posterior_sims['x_o'].value_counts().values[0]} samples per xo.") 672 | posterior = load_pickled(path_dict["root_path"] + path_dict["posterior"]) 673 | params_dict_default = load_pickled( 674 | path_dict["root_path"] + path_dict["params_dict_analysis_updated"] 675 | ) 676 | return df_posterior_sims, posterior, params_dict_default 677 | -------------------------------------------------------------------------------- /automind/utils/dist_utils.py: -------------------------------------------------------------------------------- 1 | ### Utility functions for distributions and inference 2 | import numpy as np 3 | import pandas as pd 4 | import torch 5 | from torch import nn, Tensor, stack 6 | from scipy.stats import gaussian_kde 7 | from sbi.utils import MultipleIndependent 8 | from . import data_utils 9 | from .analysis_utils import manual_filter_logPSD, discard_by_avg_power 10 | from ..inference import trainers 11 | 12 | 13 | class CustomIndependentJoint(MultipleIndependent): 14 | """Custom class for base indepedent-marginal prior distributions.""" 15 | 16 | def __init__(self, variable_params): 17 | dist_list = [vp[1](**vp[2]).expand((1,)) for vp in variable_params] 18 | super().__init__(dist_list) 19 | """ 20 | Set the list of 1D distributions and names. 21 | """ 22 | self.names = [vp[0] for vp in variable_params] 23 | self.marginals = [vp[1](**vp[2]) for vp in variable_params] 24 | self.b2_units = [vp[3] for vp in variable_params] 25 | self.make_prior_dict() 26 | 27 | def __repr__(self) -> str: 28 | out = "" 29 | for n, u, m in zip(self.names, self.b2_units, self.marginals): 30 | out += f"{n} ({u}) ~ {m} \n" 31 | return out 32 | 33 | def make_prior_dict(self): 34 | self.as_dict = { 35 | name: {"unit": self.b2_units[i_n], "marginal": self.marginals[i_n]} 36 | for i_n, name in enumerate(self.names) 37 | } 38 | return self.as_dict 39 | 40 | 41 | class MinMaxStandardizeTransform(nn.Module): 42 | def __init__( 43 | self, 44 | min, 45 | max, 46 | new_min, 47 | new_max, 48 | ): 49 | """Transforms input tensor to new min/max range.""" 50 | super().__init__() 51 | min, max, new_min, new_max = map(torch.as_tensor, (min, max, new_min, new_max)) 52 | self.min = min 53 | self.max = max 54 | self.new_min = new_min 55 | self.new_max = new_max 56 | self.register_buffer("_min", min) 57 | self.register_buffer("_max", max) 58 | self.register_buffer("_new_min", new_min) 59 | self.register_buffer("_new_max", new_max) 60 | 61 | def forward(self, tensor): 62 | return standardize_minmax( 63 | tensor, self._min, self._max, self._new_min, self._new_max 64 | ) 65 | 66 | 67 | def pass_info_from_prior( 68 | dist1, dist2, passed_attrs=["names", "marginals", "b2_units", "as_dict"] 69 | ): 70 | """Pass attributes from one distribution to another.""" 71 | for attr in passed_attrs: 72 | if hasattr(dist1, attr): 73 | setattr(dist2, attr, getattr(dist1, attr)) 74 | else: 75 | print(f"{attr} not in giving distribution.") 76 | return dist2 77 | 78 | 79 | def find_matching_xo(df_xos, xo_query): 80 | """Find xo in dataframe given query. 81 | 82 | Args: 83 | df_xos (pandas dataframe): Summary dataframe of all xos. 84 | xo_query (tuple): querying info of xo, (batch_seed, random_seed) for sims, (date, well) for organoid. 85 | 86 | Returns: 87 | dataframe: queried xo in the dataframe. 88 | """ 89 | batch, run = xo_query 90 | if ("date" in df_xos.columns) & ("well" in df_xos.columns): 91 | # xos from organoid data, get xos by date and well 92 | # print(f"recording: {batch}, well: {run}") 93 | df_matched = df_xos[(df_xos["date"] == batch) & (df_xos["well"] == run)] 94 | 95 | elif ("mouse" in df_xos.columns) & ("area" in df_xos.columns): 96 | # xos from mouse neuropixels data, get by mouse and region 97 | # print(f"mouse: {batch}, area: {run}") 98 | df_matched = df_xos[(df_xos["mouse"] == batch) & (df_xos["area"] == run)] 99 | 100 | elif ("params_settings.batch_seed" in df_xos.columns) & ( 101 | "params_settings.random_seed" in df_xos.columns 102 | ): 103 | # xos from simulation, get by batch and random seed 104 | # print(f"batch seed: {batch}, random seed: {run}") 105 | df_matched = df_xos[ 106 | (df_xos["params_settings.batch_seed"] == batch) 107 | & (df_xos["params_settings.random_seed"] == run) 108 | ] 109 | return df_matched 110 | 111 | 112 | def find_posterior_preds_of_xo( 113 | df_sims, xo_query, query_string_form="%s_%s", discard_early_stopped=True 114 | ): 115 | """Find simulations whose xo is queried. 116 | 117 | Args: 118 | df_sims (pandas dataframe): Dataframe containing simulation data. 119 | xo_query (tuple): querying info of xo. 120 | query_string_form (str, optional): String format of xo column in df_sims. Defaults to "%s_%s". 121 | discard_early_stopped (bool, optional): Whether to discard sims that didn't run. Defaults to True. 122 | 123 | Returns: 124 | dataframe: all simulation samples that have matching xo. 125 | """ 126 | if discard_early_stopped: 127 | return df_sims[ 128 | (df_sims["x_o"] == (query_string_form % xo_query)) 129 | & (df_sims["params_analysis.early_stopped"] == False) 130 | ] 131 | else: 132 | return df_sims[(df_sims["x_o"] == (query_string_form % xo_query))] 133 | 134 | 135 | def sample_proposal(proposal, n_samples, xo=None): 136 | """Sample from proposal distribution. If xo is provided, set it as default xo.""" 137 | if proposal.default_x is not None: 138 | print("Proposal has default xo.") 139 | if xo is not None: 140 | print( 141 | f"Default xo is identical to current xo: {(proposal.default_x==xo).all()}" 142 | ) 143 | else: 144 | print("Setting default xo") 145 | proposal.set_default_x(xo) 146 | 147 | samples = proposal.sample((n_samples,)) 148 | df_samples = pd.DataFrame(samples.numpy().astype(float), columns=proposal.names) 149 | return df_samples 150 | 151 | 152 | def sample_different_xos_from_posterior( 153 | density_estimator, df_xos, xo_queries, n_samples_per, return_xos=False 154 | ): 155 | """Draw samples conditioned on a list of observations (xos). 156 | 157 | Args: 158 | density_estimator (SBI posterior): posterior density estimator from SBI. 159 | df_xos (pandas dataframe): dataframe of all observations. 160 | xo_queries (list): list of tuples that uniquely identify xos to be matched in df_xos. 161 | n_samples_per (int): number of samples to be drawn per observation. 162 | return_xos (bool, optional): Whether to also return info on matched xos. Defaults to False. 163 | 164 | Returns: 165 | _type_: df_posterior_samples, list of [df_matched, batch, run, xo] 166 | """ 167 | posterior_samples, xos = [], [] 168 | # transform xos 169 | try: 170 | print("Applying custom feature transformations.") 171 | [print(k, v) for k, v in density_estimator.x_bounds_and_transforms.items()] 172 | if "freq_bounds" in density_estimator.x_bounds_and_transforms.keys(): 173 | # Using PSD features 174 | df_xos[density_estimator.x_feat_set] = preproc_dataframe_psd( 175 | df_xos, density_estimator.x_bounds_and_transforms, drop_nans=False 176 | )[0] 177 | else: 178 | # Using spike/burst features 179 | df_xos[density_estimator.x_feat_set] = preproc_dataframe( 180 | df_xos, density_estimator.x_bounds_and_transforms, drop_nans=False 181 | )[0] 182 | except Exception as e: 183 | print(e) 184 | print( 185 | "Transformations failed (likely no transformation included in posterior)." 186 | ) 187 | 188 | for batch, run in xo_queries: 189 | if ("date" in df_xos.columns) & ("well" in df_xos.columns): 190 | # xos from organoid data, get xos by date and well 191 | print(f"recording: {batch}, well: {run}") 192 | df_matched = df_xos[(df_xos["date"] == batch) & (df_xos["well"] == run)] 193 | elif ("mouse" in df_xos.columns) & ("area" in df_xos.columns): 194 | # xos from mouse neuropixels data, get by mouse and region 195 | print(f"mouse: {batch}, area: {run}") 196 | df_matched = df_xos[(df_xos["mouse"] == batch) & (df_xos["area"] == run)] 197 | elif ("params_settings.batch_seed" in df_xos.columns) & ( 198 | "params_settings.random_seed" in df_xos.columns 199 | ): 200 | # xos from simulation, get by batch and random seed 201 | print(f"batch seed: {batch}, random seed: {run}") 202 | df_matched = df_xos[ 203 | (df_xos["params_settings.batch_seed"] == batch) 204 | & (df_xos["params_settings.random_seed"] == run) 205 | ] 206 | 207 | # get xo in tensor 208 | xo = Tensor(df_matched[density_estimator.x_feat_set].astype(float).values) 209 | # sample posterior 210 | if density_estimator.default_x is not None: 211 | print("Density estimator has default x.") 212 | print( 213 | f"Default x is identical to current xo: {(density_estimator.default_x==xo).all()}" 214 | ) 215 | 216 | samples = density_estimator.sample((n_samples_per,), x=xo) 217 | 218 | # store samples 219 | df_samples = pd.DataFrame( 220 | samples.numpy().astype(float), columns=density_estimator.names 221 | ) 222 | df_samples.insert(0, "x_o", f"{batch}_{run}") 223 | posterior_samples.append(df_samples) 224 | xos.append((xo, df_matched, batch, run)) 225 | 226 | df_posterior_samples = pd.concat(posterior_samples, ignore_index=True) 227 | if return_xos: 228 | return df_posterior_samples, xos 229 | else: 230 | return df_posterior_samples 231 | 232 | 233 | def fill_gaps_with_nans(theta_full, x_partial, idx_good): 234 | """Fill in missing values in x_partial with nans. For ACE-GBI.""" 235 | x_full = torch.zeros((theta_full.shape[0], x_partial.shape[1])) 236 | x_full[idx_good] = x_partial 237 | x_full[~idx_good] = torch.nan 238 | return Tensor(theta_full), x_full 239 | 240 | 241 | def proc_one_column(df_col, bound_feat, log_feat): 242 | """Process one column of dataframe given bounds and log flag.""" 243 | if bound_feat: 244 | df_col[(df_col < bound_feat[0]) | (df_col > bound_feat[1])] = np.nan 245 | if log_feat: 246 | df_col = np.log10(df_col) 247 | col_range = [df_col.min(), df_col.max()] 248 | return df_col, col_range 249 | 250 | 251 | def preproc_dataframe( 252 | df_summary, x_transforms_dict, drop_nans=False, replace_infs=False 253 | ): 254 | """Preprocess dataframe for inference and pairplot. 255 | 256 | Discard entries outside of bound, and log features if indicated. 257 | Args: 258 | df_summary (pd dataframe): dataframe to preprocess. 259 | transforms_dict (dict): {name: (bounds, log_or_not)} 260 | 261 | Returns: 262 | df_copy, feat_bounds, feat_names_pretty : 263 | """ 264 | columns = list(x_transforms_dict.keys()) 265 | 266 | # make a copy of dataframe 267 | df_copy = df_summary[columns].copy() 268 | feat_bounds = [] 269 | for i_c, col in enumerate(columns): 270 | # call the processing function 271 | df_col, col_range = proc_one_column(df_copy[col], *x_transforms_dict[col]) 272 | df_copy[col] = df_col 273 | feat_bounds.append(col_range) 274 | 275 | if replace_infs: 276 | df_copy = df_copy.replace([np.inf, -np.inf], np.nan) 277 | 278 | if drop_nans: 279 | df_copy = df_copy[~df_copy.isin([np.nan, np.inf, -np.inf]).any(axis=1)] 280 | 281 | # Get binary indices of good samples 282 | idx_good = np.zeros(len(df_summary), dtype=bool) 283 | # Get indices that have no nans or inf 284 | idx_good[df_copy[~df_copy.isin([np.nan, np.inf, -np.inf]).any(axis=1)].index] = True 285 | 286 | # Get pretty feature names 287 | feat_names_pretty = [ 288 | ( 289 | f"$log_{{10}}$ {' '.join(col.split('_'))}" 290 | if x_transforms_dict[col][1] 291 | else " ".join(col.split("_")) 292 | ) 293 | for col in columns 294 | ] 295 | return df_copy, feat_bounds, idx_good, feat_names_pretty 296 | 297 | 298 | def preproc_dataframe_psd( 299 | df_summary, f_transforms_dict, drop_nans=False, replace_infs=False 300 | ): 301 | """Preprocess PSD dataframe for inference.""" 302 | 303 | # Get frequency axis and the indices of the frequencies we want to keep 304 | cols_psd = data_utils.subselect_features(df_summary, ["psd"]) 305 | f_axis = data_utils.decode_df_float_axis(cols_psd) 306 | f_sel_idx = (f_axis >= f_transforms_dict["freq_bounds"][0]) & ( 307 | f_axis <= f_transforms_dict["freq_bounds"][1] 308 | ) 309 | idx_good = np.ones(len(df_summary), dtype=bool) 310 | 311 | # Discard bad samples if required and grab only the frequencies we want 312 | if f_transforms_dict["discard_stopped"]: 313 | if "params_analysis.early_stopped" in df_summary.columns: 314 | # Discard samples that was early stopped 315 | idx_keep = df_summary["params_analysis.early_stopped"] == False 316 | print( 317 | f"{idx_good.sum()-(idx_good & idx_keep).sum()} sims dropped due to early stopping." 318 | ) 319 | idx_good = idx_good & idx_keep 320 | 321 | if f_transforms_dict["discard_conditions"]: 322 | if f_transforms_dict["discard_conditions"]["manual_filter_logPSD"]: 323 | log_psd = np.log10(df_summary[cols_psd].values) 324 | f_axis = data_utils.decode_df_float_axis(cols_psd) 325 | idx_keep = manual_filter_logPSD(log_psd, f_axis) 326 | print( 327 | f"{idx_good.sum()-(idx_good & idx_keep).sum()} sims dropped due to manual criteria." 328 | ) 329 | idx_good = idx_good & idx_keep 330 | 331 | if f_transforms_dict["discard_conditions"]["power_thresh"]: 332 | idx_keep = discard_by_avg_power( 333 | df_summary, 334 | f_band=f_transforms_dict["discard_conditions"]["f_band"], 335 | power_thresh=f_transforms_dict["discard_conditions"]["power_thresh"], 336 | return_idx=True, 337 | ) 338 | print( 339 | f"{idx_good.sum()-(idx_good & idx_keep).sum()} sims dropped due to power threshold discard." 340 | ) 341 | idx_good = idx_good & idx_keep 342 | 343 | df_copy = df_summary[cols_psd][idx_good].iloc[:, f_sel_idx].copy() 344 | print(f"{len(df_summary) - len(df_copy)} samples discarded in total.") 345 | 346 | # Log transform if required 347 | if f_transforms_dict["log_power"]: 348 | df_copy = np.log10(df_copy) 349 | 350 | # Replace infs and nans 351 | if replace_infs: 352 | df_copy = df_copy.replace([np.inf, -np.inf], np.nan) 353 | if drop_nans: 354 | df_copy = df_copy.dropna() 355 | 356 | return df_copy, f_axis[f_sel_idx], idx_good, f_sel_idx 357 | 358 | 359 | def logPSD_scaler(log_psd, demean=True, scaler=0.1): 360 | """Scaling function for log-PSDs prior to training. 361 | 362 | Args: 363 | log_psd (tensor): log-PSD tensor. 364 | demean (bool, optional): Whether to demean the log PSD. Defaults to True. 365 | scaler (float, optional): Scaling factor. Defaults to 0.1. 366 | 367 | Returns: 368 | tensor: scaled log-PSD. 369 | """ 370 | 371 | if demean: 372 | log_psd = log_psd - log_psd.mean(1, keepdim=True) 373 | return log_psd * scaler 374 | 375 | 376 | def retrieve_trained_network( 377 | inf_dir, 378 | algorithm, 379 | feature_set, 380 | inference_datetime=None, 381 | job_id=None, 382 | load_network_prior=True, 383 | ): 384 | """Retrieve trained network and prior from inference directory.""" 385 | from os import listdir, path 386 | 387 | full_path = f"{inf_dir}/{feature_set}/{algorithm}/" 388 | if inference_datetime is None: 389 | # Get latest inference run 390 | inference_datetime = np.sort(listdir(full_path))[-1] 391 | print(listdir(full_path)) 392 | 393 | full_path += inference_datetime 394 | 395 | if job_id is None: 396 | # Get last job id 397 | job_id = np.sort( 398 | [f for f in listdir(full_path) if path.isdir(full_path + f"/{f}")] 399 | )[-1] 400 | 401 | full_path += f"/{job_id}/" 402 | # Check that the path exists 403 | 404 | if load_network_prior: 405 | if not path.isdir(full_path): 406 | raise ValueError(f"{full_path}: path does not exist.") 407 | else: 408 | print(f"Loading network and prior from {full_path}") 409 | neural_net = data_utils.load_pickled(full_path + "/neural_net.pickle") 410 | prior = data_utils.load_pickled(full_path + "/prior.pickle") 411 | return full_path, neural_net, prior 412 | else: 413 | if not path.isdir(full_path): 414 | print(f"Warning: {full_path} path does not exist but returned anyway.") 415 | return full_path 416 | 417 | 418 | def df_to_tensor(df_theta, df_x): 419 | """Convert dataframes to tensors.""" 420 | thetas = Tensor(df_theta.values) 421 | xs = Tensor(df_x.values) 422 | return thetas, xs 423 | 424 | 425 | def log_n_stdz(samples, do_log=True, standardizing_func=None): 426 | """Log and standardize samples.""" 427 | if do_log: 428 | samples = np.log10(samples) 429 | if standardizing_func: 430 | samples = standardizing_func(torch.Tensor(samples)).numpy() 431 | return samples 432 | 433 | 434 | def train_algorithm(theta, x, prior, cfg): 435 | """Train algorithm based on config.""" 436 | print(f"Training {cfg.name}: {cfg}...") 437 | if cfg.name == "NPE": 438 | inference, neural_net = trainers.train_NPE(theta, x, prior, cfg) 439 | elif cfg.name == "RestrictorNPE": 440 | inference, neural_net = trainers.train_RestrictorNPE(theta, x, prior, cfg) 441 | elif cfg.name == "NLE": 442 | inference, neural_net = trainers.train_NLE(theta, x, prior, cfg) 443 | elif cfg.name == "NRE": 444 | inference, neural_net = trainers.train_NRE(theta, x, prior, cfg) 445 | elif cfg.name == "ACE": 446 | inference, neural_net = trainers.train_ACE(theta, x, prior, cfg) 447 | elif (cfg.name == "REGR-F") or (cfg.name == "REGR-R"): 448 | # Forward or reverse is taken care of in the trainer 449 | inference, neural_net = trainers.train_Regression(theta, x, prior, cfg) 450 | else: 451 | raise NotImplementedError("Algorithm not recognised.") 452 | 453 | print("Training finished.") 454 | return inference, neural_net 455 | 456 | 457 | def sort_closest_to_xo(xo, samples, distance_func="mse", top_n=None, weights=None): 458 | """Sort samples by distance to xo.""" 459 | if weights is None: 460 | weights = np.ones(xo.shape[1]) 461 | if distance_func in ["mse", "mae"]: 462 | dists = (samples - xo) ** 2 if distance_func == "mse" else np.abs(samples - xo) 463 | dist = (dists * weights).mean(1) 464 | else: 465 | raise NotImplementedError(f"Distance function {distance_func} not implemented.") 466 | 467 | # Sort by distance 468 | idx_dist_sorted = np.argsort(dist) 469 | if top_n is None: 470 | top_n = len(dist) 471 | return ( 472 | samples[idx_dist_sorted][:top_n], 473 | dist[idx_dist_sorted][:top_n], 474 | idx_dist_sorted, 475 | ) 476 | 477 | 478 | def standardize_theta(theta, prior, low_high=Tensor([0.0, 1.0]), destandardize=False): 479 | """Standardize theta to new min-max range based on prior bounds, or convert back.""" 480 | theta_minmax = get_minmax(prior) 481 | if destandardize: 482 | # Convert from standardized back to original range 483 | theta_standardized = standardize_minmax( 484 | theta, low_high[0], low_high[1], theta_minmax[:, 0], theta_minmax[:, 1] 485 | ) 486 | else: 487 | # Standardize to [low, high] from prior bounds 488 | theta_standardized = standardize_minmax( 489 | theta, theta_minmax[:, 0], theta_minmax[:, 1], low_high[0], low_high[1] 490 | ) 491 | 492 | return theta_standardized, theta_minmax 493 | 494 | 495 | def get_minmax(prior): 496 | """Get minmax range of prior.""" 497 | minmax = stack( 498 | [ 499 | Tensor([prior.marginals[i].low, prior.marginals[i].high]) 500 | for i in range(len(prior.names)) 501 | ], 502 | 0, 503 | ) 504 | return minmax 505 | 506 | 507 | def standardize_minmax(theta, min, max, new_min, new_max): 508 | """Standardize to between [low,high] based on prior bounds.""" 509 | theta_transformed = ((theta - min) / (max - min)) * (new_max - new_min) + new_min 510 | return theta_transformed 511 | 512 | 513 | def kde_estimate(data, bounds, points=1000): 514 | """Get kernel density estimate of samples.""" 515 | kde = gaussian_kde(data) 516 | grid = np.linspace(bounds[0], bounds[1], points) 517 | density = kde(grid) 518 | return grid, density 519 | -------------------------------------------------------------------------------- /automind/utils/organoid_utils.py: -------------------------------------------------------------------------------- 1 | ### Organoid data analysis utilities 2 | 3 | from datetime import datetime 4 | import numpy as np 5 | import pandas as pd 6 | import matplotlib.pyplot as plt 7 | 8 | from ..analysis import spikes_summary 9 | from .data_utils import ( 10 | convert_spike_array_to_dict, 11 | update_params_dict, 12 | make_subfolders, 13 | pickle_file, 14 | ) 15 | from .analysis_utils import compute_summary_features 16 | from .plot_utils import plot_rates_and_bursts 17 | from ..sim.runners import construct_experiment_settings_adex 18 | 19 | 20 | def convert_date_to_int(df_organoid): 21 | """ 22 | Convert dates in the dataframe to integer days since the first date. 23 | 24 | Args: 25 | df_organoid (pd.DataFrame): DataFrame containing a 'date' column with dates in '%y%m%d' format. 26 | 27 | Returns: 28 | np.ndarray: Array of integers representing days since the first date. 29 | """ 30 | day1 = datetime.strptime(df_organoid["date"][0], "%y%m%d") 31 | return np.array( 32 | [ 33 | (datetime.strptime(d.split("_")[0], "%y%m%d") - day1).days 34 | for d in df_organoid["date"] 35 | ] 36 | ) 37 | 38 | 39 | def get_organoid_raw_data(org_spikes_path, query=None, bin_spike_params=None): 40 | """Quick access helper for loading organoid spiking data. 41 | 42 | Args: 43 | org_spikes_path (str): path string to organoid data file. 44 | query (dict, optional): {'date': YYMMDD, 'well': int}. Defaults to None. 45 | bin_spike_params (dict, optional): configurations for spike binning. Defaults to None. 46 | 47 | Returns: 48 | dict: dictionary with spikes and rate from queried date, or all spikes. 49 | """ 50 | fs = 12500 51 | # load data 52 | data_all = np.load(org_spikes_path, allow_pickle=True) 53 | org_spikes_all, org_t_all, org_name_all = ( 54 | data_all["spikes"], 55 | data_all["t"], 56 | data_all["recs"], 57 | ) 58 | org_name_all_matched = [n.split(".")[0][7:] for n in org_name_all] 59 | 60 | data_out = {} 61 | if query is not None: 62 | # find queried datapoint 63 | idx_date, idx_well = org_name_all_matched.index(query["date"]), query["well"] 64 | spikes_query = org_spikes_all[idx_date][idx_well] 65 | data_out["spikes"] = spikes_query 66 | 67 | # population firing rate 68 | t_s = org_t_all[idx_date] 69 | spikes_well = ( 70 | np.array([np.array(s, dtype=float) for s in spikes_query], dtype=object) 71 | / fs 72 | ) 73 | data_out["t"], data_out["pop_rate"] = spikes_summary.bin_population_spikes( 74 | spikes_well, t_s[0], t_s[-1], **bin_spike_params 75 | ) 76 | return data_out 77 | 78 | else: 79 | # no specific query, return all spikes 80 | data_out = { 81 | "spikes_all": org_spikes_all, 82 | "t_all": org_t_all, 83 | "name_all_matched": org_name_all_matched, 84 | } 85 | return data_out 86 | 87 | 88 | def compute_summary_organoid_well(organoid_raw_data, day, well, params_dict_default): 89 | """ 90 | Compute summary statistics for a specific well in the organoid data. 91 | 92 | Args: 93 | organoid_raw_data (dict): Raw organoid data containing spikes and time information. 94 | day (int): Day index to extract data from. 95 | well (int): Well index to extract data from. 96 | params_dict_default (dict): Default parameters for computing summary statistics. 97 | 98 | Returns: 99 | tuple: (dataframe, dict, dict) Summary statistics for the specified well. 100 | """ 101 | # reformat into spike dictionary like simulations 102 | spikes = organoid_raw_data["spikes_all"][day][well] 103 | t_s = organoid_raw_data["t_all"][day] 104 | spikes_dict = convert_spike_array_to_dict(spikes, fs=12500) 105 | spikes_dict["exc_spikes"] = spikes_dict["exc"] 106 | spikes_dict["t_end"] = t_s[-1] 107 | summary_stats = compute_summary_features(spikes_dict, params_dict_default) 108 | 109 | df_features = pd.concat( 110 | [summary_stats["summary_spikes"], summary_stats["summary_bursts"]], axis=1 111 | ) 112 | 113 | return df_features, summary_stats["pop_rates"], summary_stats["summary_burst_stats"] 114 | 115 | 116 | def run_organoid_analysis( 117 | organoid_data_folder, output_folder, exp_settings_update, analysis_settings_update 118 | ): 119 | """ 120 | Run analysis on organoid data. 121 | """ 122 | params_dict_default = construct_experiment_settings_adex(exp_settings_update) 123 | params_dict_default = update_params_dict( 124 | params_dict_default, analysis_settings_update 125 | ) 126 | 127 | # turn off psd and pca analyses 128 | params_dict_default["params_analysis"]["do_pca"] = False 129 | params_dict_default["params_analysis"]["do_psd"] = False 130 | 131 | [print(f"{k}: {v}") for k, v in params_dict_default["params_analysis"].items()] 132 | organoid_raw_data = get_organoid_raw_data( 133 | organoid_data_folder + "organoid_spikes.npz" 134 | ) 135 | num_days = len(organoid_raw_data["name_all_matched"]) 136 | df_aggregate_summary = pd.DataFrame([]) 137 | do_plot = params_dict_default["params_analysis"]["do_plot"] 138 | 139 | for day in range(num_days): 140 | print(f"{day+1} / {num_days} days...") 141 | if do_plot: 142 | fig, axs = plt.subplots( 143 | 8, 2, gridspec_kw={"width_ratios": [5, 1]}, figsize=(20, 24) 144 | ) 145 | 146 | features = [] 147 | for well in range(8): 148 | df_features, rates, burst_stats = compute_summary_organoid_well( 149 | organoid_raw_data, day, well, params_dict_default 150 | ) 151 | features.append(df_features) 152 | 153 | if do_plot: 154 | plot_rates_and_bursts( 155 | rates, 156 | burst_stats, 157 | vars_to_plot={"exc_rate": "k"}, 158 | burst_alpha=0.2, 159 | fig_handles=(fig, axs[well, :]), 160 | burst_time_offset=params_dict_default["params_analysis"][ 161 | "burst_win" 162 | ][0], 163 | ) 164 | axs[well, 0].set_xlim([0, 120]) 165 | axs[well, 1].set_xlim( 166 | params_dict_default["params_analysis"]["burst_win"] 167 | ) 168 | df_info = pd.DataFrame( 169 | [ 170 | [organoid_raw_data["name_all_matched"][day], well, rates["t_ds"][-1]] 171 | for well in range(8) 172 | ], 173 | columns=["date", "well", "t_rec"], 174 | ) 175 | df_features = df_info.join(pd.concat(features).reset_index(drop=True)) 176 | df_aggregate_summary = pd.concat( 177 | (df_aggregate_summary, df_features), ignore_index=True 178 | ) 179 | path_dict = make_subfolders(output_folder, ["figures"]) 180 | if do_plot: 181 | plt.tight_layout() 182 | plt.savefig( 183 | path_dict["figures"] 184 | + organoid_raw_data["name_all_matched"][day] 185 | + ".pdf" 186 | ) 187 | plt.close() 188 | 189 | # save out params file 190 | pickle_file( 191 | output_folder + "/organoid_params_dict_default.pickle", params_dict_default 192 | ) 193 | df_aggregate_summary.to_csv(output_folder + "/organoid_summary.csv") 194 | return 195 | -------------------------------------------------------------------------------- /automind/utils/plot_utils.py: -------------------------------------------------------------------------------- 1 | ### Utility functions for plotting, likely to be changed in the future. 2 | 3 | import matplotlib.pyplot as plt 4 | 5 | # plt.style.use("../../../assets/matplotlibrc_notebook") 6 | import numpy as np 7 | from ..utils import data_utils, dist_utils 8 | from ..sim.default_configs import MKI_3col_plot_order 9 | 10 | 11 | def plot_rates_tiny( 12 | rates_to_plot, 13 | figsize, 14 | fig_axs=None, 15 | color="k", 16 | alpha=0.8, 17 | lw=1, 18 | XL=(0, 60), 19 | ylim_past=None, 20 | fontsize=14, 21 | ): 22 | if fig_axs is None: 23 | fig, axs = plt.subplots(ncols=1, nrows=len(rates_to_plot), figsize=figsize) 24 | else: 25 | fig, axs = fig_axs 26 | 27 | for i_r, rates in enumerate(rates_to_plot): 28 | axs[i_r].plot( 29 | rates["t_ds"], 30 | rates["exc_rate"], 31 | alpha=alpha, 32 | color=color[i_r] if type(color) == list else color, 33 | lw=lw, 34 | ) 35 | axs[i_r].set_xlim(XL) 36 | axs[i_r].set_xticks([]) 37 | 38 | if ylim_past: 39 | ymax = rates["exc_rate"][ylim_past:].max() 40 | axs[i_r].set_ylim([0, ymax * 1.1]) 41 | 42 | axs[i_r].set_yticks([int(np.ceil(axs[i_r].get_ylim()[1] / 1.1 / 5) * 5)]) 43 | # axs[i_r].set_yticks([round(rates["exc_rate"].max() / 10) * 10]) 44 | axs[i_r].spines["bottom"].set_visible(False) 45 | 46 | # axs[i_r].set_xticks(ticks=axs[i_r].get_xlim()) 47 | # axs[i_r].set_xticklabels(labels=XL) 48 | axs[i_r].set_xlabel(f"{XL[1]-XL[0]} s", fontsize=fontsize) 49 | return fig, axs 50 | 51 | 52 | def plot_rates_and_bursts( 53 | pop_rates, 54 | burst_stats, 55 | vars_to_plot=None, 56 | burst_alpha=0.5, 57 | figsize=(20, 3), 58 | w_ratio=(5, 1), 59 | fig_handles=None, 60 | burst_time_offset=-0.5, 61 | ): 62 | if fig_handles is None: 63 | fig, axs = plt.subplots( 64 | ncols=2, gridspec_kw={"width_ratios": w_ratio}, figsize=figsize 65 | ) 66 | else: 67 | # just unpack input 68 | fig, axs = fig_handles 69 | 70 | fs = 1 / (pop_rates["t_ds"][1] - pop_rates["t_ds"][0]) 71 | 72 | if vars_to_plot is None: 73 | # just plot exc rate 74 | ts_color = "C0" 75 | axs[0].plot(pop_rates["t_ds"], pop_rates["exc_rate"], alpha=0.9, color=ts_color) 76 | else: 77 | ts_color = list(vars_to_plot.values())[0] 78 | for k, c in vars_to_plot.items(): 79 | axs[0].plot(pop_rates["t_ds"], pop_rates[k], alpha=0.7, color=c) 80 | 81 | if len(burst_stats["burst_times"]) > 0: 82 | # if there are bursts, mark burst peak and widths 83 | axs[0].plot( 84 | burst_stats["burst_times"], 85 | burst_stats["burst_heights"], 86 | "ow", 87 | mec="k", 88 | ms=6, 89 | alpha=0.9, 90 | ) 91 | 92 | if len(burst_stats["burst_kernels"]) > 0: 93 | # plot kernels 94 | if "burst_kernels_refined" in burst_stats: 95 | ips = burst_stats["burst_width_ips"] 96 | for i_b, bk in enumerate(burst_stats["burst_kernels_refined"]): 97 | # t_burst = np.arange(int(ips[i][0]*fs), len(bk)) 98 | t_burst = ( 99 | burst_stats["burst_width_ips"][i_b][0] 100 | - burst_stats["burst_times"][i_b] 101 | ) + np.arange(len(bk)) / fs 102 | axs[1].plot(t_burst, bk, color=ts_color, alpha=burst_alpha, lw=1) 103 | 104 | # ips_idx = np.arange(int(ips[i][0]*fs), int(ips[i][1]*fs), dtype=int) 105 | 106 | else: 107 | axs[1].plot( 108 | pop_rates["t_ds"][: burst_stats["burst_kernels"][0].shape[0]] 109 | + burst_time_offset, 110 | np.array(burst_stats["burst_kernels"]).T, 111 | color=ts_color, 112 | alpha=burst_alpha, 113 | lw=1, 114 | ) 115 | for i, ips in enumerate(burst_stats["burst_width_ips"]): 116 | b_height = burst_stats["burst_width_heights"][i] 117 | b_time = burst_stats["burst_times"][i] 118 | axs[0].plot([ips[0], ips[1]], [b_height] * 2, "b", lw=4) 119 | axs[1].plot( 120 | [ips[0] - b_time, ips[1] - b_time], 121 | [b_height + abs(np.random.randn()) * b_height * 2] * 2, 122 | "b", 123 | lw=1, 124 | alpha=0.75, 125 | ) 126 | if "subpeak_times" in burst_stats: 127 | for subpeak_time in burst_stats["subpeak_times"]: 128 | subpeak_idx = (subpeak_time * fs).astype(int) 129 | axs[0].plot( 130 | subpeak_time, 131 | pop_rates["exc_rate"][subpeak_idx], 132 | "ow", 133 | mec="b", 134 | ms=3, 135 | alpha=0.8, 136 | ) 137 | 138 | plt.tight_layout() 139 | return fig, axs 140 | 141 | 142 | def plot_wrapper(pop_rates, burst_stats, random_seed, params_dict): 143 | fig, axs = plot_rates_and_bursts( 144 | pop_rates, 145 | burst_stats, 146 | vars_to_plot={"exc_rate": "k"}, 147 | burst_alpha=0.2, 148 | figsize=(20, 2.5), 149 | w_ratio=(8, 1), 150 | burst_time_offset=params_dict["params_analysis"]["burst_win"][0], 151 | ) 152 | axs[0].set_xlim([0, 120]) 153 | # axs[0].set_ylim([0, None]) 154 | axs[1].set_xlim(params_dict["params_analysis"]["burst_win"]) 155 | axs[1].set_ylim([0, None]) 156 | run_id = f"{params_dict['params_settings']['batch_seed']}_{random_seed}" 157 | axs[0].set_title(f"{run_id}") 158 | plt.savefig(params_dict["path_dict"]["figures"] + f"{run_id}_analyzed.pdf") 159 | plt.close() 160 | 161 | 162 | def plot_wrapper_MK1(summary_stats, random_seed, params_dict): 163 | fig, axs = plt.subplots( 164 | ncols=4, gridspec_kw={"width_ratios": (7, 1, 1, 1)}, figsize=(20, 2.5) 165 | ) 166 | plot_rates_and_bursts( 167 | summary_stats["pop_rates"], 168 | summary_stats["summary_burst_stats"], 169 | vars_to_plot={"exc_rate": "k"}, 170 | burst_alpha=0.2, 171 | fig_handles=(fig, axs), 172 | burst_time_offset=params_dict["params_analysis"]["burst_win"][0], 173 | ) 174 | 175 | axs[2].loglog( 176 | data_utils.decode_df_float_axis(summary_stats["summary_psd"].columns, float), 177 | summary_stats["summary_psd"].loc["exc_rate"], 178 | "C1", 179 | label="exc", 180 | ) 181 | axs[2].loglog( 182 | data_utils.decode_df_float_axis(summary_stats["summary_psd"].columns, float), 183 | summary_stats["summary_psd"].loc["inh_rate"], 184 | "C4", 185 | label="inh", 186 | ) 187 | axs[2].legend() 188 | 189 | # axs[3].loglog(np.arange(len(summary_stats['summary_pca'])-1)+1, summary_stats['summary_pca']['var_exp_ratio'].iloc[1:], 'o') 190 | axs[3].loglog( 191 | data_utils.decode_df_float_axis(summary_stats["summary_pca"].columns[1:], int), 192 | summary_stats["summary_pca"].iloc[0].values[1:], 193 | "o", 194 | ) 195 | 196 | run_id = f"{params_dict['params_settings']['batch_seed']}_{random_seed}" 197 | if len(summary_stats["summary_burst_stats"]["burst_heights"]) > 0: 198 | YMAX = np.median(summary_stats["summary_burst_stats"]["burst_heights"]) * 1.1 199 | else: 200 | fs = 1 / ( 201 | summary_stats["pop_rates"]["t_ds"][1] 202 | - summary_stats["pop_rates"]["t_ds"][0] 203 | ) 204 | YMAX = summary_stats["pop_rates"]["exc_rate"][int(20 * fs) :].max() * 1.1 205 | axs[0].set_xlim([0, 120]) 206 | axs[0].set_ylim([0, YMAX]) 207 | axs[0].set_title(run_id) 208 | axs[1].set_xlim(params_dict["params_analysis"]["burst_win"]) 209 | axs[1].set_ylim([0, YMAX]) 210 | axs[1].set_title("bursts") 211 | 212 | axs[2].set_xlim([None, params_dict["params_analysis"]["f_lim"]]) 213 | axs[2].set_xticks([1, 100]) 214 | axs[2].set_xticklabels(["1", "100"]) 215 | axs[2].set_title("PSD") 216 | 217 | axs[3].set_xticks([1, 10, 100]) 218 | axs[3].set_xticklabels(["1", "10", "100"]) 219 | axs[3].set_title("PCA eigval") 220 | 221 | plt.tight_layout() 222 | plt.savefig(params_dict["path_dict"]["figures"] + f"{run_id}_analyzed.pdf") 223 | plt.close() 224 | 225 | 226 | def add_point_to_pairplot(df_point_plot, x_features, axs, **kwargs): 227 | for i_y, feat_x in enumerate(x_features): 228 | for i_x, feat_y in enumerate(x_features): 229 | if i_x > i_y: 230 | axs[i_y, i_x].plot( 231 | df_point_plot[feat_y], df_point_plot[feat_x], **kwargs["upper"] 232 | ) 233 | if i_x == i_y: 234 | axs[i_y, i_x].axvline(df_point_plot[feat_y], **kwargs["diag"]) 235 | 236 | 237 | def _plot_raster_pretty( 238 | spikes, 239 | XL, 240 | every_other=1, 241 | ax=None, 242 | fontsize=14, 243 | plot_combined=False, 244 | plot_inh=True, 245 | E_color="k", 246 | I_color="gray", 247 | mew=0.5, 248 | ms=1, 249 | **plot_kwargs, 250 | ): 251 | if ax == None: 252 | ax = plt.axes() 253 | 254 | if plot_combined: 255 | combined_spikes = list(spikes["exc_spikes"].values()) + list( 256 | spikes["inh_spikes"].values() 257 | ) 258 | if plot_combined == "sorted": 259 | c_order = np.argsort([len(v) for v in combined_spikes]) 260 | elif plot_combined == "random": 261 | c_order = np.random.permutation(len(combined_spikes)) 262 | else: 263 | c_order = np.arange(len(combined_spikes)) 264 | 265 | combined_spikes = [combined_spikes[i] for i in c_order] 266 | 267 | # import pdb; pdb.set_trace() 268 | [ 269 | ( 270 | ax.plot( 271 | v[::every_other], 272 | i_v * np.ones_like(v[::every_other]), 273 | # c_order[i_v] * np.ones_like(v[::every_other]), 274 | "|", 275 | color=E_color, 276 | alpha=1, 277 | ms=ms, 278 | mew=mew, 279 | ) 280 | if len(v) > 0 281 | else None 282 | ) 283 | for i_v, v in enumerate(combined_spikes) 284 | ] 285 | else: 286 | # plot exc (and inh) separately 287 | [ 288 | ( 289 | ax.plot( 290 | v[::every_other], 291 | i_v * np.ones_like(v[::every_other]), 292 | "|", 293 | color=E_color, 294 | alpha=1, 295 | ms=ms, 296 | mew=mew, 297 | ) 298 | if len(v) > 0 299 | else None 300 | ) 301 | for i_v, v in enumerate(spikes["exc_spikes"].values()) 302 | ] 303 | # plot inh 304 | if plot_inh: 305 | [ 306 | ( 307 | ax.plot( 308 | v[::every_other], 309 | (i_v + len(spikes["exc_spikes"])) 310 | * np.ones_like(v[::every_other]), 311 | "|", 312 | color=I_color, 313 | alpha=1, 314 | ms=ms, 315 | mew=mew, 316 | ) 317 | if len(v) > 0 318 | else None 319 | ) 320 | for i_v, v in enumerate(spikes["inh_spikes"].values()) 321 | ] 322 | ax.set_xticks([]) 323 | ax.set_yticks([]) 324 | ax.spines.left.set_visible(False) 325 | ax.spines.bottom.set_visible(False) 326 | ax.set_xlim(XL) 327 | ax.set_ylabel("Raster", fontsize=fontsize) 328 | return ax 329 | 330 | 331 | def _plot_rates_pretty( 332 | rates, 333 | XL, 334 | pops_to_plot=["exc_rate"], 335 | ylim_past=None, 336 | ax=None, 337 | fontsize=14, 338 | lw=0.5, 339 | color=None, 340 | ): 341 | if ax == None: 342 | ax = plt.axes() 343 | if color is None: 344 | C_dict = {"exc_rate": "C1", "inh_rate": "C4", "avgpop_rate": "C0"} 345 | else: 346 | C_dict = {"avgpop_rate": color} 347 | 348 | for pop in pops_to_plot: 349 | ax.plot(rates["t_ds"], rates[pop], lw=lw, alpha=1, color=C_dict[pop]) 350 | # ax.set_yticks([]) 351 | 352 | ax.set_ylabel("Rate", fontsize=fontsize, labelpad=None) 353 | ax.set_xticks([]) 354 | ax.set_xlim(XL) 355 | if ylim_past: 356 | ymax = rates["exc_rate"][ylim_past:].max() 357 | ax.set_ylim([0, ymax * 1.1]) 358 | # ax.set_yticks([round(ymax / 10) * 10]) 359 | 360 | ax.set_yticks([0, int(np.ceil(ax.get_ylim()[1] / 5) * 5)]) 361 | # ax.set_yticks([round(1+ax.get_ylim()[1] / 10) * 10]) 362 | ax.set_xlabel(f"{XL[1]-XL[0]} s", fontsize=fontsize) 363 | ax.spines.bottom.set_visible(False) 364 | return ax 365 | 366 | 367 | def _plot_eigspec_pretty( 368 | df_pca, n_pcs, ax=None, fontsize=14, color="C0", alpha=1, ms=1, lw=0 369 | ): 370 | if ax == None: 371 | ax = plt.axes() 372 | pcs = data_utils.decode_df_float_axis(df_pca.columns.values[1:]) 373 | pca = df_pca.iloc[0, 1:].values 374 | ax.loglog( 375 | pcs[:n_pcs], pca[:n_pcs] * 100, "-o", color=color, lw=lw, alpha=alpha, ms=ms 376 | ) 377 | ax.set_xticks([1, 10], ["1", "10"]) 378 | # ax.set_yticks([1e-5, 1], [r"$10^{-5}$", "1"]) 379 | ax.minorticks_off() 380 | ax.set_ylabel("PCA", fontsize=fontsize) 381 | return ax 382 | 383 | 384 | def _plot_psd_pretty( 385 | df_psd, 386 | pops_to_plot=["exc_rate", "inh_rate"], 387 | ax=None, 388 | fontsize=14, 389 | color=None, 390 | alpha=1, 391 | ): 392 | if ax == None: 393 | ax = plt.axes() 394 | f_axis = data_utils.decode_df_float_axis(df_psd.columns.values) 395 | if color is None: 396 | C_dict = {"exc_rate": "C1", "inh_rate": "C4", "avgpop_rate": "C0"} 397 | else: 398 | C_dict = {"avgpop_rate": color} 399 | for pop in pops_to_plot: 400 | ax.loglog( 401 | f_axis, df_psd.loc[pop].values, color=C_dict[pop], lw=0.8, alpha=alpha 402 | ) 403 | ax.set_xticks([1, 10, 100], ["1", "10", "100"]) 404 | ax.set_yticks([]) 405 | ax.set_xlim([0.5, 450]) 406 | ax.set_ylabel("PSD", fontsize=fontsize) 407 | ax.minorticks_off() 408 | return ax 409 | 410 | 411 | def strip_decimal(s): 412 | return int(s) if int(s) == s else s 413 | 414 | 415 | def plot_params_pretty( 416 | thetas, 417 | prior, 418 | param_names, 419 | figsize=(2.5, 10), 420 | fig_axs=None, 421 | labelpad=40, 422 | fontsize=14, 423 | **plot_kwarg, 424 | ): 425 | 426 | n_params = thetas.shape[1] 427 | params_bound = [ 428 | [prior.marginals[i].low, prior.marginals[i].high] for i in range(n_params) 429 | ] 430 | 431 | if fig_axs is None: 432 | fig, axs = plt.subplots(n_params, 1, figsize=figsize, constrained_layout=False) 433 | else: 434 | fig, axs = fig_axs 435 | 436 | for i in range(n_params): 437 | if plot_kwarg == {}: 438 | [axs[i].plot(th[i], 0, "o", ms=5, alpha=0.8) for th in thetas] 439 | else: 440 | [axs[i].plot(th[i], 0, **plot_kwarg) for th in thetas] 441 | 442 | # axs[i].errorbar(posterior_samples[:n_best,i].mean(), 0, xerr=stats.sem(posterior_samples[:n_best,i])*3, color=plt_colors[i_q], alpha=0.8) 443 | 444 | if fig_axs is None: 445 | axs[i].spines.left.set_visible(False) 446 | axs[i].spines.bottom.set_visible(False) 447 | axs[i].axhline(0, alpha=0.1, lw=1) 448 | 449 | axs[i].text( 450 | np.array(params_bound[i])[0], 451 | 0, 452 | strip_decimal(np.array(params_bound[i])[0]), 453 | ha="right", 454 | va="center", 455 | fontsize=fontsize, 456 | ) 457 | axs[i].text( 458 | np.array(params_bound[i])[1], 459 | 0, 460 | strip_decimal(np.array(params_bound[i])[1]), 461 | ha="left", 462 | va="center", 463 | fontsize=fontsize, 464 | ) 465 | 466 | axs[i].set_xticks([]) 467 | axs[i].set_yticks([]) 468 | axs[i].set_ylabel( 469 | param_names[i], 470 | rotation=0, 471 | labelpad=labelpad, 472 | ha="right", 473 | va="center", 474 | fontsize=fontsize, 475 | ) 476 | 477 | axs[i].set_xlim(np.array(params_bound[i])) 478 | axs[i].set_ylim([-0.1, 0.1]) 479 | # plt.subplots_adjust( 480 | # left=None, bottom=None, right=None, top=None, wspace=0.0, hspace=0 481 | # ) 482 | return fig, axs 483 | 484 | 485 | def plot_params_1D( 486 | thetas, 487 | param_bounds, 488 | param_names, 489 | fig_axs, 490 | color, 491 | draw_canvas=True, 492 | draw_kde=True, 493 | flip_density=False, 494 | draw_median=True, 495 | draw_samples=False, 496 | labelpad=40, 497 | fontsize=5, 498 | kde_points=100, 499 | **plot_kwarg, 500 | ): 501 | fig, axs = fig_axs 502 | 503 | if len(axs.shape) == 2: 504 | # split mode 505 | axes = axs.T.flatten() 506 | # hard coded order 507 | plt_order = MKI_3col_plot_order 508 | else: 509 | # single column mode 510 | axes = axs 511 | plt_order = np.arange(thetas.shape[1]) 512 | assert len(axs) == thetas.shape[1] 513 | 514 | for i_, ax in enumerate(axes): 515 | i_a = plt_order[i_] 516 | if i_a == -1: 517 | ax.axis("off") 518 | continue 519 | 520 | # assert len(axs)==thetas.shape[1] 521 | # for i_a, ax in enumerate(axs): 522 | if draw_kde: 523 | kde = dist_utils.kde_estimate(thetas[:, i_a], param_bounds[i_a], kde_points) 524 | ax.fill_between( 525 | kde[0], 526 | -kde[1] if flip_density else kde[1], 527 | alpha=0.6, 528 | color=color, 529 | lw=0.5, 530 | ) 531 | 532 | if draw_median: 533 | ax.plot(np.median(thetas[:, i_a]), 0, "|", color=color, mew=0.75, ms=2) 534 | 535 | if draw_samples: 536 | samples = thetas[:10, i_a] 537 | ax.plot( 538 | samples, 539 | samples * 0.0, 540 | ".", 541 | color=color, 542 | mew=0.75, 543 | ms=plot_kwarg["sample_ms"], 544 | alpha=plot_kwarg["sample_alpha"], 545 | ) 546 | 547 | if draw_canvas: 548 | ax.axhline(0, lw=0.25, alpha=1, zorder=-1) 549 | ax.set_xticks([]) 550 | ax.set_yticks([]) 551 | ax.set_xlim(param_bounds[i_a]) 552 | ax.spines.left.set_visible(False) 553 | ax.spines.bottom.set_visible(False) 554 | 555 | for i_s, side in enumerate(["right", "left"]): 556 | # Aligned to the opposite end of word, so left to right 557 | ax.text( 558 | np.array(param_bounds[i_a])[i_s], 559 | 0, 560 | strip_decimal(np.array(param_bounds[i_a])[i_s]), 561 | ha=side, 562 | va="center", 563 | fontsize=fontsize, 564 | ) 565 | ax.set_xticks([]) 566 | ax.set_yticks([]) 567 | ax.set_ylabel( 568 | param_names[i_a], 569 | rotation=0, 570 | labelpad=labelpad, 571 | ha="right", 572 | va="center", 573 | fontsize=fontsize, 574 | ) 575 | ax.set_xlim(np.array(param_bounds[i_a])) 576 | 577 | # YL=np.abs(ax.get_ylim()).max() 578 | # ax.set_ylim(YL*np.array([-1,1])) 579 | return fig, axs 580 | 581 | 582 | def plot_corr_pv(pvals, ax, alpha_level=0.05, fmt="w*", ms=0.5): 583 | for i in range(pvals.shape[0]): 584 | for j in range(pvals.shape[0]): 585 | if pvals[i, j] < alpha_level: 586 | ax.plot(j, i, fmt, ms=ms, alpha=1) 587 | -------------------------------------------------------------------------------- /datasets/README.md: -------------------------------------------------------------------------------- 1 | # Data Download 2 | ### Summary data and trained DGMs 3 | 4 | All files can be downloaded as a single file [from figshare](https://figshare.com/s/3f1467f8fb0f328aed16). Place the `.zip` file in this directory (`./datasets/`) and unzip in place preserves the correct relative paths for code demos, model training and inference, etc. 5 | 6 | Alternatively, files can be individually downloaded and manually organized in the directory structure as follows: 7 | 8 | --- 9 | `./training_prior_samples/` 10 | - `training.zip`: 1-million parameter configurations and summary features from model simulations used to train deep generative models. 11 | - `heldout.zip`: additional network simulations not used for DGM training, a subset of which was used as synthetic observations. 12 | 13 | --- 14 | `./discovered_posterior_samples/` 15 | - `organoids.zip`: discovered model configurations consistent with human brain organoid network burst across development. See ./organoid_predictives.png. 16 | - `mouse-vis.zip`: discovered model configurations consistent with population firing rate PSD of Neuropixels recordings from mouse visual areas. See ./allen_predicitves_all.png. 17 | - `mouse-hc.zip`: discovered model configurations consistent with population firing rate PSD of Neuropixels recordings from mouse hippocampal areas. See ./allen_predicitves_all.png. 18 | - `synthetic.zip`: discovered model configurations consistent with population firing rate PSD of synthetic observations, i.e., held-out network simulations. See ./synthetic.png. 19 | - NOTE: all `.zip` files also contain the prior distribution, posterior density estimator, and config files necessary for running the simulations and analyses. 20 | 21 | --- 22 | `./dgms/` 23 | - `burst_posterior.pickle`: trained conditional density estimator that approximates the posterior distribution conditioned on network burst summary features. 24 | - `psd_posterior.pickle`: trained conditional density estimator that approximates the posterior distribution. conditioned on network firing rate power spectral densities. 25 | 26 | --- 27 | `./observations/` 28 | - `allenvc_summary.csv`: population firing rate PSD of Neuropixels recordings. 29 | - `synthetic_summary.csv`: various summary features of synthetic observations (i.e., held-out network simulations). 30 | - `organoid_summary.csv`: population firing burst statistics of organoid multi-electrode array recordings. 31 | - `example_raw_data.npz`: raw spike train data from an example organoid recording. 32 | 33 | --- 34 | 35 | Example discovered model configurations and simulations, overlaid with target observation data (first row of each subplot in organoid panels, black lines in mouse and synthetic PSD panels): 36 | ![](../assets/img/predictives.png) 37 | 38 | --- 39 | ### Raw simulation data 40 | Coming soon. -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: automind 2 | channels: 3 | - defaults 4 | - conda-forge 5 | - anaconda 6 | dependencies: 7 | - python=3.11 8 | - jupyterlab 9 | - ipywidgets 10 | - brian2 11 | - scikit-learn 12 | - numba 13 | - seaborn 14 | - hdf5 15 | # - pytables 16 | - pip 17 | - pip: 18 | - sbi==0.22.0 19 | - hydra-core 20 | # - allensdk 21 | - -e . -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | REQUIRED = [ 4 | "jupyterlab", 5 | ] 6 | 7 | setup( 8 | name="automind", 9 | python_requires=">=3.11.0", 10 | packages=find_packages(), 11 | install_requires=REQUIRED, 12 | ) 13 | --------------------------------------------------------------------------------