├── .gitattributes
├── .gitignore
├── README.md
├── UNLICENSE
├── assets
└── teaser.png
├── configs
├── 20ms_arxiv.json
├── 20ms_arxiv_no_reg.json
├── 20ms_full_context.json
├── 20ms_mask.json
├── 20ms_no_span.json
├── area2_bump.yaml
├── arxiv
│ ├── chaotic.yaml
│ ├── lorenz.yaml
│ ├── m700_115.yaml
│ ├── m700_2296.yaml
│ ├── m700_2296_postnorm.yaml
│ ├── m700_230.yaml
│ ├── m700_460.yaml
│ ├── m700_no_log.yaml
│ ├── m700_no_reg.yaml
│ ├── m700_no_span.yaml
│ └── m700_nonzero.yaml
├── dmfc_rsg.json
├── dmfc_rsg.yaml
├── mc_maze.json
├── mc_maze.yaml
├── mc_maze_large.yaml
├── mc_maze_medium.yaml
├── mc_maze_small.yaml
├── mc_maze_small_from_scratch.yaml
├── mc_rtt.yaml
├── sweep_bump.json
├── sweep_generic.json
└── sweep_simple.json
├── data
├── chaotic_rnn
│ ├── gen_synth_data_no_inputs.sh
│ ├── generate_chaotic_rnn_data.py
│ ├── generate_chaotic_rnn_data_allowRandomSeed.py
│ ├── synthetic_data_utils.py
│ └── utils.py
├── lfads_lorenz.h5
└── lorenz.py
├── defaults.py
├── environment.yml
├── nlb.yml
├── ray_get_lfve.py
├── ray_random.py
├── scripts
├── analyze_ray.py
├── analyze_utils.py
├── clear.sh
├── eval.sh
├── fig_3a_synthetic_neurons.py
├── fig_3b_4b_plot_hparams.py
├── fig_5_lfads_times.py
├── fig_5_ndt_times.py
├── fig_6_plot_losses.py
├── nlb.py
├── nlb_from_scratch.ipynb
├── nlb_from_scratch.py
├── record_all_rates.py
├── scratch.py
├── simple_ci.py
└── train.sh
├── src
├── __init__.py
├── config
│ ├── __init__.py
│ └── default.py
├── dataset.py
├── logger_wrapper.py
├── mask.py
├── model.py
├── model_baselines.py
├── model_registry.py
├── run.py
├── runner.py
├── tb_wrapper.py
└── utils.py
└── tune_models.py
/.gitattributes:
--------------------------------------------------------------------------------
1 | *.h5 filter=lfs diff=lfs merge=lfs -text
2 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 |
2 | *.h5
3 | *.pdf
4 | *.png
5 | logs/
6 | tb/
7 | tmp
8 |
9 |
10 | # Byte-compiled / optimized / DLL files
11 | __pycache__/
12 | *.py[cod]
13 | *$py.class
14 | *.pyc
15 | # C extensions
16 | *.so
17 |
18 | # Distribution / packaging
19 | .Python
20 | build/
21 | develop-eggs/
22 | dist/
23 | downloads/
24 | eggs/
25 | .eggs/
26 | lib/
27 | lib64/
28 | parts/
29 | sdist/
30 | var/
31 | wheels/
32 | share/python-wheels/
33 | *.egg-info/
34 | .installed.cfg
35 | *.egg
36 | MANIFEST
37 |
38 | # PyInstaller
39 | # Usually these files are written by a python script from a template
40 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
41 | *.manifest
42 | *.spec
43 |
44 | # Installer logs
45 | pip-log.txt
46 | pip-delete-this-directory.txt
47 |
48 | # Unit test / coverage reports
49 | htmlcov/
50 | .tox/
51 | .nox/
52 | .coverage
53 | .coverage.*
54 | .cache
55 | nosetests.xml
56 | coverage.xml
57 | *.cover
58 | *.py,cover
59 | .hypothesis/
60 | .pytest_cache/
61 | cover/
62 |
63 | # Translations
64 | *.mo
65 | *.pot
66 |
67 | # Django stuff:
68 | *.log
69 | local_settings.py
70 | db.sqlite3
71 | db.sqlite3-journal
72 |
73 | # Flask stuff:
74 | instance/
75 | .webassets-cache
76 |
77 | # Scrapy stuff:
78 | .scrapy
79 |
80 | # Sphinx documentation
81 | docs/_build/
82 |
83 | # PyBuilder
84 | .pybuilder/
85 | target/
86 |
87 | # Jupyter Notebook
88 | .ipynb_checkpoints
89 |
90 | # IPython
91 | profile_default/
92 | ipython_config.py
93 |
94 | # pyenv
95 | # For a library or package, you might want to ignore these files since the code is
96 | # intended to run in multiple environments; otherwise, check them in:
97 | # .python-version
98 |
99 | # pipenv
100 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
101 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
102 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
103 | # install all needed dependencies.
104 | #Pipfile.lock
105 |
106 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
107 | __pypackages__/
108 |
109 | # Celery stuff
110 | celerybeat-schedule
111 | celerybeat.pid
112 |
113 | # SageMath parsed files
114 | *.sage.py
115 |
116 | # Environments
117 | .env
118 | .venv
119 | env/
120 | venv/
121 | ENV/
122 | env.bak/
123 | venv.bak/
124 |
125 | # Spyder project settings
126 | .spyderproject
127 | .spyproject
128 |
129 | # Rope project settings
130 | .ropeproject
131 |
132 | # mkdocs documentation
133 | /site
134 |
135 | # mypy
136 | .mypy_cache/
137 | .dmypy.json
138 | dmypy.json
139 |
140 | # Pyre type checker
141 | .pyre/
142 |
143 | # pytype static type analyzer
144 | .pytype/
145 |
146 | # Cython debug symbols
147 | cython_debug/
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Neural Data Transformers
2 |
3 |
4 |
5 |
6 | This is the code for the paper "Representation learning for neural population activity with Neural Data Transformers". We provide the code as reference, but we are unable to help debug specific issues e.g. in using the model, at this time. NLB-related configuration files are listed directly under `configs/`, whereas original NDT paper configs are under `configs/arxiv/`.
7 |
8 | - Want to quickly get started with NLB? Check out `nlb_from_scratch.py` or `nlb_from_scratch.ipynb`.
9 |
10 | ## Setup
11 | We recommend you set up your code environment with `conda/miniconda`.
12 | The dependencies necessary for this project can then be installed with:
13 | `conda env create -f nlb.yml`
14 | This project was developed with Python 3.6, and originally with `environment.yml`. `nlb.yml` is a fresh Python 3.7 environment with minimal dependencies, but has not been tested for result reproducibility (no impact expected).
15 |
16 | ## Data
17 | The Lorenz dataset is provided in `data/lfads_lorenz.h5`. This file is stored on this repo with [`git-lfs`](https://git-lfs.github.com/). Therefore, if you've not used `git-lfs` before, please run `git lfs install` and `git lfs pull` to pull down the full h5 file.
18 |
19 | The autonomous chaotic RNN dataset can be generated by running `./data/gen_synth_data_no_inputs.sh`. The generating script is taken from [the Tensorflow release](https://github.com/tensorflow/models/tree/master/research/lfads/synth_data) from LFADS, Sussillo et al. The maze dataset is unavailable at this time.
20 |
21 | ## Training + Evaluation
22 | Experimental configurations are set in `./configs/`. To train a single model with a configuration `./configs/.yaml`, run `./scripts/train.sh `.
23 |
24 | The provided sample configurations in `./configs/arxiv/` were used in the HP sweeps for the main results in the paper, with the sweep parameters in `./configs/*json`. Note that sweeping is done with the `ray[tune]` package. To run a sweep, run `python ray_random.py -e ` (the same config system is used). Note that main paper and NLB results use this random search.
25 |
26 | R2 is reported automatically for synthetic datasets. Maze analyses + configurations are unfortunately unavailable at this time.
27 |
28 | ## Analysis
29 | Reference scripts that were used to produce most figures are available in `scripts`. They were created and iterated on as VSCode notebooks. They may require external information, run directories, even codebases etc. Scripts are provided to give a sense of analysis procedure, not to use as an out-of-the-box reproducibility notebook.
30 |
31 | ## Citation
32 | ```
33 | @article{ye2021ndt,
34 | title = {Representation learning for neural population activity with {Neural} {Data} {Transformers}},
35 | issn = {2690-2664},
36 | url = {https://nbdt.scholasticahq.com/article/27358-representation-learning-for-neural-population-activity-with-neural-data-transformers},
37 | doi = {10.51628/001c.27358},
38 | language = {en},
39 | urldate = {2021-08-29},
40 | journal = {Neurons, Behavior, Data analysis, and Theory},
41 | author = {Ye, Joel and Pandarinath, Chethan},
42 | month = aug,
43 | year = {2021},
44 | }
45 | ```
46 |
--------------------------------------------------------------------------------
/UNLICENSE:
--------------------------------------------------------------------------------
1 | This is free and unencumbered software released into the public domain.
2 |
3 | Anyone is free to copy, modify, publish, use, compile, sell, or
4 | distribute this software, either in source code form or as a compiled
5 | binary, for any purpose, commercial or non-commercial, and by any
6 | means.
7 |
8 | In jurisdictions that recognize copyright laws, the author or authors
9 | of this software dedicate any and all copyright interest in the
10 | software to the public domain. We make this dedication for the benefit
11 | of the public at large and to the detriment of our heirs and
12 | successors. We intend this dedication to be an overt act of
13 | relinquishment in perpetuity of all present and future rights to this
14 | software under copyright law.
15 |
16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
17 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
18 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
19 | IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR
20 | OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
21 | ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
22 | OTHER DEALINGS IN THE SOFTWARE.
23 |
24 | For more information, please refer to
25 |
--------------------------------------------------------------------------------
/assets/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/snel-repo/neural-data-transformers/98dd85a24885ffb76adfeed0c2a89d3ea3ecf9d1/assets/teaser.png
--------------------------------------------------------------------------------
/configs/20ms_arxiv.json:
--------------------------------------------------------------------------------
1 | {
2 | "MODEL.DROPOUT": {
3 | "low": 0.2,
4 | "high": 0.6,
5 | "sample_fn": "uniform",
6 | "explore_wt": 0.3
7 | },
8 | "MODEL.DROPOUT_RATES": {
9 | "low": 0.2,
10 | "high": 0.6,
11 | "sample_fn": "uniform",
12 | "explore_wt": 0.3
13 | },
14 | "MODEL.DROPOUT_EMBEDDING": {
15 | "low": 0.2,
16 | "high": 0.6,
17 | "sample_fn": "uniform",
18 | "explore_wt": 0.3
19 | },
20 | "MODEL.CONTEXT_FORWARD": {
21 | "low": 10,
22 | "high": 40,
23 | "sample_fn": "randint",
24 | "enforce_limits": false,
25 | "is_integer": true
26 | },
27 | "MODEL.CONTEXT_BACKWARD": {
28 | "low": 10,
29 | "high": 40,
30 | "sample_fn": "randint",
31 | "enforce_limits": false,
32 | "is_integer": true
33 | },
34 | "TRAIN.LR.INIT": {
35 | "low": 1e-5,
36 | "high": 1e-3,
37 | "sample_fn": "loguniform"
38 | },
39 | "TRAIN.MASK_RANDOM_RATIO": {
40 | "low": 0.6,
41 | "high": 1.0,
42 | "sample_fn": "uniform"
43 | },
44 | "TRAIN.MASK_TOKEN_RATIO": {
45 | "low": 0.6,
46 | "high": 1.0,
47 | "sample_fn": "uniform"
48 | },
49 | "TRAIN.MASK_MAX_SPAN": {
50 | "low": 1,
51 | "high": 7,
52 | "sample_fn": "randint",
53 | "is_integer": true
54 | }
55 | }
--------------------------------------------------------------------------------
/configs/20ms_arxiv_no_reg.json:
--------------------------------------------------------------------------------
1 | {
2 | "MODEL.DROPOUT": {
3 | "low": 0.1,
4 | "high": 0.2,
5 | "sample_fn": "uniform",
6 | "explore_wt": 0.3
7 | },
8 | "MODEL.DROPOUT_RATES": {
9 | "low": 0.1,
10 | "high": 0.2,
11 | "sample_fn": "uniform",
12 | "explore_wt": 0.3
13 | },
14 | "MODEL.DROPOUT_EMBEDDING": {
15 | "low": 0.1,
16 | "high": 0.2,
17 | "sample_fn": "uniform",
18 | "explore_wt": 0.3
19 | },
20 | "MODEL.CONTEXT_FORWARD": {
21 | "low": 10,
22 | "high": 40,
23 | "sample_fn": "randint",
24 | "enforce_limits": false,
25 | "is_integer": true
26 | },
27 | "MODEL.CONTEXT_BACKWARD": {
28 | "low": 10,
29 | "high": 40,
30 | "sample_fn": "randint",
31 | "enforce_limits": false,
32 | "is_integer": true
33 | },
34 | "TRAIN.LR.INIT": {
35 | "low": 1e-5,
36 | "high": 1e-3,
37 | "sample_fn": "loguniform"
38 | },
39 | "TRAIN.MASK_RANDOM_RATIO": {
40 | "low": 0.6,
41 | "high": 1.0,
42 | "sample_fn": "uniform"
43 | },
44 | "TRAIN.MASK_TOKEN_RATIO": {
45 | "low": 0.6,
46 | "high": 1.0,
47 | "sample_fn": "uniform"
48 | },
49 | "TRAIN.MASK_MAX_SPAN": {
50 | "low": 1,
51 | "high": 7,
52 | "sample_fn": "randint",
53 | "is_integer": true
54 | }
55 | }
--------------------------------------------------------------------------------
/configs/20ms_full_context.json:
--------------------------------------------------------------------------------
1 | {
2 | "MODEL.DROPOUT": {
3 | "low": 0.2,
4 | "high": 0.6,
5 | "sample_fn": "uniform",
6 | "explore_wt": 0.3
7 | },
8 | "MODEL.DROPOUT_RATES": {
9 | "low": 0.2,
10 | "high": 0.6,
11 | "sample_fn": "uniform",
12 | "explore_wt": 0.3
13 | },
14 | "MODEL.DROPOUT_EMBEDDING": {
15 | "low": 0.2,
16 | "high": 0.6,
17 | "sample_fn": "uniform",
18 | "explore_wt": 0.3
19 | },
20 | "TRAIN.LR.INIT": {
21 | "low": 1e-5,
22 | "high": 1e-3,
23 | "sample_fn": "loguniform"
24 | },
25 | "TRAIN.MASK_RANDOM_RATIO": {
26 | "low": 0.6,
27 | "high": 1.0,
28 | "sample_fn": "uniform"
29 | },
30 | "TRAIN.MASK_TOKEN_RATIO": {
31 | "low": 0.6,
32 | "high": 1.0,
33 | "sample_fn": "uniform"
34 | },
35 | "TRAIN.MASK_MAX_SPAN": {
36 | "low": 1,
37 | "high": 5,
38 | "sample_fn": "randint",
39 | "is_integer": true
40 | }
41 | }
--------------------------------------------------------------------------------
/configs/20ms_mask.json:
--------------------------------------------------------------------------------
1 | {
2 | "MODEL.DROPOUT": {
3 | "low": 0.2,
4 | "high": 0.6,
5 | "sample_fn": "uniform",
6 | "explore_wt": 0.3
7 | },
8 | "MODEL.DROPOUT_RATES": {
9 | "low": 0.2,
10 | "high": 0.6,
11 | "sample_fn": "uniform",
12 | "explore_wt": 0.3
13 | },
14 | "MODEL.DROPOUT_EMBEDDING": {
15 | "low": 0.2,
16 | "high": 0.6,
17 | "sample_fn": "uniform",
18 | "explore_wt": 0.3
19 | },
20 | "MODEL.CONTEXT_FORWARD": {
21 | "low": 10,
22 | "high": 40,
23 | "sample_fn": "randint",
24 | "enforce_limits": false,
25 | "is_integer": true
26 | },
27 | "MODEL.CONTEXT_BACKWARD": {
28 | "low": 10,
29 | "high": 40,
30 | "sample_fn": "randint",
31 | "enforce_limits": false,
32 | "is_integer": true
33 | },
34 | "TRAIN.MASK_RANDOM_RATIO": {
35 | "low": 0.6,
36 | "high": 1.0,
37 | "sample_fn": "uniform"
38 | },
39 | "TRAIN.MASK_TOKEN_RATIO": {
40 | "low": 0.6,
41 | "high": 1.0,
42 | "sample_fn": "uniform"
43 | },
44 | "TRAIN.MASK_MAX_SPAN": {
45 | "low": 1,
46 | "high": 5,
47 | "sample_fn": "randint",
48 | "is_integer": true
49 | }
50 | }
--------------------------------------------------------------------------------
/configs/20ms_no_span.json:
--------------------------------------------------------------------------------
1 | {
2 | "MODEL.DROPOUT": {
3 | "low": 0.2,
4 | "high": 0.6,
5 | "sample_fn": "uniform",
6 | "explore_wt": 0.3
7 | },
8 | "MODEL.DROPOUT_RATES": {
9 | "low": 0.2,
10 | "high": 0.6,
11 | "sample_fn": "uniform",
12 | "explore_wt": 0.3
13 | },
14 | "MODEL.DROPOUT_EMBEDDING": {
15 | "low": 0.2,
16 | "high": 0.6,
17 | "sample_fn": "uniform",
18 | "explore_wt": 0.3
19 | },
20 | "MODEL.CONTEXT_FORWARD": {
21 | "low": 10,
22 | "high": 40,
23 | "sample_fn": "randint",
24 | "enforce_limits": false,
25 | "is_integer": true
26 | },
27 | "MODEL.CONTEXT_BACKWARD": {
28 | "low": 10,
29 | "high": 40,
30 | "sample_fn": "randint",
31 | "enforce_limits": false,
32 | "is_integer": true
33 | },
34 | "TRAIN.LR.INIT": {
35 | "low": 1e-5,
36 | "high": 1e-3,
37 | "sample_fn": "loguniform"
38 | },
39 | "TRAIN.MASK_RANDOM_RATIO": {
40 | "low": 0.6,
41 | "high": 1.0,
42 | "sample_fn": "uniform"
43 | },
44 | "TRAIN.MASK_TOKEN_RATIO": {
45 | "low": 0.6,
46 | "high": 1.0,
47 | "sample_fn": "uniform"
48 | }
49 | }
--------------------------------------------------------------------------------
/configs/area2_bump.yaml:
--------------------------------------------------------------------------------
1 | CHECKPOINT_DIR: "/snel/share/joel/transformer_modeling/"
2 | DATA:
3 | DATAPATH: "/snel/share/data/nlb/"
4 | TRAIN_FILENAME: 'area2_bump.h5'
5 | VAL_FILENAME: 'area2_bump.h5'
6 | MODEL:
7 | TRIAL_LENGTH: 0
8 | LEARNABLE_POSITION: True
9 | PRE_NORM: True
10 | FIXUP_INIT: True
11 | EMBED_DIM: 2 # We embed to 2 here so transformer can use 2 heads. Perf diff is minimal.
12 | LOGRATE: True
13 | NUM_LAYERS: 4
14 | TRAIN:
15 | LR:
16 | WARMUP: 5000
17 | WEIGHT_DECAY: 5.0e-05
18 | LOG_INTERVAL: 200
19 | VAL_INTERVAL: 20
20 | CHECKPOINT_INTERVAL: 1000
21 | PATIENCE: 2500
22 | NUM_UPDATES: 40001
23 | MASK_RATIO: 0.25
24 |
25 | # Aggressive regularization
26 | TUNE_HP_JSON: './configs/sweep_bump.json'
--------------------------------------------------------------------------------
/configs/arxiv/chaotic.yaml:
--------------------------------------------------------------------------------
1 | CHECKPOINT_DIR: "/snel/share/joel/transformer_modeling/"
2 | DATA:
3 | DATAPATH: "/snel/share/data/chaotic_rnn_data/data_no_inputs/"
4 | TRAIN_FILENAME: 'chaotic_rnn_no_inputs_dataset_N50_S50'
5 | VAL_FILENAME: 'chaotic_rnn_no_inputs_dataset_N50_S50'
6 | MODEL:
7 | TRIAL_LENGTH: 100
8 | LEARNABLE_POSITION: True
9 | EMBED_DIM: 2
10 | TRAIN:
11 | LR:
12 | SCHEDULE: false
13 | LOG_INTERVAL: 250
14 | CHECKPOINT_INTERVAL: 250
15 | NUM_UPDATES: 5001
16 | MASK_RATIO: 0.25
17 | WEIGHT_DECAY: 0.0001
18 |
19 | TUNE_EPOCHS_PER_GENERATION: 100
20 | TUNE_HP_JSON: './configs/sweep_generic.json'
21 |
--------------------------------------------------------------------------------
/configs/arxiv/lorenz.yaml:
--------------------------------------------------------------------------------
1 | CHECKPOINT_DIR: "/snel/share/joel/transformer_modeling/"
2 | DATA:
3 | DATAPATH: "/snel/share/data/lfads_lorenz_20ms/"
4 | TRAIN_FILENAME: 'lfads_dataset001.h5'
5 | VAL_FILENAME: 'lfads_dataset001.h5'
6 | MODEL:
7 | TRIAL_LENGTH: 50
8 | LEARNABLE_POSITION: True
9 | PRE_NORM: True
10 | FIXUP_INIT: True
11 | EMBED_DIM: 2 # We embed to 2 here so transformer can use 2 heads. Perf diff is minimal.
12 | LOGRATE: True
13 | TRAIN:
14 | LR:
15 | SCHEDULE: False
16 | LOG_INTERVAL: 50
17 | CHECKPOINT_INTERVAL: 500
18 | PATIENCE: 2500
19 | NUM_UPDATES: 20001
20 | MASK_RATIO: 0.25
21 |
22 | TUNE_HP_JSON: './configs/sweep_generic.json'
--------------------------------------------------------------------------------
/configs/arxiv/m700_115.yaml:
--------------------------------------------------------------------------------
1 | CHECKPOINT_DIR: "/snel/share/joel/transformer_modeling/"
2 | DATA:
3 | DATAPATH: "/snel/share/data/ndt_paper/m1_maze/heldout_trial/0115_trials/0_seed"
4 | TRAIN_FILENAME: 'lfads_input.h5'
5 | VAL_FILENAME: 'lfads_input.h5'
6 | MODEL:
7 | TRIAL_LENGTH: 70
8 | LEARNABLE_POSITION: True
9 | PRE_NORM: True
10 | FIXUP_INIT: True
11 | EMBED_DIM: 0
12 | LOGRATE: True
13 | TRAIN:
14 | LR:
15 | WARMUP: 5000
16 | MASK_RATIO: 0.25
17 | WEIGHT_DECAY: 5.0e-05
18 | PATIENCE: 3000
19 | LOG_INTERVAL: 200
20 | VAL_INTERVAL: 1
21 | CHECKPOINT_INTERVAL: 1000
22 | NUM_UPDATES: 50501
23 | MASK_SPAN_RAMP_START: 8000
24 | MASK_SPAN_RAMP_END: 12000
25 |
26 | TUNE_HP_JSON: './configs/20ms_arxiv.json' # This space has more aggressive regularization
--------------------------------------------------------------------------------
/configs/arxiv/m700_2296.yaml:
--------------------------------------------------------------------------------
1 | CHECKPOINT_DIR: "/snel/share/joel/transformer_modeling/"
2 | DATA:
3 | DATAPATH: "/snel/share/data/ndt_paper/m1_maze/heldout_trial/2296_trials/0_seed"
4 | TRAIN_FILENAME: 'lfads_input.h5'
5 | VAL_FILENAME: 'lfads_input.h5'
6 | MODEL:
7 | TRIAL_LENGTH: 70
8 | LEARNABLE_POSITION: True
9 | PRE_NORM: True
10 | FIXUP_INIT: True
11 | EMBED_DIM: 0
12 | LOGRATE: True
13 | TRAIN:
14 | LR:
15 | WARMUP: 5000
16 | MASK_RATIO: 0.25
17 | WEIGHT_DECAY: 5.0e-05
18 | PATIENCE: 3000
19 | LOG_INTERVAL: 200
20 | VAL_INTERVAL: 20
21 | CHECKPOINT_INTERVAL: 1000
22 | NUM_UPDATES: 50501
23 | MASK_SPAN_RAMP_START: 8000
24 | MASK_SPAN_RAMP_END: 12000
25 |
26 | TUNE_HP_JSON: './configs/20ms_arxiv.json' # This space has more aggressive regularization
27 |
--------------------------------------------------------------------------------
/configs/arxiv/m700_2296_postnorm.yaml:
--------------------------------------------------------------------------------
1 | CHECKPOINT_DIR: "/snel/share/joel/transformer_modeling/"
2 | DATA:
3 | DATAPATH: "/snel/share/data/ndt_paper/m1_maze/heldout_trial/2296_trials/0_seed"
4 | TRAIN_FILENAME: 'lfads_input.h5'
5 | VAL_FILENAME: 'lfads_input.h5'
6 | MODEL:
7 | TRIAL_LENGTH: 70
8 | LEARNABLE_POSITION: True
9 | PRE_NORM: False
10 | FIXUP_INIT: False
11 | EMBED_DIM: 0
12 | LOGRATE: True
13 | TRAIN:
14 | LR:
15 | WARMUP: 5000
16 | MASK_RATIO: 0.25
17 | WEIGHT_DECAY: 5.0e-05
18 | PATIENCE: 3000
19 | LOG_INTERVAL: 200
20 | VAL_INTERVAL: 20
21 | CHECKPOINT_INTERVAL: 1000
22 | NUM_UPDATES: 50501
23 | MASK_SPAN_RAMP_START: 8000
24 | MASK_SPAN_RAMP_END: 12000
25 |
26 | TUNE_HP_JSON: './configs/20ms_arxiv.json' # This space has more aggressive regularization
27 |
--------------------------------------------------------------------------------
/configs/arxiv/m700_230.yaml:
--------------------------------------------------------------------------------
1 | CHECKPOINT_DIR: "/snel/share/joel/transformer_modeling/"
2 | DATA:
3 | DATAPATH: "/snel/share/data/ndt_paper/m1_maze/heldout_trial/0230_trials/0_seed"
4 | TRAIN_FILENAME: 'lfads_input.h5'
5 | VAL_FILENAME: 'lfads_input.h5'
6 | MODEL:
7 | TRIAL_LENGTH: 70
8 | LEARNABLE_POSITION: True
9 | PRE_NORM: True
10 | FIXUP_INIT: True
11 | EMBED_DIM: 0
12 | LOGRATE: True
13 | TRAIN:
14 | LR:
15 | WARMUP: 5000
16 | MASK_RATIO: 0.25
17 | WEIGHT_DECAY: 5.0e-05
18 | PATIENCE: 3000
19 | LOG_INTERVAL: 200
20 | VAL_INTERVAL: 2
21 | CHECKPOINT_INTERVAL: 1000
22 | NUM_UPDATES: 50501
23 | MASK_SPAN_RAMP_START: 8000
24 | MASK_SPAN_RAMP_END: 12000
25 |
26 | TUNE_HP_JSON: './configs/20ms_arxiv.json' # This space has more aggressive regularization
27 |
--------------------------------------------------------------------------------
/configs/arxiv/m700_460.yaml:
--------------------------------------------------------------------------------
1 | CHECKPOINT_DIR: "/snel/share/joel/transformer_modeling/"
2 | DATA:
3 | DATAPATH: "/snel/share/data/ndt_paper/m1_maze/heldout_trial/0460_trials/0_seed"
4 | TRAIN_FILENAME: 'lfads_input.h5'
5 | VAL_FILENAME: 'lfads_input.h5'
6 | MODEL:
7 | TRIAL_LENGTH: 70
8 | LEARNABLE_POSITION: True
9 | PRE_NORM: True
10 | FIXUP_INIT: True
11 | EMBED_DIM: 0
12 | LOGRATE: True
13 | TRAIN:
14 | LR:
15 | WARMUP: 5000
16 | MASK_RATIO: 0.25
17 | WEIGHT_DECAY: 5.0e-05
18 | PATIENCE: 3000
19 | LOG_INTERVAL: 200
20 | VAL_INTERVAL: 4
21 | CHECKPOINT_INTERVAL: 1000
22 | NUM_UPDATES: 50501
23 | MASK_SPAN_RAMP_START: 8000
24 | MASK_SPAN_RAMP_END: 12000
25 |
26 | TUNE_HP_JSON: './configs/20ms_arxiv.json' # This space has more aggressive regularization
27 |
--------------------------------------------------------------------------------
/configs/arxiv/m700_no_log.yaml:
--------------------------------------------------------------------------------
1 | CHECKPOINT_DIR: "/snel/share/joel/transformer_modeling/"
2 | DATA:
3 | DATAPATH: "/snel/share/data/ndt_paper/m1_maze/heldout_trial/2296_trials/0_seed"
4 | TRAIN_FILENAME: 'lfads_input.h5'
5 | VAL_FILENAME: 'lfads_input.h5'
6 | MODEL:
7 | TRIAL_LENGTH: 70
8 | LEARNABLE_POSITION: True
9 | PRE_NORM: True
10 | FIXUP_INIT: True
11 | EMBED_DIM: 0
12 | LOGRATE: False
13 | TRAIN:
14 | LR:
15 | WARMUP: 5000
16 | MASK_RATIO: 0.25
17 | WEIGHT_DECAY: 5.0e-05
18 | PATIENCE: 3000
19 | LOG_INTERVAL: 200
20 | VAL_INTERVAL: 20
21 | CHECKPOINT_INTERVAL: 1000
22 | NUM_UPDATES: 50501
23 | MASK_SPAN_RAMP_START: 8000
24 | MASK_SPAN_RAMP_END: 12000
25 |
26 | TUNE_HP_JSON: './configs/20ms_arxiv.json' # This space has more aggressive regularization
27 |
--------------------------------------------------------------------------------
/configs/arxiv/m700_no_reg.yaml:
--------------------------------------------------------------------------------
1 | CHECKPOINT_DIR: "/snel/share/joel/transformer_modeling/"
2 | DATA:
3 | DATAPATH: "/snel/share/data/ndt_paper/m1_maze/heldout_trial/2296_trials/0_seed"
4 | TRAIN_FILENAME: 'lfads_input.h5'
5 | VAL_FILENAME: 'lfads_input.h5'
6 | MODEL:
7 | TRIAL_LENGTH: 70
8 | LEARNABLE_POSITION: True
9 | PRE_NORM: True
10 | FIXUP_INIT: True
11 | EMBED_DIM: 0
12 | LOGRATE: False
13 | TRAIN:
14 | LR:
15 | WARMUP: 5000
16 | MASK_RATIO: 0.25
17 | WEIGHT_DECAY: 5.0e-05
18 | PATIENCE: 3000
19 | LOG_INTERVAL: 200
20 | VAL_INTERVAL: 20
21 | CHECKPOINT_INTERVAL: 1000
22 | NUM_UPDATES: 50501
23 | MASK_SPAN_RAMP_START: 8000
24 | MASK_SPAN_RAMP_END: 12000
25 |
26 | TUNE_HP_JSON: './configs/20ms_arxiv_no_reg.json' # This space has more aggressive regularization
27 |
--------------------------------------------------------------------------------
/configs/arxiv/m700_no_span.yaml:
--------------------------------------------------------------------------------
1 | CHECKPOINT_DIR: "/snel/share/joel/transformer_modeling/"
2 | DATA:
3 | DATAPATH: "/snel/share/data/ndt_paper/m1_maze/heldout_trial/2296_trials/0_seed"
4 | TRAIN_FILENAME: 'lfads_input.h5'
5 | VAL_FILENAME: 'lfads_input.h5'
6 | MODEL:
7 | TRIAL_LENGTH: 70
8 | LEARNABLE_POSITION: True
9 | PRE_NORM: True
10 | FIXUP_INIT: True
11 | EMBED_DIM: 0
12 | LOGRATE: True
13 | TRAIN:
14 | LR:
15 | WARMUP: 5000
16 | MASK_RATIO: 0.25
17 | WEIGHT_DECAY: 5.0e-05
18 | PATIENCE: 3000
19 | LOG_INTERVAL: 200
20 | VAL_INTERVAL: 20
21 | CHECKPOINT_INTERVAL: 1000
22 | NUM_UPDATES: 50501
23 | MASK_SPAN_RAMP_START: 8000
24 | MASK_SPAN_RAMP_END: 12000
25 |
26 | TUNE_HP_JSON: './configs/20ms_no_span.json'
27 |
--------------------------------------------------------------------------------
/configs/arxiv/m700_nonzero.yaml:
--------------------------------------------------------------------------------
1 | CHECKPOINT_DIR: "/snel/share/joel/transformer_modeling/"
2 | DATA:
3 | DATAPATH: "/snel/share/data/ndt_paper/m1_maze/heldout_trial/2296_trials/0_seed"
4 | TRAIN_FILENAME: 'lfads_input.h5'
5 | VAL_FILENAME: 'lfads_input.h5'
6 | MODEL:
7 | TRIAL_LENGTH: 70
8 | LEARNABLE_POSITION: True
9 | PRE_NORM: True
10 | FIXUP_INIT: True
11 | EMBED_DIM: 0
12 | LOGRATE: False
13 | TRAIN:
14 | LR:
15 | WARMUP: 5000
16 | MASK_RATIO: 0.25
17 | WEIGHT_DECAY: 5.0e-05
18 | PATIENCE: 3000
19 | LOG_INTERVAL: 200
20 | VAL_INTERVAL: 20
21 | CHECKPOINT_INTERVAL: 1000
22 | NUM_UPDATES: 50501
23 | MASK_SPAN_RAMP_START: 8000
24 | MASK_SPAN_RAMP_END: 12000
25 | USE_ZERO_MASK: False
26 |
27 | TUNE_HP_JSON: './configs/20ms_arxiv.json' # This space has more aggressive regularization
28 |
--------------------------------------------------------------------------------
/configs/dmfc_rsg.json:
--------------------------------------------------------------------------------
1 | {
2 | "MODEL.DROPOUT": {
3 | "low": 0.3,
4 | "high": 0.7,
5 | "sample_fn": "uniform",
6 | "explore_wt": 0.3
7 | },
8 | "MODEL.DROPOUT_RATES": {
9 | "low": 0.3,
10 | "high": 0.7,
11 | "sample_fn": "uniform",
12 | "explore_wt": 0.3
13 | },
14 | "MODEL.DROPOUT_EMBEDDING": {
15 | "low": 0.3,
16 | "high": 0.7,
17 | "sample_fn": "uniform",
18 | "explore_wt": 0.3
19 | },
20 | "MODEL.CONTEXT_FORWARD": {
21 | "low": 40,
22 | "high": 240,
23 | "sample_fn": "randint",
24 | "enforce_limits": false,
25 | "is_integer": true
26 | },
27 | "MODEL.CONTEXT_BACKWARD": {
28 | "low": 40,
29 | "high": 240,
30 | "sample_fn": "randint",
31 | "enforce_limits": false,
32 | "is_integer": true
33 | },
34 | "TRAIN.LR.INIT": {
35 | "low": 1e-5,
36 | "high": 5e-3,
37 | "explore_wt": 0.3
38 | },
39 | "TRAIN.LR.WARMUP": {
40 | "low": 0,
41 | "high": 2000,
42 | "sample_fn": "randint",
43 | "is_integer": true
44 | },
45 | "TRAIN.MASK_RANDOM_RATIO": {
46 | "low": 0.9,
47 | "high": 1.0,
48 | "sample_fn": "uniform"
49 | },
50 | "TRAIN.MASK_TOKEN_RATIO": {
51 | "low": 0.5,
52 | "high": 1.0,
53 | "sample_fn": "loguniform"
54 | },
55 | "TRAIN.MASK_MAX_SPAN": {
56 | "low": 1,
57 | "high": 7,
58 | "sample_fn": "randint",
59 | "is_integer": true
60 | }
61 | }
--------------------------------------------------------------------------------
/configs/dmfc_rsg.yaml:
--------------------------------------------------------------------------------
1 | CHECKPOINT_DIR: "/snel/share/joel/transformer_modeling/"
2 | DATA:
3 | DATAPATH: "/snel/share/data/nlb/"
4 | TRAIN_FILENAME: 'dmfc_rsg.h5'
5 | VAL_FILENAME: 'dmfc_rsg.h5'
6 | MODEL:
7 | TRIAL_LENGTH: 0 # 300
8 | LEARNABLE_POSITION: True
9 | PRE_NORM: True
10 | FIXUP_INIT: True
11 | EMBED_DIM: 0
12 | LOGRATE: True
13 | NUM_LAYERS: 6
14 | TRAIN:
15 | LR:
16 | WARMUP: 5000
17 | MASK_RATIO: 0.25
18 | WEIGHT_DECAY: 5.0e-05
19 | PATIENCE: 4000 # it's hard...
20 | LOG_INTERVAL: 200
21 | VAL_INTERVAL: 20
22 | CHECKPOINT_INTERVAL: 1000
23 | NUM_UPDATES: 50501
24 | MASK_SPAN_RAMP_START: 8000
25 | MASK_SPAN_RAMP_END: 12000
26 |
27 | TUNE_HP_JSON: './configs/dmfc_rsg.json' # This space has more aggressive regularization
28 |
29 | # current run as of Wednesday used 4 layers, 0.1-0.5 reg
--------------------------------------------------------------------------------
/configs/mc_maze.json:
--------------------------------------------------------------------------------
1 | {
2 | "MODEL.DROPOUT": {
3 | "low": 0.3,
4 | "high": 0.7,
5 | "sample_fn": "uniform",
6 | "explore_wt": 0.3
7 | },
8 | "MODEL.DROPOUT_RATES": {
9 | "low": 0.3,
10 | "high": 0.7,
11 | "sample_fn": "uniform",
12 | "explore_wt": 0.3
13 | },
14 | "MODEL.DROPOUT_EMBEDDING": {
15 | "low": 0.3,
16 | "high": 0.7,
17 | "sample_fn": "uniform",
18 | "explore_wt": 0.3
19 | },
20 | "MODEL.CONTEXT_FORWARD": {
21 | "low": 10,
22 | "high": 100,
23 | "sample_fn": "randint",
24 | "enforce_limits": false,
25 | "is_integer": true
26 | },
27 | "MODEL.CONTEXT_BACKWARD": {
28 | "low": 10,
29 | "high": 100,
30 | "sample_fn": "randint",
31 | "enforce_limits": false,
32 | "is_integer": true
33 | },
34 | "TRAIN.LR.INIT": {
35 | "low": 1e-5,
36 | "high": 5e-3,
37 | "explore_wt": 0.3
38 | },
39 | "TRAIN.LR.WARMUP": {
40 | "low": 0,
41 | "high": 2000,
42 | "sample_fn": "randint",
43 | "is_integer": true
44 | },
45 | "TRAIN.MASK_RANDOM_RATIO": {
46 | "low": 0.8,
47 | "high": 1.0,
48 | "sample_fn": "uniform"
49 | },
50 | "TRAIN.MASK_TOKEN_RATIO": {
51 | "low": 0.5,
52 | "high": 1.0,
53 | "sample_fn": "loguniform"
54 | },
55 | "TRAIN.MASK_MAX_SPAN": {
56 | "low": 1,
57 | "high": 7,
58 | "sample_fn": "randint",
59 | "is_integer": true
60 | }
61 | }
--------------------------------------------------------------------------------
/configs/mc_maze.yaml:
--------------------------------------------------------------------------------
1 | CHECKPOINT_DIR: "/snel/share/joel/transformer_modeling/"
2 | DATA:
3 | DATAPATH: "/snel/share/data/nlb/"
4 | TRAIN_FILENAME: 'mc_maze.h5'
5 | VAL_FILENAME: 'mc_maze.h5'
6 | MODEL:
7 | TRIAL_LENGTH: 0 # 140
8 | LEARNABLE_POSITION: True
9 | PRE_NORM: True
10 | FIXUP_INIT: True
11 | EMBED_DIM: 0
12 | LOGRATE: True
13 | NUM_LAYERS: 4
14 | TRAIN:
15 | LR:
16 | WARMUP: 5000
17 | MASK_RATIO: 0.25
18 | WEIGHT_DECAY: 5.0e-05
19 | PATIENCE: 3000
20 | LOG_INTERVAL: 200
21 | VAL_INTERVAL: 20
22 | CHECKPOINT_INTERVAL: 1000
23 | NUM_UPDATES: 50501
24 | MASK_SPAN_RAMP_START: 8000
25 | MASK_SPAN_RAMP_END: 12000
26 |
27 | TUNE_HP_JSON: './configs/mc_maze.json' # This space has more aggressive regularization
28 |
--------------------------------------------------------------------------------
/configs/mc_maze_large.yaml:
--------------------------------------------------------------------------------
1 | CHECKPOINT_DIR: "/snel/share/joel/transformer_modeling/"
2 | DATA:
3 | DATAPATH: "/snel/share/data/nlb/"
4 | TRAIN_FILENAME: 'mc_maze_large.h5'
5 | VAL_FILENAME: 'mc_maze_large.h5'
6 | MODEL:
7 | TRIAL_LENGTH: 0 # 140
8 | LEARNABLE_POSITION: True
9 | PRE_NORM: True
10 | FIXUP_INIT: True
11 | EMBED_DIM: 0
12 | LOGRATE: True
13 | NUM_LAYERS: 4
14 | TRAIN:
15 | LR:
16 | WARMUP: 5000
17 | MASK_RATIO: 0.25
18 | WEIGHT_DECAY: 5.0e-05
19 | PATIENCE: 3000
20 | LOG_INTERVAL: 200
21 | VAL_INTERVAL: 20
22 | CHECKPOINT_INTERVAL: 1000
23 | NUM_UPDATES: 50501
24 | MASK_SPAN_RAMP_START: 8000
25 | MASK_SPAN_RAMP_END: 12000
26 |
27 | TUNE_HP_JSON: './configs/mc_maze.json' # This space has more aggressive regularization
28 |
--------------------------------------------------------------------------------
/configs/mc_maze_medium.yaml:
--------------------------------------------------------------------------------
1 | CHECKPOINT_DIR: "/snel/share/joel/transformer_modeling/"
2 | DATA:
3 | DATAPATH: "/snel/share/data/nlb/"
4 | TRAIN_FILENAME: 'mc_maze_medium.h5'
5 | VAL_FILENAME: 'mc_maze_medium.h5'
6 | MODEL:
7 | TRIAL_LENGTH: 0 # 140
8 | LEARNABLE_POSITION: True
9 | PRE_NORM: True
10 | FIXUP_INIT: True
11 | EMBED_DIM: 0
12 | LOGRATE: True
13 | NUM_LAYERS: 4
14 | TRAIN:
15 | LR:
16 | WARMUP: 5000
17 | MASK_RATIO: 0.25
18 | WEIGHT_DECAY: 5.0e-05
19 | PATIENCE: 3000
20 | LOG_INTERVAL: 200
21 | VAL_INTERVAL: 20
22 | CHECKPOINT_INTERVAL: 1000
23 | NUM_UPDATES: 50501
24 | MASK_SPAN_RAMP_START: 8000
25 | MASK_SPAN_RAMP_END: 12000
26 |
27 | TUNE_HP_JSON: './configs/mc_maze.json' # This space has more aggressive regularization
28 |
--------------------------------------------------------------------------------
/configs/mc_maze_small.yaml:
--------------------------------------------------------------------------------
1 | CHECKPOINT_DIR: "/home/joelye/user_data/nlb/ndt_runs/"
2 | # CHECKPOINT_DIR: "/snel/share/joel/transformer_modeling/"
3 | DATA:
4 | DATAPATH: "/home/joelye/user_data/nlb/ndt_old/"
5 | # DATAPATH: "/snel/share/data/nlb/"
6 | TRAIN_FILENAME: 'mc_maze_small.h5'
7 | VAL_FILENAME: 'mc_maze_small.h5'
8 | MODEL:
9 | TRIAL_LENGTH: 0 # 140
10 | LEARNABLE_POSITION: True
11 | PRE_NORM: True
12 | FIXUP_INIT: True
13 | EMBED_DIM: 0
14 | LOGRATE: True
15 | NUM_LAYERS: 4
16 | TRAIN:
17 | LR:
18 | WARMUP: 2000
19 | MASK_RATIO: 0.25
20 | WEIGHT_DECAY: 5.0e-05
21 | PATIENCE: 2000
22 | LOG_INTERVAL: 200
23 | VAL_INTERVAL: 20
24 | CHECKPOINT_INTERVAL: 1000
25 | NUM_UPDATES: 50501
26 | MASK_SPAN_RAMP_START: 5000
27 | MASK_SPAN_RAMP_END: 10000
28 |
29 | TUNE_HP_JSON: './configs/mc_maze.json' # This space has more aggressive regularization
30 |
--------------------------------------------------------------------------------
/configs/mc_maze_small_from_scratch.yaml:
--------------------------------------------------------------------------------
1 | CHECKPOINT_DIR: "/home/joelye/user_data/nlb/ndt_runs"
2 | DATA:
3 | DATAPATH: "/home/joelye/user_data/nlb"
4 | TRAIN_FILENAME: 'mc_maze_small.h5'
5 | VAL_FILENAME: 'mc_maze_small.h5'
6 | MODEL:
7 | TRIAL_LENGTH: 0 # 140
8 | LEARNABLE_POSITION: True
9 | PRE_NORM: True
10 | FIXUP_INIT: True
11 | EMBED_DIM: 0
12 | LOGRATE: True
13 | NUM_LAYERS: 4
14 | TRAIN:
15 | LR:
16 | WARMUP: 2000
17 | MASK_RATIO: 0.25
18 | WEIGHT_DECAY: 5.0e-05
19 | PATIENCE: 2000
20 | LOG_INTERVAL: 200
21 | VAL_INTERVAL: 20
22 | CHECKPOINT_INTERVAL: 1000
23 | NUM_UPDATES: 50501
24 | MASK_SPAN_RAMP_START: 5000
25 | MASK_SPAN_RAMP_END: 10000
26 |
27 | TUNE_HP_JSON: './configs/mc_maze.json' # This space has more aggressive regularization
28 |
--------------------------------------------------------------------------------
/configs/mc_rtt.yaml:
--------------------------------------------------------------------------------
1 | CHECKPOINT_DIR: "/snel/share/joel/transformer_modeling/"
2 | DATA:
3 | DATAPATH: "/snel/share/data/nlb/"
4 | TRAIN_FILENAME: 'mc_rtt.h5'
5 | VAL_FILENAME: 'mc_rtt.h5'
6 | MODEL:
7 | TRIAL_LENGTH: 0 # 140
8 | LEARNABLE_POSITION: True
9 | PRE_NORM: True
10 | FIXUP_INIT: True
11 | EMBED_DIM: 0
12 | LOGRATE: True
13 | NUM_LAYERS: 4
14 | TRAIN:
15 | LR:
16 | WARMUP: 5000
17 | MASK_RATIO: 0.25
18 | WEIGHT_DECAY: 5.0e-05
19 | PATIENCE: 3000
20 | LOG_INTERVAL: 200
21 | VAL_INTERVAL: 20
22 | CHECKPOINT_INTERVAL: 1000
23 | NUM_UPDATES: 50501
24 | MASK_SPAN_RAMP_START: 8000
25 | MASK_SPAN_RAMP_END: 12000
26 |
27 | TUNE_HP_JSON: './configs/mc_maze.json' # This space has more aggressive regularization
28 |
--------------------------------------------------------------------------------
/configs/sweep_bump.json:
--------------------------------------------------------------------------------
1 | {
2 | "MODEL.DROPOUT": {
3 | "low": 0.2,
4 | "high": 0.6,
5 | "sample_fn": "uniform",
6 | "explore_wt": 0.3
7 | },
8 | "MODEL.DROPOUT_RATES": {
9 | "low": 0.2,
10 | "high": 0.6,
11 | "sample_fn": "uniform",
12 | "explore_wt": 0.3
13 | },
14 | "MODEL.DROPOUT_EMBEDDING": {
15 | "low": 0.2,
16 | "high": 0.6,
17 | "sample_fn": "uniform",
18 | "explore_wt": 0.3
19 | },
20 | "MODEL.CONTEXT_FORWARD": {
21 | "low": 4,
22 | "high": 100,
23 | "sample_fn": "randint",
24 | "enforce_limits": false,
25 | "is_integer": true
26 | },
27 | "MODEL.CONTEXT_BACKWARD": {
28 | "low": 4,
29 | "high": 100,
30 | "sample_fn": "randint",
31 | "enforce_limits": false,
32 | "is_integer": true
33 | },
34 | "TRAIN.LR.INIT": {
35 | "low": 1e-5,
36 | "high": 5e-3,
37 | "explore_wt": 0.3
38 | },
39 | "TRAIN.MASK_RANDOM_RATIO": {
40 | "low": 0.9,
41 | "high": 1.0,
42 | "sample_fn": "uniform"
43 | },
44 | "TRAIN.MASK_TOKEN_RATIO": {
45 | "low": 0.5,
46 | "high": 1.0,
47 | "sample_fn": "loguniform"
48 | },
49 | "TRAIN.MASK_MAX_SPAN": {
50 | "low": 1,
51 | "high": 5,
52 | "sample_fn": "randint",
53 | "is_integer": true
54 | }
55 | }
--------------------------------------------------------------------------------
/configs/sweep_generic.json:
--------------------------------------------------------------------------------
1 | {
2 | "MODEL.DROPOUT": {
3 | "low": 0.0,
4 | "high": 0.3,
5 | "sample_fn": "uniform",
6 | "explore_wt": 0.3
7 | },
8 | "MODEL.DROPOUT_RATES": {
9 | "low": 0.0,
10 | "high": 0.3,
11 | "sample_fn": "uniform",
12 | "explore_wt": 0.3
13 | },
14 | "MODEL.DROPOUT_EMBEDDING": {
15 | "low": 0.0,
16 | "high": 0.3,
17 | "sample_fn": "uniform",
18 | "explore_wt": 0.3
19 | },
20 | "MODEL.CONTEXT_FORWARD": {
21 | "low": 4,
22 | "high": 32,
23 | "sample_fn": "randint",
24 | "enforce_limits": false,
25 | "is_integer": true
26 | },
27 | "MODEL.CONTEXT_BACKWARD": {
28 | "low": 4,
29 | "high": 32,
30 | "sample_fn": "randint",
31 | "enforce_limits": false,
32 | "is_integer": true
33 | },
34 | "TRAIN.LR.INIT": {
35 | "low": 1e-5,
36 | "high": 5e-3,
37 | "explore_wt": 0.3
38 | },
39 | "TRAIN.MASK_RANDOM_RATIO": {
40 | "low": 0.9,
41 | "high": 1.0,
42 | "sample_fn": "uniform"
43 | },
44 | "TRAIN.MASK_TOKEN_RATIO": {
45 | "low": 0.5,
46 | "high": 1.0,
47 | "sample_fn": "loguniform"
48 | },
49 | "TRAIN.MASK_MAX_SPAN": {
50 | "low": 1,
51 | "high": 5,
52 | "sample_fn": "randint",
53 | "is_integer": true
54 | }
55 | }
--------------------------------------------------------------------------------
/configs/sweep_simple.json:
--------------------------------------------------------------------------------
1 | {
2 | "MODEL.DROPOUT": {
3 | "low": 0.0,
4 | "high": 0.3,
5 | "sample_fn": "uniform",
6 | "explore_wt": 0.3
7 | },
8 | "MODEL.DROPOUT_RATES": {
9 | "low": 0.0,
10 | "high": 0.3,
11 | "sample_fn": "uniform",
12 | "explore_wt": 0.3
13 | },
14 | "MODEL.DROPOUT_EMBEDDING": {
15 | "low": 0.0,
16 | "high": 0.3,
17 | "sample_fn": "uniform",
18 | "explore_wt": 0.3
19 | },
20 | "TRAIN.LR.INIT": {
21 | "low": 1e-5,
22 | "high": 5e-3,
23 | "explore_wt": 0.3
24 | },
25 | "TRAIN.MASK_RANDOM_RATIO": {
26 | "low": 0.9,
27 | "high": 1.0,
28 | "sample_fn": "uniform"
29 | },
30 | "TRAIN.MASK_TOKEN_RATIO": {
31 | "low": 0.5,
32 | "high": 1.0,
33 | "sample_fn": "loguniform"
34 | }
35 | }
--------------------------------------------------------------------------------
/data/chaotic_rnn/gen_synth_data_no_inputs.sh:
--------------------------------------------------------------------------------
1 | SYNTH_PATH=/snel/share/data/chaotic_rnn_data/data_no_inputs
2 |
3 |
4 | python generate_chaotic_rnn_data.py --save_dir=$SYNTH_PATH --datafile_name=chaotic_rnn_no_inputs --synth_data_seed=5 --T=1.0 --C=130 --N=50 --S=50 --train_percentage=0.8 --nreplications=10 --g=1.5 --x0_std=1.0 --tau=0.025 --dt=0.01 --ninputs=0 --input_magnitude_list='' --max_firing_rate=30.0 --noise_type='poisson'
--------------------------------------------------------------------------------
/data/chaotic_rnn/generate_chaotic_rnn_data.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Google Inc. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # ==============================================================================
16 | from __future__ import print_function
17 |
18 | import h5py
19 | import numpy as np
20 | import os
21 | import tensorflow as tf # used for flags here
22 |
23 | from utils import write_datasets
24 | from synthetic_data_utils import add_alignment_projections, generate_data
25 | from synthetic_data_utils import generate_rnn, get_train_n_valid_inds
26 | from synthetic_data_utils import nparray_and_transpose
27 | from synthetic_data_utils import spikify_data, gaussify_data, split_list_by_inds
28 | #import matplotlib
29 | #import matplotlib.pyplot as plt
30 | import scipy.signal
31 |
32 | #matplotlib.rcParams['image.interpolation'] = 'nearest'
33 | DATA_DIR = "rnn_synth_data_v1.0"
34 |
35 | flags = tf.app.flags
36 | flags.DEFINE_string("save_dir", "/tmp/" + DATA_DIR + "/",
37 | "Directory for saving data.")
38 | flags.DEFINE_string("datafile_name", "thits_data",
39 | "Name of data file for input case.")
40 | flags.DEFINE_string("noise_type", "poisson", "Noise type for data.")
41 | flags.DEFINE_integer("synth_data_seed", 5, "Random seed for RNN generation.")
42 | flags.DEFINE_float("T", 1.0, "Time in seconds to generate.")
43 | flags.DEFINE_integer("C", 100, "Number of conditions")
44 | flags.DEFINE_integer("N", 50, "Number of units for the RNN")
45 | flags.DEFINE_integer("S", 50, "Number of sampled units from RNN")
46 | flags.DEFINE_integer("npcs", 10, "Number of PCS for multi-session case.")
47 | flags.DEFINE_float("train_percentage", 4.0/5.0,
48 | "Percentage of train vs validation trials")
49 | flags.DEFINE_integer("nreplications", 40,
50 | "Number of noise replications of the same underlying rates.")
51 | flags.DEFINE_float("g", 1.5, "Complexity of dynamics")
52 | flags.DEFINE_float("x0_std", 1.0,
53 | "Volume from which to pull initial conditions (affects diversity of dynamics.")
54 | flags.DEFINE_float("tau", 0.025, "Time constant of RNN")
55 | flags.DEFINE_float("dt", 0.010, "Time bin")
56 | flags.DEFINE_float("input_magnitude", 20.0,
57 | "For the input case, what is the value of the input?")
58 | flags.DEFINE_integer("ninputs", 0, "number of inputs")
59 | flags.DEFINE_list("input_magnitude_list", "10,15,20,25,30", "Magnitudes for multiple inputs")
60 | flags.DEFINE_float("max_firing_rate", 30.0, "Map 1.0 of RNN to a spikes per second")
61 | flags.DEFINE_boolean("lorenz", False, "use lorenz system as generated inputs")
62 | FLAGS = flags.FLAGS
63 |
64 |
65 | # Note that with N small, (as it is 25 above), the finite size effects
66 | # will have pretty dramatic effects on the dynamics of the random RNN.
67 | # If you want more complex dynamics, you'll have to run the script a
68 | # lot, or increase N (or g).
69 |
70 | # Getting hard vs. easy data can be a little stochastic, so we set the seed.
71 |
72 | # Pull out some commonly used parameters.
73 | # These are user parameters (configuration)
74 | rng = np.random.RandomState(seed=FLAGS.synth_data_seed)
75 | T = FLAGS.T
76 | C = FLAGS.C
77 | N = FLAGS.N
78 | S = FLAGS.S
79 | input_magnitude = FLAGS.input_magnitude
80 | input_magnitude_list = [float(i) for i in (FLAGS.input_magnitude_list)]
81 | ninputs = FLAGS.ninputs
82 | nreplications = FLAGS.nreplications
83 | E = nreplications * C # total number of trials
84 | # S is the number of measurements in each datasets, w/ each
85 | # dataset having a different set of observations.
86 | ndatasets = N/S # ok if rounded down
87 | train_percentage = FLAGS.train_percentage
88 | ntime_steps = int(T / FLAGS.dt)
89 | # End of user parameters
90 |
91 | lorenz=FLAGS.lorenz
92 |
93 |
94 | if ninputs >= 1:
95 | rnn = generate_rnn(rng, N, FLAGS.g, FLAGS.tau, FLAGS.dt, FLAGS.max_firing_rate, ninputs)
96 | else:
97 | rnn = generate_rnn(rng, N, FLAGS.g, FLAGS.tau, FLAGS.dt, FLAGS.max_firing_rate)
98 |
99 | # Check to make sure the RNN is the one we used in the paper.
100 | if N == 50:
101 | assert abs(rnn['W'][0,0] - 0.06239899) < 1e-8, 'Error in random seed?'
102 | rem_check = nreplications * train_percentage
103 | assert abs(rem_check - int(rem_check)) < 1e-8, \
104 | 'Train percentage * nreplications should be integral number.'
105 |
106 |
107 | #if lorenz:
108 | # lorenz_input = generate_lorenz(ntime_steps, rng)
109 | # rnn = generate_rnn(rng, N, FLAGS.g, FLAGS.tau, FLAGS.dt, FLAGS.max_firing_rate, lorenz_input)
110 |
111 |
112 | # Initial condition generation, and condition label generation. This
113 | # happens outside of the dataset loop, so that all datasets have the
114 | # same conditions, which is similar to a neurophys setup.
115 | condition_number = 0
116 | x0s = []
117 | condition_labels = []
118 | print(FLAGS.x0_std)
119 | for c in range(C):
120 | x0 = FLAGS.x0_std * rng.randn(N, 1)
121 | x0s.append(np.tile(x0, nreplications)) # replicate x0 nreplications times
122 | # replicate the condition label nreplications times
123 | for ns in range(nreplications):
124 | condition_labels.append(condition_number)
125 | condition_number += 1
126 | x0s = np.concatenate(x0s, axis=1)
127 |
128 | #print(x0s.shape)
129 | #print(x0s[1,1:20])
130 |
131 | # Containers for storing data across data.
132 | datasets = {}
133 | for n in range(ndatasets):
134 | print(n+1, " of ", ndatasets)
135 |
136 | # First generate all firing rates. in the next loop, generate all
137 | # replications this allows the random state for rate generation to be
138 | # independent of n_replications.
139 | dataset_name = 'dataset_N' + str(N) + '_S' + str(S)
140 | if S < N:
141 | dataset_name += '_n' + str(n+1)
142 |
143 | # Sample neuron subsets. The assumption is the PC axes of the RNN
144 | # are not unit aligned, so sampling units is adequate to sample all
145 | # the high-variance PCs.
146 | P_sxn = np.eye(S,N)
147 | for m in range(n):
148 | P_sxn = np.roll(P_sxn, S, axis=1)
149 |
150 | if input_magnitude > 0.0:
151 | # time of "hits" randomly chosen between [1/4 and 3/4] of total time
152 | if ninputs>1:
153 | for n in range(ninputs):
154 | if n == 0:
155 | input_times = rng.choice(int(ntime_steps/2), size=[E]) + int(ntime_steps/4)
156 | else:
157 | input_times = np.vstack((input_times, (rng.choice(int(ntime_steps/2), size=[E]) + int(ntime_steps/4))))
158 | else:
159 | input_times = rng.choice(int(ntime_steps/2), size=[E]) + int(ntime_steps/4)
160 | else:
161 | input_times = None
162 |
163 | print(ninputs)
164 |
165 |
166 | if ninputs > 1:
167 | rates, x0s, inputs = \
168 | generate_data(rnn, T=T, E=E, x0s=x0s, P_sxn=P_sxn,
169 | input_magnitude=input_magnitude_list,
170 | input_times=input_times, ninputs=ninputs, rng=rng)
171 | else:
172 | rates, x0s, inputs = \
173 | generate_data(rnn, T=T, E=E, x0s=x0s, P_sxn=P_sxn,
174 | input_magnitude=input_magnitude,
175 | input_times=input_times, ninputs=ninputs, rng=rng)
176 |
177 | if FLAGS.noise_type == "poisson":
178 | noisy_data = spikify_data(rates, rng, rnn['dt'], rnn['max_firing_rate'])
179 | elif FLAGS.noise_type == "gaussian":
180 | noisy_data = gaussify_data(rates, rng, rnn['dt'], rnn['max_firing_rate'])
181 | else:
182 | raise ValueError("Only noise types supported are poisson or gaussian")
183 |
184 | # split into train and validation sets
185 | train_inds, valid_inds = get_train_n_valid_inds(E, train_percentage,
186 | nreplications)
187 |
188 | # Split the data, inputs, labels and times into train vs. validation.
189 | rates_train, rates_valid = \
190 | split_list_by_inds(rates, train_inds, valid_inds)
191 | #rates_train_no_input, rates_valid_no_input = \
192 | # split_list_by_inds(rates_noinput, train_inds, valid_inds)
193 | noisy_data_train, noisy_data_valid = \
194 | split_list_by_inds(noisy_data, train_inds, valid_inds)
195 | input_train, inputs_valid = \
196 | split_list_by_inds(inputs, train_inds, valid_inds)
197 | condition_labels_train, condition_labels_valid = \
198 | split_list_by_inds(condition_labels, train_inds, valid_inds)
199 | if ninputs>1:
200 | input_times_train, input_times_valid = \
201 | split_list_by_inds(input_times, train_inds, valid_inds, ninputs)
202 | input_magnitude = input_magnitude_list
203 | else:
204 | input_times_train, input_times_valid = \
205 | split_list_by_inds(input_times, train_inds, valid_inds)
206 |
207 | #lorenz_train = np.expand_dims(lorenz_train, axis=1)
208 | #lorenz_valid = np.expand_dims(lorenz_valid, axis=1)
209 | #print((np.array(input_train)).shape)
210 | #print((np.array(lorenz_train)).shape)
211 | # Turn rates, noisy_data, and input into numpy arrays.
212 | rates_train = nparray_and_transpose(rates_train)
213 | rates_valid = nparray_and_transpose(rates_valid)
214 | #rates_train_no_input = nparray_and_transpose(rates_train_no_input)
215 | #rates_valid_no_input = nparray_and_transpose(rates_valid_no_input)
216 | noisy_data_train = nparray_and_transpose(noisy_data_train)
217 | noisy_data_valid = nparray_and_transpose(noisy_data_valid)
218 | input_train = nparray_and_transpose(input_train)
219 | inputs_valid = nparray_and_transpose(inputs_valid)
220 |
221 | # Note that we put these 'truth' rates and input into this
222 | # structure, the only data that is used in LFADS are the noisy
223 | # data e.g. spike trains. The rest is either for printing or posterity.
224 | data = {'train_truth': rates_train,
225 | 'valid_truth': rates_valid,
226 | #'train_truth_no_input': rates_train_no_input,
227 | #'valid_truth_no_input': rates_valid_no_input,
228 | 'input_train_truth' : input_train,
229 | 'input_valid_truth' : inputs_valid,
230 | 'train_data' : noisy_data_train,
231 | 'valid_data' : noisy_data_valid,
232 | 'train_percentage' : train_percentage,
233 | 'nreplications' : nreplications,
234 | 'dt' : rnn['dt'],
235 | 'input_magnitude' : input_magnitude,
236 | 'input_times_train' : input_times_train,
237 | 'input_times_valid' : input_times_valid,
238 | 'P_sxn' : P_sxn,
239 | 'condition_labels_train' : condition_labels_train,
240 | 'condition_labels_valid' : condition_labels_valid,
241 | 'conversion_factor': 1.0 / rnn['conversion_factor']}
242 | datasets[dataset_name] = data
243 |
244 | if S < N:
245 | # Note that this isn't necessary for this synthetic example, but
246 | # it's useful to see how the input factor matrices were initialized
247 | # for actual neurophysiology data.
248 | datasets = add_alignment_projections(datasets, npcs=FLAGS.npcs)
249 |
250 | # Write out the datasets.
251 | write_datasets(FLAGS.save_dir, FLAGS.datafile_name, datasets)
252 |
--------------------------------------------------------------------------------
/data/chaotic_rnn/generate_chaotic_rnn_data_allowRandomSeed.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Google Inc. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # ==============================================================================
16 | from __future__ import print_function
17 |
18 | import h5py
19 | import numpy as np
20 | import os
21 | import tensorflow as tf # used for flags here
22 |
23 | from utils import write_datasets
24 | from synthetic_data_utils import add_alignment_projections, generate_data
25 | from synthetic_data_utils import generate_rnn, get_train_n_valid_inds
26 | from synthetic_data_utils import nparray_and_transpose
27 | from synthetic_data_utils import spikify_data, gaussify_data, split_list_by_inds
28 | #import matplotlib
29 | #import matplotlib.pyplot as plt
30 | import scipy.signal
31 |
32 | #matplotlib.rcParams['image.interpolation'] = 'nearest'
33 | DATA_DIR = "rnn_synth_data_v1.0"
34 |
35 | flags = tf.app.flags
36 | flags.DEFINE_string("save_dir", "/tmp/" + DATA_DIR + "/",
37 | "Directory for saving data.")
38 | flags.DEFINE_string("datafile_name", "thits_data",
39 | "Name of data file for input case.")
40 | flags.DEFINE_string("noise_type", "poisson", "Noise type for data.")
41 | flags.DEFINE_integer("synth_data_seed", 5, "Random seed for RNN generation.")
42 | flags.DEFINE_float("T", 1.0, "Time in seconds to generate.")
43 | flags.DEFINE_integer("C", 100, "Number of conditions")
44 | flags.DEFINE_integer("N", 50, "Number of units for the RNN")
45 | flags.DEFINE_integer("S", 50, "Number of sampled units from RNN")
46 | flags.DEFINE_integer("npcs", 10, "Number of PCS for multi-session case.")
47 | flags.DEFINE_float("train_percentage", 4.0/5.0,
48 | "Percentage of train vs validation trials")
49 | flags.DEFINE_integer("nreplications", 40,
50 | "Number of noise replications of the same underlying rates.")
51 | flags.DEFINE_float("g", 1.5, "Complexity of dynamics")
52 | flags.DEFINE_float("x0_std", 1.0,
53 | "Volume from which to pull initial conditions (affects diversity of dynamics.")
54 | flags.DEFINE_float("tau", 0.025, "Time constant of RNN")
55 | flags.DEFINE_float("dt", 0.010, "Time bin")
56 | flags.DEFINE_float("input_magnitude", 20.0,
57 | "For the input case, what is the value of the input?")
58 | flags.DEFINE_integer("ninputs", 0, "number of inputs")
59 | flags.DEFINE_list("input_magnitude_list", "10,15,20,25,30", "Magnitudes for multiple inputs")
60 | flags.DEFINE_float("max_firing_rate", 30.0, "Map 1.0 of RNN to a spikes per second")
61 | flags.DEFINE_boolean("lorenz", False, "use lorenz system as generated inputs")
62 | FLAGS = flags.FLAGS
63 |
64 |
65 | # Note that with N small, (as it is 25 above), the finite size effects
66 | # will have pretty dramatic effects on the dynamics of the random RNN.
67 | # If you want more complex dynamics, you'll have to run the script a
68 | # lot, or increase N (or g).
69 |
70 | # Getting hard vs. easy data can be a little stochastic, so we set the seed.
71 |
72 | # Pull out some commonly used parameters.
73 | # These are user parameters (configuration)
74 | rng = np.random.RandomState(seed=FLAGS.synth_data_seed)
75 | T = FLAGS.T
76 | C = FLAGS.C
77 | N = FLAGS.N
78 | S = FLAGS.S
79 | input_magnitude = FLAGS.input_magnitude
80 | input_magnitude_list = [float(i) for i in (FLAGS.input_magnitude_list)]
81 | ninputs = FLAGS.ninputs
82 | nreplications = FLAGS.nreplications
83 | E = nreplications * C # total number of trials
84 | # S is the number of measurements in each datasets, w/ each
85 | # dataset having a different set of observations.
86 | ndatasets = N/S # ok if rounded down
87 | train_percentage = FLAGS.train_percentage
88 | ntime_steps = int(T / FLAGS.dt)
89 | # End of user parameters
90 |
91 | lorenz=FLAGS.lorenz
92 |
93 |
94 | if ninputs >= 1:
95 | rnn = generate_rnn(rng, N, FLAGS.g, FLAGS.tau, FLAGS.dt, FLAGS.max_firing_rate, ninputs)
96 | else:
97 | rnn = generate_rnn(rng, N, FLAGS.g, FLAGS.tau, FLAGS.dt, FLAGS.max_firing_rate)
98 |
99 | # Check to make sure the RNN is the one we used in the paper.
100 | if N == 50:
101 | # assert abs(rnn['W'][0,0] - 0.06239899) < 1e-8, 'Error in random seed?'
102 | rem_check = nreplications * train_percentage
103 | assert abs(rem_check - int(rem_check)) < 1e-8, \
104 | 'Train percentage * nreplications should be integral number.'
105 |
106 |
107 | #if lorenz:
108 | # lorenz_input = generate_lorenz(ntime_steps, rng)
109 | # rnn = generate_rnn(rng, N, FLAGS.g, FLAGS.tau, FLAGS.dt, FLAGS.max_firing_rate, lorenz_input)
110 |
111 |
112 | # Initial condition generation, and condition label generation. This
113 | # happens outside of the dataset loop, so that all datasets have the
114 | # same conditions, which is similar to a neurophys setup.
115 | condition_number = 0
116 | x0s = []
117 | condition_labels = []
118 | print(FLAGS.x0_std)
119 | for c in range(C):
120 | x0 = FLAGS.x0_std * rng.randn(N, 1)
121 | x0s.append(np.tile(x0, nreplications)) # replicate x0 nreplications times
122 | # replicate the condition label nreplications times
123 | for ns in range(nreplications):
124 | condition_labels.append(condition_number)
125 | condition_number += 1
126 | x0s = np.concatenate(x0s, axis=1)
127 |
128 | #print(x0s.shape)
129 | #print(x0s[1,1:20])
130 |
131 | # Containers for storing data across data.
132 | datasets = {}
133 | for n in range(ndatasets):
134 | print(n+1, " of ", ndatasets)
135 |
136 | # First generate all firing rates. in the next loop, generate all
137 | # replications this allows the random state for rate generation to be
138 | # independent of n_replications.
139 | dataset_name = 'dataset_N' + str(N) + '_S' + str(S)
140 | if S < N:
141 | dataset_name += '_n' + str(n+1)
142 |
143 | # Sample neuron subsets. The assumption is the PC axes of the RNN
144 | # are not unit aligned, so sampling units is adequate to sample all
145 | # the high-variance PCs.
146 | P_sxn = np.eye(S,N)
147 | for m in range(n):
148 | P_sxn = np.roll(P_sxn, S, axis=1)
149 |
150 | if input_magnitude > 0.0:
151 | # time of "hits" randomly chosen between [1/4 and 3/4] of total time
152 | if ninputs>1:
153 | for n in range(ninputs):
154 | if n == 0:
155 | input_times = rng.choice(int(ntime_steps/2), size=[E]) + int(ntime_steps/4)
156 | else:
157 | input_times = np.vstack((input_times, (rng.choice(int(ntime_steps/2), size=[E]) + int(ntime_steps/4))))
158 | else:
159 | input_times = rng.choice(int(ntime_steps/2), size=[E]) + int(ntime_steps/4)
160 | else:
161 | input_times = None
162 |
163 | print(ninputs)
164 |
165 |
166 | if ninputs > 1:
167 | rates, x0s, inputs = \
168 | generate_data(rnn, T=T, E=E, x0s=x0s, P_sxn=P_sxn,
169 | input_magnitude=input_magnitude_list,
170 | input_times=input_times, ninputs=ninputs, rng=rng)
171 | else:
172 | rates, x0s, inputs = \
173 | generate_data(rnn, T=T, E=E, x0s=x0s, P_sxn=P_sxn,
174 | input_magnitude=input_magnitude,
175 | input_times=input_times, ninputs=ninputs, rng=rng)
176 |
177 | if FLAGS.noise_type == "poisson":
178 | noisy_data = spikify_data(rates, rng, rnn['dt'], rnn['max_firing_rate'])
179 | elif FLAGS.noise_type == "gaussian":
180 | noisy_data = gaussify_data(rates, rng, rnn['dt'], rnn['max_firing_rate'])
181 | else:
182 | raise ValueError("Only noise types supported are poisson or gaussian")
183 |
184 | # split into train and validation sets
185 | train_inds, valid_inds = get_train_n_valid_inds(E, train_percentage,
186 | nreplications)
187 |
188 | # Split the data, inputs, labels and times into train vs. validation.
189 | rates_train, rates_valid = \
190 | split_list_by_inds(rates, train_inds, valid_inds)
191 | #rates_train_no_input, rates_valid_no_input = \
192 | # split_list_by_inds(rates_noinput, train_inds, valid_inds)
193 | noisy_data_train, noisy_data_valid = \
194 | split_list_by_inds(noisy_data, train_inds, valid_inds)
195 | input_train, inputs_valid = \
196 | split_list_by_inds(inputs, train_inds, valid_inds)
197 | condition_labels_train, condition_labels_valid = \
198 | split_list_by_inds(condition_labels, train_inds, valid_inds)
199 | if ninputs>1:
200 | input_times_train, input_times_valid = \
201 | split_list_by_inds(input_times, train_inds, valid_inds, ninputs)
202 | input_magnitude = input_magnitude_list
203 | else:
204 | input_times_train, input_times_valid = \
205 | split_list_by_inds(input_times, train_inds, valid_inds)
206 |
207 | #lorenz_train = np.expand_dims(lorenz_train, axis=1)
208 | #lorenz_valid = np.expand_dims(lorenz_valid, axis=1)
209 | #print((np.array(input_train)).shape)
210 | #print((np.array(lorenz_train)).shape)
211 | # Turn rates, noisy_data, and input into numpy arrays.
212 | rates_train = nparray_and_transpose(rates_train)
213 | rates_valid = nparray_and_transpose(rates_valid)
214 | #rates_train_no_input = nparray_and_transpose(rates_train_no_input)
215 | #rates_valid_no_input = nparray_and_transpose(rates_valid_no_input)
216 | noisy_data_train = nparray_and_transpose(noisy_data_train)
217 | noisy_data_valid = nparray_and_transpose(noisy_data_valid)
218 | input_train = nparray_and_transpose(input_train)
219 | inputs_valid = nparray_and_transpose(inputs_valid)
220 |
221 | # Note that we put these 'truth' rates and input into this
222 | # structure, the only data that is used in LFADS are the noisy
223 | # data e.g. spike trains. The rest is either for printing or posterity.
224 | data = {'train_truth': rates_train,
225 | 'valid_truth': rates_valid,
226 | #'train_truth_no_input': rates_train_no_input,
227 | #'valid_truth_no_input': rates_valid_no_input,
228 | 'input_train_truth' : input_train,
229 | 'input_valid_truth' : inputs_valid,
230 | 'train_data' : noisy_data_train,
231 | 'valid_data' : noisy_data_valid,
232 | 'train_percentage' : train_percentage,
233 | 'nreplications' : nreplications,
234 | 'dt' : rnn['dt'],
235 | 'input_magnitude' : input_magnitude,
236 | 'input_times_train' : input_times_train,
237 | 'input_times_valid' : input_times_valid,
238 | 'P_sxn' : P_sxn,
239 | 'condition_labels_train' : condition_labels_train,
240 | 'condition_labels_valid' : condition_labels_valid,
241 | 'conversion_factor': 1.0 / rnn['conversion_factor']}
242 | datasets[dataset_name] = data
243 |
244 | if S < N:
245 | # Note that this isn't necessary for this synthetic example, but
246 | # it's useful to see how the input factor matrices were initialized
247 | # for actual neurophysiology data.
248 | datasets = add_alignment_projections(datasets, npcs=FLAGS.npcs)
249 |
250 | # Write out the datasets.
251 | write_datasets(FLAGS.save_dir, FLAGS.datafile_name, datasets)
252 |
--------------------------------------------------------------------------------
/data/chaotic_rnn/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Google Inc. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # ==============================================================================
16 | from __future__ import print_function
17 |
18 | import os
19 | import h5py
20 | import json
21 |
22 | import numpy as np
23 | import tensorflow as tf
24 |
25 |
26 | def log_sum_exp(x_k):
27 | """Computes log \sum exp in a numerically stable way.
28 | log ( sum_i exp(x_i) )
29 | log ( sum_i exp(x_i - m + m) ), with m = max(x_i)
30 | log ( sum_i exp(x_i - m)*exp(m) )
31 | log ( sum_i exp(x_i - m) + m
32 |
33 | Args:
34 | x_k - k -dimensional list of arguments to log_sum_exp.
35 |
36 | Returns:
37 | log_sum_exp of the arguments.
38 | """
39 | m = tf.reduce_max(x_k)
40 | x1_k = x_k - m
41 | u_k = tf.exp(x1_k)
42 | z = tf.reduce_sum(u_k)
43 | return tf.log(z) + m
44 |
45 |
46 | def linear(x, out_size, do_bias=True, alpha=1.0, identity_if_possible=False,
47 | normalized=False, name=None, collections=None):
48 | """Linear (affine) transformation, y = x W + b, for a variety of
49 | configurations.
50 |
51 | Args:
52 | x: input The tensor to tranformation.
53 | out_size: The integer size of non-batch output dimension.
54 | do_bias (optional): Add a learnable bias vector to the operation.
55 | alpha (optional): A multiplicative scaling for the weight initialization
56 | of the matrix, in the form \alpha * 1/\sqrt{x.shape[1]}.
57 | identity_if_possible (optional): just return identity,
58 | if x.shape[1] == out_size.
59 | normalized (optional): Option to divide out by the norms of the rows of W.
60 | name (optional): The name prefix to add to variables.
61 | collections (optional): List of additional collections. (Placed in
62 | tf.GraphKeys.GLOBAL_VARIABLES already, so no need for that.)
63 |
64 | Returns:
65 | In the equation, y = x W + b, returns the tensorflow op that yields y.
66 | """
67 | in_size = int(x.get_shape()[1]) # from Dimension(10) -> 10
68 | stddev = alpha/np.sqrt(float(in_size))
69 | mat_init = tf.random_normal_initializer(0.0, stddev)
70 | wname = (name + "/W") if name else "/W"
71 |
72 | if identity_if_possible and in_size == out_size:
73 | # Sometimes linear layers are nothing more than size adapters.
74 | return tf.identity(x, name=(wname+'_ident'))
75 |
76 | W,b = init_linear(in_size, out_size, do_bias=do_bias, alpha=alpha,
77 | normalized=normalized, name=name, collections=collections)
78 |
79 | if do_bias:
80 | return tf.matmul(x, W) + b
81 | else:
82 | return tf.matmul(x, W)
83 |
84 |
85 | def init_linear(in_size, out_size, do_bias=True, mat_init_value=None,
86 | bias_init_value=None, alpha=1.0, identity_if_possible=False,
87 | normalized=False, name=None, collections=None, trainable=True):
88 | """Linear (affine) transformation, y = x W + b, for a variety of
89 | configurations.
90 |
91 | Args:
92 | in_size: The integer size of the non-batc input dimension. [(x),y]
93 | out_size: The integer size of non-batch output dimension. [x,(y)]
94 | do_bias (optional): Add a (learnable) bias vector to the operation,
95 | if false, b will be None
96 | mat_init_value (optional): numpy constant for matrix initialization, if None
97 | , do random, with additional parameters.
98 | alpha (optional): A multiplicative scaling for the weight initialization
99 | of the matrix, in the form \alpha * 1/\sqrt{x.shape[1]}.
100 | identity_if_possible (optional): just return identity,
101 | if x.shape[1] == out_size.
102 | normalized (optional): Option to divide out by the norms of the rows of W.
103 | name (optional): The name prefix to add to variables.
104 | collections (optional): List of additional collections. (Placed in
105 | tf.GraphKeys.GLOBAL_VARIABLES already, so no need for that.)
106 |
107 | Returns:
108 | In the equation, y = x W + b, returns the pair (W, b).
109 | """
110 |
111 | if mat_init_value is not None and mat_init_value.shape != (in_size, out_size):
112 | raise ValueError(
113 | 'Provided mat_init_value must have shape [%d, %d].'%(in_size, out_size))
114 | if bias_init_value is not None and bias_init_value.shape != (1,out_size):
115 | raise ValueError(
116 | 'Provided bias_init_value must have shape [1,%d].'%(out_size,))
117 |
118 | if mat_init_value is None:
119 | stddev = alpha/np.sqrt(float(in_size))
120 | mat_init = tf.random_normal_initializer(0.0, stddev)
121 |
122 | wname = (name + "/W") if name else "/W"
123 |
124 | if identity_if_possible and in_size == out_size:
125 | return (tf.constant(np.eye(in_size).astype(np.float32)),
126 | tf.zeros(in_size))
127 |
128 | # Note the use of get_variable vs. tf.Variable. this is because get_variable
129 | # does not allow the initialization of the variable with a value.
130 | if normalized:
131 | w_collections = [tf.GraphKeys.GLOBAL_VARIABLES, "norm-variables"]
132 | if collections:
133 | w_collections += collections
134 | if mat_init_value is not None:
135 | w = tf.Variable(mat_init_value, name=wname, collections=w_collections,
136 | trainable=trainable)
137 | else:
138 | w = tf.get_variable(wname, [in_size, out_size], initializer=mat_init,
139 | collections=w_collections, trainable=trainable)
140 | w = tf.nn.l2_normalize(w, dim=0) # x W, so xW_j = \sum_i x_bi W_ij
141 | else:
142 | w_collections = [tf.GraphKeys.GLOBAL_VARIABLES]
143 | if collections:
144 | w_collections += collections
145 | if mat_init_value is not None:
146 | w = tf.Variable(mat_init_value, name=wname, collections=w_collections,
147 | trainable=trainable)
148 | else:
149 | w = tf.get_variable(wname, [in_size, out_size], initializer=mat_init,
150 | collections=w_collections, trainable=trainable)
151 | b = None
152 | if do_bias:
153 | b_collections = [tf.GraphKeys.GLOBAL_VARIABLES]
154 | if collections:
155 | b_collections += collections
156 | bname = (name + "/b") if name else "/b"
157 | if bias_init_value is None:
158 | b = tf.get_variable(bname, [1, out_size],
159 | initializer=tf.zeros_initializer(),
160 | collections=b_collections,
161 | trainable=trainable)
162 | else:
163 | b = tf.Variable(bias_init_value, name=bname,
164 | collections=b_collections,
165 | trainable=trainable)
166 |
167 | return (w, b)
168 |
169 |
170 | def write_data(data_fname, data_dict, use_json=False, compression=None):
171 | """Write data in HD5F format.
172 |
173 | Args:
174 | data_fname: The filename of teh file in which to write the data.
175 | data_dict: The dictionary of data to write. The keys are strings
176 | and the values are numpy arrays.
177 | use_json (optional): human readable format for simple items
178 | compression (optional): The compression to use for h5py (disabled by
179 | default because the library borks on scalars, otherwise try 'gzip').
180 | """
181 |
182 | dir_name = os.path.dirname(data_fname)
183 | if not os.path.exists(dir_name):
184 | os.makedirs(dir_name)
185 |
186 | if use_json:
187 | the_file = open(data_fname,'w')
188 | json.dump(data_dict, the_file)
189 | the_file.close()
190 | else:
191 | try:
192 | with h5py.File(data_fname, 'w') as hf:
193 | for k, v in data_dict.items():
194 | clean_k = k.replace('/', '_')
195 | if clean_k is not k:
196 | print('Warning: saving variable with name: ', k, ' as ', clean_k)
197 | else:
198 | print('Saving variable with name: ', clean_k)
199 | hf.create_dataset(clean_k, data=v, compression=compression)
200 | except IOError:
201 | print("Cannot open %s for writing.", data_fname)
202 | raise
203 |
204 |
205 | def read_data(data_fname):
206 | """ Read saved data in HDF5 format.
207 |
208 | Args:
209 | data_fname: The filename of the file from which to read the data.
210 | Returns:
211 | A dictionary whose keys will vary depending on dataset (but should
212 | always contain the keys 'train_data' and 'valid_data') and whose
213 | values are numpy arrays.
214 | """
215 |
216 | try:
217 | with h5py.File(data_fname, 'r') as hf:
218 | data_dict = {k: np.array(v) for k, v in hf.items()}
219 | return data_dict
220 | except IOError:
221 | print("Cannot open %s for reading." % data_fname)
222 | raise
223 |
224 |
225 | def write_datasets(data_path, data_fname_stem, dataset_dict, compression=None):
226 | """Write datasets in HD5F format.
227 |
228 | This function assumes the dataset_dict is a mapping ( string ->
229 | to data_dict ). It calls write_data for each data dictionary,
230 | post-fixing the data filename with the key of the dataset.
231 |
232 | Args:
233 | data_path: The path to the save directory.
234 | data_fname_stem: The filename stem of the file in which to write the data.
235 | dataset_dict: The dictionary of datasets. The keys are strings
236 | and the values data dictionaries (str -> numpy arrays) associations.
237 | compression (optional): The compression to use for h5py (disabled by
238 | default because the library borks on scalars, otherwise try 'gzip').
239 | """
240 |
241 | full_name_stem = os.path.join(data_path, data_fname_stem)
242 | for s, data_dict in dataset_dict.items():
243 | write_data(full_name_stem + "_" + s, data_dict, compression=compression)
244 |
245 |
246 | def read_datasets(data_path, data_fname_stem):
247 | """Read dataset sin HD5F format.
248 |
249 | This function assumes the dataset_dict is a mapping ( string ->
250 | to data_dict ). It calls write_data for each data dictionary,
251 | post-fixing the data filename with the key of the dataset.
252 |
253 | Args:
254 | data_path: The path to the save directory.
255 | data_fname_stem: The filename stem of the file in which to write the data.
256 | """
257 |
258 | dataset_dict = {}
259 | fnames = os.listdir(data_path)
260 |
261 | print ('loading data from ' + data_path + ' with stem ' + data_fname_stem)
262 | for fname in fnames:
263 | if fname.startswith(data_fname_stem):
264 | data_dict = read_data(os.path.join(data_path,fname))
265 | idx = len(data_fname_stem) + 1
266 | key = fname[idx:]
267 | data_dict['data_dim'] = data_dict['train_data'].shape[2]
268 | data_dict['num_steps'] = data_dict['train_data'].shape[1]
269 | dataset_dict[key] = data_dict
270 |
271 | if len(dataset_dict) == 0:
272 | raise ValueError("Failed to load any datasets, are you sure that the "
273 | "'--data_dir' and '--data_filename_stem' flag values "
274 | "are correct?")
275 |
276 | print (str(len(dataset_dict)) + ' datasets loaded')
277 | return dataset_dict
278 |
279 |
280 | # NUMPY utility functions
281 | def list_t_bxn_to_list_b_txn(values_t_bxn):
282 | """Convert a length T list of BxN numpy tensors of length B list of TxN numpy
283 | tensors.
284 |
285 | Args:
286 | values_t_bxn: The length T list of BxN numpy tensors.
287 |
288 | Returns:
289 | The length B list of TxN numpy tensors.
290 | """
291 | T = len(values_t_bxn)
292 | B, N = values_t_bxn[0].shape
293 | values_b_txn = []
294 | for b in range(B):
295 | values_pb_txn = np.zeros([T,N])
296 | for t in range(T):
297 | values_pb_txn[t,:] = values_t_bxn[t][b,:]
298 | values_b_txn.append(values_pb_txn)
299 |
300 | return values_b_txn
301 |
302 |
303 | def list_t_bxn_to_tensor_bxtxn(values_t_bxn):
304 | """Convert a length T list of BxN numpy tensors to single numpy tensor with
305 | shape BxTxN.
306 |
307 | Args:
308 | values_t_bxn: The length T list of BxN numpy tensors.
309 |
310 | Returns:
311 | values_bxtxn: The BxTxN numpy tensor.
312 | """
313 |
314 | T = len(values_t_bxn)
315 | B, N = values_t_bxn[0].shape
316 | values_bxtxn = np.zeros([B,T,N])
317 | for t in range(T):
318 | values_bxtxn[:,t,:] = values_t_bxn[t]
319 |
320 | return values_bxtxn
321 |
322 |
323 | def tensor_bxtxn_to_list_t_bxn(tensor_bxtxn):
324 | """Convert a numpy tensor with shape BxTxN to a length T list of numpy tensors
325 | with shape BxT.
326 |
327 | Args:
328 | tensor_bxtxn: The BxTxN numpy tensor.
329 |
330 | Returns:
331 | A length T list of numpy tensors with shape BxT.
332 | """
333 |
334 | values_t_bxn = []
335 | B, T, N = tensor_bxtxn.shape
336 | for t in range(T):
337 | values_t_bxn.append(np.squeeze(tensor_bxtxn[:,t,:]))
338 |
339 | return values_t_bxn
340 |
341 |
342 | def flatten(list_of_lists):
343 | """Takes a list of lists and returns a list of the elements.
344 |
345 | Args:
346 | list_of_lists: List of lists.
347 |
348 | Returns:
349 | flat_list: Flattened list.
350 | flat_list_idxs: Flattened list indices.
351 | """
352 | flat_list = []
353 | flat_list_idxs = []
354 | start_idx = 0
355 | for item in list_of_lists:
356 | if isinstance(item, list):
357 | flat_list += item
358 | l = len(item)
359 | idxs = range(start_idx, start_idx+l)
360 | start_idx = start_idx+l
361 | else: # a value
362 | flat_list.append(item)
363 | idxs = [start_idx]
364 | start_idx += 1
365 | flat_list_idxs.append(idxs)
366 |
367 | return flat_list, flat_list_idxs
368 |
--------------------------------------------------------------------------------
/data/lfads_lorenz.h5:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:d0766fc6e7d1e59be09827fcb839827bcd6904f91f1056975672b4e4b14266e0
3 | size 36207696
4 |
--------------------------------------------------------------------------------
/data/lorenz.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Author: Joel Ye
3 |
4 | r""" Notebook for generating Lorenz attractor data.
5 | Note:
6 | This is an example reference,
7 | and does not reproduce the data with repeated conditions used in the paper.
8 | That was generated with LFADS utilities, which we are unable to release at this time.
9 | """
10 |
11 | #%%
12 | import os
13 | import os.path as osp
14 | import numpy as np
15 | import matplotlib.pyplot as plt
16 | # This import registers the 3D projection, but is otherwise unused.
17 | from mpl_toolkits.mplot3d import Axes3D # noqa: F401 unused import
18 | import torch
19 | import torch.nn as nn
20 | from torch.distributions import Poisson
21 | torch.set_grad_enabled(False)
22 |
23 | default_s = 10 # sigma
24 | default_r = 28 # rho
25 | default_b = 2.667 # beta
26 |
27 | def lorenz_deriv(x, y, z, s=default_s, r=default_r, b=default_b):
28 | '''
29 | Given:
30 | x, y, z: a point of interest in three dimensional space
31 | s, r, b: parameters defining the lorenz attractor
32 | Returns:
33 | x_dot, y_dot, z_dot: values of the lorenz attractor's partial
34 | derivatives at the point x, y, z
35 | '''
36 | x_dot = s*(y - x)
37 | y_dot = r*x - y - x*z
38 | z_dot = x*y - b*z
39 | return x_dot, y_dot, z_dot
40 |
41 | def lorenz_generator(
42 | initial_values,
43 | s=default_s, r=default_r, b=default_b,
44 | dt=0.01, num_steps=10000
45 | ):
46 | TRANSIENT_TIME = 50 # dock this
47 | values = np.empty((3, num_steps + TRANSIENT_TIME + 1))
48 | values[:, 0] = np.array(initial_values)
49 | # Step through "time", calculating the partial derivatives at the current point
50 | # and using them to estimate the next point
51 | for i in range(num_steps + TRANSIENT_TIME):
52 | x_dot, y_dot, z_dot = lorenz_deriv(
53 | values[0, i],
54 | values[1, i],
55 | values[2, i],
56 | s, r, b
57 | )
58 | values[0, i + 1] = values[0, i] + (x_dot * dt)
59 | values[1, i + 1] = values[1, i] + (y_dot * dt)
60 | values[2, i + 1] = values[2, i] + (z_dot * dt)
61 | return values.T[TRANSIENT_TIME:]
62 |
63 |
64 | #%%
65 | initial_values = (0., 1., 1.05)
66 | lorenz_data = lorenz_generator(initial_values, num_steps=5000).T
67 | # Plot
68 | fig = plt.figure()
69 | ax = fig.gca(projection='3d')
70 |
71 | ax.plot(*lorenz_data, lw=0.5)
72 | ax.set_xlabel("X Axis")
73 | ax.set_ylabel("Y Axis")
74 | ax.set_zlabel("Z Axis")
75 | ax.set_title("Lorenz Attractor")
76 |
77 | plt.show()
78 |
79 | #%%
80 |
81 | # Data HPs
82 | seed = 0
83 | np.random.seed(seed)
84 | n = 2000
85 | t = 200
86 | initial_values = np.random.random((n, 3))
87 | # initial_values = np.array([0., 1., 1.05])[None, :].repeat(n, axis=0)
88 | # print(initial_values[0])
89 | lorenz_params = np.array([default_s, default_r, default_b]) \
90 | + np.random.random((n, 3)) * 10 # Don't deviate too far from chaotic params
91 |
92 | # Generate data
93 | lorenz_dataset = np.empty((n, t, 3))
94 | for i in range(n):
95 | s, r, b = lorenz_params[i]
96 | lorenz_dataset[i] = lorenz_generator(
97 | initial_values[i], s=s, r=r, b=b,
98 | num_steps=t # Other values tend to be bad
99 | )[1:] # Drops initial value
100 |
101 | noise = np.random.random((n, t, 3)) * 0 # No noise
102 |
103 | trial_min = lorenz_dataset.min(axis=1)
104 | trial_max = lorenz_dataset.max(axis=1)
105 | lorenz_dataset -= trial_min[:, None, :]
106 | lorenz_dataset /= trial_max[:, None, :]
107 | lorenz_dataset += noise
108 | #%%
109 | torch.manual_seed(seed)
110 | # Convert factors to spike times
111 | n_neurons = 96
112 | SPARSE = False
113 | # What. How do we get this to look normal
114 | def lorenz_rates_to_rates_and_spike_times(dataset, base=-10, dynamic_range=(-1, 0)):
115 | # lorenz b x t x 3 -> b x t x h (h neurons)
116 | linear_transform = nn.Linear(3, n_neurons) # Generates a random transform
117 | linear_transform.weight.data.uniform_(*dynamic_range)
118 | # base = torch.log(torch.tensor(base) / 1000.0)
119 | linear_transform.bias.data.fill_(base)
120 | raw_logrates = linear_transform(torch.tensor(dataset).float())
121 | if SPARSE:
122 | # raw_logrates /= raw_logrates.max() # make max firing lograte 0.5 ~ 1.65 firing rate
123 | clipped_logrates = torch.clamp(raw_logrates, -10, 1) # Double take
124 | else:
125 | clipped_logrates = torch.clamp(raw_logrates, -20, 20) # Double take
126 | # Leads to spikes mostly in 1-5
127 | neuron_rates = torch.exp(clipped_logrates)
128 | return Poisson(neuron_rates).sample(), clipped_logrates
129 | # default mean firing rate is 5
130 | spike_timings, logrates = lorenz_rates_to_rates_and_spike_times(lorenz_dataset)
131 | print(torch.exp(logrates).mean())
132 | plt.hist(torch.exp(logrates.flatten()).numpy())
133 | if SPARSE:
134 | plt.yscale("log")
135 | plt.hist(spike_timings.flatten().numpy(), bins=4)
136 | #%%
137 | # 2000
138 | train_n = 160
139 | val_n = int((n - train_n) / 2)
140 | test_n = int((n - train_n) / 2)
141 | splits = (train_n, val_n, test_n)
142 | assert sum(splits) <= n, "Not enough data trials"
143 |
144 | train_rates, val_rates, test_rates = torch.split(logrates, splits)
145 | train_spikes, val_spikes, test_spikes = \
146 | torch.split(spike_timings, splits)
147 |
148 |
149 | def pack_dataset(rates, spikes):
150 | return {
151 | "rates": rates,
152 | "spikes": spikes.long()
153 | }
154 |
155 | data_dir = "/snel/share/data/pth_lorenz_large"
156 | os.makedirs(data_dir, exist_ok=True)
157 |
158 | torch.save(pack_dataset(train_rates, train_spikes), osp.join(data_dir, "train.pth"))
159 | torch.save(pack_dataset(val_rates, val_spikes), osp.join(data_dir, "val.pth"))
160 | torch.save(pack_dataset(test_rates, test_spikes), osp.join(data_dir, "test.pth"))
161 |
162 | #%%
163 | # Plot trajectories
164 | np.random.seed(0)
165 | fig = plt.figure()
166 | NUM_PLOTS = 6
167 | for i in range(NUM_PLOTS):
168 | sample_idx = np.random.randint(0, train_n)
169 | sample = lorenz_dataset[sample_idx].T
170 | ax = fig.add_subplot(231 + i, projection='3d')
171 | color = np.linspace(0, 1, sample.shape[1])
172 | ax.scatter(*sample, lw=0.5, c=color)
173 |
174 | ax.set_title(f"{sample_idx}")
175 | fig.tight_layout()
176 |
177 | plt.show()
178 |
179 | #%%
180 | # Plot rates
181 | np.random.seed(0)
182 | fig = plt.figure()
183 | NUM_PLOTS = 8
184 | print(f"Train \t {train_rates[0].max().item()}")
185 | print(f"Val \t {val_rates[0].max().item()}")
186 | plot_t = 100
187 | for i in range(NUM_PLOTS):
188 | sample_idx = np.random.randint(0, n_neurons)
189 | sample = torch.exp(train_rates[0, :plot_t, sample_idx]).numpy()
190 | ax = fig.add_subplot(421 + i)
191 | ax.plot(np.arange(plot_t), sample, )
192 | # ax.scatter(np.arange(t), sample, c=np.arange(0, 1, 1.0/t))
193 |
194 | ax.set_title(f"{sample_idx}")
195 | plt.suptitle("Rates")
196 | fig.tight_layout()
197 | plt.show()
198 |
199 | #%%
200 | # Plot spike times
201 | np.random.seed(0)
202 | fig = plt.figure()
203 | NUM_PLOTS = 8
204 | print(f"Train \t {train_spikes[0].max().item()}")
205 | print(f"Val \t {val_spikes[0].max().item()}")
206 | plot_t = 50
207 | for i in range(NUM_PLOTS):
208 | sample_idx = np.random.randint(0, n_neurons)
209 | plot_data = train_spikes[0, :plot_t, sample_idx]
210 | sample = plot_data.numpy()
211 | ax = fig.add_subplot(421 + i)
212 | ax.plot(np.arange(plot_t), sample, )
213 | # ax.scatter(np.arange(t), sample, c=np.arange(0, 1, 1.0/t))
214 |
215 | ax.set_title(f"{sample_idx}")
216 | plt.suptitle("Spikes")
217 | fig.tight_layout()
218 | plt.show()
219 |
220 |
221 | #%%
222 | lfads_data = "/snel/share/data/lfads_lorenz_20ms"
223 | import h5py
224 | h5file = f"{lfads_data}/lfads_dataset001.h5"
225 | with h5py.File(h5file, "r") as f:
226 | h5dict = {key: f[key][()] for key in f.keys()}
227 | if 'train_truth' in h5dict:
228 | cf = h5dict['conversion_factor']
229 | train_truth = h5dict['train_truth'].astype(np.float32) / cf
230 | data = torch.tensor(train_truth)
231 | n_neurons = data.shape[-1]
232 | plt.hist(data.flatten().numpy(), bins=50)
--------------------------------------------------------------------------------
/defaults.py:
--------------------------------------------------------------------------------
1 | # Src: Andrew's tune_tf2
2 |
3 | from os import path
4 |
5 | repo_path = path.dirname(path.realpath(__file__))
6 | DEFAULT_CONFIG_DIR = path.join(repo_path, 'configs')
7 |
8 | # contains data about general status of PBT optimization
9 | PBT_CSV = 'pbt_state.csv'
10 | # contains data about which models are exploited
11 | EXPLOIT_CSV = 'exploits.csv'
12 | # contains data about which hyperparameters are used
13 | HPS_CSV= 'gen_config.csv'
14 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: ndt_old
2 | channels:
3 | - pytorch
4 | - conda-forge
5 | - anaconda
6 | - defaults
7 | dependencies:
8 | - _libgcc_mutex=0.1=main
9 | - attrs=19.3.0=py_0
10 | - backcall=0.1.0=py36_0
11 | - basemap=1.2.0=py36h705c2d8_0
12 | - blas=1.0=mkl
13 | - bleach=3.1.4=py_0
14 | - ca-certificates=2020.10.14=0
15 | - certifi=2020.6.20=py36_0
16 | - cudatoolkit=10.1.243=h6bb024c_0
17 | - cycler=0.10.0=py36_0
18 | - dbus=1.13.12=h746ee38_0
19 | - decorator=4.4.2=py_0
20 | - defusedxml=0.6.0=py_0
21 | - entrypoints=0.3=py36_0
22 | - expat=2.2.6=he6710b0_0
23 | - fontconfig=2.13.0=h9420a91_0
24 | - freetype=2.9.1=h8a8886c_1
25 | - geos=3.6.2=heeff764_2
26 | - git-lfs=2.6.1=h7b6447c_0
27 | - glib=2.56.2=hd408876_0
28 | - gmp=6.1.2=h6c8ec71_1
29 | - gst-plugins-base=1.14.0=hbbd80ab_1
30 | - gstreamer=1.14.0=hb453b48_1
31 | - icu=58.2=h211956c_0
32 | - importlib_metadata=1.5.0=py36_0
33 | - intel-openmp=2020.0=166
34 | - ipykernel=5.1.4=py36h39e3cac_0
35 | - ipython=7.13.0=py36h5ca1d4c_0
36 | - ipython_genutils=0.2.0=py36_0
37 | - ipywidgets=7.5.1=py_0
38 | - jedi=0.17.0=py36_0
39 | - jinja2=2.11.2=py_0
40 | - joblib=0.14.1=py_0
41 | - jpeg=9b=h024ee3a_2
42 | - jsonschema=3.2.0=py36_0
43 | - jupyter=1.0.0=py36_7
44 | - jupyter_client=6.1.3=py_0
45 | - jupyter_console=6.1.0=py_0
46 | - jupyter_core=4.6.3=py36_0
47 | - kiwisolver=1.1.0=py36he6710b0_0
48 | - ld_impl_linux-64=2.33.1=h53a641e_7
49 | - libedit=3.1.20181209=hc058e9b_0
50 | - libffi=3.2.1=hd88cf55_4
51 | - libgcc-ng=9.1.0=hdf63c60_0
52 | - libgfortran-ng=7.3.0=hdf63c60_0
53 | - libpng=1.6.37=hbc83047_0
54 | - libsodium=1.0.16=h1bed415_0
55 | - libstdcxx-ng=9.1.0=hdf63c60_0
56 | - libtiff=4.1.0=h2733197_0
57 | - libuuid=1.0.3=h1bed415_2
58 | - libxcb=1.13=h1bed415_1
59 | - libxml2=2.9.9=hea5a465_1
60 | - markupsafe=1.1.1=py36h7b6447c_0
61 | - matplotlib=3.1.3=py36_0
62 | - matplotlib-base=3.1.3=py36hef1b27d_0
63 | - mistune=0.8.4=py36h7b6447c_0
64 | - mkl=2020.0=166
65 | - mkl-service=2.3.0=py36he904b0f_0
66 | - mkl_fft=1.0.15=py36ha843d7b_0
67 | - mkl_random=1.1.0=py36hd6b4f25_0
68 | - nbconvert=5.6.1=py36_0
69 | - nbformat=5.0.4=py_0
70 | - ncurses=6.2=he6710b0_1
71 | - ninja=1.9.0=py36hfd86e86_0
72 | - notebook=6.0.3=py36_0
73 | - numpy=1.18.1=py36h4f9e942_0
74 | - numpy-base=1.18.1=py36hde5b4d6_1
75 | - olefile=0.46=py36_0
76 | - openssl=1.1.1h=h7b6447c_0
77 | - pandas=1.0.4=py36h0573a6f_0
78 | - pandoc=2.2.3.2=0
79 | - pandocfilters=1.4.2=py36_1
80 | - parso=0.7.0=py_0
81 | - pcre=8.43=he6710b0_0
82 | - pexpect=4.8.0=py36_0
83 | - pickleshare=0.7.5=py36_0
84 | - pillow=7.1.2=py36hb39fc2d_0
85 | - pip=20.0.2=py36_1
86 | - proj=6.2.1=haa6030c_0
87 | - prometheus_client=0.7.1=py_0
88 | - prompt-toolkit=3.0.4=py_0
89 | - prompt_toolkit=3.0.4=0
90 | - ptyprocess=0.6.0=py36_0
91 | - pygments=2.6.1=py_0
92 | - pyparsing=2.4.6=py_0
93 | - pyproj=2.6.0=py36hd003209_1
94 | - pyqt=5.9.2=py36h22d08a2_1
95 | - pyrsistent=0.16.0=py36h7b6447c_0
96 | - pyshp=2.1.0=py_0
97 | - python=3.6.10=hcf32534_1
98 | - python-dateutil=2.8.1=py_0
99 | - python_abi=3.6=1_cp36m
100 | - pytorch=1.5.0=py3.6_cuda10.1.243_cudnn7.6.3_0
101 | - pytz=2020.1=py_0
102 | - pyzmq=18.1.1=py36he6710b0_0
103 | - qt=5.9.7=h5867ecd_1
104 | - qtconsole=4.7.3=py_0
105 | - qtpy=1.9.0=py_0
106 | - readline=8.0=h7b6447c_0
107 | - scikit-learn=0.22.1=py36hd81dba3_0
108 | - scipy=1.4.1=py36h0b6359f_0
109 | - seaborn=0.10.1=py_0
110 | - send2trash=1.5.0=py36_0
111 | - setuptools=46.1.3=py36_0
112 | - sip=4.19.13=py36he6710b0_0
113 | - six=1.14.0=py36_0
114 | - sqlite=3.31.1=h62c20be_1
115 | - terminado=0.8.3=py36_0
116 | - testpath=0.4.4=py_0
117 | - tk=8.6.8=hbc83047_0
118 | - torchvision=0.6.0=py36_cu101
119 | - tornado=6.0.4=py36h7b6447c_1
120 | - traitlets=4.3.3=py36_0
121 | - wcwidth=0.1.9=py_0
122 | - webencodings=0.5.1=py36_1
123 | - wheel=0.34.2=py36_0
124 | - widgetsnbextension=3.5.1=py36_0
125 | - xz=5.2.5=h7b6447c_0
126 | - yaml=0.1.7=h96e3832_1
127 | - zeromq=4.3.1=he6710b0_3
128 | - zipp=3.1.0=py_0
129 | - zlib=1.2.11=h7b6447c_3
130 | - zstd=1.3.7=h0b5b093_0
131 | - pip:
132 | - absl-py==0.9.0
133 | - aiohttp==3.6.2
134 | - aiohttp-cors==0.7.0
135 | - aioredis==1.3.1
136 | - astor==0.8.1
137 | - async-timeout==3.0.1
138 | - beautifulsoup4==4.9.2
139 | - blessings==1.7
140 | - boto3==1.13.26
141 | - botocore==1.16.26
142 | - cachetools==4.1.0
143 | - chardet==3.0.4
144 | - click==7.1.2
145 | - colorama==0.4.3
146 | - colorful==0.5.4
147 | - contextvars==2.4
148 | - dataclasses==0.7
149 | - docutils==0.15.2
150 | - filelock==3.0.12
151 | - gast==0.2.2
152 | - google==3.0.0
153 | - google-api-core==1.22.2
154 | - google-auth==1.22.0
155 | - google-auth-oauthlib==0.4.1
156 | - google-pasta==0.2.0
157 | - googleapis-common-protos==1.52.0
158 | - gpustat==0.6.0
159 | - grpcio==1.28.1
160 | - h5py==2.10.0
161 | - hiredis==1.1.0
162 | - idna==2.9
163 | - idna-ssl==1.1.0
164 | - immutables==0.14
165 | - jmespath==0.10.0
166 | - keras-applications==1.0.8
167 | - keras-preprocessing==1.1.0
168 | - markdown==3.2.1
169 | - msgpack==1.0.0
170 | - multidict==4.7.6
171 | - nvidia-ml-py3==7.352.0
172 | - oauthlib==3.1.0
173 | - opencensus==0.7.10
174 | - opencensus-context==0.1.1
175 | - opt-einsum==3.2.1
176 | - protobuf==3.11.3
177 | - psutil==5.7.2
178 | - py-spy==0.3.3
179 | - pyasn1==0.4.8
180 | - pyasn1-modules==0.2.8
181 | - pytorch-transformers==1.2.0
182 | - pyyaml==5.3.1
183 | - ray==0.8.7
184 | - redis==3.4.1
185 | - regex==2020.6.8
186 | - requests==2.23.0
187 | - requests-oauthlib==1.3.0
188 | - rsa==4.0
189 | - s3transfer==0.3.3
190 | - sacremoses==0.0.43
191 | - sentencepiece==0.1.91
192 | - soupsieve==2.0.1
193 | - tabulate==0.8.7
194 | - tensorboard==2.2.0
195 | - tensorboard-plugin-wit==1.6.0.post3
196 | - tensorboardx==2.1
197 | - tensorflow==2.1.0
198 | - tensorflow-estimator==2.1.0
199 | - termcolor==1.1.0
200 | - tqdm==4.46.1
201 | - typing-extensions==3.7.4.3
202 | - urllib3==1.25.9
203 | - werkzeug==1.0.1
204 | - wrapt==1.12.1
205 | - yacs==0.1.7
206 | - yarl==1.6.0
207 | prefix: /snel/home/joely/anaconda3/envs/pytorch
208 |
209 |
--------------------------------------------------------------------------------
/nlb.yml:
--------------------------------------------------------------------------------
1 | name: ndt
2 | channels:
3 | - pytorch
4 | - defaults
5 | dependencies:
6 | - _libgcc_mutex=0.1=main
7 | - _openmp_mutex=4.5=1_gnu
8 | - blas=1.0=mkl
9 | - bottleneck=1.3.2=py37heb32a55_1
10 | - brotli=1.0.9=he6710b0_2
11 | - bzip2=1.0.8=h7b6447c_0
12 | - ca-certificates=2021.10.26=h06a4308_2
13 | - certifi=2021.10.8=py37h06a4308_2
14 | - cudatoolkit=10.2.89=hfd86e86_1
15 | - cycler=0.11.0=pyhd3eb1b0_0
16 | - dbus=1.13.18=hb2f20db_0
17 | - expat=2.4.4=h295c915_0
18 | - ffmpeg=4.3=hf484d3e_0
19 | - fontconfig=2.13.1=h6c09931_0
20 | - fonttools=4.25.0=pyhd3eb1b0_0
21 | - freetype=2.11.0=h70c0345_0
22 | - giflib=5.2.1=h7b6447c_0
23 | - glib=2.69.1=h4ff587b_1
24 | - gmp=6.2.1=h2531618_2
25 | - gnutls=3.6.15=he1e5248_0
26 | - gst-plugins-base=1.14.0=h8213a91_2
27 | - gstreamer=1.14.0=h28cd5cc_2
28 | - icu=58.2=he6710b0_3
29 | - intel-openmp=2021.4.0=h06a4308_3561
30 | - jpeg=9d=h7f8727e_0
31 | - kiwisolver=1.3.2=py37h295c915_0
32 | - lame=3.100=h7b6447c_0
33 | - lcms2=2.12=h3be6417_0
34 | - ld_impl_linux-64=2.35.1=h7274673_9
35 | - libffi=3.3=he6710b0_2
36 | - libgcc-ng=9.3.0=h5101ec6_17
37 | - libgfortran-ng=7.5.0=ha8ba4b0_17
38 | - libgfortran4=7.5.0=ha8ba4b0_17
39 | - libgomp=9.3.0=h5101ec6_17
40 | - libiconv=1.15=h63c8f33_5
41 | - libidn2=2.3.2=h7f8727e_0
42 | - libpng=1.6.37=hbc83047_0
43 | - libstdcxx-ng=9.3.0=hd4cf53a_17
44 | - libtasn1=4.16.0=h27cfd23_0
45 | - libtiff=4.2.0=h85742a9_0
46 | - libunistring=0.9.10=h27cfd23_0
47 | - libuuid=1.0.3=h7f8727e_2
48 | - libuv=1.40.0=h7b6447c_0
49 | - libwebp=1.2.0=h89dd481_0
50 | - libwebp-base=1.2.0=h27cfd23_0
51 | - libxcb=1.14=h7b6447c_0
52 | - libxml2=2.9.12=h03d6c58_0
53 | - lz4-c=1.9.3=h295c915_1
54 | - matplotlib=3.5.1=py37h06a4308_0
55 | - matplotlib-base=3.5.1=py37ha18d171_0
56 | - mkl=2021.4.0=h06a4308_640
57 | - mkl-service=2.4.0=py37h7f8727e_0
58 | - mkl_fft=1.3.1=py37hd3c417c_0
59 | - mkl_random=1.2.2=py37h51133e4_0
60 | - munkres=1.1.4=py_0
61 | - ncurses=6.3=h7f8727e_2
62 | - nettle=3.7.3=hbbd107a_1
63 | - numexpr=2.8.1=py37h6abb31d_0
64 | - numpy=1.21.2=py37h20f2e39_0
65 | - numpy-base=1.21.2=py37h79a1101_0
66 | - olefile=0.46=py37_0
67 | - openh264=2.1.1=h4ff587b_0
68 | - openssl=1.1.1m=h7f8727e_0
69 | - packaging=21.3=pyhd3eb1b0_0
70 | - pcre=8.45=h295c915_0
71 | - pillow=8.4.0=py37h5aabda8_0
72 | - pip=21.2.2=py37h06a4308_0
73 | - pyqt=5.9.2=py37h05f1152_2
74 | - python=3.7.11=h12debd9_0
75 | - python-dateutil=2.8.2=pyhd3eb1b0_0
76 | - pytorch=1.10.2=py3.7_cuda10.2_cudnn7.6.5_0
77 | - pytorch-mutex=1.0=cuda
78 | - qt=5.9.7=h5867ecd_1
79 | - readline=8.1.2=h7f8727e_1
80 | - scipy=1.7.3=py37hc147768_0
81 | - seaborn=0.11.2=pyhd3eb1b0_0
82 | - setuptools=58.0.4=py37h06a4308_0
83 | - sip=4.19.8=py37hf484d3e_0
84 | - six=1.16.0=pyhd3eb1b0_1
85 | - sqlite=3.37.2=hc218d9a_0
86 | - tk=8.6.11=h1ccaba5_0
87 | - torchaudio=0.10.2=py37_cu102
88 | - torchvision=0.11.3=py37_cu102
89 | - tornado=6.1=py37h27cfd23_0
90 | - typing_extensions=3.10.0.2=pyh06a4308_0
91 | - wheel=0.37.1=pyhd3eb1b0_0
92 | - xz=5.2.5=h7b6447c_0
93 | - zlib=1.2.11=h7f8727e_4
94 | - zstd=1.4.9=haebb681_0
95 | - pip:
96 | - absl-py==1.0.0
97 | - appdirs==1.4.4
98 | - asciitree==0.3.3
99 | - attrs==21.4.0
100 | - blessings==1.7
101 | - boto3==1.21.4
102 | - botocore==1.24.5
103 | - cached-property==1.5.2
104 | - cachetools==5.0.0
105 | - cffi==1.15.0
106 | - charset-normalizer==2.0.12
107 | - ci-info==0.2.0
108 | - click==8.0.4
109 | - click-didyoumean==0.3.0
110 | - cryptography==36.0.1
111 | - dandi==0.36.0
112 | - dandischema==0.5.2
113 | - deprecated==1.2.13
114 | - dnspython==2.2.0
115 | - email-validator==1.1.3
116 | - etelemetry==0.3.0
117 | - fasteners==0.17.3
118 | - filelock==3.6.0
119 | - fscacher==0.2.0
120 | - google-auth==2.6.0
121 | - google-auth-oauthlib==0.4.6
122 | - grpcio==1.44.0
123 | - h5py==3.6.0
124 | - hdmf==3.1.1
125 | - humanize==4.0.0
126 | - idna==3.3
127 | - importlib-metadata==4.11.1
128 | - interleave==0.2.0
129 | - jeepney==0.7.1
130 | - jmespath==0.10.0
131 | - joblib==1.1.0
132 | - jsonpointer==2.2
133 | - jsonschema==3.2.0
134 | - keyring==23.5.0
135 | - keyrings-alt==4.1.0
136 | - markdown==3.3.6
137 | - msgpack==1.0.3
138 | - numcodecs==0.9.1
139 | - oauthlib==3.2.0
140 | - pandas==1.3.5
141 | - protobuf==3.19.4
142 | - pyasn1==0.4.8
143 | - pyasn1-modules==0.2.8
144 | - pycparser==2.21
145 | - pycryptodomex==3.14.1
146 | - pydantic==1.9.0
147 | - pynwb==2.0.0
148 | - pyout==0.7.2
149 | - pyparsing==3.0.7
150 | - pyrsistent==0.18.1
151 | - pytorch-transformers==1.2.0
152 | - pytz==2021.3
153 | - pyyaml==6.0
154 | - ray==1.10.0
155 | - redis==4.1.4
156 | - regex==2022.1.18
157 | - requests==2.27.1
158 | - requests-oauthlib==1.3.1
159 | - rfc3987==1.3.8
160 | - rsa==4.8
161 | - ruamel-yaml==0.17.21
162 | - ruamel-yaml-clib==0.2.6
163 | - s3transfer==0.5.1
164 | - sacremoses==0.0.47
165 | - scikit-learn==1.0.2
166 | - secretstorage==3.3.1
167 | - semantic-version==2.9.0
168 | - sentencepiece==0.1.96
169 | - sklearn==0.0
170 | - strict-rfc3339==0.7
171 | - tabulate==0.8.9
172 | - tenacity==8.0.1
173 | - tensorboard==2.8.0
174 | - tensorboard-data-server==0.6.1
175 | - tensorboard-plugin-wit==1.8.1
176 | - tensorboardx==2.5
177 | - threadpoolctl==3.1.0
178 | - tqdm==4.62.3
179 | - urllib3==1.26.8
180 | - webcolors==1.11.1
181 | - werkzeug==2.0.3
182 | - wrapt==1.13.3
183 | - yacs==0.1.8
184 | - zarr==2.11.0
185 | - zipp==3.7.0
186 | prefix: /home/joelye/anaconda3/envs/ndt
187 |
--------------------------------------------------------------------------------
/ray_get_lfve.py:
--------------------------------------------------------------------------------
1 | """
2 | posthoc script to create a directory for "best lfve"
3 | """
4 |
5 | from typing import List, Union
6 | from os import path
7 | import json
8 | import argparse
9 | import ray, yaml, shutil
10 | from ray import tune
11 | import torch
12 |
13 | from tune_models import tuneNDT
14 |
15 | from defaults import DEFAULT_CONFIG_DIR
16 | from src.config.default import flatten
17 |
18 | PBT_HOME = path.expanduser('~/ray_results/ndt/gridsearch')
19 | OVERWRITE = True
20 | PBT_METRIC = 'smth_masked_loss'
21 | BEST_MODEL_METRIC = 'best_masked_loss'
22 | LOGGED_COLUMNS = ['smth_masked_loss', 'masked_loss', 'r2', 'unmasked_loss']
23 |
24 | DEFAULT_HP_DICT = {
25 | 'TRAIN.WEIGHT_DECAY': tune.loguniform(1e-8, 1e-3),
26 | 'TRAIN.MASK_RATIO': tune.uniform(0.1, 0.4)
27 | }
28 |
29 | def get_parser():
30 | parser = argparse.ArgumentParser()
31 |
32 | parser.add_argument(
33 | "--exp-config", "-e",
34 | type=str,
35 | required=True,
36 | help="path to config yaml containing info about experiment",
37 | )
38 |
39 | parser.add_argument('--eval-only', '-ev', dest='eval_only', action='store_true')
40 | parser.add_argument('--no-eval-only', '-nev', dest='eval_only', action='store_false')
41 | parser.set_defaults(eval_only=False)
42 |
43 | parser.add_argument(
44 | "--name", "-n",
45 | type=str,
46 | default="",
47 | help="defaults to exp filename"
48 | )
49 |
50 | parser.add_argument(
51 | "--gpus-per-worker", "-g",
52 | type=float,
53 | default=0.5
54 | )
55 |
56 | parser.add_argument(
57 | "--cpus-per-worker", "-c",
58 | type=float,
59 | default=3.0
60 | )
61 |
62 | parser.add_argument(
63 | "--workers", "-w",
64 | type=int,
65 | default=-1,
66 | help="-1 indicates -- use max possible workers on machine (assuming 0.5 GPUs per trial)"
67 | )
68 |
69 | parser.add_argument(
70 | "--samples", "-s",
71 | type=int,
72 | default=20,
73 | help="samples for random search"
74 | )
75 |
76 | parser.add_argument(
77 | "--seed", "-d",
78 | type=int,
79 | default=-1,
80 | help="seed for config"
81 | )
82 |
83 | return parser
84 |
85 | def main():
86 | parser = get_parser()
87 | args = parser.parse_args()
88 | launch_search(**vars(args))
89 |
90 | def build_hp_dict(raw_json: dict):
91 | hp_dict = {}
92 | for key in raw_json:
93 | info: dict = raw_json[key]
94 | sample_fn = info.get("sample_fn", "uniform")
95 | assert hasattr(tune, sample_fn)
96 | if sample_fn == "choice":
97 | hp_dict[key] = tune.choice(info['opts'])
98 | else:
99 | assert "low" in info, "high" in info
100 | sample_fn = getattr(tune, sample_fn)
101 | hp_dict[key] = sample_fn(info['low'], info['high'])
102 | return hp_dict
103 |
104 | def launch_search(exp_config: Union[List[str], str], name: str, workers: int, gpus_per_worker: float, cpus_per_worker: float, eval_only: bool, samples: int, seed: int) -> None:
105 | # ---------- PBT I/O CONFIGURATION ----------
106 | # the directory to save PBT runs (usually '~/ray_results')
107 |
108 | if len(path.split(exp_config)[0]) > 0:
109 | CFG_PATH = exp_config
110 | else:
111 | CFG_PATH = path.join(DEFAULT_CONFIG_DIR, exp_config)
112 | variant_name = path.split(CFG_PATH)[1].split('.')[0]
113 | if seed > 0:
114 | variant_name = f"{variant_name}-s{seed}"
115 | if name == "":
116 | name = variant_name
117 | pbt_dir = path.join(PBT_HOME, name)
118 | # the name of this PBT run (run will be stored at `pbt_dir`)
119 |
120 | # ---------------------------------------------
121 | # * No train step
122 | # load the results dataframe for this run
123 | df = tune.Analysis(
124 | pbt_dir
125 | ).dataframe()
126 | df = df[df.logdir.apply(lambda path: not 'best_model' in path)]
127 |
128 | lfves = []
129 | for logdir in df.logdir:
130 | ckpt = torch.load(path.join(logdir, f'ckpts/{variant_name}.lfve.pth'), map_location='cpu')
131 | lfves.append(ckpt['best_unmasked_val']['value'])
132 | df['best_unmasked_val'] = lfves
133 | best_model_logdir = df.loc[df['best_unmasked_val'].idxmin()].logdir
134 | best_model_dest = path.join(pbt_dir, 'best_model_unmasked')
135 | if path.exists(best_model_dest):
136 | shutil.rmtree(best_model_dest)
137 | shutil.copytree(best_model_logdir, best_model_dest)
138 |
139 | if __name__ == "__main__":
140 | main()
--------------------------------------------------------------------------------
/ray_random.py:
--------------------------------------------------------------------------------
1 | # Src: Andrew's run_random_search in tune_tf2
2 |
3 | """
4 | Run grid search for NDT.
5 | """
6 |
7 | from typing import List, Union
8 | import os
9 | from os import path
10 | import json
11 | import argparse
12 | import ray, yaml, shutil
13 | from ray import tune
14 | import torch
15 |
16 | from tune_models import tuneNDT
17 |
18 | from defaults import DEFAULT_CONFIG_DIR
19 | from src.config.default import flatten
20 |
21 | PBT_HOME = path.expanduser('~/user_data/nlb/ndt_runs/ray/')
22 | os.makedirs(PBT_HOME, exist_ok=True) # ray hangs if this isn't true...
23 | OVERWRITE = True
24 | PBT_METRIC = 'smth_masked_loss'
25 | BEST_MODEL_METRIC = 'best_masked_loss'
26 | LOGGED_COLUMNS = ['smth_masked_loss', 'masked_loss', 'r2', 'unmasked_loss']
27 |
28 | DEFAULT_HP_DICT = {
29 | 'TRAIN.WEIGHT_DECAY': tune.loguniform(1e-8, 1e-3),
30 | 'TRAIN.MASK_RATIO': tune.uniform(0.1, 0.4)
31 | }
32 |
33 | def get_parser():
34 | parser = argparse.ArgumentParser()
35 |
36 | parser.add_argument(
37 | "--exp-config", "-e",
38 | type=str,
39 | required=True,
40 | help="path to config yaml containing info about experiment",
41 | )
42 |
43 | parser.add_argument('--eval-only', '-ev', dest='eval_only', action='store_true')
44 | parser.add_argument('--no-eval-only', '-nev', dest='eval_only', action='store_false')
45 | parser.set_defaults(eval_only=False)
46 |
47 | parser.add_argument(
48 | "--name", "-n",
49 | type=str,
50 | default="",
51 | help="defaults to exp filename"
52 | )
53 |
54 | parser.add_argument(
55 | "--gpus-per-worker", "-g",
56 | type=float,
57 | default=0.5
58 | )
59 |
60 | parser.add_argument(
61 | "--cpus-per-worker", "-c",
62 | type=float,
63 | default=3.0
64 | )
65 |
66 | parser.add_argument(
67 | "--workers", "-w",
68 | type=int,
69 | default=-1,
70 | help="-1 indicates -- use max possible workers on machine (assuming 0.5 GPUs per trial)"
71 | )
72 |
73 | parser.add_argument(
74 | "--samples", "-s",
75 | type=int,
76 | default=20,
77 | help="samples for random search"
78 | )
79 |
80 | parser.add_argument(
81 | "--seed", "-d",
82 | type=int,
83 | default=-1,
84 | help="seed for config"
85 | )
86 |
87 | return parser
88 |
89 | def main():
90 | parser = get_parser()
91 | args = parser.parse_args()
92 | launch_search(**vars(args))
93 |
94 | def build_hp_dict(raw_json: dict):
95 | hp_dict = {}
96 | for key in raw_json:
97 | info: dict = raw_json[key]
98 | sample_fn = info.get("sample_fn", "uniform")
99 | assert hasattr(tune, sample_fn)
100 | if sample_fn == "choice":
101 | hp_dict[key] = tune.choice(info['opts'])
102 | else:
103 | assert "low" in info, "high" in info
104 | sample_fn = getattr(tune, sample_fn)
105 | hp_dict[key] = sample_fn(info['low'], info['high'])
106 | return hp_dict
107 |
108 | def launch_search(exp_config: Union[List[str], str], name: str, workers: int, gpus_per_worker: float, cpus_per_worker: float, eval_only: bool, samples: int, seed: int) -> None:
109 | # ---------- PBT I/O CONFIGURATION ----------
110 | # the directory to save PBT runs (usually '~/ray_results')
111 |
112 | if len(path.split(exp_config)[0]) > 0:
113 | CFG_PATH = exp_config
114 | else:
115 | CFG_PATH = path.join(DEFAULT_CONFIG_DIR, exp_config)
116 | variant_name = path.split(CFG_PATH)[1].split('.')[0]
117 | # Ok, now update the paths in the config
118 | if seed > 0:
119 | variant_name = f"{variant_name}-s{seed}"
120 | if name == "":
121 | name = variant_name
122 |
123 | pbt_dir = path.join(PBT_HOME, name)
124 | # the name of this PBT run (run will be stored at `pbt_dir`)
125 |
126 | # ---------- PBT RUN CONFIGURATION ----------
127 | # whether to use single machine or cluster
128 | SINGLE_MACHINE = True # Cluster not supported atm, don't know how to use it.
129 |
130 | NUM_WORKERS = workers if workers > 0 else int(torch.cuda.device_count() // gpus_per_worker)
131 | # the resources to allocate per model
132 | RESOURCES_PER_TRIAL = {"cpu": cpus_per_worker, "gpu": gpus_per_worker}
133 |
134 | # ---------------------------------------------
135 |
136 | def train():
137 | if path.exists(pbt_dir):
138 | print("Run exists!!! Overwriting.")
139 | if not OVERWRITE:
140 | print("overwriting disallowed, exiting..")
141 | exit(0)
142 | else:
143 | if path.exists(pbt_dir):
144 | shutil.rmtree(pbt_dir)
145 |
146 | # load the configuration as a dictionary and update for this run
147 | flat_cfg_dict = flatten(yaml.full_load(open(CFG_PATH)))
148 |
149 | # Default behavior is to pull experiment name from config file
150 | # Bind variant name to directories
151 | flat_cfg_dict.update({'VARIANT': variant_name})
152 | if seed > 0:
153 | flat_cfg_dict.update({'SEED': seed})
154 |
155 | # the hyperparameter space to search
156 | assert 'TRAIN.TUNE_HP_JSON' in flat_cfg_dict, "please specify hp sweep (no default)"
157 | with open(flat_cfg_dict['TRAIN.TUNE_HP_JSON']) as f:
158 | raw_hp_json = json.load(f)
159 | cfg_samples = DEFAULT_HP_DICT
160 | cfg_samples.update(build_hp_dict(raw_hp_json))
161 |
162 | flat_cfg_dict.update(cfg_samples)
163 |
164 | # connect to Ray cluster or start on single machine
165 | address = None if SINGLE_MACHINE else 'localhost:6379'
166 | ray.init(address=address)
167 |
168 | reporter = tune.CLIReporter(metric_columns=LOGGED_COLUMNS)
169 |
170 | analysis = tune.run(
171 | tuneNDT,
172 | name=name,
173 | local_dir=pbt_dir,
174 | stop={'done': True},
175 | config=flat_cfg_dict,
176 | resources_per_trial=RESOURCES_PER_TRIAL,
177 | num_samples=samples,
178 | # sync_to_driver='# {source} {target}', # prevents rsync
179 | verbose=1,
180 | progress_reporter=reporter,
181 | # loggers=(tune.logger.JsonLogger, tune.logger.CSVLogger)
182 | )
183 |
184 | if not eval_only:
185 | train()
186 | # load the results dataframe for this run
187 | df = tune.Analysis(
188 | pbt_dir
189 | ).dataframe()
190 | df = df[df.logdir.apply(lambda path: not 'best_model' in path)]
191 |
192 | # Hm... we need to go through each model, and run the lfve ckpt.
193 | # And then record that in the dataframe?
194 |
195 | if df[BEST_MODEL_METRIC].dtype == 'O': # Accidentally didn't case to scalar, now we have a tensor string
196 | df = df.assign(best_masked_loss=lambda df: df[BEST_MODEL_METRIC].str[7:13].astype(float))
197 | best_model_logdir = df.loc[df[BEST_MODEL_METRIC].idxmin()].logdir
198 | # copy the best model somewhere easy to find
199 | # best_model_src = path.join(best_model_logdir, 'model_dir')
200 | best_model_dest = path.join(pbt_dir, 'best_model')
201 | if path.exists(best_model_dest):
202 | shutil.rmtree(best_model_dest)
203 | shutil.copytree(best_model_logdir, best_model_dest)
204 |
205 | if __name__ == "__main__":
206 | main()
--------------------------------------------------------------------------------
/scripts/analyze_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 |
4 | import sys
5 | module_path = os.path.abspath(os.path.join('..'))
6 | if module_path not in sys.path:
7 | sys.path.append(module_path)
8 |
9 | import time
10 | import h5py
11 | import matplotlib.pyplot as plt
12 | import seaborn as sns
13 | import numpy as np
14 | import torch
15 | import torch.nn.functional as f
16 | from torch.utils import data
17 |
18 | from src.run import prepare_config
19 | from src.runner import Runner
20 | from src.dataset import SpikesDataset, DATASET_MODES
21 | from src.mask import Masker, UNMASKED_LABEL
22 |
23 | def make_runner(variant, ckpt, base="", prefix=""):
24 | run_type = "eval"
25 | exp_config = osp.join("../configs", prefix, f"{variant}.yaml")
26 | if base != "":
27 | exp_config = [osp.join("../configs", f"{base}.yaml"), exp_config]
28 | ckpt_path = f"{variant}.{ckpt}.pth"
29 | config, ckpt_path = prepare_config(
30 | exp_config, run_type, ckpt_path, [
31 | "USE_TENSORBOARD", False,
32 | "SYSTEM.NUM_GPUS", 1,
33 | ], suffix=prefix
34 | )
35 | return Runner(config), ckpt_path
36 |
37 | def setup_dataset(runner, mode):
38 | test_set = SpikesDataset(runner.config, runner.config.DATA.VAL_FILENAME, mode=mode, logger=runner.logger)
39 | runner.logger.info(f"Evaluating on {len(test_set)} samples.")
40 | test_set.clip_spikes(runner.max_spikes)
41 | spikes, rates, heldout_spikes, forward_spikes = test_set.get_dataset()
42 | if heldout_spikes is not None:
43 | heldout_spikes = heldout_spikes.to(runner.device)
44 | if forward_spikes is not None:
45 | forward_spikes = forward_spikes.to(runner.device)
46 | return spikes.to(runner.device), rates.to(runner.device), heldout_spikes, forward_spikes
47 |
48 | def init_by_ckpt(ckpt_path, mode=DATASET_MODES.val):
49 | runner = Runner(checkpoint_path=ckpt_path)
50 | runner.model.eval()
51 | torch.set_grad_enabled(False)
52 | spikes, rates, heldout_spikes, forward_spikes = setup_dataset(runner, mode)
53 | return runner, spikes, rates, heldout_spikes, forward_spikes
54 |
55 | def init(variant, ckpt, base="", prefix="", mode=DATASET_MODES.val):
56 | runner, ckpt_path = make_runner(variant, ckpt, base, prefix)
57 | return init_by_ckpt(ckpt_path, mode)
58 |
59 |
60 | # Accumulates multiplied attentions - examine at lower layers to see where information is sourced in early processing.
61 | # Takes in layer weights
62 | def get_multiplicative_weights(weights_list):
63 | weights = weights_list[0]
64 | multiplied_weights = [weights]
65 | for layer_weights in weights_list[1:]:
66 | weights = torch.bmm(layer_weights, weights)
67 | multiplied_weights.append(weights)
68 | return multiplied_weights
69 |
--------------------------------------------------------------------------------
/scripts/clear.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | if [[ $# -eq 1 ]]
4 | then
5 | python -u src/run.py --run-type train --exp-config configs/$1.yaml --clear-only True
6 | else
7 | echo "Expected args (ckpt)"
8 | fi
--------------------------------------------------------------------------------
/scripts/eval.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # * The evaluation code path has legacy code. Evaluation / analysis is done in analysis scripts.
4 |
5 | if [[ $# -eq 2 ]]
6 | then
7 | python -u src/run.py --run-type eval --exp-config configs/$1.yaml --ckpt-path $1.$2.pth
8 | else
9 | echo "Expected args (ckpt)"
10 | fi
--------------------------------------------------------------------------------
/scripts/fig_3a_synthetic_neurons.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Author: Joel Ye
3 |
4 | # Notebook for interactive model evaluation/analysis
5 | # Allows us to interrogate model on variable data (instead of masked sample again)
6 |
7 | #%%
8 | import os
9 | import os.path as osp
10 |
11 | import sys
12 | module_path = os.path.abspath(os.path.join('..'))
13 | if module_path not in sys.path:
14 | sys.path.append(module_path)
15 |
16 | import time
17 | import h5py
18 | import matplotlib.pyplot as plt
19 | import seaborn as sns
20 | import numpy as np
21 | import torch
22 | import torch.nn.functional as f
23 | from torch.utils import data
24 |
25 | from src.run import prepare_config
26 | from src.runner import Runner
27 | from src.dataset import SpikesDataset, DATASET_MODES
28 | from src.mask import UNMASKED_LABEL
29 |
30 | from analyze_utils import init_by_ckpt
31 | # Note that a whole lot of logger items will still be dumped into ./scripts/logs
32 |
33 | grid = True
34 | ckpt = "Grid" if grid else "PBT"
35 |
36 | # Lorenz
37 | variant = "lorenz-s1"
38 |
39 | # Chaotic
40 | variant = "chaotic-s1"
41 |
42 | def get_info(variant):
43 | ckpt_path = f"/snel/home/joely/ray_results/ndt/gridsearch/{variant}/best_model/ckpts/{variant}.lve.pth"
44 | runner, spikes, rates = init_by_ckpt(ckpt_path, mode=DATASET_MODES.val)
45 |
46 | train_set = SpikesDataset(runner.config, runner.config.DATA.TRAIN_FILENAME, mode=DATASET_MODES.train, logger=runner.logger)
47 | train_spikes, train_rates = train_set.get_dataset()
48 | full_spikes = torch.cat([train_spikes.to(spikes.device), spikes], dim=0)
49 | full_rates = torch.cat([train_rates.to(rates.device), rates], dim=0)
50 | # full_spikes = spikes
51 | # full_rates = rates
52 |
53 | (
54 | unmasked_loss,
55 | pred_rates,
56 | layer_outputs,
57 | attn_weights,
58 | attn_list,
59 | ) = runner.model(full_spikes, mask_labels=full_spikes, return_weights=True)
60 |
61 | return full_spikes.cpu(), full_rates.cpu(), pred_rates, runner
62 |
63 | chaotic_spikes, chaotic_rates, chaotic_ndt, chaotic_runner = get_info('chaotic-s1')
64 | lorenz_spikes, lorenz_rates, lorenz_ndt, lorenz_runner = get_info('lorenz-s1')
65 |
66 |
67 | #%%
68 | import matplotlib
69 | import matplotlib.pyplot as plt
70 | import numpy as np
71 | import seaborn as sns
72 |
73 | palette = sns.color_palette(palette='muted', n_colors=3, desat=0.9)
74 | SMALL_SIZE = 12
75 | MEDIUM_SIZE = 15
76 | LARGE_SIZE = 18
77 |
78 | def prep_plt(ax=plt.gca()):
79 | plt.rc('font', size=MEDIUM_SIZE) # controls default text sizes
80 | plt.rc('axes', labelsize=LARGE_SIZE) # fontsize of the x and y labels
81 | plt.rc('xtick', labelsize=MEDIUM_SIZE) # fontsize of the tick labels
82 | plt.rc('ytick', labelsize=MEDIUM_SIZE) # fontsize of the tick labels
83 | # plt.rc('title', labelsize=MEDIUM_SIZE) # fontsize of the tick labels
84 |
85 | plt.rc('legend', fontsize=SMALL_SIZE) # legend fontsize
86 | plt.style.use('seaborn-muted')
87 | # plt.figure(figsize=(6,4))
88 |
89 | spine_alpha = 0.5
90 | ax.spines['right'].set_alpha(0.0)
91 | # plt.gca().spines['bottom'].set_alpha(spine_alpha)
92 | ax.spines['bottom'].set_alpha(0)
93 | # plt.gca().spines['left'].set_alpha(spine_alpha)
94 | ax.spines['left'].set_alpha(0)
95 | ax.spines['top'].set_alpha(0.0)
96 |
97 | plt.tight_layout()
98 |
99 | prep_plt()
100 |
101 | # np.random.seed(5)
102 |
103 | # plt.text("NDT Predictions", color=palette[0])
104 | # plt.text("Ground Truth", color=palette[2])
105 | def plot_axis(
106 | spikes,
107 | true_rates,
108 | ndt_rates, # ndt
109 | ax,
110 | key="lorenz",
111 | condition_idx=0,
112 | # trial_idx=0,
113 | neuron_idx=0,
114 | seed=3, # legacy,
115 | first_n=8
116 | ):
117 | prep_plt(ax)
118 | # np.random.seed(seed)
119 | # neuron_idx = np.random.randint(0, true_rates.size(-1))
120 | # trial_idx = np.random.randint(0, true_rates.size(0))
121 | _, unique, counts = np.unique(true_rates.numpy(), axis=0, return_inverse=True, return_counts=True)
122 | trial_idx = (unique == condition_idx)
123 |
124 | true_trial_rates = true_rates[trial_idx][0, :, neuron_idx].cpu()
125 | pred_trial_rates = ndt_rates[trial_idx][0:first_n, :, neuron_idx].cpu()
126 | trial_spikes = spikes[trial_idx][0:first_n, :, neuron_idx].cpu()
127 |
128 | time = true_rates.size(1)
129 | if True: # lograte
130 | true_trial_rates = true_trial_rates.exp()
131 | pred_trial_rates = pred_trial_rates.exp()
132 | lfads_preds = np.load(f"/snel/home/joely/data/{key}.npy")
133 | lfads_trial_rates = lfads_preds[trial_idx][0:first_n, :, neuron_idx]
134 |
135 | if key == "lorenz":
136 | spike_level = -0.06
137 | else:
138 | spike_level = -0.03
139 | one_tenth_point = 0.1 * (50 if key == "lorenz" else 100)
140 |
141 | ax.plot([0, one_tenth_point], [(first_n + 1) * spike_level, (first_n + 1) * spike_level], 'k-', lw=3) # 5 / 50 = 0.1
142 |
143 | ax.plot(np.arange(time), true_trial_rates.numpy(), color='#111111', label="Ground Truth") #, linestyle="dashed")
144 |
145 | # Only show labels once
146 | labels = {
147 | "NDT": "NDT",
148 | 'LFADS': "AutoLFADS",
149 | 'Spikes': 'Spikes'
150 | }
151 |
152 | for trial_idx in range(pred_trial_rates.numpy().shape[0]):
153 | # print(trial_idx)
154 | # print(pred_trial_rates.size())
155 | ax.plot(np.arange(time), pred_trial_rates[trial_idx].numpy(), color=palette[0], label=labels['NDT'], alpha=0.4)
156 | labels['NDT'] = ""
157 | ax.plot(np.arange(time), lfads_trial_rates[trial_idx], color=palette[1], label=labels['LFADS'], alpha=0.4)
158 | labels["LFADS"] = ""
159 | spike_times, = np.where(trial_spikes[trial_idx].numpy())
160 | ax.scatter(spike_times, spike_level * (trial_idx + 1)*np.ones_like(spike_times), c='k', marker='|', label=labels['Spikes'], s=30)
161 | labels['Spikes'] = ""
162 |
163 | ax.set_xticks([])
164 | # ax.set_xticks(np.linspace(one_tenth_point, one_tenth_point, 1))
165 |
166 | if key == "lorenz":
167 | ax.set_ylim(-0.6, 0.65)
168 | ax.set_yticks([0.0, 0.4])
169 | ax.plot([-1, -1], [0, 0.4], 'k-', lw=3)
170 |
171 | ax.annotate("",
172 | xy=(-1.5, -0.5),
173 | xycoords="data",
174 | xytext=(-1.5, -0.3),
175 | textcoords="data",
176 | arrowprops=dict(
177 | arrowstyle="<-",
178 | connectionstyle="arc3,rad=0",
179 | linewidth="2",
180 | # color=(0.2, 0.2, 0.2)
181 | ),
182 | size=14
183 | )
184 |
185 | ax.text(
186 | -5.5, -0.5,
187 | # -3.5, -0.5,
188 | "Trials",
189 | fontsize=14,
190 | rotation=90
191 | )
192 |
193 | else:
194 | ax.set_ylim(-0.3, 0.5)
195 | ax.set_yticks([0, 0.2])
196 | ax.plot([-2, -2], [0, 0.2], 'k-', lw=3)
197 |
198 | ax.annotate("",
199 | xy=(-3, -0.25),
200 | xycoords="data",
201 | xytext=(-3, -0.15),
202 | textcoords="data",
203 | arrowprops=dict(
204 | arrowstyle="<-",
205 | connectionstyle="arc3,rad=0",
206 | linewidth="2",
207 | # color=(0.2, 0.2, 0.2)
208 | ),
209 | size=14
210 | )
211 |
212 | ax.text(
213 | -11, -0.25,
214 | # -7, -0.25,
215 | "Trials",
216 | fontsize=14,
217 | rotation=90
218 | )
219 |
220 | # ax.set_title(f"Trial {trial_idx}, Neuron {neuron_idx}")
221 |
222 | f, axes = plt.subplots(
223 | nrows=2, ncols=2, sharex=False, sharey=False, figsize=(8, 6)
224 | )
225 |
226 | plot_axis(
227 | lorenz_spikes,
228 | lorenz_rates,
229 | lorenz_ndt,
230 | axes[0, 0],
231 | condition_idx=0,
232 | key="lorenz"
233 | )
234 |
235 | plot_axis(
236 | lorenz_spikes,
237 | lorenz_rates,
238 | lorenz_ndt,
239 | axes[0, 1],
240 | condition_idx=1,
241 | key="lorenz"
242 | )
243 |
244 | plot_axis(
245 | chaotic_spikes,
246 | chaotic_rates,
247 | chaotic_ndt,
248 | axes[1, 0],
249 | condition_idx=0,
250 | key="chaotic"
251 | )
252 |
253 | plot_axis(
254 | chaotic_spikes,
255 | chaotic_rates,
256 | chaotic_ndt,
257 | axes[1, 1],
258 | condition_idx=1,
259 | key="chaotic"
260 | )
261 |
262 | # plt.suptitle(f"{variant} {ckpt}", y=1.05)
263 | axes[0, 0].text(15, 0.45, "Lorenz", size=18, rotation=0)
264 | # axes[0, 0].text(-20, 0.2, "Lorenz", size=18, rotation=45)
265 | axes[1, 0].text(30, 0.1, "Chaotic", size=18, rotation=0)
266 | # axes[1, 0].text(20, 0.25, "Chaotic", size=18, rotation=0)
267 |
268 | # plt.tight_layout()
269 | f.subplots_adjust(
270 | # left=0.15,
271 | # bottom=-0.1,
272 | hspace=0.0,
273 | wspace=0.0
274 | )
275 | # axes[0,0].set_xticks([])
276 | axes[0,1].set_yticks([])
277 | # axes[0,1].set_xticks([])
278 | axes[1,1].set_yticks([])
279 |
280 | legend = axes[1, 1].legend(
281 | # loc=(-.95, 0.97),
282 | loc=(-1.05, 0.85),
283 | fontsize=14,
284 | frameon=False,
285 | ncol=4,
286 | )
287 |
288 | # for line in legend.get_lines():
289 | # line.set_linewidth(3.0)
290 |
291 | # plt.savefig("lorenz_rates.png", dpi=300, bbox_inches="tight")
292 | # plt.savefig("lorenz_rates_2.png", dpi=300, bbox_inches="tight")
293 | plt.setp(legend.get_texts()[1], color=palette[0])
294 | plt.setp(legend.get_texts()[2], color=palette[1])
295 | plt.setp(legend.get_texts()[0], color="#111111")
296 | # plt.setp(legend.get_texts()[3], color=palette[2])
297 |
298 | plt.savefig("3a_synth_qual.pdf")
299 |
300 | #%%
301 | # Total 1300 (130 x 10 per)
302 | # Val 260 (130 x 2 per)
303 | _, unique, counts = np.unique(chaotic_rates.numpy(), axis=0, return_inverse=True, return_counts=True)
304 | # unique, counts = chaotic_rates.unique(dim=0, return_counts=True)
305 | trial_idx = (unique == 0)
306 | # Oh, a lot of people look like
307 | print(np.where(trial_idx))
308 | true_trial_rates = chaotic_rates[trial_idx][0, :, 0].cpu()
309 | print(true_trial_rates[:10])
--------------------------------------------------------------------------------
/scripts/fig_3b_4b_plot_hparams.py:
--------------------------------------------------------------------------------
1 | # ! Reference script. Not supported out of the box.
2 |
3 | #%%
4 | # param vs nll (3b) and match to vaf vs nll (4b)
5 | # 3b requires a CSV downloaded from TB Hparams page (automatically generated by Ray)
6 | # 4b requires saving down of rate predictions from each model (for NDT, this step is in `record_all_rates.py`)
7 |
8 | from pathlib import Path
9 | import sys
10 | module_path = str(Path('..').resolve())
11 | if module_path not in sys.path:
12 | sys.path.append(module_path)
13 | from collections import defaultdict
14 | import time
15 | import matplotlib.pyplot as plt
16 | import seaborn as sns
17 | import numpy as np
18 | import pandas as pd
19 | import torch
20 | import torch.nn.functional as f
21 | from torch.utils import data
22 |
23 | from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
24 |
25 | from src.run import prepare_config
26 | from src.runner import Runner
27 | from src.dataset import SpikesDataset, DATASET_MODES
28 | from src.mask import UNMASKED_LABEL
29 |
30 | from analyze_utils import init_by_ckpt
31 |
32 | tf_size_guidance = {'scalars': 1000}
33 | #%%
34 | # Extract the NLL and R2
35 | plot_path = Path('~/projects/transformer-modeling/scripts/hparams.csv').expanduser().resolve()
36 | df = pd.read_csv(plot_path)
37 |
38 | #%%
39 | SMALL_SIZE = 12
40 | MEDIUM_SIZE = 20
41 | LARGE_SIZE = 24
42 |
43 | def prep_plt():
44 | plt.rc('font', size=MEDIUM_SIZE) # controls default text sizes
45 | plt.rc('axes', labelsize=LARGE_SIZE) # fontsize of the x and y labels
46 | plt.rc('xtick', labelsize=MEDIUM_SIZE) # fontsize of the tick labels
47 | plt.rc('ytick', labelsize=MEDIUM_SIZE) # fontsize of the tick labels
48 | # plt.rc('title', labelsize=MEDIUM_SIZE) # fontsize of the tick labels
49 |
50 | plt.rc('legend', fontsize=SMALL_SIZE) # legend fontsize
51 | plt.style.use('seaborn-muted')
52 | # plt.figure(figsize=(6,4))
53 |
54 | spine_alpha = 0.5
55 | plt.gca().spines['right'].set_alpha(0.0)
56 | plt.gca().spines['bottom'].set_alpha(spine_alpha)
57 | # plt.gca().spines['bottom'].set_alpha(0)
58 | plt.gca().spines['left'].set_alpha(spine_alpha)
59 | # plt.gca().spines['left'].set_alpha(0)
60 | plt.gca().spines['top'].set_alpha(0.0)
61 |
62 | plt.tight_layout()
63 |
64 | plt.figure(figsize=(4, 6))
65 | # plt.figure(figsize=(6, 4))
66 | prep_plt()
67 | # So. I select the model with the least all-time MVE, and report metrics on that MVE ckpt.
68 | # Here I'm reporting numbers from the final ckpt instead.
69 |
70 | x_axis = 'unmasked_loss'
71 | # x_axis = 'masked_loss'
72 | sns.scatterplot(data=df, x=f"ray/tune/{x_axis}", y="ray/tune/r2", s=120)
73 | # sns.scatterplot(data=df, x="ray/tune/best_masked_loss", y="ray/tune/r2", s=120)
74 | plt.yticks(np.arange(0.8, 0.96, 0.15))
75 | # plt.yticks(np.arange(0.8, 0.96, 0.05))
76 | plt.xticks([0.37, 0.39])
77 | plt.xlim(0.365, 0.39)
78 | plt.xlabel("Match to Spikes (NLL)")
79 | plt.ylabel("Rate Prediction $R^2$", labelpad=-20)
80 | from scipy.stats import pearsonr
81 | r = pearsonr(df[f'ray/tune/{x_axis}'], df['ray/tune/r2'])
82 | print(r)
83 | plt.text(0.378, 0.92, f"$\it{{\\rho}}$ : {r[0]:.3f}", size=LARGE_SIZE)
84 |
85 | plt.savefig("3_match_spikes.pdf", bbox_inches="tight")
86 | #%%
87 | palette = sns.color_palette(palette='muted', n_colors=2, desat=0.9)
88 | variant = 'm700_2296-s1'
89 | nlls = torch.load(f'{variant}_val_errs_sweep.pth')
90 | matches = torch.load(f'/snel/home/joely/projects/rds/{variant}_psth_match_sweep.pth')
91 | decoding = torch.load(f'/snel/home/joely/projects/rds/{variant}_deocding_sweep.pth')
92 | nll_arr = []
93 | match_arr = []
94 | decoding_arr = []
95 | for key in nlls:
96 | nll_arr.append(nlls[key])
97 | match_arr.append(matches[key])
98 | decoding_arr.append(decoding[key])
99 | plt.figure(figsize=(4, 5))
100 | prep_plt()
101 | plt.scatter(nll_arr, match_arr, color = palette[0])
102 | plt.xticks([0.139, 0.144], rotation=20)
103 | plt.xlim(0.139, 0.144)
104 | plt.yticks([0.55, 0.75])
105 | # plt.yticks(np.linspace(0.5, 0.75, 2))
106 | plt.xlabel("Match to Spikes (NLL)", labelpad=0, fontsize=MEDIUM_SIZE)
107 | plt.ylabel("Match to Empirical PSTH ($R^2$)", labelpad=0, fontsize=MEDIUM_SIZE)
108 | r = pearsonr(nll_arr, match_arr)
109 | print(r)
110 |
111 | plt.text(0.1392, 0.53, f"$\it{{\\rho}}$ : {r[0]:.3f}", size=LARGE_SIZE, color=palette[0])
112 | # plt.text(0.141, 0.72, f"$\it{{\\rho}}$ : {r[0]:.3f}", size=LARGE_SIZE)
113 | plt.hlines(0.7078, 0.139, 0.144, linestyles="--", color=palette[1])
114 | plt.text(0.141, 0.715, f"AutoLFADS", size=MEDIUM_SIZE, color=palette[1])
115 | # plt.text(0.142, 0.715, f"LFADS", size=MEDIUM_SIZE, color=palette[1])
116 |
117 | plt.savefig("4b_match_psth.pdf", bbox_inches="tight")
118 | #%%
119 | plt.figure(figsize=(4, 5))
120 | prep_plt()
121 | plt.scatter(nll_arr, decoding_arr)
122 | plt.xticks([0.139, 0.144], rotation=20)
123 | plt.xlim(0.139, 0.144)
124 |
125 | #%%
126 | plt.figure(figsize=(4, 5))
127 | prep_plt()
128 | plt.scatter(match_arr, decoding_arr)
129 | # plt.xticks([0.139, 0.144], rotation=20)
130 | # plt.xlim(0.139, 0.144)
131 |
--------------------------------------------------------------------------------
/scripts/fig_5_lfads_times.py:
--------------------------------------------------------------------------------
1 | # Author: Joel Ye
2 |
3 | # LFADS timing
4 | # ! This is a reference script. May not run out of the box with the NDT environment (e.g. needs LFADS dependencies)
5 | #%%
6 | import os
7 | os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
8 |
9 | import copy
10 | import cProfile
11 | import os
12 | import time
13 | import gc
14 |
15 | import h5py
16 | import matplotlib.pyplot as plt
17 | import numpy as np
18 | import tensorflow as tf
19 | import tensorflow_probability as tfp
20 | from tensorflow.keras.utils import Progbar
21 |
22 | from lfads_tf2.defaults import get_cfg_defaults
23 | from lfads_tf2.models import LFADS
24 | from lfads_tf2.tuples import DecoderInput, SamplingOutput
25 | from lfads_tf2.utils import (load_data, load_posterior_averages,
26 | restrict_gpu_usage)
27 |
28 | tfd = tfp.distributions
29 | gc.disable() # disable garbage collection
30 |
31 | # tf.debugging.set_log_device_placement(True)
32 | restrict_gpu_usage(gpu_ix=0)
33 |
34 |
35 | # %%
36 | # restore the LFADS model
37 |
38 |
39 | # model = LFADS(model_dir) # A hardcoded chaotic path
40 | model_dir = '/snel/home/joely/ray_results/lfads/chaotic-s1/best_model'
41 | cfg_path = os.path.join(model_dir, 'model_spec.yaml') # Don't directly load model
42 |
43 | def sample_and_average(n_samples=50,
44 | batch_size=64,
45 | merge_tv=False,
46 | seq_len_cap=100):
47 | cfg = get_cfg_defaults()
48 | cfg.merge_from_file(cfg_path)
49 | cfg.MODEL.SEQ_LEN = seq_len_cap
50 | cfg.freeze()
51 | model = LFADS(
52 | cfg_node=cfg,
53 | # model_dir='/snel/home/joely/ray_results/lfads/chaotic-s1/best_model'
54 | )
55 | model.restore_weights()
56 | # ! Modified for timing
57 | if not model.is_trained:
58 | model.lgr.warn("Performing posterior sampling on an untrained model.")
59 |
60 | # define merging and splitting utilities
61 | def merge_samp_and_batch(data, batch_dim):
62 | """ Combines the sample and batch dimensions """
63 | return tf.reshape(data, [n_samples * batch_dim] +
64 | tf.unstack(tf.shape(data)[2:]))
65 |
66 | def split_samp_and_batch(data, batch_dim):
67 | """ Splits up the sample and batch dimensions """
68 | return tf.reshape(data, [n_samples, batch_dim] +
69 | tf.unstack(tf.shape(data)[1:]))
70 |
71 | # ========== POSTERIOR SAMPLING ==========
72 | # perform sampling on both training and validation data
73 | loop_times = []
74 | for prefix, dataset in zip(['train_', 'valid_'],
75 | [model._train_ds, model._val_ds]):
76 | data_len = len(model.train_tuple.data) if prefix == 'train_' else len(
77 | model.val_tuple.data)
78 |
79 | # initialize lists to store rates
80 | all_outputs = []
81 | model.lgr.info(
82 | "Posterior sample and average on {} segments.".format(data_len))
83 | if not model.cfg.TRAIN.TUNE_MODE:
84 | pbar = Progbar(data_len, width=50, unit_name='dataset')
85 |
86 | def process_batch():
87 | # unpack the batch
88 | data, _, ext_input = batch
89 | data = data[:,:seq_len_cap]
90 | ext_input = ext_input[:,:seq_len_cap]
91 |
92 | time_start = time.time()
93 |
94 | # for each chop in the dataset, compute the initial conditions
95 | # distribution
96 | ic_mean, ic_stddev, ci = model.encoder.graph_call(data)
97 | ic_post = tfd.MultivariateNormalDiag(ic_mean, ic_stddev)
98 |
99 | # sample from the posterior and merge sample and batch dimensions
100 | ic_post_samples = ic_post.sample(n_samples)
101 | ic_post_samples_merged = merge_samp_and_batch(
102 | ic_post_samples, len(data))
103 |
104 | # tile and merge the controller inputs and the external inputs
105 | ci_tiled = tf.tile(tf.expand_dims(ci, axis=0),
106 | [n_samples, 1, 1, 1])
107 | ci_merged = merge_samp_and_batch(ci_tiled, len(data))
108 | ext_tiled = tf.tile(tf.expand_dims(ext_input, axis=0),
109 | [n_samples, 1, 1, 1])
110 | ext_merged = merge_samp_and_batch(ext_tiled, len(data))
111 |
112 | # pass all samples into the decoder
113 | dec_input = DecoderInput(ic_samp=ic_post_samples_merged,
114 | ci=ci_merged,
115 | ext_input=ext_merged)
116 | output_samples_merged = model.decoder.graph_call(dec_input)
117 |
118 | # average the outputs across samples
119 | output_samples = [
120 | split_samp_and_batch(t, len(data))
121 | for t in output_samples_merged
122 | ]
123 | output = [np.mean(t, axis=0) for t in output_samples]
124 |
125 | time_elapsed = time.time() - time_start
126 |
127 | if not model.cfg.TRAIN.TUNE_MODE:
128 | pbar.add(len(data))
129 | return time_elapsed
130 |
131 | for batch in dataset.batch(batch_size):
132 | loop_times.append(process_batch())
133 | return loop_times
134 | # %%
135 | def time_bin(bins):
136 | n_samples = 1
137 | p_loop_times = sample_and_average(batch_size=1, n_samples=n_samples, seq_len_cap=bins)
138 | del p_loop_times[0] # first iteration is slower due to graph initialization
139 | p_loop_times = np.array(p_loop_times) * 1e3
140 | print(f"{p_loop_times.mean():.3f}ms for {bins} bins")
141 | return p_loop_times
142 |
143 | all_times = []
144 | for bins in range(5, 10, 5):
145 | for bins in range(5, 105, 5):
146 | all_times.append(time_bin(bins))
147 | all_times = np.array(all_times)
148 |
149 | with open('lfads_times.npy', 'wb') as f:
150 | np.save(f, all_times)
151 |
--------------------------------------------------------------------------------
/scripts/fig_5_ndt_times.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Author: Joel Ye
3 |
4 | # Run from scripts directory.
5 | # python timing_tests.py -l {1, 2, 6}
6 |
7 | #%%
8 | import os
9 | import os.path as osp
10 | import argparse
11 | import sys
12 | module_path = os.path.abspath(os.path.join('..'))
13 | if module_path not in sys.path:
14 | sys.path.append(module_path)
15 |
16 | import time
17 | import gc
18 | gc.disable()
19 |
20 | import h5py
21 | import matplotlib.pyplot as plt
22 | import seaborn as sns
23 | import numpy as np
24 | import torch
25 | from torch.utils import data
26 | import torch.nn.functional as f
27 |
28 | from src.dataset import DATASET_MODES, SpikesDataset
29 | from src.run import prepare_config
30 | from src.runner import Runner
31 | from analyze_utils import make_runner, get_multiplicative_weights, init_by_ckpt
32 |
33 | prefix="arxiv"
34 | base=""
35 | variant = "chaotic"
36 |
37 | run_type = "eval"
38 | exp_config = osp.join("../configs", prefix, f"{variant}.yaml")
39 | if base != "":
40 | exp_config = [osp.join("../configs", f"{base}.yaml"), exp_config]
41 |
42 | def get_parser():
43 | parser = argparse.ArgumentParser()
44 |
45 | parser.add_argument(
46 | "--num-layers", "-l",
47 | type=int,
48 | required=True,
49 | )
50 | return parser
51 |
52 | parser = get_parser()
53 | args = parser.parse_args()
54 | layers = vars(args)["num_layers"]
55 |
56 | def make_runner_of_bins(bin_count=100, layers=None):
57 | config, _ = prepare_config(
58 | exp_config, run_type, "", [
59 | "USE_TENSORBOARD", False,
60 | "SYSTEM.NUM_GPUS", 1,
61 | ], suffix=prefix
62 | )
63 | config.defrost()
64 | config.MODEL.TRIAL_LENGTH = bin_count
65 | if layers is not None:
66 | config.MODEL.NUM_LAYERS = layers
67 | config.MODEL.LEARNABLE_POSITION = False # Not sure why...
68 | config.freeze()
69 | return Runner(config)
70 |
71 | def time_length(trials=1300, bin_count=100, **kwargs):
72 | # 100 as upper bound
73 | runner = make_runner_of_bins(bin_count=bin_count, **kwargs)
74 | runner.logger.mute()
75 | runner.load_device()
76 | runner.max_spikes = 9 # from chaotic ckpt
77 | runner.num_neurons = 50 # from chaotic ckpt
78 | # runner.num_neurons = 202 # from chaotic ckpt
79 | runner.setup_model(runner.device)
80 | # whole_set = SpikesDataset(runner.config, runner.config.DATA.TRAIN_FILENAME, mode="trainval")
81 | # whole_set.clip_spikes(runner.max_spikes)
82 | # # print(f"Evaluating on {len(whole_set)} samples.")
83 | # data_generator = data.DataLoader(whole_set,
84 | # batch_size=1, shuffle=False
85 | # )
86 | loop_times = []
87 | with torch.no_grad():
88 | probs = torch.full((1, bin_count, runner.num_neurons), 0.1)
89 | # probs = torch.full((1, bin_count, runner.num_neurons), 0.01)
90 | while len(loop_times) < trials:
91 | spikes = torch.bernoulli(probs).long()
92 | spikes = spikes.to(runner.device)
93 | start = time.time()
94 | runner.model(spikes, mask_labels=spikes, passthrough=True)
95 | delta = time.time() - start
96 | loop_times.append(delta)
97 | p_loop_times = np.array(loop_times) * 1e3
98 | print(f"{p_loop_times.mean():.4f}ms for {bin_count} bins")
99 |
100 | # A note about memory: It's a bit unclear why `empty_cache` is failing and memory still shows as used on torch, but the below diagnostic indicates the memory is not allocated, and will not cause OOM. So, a minor inconvenience for now.
101 | # device = runner.device
102 | # runner.model.to('cpu')
103 | # t = torch.cuda.get_device_properties(device).total_memory
104 | # c = torch.cuda.memory_cached(device)
105 | # a = torch.cuda.memory_allocated(device)
106 | # print(device, t, c, a)
107 | # del runner
108 | # t = torch.cuda.get_device_properties(device).total_memory
109 | # c = torch.cuda.memory_cached(device)
110 | # a = torch.cuda.memory_allocated(device)
111 | # print(device, t, c, a)
112 | # del data_generator
113 | # del whole_set
114 | # del spikes
115 | # torch.cuda.empty_cache()
116 | return p_loop_times
117 |
118 | times = []
119 | for i in range(5, 15, 5):
120 | p_loop_times = time_length(trials=2000, bin_count=i, layers=layers)
121 | times.append(p_loop_times)
122 |
123 | times = np.stack(times, axis=0)
124 | np.save(f'ndt_times_layer_{layers}', times)
125 |
126 | #%%
127 |
128 |
--------------------------------------------------------------------------------
/scripts/fig_6_plot_losses.py:
--------------------------------------------------------------------------------
1 | #%%
2 | from pathlib import Path
3 | import sys
4 | module_path = str(Path('..').resolve())
5 | if module_path not in sys.path:
6 | sys.path.append(module_path)
7 | from collections import defaultdict
8 | import time
9 | import matplotlib.pyplot as plt
10 | import seaborn as sns
11 | import numpy as np
12 | import torch
13 | import torch.nn.functional as f
14 | from torch.utils import data
15 |
16 | from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
17 |
18 | from src.run import prepare_config
19 | from src.runner import Runner
20 | from src.dataset import SpikesDataset, DATASET_MODES
21 | from src.mask import UNMASKED_LABEL
22 |
23 | from analyze_utils import init_by_ckpt
24 |
25 | tf_size_guidance = {'scalars': 1000}
26 | #%%
27 |
28 | plot_folder = Path('~/ray_results/ndt/gridsearch/lograte_tb').expanduser().resolve()
29 | variants = {
30 | "m700_no_log-s1": "Rates",
31 | "m700_2296-s1": "Logrates"
32 | }
33 |
34 | all_info = defaultdict(dict) # key: variant, value: dict per variant, keyed by run (value will dict of step and value)
35 | for variant in variants:
36 | v_dir = plot_folder.joinpath(variant)
37 | for run_dir in v_dir.iterdir():
38 | if not run_dir.is_dir():
39 | continue
40 | tb_dir = run_dir.joinpath('tb')
41 | all_info[variant][tb_dir.parts[-2][:10]] = defaultdict(list) # key: "step" or "value", value: info
42 | for tb_file in tb_dir.iterdir():
43 | event_acc = EventAccumulator(str(tb_file), tf_size_guidance)
44 | event_acc.Reload()
45 | # print(event_acc.Tags())
46 | if 'val_loss' in event_acc.Tags()['scalars']:
47 | val_loss = event_acc.Scalars('val_loss')
48 | all_info[variant][tb_dir.parts[-2][:10]]['step'].append(val_loss[0].step)
49 | all_info[variant][tb_dir.parts[-2][:10]]['value'].append(val_loss[0].value)
50 |
51 | #%%
52 | SMALL_SIZE = 12
53 | MEDIUM_SIZE = 15
54 | LARGE_SIZE = 18
55 |
56 | def prep_plt():
57 | plt.rc('font', size=MEDIUM_SIZE) # controls default text sizes
58 | plt.rc('axes', labelsize=LARGE_SIZE) # fontsize of the x and y labels
59 | plt.rc('xtick', labelsize=MEDIUM_SIZE) # fontsize of the tick labels
60 | plt.rc('ytick', labelsize=MEDIUM_SIZE) # fontsize of the tick labels
61 | # plt.rc('title', labelsize=MEDIUM_SIZE) # fontsize of the tick labels
62 |
63 | plt.rc('legend', fontsize=SMALL_SIZE) # legend fontsize
64 | plt.style.use('seaborn-muted')
65 | # plt.figure(figsize=(6,4))
66 |
67 | spine_alpha = 0.5
68 | plt.gca().spines['right'].set_alpha(0.0)
69 | # plt.gca().spines['bottom'].set_alpha(spine_alpha)
70 | plt.gca().spines['bottom'].set_alpha(0)
71 | # plt.gca().spines['left'].set_alpha(spine_alpha)
72 | plt.gca().spines['left'].set_alpha(0)
73 | plt.gca().spines['top'].set_alpha(0.0)
74 |
75 | plt.tight_layout()
76 |
77 | plt.figure(figsize=(6, 4))
78 |
79 | prep_plt()
80 | palette = sns.color_palette(palette='muted', n_colors=len(variants), desat=0.9)
81 | colors = {}
82 | color_ind = 0
83 | for variant, label in variants.items():
84 | colors[variant] = palette[color_ind]
85 | color_ind += 1
86 | legend_str = label
87 | for run, run_info in all_info[variant].items():
88 | # Sort
89 | sort_ind = np.argsort(run_info['step'])
90 | steps = np.array(run_info['step'])[sort_ind]
91 | vals = np.array(run_info['value'])[sort_ind]
92 | plt.plot(steps, vals, color=colors[variant], label=legend_str)
93 | # plt.hlines(vals.min(),0, 25000, color=colors[variant])
94 | legend_str = ""
95 | plt.legend(fontsize=MEDIUM_SIZE, frameon=False)
96 | plt.xticks(np.arange(0, 25001, 12500))
97 | plt.yticks(np.arange(0.1, 1.21, 0.5))
98 | plt.ylim(0.1, 1.21)
99 | # plt.yticks(np.arange(0.2, 1.21, 0.5))
100 | plt.ylabel("Loss (NLL)")
101 | plt.yscale('log')
102 | plt.xlabel("Epochs")
103 | plt.savefig("6_losses.pdf", bbox_inches="tight")
--------------------------------------------------------------------------------
/scripts/nlb.py:
--------------------------------------------------------------------------------
1 | #%%
2 |
3 | # 1. Load model and get rate predictions
4 | import os
5 | import os.path as osp
6 | from pathlib import Path
7 | import sys
8 | module_path = os.path.abspath(os.path.join('..'))
9 | if module_path not in sys.path:
10 | sys.path.append(module_path)
11 |
12 | import time
13 | import h5py
14 | import matplotlib.pyplot as plt
15 | import seaborn as sns
16 | import numpy as np
17 | import torch
18 | import torch.nn.functional as f
19 | from torch.utils import data
20 |
21 | from nlb_tools.evaluation import evaluate
22 |
23 | from src.run import prepare_config
24 | from src.runner import Runner
25 | from src.dataset import SpikesDataset, DATASET_MODES
26 | from src.mask import UNMASKED_LABEL
27 |
28 | from analyze_utils import init_by_ckpt
29 |
30 | variant = "area2_bump"
31 | variant = "mc_maze"
32 | # variant = "mc_maze_large"
33 | # variant = "mc_maze_medium"
34 | # variant = "mc_maze_small"
35 | variant = 'dmfc_rsg'
36 | variant = 'dmfc_rsg2'
37 | # variant = 'mc_rtt'
38 |
39 | is_ray = True
40 | # is_ray = False
41 |
42 | if is_ray:
43 | best_model = "best_model"
44 | # best_model = "best_model_unmasked"
45 | lve = "lfve" if "unmasked" in best_model else "lve"
46 |
47 | def to_path(variant):
48 | grid_var = f"{variant}_lite"
49 | ckpt_path = f"/snel/home/joely/ray_results/ndt/gridsearch/{grid_var}/best_model/ckpts/{grid_var}.{lve}.pth"
50 | runner, spikes, rates, heldout_spikes, forward_spikes = init_by_ckpt(ckpt_path, mode=DATASET_MODES.val)
51 | print(runner.config.VARIANT)
52 | return runner, spikes, rates, heldout_spikes, ckpt_path
53 |
54 | runner, spikes, rates, heldout_spikes, ckpt_path = to_path(variant)
55 | else:
56 | ckpt_dir = Path(f"/snel/share/joel/transformer_modeling/{variant}/")
57 | ckpt_path = ckpt_dir.joinpath(f"{variant}.lve.pth")
58 | runner, spikes, rates, heldout_spikes, forward_spikes = init_by_ckpt(ckpt_path, mode=DATASET_MODES.val)
59 |
60 | eval_rates, _ = runner.get_rates(
61 | checkpoint_path=ckpt_path,
62 | save_path = None,
63 | mode = DATASET_MODES.val
64 | )
65 | train_rates, _ = runner.get_rates(
66 | checkpoint_path=ckpt_path,
67 | save_path = None,
68 | mode = DATASET_MODES.train
69 | )
70 |
71 | # * Val
72 | eval_rates, eval_rates_forward = torch.split(eval_rates, [spikes.size(1), eval_rates.size(1) - spikes.size(1)], 1)
73 | eval_rates_heldin_forward, eval_rates_heldout_forward = torch.split(eval_rates_forward, [spikes.size(-1), heldout_spikes.size(-1)], -1)
74 | train_rates, _ = torch.split(train_rates, [spikes.size(1), train_rates.size(1) - spikes.size(1)], 1)
75 | eval_rates_heldin, eval_rates_heldout = torch.split(eval_rates, [spikes.size(-1), heldout_spikes.size(-1)], -1)
76 | train_rates_heldin, train_rates_heldout = torch.split(train_rates, [spikes.size(-1), heldout_spikes.size(-1)], -1)
77 |
78 |
79 | #%%
80 |
81 | output_dict = {
82 | variant: {
83 | 'train_rates_heldin': train_rates_heldin.cpu().numpy(),
84 | 'train_rates_heldout': train_rates_heldout.cpu().numpy(),
85 | 'eval_rates_heldin': eval_rates_heldin.cpu().numpy(),
86 | 'eval_rates_heldout': eval_rates_heldout.cpu().numpy(),
87 | 'eval_rates_heldin_forward': eval_rates_heldin_forward.cpu().numpy(),
88 | 'eval_rates_heldout_forward': eval_rates_heldout_forward.cpu().numpy()
89 | }
90 | }
91 |
92 | # target_dict = torch.load(f'/snel/home/joely/tmp/{variant}_target.pth')
93 | target_dict = np.load(f'/snel/home/joely/tmp/{variant}_target.npy', allow_pickle=True).item()
94 |
95 | print(evaluate(target_dict, output_dict))
96 |
97 | #%%
98 | # * Test
99 |
100 | variant = 'dmfc_rsg'
101 | runner.config.defrost()
102 | runner.config.DATA.TRAIN_FILENAME = f'{variant}_test.h5'
103 | runner.config.freeze()
104 | train_rates, _ = runner.get_rates(
105 | checkpoint_path=ckpt_path,
106 | save_path = None,
107 | mode = DATASET_MODES.train
108 | )
109 | eval_rates, _ = runner.get_rates(
110 | checkpoint_path=ckpt_path,
111 | save_path = None,
112 | mode = DATASET_MODES.val,
113 | )
114 |
115 | eval_rates, eval_rates_forward = torch.split(eval_rates, [spikes.size(1), eval_rates.size(1) - spikes.size(1)], 1)
116 | eval_rates_heldin_forward, eval_rates_heldout_forward = torch.split(eval_rates_forward, [spikes.size(-1), heldout_spikes.size(-1)], -1)
117 | train_rates, _ = torch.split(train_rates, [spikes.size(1), train_rates.size(1) - spikes.size(1)], 1)
118 | eval_rates_heldin, eval_rates_heldout = torch.split(eval_rates, [spikes.size(-1), heldout_spikes.size(-1)], -1)
119 | train_rates_heldin, train_rates_heldout = torch.split(train_rates, [spikes.size(-1), heldout_spikes.size(-1)], -1)
120 |
121 | output_dict = {
122 | variant: {
123 | 'train_rates_heldin': train_rates_heldin.cpu().numpy(),
124 | 'train_rates_heldout': train_rates_heldout.cpu().numpy(),
125 | 'eval_rates_heldin': eval_rates_heldin.cpu().numpy(),
126 | 'eval_rates_heldout': eval_rates_heldout.cpu().numpy(),
127 | 'eval_rates_heldin_forward': eval_rates_heldin_forward.cpu().numpy(),
128 | 'eval_rates_heldout_forward': eval_rates_heldout_forward.cpu().numpy()
129 | }
130 | }
131 |
132 | print(evaluate('/snel/share/data/nlb/test_data_do_not_share/eval_data_test.h5', output_dict))
133 |
134 | #%%
135 | import h5py
136 | with h5py.File('ndt_maze_preds.h5', 'w') as f:
137 | group = f.create_group('mc_maze')
138 | for key in output_dict['mc_maze']:
139 | group.create_dataset(key, data=output_dict['mc_maze'][key])
140 | #%%
141 | # Viz some trials
142 | trials = [1, 2, 3]
143 | neuron = 10
144 | trial_rates = train_rates_heldout[trials, :, neuron].cpu()
145 | trial_spikes = heldout_spikes[trials, :, neuron].cpu()
146 | trial_time = heldout_spikes.size(1)
147 | """ """
148 | spike_level = 0.05
149 | # one_tenth_point = 0.1 * (50 if key == "lorenz" else 100)
150 | # ax.plot([0, one_tenth_point], [(first_n + 1) * spike_level, (first_n + 1) * spike_level], 'k-', lw=3) # 5 / 50 = 0.1
151 |
152 | for trial_index in range(trial_spikes.size(0)):
153 | print(trial_spikes[trial_index])
154 | # plt.plot(trial_rates[trial_index].exp())
155 | spike_times, = np.where(trial_spikes[trial_index].numpy())
156 | plt.scatter(spike_times, spike_level * (trial_index + 1)*np.ones_like(spike_times), c='k', marker='|', label='Spikes', s=30)
157 | # labels['Spikes'] = ""
158 |
159 | print(train_rates_heldout.size())
160 |
161 | # %%
162 | print(heldout_spikes.sum(0).sum(0).argmax())
163 | # print(heldout_spikes[:,:,15])
164 | for i in range(0, heldout_spikes.size(0), 6):
165 | plt.plot(heldout_spikes[i, :, 15].cpu())
--------------------------------------------------------------------------------
/scripts/nlb_from_scratch.ipynb:
--------------------------------------------------------------------------------
1 | {"cells":[{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["\"\"\"\n","- Let's setup NDT for train/eval from scratch.\n","- (Assumes your device is reasonably pytorch/GPU compatible)\n","- This is an interactive python script run via vscode. If you'd like to run as a notebook\n","\n","Run to setup requirements:\n","```\n"," Making a new env\n"," - python 3.7\n"," - conda install pytorch torchvision torchaudio cudatoolkit=10.2 -c pytorch\n"," - conda install seaborn\n"," - pip install yacs pytorch_transformers tensorboard \"ray[tune]\" sklearn\n"," - pip install dandi \"pynwb>=2.0.0\"\n"," Or from environment.yaml\n","\n"," Then,\n"," - conda develop ~/path/to/nlb_tools\n","```\n","\n","Install NLB dataset(s), and create the h5s for training. (Here we install MC_Maze_Small)\n","```\n"," pip install dandi\n"," dandi download DANDI:000140/0.220113.0408\n","```\n","\n","This largely follows the `basic_example.ipynb` in `nlb_tools`. The only distinction is that we save to an h5.\n","\"\"\"\n","\n","from nlb_tools.nwb_interface import NWBDataset\n","from nlb_tools.make_tensors import (\n"," make_train_input_tensors, make_eval_input_tensors, make_eval_target_tensors, save_to_h5\n",")\n","from nlb_tools.evaluation import evaluate\n","\n","import numpy as np\n","import pandas as pd\n","import h5py\n","\n","import logging\n","logging.basicConfig(level=logging.INFO)\n","\n","dataset_name = 'mc_maze_small'\n","datapath = '/home/joelye/user_data/nlb/000140/sub-Jenkins/'\n","dataset = NWBDataset(datapath)\n","\n","# Prepare dataset\n","phase = 'val'\n","\n","# Choose bin width and resample\n","bin_width = 5\n","dataset.resample(bin_width)\n","\n","# Create suffix for group naming later\n","suffix = '' if (bin_width == 5) else f'_{int(bin_width)}'\n","\n","train_split = 'train' if (phase == 'val') else ['train', 'val']\n","train_dict = make_train_input_tensors(\n"," dataset, dataset_name=dataset_name, trial_split=train_split, save_file=False,\n"," include_behavior=True,\n"," include_forward_pred = True,\n",")\n","\n","# Show fields of returned dict\n","print(train_dict.keys())\n","\n","# Unpack data\n","train_spikes_heldin = train_dict['train_spikes_heldin']\n","train_spikes_heldout = train_dict['train_spikes_heldout']\n","\n","# Print 3d array shape: trials x time x channel\n","print(train_spikes_heldin.shape)\n","\n","## Make eval data (i.e. val)\n","\n","# Split for evaluation is same as phase name\n","eval_split = phase\n","# Make data tensors\n","eval_dict = make_eval_input_tensors(\n"," dataset, dataset_name=dataset_name, trial_split=eval_split, save_file=False,\n",")\n","print(eval_dict.keys()) # only includes 'eval_spikes_heldout' if available\n","eval_spikes_heldin = eval_dict['eval_spikes_heldin']\n","\n","print(eval_spikes_heldin.shape)\n","\n","h5_dict = {\n"," **train_dict,\n"," **eval_dict\n","}\n","\n","h5_target = '/home/joelye/user_data/nlb/mc_maze_small.h5'\n","save_to_h5(h5_dict, h5_target)\n","\n",""]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["\"\"\"\n","- At this point we should be able to train a basic model.\n","- In CLI, run a training call, replacing the appropriate paths\n","```\n"," ./scripts/train.sh mc_maze_small_from_scratch\n","\n"," OR\n","\n"," python ray_random.py -e ./configs/mc_maze_small_from_scratch.yaml\n"," (CLI overrides aren't available here, so make another config file)\n","```\n","- Once this is done training (~0.5hr for non-search), let's load the results...\n","\"\"\"\n","import os\n","import os.path as osp\n","from pathlib import Path\n","import sys\n","\n","# Add ndt src if not in path\n","module_path = osp.abspath(osp.join('..'))\n","if module_path not in sys.path:\n"," sys.path.append(module_path)\n","\n","import time\n","import h5py\n","import matplotlib.pyplot as plt\n","import seaborn as sns\n","import numpy as np\n","import torch\n","import torch.nn.functional as f\n","from torch.utils import data\n","\n","from nlb_tools.evaluation import evaluate\n","\n","from src.run import prepare_config\n","from src.runner import Runner\n","from src.dataset import SpikesDataset, DATASET_MODES\n","\n","from analyze_utils import init_by_ckpt\n","\n","variant = \"mc_maze_small_from_scratch\"\n","\n","is_ray = True\n","is_ray = False\n","\n","if is_ray:\n"," ckpt_path = f\"/home/joelye/user_data/nlb/ndt_runs/ray/{variant}/best_model/ckpts/{variant}.lve.pth\"\n"," runner, spikes, rates, heldout_spikes, forward_spikes = init_by_ckpt(ckpt_path, mode=DATASET_MODES.val)\n","else:\n"," ckpt_dir = Path(\"/home/joelye/user_data/nlb/ndt_runs/\")\n"," ckpt_path = ckpt_dir / variant / f\"{variant}.lve.pth\"\n"," runner, spikes, rates, heldout_spikes, forward_spikes = init_by_ckpt(ckpt_path, mode=DATASET_MODES.val)\n","\n","eval_rates, _ = runner.get_rates(\n"," checkpoint_path=ckpt_path,\n"," save_path = None,\n"," mode = DATASET_MODES.val\n",")\n","train_rates, _ = runner.get_rates(\n"," checkpoint_path=ckpt_path,\n"," save_path = None,\n"," mode = DATASET_MODES.train\n",")\n","\n","# * Val\n","eval_rates, eval_rates_forward = torch.split(eval_rates, [spikes.size(1), eval_rates.size(1) - spikes.size(1)], 1)\n","eval_rates_heldin_forward, eval_rates_heldout_forward = torch.split(eval_rates_forward, [spikes.size(-1), heldout_spikes.size(-1)], -1)\n","train_rates, _ = torch.split(train_rates, [spikes.size(1), train_rates.size(1) - spikes.size(1)], 1)\n","eval_rates_heldin, eval_rates_heldout = torch.split(eval_rates, [spikes.size(-1), heldout_spikes.size(-1)], -1)\n","train_rates_heldin, train_rates_heldout = torch.split(train_rates, [spikes.size(-1), heldout_spikes.size(-1)], -1)\n",""]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["# * Viz some trials\n","trials = [1, 2, 3]\n","neuron = 10\n","trial_rates = train_rates_heldout[trials, :, neuron].cpu()\n","trial_spikes = heldout_spikes[trials, :, neuron].cpu()\n","trial_time = heldout_spikes.size(1)\n","\n","spike_level = 0.05\n","f, axes = plt.subplots(2, figsize=(6, 4))\n","times = np.arange(0, trial_time * 0.05, 0.05)\n","for trial_index in range(trial_spikes.size(0)):\n"," spike_times, = np.where(trial_spikes[trial_index].numpy())\n"," spike_times = spike_times * 0.05\n"," axes[0].scatter(spike_times, spike_level * (trial_index + 1)*np.ones_like(spike_times), marker='|', label='Spikes', s=30)\n"," axes[1].plot(times, trial_rates[trial_index].exp())\n","\n",""]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["# Submission e.g. as in `basic_example.ipynb`\n","# Looks like this model is pretty terrible :/\n","output_dict = {\n"," dataset_name + suffix: {\n"," 'train_rates_heldin': train_rates_heldin.cpu().numpy(),\n"," 'train_rates_heldout': train_rates_heldout.cpu().numpy(),\n"," 'eval_rates_heldin': eval_rates_heldin.cpu().numpy(),\n"," 'eval_rates_heldout': eval_rates_heldout.cpu().numpy(),\n"," 'eval_rates_heldin_forward': eval_rates_heldin_forward.cpu().numpy(),\n"," 'eval_rates_heldout_forward': eval_rates_heldout_forward.cpu().numpy()\n"," }\n","}\n","\n","# Reset logging level to hide excessive info messages\n","logging.getLogger().setLevel(logging.WARNING)\n","\n","# If 'val' phase, make the target data\n","if phase == 'val':\n"," # Note that the RTT task is not well suited to trial averaging, so PSTHs are not made for it\n"," target_dict = make_eval_target_tensors(dataset, dataset_name=dataset_name, train_trial_split='train', eval_trial_split='val', include_psth=True, save_file=False)\n","\n"," # Demonstrate target_dict structure\n"," print(target_dict.keys())\n"," print(target_dict[dataset_name + suffix].keys())\n","\n","# Set logging level again\n","logging.getLogger().setLevel(logging.INFO)\n","\n","if phase == 'val':\n"," print(evaluate(target_dict, output_dict))\n","\n","# e.g. with targets to compare to\n","# target_dict = torch.load(f'/snel/home/joely/tmp/{variant}_target.pth')\n","# target_dict = np.load(f'/snel/home/joely/tmp/{variant}_target.npy', allow_pickle=True).item()\n","\n","# print(evaluate(target_dict, output_dict))\n","\n","# e.g. to upload to EvalAI\n","# with h5py.File('ndt_maze_preds.h5', 'w') as f:\n","# group = f.create_group('mc_maze')\n","# for key in output_dict['mc_maze']:\n","# group.create_dataset(key, data=output_dict['mc_maze'][key])"]}],"nbformat":4,"nbformat_minor":2,"metadata":{"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":3},"orig_nbformat":4}}
--------------------------------------------------------------------------------
/scripts/nlb_from_scratch.py:
--------------------------------------------------------------------------------
1 | #%%
2 | """
3 | - Let's setup NDT for train/eval from scratch.
4 | - (Assumes your device is reasonably pytorch/GPU compatible)
5 | - This is an interactive python script run via vscode. If you'd like to run as a notebook
6 |
7 | Run to setup requirements:
8 | ```
9 | Making a new env
10 | - python 3.7
11 | - conda install pytorch torchvision torchaudio cudatoolkit=10.2 -c pytorch
12 | - conda install seaborn
13 | - pip install yacs pytorch_transformers tensorboard "ray[tune]" sklearn
14 | - pip install dandi "pynwb>=2.0.0"
15 | Or from nlb.yaml
16 |
17 | Then,
18 | - conda develop ~/path/to/nlb_tools
19 | ```
20 |
21 | Install NLB dataset(s), and create the h5s for training. (Here we install MC_Maze_Small)
22 | ```
23 | pip install dandi
24 | dandi download DANDI:000140/0.220113.0408
25 | ```
26 |
27 | This largely follows the `basic_example.ipynb` in `nlb_tools`. The only distinction is that we save to an h5.
28 | """
29 |
30 | from nlb_tools.nwb_interface import NWBDataset
31 | from nlb_tools.make_tensors import (
32 | make_train_input_tensors, make_eval_input_tensors, make_eval_target_tensors, save_to_h5
33 | )
34 | from nlb_tools.evaluation import evaluate
35 |
36 | import numpy as np
37 | import pandas as pd
38 | import h5py
39 |
40 | import logging
41 | logging.basicConfig(level=logging.INFO)
42 |
43 | dataset_name = 'mc_maze_small'
44 | datapath = '/home/joelye/user_data/nlb/000140/sub-Jenkins/'
45 | dataset = NWBDataset(datapath)
46 |
47 | # Prepare dataset
48 | phase = 'val'
49 |
50 | # Choose bin width and resample
51 | bin_width = 5
52 | dataset.resample(bin_width)
53 |
54 | # Create suffix for group naming later
55 | suffix = '' if (bin_width == 5) else f'_{int(bin_width)}'
56 |
57 | train_split = 'train' if (phase == 'val') else ['train', 'val']
58 | train_dict = make_train_input_tensors(
59 | dataset, dataset_name=dataset_name, trial_split=train_split, save_file=False,
60 | include_behavior=True,
61 | include_forward_pred = True,
62 | )
63 |
64 | # Show fields of returned dict
65 | print(train_dict.keys())
66 |
67 | # Unpack data
68 | train_spikes_heldin = train_dict['train_spikes_heldin']
69 | train_spikes_heldout = train_dict['train_spikes_heldout']
70 |
71 | # Print 3d array shape: trials x time x channel
72 | print(train_spikes_heldin.shape)
73 |
74 | ## Make eval data (i.e. val)
75 |
76 | # Split for evaluation is same as phase name
77 | eval_split = phase
78 | # eval_dict = make_eval_input_tensors(
79 | # dataset, dataset_name=dataset_name, trial_split=eval_split, save_file=False,
80 | # )
81 | # Make data tensors - use all chunks including forward prediction for training NDT
82 | eval_dict = make_train_input_tensors(
83 | dataset, dataset_name=dataset_name, trial_split=['val'], save_file=False, include_forward_pred=True,
84 | )
85 | eval_dict = {
86 | f'eval{key[5:]}': val for key, val in eval_dict.items()
87 | }
88 | eval_spikes_heldin = eval_dict['eval_spikes_heldin']
89 |
90 | print(eval_spikes_heldin.shape)
91 |
92 | h5_dict = {
93 | **train_dict,
94 | **eval_dict
95 | }
96 |
97 | h5_target = '/home/joelye/user_data/nlb/mc_maze_small.h5'
98 | save_to_h5(h5_dict, h5_target, overwrite=True)
99 |
100 |
101 | #%%
102 | """
103 | - At this point we should be able to train a basic model.
104 | - In CLI, run a training call, replacing the appropriate paths
105 | ```
106 | ./scripts/train.sh mc_maze_small_from_scratch
107 |
108 | OR
109 |
110 | python ray_random.py -e ./configs/mc_maze_small_from_scratch.yaml
111 | (CLI overrides aren't available here, so make another config file)
112 | ```
113 | - Once this is done training (~0.5hr for non-search), let's load the results...
114 | """
115 | import os
116 | import os.path as osp
117 | from pathlib import Path
118 | import sys
119 |
120 | # Add ndt src if not in path
121 | module_path = osp.abspath(osp.join('..'))
122 | if module_path not in sys.path:
123 | sys.path.append(module_path)
124 |
125 | import time
126 | import h5py
127 | import matplotlib.pyplot as plt
128 | import seaborn as sns
129 | import numpy as np
130 | import torch
131 | import torch.nn.functional as f
132 | from torch.utils import data
133 | from ray import tune
134 |
135 | from nlb_tools.evaluation import evaluate
136 |
137 | from src.run import prepare_config
138 | from src.runner import Runner
139 | from src.dataset import SpikesDataset, DATASET_MODES
140 | from analyze_utils import init_by_ckpt
141 |
142 | variant = "mc_maze_small_from_scratch"
143 |
144 | is_ray = True
145 | # is_ray = False
146 |
147 | if is_ray:
148 | tune_dir = f"/home/joelye/user_data/nlb/ndt_runs/ray/{variant}"
149 | df = tune.ExperimentAnalysis(tune_dir).dataframe()
150 | # ckpt_path = f"/home/joelye/user_data/nlb/ndt_runs/ray/{variant}/best_model/ckpts/{variant}.lve.pth"
151 | ckpt_dir = df.loc[df["best_masked_loss"].idxmin()].logdir
152 | ckpt_path = f"{ckpt_dir}/ckpts/{variant}.lve.pth"
153 | runner, spikes, rates, heldout_spikes, forward_spikes = init_by_ckpt(ckpt_path, mode=DATASET_MODES.val)
154 | else:
155 | ckpt_dir = Path("/home/joelye/user_data/nlb/ndt_runs/")
156 | ckpt_path = ckpt_dir / variant / f"{variant}.lve.pth"
157 | runner, spikes, rates, heldout_spikes, forward_spikes = init_by_ckpt(ckpt_path, mode=DATASET_MODES.val)
158 |
159 | eval_rates, _ = runner.get_rates(
160 | checkpoint_path=ckpt_path,
161 | save_path = None,
162 | mode = DATASET_MODES.val
163 | )
164 | train_rates, _ = runner.get_rates(
165 | checkpoint_path=ckpt_path,
166 | save_path = None,
167 | mode = DATASET_MODES.train
168 | )
169 |
170 | # * Val
171 | eval_rates, eval_rates_forward = torch.split(eval_rates, [spikes.size(1), eval_rates.size(1) - spikes.size(1)], 1)
172 | eval_rates_heldin_forward, eval_rates_heldout_forward = torch.split(eval_rates_forward, [spikes.size(-1), heldout_spikes.size(-1)], -1)
173 | train_rates, _ = torch.split(train_rates, [spikes.size(1), train_rates.size(1) - spikes.size(1)], 1)
174 | eval_rates_heldin, eval_rates_heldout = torch.split(eval_rates, [spikes.size(-1), heldout_spikes.size(-1)], -1)
175 | train_rates_heldin, train_rates_heldout = torch.split(train_rates, [spikes.size(-1), heldout_spikes.size(-1)], -1)
176 | #%%
177 | # * Viz some trials
178 | trials = [1, 2, 3]
179 | neuron = 10
180 | trial_rates = train_rates_heldout[trials, :, neuron].cpu()
181 | trial_spikes = heldout_spikes[trials, :, neuron].cpu()
182 | trial_time = heldout_spikes.size(1)
183 |
184 | spike_level = 0.05
185 | f, axes = plt.subplots(2, figsize=(6, 4))
186 | times = np.arange(0, trial_time * 0.05, 0.05)
187 | for trial_index in range(trial_spikes.size(0)):
188 | spike_times, = np.where(trial_spikes[trial_index].numpy())
189 | spike_times = spike_times * 0.05
190 | axes[0].scatter(spike_times, spike_level * (trial_index + 1)*np.ones_like(spike_times), marker='|', label='Spikes', s=30)
191 | axes[1].plot(times, trial_rates[trial_index].exp())
192 |
193 | #%%
194 | # Submission e.g. as in `basic_example.ipynb`
195 | # Looks like this model is pretty terrible :/
196 | output_dict = {
197 | dataset_name + suffix: {
198 | 'train_rates_heldin': train_rates_heldin.cpu().numpy(),
199 | 'train_rates_heldout': train_rates_heldout.cpu().numpy(),
200 | 'eval_rates_heldin': eval_rates_heldin.cpu().numpy(),
201 | 'eval_rates_heldout': eval_rates_heldout.cpu().numpy(),
202 | # 'eval_rates_heldin_forward': eval_rates_heldin_forward.cpu().numpy(),
203 | # 'eval_rates_heldout_forward': eval_rates_heldout_forward.cpu().numpy()
204 | }
205 | }
206 |
207 | # Reset logging level to hide excessive info messages
208 | logging.getLogger().setLevel(logging.WARNING)
209 |
210 | # If 'val' phase, make the target data
211 | if phase == 'val':
212 | # Note that the RTT task is not well suited to trial averaging, so PSTHs are not made for it
213 | target_dict = make_eval_target_tensors(dataset, dataset_name=dataset_name, train_trial_split='train', eval_trial_split='val', include_psth=True, save_file=False)
214 |
215 | # Demonstrate target_dict structure
216 | print(target_dict.keys())
217 | print(target_dict[dataset_name + suffix].keys())
218 |
219 | # Set logging level again
220 | logging.getLogger().setLevel(logging.INFO)
221 |
222 | if phase == 'val':
223 | print(evaluate(target_dict, output_dict))
224 |
225 | # e.g. with targets to compare to
226 | # target_dict = torch.load(f'/snel/home/joely/tmp/{variant}_target.pth')
227 | # target_dict = np.load(f'/snel/home/joely/tmp/{variant}_target.npy', allow_pickle=True).item()
228 |
229 | # print(evaluate(target_dict, output_dict))
230 |
231 | # e.g. to upload to EvalAI
232 | # with h5py.File('ndt_maze_preds.h5', 'w') as f:
233 | # group = f.create_group('mc_maze')
234 | # for key in output_dict['mc_maze']:
235 | # group.create_dataset(key, data=output_dict['mc_maze'][key])
--------------------------------------------------------------------------------
/scripts/record_all_rates.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Author: Joel Ye
3 |
4 | # Notebook for interactive model evaluation/analysis
5 | # Allows us to interrogate model on variable data (instead of masked sample again)
6 |
7 | #%%
8 | import os
9 | import os.path as osp
10 | from pathlib import Path
11 | import sys
12 | module_path = os.path.abspath(os.path.join('..'))
13 | if module_path not in sys.path:
14 | sys.path.append(module_path)
15 |
16 | import time
17 | import h5py
18 | import matplotlib.pyplot as plt
19 | import seaborn as sns
20 | import numpy as np
21 | import torch
22 | import torch.nn.functional as f
23 | from torch.utils import data
24 |
25 | from src.run import prepare_config
26 | from src.runner import Runner
27 | from src.dataset import SpikesDataset, DATASET_MODES
28 | from src.mask import UNMASKED_LABEL
29 |
30 | from analyze_utils import init_by_ckpt
31 |
32 | grid = False
33 | grid = True
34 |
35 | variant = "m700_2296-s1"
36 |
37 | val_errs = {}
38 | def save_rates(ckpt_path, handle):
39 | runner, spikes, rates = init_by_ckpt(ckpt_path, mode=DATASET_MODES.val)
40 | if "maze" in variant or "m700" in variant:
41 | runner.config.defrost()
42 | runner.config.DATA.DATAPATH = "/snel/share/data/ndt_paper/m1_maze/heldout_trial/2296_trials/0_seed"
43 | runner.config.freeze()
44 | rate_output_pth = f"/snel/share/joel/ndt_rates/psth_match"
45 | rate_output_pth = osp.join(rate_output_pth, "grid" if grid else "pbt")
46 | rate_output_fn = f"{handle}_{variant}.h5"
47 | val_errs[handle] = runner.best_val['value'].cpu().numpy()
48 |
49 |
50 | sweep_dir = Path("/snel/home/joely/ray_results/ndt/")
51 | if grid:
52 | sweep_dir = sweep_dir.joinpath("gridsearch")
53 | sweep_dir = sweep_dir.joinpath(variant, variant)
54 | for run_dir in sweep_dir.glob("tuneNDT*/"):
55 | run_ckpt_path = run_dir.joinpath(f'ckpts/{variant}.lve.pth')
56 | handle = run_dir.parts[-1][:10]
57 | save_rates(run_ckpt_path, handle)
58 |
59 | #%%
60 | torch.save(val_errs, f'{variant}_val_errs_sweep.pth')
--------------------------------------------------------------------------------
/scripts/scratch.py:
--------------------------------------------------------------------------------
1 | #%%
2 | import h5py
3 | path_old = '/home/joelye/user_data/nlb/ndt_old/mc_maze_small.h5'
4 | info_old = h5py.File(path_old)
5 | path_new = '/home/joelye/user_data/nlb/mc_maze_small.h5'
6 | info_new = h5py.File(path_new)
7 | path_old_eval = '/home/joelye/user_data/nlb/ndt_old/eval_data_val.h5'
8 | info_eval = h5py.File(path_old_eval)
9 | #%%
10 | # print(info_old.keys())
11 | print(info_old.keys())
12 | print(info_new.keys())
13 | for key in info_old:
14 | new_key = key.split('_')
15 | new_key[1] = 'spikes'
16 | new_key = '_'.join(new_key)
17 | if new_key in info_new:
18 | print(key, (info_old[key][:] == info_new[new_key][:]).all())
19 | else:
20 | print(new_key, ' missing')
21 | # print((info_old['train_data_heldin'][:] == info_new['train_spikes_heldin'][:]).all())
22 | # print(info_new['train_spikes_heldin'][:].std())
23 | # print(info_old['train_data_heldout'][:].std())
24 | # print(info_new['train_spikes_heldout'][:].std())
25 | # print(info_old['eval_data_heldin'][:].std())
26 |
27 | # print((info_old['eval_data_heldout'][:] == info_new['eval_spikes_heldout'][:]).all())
28 | # print(info_eval['mc_maze_small']['eval_spikes_heldout'][:].std())
29 |
30 | #%%
31 | #%%
32 | import torch
33 | scratch_payload = torch.load('scratch_rates.pth')
34 | old_payload = torch.load('old_rates.pth')
35 |
36 | print((scratch_payload['spikes'] == old_payload['spikes']).all())
37 | print((scratch_payload['rates'] == old_payload['rates']).all())
38 | print((scratch_payload['labels'] == old_payload['labels']).all())
39 | print(scratch_payload['labels'].sum())
40 | print(old_payload['labels'].sum())
41 | print(scratch_payload['loss'].mean(), old_payload['loss'].mean())
--------------------------------------------------------------------------------
/scripts/simple_ci.py:
--------------------------------------------------------------------------------
1 | # Calculate some confidence intervals
2 |
3 | #%%
4 | import os
5 | import os.path as osp
6 |
7 | import sys
8 | module_path = os.path.abspath(os.path.join('..'))
9 | if module_path not in sys.path:
10 | sys.path.append(module_path)
11 |
12 | import time
13 | import h5py
14 | import matplotlib.pyplot as plt
15 | import seaborn as sns
16 | import numpy as np
17 | import torch
18 | import torch.nn.functional as f
19 | from torch.utils import data
20 |
21 | from src.run import prepare_config
22 | from src.runner import Runner
23 | from src.dataset import SpikesDataset, DATASET_MODES
24 | from src.mask import UNMASKED_LABEL
25 |
26 | from analyze_utils import init_by_ckpt
27 |
28 | grid = True
29 | ckpt = "Grid"
30 | best_model = "best_model"
31 | best_model = "best_model_unmasked"
32 | lve = "lfve" if "unmasked" in best_model else "lve"
33 | r2s = []
34 | mnlls = []
35 | for i in range(3):
36 | variant = f"lorenz-s{i+1}"
37 | # variant = f"lorenz_lite-s{i+1}"
38 | # variant = f"chaotic-s{i+1}"
39 | # variant = f"chaotic_lite-s{i+1}"
40 | ckpt_path = f"/snel/home/joely/ray_results/ndt/gridsearch/{variant}/{best_model}/ckpts/{variant}.{lve}.pth"
41 | runner, spikes, rates = init_by_ckpt(ckpt_path, mode=DATASET_MODES.val)
42 |
43 | # print(runner.config.MODEL.CONTEXT_FORWARD)
44 | # print(runner.config.MODEL.CONTEXT_BACKWARD)
45 | # print(runner.config.TRAIN.MASK_MAX_SPAN)
46 | (
47 | unmasked_loss,
48 | pred_rates,
49 | layer_outputs,
50 | attn_weights,
51 | attn_list,
52 | ) = runner.model(spikes, mask_labels=spikes, return_weights=True)
53 | print(f"Best Unmasked Val: {runner.best_unmasked_val}") # .37763
54 | print(f"Best Masked Val: {runner.best_val}") # ` best val is .380 at 1302...
55 | mnlls.append(runner.best_val['value'])
56 | print(f"Best R2: {runner.best_R2}") # best val is .300 at 1302...
57 |
58 | print(f"Unmasked: {unmasked_loss}")
59 |
60 | if "maze" not in variant:
61 | r2 = runner.neuron_r2(rates, pred_rates, flatten=True)
62 | # r2 = runner.neuron_r2(rates, pred_rates)
63 | r2s.append(r2)
64 | vaf = runner.neuron_vaf(rates, pred_rates)
65 | print(f"R2:\t{r2}, VAF:\t{vaf}")
66 |
67 | trials, time, num_neurons = rates.size()
68 | print(runner.count_updates)
69 | print(runner.config.MODEL.EMBED_DIM)
70 |
71 | #%%
72 | import math
73 | # print(sum(mnlls) / 3)
74 | # print(mnlls)
75 | def print_ci(r2s):
76 | r2s = np.array(r2s)
77 | mean = r2s.mean()
78 | ci = r2s.std() * 1.96 / math.sqrt(3)
79 | print(f"{mean:.3f} \pm {ci:.3f}")
80 | print_ci(r2s)
81 | # print_ci([0.2079, 0.2077, 0.2091])
82 | # print_ci([
83 | # 0.9225, 0.9113, 0.9183
84 | # ])
85 | # print_ci([0.8712, 0.8687, 0.8664])
86 | # print_ci([0.9255, 0.9271, 0.9096])
87 | # print_ci([.924, .914, .9174])
88 | # print_ci([.0496, .0095, .0054])
89 | # print_ci([.4003, .0382, .4014])
90 | # print_ci([.52, .4469, .6242])
91 | print_ci([.416, .0388, .4221])
92 | print_ci([.0516, .0074, .0062])
93 | print_ci([
94 | 0.5184, 0.4567, 0.6724
95 | ])
96 |
97 | #%%
98 | # Port for rds analysis
99 | if "maze" in variant:
100 | rate_output_pth = f"/snel/share/joel/ndt_rates/"
101 | rate_output_pth = osp.join(rate_output_pth, "grid" if grid else "pbt")
102 | rate_output_fn = f"{variant}.h5"
103 | pred_rates, layer_outputs = runner.get_rates(
104 | checkpoint_path=ckpt_path,
105 | save_path=osp.join(rate_output_pth, rate_output_fn)
106 | )
107 | trials, time, num_neurons = pred_rates.size()
108 |
109 |
110 | #%%
111 |
112 | (
113 | unmasked_loss,
114 | pred_rates,
115 | layer_outputs,
116 | attn_weights,
117 | attn_list,
118 | ) = runner.model(spikes, mask_labels=spikes, return_weights=True)
119 | print(f"Unmasked: {unmasked_loss}")
120 |
121 | if "maze" not in variant:
122 | r2 = runner.neuron_r2(rates, pred_rates)
123 | vaf = runner.neuron_vaf(rates, pred_rates)
124 | print(f"R2:\t{r2}, VAF:\t{vaf}")
125 |
126 | trials, time, num_neurons = rates.size()
127 | print(f"Best Masked Val: {runner.best_val}") # ` best val is .380 at 1302...
128 | print(f"Best Unmasked Val: {runner.best_unmasked_val}") # .37763
129 | print(f"Best R2: {runner.best_R2}") # best val is .300 at 1302...
130 | print(runner.count_updates)
131 |
--------------------------------------------------------------------------------
/scripts/train.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | if [[ $# -eq 1 ]]
4 | then
5 | python -u src/run.py --run-type train --exp-config configs/$1.yaml
6 | elif [[ $# -eq 2 ]]
7 | then
8 | python -u src/run.py --run-type train --exp-config configs/$1.yaml --ckpt-path $1.$2.pth
9 | else
10 | echo "Expected args (ckpt)"
11 | fi
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Author: Joel Ye
3 |
4 | from src.logger_wrapper import create_logger
5 | from src.model_registry import (
6 | get_model_class, is_learning_model, is_input_masked_model
7 | )
8 | from src.tb_wrapper import TensorboardWriter
9 |
10 | __all__ = [
11 | "get_model_class",
12 | "is_learning_model",
13 | "is_input_masked_model",
14 | "create_logger",
15 | "TensorboardWriter"
16 | ]
--------------------------------------------------------------------------------
/src/config/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/snel-repo/neural-data-transformers/98dd85a24885ffb76adfeed0c2a89d3ea3ecf9d1/src/config/__init__.py
--------------------------------------------------------------------------------
/src/config/default.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Author: Joel Ye
3 |
4 | from typing import List, Optional, Union
5 |
6 | from yacs.config import CfgNode as CN
7 |
8 | DEFAULT_CONFIG_DIR = "config/"
9 | CONFIG_FILE_SEPARATOR = ","
10 |
11 | # -----------------------------------------------------------------------------
12 | # Config definition
13 | # -----------------------------------------------------------------------------
14 | _C = CN()
15 | _C.SEED = 100
16 |
17 | # Name of experiment
18 | _C.VARIANT = "experiment"
19 | _C.USE_TENSORBOARD = True
20 | _C.TENSORBOARD_DIR = "tb/"
21 | _C.CHECKPOINT_DIR = "ckpts/"
22 | _C.LOG_DIR = "logs/"
23 |
24 | # -----------------------------------------------------------------------------
25 | # System
26 | # -----------------------------------------------------------------------------
27 | _C.SYSTEM = CN()
28 | _C.SYSTEM.TORCH_GPU_ID = 0
29 | _C.SYSTEM.GPU_AUTO_ASSIGN = True # Auto-assign
30 | _C.SYSTEM.NUM_GPUS = 1
31 |
32 | # -----------------------------------------------------------------------------
33 | # Data
34 | # -----------------------------------------------------------------------------
35 | _C.DATA = CN()
36 | _C.DATA.DATAPATH = 'data/'
37 | _C.DATA.TRAIN_FILENAME = 'train.pth'
38 | _C.DATA.VAL_FILENAME = 'val.pth'
39 | _C.DATA.TEST_FILENAME = 'test.pth'
40 | _C.DATA.OVERFIT_TEST = False
41 | _C.DATA.RANDOM_SUBSET_TRIALS = 1.0 # Testing how NDT performs on a variety of dataset sizes
42 |
43 | _C.DATA.LOG_EPSILON = 1e-7 # prevent -inf if we use logrates
44 | # _C.DATA.IGNORE_FORWARD = True # Ignore forward prediction even if it's available in train. Useful if we don't have forward spikes in validation. (system misbehaves if we only have train...)
45 | # Performance with above seems subpar, i.e. we need forward and heldout together for some reason
46 | _C.DATA.IGNORE_FORWARD = False
47 |
48 | # -----------------------------------------------------------------------------
49 | # Model
50 | # -----------------------------------------------------------------------------
51 | _C.MODEL = CN()
52 | _C.MODEL.NAME = "NeuralDataTransformer"
53 | _C.MODEL.TRIAL_LENGTH = -1 # -1 represents "as much as available" # ! not actually supported in model yet...
54 | _C.MODEL.CONTEXT_FORWARD = 4 # -1 represents "as much as available"
55 | _C.MODEL.CONTEXT_BACKWARD = 8 # -1 represents "as much as available"
56 | _C.MODEL.CONTEXT_WRAP_INITIAL = False
57 | _C.MODEL.FULL_CONTEXT = False # Ignores CONTEXT_FORWARD and CONTEXT_BACKWARD if True (-1 for both)
58 | _C.MODEL.UNMASKED_LOSS_SCALE = 0.0 # Relative scale for predicting unmasked spikes (kinda silly - deprecated)
59 | _C.MODEL.HIDDEN_SIZE = 128 # Generic hidden size, used as default
60 | _C.MODEL.DROPOUT = .1 # Catch all
61 | _C.MODEL.DROPOUT_RATES = 0.2 # Specific for rates
62 | _C.MODEL.DROPOUT_EMBEDDING = 0.2 # Dropout Population Activity pre-transformer
63 | _C.MODEL.NUM_HEADS = 2
64 | _C.MODEL.NUM_LAYERS = 6
65 | _C.MODEL.ACTIVATION = "relu" # "gelu"
66 | _C.MODEL.LINEAR_EMBEDDER = False # Use linear layer instead of embedding layer
67 | _C.MODEL.EMBED_DIM = 2 # this greatly affects model size btw
68 | _C.MODEL.LEARNABLE_POSITION = False
69 | _C.MODEL.MAX_SPIKE_COUNT = 20
70 | _C.MODEL.REQUIRES_RATES = False
71 | _C.MODEL.LOGRATE = True # If true, we operate in lograte, and assume rates from data are logrates. Only for R2 do we exp
72 | _C.MODEL.SPIKE_LOG_INIT = False # If true, init spike embeddings as a 0 centered linear sequence
73 | _C.MODEL.FIXUP_INIT = False
74 | _C.MODEL.PRE_NORM = False # per transformers without tears
75 | _C.MODEL.SCALE_NORM = False # per transformers without tears
76 |
77 | _C.MODEL.DECODER = CN()
78 | _C.MODEL.DECODER.LAYERS = 1
79 |
80 | _C.MODEL.LOSS = CN()
81 | _C.MODEL.LOSS.TYPE = "poisson" # ["cel", "poisson"]
82 | _C.MODEL.LOSS.TOPK = 1.0 # In case we're neglecting some neurons, focus on them
83 |
84 | _C.MODEL.POSITION = CN()
85 | _C.MODEL.POSITION.OFFSET = True
86 | # -----------------------------------------------------------------------------
87 | # Train Config
88 | # -----------------------------------------------------------------------------
89 | _C.TRAIN = CN()
90 |
91 | _C.TRAIN.DO_VAL = True # Run validation while training
92 | _C.TRAIN.DO_R2 = True # Run validation while training
93 |
94 | _C.TRAIN.BATCH_SIZE = 64
95 | _C.TRAIN.NUM_UPDATES = 10000 # Max updates
96 | _C.TRAIN.MAX_GRAD_NORM = 200.0
97 | _C.TRAIN.USE_ZERO_MASK = True
98 | _C.TRAIN.MASK_RATIO = 0.2
99 | _C.TRAIN.MASK_TOKEN_RATIO = 1.0 # We don't need this if we use zero mask
100 | _C.TRAIN.MASK_RANDOM_RATIO = 0.5 # Of the non-replaced, what percentage should be random?
101 | _C.TRAIN.MASK_MODE = "timestep" # ["full", "timestep"]
102 | _C.TRAIN.MASK_MAX_SPAN = 1
103 | _C.TRAIN.MASK_SPAN_RAMP_START = 600
104 | _C.TRAIN.MASK_SPAN_RAMP_END = 1200
105 |
106 | _C.TRAIN.LR = CN()
107 | _C.TRAIN.LR.INIT = 1e-3
108 | _C.TRAIN.LR.SCHEDULE = True
109 | _C.TRAIN.LR.SCHEDULER = "cosine" # invsqrt
110 | _C.TRAIN.LR.WARMUP = 1000 # Mostly decay
111 | _C.TRAIN.WEIGHT_DECAY = 0.0
112 | _C.TRAIN.EPS = 1e-8
113 | _C.TRAIN.PATIENCE = 750 # For early stopping (be generous, our loss steps)
114 |
115 | _C.TRAIN.CHECKPOINT_INTERVAL = 1000
116 | _C.TRAIN.LOG_INTERVAL = 50
117 | _C.TRAIN.VAL_INTERVAL = 10 # Val less often so things run faster
118 |
119 | _C.TRAIN.TUNE_MODE = False
120 | _C.TRAIN.TUNE_EPOCHS_PER_GENERATION = 500
121 | _C.TRAIN.TUNE_HP_JSON = "./lorenz_pbt.json"
122 | _C.TRAIN.TUNE_WARMUP = 0
123 | _C.TRAIN.TUNE_METRIC = "smth_masked_loss"
124 | # JSON schema - flattened config dict. Each entry has info to construct a hyperparam.
125 |
126 | def get_cfg_defaults():
127 | """Get default LFADS config (yacs config node)."""
128 | return _C.clone()
129 |
130 | def get_config(
131 | config_paths: Optional[Union[List[str], str]] = None,
132 | opts: Optional[list] = None,
133 | ) -> CN:
134 | r"""Create a unified config with default values overwritten by values from
135 | :p:`config_paths` and overwritten by options from :p:`opts`.
136 |
137 | :param config_paths: List of config paths or string that contains comma
138 | separated list of config paths.
139 | :param opts: Config options (keys, values) in a list (e.g., passed from
140 | command line into the config. For example,
141 | :py:`opts = ['FOO.BAR', 0.5]`. Argument can be used for parameter
142 | sweeping or quick tests.
143 | """
144 | config = get_cfg_defaults()
145 | if config_paths:
146 | if isinstance(config_paths, str):
147 | if CONFIG_FILE_SEPARATOR in config_paths:
148 | config_paths = config_paths.split(CONFIG_FILE_SEPARATOR)
149 | else:
150 | config_paths = [config_paths]
151 |
152 | for config_path in config_paths:
153 | config.merge_from_file(config_path)
154 |
155 | if opts:
156 | config.merge_from_list(opts)
157 |
158 | config.freeze()
159 | return config
160 |
161 |
162 | # The flatten and unflatten snippets are from an internal lfads_tf2 implementation.
163 |
164 | def flatten(dictionary, level=[]):
165 | """ Flattens a dictionary by placing '.' between levels.
166 |
167 | This function flattens a hierarchical dictionary by placing '.'
168 | between keys at various levels to create a single key for each
169 | value. It is used internally for converting the configuration
170 | dictionary to more convenient formats. Implementation was
171 | inspired by `this StackOverflow post
172 | `_.
173 |
174 | Parameters
175 | ----------
176 | dictionary : dict
177 | The hierarchical dictionary to be flattened.
178 | level : str, optional
179 | The string to append to the beginning of this dictionary,
180 | enabling recursive calls. By default, an empty string.
181 |
182 | Returns
183 | -------
184 | dict
185 | The flattened dictionary.
186 |
187 | See Also
188 | --------
189 | lfads_tf2.utils.unflatten : Performs the opposite of this operation.
190 |
191 | """
192 |
193 | tmp_dict = {}
194 | for key, val in dictionary.items():
195 | if type(val) == dict:
196 | tmp_dict.update(flatten(val, level + [key]))
197 | else:
198 | tmp_dict['.'.join(level + [key])] = val
199 | return tmp_dict
200 |
201 |
202 | def unflatten(dictionary):
203 | """ Unflattens a dictionary by splitting keys at '.'s.
204 |
205 | This function unflattens a hierarchical dictionary by splitting
206 | its keys at '.'s. It is used internally for converting the
207 | configuration dictionary to more convenient formats. Implementation was
208 | inspired by `this StackOverflow post
209 | `_.
210 |
211 | Parameters
212 | ----------
213 | dictionary : dict
214 | The flat dictionary to be unflattened.
215 |
216 | Returns
217 | -------
218 | dict
219 | The unflattened dictionary.
220 |
221 | See Also
222 | --------
223 | lfads_tf2.utils.flatten : Performs the opposite of this operation.
224 |
225 | """
226 |
227 | resultDict = dict()
228 | for key, value in dictionary.items():
229 | parts = key.split(".")
230 | d = resultDict
231 | for part in parts[:-1]:
232 | if part not in d:
233 | d[part] = dict()
234 | d = d[part]
235 | d[parts[-1]] = value
236 | return resultDict
237 |
238 |
--------------------------------------------------------------------------------
/src/dataset.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Author: Joel Ye
3 | import os.path as osp
4 |
5 | import h5py
6 | import numpy as np
7 | import torch
8 | from torch.utils import data
9 |
10 | from src.utils import merge_train_valid
11 |
12 | class DATASET_MODES:
13 | train = "train"
14 | val = "val"
15 | test = "test"
16 | trainval = "trainval"
17 |
18 | class SpikesDataset(data.Dataset):
19 | r"""
20 | Dataset for single file of spike times (loads into memory)
21 | Lorenz data is NxTxH (H being number of neurons) - we load to T x N x H
22 | # ! Note that codepath for forward but not heldout neurons is not tested and likely broken
23 | """
24 |
25 | def __init__(self, config, filename, mode=DATASET_MODES.train, logger=None):
26 | r"""
27 | args:
28 | config: dataset config
29 | filename: excluding path
30 | mode: used to extract the right indices from LFADS h5 data
31 | """
32 | super().__init__()
33 | self.logger = logger
34 | if self.logger is not None:
35 | self.logger.info(f"Loading {filename} in {mode}")
36 | self.config = config.DATA
37 | self.use_lograte = config.MODEL.LOGRATE
38 | self.batch_size = config.TRAIN.BATCH_SIZE
39 | self.datapath = osp.join(config.DATA.DATAPATH, filename)
40 | split_path = self.datapath.split(".")
41 |
42 | self.has_rates = False
43 | self.has_heldout = False
44 | self.has_forward = False
45 | if len(split_path) == 1 or split_path[-1] == "h5":
46 | spikes, rates, heldout_spikes, forward_spikes = self.get_data_from_h5(mode, self.datapath)
47 |
48 | spikes = torch.tensor(spikes).long()
49 | if rates is not None:
50 | rates = torch.tensor(rates)
51 | if heldout_spikes is not None:
52 | self.has_heldout = True
53 | heldout_spikes = torch.tensor(heldout_spikes).long()
54 | if forward_spikes is not None and not config.DATA.IGNORE_FORWARD:
55 | self.has_forward = True
56 | forward_spikes = torch.tensor(forward_spikes).long()
57 | else:
58 | forward_spikes = None
59 | elif split_path[-1] == "pth":
60 | dataset_dict = torch.load(self.datapath)
61 | spikes = dataset_dict["spikes"]
62 | if "rates" in dataset_dict:
63 | self.has_rates = True
64 | rates = dataset_dict["rates"]
65 | heldout_spikes = None
66 | forward_spikes = None
67 | else:
68 | raise Exception(f"Unknown dataset extension {split_path[-1]}")
69 |
70 | self.num_trials, _, self.num_neurons = spikes.size()
71 | self.full_length = config.MODEL.TRIAL_LENGTH <= 0
72 | self.trial_length = spikes.size(1) if self.full_length else config.MODEL.TRIAL_LENGTH
73 | if self.has_heldout:
74 | self.num_neurons += heldout_spikes.size(-1)
75 | if self.has_forward:
76 | self.trial_length += forward_spikes.size(1)
77 | self.spikes = self.batchify(spikes)
78 | # Fake rates so we can skip None checks everywhere. Use `self.has_rates` when desired
79 | self.rates = self.batchify(rates) if self.has_rates else torch.zeros_like(spikes)
80 | # * else condition below is not precisely correctly shaped as correct shape isn't used
81 | self.heldout_spikes = self.batchify(heldout_spikes) if self.has_heldout else torch.zeros_like(spikes)
82 | self.forward_spikes = self.batchify(forward_spikes) if self.has_forward else torch.zeros_like(spikes)
83 |
84 | if config.DATA.OVERFIT_TEST:
85 | if self.logger is not None:
86 | self.logger.warning("Overfitting..")
87 | self.spikes = self.spikes[:2]
88 | self.rates = self.rates[:2]
89 | self.num_trials = 2
90 | elif hasattr(config.DATA, "RANDOM_SUBSET_TRIALS") and config.DATA.RANDOM_SUBSET_TRIALS < 1.0 and mode == DATASET_MODES.train:
91 | if self.logger is not None:
92 | self.logger.warning(f"!!!!! Training on {config.DATA.RANDOM_SUBSET_TRIALS} of the data with seed {config.SEED}.")
93 | reduced = int(self.num_trials * config.DATA.RANDOM_SUBSET_TRIALS)
94 | torch.random.manual_seed(config.SEED)
95 | random_subset = torch.randperm(self.num_trials)[:reduced]
96 | self.num_trials = reduced
97 | self.spikes = self.spikes[random_subset]
98 | self.rates = self.rates[random_subset]
99 |
100 | def batchify(self, x):
101 | r"""
102 | Chops data into uniform sizes as configured by trial_length.
103 |
104 | Returns:
105 | x reshaped as num_samples x trial_length x neurons
106 | """
107 | if self.full_length:
108 | return x
109 | trial_time = x.size(1)
110 | samples_per_trial = trial_time // self.trial_length
111 | if trial_time % self.trial_length != 0:
112 | if self.logger is not None:
113 | self.logger.debug(f"Trimming dangling trial info. Data trial length {trial_time} \
114 | is not divisible by asked length {self.trial_length})")
115 | x = x.narrow(1, 0, samples_per_trial * self.trial_length)
116 |
117 | # ! P sure this can be a view op
118 | # num_samples x trial_length x neurons
119 | return torch.cat(torch.split(x, self.trial_length, dim=1), dim=0)
120 |
121 | def get_num_neurons(self):
122 | return self.num_neurons
123 |
124 | def __len__(self):
125 | return self.spikes.size(0)
126 |
127 | def __getitem__(self, index):
128 | r"""
129 | Return spikes and rates, shaped T x N (num_neurons)
130 | """
131 | return (
132 | self.spikes[index],
133 | None if self.rates is None else self.rates[index],
134 | None if self.heldout_spikes is None else self.heldout_spikes[index],
135 | None if self.forward_spikes is None else self.forward_spikes[index]
136 | )
137 |
138 | def get_dataset(self):
139 | return self.spikes, self.rates, self.heldout_spikes, self.forward_spikes
140 |
141 | def get_max_spikes(self):
142 | return self.spikes.max().item()
143 |
144 | def get_num_batches(self):
145 | return self.spikes.size(0) // self.batch_size
146 |
147 | def clip_spikes(self, max_val):
148 | self.spikes = torch.clamp(self.spikes, max=max_val)
149 |
150 | def get_data_from_h5(self, mode, filepath):
151 | r"""
152 | returns:
153 | spikes
154 | rates (None if not available)
155 | held out spikes (for cosmoothing, None if not available)
156 | * Note, rates and held out spikes codepaths conflict
157 | """
158 | NLB_KEY = 'spikes' # curiously, old code thought NLB data keys came as "train_data_heldin" and not "train_spikes_heldin"
159 | NLB_KEY_ALT = 'data'
160 |
161 | with h5py.File(filepath, 'r') as h5file:
162 | h5dict = {key: h5file[key][()] for key in h5file.keys()}
163 | if f'eval_{NLB_KEY}_heldin' not in h5dict: # double check
164 | if f'eval_{NLB_KEY_ALT}_heldin' in h5dict:
165 | NLB_KEY = NLB_KEY_ALT
166 | if f'eval_{NLB_KEY}_heldin' in h5dict: # NLB data, presumes both heldout neurons and time are available
167 | get_key = lambda key: h5dict[key].astype(np.float32)
168 | train_data = get_key(f'train_{NLB_KEY}_heldin')
169 | train_data_fp = get_key(f'train_{NLB_KEY}_heldin_forward')
170 | train_data_heldout_fp = get_key(f'train_{NLB_KEY}_heldout_forward')
171 | train_data_all_fp = np.concatenate([train_data_fp, train_data_heldout_fp], -1)
172 | valid_data = get_key(f'eval_{NLB_KEY}_heldin')
173 | train_data_heldout = get_key(f'train_{NLB_KEY}_heldout')
174 | if f'eval_{NLB_KEY}_heldout' in h5dict:
175 | valid_data_heldout = get_key(f'eval_{NLB_KEY}_heldout')
176 | else:
177 | self.logger.warn('Substituting zero array for heldout neurons. Only done for evaluating models locally, i.e. will disrupt training due to early stopping.')
178 | valid_data_heldout = np.zeros((valid_data.shape[0], valid_data.shape[1], train_data_heldout.shape[2]), dtype=np.float32)
179 | if f'eval_{NLB_KEY}_heldin_forward' in h5dict:
180 | valid_data_fp = get_key(f'eval_{NLB_KEY}_heldin_forward')
181 | valid_data_heldout_fp = get_key(f'eval_{NLB_KEY}_heldout_forward')
182 | valid_data_all_fp = np.concatenate([valid_data_fp, valid_data_heldout_fp], -1)
183 | else:
184 | self.logger.warn('Substituting zero array for heldout forward neurons. Only done for evaluating models locally, i.e. will disrupt training due to early stopping.')
185 | valid_data_all_fp = np.zeros(
186 | (valid_data.shape[0], train_data_fp.shape[1], valid_data.shape[2] + valid_data_heldout.shape[2]), dtype=np.float32
187 | )
188 |
189 | # NLB data does not have ground truth rates
190 | if mode == DATASET_MODES.train:
191 | return train_data, None, train_data_heldout, train_data_all_fp
192 | elif mode == DATASET_MODES.val:
193 | return valid_data, None, valid_data_heldout, valid_data_all_fp
194 | train_data = h5dict['train_data'].astype(np.float32).squeeze()
195 | valid_data = h5dict['valid_data'].astype(np.float32).squeeze()
196 | train_rates = None
197 | valid_rates = None
198 | if "train_truth" and "valid_truth" in h5dict: # original LFADS-type datasets
199 | self.has_rates = True
200 | train_rates = h5dict['train_truth'].astype(np.float32)
201 | valid_rates = h5dict['valid_truth'].astype(np.float32)
202 | train_rates = train_rates / h5dict['conversion_factor']
203 | valid_rates = valid_rates / h5dict['conversion_factor']
204 | if self.use_lograte:
205 | train_rates = torch.log(torch.tensor(train_rates) + self.config.LOG_EPSILON)
206 | valid_rates = torch.log(torch.tensor(valid_rates) + self.config.LOG_EPSILON)
207 | if mode == DATASET_MODES.train:
208 | return train_data, train_rates, None, None
209 | elif mode == DATASET_MODES.val:
210 | return valid_data, valid_rates, None, None
211 | elif mode == DATASET_MODES.trainval:
212 | # merge training and validation data
213 | if 'train_inds' in h5dict and 'valid_inds' in h5dict:
214 | # if there are index labels, use them to reassemble full data
215 | train_inds = h5dict['train_inds'].squeeze()
216 | valid_inds = h5dict['valid_inds'].squeeze()
217 | file_data = merge_train_valid(
218 | train_data, valid_data, train_inds, valid_inds)
219 | if self.has_rates:
220 | merged_rates = merge_train_valid(
221 | train_rates, valid_rates, train_inds, valid_inds
222 | )
223 | else:
224 | if self.logger is not None:
225 | self.logger.info("No indices found for merge. "
226 | "Concatenating training and validation samples.")
227 | file_data = np.concatenate([train_data, valid_data], axis=0)
228 | if self.has_rates:
229 | merged_rates = np.concatenate([train_rates, valid_rates], axis=0)
230 | return file_data, merged_rates if self.has_rates else None, None, None
231 | else: # test unsupported
232 | return None, None, None, None
--------------------------------------------------------------------------------
/src/logger_wrapper.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Adapted from Facebook Habitat Framework
4 |
5 | import logging
6 | import copy
7 |
8 | class Logger(logging.Logger):
9 | def __init__(
10 | self,
11 | name,
12 | level,
13 | filename=None,
14 | filemode="a",
15 | stream=None,
16 | format=None,
17 | dateformat=None,
18 | style="%",
19 | ):
20 | super().__init__(name, level)
21 | if filename is not None:
22 | handler = logging.FileHandler(filename, filemode)
23 | else:
24 | handler = logging.StreamHandler(stream)
25 | self._formatter = logging.Formatter(format, dateformat, style)
26 | handler.setFormatter(self._formatter)
27 | super().addHandler(handler)
28 | self.stat_queue = [] # Going to be tuples
29 |
30 | def clear_filehandlers(self):
31 | self.handlers = [h for h in self.handlers if not isinstance(h, logging.FileHandler)]
32 |
33 | def clear_streamhandlers(self):
34 | self.handlers = [h for h in self.handlers if (not isinstance(h, logging.StreamHandler) or isinstance(h, logging.FileHandler))]
35 |
36 | def add_filehandler(self, log_filename):
37 | filehandler = logging.FileHandler(log_filename)
38 | filehandler.setFormatter(self._formatter)
39 | self.addHandler(filehandler)
40 |
41 | def queue_stat(self, stat_name, stat):
42 | self.stat_queue.append((stat_name, stat))
43 |
44 | def empty_queue(self):
45 | queue = copy.deepcopy(self.stat_queue)
46 | self.stat_queue = []
47 | return queue
48 |
49 | def log_update(self, update):
50 | stat_str = "\t".join([f"{stat[0]}: {stat[1]:.3f}" for stat in self.empty_queue()])
51 | self.info("update: {}\t{}".format(update, stat_str))
52 |
53 | def mute(self):
54 | self.setLevel(logging.ERROR)
55 |
56 | def unmute(self):
57 | self.setLevel(logging.INFO)
58 |
59 | def create_logger():
60 | return Logger(
61 | name="NDT", level=logging.INFO, format="%(asctime)-15s %(message)s"
62 | )
63 |
64 | __all__ = ["create_logger"]
65 |
--------------------------------------------------------------------------------
/src/mask.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Author: Joel Ye
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 | # Some infeasibly high spike count
9 | DEFAULT_MASK_VAL = 30
10 | UNMASKED_LABEL = -100
11 | SUPPORTED_MODES = ["full", "timestep", "neuron", "timestep_only"]
12 |
13 | # Use a class so we can cache random mask
14 | class Masker:
15 |
16 | def __init__(self, train_cfg, device):
17 | self.update_config(train_cfg)
18 | if self.cfg.MASK_MODE not in SUPPORTED_MODES:
19 | raise Exception(f"Given {self.cfg.MASK_MODE} not in supported {SUPPORTED_MODES}")
20 | self.device = device
21 |
22 | def update_config(self, config):
23 | self.cfg = config
24 | self.prob_mask = None
25 |
26 | def expand_mask(self, mask, width):
27 | r"""
28 | args:
29 | mask: N x T
30 | width: expansion block size
31 | """
32 | kernel = torch.ones(width, device=mask.device).view(1, 1, -1)
33 | expanded_mask = F.conv1d(mask.unsqueeze(1), kernel, padding= width// 2).clamp_(0, 1)
34 | if width % 2 == 0:
35 | expanded_mask = expanded_mask[...,:-1] # crop if even (we've added too much padding)
36 | return expanded_mask.squeeze(1)
37 |
38 | def mask_batch(
39 | self,
40 | batch,
41 | mask=None,
42 | max_spikes=DEFAULT_MASK_VAL - 1,
43 | should_mask=True,
44 | expand_prob=0.0,
45 | heldout_spikes=None,
46 | forward_spikes=None,
47 | ):
48 | r""" Given complete batch, mask random elements and return true labels separately.
49 | Modifies batch OUT OF place!
50 | Modeled after HuggingFace's `mask_tokens` in `run_language_modeling.py`
51 | args:
52 | batch: batch NxTxH
53 | mask_ratio: ratio to randomly mask
54 | mode: "full" or "timestep" - if "full", will randomly drop on full matrix, whereas on "timestep", will mask out random timesteps
55 | mask: Optional mask to use
56 | max_spikes: in case not zero masking, "mask token"
57 | expand_prob: with this prob, uniformly expand. else, keep single tokens. UniLM does, with 40% expand to fixed, else keep single.
58 | heldout_spikes: None
59 | returns:
60 | batch: list of data batches NxTxH, with some elements along H set to -1s (we allow peeking between rates)
61 | labels: true data (also NxTxH)
62 | """
63 | batch = batch.clone() # make sure we don't corrupt the input data (which is stored in memory)
64 |
65 | mode = self.cfg.MASK_MODE
66 | should_expand = self.cfg.MASK_MAX_SPAN > 1 and expand_prob > 0.0 and torch.rand(1).item() < expand_prob
67 | width = torch.randint(1, self.cfg.MASK_MAX_SPAN + 1, (1, )).item() if should_expand else 1
68 | mask_ratio = self.cfg.MASK_RATIO if width == 1 else self.cfg.MASK_RATIO / width
69 |
70 | labels = batch.clone()
71 | if mask is None:
72 | if self.prob_mask is None or self.prob_mask.size() != labels.size():
73 | if mode == "full":
74 | mask_probs = torch.full(labels.shape, mask_ratio)
75 | elif mode == "timestep":
76 | single_timestep = labels[:, :, 0] # N x T
77 | mask_probs = torch.full(single_timestep.shape, mask_ratio)
78 | elif mode == "neuron":
79 | single_neuron = labels[:, 0] # N x H
80 | mask_probs = torch.full(single_neuron.shape, mask_ratio)
81 | elif mode == "timestep_only":
82 | single_timestep = labels[0, :, 0] # T
83 | mask_probs = torch.full(single_timestep.shape, mask_ratio)
84 | self.prob_mask = mask_probs.to(self.device)
85 | # If we want any tokens to not get masked, do it here (but we don't currently have any)
86 | mask = torch.bernoulli(self.prob_mask)
87 |
88 | # N x T
89 | if width > 1:
90 | mask = self.expand_mask(mask, width)
91 |
92 | mask = mask.bool()
93 | if mode == "timestep":
94 | mask = mask.unsqueeze(2).expand_as(labels)
95 | elif mode == "neuron":
96 | mask = mask.unsqueeze(0).expand_as(labels)
97 | elif mode == "timestep_only":
98 | mask = mask.unsqueeze(0).unsqueeze(2).expand_as(labels)
99 | # we want the shape of the mask to be T
100 | elif mask.size() != labels.size():
101 | raise Exception(f"Input mask of size {mask.size()} does not match input size {labels.size()}")
102 |
103 | labels[~mask] = UNMASKED_LABEL # No ground truth for unmasked - use this to mask loss
104 | if not should_mask:
105 | # Only do the generation
106 | return batch, labels
107 |
108 | # We use random assignment so the model learns embeddings for non-mask tokens, and must rely on context
109 | # Most times, we replace tokens with MASK token
110 | indices_replaced = torch.bernoulli(torch.full(labels.shape, self.cfg.MASK_TOKEN_RATIO, device=mask.device)).bool() & mask
111 | if self.cfg.USE_ZERO_MASK:
112 | batch[indices_replaced] = 0
113 | else:
114 | batch[indices_replaced] = max_spikes + 1
115 |
116 | # Random % of the time, we replace masked input tokens with random value (the rest are left intact)
117 | indices_random = torch.bernoulli(torch.full(labels.shape, self.cfg.MASK_RANDOM_RATIO, device=mask.device)).bool() & mask & ~indices_replaced
118 | random_spikes = torch.randint(batch.max(), labels.shape, dtype=torch.long, device=batch.device)
119 | batch[indices_random] = random_spikes[indices_random]
120 |
121 | if heldout_spikes is not None:
122 | # heldout spikes are all masked
123 | batch = torch.cat([batch, torch.zeros_like(heldout_spikes, device=batch.device)], -1)
124 | labels = torch.cat([labels, heldout_spikes.to(batch.device)], -1)
125 | if forward_spikes is not None:
126 | batch = torch.cat([batch, torch.zeros_like(forward_spikes, device=batch.device)], 1)
127 | labels = torch.cat([labels, forward_spikes.to(batch.device)], 1)
128 | # Leave the other 10% alone
129 | return batch, labels
--------------------------------------------------------------------------------
/src/model_baselines.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Author: Joel Ye
3 | import math
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | from torch.nn import Transformer, TransformerEncoder, TransformerEncoderLayer
9 | from torch.distributions import Poisson
10 |
11 | from src.utils import binary_mask_to_attn_mask
12 | from src.mask import UNMASKED_LABEL
13 |
14 | class RatesOracle(nn.Module):
15 |
16 | def __init__(self, config, num_neurons, device, **kwargs):
17 | super().__init__()
18 | assert config.REQUIRES_RATES == True, "Oracle requires rates"
19 | if config.LOSS.TYPE == "poisson":
20 | self.classifier = nn.PoissonNLLLoss(reduction='none', log_input=config.LOGRATE)
21 | else:
22 | raise Exception(f"Loss type {config.LOSS_TYPE} not supported")
23 |
24 | def get_hidden_size(self):
25 | return 0
26 |
27 | def forward(self, src, mask_labels, rates=None, **kwargs):
28 | # output is t x b x neurons (rate predictions)
29 | loss = self.classifier(rates, mask_labels)
30 | # Mask out losses unmasked labels
31 | masked_loss = loss[mask_labels != UNMASKED_LABEL]
32 | masked_loss = masked_loss.mean()
33 | return (
34 | masked_loss.unsqueeze(0),
35 | rates,
36 | None,
37 | torch.tensor(0, device=masked_loss.device, dtype=torch.float),
38 | None,
39 | None,
40 | )
41 |
42 | class RandomModel(nn.Module):
43 | # Guess a random rate in LOGRATE_RANGE
44 | # Purpose - why is our initial loss so close to our final loss
45 | LOGRATE_RANGE = (-2.5, 2.5)
46 |
47 | def __init__(self, config, num_neurons, device):
48 | super().__init__()
49 | self.device = device
50 | if config.LOSS.TYPE == "poisson":
51 | self.classifier = nn.PoissonNLLLoss(reduction='none', log_input=config.LOGRATE)
52 | else:
53 | raise Exception(f"Loss type {config.LOSS_TYPE} not supported")
54 |
55 | def forward(self, src, mask_labels, *args, **kwargs):
56 | # output is t x b x neurons (rate predictions)
57 | rates = torch.rand(mask_labels.size(), dtype=torch.float32).to(self.device)
58 | rates *= (self.LOGRATE_RANGE[1] - self.LOGRATE_RANGE[0])
59 | rates += self.LOGRATE_RANGE[0]
60 | loss = self.classifier(rates, mask_labels)
61 | # Mask out losses unmasked labels
62 | loss[mask_labels == UNMASKED_LABEL] = 0.0
63 |
64 | return loss.mean(), rates
65 |
--------------------------------------------------------------------------------
/src/model_registry.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Author: Joel Ye
3 |
4 | from src.model import (
5 | NeuralDataTransformer,
6 | )
7 |
8 | from src.model_baselines import (
9 | RatesOracle,
10 | RandomModel,
11 | )
12 |
13 | LEARNING_MODELS = {
14 | "NeuralDataTransformer": NeuralDataTransformer,
15 | }
16 |
17 | NONLEARNING_MODELS = {
18 | "Oracle": RatesOracle,
19 | "Random": RandomModel
20 | }
21 |
22 | INPUT_MASKED_MODELS = {
23 | "NeuralDataTransformer": NeuralDataTransformer,
24 | }
25 |
26 | MODELS = {**LEARNING_MODELS, **NONLEARNING_MODELS, **INPUT_MASKED_MODELS}
27 |
28 | def is_learning_model(model_name):
29 | return model_name in LEARNING_MODELS
30 |
31 | def is_input_masked_model(model_name):
32 | return model_name in INPUT_MASKED_MODELS
33 |
34 | def get_model_class(model_name):
35 | return MODELS[model_name]
--------------------------------------------------------------------------------
/src/run.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Author: Joel Ye
3 |
4 | from typing import List, Union
5 | import os.path as osp
6 | import shutil
7 | import random
8 |
9 | import argparse
10 | import numpy as np
11 | import torch
12 |
13 | from src.config.default import get_config
14 | from src.runner import Runner
15 |
16 | DO_PRESERVE_RUNS = False
17 |
18 | def get_parser():
19 | parser = argparse.ArgumentParser()
20 | parser.add_argument(
21 | "--run-type",
22 | choices=["train", "eval"],
23 | required=True,
24 | help="run type of the experiment (train or eval)",
25 | )
26 |
27 | parser.add_argument(
28 | "--exp-config",
29 | type=str,
30 | required=True,
31 | help="path to config yaml containing info about experiment",
32 | )
33 |
34 | parser.add_argument(
35 | "--ckpt-path",
36 | default=None,
37 | type=str,
38 | help="full path to a ckpt (for eval or resumption)"
39 | )
40 |
41 | parser.add_argument(
42 | "--clear-only",
43 | default=False,
44 | type=bool,
45 | )
46 |
47 | parser.add_argument(
48 | "opts",
49 | default=None,
50 | nargs=argparse.REMAINDER,
51 | help="Modify config options from command line",
52 | )
53 | return parser
54 |
55 | def main():
56 | parser = get_parser()
57 | args = parser.parse_args()
58 | run_exp(**vars(args))
59 |
60 | def check_exists(path, preserve=DO_PRESERVE_RUNS):
61 | if osp.exists(path):
62 | # logger.warn(f"{path} exists")
63 | print(f"{path} exists")
64 | if not preserve:
65 | # logger.warn(f"removing {path}")
66 | print(f"removing {path}")
67 | shutil.rmtree(path, ignore_errors=True)
68 | return True
69 | return False
70 |
71 |
72 | def prepare_config(exp_config: Union[List[str], str], run_type: str, ckpt_path="", opts=None, suffix=None) -> None:
73 | r"""Prepare config node / do some preprocessing
74 |
75 | Args:
76 | exp_config: path to config file.
77 | run_type: "train" or "eval.
78 | ckpt_path: If training, ckpt to resume. If evaluating, ckpt to evaluate.
79 | opts: list of strings of additional config options.
80 |
81 | Returns:
82 | Runner, config, ckpt_path
83 | """
84 | config = get_config(exp_config, opts)
85 |
86 | # Default behavior is to pull experiment name from config file
87 | # Bind variant name to directories
88 | if isinstance(exp_config, str):
89 | variant_config = exp_config
90 | else:
91 | variant_config = exp_config[-1]
92 | variant_name = osp.split(variant_config)[1].split('.')[0]
93 | config.defrost()
94 | config.VARIANT = variant_name
95 | if suffix is not None:
96 | config.TENSORBOARD_DIR = osp.join(config.TENSORBOARD_DIR, suffix)
97 | config.CHECKPOINT_DIR = osp.join(config.CHECKPOINT_DIR, suffix)
98 | config.LOG_DIR = osp.join(config.LOG_DIR, suffix)
99 | config.TENSORBOARD_DIR = osp.join(config.TENSORBOARD_DIR, config.VARIANT)
100 | config.CHECKPOINT_DIR = osp.join(config.CHECKPOINT_DIR, config.VARIANT)
101 | config.LOG_DIR = osp.join(config.LOG_DIR, config.VARIANT)
102 | config.freeze()
103 |
104 | if ckpt_path is not None:
105 | if not osp.exists(ckpt_path):
106 | ckpt_path = osp.join(config.CHECKPOINT_DIR, ckpt_path)
107 |
108 | np.random.seed(config.SEED)
109 | random.seed(config.SEED)
110 | torch.random.manual_seed(config.SEED)
111 | torch.backends.cudnn.deterministic = True
112 |
113 | return config, ckpt_path
114 |
115 | def run_exp(exp_config: Union[List[str], str], run_type: str, ckpt_path="", clear_only=False, opts=None, suffix=None) -> None:
116 | config, ckpt_path = prepare_config(exp_config, run_type, ckpt_path, opts, suffix=suffix)
117 | if clear_only:
118 | check_exists(config.TENSORBOARD_DIR, preserve=False)
119 | check_exists(config.CHECKPOINT_DIR, preserve=False)
120 | check_exists(config.LOG_DIR, preserve=False)
121 | exit(0)
122 | if run_type == "train":
123 | if ckpt_path is not None:
124 | runner = Runner(config)
125 | runner.train(checkpoint_path=ckpt_path)
126 | else:
127 | if DO_PRESERVE_RUNS:
128 | if check_exists(config.TENSORBOARD_DIR) or \
129 | check_exists(config.CHECKPOINT_DIR) or \
130 | check_exists(config.LOG_DIR):
131 | exit(1)
132 | else:
133 | check_exists(config.TENSORBOARD_DIR)
134 | check_exists(config.CHECKPOINT_DIR)
135 | check_exists(config.LOG_DIR)
136 | runner = Runner(config)
137 | runner.train()
138 | # * The evaluation code path has legacy code (and will not run).
139 | # * Evaluation / analysis is done in analysis scripts.
140 | # elif run_type == "eval":
141 | # runner.eval(checkpoint_path=ckpt_path)
142 |
143 | if __name__ == "__main__":
144 | main()
145 |
--------------------------------------------------------------------------------
/src/tb_wrapper.py:
--------------------------------------------------------------------------------
1 | from torch.utils.tensorboard import SummaryWriter
2 |
3 | TensorboardWriter = SummaryWriter
--------------------------------------------------------------------------------
/src/utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Author: Joel Ye
3 |
4 | import numpy as np
5 | import torch
6 | from torch.optim.lr_scheduler import LambdaLR
7 |
8 | def binary_mask_to_attn_mask(x):
9 | return x.float().masked_fill(x == 0, float('-inf')).masked_fill(x == 1, float(0.0))
10 |
11 | # Verbatim from LFADS_TF2
12 | def merge_train_valid(train_data, valid_data, train_ixs, valid_ixs):
13 | """Merges training and validation numpy arrays using indices.
14 |
15 | This function merges training and validation numpy arrays
16 | in the appropriate order using arrays of their indices. The
17 | lengths of the indices must be the same as the first dimension
18 | of the corresponding data.
19 |
20 | Parameters
21 | ----------
22 | train_data : np.ndarray
23 | An N-dimensional numpy array of training data with
24 | first dimension T.
25 | valid_data : np.ndarray
26 | An N-dimensional numpy array of validation data with
27 | first dimension V.
28 | train_ixs : np.ndarray
29 | A 1-D numpy array of training indices with length T.
30 | valid_ixs : np.ndarray
31 | A 1-D numpy array of validation indices with length V.
32 |
33 | Returns
34 | -------
35 | np.ndarray
36 | An N-dimensional numpy array with dimension T + V.
37 |
38 | """
39 |
40 | if train_data.shape[0] == train_ixs.shape[0] \
41 | and valid_data.shape[0] == valid_ixs.shape[0]:
42 | # if the indices match up, then we can use them to merge
43 | data = np.full_like(np.concatenate([train_data, valid_data]), np.nan)
44 | if min(min(train_ixs), min(valid_ixs)) > 0:
45 | # we've got matlab data...
46 | train_ixs -= 1
47 | valid_ixs -= 1
48 | data[train_ixs.astype(int)] = train_data
49 | data[valid_ixs.astype(int)] = valid_data
50 | else:
51 | # if the indices do not match, train and
52 | # valid data may be the same (e.g. for priors)
53 | if np.all(train_data == valid_data):
54 | data = train_data
55 | else:
56 | raise ValueError("shape mismatch: "
57 | f"Index shape {train_ixs.shape} does not "
58 | f"match the data shape {train_data.shape}.")
59 | return data
60 |
61 | def get_inverse_sqrt_schedule(optimizer, warmup_steps=1000, lr_init=1e-8, lr_max=5e-4):
62 | """
63 | https://github.com/pytorch/fairseq/blob/master/fairseq/optim/lr_scheduler/inverse_square_root_schedule.py
64 | Decay the LR based on the inverse square root of the update number.
65 | We also support a warmup phase where we linearly increase the learning rate
66 | from some initial learning rate (``--warmup-init-lr``) until the configured
67 | learning rate (``--lr``). Thereafter we decay proportional to the number of
68 | updates, with a decay factor set to align with the configured learning rate.
69 | During warmup::
70 | lrs = torch.linspace(args.warmup_init_lr, args.lr, args.warmup_updates)
71 | lr = lrs[update_num]
72 | After warmup::
73 | decay_factor = args.lr * sqrt(args.warmup_updates)
74 | lr = decay_factor / sqrt(update_num)
75 | """
76 | def lr_lambda(current_step):
77 | lr_step = (lr_max - lr_init) / warmup_steps
78 | decay_factor = lr_max * warmup_steps ** 0.5
79 |
80 | if current_step < warmup_steps:
81 | lr = lr_init + current_step * lr_step
82 | else:
83 | lr = decay_factor * current_step ** -0.5
84 | return lr
85 |
86 | return LambdaLR(optimizer, lr_lambda)
--------------------------------------------------------------------------------
/tune_models.py:
--------------------------------------------------------------------------------
1 | # Src: Andrew's tune_tf2
2 |
3 | import os
4 | import os.path as osp
5 | import numpy as np
6 | import ray
7 | from ray import tune
8 | from yacs.config import CfgNode as CN
9 | import torch
10 |
11 | from src.config.default import get_cfg_defaults, unflatten
12 | from src.runner import Runner
13 |
14 | class tuneNDT(tune.Trainable):
15 | """ A wrapper class that allows `tune` to interface with NDT.
16 | """
17 |
18 | def setup(self, config):
19 | yacs_cfg = self.convert_tune_cfg(config)
20 | self.epochs_per_generation = yacs_cfg.TRAIN.TUNE_EPOCHS_PER_GENERATION
21 | self.warmup_epochs = yacs_cfg.TRAIN.TUNE_WARMUP
22 | self.runner = Runner(config=yacs_cfg)
23 | self.runner.load_device()
24 | self.runner.load_train_val_data_and_masker()
25 | num_hidden = self.runner.setup_model(self.runner.device)
26 | self.runner.load_optimizer(num_hidden)
27 |
28 | def step(self):
29 | num_epochs = self.epochs_per_generation
30 | # the first generation always completes ramping (warmup)
31 | if self.runner.count_updates < self.warmup_epochs:
32 | num_epochs += self.warmup_epochs
33 | for i in range(num_epochs):
34 | metrics = self.runner.train_epoch()
35 | return metrics
36 |
37 | def save_checkpoint(self, tmp_ckpt_dir):
38 | path = osp.join(tmp_ckpt_dir, f"{self.runner.config.VARIANT}.{self.runner.count_checkpoints}.pth")
39 | self.runner.save_checkpoint(path)
40 | return path
41 |
42 | def load_checkpoint(self, path):
43 | self.runner.load_checkpoint(path)
44 |
45 | def reset_config(self, new_config):
46 | new_cfg_node = self.convert_tune_cfg(new_config)
47 | self.runner.update_config(new_cfg_node)
48 | return True
49 |
50 | def convert_tune_cfg(self, flat_cfg_dict):
51 | """Converts the tune config dictionary into a CfgNode for LFADS.
52 | """
53 | cfg_node = get_cfg_defaults()
54 |
55 | flat_cfg_dict['CHECKPOINT_DIR'] = osp.join(self.logdir, 'ckpts')
56 | flat_cfg_dict['TENSORBOARD_DIR'] = osp.join(self.logdir, 'tb')
57 | flat_cfg_dict['LOG_DIR'] = osp.join(self.logdir, 'logs')
58 | flat_cfg_dict['TRAIN.TUNE_MODE'] = True
59 | cfg_update = CN(unflatten(flat_cfg_dict))
60 | cfg_node.merge_from_other_cfg(cfg_update)
61 |
62 | return cfg_node
--------------------------------------------------------------------------------