├── .env ├── .gitattributes ├── .gitignore ├── .gitmodules ├── .pre-commit-config.yaml ├── README.md ├── ctd ├── __init__.py ├── comparison │ ├── __init__.py │ ├── analysis │ │ ├── analysis.py │ │ ├── dd │ │ │ └── dd.py │ │ └── tt │ │ │ ├── tasks │ │ │ ├── tt_MultiTask.py │ │ │ └── tt_RandomTarget.py │ │ │ └── tt.py │ ├── comparison.py │ ├── fixedpoints.py │ ├── metrics.py │ └── utils.py ├── data_modeling │ ├── LICENSE │ ├── __init__.py │ ├── callbacks │ │ ├── LFADS │ │ │ ├── callbacks.py │ │ │ └── extensions │ │ │ │ ├── __init__.py │ │ │ │ ├── armnet.py │ │ │ │ ├── nlb.py │ │ │ │ └── task_trained.py │ │ ├── SAE │ │ │ └── sim_callbacks.py │ │ └── metrics.py │ ├── configs │ │ ├── callbacks │ │ │ ├── LFADS │ │ │ │ ├── default.yaml │ │ │ │ ├── default_MultiTask.yaml │ │ │ │ ├── default_NBFF.yaml │ │ │ │ └── default_RandomTarget.yaml │ │ │ └── SAE │ │ │ │ ├── default_MultiTask.yaml │ │ │ │ ├── default_NBFF.yaml │ │ │ │ └── default_RandomTarget.yaml │ │ ├── datamodules │ │ │ ├── LFADS │ │ │ │ ├── data_MultiTask.yaml │ │ │ │ ├── data_MultiTask_infer.yaml │ │ │ │ ├── data_NBFF.yaml │ │ │ │ ├── data_NBFF_infer.yaml │ │ │ │ ├── data_RandomTarget.yaml │ │ │ │ └── data_RandomTarget_infer.yaml │ │ │ └── SAE │ │ │ │ ├── data_MultiTask.yaml │ │ │ │ ├── data_NBFF.yaml │ │ │ │ └── data_RandomTarget.yaml │ │ ├── extensions │ │ │ └── LFADS │ │ │ │ └── posterior_sampling.yaml │ │ ├── loggers │ │ │ ├── LFADS │ │ │ │ ├── default.yaml │ │ │ │ └── default_no_wandb.yaml │ │ │ └── SAE │ │ │ │ ├── default.yaml │ │ │ │ └── default_no_wandb.yaml │ │ ├── models │ │ │ ├── LFADS │ │ │ │ ├── MultiTask │ │ │ │ │ ├── MultiTask_LFADS.yaml │ │ │ │ │ └── MultiTask_LFADS_infer.yaml │ │ │ │ ├── NBFF │ │ │ │ │ ├── NBFF_LFADS.yaml │ │ │ │ │ └── NBFF_LFADS_infer.yaml │ │ │ │ └── RandomTarget │ │ │ │ │ ├── RandomTarget_LFADS.yaml │ │ │ │ │ └── RandomTarget_LFADS_infer.yaml │ │ │ └── SAE │ │ │ │ ├── MultiTask │ │ │ │ ├── MultiTask_GRU_RNN.yaml │ │ │ │ ├── MultiTask_LDS.yaml │ │ │ │ ├── MultiTask_NODE.yaml │ │ │ │ └── MultiTask_Vanilla_RNN.yaml │ │ │ │ ├── NBFF │ │ │ │ ├── NBFF_GRU_RNN.yaml │ │ │ │ ├── NBFF_LDS.yaml │ │ │ │ ├── NBFF_NODE.yaml │ │ │ │ └── NBFF_Vanilla_RNN.yaml │ │ │ │ └── RandomTarget │ │ │ │ ├── RandomTarget_GRU_RNN.yaml │ │ │ │ ├── RandomTarget_LDS.yaml │ │ │ │ ├── RandomTarget_NODE.yaml │ │ │ │ └── RandomTarget_Vanilla_RNN.yaml │ │ └── trainers │ │ │ ├── LFADS │ │ │ ├── trainer_MultiTask.yaml │ │ │ ├── trainer_NBFF.yaml │ │ │ └── trainer_RandomTarget.yaml │ │ │ └── SAE │ │ │ ├── trainer_MultiTask.yaml │ │ │ ├── trainer_NBFF.yaml │ │ │ └── trainer_RandomTarget.yaml │ ├── datamodules │ │ ├── LFADS │ │ │ ├── __init__.py │ │ │ ├── datamodule.py │ │ │ └── tuples.py │ │ ├── SAE │ │ │ └── task_trained_data.py │ │ ├── __init__.py │ │ └── utils.py │ ├── extensions │ │ ├── LFADS │ │ │ ├── metrics.py │ │ │ ├── post_run │ │ │ │ ├── __init__.py │ │ │ │ ├── analysis.py │ │ │ │ └── pbt.py │ │ │ └── utils.py │ │ ├── SAE │ │ │ └── utils.py │ │ └── evaluation.py │ ├── models │ │ ├── LFADS │ │ │ ├── __init__.py │ │ │ ├── lfads.py │ │ │ └── modules │ │ │ │ ├── __init__.py │ │ │ │ ├── augmentations.py │ │ │ │ ├── decoder.py │ │ │ │ ├── encoder.py │ │ │ │ ├── initializers.py │ │ │ │ ├── l2.py │ │ │ │ ├── l2_simple.py │ │ │ │ ├── priors.py │ │ │ │ ├── readin_readout.py │ │ │ │ ├── recons.py │ │ │ │ └── recurrent.py │ │ ├── SAE │ │ │ ├── LDS.py │ │ │ ├── __init__.py │ │ │ ├── dyn_models.py │ │ │ ├── dyn_models_gru.py │ │ │ ├── dyn_models_rnn.py │ │ │ ├── gru_rnn.py │ │ │ ├── lds.py │ │ │ ├── loss_func.py │ │ │ ├── node.py │ │ │ ├── readouts.py │ │ │ ├── template.py │ │ │ └── vanilla_rnn.py │ │ └── __init__.py │ └── train_PTL.py └── task_modeling │ ├── __init__.py │ ├── callbacks │ ├── __init__.py │ ├── callbacks.py │ ├── callbacks_coupled.py │ └── callbacks_multitask.py │ ├── configs │ ├── callbacks │ │ ├── default_MultiTask.yaml │ │ ├── default_NBFF.yaml │ │ ├── default_RandomTarget.yaml │ │ └── default_no_wandb.yaml │ ├── datamodule_sim │ │ ├── datamodule_MultiTask.yaml │ │ ├── datamodule_NBFF.yaml │ │ └── datamodule_RandomTarget.yaml │ ├── datamodule_train │ │ ├── datamodule_MultiTask.yaml │ │ ├── datamodule_NBFF.yaml │ │ └── datamodule_RandomTarget.yaml │ ├── env_sim │ │ ├── MultiTask.yaml │ │ ├── NBFF.yaml │ │ └── RandomTarget.yaml │ ├── env_task │ │ ├── MultiTask.yaml │ │ ├── NBFF.yaml │ │ └── RandomTarget.yaml │ ├── logger │ │ ├── default.yaml │ │ └── default_no_wandb.yaml │ ├── model │ │ ├── DriscollRNN.yaml │ │ ├── GRU_RNN.yaml │ │ ├── NODE.yaml │ │ ├── NoisyGRU.yaml │ │ ├── NoisyGRULatentL2.yaml │ │ └── Vanilla_RNN.yaml │ ├── simulator │ │ ├── default_MultiTask.yaml │ │ ├── default_NBFF.yaml │ │ └── default_RandomTarget.yaml │ ├── task_wrapper │ │ ├── MultiTask.yaml │ │ ├── NBFF.yaml │ │ └── RandomTarget.yaml │ └── trainer │ │ └── default.yaml │ ├── datamodule │ ├── __init__.py │ ├── samplers.py │ └── task_datamodule.py │ ├── model │ ├── __init__.py │ ├── node.py │ ├── rnn.py │ └── tt_template.py │ ├── simulator │ ├── __init__.py │ └── neural_simulator.py │ ├── task_env │ ├── __init__.py │ ├── loss_func.py │ ├── multitask.py │ ├── old │ │ └── alternative_tasks.py.py │ ├── random_target.py │ └── task_env.py │ ├── task_training.py │ └── task_wrapper │ ├── __init__.py │ ├── task_wrapper.py │ └── utils.py ├── examples ├── compare_dd_tt_models.ipynb ├── figures │ ├── Embedding │ │ ├── Figure4_NBFF_TTNODE_DTGRU_LatSweeps.ipynb │ │ └── Figure4_NBFF_TTNODE_DTNODE_LatSweeps.ipynb │ ├── Fig1PerfCriteria │ │ ├── Fig1OneBitFlipFlopGen.ipynb │ │ └── old │ │ │ ├── Fig8Gen.ipynb │ │ │ └── lorenz.ipynb │ ├── Fig3TaskPerformance │ │ ├── Figure3AllTasks.ipynb │ │ ├── Figure3MultiTask.ipynb │ │ ├── Figure3NBFF.ipynb │ │ ├── Figure3RandomTarget.ipynb │ │ ├── MemoryPro_MemoryProPCs_combined_video.gif │ │ └── bumpMove.ipynb │ ├── Fig4Canonical │ │ ├── CanonicalDatasetPerf.ipynb │ │ ├── LearningProgress.ipynb │ │ └── NBFF_Exploration.ipynb │ ├── Fig5Metrics │ │ ├── Figure1_NBFF_TTGRUSweep.ipynb │ │ └── SLDS │ │ │ └── slds.ipynb │ ├── Fig6InputInf │ │ ├── FigInputInf.ipynb │ │ └── FigInputInfFPFinding.ipynb │ ├── HPSweeping │ │ ├── Figure3_NBFF_TTGRU_DTNODE_LatSweeps.ipynb │ │ └── Figure3_NBFF_TTNODE_DTGRU_LatSweeps.ipynb │ ├── TaskPerformance │ │ └── Figure3MultiTask.ipynb │ └── websiteVids │ │ └── videoGen.ipynb ├── gen_datasets.py ├── notebooks │ ├── old │ │ ├── CtDTutorial.ipynb │ │ ├── NeuroMatchNotebook.ipynb │ │ └── WorkshopNotebook.ipynb │ └── png │ │ ├── AnalysisStructure-01.png │ │ ├── BenchmarkFlow2-01.png │ │ ├── BenchmarkFlowTTDT-01-01.png │ │ ├── BenchmarkGrey-01.png │ │ ├── BenchmarkSchematicSimple_steps.png │ │ ├── DSAPic.png │ │ ├── FinalGif.gif │ │ ├── Hourglass.png │ │ ├── MemoryPro_MemoryProPCs_combined_video.gif │ │ ├── MotorNet Illustration-01.png │ │ ├── MotorNetGif.gif │ │ ├── NoteBookQR.png │ │ ├── Problem.png │ │ ├── SAE.png │ │ ├── SimulationDiagram.png │ │ ├── StateR2-01.png │ │ ├── Step1-01.png │ │ ├── Step2-01.png │ │ ├── Step3-01.png │ │ ├── Step4-01.png │ │ ├── SussilloBarack.png │ │ ├── TTModelExample-01.png │ │ ├── TaskComplexity-01.png │ │ ├── TaskEnvs-01.png │ │ ├── TaskTrained-01.png │ │ ├── TaskTraininSchematic-01.png │ │ ├── Template.png │ │ ├── TutorialTT-01.png │ │ ├── TutorialTT0-01.png │ │ ├── TutorialTTComp-01.png │ │ ├── TutorialTT_model-01.png │ │ ├── lfads_fps.png │ │ └── loopingMultiTask.gif ├── run_data_training.py └── run_task_training.py ├── pretrained ├── .gitattributes ├── 20241017_NBFF_NoisyGRU_NewFinal │ ├── .gitattributes │ ├── datamodule_sim.pkl │ ├── datamodule_train.pkl │ ├── model.pkl │ └── simulator.pkl ├── 20241113_MultiTask_NoisyGRU_Final2 │ ├── datamodule_sim.pkl │ ├── datamodule_train.pkl │ ├── model.pkl │ └── simulator.pkl └── 20241113_RandomTarget_NoisyGRU_Final2 │ ├── datamodule_sim.pkl │ ├── datamodule_train.pkl │ ├── model.pkl │ └── simulator.pkl ├── requirements.txt ├── setup.py └── utils.py /.env: -------------------------------------------------------------------------------- 1 | HOME_DIR = /home/csverst/Github/CtDBenchmark/ 2 | 3 | # Don't change these 4 | TRAIN_INPUT_FILE=train_input.h5 5 | EVAL_INPUT_FILE=eval_input.h5 6 | EVAL_TARGET_FILE=eval_target.h5 7 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | pretrained/*.pkl filter=lfs diff=lfs merge=lfs -text 2 | pretrained/20241017_NBFF_NoisyGRU_NewFinal/datamodule_sim.pkl filter=lfs diff=lfs merge=lfs -text 3 | pretrained/20241017_NBFF_NoisyGRU_NewFinal/datamodule_train.pkl filter=lfs diff=lfs merge=lfs -text 4 | pretrained/20241017_NBFF_NoisyGRU_NewFinal/model.pkl filter=lfs diff=lfs merge=lfs -text 5 | pretrained/20241017_NBFF_NoisyGRU_NewFinal/simulator.pkl filter=lfs diff=lfs merge=lfs -text 6 | 7 | pretrained/20241113_RandomTarget_NoisyGRU_Final2/datamodule_sim.pkl filter=lfs diff=lfs merge=lfs -text 8 | pretrained/20241113_RandomTarget_NoisyGRU_Final2/datamodule_train.pkl filter=lfs diff=lfs merge=lfs -text 9 | pretrained/20241113_RandomTarget_NoisyGRU_Final2/model.pkl filter=lfs diff=lfs merge=lfs -text 10 | pretrained/20241113_RandomTarget_NoisyGRU_Final2/simulator.pkl filter=lfs diff=lfs merge=lfs -text 11 | 12 | pretrained/20241113_MultiTask_NoisyGRU_Final2/datamodule_sim.pkl filter=lfs diff=lfs merge=lfs -text 13 | pretrained/20241113_MultiTask_NoisyGRU_Final2/datamodule_train.pkl filter=lfs diff=lfs merge=lfs -text 14 | pretrained/20241113_MultiTask_NoisyGRU_Final2/model.pkl filter=lfs diff=lfs merge=lfs -text 15 | pretrained/20241113_MultiTask_NoisyGRU_Final2/simulator.pkl filter=lfs diff=lfs merge=lfs -text 16 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | *.pyc 9 | *.pkl 10 | *.h5 11 | 12 | # Distribution / packaging 13 | .Python 14 | events* 15 | *.png 16 | *.pdf 17 | *.pyc 18 | build/ 19 | develop-eggs/ 20 | dist/ 21 | downloads/ 22 | eggs/ 23 | .eggs/ 24 | lib/ 25 | lib64/ 26 | parts/ 27 | sdist/ 28 | var/ 29 | wheels/ 30 | pip-wheel-metadata/ 31 | share/python-wheels/ 32 | *.egg-info/ 33 | .installed.cfg 34 | *.egg 35 | MANIFEST 36 | *.mat 37 | *.p 38 | *.pyc 39 | *.html 40 | *.svg 41 | *.mp4 42 | 43 | *.avi 44 | *.json 45 | 46 | data/runs/* 47 | content/* 48 | old/* 49 | examples/dev/* 50 | 51 | !/examples/png/*.png 52 | !/pretrained/**.pkl 53 | # PyInstaller 54 | # Usually these files are written by a python script from a template 55 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 56 | *.manifest 57 | *.spec 58 | 59 | # Installer logs 60 | pip-log.txt 61 | pip-delete-this-directory.txt 62 | 63 | # Unit test / coverage reports 64 | htmlcov/ 65 | .tox/ 66 | .nox/ 67 | .coverage 68 | .coverage.* 69 | .cache 70 | nosetests.xml 71 | coverage.xml 72 | *.cover 73 | *.py,cover 74 | .hypothesis/ 75 | .pytest_cache/ 76 | 77 | # Translations 78 | *.mo 79 | *.pot 80 | 81 | # Django stuff: 82 | *.log 83 | local_settings.py 84 | db.sqlite3 85 | db.sqlite3-journal 86 | 87 | # Flask stuff: 88 | instance/ 89 | .webassets-cache 90 | 91 | # Scrapy stuff: 92 | .scrapy 93 | 94 | # Sphinx documentation 95 | docs/_build/ 96 | 97 | # PyBuilder 98 | target/ 99 | 100 | # Jupyter Notebook 101 | .ipynb_checkpoints 102 | 103 | # IPython 104 | profile_default/ 105 | ipython_config.py 106 | 107 | # pyenv 108 | .python-version 109 | 110 | # pipenv 111 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 112 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 113 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 114 | # install all needed dependencies. 115 | #Pipfile.lock 116 | 117 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 118 | __pypackages__/ 119 | 120 | # Celery stuff 121 | celerybeat-schedule 122 | celerybeat.pid 123 | 124 | # SageMath parsed files 125 | *.sage.py 126 | 127 | # Environments 128 | #.env 129 | .venv 130 | env/ 131 | venv/ 132 | ENV/ 133 | env.bak/ 134 | venv.bak/ 135 | 136 | # Spyder project settings 137 | .spyderproject 138 | .spyproject 139 | 140 | # Rope project settings 141 | .ropeproject 142 | 143 | # mkdocs documentation 144 | /site 145 | 146 | # mypy 147 | .mypy_cache/ 148 | .dmypy.json 149 | dmypy.json 150 | 151 | # Pyre type checker 152 | .pyre/ 153 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "libs/DSA"] 2 | path = libs/DSA 3 | url = https://github.com/mitchellostrow/DSA.git 4 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | exclude: 'examples/run_task_training\.py' 2 | 3 | repos: 4 | - repo: https://github.com/pre-commit/pre-commit-hooks 5 | rev: v3.4.0 6 | hooks: 7 | # list of supported hooks: https://pre-commit.com/hooks.html 8 | - id: trailing-whitespace 9 | - id: end-of-file-fixer 10 | - id: check-yaml 11 | - id: debug-statements 12 | - id: detect-private-key 13 | 14 | # python code formatting 15 | - repo: https://github.com/psf/black 16 | rev: 22.3.0 17 | hooks: 18 | - id: black 19 | 20 | # python import sorting 21 | - repo: https://github.com/PyCQA/isort 22 | rev: 5.8.0 23 | hooks: 24 | - id: isort 25 | args: ["--profile", "black"] 26 | 27 | # yaml formatting 28 | - repo: https://github.com/pre-commit/mirrors-prettier 29 | rev: v2.3.0 30 | hooks: 31 | - id: prettier 32 | types: [yaml] 33 | 34 | # python code analysis 35 | - repo: https://github.com/PyCQA/flake8 36 | rev: 3.9.2 37 | hooks: 38 | - id: flake8 39 | args: ["--max-line-length", "88", "--extend-ignore", "E203"] 40 | -------------------------------------------------------------------------------- /ctd/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snel-repo/ComputationThroughDynamicsBenchmark/60caed2c6bb6e73ce3d05f99eda992f509e9ed94/ctd/__init__.py -------------------------------------------------------------------------------- /ctd/comparison/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snel-repo/ComputationThroughDynamicsBenchmark/60caed2c6bb6e73ce3d05f99eda992f509e9ed94/ctd/comparison/__init__.py -------------------------------------------------------------------------------- /ctd/comparison/analysis/analysis.py: -------------------------------------------------------------------------------- 1 | class Analysis: 2 | def __init__(self, run_name, filepath): 3 | self.run_name = run_name 4 | self.filepath = filepath 5 | 6 | def load_wrapper(self, filepath): 7 | # Throw a warning 8 | return None 9 | 10 | def get_model_output(self): 11 | return None 12 | 13 | def compute_FPs(self, latents, inputs): 14 | return None 15 | -------------------------------------------------------------------------------- /ctd/data_modeling/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Andrew Sedler 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /ctd/data_modeling/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snel-repo/ComputationThroughDynamicsBenchmark/60caed2c6bb6e73ce3d05f99eda992f509e9ed94/ctd/data_modeling/__init__.py -------------------------------------------------------------------------------- /ctd/data_modeling/callbacks/LFADS/extensions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snel-repo/ComputationThroughDynamicsBenchmark/60caed2c6bb6e73ce3d05f99eda992f509e9ed94/ctd/data_modeling/callbacks/LFADS/extensions/__init__.py -------------------------------------------------------------------------------- /ctd/data_modeling/callbacks/LFADS/extensions/nlb.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import pytorch_lightning as pl 4 | import torch 5 | from nlb_tools.evaluation import ( 6 | bits_per_spike, 7 | eval_psth, 8 | speed_tp_correlation, 9 | velocity_decoding, 10 | ) 11 | from scipy.linalg import LinAlgWarning 12 | 13 | from ..metrics import ExpSmoothedMetric 14 | from ..utils import send_batch_to_device 15 | 16 | 17 | class NLBEvaluation(pl.Callback): 18 | """Computes and logs all evaluation metrics for the Neural Latents 19 | Benchmark to tensorboard. These include `co_bps`, `fp_bps`, 20 | `behavior_r2`, `psth_r2`, and `tp_corr`. 21 | 22 | To enable this functionality, install nlb_tools 23 | (https://github.com/neurallatents/nlb_tools). 24 | """ 25 | 26 | def __init__(self, log_every_n_epochs=20, decoding_cv_sweep=False): 27 | """Initializes the callback. 28 | 29 | Parameters 30 | ---------- 31 | log_every_n_epochs : int, optional 32 | The frequency with which to plot and log, by default 100 33 | decoding_cv_sweep : bool, optional 34 | Whether to run a cross-validated hyperparameter sweep to 35 | find optimal regularization values, by default False 36 | """ 37 | self.log_every_n_epochs = log_every_n_epochs 38 | self.decoding_cv_sweep = decoding_cv_sweep 39 | self.smth_metrics = {} 40 | 41 | def on_validation_epoch_end(self, trainer, pl_module): 42 | """Logs plots at the end of the validation epoch. 43 | 44 | Parameters 45 | ---------- 46 | trainer : pytorch_lightning.Trainer 47 | The trainer currently handling the model. 48 | pl_module : pytorch_lightning.LightningModule 49 | The model currently being trained. 50 | """ 51 | # Skip evaluation for most epochs to save time 52 | if (trainer.current_epoch % self.log_every_n_epochs) != 0: 53 | return 54 | # Get the dataloaders 55 | pred_dls = trainer.datamodule.predict_dataloader() 56 | s = 0 57 | val_dataloader = pred_dls[s]["valid"] 58 | train_dataloader = pred_dls[s]["train"] 59 | # Create object to store evaluation metrics 60 | metrics = {} 61 | # Get entire validation dataset from datamodule 62 | (input_data, recon_data, *_), (behavior,) = trainer.datamodule.valid_data[s] 63 | recon_data = recon_data.detach().cpu().numpy() 64 | behavior = behavior.detach().cpu().numpy() 65 | # Pass the data through the model 66 | # TODO: Replace this with Trainer.predict? Hesitation is that switching to 67 | # Trainer.predict for posterior sampling is inefficient because we can't 68 | # tell it how many forward passes to use. 69 | rates = [] 70 | for batch in val_dataloader: 71 | batch = send_batch_to_device({s: batch}, pl_module.device) 72 | output = pl_module.predict_step(batch, None, sample_posteriors=False)[s] 73 | rates.append(output.output_params) 74 | rates = torch.cat(rates).detach().cpu().numpy() 75 | # Compute co-smoothing bits per spike 76 | _, n_obs, n_heldin = input_data.shape 77 | heldout = recon_data[:, :n_obs, n_heldin:] 78 | rates_heldout = rates[:, :n_obs, n_heldin:] 79 | co_bps = bits_per_spike(rates_heldout, heldout) 80 | metrics["nlb/co_bps"] = max(co_bps, -1.0) 81 | # Compute forward prediction bits per spike 82 | forward = recon_data[:, n_obs:] 83 | rates_forward = rates[:, n_obs:] 84 | fp_bps = bits_per_spike(rates_forward, forward) 85 | metrics["nlb/fp_bps"] = max(fp_bps, -1.0) 86 | # Get relevant training dataset from datamodule 87 | _, (train_behavior,) = trainer.datamodule.train_data[s] 88 | train_behavior = train_behavior.detach().cpu().numpy() 89 | # Get model predictions for the training dataset 90 | train_rates = [] 91 | for batch in train_dataloader: 92 | batch = send_batch_to_device({s: batch}, pl_module.device) 93 | output = pl_module.predict_step(batch, None, sample_posteriors=False)[s] 94 | train_rates.append(output.output_params) 95 | train_rates = torch.cat(train_rates).detach().cpu().numpy() 96 | # Get firing rates for observed time points 97 | rates_obs = rates[:, :n_obs] 98 | train_rates_obs = train_rates[:, :n_obs] 99 | # Compute behavioral decoding performance 100 | if behavior.ndim < 3: 101 | tp_corr = speed_tp_correlation(heldout, rates_obs, behavior) 102 | metrics["nlb/tp_corr"] = tp_corr 103 | else: 104 | with warnings.catch_warnings(): 105 | # Ignore LinAlgWarning from early in training 106 | warnings.filterwarnings("ignore", category=LinAlgWarning) 107 | behavior_r2 = velocity_decoding( 108 | train_rates_obs, 109 | train_behavior, 110 | trainer.datamodule.train_decode_mask, 111 | rates_obs, 112 | behavior, 113 | trainer.datamodule.valid_decode_mask, 114 | self.decoding_cv_sweep, 115 | ) 116 | metrics["nlb/behavior_r2"] = max(behavior_r2, -1.0) 117 | # Compute PSTH reconstruction performance 118 | if hasattr(trainer.datamodule, "psth"): 119 | psth = trainer.datamodule.psth 120 | cond_idxs = trainer.datamodule.valid_cond_idx 121 | jitter = getattr(trainer.datamodule, "valid_jitter", None) 122 | psth_r2 = eval_psth(psth, rates_obs, cond_idxs, jitter) 123 | metrics["nlb/psth_r2"] = max(psth_r2, -1.0) 124 | # Compute smoothed metrics 125 | for k, v in metrics.items(): 126 | if k not in self.smth_metrics: 127 | self.smth_metrics[k] = ExpSmoothedMetric(coef=0.7) 128 | self.smth_metrics[k].update(v, 1) 129 | # Log actual and smoothed metrics 130 | pl_module.log_dict( 131 | { 132 | **metrics, 133 | **{k + "_smth": m.compute() for k, m in self.smth_metrics.items()}, 134 | } 135 | ) 136 | # Reset the smoothed metrics (per-step aggregation not necessary) 137 | [m.reset() for m in self.smth_metrics.values()] 138 | -------------------------------------------------------------------------------- /ctd/data_modeling/callbacks/metrics.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn.functional as F 6 | from sklearn.linear_model import LinearRegression 7 | from sklearn.metrics import r2_score 8 | from torch.nn.functional import poisson_nll_loss 9 | 10 | 11 | def _default_2d_array(array): 12 | return array.reshape(-1, array.shape[-1]) 13 | 14 | 15 | def _default_2d_func(func): 16 | def wrapper(preds, targets): 17 | return func(_default_2d_array(preds), _default_2d_array(targets)) 18 | 19 | return wrapper 20 | 21 | 22 | # @_default_2d_func 23 | # def r2_score(preds, targets): 24 | # target_mean = torch.mean(targets, dim=0) 25 | # ss_tot = torch.sum((targets - target_mean) ** 2, dim=0) 26 | # ss_res = torch.sum((targets - preds) ** 2, dim=0) 27 | # return torch.mean(1 - ss_res / ss_tot) 28 | 29 | 30 | @_default_2d_func 31 | def linear_regression(preds, targets): 32 | preds_1 = F.pad(preds, (0, 1), value=1) 33 | W = preds_1.pinverse() @ targets 34 | return preds_1 @ W 35 | 36 | 37 | @_default_2d_func 38 | def regression_r2_score(preds, targets): 39 | projs = linear_regression(preds, targets) 40 | return torch.clamp_min(r2_score(projs, targets), -10) 41 | 42 | 43 | @_default_2d_func 44 | def regression_mse(preds, targets): 45 | projs = linear_regression(preds, targets) 46 | return F.mse_loss(projs, targets) 47 | 48 | 49 | def weighted_loss(loss_fn, preds, targets, weight=1.0): 50 | loss_all = loss_fn(input=preds, target=targets, reduction="none") 51 | return torch.mean(weight * loss_all) 52 | 53 | 54 | def bits_per_spike(preds, targets): 55 | """ 56 | Computes BPS for n_samples x n_timesteps x n_neurons arrays. 57 | Preds are logrates and targets are binned spike counts. 58 | """ 59 | if len(preds.shape) == 3: 60 | dim = (0, 1) 61 | elif len(preds.shape) == 2: 62 | dim = 0 63 | nll_model = poisson_nll_loss(preds, targets, full=True, reduction="sum") 64 | nll_null = poisson_nll_loss( 65 | torch.mean(targets, dim=dim, keepdim=True), 66 | targets, 67 | log_input=False, 68 | full=True, 69 | reduction="sum", 70 | ) 71 | return (nll_null - nll_model) / torch.nansum(targets) / math.log(2) 72 | 73 | 74 | def compute_metrics( 75 | true_rates, 76 | inf_rates, 77 | true_latents, 78 | inf_latents, 79 | true_inputs, 80 | inf_inputs, 81 | true_spikes, 82 | n_heldin, 83 | device=None, 84 | ): 85 | if device is None: 86 | device = true_rates.device 87 | # Compute Rate R2 88 | rate_r2 = r2_score(true_rates, inf_rates, multioutput="variance_weighted") 89 | 90 | # Compute Input R2 91 | if inf_inputs is None or true_inputs is None: 92 | input_r2 = np.nan 93 | else: 94 | lm = LinearRegression() 95 | lm.fit(inf_inputs, true_inputs) 96 | true_inputs_pred = lm.predict(inf_inputs) 97 | input_r2 = r2_score( 98 | true_inputs, true_inputs_pred, multioutput="variance_weighted" 99 | ) 100 | 101 | # Compute Latent R2 102 | lm = LinearRegression() 103 | lm.fit(true_latents, inf_latents) 104 | latent_pred_flat = lm.predict(true_latents) 105 | latent_r2 = r2_score(inf_latents, latent_pred_flat, multioutput="variance_weighted") 106 | 107 | bps = bits_per_spike( 108 | torch.tensor(np.log(inf_rates)).float(), torch.tensor(true_spikes).float() 109 | ).item() 110 | hi_bps = bits_per_spike( 111 | torch.tensor(np.log(inf_rates[:, :n_heldin])).float(), 112 | torch.tensor(true_spikes[:, :n_heldin]).float(), 113 | ).item() 114 | ho_bps = bits_per_spike( 115 | torch.tensor(np.log(inf_rates[:, n_heldin:])).float(), 116 | torch.tensor(true_spikes[:, n_heldin:]).float(), 117 | ).item() 118 | 119 | torch.set_grad_enabled(True) 120 | inf_latents_torch = torch.tensor(inf_latents).float().to(device) 121 | inf_rates_torch = torch.tensor(inf_rates).float().to(device) 122 | mlp = torch.nn.Sequential( 123 | torch.nn.Linear(inf_rates.shape[1], 100), 124 | torch.nn.ReLU(), 125 | torch.nn.Linear(100, 100), 126 | torch.nn.ReLU(), 127 | torch.nn.Linear(100, inf_latents.shape[1]), 128 | ).to(device) 129 | optimizer = torch.optim.Adam(mlp.parameters(), lr=1e-3) 130 | criterion = torch.nn.MSELoss() 131 | for _ in range(100): 132 | optimizer.zero_grad() 133 | pred = mlp(inf_rates_torch) 134 | loss = criterion(pred, inf_latents_torch) 135 | loss.backward() 136 | optimizer.step() 137 | 138 | metric_dict = { 139 | "rate_r2": rate_r2, 140 | "input_r2": input_r2, 141 | "state_r2": latent_r2, 142 | "bps": bps, 143 | "hi_bps": hi_bps, 144 | "ho_bps": ho_bps, 145 | } 146 | noise_levels = np.linspace(0, 1, 6) 147 | for noise in noise_levels: 148 | noised_rates_flat = inf_rates_torch + torch.rand_like(inf_rates_torch) * noise 149 | latent_pred_flat = mlp(noised_rates_flat).detach().cpu().numpy() 150 | cycle_con_r2 = r2_score( 151 | inf_latents, latent_pred_flat, multioutput="variance_weighted" 152 | ) 153 | metric_dict[f"cycle_con_{noise:.2f}_r2"] = cycle_con_r2 154 | return metric_dict 155 | -------------------------------------------------------------------------------- /ctd/data_modeling/configs/callbacks/LFADS/default.yaml: -------------------------------------------------------------------------------- 1 | raster_plot: 2 | _target_: ctd.data_modeling.callbacks.LFADS.callbacks.RasterPlot 3 | log_every_n_epochs: 100 4 | # trajectory_plot: 5 | # _target_: ctd.data_modeling.data_training.callbacks.TrajectoryPlot 6 | # log_every_n_epochs: 100 7 | -------------------------------------------------------------------------------- /ctd/data_modeling/configs/callbacks/LFADS/default_MultiTask.yaml: -------------------------------------------------------------------------------- 1 | raster_plot: 2 | _target_: ctd.data_modeling.callbacks.LFADS.callbacks.RasterPlot 3 | log_every_n_epochs: 100 4 | tune_report_callback: 5 | _target_: ray.tune.integration.pytorch_lightning.TuneReportCallback 6 | metrics: 7 | loss: valid/loss 8 | trajectory_plot: 9 | _target_: ctd.data_modeling.callbacks.LFADS.callbacks.TrajectoryPlot 10 | log_every_n_epochs: 100 11 | 12 | input_accuracy: 13 | _target_: ctd.data_modeling.callbacks.LFADS.extensions.task_trained.InputR2Plot 14 | log_every_n_epochs: 20 15 | n_samples: 16 16 | 17 | model_checkpoint: 18 | _target_: pytorch_lightning.callbacks.ModelCheckpoint 19 | monitor: valid/recon_smth 20 | mode: min 21 | save_top_k: 1 22 | save_last: True 23 | verbose: False 24 | dirpath: lightning_checkpoints 25 | auto_insert_metric_name: False 26 | early_stopping: 27 | _target_: pytorch_lightning.callbacks.EarlyStopping 28 | monitor: valid/recon_smth 29 | mode: min 30 | patience: 200 31 | min_delta: 0 32 | learning_rate_monitor: 33 | _target_: pytorch_lightning.callbacks.LearningRateMonitor 34 | logging_interval: epoch 35 | -------------------------------------------------------------------------------- /ctd/data_modeling/configs/callbacks/LFADS/default_NBFF.yaml: -------------------------------------------------------------------------------- 1 | raster_plot: 2 | _target_: ctd.data_modeling.callbacks.LFADS.callbacks.RasterPlot 3 | log_every_n_epochs: 100 4 | tune_report_callback: 5 | _target_: ray.tune.integration.pytorch_lightning.TuneReportCallback 6 | metrics: 7 | loss: valid/loss 8 | trajectory_plot: 9 | _target_: ctd.data_modeling.callbacks.LFADS.callbacks.TrajectoryPlot 10 | log_every_n_epochs: 100 11 | dt_metrics: 12 | _target_: ctd.data_modeling.callbacks.LFADS.callbacks.DTMetricsCallback 13 | log_every_n_epochs: 100 14 | 15 | input_accuracy: 16 | _target_: ctd.data_modeling.callbacks.LFADS.extensions.task_trained.InputR2Plot 17 | log_every_n_epochs: 20 18 | n_samples: 16 19 | 20 | model_checkpoint: 21 | _target_: pytorch_lightning.callbacks.ModelCheckpoint 22 | monitor: valid/recon_smth 23 | mode: min 24 | save_top_k: 1 25 | save_last: True 26 | verbose: False 27 | dirpath: lightning_checkpoints 28 | auto_insert_metric_name: False 29 | early_stopping: 30 | _target_: pytorch_lightning.callbacks.EarlyStopping 31 | monitor: valid/recon_smth 32 | mode: min 33 | patience: 200 34 | min_delta: 0 35 | learning_rate_monitor: 36 | _target_: pytorch_lightning.callbacks.LearningRateMonitor 37 | logging_interval: epoch 38 | -------------------------------------------------------------------------------- /ctd/data_modeling/configs/callbacks/LFADS/default_RandomTarget.yaml: -------------------------------------------------------------------------------- 1 | raster_plot: 2 | _target_: ctd.data_modeling.callbacks.LFADS.callbacks.RasterPlot 3 | log_every_n_epochs: 100 4 | tune_report_callback: 5 | _target_: ray.tune.integration.pytorch_lightning.TuneReportCallback 6 | metrics: 7 | loss: valid/loss 8 | trajectory_plot: 9 | _target_: ctd.data_modeling.callbacks.LFADS.callbacks.TrajectoryPlot 10 | log_every_n_epochs: 100 11 | 12 | input_accuracy: 13 | _target_: ctd.data_modeling.callbacks.LFADS.extensions.task_trained.InputR2Plot 14 | log_every_n_epochs: 20 15 | n_samples: 16 16 | 17 | model_checkpoint: 18 | _target_: pytorch_lightning.callbacks.ModelCheckpoint 19 | monitor: valid/recon_smth 20 | mode: min 21 | save_top_k: 1 22 | save_last: True 23 | verbose: False 24 | dirpath: lightning_checkpoints 25 | auto_insert_metric_name: False 26 | early_stopping: 27 | _target_: pytorch_lightning.callbacks.EarlyStopping 28 | monitor: valid/recon_smth 29 | mode: min 30 | patience: 200 31 | min_delta: 0 32 | learning_rate_monitor: 33 | _target_: pytorch_lightning.callbacks.LearningRateMonitor 34 | logging_interval: epoch 35 | -------------------------------------------------------------------------------- /ctd/data_modeling/configs/callbacks/SAE/default_MultiTask.yaml: -------------------------------------------------------------------------------- 1 | model_checkpoint: 2 | _target_: pytorch_lightning.callbacks.ModelCheckpoint 3 | dirpath: "." 4 | monitor: valid/loss_all 5 | save_last: True 6 | 7 | tune_report_callback: 8 | _target_: ray.tune.integration.pytorch_lightning.TuneReportCallback 9 | metrics: 10 | loss: valid/loss_all 11 | 12 | raster_plot_callback: 13 | _target_: ctd.data_modeling.callbacks.SAE.sim_callbacks.RasterPlot 14 | log_every_n_epochs: 100 15 | 16 | trajectory_plot_callback: 17 | _target_: ctd.data_modeling.callbacks.SAE.sim_callbacks.TrajectoryPlot 18 | log_every_n_epochs: 100 19 | 20 | trajectory_callback: 21 | _target_: ctd.data_modeling.callbacks.SAE.sim_callbacks.TrajectoryPlotOverTimeCallback 22 | log_every_n_epochs: 100 23 | 24 | avg_firing_rate_callback: 25 | _target_: ctd.data_modeling.callbacks.SAE.sim_callbacks.AvgFiringRateCallback 26 | log_every_n_epochs: 100 27 | 28 | inputs_plot: 29 | _target_: ctd.data_modeling.callbacks.SAE.sim_callbacks.InputsPlot 30 | log_every_n_epochs: 100 31 | -------------------------------------------------------------------------------- /ctd/data_modeling/configs/callbacks/SAE/default_NBFF.yaml: -------------------------------------------------------------------------------- 1 | model_checkpoint: 2 | _target_: pytorch_lightning.callbacks.ModelCheckpoint 3 | dirpath: "." 4 | monitor: valid/loss_all 5 | save_last: True 6 | 7 | tune_report_callback: 8 | _target_: ray.tune.integration.pytorch_lightning.TuneReportCallback 9 | metrics: 10 | loss: valid/loss_all 11 | 12 | raster_plot_callback: 13 | _target_: ctd.data_modeling.callbacks.SAE.sim_callbacks.RasterPlot 14 | log_every_n_epochs: 100 15 | 16 | trajectory_plot_callback: 17 | _target_: ctd.data_modeling.callbacks.SAE.sim_callbacks.TrajectoryPlot 18 | log_every_n_epochs: 100 19 | 20 | trajectory_callback: 21 | _target_: ctd.data_modeling.callbacks.SAE.sim_callbacks.TrajectoryPlotOverTimeCallback 22 | log_every_n_epochs: 100 23 | 24 | avg_firing_rate_callback: 25 | _target_: ctd.data_modeling.callbacks.SAE.sim_callbacks.AvgFiringRateCallback 26 | log_every_n_epochs: 100 27 | 28 | inputs_plot: 29 | _target_: ctd.data_modeling.callbacks.SAE.sim_callbacks.InputsPlot 30 | log_every_n_epochs: 100 31 | 32 | dt_metrics: 33 | _target_: ctd.data_modeling.callbacks.SAE.sim_callbacks.DTMetricsCallback 34 | log_every_n_epochs: 100 35 | -------------------------------------------------------------------------------- /ctd/data_modeling/configs/callbacks/SAE/default_RandomTarget.yaml: -------------------------------------------------------------------------------- 1 | model_checkpoint: 2 | _target_: pytorch_lightning.callbacks.ModelCheckpoint 3 | dirpath: "." 4 | monitor: valid/loss_all 5 | save_last: True 6 | 7 | tune_report_callback: 8 | _target_: ray.tune.integration.pytorch_lightning.TuneReportCallback 9 | metrics: 10 | loss: valid/loss_all 11 | 12 | raster_plot_callback: 13 | _target_: ctd.data_modeling.callbacks.SAE.sim_callbacks.RasterPlot 14 | log_every_n_epochs: 100 15 | 16 | trajectory_plot_callback: 17 | _target_: ctd.data_modeling.callbacks.SAE.sim_callbacks.TrajectoryPlot 18 | log_every_n_epochs: 100 19 | 20 | trajectory_callback: 21 | _target_: ctd.data_modeling.callbacks.SAE.sim_callbacks.TrajectoryPlotOverTimeCallback 22 | log_every_n_epochs: 100 23 | 24 | avg_firing_rate_callback: 25 | _target_: ctd.data_modeling.callbacks.SAE.sim_callbacks.AvgFiringRateCallback 26 | log_every_n_epochs: 100 27 | 28 | inputs_plot: 29 | _target_: ctd.data_modeling.callbacks.SAE.sim_callbacks.InputsPlot 30 | log_every_n_epochs: 100 31 | -------------------------------------------------------------------------------- /ctd/data_modeling/configs/datamodules/LFADS/data_MultiTask.yaml: -------------------------------------------------------------------------------- 1 | _target_: ctd.data_modeling.datamodules.LFADS.datamodule.BasicDataModule 2 | prefix: tt_MultiTask 3 | 4 | seed: 0 5 | batch_size: 250 6 | provide_inputs: True 7 | 8 | neuron_dict: 9 | n_heldin: 50 10 | n_heldout: 10 11 | 12 | embed_dict: 13 | rect_func: exp 14 | fr_scaling: 2.0 15 | noise_dict: 16 | obs_noise: pseudoPoisson 17 | dispersion: 1.0 18 | 19 | batch_keys: 20 | - inputs 21 | - latents 22 | - activity 23 | - extra 24 | -------------------------------------------------------------------------------- /ctd/data_modeling/configs/datamodules/LFADS/data_MultiTask_infer.yaml: -------------------------------------------------------------------------------- 1 | _target_: ctd.data_modeling.datamodules.LFADS.datamodule.BasicDataModule 2 | prefix: tt_MultiTask 3 | 4 | seed: 0 5 | batch_size: 250 6 | provide_inputs: False 7 | 8 | neuron_dict: 9 | n_heldin: 50 10 | n_heldout: 10 11 | 12 | embed_dict: 13 | rect_func: exp 14 | fr_scaling: 2.0 15 | noise_dict: 16 | obs_noise: pseudoPoisson 17 | dispersion: 1.0 18 | 19 | batch_keys: 20 | - inputs 21 | - latents 22 | - activity 23 | - extra 24 | -------------------------------------------------------------------------------- /ctd/data_modeling/configs/datamodules/LFADS/data_NBFF.yaml: -------------------------------------------------------------------------------- 1 | _target_: ctd.data_modeling.datamodules.LFADS.datamodule.BasicDataModule 2 | prefix: tt_3bff 3 | 4 | seed: 0 5 | batch_size: 500 6 | provide_inputs: True 7 | 8 | neuron_dict: 9 | n_heldin: 50 10 | n_heldout: 10 11 | embed_dict: 12 | rect_func: exp 13 | fr_scaling: 2.0 14 | noise_dict: 15 | obs_noise: pseudoPoisson 16 | dispersion: 1.0 17 | 18 | batch_keys: 19 | - inputs 20 | - latents 21 | - activity 22 | -------------------------------------------------------------------------------- /ctd/data_modeling/configs/datamodules/LFADS/data_NBFF_infer.yaml: -------------------------------------------------------------------------------- 1 | _target_: ctd.data_modeling.datamodules.LFADS.datamodule.BasicDataModule 2 | prefix: tt_3bff 3 | seed: 0 4 | batch_size: 500 5 | provide_inputs: False 6 | 7 | neuron_dict: 8 | n_heldin: 50 9 | n_heldout: 10 10 | embed_dict: 11 | rect_func: exp 12 | fr_scaling: 2.0 13 | noise_dict: 14 | obs_noise: pseudoPoisson 15 | dispersion: 1.0 16 | 17 | batch_keys: 18 | - inputs 19 | - latents 20 | - activity 21 | -------------------------------------------------------------------------------- /ctd/data_modeling/configs/datamodules/LFADS/data_RandomTarget.yaml: -------------------------------------------------------------------------------- 1 | _target_: ctd.data_modeling.datamodules.LFADS.datamodule.BasicDataModule 2 | prefix: tt_RandomTarget 3 | seed: 0 4 | batch_size: 250 5 | provide_inputs: True 6 | 7 | neuron_dict: 8 | n_heldin: 50 9 | n_heldout: 10 10 | 11 | embed_dict: 12 | rect_func: exp 13 | fr_scaling: 2.0 14 | noise_dict: 15 | obs_noise: pseudoPoisson 16 | dispersion: 1.0 17 | 18 | batch_keys: 19 | - inputs 20 | - latents 21 | - activity 22 | -------------------------------------------------------------------------------- /ctd/data_modeling/configs/datamodules/LFADS/data_RandomTarget_infer.yaml: -------------------------------------------------------------------------------- 1 | _target_: ctd.data_modeling.datamodules.LFADS.datamodule.BasicDataModule 2 | prefix: tt_RandomTarget 3 | 4 | seed: 0 5 | batch_size: 250 6 | provide_inputs: False 7 | 8 | neuron_dict: 9 | n_heldin: 50 10 | n_heldout: 10 11 | 12 | embed_dict: 13 | rect_func: exp 14 | fr_scaling: 2.0 15 | noise_dict: 16 | obs_noise: pseudoPoisson 17 | dispersion: 1.0 18 | 19 | batch_keys: 20 | - inputs 21 | - latents 22 | - activity 23 | -------------------------------------------------------------------------------- /ctd/data_modeling/configs/datamodules/SAE/data_MultiTask.yaml: -------------------------------------------------------------------------------- 1 | _target_: ctd.data_modeling.datamodules.SAE.task_trained_data.TaskTrainedRNNDataModule 2 | prefix: tt_MultiTask 3 | neuron_dict: 4 | n_heldin: 50 5 | n_heldout: 10 6 | embed_dict: 7 | rect_func: exp 8 | fr_scaling: 2.0 9 | noise_dict: 10 | obs_noise: pseudoPoisson 11 | dispersion: 1.0 12 | seed: 0 13 | batch_size: 250 14 | num_workers: 4 15 | provide_inputs: True 16 | -------------------------------------------------------------------------------- /ctd/data_modeling/configs/datamodules/SAE/data_NBFF.yaml: -------------------------------------------------------------------------------- 1 | _target_: ctd.data_modeling.datamodules.SAE.task_trained_data.TaskTrainedRNNDataModule 2 | prefix: tt_3bff 3 | neuron_dict: 4 | n_heldin: 50 5 | n_heldout: 10 6 | embed_dict: 7 | rect_func: exp 8 | fr_scaling: 2.0 9 | noise_dict: 10 | obs_noise: pseudoPoisson 11 | dispersion: 1.0 12 | seed: 0 13 | batch_size: 250 14 | num_workers: 4 15 | provide_inputs: True 16 | -------------------------------------------------------------------------------- /ctd/data_modeling/configs/datamodules/SAE/data_RandomTarget.yaml: -------------------------------------------------------------------------------- 1 | _target_: ctd.data_modeling.datamodules.SAE.task_trained_data.TaskTrainedRNNDataModule 2 | prefix: tt_RandomTarget 3 | neuron_dict: 4 | n_heldin: 50 5 | n_heldout: 10 6 | embed_dict: 7 | rect_func: exp 8 | fr_scaling: 2.0 9 | noise_dict: 10 | obs_noise: pseudoPoisson 11 | dispersion: 1.0 12 | seed: 0 13 | batch_size: 250 14 | num_workers: 4 15 | provide_inputs: True 16 | -------------------------------------------------------------------------------- /ctd/data_modeling/configs/extensions/LFADS/posterior_sampling.yaml: -------------------------------------------------------------------------------- 1 | use_best_ckpt: True 2 | fn: 3 | _target_: ctd.data_modeling.extensions.LFADS.post_run.analysis.run_posterior_sampling 4 | filename: lfads_output.h5 5 | num_samples: 50 6 | -------------------------------------------------------------------------------- /ctd/data_modeling/configs/loggers/LFADS/default.yaml: -------------------------------------------------------------------------------- 1 | csv_logger: 2 | _target_: pytorch_lightning.loggers.CSVLogger 3 | save_dir: "csv_logs" 4 | version: "" 5 | name: "" 6 | tensorboard_logger: 7 | _target_: pytorch_lightning.loggers.TensorBoardLogger 8 | save_dir: "." 9 | version: "" 10 | name: "" 11 | 12 | wandb_logger: 13 | _target_: pytorch_lightning.loggers.WandbLogger 14 | save_dir: "." 15 | version: "" 16 | name: "" 17 | project: "data-trained" 18 | group: "" 19 | -------------------------------------------------------------------------------- /ctd/data_modeling/configs/loggers/LFADS/default_no_wandb.yaml: -------------------------------------------------------------------------------- 1 | csv_logger: 2 | _target_: pytorch_lightning.loggers.CSVLogger 3 | save_dir: "csv_logs" 4 | version: "" 5 | name: "" 6 | tensorboard_logger: 7 | _target_: pytorch_lightning.loggers.TensorBoardLogger 8 | save_dir: "." 9 | version: "" 10 | name: "" 11 | -------------------------------------------------------------------------------- /ctd/data_modeling/configs/loggers/SAE/default.yaml: -------------------------------------------------------------------------------- 1 | tensorboard_logger: 2 | _target_: pytorch_lightning.loggers.TensorBoardLogger 3 | save_dir: "." 4 | version: "" 5 | name: "" 6 | csv_logger: 7 | _target_: pytorch_lightning.loggers.CSVLogger 8 | save_dir: "." 9 | version: "" 10 | name: "" 11 | wandb_logger: 12 | _target_: pytorch_lightning.loggers.WandbLogger 13 | save_dir: "." 14 | version: "" 15 | name: "" 16 | project: "data-trained" 17 | group: "" 18 | -------------------------------------------------------------------------------- /ctd/data_modeling/configs/loggers/SAE/default_no_wandb.yaml: -------------------------------------------------------------------------------- 1 | tensorboard_logger: 2 | _target_: pytorch_lightning.loggers.TensorBoardLogger 3 | save_dir: "." 4 | version: "" 5 | name: "" 6 | csv_logger: 7 | _target_: pytorch_lightning.loggers.CSVLogger 8 | save_dir: "." 9 | version: "" 10 | name: "" 11 | -------------------------------------------------------------------------------- /ctd/data_modeling/configs/models/LFADS/MultiTask/MultiTask_LFADS.yaml: -------------------------------------------------------------------------------- 1 | _target_: ctd.data_modeling.models.LFADS.lfads.LFADS 2 | 3 | # --------- architecture --------- # 4 | gen_type: RNN 5 | inv_encoder: False 6 | encod_data_dim: 50 7 | encod_seq_len: 320 8 | recon_data_dim: 60 9 | recon_seq_len: 320 10 | ext_input_dim: 20 # Ext. Inputs 11 | ic_enc_seq_len: 0 12 | ic_enc_dim: 128 # Encoder for latent ICs hidden units 13 | ci_enc_dim: 0 # Controller encoder dimensionality 14 | ci_lag: 1 15 | con_dim: 0 # Hidden size of controller 16 | co_dim: 0 # # of controller inputs 17 | ic_dim: 128 # # neurons if Flow_inv, gen_dim if not 18 | gen_dim: 128 19 | fac_dim: 20 20 | 21 | # --------- readin / readout --------- # 22 | readin: 23 | - _target_: torch.nn.Identity 24 | readout: 25 | - _target_: ctd.data_modeling.models.LFADS.modules.readin_readout.FanInLinear 26 | in_features: ${fac_dim} 27 | out_features: 60 28 | # readout: 29 | # - _target_: ctd.data_modeling.models.LFADS.modules.readin_readout.Flow 30 | # in_features: ${fac_dim} 31 | # out_features: 50 32 | # readout_num_layers: 3 33 | # readout_hidden_size: 128 34 | # flow_num_steps: 25 35 | # --------- augmentation --------- # 36 | train_aug_stack: 37 | _target_: ctd.data_modeling.models.LFADS.modules.augmentations.AugmentationStack 38 | transforms: 39 | - _target_: ctd.data_modeling.models.LFADS.modules.augmentations.CoordinatedDropout 40 | cd_rate: 0.3 41 | cd_pass_rate: 0.0 42 | ic_enc_seq_len: ${ic_enc_seq_len} 43 | - _target_: ctd.data_modeling.models.LFADS.modules.augmentations.MultiTaskTrialLengthMasking 44 | batch_order: [0, 1] 45 | loss_order: [0, 1] 46 | infer_aug_stack: 47 | _target_: ctd.data_modeling.models.LFADS.modules.augmentations.AugmentationStack 48 | transforms: 49 | # Ignore NaNs for heldout data in test-phase validation loss 50 | - _target_: ctd.data_modeling.models.LFADS.modules.augmentations.IgnoreNaNLoss 51 | encod_data_dim: ${encod_data_dim} 52 | encod_seq_len: ${encod_seq_len} 53 | scale_by_quadrant: False 54 | - _target_: ctd.data_modeling.models.LFADS.modules.augmentations.MultiTaskTrialLengthMasking 55 | loss_order: [0, 1] 56 | 57 | # --------- priors / posteriors --------- # 58 | reconstruction: 59 | - _target_: ctd.data_modeling.models.LFADS.modules.recons.Poisson 60 | variational: True 61 | 62 | # Autoregressive input prior 63 | co_prior: 64 | _target_: ctd.data_modeling.models.LFADS.modules.priors.AutoregressiveMultivariateNormal 65 | tau: 10.0 66 | nvar: 0.1 67 | shape: ${co_dim} 68 | 69 | ic_prior: 70 | _target_: ctd.data_modeling.models.LFADS.modules.priors.MultivariateNormal 71 | mean: 0 72 | variance: 0.1 73 | shape: ${ic_dim} 74 | ic_post_var_min: 1.0e-4 75 | 76 | # --------- misc --------- # 77 | dropout_rate: 0.02 # sampled 78 | cell_clip: 5.0 79 | loss_scale: 1.0e+4 80 | recon_reduce_mean: True 81 | 82 | # --------- learning rate --------- # 83 | lr_init: 1.0e-3 84 | lr_stop: 1.0e-5 85 | lr_decay: 0.95 86 | lr_patience: 6 87 | lr_adam_beta1: 0.9 88 | lr_adam_beta2: 0.999 89 | lr_adam_epsilon: 1.0e-7 90 | lr_scheduler: True 91 | 92 | # --------- regularization --------- # 93 | weight_decay: 1.0e-5 94 | l2_start_epoch: 0 95 | l2_increase_epoch: 80 96 | l2_ic_enc_scale: 0.0 97 | l2_ci_enc_scale: 0.0 98 | l2_gen_scale: 0.0 # sampled 99 | l2_con_scale: 0.0 # sampled 100 | l2_readout_scale: 0 101 | kl_start_epoch: 0 102 | kl_increase_epoch: 80 103 | kl_ic_scale: 1.0e-6 # sampled 104 | kl_co_scale: 1.0e-5 # sampled 105 | -------------------------------------------------------------------------------- /ctd/data_modeling/configs/models/LFADS/MultiTask/MultiTask_LFADS_infer.yaml: -------------------------------------------------------------------------------- 1 | _target_: ctd.data_modeling.models.LFADS.lfads.LFADS 2 | 3 | # --------- architecture --------- # 4 | gen_type: RNN 5 | inv_encoder: False 6 | encod_data_dim: 50 7 | encod_seq_len: 320 8 | recon_data_dim: 60 9 | recon_seq_len: 320 10 | ext_input_dim: 0 # Ext. Inputs 11 | ic_enc_seq_len: 0 12 | ic_enc_dim: 128 # Encoder for latent ICs hidden units 13 | ci_enc_dim: 128 # Controller encoder dimensionality 14 | ci_lag: 1 15 | con_dim: 128 # Hidden size of controller 16 | co_dim: 15 # # of controller inputs 17 | ic_dim: 128 # # neurons if Flow_inv, gen_dim if not 18 | gen_dim: 128 19 | fac_dim: 20 20 | 21 | # --------- readin / readout --------- # 22 | readin: 23 | - _target_: torch.nn.Identity 24 | readout: 25 | - _target_: ctd.data_modeling.models.LFADS.modules.readin_readout.FanInLinear 26 | in_features: ${fac_dim} 27 | out_features: 60 28 | # readout: 29 | # - _target_: ctd.data_modeling.models.LFADS.modules.readin_readout.Flow 30 | # in_features: ${fac_dim} 31 | # out_features: 50 32 | # readout_num_layers: 3 33 | # readout_hidden_size: 128 34 | # flow_num_steps: 25 35 | # --------- augmentation --------- # 36 | train_aug_stack: 37 | _target_: ctd.data_modeling.models.LFADS.modules.augmentations.AugmentationStack 38 | transforms: 39 | - _target_: ctd.data_modeling.models.LFADS.modules.augmentations.CoordinatedDropout 40 | cd_rate: 0.3 41 | cd_pass_rate: 0.0 42 | ic_enc_seq_len: ${ic_enc_seq_len} 43 | - _target_: ctd.data_modeling.models.LFADS.modules.augmentations.MultiTaskTrialLengthMasking 44 | batch_order: [0] 45 | loss_order: [0] 46 | infer_aug_stack: 47 | _target_: ctd.data_modeling.models.LFADS.modules.augmentations.AugmentationStack 48 | transforms: 49 | # Ignore NaNs for heldout data in test-phase validation loss 50 | - _target_: ctd.data_modeling.models.LFADS.modules.augmentations.IgnoreNaNLoss 51 | encod_data_dim: ${encod_data_dim} 52 | encod_seq_len: ${encod_seq_len} 53 | scale_by_quadrant: False 54 | - _target_: ctd.data_modeling.models.LFADS.modules.augmentations.MultiTaskTrialLengthMasking 55 | loss_order: [0] 56 | 57 | # --------- priors / posteriors --------- # 58 | reconstruction: 59 | - _target_: ctd.data_modeling.models.LFADS.modules.recons.Poisson 60 | variational: True 61 | 62 | # Autoregressive input prior 63 | co_prior: 64 | _target_: ctd.data_modeling.models.LFADS.modules.priors.AutoregressiveMultivariateNormal 65 | tau: 10.0 66 | nvar: 0.1 67 | shape: ${co_dim} 68 | 69 | ic_prior: 70 | _target_: ctd.data_modeling.models.LFADS.modules.priors.MultivariateNormal 71 | mean: 0 72 | variance: 0.1 73 | shape: ${ic_dim} 74 | ic_post_var_min: 1.0e-4 75 | 76 | # --------- misc --------- # 77 | dropout_rate: 0.02 # sampled 78 | cell_clip: 5.0 79 | loss_scale: 1.0e+4 80 | recon_reduce_mean: True 81 | 82 | # --------- learning rate --------- # 83 | lr_init: 1.0e-3 84 | lr_stop: 1.0e-5 85 | lr_decay: 0.95 86 | lr_patience: 6 87 | lr_adam_beta1: 0.9 88 | lr_adam_beta2: 0.999 89 | lr_adam_epsilon: 1.0e-7 90 | lr_scheduler: True 91 | 92 | # --------- regularization --------- # 93 | weight_decay: 1.0e-5 94 | l2_start_epoch: 0 95 | l2_increase_epoch: 80 96 | l2_ic_enc_scale: 0.0 97 | l2_ci_enc_scale: 0.0 98 | l2_gen_scale: 0.0 # sampled 99 | l2_con_scale: 0.0 # sampled 100 | l2_readout_scale: 0 101 | kl_start_epoch: 0 102 | kl_increase_epoch: 80 103 | kl_ic_scale: 1.0e-6 # sampled 104 | kl_co_scale: 1.0e-5 # sampled 105 | -------------------------------------------------------------------------------- /ctd/data_modeling/configs/models/LFADS/NBFF/NBFF_LFADS.yaml: -------------------------------------------------------------------------------- 1 | _target_: ctd.data_modeling.models.LFADS.lfads.LFADS 2 | 3 | # --------- architecture --------- # 4 | gen_type: RNN 5 | inv_encoder: False 6 | encod_data_dim: 50 7 | encod_seq_len: 500 8 | recon_data_dim: 60 9 | recon_seq_len: 500 10 | ext_input_dim: 3 # Ext. Inputs 11 | ic_enc_seq_len: 0 12 | ic_enc_dim: 128 # Encoder for latent ICs hidden units 13 | ci_enc_dim: 0 # Controller encoder dimensionality 14 | ci_lag: 1 15 | con_dim: 0 # Hidden size of controller 16 | co_dim: 0 # # of controller inputs 17 | ic_dim: 128 # # neurons if Flow_inv, gen_dim if not 18 | gen_dim: 128 19 | fac_dim: 128 20 | 21 | # --------- readin / readout --------- # 22 | readin: 23 | - _target_: torch.nn.Identity 24 | readout: 25 | - _target_: ctd.data_modeling.models.LFADS.modules.readin_readout.FanInLinear 26 | in_features: ${fac_dim} 27 | out_features: ${recon_data_dim} 28 | # readout: 29 | # - _target_: ctd.data_modeling.models.LFADS.modules.readin_readout.Flow 30 | # in_features: ${fac_dim} 31 | # out_features: 50 32 | # readout_num_layers: 3 33 | # readout_hidden_size: 128 34 | # flow_num_steps: 25 35 | # --------- augmentation --------- # 36 | train_aug_stack: 37 | _target_: ctd.data_modeling.models.LFADS.modules.augmentations.AugmentationStack 38 | transforms: 39 | - _target_: ctd.data_modeling.models.LFADS.modules.augmentations.CoordinatedDropout 40 | cd_rate: 0.3 41 | cd_pass_rate: 0.0 42 | ic_enc_seq_len: ${ic_enc_seq_len} 43 | batch_order: [0] 44 | loss_order: [0] 45 | infer_aug_stack: 46 | _target_: ctd.data_modeling.models.LFADS.modules.augmentations.AugmentationStack 47 | transforms: 48 | # Ignore NaNs for heldout data in test-phase validation loss 49 | - _target_: ctd.data_modeling.models.LFADS.modules.augmentations.IgnoreNaNLoss 50 | encod_data_dim: ${encod_data_dim} 51 | encod_seq_len: ${encod_seq_len} 52 | scale_by_quadrant: False 53 | loss_order: [0] 54 | 55 | # --------- priors / posteriors --------- # 56 | reconstruction: 57 | - _target_: ctd.data_modeling.models.LFADS.modules.recons.Poisson 58 | variational: True 59 | 60 | # Autoregressive input prior 61 | co_prior: 62 | _target_: ctd.data_modeling.models.LFADS.modules.priors.AutoregressiveMultivariateNormal 63 | tau: 10.0 64 | nvar: 0.1 65 | shape: ${co_dim} 66 | 67 | ic_prior: 68 | _target_: ctd.data_modeling.models.LFADS.modules.priors.MultivariateNormal 69 | mean: 0 70 | variance: 0.1 71 | shape: ${ic_dim} 72 | ic_post_var_min: 1.0e-4 73 | 74 | # --------- misc --------- # 75 | dropout_rate: 0.02 # sampled 76 | cell_clip: 5.0 77 | loss_scale: 1.0e+4 78 | recon_reduce_mean: True 79 | 80 | # --------- learning rate --------- # 81 | lr_init: 5.0e-3 82 | lr_stop: 1.0e-5 83 | lr_decay: 0.95 84 | lr_patience: 6 85 | lr_adam_beta1: 0.9 86 | lr_adam_beta2: 0.999 87 | lr_adam_epsilon: 1.0e-7 88 | lr_scheduler: True 89 | 90 | # --------- regularization --------- # 91 | weight_decay: 1.0e-8 92 | l2_start_epoch: 0 93 | l2_increase_epoch: 80 94 | l2_ic_enc_scale: 0.0 95 | l2_ci_enc_scale: 0.0 96 | l2_gen_scale: 0.0 # sampled 97 | l2_con_scale: 0.0 # sampled 98 | l2_readout_scale: 0 99 | kl_start_epoch: 0 100 | kl_increase_epoch: 80 101 | kl_ic_scale: 1.0e-8 # sampled 102 | kl_co_scale: 1.0e-8 # sampled 103 | -------------------------------------------------------------------------------- /ctd/data_modeling/configs/models/LFADS/NBFF/NBFF_LFADS_infer.yaml: -------------------------------------------------------------------------------- 1 | _target_: ctd.data_modeling.models.LFADS.lfads.LFADS 2 | 3 | # --------- architecture --------- # 4 | gen_type: RNN 5 | inv_encoder: False 6 | encod_data_dim: 50 7 | encod_seq_len: 500 8 | recon_seq_len: 500 9 | recon_data_dim: 60 10 | ext_input_dim: 0 # Ext. Inputs 11 | ic_enc_seq_len: 0 12 | ic_enc_dim: 128 # Encoder for latent ICs hidden units 13 | ci_enc_dim: 128 # Controller encoder dimensionality 14 | ci_lag: 1 15 | con_dim: 128 # Hidden size of controller 16 | co_dim: 3 # # of controller inputs 17 | ic_dim: 128 # # neurons if Flow_inv, gen_dim if not 18 | gen_dim: 128 19 | fac_dim: 20 20 | 21 | # --------- readin / readout --------- # 22 | readin: 23 | - _target_: torch.nn.Identity 24 | readout: 25 | - _target_: ctd.data_modeling.models.LFADS.modules.readin_readout.FanInLinear 26 | in_features: ${fac_dim} 27 | out_features: ${recon_data_dim} 28 | # readout: 29 | # - _target_: ctd.data_modeling.models.LFADS.modules.readin_readout.Flow 30 | # in_features: ${fac_dim} 31 | # out_features: 50 32 | # readout_num_layers: 3 33 | # readout_hidden_size: 128 34 | # flow_num_steps: 25 35 | # --------- augmentation --------- # 36 | train_aug_stack: 37 | _target_: ctd.data_modeling.models.LFADS.modules.augmentations.AugmentationStack 38 | transforms: 39 | - _target_: ctd.data_modeling.models.LFADS.modules.augmentations.CoordinatedDropout 40 | cd_rate: 0.3 41 | cd_pass_rate: 0.0 42 | ic_enc_seq_len: ${ic_enc_seq_len} 43 | batch_order: [0] 44 | loss_order: [0] 45 | infer_aug_stack: 46 | _target_: ctd.data_modeling.models.LFADS.modules.augmentations.AugmentationStack 47 | transforms: 48 | # Ignore NaNs for heldout data in test-phase validation loss 49 | - _target_: ctd.data_modeling.models.LFADS.modules.augmentations.IgnoreNaNLoss 50 | encod_data_dim: ${encod_data_dim} 51 | encod_seq_len: ${encod_seq_len} 52 | scale_by_quadrant: False 53 | loss_order: [0] 54 | 55 | # --------- priors / posteriors --------- # 56 | reconstruction: 57 | - _target_: ctd.data_modeling.models.LFADS.modules.recons.Poisson 58 | variational: True 59 | 60 | # Autoregressive input prior 61 | # co_prior: 62 | # _target_: ctd.data_modeling.models.LFADS.modules.priors.AutoregressiveMultivariateNormal 63 | # tau: 10.0 64 | # nvar: 0.1 65 | # shape: ${co_dim} 66 | 67 | # MultivariateStudentT 68 | co_prior: 69 | _target_: ctd.data_modeling.models.LFADS.modules.priors.MultivariateStudentT 70 | loc: 0 71 | scale: 1.0 72 | shape: ${co_dim} 73 | df: 3 74 | 75 | ic_prior: 76 | _target_: ctd.data_modeling.models.LFADS.modules.priors.MultivariateNormal 77 | mean: 0 78 | variance: 0.1 79 | shape: ${ic_dim} 80 | ic_post_var_min: 1.0e-4 81 | 82 | # --------- misc --------- # 83 | dropout_rate: 0.02 # sampled 84 | cell_clip: 5.0 85 | loss_scale: 1.0e+4 86 | recon_reduce_mean: True 87 | 88 | # --------- learning rate --------- # 89 | lr_init: 1.0e-3 90 | lr_stop: 1.0e-5 91 | lr_decay: 0.95 92 | lr_patience: 6 93 | lr_adam_beta1: 0.9 94 | lr_adam_beta2: 0.999 95 | lr_adam_epsilon: 1.0e-7 96 | lr_scheduler: True 97 | 98 | # --------- regularization --------- # 99 | weight_decay: 1.0e-5 100 | l2_start_epoch: 0 101 | l2_increase_epoch: 80 102 | l2_ic_enc_scale: 0.0 103 | l2_ci_enc_scale: 0.0 104 | l2_gen_scale: 0.0 # sampled 105 | l2_con_scale: 0.0 # sampled 106 | l2_readout_scale: 0 107 | kl_start_epoch: 0 108 | kl_increase_epoch: 80 109 | kl_ic_scale: 1.0e-6 # sampled 110 | kl_co_scale: 1.0e-5 # sampled 111 | -------------------------------------------------------------------------------- /ctd/data_modeling/configs/models/LFADS/RandomTarget/RandomTarget_LFADS.yaml: -------------------------------------------------------------------------------- 1 | _target_: ctd.data_modeling.models.LFADS.lfads.LFADS 2 | 3 | # --------- architecture --------- # 4 | gen_type: RNN 5 | inv_encoder: False 6 | encod_data_dim: 50 7 | encod_seq_len: 150 8 | recon_seq_len: 150 9 | recon_data_dim: 60 10 | ext_input_dim: 17 # Ext. Inputs 11 | ic_enc_seq_len: 0 12 | ic_enc_dim: 128 # Encoder for latent ICs hidden units 13 | ci_enc_dim: 0 # Controller encoder dimensionality 14 | ci_lag: 1 15 | con_dim: 0 # Hidden size of controller 16 | co_dim: 0 # # of controller inputs 17 | ic_dim: 128 # # neurons if Flow_inv, gen_dim if not 18 | gen_dim: 128 19 | fac_dim: 20 20 | 21 | # --------- readin / readout --------- # 22 | readin: 23 | - _target_: torch.nn.Identity 24 | readout: 25 | - _target_: ctd.data_modeling.models.LFADS.modules.readin_readout.FanInLinear 26 | in_features: ${fac_dim} 27 | out_features: ${recon_data_dim} 28 | # readout: 29 | # - _target_: ctd.data_modeling.models.LFADS.modules.readin_readout.Flow 30 | # in_features: ${fac_dim} 31 | # out_features: 50 32 | # readout_num_layers: 3 33 | # readout_hidden_size: 128 34 | # flow_num_steps: 25 35 | # --------- augmentation --------- # 36 | train_aug_stack: 37 | _target_: ctd.data_modeling.models.LFADS.modules.augmentations.AugmentationStack 38 | transforms: 39 | - _target_: ctd.data_modeling.models.LFADS.modules.augmentations.CoordinatedDropout 40 | cd_rate: 0.3 41 | cd_pass_rate: 0.0 42 | ic_enc_seq_len: ${ic_enc_seq_len} 43 | batch_order: [0] 44 | loss_order: [0] 45 | infer_aug_stack: 46 | _target_: ctd.data_modeling.models.LFADS.modules.augmentations.AugmentationStack 47 | transforms: 48 | # Ignore NaNs for heldout data in test-phase validation loss 49 | - _target_: ctd.data_modeling.models.LFADS.modules.augmentations.IgnoreNaNLoss 50 | encod_data_dim: ${encod_data_dim} 51 | encod_seq_len: ${encod_seq_len} 52 | scale_by_quadrant: False 53 | loss_order: [0] 54 | 55 | # --------- priors / posteriors --------- # 56 | reconstruction: 57 | - _target_: ctd.data_modeling.models.LFADS.modules.recons.Poisson 58 | variational: True 59 | 60 | # Sparse input prior 61 | 62 | # Autoregressive input prior 63 | co_prior: 64 | _target_: ctd.data_modeling.models.LFADS.modules.priors.AutoregressiveMultivariateNormal 65 | tau: 10.0 66 | nvar: 0.1 67 | shape: ${co_dim} 68 | 69 | ic_prior: 70 | _target_: ctd.data_modeling.models.LFADS.modules.priors.MultivariateNormal 71 | mean: 0 72 | variance: 0.1 73 | shape: ${ic_dim} 74 | ic_post_var_min: 1.0e-4 75 | 76 | # --------- misc --------- # 77 | dropout_rate: 0.02 # sampled 78 | cell_clip: 5.0 79 | loss_scale: 1.0e+4 80 | recon_reduce_mean: True 81 | 82 | # --------- learning rate --------- # 83 | lr_init: 1.0e-3 84 | lr_stop: 1.0e-5 85 | lr_decay: 0.95 86 | lr_patience: 6 87 | lr_adam_beta1: 0.9 88 | lr_adam_beta2: 0.999 89 | lr_adam_epsilon: 1.0e-7 90 | lr_scheduler: True 91 | 92 | # --------- regularization --------- # 93 | weight_decay: 1.0e-5 94 | l2_start_epoch: 0 95 | l2_increase_epoch: 80 96 | l2_ic_enc_scale: 0.0 97 | l2_ci_enc_scale: 0.0 98 | l2_gen_scale: 0.0 # sampled 99 | l2_con_scale: 0.0 # sampled 100 | l2_readout_scale: 0 101 | kl_start_epoch: 0 102 | kl_increase_epoch: 80 103 | kl_ic_scale: 1.0e-6 # sampled 104 | kl_co_scale: 1.0e-5 # sampled 105 | -------------------------------------------------------------------------------- /ctd/data_modeling/configs/models/LFADS/RandomTarget/RandomTarget_LFADS_infer.yaml: -------------------------------------------------------------------------------- 1 | _target_: ctd.data_modeling.models.LFADS.lfads.LFADS 2 | 3 | # --------- architecture --------- # 4 | gen_type: RNN 5 | inv_encoder: False 6 | encod_data_dim: 50 7 | encod_seq_len: 150 8 | recon_data_dim: 60 9 | recon_seq_len: 150 10 | ext_input_dim: 0 # Ext. Inputs 11 | ic_enc_seq_len: 0 12 | ic_enc_dim: 128 # Encoder for latent ICs hidden units 13 | ci_enc_dim: 128 # Controller encoder dimensionality 14 | ci_lag: 1 15 | con_dim: 128 # Hidden size of controller 16 | co_dim: 17 # # of controller inputs 17 | ic_dim: 128 # # neurons if Flow_inv, gen_dim if not 18 | gen_dim: 128 19 | fac_dim: 64 20 | 21 | # --------- readin / readout --------- # 22 | readin: 23 | - _target_: torch.nn.Identity 24 | readout: 25 | - _target_: ctd.data_modeling.models.LFADS.modules.readin_readout.FanInLinear 26 | in_features: ${fac_dim} 27 | out_features: ${recon_data_dim} 28 | # readout: 29 | # - _target_: ctd.data_modeling.models.LFADS.modules.readin_readout.Flow 30 | # in_features: ${fac_dim} 31 | # out_features: 50 32 | # readout_num_layers: 3 33 | # readout_hidden_size: 128 34 | # flow_num_steps: 25 35 | # --------- augmentation --------- # 36 | train_aug_stack: 37 | _target_: ctd.data_modeling.models.LFADS.modules.augmentations.AugmentationStack 38 | transforms: 39 | - _target_: ctd.data_modeling.models.LFADS.modules.augmentations.CoordinatedDropout 40 | cd_rate: 0.3 41 | cd_pass_rate: 0.0 42 | ic_enc_seq_len: ${ic_enc_seq_len} 43 | batch_order: [0] 44 | loss_order: [0] 45 | infer_aug_stack: 46 | _target_: ctd.data_modeling.models.LFADS.modules.augmentations.AugmentationStack 47 | transforms: 48 | # Ignore NaNs for heldout data in test-phase validation loss 49 | - _target_: ctd.data_modeling.models.LFADS.modules.augmentations.IgnoreNaNLoss 50 | encod_data_dim: ${encod_data_dim} 51 | encod_seq_len: ${encod_seq_len} 52 | scale_by_quadrant: False 53 | loss_order: [0] 54 | 55 | # --------- priors / posteriors --------- # 56 | reconstruction: 57 | - _target_: ctd.data_modeling.models.LFADS.modules.recons.Poisson 58 | variational: True 59 | 60 | # Autoregressive input prior 61 | co_prior: 62 | _target_: ctd.data_modeling.models.LFADS.modules.priors.AutoregressiveMultivariateNormal 63 | tau: 10.0 64 | nvar: 0.1 65 | shape: ${co_dim} 66 | 67 | ic_prior: 68 | _target_: ctd.data_modeling.models.LFADS.modules.priors.MultivariateNormal 69 | mean: 0 70 | variance: 0.1 71 | shape: ${ic_dim} 72 | ic_post_var_min: 1.0e-4 73 | 74 | # --------- misc --------- # 75 | dropout_rate: 0.02 # sampled 76 | cell_clip: 5.0 77 | loss_scale: 1.0e+4 78 | recon_reduce_mean: True 79 | 80 | # --------- learning rate --------- # 81 | lr_init: 1.0e-3 82 | lr_stop: 1.0e-5 83 | lr_decay: 0.95 84 | lr_patience: 6 85 | lr_adam_beta1: 0.9 86 | lr_adam_beta2: 0.999 87 | lr_adam_epsilon: 1.0e-7 88 | lr_scheduler: True 89 | 90 | # --------- regularization --------- # 91 | weight_decay: 1.0e-5 92 | l2_start_epoch: 0 93 | l2_increase_epoch: 80 94 | l2_ic_enc_scale: 0.0 95 | l2_ci_enc_scale: 0.0 96 | l2_gen_scale: 0.0 # sampled 97 | l2_con_scale: 0.0 # sampled 98 | l2_readout_scale: 0 99 | kl_start_epoch: 0 100 | kl_increase_epoch: 80 101 | kl_ic_scale: 1.0e-6 # sampled 102 | kl_co_scale: 1.0e-5 # sampled 103 | -------------------------------------------------------------------------------- /ctd/data_modeling/configs/models/SAE/MultiTask/MultiTask_GRU_RNN.yaml: -------------------------------------------------------------------------------- 1 | _target_: ctd.data_modeling.models.SAE.gru_rnn.GRULatentSAE 2 | dataset: MultiTask 3 | 4 | encoder_size: 100 5 | encoder_window: -1 6 | 7 | heldin_size: 50 8 | heldout_size: 60 9 | latent_size: 128 10 | input_size: 20 11 | 12 | weight_decay: 0 13 | lr: 5e-3 14 | 15 | dropout: 0.05 16 | 17 | loss_func: 18 | _target_: ctd.data_modeling.models.SAE.loss_func.MultiTaskPoissonLossFunc 19 | -------------------------------------------------------------------------------- /ctd/data_modeling/configs/models/SAE/MultiTask/MultiTask_LDS.yaml: -------------------------------------------------------------------------------- 1 | _target_: ctd.data_modeling.models.SAE.lds.LDSSAE 2 | dataset: MultiTask 3 | 4 | encoder_size: 100 5 | encoder_window: -1 6 | 7 | heldin_size: 50 8 | heldout_size: 60 9 | latent_size: 128 10 | input_size: 20 11 | 12 | weight_decay: 0 13 | lr: 1e-4 14 | 15 | dropout: 0.05 16 | loss_func: 17 | _target_: ctd.data_modeling.models.SAE.loss_func.MultiTaskPoissonLossFunc 18 | -------------------------------------------------------------------------------- /ctd/data_modeling/configs/models/SAE/MultiTask/MultiTask_NODE.yaml: -------------------------------------------------------------------------------- 1 | _target_: ctd.data_modeling.models.SAE.node.NODELatentSAE 2 | dataset: MultiTask 3 | 4 | encoder_size: 100 5 | encoder_window: -1 6 | 7 | heldin_size: 50 8 | heldout_size: 60 9 | latent_size: 20 10 | input_size: 20 11 | 12 | dropout: 0.05 13 | 14 | vf_hidden_size: 128 15 | vf_num_layers: 6 16 | 17 | lr: 2e-3 18 | 19 | weight_decay: 1e-7 20 | 21 | loss_func: 22 | _target_: ctd.data_modeling.models.SAE.loss_func.MultiTaskPoissonLossFunc 23 | -------------------------------------------------------------------------------- /ctd/data_modeling/configs/models/SAE/MultiTask/MultiTask_Vanilla_RNN.yaml: -------------------------------------------------------------------------------- 1 | _target_: ctd.data_modeling.models.SAE.vanilla_rnn.RNNLatentSAE 2 | dataset: MultiTask 3 | 4 | encoder_size: 100 5 | encoder_window: -1 6 | 7 | heldin_size: 50 8 | heldout_size: 60 9 | latent_size: 128 10 | input_size: 20 11 | 12 | weight_decay: 0 13 | lr: 5e-3 14 | 15 | dropout: 0.05 16 | 17 | loss_func: 18 | _target_: ctd.data_modeling.models.SAE.loss_func.MultiTaskPoissonLossFunc 19 | -------------------------------------------------------------------------------- /ctd/data_modeling/configs/models/SAE/NBFF/NBFF_GRU_RNN.yaml: -------------------------------------------------------------------------------- 1 | _target_: ctd.data_modeling.models.SAE.gru_rnn.GRULatentSAE 2 | dataset: NBFF 3 | 4 | encoder_size: 100 5 | encoder_window: -1 6 | 7 | heldin_size: 50 8 | heldout_size: 60 9 | latent_size: 128 10 | input_size: 3 11 | 12 | weight_decay: 0 13 | lr: 2e-3 14 | 15 | dropout: 0.05 16 | -------------------------------------------------------------------------------- /ctd/data_modeling/configs/models/SAE/NBFF/NBFF_LDS.yaml: -------------------------------------------------------------------------------- 1 | _target_: ctd.data_modeling.models.SAE.lds.LDSSAE 2 | dataset: NBFF 3 | 4 | encoder_size: 100 5 | encoder_window: -1 6 | 7 | heldin_size: 50 8 | heldout_size: 60 9 | latent_size: 128 10 | input_size: 3 11 | 12 | weight_decay: 0 13 | lr: 1e-4 14 | 15 | dropout: 0.05 16 | -------------------------------------------------------------------------------- /ctd/data_modeling/configs/models/SAE/NBFF/NBFF_NODE.yaml: -------------------------------------------------------------------------------- 1 | _target_: ctd.data_modeling.models.SAE.node.NODELatentSAE 2 | dataset: NBFF 3 | 4 | encoder_size: 100 5 | encoder_window: -1 6 | 7 | heldin_size: 50 8 | heldout_size: 60 9 | latent_size: 10 10 | input_size: 3 11 | 12 | dropout: 0.05 13 | 14 | vf_hidden_size: 128 15 | vf_num_layers: 6 16 | 17 | lr: 2e-3 18 | 19 | weight_decay: 0 20 | -------------------------------------------------------------------------------- /ctd/data_modeling/configs/models/SAE/NBFF/NBFF_Vanilla_RNN.yaml: -------------------------------------------------------------------------------- 1 | _target_: ctd.data_modeling.models.SAE.vanilla_rnn.RNNLatentSAE 2 | dataset: NBFF 3 | 4 | encoder_size: 100 5 | encoder_window: -1 6 | 7 | heldin_size: 50 8 | heldout_size: 60 9 | latent_size: 128 10 | input_size: 3 11 | 12 | weight_decay: 0 13 | lr: 5e-3 14 | 15 | dropout: 0.05 16 | -------------------------------------------------------------------------------- /ctd/data_modeling/configs/models/SAE/RandomTarget/RandomTarget_GRU_RNN.yaml: -------------------------------------------------------------------------------- 1 | _target_: ctd.data_modeling.models.SAE.gru_rnn.GRULatentSAE 2 | dataset: RandomTargetDelay 3 | 4 | encoder_size: 100 5 | encoder_window: -1 6 | 7 | heldin_size: 50 8 | heldout_size: 60 9 | latent_size: 64 10 | input_size: 17 11 | 12 | weight_decay: 1e-6 13 | lr: 5e-3 14 | 15 | dropout: 0.2 16 | -------------------------------------------------------------------------------- /ctd/data_modeling/configs/models/SAE/RandomTarget/RandomTarget_LDS.yaml: -------------------------------------------------------------------------------- 1 | _target_: ctd.data_modeling.models.SAE.lds.LDSSAE 2 | dataset: NBFF 3 | 4 | encoder_size: 100 5 | encoder_window: -1 6 | 7 | heldin_size: 50 8 | heldout_size: 60 9 | latent_size: 128 10 | input_size: 17 11 | 12 | weight_decay: 0 13 | lr: 1e-4 14 | 15 | dropout: 0.05 16 | -------------------------------------------------------------------------------- /ctd/data_modeling/configs/models/SAE/RandomTarget/RandomTarget_NODE.yaml: -------------------------------------------------------------------------------- 1 | _target_: ctd.data_modeling.models.SAE.node.NODELatentSAE 2 | dataset: RandomTarget 3 | 4 | encoder_size: 100 5 | encoder_window: -1 6 | 7 | heldin_size: 50 8 | heldout_size: 60 9 | latent_size: 15 10 | input_size: 17 11 | 12 | dropout: 0.05 13 | 14 | vf_hidden_size: 128 15 | vf_num_layers: 6 16 | 17 | inv_encoder: False 18 | flow_num_steps: 20 19 | readout_vf_hidden: 128 20 | readout_num_layers: 3 21 | 22 | lr_readout: 2e-3 23 | lr_encoder: 2e-3 24 | lr_decoder: 2e-3 25 | 26 | decay_readout: 0 27 | decay_encoder: 1e-7 28 | decay_decoder: 1e-7 29 | 30 | readout_type: Linear 31 | increment_trial: True 32 | points_per_group: 20 33 | epochs_per_group: 10 34 | -------------------------------------------------------------------------------- /ctd/data_modeling/configs/models/SAE/RandomTarget/RandomTarget_Vanilla_RNN.yaml: -------------------------------------------------------------------------------- 1 | _target_: ctd.data_modeling.models.SAE.vanilla_rnn.RNNLatentSAE 2 | dataset: RandomTargetDelay 3 | 4 | encoder_size: 100 5 | encoder_window: -1 6 | 7 | heldin_size: 50 8 | heldout_size: 60 9 | latent_size: 128 10 | input_size: 17 11 | 12 | weight_decay: 0 13 | lr: 5e-3 14 | 15 | dropout: 0.05 16 | -------------------------------------------------------------------------------- /ctd/data_modeling/configs/trainers/LFADS/trainer_MultiTask.yaml: -------------------------------------------------------------------------------- 1 | _target_: pytorch_lightning.Trainer 2 | log_every_n_steps: 20 3 | max_epochs: 200 4 | gradient_clip_val: 0.01 5 | # Prevent console output from individual models 6 | enable_progress_bar: False 7 | # weights_summary: null 8 | #auto_scale_batch_size: True 9 | -------------------------------------------------------------------------------- /ctd/data_modeling/configs/trainers/LFADS/trainer_NBFF.yaml: -------------------------------------------------------------------------------- 1 | _target_: pytorch_lightning.Trainer 2 | log_every_n_steps: 20 3 | max_epochs: 1000 4 | gradient_clip_val: 0.01 5 | # Prevent console output from individual models 6 | enable_progress_bar: False 7 | # weights_summary: null 8 | #auto_scale_batch_size: True 9 | -------------------------------------------------------------------------------- /ctd/data_modeling/configs/trainers/LFADS/trainer_RandomTarget.yaml: -------------------------------------------------------------------------------- 1 | _target_: pytorch_lightning.Trainer 2 | log_every_n_steps: 20 3 | max_epochs: 200 4 | gradient_clip_val: 0.01 5 | # Prevent console output from individual models 6 | enable_progress_bar: False 7 | # weights_summary: null 8 | #auto_scale_batch_size: True 9 | -------------------------------------------------------------------------------- /ctd/data_modeling/configs/trainers/SAE/trainer_MultiTask.yaml: -------------------------------------------------------------------------------- 1 | _target_: pytorch_lightning.Trainer 2 | log_every_n_steps: 20 3 | max_epochs: 200 4 | gradient_clip_val: 0.01 5 | # Prevent console output from individual models 6 | enable_progress_bar: False 7 | # weights_summary: null 8 | #auto_scale_batch_size: True 9 | -------------------------------------------------------------------------------- /ctd/data_modeling/configs/trainers/SAE/trainer_NBFF.yaml: -------------------------------------------------------------------------------- 1 | _target_: pytorch_lightning.Trainer 2 | log_every_n_steps: 20 3 | max_epochs: 500 4 | gradient_clip_val: 0.01 5 | # Prevent console output from individual models 6 | enable_progress_bar: False 7 | # weights_summary: null 8 | #auto_scale_batch_size: True 9 | -------------------------------------------------------------------------------- /ctd/data_modeling/configs/trainers/SAE/trainer_RandomTarget.yaml: -------------------------------------------------------------------------------- 1 | _target_: pytorch_lightning.Trainer 2 | log_every_n_steps: 20 3 | max_epochs: 200 4 | gradient_clip_val: 0.01 5 | # Prevent console output from individual models 6 | enable_progress_bar: False 7 | # weights_summary: null 8 | #auto_scale_batch_size: True 9 | -------------------------------------------------------------------------------- /ctd/data_modeling/datamodules/LFADS/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snel-repo/ComputationThroughDynamicsBenchmark/60caed2c6bb6e73ce3d05f99eda992f509e9ed94/ctd/data_modeling/datamodules/LFADS/__init__.py -------------------------------------------------------------------------------- /ctd/data_modeling/datamodules/LFADS/tuples.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | 3 | SessionBatch = namedtuple( 4 | "SessionBatch", 5 | [ 6 | "encod_data", 7 | "recon_data", 8 | "ext_input", 9 | "truth", 10 | "sv_mask", 11 | ], 12 | ) 13 | 14 | SessionOutput = namedtuple( 15 | "SessionOutput", 16 | [ 17 | "output_params", 18 | "factors", 19 | "ic_mean", 20 | "ic_std", 21 | "co_means", 22 | "co_stds", 23 | "gen_states", 24 | "gen_init", 25 | "gen_inputs", 26 | "con_states", 27 | ], 28 | ) 29 | Batch = namedtuple( 30 | "Batch", 31 | [ 32 | "encod_data", 33 | "recon_data", 34 | "ext_input", 35 | "truth", 36 | "sv_mask", 37 | ], 38 | ) 39 | 40 | Output = namedtuple( 41 | "Output", 42 | [ 43 | "output_params", 44 | "factors", 45 | "ic_mean", 46 | "ic_std", 47 | "co_means", 48 | "co_stds", 49 | "gen_states", 50 | "gen_init", 51 | "gen_inputs", 52 | "con_states", 53 | ], 54 | ) 55 | -------------------------------------------------------------------------------- /ctd/data_modeling/datamodules/SAE/task_trained_data.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | import dotenv 5 | import h5py 6 | import pytorch_lightning as pl 7 | import torch 8 | from torch.utils.data import DataLoader, TensorDataset 9 | 10 | logger = logging.getLogger(__name__) 11 | dotenv.load_dotenv(override=True) 12 | HOME_DIR = os.environ.get("HOME_DIR") 13 | 14 | 15 | def to_tensor(array): 16 | return torch.tensor(array, dtype=torch.float) 17 | 18 | 19 | class TaskTrainedRNNDataModule(pl.LightningDataModule): 20 | def __init__( 21 | self, 22 | embed_dict: dict, 23 | noise_dict: dict, 24 | neuron_dict: dict, 25 | prefix=None, 26 | system: str = "3BFF", 27 | seed: int = 0, 28 | batch_size: int = 64, 29 | num_workers: int = 2, 30 | provide_inputs: bool = True, 31 | file_index: int = 0, 32 | ): 33 | super().__init__() 34 | self.save_hyperparameters() 35 | self.neuron_dict = neuron_dict 36 | self.seed = seed 37 | self.noise_dict = noise_dict 38 | self.embed_dict = embed_dict 39 | self.data_dir = os.path.join(HOME_DIR, "content", "datasets", "dd") 40 | 41 | filedir = prefix 42 | fpath = os.path.join(self.data_dir, filedir) 43 | dirs = os.listdir(fpath) 44 | if file_index >= len(dirs): 45 | raise ValueError( 46 | f"File index {file_index} is out of range for directory {fpath}" 47 | ) 48 | else: 49 | run_folder = dirs[file_index] 50 | 51 | filename = ( 52 | f"heldin_{neuron_dict['n_heldin']}_heldout_{neuron_dict['n_heldout']}" 53 | ) 54 | if embed_dict["rect_func"] not in ["exp"]: 55 | for key, val in self.embed_dict.items(): 56 | filename += f"_{key}_{val}" 57 | 58 | if noise_dict["obs_noise"] not in ["poisson"]: 59 | for key, val in self.noise_dict.items(): 60 | filename += f"_{key}_{val}" 61 | 62 | filename += f"_seed_{seed}" 63 | 64 | self.run_folder = run_folder 65 | self.name = filename 66 | 67 | self.fpath = filedir 68 | self.system = system 69 | 70 | def prepare_data(self): 71 | filename = self.name 72 | fpath = os.path.join( 73 | self.data_dir, self.fpath, self.run_folder, filename + ".h5" 74 | ) 75 | if os.path.isfile(fpath): 76 | logger.info(f"Loading dataset {self.name}") 77 | return 78 | else: 79 | # throw an error here 80 | raise FileNotFoundError(f"Dataset {self.name} not found at {self.fpath}") 81 | 82 | def setup(self, stage=None): 83 | """ 84 | Attach data to the datamodule 85 | 86 | TODO: REVISE 87 | 88 | Args: 89 | stage (TODO: dtype) 90 | 91 | Returns: 92 | None 93 | """ 94 | 95 | # Load data arrays from file 96 | data_path = os.path.join( 97 | self.data_dir, self.fpath, self.run_folder, self.name + ".h5" 98 | ) 99 | with h5py.File(data_path, "r") as h5file: 100 | # Load the data 101 | train_data = to_tensor(h5file["train_encod_data"][()]) 102 | valid_data = to_tensor(h5file["valid_encod_data"][()]) 103 | 104 | train_recon_data = to_tensor(h5file["train_recon_data"][()]) 105 | valid_recon_data = to_tensor(h5file["valid_recon_data"][()]) 106 | # test_data = to_tensor(h5file["test_data"][()]) 107 | # Load the activity 108 | train_activity = to_tensor(h5file["train_activity"][()]) 109 | valid_activity = to_tensor(h5file["valid_activity"][()]) 110 | # test_activity = to_tensor(h5file["test_activity"][()]) 111 | # Load the latents 112 | train_latents = to_tensor(h5file["train_latents"][()]) 113 | valid_latents = to_tensor(h5file["valid_latents"][()]) 114 | # test_latents = to_tensor(h5file["test_latents"][()]) 115 | # Load the indices 116 | train_inds = to_tensor(h5file["train_inds"][()]) 117 | valid_inds = to_tensor(h5file["valid_inds"][()]) 118 | # test_inds = to_tensor(h5file["test_inds"][()]) 119 | # Load other parameters 120 | self.orig_mean = h5file["orig_mean"][()] 121 | self.orig_std = h5file["orig_std"][()] 122 | self.readout = h5file["readout"][()] 123 | 124 | train_inputs = h5file["train_inputs"][()] 125 | valid_inputs = h5file["valid_inputs"][()] 126 | 127 | train_extra = h5file["train_extra"][()] 128 | valid_extra = h5file["valid_extra"][()] 129 | 130 | # self.test_inputs = h5file["test_inputs"][()] 131 | 132 | train_inputs = to_tensor(train_inputs) 133 | valid_inputs = to_tensor(valid_inputs) 134 | 135 | train_extra = to_tensor(train_extra) 136 | valid_extra = to_tensor(valid_extra) 137 | 138 | if self.hparams.provide_inputs: 139 | # Store datasets 140 | self.train_ds = TensorDataset( 141 | train_data, 142 | train_recon_data, 143 | train_inputs, 144 | train_extra, 145 | train_latents, 146 | train_inds, 147 | train_activity, 148 | ) 149 | 150 | self.valid_ds = TensorDataset( 151 | valid_data, 152 | valid_recon_data, 153 | valid_inputs, 154 | valid_extra, 155 | valid_latents, 156 | valid_inds, 157 | valid_activity, 158 | ) 159 | else: 160 | self.train_ds = TensorDataset( 161 | train_data, 162 | train_recon_data, 163 | None, 164 | train_extra, 165 | train_latents, 166 | train_inds, 167 | train_activity, 168 | ) 169 | 170 | self.valid_ds = TensorDataset( 171 | valid_data, 172 | valid_recon_data, 173 | None, 174 | valid_extra, 175 | valid_latents, 176 | valid_inds, 177 | valid_activity, 178 | ) 179 | 180 | # self.test_ds = TensorDataset( 181 | # test_data, test_data, test_inputs, test_latents, test_inds, test_activity 182 | # ) 183 | 184 | def train_dataloader(self, shuffle=True): 185 | train_dl = DataLoader( 186 | self.train_ds, 187 | batch_size=self.hparams.batch_size, 188 | num_workers=self.hparams.num_workers, 189 | shuffle=shuffle, 190 | ) 191 | return train_dl 192 | 193 | def val_dataloader(self): 194 | valid_dl = DataLoader( 195 | self.valid_ds, 196 | batch_size=self.hparams.batch_size, 197 | num_workers=self.hparams.num_workers, 198 | ) 199 | return valid_dl 200 | -------------------------------------------------------------------------------- /ctd/data_modeling/datamodules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snel-repo/ComputationThroughDynamicsBenchmark/60caed2c6bb6e73ce3d05f99eda992f509e9ed94/ctd/data_modeling/datamodules/__init__.py -------------------------------------------------------------------------------- /ctd/data_modeling/datamodules/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def softplusActivation(module, input): 6 | return module.log(1 + module.exp(input)) 7 | 8 | 9 | def tanhActivation(module, input): 10 | return (module.tanh(input) + 1) / 2 11 | 12 | 13 | def sigmoidActivation(module, input): 14 | return 1 / (1 + module.exp(-1 * input)) 15 | 16 | 17 | def apply_data_warp(data): 18 | warp_functions = [tanhActivation, sigmoidActivation, softplusActivation] 19 | firingMax = [2, 4, 6, 8] 20 | numDims = data.shape[1] 21 | 22 | a = np.array(1) 23 | dataGen = type(a) == type(data) 24 | if dataGen: 25 | module = np 26 | else: 27 | module = torch 28 | 29 | for i in range(numDims): 30 | 31 | j = np.mod(i, len(warp_functions) * len(firingMax)) 32 | # print(f'Max firing {firingMax[np.mod(j, len(firingMax))]} 33 | # warp {warp_functions[int(np.floor((j)/(len(warp_functions)+1)))]}') 34 | data[:, i] = firingMax[np.mod(j, len(firingMax))] * warp_functions[ 35 | int(np.floor((j) / (len(warp_functions) + 1))) 36 | ](module, data[:, i]) 37 | 38 | return data 39 | 40 | 41 | def apply_data_warp_sigmoid(data): 42 | warp_functions = [sigmoidActivation, sigmoidActivation, sigmoidActivation] 43 | firingMax = [2, 2, 2, 2] 44 | numDims = data.shape[1] 45 | 46 | a = np.array(1) 47 | dataGen = type(a) == type(data) 48 | if dataGen: 49 | module = np 50 | else: 51 | module = torch 52 | 53 | for i in range(numDims): 54 | 55 | j = np.mod(i, len(warp_functions) * len(firingMax)) 56 | # print(f'Max firing {firingMax[np.mod(j, len(firingMax))]} 57 | # warp {warp_functions[int(np.floor((j)/(len(warp_functions)+1)))]}') 58 | data[:, i] = firingMax[np.mod(j, len(firingMax))] * warp_functions[ 59 | int(np.floor((j) / (len(warp_functions) + 1))) 60 | ](module, data[:, i]) 61 | 62 | return data 63 | 64 | 65 | def make_data_tag(dm_cfg): 66 | obs_dim = "" if "obs_dim" not in dm_cfg else dm_cfg.obs_dim 67 | obs_noise = "" if "obs_noise" not in dm_cfg else dm_cfg.obs_noise 68 | if "obs_noise_params" in dm_cfg: 69 | obs_noise_params = ",".join( 70 | [f"{k}={v}" for k, v in dm_cfg.obs_noise_params.items()] 71 | ) 72 | else: 73 | obs_noise_params = "" 74 | data_tag = ( 75 | f"{dm_cfg.system}{obs_dim}_" 76 | f"{dm_cfg.n_samples}S_" 77 | f"{dm_cfg.n_timesteps}T_" 78 | f"{dm_cfg.pts_per_period}P_" 79 | f"{dm_cfg.seed}seed" 80 | ) 81 | if obs_noise: 82 | data_tag += f"_{obs_noise}{obs_noise_params}" 83 | return data_tag 84 | 85 | 86 | def make_data_tag_multi_system(dm_cfg): 87 | if "obs_noise_params" in dm_cfg: 88 | obs_noise_params = ",".join( 89 | [f"{k}={v}" for k, v in dm_cfg.obs_noise_params.items()] 90 | ) 91 | else: 92 | obs_noise_params = "" 93 | data_tag = ( 94 | "MultiSystem_" 95 | f"{dm_cfg.n_samples}S_" 96 | f"{dm_cfg.n_timesteps}T_" 97 | f"{dm_cfg.pts_per_period}P_" 98 | f"{dm_cfg.seed}seed_" 99 | f"{dm_cfg.obs_noise}{obs_noise_params}" 100 | ) 101 | return data_tag 102 | 103 | 104 | def flatten(dictionary, level=[]): 105 | """Flattens a dictionary by placing '.' between levels. 106 | This function flattens a hierarchical dictionary by placing '.' 107 | between keys at various levels to create a single key for each 108 | value. It is used internally for converting the configuration 109 | dictionary to more convenient formats. Implementation was 110 | inspired by `this StackOverflow post 111 | `_. 112 | Parameters 113 | ---------- 114 | dictionary : dict 115 | The hierarchical dictionary to be flattened. 116 | level : str, optional 117 | The string to append to the beginning of this dictionary, 118 | enabling recursive calls. By default, an empty string. 119 | Returns 120 | ------- 121 | dict 122 | The flattened dictionary. 123 | See Also 124 | -------- 125 | lfads_tf2.utils.unflatten : Performs the opposite of this operation. 126 | """ 127 | 128 | tmp_dict = {} 129 | for key, val in dictionary.items(): 130 | if type(val) == dict: 131 | tmp_dict.update(flatten(val, level + [key])) 132 | else: 133 | tmp_dict[".".join(level + [key])] = val 134 | return tmp_dict 135 | -------------------------------------------------------------------------------- /ctd/data_modeling/extensions/LFADS/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchmetrics import Metric 3 | 4 | 5 | def r2_score(preds, targets): 6 | if preds.ndim > 2: 7 | preds = preds.reshape(-1, preds.shape[-1]) 8 | if targets.ndim > 2: 9 | targets = targets.reshape(-1, targets.shape[-1]) 10 | target_mean = torch.mean(targets, dim=0) 11 | ss_tot = torch.sum((targets - target_mean) ** 2, dim=0) 12 | ss_res = torch.sum((targets - preds) ** 2, dim=0) 13 | return torch.mean(1 - ss_res / ss_tot) 14 | 15 | 16 | class ExpSmoothedMetric(Metric): 17 | """Averages within epochs and exponentially smooths between epochs.""" 18 | 19 | def __init__(self, coef=0.9, **kwargs): 20 | super().__init__(**kwargs) 21 | self.coef = coef 22 | # PTL will automatically `reset` these after each epoch 23 | self.add_state("value", default=torch.tensor(0.0)) 24 | self.add_state("count", default=torch.tensor(0)) 25 | # Previous value must be immune to `reset` 26 | self.prev = torch.tensor(float("nan")) 27 | 28 | def update(self, value, batch_size): 29 | self.value += value * batch_size 30 | self.count += batch_size 31 | 32 | def compute(self): 33 | curr = self.value / self.count 34 | if torch.isnan(self.prev): 35 | self.prev = curr 36 | smth = self.coef * self.prev + (1 - self.coef) * curr 37 | self.prev = smth 38 | return smth 39 | -------------------------------------------------------------------------------- /ctd/data_modeling/extensions/LFADS/post_run/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snel-repo/ComputationThroughDynamicsBenchmark/60caed2c6bb6e73ce3d05f99eda992f509e9ed94/ctd/data_modeling/extensions/LFADS/post_run/__init__.py -------------------------------------------------------------------------------- /ctd/data_modeling/extensions/LFADS/post_run/analysis.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import shutil 3 | from pathlib import Path 4 | 5 | import h5py 6 | import torch 7 | from tqdm import tqdm 8 | 9 | from ....datamodules.LFADS.tuples import SessionOutput 10 | from ..datamodules import reshuffle_train_valid 11 | from ..utils import send_batch_to_device, transpose_lists 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | def run_posterior_sampling(model, datamodule, filename, num_samples=50): 17 | """Runs the model repeatedly to generate outputs for different samples 18 | of the posteriors. Averages these outputs and saves them to an output file. 19 | 20 | Parameters 21 | ---------- 22 | model : lfads_torch.models.base_model.LFADS 23 | A trained LFADS model. 24 | datamodule : pytorch_lightning.LightningDataModule 25 | The `LightningDataModule` to pass through the `model`. 26 | filename : str 27 | The filename to use for saving output 28 | num_samples : int, optional 29 | The number of forward passes to average, by default 50 30 | """ 31 | # Convert filename to pathlib.Path for convenience 32 | filename = Path(filename) 33 | # Set up the dataloaders 34 | datamodule.setup() 35 | pred_dls = datamodule.predict_dataloader() 36 | 37 | # Function to run posterior sampling for a single session at a time 38 | def run_ps_batch(s, batch): 39 | # Move the batch to the model device 40 | batch = send_batch_to_device({s: batch}, model.device) 41 | # Repeatedly compute the model outputs for this batch 42 | for i in range(num_samples): 43 | # Perform the forward pass through the model 44 | output = model.predict_step(batch, None, sample_posteriors=True)[s] 45 | # Use running sum to save memory while averaging 46 | if i == 0: 47 | # Detach output from the graph to save memory on gradients 48 | sums = [o.detach() for o in output] 49 | else: 50 | sums = [s + o.detach() for s, o in zip(sums, output)] 51 | # Finish averaging by dividing by the total number of samples 52 | return [s / num_samples for s in sums] 53 | 54 | # Compute outputs for one session at a time 55 | for s, dataloaders in pred_dls.items(): 56 | # Give each session a unique file path 57 | sess_fname = f"{filename.stem}_sess{s}{filename.suffix}" 58 | # Copy data file for easy access to original data and indices 59 | dhps = datamodule.hparams 60 | if dhps.reshuffle_tv_seed is not None: 61 | # If the data was shuffled, shuffle it when copying 62 | with h5py.File(dhps.data_paths[s]) as h5file: 63 | data_dict = {k: v[()] for k, v in h5file.items()} 64 | data_dict = reshuffle_train_valid( 65 | data_dict, dhps.reshuffle_tv_seed, dhps.reshuffle_tv_ratio 66 | ) 67 | with h5py.File(sess_fname, "w") as h5file: 68 | for k, v in data_dict.items(): 69 | h5file.create_dataset(k, data=v) 70 | else: 71 | shutil.copyfile(datamodule.hparams.data_paths[s], sess_fname) 72 | for split in dataloaders.keys(): 73 | # Compute average model outputs for each session and then recombine batches 74 | logger.info(f"Running posterior sampling on Session {s} {split} data.") 75 | post_means = [run_ps_batch(s, batch) for batch in tqdm(dataloaders[split])] 76 | post_means = SessionOutput( 77 | *[torch.cat(o).cpu().numpy() for o in transpose_lists(post_means)] 78 | ) 79 | # Save the averages to the output file 80 | with h5py.File(sess_fname, mode="a") as h5file: 81 | for name in SessionOutput._fields: 82 | h5file.create_dataset( 83 | f"{split}_{name}", data=getattr(post_means, name) 84 | ) 85 | # Log message about sucessful completion 86 | logger.info(f"Session {s} posterior means successfully saved to `{sess_fname}`") 87 | -------------------------------------------------------------------------------- /ctd/data_modeling/extensions/LFADS/post_run/pbt.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import json 3 | import os 4 | 5 | import matplotlib.pyplot as plt 6 | import pandas as pd 7 | 8 | 9 | def read_pbt_hps(pbt_dir): 10 | # Get the initial values from the results files 11 | result_files = glob.glob(os.path.join(pbt_dir, "run_model_*/result.json")) 12 | 13 | def get_first_result(fpath): 14 | with open(fpath, "r") as file: 15 | result = json.loads(file.readline()) 16 | return result 17 | 18 | inits = pd.DataFrame([get_first_result(fpath) for fpath in result_files]) 19 | # Get the perturbations from `pbt_global.txt` 20 | pbt_global_path = os.path.join(pbt_dir, "pbt_global.txt") 21 | with open(pbt_global_path, "r") as file: 22 | perturbs = [json.loads(line) for line in file.read().splitlines()] 23 | perturbs = pd.DataFrame( 24 | perturbs, 25 | columns=[ 26 | "target_tag", 27 | "clone_tag", 28 | "target_iteration", 29 | "cur_epoch", 30 | "old_config", 31 | "config", 32 | ], 33 | ) 34 | # Use trial_num to match intial values to perturbations 35 | inits["trial_num"] = inits.trial_id.apply(lambda x: int(x.split("_")[1])) 36 | perturbs["trial_num"] = perturbs.target_tag.apply(lambda x: int(x.split("_")[0])) 37 | # Combine initial values and perturbations 38 | hps_df = pd.concat([inits, perturbs]).reset_index() 39 | hps_df = hps_df[["trial_num", "cur_epoch", "config"]] 40 | # Expand the config dictionary into separate columns and recombine 41 | configs = pd.json_normalize(hps_df.pop("config")) 42 | hps_df = pd.concat([hps_df, configs], axis=1) 43 | return hps_df 44 | 45 | 46 | def plot_pbt_hps(pbt_dir, plot_field, save_dir=None, **kwargs): 47 | """Plots an HP for all models over the course of PBT. 48 | This function generates a plot to visualize how an HP 49 | changes over the course of PBT. 50 | Parameters 51 | ---------- 52 | pbt_dir : str 53 | The path to the PBT run. 54 | plot_field : str 55 | The HP to plot. See the HP log headers or lfads_tf2 56 | source code for options. 57 | save_dir : str, optional 58 | The directory for saving the figure, by default None will 59 | show an interactive plot 60 | kwargs: optional 61 | Any keyword arguments to be passed to pandas.DataFrame.plot 62 | """ 63 | 64 | hps_df = read_pbt_hps(pbt_dir) 65 | plot_df = hps_df.pivot(index="cur_epoch", columns="trial_num", values=plot_field) 66 | plot_df = plot_df.ffill() 67 | gen_range = plot_df.index.min(), plot_df.index.max() 68 | field_range = plot_df.min().min(), plot_df.max().max() 69 | plot_kwargs = dict( 70 | drawstyle="steps-post", 71 | legend=False, 72 | logy=True, 73 | c="b", 74 | alpha=0.2, 75 | title=f"{plot_field} for PBT run at {pbt_dir}", 76 | xlim=gen_range, 77 | ylim=field_range, 78 | figsize=(10, 5), 79 | ) 80 | plot_kwargs.update(kwargs) 81 | plot_df.plot(**plot_kwargs) 82 | if save_dir is not None: 83 | filename = plot_field.replace(".", "_").lower() 84 | fig_path = os.path.join(save_dir, f"{filename}.png") 85 | plt.savefig(fig_path, bbox_inches="tight") 86 | plt.close() 87 | -------------------------------------------------------------------------------- /ctd/data_modeling/extensions/LFADS/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ...datamodules.LFADS.tuples import Batch, SessionBatch 4 | 5 | 6 | def flatten(dictionary, level=[]): 7 | """Flattens a dictionary by placing '.' between levels. 8 | This function flattens a hierarchical dictionary by placing '.' 9 | between keys at various levels to create a single key for each 10 | value. It is used internally for converting the configuration 11 | dictionary to more convenient formats. Implementation was 12 | inspired by `this StackOverflow post 13 | `_. 14 | Parameters 15 | ---------- 16 | dictionary : dict 17 | The hierarchical dictionary to be flattened. 18 | level : str, optional 19 | The string to append to the beginning of this dictionary, 20 | enabling recursive calls. By default, an empty string. 21 | Returns 22 | ------- 23 | dict 24 | The flattened dictionary. 25 | """ 26 | 27 | tmp_dict = {} 28 | for key, val in dictionary.items(): 29 | if type(val) == dict: 30 | tmp_dict.update(flatten(val, level + [key])) 31 | else: 32 | tmp_dict[".".join(level + [key])] = val 33 | return tmp_dict 34 | 35 | 36 | def transpose_lists(output: list[list]): 37 | """Transposes the ordering of a list of lists.""" 38 | return list(map(list, zip(*output))) 39 | 40 | 41 | def send_batch_to_device(batch, device): 42 | """Recursively searches the batch for tensors and sends them to the device""" 43 | 44 | def send_to_device(obj): 45 | obj_type = type(obj) 46 | if obj_type == torch.Tensor: 47 | return obj.to(device) 48 | elif obj_type == dict: 49 | return {k: send_to_device(v) for k, v in obj.items()} 50 | elif obj_type == list: 51 | return [send_to_device(o) for o in obj] 52 | elif obj_type == SessionBatch: 53 | return SessionBatch(*[send_to_device(o) for o in obj]) 54 | elif obj_type == Batch: 55 | return Batch(send_to_device(obj)) 56 | else: 57 | raise NotImplementedError( 58 | f"`send_batch_to_device` has not been implemented for {str(obj_type)}." 59 | ) 60 | 61 | return send_to_device(batch) 62 | -------------------------------------------------------------------------------- /ctd/data_modeling/models/LFADS/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snel-repo/ComputationThroughDynamicsBenchmark/60caed2c6bb6e73ce3d05f99eda992f509e9ed94/ctd/data_modeling/models/LFADS/__init__.py -------------------------------------------------------------------------------- /ctd/data_modeling/models/LFADS/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snel-repo/ComputationThroughDynamicsBenchmark/60caed2c6bb6e73ce3d05f99eda992f509e9ed94/ctd/data_modeling/models/LFADS/modules/__init__.py -------------------------------------------------------------------------------- /ctd/data_modeling/models/LFADS/modules/encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | 5 | from .initializers import init_linear_ 6 | from .recurrent import BidirectionalClippedGRU 7 | 8 | 9 | class Encoder(nn.Module): 10 | def __init__(self, hparams: dict): 11 | super().__init__() 12 | self.hparams = hps = hparams 13 | 14 | # Initial hidden state for IC encoder 15 | self.ic_enc_h0 = nn.Parameter( 16 | torch.zeros((2, 1, hps.ic_enc_dim), requires_grad=True) 17 | ) 18 | # Initial condition encoder 19 | self.ic_enc = BidirectionalClippedGRU( 20 | input_size=hps.encod_data_dim, 21 | hidden_size=hps.ic_enc_dim, 22 | clip_value=hps.cell_clip, 23 | ) 24 | # Mapping from final IC encoder state to IC parameters 25 | self.ic_linear = nn.Linear(hps.ic_enc_dim * 2, hps.ic_dim * 2) 26 | init_linear_(self.ic_linear) 27 | # Decide whether to use the controller 28 | self.use_con = all( 29 | [ 30 | hps.ci_enc_dim > 0, 31 | hps.con_dim > 0, 32 | hps.co_dim > 0, 33 | ] 34 | ) 35 | if self.use_con: 36 | # Initial hidden state for CI encoder 37 | self.ci_enc_h0 = nn.Parameter( 38 | torch.zeros((2, 1, hps.ci_enc_dim), requires_grad=True) 39 | ) 40 | # CI encoder 41 | self.ci_enc = BidirectionalClippedGRU( 42 | input_size=hps.encod_data_dim, 43 | hidden_size=hps.ci_enc_dim, 44 | clip_value=hps.cell_clip, 45 | ) 46 | # Activation dropout layer 47 | self.dropout = nn.Dropout(hps.dropout_rate) 48 | 49 | def forward(self, data: torch.Tensor): 50 | hps = self.hparams 51 | batch_size = data.shape[0] 52 | assert data.shape[1] == hps.encod_seq_len, ( 53 | f"Sequence length specified in HPs ({hps.encod_seq_len}) " 54 | f"must match data dim 1 ({data.shape[1]})." 55 | ) 56 | data_drop = self.dropout(data) 57 | # option to use separate segment for IC encoding 58 | if hps.ic_enc_seq_len > 0: 59 | ic_enc_data = data_drop[:, : hps.ic_enc_seq_len, :] 60 | ci_enc_data = data_drop[:, hps.ic_enc_seq_len :, :] 61 | else: 62 | ic_enc_data = data_drop 63 | ci_enc_data = data_drop 64 | # Pass data through IC encoder 65 | ic_enc_h0 = torch.tile(self.ic_enc_h0, (1, batch_size, 1)) 66 | _, h_n = self.ic_enc(ic_enc_data, ic_enc_h0) 67 | h_n = torch.cat([*h_n], dim=1) 68 | # Compute initial condition posterior 69 | h_n_drop = self.dropout(h_n) 70 | ic_params = self.ic_linear(h_n_drop) 71 | ic_mean, ic_logvar = torch.split(ic_params, hps.ic_dim, dim=1) 72 | ic_std = torch.sqrt(torch.exp(ic_logvar) + hps.ic_post_var_min) 73 | if self.use_con: 74 | # Pass data through CI encoder 75 | ci_enc_h0 = torch.tile(self.ci_enc_h0, (1, batch_size, 1)) 76 | ci, _ = self.ci_enc(ci_enc_data, ci_enc_h0) 77 | # Add a lag to the controller input 78 | ci_fwd, ci_bwd = torch.split(ci, hps.ci_enc_dim, dim=2) 79 | ci_fwd = F.pad(ci_fwd, (0, 0, hps.ci_lag, 0, 0, 0)) 80 | ci_bwd = F.pad(ci_bwd, (0, 0, 0, hps.ci_lag, 0, 0)) 81 | ci_len = hps.encod_seq_len - hps.ic_enc_seq_len 82 | ci = torch.cat([ci_fwd[:, :ci_len, :], ci_bwd[:, -ci_len:, :]], dim=2) 83 | # Add extra zeros if necessary for forward prediction 84 | fwd_steps = hps.recon_seq_len - hps.encod_seq_len 85 | ci = F.pad(ci, (0, 0, 0, fwd_steps, 0, 0)) 86 | else: 87 | # Create a placeholder if there's no controller 88 | ci = torch.zeros(data.shape[0], hps.recon_seq_len, 0).to(data.device) 89 | 90 | return ic_mean, ic_std, ci 91 | -------------------------------------------------------------------------------- /ctd/data_modeling/models/LFADS/modules/initializers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | def init_variance_scaling_(weight, scale_dim: int): 6 | scale_dim = torch.tensor(scale_dim) 7 | nn.init.normal_(weight, std=1 / torch.sqrt(scale_dim)) 8 | 9 | 10 | def init_linear_(linear: nn.Linear): 11 | init_variance_scaling_(linear.weight, linear.in_features) 12 | if linear.bias is not None: 13 | nn.init.zeros_(linear.bias) 14 | 15 | 16 | def init_gru_cell_(cell: nn.GRUCell, scale_dim: int = None): 17 | if scale_dim is None: 18 | ih_scale = cell.input_size 19 | hh_scale = cell.hidden_size 20 | else: 21 | ih_scale = hh_scale = scale_dim 22 | init_variance_scaling_(cell.weight_ih, ih_scale) 23 | init_variance_scaling_(cell.weight_hh, hh_scale) 24 | nn.init.ones_(cell.bias_ih) 25 | cell.bias_ih.data[-cell.hidden_size :] = 0.0 26 | # NOTE: these weights are not present in TF 27 | nn.init.zeros_(cell.bias_hh) 28 | -------------------------------------------------------------------------------- /ctd/data_modeling/models/LFADS/modules/l2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def compute_l2_penalty(lfads, hps): 5 | if hps.gen_type == "RNN": 6 | recurrent_kernels_and_weights = [ 7 | (lfads.encoder.ic_enc.fwd_gru.cell.weight_hh, hps.l2_ic_enc_scale), 8 | (lfads.encoder.ic_enc.bwd_gru.cell.weight_hh, hps.l2_ic_enc_scale), 9 | (lfads.decoder.rnn.cell.gen_cell.weight_hh, hps.l2_gen_scale), 10 | ] 11 | for param in lfads.readout.parameters(): 12 | recurrent_kernels_and_weights.append((param, hps.l2_readout_scale)) 13 | elif hps.gen_type == "NODE": 14 | recurrent_kernels_and_weights = [ 15 | (lfads.encoder.ic_enc.fwd_gru.cell.weight_hh, hps.l2_ic_enc_scale), 16 | (lfads.encoder.ic_enc.bwd_gru.cell.weight_hh, hps.l2_ic_enc_scale), 17 | ] 18 | for param in lfads.decoder.rnn.cell.gen_cell.parameters(): 19 | recurrent_kernels_and_weights.append((param, hps.l2_gen_scale)) 20 | for param in lfads.readout.parameters(): 21 | recurrent_kernels_and_weights.append((param, hps.l2_readout_scale)) 22 | if lfads.use_con: 23 | recurrent_kernels_and_weights.extend( 24 | [ 25 | (lfads.encoder.ci_enc.fwd_gru.cell.weight_hh, hps.l2_ci_enc_scale), 26 | (lfads.encoder.ci_enc.bwd_gru.cell.weight_hh, hps.l2_ci_enc_scale), 27 | (lfads.decoder.rnn.cell.con_cell.weight_hh, hps.l2_con_scale), 28 | ] 29 | ) 30 | # Add recurrent penalty 31 | recurrent_penalty = 0.0 32 | recurrent_size = 0 33 | for kernel, weight in recurrent_kernels_and_weights: 34 | if weight > 0: 35 | recurrent_penalty += weight * 0.5 * torch.norm(kernel, 2) ** 2 36 | recurrent_size += kernel.numel() 37 | recurrent_penalty /= recurrent_size + 1e-8 38 | # Add recon penalty if applicable 39 | recon_penalty = 0.0 40 | if hasattr(lfads.recon, "compute_l2"): 41 | recon_penalty = lfads.recon.compute_l2() 42 | return recurrent_penalty + recon_penalty 43 | -------------------------------------------------------------------------------- /ctd/data_modeling/models/LFADS/modules/l2_simple.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def compute_l2_penalty(lfads, hps): 5 | if hps.gen_type == "RNN": 6 | recurrent_kernels_and_weights = [ 7 | (lfads.encoder.ic_enc.fwd_gru.cell.weight_hh, hps.l2_ic_enc_scale), 8 | (lfads.encoder.ic_enc.bwd_gru.cell.weight_hh, hps.l2_ic_enc_scale), 9 | (lfads.decoder.rnn.cell.gen_cell.weight_hh, hps.l2_gen_scale), 10 | ] 11 | elif hps.gen_type == "NODE": 12 | recurrent_kernels_and_weights = [ 13 | (lfads.encoder.ic_enc.fwd_gru.cell.weight_hh, hps.l2_ic_enc_scale), 14 | (lfads.encoder.ic_enc.bwd_gru.cell.weight_hh, hps.l2_ic_enc_scale), 15 | ] 16 | for param in lfads.decoder.rnn.cell.gen_cell.parameters(): 17 | recurrent_kernels_and_weights.append((param, hps.l2_gen_scale)) 18 | if lfads.use_con: 19 | recurrent_kernels_and_weights.extend( 20 | [ 21 | (lfads.encoder.ci_enc.fwd_gru.cell.weight_hh, hps.l2_ci_enc_scale), 22 | (lfads.encoder.ci_enc.bwd_gru.cell.weight_hh, hps.l2_ci_enc_scale), 23 | (lfads.decoder.rnn.cell.con_cell.weight_hh, hps.l2_con_scale), 24 | ] 25 | ) 26 | # Add recurrent penalty 27 | recurrent_penalty = 0.0 28 | recurrent_size = 0 29 | for kernel, weight in recurrent_kernels_and_weights: 30 | if weight > 0: 31 | recurrent_penalty += weight * 0.5 * torch.norm(kernel, 2) ** 2 32 | recurrent_size += kernel.numel() 33 | recurrent_penalty /= recurrent_size + 1e-8 34 | # Add recon penalty if applicable 35 | recon_penalty = 0.0 36 | if hasattr(lfads.recon, "compute_l2"): 37 | recon_penalty = lfads.recon.compute_l2() 38 | return recurrent_penalty + recon_penalty 39 | -------------------------------------------------------------------------------- /ctd/data_modeling/models/LFADS/modules/priors.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.distributions import Independent, Normal, StudentT, kl_divergence 4 | from torch.distributions.transforms import AffineTransform 5 | 6 | 7 | class Null(nn.Module): 8 | def make_posterior(self, *args): 9 | return None 10 | 11 | def forward(self, *args): 12 | return 0 13 | 14 | 15 | class MultivariateNormal(nn.Module): 16 | def __init__( 17 | self, 18 | mean: float, 19 | variance: float, 20 | shape: int, 21 | ): 22 | super().__init__() 23 | # Create distribution parameter tensors 24 | means = torch.ones(shape) * mean 25 | logvars = torch.log(torch.ones(shape) * variance) 26 | self.mean = nn.Parameter(means, requires_grad=True) 27 | self.logvar = nn.Parameter(logvars, requires_grad=False) 28 | 29 | def make_posterior(self, post_mean, post_std): 30 | return Independent(Normal(post_mean, post_std), 1) 31 | 32 | def forward(self, post_mean, post_std): 33 | # Create the posterior distribution 34 | posterior = self.make_posterior(post_mean, post_std) 35 | # Create the prior and posterior 36 | prior_std = torch.exp(0.5 * self.logvar) 37 | prior = Independent(Normal(self.mean, prior_std), 1) 38 | # Compute KL analytically 39 | kl_batch = kl_divergence(posterior, prior) 40 | return torch.mean(kl_batch) 41 | 42 | 43 | class AutoregressiveMultivariateNormal(nn.Module): 44 | def __init__( 45 | self, 46 | tau: float, 47 | nvar: float, 48 | shape: int, 49 | ): 50 | super().__init__() 51 | # Create the distribution parameters 52 | logtaus = torch.log(torch.ones(shape) * tau) 53 | lognvars = torch.log(torch.ones(shape) * nvar) 54 | self.logtaus = nn.Parameter(logtaus, requires_grad=True) 55 | self.lognvars = nn.Parameter(lognvars, requires_grad=True) 56 | 57 | def make_posterior(self, post_mean, post_std): 58 | return Independent(Normal(post_mean, post_std), 2) 59 | 60 | def log_prob(self, sample): 61 | # Compute alpha and process variance 62 | alphas = torch.exp(-1.0 / torch.exp(self.logtaus)) 63 | logpvars = self.lognvars - torch.log(1 - alphas**2) 64 | # Create autocorrelative transformation 65 | transform = AffineTransform(loc=0, scale=alphas) 66 | # Align previous samples and compute means and stddevs 67 | prev_samp = torch.roll(sample, shifts=1, dims=1) 68 | means = transform(prev_samp) 69 | stddevs = torch.ones_like(means) * torch.exp(0.5 * self.lognvars) 70 | # Correct the first time point 71 | means[:, 0] = 0.0 72 | stddevs[:, 0] = torch.exp(0.5 * logpvars) 73 | # Create the prior and compute the log-probability 74 | prior = Independent(Normal(means, stddevs), 2) 75 | return prior.log_prob(sample) 76 | 77 | def forward(self, post_mean, post_std): 78 | posterior = self.make_posterior(post_mean, post_std) 79 | sample = posterior.rsample() 80 | log_q = posterior.log_prob(sample) 81 | log_p = self.log_prob(sample) 82 | kl_batch = log_q - log_p 83 | return torch.mean(kl_batch) 84 | 85 | 86 | class MultivariateStudentT(nn.Module): 87 | def __init__( 88 | self, 89 | loc: float, 90 | scale: float, 91 | df: int, 92 | shape: int, 93 | ): 94 | super().__init__() 95 | # Create the distribution parameters 96 | loc = torch.ones(shape) * scale 97 | self.loc = nn.Parameter(loc, requires_grad=True) 98 | logscale = torch.log(torch.ones(shape) * scale) 99 | self.logscale = nn.Parameter(logscale, requires_grad=True) 100 | self.df = df 101 | 102 | def make_posterior(self, post_loc, post_scale): 103 | # TODO: Should probably be inferring degrees of freedom along with loc and scale 104 | return Independent(StudentT(self.df, post_loc, post_scale), 1) 105 | 106 | def forward(self, post_loc, post_scale): 107 | # Create the posterior distribution 108 | posterior = self.make_posterior(post_loc, post_scale) 109 | # Create the prior distribution 110 | prior_scale = torch.exp(self.logscale) 111 | prior = Independent(StudentT(self.df, self.loc, prior_scale), 1) 112 | # Approximate KL divergence 113 | sample = posterior.rsample() 114 | log_q = posterior.log_prob(sample) 115 | log_p = prior.log_prob(sample) 116 | kl_batch = log_q - log_p 117 | return torch.mean(kl_batch) 118 | -------------------------------------------------------------------------------- /ctd/data_modeling/models/LFADS/modules/readin_readout.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch import nn 5 | 6 | # import h5py 7 | 8 | 9 | class FanInLinear(nn.Linear): 10 | def reset_parameters(self): 11 | super().reset_parameters() 12 | nn.init.normal_(self.weight, std=1 / math.sqrt(self.in_features)) 13 | nn.init.constant_(self.bias, 0.0) 14 | 15 | 16 | class MLP(nn.Module): 17 | def __init__(self, in_features, hidden_size, num_layers, out_features): 18 | super(MLP, self).__init__() 19 | self.layers = nn.ModuleList() 20 | for i in range(num_layers): 21 | if i == 0: 22 | self.layers.append(nn.Linear(in_features, hidden_size)) 23 | else: 24 | self.layers.append(nn.Linear(hidden_size, hidden_size)) 25 | self.output_layer = nn.Linear(hidden_size, out_features) 26 | 27 | def forward(self, x): 28 | for layer in self.layers: 29 | x = torch.relu(layer(x)) 30 | x = self.output_layer(x) 31 | return x 32 | 33 | 34 | # class PCRInitModuleList(nn.ModuleList): 35 | # def __init__(self, inits_path: str, modules: list[nn.Module]): 36 | # super().__init__(modules) 37 | # # Pull pre-computed initialization from the file, assuming correct order 38 | # with h5py.File(inits_path, "r") as h5file: 39 | # weights = [v["/" + k + "/matrix"][()] for k, v in h5file.items()] 40 | # biases = [v["/" + k + "/bias"][()] for k, v in h5file.items()] 41 | # # Load the state dict for each layer 42 | # for layer, weight, bias in zip(self, weights, biases): 43 | # state_dict = {"weight": torch.tensor(weight), "bias": torch.tensor(bias)} 44 | # layer.load_state_dict(state_dict) 45 | -------------------------------------------------------------------------------- /ctd/data_modeling/models/LFADS/modules/recons.py: -------------------------------------------------------------------------------- 1 | """This module specifies options for reconstruction losses and 2 | loss-specific parameter processing. 3 | Each loss class must set self.n_params for the number of parameters, 4 | self.process_output_params which performs any fixed transformations 5 | (may be different depending on boolean sample_and_average, i.e. rates 6 | instead of logrates) and separates different parameters in a new inner 7 | dimension, and self.compute_loss which computes the loss for given 8 | tensors of data and inferred parameters. 9 | """ 10 | 11 | import abc 12 | 13 | import torch 14 | import torch.nn.functional as F 15 | from torch import nn 16 | 17 | 18 | class Reconstruction(abc.ABC): 19 | @abc.abstractmethod 20 | def __init__(self): 21 | pass 22 | 23 | @abc.abstractmethod 24 | def reshape_output_params(self, output_params): 25 | pass 26 | 27 | @abc.abstractmethod 28 | def compute_loss(self, data, output_params): 29 | pass 30 | 31 | @abc.abstractmethod 32 | def compute_means(self, output_params): 33 | pass 34 | 35 | 36 | class Poisson(Reconstruction): 37 | def __init__(self): 38 | self.n_params = 1 39 | 40 | def reshape_output_params(self, output_params): 41 | return torch.unsqueeze(output_params, dim=-1) 42 | 43 | def compute_loss(self, data, output_params): 44 | return F.poisson_nll_loss( 45 | output_params[..., 0], 46 | data, 47 | full=True, 48 | reduction="none", 49 | ) 50 | 51 | def compute_means(self, output_params): 52 | return torch.exp(output_params[..., 0]) 53 | 54 | 55 | class MSE(Reconstruction): 56 | def __init__(self): 57 | self.n_params = 1 58 | 59 | def reshape_output_params(self, output_params): 60 | return torch.unsqueeze(output_params, dim=-1) 61 | 62 | def compute_loss(self, data, output_params): 63 | return (data - output_params[..., 0]) ** 2 64 | 65 | def compute_means(self, output_params): 66 | return output_params[..., 0] 67 | 68 | 69 | class Gaussian(Reconstruction): 70 | def __init__(self): 71 | self.n_params = 2 72 | 73 | def reshape_output_params(self, output_params): 74 | means, logvars = torch.chunk(output_params, 2, -1) 75 | return torch.stack([means, logvars], -1) 76 | 77 | def compute_loss(self, data, output_params): 78 | means, logvars = torch.unbind(output_params, axis=-1) 79 | recon_all = F.gaussian_nll_loss( 80 | input=means, target=data, var=torch.exp(logvars), reduction="none" 81 | ) 82 | return recon_all 83 | 84 | def compute_means(self, output_params): 85 | return output_params[..., 0] 86 | 87 | 88 | class Gamma(Reconstruction): 89 | def __init__(self): 90 | self.n_params = 2 91 | 92 | def reshape_output_params(self, output_params): 93 | logalphas, logbetas = torch.chunk(output_params, chunks=2, dim=-1) 94 | return torch.stack([logalphas, logbetas], -1) 95 | 96 | def compute_loss(self, data, output_params): 97 | alphas, betas = torch.unbind(torch.exp(output_params), axis=-1) 98 | output_dist = torch.distributions.Gamma(alphas, betas) 99 | recon_all = -output_dist.log_prob(data) 100 | return recon_all 101 | 102 | def compute_means(self, output_params): 103 | alphas, betas = torch.unbind(torch.exp(output_params), axis=-1) 104 | return alphas / betas 105 | 106 | 107 | class ZeroInflatedGamma(nn.Module, Reconstruction): 108 | def __init__( 109 | self, 110 | recon_dim: int, 111 | gamma_loc: float, 112 | scale_init: float, 113 | scale_prior: float, 114 | scale_penalty: float, 115 | ): 116 | super().__init__() 117 | self.n_params = 3 118 | self.gamma_loc = gamma_loc 119 | # Initialize gamma parameter scaling weights 120 | scale_inits = torch.ones(2, recon_dim) * scale_init 121 | self.scale = nn.Parameter(scale_inits, requires_grad=True) 122 | self.scale_prior = scale_prior 123 | self.scale_penalty = scale_penalty 124 | 125 | def reshape_output_params(self, output_params): 126 | alpha_ps, beta_ps, q_ps = torch.chunk(output_params, chunks=3, dim=-1) 127 | return torch.stack([alpha_ps, beta_ps, q_ps], -1) 128 | 129 | def compute_loss(self, data, output_params): 130 | # Compute the scaled output parameters 131 | alphas, betas, qs = self._compute_scaled_params(output_params) 132 | # Shift data and replace zeros for convenient NLL calculation 133 | nz_ctr_data = torch.where( 134 | data == 0, torch.ones_like(data), data - self.gamma_loc 135 | ) 136 | gamma = torch.distributions.Gamma(alphas, betas) 137 | recon_gamma = -gamma.log_prob(nz_ctr_data) 138 | # Replace with zero-inflated likelihoods 139 | recon_all = torch.where( 140 | data == 0, -torch.log(1 - qs), recon_gamma - torch.log(qs) 141 | ) 142 | return recon_all 143 | 144 | def compute_means(self, output_params): 145 | # Compute the means of the ZIG distribution 146 | alphas, betas, qs = self._compute_scaled_params(output_params) 147 | return qs * (alphas / betas + self.gamma_loc) 148 | 149 | def compute_l2(self): 150 | # Compute an L2 scaling penalty on the gamma parameter scaling 151 | l2 = torch.sum((self.scale - self.scale_prior) ** 2) 152 | return 0.5 * self.scale_penalty * l2 153 | 154 | def _compute_scaled_params(self, output_params): 155 | # Compute sigmoid and clamp to avoid zero-valued rates 156 | sig_params = torch.clamp_min(torch.sigmoid(output_params), 1e-5) 157 | # Separate the parameters 158 | sig_alphas, sig_betas, qs = torch.unbind(sig_params, axis=-1) 159 | # Scale alphas and betas by per-neuron multiplicative factors 160 | alphas = sig_alphas * self.scale[0] 161 | betas = sig_betas * self.scale[1] 162 | return alphas, betas, qs 163 | -------------------------------------------------------------------------------- /ctd/data_modeling/models/LFADS/modules/recurrent.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from .initializers import init_gru_cell_ 5 | 6 | 7 | class ClippedGRUCell(nn.GRUCell): 8 | def __init__( 9 | self, 10 | input_size: int, 11 | hidden_size: int, 12 | clip_value: float = float("inf"), 13 | is_encoder: bool = False, 14 | ): 15 | super().__init__(input_size, hidden_size, bias=True) 16 | self.bias_hh.requires_grad = False 17 | self.clip_value = clip_value 18 | scale_dim = input_size + hidden_size if is_encoder else None 19 | init_gru_cell_(self, scale_dim=scale_dim) 20 | 21 | def forward(self, input: torch.Tensor, hidden: torch.Tensor): 22 | x_all = input @ self.weight_ih.T + self.bias_ih 23 | x_z, x_r, x_n = torch.chunk(x_all, chunks=3, dim=1) 24 | split_dims = [2 * self.hidden_size, self.hidden_size] 25 | weight_hh_zr, weight_hh_n = torch.split(self.weight_hh, split_dims) 26 | bias_hh_zr, bias_hh_n = torch.split(self.bias_hh, split_dims) 27 | h_all = hidden @ weight_hh_zr.T + bias_hh_zr 28 | h_z, h_r = torch.chunk(h_all, chunks=2, dim=1) 29 | z = torch.sigmoid(x_z + h_z) 30 | r = torch.sigmoid(x_r + h_r) 31 | h_n = (r * hidden) @ weight_hh_n.T + bias_hh_n 32 | n = torch.tanh(x_n + h_n) 33 | hidden = z * hidden + (1 - z) * n 34 | hidden = torch.clamp(hidden, -self.clip_value, self.clip_value) 35 | return hidden 36 | 37 | 38 | class ClippedGRU(nn.Module): 39 | def __init__( 40 | self, 41 | input_size: int, 42 | hidden_size: int, 43 | clip_value: float = float("inf"), 44 | ): 45 | super().__init__() 46 | self.cell = ClippedGRUCell( 47 | input_size, hidden_size, clip_value=clip_value, is_encoder=True 48 | ) 49 | 50 | def forward(self, input: torch.Tensor, h_0: torch.Tensor): 51 | hidden = h_0 52 | input = torch.transpose(input, 0, 1) 53 | output = [] 54 | for input_step in input: 55 | hidden = self.cell(input_step, hidden) 56 | output.append(hidden) 57 | output = torch.stack(output, dim=1) 58 | return output, hidden 59 | 60 | 61 | class BidirectionalClippedGRU(nn.Module): 62 | def __init__( 63 | self, 64 | input_size: int, 65 | hidden_size: int, 66 | clip_value: float = float("inf"), 67 | ): 68 | super().__init__() 69 | self.fwd_gru = ClippedGRU(input_size, hidden_size, clip_value=clip_value) 70 | self.bwd_gru = ClippedGRU(input_size, hidden_size, clip_value=clip_value) 71 | 72 | def forward(self, input: torch.Tensor, h_0: torch.Tensor): 73 | h0_fwd, h0_bwd = h_0 74 | input_fwd = input 75 | input_bwd = torch.flip(input, [1]) 76 | output_fwd, hn_fwd = self.fwd_gru(input_fwd, h0_fwd) 77 | output_bwd, hn_bwd = self.bwd_gru(input_bwd, h0_bwd) 78 | output_bwd = torch.flip(output_bwd, [1]) 79 | output = torch.cat([output_fwd, output_bwd], dim=2) 80 | h_n = torch.stack([hn_fwd, hn_bwd]) 81 | return output, h_n 82 | -------------------------------------------------------------------------------- /ctd/data_modeling/models/SAE/LDS.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | import torch 3 | from torch import nn 4 | 5 | from .loss_func import LossFunc, PoissonLossFunc 6 | 7 | 8 | class LinearCell(nn.Module): 9 | def __init__(self, input_size, latent_size): 10 | super().__init__() 11 | self.input_size = input_size 12 | self.latent_size = latent_size 13 | self.linear_ih = nn.Linear(input_size, latent_size) 14 | self.linear_hh = nn.Linear(latent_size, latent_size) 15 | 16 | def forward(self, input, hidden): 17 | return self.linear_ih(input) + self.linear_hh(hidden) 18 | 19 | 20 | class RNN(nn.Module): 21 | def __init__(self, cell): 22 | super().__init__() 23 | self.cell = cell 24 | 25 | def forward(self, input, h_0): 26 | hidden = h_0 27 | states = [] 28 | for input_step in input.transpose(0, 1): 29 | hidden = self.cell(input_step, hidden) 30 | states.append(hidden) 31 | states = torch.stack(states, dim=1) 32 | return states, hidden 33 | 34 | 35 | class LDSSAE(pl.LightningModule): 36 | def __init__( 37 | self, 38 | dataset: str, 39 | encoder_size: int, 40 | encoder_window: int, 41 | heldin_size: int, 42 | heldout_size: int, 43 | latent_size: int, 44 | lr: float, 45 | weight_decay: float, 46 | dropout: float, 47 | input_size: int, 48 | loss_func: LossFunc = PoissonLossFunc(), 49 | ): 50 | super().__init__() 51 | # Instantiate bidirectional GRU encoder 52 | self.encoder = nn.GRU( 53 | input_size=heldin_size, 54 | hidden_size=encoder_size, 55 | batch_first=True, 56 | bidirectional=True, 57 | ) 58 | self.dropout = nn.Dropout(p=dropout) 59 | self.readout = nn.Linear(in_features=latent_size, out_features=heldout_size) 60 | self.ic_linear = nn.Linear(2 * encoder_size, latent_size) 61 | self.encoder_window = encoder_window 62 | self.latent_size = latent_size 63 | self.weight_decay = weight_decay 64 | self.lr = lr 65 | self.decoder = RNN(LinearCell(input_size, latent_size)) 66 | self.loss_func = loss_func 67 | 68 | def forward(self, data, inputs): 69 | # Pass data through the model 70 | _, h_n = self.encoder(data[:, : self.encoder_window, :]) 71 | h_n = torch.cat([*h_n], -1) 72 | h_n_drop = self.dropout(h_n) 73 | ic = self.ic_linear(h_n_drop) 74 | ic_drop = self.dropout(ic) 75 | # Evaluate the NeuralODE 76 | latents, _ = self.decoder(inputs, ic_drop) 77 | B, T, N = latents.shape 78 | # Map decoder state to data dimension 79 | rates = self.readout(latents) 80 | return rates, latents 81 | 82 | def configure_optimizers(self): 83 | optimizer = torch.optim.Adam( 84 | [ 85 | { 86 | "params": self.parameters(), 87 | "weight_decay": self.weight_decay, 88 | "lr": self.lr, 89 | }, 90 | ], 91 | ) 92 | return optimizer 93 | 94 | def training_step(self, batch, batch_ix): 95 | spikes, recon_spikes, inputs, extra, *_ = batch 96 | # Pass data through the model 97 | pred_logrates, pred_latents = self.forward(spikes, inputs) 98 | # Compute the weighted loss 99 | loss_dict = dict( 100 | controlled=pred_logrates, 101 | targets=recon_spikes, 102 | extra=extra, 103 | ) 104 | loss = self.loss_func(loss_dict) 105 | 106 | self.log("train/loss_all", loss) 107 | 108 | return loss 109 | 110 | def validation_step(self, batch, batch_ix): 111 | spikes, recon_spikes, inputs, extra, *_ = batch 112 | # Pass data through the model 113 | pred_logrates, latents = self.forward(spikes, inputs) 114 | loss_dict = dict( 115 | controlled=pred_logrates, 116 | targets=recon_spikes, 117 | extra=extra, 118 | ) 119 | 120 | loss = self.loss_func(loss_dict) 121 | 122 | self.log("valid/loss_all", loss) 123 | return loss 124 | -------------------------------------------------------------------------------- /ctd/data_modeling/models/SAE/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snel-repo/ComputationThroughDynamicsBenchmark/60caed2c6bb6e73ce3d05f99eda992f509e9ed94/ctd/data_modeling/models/SAE/__init__.py -------------------------------------------------------------------------------- /ctd/data_modeling/models/SAE/dyn_models_gru.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | import torch 3 | from torch import nn 4 | 5 | from .loss_func import LossFunc, PoissonLossFunc 6 | 7 | 8 | class RNN(nn.Module): 9 | def __init__(self, cell): 10 | super().__init__() 11 | self.cell = cell 12 | 13 | def forward(self, input, h_0): 14 | hidden = h_0 15 | states = [] 16 | for input_step in input.transpose(0, 1): 17 | hidden = self.cell(input_step, hidden) 18 | states.append(hidden) 19 | states = torch.stack(states, dim=1) 20 | return states, hidden 21 | 22 | 23 | class GRULatentSAE(pl.LightningModule): 24 | def __init__( 25 | self, 26 | dataset: str, 27 | encoder_size: int, 28 | encoder_window: int, 29 | heldin_size: int, 30 | heldout_size: int, 31 | latent_size: int, 32 | lr: float, 33 | weight_decay: float, 34 | dropout: float, 35 | input_size: int, 36 | loss_func: LossFunc = PoissonLossFunc(), 37 | ): 38 | super().__init__() 39 | # Instantiate bidirectional GRU encoder 40 | self.encoder = nn.GRU( 41 | input_size=heldin_size, 42 | hidden_size=encoder_size, 43 | batch_first=True, 44 | bidirectional=True, 45 | ) 46 | self.dropout = nn.Dropout(p=dropout) 47 | self.readout = nn.Linear(in_features=latent_size, out_features=heldout_size) 48 | self.ic_linear = nn.Linear(2 * encoder_size, latent_size) 49 | self.save_hyperparameters() 50 | latent_size = self.hparams.latent_size 51 | self.loss_func = loss_func 52 | self.decoder = RNN(nn.GRUCell(input_size, latent_size)) 53 | 54 | def forward(self, data, inputs): 55 | # Pass data through the model 56 | _, h_n = self.encoder(data[:, : self.hparams.encoder_window, :]) 57 | h_n = torch.cat([*h_n], -1) 58 | h_n_drop = self.dropout(h_n) 59 | ic = self.ic_linear(h_n_drop) 60 | ic_drop = self.dropout(ic) 61 | # Evaluate the forward pass 62 | latents, _ = self.decoder(inputs, ic_drop) 63 | B, T, N = latents.shape 64 | # Map decoder state to data dimension 65 | rates = self.readout(latents) 66 | return rates, latents 67 | 68 | def configure_optimizers(self): 69 | optimizer = torch.optim.Adam( 70 | [ 71 | { 72 | "params": self.parameters(), 73 | "weight_decay": self.hparams.weight_decay, 74 | "lr": self.hparams.lr, 75 | }, 76 | ], 77 | ) 78 | return optimizer 79 | 80 | def training_step(self, batch, batch_ix): 81 | spikes, recon_spikes, inputs, extra, *_ = batch 82 | # Pass data through the model 83 | pred_logrates, pred_latents = self.forward(spikes, inputs) 84 | # Compute the weighted loss 85 | loss_dict = dict( 86 | controlled=pred_logrates, 87 | targets=recon_spikes, 88 | extra=extra, 89 | ) 90 | loss = self.loss_func(loss_dict) 91 | 92 | self.log("train/loss_all", loss) 93 | 94 | return loss 95 | 96 | def validation_step(self, batch, batch_ix): 97 | spikes, recon_spikes, inputs, extra, *_ = batch 98 | # Pass data through the model 99 | pred_logrates, latents = self.forward(spikes, inputs) 100 | loss_dict = dict( 101 | controlled=pred_logrates, 102 | targets=recon_spikes, 103 | extra=extra, 104 | ) 105 | 106 | loss = self.loss_func(loss_dict) 107 | 108 | self.log("valid/loss_all", loss) 109 | return loss 110 | -------------------------------------------------------------------------------- /ctd/data_modeling/models/SAE/dyn_models_rnn.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | import torch 3 | from torch import nn 4 | 5 | from .loss_func import LossFunc, PoissonLossFunc 6 | 7 | 8 | class RNN(nn.Module): 9 | def __init__(self, cell): 10 | super().__init__() 11 | self.cell = cell 12 | 13 | def forward(self, input, h_0): 14 | hidden = h_0 15 | states = [] 16 | for input_step in input.transpose(0, 1): 17 | hidden = self.cell(input_step, hidden) 18 | states.append(hidden) 19 | states = torch.stack(states, dim=1) 20 | return states, hidden 21 | 22 | 23 | class RNNLatentSAE(pl.LightningModule): 24 | def __init__( 25 | self, 26 | dataset: str, 27 | encoder_size: int, 28 | encoder_window: int, 29 | heldin_size: int, 30 | heldout_size: int, 31 | latent_size: int, 32 | lr: float, 33 | weight_decay: float, 34 | dropout: float, 35 | input_size: int, 36 | loss_func: LossFunc = PoissonLossFunc(), 37 | ): 38 | super().__init__() 39 | # Instantiate bidirectional GRU encoder 40 | self.encoder = nn.GRU( 41 | input_size=heldin_size, 42 | hidden_size=encoder_size, 43 | batch_first=True, 44 | bidirectional=True, 45 | ) 46 | self.dropout = nn.Dropout(p=dropout) 47 | self.readout = nn.Linear(in_features=latent_size, out_features=heldout_size) 48 | self.ic_linear = nn.Linear(2 * encoder_size, latent_size) 49 | self.encoder_window = encoder_window 50 | self.latent_size = latent_size 51 | self.weight_decay = weight_decay 52 | self.lr = lr 53 | self.decoder = RNN(nn.RNNCell(input_size, latent_size)) 54 | self.loss_func = loss_func 55 | 56 | def forward(self, data, inputs): 57 | # Pass data through the model 58 | _, h_n = self.encoder(data[:, : self.encoder_window, :]) 59 | h_n = torch.cat([*h_n], -1) 60 | h_n_drop = self.dropout(h_n) 61 | ic = self.ic_linear(h_n_drop) 62 | ic_drop = self.dropout(ic) 63 | # Evaluate the NeuralODE 64 | latents, _ = self.decoder(inputs, ic_drop) 65 | B, T, N = latents.shape 66 | # Map decoder state to data dimension 67 | rates = self.readout(latents) 68 | return rates, latents 69 | 70 | def configure_optimizers(self): 71 | optimizer = torch.optim.Adam( 72 | [ 73 | { 74 | "params": self.parameters(), 75 | "weight_decay": self.weight_decay, 76 | "lr": self.lr, 77 | }, 78 | ], 79 | ) 80 | return optimizer 81 | 82 | def training_step(self, batch, batch_ix): 83 | spikes, recon_spikes, inputs, extra, *_ = batch 84 | # Pass data through the model 85 | pred_logrates, pred_latents = self.forward(spikes, inputs) 86 | # Compute the weighted loss 87 | loss_dict = dict( 88 | controlled=pred_logrates, 89 | targets=recon_spikes, 90 | extra=extra, 91 | ) 92 | loss = self.loss_func(loss_dict) 93 | 94 | self.log("train/loss_all", loss) 95 | 96 | return loss 97 | 98 | def validation_step(self, batch, batch_ix): 99 | spikes, recon_spikes, inputs, extra, *_ = batch 100 | # Pass data through the model 101 | pred_logrates, latents = self.forward(spikes, inputs) 102 | loss_dict = dict( 103 | controlled=pred_logrates, 104 | targets=recon_spikes, 105 | extra=extra, 106 | ) 107 | 108 | loss = self.loss_func(loss_dict) 109 | 110 | self.log("valid/loss_all", loss) 111 | return loss 112 | -------------------------------------------------------------------------------- /ctd/data_modeling/models/SAE/gru_rnn.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | import torch 3 | from torch import nn 4 | 5 | from .loss_func import LossFunc, PoissonLossFunc 6 | 7 | 8 | class RNN(nn.Module): 9 | def __init__(self, cell): 10 | super().__init__() 11 | self.cell = cell 12 | 13 | def forward(self, input, h_0): 14 | hidden = h_0 15 | states = [] 16 | for input_step in input.transpose(0, 1): 17 | hidden = self.cell(input_step, hidden) 18 | states.append(hidden) 19 | states = torch.stack(states, dim=1) 20 | return states, hidden 21 | 22 | 23 | class GRULatentSAE(pl.LightningModule): 24 | def __init__( 25 | self, 26 | dataset: str, 27 | encoder_size: int, 28 | encoder_window: int, 29 | heldin_size: int, 30 | heldout_size: int, 31 | latent_size: int, 32 | lr: float, 33 | weight_decay: float, 34 | dropout: float, 35 | input_size: int, 36 | loss_func: LossFunc = PoissonLossFunc(), 37 | ): 38 | super().__init__() 39 | # Instantiate bidirectional GRU encoder 40 | self.encoder = nn.GRU( 41 | input_size=heldin_size, 42 | hidden_size=encoder_size, 43 | batch_first=True, 44 | bidirectional=True, 45 | ) 46 | self.dropout = nn.Dropout(p=dropout) 47 | self.readout = nn.Linear(in_features=latent_size, out_features=heldout_size) 48 | self.ic_linear = nn.Linear(2 * encoder_size, latent_size) 49 | self.save_hyperparameters() 50 | latent_size = self.hparams.latent_size 51 | self.loss_func = loss_func 52 | self.decoder = RNN(nn.GRUCell(input_size, latent_size)) 53 | 54 | def forward(self, data, inputs): 55 | # Pass data through the model 56 | _, h_n = self.encoder(data[:, : self.hparams.encoder_window, :]) 57 | h_n = torch.cat([*h_n], -1) 58 | h_n_drop = self.dropout(h_n) 59 | ic = self.ic_linear(h_n_drop) 60 | ic_drop = self.dropout(ic) 61 | # Evaluate the forward pass 62 | latents, _ = self.decoder(inputs, ic_drop) 63 | B, T, N = latents.shape 64 | # Map decoder state to data dimension 65 | rates = self.readout(latents) 66 | return rates, latents 67 | 68 | def configure_optimizers(self): 69 | optimizer = torch.optim.Adam( 70 | [ 71 | { 72 | "params": self.parameters(), 73 | "weight_decay": self.hparams.weight_decay, 74 | "lr": self.hparams.lr, 75 | }, 76 | ], 77 | ) 78 | return optimizer 79 | 80 | def training_step(self, batch, batch_ix): 81 | spikes, recon_spikes, inputs, extra, *_ = batch 82 | # Pass data through the model 83 | pred_logrates, pred_latents = self.forward(spikes, inputs) 84 | # Compute the weighted loss 85 | loss_dict = dict( 86 | controlled=pred_logrates, 87 | targets=recon_spikes, 88 | extra=extra, 89 | ) 90 | loss = self.loss_func(loss_dict) 91 | 92 | self.log("train/loss_all", loss) 93 | 94 | return loss 95 | 96 | def validation_step(self, batch, batch_ix): 97 | spikes, recon_spikes, inputs, extra, *_ = batch 98 | # Pass data through the model 99 | pred_logrates, latents = self.forward(spikes, inputs) 100 | loss_dict = dict( 101 | controlled=pred_logrates, 102 | targets=recon_spikes, 103 | extra=extra, 104 | ) 105 | 106 | loss = self.loss_func(loss_dict) 107 | 108 | self.log("valid/loss_all", loss) 109 | return loss 110 | -------------------------------------------------------------------------------- /ctd/data_modeling/models/SAE/lds.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | import torch 3 | from torch import nn 4 | 5 | from .loss_func import LossFunc, PoissonLossFunc 6 | 7 | 8 | class LinearCell(nn.Module): 9 | def __init__(self, input_size, latent_size): 10 | super().__init__() 11 | self.input_size = input_size 12 | self.latent_size = latent_size 13 | self.linear_ih = nn.Linear(input_size, latent_size) 14 | self.linear_hh = nn.Linear(latent_size, latent_size) 15 | 16 | def forward(self, input, hidden): 17 | return self.linear_ih(input) + self.linear_hh(hidden) 18 | 19 | 20 | class RNN(nn.Module): 21 | def __init__(self, cell): 22 | super().__init__() 23 | self.cell = cell 24 | 25 | def forward(self, input, h_0): 26 | hidden = h_0 27 | states = [] 28 | for input_step in input.transpose(0, 1): 29 | hidden = self.cell(input_step, hidden) 30 | states.append(hidden) 31 | states = torch.stack(states, dim=1) 32 | return states, hidden 33 | 34 | 35 | class LDSSAE(pl.LightningModule): 36 | def __init__( 37 | self, 38 | dataset: str, 39 | encoder_size: int, 40 | encoder_window: int, 41 | heldin_size: int, 42 | heldout_size: int, 43 | latent_size: int, 44 | lr: float, 45 | weight_decay: float, 46 | dropout: float, 47 | input_size: int, 48 | loss_func: LossFunc = PoissonLossFunc(), 49 | ): 50 | super().__init__() 51 | # Instantiate bidirectional GRU encoder 52 | self.encoder = nn.GRU( 53 | input_size=heldin_size, 54 | hidden_size=encoder_size, 55 | batch_first=True, 56 | bidirectional=True, 57 | ) 58 | self.dropout = nn.Dropout(p=dropout) 59 | self.readout = nn.Linear(in_features=latent_size, out_features=heldout_size) 60 | self.ic_linear = nn.Linear(2 * encoder_size, latent_size) 61 | self.encoder_window = encoder_window 62 | self.latent_size = latent_size 63 | self.weight_decay = weight_decay 64 | self.lr = lr 65 | self.decoder = RNN(LinearCell(input_size, latent_size)) 66 | self.loss_func = loss_func 67 | 68 | def forward(self, data, inputs): 69 | # Pass data through the model 70 | _, h_n = self.encoder(data[:, : self.encoder_window, :]) 71 | h_n = torch.cat([*h_n], -1) 72 | h_n_drop = self.dropout(h_n) 73 | ic = self.ic_linear(h_n_drop) 74 | ic_drop = self.dropout(ic) 75 | # Evaluate the NeuralODE 76 | latents, _ = self.decoder(inputs, ic_drop) 77 | B, T, N = latents.shape 78 | # Map decoder state to data dimension 79 | rates = self.readout(latents) 80 | return rates, latents 81 | 82 | def configure_optimizers(self): 83 | optimizer = torch.optim.Adam( 84 | [ 85 | { 86 | "params": self.parameters(), 87 | "weight_decay": self.weight_decay, 88 | "lr": self.lr, 89 | }, 90 | ], 91 | ) 92 | return optimizer 93 | 94 | def training_step(self, batch, batch_ix): 95 | spikes, recon_spikes, inputs, extra, *_ = batch 96 | # Pass data through the model 97 | pred_logrates, pred_latents = self.forward(spikes, inputs) 98 | # Compute the weighted loss 99 | loss_dict = dict( 100 | controlled=pred_logrates, 101 | targets=recon_spikes, 102 | extra=extra, 103 | ) 104 | loss = self.loss_func(loss_dict) 105 | 106 | self.log("train/loss_all", loss) 107 | 108 | return loss 109 | 110 | def validation_step(self, batch, batch_ix): 111 | spikes, recon_spikes, inputs, extra, *_ = batch 112 | # Pass data through the model 113 | pred_logrates, latents = self.forward(spikes, inputs) 114 | loss_dict = dict( 115 | controlled=pred_logrates, 116 | targets=recon_spikes, 117 | extra=extra, 118 | ) 119 | 120 | loss = self.loss_func(loss_dict) 121 | 122 | self.log("valid/loss_all", loss) 123 | return loss 124 | -------------------------------------------------------------------------------- /ctd/data_modeling/models/SAE/loss_func.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | class LossFunc: 6 | def __init__(): 7 | pass 8 | 9 | def __call__(self, pred, target): 10 | pass 11 | 12 | 13 | class PoissonLossFunc(LossFunc): 14 | def __init__(self): 15 | pass 16 | 17 | def __call__(self, loss_dict): 18 | pred = loss_dict["controlled"] 19 | target = loss_dict["targets"] 20 | # action = loss_dict["actions"] 21 | # inputs = loss_dict["inputs"] 22 | return F.poisson_nll_loss(pred, target) 23 | 24 | 25 | class MultiTaskPoissonLossFunc(LossFunc): 26 | def __init__(self): 27 | pass 28 | 29 | def __call__(self, loss_dict): 30 | pred = loss_dict["controlled"] 31 | target = loss_dict["targets"] 32 | extras = loss_dict["extra"] 33 | end_ind = extras[:, 1] 34 | # action = loss_dict["actions"] 35 | # inputs = loss_dict["inputs"] 36 | loss_all = F.poisson_nll_loss(pred, target, reduction="none") 37 | weights = torch.ones_like(loss_all) 38 | for i in range(len(end_ind)): 39 | weights[i, int(end_ind[i]) :, :] = 0 40 | # Normalize each trial by the number of time steps 41 | weights = weights / weights.sum(dim=1, keepdim=True) 42 | loss = torch.mean(loss_all * weights) 43 | return loss 44 | -------------------------------------------------------------------------------- /ctd/data_modeling/models/SAE/node.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | import torch 3 | from torch import nn 4 | 5 | from .loss_func import LossFunc, PoissonLossFunc 6 | 7 | 8 | class RNN(nn.Module): 9 | def __init__(self, cell): 10 | super().__init__() 11 | self.cell = cell 12 | 13 | def forward(self, input, h_0): 14 | hidden = h_0 15 | states = [] 16 | for input_step in input.transpose(0, 1): 17 | hidden = self.cell(input_step, hidden) 18 | states.append(hidden) 19 | states = torch.stack(states, dim=1) 20 | return states, hidden 21 | 22 | 23 | class MLPCell(nn.Module): 24 | def __init__(self, vf_net, input_size): 25 | super().__init__() 26 | self.vf_net = vf_net 27 | self.input_size = input_size 28 | 29 | def forward(self, input, hidden): 30 | input_hidden = torch.cat([hidden, input], dim=1) 31 | vf_out = 0.1 * self.vf_net(input_hidden) 32 | return hidden + vf_out 33 | 34 | 35 | class NODELatentSAE(pl.LightningModule): 36 | def __init__( 37 | self, 38 | dataset: str, 39 | encoder_size: int, 40 | encoder_window: int, 41 | heldin_size: int, 42 | heldout_size: int, 43 | latent_size: int, 44 | lr: float, 45 | weight_decay: float, 46 | dropout: float, 47 | input_size: int, 48 | vf_hidden_size: int, 49 | vf_num_layers: int, 50 | loss_func: LossFunc = PoissonLossFunc(), 51 | ): 52 | super().__init__() 53 | # Instantiate bidirectional GRU encoder 54 | self.encoder = nn.GRU( 55 | input_size=heldin_size, 56 | hidden_size=encoder_size, 57 | batch_first=True, 58 | bidirectional=True, 59 | ) 60 | self.dropout = nn.Dropout(p=dropout) 61 | self.ic_linear = nn.Linear(2 * encoder_size, latent_size) 62 | self.save_hyperparameters() 63 | 64 | act_func = torch.nn.ReLU 65 | latent_size = self.hparams.latent_size 66 | vector_field = [] 67 | vector_field.append(nn.Linear(latent_size + input_size, vf_hidden_size)) 68 | vector_field.append(act_func()) 69 | for k in range(self.hparams.vf_num_layers - 1): 70 | vector_field.append(nn.Linear(vf_hidden_size, vf_hidden_size)) 71 | vector_field.append(act_func()) 72 | vector_field.append(nn.Linear(vf_hidden_size, latent_size)) 73 | vector_field_net = nn.Sequential(*vector_field) 74 | self.decoder = RNN(MLPCell(vector_field_net, input_size)) 75 | self.readout = nn.Linear(in_features=latent_size, out_features=heldout_size) 76 | self.loss_func = loss_func 77 | self.weight_decay = weight_decay 78 | self.lr = lr 79 | 80 | def forward(self, data, inputs): 81 | # Pass data through the model 82 | _, h_n = self.encoder(data[:, : self.hparams.encoder_window, :]) 83 | h_n = torch.cat([*h_n], -1) 84 | h_n_drop = self.dropout(h_n) 85 | ic = self.ic_linear(h_n_drop) 86 | # Evaluate the NeuralODE 87 | latents, _ = self.decoder(inputs, ic) 88 | B, T, N = latents.shape 89 | # Map decoder state to data dimension 90 | rates = self.readout(latents) 91 | return rates, latents 92 | 93 | def configure_optimizers(self): 94 | optimizer = torch.optim.Adam( 95 | [ 96 | { 97 | "params": self.parameters(), 98 | "weight_decay": self.weight_decay, 99 | "lr": self.lr, 100 | }, 101 | ], 102 | ) 103 | return optimizer 104 | 105 | def training_step(self, batch, batch_ix): 106 | 107 | spikes, recon_spikes, inputs, extra, *_ = batch 108 | # Pass data through the model 109 | pred_logrates, pred_latents = self.forward(spikes, inputs) 110 | 111 | # Compute the weighted loss 112 | loss_dict = dict( 113 | controlled=pred_logrates, 114 | targets=recon_spikes, 115 | extra=extra, 116 | ) 117 | loss_all_train = self.loss_func(loss_dict) 118 | self.log("train/loss_all_train", loss_all_train) 119 | 120 | return loss_all_train 121 | 122 | def validation_step(self, batch, batch_ix): 123 | if len(batch) == 1: 124 | (spikes,) = batch 125 | # Pass data through the model 126 | pred_logrates, latents = self.forward(spikes) 127 | # Isolate heldin predictions 128 | _, n_obs, n_heldin = spikes.shape 129 | pred_logrates = pred_logrates[:, :n_obs, :n_heldin] 130 | recon_spikes = spikes 131 | else: 132 | spikes, recon_spikes, inputs, extra, *_ = batch 133 | # Pass data through the model 134 | pred_logrates, latents = self.forward(spikes, inputs) 135 | 136 | loss_dict = dict( 137 | controlled=pred_logrates, 138 | targets=recon_spikes, 139 | extra=extra, 140 | ) 141 | 142 | loss = self.loss_func(loss_dict) 143 | self.log("valid/loss_all", loss) 144 | 145 | return loss 146 | -------------------------------------------------------------------------------- /ctd/data_modeling/models/SAE/readouts.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class RNN(nn.Module): 6 | def __init__(self, cell): 7 | super().__init__() 8 | self.cell = cell 9 | 10 | def forward(self, h_0, num_steps, rev): 11 | hidden = h_0 12 | states = [] 13 | vf_out = [] 14 | for input_step in range(num_steps): 15 | hidden, vf_1 = self.cell(hidden, rev=rev) 16 | states.append(hidden) 17 | vf_out.append(vf_1) 18 | states = torch.stack(states, dim=1) 19 | vf_out = torch.norm(torch.stack(vf_out, dim=1), dim=2) 20 | return states, hidden, vf_out 21 | 22 | 23 | class MLPCell(nn.Module): 24 | def __init__(self, vf_net): 25 | super().__init__() 26 | self.vf_net = vf_net 27 | self.input_size = 3 28 | 29 | def forward(self, hidden, rev): 30 | vf_out = 0.1 * self.vf_net(hidden) 31 | if rev: 32 | return hidden - vf_out, vf_out 33 | else: 34 | return hidden + vf_out, vf_out 35 | 36 | 37 | class MLPCellScale(nn.Module): 38 | def __init__(self, vf_net, scale=0.1): 39 | super().__init__() 40 | self.vf_net = vf_net 41 | self.input_size = 3 42 | self.scale = scale 43 | 44 | def forward(self, hidden, rev): 45 | vf_out = self.scale * self.vf_net(hidden) 46 | if rev: 47 | return hidden - vf_out, vf_out 48 | else: 49 | return hidden + vf_out, vf_out 50 | 51 | 52 | def build_subnet(dims_in, dims_out): 53 | return nn.Sequential( 54 | nn.Linear(dims_in, 64), 55 | nn.ReLU(), 56 | nn.Linear(64, dims_out), 57 | ) 58 | 59 | 60 | class FeedForwardNet(nn.Module): 61 | def __init__(self, input_dim, output_dim, hidden_dim=128, num_layers=2): 62 | super().__init__() 63 | self.network = [] 64 | self.network.append(nn.Linear(input_dim, hidden_dim)) 65 | self.network.append(nn.ReLU()) 66 | for k in range(num_layers - 1): 67 | self.network.append(nn.Linear(hidden_dim, hidden_dim)) 68 | self.network.append(nn.ReLU()) 69 | self.network.append(nn.Linear(hidden_dim, output_dim)) 70 | self.network = nn.Sequential(*self.network) 71 | 72 | def forward(self, input): 73 | return self.network(input) 74 | 75 | 76 | # class InvertibleNetNeural(nn.Module): 77 | # def __init__(self, node_dim, heldin_dim, heldout_dim, inn_num_layers): 78 | # super().__init__() 79 | # self.node_dim = node_dim 80 | # self.heldin_dim = heldin_dim 81 | # self.heldout_dim = heldout_dim 82 | # self.hidden_dim = max(node_dim, heldout_dim) 83 | 84 | # inn = Ff.SequenceINN(self.hidden_dim) 85 | # for k in range(inn_num_layers): 86 | # inn.append( 87 | # Fm.AllInOneBlock, 88 | # subnet_constructor=build_subnet, 89 | # permute_soft=True, 90 | # ) 91 | # self.network = inn 92 | 93 | # def forward(self, inputs, reverse=False): 94 | # if not reverse: 95 | # batch_size, n_steps, n_inputs = inputs.shape 96 | # assert n_inputs == self.node_dim 97 | # else: 98 | # batch_size, n_inputs = inputs.shape 99 | # assert n_inputs == self.heldout_dim 100 | # # Pad the inputs if necessary 101 | # inputs = F.pad(inputs, (0, self.hidden_dim - n_inputs)) 102 | # # Pass the inputs through the network 103 | # outputs, _ = self.network(inputs.reshape(-1, self.hidden_dim), rev=reverse) 104 | # if not reverse: 105 | # outputs = outputs.reshape(batch_size, n_steps, self.hidden_dim) 106 | # return outputs 107 | # # Remove padded elements if necessary 108 | # else: 109 | # # Trim the final dimension to match the node dimension 110 | # outputs = outputs[:, : self.node_dim] 111 | # return outputs 112 | 113 | 114 | class FlowReadout(nn.Module): 115 | def __init__( 116 | self, 117 | node_dim, 118 | heldin_dim, 119 | heldout_dim, 120 | vf_hidden_size, 121 | num_layers, 122 | num_steps, 123 | ): 124 | super().__init__() 125 | self.node_dim = node_dim 126 | self.heldin_dim = heldin_dim 127 | self.heldout_dim = heldout_dim 128 | 129 | self.vf_hidden_size = vf_hidden_size 130 | self.num_layers = num_layers 131 | self.num_steps = num_steps 132 | 133 | act_func = torch.nn.ReLU 134 | vector_field = [] 135 | vector_field.append(nn.Linear(self.heldout_dim, self.vf_hidden_size)) 136 | vector_field.append(act_func()) 137 | for k in range(self.num_layers - 1): 138 | vector_field.append(nn.Linear(vf_hidden_size, vf_hidden_size)) 139 | vector_field.append(act_func()) 140 | vector_field.append(nn.Linear(vf_hidden_size, self.heldout_dim)) 141 | vector_field_net = nn.Sequential(*vector_field) 142 | self.network = RNN(cell=MLPCell(vf_net=vector_field_net)) 143 | 144 | def forward(self, inputs, reverse=False): 145 | if not reverse: 146 | batch_size, n_time, n_inputs = inputs.shape 147 | assert n_inputs == self.node_dim 148 | inputs = torch.cat( 149 | [ 150 | inputs, 151 | torch.zeros( 152 | batch_size, 153 | n_time, 154 | self.heldout_dim - self.node_dim, 155 | device=inputs.device, 156 | ), 157 | ], 158 | dim=-1, 159 | ) 160 | else: 161 | batch_size, n_inputs = inputs.shape 162 | assert n_inputs == self.heldout_dim 163 | 164 | # Pass the inputs through the network 165 | _, outputs, _ = self.network( 166 | inputs.reshape(-1, self.heldout_dim), num_steps=self.num_steps, rev=reverse 167 | ) 168 | if not reverse: 169 | outputs = outputs.reshape(batch_size, n_time, self.heldout_dim) 170 | return outputs 171 | else: 172 | # Trim the final dimension to match the node dimension 173 | outputs = outputs[:, : self.node_dim] 174 | return outputs 175 | -------------------------------------------------------------------------------- /ctd/data_modeling/models/SAE/template.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | 3 | 4 | class TemplateSAE(pl.LightningModule): 5 | def __init__( 6 | self, 7 | dataset: str, 8 | ): 9 | super().__init__() 10 | # Instantiate SAE model 11 | # To use fixed-point finding, must have a "decoder" 12 | # attribute with a "cell" attribute, where "cell" is a function 13 | # that takes input and hidden state and returns the new hidden state 14 | # 15 | 16 | def forward(self, data, inputs): 17 | # Pass data through the model 18 | # Inputs: 19 | # data: Tensor of shape (batch_size, seq_len, input_size) 20 | # containing the spiking activity 21 | # inputs: Tensor of shape (batch_size, seq_len, input_size) 22 | # containing the input to the model (if provided) 23 | # 24 | # Returns: 25 | # log_rates: Tensor of shape (batch_size, seq_len, input_size) 26 | # containing the predicted log-firing rates (log if Poisson Loss is used) 27 | # latents: Tensor of shape (batch_size, seq_len, latent_size) 28 | # containing the hidden state of the model 29 | # return log_rates, latents 30 | pass 31 | 32 | def configure_optimizers(self): 33 | # Define optimizer 34 | # Must return a pytorch optimizer 35 | pass 36 | 37 | def training_step(self, batch, batch_ix): 38 | # Define training step 39 | # Inputs: 40 | # batch: Tuple containing: 41 | # - data: Tensor of shape (batch_size, seq_len, input_size) 42 | # containing the spiking activity 43 | # - data: (used if different IC encoding than recon activity) 44 | # - inputs: Tensor of shape (batch_size, seq_len, input_size) 45 | # containing the input to the model (if provided) 46 | # - extra: Tuple containing any additional 47 | # data needed for training (trial lens, etc.) 48 | 49 | # batch_ix: Index of the batch 50 | # 51 | # Returns: 52 | # loss: Tensor containing the loss for the batch 53 | pass 54 | 55 | def validation_step(self, batch, batch_ix): 56 | # Define validation step 57 | # Inputs: 58 | # batch: Tuple containing: 59 | # - data: Tensor of shape (batch_size, seq_len, input_size) 60 | # containing the spiking activity 61 | # - data: (used if different IC encoding than recon activity) 62 | # - inputs: Tensor of shape (batch_size, seq_len, input_size) 63 | # containing the input to the model (if provided) 64 | # - extra: Tuple containing any additional data 65 | # needed for training (trial lens, etc.) 66 | 67 | # batch_ix: Index of the batch 68 | # 69 | # Returns: 70 | # loss: Tensor containing the loss for the batch 71 | pass 72 | -------------------------------------------------------------------------------- /ctd/data_modeling/models/SAE/vanilla_rnn.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | import torch 3 | from torch import nn 4 | 5 | from .loss_func import LossFunc, PoissonLossFunc 6 | 7 | 8 | class RNN(nn.Module): 9 | def __init__(self, cell): 10 | super().__init__() 11 | self.cell = cell 12 | 13 | def forward(self, input, h_0): 14 | hidden = h_0 15 | states = [] 16 | for input_step in input.transpose(0, 1): 17 | hidden = self.cell(input_step, hidden) 18 | states.append(hidden) 19 | states = torch.stack(states, dim=1) 20 | return states, hidden 21 | 22 | 23 | class RNNLatentSAE(pl.LightningModule): 24 | def __init__( 25 | self, 26 | dataset: str, 27 | encoder_size: int, 28 | encoder_window: int, 29 | heldin_size: int, 30 | heldout_size: int, 31 | latent_size: int, 32 | lr: float, 33 | weight_decay: float, 34 | dropout: float, 35 | input_size: int, 36 | loss_func: LossFunc = PoissonLossFunc(), 37 | ): 38 | super().__init__() 39 | # Instantiate bidirectional GRU encoder 40 | self.encoder = nn.GRU( 41 | input_size=heldin_size, 42 | hidden_size=encoder_size, 43 | batch_first=True, 44 | bidirectional=True, 45 | ) 46 | self.dropout = nn.Dropout(p=dropout) 47 | self.readout = nn.Linear(in_features=latent_size, out_features=heldout_size) 48 | self.ic_linear = nn.Linear(2 * encoder_size, latent_size) 49 | self.encoder_window = encoder_window 50 | self.latent_size = latent_size 51 | self.weight_decay = weight_decay 52 | self.lr = lr 53 | self.decoder = RNN(nn.RNNCell(input_size, latent_size)) 54 | self.loss_func = loss_func 55 | 56 | def forward(self, data, inputs): 57 | # Pass data through the model 58 | _, h_n = self.encoder(data[:, : self.encoder_window, :]) 59 | h_n = torch.cat([*h_n], -1) 60 | h_n_drop = self.dropout(h_n) 61 | ic = self.ic_linear(h_n_drop) 62 | ic_drop = self.dropout(ic) 63 | # Evaluate the NeuralODE 64 | latents, _ = self.decoder(inputs, ic_drop) 65 | B, T, N = latents.shape 66 | # Map decoder state to data dimension 67 | rates = self.readout(latents) 68 | return rates, latents 69 | 70 | def configure_optimizers(self): 71 | optimizer = torch.optim.Adam( 72 | [ 73 | { 74 | "params": self.parameters(), 75 | "weight_decay": self.weight_decay, 76 | "lr": self.lr, 77 | }, 78 | ], 79 | ) 80 | return optimizer 81 | 82 | def training_step(self, batch, batch_ix): 83 | spikes, recon_spikes, inputs, extra, *_ = batch 84 | # Pass data through the model 85 | pred_logrates, pred_latents = self.forward(spikes, inputs) 86 | # Compute the weighted loss 87 | loss_dict = dict( 88 | controlled=pred_logrates, 89 | targets=recon_spikes, 90 | extra=extra, 91 | ) 92 | loss = self.loss_func(loss_dict) 93 | 94 | self.log("train/loss_all", loss) 95 | 96 | return loss 97 | 98 | def validation_step(self, batch, batch_ix): 99 | spikes, recon_spikes, inputs, extra, *_ = batch 100 | # Pass data through the model 101 | pred_logrates, latents = self.forward(spikes, inputs) 102 | loss_dict = dict( 103 | controlled=pred_logrates, 104 | targets=recon_spikes, 105 | extra=extra, 106 | ) 107 | 108 | loss = self.loss_func(loss_dict) 109 | 110 | self.log("valid/loss_all", loss) 111 | return loss 112 | -------------------------------------------------------------------------------- /ctd/data_modeling/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snel-repo/ComputationThroughDynamicsBenchmark/60caed2c6bb6e73ce3d05f99eda992f509e9ed94/ctd/data_modeling/models/__init__.py -------------------------------------------------------------------------------- /ctd/data_modeling/train_PTL.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import pickle 4 | from pathlib import Path 5 | from typing import List 6 | 7 | import dotenv 8 | import hydra 9 | import pytorch_lightning as pl 10 | 11 | from ctd.data_modeling.extensions.SAE.utils import flatten 12 | 13 | dotenv.load_dotenv(override=True) 14 | 15 | log = logging.getLogger(__name__) 16 | 17 | 18 | def train_PTL( 19 | overrides: dict = {}, 20 | config_dict: dict = {}, 21 | path_dict: str = "", 22 | run_tag: str = "", 23 | ): 24 | compose_list = config_dict.keys() 25 | # Convert the overrides dict into a list of override strings 26 | overrides_list = [f"{k}={v}" for k, v in overrides.items()] 27 | 28 | # Generate a run_name from the overrides 29 | run_list = [] 30 | for k, v in overrides.items(): 31 | if isinstance(v, float): 32 | v = "{:.2E}".format(v) 33 | k_list = k.split(".") 34 | run_list.append(f"{k_list[-1]}={v}") 35 | run_name = "_".join(run_list) 36 | 37 | # Compose the configs for all components 38 | config_all = {} 39 | for field in compose_list: 40 | with hydra.initialize( 41 | config_path=str(config_dict[field].parent), job_name=field 42 | ): 43 | # Filter overrides relevant to this field 44 | field_prefix = f"{field}." 45 | field_overrides = [ 46 | override 47 | for override in overrides_list 48 | if override.startswith(field_prefix) 49 | ] 50 | # Remove the field prefix from the overrides 51 | field_overrides = [ 52 | override[len(field_prefix) :] 53 | if override.startswith(field_prefix) 54 | else override 55 | for override in field_overrides 56 | ] 57 | config_all[field] = hydra.compose( 58 | config_name=config_dict[field].name, overrides=field_overrides 59 | ) 60 | 61 | # Handle special parameters 62 | if "params.seed" in overrides: 63 | seed = overrides["params.seed"] 64 | pl.seed_everything(seed, workers=True) 65 | else: 66 | pl.seed_everything(0, workers=True) 67 | 68 | # --------------------------Instantiate datamodule------------------------------- 69 | log.info("Instantiating datamodule") 70 | datamodule: pl.LightningDataModule = hydra.utils.instantiate( 71 | config_all["datamodule"], _convert_="all" 72 | ) 73 | 74 | # ---------------------------Instantiate callbacks--------------------------- 75 | callbacks: List[pl.Callback] = [] 76 | if "callbacks" in config_all: 77 | for _, cb_conf in config_all["callbacks"].items(): 78 | if "_target_" in cb_conf: 79 | log.info(f"Instantiating callback <{cb_conf._target_}>") 80 | callbacks.append(hydra.utils.instantiate(cb_conf, _convert_="all")) 81 | 82 | # -----------------------------Instantiate loggers---------------------------- 83 | flat_list = flatten(overrides).items() 84 | run_list = [] 85 | for k, v in flat_list: 86 | if type(v) == float: 87 | v = "{:.2E}".format(v) 88 | k_list = k.split(".") 89 | run_list.append(f"{k_list[-1]}={v}") 90 | run_name = "_".join(run_list) 91 | 92 | logger: List[pl.LightningLoggerBase] = [] 93 | if "loggers" in config_all: 94 | for _, lg_conf in config_all["loggers"].items(): 95 | if "_target_" in lg_conf: 96 | log.info(f"Instantiating logger <{lg_conf._target_}>") 97 | if lg_conf._target_ == "pytorch_lightning.loggers.WandbLogger": 98 | lg_conf["group"] = run_tag 99 | lg_conf["name"] = run_name 100 | logger.append(hydra.utils.instantiate(lg_conf)) 101 | 102 | # ------------------------------Instantiate model-------------------------------- 103 | log.info(f"Instantiating model <{config_all['model']._target_}") 104 | model: pl.LightningModule = hydra.utils.instantiate( 105 | config_all["model"], _convert_="all" 106 | ) 107 | # -----------------------------Instantiate trainer--------------------------- 108 | targ_string = config_all["trainer"]._target_ 109 | log.info(f"Instantiating trainer <{targ_string}>") 110 | trainer: pl.Trainer = hydra.utils.instantiate( 111 | config_all["trainer"], 112 | logger=logger, 113 | callbacks=callbacks, 114 | accelerator="auto", 115 | _convert_="all", 116 | ) 117 | # -----------------------------Train the model------------------------------- 118 | log.info("Starting training") 119 | trainer.fit(model=model, datamodule=datamodule) 120 | 121 | # -----------------------------Save the model------------------------------- 122 | # Save the model, datamodule, and simulator to the directory 123 | save_path = path_dict["trained_models"] 124 | save_path = os.path.join(save_path, run_tag, run_name) 125 | 126 | Path(save_path).mkdir(parents=True, exist_ok=True) 127 | model_path = os.path.join(save_path, "model.pkl") 128 | datamodule_path = os.path.join(save_path, "datamodule.pkl") 129 | 130 | model = model.to("cpu") 131 | with open(model_path, "wb") as f: 132 | pickle.dump(model, f) 133 | 134 | with open(datamodule_path, "wb") as f: 135 | pickle.dump(datamodule, f) 136 | -------------------------------------------------------------------------------- /ctd/task_modeling/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snel-repo/ComputationThroughDynamicsBenchmark/60caed2c6bb6e73ce3d05f99eda992f509e9ed94/ctd/task_modeling/__init__.py -------------------------------------------------------------------------------- /ctd/task_modeling/callbacks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snel-repo/ComputationThroughDynamicsBenchmark/60caed2c6bb6e73ce3d05f99eda992f509e9ed94/ctd/task_modeling/callbacks/__init__.py -------------------------------------------------------------------------------- /ctd/task_modeling/configs/callbacks/default_MultiTask.yaml: -------------------------------------------------------------------------------- 1 | model_checkpoint: 2 | _target_: pytorch_lightning.callbacks.ModelCheckpoint 3 | dirpath: "." 4 | monitor: valid/loss 5 | save_last: True 6 | 7 | tune_report_callback: 8 | _target_: ray.tune.integration.pytorch_lightning.TuneReportCallback 9 | metrics: 10 | loss: valid/loss 11 | 12 | state_transition_callback: 13 | _target_: ctd.task_modeling.callbacks.callbacks_multitask.StateTransitionCallback 14 | log_every_n_epochs: 50 15 | 16 | state_transition_scatter_callback: 17 | _target_: ctd.task_modeling.callbacks.callbacks_multitask.StateTransitionScatterCallback 18 | log_every_n_epochs: 50 19 | 20 | performance_callback: 21 | _target_: ctd.task_modeling.callbacks.callbacks_multitask.MultiTaskPerformanceCallback 22 | log_every_n_epochs: 50 23 | shared_subspace: 24 | _target_: ctd.task_modeling.callbacks.callbacks_multitask.SharedSubspaceCallback 25 | log_every_n_epochs: 20 26 | -------------------------------------------------------------------------------- /ctd/task_modeling/configs/callbacks/default_NBFF.yaml: -------------------------------------------------------------------------------- 1 | model_checkpoint: 2 | _target_: pytorch_lightning.callbacks.ModelCheckpoint 3 | dirpath: "." 4 | monitor: valid/loss 5 | save_last: True 6 | 7 | tune_report_callback: 8 | _target_: ray.tune.integration.pytorch_lightning.TuneReportCallback 9 | metrics: 10 | loss: valid/loss 11 | 12 | state_transition_callback: 13 | _target_: ctd.task_modeling.callbacks.callbacks.StateTransitionCallback 14 | log_every_n_epochs: 20 15 | 16 | latent_traj: 17 | _target_: ctd.task_modeling.callbacks.callbacks.LatentTrajectoryPlot 18 | log_every_n_epochs: 20 19 | -------------------------------------------------------------------------------- /ctd/task_modeling/configs/callbacks/default_RandomTarget.yaml: -------------------------------------------------------------------------------- 1 | model_checkpoint: 2 | _target_: pytorch_lightning.callbacks.ModelCheckpoint 3 | dirpath: "." 4 | monitor: valid/loss 5 | save_last: True 6 | 7 | tune_report_callback: 8 | _target_: ray.tune.integration.pytorch_lightning.TuneReportCallback 9 | metrics: 10 | loss: valid/loss 11 | 12 | video_generation_arm: 13 | _target_: ctd.task_modeling.callbacks.callbacks_coupled.MotorNetVideoGenerationArm 14 | log_every_n_epochs: 50 15 | 16 | latent_traj: 17 | _target_: ctd.task_modeling.callbacks.callbacks_coupled.LatentTrajectoryPlot 18 | log_every_n_epochs: 50 19 | trim_inds: [5, -1] 20 | -------------------------------------------------------------------------------- /ctd/task_modeling/configs/callbacks/default_no_wandb.yaml: -------------------------------------------------------------------------------- 1 | model_checkpoint: 2 | _target_: pytorch_lightning.callbacks.ModelCheckpoint 3 | dirpath: "." 4 | monitor: valid/loss 5 | save_last: True 6 | 7 | tune_report_callback: 8 | _target_: ray.tune.integration.pytorch_lightning.TuneReportCallback 9 | metrics: 10 | loss: valid/loss 11 | -------------------------------------------------------------------------------- /ctd/task_modeling/configs/datamodule_sim/datamodule_MultiTask.yaml: -------------------------------------------------------------------------------- 1 | _target_: ctd.task_modeling.datamodule.task_datamodule.TaskDataModule 2 | n_samples: 500 3 | seed: 100 4 | batch_size: 64 5 | num_workers: 4 6 | -------------------------------------------------------------------------------- /ctd/task_modeling/configs/datamodule_sim/datamodule_NBFF.yaml: -------------------------------------------------------------------------------- 1 | _target_: ctd.task_modeling.datamodule.task_datamodule.TaskDataModule 2 | n_samples: 1000 3 | seed: 100 4 | batch_size: 256 5 | num_workers: 4 6 | -------------------------------------------------------------------------------- /ctd/task_modeling/configs/datamodule_sim/datamodule_RandomTarget.yaml: -------------------------------------------------------------------------------- 1 | _target_: ctd.task_modeling.datamodule.task_datamodule.TaskDataModule 2 | n_samples: 1000 3 | seed: 100 4 | batch_size: 256 5 | num_workers: 4 6 | -------------------------------------------------------------------------------- /ctd/task_modeling/configs/datamodule_train/datamodule_MultiTask.yaml: -------------------------------------------------------------------------------- 1 | _target_: ctd.task_modeling.datamodule.task_datamodule.TaskDataModule 2 | n_samples: 1000 3 | seed: 0 4 | batch_size: 64 5 | num_workers: 4 6 | -------------------------------------------------------------------------------- /ctd/task_modeling/configs/datamodule_train/datamodule_NBFF.yaml: -------------------------------------------------------------------------------- 1 | _target_: ctd.task_modeling.datamodule.task_datamodule.TaskDataModule 2 | n_samples: 1000 3 | seed: 0 4 | batch_size: 256 5 | num_workers: 4 6 | -------------------------------------------------------------------------------- /ctd/task_modeling/configs/datamodule_train/datamodule_RandomTarget.yaml: -------------------------------------------------------------------------------- 1 | _target_: ctd.task_modeling.datamodule.task_datamodule.TaskDataModule 2 | n_samples: 1000 3 | seed: 0 4 | batch_size: 1000 5 | num_workers: 8 6 | -------------------------------------------------------------------------------- /ctd/task_modeling/configs/env_sim/MultiTask.yaml: -------------------------------------------------------------------------------- 1 | _target_: ctd.task_modeling.task_env.multitask.MultiTaskWrapper 2 | dataset_name: MultiTask 3 | task_list: 4 | - DelayPro 5 | - DelayAnti 6 | - MemoryPro 7 | - MemoryAnti 8 | - ReactPro 9 | - ReactAnti 10 | - IntMod1 11 | - IntMod2 12 | - ContextIntMod1 13 | - ContextIntMod2 14 | - ContextIntMultimodal 15 | - Match2Sample 16 | - NonMatch2Sample 17 | - MatchCatPro 18 | - MatchCatAnti 19 | 20 | n_timesteps: 320 21 | noise: 0.3 22 | dynamic_noise: 0.0 23 | num_targets: 32 24 | bin_size: 20 25 | grouped_sampler: False 26 | latent_l2_wt: 0.0 27 | -------------------------------------------------------------------------------- /ctd/task_modeling/configs/env_sim/NBFF.yaml: -------------------------------------------------------------------------------- 1 | _target_: ctd.task_modeling.task_env.task_env.NBitFlipFlop 2 | 3 | n: 3 4 | n_timesteps: 500 5 | noise: 0.15 6 | dynamic_noise: 0.0 7 | switch_prob: 0.01 8 | 9 | transition_blind: 4 10 | -------------------------------------------------------------------------------- /ctd/task_modeling/configs/env_sim/RandomTarget.yaml: -------------------------------------------------------------------------------- 1 | _target_: ctd.task_modeling.task_env.random_target.RandomTarget 2 | effector: 3 | _target_: motornet.effector.RigidTendonArm26 4 | muscle: 5 | _target_: motornet.muscle.MujocoHillMuscle 6 | max_ep_duration: 1.55 7 | action_noise: 0.005 8 | proprioception_noise: 0.005 9 | proprioception_delay: 0.02 10 | vision_noise: 0.005 11 | vision_delay: 0.05 12 | 13 | act_weight: 1.0 14 | pos_weight: 1.0 15 | 16 | context_input_noise: 0.04 17 | is_aligned: True 18 | -------------------------------------------------------------------------------- /ctd/task_modeling/configs/env_task/MultiTask.yaml: -------------------------------------------------------------------------------- 1 | _target_: ctd.task_modeling.task_env.multitask.MultiTaskWrapper 2 | dataset_name: MultiTask 3 | task_list: 4 | - DelayPro 5 | - MemoryPro 6 | - ReactPro 7 | - DelayAnti 8 | - MemoryAnti 9 | - ReactAnti 10 | - IntMod1 11 | - IntMod2 12 | - ContextIntMod1 13 | - ContextIntMod2 14 | - ContextIntMultimodal 15 | - Match2Sample 16 | - NonMatch2Sample 17 | - MatchCatPro 18 | - MatchCatAnti 19 | 20 | n_timesteps: 320 21 | noise: 0.3 22 | dynamic_noise: 0.0 23 | num_targets: 32 24 | bin_size: 20 25 | grouped_sampler: True 26 | latent_l2_wt: 0.0 27 | -------------------------------------------------------------------------------- /ctd/task_modeling/configs/env_task/NBFF.yaml: -------------------------------------------------------------------------------- 1 | _target_: ctd.task_modeling.task_env.task_env.NBitFlipFlop 2 | 3 | n: 3 4 | n_timesteps: 500 5 | noise: 0.15 6 | dynamic_noise: 0.0 7 | switch_prob: 0.01 8 | 9 | transition_blind: 4 10 | -------------------------------------------------------------------------------- /ctd/task_modeling/configs/env_task/RandomTarget.yaml: -------------------------------------------------------------------------------- 1 | _target_: ctd.task_modeling.task_env.random_target.RandomTarget 2 | effector: 3 | _target_: motornet.effector.RigidTendonArm26 4 | muscle: 5 | _target_: motornet.muscle.MujocoHillMuscle 6 | max_ep_duration: 3.0 7 | action_noise: 0.005 8 | proprioception_noise: 0.005 9 | proprioception_delay: 0.02 10 | vision_noise: 0.005 11 | vision_delay: 0.05 12 | 13 | act_weight: 1.0 14 | pos_weight: 1.0 15 | 16 | context_input_noise: 0.04 17 | is_aligned: False 18 | -------------------------------------------------------------------------------- /ctd/task_modeling/configs/logger/default.yaml: -------------------------------------------------------------------------------- 1 | tensorboard_logger: 2 | _target_: pytorch_lightning.loggers.TensorBoardLogger 3 | save_dir: "." 4 | version: "" 5 | name: "" 6 | csv_logger: 7 | _target_: pytorch_lightning.loggers.CSVLogger 8 | save_dir: "." 9 | version: "" 10 | name: "" 11 | wandb_logger: 12 | _target_: pytorch_lightning.loggers.WandbLogger 13 | save_dir: "." 14 | version: "" 15 | name: "" 16 | project: "task_trained_RNN" 17 | group: "" 18 | -------------------------------------------------------------------------------- /ctd/task_modeling/configs/logger/default_no_wandb.yaml: -------------------------------------------------------------------------------- 1 | tensorboard_logger: 2 | _target_: pytorch_lightning.loggers.TensorBoardLogger 3 | save_dir: "." 4 | version: "" 5 | name: "" 6 | csv_logger: 7 | _target_: pytorch_lightning.loggers.CSVLogger 8 | save_dir: "." 9 | version: "" 10 | name: "" 11 | -------------------------------------------------------------------------------- /ctd/task_modeling/configs/model/DriscollRNN.yaml: -------------------------------------------------------------------------------- 1 | _target_: ctd.task_modeling.model.rnn.DriscollRNN 2 | latent_size: 128 3 | noise_level: 0.05 4 | gamma: 0.2 5 | -------------------------------------------------------------------------------- /ctd/task_modeling/configs/model/GRU_RNN.yaml: -------------------------------------------------------------------------------- 1 | _target_: ctd.task_modeling.model.rnn.GRU_RNN 2 | latent_size: 128 3 | -------------------------------------------------------------------------------- /ctd/task_modeling/configs/model/NODE.yaml: -------------------------------------------------------------------------------- 1 | _target_: ctd.task_modeling.model.node.NODE 2 | latent_size: 3 3 | layer_hidden_size: 128 4 | num_layers: 3 5 | -------------------------------------------------------------------------------- /ctd/task_modeling/configs/model/NoisyGRU.yaml: -------------------------------------------------------------------------------- 1 | _target_: ctd.task_modeling.model.rnn.NoisyGRU 2 | latent_size: 128 3 | noise_level: 0.01 4 | -------------------------------------------------------------------------------- /ctd/task_modeling/configs/model/NoisyGRULatentL2.yaml: -------------------------------------------------------------------------------- 1 | _target_: ctd.task_modeling.model.rnn.NoisyGRU_LatentL2 2 | latent_size: 128 3 | noise_level: 0.01 4 | latent_ic_var: 0.05 5 | l2_wt: 1e-6 6 | -------------------------------------------------------------------------------- /ctd/task_modeling/configs/model/Vanilla_RNN.yaml: -------------------------------------------------------------------------------- 1 | _target_: ctd.task_modeling.model.rnn.Vanilla_RNN 2 | latent_size: 128 3 | -------------------------------------------------------------------------------- /ctd/task_modeling/configs/simulator/default_MultiTask.yaml: -------------------------------------------------------------------------------- 1 | _target_: ctd.task_modeling.simulator.neural_simulator.NeuralDataSimulator 2 | neuron_dict: 3 | n_neurons_heldin: 50 4 | n_neurons_heldout: 10 5 | embed_dict: 6 | rect_func: exp 7 | fr_scaling: 2.0 8 | noise_dict: 9 | obs_noise: pseudoPoisson 10 | dispersion: 1.0 11 | -------------------------------------------------------------------------------- /ctd/task_modeling/configs/simulator/default_NBFF.yaml: -------------------------------------------------------------------------------- 1 | _target_: ctd.task_modeling.simulator.neural_simulator.NeuralDataSimulator 2 | neuron_dict: 3 | n_neurons_heldin: 50 4 | n_neurons_heldout: 10 5 | embed_dict: 6 | rect_func: exp 7 | fr_scaling: 2.0 8 | noise_dict: 9 | obs_noise: pseudoPoisson 10 | dispersion: 1.0 11 | -------------------------------------------------------------------------------- /ctd/task_modeling/configs/simulator/default_RandomTarget.yaml: -------------------------------------------------------------------------------- 1 | _target_: ctd.task_modeling.simulator.neural_simulator.NeuralDataSimulator 2 | neuron_dict: 3 | n_neurons_heldin: 50 4 | n_neurons_heldout: 10 5 | embed_dict: 6 | rect_func: exp 7 | fr_scaling: 4.0 8 | noise_dict: 9 | obs_noise: pseudoPoisson 10 | dispersion: 1.0 11 | 12 | trim_inds: [5, -1] 13 | -------------------------------------------------------------------------------- /ctd/task_modeling/configs/task_wrapper/MultiTask.yaml: -------------------------------------------------------------------------------- 1 | _target_: ctd.task_modeling.task_wrapper.task_wrapper.TaskTrainedWrapper 2 | 3 | learning_rate: 1.0e-3 4 | weight_decay: 1.0e-8 5 | -------------------------------------------------------------------------------- /ctd/task_modeling/configs/task_wrapper/NBFF.yaml: -------------------------------------------------------------------------------- 1 | _target_: ctd.task_modeling.task_wrapper.task_wrapper.TaskTrainedWrapper 2 | 3 | learning_rate: 1.0e-3 4 | weight_decay: 1.0e-8 5 | -------------------------------------------------------------------------------- /ctd/task_modeling/configs/task_wrapper/RandomTarget.yaml: -------------------------------------------------------------------------------- 1 | _target_: ctd.task_modeling.task_wrapper.task_wrapper.TaskTrainedWrapper 2 | 3 | learning_rate: 1.0e-3 4 | weight_decay: 0 5 | -------------------------------------------------------------------------------- /ctd/task_modeling/configs/trainer/default.yaml: -------------------------------------------------------------------------------- 1 | _target_: pytorch_lightning.Trainer 2 | log_every_n_steps: 100 3 | max_epochs: 1000 4 | # Prevent console output from individual models 5 | gradient_clip_val: 1.0 6 | enable_progress_bar: False 7 | -------------------------------------------------------------------------------- /ctd/task_modeling/datamodule/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snel-repo/ComputationThroughDynamicsBenchmark/60caed2c6bb6e73ce3d05f99eda992f509e9ed94/ctd/task_modeling/datamodule/__init__.py -------------------------------------------------------------------------------- /ctd/task_modeling/datamodule/samplers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch.utils.data import BatchSampler 3 | 4 | 5 | class GroupedSampler(BatchSampler): 6 | # This sampler yields batches of data grouped by trial type. 7 | # This is useful for getting shared motifs on the MultiTask dataset 8 | def __init__(self, data_source, num_samples): 9 | self.dataset = data_source 10 | self.batch_size = num_samples 11 | self.num_samples = len(data_source) 12 | self.grouped_indices = self._group_indices_by_trial_type() 13 | 14 | def _group_indices_by_trial_type(self): 15 | # Group indices by trial type. 16 | trial_type_indices = {} 17 | trial_type = self.dataset.tensors[4] 18 | unique_trial_types = np.unique(trial_type) 19 | for ind1, trial_type1 in enumerate(unique_trial_types): 20 | trial_type_indices[ind1] = np.where(trial_type == trial_type1)[0] 21 | return trial_type_indices 22 | 23 | def __iter__(self): 24 | group_indices = list(self.grouped_indices.values()) 25 | indices_lens = np.array([len(x) for x in group_indices]) 26 | group_counter = np.zeros(len(group_indices)).astype(int) 27 | np.random.shuffle(group_indices) # Shuffle the groups 28 | while np.any(group_counter < indices_lens): 29 | for i, group in enumerate(group_indices): 30 | if group_counter[i] < len(group): 31 | yield group[group_counter[i] : group_counter[i] + self.batch_size] 32 | group_counter[i] += self.batch_size 33 | 34 | def __len__(self): 35 | # Calculate batches 36 | return (self.num_samples + self.batch_size - 1) // self.batch_size 37 | 38 | 39 | class RandomSampler(BatchSampler): 40 | def __init__(self, data_source, num_samples): 41 | self.dataset = data_source 42 | self.batch_size = num_samples 43 | self.num_samples = len(data_source) 44 | 45 | def __iter__(self): 46 | indices = np.arange(self.num_samples) 47 | np.random.shuffle(indices) 48 | 49 | for i in range(0, self.num_samples, self.batch_size): 50 | yield indices[i : i + self.batch_size] 51 | 52 | def __len__(self): 53 | # Calculate the number of batches 54 | return (self.num_samples + self.batch_size - 1) // self.batch_size 55 | 56 | 57 | class SequentialSampler(BatchSampler): 58 | def __init__(self, data_source, num_samples): 59 | self.dataset = data_source 60 | self.batch_size = num_samples 61 | self.num_samples = len(data_source) 62 | 63 | def __iter__(self): 64 | indices = np.arange(self.num_samples) 65 | 66 | for i in range(0, self.num_samples, self.batch_size): 67 | yield indices[i : i + self.batch_size] 68 | 69 | def __len__(self): 70 | # Calculate the number of batches 71 | return (self.num_samples + self.batch_size - 1) // self.batch_size 72 | -------------------------------------------------------------------------------- /ctd/task_modeling/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snel-repo/ComputationThroughDynamicsBenchmark/60caed2c6bb6e73ce3d05f99eda992f509e9ed94/ctd/task_modeling/model/__init__.py -------------------------------------------------------------------------------- /ctd/task_modeling/model/node.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | """ 5 | All models must meet a few requirements 6 | 1. They must have an init_model method that takes 7 | input_size and output_size as arguments 8 | 2. They must have a forward method that takes inputs and hidden 9 | as arguments and returns output and hidden for one time step 10 | 3. They must have a cell attribute that is the recurrent cell 11 | 4. They must have a readout attribute that is the output layer 12 | (mapping from latent to output) 13 | 14 | Optionally, 15 | 1. They can have an init_hidden method that takes 16 | batch_size as an argument and returns an initial hidden state 17 | 2. They can have a model_loss method that takes a loss_dict 18 | as an argument and returns a loss (L2 regularization on latents, etc.) 19 | 20 | """ 21 | 22 | 23 | class NODE(nn.Module): 24 | def __init__( 25 | self, 26 | num_layers, 27 | layer_hidden_size, 28 | latent_size, 29 | output_size=None, 30 | input_size=None, 31 | ): 32 | super().__init__() 33 | self.num_layers = num_layers 34 | self.layer_hidden_size = layer_hidden_size 35 | self.latent_size = latent_size 36 | self.output_size = output_size 37 | self.input_size = input_size 38 | self.generator = None 39 | self.readout = None 40 | self.latent_ics = torch.nn.Parameter( 41 | torch.zeros(latent_size), requires_grad=True 42 | ) 43 | 44 | def init_hidden(self, batch_size): 45 | return self.latent_ics.unsqueeze(0).expand(batch_size, -1) 46 | 47 | def init_model(self, input_size, output_size): 48 | self.input_size = input_size 49 | self.output_size = output_size 50 | self.generator = MLPCell( 51 | input_size, self.num_layers, self.layer_hidden_size, self.latent_size 52 | ) 53 | self.readout = nn.Linear(self.latent_size, output_size) 54 | # Initialize weights and biases for the readout layer 55 | nn.init.normal_( 56 | self.readout.weight, mean=0.0, std=0.01 57 | ) # Small standard deviation 58 | nn.init.constant_(self.readout.bias, 0.0) # Zero bias initialization 59 | 60 | def forward(self, inputs, hidden=None): 61 | n_samples, n_inputs = inputs.shape 62 | dev = inputs.device 63 | if hidden is None: 64 | hidden = torch.zeros((n_samples, self.latent_size), device=dev) 65 | hidden = self.generator(inputs, hidden) 66 | output = self.readout(hidden) 67 | return output, hidden 68 | 69 | 70 | class MLPCell(nn.Module): 71 | def __init__(self, input_size, num_layers, layer_hidden_size, latent_size): 72 | super().__init__() 73 | self.input_size = input_size 74 | self.num_layers = num_layers 75 | self.layer_hidden_size = layer_hidden_size 76 | self.latent_size = latent_size 77 | layers = nn.ModuleList() 78 | for i in range(num_layers): 79 | if i == 0: 80 | layers.append(nn.Linear(input_size + latent_size, layer_hidden_size)) 81 | layers.append(nn.ReLU()) 82 | elif i == num_layers - 1: 83 | layers.append(nn.Linear(layer_hidden_size, latent_size)) 84 | else: 85 | layers.append(nn.Linear(layer_hidden_size, layer_hidden_size)) 86 | layers.append(nn.ReLU()) 87 | self.vf_net = nn.Sequential(*layers) 88 | 89 | def forward(self, input, hidden): 90 | input_hidden = torch.cat([hidden, input], dim=1) 91 | return hidden + 0.1 * self.vf_net(input_hidden) 92 | -------------------------------------------------------------------------------- /ctd/task_modeling/model/tt_template.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class TT_Template(nn.Module): 6 | """ 7 | Template for Task-Trained RNN models. 8 | 9 | All subclasses must implement: 10 | 1. init_model(input_size, output_size) 11 | 2. forward(inputs, hidden) -> (output, hidden) 12 | 3. self.cell # the recurrent cell module 13 | 4. self.readout # the output layer (latent → output) 14 | 15 | Optional hooks: 16 | - init_hidden(batch_size) -> Tensor 17 | - model_loss(loss_dict) -> Tensor 18 | """ 19 | 20 | def __init__( 21 | self, latent_size: int, input_size: int = None, output_size: int = None 22 | ): 23 | super().__init__() 24 | # will be set in init_model(): 25 | self.cell: nn.Module = None 26 | self.readout: nn.Module = None 27 | 28 | self.input_size = input_size 29 | self.latent_size = latent_size 30 | self.output_size = output_size 31 | 32 | def init_model(self, input_size: int, output_size: int): 33 | """ 34 | Instantiate: 35 | - self.cell (e.g. GRUCell/LSTMCell) 36 | - self.readout (nn.Linear from latent_size→output_size) 37 | """ 38 | raise NotImplementedError("Must implement init_model()") 39 | 40 | def init_hidden(self, batch_size: int) -> torch.Tensor: 41 | """ 42 | (Optional) Return initial hidden state of shape 43 | (batch_size, latent_size). 44 | """ 45 | raise NotImplementedError("Optional: implement init_hidden()") 46 | 47 | def forward( 48 | self, inputs: torch.Tensor, hidden: torch.Tensor 49 | ) -> tuple[torch.Tensor, torch.Tensor]: 50 | """ 51 | One timestep of RNN dynamics: 52 | inputs: Tensor [batch_size, input_size] 53 | hidden: Tensor [batch_size, latent_size] 54 | Returns: 55 | output: Tensor [batch_size, output_size] 56 | hidden: Tensor [batch_size, latent_size] 57 | """ 58 | raise NotImplementedError("Must implement forward()") 59 | 60 | def model_loss(self, loss_dict: dict) -> torch.Tensor: 61 | """ 62 | (Optional) Compute extra loss terms (e.g. latent regularization) 63 | from loss_dict and return a scalar Tensor. 64 | """ 65 | raise NotImplementedError("Optional: implement model_loss()") 66 | -------------------------------------------------------------------------------- /ctd/task_modeling/simulator/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snel-repo/ComputationThroughDynamicsBenchmark/60caed2c6bb6e73ce3d05f99eda992f509e9ed94/ctd/task_modeling/simulator/__init__.py -------------------------------------------------------------------------------- /ctd/task_modeling/task_env/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snel-repo/ComputationThroughDynamicsBenchmark/60caed2c6bb6e73ce3d05f99eda992f509e9ed94/ctd/task_modeling/task_env/__init__.py -------------------------------------------------------------------------------- /ctd/task_modeling/task_env/loss_func.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class LossFunc: 7 | def __init__(): 8 | pass 9 | 10 | def __call__(self, loss_dict): 11 | pass 12 | 13 | 14 | class RandomTargetLoss(LossFunc): 15 | def __init__( 16 | self, position_loss, pos_weight, act_weight, full_trial_epoch: int = 200 17 | ): 18 | """Initialize the loss function 19 | Args: 20 | position_loss (nn.Module): The loss function to use for the position 21 | pos_weight (float): The weight to apply to the position loss 22 | act_weight (float): The weight to apply to the action loss 23 | full_trial_epoch (int): The number of epochs 24 | before the full trial is included in the loss""" 25 | self.position_loss = position_loss 26 | self.action_loss = nn.MSELoss() 27 | self.pos_weight = pos_weight 28 | self.act_weight = act_weight 29 | self.full_trial_epoch = full_trial_epoch 30 | 31 | def __call__(self, loss_dict): 32 | pred = loss_dict["controlled"] 33 | target = loss_dict["targets"] 34 | act = loss_dict["actions"] 35 | epoch = loss_dict["epoch"] 36 | n_time = pred.shape[1] 37 | # Gradually increase the percent of the trial to include in the loss 38 | include_loss = np.ceil(n_time * min(1.0, epoch / self.full_trial_epoch)).astype( 39 | int 40 | ) 41 | pos_loss = self.pos_weight * self.position_loss( 42 | pred[:, :include_loss, :], target[:, :include_loss, :] 43 | ) 44 | act_loss = self.act_weight * self.action_loss(act, torch.zeros_like(act)) 45 | return pos_loss + act_loss 46 | 47 | 48 | class NBFFLoss(LossFunc): 49 | def __init__(self, transition_blind): 50 | """Initialize the loss function 51 | Args: 52 | transition_blind (int): The number of steps to 53 | ignore the effect of transitions for""" 54 | 55 | self.transition_blind = transition_blind 56 | 57 | def __call__(self, loss_dict): 58 | pred = loss_dict["controlled"] 59 | target = loss_dict["targets"] 60 | 61 | # Find where the change in the target is not zero 62 | # Step 1: Find where the transitions occur (change in value) 63 | transitions = torch.diff(target, dim=1) != 0 64 | 65 | # Initialize the mask with ones, with one less column (due to diff) 66 | mask = torch.ones_like(transitions, dtype=torch.float) 67 | 68 | # Step 2: Propagate the effect of transitions for 'transition_blind' steps 69 | for i in range(1, self.transition_blind + 1): 70 | # Shift the transition marks to the right to affect subsequent values 71 | shifted_transitions = torch.cat( 72 | (torch.zeros_like(transitions[:, :i]), transitions[:, :-i]), dim=1 73 | ) 74 | mask = mask * (1 - shifted_transitions.float()) 75 | 76 | # Step 3: Adjust mask size to match the original target tensor 77 | # Adding a column of ones at the beginning because diff reduces the size by 1 78 | final_mask = torch.cat((torch.ones_like(mask[:, :1]), mask), dim=1) 79 | final_mask[:, 0:5, :] = 0.0 80 | 81 | loss = nn.MSELoss(reduction="none")(pred, target) * final_mask 82 | return loss.mean() 83 | 84 | 85 | class MatchTargetLossMSE(LossFunc): 86 | def __init__(self): 87 | pass 88 | 89 | def __call__(self, loss_dict): 90 | pred = loss_dict["controlled"] 91 | target = loss_dict["targets"] 92 | # action = loss_dict["actions"] 93 | # inputs = loss_dict["inputs"] 94 | return nn.MSELoss()(pred, target) 95 | 96 | 97 | class MultiTaskLoss(LossFunc): 98 | def __init__(self, lat_loss_weight=1e-6): 99 | self.lat_loss_weight = lat_loss_weight 100 | pass 101 | 102 | def __call__(self, loss_dict): 103 | 104 | """Calculate the loss""" 105 | pred = loss_dict["controlled"] 106 | target = loss_dict["targets"] 107 | latents = loss_dict["latents"] 108 | # action = loss_dict["actions"] 109 | inputs = loss_dict["inputs"] 110 | extras = loss_dict["extra"] 111 | resp_start = extras[:, 0].long() 112 | resp_end = extras[:, 1].long() 113 | recon_loss = nn.MSELoss(reduction="none")(pred, target) 114 | mask = torch.ones_like(recon_loss) 115 | mask_lats = torch.ones_like(latents) 116 | 117 | # Ignore the first 5 time steps and the time steps after the response 118 | mask[:, 0:5, :] = 0 119 | for i in range(inputs.shape[0]): 120 | mask[i, resp_start[i] : resp_end[i], :] = 5.0 121 | mask[i, resp_start[i] : resp_start[i] + 5, :] = 0.0 122 | mask[i, resp_end[i] :, :] = 0.0 123 | # Mask the latents after the response 124 | mask_lats[i, resp_end[i] :, :] = 0 125 | 126 | masked_loss = recon_loss * mask 127 | lats_loss = ( 128 | nn.MSELoss(reduction="none")(latents, torch.zeros_like(latents)) * mask_lats 129 | ) 130 | 131 | total_loss = ( 132 | masked_loss.sum(dim=1).mean() 133 | + self.lat_loss_weight * lats_loss.sum(dim=1).mean() 134 | ) 135 | return total_loss 136 | -------------------------------------------------------------------------------- /ctd/task_modeling/task_wrapper/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snel-repo/ComputationThroughDynamicsBenchmark/60caed2c6bb6e73ce3d05f99eda992f509e9ed94/ctd/task_modeling/task_wrapper/__init__.py -------------------------------------------------------------------------------- /ctd/task_modeling/task_wrapper/utils.py: -------------------------------------------------------------------------------- 1 | def make_data_tag(dm_cfg): 2 | obs_dim = "" if "obs_dim" not in dm_cfg else dm_cfg.obs_dim 3 | obs_noise = "" if "obs_noise" not in dm_cfg else dm_cfg.obs_noise 4 | if "obs_noise_params" in dm_cfg: 5 | obs_noise_params = ",".join( 6 | [f"{k}={v}" for k, v in dm_cfg.obs_noise_params.items()] 7 | ) 8 | else: 9 | obs_noise_params = "" 10 | data_tag = ( 11 | f"{dm_cfg.system}{obs_dim}_" 12 | f"{dm_cfg.n_samples}S_" 13 | f"{dm_cfg.n_timesteps}T_" 14 | f"{dm_cfg.pts_per_period}P_" 15 | f"{dm_cfg.seed}seed" 16 | ) 17 | if obs_noise: 18 | data_tag += f"_{obs_noise}{obs_noise_params}" 19 | return data_tag 20 | -------------------------------------------------------------------------------- /examples/figures/Fig3TaskPerformance/MemoryPro_MemoryProPCs_combined_video.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snel-repo/ComputationThroughDynamicsBenchmark/60caed2c6bb6e73ce3d05f99eda992f509e9ed94/examples/figures/Fig3TaskPerformance/MemoryPro_MemoryProPCs_combined_video.gif -------------------------------------------------------------------------------- /examples/gen_datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | 4 | import dotenv 5 | 6 | from ctd.comparison.analysis.tt.tt import Analysis_TT 7 | 8 | 9 | def copy_folder_contents(src_folder, dest_folder): 10 | # Ensure the destination folder exists 11 | os.makedirs(dest_folder, exist_ok=True) 12 | 13 | # Iterate over all files and directories in the source folder 14 | for item_name in os.listdir(src_folder): 15 | src_item = os.path.join(src_folder, item_name) 16 | dest_item = os.path.join(dest_folder, item_name) 17 | 18 | # If it's a directory, copy it recursively 19 | if os.path.isdir(src_item): 20 | shutil.copytree(src_item, dest_item, dirs_exist_ok=True) 21 | else: 22 | shutil.copy2(src_item, dest_item) 23 | 24 | 25 | dotenv.load_dotenv(override=True) 26 | HOME_DIR = os.environ.get("HOME_DIR") 27 | print(HOME_DIR) 28 | 29 | tt_3bff_path = HOME_DIR + "pretrained/20241017_NBFF_NoisyGRU_NewFinal/" 30 | tt_MultiTask_path = HOME_DIR + "pretrained/20241113_MultiTask_NoisyGRU_Final2/" 31 | tt_RandomTarget_path = HOME_DIR + "pretrained/20241113_RandomTarget_NoisyGRU_Final2/" 32 | 33 | tt_3bff = Analysis_TT(run_name="tt_3bff", filepath=tt_3bff_path) 34 | tt_MultiTask = Analysis_TT(run_name="tt_MultiTask", filepath=tt_MultiTask_path) 35 | tt_RandomTarget = Analysis_TT(run_name="tt_RandomTarget", filepath=tt_RandomTarget_path) 36 | 37 | # Make copies of the pretrained models to the trained_models folder 38 | # if the folders don't already exist 39 | path_3bff = HOME_DIR + "content/trained_models/task-trained/tt_3bff/" 40 | path_MultiTask = HOME_DIR + "content/trained_models/task-trained/tt_MultiTask/" 41 | path_RandomTarget = HOME_DIR + "content/trained_models/task-trained/tt_RandomTarget/" 42 | 43 | if not os.path.exists(path_3bff): 44 | copy_folder_contents( 45 | tt_3bff_path, HOME_DIR + "content/trained_models/task-trained/tt_3bff/" 46 | ) 47 | 48 | if not os.path.exists(path_MultiTask): 49 | copy_folder_contents( 50 | tt_MultiTask_path, 51 | HOME_DIR + "content/trained_models/task-trained/tt_MultiTask/", 52 | ) 53 | 54 | if not os.path.exists(path_RandomTarget): 55 | copy_folder_contents( 56 | tt_RandomTarget_path, 57 | HOME_DIR + "content/trained_models/task-trained/tt_RandomTarget/", 58 | ) 59 | 60 | # Generate simulated datasets 61 | dataset_path = HOME_DIR + "content/datasets/dd/" 62 | 63 | tt_3bff.simulate_neural_data( 64 | subfolder="max_epochs=500 n_samples=1000 latent_size=64 seed=0 learning_rate=0.001", 65 | dataset_path=dataset_path, 66 | ) 67 | 68 | mt_subfolder = "max_epochs=500 seed=0" 69 | tt_MultiTask.simulate_neural_data( 70 | subfolder=mt_subfolder, 71 | dataset_path=dataset_path, 72 | ) 73 | 74 | rt_subfolder = ( 75 | "max_epochs=2000 latent_size=128 l2_wt=5e-05 " 76 | + "proprioception_delay=0.02 vision_delay=0.05 " 77 | + "n_samples=1100 n_samples=1100 seed=0 learning_rate=0.005" 78 | ) 79 | tt_RandomTarget.simulate_neural_data( 80 | subfolder=rt_subfolder, 81 | dataset_path=dataset_path, 82 | ) 83 | -------------------------------------------------------------------------------- /examples/notebooks/png/AnalysisStructure-01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snel-repo/ComputationThroughDynamicsBenchmark/60caed2c6bb6e73ce3d05f99eda992f509e9ed94/examples/notebooks/png/AnalysisStructure-01.png -------------------------------------------------------------------------------- /examples/notebooks/png/BenchmarkFlow2-01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snel-repo/ComputationThroughDynamicsBenchmark/60caed2c6bb6e73ce3d05f99eda992f509e9ed94/examples/notebooks/png/BenchmarkFlow2-01.png -------------------------------------------------------------------------------- /examples/notebooks/png/BenchmarkFlowTTDT-01-01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snel-repo/ComputationThroughDynamicsBenchmark/60caed2c6bb6e73ce3d05f99eda992f509e9ed94/examples/notebooks/png/BenchmarkFlowTTDT-01-01.png -------------------------------------------------------------------------------- /examples/notebooks/png/BenchmarkGrey-01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snel-repo/ComputationThroughDynamicsBenchmark/60caed2c6bb6e73ce3d05f99eda992f509e9ed94/examples/notebooks/png/BenchmarkGrey-01.png -------------------------------------------------------------------------------- /examples/notebooks/png/BenchmarkSchematicSimple_steps.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snel-repo/ComputationThroughDynamicsBenchmark/60caed2c6bb6e73ce3d05f99eda992f509e9ed94/examples/notebooks/png/BenchmarkSchematicSimple_steps.png -------------------------------------------------------------------------------- /examples/notebooks/png/DSAPic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snel-repo/ComputationThroughDynamicsBenchmark/60caed2c6bb6e73ce3d05f99eda992f509e9ed94/examples/notebooks/png/DSAPic.png -------------------------------------------------------------------------------- /examples/notebooks/png/FinalGif.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snel-repo/ComputationThroughDynamicsBenchmark/60caed2c6bb6e73ce3d05f99eda992f509e9ed94/examples/notebooks/png/FinalGif.gif -------------------------------------------------------------------------------- /examples/notebooks/png/Hourglass.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snel-repo/ComputationThroughDynamicsBenchmark/60caed2c6bb6e73ce3d05f99eda992f509e9ed94/examples/notebooks/png/Hourglass.png -------------------------------------------------------------------------------- /examples/notebooks/png/MemoryPro_MemoryProPCs_combined_video.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snel-repo/ComputationThroughDynamicsBenchmark/60caed2c6bb6e73ce3d05f99eda992f509e9ed94/examples/notebooks/png/MemoryPro_MemoryProPCs_combined_video.gif -------------------------------------------------------------------------------- /examples/notebooks/png/MotorNet Illustration-01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snel-repo/ComputationThroughDynamicsBenchmark/60caed2c6bb6e73ce3d05f99eda992f509e9ed94/examples/notebooks/png/MotorNet Illustration-01.png -------------------------------------------------------------------------------- /examples/notebooks/png/MotorNetGif.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snel-repo/ComputationThroughDynamicsBenchmark/60caed2c6bb6e73ce3d05f99eda992f509e9ed94/examples/notebooks/png/MotorNetGif.gif -------------------------------------------------------------------------------- /examples/notebooks/png/NoteBookQR.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snel-repo/ComputationThroughDynamicsBenchmark/60caed2c6bb6e73ce3d05f99eda992f509e9ed94/examples/notebooks/png/NoteBookQR.png -------------------------------------------------------------------------------- /examples/notebooks/png/Problem.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snel-repo/ComputationThroughDynamicsBenchmark/60caed2c6bb6e73ce3d05f99eda992f509e9ed94/examples/notebooks/png/Problem.png -------------------------------------------------------------------------------- /examples/notebooks/png/SAE.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snel-repo/ComputationThroughDynamicsBenchmark/60caed2c6bb6e73ce3d05f99eda992f509e9ed94/examples/notebooks/png/SAE.png -------------------------------------------------------------------------------- /examples/notebooks/png/SimulationDiagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snel-repo/ComputationThroughDynamicsBenchmark/60caed2c6bb6e73ce3d05f99eda992f509e9ed94/examples/notebooks/png/SimulationDiagram.png -------------------------------------------------------------------------------- /examples/notebooks/png/StateR2-01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snel-repo/ComputationThroughDynamicsBenchmark/60caed2c6bb6e73ce3d05f99eda992f509e9ed94/examples/notebooks/png/StateR2-01.png -------------------------------------------------------------------------------- /examples/notebooks/png/Step1-01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snel-repo/ComputationThroughDynamicsBenchmark/60caed2c6bb6e73ce3d05f99eda992f509e9ed94/examples/notebooks/png/Step1-01.png -------------------------------------------------------------------------------- /examples/notebooks/png/Step2-01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snel-repo/ComputationThroughDynamicsBenchmark/60caed2c6bb6e73ce3d05f99eda992f509e9ed94/examples/notebooks/png/Step2-01.png -------------------------------------------------------------------------------- /examples/notebooks/png/Step3-01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snel-repo/ComputationThroughDynamicsBenchmark/60caed2c6bb6e73ce3d05f99eda992f509e9ed94/examples/notebooks/png/Step3-01.png -------------------------------------------------------------------------------- /examples/notebooks/png/Step4-01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snel-repo/ComputationThroughDynamicsBenchmark/60caed2c6bb6e73ce3d05f99eda992f509e9ed94/examples/notebooks/png/Step4-01.png -------------------------------------------------------------------------------- /examples/notebooks/png/SussilloBarack.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snel-repo/ComputationThroughDynamicsBenchmark/60caed2c6bb6e73ce3d05f99eda992f509e9ed94/examples/notebooks/png/SussilloBarack.png -------------------------------------------------------------------------------- /examples/notebooks/png/TTModelExample-01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snel-repo/ComputationThroughDynamicsBenchmark/60caed2c6bb6e73ce3d05f99eda992f509e9ed94/examples/notebooks/png/TTModelExample-01.png -------------------------------------------------------------------------------- /examples/notebooks/png/TaskComplexity-01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snel-repo/ComputationThroughDynamicsBenchmark/60caed2c6bb6e73ce3d05f99eda992f509e9ed94/examples/notebooks/png/TaskComplexity-01.png -------------------------------------------------------------------------------- /examples/notebooks/png/TaskEnvs-01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snel-repo/ComputationThroughDynamicsBenchmark/60caed2c6bb6e73ce3d05f99eda992f509e9ed94/examples/notebooks/png/TaskEnvs-01.png -------------------------------------------------------------------------------- /examples/notebooks/png/TaskTrained-01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snel-repo/ComputationThroughDynamicsBenchmark/60caed2c6bb6e73ce3d05f99eda992f509e9ed94/examples/notebooks/png/TaskTrained-01.png -------------------------------------------------------------------------------- /examples/notebooks/png/TaskTraininSchematic-01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snel-repo/ComputationThroughDynamicsBenchmark/60caed2c6bb6e73ce3d05f99eda992f509e9ed94/examples/notebooks/png/TaskTraininSchematic-01.png -------------------------------------------------------------------------------- /examples/notebooks/png/Template.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snel-repo/ComputationThroughDynamicsBenchmark/60caed2c6bb6e73ce3d05f99eda992f509e9ed94/examples/notebooks/png/Template.png -------------------------------------------------------------------------------- /examples/notebooks/png/TutorialTT-01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snel-repo/ComputationThroughDynamicsBenchmark/60caed2c6bb6e73ce3d05f99eda992f509e9ed94/examples/notebooks/png/TutorialTT-01.png -------------------------------------------------------------------------------- /examples/notebooks/png/TutorialTT0-01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snel-repo/ComputationThroughDynamicsBenchmark/60caed2c6bb6e73ce3d05f99eda992f509e9ed94/examples/notebooks/png/TutorialTT0-01.png -------------------------------------------------------------------------------- /examples/notebooks/png/TutorialTTComp-01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snel-repo/ComputationThroughDynamicsBenchmark/60caed2c6bb6e73ce3d05f99eda992f509e9ed94/examples/notebooks/png/TutorialTTComp-01.png -------------------------------------------------------------------------------- /examples/notebooks/png/TutorialTT_model-01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snel-repo/ComputationThroughDynamicsBenchmark/60caed2c6bb6e73ce3d05f99eda992f509e9ed94/examples/notebooks/png/TutorialTT_model-01.png -------------------------------------------------------------------------------- /examples/notebooks/png/lfads_fps.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snel-repo/ComputationThroughDynamicsBenchmark/60caed2c6bb6e73ce3d05f99eda992f509e9ed94/examples/notebooks/png/lfads_fps.png -------------------------------------------------------------------------------- /examples/notebooks/png/loopingMultiTask.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snel-repo/ComputationThroughDynamicsBenchmark/60caed2c6bb6e73ce3d05f99eda992f509e9ed94/examples/notebooks/png/loopingMultiTask.gif -------------------------------------------------------------------------------- /examples/run_data_training.py: -------------------------------------------------------------------------------- 1 | # import os 2 | 3 | # os.environ["CUDA_VISIBLE_DEVICES"] = "" 4 | import logging 5 | import os 6 | import shutil 7 | from datetime import datetime 8 | from pathlib import Path 9 | 10 | import dotenv 11 | import ray 12 | from ray import tune 13 | from ray.tune import CLIReporter 14 | from ray.tune.schedulers import FIFOScheduler 15 | from ray.tune.search.basic_variant import BasicVariantGenerator 16 | 17 | from ctd.data_modeling.train_PTL import train_PTL 18 | 19 | dotenv.load_dotenv(override=True) 20 | HOME_DIR = Path(os.environ.get("HOME_DIR")) 21 | 22 | log = logging.getLogger(__name__) 23 | # ---------------Options--------------- 24 | LOCAL_MODE = False 25 | OVERWRITE = True 26 | WANDB_LOGGING = True # If users have a WandB account 27 | 28 | RUN_DESC = "3BFF_NODE_sweep" # Description of the run 29 | NUM_SAMPLES = 1 30 | MODEL_CLASS = "SAE" # "LFADS" or "SAE" 31 | MODEL = "NODE" # see /ctd/data_modeling/configs/models/{MODEL_CLASS}/ for options 32 | DATA = "NBFF" # "NBFF", "RandomTarget" or "MultiTask 33 | INFER_INPUTS = False # Whether external inputs are inferred or supplied 34 | 35 | if DATA == "NBFF": 36 | prefix = "tt_3bff" 37 | elif DATA == "MultiTask": 38 | prefix = "tt_MultiTask" 39 | elif DATA == "RandomTarget": 40 | prefix = "tt_RandomTarget" 41 | 42 | # ------------------------------------- 43 | # Hyperparameter sweeping: 44 | # Default parameters chosen to replicate Fig. 5 45 | # ------------------------------------- 46 | SEARCH_SPACE = { 47 | "datamodule.prefix": tune.grid_search([prefix]), 48 | "model.latent_size": tune.grid_search([3, 5, 8, 16, 32, 64]), 49 | "trainer.max_epochs": tune.grid_search([1000]), 50 | "params.seed": tune.grid_search([0, 1, 2, 3, 4]), 51 | } 52 | 53 | # -----------------Default Parameter Sets ----------------------------------- 54 | cpath = "../data_modeling/configs" 55 | 56 | model_path = Path( 57 | ( 58 | f"{cpath}/models/{MODEL_CLASS}/{DATA}/{DATA}_{MODEL}" 59 | f"{'_infer' if INFER_INPUTS else ''}.yaml" 60 | ) 61 | ) 62 | 63 | datamodule_path = Path( 64 | ( 65 | f"{cpath}/datamodules/{MODEL_CLASS}/data_{DATA}" 66 | f"{'_infer' if INFER_INPUTS else ''}.yaml" 67 | ) 68 | ) 69 | 70 | callbacks_path = Path(f"{cpath}/callbacks/{MODEL_CLASS}/default_{DATA}.yaml") 71 | loggers_path = Path(f"{cpath}/loggers/{MODEL_CLASS}/default.yaml") 72 | trainer_path = Path(f"{cpath}/trainers/{MODEL_CLASS}/trainer_{DATA}.yaml") 73 | 74 | if not WANDB_LOGGING: 75 | loggers_path = Path(f"{cpath}/loggers/{MODEL_CLASS}/default_no_wandb.yaml") 76 | callbacks_path = Path(f"{cpath}/callbacks/{MODEL_CLASS}/default_no_wandb.yaml") 77 | 78 | if MODEL_CLASS not in ["LDS"]: 79 | config_dict = dict( 80 | model=model_path, 81 | datamodule=datamodule_path, 82 | callbacks=callbacks_path, 83 | loggers=loggers_path, 84 | trainer=trainer_path, 85 | ) 86 | train = train_PTL 87 | else: 88 | config_dict = dict( 89 | model=model_path, 90 | datamodule=datamodule_path, 91 | trainer=trainer_path, 92 | ) 93 | # train = train_JAX 94 | 95 | # ------------------Data Management Variables -------------------------------- 96 | DATE_STR = datetime.now().strftime("%Y%m%d") 97 | RUN_TAG = f"{DATE_STR}_{RUN_DESC}" 98 | RUNS_HOME = Path(HOME_DIR) 99 | RUN_DIR = HOME_DIR / "content" / "runs" / "data-trained" / RUN_TAG 100 | path_dict = dict( 101 | dd_datasets=HOME_DIR / "content" / "datasets" / "dd", 102 | trained_models=HOME_DIR / "content" / "trained_models" / "task-trained" / prefix, 103 | ) 104 | 105 | 106 | def trial_function(trial): 107 | return trial.experiment_tag 108 | 109 | 110 | # -------------------Main Function---------------------------------- 111 | def main( 112 | run_tag_in: str, 113 | path_dict: dict, 114 | config_dict: dict, 115 | ): 116 | if LOCAL_MODE: 117 | ray.init(local_mode=True) 118 | if RUN_DIR.exists() and OVERWRITE: 119 | shutil.rmtree(RUN_DIR) 120 | 121 | RUN_DIR.mkdir(parents=True) 122 | shutil.copyfile(__file__, RUN_DIR / Path(__file__).name) 123 | run_dir = str(RUN_DIR) 124 | tune.run( 125 | tune.with_parameters( 126 | train, run_tag=run_tag_in, config_dict=config_dict, path_dict=path_dict 127 | ), 128 | config=SEARCH_SPACE, 129 | resources_per_trial=dict(cpu=4, gpu=0.45), 130 | num_samples=NUM_SAMPLES, 131 | storage_path=run_dir, 132 | search_alg=BasicVariantGenerator(), 133 | scheduler=FIFOScheduler(), 134 | verbose=1, 135 | progress_reporter=CLIReporter( 136 | metric_columns=["loss", "training_iteration"], 137 | sort_by_metric=True, 138 | ), 139 | trial_dirname_creator=trial_function, 140 | ) 141 | 142 | 143 | if __name__ == "__main__": 144 | main( 145 | run_tag_in=RUN_TAG, 146 | config_dict=config_dict, 147 | path_dict=path_dict, 148 | ) 149 | -------------------------------------------------------------------------------- /examples/run_task_training.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import ray 4 | 5 | LOCAL_MODE = False # Set to True to run locally (for debugging or RandomTarget) 6 | if LOCAL_MODE: 7 | ray.init(local_mode=True, num_gpus=0) # Ensure no GPUs are requested 8 | os.environ["CUDA_VISIBLE_DEVICES"] = "" 9 | 10 | import shutil 11 | from pathlib import Path 12 | 13 | import dotenv 14 | from ray import tune 15 | from ray.tune import CLIReporter 16 | from ray.tune.schedulers import FIFOScheduler 17 | from ray.tune.search.basic_variant import BasicVariantGenerator 18 | 19 | from ctd.task_modeling.task_training import train 20 | from utils import generate_paths, trial_function 21 | 22 | dotenv.load_dotenv(override=True) 23 | 24 | 25 | # ---------------Options--------------- 26 | OVERWRITE = True # Set to True to overwrite existing run 27 | 28 | RUN_DESC = "NBFF_NoisyGRU_Final" 29 | TASK = "NBFF" # Task to train on (see configs/task_env for options) 30 | MODEL = "NoisyGRULatentL2" # Model to train (see configs/model for options) 31 | 32 | # ----------------- Parameter Selection ----------------------------------- 33 | SEARCH_SPACE = { 34 | "trainer.max_epochs": tune.choice([3000]), 35 | # 'datamodule_train.batch_size': tune.choice([1000]), 36 | # 'task_wrapper.weight_decay': tune.choice([1e-5]), 37 | "params.seed": tune.grid_search([0]), 38 | } 39 | 40 | # ------------------Data Management -------------------------------- 41 | combo_dict = generate_paths(RUN_DESC, TASK, MODEL) 42 | path_dict = combo_dict["path_dict"] 43 | RUN_TAG = combo_dict["RUN_TAG"] 44 | RUN_DIR = combo_dict["RUN_DIR"] 45 | config_dict = combo_dict["config_dict"] 46 | 47 | # -------------------Main Function---------------------------------- 48 | def main( 49 | run_tag_in: str, 50 | path_dict: str, 51 | config_dict: dict, 52 | ): 53 | 54 | if RUN_DIR.exists() and OVERWRITE: 55 | shutil.rmtree(RUN_DIR) 56 | 57 | RUN_DIR.mkdir(parents=True) 58 | shutil.copyfile(__file__, RUN_DIR / Path(__file__).name) 59 | tune.run( 60 | tune.with_parameters( 61 | train, 62 | run_tag=run_tag_in, 63 | path_dict=path_dict, 64 | config_dict=config_dict, 65 | ), 66 | metric="loss", 67 | mode="min", 68 | config=SEARCH_SPACE, 69 | # resources_per_trial=dict(cpu=8, gpu=0.9), 70 | num_samples=1, 71 | storage_path=str(RUN_DIR), 72 | search_alg=BasicVariantGenerator(), 73 | scheduler=FIFOScheduler(), 74 | verbose=1, 75 | progress_reporter=CLIReporter( 76 | metric_columns=["loss", "training_iteration"], 77 | sort_by_metric=True, 78 | ), 79 | trial_dirname_creator=trial_function, 80 | ) 81 | 82 | 83 | if __name__ == "__main__": 84 | main( 85 | run_tag_in=RUN_TAG, 86 | path_dict=path_dict, 87 | config_dict=config_dict, 88 | ) 89 | -------------------------------------------------------------------------------- /pretrained/.gitattributes: -------------------------------------------------------------------------------- 1 | 20240703_MultiTask_TrialLenFix/datamodule_sim.pkl filter=lfs diff=lfs merge=lfs -text 2 | 20240703_MultiTask_TrialLenFix/datamodule_train.pkl filter=lfs diff=lfs merge=lfs -text 3 | 20240703_MultiTask_TrialLenFix/model.pkl filter=lfs diff=lfs merge=lfs -text 4 | 20240703_MultiTask_TrialLenFix/simulator.pkl filter=lfs diff=lfs merge=lfs -text 5 | -------------------------------------------------------------------------------- /pretrained/20241017_NBFF_NoisyGRU_NewFinal/.gitattributes: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snel-repo/ComputationThroughDynamicsBenchmark/60caed2c6bb6e73ce3d05f99eda992f509e9ed94/pretrained/20241017_NBFF_NoisyGRU_NewFinal/.gitattributes -------------------------------------------------------------------------------- /pretrained/20241017_NBFF_NoisyGRU_NewFinal/datamodule_sim.pkl: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:d985c4bdccbd7e5aeedde85833f12f824c426bc3a44086c087b4f5adab030978 3 | size 18031030 4 | -------------------------------------------------------------------------------- /pretrained/20241017_NBFF_NoisyGRU_NewFinal/datamodule_train.pkl: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:7c1f76fda99c217d8e0cd17b8085aa55112f8d99d0858d5f226d827b4f2f761b 3 | size 18223748 4 | -------------------------------------------------------------------------------- /pretrained/20241017_NBFF_NoisyGRU_NewFinal/model.pkl: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:8416f4c7b6477c52105e1fa0ece379ae60eeaac2d18aac0b719ee0924ce6ea13 3 | size 59762 4 | -------------------------------------------------------------------------------- /pretrained/20241017_NBFF_NoisyGRU_NewFinal/simulator.pkl: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:a8674123b65bb2b29733ac563f5bd7baf3f7bac59706e9dc9f5dbead2d6d9a9e 3 | size 31885 4 | -------------------------------------------------------------------------------- /pretrained/20241113_MultiTask_NoisyGRU_Final2/datamodule_sim.pkl: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:0af15ad46347b8c67225a909aca5095ecaafa3c47edf6319ae05c5162ec9430e 3 | size 413503210 4 | -------------------------------------------------------------------------------- /pretrained/20241113_MultiTask_NoisyGRU_Final2/datamodule_train.pkl: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:47929e26b9cc0463522a9e4459eb5db19310216339cac1ab6dee92ddeca868e5 3 | size 827826533 4 | -------------------------------------------------------------------------------- /pretrained/20241113_MultiTask_NoisyGRU_Final2/model.pkl: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:900a8fc4ab149a20d4dd333307e27495bafadcc03595fa56fb8194f59ea59020 3 | size 241791 4 | -------------------------------------------------------------------------------- /pretrained/20241113_MultiTask_NoisyGRU_Final2/simulator.pkl: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:0b6ffc308ab224e09203313fd3a080c30f691e411132c6da7d1f655d0b939d4e 3 | size 62588 4 | -------------------------------------------------------------------------------- /pretrained/20241113_RandomTarget_NoisyGRU_Final2/datamodule_sim.pkl: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:f20775e0cdd09fe86b4b400b6baeaf72e64b34f3a04a6e65b2dc693be096e314 3 | size 6738868 4 | -------------------------------------------------------------------------------- /pretrained/20241113_RandomTarget_NoisyGRU_Final2/datamodule_train.pkl: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:ab01dd23b551563358cca442cb0032dc1dec752c7523d0ac3e96e52c0424362e 3 | size 12890179 4 | -------------------------------------------------------------------------------- /pretrained/20241113_RandomTarget_NoisyGRU_Final2/model.pkl: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:0d8d3fad457840aa8627c0286eecca14595cda862292737e5f6f0ce156558209 3 | size 253648 4 | -------------------------------------------------------------------------------- /pretrained/20241113_RandomTarget_NoisyGRU_Final2/simulator.pkl: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:9ad624df5f70f230f126cbc447768c17f6554536ba777032d83e056b388ae294 3 | size 62598 4 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchmetrics 3 | pytorch-lightning 4 | ray[tune] 5 | omegaconf 6 | wandb 7 | hydra-core 8 | gymnasium 9 | h5py 10 | scikit-learn 11 | matplotlib 12 | python-dotenv 13 | imageio[ffmpeg] 14 | motornet 15 | ipykernel 16 | ipywidgets 17 | opencv-python 18 | jax[cpu] 19 | #torchvision==0.14 20 | #torchdata==0.5.1 21 | #torchaudio==0.13.1 22 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | # Avoids duplication of requirements 4 | with open("requirements.txt") as file: 5 | requirements = file.read().splitlines() 6 | requirements.append("DSA @ git+https://github.com/mitchellostrow/DSA.git@main#egg=DSA") 7 | setup( 8 | name="ctd", 9 | version="1.0", 10 | install_requires=requirements, 11 | packages=find_packages(), 12 | ) 13 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from datetime import datetime 3 | from pathlib import Path 4 | 5 | 6 | def generate_paths(RUN_DESC: str, TASK: str, MODEL: str): 7 | # ------------------Data Management -------------------------------- 8 | 9 | HOME_DIR = Path(os.environ.get("HOME_DIR")) 10 | paths = dict( 11 | tt_datasets=HOME_DIR / "content" / "datasets" / "tt", 12 | sim_datasets=HOME_DIR / "content" / "datasets" / "sim", 13 | dt_datasets=HOME_DIR / "content" / "datasets" / "dt", 14 | trained_models=HOME_DIR / "content" / "trained_models", 15 | ) 16 | for key, val in paths.items(): 17 | if not val.exists(): 18 | val.mkdir(parents=True) 19 | 20 | DATE_STR = datetime.now().strftime("%Y%m%d") 21 | RUN_TAG = f"{DATE_STR}_{RUN_DESC}" 22 | RUN_DIR = HOME_DIR / "content" / "runs" / "task-trained" / RUN_TAG 23 | 24 | # -----------------Default Parameter Sets ----------------------------------- 25 | configs = dict( 26 | task_wrapper=Path(f"configs/task_wrapper/{TASK}.yaml"), 27 | env_task=Path(f"configs/env_task/{TASK}.yaml"), 28 | env_sim=Path(f"configs/env_sim/{TASK}.yaml"), 29 | datamodule_task=Path(f"configs/datamodule_train/datamodule_{TASK}.yaml"), 30 | datamodule_sim=Path(f"configs/datamodule_sim/datamodule_{TASK}.yaml"), 31 | model=Path(f"configs/model/{MODEL}.yaml"), 32 | simulator=Path(f"configs/simulator/default_{TASK}.yaml"), 33 | callbacks=Path(f"configs/callbacks/default_{TASK}.yaml"), 34 | loggers=Path("configs/logger/default.yaml"), 35 | trainer=Path("configs/trainer/default.yaml"), 36 | ) 37 | output_dict = { 38 | "path_dict": paths, 39 | "RUN_TAG": RUN_TAG, 40 | "RUN_DIR": RUN_DIR, 41 | "config_dict": configs, 42 | } 43 | return output_dict 44 | 45 | 46 | def make_data_tag(dm_cfg): 47 | obs_dim = "" if "obs_dim" not in dm_cfg else dm_cfg.obs_dim 48 | obs_noise = "" if "obs_noise" not in dm_cfg else dm_cfg.obs_noise 49 | if "obs_noise_params" in dm_cfg: 50 | obs_noise_params = ",".join( 51 | [f"{k}={v}" for k, v in dm_cfg.obs_noise_params.items()] 52 | ) 53 | else: 54 | obs_noise_params = "" 55 | data_tag = ( 56 | f"{dm_cfg.system}{obs_dim}_" 57 | f"{dm_cfg.n_samples}S_" 58 | f"{dm_cfg.n_timesteps}T_" 59 | f"{dm_cfg.pts_per_period}P_" 60 | f"{dm_cfg.seed}seed" 61 | ) 62 | if obs_noise: 63 | data_tag += f"_{obs_noise}{obs_noise_params}" 64 | return data_tag 65 | 66 | 67 | def flatten(dictionary, level=[]): 68 | """Flattens a dictionary by placing '.' between levels. 69 | This function flattens a hierarchical dictionary by placing '.' 70 | between keys at various levels to create a single key for each 71 | value. It is used internally for converting the configuration 72 | dictionary to more convenient formats. Implementation was 73 | inspired by `this StackOverflow post 74 | `_. 75 | Parameters 76 | ---------- 77 | dictionary : dict 78 | The hierarchical dictionary to be flattened. 79 | level : str, optional 80 | The string to append to the beginning of this dictionary, 81 | enabling recursive calls. By default, an empty string. 82 | Returns 83 | ------- 84 | dict 85 | The flattened dictionary. 86 | See Also 87 | -------- 88 | lfads_tf2.utils.unflatten : Performs the opposite of this operation. 89 | """ 90 | 91 | tmp_dict = {} 92 | for key, val in dictionary.items(): 93 | if type(val) == dict: 94 | tmp_dict.update(flatten(val, level + [key])) 95 | else: 96 | tmp_dict[".".join(level + [key])] = val 97 | return tmp_dict 98 | 99 | 100 | def trial_function(trial): 101 | return trial.experiment_tag 102 | --------------------------------------------------------------------------------