├── data └── trained_models │ └── 0 │ ├── model.ckpt.meta │ ├── model.ckpt.index │ ├── task_ckpts │ ├── model.ckpt-0.meta │ ├── model.ckpt-1.meta │ ├── model.ckpt-2.meta │ ├── model.ckpt-3.meta │ ├── model.ckpt-0.index │ ├── model.ckpt-1.index │ ├── model.ckpt-2.index │ ├── model.ckpt-3.index │ ├── model.ckpt-0.data-00000-of-00001 │ ├── model.ckpt-1.data-00000-of-00001 │ ├── model.ckpt-2.data-00000-of-00001 │ ├── model.ckpt-3.data-00000-of-00001 │ └── checkpoint │ ├── model.ckpt.data-00000-of-00001 │ ├── hp.json │ └── checkpoint ├── .gitignore ├── requirements-cpu.txt ├── requirements-gpu.txt ├── environment-cpu.yml ├── environment-gpu.yml ├── seq_tools.py ├── example_sequential_training.py ├── README.md ├── opt_tools.py ├── analyses └── run_fixed_point_finder.py ├── tools.py ├── train.py ├── network.py ├── tools_lnd.py └── task.py /data/trained_models/0/model.ckpt.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LDlabs/seqMultiTaskRNN/HEAD/data/trained_models/0/model.ckpt.meta -------------------------------------------------------------------------------- /data/trained_models/0/model.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LDlabs/seqMultiTaskRNN/HEAD/data/trained_models/0/model.ckpt.index -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | /data/example_run/* 3 | analyses/.ipynb_checkpoints/* 4 | *.DS_Store 5 | /data/trained_models/0/tf_fixed_pts_all_init/* 6 | -------------------------------------------------------------------------------- /data/trained_models/0/task_ckpts/model.ckpt-0.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LDlabs/seqMultiTaskRNN/HEAD/data/trained_models/0/task_ckpts/model.ckpt-0.meta -------------------------------------------------------------------------------- /data/trained_models/0/task_ckpts/model.ckpt-1.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LDlabs/seqMultiTaskRNN/HEAD/data/trained_models/0/task_ckpts/model.ckpt-1.meta -------------------------------------------------------------------------------- /data/trained_models/0/task_ckpts/model.ckpt-2.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LDlabs/seqMultiTaskRNN/HEAD/data/trained_models/0/task_ckpts/model.ckpt-2.meta -------------------------------------------------------------------------------- /data/trained_models/0/task_ckpts/model.ckpt-3.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LDlabs/seqMultiTaskRNN/HEAD/data/trained_models/0/task_ckpts/model.ckpt-3.meta -------------------------------------------------------------------------------- /data/trained_models/0/model.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LDlabs/seqMultiTaskRNN/HEAD/data/trained_models/0/model.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /data/trained_models/0/task_ckpts/model.ckpt-0.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LDlabs/seqMultiTaskRNN/HEAD/data/trained_models/0/task_ckpts/model.ckpt-0.index -------------------------------------------------------------------------------- /data/trained_models/0/task_ckpts/model.ckpt-1.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LDlabs/seqMultiTaskRNN/HEAD/data/trained_models/0/task_ckpts/model.ckpt-1.index -------------------------------------------------------------------------------- /data/trained_models/0/task_ckpts/model.ckpt-2.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LDlabs/seqMultiTaskRNN/HEAD/data/trained_models/0/task_ckpts/model.ckpt-2.index -------------------------------------------------------------------------------- /data/trained_models/0/task_ckpts/model.ckpt-3.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LDlabs/seqMultiTaskRNN/HEAD/data/trained_models/0/task_ckpts/model.ckpt-3.index -------------------------------------------------------------------------------- /data/trained_models/0/task_ckpts/model.ckpt-0.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LDlabs/seqMultiTaskRNN/HEAD/data/trained_models/0/task_ckpts/model.ckpt-0.data-00000-of-00001 -------------------------------------------------------------------------------- /data/trained_models/0/task_ckpts/model.ckpt-1.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LDlabs/seqMultiTaskRNN/HEAD/data/trained_models/0/task_ckpts/model.ckpt-1.data-00000-of-00001 -------------------------------------------------------------------------------- /data/trained_models/0/task_ckpts/model.ckpt-2.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LDlabs/seqMultiTaskRNN/HEAD/data/trained_models/0/task_ckpts/model.ckpt-2.data-00000-of-00001 -------------------------------------------------------------------------------- /data/trained_models/0/task_ckpts/model.ckpt-3.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LDlabs/seqMultiTaskRNN/HEAD/data/trained_models/0/task_ckpts/model.ckpt-3.data-00000-of-00001 -------------------------------------------------------------------------------- /requirements-cpu.txt: -------------------------------------------------------------------------------- 1 | numpy==1.14.5 2 | tensorflow==1.10.0 3 | scipy==1.1.0 4 | scikit-learn==0.20.0 5 | matplotlib==2.2.3 6 | PyYAML==3.13 7 | -e git://github.com/mattgolub/recurrent-whisperer.git@v1.2.0#egg=recurrent-whisperer 8 | -------------------------------------------------------------------------------- /requirements-gpu.txt: -------------------------------------------------------------------------------- 1 | numpy==1.14.5 2 | tensorflow-gpu==1.10.0 3 | scipy==1.1.0 4 | scikit-learn==0.20.0 5 | matplotlib==2.2.3 6 | PyYAML==3.13 7 | -e git://github.com/mattgolub/recurrent-whisperer.git@v1.2.0#egg=recurrent-whisperer 8 | -------------------------------------------------------------------------------- /environment-cpu.yml: -------------------------------------------------------------------------------- 1 | name: seqRNN 2 | channels: 3 | - defaults 4 | dependencies: 5 | - python==2.7.18 6 | - numpy==1.14.5 7 | - tensorflow==1.10.0 8 | - scipy==1.1.0 9 | - scikit-learn==0.20.0 10 | - matplotlib==2.2.3 11 | - PyYAML==3.13 12 | - pip 13 | - pip: 14 | - git+https://github.com/mattgolub/recurrent-whisperer.git@v1.2.0#egg=recurrent-whisperer 15 | -------------------------------------------------------------------------------- /environment-gpu.yml: -------------------------------------------------------------------------------- 1 | name: seqRNN 2 | channels: 3 | - defaults 4 | dependencies: 5 | - python==2.7.18 6 | - numpy==1.14.5 7 | - tensorflow-gpu==1.10.0 8 | - scipy==1.1.0 9 | - scikit-learn==0.20.0 10 | - matplotlib==2.2.3 11 | - PyYAML==3.13 12 | - pip 13 | - pip: 14 | - git+https://github.com/mattgolub/recurrent-whisperer.git@v1.2.0#egg=recurrent-whisperer 15 | -------------------------------------------------------------------------------- /data/trained_models/0/task_ckpts/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "/scratch/gpfs/lduncker/data/l2h_0.0000001_l2w_0.00001_alphaProj_0.001/goantiset/new_opt/1/0/task_ckpts/model.ckpt-3" 2 | all_model_checkpoint_paths: "/scratch/gpfs/lduncker/data/l2h_0.0000001_l2w_0.00001_alphaProj_0.001/goantiset/new_opt/1/0/task_ckpts/model.ckpt-0" 3 | all_model_checkpoint_paths: "/scratch/gpfs/lduncker/data/l2h_0.0000001_l2w_0.00001_alphaProj_0.001/goantiset/new_opt/1/0/task_ckpts/model.ckpt-1" 4 | all_model_checkpoint_paths: "/scratch/gpfs/lduncker/data/l2h_0.0000001_l2w_0.00001_alphaProj_0.001/goantiset/new_opt/1/0/task_ckpts/model.ckpt-2" 5 | all_model_checkpoint_paths: "/scratch/gpfs/lduncker/data/l2h_0.0000001_l2w_0.00001_alphaProj_0.001/goantiset/new_opt/1/0/task_ckpts/model.ckpt-3" 6 | -------------------------------------------------------------------------------- /data/trained_models/0/hp.json: -------------------------------------------------------------------------------- 1 | {"tau": 100, "l2_weight_init": 0, "w_rec_init": "randortho", "num_ring": 2, "target_cost": 0, "l2_h": 1e-07, "n_eachring": 2, "rule_start": 5, "sigma_x": 0.1, "batch_size_train": 64, "sequential_orthog": 0, "l1_h": 0, "n_rnn": 256, "n_input": 25, "n_rep": 256, "n_output": 3, "p_weight_train": null, "l1_weight": 0, "activation": "relu", "momentum": 0.9, "ruleset": "all", "loss_type": "lsq", "optimizer": "sgd_mom", "rules": ["fdgo", "fdanti", "delaygo", "delayanti"], "target_perf": 1.0, "learning_rate": 0.001, "delay_fac": 1, "c_intsyn": 0, "l2_weight": 1e-05, "alpha": 0.2, "dt": 20, "in_type": "normal", "use_separate_input": false, "rule_trains": ["fdgo", "fdanti", "delaygo", "delayanti"], "batch_size_test": 8192, "sigma_rec": 0.05, "n_rule": 20, "rnn_type": "LeakyRNN", "seed": 0, "save_name": "test", "max_steps": 12500000.0, "alpha_projection": 0.001, "ksi_intsyn": 0} -------------------------------------------------------------------------------- /seq_tools.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | 5 | def compute_projection_matrices(Swh, Suh, Syy, Shh, alpha): 6 | # ------ rescaled eigenvalue approach ------ 7 | 8 | # get eigendecomposition of covariance matrices 9 | Dwh, Vwh = np.linalg.eig(Swh) 10 | Duh, Vuh = np.linalg.eig(Suh) 11 | Dyy, Vyy = np.linalg.eig(Syy) 12 | Dhh, Vhh = np.linalg.eig(Shh) 13 | 14 | # recompute eigenvalue 15 | Dwh_scaled = rescale_cov_evals(Dwh, alpha) 16 | Duh_scaled = rescale_cov_evals(Duh, alpha) 17 | Dyy_scaled = rescale_cov_evals(Dyy, alpha) 18 | Dhh_scaled = rescale_cov_evals(Dhh, alpha) 19 | 20 | # reconstruct projection matrices with eigenvalues rescaled (and inverted: high variance dims are zero-ed out) 21 | P1 = tf.constant(np.matmul(np.matmul(Vwh, np.diag(Dwh_scaled)), Vwh.T), dtype=tf.float32) # output space W cov(Z) W' 22 | P2 = tf.constant(np.matmul(np.matmul(Vuh, np.diag(Duh_scaled)), Vuh.T), dtype=tf.float32) # input space cov(Z) 23 | P3 = tf.constant(np.matmul(np.matmul(Vyy, np.diag(Dyy_scaled)), Vyy.T), dtype=tf.float32) # readiyt space cov(Y) 24 | P4 = tf.constant(np.matmul(np.matmul(Vhh, np.diag(Dhh_scaled)), Vhh.T), dtype=tf.float32) # recurrent space cov(H) 25 | 26 | return P1, P2, P3, P4 27 | 28 | 29 | def rescale_cov_evals(evals, alpha): 30 | # ---- cut-off ---- 31 | fvals = alpha / (alpha + evals) 32 | 33 | return fvals 34 | 35 | 36 | def compute_covariance(x): 37 | # computes X * X.T 38 | return np.matmul(x, x.T) / (x.shape[1] - 1) # or use biased estimate? 39 | -------------------------------------------------------------------------------- /example_sequential_training.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | import os 4 | import train 5 | 6 | # set regularization parameters 7 | 8 | # 1. L2 activity reg 9 | l2_h_value = 1e-7 10 | 11 | # 2. L2 weight reg 12 | l2_w_value = 1e-5 13 | 14 | # 3. value for alpha parameter in gradient projection for continual learning 15 | alpha_value = 0.001 16 | 17 | # set total number of iterations 18 | # max_iter_steps = 7.5e6 # from paper 19 | max_iter_steps = 1e6 20 | 21 | # save some info on what regularizers were used in folder structure 22 | folder = 'example_run' 23 | seed_val = 0 24 | 25 | # which rules to train on 26 | tasklabel = 'goantiset' 27 | rule_set = ['fdgo', 'fdanti', 'delaygo', 'delayanti'] 28 | 29 | # set directory for saving model 30 | homedir = 'data' 31 | filedir = os.path.join(homedir, folder, tasklabel, 'proj_both', str(seed_val)) 32 | 33 | # perform model training 34 | train.train_sequential_orthogonalized(filedir, projGrad=True, applyProj='both', 35 | alpha=alpha_value, seed=seed_val, max_steps=max_iter_steps, ruleset='all', 36 | rule_trains=rule_set, 37 | hp={'activation': 'relu', 38 | 'l1_h': 0, 39 | 'l2_h': l2_h_value, 40 | 'l1_weight': 0, 41 | 'l2_weight': l2_w_value, 42 | 'l2_weight_init': 0, 43 | 'n_eachring': 2, 44 | 'n_output': 1 + 2, 45 | 'n_input': 1 + 2 * 2 + 20, 46 | 'delay_fac': 1, 47 | 'sigma_rec': 0.05, 48 | 'sigma_x': 0.1, 49 | 'optimizer': 'sgd_mom', 50 | 'momentum': 0.9, 51 | 'learning_rate': 0.001, 52 | 'use_separate_input': False}, 53 | display_step=1000, 54 | rich_output=False) 55 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # seqMultiTaskRNN 2 | This repository contains code to accompany the [paper](https://proceedings.neurips.cc/paper/2020/hash/a576eafbce762079f7d1f77fca1c5cc2-Abstract.html) 3 | 4 | Duncker, L.\*, Driscoll, L.\*, Shenoy, K. V., Sahani, M.\*\*, & Sussillo, D.\*\* (2020). Organizing recurrent network dynamics by task-computation to enable continual learning. Advances in Neural Information Processing Systems, 33. 5 | 6 | ``` 7 | @inproceedings{duncker+driscoll:2020:neurips, 8 | title={Organizing recurrent network dynamics by task-computation to enable continual learning}, 9 | author={Duncker, Lea and Driscoll, Laura N and Shenoy, Krishna V and Sahani, Maneesh and Sussillo, David}, 10 | journal={Advances in Neural Information Processing Systems}, 11 | volume={33}, 12 | year={2020} 13 | } 14 | ``` 15 | the repository is based on work in 16 | 17 | Yang, G. R., Joglekar, M. R., Song, H. F., Newsome, W. T., & Wang, X. J. (2019). Task representations in neural networks trained to perform many cognitive tasks. Nature neuroscience, 22(2), 297-306. 18 | 19 | and code which can be found [here](https://github.com/gyyang/multitask) (though versions might have diverged) 20 | 21 | ## Installation 22 | The code runs on Python 2.7 and an older tensorflow version. After cloning the repository, you can create a virtual environment and install the requirements using 23 | 24 | ``` 25 | virtualenv -p /usr/bin/python2.7 seqRNN 26 | source seqRNN/bin/activate 27 | pip install -r requirements-cpu.txt 28 | ``` 29 | 30 | Alternatively, to create an environment with conda: 31 | 32 | ``` 33 | conda env create -f environment-cpu.yml 34 | conda activate seqRNN 35 | ``` 36 | 37 | Note that this installs tensorflow to run on your CPU. If you'd rather run tensorflow on a GPU, use `requirements-gpu.txt` or `environment-gpu.yml` instead. 38 | 39 | ## Examples 40 | An example script for sequentially training an RNN on the task-set from the paper using our continual learning approach is provided in the script `example_sequential_training.py` 41 | 42 | 43 | The folder `data/trained_models/` contains an example trained network. `analyses/demos.ipynb` contains some examples to reproduce analyses from the paper. 44 | 45 | Some of the analyses rely on running [FixedPointFinder](https://github.com/mattgolub/fixed-point-finder) on the trained RNN. 46 | You need to [install FixedPointFinder](https://github.com/mattgolub/fixed-point-finder) and then specify the relevant Python path in `analyses/run_fixed_point_finder.py` by editing the line 47 | 48 | ``` 49 | PATH_TO_FIXED_POINT_FINDER = '/path/to/your/directory/fixed-point-finder/' 50 | ``` 51 | to match the path corresponding to your local directory. 52 | -------------------------------------------------------------------------------- /opt_tools.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | class AdamOptimizer_withProjection(tf.train.Optimizer): 5 | """ 6 | implements modified version of adam optimizer 7 | """ 8 | 9 | def __init__(self, learning_rate=0.001, 10 | beta1=0.9, 11 | beta2=0.999, 12 | epsilon=1e-8): 13 | 14 | self.learning_rate = learning_rate 15 | # adam parameters 16 | self.beta1 = beta1 17 | self.beta2 = beta2 18 | self.epsilon = epsilon 19 | 20 | self.m = {} 21 | self.u = {} 22 | self.t = tf.Variable(0.0, trainable=False) 23 | 24 | for v in tf.trainable_variables(): 25 | self.m[v] = tf.Variable(tf.zeros(tf.shape(v.initial_value)), trainable=False) 26 | self.u[v] = tf.Variable(tf.zeros(tf.shape(v.initial_value)), trainable=False) 27 | 28 | def apply_gradients(self, gvs, P1, P2, P3, P4, taskNumber): 29 | 30 | t = self.t.assign_add(1.0) 31 | 32 | if taskNumber == 0: 33 | doProj = False 34 | else: 35 | doProj = True 36 | 37 | update_ops = [] 38 | for (g, v) in gvs: 39 | m = self.m[v].assign(self.beta1 * self.m[v] + (1 - self.beta1) * g) 40 | u = self.u[v].assign(self.beta2 * self.u[v] + (1 - self.beta2) * g * g) 41 | 42 | m_hat = m / (1 - tf.pow(self.beta1, t)) 43 | u_hat = u / (1 - tf.pow(self.beta2, t)) 44 | 45 | update = -self.learning_rate * m_hat / (tf.sqrt(u_hat) + self.epsilon) 46 | 47 | # projections are specific to recurrent or readout matrices, so check for name 48 | if doProj: 49 | if 'rnn/leaky_rnn_cell/kernel:0' in v.name: 50 | # continual learning correction for recurrent/input weight update 51 | update_proj = tf.matmul(tf.matmul(P2, update), P1) 52 | elif 'output/weights:0' in v.name: 53 | # continual learning correction for readout weight update 54 | update_proj = tf.matmul(tf.matmul(P4, update), P3) 55 | # update_proj = tf.matmul(P1, update) 56 | 57 | else: 58 | update_proj = update 59 | 60 | update_ops.append(v.assign_add(update_proj)) 61 | 62 | return tf.group(*update_ops) 63 | 64 | 65 | class GradientDescentOptimizer_withProjection(tf.train.Optimizer): 66 | """ 67 | implements modified version of SGD optimizer 68 | """ 69 | 70 | def __init__(self, learning_rate=0.001): 71 | 72 | self.learning_rate = learning_rate 73 | 74 | def apply_gradients(self, gvs, P1, P2, P3, P4, taskNumber): 75 | 76 | if taskNumber == 0: 77 | doProj = False 78 | else: 79 | doProj = True 80 | 81 | update_ops = [] 82 | for (g, v) in gvs: 83 | 84 | update = -self.learning_rate * g 85 | 86 | if doProj: 87 | if 'rnn/leaky_rnn_cell/kernel:0' in v.name: 88 | # continual learning correction for recurrent/input weight update 89 | update_proj = tf.matmul(tf.matmul(P2, update), P1) 90 | elif 'output/weights:0' in v.name: 91 | # continual learning correction for readout weight update 92 | update_proj = tf.matmul(tf.matmul(P4, update), P3) 93 | # update_proj = tf.matmul(P1, update) 94 | else: 95 | update_proj = update 96 | 97 | update_ops.append(v.assign_add(update_proj)) 98 | 99 | return tf.group(*update_ops) 100 | 101 | 102 | class MomentumOptimizer_withProjection(tf.train.Optimizer): 103 | """ 104 | implements modified version of SGD optimizer 105 | """ 106 | 107 | def __init__(self, learning_rate=0.001, 108 | momentum=0.1): 109 | 110 | self.learning_rate = learning_rate 111 | self.momentum = momentum 112 | 113 | self.m = {} 114 | for v in tf.trainable_variables(): 115 | self.m[v] = tf.Variable(tf.zeros(tf.shape(v.initial_value)), trainable=False) 116 | 117 | def apply_gradients(self, gvs, P1, P2, P3, P4, taskNumber): 118 | 119 | if taskNumber == 0: 120 | doProj = False 121 | else: 122 | doProj = True 123 | 124 | update_ops = [] 125 | for (g, v) in gvs: 126 | self.m[v] = self.momentum * self.m[v] + g 127 | 128 | update = -self.learning_rate * self.m[v] 129 | 130 | if doProj: 131 | if 'rnn/leaky_rnn_cell/kernel:0' in v.name: 132 | # continual learning correction for recurrent/input weight update 133 | update_proj = tf.matmul(tf.matmul(P2, update), P1) 134 | elif 'output/weights:0' in v.name: 135 | # continual learning correction for readout weight update 136 | # update_proj = tf.matmul(P1, update) 137 | update_proj = tf.matmul(tf.matmul(P4, update), P3) 138 | else: 139 | update_proj = update 140 | 141 | update_ops.append(v.assign_add(update_proj)) 142 | 143 | return tf.group(*update_ops) 144 | -------------------------------------------------------------------------------- /analyses/run_fixed_point_finder.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | import os 3 | import pdb 4 | import numpy as np 5 | import numpy.random as npr 6 | import tensorflow as tf 7 | import sys 8 | import getpass 9 | 10 | ################################################################# 11 | # Setup path to data and code 12 | ################################################################## 13 | 14 | PATH_NET = os.path.dirname(os.getcwd()) 15 | sys.path.insert(0, PATH_NET) 16 | 17 | # directory where example network is stored 18 | model_dir_all = os.path.join(PATH_NET, 'data', 'trained_models', '0') 19 | 20 | from task import generate_trials 21 | from network import FixedPoint_Model 22 | import tools 23 | 24 | from RecurrentWhisperer import RecurrentWhisperer 25 | 26 | ################################################################## 27 | # set up path to fixed point finder (EDIT THIS) 28 | ################################################################## 29 | 30 | PATH_TO_FIXED_POINT_FINDER = '/path/to/your/directory/fixed-point-finder/' 31 | 32 | # add fixed point finder to path 33 | sys.path.insert(0, PATH_TO_FIXED_POINT_FINDER) 34 | from FixedPointFinder import FixedPointFinder 35 | 36 | ################################################################## 37 | # run fixed point finder 38 | ################################################################## 39 | 40 | 41 | NOISE_SCALE = 0.05 # 0.5 # Standard deviation of noise added to initial states 42 | N_INITS = 1000 # The number of initial states to provide 43 | 44 | 45 | task_list = ['delaygo', 'delayanti', 'fdgo', 'fdanti'] 46 | 47 | ################################################################## 48 | 49 | 50 | def add_unique_to_inputs_list(dict_list, key, value): 51 | for d in range(len(dict_list)): 52 | if (dict_list.values()[d] == value).all(): 53 | return False, dict_list 54 | 55 | dict_list.update({key: value}) 56 | return True, dict_list 57 | 58 | 59 | def get_filename(trial, epoch, t): 60 | ind_stim_loc = 180 * trial.y_loc[-1, t] / np.pi 61 | filename = trial.epochs.keys()[epoch] + '_' + str(round(ind_stim_loc, 2)) 62 | 63 | return filename, ind_stim_loc 64 | 65 | 66 | for rule in task_list: 67 | model = FixedPoint_Model(model_dir_all) 68 | with tf.Session() as sess: 69 | model.restore() 70 | model._sigma = 0 71 | # get all connection weights and biases as tensorflow variables 72 | var_list = model.var_list 73 | # evaluate the parameters after training 74 | params = [sess.run(var) for var in var_list] 75 | # get hparams 76 | hparams = model.hp 77 | # create a trial 78 | trial = generate_trials(rule, hparams, mode='test', noise_on=False, batch_size=40) # get feed_dict 79 | feed_dict = tools.gen_feed_dict(model, trial, hparams) 80 | # run model 81 | h_tf, y_hat_tf = sess.run([model.h, model.y_hat], feed_dict=feed_dict) # (n_time, n_condition, n_neuron) 82 | 83 | ################################################################## 84 | # get shapes 85 | n_steps, n_trials, n_input_dim = np.shape(trial.x) 86 | n_rnn = np.shape(h_tf)[2] 87 | n_output = np.shape(y_hat_tf)[2] 88 | 89 | # Fixed point finder hyperparameters 90 | # See FixedPointFinder.py for detailed descriptions of available 91 | # hyperparameters. 92 | fpf_hps = {} 93 | alr_dict = ({'decrease_factor': .95, 'initial_rate': 1}) 94 | 95 | n_epochs = len(trial.epochs) 96 | for epoch in range(n_epochs): 97 | e_start = max([0, trial.epochs.values()[epoch][0]]) 98 | end_set = [n_steps, trial.epochs.values()[epoch][1]] 99 | e_end = min(x for x in end_set if x is not None) 100 | 101 | n_inputs = 0 102 | input_set = {str(n_inputs): np.zeros((1, n_input_dim))} 103 | 104 | for t in range(0, np.shape(h_tf)[1], 10): # Set which trials you want to find fixed points on 105 | 106 | inputs = np.squeeze(trial.x[e_start, t, :]) 107 | inputs = inputs[np.newaxis, :] 108 | inputs_big = inputs[np.newaxis, :] 109 | 110 | unique_input, input_set = add_unique_to_inputs_list(input_set, str(n_inputs), inputs) 111 | 112 | if unique_input: 113 | n_inputs += 1 114 | input_set[str(n_inputs)] = inputs 115 | 116 | fpf = [] 117 | fpf = FixedPointFinder(model.cell, sess, alr_hps=alr_dict, method='joint', verbose=True, **fpf_hps) # do_compute_input_jacobians = True , q_tol = 1e-1, do_q_tol = True 118 | 119 | example_predictions = {'state': np.transpose(h_tf, (1, 0, 2)), # [0:90,0:1,:] 120 | 'output': np.transpose(y_hat_tf, (1, 0, 2))} 121 | 122 | initial_states = fpf.sample_states(example_predictions['state'][:, :, :], # specify T inds removed e_start:e_end 123 | n_inits=N_INITS, 124 | noise_scale=NOISE_SCALE) 125 | # Run the fixed point finder 126 | unique_fps, all_fps = fpf.find_fixed_points(initial_states, inputs) 127 | 128 | if unique_fps.xstar.shape[0] > 0: 129 | 130 | all_fps = {} 131 | all_fps = {'xstar': unique_fps.xstar, 132 | # 'J_inputs':unique_fps.J_inputs, 133 | 'J_xstar': unique_fps.J_xstar, 134 | 'qstar': unique_fps.qstar, 135 | 'inputs': unique_fps.inputs, 136 | 'epoch_inds': range(e_start, e_end), 137 | 'noise_var': NOISE_SCALE, 138 | 'state_traj': example_predictions['state'], 139 | 'out_dir': 180 * trial.y_loc[-1, t] / np.pi} 140 | 141 | save_dir = os.path.join(model_dir_all, 'tf_fixed_pts_all_init', rule) 142 | filename, ind_stim_loc = get_filename(trial, epoch, t) 143 | 144 | if not os.path.exists(save_dir): 145 | os.makedirs(save_dir) 146 | np.savez(os.path.join(save_dir, filename + '.npz'), **all_fps) 147 | -------------------------------------------------------------------------------- /tools.py: -------------------------------------------------------------------------------- 1 | """Utility functions.""" 2 | 3 | import os 4 | import errno 5 | import six 6 | import json 7 | import pickle 8 | import numpy as np 9 | 10 | 11 | def gen_feed_dict(model, trial, hp): 12 | """Generate feed_dict for session run.""" 13 | if hp['in_type'] == 'normal': 14 | feed_dict = {model.x: trial.x, 15 | model.y: trial.y, 16 | model.c_mask: trial.c_mask} 17 | elif hp['in_type'] == 'multi': 18 | n_time, batch_size = trial.x.shape[:2] 19 | new_shape = [n_time, 20 | batch_size, 21 | hp['rule_start']*hp['n_rule']] 22 | 23 | x = np.zeros(new_shape, dtype=np.float32) 24 | for i in range(batch_size): 25 | ind_rule = np.argmax(trial.x[0, i, hp['rule_start']:]) 26 | i_start = ind_rule*hp['rule_start'] 27 | x[:, i, i_start:i_start+hp['rule_start']] = \ 28 | trial.x[:, i, :hp['rule_start']] 29 | 30 | feed_dict = {model.x: x, 31 | model.y: trial.y, 32 | model.c_mask: trial.c_mask} 33 | else: 34 | raise ValueError() 35 | 36 | return feed_dict 37 | 38 | 39 | def _contain_model_file(model_dir): 40 | """Check if the directory contains model files.""" 41 | for f in os.listdir(model_dir): 42 | if 'model.ckpt' in f: 43 | return True 44 | return False 45 | 46 | 47 | def _valid_model_dirs(root_dir): 48 | """Get valid model directories given a root directory.""" 49 | return [x[0] for x in os.walk(root_dir) if _contain_model_file(x[0])] 50 | 51 | 52 | def valid_model_dirs(root_dir): 53 | """Get valid model directories given a root directory(s). 54 | 55 | Args: 56 | root_dir: str or list of strings 57 | """ 58 | if isinstance(root_dir, six.string_types): 59 | return _valid_model_dirs(root_dir) 60 | else: 61 | model_dirs = list() 62 | for d in root_dir: 63 | model_dirs.extend(_valid_model_dirs(d)) 64 | return model_dirs 65 | 66 | 67 | def load_log(model_dir): 68 | """Load the log file of model save_name""" 69 | fname = os.path.join(model_dir, 'log.json') 70 | if not os.path.isfile(fname): 71 | return None 72 | 73 | with open(fname, 'r') as f: 74 | log = json.load(f) 75 | return log 76 | 77 | 78 | def save_log(log): 79 | """Save the log file of model.""" 80 | model_dir = log['model_dir'] 81 | fname = os.path.join(model_dir, 'log.json') 82 | with open(fname, 'w') as f: 83 | json.dump(log, f) 84 | 85 | 86 | def load_hp(model_dir): 87 | """Load the hyper-parameter file of model save_name""" 88 | fname = os.path.join(model_dir, 'hp.json') 89 | if not os.path.isfile(fname): 90 | fname = os.path.join(model_dir, 'hparams.json') # backward compat 91 | if not os.path.isfile(fname): 92 | return None 93 | 94 | with open(fname, 'r') as f: 95 | hp = json.load(f) 96 | 97 | # Use a different seed aftering loading, 98 | # since loading is typically for analysis 99 | hp['rng'] = np.random.RandomState(hp['seed']+1000) 100 | return hp 101 | 102 | 103 | def save_hp(hp, model_dir): 104 | """Save the hyper-parameter file of model save_name""" 105 | hp_copy = hp.copy() 106 | hp_copy.pop('rng') # rng can not be serialized 107 | with open(os.path.join(model_dir, 'hp.json'), 'w') as f: 108 | json.dump(hp_copy, f) 109 | 110 | 111 | def load_pickle(file): 112 | try: 113 | with open(file, 'rb') as f: 114 | data = pickle.load(f) 115 | except UnicodeDecodeError as e: 116 | with open(file, 'rb') as f: 117 | data = pickle.load(f, encoding='latin1') 118 | except Exception as e: 119 | print('Unable to load data ', file, ':', e) 120 | raise 121 | return data 122 | 123 | 124 | def find_all_models(root_dir, hp_target): 125 | """Find all models that satisfy hyperparameters. 126 | 127 | Args: 128 | root_dir: root directory 129 | hp_target: dictionary of hyperparameters 130 | 131 | Returns: 132 | model_dirs: list of model directories 133 | """ 134 | dirs = valid_model_dirs(root_dir) 135 | 136 | model_dirs = list() 137 | for d in dirs: 138 | hp = load_hp(d) 139 | if all(hp[key] == val for key, val in hp_target.items()): 140 | model_dirs.append(d) 141 | 142 | return model_dirs 143 | 144 | 145 | def find_model(root_dir, hp_target, perf_min=None): 146 | """Find one model that satisfies hyperparameters. 147 | 148 | Args: 149 | root_dir: root directory 150 | hp_target: dictionary of hyperparameters 151 | perf_min: float or None. If not None, minimum performance to be chosen 152 | 153 | Returns: 154 | d: model directory 155 | """ 156 | model_dirs = find_all_models(root_dir, hp_target) 157 | if perf_min is not None: 158 | model_dirs = select_by_perf(model_dirs, perf_min) 159 | 160 | if not model_dirs: 161 | # If list empty 162 | print('Model not found') 163 | return None, None 164 | 165 | d = model_dirs[0] 166 | hp = load_hp(d) 167 | 168 | log = load_log(d) 169 | # check if performance exceeds target 170 | if log['perf_min'][-1] < hp['target_perf']: 171 | print("""Warning: this network perform {:0.2f}, not reaching target 172 | performance {:0.2f}.""".format( 173 | log['perf_min'][-1], hp['target_perf'])) 174 | 175 | return d 176 | 177 | 178 | def select_by_perf(model_dirs, perf_min): 179 | """Select a list of models by a performance threshold.""" 180 | new_model_dirs = list() 181 | for model_dir in model_dirs: 182 | log = load_log(model_dir) 183 | # check if performance exceeds target 184 | if log['perf_min'][-1] > perf_min: 185 | new_model_dirs.append(model_dir) 186 | return new_model_dirs 187 | 188 | 189 | def mkdir_p(path): 190 | """ 191 | Portable mkdir -p 192 | 193 | """ 194 | try: 195 | os.makedirs(path) 196 | except OSError as e: 197 | if e.errno == errno.EEXIST and os.path.isdir(path): 198 | pass 199 | else: 200 | raise 201 | 202 | 203 | def gen_ortho_matrix(dim, rng=None): 204 | """Generate random orthogonal matrix 205 | Taken from scipy.stats.ortho_group 206 | Copied here from compatibilty with older versions of scipy 207 | """ 208 | H = np.eye(dim) 209 | for n in range(1, dim): 210 | if rng is None: 211 | x = np.random.normal(size=(dim-n+1,)) 212 | else: 213 | x = rng.normal(size=(dim-n+1,)) 214 | # random sign, 50/50, but chosen carefully to avoid roundoff error 215 | D = np.sign(x[0]) 216 | x[0] += D*np.sqrt((x*x).sum()) 217 | # Householder transformation 218 | Hx = -D*(np.eye(dim-n+1) - 2.*np.outer(x, x)/(x*x).sum()) 219 | mat = np.eye(dim) 220 | mat[n-1:, n-1:] = Hx 221 | H = np.dot(H, mat) 222 | return H 223 | -------------------------------------------------------------------------------- /data/trained_models/0/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "/scratch/gpfs/lduncker/data/l2h_0.0000001_l2w_0.00001_alphaProj_0.001/goantiset/new_opt/1/0/model.ckpt" 2 | all_model_checkpoint_paths: "/scratch/gpfs/lduncker/data/l2h_0.0000001_l2w_0.00001_alphaProj_0.001/goantiset/new_opt/1/0/ckpts/model.ckpt-682939" 3 | all_model_checkpoint_paths: "/scratch/gpfs/lduncker/data/l2h_0.0000001_l2w_0.00001_alphaProj_0.001/goantiset/new_opt/1/0/ckpts/model.ckpt-683939" 4 | all_model_checkpoint_paths: "/scratch/gpfs/lduncker/data/l2h_0.0000001_l2w_0.00001_alphaProj_0.001/goantiset/new_opt/1/0/ckpts/model.ckpt-684939" 5 | all_model_checkpoint_paths: "/scratch/gpfs/lduncker/data/l2h_0.0000001_l2w_0.00001_alphaProj_0.001/goantiset/new_opt/1/0/ckpts/model.ckpt-685939" 6 | all_model_checkpoint_paths: "/scratch/gpfs/lduncker/data/l2h_0.0000001_l2w_0.00001_alphaProj_0.001/goantiset/new_opt/1/0/ckpts/model.ckpt-686939" 7 | all_model_checkpoint_paths: "/scratch/gpfs/lduncker/data/l2h_0.0000001_l2w_0.00001_alphaProj_0.001/goantiset/new_opt/1/0/ckpts/model.ckpt-687939" 8 | all_model_checkpoint_paths: "/scratch/gpfs/lduncker/data/l2h_0.0000001_l2w_0.00001_alphaProj_0.001/goantiset/new_opt/1/0/ckpts/model.ckpt-688939" 9 | all_model_checkpoint_paths: "/scratch/gpfs/lduncker/data/l2h_0.0000001_l2w_0.00001_alphaProj_0.001/goantiset/new_opt/1/0/ckpts/model.ckpt-689939" 10 | all_model_checkpoint_paths: "/scratch/gpfs/lduncker/data/l2h_0.0000001_l2w_0.00001_alphaProj_0.001/goantiset/new_opt/1/0/ckpts/model.ckpt-690939" 11 | all_model_checkpoint_paths: "/scratch/gpfs/lduncker/data/l2h_0.0000001_l2w_0.00001_alphaProj_0.001/goantiset/new_opt/1/0/ckpts/model.ckpt-691939" 12 | all_model_checkpoint_paths: "/scratch/gpfs/lduncker/data/l2h_0.0000001_l2w_0.00001_alphaProj_0.001/goantiset/new_opt/1/0/ckpts/model.ckpt-692939" 13 | all_model_checkpoint_paths: "/scratch/gpfs/lduncker/data/l2h_0.0000001_l2w_0.00001_alphaProj_0.001/goantiset/new_opt/1/0/ckpts/model.ckpt-693939" 14 | all_model_checkpoint_paths: "/scratch/gpfs/lduncker/data/l2h_0.0000001_l2w_0.00001_alphaProj_0.001/goantiset/new_opt/1/0/ckpts/model.ckpt-694939" 15 | all_model_checkpoint_paths: "/scratch/gpfs/lduncker/data/l2h_0.0000001_l2w_0.00001_alphaProj_0.001/goantiset/new_opt/1/0/ckpts/model.ckpt-695939" 16 | all_model_checkpoint_paths: "/scratch/gpfs/lduncker/data/l2h_0.0000001_l2w_0.00001_alphaProj_0.001/goantiset/new_opt/1/0/ckpts/model.ckpt-696939" 17 | all_model_checkpoint_paths: "/scratch/gpfs/lduncker/data/l2h_0.0000001_l2w_0.00001_alphaProj_0.001/goantiset/new_opt/1/0/ckpts/model.ckpt-697939" 18 | all_model_checkpoint_paths: "/scratch/gpfs/lduncker/data/l2h_0.0000001_l2w_0.00001_alphaProj_0.001/goantiset/new_opt/1/0/ckpts/model.ckpt-698939" 19 | all_model_checkpoint_paths: "/scratch/gpfs/lduncker/data/l2h_0.0000001_l2w_0.00001_alphaProj_0.001/goantiset/new_opt/1/0/ckpts/model.ckpt-699939" 20 | all_model_checkpoint_paths: "/scratch/gpfs/lduncker/data/l2h_0.0000001_l2w_0.00001_alphaProj_0.001/goantiset/new_opt/1/0/ckpts/model.ckpt-700939" 21 | all_model_checkpoint_paths: "/scratch/gpfs/lduncker/data/l2h_0.0000001_l2w_0.00001_alphaProj_0.001/goantiset/new_opt/1/0/ckpts/model.ckpt-701939" 22 | all_model_checkpoint_paths: "/scratch/gpfs/lduncker/data/l2h_0.0000001_l2w_0.00001_alphaProj_0.001/goantiset/new_opt/1/0/ckpts/model.ckpt-702939" 23 | all_model_checkpoint_paths: "/scratch/gpfs/lduncker/data/l2h_0.0000001_l2w_0.00001_alphaProj_0.001/goantiset/new_opt/1/0/ckpts/model.ckpt-703939" 24 | all_model_checkpoint_paths: "/scratch/gpfs/lduncker/data/l2h_0.0000001_l2w_0.00001_alphaProj_0.001/goantiset/new_opt/1/0/ckpts/model.ckpt-704939" 25 | all_model_checkpoint_paths: "/scratch/gpfs/lduncker/data/l2h_0.0000001_l2w_0.00001_alphaProj_0.001/goantiset/new_opt/1/0/ckpts/model.ckpt-705939" 26 | all_model_checkpoint_paths: "/scratch/gpfs/lduncker/data/l2h_0.0000001_l2w_0.00001_alphaProj_0.001/goantiset/new_opt/1/0/ckpts/model.ckpt-706939" 27 | all_model_checkpoint_paths: "/scratch/gpfs/lduncker/data/l2h_0.0000001_l2w_0.00001_alphaProj_0.001/goantiset/new_opt/1/0/ckpts/model.ckpt-707939" 28 | all_model_checkpoint_paths: "/scratch/gpfs/lduncker/data/l2h_0.0000001_l2w_0.00001_alphaProj_0.001/goantiset/new_opt/1/0/ckpts/model.ckpt-708939" 29 | all_model_checkpoint_paths: "/scratch/gpfs/lduncker/data/l2h_0.0000001_l2w_0.00001_alphaProj_0.001/goantiset/new_opt/1/0/ckpts/model.ckpt-709939" 30 | all_model_checkpoint_paths: "/scratch/gpfs/lduncker/data/l2h_0.0000001_l2w_0.00001_alphaProj_0.001/goantiset/new_opt/1/0/ckpts/model.ckpt-710939" 31 | all_model_checkpoint_paths: "/scratch/gpfs/lduncker/data/l2h_0.0000001_l2w_0.00001_alphaProj_0.001/goantiset/new_opt/1/0/ckpts/model.ckpt-711939" 32 | all_model_checkpoint_paths: "/scratch/gpfs/lduncker/data/l2h_0.0000001_l2w_0.00001_alphaProj_0.001/goantiset/new_opt/1/0/ckpts/model.ckpt-712939" 33 | all_model_checkpoint_paths: "/scratch/gpfs/lduncker/data/l2h_0.0000001_l2w_0.00001_alphaProj_0.001/goantiset/new_opt/1/0/ckpts/model.ckpt-713939" 34 | all_model_checkpoint_paths: "/scratch/gpfs/lduncker/data/l2h_0.0000001_l2w_0.00001_alphaProj_0.001/goantiset/new_opt/1/0/ckpts/model.ckpt-714939" 35 | all_model_checkpoint_paths: "/scratch/gpfs/lduncker/data/l2h_0.0000001_l2w_0.00001_alphaProj_0.001/goantiset/new_opt/1/0/ckpts/model.ckpt-715939" 36 | all_model_checkpoint_paths: "/scratch/gpfs/lduncker/data/l2h_0.0000001_l2w_0.00001_alphaProj_0.001/goantiset/new_opt/1/0/ckpts/model.ckpt-716939" 37 | all_model_checkpoint_paths: "/scratch/gpfs/lduncker/data/l2h_0.0000001_l2w_0.00001_alphaProj_0.001/goantiset/new_opt/1/0/ckpts/model.ckpt-717939" 38 | all_model_checkpoint_paths: "/scratch/gpfs/lduncker/data/l2h_0.0000001_l2w_0.00001_alphaProj_0.001/goantiset/new_opt/1/0/ckpts/model.ckpt-718939" 39 | all_model_checkpoint_paths: "/scratch/gpfs/lduncker/data/l2h_0.0000001_l2w_0.00001_alphaProj_0.001/goantiset/new_opt/1/0/ckpts/model.ckpt-719939" 40 | all_model_checkpoint_paths: "/scratch/gpfs/lduncker/data/l2h_0.0000001_l2w_0.00001_alphaProj_0.001/goantiset/new_opt/1/0/ckpts/model.ckpt-720939" 41 | all_model_checkpoint_paths: "/scratch/gpfs/lduncker/data/l2h_0.0000001_l2w_0.00001_alphaProj_0.001/goantiset/new_opt/1/0/ckpts/model.ckpt-721939" 42 | all_model_checkpoint_paths: "/scratch/gpfs/lduncker/data/l2h_0.0000001_l2w_0.00001_alphaProj_0.001/goantiset/new_opt/1/0/ckpts/model.ckpt-722939" 43 | all_model_checkpoint_paths: "/scratch/gpfs/lduncker/data/l2h_0.0000001_l2w_0.00001_alphaProj_0.001/goantiset/new_opt/1/0/ckpts/model.ckpt-723939" 44 | all_model_checkpoint_paths: "/scratch/gpfs/lduncker/data/l2h_0.0000001_l2w_0.00001_alphaProj_0.001/goantiset/new_opt/1/0/ckpts/model.ckpt-724939" 45 | all_model_checkpoint_paths: "/scratch/gpfs/lduncker/data/l2h_0.0000001_l2w_0.00001_alphaProj_0.001/goantiset/new_opt/1/0/ckpts/model.ckpt-725939" 46 | all_model_checkpoint_paths: "/scratch/gpfs/lduncker/data/l2h_0.0000001_l2w_0.00001_alphaProj_0.001/goantiset/new_opt/1/0/ckpts/model.ckpt-726939" 47 | all_model_checkpoint_paths: "/scratch/gpfs/lduncker/data/l2h_0.0000001_l2w_0.00001_alphaProj_0.001/goantiset/new_opt/1/0/ckpts/model.ckpt-727939" 48 | all_model_checkpoint_paths: "/scratch/gpfs/lduncker/data/l2h_0.0000001_l2w_0.00001_alphaProj_0.001/goantiset/new_opt/1/0/ckpts/model.ckpt-728939" 49 | all_model_checkpoint_paths: "/scratch/gpfs/lduncker/data/l2h_0.0000001_l2w_0.00001_alphaProj_0.001/goantiset/new_opt/1/0/ckpts/model.ckpt-729939" 50 | all_model_checkpoint_paths: "/scratch/gpfs/lduncker/data/l2h_0.0000001_l2w_0.00001_alphaProj_0.001/goantiset/new_opt/1/0/ckpts/model.ckpt-730939" 51 | all_model_checkpoint_paths: "/scratch/gpfs/lduncker/data/l2h_0.0000001_l2w_0.00001_alphaProj_0.001/goantiset/new_opt/1/0/ckpts/model.ckpt-731939" 52 | all_model_checkpoint_paths: "/scratch/gpfs/lduncker/data/l2h_0.0000001_l2w_0.00001_alphaProj_0.001/goantiset/new_opt/1/0/ckpts/model.ckpt-732939" 53 | all_model_checkpoint_paths: "/scratch/gpfs/lduncker/data/l2h_0.0000001_l2w_0.00001_alphaProj_0.001/goantiset/new_opt/1/0/ckpts/model.ckpt-733939" 54 | all_model_checkpoint_paths: "/scratch/gpfs/lduncker/data/l2h_0.0000001_l2w_0.00001_alphaProj_0.001/goantiset/new_opt/1/0/ckpts/model.ckpt-734939" 55 | all_model_checkpoint_paths: "/scratch/gpfs/lduncker/data/l2h_0.0000001_l2w_0.00001_alphaProj_0.001/goantiset/new_opt/1/0/ckpts/model.ckpt-735939" 56 | all_model_checkpoint_paths: "/scratch/gpfs/lduncker/data/l2h_0.0000001_l2w_0.00001_alphaProj_0.001/goantiset/new_opt/1/0/ckpts/model.ckpt-736939" 57 | all_model_checkpoint_paths: "/scratch/gpfs/lduncker/data/l2h_0.0000001_l2w_0.00001_alphaProj_0.001/goantiset/new_opt/1/0/ckpts/model.ckpt-737939" 58 | all_model_checkpoint_paths: "/scratch/gpfs/lduncker/data/l2h_0.0000001_l2w_0.00001_alphaProj_0.001/goantiset/new_opt/1/0/ckpts/model.ckpt-738939" 59 | all_model_checkpoint_paths: "/scratch/gpfs/lduncker/data/l2h_0.0000001_l2w_0.00001_alphaProj_0.001/goantiset/new_opt/1/0/ckpts/model.ckpt-739939" 60 | all_model_checkpoint_paths: "/scratch/gpfs/lduncker/data/l2h_0.0000001_l2w_0.00001_alphaProj_0.001/goantiset/new_opt/1/0/ckpts/model.ckpt-740939" 61 | all_model_checkpoint_paths: "/scratch/gpfs/lduncker/data/l2h_0.0000001_l2w_0.00001_alphaProj_0.001/goantiset/new_opt/1/0/ckpts/model.ckpt-741939" 62 | all_model_checkpoint_paths: "/scratch/gpfs/lduncker/data/l2h_0.0000001_l2w_0.00001_alphaProj_0.001/goantiset/new_opt/1/0/ckpts/model.ckpt-742939" 63 | all_model_checkpoint_paths: "/scratch/gpfs/lduncker/data/l2h_0.0000001_l2w_0.00001_alphaProj_0.001/goantiset/new_opt/1/0/ckpts/model.ckpt-743939" 64 | all_model_checkpoint_paths: "/scratch/gpfs/lduncker/data/l2h_0.0000001_l2w_0.00001_alphaProj_0.001/goantiset/new_opt/1/0/ckpts/model.ckpt-744939" 65 | all_model_checkpoint_paths: "/scratch/gpfs/lduncker/data/l2h_0.0000001_l2w_0.00001_alphaProj_0.001/goantiset/new_opt/1/0/ckpts/model.ckpt-745939" 66 | all_model_checkpoint_paths: "/scratch/gpfs/lduncker/data/l2h_0.0000001_l2w_0.00001_alphaProj_0.001/goantiset/new_opt/1/0/ckpts/model.ckpt-746939" 67 | all_model_checkpoint_paths: "/scratch/gpfs/lduncker/data/l2h_0.0000001_l2w_0.00001_alphaProj_0.001/goantiset/new_opt/1/0/ckpts/model.ckpt-747939" 68 | all_model_checkpoint_paths: "/scratch/gpfs/lduncker/data/l2h_0.0000001_l2w_0.00001_alphaProj_0.001/goantiset/new_opt/1/0/ckpts/model.ckpt-748939" 69 | all_model_checkpoint_paths: "/scratch/gpfs/lduncker/data/l2h_0.0000001_l2w_0.00001_alphaProj_0.001/goantiset/new_opt/1/0/ckpts/model.ckpt-749939" 70 | all_model_checkpoint_paths: "/scratch/gpfs/lduncker/data/l2h_0.0000001_l2w_0.00001_alphaProj_0.001/goantiset/new_opt/1/0/ckpts/model.ckpt-750939" 71 | all_model_checkpoint_paths: "/scratch/gpfs/lduncker/data/l2h_0.0000001_l2w_0.00001_alphaProj_0.001/goantiset/new_opt/1/0/ckpts/model.ckpt-751939" 72 | all_model_checkpoint_paths: "/scratch/gpfs/lduncker/data/l2h_0.0000001_l2w_0.00001_alphaProj_0.001/goantiset/new_opt/1/0/ckpts/model.ckpt-752939" 73 | all_model_checkpoint_paths: "/scratch/gpfs/lduncker/data/l2h_0.0000001_l2w_0.00001_alphaProj_0.001/goantiset/new_opt/1/0/ckpts/model.ckpt-753939" 74 | all_model_checkpoint_paths: "/scratch/gpfs/lduncker/data/l2h_0.0000001_l2w_0.00001_alphaProj_0.001/goantiset/new_opt/1/0/ckpts/model.ckpt-754939" 75 | all_model_checkpoint_paths: "/scratch/gpfs/lduncker/data/l2h_0.0000001_l2w_0.00001_alphaProj_0.001/goantiset/new_opt/1/0/ckpts/model.ckpt-755939" 76 | all_model_checkpoint_paths: "/scratch/gpfs/lduncker/data/l2h_0.0000001_l2w_0.00001_alphaProj_0.001/goantiset/new_opt/1/0/ckpts/model.ckpt-756939" 77 | all_model_checkpoint_paths: "/scratch/gpfs/lduncker/data/l2h_0.0000001_l2w_0.00001_alphaProj_0.001/goantiset/new_opt/1/0/ckpts/model.ckpt-757939" 78 | all_model_checkpoint_paths: "/scratch/gpfs/lduncker/data/l2h_0.0000001_l2w_0.00001_alphaProj_0.001/goantiset/new_opt/1/0/ckpts/model.ckpt-758939" 79 | all_model_checkpoint_paths: "/scratch/gpfs/lduncker/data/l2h_0.0000001_l2w_0.00001_alphaProj_0.001/goantiset/new_opt/1/0/ckpts/model.ckpt-759939" 80 | all_model_checkpoint_paths: "/scratch/gpfs/lduncker/data/l2h_0.0000001_l2w_0.00001_alphaProj_0.001/goantiset/new_opt/1/0/ckpts/model.ckpt-760939" 81 | all_model_checkpoint_paths: "/scratch/gpfs/lduncker/data/l2h_0.0000001_l2w_0.00001_alphaProj_0.001/goantiset/new_opt/1/0/ckpts/model.ckpt-761939" 82 | all_model_checkpoint_paths: "/scratch/gpfs/lduncker/data/l2h_0.0000001_l2w_0.00001_alphaProj_0.001/goantiset/new_opt/1/0/ckpts/model.ckpt-762939" 83 | all_model_checkpoint_paths: "/scratch/gpfs/lduncker/data/l2h_0.0000001_l2w_0.00001_alphaProj_0.001/goantiset/new_opt/1/0/ckpts/model.ckpt-763939" 84 | all_model_checkpoint_paths: "/scratch/gpfs/lduncker/data/l2h_0.0000001_l2w_0.00001_alphaProj_0.001/goantiset/new_opt/1/0/ckpts/model.ckpt-764939" 85 | all_model_checkpoint_paths: "/scratch/gpfs/lduncker/data/l2h_0.0000001_l2w_0.00001_alphaProj_0.001/goantiset/new_opt/1/0/ckpts/model.ckpt-765939" 86 | all_model_checkpoint_paths: "/scratch/gpfs/lduncker/data/l2h_0.0000001_l2w_0.00001_alphaProj_0.001/goantiset/new_opt/1/0/ckpts/model.ckpt-766939" 87 | all_model_checkpoint_paths: "/scratch/gpfs/lduncker/data/l2h_0.0000001_l2w_0.00001_alphaProj_0.001/goantiset/new_opt/1/0/ckpts/model.ckpt-767939" 88 | all_model_checkpoint_paths: "/scratch/gpfs/lduncker/data/l2h_0.0000001_l2w_0.00001_alphaProj_0.001/goantiset/new_opt/1/0/ckpts/model.ckpt-768939" 89 | all_model_checkpoint_paths: "/scratch/gpfs/lduncker/data/l2h_0.0000001_l2w_0.00001_alphaProj_0.001/goantiset/new_opt/1/0/ckpts/model.ckpt-769939" 90 | all_model_checkpoint_paths: "/scratch/gpfs/lduncker/data/l2h_0.0000001_l2w_0.00001_alphaProj_0.001/goantiset/new_opt/1/0/ckpts/model.ckpt-770939" 91 | all_model_checkpoint_paths: "/scratch/gpfs/lduncker/data/l2h_0.0000001_l2w_0.00001_alphaProj_0.001/goantiset/new_opt/1/0/ckpts/model.ckpt-771939" 92 | all_model_checkpoint_paths: "/scratch/gpfs/lduncker/data/l2h_0.0000001_l2w_0.00001_alphaProj_0.001/goantiset/new_opt/1/0/ckpts/model.ckpt-772939" 93 | all_model_checkpoint_paths: "/scratch/gpfs/lduncker/data/l2h_0.0000001_l2w_0.00001_alphaProj_0.001/goantiset/new_opt/1/0/ckpts/model.ckpt-773939" 94 | all_model_checkpoint_paths: "/scratch/gpfs/lduncker/data/l2h_0.0000001_l2w_0.00001_alphaProj_0.001/goantiset/new_opt/1/0/ckpts/model.ckpt-774939" 95 | all_model_checkpoint_paths: "/scratch/gpfs/lduncker/data/l2h_0.0000001_l2w_0.00001_alphaProj_0.001/goantiset/new_opt/1/0/ckpts/model.ckpt-775939" 96 | all_model_checkpoint_paths: "/scratch/gpfs/lduncker/data/l2h_0.0000001_l2w_0.00001_alphaProj_0.001/goantiset/new_opt/1/0/ckpts/model.ckpt-776939" 97 | all_model_checkpoint_paths: "/scratch/gpfs/lduncker/data/l2h_0.0000001_l2w_0.00001_alphaProj_0.001/goantiset/new_opt/1/0/ckpts/model.ckpt-777939" 98 | all_model_checkpoint_paths: "/scratch/gpfs/lduncker/data/l2h_0.0000001_l2w_0.00001_alphaProj_0.001/goantiset/new_opt/1/0/ckpts/model.ckpt-778939" 99 | all_model_checkpoint_paths: "/scratch/gpfs/lduncker/data/l2h_0.0000001_l2w_0.00001_alphaProj_0.001/goantiset/new_opt/1/0/ckpts/model.ckpt-779939" 100 | all_model_checkpoint_paths: "/scratch/gpfs/lduncker/data/l2h_0.0000001_l2w_0.00001_alphaProj_0.001/goantiset/new_opt/1/0/ckpts/model.ckpt-780939" 101 | all_model_checkpoint_paths: "/scratch/gpfs/lduncker/data/l2h_0.0000001_l2w_0.00001_alphaProj_0.001/goantiset/new_opt/1/0/model.ckpt" 102 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """Main training loop""" 2 | 3 | from __future__ import division 4 | 5 | import sys 6 | import time 7 | from collections import defaultdict 8 | 9 | import numpy as np 10 | import tensorflow as tf 11 | 12 | import task 13 | from task import generate_trials, generate_datasetTensors, datasetGeneratorFromTaskDef, defineDatasetFormat 14 | from network import Model, get_perf, Sequential_Model 15 | import tools 16 | from datetime import datetime as datetime 17 | from tensorflow.python.ops import array_ops 18 | import pdb 19 | from seq_tools import compute_projection_matrices, compute_covariance 20 | 21 | 22 | def get_default_hp(ruleset): 23 | '''Get a default hp. 24 | 25 | Useful for debugging. 26 | 27 | Returns: 28 | hp : a dictionary containing training hpuration 29 | ''' 30 | num_ring = task.get_num_ring(ruleset) 31 | n_rule = task.get_num_rule(ruleset) 32 | 33 | n_eachring = 2 34 | n_input, n_output = 1 + num_ring * n_eachring + n_rule, n_eachring + 1 35 | hp = { # factor to multiply delay periods during training 36 | 'delay_fac': 1, 37 | # batch size for training 38 | 'batch_size_train': 64, 39 | # batch_size for testing 40 | 'batch_size_test': 8192, # changed from 512 jan 8th 2019 41 | # n_reps for testing 42 | 'n_rep': 256, # changed from 16 jan 8th 2019 43 | # input type: normal, multi 44 | 'in_type': 'normal', 45 | # Type of RNNs: LeakyRNN, LeakyGRU, EILeakyGRU, GRU, LSTM 46 | 'rnn_type': 'LeakyRNN', 47 | # whether rule and stimulus inputs are represented separately 48 | 'use_separate_input': False, 49 | # Type of loss functions 50 | 'loss_type': 'lsq', 51 | # Optimizer 52 | 'optimizer': 'adam', 53 | # Type of activation runctions, relu, softplus, tanh, elu 54 | 'activation': 'relu', 55 | # Time constant (ms) 56 | 'tau': 100, 57 | # discretization time step (ms) 58 | 'dt': 20, 59 | # discretization time step/time constant 60 | 'alpha': 0.2, 61 | # recurrent noise 62 | 'sigma_rec': 0.05, 63 | # input noise 64 | 'sigma_x': 0.01, 65 | # leaky_rec weight initialization, diag, randortho, randgauss 66 | 'w_rec_init': 'randortho', 67 | # a default weak regularization prevents instability 68 | 'l1_h': 0, 69 | # l2 regularization on activity 70 | 'l2_h': 0, 71 | # l2 regularization on weight 72 | 'l1_weight': 0, 73 | # l2 regularization on weight 74 | 'l2_weight': 0, 75 | # orthogonalize separate task input, activity for sequential learning 76 | 'sequential_orthog': 0, 77 | # l2 regularization on deviation from initialization 78 | 'l2_weight_init': 0, 79 | # proportion of weights to train, None or float between (0, 1) 80 | 'p_weight_train': None, 81 | # Stopping performance 82 | 'target_perf': 1., 83 | # Stopping cost 84 | 'target_cost': 0, # basically off 85 | # number of units each ring 86 | 'n_eachring': n_eachring, 87 | # number of rings 88 | 'num_ring': num_ring, 89 | # number of rules 90 | 'n_rule': n_rule, 91 | # first input index for rule units 92 | 'rule_start': 1 + num_ring * n_eachring, 93 | # number of input units 94 | 'n_input': 1 + num_ring * n_eachring + n_rule, 95 | # number of output units 96 | 'n_output': n_eachring + 1, 97 | # number of recurrent units 98 | 'n_rnn': 256, 99 | # number of input units 100 | 'ruleset': ruleset, 101 | # name to save 102 | 'save_name': 'test', 103 | # learning rate 104 | 'learning_rate': 0.001, 105 | # momentum for sgd_mom 106 | 'momentum': 0.1, 107 | # intelligent synapses parameters, tuple (c, ksi) 108 | 'c_intsyn': 0, 109 | 'ksi_intsyn': 0, 110 | } 111 | 112 | return hp 113 | 114 | 115 | def do_eval(sess, model, log, rule_train): 116 | """Do evaluation. 117 | 118 | Args: 119 | sess: tensorflow session 120 | model: Model class instance 121 | log: dictionary that stores the log 122 | rule_train: string or list of strings, the rules being trained 123 | """ 124 | hp = model.hp 125 | if not hasattr(rule_train, '__iter__'): 126 | rule_name_print = rule_train 127 | else: 128 | rule_name_print = ' & '.join(rule_train) 129 | 130 | print('Trial {:7d}'.format(log['trials'][-1]) + 131 | ' | Time {:0.2f} s'.format(log['times'][-1]) + 132 | ' | Now training ' + rule_name_print) 133 | 134 | # print(hp['rules']) 135 | for rule_test in hp['rules']: 136 | # rule_test = rule_train 137 | # for rule_test in hp['rule_trains']: 138 | 139 | n_rep = hp['n_rep'] 140 | batch_size_test_rep = int(hp['batch_size_test'] / n_rep) 141 | clsq_tmp = list() 142 | creg_tmp = list() 143 | perf_tmp = list() 144 | for i_rep in range(n_rep): 145 | trial = generate_trials( 146 | rule_test, hp, 'random', batch_size=batch_size_test_rep, delay_fac=hp['delay_fac']) 147 | feed_dict = tools.gen_feed_dict(model, trial, hp) 148 | 149 | # import pdb 150 | # pdb.set_trace() 151 | 152 | c_lsq, c_reg, y_hat_test = sess.run( 153 | [model.cost_lsq, model.cost_reg, model.y_hat], 154 | feed_dict=feed_dict) 155 | 156 | # Cost is first summed over time, 157 | # and averaged across batch and units 158 | # We did the averaging over time through c_mask 159 | perf_test = np.mean(get_perf(y_hat_test, trial.y_loc)) 160 | clsq_tmp.append(c_lsq) 161 | creg_tmp.append(c_reg) 162 | perf_tmp.append(perf_test) 163 | 164 | log['cost_' + rule_test].append(np.mean(clsq_tmp, dtype=np.float64)) 165 | log['creg_' + rule_test].append(np.mean(creg_tmp, dtype=np.float64)) 166 | log['perf_' + rule_test].append(np.mean(perf_tmp, dtype=np.float64)) 167 | print('{:15s}'.format(rule_test) + 168 | '| cost {:0.6f}'.format(np.mean(clsq_tmp)) + 169 | '| c_reg {:0.6f}'.format(np.mean(creg_tmp)) + 170 | ' | perf {:0.2f}'.format(np.mean(perf_tmp))) 171 | sys.stdout.flush() 172 | 173 | # TODO: This needs to be fixed since now rules are strings 174 | if hasattr(rule_train, '__iter__'): 175 | rule_tmp = rule_train 176 | else: 177 | rule_tmp = [rule_train] 178 | perf_tests_mean = np.mean([log['perf_' + r][-1] for r in rule_tmp]) 179 | log['perf_avg'].append(perf_tests_mean) 180 | 181 | perf_tests_min = np.min([log['perf_' + r][-1] for r in rule_tmp]) 182 | log['perf_min'].append(perf_tests_min) 183 | 184 | cost_tests_max = np.max([log['cost_' + r][-1] for r in rule_tmp]) # jan 4 2019 185 | log['cost_max'].append(cost_tests_max) # jan 4 2019 186 | 187 | # Saving the model 188 | model.save() 189 | tools.save_log(log) 190 | 191 | return log 192 | 193 | 194 | def do_eval_test(sess, model, rule): 195 | """Do evaluation. 196 | 197 | Args: 198 | sess: tensorflow session 199 | model: Model class instance 200 | rule_train: string or list of strings, the rules being trained 201 | """ 202 | hp = model.hp 203 | 204 | trial = generate_trials(rule, hp, 'test') 205 | feed_dict = tools.gen_feed_dict(model, trial, hp) 206 | c_lsq, c_reg, y_hat_test = sess.run( 207 | [model.cost_lsq, model.cost_reg, model.y_hat], feed_dict=feed_dict) 208 | 209 | # Cost is first summed over time, 210 | # and averaged across batch and units 211 | # We did the averaging over time through c_mask 212 | perf_test = np.mean(get_perf(y_hat_test, trial.y_loc)) 213 | sys.stdout.flush() 214 | 215 | return c_lsq, c_reg, perf_test 216 | 217 | 218 | def display_rich_output(model, sess, step, log, model_dir): 219 | """Display step by step outputs during training.""" 220 | variance._compute_variance_bymodel(model, sess) 221 | rule_pair = ['contextdm1', 'contextdm2'] 222 | save_name = '_atstep' + str(step) 223 | title = ('Step ' + str(step) + 224 | ' Perf. {:0.2f}'.format(log['perf_avg'][-1])) 225 | variance.plot_hist_varprop(model_dir, rule_pair, 226 | figname_extra=save_name, 227 | title=title) 228 | plt.close('all') 229 | 230 | 231 | def train(model_dir, 232 | hp=None, 233 | max_steps=1e7, 234 | display_step=500, 235 | ruleset='mante', 236 | rule_trains=None, 237 | rule_prob_map=None, 238 | seed=0, 239 | rich_output=True, 240 | load_dir=None, 241 | trainables=None, 242 | fixReadoutandBias=False, 243 | fixBias=False, 244 | ): 245 | """Train the network. 246 | 247 | Args: 248 | model_dir: str, training directory 249 | hp: dictionary of hyperparameters 250 | max_steps: int, maximum number of training steps 251 | display_step: int, display steps 252 | ruleset: the set of rules to train 253 | rule_trains: list of rules to train, if None then all rules possible 254 | rule_prob_map: None or dictionary of relative rule probability 255 | seed: int, random seed to be used 256 | 257 | Returns: 258 | model is stored at model_dir/model.ckpt 259 | training configuration is stored at model_dir/hp.json 260 | """ 261 | 262 | tools.mkdir_p(model_dir) 263 | 264 | # Network parameters 265 | default_hp = get_default_hp(ruleset) 266 | if hp is not None: 267 | default_hp.update(hp) 268 | hp = default_hp 269 | hp['seed'] = seed 270 | hp['rng'] = np.random.RandomState(seed) 271 | 272 | # Rules to train and test. Rules in a set are trained together 273 | if rule_trains is None: 274 | # By default, training all rules available to this ruleset 275 | hp['rule_trains'] = task.rules_dict[ruleset] 276 | else: 277 | hp['rule_trains'] = rule_trains 278 | hp['rules'] = hp['rule_trains'] 279 | 280 | # Assign probabilities for rule_trains. 281 | if rule_prob_map is None: 282 | rule_prob_map = dict() 283 | 284 | # Turn into rule_trains format 285 | hp['rule_probs'] = None 286 | if hasattr(hp['rule_trains'], '__iter__'): 287 | # Set default as 1. 288 | 289 | rule_prob = np.array( 290 | [rule_prob_map.get(r, 1.) for r in hp['rule_trains']]) 291 | hp['rule_probs'] = list(rule_prob / np.sum(rule_prob)) 292 | 293 | tools.save_hp(hp, model_dir) 294 | 295 | # Build the model 296 | with tf.device('gpu:0'): 297 | model = Model(model_dir, hp=hp) 298 | 299 | # Display hp 300 | for key, val in hp.items(): 301 | print('{:20s} = '.format(key) + str(val)) 302 | 303 | if fixReadoutandBias is True: 304 | my_var_list = [var for var in model.var_list if 'rnn/leaky_rnn_cell/kernel:0' in var.name] 305 | print(my_var_list) 306 | elif fixBias is True: 307 | my_var_list = [var for var in model.var_list if 'rnn/leaky_rnn_cell/kernel:0' in var.name or 'output/weights:0' in var.name] 308 | else: 309 | my_var_list = model.var_list 310 | 311 | model.set_optimizer(var_list=my_var_list) 312 | 313 | # Store results 314 | log = defaultdict(list) 315 | log['model_dir'] = model_dir 316 | 317 | # Record time 318 | t_start = time.time() 319 | 320 | # Use customized session that launches the graph as well 321 | with tf.Session() as sess: 322 | sess.run(tf.global_variables_initializer()) 323 | 324 | # penalty on deviation from initial weight 325 | if hp['l2_weight_init'] > 0: 326 | anchor_ws = sess.run(model.weight_list) 327 | for w, w_val in zip(model.weight_list, anchor_ws): 328 | model.cost_reg += (hp['l2_weight_init'] * 329 | tf.nn.l2_loss(w - w_val)) 330 | 331 | model.set_optimizer(var_list=my_var_list) 332 | 333 | # partial weight training 334 | if ('p_weight_train' in hp and 335 | (hp['p_weight_train'] is not None) and 336 | hp['p_weight_train'] < 1.0): 337 | for w in model.weight_list: 338 | w_val = sess.run(w) 339 | w_size = sess.run(tf.size(w)) 340 | w_mask_tmp = np.linspace(0, 1, w_size) 341 | hp['rng'].shuffle(w_mask_tmp) 342 | ind_fix = w_mask_tmp > hp['p_weight_train'] 343 | w_mask = np.zeros(w_size, dtype=np.float32) 344 | w_mask[ind_fix] = 1e-1 # will be squared in l2_loss 345 | w_mask = tf.constant(w_mask) 346 | w_mask = tf.reshape(w_mask, w.shape) 347 | model.cost_reg += tf.nn.l2_loss((w - w_val) * w_mask) 348 | model.set_optimizer(var_list=my_var_list) 349 | 350 | step = 0 351 | run_ave_time = [] 352 | while step * hp['batch_size_train'] <= max_steps: 353 | try: 354 | # Validation 355 | if step % display_step == 0: 356 | grad_norm = tf.global_norm(model.clipped_gs) 357 | grad_norm_np = sess.run(grad_norm) 358 | # import pdb 359 | # pdb.set_trace() 360 | log['grad_norm'].append(grad_norm_np.item()) 361 | log['trials'].append(step * hp['batch_size_train']) 362 | log['times'].append(time.time() - t_start) 363 | log = do_eval(sess, model, log, hp['rule_trains']) 364 | # if log['perf_avg'][-1] > model.hp['target_perf']: 365 | # check if minimum performance is above target 366 | if log['perf_min'][-1] > model.hp['target_perf']: 367 | print('Perf reached the target: {:0.2f}'.format( 368 | hp['target_perf'])) 369 | break 370 | 371 | if rich_output: 372 | display_rich_output(model, sess, step, log, model_dir) 373 | 374 | # Training 375 | 376 | dtStart = datetime.now() 377 | sess.run(model.train_step) 378 | dtEnd = datetime.now() 379 | 380 | if len(run_ave_time) is 0: 381 | run_ave_time = np.expand_dims((dtEnd - dtStart).total_seconds(), axis=0) 382 | else: 383 | run_ave_time = np.concatenate((run_ave_time, np.expand_dims((dtEnd - dtStart).total_seconds(), axis=0))) 384 | 385 | # print(np.mean(run_ave_time)) 386 | # print((dtEnd-dtStart).total_seconds()) 387 | 388 | step += 1 389 | 390 | if step < 10: 391 | model.save_ckpt(step) 392 | 393 | if step < 1000: 394 | if step % display_step / 10 == 0: 395 | model.save_ckpt(step) 396 | 397 | if step % display_step == 0: 398 | model.save_ckpt(step) 399 | 400 | except KeyboardInterrupt: 401 | print("Optimization interrupted by user") 402 | break 403 | 404 | print("Optimization finished!") 405 | 406 | 407 | def train_sequential_orthogonalized( 408 | model_dir, 409 | rule_trains, 410 | hp=None, 411 | max_steps=1e7, 412 | display_step=500, 413 | rich_output=False, 414 | ruleset='mante', 415 | applyProj='both', 416 | seed=0, 417 | nEpisodeBatches=100, 418 | projGrad=True, 419 | alpha=0.001, 420 | fixReadout=False): 421 | '''Train the network sequentially. 422 | 423 | Args: 424 | model_dir: str, training directory 425 | rule_trains: a list of list of tasks to train sequentially 426 | hp: dictionary of hyperparameters 427 | max_steps: int, maximum number of training steps for each list of tasks 428 | display_step: int, display steps 429 | ruleset: the set of rules to train 430 | seed: int, random seed to be used 431 | 432 | Returns: 433 | model is stored at model_dir/model.ckpt 434 | training configuration is stored at model_dir/hp.json 435 | ''' 436 | 437 | tools.mkdir_p(model_dir) 438 | 439 | # Network parameters 440 | default_hp = get_default_hp(ruleset) 441 | if hp is not None: 442 | default_hp.update(hp) 443 | hp = default_hp 444 | hp['seed'] = seed 445 | hp['rng'] = np.random.RandomState(seed) 446 | hp['rule_trains'] = rule_trains 447 | # Get all rules by flattening the list of lists 448 | # hp['rules'] = [r for rs in rule_trains for r in rs] 449 | hp['rules'] = rule_trains 450 | 451 | # save some other parameters 452 | hp['alpha_projection'] = alpha 453 | hp['max_steps'] = max_steps 454 | 455 | # Number of training iterations for each rule 456 | rule_train_iters = [max_steps for _ in rule_trains] 457 | 458 | tools.save_hp(hp, model_dir) 459 | # Display hp 460 | for key, val in hp.items(): 461 | print('{:20s} = '.format(key) + str(val)) 462 | 463 | # Build the model 464 | model = Sequential_Model(model_dir, projGrad=projGrad, applyProj=applyProj, hp=hp) 465 | 466 | # Store results 467 | log = defaultdict(list) 468 | log['model_dir'] = model_dir 469 | 470 | # Record time 471 | t_start = time.time() 472 | 473 | def relu(x): 474 | return x * (x > 0.) 475 | 476 | # ------------------------------------------------------- 477 | 478 | # Use customized session that launches the graph as well 479 | with tf.Session() as sess: 480 | sess.run(tf.global_variables_initializer()) 481 | 482 | # penalty on deviation from initial weight 483 | if hp['l2_weight_init'] > 0: 484 | raise NotImplementedError() 485 | 486 | # Looping 487 | step_total = 0 488 | taskNumber = 0 489 | 490 | if fixReadout is True: 491 | my_var_list = [var for var in model.var_list if 'rnn/leaky_rnn_cell/kernel:0' in var.name] 492 | else: 493 | my_var_list = [var for var in model.var_list if 'rnn/leaky_rnn_cell/kernel:0' in var.name or 'output/weights:0' in var.name] 494 | 495 | # initialise projection matrices 496 | input_proj = tf.zeros((hp['n_rnn'] + hp['n_input'], hp['n_rnn'] + hp['n_input'])) 497 | activity_proj = tf.zeros((hp['n_rnn'], hp['n_rnn'])) 498 | output_proj = tf.zeros((hp['n_output'], hp['n_output'])) 499 | recurrent_proj = tf.zeros((hp['n_rnn'], hp['n_rnn'])) 500 | 501 | for i_rule_train, rule_train in enumerate(hp['rule_trains']): 502 | 503 | step = 0 504 | 505 | model.set_optimizer(activity_proj=activity_proj, input_proj=input_proj, output_proj=output_proj, recurrent_proj=recurrent_proj, taskNumber=taskNumber, var_list=my_var_list, alpha=alpha) 506 | 507 | # Keep training until reach max iterations 508 | while (step * hp['batch_size_train'] <= 509 | rule_train_iters[i_rule_train]): 510 | # Validation 511 | if step % display_step == 0: 512 | trial = step_total * hp['batch_size_train'] 513 | log['trials'].append(trial) 514 | log['times'].append(time.time() - t_start) 515 | log['rule_now'].append(rule_train) 516 | log = do_eval(sess, model, log, rule_train) 517 | if log['perf_avg'][-1] > model.hp['target_perf']: 518 | print('Perf reached the target: {:0.2f}'.format( 519 | hp['target_perf'])) 520 | break 521 | 522 | # Training 523 | # rule_train_now = hp['rng'].choice(rule_train) 524 | 525 | # Generate a random batch of trials. 526 | # Each batch has the same trial length 527 | trial = generate_trials( 528 | rule_train, hp, 'random', 529 | batch_size=hp['batch_size_train'], delay_fac=hp['delay_fac']) 530 | 531 | # Generating feed_dict. 532 | feed_dict = tools.gen_feed_dict(model, trial, hp) 533 | 534 | # update model 535 | sess.run(model.train_step, feed_dict=feed_dict) 536 | 537 | # # Get the weight after train step 538 | # v_current = sess.run(model.var_list) 539 | 540 | step += 1 541 | step_total += 1 542 | 543 | if step % display_step == 0: 544 | model.save_ckpt(step_total) 545 | 546 | # ---------- save model after its completed training the current task ---------- 547 | model.save_after_task(taskNumber) 548 | 549 | # ---------- generate task activity for continual learning ------- 550 | trial = generate_trials( 551 | rule_train, hp, 'random', 552 | batch_size=hp['batch_size_test'], delay_fac=hp['delay_fac']) 553 | 554 | # Generating feed_dict. 555 | feed_dict = tools.gen_feed_dict(model, trial, hp) 556 | eval_h, eval_x, eval_y, Wrec, Win = sess.run([model.h, model.x, model.y, model.w_rec, model.w_in], feed_dict=feed_dict) 557 | full_state = np.concatenate([eval_x, eval_h], -1) 558 | 559 | # get weight matrix after current task 560 | Wfull = np.concatenate([Win, Wrec], 0) 561 | 562 | # joint covariance matrix of input and activity 563 | Shx_task = compute_covariance(np.reshape(full_state, (-1, hp['n_rnn'] + hp['n_input'])).T) 564 | 565 | # covariance matrix of output 566 | Sy_task = compute_covariance(np.reshape(eval_y, (-1, hp['n_output'])).T) 567 | 568 | # get block matrices from Shx_task 569 | # Sh_task = Shx_task[-hp['n_rnn']:, -hp['n_rnn']:] 570 | Sh_task = np.matmul(np.matmul(Wfull.T, Shx_task), Wfull) 571 | 572 | # ---------- update stored covariance matrices for continual learning ------- 573 | if taskNumber == 0: 574 | input_cov = Shx_task 575 | activity_cov = Sh_task 576 | output_cov = Sy_task 577 | else: 578 | input_cov = taskNumber / (taskNumber + 1) * input_cov + Shx_task / (taskNumber + 1) 579 | activity_cov = taskNumber / (taskNumber + 1) * activity_cov + Sh_task / (taskNumber + 1) 580 | output_cov = taskNumber / (taskNumber + 1) * output_cov + Sy_task / (taskNumber + 1) 581 | 582 | # ---------- update projection matrices for continual learning ---------- 583 | activity_proj, input_proj, output_proj, recurrent_proj = compute_projection_matrices(activity_cov, input_cov, output_cov, input_cov[-hp['n_rnn']:, -hp['n_rnn']:], alpha) 584 | 585 | # update task number 586 | taskNumber += 1 587 | 588 | print("Optimization Finished!") 589 | -------------------------------------------------------------------------------- /network.py: -------------------------------------------------------------------------------- 1 | """Definition of the network model and various RNN cells""" 2 | 3 | from __future__ import division 4 | 5 | import os 6 | import numpy as np 7 | 8 | import tensorflow as tf 9 | from tensorflow.python.ops import array_ops 10 | from tensorflow.python.ops import init_ops 11 | from tensorflow.python.ops import math_ops 12 | from tensorflow.python.ops import nn_ops 13 | from tensorflow.python.ops import rnn 14 | from tensorflow.python.ops.rnn_cell_impl import RNNCell 15 | 16 | import tools 17 | from task import datasetGeneratorFromTaskDef, defineDatasetFormat 18 | from opt_tools import AdamOptimizer_withProjection, GradientDescentOptimizer_withProjection, MomentumOptimizer_withProjection 19 | 20 | 21 | def clip_grad(grad, max_norm): 22 | n = tf.norm(grad) 23 | 24 | # do_clip = tf.math.greater(n,max_norm) 25 | do_clip = 0 26 | 27 | with tf.Session() as session: 28 | session.run(tf.initialize_all_variables()) 29 | 30 | if do_clip: # .eval(): 31 | clipped_grad = (max_norm / n) * grad 32 | else: 33 | clipped_grad = grad 34 | 35 | return clipped_grad 36 | 37 | 38 | def is_weight(v): 39 | """Check if Tensorflow variable v is a connection weight.""" 40 | return ('kernel' in v.name or 'weight' in v.name) 41 | 42 | 43 | def popvec(y): 44 | """Population vector read out. 45 | 46 | Assuming the last dimension is the dimension to be collapsed 47 | 48 | Args: 49 | y: population output on a ring network. Numpy array (Batch, Units) 50 | 51 | Returns: 52 | Readout locations: Numpy array (Batch,) 53 | """ 54 | 55 | loc = np.arctan2(y[:, 0], y[:, 1]) 56 | return np.mod(loc, 2 * np.pi) # check this? January 22 2019 57 | 58 | 59 | def tf_popvec(y): 60 | """Population vector read-out in tensorflow.""" 61 | 62 | loc = tf.atan2(y[:, 0], y[:, 1]) 63 | return tf.mod(loc + np.pi, 2 * np.pi) # check this? January 22 2019 64 | 65 | 66 | def get_perf(y_hat, y_loc): 67 | """Get performance. 68 | 69 | Args: 70 | y_hat: Actual output. Numpy array (Time, Batch, Unit) 71 | y_loc: Target output location (-1 for fixation). 72 | Numpy array (Time, Batch) 73 | 74 | Returns: 75 | perf: Numpy array (Batch,) 76 | """ 77 | if len(y_hat.shape) != 3: 78 | raise ValueError('y_hat must have shape (Time, Batch, Unit)') 79 | # Only look at last time points 80 | 81 | y_loc = y_loc[-1] 82 | y_hat = y_hat[-1] 83 | 84 | # Fixation and location of y_hat 85 | y_hat_fix = y_hat[..., 0] 86 | y_hat_loc = popvec(y_hat[..., 1:]) 87 | 88 | # Fixating? Correctly saccading? 89 | fixating = y_hat_fix > 0.5 90 | 91 | original_dist = y_loc - y_hat_loc 92 | dist = np.minimum(abs(original_dist), 2 * np.pi - abs(original_dist)) 93 | corr_loc = dist < 0.1 * np.pi 94 | 95 | # Should fixate? 96 | should_fix = y_loc < 0 97 | 98 | # performance 99 | perf = should_fix * fixating + (1 - should_fix) * corr_loc * (1 - fixating) 100 | return perf 101 | 102 | 103 | class LeakyRNNCell(RNNCell): 104 | """The most basic RNN cell. 105 | 106 | Args: 107 | num_units: int, The number of units in the RNN cell. 108 | activation: Nonlinearity to use. Default: `tanh`. 109 | reuse: (optional) Python boolean describing whether to reuse variables 110 | in an existing scope. If not `True`, and the existing scope already has 111 | the given variables, an error is raised. 112 | name: String, the name of the layer. Layers with the same name will 113 | share weights, but to avoid mistakes we require reuse=True in such 114 | cases. 115 | """ 116 | 117 | def __init__(self, 118 | num_units, 119 | n_input, 120 | alpha, 121 | sigma_rec=0, 122 | activation='softplus', 123 | w_rec_init='diag', 124 | rng=None, 125 | reuse=None, 126 | name=None): 127 | super(LeakyRNNCell, self).__init__(_reuse=reuse, name=name) 128 | 129 | # Inputs must be 2-dimensional. 130 | # self.input_spec = base_layer.InputSpec(ndim=2) 131 | 132 | self._num_units = num_units 133 | self._w_rec_init = w_rec_init 134 | self._reuse = reuse 135 | 136 | if activation == 'softplus': 137 | self._activation = tf.nn.softplus 138 | self._w_in_start = 1.0 139 | self._w_rec_start = 0.6 # 0.5 140 | elif activation == 'tanh': 141 | self._activation = tf.tanh 142 | self._w_in_start = 1.0 143 | self._w_rec_start = 1.0 144 | elif activation == 'relu': 145 | self._activation = tf.nn.relu 146 | self._w_in_start = 1.0 147 | self._w_rec_start = 0.5 148 | elif activation == 'power': 149 | self._activation = lambda x: tf.square(tf.nn.relu(x)) 150 | self._w_in_start = 1.0 151 | self._w_rec_start = 0.01 152 | elif activation == 'retanh': 153 | self._activation = lambda x: tf.tanh(tf.nn.relu(x)) 154 | self._w_in_start = 1.0 155 | self._w_rec_start = 0.5 156 | else: 157 | raise ValueError('Unknown activation') 158 | self._alpha = alpha 159 | self._sigma = np.sqrt(2 / alpha) * sigma_rec 160 | if rng is None: 161 | self.rng = np.random.RandomState() 162 | else: 163 | self.rng = rng 164 | 165 | # Generate initialization matrix 166 | n_hidden = self._num_units 167 | w_in0 = (self.rng.randn(n_input, n_hidden) / 168 | np.sqrt(n_input) * self._w_in_start) 169 | 170 | if self._w_rec_init == 'diag': 171 | w_rec0 = self._w_rec_start * np.eye(n_hidden) 172 | elif self._w_rec_init == 'randortho': 173 | w_rec0 = self._w_rec_start * tools.gen_ortho_matrix(n_hidden, 174 | rng=self.rng) 175 | elif self._w_rec_init == 'randgauss': 176 | w_rec0 = (self._w_rec_start * 177 | self.rng.randn(n_hidden, n_hidden) / np.sqrt(n_hidden)) 178 | 179 | matrix0 = np.concatenate((w_in0, w_rec0), axis=0) 180 | 181 | self.w_rnn0 = matrix0 182 | self._initializer = tf.constant_initializer(matrix0, dtype=tf.float32) 183 | 184 | @property 185 | def state_size(self): 186 | return self._num_units 187 | 188 | @property 189 | def output_size(self): 190 | return self._num_units 191 | 192 | def build(self, inputs_shape): 193 | if inputs_shape[1].value is None: 194 | raise ValueError( 195 | "Expected inputs.shape[-1] to be known, saw shape: %s" 196 | % inputs_shape) 197 | 198 | input_depth = inputs_shape[1].value 199 | self._kernel = self.add_variable( 200 | 'kernel', 201 | shape=[input_depth + self._num_units, self._num_units], 202 | initializer=self._initializer) 203 | self._bias = self.add_variable( 204 | 'bias', 205 | shape=[self._num_units], 206 | initializer=init_ops.zeros_initializer(dtype=self.dtype)) 207 | 208 | self.built = True 209 | 210 | def call(self, inputs, state): 211 | """Most basic RNN: output = new_state = act(W * input + U * state + B).""" 212 | 213 | gate_inputs = math_ops.matmul( 214 | array_ops.concat([inputs, state], 1), self._kernel) 215 | gate_inputs = nn_ops.bias_add(gate_inputs, self._bias) 216 | 217 | noise = tf.random_normal(tf.shape(state), mean=0, stddev=self._sigma) 218 | gate_inputs = gate_inputs + noise 219 | 220 | output = self._activation(gate_inputs) 221 | 222 | output = (1 - self._alpha) * state + self._alpha * output 223 | 224 | return output, output 225 | 226 | 227 | class Model(object): 228 | """The model.""" 229 | 230 | def __init__(self, 231 | model_dir, 232 | hp=None, 233 | sigma_rec=None, 234 | dt=None): 235 | """ 236 | Initializing the model with information from hp 237 | 238 | Args: 239 | model_dir: string, directory of the model 240 | hp: a dictionary or None 241 | sigma_rec: if not None, overwrite the sigma_rec passed by hp 242 | """ 243 | 244 | # Reset tensorflow graphs 245 | tf.reset_default_graph() # must be in the beginning 246 | 247 | if hp is None: 248 | hp = tools.load_hp(model_dir) 249 | if hp is None: 250 | raise ValueError( 251 | 'No hp found for model_dir {:s}'.format(model_dir)) 252 | 253 | tf.set_random_seed(hp['seed']) 254 | self.rng = np.random.RandomState(hp['seed']) 255 | 256 | if sigma_rec is not None: 257 | print('Overwrite sigma_rec with {:0.3f}'.format(sigma_rec)) 258 | hp['sigma_rec'] = sigma_rec 259 | 260 | if dt is not None: 261 | print('Overwrite original dt with {:0.1f}'.format(dt)) 262 | hp['dt'] = dt 263 | 264 | hp['alpha'] = 1.0 * hp['dt'] / hp['tau'] 265 | 266 | # Input, target output, and cost mask 267 | # Shape: [Time, Batch, Num_units] 268 | if hp['in_type'] != 'normal': 269 | raise ValueError('Only support in_type ' + hp['in_type']) 270 | 271 | datasetType, datasetShape = defineDatasetFormat(hp) 272 | dataset = tf.data.Dataset.from_generator(lambda: datasetGeneratorFromTaskDef( 273 | hp, 'random'), datasetType, datasetShape) 274 | dataset = dataset.prefetch(4) 275 | self.datasetTensors = dataset.make_one_shot_iterator().get_next() 276 | 277 | self._build(hp) 278 | 279 | self.model_dir = model_dir 280 | self.hp = hp 281 | 282 | def _build(self, hp): 283 | if 'use_separate_input' in hp and hp['use_separate_input']: 284 | self._build_seperate(hp) 285 | else: 286 | self._build_fused(hp) 287 | 288 | self.var_list = tf.trainable_variables() 289 | self.weight_list = [v for v in self.var_list if is_weight(v)] 290 | 291 | if 'use_separate_input' in hp and hp['use_separate_input']: 292 | self._set_weights_separate(hp) 293 | else: 294 | self._set_weights_fused(hp) 295 | 296 | # Regularization terms 297 | self.cost_reg = tf.constant(0.) 298 | if hp['l1_h'] > 0: 299 | self.cost_reg += tf.reduce_mean(tf.abs(self.h)) * hp['l1_h'] 300 | if hp['l2_h'] > 0: 301 | self.cost_reg += tf.nn.l2_loss(self.h) * hp['l2_h'] 302 | 303 | if hp['l1_weight'] > 0: 304 | self.cost_reg += hp['l1_weight'] * tf.add_n( 305 | [tf.reduce_mean(tf.abs(v)) for v in self.weight_list]) 306 | if hp['l2_weight'] > 0: 307 | self.cost_reg += hp['l2_weight'] * tf.add_n( 308 | [tf.nn.l2_loss(v) for v in self.weight_list]) 309 | 310 | # Create an optimizer. 311 | if 'optimizer' not in hp or hp['optimizer'] == 'adam': 312 | self.opt = tf.train.AdamOptimizer( 313 | learning_rate=hp['learning_rate']) 314 | elif hp['optimizer'] == 'sgd': 315 | self.opt = tf.train.GradientDescentOptimizer( 316 | learning_rate=hp['learning_rate']) 317 | elif hp['optimizer'] == 'sgd_mom': 318 | self.opt = tf.train.MomentumOptimizer( 319 | learning_rate=hp['learning_rate'], momentum=hp['momentum']) 320 | 321 | # Set cost 322 | self.set_optimizer() 323 | 324 | # Variable saver 325 | # self.saver = tf.train.Saver(self.var_list) 326 | self.saver = tf.train.Saver(max_to_keep=100) 327 | self.saver_task = tf.train.Saver(max_to_keep=100) 328 | 329 | def _build_fused(self, hp): 330 | n_input = hp['n_input'] 331 | n_rnn = hp['n_rnn'] 332 | n_output = hp['n_output'] 333 | 334 | self.x = self.datasetTensors[0] # tf.placeholder("float", [None, None, n_input]) #add January 11 2019 # 335 | self.y = self.datasetTensors[1] # tf.placeholder("float", [None, None, n_output]) #add January 11 2019 # 336 | self.c_mask = self.datasetTensors[2] # tf.placeholder("float", [None, n_output]) #add January 11 2019 # 337 | 338 | # self.x = tf.placeholder("float", [None, None, n_input]) #add January 11 2019 # 339 | # self.y = tf.placeholder("float", [None, None, n_output]) #add January 11 2019 # 340 | # self.c_mask = tf.placeholder("float", [None, n_output]) #add January 11 2019 # 341 | 342 | # Activation functions 343 | if hp['activation'] == 'power': 344 | def f_act(x): return tf.square(tf.nn.relu(x)) 345 | elif hp['activation'] == 'retanh': 346 | def f_act(x): return tf.tanh(tf.nn.relu(x)) 347 | elif hp['activation'] == 'relu+': 348 | def f_act(x): return tf.nn.relu(x + tf.constant(1.)) 349 | else: 350 | f_act = getattr(tf.nn, hp['activation']) 351 | 352 | # Recurrent activity 353 | if hp['rnn_type'] == 'LeakyRNN': 354 | n_in_rnn = self.x.get_shape().as_list()[-1] 355 | self.cell = LeakyRNNCell(n_rnn, n_in_rnn, 356 | hp['alpha'], 357 | sigma_rec=hp['sigma_rec'], 358 | activation=hp['activation'], 359 | w_rec_init=hp['w_rec_init'], 360 | rng=self.rng) 361 | elif hp['rnn_type'] == 'LeakyGRU': 362 | self.cell = LeakyGRUCell( 363 | n_rnn, hp['alpha'], 364 | sigma_rec=hp['sigma_rec'], activation=f_act) 365 | elif hp['rnn_type'] == 'LSTM': 366 | self.cell = tf.contrib.rnn.LSTMCell(n_rnn, activation=f_act) 367 | 368 | elif hp['rnn_type'] == 'GRU': 369 | self.cell = tf.contrib.rnn.GRUCell(n_rnn, activation=f_act) 370 | else: 371 | raise NotImplementedError("""rnn_type must be one of LeakyRNN, 372 | LeakyGRU, EILeakyGRU, LSTM, GRU 373 | """) 374 | 375 | # Dynamic rnn with time major 376 | self.h, states = rnn.dynamic_rnn( 377 | self.cell, self.x, dtype=tf.float32, time_major=True) 378 | 379 | # Output 380 | with tf.variable_scope("output"): 381 | # Using default initialization `glorot_uniform_initializer` 382 | w_out = tf.get_variable( 383 | 'weights', 384 | [n_rnn, n_output], 385 | dtype=tf.float32 386 | ) 387 | b_out = tf.get_variable( 388 | 'biases', 389 | [n_output], 390 | dtype=tf.float32, 391 | initializer=tf.constant_initializer(0.0, dtype=tf.float32) 392 | ) 393 | 394 | h_shaped = tf.reshape(self.h, (-1, n_rnn)) 395 | y_shaped = tf.reshape(self.y, (-1, n_output)) 396 | # y_hat_ shape (n_time*n_batch, n_unit) 397 | y_hat = tf.matmul(h_shaped, w_out) + b_out 398 | 399 | if hp['loss_type'] == 'lsq': 400 | # Least-square loss 401 | # y_hat = tf.sigmoid(y_hat_) #removed sigmoid Jan 24, 2019 402 | self.cost_lsq = tf.reduce_mean( 403 | tf.square((y_shaped - y_hat) * self.c_mask)) 404 | else: 405 | y_hat = tf.nn.softmax(y_hat_) 406 | # Cross-entropy loss 407 | self.cost_lsq = tf.reduce_mean( 408 | self.c_mask * tf.nn.softmax_cross_entropy_with_logits( 409 | labels=y_shaped, logits=y_hat_)) 410 | 411 | self.y_hat = tf.reshape(y_hat, 412 | (-1, tf.shape(self.h)[1], n_output)) 413 | y_hat_fix, y_hat_ring = tf.split( 414 | self.y_hat, [1, n_output - 1], axis=-1) 415 | self.y_hat_loc = tf_popvec(y_hat_ring) 416 | 417 | def _set_weights_fused(self, hp): 418 | """Set model attributes for several weight variables.""" 419 | n_input = hp['n_input'] 420 | n_rnn = hp['n_rnn'] 421 | n_output = hp['n_output'] 422 | 423 | for v in self.var_list: 424 | if 'rnn' in v.name: 425 | if 'kernel' in v.name or 'weight' in v.name: 426 | # TODO(gryang): For GRU, fix 427 | self.w_rec = v[n_input:, :] 428 | self.w_in = v[:n_input, :] 429 | else: 430 | self.b_rec = v 431 | else: 432 | assert 'output' in v.name 433 | if 'kernel' in v.name or 'weight' in v.name: 434 | self.w_out = v 435 | else: 436 | self.b_out = v 437 | 438 | # check if the recurrent and output connection has the correct shape 439 | if self.w_out.shape != (n_rnn, n_output): 440 | raise ValueError('Shape of w_out should be ' + 441 | str((n_rnn, n_output)) + ', but found ' + 442 | str(self.w_out.shape)) 443 | if self.w_rec.shape != (n_rnn, n_rnn): 444 | raise ValueError('Shape of w_rec should be ' + 445 | str((n_rnn, n_rnn)) + ', but found ' + 446 | str(self.w_rec.shape)) 447 | if self.w_in.shape != (n_input, n_rnn): 448 | raise ValueError('Shape of w_in should be ' + 449 | str((n_input, n_rnn)) + ', but found ' + 450 | str(self.w_in.shape)) 451 | 452 | def _build_seperate(self, hp): 453 | # Input, target output, and cost mask 454 | # Shape: [Time, Batch, Num_units] 455 | n_input = hp['n_input'] 456 | n_rnn = hp['n_rnn'] 457 | n_output = hp['n_output'] 458 | 459 | self.x = self.datasetTensors[0] # tf.placeholder("float", [None, None, n_input]) #add January 11 2019 # 460 | self.y = self.datasetTensors[1] # tf.placeholder("float", [None, None, n_output]) #add January 11 2019 # 461 | self.c_mask = self.datasetTensors[2] # tf.placeholder("float", [None, n_output]) #add January 11 2019 # 462 | 463 | # self.x = tf.placeholder("float", [None, None, n_input]) #add January 11 2019 # 464 | # self.y = tf.placeholder("float", [None, None, n_output]) #add January 11 2019 # 465 | # self.c_mask = tf.placeholder("float", [None, n_output]) #add January 11 2019 # 466 | 467 | sensory_inputs, rule_inputs = tf.split( 468 | self.x, [hp['rule_start'], hp['n_rule']], axis=-1) 469 | 470 | sensory_rnn_inputs = tf.layers.dense(sensory_inputs, n_rnn, name='sen_input') 471 | 472 | if 'mix_rule' in hp and hp['mix_rule'] is True: 473 | # rotate rule matrix 474 | kernel_initializer = tf.orthogonal_initializer() 475 | rule_inputs = tf.layers.dense( 476 | rule_inputs, hp['n_rule'], name='mix_rule', 477 | use_bias=False, trainable=False, 478 | kernel_initializer=kernel_initializer) 479 | 480 | rule_rnn_inputs = tf.layers.dense(rule_inputs, n_rnn, name='rule_input', use_bias=False) 481 | 482 | rnn_inputs = sensory_rnn_inputs + rule_rnn_inputs 483 | 484 | # Recurrent activity 485 | self.cell = LeakyRNNCellSeparateInput( 486 | n_rnn, hp['alpha'], 487 | sigma_rec=hp['sigma_rec'], 488 | activation=hp['activation'], 489 | w_rec_init=hp['w_rec_init'], 490 | rng=self.rng) 491 | 492 | # Dynamic rnn with time major 493 | self.h, states = rnn.dynamic_rnn( 494 | self.cell, rnn_inputs, dtype=tf.float32, time_major=True) 495 | 496 | # Output 497 | h_shaped = tf.reshape(self.h, (-1, n_rnn)) 498 | y_shaped = tf.reshape(self.y, (-1, n_output)) 499 | # y_hat shape (n_time*n_batch, n_unit) 500 | y_hat = tf.layers.dense( 501 | h_shaped, n_output, activation=tf.nn.sigmoid, name='output') 502 | # Least-square loss 503 | 504 | self.cost_lsq = tf.reduce_mean( 505 | tf.square((y_shaped - y_hat) * self.c_mask)) 506 | 507 | self.y_hat = tf.reshape(y_hat, 508 | (-1, tf.shape(self.h)[1], n_output)) 509 | y_hat_fix, y_hat_ring = tf.split( 510 | self.y_hat, [1, n_output - 1], axis=-1) 511 | self.y_hat_loc = tf_popvec(y_hat_ring) 512 | 513 | def _set_weights_separate(self, hp): 514 | """Set model attributes for several weight variables.""" 515 | n_input = hp['n_input'] 516 | n_rnn = hp['n_rnn'] 517 | n_output = hp['n_output'] 518 | 519 | for v in self.var_list: 520 | if 'rnn' in v.name: 521 | if 'kernel' in v.name or 'weight' in v.name: 522 | self.w_rec = v 523 | else: 524 | self.b_rec = v 525 | elif 'sen_input' in v.name: 526 | if 'kernel' in v.name or 'weight' in v.name: 527 | self.w_sen_in = v 528 | else: 529 | self.b_in = v 530 | elif 'rule_input' in v.name: 531 | self.w_rule = v 532 | else: 533 | assert 'output' in v.name 534 | if 'kernel' in v.name or 'weight' in v.name: 535 | self.w_out = v 536 | else: 537 | self.b_out = v 538 | 539 | # check if the recurrent and output connection has the correct shape 540 | if self.w_out.shape != (n_rnn, n_output): 541 | raise ValueError('Shape of w_out should be ' + 542 | str((n_rnn, n_output)) + ', but found ' + 543 | str(self.w_out.shape)) 544 | if self.w_rec.shape != (n_rnn, n_rnn): 545 | raise ValueError('Shape of w_rec should be ' + 546 | str((n_rnn, n_rnn)) + ', but found ' + 547 | str(self.w_rec.shape)) 548 | if self.w_sen_in.shape != (hp['rule_start'], n_rnn): 549 | raise ValueError('Shape of w_sen_in should be ' + 550 | str((hp['rule_start'], n_rnn)) + ', but found ' + 551 | str(self.w_sen_in.shape)) 552 | if self.w_rule.shape != (hp['n_rule'], n_rnn): 553 | raise ValueError('Shape of w_in should be ' + 554 | str((hp['n_rule'], n_rnn)) + ', but found ' + 555 | str(self.w_rule.shape)) 556 | 557 | def initialize(self): 558 | """Initialize the model for training.""" 559 | sess = tf.get_default_session() 560 | sess.run(tf.global_variables_initializer()) 561 | 562 | def restore(self, load_dir=None): 563 | """restore the model""" 564 | sess = tf.get_default_session() 565 | if load_dir is None: 566 | load_dir = self.model_dir 567 | save_path = os.path.join(load_dir, 'model.ckpt') 568 | try: 569 | self.saver.restore(sess, save_path) 570 | except: 571 | # Some earlier checkpoints only stored trainable variables 572 | self.saver = tf.train.Saver(self.var_list) 573 | self.saver.restore(sess, save_path) 574 | print("Model restored from file: %s" % save_path) 575 | 576 | def save(self): 577 | """Save the model.""" 578 | sess = tf.get_default_session() 579 | save_path = os.path.join(self.model_dir, 'model.ckpt') 580 | self.saver.save(sess, save_path) 581 | print("Model saved in file: %s" % save_path) 582 | 583 | def save_ckpt(self, step): # added Jan 9 2019 584 | """Save the model.""" 585 | sess = tf.get_default_session() 586 | save_path = os.path.join(self.model_dir, 'ckpts', 'model.ckpt' + '-' + str(step)) 587 | self.saver.save(sess, save_path) 588 | print("Model saved in file: %s" % save_path) 589 | 590 | def set_optimizer(self, extra_cost=None, var_list=None): 591 | """Recompute the optimizer to reflect the latest cost function. 592 | 593 | This is useful when the cost function is modified throughout training 594 | 595 | Args: 596 | extra_cost : tensorflow variable, 597 | added to the lsq and regularization cost 598 | """ 599 | cost = self.cost_lsq + self.cost_reg 600 | if extra_cost is not None: 601 | cost += extra_cost 602 | 603 | if var_list is None: 604 | var_list = self.var_list 605 | 606 | print('Variables being optimized:') 607 | for v in var_list: 608 | print(v) 609 | 610 | self.grads_and_vars = self.opt.compute_gradients(cost, var_list) 611 | 612 | # gradient clipping 613 | self.clip_max_norm = 10 614 | clipped_gvs = [(clip_grad(grad, self.clip_max_norm), var) for grad, var in self.grads_and_vars] 615 | clipped_gs = [(clip_grad(grad, self.clip_max_norm)) for grad, var in self.grads_and_vars] 616 | 617 | self.train_step = self.opt.apply_gradients(clipped_gvs) 618 | self.clipped_gs = clipped_gs # trying to save gradients in log feb 8th 619 | 620 | def save_after_task(self, taskNumber): 621 | """Save the model.""" 622 | sess = tf.get_default_session() 623 | save_path = os.path.join(self.model_dir, 'task_ckpts', 'model.ckpt' + '-' + str(taskNumber)) 624 | self.saver_task.save(sess, save_path) 625 | print("Model saved in file: %s" % save_path) 626 | 627 | 628 | class Sequential_Model(Model): 629 | """The sequential model.""" 630 | 631 | def __init__(self, model_dir, 632 | projGrad=True, 633 | applyProj='both', 634 | hp=None, 635 | sigma_rec=None, 636 | dt=None): 637 | 638 | self.projGrad = projGrad # whether or not to project out interfering directions 639 | self.applyProj = applyProj # how to apply update: both, left, right (for testing double-sided update rule) 640 | Model.__init__(self, model_dir, hp, sigma_rec, dt) 641 | 642 | def _build(self, hp): 643 | if 'use_separate_input' in hp and hp['use_separate_input']: 644 | self._build_seperate(hp) 645 | else: 646 | self._build_fused(hp) 647 | 648 | self.var_list = tf.trainable_variables() 649 | self.weight_list = [v for v in self.var_list if is_weight(v)] 650 | 651 | if 'use_separate_input' in hp and hp['use_separate_input']: 652 | self._set_weights_separate(hp) 653 | else: 654 | self._set_weights_fused(hp) 655 | 656 | # Regularization terms 657 | self.cost_reg = tf.constant(0.) 658 | if hp['l1_h'] > 0: 659 | self.cost_reg += tf.reduce_mean(tf.abs(self.h)) * hp['l1_h'] 660 | if hp['l2_h'] > 0: 661 | self.cost_reg += tf.nn.l2_loss(self.h) * hp['l2_h'] 662 | 663 | if hp['l1_weight'] > 0: 664 | self.cost_reg += hp['l1_weight'] * tf.add_n( 665 | [tf.reduce_mean(tf.abs(v)) for v in self.weight_list]) 666 | if hp['l2_weight'] > 0: 667 | self.cost_reg += hp['l2_weight'] * tf.add_n( 668 | [tf.nn.l2_loss(v) for v in self.weight_list]) 669 | 670 | # Create an optimizer. 671 | if 'optimizer' not in hp or hp['optimizer'] == 'adam': 672 | if self.projGrad is True: 673 | self.opt = AdamOptimizer_withProjection( 674 | learning_rate=hp['learning_rate'], beta2=(1 - 9e-8)) 675 | else: 676 | self.opt = tf.train.AdamOptimizer( 677 | learning_rate=hp['learning_rate']) 678 | elif hp['optimizer'] == 'sgd': 679 | if self.projGrad is True: 680 | self.opt = GradientDescentOptimizer_withProjection( 681 | learning_rate=hp['learning_rate']) 682 | else: 683 | self.opt = tf.train.GradientDescentOptimizer( 684 | learning_rate=hp['learning_rate']) 685 | elif hp['optimizer'] == 'sgd_mom': 686 | if self.projGrad is True: 687 | self.opt = MomentumOptimizer_withProjection( 688 | learning_rate=hp['learning_rate'], momentum=hp['momentum']) 689 | 690 | else: 691 | self.opt = tf.train.MomentumOptimizer( 692 | learning_rate=hp['learning_rate'], momentum=hp['momentum']) 693 | # Set cost 694 | self.set_optimizer() 695 | 696 | # Variable saver 697 | # self.saver = tf.train.Saver(self.var_list) 698 | self.saver = tf.train.Saver(max_to_keep=100) 699 | self.saver_task = tf.train.Saver(max_to_keep=100) 700 | 701 | def set_optimizer(self, activity_proj=None, input_proj=None, output_proj=None, recurrent_proj=None, taskNumber=0, extra_cost=None, var_list=None, alpha=1e-3): 702 | """Recompute the optimizer to reflect the latest cost function. 703 | 704 | This is useful when the cost function is modified throughout training 705 | 706 | Args: 707 | extra_cost : tensorflow variable, 708 | added to the lsq and regularization cost 709 | """ 710 | cost = self.cost_lsq + self.cost_reg 711 | if extra_cost is not None: 712 | cost += extra_cost 713 | 714 | if var_list is None: 715 | var_list = self.var_list 716 | 717 | print('Variables being optimized:') 718 | for v in var_list: 719 | print(v) 720 | 721 | self.grads_and_vars = self.opt.compute_gradients(cost, var_list) 722 | 723 | # gradient clipping 724 | self.clip_max_norm = 10 725 | clipped_gvs = [(clip_grad(grad, self.clip_max_norm), var) for grad, var in self.grads_and_vars] 726 | clipped_gs = [(clip_grad(grad, self.clip_max_norm)) for grad, var in self.grads_and_vars] 727 | 728 | if self.projGrad is True: 729 | self.train_step = self.opt.apply_gradients(clipped_gvs, activity_proj, input_proj, output_proj, recurrent_proj, taskNumber) 730 | else: 731 | self.train_step = self.opt.apply_gradients(clipped_gvs) 732 | 733 | self.clipped_gs = clipped_gs # trying to save gradients in log feb 8th 734 | 735 | def save_after_task(self, taskNumber): 736 | """Save the model.""" 737 | sess = tf.get_default_session() 738 | save_path = os.path.join(self.model_dir, 'task_ckpts', 'model.ckpt' + '-' + str(taskNumber)) 739 | self.saver_task.save(sess, save_path) 740 | print("Model saved in file: %s" % save_path) 741 | 742 | 743 | class FixedPoint_Model(Model): 744 | """For finding fixed points.""" 745 | 746 | def __init__(self, model_dir, 747 | projGrad=True, 748 | hp=None, 749 | sigma_rec=0, 750 | dt=None): 751 | 752 | Model.__init__(self, model_dir, hp, sigma_rec, dt) 753 | 754 | def _build(self, hp): 755 | if 'use_separate_input' in hp and hp['use_separate_input']: 756 | self._build_seperate(hp) 757 | else: 758 | self._build_fused(hp) 759 | 760 | self.var_list = tf.trainable_variables() 761 | self.weight_list = [v for v in self.var_list if is_weight(v)] 762 | 763 | if 'use_separate_input' in hp and hp['use_separate_input']: 764 | self._set_weights_separate(hp) 765 | else: 766 | self._set_weights_fused(hp) 767 | 768 | # Regularization terms 769 | self.cost_reg = tf.constant(0.) 770 | if hp['l1_h'] > 0: 771 | self.cost_reg += tf.reduce_mean(tf.abs(self.h)) * hp['l1_h'] 772 | if hp['l2_h'] > 0: 773 | self.cost_reg += tf.nn.l2_loss(self.h) * hp['l2_h'] 774 | 775 | if hp['l1_weight'] > 0: 776 | self.cost_reg += hp['l1_weight'] * tf.add_n( 777 | [tf.reduce_mean(tf.abs(v)) for v in self.weight_list]) 778 | if hp['l2_weight'] > 0: 779 | self.cost_reg += hp['l2_weight'] * tf.add_n( 780 | [tf.nn.l2_loss(v) for v in self.weight_list]) 781 | 782 | # self.saver = tf.train.Saver(self.var_list) 783 | -------------------------------------------------------------------------------- /tools_lnd.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | import os 4 | import sys 5 | import time 6 | from collections import defaultdict 7 | import numpy as np 8 | import matplotlib.pyplot as plt 9 | import tensorflow as tf 10 | import re 11 | import json 12 | from datetime import datetime as datetime 13 | from tensorflow.python.ops import parallel_for as pfor 14 | from scipy.linalg import orthogonal_procrustes 15 | from sklearn.decomposition import PCA 16 | from sklearn.neighbors import DistanceMetric 17 | from sklearn.manifold import MDS 18 | from sklearn.linear_model import LinearRegression 19 | from sklearn import linear_model 20 | from numpy import linalg as LA 21 | import numpy.random as npr 22 | from scipy import stats 23 | 24 | import task 25 | from task import generate_trials, rules_dict 26 | from network import Model, get_perf, FixedPoint_Model 27 | import tools 28 | import train 29 | 30 | def gen_trials_from_model_dir(model_dir,rule,mode='test',noise_on = True,batch_size = 500): 31 | model = Model(model_dir) 32 | with tf.Session() as sess: 33 | model.restore() 34 | # model._sigma=0 35 | # get all connection weights and biases as tensorflow variables 36 | var_list = model.var_list 37 | # evaluate the parameters after training 38 | # params = [sess.run(var) for var in var_list] 39 | # get hparams 40 | hparams = model.hp 41 | # create a trial 42 | trial = generate_trials(rule, hparams, mode=mode, noise_on=noise_on, batch_size =batch_size, delay_fac =1) 43 | return trial 44 | 45 | def gen_X_from_model_dir(model_dir,trial,d = []): 46 | model = Model(model_dir) 47 | with tf.Session() as sess: 48 | 49 | if len(d)==0: 50 | model.restore() 51 | else: 52 | model.saver.restore(sess,d) 53 | 54 | # model._sigma=0 55 | # get all connection weights and biases as tensorflow variables 56 | var_list = model.var_list 57 | # evaluate the parameters after training 58 | hparams = model.hp 59 | feed_dict = tools.gen_feed_dict(model, trial, hparams) 60 | # run model 61 | h_tf, y_hat_tf = sess.run([model.h, model.y_hat], feed_dict=feed_dict) #(n_time, n_condition, n_neuron) 62 | x = np.transpose(h_tf,(2,1,0)) # h_tf[:,range(1,n_trials),:],(2,1,0)) 63 | X = np.reshape(x,(x.shape[0],-1)) 64 | return X, x #return orthogonal complement of hidden unit activity to ouput projection matrix 65 | 66 | def gen_X_from_model_dir_epoch(model_dir,trial,epoch,d = []): 67 | model = Model(model_dir) 68 | with tf.Session() as sess: 69 | 70 | if len(d)==0: 71 | model.restore() 72 | else: 73 | model.saver.restore(sess,d) 74 | 75 | model._sigma=0 76 | # get all connection weights and biases as tensorflow variables 77 | var_list = model.var_list 78 | # evaluate the parameters after training 79 | params = [sess.run(var) for var in var_list] 80 | # get hparams 81 | hparams = model.hp 82 | # create a trial 83 | feed_dict = tools.gen_feed_dict(model, trial, hparams) 84 | # run model 85 | h_tf, y_hat_tf = sess.run([model.h, model.y_hat], feed_dict=feed_dict) #(n_time, n_condition, n_neuron) 86 | 87 | if trial.epochs[epoch][1] is None: 88 | epoch_range = range(trial.epochs[epoch][0],np.shape(h_tf)[0]) 89 | elif trial.epochs[epoch][0] is None: 90 | epoch_range = range(0,trial.epochs[epoch][1]) 91 | else: 92 | epoch_range = range(trial.epochs[epoch][0],trial.epochs[epoch][1]) 93 | 94 | x = np.transpose(h_tf[epoch_range,:,:],(2,1,0)) #h_tf[:,range(1,n_trials),:],(2,1,0)) 95 | X = np.reshape(x,(x.shape[0],-1)) 96 | return X, x #return hidden unit activity 97 | 98 | def restore_ckpt(model_dir, ckpt_n): 99 | ckpt_n_dir = os.path.join(model_dir,'ckpts/model.ckpt-' + str(int(ckpt_n)) + '.meta') 100 | model = Model(model_dir) 101 | with tf.Session() as sess: 102 | model.saver.restore(sess,ckpt_n_dir) 103 | return model 104 | 105 | def find_ckpts(model_dir): 106 | s_all = [] 107 | ckpt_n_dir = os.path.join(model_dir,'ckpts/') 108 | for file in os.listdir(ckpt_n_dir): 109 | if file.endswith('.meta'): 110 | m = re.search('model.ckpt(.+?).meta', file) 111 | if m: 112 | found = m.group(1) 113 | s_all = np.concatenate((s_all,np.expand_dims(abs(int(found)),axis=0)),axis = 0) 114 | return s_all.astype(int) 115 | 116 | def name_best_ckpt(model_dir,rule): 117 | s_all = find_ckpts(model_dir) 118 | s_all_inds = np.sort(s_all) 119 | s_all_inds = s_all_inds.astype(int) 120 | fname = os.path.join(model_dir, 'log.json') 121 | 122 | with open(fname, 'r') as f: 123 | log_all = json.load(f) 124 | x = log_all['cost_'+rule] 125 | 126 | y = [x[int(j/1000)] for j in s_all_inds[:-1]] 127 | ind = int(s_all_inds[np.argmin(y)]) 128 | return ind 129 | 130 | def get_model_params(model_dir,ckpt_n_dir = []): 131 | 132 | model = Model(model_dir) 133 | with tf.Session() as sess: 134 | if len(ckpt_n_dir)==0: 135 | model.restore() 136 | else: 137 | model.saver.restore(sess,ckpt_n_dir) 138 | # get all connection weights and biases as tensorflow variables 139 | var_list = model.var_list 140 | # evaluate the parameters after training 141 | params = [sess.run(var) for var in var_list] 142 | 143 | w_in = params[0] 144 | b_in = params[1] 145 | w_out = params[2] 146 | b_out = params[3] 147 | 148 | return w_in, b_in, w_out, b_out 149 | 150 | def get_path_names(): 151 | import getpass 152 | ui = getpass.getuser() 153 | if ui == 'laura': 154 | p = '/home/laura' 155 | elif ui == 'lauradriscoll': 156 | p = '/Users/lauradriscoll/Documents' 157 | return p 158 | 159 | def take_names(epoch,rule,epoch_axes = [],h_epoch = []): 160 | epochs = ['stim1','delay1','go1'] 161 | epoch_names = ['stimulus','memory','go'] 162 | ei = [i for i,e in enumerate(epochs) if e==epoch] 163 | epoch_name = epoch_names[ei[0]] 164 | 165 | rules = ['fdgo','fdanti','delaygo','delayanti'] 166 | rule_names = ['DelayPro','DelayAnti','MemoryPro','MemoryAnti'] 167 | ri = [i for i,e in enumerate(rules) if e==rule] 168 | rule_name = rule_names[ri[0]] 169 | 170 | if len(epoch_axes)<1: 171 | epoch_axes_name = epoch_names[ei[0]] 172 | else: 173 | ei = [i for i,e in enumerate(epochs) if e==epoch_axes] 174 | epoch_axes_name = epoch_names[ei[0]] 175 | 176 | if len(h_epoch)==0: 177 | h_epoch = epoch 178 | 179 | return epoch_name, rule_name, epoch_axes_name, h_epoch 180 | 181 | def plot_N(X, D, clist, linewidth = 1): 182 | """Plot activity is some 2D space. 183 | 184 | Args: 185 | X: neural activity in Trials x Time x Neurons 186 | D: Neurons x 2 plotting dims 187 | """ 188 | cmap=plt.get_cmap('rainbow') 189 | S = np.shape(X)[0] 190 | 191 | for s in range(S): 192 | c = cmap(clist[s]/max(clist)) 193 | X_trial = np.dot(X[s,:,:],D.T) 194 | plt.plot(X_trial[-1,0],X_trial[-1,1],'^',c = c, linewidth = linewidth) 195 | plt.plot(X_trial[:,0],X_trial[:,1],'-',c = c, linewidth = linewidth) 196 | plt.plot(X_trial[0,0],X_trial[0,1],'.',c = c, linewidth = linewidth) 197 | 198 | def plot_FP(X, D, eig_decomps, c='k'): 199 | """Plot activity is some 2D space. 200 | 201 | Args: 202 | X: Fixed points in #Fps x Neurons 203 | D: Neurons x 2 plotting dims 204 | 205 | """ 206 | S = np.shape(X)[0] 207 | lf = 7 208 | rf = 7 209 | 210 | for s in range(S): 211 | 212 | X_trial = np.dot(X[s,:],D.T) 213 | 214 | n_arg = np.argwhere(eig_decomps[s]['evals']>1)+1 215 | if len(n_arg)>0: 216 | for arg in range(np.max(n_arg)): 217 | rdots = np.dot(np.real(eig_decomps[s]['R'][:, arg]).T,D.T) 218 | ldots = np.dot(np.real(eig_decomps[s]['L'][:, arg]).T,D.T) 219 | overlap = np.dot(rdots,ldots.T) 220 | r = np.concatenate((X_trial - rf*overlap*rdots, X_trial + rf*overlap*rdots),0) 221 | plt.plot(r[0:4:2],r[1:4:2], c = c ,alpha = .2,linewidth = .5) 222 | 223 | n_arg = np.argwhere(eig_decomps[s]['evals']<.3) 224 | if len(n_arg)>0: 225 | for arg in range(np.min(n_arg),len(eig_decomps[s]['evals'])): 226 | rdots = np.dot(np.real(eig_decomps[s]['R'][:, arg]).T,D.T) 227 | ldots = np.dot(np.real(eig_decomps[s]['L'][:, arg]).T,D.T) 228 | overlap = np.dot(rdots,ldots.T) 229 | r = np.concatenate((X_trial - rf*overlap*rdots, X_trial + rf*overlap*rdots),0) 230 | plt.plot(r[0:4:2],r[1:4:2],'b',alpha = .2,linewidth = .5) 231 | 232 | plt.plot(X_trial[0], X_trial[1], 'o', markerfacecolor = c, markeredgecolor = 'k', 233 | markersize = 6, alpha = .5) 234 | 235 | def comp_eig_decomp(Ms, sort_by='real', 236 | do_compute_lefts=True): 237 | """Compute the eigenvalues of the matrix M. No assumptions are made on M. 238 | 239 | Arguments: 240 | M: 3D np.array nmatrices x dim x dim matrix 241 | do_compute_lefts: Compute the left eigenvectors? Requires a pseudo-inverse 242 | call. 243 | 244 | Returns: 245 | list of dictionaries with eigenvalues components: sorted 246 | eigenvalues, sorted right eigenvectors, and sored left eigenvectors 247 | (as column vectors). 248 | """ 249 | if sort_by == 'magnitude': 250 | sort_fun = np.abs 251 | elif sort_by == 'real': 252 | sort_fun = np.real 253 | else: 254 | assert False, "Not implemented yet." 255 | 256 | decomps = [] 257 | L = None 258 | for M in Ms: 259 | evals, R = LA.eig(M) 260 | indices = np.flipud(np.argsort(sort_fun(evals))) 261 | if do_compute_lefts: 262 | L = LA.pinv(R).T # as columns 263 | L = L[:, indices] 264 | decomps.append({'evals' : evals[indices], 'R' : R[:, indices], 'L' : L}) 265 | 266 | return decomps 267 | 268 | def angle_between(v1, v2): 269 | """ Returns the angle in radians between vectors 'v1' and 'v2':: 270 | 271 | >>> angle_between((1, 0, 0), (0, 1, 0)) 272 | 1.5707963267948966 273 | >>> angle_between((1, 0, 0), (1, 0, 0)) 274 | 0.0 275 | >>> angle_between((1, 0, 0), (-1, 0, 0)) 276 | 3.141592653589793 277 | """ 278 | v1_u = unit_vector(v1) 279 | v2_u = unit_vector(v2) 280 | return np.arccos(np.clip(np.dot(v1_u, v2_u), -1.0, 1.0)) 281 | 282 | def rot_mat(theta): 283 | R = np.array(((np.cos(theta), -np.sin(theta)), (np.sin(theta), np.cos(theta)))) 284 | return R 285 | 286 | def calc_R_angle(R): 287 | return np.arccos((np.trace(R)-1)/2) 288 | 289 | def tranform_in_rPC(X,R,X_ss): 290 | Xr_ss = np.dot(R,X_ss.T).T 291 | Xr = np.dot(R,X.T).T 292 | if Xr_ss[1,1]>0: 293 | Xr = np.dot(Xr,np.array(((1,0),(0,-1)))) 294 | return Xr 295 | 296 | def unit_vector(vector): 297 | """ Returns the unit vector of the vector. """ 298 | return vector / np.linalg.norm(vector) 299 | 300 | 301 | def make_Jac_u_dot_delu(model_dir_all,ckpt_n_dir,rule,task_set,time_set,trial_set): 302 | n_tasks = len(task_set) 303 | 304 | model = Model(model_dir_all) 305 | with tf.Session() as sess: 306 | 307 | model.saver.restore(sess,ckpt_n_dir) 308 | # get all connection weights and biases as tensorflow variables 309 | var_list = model.var_list 310 | # evaluate the parameters after training 311 | params = [sess.run(var) for var in var_list] 312 | # get hparams 313 | hparams = model.hp 314 | trial = generate_trials(rule, hparams, mode='test', noise_on=False, delay_fac =1) 315 | 316 | #get size of relevant variables to init mats 317 | n_inputs = np.shape(trial.x)[2] 318 | N = np.shape(params[0])[1] 319 | n_stim_dims = n_inputs - 20 320 | #change this depending on when in the trial you're looking [must be a transition btwn epochs] 321 | 322 | #init mats 323 | J_np_u = np.zeros((n_tasks,len(trial_set),len(time_set),N,n_inputs)) 324 | J_np_u_dot_delu = np.zeros((n_tasks,len(trial_set),len(time_set),N)) 325 | 326 | for r in range(n_tasks): 327 | r_all_tasks_ind = task_set[r] 328 | 329 | trial.x[:,:,n_stim_dims:] = 0 #set all tasks to 0 #(n_time, n_trials, n_inputs) 330 | trial.x[:,:,n_stim_dims+r_all_tasks_ind] = 1 #except for this task 331 | 332 | feed_dict = tools.gen_feed_dict(model, trial, hparams) 333 | h_tf = sess.run(model.h, feed_dict=feed_dict) #(n_time, n_trials, n_neuron) 334 | 335 | for trial_i in range(len(trial_set)): #depending on the analysis I was including one or many trials 336 | for time_i in range(len(time_set)): #also including one or many time pts 337 | 338 | inputs = np.squeeze(trial.x[time_set[time_i],trial_set[trial_i],:]) #(n_time, n_condition, n_inputs) 339 | inputs = inputs[np.newaxis,:] 340 | 341 | states = h_tf[time_set[time_i],trial_set[trial_i],:] 342 | states = states[np.newaxis,:] 343 | 344 | #calc Jac wrt inputs 345 | inputs_context = np.squeeze(trial.x[time_set[time_i]-1,trial_set[trial_i],:]) #(n_time, n_condition, n_inputs) 346 | inputs_context = inputs_context[np.newaxis,:] 347 | delta_inputs = inputs - inputs_context 348 | 349 | inputs_tf_context = tf.constant(inputs_context, dtype=tf.float32) 350 | states_tf = tf.constant(states, dtype=tf.float32) 351 | output, new_states = model.cell(inputs_tf_context, states_tf) 352 | F_context = new_states 353 | 354 | J_tf_u = pfor.batch_jacobian(F_context, inputs_tf_context, use_pfor=False) 355 | J_np_u[r,trial_i,time_i,:,:] = sess.run(J_tf_u) 356 | J_np_u_dot_delu[r,trial_i,time_i,:] = np.squeeze(np.dot(J_np_u[r,trial_i,time_i,:,:],delta_inputs.T)) 357 | 358 | return J_np_u_dot_delu 359 | 360 | def make_Jac_x(model_dir_all,ckpt_n_dir,rule,task_set,time_set,trial_set): 361 | n_tasks = len(task_set) 362 | 363 | model = Model(model_dir_all) 364 | with tf.Session() as sess: 365 | 366 | model.saver.restore(sess,ckpt_n_dir) 367 | # get all connection weights and biases as tensorflow variables 368 | var_list = model.var_list 369 | # evaluate the parameters after training 370 | params = [sess.run(var) for var in var_list] 371 | # get hparams 372 | hparams = model.hp 373 | trial = generate_trials(rule, hparams, mode='test', noise_on=False, delay_fac =1) 374 | 375 | #get size of relevant variables to init mats 376 | n_inputs = np.shape(trial.x)[2] 377 | N = np.shape(params[0])[1] 378 | n_stim_dims = n_inputs - 20 379 | #change this depending on when in the trial you're looking [must be a transition btwn epochs] 380 | 381 | #init mats 382 | J_np_x = np.zeros((n_tasks,len(trial_set),len(time_set),N,N)) 383 | 384 | for r in range(n_tasks): 385 | r_all_tasks_ind = task_set[r] 386 | 387 | trial.x[:,:,n_stim_dims:] = 0 #set all tasks to 0 #(n_time, n_trials, n_inputs) 388 | trial.x[:,:,n_stim_dims+r_all_tasks_ind] = 1 #except for this task 389 | 390 | feed_dict = tools.gen_feed_dict(model, trial, hparams) 391 | h_tf = sess.run(model.h, feed_dict=feed_dict) #(n_time, n_trials, n_neuron) 392 | 393 | for trial_i in range(len(trial_set)): #depending on the analysis I was including one or many trials 394 | for time_i in range(len(time_set)): #also including one or many time pts 395 | 396 | inputs = np.squeeze(trial.x[time_set[time_i],trial_set[trial_i],:]) #(n_time, n_condition, n_inputs) 397 | inputs = inputs[np.newaxis,:] 398 | 399 | states = h_tf[time_set[time_i],trial_set[trial_i],:] 400 | states = states[np.newaxis,:] 401 | 402 | #calc Jac wrt inputs 403 | inputs_context = np.squeeze(trial.x[time_set[time_i]-1,trial_set[trial_i],:]) #(n_time, n_condition, n_inputs) 404 | inputs_context = inputs_context[np.newaxis,:] 405 | delta_inputs = inputs - inputs_context 406 | 407 | inputs_tf_context = tf.constant(inputs_context, dtype=tf.float32) 408 | states_tf = tf.constant(states, dtype=tf.float32) 409 | output, new_states = model.cell(inputs_tf_context, states_tf) 410 | F_context = new_states 411 | 412 | J_tf_x = pfor.batch_jacobian(F_context, states_tf, use_pfor=False) 413 | J_np_x[r,trial_i,time_i,:,:] = sess.run(J_tf_x) 414 | 415 | return J_np_x 416 | 417 | def make_h_and_Jac(model_dir_all,ckpt_n_dir,rule,task_set,time_set,trial_set): 418 | 419 | h_context_combined = [] 420 | h_stim_early_combined = [] 421 | h_stim_late_combined = [] 422 | 423 | model = Model(model_dir_all) 424 | with tf.Session() as sess: 425 | 426 | model.saver.restore(sess,ckpt_n_dir) 427 | # get all connection weights and biases as tensorflow variables 428 | var_list = model.var_list 429 | # evaluate the parameters after training 430 | params = [sess.run(var) for var in var_list] 431 | # get hparams 432 | hparams = model.hp 433 | trial = generate_trials('delaygo', hparams, mode='test', noise_on=False, delay_fac =1) 434 | 435 | #get size of relevant variables to init mats 436 | n_inputs = np.shape(trial.x)[2] 437 | N = np.shape(params[0])[1] 438 | n_stim_dims = n_inputs - 20 439 | #change this depending on when in the trial you're looking [must be a transition btwn epochs] 440 | time_set = [trial.epochs['stim1'][0]] #beginning of stim period 441 | 442 | #init mats 443 | J_np_u = np.zeros((n_tasks,len(trial_set),len(time_set),N,n_inputs)) 444 | J_np_u_dot_delu = np.zeros((n_tasks,len(trial_set),len(time_set),N)) 445 | 446 | for r in range(n_tasks): 447 | r_all_tasks_ind = task_set[r] 448 | 449 | trial.x[:,:,n_stim_dims:] = 0 #set all tasks to 0 #(n_time, n_trials, n_inputs) 450 | trial.x[:,:,n_stim_dims+r_all_tasks_ind] = 1 #except for this task 451 | 452 | feed_dict = tools.gen_feed_dict(model, trial, hparams) 453 | h_tf = sess.run(model.h, feed_dict=feed_dict) #(n_time, n_trials, n_neuron) 454 | 455 | # comparing Jacobians to proximity of hidden state across tasks 456 | # we focus on end of the context period, early, and late in the stim period 457 | h_context = np.reshape(h_tf[trial.epochs['stim1'][0]-1,trial_set,:],(1,-1)) # h @ end of context period 458 | h_stim_early = np.reshape(h_tf[trial.epochs['stim1'][0]+n_steps_early,trial_set,:],(1,-1)) # h @ 5 steps into stim 459 | h_stim_late = np.reshape(h_tf[trial.epochs['stim1'][1],trial_set,:],(1,-1)) # h @ end of stim period 460 | 461 | #concatenate activity states across tasks 462 | if h_context_combined == []: 463 | h_context_combined = h_context[np.newaxis,:] 464 | h_stim_late_combined = h_stim_late[np.newaxis,:] 465 | h_stim_early_combined = h_stim_early[np.newaxis,:] 466 | else: 467 | h_context_combined = np.concatenate((h_context_combined, h_context[np.newaxis,:]), axis=0) 468 | h_stim_late_combined = np.concatenate((h_stim_late_combined, h_stim_late[np.newaxis,:]), axis=0) 469 | h_stim_early_combined = np.concatenate((h_stim_early_combined, h_stim_early[np.newaxis,:]), axis=0) 470 | 471 | for trial_i in range(len(trial_set)): #depending on the analysis I was including one or many trials 472 | for time_i in range(len(time_set)): #also including one or many time pts 473 | 474 | inputs = np.squeeze(trial.x[time_set[time_i],trial_set[trial_i],:]) #(n_time, n_condition, n_inputs) 475 | inputs = inputs[np.newaxis,:] 476 | 477 | states = h_tf[time_set[time_i],trial_set[trial_i],:] 478 | states = states[np.newaxis,:] 479 | 480 | #calc Jac wrt inputs 481 | inputs_context = np.squeeze(trial.x[time_set[time_i]-1,trial_set[trial_i],:]) #(n_time, n_condition, n_inputs) 482 | inputs_context = inputs_context[np.newaxis,:] 483 | delta_inputs = inputs - inputs_context 484 | 485 | inputs_tf_context = tf.constant(inputs_context, dtype=tf.float32) 486 | states_tf = tf.constant(states, dtype=tf.float32) 487 | output, new_states = model.cell(inputs_tf_context, states_tf) 488 | F_context = new_states 489 | 490 | J_tf_u = pfor.batch_jacobian(F_context, inputs_tf_context, use_pfor=False) 491 | J_np_u[r,trial_i,time_i,:,:] = sess.run(J_tf_u) 492 | J_np_u_dot_delu[r,trial_i,time_i,:] = np.squeeze(np.dot(J_np_u[r,trial_i,time_i,:,:],delta_inputs.T)) 493 | 494 | return J_np_u_dot_delu, h_context_combined, h_stim_late_combined, h_stim_early_combined 495 | 496 | def prep_procrustes(data1, data2): 497 | r"""Procrustes analysis, a similarity test for two data sets. 498 | 499 | Parameters 500 | ---------- 501 | data1 : array_like 502 | Matrix, n rows represent points in k (columns) space `data1` is the 503 | reference data, after it is standardised, the data from `data2` will be 504 | transformed to fit the pattern in `data1` (must have >1 unique points). 505 | data2 : array_like 506 | n rows of data in k space to be fit to `data1`. Must be the same 507 | shape ``(numrows, numcols)`` as data1 (must have >1 unique points). 508 | Returns 509 | ------- 510 | mtx1 : array_like 511 | A standardized version of `data1`. 512 | mtx2 : array_like 513 | The orientation of `data2` that best fits `data1`. Centered, but not 514 | necessarily :math:`tr(AA^{T}) = 1`. 515 | disparity : float 516 | :math:`M^{2}` as defined above. 517 | 518 | """ 519 | mtx1 = np.array(data1, dtype=np.double, copy=True) 520 | mtx2 = np.array(data2, dtype=np.double, copy=True) 521 | 522 | if mtx1.ndim != 2 or mtx2.ndim != 2: 523 | raise ValueError("Input matrices must be two-dimensional") 524 | if mtx1.shape != mtx2.shape: 525 | raise ValueError("Input matrices must be of same shape") 526 | if mtx1.size == 0: 527 | raise ValueError("Input matrices must be >0 rows and >0 cols") 528 | 529 | # translate all the data to the origin 530 | mtx1 -= np.mean(mtx1, 0) 531 | mtx2 -= np.mean(mtx2, 0) 532 | 533 | norm1 = np.linalg.norm(mtx1) 534 | norm2 = np.linalg.norm(mtx2) 535 | 536 | if norm1 == 0 or norm2 == 0: 537 | raise ValueError("Input matrices must contain >1 unique points") 538 | 539 | # change scaling of data (in rows) such that trace(mtx*mtx') = 1 540 | mtx1 /= norm1 541 | mtx2 /= norm2 542 | 543 | return mtx1,mtx2 544 | 545 | # def procrustes(mtx1, mtx2): 546 | # # transform mtx2 to minimize disparity 547 | # R, s = orthogonal_procrustes(mtx1, mtx2) 548 | # mtx2 = np.dot(mtx2, R.T) * s 549 | 550 | # # measure the dissimilarity between the two datasets 551 | # disparity = np.sum(np.square(mtx1 - mtx2)) 552 | 553 | # return mtx1, mtx2, disparity, R, s 554 | 555 | def same_stim_trial(trial_master, task_num): 556 | n_stim_per_ring = int(np.shape(trial_master.y)[2]-1) 557 | stim_rep_size = int(2*n_stim_per_ring+1) 558 | trial_task_num = trial_master 559 | trial_task_num.x[:,:,stim_rep_size:] = 0 560 | trial_task_num.x[:,:,stim_rep_size+task_num] = 1 561 | return trial_task_num 562 | 563 | def pca_denoise(X1,X2,nD): 564 | pca = PCA(n_components = nD) 565 | X12 = np.concatenate((X1,X2),axis=1) 566 | _ = pca.fit_transform(X12.T) 567 | X1_pca = pca.transform(X1.T) 568 | X2_pca = pca.transform(X2.T) 569 | return X1_pca, X2_pca 570 | 571 | def procrustes_fit(mtx1, mtx2): 572 | # transform mtx2 to minimize disparity 573 | R, s = orthogonal_procrustes(mtx1, mtx2) 574 | mtx2 = np.dot(mtx2, R.T) * s 575 | 576 | # measure the dissimilarity between the two datasets 577 | disparity = np.sum(np.square(mtx1 - mtx2)) 578 | 579 | return mtx1, mtx2, disparity, R, s 580 | 581 | def procrustes_test(mtx1, mtx2, R, s): 582 | # transform mtx2 to minimize disparity 583 | mtx2 = np.dot(mtx2, R.T) * s 584 | 585 | # measure the dissimilarity between the two datasets 586 | disparity = np.sum(np.square(mtx1 - mtx2)) 587 | 588 | return mtx1, mtx2, disparity 589 | 590 | def make_procrustes_mat_stim(model_dir_all,epoch,tasks,nD = 10, batch_size = 1000): 591 | 592 | procrust = {} 593 | procrust['Disparity'] = np.zeros((len(tasks),len(tasks))) 594 | procrust['Scaling'] = np.zeros((len(tasks),len(tasks))) 595 | procrust['R']= np.zeros((len(tasks),len(tasks))) 596 | 597 | 598 | rule = 'delaygo' 599 | trial_all = gen_trials_from_model_dir(model_dir_all,rule, mode = 'random', batch_size = batch_size) 600 | trial_all_test = gen_trials_from_model_dir(model_dir_all,rule, mode = 'random', batch_size = batch_size) 601 | 602 | for t1_ind in range(len(tasks)): 603 | t1 = tasks[t1_ind] 604 | 605 | trial1 = same_stim_trial(trial_all, t1) 606 | X1,_ = gen_X_from_model_dir_epoch(model_dir_all,trial1,epoch) 607 | 608 | trial1_test = same_stim_trial(trial_all_test, t1) 609 | X1_test,_ = gen_X_from_model_dir_epoch(model_dir_all,trial1_test,epoch) 610 | 611 | for t2_ind in range(len(tasks)): 612 | if t1_ind !=t2_ind: 613 | t2 = tasks[t2_ind] 614 | 615 | trial2 = same_stim_trial(trial_all, t2) 616 | X2,_ = gen_X_from_model_dir_epoch(model_dir_all,trial2,epoch) 617 | X1_pca,X2_pca = pca_denoise(X1,X2,nD) 618 | prep_mtx1, prep_mtx2 = prep_procrustes(X1_pca,X2_pca) 619 | _, _, disparity_train, R, s = procrustes_fit(prep_mtx1, prep_mtx2) 620 | 621 | trial2_test = same_stim_trial(trial_all_test, t2) 622 | X2_test,_ = gen_X_from_model_dir_epoch(model_dir_all,trial2_test,epoch) 623 | X1_pca_test,X2_pca_test = pca_denoise(X1_test,X2_test,nD) 624 | prep_mtx1_test, prep_mtx2_test = prep_procrustes(X1_pca_test,X2_pca_test) 625 | mtx1, mtx2, disparity_test = procrustes_test(prep_mtx1_test, prep_mtx2_test, R, s) 626 | 627 | procrust['Disparity'][t1_ind,t2_ind] = disparity_test 628 | procrust['Scaling'][t1_ind,t2_ind] = s 629 | procrust['R'][t1_ind,t2_ind] = calc_R_angle(R) 630 | return procrust 631 | 632 | def align_output_inds(trial_master, trial_temp): 633 | 634 | indices = range(np.shape(trial_master.y_loc)[1]) 635 | n_out = np.shape(trial_master.y)[2]-1 636 | 637 | for ii in range(np.shape(trial_master.y_loc)[1]): 638 | if np.max(np.sum(abs(trial_master.x[:,ii,1:(1+n_out)]),axis = 1),axis = 0)>0: 639 | ind_use = np.max(np.sum(abs(trial_temp.x[:,:,1:(1+n_out)]),axis = 2),axis = 0)>0 640 | else: 641 | ind_use = np.max(np.sum(abs(trial_temp.x[:,:,(1+n_out):(1+2*n_out)]),axis = 2),axis = 0)>0 642 | 643 | loc_diff = abs(trial_temp.y_loc[-1,:]-trial_master.y_loc[-1,ii])%(2*np.pi) 644 | align_ind = [int(i) for i, x in enumerate(loc_diff) if x == min(loc_diff)] 645 | align_ind_choosey = [x for i, x in enumerate(align_ind) if ind_use[x]] 646 | if len(align_ind_choosey)==0: 647 | align_ind_choosey = align_ind 648 | indices[ii] = align_ind_choosey[npr.randint(len(align_ind_choosey))] 649 | 650 | trial_temp_new = trial_temp 651 | trial_temp_new.x = trial_temp_new.x[:,indices,:] 652 | trial_temp_new.y = trial_temp_new.y[:,indices,:] 653 | trial_temp_new.y_loc = trial_temp_new.y_loc[:,indices] 654 | return trial_temp_new 655 | 656 | def project_to_output(model_dir_all,X): 657 | w_in, b_in, w_out, b_out = get_model_params(model_dir_all) 658 | y = np.dot(X.T, w_out) + b_out 659 | return y 660 | 661 | def gen_mov_x(model_dir_all,rule,trial_master, batch_size = 2000, ckpt_n_dir = []): 662 | trial = gen_trials_from_model_dir(model_dir_all,rule,mode = 'random', batch_size = batch_size) 663 | trial = align_output_inds(trial_master, trial) 664 | _,x = gen_X_from_model_dir_epoch(model_dir_all,trial,'go1') 665 | x_out = project_to_output(model_dir_all,x[:,:,-1]) 666 | err = np.sum(np.square(x_out[:,1:] - trial.y[-1,:,1:]),axis=1) 667 | return err, x 668 | 669 | def make_fp_struct(m,fp_file,rule,epoch,ind_stim_loc,trial_set = range(0,360,36)): 670 | 671 | fps = [] 672 | J_xstar = [] 673 | 674 | if (rule[:2]=='fd') & (epoch=='delay1'): 675 | epoch_temp = 'stim1' 676 | 677 | for ti in trial_set: 678 | filename = os.path.join(m,fp_file,rule,epoch_temp+'_'+str(round(ti,2))+'.npz') 679 | fp_struct = np.load(filename) 680 | fp_num = np.argmin(np.log10(fp_struct['qstar'])) 681 | 682 | fps_temp = fp_struct['xstar'][fp_num,:] 683 | J_xstar_temp = fp_struct['J_xstar'][fp_num,:,:] 684 | 685 | if len(np.shape(fps_temp))==1: 686 | fps = fps_temp[np.newaxis,:] 687 | J_xstar = J_xstar_temp[np.newaxis,:,:] 688 | else: 689 | fps = np.concatenate((fps,fps_temp[np.newaxis,:]),axis = 0) 690 | J_xstar = np.concatenate((J_xstar,J_xstar_temp[np.newaxis,:,:]),axis = 0) 691 | 692 | else: 693 | filename = os.path.join(m,fp_file,rule,epoch+'_'+str(round(ind_stim_loc,2))+'.npz') 694 | fp_struct = np.load(filename) 695 | print(filename) 696 | if (epoch=='delay1') or ((rule[:2]!='fd') & (epoch=='go1')): 697 | fp_num = np.squeeze(np.argwhere(np.log10(fp_struct['qstar'])<-0)) 698 | else: 699 | fp_num = np.argmin(np.log10(fp_struct['qstar'])) 700 | 701 | if len(np.shape(fp_struct['xstar'][fp_num,:]))==1: 702 | fps = fp_struct['xstar'][fp_num,:][np.newaxis,:] 703 | J_xstar = fp_struct['J_xstar'][fp_num,:,:][np.newaxis,:,:] 704 | else: 705 | fps = fp_struct['xstar'][fp_num,:] 706 | J_xstar = fp_struct['J_xstar'][fp_num,:,:] 707 | 708 | return fps, J_xstar 709 | 710 | def load_fps_J(m,fp_file,rule,epoch,ind_stim_loc,trial_set): 711 | 712 | ind_stim_loc_anti = (ind_stim_loc+180)%360 # ind_stim_loc is the input angle angle, anti is in the opposite direction (relevant for file names) 713 | 714 | if rule[-4:]=='anti': # anti task 715 | if (rule == 'delayanti') & (epoch!='stim1'): # if outside of stim epoch, inputs are the same across trials (and therefore only one set of FPs) 716 | ind_stim_loc_anti=180 # this is the output angle that we identified the set of fixed points on (could use any trial on this epoch) 717 | fps, J_xstar = make_fp_struct(m,fp_file,rule,epoch,ind_stim_loc_anti,trial_set = trial_set) # load fixed points and Jacobian 718 | else: # pro task 719 | if (rule == 'delaygo') & (epoch!='stim1'): #again, if outside of stim epoch, inputs are the same across trials (and therefore only one set of FPs) 720 | ind_stim_loc=0 # this is the output angle that we identified the set of fixed points on (could use any trial on this epoch) 721 | fps, J_xstar = make_fp_struct(m,fp_file,rule,epoch,ind_stim_loc,trial_set = trial_set) # load fixed points and Jacobian 722 | 723 | return fps, J_xstar 724 | 725 | 726 | def make_fp_tdr_fig(m,fp_file,rule1,rule2,epoch,ind_stim_loc,tit,trial_set = range(0,360,36),dims = 'tdr'): 727 | 728 | nr = 1 # number of rows in subplots 729 | nc = 1 # number of columns in subplots 730 | ms = 10 # marker size 731 | 732 | h,trial,tasks = make_h_trial_rule(m) 733 | D = get_D(dims,h,trial,[rule1,],epoch,ind = -1) #identify subspace through either PCA or TDR 734 | 735 | fig = plt.figure(figsize=(5.5*nc,4.5*nr),tight_layout=True,facecolor='white') 736 | ax = plt.subplot(nr,nc,1) 737 | cmap=plt.get_cmap('hsv') 738 | 739 | for ind_stim_loc in trial_set: 740 | 741 | # load fixed points for rule 1 and plot in rule 1 axes 742 | fps, J_xstar = load_fps_J(m,fp_file,rule1,epoch,ind_stim_loc,trial_set) 743 | fp_tdr = np.dot(fps,D[rule1].T) # project FP into subspace 744 | if (epoch=='delay1') or (epoch=='go1'): 745 | plt.plot(fp_tdr[:,0],fp_tdr[:,1],'o',c = 'dodgerblue',markersize = ms) 746 | else: 747 | plt.plot(fp_tdr[:,0],fp_tdr[:,1],'o',c = cmap(ind_stim_loc/360),markersize = ms) # if FP diff on different trials, color by input 748 | 749 | # load fixed points for rule 2 and plot in rule 1 axes 750 | fps, J_xstar = load_fps_J(m,fp_file,rule2,epoch,ind_stim_loc,trial_set) 751 | fp_tdr = np.dot(fps,D[rule1].T) 752 | if (epoch=='delay1') or (epoch=='go1'): 753 | plt.plot(fp_tdr[:,0],fp_tdr[:,1],'o',c = 'orangered',markersize = ms) 754 | else: 755 | print(epoch) 756 | plt.plot(fp_tdr[:,0],fp_tdr[:,1],'o',c = cmap(ind_stim_loc/360),markerfacecolor = 'w',markersize = ms) 757 | 758 | if dims == 'tdr': 759 | plt.xlabel(rule1 + ' TDR input 1') 760 | plt.ylabel(rule1 + ' TDR input 2') 761 | ax.spines['top'].set_visible(False) 762 | ax.spines['right'].set_visible(False) 763 | plt.title('Fixed Points : ' + tit) 764 | plt.legend((rule1,rule2)) 765 | return ax 766 | 767 | def make_h_combined(model_dir_all,ckpt_n_dir,tasks,trial_set,n_steps_early = 5): 768 | 769 | h_context_combined = [] 770 | h_stim_early_combined = [] 771 | h_stim_late_combined = [] 772 | 773 | model = Model(model_dir_all) 774 | with tf.Session() as sess: 775 | 776 | rule = 'delaygo' 777 | model.saver.restore(sess,ckpt_n_dir) 778 | # get all connection weights and biases as tensorflow variables 779 | var_list = model.var_list 780 | # evaluate the parameters after training 781 | params = [sess.run(var) for var in var_list] 782 | # get hparams 783 | hparams = model.hp 784 | trial = generate_trials(rule, hparams, mode='test', noise_on=False, delay_fac =1) 785 | 786 | #get size of relevant variables to init mats 787 | n_inputs = np.shape(trial.x)[2] 788 | N = np.shape(params[0])[1] 789 | #change this depending on when in the trial you're looking [must be a transition btwn epochs] 790 | time_set = [trial.epochs['stim1'][0]] #beginning of stim period 791 | n_stim_dims = np.shape(trial.x)[2]-20 792 | 793 | 794 | for r in range(len(tasks)): 795 | r_all_tasks_ind = tasks[r] 796 | 797 | trial.x[:,:,n_stim_dims:] = 0 #set all tasks to 0 #(n_time, n_trials, n_inputs) 798 | trial.x[:,:,n_stim_dims+r_all_tasks_ind] = 1 #except for this task 799 | 800 | feed_dict = tools.gen_feed_dict(model, trial, hparams) 801 | h_tf = sess.run(model.h, feed_dict=feed_dict) #(n_time, n_trials, n_neuron) 802 | 803 | # comparing Jacobians to proximity of hidden state across tasks 804 | # we focus on end of the context period, early, and late in the stim period 805 | h_context = np.reshape(h_tf[trial.epochs['stim1'][0]-1,trial_set,:],(1,-1)) # h @ end of context period 806 | h_stim_early = np.reshape(h_tf[trial.epochs['stim1'][0]+n_steps_early,trial_set,:],(1,-1)) # h @ 5 steps into stim 807 | h_stim_late = np.reshape(h_tf[trial.epochs['stim1'][1],trial_set,:],(1,-1)) # h @ end of stim period 808 | 809 | #concatenate activity states across tasks 810 | if h_context_combined == []: 811 | h_context_combined = h_context[np.newaxis,:] 812 | h_stim_late_combined = h_stim_late[np.newaxis,:] 813 | h_stim_early_combined = h_stim_early[np.newaxis,:] 814 | else: 815 | h_context_combined = np.concatenate((h_context_combined, h_context[np.newaxis,:]), axis=0) 816 | h_stim_late_combined = np.concatenate((h_stim_late_combined, h_stim_late[np.newaxis,:]), axis=0) 817 | h_stim_early_combined = np.concatenate((h_stim_early_combined, h_stim_early[np.newaxis,:]), axis=0) 818 | 819 | return h_context_combined, h_stim_late_combined, h_stim_early_combined 820 | 821 | def generate_Beta_epoch(h_tf,trial,ind = -1,mod = 'either', ind_adjust = 0): 822 | Beta_epoch = {} 823 | 824 | for epoch in trial.epochs.keys(): 825 | 826 | T_inds = get_T_inds(trial,epoch) 827 | T_use = T_inds[ind] 828 | 829 | inds_use = np.min(trial.stim_strength,axis=1)>.5 830 | # X = h_tf[T_use,inds_use,:].T 831 | # X_zscore = stats.zscore(X, axis=1) 832 | # X_zscore_nonan = X_zscore 833 | # X_zscore_nonan[np.isnan(X_zscore)] = 0 834 | # r = X_zscore_nonan 835 | 836 | r = h_tf[T_use,inds_use,:].T 837 | 838 | if mod is 'either': 839 | stim1_locs = np.min(trial.stim_locs[:,[0,2]],axis=1) 840 | stim2_locs = np.min(trial.stim_locs[:,[1,3]],axis=1) 841 | elif mod==1: 842 | stim1_locs = trial.stim_locs[:,0] 843 | stim2_locs = trial.stim_locs[:,1] 844 | elif mod==2: 845 | stim1_locs = trial.stim_locs[:,2] 846 | stim2_locs = trial.stim_locs[:,3] 847 | 848 | y_loc = trial.y_loc[-1,:] 849 | 850 | if epoch == 'stim1' or epoch == 'delay1': 851 | angle_var = stim1_locs[inds_use] 852 | elif epoch =='stim2' or epoch == 'delay2': 853 | angle_var = stim2_locs[inds_use] 854 | elif epoch =='go1' or epoch == 'fix1': 855 | angle_var = stim1_locs[inds_use] 856 | 857 | y1 = np.expand_dims(np.sin(angle_var),axis = 1) 858 | y2 = np.expand_dims(np.cos(angle_var),axis = 1) 859 | y = np.concatenate((y1,y2),axis=1) 860 | 861 | lm = linear_model.LinearRegression() 862 | model = lm.fit(y,r.T) 863 | Beta = model.coef_ 864 | Beta_epoch[epoch],_ = LA.qr(Beta) 865 | 866 | #Make sure vectors are oriented appropriately 867 | #first identify a trial that should be in quadrant 1 868 | quad1_arg = np.argmin((angle_var - np.pi/4)%(2*np.pi)) 869 | quad1_x = h_tf[T_use,quad1_arg,:] 870 | dr_loc = np.dot(quad1_x,Beta_epoch[epoch]) 871 | 872 | #flip vectors so that point is actually in quadrant 1 873 | if dr_loc[0]<0: 874 | Beta_epoch[epoch][:,0] = -Beta_epoch[epoch][:,0] 875 | 876 | if dr_loc[1]<0: 877 | Beta_epoch[epoch][:,1] = -Beta_epoch[epoch][:,1] 878 | 879 | return Beta_epoch 880 | 881 | # def make_axes(model_dir_all,ckpt_n_dir,rule_master,epoch,ind = -1,mod = 'either'): 882 | 883 | # model = Model(model_dir_all) 884 | # with tf.Session() as sess: 885 | 886 | # model.saver.restore(sess,ckpt_n_dir) 887 | # # get all connection weights and biases as tensorflow variables 888 | # var_list = model.var_list 889 | # # evaluate the parameters after training 890 | # params = [sess.run(var) for var in var_list] 891 | # # get hparams 892 | # hparams = model.hp 893 | # trial_master = generate_trials(rule_master, hparams, mode = 'test', batch_size = 400, noise_on=False, delay_fac =1) 894 | # feed_dict = tools.gen_feed_dict(model, trial_master, hparams) 895 | # h_tf = sess.run(model.h, feed_dict=feed_dict) #(n_time, n_trials, n_neuron) 896 | 897 | # Beta_epoch = generate_Beta_epoch(h_tf,trial_master,ind,mod = mod) 898 | # X_pca = Beta_epoch[epoch] 899 | # D = np.concatenate((np.expand_dims(X_pca[:,0],axis=1),np.expand_dims(X_pca[:,1],axis=1)),axis = 1) 900 | # return D 901 | 902 | def get_D(dims,h,trial,tasks,epoch,ind = -1): 903 | D = {} 904 | 905 | if dims=='pca': 906 | for ri in range(len(tasks)): 907 | rule = tasks[ri] 908 | pca = PCA(n_components = 100) 909 | X = np.reshape(h[rule],(-1,N)) 910 | _ = pca.fit_transform(X) 911 | D[rule] = pca.components_ 912 | elif dims=='tdr': 913 | for ri in range(len(tasks)): 914 | rule = tasks[ri] 915 | Beta_temp = generate_Beta_epoch(h[rule],trial[rule],ind = ind) 916 | if (rule[:2] == 'fd') & (epoch == 'delay1'): 917 | D[rule] = Beta_temp['stim1'].T 918 | else: 919 | D[rule] = Beta_temp[epoch].T 920 | return D 921 | 922 | def get_T_inds(trial,epoch): 923 | 924 | T_end = trial.epochs[epoch][1] 925 | if T_end is None: 926 | T_end = np.shape(trial.x)[0] 927 | 928 | T_start = trial.epochs[epoch][0] 929 | if T_start is None: 930 | T_start = 1 931 | 932 | T_inds = range(T_start-1,T_end) 933 | 934 | return T_inds 935 | 936 | def generate_Beta_timeseries(h_tf,trial,T_inds,align_group): 937 | T,S,N = np.shape(h_tf) 938 | Beta_timeseries = np.empty((N,2,len(T_inds))) 939 | 940 | for t in T_inds: 941 | 942 | inds_use = np.min(trial.stim_strength,axis=1)>.5 943 | # X = h_tf[t,inds_use,:].T 944 | # X_zscore = stats.zscore(X, axis=1) 945 | # X_zscore_nonan = X_zscore 946 | # X_zscore_nonan[np.isnan(X_zscore)] = 0 947 | # r = X_zscore_nonan 948 | r = h_tf[t,inds_use,:].T 949 | 950 | stim1_locs = np.min(trial.stim_locs[:,[0,2]],axis=1) 951 | stim2_locs = np.min(trial.stim_locs[:,[1,3]],axis=1) 952 | y_loc = trial.y_loc[-1,:] 953 | 954 | if align_group == 'stim1': 955 | angle_var = stim1_locs[inds_use] 956 | elif align_group =='stim2': 957 | angle_var = stim2_locs[inds_use] 958 | elif align_group =='go1': 959 | angle_var = y_loc[inds_use] 960 | 961 | y1 = np.expand_dims(np.sin(angle_var),axis = 1) 962 | y2 = np.expand_dims(np.cos(angle_var),axis = 1) 963 | y = np.concatenate((y1,y2),axis=1) 964 | 965 | lm = linear_model.LinearRegression() 966 | model = lm.fit(y,r.T) 967 | Beta = model.coef_ 968 | Beta_timeseries[:,:,t],_ = LA.qr(Beta) 969 | 970 | return Beta_timeseries 971 | 972 | def get_stim_cats(trial): 973 | #stim locations and category ids 974 | stim1_locs = np.min(trial.stim_locs[:,[0,2]],axis=1) 975 | stim2_locs = np.min(trial.stim_locs[:,[1,3]],axis=1) 976 | 977 | stim1_cats = stim1_locs10: 118 | self.stim_locs[i,2*mods[i]-2] = locs[i] 119 | self.stim_strength[i,2*mods[i]-2] = strengths[i] 120 | else: 121 | self.stim_locs[i,2*mods[i]-1] = locs[i] 122 | self.stim_strength[i,2*mods[i]-1] = strengths[i] 123 | 124 | elif loc_type == 'fix_out': 125 | # Notice this shouldn't be set at 1, because the output is logistic and saturates at 1 126 | if self.config['loss_type'] == 'lsq': 127 | self.y[ons[i]: offs[i], i, 0] = 0.8 128 | else: 129 | self.y[ons[i]: offs[i], i, 0] = 1.0 130 | elif loc_type == 'out': 131 | if self.config['loss_type'] == 'lsq': 132 | self.y[ons[i]: offs[i], i, 1:] += self.add_y_loc(locs[i])*strengths[i] 133 | else: 134 | y_tmp = self.add_y_loc(locs[i]) 135 | y_tmp /= np.sum(y_tmp) 136 | self.y[ons[i]: offs[i], i, 1:] += .8*y_tmp # shrink target to prevent saturation of some act fxns 20190522 137 | self.y_loc[ons[i]: offs[i], i] = locs[i] 138 | else: 139 | raise ValueError('Unknown loc_type') 140 | 141 | def add_x_noise(self): 142 | """Add input noise.""" 143 | self.x += self.config['rng'].randn(*self.x.shape)*self._sigma_x 144 | 145 | def add_c_mask(self, pre_offs, post_ons): 146 | """Add a cost mask. 147 | 148 | Usually there are two periods, pre and post response 149 | Scale the mask weight for the post period so in total it's as important 150 | as the pre period 151 | """ 152 | 153 | pre_on = int(100/self.dt) # never check the first 100ms 154 | pre_offs = self.expand(pre_offs) 155 | post_ons = self.expand(post_ons) 156 | 157 | if self.config['loss_type'] == 'lsq': 158 | c_mask = np.zeros((self.tdim, self.batch_size, self.n_output), dtype=self.float_type) 159 | for i in range(self.batch_size): 160 | # Post response periods usually have the same length across tasks 161 | c_mask[post_ons[i]:, i, :] = 5. 162 | # Pre-response periods usually have different lengths across tasks 163 | # To keep cost comparable across tasks 164 | # Scale the cost mask of the pre-response period by a factor 165 | c_mask[pre_on:pre_offs[i], i, :] = 1. 166 | 167 | # self.c_mask[:, :, 0] *= self.n_eachring # Fixation is important 168 | c_mask[:, :, 0] *= 2. # Fixation is important 169 | 170 | self.c_mask = c_mask.reshape((self.tdim*self.batch_size, self.n_output)) 171 | else: 172 | c_mask = np.zeros((self.tdim, self.batch_size), dtype=self.float_type) 173 | for i in range(self.batch_size): 174 | # Post response periods usually have the same length across tasks 175 | # Having it larger than 1 encourages the network to achieve higher performance 176 | c_mask[post_ons[i]:, i] = 5. 177 | # Pre-response periods usually have different lengths across tasks 178 | # To keep cost comparable across tasks 179 | # Scale the cost mask of the pre-response period by a factor 180 | c_mask[pre_on:pre_offs[i], i] = 1. 181 | 182 | self.c_mask = c_mask.reshape((self.tdim*self.batch_size,)) 183 | self.c_mask /= self.c_mask.mean() 184 | 185 | def add_rule(self, rule, on=None, off=None, strength=1.): 186 | """Add rule input.""" 187 | if isinstance(rule, int): 188 | self.x[on:off, :, self.config['rule_start']+rule] = strength 189 | else: 190 | ind_rule = get_rule_index(rule, self.config) 191 | self.x[on:off, :, ind_rule] = strength 192 | 193 | def add_x_loc(self, x_loc): 194 | """Input activity given location.""" 195 | return np.array((np.sin(x_loc), np.cos(x_loc))) 196 | 197 | def add_y_loc(self, y_loc): 198 | """Target response given location.""" 199 | dist = get_dist(y_loc-self.pref) # periodic boundary 200 | if self.config['loss_type'] == 'lsq': 201 | y = np.array((np.sin(y_loc), np.cos(y_loc))) 202 | # y = 1. / (1. + np.exp(-y)) # Remove sigmoid 203 | else: 204 | # One-hot output 205 | y = np.zeros_like(dist) 206 | ind = np.argmin(dist) 207 | y[ind] = 1. 208 | return y 209 | 210 | 211 | def test_init(config, mode, **kwargs): 212 | ''' 213 | Test initialization of model. mode is not actually used 214 | Fixation is on then off. 215 | ''' 216 | dt = config['dt'] 217 | tdim = int(10000/dt) 218 | fix_offs = [int(800/dt)] 219 | batch_size = 1 220 | 221 | trial = Trial(config, tdim, batch_size) 222 | trial.add('fix_in', offs=fix_offs) 223 | 224 | return trial 225 | 226 | 227 | def delaygo_(config, mode, anti_response, **kwargs): 228 | ''' 229 | Fixate whenever fixation point is shown, 230 | saccade to the location of the previously shown stimulus 231 | whenever the fixation point is off 232 | Generate one batch of trials 233 | 234 | The fixation is shown between (0, fix_off) 235 | The stimulus is shown between (stim_on, stim_off) 236 | 237 | The output should be fixation location for (0, fix_off) 238 | and the stimulus location for (fix_off, T) 239 | 240 | :param mode: the mode of generating. Options: 'random', 'explicit'... 241 | Optional parameters: 242 | :param batch_size: Batch size (required for mode=='random') 243 | :param tdim: dimension of time (required for mode=='sample') 244 | :param param: a dictionary of parameters (required for mode=='explicit') 245 | :return: 2 Tensor3 data array (Time, Batchsize, Units) 246 | ''' 247 | dt = config['dt'] 248 | rng = config['rng'] 249 | if mode == 'random': # Randomly generate parameters 250 | batch_size = kwargs['batch_size'] 251 | delay_fac = kwargs['delay_fac'] 252 | 253 | # A list of locations of stimuluss and on/off time 254 | stim_locs = rng.rand(batch_size)*2*np.pi 255 | # stim_ons = int(500/dt) 256 | stim_ons = int(rng.uniform(300,700)/dt) # int(rng.choice([300, 500, 700])/dt) #dec 19th 2018 257 | # stim_offs = stim_ons + int(200/dt) 258 | stim_offs = stim_ons + int(rng.uniform(200,600)/dt) #int(rng.choice([200, 400, 600])/dt) # dec 14 2018 259 | fix_offs = stim_offs + int(delay_fac*rng.uniform(200,1600)/dt) #int(rng.choice([200, 400, 800, 1600])/dt) # dec 14 2018 260 | # fix_offs = stim_offs + int(rng.choice([1600])/dt) 261 | tdim = fix_offs + int(rng.uniform(300,700)/dt) # 20190510 262 | stim_mod = rng.choice([1,2]) 263 | 264 | elif mode == 'test': 265 | tdim = int(2500/dt) 266 | n_stim_loc, n_stim_mod = batch_shape = 100, 2 267 | 268 | batch_size = np.prod(batch_shape) 269 | ind_stim_loc, ind_stim_mod = np.unravel_index(range(batch_size),batch_shape) 270 | 271 | fix_offs = int(2000/dt) 272 | stim_locs = 2*np.pi*ind_stim_loc/n_stim_loc 273 | stim_ons = int(500/dt) 274 | stim_mod = ind_stim_mod + 1 275 | stim_offs = int(1000/dt) 276 | 277 | elif mode == 'psychometric': 278 | p = kwargs['params'] 279 | stim_locs = p['stim_locs'] 280 | # Time of stimuluss on/off 281 | stim_ons = int(p['stim_ons']/dt) 282 | stim_offs = int(p['stim_offs']/dt) 283 | delay_time = int(p['delay_time']/dt) 284 | fix_offs = stim_offs + delay_time 285 | tdim = int(400/dt) + fix_offs 286 | stim_mod = 1 287 | 288 | batch_size = len(stim_locs) 289 | 290 | else: 291 | raise ValueError('Unknown mode: ' + str(mode)) 292 | 293 | check_ons= fix_offs + int(100/dt) 294 | 295 | # Response locations 296 | stim_locs = np.array(stim_locs) 297 | if not anti_response: 298 | response_locs = stim_locs 299 | else: 300 | response_locs = (stim_locs+np.pi)%(2*np.pi) 301 | 302 | trial = Trial(config, tdim, batch_size) 303 | trial.add('fix_in', offs=fix_offs) 304 | trial.add('stim', stim_locs, ons=stim_ons, offs=stim_offs, mods=stim_mod) 305 | trial.add('fix_out', offs=fix_offs) 306 | trial.add('out', response_locs, ons=fix_offs) 307 | trial.add_c_mask(pre_offs=fix_offs, post_ons=check_ons) 308 | 309 | trial.epochs = {'fix1' : (None, stim_ons), 310 | 'stim1' : (stim_ons, stim_offs), 311 | 'delay1' : (stim_offs, fix_offs), 312 | 'go1' : (fix_offs, None)} 313 | 314 | return trial 315 | 316 | 317 | def delaygo(config, mode, **kwargs): 318 | return delaygo_(config, mode, False, **kwargs) 319 | 320 | 321 | def contextdm_genstim(batch_size, rng, stim_coh_range=None): 322 | stim_mean = rng.uniform(0.8, 1.2, (batch_size,)) 323 | if stim_coh_range is None: 324 | stim_coh_range = np.random.uniform(0, 0.8, (100,)) #110220 change lower bound to zero 325 | stim_coh = rng.choice(stim_coh_range, (batch_size,)) 326 | stim_sign = rng.choice([+1, -1], (batch_size,)) 327 | stim1_strengths = stim_mean + stim_coh*stim_sign 328 | stim2_strengths = stim_mean - stim_coh*stim_sign 329 | return stim1_strengths, stim2_strengths 330 | 331 | 332 | def _contextdm(config, mode, attend_mod, **kwargs): 333 | ''' 334 | Fixate whenever fixation point is shown. 335 | Two stimuluss are shown in each ring, 336 | Saccade to the one with higher intensity for the attended ring 337 | Generate one batch of trials 338 | 339 | The fixation is shown between (0, fix_off) 340 | The two stimuluss is shown between (0,T) 341 | 342 | The output should be fixation location for (0, fix_off) 343 | Otherwise the location of the stronger stimulus 344 | 345 | In this task, if the model's strategy is to ignore context, and integrate both, 346 | then the maximum performance is 75%. So we need to make the highest correct performance 347 | much higher than that. 348 | 349 | :param mode: the mode of generating. Options: 'random', 'explicit'... 350 | Optional parameters: 351 | :param batch_size: Batch size (required for mode=='random') 352 | :param tdim: dimension of time (required for mode=='sample') 353 | :param param: a dictionary of parameters (required for mode=='explicit') 354 | :return: 2 Tensor3 data array (Time, Batchsize, Units) 355 | ''' 356 | dt = config['dt'] 357 | rng = config['rng'] 358 | if mode == 'random': # Randomly generate parameters 359 | batch_size = kwargs['batch_size'] 360 | 361 | # A list of locations of stimuluss, same locations for both modalities 362 | stim_dist = rng.uniform(0.5*np.pi, 1.5*np.pi,(batch_size,))*rng.choice([-1,1],(batch_size,)) 363 | stim1_locs = rng.uniform(0, 2*np.pi, (batch_size,)) 364 | stim2_locs = (stim1_locs+stim_dist)%(2*np.pi) 365 | 366 | stim_coh_range = np.random.uniform(0.005, 0.8, (100,)) 367 | 368 | if ('easy_task' in config) and config['easy_task']: 369 | # stim_coh_range = np.array([0.1, 0.2, 0.4, 0.8]) 370 | stim_coh_range *= 10 371 | 372 | if (attend_mod == 1) or (attend_mod == 2): 373 | stim1_mod1_strengths, stim2_mod1_strengths = contextdm_genstim(batch_size, rng, stim_coh_range) 374 | stim1_mod2_strengths, stim2_mod2_strengths = contextdm_genstim(batch_size, rng, stim_coh_range) 375 | if attend_mod == 1: 376 | stim1_strengths, stim2_strengths = stim1_mod1_strengths, stim2_mod1_strengths 377 | else: 378 | stim1_strengths, stim2_strengths = stim1_mod2_strengths, stim2_mod2_strengths 379 | else: 380 | stim1_strengths, stim2_strengths = contextdm_genstim(batch_size, rng, stim_coh_range) 381 | 382 | stim1_mod12_diff = stim1_strengths * \ 383 | np.random.uniform(0.005, 0.8, (batch_size,)) * \ 384 | np.random.choice([+1, -1], (batch_size,)) #Min diff changed from .2 to .005 20190805 385 | stim1_mod1_strengths = stim1_strengths + stim1_mod12_diff/2 386 | stim1_mod2_strengths = stim1_strengths - stim1_mod12_diff/2 387 | 388 | stim2_mod12_diff = stim2_strengths * \ 389 | np.random.uniform(0.005, 0.8, (batch_size,)) * \ 390 | np.random.choice([+1, -1], (batch_size,)) 391 | stim2_mod1_strengths = stim2_strengths + stim2_mod12_diff/2 392 | stim2_mod2_strengths = stim2_strengths - stim2_mod12_diff/2 393 | 394 | # Time of stimuluss on/off 395 | stim_on = int(rng.uniform(100,400)/dt) 396 | stim_ons = (np.ones(batch_size)*stim_on).astype(int) 397 | stim_dur = int(rng.uniform(400,1600)/dt) #changed from discrete 20190805 398 | # stim_dur = rng.choice((np.array([200, 400, 800, 1600])/dt).astype(int)) # Current setting 399 | # stim_dur = int(rng.uniform(500, 1000)/dt) # Current setting 400 | # stim_dur = int(800/dt) 401 | stim_offs = stim_ons+stim_dur 402 | 403 | # delay_dur = rng.choice((np.array([200, 400, 800])/dt).astype(int)) # Current setting 404 | delay_dur = 0 405 | fix_offs = stim_offs + delay_dur 406 | 407 | # each batch consists of sequences of equal length 408 | tdim = stim_on+stim_dur+delay_dur + int(rng.uniform(300,700)/dt) # 20190510 409 | 410 | elif mode == 'test': 411 | tdim = int(2000/dt) 412 | n_stim_loc, n_stim_mod1_strength, n_stim_mod2_strength = batch_shape = 100, 4, 4 413 | batch_size = np.prod(batch_shape) 414 | ind_stim_loc, ind_stim_mod1_strength, ind_stim_mod2_strength = np.unravel_index(range(batch_size),batch_shape) 415 | fix_offs = int(1500/dt) 416 | 417 | stim1_locs = 2*np.pi*ind_stim_loc/n_stim_loc 418 | stim2_locs = (stim1_locs+np.pi)%(2*np.pi) 419 | stim1_mod1_strengths = 0.4*ind_stim_mod1_strength/n_stim_mod1_strength+0.8 420 | stim2_mod1_strengths = 2 - stim1_mod1_strengths 421 | stim1_mod2_strengths = 0.4*ind_stim_mod2_strength/n_stim_mod2_strength+0.8 422 | stim2_mod2_strengths = 2 - stim1_mod2_strengths 423 | stim_ons = int(500/dt) 424 | stim_offs = int(1500/dt) 425 | 426 | elif mode == 'psychometric': 427 | p = kwargs['params'] 428 | stim1_locs = p['stim1_locs'] 429 | stim2_locs = p['stim2_locs'] 430 | stim1_mod1_strengths = p['stim1_mod1_strengths'] 431 | stim2_mod1_strengths = p['stim2_mod1_strengths'] 432 | stim1_mod2_strengths = p['stim1_mod2_strengths'] 433 | stim2_mod2_strengths = p['stim2_mod2_strengths'] 434 | stim_time = int(p['stim_time']/dt) 435 | batch_size = len(stim1_locs) 436 | 437 | # Time of stimuluss on/off 438 | stim_ons = int(500/dt) 439 | stim_offs = stim_ons + stim_time 440 | fix_offs = stim_offs 441 | tdim = int(500/dt) + fix_offs 442 | 443 | else: 444 | raise ValueError('Unknown mode: ' + str(mode)) 445 | 446 | # time to check the saccade location 447 | check_ons = fix_offs + int(100/dt) 448 | 449 | if attend_mod == 1: 450 | stim1_strengths, stim2_strengths = stim1_mod1_strengths, stim2_mod1_strengths 451 | elif attend_mod == 2: 452 | stim1_strengths, stim2_strengths = stim1_mod2_strengths, stim2_mod2_strengths 453 | elif attend_mod == 'both': 454 | stim1_strengths = stim1_mod1_strengths + stim1_mod2_strengths 455 | stim2_strengths = stim2_mod1_strengths + stim2_mod2_strengths 456 | 457 | trial = Trial(config, tdim, batch_size) 458 | trial.add('fix_in', offs=fix_offs) 459 | trial.add('stim', stim1_locs, ons=stim_ons, offs=stim_offs, strengths=stim1_mod1_strengths, mods=1) 460 | trial.add('stim', stim2_locs, ons=stim_ons, offs=stim_offs, strengths=stim2_mod1_strengths, mods=1) 461 | trial.add('stim', stim1_locs, ons=stim_ons, offs=stim_offs, strengths=stim1_mod2_strengths, mods=2) 462 | trial.add('stim', stim2_locs, ons=stim_ons, offs=stim_offs, strengths=stim2_mod2_strengths, mods=2) 463 | trial.add('fix_out', offs=fix_offs) 464 | stim_locs = [stim1_locs[i] if (stim1_strengths[i]>stim2_strengths[i]) 465 | else stim2_locs[i] for i in range(batch_size)] 466 | trial.add('out', stim_locs, ons=fix_offs) 467 | 468 | trial.add_c_mask(pre_offs=fix_offs, post_ons=check_ons) 469 | 470 | trial.epochs = {'fix1' : (None, stim_ons), 471 | 'stim1' : (stim_ons, stim_offs), 472 | # 'delay1' : (stim_offs, fix_offs), 473 | 'go1' : (fix_offs, None)} 474 | 475 | return trial 476 | 477 | 478 | def contextdm1(config, mode, **kwargs): 479 | return _contextdm(config, mode, 1, **kwargs) 480 | 481 | 482 | def contextdm2(config, mode, **kwargs): 483 | return _contextdm(config, mode, 2, **kwargs) 484 | 485 | 486 | def multidm(config, mode, **kwargs): 487 | return _contextdm(config, mode, 'both', **kwargs) 488 | 489 | 490 | def reactgo_(config, mode, anti_response, **kwargs): 491 | ''' 492 | Fixate when fixation point is shown, 493 | A stimulus will be shown, and the output should saccade to the stimulus location 494 | Generate one batch of trials 495 | 496 | The fixation is shown between (0, T) 497 | The stimulus is shown between (fix_off,T) 498 | 499 | The output should be fixation location for (0, fix_off) 500 | Otherwise should be the stimulus location 501 | 502 | :param mode: the mode of generating. Options: 'random', 'explicit'... 503 | Optional parameters: 504 | :param batch_size: Batch size (required for mode=='random') 505 | :param tdim: dimension of time (required for mode=='sample') 506 | :param param: a dictionary of parameters (required for mode=='explicit') 507 | :return: 2 Tensor3 data array (Time, Batchsize, Units) 508 | ''' 509 | dt = config['dt'] 510 | rng = config['rng'] 511 | if mode == 'random': # Randomly generate parameters 512 | batch_size = kwargs['batch_size'] 513 | # each batch consists of sequences of equal length 514 | # A list of locations of fixation points and fixation off time 515 | stim_ons = int(rng.uniform(500,2500)/dt) 516 | tdim = stim_ons + int(rng.uniform(300,700)/dt) # 20190510 517 | 518 | # A list of locations of stimuluss (they are always on) 519 | stim_locs = rng.uniform(0, 2*np.pi, (batch_size,)) 520 | 521 | stim_mod = rng.choice([1,2]) 522 | 523 | elif mode == 'test': 524 | tdim = int(2500/dt) 525 | n_stim_loc, n_stim_mod = batch_shape = 100, 2 526 | batch_size = np.prod(batch_shape) 527 | ind_stim_loc, ind_stim_mod = np.unravel_index(range(batch_size),batch_shape) 528 | 529 | stim_ons = int(2000/dt) 530 | stim_locs = 2*np.pi*ind_stim_loc/n_stim_loc 531 | stim_mod = ind_stim_mod + 1 532 | 533 | elif mode == 'psychometric': 534 | p = kwargs['params'] 535 | stim_locs = p['stim_locs'] 536 | batch_size = len(stim_locs) 537 | 538 | # Time of stimuluss on/off 539 | stim_ons = int(1000/dt) 540 | tdim = int(400/dt) + stim_ons 541 | stim_mod = 1 542 | 543 | else: 544 | raise ValueError('Unknown mode: ' + str(mode)) 545 | 546 | # time to check the saccade location 547 | check_ons = stim_ons + int(100/dt) 548 | 549 | # Response locations 550 | stim_locs = np.array(stim_locs) 551 | if not anti_response: 552 | response_locs = stim_locs 553 | else: 554 | response_locs = (stim_locs+np.pi)%(2*np.pi) 555 | 556 | trial = Trial(config, tdim, batch_size) 557 | trial.add('fix_in') 558 | trial.add('stim', stim_locs, ons=stim_ons, mods=stim_mod) 559 | trial.add('fix_out', offs=stim_ons) 560 | trial.add('out', response_locs, ons=stim_ons) 561 | trial.add_c_mask(pre_offs=stim_ons, post_ons=check_ons) 562 | 563 | trial.epochs = {'fix1' : (None, stim_ons), 564 | 'go1' : (stim_ons, None)} 565 | 566 | return trial 567 | 568 | 569 | def reactgo(config, mode, **kwargs): 570 | return reactgo_(config, mode, False, **kwargs) 571 | 572 | 573 | def reactanti(config, mode, **kwargs): 574 | return reactgo_(config, mode, True, **kwargs) 575 | 576 | 577 | def fdgo_(config, mode, anti_response, **kwargs): 578 | ''' 579 | Go with inhibitory control. Important difference with Go task is that 580 | the stimulus is presented from the beginning. 581 | 582 | Fixate whenever fixation point is shown, 583 | A stimulus will be shown from the beginning 584 | And output should saccade to the stimulus location 585 | Generate one batch of trials 586 | 587 | The fixation is shown between (0, fix_off) 588 | The stimulus is shown between (0,T) 589 | 590 | The output should be fixation location for (0, fix_off) 591 | Otherwise should be the stimulus location 592 | 593 | :param mode: the mode of generating. Options: 'random', 'explicit'... 594 | Optional parameters: 595 | :param batch_size: Batch size (required for mode=='random') 596 | :param tdim: dimension of time (required for mode=='sample') 597 | :param param: a dictionary of parameters (required for mode=='explicit') 598 | :return: 2 Tensor3 data array (Time, Batchsize, Units) 599 | ''' 600 | dt = config['dt'] 601 | rng = config['rng'] 602 | if mode == 'random': # Randomly generate parameters 603 | batch_size = kwargs['batch_size'] 604 | delay_fac = kwargs['delay_fac'] 605 | # each batch consists of sequences of equal length 606 | # A list of locations of fixation points and fixation off time 607 | 608 | # A list of locations of stimulus (they are always on) 609 | stim_locs = rng.rand(batch_size)*2*np.pi 610 | stim_mod = rng.choice([1,2]) 611 | stim_ons = int(rng.uniform(300,700)/dt) 612 | 613 | fix_offs = stim_ons + int(delay_fac*rng.uniform(500,1500)/dt) 614 | tdim = fix_offs + int(rng.uniform(300,700)/dt) # 20190510 615 | 616 | elif mode == 'test': 617 | tdim = int(2500/dt) #int(2500/dt) #changed may 30 2020 618 | n_stim_loc, n_stim_mod = batch_shape = 100, 2 619 | batch_size = np.prod(batch_shape) 620 | ind_stim_loc, ind_stim_mod = np.unravel_index(range(batch_size),batch_shape) 621 | 622 | # stim_ons = int(500/dt) 623 | # fix_offs = int(1500/dt) 624 | # stim_locs = 2*np.pi*ind_stim_loc/n_stim_loc 625 | # stim_mod = ind_stim_mod + 1 626 | 627 | fix_offs = int(2000/dt) 628 | stim_locs = 2*np.pi*ind_stim_loc/n_stim_loc 629 | stim_ons = int(500/dt) 630 | stim_mod = ind_stim_mod + 1 631 | 632 | elif mode == 'psychometric': 633 | p = kwargs['params'] 634 | stim_locs = p['stim_locs'] 635 | stim_time = int(p['stim_time']/dt) 636 | batch_size = len(stim_locs) 637 | 638 | # Time of stimuluss on/off 639 | stim_ons = int(300/dt) 640 | fix_offs = stim_ons + stim_time 641 | tdim = int(400/dt) + fix_offs 642 | stim_mod = 1 643 | 644 | else: 645 | raise ValueError('Unknown mode: ' + str(mode)) 646 | 647 | # time to check the saccade location 648 | check_ons = fix_offs + int(100/dt) 649 | 650 | # Response locations 651 | stim_locs = np.array(stim_locs) 652 | if not anti_response: 653 | response_locs = stim_locs 654 | else: 655 | response_locs = (stim_locs+np.pi)%(2*np.pi) 656 | 657 | trial = Trial(config, tdim, batch_size) 658 | trial.add('fix_in', offs=fix_offs) 659 | trial.add('stim', stim_locs, ons=stim_ons, mods=stim_mod) 660 | trial.add('fix_out', offs=fix_offs) 661 | trial.add('out', response_locs, ons=fix_offs) 662 | trial.add_c_mask(pre_offs=fix_offs, post_ons=check_ons) 663 | 664 | trial.epochs = {'fix1' : (None, stim_ons), 665 | 'stim1' : (stim_ons, fix_offs), 666 | 'go1' : (fix_offs, None)} 667 | 668 | return trial 669 | 670 | 671 | def fdgo(config, mode, **kwargs): 672 | return fdgo_(config, mode, False, **kwargs) 673 | 674 | 675 | def fdanti(config, mode, **kwargs): 676 | return fdgo_(config, mode, True, **kwargs) 677 | 678 | 679 | def delayanti(config, mode, **kwargs): 680 | return delaygo_(config, mode, True, **kwargs) 681 | 682 | 683 | def _dm(config, mode, stim_mod, **kwargs): 684 | ''' 685 | Fixate whenever fixation point is shown. 686 | Two stimuluss are shown, saccade to the one with higher intensity 687 | Generate one batch of trials 688 | 689 | The fixation is shown between (0, fix_off) 690 | The two stimuluss is shown between (0,T) 691 | 692 | The output should be fixation location for (0, fix_off) 693 | Otherwise the location of the stronger stimulus 694 | 695 | :param mode: the mode of generating. Options: 'random', 'explicit'... 696 | Optional parameters: 697 | :param batch_size: Batch size (required for mode=='random') 698 | :param tdim: dimension of time (required for mode=='sample') 699 | :param param: a dictionary of parameters (required for mode=='explicit') 700 | :return: 2 Tensor3 data array (Time, Batchsize, Units) 701 | ''' 702 | dt = config['dt'] 703 | rng = config['rng'] 704 | if mode == 'random': # Randomly generate parameters 705 | batch_size = kwargs['batch_size'] 706 | 707 | # A list of locations of stimuluss (they are always on) 708 | stim_dist = rng.uniform(0.5*np.pi,1.5*np.pi,(batch_size,))*rng.choice([-1,1],(batch_size,)) 709 | stim1_locs = rng.uniform(0, 2*np.pi, (batch_size,)) 710 | stim2_locs = (stim1_locs+stim_dist)%(2*np.pi) 711 | 712 | # Target strengths 713 | stims_mean = rng.uniform(0.8,1.2,(batch_size,)) 714 | # stims_diff = rng.uniform(0.01,0.2,(batch_size,)) 715 | # stims_diff = rng.choice([0.02, 0.04, 0.08], (batch_size,)) # Encourage integration 716 | # stims_coh = rng.choice([0.16, 0.32, 0.64], (batch_size,)) 717 | 718 | stim_coh_range = np.random.uniform(0.005, 0.8, (100,)) #20190805 719 | 720 | if ('easy_task' in config) and config['easy_task']: 721 | # stim_coh_range = np.array([0.1, 0.2, 0.4, 0.8]) 722 | stim_coh_range *= 10 723 | 724 | stims_coh = rng.choice(stim_coh_range, (batch_size,)) 725 | stims_sign = rng.choice([1,-1], (batch_size,)) 726 | 727 | stim1_strengths = stims_mean + stims_coh*stims_sign 728 | stim2_strengths = stims_mean - stims_coh*stims_sign 729 | 730 | # Time of stimuluss on/off 731 | stim_on = int(rng.uniform(100,400)/dt) 732 | stim_ons = (np.ones(batch_size)*stim_on).astype(int) 733 | # stim_dur = int(rng.uniform(300,1500)/dt) 734 | stim_dur = int(rng.uniform(400,1600)/dt) #int(rng.choice([400, 800, 1600])/dt) dec 19th, 2018 735 | fix_offs = (stim_ons+stim_dur).astype(int) 736 | # each batch consists of sequences of equal length 737 | tdim = stim_on+stim_dur + int(rng.uniform(300,700)/dt) # 20190510 738 | 739 | elif mode == 'test': 740 | # Dense coverage of the stimulus space 741 | tdim = int(2500/dt) 742 | n_stim_loc, n_stim1_strength = batch_shape = 100, 4 743 | batch_size = np.prod(batch_shape) 744 | ind_stim_loc, ind_stim1_strength = np.unravel_index(range(batch_size),batch_shape) 745 | fix_offs = int(2000/dt) 746 | 747 | stim1_locs = 2*np.pi*ind_stim_loc/n_stim_loc 748 | stim2_locs = (stim1_locs+np.pi)%(2*np.pi) 749 | stim1_strengths = 0.4*ind_stim1_strength/n_stim1_strength+0.8 750 | stim2_strengths = 2 - stim1_strengths 751 | stim_ons = int(500/dt) 752 | 753 | elif mode == 'psychometric': 754 | p = kwargs['params'] 755 | stim1_locs = p['stim1_locs'] 756 | stim2_locs = p['stim2_locs'] 757 | stim1_strengths = p['stim1_strengths'] 758 | stim2_strengths = p['stim2_strengths'] 759 | stim_time = int(p['stim_time']/dt) 760 | batch_size = len(stim1_locs) 761 | 762 | # Time of stimuluss on/off 763 | stim_ons = int(300/dt) 764 | fix_offs = int(300/dt) + stim_time 765 | tdim = int(400/dt) + fix_offs 766 | 767 | else: 768 | raise ValueError('Unknown mode: ' + str(mode)) 769 | 770 | # time to check the saccade location 771 | check_ons = fix_offs + int(100/dt) 772 | 773 | 774 | trial = Trial(config, tdim, batch_size) 775 | trial.add('fix_in', offs=fix_offs) 776 | trial.add('stim', stim1_locs, ons=stim_ons, offs=fix_offs, strengths=stim1_strengths, mods=stim_mod) 777 | trial.add('stim', stim2_locs, ons=stim_ons, offs=fix_offs, strengths=stim2_strengths, mods=stim_mod) 778 | trial.add('fix_out', offs=fix_offs) 779 | stim_locs = [stim1_locs[i] if (stim1_strengths[i]>stim2_strengths[i]) 780 | else stim2_locs[i] for i in range(batch_size)] 781 | trial.add('out', stim_locs, ons=fix_offs) 782 | 783 | trial.add_c_mask(pre_offs=fix_offs, post_ons=check_ons) 784 | 785 | trial.epochs = {'fix1' : (None, stim_ons), 786 | 'stim1' : (stim_ons, fix_offs), 787 | 'go1' : (fix_offs, None)} 788 | 789 | return trial 790 | 791 | 792 | def dm1(config, mode, **kwargs): 793 | return _dm(config, mode, 1, **kwargs) 794 | 795 | 796 | def dm2(config, mode, **kwargs): 797 | return _dm(config, mode, 2, **kwargs) 798 | 799 | 800 | def _delaydm(config, mode, stim_mod, **kwargs): 801 | ''' 802 | Fixate whenever fixation point is shown. 803 | Two stimuluss are shown at different time, with different intensities 804 | 805 | The fixation is shown between (0, fix_off) 806 | The two stimuluss is shown between (0,T) 807 | 808 | The output should be fixation location for (0, fix_off) 809 | Otherwise the location of the stronger stimulus 810 | 811 | :param mode: the mode of generating. Options: 'random', 'explicit'... 812 | Optional parameters: 813 | :param batch_size: Batch size (required for mode=='random') 814 | :param tdim: dimension of time (required for mode=='sample') 815 | :param param: a dictionary of parameters (required for mode=='explicit') 816 | :return: 2 Tensor3 data array (Time, Batchsize, Units) 817 | ''' 818 | dt = config['dt'] 819 | rng = config['rng'] 820 | if mode == 'random': # Randomly generate parameters 821 | batch_size = kwargs['batch_size'] 822 | delay_fac = kwargs['delay_fac'] 823 | 824 | # A list of locations of stimuluss (they are always on) 825 | stim_dist = rng.uniform(0.5*np.pi, 1.5*np.pi,(batch_size,))*rng.choice([-1,1],(batch_size,)) 826 | stim1_locs = rng.uniform(0, 2*np.pi, (batch_size,)) 827 | stim2_locs = (stim1_locs+stim_dist)%(2*np.pi) 828 | 829 | stims_mean = rng.uniform(0.8,1.2,(batch_size,)) 830 | # stims_diff = rng.choice([0.32,0.64,1.28],(batch_size,)) 831 | 832 | stim_coh_range = np.random.uniform(0.005, 0.8, (100,)) #20190805 833 | 834 | if ('easy_task' in config) and config['easy_task']: 835 | # stim_coh_range = np.array([0.16,0.32,0.64]) 836 | stim_coh_range *= 2 837 | 838 | stims_coh = rng.choice(stim_coh_range,(batch_size,)) 839 | stims_sign = rng.choice([1,-1], (batch_size,)) 840 | 841 | stim1_strengths = stims_mean + stims_coh*stims_sign 842 | stim2_strengths = stims_mean - stims_coh*stims_sign 843 | 844 | # stim1_strengths = rng.uniform(0.25,1.75,(batch_size,)) 845 | # stim2_strengths = rng.uniform(0.25,1.75,(batch_size,)) 846 | 847 | # Time of stimuluss on/off 848 | stim1_ons = int(rng.uniform(200,600)/dt) #int(rng.choice([200, 400, 600])/dt) #dec 19th 2018 849 | stim1_offs = stim1_ons + int(rng.uniform(200,600)/dt) # int(rng.choice([200, 400, 600])/dt) #dec 19th 2018 850 | stim2_ons = stim1_offs + int(delay_fac*rng.uniform(200,1600)/dt) #int(rng.choice([200, 400, 800, 1600])/dt) # dec 17 2018 851 | stim2_offs = stim2_ons + int(rng.uniform(200,600)/dt) # int(rng.choice([200, 400, 600])/dt) #dec 19th 2018 852 | fix_offs = stim2_offs + int(rng.uniform(100,300)/dt) 853 | 854 | # stim2_ons = (np.ones(batch_size)*rng.choice([400,500,600,700,1400])/dt).astype(int) 855 | # stim2_ons = (np.ones(batch_size)*rng.choice([400,600,1000,1400,2000])/dt).astype(int) 856 | # stim2_ons = (np.ones(batch_size)*rng.uniform(2800,3200)/dt).astype(int) 857 | 858 | # each batch consists of sequences of equal length 859 | tdim = fix_offs + int(rng.uniform(300,700)/dt) # 20190510 860 | 861 | elif mode == 'test': 862 | tdim = int(3000/dt) 863 | n_stim_loc, n_stim1_strength = batch_shape = 100, 4 864 | batch_size = np.prod(batch_shape) 865 | ind_stim_loc, ind_stim1_strength = np.unravel_index(range(batch_size),batch_shape) 866 | 867 | fix_offs = int(2700/dt) 868 | stim1_locs = 2*np.pi*ind_stim_loc/n_stim_loc 869 | stim2_locs = (stim1_locs+np.pi)%(2*np.pi) 870 | stim1_strengths = 1.0*ind_stim1_strength/n_stim1_strength+0.5 871 | stim2_strengths = 2 - stim1_strengths 872 | stim1_ons = int(500/dt) 873 | stim1_offs = int(1000/dt) 874 | stim2_ons = int(2000/dt) 875 | stim2_offs = int(2500/dt) 876 | 877 | elif mode == 'psychometric': 878 | p = kwargs['params'] 879 | stim1_locs = p['stim1_locs'] 880 | stim2_locs = p['stim2_locs'] 881 | stim1_strengths = p['stim1_strengths'] 882 | stim2_strengths = p['stim2_strengths'] 883 | stim1_ons = int(p['stim1_ons']/dt) 884 | stim1_offs = int(p['stim1_offs']/dt) 885 | stim2_ons = int(p['stim2_ons']/dt) 886 | stim2_offs = int(p['stim2_offs']/dt) 887 | batch_size = len(stim1_locs) 888 | 889 | fix_offs = int(200/dt) + stim2_offs 890 | tdim = int(300/dt) + fix_offs 891 | 892 | else: 893 | raise ValueError('Unknown mode: ' + str(mode)) 894 | 895 | # time to check the saccade location 896 | check_ons = fix_offs + int(100/dt) 897 | 898 | trial = Trial(config, tdim, batch_size) 899 | trial.add('fix_in', offs=fix_offs) 900 | trial.add('stim', stim1_locs, ons=stim1_ons, offs=stim1_offs, strengths=stim1_strengths, mods=stim_mod) 901 | trial.add('stim', stim2_locs, ons=stim2_ons, offs=stim2_offs, strengths=stim2_strengths, mods=stim_mod) 902 | trial.add('fix_out', offs=fix_offs) 903 | stim_locs = [stim1_locs[i] if (stim1_strengths[i]>stim2_strengths[i]) 904 | else stim2_locs[i] for i in range(batch_size)] 905 | trial.add('out', stim_locs, ons=fix_offs) 906 | 907 | trial.add_c_mask(pre_offs=fix_offs, post_ons=check_ons) 908 | 909 | trial.epochs = {'fix1' : (None, stim1_ons), 910 | 'stim1' : (stim1_ons, stim1_offs), 911 | 'delay1' : (stim1_offs, stim2_ons), 912 | 'stim2' : (stim2_ons, stim2_offs), 913 | 'delay2' : (stim2_offs, fix_offs), 914 | 'go1' : (fix_offs, None)} 915 | 916 | return trial 917 | 918 | 919 | def delaydm1(config, mode, **kwargs): 920 | return _delaydm(config, mode, 1, **kwargs) 921 | 922 | 923 | def delaydm2(config, mode, **kwargs): 924 | return _delaydm(config, mode, 2, **kwargs) 925 | 926 | 927 | def _contextdelaydm(config, mode, attend_mod, **kwargs): 928 | ''' 929 | Fixate whenever fixation point is shown. 930 | Two stimuluss are shown in each ring, 931 | Saccade to the one with higher intensity for the attended ring 932 | Generate one batch of trials 933 | 934 | The fixation is shown between (0, fix_off) 935 | The two stimuluss is shown between (0,T) 936 | 937 | The output should be fixation location for (0, fix_off) 938 | Otherwise the location of the stronger stimulus 939 | 940 | In this task, if the model's strategy is to ignore context, and integrate both, 941 | then the maximum performance is 75%. So we need to make the highest correct performance 942 | much higher than that. 943 | 944 | :param mode: the mode of generating. Options: 'random', 'explicit'... 945 | Optional parameters: 946 | :param batch_size: Batch size (required for mode=='random') 947 | :param tdim: dimension of time (required for mode=='sample') 948 | :param param: a dictionary of parameters (required for mode=='explicit') 949 | :return: 2 Tensor3 data array (Time, Batchsize, Units) 950 | ''' 951 | dt = config['dt'] 952 | rng = config['rng'] 953 | if mode == 'random': # Randomly generate parameters 954 | batch_size = kwargs['batch_size'] 955 | delay_fac = kwargs['delay_fac'] 956 | 957 | # A list of locations of stimuluss, same locations for both modalities 958 | stim_dist = rng.uniform(0.5*np.pi,1.5*np.pi,(batch_size,))*rng.choice([-1,1],(batch_size,)) 959 | stim1_locs = rng.uniform(0, 2*np.pi, (batch_size,)) 960 | stim2_locs = (stim1_locs+stim_dist)%(2*np.pi) 961 | 962 | # stim_coh_range = np.array([0.08,0.16,0.32]) 963 | stim_coh_range = np.random.uniform(0.005, 0.8, (100,)) #20190805 964 | 965 | if ('easy_task' in config) and config['easy_task']: 966 | # stim_coh_range = np.array([0.16, 0.32, 0.64]) 967 | stim_coh_range *= 2 968 | 969 | if (attend_mod == 1) or (attend_mod == 2): 970 | stim1_mod1_strengths, stim2_mod1_strengths = \ 971 | contextdm_genstim(batch_size, rng, stim_coh_range) 972 | stim1_mod2_strengths, stim2_mod2_strengths = \ 973 | contextdm_genstim(batch_size, rng, stim_coh_range) 974 | if attend_mod == 1: 975 | stim1_strengths, stim2_strengths = stim1_mod1_strengths, stim2_mod1_strengths 976 | else: 977 | stim1_strengths, stim2_strengths = stim1_mod2_strengths, stim2_mod2_strengths 978 | else: 979 | stim1_strengths, stim2_strengths = \ 980 | contextdm_genstim(batch_size, rng, stim_coh_range) 981 | 982 | stim1_mod12_diff = stim1_strengths * \ 983 | np.random.uniform(0.005, 0.8, (batch_size,)) * \ 984 | np.random.choice([+1, -1], (batch_size,)) 985 | stim1_mod1_strengths = stim1_strengths + stim1_mod12_diff/2 986 | stim1_mod2_strengths = stim1_strengths - stim1_mod12_diff/2 987 | 988 | stim2_mod12_diff = stim2_strengths * \ 989 | np.random.uniform(0.005, 0.8, (batch_size,)) * \ 990 | np.random.choice([+1, -1], (batch_size,)) 991 | stim2_mod1_strengths = stim2_strengths + stim2_mod12_diff/2 992 | stim2_mod2_strengths = stim2_strengths - stim2_mod12_diff/2 993 | 994 | # Time of stimuluss on/off 995 | stim1_ons = int(rng.uniform(200,600)/dt) 996 | stim1_offs = stim1_ons + int(rng.uniform(200,600)/dt) 997 | stim2_ons = stim1_offs + int(delay_fac*rng.uniform(200,1600)/dt) 998 | stim2_offs = stim2_ons + int(rng.uniform(200,600)/dt) 999 | fix_offs = stim2_offs + int(rng.uniform(100,300)/dt) 1000 | 1001 | # each batch consists of sequences of equal length 1002 | tdim = fix_offs + int(rng.uniform(300,700)/dt) # 20190510 1003 | 1004 | elif mode == 'test': 1005 | n_stim_loc, n_stim_mod1_strength, n_stim_mod2_strength = batch_shape = 100, 4, 4 1006 | batch_size = np.prod(batch_shape) 1007 | ind_stim_loc, ind_stim_mod1_strength, ind_stim_mod2_strength = np.unravel_index(range(batch_size),batch_shape) 1008 | 1009 | stim1_locs = 2*np.pi*ind_stim_loc/n_stim_loc 1010 | stim2_locs = (stim1_locs+np.pi)%(2*np.pi) 1011 | stim1_mod1_strengths = 0.4*ind_stim_mod1_strength/n_stim_mod1_strength+0.8 1012 | stim2_mod1_strengths = 2 - stim1_mod1_strengths 1013 | stim1_mod2_strengths = 0.4*ind_stim_mod2_strength/n_stim_mod2_strength+0.8 1014 | stim2_mod2_strengths = 2 - stim1_mod2_strengths 1015 | 1016 | stim1_ons = int(500/dt) 1017 | stim1_offs = int(1000/dt) 1018 | stim2_ons = int(2000/dt) 1019 | stim2_offs = int(2500/dt) 1020 | fix_offs = int(3000/dt) 1021 | tdim = int(3500/dt) 1022 | 1023 | elif mode == 'psychometric': 1024 | p = kwargs['params'] 1025 | stim1_locs = p['stim1_locs'] 1026 | stim2_locs = p['stim2_locs'] 1027 | stim1_mod1_strengths = p['stim1_mod1_strengths'] 1028 | stim2_mod1_strengths = p['stim2_mod1_strengths'] 1029 | stim1_mod2_strengths = p['stim1_mod2_strengths'] 1030 | stim2_mod2_strengths = p['stim2_mod2_strengths'] 1031 | # stim1_ons = int(500/dt) 1032 | # stim1_offs = int(1000/dt) 1033 | # stim2_ons = int(p['stim_time']/dt) + stim1_offs 1034 | # stim2_offs = int(500/dt) + stim2_ons 1035 | stim1_ons = int(300/dt) 1036 | stim1_offs = int(600/dt) 1037 | stim2_ons = int(p['stim_time']/dt) + stim1_offs 1038 | stim2_offs = int(300/dt) + stim2_ons 1039 | batch_size = len(stim1_locs) 1040 | 1041 | # Time of stimuluss on/off 1042 | fix_offs = int(200/dt) + stim2_offs 1043 | tdim = int(300/dt) + fix_offs 1044 | 1045 | else: 1046 | raise ValueError('Unknown mode: ' + str(mode)) 1047 | 1048 | # time to check the saccade location 1049 | check_ons = fix_offs + int(100/dt) 1050 | 1051 | if attend_mod == 1: 1052 | stim1_strengths, stim2_strengths = stim1_mod1_strengths, stim2_mod1_strengths 1053 | elif attend_mod == 2: 1054 | stim1_strengths, stim2_strengths = stim1_mod2_strengths, stim2_mod2_strengths 1055 | elif attend_mod == 'both': 1056 | stim1_strengths = stim1_mod1_strengths + stim1_mod2_strengths 1057 | stim2_strengths = stim2_mod1_strengths + stim2_mod2_strengths 1058 | 1059 | trial = Trial(config, tdim, batch_size) 1060 | trial.add('fix_in', offs=fix_offs) 1061 | trial.add('stim', stim1_locs, ons=stim1_ons, offs=stim1_offs, strengths=stim1_mod1_strengths, mods=1) 1062 | trial.add('stim', stim2_locs, ons=stim2_ons, offs=stim2_offs, strengths=stim2_mod1_strengths, mods=1) 1063 | trial.add('stim', stim1_locs, ons=stim1_ons, offs=stim1_offs, strengths=stim1_mod2_strengths, mods=2) 1064 | trial.add('stim', stim2_locs, ons=stim2_ons, offs=stim2_offs, strengths=stim2_mod2_strengths, mods=2) 1065 | trial.add('fix_out', offs=fix_offs) 1066 | stim_locs = [stim1_locs[i] if (stim1_strengths[i]>stim2_strengths[i]) 1067 | else stim2_locs[i] for i in range(batch_size)] 1068 | trial.add('out', stim_locs, ons=fix_offs) 1069 | 1070 | trial.add_c_mask(pre_offs=fix_offs, post_ons=check_ons) 1071 | 1072 | trial.epochs = {'fix1' : (None, stim1_ons), 1073 | 'stim1' : (stim1_ons, stim1_offs), 1074 | 'delay1' : (stim1_offs, stim2_ons), 1075 | 'stim2' : (stim2_ons, stim2_offs), 1076 | 'delay2' : (stim2_offs, fix_offs), 1077 | 'go1' : (fix_offs, None)} 1078 | 1079 | return trial 1080 | 1081 | 1082 | def contextdelaydm1(config, mode, **kwargs): 1083 | return _contextdelaydm(config, mode, 1, **kwargs) 1084 | 1085 | 1086 | def contextdelaydm2(config, mode, **kwargs): 1087 | return _contextdelaydm(config, mode, 2, **kwargs) 1088 | 1089 | 1090 | def multidelaydm(config, mode, **kwargs): 1091 | return _contextdelaydm(config, mode, 'both', **kwargs) 1092 | 1093 | 1094 | def dms_(config, mode, matchnogo, **kwargs): 1095 | ''' 1096 | Delay-match-to-sample 1097 | 1098 | Two stimuli are shown, separated in time, either at the same location or not 1099 | Fixate before the second stimulus is shown 1100 | 1101 | If matchnogo is one, then: 1102 | If the two stimuli are the same, then keep fixation. 1103 | If the two stimuli are different, then saccade to the location of the stimulus 1104 | 1105 | If matchnogo is zero, then: 1106 | If the two stimuli are different, then keep fixation. 1107 | If the two stimuli are the same, then saccade to the location of the stimulus 1108 | 1109 | The first stimulus is shown between (stim1_on, stim1_off) 1110 | The second stimulus is shown between (stim2_on, T) 1111 | 1112 | The output should be fixation location for (0, stim2_on) 1113 | If two stimuli the different location, then for (stim2_on, T) go to stim2_loc 1114 | Otherwise keep fixation 1115 | 1116 | :param mode: the mode of generating. Options: 'random', 'explicit'... 1117 | Optional parameters: 1118 | :param batch_size: Batch size (required for mode=='random') 1119 | :param tdim: dimension of time (required for mode=='sample') 1120 | :param param: a dictionary of parameters (required for mode=='explicit') 1121 | :return: 2 Tensor3 data array (Time, Batchsize, Units) 1122 | ''' 1123 | dt = config['dt'] 1124 | rng = config['rng'] 1125 | if mode == 'random': # Randomly generate parameters 1126 | batch_size = kwargs['batch_size'] 1127 | delay_fac = kwargs['delay_fac'] 1128 | 1129 | stim1_mod = rng.choice([1,2]) 1130 | stim2_mod = rng.choice([1,2]) 1131 | # A list of locations of stimuluss 1132 | # Since stim1 is always shown first, it's important that we completely randomize their relative positions 1133 | matchs = rng.choice([0,1],(batch_size,)) # match or not? 1134 | # stim_dist range between 1/18*pi and (2-1/18*pi), corresponding to 10 degree to 350 degree 1135 | stim_dist = rng.uniform(np.pi/9,np.pi*17./9.,(batch_size,))*rng.choice([-1,1],(batch_size,)) 1136 | stim1_locs = rng.uniform(0, 2*np.pi, (batch_size,)) 1137 | stim2_locs = (stim1_locs+stim_dist*(1-matchs))%(2*np.pi) 1138 | 1139 | # Time of stimuluss on/off 1140 | stim1_ons = int(rng.uniform(200,600)/dt) # int(rng.choice([200, 400, 600])/dt) #dec 19th 2018 1141 | stim1_offs = stim1_ons + int(rng.uniform(200,600)/dt) # int(rng.choice([200, 400, 600])/dt) #dec 19th 2018 1142 | stim2_ons = stim1_offs + int(delay_fac*rng.uniform(200,1600)/dt) #int(rng.choice([200, 400, 800, 1600])/dt) #dec 17 2018 1143 | tdim = stim2_ons + int(rng.uniform(300,700)/dt) # 20190510 1144 | 1145 | elif mode == 'test': 1146 | # Set this test so the model always respond 1147 | n_stim_loc, n_mod1, n_mod2 = batch_shape = 100, 2, 2 1148 | batch_size = np.prod(batch_shape) 1149 | ind_stim_loc, ind_mod1, ind_mod2 = np.unravel_index(range(batch_size),batch_shape) 1150 | 1151 | stim1_mod = ind_mod1 + 1 1152 | stim2_mod = ind_mod2 + 1 1153 | 1154 | stim1_locs = 2*np.pi*ind_stim_loc/n_stim_loc 1155 | matchs = (1 - matchnogo)*np.ones(batch_size) # make sure the response is Go 1156 | stim2_locs = (stim1_locs+np.pi*(1-matchs))%(2*np.pi) 1157 | 1158 | stim1_ons = int(500/dt) 1159 | stim1_offs = stim1_ons + int(500/dt) 1160 | stim2_ons = stim1_offs + int(1200/dt) 1161 | tdim = stim2_ons + int(500/dt) 1162 | 1163 | elif mode == 'psychometric': 1164 | p = kwargs['params'] 1165 | stim1_locs = p['stim1_locs'] 1166 | stim2_locs = p['stim2_locs'] 1167 | matchs = get_dist(stim1_locs-stim2_locs)