├── run_a_trial.py ├── README.md └── FlexibleWM.py /run_a_trial.py: -------------------------------------------------------------------------------- 1 | # Bouchacourt and Buschman 2019 2 | # A Flexible Model of Working Memory 3 | 4 | from FlexibleWM import * 5 | 6 | dictionnary = {} # Add here any parameter you want to change from default. Defaults values are at the beginning of FlexibleWM.py 7 | MyModel = FlexibleWM(dictionnary) 8 | MyModel.run_a_trial() 9 | gcPython.collect() 10 | 11 | 12 | # if you want to create a specific network first, you can call MyModel.initialize_weights() with a dictionnary including 'create_a_specific_network':True ; then calling MyModel.run_a_trial() using a dictionnary including 'same_network_to_use':True 13 | 14 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FlexibleWorkingMemory 2 | Simplified code for the paper "A Flexible Model of Working Memory" Bouchacourt and Buschman, 2019. 3 | This code is written in Python and is using the Brian2 spiking network simulator (https://brian2.readthedocs.io/en/stable/) so you will need brian2 and brian2tools to use it. 4 | 5 | run_a_trial.py is calling FlexibleWM.py. 6 | 7 | In run_a_trial.py, you can add in the dictionnary the parameters you want to change from default. 8 | To see these parameters name and their default value, go to FlexibleWM.py, at the beginning of the script. 9 | 10 | Example : In the run_a_trial.py, you change dictionnary={} to : dictionnary={'value_of_specific_load':6, 'specific_load':True} to reproduce Figure 1 because otherwise a random initial load is chosen at each simulation, as described. 11 | 12 | It should create a folder with 2 compressed files (.npz) describing the simulation results (simulation_results.npz, description line 471 of FlexibleWM.py) and the tuning curve of neurons (Matrix_tuning.npz). It should also create a .png image called rasterplot_trial0.png where you can see a similar raster plot as the one of Figure 1 of the paper. Obviously, the number of forgotten memories varies from trial to trial. 13 | 14 | To run multiple simulations over multiple networks on a cluster, you can use the python multiprocessing package. If you plan to use it we can push the script. In general, it is better to use multiple arrays with slurm instead. 15 | If you are interested in seeing the script of specific paper analyses, or of the 2D surface or Hopfield-like sensory networks, please email us and we will push it. 16 | 17 | 18 | 19 | -------------------------------------------------------------------------------- /FlexibleWM.py: -------------------------------------------------------------------------------- 1 | # Bouchacourt and Buschman 2019 2 | # A Flexible Model of Working Memory 3 | # This is a simplified version of the code. Specific analyses of the paper can be pushed on demand. 4 | 5 | 6 | import pickle, numpy, scipy, pylab, os 7 | import pdb 8 | import scipy.stats 9 | import math 10 | import sys 11 | import os.path 12 | import logging 13 | from brian2 import * 14 | BrianLogger.log_level_error() 15 | from brian2tools import * 16 | import cython 17 | prefs.codegen.target = 'cython' 18 | import gc as gcPython 19 | 20 | class FlexibleWM: 21 | """Class running one trial of the network. We give it a dictionary including parameters not taking default value, as well as the name of the folder for saving.""" 22 | def __init__(self,spec): 23 | self.specF={} # empty dictionnary 24 | self.specF['name_simu'] = spec.get('name_simu','FlexibleWM') # Name of the folder to save results and eventual raster plots 25 | self.specF['Number_of_trials'] = spec.get('Number_of_trials',1) # increase in order to run multiple trials with the same network 26 | 27 | # Timing 28 | self.specF['clock'] = spec.get('clock',0.1) # Integration with time-step in msec 29 | self.specF['simtime'] = spec.get('simtime',1.1) # Simulation time in sec 30 | self.specF['RecSynapseTau'] = spec.get('RecSynapseTau',0.010) # Synaptic time constant 31 | self.specF['window_save_data'] = spec.get('window_save_data',0.1) # the time step for saving data like rates, it will also be the by default timestep for saving spikes 32 | if self.specF['window_save_data']!=0.1 or self.specF['simtime']!=1.1 : 33 | print(" ----------- WARNING : Make sure simtime divided by window_save_data is an integer ----------- ") 34 | 35 | # Network architecture 36 | self.specF['N_sensory'] = spec.get('N_sensory',512) # Number of recurrent neurons per SN 37 | self.specF['N_sensory_pools'] = spec.get('N_sensory_pools',8) # Number of ring-like SN 38 | self.specF['N_random'] = spec.get('N_random',1024) # Number of neuron in the random network 39 | self.specF['self_excitation'] = spec.get('self_excitation',False) # if you want to run the SN by itself, the self excitation has to be True 40 | self.specF['with_random_network'] = spec.get('with_random_network',True) # with the RN, self excitation is False, so that a SN by itself cannot maintain a memory 41 | self.specF['fI_slope'] = spec.get('fI_slope',0.4) # slope of the non-linear f-I function 42 | 43 | # Parameters describing recurrence within the SN 44 | self.specF['RecWNegativeWidth']= spec.get('RecWNegativeWidth',0.25); # width of negative surround suppression 45 | self.specF['RecWPositiveWidth'] = spec.get('RecWPositiveWidth',1); # width of positive amplification 46 | self.specF['RecWAmp_exc'] = spec.get('RecWAmp_exc',2) # amplitude of weight function 47 | self.specF['RecWAmp_inh'] = spec.get('RecWAmp_inh',2) # amplitude of weight function 48 | self.specF['RecWBaseline'] = spec.get('RecWBaseline',0.28) # baseline of weight matrix for recurrent network 49 | 50 | # Parameters describing the weights between SN and RN 51 | self.specF['eibalance_ff'] = spec.get('eibalance_ff',-1.) # if -1, perfect feed-forward balance from SN to RN 52 | self.specF['eibalance_fb'] = spec.get('eibalance_fb',-1.) # if -1, perfect feedback balance from RN to SN 53 | self.specF['RecToRndW_TargetFR'] = spec.get('RecToRndW_TargetFR',2.1) # parameter (alpha in the paper) used to compute the feedforward weight, before balancing 54 | self.specF['RndToRecW_TargetFR'] = spec.get('RndToRecW_TargetFR',0.2) # parameter (beta in the paper) used to compute the feedback weight, before balancing 55 | self.specF['RndRec_f'] = spec.get('RndRec_f',0.35) # connectivity (gamma in the paper) 56 | self.specF['factor'] = spec.get('factor',1000) # factor for computing weights values (see Methods of the paper) 57 | 58 | # Saving/Using or not a pre-saved network 59 | self.specF['same_network_to_use'] = spec.get('same_network_to_use',False) # if we want to initialise the network with weights previously saved 60 | self.specF['create_a_specific_network'] = spec.get('create_a_specific_network',False) # if we want save weights in order to run later the model with the same network 61 | self.specF['path_for_same_network'] = spec.get('path_for_same_network',self.specF['name_simu']+'/network') # path for the weights 62 | 63 | # Stimulation 64 | self.specF['specific_load'] = spec.get('specific_load',False) # whether to use a specific load for all trials, or having it random 65 | self.specF['value_of_specific_load'] = spec.get('value_of_specific_load',1) # value of the load if specific_load is True 66 | self.specF['start_stimulation'] = spec.get('start_stimulation',0.1) 67 | self.specF['end_stimulation'] = spec.get('end_stimulation',0.2) 68 | self.specF['input_strength'] = spec.get('input_strength',10) # strength of the stimulation 69 | self.specF['N_sensory_inputwidth'] = spec.get('N_sensory_inputwidth',32) 70 | self.specF['InputWidthFactor'] = spec.get('InputWidthFactor',3) 71 | self.specF['InputWidth'] = round(self.specF['N_sensory']/float(self.specF['N_sensory_inputwidth'])) # the width for input stimulation of the gaussian distribution 72 | 73 | # ML decoding 74 | self.specF['decode_spikes_timestep'] = spec.get('decode_spikes_timestep',0.1) # decoding window 75 | self.specF['path_load_matrix_tuning_sensory'] = spec.get('path_load_matrix_tuning_sensory',self.specF['name_simu']+'/Matrix_tuning.npz') # path for tuning curve matrix 76 | self.specF['compute_tuning_curve'] = spec.get('compute_tuning_curve',True) 77 | 78 | # Define path_to_save and eventual raster plot 79 | self.specF['plot_raster'] = spec.get('plot_raster',True) # plot a raster 80 | self.specF['path_to_save'] = self.define_path_to_save_results_from_the_trial() 81 | self.specF['path_sim'] = self.specF['path_to_save']+'simulation_results.npz' 82 | 83 | 84 | 85 | def define_path_to_save_results_from_the_trial(self) : 86 | path_to_save = self.specF['name_simu']+'/' 87 | if not os.path.exists(path_to_save): 88 | try : 89 | os.makedirs(path_to_save) 90 | except OSError: 91 | pass 92 | return path_to_save 93 | 94 | def apply_matFuncMask(self, m, target, mask, axis_ziou): 95 | for i in range(m.shape[0]) : 96 | for j in range(m.shape[1]) : 97 | if mask[i,j]: 98 | if axis_ziou : 99 | m[i,j]= float(self.specF['factor'])*target/numpy.sum(mask[:,j]) 100 | else : 101 | m[i,j]= float(self.specF['factor'])*target/numpy.sum(mask[i,:]) 102 | return m 103 | 104 | 105 | def compute_activity_vector(self,R_rn_timed) : 106 | neuron_angs = numpy.arange(1,self.specF['N_sensory']+1,1)/float(self.specF['N_sensory'])*2*math.pi 107 | exp_neuron_angs = numpy.exp(1j*neuron_angs,dtype=complex) 108 | Matrix_abs_timed = numpy.zeros(self.specF['N_sensory_pools']) 109 | Matrix_angle_timed = numpy.zeros(self.specF['N_sensory_pools']) 110 | R_rn_2 = numpy.ones(R_rn_timed.shape,dtype=complex) 111 | for index_pool in range(self.specF['N_sensory_pools']) : 112 | R_rn_2[index_pool*self.specF['N_sensory']:(index_pool+1)*self.specF['N_sensory']] = numpy.multiply(R_rn_timed[index_pool*self.specF['N_sensory']:(index_pool+1)*self.specF['N_sensory']], exp_neuron_angs, dtype=complex) 113 | R_rn_3 = numpy.mean(R_rn_2[index_pool*self.specF['N_sensory']:(index_pool+1)*self.specF['N_sensory']]) 114 | Matrix_angle_timed[index_pool] = numpy.angle(R_rn_3)*self.specF['N_sensory']/(2*math.pi) 115 | Matrix_abs_timed[index_pool] = numpy.absolute(R_rn_3) 116 | Matrix_angle_timed[Matrix_angle_timed<0]+=self.specF['N_sensory'] 117 | return Matrix_abs_timed, Matrix_angle_timed 118 | 119 | 120 | def compute_drift(self,Angle_output,Angle_input): 121 | Angle_output_rad = Angle_output*2*math.pi/float(self.specF['N_sensory']) 122 | Angle_input_rad = Angle_input*2*math.pi/float(self.specF['N_sensory']) 123 | difference_angle = ((Angle_output_rad-Angle_input_rad+math.pi)%(2*math.pi))-math.pi 124 | difference_complex = numpy.exp(1j*difference_angle,dtype=complex) 125 | difference_angle2 = numpy.angle(difference_complex) 126 | return difference_angle2 127 | 128 | def ml_decode(self,number_of_spikes_rn,Matrix_tc) : 129 | Matrix_likelihood_per_stim = numpy.zeros(self.specF['N_sensory']) 130 | for index_stim in range(self.specF['N_sensory']) : 131 | Matrix_likelihood_per_stim[index_stim] = numpy.dot(number_of_spikes_rn,numpy.log(Matrix_tc[:,index_stim])) 132 | S_ml = numpy.argmax(Matrix_likelihood_per_stim) 133 | if isinstance(S_ml, numpy.ndarray) : 134 | pdb.set_trace() 135 | return S_ml 136 | 137 | 138 | def give_input_stimulus(self,current_InputCenter) : 139 | inp_vect = numpy.zeros(self.specF['N_sensory']) 140 | inp_ind = numpy.arange(int(round(current_InputCenter-self.specF['InputWidthFactor']*self.specF['InputWidth'])),int(round(current_InputCenter+self.specF['InputWidthFactor']*self.specF['InputWidth']+1)),1) 141 | inp_scale = scipy.stats.norm.pdf(inp_ind-current_InputCenter,0,self.specF['InputWidth']) 142 | inp_scale/=float(numpy.amax(inp_scale)) 143 | inp_ind = numpy.remainder(inp_ind-1, self.specF['N_sensory']) 144 | inp_vect[inp_ind] = self.specF['input_strength']*inp_scale 145 | inp_vect[:] = inp_vect[:]-numpy.sum(inp_vect[:])/float(inp_vect[:].shape[0]) 146 | return inp_vect 147 | 148 | def compute_matrix_tuning_sensory(self) : 149 | Matrix_tuning = numpy.zeros((self.specF['N_sensory'], self.specF['N_sensory'])) # number of neurons in each pool, number of possible stimuli 150 | for index_stimulus in range(self.specF['N_sensory']) : 151 | Vector_stim = self.give_input_stimulus(index_stimulus) 152 | for index_sensoryneuron in range(self.specF['N_sensory']) : 153 | S_ext = Vector_stim[index_sensoryneuron] 154 | Matrix_tuning[index_sensoryneuron,index_stimulus] = 0.4*(1+math.tanh(self.specF['fI_slope']*S_ext-3))/self.specF['RecSynapseTau'] 155 | numpy.savez_compressed(self.specF['path_load_matrix_tuning_sensory'],Matrix_tuning=Matrix_tuning) 156 | return Matrix_tuning 157 | 158 | def compute_psth_for_mldecoding(self,time_matrix,spike_matrix,End_of_delay,timestep) : 159 | time_length = int(round(End_of_delay/timestep)) 160 | Matrix = numpy.zeros((self.specF['N_sensory_pools']*self.specF['N_sensory'],time_length)) 161 | for index_tab in range(time_matrix.shape[0]) : 162 | if time_matrix[index_tab] 1 333 | 334 | numpy.random.seed() 335 | 336 | # --------------------------------- Inputs --------------------------------------------------------- 337 | #Input vector into sensory network 338 | InputCenter = numpy.floor(numpy.random.rand(self.specF['N_sensory_pools'])*self.specF['N_sensory']) 339 | 340 | # For now we implement the input at the same time for all pools 341 | IT1 = self.specF['start_stimulation']*second 342 | IT2 = self.specF['end_stimulation']*second 343 | 344 | Matrix_pools_receiving_inputs = [] 345 | if self.specF['specific_load'] : 346 | load = self.specF['value_of_specific_load'] 347 | else : 348 | load = numpy.random.randint(low=1,high=self.specF['N_sensory_pools']+1) 349 | Matrix_pools = numpy.arange(self.specF['N_sensory_pools']) 350 | numpy.random.shuffle(Matrix_pools) 351 | Matrix_pools_receiving_inputs = Matrix_pools[:load] 352 | print("Load for this trial is "+str(load)) 353 | 354 | # We build inp_vect, a matrix which gives the stimulus input to each SN 355 | inp_vect = numpy.zeros((self.specF['N_sensory_pools'],self.specF['N_sensory'])) 356 | for index_pool in Matrix_pools_receiving_inputs : 357 | inp_vect[index_pool,:] = self.give_input_stimulus(InputCenter[index_pool]) 358 | Matrix_initial_input[index_simulation,index_pool] = InputCenter[index_pool] 359 | 360 | 361 | # --------------------------------- Running and recording --------------------------------------------------------- 362 | 363 | 364 | # Running 365 | print("Running...") 366 | for index_pool in range(self.specF['N_sensory_pools']) : 367 | Recurrent_Pools[self.specF['N_sensory']*index_pool:self.specF['N_sensory']*(index_pool+1)].S_ext = inp_baseline[index_pool] 368 | if self.specF['with_random_network'] : 369 | RCN_pool.S_ext_rnd = inp_baseline_rnd 370 | run(IT1) 371 | 372 | for index_pool in range(self.specF['N_sensory_pools']) : 373 | if index_pool in Matrix_pools_receiving_inputs : 374 | Recurrent_Pools[self.specF['N_sensory']*index_pool:self.specF['N_sensory']*(index_pool+1)].S_ext = inp_baseline[index_pool] + inp_vect[index_pool] 375 | else : 376 | Recurrent_Pools[self.specF['N_sensory']*index_pool:self.specF['N_sensory']*(index_pool+1)].S_ext = inp_baseline[index_pool] 377 | if self.specF['with_random_network'] : 378 | RCN_pool.S_ext_rnd = inp_baseline_rnd 379 | run(IT2-IT1) 380 | 381 | for index_pool in range(self.specF['N_sensory_pools']) : 382 | Recurrent_Pools[self.specF['N_sensory']*index_pool:self.specF['N_sensory']*(index_pool+1)].S_ext = inp_baseline[index_pool] 383 | if self.specF['with_random_network'] : 384 | RCN_pool.S_ext_rnd = inp_baseline_rnd 385 | run(Simtime-IT2) 386 | 387 | 388 | if self.specF['plot_raster'] : 389 | print("Plot of a raster into the folder ..") 390 | if self.specF['with_random_network'] : 391 | figure() 392 | subplot(211) 393 | plot_raster(S_rcn.i, S_rcn.t, time_unit=second, marker=',', color='k') 394 | ylabel('Random neurons') 395 | subplot(212) 396 | plot_raster(S_rn.i, S_rn.t, time_unit=second, marker=',', color='k') 397 | for index_sn in range(self.specF['N_sensory_pools']+1) : 398 | plot(numpy.arange(0,self.specF['simtime']+self.specF['window_save_data']/2.,self.specF['window_save_data']),index_sn*self.specF['N_sensory']*numpy.ones(int(round(self.specF['simtime']/self.specF['window_save_data']))+1),color='b', linewidth=0.5) 399 | ylabel('Sensory neurons') 400 | savefig(self.specF['path_to_save']+'rasterplot_trial'+str(index_simulation)+'.png') 401 | close() 402 | else : 403 | figure() 404 | plot_raster(S_rn.i, S_rn.t, time_unit=second, marker=',', color='k') 405 | for index_sn in range(self.specF['N_sensory_pools']+1) : 406 | plot(numpy.arange(0,self.specF['simtime']+self.specF['window_save_data']/2.,self.specF['window_save_data']),index_sn*self.specF['N_sensory']*numpy.ones(int(round(self.specF['simtime']/self.specF['window_save_data']))+1),color='b', linewidth=0.5) 407 | ylabel('Sensory neurons') 408 | savefig(self.specF['path_to_save']+'rasterplot_trial'+str(index_simulation)+'.png') 409 | close() 410 | 411 | 412 | # RESULTS 413 | End_of_delay = self.specF['simtime']-self.specF['window_save_data'] 414 | Time_chosen = int(round(End_of_delay/self.specF['window_save_data'])) 415 | R_rn_enddelay = numpy.transpose(R_rn.rate_rec[:,Time_chosen].copy()) # beware numpy.transpose is reference 416 | Matrix_abs, Matrix_angle = self.compute_activity_vector(R_rn_enddelay) 417 | 418 | Matrix_abs_all[index_simulation,:] = Matrix_abs 419 | Matrix_angle_all[index_simulation,:] = Matrix_angle 420 | if self.specF['compute_tuning_curve'] : 421 | Matrix_tuning = self.compute_matrix_tuning_sensory() 422 | print("Tuning curve is saved in the folder, next time you can include 'compute_tuning_curve':False to reuse it") 423 | 424 | 425 | print("\n---------- ML decoding at the end of trial "+str(index_simulation)+' ----------') 426 | tuning_data = numpy.load(self.specF['path_load_matrix_tuning_sensory']) 427 | Matrix_tuning = tuning_data['Matrix_tuning'] 428 | tuning_data.close() 429 | 430 | # S_rn.t is in second, we need to get rid of the unit time in order to compare it to timestep in the function compute_psth 431 | time_matrix = numpy.zeros(S_rn.t.shape[0]) 432 | for index in range(S_rn.t.shape[0]) : 433 | time_matrix[index] = S_rn.t[index] 434 | 435 | psth_rn = self.compute_psth_for_mldecoding(time_matrix,S_rn.i,End_of_delay,self.specF['decode_spikes_timestep'])[:,int(round(End_of_delay/self.specF['decode_spikes_timestep']))-1] 436 | for index_pool in range(self.specF['N_sensory_pools']) : 437 | Results_ml_spikes[index_simulation,index_pool] = self.ml_decode(psth_rn[index_pool*self.specF['N_sensory']:(index_pool+1)*self.specF['N_sensory']],Matrix_tuning) 438 | if numpy.isnan(InputCenter[index_pool])==False : # or index_pool in Matrix_pools_receiving_inputs 439 | Drift_from_ml_spikes[index_simulation,index_pool] = self.compute_drift(Results_ml_spikes[index_simulation,index_pool],InputCenter[index_pool]) 440 | 441 | print('Initial inputs were (nan means no initial input into this SN)') 442 | print(Matrix_initial_input[index_simulation,:]) 443 | print("Memory is found ") 444 | print(Matrix_abs>activity_threshold) 445 | print("ML decoding gives") 446 | print(Results_ml_spikes[index_simulation,:]) 447 | print("Decoding from the activity vector gives") 448 | print(Matrix_angle) 449 | print('\n') 450 | 451 | print("---------- Computing capacity, maintained and spurious memories ----------") 452 | 453 | proba_maintained = 0 454 | proba_spurious = 0 455 | for index_pool_capacity in range(self.specF['N_sensory_pools']) : 456 | if Matrix_abs[index_pool_capacity]>activity_threshold : 457 | if index_pool_capacity in Matrix_pools_receiving_inputs : 458 | proba_maintained+=1 459 | else : 460 | proba_spurious+=1 461 | print(str(proba_maintained)+' maintained memories') 462 | print(str(Matrix_pools_receiving_inputs.shape[0]-proba_maintained)+' forgotten memories') 463 | print(str(proba_spurious)+' spurious memories\n') 464 | proba_maintained/=float(load) 465 | if load!=self.specF['N_sensory_pools'] : 466 | proba_spurious/=float(self.specF['N_sensory_pools']-load) 467 | 468 | Matrix_all_results[index_simulation,:] = numpy.asarray([proba_maintained,proba_spurious]) 469 | 470 | # saving 471 | numpy.savez_compressed(self.specF['path_sim'], Matrix_all_results=Matrix_all_results,Matrix_abs_all=Matrix_abs_all, Matrix_angle_all=Matrix_angle_all, Results_ml_spikes=Results_ml_spikes, Drift_from_ml_spikes=Drift_from_ml_spikes,Matrix_initial_input = Matrix_initial_input) 472 | print("All results saved in the folder") 473 | gcPython.collect() 474 | 475 | return 476 | 477 | 478 | 479 | 480 | 481 | --------------------------------------------------------------------------------