├── .gitignore ├── README.md ├── __init__.py ├── analysis ├── __init__.py ├── antitask.py ├── clustering.py ├── contextdm_analysis.py ├── contlearn_schematic.py ├── data_analysis.py ├── dimensionality.py ├── performance.py ├── posttrain_analysis.py ├── standard_analysis.py ├── taskset.py ├── variance.py └── varyhp.py ├── datasets ├── contextdm_data_analysis.py ├── mante_dataset_preprocess.py └── siegel_dataset_preprocess.py ├── experiment.py ├── network.py ├── paper.py ├── submit_jobs.py ├── task.py ├── tools.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | .idea/* 3 | data/* 4 | figure/* 5 | debug/* 6 | datasets/ 7 | previous_versions/* 8 | sbatch/* 9 | *test.py 10 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MultiTask Network 2 | 3 | ## Dependencies 4 | The code is tested in Tensorflow 1.8.0, Python 2.7 and Python 3.6, and on MacOS 10.13 and Ubuntu 16.04. 5 | 6 | Scikit-learn (http://scikit-learn.org/stable/) is necessary for many analyses. 7 | 8 | The seaborn package (https://seaborn.pydata.org/) is needed to correctly 9 | plot a few analysis results. 10 | 11 | ## Reproducing results from the paper 12 | All analysis results from the paper can be reproduced from paper.py 13 | 14 | Simply go to paper.py, set the model_dir to be the directory of your 15 | model files, uncomment the analyses you want to run, and run the file. 16 | 17 | ## Pretrained models 18 | We provide 20 pretrained models and their auxillary data files for 19 | analyses. 20 | https://drive.google.com/drive/folders/1L8v-OZgYHVcKh1UKtCJl5QVlz8mkaRxr?usp=sharing 21 | 22 | ## Get started with training 23 | Train a default network with: 24 | 25 | import train 26 | train.train(model_dir='debug', hp={'learning_rate': 0.001}, ruleset='mante') 27 | 28 | These lines will train a default network for the Mante task, and store the 29 | results in your_working_directory/debug/. 30 | 31 | ## Get started with some simple analyses 32 | After training (you can interrupt at any time), you can visualize the neural 33 | activity using 34 | 35 | from analysis import standard_analysis 36 | model_dir = 'debug' 37 | rule = 'contextdm1' 38 | standard_analysis.easy_activity_plot(model_dir, rule) 39 | 40 | This will plot some neural activity. See the source code to know how to load 41 | hyperparameters, restore model, and run it for analysis. 42 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gyyang/multitask/817e23d8e418197b875b8448195f963ab51dfe9a/__init__.py -------------------------------------------------------------------------------- /analysis/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gyyang/multitask/817e23d8e418197b875b8448195f963ab51dfe9a/analysis/__init__.py -------------------------------------------------------------------------------- /analysis/antitask.py: -------------------------------------------------------------------------------- 1 | """ 2 | Analysis of anti units 3 | """ 4 | 5 | from __future__ import division 6 | 7 | import numpy as np 8 | import pickle 9 | import matplotlib.pyplot as plt 10 | 11 | import tensorflow as tf 12 | from network import Model, get_perf 13 | from task import get_dist, generate_trials 14 | import tools 15 | 16 | save = True 17 | 18 | 19 | class Analysis(object): 20 | """Analyze the Anti tasks.""" 21 | def __init__(self, model_dir): 22 | self.model_dir = model_dir 23 | 24 | # Run model 25 | model = Model(model_dir) 26 | self.hp = model.hp 27 | with tf.Session() as sess: 28 | model.restore() 29 | for name in ['w_in', 'w_rec', 'w_out', 'b_rec', 'b_out']: 30 | setattr(self, name, sess.run(getattr(model, name))) 31 | if 'w_' in name: 32 | setattr(self, name, getattr(self, name).T) 33 | 34 | data_type = 'rule' 35 | with open(model_dir + '/variance_'+data_type+'.pkl', 'rb') as f: 36 | res = pickle.load(f) 37 | h_var_all = res['h_var_all'] 38 | self.rules = res['keys'] 39 | 40 | # First only get active units. Total variance across tasks larger than 1e-3 41 | ind_active = np.where(h_var_all.sum(axis=1) > 1e-3)[0] 42 | # ind_active = np.where(h_var_all.sum(axis=1) > 0.)[0] 43 | h_var_all = h_var_all[ind_active, :] 44 | 45 | # Normalize by the total variance across tasks 46 | h_normvar_all = (h_var_all.T/np.sum(h_var_all, axis=1)).T 47 | 48 | ########################## Get Anti Units #################################### 49 | # Directly search 50 | # This will be a stricter subset of the anti modules found in clustering results 51 | self.rules_anti = np.array(['fdanti', 'reactanti', 'delayanti']) 52 | self.rules_nonanti = np.array([r for r in self.rules if r not in self.rules_anti]) 53 | 54 | # Rule index used only for the rules 55 | self.ind_rules_anti = [self.rules.index(r) for r in self.rules_anti] 56 | self.ind_rules_nonanti = [self.rules.index(r) for r in self.rules_nonanti] 57 | 58 | self.h_normvar_all_anti = h_normvar_all[:, self.ind_rules_anti].sum(axis=1) 59 | self.h_normvar_all_nonanti = h_normvar_all[:, self.ind_rules_nonanti].sum(axis=1) 60 | 61 | # plt.figure() 62 | # _ = plt.hist(h_normvar_all_anti, bins=50) 63 | # plt.xlabel('Proportion of variance in anti tasks') 64 | # plt.show() 65 | 66 | ind_anti = np.where(self.h_normvar_all_anti>0.5)[0] 67 | ind_nonanti = np.where(self.h_normvar_all_anti<=0.5)[0] 68 | self.ind_anti_orig = ind_active[ind_anti] # Indices of anti units in the original matrix 69 | self.ind_nonanti_orig = ind_active[ind_nonanti] 70 | 71 | # Use clustering results (tend to be loose) 72 | # label_anti = np.where(label_prefs==FDANTI)[0][0] 73 | # ind_anti = np.where(labels==label_anti)[0] 74 | # ind_anti_orig = ind_orig[ind_anti] # Indices of anti units in the original matrix 75 | 76 | def plot_example_unit(self): 77 | """Plot activity of an example unit.""" 78 | from standard_analysis import pretty_singleneuron_plot 79 | pretty_singleneuron_plot( 80 | self.model_dir, ['fdanti', 'fdgo'], self.ind_anti_orig[2], 81 | save=save, ylabel_firstonly = True) 82 | 83 | def plot_inout_connections(self): 84 | """Plot the input and output connections.""" 85 | n_eachring = self.hp['n_eachring'] 86 | w_in, w_out = self.w_in, self.w_out 87 | 88 | w_in_ = (w_in[:, 1:n_eachring+1]+w_in[:, 1+n_eachring:2*n_eachring+1])/2. 89 | w_out_ = w_out[1:, :].T 90 | 91 | for ind_group, unit_type in zip([self.ind_anti_orig, self.ind_nonanti_orig], 92 | ['Anti units', 'Non-Anti Units']): 93 | # ind_group = ind_anti_orig 94 | n_group = len(ind_group) 95 | w_in_group = np.zeros((n_group, n_eachring)) 96 | w_out_group = np.zeros((n_group, n_eachring)) 97 | 98 | ind_pref_ins = list() 99 | ind_pref_outs = list() 100 | 101 | for i, ind in enumerate(ind_group): 102 | tmp_in = w_in_[ind, :] 103 | tmp_out = w_out_[ind, :] 104 | 105 | # Get preferred input and output directions 106 | ind_pref_in = np.argmax(tmp_in) 107 | ind_pref_out = np.argmax(tmp_out) 108 | 109 | ind_pref_ins.append(ind_pref_in) 110 | ind_pref_outs.append(ind_pref_out) 111 | 112 | # Sort by preferred input direction 113 | w_in_group[i, :] = np.roll(tmp_in, int(n_eachring/2)-ind_pref_in) 114 | w_out_group[i, :] = np.roll(tmp_out, int(n_eachring/2)-ind_pref_in) 115 | 116 | w_in_ave = w_in_group.mean(axis=0) 117 | w_out_ave = w_out_group.mean(axis=0) 118 | 119 | fs = 6 120 | fig = plt.figure(figsize=(1.5, 1.0)) 121 | ax = fig.add_axes([.3, .3, .6, .6]) 122 | ax.plot(w_in_ave, color='black', label='In') 123 | ax.plot(w_out_ave, color='red', label='Out') 124 | ax.set_xticks([int(n_eachring/2)]) 125 | ax.set_xticklabels(['Preferred input dir.']) 126 | # ax.set_xlabel(xlabel, fontsize=fs, labelpad=3) 127 | ax.set_ylabel('Conn. weight', fontsize=fs) 128 | lg = ax.legend(fontsize=fs, bbox_to_anchor=(1.1,1.1), 129 | labelspacing=0.2, loc=1, frameon=False) 130 | # plt.setp(lg.get_title(),fontsize=fs) 131 | ax.tick_params(axis='both', which='major', labelsize=fs) 132 | ax.set_title(unit_type, fontsize=fs, y=0.9) 133 | plt.locator_params(axis='y',nbins=3) 134 | ax.spines["right"].set_visible(False) 135 | ax.spines["top"].set_visible(False) 136 | ax.xaxis.set_ticks_position('bottom') 137 | ax.yaxis.set_ticks_position('left') 138 | if save: 139 | plt.savefig('figure/conn_'+unit_type+'.pdf', transparent=True) 140 | 141 | def plot_rule_connections(self): 142 | """Plot connectivity from the rule input units""" 143 | 144 | # Rule index for the connectivity 145 | from task import get_rule_index 146 | indconn_rules_anti = [get_rule_index(r, self.hp) for r in self.rules_anti] 147 | indconn_rules_nonanti = [get_rule_index(r, self.hp) for r in self.rules_nonanti] 148 | 149 | for ind, unit_type in zip([self.ind_anti_orig, self.ind_nonanti_orig], 150 | ['Anti units', 'Non-Anti units']): 151 | b1 = self.w_in[:, indconn_rules_anti][ind, :].flatten() 152 | b2 = self.w_in[:, indconn_rules_nonanti][ind, :].flatten() 153 | 154 | fs = 6 155 | fig = plt.figure(figsize=(1.5,1.2)) 156 | ax = fig.add_axes([0.3,0.3,0.6,0.4]) 157 | ax.boxplot([b1, b2], showfliers=False) 158 | ax.set_xticklabels(['Anti', 'Non-Anti']) 159 | ax.set_xlabel('Input from rule units', fontsize=fs, labelpad=3) 160 | ax.set_ylabel('Conn. weight', fontsize=fs) 161 | ax.set_title('To '+unit_type, fontsize=fs, y=0.9) 162 | ax.tick_params(axis='both', which='major', labelsize=fs) 163 | plt.locator_params(axis='y',nbins=2) 164 | ax.spines["right"].set_visible(False) 165 | ax.spines["top"].set_visible(False) 166 | ax.xaxis.set_ticks_position('bottom') 167 | ax.yaxis.set_ticks_position('left') 168 | if save: 169 | plt.savefig('figure/connweightrule_'+unit_type+'.pdf', transparent=True) 170 | plt.show() 171 | 172 | def plot_rec_connections(self): 173 | """Plot connectivity between recurrent units""" 174 | 175 | n_eachring = self.hp['n_eachring'] 176 | w_in, w_rec, w_out = self.w_in, self.w_rec, self.w_out 177 | 178 | w_in_ = (w_in[:, 1:n_eachring+1]+w_in[:, 1+n_eachring:2*n_eachring+1])/2. 179 | w_out_ = w_out[1:, :].T 180 | 181 | inds = [self.ind_nonanti_orig, self.ind_anti_orig] 182 | names = ['Non-Anti', 'Anti'] 183 | 184 | i_pairs = [(0,0), (0,1), (1,0), (1,1)] 185 | 186 | pref_diffs_list = list() 187 | w_recs_list = list() 188 | 189 | w_rec_bygroup = np.zeros((2,2)) 190 | 191 | for i_pair in i_pairs: 192 | ind1, ind2 = inds[i_pair[0]], inds[i_pair[1]] 193 | # For each neuron get the preference based on input weight 194 | # sort by weights 195 | w_sortby = w_in_ 196 | # w_sortby = w_out_ 197 | prefs1 = np.argmax(w_sortby[ind1, :], axis=1)*2.*np.pi/n_eachring 198 | prefs2 = np.argmax(w_sortby[ind2, :], axis=1)*2.*np.pi/n_eachring 199 | 200 | # Compute the pairwise distance based on preference 201 | # Then get the connection weight between pairs 202 | pref_diffs = list() 203 | w_recs = list() 204 | for i, ind_i in enumerate(ind1): 205 | for j, ind_j in enumerate(ind2): 206 | if ind_j == ind_i: 207 | # Excluding self connections, which tend to be positive 208 | continue 209 | pref_diffs.append(get_dist(prefs1[i]-prefs2[j])) 210 | # pref_diffs.append(prefs1[i]-prefs2[j]) 211 | w_recs.append(w_rec[ind_j, ind_i]) 212 | pref_diffs, w_recs = np.array(pref_diffs), np.array(w_recs) 213 | pref_diffs_list.append(pref_diffs) 214 | w_recs_list.append(w_recs) 215 | 216 | w_rec_bygroup[i_pair[1], i_pair[0]] = np.mean(w_recs[pref_diffs 1e-2)[0] 60 | ind_active = np.where(h_var_all_.sum(axis=1) > 1e-3)[0] 61 | h_var_all = h_var_all_[ind_active, :] 62 | 63 | # Normalize by the total variance across tasks 64 | if normalization_method == 'sum': 65 | h_normvar_all = (h_var_all.T/np.sum(h_var_all, axis=1)).T 66 | elif normalization_method == 'max': 67 | h_normvar_all = (h_var_all.T/np.max(h_var_all, axis=1)).T 68 | elif normalization_method == 'none': 69 | h_normvar_all = h_var_all 70 | else: 71 | raise NotImplementedError() 72 | 73 | ################################## Clustering ################################ 74 | from sklearn import metrics 75 | X = h_normvar_all 76 | 77 | # Clustering 78 | from sklearn.cluster import AgglomerativeClustering, KMeans 79 | 80 | # Choose number of clusters that maximize silhouette score 81 | n_clusters = range(2, 30) 82 | scores = list() 83 | labels_list = list() 84 | for n_cluster in n_clusters: 85 | # clustering = AgglomerativeClustering(n_cluster, affinity='cosine', linkage='average') 86 | clustering = KMeans(n_cluster, algorithm='full', n_init=20, random_state=0) 87 | clustering.fit(X) # n_samples, n_features = n_units, n_rules/n_epochs 88 | labels = clustering.labels_ # cluster labels 89 | 90 | score = metrics.silhouette_score(X, labels) 91 | 92 | scores.append(score) 93 | labels_list.append(labels) 94 | 95 | scores = np.array(scores) 96 | 97 | # Heuristic elbow method 98 | # Choose the number of cluster when Silhouette score first falls 99 | # Choose the number of cluster when Silhouette score is maximum 100 | if data_type == 'rule': 101 | #i = np.where((scores[1:]-scores[:-1])<0)[0][0] 102 | i = np.argmax(scores) 103 | else: 104 | # The more rigorous method doesn't work well in this case 105 | i = n_clusters.index(10) 106 | 107 | labels = labels_list[i] 108 | n_cluster = n_clusters[i] 109 | print('Choosing {:d} clusters'.format(n_cluster)) 110 | 111 | # Sort clusters by its task preference (important for consistency across nets) 112 | if data_type == 'rule': 113 | label_prefs = [np.argmax(h_normvar_all[labels==l].sum(axis=0)) for l in set(labels)] 114 | elif data_type == 'epoch': 115 | ## TODO: this may no longer work! 116 | label_prefs = [self.keys[np.argmax(h_normvar_all[labels==l].sum(axis=0))][0] for l in set(labels)] 117 | 118 | ind_label_sort = np.argsort(label_prefs) 119 | label_prefs = np.array(label_prefs)[ind_label_sort] 120 | # Relabel 121 | labels2 = np.zeros_like(labels) 122 | for i, ind in enumerate(ind_label_sort): 123 | labels2[labels==ind] = i 124 | labels = labels2 125 | 126 | # # Sort data by labels and by input connectivity 127 | # model = Model(save_name) 128 | # hp = model.hp 129 | # with tf.Session() as sess: 130 | # model.restore(sess) 131 | # var_list = sess.run(model.var_list) 132 | # 133 | # # Get connectivity 134 | # w_out = var_list[0].T 135 | # b_out = var_list[1] 136 | # w_in = var_list[2][:n_input, :].T 137 | # w_rec = var_list[2][n_input:, :].T 138 | # b_rec = var_list[3] 139 | # 140 | # # nx, nh, ny = hp['shape'] 141 | # nr = hp['n_eachring'] 142 | # 143 | # sort_by = 'w_in' 144 | # if sort_by == 'w_in': 145 | # w_in_mod1 = w_in[ind_active, :][:, 1:nr+1] 146 | # w_in_mod2 = w_in[ind_active, :][:, nr+1:2*nr+1] 147 | # w_in_modboth = w_in_mod1 + w_in_mod2 148 | # w_prefs = np.argmax(w_in_modboth, axis=1) 149 | # elif sort_by == 'w_out': 150 | # w_prefs = np.argmax(w_out[1:, ind_active], axis=0) 151 | # 152 | # ind_sort = np.lexsort((w_prefs, labels)) # sort by labels then by prefs 153 | 154 | ind_sort = np.argsort(labels) 155 | 156 | labels = labels[ind_sort] 157 | self.h_normvar_all = h_normvar_all[ind_sort, :] 158 | self.ind_active = ind_active[ind_sort] 159 | 160 | self.n_clusters = n_clusters 161 | self.scores = scores 162 | self.n_cluster = n_cluster 163 | 164 | self.h_var_all = h_var_all 165 | self.normalization_method = normalization_method 166 | self.labels = labels 167 | self.unique_labels = np.unique(labels) 168 | 169 | self.model_dir = model_dir 170 | self.hp = hp 171 | self.data_type = data_type 172 | self.rules = hp['rules'] 173 | 174 | def plot_cluster_score(self, save_name=None): 175 | """Plot the score by the number of clusters.""" 176 | fig = plt.figure(figsize=(2, 2)) 177 | ax = fig.add_axes([0.3, 0.3, 0.55, 0.55]) 178 | ax.plot(self.n_clusters, self.scores, 'o-', ms=3) 179 | ax.set_xlabel('Number of clusters', fontsize=7) 180 | ax.set_ylabel('Silhouette score', fontsize=7) 181 | ax.set_title('Chosen number of clusters: {:d}'.format(self.n_cluster), 182 | fontsize=7) 183 | ax.spines["right"].set_visible(False) 184 | ax.spines["top"].set_visible(False) 185 | ax.xaxis.set_ticks_position('bottom') 186 | ax.yaxis.set_ticks_position('left') 187 | ax.set_ylim([0, 0.32]) 188 | if save: 189 | fig_name = 'cluster_score' 190 | if save_name is None: 191 | save_name = self.hp['activation'] 192 | fig_name = fig_name + save_name 193 | plt.savefig('figure/'+fig_name+'.pdf', transparent=True) 194 | plt.show() 195 | 196 | def plot_variance(self, save_name=None): 197 | labels = self.labels 198 | ######################### Plotting Variance ################################### 199 | # Plot Normalized Variance 200 | if self.data_type == 'rule': 201 | figsize = (3.5,2.5) 202 | rect = [0.25, 0.2, 0.6, 0.7] 203 | rect_color = [0.25, 0.15, 0.6, 0.05] 204 | rect_cb = [0.87, 0.2, 0.03, 0.7] 205 | tick_names = [rule_name[r] for r in self.rules] 206 | fs = 6 207 | labelpad = 13 208 | elif self.data_type == 'epoch': 209 | figsize = (3.5,4.5) 210 | rect = [0.25, 0.1, 0.6, 0.85] 211 | rect_color = [0.25, 0.05, 0.6, 0.05] 212 | rect_cb = [0.87, 0.1, 0.03, 0.85] 213 | tick_names = [rule_name[key[0]]+' '+key[1] for key in self.keys] 214 | fs = 5 215 | labelpad = 20 216 | else: 217 | raise ValueError 218 | 219 | h_plot = self.h_normvar_all.T 220 | vmin, vmax = 0, 1 221 | fig = plt.figure(figsize=figsize) 222 | ax = fig.add_axes(rect) 223 | im = ax.imshow(h_plot, cmap='hot', 224 | aspect='auto', interpolation='nearest', vmin=vmin, vmax=vmax) 225 | 226 | plt.yticks(range(len(tick_names)), tick_names, 227 | rotation=0, va='center', fontsize=fs) 228 | plt.xticks([]) 229 | plt.title('Units', fontsize=7, y=0.97) 230 | plt.xlabel('Clusters', fontsize=7, labelpad=labelpad) 231 | ax.tick_params('both', length=0) 232 | for loc in ['bottom','top','left','right']: 233 | ax.spines[loc].set_visible(False) 234 | ax = fig.add_axes(rect_cb) 235 | cb = plt.colorbar(im, cax=ax, ticks=[vmin,vmax]) 236 | cb.outline.set_linewidth(0.5) 237 | if self.normalization_method == 'sum': 238 | clabel = 'Normalized Task Variance' 239 | elif self.normalization_method == 'max': 240 | clabel = 'Normalized Task Variance' 241 | elif self.normalization_method == 'none': 242 | clabel = 'Variance' 243 | 244 | cb.set_label(clabel, fontsize=7, labelpad=0) 245 | plt.tick_params(axis='both', which='major', labelsize=7) 246 | 247 | 248 | # Plot color bars indicating clustering 249 | if True: 250 | ax = fig.add_axes(rect_color) 251 | for il, l in enumerate(self.unique_labels): 252 | ind_l = np.where(labels==l)[0][[0, -1]]+np.array([0,1]) 253 | ax.plot(ind_l, [0,0], linewidth=4, solid_capstyle='butt', 254 | color=kelly_colors[il+1]) 255 | ax.text(np.mean(ind_l), -0.5, str(il+1), fontsize=6, 256 | ha='center', va='top', color=kelly_colors[il+1]) 257 | ax.set_xlim([0, len(labels)]) 258 | ax.set_ylim([-1, 1]) 259 | ax.axis('off') 260 | 261 | if save: 262 | fig_name = ('feature_map_by' + self.data_type + 263 | '_norm' + self.normalization_method) 264 | if save_name is not None: 265 | fig_name = fig_name + save_name 266 | plt.savefig('figure/'+fig_name+'.pdf', transparent=True) 267 | plt.show() 268 | 269 | def plot_similarity_matrix(self): 270 | labels = self.labels 271 | ######################### Plotting Similarity Matrix ########################## 272 | 273 | from sklearn.metrics.pairwise import cosine_similarity 274 | similarity = cosine_similarity(self.h_normvar_all) # TODO: check 275 | fig = plt.figure(figsize=(3.5, 3.5)) 276 | ax = fig.add_axes([0.25, 0.25, 0.6, 0.6]) 277 | im = ax.imshow(similarity, cmap='hot', interpolation='nearest', vmin=0, vmax=1) 278 | ax.axis('off') 279 | 280 | ax = fig.add_axes([0.87, 0.25, 0.03, 0.6]) 281 | cb = plt.colorbar(im, cax=ax, ticks=[0,1]) 282 | cb.outline.set_linewidth(0.5) 283 | cb.set_label('Similarity', fontsize=7, labelpad=0) 284 | plt.tick_params(axis='both', which='major', labelsize=7) 285 | 286 | ax1 = fig.add_axes([0.25, 0.85, 0.6, 0.05]) 287 | ax2 = fig.add_axes([0.2, 0.25, 0.05, 0.6]) 288 | for il, l in enumerate(self.unique_labels): 289 | ind_l = np.where(labels==l)[0][[0, -1]]+np.array([0,1]) 290 | ax1.plot(ind_l, [0,0], linewidth=2, solid_capstyle='butt', 291 | color=kelly_colors[il+1]) 292 | ax2.plot([0,0], len(labels)-ind_l, linewidth=2, solid_capstyle='butt', 293 | color=kelly_colors[il+1]) 294 | ax1.set_xlim([0, len(labels)]) 295 | ax2.set_ylim([0, len(labels)]) 296 | ax1.axis('off') 297 | ax2.axis('off') 298 | if save: 299 | plt.savefig('figure/feature_similarity_by'+self.data_type+'.pdf', transparent=True) 300 | plt.show() 301 | 302 | def plot_2Dvisualization(self, method='TSNE'): 303 | labels = self.labels 304 | ######################## Plotting 2-D visualization of variance map ########### 305 | from sklearn.manifold import TSNE, MDS, LocallyLinearEmbedding 306 | from sklearn.decomposition import PCA 307 | 308 | # model = LocallyLinearEmbedding() 309 | if method == 'PCA': 310 | model = PCA(n_components=2, whiten=False) 311 | elif method == 'MDS': 312 | model = MDS(metric=True, n_components=2, n_init=10, max_iter=1000) 313 | elif method == 'tSNE': 314 | model = TSNE(n_components=2, random_state=0, init='pca', 315 | verbose=1, method='exact', 316 | learning_rate=100, perplexity=30) 317 | else: 318 | raise NotImplementedError 319 | 320 | Y = model.fit_transform(self.h_normvar_all) 321 | 322 | fig = plt.figure(figsize=(2, 2)) 323 | ax = fig.add_axes([0.1, 0.1, 0.8, 0.8]) 324 | for il, l in enumerate(self.unique_labels): 325 | ind_l = np.where(labels==l)[0] 326 | ax.scatter(Y[ind_l, 0], Y[ind_l, 1], color=kelly_colors[il+1], s=10) 327 | ax.axis('off') 328 | plt.title(method, fontsize=7) 329 | figname = 'figure/taskvar_visual_by'+method+self.data_type+'.pdf' 330 | if save: 331 | plt.savefig(figname, transparent=True) 332 | plt.show() 333 | 334 | fig = plt.figure(figsize=(3.5, 3.5)) 335 | ax = fig.add_axes([0.1, 0.1, 0.8, 0.8]) 336 | ax.scatter(Y[:,0], Y[:,1], color='black') 337 | ax.axis('off') 338 | 339 | def plot_example_unit(self): 340 | ######################## Plotting Variance for example unit ################### 341 | if self.data_type == 'rule': 342 | tick_names = [rule_name[r] for r in self.rules] 343 | 344 | ind = 2 # example unit 345 | fig = plt.figure(figsize=(1.2,1.0)) 346 | ax = fig.add_axes([0.4,0.4,0.5,0.45]) 347 | ax.plot(range(self.h_var_all.shape[1]), self.h_var_all[ind, :], 'o-', color='black', lw=1, ms=2) 348 | plt.xticks(range(len(tick_names)), [tick_names[0]] + ['.']*(len(tick_names)-2) + [tick_names[-1]], 349 | rotation=90, fontsize=6, horizontalalignment='center') 350 | plt.xlabel('Task', fontsize=7, labelpad=-10) 351 | plt.ylabel('Task Variance', fontsize=7) 352 | plt.title('Unit {:d}'.format(self.ind_active[ind]), fontsize=7, y=0.85) 353 | plt.locator_params(axis='y', nbins=3) 354 | ax.tick_params(axis='both', which='major', labelsize=6, length=2) 355 | ax.spines["right"].set_visible(False) 356 | ax.spines["top"].set_visible(False) 357 | ax.xaxis.set_ticks_position('bottom') 358 | ax.yaxis.set_ticks_position('left') 359 | if save: 360 | plt.savefig('figure/exampleunit_variance.pdf', transparent=True) 361 | plt.show() 362 | 363 | from analysis.standard_analysis import pretty_singleneuron_plot 364 | # Plot single example neuron in time 365 | pretty_singleneuron_plot(self.model_dir, ['fdgo'], [self.ind_active[ind]], 366 | epoch=None, save=save, ylabel_firstonly=True) 367 | 368 | def plot_connectivity_byclusters(self): 369 | """Plot connectivity of the model""" 370 | 371 | ind_active = self.ind_active 372 | 373 | # Sort data by labels and by input connectivity 374 | model = Model(self.model_dir) 375 | hp = model.hp 376 | with tf.Session() as sess: 377 | model.restore() 378 | w_in = sess.run(model.w_in).T 379 | w_rec = sess.run(model.w_rec).T 380 | w_out = sess.run(model.w_out).T 381 | b_rec = sess.run(model.b_rec) 382 | b_out = sess.run(model.b_out) 383 | 384 | w_rec = w_rec[ind_active, :][:, ind_active] 385 | w_in = w_in[ind_active, :] 386 | w_out = w_out[:, ind_active] 387 | b_rec = b_rec[ind_active] 388 | 389 | # nx, nh, ny = hp['shape'] 390 | nr = hp['n_eachring'] 391 | 392 | sort_by = 'w_in' 393 | if sort_by == 'w_in': 394 | w_in_mod1 = w_in[:, 1:nr+1] 395 | w_in_mod2 = w_in[:, nr+1:2*nr+1] 396 | w_in_modboth = w_in_mod1 + w_in_mod2 397 | w_prefs = np.argmax(w_in_modboth, axis=1) 398 | elif sort_by == 'w_out': 399 | w_prefs = np.argmax(w_out[1:], axis=0) 400 | 401 | # sort by labels then by prefs 402 | ind_sort = np.lexsort((w_prefs, self.labels)) 403 | 404 | ######################### Plotting Connectivity ############################### 405 | nx = self.hp['n_input'] 406 | ny = self.hp['n_output'] 407 | nh = len(self.ind_active) 408 | nr = self.hp['n_eachring'] 409 | nrule = len(self.hp['rules']) 410 | 411 | # Plot active units 412 | _w_rec = w_rec[ind_sort,:][:,ind_sort] 413 | _w_in = w_in[ind_sort,:] 414 | _w_out = w_out[:,ind_sort] 415 | _b_rec = b_rec[ind_sort, np.newaxis] 416 | _b_out = b_out[:, np.newaxis] 417 | labels = self.labels[ind_sort] 418 | 419 | l = 0.3 420 | l0 = (1-1.5*l)/nh 421 | 422 | plot_infos = [(_w_rec , [l ,l ,nh*l0 ,nh*l0]), 423 | (_w_in[:,[0]] , [l-(nx+15)*l0 ,l ,1*l0 ,nh*l0]), # Fixation input 424 | (_w_in[:,1:nr+1] , [l-(nx+11)*l0 ,l ,nr*l0 ,nh*l0]), # Mod 1 stimulus 425 | (_w_in[:,nr+1:2*nr+1], [l-(nx-nr+8)*l0 ,l ,nr*l0 ,nh*l0]), # Mod 2 stimulus 426 | (_w_in[:,2*nr+1:] , [l-(nx-2*nr+5)*l0,l ,nrule*l0 ,nh*l0]), # Rule inputs 427 | (_w_out[[0],:] , [l ,l-(4)*l0 ,nh*l0 ,1*l0]), 428 | (_w_out[1:,:] , [l ,l-(ny+6)*l0,nh*l0 ,(ny-1)*l0]), 429 | (_b_rec , [l+(nh+6)*l0 ,l ,l0 ,nh*l0]), 430 | (_b_out , [l+(nh+6)*l0 ,l-(ny+6)*l0,l0 ,ny*l0])] 431 | 432 | # cmap = sns.diverging_palette(220, 10, sep=80, as_cmap=True) 433 | cmap = 'coolwarm' 434 | fig = plt.figure(figsize=(6, 6)) 435 | for plot_info in plot_infos: 436 | ax = fig.add_axes(plot_info[1]) 437 | vmin, vmid, vmax = np.percentile(plot_info[0].flatten(), [5,50,95]) 438 | _ = ax.imshow(plot_info[0], interpolation='nearest', cmap=cmap, aspect='auto', 439 | vmin=vmid-(vmax-vmin)/2, vmax=vmid+(vmax-vmin)/2) 440 | ax.axis('off') 441 | 442 | ax1 = fig.add_axes([l , l+nh*l0, nh*l0, 6*l0]) 443 | ax2 = fig.add_axes([l-6*l0, l , 6*l0 , nh*l0]) 444 | for il, l in enumerate(self.unique_labels): 445 | ind_l = np.where(labels==l)[0][[0, -1]]+np.array([0,1]) 446 | ax1.plot(ind_l, [0,0], linewidth=2, solid_capstyle='butt', 447 | color=kelly_colors[il+1]) 448 | ax2.plot([0,0], len(labels)-ind_l, linewidth=2, solid_capstyle='butt', 449 | color=kelly_colors[il+1]) 450 | ax1.set_xlim([0, len(labels)]) 451 | ax2.set_ylim([0, len(labels)]) 452 | ax1.axis('off') 453 | ax2.axis('off') 454 | if save: 455 | plt.savefig('figure/connectivity_by'+self.data_type+'.pdf', transparent=True) 456 | plt.show() 457 | 458 | def lesions(self): 459 | labels = self.labels 460 | 461 | from network import get_perf 462 | from task import generate_trials 463 | 464 | # The first will be the intact network 465 | lesion_units_list = [None] 466 | for il, l in enumerate(self.unique_labels): 467 | ind_l = np.where(labels == l)[0] 468 | # In original indices 469 | lesion_units_list += [self.ind_active[ind_l]] 470 | 471 | perfs_store_list = list() 472 | perfs_changes = list() 473 | cost_store_list = list() 474 | cost_changes = list() 475 | 476 | for i, lesion_units in enumerate(lesion_units_list): 477 | model = Model(self.model_dir) 478 | hp = model.hp 479 | with tf.Session() as sess: 480 | model.restore() 481 | model.lesion_units(sess, lesion_units) 482 | 483 | perfs_store = list() 484 | cost_store = list() 485 | for rule in self.rules: 486 | n_rep = 16 487 | batch_size_test = 256 488 | batch_size_test_rep = int(batch_size_test / n_rep) 489 | clsq_tmp = list() 490 | perf_tmp = list() 491 | for i_rep in range(n_rep): 492 | trial = generate_trials(rule, hp, 'random', 493 | batch_size=batch_size_test_rep) 494 | feed_dict = tools.gen_feed_dict(model, trial, hp) 495 | y_hat_test, c_lsq = sess.run( 496 | [model.y_hat, model.cost_lsq], feed_dict=feed_dict) 497 | 498 | # Cost is first summed over time, and averaged across batch and units 499 | # We did the averaging over time through c_mask 500 | 501 | # IMPORTANT CHANGES: take overall mean 502 | perf_test = np.mean(get_perf(y_hat_test, trial.y_loc)) 503 | clsq_tmp.append(c_lsq) 504 | perf_tmp.append(perf_test) 505 | 506 | perfs_store.append(np.mean(perf_tmp)) 507 | cost_store.append(np.mean(clsq_tmp)) 508 | 509 | perfs_store = np.array(perfs_store) 510 | cost_store = np.array(cost_store) 511 | 512 | perfs_store_list.append(perfs_store) 513 | cost_store_list.append(cost_store) 514 | 515 | if i > 0: 516 | perfs_changes.append(perfs_store - perfs_store_list[0]) 517 | cost_changes.append(cost_store - cost_store_list[0]) 518 | 519 | perfs_changes = np.array(perfs_changes) 520 | cost_changes = np.array(cost_changes) 521 | 522 | return perfs_changes, cost_changes 523 | 524 | def plot_lesions(self): 525 | """Lesion individual cluster and show performance.""" 526 | 527 | perfs_changes, cost_changes = self.lesions() 528 | 529 | cb_labels = ['Performance change after lesioning', 530 | 'Cost change after lesioning'] 531 | vmins = [-0.5, -0.5] 532 | vmaxs = [+0.5, +0.5] 533 | ticks = [[-0.5,0.5], [-0.5, 0.5]] 534 | changes_plot = [perfs_changes, cost_changes] 535 | 536 | fs = 6 537 | figsize = (2.5,2.5) 538 | rect = [0.3, 0.2, 0.5, 0.7] 539 | rect_cb = [0.82, 0.2, 0.03, 0.7] 540 | rect_color = [0.3, 0.15, 0.5, 0.05] 541 | for i in range(2): 542 | fig = plt.figure(figsize=figsize) 543 | ax = fig.add_axes(rect) 544 | im = ax.imshow(changes_plot[i].T, cmap='coolwarm', aspect='auto', 545 | interpolation='nearest', vmin=vmins[i], vmax=vmaxs[i]) 546 | 547 | tick_names = [rule_name[r] for r in self.rules] 548 | _ = plt.yticks(range(len(tick_names)), tick_names, 549 | rotation=0, va='center', fontsize=fs) 550 | plt.xticks([]) 551 | plt.xlabel('Clusters', fontsize=7, labelpad=13) 552 | ax.tick_params('both', length=0) 553 | for loc in ['bottom','top','left','right']: 554 | ax.spines[loc].set_visible(False) 555 | 556 | ax = fig.add_axes(rect_cb) 557 | cb = plt.colorbar(im, cax=ax, ticks=ticks[i]) 558 | cb.outline.set_linewidth(0.5) 559 | cb.set_label(cb_labels[i], fontsize=7, labelpad=-10) 560 | plt.tick_params(axis='both', which='major', labelsize=7) 561 | 562 | ax = fig.add_axes(rect_color) 563 | for il, l in enumerate(self.unique_labels): 564 | ax.plot([il, il+1], [0,0], linewidth=4, solid_capstyle='butt', 565 | color=kelly_colors[il+1]) 566 | ax.text(np.mean(il+0.5), -0.5, str(il+1), fontsize=6, 567 | ha='center', va='top', color=kelly_colors[il+1]) 568 | ax.set_xlim([0, len(self.unique_labels)]) 569 | ax.set_ylim([-1, 1]) 570 | ax.axis('off') 571 | 572 | if save: 573 | plt.savefig('figure/lesion_cluster_by'+self.data_type+'_{:d}.pdf'.format(i), transparent=True) 574 | 575 | if __name__ == '__main__': 576 | root_dir = './data/train_all' 577 | model_dir = root_dir + '/1' 578 | # CA = Analysis(model_dir, data_type='rule') 579 | -------------------------------------------------------------------------------- /analysis/contlearn_schematic.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Plot schematics for continual learning 4 | @author: guangyuyang 5 | """ 6 | 7 | from __future__ import division 8 | import numpy as np 9 | import matplotlib.pyplot as plt 10 | 11 | f_loss = lambda x, u : (np.abs(x-u))**4 12 | 13 | # colors = sns.xkcd_palette(['red', 'blue', 'green']) 14 | colors = np.array([[27, 158, 119], [117, 112, 179]])/255. 15 | d = 4 16 | u_ = [-2, 2] 17 | lines = [] 18 | fig = plt.figure(figsize=(2.5,1.5)) 19 | ax = fig.add_axes([0.15,0.25,0.8,0.7]) 20 | for i in range(3): 21 | if i < 2: 22 | color = colors[i] 23 | u = u_[i] 24 | x_plot = np.linspace(u-d, u+d, 1000) 25 | loss_plot = f_loss(x_plot, u) 26 | lw = 4 27 | else: 28 | color = (colors[0]+colors[1])/2 29 | x_plot = np.linspace(np.min(u_)-0.5, np.max(u_)+.5, 1000) 30 | loss_plot = 0 31 | for u in u_: 32 | loss_plot = loss_plot + f_loss(x_plot, u) 33 | lw = 2 34 | # color = (np.array(colors[0])+np.array(colors[1]))/2 35 | line = ax.plot(x_plot, loss_plot, lw=lw, color=color) 36 | lines.append(line[0]) 37 | i_min = np.argmin(loss_plot) 38 | ax.plot(x_plot[i_min], loss_plot[i_min], markersize=lw*2, marker='o', 39 | color='white', markeredgecolor=color, markeredgewidth=lw*0.7) 40 | 41 | lg = ax.legend(lines, ('Task 1', 'Task 2', 'Task 1 + 2'), 42 | fontsize=7, ncol=1, bbox_to_anchor=(1.1,1.1), 43 | labelspacing=0.2, loc=1, frameon=False) 44 | 45 | ax.set_xticks([]) 46 | ax.set_yticks([]) 47 | 48 | ax.spines["right"].set_visible(False) 49 | ax.spines["top"].set_visible(False) 50 | ax.spines['left'].set_position(('outward', 10)) # outward by 10 points 51 | ax.spines['bottom'].set_position(('outward', 10)) # outward by 10 points 52 | ax.xaxis.set_ticks_position('bottom') 53 | ax.yaxis.set_ticks_position('left') 54 | 55 | ax.set_ylim(bottom=-50) 56 | ax.set_xlim([np.min(u_)-d-0.3, np.max(u_)+d+1]) 57 | 58 | ax.set_xlabel(r'Parameter $\theta$', fontsize=7) 59 | ax.set_ylabel(r'Loss $L$', fontsize=7) 60 | plt.savefig('figure/schematic_contlearn.pdf', transparent=True) -------------------------------------------------------------------------------- /analysis/dimensionality.py: -------------------------------------------------------------------------------- 1 | """ 2 | Analyze dimennsionality by counting implementable linear classifiers 3 | Rigotti et al 2013 Nature 4 | 5 | @ Robert Yang 2017 6 | """ 7 | 8 | from __future__ import division 9 | 10 | import os 11 | import time 12 | import numpy as np 13 | import pickle 14 | import itertools 15 | from collections import OrderedDict 16 | import matplotlib as mpl 17 | import matplotlib.pyplot as plt 18 | from sklearn.svm import SVC 19 | from sklearn.discriminant_analysis import LinearDiscriminantAnalysis 20 | # import seaborn.apionly as sns 21 | from contextdm_data_analysis import condition_averaging_split_trte 22 | from contextdm_data_analysis import get_cond_16_dim 23 | from contextdm_data_analysis import get_condavg_simu_16_dim 24 | from contextdm_data_analysis import run_simulation 25 | 26 | 27 | def compute_implementable_classifier(data_train, data_test): 28 | # data_train & data_test should be (n_condition, n_unit) 29 | 30 | # number of conditions 31 | n_condition = data_train.shape[0] 32 | 33 | # classification 34 | classifier = SVC(kernel='linear', C=1.0) 35 | # classifier = LinearDiscriminantAnalysis() # much slower 36 | 37 | n_coloration = 2**(n_condition-1)-1 38 | if n_coloration > 10**6: 39 | raise ValueError('too many colorations') 40 | 41 | performance_train = list() 42 | performance_test = list() 43 | colors_list = list() 44 | for i, colors_ in enumerate(itertools.product([0,1], repeat=n_condition-1)): 45 | if i == 0: 46 | # Don't use [0, 0, ..., 0] 47 | continue 48 | 49 | colors = np.array([0]+list(colors_)) # the first one is always zero to break symmetry 50 | 51 | # Fit 52 | classifier.fit(data_train, colors) 53 | 54 | color_train_predict = classifier.predict(data_train) 55 | color_test_predict = classifier.predict(data_test) 56 | 57 | performance_train.append(np.mean(colors==color_train_predict)) 58 | performance_test.append(np.mean(colors==color_test_predict)) 59 | colors_list.append(colors) 60 | 61 | performance_train = np.array(performance_train) 62 | performance_test = np.array(performance_test) 63 | 64 | # finish = time.time() 65 | # print('Time taken {:0.5f}s'.format(finish-start)) 66 | 67 | threshold = 0.8 68 | 69 | threshold = 0.9 #maddy changed. 70 | 71 | n_implementable_train = np.sum(performance_train>threshold) 72 | n_implementable_test = np.sum(performance_test >threshold) 73 | 74 | # Estimate total number of implementable classifications 75 | n_total_classification = 2**(n_condition-1)-1 76 | 77 | N_implementable_train = n_total_classification * (n_implementable_train/n_coloration) 78 | N_implementable_test = n_total_classification * (n_implementable_test /n_coloration) 79 | print N_implementable_train, N_implementable_test 80 | 81 | return N_implementable_train, N_implementable_test 82 | 83 | # print(np.log2(N_implementable_train), np.log2(N_implementable_test)) 84 | 85 | 86 | def generate_test_data(): 87 | # generate some data (n_batch can be the same or much larger than n_condition) 88 | n_time, n_batch, n_unit = 10, 100, 100 89 | 90 | data = np.random.rand(n_time, n_batch, n_unit) 91 | 92 | # trial condition 93 | n_condition = 72 94 | conditions = np.array(range(n_condition)*int(n_batch/n_condition)) 95 | 96 | # TODO: Splitting training and testing data 97 | # For now assume the same, and assume batch=condition 98 | data_train = np.random.rand(n_time, n_condition, n_unit) 99 | data_test = np.random.randn(n_time, n_condition, n_unit)*0.1 + data_train 100 | 101 | # pick one data point 102 | i_t = 0 103 | data_train_t = data_train[i_t] # (n_batch, n_unit) 104 | data_test_t = data_test[i_t] # (n_batch, n_unit) 105 | 106 | return data_train_t, data_test_t 107 | 108 | 109 | def _get_dimension(data_train, data_test, n_unit_used=None): 110 | # Temporarily excluding neurons because there are not enough trials 111 | n_unit = data_train.shape[1] 112 | ind_units = range(n_unit) 113 | excluding = np.where(np.isnan(np.sum(data_train, axis=0)+np.sum(data_test, axis=0)))[0] 114 | for exc in excluding: 115 | ind_units.pop(ind_units.index(exc)) 116 | data_train, data_test = data_train[:, ind_units], data_test[:, ind_units] 117 | n_unit = data_train.shape[1] 118 | 119 | if n_unit_used is not None: 120 | ind_used = np.arange(n_unit) 121 | np.random.shuffle(ind_used) 122 | ind_used = ind_used[:n_unit_used] 123 | data_train, data_test = data_train[:, ind_used], data_test[:, ind_used] 124 | 125 | N_implementable_train, N_implementable_test = \ 126 | compute_implementable_classifier(data_train, data_test) 127 | 128 | # print(np.log2(N_implementable_train), np.log2(N_implementable_test)) 129 | 130 | return N_implementable_train, N_implementable_test 131 | 132 | 133 | def get_dimension(data_train, data_test, n_unit_used=None, n_rep=1): 134 | N_implementable_trains = np.zeros(n_rep) 135 | N_implementable_tests = np.zeros(n_rep) 136 | for i_rep in range(n_rep): 137 | N_implementable_train, N_implementable_test = \ 138 | _get_dimension(data_train, data_test, n_unit_used=n_unit_used) 139 | N_implementable_tests[i_rep] = N_implementable_test 140 | N_implementable_trains[i_rep]= N_implementable_train 141 | 142 | return np.log2(np.mean(N_implementable_trains)), np.log2(np.mean(N_implementable_tests)) 143 | 144 | 145 | def _get_dimension_16_dim(data_train, data_test, n_unit_used=None): 146 | # Temporarily excluding neurons because there are not enough trials 147 | n_unit = data_train.shape[1] 148 | ind_units = range(n_unit) 149 | exc_train_notactive = np.where(0==np.sum(data_train, axis=0))[0]#maddy added below. 150 | exc_test_notactive = np.where(0==np.sum(data_test, axis=0))[0] 151 | excluding = np.concatenate([exc_train_notactive, exc_test_notactive]) 152 | for exc in excluding: 153 | ind_units.pop(ind_units.index(exc)) 154 | 155 | """ 156 | excluding = np.where(np.isnan(np.sum(data_train, axis=0)+np.sum(data_test, axis=0)))[0] 157 | for exc in excluding: 158 | ind_units.pop(ind_units.index(exc)) 159 | data_train, data_test = data_train[:, ind_units], data_test[:, ind_units] 160 | """ 161 | 162 | if n_unit_used == None: 163 | raise ValueError("specify no. of units used.") 164 | else: 165 | ind_used = np.arange(n_unit) 166 | np.random.shuffle(ind_used) 167 | ind_used = ind_used[:n_unit_used] 168 | data_train, data_test = data_train[:, ind_used], data_test[:, ind_used] 169 | 170 | N_implementable_train, N_implementable_test = \ 171 | compute_implementable_classifier(data_train, data_test) 172 | 173 | # print(np.log2(N_implementable_train), np.log2(N_implementable_test)) 174 | 175 | return N_implementable_train, N_implementable_test 176 | 177 | 178 | def get_dimension_16_dim(data_train, data_test, n_unit_used=None, n_rep=1): 179 | N_implementable_trains = np.zeros(n_rep) 180 | N_implementable_tests = np.zeros(n_rep) 181 | for i_rep in range(n_rep): 182 | N_implementable_train, N_implementable_test = \ 183 | _get_dimension_16_dim(data_train, data_test, n_unit_used=n_unit_used) 184 | N_implementable_tests[i_rep] = N_implementable_test 185 | N_implementable_trains[i_rep]= N_implementable_train 186 | 187 | print n_unit_used, np.log2(np.mean(N_implementable_trains)), np.log2(np.mean(N_implementable_tests)) 188 | return np.log2(np.mean(N_implementable_trains)), np.log2(np.mean(N_implementable_tests)) 189 | 190 | 191 | def _get_dimension_varyusedunit(analyze_data=False, n_unit_used_list=None, **kwargs): 192 | if analyze_data: 193 | from mante_data_analysis import load_mante_data 194 | data = load_mante_data() 195 | data_train, data_test = get_trial_avg(data=data) 196 | else: 197 | from choiceattend_analysis import StateSpaceAnalysis 198 | # save_addon = 'allrule_softplus_400largeinput' 199 | save_addon = kwargs['save_addon'] 200 | sigma_rec = 5 # chosen to reproduce the behavioral level performance 201 | ssa = StateSpaceAnalysis(save_addon, lesion_units=None, 202 | z_score=False, n_rep=30, sigma_rec=sigma_rec) 203 | 204 | data_train, data_test = get_trial_avg( 205 | analyze_data=False, task_var=ssa.task_var, data=ssa.H_original.mean(axis=0)) 206 | 207 | if n_unit_used_list is None: 208 | n_unit_used_list = range(1, 400, 40) 209 | 210 | N_implementable_test_list = list() 211 | N_implementable_train_list = list() 212 | for n_unit_used in n_unit_used_list: 213 | start = time.time() 214 | N_implementable_train, N_implementable_test = \ 215 | get_dimension(data_train, data_test, n_unit_used) 216 | 217 | N_implementable_test_list.append(N_implementable_test) 218 | N_implementable_train_list.append(N_implementable_train) 219 | 220 | print('Time taken {:0.3f}s'.format(time.time()-start)) 221 | print(n_unit_used, np.log2(N_implementable_test)) 222 | 223 | return n_unit_used_list, N_implementable_test_list 224 | 225 | # plt.plot(n_unit_used_list, np.log2(N_implementable_test_list)) 226 | 227 | 228 | def get_dimension_varyusedunit(analyze_data=False, n_rep=3, **kwargs): 229 | results = dict() 230 | if analyze_data: 231 | n_unit_used_list = range(1, 701, 50) 232 | save_name = 'dimension_data' 233 | else: 234 | n_unit_used_list = range(1, 300, 50) 235 | save_name = 'dimension_'+kwargs['save_addon'] 236 | 237 | results['n_unit_used_list'] = n_unit_used_list 238 | N_implementable_test_matrix = list() 239 | for i_rep in range(n_rep): 240 | n_unit_used_list, N_implementable_test_list = \ 241 | _get_dimension_varyusedunit(analyze_data, n_unit_used_list, **kwargs) 242 | N_implementable_test_matrix.append(N_implementable_test_list) 243 | 244 | results['N_implementable_test_matrix'] = N_implementable_test_matrix 245 | 246 | with open(os.path.join('data', save_name+'.pkl'), 'wb') as f: 247 | pickle.dump(results, f) 248 | 249 | 250 | def _get_dimension_varyusedunit_16_dim(data_train, data_test, n_unit_used_list, n_rep): 251 | 252 | N_implementable_test_list = list() 253 | N_implementable_train_list = list() 254 | 255 | for n_unit_used in n_unit_used_list: 256 | 257 | start = time.time() 258 | N_implementable_train, N_implementable_test = \ 259 | get_dimension_16_dim(data_train, data_test, n_unit_used,n_rep=n_rep) 260 | 261 | N_implementable_test_list.append(N_implementable_test) 262 | N_implementable_train_list.append(N_implementable_train) 263 | 264 | print('Time taken {:0.3f}s'.format(time.time()-start)) 265 | 266 | return n_unit_used_list, N_implementable_test_list 267 | 268 | #def get_dimension_varyusedunit_16_dim(data_train, data_test, n_rep=3):#maddy 269 | def get_dimension_varyusedunit_16_dim(save_name, data_train, data_test, n_rep):#maddy 270 | results_new = dict() #maddy 271 | n_unit_used_list = range(1, 352, 50)#range(1, 702, 50) 272 | save_name = 'dimension_'+save_name 273 | 274 | results_new['n_unit_used_list'] = n_unit_used_list 275 | N_implementable_test_matrix = list() 276 | 277 | #for i_rep in range(n_rep): 278 | 279 | n_unit_used_list, N_implementable_test_list = \ 280 | _get_dimension_varyusedunit_16_dim(data_train, data_test, n_unit_used_list, n_rep=n_rep) 281 | 282 | N_implementable_test_matrix.append(N_implementable_test_list) 283 | 284 | results_new['N_implementable_test_matrix'] = N_implementable_test_matrix 285 | 286 | filename = os.path.join('data', save_name+'.pkl')#maddy added. 287 | if os.path.isfile(filename): 288 | with open(filename, 'rb') as resold: 289 | results = pickle.load(resold) 290 | results['N_implementable_test_matrix'].append(N_implementable_test_matrix) 291 | with open(os.path.join('data', save_name+'.pkl'), 'wb') as f: 292 | pickle.dump(results, f) 293 | print "results", results 294 | else: 295 | with open(os.path.join('data', save_name+'.pkl'), 'wb') as f: 296 | pickle.dump(results_new, f) 297 | print "results_new", results_new 298 | #with open(os.path.join('data', save_name+'.pkl'), 'wb') as f: 299 | # pickle.dump(results, f) 300 | 301 | def call_get_dimension_varyusedunit_16_dim(save_name = 'debug', n_rep=10):#10. #maddy 302 | 303 | #fname = os.path.join('data', 'config_' + save_name + '.pkl') 304 | 305 | #if save_name == 'Data': 306 | # fname = os.path.join('data', 'ManteData.pkl') #'ManteDataCond.pkl' 307 | 308 | #if os.path.isfile(fname): 309 | # with open(fname, 'rb') as f: 310 | # Data = pickle.load(f) 311 | 312 | train_set, test_set = get_condavg_simu_16_dim(save_name) 313 | 314 | #ind_time = 14##maddy changed. 315 | for ind_time in np.arange(0,15):#15 316 | print "ind_time", ind_time 317 | train_timept, test_timept = train_set[ind_time], test_set[ind_time] 318 | get_dimension_varyusedunit_16_dim(save_name, train_timept, test_timept, n_rep=n_rep) 319 | 320 | 321 | def plot_dimension_estimation(save_addon_list): 322 | import scipy as sp 323 | import scipy.stats 324 | 325 | def mean_confidence_interval(data, confidence=0.95, **kwargs): 326 | a = 1.0*np.array(data) 327 | n = len(a) 328 | m, se = np.mean(a, **kwargs), scipy.stats.sem(a, **kwargs) 329 | h = se * sp.stats.t._ppf((1+confidence)/2., n-1) 330 | return m, m-h, m+h 331 | 332 | # colors = dict(zip(['data','model'], sns.xkcd_palette(['black', 'red']))) 333 | plt.figure() 334 | for i_plt, save_addon in enumerate(save_addon_list): 335 | # save_name = 'dimension_data' 336 | with open(os.path.join('data', 'dimension'+save_addon+'.pkl'), 'rb') as f: 337 | results = pickle.load(f) 338 | N_implementable_test_matrix = np.array(results['N_implementable_test_matrix']) 339 | n_unit_used_list = results['n_unit_used_list'] 340 | 341 | mean_val, lower_bnds, higher_bnds = mean_confidence_interval(N_implementable_test_matrix, axis=0) 342 | print(mean_val[1]) 343 | plt.plot(n_unit_used_list, np.log2(mean_val), color=plt.cm.cool(1.0*i_plt/len(save_addon_list))) 344 | # plt.fill_between(n_unit_used_list, np.log2(lower_bnds), np.log2(higher_bnds), 345 | # alpha=0.3, color=color) 346 | 347 | plt.plot(n_unit_used_list, [15]*len(n_unit_used_list), '--', color='black') 348 | 349 | # plt.legend(loc=4) 350 | plt.xlabel('number of neurons or units') 351 | plt.ylabel('number of dimensions') 352 | # plt.savefig(os.path.join('figure','dimension_estimation.pdf'), transparent=True) 353 | 354 | 355 | if __name__ == '__main__': 356 | pass 357 | 358 | save_name = 'hidden_64_seed_2_softplus_LeakyRNN_diag__regwt_L1_1e_min_4_regact_L1_1e_min_4' 359 | #save_name = 'Data' 360 | call_get_dimension_varyusedunit_16_dim(save_name=save_name) 361 | 362 | print "here" 363 | 364 | -------------------------------------------------------------------------------- /analysis/posttrain_analysis.py: -------------------------------------------------------------------------------- 1 | """Analyze the results after varying hyperparameters.""" 2 | 3 | from __future__ import division 4 | 5 | from collections import defaultdict 6 | from collections import OrderedDict 7 | import os 8 | import numpy as np 9 | import matplotlib as mpl 10 | import matplotlib.pyplot as plt 11 | 12 | import tools 13 | 14 | mpl.rcParams.update({'font.size': 7}) 15 | 16 | DATAPATH = os.path.join(os.getcwd(), 'data', 'posttrain') 17 | FIGPATH = os.path.join(os.getcwd(), 'figure') 18 | 19 | 20 | def get_avg_performance(model_dirs, rule): 21 | """Get average performance across trials for model_dirs. 22 | 23 | Some networks converge earlier than others. For those converged early, 24 | choose the last performance for later performance 25 | """ 26 | perfs = defaultdict(list) 27 | 28 | trials = [] 29 | for model_dir in model_dirs: 30 | log = tools.load_log(model_dir) 31 | trials += list(log['trials']) 32 | trials = np.sort(np.unique(trials)) 33 | 34 | for model_dir in model_dirs: 35 | log = tools.load_log(model_dir) 36 | for t in trials: 37 | if t in log['trials']: 38 | ind = log['trials'].index(t) 39 | else: 40 | ind = -1 41 | perfs[t].append(log['perf_' + rule][ind]) 42 | # for t, perf in zip(log['trials'], log['perf_'+rule]): 43 | # perfs[t].append(perf) 44 | 45 | # average performances 46 | trials = list(perfs.keys()) 47 | trials = np.sort(trials) 48 | avg_perfs = [np.mean(perfs[t]) for t in trials] 49 | return avg_perfs, trials 50 | 51 | 52 | def plot_posttrain_performance(posttrain_setup, trainables): 53 | from task import rule_name 54 | hp_target = {'posttrain_setup': posttrain_setup, 55 | 'trainables': trainables} 56 | fs = 7 57 | fig = plt.figure(figsize=(1.5, 1.2)) 58 | ax = fig.add_axes([0.25, 0.3, 0.7, 0.65]) 59 | 60 | colors = ['xkcd:blue', 'xkcd:red'] 61 | for pretrain_setup in [1, 0]: 62 | c = colors[pretrain_setup] 63 | l = ['B', 'A'][pretrain_setup] 64 | hp_target['pretrain_setup'] = pretrain_setup 65 | model_dirs = tools.find_all_models(DATAPATH, hp_target) 66 | hp = tools.load_hp(model_dirs[0]) 67 | rule = hp['rule_trains'][0] # depends on posttrain setup 68 | for model_dir in model_dirs: 69 | log = tools.load_log(model_dir) 70 | ax.plot(np.array(log['trials']) / 1000., 71 | log['perf_' + rule], color=c, alpha=0.1) 72 | avg_perfs, trials = get_avg_performance(model_dirs, rule) 73 | l0 = ax.plot(trials / 1000., avg_perfs, color=c, label=l) 74 | 75 | ax.set_ylim([0, 1]) 76 | ax.set_xlabel('Total trials (1,000)', fontsize=fs, labelpad=2) 77 | ax.set_yticks([0, 1]) 78 | ax.spines["right"].set_visible(False) 79 | ax.spines["top"].set_visible(False) 80 | 81 | # lg = ax.legend(title='Pretrained set', ncol=2, loc=4, 82 | # frameon=False) 83 | 84 | plt.ylabel('Perf. of ' + rule_name[rule]) 85 | # plt.title('Training ' + hp_target['trainables']) 86 | plt.savefig('figure/Posttrain_post{:d}train{:s}.pdf'.format( 87 | posttrain_setup, trainables), transparent=True) 88 | # plt.show() 89 | 90 | 91 | if __name__ == '__main__': 92 | pass 93 | 94 | -------------------------------------------------------------------------------- /analysis/standard_analysis.py: -------------------------------------------------------------------------------- 1 | """Standard analyses that can be performed on any task""" 2 | 3 | from __future__ import division 4 | 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | import tensorflow as tf 8 | 9 | from task import generate_trials, rule_name 10 | from network import Model 11 | import tools 12 | 13 | 14 | def easy_activity_plot(model_dir, rule): 15 | """A simple plot of neural activity from one task. 16 | 17 | Args: 18 | model_dir: directory where model file is saved 19 | rule: string, the rule to plot 20 | """ 21 | 22 | model = Model(model_dir) 23 | hp = model.hp 24 | 25 | with tf.Session() as sess: 26 | model.restore() 27 | 28 | trial = generate_trials(rule, hp, mode='test') 29 | feed_dict = tools.gen_feed_dict(model, trial, hp) 30 | h, y_hat = sess.run([model.h, model.y_hat], feed_dict=feed_dict) 31 | # All matrices have shape (n_time, n_condition, n_neuron) 32 | 33 | # Take only the one example trial 34 | i_trial = 0 35 | 36 | for activity, title in zip([trial.x, h, y_hat], 37 | ['input', 'recurrent', 'output']): 38 | plt.figure() 39 | plt.imshow(activity[:,i_trial,:].T, aspect='auto', cmap='hot', 40 | interpolation='none', origin='lower') 41 | plt.title(title) 42 | plt.colorbar() 43 | plt.show() 44 | 45 | 46 | def easy_connectivity_plot(model_dir): 47 | """A simple plot of network connectivity.""" 48 | 49 | model = Model(model_dir) 50 | with tf.Session() as sess: 51 | model.restore() 52 | # get all connection weights and biases as tensorflow variables 53 | var_list = model.var_list 54 | # evaluate the parameters after training 55 | params = [sess.run(var) for var in var_list] 56 | # get name of each variable 57 | names = [var.name for var in var_list] 58 | 59 | # Plot weights 60 | for param, name in zip(params, names): 61 | if len(param.shape) != 2: 62 | continue 63 | 64 | vmax = np.max(abs(param))*0.7 65 | plt.figure() 66 | # notice the transpose 67 | plt.imshow(param.T, aspect='auto', cmap='bwr', vmin=-vmax, vmax=vmax, 68 | interpolation='none', origin='lower') 69 | plt.title(name) 70 | plt.colorbar() 71 | plt.xlabel('From') 72 | plt.ylabel('To') 73 | plt.show() 74 | 75 | 76 | def pretty_inputoutput_plot(model_dir, rule, save=False, plot_ylabel=False): 77 | """Plot the input and output activity for a sample trial from one task. 78 | 79 | Args: 80 | model_dir: model directory 81 | rule: string, the rule 82 | save: bool, whether to save plots 83 | plot_ylabel: bool, whether to plot ylable 84 | """ 85 | 86 | 87 | fs = 7 88 | 89 | model = Model(model_dir) 90 | hp = model.hp 91 | 92 | with tf.Session() as sess: 93 | model.restore() 94 | 95 | trial = generate_trials(rule, hp, mode='test') 96 | x, y = trial.x, trial.y 97 | feed_dict = tools.gen_feed_dict(model, trial, hp) 98 | h, y_hat = sess.run([model.h, model.y_hat], feed_dict=feed_dict) 99 | 100 | t_plot = np.arange(x.shape[0])*hp['dt']/1000 101 | 102 | assert hp['num_ring'] == 2 103 | 104 | n_eachring = hp['n_eachring'] 105 | 106 | fig = plt.figure(figsize=(1.3,2)) 107 | ylabels = ['fix. in', 'stim. mod1', 'stim. mod2','fix. out', 'out'] 108 | heights = np.array([0.03,0.2,0.2,0.03,0.2])+0.01 109 | for i in range(5): 110 | ax = fig.add_axes([0.15,sum(heights[i+1:]+0.02)+0.1,0.8,heights[i]]) 111 | cmap = 'Purples' 112 | plt.xticks([]) 113 | ax.tick_params(axis='both', which='major', labelsize=fs, 114 | width=0.5, length=2, pad=3) 115 | 116 | if plot_ylabel: 117 | ax.spines["right"].set_visible(False) 118 | ax.spines["bottom"].set_visible(False) 119 | ax.spines["top"].set_visible(False) 120 | ax.xaxis.set_ticks_position('bottom') 121 | ax.yaxis.set_ticks_position('left') 122 | 123 | else: 124 | ax.spines["left"].set_visible(False) 125 | ax.spines["right"].set_visible(False) 126 | ax.spines["bottom"].set_visible(False) 127 | ax.spines["top"].set_visible(False) 128 | ax.xaxis.set_ticks_position('none') 129 | 130 | if i == 0: 131 | plt.plot(t_plot, x[:,0,0], color='xkcd:blue') 132 | if plot_ylabel: 133 | plt.yticks([0,1],['',''],rotation='vertical') 134 | plt.ylim([-0.1,1.5]) 135 | plt.title(rule_name[rule],fontsize=fs) 136 | elif i == 1: 137 | plt.imshow(x[:,0,1:1+n_eachring].T, aspect='auto', cmap=cmap, 138 | vmin=0, vmax=1, interpolation='none',origin='lower') 139 | if plot_ylabel: 140 | plt.yticks([0, (n_eachring-1)/2, n_eachring-1], 141 | [r'0$\degree$',r'180$\degree$',r'360$\degree$'], 142 | rotation='vertical') 143 | elif i == 2: 144 | plt.imshow(x[:, 0, 1+n_eachring:1+2*n_eachring].T, 145 | aspect='auto', cmap=cmap, vmin=0, vmax=1, 146 | interpolation='none',origin='lower') 147 | 148 | if plot_ylabel: 149 | plt.yticks( 150 | [0, (n_eachring-1)/2, n_eachring-1], 151 | [r'0$\degree$', r'180$\degree$', r'360$\degree$'], 152 | rotation='vertical') 153 | elif i == 3: 154 | plt.plot(t_plot, y[:,0,0],color='xkcd:green') 155 | plt.plot(t_plot, y_hat[:,0,0],color='xkcd:blue') 156 | if plot_ylabel: 157 | plt.yticks([0.05,0.8],['',''],rotation='vertical') 158 | plt.ylim([-0.1,1.1]) 159 | elif i == 4: 160 | plt.imshow(y_hat[:, 0, 1:].T, aspect='auto', cmap=cmap, 161 | vmin=0, vmax=1, interpolation='none', origin='lower') 162 | if plot_ylabel: 163 | plt.yticks( 164 | [0, (n_eachring-1)/2, n_eachring-1], 165 | [r'0$\degree$', r'180$\degree$', r'360$\degree$'], 166 | rotation='vertical') 167 | plt.xticks([0,y_hat.shape[0]], ['0', '2']) 168 | plt.xlabel('Time (s)',fontsize=fs, labelpad=-3) 169 | ax.spines["bottom"].set_visible(True) 170 | 171 | if plot_ylabel: 172 | plt.ylabel(ylabels[i],fontsize=fs) 173 | else: 174 | plt.yticks([]) 175 | ax.get_yaxis().set_label_coords(-0.12,0.5) 176 | 177 | if save: 178 | save_name = 'figure/sample_'+rule_name[rule].replace(' ','')+'.pdf' 179 | plt.savefig(save_name, transparent=True) 180 | plt.show() 181 | 182 | # plt.figure() 183 | # _ = plt.plot(h_sample[:,0,:20]) 184 | # plt.show() 185 | # 186 | # plt.figure() 187 | # _ = plt.plot(y_sample[:,0,:]) 188 | # plt.show() 189 | 190 | 191 | def pretty_singleneuron_plot(model_dir, 192 | rules, 193 | neurons, 194 | epoch=None, 195 | save=False, 196 | ylabel_firstonly=True, 197 | trace_only=False, 198 | plot_stim_avg=False, 199 | save_name=''): 200 | """Plot the activity of a single neuron in time across many trials 201 | 202 | Args: 203 | model_dir: 204 | rules: rules to plot 205 | neurons: indices of neurons to plot 206 | epoch: epoch to plot 207 | save: save figure? 208 | ylabel_firstonly: if True, only plot ylabel for the first rule in rules 209 | """ 210 | 211 | if isinstance(rules, str): 212 | rules = [rules] 213 | 214 | try: 215 | _ = iter(neurons) 216 | except TypeError: 217 | neurons = [neurons] 218 | 219 | h_tests = dict() 220 | model = Model(model_dir) 221 | hp = model.hp 222 | with tf.Session() as sess: 223 | model.restore() 224 | 225 | t_start = int(500/hp['dt']) 226 | 227 | for rule in rules: 228 | # Generate a batch of trial from the test mode 229 | trial = generate_trials(rule, hp, mode='test') 230 | feed_dict = tools.gen_feed_dict(model, trial, hp) 231 | h = sess.run(model.h, feed_dict=feed_dict) 232 | h_tests[rule] = h 233 | 234 | for neuron in neurons: 235 | h_max = np.max([h_tests[r][t_start:,:,neuron].max() for r in rules]) 236 | for j, rule in enumerate(rules): 237 | fs = 6 238 | fig = plt.figure(figsize=(1.0,0.8)) 239 | ax = fig.add_axes([0.35,0.25,0.55,0.55]) 240 | t_plot = np.arange(h_tests[rule][t_start:].shape[0])*hp['dt']/1000 241 | _ = ax.plot(t_plot, 242 | h_tests[rule][t_start:,:,neuron], lw=0.5, color='gray') 243 | 244 | if plot_stim_avg: 245 | # Plot stimulus averaged trace 246 | _ = ax.plot(np.arange(h_tests[rule][t_start:].shape[0])*hp['dt']/1000, 247 | h_tests[rule][t_start:,:,neuron].mean(axis=1), lw=1, color='black') 248 | 249 | if epoch is not None: 250 | e0, e1 = trial.epochs[epoch] 251 | e0 = e0 if e0 is not None else 0 252 | e1 = e1 if e1 is not None else h_tests[rule].shape[0] 253 | ax.plot([e0, e1], [h_max*1.15]*2, 254 | color='black',linewidth=1.5) 255 | figname = 'figure/trace_'+rule_name[rule]+epoch+save_name+'.pdf' 256 | else: 257 | figname = 'figure/trace_unit'+str(neuron)+rule_name[rule]+save_name+'.pdf' 258 | 259 | plt.ylim(np.array([-0.1, 1.2])*h_max) 260 | plt.xticks([0, 1.5]) 261 | plt.xlabel('Time (s)', fontsize=fs, labelpad=-5) 262 | plt.locator_params(axis='y', nbins=4) 263 | if j>0 and ylabel_firstonly: 264 | ax.set_yticklabels([]) 265 | else: 266 | plt.ylabel('Activitity (a.u.)', fontsize=fs, labelpad=2) 267 | plt.title('Unit {:d} '.format(neuron) + rule_name[rule], fontsize=5) 268 | ax.tick_params(axis='both', which='major', labelsize=fs) 269 | ax.spines["right"].set_visible(False) 270 | ax.spines["top"].set_visible(False) 271 | ax.xaxis.set_ticks_position('bottom') 272 | ax.yaxis.set_ticks_position('left') 273 | if trace_only: 274 | ax.spines["left"].set_visible(False) 275 | ax.spines["bottom"].set_visible(False) 276 | ax.xaxis.set_ticks_position('none') 277 | ax.set_xlabel('') 278 | ax.set_ylabel('') 279 | ax.set_xticks([]) 280 | ax.set_yticks([]) 281 | ax.set_title('') 282 | 283 | if save: 284 | plt.savefig(figname, transparent=True) 285 | plt.show() 286 | 287 | 288 | def activity_histogram(model_dir, 289 | rules, 290 | title=None, 291 | save_name=None): 292 | """Plot the activity histogram.""" 293 | 294 | if isinstance(rules, str): 295 | rules = [rules] 296 | 297 | h_all = None 298 | model = Model(model_dir) 299 | hp = model.hp 300 | with tf.Session() as sess: 301 | model.restore() 302 | 303 | t_start = int(500/hp['dt']) 304 | 305 | for rule in rules: 306 | # Generate a batch of trial from the test mode 307 | trial = generate_trials(rule, hp, mode='test') 308 | feed_dict = tools.gen_feed_dict(model, trial, hp) 309 | h = sess.run(model.h, feed_dict=feed_dict) 310 | h = h[t_start:, :, :] 311 | if h_all is None: 312 | h_all = h 313 | else: 314 | h_all = np.concatenate((h_all, h), axis=1) 315 | 316 | # var = h_all.var(axis=0).mean(axis=0) 317 | # ind = var > 1e-2 318 | # h_plot = h_all[:, :, ind].flatten() 319 | h_plot = h_all.flatten() 320 | 321 | fig = plt.figure(figsize=(1.5, 1.2)) 322 | ax = fig.add_axes([0.2, 0.2, 0.7, 0.6]) 323 | ax.hist(h_plot, bins=20, density=True) 324 | ax.set_xlabel('Activity', fontsize=7) 325 | [ax.spines[s].set_visible(False) for s in ['left', 'top', 'right']] 326 | ax.set_yticks([]) 327 | 328 | 329 | def schematic_plot(model_dir, rule=None): 330 | fontsize = 6 331 | 332 | rule = rule or 'dm1' 333 | 334 | model = Model(model_dir, dt=1) 335 | hp = model.hp 336 | 337 | with tf.Session() as sess: 338 | model.restore() 339 | trial = generate_trials(rule, hp, mode='test') 340 | feed_dict = tools.gen_feed_dict(model, trial, hp) 341 | x = trial.x 342 | h, y_hat = sess.run([model.h, model.y_hat], feed_dict=feed_dict) 343 | 344 | 345 | n_eachring = hp['n_eachring'] 346 | n_hidden = hp['n_rnn'] 347 | 348 | # Plot Stimulus 349 | fig = plt.figure(figsize=(1.0,1.2)) 350 | heights = np.array([0.06,0.25,0.25]) 351 | for i in range(3): 352 | ax = fig.add_axes([0.2,sum(heights[i+1:]+0.1)+0.05,0.7,heights[i]]) 353 | cmap = 'Purples' 354 | plt.xticks([]) 355 | 356 | # Fixed style for these plots 357 | ax.tick_params(axis='both', which='major', labelsize=fontsize, width=0.5, length=2, pad=3) 358 | ax.spines["left"].set_linewidth(0.5) 359 | ax.spines["right"].set_visible(False) 360 | ax.spines["bottom"].set_visible(False) 361 | ax.spines["top"].set_visible(False) 362 | ax.xaxis.set_ticks_position('bottom') 363 | ax.yaxis.set_ticks_position('left') 364 | 365 | if i == 0: 366 | plt.plot(x[:,0,0], color='xkcd:blue') 367 | plt.yticks([0, 1], ['', ''],rotation='vertical') 368 | plt.ylim([-0.1, 1.5]) 369 | plt.title('Fixation input', fontsize=fontsize, y=0.9) 370 | elif i == 1: 371 | plt.imshow(x[:, 0, 1:1+n_eachring].T, aspect='auto', cmap=cmap, 372 | vmin=0, vmax=1, interpolation='none',origin='lower') 373 | plt.yticks([0, (n_eachring-1)/2, n_eachring-1], 374 | [r'0$\degree$', '', r'360$\degree$'], 375 | rotation='vertical') 376 | plt.title('Stimulus mod 1', fontsize=fontsize, y=0.9) 377 | elif i == 2: 378 | plt.imshow(x[:, 0, 1+n_eachring:1+2*n_eachring].T, aspect='auto', 379 | cmap=cmap, vmin=0, vmax=1, 380 | interpolation='none', origin='lower') 381 | plt.yticks([0, (n_eachring-1)/2, n_eachring-1], ['', '', ''], 382 | rotation='vertical') 383 | plt.title('Stimulus mod 2', fontsize=fontsize, y=0.9) 384 | ax.get_yaxis().set_label_coords(-0.12,0.5) 385 | plt.savefig('figure/schematic_input.pdf',transparent=True) 386 | plt.show() 387 | 388 | # Plot Rule Inputs 389 | fig = plt.figure(figsize=(1.0, 0.5)) 390 | ax = fig.add_axes([0.2,0.3,0.7,0.45]) 391 | cmap = 'Purples' 392 | X = x[:, 0, 1+2*n_eachring:] 393 | plt.imshow(X.T, aspect='auto', vmin=0, vmax=1, cmap=cmap, 394 | interpolation='none', origin='lower') 395 | 396 | plt.xticks([0, X.shape[0]]) 397 | ax.set_xlabel('Time (ms)', fontsize=fontsize, labelpad=-5) 398 | 399 | # Fixed style for these plots 400 | ax.tick_params(axis='both', which='major', labelsize=fontsize, 401 | width=0.5, length=2, pad=3) 402 | ax.spines["left"].set_linewidth(0.5) 403 | ax.spines["right"].set_visible(False) 404 | ax.spines["bottom"].set_linewidth(0.5) 405 | ax.spines["top"].set_visible(False) 406 | ax.xaxis.set_ticks_position('bottom') 407 | ax.yaxis.set_ticks_position('left') 408 | 409 | plt.yticks([0, X.shape[-1]-1], ['1',str(X.shape[-1])], rotation='vertical') 410 | plt.title('Rule inputs', fontsize=fontsize, y=0.9) 411 | ax.get_yaxis().set_label_coords(-0.12,0.5) 412 | 413 | plt.savefig('figure/schematic_rule.pdf',transparent=True) 414 | plt.show() 415 | 416 | 417 | # Plot Units 418 | fig = plt.figure(figsize=(1.0, 0.8)) 419 | ax = fig.add_axes([0.2,0.1,0.7,0.75]) 420 | cmap = 'Purples' 421 | plt.xticks([]) 422 | # Fixed style for these plots 423 | ax.tick_params(axis='both', which='major', labelsize=fontsize, 424 | width=0.5, length=2, pad=3) 425 | ax.spines["left"].set_linewidth(0.5) 426 | ax.spines["right"].set_visible(False) 427 | ax.spines["bottom"].set_visible(False) 428 | ax.spines["top"].set_visible(False) 429 | ax.xaxis.set_ticks_position('bottom') 430 | ax.yaxis.set_ticks_position('left') 431 | 432 | plt.imshow(h[:, 0, :].T, aspect='auto', cmap=cmap, vmin=0, vmax=1, 433 | interpolation='none',origin='lower') 434 | plt.yticks([0,n_hidden-1],['1',str(n_hidden)],rotation='vertical') 435 | plt.title('Recurrent units', fontsize=fontsize, y=0.95) 436 | ax.get_yaxis().set_label_coords(-0.12,0.5) 437 | plt.savefig('figure/schematic_units.pdf',transparent=True) 438 | plt.show() 439 | 440 | 441 | # Plot Outputs 442 | fig = plt.figure(figsize=(1.0,0.8)) 443 | heights = np.array([0.1,0.45])+0.01 444 | for i in range(2): 445 | ax = fig.add_axes([0.2, sum(heights[i+1:]+0.15)+0.1, 0.7, heights[i]]) 446 | cmap = 'Purples' 447 | plt.xticks([]) 448 | 449 | # Fixed style for these plots 450 | ax.tick_params(axis='both', which='major', labelsize=fontsize, 451 | width=0.5, length=2, pad=3) 452 | ax.spines["left"].set_linewidth(0.5) 453 | ax.spines["right"].set_visible(False) 454 | ax.spines["bottom"].set_visible(False) 455 | ax.spines["top"].set_visible(False) 456 | ax.xaxis.set_ticks_position('bottom') 457 | ax.yaxis.set_ticks_position('left') 458 | 459 | if i == 0: 460 | plt.plot(y_hat[:,0,0],color='xkcd:blue') 461 | plt.yticks([0.05,0.8],['',''],rotation='vertical') 462 | plt.ylim([-0.1,1.1]) 463 | plt.title('Fixation output', fontsize=fontsize, y=0.9) 464 | 465 | elif i == 1: 466 | plt.imshow(y_hat[:,0,1:].T, aspect='auto', cmap=cmap, 467 | vmin=0, vmax=1, interpolation='none', origin='lower') 468 | plt.yticks([0, (n_eachring-1)/2, n_eachring-1], 469 | [r'0$\degree$', '', r'360$\degree$'], 470 | rotation='vertical') 471 | plt.xticks([]) 472 | plt.title('Response', fontsize=fontsize, y=0.9) 473 | 474 | ax.get_yaxis().set_label_coords(-0.12,0.5) 475 | 476 | plt.savefig('figure/schematic_outputs.pdf',transparent=True) 477 | plt.show() 478 | 479 | 480 | def networkx_illustration(model_dir): 481 | import networkx as nx 482 | 483 | model = Model(model_dir) 484 | with tf.Session() as sess: 485 | model.restore() 486 | # get all connection weights and biases as tensorflow variables 487 | w_rec = sess.run(model.w_rec) 488 | 489 | w_rec_flat = w_rec.flatten() 490 | ind_sort = np.argsort(abs(w_rec_flat - np.mean(w_rec_flat))) 491 | n_show = int(0.01*len(w_rec_flat)) 492 | ind_gone = ind_sort[:-n_show] 493 | ind_keep = ind_sort[-n_show:] 494 | w_rec_flat[ind_gone] = 0 495 | w_rec2 = np.reshape(w_rec_flat, w_rec.shape) 496 | w_rec_keep = w_rec_flat[ind_keep] 497 | G=nx.from_numpy_array(abs(w_rec2), create_using=nx.DiGraph()) 498 | 499 | color = w_rec_keep 500 | fig = plt.figure(figsize=(4, 4)) 501 | ax = fig.add_axes([0.1, 0.1, 0.8, 0.8]) 502 | nx.draw(G, 503 | linewidths=0, 504 | width=0.1, 505 | alpha=1.0, 506 | edge_vmin=-3, 507 | edge_vmax=3, 508 | arrows=False, 509 | pos=nx.circular_layout(G), 510 | node_color=np.array([99./255]*3), 511 | node_size=10, 512 | edge_color=color, 513 | edge_cmap=plt.cm.RdBu_r, 514 | ax=ax) 515 | plt.savefig('figure/illustration_networkx.pdf', transparent=True) 516 | 517 | 518 | if __name__ == "__main__": 519 | root_dir = './data/train_all' 520 | model_dir = root_dir + '/0' 521 | 522 | # Rules to analyze 523 | # rule = 'dm1' 524 | # rule = ['dmsgo','dmsnogo','dmcgo','dmcnogo'] 525 | 526 | # Easy activity plot, see this function to begin your analysis 527 | # rule = 'contextdm1' 528 | # easy_activity_plot(model_dir, rule) 529 | 530 | # Easy connectivity plot 531 | # easy_connectivity_plot(model_dir) 532 | 533 | # Plot sample activity 534 | # pretty_inputoutput_plot(model_dir, rule, save=False) 535 | 536 | # Plot a single in time 537 | # pretty_singleneuron_plot(model_dir, rule, [0], epoch=None, save=False, 538 | # trace_only=True, plot_stim_avg=True) 539 | 540 | # Plot activity histogram 541 | # model_dir = '/Users/guangyuyang/MyPython/RecurrentNetworkTraining/multitask/data/varyhp/33' 542 | # activity_histogram(model_dir, ['contextdm1', 'contextdm2']) 543 | 544 | # Plot schematic 545 | # schematic_plot(model_dir, rule) 546 | 547 | # Plot networkx illustration 548 | # networkx_illustration(model_dir) 549 | 550 | 551 | -------------------------------------------------------------------------------- /analysis/taskset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Task set analysis 3 | Analyze the state-space of stimulus-averaged activity 4 | """ 5 | 6 | from __future__ import division 7 | 8 | import os 9 | import numpy as np 10 | import pickle 11 | from collections import OrderedDict 12 | import matplotlib.pyplot as plt 13 | 14 | import tensorflow as tf 15 | 16 | from task import rule_name 17 | from task import generate_trials 18 | from network import Model 19 | from network import get_perf 20 | import tools 21 | 22 | 23 | class TaskSetAnalysis(object): 24 | """Analyzing the representation of tasks.""" 25 | 26 | def __init__(self, model_dir, rules=None): 27 | """Initialization. 28 | 29 | Args: 30 | model_dir: str, model directory 31 | rules: None or a list of rules 32 | """ 33 | # Stimulus-averaged traces 34 | h_stimavg_byrule = OrderedDict() 35 | h_stimavg_byepoch = OrderedDict() 36 | # Last time points of epochs 37 | h_lastt_byepoch = OrderedDict() 38 | 39 | model = Model(model_dir) 40 | hp = model.hp 41 | 42 | if rules is None: 43 | # Default value 44 | rules = hp['rules'] 45 | n_rules = len(rules) 46 | 47 | with tf.Session() as sess: 48 | model.restore() 49 | 50 | for rule in rules: 51 | trial = generate_trials(rule=rule, hp=hp, mode='test') 52 | feed_dict = tools.gen_feed_dict(model, trial, hp) 53 | h = sess.run(model.h, feed_dict=feed_dict) 54 | 55 | # Average across stimulus conditions 56 | h_stimavg = h.mean(axis=1) 57 | 58 | # dt_new = 50 59 | # every_t = int(dt_new/hp['dt']) 60 | 61 | t_start = int(500/hp['dt']) # Important: Ignore the initial transition 62 | # Average across stimulus conditions 63 | h_stimavg_byrule[rule] = h_stimavg[t_start:, :] 64 | 65 | for e_name, e_time in trial.epochs.items(): 66 | if 'fix' in e_name: 67 | continue 68 | 69 | # if ('fix' not in e_name) and ('go' not in e_name): 70 | # Take epoch 71 | e_time_start = e_time[0]-1 if e_time[0]>0 else 0 72 | h_stimavg_byepoch[(rule, e_name)] = h_stimavg[e_time_start:e_time[1],:] 73 | # Take last time point from epoch 74 | # h_all_byepoch[(rule, e_name)] = np.mean(h[e_time[0]:e_time[1],:,:][-1], axis=1) 75 | h_lastt_byepoch[(rule, e_name)] = h[e_time[1],:,:] 76 | 77 | self.rules = rules 78 | self.h_stimavg_byrule = h_stimavg_byrule 79 | self.h_stimavg_byepoch = h_stimavg_byepoch 80 | self.h_lastt_byepoch = h_lastt_byepoch 81 | self.model_dir = model_dir 82 | 83 | @staticmethod 84 | def filter(h, rules=None, epochs=None, non_rules=None, non_epochs=None, 85 | get_lasttimepoint=True, get_timeaverage=False, **kwargs): 86 | # h should be a dictionary 87 | # get a new dictionary containing keys from the list of rules and epochs 88 | # And avoid epochs from non_rules and non_epochs 89 | # h_new = OrderedDict([(key, val) for key, val in h.items() if key[1] in epochs]) 90 | 91 | if get_lasttimepoint: 92 | print('Analyzing last time points of epochs') 93 | if get_timeaverage: 94 | print('Analyzing time-averaged activities of epochs') 95 | 96 | h_new = OrderedDict() 97 | for key in h: 98 | rule, epoch = key 99 | 100 | include_key = True 101 | if rules is not None: 102 | include_key = include_key and (rule in rules) 103 | 104 | if epochs is not None: 105 | include_key = include_key and (epoch in epochs) 106 | 107 | if non_rules is not None: 108 | include_key = include_key and (rule not in non_rules) 109 | 110 | if non_epochs is not None: 111 | include_key = include_key and (epoch not in non_epochs) 112 | 113 | if include_key: 114 | if get_lasttimepoint: 115 | h_new[key] = h[key][np.newaxis, -1, :] 116 | elif get_timeaverage: 117 | h_new[key] = np.mean(h[key], axis=0, keepdims=True) 118 | else: 119 | h_new[key] = h[key] 120 | 121 | return h_new 122 | 123 | def compute_and_plot_taskspace(self, rules=None, epochs=None, **kwargs): 124 | h_trans = self.compute_taskspace(rules=rules, epochs=epochs, **kwargs) 125 | self.plot_taskspace(h_trans, **kwargs) 126 | 127 | def compute_taskspace(self, rules=None, epochs=None, dim_reduction_type='MDS', **kwargs): 128 | # Only get last time points for each epoch 129 | h = self.filter(self.h_stimavg_byepoch, epochs=epochs, rules=rules, **kwargs) 130 | 131 | # Concatenate across rules to create dataset 132 | data = np.concatenate(list(h.values()), axis=0) 133 | data = data.astype(dtype='float64') 134 | 135 | # First reduce dimension to dimension of data points 136 | from sklearn.decomposition import PCA 137 | n_comp = int(np.min([data.shape[0], data.shape[1]])-1) 138 | model = PCA(n_components=n_comp) 139 | data = model.fit_transform(data) 140 | 141 | if dim_reduction_type == 'PCA': 142 | model = PCA(n_components=2) 143 | 144 | elif dim_reduction_type == 'MDS': 145 | from sklearn.manifold import MDS 146 | model = MDS(n_components=2, metric=True, random_state=0) 147 | 148 | elif dim_reduction_type == 'TSNE': 149 | from sklearn.manifold import TSNE 150 | model = TSNE(n_components=2, init='pca', 151 | verbose=1, method='exact', learning_rate=100, perplexity=5) 152 | 153 | elif dim_reduction_type == 'IsoMap': 154 | from sklearn.manifold import Isomap 155 | model = Isomap(n_components=2) 156 | 157 | else: 158 | raise ValueError('Unknown dim_reduction_type') 159 | 160 | # Transform data 161 | data_trans = model.fit_transform(data) 162 | 163 | # Package back to dictionary 164 | h_trans = OrderedDict() 165 | i_start = 0 166 | for key, val in h.items(): 167 | i_end = i_start + val.shape[0] 168 | h_trans[key] = data_trans[i_start:i_end, :] 169 | i_start = i_end 170 | 171 | return h_trans 172 | 173 | def plot_taskspace(self, h_trans, epochs=None, dim_reduction_type='MDS', 174 | plot_text=True, figsize=(4,4), markersize=5, plot_label=True): 175 | # Plot tasks in space 176 | shape_mapping = {'stim1' : 'o', 177 | 'stim2' : 'o', 178 | 'delay1' : 'v', 179 | 'delay2' : 'd', 180 | 'go1' : 's', 181 | 'fix1' : 'p'} 182 | 183 | from analysis.performance import rule_color 184 | 185 | fs = 6 # fontsize 186 | dim0, dim1 = (0, 1) # plot dimensions 187 | 188 | texts = list() 189 | 190 | fig = plt.figure(figsize=figsize) 191 | ax = fig.add_axes([0.05, 0.05, 0.9, 0.9]) 192 | 193 | for key, val in h_trans.items(): 194 | rule, epoch = key 195 | 196 | # Default coloring by rule_color 197 | color = rule_color[rule] 198 | 199 | ax.plot(val[-1, dim0], val[-1, dim1], shape_mapping[epoch], 200 | color=color, mec=color, mew=1.0, ms=markersize) 201 | 202 | if plot_text: 203 | texts.append(ax.text(val[-1, dim0]+0.03, val[-1, dim1]+0.03, rule_name[rule], 204 | fontsize=6, color=color)) 205 | 206 | if 'fix' not in epoch: 207 | ax.plot(val[:, dim0], val[:, dim1], color=color, alpha=0.5) 208 | 209 | if plot_label: 210 | if dim_reduction_type == 'PCA': 211 | xlabel = 'PC {:d}'.format(dim0+1) 212 | ylabel = 'PC {:d}'.format(dim1+1) 213 | else: 214 | xlabel = dim_reduction_type + ' dim. {:d}'.format(dim0+1) 215 | ylabel = dim_reduction_type + ' dim. {:d}'.format(dim1+1) 216 | ax.set_xlabel(xlabel, fontsize=fs) 217 | ax.set_ylabel(ylabel, fontsize=fs) 218 | ax.tick_params(axis='both', which='major', labelsize=fs) 219 | # plt.locator_params(nbins=3) 220 | ax.spines["right"].set_visible(False) 221 | ax.spines["top"].set_visible(False) 222 | ax.set_xticks([]) 223 | ax.set_yticks([]) 224 | ax.margins(0.1) 225 | # ax.xaxis.set_ticks_position('bottom') 226 | # ax.yaxis.set_ticks_position('left') 227 | 228 | save_name = 'taskspace'+dim_reduction_type 229 | 230 | if epochs is not None: 231 | save_name = save_name + ''.join(epochs) 232 | 233 | plt.savefig(os.path.join('figure', save_name+'.pdf'), transparent=True) 234 | plt.show() 235 | 236 | 237 | def compute_taskspace(model_dir, setup, restore=False, representation='rate'): 238 | if setup == 1: 239 | rules = ['fdgo', 'fdanti', 'delaygo', 'delayanti'] 240 | elif setup == 2: 241 | rules = ['contextdelaydm1', 'contextdelaydm2', 'contextdm1', 'contextdm2'] 242 | elif setup == 3: 243 | rules = ['dmsgo', 'dmcgo', 'dmsnogo', 'dmcnogo'] 244 | elif setup == 4: 245 | rules = ['contextdelaydm1', 'contextdelaydm2', 'multidelaydm', 246 | 'contextdm1', 'contextdm2', 'multidm'] 247 | elif setup == 5: 248 | rules = ['contextdelaydm1', 'contextdelaydm2', 'multidelaydm', 249 | 'delaydm1', 'delaydm2', 'contextdm1', 'contextdm2', 250 | 'multidm', 'dm1', 'dm2',] 251 | elif setup == 6: 252 | rules = ['fdgo', 'delaygo', 'contextdm1', 'contextdelaydm1'] 253 | 254 | if representation == 'rate': 255 | fname = 'taskset{:d}_space'.format(setup)+'.pkl' 256 | fname = os.path.join(model_dir, fname) 257 | 258 | if restore and os.path.isfile(fname): 259 | print('Reloading results from '+fname) 260 | h_trans = tools.load_pickle(fname) 261 | else: 262 | tsa = TaskSetAnalysis(model_dir, rules=rules) 263 | h_trans = tsa.compute_taskspace(rules=rules, epochs=['stim1'], 264 | dim_reduction_type='PCA', setup=setup) 265 | with open(fname, 'wb') as f: 266 | pickle.dump(h_trans, f) 267 | print('Results stored at : '+fname) 268 | 269 | elif representation == 'weight': 270 | from task import get_rule_index 271 | 272 | model = Model(model_dir) 273 | hp = model.hp 274 | n_hidden = hp['n_rnn'] 275 | n_output = hp['n_output'] 276 | with tf.Session() as sess: 277 | model.restore() 278 | w_in = sess.run(model.w_in).T 279 | 280 | rule_indices = [get_rule_index(r, hp) for r in rules] 281 | w_rules = w_in[:, rule_indices] 282 | 283 | from sklearn.decomposition import PCA 284 | model = PCA(n_components=2) 285 | 286 | # Transform data 287 | data_trans = model.fit_transform(w_rules.T) 288 | 289 | # Turn into dictionary, and consistent with previous code 290 | h_trans = OrderedDict() 291 | for i, r in enumerate(rules): 292 | # shape will be (1,2), and the key is added an epoch value only for consistency 293 | h_trans[(r,'stim1')] = np.array([data_trans[i]]) 294 | 295 | else: 296 | raise ValueError() 297 | 298 | return h_trans 299 | 300 | 301 | def _plot_taskspace(h_trans, fig_name='temp', plot_example=False, lxy=None, 302 | plot_arrow=True, **kwargs): 303 | from analysis.performance import rule_color 304 | figsize = (1.7,1.7) 305 | fs = 7 # fontsize 306 | dim0, dim1 = (0, 1) # plot dimensions 307 | i_example = 0 # index of the example to plot 308 | 309 | texts = list() 310 | 311 | maxv0, maxv1 = -1, -1 312 | 313 | fig = plt.figure(figsize=figsize) 314 | ax = fig.add_axes([0.2, 0.2, 0.65, 0.65]) 315 | 316 | for key, val in h_trans.items(): 317 | rule, epoch = key 318 | # Default coloring by rule_color 319 | color = rule_color[rule] 320 | 321 | if plot_example: 322 | xplot, yplot = val[i_example,dim0], val[i_example,dim1] 323 | else: 324 | xplot, yplot = val[:,dim0], val[:,dim1] 325 | 326 | ax.plot(xplot, yplot, 'o', color=color, mec=color, mew=1.0, ms=2) 327 | 328 | 329 | xtext = np.mean(val[:,dim0]) 330 | if np.mean(val[:,dim1])>0: 331 | ytext = np.max(val[:,dim1]) 332 | va = 'bottom' 333 | else: 334 | ytext = np.min(val[:,dim1]) 335 | va = 'top' 336 | 337 | texts.append(ax.text(xtext*1.1, ytext*1.1, rule_name[rule], 338 | fontsize=6, color=color, 339 | horizontalalignment='center', verticalalignment=va)) 340 | 341 | maxv0 = np.max([maxv0, np.max(abs(val[:,dim0]))]) 342 | maxv1 = np.max([maxv1, np.max(abs(val[:,dim1]))]) 343 | 344 | if kwargs['setup'] == 1: 345 | arrow_starts = [h_trans[('fdgo','stim1')], h_trans[('fdanti','stim1')]] 346 | arrow_ends = [h_trans[('delaygo','stim1')], 347 | h_trans[('delayanti','stim1')]] 348 | elif kwargs['setup'] == 2: 349 | arrow_starts = [h_trans[('contextdm1','stim1')], 350 | h_trans[('contextdelaydm1','stim1')]] 351 | arrow_ends = [h_trans[('contextdm2','stim1')], 352 | h_trans[('contextdelaydm2','stim1')]] 353 | elif kwargs['setup'] == 3: 354 | arrow_starts = [h_trans[('dmsgo','stim1')], 355 | h_trans[('dmsnogo','stim1')]] 356 | arrow_ends = [h_trans[('dmcgo','stim1')], 357 | h_trans[('dmcnogo','stim1')]] 358 | else: 359 | plot_arrow = False 360 | 361 | if plot_arrow: 362 | for arrow_start, arrow_end in zip(arrow_starts, arrow_ends): 363 | if plot_example: 364 | a_start = arrow_start[i_example,[dim0, dim1]] 365 | a_end = arrow_end[i_example,[dim0, dim1]] 366 | else: 367 | a_start = arrow_start[:,[dim0, dim1]].mean(axis=0) 368 | a_end = arrow_end[:,[dim0, dim1]].mean(axis=0) 369 | ax.annotate("", xy=a_start, xytext=a_end, 370 | arrowprops=dict(arrowstyle="<-", ec='gray')) 371 | 372 | if lxy is None: 373 | lx = np.ceil(maxv0) 374 | ly = np.ceil(maxv1) 375 | else: 376 | lx, ly = lxy 377 | 378 | ax.tick_params(axis='both', which='major', labelsize=fs) 379 | # plt.locator_params(nbins=3) 380 | ax.spines["right"].set_visible(False) 381 | ax.spines["top"].set_visible(False) 382 | # ax.set_xticks([]) 383 | # ax.set_yticks([]) 384 | ax.margins(0.1) 385 | # plt.axis('equal') 386 | plt.xlim([-lx,lx]) 387 | plt.ylim([-ly,ly]) 388 | ax.plot([0,0], [-ly,ly], '--', color='gray') 389 | ax.plot([-lx,lx], [0,0], '--', color='gray') 390 | ax.set_xticks([-lx,lx]) 391 | ax.set_yticks([-ly,ly]) 392 | ax.xaxis.set_ticks_position('bottom') 393 | ax.yaxis.set_ticks_position('left') 394 | pc_name = 'rPC' 395 | ax.set_xlabel(pc_name+' {:d}'.format(dim0+1), fontsize=fs, labelpad=-5) 396 | ax.set_ylabel(pc_name+' {:d}'.format(dim1+1), fontsize=fs, labelpad=-5) 397 | 398 | plt.savefig(os.path.join('figure', fig_name+'.pdf'), transparent=True) 399 | plt.show() 400 | 401 | return (lx, ly) 402 | 403 | 404 | def plot_taskspace(model_dir, setup=1, restore=True, representation='rate'): 405 | h_trans = compute_taskspace( 406 | model_dir, setup, restore=restore, representation=representation) 407 | save_name = 'taskset{:d}_space'.format(setup) 408 | _plot_taskspace(h_trans, save_name, setup=setup) 409 | 410 | 411 | def plot_taskspace_group(root_dir, setup=1, restore=True, 412 | representation='rate', fig_name_addon=None): 413 | """Plot task space for a group of networks. 414 | 415 | Args: 416 | root_dir : the root directory for all models to analyse 417 | setup: int, the combination of rules to use 418 | restore: bool, whether to restore results 419 | representation: 'rate' or 'weight' 420 | """ 421 | 422 | model_dirs = tools.valid_model_dirs(root_dir) 423 | print('Analyzing models : ') 424 | print(model_dirs) 425 | 426 | h_trans_all = OrderedDict() 427 | i = 0 428 | for model_dir in model_dirs: 429 | try: 430 | h_trans = compute_taskspace(model_dir, setup, 431 | restore=restore, 432 | representation=representation) 433 | except ValueError: 434 | print('Skipping model at ' + model_dir) 435 | continue 436 | 437 | h_trans_values = list(h_trans.values()) 438 | 439 | # When PC1 and PC2 capture similar variances, allow for a rotation 440 | # rotation_matrix, clock wise 441 | get_angle = lambda vec : np.arctan2(vec[1], vec[0]) 442 | theta = get_angle(h_trans_values[0][0]) 443 | # theta = 0 444 | rot_mat = np.array([[np.cos(theta), -np.sin(theta)], 445 | [np.sin(theta), np.cos(theta)]]) 446 | 447 | for key, val in h_trans.items(): 448 | h_trans[key] = np.dot(val, rot_mat) 449 | 450 | h_trans_values = list(h_trans.values()) 451 | if h_trans_values[1][0][1] < 0: 452 | for key, val in h_trans.items(): 453 | h_trans[key] = val*np.array([1, -1]) 454 | 455 | if i == 0: 456 | for key, val in h_trans.items(): 457 | h_trans_all[key] = val 458 | else: 459 | for key, val in h_trans.items(): 460 | h_trans_all[key] = np.concatenate((h_trans_all[key], val), axis=0) 461 | i += 1 462 | fig_name = 'taskset{:d}_{:s}space'.format(setup, representation) 463 | if fig_name_addon is not None: 464 | fig_name = fig_name + fig_name_addon 465 | 466 | lxy = _plot_taskspace(h_trans_all, fig_name, setup=setup) 467 | fig_name = fig_name + '_example' 468 | lxy = _plot_taskspace(h_trans_all, fig_name, setup=setup, 469 | plot_example=True, lxy=lxy) 470 | 471 | 472 | def run_network_replacerule(model_dir, rule, replace_rule, rule_strength): 473 | """Run the network but with replaced rule input weights. 474 | 475 | Args: 476 | model_dir: model directory 477 | rule: the rule to test on 478 | replace_rule: a list of rule input units to use 479 | rule_strength: the relative strength of each replace rule unit 480 | """ 481 | model = Model(model_dir) 482 | hp = model.hp 483 | with tf.Session() as sess: 484 | model.restore() 485 | 486 | # Get performance 487 | batch_size_test = 1000 488 | n_rep = 20 489 | batch_size_test_rep = int(batch_size_test/n_rep) 490 | perf_rep = list() 491 | for i_rep in range(n_rep): 492 | trial = generate_trials(rule, hp, 'random', batch_size=batch_size_test_rep, 493 | replace_rule=replace_rule, rule_strength=rule_strength) 494 | feed_dict = tools.gen_feed_dict(model, trial, hp) 495 | y_hat_test = sess.run(model.y_hat, feed_dict=feed_dict) 496 | 497 | perf_rep.append(np.mean(get_perf(y_hat_test, trial.y_loc))) 498 | 499 | return np.mean(perf_rep), rule_strength 500 | 501 | 502 | def replace_rule_name(replace_rule, rule_strength): 503 | """Helper function to replace rule name""" 504 | # little helper function 505 | name = '' 506 | counter = 0 507 | for r, b in zip(replace_rule, rule_strength): 508 | if b != 0: 509 | 510 | if b == 1: 511 | if counter==0: 512 | prefix = '' 513 | else: 514 | prefix = '+' 515 | elif b == -1: 516 | prefix = '-' 517 | else: 518 | prefix = '{:+d}'.format(b) 519 | name += prefix + rule_name[r] + '\n' 520 | counter += 1 521 | # get rid of the last \n 522 | name = name[:-1] 523 | return name 524 | 525 | 526 | def compute_replacerule_performance(model_dir, setup, restore=False): 527 | """Compute the performance of one task given a replaced rule input.""" 528 | 529 | if setup == 1: 530 | rule = 'delayanti' 531 | replace_rule = np.array(['delayanti', 'fdanti', 'delaygo', 'fdgo']) 532 | 533 | rule_strengths = \ 534 | [[1,0,0,0], 535 | [0,1,0,0], 536 | [0,1,1,0], 537 | [0,1,1,-1]] 538 | 539 | elif setup == 2: 540 | rule = 'contextdelaydm1' 541 | replace_rule = np.array(['contextdelaydm1', 'contextdelaydm2', 542 | 'contextdm1', 'contextdm2']) 543 | 544 | rule_strengths = \ 545 | [[1,0,0,0], 546 | [0,1,0,0], 547 | [0,1,1,0], 548 | [0,0,1,0], 549 | [0,1,1,-1]] 550 | 551 | elif setup == 3: 552 | rule = 'dmsgo' 553 | replace_rule = np.array(['dmsgo', 'dmcgo', 'dmsnogo', 'dmcnogo']) 554 | rule_strengths = \ 555 | [[1,0,0,0], 556 | [0,1,0,0], 557 | [0,1,1,0], 558 | [0,1,1,-1]] 559 | 560 | else: 561 | raise ValueError('Unknown setup value') 562 | 563 | fname = 'taskset{:d}_perf'.format(setup)+'.pkl' 564 | fname = os.path.join(model_dir, fname) 565 | 566 | if restore and os.path.isfile(fname): 567 | print('Reloading results from '+fname) 568 | r = tools.load_pickle(fname) 569 | perfs, rule, names = r['perfs'], r['rule'], r['names'] 570 | 571 | else: 572 | perfs = list() 573 | names = list() 574 | for rule_strength in rule_strengths: 575 | perf, _ = run_network_replacerule(model_dir, rule, replace_rule, rule_strength) 576 | perfs.append(perf) 577 | names.append(replace_rule_name(replace_rule, rule_strength)) 578 | 579 | perfs = np.array(perfs) 580 | print(perfs) 581 | 582 | results = {'perfs':perfs, 'rule':rule, 'names':names} 583 | with open(fname, 'wb') as f: 584 | pickle.dump(results, f) 585 | print('Results stored at : '+fname) 586 | 587 | return perfs, rule, names 588 | 589 | 590 | def _plot_replacerule_performance(perfs_all, rule, names, setup, fig_name=None): 591 | perfs_all = perfs_all.T # make it (4, n_nets) 592 | from scipy.stats import mannwhitneyu 593 | print(mannwhitneyu(perfs_all[-1], perfs_all[-2])) 594 | print(mannwhitneyu(perfs_all[-1], perfs_all[-3])) 595 | 596 | n_condition, n_net = perfs_all.shape 597 | fs = 7 598 | fig = plt.figure(figsize=(1.6,2.2)) 599 | ax = fig.add_axes([0.55,0.05,0.35,0.7]) 600 | 601 | bp = ax.boxplot(list(perfs_all[::-1]), notch=True, vert=False, bootstrap=10000, 602 | showcaps=False, patch_artist=True, widths=0.4, 603 | flierprops={'markersize': 2}, whiskerprops={'linewidth': 1.5}) 604 | for element in ['boxes', 'whiskers', 'fliers']: 605 | plt.setp(bp[element], color='xkcd:cerulean') 606 | for patch in bp['boxes']: 607 | patch.set_facecolor('xkcd:cerulean') 608 | for element in ['means', 'medians']: 609 | plt.setp(bp[element], color='white') 610 | 611 | ax.set_yticks(np.arange(1, 1+n_condition)) 612 | ax.set_yticklabels(names[::-1], rotation=0, horizontalalignment='right') 613 | ax.set_ylabel('Rule input', fontsize=fs, labelpad=3) 614 | # ax.set_ylabel('performance', fontsize=fs) 615 | title = 'Performance on\n'+rule_name[rule] 616 | if perfs_all is not None: 617 | n_net = perfs_all.shape[1] 618 | title = title + ' (n={:d})'.format(n_net) 619 | ax.set_title(title, fontsize=fs, y=1.13) 620 | ax.tick_params(axis='both', which='major', labelsize=fs) 621 | ax.spines["right"].set_visible(False) 622 | ax.spines["bottom"].set_visible(False) 623 | ax.xaxis.set_ticks_position('top') 624 | ax.yaxis.set_ticks_position('left') 625 | ax.xaxis.grid(True) 626 | ax.set_xticks([0,0.5,1.0]) 627 | ax.set_xlim([0, 1.05]) 628 | ax.set_ylim([0.5, n_condition+0.5]) 629 | if fig_name is None: 630 | fig_name = 'taskset{:d}_perf'.format(setup) 631 | plt.savefig(os.path.join('figure', fig_name+'.pdf'), transparent=True) 632 | plt.show() 633 | 634 | 635 | def plot_replacerule_performance( 636 | model_dir,setup, perfs_all=None, fig_name=None, restore=True): 637 | perfs, rule, names = compute_replacerule_performance( 638 | model_dir, setup, restore) 639 | _plot_replacerule_performance( 640 | perfs_all, rule, names, setup, fig_name) 641 | 642 | 643 | def plot_replacerule_performance_group(model_dir, setup=1, restore=True, fig_name_addon=None): 644 | model_dirs = tools.valid_model_dirs(model_dir) 645 | print('Analyzing models : ') 646 | print(model_dirs) 647 | 648 | perfs_plot = list() 649 | for model_dir in model_dirs: 650 | perfs, rule, names = compute_replacerule_performance(model_dir, setup, restore) 651 | perfs_plot.append(perfs) 652 | 653 | perfs_plot = np.array(perfs_plot) 654 | perfs_median = np.median(perfs_plot, axis=0) 655 | 656 | fig_name = 'taskset{:d}_perf'.format(setup) 657 | if fig_name_addon is not None: 658 | fig_name = fig_name + fig_name_addon 659 | 660 | print(perfs_median) 661 | _plot_replacerule_performance(perfs_plot, rule, names, setup, fig_name=fig_name) 662 | 663 | 664 | if __name__ == '__main__': 665 | root_dir = './data/train_all' 666 | model_dir = root_dir + '/0' 667 | setups = [3] 668 | for setup in setups: 669 | pass 670 | plot_taskspace_group(root_dir, setup=setup, 671 | restore=True, representation='rate') 672 | plot_taskspace_group(root_dir, setup=setup, 673 | restore=True, representation='weight') 674 | plot_replacerule_performance_group( 675 | root_dir, setup=setup, restore=True) -------------------------------------------------------------------------------- /analysis/variance.py: -------------------------------------------------------------------------------- 1 | """ 2 | Compute Variance 3 | """ 4 | 5 | from __future__ import division 6 | 7 | import os 8 | import time 9 | import numpy as np 10 | import pickle 11 | from collections import OrderedDict 12 | import matplotlib as mpl 13 | import matplotlib.pyplot as plt 14 | import tensorflow as tf 15 | 16 | from task import * 17 | from network import Model 18 | import tools 19 | 20 | save = True 21 | 22 | 23 | def _compute_variance_bymodel(model, sess, rules=None, random_rotation=False): 24 | """Compute variance for all tasks. 25 | 26 | Args: 27 | model: network.Model instance 28 | sess: tensorflow session 29 | rules: list of rules to compute variance, list of strings 30 | random_rotation: boolean. If True, rotate the neural activity. 31 | """ 32 | h_all_byrule = OrderedDict() 33 | h_all_byepoch = OrderedDict() 34 | hp = model.hp 35 | 36 | if rules is None: 37 | rules = hp['rules'] 38 | print(rules) 39 | 40 | n_hidden = hp['n_rnn'] 41 | 42 | if random_rotation: 43 | # Generate random orthogonal matrix 44 | from scipy.stats import ortho_group 45 | random_ortho_matrix = ortho_group.rvs(dim=n_hidden) 46 | 47 | for rule in rules: 48 | trial = generate_trials(rule, hp, 'test', noise_on=False) 49 | feed_dict = tools.gen_feed_dict(model, trial, hp) 50 | h = sess.run(model.h, feed_dict=feed_dict) 51 | if random_rotation: 52 | h = np.dot(h, random_ortho_matrix) # randomly rotate 53 | 54 | for e_name, e_time in trial.epochs.items(): 55 | if 'fix' not in e_name: # Ignore fixation period 56 | h_all_byepoch[(rule, e_name)] = h[e_time[0]:e_time[1], :, 57 | :] 58 | 59 | # Ignore fixation period 60 | h_all_byrule[rule] = h[trial.epochs['fix1'][1]:, :, :] 61 | 62 | # Reorder h_all_byepoch by epoch-first 63 | keys = list(h_all_byepoch.keys()) 64 | # ind_key_sort = np.lexsort(zip(*keys)) 65 | # Using mergesort because it is stable 66 | ind_key_sort = np.argsort(list(zip(*keys))[1], kind='mergesort') 67 | h_all_byepoch = OrderedDict( 68 | [(keys[i], h_all_byepoch[keys[i]]) for i in ind_key_sort]) 69 | 70 | for data_type in ['rule', 'epoch']: 71 | if data_type == 'rule': 72 | h_all = h_all_byrule 73 | elif data_type == 'epoch': 74 | h_all = h_all_byepoch 75 | else: 76 | raise ValueError 77 | 78 | h_var_all = np.zeros((n_hidden, len(h_all.keys()))) 79 | for i, val in enumerate(h_all.values()): 80 | # val is Time, Batch, Units 81 | # Variance across time and stimulus 82 | # h_var_all[:, i] = val[t_start:].reshape((-1, n_hidden)).var(axis=0) 83 | # Variance acros stimulus, then averaged across time 84 | h_var_all[:, i] = val.var(axis=1).mean(axis=0) 85 | 86 | result = {'h_var_all': h_var_all, 'keys': list(h_all.keys())} 87 | save_name = 'variance_' + data_type 88 | if random_rotation: 89 | save_name += '_rr' 90 | 91 | fname = os.path.join(model.model_dir, save_name + '.pkl') 92 | print('Variance saved at {:s}'.format(fname)) 93 | with open(fname, 'wb') as f: 94 | pickle.dump(result, f) 95 | 96 | 97 | def _compute_variance(model_dir, rules=None, random_rotation=False): 98 | """Compute variance for all tasks. 99 | 100 | Args: 101 | model_dir: str, the path of the model directory 102 | rules: list of rules to compute variance, list of strings 103 | random_rotation: boolean. If True, rotate the neural activity. 104 | """ 105 | model = Model(model_dir, sigma_rec=0) 106 | with tf.Session() as sess: 107 | model.restore() 108 | _compute_variance_bymodel(model, sess, rules, random_rotation) 109 | 110 | 111 | def compute_variance(model_dir, rules=None, random_rotation=False): 112 | """Compute variance for all tasks. 113 | 114 | Args: 115 | model_dir: str, the path of the model directory 116 | rules: list of rules to compute variance, list of strings 117 | random_rotation: boolean. If True, rotate the neural activity. 118 | """ 119 | dirs = tools.valid_model_dirs(model_dir) 120 | for d in dirs: 121 | _compute_variance(d, rules, random_rotation) 122 | 123 | 124 | def _compute_hist_varprop(model_dir, rule_pair, random_rotation=False): 125 | data_type = 'rule' 126 | assert len(rule_pair) == 2 127 | assert data_type == 'rule' 128 | 129 | fname = os.path.join(model_dir, 'variance_'+data_type) 130 | if random_rotation: 131 | fname += '_rr' 132 | fname += '.pkl' 133 | if not os.path.isfile(fname): 134 | # If not computed, compute now 135 | compute_variance(model_dir, random_rotation=random_rotation) 136 | 137 | res = tools.load_pickle(fname) 138 | h_var_all = res['h_var_all'] 139 | keys = res['keys'] 140 | 141 | ind_rules = [keys.index(rule) for rule in rule_pair] 142 | h_var_all = h_var_all[:, ind_rules] 143 | 144 | # First only get active units. Total variance across tasks larger than 1e-3 145 | ind_active = np.where(h_var_all.sum(axis=1) > 1e-3)[0] 146 | 147 | # Temporary: Mimicking biased sampling. Notice the free parameter though. 148 | # print('Mimicking selective sampling') 149 | # ind_active = np.where((h_var_all.sum(axis=1) > 1e-3)*(h_var_all[:,0]>1*1e-2))[0] 150 | 151 | h_var_all = h_var_all[ind_active, :] 152 | 153 | # Normalize by the total variance across tasks 154 | h_normvar_all = (h_var_all.T/np.sum(h_var_all, axis=1)).T 155 | 156 | # Plot the proportion of variance for the first rule 157 | # data_plot = h_normvar_all[:, 0] 158 | data_plot = (h_var_all[:, 0]-h_var_all[:, 1])/((h_var_all[:, 0]+h_var_all[:, 1])) 159 | hist, bins_edge = np.histogram(data_plot, bins=20, range=(-1,1)) 160 | 161 | # # Plot the percentage instead of the total count 162 | # hist = hist/np.sum(hist) 163 | 164 | return hist, bins_edge 165 | 166 | 167 | def compute_hist_varprop(model_dir, rule_pair, random_rotation=False): 168 | data_type = 'rule' 169 | assert len(rule_pair) == 2 170 | assert data_type == 'rule' 171 | 172 | model_dirs = tools.valid_model_dirs(model_dir) 173 | 174 | hists = list() 175 | for model_dir in model_dirs: 176 | hist, bins_edge_ = _compute_hist_varprop(model_dir, rule_pair, random_rotation) 177 | if hist is None: 178 | continue 179 | else: 180 | bins_edge = bins_edge_ 181 | 182 | # Store 183 | hists.append(hist) 184 | 185 | # Get median of all histogram 186 | hists = np.array(hists) 187 | # hist_low, hist_med, hist_high = np.percentile(hists, [10, 50, 90], axis=0) 188 | 189 | return hists, bins_edge 190 | 191 | def _plot_hist_varprop(hist_plot, bins_edge, rule_pair, hist_example=None, 192 | plot_legend=False, figname=None, title=None): 193 | '''Plot histogram of fractional variance''' 194 | # Plot the percentage instead of the total count 195 | hist_plot = hist_plot/np.sum(hist_plot) 196 | if hist_example is not None: 197 | hist_example = hist_example/np.sum(hist_example) 198 | 199 | fs = 6 200 | fig = plt.figure(figsize=(1.5,1.2)) 201 | ax = fig.add_axes([0.2,0.3,0.6,0.5]) 202 | legends = list() 203 | labels = list() 204 | if hist_example is not None: 205 | pass 206 | bar = ax.bar(bins_edge[:-1], hist_example, width=bins_edge[1]-bins_edge[0], 207 | color='xkcd:cerulean', edgecolor='none') 208 | legends.append(bar) 209 | labels.append('Example network') 210 | pl, = ax.plot((bins_edge[:-1]+bins_edge[1:])/2, hist_plot, color='black', linewidth=1.5) 211 | legends.append(pl) 212 | labels.append('All networks') 213 | # ax.plot((bins_edge[:-1]+bins_edge[1:])/2, hist_low) 214 | # ax.plot((bins_edge[:-1]+bins_edge[1:])/2, hist_high) 215 | plt.locator_params(nbins=3) 216 | xlabel = 'FTV({:s}, {:s})'.format(rule_name[rule_pair[0]], rule_name[rule_pair[1]]) 217 | ax.set_xlabel(xlabel, fontsize=fs) 218 | ax.set_ylim(bottom=-0.02*hist_plot.max()) 219 | ax.set_xlim([-1.1,1.1]) 220 | ax.spines["top"].set_visible(False) 221 | ax.spines["right"].set_visible(False) 222 | ax.xaxis.set_ticks_position('bottom') 223 | ax.yaxis.set_ticks_position('left') 224 | ax.tick_params(axis='both', which='major', labelsize=fs, length=2) 225 | if title: 226 | ax.set_title(title, fontsize=7) 227 | if plot_legend: 228 | lg = plt.legend(legends, labels, ncol=1,bbox_to_anchor=(1.1,1.3), 229 | fontsize=fs,labelspacing=0.3,loc=1, frameon=False) 230 | plt.setp(lg.get_title(),fontsize=fs) 231 | if save: 232 | if figname is None: 233 | figname = 'plot_hist_varprop_tmp.pdf' 234 | plt.savefig(os.path.join('figure', figname), transparent=True) 235 | 236 | 237 | def plot_hist_varprop(model_dir, 238 | rule_pair, 239 | plot_example=False, 240 | figname_extra=None, 241 | **kwargs): 242 | """ 243 | Plot histogram of proportion of variance for some tasks across units 244 | 245 | Args: 246 | model_dir: model directory 247 | rule_pair: tuple of strings, pair of rules 248 | plot_example: bool 249 | figname_extra: string or None 250 | """ 251 | 252 | hists, bins_edge = compute_hist_varprop(model_dir, rule_pair) 253 | 254 | hist_low, hist_med, hist_high = np.percentile(hists, [10, 50, 90], axis=0) 255 | 256 | # hist_med, bins_edge = np.histogram(data_plots, bins=20, range=(0,1)) 257 | # hist_med = np.array(hist_med)/len(hdims) 258 | 259 | if plot_example: 260 | hist_example = hists[0] 261 | else: 262 | hist_example = None 263 | 264 | hist_plot = hist_med 265 | figname = 'plot_hist_varprop' + rule_pair[0] + rule_pair[1] 266 | figname = figname.replace('*','') 267 | if figname_extra: 268 | figname += figname_extra 269 | _plot_hist_varprop(hist_plot, bins_edge, rule_pair=rule_pair, 270 | hist_example=hist_example, figname=figname+'.pdf', 271 | **kwargs) 272 | 273 | 274 | def plot_hist_varprop_selection(model_dir, figname_extra=None): 275 | rule_pair_list = [('dm1', 'dm2'), 276 | ('contextdm1', 'contextdm2'), 277 | ('dm1', 'fdanti'), 278 | ('dm1', 'contextdm1'), 279 | ('fdgo', 'reactgo'), 280 | ('delaydm1', 'dm1'), 281 | ('dmcgo', 'dmcnogo'), 282 | ('contextdm1', 'contextdelaydm1')] 283 | for rule_pair in rule_pair_list: 284 | plot_hist_varprop(model_dir=model_dir, 285 | rule_pair=rule_pair, 286 | plot_legend=(rule_pair==('dm1', 'fdanti')), 287 | plot_example=True, 288 | figname_extra=figname_extra) 289 | 290 | 291 | def plot_hist_varprop_all(model_dir, plot_control=True): 292 | ''' 293 | Plot histogram of proportion of variance for some tasks across units 294 | :param save_name: 295 | :param data_type: 296 | :param rule_pair: list of rule_pair. Show proportion of variance for the first rule 297 | :return: 298 | ''' 299 | 300 | model_dirs = tools.valid_model_dirs(model_dir) 301 | 302 | hp = tools.load_hp(model_dirs[0]) 303 | rules = hp['rules'] 304 | 305 | figsize = (7, 7) 306 | 307 | # For testing 308 | # rules, figsize = ['fdgo','reactgo','delaygo', 'fdanti', 'reactanti'], (4, 4) 309 | 310 | fs = 6 # fontsize 311 | 312 | f, axarr = plt.subplots(len(rules), len(rules), figsize=figsize) 313 | plt.subplots_adjust(left=0.1, right=0.98, bottom=0.02, top=0.9) 314 | 315 | for i in range(len(rules)): 316 | for j in range(len(rules)): 317 | ax = axarr[i, j] 318 | if i == 0: 319 | ax.set_title(rule_name[rules[j]], fontsize=fs, rotation=45, va='bottom') 320 | if j == 0: 321 | ax.set_ylabel(rule_name[rules[i]], fontsize=fs, rotation=45, ha='right') 322 | 323 | ax.spines["right"].set_visible(False) 324 | ax.spines["left"].set_visible(False) 325 | ax.spines["top"].set_visible(False) 326 | if i == j: 327 | ax.spines["bottom"].set_visible(False) 328 | ax.set_xticks([]) 329 | ax.set_yticks([]) 330 | continue 331 | 332 | hists, bins_edge = compute_hist_varprop(model_dir, (rules[i], rules[j])) 333 | hist_low, hist_med, hist_high = np.percentile(hists, [10, 50, 90], axis=0) 334 | hist_med /= hist_med.sum() 335 | 336 | # Control case 337 | if plot_control: 338 | hists_ctrl, _ = compute_hist_varprop(model_dir, (rules[i], rules[j]), random_rotation=True) 339 | _, hist_med_ctrl, _ = np.percentile(hists_ctrl, [10, 50, 90], axis=0) 340 | hist_med_ctrl /= hist_med_ctrl.sum() 341 | ax.plot((bins_edge[:-1]+bins_edge[1:])/2, 342 | hist_med_ctrl, color='gray', lw=0.75) 343 | 344 | ax.plot((bins_edge[:-1]+bins_edge[1:])/2, hist_med, color='black') 345 | plt.locator_params(nbins=3) 346 | 347 | # ax.set_ylim(bottom=-0.02*hist_med.max()) 348 | ax.set_ylim([-0.01, 0.6]) 349 | print(hist_med.max()) 350 | ax.set_xticks([-1,1]) 351 | ax.set_xticklabels([]) 352 | if i == 0 and j == 1: 353 | ax.set_yticks([0, 0.6]) 354 | ax.spines["left"].set_visible(True) 355 | else: 356 | ax.set_yticks([]) 357 | ax.set_xlim([-1,1]) 358 | ax.xaxis.set_ticks_position('bottom') 359 | ax.tick_params(axis='both', which='major', labelsize=fs, length=2) 360 | 361 | 362 | # plt.tight_layout() 363 | plt.savefig('figure/plot_hist_varprop_all.pdf', transparent=True) 364 | 365 | def plot_hist_varprop_selection_cont(): 366 | save_type = 'cont_allrule' 367 | save_type_end = '_0_1_2intsynmain' 368 | rules_list = [(CHOICE_MOD1, CHOICE_MOD2), 369 | (CHOICEATTEND_MOD1, CHOICEATTEND_MOD2), 370 | (CHOICE_MOD1, CHOICEATTEND_MOD1), 371 | (CHOICEDELAY_MOD1, CHOICE_MOD1)] 372 | for rules in rules_list: 373 | plot_hist_varprop(save_type=save_type, save_type_end=save_type_end, rules=rules, 374 | plot_legend=(rules==(CHOICE_MOD1, REACTANTI)), hdim_example=0) 375 | 376 | def get_random_rotation_variance(save_name, data_type): ##TODO: Need more work 377 | # save_name = 'allrule_weaknoise_300' 378 | # data_type = 'rule' 379 | 380 | # If not computed, use variance.py 381 | # fname = 'data/variance'+data_type+save_name+'_rr' 382 | fname = os.path.join('data', 'variance_'+data_type+save_name) 383 | with open(fname+'.pkl','rb') as f: 384 | res = pickle.load(f) 385 | h_var_all = res['h_var_all'] 386 | keys = res['keys'] 387 | 388 | # First only get active units. Total variance across tasks larger than 1e-3 389 | ind_active = np.where(h_var_all.sum(axis=1) > 1e-3)[0] 390 | h_var_all = h_var_all[ind_active, :] 391 | 392 | # Normalize by the total variance across tasks 393 | h_normvar_all = (h_var_all.T/np.sum(h_var_all, axis=1)).T 394 | 395 | 396 | rule_hist = CHOICEATTEND_MOD1 397 | data_plot = h_normvar_all[:, keys.index(rule_hist)] 398 | 399 | p_low, p_high = 2.5, 97.5 400 | normvar_low, normvar_high = np.percentile(data_plot, [p_low, p_high]) 401 | 402 | fig = plt.figure(figsize=(1.5,1.2)) 403 | ax = fig.add_axes([0.3,0.3,0.6,0.5]) 404 | hist, bins_edge = np.histogram(data_plot, bins=30) 405 | ax.bar(bins_edge[:-1], hist, width=bins_edge[1]-bins_edge[0], 406 | color=sns.xkcd_palette(['cerulean'])[0], edgecolor='none') 407 | # ax.set_xlim([0,0.3]) 408 | ax.plot([normvar_low]*2, [0, hist.max()], 'black') 409 | ax.plot([normvar_high]*2, [0, hist.max()], 'black') 410 | plt.locator_params(nbins=3) 411 | ax.set_ylim(bottom=-1) 412 | 413 | print('{:0.1f} percentile: {:0.2f}'.format(p_low, normvar_low)) 414 | print('{:0.1f} percentile: {:0.2f}'.format(p_high, normvar_high)) 415 | 416 | 417 | def compute_ntasks_selective(): 418 | # Compute the number of tasks each neuron is selective for 419 | # NOT WELL DEFINED YET 420 | # DOESN"T REALLY WORK 421 | with open(os.path.join('data','variance'+data_type+save_name+'_rr'+'.pkl'),'rb') as f: 422 | res_rr = pickle.load(f) 423 | h_var_all_rr = res_rr['h_var_all'] 424 | 425 | bounds = np.percentile(h_var_all_rr, 97.5, axis=0) 426 | 427 | # bounds = 1e-2 428 | 429 | with open(os.path.join('data','variance'+data_type+save_name+'.pkl'),'rb') as f: 430 | res = pickle.load(f) 431 | h_var_all = res['h_var_all'] 432 | 433 | # First only get active units. Total variance across tasks larger than 1e-3 434 | ind_active = np.where(h_var_all.sum(axis=1) > 1e-3)[0] 435 | h_var_all = h_var_all[ind_active, :] 436 | 437 | h_selective = h_var_all > bounds 438 | n_selective = h_selective.sum(axis=1) 439 | 440 | hist, bins_edge = np.histogram(n_selective, bins=lren(rules)+1, range=(-0.5,len(rules)+0.5)) 441 | 442 | fig = plt.figure(figsize=(3,2.4)) 443 | ax = fig.add_axes([0.2,0.3,0.6,0.5]) 444 | ax.bar(bins_edge[:-1], hist, width=bins_edge[1]-bins_edge[0], 445 | color=sns.xkcd_palette(['cerulean'])[0], edgecolor='none') 446 | 447 | 448 | def plot_var_random(): 449 | dist = 'beta' 450 | n = 10000 451 | if dist == 'uniform': 452 | var = np.random.rand(2 * n) 453 | elif dist == 'beta': 454 | var = np.random.beta(4, 3, size=(2 * n,)) 455 | elif dist == 'gamma': 456 | var = np.random.gamma(1, 2, size=(2 * n,)) 457 | elif dist == 'lognormal': 458 | var = np.random.randn(2 * n) * 1.9 + 0.75 459 | var = var * (var < 6) + 6.0 * (var >= 6) 460 | var = np.exp(var) 461 | 462 | frac_var = (var[:n] - var[n:]) / (var[:n] + var[n:]) 463 | 464 | plt.figure(figsize=(2, 2)) 465 | plt.hist(var) 466 | 467 | plt.figure(figsize=(2, 2)) 468 | plt.hist(frac_var) 469 | 470 | 471 | if __name__ == '__main__': 472 | pass 473 | 474 | -------------------------------------------------------------------------------- /analysis/varyhp.py: -------------------------------------------------------------------------------- 1 | """Analyze the results after varying hyperparameters.""" 2 | 3 | from __future__ import division 4 | 5 | from collections import defaultdict 6 | from collections import OrderedDict 7 | import os 8 | import numpy as np 9 | import matplotlib as mpl 10 | import matplotlib.pyplot as plt 11 | 12 | import tools 13 | from analysis import variance 14 | from analysis import clustering 15 | from analysis import standard_analysis 16 | 17 | mpl.rcParams.update({'font.size': 7}) 18 | 19 | 20 | FIGPATH = os.path.join(os.getcwd(), 'figure') 21 | 22 | HP_NAME = {'activation': 'Activation Fun.', 23 | 'rnn_type': 'Network type', 24 | 'w_rec_init': 'Initialization', 25 | 'l1_h': 'L1 rate', 26 | 'l1_weight': 'L1 weight', 27 | 'l2_weight_init': 'L2 weight anchor', 28 | 'target_perf': 'Target perf.'} 29 | 30 | #maddy added check tanh fig 4 31 | #root_dir = './data/debug/8' #0, 33 './data/train_all' 32 | """ 33 | variance.compute_variance(root_dir) 34 | variance.plot_hist_varprop_selection(root_dir) 35 | variance.plot_hist_varprop_all(root_dir) 36 | analysis = clustering.Analysis(root_dir, 'rule') 37 | analysis.plot_variance() 38 | """ 39 | """ 40 | standard_analysis.easy_connectivity_plot(root_dir) 41 | rule = 'contextdm1' 42 | standard_analysis.easy_activity_plot(root_dir, rule) 43 | print "easy_connectivity_plot"+root_dir 44 | """ 45 | 46 | def compute_n_cluster(model_dirs): 47 | for model_dir in model_dirs: 48 | print(model_dir) 49 | log = tools.load_log(model_dir) 50 | hp = tools.load_hp(model_dir) 51 | try: 52 | analysis = clustering.Analysis(model_dir, 'rule') 53 | 54 | log['n_cluster'] = analysis.n_cluster 55 | log['model_dir'] = model_dir 56 | tools.save_log(log) 57 | except IOError: 58 | # Training never finished 59 | assert log['perf_min'][-1] <= hp['target_perf'] 60 | 61 | # analysis.plot_example_unit() 62 | # analysis.plot_variance() 63 | # analysis.plot_2Dvisualization() 64 | 65 | print("done") 66 | 67 | 68 | def plot_histogram(): 69 | initdict = defaultdict(list) 70 | initdictother = defaultdict(list) 71 | initdictotherother = defaultdict(list) 72 | 73 | for model_dir in model_dirs: 74 | hp = tools.load_hp(model_dir) 75 | #check if performance exceeds target 76 | log = tools.load_log(model_dir) 77 | #if log['perf_avg'][-1] > hp['target_perf']: 78 | if log['perf_min'][-1] > hp['target_perf']: 79 | print('no. of clusters', log['n_cluster']) 80 | n_clusters.append(log['n_cluster']) 81 | hp_list.append(hp) 82 | 83 | initdict[hp['w_rec_init']].append(log['n_cluster']) 84 | initdict[hp['activation']].append(log['n_cluster']) 85 | 86 | #initdict[hp['rnn_type']].append(log['n_cluster']) 87 | if hp['activation'] != 'tanh': 88 | initdict[hp['rnn_type']].append(log['n_cluster']) 89 | initdictother[hp['rnn_type']+hp['activation']].append(log['n_cluster']) 90 | initdictotherother[hp['rnn_type']+hp['activation']+hp['w_rec_init']].append(log['n_cluster']) 91 | 92 | if hp['l1_h'] == 0: 93 | initdict['l1_h_0'].append(log['n_cluster']) 94 | else: #hp['l1_h'] == 1e-3 or 1e-4 or 1e-5: 95 | keyvalstr = 'l1_h_1emin'+str(int(abs(np.log10(hp['l1_h'])))) 96 | initdict[keyvalstr].append(log['n_cluster']) 97 | 98 | if hp['l1_weight'] == 0: 99 | initdict['l1_weight_0'].append(log['n_cluster']) 100 | else: #hp['l1_h'] == 1e-3 or 1e-4 or 1e-5: 101 | keyvalstr = 'l1_weight_1emin'+str(int(abs(np.log10(hp['l1_weight'])))) 102 | initdict[keyvalstr].append(log['n_cluster']) 103 | 104 | #initdict[hp['l1_weight']].append(log['n_cluster']) 105 | 106 | # Check no of clusters under various conditions. 107 | f, axarr = plt.subplots(7, 1, figsize=(3,12), sharex=True) 108 | u = 0 109 | for key in initdict.keys(): 110 | if 'l1_' not in key: 111 | title = (key + ' ' + str(len(initdict[key])) + 112 | ' mean: '+str(round(np.mean(initdict[key]),2))) 113 | axarr[u].set_title(title) 114 | axarr[u].hist(initdict[key]) 115 | u += 1 116 | f.subplots_adjust(wspace=.3, hspace=0.3) 117 | # plt.savefig('./figure/histforcases_96nets.png') 118 | # plt.savefig('./figure/histforcases__pt9_192nets.pdf') 119 | # plt.savefig('./figure/histforcases___leakygrunotanh_pt9_192nets.pdf') 120 | 121 | f, axarr = plt.subplots(4, 1, figsize=(3,8), sharex=True) 122 | u = 0 123 | for key in initdictother.keys(): 124 | if 'l1_' not in key: 125 | axarr[u].set_title(key + ' ' + str(len(initdictother[key]))+ ' mean: '+str(round(np.mean(initdictother[key]),2)) ) 126 | axarr[u].hist(initdictother[key]) 127 | u += 1 128 | f.subplots_adjust(wspace=.3, hspace=0.3) 129 | # plt.savefig('./figure/histforcases__leakyrnngrurelusoftplus_pt9_192nets.pdf') 130 | 131 | 132 | f, axarr = plt.subplots(4, 1, figsize=(3,6), sharex=True) 133 | u = 0 134 | for key in initdictotherother.keys(): 135 | if 'l1_' not in key and 'diag' not in key: 136 | axarr[u].set_title(key + ' ' + str(len(initdictotherother[key]))+ ' mean: '+str(round(np.mean(initdictotherother[key]),2)) ) 137 | axarr[u].hist(initdictotherother[key]) 138 | u += 1 139 | f.subplots_adjust(wspace=.3, hspace=0.3) 140 | # plt.savefig('./figure/histforcases_randortho_notanh_pt9_192nets.pdf') 141 | 142 | f, axarr = plt.subplots(4, 1, figsize=(3,6),sharex=True) 143 | u = 0 144 | for key in initdictotherother.keys(): 145 | if 'l1_' not in key and 'randortho' not in key: 146 | axarr[u].set_title(key + ' ' + str(len(initdictotherother[key]))+ ' mean: '+str(round(np.mean(initdictotherother[key]),2)) ) 147 | axarr[u].hist(initdictotherother[key]) 148 | u += 1 149 | f.subplots_adjust(wspace=.3, hspace=0.3) 150 | # plt.savefig('./figure/histforcases_diag_notanh_pt9_192nets.pdf') 151 | 152 | 153 | #regu-- 154 | f, axarr = plt.subplots(4, 1,figsize=(3,8),sharex=True) 155 | u = 0 156 | for key in initdict.keys(): 157 | if 'l1_h_' in key: 158 | axarr[u].set_title(key + ' ' + str(len(initdict[key]))+ ' mean: '+str(round(np.mean(initdict[key]),2)) ) 159 | axarr[u].hist(initdict[key]) 160 | u += 1 161 | f.subplots_adjust(wspace=.3, hspace=0.3) 162 | #plt.savefig('./figure/noofclusters_pt9_l1_h_192nets.pdf') 163 | 164 | f, axarr = plt.subplots(4, 1,figsize=(3,8),sharex=True) 165 | u = 0 166 | for key in initdict.keys(): 167 | if 'l1_weight_' in key: 168 | axarr[u].set_title(key + ' ' + str(len(initdict[key])) + ' mean: '+str(round(np.mean(initdict[key]),2)) ) 169 | axarr[u].hist(initdict[key]) 170 | u += 1 171 | f.subplots_adjust(wspace=.3, hspace=0.3) 172 | #plt.savefig('./figure/noofclusters_pt9_l1_weight_192nets.pdf') 173 | 174 | 175 | def get_n_clusters(root_dir): 176 | model_dirs = tools.valid_model_dirs(root_dir) 177 | hp_list = list() 178 | n_clusters = list() 179 | for i, model_dir in enumerate(model_dirs): 180 | if i % 50 == 0: 181 | print('Analyzing model {:d}/{:d}'.format(i, len(model_dirs))) 182 | hp = tools.load_hp(model_dir) 183 | log = tools.load_log(model_dir) 184 | # check if performance exceeds target 185 | if log['perf_min'][-1] > hp['target_perf']: 186 | n_clusters.append(log['n_cluster']) 187 | hp_list.append(hp) 188 | return n_clusters, hp_list 189 | 190 | 191 | def _get_hp_ranges(): 192 | """Get ranges of hp.""" 193 | hp_ranges = OrderedDict() 194 | hp_ranges['activation'] = ['softplus', 'relu', 'retanh', 'tanh'] 195 | hp_ranges['rnn_type'] = ['LeakyRNN', 'LeakyGRU'] 196 | hp_ranges['w_rec_init'] = ['diag', 'randortho'] 197 | hp_ranges['l1_h'] = [0, 1e-5, 1e-4, 1e-3] 198 | # hp_ranges['l2_h'] = [0, 1e-4] 199 | hp_ranges['l1_weight'] = [0, 1e-5, 1e-4, 1e-3] 200 | return hp_ranges 201 | 202 | 203 | def plot_n_clusters(n_clusters, hp_list): 204 | """Plot the number of clusters. 205 | 206 | Args: 207 | n_clusters: list of cluster numbers 208 | hp_list: list of hp dictionary 209 | """ 210 | hp_ranges = _get_hp_ranges() 211 | 212 | # The hp to show 213 | hp_plots = hp_ranges.keys() 214 | 215 | # Sort by number of clusters 216 | ind_sort = np.argsort(n_clusters)[::-1] 217 | n_clusters_sorted = [n_clusters[i] for i in ind_sort] 218 | hp_list_sorted = [hp_list[i] for i in ind_sort] 219 | 220 | # Fill a matrix with the index of hp 221 | hp_visualize = np.zeros([len(hp_plots), len(n_clusters)]) 222 | for i, hp in enumerate(hp_list_sorted): 223 | for j, hp_plot in enumerate(hp_plots): 224 | ind = hp_ranges[hp_plot].index(hp[hp_plot]) 225 | ind /= len(hp_ranges[hp_plot]) - 1. 226 | hp_visualize[j, i] = ind 227 | 228 | # Plot results 229 | fig = plt.figure(figsize=(3, 2)) 230 | ax = fig.add_axes([0.3, 0.6, 0.65, 0.3]) 231 | ax.plot(n_clusters_sorted, '-') 232 | ax.set_xlim([0, len(n_clusters) - 1]) 233 | ax.set_xticks([0, len(n_clusters) - 1]) 234 | ax.set_xticklabels([]) 235 | ax.set_yticks([0, 10, 20, 30]) 236 | ax.set_ylabel('Num. of clusters', fontsize=7) 237 | ax.spines["right"].set_visible(False) 238 | ax.spines["top"].set_visible(False) 239 | 240 | import matplotlib as mpl 241 | import seaborn as sns 242 | colors = sns.color_palette("hls", 5) 243 | cmap = mpl.colors.ListedColormap(colors) 244 | cmap.set_over('0') 245 | cmap.set_under('1') 246 | ax = fig.add_axes([0.3, 0.15, 0.65, 0.35]) 247 | ax.imshow(hp_visualize, aspect='auto', cmap='viridis') 248 | ax.set_xticks([0, len(n_clusters) - 1]) 249 | ax.set_xticklabels([1, len(n_clusters)]) 250 | ax.set_yticks(range(len(hp_plots))) 251 | 252 | hp_plot_names = [HP_NAME[hp] for hp in hp_plots] 253 | ax.set_yticklabels(hp_plot_names, fontsize=7) 254 | ax.tick_params(length=0) 255 | [i.set_linewidth(0.1) for i in ax.spines.values()] 256 | ax.set_xlabel('Networks', labelpad=-5) 257 | # plt.title('target perf-min 0.9, total:'+str(len(n_clusters))) # 258 | plt.savefig(os.path.join(FIGPATH, 'NumClusters.pdf'), transparent=True) 259 | 260 | val = n_clusters_sorted 261 | fig = plt.figure(figsize=(1.0, 0.8)) 262 | ax = fig.add_axes([0.2, 0.4, 0.7, 0.5]) 263 | hist, bin_edges = np.histogram(val, density=True, range=(0, 30), 264 | bins=30) 265 | color = 'gray' 266 | ax.hist(val, range=(0, 30), 267 | density=True, bins=16, ec=color, facecolor=color, 268 | lw=1.5) 269 | ax.spines["left"].set_visible(False) 270 | ax.spines["right"].set_visible(False) 271 | ax.spines["top"].set_visible(False) 272 | ax.set_yticks([]) 273 | ax.set_xticks([0, 30]) 274 | ax.set_xlim([0, 30]) 275 | ax.set_xlabel('No. clusters', labelpad=-5) 276 | plt.tight_layout() 277 | figname = os.path.join(FIGPATH, 'NumClustersHist.pdf') 278 | plt.savefig(figname, transparent=True) 279 | 280 | 281 | def _plot_n_cluster_hist(hp_plot, n_clusters=None, hp_list=None): 282 | """Plot histogram for number of clusters, separating by an attribute. 283 | 284 | Args: 285 | hp_plot: str, the attribute to separate histogram by 286 | n_clusters: list of cluster numbers 287 | hp_list: list of hp dictionary 288 | """ 289 | if hp_list is None: 290 | n_clusters, hp_list = get_n_clusters() 291 | 292 | # Compare activation, ignore tanh that can not be trained with LeakyRNN 293 | # hp_plot = 'activation' 294 | # hp_plot = 'rnn_type' 295 | # hp_plot = 'w_rec_init' 296 | 297 | n_cluster_dict = OrderedDict() 298 | hp_ranges = _get_hp_ranges() 299 | for key in hp_ranges[hp_plot]: 300 | n_cluster_dict[key] = list() 301 | 302 | for hp, n_cluster in zip(hp_list, n_clusters): 303 | # if hp_plot == 'activation' and hp['rnn_type'] != 'LeakyGRU': 304 | # For activation, only analyze LeakyGRU cells 305 | # continue 306 | if hp_plot == 'rnn_type' and hp['activation'] in ['tanh', 'retanh']: 307 | # For rnn_type, exclude tanh units 308 | continue 309 | n_cluster_dict[hp[hp_plot]].append(n_cluster) 310 | 311 | label_map = {'softplus': 'Softplus', 312 | 'relu': 'ReLU', 313 | 'retanh': 'Retanh', 314 | 'tanh': 'Tanh', 315 | 'LeakyGRU': 'GRU', 316 | 'LeakyRNN': 'RNN', 317 | 'randortho': 'Rand.\nOrtho.', 318 | 'diag': 'Diag.'} 319 | # fig = plt.figure(figsize=(1.5, 1.2)) 320 | # ax = fig.add_axes([0.2, 0.2, 0.7, 0.7]) 321 | f, axs = plt.subplots(len(n_cluster_dict), 1, 322 | sharex=True, figsize=(1.2, 1.8)) 323 | for i, (key, val) in enumerate(n_cluster_dict.items()): 324 | ax = axs[i] 325 | hist, bin_edges = np.histogram(val, density=True, range=(0, 30), 326 | bins=30) 327 | # plt.bar(bin_edges[:-1], hist, label=key) 328 | color_ind = i / (len(hp_ranges[hp_plot]) - 1.) 329 | color = mpl.cm.viridis(color_ind) 330 | if isinstance(key, float): 331 | label = '{:1.0e}'.format(key) 332 | else: 333 | label = label_map.get(key, str(key)) 334 | ax.hist(val, label=label, range=(0, 30), 335 | density=True, bins=16, ec=color, facecolor=color, 336 | lw=1.5) 337 | ax.spines["left"].set_visible(False) 338 | ax.spines["right"].set_visible(False) 339 | ax.spines["top"].set_visible(False) 340 | ax.set_yticks([]) 341 | ax.set_xticks([0, 15, 30]) 342 | ax.set_xlim([0, 30]) 343 | ax.text(0.7, 0.7, label, fontsize=7, transform=ax.transAxes) 344 | if i == 0: 345 | ax.set_title(HP_NAME[hp_plot], fontsize=7) 346 | # ax.legend(loc=3, bbox_to_anchor=(1, 0), title=HP_NAME[hp_plot], frameon=False) 347 | ax.set_xlabel('Number of clusters') 348 | plt.tight_layout() 349 | figname = os.path.join(FIGPATH, 'NumClustersHist' + hp_plot + '.pdf') 350 | plt.savefig(figname, transparent=True) 351 | 352 | return n_cluster_dict 353 | 354 | 355 | def plot_n_cluster_hist(n_clusters, hp_list): 356 | """Plot histogram of number of clusters. 357 | 358 | Args: 359 | n_clusters: list of cluster numbers 360 | hp_list: list of hp dictionary 361 | """ 362 | hp_plots = ['activation', 'rnn_type', 'w_rec_init', 'l1_h', 'l1_weight'] 363 | # hp_plots = ['activation'] 364 | for hp_plot in hp_plots: 365 | n_cluster_dict = _plot_n_cluster_hist(hp_plot, n_clusters, hp_list) 366 | 367 | 368 | def get_model_by_activation(activation): 369 | hp_target = {'activation': activation, 370 | 'rnn_type': 'LeakyGRU', 371 | 'w_rec_init': 'diag', 372 | 'l1_h': 0, 373 | 'l1_weight': 0} 374 | 375 | return tools.find_model(DATAPATH, hp_target) 376 | 377 | 378 | def plot_hist_varprop(activation): 379 | """Plot FTV distribution.""" 380 | model_dir = get_model_by_activation(activation) 381 | variance.plot_hist_varprop_selection(model_dir, figname_extra='_tanh') 382 | 383 | 384 | def pretty_singleneuron_plot(activation='tanh'): 385 | """Plot single neuron activity.""" 386 | model_dir = get_model_by_activation(activation) 387 | standard_analysis.pretty_singleneuron_plot( 388 | model_dir, ['contextdm1', 'contextdm2'], range(2) 389 | ) 390 | 391 | 392 | def activity_histogram(activation): 393 | """Plot FTV distribution for tanh network.""" 394 | model_dir = get_model_by_activation(activation) 395 | title = activation 396 | save_name = '_' + activation 397 | standard_analysis.activity_histogram( 398 | model_dir, ['contextdm1', 'contextdm2'], title=title, 399 | save_name=save_name 400 | ) 401 | 402 | 403 | 404 | if __name__ == '__main__': 405 | pass 406 | DATAPATH = os.path.join(os.getcwd(), 'data', 'varyhp') 407 | # model_dirs = tools.valid_model_dirs(DATAPATH) 408 | 409 | # compute_n_cluster() 410 | n_clusters, hp_list = get_n_clusters(DATAPATH) 411 | # plot_n_clusters(n_clusters, hp_list) 412 | plot_n_cluster_hist(n_clusters, hp_list) 413 | # pretty_singleneuron_plot('tanh') 414 | # pretty_singleneuron_plot('relu') 415 | # [activity_histogram(a) for a in ['tanh', 'relu', 'softplus', 'retanh']] 416 | 417 | 418 | # ============================================================================= 419 | # DATAPATH = os.path.join(os.getcwd(), 'data', 'varyhp_reg') 420 | # FIGPATH = os.path.join(os.getcwd(), 'figure') 421 | # model_dirs = tools.valid_model_dirs(DATAPATH) 422 | # 423 | # hp_list = list() 424 | # n_clusters = list() 425 | # logs = list() 426 | # perfs = list() 427 | # for i, model_dir in enumerate(model_dirs): 428 | # hp = tools.load_hp(model_dir) 429 | # log = tools.load_log(model_dir) 430 | # # check if performance exceeds target 431 | # perfs.append(log['perf_min'][-1]) 432 | # if log['perf_min'][-1] > 0.8: 433 | # logs.append(log) 434 | # n_clusters.append(log['n_cluster']) 435 | # hp_list.append(hp) 436 | # ============================================================================= 437 | 438 | 439 | -------------------------------------------------------------------------------- /datasets/mante_dataset_preprocess.py: -------------------------------------------------------------------------------- 1 | """Mante dataset preprocessing. 2 | 3 | Standardize the Mante dataset. 4 | """ 5 | 6 | from __future__ import division 7 | 8 | import os 9 | import csv 10 | from collections import defaultdict 11 | import math 12 | import numpy as np 13 | from scipy.io import loadmat 14 | import matplotlib.pyplot as plt 15 | 16 | DATASETPATH = './datasets/mante_dataset' 17 | 18 | 19 | def _expand_task_var(task_var): 20 | """Little helper function that calculate a few more things.""" 21 | task_var['stim_dir_sign'] = (task_var['stim_dir']>0).astype(int)*2-1 22 | task_var['stim_col2dir_sign'] = (task_var['stim_col2dir']>0).astype(int)*2-1 23 | return task_var 24 | 25 | 26 | def load_data(smooth=True, single_units=False, animal='ar'): 27 | """Load Mante data into raw format. 28 | 29 | Args: 30 | smooth: bool, whether to load smoothed data 31 | single_units: bool, if True, only analyze single units 32 | animal: str, 'ar' or 'fe' 33 | 34 | Returns: 35 | data: standard format, list of dict of arrays/dict 36 | list is over neurons 37 | dict is for response array and task variable dict 38 | response array has shape (n_trial, n_time) 39 | """ 40 | if smooth: 41 | fname = 'dataTsmooth_' + animal + '.mat' 42 | else: 43 | fname = 'dataT_' + animal + '.mat' 44 | 45 | fname = os.path.join(DATASETPATH, fname) 46 | 47 | mat_dict = loadmat(fname, squeeze_me=True, struct_as_record=False) 48 | 49 | dataT = mat_dict['dataT'].__dict__ # as dictionary 50 | 51 | data = dataT['unit'] 52 | # time = dataT['time'] 53 | 54 | if single_units: 55 | single_units = get_single_units(animal) 56 | ind_single_units = np.where(single_units)[0] 57 | data = [data[i] for i in ind_single_units] 58 | 59 | # Convert to standard format 60 | new_data = list() 61 | n_unit = len(data) 62 | for i in range(n_unit): 63 | task_var = data[i].task_variable.__dict__ 64 | task_var = _expand_task_var(task_var) 65 | unit_dict = { 66 | 'task_var': task_var, # turn into dictionary 67 | 'rate': data[i].response # (n_trial, n_time) 68 | } 69 | new_data.append(unit_dict) 70 | 71 | return new_data 72 | 73 | 74 | def get_single_units(animal): 75 | # get single units 76 | fname = os.path.join('datasets', 'mante_dataset', 77 | 'metadata_'+animal+'.mat') 78 | mat_dict = loadmat(fname, squeeze_me=True, struct_as_record=False) 79 | metadata = mat_dict['metadata'].unit # as dictionary 80 | discriminable = np.array([(m.unitInfo.discriminability in [3,4]) for m in metadata]) 81 | single_units = np.array([m.unitInfo.type=='s' for m in metadata]) 82 | single_units = np.logical_and(single_units, discriminable) 83 | if animal == 'ar': 84 | assert np.sum(single_units)==181 85 | else: 86 | assert np.sum(single_units)==207 87 | return single_units 88 | 89 | 90 | def get_mante_data(): 91 | """Get Mante data in standard format. 92 | 93 | Returns: 94 | rate1s: numpy array (n_unit, n_condition, n_time) in context 1 95 | rate2s: numpy array (n_unit, n_condition, n_time) in context 2 96 | """ 97 | data = load_mante_data() 98 | 99 | n_unit = len(data) 100 | 101 | rate1s = list() 102 | rate2s = list() 103 | random_shuffle = True 104 | for i_unit in range(n_unit): 105 | # Get trial-averaged condition-based responses (n_condition, n_time) 106 | rate1s.append(get_trial_avg_rate_mante(data[i_unit], context=1, 107 | random_shuffle=random_shuffle)) 108 | rate2s.append(get_trial_avg_rate_mante(data[i_unit], context=-1, 109 | random_shuffle=random_shuffle)) 110 | # (n_unit, n_condition, n_time) 111 | rate1s, rate2s = np.array(rate1s), np.array(rate2s) 112 | return rate1s, rate2s 113 | 114 | 115 | if __name__ == '__main__': 116 | rate1s, rate2s = get_mante_data() -------------------------------------------------------------------------------- /datasets/siegel_dataset_preprocess.py: -------------------------------------------------------------------------------- 1 | """Siegel dataset preprocessing. 2 | 3 | Standardize the Siegel dataset 4 | Run siegel_preprocess.m before for correct results from the Siegel dataset. 5 | """ 6 | 7 | from __future__ import division 8 | 9 | import os 10 | import csv 11 | from collections import defaultdict 12 | import math 13 | import pickle 14 | import time 15 | import numpy as np 16 | from scipy.io import loadmat 17 | import matplotlib.pyplot as plt 18 | 19 | DATASETPATH = './datasets/siegel_dataset' 20 | 21 | 22 | def _load_spikes(fname): 23 | """Load spiking data from Siegel dataset.""" 24 | fname_full = os.path.join(DATASETPATH, 'sorted', fname) 25 | 26 | # Analyze a single file 27 | mat_dict = loadmat(fname_full, squeeze_me=True, struct_as_record=False) 28 | 29 | # Get spike times 30 | # spikes is an array (trials, neurons) of 1-D array 31 | spikes = mat_dict['spikeTimes'] 32 | return spikes 33 | 34 | 35 | def _load_tables(fname, table_name): 36 | """Load tables from Siegel dataset and convert to dictionary. 37 | 38 | Args: 39 | fname: str, file name 40 | table_name: str, can be 'trialinfo', 'unitinfo', and 'electrodeinfo' 41 | 42 | Returns: 43 | table_dict: dictionary, for each (key, val) pair, val is an array 44 | """ 45 | fname_full = os.path.join(DATASETPATH, table_name, fname[:6] + '.csv') 46 | 47 | table_dict = defaultdict(list) 48 | with open(fname_full) as csvDataFile: 49 | csvReader = csv.reader(csvDataFile) 50 | keys = csvReader.next() 51 | for row in csvReader: 52 | for k, r in zip(keys, row): 53 | # r is a string, convert to int, float, or string 54 | try: 55 | r2 = int(r) 56 | except ValueError: 57 | try: 58 | r2 = float(r) 59 | except ValueError: 60 | r2 = r 61 | 62 | table_dict[k].append(r2) 63 | return table_dict 64 | 65 | 66 | def _get_valid_trials(trial_infos): 67 | """Get valid trials. 68 | 69 | Args: 70 | trial_infos: dict of arrays. 71 | 72 | Returns: 73 | new_trial_infos: dict of arrays, only contain valid trials 74 | valid_trials: list of bools, indicating the valid trials 75 | """ 76 | new_trial_infos = defaultdict(list) 77 | keys = trial_infos.keys() 78 | n_trial = len(trial_infos[keys[0]]) 79 | valid_trials = list() 80 | for i_trial in range(n_trial): 81 | if (trial_infos['badTrials'][i_trial] == 1 or 82 | math.isnan(trial_infos['responseTime'][i_trial]) or 83 | trial_infos['responseTime'][i_trial] < 0.2 84 | ): 85 | valid = False 86 | else: 87 | valid = True 88 | valid_trials.append(valid) 89 | if valid: 90 | for key in keys: 91 | new_trial_infos[key].append(trial_infos[key][i_trial]) 92 | return new_trial_infos, valid_trials 93 | 94 | 95 | def _expand_task_var(task_var): 96 | """Little helper function that calculate a few more things.""" 97 | tmp = list() 98 | for r in task_var['rule']: 99 | if r == 'dir': 100 | tmp.append(+1) 101 | else: 102 | tmp.append(-1) 103 | task_var['context'] = np.array(tmp) 104 | # TODO(gryang): Take care of boundary case 105 | task_var['stim_dir_sign'] = (np.array(task_var['direction'])>0).astype(int)*2-1 106 | task_var['stim_col2dir_sign'] = (np.array(task_var['color'])>90).astype(int)*2-1 107 | return task_var 108 | 109 | 110 | def _compute_data_single_file(f): 111 | """Compute data for a single file. 112 | 113 | Args: 114 | f: str, file name 115 | 116 | Returns: 117 | data: standard format 118 | """ 119 | trial_infos = _load_tables(f, 'trialinfo') 120 | trial_infos, valid_trials = _get_valid_trials(trial_infos) 121 | task_var = _expand_task_var(trial_infos) 122 | 123 | unit_infos = _load_tables(f, 'unitinfo') 124 | # electrode_infos = _load_tables(f, 'electrodeinfo') 125 | 126 | spikes = _load_spikes(f) 127 | spikes = spikes[valid_trials, :] 128 | 129 | n_trial, n_unit = spikes.shape 130 | 131 | assert len(trial_infos.values()[0]) == n_trial 132 | assert len(unit_infos.values()[0]) == n_unit 133 | 134 | bin_size = 0.05 # unit: second 135 | bins = np.arange(0, 0.21, bin_size) 136 | n_time = len(bins) - 1 137 | 138 | data = list() 139 | for i_unit in range(n_unit): 140 | rates = np.zeros((n_trial, n_time)) 141 | for i_trial in range(n_trial): 142 | spikes_unit = spikes[i_trial, i_unit] 143 | # Compute PSTH 144 | hist, bin_edges = np.histogram(spikes_unit, bins=bins) 145 | rates[i_trial, :] = hist / bin_size 146 | unit_dict = { 147 | 'task_var': task_var, 148 | 'rate': rates 149 | } 150 | for key, val in unit_infos.items(): 151 | unit_dict[key] = val[i_unit] 152 | data.append(unit_dict) 153 | return data 154 | 155 | 156 | def _compute_data(): 157 | """Compute data for all files.""" 158 | datasetpath = os.path.join(DATASETPATH, 'sorted') 159 | 160 | files = os.listdir(datasetpath) 161 | files = [f for f in files if '1' in f] 162 | 163 | start_time = time.time() 164 | 165 | for f in files: 166 | print('Analyzing file: ' + f) 167 | print('Time taken {:0.2f}s'.format(time.time()-start_time)) 168 | data_single_file = _compute_data_single_file(f) 169 | 170 | fname = os.path.join(DATASETPATH, 'standard', 'siegel'+f[:6]+'.pkl') 171 | print('File saved at: ' + fname) 172 | with open(fname, 'wb') as f2: 173 | pickle.dump(data_single_file, f2) 174 | 175 | 176 | def load_data(single_file=False): 177 | """Load Siegel data into standard format. 178 | 179 | Returns: 180 | data: standard format, list of dict of arrays/dict 181 | list is over neurons 182 | dict is for response array and task variable dict 183 | response array has shape (n_trial, n_time) 184 | """ 185 | datasetpath = os.path.join(DATASETPATH, 'standard') 186 | 187 | files = os.listdir(datasetpath) 188 | files = [f for f in files if '1' in f] 189 | 190 | if single_file: 191 | files = files[:1] 192 | 193 | start_time = time.time() 194 | 195 | data = list() 196 | for f in files: 197 | print('Analyzing file: ' + f) 198 | print('Time taken {:0.2f}s'.format(time.time() - start_time)) 199 | fname = os.path.join(datasetpath, f) 200 | with open(fname, 'rb') as f2: 201 | data_single_file = pickle.load(f2) 202 | data.extend(data_single_file) 203 | return data 204 | 205 | 206 | # def _spike_to_rate(spikes_unit, times): 207 | # """Convert spikes to rate. 208 | # 209 | # Args: 210 | # spikes_unit: list of float, a list of spike times 211 | # times: list of float, a list of time points, default unit second 212 | # 213 | # Returns: 214 | # rates: list of float, a list of rate, default unit spike/second, Hz 215 | # rates will be the same size as times 216 | # """ 217 | 218 | 219 | if __name__ == '__main__': 220 | # _compute_data() 221 | data = load_data(single_file=True) 222 | 223 | 224 | 225 | -------------------------------------------------------------------------------- /experiment.py: -------------------------------------------------------------------------------- 1 | """Different training experiments.""" 2 | 3 | from __future__ import division 4 | 5 | import os 6 | from collections import OrderedDict 7 | import numpy as np 8 | 9 | import tools 10 | import train 11 | from analysis import variance 12 | from analysis import clustering 13 | from analysis import data_analysis 14 | from analysis import performance 15 | from analysis import taskset 16 | 17 | # TODO: make this flexible 18 | DATAPATH = os.path.join(os.getcwd(), 'data') 19 | 20 | 21 | def train_mante(seed=0, model_dir='train_mante'): 22 | """Training of only the Mante task.""" 23 | hp = {'target_perf': 0.9} 24 | model_dir = os.path.join(DATAPATH, model_dir, str(seed)) 25 | train.train(model_dir, hp=hp, ruleset='mante', seed=seed) 26 | 27 | 28 | def mante_tanh(seed=0, model_dir='mante_tanh'): 29 | """Training of only the Mante task.""" 30 | hp = {'activation': 'tanh', 31 | 'target_perf': 0.9} 32 | model_dir = os.path.join(DATAPATH, model_dir, str(seed)) 33 | train.train(model_dir, hp=hp, ruleset='mante', seed=seed) 34 | # Analyses 35 | variance.compute_variance(model_dir) 36 | 37 | log = tools.load_log(model_dir) 38 | analysis = clustering.Analysis(model_dir, 'rule') 39 | log['n_cluster'] = analysis.n_cluster 40 | tools.save_log(log) 41 | data_analysis.compute_var_all(model_dir) 42 | 43 | 44 | def train_all(seed=0, root_dir='train_all'): 45 | """Training of all tasks.""" 46 | model_dir = os.path.join(DATAPATH, root_dir, str(seed)) 47 | hp = {'activation': 'softplus', 'w_rec_init': 'diag'} # TODO: change the default back to diag 48 | rule_prob_map = {'contextdm1': 5, 'contextdm2': 5} 49 | train.train(model_dir, hp=hp, ruleset='all', 50 | rule_prob_map=rule_prob_map, seed=seed) 51 | train_all_analysis(seed=seed, root_dir=root_dir) 52 | 53 | 54 | def debug_train_all(): 55 | root_dir = 'debug_train_all' 56 | seed = 0 57 | model_dir = os.path.join(DATAPATH, root_dir, str(seed)) 58 | hp = {'activation': 'softplus', 'w_rec_init': 'diag'} 59 | rule_prob_map = {'contextdm1': 5, 'contextdm2': 5} 60 | train.train(model_dir, hp=hp, ruleset='all', 61 | rule_prob_map=rule_prob_map, seed=seed, max_steps=1e3) 62 | train_all_analysis(seed=seed, root_dir=root_dir) 63 | 64 | 65 | def train_all_analysis(seed=0, root_dir='train_all'): 66 | model_dir = os.path.join(DATAPATH, root_dir, str(seed)) 67 | # Analyses 68 | variance.compute_variance(model_dir) 69 | variance.compute_variance(model_dir, random_rotation=True) 70 | log = tools.load_log(model_dir) 71 | analysis = clustering.Analysis(model_dir, 'rule') 72 | log['n_cluster'] = analysis.n_cluster 73 | tools.save_log(log) 74 | data_analysis.compute_var_all(model_dir) 75 | 76 | for rule in ['dm1', 'contextdm1', 'multidm']: 77 | performance.compute_choicefamily_varytime(model_dir, rule) 78 | 79 | setups = [1, 2, 3] 80 | for setup in setups: 81 | taskset.compute_taskspace(model_dir, setup, 82 | restore=False, 83 | representation='rate') 84 | taskset.compute_replacerule_performance(model_dir, setup, False) 85 | 86 | 87 | def train_all_tanhgru(seed=0, model_dir='tanhgru'): 88 | """Training of all tasks with Tanh GRUs.""" 89 | model_dir = os.path.join(DATAPATH, model_dir, str(seed)) 90 | hp = {'activation': 'tanh', 91 | 'rnn_type': 'LeakyGRU'} 92 | rule_prob_map = {'contextdm1': 5, 'contextdm2': 5} 93 | train.train(model_dir, hp=hp, ruleset='all', 94 | rule_prob_map=rule_prob_map, seed=seed) 95 | # Analyses 96 | variance.compute_variance(model_dir) 97 | log = tools.load_log(model_dir) 98 | analysis = clustering.Analysis(model_dir, 'rule') 99 | log['n_cluster'] = analysis.n_cluster 100 | tools.save_log(log) 101 | data_analysis.compute_var_all(model_dir) 102 | 103 | setups = [1, 2, 3] 104 | for setup in setups: 105 | taskset.compute_taskspace(model_dir, setup, 106 | restore=False, 107 | representation='rate') 108 | taskset.compute_replacerule_performance(model_dir, setup, False) 109 | 110 | 111 | def train_all_mixrule(seed=0, root_dir='mixrule'): 112 | """Training of all tasks.""" 113 | model_dir = os.path.join(DATAPATH, root_dir, str(seed)) 114 | hp = {'activation': 'relu', 'w_rec_init': 'diag', 115 | 'use_separate_input': True, 'mix_rule': True} 116 | rule_prob_map = {'contextdm1': 5, 'contextdm2': 5} 117 | train.train(model_dir, hp=hp, ruleset='all', 118 | rule_prob_map=rule_prob_map, seed=seed) 119 | 120 | # Analyses 121 | variance.compute_variance(model_dir) 122 | log = tools.load_log(model_dir) 123 | analysis = clustering.Analysis(model_dir, 'rule') 124 | log['n_cluster'] = analysis.n_cluster 125 | tools.save_log(log) 126 | 127 | setups = [1, 2, 3] 128 | for setup in setups: 129 | taskset.compute_taskspace(model_dir, setup, 130 | restore=False, 131 | representation='rate') 132 | taskset.compute_replacerule_performance(model_dir, setup, False) 133 | 134 | 135 | def train_all_mixrule_softplus(seed=0, root_dir='mixrule_softplus'): 136 | """Training of all tasks.""" 137 | model_dir = os.path.join(DATAPATH, root_dir, str(seed)) 138 | hp = {'activation': 'softplus', 'w_rec_init': 'diag', 139 | 'use_separate_input': True, 'mix_rule': True} 140 | rule_prob_map = {'contextdm1': 5, 'contextdm2': 5} 141 | train.train(model_dir, hp=hp, ruleset='all', 142 | rule_prob_map=rule_prob_map, seed=seed) 143 | 144 | # Analyses 145 | variance.compute_variance(model_dir) 146 | log = tools.load_log(model_dir) 147 | analysis = clustering.Analysis(model_dir, 'rule') 148 | log['n_cluster'] = analysis.n_cluster 149 | tools.save_log(log) 150 | 151 | setups = [1, 2, 3] 152 | for setup in setups: 153 | taskset.compute_taskspace(model_dir, setup, 154 | restore=False, 155 | representation='rate') 156 | taskset.compute_replacerule_performance(model_dir, setup, False) 157 | 158 | 159 | def train_seq(i): 160 | # Ranges of hyperparameters to loop over 161 | hp_ranges = OrderedDict() 162 | hp_ranges['c_intsyn'] = [0, 1.0] 163 | 164 | # Unravel the input index 165 | keys = hp_ranges.keys() 166 | dims = [len(hp_ranges[k]) for k in keys] 167 | n_max = np.prod(dims) 168 | indices = np.unravel_index(i % n_max, dims=dims) 169 | 170 | # Set up new hyperparameter 171 | hp = dict() 172 | for key, index in zip(keys, indices): 173 | hp[key] = hp_ranges[key][index] 174 | hp['learning_rate'] = 0.001 175 | hp['w_rec_init'] = 'randortho' 176 | hp['easy_task'] = True 177 | hp['activation'] = 'relu' 178 | hp['ksi_intsyn'] = 0.01 179 | hp['max_steps'] = 4e5 180 | 181 | model_dir = os.path.join(DATAPATH, 'seq', str(i)) 182 | rule_trains = [['fdgo'], ['delaygo'], ['dm1', 'dm2'], ['multidm'], 183 | ['contextdm1', 'contextdm2']] 184 | train.train_sequential( 185 | model_dir, 186 | rule_trains, 187 | hp=hp, 188 | max_steps=hp['max_steps'], 189 | display_step=500, 190 | ruleset='all', 191 | seed=i // n_max, 192 | ) 193 | 194 | 195 | def train_vary_hp_seq(i): 196 | # Ranges of hyperparameters to loop over 197 | hp_ranges = OrderedDict() 198 | hp_ranges['activation'] = ['softplus', 'relu'] 199 | hp_ranges['w_rec_init'] = ['randortho'] 200 | hp_ranges['c_intsyn'] = [0, 0.1, 1.0, 10.] 201 | hp_ranges['ksi_intsyn'] = [0.001, 0.01, 0.1] 202 | hp_ranges['max_steps'] = [1e5, 2e5, 4e5] 203 | 204 | # Unravel the input index 205 | keys = hp_ranges.keys() 206 | dims = [len(hp_ranges[k]) for k in keys] 207 | n_max = np.prod(dims) 208 | indices = np.unravel_index(i % n_max, dims=dims) 209 | 210 | # Set up new hyperparameter 211 | hp = dict() 212 | for key, index in zip(keys, indices): 213 | hp[key] = hp_ranges[key][index] 214 | hp['learning_rate'] = 0.001 215 | hp['w_rec_init'] = 'randortho' 216 | hp['easy_task'] = True 217 | 218 | model_dir = os.path.join(DATAPATH, 'seq_varyhp', str(i)) 219 | rule_trains = [['fdgo'], ['delaygo'], ['dm1', 'dm2'], ['multidm'], 220 | ['contextdm1', 'contextdm2']] 221 | train.train_sequential( 222 | model_dir, 223 | rule_trains, 224 | hp=hp, 225 | max_steps=hp['max_steps'], 226 | display_step=500, 227 | ruleset='all', 228 | seed=i // n_max, 229 | ) 230 | 231 | 232 | def train_vary_hp(i): 233 | """Vary the hyperparameters. 234 | 235 | This experiment loops over a set of hyperparameters. 236 | 237 | Args: 238 | i: int, the index of the hyperparameters list 239 | """ 240 | # Ranges of hyperparameters to loop over 241 | hp_ranges = OrderedDict() 242 | # hp_ranges['activation'] = ['softplus', 'relu', 'tanh', 'retanh'] 243 | # hp_ranges['rnn_type'] = ['LeakyRNN', 'LeakyGRU'] 244 | # hp_ranges['w_rec_init'] = ['diag', 'randortho'] 245 | hp_ranges['activation'] = ['softplus'] 246 | hp_ranges['rnn_type'] = ['LeakyRNN'] 247 | hp_ranges['w_rec_init'] = ['randortho'] 248 | hp_ranges['l1_h'] = [0, 1e-9, 1e-8, 1e-7, 1e-6] # TODO(gryang): Change this? 249 | hp_ranges['l2_h'] = [0] 250 | hp_ranges['l1_weight'] = [0, 1e-7, 1e-6, 1e-5] 251 | # TODO(gryang): add the level of overtraining 252 | 253 | # Unravel the input index 254 | keys = hp_ranges.keys() 255 | dims = [len(hp_ranges[k]) for k in keys] 256 | n_max = np.prod(dims) 257 | indices = np.unravel_index(i % n_max, dims=dims) 258 | 259 | # Set up new hyperparameter 260 | hp = dict() 261 | for key, index in zip(keys, indices): 262 | hp[key] = hp_ranges[key][index] 263 | 264 | model_dir = os.path.join(DATAPATH, 'varyhp_reg2', str(i)) 265 | rule_prob_map = {'contextdm1': 5, 'contextdm2': 5} 266 | train.train(model_dir, hp, ruleset='all', 267 | rule_prob_map=rule_prob_map, seed=i // n_max) 268 | 269 | # Analyses 270 | variance.compute_variance(model_dir) 271 | log = tools.load_log(model_dir) 272 | analysis = clustering.Analysis(model_dir, 'rule') 273 | log['n_cluster'] = analysis.n_cluster 274 | tools.save_log(log) 275 | data_analysis.compute_var_all(model_dir) 276 | 277 | 278 | def _base_vary_hp_mante(i, hp_ranges, base_name): 279 | """Vary hyperparameters for mante tasks.""" 280 | # Unravel the input index 281 | keys = hp_ranges.keys() 282 | dims = [len(hp_ranges[k]) for k in keys] 283 | n_max = np.prod(dims) 284 | indices = np.unravel_index(i % n_max, dims=dims) 285 | 286 | # Set up new hyperparameter 287 | hp = dict() 288 | for key, index in zip(keys, indices): 289 | hp[key] = hp_ranges[key][index] 290 | 291 | model_dir = os.path.join(DATAPATH, base_name, str(i)) 292 | train.train(model_dir, hp, ruleset='mante', 293 | max_steps=1e7, seed=i // n_max) 294 | 295 | # Analyses 296 | variance.compute_variance(model_dir) 297 | 298 | log = tools.load_log(model_dir) 299 | analysis = clustering.Analysis(model_dir, 'rule') 300 | log['n_cluster'] = analysis.n_cluster 301 | tools.save_log(log) 302 | data_analysis.compute_var_all(model_dir) 303 | 304 | 305 | def vary_l2_init_mante(i): 306 | """Vary the hyperparameters and train on Mante tasks only. 307 | 308 | This experiment loops over a set of hyperparameters. 309 | 310 | Args: 311 | i: int, the index of the hyperparameters list 312 | """ 313 | # Ranges of hyperparameters to loop over 314 | hp_ranges = OrderedDict() 315 | hp_ranges['activation'] = ['softplus'] 316 | hp_ranges['rnn_type'] = ['LeakyRNN'] 317 | hp_ranges['w_rec_init'] = ['randortho'] 318 | hp_ranges['l2_weight_init'] = [0, 1e-4, 2*1e-4, 4*1e-4, 8*1e-4, 1.6*1e-3] 319 | hp_ranges['target_perf'] = [0.9] 320 | 321 | _base_vary_hp_mante(i, hp_ranges, base_name='vary_l2init_mante') 322 | 323 | 324 | def vary_l2_weight_mante(i): 325 | """Vary the hyperparameters and train on Mante tasks only. 326 | 327 | This experiment loops over a set of hyperparameters. 328 | 329 | Args: 330 | i: int, the index of the hyperparameters list 331 | """ 332 | # Ranges of hyperparameters to loop over 333 | hp_ranges = OrderedDict() 334 | hp_ranges['activation'] = ['softplus'] 335 | hp_ranges['rnn_type'] = ['LeakyRNN'] 336 | hp_ranges['w_rec_init'] = ['randortho'] 337 | hp_ranges['l2_weight'] = [0, 1e-4, 2*1e-4, 4*1e-4, 8*1e-4, 1.6*1e-3] 338 | hp_ranges['target_perf'] = [0.9] 339 | 340 | _base_vary_hp_mante(i, hp_ranges, base_name='vary_l2weight_mante') 341 | 342 | 343 | def vary_p_weight_train_mante(i): 344 | """Vary the hyperparameters and train on Mante tasks only. 345 | 346 | This experiment loops over a set of hyperparameters. 347 | 348 | Args: 349 | i: int, the index of the hyperparameters list 350 | """ 351 | # Ranges of hyperparameters to loop over 352 | hp_ranges = OrderedDict() 353 | hp_ranges['activation'] = ['softplus'] 354 | hp_ranges['rnn_type'] = ['LeakyRNN'] 355 | hp_ranges['w_rec_init'] = ['randortho'] 356 | # hp_ranges['p_weight_train'] = [0, 0.2, 0.4, 0.6, 0.8, 1.0] 357 | hp_ranges['p_weight_train'] = [0.05, 0.075] 358 | hp_ranges['target_perf'] = [0.9] 359 | 360 | _base_vary_hp_mante(i, hp_ranges, base_name='vary_pweighttrain_mante') 361 | 362 | 363 | def pretrain(setup, seed): 364 | """Get pre-trained networks.""" 365 | hp = dict() 366 | hp['learning_rate'] = 0.001 367 | hp['w_rec_init'] = 'diag' 368 | hp['easy_task'] = False 369 | hp['activation'] = 'relu' 370 | hp['max_steps'] = 2*1e6 371 | hp['l1_h'] = 1e-8 372 | hp['target_perf'] = 0.97 373 | hp['n_rnn'] = 128 374 | hp['use_separate_input'] = True 375 | 376 | model_dir = os.path.join(DATAPATH, 'pretrain', 'setup'+str(setup), str(seed)) 377 | if setup == 0: 378 | rule_trains = ['contextdm1', 'contextdm2', 'contextdelaydm2'] 379 | elif setup == 1: 380 | rule_trains = ['fdgo', 'fdanti', 'delaygo'] 381 | else: 382 | raise ValueError 383 | 384 | train.train(model_dir, 385 | hp=hp, 386 | max_steps=hp['max_steps'], 387 | display_step=500, 388 | ruleset='all', 389 | rule_trains=rule_trains, 390 | rule_prob_map=None, 391 | seed=seed, 392 | ) 393 | 394 | 395 | def posttrain(pretrain_setup, posttrain_setup, trainables, seed): 396 | """Training based on pre-trained networks.""" 397 | hp = {'n_rnn': 128, 398 | 'l1_h': 1e-8, 399 | 'target_perf': 0.97, 400 | 'activation': 'relu', 401 | 'max_steps': 1e6, 402 | 'use_separate_input': True} 403 | 404 | if posttrain_setup == 0: 405 | rule_trains = ['contextdelaydm1'] 406 | elif posttrain_setup == 1: 407 | rule_trains = ['delayanti'] 408 | else: 409 | raise ValueError 410 | 411 | if trainables == 0: 412 | hp['trainables'] = 'all' 413 | elif trainables == 1: 414 | hp['trainables'] = 'rule' 415 | else: 416 | raise ValueError 417 | 418 | name = (str(pretrain_setup) + '_' + str(posttrain_setup) + 419 | '_' + str(trainables) + '_' + str(seed)) 420 | model_dir = os.path.join(DATAPATH, 'posttrain', name) 421 | load_dir = os.path.join(DATAPATH, 'pretrain', 422 | 'setup' + str(pretrain_setup), str(seed)) 423 | hp['load_dir'] = load_dir 424 | hp['pretrain_setup'] = pretrain_setup 425 | hp['posttrain_setup'] = posttrain_setup 426 | train.train(model_dir, 427 | hp=hp, 428 | max_steps=hp['max_steps'], 429 | display_step=50, 430 | ruleset='all', 431 | rule_trains=rule_trains, 432 | seed=seed, 433 | load_dir=load_dir, 434 | trainables=hp['trainables'], 435 | ) 436 | 437 | 438 | if __name__ == '__main__': 439 | debug_train_all() 440 | -------------------------------------------------------------------------------- /paper.py: -------------------------------------------------------------------------------- 1 | """ 2 | Main file for generating results in the paper: 3 | Clustering and compositionality of task representations 4 | in a neural network trained to perform many cognitive tasks 5 | Yang GR et al. 2017 BioRxiv 6 | """ 7 | from __future__ import absolute_import 8 | 9 | import tools 10 | from analysis import performance 11 | from analysis import standard_analysis 12 | from analysis import clustering 13 | from analysis import variance 14 | from analysis import taskset 15 | from analysis import varyhp 16 | from analysis import data_analysis 17 | from analysis import contextdm_analysis 18 | from analysis import posttrain_analysis 19 | 20 | 21 | # Directories of the models and the sample model 22 | # Change these to your directories 23 | # root_dir = './data/tanhgru' 24 | root_dir = './data/train_all' 25 | model_dir = root_dir + '/1' 26 | 27 | 28 | # # Performance Analysis----------------------------------------------------- 29 | # standard_analysis.schematic_plot(model_dir=model_dir) 30 | # performance.plot_performanceprogress(model_dir) 31 | # performance.psychometric_choice(model_dir) # Psychometric for dm 32 | # performance.psychometric_choiceattend(model_dir, no_ylabel=True) 33 | # performance.psychometric_choiceint(model_dir, no_ylabel=True) 34 | # 35 | # for rule in ['dm1', 'contextdm1', 'multidm']: 36 | # performance.plot_choicefamily_varytime(model_dir, rule) 37 | # performance.psychometric_delaychoice_varytime(model_dir, 'delaydm1') 38 | # 39 | # 40 | # # Clustering Analysis------------------------------------------------------ 41 | model_dir = root_dir + '/1' 42 | # CA = clustering.Analysis(model_dir, data_type='rule') 43 | # CA.plot_example_unit() 44 | # CA.plot_cluster_score() 45 | # CA.plot_variance() 46 | # CA.plot_2Dvisualization('PCA') 47 | # CA.plot_2Dvisualization('MDS') 48 | # CA.plot_2Dvisualization('tSNE') 49 | # CA.plot_lesions() 50 | # CA.plot_connectivity_byclusters() 51 | # 52 | # 53 | CA = clustering.Analysis(model_dir, data_type='epoch') 54 | CA.plot_variance() 55 | # CA.plot_2Dvisualization('tSNE') 56 | # 57 | # 58 | # # Varying hyperparameter analysis------------------------------------------ 59 | # varyhp_root_dir = './data/varyhp' 60 | # n_clusters, hp_list = varyhp.get_n_clusters(varyhp_root_dir) 61 | # varyhp.plot_n_clusters(n_clusters, hp_list) 62 | # varyhp.plot_n_cluster_hist(n_clusters, hp_list) 63 | # 64 | # 65 | # # FTV Analysis------------------------------------------------------------- 66 | # variance.plot_hist_varprop_selection(root_dir) 67 | # variance.plot_hist_varprop_selection('./data/tanhgru') 68 | # variance.plot_hist_varprop_all(root_dir, plot_control=True) 69 | # 70 | # 71 | # # ContextDM analysis------------------------------------------------------- 72 | # ua = contextdm_analysis.UnitAnalysis(model_dir) 73 | # ua.plot_inout_connections() 74 | # ua.plot_rec_connections() 75 | # ua.plot_rule_connections() 76 | # ua.prettyplot_hist_varprop() 77 | # 78 | # contextdm_analysis.plot_performance_choicetasks(model_dir, grouping='var') 79 | # contextdm_analysis.plot_performance_2D_all(model_dir, 'contextdm1') 80 | # 81 | # 82 | # # Task Representation------------------------------------------------------ 83 | # tsa = taskset.TaskSetAnalysis(model_dir) 84 | # tsa.compute_and_plot_taskspace(epochs=['stim1'], dim_reduction_type='PCA') 85 | # 86 | # 87 | # # Compositional Representation--------------------------------------------- 88 | # setups = [1, 2, 3] 89 | # for setup in setups: 90 | # taskset.plot_taskspace_group(root_dir, setup=setup, 91 | # restore=True, representation='rate') 92 | # taskset.plot_taskspace_group(root_dir, setup=setup, 93 | # restore=True, representation='weight') 94 | # taskset.plot_replacerule_performance_group( 95 | # root_dir, setup=setup, restore=True) 96 | 97 | # name = 'tanhgru' 98 | # name = 'mixrule' 99 | # name = 'mixrule_softplus' 100 | # setups = [1, 2] 101 | # d = './data/' + name 102 | # for setup in setups: 103 | # taskset.plot_taskspace_group(d, setup=setup, 104 | # restore=False, representation='rate', 105 | # fig_name_addon=name) 106 | # taskset.plot_taskspace_group(d, setup=setup, 107 | # restore=True, representation='weight', 108 | # fig_name_addon=name) 109 | # taskset.plot_replacerule_performance_group( 110 | # d, setup=setup, restore=False, fig_name_addon=name) 111 | 112 | 113 | ## Continual Learning Analysis---------------------------------------------- 114 | # hp_target0 = {'c_intsyn': 0, 'ksi_intsyn': 0.01, 115 | # 'activation': 'relu', 'max_steps': 4e5} 116 | # hp_target1 = {'c_intsyn': 1, 'ksi_intsyn': 0.01, 117 | # 'activation': 'relu', 'max_steps': 4e5} 118 | # model_dirs0 = tools.find_all_models('data/seq/', hp_target0) 119 | # model_dirs1 = tools.find_all_models('data/seq/', hp_target1) 120 | # model_dirs0 = tools.select_by_perf(model_dirs0, perf_min=0.8) 121 | # model_dirs1 = tools.select_by_perf(model_dirs1, perf_min=0.8) 122 | # performance.plot_performanceprogress_cont((model_dirs0[0], model_dirs1[2])) 123 | # performance.plot_finalperformance_cont(model_dirs0, model_dirs1) 124 | # data_analysis.plot_fracvar_hist_byhp(hp_vary='c_intsyn', mode='all_var', legend=False) 125 | # data_analysis.plot_fracvar_hist_byhp(hp_vary='p_weight_train', mode='all_var') 126 | 127 | 128 | ## Data analysis------------------------------------------------------------ 129 | # Note that these wouldn't work without the data file 130 | # data_analysis.plot_all('mante_single_ar') 131 | # data_analysis.plot_all('mante_single_fe') 132 | # data_analysis.plot_all('mante_ar') 133 | # data_analysis.plot_all('mante_fe') 134 | 135 | ## Post-training of pre-trained networks------------------------------------ 136 | # for posttrain_setup in range(2): 137 | # for trainables in ['all', 'rule']: 138 | # posttrain_analysis.plot_posttrain_performance(posttrain_setup, trainables) 139 | -------------------------------------------------------------------------------- /submit_jobs.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | """ 3 | Launching jobs on the NYU cluster 4 | """ 5 | from __future__ import absolute_import 6 | 7 | import os 8 | import argparse 9 | import subprocess 10 | import numpy as np 11 | 12 | import tools 13 | 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('run') 16 | args = parser.parse_args() 17 | 18 | sbatchpath = './sbatch/' 19 | scratchpath = '/scratch/gy441/multitask/' 20 | 21 | 22 | def write_jobfile(cmd, jobname, sbatchpath, scratchpath, 23 | nodes=1, ppn=1, gpus=0, mem=16, nhours=18): 24 | """ 25 | Create a job file. 26 | 27 | Args: 28 | cmd : str, Command to execute. 29 | jobname : str, Name of the job. 30 | sbatchpath : str, Directory to store SBATCH file in. 31 | scratchpath : str, Directory to store output files in. 32 | nodes : int, optional, Number of compute nodes. 33 | ppn : int, optional, Number of cores per node. 34 | gpus : int, optional, Number of GPU cores. 35 | mem : int, optional, Amount, in GB, of memory. 36 | ndays : int, optional, Running time, in days. 37 | queue : str, optional, Queue name. 38 | 39 | Returns: 40 | jobfile : str, Path to the job file. 41 | """ 42 | 43 | tools.mkdir_p(sbatchpath) 44 | jobfile = os.path.join(sbatchpath, jobname + '.s') 45 | logname = os.path.join('log', jobname) 46 | 47 | if gpus == 0: 48 | with open(jobfile, 'w') as f: 49 | f.write( 50 | '#! /bin/bash\n' 51 | + '\n' 52 | + '#SBATCH --nodes={}\n'.format(nodes) 53 | #+ '#SBATCH --ntasks=1\n' 54 | + '#SBATCH --ntasks-per-node=1\n' 55 | + '#SBATCH --cpus-per-task={}\n'.format(ppn) 56 | + '#SBATCH --mem={}GB\n'.format(mem) 57 | + '#SBATCH --time={}:00:00\n'.format(nhours) 58 | + '#SBATCH --job-name={}\n'.format(jobname[0:16]) 59 | + '#SBATCH --output={}log/{}.o\n'.format(scratchpath, jobname[0:16]) 60 | + '\n' 61 | + 'cd {}\n'.format(scratchpath) 62 | + 'pwd > {}.log\n'.format(logname) 63 | + 'date >> {}.log\n'.format(logname) 64 | + 'which python >> {}.log\n'.format(logname) 65 | + '{} >> {}.log 2>&1\n'.format(cmd, logname) 66 | + '\n' 67 | + 'exit 0;\n' 68 | ) 69 | else: 70 | with open(jobfile, 'w') as f: 71 | f.write( 72 | '#! /bin/bash\n' 73 | + '\n' 74 | + '#SBATCH --nodes={}\n'.format(nodes) 75 | + '#SBATCH --ntasks-per-node=1\n' 76 | + '#SBATCH --cpus-per-task={}\n'.format(ppn) 77 | + '#SBATCH --mem={}GB\n'.format(mem) 78 | + '#SBATCH --partition=xwang_gpu\n' 79 | + '#SBATCH --gres=gpu:1\n' 80 | + '#SBATCH --time={}:00:00\n'.format(nhours) 81 | + '#SBATCH --job-name={}\n'.format(jobname[0:16]) 82 | + '#SBATCH --output={}log/{}.o\n'.format(scratchpath, jobname[0:16]) 83 | + '\n' 84 | + 'cd {}\n'.format(scratchpath) 85 | + 'pwd > {}.log\n'.format(logname) 86 | + 'date >> {}.log\n'.format(logname) 87 | + 'which python >> {}.log\n'.format(logname) 88 | + '{} >> {}.log 2>&1\n'.format(cmd, logname) 89 | + '\n' 90 | + 'exit 0;\n' 91 | ) 92 | return jobfile 93 | 94 | 95 | if args.run == 'all': 96 | for seed in range(0, 40): 97 | jobname = 'train_all_{:d}'.format(seed) 98 | train_arg = 'seed={:d}'.format(seed) 99 | cmd = r'''python -c "import experiment as e;e.train_all('''+\ 100 | train_arg+''')"''' 101 | 102 | jobfile = write_jobfile( 103 | cmd, jobname, sbatchpath, scratchpath, ppn=1, gpus=0) 104 | subprocess.call(['sbatch', jobfile]) 105 | 106 | elif args.run == 'analysis_all': 107 | for seed in range(0, 40): 108 | jobname = 'analysis_all_{:d}'.format(seed) 109 | train_arg = 'seed={:d}'.format(seed) 110 | cmd = r'''python -c "import experiment as e;e.train_all_analysis('''+\ 111 | train_arg+''')"''' 112 | 113 | jobfile = write_jobfile( 114 | cmd, jobname, sbatchpath, scratchpath, ppn=1, gpus=0) 115 | subprocess.call(['sbatch', jobfile]) 116 | 117 | elif args.run == 'tanhgru': 118 | for seed in range(0, 20): 119 | jobname = 'tanhgru_{:d}'.format(seed) 120 | train_arg = 'seed={:d}'.format(seed) 121 | cmd = r'''python -c "import experiment as e;e.train_all_tanhgru('''+\ 122 | train_arg+''')"''' 123 | 124 | jobfile = write_jobfile( 125 | cmd, jobname, sbatchpath, scratchpath, ppn=1, gpus=0) 126 | subprocess.call(['sbatch', jobfile]) 127 | 128 | elif args.run == 'mixrule': 129 | for seed in range(0, 20): 130 | jobname = 'mr_{:d}'.format(seed) 131 | train_arg = 'seed={:d}'.format(seed) 132 | cmd = r'''python -c "import experiment as e;e.train_all_mixrule('''+\ 133 | train_arg+''')"''' 134 | 135 | jobfile = write_jobfile( 136 | cmd, jobname, sbatchpath, scratchpath, ppn=1, gpus=0) 137 | subprocess.call(['sbatch', jobfile]) 138 | 139 | elif args.run == 'mixrule_softplus': 140 | for seed in range(0, 20): 141 | jobname = 'mrsp_{:d}'.format(seed) 142 | train_arg = 'seed={:d}'.format(seed) 143 | cmd = r'''python -c "import experiment as e;e.train_all_mixrule_softplus('''+\ 144 | train_arg+''')"''' 145 | 146 | jobfile = write_jobfile( 147 | cmd, jobname, sbatchpath, scratchpath, ppn=1, gpus=0) 148 | subprocess.call(['sbatch', jobfile]) 149 | 150 | elif args.run == 'all_varyhp': 151 | for i in range(0, 20): 152 | jobname = 'train_varyhp_{:d}'.format(i) 153 | train_arg = '{:d}'.format(i) 154 | cmd = r'''python -c "import experiment as e;e.train_vary_hp('''+\ 155 | train_arg+''')"''' 156 | 157 | jobfile = write_jobfile( 158 | cmd, jobname, sbatchpath, scratchpath, ppn=1, gpus=0) 159 | subprocess.call(['sbatch', jobfile]) 160 | 161 | elif args.run == 'seq': 162 | for i in range(0, 40): 163 | jobname = 'seq_{:d}'.format(i) 164 | train_arg = '{:d}'.format(i) 165 | cmd = r'''python -c "import experiment as e;e.train_seq('''+\ 166 | train_arg+''')"''' 167 | 168 | jobfile = write_jobfile( 169 | cmd, jobname, sbatchpath, scratchpath, ppn=1, gpus=0) 170 | subprocess.call(['sbatch', jobfile]) 171 | 172 | elif args.run == 'seq_varyhp': 173 | for i in range(0, 72): 174 | jobname = 'seq_varyhp_{:d}'.format(i) 175 | train_arg = '{:d}'.format(i) 176 | cmd = r'''python -c "import experiment as e;e.train_vary_hp_seq('''+\ 177 | train_arg+''')"''' 178 | 179 | jobfile = write_jobfile( 180 | cmd, jobname, sbatchpath, scratchpath, ppn=1, gpus=0) 181 | subprocess.call(['sbatch', jobfile]) 182 | 183 | elif args.run == 'mante': 184 | for seed in range(0, 20): 185 | jobname = 'train_mante_{:d}'.format(seed) 186 | train_arg = 'seed={:d}'.format(seed) 187 | cmd = r'''python -c "import experiment as e;e.train_mante('''+\ 188 | train_arg+''')"''' 189 | 190 | jobfile = write_jobfile( 191 | cmd, jobname, sbatchpath, scratchpath, ppn=1, gpus=0) 192 | subprocess.call(['sbatch', jobfile]) 193 | 194 | elif args.run == 'mante_tanh': 195 | for seed in range(0, 50): 196 | jobname = 'mantetanh_{:d}'.format(seed) 197 | train_arg = 'seed={:d}'.format(seed) 198 | cmd = r'''python -c "import experiment as e;e.mante_tanh('''+\ 199 | train_arg+''')"''' 200 | 201 | jobfile = write_jobfile( 202 | cmd, jobname, sbatchpath, scratchpath, ppn=1, gpus=0) 203 | subprocess.call(['sbatch', jobfile]) 204 | 205 | elif args.run == 'mante_vary_l2init': 206 | for i in range(0, 300): 207 | jobname = 'mante_vary_l2init_{:d}'.format(i) 208 | train_arg = '{:d}'.format(i) 209 | cmd = r'''python -c "import experiment as e;e.vary_l2_init_mante('''+\ 210 | train_arg+''')"''' 211 | 212 | jobfile = write_jobfile( 213 | cmd, jobname, sbatchpath, scratchpath, ppn=1, gpus=0) 214 | subprocess.call(['sbatch', jobfile]) 215 | 216 | elif args.run == 'mante_vary_l2weight': 217 | for i in range(0, 300): 218 | jobname = 'mante_vary_l2weight_{:d}'.format(i) 219 | train_arg = '{:d}'.format(i) 220 | cmd = r'''python -c "import experiment as e;e.vary_l2_weight_mante('''+\ 221 | train_arg+''')"''' 222 | 223 | jobfile = write_jobfile( 224 | cmd, jobname, sbatchpath, scratchpath, ppn=1, gpus=0) 225 | subprocess.call(['sbatch', jobfile]) 226 | 227 | elif args.run == 'mante_vary_pweighttrain': 228 | for i in range(200, 260): 229 | jobname = 'mante_vary_pweighttrain_{:d}'.format(i) 230 | train_arg = '{:d}'.format(i) 231 | cmd = r'''python -c "import experiment as e;e.vary_p_weight_train_mante('''+\ 232 | train_arg+''')"''' 233 | 234 | jobfile = write_jobfile( 235 | cmd, jobname, sbatchpath, scratchpath, ppn=1, gpus=0) 236 | subprocess.call(['sbatch', jobfile]) 237 | 238 | elif args.run == 'pretrain': 239 | for seed in range(0, 20): 240 | for setup in range(2): 241 | jobname = 'pt_{:d}_{:d}'.format(setup, seed) 242 | train_arg = 'setup={:d},seed={:d}'.format(setup, seed) 243 | cmd = r'''python -c "import experiment as e;e.pretrain('''+\ 244 | train_arg+''')"''' 245 | 246 | jobfile = write_jobfile( 247 | cmd, jobname, sbatchpath, scratchpath, ppn=1, gpus=0) 248 | subprocess.call(['sbatch', jobfile]) 249 | 250 | elif args.run == 'posttrain': 251 | for seed in range(0, 20): 252 | for pretrain_setup in range(2): 253 | for posttrain_setup in range(2): 254 | for trainables in range(2): 255 | jobname = 'pt{:d}{:d}{:d}{:d}'.format( 256 | pretrain_setup, posttrain_setup, trainables, seed) 257 | train_arg = '{:d}, {:d}, {:d}, {:d}'.format( 258 | pretrain_setup, posttrain_setup, trainables, seed) 259 | cmd = r'''python -c "import experiment as e;e.posttrain('''+\ 260 | train_arg+''')"''' 261 | 262 | jobfile = write_jobfile( 263 | cmd, jobname, sbatchpath, scratchpath, ppn=1, gpus=0) 264 | subprocess.call(['sbatch', jobfile]) 265 | 266 | # Grid search 267 | elif args.run == 'grid': 268 | raise NotImplementedError() 269 | s = 1 270 | n_unit = 256 271 | for seed in range(5): 272 | for i_c, c_intsyn in enumerate([0.1, 1.0]): 273 | for i_ksi, ksi_intsyn in enumerate([0.01, 0.1, 1.0]): 274 | jobname = 'grid{:d}_{:d}_{:d}'.format(seed, i_ksi, i_c) 275 | train_arg = 'c={:0.6f}, ksi={:0.6f}, seed={:d}'.format( 276 | c_intsyn, ksi_intsyn, seed) 277 | train_arg+= r", save_name='"+'{:d}_{:d}_{:d}grid'.format(seed, i_ksi, i_c)+r"'" 278 | 279 | cmd = r'''python -c "import paper as p;p.cont_train('''+train_arg+''')"''' 280 | 281 | jobfile = write_jobfile(cmd, jobname, sbatchpath, scratchpath, ppn=1, gpus=0) 282 | subprocess.call(['sbatch', jobfile]) 283 | 284 | else: 285 | raise ValueError('Unknow argument run ' + str(args.run)) 286 | -------------------------------------------------------------------------------- /tools.py: -------------------------------------------------------------------------------- 1 | """Utility functions.""" 2 | 3 | import os 4 | import errno 5 | import six 6 | import json 7 | import pickle 8 | import numpy as np 9 | 10 | 11 | def gen_feed_dict(model, trial, hp): 12 | """Generate feed_dict for session run.""" 13 | if hp['in_type'] == 'normal': 14 | feed_dict = {model.x: trial.x, 15 | model.y: trial.y, 16 | model.c_mask: trial.c_mask} 17 | elif hp['in_type'] == 'multi': 18 | n_time, batch_size = trial.x.shape[:2] 19 | new_shape = [n_time, 20 | batch_size, 21 | hp['rule_start']*hp['n_rule']] 22 | 23 | x = np.zeros(new_shape, dtype=np.float32) 24 | for i in range(batch_size): 25 | ind_rule = np.argmax(trial.x[0, i, hp['rule_start']:]) 26 | i_start = ind_rule*hp['rule_start'] 27 | x[:, i, i_start:i_start+hp['rule_start']] = \ 28 | trial.x[:, i, :hp['rule_start']] 29 | 30 | feed_dict = {model.x: x, 31 | model.y: trial.y, 32 | model.c_mask: trial.c_mask} 33 | else: 34 | raise ValueError() 35 | 36 | return feed_dict 37 | 38 | 39 | def _contain_model_file(model_dir): 40 | """Check if the directory contains model files.""" 41 | for f in os.listdir(model_dir): 42 | if 'model.ckpt' in f: 43 | return True 44 | return False 45 | 46 | 47 | def _valid_model_dirs(root_dir): 48 | """Get valid model directories given a root directory.""" 49 | return [x[0] for x in os.walk(root_dir) if _contain_model_file(x[0])] 50 | 51 | 52 | def valid_model_dirs(root_dir): 53 | """Get valid model directories given a root directory(s). 54 | 55 | Args: 56 | root_dir: str or list of strings 57 | """ 58 | if isinstance(root_dir, six.string_types): 59 | return _valid_model_dirs(root_dir) 60 | else: 61 | model_dirs = list() 62 | for d in root_dir: 63 | model_dirs.extend(_valid_model_dirs(d)) 64 | return model_dirs 65 | 66 | 67 | def load_log(model_dir): 68 | """Load the log file of model save_name""" 69 | fname = os.path.join(model_dir, 'log.json') 70 | if not os.path.isfile(fname): 71 | return None 72 | 73 | with open(fname, 'r') as f: 74 | log = json.load(f) 75 | return log 76 | 77 | 78 | def save_log(log): 79 | """Save the log file of model.""" 80 | model_dir = log['model_dir'] 81 | fname = os.path.join(model_dir, 'log.json') 82 | with open(fname, 'w') as f: 83 | json.dump(log, f) 84 | 85 | 86 | def load_hp(model_dir): 87 | """Load the hyper-parameter file of model save_name""" 88 | fname = os.path.join(model_dir, 'hp.json') 89 | if not os.path.isfile(fname): 90 | fname = os.path.join(model_dir, 'hparams.json') # backward compat 91 | if not os.path.isfile(fname): 92 | return None 93 | 94 | with open(fname, 'r') as f: 95 | hp = json.load(f) 96 | 97 | # Use a different seed aftering loading, 98 | # since loading is typically for analysis 99 | hp['rng'] = np.random.RandomState(hp['seed']+1000) 100 | return hp 101 | 102 | 103 | def save_hp(hp, model_dir): 104 | """Save the hyper-parameter file of model save_name""" 105 | hp_copy = hp.copy() 106 | hp_copy.pop('rng') # rng can not be serialized 107 | with open(os.path.join(model_dir, 'hp.json'), 'w') as f: 108 | json.dump(hp_copy, f) 109 | 110 | 111 | def load_pickle(file): 112 | try: 113 | with open(file, 'rb') as f: 114 | data = pickle.load(f) 115 | except UnicodeDecodeError as e: 116 | with open(file, 'rb') as f: 117 | data = pickle.load(f, encoding='latin1') 118 | except Exception as e: 119 | print('Unable to load data ', file, ':', e) 120 | raise 121 | return data 122 | 123 | 124 | def find_all_models(root_dir, hp_target): 125 | """Find all models that satisfy hyperparameters. 126 | 127 | Args: 128 | root_dir: root directory 129 | hp_target: dictionary of hyperparameters 130 | 131 | Returns: 132 | model_dirs: list of model directories 133 | """ 134 | dirs = valid_model_dirs(root_dir) 135 | 136 | model_dirs = list() 137 | for d in dirs: 138 | hp = load_hp(d) 139 | if all(hp[key] == val for key, val in hp_target.items()): 140 | model_dirs.append(d) 141 | 142 | return model_dirs 143 | 144 | 145 | def find_model(root_dir, hp_target, perf_min=None): 146 | """Find one model that satisfies hyperparameters. 147 | 148 | Args: 149 | root_dir: root directory 150 | hp_target: dictionary of hyperparameters 151 | perf_min: float or None. If not None, minimum performance to be chosen 152 | 153 | Returns: 154 | d: model directory 155 | """ 156 | model_dirs = find_all_models(root_dir, hp_target) 157 | if perf_min is not None: 158 | model_dirs = select_by_perf(model_dirs, perf_min) 159 | 160 | if not model_dirs: 161 | # If list empty 162 | print('Model not found') 163 | return None, None 164 | 165 | d = model_dirs[0] 166 | hp = load_hp(d) 167 | 168 | log = load_log(d) 169 | # check if performance exceeds target 170 | if log['perf_min'][-1] < hp['target_perf']: 171 | print("""Warning: this network perform {:0.2f}, not reaching target 172 | performance {:0.2f}.""".format( 173 | log['perf_min'][-1], hp['target_perf'])) 174 | 175 | return d 176 | 177 | 178 | def select_by_perf(model_dirs, perf_min): 179 | """Select a list of models by a performance threshold.""" 180 | new_model_dirs = list() 181 | for model_dir in model_dirs: 182 | log = load_log(model_dir) 183 | # check if performance exceeds target 184 | if log['perf_min'][-1] > perf_min: 185 | new_model_dirs.append(model_dir) 186 | return new_model_dirs 187 | 188 | 189 | def mkdir_p(path): 190 | """ 191 | Portable mkdir -p 192 | 193 | """ 194 | try: 195 | os.makedirs(path) 196 | except OSError as e: 197 | if e.errno == errno.EEXIST and os.path.isdir(path): 198 | pass 199 | else: 200 | raise 201 | 202 | 203 | def gen_ortho_matrix(dim, rng=None): 204 | """Generate random orthogonal matrix 205 | Taken from scipy.stats.ortho_group 206 | Copied here from compatibilty with older versions of scipy 207 | """ 208 | H = np.eye(dim) 209 | for n in range(1, dim): 210 | if rng is None: 211 | x = np.random.normal(size=(dim-n+1,)) 212 | else: 213 | x = rng.normal(size=(dim-n+1,)) 214 | # random sign, 50/50, but chosen carefully to avoid roundoff error 215 | D = np.sign(x[0]) 216 | x[0] += D*np.sqrt((x*x).sum()) 217 | # Householder transformation 218 | Hx = -D*(np.eye(dim-n+1) - 2.*np.outer(x, x)/(x*x).sum()) 219 | mat = np.eye(dim) 220 | mat[n-1:, n-1:] = Hx 221 | H = np.dot(H, mat) 222 | return H 223 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """Main training loop""" 2 | 3 | from __future__ import division 4 | 5 | import sys 6 | import time 7 | from collections import defaultdict 8 | 9 | import numpy as np 10 | import matplotlib.pyplot as plt 11 | import tensorflow as tf 12 | 13 | import task 14 | from task import generate_trials 15 | from network import Model, get_perf 16 | from analysis import variance 17 | import tools 18 | 19 | 20 | def get_default_hp(ruleset): 21 | '''Get a default hp. 22 | 23 | Useful for debugging. 24 | 25 | Returns: 26 | hp : a dictionary containing training hpuration 27 | ''' 28 | num_ring = task.get_num_ring(ruleset) 29 | n_rule = task.get_num_rule(ruleset) 30 | 31 | n_eachring = 32 32 | n_input, n_output = 1+num_ring*n_eachring+n_rule, n_eachring+1 33 | hp = { 34 | # batch size for training 35 | 'batch_size_train': 64, 36 | # batch_size for testing 37 | 'batch_size_test': 512, 38 | # input type: normal, multi 39 | 'in_type': 'normal', 40 | # Type of RNNs: LeakyRNN, LeakyGRU, EILeakyGRU, GRU, LSTM 41 | 'rnn_type': 'LeakyRNN', 42 | # whether rule and stimulus inputs are represented separately 43 | 'use_separate_input': False, 44 | # Type of loss functions 45 | 'loss_type': 'lsq', 46 | # Optimizer 47 | 'optimizer': 'adam', 48 | # Type of activation runctions, relu, softplus, tanh, elu 49 | 'activation': 'relu', 50 | # Time constant (ms) 51 | 'tau': 100, 52 | # discretization time step (ms) 53 | 'dt': 20, 54 | # discretization time step/time constant 55 | 'alpha': 0.2, 56 | # recurrent noise 57 | 'sigma_rec': 0.05, 58 | # input noise 59 | 'sigma_x': 0.01, 60 | # leaky_rec weight initialization, diag, randortho, randgauss 61 | 'w_rec_init': 'randortho', 62 | # a default weak regularization prevents instability 63 | 'l1_h': 0, 64 | # l2 regularization on activity 65 | 'l2_h': 0, 66 | # l2 regularization on weight 67 | 'l1_weight': 0, 68 | # l2 regularization on weight 69 | 'l2_weight': 0, 70 | # l2 regularization on deviation from initialization 71 | 'l2_weight_init': 0, 72 | # proportion of weights to train, None or float between (0, 1) 73 | 'p_weight_train': None, 74 | # Stopping performance 75 | 'target_perf': 1., 76 | # number of units each ring 77 | 'n_eachring': n_eachring, 78 | # number of rings 79 | 'num_ring': num_ring, 80 | # number of rules 81 | 'n_rule': n_rule, 82 | # first input index for rule units 83 | 'rule_start': 1+num_ring*n_eachring, 84 | # number of input units 85 | 'n_input': n_input, 86 | # number of output units 87 | 'n_output': n_output, 88 | # number of recurrent units 89 | 'n_rnn': 256, 90 | # number of input units 91 | 'ruleset': ruleset, 92 | # name to save 93 | 'save_name': 'test', 94 | # learning rate 95 | 'learning_rate': 0.001, 96 | # intelligent synapses parameters, tuple (c, ksi) 97 | 'c_intsyn': 0, 98 | 'ksi_intsyn': 0, 99 | } 100 | 101 | return hp 102 | 103 | 104 | def do_eval(sess, model, log, rule_train): 105 | """Do evaluation. 106 | 107 | Args: 108 | sess: tensorflow session 109 | model: Model class instance 110 | log: dictionary that stores the log 111 | rule_train: string or list of strings, the rules being trained 112 | """ 113 | hp = model.hp 114 | if not hasattr(rule_train, '__iter__'): 115 | rule_name_print = rule_train 116 | else: 117 | rule_name_print = ' & '.join(rule_train) 118 | 119 | print('Trial {:7d}'.format(log['trials'][-1]) + 120 | ' | Time {:0.2f} s'.format(log['times'][-1]) + 121 | ' | Now training '+rule_name_print) 122 | 123 | for rule_test in hp['rules']: 124 | n_rep = 16 125 | batch_size_test_rep = int(hp['batch_size_test']/n_rep) 126 | clsq_tmp = list() 127 | creg_tmp = list() 128 | perf_tmp = list() 129 | for i_rep in range(n_rep): 130 | trial = generate_trials( 131 | rule_test, hp, 'random', batch_size=batch_size_test_rep) 132 | feed_dict = tools.gen_feed_dict(model, trial, hp) 133 | c_lsq, c_reg, y_hat_test = sess.run( 134 | [model.cost_lsq, model.cost_reg, model.y_hat], 135 | feed_dict=feed_dict) 136 | 137 | # Cost is first summed over time, 138 | # and averaged across batch and units 139 | # We did the averaging over time through c_mask 140 | perf_test = np.mean(get_perf(y_hat_test, trial.y_loc)) 141 | clsq_tmp.append(c_lsq) 142 | creg_tmp.append(c_reg) 143 | perf_tmp.append(perf_test) 144 | 145 | log['cost_'+rule_test].append(np.mean(clsq_tmp, dtype=np.float64)) 146 | log['creg_'+rule_test].append(np.mean(creg_tmp, dtype=np.float64)) 147 | log['perf_'+rule_test].append(np.mean(perf_tmp, dtype=np.float64)) 148 | print('{:15s}'.format(rule_test) + 149 | '| cost {:0.6f}'.format(np.mean(clsq_tmp)) + 150 | '| c_reg {:0.6f}'.format(np.mean(creg_tmp)) + 151 | ' | perf {:0.2f}'.format(np.mean(perf_tmp))) 152 | sys.stdout.flush() 153 | 154 | # TODO: This needs to be fixed since now rules are strings 155 | if hasattr(rule_train, '__iter__'): 156 | rule_tmp = rule_train 157 | else: 158 | rule_tmp = [rule_train] 159 | perf_tests_mean = np.mean([log['perf_'+r][-1] for r in rule_tmp]) 160 | log['perf_avg'].append(perf_tests_mean) 161 | 162 | perf_tests_min = np.min([log['perf_'+r][-1] for r in rule_tmp]) 163 | log['perf_min'].append(perf_tests_min) 164 | 165 | # Saving the model 166 | model.save() 167 | tools.save_log(log) 168 | 169 | return log 170 | 171 | 172 | def display_rich_output(model, sess, step, log, model_dir): 173 | """Display step by step outputs during training.""" 174 | variance._compute_variance_bymodel(model, sess) 175 | rule_pair = ['contextdm1', 'contextdm2'] 176 | save_name = '_atstep' + str(step) 177 | title = ('Step ' + str(step) + 178 | ' Perf. {:0.2f}'.format(log['perf_avg'][-1])) 179 | variance.plot_hist_varprop(model_dir, rule_pair, 180 | figname_extra=save_name, 181 | title=title) 182 | plt.close('all') 183 | 184 | 185 | def train(model_dir, 186 | hp=None, 187 | max_steps=1e7, 188 | display_step=500, 189 | ruleset='mante', 190 | rule_trains=None, 191 | rule_prob_map=None, 192 | seed=0, 193 | rich_output=False, 194 | load_dir=None, 195 | trainables=None, 196 | ): 197 | """Train the network. 198 | 199 | Args: 200 | model_dir: str, training directory 201 | hp: dictionary of hyperparameters 202 | max_steps: int, maximum number of training steps 203 | display_step: int, display steps 204 | ruleset: the set of rules to train 205 | rule_trains: list of rules to train, if None then all rules possible 206 | rule_prob_map: None or dictionary of relative rule probability 207 | seed: int, random seed to be used 208 | 209 | Returns: 210 | model is stored at model_dir/model.ckpt 211 | training configuration is stored at model_dir/hp.json 212 | """ 213 | 214 | tools.mkdir_p(model_dir) 215 | 216 | # Network parameters 217 | default_hp = get_default_hp(ruleset) 218 | if hp is not None: 219 | default_hp.update(hp) 220 | hp = default_hp 221 | hp['seed'] = seed 222 | hp['rng'] = np.random.RandomState(seed) 223 | 224 | # Rules to train and test. Rules in a set are trained together 225 | if rule_trains is None: 226 | # By default, training all rules available to this ruleset 227 | hp['rule_trains'] = task.rules_dict[ruleset] 228 | else: 229 | hp['rule_trains'] = rule_trains 230 | hp['rules'] = hp['rule_trains'] 231 | 232 | # Assign probabilities for rule_trains. 233 | if rule_prob_map is None: 234 | rule_prob_map = dict() 235 | 236 | # Turn into rule_trains format 237 | hp['rule_probs'] = None 238 | if hasattr(hp['rule_trains'], '__iter__'): 239 | # Set default as 1. 240 | rule_prob = np.array( 241 | [rule_prob_map.get(r, 1.) for r in hp['rule_trains']]) 242 | hp['rule_probs'] = list(rule_prob/np.sum(rule_prob)) 243 | tools.save_hp(hp, model_dir) 244 | 245 | # Build the model 246 | model = Model(model_dir, hp=hp) 247 | 248 | # Display hp 249 | for key, val in hp.items(): 250 | print('{:20s} = '.format(key) + str(val)) 251 | 252 | # Store results 253 | log = defaultdict(list) 254 | log['model_dir'] = model_dir 255 | 256 | # Record time 257 | t_start = time.time() 258 | 259 | with tf.Session() as sess: 260 | if load_dir is not None: 261 | model.restore(load_dir) # complete restore 262 | else: 263 | # Assume everything is restored 264 | sess.run(tf.global_variables_initializer()) 265 | 266 | # Set trainable parameters 267 | if trainables is None or trainables == 'all': 268 | var_list = model.var_list # train everything 269 | elif trainables == 'input': 270 | # train all nputs 271 | var_list = [v for v in model.var_list 272 | if ('input' in v.name) and ('rnn' not in v.name)] 273 | elif trainables == 'rule': 274 | # train rule inputs only 275 | var_list = [v for v in model.var_list if 'rule_input' in v.name] 276 | else: 277 | raise ValueError('Unknown trainables') 278 | model.set_optimizer(var_list=var_list) 279 | 280 | # penalty on deviation from initial weight 281 | if hp['l2_weight_init'] > 0: 282 | anchor_ws = sess.run(model.weight_list) 283 | for w, w_val in zip(model.weight_list, anchor_ws): 284 | model.cost_reg += (hp['l2_weight_init'] * 285 | tf.nn.l2_loss(w - w_val)) 286 | 287 | model.set_optimizer(var_list=var_list) 288 | 289 | # partial weight training 290 | if ('p_weight_train' in hp and 291 | (hp['p_weight_train'] is not None) and 292 | hp['p_weight_train'] < 1.0): 293 | for w in model.weight_list: 294 | w_val = sess.run(w) 295 | w_size = sess.run(tf.size(w)) 296 | w_mask_tmp = np.linspace(0, 1, w_size) 297 | hp['rng'].shuffle(w_mask_tmp) 298 | ind_fix = w_mask_tmp > hp['p_weight_train'] 299 | w_mask = np.zeros(w_size, dtype=np.float32) 300 | w_mask[ind_fix] = 1e-1 # will be squared in l2_loss 301 | w_mask = tf.constant(w_mask) 302 | w_mask = tf.reshape(w_mask, w.shape) 303 | model.cost_reg += tf.nn.l2_loss((w - w_val) * w_mask) 304 | model.set_optimizer(var_list=var_list) 305 | 306 | step = 0 307 | while step * hp['batch_size_train'] <= max_steps: 308 | try: 309 | # Validation 310 | if step % display_step == 0: 311 | log['trials'].append(step * hp['batch_size_train']) 312 | log['times'].append(time.time()-t_start) 313 | log = do_eval(sess, model, log, hp['rule_trains']) 314 | #if log['perf_avg'][-1] > model.hp['target_perf']: 315 | #check if minimum performance is above target 316 | if log['perf_min'][-1] > model.hp['target_perf']: 317 | print('Perf reached the target: {:0.2f}'.format( 318 | hp['target_perf'])) 319 | break 320 | 321 | if rich_output: 322 | display_rich_output(model, sess, step, log, model_dir) 323 | 324 | # Training 325 | rule_train_now = hp['rng'].choice(hp['rule_trains'], 326 | p=hp['rule_probs']) 327 | # Generate a random batch of trials. 328 | # Each batch has the same trial length 329 | trial = generate_trials( 330 | rule_train_now, hp, 'random', 331 | batch_size=hp['batch_size_train']) 332 | 333 | # Generating feed_dict. 334 | feed_dict = tools.gen_feed_dict(model, trial, hp) 335 | sess.run(model.train_step, feed_dict=feed_dict) 336 | 337 | step += 1 338 | 339 | except KeyboardInterrupt: 340 | print("Optimization interrupted by user") 341 | break 342 | 343 | print("Optimization finished!") 344 | 345 | 346 | def train_sequential( 347 | model_dir, 348 | rule_trains, 349 | hp=None, 350 | max_steps=1e7, 351 | display_step=500, 352 | ruleset='mante', 353 | seed=0, 354 | ): 355 | '''Train the network sequentially. 356 | 357 | Args: 358 | model_dir: str, training directory 359 | rule_trains: a list of list of tasks to train sequentially 360 | hp: dictionary of hyperparameters 361 | max_steps: int, maximum number of training steps for each list of tasks 362 | display_step: int, display steps 363 | ruleset: the set of rules to train 364 | seed: int, random seed to be used 365 | 366 | Returns: 367 | model is stored at model_dir/model.ckpt 368 | training configuration is stored at model_dir/hp.json 369 | ''' 370 | 371 | tools.mkdir_p(model_dir) 372 | 373 | # Network parameters 374 | default_hp = get_default_hp(ruleset) 375 | if hp is not None: 376 | default_hp.update(hp) 377 | hp = default_hp 378 | hp['seed'] = seed 379 | hp['rng'] = np.random.RandomState(seed) 380 | hp['rule_trains'] = rule_trains 381 | # Get all rules by flattening the list of lists 382 | hp['rules'] = [r for rs in rule_trains for r in rs] 383 | 384 | # Number of training iterations for each rule 385 | rule_train_iters = [len(r)*max_steps for r in rule_trains] 386 | 387 | tools.save_hp(hp, model_dir) 388 | # Display hp 389 | for key, val in hp.items(): 390 | print('{:20s} = '.format(key) + str(val)) 391 | 392 | # Using continual learning or not 393 | c, ksi = hp['c_intsyn'], hp['ksi_intsyn'] 394 | 395 | # Build the model 396 | model = Model(model_dir, hp=hp) 397 | 398 | grad_unreg = tf.gradients(model.cost_lsq, model.var_list) 399 | 400 | # Store results 401 | log = defaultdict(list) 402 | log['model_dir'] = model_dir 403 | 404 | # Record time 405 | t_start = time.time() 406 | 407 | # tensorboard summaries 408 | placeholders = list() 409 | for v_name in ['Omega0', 'omega0', 'vdelta']: 410 | for v in model.var_list: 411 | placeholder = tf.placeholder(tf.float32, shape=v.shape) 412 | tf.summary.histogram(v_name + '/' + v.name, placeholder) 413 | placeholders.append(placeholder) 414 | merged = tf.summary.merge_all() 415 | test_writer = tf.summary.FileWriter(model_dir + '/tb') 416 | 417 | def relu(x): 418 | return x * (x > 0.) 419 | 420 | # Use customized session that launches the graph as well 421 | with tf.Session() as sess: 422 | sess.run(tf.global_variables_initializer()) 423 | 424 | # penalty on deviation from initial weight 425 | if hp['l2_weight_init'] > 0: 426 | raise NotImplementedError() 427 | 428 | # Looping 429 | step_total = 0 430 | for i_rule_train, rule_train in enumerate(hp['rule_trains']): 431 | step = 0 432 | 433 | # At the beginning of new tasks 434 | # Only if using intelligent synapses 435 | v_current = sess.run(model.var_list) 436 | 437 | if i_rule_train == 0: 438 | v_anc0 = v_current 439 | Omega0 = [np.zeros(v.shape, dtype='float32') for v in v_anc0] 440 | omega0 = [np.zeros(v.shape, dtype='float32') for v in v_anc0] 441 | v_delta = [np.zeros(v.shape, dtype='float32') for v in v_anc0] 442 | elif c > 0: 443 | v_anc0_prev = v_anc0 444 | v_anc0 = v_current 445 | v_delta = [v-v_prev for v, v_prev in zip(v_anc0, v_anc0_prev)] 446 | 447 | # Make sure all elements in omega0 are non-negative 448 | # Penalty 449 | Omega0 = [relu(O + o / (v_d ** 2 + ksi)) 450 | for O, o, v_d in zip(Omega0, omega0, v_delta)] 451 | 452 | # Update cost 453 | model.cost_reg = tf.constant(0.) 454 | for v, w, v_val in zip(model.var_list, Omega0, v_current): 455 | model.cost_reg += c * tf.reduce_sum( 456 | tf.multiply(tf.constant(w), 457 | tf.square(v - tf.constant(v_val)))) 458 | model.set_optimizer() 459 | 460 | # Store Omega0 to tf summary 461 | feed_dict = dict(zip(placeholders, Omega0 + omega0 + v_delta)) 462 | summary = sess.run(merged, feed_dict=feed_dict) 463 | test_writer.add_summary(summary, i_rule_train) 464 | 465 | # Reset 466 | omega0 = [np.zeros(v.shape, dtype='float32') for v in v_anc0] 467 | 468 | # Keep training until reach max iterations 469 | while (step * hp['batch_size_train'] <= 470 | rule_train_iters[i_rule_train]): 471 | # Validation 472 | if step % display_step == 0: 473 | trial = step_total * hp['batch_size_train'] 474 | log['trials'].append(trial) 475 | log['times'].append(time.time()-t_start) 476 | log['rule_now'].append(rule_train) 477 | log = do_eval(sess, model, log, rule_train) 478 | if log['perf_avg'][-1] > model.hp['target_perf']: 479 | print('Perf reached the target: {:0.2f}'.format( 480 | hp['target_perf'])) 481 | break 482 | 483 | # Training 484 | rule_train_now = hp['rng'].choice(rule_train) 485 | # Generate a random batch of trials. 486 | # Each batch has the same trial length 487 | trial = generate_trials( 488 | rule_train_now, hp, 'random', 489 | batch_size=hp['batch_size_train']) 490 | 491 | # Generating feed_dict. 492 | feed_dict = tools.gen_feed_dict(model, trial, hp) 493 | 494 | # Continual learning with intelligent synapses 495 | v_prev = v_current 496 | 497 | # This will compute the gradient BEFORE train step 498 | _, v_grad = sess.run([model.train_step, grad_unreg], 499 | feed_dict=feed_dict) 500 | # Get the weight after train step 501 | v_current = sess.run(model.var_list) 502 | 503 | # Update synaptic importance 504 | omega0 = [ 505 | o - (v_c - v_p) * v_g for o, v_c, v_p, v_g in 506 | zip(omega0, v_current, v_prev, v_grad) 507 | ] 508 | 509 | step += 1 510 | step_total += 1 511 | 512 | print("Optimization Finished!") 513 | 514 | 515 | def train_rule_only( 516 | model_dir, 517 | rule_trains, 518 | max_steps, 519 | hp=None, 520 | ruleset='all', 521 | seed=0, 522 | ): 523 | '''Customized training function. 524 | 525 | The network sequentially but only train rule for the second set. 526 | First train the network to perform tasks in group 1, then train on group 2. 527 | When training group 2, only rule connections are being trained. 528 | 529 | Args: 530 | model_dir: str, training directory 531 | rule_trains: a list of list of tasks to train sequentially 532 | hp: dictionary of hyperparameters 533 | max_steps: int, maximum number of training steps for each list of tasks 534 | display_step: int, display steps 535 | ruleset: the set of rules to train 536 | seed: int, random seed to be used 537 | 538 | Returns: 539 | model is stored at model_dir/model.ckpt 540 | training configuration is stored at model_dir/hp.json 541 | ''' 542 | 543 | tools.mkdir_p(model_dir) 544 | 545 | # Network parameters 546 | default_hp = get_default_hp(ruleset) 547 | if hp is not None: 548 | default_hp.update(hp) 549 | hp = default_hp 550 | hp['seed'] = seed 551 | hp['rng'] = np.random.RandomState(seed) 552 | hp['rule_trains'] = rule_trains 553 | # Get all rules by flattening the list of lists 554 | hp['rules'] = [r for rs in rule_trains for r in rs] 555 | 556 | # Number of training iterations for each rule 557 | if hasattr(max_steps, '__iter__'): 558 | rule_train_iters = max_steps 559 | else: 560 | rule_train_iters = [len(r) * max_steps for r in rule_trains] 561 | 562 | tools.save_hp(hp, model_dir) 563 | # Display hp 564 | for key, val in hp.items(): 565 | print('{:20s} = '.format(key) + str(val)) 566 | 567 | # Build the model 568 | model = Model(model_dir, hp=hp) 569 | 570 | # Store results 571 | log = defaultdict(list) 572 | log['model_dir'] = model_dir 573 | 574 | # Record time 575 | t_start = time.time() 576 | 577 | # Use customized session that launches the graph as well 578 | with tf.Session() as sess: 579 | sess.run(tf.global_variables_initializer()) 580 | 581 | # penalty on deviation from initial weight 582 | if hp['l2_weight_init'] > 0: 583 | raise NotImplementedError() 584 | 585 | # Looping 586 | step_total = 0 587 | for i_rule_train, rule_train in enumerate(hp['rule_trains']): 588 | step = 0 589 | 590 | if i_rule_train == 0: 591 | display_step = 200 592 | else: 593 | display_step = 50 594 | 595 | if i_rule_train > 0: 596 | # var_list = [v for v in model.var_list 597 | # if ('input' in v.name) and ('rnn' not in v.name)] 598 | var_list = [v for v in model.var_list if 'rule_input' in v.name] 599 | model.set_optimizer(var_list=var_list) 600 | 601 | # Keep training until reach max iterations 602 | while (step * hp['batch_size_train'] <= 603 | rule_train_iters[i_rule_train]): 604 | # Validation 605 | if step % display_step == 0: 606 | trial = step_total * hp['batch_size_train'] 607 | log['trials'].append(trial) 608 | log['times'].append(time.time() - t_start) 609 | log['rule_now'].append(rule_train) 610 | log = do_eval(sess, model, log, rule_train) 611 | if log['perf_avg'][-1] > model.hp['target_perf']: 612 | print('Perf reached the target: {:0.2f}'.format( 613 | hp['target_perf'])) 614 | break 615 | 616 | # Training 617 | rule_train_now = hp['rng'].choice(rule_train) 618 | # Generate a random batch of trials. 619 | # Each batch has the same trial length 620 | trial = generate_trials( 621 | rule_train_now, hp, 'random', 622 | batch_size=hp['batch_size_train']) 623 | 624 | # Generating feed_dict. 625 | feed_dict = tools.gen_feed_dict(model, trial, hp) 626 | 627 | # This will compute the gradient BEFORE train step 628 | _ = sess.run(model.train_step, feed_dict=feed_dict) 629 | 630 | step += 1 631 | step_total += 1 632 | 633 | print("Optimization Finished!") 634 | 635 | 636 | if __name__ == '__main__': 637 | import argparse 638 | import os 639 | parser = argparse.ArgumentParser( 640 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 641 | 642 | parser.add_argument('--modeldir', type=str, default='data/debug') 643 | args = parser.parse_args() 644 | 645 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 646 | hp = {'activation': 'softplus', 647 | 'n_rnn': 64, 648 | 'mix_rule': True, 649 | 'l1_h': 0., 650 | 'use_separate_input': True} 651 | train(args.modeldir, 652 | seed=1, 653 | hp=hp, 654 | ruleset='all', 655 | rule_trains=['contextdelaydm1', 'contextdelaydm2', 656 | 'contextdm1', 'contextdm2'], 657 | display_step=500) --------------------------------------------------------------------------------