├── LICENSE ├── Test.py └── Tempotron.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Dieuwke Hupkes 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Test.py: -------------------------------------------------------------------------------- 1 | """ 2 | Unittesting for tempotron 3 | """ 4 | 5 | 6 | import unittest 7 | import numpy as np 8 | from Tempotron import Tempotron 9 | 10 | 11 | class TestTempotron(unittest.TestCase): 12 | """ 13 | Test the functionality of the Tempotron class. 14 | """ 15 | def setUp(self): 16 | # Change synaptic efficacies later!! 17 | self.tempotron = Tempotron(0, 10, 2.5, np.ones(10)) 18 | pass 19 | 20 | def test_normalisation_tempotron(self): 21 | """ 22 | Test if the tempotron normalisation is 23 | computed correctly. 24 | """ 25 | V_computed = self.tempotron.V_norm 26 | V_man = 2.116534735957599 27 | self.assertEqual(V_computed, V_man) 28 | 29 | def test_membrane_potential0(self): 30 | """ 31 | Test for tempotron.compute_membrane_potential 32 | for spike_times = {} 33 | """ 34 | self.tempotron.efficacies = np.random.random(10) - 0.5 35 | spike_times = [set([])] * 10 36 | V = self.tempotron.compute_membrane_potential(10, spike_times) 37 | self.assertEqual(V, 0.0) 38 | 39 | def test_membrane_potential1(self): 40 | """ 41 | Test 2 for tempotron.compute_membrane_potential 42 | for non empty spike_times 43 | """ 44 | self.tempotron.efficacies = np.random.random(10) 45 | spike_times = np.array([[0], [0], [0], [], [], [], [], [], [], []]) 46 | potential = self.tempotron.compute_membrane_potential(4.62, spike_times) 47 | potential_man = self.tempotron.efficacies[0:3].sum() 48 | self.assertAlmostEqual(potential, potential_man) 49 | 50 | def test_spike_contributions1(self): 51 | """ 52 | Test 1 for tempotron.compute_spike_contributions 53 | Every neuron spikes once at a different time 54 | """ 55 | spike_times = np.array([[0], [10], [20], [30], [40]]) 56 | spike_contribs = self.tempotron.compute_spike_contributions(40, spike_times) 57 | spike_contribs_correct = np.array([0.03877, 0.1054, 0.2857, 0.7399, 0.0]) 58 | self.assertTrue(np.allclose(spike_contribs, spike_contribs_correct, atol=1e-4)) 59 | 60 | def test_spike_contributions2(self): 61 | """ 62 | Test 2 for tempotron.compute_spike_contributions 63 | One neuron spikes twice 64 | """ 65 | spike_times = np.array([[0, 10], [], [], [], []]) 66 | spike_contribs = self.tempotron.compute_spike_contributions(20, spike_times) 67 | spike_contribs_correct = np.array([1.025596, 0.0, 0.0, 0.0, 0.0]) 68 | self.assertTrue(np.allclose(spike_contribs, spike_contribs_correct, atol=1e-4)) 69 | 70 | def test_compute_tmax1(self): 71 | """ 72 | Test for tempotron.compute_tmax 73 | Only one spike at t=10 74 | """ 75 | self.tempotron.efficacies = np.random.random(10) 76 | spike_times = ([[10], [], [], [], [], [], [], [], [], []]) 77 | 78 | # maximum occurs 4.6 ms after spike 79 | self.assertAlmostEqual(self.tempotron.compute_tmax(spike_times), 14.62098120373297, places=7) 80 | 81 | def test_compute_tmax2(self): 82 | """ 83 | Test for tempotron.compute_tmax 84 | Multiple spikes at a single time 85 | """ 86 | self.tempotron.efficacies = np.random.random(10) 87 | spike_times = ([[10], [10], [10], [10], [], [], [], [], [], []]) 88 | 89 | # maximum occurs 4.6 ms after spike 90 | self.assertAlmostEqual(self.tempotron.compute_tmax(spike_times), 14.62098120373297, places=7) 91 | 92 | def test_compute_tmax3(self): 93 | """ 94 | Test for tempotron.compute_tmax 95 | Boundary case with non-existing derivative 96 | """ 97 | tempotron1 = Tempotron(0, 15, 15/4, np.array([2, -1])) 98 | spike_times = np.array([[np.log(2)*15, 100], [np.log(3)*15, 101]]) 99 | 100 | self.assertAlmostEqual(tempotron1.compute_tmax(spike_times), 16.43259988) 101 | 102 | 103 | if __name__ == '__main__': 104 | unittest.main() 105 | -------------------------------------------------------------------------------- /Tempotron.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | 4 | 5 | class Tempotron: 6 | """ 7 | A class representing a tempotron, as described in 8 | Gutig & Sompolinsky (2006). 9 | The (subthreshold) membrane voltage of the tempotron 10 | is a weighted sum from all incoming spikes and the 11 | resting potential of the neuron. The contribution of 12 | each spike decays exponentiall with time, the speed of 13 | this decay is determined by two parameters tau and tau_s, 14 | denoting the decay time constants of membrane integration 15 | and synaptic currents, respectively. 16 | """ 17 | def __init__(self, V_rest, tau, tau_s, synaptic_efficacies, threshold=1.0): 18 | # set parameters as attributes 19 | self.V_rest = V_rest 20 | self.tau = float(tau) 21 | self.tau_s = float(tau_s) 22 | self.log_tts = np.log(self.tau/self.tau_s) 23 | self.threshold = threshold 24 | self.efficacies = synaptic_efficacies 25 | self.t_spi = 10 # spike integration time, compute this with formula 26 | 27 | # compute normalisation factor V_0 28 | self.V_norm = self.compute_norm_factor(tau, tau_s) 29 | 30 | def compute_norm_factor(self, tau, tau_s): 31 | """ 32 | Compute and return the normalisation factor: 33 | 34 | V_0 = (tau * tau_s * log(tau/tau_s)) / (tau - tau_s) 35 | 36 | That normalises the function: 37 | 38 | K(t-t_i) = V_0 (exp(-(t-t_i)/tau) - exp(-(t-t_i)/tau_s) 39 | 40 | Such that it amplitude is 1 and the unitary PSP 41 | amplitudes are given by the synaptic efficacies. 42 | """ 43 | tmax = (tau * tau_s * np.log(tau/tau_s)) / (tau - tau_s) 44 | v_max = self.K(1, tmax, 0) 45 | V_0 = 1/v_max 46 | return V_0 47 | 48 | def K(self, V_0, t, t_i): 49 | """ 50 | Compute the function 51 | 52 | K(t-t_i) = V_0 (exp(-(t-t_i)/tau) - exp(-(t-t_i)/tau_s) 53 | """ 54 | if t < t_i: 55 | value = 0 56 | else: 57 | value = V_0 * (np.exp(-(t-t_i)/self.tau) - np.exp(-(t-t_i)/self.tau_s)) 58 | return value 59 | 60 | def compute_membrane_potential(self, t, spike_times): 61 | """ 62 | Compute the membrane potential of the neuron given 63 | by the function: 64 | 65 | V(t) = sum_i w_i sum_{t_i} K(t-t_i) + V_rest 66 | 67 | Where w_i denote the synaptic efficacies and t_i denote 68 | ith afferent. 69 | 70 | :param spike_times: an array with at position i the spike times of 71 | the ith afferent 72 | :type spike_times: numpy.ndarray 73 | """ 74 | # create an array with the contributions of the 75 | # spikes for each synaps 76 | spike_contribs = self.compute_spike_contributions(t, spike_times) 77 | 78 | # multiply with the synaptic efficacies 79 | total_incoming = spike_contribs * self.efficacies 80 | 81 | # add sum and add V_rest to get membrane potential 82 | V = total_incoming.sum() + self.V_rest 83 | 84 | return V 85 | 86 | def compute_derivative(self, t, spike_times): 87 | """ 88 | Compute the derivative of the membrane potential 89 | of the neuron at time t. 90 | This derivative is given by: 91 | 92 | V'(t) = V_0 sum_i w_i sum_{t_n} (exp(-(t-t_n)/tau_s)/tau_s - exp(-(t-t_n)/tau)/tau) 93 | 94 | for t_n < t 95 | """ 96 | # sort spikes in chronological order 97 | spikes_chron = [(time, synapse) for synapse in xrange(len(spike_times)) for time in spike_times[synapse]] 98 | spikes_chron.sort() 99 | 100 | # Make a list of spike times and their corresponding weights 101 | spikes = [(s[0], self.efficacies[s[1]]) for s in spikes_chron] 102 | 103 | # At time t we want to incorporate all the spikes for which 104 | # t_spike < t 105 | sum_tau = np.array([spike[1]*np.exp(spike[0]/self.tau) for spike in spikes if spike[0] <= t]).sum() 106 | sum_tau_s = np.array([spike[1]*np.exp(spike[0]/self.tau_s) for spike in spikes if spike[0] <= t]).sum() 107 | 108 | factor_tau = np.exp(-t/self.tau)/self.tau 109 | factor_tau_s = np.exp(-t/self.tau_s)/self.tau_s 110 | 111 | deriv = self.V_norm * (factor_tau_s*sum_tau_s - factor_tau*sum_tau) 112 | 113 | return deriv 114 | 115 | def compute_spike_contributions(self, t, spike_times): 116 | """ 117 | Compute the decayed contribution of the incoming spikes. 118 | """ 119 | # nr of synapses 120 | N_synapse = len(spike_times) 121 | # loop over spike times to compute the contributions 122 | # of individual spikes 123 | spike_contribs = np.zeros(N_synapse) 124 | for neuron_pos in xrange(N_synapse): 125 | for spike_time in spike_times[neuron_pos]: 126 | # print self.K(self.V_rest, t, spike_time) 127 | spike_contribs[neuron_pos] += self.K(self.V_norm, t, spike_time) 128 | return spike_contribs 129 | 130 | def train(self, io_pairs, steps, learning_rate): 131 | """ 132 | Train the tempotron on the given input-output pairs, 133 | applying gradient decscend to adapt the weights. 134 | 135 | :param steps: the maximum number of training steps 136 | :param io_pairs: a list with tuples of spike times and the 137 | desired response on them 138 | :param learning_rate: the learning rate of the gradient descend 139 | """ 140 | # Run until maximum number of steps is reached or 141 | # no weight updates occur anymore 142 | for i in xrange(steps): 143 | # go through io-pairs in random order 144 | for spike_times, target in np.random.permutation(io_pairs): 145 | self.adapt_weights(spike_times, target, learning_rate) 146 | return 147 | 148 | def get_membrane_potentials(self, t_start, t_end, spike_times, interval=0.1): 149 | """ 150 | Get a list of membrane potentials from t_start to t_end 151 | as a result of the inputted spike times. 152 | """ 153 | # create vectorised version of membrane potential function 154 | potential_vect = np.vectorize(self.compute_membrane_potential) 155 | # exclude spike times from being vectorised 156 | potential_vect.excluded.add(1) 157 | 158 | # compute membrane potentials 159 | t = np.arange(t_start, t_end, interval) 160 | membrane_potentials = potential_vect(t, spike_times) 161 | 162 | return t, membrane_potentials 163 | 164 | def get_derivatives(self, t_start, t_end, spike_times, interval=0.1): 165 | """ 166 | Get a list of the derivative of the membrane potentials from 167 | t_start to t_end as a result of the inputted spike times. 168 | """ 169 | # create a vectorised version of derivative function 170 | deriv_vect = np.vectorize(self.compute_derivative) 171 | # exclude spike times from being vectorised 172 | deriv_vect.excluded.add(1) 173 | 174 | # compute derivatives 175 | t = np.arange(t_start, t_end, interval) 176 | derivatives = deriv_vect(t, spike_times) 177 | 178 | return t, derivatives 179 | 180 | def plot_membrane_potential(self, t_start, t_end, spike_times, interval=0.1): 181 | """ 182 | Plot the membrane potential between t_start and t_end as 183 | a result of the input spike times. 184 | :param t_start: start time in ms 185 | :param t_end: end time in ms 186 | :param interval: time step at which membrane potential is computed 187 | """ 188 | # compute membrane_potential 189 | t, membrane_potentials = self.get_membrane_potentials(t_start, t_end, spike_times, interval) 190 | 191 | # format axes 192 | plt.xlabel('Time (ms)') 193 | plt.ylabel('V(t)') 194 | 195 | ymax = max(membrane_potentials.max() + 0.1, self.threshold + 0.1) 196 | ymin = min(membrane_potentials.min() - 0.1, -self.threshold - 0.1) 197 | plt.ylim(ymax=ymax, ymin=ymin) 198 | plt.axhline(y=self.threshold, linestyle='--', color='k') 199 | 200 | # plot membrane potential 201 | plot = plt.plot(t, membrane_potentials) 202 | # return plot 203 | # plt.show() 204 | 205 | def plot_potential_and_derivative(self, t_start, t_end, spike_times, interval=0.1): 206 | """ 207 | Plot the membrane potential and the derivative of the membrane 208 | potential as a result of the input spikes between t_start and 209 | t_end. 210 | :param t_start: start time in ms 211 | :param t_end: end time in ms 212 | """ 213 | # compute membrane potentials 214 | t, membrane_potentials = self.get_membrane_potentials(t_start, t_end, spike_times, interval) 215 | 216 | # compute derivatives 217 | t, derivatives = self.get_derivatives(t_start, t_end, spike_times, interval) 218 | 219 | # format axes 220 | plt.xlabel('Time(ms)') 221 | # ylabel??? 222 | 223 | ymax = max(membrane_potentials.max() + 0.1, self.threshold + 0.1) 224 | ymin = min(membrane_potentials.min() - 0.1, -self.threshold - 0.1) 225 | plt.ylim(ymax=ymax, ymin=ymin) 226 | 227 | plt.axhline(y=self.threshold, linestyle='--', color='k') 228 | plt.axhline(y=0.0, linestyle='--', color='r') 229 | plt.axvline(x=16.5, color='b') 230 | 231 | # plot 232 | plt.plot(t, membrane_potentials, label='Membrane potential') 233 | plt.plot(t, derivatives, label='Derivative') 234 | plt.show() 235 | 236 | def compute_tmax(self, spike_times): 237 | """ 238 | Compute the maximum mebrane potential of the tempotron as 239 | a result of the input spikes. 240 | The maxima of the function can be computed analytically, but as 241 | there are as many maxima and minima as their are number of spikes, 242 | we still need to sort through them to find the highest one. 243 | 244 | The maxima are given by: 245 | 246 | t = (log(tau/tau_s) + log(sum w_n exp(t_n/tau_s)) - log(sum w_n exp(t_n/tau)))*tau_s*tau/ (tau-tau_s) 247 | 248 | for n = 1, 2, ..., len(spike_times) 249 | 250 | The time at which the membrane potential is maximal is given by 251 | Check if the input spikes result produce the desired 252 | output. Return tmax. (maybe I should return something else) 253 | """ 254 | 255 | # sort spikes in chronological order 256 | spikes_chron = [(time, synapse) for synapse in xrange(len(spike_times)) for time in spike_times[synapse]] 257 | spikes_chron.sort() 258 | 259 | # Make a list of spike times and their corresponding weights 260 | spikes = [(s[0], self.efficacies[s[1]]) for s in spikes_chron] 261 | times = np.array([spike[0] for spike in spikes]) 262 | weights = np.array([spike[1] for spike in spikes]) 263 | 264 | sum_tau = (weights*np.exp(times/self.tau)).cumsum() 265 | sum_tau_s = (weights*np.exp(times/self.tau_s)).cumsum() 266 | 267 | # when an inhibitive spike is generated when the membrane potential 268 | # is still growing, the derivative does not exist in the maximum 269 | # In such cases, thus when sum_tau/sum_tau_s is negative, 270 | # manually set tmax to the spike time of the second spike 271 | div = sum_tau_s/sum_tau 272 | boundary_cases = div < 0 273 | div[boundary_cases] = 10 274 | 275 | tmax_list = self.tau*self.tau_s*(self.log_tts + np.log(div))/(self.tau - self.tau_s) 276 | tmax_list[boundary_cases] = times[boundary_cases] 277 | 278 | vmax_list = np.array([self.compute_membrane_potential(t, spike_times) for t in tmax_list]) 279 | 280 | tmax = tmax_list[vmax_list.argmax()] 281 | 282 | return tmax 283 | 284 | def adapt_weights(self, spike_times, target, learning_rate): 285 | """ 286 | Modify the synaptic efficacies such that the learns 287 | to classify the input pattern correctly. 288 | Whenever an error occurs, the following update is 289 | computed: 290 | 291 | dw = lambda sum_{ti} K(t_max, ti) 292 | 293 | The synaptic efficacies are increased by this weight 294 | if the tempotron did erroneously not elecit an output 295 | spike, and decreased if it erroneously did. 296 | :param spike_times: an array with lists of spike times 297 | for every afferent 298 | :param output_spike: the classification of the input pattern 299 | :type output_spike: Boolean 300 | """ 301 | 302 | # compute tmax 303 | tmax = self.compute_tmax(spike_times) 304 | vmax = self.compute_membrane_potential(tmax, spike_times) 305 | 306 | # print "vmax = ", vmax 307 | # print "target = ", target 308 | 309 | # if target output is correct, don't adapt weights 310 | if (vmax >= self.threshold) == target: 311 | # print "no weight update necessary" 312 | return 313 | 314 | # compute weight updates 315 | dw = self.dw(learning_rate, tmax, spike_times) 316 | # print "update =", dw 317 | 318 | if target is True: 319 | self.efficacies += dw 320 | else: 321 | self.efficacies -= dw 322 | 323 | def dw(self, learning_rate, tmax, spike_times): 324 | """ 325 | Compute the update for synaptic efficacies wi, 326 | according to the following learning rule 327 | (implementing gradient descend dynamics): 328 | 329 | dwi = lambda sum_{ti} K(t_max, ti) 330 | 331 | where lambda is the learning rate and t_max denotes 332 | the time at which the postsynaptic potential V(t) 333 | reached its maximal value. 334 | """ 335 | # compute the contributions of the individual spikes at 336 | # time tmax 337 | spike_contribs = self.compute_spike_contributions(tmax, spike_times) 338 | 339 | # multiply with learning rate to get updates 340 | update = learning_rate * spike_contribs 341 | 342 | return update 343 | 344 | 345 | if __name__ == '__main__': 346 | np.random.seed(0) 347 | efficacies = 1.8 * np.random.random(10) - 0.50 348 | print 'synaptic efficacies:', efficacies, '\n' 349 | 350 | tempotron = Tempotron(0, 10, 2.5, efficacies) 351 | # efficacies = np.array([0.8, 0.8, 0.8, 0.8, 0.8]) 352 | spike_times1 = np.array([[70, 200, 400], [], [400, 420], [], [110], [230], [240, 260, 340], [380], [300], [105]]) 353 | spike_times2 = np.array([[], [395], [50, 170], [], [70, 280], [], [290], [115], [250, 320], [225, 330]]) 354 | spike_times = np.array([[0, 10], [], [3], [], [], [], [], [], [], []]) 355 | 356 | # tempotron.plot_membrane_potential(0, 500, spike_times1) 357 | tempotron.plot_membrane_potential(0, 500, spike_times2) 358 | 359 | tempotron.train([(spike_times1, True), (spike_times2, False)], 300, learning_rate=10e-3) 360 | print tempotron.efficacies 361 | # tempotron.plot_membrane_potential(0, 500, spike_times1) 362 | tempotron.plot_membrane_potential(0, 500, spike_times2) 363 | plt.show() 364 | --------------------------------------------------------------------------------