├── .gitattributes ├── .gitignore ├── README.md ├── UNLICENSE ├── assets └── teaser.png ├── configs ├── 20ms_arxiv.json ├── 20ms_arxiv_no_reg.json ├── 20ms_full_context.json ├── 20ms_mask.json ├── 20ms_no_span.json ├── area2_bump.yaml ├── arxiv │ ├── chaotic.yaml │ ├── lorenz.yaml │ ├── m700_115.yaml │ ├── m700_2296.yaml │ ├── m700_2296_postnorm.yaml │ ├── m700_230.yaml │ ├── m700_460.yaml │ ├── m700_no_log.yaml │ ├── m700_no_reg.yaml │ ├── m700_no_span.yaml │ └── m700_nonzero.yaml ├── dmfc_rsg.json ├── dmfc_rsg.yaml ├── mc_maze.json ├── mc_maze.yaml ├── mc_maze_large.yaml ├── mc_maze_medium.yaml ├── mc_maze_small.yaml ├── mc_maze_small_from_scratch.yaml ├── mc_rtt.yaml ├── sweep_bump.json ├── sweep_generic.json └── sweep_simple.json ├── data ├── chaotic_rnn │ ├── gen_synth_data_no_inputs.sh │ ├── generate_chaotic_rnn_data.py │ ├── generate_chaotic_rnn_data_allowRandomSeed.py │ ├── synthetic_data_utils.py │ └── utils.py ├── lfads_lorenz.h5 └── lorenz.py ├── defaults.py ├── environment.yml ├── nlb.yml ├── ray_get_lfve.py ├── ray_random.py ├── scripts ├── analyze_ray.py ├── analyze_utils.py ├── clear.sh ├── eval.sh ├── fig_3a_synthetic_neurons.py ├── fig_3b_4b_plot_hparams.py ├── fig_5_lfads_times.py ├── fig_5_ndt_times.py ├── fig_6_plot_losses.py ├── nlb.py ├── nlb_from_scratch.ipynb ├── nlb_from_scratch.py ├── record_all_rates.py ├── scratch.py ├── simple_ci.py └── train.sh ├── src ├── __init__.py ├── config │ ├── __init__.py │ └── default.py ├── dataset.py ├── logger_wrapper.py ├── mask.py ├── model.py ├── model_baselines.py ├── model_registry.py ├── run.py ├── runner.py ├── tb_wrapper.py └── utils.py └── tune_models.py /.gitattributes: -------------------------------------------------------------------------------- 1 | *.h5 filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | *.h5 3 | *.pdf 4 | *.png 5 | logs/ 6 | tb/ 7 | tmp 8 | 9 | 10 | # Byte-compiled / optimized / DLL files 11 | __pycache__/ 12 | *.py[cod] 13 | *$py.class 14 | *.pyc 15 | # C extensions 16 | *.so 17 | 18 | # Distribution / packaging 19 | .Python 20 | build/ 21 | develop-eggs/ 22 | dist/ 23 | downloads/ 24 | eggs/ 25 | .eggs/ 26 | lib/ 27 | lib64/ 28 | parts/ 29 | sdist/ 30 | var/ 31 | wheels/ 32 | share/python-wheels/ 33 | *.egg-info/ 34 | .installed.cfg 35 | *.egg 36 | MANIFEST 37 | 38 | # PyInstaller 39 | # Usually these files are written by a python script from a template 40 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 41 | *.manifest 42 | *.spec 43 | 44 | # Installer logs 45 | pip-log.txt 46 | pip-delete-this-directory.txt 47 | 48 | # Unit test / coverage reports 49 | htmlcov/ 50 | .tox/ 51 | .nox/ 52 | .coverage 53 | .coverage.* 54 | .cache 55 | nosetests.xml 56 | coverage.xml 57 | *.cover 58 | *.py,cover 59 | .hypothesis/ 60 | .pytest_cache/ 61 | cover/ 62 | 63 | # Translations 64 | *.mo 65 | *.pot 66 | 67 | # Django stuff: 68 | *.log 69 | local_settings.py 70 | db.sqlite3 71 | db.sqlite3-journal 72 | 73 | # Flask stuff: 74 | instance/ 75 | .webassets-cache 76 | 77 | # Scrapy stuff: 78 | .scrapy 79 | 80 | # Sphinx documentation 81 | docs/_build/ 82 | 83 | # PyBuilder 84 | .pybuilder/ 85 | target/ 86 | 87 | # Jupyter Notebook 88 | .ipynb_checkpoints 89 | 90 | # IPython 91 | profile_default/ 92 | ipython_config.py 93 | 94 | # pyenv 95 | # For a library or package, you might want to ignore these files since the code is 96 | # intended to run in multiple environments; otherwise, check them in: 97 | # .python-version 98 | 99 | # pipenv 100 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 101 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 102 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 103 | # install all needed dependencies. 104 | #Pipfile.lock 105 | 106 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 107 | __pypackages__/ 108 | 109 | # Celery stuff 110 | celerybeat-schedule 111 | celerybeat.pid 112 | 113 | # SageMath parsed files 114 | *.sage.py 115 | 116 | # Environments 117 | .env 118 | .venv 119 | env/ 120 | venv/ 121 | ENV/ 122 | env.bak/ 123 | venv.bak/ 124 | 125 | # Spyder project settings 126 | .spyderproject 127 | .spyproject 128 | 129 | # Rope project settings 130 | .ropeproject 131 | 132 | # mkdocs documentation 133 | /site 134 | 135 | # mypy 136 | .mypy_cache/ 137 | .dmypy.json 138 | dmypy.json 139 | 140 | # Pyre type checker 141 | .pyre/ 142 | 143 | # pytype static type analyzer 144 | .pytype/ 145 | 146 | # Cython debug symbols 147 | cython_debug/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Neural Data Transformers 2 |

3 | 4 |

