├── requirements.txt ├── dunl ├── postprocess_scripts │ ├── neural_glm.sh │ ├── save_data_for_pcanmf_local_speed.py │ ├── learn_pcanmf_transforms.py │ ├── save_data_for_pcanmf_dopamine_calcium_saramatias_uchida.py │ ├── plot_kernels_dopamine_calcium_saramatias_uchida.py │ ├── plot_kernels_whisker_thalamus.py │ ├── plot_report_r2_local_deconv_on_test.py │ ├── save_data_for_pcanmf_dopamine_spiking_eshel_uchida.py │ ├── plot_expsetup_dopamine_calcium_saramatias_uchida.py │ ├── plot_kernels_dopamine_spiking_eshel_uchida_hor.py │ ├── plot_pcanmf_kernels_calcium.py │ ├── plot_expsetup_dopamine_spiking_eshel_uchida.py │ ├── plot_kernels_local_orthkernels_deconv_spiking_simulated_noisy.py │ ├── plot_pcanmf_kernels_spiking.py │ ├── plot_kernels_dopamine_spiking_eshel_uchida.py │ └── plot_kernels_dopamine_spiking_eshel_uchida_vertical.py ├── lossfunc.py ├── datasetloader.py └── preprocess_scripts │ └── preprocess_data_dopamine_spiking_eshel_into_neuralgml_matlab.py ├── setup.py ├── LICENSE ├── .gitignore ├── README.md └── config ├── instrcutions.yaml ├── whisker_glm_config.yaml ├── local_2kernelfornmf_simulated_config.yaml ├── local_orthkernels_simulated_config.yaml ├── local_orthkernels_simulated_config_for_noisy.yaml ├── local_deconv_simulated_config.yaml ├── dopamine_spiking_eshel_uchida_config.yaml ├── local_deconv_calscenario_simulated_config.yaml ├── local_deconv_calscenario_longtrial_simulated_config.yaml ├── local_deconv_calscenario_shorttrial_simulated_config.yaml ├── local_deconv_calscenario_shorttrial_structured_simulated_config.yaml ├── dopamine_spiking_eshel_uchida_limited_data_exp_config.yaml ├── whisker_simulated_config.yaml ├── dopamine_spiking_simulated_config.yaml ├── dopamine_spiking_eshel_uchida_code122_kernel011_config.yaml ├── dopamine_spiking_eshel_uchida_code122_kernel011_inferbaseline_config.yaml ├── dopamine_spiking_eshel_uchida_code122_kernel011_inferbaseline_independentkernelsamongneurons_config.yaml ├── whisker_config.yaml ├── dopamine_calcium_saramatias_uchida_config.yaml ├── dopamine_fiberphotometry_saramatias_uchida_config_1window_1kernel.yaml ├── dopamine_calcium_saramatias_uchida_inferbaseline_config.yaml ├── dopamine_calcium_saramatias_uchida_independentkernelsamongneurons_config.yaml ├── whisker_groupneuralfirings_config.yaml └── dopamine_calcium_saramatias_uchida_inferbaseline_independentkernelsamongneurons_config.yaml /requirements.txt: -------------------------------------------------------------------------------- 1 | configmypy==0.1.0 2 | h5py==3.10.0 3 | hillfit==0.1.7 4 | matplotlib==3.8.0 5 | numpy==2.2.4 6 | scikit_learn==1.3.2 7 | scipy==1.15.2 8 | tensorboardX==2.6.2.2 9 | torch==2.1.2 10 | tqdm==4.66.1 -------------------------------------------------------------------------------- /dunl/postprocess_scripts/neural_glm.sh: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | res_path_list="../results/local_raised_cosine_5_bases_res25ms ../results/local_raised_cosine_2_bases_res25ms ../results/local_raised_nonlin_cosine_5_bases_res25ms ../results/local_raised_nonlin_cosine_2_bases_res25ms" 5 | 6 | for res_path in $res_path_list 7 | do 8 | python postprocess_scripts/plot_neuralglm_dopamine_spiking_eshel_uchida_code.py \ 9 | --res-path=$res_path 10 | 11 | python postprocess_scripts/plot_neuralglm_dopamine_spiking_eshel_uchida_base.py \ 12 | --res-path=$res_path 13 | 14 | python postprocess_scripts/plot_neuralglm_dopamine_spiking_eshel_uchida_rec.py \ 15 | --res-path=$res_path 16 | 17 | done 18 | 19 | 20 | 21 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup( 4 | name='dunl', 5 | version='0.1.0', 6 | packages=['dunl'], 7 | install_requires=[ 8 | 'configmypy', 9 | 'h5py', 10 | 'hillfit', 11 | 'matplotlib>=3.4.0', 12 | 'numpy=2.2.4', 13 | 'scikit_learn', 14 | 'scipy', 15 | 'tensorboardX', 16 | 'torch>=2.1.2', 17 | 'torchvision>=0.6.2', 18 | 'tqdm', 19 | ], 20 | author='Bahareh Tolooshams', 21 | author_email='btolooshams@gmail.com', 22 | description='DUNL for Computational Neuroscience (published at Neuron in 2025)', 23 | url='https://github.com/btolooshams/dunl-compneuro', 24 | classifiers=[ 25 | 'Programming Language :: Python :: 3', 26 | 'License :: OSI Approved :: MIT License', 27 | ], 28 | ) -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Bahareh Tolooshams 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Latex related 2 | **/*.aux 3 | **/*.bbl 4 | **/*.bib.bak 5 | **/*.blg 6 | **/*.cls.bak 7 | **/*.dvi 8 | **/*.log 9 | **/*.nav 10 | **/*.snm 11 | **/*.out 12 | **/*.synctex.gz 13 | **/*.synctex.gz(busy) 14 | **/*.tex.bak 15 | **/*.tikz.bak 16 | **/*.toc 17 | 18 | # Jupyter notebooks 19 | **/.ipynb_checkpoints 20 | 21 | # Python 22 | **/__pycache__ 23 | **/*.pyc 24 | 25 | # Lyx backup files 26 | **/*.lyx~ 27 | 28 | # Byte-compiled / optimized / DLL files 29 | **/*.py[cod] 30 | **/*.pytest_cache 31 | 32 | # C extensions 33 | **.so 34 | 35 | # Unit test / coverage reports 36 | **/htmlcov/ 37 | **/.tox/ 38 | **/.coverage 39 | **/.coverage.* 40 | **/.cache 41 | **/nosetests.xml 42 | **/coverage.xml 43 | **/*,cover 44 | 45 | # Translations 46 | **/*.mo 47 | **/*.pot 48 | 49 | # Log files: 50 | **/*.log 51 | 52 | # Sphinx documentation 53 | **/docs/_build/ 54 | 55 | # PyBuilder 56 | **/target/ 57 | 58 | # DotEnv configuration 59 | **/.env 60 | 61 | # Database 62 | **/*.db 63 | **/*.rdb 64 | 65 | # Pycharm 66 | **/.idea 67 | 68 | # VS Code 69 | **/.vscode 70 | 71 | # Spyder 72 | **/.spyproject 73 | 74 | # Mac OS-specific storage files 75 | **/.DS_Store 76 | 77 | # exclude model weights stored for training purposes 78 | **/*.hdf5 79 | 80 | # exclude data 81 | **/data 82 | **/results 83 | 84 | # Deprecated code 85 | **/old 86 | -------------------------------------------------------------------------------- /dunl/lossfunc.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2025 Bahareh Tolooshams 3 | 4 | loss functions for training 5 | 6 | :author: Bahareh Tolooshams 7 | """ 8 | 9 | import torch 10 | 11 | 12 | class DUNL1DLoss(torch.nn.Module): 13 | def __init__(self, model_distribution): 14 | super(DUNL1DLoss, self).__init__() 15 | 16 | self.model_distribution = model_distribution 17 | 18 | def forward(self, y, Hxa): 19 | if self.model_distribution == "gaussian": 20 | loss = torch.nn.functional.mse_loss(y, Hxa, reduction="none") 21 | elif self.model_distribution == "binomial": 22 | loss = -torch.mean(y * Hxa, dim=-1) + torch.mean( 23 | torch.log1p(torch.exp(Hxa)), dim=-1 24 | ) 25 | elif self.model_distribution == "poisson": 26 | loss = -torch.mean(y * Hxa, dim=-1) + torch.mean(torch.exp(Hxa), dim=-1) 27 | 28 | return torch.mean(loss) 29 | 30 | 31 | class Smoothloss(torch.nn.Module): 32 | def __init__(self): 33 | super(Smoothloss, self).__init__() 34 | 35 | def forward(self, H): 36 | loss = (H[:, :, 1:] - H[:, :, :-1]).pow(2).sum() / H.shape[0] 37 | return loss 38 | 39 | 40 | class l1loss(torch.nn.Module): 41 | def __init__(self): 42 | super(l1loss, self).__init__() 43 | 44 | def forward(self, x): 45 | loss = torch.mean(torch.abs(x)) 46 | return loss 47 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/ambv/black) 2 | 3 | 4 | # Deconvolutional Unrolled Neural Learning (DUNL) for Computational Neuroscience 5 | 6 | This code is wrriten for this paper [https://www.cell.com/neuron/abstract/S0896-6273(25)00119-9](https://www.cell.com/neuron/abstract/S0896-6273(25)00119-9) published at Neuron. 7 | 8 | Learning locally low-rank temporal representation from neural time series data. 9 | 10 | ### Usage 11 | 12 | For clone: `git clone https://github.com/btolooshams/dunl-compneuro.git` 13 | 14 | For pip install: `pip install git+https://github.com/btolooshams/dunl-compneuro.git` 15 | 16 | ### PATH 17 | 18 | For any scripts to run, make sure you are in `dunl` directory. 19 | 20 | ### Configuration 21 | 22 | Check `config` for detailed parameters of each experiment. 23 | 24 | You should provide all detailed information about the model as a yaml file. 25 | 26 | See `instrcutions.yaml` for information on important parameters. 27 | 28 | ### Data generation 29 | 30 | Go to `dunl/preprocess_script`. Create a script similar to `generate_simulated_data_dopamine_spiking_into_dataformat.py`. 31 | 32 | ### Data preparation 33 | 34 | Go to `dunl/preprocess_script`. Create a data dictionary that has format similar to those in preprocess files `preprocess_data_whisker_into_dataformat.pt`. 35 | 36 | Run `prepare_data_and_save.py` which take a datafolder with numpy files created from last step (e.g., `...general_format_processed.npy`) and create a `..._trainready.py` data file. 37 | 38 | The key module to load data can be found in `dunl/dataloader.py` which is `DUNLdataset`. Seethe Module for more info on the data. 39 | 40 | ### Tensorboard 41 | 42 | `dunl/boardfunc.py` contain preliminary functions that are being used during training to report train progress onto a board. 43 | 44 | ### Training 45 | 46 | See `dunl/train_sharekernels_acrossneurons.py` for an example training script. 47 | 48 | See `dunl/train_sharekernels_acrossneurons_groupneuralfirings.py` for using group sparsity across neurons. 49 | 50 | 51 | 52 | 53 | 54 | 55 | -------------------------------------------------------------------------------- /config/instrcutions.yaml: -------------------------------------------------------------------------------- 1 | # This shows the important variables that you wanna change within each config 2 | # This shows the important variables that you wanna change within each config 3 | # This shows the important variables that you wanna change within each config 4 | # This shows the important variables that you wanna change within each config 5 | default: &DEFAULT 6 | data_path: ["../data/datafile_xxx.pt"] # give it a list of datasets 7 | ################################# 8 | model_distribution: "binomial" # for spiking use binomial, for calcium use Gaussian 9 | share_kernels_among_neurons: True # set true to share kernels among neurons 10 | ################################# 11 | # kernel (dictionary) 12 | # important for kernel 13 | kernel_nonneg: False # True: project kernels into non-negative values 14 | kernel_nonneg_indicator: [0] # 0 for +-, 1 for + 15 | kernel_num: 1 # number of kernels to learn 16 | kernel_length: 25 # number of samples for kernel in time 17 | kernel_smoother: True # flag to apply smoother to the kernel during training 18 | kernel_smoother_penalty_weight: 0.003 # this is easy to tune (set to a a small value) and make kernel_smoother False if you have much data 19 | ################################# 20 | # code (representation) 21 | code_nonneg: [1] # apply sign constraint on the code. 1 for pos, -1 for neg, 2 for twosided 22 | code_sparse_regularization: 0.01 # apply sparse (lambda l1-norm) regularization on the code 23 | code_group_neural_firings_regularization: 0.05 # if > 0, then it applies groupping across neurons 24 | # if you don't have the event onsets, then code_supp would be off 25 | code_supp: False # True: apply known event indices (supp) into code x 26 | code_topk: True # True: keep only top k indices in each kernel code non-zero (this is greedy) 27 | code_topk_sparse: 18 # number of top k non-zero entires in each code kernel 28 | code_topk_period: 10 # period on encoder iteration to apply topk 29 | code_l1loss_bp_penalty_weight: 0.01 # suggest to keep this the same as code_sparse_regularization 30 | ################################# 31 | est_baseline_activity: True # if you wanna also est the baseline (have this true) 32 | ################################# 33 | # unrolling parameters 34 | unrolling_num: 200 # if you want highly sparse codes increase this. Recommend to be between 100 to 1000. 35 | unrolling_alpha: 0.5 # make sure that this is lower than 1 and small that network does not blow up 36 | -------------------------------------------------------------------------------- /config/whisker_glm_config.yaml: -------------------------------------------------------------------------------- 1 | default: &DEFAULT 2 | exp_name: "whisker_glm_05msbinres" 3 | data_path: ["../data/whisker/whisker_train_5msbinres_general_format_processed_kernellength25_kernelnum1_trainready.pt"] # give it a list of datasets 4 | data_folder: None # this will look for data in format *trainready.pt 5 | 6 | test_data_path: ["../data/whisker/whisker_test_5msbinres_general_format_processed_kernellength25_kernelnum1_trainready.pt"] # give it a list of datasets 7 | # data_path: ["../data/whisker/whisker_train_10msbinres_general_format_processed_kernellength12_kernelnum1_trainready.pt"] # give it a list of datasets 8 | # test_data_path: ["../data/whisker/whisker_test_10msbinres_general_format_processed_kernellength12_kernelnum1_trainready.pt"] # give it a list of datasets 9 | ################################# 10 | model_distribution: "binomial" # data distrbution gaussian, binomila, poisson 11 | share_kernels_among_neurons: True # set true to share kernels among neurons 12 | ################################# 13 | # kernel (dictionary) 14 | kernel_normalize: True # True: l2-norm of kernels is set to one after each update 15 | kernel_nonneg: False # True: project kernels into non-negative values 16 | kernel_nonneg_indicator: [0] # 0 for +-, 1 for + 17 | kernel_num: 1 # number of kernels to learn 18 | kernel_length: 25 # number of samples for kernel in time 19 | kernel_stride: 1 # default 1, convolution stride 20 | kernel_init_smoother: False # flag to init kernels to be smooth 21 | kernel_init_smoother_sigma: 0 # sigma of the gaussian kernel for kernel_init_smoother 22 | kernel_smoother: False # flag to apply smoother to the kernel during training 23 | kernel_smoother_penalty_weight: 0 # penalty weight to apply for kernel smoother 24 | kernel_initialization: "../data/whisker/kernel_init_25.pt" # None, or a data path 25 | kernel_initialization_needs_adjustment_of_time_bin_resolution: False 26 | ################################# 27 | # code (representation) 28 | code_nonneg: [1] # apply sign constraint on the code. 1 for pos, -1 for neg, 2 for twosided 29 | code_sparse_regularization: 0 # apply sparse (lambda l1-norm) regularization on the code 30 | code_sparse_regularization_decay: 1 # apply decay factor to lambda at every encoder iteration 31 | code_group_neural_firings_regularization: 0 # if > 0, then it applies groupping across neurons 32 | code_q_regularization: False # set True to apply Q-regularization on the norm of the code 33 | code_q_regularization_matrix: None # The matrix of relations between the codes (if flag is True, use the path to load) 34 | code_q_regularization_matrix_path: None 35 | code_q_regularization_period: 1 # the period to apply Q-regularization in encoder iterations 36 | code_q_regularization_scale: 5 # scale factor in front of the Q-regularization term 37 | code_q_regularization_norm_type: 2 # Set to the norm number you want the Q-regularization to be applied 38 | code_supp: True # True: apply known event indices (supp) into code x 39 | code_topk: False # True: keep only top k indices in each kernel code non-zero (this is greedy) 40 | code_topk_sparse: 16 # number of top k non-zero entires in each code kernel 41 | code_topk_period: 10 # period on encoder iteration to apply topk 42 | code_l1loss_bp: True # True: to include l1-norm of the code in the loss during training 43 | code_l1loss_bp_penalty_weight: 0 # amount of sparse regularization of the code with bp during training 44 | ################################# 45 | est_baseline_activity: False # True: estimate the baseline activity along with the code in the encoder 46 | poisson_stability_name: None # type of non-linearity to use on poisson case for encoder stability 47 | poisson_peak: 1 # For ELU "poisson_stability_name", this peak must be set to a value 48 | ################################# 49 | # unrolling parameters 50 | unrolling_num: 2000 # number of unrolling iterations in the encoder 51 | unrolling_mode: "fista" # ista or fista encoder 52 | unrolling_alpha: 1.0 # alpha step size in unrolling 53 | unrolling_prox: "shrinkage" # type of proximal operator (shrinkage, threshold) 54 | unrolling_threshold: None # must set to a value if unrolling_prox is threshold" 55 | -------------------------------------------------------------------------------- /dunl/postprocess_scripts/save_data_for_pcanmf_local_speed.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2025 Bahareh Tolooshams 3 | 4 | save data pca nmf 5 | 6 | :author: Bahareh Tolooshams 7 | """ 8 | 9 | import numpy as np 10 | import torch 11 | import configmypy 12 | import os 13 | import argparse 14 | 15 | import sys 16 | 17 | sys.path.append("../dunl/") 18 | 19 | import datasetloader 20 | 21 | 22 | def init_params(): 23 | parser = argparse.ArgumentParser(description=__doc__) 24 | parser.add_argument( 25 | "--config-folder", 26 | type=str, 27 | help="config folder", 28 | default="../config", 29 | ) 30 | parser.add_argument( 31 | "--config-filename", 32 | type=str, 33 | help="config filename", 34 | default="./local_speed_simulated_config.yaml", 35 | ) 36 | parser.add_argument( 37 | "--num-trials-list", 38 | type=int, 39 | help="number of trials list", # see main 40 | default=[250, 50, 100, 25, 500, 750, 1000], 41 | ) 42 | args = parser.parse_args() 43 | params = vars(args) 44 | 45 | return params 46 | 47 | 48 | def main(params): 49 | print("Train DUNL on neural data for local speed simulated.") 50 | 51 | params["window_dur"] = 16 52 | 53 | # create dataset and dataloaders ----------------------------------------# 54 | 55 | if params["data_path"] == "": 56 | data_folder = params["data_folder"] 57 | filename_list = os.listdir(data_folder) 58 | data_path_list = [ 59 | f"{data_folder}/{x}" for x in filename_list if "trainready.pt" in x 60 | ] 61 | else: 62 | data_path_list = params["data_path"] 63 | 64 | print("There {} dataset in the folder.".format(len(data_path_list))) 65 | 66 | data_path_cur = data_path_list[0] 67 | print(data_path_cur) 68 | train_dataset = datasetloader.DUNLdataset(data_path_cur) 69 | 70 | y = train_dataset.y 71 | x = train_dataset.x 72 | 73 | num_trials = y.shape[0] 74 | num_neurons = y.shape[1] 75 | num_kernels = x.shape[1] 76 | 77 | yavg = list() 78 | rew_amount = list() 79 | 80 | # go over all trials 81 | for i in range(num_trials): 82 | 83 | xi = x[i] 84 | yi = y[i] 85 | 86 | for kernel_ctr in range(num_kernels): 87 | onset = np.where(xi[kernel_ctr] > 0)[-1] 88 | for on in onset: 89 | y_curr = yi[:, on : on + params["window_dur"]] 90 | yavg.append(y_curr) 91 | 92 | yavg = torch.stack(yavg, dim=0) 93 | 94 | print("yavg", yavg.shape) 95 | 96 | np.save( 97 | os.path.join( 98 | "../data/local-speed-simulated", 99 | "y_for_pcanmf_numtrials{}.npy".format(num_trials), 100 | ), 101 | yavg, 102 | ) 103 | 104 | 105 | if __name__ == "__main__": 106 | # init parameters -------------------------------------------------------# 107 | print("init parameters.") 108 | params_init = init_params() 109 | 110 | pipe = configmypy.ConfigPipeline( 111 | [ 112 | configmypy.YamlConfig( 113 | params_init["config_filename"], 114 | config_name="default", 115 | config_folder=params_init["config_folder"], 116 | ), 117 | configmypy.ArgparseConfig( 118 | infer_types=True, config_name=None, config_file=None 119 | ), 120 | configmypy.YamlConfig(config_folder=params_init["config_folder"]), 121 | ] 122 | ) 123 | params = pipe.read_conf() 124 | params["config_folder"] = params_init["config_folder"] 125 | params["config_filename"] = params_init["config_filename"] 126 | 127 | for num_trials in params_init["num_trials_list"]: 128 | print("num_trials", num_trials) 129 | params["num_trials"] = num_trials 130 | 131 | data_path_name = f"../data/local-speed-simulated/simulated_100neurons_{num_trials}trials_25msbinres_14Hzbaseline_long_general_format_processed_kernellength16_kernelnum2_trainready.pt" 132 | 133 | params["data_path"] = [data_path_name] 134 | 135 | main(params) 136 | -------------------------------------------------------------------------------- /dunl/postprocess_scripts/learn_pcanmf_transforms.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2025 Bahareh Tolooshams 3 | 4 | learn pca nmf 5 | 6 | :author: Bahareh Tolooshams 7 | """ 8 | 9 | import torch 10 | import numpy as np 11 | import os 12 | import pickle 13 | import argparse 14 | from sklearn.decomposition import PCA, NMF 15 | from sklearn.preprocessing import StandardScaler 16 | 17 | 18 | def init_params(): 19 | parser = argparse.ArgumentParser(description=__doc__) 20 | 21 | parser.add_argument( 22 | "--res-path", 23 | type=str, 24 | help="res path", 25 | default="../results/dopaminecalcium_kernellength60_kernelnum5_code2211n1_kernel00011_qreg_2023_07_13_11_37_31", 26 | ) 27 | parser.add_argument( 28 | "--num-comp", 29 | type=int, 30 | help="number of components", 31 | default=2, 32 | ) 33 | parser.add_argument( 34 | "--max-iter", 35 | type=int, 36 | help="max iter for nmf", 37 | default=1000, 38 | ) 39 | 40 | args = parser.parse_args() 41 | params = vars(args) 42 | 43 | return params 44 | 45 | 46 | def main(): 47 | print("Predict.") 48 | 49 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 50 | print("device is", device) 51 | 52 | # init parameters -------------------------------------------------------# 53 | print("init parameters.") 54 | params_init = init_params() 55 | 56 | # take parameters from the result path 57 | params = pickle.load( 58 | open(os.path.join(params_init["res_path"], "params.pickle"), "rb") 59 | ) 60 | for key in params_init.keys(): 61 | params[key] = params_init[key] 62 | 63 | if params["data_path"] == "": 64 | data_folder = params["data_folder"] 65 | filename_list = os.listdir(data_folder) 66 | data_path_list = [ 67 | f"{data_folder}/{x}" for x in filename_list if "trainready.pt" in x 68 | ] 69 | else: 70 | data_path_list = params["data_path"] 71 | 72 | print("There {} dataset in the folder.".format(len(data_path_list))) 73 | 74 | # create folders -------------------------------------------------------# 75 | postprocess_path = os.path.join( 76 | params["res_path"], 77 | "postprocess", 78 | ) 79 | 80 | # load data -------------------------------------------------------# 81 | y_all = list() 82 | label_all = list() 83 | 84 | for data_path in data_path_list: 85 | datafile_name = data_path.split("/")[-1].split(".pt")[0] 86 | 87 | # (neuron, time, trials) 88 | y = np.load( 89 | os.path.join(postprocess_path, "y_for_pcanmf_{}.npy".format(datafile_name)) 90 | ) 91 | label = np.load( 92 | os.path.join( 93 | postprocess_path, "label_for_pcanmf_{}.npy".format(datafile_name) 94 | ) 95 | ) 96 | 97 | y = np.transpose(y, (0, 2, 1)) 98 | y = np.reshape(y, (-1, y.shape[-1])) 99 | 100 | y_all.append(y) 101 | label_all.append(label) 102 | 103 | y_all = np.concatenate(y_all, axis=0) 104 | label_all = np.concatenate(label_all, axis=0) 105 | 106 | # do transform -------------------------------------------------------# 107 | scaler = StandardScaler() 108 | y_all_standardized = scaler.fit_transform(y_all) 109 | 110 | print("y_all", y_all.shape) 111 | print("y_standardized", y_all_standardized.shape) 112 | 113 | pca_transform = PCA(n_components=params["num_comp"]) 114 | y_pca_coeff = pca_transform.fit_transform(y_all_standardized) 115 | nmf_transform = NMF(n_components=params["num_comp"], max_iter=params["max_iter"]) 116 | y_nmf_coeff = nmf_transform.fit_transform(y_all - np.min(y_all)) 117 | 118 | pickle.dump( 119 | pca_transform, 120 | open( 121 | os.path.join( 122 | postprocess_path, "pca_transform_{}.pkl".format(params["num_comp"]) 123 | ), 124 | "wb", 125 | ), 126 | ) 127 | pickle.dump( 128 | nmf_transform, 129 | open( 130 | os.path.join( 131 | postprocess_path, "nmf_transform_{}.pkl".format(params["num_comp"]) 132 | ), 133 | "wb", 134 | ), 135 | ) 136 | pickle.dump( 137 | scaler, open(os.path.join(postprocess_path, "scaler_transform.pkl"), "wb") 138 | ) 139 | 140 | 141 | if __name__ == "__main__": 142 | main() 143 | -------------------------------------------------------------------------------- /dunl/datasetloader.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2025 Tolooshams 3 | 4 | create data generator 5 | 6 | :author: Bahareh Tolooshams 7 | """ 8 | 9 | import torch 10 | 11 | 12 | class DUNLdataset(torch.utils.data.Dataset): 13 | def __init__(self, data_path): 14 | data = torch.load(data_path) 15 | 16 | self.data_path = data_path 17 | 18 | self.y = data["y"] # recorded data dim(num_trials, num_neurons, trial_length) 19 | self.x = data[ 20 | "x" 21 | ] # event onsets dim(num_trials, num_kernels, trial_length - kernel_length + 1) 22 | self.a = data["a"] # baseline dim(num_trials, num_neurons, 1) 23 | self.type = data["type"] # trial type dim(num_trials) 24 | self.num_data = self.y.shape[0] # number of trials 25 | 26 | print( 27 | "x is shared among neurons. It is a function of trials, and number of kernels!" 28 | ) 29 | 30 | def __len__(self): 31 | return self.num_data 32 | 33 | def __getitem__(self, idx): 34 | return ( 35 | self.y[idx], 36 | self.x[idx], 37 | self.a[idx], 38 | self.type[idx], 39 | ) 40 | 41 | 42 | class DUNLdatasetwithRasterNoRate(torch.utils.data.Dataset): 43 | def __init__(self, data_path): 44 | data = torch.load(data_path) 45 | 46 | self.data_path = data_path 47 | 48 | self.time_org_resolution = data["time_org_resolution"] 49 | self.time_bin_resolution = data["time_bin_resolution"] 50 | 51 | self.raster = data["raster"] 52 | self.y = data["y"] # recorded data dim(num_trials, num_neurons, trial_length) 53 | self.x = data[ 54 | "x" 55 | ] # event onsets dim(num_trials, num_kernels, trial_length - kernel_length + 1) 56 | self.a = data["a"] # baseline dim(num_trials, num_neurons, 1) 57 | self.type = data["type"] # trial type dim(num_trials) 58 | self.num_data = self.y.shape[0] # number of trials 59 | 60 | def __len__(self): 61 | return self.num_data 62 | 63 | def __getitem__(self, idx): 64 | return ( 65 | self.y[idx], 66 | self.x[idx], 67 | self.a[idx], 68 | self.type[idx], 69 | self.raster[idx], 70 | ) 71 | 72 | 73 | class DUNLdatasetwithRaster(torch.utils.data.Dataset): 74 | def __init__(self, data_path): 75 | data = torch.load(data_path) 76 | 77 | self.data_path = data_path 78 | 79 | self.time_org_resolution = data["time_org_resolution"] 80 | self.time_bin_resolution = data["time_bin_resolution"] 81 | 82 | self.raster = data["raster"] 83 | self.y = data["y"] # recorded data dim(num_trials, num_neurons, trial_length) 84 | self.x = data[ 85 | "x" 86 | ] # event onsets dim(num_trials, num_kernels, trial_length - kernel_length + 1) 87 | self.a = data["a"] # baseline dim(num_trials, num_neurons, 1) 88 | self.type = data["type"] # trial type dim(num_trials) 89 | self.num_data = self.y.shape[0] # number of trials 90 | 91 | print( 92 | "x is shared among neurons. It is a function of trials, and number of kernels!" 93 | ) 94 | 95 | def __len__(self): 96 | return self.num_data 97 | 98 | def __getitem__(self, idx): 99 | return ( 100 | self.y[idx], 101 | self.x[idx], 102 | self.a[idx], 103 | self.type[idx], 104 | self.raster[idx], 105 | ) 106 | 107 | 108 | class DUNLdatasetwithRasterWithCodeRate(torch.utils.data.Dataset): 109 | def __init__(self, data_path): 110 | data = torch.load(data_path) 111 | 112 | self.data_path = data_path 113 | 114 | self.time_org_resolution = data["time_org_resolution"] 115 | self.time_bin_resolution = data["time_bin_resolution"] 116 | 117 | self.raster = data["raster"] 118 | self.y = data["y"] # recorded data dim(num_trials, num_neurons, trial_length) 119 | self.x = data[ 120 | "x" 121 | ] # event onsets dim(num_trials, num_kernels, trial_length - kernel_length + 1) 122 | self.a = data["a"] # baseline dim(num_trials, num_neurons, 1) 123 | self.type = data["type"] # trial type dim(num_trials) 124 | self.num_data = self.y.shape[0] # number of trials 125 | self.codes = data["codes"] 126 | self.rate = data["rate"] 127 | 128 | print( 129 | "x is shared among neurons. It is a function of trials, and number of kernels!" 130 | ) 131 | 132 | def __len__(self): 133 | return self.num_data 134 | 135 | def __getitem__(self, idx): 136 | return ( 137 | self.y[idx], 138 | self.x[idx], 139 | self.a[idx], 140 | self.type[idx], 141 | self.raster[idx], 142 | self.codes[idx], 143 | self.rate[idx], 144 | ) 145 | -------------------------------------------------------------------------------- /config/local_2kernelfornmf_simulated_config.yaml: -------------------------------------------------------------------------------- 1 | default: &DEFAULT 2 | exp_name: None 3 | data_path: None # this must be a list of data_path, set to "" if using data_folder as path. 4 | test_data_path: None # give it a list of datasets 5 | ################################# 6 | model_distribution: "binomial" # data distrbution gaussian, binomila, poisson 7 | share_kernels_among_neurons: True # set true to share kernels among neurons 8 | ################################# 9 | # kernel (dictionary) 10 | kernel_normalize: True # True: l2-norm of kernels is set to one after each update 11 | kernel_nonneg: False # True: project kernels into non-negative values 12 | kernel_nonneg_indicator: None # 0 for +-, 1 for + 13 | kernel_num: 3 # number of kernels to learn 14 | kernel_length: 80 # number of samples for kernel in time 15 | kernel_stride: 1 # default 1, convolution stride 16 | kernel_init_smoother: False # flag to init kernels to be smooth 17 | kernel_init_smoother_sigma: 0.1 # sigma of the gaussian kernel for kernel_init_smoother 18 | kernel_smoother: True # flag to apply smoother to the kernel during training 19 | kernel_smoother_penalty_weight: 0.01 # penalty weight to apply for kernel smoother 20 | # kernel_initialization: "../data/local-deconv-simulated/kernels_sin_and_bump.pt" # None, or a data path 21 | kernel_initialization: None # None, or a data path 22 | kernel_initialization_needs_adjustment_of_time_bin_resolution: False 23 | ################################# 24 | # code (representation) 25 | code_nonneg: [1, 1, 1] # apply sign constraint on the code. 1 for pos, -1 for neg, 2 for twosided 26 | code_sparse_regularization: 0.005 # apply sparse (lambda l1-norm) regularization on the code 27 | code_sparse_regularization_decay: 1 # apply decay factor to lambda at every encoder iteration 28 | code_group_neural_firings_regularization: 0.025 # if > 0, then it applies groupping across neurons 29 | code_q_regularization: False # set True to apply Q-regularization on the norm of the code 30 | code_q_regularization_matrix: None # The matrix of relations between the codes (use the path to load) 31 | code_q_regularization_matrix_path: None 32 | code_q_regularization_period: 10 # the period to apply Q-regularization in encoder iterations 33 | code_q_regularization_scale: 2.5 # scale factor in front of the Q-regularization term 34 | code_q_regularization_norm_type: None # Set to the norm number you want the Q-regularization to be applied 35 | code_supp: False # True: apply known event indices (supp) into code x 36 | code_topk: True # True: keep only top k indices in each kernel code non-zero (this is greedy) 37 | code_topk_sparse: 1 # number of top k non-zero entires in each code kernel 38 | code_topk_period: 10 # period on encoder iteration to apply topk 39 | code_l1loss_bp: True # True: to include l1-norm of the code in the loss during training 40 | code_l1loss_bp_penalty_weight: 0.005 # amount of sparse regularization of the code with bp during training 41 | ################################# 42 | est_baseline_activity: False # True: estimate the baseline activity along with the code in the encoder 43 | poisson_stability_name: None # type of non-linearity to use on poisson case for encoder stability 44 | poisson_peak: 1 # For ELU "poisson_stability_name", this peak must be set to a value 45 | ################################# 46 | # unrolling parameters 47 | unrolling_num: 20 # number of unrolling iterations in the encoder 48 | unrolling_mode: "fista" # ista or fista encoder 49 | unrolling_alpha: 0.25 # alpha step size in unrolling for known events it was 0.5 50 | unrolling_prox: "shrinkage" # type of proximal operator (shrinkage, threshold) 51 | unrolling_threshold: None # must set to a value if unrolling_prox is threshold" 52 | ################################# 53 | # training related 54 | # default optimizer is ADAM. 55 | optimizer_lr: 1e-2 # learning rate for training the model (learning the kernels) 56 | optimizer_lr_step: 1000 # number of steps (updates) after which the lr will decay 57 | optimizer_lr_decay: 1 # decay factor for learning rate 58 | optimizer_adam_eps: 1e-3 # eps parameter of adam optimizer 59 | optimizer_adam_weight_decay: 0 # weight_decay parameter for adam optimizer 60 | # 61 | backward_gradient_decsent: truncated_bprop" # type of backward gradient update (bprop, truncated_bprop) 62 | backward_truncated_bprop_itr: 5 # must be set for truncated_bprop 63 | # 64 | train_num_steps: 1000 # number of steps for training 65 | train_data_shuffle: True # True: to shuffle dataset at every epoch for training 66 | train_batch_size: 10 # batch size for training 67 | train_num_workers: 4 # number of workers to load data 68 | train_val_split: 0.85 # 1: use all for train. percentage of data used to train, rest to be used for validation. 69 | train_val_max: 10 70 | # 71 | val_period: 5 72 | enable_board: True 73 | tqdm_prints_disable: False # True: to disable prints of epoch training process 74 | tqdm_prints_inside_disable: True # True: to disable prints inside of epoch training process 75 | log_fig_epoch_period: 10 -------------------------------------------------------------------------------- /config/local_orthkernels_simulated_config.yaml: -------------------------------------------------------------------------------- 1 | default: &DEFAULT 2 | exp_name: None 3 | data_path: None # this must be a list of data_path, set to "" if using data_folder as path. 4 | test_data_path: None # give it a list of datasets 5 | ################################# 6 | model_distribution: "binomial" # data distrbution gaussian, binomila, poisson 7 | share_kernels_among_neurons: True # set true to share kernels among neurons 8 | ################################# 9 | # kernel (dictionary) 10 | kernel_normalize: True # True: l2-norm of kernels is set to one after each update 11 | kernel_nonneg: False # True: project kernels into non-negative values 12 | kernel_nonneg_indicator: None # 0 for +-, 1 for + 13 | kernel_num: 5 # number of kernels to learn 14 | kernel_length: 80 # number of samples for kernel in time 15 | kernel_stride: 1 # default 1, convolution stride 16 | kernel_init_smoother: False # flag to init kernels to be smooth 17 | kernel_init_smoother_sigma: 0.1 # sigma of the gaussian kernel for kernel_init_smoother 18 | kernel_smoother: True # flag to apply smoother to the kernel during training 19 | kernel_smoother_penalty_weight: 0.01 # penalty weight to apply for kernel smoother 20 | # kernel_initialization: "../data/local-deconv-simulated/kernels_sin_and_bump.pt" # None, or a data path 21 | kernel_initialization: None # None, or a data path 22 | kernel_initialization_needs_adjustment_of_time_bin_resolution: False 23 | ################################# 24 | # code (representation) 25 | code_nonneg: [1, 1, 1, 1, 1] # apply sign constraint on the code. 1 for pos, -1 for neg, 2 for twosided 26 | code_sparse_regularization: 0.005 # apply sparse (lambda l1-norm) regularization on the code 27 | code_sparse_regularization_decay: 1 # apply decay factor to lambda at every encoder iteration 28 | code_group_neural_firings_regularization: 0.05 # if > 0, then it applies groupping across neurons 29 | code_q_regularization: False # set True to apply Q-regularization on the norm of the code 30 | code_q_regularization_matrix: None # The matrix of relations between the codes (use the path to load) 31 | code_q_regularization_matrix_path: None 32 | code_q_regularization_period: 10 # the period to apply Q-regularization in encoder iterations 33 | code_q_regularization_scale: 2.5 # scale factor in front of the Q-regularization term 34 | code_q_regularization_norm_type: None # Set to the norm number you want the Q-regularization to be applied 35 | code_supp: False # True: apply known event indices (supp) into code x 36 | code_topk: True # True: keep only top k indices in each kernel code non-zero (this is greedy) 37 | code_topk_sparse: 1 # number of top k non-zero entires in each code kernel 38 | code_topk_period: 10 # period on encoder iteration to apply topk 39 | code_l1loss_bp: True # True: to include l1-norm of the code in the loss during training 40 | code_l1loss_bp_penalty_weight: 0.005 # amount of sparse regularization of the code with bp during training 41 | ################################# 42 | est_baseline_activity: False # True: estimate the baseline activity along with the code in the encoder 43 | poisson_stability_name: None # type of non-linearity to use on poisson case for encoder stability 44 | poisson_peak: 1 # For ELU "poisson_stability_name", this peak must be set to a value 45 | ################################# 46 | # unrolling parameters 47 | unrolling_num: 20 # number of unrolling iterations in the encoder 48 | unrolling_mode: "fista" # ista or fista encoder 49 | unrolling_alpha: 0.25 # alpha step size in unrolling for known events it was 0.5 50 | unrolling_prox: "shrinkage" # type of proximal operator (shrinkage, threshold) 51 | unrolling_threshold: None # must set to a value if unrolling_prox is threshold" 52 | ################################# 53 | # training related 54 | # default optimizer is ADAM. 55 | optimizer_lr: 1e-2 # learning rate for training the model (learning the kernels) 56 | optimizer_lr_step: 1000 # number of steps (updates) after which the lr will decay 57 | optimizer_lr_decay: 1 # decay factor for learning rate 58 | optimizer_adam_eps: 1e-3 # eps parameter of adam optimizer 59 | optimizer_adam_weight_decay: 0 # weight_decay parameter for adam optimizer 60 | # 61 | backward_gradient_decsent: truncated_bprop" # type of backward gradient update (bprop, truncated_bprop) 62 | backward_truncated_bprop_itr: 5 # must be set for truncated_bprop 63 | # 64 | train_num_steps: 1000 # number of steps for training 65 | train_data_shuffle: True # True: to shuffle dataset at every epoch for training 66 | train_batch_size: 5 # batch size for training 67 | train_num_workers: 4 # number of workers to load data 68 | train_val_split: 0.85 # 1: use all for train. percentage of data used to train, rest to be used for validation. 69 | train_val_max: 10 70 | # 71 | val_period: 5 72 | enable_board: True 73 | tqdm_prints_disable: False # True: to disable prints of epoch training process 74 | tqdm_prints_inside_disable: True # True: to disable prints inside of epoch training process 75 | log_fig_epoch_period: 20 -------------------------------------------------------------------------------- /config/local_orthkernels_simulated_config_for_noisy.yaml: -------------------------------------------------------------------------------- 1 | default: &DEFAULT 2 | exp_name: None 3 | data_path: None # this must be a list of data_path, set to "" if using data_folder as path. 4 | test_data_path: None # give it a list of datasets 5 | ################################# 6 | model_distribution: "binomial" # data distrbution gaussian, binomila, poisson 7 | share_kernels_among_neurons: True # set true to share kernels among neurons 8 | ################################# 9 | # kernel (dictionary) 10 | kernel_normalize: True # True: l2-norm of kernels is set to one after each update 11 | kernel_nonneg: False # True: project kernels into non-negative values 12 | kernel_nonneg_indicator: None # 0 for +-, 1 for + 13 | kernel_num: 6 # number of kernels to learn 14 | kernel_length: 84 # number of samples for kernel in time 15 | kernel_stride: 1 # default 1, convolution stride 16 | kernel_init_smoother: False # flag to init kernels to be smooth 17 | kernel_init_smoother_sigma: 0.1 # sigma of the gaussian kernel for kernel_init_smoother 18 | kernel_smoother: True # flag to apply smoother to the kernel during training 19 | kernel_smoother_penalty_weight: 0.01 # penalty weight to apply for kernel smoother 20 | # kernel_initialization: "../data/local-deconv-simulated/kernels_sin_and_bump.pt" # None, or a data path 21 | kernel_initialization: None # None, or a data path 22 | kernel_initialization_needs_adjustment_of_time_bin_resolution: False 23 | ################################# 24 | # code (representation) 25 | code_nonneg: [1, 1, 1, 1, 1, 1] # apply sign constraint on the code. 1 for pos, -1 for neg, 2 for twosided 26 | code_sparse_regularization: 0.005 # apply sparse (lambda l1-norm) regularization on the code 27 | code_sparse_regularization_decay: 1 # apply decay factor to lambda at every encoder iteration 28 | code_group_neural_firings_regularization: 0.05 # if > 0, then it applies groupping across neurons 29 | code_q_regularization: False # set True to apply Q-regularization on the norm of the code 30 | code_q_regularization_matrix: None # The matrix of relations between the codes (use the path to load) 31 | code_q_regularization_matrix_path: None 32 | code_q_regularization_period: 10 # the period to apply Q-regularization in encoder iterations 33 | code_q_regularization_scale: 2.5 # scale factor in front of the Q-regularization term 34 | code_q_regularization_norm_type: None # Set to the norm number you want the Q-regularization to be applied 35 | code_supp: False # True: apply known event indices (supp) into code x 36 | code_topk: True # True: keep only top k indices in each kernel code non-zero (this is greedy) 37 | code_topk_sparse: 1 # number of top k non-zero entires in each code kernel 38 | code_topk_period: 10 # period on encoder iteration to apply topk 39 | code_l1loss_bp: True # True: to include l1-norm of the code in the loss during training 40 | code_l1loss_bp_penalty_weight: 0.005 # amount of sparse regularization of the code with bp during training 41 | ################################# 42 | est_baseline_activity: False # True: estimate the baseline activity along with the code in the encoder 43 | poisson_stability_name: None # type of non-linearity to use on poisson case for encoder stability 44 | poisson_peak: 1 # For ELU "poisson_stability_name", this peak must be set to a value 45 | ################################# 46 | # unrolling parameters 47 | unrolling_num: 20 # number of unrolling iterations in the encoder 48 | unrolling_mode: "fista" # ista or fista encoder 49 | unrolling_alpha: 0.25 # alpha step size in unrolling for known events it was 0.5 50 | unrolling_prox: "shrinkage" # type of proximal operator (shrinkage, threshold) 51 | unrolling_threshold: None # must set to a value if unrolling_prox is threshold" 52 | ################################# 53 | # training related 54 | # default optimizer is ADAM. 55 | optimizer_lr: 1e-2 # learning rate for training the model (learning the kernels) 56 | optimizer_lr_step: 1000 # number of steps (updates) after which the lr will decay 57 | optimizer_lr_decay: 1 # decay factor for learning rate 58 | optimizer_adam_eps: 1e-3 # eps parameter of adam optimizer 59 | optimizer_adam_weight_decay: 0 # weight_decay parameter for adam optimizer 60 | # 61 | backward_gradient_decsent: truncated_bprop" # type of backward gradient update (bprop, truncated_bprop) 62 | backward_truncated_bprop_itr: 5 # must be set for truncated_bprop 63 | # 64 | train_num_steps: 700 # number of steps for training 65 | train_data_shuffle: True # True: to shuffle dataset at every epoch for training 66 | train_batch_size: 5 # batch size for training 67 | train_num_workers: 4 # number of workers to load data 68 | train_val_split: 0.85 # 1: use all for train. percentage of data used to train, rest to be used for validation. 69 | train_val_max: 10 70 | # 71 | val_period: 5 72 | enable_board: True 73 | tqdm_prints_disable: False # True: to disable prints of epoch training process 74 | tqdm_prints_inside_disable: True # True: to disable prints inside of epoch training process 75 | log_fig_epoch_period: 50 -------------------------------------------------------------------------------- /config/local_deconv_simulated_config.yaml: -------------------------------------------------------------------------------- 1 | default: &DEFAULT 2 | exp_name: None 3 | data_path: None # this must be a list of data_path, set to "" if using data_folder as path. 4 | test_data_path: None # give it a list of datasets 5 | ################################# 6 | model_distribution: "binomial" # data distrbution gaussian, binomila, poisson 7 | share_kernels_among_neurons: True # set true to share kernels among neurons 8 | ################################# 9 | # kernel (dictionary) 10 | kernel_normalize: True # True: l2-norm of kernels is set to one after each update 11 | kernel_nonneg: False # True: project kernels into non-negative values 12 | kernel_nonneg_indicator: None # 0 for +-, 1 for + 13 | kernel_num: 2 # number of kernels to learn 14 | kernel_length: 16 # number of samples for kernel in time 15 | kernel_stride: 1 # default 1, convolution stride 16 | kernel_init_smoother: False # flag to init kernels to be smooth 17 | kernel_init_smoother_sigma: 0.5 # sigma of the gaussian kernel for kernel_init_smoother 18 | kernel_smoother: True # flag to apply smoother to the kernel during training 19 | kernel_smoother_penalty_weight: 0.007 # penalty weight to apply for kernel smoother 20 | # kernel_initialization: "../data/local-deconv-simulated/kernels_sin_and_bump.pt" # None, or a data path 21 | kernel_initialization: None # None, or a data path 22 | kernel_initialization_needs_adjustment_of_time_bin_resolution: False, 23 | ################################# 24 | # code (representation) 25 | code_nonneg: [1, 1] # apply sign constraint on the code. 1 for pos, -1 for neg, 2 for twosided 26 | code_sparse_regularization: 0.1 # apply sparse (lambda l1-norm) regularization on the code 27 | code_sparse_regularization_decay: 1 # apply decay factor to lambda at every encoder iteration 28 | code_q_regularization: True # set True to apply Q-regularization on the norm of the code 29 | code_q_regularization_matrix: None # The matrix of relations between the codes (use the path to load) 30 | code_q_regularization_matrix_path: "../data/local-deconv-simulated/code_q_regularization_matrix.pt" 31 | code_q_regularization_period: 1 # the period to apply Q-regularization in encoder iterations 32 | code_q_regularization_scale: 2.5 # scale factor in front of the Q-regularization term 33 | code_q_regularization_norm_type: None # Set to the norm number you want the Q-regularization to be applied 34 | code_supp: False # True: apply known event indices (supp) into code x 35 | code_topk: True # True: keep only top k indices in each kernel code non-zero (this is greedy) 36 | code_topk_sparse: 4 # number of top k non-zero entires in each code kernel 37 | code_topk_period: 10 # period on encoder iteration to apply topk 38 | code_l1loss_bp: True # True: to include l1-norm of the code in the loss during training 39 | code_l1loss_bp_penalty_weight: 0.1 # amount of sparse regularization of the code with bp during training 40 | ################################# 41 | est_baseline_activity: False # True: estimate the baseline activity along with the code in the encoder 42 | poisson_stability_name: None # type of non-linearity to use on poisson case for encoder stability 43 | poisson_peak: 1 # For ELU "poisson_stability_name", this peak must be set to a value 44 | ################################# 45 | # unrolling parameters 46 | unrolling_num: 800 # number of unrolling iterations in the encoder 47 | unrolling_mode: "fista" # ista or fista encoder 48 | unrolling_alpha: 0.25 # alpha step size in unrolling for known events it was 0.5 49 | unrolling_prox: "shrinkage" # type of proximal operator (shrinkage, threshold) 50 | unrolling_threshold: None # must set to a value if unrolling_prox is threshold" 51 | ################################# 52 | # training related 53 | # default optimizer is ADAM. 54 | optimizer_lr: 1e-2 # learning rate for training the model (learning the kernels) 55 | optimizer_lr_step: 1000 # number of steps (updates) after which the lr will decay 56 | optimizer_lr_decay: 1 # decay factor for learning rate 57 | optimizer_adam_eps: 1e-3 # eps parameter of adam optimizer 58 | optimizer_adam_weight_decay: 0 # weight_decay parameter for adam optimizer 59 | # 60 | backward_gradient_decsent: truncated_bprop" # type of backward gradient update (bprop, truncated_bprop) 61 | backward_truncated_bprop_itr: 5 # must be set for truncated_bprop 62 | # 63 | train_num_epochs: None # number of epochs for training 64 | train_data_shuffle: True # True: to shuffle dataset at every epoch for training 65 | train_batch_size: 128 # batch size for training 66 | train_num_workers: 4 # number of workers to load data 67 | train_val_split: 0.85 # 1: use all for train. percentage of data used to train, rest to be used for validation. 68 | # 69 | enable_board: True 70 | log_info_epoch_period: 1 # period to push small info into the board 71 | log_model_epoch_period: 200 # period to save model 72 | log_fig_epoch_period: None # period to push figures into the board 73 | tqdm_prints_disable: False # True: to disable prints of epoch training process 74 | tqdm_prints_inside_disable: True # True: to disable prints inside of epoch training process -------------------------------------------------------------------------------- /config/dopamine_spiking_eshel_uchida_config.yaml: -------------------------------------------------------------------------------- 1 | default: &DEFAULT 2 | exp_name: "dopaminespiking_25msbin_kernellength24_kernelnum3_codefree_kernel111" 3 | data_path: "" # this must be a list of data_path, set to "" if using data_folder as path. 4 | data_folder: "../data/dopamine-spiking-eshel-uchida" # this will look for data in format *trainready.pt 5 | test_data_path: None 6 | ################################# 7 | model_distribution: "binomial" # data distrbution gaussian, binomila, poisson 8 | share_kernels_among_neurons: True # set true to share kernels among neurons 9 | ################################# 10 | # kernel (dictionary) 11 | kernel_normalize: True # True: l2-norm of kernels is set to one after each update 12 | kernel_nonneg: True # True: project kernels into non-negative values 13 | kernel_nonneg_indicator: None # list of [0 for +-, 1 for +] or set to None for all + 14 | kernel_num: 3 # number of kernels to learn 15 | kernel_length: 24 # number of samples for kernel in time 16 | kernel_stride: 1 # default 1, convolution stride 17 | kernel_init_smoother: False # flag to init kernels to be smooth 18 | kernel_init_smoother_sigma: 0.5 # sigma of the gaussian kernel for kernel_init_smoother 19 | kernel_smoother: False # flag to apply smoother to the kernel during training 20 | kernel_smoother_penalty_weight: 0 # penalty weight to apply for kernel smoother 21 | kernel_initialization: None # None, or a data path 22 | ################################# 23 | # code (representation) 24 | code_nonneg: False # apply sign constraint on the code. list of [1 for pos, -1 for neg, 2 for twosided], or set True/False for all 25 | code_sparse_regularization: 0 # apply sparse (lambda l1-norm) regularization on the code 26 | code_sparse_regularization_decay: 1 # apply decay factor to lambda at every encoder iteration 27 | code_q_regularization: False # set True to apply Q-regularization on the norm of the code 28 | code_q_regularization_matrix: None # The matrix of relations between the codes (if flag is True, use the path to load) 29 | code_q_regularization_matrix_path: "" 30 | code_q_regularization_period: 1 # the period to apply Q-regularization in encoder iterations 31 | code_q_regularization_scale: 2.5 # scale factor in front of the Q-regularization term 32 | code_q_regularization_norm_type: 2 # Set to the norm number you want the Q-regularization to be applied 33 | code_supp: True # True: apply known event indices (supp) into code x 34 | code_topk: False # True: keep only top k indices in each kernel code non-zero (this is greedy) 35 | code_topk_sparse: None # number of top k non-zero entires in each code kernel 36 | code_topk_period: None # period on encoder iteration to apply topk 37 | code_l1loss_bp: False # True: to include l1-norm of the code in the loss during training 38 | code_l1loss_bp_penalty_weight: 0 # amount of sparse regularization of the code with bp during training 39 | ################################# 40 | est_baseline_activity: False # True: estimate the baseline activity along with the code in the encoder 41 | poisson_stability_name: None # type of non-linearity to use on poisson case for encoder stability 42 | poisson_peak: 1 # For ELU "poisson_stability_name", this peak must be set to a value 43 | ################################# 44 | # unrolling parameters 45 | unrolling_num: 100 # number of unrolling iterations in the encoder 46 | unrolling_mode: "fista" # ista or fista encoder 47 | unrolling_alpha: 0.1 # alpha step size in unrolling 48 | unrolling_prox: "shrinkage" # type of proximal operator (shrinkage, threshold) 49 | unrolling_threshold: None # must set to a value if unrolling_prox is threshold" 50 | ################################# 51 | # training related 52 | # default optimizer is ADAM. 53 | optimizer_lr: 1e-2 # learning rate for training the model (learning the kernels) 54 | optimizer_lr_step: 20 # number of steps (updates) after which the lr will decay 55 | optimizer_lr_decay: 1 # decay factor for learning rate 56 | optimizer_adam_eps: 1e-3 # eps parameter of adam optimizer 57 | optimizer_adam_weight_decay: 0 # weight_decay parameter for adam optimizer 58 | # 59 | backward_gradient_decsent: bprop" # type of backward gradient update (bprop, truncated_bprop) 60 | backward_truncated_bprop_itr: 10 # must be set for truncated_bprop 61 | # 62 | train_num_epochs: 15 # number of epochs for training 63 | train_data_shuffle: True # True: to shuffle dataset at every epoch for training 64 | train_batch_size: 32 # batch size for training 65 | train_num_workers: 4 # number of workers to load data 66 | train_val_split: 1 # 1: use all for train. percentage of data used to train, rest to be used for validation. 67 | train_with_fraction: 1 # 1 for all the data, or a fraction e.g. 0.1 68 | # 69 | enable_board: True 70 | log_info_epoch_period: 1 # period to push small info into the board 71 | log_model_epoch_period: 5 # period to save model 72 | log_fig_epoch_period: 1 # period to push figures into the board 73 | tqdm_prints_disable: False # True: to disable prints of epoch training process 74 | tqdm_prints_inside_disable: True # True: to disable prints inside of epoch training process -------------------------------------------------------------------------------- /config/local_deconv_calscenario_simulated_config.yaml: -------------------------------------------------------------------------------- 1 | default: &DEFAULT 2 | exp_name: None 3 | data_path: None # this must be a list of data_path, set to "" if using data_folder as path. 4 | test_data_path: None # give it a list of datasets 5 | ################################# 6 | model_distribution: "binomial" # data distrbution gaussian, binomila, poisson 7 | share_kernels_among_neurons: True # set true to share kernels among neurons 8 | ################################# 9 | # kernel (dictionary) 10 | kernel_normalize: True # True: l2-norm of kernels is set to one after each update 11 | kernel_nonneg: False # True: project kernels into non-negative values 12 | kernel_nonneg_indicator: None # 0 for +-, 1 for + 13 | kernel_num: 2 # number of kernels to learn 14 | kernel_length: 16 # number of samples for kernel in time 15 | kernel_stride: 1 # default 1, convolution stride 16 | kernel_init_smoother: False # flag to init kernels to be smooth 17 | kernel_init_smoother_sigma: 0.1 # sigma of the gaussian kernel for kernel_init_smoother 18 | kernel_smoother: True # flag to apply smoother to the kernel during training 19 | kernel_smoother_penalty_weight: 0.012 # penalty weight to apply for kernel smoother 20 | # kernel_initialization: "../data/local-deconv-simulated/kernels_sin_and_bump.pt" # None, or a data path 21 | kernel_initialization: None # None, or a data path 22 | kernel_initialization_needs_adjustment_of_time_bin_resolution: False, 23 | ################################# 24 | # code (representation) 25 | code_nonneg: [1, 1] # apply sign constraint on the code. 1 for pos, -1 for neg, 2 for twosided 26 | code_sparse_regularization: 0.15 # apply sparse (lambda l1-norm) regularization on the code 27 | code_sparse_regularization_decay: 1 # apply decay factor to lambda at every encoder iteration 28 | code_q_regularization: False # set True to apply Q-regularization on the norm of the code 29 | code_q_regularization_matrix: None # The matrix of relations between the codes (use the path to load) 30 | code_q_regularization_matrix_path: "../data/local-deconv-calscenario-simulated/code_q_regularization_matrix.pt" 31 | code_q_regularization_period: 10 # the period to apply Q-regularization in encoder iterations 32 | code_q_regularization_scale: 2.5 # scale factor in front of the Q-regularization term 33 | code_q_regularization_norm_type: None # Set to the norm number you want the Q-regularization to be applied 34 | code_supp: False # True: apply known event indices (supp) into code x 35 | code_topk: True # True: keep only top k indices in each kernel code non-zero (this is greedy) 36 | code_topk_sparse: 4 # number of top k non-zero entires in each code kernel 37 | code_topk_period: 10 # period on encoder iteration to apply topk 38 | code_l1loss_bp: True # True: to include l1-norm of the code in the loss during training 39 | code_l1loss_bp_penalty_weight: 0.15 # amount of sparse regularization of the code with bp during training 40 | ################################# 41 | est_baseline_activity: False # True: estimate the baseline activity along with the code in the encoder 42 | poisson_stability_name: None # type of non-linearity to use on poisson case for encoder stability 43 | poisson_peak: 1 # For ELU "poisson_stability_name", this peak must be set to a value 44 | ################################# 45 | # unrolling parameters 46 | unrolling_num: 800 # number of unrolling iterations in the encoder 47 | unrolling_mode: "fista" # ista or fista encoder 48 | unrolling_alpha: 0.25 # alpha step size in unrolling for known events it was 0.5 49 | unrolling_prox: "shrinkage" # type of proximal operator (shrinkage, threshold) 50 | unrolling_threshold: None # must set to a value if unrolling_prox is threshold" 51 | ################################# 52 | # training related 53 | # default optimizer is ADAM. 54 | optimizer_lr: 1e-2 # learning rate for training the model (learning the kernels) 55 | optimizer_lr_step: 1000 # number of steps (updates) after which the lr will decay 56 | optimizer_lr_decay: 1 # decay factor for learning rate 57 | optimizer_adam_eps: 1e-3 # eps parameter of adam optimizer 58 | optimizer_adam_weight_decay: 0 # weight_decay parameter for adam optimizer 59 | # 60 | backward_gradient_decsent: truncated_bprop" # type of backward gradient update (bprop, truncated_bprop) 61 | backward_truncated_bprop_itr: 5 # must be set for truncated_bprop 62 | # 63 | train_num_epochs: None # number of epochs for training 64 | train_data_shuffle: True # True: to shuffle dataset at every epoch for training 65 | train_batch_size: 128 # batch size for training 66 | train_num_workers: 4 # number of workers to load data 67 | train_val_split: 0.85 # 1: use all for train. percentage of data used to train, rest to be used for validation. 68 | # 69 | enable_board: True 70 | log_info_epoch_period: 1 # period to push small info into the board 71 | log_model_epoch_period: 200 # period to save model 72 | log_fig_epoch_period: None # period to push figures into the board 73 | tqdm_prints_disable: False # True: to disable prints of epoch training process 74 | tqdm_prints_inside_disable: True # True: to disable prints inside of epoch training process -------------------------------------------------------------------------------- /config/local_deconv_calscenario_longtrial_simulated_config.yaml: -------------------------------------------------------------------------------- 1 | default: &DEFAULT 2 | exp_name: None 3 | data_path: None # this must be a list of data_path, set to "" if using data_folder as path. 4 | test_data_path: None # give it a list of datasets 5 | ################################# 6 | model_distribution: "binomial" # data distrbution gaussian, binomila, poisson 7 | share_kernels_among_neurons: True # set true to share kernels among neurons 8 | ################################# 9 | # kernel (dictionary) 10 | kernel_normalize: True # True: l2-norm of kernels is set to one after each update 11 | kernel_nonneg: False # True: project kernels into non-negative values 12 | kernel_nonneg_indicator: None # 0 for +-, 1 for + 13 | kernel_num: 2 # number of kernels to learn 14 | kernel_length: 16 # number of samples for kernel in time 15 | kernel_stride: 1 # default 1, convolution stride 16 | kernel_init_smoother: False # flag to init kernels to be smooth 17 | kernel_init_smoother_sigma: 0.1 # sigma of the gaussian kernel for kernel_init_smoother 18 | kernel_smoother: True # flag to apply smoother to the kernel during training 19 | kernel_smoother_penalty_weight: 0.015 # penalty weight to apply for kernel smoother 20 | # kernel_initialization: "../data/local-deconv-simulated/kernels_sin_and_bump.pt" # None, or a data path 21 | kernel_initialization: None # None, or a data path 22 | kernel_initialization_needs_adjustment_of_time_bin_resolution: False, 23 | ################################# 24 | # code (representation) 25 | code_nonneg: [1, 1] # apply sign constraint on the code. 1 for pos, -1 for neg, 2 for twosided 26 | code_sparse_regularization: 0.10 # apply sparse (lambda l1-norm) regularization on the code 27 | code_sparse_regularization_decay: 1 # apply decay factor to lambda at every encoder iteration 28 | code_q_regularization: False # set True to apply Q-regularization on the norm of the code 29 | code_q_regularization_matrix: None # The matrix of relations between the codes (use the path to load) 30 | code_q_regularization_matrix_path: "../data/local-deconv-calscenario-longtrial-simulated/code_q_regularization_matrix.pt" 31 | code_q_regularization_period: 10 # the period to apply Q-regularization in encoder iterations 32 | code_q_regularization_scale: 2.5 # scale factor in front of the Q-regularization term 33 | code_q_regularization_norm_type: None # Set to the norm number you want the Q-regularization to be applied 34 | code_supp: False # True: apply known event indices (supp) into code x 35 | code_topk: True # True: keep only top k indices in each kernel code non-zero (this is greedy) 36 | code_topk_sparse: 3 # number of top k non-zero entires in each code kernel 37 | code_topk_period: 10 # period on encoder iteration to apply topk 38 | code_l1loss_bp: True # True: to include l1-norm of the code in the loss during training 39 | code_l1loss_bp_penalty_weight: 0.10 # amount of sparse regularization of the code with bp during training 40 | ################################# 41 | est_baseline_activity: False # True: estimate the baseline activity along with the code in the encoder 42 | poisson_stability_name: None # type of non-linearity to use on poisson case for encoder stability 43 | poisson_peak: 1 # For ELU "poisson_stability_name", this peak must be set to a value 44 | ################################# 45 | # unrolling parameters 46 | unrolling_num: 800 # number of unrolling iterations in the encoder 47 | unrolling_mode: "fista" # ista or fista encoder 48 | unrolling_alpha: 0.25 # alpha step size in unrolling for known events it was 0.5 49 | unrolling_prox: "shrinkage" # type of proximal operator (shrinkage, threshold) 50 | unrolling_threshold: None # must set to a value if unrolling_prox is threshold" 51 | ################################# 52 | # training related 53 | # default optimizer is ADAM. 54 | optimizer_lr: 1e-2 # learning rate for training the model (learning the kernels) 55 | optimizer_lr_step: 1000 # number of steps (updates) after which the lr will decay 56 | optimizer_lr_decay: 1 # decay factor for learning rate 57 | optimizer_adam_eps: 1e-3 # eps parameter of adam optimizer 58 | optimizer_adam_weight_decay: 0 # weight_decay parameter for adam optimizer 59 | # 60 | backward_gradient_decsent: truncated_bprop" # type of backward gradient update (bprop, truncated_bprop) 61 | backward_truncated_bprop_itr: 5 # must be set for truncated_bprop 62 | # 63 | train_num_epochs: None # number of epochs for training 64 | train_data_shuffle: True # True: to shuffle dataset at every epoch for training 65 | train_batch_size: 128 # batch size for training 66 | train_num_workers: 4 # number of workers to load data 67 | train_val_split: 0.85 # 1: use all for train. percentage of data used to train, rest to be used for validation. 68 | # 69 | enable_board: True 70 | log_info_epoch_period: 1 # period to push small info into the board 71 | log_model_epoch_period: 200 # period to save model 72 | log_fig_epoch_period: None # period to push figures into the board 73 | tqdm_prints_disable: False # True: to disable prints of epoch training process 74 | tqdm_prints_inside_disable: True # True: to disable prints inside of epoch training process -------------------------------------------------------------------------------- /config/local_deconv_calscenario_shorttrial_simulated_config.yaml: -------------------------------------------------------------------------------- 1 | default: &DEFAULT 2 | exp_name: None 3 | data_path: None # this must be a list of data_path, set to "" if using data_folder as path. 4 | test_data_path: None # give it a list of datasets 5 | ################################# 6 | model_distribution: "binomial" # data distrbution gaussian, binomila, poisson 7 | share_kernels_among_neurons: True # set true to share kernels among neurons 8 | ################################# 9 | # kernel (dictionary) 10 | kernel_normalize: True # True: l2-norm of kernels is set to one after each update 11 | kernel_nonneg: False # True: project kernels into non-negative values 12 | kernel_nonneg_indicator: None # 0 for +-, 1 for + 13 | kernel_num: 2 # number of kernels to learn 14 | kernel_length: 16 # number of samples for kernel in time 15 | kernel_stride: 1 # default 1, convolution stride 16 | kernel_init_smoother: False # flag to init kernels to be smooth 17 | kernel_init_smoother_sigma: 0.1 # sigma of the gaussian kernel for kernel_init_smoother 18 | kernel_smoother: True # flag to apply smoother to the kernel during training 19 | kernel_smoother_penalty_weight: 0.015 # penalty weight to apply for kernel smoother 20 | # kernel_initialization: "../data/local-deconv-simulated/kernels_sin_and_bump.pt" # None, or a data path 21 | kernel_initialization: None # None, or a data path 22 | kernel_initialization_needs_adjustment_of_time_bin_resolution: False, 23 | ################################# 24 | # code (representation) 25 | code_nonneg: [1, 1] # apply sign constraint on the code. 1 for pos, -1 for neg, 2 for twosided 26 | code_sparse_regularization: 0.10 # apply sparse (lambda l1-norm) regularization on the code 27 | code_sparse_regularization_decay: 1 # apply decay factor to lambda at every encoder iteration 28 | code_q_regularization: False # set True to apply Q-regularization on the norm of the code 29 | code_q_regularization_matrix: None # The matrix of relations between the codes (use the path to load) 30 | code_q_regularization_matrix_path: "../data/local-deconv-calscenario-shorttrial-simulated/code_q_regularization_matrix.pt" 31 | code_q_regularization_period: 10 # the period to apply Q-regularization in encoder iterations 32 | code_q_regularization_scale: 2.5 # scale factor in front of the Q-regularization term 33 | code_q_regularization_norm_type: None # Set to the norm number you want the Q-regularization to be applied 34 | code_supp: False # True: apply known event indices (supp) into code x 35 | code_topk: True # True: keep only top k indices in each kernel code non-zero (this is greedy) 36 | code_topk_sparse: 2 # number of top k non-zero entires in each code kernel 37 | code_topk_period: 10 # period on encoder iteration to apply topk 38 | code_l1loss_bp: True # True: to include l1-norm of the code in the loss during training 39 | code_l1loss_bp_penalty_weight: 0.10 # amount of sparse regularization of the code with bp during training 40 | ################################# 41 | est_baseline_activity: False # True: estimate the baseline activity along with the code in the encoder 42 | poisson_stability_name: None # type of non-linearity to use on poisson case for encoder stability 43 | poisson_peak: 1 # For ELU "poisson_stability_name", this peak must be set to a value 44 | ################################# 45 | # unrolling parameters 46 | unrolling_num: 800 # number of unrolling iterations in the encoder 47 | unrolling_mode: "fista" # ista or fista encoder 48 | unrolling_alpha: 0.25 # alpha step size in unrolling for known events it was 0.5 49 | unrolling_prox: "shrinkage" # type of proximal operator (shrinkage, threshold) 50 | unrolling_threshold: None # must set to a value if unrolling_prox is threshold" 51 | ################################# 52 | # training related 53 | # default optimizer is ADAM. 54 | optimizer_lr: 1e-2 # learning rate for training the model (learning the kernels) 55 | optimizer_lr_step: 1000 # number of steps (updates) after which the lr will decay 56 | optimizer_lr_decay: 1 # decay factor for learning rate 57 | optimizer_adam_eps: 1e-3 # eps parameter of adam optimizer 58 | optimizer_adam_weight_decay: 0 # weight_decay parameter for adam optimizer 59 | # 60 | backward_gradient_decsent: truncated_bprop" # type of backward gradient update (bprop, truncated_bprop) 61 | backward_truncated_bprop_itr: 5 # must be set for truncated_bprop 62 | # 63 | train_num_epochs: None # number of epochs for training 64 | train_data_shuffle: True # True: to shuffle dataset at every epoch for training 65 | train_batch_size: 128 # batch size for training 66 | train_num_workers: 4 # number of workers to load data 67 | train_val_split: 0.85 # 1: use all for train. percentage of data used to train, rest to be used for validation. 68 | # 69 | enable_board: True 70 | log_info_epoch_period: 1 # period to push small info into the board 71 | log_model_epoch_period: 200 # period to save model 72 | log_fig_epoch_period: None # period to push figures into the board 73 | tqdm_prints_disable: False # True: to disable prints of epoch training process 74 | tqdm_prints_inside_disable: True # True: to disable prints inside of epoch training process -------------------------------------------------------------------------------- /config/local_deconv_calscenario_shorttrial_structured_simulated_config.yaml: -------------------------------------------------------------------------------- 1 | default: &DEFAULT 2 | exp_name: None 3 | data_path: None # this must be a list of data_path, set to "" if using data_folder as path. 4 | test_data_path: None # give it a list of datasets 5 | ################################# 6 | model_distribution: "binomial" # data distrbution gaussian, binomila, poisson 7 | share_kernels_among_neurons: True # set true to share kernels among neurons 8 | ################################# 9 | # kernel (dictionary) 10 | kernel_normalize: True # True: l2-norm of kernels is set to one after each update 11 | kernel_nonneg: False # True: project kernels into non-negative values 12 | kernel_nonneg_indicator: None # 0 for +-, 1 for + 13 | kernel_num: 2 # number of kernels to learn 14 | kernel_length: 16 # number of samples for kernel in time 15 | kernel_stride: 1 # default 1, convolution stride 16 | kernel_init_smoother: False # flag to init kernels to be smooth 17 | kernel_init_smoother_sigma: 0.1 # sigma of the gaussian kernel for kernel_init_smoother 18 | kernel_smoother: True # flag to apply smoother to the kernel during training 19 | kernel_smoother_penalty_weight: 0.015 # penalty weight to apply for kernel smoother 20 | # kernel_initialization: "../data/local-deconv-simulated/kernels_sin_and_bump.pt" # None, or a data path 21 | kernel_initialization: None # None, or a data path 22 | kernel_initialization_needs_adjustment_of_time_bin_resolution: False, 23 | ################################# 24 | # code (representation) 25 | code_nonneg: [1, 1] # apply sign constraint on the code. 1 for pos, -1 for neg, 2 for twosided 26 | code_sparse_regularization: 0.10 # apply sparse (lambda l1-norm) regularization on the code 27 | code_sparse_regularization_decay: 1 # apply decay factor to lambda at every encoder iteration 28 | code_q_regularization: False # set True to apply Q-regularization on the norm of the code 29 | code_q_regularization_matrix: None # The matrix of relations between the codes (use the path to load) 30 | code_q_regularization_matrix_path: "../data/local-deconv-calscenario-shorttrial-structured-simulated/code_q_regularization_matrix.pt" 31 | code_q_regularization_period: 10 # the period to apply Q-regularization in encoder iterations 32 | code_q_regularization_scale: 2.5 # scale factor in front of the Q-regularization term 33 | code_q_regularization_norm_type: None # Set to the norm number you want the Q-regularization to be applied 34 | code_supp: False # True: apply known event indices (supp) into code x 35 | code_topk: True # True: keep only top k indices in each kernel code non-zero (this is greedy) 36 | code_topk_sparse: 1 # number of top k non-zero entires in each code kernel 37 | code_topk_period: 10 # period on encoder iteration to apply topk 38 | code_l1loss_bp: True # True: to include l1-norm of the code in the loss during training 39 | code_l1loss_bp_penalty_weight: 0.10 # amount of sparse regularization of the code with bp during training 40 | ################################# 41 | est_baseline_activity: False # True: estimate the baseline activity along with the code in the encoder 42 | poisson_stability_name: None # type of non-linearity to use on poisson case for encoder stability 43 | poisson_peak: 1 # For ELU "poisson_stability_name", this peak must be set to a value 44 | ################################# 45 | # unrolling parameters 46 | unrolling_num: 800 # number of unrolling iterations in the encoder 47 | unrolling_mode: "fista" # ista or fista encoder 48 | unrolling_alpha: 0.25 # alpha step size in unrolling for known events it was 0.5 49 | unrolling_prox: "shrinkage" # type of proximal operator (shrinkage, threshold) 50 | unrolling_threshold: None # must set to a value if unrolling_prox is threshold" 51 | ################################# 52 | # training related 53 | # default optimizer is ADAM. 54 | optimizer_lr: 1e-2 # learning rate for training the model (learning the kernels) 55 | optimizer_lr_step: 1000 # number of steps (updates) after which the lr will decay 56 | optimizer_lr_decay: 1 # decay factor for learning rate 57 | optimizer_adam_eps: 1e-3 # eps parameter of adam optimizer 58 | optimizer_adam_weight_decay: 0 # weight_decay parameter for adam optimizer 59 | # 60 | backward_gradient_decsent: truncated_bprop" # type of backward gradient update (bprop, truncated_bprop) 61 | backward_truncated_bprop_itr: 5 # must be set for truncated_bprop 62 | # 63 | train_num_epochs: None # number of epochs for training 64 | train_data_shuffle: True # True: to shuffle dataset at every epoch for training 65 | train_batch_size: 128 # batch size for training 66 | train_num_workers: 4 # number of workers to load data 67 | train_val_split: 0.85 # 1: use all for train. percentage of data used to train, rest to be used for validation. 68 | # 69 | enable_board: True 70 | log_info_epoch_period: 1 # period to push small info into the board 71 | log_model_epoch_period: 200 # period to save model 72 | log_fig_epoch_period: None # period to push figures into the board 73 | tqdm_prints_disable: False # True: to disable prints of epoch training process 74 | tqdm_prints_inside_disable: True # True: to disable prints inside of epoch training process -------------------------------------------------------------------------------- /config/dopamine_spiking_eshel_uchida_limited_data_exp_config.yaml: -------------------------------------------------------------------------------- 1 | default: &DEFAULT 2 | exp_name: "dopaminespiking_25msbin_kernellength24_kernelnum3_codefree_kernel111_limiteddata0p1_smoothkernel_0p0005" 3 | data_path: "" # this must be a list of data_path, set to "" if using data_folder as path. 4 | data_folder: "../data/dopamine-spiking-eshel-uchida/train" # this will look for data in format *trainready.pt 5 | test_data_path: None 6 | ################################# 7 | model_distribution: "binomial" # data distrbution gaussian, binomila, poisson 8 | share_kernels_among_neurons: True # set true to share kernels among neurons 9 | ################################# 10 | # kernel (dictionary) 11 | kernel_normalize: True # True: l2-norm of kernels is set to one after each update 12 | kernel_nonneg: True # True: project kernels into non-negative values 13 | kernel_nonneg_indicator: None # list of [0 for +-, 1 for +] or set to None for all + 14 | kernel_num: 3 # number of kernels to learn 15 | kernel_length: 24 # number of samples for kernel in time 16 | kernel_stride: 1 # default 1, convolution stride 17 | kernel_init_smoother: False # flag to init kernels to be smooth 18 | kernel_init_smoother_sigma: 0.5 # sigma of the gaussian kernel for kernel_init_smoother 19 | kernel_smoother: True # flag to apply smoother to the kernel during training 20 | kernel_smoother_penalty_weight: 0.0005 # penalty weight to apply for kernel smoother 21 | kernel_initialization: None # None, or a data path 22 | ################################# 23 | # code (representation) 24 | code_nonneg: False # apply sign constraint on the code. list of [1 for pos, -1 for neg, 2 for twosided], or set True/False for all 25 | code_sparse_regularization: 0 # apply sparse (lambda l1-norm) regularization on the code 26 | code_sparse_regularization_decay: 1 # apply decay factor to lambda at every encoder iteration 27 | code_q_regularization: False # set True to apply Q-regularization on the norm of the code 28 | code_q_regularization_matrix: None # The matrix of relations between the codes (if flag is True, use the path to load) 29 | code_q_regularization_matrix_path: "" 30 | code_q_regularization_period: 1 # the period to apply Q-regularization in encoder iterations 31 | code_q_regularization_scale: 2.5 # scale factor in front of the Q-regularization term 32 | code_q_regularization_norm_type: 2 # Set to the norm number you want the Q-regularization to be applied 33 | code_supp: True # True: apply known event indices (supp) into code x 34 | code_topk: False # True: keep only top k indices in each kernel code non-zero (this is greedy) 35 | code_topk_sparse: None # number of top k non-zero entires in each code kernel 36 | code_topk_period: None # period on encoder iteration to apply topk 37 | code_l1loss_bp: False # True: to include l1-norm of the code in the loss during training 38 | code_l1loss_bp_penalty_weight: 0 # amount of sparse regularization of the code with bp during training 39 | ################################# 40 | est_baseline_activity: False # True: estimate the baseline activity along with the code in the encoder 41 | poisson_stability_name: None # type of non-linearity to use on poisson case for encoder stability 42 | poisson_peak: 1 # For ELU "poisson_stability_name", this peak must be set to a value 43 | ################################# 44 | # unrolling parameters 45 | unrolling_num: 100 # number of unrolling iterations in the encoder 46 | unrolling_mode: "fista" # ista or fista encoder 47 | unrolling_alpha: 0.1 # alpha step size in unrolling 48 | unrolling_prox: "shrinkage" # type of proximal operator (shrinkage, threshold) 49 | unrolling_threshold: None # must set to a value if unrolling_prox is threshold" 50 | ################################# 51 | # training related 52 | # default optimizer is ADAM. 53 | optimizer_lr: 1e-2 # learning rate for training the model (learning the kernels) 54 | optimizer_lr_step: 20 # number of steps (updates) after which the lr will decay 55 | optimizer_lr_decay: 1 # decay factor for learning rate 56 | optimizer_adam_eps: 1e-3 # eps parameter of adam optimizer 57 | optimizer_adam_weight_decay: 0 # weight_decay parameter for adam optimizer 58 | # 59 | backward_gradient_decsent: bprop" # type of backward gradient update (bprop, truncated_bprop) 60 | backward_truncated_bprop_itr: 10 # must be set for truncated_bprop 61 | # 62 | train_num_epochs: 50 # number of epochs for training 63 | train_data_shuffle: True # True: to shuffle dataset at every epoch for training 64 | train_batch_size: 32 # batch size for training 65 | train_num_workers: 4 # number of workers to load data 66 | train_val_split: 1 # 1: use all for train. percentage of data used to train, rest to be used for validation. 67 | train_with_fraction: 0.1 # 1 for all the data, or a fraction e.g. 0.1 68 | # 69 | enable_board: True 70 | log_info_epoch_period: 1 # period to push small info into the board 71 | log_model_epoch_period: 50 # period to save model 72 | log_fig_epoch_period: 10 # period to push figures into the board 73 | tqdm_prints_disable: False # True: to disable prints of epoch training process 74 | tqdm_prints_inside_disable: True # True: to disable prints inside of epoch training process -------------------------------------------------------------------------------- /config/whisker_simulated_config.yaml: -------------------------------------------------------------------------------- 1 | default: &DEFAULT 2 | exp_name: None 3 | data_path: None 4 | test_data_path: None # give it a list of datasets 5 | # data_path: ["../data/whisker/whisker_train_10msbinres_general_format_processed_kernellength12_kernelnum1_trainready.pt"] # give it a list of datasets 6 | # test_data_path: ["../data/whisker/whisker_test_10msbinres_general_format_processed_kernellength12_kernelnum1_trainready.pt"] # give it a list of datasets 7 | ################################# 8 | model_distribution: "binomial" # data distrbution gaussian, binomila, poisson 9 | share_kernels_among_neurons: True # set true to share kernels among neurons 10 | ################################# 11 | # kernel (dictionary) 12 | kernel_normalize: True # True: l2-norm of kernels is set to one after each update 13 | kernel_nonneg: False # True: project kernels into non-negative values 14 | kernel_nonneg_indicator: [0] # 0 for +-, 1 for + 15 | kernel_num: 1 # number of kernels to learn 16 | kernel_length: None # number of samples for kernel in time 17 | kernel_stride: 1 # default 1, convolution stride 18 | kernel_init_smoother: False # flag to init kernels to be smooth 19 | kernel_init_smoother_sigma: 0.5 # sigma of the gaussian kernel for kernel_init_smoother 20 | kernel_smoother: True # flag to apply smoother to the kernel during training 21 | kernel_smoother_penalty_weight: None # penalty weight to apply for kernel smoother 22 | kernel_initialization: "../data/whisker-simulated/kernels_sin.pt" # None, or a data path 23 | kernel_initialization_needs_adjustment_of_time_bin_resolution: True, 24 | ################################# 25 | # code (representation) 26 | code_nonneg: [1] # apply sign constraint on the code. 1 for pos, -1 for neg, 2 for twosided 27 | code_sparse_regularization: 0.03 # apply sparse (lambda l1-norm) regularization on the code 28 | code_sparse_regularization_decay: 1 # apply decay factor to lambda at every encoder iteration 29 | code_q_regularization: False # set True to apply Q-regularization on the norm of the code 30 | code_q_regularization_matrix: None # The matrix of relations between the codes (if flag is True, use the path to load) 31 | code_q_regularization_matrix_path: None 32 | code_q_regularization_period: 1 # the period to apply Q-regularization in encoder iterations 33 | code_q_regularization_scale: 5 # scale factor in front of the Q-regularization term 34 | code_q_regularization_norm_type: 2 # Set to the norm number you want the Q-regularization to be applied 35 | code_supp: True # True: apply known event indices (supp) into code x 36 | code_topk: False # True: keep only top k indices in each kernel code non-zero (this is greedy) 37 | code_topk_sparse: 5 # number of top k non-zero entires in each code kernel 38 | code_topk_period: 10 # period on encoder iteration to apply topk 39 | code_l1loss_bp: True # True: to include l1-norm of the code in the loss during training 40 | code_l1loss_bp_penalty_weight: 0.03 # amount of sparse regularization of the code with bp during training 41 | ################################# 42 | est_baseline_activity: False # True: estimate the baseline activity along with the code in the encoder 43 | poisson_stability_name: None # type of non-linearity to use on poisson case for encoder stability 44 | poisson_peak: 1 # For ELU "poisson_stability_name", this peak must be set to a value 45 | ################################# 46 | # unrolling parameters 47 | unrolling_num: 800 # number of unrolling iterations in the encoder 48 | unrolling_mode: "fista" # ista or fista encoder 49 | unrolling_alpha: 0.25 # alpha step size in unrolling for known events it was 0.5 50 | unrolling_prox: "shrinkage" # type of proximal operator (shrinkage, threshold) 51 | unrolling_threshold: None # must set to a value if unrolling_prox is threshold" 52 | ################################# 53 | # training related 54 | # default optimizer is ADAM. 55 | optimizer_lr: 1e-2 # learning rate for training the model (learning the kernels) 56 | optimizer_lr_step: 1000 # number of steps (updates) after which the lr will decay 57 | optimizer_lr_decay: 1 # decay factor for learning rate 58 | optimizer_adam_eps: 1e-3 # eps parameter of adam optimizer 59 | optimizer_adam_weight_decay: 0 # weight_decay parameter for adam optimizer 60 | # 61 | backward_gradient_decsent: truncated_bprop" # type of backward gradient update (bprop, truncated_bprop) 62 | backward_truncated_bprop_itr: 20 # must be set for truncated_bprop 63 | # 64 | train_num_epochs: None # number of epochs for training 65 | train_data_shuffle: True # True: to shuffle dataset at every epoch for training 66 | train_batch_size: 128 # batch size for training 67 | train_num_workers: 4 # number of workers to load data 68 | train_val_split: 1 # 1: use all for train. percentage of data used to train, rest to be used for validation. 69 | # 70 | enable_board: True 71 | log_info_epoch_period: 1 # period to push small info into the board 72 | log_model_epoch_period: 1 # period to save model 73 | log_fig_epoch_period: None # period to push figures into the board 74 | tqdm_prints_disable: False # True: to disable prints of epoch training process 75 | tqdm_prints_inside_disable: True # True: to disable prints inside of epoch training process -------------------------------------------------------------------------------- /config/dopamine_spiking_simulated_config.yaml: -------------------------------------------------------------------------------- 1 | default: &DEFAULT 2 | exp_name: "simulated_dopaminespiking_40neurons_14trials_25msbin_kernellength24_kernelnum3_codefree_kernel111" 3 | data_path: ["../data/dopamine-spiking-simulated/simulated_40nuerons_14trials_25msbinres_general_format_processed_kernellength24_kernelnum3_trainready.pt"] # this must be a list of data_path, set to "" if using data_folder as path. 4 | data_folder: None # this will look for data in format *trainready.pt 5 | test_data_path: None 6 | ################################# 7 | model_distribution: "binomial" # data distrbution gaussian, binomila, poisson 8 | share_kernels_among_neurons: True # set true to share kernels among neurons 9 | ################################# 10 | # kernel (dictionary) 11 | kernel_normalize: True # True: l2-norm of kernels is set to one after each update 12 | kernel_nonneg: True # True: project kernels into non-negative values 13 | kernel_nonneg_indicator: None # list of [0 for +-, 1 for +] or set to None for all + 14 | kernel_num: 3 # number of kernels to learn 15 | kernel_length: 24 # number of samples for kernel in time 16 | kernel_stride: 1 # default 1, convolution stride 17 | kernel_init_smoother: False # flag to init kernels to be smooth 18 | kernel_init_smoother_sigma: 0.5 # sigma of the gaussian kernel for kernel_init_smoother 19 | kernel_smoother: False # flag to apply smoother to the kernel during training 20 | kernel_smoother_penalty_weight: 0 # penalty weight to apply for kernel smoother 21 | kernel_initialization: None # None, or a data path 22 | ################################# 23 | # code (representation) 24 | code_nonneg: False # apply sign constraint on the code. list of [1 for pos, -1 for neg, 2 for twosided], or set True/False for all 25 | code_sparse_regularization: 0 # apply sparse (lambda l1-norm) regularization on the code 26 | code_sparse_regularization_decay: 1 # apply decay factor to lambda at every encoder iteration 27 | code_q_regularization: False # set True to apply Q-regularization on the norm of the code 28 | code_q_regularization_matrix: None # The matrix of relations between the codes (if flag is True, use the path to load) 29 | code_q_regularization_matrix_path: "" 30 | code_q_regularization_period: 1 # the period to apply Q-regularization in encoder iterations 31 | code_q_regularization_scale: 2.5 # scale factor in front of the Q-regularization term 32 | code_q_regularization_norm_type: 2 # Set to the norm number you want the Q-regularization to be applied 33 | code_supp: True # True: apply known event indices (supp) into code x 34 | code_topk: False # True: keep only top k indices in each kernel code non-zero (this is greedy) 35 | code_topk_sparse: None # number of top k non-zero entires in each code kernel 36 | code_topk_period: None # period on encoder iteration to apply topk 37 | code_l1loss_bp: False # True: to include l1-norm of the code in the loss during training 38 | code_l1loss_bp_penalty_weight: 0 # amount of sparse regularization of the code with bp during training 39 | ################################# 40 | est_baseline_activity: False # True: estimate the baseline activity along with the code in the encoder 41 | poisson_stability_name: None # type of non-linearity to use on poisson case for encoder stability 42 | poisson_peak: 1 # For ELU "poisson_stability_name", this peak must be set to a value 43 | ################################# 44 | # unrolling parameters 45 | unrolling_num: 100 # number of unrolling iterations in the encoder 46 | unrolling_mode: "fista" # ista or fista encoder 47 | unrolling_alpha: 0.1 # alpha step size in unrolling 48 | unrolling_prox: "shrinkage" # type of proximal operator (shrinkage, threshold) 49 | unrolling_threshold: None # must set to a value if unrolling_prox is threshold" 50 | ################################# 51 | # training related 52 | # default optimizer is ADAM. 53 | optimizer_lr: 1e-2 # learning rate for training the model (learning the kernels) 54 | optimizer_lr_step: 20 # number of steps (updates) after which the lr will decay 55 | optimizer_lr_decay: 1 # decay factor for learning rate 56 | optimizer_adam_eps: 1e-3 # eps parameter of adam optimizer 57 | optimizer_adam_weight_decay: 0 # weight_decay parameter for adam optimizer 58 | # 59 | backward_gradient_decsent: bprop" # type of backward gradient update (bprop, truncated_bprop) 60 | backward_truncated_bprop_itr: 10 # must be set for truncated_bprop 61 | # 62 | train_num_epochs: 200 # number of epochs for training 63 | train_data_shuffle: True # True: to shuffle dataset at every epoch for training 64 | train_batch_size: 2 # batch size for training 65 | train_num_workers: 4 # number of workers to load data 66 | train_val_split: 1 # 1: use all for train. percentage of data used to train, rest to be used for validation. 67 | train_with_fraction: 1 # 1 for all the data, or a fraction e.g. 0.1 68 | # 69 | enable_board: True 70 | log_info_epoch_period: 1 # period to push small info into the board 71 | log_model_epoch_period: 50 # period to save model 72 | log_fig_epoch_period: 1 # period to push figures into the board 73 | tqdm_prints_disable: False # True: to disable prints of epoch training process 74 | tqdm_prints_inside_disable: True # True: to disable prints inside of epoch training process -------------------------------------------------------------------------------- /dunl/postprocess_scripts/save_data_for_pcanmf_dopamine_calcium_saramatias_uchida.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2025 Bahareh Tolooshams 3 | 4 | plot code data 5 | 6 | :author: Bahareh Tolooshams 7 | """ 8 | 9 | import torch 10 | import numpy as np 11 | import os 12 | import pickle 13 | import argparse 14 | 15 | 16 | def init_params(): 17 | parser = argparse.ArgumentParser(description=__doc__) 18 | 19 | parser.add_argument( 20 | "--res-path", 21 | type=str, 22 | help="res path", 23 | default="../results/dopaminecalcium_kernellength60_kernelnum5_code2211n1_kernel00011_qreg_2023_07_13_11_37_31", 24 | ) 25 | parser.add_argument( 26 | "--reward-amount-list", 27 | type=list, 28 | help="reward amount list", 29 | default=[0.0, 0.3, 0.5, 1.2, 2.5, 5.0, 8.0, 11.0], 30 | ) 31 | parser.add_argument( 32 | "--window-dur", 33 | type=int, 34 | help="window duration to get average activity", 35 | default=60, # this is after time bin resolution 36 | ) 37 | parser.add_argument( 38 | "--save-only-sur", 39 | type=bool, 40 | help="save only surprise trials", 41 | default=False, 42 | ) 43 | args = parser.parse_args() 44 | params = vars(args) 45 | 46 | return params 47 | 48 | 49 | def main(): 50 | print("Predict.") 51 | 52 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 53 | print("device is", device) 54 | 55 | # init parameters -------------------------------------------------------# 56 | print("init parameters.") 57 | params_init = init_params() 58 | 59 | # take parameters from the result path 60 | params = pickle.load( 61 | open(os.path.join(params_init["res_path"], "params.pickle"), "rb") 62 | ) 63 | for key in params_init.keys(): 64 | params[key] = params_init[key] 65 | 66 | if params["data_path"] == "": 67 | data_folder = params["data_folder"] 68 | filename_list = os.listdir(data_folder) 69 | data_path_list = [ 70 | f"{data_folder}/{x}" for x in filename_list if "trainready.pt" in x 71 | ] 72 | else: 73 | data_path_list = params["data_path"] 74 | 75 | print("There {} dataset in the folder.".format(len(data_path_list))) 76 | 77 | # create folders -------------------------------------------------------# 78 | 79 | postprocess_path = os.path.join( 80 | params["res_path"], 81 | "postprocess", 82 | ) 83 | 84 | # load data -------------------------------------------------------# 85 | 86 | for data_path in data_path_list: 87 | datafile_name = data_path.split("/")[-1].split(".pt")[0] 88 | 89 | y = torch.load(os.path.join(postprocess_path, "y_{}.pt".format(datafile_name))) 90 | x = torch.load(os.path.join(postprocess_path, "x_{}.pt".format(datafile_name))) 91 | label = torch.load( 92 | os.path.join(postprocess_path, "label_{}.pt".format(datafile_name)) 93 | ) 94 | 95 | num_trials = y.shape[0] 96 | 97 | yavg = list() 98 | rew_amount = list() 99 | 100 | # go over all trials 101 | for i in range(num_trials): 102 | yi = y[i] 103 | xi = x[i] 104 | labeli = label[i] 105 | 106 | if labeli == 0: 107 | continue 108 | else: 109 | # skip if it's a expected trial 110 | if params["save_only_sur"]: 111 | cue_flag = torch.sum(torch.abs(xi[1]), dim=-1).item() 112 | if cue_flag: 113 | # expected trial hence, skip 114 | continue 115 | 116 | # reward presence 117 | reward_onset = np.where(xi[2] > 0)[-1][0] 118 | 119 | y_curr = yi[:, reward_onset : reward_onset + params["window_dur"]] 120 | yavg.append(y_curr) 121 | rew_amount.append(labeli) 122 | 123 | # (neurons, time, trials) 124 | yavg = torch.stack(yavg, dim=-1).clone().detach().cpu().numpy() 125 | rew_amount = np.array(rew_amount) 126 | if params["save_only_sur"]: 127 | np.save( 128 | os.path.join( 129 | postprocess_path, 130 | "y_for_pcanmf_{}_only_sur.npy".format(datafile_name), 131 | ), 132 | yavg, 133 | ) 134 | np.save( 135 | os.path.join( 136 | postprocess_path, 137 | "label_for_pcanmf_{}_only_sur.npy".format(datafile_name), 138 | ), 139 | rew_amount, 140 | ) 141 | else: 142 | np.save( 143 | os.path.join( 144 | postprocess_path, "y_for_pcanmf_{}.npy".format(datafile_name) 145 | ), 146 | yavg, 147 | ) 148 | np.save( 149 | os.path.join( 150 | postprocess_path, "label_for_pcanmf_{}.npy".format(datafile_name) 151 | ), 152 | rew_amount, 153 | ) 154 | 155 | 156 | if __name__ == "__main__": 157 | main() 158 | -------------------------------------------------------------------------------- /config/dopamine_spiking_eshel_uchida_code122_kernel011_config.yaml: -------------------------------------------------------------------------------- 1 | default: &DEFAULT 2 | exp_name: "dopaminespiking_25msbin_kernellength24_kernelnum3_code122_kernel011" 3 | data_path: "" # this must be a list of data_path, set to "" if using data_folder as path. 4 | data_folder: "../data/dopamine-spiking-eshel-uchida" # this will look for data in format *trainready.pt 5 | test_data_path: None 6 | ################################# 7 | model_distribution: "binomial" # data distrbution gaussian, binomila, poisson 8 | share_kernels_among_neurons: True # set true to share kernels among neurons 9 | ################################# 10 | # kernel (dictionary) 11 | kernel_normalize: True # True: l2-norm of kernels is set to one after each update 12 | kernel_nonneg: True # True: project kernels into non-negative values 13 | kernel_nonneg_indicator: [0, 1, 1] # list of [0 for +-, 1 for +] or set to None for all + 14 | kernel_num: 3 # number of kernels to learn 15 | kernel_length: 24 # number of samples for kernel in time 16 | kernel_stride: 1 # default 1, convolution stride 17 | kernel_init_smoother: False # flag to init kernels to be smooth 18 | kernel_init_smoother_sigma: 0.5 # sigma of the gaussian kernel for kernel_init_smoother 19 | kernel_smoother: False # flag to apply smoother to the kernel during training 20 | kernel_smoother_penalty_weight: 0 # penalty weight to apply for kernel smoother 21 | kernel_initialization: None # None, or a data path 22 | kernel_initialization_needs_adjustment_of_time_bin_resolution: False 23 | ################################# 24 | # code (representation) 25 | code_nonneg: [1, 2, 2] # apply sign constraint on the code. list of [1 for pos, -1 for neg, 2 for twosided], or set True/False for all 26 | code_sparse_regularization: 0 # apply sparse (lambda l1-norm) regularization on the code 27 | code_sparse_regularization_decay: 1 # apply decay factor to lambda at every encoder iteration 28 | code_group_neural_firings_regularization: 0 # if > 0, then it applies groupping across neurons 29 | code_q_regularization: False # set True to apply Q-regularization on the norm of the code 30 | code_q_regularization_matrix: None # The matrix of relations between the codes (if flag is True, use the path to load) 31 | code_q_regularization_matrix_path: "" 32 | code_q_regularization_period: 1 # the period to apply Q-regularization in encoder iterations 33 | code_q_regularization_scale: 2.5 # scale factor in front of the Q-regularization term 34 | code_q_regularization_norm_type: 2 # Set to the norm number you want the Q-regularization to be applied 35 | code_supp: True # True: apply known event indices (supp) into code x 36 | code_topk: False # True: keep only top k indices in each kernel code non-zero (this is greedy) 37 | code_topk_sparse: None # number of top k non-zero entires in each code kernel 38 | code_topk_period: None # period on encoder iteration to apply topk 39 | code_l1loss_bp: False # True: to include l1-norm of the code in the loss during training 40 | code_l1loss_bp_penalty_weight: 0 # amount of sparse regularization of the code with bp during training 41 | ################################# 42 | est_baseline_activity: False # True: estimate the baseline activity along with the code in the encoder 43 | poisson_stability_name: None # type of non-linearity to use on poisson case for encoder stability 44 | poisson_peak: 1 # For ELU "poisson_stability_name", this peak must be set to a value 45 | ################################# 46 | # unrolling parameters 47 | unrolling_num: 100 # number of unrolling iterations in the encoder 48 | unrolling_mode: "fista" # ista or fista encoder 49 | unrolling_alpha: 0.1 # alpha step size in unrolling 50 | unrolling_prox: "shrinkage" # type of proximal operator (shrinkage, threshold) 51 | unrolling_threshold: None # must set to a value if unrolling_prox is threshold" 52 | ################################# 53 | # training related 54 | # default optimizer is ADAM. 55 | optimizer_lr: 1e-2 # learning rate for training the model (learning the kernels) 56 | optimizer_lr_step: 20 # number of steps (updates) after which the lr will decay 57 | optimizer_lr_decay: 1 # decay factor for learning rate 58 | optimizer_adam_eps: 1e-3 # eps parameter of adam optimizer 59 | optimizer_adam_weight_decay: 0 # weight_decay parameter for adam optimizer 60 | # 61 | backward_gradient_decsent: bprop" # type of backward gradient update (bprop, truncated_bprop) 62 | backward_truncated_bprop_itr: 10 # must be set for truncated_bprop 63 | # 64 | train_num_epochs: 15 # number of epochs for training 65 | train_data_shuffle: True # True: to shuffle dataset at every epoch for training 66 | train_batch_size: 32 # batch size for training 67 | train_num_workers: 4 # number of workers to load data 68 | train_val_split: 1 # 1: use all for train. percentage of data used to train, rest to be used for validation. 69 | train_with_fraction: 1 # 1 for all the data, or a fraction e.g. 0.1 70 | # 71 | enable_board: True 72 | log_info_epoch_period: 1 # period to push small info into the board 73 | log_model_epoch_period: 5 # period to save model 74 | log_fig_epoch_period: 1 # period to push figures into the board 75 | tqdm_prints_disable: False # True: to disable prints of epoch training process 76 | tqdm_prints_inside_disable: True # True: to disable prints inside of epoch training process -------------------------------------------------------------------------------- /config/dopamine_spiking_eshel_uchida_code122_kernel011_inferbaseline_config.yaml: -------------------------------------------------------------------------------- 1 | default: &DEFAULT 2 | exp_name: "dopaminespiking_25msbin_kernellength24_kernelnum3_code122_kernel011_inferbase" 3 | data_path: "" # this must be a list of data_path, set to "" if using data_folder as path. 4 | data_folder: "../data/dopamine-spiking-eshel-uchida" # this will look for data in format *trainready.pt 5 | test_data_path: None 6 | ################################# 7 | model_distribution: "binomial" # data distrbution gaussian, binomila, poisson 8 | share_kernels_among_neurons: True # set true to share kernels among neurons 9 | ################################# 10 | # kernel (dictionary) 11 | kernel_normalize: True # True: l2-norm of kernels is set to one after each update 12 | kernel_nonneg: True # True: project kernels into non-negative values 13 | kernel_nonneg_indicator: [0, 1, 1] # list of [0 for +-, 1 for +] or set to None for all + 14 | kernel_num: 3 # number of kernels to learn 15 | kernel_length: 24 # number of samples for kernel in time 16 | kernel_stride: 1 # default 1, convolution stride 17 | kernel_init_smoother: False # flag to init kernels to be smooth 18 | kernel_init_smoother_sigma: 0.5 # sigma of the gaussian kernel for kernel_init_smoother 19 | kernel_smoother: False # flag to apply smoother to the kernel during training 20 | kernel_smoother_penalty_weight: 0 # penalty weight to apply for kernel smoother 21 | kernel_initialization: None # None, or a data path 22 | kernel_initialization_needs_adjustment_of_time_bin_resolution: False 23 | ################################# 24 | # code (representation) 25 | code_nonneg: [1, 2, 2] # apply sign constraint on the code. list of [1 for pos, -1 for neg, 2 for twosided], or set True/False for all 26 | code_sparse_regularization: 0 # apply sparse (lambda l1-norm) regularization on the code 27 | code_sparse_regularization_decay: 1 # apply decay factor to lambda at every encoder iteration 28 | code_group_neural_firings_regularization: 0 # if > 0, then it applies groupping across neurons 29 | code_q_regularization: False # set True to apply Q-regularization on the norm of the code 30 | code_q_regularization_matrix: None # The matrix of relations between the codes (if flag is True, use the path to load) 31 | code_q_regularization_matrix_path: "" 32 | code_q_regularization_period: 1 # the period to apply Q-regularization in encoder iterations 33 | code_q_regularization_scale: 2.5 # scale factor in front of the Q-regularization term 34 | code_q_regularization_norm_type: 2 # Set to the norm number you want the Q-regularization to be applied 35 | code_supp: True # True: apply known event indices (supp) into code x 36 | code_topk: False # True: keep only top k indices in each kernel code non-zero (this is greedy) 37 | code_topk_sparse: None # number of top k non-zero entires in each code kernel 38 | code_topk_period: None # period on encoder iteration to apply topk 39 | code_l1loss_bp: False # True: to include l1-norm of the code in the loss during training 40 | code_l1loss_bp_penalty_weight: 0 # amount of sparse regularization of the code with bp during training 41 | ################################# 42 | est_baseline_activity: True # True: estimate the baseline activity along with the code in the encoder 43 | poisson_stability_name: None # type of non-linearity to use on poisson case for encoder stability 44 | poisson_peak: 1 # For ELU "poisson_stability_name", this peak must be set to a value 45 | ################################# 46 | # unrolling parameters 47 | unrolling_num: 100 # number of unrolling iterations in the encoder 48 | unrolling_mode: "fista" # ista or fista encoder 49 | unrolling_alpha: 0.1 # alpha step size in unrolling 50 | unrolling_prox: "shrinkage" # type of proximal operator (shrinkage, threshold) 51 | unrolling_threshold: None # must set to a value if unrolling_prox is threshold" 52 | ################################# 53 | # training related 54 | # default optimizer is ADAM. 55 | optimizer_lr: 1e-2 # learning rate for training the model (learning the kernels) 56 | optimizer_lr_step: 20 # number of steps (updates) after which the lr will decay 57 | optimizer_lr_decay: 1 # decay factor for learning rate 58 | optimizer_adam_eps: 1e-3 # eps parameter of adam optimizer 59 | optimizer_adam_weight_decay: 0 # weight_decay parameter for adam optimizer 60 | # 61 | backward_gradient_decsent: bprop" # type of backward gradient update (bprop, truncated_bprop) 62 | backward_truncated_bprop_itr: 10 # must be set for truncated_bprop 63 | # 64 | train_num_epochs: 15 # number of epochs for training 65 | train_data_shuffle: True # True: to shuffle dataset at every epoch for training 66 | train_batch_size: 32 # batch size for training 67 | train_num_workers: 4 # number of workers to load data 68 | train_val_split: 1 # 1: use all for train. percentage of data used to train, rest to be used for validation. 69 | train_with_fraction: 1 # 1 for all the data, or a fraction e.g. 0.1 70 | # 71 | enable_board: True 72 | log_info_epoch_period: 1 # period to push small info into the board 73 | log_model_epoch_period: 5 # period to save model 74 | log_fig_epoch_period: 1 # period to push figures into the board 75 | tqdm_prints_disable: False # True: to disable prints of epoch training process 76 | tqdm_prints_inside_disable: True # True: to disable prints inside of epoch training process -------------------------------------------------------------------------------- /dunl/postprocess_scripts/plot_kernels_dopamine_calcium_saramatias_uchida.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2025 Bahareh Tolooshams 3 | 4 | plot rec data 5 | 6 | :author: Bahareh Tolooshams 7 | """ 8 | 9 | import torch 10 | import numpy as np 11 | import os 12 | import pickle 13 | import argparse 14 | import matplotlib as mpl 15 | import matplotlib.pyplot as plt 16 | 17 | 18 | import sys 19 | 20 | sys.path.append("../dunl/") 21 | 22 | import model 23 | 24 | 25 | def init_params(): 26 | parser = argparse.ArgumentParser(description=__doc__) 27 | 28 | parser.add_argument( 29 | "--res-path", 30 | type=str, 31 | help="res path", 32 | default="../results/dopaminecalcium_kernellength60_kernelnum5_code2211n1_kernel00011_qreg_fixedq_2p5_firstshrinkage_2023_09_27_01_17_09", 33 | ) 34 | parser.add_argument( 35 | "--sampling-rate", 36 | type=int, 37 | help="sampling rate", 38 | default=15, 39 | ) 40 | parser.add_argument( 41 | "--color-list", 42 | type=list, 43 | help="color list", 44 | default=[ 45 | "black", 46 | "orange", 47 | "blue", 48 | "red", 49 | "green", 50 | ], # cue reg, cue exp, 3 rewards 51 | ) 52 | parser.add_argument( 53 | "--figsize", 54 | type=tuple, 55 | help="figsize", 56 | default=(8, 2), 57 | ) 58 | 59 | args = parser.parse_args() 60 | params = vars(args) 61 | 62 | return params 63 | 64 | 65 | def main(): 66 | print("Predict.") 67 | 68 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 69 | print("device is", device) 70 | 71 | # init parameters -------------------------------------------------------# 72 | print("init parameters.") 73 | params_init = init_params() 74 | 75 | # take parameters from the result path 76 | params = pickle.load( 77 | open(os.path.join(params_init["res_path"], "params.pickle"), "rb") 78 | ) 79 | for key in params_init.keys(): 80 | params[key] = params_init[key] 81 | 82 | # create folders -------------------------------------------------------# 83 | model_path = os.path.join( 84 | params["res_path"], 85 | "model", 86 | "model_final.pt", 87 | ) 88 | 89 | out_path = os.path.join( 90 | params["res_path"], 91 | "figures", 92 | ) 93 | if not os.path.exists(out_path): 94 | os.makedirs(out_path) 95 | 96 | # load model ------------------------------------------------------# 97 | net = torch.load(model_path, map_location=device) 98 | net.to(device) 99 | net.eval() 100 | 101 | kernels = np.squeeze(net.get_param("H").clone().detach().cpu().numpy()) 102 | 103 | plot_kernel(kernels, params, out_path) 104 | 105 | 106 | def plot_kernel(kernels, params, out_path): 107 | axes_fontsize = 15 108 | legend_fontsize = 8 109 | tick_fontsize = 15 110 | title_fontsize = 20 111 | fontfamily = "sans-serif" 112 | 113 | # upadte plot parameters 114 | # style 115 | mpl.rcParams.update( 116 | { 117 | "pgf.texsystem": "pdflatex", 118 | "text.usetex": True, 119 | "axes.labelsize": axes_fontsize, 120 | "axes.titlesize": title_fontsize, 121 | "legend.fontsize": legend_fontsize, 122 | "xtick.labelsize": tick_fontsize, 123 | "ytick.labelsize": tick_fontsize, 124 | "text.latex.preamble": r"\usepackage{bm}", 125 | "axes.unicode_minus": False, 126 | "font.family": fontfamily, 127 | } 128 | ) 129 | 130 | fig, axn = plt.subplots(1, 4, sharex=True, sharey=True, figsize=params["figsize"]) 131 | 132 | for ax in axn.flat: 133 | ax.tick_params(axis="x", direction="out") 134 | ax.tick_params(axis="y", direction="out") 135 | ax.spines["right"].set_visible(False) 136 | ax.spines["top"].set_visible(False) 137 | 138 | t = np.linspace( 139 | 0, params["kernel_length"] / params["sampling_rate"], params["kernel_length"] 140 | ) 141 | 142 | for ctr in range(params["kernel_num"] - 1): 143 | plt.subplot(1, 4, ctr + 1) 144 | axn[ctr].axhline(0, color="gray", lw=0.3) 145 | 146 | plt.plot(t, kernels[ctr], color=params["color_list"][ctr]) 147 | if ctr == 3: 148 | plt.plot(t, kernels[ctr + 1], color=params["color_list"][ctr + 1]) 149 | 150 | if ctr == 0: 151 | plt.title(r"$\textbf{Cue\ Regret}$") 152 | elif ctr == 1: 153 | plt.title(r"$\textbf{Cue\ Expected}$") 154 | elif ctr == 2: 155 | plt.title(r"$\textbf{Reward}$") 156 | else: 157 | plt.title(r"$\textbf{Reward\ Coupled}$") 158 | xtic = np.array([0, 0.5, 1]) * params["kernel_length"] / params["sampling_rate"] 159 | plt.xticks(xtic, xtic) 160 | 161 | if ctr == 1: 162 | plt.xlabel("Time [s]", labelpad=0) 163 | 164 | fig.tight_layout(pad=0.8, w_pad=0.7, h_pad=0.5) 165 | plt.savefig( 166 | os.path.join(out_path, "kernels.svg"), 167 | bbox_inches="tight", 168 | pad_inches=0.02, 169 | ) 170 | plt.close() 171 | 172 | print(f"plotting of kernels is done. plots are saved at {out_path}") 173 | 174 | 175 | if __name__ == "__main__": 176 | main() 177 | -------------------------------------------------------------------------------- /config/dopamine_spiking_eshel_uchida_code122_kernel011_inferbaseline_independentkernelsamongneurons_config.yaml: -------------------------------------------------------------------------------- 1 | default: &DEFAULT 2 | exp_name: "dopaminespiking_25msbin_kernellength24_kernelnum3_code122_kernel011_inferbase_independentkernels_kernelsmoothing_0p0005" 3 | data_path: "" # this must be a list of data_path, set to "" if using data_folder as path. 4 | data_folder: "../data/dopamine-spiking-eshel-uchida" # this will look for data in format *trainready.pt 5 | test_data_path: None 6 | ################################# 7 | model_distribution: "binomial" # data distrbution gaussian, binomila, poisson 8 | share_kernels_among_neurons: False # set true to share kernels among neurons 9 | ################################# 10 | # kernel (dictionary) 11 | kernel_normalize: True # True: l2-norm of kernels is set to one after each update 12 | kernel_nonneg: True # True: project kernels into non-negative values 13 | kernel_nonneg_indicator: [0, 1, 1] # list of [0 for +-, 1 for +] or set to None for all + 14 | kernel_num: 3 # number of kernels to learn 15 | kernel_length: 24 # number of samples for kernel in time 16 | kernel_stride: 1 # default 1, convolution stride 17 | kernel_init_smoother: False # flag to init kernels to be smooth 18 | kernel_init_smoother_sigma: 0.5 # sigma of the gaussian kernel for kernel_init_smoother 19 | kernel_smoother: True # flag to apply smoother to the kernel during training 20 | kernel_smoother_penalty_weight: 0.0005 # penalty weight to apply for kernel smoother 21 | kernel_initialization: None # None, or a data path 22 | kernel_initialization_needs_adjustment_of_time_bin_resolution: False 23 | ################################# 24 | # code (representation) 25 | code_nonneg: [1, 2, 2] # apply sign constraint on the code. list of [1 for pos, -1 for neg, 2 for twosided], or set True/False for all 26 | code_sparse_regularization: 0 # apply sparse (lambda l1-norm) regularization on the code 27 | code_sparse_regularization_decay: 1 # apply decay factor to lambda at every encoder iteration 28 | code_group_neural_firings_regularization: 0 # if > 0, then it applies groupping across neurons 29 | code_q_regularization: False # set True to apply Q-regularization on the norm of the code 30 | code_q_regularization_matrix: None # The matrix of relations between the codes (if flag is True, use the path to load) 31 | code_q_regularization_matrix_path: "" 32 | code_q_regularization_period: 1 # the period to apply Q-regularization in encoder iterations 33 | code_q_regularization_scale: 2.5 # scale factor in front of the Q-regularization term 34 | code_q_regularization_norm_type: 2 # Set to the norm number you want the Q-regularization to be applied 35 | code_supp: True # True: apply known event indices (supp) into code x 36 | code_topk: False # True: keep only top k indices in each kernel code non-zero (this is greedy) 37 | code_topk_sparse: None # number of top k non-zero entires in each code kernel 38 | code_topk_period: None # period on encoder iteration to apply topk 39 | code_l1loss_bp: False # True: to include l1-norm of the code in the loss during training 40 | code_l1loss_bp_penalty_weight: 0 # amount of sparse regularization of the code with bp during training 41 | ################################# 42 | est_baseline_activity: True # True: estimate the baseline activity along with the code in the encoder 43 | poisson_stability_name: None # type of non-linearity to use on poisson case for encoder stability 44 | poisson_peak: 1 # For ELU "poisson_stability_name", this peak must be set to a value 45 | ################################# 46 | # unrolling parameters 47 | unrolling_num: 100 # number of unrolling iterations in the encoder 48 | unrolling_mode: "fista" # ista or fista encoder 49 | unrolling_alpha: 0.1 # alpha step size in unrolling 50 | unrolling_prox: "shrinkage" # type of proximal operator (shrinkage, threshold) 51 | unrolling_threshold: None # must set to a value if unrolling_prox is threshold" 52 | ################################# 53 | # training related 54 | # default optimizer is ADAM. 55 | optimizer_lr: 1e-2 # learning rate for training the model (learning the kernels) 56 | optimizer_lr_step: 1000 # number of steps (updates) after which the lr will decay 57 | optimizer_lr_decay: 1 # decay factor for learning rate 58 | optimizer_adam_eps: 1e-3 # eps parameter of adam optimizer 59 | optimizer_adam_weight_decay: 0 # weight_decay parameter for adam optimizer 60 | # 61 | backward_gradient_decsent: bprop" # type of backward gradient update (bprop, truncated_bprop) 62 | backward_truncated_bprop_itr: 10 # must be set for truncated_bprop 63 | # 64 | train_num_epochs: 600 # number of epochs for training 65 | train_data_shuffle: True # True: to shuffle dataset at every epoch for training 66 | train_batch_size: 32 # batch size for training 67 | train_num_workers: 4 # number of workers to load data 68 | train_val_split: 1 # 1: use all for train. percentage of data used to train, rest to be used for validation. 69 | train_with_fraction: 1 # 1 for all the data, or a fraction e.g. 0.1 70 | # 71 | enable_board: True 72 | log_info_epoch_period: 40 # period to push small info into the board 73 | log_model_epoch_period: 3000 # period to save model 74 | log_fig_epoch_period: 40 # period to push figures into the board 75 | tqdm_prints_disable: False # True: to disable prints of epoch training process 76 | tqdm_prints_inside_disable: True # True: to disable prints inside of epoch training process -------------------------------------------------------------------------------- /config/whisker_config.yaml: -------------------------------------------------------------------------------- 1 | default: &DEFAULT 2 | exp_name: "whisker_05msbinres_lamp03_topk16_smoothkernelp003" 3 | data_path: ["../data/whisker/whisker_train_5msbinres_general_format_processed_kernellength25_kernelnum1_trainready.pt"] # give it a list of datasets 4 | data_folder: None # this will look for data in format *trainready.pt 5 | 6 | test_data_path: ["../data/whisker/whisker_test_5msbinres_general_format_processed_kernellength25_kernelnum1_trainready.pt"] # give it a list of datasets 7 | # data_path: ["../data/whisker/whisker_train_10msbinres_general_format_processed_kernellength12_kernelnum1_trainready.pt"] # give it a list of datasets 8 | # test_data_path: ["../data/whisker/whisker_test_10msbinres_general_format_processed_kernellength12_kernelnum1_trainready.pt"] # give it a list of datasets 9 | ################################# 10 | model_distribution: "binomial" # data distrbution gaussian, binomila, poisson 11 | share_kernels_among_neurons: True # set true to share kernels among neurons 12 | ################################# 13 | # kernel (dictionary) 14 | kernel_normalize: True # True: l2-norm of kernels is set to one after each update 15 | kernel_nonneg: False # True: project kernels into non-negative values 16 | kernel_nonneg_indicator: [0] # 0 for +-, 1 for + 17 | kernel_num: 1 # number of kernels to learn 18 | kernel_length: 25 # number of samples for kernel in time 19 | kernel_stride: 1 # default 1, convolution stride 20 | kernel_init_smoother: False # flag to init kernels to be smooth 21 | kernel_init_smoother_sigma: 0.5 # sigma of the gaussian kernel for kernel_init_smoother 22 | kernel_smoother: True # flag to apply smoother to the kernel during training 23 | kernel_smoother_penalty_weight: 0.003 # penalty weight to apply for kernel smoother 24 | kernel_initialization: "../data/whisker/kernel_init_25.pt" # None, or a data path 25 | ################################# 26 | # code (representation) 27 | code_nonneg: [1] # apply sign constraint on the code. 1 for pos, -1 for neg, 2 for twosided 28 | code_sparse_regularization: 0.03 # apply sparse (lambda l1-norm) regularization on the code 29 | code_sparse_regularization_decay: 1 # apply decay factor to lambda at every encoder iteration 30 | code_q_regularization: False # set True to apply Q-regularization on the norm of the code 31 | code_q_regularization_matrix: None # The matrix of relations between the codes (if flag is True, use the path to load) 32 | code_q_regularization_matrix_path: None 33 | code_q_regularization_period: 1 # the period to apply Q-regularization in encoder iterations 34 | code_q_regularization_scale: 5 # scale factor in front of the Q-regularization term 35 | code_q_regularization_norm_type: 2 # Set to the norm number you want the Q-regularization to be applied 36 | code_supp: False # True: apply known event indices (supp) into code x 37 | code_topk: True # True: keep only top k indices in each kernel code non-zero (this is greedy) 38 | code_topk_sparse: 16 # number of top k non-zero entires in each code kernel 39 | code_topk_period: 10 # period on encoder iteration to apply topk 40 | code_l1loss_bp: True # True: to include l1-norm of the code in the loss during training 41 | code_l1loss_bp_penalty_weight: 0.03 # amount of sparse regularization of the code with bp during training 42 | ################################# 43 | est_baseline_activity: False # True: estimate the baseline activity along with the code in the encoder 44 | poisson_stability_name: None # type of non-linearity to use on poisson case for encoder stability 45 | poisson_peak: 1 # For ELU "poisson_stability_name", this peak must be set to a value 46 | ################################# 47 | # unrolling parameters 48 | unrolling_num: 800 # number of unrolling iterations in the encoder 49 | unrolling_mode: "fista" # ista or fista encoder 50 | unrolling_alpha: 0.5 # alpha step size in unrolling 51 | unrolling_prox: "shrinkage" # type of proximal operator (shrinkage, threshold) 52 | unrolling_threshold: None # must set to a value if unrolling_prox is threshold" 53 | ################################# 54 | # training related 55 | # default optimizer is ADAM. 56 | optimizer_lr: 1e-2 # learning rate for training the model (learning the kernels) 57 | optimizer_lr_step: 1000 # number of steps (updates) after which the lr will decay 58 | optimizer_lr_decay: 1 # decay factor for learning rate 59 | optimizer_adam_eps: 1e-3 # eps parameter of adam optimizer 60 | optimizer_adam_weight_decay: 0 # weight_decay parameter for adam optimizer 61 | # 62 | backward_gradient_decsent: truncated_bprop" # type of backward gradient update (bprop, truncated_bprop) 63 | backward_truncated_bprop_itr: 20 # must be set for truncated_bprop 64 | # 65 | train_num_epochs: 120 # number of epochs for training 66 | train_data_shuffle: True # True: to shuffle dataset at every epoch for training 67 | train_batch_size: 30 # batch size for training 68 | train_num_workers: 4 # number of workers to load data 69 | train_val_split: 1 # 1: use all for train. percentage of data used to train, rest to be used for validation. 70 | # 71 | enable_board: True 72 | log_info_epoch_period: 1 # period to push small info into the board 73 | log_model_epoch_period: 200 # period to save model 74 | log_fig_epoch_period: 20 # period to push figures into the board 75 | tqdm_prints_disable: False # True: to disable prints of epoch training process 76 | tqdm_prints_inside_disable: True # True: to disable prints inside of epoch training process -------------------------------------------------------------------------------- /config/dopamine_calcium_saramatias_uchida_config.yaml: -------------------------------------------------------------------------------- 1 | default: &DEFAULT 2 | exp_name: "dopaminecalcium_kernellength60_kernelnum5_code2211n1_kernel00011_qreg_fixedq_2p5_firstshrinkage" 3 | data_path: [ 4 | "../data/dopamine-calcium-saramatias-uchida/VarMag_SM103_20191104_general_format_processed_kernellength60_kernelnum5_trainready.pt", # 20 neurons, 299 trials 5 | "../data/dopamine-calcium-saramatias-uchida/VarMag_SM99_20191109_general_format_processed_kernellength60_kernelnum5_trainready.pt", # 30 neurons, 195 trials 6 | "../data/dopamine-calcium-saramatias-uchida/VarMag_SM104_20191103_general_format_processed_kernellength60_kernelnum5_trainready.pt", # 6 neurons, 252 trials 7 | ] # must be a list of datasets 8 | data_folder: None 9 | test_data_path: None 10 | ################################# 11 | model_distribution: "gaussian" # data distrbution gaussian, binomila, poisson 12 | share_kernels_among_neurons: True # set true to share kernels among neurons 13 | ################################# 14 | # kernel (dictionary) 15 | kernel_normalize: True # True: l2-norm of kernels is set to one after each update 16 | kernel_nonneg: True # True: project kernels into non-negative values 17 | kernel_nonneg_indicator: [0, 0, 0, 1, 1] # 0 for +-, 1 for + 18 | kernel_num: 5 # number of kernels to learn 19 | kernel_length: 60 # number of samples for kernel in time 20 | kernel_stride: 1 # default 1, convolution stride 21 | kernel_init_smoother: False # flag to init kernels to be smooth 22 | kernel_init_smoother_sigma: 0.2 # sigma of the gaussian kernel for kernel_init_smoother 23 | kernel_smoother: False # flag to apply smoother to the kernel during training 24 | kernel_smoother_penalty_weight: 0 # penalty weight to apply for kernel smoother 25 | kernel_initialization: None # None, or a data path 26 | ################################# 27 | # code (representation) 28 | code_nonneg: [2, 2, 1, 1, -1] # apply sign constraint on the code. 1 for pos, -1 for neg, 2 for twosided 29 | code_sparse_regularization: 0 # apply sparse (lambda l1-norm) regularization on the code 30 | code_sparse_regularization_decay: 1 # apply decay factor to lambda at every encoder iteration 31 | code_q_regularization: True # set True to apply Q-regularization on the norm of the code 32 | code_q_regularization_matrix: None # The matrix of relations between the codes (use the path to load) 33 | code_q_regularization_matrix_path: "../data/dopamine-calcium-saramatias-uchida/code_q_regularization_matrix.pt" 34 | code_q_regularization_period: 1 # the period to apply Q-regularization in encoder iterations 35 | code_q_regularization_scale: 2.5 # scale factor in front of the Q-regularization term 36 | code_q_regularization_norm_type: 2 # Set to the norm number you want the Q-regularization to be applied 37 | code_supp: True # True: apply known event indices (supp) into code x 38 | code_topk: False # True: keep only top k indices in each kernel code non-zero (this is greedy) 39 | code_topk_sparse: None # number of top k non-zero entires in each code kernel 40 | code_topk_period: None # period on encoder iteration to apply topk 41 | code_l1loss_bp: False # True: to include l1-norm of the code in the loss during training 42 | code_l1loss_bp_penalty_weight: 0 # amount of sparse regularization of the code with bp during training 43 | ################################# 44 | est_baseline_activity: False # True: estimate the baseline activity along with the code in the encoder 45 | poisson_stability_name: None # type of non-linearity to use on poisson case for encoder stability 46 | poisson_peak: 1 # For ELU "poisson_stability_name", this peak must be set to a value 47 | ################################# 48 | # unrolling parameters 49 | unrolling_num: 100 # number of unrolling iterations in the encoder 50 | unrolling_mode: "fista" # ista or fista encoder 51 | unrolling_alpha: 0.1 # alpha step size in unrolling 52 | unrolling_prox: "shrinkage" # type of proximal operator (shrinkage, threshold) 53 | unrolling_threshold: None # must set to a value if unrolling_prox is threshold" 54 | ################################# 55 | # training related 56 | # default optimizer is ADAM. 57 | optimizer_lr: 1e-2 # learning rate for training the model (learning the kernels) 58 | optimizer_lr_step: 20 # number of steps (updates) after which the lr will decay 59 | optimizer_lr_decay: 1 # decay factor for learning rate 60 | optimizer_adam_eps: 1e-3 # eps parameter of adam optimizer 61 | optimizer_adam_weight_decay: 0 # weight_decay parameter for adam optimizer 62 | # 63 | backward_gradient_decsent: bprop" # type of backward gradient update (bprop, truncated_bprop) 64 | backward_truncated_bprop_itr: 10 # must be set for truncated_bprop 65 | # 66 | train_num_epochs: 15 # number of epochs for training 67 | train_data_shuffle: True # True: to shuffle dataset at every epoch for training 68 | train_batch_size: 8 # batch size for training 69 | train_num_workers: 4 # number of workers to load data 70 | train_val_split: 1 # 1: use all for train. percentage of data used to train, rest to be used for validation. 71 | train_with_fraction: 1 # 1 for all the data, or a fraction e.g. 0.1 72 | # 73 | enable_board: True 74 | log_info_epoch_period: 1 # period to push small info into the board 75 | log_model_epoch_period: 10 # period to save model 76 | log_fig_epoch_period: 1 # period to push figures into the board 77 | tqdm_prints_disable: False # True: to disable prints of epoch training process 78 | tqdm_prints_inside_disable: True # True: to disable prints inside of epoch training process -------------------------------------------------------------------------------- /dunl/postprocess_scripts/plot_kernels_whisker_thalamus.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2025 Bahareh Tolooshams 3 | 4 | plot kernel whisker data 5 | 6 | :author: Bahareh Tolooshams 7 | """ 8 | 9 | import torch 10 | import numpy as np 11 | import os 12 | import pickle 13 | import argparse 14 | import matplotlib as mpl 15 | import matplotlib.pyplot as plt 16 | 17 | import sys 18 | 19 | sys.path.append("../dunl/") 20 | 21 | 22 | def init_params(): 23 | parser = argparse.ArgumentParser(description=__doc__) 24 | 25 | parser.add_argument( 26 | "--res-path", 27 | type=str, 28 | help="res path", 29 | default="../results/whisker_05msbinres_lamp03_topk18_smoothkernelp003_2023_07_19_00_03_18", 30 | # default="../results//whisker_05msbinres_lamp03_topk16_smoothkernelp003_2023_07_20_23_11_21", 31 | ) 32 | parser.add_argument( 33 | "--color-list", 34 | type=list, 35 | help="color list", 36 | default=[ 37 | "cyan", 38 | ], # learning one kernel 39 | ) 40 | parser.add_argument( 41 | "--figsize", 42 | type=tuple, 43 | help="figsize", 44 | default=(1.6, 2), 45 | ) 46 | 47 | args = parser.parse_args() 48 | params = vars(args) 49 | 50 | return params 51 | 52 | 53 | def main(): 54 | print("Predict.") 55 | 56 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 57 | print("device is", device) 58 | 59 | # init parameters -------------------------------------------------------# 60 | print("init parameters.") 61 | params_init = init_params() 62 | 63 | # take parameters from the result path 64 | params = pickle.load( 65 | open(os.path.join(params_init["res_path"], "params.pickle"), "rb") 66 | ) 67 | for key in params_init.keys(): 68 | params[key] = params_init[key] 69 | 70 | if params["data_path"] == "": 71 | data_folder = params["data_folder"] 72 | filename_list = os.listdir(data_folder) 73 | data_path_list = [ 74 | f"{data_folder}/{x}" for x in filename_list if "trainready.pt" in x 75 | ] 76 | else: 77 | data_path_list = params["data_path"] 78 | 79 | print("There {} dataset in the folder.".format(len(data_path_list))) 80 | 81 | # set time bin resolution -----------------------------------------------# 82 | data_dict = torch.load(data_path_list[0]) 83 | params["time_bin_resolution"] = data_dict["time_bin_resolution"] 84 | # create folders -------------------------------------------------------# 85 | model_path = os.path.join( 86 | params["res_path"], 87 | "model", 88 | "model_final.pt", 89 | ) 90 | 91 | out_path = os.path.join( 92 | params["res_path"], 93 | "figures", 94 | ) 95 | if not os.path.exists(out_path): 96 | os.makedirs(out_path) 97 | 98 | # load model ------------------------------------------------------# 99 | net = torch.load(model_path, map_location=device) 100 | net.to(device) 101 | net.eval() 102 | 103 | kernels = np.squeeze(net.get_param("H").clone().detach().cpu().numpy(), axis=1) 104 | t = np.linspace( 105 | 0, 106 | params["kernel_length"] * params["time_bin_resolution"], 107 | params["kernel_length"], 108 | ) 109 | 110 | # plot configuration -------------------------------------------------------# 111 | 112 | axes_fontsize = 10 113 | legend_fontsize = 8 114 | tick_fontsize = 10 115 | title_fontsize = 10 116 | 117 | # upadte plot parameters 118 | # style 119 | mpl.rcParams.update( 120 | { 121 | "pgf.texsystem": "pdflatex", 122 | "text.usetex": True, 123 | "axes.labelsize": axes_fontsize, 124 | "axes.titlesize": title_fontsize, 125 | "legend.fontsize": legend_fontsize, 126 | "xtick.labelsize": tick_fontsize, 127 | "ytick.labelsize": tick_fontsize, 128 | "text.latex.preamble": r"\usepackage{bm}", 129 | "axes.unicode_minus": False, 130 | } 131 | ) 132 | 133 | # plot -------------------------------------------------------# 134 | fig, ax = plt.subplots(1, 1, sharex=True, sharey=True, figsize=params["figsize"]) 135 | 136 | ax.tick_params(axis="x", direction="out") 137 | ax.tick_params(axis="y", direction="out") 138 | ax.spines["right"].set_visible(False) 139 | ax.spines["top"].set_visible(False) 140 | 141 | for ctr in range(params["kernel_num"]): 142 | plt.subplot(1, 1, ctr + 1) 143 | ax.axhline(0, color="gray", lw=0.3) 144 | 145 | plt.plot(t, kernels[ctr], color=params["color_list"][ctr]) 146 | 147 | print(t) 148 | stim = np.sin(2 * np.pi * (1 / 125 * t)) 149 | stim /= np.linalg.norm(stim) 150 | plt.plot(t, stim, color="gray", lw=0.5) 151 | xtic = ( 152 | np.array([0, 0.5, 1]) 153 | * params["kernel_length"] 154 | * params["time_bin_resolution"] 155 | ) 156 | xtic = [int(x) for x in xtic] 157 | plt.xticks(xtic, xtic) 158 | plt.xlabel("Time [ms]", labelpad=0) 159 | 160 | fig.tight_layout(pad=0.8, w_pad=0.7, h_pad=0.5) 161 | plt.savefig( 162 | os.path.join(out_path, "kernels.svg"), 163 | bbox_inches="tight", 164 | pad_inches=0.02, 165 | ) 166 | plt.close() 167 | 168 | print(f"plotting of kernels is done. plots are saved at {out_path}") 169 | 170 | 171 | if __name__ == "__main__": 172 | main() 173 | -------------------------------------------------------------------------------- /config/dopamine_fiberphotometry_saramatias_uchida_config_1window_1kernel.yaml: -------------------------------------------------------------------------------- 1 | default: &DEFAULT 2 | exp_name: "fiber" 3 | data_path: [ 4 | "../data/fiberphotometry-saramatias-uchida/Gauntlet_FreeBeh_SM156_20221012_general_format_processed_trainready.pt", # 5 | ] # must be a list of datasets 6 | data_folder: None 7 | test_data_path: None 8 | ################################# 9 | ##### these are about the data 10 | number_of_window: 19 #19 11 | neuron_index: nan # it starts from 0 12 | ################################# 13 | model_distribution: "gaussian" # data distrbution gaussian, binomila, poisson 14 | share_kernels_among_neurons: False # set true to share kernels among neurons 15 | ################################# 16 | # kernel (dictionary) 17 | kernel_normalize: True # True: l2-norm of kernels is set to one after each update 18 | kernel_nonneg: True # True: project kernels into non-negative values 19 | kernel_nonneg_indicator: [1] # 0 for +-, 1 for + 20 | kernel_num: 1 # number of kernels to learn 21 | kernel_length: 30 # number of samples for kernel in time (20Hz acquisition) 30 for VS fibers, 20 for DS 22 | kernel_stride: 1 # default 1, convolution stride 23 | kernel_init_smoother: False # flag to init kernels to be smooth 24 | kernel_init_smoother_sigma: 1 # sigma of the gaussian kernel for kernel_init_smoother 25 | kernel_smoother: True # flag to apply smoother to the kernel during training ********* 26 | kernel_smoother_penalty_weight: 0.002 # 0.01 # penalty weight to apply for kernel smoother 27 | kernel_initialization: None # None, or a data path 28 | kernel_initialization_needs_adjustment_of_time_bin_resolution: False 29 | ################################# 30 | # code (representation) 31 | code_nonneg: [1] # apply sign constraint on the code. 1 for pos, -1 for neg, 2 for twosided 32 | code_sparse_regularization: 0.25 # apply sparse (lambda l1-norm) regularization on the code - default: 0.05 33 | code_sparse_regularization_decay: 1 # apply decay factor to lambda at every encoder iteration 34 | code_group_neural_firings_regularization: 0 # if > 0, then it applies groupping across neurons 35 | code_q_regularization: False # set True to apply Q-regularization on the norm of the code 36 | code_q_regularization_matrix: None # The matrix of relations between the codes (use the path to load) 37 | code_q_regularization_matrix_path: None 38 | code_q_regularization_period: 1 # the period to apply Q-regularization in encoder iterations 39 | code_q_regularization_scale: 2.5 # scale factor in front of the Q-regularization term 40 | code_q_regularization_norm_type: 2 # Set to the norm number you want the Q-regularization to be applied 41 | code_supp: False # True: apply known event indices (supp) into code x 42 | code_topk: False # True: keep only top k indices in each kernel code non-zero (this is greedy) 43 | code_topk_sparse: 100 # number of top k non-zero entires in each code kernel - default 20 44 | code_topk_period: 300 # period on encoder iteration to apply topk 45 | code_l1loss_bp: True # True: to include l1-norm of the code in the loss during training 46 | code_l1loss_bp_penalty_weight: 0.25 # amount of sparse regularization of the code with bp during training 47 | ################################# 48 | est_baseline_activity: True # True: estimate the baseline activity along with the code in the encoder 49 | poisson_stability_name: None # type of non-linearity to use on poisson case for encoder stability 50 | poisson_peak: 1 # For ELU "poisson_stability_name", this peak must be set to a value 51 | ################################# 52 | # unrolling parameters 53 | unrolling_num: 1000 # number of unrolling iterations in the encoder 54 | unrolling_mode: "fista" # ista or fista encoder 55 | unrolling_alpha: 0.01 # alpha step size in unrolling 56 | unrolling_prox: "shrinkage" # type of proximal operator (shrinkage, threshold) 57 | unrolling_threshold: None # must set to a value if unrolling_prox is threshold" 58 | ################################# 59 | # training related 60 | # default optimizer is ADAM. 61 | optimizer_lr: 1e-2 # learning rate for training the model (learning the kernels) 62 | optimizer_lr_step: 20 # number of steps (updates) after which the lr will decay 63 | optimizer_lr_decay: 1 # decay factor for learning rate 64 | optimizer_adam_eps: 1e-3 # eps parameter of adam optimizer 65 | optimizer_adam_weight_decay: 0 # weight_decay parameter for adam optimizer 66 | # 67 | backward_gradient_decsent: "bprop" # type of backward gradient update (bprop, truncated_bprop) 68 | backward_truncated_bprop_itr: 10 # must be set for truncated_bprop this can be increased if topk is False - then we leverage backprop, which is lost when topk is True 69 | # 70 | train_num_epochs: 1000 # number of epochs for training - default : 1000 71 | train_data_shuffle: True # True: to shuffle dataset at every epoch for training 72 | train_batch_size: 32 # batch size for training 73 | train_num_workers: 8 # number of workers to load data 74 | train_val_split: 1 # 1: use all for train. percentage of data used to train, rest to be used for validation. 75 | train_with_fraction: 1 # 1 for all the data, or a fraction e.g. 0.1 76 | # 77 | enable_board: True 78 | log_info_epoch_period: 10 # period to push small info into the board 79 | log_model_epoch_period: 50 # period to save model 80 | log_fig_epoch_period: 10 # period to push figures into the board 81 | tqdm_prints_disable: False # True: to disable prints of epoch training process 82 | tqdm_prints_inside_disable: True # True: to disable prints inside of epoch training process -------------------------------------------------------------------------------- /dunl/postprocess_scripts/plot_report_r2_local_deconv_on_test.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2025 Bahareh Tolooshams 3 | 4 | plot recompose data 5 | 6 | :author: Bahareh Tolooshams 7 | """ 8 | 9 | import torch 10 | import numpy as np 11 | import os 12 | import pickle 13 | import argparse 14 | 15 | import sys 16 | 17 | sys.path.append("../dunl/") 18 | 19 | import datasetloader, utils 20 | 21 | 22 | def init_params(): 23 | parser = argparse.ArgumentParser(description=__doc__) 24 | 25 | parser.add_argument( 26 | "--res-path-partial", 27 | type=str, 28 | help="res path partial", 29 | default="../results/2000_1sparse_local_deconv_calscenario_shorttrial_structured", 30 | # default="../results/6000_3sparse_local_deconv_calscenario_longtrial", 31 | ) 32 | parser.add_argument( 33 | "--batch-size", 34 | type=int, 35 | help="batch size", 36 | default=128, 37 | ) 38 | parser.add_argument( 39 | "--num-workers", 40 | type=int, 41 | help="number of workers for dataloader", 42 | default=4, 43 | ) 44 | parser.add_argument( 45 | "--color-list", 46 | type=list, 47 | help="color decomposition list", 48 | default=[ 49 | "blue", 50 | "red", 51 | ], # 52 | ) 53 | parser.add_argument( 54 | "--swap-kernel", 55 | type=bool, 56 | help="bool to swap kernel", 57 | default=True, 58 | ) 59 | 60 | args = parser.parse_args() 61 | params = vars(args) 62 | 63 | return params 64 | 65 | 66 | def compute_r2_score(spikes, rate_hat): 67 | # compute r2 score 68 | ss_res = np.mean((spikes - rate_hat), axis=1) ** 2 69 | ss_tot = np.var(spikes) 70 | 71 | r2_fit = 1 - ss_res / ss_tot 72 | 73 | return np.mean(r2_fit) 74 | 75 | 76 | def main(): 77 | print("Predict.") 78 | 79 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 80 | print("device is", device) 81 | 82 | for trials in [25, 50, 100, 200, 400, 800, 1600]: 83 | # init parameters -------------------------------------------------------# 84 | print("init parameters.") 85 | params_init = init_params() 86 | 87 | # this is make sure the inference would be on full eshel data 88 | if ( 89 | params_init["res_path_partial"] 90 | == f"../results/2000_1sparse_local_deconv_calscenario_shorttrial_structured" 91 | ): 92 | params_init[ 93 | "res_path" 94 | ] = f"../results/2000_1sparse_local_deconv_calscenario_shorttrial_structured_{trials}trials_25msbinres_kernellength16_kernelnum2_lam0.1_lamloss0.1_lamdecay1_code_topkTruesparse1period10_kernelsmooth0.015_knownsuppFalse_2023_10_27_07_45_35" 95 | 96 | params_init["test_data_path"] = [ 97 | "../data/local-deconv-calscenario-shorttrial-structured-simulated/test_simulated_1neurons_500trials_25msbinres_8Hzbaseline_nov_general_format_processed_kernellength16_kernelnum2_trainready.pt" 98 | ] 99 | 100 | elif ( 101 | params_init["res_path_partial"] 102 | == f"../results/6000_3sparse_local_deconv_calscenario_longtrial" 103 | ): 104 | params_init[ 105 | "res_path" 106 | ] = f"../results/6000_3sparse_local_deconv_calscenario_longtrial_{trials}trials_lam0.1_lamloss0.1_lamdecay1_code_topkTruesparse3period10_kernelsmooth0.015_knownsuppFalse_2023_11_02_17_48_18" 107 | params_init["test_data_path"] = [ 108 | "../data/local-deconv-calscenario-longtrial-simulated/test_simulated_1neurons_500trials_25msbinres_8Hzbaseline_long_general_format_processed_kernellength16_kernelnum2_trainready.pt" 109 | ] 110 | 111 | # take parameters from the result path 112 | params = pickle.load( 113 | open(os.path.join(params_init["res_path"], "params.pickle"), "rb") 114 | ) 115 | for key in params_init.keys(): 116 | params[key] = params_init[key] 117 | 118 | postprocess_path = os.path.join( 119 | params["res_path"], 120 | "postprocess", 121 | ) 122 | 123 | data_path_list = params["test_data_path"] 124 | 125 | print("There {} dataset in the folder.".format(len(data_path_list))) 126 | 127 | # set time bin resolution -----------------------------------------------# 128 | data_dict = torch.load(data_path_list[0]) 129 | params["time_bin_resolution"] = data_dict["time_bin_resolution"] 130 | 131 | # create datasets -------------------------------------------------------# 132 | dataset = datasetloader.DUNLdatasetwithRasterWithCodeRate( 133 | params["test_data_path"][0] 134 | ) 135 | datafile_name = params["test_data_path"][0].split("/")[-1].split(".pt")[0] 136 | 137 | # create folders -------------------------------------------------------# 138 | 139 | codes = dataset.codes 140 | rate = dataset.rate 141 | 142 | # train_num_trials = len(dataset) 143 | 144 | y = torch.load( 145 | os.path.join(postprocess_path, "test_y_{}.pt".format(datafile_name)) 146 | ) 147 | y = y[:, 0, :] 148 | 149 | rate_hat = torch.load( 150 | os.path.join(postprocess_path, "test_ratehat_{}.pt".format(datafile_name)) 151 | ) 152 | rate_hat = rate_hat[:, 0, :] 153 | 154 | r2_score = compute_r2_score(y.numpy(), rate_hat.numpy()) 155 | print(f"DUNL trial {trials}, r2 {r2_score}") 156 | 157 | 158 | if __name__ == "__main__": 159 | main() 160 | -------------------------------------------------------------------------------- /dunl/postprocess_scripts/save_data_for_pcanmf_dopamine_spiking_eshel_uchida.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2025 Bahareh Tolooshams 3 | 4 | plot code data 5 | 6 | :author: Bahareh Tolooshams 7 | """ 8 | 9 | import torch 10 | import numpy as np 11 | import os 12 | import pickle 13 | import argparse 14 | 15 | 16 | def init_params(): 17 | parser = argparse.ArgumentParser(description=__doc__) 18 | 19 | parser.add_argument( 20 | "--res-path", 21 | type=str, 22 | help="res path", 23 | default="../results/dopaminespiking_25msbin_kernellength24_kernelnum3_codefree_kernel111_2023_07_14_12_37_30", 24 | ) 25 | parser.add_argument( 26 | "--reward-amount-list", 27 | type=list, 28 | help="reward amount list", 29 | default=[0.1, 0.3, 1.2, 2.5, 5.0, 10.0, 20.0], 30 | ) 31 | parser.add_argument( 32 | "--window-dur", 33 | type=int, 34 | help="window duration to get average activity", 35 | default=24, # this is after time bin resolution 36 | ) 37 | parser.add_argument( 38 | "--save-only-sur", 39 | type=bool, 40 | help="save only surprise trials", 41 | default=False, 42 | ) 43 | args = parser.parse_args() 44 | params = vars(args) 45 | 46 | return params 47 | 48 | 49 | def main(): 50 | print("Predict.") 51 | 52 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 53 | print("device is", device) 54 | 55 | # init parameters -------------------------------------------------------# 56 | print("init parameters.") 57 | params_init = init_params() 58 | 59 | # take parameters from the result path 60 | params = pickle.load( 61 | open(os.path.join(params_init["res_path"], "params.pickle"), "rb") 62 | ) 63 | for key in params_init.keys(): 64 | params[key] = params_init[key] 65 | 66 | if params["data_path"] == "": 67 | data_folder = params["data_folder"] 68 | filename_list = os.listdir(data_folder) 69 | data_path_list = [ 70 | f"{data_folder}/{x}" for x in filename_list if "trainready.pt" in x 71 | ] 72 | else: 73 | data_path_list = params["data_path"] 74 | 75 | print("There {} dataset in the folder.".format(len(data_path_list))) 76 | 77 | # create folders -------------------------------------------------------# 78 | 79 | postprocess_path = os.path.join( 80 | params["res_path"], 81 | "postprocess", 82 | ) 83 | 84 | # load data -------------------------------------------------------# 85 | 86 | for data_path in data_path_list: 87 | datafile_name = data_path.split("/")[-1].split(".pt")[0] 88 | 89 | y = torch.load(os.path.join(postprocess_path, "y_{}.pt".format(datafile_name))) 90 | x = torch.load(os.path.join(postprocess_path, "x_{}.pt".format(datafile_name))) 91 | label_int = torch.load( 92 | os.path.join(postprocess_path, "label_{}.pt".format(datafile_name)) 93 | ) 94 | 95 | label = label_int.clone() 96 | tmp_ctr = 0 97 | for reward in params["reward_amount_list"]: 98 | tmp_ctr += 1 99 | label[label == tmp_ctr] = reward 100 | 101 | num_trials = y.shape[0] 102 | 103 | yavg = list() 104 | rew_amount = list() 105 | 106 | # go over all trials 107 | for i in range(num_trials): 108 | yi = y[i] 109 | xi = x[i] 110 | labeli = label[i] 111 | 112 | # skip if it's a expected trial 113 | if params["save_only_sur"]: 114 | cue_flag = torch.sum(torch.abs(xi[0]), dim=-1).item() 115 | if cue_flag: 116 | # expected trial hence, skip 117 | continue 118 | 119 | cue_flag = torch.sum(torch.abs(xi[0]), dim=-1).item() 120 | if cue_flag: 121 | pass 122 | else: 123 | labeli = -1 * labeli # suprise is negative 124 | 125 | # reward presence 126 | reward_onset = np.where(xi[1] > 0)[-1][0] 127 | 128 | y_curr = yi[:, reward_onset : reward_onset + params["window_dur"]] 129 | yavg.append(y_curr) 130 | rew_amount.append(labeli) 131 | 132 | # (neurons, time, trials) 133 | yavg = torch.stack(yavg, dim=-1).clone().detach().cpu().numpy() 134 | rew_amount = np.array(rew_amount) 135 | 136 | if 1: 137 | if params["save_only_sur"]: 138 | np.save( 139 | os.path.join( 140 | postprocess_path, 141 | "y_for_pcanmf_{}_only_sur.npy".format(datafile_name), 142 | ), 143 | yavg, 144 | ) 145 | np.save( 146 | os.path.join( 147 | postprocess_path, 148 | "label_for_pcanmf_{}_only_sur.npy".format(datafile_name), 149 | ), 150 | rew_amount, 151 | ) 152 | else: 153 | np.save( 154 | os.path.join( 155 | postprocess_path, "y_for_pcanmf_{}.npy".format(datafile_name) 156 | ), 157 | yavg, 158 | ) 159 | np.save( 160 | os.path.join( 161 | postprocess_path, 162 | "label_for_pcanmf_{}.npy".format(datafile_name), 163 | ), 164 | rew_amount, 165 | ) 166 | 167 | 168 | if __name__ == "__main__": 169 | main() 170 | -------------------------------------------------------------------------------- /dunl/postprocess_scripts/plot_expsetup_dopamine_calcium_saramatias_uchida.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2025 Bahareh Tolooshams 3 | 4 | plot experimental setup 5 | 6 | :author: Bahareh Tolooshams 7 | """ 8 | 9 | import torch 10 | import numpy as np 11 | import os 12 | import pickle 13 | import argparse 14 | import matplotlib as mpl 15 | import matplotlib.pyplot as plt 16 | 17 | 18 | 19 | def init_params(): 20 | parser = argparse.ArgumentParser(description=__doc__) 21 | 22 | parser.add_argument( 23 | "--res-path", 24 | type=str, 25 | help="res path", 26 | default="../results/dopaminecalcium_kernellength60_kernelnum5_code2211n1_kernel00011_qreg_2023_07_13_11_37_31", 27 | ) 28 | parser.add_argument( 29 | "--batch-size", 30 | type=int, 31 | help="batch size", 32 | default=128, 33 | ) 34 | parser.add_argument( 35 | "--num-workers", 36 | type=int, 37 | help="number of workers for dataloader", 38 | default=4, 39 | ) 40 | parser.add_argument( 41 | "--regret-dur", 42 | type=int, 43 | help="regret duration after onset in samples", 44 | default=60, 45 | ) 46 | parser.add_argument( 47 | "--sampling-rate", 48 | type=int, 49 | help="sampling rate", 50 | default=15, 51 | ) 52 | parser.add_argument( 53 | "--reward-delay", 54 | type=int, 55 | help="reward delay from the cue onset", 56 | default=45, 57 | ) 58 | parser.add_argument( 59 | "--duration", 60 | type=int, 61 | help="duration", 62 | default=90, 63 | ) 64 | 65 | args = parser.parse_args() 66 | params = vars(args) 67 | 68 | return params 69 | 70 | 71 | def main(): 72 | print("Predict.") 73 | 74 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 75 | print("device is", device) 76 | 77 | # init parameters -------------------------------------------------------# 78 | print("init parameters.") 79 | params_init = init_params() 80 | 81 | # take parameters from the result path 82 | params = pickle.load( 83 | open(os.path.join(params_init["res_path"], "params.pickle"), "rb") 84 | ) 85 | for key in params_init.keys(): 86 | params[key] = params_init[key] 87 | 88 | # create folders -------------------------------------------------------# 89 | out_path = os.path.join( 90 | params["res_path"], 91 | "figures", 92 | ) 93 | if not os.path.exists(out_path): 94 | os.makedirs(out_path) 95 | 96 | # plot configuration -------------------------------------------------------# 97 | 98 | axes_fontsize = 10 99 | legend_fontsize = 8 100 | tick_fontsize = 10 101 | title_fontsize = 10 102 | 103 | # upadte plot parameters 104 | # style 105 | mpl.rcParams.update( 106 | { 107 | "pgf.texsystem": "pdflatex", 108 | "text.usetex": True, 109 | "axes.labelsize": axes_fontsize, 110 | "axes.titlesize": title_fontsize, 111 | "legend.fontsize": legend_fontsize, 112 | "xtick.labelsize": tick_fontsize, 113 | "ytick.labelsize": tick_fontsize, 114 | "text.latex.preamble": r"\usepackage{bm}", 115 | "axes.unicode_minus": False, 116 | } 117 | ) 118 | 119 | # plot -------------------------------------------------------# 120 | fig, axn = plt.subplots(3, 1, sharex=True, sharey=True) 121 | 122 | for ax in axn.flat: 123 | ax.tick_params(axis="x", direction="out") 124 | ax.tick_params(axis="y", direction="out") 125 | ax.spines["right"].set_visible(False) 126 | ax.spines["top"].set_visible(False) 127 | 128 | dot_loc = 1 129 | 130 | cue_regret_dot = 0 131 | cue_expected_dot = 0 132 | reward_expected_dot = params["reward_delay"] 133 | reward_surprise_dot = 0 134 | 135 | plt.subplot(3, 1, 1) 136 | plt.title(r"$\textbf{Regret\ Trials}$") 137 | plt.axvline(x=0, linestyle="--", linewidth=0.5, color="black") 138 | ctr = -1 139 | plt.plot( 140 | cue_regret_dot, 141 | dot_loc, 142 | ".", 143 | markersize=10, 144 | color="Orange", 145 | ) 146 | 147 | plt.subplot(3, 1, 2) 148 | plt.title(r"$\textbf{Surprise\ Trials}$") 149 | plt.axvline(x=0, linestyle="--", linewidth=0.5, color="black") 150 | plt.plot( 151 | reward_surprise_dot, 152 | dot_loc, 153 | ".", 154 | markersize=10, 155 | color="Blue", 156 | ) 157 | 158 | plt.subplot(3, 1, 3) 159 | plt.title(r"$\textbf{Expected\ Trials}$") 160 | plt.axvline(x=0, linestyle="--", linewidth=0.5, color="black") 161 | plt.axvline(x=params["reward_delay"], linestyle="--", linewidth=0.5, color="black") 162 | plt.plot( 163 | cue_expected_dot, 164 | dot_loc, 165 | ".", 166 | markersize=10, 167 | color="Orange", 168 | ) 169 | plt.plot( 170 | reward_expected_dot, 171 | dot_loc, 172 | ".", 173 | markersize=10, 174 | color="Blue", 175 | ) 176 | xtic = np.array([0, 0.5, 1]) * params["duration"] 177 | plt.xticks(xtic, xtic / params["sampling_rate"]) 178 | plt.xlabel("Time [s]", labelpad=0) 179 | 180 | fig.tight_layout(pad=0.8, w_pad=0.7, h_pad=0.5) 181 | plt.savefig( 182 | os.path.join(out_path, "experiment_setup.svg"), 183 | bbox_inches="tight", 184 | pad_inches=0.02, 185 | ) 186 | plt.close() 187 | 188 | print(f"plotting of experimental setup is done. plots are saved at {out_path}") 189 | 190 | 191 | if __name__ == "__main__": 192 | main() 193 | -------------------------------------------------------------------------------- /config/dopamine_calcium_saramatias_uchida_inferbaseline_config.yaml: -------------------------------------------------------------------------------- 1 | default: &DEFAULT 2 | exp_name: "dopaminecalcium_kernellength60_kernelnum5_code2211n1_kernel00011_qreg_fixedq_2p5_firstshrinkage_inferbase" 3 | data_path: [ 4 | "../data/dopamine-calcium-saramatias-uchida/VarMag_SM103_20191104_general_format_processed_kernellength60_kernelnum5_trainready.pt", # 20 neurons, 299 trials 5 | "../data/dopamine-calcium-saramatias-uchida/VarMag_SM99_20191109_general_format_processed_kernellength60_kernelnum5_trainready.pt", # 30 neurons, 195 trials 6 | "../data/dopamine-calcium-saramatias-uchida/VarMag_SM104_20191103_general_format_processed_kernellength60_kernelnum5_trainready.pt", # 6 neurons, 252 trials 7 | ] # must be a list of datasets 8 | data_folder: None 9 | test_data_path: None 10 | ################################# 11 | model_distribution: "gaussian" # data distrbution gaussian, binomila, poisson 12 | share_kernels_among_neurons: True # set true to share kernels among neurons 13 | ################################# 14 | # kernel (dictionary) 15 | kernel_normalize: True # True: l2-norm of kernels is set to one after each update 16 | kernel_nonneg: True # True: project kernels into non-negative values 17 | kernel_nonneg_indicator: [0, 0, 0, 1, 1] # 0 for +-, 1 for + 18 | kernel_num: 5 # number of kernels to learn 19 | kernel_length: 60 # number of samples for kernel in time 20 | kernel_stride: 1 # default 1, convolution stride 21 | kernel_init_smoother: False # flag to init kernels to be smooth 22 | kernel_init_smoother_sigma: 0.2 # sigma of the gaussian kernel for kernel_init_smoother 23 | kernel_smoother: False # flag to apply smoother to the kernel during training 24 | kernel_smoother_penalty_weight: 0 # penalty weight to apply for kernel smoother 25 | kernel_initialization: None # None, or a data path 26 | kernel_initialization_needs_adjustment_of_time_bin_resolution: False 27 | ################################# 28 | # code (representation) 29 | code_nonneg: [2, 2, 1, 1, -1] # apply sign constraint on the code. 1 for pos, -1 for neg, 2 for twosided 30 | code_sparse_regularization: 0 # apply sparse (lambda l1-norm) regularization on the code 31 | code_sparse_regularization_decay: 1 # apply decay factor to lambda at every encoder iteration 32 | code_group_neural_firings_regularization: 0 # if > 0, then it applies groupping across neurons 33 | code_q_regularization: True # set True to apply Q-regularization on the norm of the code 34 | code_q_regularization_matrix: None # The matrix of relations between the codes (use the path to load) 35 | code_q_regularization_matrix_path: "../data/dopamine-calcium-saramatias-uchida/code_q_regularization_matrix.pt" 36 | code_q_regularization_period: 1 # the period to apply Q-regularization in encoder iterations 37 | code_q_regularization_scale: 2.5 # scale factor in front of the Q-regularization term 38 | code_q_regularization_norm_type: 2 # Set to the norm number you want the Q-regularization to be applied 39 | code_supp: True # True: apply known event indices (supp) into code x 40 | code_topk: False # True: keep only top k indices in each kernel code non-zero (this is greedy) 41 | code_topk_sparse: None # number of top k non-zero entires in each code kernel 42 | code_topk_period: None # period on encoder iteration to apply topk 43 | code_l1loss_bp: False # True: to include l1-norm of the code in the loss during training 44 | code_l1loss_bp_penalty_weight: 0 # amount of sparse regularization of the code with bp during training 45 | ################################# 46 | est_baseline_activity: True # True: estimate the baseline activity along with the code in the encoder 47 | poisson_stability_name: None # type of non-linearity to use on poisson case for encoder stability 48 | poisson_peak: 1 # For ELU "poisson_stability_name", this peak must be set to a value 49 | ################################# 50 | # unrolling parameters 51 | unrolling_num: 100 # number of unrolling iterations in the encoder 52 | unrolling_mode: "fista" # ista or fista encoder 53 | unrolling_alpha: 0.1 # alpha step size in unrolling 54 | unrolling_prox: "shrinkage" # type of proximal operator (shrinkage, threshold) 55 | unrolling_threshold: None # must set to a value if unrolling_prox is threshold" 56 | ################################# 57 | # training related 58 | # default optimizer is ADAM. 59 | optimizer_lr: 1e-2 # learning rate for training the model (learning the kernels) 60 | optimizer_lr_step: 20 # number of steps (updates) after which the lr will decay 61 | optimizer_lr_decay: 1 # decay factor for learning rate 62 | optimizer_adam_eps: 1e-3 # eps parameter of adam optimizer 63 | optimizer_adam_weight_decay: 0 # weight_decay parameter for adam optimizer 64 | # 65 | backward_gradient_decsent: bprop" # type of backward gradient update (bprop, truncated_bprop) 66 | backward_truncated_bprop_itr: 10 # must be set for truncated_bprop 67 | # 68 | train_num_epochs: 30 # number of epochs for training 69 | train_data_shuffle: True # True: to shuffle dataset at every epoch for training 70 | train_batch_size: 8 # batch size for training 71 | train_num_workers: 4 # number of workers to load data 72 | train_val_split: 1 # 1: use all for train. percentage of data used to train, rest to be used for validation. 73 | train_with_fraction: 1 # 1 for all the data, or a fraction e.g. 0.1 74 | # 75 | enable_board: True 76 | log_info_epoch_period: 1 # period to push small info into the board 77 | log_model_epoch_period: 30 # period to save model 78 | log_fig_epoch_period: 1 # period to push figures into the board 79 | tqdm_prints_disable: False # True: to disable prints of epoch training process 80 | tqdm_prints_inside_disable: True # True: to disable prints inside of epoch training process -------------------------------------------------------------------------------- /dunl/postprocess_scripts/plot_kernels_dopamine_spiking_eshel_uchida_hor.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2025 Bahareh Tolooshams 3 | 4 | plot rec data 5 | 6 | :author: Bahareh Tolooshams 7 | """ 8 | 9 | import torch 10 | import numpy as np 11 | import os 12 | import pickle 13 | import argparse 14 | import matplotlib as mpl 15 | import matplotlib.pyplot as plt 16 | 17 | import sys 18 | 19 | sys.path.append("../dunl/") 20 | 21 | import utils 22 | 23 | 24 | def init_params(): 25 | parser = argparse.ArgumentParser(description=__doc__) 26 | 27 | parser.add_argument( 28 | "--res-path", 29 | type=str, 30 | help="res path", 31 | default="../results/dopaminespiking_25msbin_kernellength24_kernelnum3_codefree_kernel111_2023_07_14_12_37_30", 32 | ) 33 | parser.add_argument( 34 | "--color-list", 35 | type=list, 36 | help="color list", 37 | default=[ 38 | # "orange", 39 | # "blue", 40 | # "red", 41 | "orange", 42 | "brown", 43 | "black", 44 | ], # cue exp, 2 rewards 45 | ) 46 | parser.add_argument( 47 | "--figsize", 48 | type=tuple, 49 | help="figsize", 50 | default=(1.5, 1.5), 51 | ) 52 | parser.add_argument( 53 | "--swap-kernel", 54 | type=bool, 55 | help="bool to swap kernels", 56 | default=True, 57 | ) 58 | 59 | args = parser.parse_args() 60 | params = vars(args) 61 | 62 | return params 63 | 64 | 65 | def main(): 66 | print("Predict.") 67 | 68 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 69 | print("device is", device) 70 | 71 | # init parameters -------------------------------------------------------# 72 | print("init parameters.") 73 | params_init = init_params() 74 | 75 | # take parameters from the result path 76 | params = pickle.load( 77 | open(os.path.join(params_init["res_path"], "params.pickle"), "rb") 78 | ) 79 | for key in params_init.keys(): 80 | params[key] = params_init[key] 81 | 82 | if params["data_path"] == "": 83 | data_folder = params["data_folder"] 84 | filename_list = os.listdir(data_folder) 85 | data_path_list = [ 86 | f"{data_folder}/{x}" for x in filename_list if "trainready.pt" in x 87 | ] 88 | else: 89 | data_path_list = params["data_path"] 90 | 91 | print("There {} dataset in the folder.".format(len(data_path_list))) 92 | 93 | # set time bin resolution -----------------------------------------------# 94 | data_dict = torch.load(data_path_list[0]) 95 | params["time_bin_resolution"] = data_dict["time_bin_resolution"] 96 | 97 | # create folders -------------------------------------------------------# 98 | model_path = os.path.join( 99 | params["res_path"], 100 | "model", 101 | "model_final.pt", 102 | ) 103 | 104 | out_path = os.path.join( 105 | params["res_path"], 106 | "figures", 107 | ) 108 | if not os.path.exists(out_path): 109 | os.makedirs(out_path) 110 | 111 | # load model ------------------------------------------------------# 112 | net = torch.load(model_path, map_location=device) 113 | net.to(device) 114 | net.eval() 115 | 116 | kernels = net.get_param("H").clone().detach() 117 | if params["swap_kernel"]: 118 | kernels = utils.swap_kernel(kernels, 1, 2) 119 | kernels = np.squeeze(kernels.cpu().numpy()) 120 | 121 | plot_kernel(kernels, params, out_path) 122 | 123 | 124 | def plot_kernel(kernels, params, out_path): 125 | axes_fontsize = 15 126 | legend_fontsize = 8 127 | tick_fontsize = 15 128 | title_fontsize = 15 129 | fontfamily = "sans-serif" 130 | 131 | # upadte plot parameters 132 | # style 133 | mpl.rcParams.update( 134 | { 135 | "pgf.texsystem": "pdflatex", 136 | "text.usetex": True, 137 | "axes.labelsize": axes_fontsize, 138 | "axes.titlesize": title_fontsize, 139 | "legend.fontsize": legend_fontsize, 140 | "xtick.labelsize": tick_fontsize, 141 | "ytick.labelsize": tick_fontsize, 142 | "text.latex.preamble": r"\usepackage{bm}", 143 | "axes.unicode_minus": False, 144 | "font.family": fontfamily, 145 | } 146 | ) 147 | 148 | fig, ax = plt.subplots(1, 1, sharex=True, sharey=True, figsize=params["figsize"]) 149 | 150 | ax.tick_params(axis="x", direction="out") 151 | ax.tick_params(axis="y", direction="out") 152 | ax.spines["right"].set_visible(False) 153 | ax.spines["top"].set_visible(False) 154 | 155 | t = np.linspace( 156 | 0, 157 | params["kernel_length"] * params["time_bin_resolution"], 158 | params["kernel_length"], 159 | ) 160 | 161 | for ctr in range(params["kernel_num"]): 162 | plt.subplot(1, 1, 1) 163 | plt.plot(t, kernels[ctr], color=params["color_list"][ctr], lw=2.5) 164 | xtic = ( 165 | np.array([0, 0.5, 1]) 166 | * params["kernel_length"] 167 | * params["time_bin_resolution"] 168 | ) 169 | xtic = [int(x) for x in xtic] 170 | plt.xticks(xtic, xtic) 171 | plt.yticks([]) 172 | 173 | plt.xlabel("Time (ms)", labelpad=0) 174 | 175 | fig.tight_layout(pad=0.8, w_pad=0.7, h_pad=0.5) 176 | plt.savefig( 177 | os.path.join(out_path, "kernels_one.svg"), 178 | bbox_inches="tight", 179 | pad_inches=0.02, 180 | ) 181 | plt.close() 182 | 183 | print(f"plotting of kernels is done. plots are saved at {out_path}") 184 | 185 | 186 | if __name__ == "__main__": 187 | main() 188 | -------------------------------------------------------------------------------- /dunl/postprocess_scripts/plot_pcanmf_kernels_calcium.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2025 Bahareh Tolooshams 3 | 4 | plot code data 5 | 6 | :author: Bahareh Tolooshams 7 | """ 8 | 9 | import torch 10 | import numpy as np 11 | import os 12 | import pickle 13 | import argparse 14 | import matplotlib as mpl 15 | import matplotlib.pyplot as plt 16 | 17 | def init_params(): 18 | parser = argparse.ArgumentParser(description=__doc__) 19 | 20 | parser.add_argument( 21 | "--res-path", 22 | type=str, 23 | help="res path", 24 | default="../results/dopaminecalcium_kernellength60_kernelnum5_code2211n1_kernel00011_qreg_2023_07_13_11_37_31", 25 | ) 26 | parser.add_argument( 27 | "--num-comp", 28 | type=int, 29 | help="number of components", 30 | default=2, 31 | ) 32 | parser.add_argument( 33 | "--sampling-rate", 34 | type=int, 35 | help="sampling rate", 36 | default=15, 37 | ) 38 | parser.add_argument( 39 | "--color-list", 40 | type=list, 41 | help="color list", 42 | default=[ 43 | "blue", 44 | "red", 45 | "green", 46 | ], 47 | ) 48 | parser.add_argument( 49 | "--figsize", 50 | type=tuple, 51 | help="figsize", 52 | default=(1.6, 2), 53 | ) 54 | 55 | args = parser.parse_args() 56 | params = vars(args) 57 | 58 | return params 59 | 60 | 61 | def main(): 62 | print("Predict.") 63 | 64 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 65 | print("device is", device) 66 | 67 | # init parameters -------------------------------------------------------# 68 | print("init parameters.") 69 | params_init = init_params() 70 | 71 | # take parameters from the result path 72 | params = pickle.load( 73 | open(os.path.join(params_init["res_path"], "params.pickle"), "rb") 74 | ) 75 | for key in params_init.keys(): 76 | params[key] = params_init[key] 77 | 78 | if params["data_path"] == "": 79 | data_folder = params["data_folder"] 80 | filename_list = os.listdir(data_folder) 81 | data_path_list = [ 82 | f"{data_folder}/{x}" for x in filename_list if "trainready.pt" in x 83 | ] 84 | else: 85 | data_path_list = params["data_path"] 86 | 87 | print("There {} dataset in the folder.".format(len(data_path_list))) 88 | 89 | # create folders -------------------------------------------------------# 90 | out_path = os.path.join( 91 | params["res_path"], 92 | "figures", 93 | ) 94 | if not os.path.exists(out_path): 95 | os.makedirs(out_path) 96 | 97 | postprocess_path = os.path.join( 98 | params["res_path"], 99 | "postprocess", 100 | ) 101 | 102 | # load data -------------------------------------------------------# 103 | pca_transform = pickle.load( 104 | open( 105 | os.path.join( 106 | postprocess_path, "pca_transform_{}.pkl".format(params["num_comp"]) 107 | ), 108 | "rb", 109 | ) 110 | ) 111 | nmf_transform = pickle.load( 112 | open( 113 | os.path.join( 114 | postprocess_path, "nmf_transform_{}.pkl".format(params["num_comp"]) 115 | ), 116 | "rb", 117 | ) 118 | ) 119 | 120 | pca_W = pca_transform.components_.T 121 | nmf_W = nmf_transform.components_.T 122 | 123 | plot_pca_nmf(pca_W, params, out_path, "pca_{}".format(params["num_comp"])) 124 | plot_pca_nmf(nmf_W, params, out_path, "nmf_{}".format(params["num_comp"])) 125 | 126 | print(f"plotting of kernels is done. plots are saved at {out_path}") 127 | 128 | 129 | def plot_pca_nmf(W, params, out_path, name): 130 | axes_fontsize = 10 131 | legend_fontsize = 8 132 | tick_fontsize = 10 133 | title_fontsize = 10 134 | 135 | # upadte plot parameters 136 | # style 137 | mpl.rcParams.update( 138 | { 139 | "pgf.texsystem": "pdflatex", 140 | "text.usetex": True, 141 | "axes.labelsize": axes_fontsize, 142 | "axes.titlesize": title_fontsize, 143 | "legend.fontsize": legend_fontsize, 144 | "xtick.labelsize": tick_fontsize, 145 | "ytick.labelsize": tick_fontsize, 146 | "text.latex.preamble": r"\usepackage{bm}", 147 | "axes.unicode_minus": False, 148 | } 149 | ) 150 | 151 | # plot -------------------------------------------------------# 152 | fig, ax = plt.subplots(1, 1, sharex=True, sharey=True, figsize=params["figsize"]) 153 | 154 | ax.tick_params(axis="x", direction="out") 155 | ax.tick_params(axis="y", direction="out") 156 | ax.spines["right"].set_visible(False) 157 | ax.spines["top"].set_visible(False) 158 | 159 | ax.axhline(0, color="gray", lw=0.3) 160 | 161 | t = np.linspace( 162 | 0, 163 | params["kernel_length"] * params["sampling_rate"], 164 | params["kernel_length"], 165 | ) 166 | 167 | plt.subplot(1, 1, 1) 168 | for ctr in range(W.shape[1]): 169 | plt.plot(t, W[:, ctr], color=params["color_list"][ctr]) 170 | xtic = np.array([0, 0.5, 1]) * params["kernel_length"] * params["sampling_rate"] 171 | xtic = [int(x) for x in xtic] 172 | plt.xticks(xtic, xtic) 173 | plt.xlabel("Time [ms]", labelpad=0) 174 | 175 | fig.tight_layout(pad=0.8, w_pad=0.7, h_pad=0.5) 176 | plt.savefig( 177 | os.path.join(out_path, "kernels_{}.svg".format(name)), 178 | bbox_inches="tight", 179 | pad_inches=0.02, 180 | ) 181 | plt.close() 182 | 183 | 184 | if __name__ == "__main__": 185 | main() 186 | -------------------------------------------------------------------------------- /config/dopamine_calcium_saramatias_uchida_independentkernelsamongneurons_config.yaml: -------------------------------------------------------------------------------- 1 | default: &DEFAULT 2 | exp_name: "dopaminecalcium_kernellength60_kernelnum5_code2211n1_kernel00011_qreg_fixedq_2p5_firstshrinkage_independentkernels_kernelsmoothing_0p0005" 3 | data_path: [ 4 | "../data/dopamine-calcium-saramatias-uchida/VarMag_SM103_20191104_general_format_processed_kernellength60_kernelnum5_trainready.pt", # 20 neurons, 299 trials 5 | "../data/dopamine-calcium-saramatias-uchida/VarMag_SM99_20191109_general_format_processed_kernellength60_kernelnum5_trainready.pt", # 30 neurons, 195 trials 6 | "../data/dopamine-calcium-saramatias-uchida/VarMag_SM104_20191103_general_format_processed_kernellength60_kernelnum5_trainready.pt", # 6 neurons, 252 trials 7 | ] # must be a list of datasets 8 | data_folder: None 9 | test_data_path: None 10 | ################################# 11 | model_distribution: "gaussian" # data distrbution gaussian, binomila, poisson 12 | share_kernels_among_neurons: False # set true to share kernels among neurons 13 | ################################# 14 | # kernel (dictionary) 15 | kernel_normalize: True # True: l2-norm of kernels is set to one after each update 16 | kernel_nonneg: True # True: project kernels into non-negative values 17 | kernel_nonneg_indicator: [0, 0, 0, 1, 1] # 0 for +-, 1 for + 18 | kernel_num: 5 # number of kernels to learn 19 | kernel_length: 60 # number of samples for kernel in time 20 | kernel_stride: 1 # default 1, convolution stride 21 | kernel_init_smoother: False # flag to init kernels to be smooth 22 | kernel_init_smoother_sigma: 0.2 # sigma of the gaussian kernel for kernel_init_smoother 23 | kernel_smoother: True # flag to apply smoother to the kernel during training 24 | kernel_smoother_penalty_weight: 0.0005 # penalty weight to apply for kernel smoother 25 | kernel_initialization: None # None, or a data path 26 | kernel_initialization_needs_adjustment_of_time_bin_resolution: False 27 | ################################# 28 | # code (representation) 29 | code_nonneg: [2, 2, 1, 1, -1] # apply sign constraint on the code. 1 for pos, -1 for neg, 2 for twosided 30 | code_sparse_regularization: 0 # apply sparse (lambda l1-norm) regularization on the code 31 | code_sparse_regularization_decay: 1 # apply decay factor to lambda at every encoder iteration 32 | code_group_neural_firings_regularization: 0 # if > 0, then it applies groupping across neurons 33 | code_q_regularization: True # set True to apply Q-regularization on the norm of the code 34 | code_q_regularization_matrix: None # The matrix of relations between the codes (use the path to load) 35 | code_q_regularization_matrix_path: "../data/dopamine-calcium-saramatias-uchida/code_q_regularization_matrix.pt" 36 | code_q_regularization_period: 1 # the period to apply Q-regularization in encoder iterations 37 | code_q_regularization_scale: 2.5 # scale factor in front of the Q-regularization term 38 | code_q_regularization_norm_type: 2 # Set to the norm number you want the Q-regularization to be applied 39 | code_supp: True # True: apply known event indices (supp) into code x 40 | code_topk: False # True: keep only top k indices in each kernel code non-zero (this is greedy) 41 | code_topk_sparse: None # number of top k non-zero entires in each code kernel 42 | code_topk_period: None # period on encoder iteration to apply topk 43 | code_l1loss_bp: False # True: to include l1-norm of the code in the loss during training 44 | code_l1loss_bp_penalty_weight: 0 # amount of sparse regularization of the code with bp during training 45 | ################################# 46 | est_baseline_activity: False # True: estimate the baseline activity along with the code in the encoder 47 | poisson_stability_name: None # type of non-linearity to use on poisson case for encoder stability 48 | poisson_peak: 1 # For ELU "poisson_stability_name", this peak must be set to a value 49 | ################################# 50 | # unrolling parameters 51 | unrolling_num: 100 # number of unrolling iterations in the encoder 52 | unrolling_mode: "fista" # ista or fista encoder 53 | unrolling_alpha: 0.1 # alpha step size in unrolling 54 | unrolling_prox: "shrinkage" # type of proximal operator (shrinkage, threshold) 55 | unrolling_threshold: None # must set to a value if unrolling_prox is threshold" 56 | ################################# 57 | # training related 58 | # default optimizer is ADAM. 59 | optimizer_lr: 1e-2 # learning rate for training the model (learning the kernels) 60 | optimizer_lr_step: 1000 # number of steps (updates) after which the lr will decay 61 | optimizer_lr_decay: 1 # decay factor for learning rate 62 | optimizer_adam_eps: 1e-3 # eps parameter of adam optimizer 63 | optimizer_adam_weight_decay: 0 # weight_decay parameter for adam optimizer 64 | # 65 | backward_gradient_decsent: bprop" # type of backward gradient update (bprop, truncated_bprop) 66 | backward_truncated_bprop_itr: 10 # must be set for truncated_bprop 67 | # 68 | train_num_epochs: 600 # number of epochs for training 69 | train_data_shuffle: True # True: to shuffle dataset at every epoch for training 70 | train_batch_size: 32 # batch size for training 71 | train_num_workers: 4 # number of workers to load data 72 | train_val_split: 1 # 1: use all for train. percentage of data used to train, rest to be used for validation. 73 | train_with_fraction: 1 # 1 for all the data, or a fraction e.g. 0.1 74 | # 75 | enable_board: True 76 | log_info_epoch_period: 40 # period to push small info into the board 77 | log_model_epoch_period: 3000 # period to save model 78 | log_fig_epoch_period: 40 # period to push figures into the board 79 | tqdm_prints_disable: False # True: to disable prints of epoch training process 80 | tqdm_prints_inside_disable: True # True: to disable prints inside of epoch training process -------------------------------------------------------------------------------- /config/whisker_groupneuralfirings_config.yaml: -------------------------------------------------------------------------------- 1 | default: &DEFAULT 2 | exp_name: "whisker_05msbinres_lamp01_grouptop18_smoothkernelp003_groupneuralfiringsp05" 3 | data_path: ["../data/whisker/whisker_train_5msbinres_general_format_processed_kernellength25_kernelnum1_trainready.pt"] # give it a list of datasets 4 | data_folder: None # this will look for data in format *trainready.pt 5 | 6 | test_data_path: ["../data/whisker/whisker_test_5msbinres_general_format_processed_kernellength25_kernelnum1_trainready.pt"] # give it a list of datasets 7 | # data_path: ["../data/whisker/whisker_train_10msbinres_general_format_processed_kernellength12_kernelnum1_trainready.pt"] # give it a list of datasets 8 | # test_data_path: ["../data/whisker/whisker_test_10msbinres_general_format_processed_kernellength12_kernelnum1_trainready.pt"] # give it a list of datasets 9 | ################################# 10 | model_distribution: "binomial" # data distrbution gaussian, binomila, poisson 11 | share_kernels_among_neurons: True # set true to share kernels among neurons 12 | ################################# 13 | # kernel (dictionary) 14 | kernel_normalize: True # True: l2-norm of kernels is set to one after each update 15 | kernel_nonneg: False # True: project kernels into non-negative values 16 | kernel_nonneg_indicator: [0] # 0 for +-, 1 for + 17 | kernel_num: 1 # number of kernels to learn 18 | kernel_length: 25 # number of samples for kernel in time 19 | kernel_stride: 1 # default 1, convolution stride 20 | kernel_init_smoother: False # flag to init kernels to be smooth 21 | kernel_init_smoother_sigma: 0.5 # sigma of the gaussian kernel for kernel_init_smoother 22 | kernel_smoother: True # flag to apply smoother to the kernel during training 23 | kernel_smoother_penalty_weight: 0.003 # penalty weight to apply for kernel smoother 24 | kernel_initialization: "../data/whisker/kernel_init_25.pt" # None, or a data path 25 | kernel_initialization_needs_adjustment_of_time_bin_resolution: False 26 | ################################# 27 | # code (representation) 28 | code_nonneg: [1] # apply sign constraint on the code. 1 for pos, -1 for neg, 2 for twosided 29 | code_sparse_regularization: 0.01 # apply sparse (lambda l1-norm) regularization on the code 30 | code_sparse_regularization_decay: 1 # apply decay factor to lambda at every encoder iteration 31 | code_group_neural_firings_regularization: 0.05 # if > 0, then it applies groupping across neurons 32 | code_q_regularization: False # set True to apply Q-regularization on the norm of the code 33 | code_q_regularization_matrix: None # The matrix of relations between the codes (if flag is True, use the path to load) 34 | code_q_regularization_matrix_path: None 35 | code_q_regularization_period: 1 # the period to apply Q-regularization in encoder iterations 36 | code_q_regularization_scale: 5 # scale factor in front of the Q-regularization term 37 | code_q_regularization_norm_type: 2 # Set to the norm number you want the Q-regularization to be applied 38 | code_supp: False # True: apply known event indices (supp) into code x 39 | code_topk: True # True: keep only top k indices in each kernel code non-zero (this is greedy) 40 | code_topk_sparse: 18 # number of top k non-zero entires in each code kernel 41 | code_topk_period: 10 # period on encoder iteration to apply topk 42 | code_l1loss_bp: True # True: to include l1-norm of the code in the loss during training 43 | code_l1loss_bp_penalty_weight: 0.03 # amount of sparse regularization of the code with bp during training 44 | ################################# 45 | est_baseline_activity: False # True: estimate the baseline activity along with the code in the encoder 46 | poisson_stability_name: None # type of non-linearity to use on poisson case for encoder stability 47 | poisson_peak: 1 # For ELU "poisson_stability_name", this peak must be set to a value 48 | ################################# 49 | # unrolling parameters 50 | unrolling_num: 800 # number of unrolling iterations in the encoder 51 | unrolling_mode: "fista" # ista or fista encoder 52 | unrolling_alpha: 0.5 # alpha step size in unrolling 53 | unrolling_prox: "shrinkage" # type of proximal operator (shrinkage, threshold) 54 | unrolling_threshold: None # must set to a value if unrolling_prox is threshold" 55 | ################################# 56 | # training related 57 | # default optimizer is ADAM. 58 | optimizer_lr: 1e-2 # learning rate for training the model (learning the kernels) 59 | optimizer_lr_step: 1000 # number of steps (updates) after which the lr will decay 60 | optimizer_lr_decay: 1 # decay factor for learning rate 61 | optimizer_adam_eps: 1e-3 # eps parameter of adam optimizer 62 | optimizer_adam_weight_decay: 0 # weight_decay parameter for adam optimizer 63 | # 64 | backward_gradient_decsent: truncated_bprop" # type of backward gradient update (bprop, truncated_bprop) 65 | backward_truncated_bprop_itr: 20 # must be set for truncated_bprop 66 | # 67 | train_num_epochs: 120 # number of epochs for training 68 | train_data_shuffle: True # True: to shuffle dataset at every epoch for training 69 | train_batch_size: 30 # batch size for training 70 | train_num_workers: 4 # number of workers to load data 71 | train_val_split: 1 # 1: use all for train. percentage of data used to train, rest to be used for validation. 72 | train_with_fraction: 1 # 1 for all the data, or a fraction e.g. 0.1 73 | # 74 | enable_board: True 75 | log_info_epoch_period: 1 # period to push small info into the board 76 | log_model_epoch_period: 200 # period to save model 77 | log_fig_epoch_period: 20 # period to push figures into the board 78 | # log_fig_epoch_period: 1 # period to push figures into the board 79 | tqdm_prints_disable: False # True: to disable prints of epoch training process 80 | tqdm_prints_inside_disable: True # True: to disable prints inside of epoch training process -------------------------------------------------------------------------------- /dunl/postprocess_scripts/plot_expsetup_dopamine_spiking_eshel_uchida.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2025 Bahareh Tolooshams 3 | 4 | plot experimental setup 5 | 6 | :author: Bahareh Tolooshams 7 | """ 8 | 9 | import torch 10 | import numpy as np 11 | import os 12 | import pickle 13 | import argparse 14 | import matplotlib as mpl 15 | import matplotlib.pyplot as plt 16 | 17 | 18 | def init_params(): 19 | parser = argparse.ArgumentParser(description=__doc__) 20 | 21 | parser.add_argument( 22 | "--res-path", 23 | type=str, 24 | help="res path", 25 | default="../results/dopaminespiking_25msbin_kernellength24_kernelnum3_codefree_kernel111_2023_07_14_12_37_30", 26 | ) 27 | parser.add_argument( 28 | "--batch-size", 29 | type=int, 30 | help="batch size", 31 | default=128, 32 | ) 33 | parser.add_argument( 34 | "--num-workers", 35 | type=int, 36 | help="number of workers for dataloader", 37 | default=4, 38 | ) 39 | parser.add_argument( 40 | "--reward-delay", 41 | type=int, 42 | help="reward delay from the cue onset", 43 | default=60, # this is after the bining 44 | ) 45 | parser.add_argument( 46 | "--duration", 47 | type=int, 48 | help="duration", 49 | default=120, # this is after the bining 50 | ) 51 | 52 | args = parser.parse_args() 53 | params = vars(args) 54 | 55 | return params 56 | 57 | 58 | def main(): 59 | print("Predict.") 60 | 61 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 62 | print("device is", device) 63 | 64 | # init parameters -------------------------------------------------------# 65 | print("init parameters.") 66 | params_init = init_params() 67 | 68 | # take parameters from the result path 69 | params = pickle.load( 70 | open(os.path.join(params_init["res_path"], "params.pickle"), "rb") 71 | ) 72 | for key in params_init.keys(): 73 | params[key] = params_init[key] 74 | 75 | if params["data_path"] == "": 76 | data_folder = params["data_folder"] 77 | filename_list = os.listdir(data_folder) 78 | data_path_list = [ 79 | f"{data_folder}/{x}" for x in filename_list if "trainready.pt" in x 80 | ] 81 | else: 82 | data_path_list = params["data_path"] 83 | 84 | print("There {} dataset in the folder.".format(len(data_path_list))) 85 | 86 | # set time bin resolution -----------------------------------------------# 87 | data_dict = torch.load(data_path_list[0]) 88 | params["time_bin_resolution"] = data_dict["time_bin_resolution"] 89 | 90 | # create folders -------------------------------------------------------# 91 | out_path = os.path.join( 92 | params["res_path"], 93 | "figures", 94 | ) 95 | if not os.path.exists(out_path): 96 | os.makedirs(out_path) 97 | 98 | # plot configuration -------------------------------------------------------# 99 | 100 | axes_fontsize = 10 101 | legend_fontsize = 8 102 | tick_fontsize = 10 103 | title_fontsize = 10 104 | 105 | # upadte plot parameters 106 | # style 107 | mpl.rcParams.update( 108 | { 109 | "pgf.texsystem": "pdflatex", 110 | "text.usetex": True, 111 | "axes.labelsize": axes_fontsize, 112 | "axes.titlesize": title_fontsize, 113 | "legend.fontsize": legend_fontsize, 114 | "xtick.labelsize": tick_fontsize, 115 | "ytick.labelsize": tick_fontsize, 116 | "text.latex.preamble": r"\usepackage{bm}", 117 | "axes.unicode_minus": False, 118 | } 119 | ) 120 | 121 | # plot -------------------------------------------------------# 122 | fig, axn = plt.subplots(2, 1, sharex=True, sharey=True) 123 | 124 | for ax in axn.flat: 125 | ax.tick_params(axis="x", direction="out") 126 | ax.tick_params(axis="y", direction="out") 127 | ax.spines["right"].set_visible(False) 128 | ax.spines["top"].set_visible(False) 129 | 130 | dot_loc = 1 131 | 132 | cue_regret_dot = 0 133 | cue_expected_dot = 0 134 | reward_expected_dot = params["reward_delay"] 135 | reward_surprise_dot = 0 136 | 137 | plt.subplot(2, 1, 1) 138 | plt.title(r"$\textbf{Surprise\ Trials}$") 139 | plt.axvline(x=0, linestyle="--", linewidth=0.5, color="black") 140 | plt.plot( 141 | reward_surprise_dot, 142 | dot_loc, 143 | ".", 144 | markersize=10, 145 | color="Blue", 146 | ) 147 | 148 | plt.subplot(2, 1, 2) 149 | plt.title(r"$\textbf{Expected\ Trials}$") 150 | plt.axvline(x=0, linestyle="--", linewidth=0.5, color="black") 151 | plt.axvline(x=params["reward_delay"], linestyle="--", linewidth=0.5, color="black") 152 | plt.plot( 153 | cue_expected_dot, 154 | dot_loc, 155 | ".", 156 | markersize=10, 157 | color="Orange", 158 | ) 159 | plt.plot( 160 | reward_expected_dot, 161 | dot_loc, 162 | ".", 163 | markersize=10, 164 | color="Blue", 165 | ) 166 | xtic = np.array([0, 0.5, 1]) * params["duration"] 167 | xtic_figure = [int(x * params["time_bin_resolution"]) for x in xtic] 168 | plt.xticks(xtic, xtic_figure) 169 | plt.xlabel("Time [ms]", labelpad=0) 170 | 171 | fig.tight_layout(pad=0.8, w_pad=0.7, h_pad=0.5) 172 | plt.savefig( 173 | os.path.join(out_path, "experiment_setup.svg"), 174 | bbox_inches="tight", 175 | pad_inches=0.02, 176 | ) 177 | plt.close() 178 | 179 | print(f"plotting of experimental setup is done. plots are saved at {out_path}") 180 | 181 | 182 | if __name__ == "__main__": 183 | main() 184 | -------------------------------------------------------------------------------- /dunl/postprocess_scripts/plot_kernels_local_orthkernels_deconv_spiking_simulated_noisy.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2025 Bahareh Tolooshams 3 | 4 | plot rec data kernel 5 | 6 | :author: Bahareh Tolooshams 7 | """ 8 | 9 | import torch 10 | import numpy as np 11 | import os 12 | import pickle 13 | import argparse 14 | import matplotlib as mpl 15 | import matplotlib.pyplot as plt 16 | 17 | import sys 18 | 19 | sys.path.append("../dunl/") 20 | 21 | import utils 22 | 23 | 24 | def init_params(): 25 | parser = argparse.ArgumentParser(description=__doc__) 26 | 27 | parser.add_argument( 28 | "--res-path-list", 29 | type=str, 30 | help="res path list", 31 | default=["../results"], 32 | ) 33 | parser.add_argument( 34 | "--color-list", 35 | type=list, 36 | help="color list", 37 | default=[ 38 | "blue", 39 | "blue", 40 | "blue", 41 | "blue", 42 | "blue", 43 | "blue", 44 | "blue", 45 | ], # 2 kernels 46 | ) 47 | parser.add_argument( 48 | "--figsize", 49 | type=tuple, 50 | help="figsize", 51 | default=(8, 2), 52 | ) 53 | 54 | args = parser.parse_args() 55 | params = vars(args) 56 | 57 | return params 58 | 59 | 60 | def main(): 61 | print("Predict.") 62 | 63 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 64 | print("device is", device) 65 | 66 | # init parameters -------------------------------------------------------# 67 | print("init parameters.") 68 | params_init = init_params() 69 | 70 | out_path = os.path.join( 71 | "../", 72 | "figures", 73 | "orthkernels", 74 | "01noise_kernels", 75 | ) 76 | if not os.path.exists(out_path): 77 | os.makedirs(out_path) 78 | 79 | epoch_type = "best_val" 80 | 81 | res_folder = "../results" 82 | filename_list = os.listdir(res_folder) 83 | filename_list = [f"{x}" for x in filename_list if "20unrolling" in x] 84 | filename_list = [f"{x}" for x in filename_list if "01noise" in x] 85 | res_path_list = [f"{res_folder}/{x}" for x in filename_list if "orth_" in x] 86 | 87 | data_folder = "../data/local-orthkernels-simulated" 88 | kernels_true = np.load(os.path.join(data_folder, "kernels.npy")) 89 | 90 | for res_path in res_path_list: 91 | num_trials = int(res_path.split("_")[1].split("trials")[0]) 92 | 93 | num_kernels = int(res_path.split("_")[2].split("kernel")[0]) 94 | 95 | # take parameters from the result path 96 | params = pickle.load(open(os.path.join(res_path, "params.pickle"), "rb")) 97 | params["time_bin_resolution"] = 5 98 | for key in params_init.keys(): 99 | params[key] = params_init[key] 100 | 101 | # model 102 | model_path = os.path.join( 103 | res_path, 104 | "model", 105 | f"model_{epoch_type}.pt", 106 | ) 107 | net = torch.load(model_path, map_location=device) 108 | net.to(device) 109 | net.eval() 110 | kernels = net.get_param("H").clone().detach() 111 | kernels = np.squeeze(kernels.cpu().numpy()) 112 | 113 | outname = res_path.split("/")[-1] 114 | plot_kernel_est( 115 | kernels, 116 | params, 117 | out_path, 118 | f"{outname}_onlyest.png", 119 | ) 120 | plot_kernel_est( 121 | kernels, 122 | params, 123 | out_path, 124 | f"{outname}_onlyest.svg", 125 | ) 126 | 127 | 128 | def plot_kernel_est(kernels, params, out_path, outname): 129 | axes_fontsize = 10 130 | legend_fontsize = 8 131 | tick_fontsize = 10 132 | title_fontsize = 10 133 | fontfamily = "sans-serif" 134 | 135 | # upadte plot parameters 136 | # style 137 | mpl.rcParams.update( 138 | { 139 | "pgf.texsystem": "pdflatex", 140 | "text.usetex": False, 141 | "axes.labelsize": axes_fontsize, 142 | "axes.titlesize": title_fontsize, 143 | "legend.fontsize": legend_fontsize, 144 | "xtick.labelsize": tick_fontsize, 145 | "ytick.labelsize": tick_fontsize, 146 | "text.latex.preamble": r"\usepackage{bm}", 147 | "axes.unicode_minus": False, 148 | "font.family": fontfamily, 149 | } 150 | ) 151 | 152 | row = 1 153 | col = kernels.shape[0] 154 | fig, axn = plt.subplots( 155 | row, col, sharex=True, sharey=True, figsize=params["figsize"] 156 | ) 157 | 158 | for ax in axn.flat: 159 | ax.tick_params(axis="x", direction="out") 160 | ax.tick_params(axis="y", direction="out") 161 | ax.spines["right"].set_visible(False) 162 | ax.spines["top"].set_visible(False) 163 | 164 | t = np.linspace( 165 | 0, 166 | kernels.shape[-1] * params["time_bin_resolution"], 167 | kernels.shape[-1], 168 | ) 169 | 170 | for ctr in range(col): 171 | plt.subplot(row, col, ctr + 1) 172 | 173 | plt.plot(t, kernels[ctr], color=params["color_list"][ctr]) 174 | 175 | xtic = ( 176 | np.array([0, 0.5, 1]) 177 | * params["kernel_length"] 178 | * params["time_bin_resolution"] 179 | ) 180 | xtic = [int(x) for x in xtic] 181 | plt.xticks(xtic, xtic) 182 | 183 | plt.xlabel("Time [ms]", labelpad=0) 184 | 185 | fig.tight_layout(pad=0.8, w_pad=0.7, h_pad=0.5) 186 | 187 | plt.savefig( 188 | os.path.join(out_path, outname), 189 | bbox_inches="tight", 190 | pad_inches=0.02, 191 | ) 192 | plt.close() 193 | 194 | 195 | if __name__ == "__main__": 196 | main() 197 | -------------------------------------------------------------------------------- /config/dopamine_calcium_saramatias_uchida_inferbaseline_independentkernelsamongneurons_config.yaml: -------------------------------------------------------------------------------- 1 | default: &DEFAULT 2 | exp_name: "dopaminecalcium_kernellength60_kernelnum5_code2211n1_kernel00011_qreg_fixedq_2p5_firstshrinkage_inferbase_independentkernels_kernelsmoothing_0p0005" 3 | data_path: [ 4 | "../data/dopamine-calcium-saramatias-uchida/VarMag_SM103_20191104_general_format_processed_kernellength60_kernelnum5_trainready.pt", # 20 neurons, 299 trials 5 | "../data/dopamine-calcium-saramatias-uchida/VarMag_SM99_20191109_general_format_processed_kernellength60_kernelnum5_trainready.pt", # 30 neurons, 195 trials 6 | "../data/dopamine-calcium-saramatias-uchida/VarMag_SM104_20191103_general_format_processed_kernellength60_kernelnum5_trainready.pt", # 6 neurons, 252 trials 7 | ] # must be a list of datasets 8 | data_folder: None 9 | test_data_path: None 10 | ################################# 11 | model_distribution: "gaussian" # data distrbution gaussian, binomila, poisson 12 | share_kernels_among_neurons: False # set true to share kernels among neurons 13 | ################################# 14 | # kernel (dictionary) 15 | kernel_normalize: True # True: l2-norm of kernels is set to one after each update 16 | kernel_nonneg: True # True: project kernels into non-negative values 17 | kernel_nonneg_indicator: [0, 0, 0, 1, 1] # 0 for +-, 1 for + 18 | kernel_num: 5 # number of kernels to learn 19 | kernel_length: 60 # number of samples for kernel in time 20 | kernel_stride: 1 # default 1, convolution stride 21 | kernel_init_smoother: False # flag to init kernels to be smooth 22 | kernel_init_smoother_sigma: 0.2 # sigma of the gaussian kernel for kernel_init_smoother 23 | kernel_smoother: True # flag to apply smoother to the kernel during training 24 | kernel_smoother_penalty_weight: 0.0005 # penalty weight to apply for kernel smoother 25 | kernel_initialization: None # None, or a data path 26 | kernel_initialization_needs_adjustment_of_time_bin_resolution: False 27 | ################################# 28 | # code (representation) 29 | code_nonneg: [2, 2, 1, 1, -1] # apply sign constraint on the code. 1 for pos, -1 for neg, 2 for twosided 30 | code_sparse_regularization: 0 # apply sparse (lambda l1-norm) regularization on the code 31 | code_sparse_regularization_decay: 1 # apply decay factor to lambda at every encoder iteration 32 | code_group_neural_firings_regularization: 0 # if > 0, then it applies groupping across neurons 33 | code_q_regularization: True # set True to apply Q-regularization on the norm of the code 34 | code_q_regularization_matrix: None # The matrix of relations between the codes (use the path to load) 35 | code_q_regularization_matrix_path: "../data/dopamine-calcium-saramatias-uchida/code_q_regularization_matrix.pt" 36 | code_q_regularization_period: 1 # the period to apply Q-regularization in encoder iterations 37 | code_q_regularization_scale: 2.5 # scale factor in front of the Q-regularization term 38 | code_q_regularization_norm_type: 2 # Set to the norm number you want the Q-regularization to be applied 39 | code_supp: True # True: apply known event indices (supp) into code x 40 | code_topk: False # True: keep only top k indices in each kernel code non-zero (this is greedy) 41 | code_topk_sparse: None # number of top k non-zero entires in each code kernel 42 | code_topk_period: None # period on encoder iteration to apply topk 43 | code_l1loss_bp: False # True: to include l1-norm of the code in the loss during training 44 | code_l1loss_bp_penalty_weight: 0 # amount of sparse regularization of the code with bp during training 45 | ################################# 46 | est_baseline_activity: True # True: estimate the baseline activity along with the code in the encoder 47 | poisson_stability_name: None # type of non-linearity to use on poisson case for encoder stability 48 | poisson_peak: 1 # For ELU "poisson_stability_name", this peak must be set to a value 49 | ################################# 50 | # unrolling parameters 51 | unrolling_num: 100 # number of unrolling iterations in the encoder 52 | unrolling_mode: "fista" # ista or fista encoder 53 | unrolling_alpha: 0.1 # alpha step size in unrolling 54 | unrolling_prox: "shrinkage" # type of proximal operator (shrinkage, threshold) 55 | unrolling_threshold: None # must set to a value if unrolling_prox is threshold" 56 | ################################# 57 | # training related 58 | # default optimizer is ADAM. 59 | optimizer_lr: 1e-2 # learning rate for training the model (learning the kernels) 60 | optimizer_lr_step: 1000 # number of steps (updates) after which the lr will decay 61 | optimizer_lr_decay: 1 # decay factor for learning rate 62 | optimizer_adam_eps: 1e-3 # eps parameter of adam optimizer 63 | optimizer_adam_weight_decay: 0 # weight_decay parameter for adam optimizer 64 | # 65 | backward_gradient_decsent: bprop" # type of backward gradient update (bprop, truncated_bprop) 66 | backward_truncated_bprop_itr: 10 # must be set for truncated_bprop 67 | # 68 | train_num_epochs: 150 # number of epochs for training 69 | train_data_shuffle: True # True: to shuffle dataset at every epoch for training 70 | train_batch_size: 32 # batch size for training 71 | train_num_workers: 4 # number of workers to load data 72 | train_val_split: 1 # 1: use all for train. percentage of data used to train, rest to be used for validation. 73 | train_with_fraction: 1 # 1 for all the data, or a fraction e.g. 0.1 74 | # 75 | enable_board: True 76 | log_info_epoch_period: 40 # period to push small info into the board 77 | log_model_epoch_period: 3000 # period to save model 78 | log_fig_epoch_period: 40 # period to push figures into the board 79 | tqdm_prints_disable: False # True: to disable prints of epoch training process 80 | tqdm_prints_inside_disable: True # True: to disable prints inside of epoch training process -------------------------------------------------------------------------------- /dunl/postprocess_scripts/plot_pcanmf_kernels_spiking.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2025 Bahareh Tolooshams 3 | 4 | plot code data 5 | 6 | :author: Bahareh Tolooshams 7 | """ 8 | 9 | import torch 10 | import numpy as np 11 | import os 12 | import pickle 13 | import argparse 14 | import matplotlib as mpl 15 | import matplotlib.pyplot as plt 16 | 17 | 18 | def init_params(): 19 | parser = argparse.ArgumentParser(description=__doc__) 20 | 21 | parser.add_argument( 22 | "--res-path", 23 | type=str, 24 | help="res path", 25 | default="../results/dopaminespiking_25msbin_kernellength24_kernelnum3_codefree_kernel111_2023_07_14_12_37_30", 26 | ) 27 | parser.add_argument( 28 | "--num-comp", 29 | type=int, 30 | help="number of components", 31 | default=2, 32 | ) 33 | parser.add_argument( 34 | "--color-list", 35 | type=list, 36 | help="color list", 37 | default=[ 38 | "blue", 39 | "red", 40 | ], 41 | ) 42 | parser.add_argument( 43 | "--figsize", 44 | type=tuple, 45 | help="figsize", 46 | default=(1.6, 2), 47 | ) 48 | 49 | args = parser.parse_args() 50 | params = vars(args) 51 | 52 | return params 53 | 54 | 55 | def main(): 56 | print("Predict.") 57 | 58 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 59 | print("device is", device) 60 | 61 | # init parameters -------------------------------------------------------# 62 | print("init parameters.") 63 | params_init = init_params() 64 | 65 | # take parameters from the result path 66 | params = pickle.load( 67 | open(os.path.join(params_init["res_path"], "params.pickle"), "rb") 68 | ) 69 | for key in params_init.keys(): 70 | params[key] = params_init[key] 71 | 72 | if params["data_path"] == "": 73 | data_folder = params["data_folder"] 74 | filename_list = os.listdir(data_folder) 75 | data_path_list = [ 76 | f"{data_folder}/{x}" for x in filename_list if "trainready.pt" in x 77 | ] 78 | else: 79 | data_path_list = params["data_path"] 80 | 81 | print("There {} dataset in the folder.".format(len(data_path_list))) 82 | 83 | # set time bin resolution -----------------------------------------------# 84 | data_dict = torch.load(data_path_list[0]) 85 | params["time_bin_resolution"] = data_dict["time_bin_resolution"] 86 | 87 | # create folders -------------------------------------------------------# 88 | out_path = os.path.join( 89 | params["res_path"], 90 | "figures", 91 | ) 92 | if not os.path.exists(out_path): 93 | os.makedirs(out_path) 94 | 95 | postprocess_path = os.path.join( 96 | params["res_path"], 97 | "postprocess", 98 | ) 99 | 100 | # load data -------------------------------------------------------# 101 | pca_transform = pickle.load( 102 | open( 103 | os.path.join( 104 | postprocess_path, "pca_transform_{}.pkl".format(params["num_comp"]) 105 | ), 106 | "rb", 107 | ) 108 | ) 109 | nmf_transform = pickle.load( 110 | open( 111 | os.path.join( 112 | postprocess_path, "nmf_transform_{}.pkl".format(params["num_comp"]) 113 | ), 114 | "rb", 115 | ) 116 | ) 117 | 118 | pca_W = pca_transform.components_.T 119 | nmf_W = nmf_transform.components_.T 120 | 121 | plot_pca_nmf(pca_W, params, out_path, "pca_{}".format(params["num_comp"])) 122 | plot_pca_nmf(nmf_W, params, out_path, "nmf_{}".format(params["num_comp"])) 123 | 124 | print(f"plotting of kernels is done. plots are saved at {out_path}") 125 | 126 | 127 | def plot_pca_nmf(W, params, out_path, name): 128 | axes_fontsize = 10 129 | legend_fontsize = 8 130 | tick_fontsize = 10 131 | title_fontsize = 10 132 | 133 | # upadte plot parameters 134 | # style 135 | mpl.rcParams.update( 136 | { 137 | "pgf.texsystem": "pdflatex", 138 | "text.usetex": True, 139 | "axes.labelsize": axes_fontsize, 140 | "axes.titlesize": title_fontsize, 141 | "legend.fontsize": legend_fontsize, 142 | "xtick.labelsize": tick_fontsize, 143 | "ytick.labelsize": tick_fontsize, 144 | "text.latex.preamble": r"\usepackage{bm}", 145 | "axes.unicode_minus": False, 146 | } 147 | ) 148 | 149 | # plot -------------------------------------------------------# 150 | fig, ax = plt.subplots(1, 1, sharex=True, sharey=True, figsize=params["figsize"]) 151 | 152 | ax.tick_params(axis="x", direction="out") 153 | ax.tick_params(axis="y", direction="out") 154 | ax.spines["right"].set_visible(False) 155 | ax.spines["top"].set_visible(False) 156 | 157 | ax.axhline(0, color="gray", lw=0.3) 158 | 159 | t = np.linspace( 160 | 0, 161 | params["kernel_length"] * params["time_bin_resolution"], 162 | params["kernel_length"], 163 | ) 164 | 165 | plt.subplot(1, 1, 1) 166 | for ctr in range(W.shape[1]): 167 | plt.plot(t, W[:, ctr], color=params["color_list"][ctr]) 168 | xtic = ( 169 | np.array([0, 0.5, 1]) 170 | * params["kernel_length"] 171 | * params["time_bin_resolution"] 172 | ) 173 | xtic = [int(x) for x in xtic] 174 | plt.xticks(xtic, xtic) 175 | plt.xlabel("Time [ms]", labelpad=0) 176 | 177 | fig.tight_layout(pad=0.8, w_pad=0.7, h_pad=0.5) 178 | plt.savefig( 179 | os.path.join(out_path, "kernels_{}.svg".format(name)), 180 | bbox_inches="tight", 181 | pad_inches=0.02, 182 | ) 183 | plt.close() 184 | 185 | 186 | if __name__ == "__main__": 187 | main() 188 | -------------------------------------------------------------------------------- /dunl/postprocess_scripts/plot_kernels_dopamine_spiking_eshel_uchida.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2025 Bahareh Tolooshams 3 | 4 | plot rec data 5 | 6 | :author: Bahareh Tolooshams 7 | """ 8 | 9 | import torch 10 | import numpy as np 11 | import os 12 | import pickle 13 | import argparse 14 | import matplotlib as mpl 15 | import matplotlib.pyplot as plt 16 | 17 | import sys 18 | 19 | sys.path.append("../dunl/") 20 | 21 | import utils 22 | 23 | 24 | def init_params(): 25 | parser = argparse.ArgumentParser(description=__doc__) 26 | 27 | parser.add_argument( 28 | "--res-path", 29 | type=str, 30 | help="res path", 31 | # default="../results/dopaminespiking_25msbin_kernellength24_kernelnum3_codefree_kernel111_limiteddata0p1_smoothkernel_0p0005_2023_08_12_22_09_40", 32 | default="../results/dopaminespiking_25msbin_kernellength24_kernelnum3_codefree_kernel111_2023_07_14_12_37_30", 33 | ) 34 | parser.add_argument( 35 | "--color-list", 36 | type=list, 37 | help="color list", 38 | default=[ 39 | "orange", 40 | "blue", 41 | "red", 42 | ], # cue exp, 2 rewards 43 | ) 44 | parser.add_argument( 45 | "--figsize", 46 | type=tuple, 47 | help="figsize", 48 | default=(6, 2), 49 | ) 50 | parser.add_argument( 51 | "--swap-kernel", 52 | type=bool, 53 | help="bool to swap kernels", 54 | default=True, 55 | ) 56 | 57 | args = parser.parse_args() 58 | params = vars(args) 59 | 60 | return params 61 | 62 | 63 | def main(): 64 | print("Predict.") 65 | 66 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 67 | print("device is", device) 68 | 69 | # init parameters -------------------------------------------------------# 70 | print("init parameters.") 71 | params_init = init_params() 72 | 73 | # take parameters from the result path 74 | params = pickle.load( 75 | open(os.path.join(params_init["res_path"], "params.pickle"), "rb") 76 | ) 77 | for key in params_init.keys(): 78 | params[key] = params_init[key] 79 | 80 | if params["data_path"] == "": 81 | data_folder = params["data_folder"] 82 | filename_list = os.listdir(data_folder) 83 | data_path_list = [ 84 | f"{data_folder}/{x}" for x in filename_list if "trainready.pt" in x 85 | ] 86 | else: 87 | data_path_list = params["data_path"] 88 | 89 | print("There {} dataset in the folder.".format(len(data_path_list))) 90 | 91 | # set time bin resolution -----------------------------------------------# 92 | data_dict = torch.load(data_path_list[0]) 93 | params["time_bin_resolution"] = data_dict["time_bin_resolution"] 94 | 95 | # create folders -------------------------------------------------------# 96 | model_path = os.path.join( 97 | params["res_path"], 98 | "model", 99 | "model_final.pt", 100 | ) 101 | 102 | out_path = os.path.join( 103 | params["res_path"], 104 | "figures", 105 | ) 106 | if not os.path.exists(out_path): 107 | os.makedirs(out_path) 108 | 109 | # load model ------------------------------------------------------# 110 | net = torch.load(model_path, map_location=device) 111 | net.to(device) 112 | net.eval() 113 | 114 | kernels = net.get_param("H").clone().detach() 115 | if params["swap_kernel"]: 116 | kernels = utils.swap_kernel(kernels, 1, 2) 117 | kernels = np.squeeze(kernels.cpu().numpy()) 118 | 119 | plot_kernel(kernels, params, out_path) 120 | 121 | 122 | def plot_kernel(kernels, params, out_path): 123 | axes_fontsize = 15 124 | legend_fontsize = 8 125 | tick_fontsize = 15 126 | title_fontsize = 20 127 | fontfamily = "sans-serif" 128 | 129 | # upadte plot parameters 130 | # style 131 | mpl.rcParams.update( 132 | { 133 | "pgf.texsystem": "pdflatex", 134 | "text.usetex": True, 135 | "axes.labelsize": axes_fontsize, 136 | "axes.titlesize": title_fontsize, 137 | "legend.fontsize": legend_fontsize, 138 | "xtick.labelsize": tick_fontsize, 139 | "ytick.labelsize": tick_fontsize, 140 | "text.latex.preamble": r"\usepackage{bm}", 141 | "axes.unicode_minus": False, 142 | "font.family": fontfamily, 143 | } 144 | ) 145 | 146 | fig, axn = plt.subplots(1, 3, sharex=True, sharey=True, figsize=params["figsize"]) 147 | 148 | for ax in axn.flat: 149 | ax.tick_params(axis="x", direction="out") 150 | ax.tick_params(axis="y", direction="out") 151 | ax.spines["right"].set_visible(False) 152 | ax.spines["top"].set_visible(False) 153 | 154 | t = np.linspace( 155 | 0, 156 | params["kernel_length"] * params["time_bin_resolution"], 157 | params["kernel_length"], 158 | ) 159 | 160 | for ctr in range(params["kernel_num"]): 161 | plt.subplot(1, 3, ctr + 1) 162 | axn[ctr].axhline(0, color="gray", lw=0.3) 163 | 164 | plt.plot(t, kernels[ctr], color=params["color_list"][ctr]) 165 | 166 | if ctr == 0: 167 | plt.title(r"$\textbf{Cue}$") 168 | elif ctr == 1: 169 | plt.title(r"$\textbf{Reward\ I}$") 170 | else: 171 | plt.title(r"$\textbf{Reward\ II}$") 172 | xtic = ( 173 | np.array([0, 0.5, 1]) 174 | * params["kernel_length"] 175 | * params["time_bin_resolution"] 176 | ) 177 | xtic = [int(x) for x in xtic] 178 | plt.xticks(xtic, xtic) 179 | 180 | if ctr == 1: 181 | plt.xlabel("Time [ms]", labelpad=0) 182 | 183 | fig.tight_layout(pad=0.8, w_pad=0.7, h_pad=0.5) 184 | plt.savefig( 185 | os.path.join(out_path, "kernels.svg"), 186 | bbox_inches="tight", 187 | pad_inches=0.02, 188 | ) 189 | plt.close() 190 | 191 | print(f"plotting of kernels is done. plots are saved at {out_path}") 192 | 193 | 194 | if __name__ == "__main__": 195 | main() 196 | -------------------------------------------------------------------------------- /dunl/postprocess_scripts/plot_kernels_dopamine_spiking_eshel_uchida_vertical.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2025 Bahareh Tolooshams 3 | 4 | plot rec data 5 | 6 | :author: Bahareh Tolooshams 7 | """ 8 | 9 | import torch 10 | import numpy as np 11 | import os 12 | import pickle 13 | import argparse 14 | import matplotlib as mpl 15 | import matplotlib.pyplot as plt 16 | 17 | import sys 18 | 19 | sys.path.append("../dunl/") 20 | 21 | import utils 22 | 23 | 24 | def init_params(): 25 | parser = argparse.ArgumentParser(description=__doc__) 26 | 27 | parser.add_argument( 28 | "--res-path", 29 | type=str, 30 | help="res path", 31 | default="../results/dopaminespiking_25msbin_kernellength24_kernelnum3_codefree_kernel111_2023_07_14_12_37_30", 32 | ) 33 | parser.add_argument( 34 | "--color-list", 35 | type=list, 36 | help="color list", 37 | default=[ 38 | "orange", 39 | "blue", 40 | "red", 41 | ], # cue exp, 2 rewards 42 | ) 43 | parser.add_argument( 44 | "--figsize", 45 | type=tuple, 46 | help="figsize", 47 | default=(1.5, 4), 48 | ) 49 | parser.add_argument( 50 | "--swap-kernel", 51 | type=bool, 52 | help="bool to swap kernels", 53 | default=True, 54 | ) 55 | 56 | args = parser.parse_args() 57 | params = vars(args) 58 | 59 | return params 60 | 61 | 62 | def main(): 63 | print("Predict.") 64 | 65 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 66 | print("device is", device) 67 | 68 | # init parameters -------------------------------------------------------# 69 | print("init parameters.") 70 | params_init = init_params() 71 | 72 | # take parameters from the result path 73 | params = pickle.load( 74 | open(os.path.join(params_init["res_path"], "params.pickle"), "rb") 75 | ) 76 | for key in params_init.keys(): 77 | params[key] = params_init[key] 78 | 79 | if params["data_path"] == "": 80 | data_folder = params["data_folder"] 81 | filename_list = os.listdir(data_folder) 82 | data_path_list = [ 83 | f"{data_folder}/{x}" for x in filename_list if "trainready.pt" in x 84 | ] 85 | else: 86 | data_path_list = params["data_path"] 87 | 88 | print("There {} dataset in the folder.".format(len(data_path_list))) 89 | 90 | # set time bin resolution -----------------------------------------------# 91 | data_dict = torch.load(data_path_list[0]) 92 | params["time_bin_resolution"] = data_dict["time_bin_resolution"] 93 | 94 | # create folders -------------------------------------------------------# 95 | model_path = os.path.join( 96 | params["res_path"], 97 | "model", 98 | "model_final.pt", 99 | ) 100 | 101 | out_path = os.path.join( 102 | params["res_path"], 103 | "figures", 104 | ) 105 | if not os.path.exists(out_path): 106 | os.makedirs(out_path) 107 | 108 | # load model ------------------------------------------------------# 109 | net = torch.load(model_path, map_location=device) 110 | net.to(device) 111 | net.eval() 112 | 113 | kernels = net.get_param("H").clone().detach() 114 | if params["swap_kernel"]: 115 | kernels = utils.swap_kernel(kernels, 1, 2) 116 | kernels = np.squeeze(kernels.cpu().numpy()) 117 | 118 | plot_kernel(kernels, params, out_path) 119 | 120 | 121 | def plot_kernel(kernels, params, out_path): 122 | axes_fontsize = 15 123 | legend_fontsize = 8 124 | tick_fontsize = 15 125 | title_fontsize = 15 126 | fontfamily = "sans-serif" 127 | 128 | # upadte plot parameters 129 | # style 130 | mpl.rcParams.update( 131 | { 132 | "pgf.texsystem": "pdflatex", 133 | "text.usetex": True, 134 | "axes.labelsize": axes_fontsize, 135 | "axes.titlesize": title_fontsize, 136 | "legend.fontsize": legend_fontsize, 137 | "xtick.labelsize": tick_fontsize, 138 | "ytick.labelsize": tick_fontsize, 139 | "text.latex.preamble": r"\usepackage{bm}", 140 | "axes.unicode_minus": False, 141 | "font.family": fontfamily, 142 | } 143 | ) 144 | 145 | fig, axn = plt.subplots(3, 1, sharex=True, sharey=True, figsize=params["figsize"]) 146 | 147 | for ax in axn.flat: 148 | ax.tick_params(axis="x", direction="out") 149 | ax.tick_params(axis="y", direction="out") 150 | ax.spines["right"].set_visible(False) 151 | ax.spines["top"].set_visible(False) 152 | 153 | t = np.linspace( 154 | 0, 155 | params["kernel_length"] * params["time_bin_resolution"], 156 | params["kernel_length"], 157 | ) 158 | 159 | for ctr in range(params["kernel_num"]): 160 | plt.subplot(3, 1, ctr + 1) 161 | axn[ctr].axhline(0, color="gray", lw=0.3) 162 | 163 | plt.plot(t, kernels[ctr], color=params["color_list"][ctr], lw=2.5) 164 | 165 | # if ctr == 0: 166 | # plt.title(r"$\textbf{Cue}$") 167 | # elif ctr == 1: 168 | # plt.title(r"$\textbf{Reward\ I}$") 169 | # else: 170 | # plt.title(r"$\textbf{Reward\ II}$") 171 | if ctr == 0: 172 | plt.title("Cue") 173 | elif ctr == 1: 174 | plt.title("Reward I") 175 | else: 176 | plt.title("Reward II") 177 | xtic = ( 178 | np.array([0, 0.5, 1]) 179 | * params["kernel_length"] 180 | * params["time_bin_resolution"] 181 | ) 182 | xtic = [int(x) for x in xtic] 183 | plt.xticks(xtic, xtic) 184 | plt.yticks([]) 185 | 186 | if ctr == 2: 187 | plt.xlabel("Time (ms)", labelpad=0) 188 | 189 | fig.tight_layout(pad=0.8, w_pad=0.7, h_pad=0.5) 190 | plt.savefig( 191 | os.path.join(out_path, "kernels_vertical.svg"), 192 | bbox_inches="tight", 193 | pad_inches=0.02, 194 | ) 195 | plt.close() 196 | 197 | print(f"plotting of kernels is done. plots are saved at {out_path}") 198 | 199 | 200 | if __name__ == "__main__": 201 | main() 202 | -------------------------------------------------------------------------------- /dunl/preprocess_scripts/preprocess_data_dopamine_spiking_eshel_into_neuralgml_matlab.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2025 Bahareh Tolooshams 3 | 4 | preprocess data for lfads 5 | 6 | :author: Bahareh Tolooshams 7 | """ 8 | 9 | import torch 10 | import numpy as np 11 | from tqdm import tqdm 12 | import os 13 | import argparse 14 | import scipy 15 | 16 | import sys 17 | 18 | sys.path.append("../dunl/") 19 | 20 | import datasetloader 21 | 22 | 23 | def init_params(): 24 | parser = argparse.ArgumentParser(description=__doc__) 25 | 26 | parser.add_argument( 27 | "--data-folder", 28 | type=str, 29 | help="data path", 30 | default="../data/dopamine-spiking-eshel-uchida", 31 | ) 32 | parser.add_argument( 33 | "--out-path", 34 | type=str, 35 | help="out path", 36 | default="../data/dopamine-spiking-eshel-uchida/neuralgml_matlab", 37 | ) 38 | parser.add_argument( 39 | "--batch-size", 40 | type=int, 41 | help="batch size", 42 | default=800, 43 | ) 44 | parser.add_argument( 45 | "--num-workers", 46 | type=int, 47 | help="number of workers for dataloader", 48 | default=4, 49 | ) 50 | parser.add_argument( 51 | "--reward-dur", 52 | type=int, 53 | help="duration after onset in samples", 54 | default=600, 55 | ) 56 | parser.add_argument( 57 | "--reward-amount-list", 58 | type=list, 59 | help="reward amount list", 60 | default=[0.1, 0.3, 1.2, 2.5, 5.0, 10.0, 20.0], 61 | ) 62 | args = parser.parse_args() 63 | params = vars(args) 64 | 65 | return params 66 | 67 | 68 | def main(): 69 | # init parameters -------------------------------------------------------# 70 | print("init parameters.") 71 | params = init_params() 72 | 73 | data_folder = params["data_folder"] 74 | filename_list = os.listdir(data_folder) 75 | data_path_list = [ 76 | f"{data_folder}/{x}" for x in filename_list if "trainready.pt" in x 77 | ] 78 | 79 | print("There {} dataset in the folder.".format(len(data_path_list))) 80 | 81 | # create datasets -------------------------------------------------------# 82 | dataset_list = list() 83 | dataloader_list = list() 84 | for data_path_cur in data_path_list: 85 | print(data_path_cur) 86 | dataset = datasetloader.DUNLdatasetwithRaster(data_path_cur) 87 | 88 | dataset_list.append(dataset) 89 | 90 | dataloader = torch.utils.data.DataLoader( 91 | dataset, 92 | shuffle=False, 93 | batch_size=params["batch_size"], 94 | num_workers=params["num_workers"], 95 | ) 96 | dataloader_list.append(dataloader) 97 | 98 | out_path = params["out_path"] 99 | if not os.path.exists(out_path): 100 | os.makedirs(out_path) 101 | 102 | # go over data -------------------------------------------------------# 103 | y_surprise = list() 104 | y_expected = list() 105 | 106 | y_surprise_rew_amount = list() 107 | y_expected_rew_amount = list() 108 | 109 | neuron_surprise = list() 110 | neuron_expected = list() 111 | 112 | neuron_ctr = -1 113 | for dataloader in dataloader_list: 114 | neuron_ctr += 1 115 | 116 | for idx, (y, x, a, label_int, raster) in tqdm( 117 | enumerate(dataloader), disable=True 118 | ): 119 | label = label_int.clone() 120 | tmp_ctr = 0 121 | for reward in params["reward_amount_list"]: 122 | tmp_ctr += 1 123 | label[label == tmp_ctr] = reward 124 | 125 | # send data to device (cpu or gpu) 126 | 127 | for i in range(y.shape[0]): 128 | yi = raster[i] 129 | xi = x[i] 130 | labeli = label[i] 131 | 132 | # reward presence 133 | cue_flag = torch.sum(torch.abs(xi[0]), dim=-1).item() 134 | 135 | reward_onset = ( 136 | np.where(xi[1] > 0)[-1][0] 137 | ) * dataset.time_bin_resolution 138 | 139 | if cue_flag: 140 | # expected trial 141 | y_expected_curr = yi[ 142 | :, 143 | reward_onset : reward_onset + params["reward_dur"], 144 | ] 145 | y_expected.append(y_expected_curr) 146 | y_expected_rew_amount.append(labeli) 147 | neuron_expected.append(torch.tensor(neuron_ctr)) 148 | else: 149 | # surprise trial 150 | y_surprise_curr = yi[ 151 | :, 152 | reward_onset : reward_onset + params["reward_dur"], 153 | ] 154 | y_surprise.append(y_surprise_curr) 155 | y_surprise_rew_amount.append(labeli) 156 | neuron_surprise.append(torch.tensor(neuron_ctr)) 157 | 158 | # stack after all data from all datasets 159 | y_expected = torch.stack(y_expected, dim=0) 160 | y_surprise = torch.stack(y_surprise, dim=0) 161 | y_expected_rew_amount = torch.stack(y_expected_rew_amount, dim=0) 162 | y_surprise_rew_amount = torch.stack(y_surprise_rew_amount, dim=0) 163 | neuron_expected = torch.stack(neuron_expected, dim=0) 164 | neuron_surprise = torch.stack(neuron_surprise, dim=0) 165 | 166 | y = torch.cat([y_expected, y_surprise], dim=0) 167 | # (reward amount is label) surprise has negative 168 | label = torch.cat([y_expected_rew_amount, -y_surprise_rew_amount], dim=0) 169 | neuron = torch.cat([neuron_expected, neuron_surprise], dim=0) 170 | 171 | # (trial, time_window) 172 | y = torch.squeeze(y, dim=1) 173 | y = np.array(y.detach().cpu().numpy(), dtype=int) 174 | label = label.detach().cpu().numpy() 175 | neuron = neuron.detach().cpu().numpy() 176 | 177 | print("saving data!") 178 | scipy.io.savemat( 179 | os.path.join(params["out_path"], f"dopamine_eshel_reward_onset.mat"), 180 | {"y": y, "label": label, "neuron": neuron}, 181 | ) 182 | 183 | 184 | if __name__ == "__main__": 185 | main() 186 | --------------------------------------------------------------------------------