5 | 6 | This is the code for the paper "Representation learning for neural population activity with Neural Data Transformers". We provide the code as reference, but we are unable to help debug specific issues e.g. in using the model, at this time. NLB-related configuration files are listed directly under `configs/`, whereas original NDT paper configs are under `configs/arxiv/`. 7 | 8 | - Want to quickly get started with NLB? Check out `nlb_from_scratch.py` or `nlb_from_scratch.ipynb`. 9 | 10 | ## Setup 11 | We recommend you set up your code environment with `conda/miniconda`. 12 | The dependencies necessary for this project can then be installed with: 13 | `conda env create -f nlb.yml` 14 | This project was developed with Python 3.6, and originally with `environment.yml`. `nlb.yml` is a fresh Python 3.7 environment with minimal dependencies, but has not been tested for result reproducibility (no impact expected). 15 | 16 | ## Data 17 | The Lorenz dataset is provided in `data/lfads_lorenz.h5`. This file is stored on this repo with [`git-lfs`](https://git-lfs.github.com/). Therefore, if you've not used `git-lfs` before, please run `git lfs install` and `git lfs pull` to pull down the full h5 file. 18 | 19 | The autonomous chaotic RNN dataset can be generated by running `./data/gen_synth_data_no_inputs.sh`. The generating script is taken from [the Tensorflow release](https://github.com/tensorflow/models/tree/master/research/lfads/synth_data) from LFADS, Sussillo et al. The maze dataset is unavailable at this time. 20 | 21 | ## Training + Evaluation 22 | Experimental configurations are set in `./configs/`. To train a single model with a configuration `./configs/.yaml`, run `./scripts/train.sh `. 23 | 24 | The provided sample configurations in `./configs/arxiv/` were used in the HP sweeps for the main results in the paper, with the sweep parameters in `./configs/*json`. Note that sweeping is done with the `ray[tune]` package. To run a sweep, run `python ray_random.py -e ` (the same config system is used). Note that main paper and NLB results use this random search. 25 | 26 | R2 is reported automatically for synthetic datasets. Maze analyses + configurations are unfortunately unavailable at this time. 27 | 28 | ## Analysis 29 | Reference scripts that were used to produce most figures are available in `scripts`. They were created and iterated on as VSCode notebooks. They may require external information, run directories, even codebases etc. Scripts are provided to give a sense of analysis procedure, not to use as an out-of-the-box reproducibility notebook. 30 | 31 | ## Citation 32 | ``` 33 | @article{ye2021ndt, 34 | title = {Representation learning for neural population activity with {Neural} {Data} {Transformers}}, 35 | issn = {2690-2664}, 36 | url = {https://nbdt.scholasticahq.com/article/27358-representation-learning-for-neural-population-activity-with-neural-data-transformers}, 37 | doi = {10.51628/001c.27358}, 38 | language = {en}, 39 | urldate = {2021-08-29}, 40 | journal = {Neurons, Behavior, Data analysis, and Theory}, 41 | author = {Ye, Joel and Pandarinath, Chethan}, 42 | month = aug, 43 | year = {2021}, 44 | } 45 | ``` 46 | -------------------------------------------------------------------------------- /UNLICENSE: -------------------------------------------------------------------------------- 1 | This is free and unencumbered software released into the public domain. 2 | 3 | Anyone is free to copy, modify, publish, use, compile, sell, or 4 | distribute this software, either in source code form or as a compiled 5 | binary, for any purpose, commercial or non-commercial, and by any 6 | means. 7 | 8 | In jurisdictions that recognize copyright laws, the author or authors 9 | of this software dedicate any and all copyright interest in the 10 | software to the public domain. We make this dedication for the benefit 11 | of the public at large and to the detriment of our heirs and 12 | successors. We intend this dedication to be an overt act of 13 | relinquishment in perpetuity of all present and future rights to this 14 | software under copyright law. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 17 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 18 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. 19 | IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR 20 | OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, 21 | ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR 22 | OTHER DEALINGS IN THE SOFTWARE. 23 | 24 | For more information, please refer to 25 | -------------------------------------------------------------------------------- /assets/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snel-repo/neural-data-transformers/98dd85a24885ffb76adfeed0c2a89d3ea3ecf9d1/assets/teaser.png -------------------------------------------------------------------------------- /configs/20ms_arxiv.json: -------------------------------------------------------------------------------- 1 | { 2 | "MODEL.DROPOUT": { 3 | "low": 0.2, 4 | "high": 0.6, 5 | "sample_fn": "uniform", 6 | "explore_wt": 0.3 7 | }, 8 | "MODEL.DROPOUT_RATES": { 9 | "low": 0.2, 10 | "high": 0.6, 11 | "sample_fn": "uniform", 12 | "explore_wt": 0.3 13 | }, 14 | "MODEL.DROPOUT_EMBEDDING": { 15 | "low": 0.2, 16 | "high": 0.6, 17 | "sample_fn": "uniform", 18 | "explore_wt": 0.3 19 | }, 20 | "MODEL.CONTEXT_FORWARD": { 21 | "low": 10, 22 | "high": 40, 23 | "sample_fn": "randint", 24 | "enforce_limits": false, 25 | "is_integer": true 26 | }, 27 | "MODEL.CONTEXT_BACKWARD": { 28 | "low": 10, 29 | "high": 40, 30 | "sample_fn": "randint", 31 | "enforce_limits": false, 32 | "is_integer": true 33 | }, 34 | "TRAIN.LR.INIT": { 35 | "low": 1e-5, 36 | "high": 1e-3, 37 | "sample_fn": "loguniform" 38 | }, 39 | "TRAIN.MASK_RANDOM_RATIO": { 40 | "low": 0.6, 41 | "high": 1.0, 42 | "sample_fn": "uniform" 43 | }, 44 | "TRAIN.MASK_TOKEN_RATIO": { 45 | "low": 0.6, 46 | "high": 1.0, 47 | "sample_fn": "uniform" 48 | }, 49 | "TRAIN.MASK_MAX_SPAN": { 50 | "low": 1, 51 | "high": 7, 52 | "sample_fn": "randint", 53 | "is_integer": true 54 | } 55 | } -------------------------------------------------------------------------------- /configs/20ms_arxiv_no_reg.json: -------------------------------------------------------------------------------- 1 | { 2 | "MODEL.DROPOUT": { 3 | "low": 0.1, 4 | "high": 0.2, 5 | "sample_fn": "uniform", 6 | "explore_wt": 0.3 7 | }, 8 | "MODEL.DROPOUT_RATES": { 9 | "low": 0.1, 10 | "high": 0.2, 11 | "sample_fn": "uniform", 12 | "explore_wt": 0.3 13 | }, 14 | "MODEL.DROPOUT_EMBEDDING": { 15 | "low": 0.1, 16 | "high": 0.2, 17 | "sample_fn": "uniform", 18 | "explore_wt": 0.3 19 | }, 20 | "MODEL.CONTEXT_FORWARD": { 21 | "low": 10, 22 | "high": 40, 23 | "sample_fn": "randint", 24 | "enforce_limits": false, 25 | "is_integer": true 26 | }, 27 | "MODEL.CONTEXT_BACKWARD": { 28 | "low": 10, 29 | "high": 40, 30 | "sample_fn": "randint", 31 | "enforce_limits": false, 32 | "is_integer": true 33 | }, 34 | "TRAIN.LR.INIT": { 35 | "low": 1e-5, 36 | "high": 1e-3, 37 | "sample_fn": "loguniform" 38 | }, 39 | "TRAIN.MASK_RANDOM_RATIO": { 40 | "low": 0.6, 41 | "high": 1.0, 42 | "sample_fn": "uniform" 43 | }, 44 | "TRAIN.MASK_TOKEN_RATIO": { 45 | "low": 0.6, 46 | "high": 1.0, 47 | "sample_fn": "uniform" 48 | }, 49 | "TRAIN.MASK_MAX_SPAN": { 50 | "low": 1, 51 | "high": 7, 52 | "sample_fn": "randint", 53 | "is_integer": true 54 | } 55 | } -------------------------------------------------------------------------------- /configs/20ms_full_context.json: -------------------------------------------------------------------------------- 1 | { 2 | "MODEL.DROPOUT": { 3 | "low": 0.2, 4 | "high": 0.6, 5 | "sample_fn": "uniform", 6 | "explore_wt": 0.3 7 | }, 8 | "MODEL.DROPOUT_RATES": { 9 | "low": 0.2, 10 | "high": 0.6, 11 | "sample_fn": "uniform", 12 | "explore_wt": 0.3 13 | }, 14 | "MODEL.DROPOUT_EMBEDDING": { 15 | "low": 0.2, 16 | "high": 0.6, 17 | "sample_fn": "uniform", 18 | "explore_wt": 0.3 19 | }, 20 | "TRAIN.LR.INIT": { 21 | "low": 1e-5, 22 | "high": 1e-3, 23 | "sample_fn": "loguniform" 24 | }, 25 | "TRAIN.MASK_RANDOM_RATIO": { 26 | "low": 0.6, 27 | "high": 1.0, 28 | "sample_fn": "uniform" 29 | }, 30 | "TRAIN.MASK_TOKEN_RATIO": { 31 | "low": 0.6, 32 | "high": 1.0, 33 | "sample_fn": "uniform" 34 | }, 35 | "TRAIN.MASK_MAX_SPAN": { 36 | "low": 1, 37 | "high": 5, 38 | "sample_fn": "randint", 39 | "is_integer": true 40 | } 41 | } -------------------------------------------------------------------------------- /configs/20ms_mask.json: -------------------------------------------------------------------------------- 1 | { 2 | "MODEL.DROPOUT": { 3 | "low": 0.2, 4 | "high": 0.6, 5 | "sample_fn": "uniform", 6 | "explore_wt": 0.3 7 | }, 8 | "MODEL.DROPOUT_RATES": { 9 | "low": 0.2, 10 | "high": 0.6, 11 | "sample_fn": "uniform", 12 | "explore_wt": 0.3 13 | }, 14 | "MODEL.DROPOUT_EMBEDDING": { 15 | "low": 0.2, 16 | "high": 0.6, 17 | "sample_fn": "uniform", 18 | "explore_wt": 0.3 19 | }, 20 | "MODEL.CONTEXT_FORWARD": { 21 | "low": 10, 22 | "high": 40, 23 | "sample_fn": "randint", 24 | "enforce_limits": false, 25 | "is_integer": true 26 | }, 27 | "MODEL.CONTEXT_BACKWARD": { 28 | "low": 10, 29 | "high": 40, 30 | "sample_fn": "randint", 31 | "enforce_limits": false, 32 | "is_integer": true 33 | }, 34 | "TRAIN.MASK_RANDOM_RATIO": { 35 | "low": 0.6, 36 | "high": 1.0, 37 | "sample_fn": "uniform" 38 | }, 39 | "TRAIN.MASK_TOKEN_RATIO": { 40 | "low": 0.6, 41 | "high": 1.0, 42 | "sample_fn": "uniform" 43 | }, 44 | "TRAIN.MASK_MAX_SPAN": { 45 | "low": 1, 46 | "high": 5, 47 | "sample_fn": "randint", 48 | "is_integer": true 49 | } 50 | } -------------------------------------------------------------------------------- /configs/20ms_no_span.json: -------------------------------------------------------------------------------- 1 | { 2 | "MODEL.DROPOUT": { 3 | "low": 0.2, 4 | "high": 0.6, 5 | "sample_fn": "uniform", 6 | "explore_wt": 0.3 7 | }, 8 | "MODEL.DROPOUT_RATES": { 9 | "low": 0.2, 10 | "high": 0.6, 11 | "sample_fn": "uniform", 12 | "explore_wt": 0.3 13 | }, 14 | "MODEL.DROPOUT_EMBEDDING": { 15 | "low": 0.2, 16 | "high": 0.6, 17 | "sample_fn": "uniform", 18 | "explore_wt": 0.3 19 | }, 20 | "MODEL.CONTEXT_FORWARD": { 21 | "low": 10, 22 | "high": 40, 23 | "sample_fn": "randint", 24 | "enforce_limits": false, 25 | "is_integer": true 26 | }, 27 | "MODEL.CONTEXT_BACKWARD": { 28 | "low": 10, 29 | "high": 40, 30 | "sample_fn": "randint", 31 | "enforce_limits": false, 32 | "is_integer": true 33 | }, 34 | "TRAIN.LR.INIT": { 35 | "low": 1e-5, 36 | "high": 1e-3, 37 | "sample_fn": "loguniform" 38 | }, 39 | "TRAIN.MASK_RANDOM_RATIO": { 40 | "low": 0.6, 41 | "high": 1.0, 42 | "sample_fn": "uniform" 43 | }, 44 | "TRAIN.MASK_TOKEN_RATIO": { 45 | "low": 0.6, 46 | "high": 1.0, 47 | "sample_fn": "uniform" 48 | } 49 | } -------------------------------------------------------------------------------- /configs/area2_bump.yaml: -------------------------------------------------------------------------------- 1 | CHECKPOINT_DIR: "/snel/share/joel/transformer_modeling/" 2 | DATA: 3 | DATAPATH: "/snel/share/data/nlb/" 4 | TRAIN_FILENAME: 'area2_bump.h5' 5 | VAL_FILENAME: 'area2_bump.h5' 6 | MODEL: 7 | TRIAL_LENGTH: 0 8 | LEARNABLE_POSITION: True 9 | PRE_NORM: True 10 | FIXUP_INIT: True 11 | EMBED_DIM: 2 # We embed to 2 here so transformer can use 2 heads. Perf diff is minimal. 12 | LOGRATE: True 13 | NUM_LAYERS: 4 14 | TRAIN: 15 | LR: 16 | WARMUP: 5000 17 | WEIGHT_DECAY: 5.0e-05 18 | LOG_INTERVAL: 200 19 | VAL_INTERVAL: 20 20 | CHECKPOINT_INTERVAL: 1000 21 | PATIENCE: 2500 22 | NUM_UPDATES: 40001 23 | MASK_RATIO: 0.25 24 | 25 | # Aggressive regularization 26 | TUNE_HP_JSON: './configs/sweep_bump.json' -------------------------------------------------------------------------------- /configs/arxiv/chaotic.yaml: -------------------------------------------------------------------------------- 1 | CHECKPOINT_DIR: "/snel/share/joel/transformer_modeling/" 2 | DATA: 3 | DATAPATH: "/snel/share/data/chaotic_rnn_data/data_no_inputs/" 4 | TRAIN_FILENAME: 'chaotic_rnn_no_inputs_dataset_N50_S50' 5 | VAL_FILENAME: 'chaotic_rnn_no_inputs_dataset_N50_S50' 6 | MODEL: 7 | TRIAL_LENGTH: 100 8 | LEARNABLE_POSITION: True 9 | EMBED_DIM: 2 10 | TRAIN: 11 | LR: 12 | SCHEDULE: false 13 | LOG_INTERVAL: 250 14 | CHECKPOINT_INTERVAL: 250 15 | NUM_UPDATES: 5001 16 | MASK_RATIO: 0.25 17 | WEIGHT_DECAY: 0.0001 18 | 19 | TUNE_EPOCHS_PER_GENERATION: 100 20 | TUNE_HP_JSON: './configs/sweep_generic.json' 21 | -------------------------------------------------------------------------------- /configs/arxiv/lorenz.yaml: -------------------------------------------------------------------------------- 1 | CHECKPOINT_DIR: "/snel/share/joel/transformer_modeling/" 2 | DATA: 3 | DATAPATH: "/snel/share/data/lfads_lorenz_20ms/" 4 | TRAIN_FILENAME: 'lfads_dataset001.h5' 5 | VAL_FILENAME: 'lfads_dataset001.h5' 6 | MODEL: 7 | TRIAL_LENGTH: 50 8 | LEARNABLE_POSITION: True 9 | PRE_NORM: True 10 | FIXUP_INIT: True 11 | EMBED_DIM: 2 # We embed to 2 here so transformer can use 2 heads. Perf diff is minimal. 12 | LOGRATE: True 13 | TRAIN: 14 | LR: 15 | SCHEDULE: False 16 | LOG_INTERVAL: 50 17 | CHECKPOINT_INTERVAL: 500 18 | PATIENCE: 2500 19 | NUM_UPDATES: 20001 20 | MASK_RATIO: 0.25 21 | 22 | TUNE_HP_JSON: './configs/sweep_generic.json' -------------------------------------------------------------------------------- /configs/arxiv/m700_115.yaml: -------------------------------------------------------------------------------- 1 | CHECKPOINT_DIR: "/snel/share/joel/transformer_modeling/" 2 | DATA: 3 | DATAPATH: "/snel/share/data/ndt_paper/m1_maze/heldout_trial/0115_trials/0_seed" 4 | TRAIN_FILENAME: 'lfads_input.h5' 5 | VAL_FILENAME: 'lfads_input.h5' 6 | MODEL: 7 | TRIAL_LENGTH: 70 8 | LEARNABLE_POSITION: True 9 | PRE_NORM: True 10 | FIXUP_INIT: True 11 | EMBED_DIM: 0 12 | LOGRATE: True 13 | TRAIN: 14 | LR: 15 | WARMUP: 5000 16 | MASK_RATIO: 0.25 17 | WEIGHT_DECAY: 5.0e-05 18 | PATIENCE: 3000 19 | LOG_INTERVAL: 200 20 | VAL_INTERVAL: 1 21 | CHECKPOINT_INTERVAL: 1000 22 | NUM_UPDATES: 50501 23 | MASK_SPAN_RAMP_START: 8000 24 | MASK_SPAN_RAMP_END: 12000 25 | 26 | TUNE_HP_JSON: './configs/20ms_arxiv.json' # This space has more aggressive regularization -------------------------------------------------------------------------------- /configs/arxiv/m700_2296.yaml: -------------------------------------------------------------------------------- 1 | CHECKPOINT_DIR: "/snel/share/joel/transformer_modeling/" 2 | DATA: 3 | DATAPATH: "/snel/share/data/ndt_paper/m1_maze/heldout_trial/2296_trials/0_seed" 4 | TRAIN_FILENAME: 'lfads_input.h5' 5 | VAL_FILENAME: 'lfads_input.h5' 6 | MODEL: 7 | TRIAL_LENGTH: 70 8 | LEARNABLE_POSITION: True 9 | PRE_NORM: True 10 | FIXUP_INIT: True 11 | EMBED_DIM: 0 12 | LOGRATE: True 13 | TRAIN: 14 | LR: 15 | WARMUP: 5000 16 | MASK_RATIO: 0.25 17 | WEIGHT_DECAY: 5.0e-05 18 | PATIENCE: 3000 19 | LOG_INTERVAL: 200 20 | VAL_INTERVAL: 20 21 | CHECKPOINT_INTERVAL: 1000 22 | NUM_UPDATES: 50501 23 | MASK_SPAN_RAMP_START: 8000 24 | MASK_SPAN_RAMP_END: 12000 25 | 26 | TUNE_HP_JSON: './configs/20ms_arxiv.json' # This space has more aggressive regularization 27 | -------------------------------------------------------------------------------- /configs/arxiv/m700_2296_postnorm.yaml: -------------------------------------------------------------------------------- 1 | CHECKPOINT_DIR: "/snel/share/joel/transformer_modeling/" 2 | DATA: 3 | DATAPATH: "/snel/share/data/ndt_paper/m1_maze/heldout_trial/2296_trials/0_seed" 4 | TRAIN_FILENAME: 'lfads_input.h5' 5 | VAL_FILENAME: 'lfads_input.h5' 6 | MODEL: 7 | TRIAL_LENGTH: 70 8 | LEARNABLE_POSITION: True 9 | PRE_NORM: False 10 | FIXUP_INIT: False 11 | EMBED_DIM: 0 12 | LOGRATE: True 13 | TRAIN: 14 | LR: 15 | WARMUP: 5000 16 | MASK_RATIO: 0.25 17 | WEIGHT_DECAY: 5.0e-05 18 | PATIENCE: 3000 19 | LOG_INTERVAL: 200 20 | VAL_INTERVAL: 20 21 | CHECKPOINT_INTERVAL: 1000 22 | NUM_UPDATES: 50501 23 | MASK_SPAN_RAMP_START: 8000 24 | MASK_SPAN_RAMP_END: 12000 25 | 26 | TUNE_HP_JSON: './configs/20ms_arxiv.json' # This space has more aggressive regularization 27 | -------------------------------------------------------------------------------- /configs/arxiv/m700_230.yaml: -------------------------------------------------------------------------------- 1 | CHECKPOINT_DIR: "/snel/share/joel/transformer_modeling/" 2 | DATA: 3 | DATAPATH: "/snel/share/data/ndt_paper/m1_maze/heldout_trial/0230_trials/0_seed" 4 | TRAIN_FILENAME: 'lfads_input.h5' 5 | VAL_FILENAME: 'lfads_input.h5' 6 | MODEL: 7 | TRIAL_LENGTH: 70 8 | LEARNABLE_POSITION: True 9 | PRE_NORM: True 10 | FIXUP_INIT: True 11 | EMBED_DIM: 0 12 | LOGRATE: True 13 | TRAIN: 14 | LR: 15 | WARMUP: 5000 16 | MASK_RATIO: 0.25 17 | WEIGHT_DECAY: 5.0e-05 18 | PATIENCE: 3000 19 | LOG_INTERVAL: 200 20 | VAL_INTERVAL: 2 21 | CHECKPOINT_INTERVAL: 1000 22 | NUM_UPDATES: 50501 23 | MASK_SPAN_RAMP_START: 8000 24 | MASK_SPAN_RAMP_END: 12000 25 | 26 | TUNE_HP_JSON: './configs/20ms_arxiv.json' # This space has more aggressive regularization 27 | -------------------------------------------------------------------------------- /configs/arxiv/m700_460.yaml: -------------------------------------------------------------------------------- 1 | CHECKPOINT_DIR: "/snel/share/joel/transformer_modeling/" 2 | DATA: 3 | DATAPATH: "/snel/share/data/ndt_paper/m1_maze/heldout_trial/0460_trials/0_seed" 4 | TRAIN_FILENAME: 'lfads_input.h5' 5 | VAL_FILENAME: 'lfads_input.h5' 6 | MODEL: 7 | TRIAL_LENGTH: 70 8 | LEARNABLE_POSITION: True 9 | PRE_NORM: True 10 | FIXUP_INIT: True 11 | EMBED_DIM: 0 12 | LOGRATE: True 13 | TRAIN: 14 | LR: 15 | WARMUP: 5000 16 | MASK_RATIO: 0.25 17 | WEIGHT_DECAY: 5.0e-05 18 | PATIENCE: 3000 19 | LOG_INTERVAL: 200 20 | VAL_INTERVAL: 4 21 | CHECKPOINT_INTERVAL: 1000 22 | NUM_UPDATES: 50501 23 | MASK_SPAN_RAMP_START: 8000 24 | MASK_SPAN_RAMP_END: 12000 25 | 26 | TUNE_HP_JSON: './configs/20ms_arxiv.json' # This space has more aggressive regularization 27 | -------------------------------------------------------------------------------- /configs/arxiv/m700_no_log.yaml: -------------------------------------------------------------------------------- 1 | CHECKPOINT_DIR: "/snel/share/joel/transformer_modeling/" 2 | DATA: 3 | DATAPATH: "/snel/share/data/ndt_paper/m1_maze/heldout_trial/2296_trials/0_seed" 4 | TRAIN_FILENAME: 'lfads_input.h5' 5 | VAL_FILENAME: 'lfads_input.h5' 6 | MODEL: 7 | TRIAL_LENGTH: 70 8 | LEARNABLE_POSITION: True 9 | PRE_NORM: True 10 | FIXUP_INIT: True 11 | EMBED_DIM: 0 12 | LOGRATE: False 13 | TRAIN: 14 | LR: 15 | WARMUP: 5000 16 | MASK_RATIO: 0.25 17 | WEIGHT_DECAY: 5.0e-05 18 | PATIENCE: 3000 19 | LOG_INTERVAL: 200 20 | VAL_INTERVAL: 20 21 | CHECKPOINT_INTERVAL: 1000 22 | NUM_UPDATES: 50501 23 | MASK_SPAN_RAMP_START: 8000 24 | MASK_SPAN_RAMP_END: 12000 25 | 26 | TUNE_HP_JSON: './configs/20ms_arxiv.json' # This space has more aggressive regularization 27 | -------------------------------------------------------------------------------- /configs/arxiv/m700_no_reg.yaml: -------------------------------------------------------------------------------- 1 | CHECKPOINT_DIR: "/snel/share/joel/transformer_modeling/" 2 | DATA: 3 | DATAPATH: "/snel/share/data/ndt_paper/m1_maze/heldout_trial/2296_trials/0_seed" 4 | TRAIN_FILENAME: 'lfads_input.h5' 5 | VAL_FILENAME: 'lfads_input.h5' 6 | MODEL: 7 | TRIAL_LENGTH: 70 8 | LEARNABLE_POSITION: True 9 | PRE_NORM: True 10 | FIXUP_INIT: True 11 | EMBED_DIM: 0 12 | LOGRATE: False 13 | TRAIN: 14 | LR: 15 | WARMUP: 5000 16 | MASK_RATIO: 0.25 17 | WEIGHT_DECAY: 5.0e-05 18 | PATIENCE: 3000 19 | LOG_INTERVAL: 200 20 | VAL_INTERVAL: 20 21 | CHECKPOINT_INTERVAL: 1000 22 | NUM_UPDATES: 50501 23 | MASK_SPAN_RAMP_START: 8000 24 | MASK_SPAN_RAMP_END: 12000 25 | 26 | TUNE_HP_JSON: './configs/20ms_arxiv_no_reg.json' # This space has more aggressive regularization 27 | -------------------------------------------------------------------------------- /configs/arxiv/m700_no_span.yaml: -------------------------------------------------------------------------------- 1 | CHECKPOINT_DIR: "/snel/share/joel/transformer_modeling/" 2 | DATA: 3 | DATAPATH: "/snel/share/data/ndt_paper/m1_maze/heldout_trial/2296_trials/0_seed" 4 | TRAIN_FILENAME: 'lfads_input.h5' 5 | VAL_FILENAME: 'lfads_input.h5' 6 | MODEL: 7 | TRIAL_LENGTH: 70 8 | LEARNABLE_POSITION: True 9 | PRE_NORM: True 10 | FIXUP_INIT: True 11 | EMBED_DIM: 0 12 | LOGRATE: True 13 | TRAIN: 14 | LR: 15 | WARMUP: 5000 16 | MASK_RATIO: 0.25 17 | WEIGHT_DECAY: 5.0e-05 18 | PATIENCE: 3000 19 | LOG_INTERVAL: 200 20 | VAL_INTERVAL: 20 21 | CHECKPOINT_INTERVAL: 1000 22 | NUM_UPDATES: 50501 23 | MASK_SPAN_RAMP_START: 8000 24 | MASK_SPAN_RAMP_END: 12000 25 | 26 | TUNE_HP_JSON: './configs/20ms_no_span.json' 27 | -------------------------------------------------------------------------------- /configs/arxiv/m700_nonzero.yaml: -------------------------------------------------------------------------------- 1 | CHECKPOINT_DIR: "/snel/share/joel/transformer_modeling/" 2 | DATA: 3 | DATAPATH: "/snel/share/data/ndt_paper/m1_maze/heldout_trial/2296_trials/0_seed" 4 | TRAIN_FILENAME: 'lfads_input.h5' 5 | VAL_FILENAME: 'lfads_input.h5' 6 | MODEL: 7 | TRIAL_LENGTH: 70 8 | LEARNABLE_POSITION: True 9 | PRE_NORM: True 10 | FIXUP_INIT: True 11 | EMBED_DIM: 0 12 | LOGRATE: False 13 | TRAIN: 14 | LR: 15 | WARMUP: 5000 16 | MASK_RATIO: 0.25 17 | WEIGHT_DECAY: 5.0e-05 18 | PATIENCE: 3000 19 | LOG_INTERVAL: 200 20 | VAL_INTERVAL: 20 21 | CHECKPOINT_INTERVAL: 1000 22 | NUM_UPDATES: 50501 23 | MASK_SPAN_RAMP_START: 8000 24 | MASK_SPAN_RAMP_END: 12000 25 | USE_ZERO_MASK: False 26 | 27 | TUNE_HP_JSON: './configs/20ms_arxiv.json' # This space has more aggressive regularization 28 | -------------------------------------------------------------------------------- /configs/dmfc_rsg.json: -------------------------------------------------------------------------------- 1 | { 2 | "MODEL.DROPOUT": { 3 | "low": 0.3, 4 | "high": 0.7, 5 | "sample_fn": "uniform", 6 | "explore_wt": 0.3 7 | }, 8 | "MODEL.DROPOUT_RATES": { 9 | "low": 0.3, 10 | "high": 0.7, 11 | "sample_fn": "uniform", 12 | "explore_wt": 0.3 13 | }, 14 | "MODEL.DROPOUT_EMBEDDING": { 15 | "low": 0.3, 16 | "high": 0.7, 17 | "sample_fn": "uniform", 18 | "explore_wt": 0.3 19 | }, 20 | "MODEL.CONTEXT_FORWARD": { 21 | "low": 40, 22 | "high": 240, 23 | "sample_fn": "randint", 24 | "enforce_limits": false, 25 | "is_integer": true 26 | }, 27 | "MODEL.CONTEXT_BACKWARD": { 28 | "low": 40, 29 | "high": 240, 30 | "sample_fn": "randint", 31 | "enforce_limits": false, 32 | "is_integer": true 33 | }, 34 | "TRAIN.LR.INIT": { 35 | "low": 1e-5, 36 | "high": 5e-3, 37 | "explore_wt": 0.3 38 | }, 39 | "TRAIN.LR.WARMUP": { 40 | "low": 0, 41 | "high": 2000, 42 | "sample_fn": "randint", 43 | "is_integer": true 44 | }, 45 | "TRAIN.MASK_RANDOM_RATIO": { 46 | "low": 0.9, 47 | "high": 1.0, 48 | "sample_fn": "uniform" 49 | }, 50 | "TRAIN.MASK_TOKEN_RATIO": { 51 | "low": 0.5, 52 | "high": 1.0, 53 | "sample_fn": "loguniform" 54 | }, 55 | "TRAIN.MASK_MAX_SPAN": { 56 | "low": 1, 57 | "high": 7, 58 | "sample_fn": "randint", 59 | "is_integer": true 60 | } 61 | } -------------------------------------------------------------------------------- /configs/dmfc_rsg.yaml: -------------------------------------------------------------------------------- 1 | CHECKPOINT_DIR: "/snel/share/joel/transformer_modeling/" 2 | DATA: 3 | DATAPATH: "/snel/share/data/nlb/" 4 | TRAIN_FILENAME: 'dmfc_rsg.h5' 5 | VAL_FILENAME: 'dmfc_rsg.h5' 6 | MODEL: 7 | TRIAL_LENGTH: 0 # 300 8 | LEARNABLE_POSITION: True 9 | PRE_NORM: True 10 | FIXUP_INIT: True 11 | EMBED_DIM: 0 12 | LOGRATE: True 13 | NUM_LAYERS: 6 14 | TRAIN: 15 | LR: 16 | WARMUP: 5000 17 | MASK_RATIO: 0.25 18 | WEIGHT_DECAY: 5.0e-05 19 | PATIENCE: 4000 # it's hard... 20 | LOG_INTERVAL: 200 21 | VAL_INTERVAL: 20 22 | CHECKPOINT_INTERVAL: 1000 23 | NUM_UPDATES: 50501 24 | MASK_SPAN_RAMP_START: 8000 25 | MASK_SPAN_RAMP_END: 12000 26 | 27 | TUNE_HP_JSON: './configs/dmfc_rsg.json' # This space has more aggressive regularization 28 | 29 | # current run as of Wednesday used 4 layers, 0.1-0.5 reg -------------------------------------------------------------------------------- /configs/mc_maze.json: -------------------------------------------------------------------------------- 1 | { 2 | "MODEL.DROPOUT": { 3 | "low": 0.3, 4 | "high": 0.7, 5 | "sample_fn": "uniform", 6 | "explore_wt": 0.3 7 | }, 8 | "MODEL.DROPOUT_RATES": { 9 | "low": 0.3, 10 | "high": 0.7, 11 | "sample_fn": "uniform", 12 | "explore_wt": 0.3 13 | }, 14 | "MODEL.DROPOUT_EMBEDDING": { 15 | "low": 0.3, 16 | "high": 0.7, 17 | "sample_fn": "uniform", 18 | "explore_wt": 0.3 19 | }, 20 | "MODEL.CONTEXT_FORWARD": { 21 | "low": 10, 22 | "high": 100, 23 | "sample_fn": "randint", 24 | "enforce_limits": false, 25 | "is_integer": true 26 | }, 27 | "MODEL.CONTEXT_BACKWARD": { 28 | "low": 10, 29 | "high": 100, 30 | "sample_fn": "randint", 31 | "enforce_limits": false, 32 | "is_integer": true 33 | }, 34 | "TRAIN.LR.INIT": { 35 | "low": 1e-5, 36 | "high": 5e-3, 37 | "explore_wt": 0.3 38 | }, 39 | "TRAIN.LR.WARMUP": { 40 | "low": 0, 41 | "high": 2000, 42 | "sample_fn": "randint", 43 | "is_integer": true 44 | }, 45 | "TRAIN.MASK_RANDOM_RATIO": { 46 | "low": 0.8, 47 | "high": 1.0, 48 | "sample_fn": "uniform" 49 | }, 50 | "TRAIN.MASK_TOKEN_RATIO": { 51 | "low": 0.5, 52 | "high": 1.0, 53 | "sample_fn": "loguniform" 54 | }, 55 | "TRAIN.MASK_MAX_SPAN": { 56 | "low": 1, 57 | "high": 7, 58 | "sample_fn": "randint", 59 | "is_integer": true 60 | } 61 | } -------------------------------------------------------------------------------- /configs/mc_maze.yaml: -------------------------------------------------------------------------------- 1 | CHECKPOINT_DIR: "/snel/share/joel/transformer_modeling/" 2 | DATA: 3 | DATAPATH: "/snel/share/data/nlb/" 4 | TRAIN_FILENAME: 'mc_maze.h5' 5 | VAL_FILENAME: 'mc_maze.h5' 6 | MODEL: 7 | TRIAL_LENGTH: 0 # 140 8 | LEARNABLE_POSITION: True 9 | PRE_NORM: True 10 | FIXUP_INIT: True 11 | EMBED_DIM: 0 12 | LOGRATE: True 13 | NUM_LAYERS: 4 14 | TRAIN: 15 | LR: 16 | WARMUP: 5000 17 | MASK_RATIO: 0.25 18 | WEIGHT_DECAY: 5.0e-05 19 | PATIENCE: 3000 20 | LOG_INTERVAL: 200 21 | VAL_INTERVAL: 20 22 | CHECKPOINT_INTERVAL: 1000 23 | NUM_UPDATES: 50501 24 | MASK_SPAN_RAMP_START: 8000 25 | MASK_SPAN_RAMP_END: 12000 26 | 27 | TUNE_HP_JSON: './configs/mc_maze.json' # This space has more aggressive regularization 28 | -------------------------------------------------------------------------------- /configs/mc_maze_large.yaml: -------------------------------------------------------------------------------- 1 | CHECKPOINT_DIR: "/snel/share/joel/transformer_modeling/" 2 | DATA: 3 | DATAPATH: "/snel/share/data/nlb/" 4 | TRAIN_FILENAME: 'mc_maze_large.h5' 5 | VAL_FILENAME: 'mc_maze_large.h5' 6 | MODEL: 7 | TRIAL_LENGTH: 0 # 140 8 | LEARNABLE_POSITION: True 9 | PRE_NORM: True 10 | FIXUP_INIT: True 11 | EMBED_DIM: 0 12 | LOGRATE: True 13 | NUM_LAYERS: 4 14 | TRAIN: 15 | LR: 16 | WARMUP: 5000 17 | MASK_RATIO: 0.25 18 | WEIGHT_DECAY: 5.0e-05 19 | PATIENCE: 3000 20 | LOG_INTERVAL: 200 21 | VAL_INTERVAL: 20 22 | CHECKPOINT_INTERVAL: 1000 23 | NUM_UPDATES: 50501 24 | MASK_SPAN_RAMP_START: 8000 25 | MASK_SPAN_RAMP_END: 12000 26 | 27 | TUNE_HP_JSON: './configs/mc_maze.json' # This space has more aggressive regularization 28 | -------------------------------------------------------------------------------- /configs/mc_maze_medium.yaml: -------------------------------------------------------------------------------- 1 | CHECKPOINT_DIR: "/snel/share/joel/transformer_modeling/" 2 | DATA: 3 | DATAPATH: "/snel/share/data/nlb/" 4 | TRAIN_FILENAME: 'mc_maze_medium.h5' 5 | VAL_FILENAME: 'mc_maze_medium.h5' 6 | MODEL: 7 | TRIAL_LENGTH: 0 # 140 8 | LEARNABLE_POSITION: True 9 | PRE_NORM: True 10 | FIXUP_INIT: True 11 | EMBED_DIM: 0 12 | LOGRATE: True 13 | NUM_LAYERS: 4 14 | TRAIN: 15 | LR: 16 | WARMUP: 5000 17 | MASK_RATIO: 0.25 18 | WEIGHT_DECAY: 5.0e-05 19 | PATIENCE: 3000 20 | LOG_INTERVAL: 200 21 | VAL_INTERVAL: 20 22 | CHECKPOINT_INTERVAL: 1000 23 | NUM_UPDATES: 50501 24 | MASK_SPAN_RAMP_START: 8000 25 | MASK_SPAN_RAMP_END: 12000 26 | 27 | TUNE_HP_JSON: './configs/mc_maze.json' # This space has more aggressive regularization 28 | -------------------------------------------------------------------------------- /configs/mc_maze_small.yaml: -------------------------------------------------------------------------------- 1 | CHECKPOINT_DIR: "/home/joelye/user_data/nlb/ndt_runs/" 2 | # CHECKPOINT_DIR: "/snel/share/joel/transformer_modeling/" 3 | DATA: 4 | DATAPATH: "/home/joelye/user_data/nlb/ndt_old/" 5 | # DATAPATH: "/snel/share/data/nlb/" 6 | TRAIN_FILENAME: 'mc_maze_small.h5' 7 | VAL_FILENAME: 'mc_maze_small.h5' 8 | MODEL: 9 | TRIAL_LENGTH: 0 # 140 10 | LEARNABLE_POSITION: True 11 | PRE_NORM: True 12 | FIXUP_INIT: True 13 | EMBED_DIM: 0 14 | LOGRATE: True 15 | NUM_LAYERS: 4 16 | TRAIN: 17 | LR: 18 | WARMUP: 2000 19 | MASK_RATIO: 0.25 20 | WEIGHT_DECAY: 5.0e-05 21 | PATIENCE: 2000 22 | LOG_INTERVAL: 200 23 | VAL_INTERVAL: 20 24 | CHECKPOINT_INTERVAL: 1000 25 | NUM_UPDATES: 50501 26 | MASK_SPAN_RAMP_START: 5000 27 | MASK_SPAN_RAMP_END: 10000 28 | 29 | TUNE_HP_JSON: './configs/mc_maze.json' # This space has more aggressive regularization 30 | -------------------------------------------------------------------------------- /configs/mc_maze_small_from_scratch.yaml: -------------------------------------------------------------------------------- 1 | CHECKPOINT_DIR: "/home/joelye/user_data/nlb/ndt_runs" 2 | DATA: 3 | DATAPATH: "/home/joelye/user_data/nlb" 4 | TRAIN_FILENAME: 'mc_maze_small.h5' 5 | VAL_FILENAME: 'mc_maze_small.h5' 6 | MODEL: 7 | TRIAL_LENGTH: 0 # 140 8 | LEARNABLE_POSITION: True 9 | PRE_NORM: True 10 | FIXUP_INIT: True 11 | EMBED_DIM: 0 12 | LOGRATE: True 13 | NUM_LAYERS: 4 14 | TRAIN: 15 | LR: 16 | WARMUP: 2000 17 | MASK_RATIO: 0.25 18 | WEIGHT_DECAY: 5.0e-05 19 | PATIENCE: 2000 20 | LOG_INTERVAL: 200 21 | VAL_INTERVAL: 20 22 | CHECKPOINT_INTERVAL: 1000 23 | NUM_UPDATES: 50501 24 | MASK_SPAN_RAMP_START: 5000 25 | MASK_SPAN_RAMP_END: 10000 26 | 27 | TUNE_HP_JSON: './configs/mc_maze.json' # This space has more aggressive regularization 28 | -------------------------------------------------------------------------------- /configs/mc_rtt.yaml: -------------------------------------------------------------------------------- 1 | CHECKPOINT_DIR: "/snel/share/joel/transformer_modeling/" 2 | DATA: 3 | DATAPATH: "/snel/share/data/nlb/" 4 | TRAIN_FILENAME: 'mc_rtt.h5' 5 | VAL_FILENAME: 'mc_rtt.h5' 6 | MODEL: 7 | TRIAL_LENGTH: 0 # 140 8 | LEARNABLE_POSITION: True 9 | PRE_NORM: True 10 | FIXUP_INIT: True 11 | EMBED_DIM: 0 12 | LOGRATE: True 13 | NUM_LAYERS: 4 14 | TRAIN: 15 | LR: 16 | WARMUP: 5000 17 | MASK_RATIO: 0.25 18 | WEIGHT_DECAY: 5.0e-05 19 | PATIENCE: 3000 20 | LOG_INTERVAL: 200 21 | VAL_INTERVAL: 20 22 | CHECKPOINT_INTERVAL: 1000 23 | NUM_UPDATES: 50501 24 | MASK_SPAN_RAMP_START: 8000 25 | MASK_SPAN_RAMP_END: 12000 26 | 27 | TUNE_HP_JSON: './configs/mc_maze.json' # This space has more aggressive regularization 28 | -------------------------------------------------------------------------------- /configs/sweep_bump.json: -------------------------------------------------------------------------------- 1 | { 2 | "MODEL.DROPOUT": { 3 | "low": 0.2, 4 | "high": 0.6, 5 | "sample_fn": "uniform", 6 | "explore_wt": 0.3 7 | }, 8 | "MODEL.DROPOUT_RATES": { 9 | "low": 0.2, 10 | "high": 0.6, 11 | "sample_fn": "uniform", 12 | "explore_wt": 0.3 13 | }, 14 | "MODEL.DROPOUT_EMBEDDING": { 15 | "low": 0.2, 16 | "high": 0.6, 17 | "sample_fn": "uniform", 18 | "explore_wt": 0.3 19 | }, 20 | "MODEL.CONTEXT_FORWARD": { 21 | "low": 4, 22 | "high": 100, 23 | "sample_fn": "randint", 24 | "enforce_limits": false, 25 | "is_integer": true 26 | }, 27 | "MODEL.CONTEXT_BACKWARD": { 28 | "low": 4, 29 | "high": 100, 30 | "sample_fn": "randint", 31 | "enforce_limits": false, 32 | "is_integer": true 33 | }, 34 | "TRAIN.LR.INIT": { 35 | "low": 1e-5, 36 | "high": 5e-3, 37 | "explore_wt": 0.3 38 | }, 39 | "TRAIN.MASK_RANDOM_RATIO": { 40 | "low": 0.9, 41 | "high": 1.0, 42 | "sample_fn": "uniform" 43 | }, 44 | "TRAIN.MASK_TOKEN_RATIO": { 45 | "low": 0.5, 46 | "high": 1.0, 47 | "sample_fn": "loguniform" 48 | }, 49 | "TRAIN.MASK_MAX_SPAN": { 50 | "low": 1, 51 | "high": 5, 52 | "sample_fn": "randint", 53 | "is_integer": true 54 | } 55 | } -------------------------------------------------------------------------------- /configs/sweep_generic.json: -------------------------------------------------------------------------------- 1 | { 2 | "MODEL.DROPOUT": { 3 | "low": 0.0, 4 | "high": 0.3, 5 | "sample_fn": "uniform", 6 | "explore_wt": 0.3 7 | }, 8 | "MODEL.DROPOUT_RATES": { 9 | "low": 0.0, 10 | "high": 0.3, 11 | "sample_fn": "uniform", 12 | "explore_wt": 0.3 13 | }, 14 | "MODEL.DROPOUT_EMBEDDING": { 15 | "low": 0.0, 16 | "high": 0.3, 17 | "sample_fn": "uniform", 18 | "explore_wt": 0.3 19 | }, 20 | "MODEL.CONTEXT_FORWARD": { 21 | "low": 4, 22 | "high": 32, 23 | "sample_fn": "randint", 24 | "enforce_limits": false, 25 | "is_integer": true 26 | }, 27 | "MODEL.CONTEXT_BACKWARD": { 28 | "low": 4, 29 | "high": 32, 30 | "sample_fn": "randint", 31 | "enforce_limits": false, 32 | "is_integer": true 33 | }, 34 | "TRAIN.LR.INIT": { 35 | "low": 1e-5, 36 | "high": 5e-3, 37 | "explore_wt": 0.3 38 | }, 39 | "TRAIN.MASK_RANDOM_RATIO": { 40 | "low": 0.9, 41 | "high": 1.0, 42 | "sample_fn": "uniform" 43 | }, 44 | "TRAIN.MASK_TOKEN_RATIO": { 45 | "low": 0.5, 46 | "high": 1.0, 47 | "sample_fn": "loguniform" 48 | }, 49 | "TRAIN.MASK_MAX_SPAN": { 50 | "low": 1, 51 | "high": 5, 52 | "sample_fn": "randint", 53 | "is_integer": true 54 | } 55 | } -------------------------------------------------------------------------------- /configs/sweep_simple.json: -------------------------------------------------------------------------------- 1 | { 2 | "MODEL.DROPOUT": { 3 | "low": 0.0, 4 | "high": 0.3, 5 | "sample_fn": "uniform", 6 | "explore_wt": 0.3 7 | }, 8 | "MODEL.DROPOUT_RATES": { 9 | "low": 0.0, 10 | "high": 0.3, 11 | "sample_fn": "uniform", 12 | "explore_wt": 0.3 13 | }, 14 | "MODEL.DROPOUT_EMBEDDING": { 15 | "low": 0.0, 16 | "high": 0.3, 17 | "sample_fn": "uniform", 18 | "explore_wt": 0.3 19 | }, 20 | "TRAIN.LR.INIT": { 21 | "low": 1e-5, 22 | "high": 5e-3, 23 | "explore_wt": 0.3 24 | }, 25 | "TRAIN.MASK_RANDOM_RATIO": { 26 | "low": 0.9, 27 | "high": 1.0, 28 | "sample_fn": "uniform" 29 | }, 30 | "TRAIN.MASK_TOKEN_RATIO": { 31 | "low": 0.5, 32 | "high": 1.0, 33 | "sample_fn": "loguniform" 34 | } 35 | } -------------------------------------------------------------------------------- /data/chaotic_rnn/gen_synth_data_no_inputs.sh: -------------------------------------------------------------------------------- 1 | SYNTH_PATH=/snel/share/data/chaotic_rnn_data/data_no_inputs 2 | 3 | 4 | python generate_chaotic_rnn_data.py --save_dir=$SYNTH_PATH --datafile_name=chaotic_rnn_no_inputs --synth_data_seed=5 --T=1.0 --C=130 --N=50 --S=50 --train_percentage=0.8 --nreplications=10 --g=1.5 --x0_std=1.0 --tau=0.025 --dt=0.01 --ninputs=0 --input_magnitude_list='' --max_firing_rate=30.0 --noise_type='poisson' -------------------------------------------------------------------------------- /data/chaotic_rnn/generate_chaotic_rnn_data.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # ============================================================================== 16 | from __future__ import print_function 17 | 18 | import h5py 19 | import numpy as np 20 | import os 21 | import tensorflow as tf # used for flags here 22 | 23 | from utils import write_datasets 24 | from synthetic_data_utils import add_alignment_projections, generate_data 25 | from synthetic_data_utils import generate_rnn, get_train_n_valid_inds 26 | from synthetic_data_utils import nparray_and_transpose 27 | from synthetic_data_utils import spikify_data, gaussify_data, split_list_by_inds 28 | #import matplotlib 29 | #import matplotlib.pyplot as plt 30 | import scipy.signal 31 | 32 | #matplotlib.rcParams['image.interpolation'] = 'nearest' 33 | DATA_DIR = "rnn_synth_data_v1.0" 34 | 35 | flags = tf.app.flags 36 | flags.DEFINE_string("save_dir", "/tmp/" + DATA_DIR + "/", 37 | "Directory for saving data.") 38 | flags.DEFINE_string("datafile_name", "thits_data", 39 | "Name of data file for input case.") 40 | flags.DEFINE_string("noise_type", "poisson", "Noise type for data.") 41 | flags.DEFINE_integer("synth_data_seed", 5, "Random seed for RNN generation.") 42 | flags.DEFINE_float("T", 1.0, "Time in seconds to generate.") 43 | flags.DEFINE_integer("C", 100, "Number of conditions") 44 | flags.DEFINE_integer("N", 50, "Number of units for the RNN") 45 | flags.DEFINE_integer("S", 50, "Number of sampled units from RNN") 46 | flags.DEFINE_integer("npcs", 10, "Number of PCS for multi-session case.") 47 | flags.DEFINE_float("train_percentage", 4.0/5.0, 48 | "Percentage of train vs validation trials") 49 | flags.DEFINE_integer("nreplications", 40, 50 | "Number of noise replications of the same underlying rates.") 51 | flags.DEFINE_float("g", 1.5, "Complexity of dynamics") 52 | flags.DEFINE_float("x0_std", 1.0, 53 | "Volume from which to pull initial conditions (affects diversity of dynamics.") 54 | flags.DEFINE_float("tau", 0.025, "Time constant of RNN") 55 | flags.DEFINE_float("dt", 0.010, "Time bin") 56 | flags.DEFINE_float("input_magnitude", 20.0, 57 | "For the input case, what is the value of the input?") 58 | flags.DEFINE_integer("ninputs", 0, "number of inputs") 59 | flags.DEFINE_list("input_magnitude_list", "10,15,20,25,30", "Magnitudes for multiple inputs") 60 | flags.DEFINE_float("max_firing_rate", 30.0, "Map 1.0 of RNN to a spikes per second") 61 | flags.DEFINE_boolean("lorenz", False, "use lorenz system as generated inputs") 62 | FLAGS = flags.FLAGS 63 | 64 | 65 | # Note that with N small, (as it is 25 above), the finite size effects 66 | # will have pretty dramatic effects on the dynamics of the random RNN. 67 | # If you want more complex dynamics, you'll have to run the script a 68 | # lot, or increase N (or g). 69 | 70 | # Getting hard vs. easy data can be a little stochastic, so we set the seed. 71 | 72 | # Pull out some commonly used parameters. 73 | # These are user parameters (configuration) 74 | rng = np.random.RandomState(seed=FLAGS.synth_data_seed) 75 | T = FLAGS.T 76 | C = FLAGS.C 77 | N = FLAGS.N 78 | S = FLAGS.S 79 | input_magnitude = FLAGS.input_magnitude 80 | input_magnitude_list = [float(i) for i in (FLAGS.input_magnitude_list)] 81 | ninputs = FLAGS.ninputs 82 | nreplications = FLAGS.nreplications 83 | E = nreplications * C # total number of trials 84 | # S is the number of measurements in each datasets, w/ each 85 | # dataset having a different set of observations. 86 | ndatasets = N/S # ok if rounded down 87 | train_percentage = FLAGS.train_percentage 88 | ntime_steps = int(T / FLAGS.dt) 89 | # End of user parameters 90 | 91 | lorenz=FLAGS.lorenz 92 | 93 | 94 | if ninputs >= 1: 95 | rnn = generate_rnn(rng, N, FLAGS.g, FLAGS.tau, FLAGS.dt, FLAGS.max_firing_rate, ninputs) 96 | else: 97 | rnn = generate_rnn(rng, N, FLAGS.g, FLAGS.tau, FLAGS.dt, FLAGS.max_firing_rate) 98 | 99 | # Check to make sure the RNN is the one we used in the paper. 100 | if N == 50: 101 | assert abs(rnn['W'][0,0] - 0.06239899) < 1e-8, 'Error in random seed?' 102 | rem_check = nreplications * train_percentage 103 | assert abs(rem_check - int(rem_check)) < 1e-8, \ 104 | 'Train percentage * nreplications should be integral number.' 105 | 106 | 107 | #if lorenz: 108 | # lorenz_input = generate_lorenz(ntime_steps, rng) 109 | # rnn = generate_rnn(rng, N, FLAGS.g, FLAGS.tau, FLAGS.dt, FLAGS.max_firing_rate, lorenz_input) 110 | 111 | 112 | # Initial condition generation, and condition label generation. This 113 | # happens outside of the dataset loop, so that all datasets have the 114 | # same conditions, which is similar to a neurophys setup. 115 | condition_number = 0 116 | x0s = [] 117 | condition_labels = [] 118 | print(FLAGS.x0_std) 119 | for c in range(C): 120 | x0 = FLAGS.x0_std * rng.randn(N, 1) 121 | x0s.append(np.tile(x0, nreplications)) # replicate x0 nreplications times 122 | # replicate the condition label nreplications times 123 | for ns in range(nreplications): 124 | condition_labels.append(condition_number) 125 | condition_number += 1 126 | x0s = np.concatenate(x0s, axis=1) 127 | 128 | #print(x0s.shape) 129 | #print(x0s[1,1:20]) 130 | 131 | # Containers for storing data across data. 132 | datasets = {} 133 | for n in range(ndatasets): 134 | print(n+1, " of ", ndatasets) 135 | 136 | # First generate all firing rates. in the next loop, generate all 137 | # replications this allows the random state for rate generation to be 138 | # independent of n_replications. 139 | dataset_name = 'dataset_N' + str(N) + '_S' + str(S) 140 | if S < N: 141 | dataset_name += '_n' + str(n+1) 142 | 143 | # Sample neuron subsets. The assumption is the PC axes of the RNN 144 | # are not unit aligned, so sampling units is adequate to sample all 145 | # the high-variance PCs. 146 | P_sxn = np.eye(S,N) 147 | for m in range(n): 148 | P_sxn = np.roll(P_sxn, S, axis=1) 149 | 150 | if input_magnitude > 0.0: 151 | # time of "hits" randomly chosen between [1/4 and 3/4] of total time 152 | if ninputs>1: 153 | for n in range(ninputs): 154 | if n == 0: 155 | input_times = rng.choice(int(ntime_steps/2), size=[E]) + int(ntime_steps/4) 156 | else: 157 | input_times = np.vstack((input_times, (rng.choice(int(ntime_steps/2), size=[E]) + int(ntime_steps/4)))) 158 | else: 159 | input_times = rng.choice(int(ntime_steps/2), size=[E]) + int(ntime_steps/4) 160 | else: 161 | input_times = None 162 | 163 | print(ninputs) 164 | 165 | 166 | if ninputs > 1: 167 | rates, x0s, inputs = \ 168 | generate_data(rnn, T=T, E=E, x0s=x0s, P_sxn=P_sxn, 169 | input_magnitude=input_magnitude_list, 170 | input_times=input_times, ninputs=ninputs, rng=rng) 171 | else: 172 | rates, x0s, inputs = \ 173 | generate_data(rnn, T=T, E=E, x0s=x0s, P_sxn=P_sxn, 174 | input_magnitude=input_magnitude, 175 | input_times=input_times, ninputs=ninputs, rng=rng) 176 | 177 | if FLAGS.noise_type == "poisson": 178 | noisy_data = spikify_data(rates, rng, rnn['dt'], rnn['max_firing_rate']) 179 | elif FLAGS.noise_type == "gaussian": 180 | noisy_data = gaussify_data(rates, rng, rnn['dt'], rnn['max_firing_rate']) 181 | else: 182 | raise ValueError("Only noise types supported are poisson or gaussian") 183 | 184 | # split into train and validation sets 185 | train_inds, valid_inds = get_train_n_valid_inds(E, train_percentage, 186 | nreplications) 187 | 188 | # Split the data, inputs, labels and times into train vs. validation. 189 | rates_train, rates_valid = \ 190 | split_list_by_inds(rates, train_inds, valid_inds) 191 | #rates_train_no_input, rates_valid_no_input = \ 192 | # split_list_by_inds(rates_noinput, train_inds, valid_inds) 193 | noisy_data_train, noisy_data_valid = \ 194 | split_list_by_inds(noisy_data, train_inds, valid_inds) 195 | input_train, inputs_valid = \ 196 | split_list_by_inds(inputs, train_inds, valid_inds) 197 | condition_labels_train, condition_labels_valid = \ 198 | split_list_by_inds(condition_labels, train_inds, valid_inds) 199 | if ninputs>1: 200 | input_times_train, input_times_valid = \ 201 | split_list_by_inds(input_times, train_inds, valid_inds, ninputs) 202 | input_magnitude = input_magnitude_list 203 | else: 204 | input_times_train, input_times_valid = \ 205 | split_list_by_inds(input_times, train_inds, valid_inds) 206 | 207 | #lorenz_train = np.expand_dims(lorenz_train, axis=1) 208 | #lorenz_valid = np.expand_dims(lorenz_valid, axis=1) 209 | #print((np.array(input_train)).shape) 210 | #print((np.array(lorenz_train)).shape) 211 | # Turn rates, noisy_data, and input into numpy arrays. 212 | rates_train = nparray_and_transpose(rates_train) 213 | rates_valid = nparray_and_transpose(rates_valid) 214 | #rates_train_no_input = nparray_and_transpose(rates_train_no_input) 215 | #rates_valid_no_input = nparray_and_transpose(rates_valid_no_input) 216 | noisy_data_train = nparray_and_transpose(noisy_data_train) 217 | noisy_data_valid = nparray_and_transpose(noisy_data_valid) 218 | input_train = nparray_and_transpose(input_train) 219 | inputs_valid = nparray_and_transpose(inputs_valid) 220 | 221 | # Note that we put these 'truth' rates and input into this 222 | # structure, the only data that is used in LFADS are the noisy 223 | # data e.g. spike trains. The rest is either for printing or posterity. 224 | data = {'train_truth': rates_train, 225 | 'valid_truth': rates_valid, 226 | #'train_truth_no_input': rates_train_no_input, 227 | #'valid_truth_no_input': rates_valid_no_input, 228 | 'input_train_truth' : input_train, 229 | 'input_valid_truth' : inputs_valid, 230 | 'train_data' : noisy_data_train, 231 | 'valid_data' : noisy_data_valid, 232 | 'train_percentage' : train_percentage, 233 | 'nreplications' : nreplications, 234 | 'dt' : rnn['dt'], 235 | 'input_magnitude' : input_magnitude, 236 | 'input_times_train' : input_times_train, 237 | 'input_times_valid' : input_times_valid, 238 | 'P_sxn' : P_sxn, 239 | 'condition_labels_train' : condition_labels_train, 240 | 'condition_labels_valid' : condition_labels_valid, 241 | 'conversion_factor': 1.0 / rnn['conversion_factor']} 242 | datasets[dataset_name] = data 243 | 244 | if S < N: 245 | # Note that this isn't necessary for this synthetic example, but 246 | # it's useful to see how the input factor matrices were initialized 247 | # for actual neurophysiology data. 248 | datasets = add_alignment_projections(datasets, npcs=FLAGS.npcs) 249 | 250 | # Write out the datasets. 251 | write_datasets(FLAGS.save_dir, FLAGS.datafile_name, datasets) 252 | -------------------------------------------------------------------------------- /data/chaotic_rnn/generate_chaotic_rnn_data_allowRandomSeed.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # ============================================================================== 16 | from __future__ import print_function 17 | 18 | import h5py 19 | import numpy as np 20 | import os 21 | import tensorflow as tf # used for flags here 22 | 23 | from utils import write_datasets 24 | from synthetic_data_utils import add_alignment_projections, generate_data 25 | from synthetic_data_utils import generate_rnn, get_train_n_valid_inds 26 | from synthetic_data_utils import nparray_and_transpose 27 | from synthetic_data_utils import spikify_data, gaussify_data, split_list_by_inds 28 | #import matplotlib 29 | #import matplotlib.pyplot as plt 30 | import scipy.signal 31 | 32 | #matplotlib.rcParams['image.interpolation'] = 'nearest' 33 | DATA_DIR = "rnn_synth_data_v1.0" 34 | 35 | flags = tf.app.flags 36 | flags.DEFINE_string("save_dir", "/tmp/" + DATA_DIR + "/", 37 | "Directory for saving data.") 38 | flags.DEFINE_string("datafile_name", "thits_data", 39 | "Name of data file for input case.") 40 | flags.DEFINE_string("noise_type", "poisson", "Noise type for data.") 41 | flags.DEFINE_integer("synth_data_seed", 5, "Random seed for RNN generation.") 42 | flags.DEFINE_float("T", 1.0, "Time in seconds to generate.") 43 | flags.DEFINE_integer("C", 100, "Number of conditions") 44 | flags.DEFINE_integer("N", 50, "Number of units for the RNN") 45 | flags.DEFINE_integer("S", 50, "Number of sampled units from RNN") 46 | flags.DEFINE_integer("npcs", 10, "Number of PCS for multi-session case.") 47 | flags.DEFINE_float("train_percentage", 4.0/5.0, 48 | "Percentage of train vs validation trials") 49 | flags.DEFINE_integer("nreplications", 40, 50 | "Number of noise replications of the same underlying rates.") 51 | flags.DEFINE_float("g", 1.5, "Complexity of dynamics") 52 | flags.DEFINE_float("x0_std", 1.0, 53 | "Volume from which to pull initial conditions (affects diversity of dynamics.") 54 | flags.DEFINE_float("tau", 0.025, "Time constant of RNN") 55 | flags.DEFINE_float("dt", 0.010, "Time bin") 56 | flags.DEFINE_float("input_magnitude", 20.0, 57 | "For the input case, what is the value of the input?") 58 | flags.DEFINE_integer("ninputs", 0, "number of inputs") 59 | flags.DEFINE_list("input_magnitude_list", "10,15,20,25,30", "Magnitudes for multiple inputs") 60 | flags.DEFINE_float("max_firing_rate", 30.0, "Map 1.0 of RNN to a spikes per second") 61 | flags.DEFINE_boolean("lorenz", False, "use lorenz system as generated inputs") 62 | FLAGS = flags.FLAGS 63 | 64 | 65 | # Note that with N small, (as it is 25 above), the finite size effects 66 | # will have pretty dramatic effects on the dynamics of the random RNN. 67 | # If you want more complex dynamics, you'll have to run the script a 68 | # lot, or increase N (or g). 69 | 70 | # Getting hard vs. easy data can be a little stochastic, so we set the seed. 71 | 72 | # Pull out some commonly used parameters. 73 | # These are user parameters (configuration) 74 | rng = np.random.RandomState(seed=FLAGS.synth_data_seed) 75 | T = FLAGS.T 76 | C = FLAGS.C 77 | N = FLAGS.N 78 | S = FLAGS.S 79 | input_magnitude = FLAGS.input_magnitude 80 | input_magnitude_list = [float(i) for i in (FLAGS.input_magnitude_list)] 81 | ninputs = FLAGS.ninputs 82 | nreplications = FLAGS.nreplications 83 | E = nreplications * C # total number of trials 84 | # S is the number of measurements in each datasets, w/ each 85 | # dataset having a different set of observations. 86 | ndatasets = N/S # ok if rounded down 87 | train_percentage = FLAGS.train_percentage 88 | ntime_steps = int(T / FLAGS.dt) 89 | # End of user parameters 90 | 91 | lorenz=FLAGS.lorenz 92 | 93 | 94 | if ninputs >= 1: 95 | rnn = generate_rnn(rng, N, FLAGS.g, FLAGS.tau, FLAGS.dt, FLAGS.max_firing_rate, ninputs) 96 | else: 97 | rnn = generate_rnn(rng, N, FLAGS.g, FLAGS.tau, FLAGS.dt, FLAGS.max_firing_rate) 98 | 99 | # Check to make sure the RNN is the one we used in the paper. 100 | if N == 50: 101 | # assert abs(rnn['W'][0,0] - 0.06239899) < 1e-8, 'Error in random seed?' 102 | rem_check = nreplications * train_percentage 103 | assert abs(rem_check - int(rem_check)) < 1e-8, \ 104 | 'Train percentage * nreplications should be integral number.' 105 | 106 | 107 | #if lorenz: 108 | # lorenz_input = generate_lorenz(ntime_steps, rng) 109 | # rnn = generate_rnn(rng, N, FLAGS.g, FLAGS.tau, FLAGS.dt, FLAGS.max_firing_rate, lorenz_input) 110 | 111 | 112 | # Initial condition generation, and condition label generation. This 113 | # happens outside of the dataset loop, so that all datasets have the 114 | # same conditions, which is similar to a neurophys setup. 115 | condition_number = 0 116 | x0s = [] 117 | condition_labels = [] 118 | print(FLAGS.x0_std) 119 | for c in range(C): 120 | x0 = FLAGS.x0_std * rng.randn(N, 1) 121 | x0s.append(np.tile(x0, nreplications)) # replicate x0 nreplications times 122 | # replicate the condition label nreplications times 123 | for ns in range(nreplications): 124 | condition_labels.append(condition_number) 125 | condition_number += 1 126 | x0s = np.concatenate(x0s, axis=1) 127 | 128 | #print(x0s.shape) 129 | #print(x0s[1,1:20]) 130 | 131 | # Containers for storing data across data. 132 | datasets = {} 133 | for n in range(ndatasets): 134 | print(n+1, " of ", ndatasets) 135 | 136 | # First generate all firing rates. in the next loop, generate all 137 | # replications this allows the random state for rate generation to be 138 | # independent of n_replications. 139 | dataset_name = 'dataset_N' + str(N) + '_S' + str(S) 140 | if S < N: 141 | dataset_name += '_n' + str(n+1) 142 | 143 | # Sample neuron subsets. The assumption is the PC axes of the RNN 144 | # are not unit aligned, so sampling units is adequate to sample all 145 | # the high-variance PCs. 146 | P_sxn = np.eye(S,N) 147 | for m in range(n): 148 | P_sxn = np.roll(P_sxn, S, axis=1) 149 | 150 | if input_magnitude > 0.0: 151 | # time of "hits" randomly chosen between [1/4 and 3/4] of total time 152 | if ninputs>1: 153 | for n in range(ninputs): 154 | if n == 0: 155 | input_times = rng.choice(int(ntime_steps/2), size=[E]) + int(ntime_steps/4) 156 | else: 157 | input_times = np.vstack((input_times, (rng.choice(int(ntime_steps/2), size=[E]) + int(ntime_steps/4)))) 158 | else: 159 | input_times = rng.choice(int(ntime_steps/2), size=[E]) + int(ntime_steps/4) 160 | else: 161 | input_times = None 162 | 163 | print(ninputs) 164 | 165 | 166 | if ninputs > 1: 167 | rates, x0s, inputs = \ 168 | generate_data(rnn, T=T, E=E, x0s=x0s, P_sxn=P_sxn, 169 | input_magnitude=input_magnitude_list, 170 | input_times=input_times, ninputs=ninputs, rng=rng) 171 | else: 172 | rates, x0s, inputs = \ 173 | generate_data(rnn, T=T, E=E, x0s=x0s, P_sxn=P_sxn, 174 | input_magnitude=input_magnitude, 175 | input_times=input_times, ninputs=ninputs, rng=rng) 176 | 177 | if FLAGS.noise_type == "poisson": 178 | noisy_data = spikify_data(rates, rng, rnn['dt'], rnn['max_firing_rate']) 179 | elif FLAGS.noise_type == "gaussian": 180 | noisy_data = gaussify_data(rates, rng, rnn['dt'], rnn['max_firing_rate']) 181 | else: 182 | raise ValueError("Only noise types supported are poisson or gaussian") 183 | 184 | # split into train and validation sets 185 | train_inds, valid_inds = get_train_n_valid_inds(E, train_percentage, 186 | nreplications) 187 | 188 | # Split the data, inputs, labels and times into train vs. validation. 189 | rates_train, rates_valid = \ 190 | split_list_by_inds(rates, train_inds, valid_inds) 191 | #rates_train_no_input, rates_valid_no_input = \ 192 | # split_list_by_inds(rates_noinput, train_inds, valid_inds) 193 | noisy_data_train, noisy_data_valid = \ 194 | split_list_by_inds(noisy_data, train_inds, valid_inds) 195 | input_train, inputs_valid = \ 196 | split_list_by_inds(inputs, train_inds, valid_inds) 197 | condition_labels_train, condition_labels_valid = \ 198 | split_list_by_inds(condition_labels, train_inds, valid_inds) 199 | if ninputs>1: 200 | input_times_train, input_times_valid = \ 201 | split_list_by_inds(input_times, train_inds, valid_inds, ninputs) 202 | input_magnitude = input_magnitude_list 203 | else: 204 | input_times_train, input_times_valid = \ 205 | split_list_by_inds(input_times, train_inds, valid_inds) 206 | 207 | #lorenz_train = np.expand_dims(lorenz_train, axis=1) 208 | #lorenz_valid = np.expand_dims(lorenz_valid, axis=1) 209 | #print((np.array(input_train)).shape) 210 | #print((np.array(lorenz_train)).shape) 211 | # Turn rates, noisy_data, and input into numpy arrays. 212 | rates_train = nparray_and_transpose(rates_train) 213 | rates_valid = nparray_and_transpose(rates_valid) 214 | #rates_train_no_input = nparray_and_transpose(rates_train_no_input) 215 | #rates_valid_no_input = nparray_and_transpose(rates_valid_no_input) 216 | noisy_data_train = nparray_and_transpose(noisy_data_train) 217 | noisy_data_valid = nparray_and_transpose(noisy_data_valid) 218 | input_train = nparray_and_transpose(input_train) 219 | inputs_valid = nparray_and_transpose(inputs_valid) 220 | 221 | # Note that we put these 'truth' rates and input into this 222 | # structure, the only data that is used in LFADS are the noisy 223 | # data e.g. spike trains. The rest is either for printing or posterity. 224 | data = {'train_truth': rates_train, 225 | 'valid_truth': rates_valid, 226 | #'train_truth_no_input': rates_train_no_input, 227 | #'valid_truth_no_input': rates_valid_no_input, 228 | 'input_train_truth' : input_train, 229 | 'input_valid_truth' : inputs_valid, 230 | 'train_data' : noisy_data_train, 231 | 'valid_data' : noisy_data_valid, 232 | 'train_percentage' : train_percentage, 233 | 'nreplications' : nreplications, 234 | 'dt' : rnn['dt'], 235 | 'input_magnitude' : input_magnitude, 236 | 'input_times_train' : input_times_train, 237 | 'input_times_valid' : input_times_valid, 238 | 'P_sxn' : P_sxn, 239 | 'condition_labels_train' : condition_labels_train, 240 | 'condition_labels_valid' : condition_labels_valid, 241 | 'conversion_factor': 1.0 / rnn['conversion_factor']} 242 | datasets[dataset_name] = data 243 | 244 | if S < N: 245 | # Note that this isn't necessary for this synthetic example, but 246 | # it's useful to see how the input factor matrices were initialized 247 | # for actual neurophysiology data. 248 | datasets = add_alignment_projections(datasets, npcs=FLAGS.npcs) 249 | 250 | # Write out the datasets. 251 | write_datasets(FLAGS.save_dir, FLAGS.datafile_name, datasets) 252 | -------------------------------------------------------------------------------- /data/chaotic_rnn/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # ============================================================================== 16 | from __future__ import print_function 17 | 18 | import os 19 | import h5py 20 | import json 21 | 22 | import numpy as np 23 | import tensorflow as tf 24 | 25 | 26 | def log_sum_exp(x_k): 27 | """Computes log \sum exp in a numerically stable way. 28 | log ( sum_i exp(x_i) ) 29 | log ( sum_i exp(x_i - m + m) ), with m = max(x_i) 30 | log ( sum_i exp(x_i - m)*exp(m) ) 31 | log ( sum_i exp(x_i - m) + m 32 | 33 | Args: 34 | x_k - k -dimensional list of arguments to log_sum_exp. 35 | 36 | Returns: 37 | log_sum_exp of the arguments. 38 | """ 39 | m = tf.reduce_max(x_k) 40 | x1_k = x_k - m 41 | u_k = tf.exp(x1_k) 42 | z = tf.reduce_sum(u_k) 43 | return tf.log(z) + m 44 | 45 | 46 | def linear(x, out_size, do_bias=True, alpha=1.0, identity_if_possible=False, 47 | normalized=False, name=None, collections=None): 48 | """Linear (affine) transformation, y = x W + b, for a variety of 49 | configurations. 50 | 51 | Args: 52 | x: input The tensor to tranformation. 53 | out_size: The integer size of non-batch output dimension. 54 | do_bias (optional): Add a learnable bias vector to the operation. 55 | alpha (optional): A multiplicative scaling for the weight initialization 56 | of the matrix, in the form \alpha * 1/\sqrt{x.shape[1]}. 57 | identity_if_possible (optional): just return identity, 58 | if x.shape[1] == out_size. 59 | normalized (optional): Option to divide out by the norms of the rows of W. 60 | name (optional): The name prefix to add to variables. 61 | collections (optional): List of additional collections. (Placed in 62 | tf.GraphKeys.GLOBAL_VARIABLES already, so no need for that.) 63 | 64 | Returns: 65 | In the equation, y = x W + b, returns the tensorflow op that yields y. 66 | """ 67 | in_size = int(x.get_shape()[1]) # from Dimension(10) -> 10 68 | stddev = alpha/np.sqrt(float(in_size)) 69 | mat_init = tf.random_normal_initializer(0.0, stddev) 70 | wname = (name + "/W") if name else "/W" 71 | 72 | if identity_if_possible and in_size == out_size: 73 | # Sometimes linear layers are nothing more than size adapters. 74 | return tf.identity(x, name=(wname+'_ident')) 75 | 76 | W,b = init_linear(in_size, out_size, do_bias=do_bias, alpha=alpha, 77 | normalized=normalized, name=name, collections=collections) 78 | 79 | if do_bias: 80 | return tf.matmul(x, W) + b 81 | else: 82 | return tf.matmul(x, W) 83 | 84 | 85 | def init_linear(in_size, out_size, do_bias=True, mat_init_value=None, 86 | bias_init_value=None, alpha=1.0, identity_if_possible=False, 87 | normalized=False, name=None, collections=None, trainable=True): 88 | """Linear (affine) transformation, y = x W + b, for a variety of 89 | configurations. 90 | 91 | Args: 92 | in_size: The integer size of the non-batc input dimension. [(x),y] 93 | out_size: The integer size of non-batch output dimension. [x,(y)] 94 | do_bias (optional): Add a (learnable) bias vector to the operation, 95 | if false, b will be None 96 | mat_init_value (optional): numpy constant for matrix initialization, if None 97 | , do random, with additional parameters. 98 | alpha (optional): A multiplicative scaling for the weight initialization 99 | of the matrix, in the form \alpha * 1/\sqrt{x.shape[1]}. 100 | identity_if_possible (optional): just return identity, 101 | if x.shape[1] == out_size. 102 | normalized (optional): Option to divide out by the norms of the rows of W. 103 | name (optional): The name prefix to add to variables. 104 | collections (optional): List of additional collections. (Placed in 105 | tf.GraphKeys.GLOBAL_VARIABLES already, so no need for that.) 106 | 107 | Returns: 108 | In the equation, y = x W + b, returns the pair (W, b). 109 | """ 110 | 111 | if mat_init_value is not None and mat_init_value.shape != (in_size, out_size): 112 | raise ValueError( 113 | 'Provided mat_init_value must have shape [%d, %d].'%(in_size, out_size)) 114 | if bias_init_value is not None and bias_init_value.shape != (1,out_size): 115 | raise ValueError( 116 | 'Provided bias_init_value must have shape [1,%d].'%(out_size,)) 117 | 118 | if mat_init_value is None: 119 | stddev = alpha/np.sqrt(float(in_size)) 120 | mat_init = tf.random_normal_initializer(0.0, stddev) 121 | 122 | wname = (name + "/W") if name else "/W" 123 | 124 | if identity_if_possible and in_size == out_size: 125 | return (tf.constant(np.eye(in_size).astype(np.float32)), 126 | tf.zeros(in_size)) 127 | 128 | # Note the use of get_variable vs. tf.Variable. this is because get_variable 129 | # does not allow the initialization of the variable with a value. 130 | if normalized: 131 | w_collections = [tf.GraphKeys.GLOBAL_VARIABLES, "norm-variables"] 132 | if collections: 133 | w_collections += collections 134 | if mat_init_value is not None: 135 | w = tf.Variable(mat_init_value, name=wname, collections=w_collections, 136 | trainable=trainable) 137 | else: 138 | w = tf.get_variable(wname, [in_size, out_size], initializer=mat_init, 139 | collections=w_collections, trainable=trainable) 140 | w = tf.nn.l2_normalize(w, dim=0) # x W, so xW_j = \sum_i x_bi W_ij 141 | else: 142 | w_collections = [tf.GraphKeys.GLOBAL_VARIABLES] 143 | if collections: 144 | w_collections += collections 145 | if mat_init_value is not None: 146 | w = tf.Variable(mat_init_value, name=wname, collections=w_collections, 147 | trainable=trainable) 148 | else: 149 | w = tf.get_variable(wname, [in_size, out_size], initializer=mat_init, 150 | collections=w_collections, trainable=trainable) 151 | b = None 152 | if do_bias: 153 | b_collections = [tf.GraphKeys.GLOBAL_VARIABLES] 154 | if collections: 155 | b_collections += collections 156 | bname = (name + "/b") if name else "/b" 157 | if bias_init_value is None: 158 | b = tf.get_variable(bname, [1, out_size], 159 | initializer=tf.zeros_initializer(), 160 | collections=b_collections, 161 | trainable=trainable) 162 | else: 163 | b = tf.Variable(bias_init_value, name=bname, 164 | collections=b_collections, 165 | trainable=trainable) 166 | 167 | return (w, b) 168 | 169 | 170 | def write_data(data_fname, data_dict, use_json=False, compression=None): 171 | """Write data in HD5F format. 172 | 173 | Args: 174 | data_fname: The filename of teh file in which to write the data. 175 | data_dict: The dictionary of data to write. The keys are strings 176 | and the values are numpy arrays. 177 | use_json (optional): human readable format for simple items 178 | compression (optional): The compression to use for h5py (disabled by 179 | default because the library borks on scalars, otherwise try 'gzip'). 180 | """ 181 | 182 | dir_name = os.path.dirname(data_fname) 183 | if not os.path.exists(dir_name): 184 | os.makedirs(dir_name) 185 | 186 | if use_json: 187 | the_file = open(data_fname,'w') 188 | json.dump(data_dict, the_file) 189 | the_file.close() 190 | else: 191 | try: 192 | with h5py.File(data_fname, 'w') as hf: 193 | for k, v in data_dict.items(): 194 | clean_k = k.replace('/', '_') 195 | if clean_k is not k: 196 | print('Warning: saving variable with name: ', k, ' as ', clean_k) 197 | else: 198 | print('Saving variable with name: ', clean_k) 199 | hf.create_dataset(clean_k, data=v, compression=compression) 200 | except IOError: 201 | print("Cannot open %s for writing.", data_fname) 202 | raise 203 | 204 | 205 | def read_data(data_fname): 206 | """ Read saved data in HDF5 format. 207 | 208 | Args: 209 | data_fname: The filename of the file from which to read the data. 210 | Returns: 211 | A dictionary whose keys will vary depending on dataset (but should 212 | always contain the keys 'train_data' and 'valid_data') and whose 213 | values are numpy arrays. 214 | """ 215 | 216 | try: 217 | with h5py.File(data_fname, 'r') as hf: 218 | data_dict = {k: np.array(v) for k, v in hf.items()} 219 | return data_dict 220 | except IOError: 221 | print("Cannot open %s for reading." % data_fname) 222 | raise 223 | 224 | 225 | def write_datasets(data_path, data_fname_stem, dataset_dict, compression=None): 226 | """Write datasets in HD5F format. 227 | 228 | This function assumes the dataset_dict is a mapping ( string -> 229 | to data_dict ). It calls write_data for each data dictionary, 230 | post-fixing the data filename with the key of the dataset. 231 | 232 | Args: 233 | data_path: The path to the save directory. 234 | data_fname_stem: The filename stem of the file in which to write the data. 235 | dataset_dict: The dictionary of datasets. The keys are strings 236 | and the values data dictionaries (str -> numpy arrays) associations. 237 | compression (optional): The compression to use for h5py (disabled by 238 | default because the library borks on scalars, otherwise try 'gzip'). 239 | """ 240 | 241 | full_name_stem = os.path.join(data_path, data_fname_stem) 242 | for s, data_dict in dataset_dict.items(): 243 | write_data(full_name_stem + "_" + s, data_dict, compression=compression) 244 | 245 | 246 | def read_datasets(data_path, data_fname_stem): 247 | """Read dataset sin HD5F format. 248 | 249 | This function assumes the dataset_dict is a mapping ( string -> 250 | to data_dict ). It calls write_data for each data dictionary, 251 | post-fixing the data filename with the key of the dataset. 252 | 253 | Args: 254 | data_path: The path to the save directory. 255 | data_fname_stem: The filename stem of the file in which to write the data. 256 | """ 257 | 258 | dataset_dict = {} 259 | fnames = os.listdir(data_path) 260 | 261 | print ('loading data from ' + data_path + ' with stem ' + data_fname_stem) 262 | for fname in fnames: 263 | if fname.startswith(data_fname_stem): 264 | data_dict = read_data(os.path.join(data_path,fname)) 265 | idx = len(data_fname_stem) + 1 266 | key = fname[idx:] 267 | data_dict['data_dim'] = data_dict['train_data'].shape[2] 268 | data_dict['num_steps'] = data_dict['train_data'].shape[1] 269 | dataset_dict[key] = data_dict 270 | 271 | if len(dataset_dict) == 0: 272 | raise ValueError("Failed to load any datasets, are you sure that the " 273 | "'--data_dir' and '--data_filename_stem' flag values " 274 | "are correct?") 275 | 276 | print (str(len(dataset_dict)) + ' datasets loaded') 277 | return dataset_dict 278 | 279 | 280 | # NUMPY utility functions 281 | def list_t_bxn_to_list_b_txn(values_t_bxn): 282 | """Convert a length T list of BxN numpy tensors of length B list of TxN numpy 283 | tensors. 284 | 285 | Args: 286 | values_t_bxn: The length T list of BxN numpy tensors. 287 | 288 | Returns: 289 | The length B list of TxN numpy tensors. 290 | """ 291 | T = len(values_t_bxn) 292 | B, N = values_t_bxn[0].shape 293 | values_b_txn = [] 294 | for b in range(B): 295 | values_pb_txn = np.zeros([T,N]) 296 | for t in range(T): 297 | values_pb_txn[t,:] = values_t_bxn[t][b,:] 298 | values_b_txn.append(values_pb_txn) 299 | 300 | return values_b_txn 301 | 302 | 303 | def list_t_bxn_to_tensor_bxtxn(values_t_bxn): 304 | """Convert a length T list of BxN numpy tensors to single numpy tensor with 305 | shape BxTxN. 306 | 307 | Args: 308 | values_t_bxn: The length T list of BxN numpy tensors. 309 | 310 | Returns: 311 | values_bxtxn: The BxTxN numpy tensor. 312 | """ 313 | 314 | T = len(values_t_bxn) 315 | B, N = values_t_bxn[0].shape 316 | values_bxtxn = np.zeros([B,T,N]) 317 | for t in range(T): 318 | values_bxtxn[:,t,:] = values_t_bxn[t] 319 | 320 | return values_bxtxn 321 | 322 | 323 | def tensor_bxtxn_to_list_t_bxn(tensor_bxtxn): 324 | """Convert a numpy tensor with shape BxTxN to a length T list of numpy tensors 325 | with shape BxT. 326 | 327 | Args: 328 | tensor_bxtxn: The BxTxN numpy tensor. 329 | 330 | Returns: 331 | A length T list of numpy tensors with shape BxT. 332 | """ 333 | 334 | values_t_bxn = [] 335 | B, T, N = tensor_bxtxn.shape 336 | for t in range(T): 337 | values_t_bxn.append(np.squeeze(tensor_bxtxn[:,t,:])) 338 | 339 | return values_t_bxn 340 | 341 | 342 | def flatten(list_of_lists): 343 | """Takes a list of lists and returns a list of the elements. 344 | 345 | Args: 346 | list_of_lists: List of lists. 347 | 348 | Returns: 349 | flat_list: Flattened list. 350 | flat_list_idxs: Flattened list indices. 351 | """ 352 | flat_list = [] 353 | flat_list_idxs = [] 354 | start_idx = 0 355 | for item in list_of_lists: 356 | if isinstance(item, list): 357 | flat_list += item 358 | l = len(item) 359 | idxs = range(start_idx, start_idx+l) 360 | start_idx = start_idx+l 361 | else: # a value 362 | flat_list.append(item) 363 | idxs = [start_idx] 364 | start_idx += 1 365 | flat_list_idxs.append(idxs) 366 | 367 | return flat_list, flat_list_idxs 368 | -------------------------------------------------------------------------------- /data/lfads_lorenz.h5: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:d0766fc6e7d1e59be09827fcb839827bcd6904f91f1056975672b4e4b14266e0 3 | size 36207696 4 | -------------------------------------------------------------------------------- /data/lorenz.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Author: Joel Ye 3 | 4 | r""" Notebook for generating Lorenz attractor data. 5 | Note: 6 | This is an example reference, 7 | and does not reproduce the data with repeated conditions used in the paper. 8 | That was generated with LFADS utilities, which we are unable to release at this time. 9 | """ 10 | 11 | #%% 12 | import os 13 | import os.path as osp 14 | import numpy as np 15 | import matplotlib.pyplot as plt 16 | # This import registers the 3D projection, but is otherwise unused. 17 | from mpl_toolkits.mplot3d import Axes3D # noqa: F401 unused import 18 | import torch 19 | import torch.nn as nn 20 | from torch.distributions import Poisson 21 | torch.set_grad_enabled(False) 22 | 23 | default_s = 10 # sigma 24 | default_r = 28 # rho 25 | default_b = 2.667 # beta 26 | 27 | def lorenz_deriv(x, y, z, s=default_s, r=default_r, b=default_b): 28 | ''' 29 | Given: 30 | x, y, z: a point of interest in three dimensional space 31 | s, r, b: parameters defining the lorenz attractor 32 | Returns: 33 | x_dot, y_dot, z_dot: values of the lorenz attractor's partial 34 | derivatives at the point x, y, z 35 | ''' 36 | x_dot = s*(y - x) 37 | y_dot = r*x - y - x*z 38 | z_dot = x*y - b*z 39 | return x_dot, y_dot, z_dot 40 | 41 | def lorenz_generator( 42 | initial_values, 43 | s=default_s, r=default_r, b=default_b, 44 | dt=0.01, num_steps=10000 45 | ): 46 | TRANSIENT_TIME = 50 # dock this 47 | values = np.empty((3, num_steps + TRANSIENT_TIME + 1)) 48 | values[:, 0] = np.array(initial_values) 49 | # Step through "time", calculating the partial derivatives at the current point 50 | # and using them to estimate the next point 51 | for i in range(num_steps + TRANSIENT_TIME): 52 | x_dot, y_dot, z_dot = lorenz_deriv( 53 | values[0, i], 54 | values[1, i], 55 | values[2, i], 56 | s, r, b 57 | ) 58 | values[0, i + 1] = values[0, i] + (x_dot * dt) 59 | values[1, i + 1] = values[1, i] + (y_dot * dt) 60 | values[2, i + 1] = values[2, i] + (z_dot * dt) 61 | return values.T[TRANSIENT_TIME:] 62 | 63 | 64 | #%% 65 | initial_values = (0., 1., 1.05) 66 | lorenz_data = lorenz_generator(initial_values, num_steps=5000).T 67 | # Plot 68 | fig = plt.figure() 69 | ax = fig.gca(projection='3d') 70 | 71 | ax.plot(*lorenz_data, lw=0.5) 72 | ax.set_xlabel("X Axis") 73 | ax.set_ylabel("Y Axis") 74 | ax.set_zlabel("Z Axis") 75 | ax.set_title("Lorenz Attractor") 76 | 77 | plt.show() 78 | 79 | #%% 80 | 81 | # Data HPs 82 | seed = 0 83 | np.random.seed(seed) 84 | n = 2000 85 | t = 200 86 | initial_values = np.random.random((n, 3)) 87 | # initial_values = np.array([0., 1., 1.05])[None, :].repeat(n, axis=0) 88 | # print(initial_values[0]) 89 | lorenz_params = np.array([default_s, default_r, default_b]) \ 90 | + np.random.random((n, 3)) * 10 # Don't deviate too far from chaotic params 91 | 92 | # Generate data 93 | lorenz_dataset = np.empty((n, t, 3)) 94 | for i in range(n): 95 | s, r, b = lorenz_params[i] 96 | lorenz_dataset[i] = lorenz_generator( 97 | initial_values[i], s=s, r=r, b=b, 98 | num_steps=t # Other values tend to be bad 99 | )[1:] # Drops initial value 100 | 101 | noise = np.random.random((n, t, 3)) * 0 # No noise 102 | 103 | trial_min = lorenz_dataset.min(axis=1) 104 | trial_max = lorenz_dataset.max(axis=1) 105 | lorenz_dataset -= trial_min[:, None, :] 106 | lorenz_dataset /= trial_max[:, None, :] 107 | lorenz_dataset += noise 108 | #%% 109 | torch.manual_seed(seed) 110 | # Convert factors to spike times 111 | n_neurons = 96 112 | SPARSE = False 113 | # What. How do we get this to look normal 114 | def lorenz_rates_to_rates_and_spike_times(dataset, base=-10, dynamic_range=(-1, 0)): 115 | # lorenz b x t x 3 -> b x t x h (h neurons) 116 | linear_transform = nn.Linear(3, n_neurons) # Generates a random transform 117 | linear_transform.weight.data.uniform_(*dynamic_range) 118 | # base = torch.log(torch.tensor(base) / 1000.0) 119 | linear_transform.bias.data.fill_(base) 120 | raw_logrates = linear_transform(torch.tensor(dataset).float()) 121 | if SPARSE: 122 | # raw_logrates /= raw_logrates.max() # make max firing lograte 0.5 ~ 1.65 firing rate 123 | clipped_logrates = torch.clamp(raw_logrates, -10, 1) # Double take 124 | else: 125 | clipped_logrates = torch.clamp(raw_logrates, -20, 20) # Double take 126 | # Leads to spikes mostly in 1-5 127 | neuron_rates = torch.exp(clipped_logrates) 128 | return Poisson(neuron_rates).sample(), clipped_logrates 129 | # default mean firing rate is 5 130 | spike_timings, logrates = lorenz_rates_to_rates_and_spike_times(lorenz_dataset) 131 | print(torch.exp(logrates).mean()) 132 | plt.hist(torch.exp(logrates.flatten()).numpy()) 133 | if SPARSE: 134 | plt.yscale("log") 135 | plt.hist(spike_timings.flatten().numpy(), bins=4) 136 | #%% 137 | # 2000 138 | train_n = 160 139 | val_n = int((n - train_n) / 2) 140 | test_n = int((n - train_n) / 2) 141 | splits = (train_n, val_n, test_n) 142 | assert sum(splits) <= n, "Not enough data trials" 143 | 144 | train_rates, val_rates, test_rates = torch.split(logrates, splits) 145 | train_spikes, val_spikes, test_spikes = \ 146 | torch.split(spike_timings, splits) 147 | 148 | 149 | def pack_dataset(rates, spikes): 150 | return { 151 | "rates": rates, 152 | "spikes": spikes.long() 153 | } 154 | 155 | data_dir = "/snel/share/data/pth_lorenz_large" 156 | os.makedirs(data_dir, exist_ok=True) 157 | 158 | torch.save(pack_dataset(train_rates, train_spikes), osp.join(data_dir, "train.pth")) 159 | torch.save(pack_dataset(val_rates, val_spikes), osp.join(data_dir, "val.pth")) 160 | torch.save(pack_dataset(test_rates, test_spikes), osp.join(data_dir, "test.pth")) 161 | 162 | #%% 163 | # Plot trajectories 164 | np.random.seed(0) 165 | fig = plt.figure() 166 | NUM_PLOTS = 6 167 | for i in range(NUM_PLOTS): 168 | sample_idx = np.random.randint(0, train_n) 169 | sample = lorenz_dataset[sample_idx].T 170 | ax = fig.add_subplot(231 + i, projection='3d') 171 | color = np.linspace(0, 1, sample.shape[1]) 172 | ax.scatter(*sample, lw=0.5, c=color) 173 | 174 | ax.set_title(f"{sample_idx}") 175 | fig.tight_layout() 176 | 177 | plt.show() 178 | 179 | #%% 180 | # Plot rates 181 | np.random.seed(0) 182 | fig = plt.figure() 183 | NUM_PLOTS = 8 184 | print(f"Train \t {train_rates[0].max().item()}") 185 | print(f"Val \t {val_rates[0].max().item()}") 186 | plot_t = 100 187 | for i in range(NUM_PLOTS): 188 | sample_idx = np.random.randint(0, n_neurons) 189 | sample = torch.exp(train_rates[0, :plot_t, sample_idx]).numpy() 190 | ax = fig.add_subplot(421 + i) 191 | ax.plot(np.arange(plot_t), sample, ) 192 | # ax.scatter(np.arange(t), sample, c=np.arange(0, 1, 1.0/t)) 193 | 194 | ax.set_title(f"{sample_idx}") 195 | plt.suptitle("Rates") 196 | fig.tight_layout() 197 | plt.show() 198 | 199 | #%% 200 | # Plot spike times 201 | np.random.seed(0) 202 | fig = plt.figure() 203 | NUM_PLOTS = 8 204 | print(f"Train \t {train_spikes[0].max().item()}") 205 | print(f"Val \t {val_spikes[0].max().item()}") 206 | plot_t = 50 207 | for i in range(NUM_PLOTS): 208 | sample_idx = np.random.randint(0, n_neurons) 209 | plot_data = train_spikes[0, :plot_t, sample_idx] 210 | sample = plot_data.numpy() 211 | ax = fig.add_subplot(421 + i) 212 | ax.plot(np.arange(plot_t), sample, ) 213 | # ax.scatter(np.arange(t), sample, c=np.arange(0, 1, 1.0/t)) 214 | 215 | ax.set_title(f"{sample_idx}") 216 | plt.suptitle("Spikes") 217 | fig.tight_layout() 218 | plt.show() 219 | 220 | 221 | #%% 222 | lfads_data = "/snel/share/data/lfads_lorenz_20ms" 223 | import h5py 224 | h5file = f"{lfads_data}/lfads_dataset001.h5" 225 | with h5py.File(h5file, "r") as f: 226 | h5dict = {key: f[key][()] for key in f.keys()} 227 | if 'train_truth' in h5dict: 228 | cf = h5dict['conversion_factor'] 229 | train_truth = h5dict['train_truth'].astype(np.float32) / cf 230 | data = torch.tensor(train_truth) 231 | n_neurons = data.shape[-1] 232 | plt.hist(data.flatten().numpy(), bins=50) -------------------------------------------------------------------------------- /defaults.py: -------------------------------------------------------------------------------- 1 | # Src: Andrew's tune_tf2 2 | 3 | from os import path 4 | 5 | repo_path = path.dirname(path.realpath(__file__)) 6 | DEFAULT_CONFIG_DIR = path.join(repo_path, 'configs') 7 | 8 | # contains data about general status of PBT optimization 9 | PBT_CSV = 'pbt_state.csv' 10 | # contains data about which models are exploited 11 | EXPLOIT_CSV = 'exploits.csv' 12 | # contains data about which hyperparameters are used 13 | HPS_CSV= 'gen_config.csv' 14 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: ndt_old 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - anaconda 6 | - defaults 7 | dependencies: 8 | - _libgcc_mutex=0.1=main 9 | - attrs=19.3.0=py_0 10 | - backcall=0.1.0=py36_0 11 | - basemap=1.2.0=py36h705c2d8_0 12 | - blas=1.0=mkl 13 | - bleach=3.1.4=py_0 14 | - ca-certificates=2020.10.14=0 15 | - certifi=2020.6.20=py36_0 16 | - cudatoolkit=10.1.243=h6bb024c_0 17 | - cycler=0.10.0=py36_0 18 | - dbus=1.13.12=h746ee38_0 19 | - decorator=4.4.2=py_0 20 | - defusedxml=0.6.0=py_0 21 | - entrypoints=0.3=py36_0 22 | - expat=2.2.6=he6710b0_0 23 | - fontconfig=2.13.0=h9420a91_0 24 | - freetype=2.9.1=h8a8886c_1 25 | - geos=3.6.2=heeff764_2 26 | - git-lfs=2.6.1=h7b6447c_0 27 | - glib=2.56.2=hd408876_0 28 | - gmp=6.1.2=h6c8ec71_1 29 | - gst-plugins-base=1.14.0=hbbd80ab_1 30 | - gstreamer=1.14.0=hb453b48_1 31 | - icu=58.2=h211956c_0 32 | - importlib_metadata=1.5.0=py36_0 33 | - intel-openmp=2020.0=166 34 | - ipykernel=5.1.4=py36h39e3cac_0 35 | - ipython=7.13.0=py36h5ca1d4c_0 36 | - ipython_genutils=0.2.0=py36_0 37 | - ipywidgets=7.5.1=py_0 38 | - jedi=0.17.0=py36_0 39 | - jinja2=2.11.2=py_0 40 | - joblib=0.14.1=py_0 41 | - jpeg=9b=h024ee3a_2 42 | - jsonschema=3.2.0=py36_0 43 | - jupyter=1.0.0=py36_7 44 | - jupyter_client=6.1.3=py_0 45 | - jupyter_console=6.1.0=py_0 46 | - jupyter_core=4.6.3=py36_0 47 | - kiwisolver=1.1.0=py36he6710b0_0 48 | - ld_impl_linux-64=2.33.1=h53a641e_7 49 | - libedit=3.1.20181209=hc058e9b_0 50 | - libffi=3.2.1=hd88cf55_4 51 | - libgcc-ng=9.1.0=hdf63c60_0 52 | - libgfortran-ng=7.3.0=hdf63c60_0 53 | - libpng=1.6.37=hbc83047_0 54 | - libsodium=1.0.16=h1bed415_0 55 | - libstdcxx-ng=9.1.0=hdf63c60_0 56 | - libtiff=4.1.0=h2733197_0 57 | - libuuid=1.0.3=h1bed415_2 58 | - libxcb=1.13=h1bed415_1 59 | - libxml2=2.9.9=hea5a465_1 60 | - markupsafe=1.1.1=py36h7b6447c_0 61 | - matplotlib=3.1.3=py36_0 62 | - matplotlib-base=3.1.3=py36hef1b27d_0 63 | - mistune=0.8.4=py36h7b6447c_0 64 | - mkl=2020.0=166 65 | - mkl-service=2.3.0=py36he904b0f_0 66 | - mkl_fft=1.0.15=py36ha843d7b_0 67 | - mkl_random=1.1.0=py36hd6b4f25_0 68 | - nbconvert=5.6.1=py36_0 69 | - nbformat=5.0.4=py_0 70 | - ncurses=6.2=he6710b0_1 71 | - ninja=1.9.0=py36hfd86e86_0 72 | - notebook=6.0.3=py36_0 73 | - numpy=1.18.1=py36h4f9e942_0 74 | - numpy-base=1.18.1=py36hde5b4d6_1 75 | - olefile=0.46=py36_0 76 | - openssl=1.1.1h=h7b6447c_0 77 | - pandas=1.0.4=py36h0573a6f_0 78 | - pandoc=2.2.3.2=0 79 | - pandocfilters=1.4.2=py36_1 80 | - parso=0.7.0=py_0 81 | - pcre=8.43=he6710b0_0 82 | - pexpect=4.8.0=py36_0 83 | - pickleshare=0.7.5=py36_0 84 | - pillow=7.1.2=py36hb39fc2d_0 85 | - pip=20.0.2=py36_1 86 | - proj=6.2.1=haa6030c_0 87 | - prometheus_client=0.7.1=py_0 88 | - prompt-toolkit=3.0.4=py_0 89 | - prompt_toolkit=3.0.4=0 90 | - ptyprocess=0.6.0=py36_0 91 | - pygments=2.6.1=py_0 92 | - pyparsing=2.4.6=py_0 93 | - pyproj=2.6.0=py36hd003209_1 94 | - pyqt=5.9.2=py36h22d08a2_1 95 | - pyrsistent=0.16.0=py36h7b6447c_0 96 | - pyshp=2.1.0=py_0 97 | - python=3.6.10=hcf32534_1 98 | - python-dateutil=2.8.1=py_0 99 | - python_abi=3.6=1_cp36m 100 | - pytorch=1.5.0=py3.6_cuda10.1.243_cudnn7.6.3_0 101 | - pytz=2020.1=py_0 102 | - pyzmq=18.1.1=py36he6710b0_0 103 | - qt=5.9.7=h5867ecd_1 104 | - qtconsole=4.7.3=py_0 105 | - qtpy=1.9.0=py_0 106 | - readline=8.0=h7b6447c_0 107 | - scikit-learn=0.22.1=py36hd81dba3_0 108 | - scipy=1.4.1=py36h0b6359f_0 109 | - seaborn=0.10.1=py_0 110 | - send2trash=1.5.0=py36_0 111 | - setuptools=46.1.3=py36_0 112 | - sip=4.19.13=py36he6710b0_0 113 | - six=1.14.0=py36_0 114 | - sqlite=3.31.1=h62c20be_1 115 | - terminado=0.8.3=py36_0 116 | - testpath=0.4.4=py_0 117 | - tk=8.6.8=hbc83047_0 118 | - torchvision=0.6.0=py36_cu101 119 | - tornado=6.0.4=py36h7b6447c_1 120 | - traitlets=4.3.3=py36_0 121 | - wcwidth=0.1.9=py_0 122 | - webencodings=0.5.1=py36_1 123 | - wheel=0.34.2=py36_0 124 | - widgetsnbextension=3.5.1=py36_0 125 | - xz=5.2.5=h7b6447c_0 126 | - yaml=0.1.7=h96e3832_1 127 | - zeromq=4.3.1=he6710b0_3 128 | - zipp=3.1.0=py_0 129 | - zlib=1.2.11=h7b6447c_3 130 | - zstd=1.3.7=h0b5b093_0 131 | - pip: 132 | - absl-py==0.9.0 133 | - aiohttp==3.6.2 134 | - aiohttp-cors==0.7.0 135 | - aioredis==1.3.1 136 | - astor==0.8.1 137 | - async-timeout==3.0.1 138 | - beautifulsoup4==4.9.2 139 | - blessings==1.7 140 | - boto3==1.13.26 141 | - botocore==1.16.26 142 | - cachetools==4.1.0 143 | - chardet==3.0.4 144 | - click==7.1.2 145 | - colorama==0.4.3 146 | - colorful==0.5.4 147 | - contextvars==2.4 148 | - dataclasses==0.7 149 | - docutils==0.15.2 150 | - filelock==3.0.12 151 | - gast==0.2.2 152 | - google==3.0.0 153 | - google-api-core==1.22.2 154 | - google-auth==1.22.0 155 | - google-auth-oauthlib==0.4.1 156 | - google-pasta==0.2.0 157 | - googleapis-common-protos==1.52.0 158 | - gpustat==0.6.0 159 | - grpcio==1.28.1 160 | - h5py==2.10.0 161 | - hiredis==1.1.0 162 | - idna==2.9 163 | - idna-ssl==1.1.0 164 | - immutables==0.14 165 | - jmespath==0.10.0 166 | - keras-applications==1.0.8 167 | - keras-preprocessing==1.1.0 168 | - markdown==3.2.1 169 | - msgpack==1.0.0 170 | - multidict==4.7.6 171 | - nvidia-ml-py3==7.352.0 172 | - oauthlib==3.1.0 173 | - opencensus==0.7.10 174 | - opencensus-context==0.1.1 175 | - opt-einsum==3.2.1 176 | - protobuf==3.11.3 177 | - psutil==5.7.2 178 | - py-spy==0.3.3 179 | - pyasn1==0.4.8 180 | - pyasn1-modules==0.2.8 181 | - pytorch-transformers==1.2.0 182 | - pyyaml==5.3.1 183 | - ray==0.8.7 184 | - redis==3.4.1 185 | - regex==2020.6.8 186 | - requests==2.23.0 187 | - requests-oauthlib==1.3.0 188 | - rsa==4.0 189 | - s3transfer==0.3.3 190 | - sacremoses==0.0.43 191 | - sentencepiece==0.1.91 192 | - soupsieve==2.0.1 193 | - tabulate==0.8.7 194 | - tensorboard==2.2.0 195 | - tensorboard-plugin-wit==1.6.0.post3 196 | - tensorboardx==2.1 197 | - tensorflow==2.1.0 198 | - tensorflow-estimator==2.1.0 199 | - termcolor==1.1.0 200 | - tqdm==4.46.1 201 | - typing-extensions==3.7.4.3 202 | - urllib3==1.25.9 203 | - werkzeug==1.0.1 204 | - wrapt==1.12.1 205 | - yacs==0.1.7 206 | - yarl==1.6.0 207 | prefix: /snel/home/joely/anaconda3/envs/pytorch 208 | 209 | -------------------------------------------------------------------------------- /nlb.yml: -------------------------------------------------------------------------------- 1 | name: ndt 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1=main 7 | - _openmp_mutex=4.5=1_gnu 8 | - blas=1.0=mkl 9 | - bottleneck=1.3.2=py37heb32a55_1 10 | - brotli=1.0.9=he6710b0_2 11 | - bzip2=1.0.8=h7b6447c_0 12 | - ca-certificates=2021.10.26=h06a4308_2 13 | - certifi=2021.10.8=py37h06a4308_2 14 | - cudatoolkit=10.2.89=hfd86e86_1 15 | - cycler=0.11.0=pyhd3eb1b0_0 16 | - dbus=1.13.18=hb2f20db_0 17 | - expat=2.4.4=h295c915_0 18 | - ffmpeg=4.3=hf484d3e_0 19 | - fontconfig=2.13.1=h6c09931_0 20 | - fonttools=4.25.0=pyhd3eb1b0_0 21 | - freetype=2.11.0=h70c0345_0 22 | - giflib=5.2.1=h7b6447c_0 23 | - glib=2.69.1=h4ff587b_1 24 | - gmp=6.2.1=h2531618_2 25 | - gnutls=3.6.15=he1e5248_0 26 | - gst-plugins-base=1.14.0=h8213a91_2 27 | - gstreamer=1.14.0=h28cd5cc_2 28 | - icu=58.2=he6710b0_3 29 | - intel-openmp=2021.4.0=h06a4308_3561 30 | - jpeg=9d=h7f8727e_0 31 | - kiwisolver=1.3.2=py37h295c915_0 32 | - lame=3.100=h7b6447c_0 33 | - lcms2=2.12=h3be6417_0 34 | - ld_impl_linux-64=2.35.1=h7274673_9 35 | - libffi=3.3=he6710b0_2 36 | - libgcc-ng=9.3.0=h5101ec6_17 37 | - libgfortran-ng=7.5.0=ha8ba4b0_17 38 | - libgfortran4=7.5.0=ha8ba4b0_17 39 | - libgomp=9.3.0=h5101ec6_17 40 | - libiconv=1.15=h63c8f33_5 41 | - libidn2=2.3.2=h7f8727e_0 42 | - libpng=1.6.37=hbc83047_0 43 | - libstdcxx-ng=9.3.0=hd4cf53a_17 44 | - libtasn1=4.16.0=h27cfd23_0 45 | - libtiff=4.2.0=h85742a9_0 46 | - libunistring=0.9.10=h27cfd23_0 47 | - libuuid=1.0.3=h7f8727e_2 48 | - libuv=1.40.0=h7b6447c_0 49 | - libwebp=1.2.0=h89dd481_0 50 | - libwebp-base=1.2.0=h27cfd23_0 51 | - libxcb=1.14=h7b6447c_0 52 | - libxml2=2.9.12=h03d6c58_0 53 | - lz4-c=1.9.3=h295c915_1 54 | - matplotlib=3.5.1=py37h06a4308_0 55 | - matplotlib-base=3.5.1=py37ha18d171_0 56 | - mkl=2021.4.0=h06a4308_640 57 | - mkl-service=2.4.0=py37h7f8727e_0 58 | - mkl_fft=1.3.1=py37hd3c417c_0 59 | - mkl_random=1.2.2=py37h51133e4_0 60 | - munkres=1.1.4=py_0 61 | - ncurses=6.3=h7f8727e_2 62 | - nettle=3.7.3=hbbd107a_1 63 | - numexpr=2.8.1=py37h6abb31d_0 64 | - numpy=1.21.2=py37h20f2e39_0 65 | - numpy-base=1.21.2=py37h79a1101_0 66 | - olefile=0.46=py37_0 67 | - openh264=2.1.1=h4ff587b_0 68 | - openssl=1.1.1m=h7f8727e_0 69 | - packaging=21.3=pyhd3eb1b0_0 70 | - pcre=8.45=h295c915_0 71 | - pillow=8.4.0=py37h5aabda8_0 72 | - pip=21.2.2=py37h06a4308_0 73 | - pyqt=5.9.2=py37h05f1152_2 74 | - python=3.7.11=h12debd9_0 75 | - python-dateutil=2.8.2=pyhd3eb1b0_0 76 | - pytorch=1.10.2=py3.7_cuda10.2_cudnn7.6.5_0 77 | - pytorch-mutex=1.0=cuda 78 | - qt=5.9.7=h5867ecd_1 79 | - readline=8.1.2=h7f8727e_1 80 | - scipy=1.7.3=py37hc147768_0 81 | - seaborn=0.11.2=pyhd3eb1b0_0 82 | - setuptools=58.0.4=py37h06a4308_0 83 | - sip=4.19.8=py37hf484d3e_0 84 | - six=1.16.0=pyhd3eb1b0_1 85 | - sqlite=3.37.2=hc218d9a_0 86 | - tk=8.6.11=h1ccaba5_0 87 | - torchaudio=0.10.2=py37_cu102 88 | - torchvision=0.11.3=py37_cu102 89 | - tornado=6.1=py37h27cfd23_0 90 | - typing_extensions=3.10.0.2=pyh06a4308_0 91 | - wheel=0.37.1=pyhd3eb1b0_0 92 | - xz=5.2.5=h7b6447c_0 93 | - zlib=1.2.11=h7f8727e_4 94 | - zstd=1.4.9=haebb681_0 95 | - pip: 96 | - absl-py==1.0.0 97 | - appdirs==1.4.4 98 | - asciitree==0.3.3 99 | - attrs==21.4.0 100 | - blessings==1.7 101 | - boto3==1.21.4 102 | - botocore==1.24.5 103 | - cached-property==1.5.2 104 | - cachetools==5.0.0 105 | - cffi==1.15.0 106 | - charset-normalizer==2.0.12 107 | - ci-info==0.2.0 108 | - click==8.0.4 109 | - click-didyoumean==0.3.0 110 | - cryptography==36.0.1 111 | - dandi==0.36.0 112 | - dandischema==0.5.2 113 | - deprecated==1.2.13 114 | - dnspython==2.2.0 115 | - email-validator==1.1.3 116 | - etelemetry==0.3.0 117 | - fasteners==0.17.3 118 | - filelock==3.6.0 119 | - fscacher==0.2.0 120 | - google-auth==2.6.0 121 | - google-auth-oauthlib==0.4.6 122 | - grpcio==1.44.0 123 | - h5py==3.6.0 124 | - hdmf==3.1.1 125 | - humanize==4.0.0 126 | - idna==3.3 127 | - importlib-metadata==4.11.1 128 | - interleave==0.2.0 129 | - jeepney==0.7.1 130 | - jmespath==0.10.0 131 | - joblib==1.1.0 132 | - jsonpointer==2.2 133 | - jsonschema==3.2.0 134 | - keyring==23.5.0 135 | - keyrings-alt==4.1.0 136 | - markdown==3.3.6 137 | - msgpack==1.0.3 138 | - numcodecs==0.9.1 139 | - oauthlib==3.2.0 140 | - pandas==1.3.5 141 | - protobuf==3.19.4 142 | - pyasn1==0.4.8 143 | - pyasn1-modules==0.2.8 144 | - pycparser==2.21 145 | - pycryptodomex==3.14.1 146 | - pydantic==1.9.0 147 | - pynwb==2.0.0 148 | - pyout==0.7.2 149 | - pyparsing==3.0.7 150 | - pyrsistent==0.18.1 151 | - pytorch-transformers==1.2.0 152 | - pytz==2021.3 153 | - pyyaml==6.0 154 | - ray==1.10.0 155 | - redis==4.1.4 156 | - regex==2022.1.18 157 | - requests==2.27.1 158 | - requests-oauthlib==1.3.1 159 | - rfc3987==1.3.8 160 | - rsa==4.8 161 | - ruamel-yaml==0.17.21 162 | - ruamel-yaml-clib==0.2.6 163 | - s3transfer==0.5.1 164 | - sacremoses==0.0.47 165 | - scikit-learn==1.0.2 166 | - secretstorage==3.3.1 167 | - semantic-version==2.9.0 168 | - sentencepiece==0.1.96 169 | - sklearn==0.0 170 | - strict-rfc3339==0.7 171 | - tabulate==0.8.9 172 | - tenacity==8.0.1 173 | - tensorboard==2.8.0 174 | - tensorboard-data-server==0.6.1 175 | - tensorboard-plugin-wit==1.8.1 176 | - tensorboardx==2.5 177 | - threadpoolctl==3.1.0 178 | - tqdm==4.62.3 179 | - urllib3==1.26.8 180 | - webcolors==1.11.1 181 | - werkzeug==2.0.3 182 | - wrapt==1.13.3 183 | - yacs==0.1.8 184 | - zarr==2.11.0 185 | - zipp==3.7.0 186 | prefix: /home/joelye/anaconda3/envs/ndt 187 | -------------------------------------------------------------------------------- /ray_get_lfve.py: -------------------------------------------------------------------------------- 1 | """ 2 | posthoc script to create a directory for "best lfve" 3 | """ 4 | 5 | from typing import List, Union 6 | from os import path 7 | import json 8 | import argparse 9 | import ray, yaml, shutil 10 | from ray import tune 11 | import torch 12 | 13 | from tune_models import tuneNDT 14 | 15 | from defaults import DEFAULT_CONFIG_DIR 16 | from src.config.default import flatten 17 | 18 | PBT_HOME = path.expanduser('~/ray_results/ndt/gridsearch') 19 | OVERWRITE = True 20 | PBT_METRIC = 'smth_masked_loss' 21 | BEST_MODEL_METRIC = 'best_masked_loss' 22 | LOGGED_COLUMNS = ['smth_masked_loss', 'masked_loss', 'r2', 'unmasked_loss'] 23 | 24 | DEFAULT_HP_DICT = { 25 | 'TRAIN.WEIGHT_DECAY': tune.loguniform(1e-8, 1e-3), 26 | 'TRAIN.MASK_RATIO': tune.uniform(0.1, 0.4) 27 | } 28 | 29 | def get_parser(): 30 | parser = argparse.ArgumentParser() 31 | 32 | parser.add_argument( 33 | "--exp-config", "-e", 34 | type=str, 35 | required=True, 36 | help="path to config yaml containing info about experiment", 37 | ) 38 | 39 | parser.add_argument('--eval-only', '-ev', dest='eval_only', action='store_true') 40 | parser.add_argument('--no-eval-only', '-nev', dest='eval_only', action='store_false') 41 | parser.set_defaults(eval_only=False) 42 | 43 | parser.add_argument( 44 | "--name", "-n", 45 | type=str, 46 | default="", 47 | help="defaults to exp filename" 48 | ) 49 | 50 | parser.add_argument( 51 | "--gpus-per-worker", "-g", 52 | type=float, 53 | default=0.5 54 | ) 55 | 56 | parser.add_argument( 57 | "--cpus-per-worker", "-c", 58 | type=float, 59 | default=3.0 60 | ) 61 | 62 | parser.add_argument( 63 | "--workers", "-w", 64 | type=int, 65 | default=-1, 66 | help="-1 indicates -- use max possible workers on machine (assuming 0.5 GPUs per trial)" 67 | ) 68 | 69 | parser.add_argument( 70 | "--samples", "-s", 71 | type=int, 72 | default=20, 73 | help="samples for random search" 74 | ) 75 | 76 | parser.add_argument( 77 | "--seed", "-d", 78 | type=int, 79 | default=-1, 80 | help="seed for config" 81 | ) 82 | 83 | return parser 84 | 85 | def main(): 86 | parser = get_parser() 87 | args = parser.parse_args() 88 | launch_search(**vars(args)) 89 | 90 | def build_hp_dict(raw_json: dict): 91 | hp_dict = {} 92 | for key in raw_json: 93 | info: dict = raw_json[key] 94 | sample_fn = info.get("sample_fn", "uniform") 95 | assert hasattr(tune, sample_fn) 96 | if sample_fn == "choice": 97 | hp_dict[key] = tune.choice(info['opts']) 98 | else: 99 | assert "low" in info, "high" in info 100 | sample_fn = getattr(tune, sample_fn) 101 | hp_dict[key] = sample_fn(info['low'], info['high']) 102 | return hp_dict 103 | 104 | def launch_search(exp_config: Union[List[str], str], name: str, workers: int, gpus_per_worker: float, cpus_per_worker: float, eval_only: bool, samples: int, seed: int) -> None: 105 | # ---------- PBT I/O CONFIGURATION ---------- 106 | # the directory to save PBT runs (usually '~/ray_results') 107 | 108 | if len(path.split(exp_config)[0]) > 0: 109 | CFG_PATH = exp_config 110 | else: 111 | CFG_PATH = path.join(DEFAULT_CONFIG_DIR, exp_config) 112 | variant_name = path.split(CFG_PATH)[1].split('.')[0] 113 | if seed > 0: 114 | variant_name = f"{variant_name}-s{seed}" 115 | if name == "": 116 | name = variant_name 117 | pbt_dir = path.join(PBT_HOME, name) 118 | # the name of this PBT run (run will be stored at `pbt_dir`) 119 | 120 | # --------------------------------------------- 121 | # * No train step 122 | # load the results dataframe for this run 123 | df = tune.Analysis( 124 | pbt_dir 125 | ).dataframe() 126 | df = df[df.logdir.apply(lambda path: not 'best_model' in path)] 127 | 128 | lfves = [] 129 | for logdir in df.logdir: 130 | ckpt = torch.load(path.join(logdir, f'ckpts/{variant_name}.lfve.pth'), map_location='cpu') 131 | lfves.append(ckpt['best_unmasked_val']['value']) 132 | df['best_unmasked_val'] = lfves 133 | best_model_logdir = df.loc[df['best_unmasked_val'].idxmin()].logdir 134 | best_model_dest = path.join(pbt_dir, 'best_model_unmasked') 135 | if path.exists(best_model_dest): 136 | shutil.rmtree(best_model_dest) 137 | shutil.copytree(best_model_logdir, best_model_dest) 138 | 139 | if __name__ == "__main__": 140 | main() -------------------------------------------------------------------------------- /ray_random.py: -------------------------------------------------------------------------------- 1 | # Src: Andrew's run_random_search in tune_tf2 2 | 3 | """ 4 | Run grid search for NDT. 5 | """ 6 | 7 | from typing import List, Union 8 | import os 9 | from os import path 10 | import json 11 | import argparse 12 | import ray, yaml, shutil 13 | from ray import tune 14 | import torch 15 | 16 | from tune_models import tuneNDT 17 | 18 | from defaults import DEFAULT_CONFIG_DIR 19 | from src.config.default import flatten 20 | 21 | PBT_HOME = path.expanduser('~/user_data/nlb/ndt_runs/ray/') 22 | os.makedirs(PBT_HOME, exist_ok=True) # ray hangs if this isn't true... 23 | OVERWRITE = True 24 | PBT_METRIC = 'smth_masked_loss' 25 | BEST_MODEL_METRIC = 'best_masked_loss' 26 | LOGGED_COLUMNS = ['smth_masked_loss', 'masked_loss', 'r2', 'unmasked_loss'] 27 | 28 | DEFAULT_HP_DICT = { 29 | 'TRAIN.WEIGHT_DECAY': tune.loguniform(1e-8, 1e-3), 30 | 'TRAIN.MASK_RATIO': tune.uniform(0.1, 0.4) 31 | } 32 | 33 | def get_parser(): 34 | parser = argparse.ArgumentParser() 35 | 36 | parser.add_argument( 37 | "--exp-config", "-e", 38 | type=str, 39 | required=True, 40 | help="path to config yaml containing info about experiment", 41 | ) 42 | 43 | parser.add_argument('--eval-only', '-ev', dest='eval_only', action='store_true') 44 | parser.add_argument('--no-eval-only', '-nev', dest='eval_only', action='store_false') 45 | parser.set_defaults(eval_only=False) 46 | 47 | parser.add_argument( 48 | "--name", "-n", 49 | type=str, 50 | default="", 51 | help="defaults to exp filename" 52 | ) 53 | 54 | parser.add_argument( 55 | "--gpus-per-worker", "-g", 56 | type=float, 57 | default=0.5 58 | ) 59 | 60 | parser.add_argument( 61 | "--cpus-per-worker", "-c", 62 | type=float, 63 | default=3.0 64 | ) 65 | 66 | parser.add_argument( 67 | "--workers", "-w", 68 | type=int, 69 | default=-1, 70 | help="-1 indicates -- use max possible workers on machine (assuming 0.5 GPUs per trial)" 71 | ) 72 | 73 | parser.add_argument( 74 | "--samples", "-s", 75 | type=int, 76 | default=20, 77 | help="samples for random search" 78 | ) 79 | 80 | parser.add_argument( 81 | "--seed", "-d", 82 | type=int, 83 | default=-1, 84 | help="seed for config" 85 | ) 86 | 87 | return parser 88 | 89 | def main(): 90 | parser = get_parser() 91 | args = parser.parse_args() 92 | launch_search(**vars(args)) 93 | 94 | def build_hp_dict(raw_json: dict): 95 | hp_dict = {} 96 | for key in raw_json: 97 | info: dict = raw_json[key] 98 | sample_fn = info.get("sample_fn", "uniform") 99 | assert hasattr(tune, sample_fn) 100 | if sample_fn == "choice": 101 | hp_dict[key] = tune.choice(info['opts']) 102 | else: 103 | assert "low" in info, "high" in info 104 | sample_fn = getattr(tune, sample_fn) 105 | hp_dict[key] = sample_fn(info['low'], info['high']) 106 | return hp_dict 107 | 108 | def launch_search(exp_config: Union[List[str], str], name: str, workers: int, gpus_per_worker: float, cpus_per_worker: float, eval_only: bool, samples: int, seed: int) -> None: 109 | # ---------- PBT I/O CONFIGURATION ---------- 110 | # the directory to save PBT runs (usually '~/ray_results') 111 | 112 | if len(path.split(exp_config)[0]) > 0: 113 | CFG_PATH = exp_config 114 | else: 115 | CFG_PATH = path.join(DEFAULT_CONFIG_DIR, exp_config) 116 | variant_name = path.split(CFG_PATH)[1].split('.')[0] 117 | # Ok, now update the paths in the config 118 | if seed > 0: 119 | variant_name = f"{variant_name}-s{seed}" 120 | if name == "": 121 | name = variant_name 122 | 123 | pbt_dir = path.join(PBT_HOME, name) 124 | # the name of this PBT run (run will be stored at `pbt_dir`) 125 | 126 | # ---------- PBT RUN CONFIGURATION ---------- 127 | # whether to use single machine or cluster 128 | SINGLE_MACHINE = True # Cluster not supported atm, don't know how to use it. 129 | 130 | NUM_WORKERS = workers if workers > 0 else int(torch.cuda.device_count() // gpus_per_worker) 131 | # the resources to allocate per model 132 | RESOURCES_PER_TRIAL = {"cpu": cpus_per_worker, "gpu": gpus_per_worker} 133 | 134 | # --------------------------------------------- 135 | 136 | def train(): 137 | if path.exists(pbt_dir): 138 | print("Run exists!!! Overwriting.") 139 | if not OVERWRITE: 140 | print("overwriting disallowed, exiting..") 141 | exit(0) 142 | else: 143 | if path.exists(pbt_dir): 144 | shutil.rmtree(pbt_dir) 145 | 146 | # load the configuration as a dictionary and update for this run 147 | flat_cfg_dict = flatten(yaml.full_load(open(CFG_PATH))) 148 | 149 | # Default behavior is to pull experiment name from config file 150 | # Bind variant name to directories 151 | flat_cfg_dict.update({'VARIANT': variant_name}) 152 | if seed > 0: 153 | flat_cfg_dict.update({'SEED': seed}) 154 | 155 | # the hyperparameter space to search 156 | assert 'TRAIN.TUNE_HP_JSON' in flat_cfg_dict, "please specify hp sweep (no default)" 157 | with open(flat_cfg_dict['TRAIN.TUNE_HP_JSON']) as f: 158 | raw_hp_json = json.load(f) 159 | cfg_samples = DEFAULT_HP_DICT 160 | cfg_samples.update(build_hp_dict(raw_hp_json)) 161 | 162 | flat_cfg_dict.update(cfg_samples) 163 | 164 | # connect to Ray cluster or start on single machine 165 | address = None if SINGLE_MACHINE else 'localhost:6379' 166 | ray.init(address=address) 167 | 168 | reporter = tune.CLIReporter(metric_columns=LOGGED_COLUMNS) 169 | 170 | analysis = tune.run( 171 | tuneNDT, 172 | name=name, 173 | local_dir=pbt_dir, 174 | stop={'done': True}, 175 | config=flat_cfg_dict, 176 | resources_per_trial=RESOURCES_PER_TRIAL, 177 | num_samples=samples, 178 | # sync_to_driver='# {source} {target}', # prevents rsync 179 | verbose=1, 180 | progress_reporter=reporter, 181 | # loggers=(tune.logger.JsonLogger, tune.logger.CSVLogger) 182 | ) 183 | 184 | if not eval_only: 185 | train() 186 | # load the results dataframe for this run 187 | df = tune.Analysis( 188 | pbt_dir 189 | ).dataframe() 190 | df = df[df.logdir.apply(lambda path: not 'best_model' in path)] 191 | 192 | # Hm... we need to go through each model, and run the lfve ckpt. 193 | # And then record that in the dataframe? 194 | 195 | if df[BEST_MODEL_METRIC].dtype == 'O': # Accidentally didn't case to scalar, now we have a tensor string 196 | df = df.assign(best_masked_loss=lambda df: df[BEST_MODEL_METRIC].str[7:13].astype(float)) 197 | best_model_logdir = df.loc[df[BEST_MODEL_METRIC].idxmin()].logdir 198 | # copy the best model somewhere easy to find 199 | # best_model_src = path.join(best_model_logdir, 'model_dir') 200 | best_model_dest = path.join(pbt_dir, 'best_model') 201 | if path.exists(best_model_dest): 202 | shutil.rmtree(best_model_dest) 203 | shutil.copytree(best_model_logdir, best_model_dest) 204 | 205 | if __name__ == "__main__": 206 | main() -------------------------------------------------------------------------------- /scripts/analyze_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | 4 | import sys 5 | module_path = os.path.abspath(os.path.join('..')) 6 | if module_path not in sys.path: 7 | sys.path.append(module_path) 8 | 9 | import time 10 | import h5py 11 | import matplotlib.pyplot as plt 12 | import seaborn as sns 13 | import numpy as np 14 | import torch 15 | import torch.nn.functional as f 16 | from torch.utils import data 17 | 18 | from src.run import prepare_config 19 | from src.runner import Runner 20 | from src.dataset import SpikesDataset, DATASET_MODES 21 | from src.mask import Masker, UNMASKED_LABEL 22 | 23 | def make_runner(variant, ckpt, base="", prefix=""): 24 | run_type = "eval" 25 | exp_config = osp.join("../configs", prefix, f"{variant}.yaml") 26 | if base != "": 27 | exp_config = [osp.join("../configs", f"{base}.yaml"), exp_config] 28 | ckpt_path = f"{variant}.{ckpt}.pth" 29 | config, ckpt_path = prepare_config( 30 | exp_config, run_type, ckpt_path, [ 31 | "USE_TENSORBOARD", False, 32 | "SYSTEM.NUM_GPUS", 1, 33 | ], suffix=prefix 34 | ) 35 | return Runner(config), ckpt_path 36 | 37 | def setup_dataset(runner, mode): 38 | test_set = SpikesDataset(runner.config, runner.config.DATA.VAL_FILENAME, mode=mode, logger=runner.logger) 39 | runner.logger.info(f"Evaluating on {len(test_set)} samples.") 40 | test_set.clip_spikes(runner.max_spikes) 41 | spikes, rates, heldout_spikes, forward_spikes = test_set.get_dataset() 42 | if heldout_spikes is not None: 43 | heldout_spikes = heldout_spikes.to(runner.device) 44 | if forward_spikes is not None: 45 | forward_spikes = forward_spikes.to(runner.device) 46 | return spikes.to(runner.device), rates.to(runner.device), heldout_spikes, forward_spikes 47 | 48 | def init_by_ckpt(ckpt_path, mode=DATASET_MODES.val): 49 | runner = Runner(checkpoint_path=ckpt_path) 50 | runner.model.eval() 51 | torch.set_grad_enabled(False) 52 | spikes, rates, heldout_spikes, forward_spikes = setup_dataset(runner, mode) 53 | return runner, spikes, rates, heldout_spikes, forward_spikes 54 | 55 | def init(variant, ckpt, base="", prefix="", mode=DATASET_MODES.val): 56 | runner, ckpt_path = make_runner(variant, ckpt, base, prefix) 57 | return init_by_ckpt(ckpt_path, mode) 58 | 59 | 60 | # Accumulates multiplied attentions - examine at lower layers to see where information is sourced in early processing. 61 | # Takes in layer weights 62 | def get_multiplicative_weights(weights_list): 63 | weights = weights_list[0] 64 | multiplied_weights = [weights] 65 | for layer_weights in weights_list[1:]: 66 | weights = torch.bmm(layer_weights, weights) 67 | multiplied_weights.append(weights) 68 | return multiplied_weights 69 | -------------------------------------------------------------------------------- /scripts/clear.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [[ $# -eq 1 ]] 4 | then 5 | python -u src/run.py --run-type train --exp-config configs/$1.yaml --clear-only True 6 | else 7 | echo "Expected args (ckpt)" 8 | fi -------------------------------------------------------------------------------- /scripts/eval.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # * The evaluation code path has legacy code. Evaluation / analysis is done in analysis scripts. 4 | 5 | if [[ $# -eq 2 ]] 6 | then 7 | python -u src/run.py --run-type eval --exp-config configs/$1.yaml --ckpt-path $1.$2.pth 8 | else 9 | echo "Expected args (ckpt)" 10 | fi -------------------------------------------------------------------------------- /scripts/fig_3a_synthetic_neurons.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Author: Joel Ye 3 | 4 | # Notebook for interactive model evaluation/analysis 5 | # Allows us to interrogate model on variable data (instead of masked sample again) 6 | 7 | #%% 8 | import os 9 | import os.path as osp 10 | 11 | import sys 12 | module_path = os.path.abspath(os.path.join('..')) 13 | if module_path not in sys.path: 14 | sys.path.append(module_path) 15 | 16 | import time 17 | import h5py 18 | import matplotlib.pyplot as plt 19 | import seaborn as sns 20 | import numpy as np 21 | import torch 22 | import torch.nn.functional as f 23 | from torch.utils import data 24 | 25 | from src.run import prepare_config 26 | from src.runner import Runner 27 | from src.dataset import SpikesDataset, DATASET_MODES 28 | from src.mask import UNMASKED_LABEL 29 | 30 | from analyze_utils import init_by_ckpt 31 | # Note that a whole lot of logger items will still be dumped into ./scripts/logs 32 | 33 | grid = True 34 | ckpt = "Grid" if grid else "PBT" 35 | 36 | # Lorenz 37 | variant = "lorenz-s1" 38 | 39 | # Chaotic 40 | variant = "chaotic-s1" 41 | 42 | def get_info(variant): 43 | ckpt_path = f"/snel/home/joely/ray_results/ndt/gridsearch/{variant}/best_model/ckpts/{variant}.lve.pth" 44 | runner, spikes, rates = init_by_ckpt(ckpt_path, mode=DATASET_MODES.val) 45 | 46 | train_set = SpikesDataset(runner.config, runner.config.DATA.TRAIN_FILENAME, mode=DATASET_MODES.train, logger=runner.logger) 47 | train_spikes, train_rates = train_set.get_dataset() 48 | full_spikes = torch.cat([train_spikes.to(spikes.device), spikes], dim=0) 49 | full_rates = torch.cat([train_rates.to(rates.device), rates], dim=0) 50 | # full_spikes = spikes 51 | # full_rates = rates 52 | 53 | ( 54 | unmasked_loss, 55 | pred_rates, 56 | layer_outputs, 57 | attn_weights, 58 | attn_list, 59 | ) = runner.model(full_spikes, mask_labels=full_spikes, return_weights=True) 60 | 61 | return full_spikes.cpu(), full_rates.cpu(), pred_rates, runner 62 | 63 | chaotic_spikes, chaotic_rates, chaotic_ndt, chaotic_runner = get_info('chaotic-s1') 64 | lorenz_spikes, lorenz_rates, lorenz_ndt, lorenz_runner = get_info('lorenz-s1') 65 | 66 | 67 | #%% 68 | import matplotlib 69 | import matplotlib.pyplot as plt 70 | import numpy as np 71 | import seaborn as sns 72 | 73 | palette = sns.color_palette(palette='muted', n_colors=3, desat=0.9) 74 | SMALL_SIZE = 12 75 | MEDIUM_SIZE = 15 76 | LARGE_SIZE = 18 77 | 78 | def prep_plt(ax=plt.gca()): 79 | plt.rc('font', size=MEDIUM_SIZE) # controls default text sizes 80 | plt.rc('axes', labelsize=LARGE_SIZE) # fontsize of the x and y labels 81 | plt.rc('xtick', labelsize=MEDIUM_SIZE) # fontsize of the tick labels 82 | plt.rc('ytick', labelsize=MEDIUM_SIZE) # fontsize of the tick labels 83 | # plt.rc('title', labelsize=MEDIUM_SIZE) # fontsize of the tick labels 84 | 85 | plt.rc('legend', fontsize=SMALL_SIZE) # legend fontsize 86 | plt.style.use('seaborn-muted') 87 | # plt.figure(figsize=(6,4)) 88 | 89 | spine_alpha = 0.5 90 | ax.spines['right'].set_alpha(0.0) 91 | # plt.gca().spines['bottom'].set_alpha(spine_alpha) 92 | ax.spines['bottom'].set_alpha(0) 93 | # plt.gca().spines['left'].set_alpha(spine_alpha) 94 | ax.spines['left'].set_alpha(0) 95 | ax.spines['top'].set_alpha(0.0) 96 | 97 | plt.tight_layout() 98 | 99 | prep_plt() 100 | 101 | # np.random.seed(5) 102 | 103 | # plt.text("NDT Predictions", color=palette[0]) 104 | # plt.text("Ground Truth", color=palette[2]) 105 | def plot_axis( 106 | spikes, 107 | true_rates, 108 | ndt_rates, # ndt 109 | ax, 110 | key="lorenz", 111 | condition_idx=0, 112 | # trial_idx=0, 113 | neuron_idx=0, 114 | seed=3, # legacy, 115 | first_n=8 116 | ): 117 | prep_plt(ax) 118 | # np.random.seed(seed) 119 | # neuron_idx = np.random.randint(0, true_rates.size(-1)) 120 | # trial_idx = np.random.randint(0, true_rates.size(0)) 121 | _, unique, counts = np.unique(true_rates.numpy(), axis=0, return_inverse=True, return_counts=True) 122 | trial_idx = (unique == condition_idx) 123 | 124 | true_trial_rates = true_rates[trial_idx][0, :, neuron_idx].cpu() 125 | pred_trial_rates = ndt_rates[trial_idx][0:first_n, :, neuron_idx].cpu() 126 | trial_spikes = spikes[trial_idx][0:first_n, :, neuron_idx].cpu() 127 | 128 | time = true_rates.size(1) 129 | if True: # lograte 130 | true_trial_rates = true_trial_rates.exp() 131 | pred_trial_rates = pred_trial_rates.exp() 132 | lfads_preds = np.load(f"/snel/home/joely/data/{key}.npy") 133 | lfads_trial_rates = lfads_preds[trial_idx][0:first_n, :, neuron_idx] 134 | 135 | if key == "lorenz": 136 | spike_level = -0.06 137 | else: 138 | spike_level = -0.03 139 | one_tenth_point = 0.1 * (50 if key == "lorenz" else 100) 140 | 141 | ax.plot([0, one_tenth_point], [(first_n + 1) * spike_level, (first_n + 1) * spike_level], 'k-', lw=3) # 5 / 50 = 0.1 142 | 143 | ax.plot(np.arange(time), true_trial_rates.numpy(), color='#111111', label="Ground Truth") #, linestyle="dashed") 144 | 145 | # Only show labels once 146 | labels = { 147 | "NDT": "NDT", 148 | 'LFADS': "AutoLFADS", 149 | 'Spikes': 'Spikes' 150 | } 151 | 152 | for trial_idx in range(pred_trial_rates.numpy().shape[0]): 153 | # print(trial_idx) 154 | # print(pred_trial_rates.size()) 155 | ax.plot(np.arange(time), pred_trial_rates[trial_idx].numpy(), color=palette[0], label=labels['NDT'], alpha=0.4) 156 | labels['NDT'] = "" 157 | ax.plot(np.arange(time), lfads_trial_rates[trial_idx], color=palette[1], label=labels['LFADS'], alpha=0.4) 158 | labels["LFADS"] = "" 159 | spike_times, = np.where(trial_spikes[trial_idx].numpy()) 160 | ax.scatter(spike_times, spike_level * (trial_idx + 1)*np.ones_like(spike_times), c='k', marker='|', label=labels['Spikes'], s=30) 161 | labels['Spikes'] = "" 162 | 163 | ax.set_xticks([]) 164 | # ax.set_xticks(np.linspace(one_tenth_point, one_tenth_point, 1)) 165 | 166 | if key == "lorenz": 167 | ax.set_ylim(-0.6, 0.65) 168 | ax.set_yticks([0.0, 0.4]) 169 | ax.plot([-1, -1], [0, 0.4], 'k-', lw=3) 170 | 171 | ax.annotate("", 172 | xy=(-1.5, -0.5), 173 | xycoords="data", 174 | xytext=(-1.5, -0.3), 175 | textcoords="data", 176 | arrowprops=dict( 177 | arrowstyle="<-", 178 | connectionstyle="arc3,rad=0", 179 | linewidth="2", 180 | # color=(0.2, 0.2, 0.2) 181 | ), 182 | size=14 183 | ) 184 | 185 | ax.text( 186 | -5.5, -0.5, 187 | # -3.5, -0.5, 188 | "Trials", 189 | fontsize=14, 190 | rotation=90 191 | ) 192 | 193 | else: 194 | ax.set_ylim(-0.3, 0.5) 195 | ax.set_yticks([0, 0.2]) 196 | ax.plot([-2, -2], [0, 0.2], 'k-', lw=3) 197 | 198 | ax.annotate("", 199 | xy=(-3, -0.25), 200 | xycoords="data", 201 | xytext=(-3, -0.15), 202 | textcoords="data", 203 | arrowprops=dict( 204 | arrowstyle="<-", 205 | connectionstyle="arc3,rad=0", 206 | linewidth="2", 207 | # color=(0.2, 0.2, 0.2) 208 | ), 209 | size=14 210 | ) 211 | 212 | ax.text( 213 | -11, -0.25, 214 | # -7, -0.25, 215 | "Trials", 216 | fontsize=14, 217 | rotation=90 218 | ) 219 | 220 | # ax.set_title(f"Trial {trial_idx}, Neuron {neuron_idx}") 221 | 222 | f, axes = plt.subplots( 223 | nrows=2, ncols=2, sharex=False, sharey=False, figsize=(8, 6) 224 | ) 225 | 226 | plot_axis( 227 | lorenz_spikes, 228 | lorenz_rates, 229 | lorenz_ndt, 230 | axes[0, 0], 231 | condition_idx=0, 232 | key="lorenz" 233 | ) 234 | 235 | plot_axis( 236 | lorenz_spikes, 237 | lorenz_rates, 238 | lorenz_ndt, 239 | axes[0, 1], 240 | condition_idx=1, 241 | key="lorenz" 242 | ) 243 | 244 | plot_axis( 245 | chaotic_spikes, 246 | chaotic_rates, 247 | chaotic_ndt, 248 | axes[1, 0], 249 | condition_idx=0, 250 | key="chaotic" 251 | ) 252 | 253 | plot_axis( 254 | chaotic_spikes, 255 | chaotic_rates, 256 | chaotic_ndt, 257 | axes[1, 1], 258 | condition_idx=1, 259 | key="chaotic" 260 | ) 261 | 262 | # plt.suptitle(f"{variant} {ckpt}", y=1.05) 263 | axes[0, 0].text(15, 0.45, "Lorenz", size=18, rotation=0) 264 | # axes[0, 0].text(-20, 0.2, "Lorenz", size=18, rotation=45) 265 | axes[1, 0].text(30, 0.1, "Chaotic", size=18, rotation=0) 266 | # axes[1, 0].text(20, 0.25, "Chaotic", size=18, rotation=0) 267 | 268 | # plt.tight_layout() 269 | f.subplots_adjust( 270 | # left=0.15, 271 | # bottom=-0.1, 272 | hspace=0.0, 273 | wspace=0.0 274 | ) 275 | # axes[0,0].set_xticks([]) 276 | axes[0,1].set_yticks([]) 277 | # axes[0,1].set_xticks([]) 278 | axes[1,1].set_yticks([]) 279 | 280 | legend = axes[1, 1].legend( 281 | # loc=(-.95, 0.97), 282 | loc=(-1.05, 0.85), 283 | fontsize=14, 284 | frameon=False, 285 | ncol=4, 286 | ) 287 | 288 | # for line in legend.get_lines(): 289 | # line.set_linewidth(3.0) 290 | 291 | # plt.savefig("lorenz_rates.png", dpi=300, bbox_inches="tight") 292 | # plt.savefig("lorenz_rates_2.png", dpi=300, bbox_inches="tight") 293 | plt.setp(legend.get_texts()[1], color=palette[0]) 294 | plt.setp(legend.get_texts()[2], color=palette[1]) 295 | plt.setp(legend.get_texts()[0], color="#111111") 296 | # plt.setp(legend.get_texts()[3], color=palette[2]) 297 | 298 | plt.savefig("3a_synth_qual.pdf") 299 | 300 | #%% 301 | # Total 1300 (130 x 10 per) 302 | # Val 260 (130 x 2 per) 303 | _, unique, counts = np.unique(chaotic_rates.numpy(), axis=0, return_inverse=True, return_counts=True) 304 | # unique, counts = chaotic_rates.unique(dim=0, return_counts=True) 305 | trial_idx = (unique == 0) 306 | # Oh, a lot of people look like 307 | print(np.where(trial_idx)) 308 | true_trial_rates = chaotic_rates[trial_idx][0, :, 0].cpu() 309 | print(true_trial_rates[:10]) -------------------------------------------------------------------------------- /scripts/fig_3b_4b_plot_hparams.py: -------------------------------------------------------------------------------- 1 | # ! Reference script. Not supported out of the box. 2 | 3 | #%% 4 | # param vs nll (3b) and match to vaf vs nll (4b) 5 | # 3b requires a CSV downloaded from TB Hparams page (automatically generated by Ray) 6 | # 4b requires saving down of rate predictions from each model (for NDT, this step is in `record_all_rates.py`) 7 | 8 | from pathlib import Path 9 | import sys 10 | module_path = str(Path('..').resolve()) 11 | if module_path not in sys.path: 12 | sys.path.append(module_path) 13 | from collections import defaultdict 14 | import time 15 | import matplotlib.pyplot as plt 16 | import seaborn as sns 17 | import numpy as np 18 | import pandas as pd 19 | import torch 20 | import torch.nn.functional as f 21 | from torch.utils import data 22 | 23 | from tensorboard.backend.event_processing.event_accumulator import EventAccumulator 24 | 25 | from src.run import prepare_config 26 | from src.runner import Runner 27 | from src.dataset import SpikesDataset, DATASET_MODES 28 | from src.mask import UNMASKED_LABEL 29 | 30 | from analyze_utils import init_by_ckpt 31 | 32 | tf_size_guidance = {'scalars': 1000} 33 | #%% 34 | # Extract the NLL and R2 35 | plot_path = Path('~/projects/transformer-modeling/scripts/hparams.csv').expanduser().resolve() 36 | df = pd.read_csv(plot_path) 37 | 38 | #%% 39 | SMALL_SIZE = 12 40 | MEDIUM_SIZE = 20 41 | LARGE_SIZE = 24 42 | 43 | def prep_plt(): 44 | plt.rc('font', size=MEDIUM_SIZE) # controls default text sizes 45 | plt.rc('axes', labelsize=LARGE_SIZE) # fontsize of the x and y labels 46 | plt.rc('xtick', labelsize=MEDIUM_SIZE) # fontsize of the tick labels 47 | plt.rc('ytick', labelsize=MEDIUM_SIZE) # fontsize of the tick labels 48 | # plt.rc('title', labelsize=MEDIUM_SIZE) # fontsize of the tick labels 49 | 50 | plt.rc('legend', fontsize=SMALL_SIZE) # legend fontsize 51 | plt.style.use('seaborn-muted') 52 | # plt.figure(figsize=(6,4)) 53 | 54 | spine_alpha = 0.5 55 | plt.gca().spines['right'].set_alpha(0.0) 56 | plt.gca().spines['bottom'].set_alpha(spine_alpha) 57 | # plt.gca().spines['bottom'].set_alpha(0) 58 | plt.gca().spines['left'].set_alpha(spine_alpha) 59 | # plt.gca().spines['left'].set_alpha(0) 60 | plt.gca().spines['top'].set_alpha(0.0) 61 | 62 | plt.tight_layout() 63 | 64 | plt.figure(figsize=(4, 6)) 65 | # plt.figure(figsize=(6, 4)) 66 | prep_plt() 67 | # So. I select the model with the least all-time MVE, and report metrics on that MVE ckpt. 68 | # Here I'm reporting numbers from the final ckpt instead. 69 | 70 | x_axis = 'unmasked_loss' 71 | # x_axis = 'masked_loss' 72 | sns.scatterplot(data=df, x=f"ray/tune/{x_axis}", y="ray/tune/r2", s=120) 73 | # sns.scatterplot(data=df, x="ray/tune/best_masked_loss", y="ray/tune/r2", s=120) 74 | plt.yticks(np.arange(0.8, 0.96, 0.15)) 75 | # plt.yticks(np.arange(0.8, 0.96, 0.05)) 76 | plt.xticks([0.37, 0.39]) 77 | plt.xlim(0.365, 0.39) 78 | plt.xlabel("Match to Spikes (NLL)") 79 | plt.ylabel("Rate Prediction $R^2$", labelpad=-20) 80 | from scipy.stats import pearsonr 81 | r = pearsonr(df[f'ray/tune/{x_axis}'], df['ray/tune/r2']) 82 | print(r) 83 | plt.text(0.378, 0.92, f"$\it{{\\rho}}$ : {r[0]:.3f}", size=LARGE_SIZE) 84 | 85 | plt.savefig("3_match_spikes.pdf", bbox_inches="tight") 86 | #%% 87 | palette = sns.color_palette(palette='muted', n_colors=2, desat=0.9) 88 | variant = 'm700_2296-s1' 89 | nlls = torch.load(f'{variant}_val_errs_sweep.pth') 90 | matches = torch.load(f'/snel/home/joely/projects/rds/{variant}_psth_match_sweep.pth') 91 | decoding = torch.load(f'/snel/home/joely/projects/rds/{variant}_deocding_sweep.pth') 92 | nll_arr = [] 93 | match_arr = [] 94 | decoding_arr = [] 95 | for key in nlls: 96 | nll_arr.append(nlls[key]) 97 | match_arr.append(matches[key]) 98 | decoding_arr.append(decoding[key]) 99 | plt.figure(figsize=(4, 5)) 100 | prep_plt() 101 | plt.scatter(nll_arr, match_arr, color = palette[0]) 102 | plt.xticks([0.139, 0.144], rotation=20) 103 | plt.xlim(0.139, 0.144) 104 | plt.yticks([0.55, 0.75]) 105 | # plt.yticks(np.linspace(0.5, 0.75, 2)) 106 | plt.xlabel("Match to Spikes (NLL)", labelpad=0, fontsize=MEDIUM_SIZE) 107 | plt.ylabel("Match to Empirical PSTH ($R^2$)", labelpad=0, fontsize=MEDIUM_SIZE) 108 | r = pearsonr(nll_arr, match_arr) 109 | print(r) 110 | 111 | plt.text(0.1392, 0.53, f"$\it{{\\rho}}$ : {r[0]:.3f}", size=LARGE_SIZE, color=palette[0]) 112 | # plt.text(0.141, 0.72, f"$\it{{\\rho}}$ : {r[0]:.3f}", size=LARGE_SIZE) 113 | plt.hlines(0.7078, 0.139, 0.144, linestyles="--", color=palette[1]) 114 | plt.text(0.141, 0.715, f"AutoLFADS", size=MEDIUM_SIZE, color=palette[1]) 115 | # plt.text(0.142, 0.715, f"LFADS", size=MEDIUM_SIZE, color=palette[1]) 116 | 117 | plt.savefig("4b_match_psth.pdf", bbox_inches="tight") 118 | #%% 119 | plt.figure(figsize=(4, 5)) 120 | prep_plt() 121 | plt.scatter(nll_arr, decoding_arr) 122 | plt.xticks([0.139, 0.144], rotation=20) 123 | plt.xlim(0.139, 0.144) 124 | 125 | #%% 126 | plt.figure(figsize=(4, 5)) 127 | prep_plt() 128 | plt.scatter(match_arr, decoding_arr) 129 | # plt.xticks([0.139, 0.144], rotation=20) 130 | # plt.xlim(0.139, 0.144) 131 | -------------------------------------------------------------------------------- /scripts/fig_5_lfads_times.py: -------------------------------------------------------------------------------- 1 | # Author: Joel Ye 2 | 3 | # LFADS timing 4 | # ! This is a reference script. May not run out of the box with the NDT environment (e.g. needs LFADS dependencies) 5 | #%% 6 | import os 7 | os.environ['CUDA_VISIBLE_DEVICES'] = '-1' 8 | 9 | import copy 10 | import cProfile 11 | import os 12 | import time 13 | import gc 14 | 15 | import h5py 16 | import matplotlib.pyplot as plt 17 | import numpy as np 18 | import tensorflow as tf 19 | import tensorflow_probability as tfp 20 | from tensorflow.keras.utils import Progbar 21 | 22 | from lfads_tf2.defaults import get_cfg_defaults 23 | from lfads_tf2.models import LFADS 24 | from lfads_tf2.tuples import DecoderInput, SamplingOutput 25 | from lfads_tf2.utils import (load_data, load_posterior_averages, 26 | restrict_gpu_usage) 27 | 28 | tfd = tfp.distributions 29 | gc.disable() # disable garbage collection 30 | 31 | # tf.debugging.set_log_device_placement(True) 32 | restrict_gpu_usage(gpu_ix=0) 33 | 34 | 35 | # %% 36 | # restore the LFADS model 37 | 38 | 39 | # model = LFADS(model_dir) # A hardcoded chaotic path 40 | model_dir = '/snel/home/joely/ray_results/lfads/chaotic-s1/best_model' 41 | cfg_path = os.path.join(model_dir, 'model_spec.yaml') # Don't directly load model 42 | 43 | def sample_and_average(n_samples=50, 44 | batch_size=64, 45 | merge_tv=False, 46 | seq_len_cap=100): 47 | cfg = get_cfg_defaults() 48 | cfg.merge_from_file(cfg_path) 49 | cfg.MODEL.SEQ_LEN = seq_len_cap 50 | cfg.freeze() 51 | model = LFADS( 52 | cfg_node=cfg, 53 | # model_dir='/snel/home/joely/ray_results/lfads/chaotic-s1/best_model' 54 | ) 55 | model.restore_weights() 56 | # ! Modified for timing 57 | if not model.is_trained: 58 | model.lgr.warn("Performing posterior sampling on an untrained model.") 59 | 60 | # define merging and splitting utilities 61 | def merge_samp_and_batch(data, batch_dim): 62 | """ Combines the sample and batch dimensions """ 63 | return tf.reshape(data, [n_samples * batch_dim] + 64 | tf.unstack(tf.shape(data)[2:])) 65 | 66 | def split_samp_and_batch(data, batch_dim): 67 | """ Splits up the sample and batch dimensions """ 68 | return tf.reshape(data, [n_samples, batch_dim] + 69 | tf.unstack(tf.shape(data)[1:])) 70 | 71 | # ========== POSTERIOR SAMPLING ========== 72 | # perform sampling on both training and validation data 73 | loop_times = [] 74 | for prefix, dataset in zip(['train_', 'valid_'], 75 | [model._train_ds, model._val_ds]): 76 | data_len = len(model.train_tuple.data) if prefix == 'train_' else len( 77 | model.val_tuple.data) 78 | 79 | # initialize lists to store rates 80 | all_outputs = [] 81 | model.lgr.info( 82 | "Posterior sample and average on {} segments.".format(data_len)) 83 | if not model.cfg.TRAIN.TUNE_MODE: 84 | pbar = Progbar(data_len, width=50, unit_name='dataset') 85 | 86 | def process_batch(): 87 | # unpack the batch 88 | data, _, ext_input = batch 89 | data = data[:,:seq_len_cap] 90 | ext_input = ext_input[:,:seq_len_cap] 91 | 92 | time_start = time.time() 93 | 94 | # for each chop in the dataset, compute the initial conditions 95 | # distribution 96 | ic_mean, ic_stddev, ci = model.encoder.graph_call(data) 97 | ic_post = tfd.MultivariateNormalDiag(ic_mean, ic_stddev) 98 | 99 | # sample from the posterior and merge sample and batch dimensions 100 | ic_post_samples = ic_post.sample(n_samples) 101 | ic_post_samples_merged = merge_samp_and_batch( 102 | ic_post_samples, len(data)) 103 | 104 | # tile and merge the controller inputs and the external inputs 105 | ci_tiled = tf.tile(tf.expand_dims(ci, axis=0), 106 | [n_samples, 1, 1, 1]) 107 | ci_merged = merge_samp_and_batch(ci_tiled, len(data)) 108 | ext_tiled = tf.tile(tf.expand_dims(ext_input, axis=0), 109 | [n_samples, 1, 1, 1]) 110 | ext_merged = merge_samp_and_batch(ext_tiled, len(data)) 111 | 112 | # pass all samples into the decoder 113 | dec_input = DecoderInput(ic_samp=ic_post_samples_merged, 114 | ci=ci_merged, 115 | ext_input=ext_merged) 116 | output_samples_merged = model.decoder.graph_call(dec_input) 117 | 118 | # average the outputs across samples 119 | output_samples = [ 120 | split_samp_and_batch(t, len(data)) 121 | for t in output_samples_merged 122 | ] 123 | output = [np.mean(t, axis=0) for t in output_samples] 124 | 125 | time_elapsed = time.time() - time_start 126 | 127 | if not model.cfg.TRAIN.TUNE_MODE: 128 | pbar.add(len(data)) 129 | return time_elapsed 130 | 131 | for batch in dataset.batch(batch_size): 132 | loop_times.append(process_batch()) 133 | return loop_times 134 | # %% 135 | def time_bin(bins): 136 | n_samples = 1 137 | p_loop_times = sample_and_average(batch_size=1, n_samples=n_samples, seq_len_cap=bins) 138 | del p_loop_times[0] # first iteration is slower due to graph initialization 139 | p_loop_times = np.array(p_loop_times) * 1e3 140 | print(f"{p_loop_times.mean():.3f}ms for {bins} bins") 141 | return p_loop_times 142 | 143 | all_times = [] 144 | for bins in range(5, 10, 5): 145 | for bins in range(5, 105, 5): 146 | all_times.append(time_bin(bins)) 147 | all_times = np.array(all_times) 148 | 149 | with open('lfads_times.npy', 'wb') as f: 150 | np.save(f, all_times) 151 | -------------------------------------------------------------------------------- /scripts/fig_5_ndt_times.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Author: Joel Ye 3 | 4 | # Run from scripts directory. 5 | # python timing_tests.py -l {1, 2, 6} 6 | 7 | #%% 8 | import os 9 | import os.path as osp 10 | import argparse 11 | import sys 12 | module_path = os.path.abspath(os.path.join('..')) 13 | if module_path not in sys.path: 14 | sys.path.append(module_path) 15 | 16 | import time 17 | import gc 18 | gc.disable() 19 | 20 | import h5py 21 | import matplotlib.pyplot as plt 22 | import seaborn as sns 23 | import numpy as np 24 | import torch 25 | from torch.utils import data 26 | import torch.nn.functional as f 27 | 28 | from src.dataset import DATASET_MODES, SpikesDataset 29 | from src.run import prepare_config 30 | from src.runner import Runner 31 | from analyze_utils import make_runner, get_multiplicative_weights, init_by_ckpt 32 | 33 | prefix="arxiv" 34 | base="" 35 | variant = "chaotic" 36 | 37 | run_type = "eval" 38 | exp_config = osp.join("../configs", prefix, f"{variant}.yaml") 39 | if base != "": 40 | exp_config = [osp.join("../configs", f"{base}.yaml"), exp_config] 41 | 42 | def get_parser(): 43 | parser = argparse.ArgumentParser() 44 | 45 | parser.add_argument( 46 | "--num-layers", "-l", 47 | type=int, 48 | required=True, 49 | ) 50 | return parser 51 | 52 | parser = get_parser() 53 | args = parser.parse_args() 54 | layers = vars(args)["num_layers"] 55 | 56 | def make_runner_of_bins(bin_count=100, layers=None): 57 | config, _ = prepare_config( 58 | exp_config, run_type, "", [ 59 | "USE_TENSORBOARD", False, 60 | "SYSTEM.NUM_GPUS", 1, 61 | ], suffix=prefix 62 | ) 63 | config.defrost() 64 | config.MODEL.TRIAL_LENGTH = bin_count 65 | if layers is not None: 66 | config.MODEL.NUM_LAYERS = layers 67 | config.MODEL.LEARNABLE_POSITION = False # Not sure why... 68 | config.freeze() 69 | return Runner(config) 70 | 71 | def time_length(trials=1300, bin_count=100, **kwargs): 72 | # 100 as upper bound 73 | runner = make_runner_of_bins(bin_count=bin_count, **kwargs) 74 | runner.logger.mute() 75 | runner.load_device() 76 | runner.max_spikes = 9 # from chaotic ckpt 77 | runner.num_neurons = 50 # from chaotic ckpt 78 | # runner.num_neurons = 202 # from chaotic ckpt 79 | runner.setup_model(runner.device) 80 | # whole_set = SpikesDataset(runner.config, runner.config.DATA.TRAIN_FILENAME, mode="trainval") 81 | # whole_set.clip_spikes(runner.max_spikes) 82 | # # print(f"Evaluating on {len(whole_set)} samples.") 83 | # data_generator = data.DataLoader(whole_set, 84 | # batch_size=1, shuffle=False 85 | # ) 86 | loop_times = [] 87 | with torch.no_grad(): 88 | probs = torch.full((1, bin_count, runner.num_neurons), 0.1) 89 | # probs = torch.full((1, bin_count, runner.num_neurons), 0.01) 90 | while len(loop_times) < trials: 91 | spikes = torch.bernoulli(probs).long() 92 | spikes = spikes.to(runner.device) 93 | start = time.time() 94 | runner.model(spikes, mask_labels=spikes, passthrough=True) 95 | delta = time.time() - start 96 | loop_times.append(delta) 97 | p_loop_times = np.array(loop_times) * 1e3 98 | print(f"{p_loop_times.mean():.4f}ms for {bin_count} bins") 99 | 100 | # A note about memory: It's a bit unclear why `empty_cache` is failing and memory still shows as used on torch, but the below diagnostic indicates the memory is not allocated, and will not cause OOM. So, a minor inconvenience for now. 101 | # device = runner.device 102 | # runner.model.to('cpu') 103 | # t = torch.cuda.get_device_properties(device).total_memory 104 | # c = torch.cuda.memory_cached(device) 105 | # a = torch.cuda.memory_allocated(device) 106 | # print(device, t, c, a) 107 | # del runner 108 | # t = torch.cuda.get_device_properties(device).total_memory 109 | # c = torch.cuda.memory_cached(device) 110 | # a = torch.cuda.memory_allocated(device) 111 | # print(device, t, c, a) 112 | # del data_generator 113 | # del whole_set 114 | # del spikes 115 | # torch.cuda.empty_cache() 116 | return p_loop_times 117 | 118 | times = [] 119 | for i in range(5, 15, 5): 120 | p_loop_times = time_length(trials=2000, bin_count=i, layers=layers) 121 | times.append(p_loop_times) 122 | 123 | times = np.stack(times, axis=0) 124 | np.save(f'ndt_times_layer_{layers}', times) 125 | 126 | #%% 127 | 128 | -------------------------------------------------------------------------------- /scripts/fig_6_plot_losses.py: -------------------------------------------------------------------------------- 1 | #%% 2 | from pathlib import Path 3 | import sys 4 | module_path = str(Path('..').resolve()) 5 | if module_path not in sys.path: 6 | sys.path.append(module_path) 7 | from collections import defaultdict 8 | import time 9 | import matplotlib.pyplot as plt 10 | import seaborn as sns 11 | import numpy as np 12 | import torch 13 | import torch.nn.functional as f 14 | from torch.utils import data 15 | 16 | from tensorboard.backend.event_processing.event_accumulator import EventAccumulator 17 | 18 | from src.run import prepare_config 19 | from src.runner import Runner 20 | from src.dataset import SpikesDataset, DATASET_MODES 21 | from src.mask import UNMASKED_LABEL 22 | 23 | from analyze_utils import init_by_ckpt 24 | 25 | tf_size_guidance = {'scalars': 1000} 26 | #%% 27 | 28 | plot_folder = Path('~/ray_results/ndt/gridsearch/lograte_tb').expanduser().resolve() 29 | variants = { 30 | "m700_no_log-s1": "Rates", 31 | "m700_2296-s1": "Logrates" 32 | } 33 | 34 | all_info = defaultdict(dict) # key: variant, value: dict per variant, keyed by run (value will dict of step and value) 35 | for variant in variants: 36 | v_dir = plot_folder.joinpath(variant) 37 | for run_dir in v_dir.iterdir(): 38 | if not run_dir.is_dir(): 39 | continue 40 | tb_dir = run_dir.joinpath('tb') 41 | all_info[variant][tb_dir.parts[-2][:10]] = defaultdict(list) # key: "step" or "value", value: info 42 | for tb_file in tb_dir.iterdir(): 43 | event_acc = EventAccumulator(str(tb_file), tf_size_guidance) 44 | event_acc.Reload() 45 | # print(event_acc.Tags()) 46 | if 'val_loss' in event_acc.Tags()['scalars']: 47 | val_loss = event_acc.Scalars('val_loss') 48 | all_info[variant][tb_dir.parts[-2][:10]]['step'].append(val_loss[0].step) 49 | all_info[variant][tb_dir.parts[-2][:10]]['value'].append(val_loss[0].value) 50 | 51 | #%% 52 | SMALL_SIZE = 12 53 | MEDIUM_SIZE = 15 54 | LARGE_SIZE = 18 55 | 56 | def prep_plt(): 57 | plt.rc('font', size=MEDIUM_SIZE) # controls default text sizes 58 | plt.rc('axes', labelsize=LARGE_SIZE) # fontsize of the x and y labels 59 | plt.rc('xtick', labelsize=MEDIUM_SIZE) # fontsize of the tick labels 60 | plt.rc('ytick', labelsize=MEDIUM_SIZE) # fontsize of the tick labels 61 | # plt.rc('title', labelsize=MEDIUM_SIZE) # fontsize of the tick labels 62 | 63 | plt.rc('legend', fontsize=SMALL_SIZE) # legend fontsize 64 | plt.style.use('seaborn-muted') 65 | # plt.figure(figsize=(6,4)) 66 | 67 | spine_alpha = 0.5 68 | plt.gca().spines['right'].set_alpha(0.0) 69 | # plt.gca().spines['bottom'].set_alpha(spine_alpha) 70 | plt.gca().spines['bottom'].set_alpha(0) 71 | # plt.gca().spines['left'].set_alpha(spine_alpha) 72 | plt.gca().spines['left'].set_alpha(0) 73 | plt.gca().spines['top'].set_alpha(0.0) 74 | 75 | plt.tight_layout() 76 | 77 | plt.figure(figsize=(6, 4)) 78 | 79 | prep_plt() 80 | palette = sns.color_palette(palette='muted', n_colors=len(variants), desat=0.9) 81 | colors = {} 82 | color_ind = 0 83 | for variant, label in variants.items(): 84 | colors[variant] = palette[color_ind] 85 | color_ind += 1 86 | legend_str = label 87 | for run, run_info in all_info[variant].items(): 88 | # Sort 89 | sort_ind = np.argsort(run_info['step']) 90 | steps = np.array(run_info['step'])[sort_ind] 91 | vals = np.array(run_info['value'])[sort_ind] 92 | plt.plot(steps, vals, color=colors[variant], label=legend_str) 93 | # plt.hlines(vals.min(),0, 25000, color=colors[variant]) 94 | legend_str = "" 95 | plt.legend(fontsize=MEDIUM_SIZE, frameon=False) 96 | plt.xticks(np.arange(0, 25001, 12500)) 97 | plt.yticks(np.arange(0.1, 1.21, 0.5)) 98 | plt.ylim(0.1, 1.21) 99 | # plt.yticks(np.arange(0.2, 1.21, 0.5)) 100 | plt.ylabel("Loss (NLL)") 101 | plt.yscale('log') 102 | plt.xlabel("Epochs") 103 | plt.savefig("6_losses.pdf", bbox_inches="tight") -------------------------------------------------------------------------------- /scripts/nlb.py: -------------------------------------------------------------------------------- 1 | #%% 2 | 3 | # 1. Load model and get rate predictions 4 | import os 5 | import os.path as osp 6 | from pathlib import Path 7 | import sys 8 | module_path = os.path.abspath(os.path.join('..')) 9 | if module_path not in sys.path: 10 | sys.path.append(module_path) 11 | 12 | import time 13 | import h5py 14 | import matplotlib.pyplot as plt 15 | import seaborn as sns 16 | import numpy as np 17 | import torch 18 | import torch.nn.functional as f 19 | from torch.utils import data 20 | 21 | from nlb_tools.evaluation import evaluate 22 | 23 | from src.run import prepare_config 24 | from src.runner import Runner 25 | from src.dataset import SpikesDataset, DATASET_MODES 26 | from src.mask import UNMASKED_LABEL 27 | 28 | from analyze_utils import init_by_ckpt 29 | 30 | variant = "area2_bump" 31 | variant = "mc_maze" 32 | # variant = "mc_maze_large" 33 | # variant = "mc_maze_medium" 34 | # variant = "mc_maze_small" 35 | variant = 'dmfc_rsg' 36 | variant = 'dmfc_rsg2' 37 | # variant = 'mc_rtt' 38 | 39 | is_ray = True 40 | # is_ray = False 41 | 42 | if is_ray: 43 | best_model = "best_model" 44 | # best_model = "best_model_unmasked" 45 | lve = "lfve" if "unmasked" in best_model else "lve" 46 | 47 | def to_path(variant): 48 | grid_var = f"{variant}_lite" 49 | ckpt_path = f"/snel/home/joely/ray_results/ndt/gridsearch/{grid_var}/best_model/ckpts/{grid_var}.{lve}.pth" 50 | runner, spikes, rates, heldout_spikes, forward_spikes = init_by_ckpt(ckpt_path, mode=DATASET_MODES.val) 51 | print(runner.config.VARIANT) 52 | return runner, spikes, rates, heldout_spikes, ckpt_path 53 | 54 | runner, spikes, rates, heldout_spikes, ckpt_path = to_path(variant) 55 | else: 56 | ckpt_dir = Path(f"/snel/share/joel/transformer_modeling/{variant}/") 57 | ckpt_path = ckpt_dir.joinpath(f"{variant}.lve.pth") 58 | runner, spikes, rates, heldout_spikes, forward_spikes = init_by_ckpt(ckpt_path, mode=DATASET_MODES.val) 59 | 60 | eval_rates, _ = runner.get_rates( 61 | checkpoint_path=ckpt_path, 62 | save_path = None, 63 | mode = DATASET_MODES.val 64 | ) 65 | train_rates, _ = runner.get_rates( 66 | checkpoint_path=ckpt_path, 67 | save_path = None, 68 | mode = DATASET_MODES.train 69 | ) 70 | 71 | # * Val 72 | eval_rates, eval_rates_forward = torch.split(eval_rates, [spikes.size(1), eval_rates.size(1) - spikes.size(1)], 1) 73 | eval_rates_heldin_forward, eval_rates_heldout_forward = torch.split(eval_rates_forward, [spikes.size(-1), heldout_spikes.size(-1)], -1) 74 | train_rates, _ = torch.split(train_rates, [spikes.size(1), train_rates.size(1) - spikes.size(1)], 1) 75 | eval_rates_heldin, eval_rates_heldout = torch.split(eval_rates, [spikes.size(-1), heldout_spikes.size(-1)], -1) 76 | train_rates_heldin, train_rates_heldout = torch.split(train_rates, [spikes.size(-1), heldout_spikes.size(-1)], -1) 77 | 78 | 79 | #%% 80 | 81 | output_dict = { 82 | variant: { 83 | 'train_rates_heldin': train_rates_heldin.cpu().numpy(), 84 | 'train_rates_heldout': train_rates_heldout.cpu().numpy(), 85 | 'eval_rates_heldin': eval_rates_heldin.cpu().numpy(), 86 | 'eval_rates_heldout': eval_rates_heldout.cpu().numpy(), 87 | 'eval_rates_heldin_forward': eval_rates_heldin_forward.cpu().numpy(), 88 | 'eval_rates_heldout_forward': eval_rates_heldout_forward.cpu().numpy() 89 | } 90 | } 91 | 92 | # target_dict = torch.load(f'/snel/home/joely/tmp/{variant}_target.pth') 93 | target_dict = np.load(f'/snel/home/joely/tmp/{variant}_target.npy', allow_pickle=True).item() 94 | 95 | print(evaluate(target_dict, output_dict)) 96 | 97 | #%% 98 | # * Test 99 | 100 | variant = 'dmfc_rsg' 101 | runner.config.defrost() 102 | runner.config.DATA.TRAIN_FILENAME = f'{variant}_test.h5' 103 | runner.config.freeze() 104 | train_rates, _ = runner.get_rates( 105 | checkpoint_path=ckpt_path, 106 | save_path = None, 107 | mode = DATASET_MODES.train 108 | ) 109 | eval_rates, _ = runner.get_rates( 110 | checkpoint_path=ckpt_path, 111 | save_path = None, 112 | mode = DATASET_MODES.val, 113 | ) 114 | 115 | eval_rates, eval_rates_forward = torch.split(eval_rates, [spikes.size(1), eval_rates.size(1) - spikes.size(1)], 1) 116 | eval_rates_heldin_forward, eval_rates_heldout_forward = torch.split(eval_rates_forward, [spikes.size(-1), heldout_spikes.size(-1)], -1) 117 | train_rates, _ = torch.split(train_rates, [spikes.size(1), train_rates.size(1) - spikes.size(1)], 1) 118 | eval_rates_heldin, eval_rates_heldout = torch.split(eval_rates, [spikes.size(-1), heldout_spikes.size(-1)], -1) 119 | train_rates_heldin, train_rates_heldout = torch.split(train_rates, [spikes.size(-1), heldout_spikes.size(-1)], -1) 120 | 121 | output_dict = { 122 | variant: { 123 | 'train_rates_heldin': train_rates_heldin.cpu().numpy(), 124 | 'train_rates_heldout': train_rates_heldout.cpu().numpy(), 125 | 'eval_rates_heldin': eval_rates_heldin.cpu().numpy(), 126 | 'eval_rates_heldout': eval_rates_heldout.cpu().numpy(), 127 | 'eval_rates_heldin_forward': eval_rates_heldin_forward.cpu().numpy(), 128 | 'eval_rates_heldout_forward': eval_rates_heldout_forward.cpu().numpy() 129 | } 130 | } 131 | 132 | print(evaluate('/snel/share/data/nlb/test_data_do_not_share/eval_data_test.h5', output_dict)) 133 | 134 | #%% 135 | import h5py 136 | with h5py.File('ndt_maze_preds.h5', 'w') as f: 137 | group = f.create_group('mc_maze') 138 | for key in output_dict['mc_maze']: 139 | group.create_dataset(key, data=output_dict['mc_maze'][key]) 140 | #%% 141 | # Viz some trials 142 | trials = [1, 2, 3] 143 | neuron = 10 144 | trial_rates = train_rates_heldout[trials, :, neuron].cpu() 145 | trial_spikes = heldout_spikes[trials, :, neuron].cpu() 146 | trial_time = heldout_spikes.size(1) 147 | """ """ 148 | spike_level = 0.05 149 | # one_tenth_point = 0.1 * (50 if key == "lorenz" else 100) 150 | # ax.plot([0, one_tenth_point], [(first_n + 1) * spike_level, (first_n + 1) * spike_level], 'k-', lw=3) # 5 / 50 = 0.1 151 | 152 | for trial_index in range(trial_spikes.size(0)): 153 | print(trial_spikes[trial_index]) 154 | # plt.plot(trial_rates[trial_index].exp()) 155 | spike_times, = np.where(trial_spikes[trial_index].numpy()) 156 | plt.scatter(spike_times, spike_level * (trial_index + 1)*np.ones_like(spike_times), c='k', marker='|', label='Spikes', s=30) 157 | # labels['Spikes'] = "" 158 | 159 | print(train_rates_heldout.size()) 160 | 161 | # %% 162 | print(heldout_spikes.sum(0).sum(0).argmax()) 163 | # print(heldout_spikes[:,:,15]) 164 | for i in range(0, heldout_spikes.size(0), 6): 165 | plt.plot(heldout_spikes[i, :, 15].cpu()) -------------------------------------------------------------------------------- /scripts/nlb_from_scratch.ipynb: -------------------------------------------------------------------------------- 1 | {"cells":[{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["\"\"\"\n","- Let's setup NDT for train/eval from scratch.\n","- (Assumes your device is reasonably pytorch/GPU compatible)\n","- This is an interactive python script run via vscode. If you'd like to run as a notebook\n","\n","Run to setup requirements:\n","```\n"," Making a new env\n"," - python 3.7\n"," - conda install pytorch torchvision torchaudio cudatoolkit=10.2 -c pytorch\n"," - conda install seaborn\n"," - pip install yacs pytorch_transformers tensorboard \"ray[tune]\" sklearn\n"," - pip install dandi \"pynwb>=2.0.0\"\n"," Or from environment.yaml\n","\n"," Then,\n"," - conda develop ~/path/to/nlb_tools\n","```\n","\n","Install NLB dataset(s), and create the h5s for training. (Here we install MC_Maze_Small)\n","```\n"," pip install dandi\n"," dandi download DANDI:000140/0.220113.0408\n","```\n","\n","This largely follows the `basic_example.ipynb` in `nlb_tools`. The only distinction is that we save to an h5.\n","\"\"\"\n","\n","from nlb_tools.nwb_interface import NWBDataset\n","from nlb_tools.make_tensors import (\n"," make_train_input_tensors, make_eval_input_tensors, make_eval_target_tensors, save_to_h5\n",")\n","from nlb_tools.evaluation import evaluate\n","\n","import numpy as np\n","import pandas as pd\n","import h5py\n","\n","import logging\n","logging.basicConfig(level=logging.INFO)\n","\n","dataset_name = 'mc_maze_small'\n","datapath = '/home/joelye/user_data/nlb/000140/sub-Jenkins/'\n","dataset = NWBDataset(datapath)\n","\n","# Prepare dataset\n","phase = 'val'\n","\n","# Choose bin width and resample\n","bin_width = 5\n","dataset.resample(bin_width)\n","\n","# Create suffix for group naming later\n","suffix = '' if (bin_width == 5) else f'_{int(bin_width)}'\n","\n","train_split = 'train' if (phase == 'val') else ['train', 'val']\n","train_dict = make_train_input_tensors(\n"," dataset, dataset_name=dataset_name, trial_split=train_split, save_file=False,\n"," include_behavior=True,\n"," include_forward_pred = True,\n",")\n","\n","# Show fields of returned dict\n","print(train_dict.keys())\n","\n","# Unpack data\n","train_spikes_heldin = train_dict['train_spikes_heldin']\n","train_spikes_heldout = train_dict['train_spikes_heldout']\n","\n","# Print 3d array shape: trials x time x channel\n","print(train_spikes_heldin.shape)\n","\n","## Make eval data (i.e. val)\n","\n","# Split for evaluation is same as phase name\n","eval_split = phase\n","# Make data tensors\n","eval_dict = make_eval_input_tensors(\n"," dataset, dataset_name=dataset_name, trial_split=eval_split, save_file=False,\n",")\n","print(eval_dict.keys()) # only includes 'eval_spikes_heldout' if available\n","eval_spikes_heldin = eval_dict['eval_spikes_heldin']\n","\n","print(eval_spikes_heldin.shape)\n","\n","h5_dict = {\n"," **train_dict,\n"," **eval_dict\n","}\n","\n","h5_target = '/home/joelye/user_data/nlb/mc_maze_small.h5'\n","save_to_h5(h5_dict, h5_target)\n","\n",""]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["\"\"\"\n","- At this point we should be able to train a basic model.\n","- In CLI, run a training call, replacing the appropriate paths\n","```\n"," ./scripts/train.sh mc_maze_small_from_scratch\n","\n"," OR\n","\n"," python ray_random.py -e ./configs/mc_maze_small_from_scratch.yaml\n"," (CLI overrides aren't available here, so make another config file)\n","```\n","- Once this is done training (~0.5hr for non-search), let's load the results...\n","\"\"\"\n","import os\n","import os.path as osp\n","from pathlib import Path\n","import sys\n","\n","# Add ndt src if not in path\n","module_path = osp.abspath(osp.join('..'))\n","if module_path not in sys.path:\n"," sys.path.append(module_path)\n","\n","import time\n","import h5py\n","import matplotlib.pyplot as plt\n","import seaborn as sns\n","import numpy as np\n","import torch\n","import torch.nn.functional as f\n","from torch.utils import data\n","\n","from nlb_tools.evaluation import evaluate\n","\n","from src.run import prepare_config\n","from src.runner import Runner\n","from src.dataset import SpikesDataset, DATASET_MODES\n","\n","from analyze_utils import init_by_ckpt\n","\n","variant = \"mc_maze_small_from_scratch\"\n","\n","is_ray = True\n","is_ray = False\n","\n","if is_ray:\n"," ckpt_path = f\"/home/joelye/user_data/nlb/ndt_runs/ray/{variant}/best_model/ckpts/{variant}.lve.pth\"\n"," runner, spikes, rates, heldout_spikes, forward_spikes = init_by_ckpt(ckpt_path, mode=DATASET_MODES.val)\n","else:\n"," ckpt_dir = Path(\"/home/joelye/user_data/nlb/ndt_runs/\")\n"," ckpt_path = ckpt_dir / variant / f\"{variant}.lve.pth\"\n"," runner, spikes, rates, heldout_spikes, forward_spikes = init_by_ckpt(ckpt_path, mode=DATASET_MODES.val)\n","\n","eval_rates, _ = runner.get_rates(\n"," checkpoint_path=ckpt_path,\n"," save_path = None,\n"," mode = DATASET_MODES.val\n",")\n","train_rates, _ = runner.get_rates(\n"," checkpoint_path=ckpt_path,\n"," save_path = None,\n"," mode = DATASET_MODES.train\n",")\n","\n","# * Val\n","eval_rates, eval_rates_forward = torch.split(eval_rates, [spikes.size(1), eval_rates.size(1) - spikes.size(1)], 1)\n","eval_rates_heldin_forward, eval_rates_heldout_forward = torch.split(eval_rates_forward, [spikes.size(-1), heldout_spikes.size(-1)], -1)\n","train_rates, _ = torch.split(train_rates, [spikes.size(1), train_rates.size(1) - spikes.size(1)], 1)\n","eval_rates_heldin, eval_rates_heldout = torch.split(eval_rates, [spikes.size(-1), heldout_spikes.size(-1)], -1)\n","train_rates_heldin, train_rates_heldout = torch.split(train_rates, [spikes.size(-1), heldout_spikes.size(-1)], -1)\n",""]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["# * Viz some trials\n","trials = [1, 2, 3]\n","neuron = 10\n","trial_rates = train_rates_heldout[trials, :, neuron].cpu()\n","trial_spikes = heldout_spikes[trials, :, neuron].cpu()\n","trial_time = heldout_spikes.size(1)\n","\n","spike_level = 0.05\n","f, axes = plt.subplots(2, figsize=(6, 4))\n","times = np.arange(0, trial_time * 0.05, 0.05)\n","for trial_index in range(trial_spikes.size(0)):\n"," spike_times, = np.where(trial_spikes[trial_index].numpy())\n"," spike_times = spike_times * 0.05\n"," axes[0].scatter(spike_times, spike_level * (trial_index + 1)*np.ones_like(spike_times), marker='|', label='Spikes', s=30)\n"," axes[1].plot(times, trial_rates[trial_index].exp())\n","\n",""]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["# Submission e.g. as in `basic_example.ipynb`\n","# Looks like this model is pretty terrible :/\n","output_dict = {\n"," dataset_name + suffix: {\n"," 'train_rates_heldin': train_rates_heldin.cpu().numpy(),\n"," 'train_rates_heldout': train_rates_heldout.cpu().numpy(),\n"," 'eval_rates_heldin': eval_rates_heldin.cpu().numpy(),\n"," 'eval_rates_heldout': eval_rates_heldout.cpu().numpy(),\n"," 'eval_rates_heldin_forward': eval_rates_heldin_forward.cpu().numpy(),\n"," 'eval_rates_heldout_forward': eval_rates_heldout_forward.cpu().numpy()\n"," }\n","}\n","\n","# Reset logging level to hide excessive info messages\n","logging.getLogger().setLevel(logging.WARNING)\n","\n","# If 'val' phase, make the target data\n","if phase == 'val':\n"," # Note that the RTT task is not well suited to trial averaging, so PSTHs are not made for it\n"," target_dict = make_eval_target_tensors(dataset, dataset_name=dataset_name, train_trial_split='train', eval_trial_split='val', include_psth=True, save_file=False)\n","\n"," # Demonstrate target_dict structure\n"," print(target_dict.keys())\n"," print(target_dict[dataset_name + suffix].keys())\n","\n","# Set logging level again\n","logging.getLogger().setLevel(logging.INFO)\n","\n","if phase == 'val':\n"," print(evaluate(target_dict, output_dict))\n","\n","# e.g. with targets to compare to\n","# target_dict = torch.load(f'/snel/home/joely/tmp/{variant}_target.pth')\n","# target_dict = np.load(f'/snel/home/joely/tmp/{variant}_target.npy', allow_pickle=True).item()\n","\n","# print(evaluate(target_dict, output_dict))\n","\n","# e.g. to upload to EvalAI\n","# with h5py.File('ndt_maze_preds.h5', 'w') as f:\n","# group = f.create_group('mc_maze')\n","# for key in output_dict['mc_maze']:\n","# group.create_dataset(key, data=output_dict['mc_maze'][key])"]}],"nbformat":4,"nbformat_minor":2,"metadata":{"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":3},"orig_nbformat":4}} -------------------------------------------------------------------------------- /scripts/nlb_from_scratch.py: -------------------------------------------------------------------------------- 1 | #%% 2 | """ 3 | - Let's setup NDT for train/eval from scratch. 4 | - (Assumes your device is reasonably pytorch/GPU compatible) 5 | - This is an interactive python script run via vscode. If you'd like to run as a notebook 6 | 7 | Run to setup requirements: 8 | ``` 9 | Making a new env 10 | - python 3.7 11 | - conda install pytorch torchvision torchaudio cudatoolkit=10.2 -c pytorch 12 | - conda install seaborn 13 | - pip install yacs pytorch_transformers tensorboard "ray[tune]" sklearn 14 | - pip install dandi "pynwb>=2.0.0" 15 | Or from nlb.yaml 16 | 17 | Then, 18 | - conda develop ~/path/to/nlb_tools 19 | ``` 20 | 21 | Install NLB dataset(s), and create the h5s for training. (Here we install MC_Maze_Small) 22 | ``` 23 | pip install dandi 24 | dandi download DANDI:000140/0.220113.0408 25 | ``` 26 | 27 | This largely follows the `basic_example.ipynb` in `nlb_tools`. The only distinction is that we save to an h5. 28 | """ 29 | 30 | from nlb_tools.nwb_interface import NWBDataset 31 | from nlb_tools.make_tensors import ( 32 | make_train_input_tensors, make_eval_input_tensors, make_eval_target_tensors, save_to_h5 33 | ) 34 | from nlb_tools.evaluation import evaluate 35 | 36 | import numpy as np 37 | import pandas as pd 38 | import h5py 39 | 40 | import logging 41 | logging.basicConfig(level=logging.INFO) 42 | 43 | dataset_name = 'mc_maze_small' 44 | datapath = '/home/joelye/user_data/nlb/000140/sub-Jenkins/' 45 | dataset = NWBDataset(datapath) 46 | 47 | # Prepare dataset 48 | phase = 'val' 49 | 50 | # Choose bin width and resample 51 | bin_width = 5 52 | dataset.resample(bin_width) 53 | 54 | # Create suffix for group naming later 55 | suffix = '' if (bin_width == 5) else f'_{int(bin_width)}' 56 | 57 | train_split = 'train' if (phase == 'val') else ['train', 'val'] 58 | train_dict = make_train_input_tensors( 59 | dataset, dataset_name=dataset_name, trial_split=train_split, save_file=False, 60 | include_behavior=True, 61 | include_forward_pred = True, 62 | ) 63 | 64 | # Show fields of returned dict 65 | print(train_dict.keys()) 66 | 67 | # Unpack data 68 | train_spikes_heldin = train_dict['train_spikes_heldin'] 69 | train_spikes_heldout = train_dict['train_spikes_heldout'] 70 | 71 | # Print 3d array shape: trials x time x channel 72 | print(train_spikes_heldin.shape) 73 | 74 | ## Make eval data (i.e. val) 75 | 76 | # Split for evaluation is same as phase name 77 | eval_split = phase 78 | # eval_dict = make_eval_input_tensors( 79 | # dataset, dataset_name=dataset_name, trial_split=eval_split, save_file=False, 80 | # ) 81 | # Make data tensors - use all chunks including forward prediction for training NDT 82 | eval_dict = make_train_input_tensors( 83 | dataset, dataset_name=dataset_name, trial_split=['val'], save_file=False, include_forward_pred=True, 84 | ) 85 | eval_dict = { 86 | f'eval{key[5:]}': val for key, val in eval_dict.items() 87 | } 88 | eval_spikes_heldin = eval_dict['eval_spikes_heldin'] 89 | 90 | print(eval_spikes_heldin.shape) 91 | 92 | h5_dict = { 93 | **train_dict, 94 | **eval_dict 95 | } 96 | 97 | h5_target = '/home/joelye/user_data/nlb/mc_maze_small.h5' 98 | save_to_h5(h5_dict, h5_target, overwrite=True) 99 | 100 | 101 | #%% 102 | """ 103 | - At this point we should be able to train a basic model. 104 | - In CLI, run a training call, replacing the appropriate paths 105 | ``` 106 | ./scripts/train.sh mc_maze_small_from_scratch 107 | 108 | OR 109 | 110 | python ray_random.py -e ./configs/mc_maze_small_from_scratch.yaml 111 | (CLI overrides aren't available here, so make another config file) 112 | ``` 113 | - Once this is done training (~0.5hr for non-search), let's load the results... 114 | """ 115 | import os 116 | import os.path as osp 117 | from pathlib import Path 118 | import sys 119 | 120 | # Add ndt src if not in path 121 | module_path = osp.abspath(osp.join('..')) 122 | if module_path not in sys.path: 123 | sys.path.append(module_path) 124 | 125 | import time 126 | import h5py 127 | import matplotlib.pyplot as plt 128 | import seaborn as sns 129 | import numpy as np 130 | import torch 131 | import torch.nn.functional as f 132 | from torch.utils import data 133 | from ray import tune 134 | 135 | from nlb_tools.evaluation import evaluate 136 | 137 | from src.run import prepare_config 138 | from src.runner import Runner 139 | from src.dataset import SpikesDataset, DATASET_MODES 140 | from analyze_utils import init_by_ckpt 141 | 142 | variant = "mc_maze_small_from_scratch" 143 | 144 | is_ray = True 145 | # is_ray = False 146 | 147 | if is_ray: 148 | tune_dir = f"/home/joelye/user_data/nlb/ndt_runs/ray/{variant}" 149 | df = tune.ExperimentAnalysis(tune_dir).dataframe() 150 | # ckpt_path = f"/home/joelye/user_data/nlb/ndt_runs/ray/{variant}/best_model/ckpts/{variant}.lve.pth" 151 | ckpt_dir = df.loc[df["best_masked_loss"].idxmin()].logdir 152 | ckpt_path = f"{ckpt_dir}/ckpts/{variant}.lve.pth" 153 | runner, spikes, rates, heldout_spikes, forward_spikes = init_by_ckpt(ckpt_path, mode=DATASET_MODES.val) 154 | else: 155 | ckpt_dir = Path("/home/joelye/user_data/nlb/ndt_runs/") 156 | ckpt_path = ckpt_dir / variant / f"{variant}.lve.pth" 157 | runner, spikes, rates, heldout_spikes, forward_spikes = init_by_ckpt(ckpt_path, mode=DATASET_MODES.val) 158 | 159 | eval_rates, _ = runner.get_rates( 160 | checkpoint_path=ckpt_path, 161 | save_path = None, 162 | mode = DATASET_MODES.val 163 | ) 164 | train_rates, _ = runner.get_rates( 165 | checkpoint_path=ckpt_path, 166 | save_path = None, 167 | mode = DATASET_MODES.train 168 | ) 169 | 170 | # * Val 171 | eval_rates, eval_rates_forward = torch.split(eval_rates, [spikes.size(1), eval_rates.size(1) - spikes.size(1)], 1) 172 | eval_rates_heldin_forward, eval_rates_heldout_forward = torch.split(eval_rates_forward, [spikes.size(-1), heldout_spikes.size(-1)], -1) 173 | train_rates, _ = torch.split(train_rates, [spikes.size(1), train_rates.size(1) - spikes.size(1)], 1) 174 | eval_rates_heldin, eval_rates_heldout = torch.split(eval_rates, [spikes.size(-1), heldout_spikes.size(-1)], -1) 175 | train_rates_heldin, train_rates_heldout = torch.split(train_rates, [spikes.size(-1), heldout_spikes.size(-1)], -1) 176 | #%% 177 | # * Viz some trials 178 | trials = [1, 2, 3] 179 | neuron = 10 180 | trial_rates = train_rates_heldout[trials, :, neuron].cpu() 181 | trial_spikes = heldout_spikes[trials, :, neuron].cpu() 182 | trial_time = heldout_spikes.size(1) 183 | 184 | spike_level = 0.05 185 | f, axes = plt.subplots(2, figsize=(6, 4)) 186 | times = np.arange(0, trial_time * 0.05, 0.05) 187 | for trial_index in range(trial_spikes.size(0)): 188 | spike_times, = np.where(trial_spikes[trial_index].numpy()) 189 | spike_times = spike_times * 0.05 190 | axes[0].scatter(spike_times, spike_level * (trial_index + 1)*np.ones_like(spike_times), marker='|', label='Spikes', s=30) 191 | axes[1].plot(times, trial_rates[trial_index].exp()) 192 | 193 | #%% 194 | # Submission e.g. as in `basic_example.ipynb` 195 | # Looks like this model is pretty terrible :/ 196 | output_dict = { 197 | dataset_name + suffix: { 198 | 'train_rates_heldin': train_rates_heldin.cpu().numpy(), 199 | 'train_rates_heldout': train_rates_heldout.cpu().numpy(), 200 | 'eval_rates_heldin': eval_rates_heldin.cpu().numpy(), 201 | 'eval_rates_heldout': eval_rates_heldout.cpu().numpy(), 202 | # 'eval_rates_heldin_forward': eval_rates_heldin_forward.cpu().numpy(), 203 | # 'eval_rates_heldout_forward': eval_rates_heldout_forward.cpu().numpy() 204 | } 205 | } 206 | 207 | # Reset logging level to hide excessive info messages 208 | logging.getLogger().setLevel(logging.WARNING) 209 | 210 | # If 'val' phase, make the target data 211 | if phase == 'val': 212 | # Note that the RTT task is not well suited to trial averaging, so PSTHs are not made for it 213 | target_dict = make_eval_target_tensors(dataset, dataset_name=dataset_name, train_trial_split='train', eval_trial_split='val', include_psth=True, save_file=False) 214 | 215 | # Demonstrate target_dict structure 216 | print(target_dict.keys()) 217 | print(target_dict[dataset_name + suffix].keys()) 218 | 219 | # Set logging level again 220 | logging.getLogger().setLevel(logging.INFO) 221 | 222 | if phase == 'val': 223 | print(evaluate(target_dict, output_dict)) 224 | 225 | # e.g. with targets to compare to 226 | # target_dict = torch.load(f'/snel/home/joely/tmp/{variant}_target.pth') 227 | # target_dict = np.load(f'/snel/home/joely/tmp/{variant}_target.npy', allow_pickle=True).item() 228 | 229 | # print(evaluate(target_dict, output_dict)) 230 | 231 | # e.g. to upload to EvalAI 232 | # with h5py.File('ndt_maze_preds.h5', 'w') as f: 233 | # group = f.create_group('mc_maze') 234 | # for key in output_dict['mc_maze']: 235 | # group.create_dataset(key, data=output_dict['mc_maze'][key]) -------------------------------------------------------------------------------- /scripts/record_all_rates.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Author: Joel Ye 3 | 4 | # Notebook for interactive model evaluation/analysis 5 | # Allows us to interrogate model on variable data (instead of masked sample again) 6 | 7 | #%% 8 | import os 9 | import os.path as osp 10 | from pathlib import Path 11 | import sys 12 | module_path = os.path.abspath(os.path.join('..')) 13 | if module_path not in sys.path: 14 | sys.path.append(module_path) 15 | 16 | import time 17 | import h5py 18 | import matplotlib.pyplot as plt 19 | import seaborn as sns 20 | import numpy as np 21 | import torch 22 | import torch.nn.functional as f 23 | from torch.utils import data 24 | 25 | from src.run import prepare_config 26 | from src.runner import Runner 27 | from src.dataset import SpikesDataset, DATASET_MODES 28 | from src.mask import UNMASKED_LABEL 29 | 30 | from analyze_utils import init_by_ckpt 31 | 32 | grid = False 33 | grid = True 34 | 35 | variant = "m700_2296-s1" 36 | 37 | val_errs = {} 38 | def save_rates(ckpt_path, handle): 39 | runner, spikes, rates = init_by_ckpt(ckpt_path, mode=DATASET_MODES.val) 40 | if "maze" in variant or "m700" in variant: 41 | runner.config.defrost() 42 | runner.config.DATA.DATAPATH = "/snel/share/data/ndt_paper/m1_maze/heldout_trial/2296_trials/0_seed" 43 | runner.config.freeze() 44 | rate_output_pth = f"/snel/share/joel/ndt_rates/psth_match" 45 | rate_output_pth = osp.join(rate_output_pth, "grid" if grid else "pbt") 46 | rate_output_fn = f"{handle}_{variant}.h5" 47 | val_errs[handle] = runner.best_val['value'].cpu().numpy() 48 | 49 | 50 | sweep_dir = Path("/snel/home/joely/ray_results/ndt/") 51 | if grid: 52 | sweep_dir = sweep_dir.joinpath("gridsearch") 53 | sweep_dir = sweep_dir.joinpath(variant, variant) 54 | for run_dir in sweep_dir.glob("tuneNDT*/"): 55 | run_ckpt_path = run_dir.joinpath(f'ckpts/{variant}.lve.pth') 56 | handle = run_dir.parts[-1][:10] 57 | save_rates(run_ckpt_path, handle) 58 | 59 | #%% 60 | torch.save(val_errs, f'{variant}_val_errs_sweep.pth') -------------------------------------------------------------------------------- /scripts/scratch.py: -------------------------------------------------------------------------------- 1 | #%% 2 | import h5py 3 | path_old = '/home/joelye/user_data/nlb/ndt_old/mc_maze_small.h5' 4 | info_old = h5py.File(path_old) 5 | path_new = '/home/joelye/user_data/nlb/mc_maze_small.h5' 6 | info_new = h5py.File(path_new) 7 | path_old_eval = '/home/joelye/user_data/nlb/ndt_old/eval_data_val.h5' 8 | info_eval = h5py.File(path_old_eval) 9 | #%% 10 | # print(info_old.keys()) 11 | print(info_old.keys()) 12 | print(info_new.keys()) 13 | for key in info_old: 14 | new_key = key.split('_') 15 | new_key[1] = 'spikes' 16 | new_key = '_'.join(new_key) 17 | if new_key in info_new: 18 | print(key, (info_old[key][:] == info_new[new_key][:]).all()) 19 | else: 20 | print(new_key, ' missing') 21 | # print((info_old['train_data_heldin'][:] == info_new['train_spikes_heldin'][:]).all()) 22 | # print(info_new['train_spikes_heldin'][:].std()) 23 | # print(info_old['train_data_heldout'][:].std()) 24 | # print(info_new['train_spikes_heldout'][:].std()) 25 | # print(info_old['eval_data_heldin'][:].std()) 26 | 27 | # print((info_old['eval_data_heldout'][:] == info_new['eval_spikes_heldout'][:]).all()) 28 | # print(info_eval['mc_maze_small']['eval_spikes_heldout'][:].std()) 29 | 30 | #%% 31 | #%% 32 | import torch 33 | scratch_payload = torch.load('scratch_rates.pth') 34 | old_payload = torch.load('old_rates.pth') 35 | 36 | print((scratch_payload['spikes'] == old_payload['spikes']).all()) 37 | print((scratch_payload['rates'] == old_payload['rates']).all()) 38 | print((scratch_payload['labels'] == old_payload['labels']).all()) 39 | print(scratch_payload['labels'].sum()) 40 | print(old_payload['labels'].sum()) 41 | print(scratch_payload['loss'].mean(), old_payload['loss'].mean()) -------------------------------------------------------------------------------- /scripts/simple_ci.py: -------------------------------------------------------------------------------- 1 | # Calculate some confidence intervals 2 | 3 | #%% 4 | import os 5 | import os.path as osp 6 | 7 | import sys 8 | module_path = os.path.abspath(os.path.join('..')) 9 | if module_path not in sys.path: 10 | sys.path.append(module_path) 11 | 12 | import time 13 | import h5py 14 | import matplotlib.pyplot as plt 15 | import seaborn as sns 16 | import numpy as np 17 | import torch 18 | import torch.nn.functional as f 19 | from torch.utils import data 20 | 21 | from src.run import prepare_config 22 | from src.runner import Runner 23 | from src.dataset import SpikesDataset, DATASET_MODES 24 | from src.mask import UNMASKED_LABEL 25 | 26 | from analyze_utils import init_by_ckpt 27 | 28 | grid = True 29 | ckpt = "Grid" 30 | best_model = "best_model" 31 | best_model = "best_model_unmasked" 32 | lve = "lfve" if "unmasked" in best_model else "lve" 33 | r2s = [] 34 | mnlls = [] 35 | for i in range(3): 36 | variant = f"lorenz-s{i+1}" 37 | # variant = f"lorenz_lite-s{i+1}" 38 | # variant = f"chaotic-s{i+1}" 39 | # variant = f"chaotic_lite-s{i+1}" 40 | ckpt_path = f"/snel/home/joely/ray_results/ndt/gridsearch/{variant}/{best_model}/ckpts/{variant}.{lve}.pth" 41 | runner, spikes, rates = init_by_ckpt(ckpt_path, mode=DATASET_MODES.val) 42 | 43 | # print(runner.config.MODEL.CONTEXT_FORWARD) 44 | # print(runner.config.MODEL.CONTEXT_BACKWARD) 45 | # print(runner.config.TRAIN.MASK_MAX_SPAN) 46 | ( 47 | unmasked_loss, 48 | pred_rates, 49 | layer_outputs, 50 | attn_weights, 51 | attn_list, 52 | ) = runner.model(spikes, mask_labels=spikes, return_weights=True) 53 | print(f"Best Unmasked Val: {runner.best_unmasked_val}") # .37763 54 | print(f"Best Masked Val: {runner.best_val}") # ` best val is .380 at 1302... 55 | mnlls.append(runner.best_val['value']) 56 | print(f"Best R2: {runner.best_R2}") # best val is .300 at 1302... 57 | 58 | print(f"Unmasked: {unmasked_loss}") 59 | 60 | if "maze" not in variant: 61 | r2 = runner.neuron_r2(rates, pred_rates, flatten=True) 62 | # r2 = runner.neuron_r2(rates, pred_rates) 63 | r2s.append(r2) 64 | vaf = runner.neuron_vaf(rates, pred_rates) 65 | print(f"R2:\t{r2}, VAF:\t{vaf}") 66 | 67 | trials, time, num_neurons = rates.size() 68 | print(runner.count_updates) 69 | print(runner.config.MODEL.EMBED_DIM) 70 | 71 | #%% 72 | import math 73 | # print(sum(mnlls) / 3) 74 | # print(mnlls) 75 | def print_ci(r2s): 76 | r2s = np.array(r2s) 77 | mean = r2s.mean() 78 | ci = r2s.std() * 1.96 / math.sqrt(3) 79 | print(f"{mean:.3f} \pm {ci:.3f}") 80 | print_ci(r2s) 81 | # print_ci([0.2079, 0.2077, 0.2091]) 82 | # print_ci([ 83 | # 0.9225, 0.9113, 0.9183 84 | # ]) 85 | # print_ci([0.8712, 0.8687, 0.8664]) 86 | # print_ci([0.9255, 0.9271, 0.9096]) 87 | # print_ci([.924, .914, .9174]) 88 | # print_ci([.0496, .0095, .0054]) 89 | # print_ci([.4003, .0382, .4014]) 90 | # print_ci([.52, .4469, .6242]) 91 | print_ci([.416, .0388, .4221]) 92 | print_ci([.0516, .0074, .0062]) 93 | print_ci([ 94 | 0.5184, 0.4567, 0.6724 95 | ]) 96 | 97 | #%% 98 | # Port for rds analysis 99 | if "maze" in variant: 100 | rate_output_pth = f"/snel/share/joel/ndt_rates/" 101 | rate_output_pth = osp.join(rate_output_pth, "grid" if grid else "pbt") 102 | rate_output_fn = f"{variant}.h5" 103 | pred_rates, layer_outputs = runner.get_rates( 104 | checkpoint_path=ckpt_path, 105 | save_path=osp.join(rate_output_pth, rate_output_fn) 106 | ) 107 | trials, time, num_neurons = pred_rates.size() 108 | 109 | 110 | #%% 111 | 112 | ( 113 | unmasked_loss, 114 | pred_rates, 115 | layer_outputs, 116 | attn_weights, 117 | attn_list, 118 | ) = runner.model(spikes, mask_labels=spikes, return_weights=True) 119 | print(f"Unmasked: {unmasked_loss}") 120 | 121 | if "maze" not in variant: 122 | r2 = runner.neuron_r2(rates, pred_rates) 123 | vaf = runner.neuron_vaf(rates, pred_rates) 124 | print(f"R2:\t{r2}, VAF:\t{vaf}") 125 | 126 | trials, time, num_neurons = rates.size() 127 | print(f"Best Masked Val: {runner.best_val}") # ` best val is .380 at 1302... 128 | print(f"Best Unmasked Val: {runner.best_unmasked_val}") # .37763 129 | print(f"Best R2: {runner.best_R2}") # best val is .300 at 1302... 130 | print(runner.count_updates) 131 | -------------------------------------------------------------------------------- /scripts/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [[ $# -eq 1 ]] 4 | then 5 | python -u src/run.py --run-type train --exp-config configs/$1.yaml 6 | elif [[ $# -eq 2 ]] 7 | then 8 | python -u src/run.py --run-type train --exp-config configs/$1.yaml --ckpt-path $1.$2.pth 9 | else 10 | echo "Expected args (ckpt)" 11 | fi -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Author: Joel Ye 3 | 4 | from src.logger_wrapper import create_logger 5 | from src.model_registry import ( 6 | get_model_class, is_learning_model, is_input_masked_model 7 | ) 8 | from src.tb_wrapper import TensorboardWriter 9 | 10 | __all__ = [ 11 | "get_model_class", 12 | "is_learning_model", 13 | "is_input_masked_model", 14 | "create_logger", 15 | "TensorboardWriter" 16 | ] -------------------------------------------------------------------------------- /src/config/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snel-repo/neural-data-transformers/98dd85a24885ffb76adfeed0c2a89d3ea3ecf9d1/src/config/__init__.py -------------------------------------------------------------------------------- /src/config/default.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Author: Joel Ye 3 | 4 | from typing import List, Optional, Union 5 | 6 | from yacs.config import CfgNode as CN 7 | 8 | DEFAULT_CONFIG_DIR = "config/" 9 | CONFIG_FILE_SEPARATOR = "," 10 | 11 | # ----------------------------------------------------------------------------- 12 | # Config definition 13 | # ----------------------------------------------------------------------------- 14 | _C = CN() 15 | _C.SEED = 100 16 | 17 | # Name of experiment 18 | _C.VARIANT = "experiment" 19 | _C.USE_TENSORBOARD = True 20 | _C.TENSORBOARD_DIR = "tb/" 21 | _C.CHECKPOINT_DIR = "ckpts/" 22 | _C.LOG_DIR = "logs/" 23 | 24 | # ----------------------------------------------------------------------------- 25 | # System 26 | # ----------------------------------------------------------------------------- 27 | _C.SYSTEM = CN() 28 | _C.SYSTEM.TORCH_GPU_ID = 0 29 | _C.SYSTEM.GPU_AUTO_ASSIGN = True # Auto-assign 30 | _C.SYSTEM.NUM_GPUS = 1 31 | 32 | # ----------------------------------------------------------------------------- 33 | # Data 34 | # ----------------------------------------------------------------------------- 35 | _C.DATA = CN() 36 | _C.DATA.DATAPATH = 'data/' 37 | _C.DATA.TRAIN_FILENAME = 'train.pth' 38 | _C.DATA.VAL_FILENAME = 'val.pth' 39 | _C.DATA.TEST_FILENAME = 'test.pth' 40 | _C.DATA.OVERFIT_TEST = False 41 | _C.DATA.RANDOM_SUBSET_TRIALS = 1.0 # Testing how NDT performs on a variety of dataset sizes 42 | 43 | _C.DATA.LOG_EPSILON = 1e-7 # prevent -inf if we use logrates 44 | # _C.DATA.IGNORE_FORWARD = True # Ignore forward prediction even if it's available in train. Useful if we don't have forward spikes in validation. (system misbehaves if we only have train...) 45 | # Performance with above seems subpar, i.e. we need forward and heldout together for some reason 46 | _C.DATA.IGNORE_FORWARD = False 47 | 48 | # ----------------------------------------------------------------------------- 49 | # Model 50 | # ----------------------------------------------------------------------------- 51 | _C.MODEL = CN() 52 | _C.MODEL.NAME = "NeuralDataTransformer" 53 | _C.MODEL.TRIAL_LENGTH = -1 # -1 represents "as much as available" # ! not actually supported in model yet... 54 | _C.MODEL.CONTEXT_FORWARD = 4 # -1 represents "as much as available" 55 | _C.MODEL.CONTEXT_BACKWARD = 8 # -1 represents "as much as available" 56 | _C.MODEL.CONTEXT_WRAP_INITIAL = False 57 | _C.MODEL.FULL_CONTEXT = False # Ignores CONTEXT_FORWARD and CONTEXT_BACKWARD if True (-1 for both) 58 | _C.MODEL.UNMASKED_LOSS_SCALE = 0.0 # Relative scale for predicting unmasked spikes (kinda silly - deprecated) 59 | _C.MODEL.HIDDEN_SIZE = 128 # Generic hidden size, used as default 60 | _C.MODEL.DROPOUT = .1 # Catch all 61 | _C.MODEL.DROPOUT_RATES = 0.2 # Specific for rates 62 | _C.MODEL.DROPOUT_EMBEDDING = 0.2 # Dropout Population Activity pre-transformer 63 | _C.MODEL.NUM_HEADS = 2 64 | _C.MODEL.NUM_LAYERS = 6 65 | _C.MODEL.ACTIVATION = "relu" # "gelu" 66 | _C.MODEL.LINEAR_EMBEDDER = False # Use linear layer instead of embedding layer 67 | _C.MODEL.EMBED_DIM = 2 # this greatly affects model size btw 68 | _C.MODEL.LEARNABLE_POSITION = False 69 | _C.MODEL.MAX_SPIKE_COUNT = 20 70 | _C.MODEL.REQUIRES_RATES = False 71 | _C.MODEL.LOGRATE = True # If true, we operate in lograte, and assume rates from data are logrates. Only for R2 do we exp 72 | _C.MODEL.SPIKE_LOG_INIT = False # If true, init spike embeddings as a 0 centered linear sequence 73 | _C.MODEL.FIXUP_INIT = False 74 | _C.MODEL.PRE_NORM = False # per transformers without tears 75 | _C.MODEL.SCALE_NORM = False # per transformers without tears 76 | 77 | _C.MODEL.DECODER = CN() 78 | _C.MODEL.DECODER.LAYERS = 1 79 | 80 | _C.MODEL.LOSS = CN() 81 | _C.MODEL.LOSS.TYPE = "poisson" # ["cel", "poisson"] 82 | _C.MODEL.LOSS.TOPK = 1.0 # In case we're neglecting some neurons, focus on them 83 | 84 | _C.MODEL.POSITION = CN() 85 | _C.MODEL.POSITION.OFFSET = True 86 | # ----------------------------------------------------------------------------- 87 | # Train Config 88 | # ----------------------------------------------------------------------------- 89 | _C.TRAIN = CN() 90 | 91 | _C.TRAIN.DO_VAL = True # Run validation while training 92 | _C.TRAIN.DO_R2 = True # Run validation while training 93 | 94 | _C.TRAIN.BATCH_SIZE = 64 95 | _C.TRAIN.NUM_UPDATES = 10000 # Max updates 96 | _C.TRAIN.MAX_GRAD_NORM = 200.0 97 | _C.TRAIN.USE_ZERO_MASK = True 98 | _C.TRAIN.MASK_RATIO = 0.2 99 | _C.TRAIN.MASK_TOKEN_RATIO = 1.0 # We don't need this if we use zero mask 100 | _C.TRAIN.MASK_RANDOM_RATIO = 0.5 # Of the non-replaced, what percentage should be random? 101 | _C.TRAIN.MASK_MODE = "timestep" # ["full", "timestep"] 102 | _C.TRAIN.MASK_MAX_SPAN = 1 103 | _C.TRAIN.MASK_SPAN_RAMP_START = 600 104 | _C.TRAIN.MASK_SPAN_RAMP_END = 1200 105 | 106 | _C.TRAIN.LR = CN() 107 | _C.TRAIN.LR.INIT = 1e-3 108 | _C.TRAIN.LR.SCHEDULE = True 109 | _C.TRAIN.LR.SCHEDULER = "cosine" # invsqrt 110 | _C.TRAIN.LR.WARMUP = 1000 # Mostly decay 111 | _C.TRAIN.WEIGHT_DECAY = 0.0 112 | _C.TRAIN.EPS = 1e-8 113 | _C.TRAIN.PATIENCE = 750 # For early stopping (be generous, our loss steps) 114 | 115 | _C.TRAIN.CHECKPOINT_INTERVAL = 1000 116 | _C.TRAIN.LOG_INTERVAL = 50 117 | _C.TRAIN.VAL_INTERVAL = 10 # Val less often so things run faster 118 | 119 | _C.TRAIN.TUNE_MODE = False 120 | _C.TRAIN.TUNE_EPOCHS_PER_GENERATION = 500 121 | _C.TRAIN.TUNE_HP_JSON = "./lorenz_pbt.json" 122 | _C.TRAIN.TUNE_WARMUP = 0 123 | _C.TRAIN.TUNE_METRIC = "smth_masked_loss" 124 | # JSON schema - flattened config dict. Each entry has info to construct a hyperparam. 125 | 126 | def get_cfg_defaults(): 127 | """Get default LFADS config (yacs config node).""" 128 | return _C.clone() 129 | 130 | def get_config( 131 | config_paths: Optional[Union[List[str], str]] = None, 132 | opts: Optional[list] = None, 133 | ) -> CN: 134 | r"""Create a unified config with default values overwritten by values from 135 | :p:`config_paths` and overwritten by options from :p:`opts`. 136 | 137 | :param config_paths: List of config paths or string that contains comma 138 | separated list of config paths. 139 | :param opts: Config options (keys, values) in a list (e.g., passed from 140 | command line into the config. For example, 141 | :py:`opts = ['FOO.BAR', 0.5]`. Argument can be used for parameter 142 | sweeping or quick tests. 143 | """ 144 | config = get_cfg_defaults() 145 | if config_paths: 146 | if isinstance(config_paths, str): 147 | if CONFIG_FILE_SEPARATOR in config_paths: 148 | config_paths = config_paths.split(CONFIG_FILE_SEPARATOR) 149 | else: 150 | config_paths = [config_paths] 151 | 152 | for config_path in config_paths: 153 | config.merge_from_file(config_path) 154 | 155 | if opts: 156 | config.merge_from_list(opts) 157 | 158 | config.freeze() 159 | return config 160 | 161 | 162 | # The flatten and unflatten snippets are from an internal lfads_tf2 implementation. 163 | 164 | def flatten(dictionary, level=[]): 165 | """ Flattens a dictionary by placing '.' between levels. 166 | 167 | This function flattens a hierarchical dictionary by placing '.' 168 | between keys at various levels to create a single key for each 169 | value. It is used internally for converting the configuration 170 | dictionary to more convenient formats. Implementation was 171 | inspired by `this StackOverflow post 172 | `_. 173 | 174 | Parameters 175 | ---------- 176 | dictionary : dict 177 | The hierarchical dictionary to be flattened. 178 | level : str, optional 179 | The string to append to the beginning of this dictionary, 180 | enabling recursive calls. By default, an empty string. 181 | 182 | Returns 183 | ------- 184 | dict 185 | The flattened dictionary. 186 | 187 | See Also 188 | -------- 189 | lfads_tf2.utils.unflatten : Performs the opposite of this operation. 190 | 191 | """ 192 | 193 | tmp_dict = {} 194 | for key, val in dictionary.items(): 195 | if type(val) == dict: 196 | tmp_dict.update(flatten(val, level + [key])) 197 | else: 198 | tmp_dict['.'.join(level + [key])] = val 199 | return tmp_dict 200 | 201 | 202 | def unflatten(dictionary): 203 | """ Unflattens a dictionary by splitting keys at '.'s. 204 | 205 | This function unflattens a hierarchical dictionary by splitting 206 | its keys at '.'s. It is used internally for converting the 207 | configuration dictionary to more convenient formats. Implementation was 208 | inspired by `this StackOverflow post 209 | `_. 210 | 211 | Parameters 212 | ---------- 213 | dictionary : dict 214 | The flat dictionary to be unflattened. 215 | 216 | Returns 217 | ------- 218 | dict 219 | The unflattened dictionary. 220 | 221 | See Also 222 | -------- 223 | lfads_tf2.utils.flatten : Performs the opposite of this operation. 224 | 225 | """ 226 | 227 | resultDict = dict() 228 | for key, value in dictionary.items(): 229 | parts = key.split(".") 230 | d = resultDict 231 | for part in parts[:-1]: 232 | if part not in d: 233 | d[part] = dict() 234 | d = d[part] 235 | d[parts[-1]] = value 236 | return resultDict 237 | 238 | -------------------------------------------------------------------------------- /src/dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Author: Joel Ye 3 | import os.path as osp 4 | 5 | import h5py 6 | import numpy as np 7 | import torch 8 | from torch.utils import data 9 | 10 | from src.utils import merge_train_valid 11 | 12 | class DATASET_MODES: 13 | train = "train" 14 | val = "val" 15 | test = "test" 16 | trainval = "trainval" 17 | 18 | class SpikesDataset(data.Dataset): 19 | r""" 20 | Dataset for single file of spike times (loads into memory) 21 | Lorenz data is NxTxH (H being number of neurons) - we load to T x N x H 22 | # ! Note that codepath for forward but not heldout neurons is not tested and likely broken 23 | """ 24 | 25 | def __init__(self, config, filename, mode=DATASET_MODES.train, logger=None): 26 | r""" 27 | args: 28 | config: dataset config 29 | filename: excluding path 30 | mode: used to extract the right indices from LFADS h5 data 31 | """ 32 | super().__init__() 33 | self.logger = logger 34 | if self.logger is not None: 35 | self.logger.info(f"Loading {filename} in {mode}") 36 | self.config = config.DATA 37 | self.use_lograte = config.MODEL.LOGRATE 38 | self.batch_size = config.TRAIN.BATCH_SIZE 39 | self.datapath = osp.join(config.DATA.DATAPATH, filename) 40 | split_path = self.datapath.split(".") 41 | 42 | self.has_rates = False 43 | self.has_heldout = False 44 | self.has_forward = False 45 | if len(split_path) == 1 or split_path[-1] == "h5": 46 | spikes, rates, heldout_spikes, forward_spikes = self.get_data_from_h5(mode, self.datapath) 47 | 48 | spikes = torch.tensor(spikes).long() 49 | if rates is not None: 50 | rates = torch.tensor(rates) 51 | if heldout_spikes is not None: 52 | self.has_heldout = True 53 | heldout_spikes = torch.tensor(heldout_spikes).long() 54 | if forward_spikes is not None and not config.DATA.IGNORE_FORWARD: 55 | self.has_forward = True 56 | forward_spikes = torch.tensor(forward_spikes).long() 57 | else: 58 | forward_spikes = None 59 | elif split_path[-1] == "pth": 60 | dataset_dict = torch.load(self.datapath) 61 | spikes = dataset_dict["spikes"] 62 | if "rates" in dataset_dict: 63 | self.has_rates = True 64 | rates = dataset_dict["rates"] 65 | heldout_spikes = None 66 | forward_spikes = None 67 | else: 68 | raise Exception(f"Unknown dataset extension {split_path[-1]}") 69 | 70 | self.num_trials, _, self.num_neurons = spikes.size() 71 | self.full_length = config.MODEL.TRIAL_LENGTH <= 0 72 | self.trial_length = spikes.size(1) if self.full_length else config.MODEL.TRIAL_LENGTH 73 | if self.has_heldout: 74 | self.num_neurons += heldout_spikes.size(-1) 75 | if self.has_forward: 76 | self.trial_length += forward_spikes.size(1) 77 | self.spikes = self.batchify(spikes) 78 | # Fake rates so we can skip None checks everywhere. Use `self.has_rates` when desired 79 | self.rates = self.batchify(rates) if self.has_rates else torch.zeros_like(spikes) 80 | # * else condition below is not precisely correctly shaped as correct shape isn't used 81 | self.heldout_spikes = self.batchify(heldout_spikes) if self.has_heldout else torch.zeros_like(spikes) 82 | self.forward_spikes = self.batchify(forward_spikes) if self.has_forward else torch.zeros_like(spikes) 83 | 84 | if config.DATA.OVERFIT_TEST: 85 | if self.logger is not None: 86 | self.logger.warning("Overfitting..") 87 | self.spikes = self.spikes[:2] 88 | self.rates = self.rates[:2] 89 | self.num_trials = 2 90 | elif hasattr(config.DATA, "RANDOM_SUBSET_TRIALS") and config.DATA.RANDOM_SUBSET_TRIALS < 1.0 and mode == DATASET_MODES.train: 91 | if self.logger is not None: 92 | self.logger.warning(f"!!!!! Training on {config.DATA.RANDOM_SUBSET_TRIALS} of the data with seed {config.SEED}.") 93 | reduced = int(self.num_trials * config.DATA.RANDOM_SUBSET_TRIALS) 94 | torch.random.manual_seed(config.SEED) 95 | random_subset = torch.randperm(self.num_trials)[:reduced] 96 | self.num_trials = reduced 97 | self.spikes = self.spikes[random_subset] 98 | self.rates = self.rates[random_subset] 99 | 100 | def batchify(self, x): 101 | r""" 102 | Chops data into uniform sizes as configured by trial_length. 103 | 104 | Returns: 105 | x reshaped as num_samples x trial_length x neurons 106 | """ 107 | if self.full_length: 108 | return x 109 | trial_time = x.size(1) 110 | samples_per_trial = trial_time // self.trial_length 111 | if trial_time % self.trial_length != 0: 112 | if self.logger is not None: 113 | self.logger.debug(f"Trimming dangling trial info. Data trial length {trial_time} \ 114 | is not divisible by asked length {self.trial_length})") 115 | x = x.narrow(1, 0, samples_per_trial * self.trial_length) 116 | 117 | # ! P sure this can be a view op 118 | # num_samples x trial_length x neurons 119 | return torch.cat(torch.split(x, self.trial_length, dim=1), dim=0) 120 | 121 | def get_num_neurons(self): 122 | return self.num_neurons 123 | 124 | def __len__(self): 125 | return self.spikes.size(0) 126 | 127 | def __getitem__(self, index): 128 | r""" 129 | Return spikes and rates, shaped T x N (num_neurons) 130 | """ 131 | return ( 132 | self.spikes[index], 133 | None if self.rates is None else self.rates[index], 134 | None if self.heldout_spikes is None else self.heldout_spikes[index], 135 | None if self.forward_spikes is None else self.forward_spikes[index] 136 | ) 137 | 138 | def get_dataset(self): 139 | return self.spikes, self.rates, self.heldout_spikes, self.forward_spikes 140 | 141 | def get_max_spikes(self): 142 | return self.spikes.max().item() 143 | 144 | def get_num_batches(self): 145 | return self.spikes.size(0) // self.batch_size 146 | 147 | def clip_spikes(self, max_val): 148 | self.spikes = torch.clamp(self.spikes, max=max_val) 149 | 150 | def get_data_from_h5(self, mode, filepath): 151 | r""" 152 | returns: 153 | spikes 154 | rates (None if not available) 155 | held out spikes (for cosmoothing, None if not available) 156 | * Note, rates and held out spikes codepaths conflict 157 | """ 158 | NLB_KEY = 'spikes' # curiously, old code thought NLB data keys came as "train_data_heldin" and not "train_spikes_heldin" 159 | NLB_KEY_ALT = 'data' 160 | 161 | with h5py.File(filepath, 'r') as h5file: 162 | h5dict = {key: h5file[key][()] for key in h5file.keys()} 163 | if f'eval_{NLB_KEY}_heldin' not in h5dict: # double check 164 | if f'eval_{NLB_KEY_ALT}_heldin' in h5dict: 165 | NLB_KEY = NLB_KEY_ALT 166 | if f'eval_{NLB_KEY}_heldin' in h5dict: # NLB data, presumes both heldout neurons and time are available 167 | get_key = lambda key: h5dict[key].astype(np.float32) 168 | train_data = get_key(f'train_{NLB_KEY}_heldin') 169 | train_data_fp = get_key(f'train_{NLB_KEY}_heldin_forward') 170 | train_data_heldout_fp = get_key(f'train_{NLB_KEY}_heldout_forward') 171 | train_data_all_fp = np.concatenate([train_data_fp, train_data_heldout_fp], -1) 172 | valid_data = get_key(f'eval_{NLB_KEY}_heldin') 173 | train_data_heldout = get_key(f'train_{NLB_KEY}_heldout') 174 | if f'eval_{NLB_KEY}_heldout' in h5dict: 175 | valid_data_heldout = get_key(f'eval_{NLB_KEY}_heldout') 176 | else: 177 | self.logger.warn('Substituting zero array for heldout neurons. Only done for evaluating models locally, i.e. will disrupt training due to early stopping.') 178 | valid_data_heldout = np.zeros((valid_data.shape[0], valid_data.shape[1], train_data_heldout.shape[2]), dtype=np.float32) 179 | if f'eval_{NLB_KEY}_heldin_forward' in h5dict: 180 | valid_data_fp = get_key(f'eval_{NLB_KEY}_heldin_forward') 181 | valid_data_heldout_fp = get_key(f'eval_{NLB_KEY}_heldout_forward') 182 | valid_data_all_fp = np.concatenate([valid_data_fp, valid_data_heldout_fp], -1) 183 | else: 184 | self.logger.warn('Substituting zero array for heldout forward neurons. Only done for evaluating models locally, i.e. will disrupt training due to early stopping.') 185 | valid_data_all_fp = np.zeros( 186 | (valid_data.shape[0], train_data_fp.shape[1], valid_data.shape[2] + valid_data_heldout.shape[2]), dtype=np.float32 187 | ) 188 | 189 | # NLB data does not have ground truth rates 190 | if mode == DATASET_MODES.train: 191 | return train_data, None, train_data_heldout, train_data_all_fp 192 | elif mode == DATASET_MODES.val: 193 | return valid_data, None, valid_data_heldout, valid_data_all_fp 194 | train_data = h5dict['train_data'].astype(np.float32).squeeze() 195 | valid_data = h5dict['valid_data'].astype(np.float32).squeeze() 196 | train_rates = None 197 | valid_rates = None 198 | if "train_truth" and "valid_truth" in h5dict: # original LFADS-type datasets 199 | self.has_rates = True 200 | train_rates = h5dict['train_truth'].astype(np.float32) 201 | valid_rates = h5dict['valid_truth'].astype(np.float32) 202 | train_rates = train_rates / h5dict['conversion_factor'] 203 | valid_rates = valid_rates / h5dict['conversion_factor'] 204 | if self.use_lograte: 205 | train_rates = torch.log(torch.tensor(train_rates) + self.config.LOG_EPSILON) 206 | valid_rates = torch.log(torch.tensor(valid_rates) + self.config.LOG_EPSILON) 207 | if mode == DATASET_MODES.train: 208 | return train_data, train_rates, None, None 209 | elif mode == DATASET_MODES.val: 210 | return valid_data, valid_rates, None, None 211 | elif mode == DATASET_MODES.trainval: 212 | # merge training and validation data 213 | if 'train_inds' in h5dict and 'valid_inds' in h5dict: 214 | # if there are index labels, use them to reassemble full data 215 | train_inds = h5dict['train_inds'].squeeze() 216 | valid_inds = h5dict['valid_inds'].squeeze() 217 | file_data = merge_train_valid( 218 | train_data, valid_data, train_inds, valid_inds) 219 | if self.has_rates: 220 | merged_rates = merge_train_valid( 221 | train_rates, valid_rates, train_inds, valid_inds 222 | ) 223 | else: 224 | if self.logger is not None: 225 | self.logger.info("No indices found for merge. " 226 | "Concatenating training and validation samples.") 227 | file_data = np.concatenate([train_data, valid_data], axis=0) 228 | if self.has_rates: 229 | merged_rates = np.concatenate([train_rates, valid_rates], axis=0) 230 | return file_data, merged_rates if self.has_rates else None, None, None 231 | else: # test unsupported 232 | return None, None, None, None -------------------------------------------------------------------------------- /src/logger_wrapper.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Adapted from Facebook Habitat Framework 4 | 5 | import logging 6 | import copy 7 | 8 | class Logger(logging.Logger): 9 | def __init__( 10 | self, 11 | name, 12 | level, 13 | filename=None, 14 | filemode="a", 15 | stream=None, 16 | format=None, 17 | dateformat=None, 18 | style="%", 19 | ): 20 | super().__init__(name, level) 21 | if filename is not None: 22 | handler = logging.FileHandler(filename, filemode) 23 | else: 24 | handler = logging.StreamHandler(stream) 25 | self._formatter = logging.Formatter(format, dateformat, style) 26 | handler.setFormatter(self._formatter) 27 | super().addHandler(handler) 28 | self.stat_queue = [] # Going to be tuples 29 | 30 | def clear_filehandlers(self): 31 | self.handlers = [h for h in self.handlers if not isinstance(h, logging.FileHandler)] 32 | 33 | def clear_streamhandlers(self): 34 | self.handlers = [h for h in self.handlers if (not isinstance(h, logging.StreamHandler) or isinstance(h, logging.FileHandler))] 35 | 36 | def add_filehandler(self, log_filename): 37 | filehandler = logging.FileHandler(log_filename) 38 | filehandler.setFormatter(self._formatter) 39 | self.addHandler(filehandler) 40 | 41 | def queue_stat(self, stat_name, stat): 42 | self.stat_queue.append((stat_name, stat)) 43 | 44 | def empty_queue(self): 45 | queue = copy.deepcopy(self.stat_queue) 46 | self.stat_queue = [] 47 | return queue 48 | 49 | def log_update(self, update): 50 | stat_str = "\t".join([f"{stat[0]}: {stat[1]:.3f}" for stat in self.empty_queue()]) 51 | self.info("update: {}\t{}".format(update, stat_str)) 52 | 53 | def mute(self): 54 | self.setLevel(logging.ERROR) 55 | 56 | def unmute(self): 57 | self.setLevel(logging.INFO) 58 | 59 | def create_logger(): 60 | return Logger( 61 | name="NDT", level=logging.INFO, format="%(asctime)-15s %(message)s" 62 | ) 63 | 64 | __all__ = ["create_logger"] 65 | -------------------------------------------------------------------------------- /src/mask.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Author: Joel Ye 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | # Some infeasibly high spike count 9 | DEFAULT_MASK_VAL = 30 10 | UNMASKED_LABEL = -100 11 | SUPPORTED_MODES = ["full", "timestep", "neuron", "timestep_only"] 12 | 13 | # Use a class so we can cache random mask 14 | class Masker: 15 | 16 | def __init__(self, train_cfg, device): 17 | self.update_config(train_cfg) 18 | if self.cfg.MASK_MODE not in SUPPORTED_MODES: 19 | raise Exception(f"Given {self.cfg.MASK_MODE} not in supported {SUPPORTED_MODES}") 20 | self.device = device 21 | 22 | def update_config(self, config): 23 | self.cfg = config 24 | self.prob_mask = None 25 | 26 | def expand_mask(self, mask, width): 27 | r""" 28 | args: 29 | mask: N x T 30 | width: expansion block size 31 | """ 32 | kernel = torch.ones(width, device=mask.device).view(1, 1, -1) 33 | expanded_mask = F.conv1d(mask.unsqueeze(1), kernel, padding= width// 2).clamp_(0, 1) 34 | if width % 2 == 0: 35 | expanded_mask = expanded_mask[...,:-1] # crop if even (we've added too much padding) 36 | return expanded_mask.squeeze(1) 37 | 38 | def mask_batch( 39 | self, 40 | batch, 41 | mask=None, 42 | max_spikes=DEFAULT_MASK_VAL - 1, 43 | should_mask=True, 44 | expand_prob=0.0, 45 | heldout_spikes=None, 46 | forward_spikes=None, 47 | ): 48 | r""" Given complete batch, mask random elements and return true labels separately. 49 | Modifies batch OUT OF place! 50 | Modeled after HuggingFace's `mask_tokens` in `run_language_modeling.py` 51 | args: 52 | batch: batch NxTxH 53 | mask_ratio: ratio to randomly mask 54 | mode: "full" or "timestep" - if "full", will randomly drop on full matrix, whereas on "timestep", will mask out random timesteps 55 | mask: Optional mask to use 56 | max_spikes: in case not zero masking, "mask token" 57 | expand_prob: with this prob, uniformly expand. else, keep single tokens. UniLM does, with 40% expand to fixed, else keep single. 58 | heldout_spikes: None 59 | returns: 60 | batch: list of data batches NxTxH, with some elements along H set to -1s (we allow peeking between rates) 61 | labels: true data (also NxTxH) 62 | """ 63 | batch = batch.clone() # make sure we don't corrupt the input data (which is stored in memory) 64 | 65 | mode = self.cfg.MASK_MODE 66 | should_expand = self.cfg.MASK_MAX_SPAN > 1 and expand_prob > 0.0 and torch.rand(1).item() < expand_prob 67 | width = torch.randint(1, self.cfg.MASK_MAX_SPAN + 1, (1, )).item() if should_expand else 1 68 | mask_ratio = self.cfg.MASK_RATIO if width == 1 else self.cfg.MASK_RATIO / width 69 | 70 | labels = batch.clone() 71 | if mask is None: 72 | if self.prob_mask is None or self.prob_mask.size() != labels.size(): 73 | if mode == "full": 74 | mask_probs = torch.full(labels.shape, mask_ratio) 75 | elif mode == "timestep": 76 | single_timestep = labels[:, :, 0] # N x T 77 | mask_probs = torch.full(single_timestep.shape, mask_ratio) 78 | elif mode == "neuron": 79 | single_neuron = labels[:, 0] # N x H 80 | mask_probs = torch.full(single_neuron.shape, mask_ratio) 81 | elif mode == "timestep_only": 82 | single_timestep = labels[0, :, 0] # T 83 | mask_probs = torch.full(single_timestep.shape, mask_ratio) 84 | self.prob_mask = mask_probs.to(self.device) 85 | # If we want any tokens to not get masked, do it here (but we don't currently have any) 86 | mask = torch.bernoulli(self.prob_mask) 87 | 88 | # N x T 89 | if width > 1: 90 | mask = self.expand_mask(mask, width) 91 | 92 | mask = mask.bool() 93 | if mode == "timestep": 94 | mask = mask.unsqueeze(2).expand_as(labels) 95 | elif mode == "neuron": 96 | mask = mask.unsqueeze(0).expand_as(labels) 97 | elif mode == "timestep_only": 98 | mask = mask.unsqueeze(0).unsqueeze(2).expand_as(labels) 99 | # we want the shape of the mask to be T 100 | elif mask.size() != labels.size(): 101 | raise Exception(f"Input mask of size {mask.size()} does not match input size {labels.size()}") 102 | 103 | labels[~mask] = UNMASKED_LABEL # No ground truth for unmasked - use this to mask loss 104 | if not should_mask: 105 | # Only do the generation 106 | return batch, labels 107 | 108 | # We use random assignment so the model learns embeddings for non-mask tokens, and must rely on context 109 | # Most times, we replace tokens with MASK token 110 | indices_replaced = torch.bernoulli(torch.full(labels.shape, self.cfg.MASK_TOKEN_RATIO, device=mask.device)).bool() & mask 111 | if self.cfg.USE_ZERO_MASK: 112 | batch[indices_replaced] = 0 113 | else: 114 | batch[indices_replaced] = max_spikes + 1 115 | 116 | # Random % of the time, we replace masked input tokens with random value (the rest are left intact) 117 | indices_random = torch.bernoulli(torch.full(labels.shape, self.cfg.MASK_RANDOM_RATIO, device=mask.device)).bool() & mask & ~indices_replaced 118 | random_spikes = torch.randint(batch.max(), labels.shape, dtype=torch.long, device=batch.device) 119 | batch[indices_random] = random_spikes[indices_random] 120 | 121 | if heldout_spikes is not None: 122 | # heldout spikes are all masked 123 | batch = torch.cat([batch, torch.zeros_like(heldout_spikes, device=batch.device)], -1) 124 | labels = torch.cat([labels, heldout_spikes.to(batch.device)], -1) 125 | if forward_spikes is not None: 126 | batch = torch.cat([batch, torch.zeros_like(forward_spikes, device=batch.device)], 1) 127 | labels = torch.cat([labels, forward_spikes.to(batch.device)], 1) 128 | # Leave the other 10% alone 129 | return batch, labels -------------------------------------------------------------------------------- /src/model_baselines.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Author: Joel Ye 3 | import math 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torch.nn import Transformer, TransformerEncoder, TransformerEncoderLayer 9 | from torch.distributions import Poisson 10 | 11 | from src.utils import binary_mask_to_attn_mask 12 | from src.mask import UNMASKED_LABEL 13 | 14 | class RatesOracle(nn.Module): 15 | 16 | def __init__(self, config, num_neurons, device, **kwargs): 17 | super().__init__() 18 | assert config.REQUIRES_RATES == True, "Oracle requires rates" 19 | if config.LOSS.TYPE == "poisson": 20 | self.classifier = nn.PoissonNLLLoss(reduction='none', log_input=config.LOGRATE) 21 | else: 22 | raise Exception(f"Loss type {config.LOSS_TYPE} not supported") 23 | 24 | def get_hidden_size(self): 25 | return 0 26 | 27 | def forward(self, src, mask_labels, rates=None, **kwargs): 28 | # output is t x b x neurons (rate predictions) 29 | loss = self.classifier(rates, mask_labels) 30 | # Mask out losses unmasked labels 31 | masked_loss = loss[mask_labels != UNMASKED_LABEL] 32 | masked_loss = masked_loss.mean() 33 | return ( 34 | masked_loss.unsqueeze(0), 35 | rates, 36 | None, 37 | torch.tensor(0, device=masked_loss.device, dtype=torch.float), 38 | None, 39 | None, 40 | ) 41 | 42 | class RandomModel(nn.Module): 43 | # Guess a random rate in LOGRATE_RANGE 44 | # Purpose - why is our initial loss so close to our final loss 45 | LOGRATE_RANGE = (-2.5, 2.5) 46 | 47 | def __init__(self, config, num_neurons, device): 48 | super().__init__() 49 | self.device = device 50 | if config.LOSS.TYPE == "poisson": 51 | self.classifier = nn.PoissonNLLLoss(reduction='none', log_input=config.LOGRATE) 52 | else: 53 | raise Exception(f"Loss type {config.LOSS_TYPE} not supported") 54 | 55 | def forward(self, src, mask_labels, *args, **kwargs): 56 | # output is t x b x neurons (rate predictions) 57 | rates = torch.rand(mask_labels.size(), dtype=torch.float32).to(self.device) 58 | rates *= (self.LOGRATE_RANGE[1] - self.LOGRATE_RANGE[0]) 59 | rates += self.LOGRATE_RANGE[0] 60 | loss = self.classifier(rates, mask_labels) 61 | # Mask out losses unmasked labels 62 | loss[mask_labels == UNMASKED_LABEL] = 0.0 63 | 64 | return loss.mean(), rates 65 | -------------------------------------------------------------------------------- /src/model_registry.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Author: Joel Ye 3 | 4 | from src.model import ( 5 | NeuralDataTransformer, 6 | ) 7 | 8 | from src.model_baselines import ( 9 | RatesOracle, 10 | RandomModel, 11 | ) 12 | 13 | LEARNING_MODELS = { 14 | "NeuralDataTransformer": NeuralDataTransformer, 15 | } 16 | 17 | NONLEARNING_MODELS = { 18 | "Oracle": RatesOracle, 19 | "Random": RandomModel 20 | } 21 | 22 | INPUT_MASKED_MODELS = { 23 | "NeuralDataTransformer": NeuralDataTransformer, 24 | } 25 | 26 | MODELS = {**LEARNING_MODELS, **NONLEARNING_MODELS, **INPUT_MASKED_MODELS} 27 | 28 | def is_learning_model(model_name): 29 | return model_name in LEARNING_MODELS 30 | 31 | def is_input_masked_model(model_name): 32 | return model_name in INPUT_MASKED_MODELS 33 | 34 | def get_model_class(model_name): 35 | return MODELS[model_name] -------------------------------------------------------------------------------- /src/run.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Author: Joel Ye 3 | 4 | from typing import List, Union 5 | import os.path as osp 6 | import shutil 7 | import random 8 | 9 | import argparse 10 | import numpy as np 11 | import torch 12 | 13 | from src.config.default import get_config 14 | from src.runner import Runner 15 | 16 | DO_PRESERVE_RUNS = False 17 | 18 | def get_parser(): 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument( 21 | "--run-type", 22 | choices=["train", "eval"], 23 | required=True, 24 | help="run type of the experiment (train or eval)", 25 | ) 26 | 27 | parser.add_argument( 28 | "--exp-config", 29 | type=str, 30 | required=True, 31 | help="path to config yaml containing info about experiment", 32 | ) 33 | 34 | parser.add_argument( 35 | "--ckpt-path", 36 | default=None, 37 | type=str, 38 | help="full path to a ckpt (for eval or resumption)" 39 | ) 40 | 41 | parser.add_argument( 42 | "--clear-only", 43 | default=False, 44 | type=bool, 45 | ) 46 | 47 | parser.add_argument( 48 | "opts", 49 | default=None, 50 | nargs=argparse.REMAINDER, 51 | help="Modify config options from command line", 52 | ) 53 | return parser 54 | 55 | def main(): 56 | parser = get_parser() 57 | args = parser.parse_args() 58 | run_exp(**vars(args)) 59 | 60 | def check_exists(path, preserve=DO_PRESERVE_RUNS): 61 | if osp.exists(path): 62 | # logger.warn(f"{path} exists") 63 | print(f"{path} exists") 64 | if not preserve: 65 | # logger.warn(f"removing {path}") 66 | print(f"removing {path}") 67 | shutil.rmtree(path, ignore_errors=True) 68 | return True 69 | return False 70 | 71 | 72 | def prepare_config(exp_config: Union[List[str], str], run_type: str, ckpt_path="", opts=None, suffix=None) -> None: 73 | r"""Prepare config node / do some preprocessing 74 | 75 | Args: 76 | exp_config: path to config file. 77 | run_type: "train" or "eval. 78 | ckpt_path: If training, ckpt to resume. If evaluating, ckpt to evaluate. 79 | opts: list of strings of additional config options. 80 | 81 | Returns: 82 | Runner, config, ckpt_path 83 | """ 84 | config = get_config(exp_config, opts) 85 | 86 | # Default behavior is to pull experiment name from config file 87 | # Bind variant name to directories 88 | if isinstance(exp_config, str): 89 | variant_config = exp_config 90 | else: 91 | variant_config = exp_config[-1] 92 | variant_name = osp.split(variant_config)[1].split('.')[0] 93 | config.defrost() 94 | config.VARIANT = variant_name 95 | if suffix is not None: 96 | config.TENSORBOARD_DIR = osp.join(config.TENSORBOARD_DIR, suffix) 97 | config.CHECKPOINT_DIR = osp.join(config.CHECKPOINT_DIR, suffix) 98 | config.LOG_DIR = osp.join(config.LOG_DIR, suffix) 99 | config.TENSORBOARD_DIR = osp.join(config.TENSORBOARD_DIR, config.VARIANT) 100 | config.CHECKPOINT_DIR = osp.join(config.CHECKPOINT_DIR, config.VARIANT) 101 | config.LOG_DIR = osp.join(config.LOG_DIR, config.VARIANT) 102 | config.freeze() 103 | 104 | if ckpt_path is not None: 105 | if not osp.exists(ckpt_path): 106 | ckpt_path = osp.join(config.CHECKPOINT_DIR, ckpt_path) 107 | 108 | np.random.seed(config.SEED) 109 | random.seed(config.SEED) 110 | torch.random.manual_seed(config.SEED) 111 | torch.backends.cudnn.deterministic = True 112 | 113 | return config, ckpt_path 114 | 115 | def run_exp(exp_config: Union[List[str], str], run_type: str, ckpt_path="", clear_only=False, opts=None, suffix=None) -> None: 116 | config, ckpt_path = prepare_config(exp_config, run_type, ckpt_path, opts, suffix=suffix) 117 | if clear_only: 118 | check_exists(config.TENSORBOARD_DIR, preserve=False) 119 | check_exists(config.CHECKPOINT_DIR, preserve=False) 120 | check_exists(config.LOG_DIR, preserve=False) 121 | exit(0) 122 | if run_type == "train": 123 | if ckpt_path is not None: 124 | runner = Runner(config) 125 | runner.train(checkpoint_path=ckpt_path) 126 | else: 127 | if DO_PRESERVE_RUNS: 128 | if check_exists(config.TENSORBOARD_DIR) or \ 129 | check_exists(config.CHECKPOINT_DIR) or \ 130 | check_exists(config.LOG_DIR): 131 | exit(1) 132 | else: 133 | check_exists(config.TENSORBOARD_DIR) 134 | check_exists(config.CHECKPOINT_DIR) 135 | check_exists(config.LOG_DIR) 136 | runner = Runner(config) 137 | runner.train() 138 | # * The evaluation code path has legacy code (and will not run). 139 | # * Evaluation / analysis is done in analysis scripts. 140 | # elif run_type == "eval": 141 | # runner.eval(checkpoint_path=ckpt_path) 142 | 143 | if __name__ == "__main__": 144 | main() 145 | -------------------------------------------------------------------------------- /src/tb_wrapper.py: -------------------------------------------------------------------------------- 1 | from torch.utils.tensorboard import SummaryWriter 2 | 3 | TensorboardWriter = SummaryWriter -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Author: Joel Ye 3 | 4 | import numpy as np 5 | import torch 6 | from torch.optim.lr_scheduler import LambdaLR 7 | 8 | def binary_mask_to_attn_mask(x): 9 | return x.float().masked_fill(x == 0, float('-inf')).masked_fill(x == 1, float(0.0)) 10 | 11 | # Verbatim from LFADS_TF2 12 | def merge_train_valid(train_data, valid_data, train_ixs, valid_ixs): 13 | """Merges training and validation numpy arrays using indices. 14 | 15 | This function merges training and validation numpy arrays 16 | in the appropriate order using arrays of their indices. The 17 | lengths of the indices must be the same as the first dimension 18 | of the corresponding data. 19 | 20 | Parameters 21 | ---------- 22 | train_data : np.ndarray 23 | An N-dimensional numpy array of training data with 24 | first dimension T. 25 | valid_data : np.ndarray 26 | An N-dimensional numpy array of validation data with 27 | first dimension V. 28 | train_ixs : np.ndarray 29 | A 1-D numpy array of training indices with length T. 30 | valid_ixs : np.ndarray 31 | A 1-D numpy array of validation indices with length V. 32 | 33 | Returns 34 | ------- 35 | np.ndarray 36 | An N-dimensional numpy array with dimension T + V. 37 | 38 | """ 39 | 40 | if train_data.shape[0] == train_ixs.shape[0] \ 41 | and valid_data.shape[0] == valid_ixs.shape[0]: 42 | # if the indices match up, then we can use them to merge 43 | data = np.full_like(np.concatenate([train_data, valid_data]), np.nan) 44 | if min(min(train_ixs), min(valid_ixs)) > 0: 45 | # we've got matlab data... 46 | train_ixs -= 1 47 | valid_ixs -= 1 48 | data[train_ixs.astype(int)] = train_data 49 | data[valid_ixs.astype(int)] = valid_data 50 | else: 51 | # if the indices do not match, train and 52 | # valid data may be the same (e.g. for priors) 53 | if np.all(train_data == valid_data): 54 | data = train_data 55 | else: 56 | raise ValueError("shape mismatch: " 57 | f"Index shape {train_ixs.shape} does not " 58 | f"match the data shape {train_data.shape}.") 59 | return data 60 | 61 | def get_inverse_sqrt_schedule(optimizer, warmup_steps=1000, lr_init=1e-8, lr_max=5e-4): 62 | """ 63 | https://github.com/pytorch/fairseq/blob/master/fairseq/optim/lr_scheduler/inverse_square_root_schedule.py 64 | Decay the LR based on the inverse square root of the update number. 65 | We also support a warmup phase where we linearly increase the learning rate 66 | from some initial learning rate (``--warmup-init-lr``) until the configured 67 | learning rate (``--lr``). Thereafter we decay proportional to the number of 68 | updates, with a decay factor set to align with the configured learning rate. 69 | During warmup:: 70 | lrs = torch.linspace(args.warmup_init_lr, args.lr, args.warmup_updates) 71 | lr = lrs[update_num] 72 | After warmup:: 73 | decay_factor = args.lr * sqrt(args.warmup_updates) 74 | lr = decay_factor / sqrt(update_num) 75 | """ 76 | def lr_lambda(current_step): 77 | lr_step = (lr_max - lr_init) / warmup_steps 78 | decay_factor = lr_max * warmup_steps ** 0.5 79 | 80 | if current_step < warmup_steps: 81 | lr = lr_init + current_step * lr_step 82 | else: 83 | lr = decay_factor * current_step ** -0.5 84 | return lr 85 | 86 | return LambdaLR(optimizer, lr_lambda) -------------------------------------------------------------------------------- /tune_models.py: -------------------------------------------------------------------------------- 1 | # Src: Andrew's tune_tf2 2 | 3 | import os 4 | import os.path as osp 5 | import numpy as np 6 | import ray 7 | from ray import tune 8 | from yacs.config import CfgNode as CN 9 | import torch 10 | 11 | from src.config.default import get_cfg_defaults, unflatten 12 | from src.runner import Runner 13 | 14 | class tuneNDT(tune.Trainable): 15 | """ A wrapper class that allows `tune` to interface with NDT. 16 | """ 17 | 18 | def setup(self, config): 19 | yacs_cfg = self.convert_tune_cfg(config) 20 | self.epochs_per_generation = yacs_cfg.TRAIN.TUNE_EPOCHS_PER_GENERATION 21 | self.warmup_epochs = yacs_cfg.TRAIN.TUNE_WARMUP 22 | self.runner = Runner(config=yacs_cfg) 23 | self.runner.load_device() 24 | self.runner.load_train_val_data_and_masker() 25 | num_hidden = self.runner.setup_model(self.runner.device) 26 | self.runner.load_optimizer(num_hidden) 27 | 28 | def step(self): 29 | num_epochs = self.epochs_per_generation 30 | # the first generation always completes ramping (warmup) 31 | if self.runner.count_updates < self.warmup_epochs: 32 | num_epochs += self.warmup_epochs 33 | for i in range(num_epochs): 34 | metrics = self.runner.train_epoch() 35 | return metrics 36 | 37 | def save_checkpoint(self, tmp_ckpt_dir): 38 | path = osp.join(tmp_ckpt_dir, f"{self.runner.config.VARIANT}.{self.runner.count_checkpoints}.pth") 39 | self.runner.save_checkpoint(path) 40 | return path 41 | 42 | def load_checkpoint(self, path): 43 | self.runner.load_checkpoint(path) 44 | 45 | def reset_config(self, new_config): 46 | new_cfg_node = self.convert_tune_cfg(new_config) 47 | self.runner.update_config(new_cfg_node) 48 | return True 49 | 50 | def convert_tune_cfg(self, flat_cfg_dict): 51 | """Converts the tune config dictionary into a CfgNode for LFADS. 52 | """ 53 | cfg_node = get_cfg_defaults() 54 | 55 | flat_cfg_dict['CHECKPOINT_DIR'] = osp.join(self.logdir, 'ckpts') 56 | flat_cfg_dict['TENSORBOARD_DIR'] = osp.join(self.logdir, 'tb') 57 | flat_cfg_dict['LOG_DIR'] = osp.join(self.logdir, 'logs') 58 | flat_cfg_dict['TRAIN.TUNE_MODE'] = True 59 | cfg_update = CN(unflatten(flat_cfg_dict)) 60 | cfg_node.merge_from_other_cfg(cfg_update) 61 | 62 | return cfg_node --------------------------------------------------------------------------------