├── .DS_Store ├── QTDB_test.mat ├── QTDB_train.mat ├── SHD ├── real_SHD.png └── intro.py ├── model_0.853338_36_New.pth ├── generate_ps_dataset.py ├── srnn_2layer-finalized (1).py ├── s_mnist-gpu.py └── vis_result copy 2.ipynb /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/byin-cwi/SRNN-ICONs2020/HEAD/.DS_Store -------------------------------------------------------------------------------- /QTDB_test.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/byin-cwi/SRNN-ICONs2020/HEAD/QTDB_test.mat -------------------------------------------------------------------------------- /QTDB_train.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/byin-cwi/SRNN-ICONs2020/HEAD/QTDB_train.mat -------------------------------------------------------------------------------- /SHD/real_SHD.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/byin-cwi/SRNN-ICONs2020/HEAD/SHD/real_SHD.png -------------------------------------------------------------------------------- /model_0.853338_36_New.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/byin-cwi/SRNN-ICONs2020/HEAD/model_0.853338_36_New.pth -------------------------------------------------------------------------------- /generate_ps_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | The function used to generate permuted sequencial mnist dataset. 3 | permute = np.random.permutation(784) 4 | """ 5 | 6 | import keras 7 | import numpy as np 8 | import matplotlib.pyplot as plt 9 | 10 | def apply_permutation(data,permuteMatrix): 11 | b,c,r = data.shape 12 | sdata = data.reshape(b,c*r) 13 | new_data = np.zeros_like(sdata) 14 | for i in range(len(sdata)): 15 | tmp = sdata[i] 16 | new_data[i,:] = tmp[permuteMatrix] 17 | return new_data 18 | 19 | (X_train, y_train), (X_test, y_test) = keras.datasets.mnist.load_data() 20 | permute = np.random.permutation(784) 21 | X_train_ps = apply_permutation(X_train,permute) 22 | X_test_ps = apply_permutation(X_test,permute) 23 | 24 | np.save("ps_X_train.npy",X_train_ps) 25 | np.save("ps_X_test.npy",X_test_ps) 26 | np.save('Y_train.npy',y_train) 27 | np.save('Y_test.npy',y_test) 28 | 29 | print("X_train shape: ",X_train_ps.shape) 30 | print("X_test shape: ",X_test_ps.shape) 31 | print("Y_train shape: ",y_train.shape) 32 | print("Y_test shape: ",y_test.shape) 33 | plt.subplot(131) 34 | plt.imshow(X_train[1].reshape(28,28)) 35 | plt.gca().set_title('original') 36 | 37 | plt.subplot(132) 38 | plt.imshow(X_train_ps[1].reshape(28,28)) 39 | plt.gca().set_title('permuted') 40 | 41 | plt.subplot(133) 42 | plt.imshow(permute.reshape(28,28)) 43 | plt.gca().set_title('permute Matrix') 44 | plt.show() 45 | -------------------------------------------------------------------------------- /SHD/intro.py: -------------------------------------------------------------------------------- 1 | import os 2 | import urllib.request 3 | import gzip, shutil 4 | # from keras.utils import get_file 5 | import matplotlib.pyplot as plt 6 | """ 7 | The dataset is 48kHZ with 24bits precision 8 | * 700 channels 9 | * longest 1.17s 10 | * shortest 0.316s 11 | """ 12 | 13 | # cache_dir=os.path.expanduser("~/data") 14 | # cache_subdir="hdspikes" 15 | # print("Using cache dir: %s"%cache_dir) 16 | # 17 | # # The remote directory with the data files 18 | # base_url = "https://compneuro.net/datasets" 19 | 20 | # Retrieve MD5 hashes from remote 21 | 22 | 23 | 24 | # file_hashes = { line.split()[1]:line.split()[0] for line in lines if len(line.split())==2 } 25 | 26 | files = ['data/shd_test.h5','data/shd_train.h5'] 27 | 28 | import tables 29 | import numpy as np 30 | fileh = tables.open_file(files[0], mode='r') 31 | units = fileh.root.spikes.units 32 | times = fileh.root.spikes.times 33 | labels = fileh.root.labels 34 | 35 | # This is how we access spikes and labels 36 | index = 0 37 | print("Times (ms):", times[index],max(times[index])) 38 | print("Unit IDs:", units[index]) 39 | print("Label:", labels[index]) 40 | 41 | 42 | def binary_image_readout(times,units,dt = 1e-3): 43 | img = [] 44 | N = int(1/dt) 45 | for i in range(N): 46 | idxs = np.argwhere(times<=i*dt).flatten() 47 | vals = units[idxs] 48 | vals = vals[vals > 0] 49 | vector = np.zeros(700) 50 | vector[700-vals] = 1 51 | times = np.delete(times,idxs) 52 | units = np.delete(units,idxs) 53 | img.append(vector) 54 | return np.array(img) 55 | 56 | def binary_image_spatical(times,units,dt = 1e-3,dc = 10): 57 | img = [] 58 | N = int(1/dt) 59 | C = int(700/dc) 60 | for i in range(N): 61 | idxs = np.argwhere(times<=i*dt).flatten() 62 | vals = units[idxs] 63 | vals = vals[vals > 0] 64 | vector = np.zeros(C)# add spacial count 65 | vector[700-vals] = 1 66 | times = np.delete(times,idxs) 67 | units = np.delete(units,idxs) 68 | img.append(vector) 69 | return np.array(img) 70 | 71 | 72 | def generate_dataset(file_name,dt=1e-3): 73 | fileh = tables.open_file(file_name, mode='r') 74 | units = fileh.root.spikes.units 75 | times = fileh.root.spikes.times 76 | labels = fileh.root.labels 77 | 78 | # This is how we access spikes and labels 79 | index = 0 80 | print("Number of samples: ",len(times)) 81 | X = [] 82 | y = [] 83 | for i in range(len(times)): 84 | tmp = binary_image_readout(times[i], units[i],dt=dt) 85 | X.append(tmp) 86 | y.append(labels[i]) 87 | return np.array(X),np.array(y) 88 | 89 | k = 1143 90 | plt.figure(figsize=(10,5)) 91 | plt.scatter(times[k],700-units[k], color="k", alpha=0.33, s=2) 92 | plt.plot(np.ones(700),np.arange(0,700),'r--',linewidth=2) 93 | plt.title("Label %i"%labels[k]) 94 | plt.xlabel('time [s]',fontsize=14) 95 | plt.ylabel('Channel',fontsize=14) 96 | plt.xlim([0,max(times[k])]) 97 | plt.ylim([0,700]) 98 | # plt.axis("off") 99 | plt.show() 100 | 101 | # test_X,testy = generate_dataset(files[0],dt=4e-3) 102 | # np.save('/Volumes/Data/b_yin/SHD/data_coount/testX_4ms.npy',test_X) 103 | # np.save('/Volumes/Data/b_yin/SHD/data_coount/testY_4ms.npy',testy) 104 | 105 | # train_X,trainy = generate_dataset(files[1],dt=4e-3) 106 | # np.save('/Volumes/Data/b_yin/SHD/data_coount/trainX_4ms.npy',train_X) 107 | # np.save('/Volumes/Data/b_yin/SHD/data_coount/trainY_4ms.npy',trainy) 108 | 109 | 110 | # # how many time steps on each sample 111 | l = [] 112 | for i in range(len(times)): 113 | l.append(len(set(times[i]))) 114 | print(max(l),np.argmax(l)) 115 | print(min(l),np.argmin(l)) 116 | # plt.hist(l,bins=20) 117 | # plt.show() 118 | # # the sampling frequence of spoken digits 119 | # l = [] 120 | for i in range(len(times)): 121 | a = np.array(sorted(list(set(times[i])))) 122 | n = len(a) 123 | l.append(min(a[1:]-a[:n-1])) 124 | print(max(l),np.argmax(l)) 125 | print(min(l),np.argmin(l)) 126 | # plt.hist(l) 127 | # plt.show() 128 | # 129 | # # how many spoken digits longer than 1s 130 | # l = [] 131 | # ll = [] 132 | # for i in range(len(times)): 133 | # l.append(max(times[i])) 134 | # if max(times[i])>1.: ll.append(i) 135 | # print(max(l),np.argmax(l)) 136 | # print(min(l),np.argmin(l)) 137 | # plt.hist(l,bins=20) 138 | # plt.show() 139 | # 140 | def binary_image_readout(times,units,dt = 1e-3): 141 | img = [] 142 | N = int(1/dt) 143 | for i in range(N): 144 | idxs = np.argwhere(times<=i*dt).flatten() 145 | vals = units[idxs] 146 | vals = vals[vals>0] 147 | vector = np.zeros(700) 148 | vector[700-vals] = 1 149 | times = np.delete(times,idxs) 150 | units = np.delete(units,idxs) 151 | img.append(vector) 152 | return np.array(img) 153 | idx = k 154 | tmp = binary_image_readout(times[idx],units[idx],dt=4e-3) 155 | plt.figure(figsize=(10,5)) 156 | plt.imshow(tmp.T,interpolation='nearest', aspect='auto',origin='lower',extent=[0,1.,0,700]) 157 | plt.show() 158 | 159 | plt.figure(figsize=(10,3)) 160 | plt.plot(np.arange(0,1,4e-3),np.sum(tmp,axis=1)) 161 | plt.xlim([0.,1.]) 162 | plt.ylabel('Spike Count',fontsize=16) 163 | plt.xlabel('Time [s]',fontsize=16) 164 | 165 | # # A quick raster plot for one of the samples 166 | # 167 | # 168 | # fig = plt.figure(figsize=(16,4)) 169 | # idx = ll[:3]#[1979,1358,626]#np.random.randint(len(times),size=3) 170 | # for i,k in enumerate(idx): 171 | # ax = plt.subplot(1,3,i+1) 172 | # ax.scatter(times[k],700-units[k], color="k", alpha=0.33, s=2) 173 | # ax.set_title("Label %i"%labels[k]) 174 | # # ax.axis("off") 175 | # plt.show() 176 | # 177 | # fig = plt.figure(figsize=(16,8)) 178 | # idx = ll[16:22]#[1979,1358,626]#np.random.randint(len(times),size=3) 179 | # for i,k in enumerate(idx): 180 | # ax = plt.subplot(2,3,i+1) 181 | # ax.scatter(times[k],700-units[k], color="k", alpha=0.33, s=2) 182 | # ax.set_title("Label %i"%labels[k]) 183 | # # ax.axis("off") 184 | # 185 | # plt.show() 186 | -------------------------------------------------------------------------------- /srnn_2layer-finalized (1).py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import torch 4 | import torch.nn as nn 5 | from torch.optim.lr_scheduler import StepLR 6 | import math 7 | import torch.nn.functional as F 8 | from torch.utils import data 9 | 10 | torch.manual_seed(0) 11 | 12 | train_X = np.load('data/trainX_10ms.npy') 13 | train_y = np.load('data/trainY_10ms.npy').astype(np.float) 14 | 15 | test_X = np.load('data/testX_10ms.npy') 16 | test_y = np.load('data/testY_10ms.npy').astype(np.float) 17 | 18 | print('dataset shape: ', train_X.shape) 19 | print('dataset shape: ', test_X.shape) 20 | 21 | batch_size = 128 22 | 23 | tensor_trainX = torch.Tensor(train_X) # transform to torch tensor 24 | tensor_trainY = torch.Tensor(train_y) 25 | train_dataset = data.TensorDataset(tensor_trainX, tensor_trainY) 26 | train_loader = data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True) 27 | tensor_testX = torch.Tensor(test_X) # transform to torch tensor 28 | tensor_testY = torch.Tensor(test_y) 29 | test_dataset = data.TensorDataset(tensor_testX, tensor_testY) 30 | test_loader = data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False) 31 | 32 | 33 | 34 | lens = 0.5 # hyper-parameters of approximate function 35 | num_epochs = 50 # 150 # n_iters / (len(train_dataset) / batch_size) 36 | num_epochs = int(num_epochs) 37 | 38 | 39 | 40 | b_j0 = 0.01 # neural threshold baseline 41 | R_m = 1 # membrane resistance 42 | dt = 1 # 43 | gamma = .5 # gradient scale 44 | 45 | 46 | 47 | class ActFun_adp(torch.autograd.Function): 48 | @staticmethod 49 | def forward(ctx, input): # input = membrane potential- threshold 50 | ctx.save_for_backward(input) 51 | return input.gt(0).float() # is firing ??? 52 | 53 | @staticmethod 54 | def backward(ctx, grad_output): # approximate the gradients 55 | input, = ctx.saved_tensors 56 | grad_input = grad_output.clone() 57 | 58 | temp = torch.exp(-(input**2)/(2*lens**2))/torch.sqrt(2*torch.tensor(math.pi))/lens 59 | return grad_input * temp.float() * gamma 60 | 61 | 62 | act_fun_adp = ActFun_adp.apply 63 | 64 | 65 | 66 | def mem_update_adp(inputs, mem, spike, tau_adp, b, tau_m, dt=1, isAdapt=1): 67 | alpha = torch.exp(-1. * dt / tau_m).cuda() 68 | ro = torch.exp(-1. * dt / tau_adp).cuda() 69 | if isAdapt: 70 | beta = 1.8 71 | else: 72 | beta = 0. 73 | 74 | b = ro * b + (1 - ro) * spike 75 | B = b_j0 + beta * b 76 | 77 | mem = mem * alpha + (1 - alpha) * R_m * inputs - B * spike * dt 78 | inputs_ = mem - B 79 | spike = act_fun_adp(inputs_) # act_fun : approximation firing function 80 | return mem, spike, B, b 81 | 82 | 83 | def output_Neuron(inputs, mem, tau_m, dt=1): 84 | """ 85 | The read out neuron is leaky integrator without spike 86 | """ 87 | alpha = torch.exp(-1. * dt / tau_m)#.cuda() 88 | mem = mem * alpha + (1. - alpha) * R_m * inputs 89 | return mem 90 | 91 | 92 | class RNN_custom(nn.Module): 93 | def __init__(self, input_size, hidden_size, output_size): 94 | super(RNN_custom, self).__init__() 95 | 96 | self.hidden_size = hidden_size 97 | # self.hidden_size = input_size 98 | self.i_2_h1 = nn.Linear(input_size, hidden_size[0]) 99 | self.h1_2_h1 = nn.Linear(hidden_size[0], hidden_size[0]) 100 | self.h1_2_h2 = nn.Linear(hidden_size[0], hidden_size[1]) 101 | self.h2_2_h2 = nn.Linear(hidden_size[1], hidden_size[1]) 102 | 103 | self.h2o = nn.Linear(hidden_size[1], output_size) 104 | 105 | self.tau_adp_h1 = nn.Parameter(torch.Tensor(hidden_size[0])) 106 | self.tau_adp_h2 = nn.Parameter(torch.Tensor(hidden_size[1])) 107 | self.tau_adp_o = nn.Parameter(torch.Tensor(output_size)) 108 | self.tau_m_h1 = nn.Parameter(torch.Tensor(hidden_size[0])) 109 | self.tau_m_h2 = nn.Parameter(torch.Tensor(hidden_size[1])) 110 | self.tau_m_o = nn.Parameter(torch.Tensor(output_size)) 111 | 112 | nn.init.orthogonal_(self.h1_2_h1.weight) 113 | nn.init.orthogonal_(self.h2_2_h2.weight) 114 | nn.init.xavier_uniform_(self.i_2_h1.weight) 115 | nn.init.xavier_uniform_(self.h1_2_h2.weight) 116 | nn.init.xavier_uniform_(self.h2_2_h2.weight) 117 | nn.init.xavier_uniform_(self.h2o.weight) 118 | 119 | nn.init.constant_(self.i_2_h1.bias, 0) 120 | nn.init.constant_(self.h1_2_h2.bias, 0) 121 | nn.init.constant_(self.h2_2_h2.bias, 0) 122 | nn.init.constant_(self.h1_2_h1.bias, 0) 123 | 124 | nn.init.constant_(self.tau_adp_h1, 50) 125 | nn.init.constant_(self.tau_adp_h2, 100) 126 | nn.init.constant_(self.tau_adp_o, 100) 127 | nn.init.constant_(self.tau_m_h1, 10.) 128 | nn.init.constant_(self.tau_m_h2, 10.) 129 | nn.init.constant_(self.tau_m_o, 15.) 130 | 131 | 132 | self.b_h1 = self.b_h2 = self.b_o = 0 133 | 134 | def forward(self, input): 135 | batch_size, seq_num, input_dim = input.shape 136 | self.b_h1 = self.b_h2 = self.b_o = b_j0 137 | 138 | mem_layer1 = spike_layer1 = torch.rand(batch_size, self.hidden_size[0]).cuda() 139 | mem_layer2 = spike_layer2 = torch.rand(batch_size, self.hidden_size[1]).cuda() 140 | mem_output = torch.rand(batch_size, output_dim).cuda() 141 | output = torch.zeros(batch_size, output_dim).cuda() 142 | 143 | hidden_spike_ = [] 144 | hidden_mem_ = [] 145 | h2o_mem_ = [] 146 | 147 | for i in range(seq_num): 148 | input_x = input[:, i, :] 149 | 150 | h_input = self.i_2_h1(input_x.float()) + self.h1_2_h1(spike_layer1) 151 | mem_layer1, spike_layer1, theta_h1, self.b_h1 = mem_update_adp(h_input, mem_layer1, spike_layer1, 152 | self.tau_adp_h1, self.b_h1,self.tau_m_h1) 153 | h2_input = self.h1_2_h2(spike_layer1) + self.h2_2_h2(spike_layer2) 154 | mem_layer2, spike_layer2, theta_h2, self.b_h2 = mem_update_adp(h2_input, mem_layer2, spike_layer2, 155 | self.tau_adp_h2, self.b_h2, self.tau_m_h2) 156 | mem_output = output_Neuron(self.h2o(spike_layer2), mem_output, self.tau_m_o) 157 | if i > 0: 158 | output= output + F.softmax(mem_output, dim=1) 159 | 160 | hidden_spike_.append(spike_layer1.data.cpu().numpy()) 161 | hidden_mem_.append(mem_layer1.data.cpu().numpy()) 162 | h2o_mem_.append(mem_output.data.cpu().numpy()) 163 | 164 | return output, hidden_spike_, hidden_mem_, h2o_mem_ 165 | 166 | 167 | ''' 168 | STEP 4: INSTANTIATE MODEL CLASS 169 | ''' 170 | input_dim = 700 171 | hidden_dim = [128,128] # 128 172 | output_dim = 20 173 | seq_dim = 100 # Number of steps to unroll 174 | num_encode = 700 175 | total_steps = seq_dim 176 | 177 | model = RNN_custom(input_dim, hidden_dim, output_dim) 178 | 179 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 180 | print("device:", device) 181 | model.to(device) 182 | criterion = nn.CrossEntropyLoss() 183 | learning_rate = 1e-2 # 1e-2 184 | 185 | 186 | base_params = [model.i_2_h1.weight, model.i_2_h1.bias, 187 | model.h1_2_h1.weight, model.h1_2_h1.bias, 188 | model.h1_2_h2.weight, model.h1_2_h2.bias, 189 | model.h2_2_h2.weight, model.h2_2_h2.bias, 190 | model.h2o.weight, model.h2o.bias] 191 | optimizer = torch.optim.Adam([ 192 | {'params': base_params}, 193 | {'params': model.tau_adp_h1, 'lr': learning_rate * 5}, 194 | {'params': model.tau_adp_h2, 'lr': learning_rate * 5}, 195 | {'params': model.tau_adp_o, 'lr': learning_rate * 5}, 196 | {'params': model.tau_m_h1, 'lr': learning_rate * 2}, 197 | {'params': model.tau_m_h2, 'lr': learning_rate * 2}, 198 | {'params': model.tau_m_o, 'lr': learning_rate * 2}], 199 | lr=learning_rate) 200 | scheduler = StepLR(optimizer, step_size=10, gamma=.5) 201 | 202 | 203 | def train(model, num_epochs=150): 204 | acc = [] 205 | best_accuracy = 80 206 | for epoch in range(num_epochs): 207 | for i, (images, labels) in enumerate(train_loader): 208 | images = images.view(-1, seq_dim, input_dim).requires_grad_().to(device) 209 | labels = labels.long().to(device) 210 | # Clear gradients w.r.t. parameters 211 | optimizer.zero_grad() 212 | # Forward pass to get output/logits 213 | outputs, _,_,_ = model(images) 214 | # Calculate Loss: softmax --> cross entropy loss 215 | loss = criterion(outputs, labels) 216 | # Getting gradients w.r.t. parameters 217 | loss.backward() 218 | # Updating parameters 219 | optimizer.step() 220 | scheduler.step() 221 | accuracy = test(model, train_loader) 222 | ts_acc = test(model) 223 | if ts_acc > best_accuracy and accuracy > 80: 224 | torch.save(model, './model/model_' + str(ts_acc) + '-readout-2layer-v1-12Feb[128,128].pth') 225 | best_accuracy = ts_acc 226 | acc.append(accuracy) 227 | print('epoch: ', epoch, '. Loss: ', loss.item(), '. Tr Accuracy: ', accuracy, '. Ts Accuracy: ', ts_acc) 228 | return acc 229 | 230 | 231 | def test(model, dataloader=test_loader): 232 | correct = 0 233 | total = 0 234 | # Iterate through test dataset 235 | for images, labels in dataloader: 236 | images = images.view(-1, seq_dim, input_dim).to(device) 237 | 238 | outputs, _,_,_ = model(images) 239 | _, predicted = torch.max(outputs.data, 1) 240 | total += labels.size(0) 241 | if torch.cuda.is_available(): 242 | correct += (predicted.cpu() == labels.long().cpu()).sum() 243 | else: 244 | correct += (predicted == labels).sum() 245 | 246 | accuracy = 100. * correct.numpy() / total 247 | return accuracy 248 | 249 | 250 | 251 | 252 | ############################### 253 | acc = train(model, num_epochs) 254 | accuracy = test(model) 255 | print(' Accuracy: ', accuracy) 256 | 257 | -------------------------------------------------------------------------------- /s_mnist-gpu.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision.transforms as transforms 4 | import torchvision.datasets as dsets 5 | import torch.nn.functional as F 6 | import numpy as np 7 | from torch.autograd import Variable 8 | from torch.optim.lr_scheduler import StepLR,MultiStepLR 9 | import math 10 | import keras 11 | from torch.utils import data 12 | import matplotlib.pyplot as plt 13 | from datetime import datetime 14 | 15 | import argparse 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument("--task", help="choose the task: smnist and psmnist", type=str,default="smnist") 18 | parser.add_argument("--ec_f", help="choose the encode function: rbf, rbf-lc, poisson", type=str,default='rbf') 19 | parser.add_argument("--dc_f", help="choose the decode function: adp-mem, adp-spike, integrator", type=str,default='adp-mem') 20 | parser.add_argument("--batch_size", help="set the batch_size", type=int,default=200) 21 | parser.add_argument("--encoder", help="set the number of encoder", type=int,default=80) 22 | parser.add_argument("--num_epochs", help="set the number of epoch", type=int,default=200) 23 | parser.add_argument("--learning_rate", help="set the learning rate", type=float,default=1e-2) 24 | parser.add_argument("--len", help="set the length of the gaussian", type=float,default=0.5) 25 | parser.add_argument('--network', nargs='+', type=int,default=[256,128]) 26 | 27 | 28 | def load_dataset(task='smnist'): 29 | if task == 'smnist': 30 | (X_train, y_train), (X_test, y_test) = keras.datasets.mnist.load_data() 31 | elif task == 'psmnist': 32 | X_train = np.load('./ps_data/ps_X_train.npy') 33 | X_test = np.load('./ps_data/ps_X_test.npy') 34 | y_train = np.load('./ps_data/Y_train.npy') 35 | y_test = np.load('./ps_data/Y_test.npy') 36 | else: 37 | print('only two task, -- smnist and psmnist') 38 | return 0 39 | X_train = torch.from_numpy(X_train).float() 40 | X_test = torch.from_numpy(X_test).float() 41 | y_train = torch.from_numpy(y_train).long() 42 | y_test = torch.from_numpy(y_test).long() 43 | train_dataset = data.TensorDataset(X_train,y_train) # create train datset 44 | test_dataset = data.TensorDataset(X_test,y_test) # create test datset 45 | 46 | return train_dataset,test_dataset 47 | 48 | ''' 49 | STEP 3a_v2: CREATE Adaptative spike MODEL CLASS 50 | ''' 51 | b_j0 = .1#0.01 # neural threshold baseline 52 | tau_m = 20 # ms membrane potential constant 53 | R_m = 1 # membrane resistance 54 | dt = 1 # 55 | gamma = .5 # gradient scale 56 | lens = 0.5 57 | 58 | def gaussian(x, mu=0., sigma=.5): 59 | return torch.exp(-((x - mu) ** 2) / (2 * sigma ** 2)) / torch.sqrt(2 * torch.tensor(math.pi)) / sigma 60 | 61 | def RBF_encode(x,num_neurons=5,eta=1.): 62 | if num_neurons<3: 63 | print('neurons number should be larger than 2') 64 | assert Exception 65 | return 0 66 | else: 67 | if len(x.shape) == 1: 68 | res = torch.zeros([x.shape[0],num_neurons]).cuda() 69 | if len(x.shape) == 2: 70 | res = torch.zeros([x.shape[0],x.shape[1],num_neurons]).cuda() 71 | 72 | # scale = 1./(num_neurons-2) 73 | # mus = [(2*i-2)/2*scale for i in range(num_neurons)] 74 | scale = 1./(num_neurons-2) 75 | mus = [(2*i-2)/2*scale for i in range(num_neurons)] 76 | 77 | sigmas = scale/eta 78 | for i in range(num_neurons): 79 | if len(x.shape) == 1: 80 | res[:,i] = gaussian(x,mu=mus[i],sigma=sigmas) 81 | if len(x.shape) == 2: 82 | res[:,:,i] = gaussian(x,mu=mus[i],sigma=sigmas) 83 | return res 84 | 85 | 86 | class ActFun_adp(torch.autograd.Function): 87 | @staticmethod 88 | def forward(ctx, input): # input = membrane potential- threshold 89 | ctx.save_for_backward(input) 90 | return input.gt(0).float() # is firing ??? 91 | 92 | @staticmethod 93 | def backward(ctx, grad_output): # approximate the gradients 94 | input, = ctx.saved_tensors 95 | grad_input = grad_output.clone() 96 | temp = gaussian(input, mu=0., sigma=lens) 97 | return grad_input * temp.float() * gamma 98 | 99 | 100 | act_fun_adp = ActFun_adp.apply 101 | 102 | 103 | 104 | def mem_update_adp(inputs, mem, spike, tau_adp,tau_m, b, dt=1, isAdapt=1): 105 | # tau_adp = torch.FloatTensor([tau_adp]) 106 | alpha = torch.exp(-1. * dt / tau_m).cuda() 107 | ro = torch.exp(-1. * dt / tau_adp).cuda() 108 | # tau_adp is tau_adaptative which is learnable # add requiregredients 109 | if isAdapt: 110 | beta = 1.8 111 | else: 112 | beta = 0. 113 | 114 | b = ro * b + (1 - ro) * spike 115 | B = b_j0 + beta * b 116 | 117 | mem = mem * alpha + (1 - alpha) * R_m * inputs - B * spike * dt 118 | inputs_ = mem - B 119 | spike = act_fun_adp(inputs_) # act_fun : approximation firing function 120 | return mem, spike, B, b 121 | 122 | def output_Neuron(inputs, mem, tau_m, dt=1): 123 | """ 124 | The read out neuron is leaky integrator without spike 125 | """ 126 | # alpha = torch.exp(-1. * dt / torch.FloatTensor([30.])).cuda() 127 | alpha = torch.exp(-1. * dt / tau_m).cuda() 128 | mem = mem * alpha + (1. - alpha) * R_m * inputs 129 | return mem 130 | 131 | ''' 132 | STEP 3b: CREATE MODEL CLASS 133 | ''' 134 | 135 | 136 | class RNN_custom(nn.Module): 137 | def __init__(self, input_size, hidden_dims, output_size, num_encode=30,EC_f='rbf',DC_f='mem'): 138 | super(RNN_custom, self).__init__() 139 | 140 | self.EC_f = EC_f 141 | self.DC_f = DC_f 142 | 143 | self.num_encoder = num_encode 144 | self.hidden_size = hidden_dims[0] 145 | self.num_decoder = hidden_dims[1] 146 | self.i2h = nn.Linear(self.num_encoder, self.hidden_size) 147 | self.h2h = nn.Linear(self.hidden_size, self.hidden_size) 148 | self.h2d = nn.Linear(self.hidden_size, self.num_decoder) 149 | self.d2d = nn.Linear(self.num_decoder, self.num_decoder) 150 | self.d2o = nn.Linear(self.num_decoder, output_size) 151 | 152 | self.tau_adp_h = nn.Parameter(torch.Tensor(self.hidden_size)) 153 | self.tau_adp_d = nn.Parameter(torch.Tensor(self.num_decoder)) 154 | self.tau_adp_o = nn.Parameter(torch.Tensor(output_size)) 155 | self.tau_m_h = nn.Parameter(torch.Tensor(self.hidden_size)) 156 | self.tau_m_d = nn.Parameter(torch.Tensor(self.num_decoder)) 157 | self.tau_m_o = nn.Parameter(torch.Tensor(output_size)) 158 | 159 | if self.EC_f == 'rbf-lc': 160 | self.threshold_event = nn.Parameter(torch.tensor(0.2,requires_grad=True)) 161 | 162 | 163 | nn.init.orthogonal_(self.h2h.weight) 164 | nn.init.xavier_uniform_(self.i2h.weight) 165 | nn.init.xavier_uniform_(self.h2d.weight) 166 | nn.init.xavier_uniform_(self.d2d.weight) 167 | nn.init.xavier_uniform_(self.d2o.weight) 168 | 169 | nn.init.constant_(self.i2h.bias, 0) 170 | nn.init.constant_(self.h2h.bias, 0) 171 | nn.init.constant_(self.h2d.bias, 0) 172 | nn.init.constant_(self.d2d.bias, 0) 173 | nn.init.constant_(self.d2o.bias, 0) 174 | 175 | nn.init.normal_(self.tau_adp_h, 700,25) 176 | nn.init.normal_(self.tau_adp_o, 700,25) 177 | nn.init.normal_(self.tau_adp_d, 700,25) 178 | 179 | #nn.init.normal_(self.tau_m_h, 20,5) 180 | #nn.init.normal_(self.tau_m_o, 100,5) 181 | #nn.init.normal_(self.tau_m_d, 15,5) 182 | 183 | #nn.init.normal_(self.tau_adp_h, 100,25) 184 | #nn.init.normal_(self.tau_adp_o, 300,25) 185 | #nn.init.normal_(self.tau_adp_d, 200,25) 186 | 187 | nn.init.normal_(self.tau_m_h, 20,5) 188 | nn.init.normal_(self.tau_m_o, 100,5) 189 | nn.init.normal_(self.tau_m_d, 15,5) 190 | self.b_h = self.b_o = self.b_d = 0 191 | 192 | def forward(self, input): 193 | batch_size, seq_num, input_dim = input.shape 194 | self.b_h = self.b_o = self.b_d = b_j0 195 | 196 | hidden_mem = hidden_spike = torch.rand(batch_size, self.hidden_size).cuda() 197 | d2o_spike = output_sumspike = d2o_mem = torch.rand(batch_size, output_dim).cuda() 198 | h2d_mem = h2d_spike = torch.rand(batch_size, self.num_decoder).cuda() 199 | 200 | input = input/255. 201 | if self.EC_f[:3]=='rbf': 202 | input_RBF = RBF_encode(input.view(batch_size,seq_num).float(),self.num_encoder) 203 | 204 | for i in range(seq_num): 205 | if self.EC_f == 'rbf': 206 | input_x = input_RBF[:,i,:] 207 | elif self.EC_f == 'rbf-lc': 208 | input_x = input_RBF[:,i,:].gt(self.threshold_event).float().to(device) 209 | elif self.EC_f == 'Poisson': 210 | input_pixel_intensity = input[:, i, :] 211 | input_x = torch.rand(self.num_encoder, device='cuda') < input_pixel_intensity 212 | 213 | #################################################################### 214 | h_input = self.i2h(input_x.float()) + self.h2h(hidden_spike) 215 | 216 | hidden_mem, hidden_spike, theta_h, self.b_h = mem_update_adp(h_input,hidden_mem, hidden_spike, self.tau_adp_h, self.tau_m_h,self.b_h) 217 | d_input = self.h2d(hidden_spike) + self.d2d(h2d_spike) 218 | h2d_mem, h2d_spike, theta_d, self.b_d = mem_update_adp(d_input, h2d_mem, h2d_spike, self.tau_adp_d,self.tau_m_d, self.b_d) 219 | 220 | if self.DC_f[:3]=='adp': 221 | d2o_mem, d2o_spike, theta_o, self.b_o = mem_update_adp(self.d2o(h2d_spike),d2o_mem, d2o_spike, self.tau_adp_o, self.tau_m_o, self.b_o) 222 | elif self.DC_f == 'integrator': 223 | d2o_mem = output_Neuron(self.d2o(h2d_spike),d2o_mem, self.tau_m_o) 224 | if i >= 0: 225 | if self.DC_f == 'adp-mem': 226 | output_sumspike = output_sumspike + F.softmax(d2o_mem,dim=1) 227 | elif self.DC_f =='adp-spike': 228 | output_sumspike = output_sumspike + d2o_spike 229 | elif self.DC_f =='integrator': 230 | output_sumspike =output_sumspike+ F.softmax(d2o_mem,dim=1) 231 | 232 | return output_sumspike, hidden_spike 233 | 234 | 235 | 236 | def train(model, num_epochs,train_loader,test_loader,file_name,MyFile): 237 | acc = [] 238 | 239 | best_accuracy = 80 240 | for epoch in range(num_epochs): 241 | for i, (images, labels) in enumerate(train_loader): 242 | images = images.view(-1, seq_dim, input_dim).requires_grad_().to(device) 243 | labels = labels.long().to(device) 244 | # Clear gradients w.r.t. parameters 245 | optimizer.zero_grad() 246 | # Forward pass to get output/logits 247 | outputs, _ = model(images) 248 | # Calculate Loss: softmax --> cross entropy loss 249 | loss = criterion(outputs, labels) 250 | # Getting gradients w.r.t. parameters 251 | loss.backward() 252 | # Updating parameters 253 | optimizer.step() 254 | scheduler.step() 255 | accuracy = test(model, train_loader) 256 | ts_acc = test(model,test_loader) 257 | if ts_acc > best_accuracy and accuracy > 80: 258 | torch.save(model, './model/model_' + str(ts_acc) + '_'+file_name+'-tau_adp.pth') 259 | best_accuracy = ts_acc 260 | acc.append(accuracy) 261 | res_str = 'epoch: '+str(epoch)+' Loss: '+ str(loss.item())+'. Tr Accuracy: '+ str(accuracy)+ '. Ts Accuracy: '+str(ts_acc) 262 | print(res_str) 263 | MyFile.write(res_str) 264 | MyFile.write('\n') 265 | return acc 266 | 267 | 268 | def test(model, dataloader): 269 | correct = 0 270 | total = 0 271 | # Iterate through test dataset 272 | for images, labels in dataloader: 273 | images = images.view(-1, seq_dim, input_dim).to(device) 274 | 275 | outputs, _ = model(images) 276 | _, predicted = torch.max(outputs.data, 1) 277 | total += labels.size(0) 278 | if torch.cuda.is_available(): 279 | correct += (predicted.cpu() == labels.long().cpu()).sum() 280 | else: 281 | correct += (predicted == labels).sum() 282 | 283 | accuracy = 100. * correct.numpy() / total 284 | return accuracy 285 | 286 | 287 | def predict(model,test_loader): 288 | # Iterate through test dataset 289 | result = np.zeros(1) 290 | for images, labels in test_loader: 291 | images = images.view(-1, seq_dim, input_dim).to(device) 292 | 293 | outputs, _,_,_ = model(images) 294 | # _, Predicted = torch.max(outputs.data, 1) 295 | # result.append(Predicted.data.cpu().numpy()) 296 | predicted_vec = outputs.data.cpu().numpy() 297 | Predicted = predicted_vec.argmax(axis=1) 298 | result = np.append(result,Predicted) 299 | return np.array(result[1:]).flatten() 300 | 301 | if __name__ == '__main__': 302 | args = parser.parse_args() 303 | 304 | batch_size = args.batch_size 305 | num_epochs = args.num_epochs 306 | task = args.task 307 | EC_f = args.ec_f 308 | DC_f = args.dc_f 309 | num_encode=args.encoder 310 | 311 | train_dataset,test_dataset = load_dataset(task) 312 | train_loader = torch.utils.data.DataLoader(dataset=train_dataset,batch_size=batch_size,shuffle=True) 313 | test_loader = torch.utils.data.DataLoader(dataset=test_dataset,batch_size=batch_size,shuffle=False) 314 | 315 | 316 | input_dim = 1 317 | hidden_dims = args.network#[256,128] 318 | output_dim = 10 319 | seq_dim = int(784 / input_dim) # Number of steps to unroll 320 | 321 | model = RNN_custom(input_dim, hidden_dims, output_dim,num_encode=num_encode,EC_f=EC_f,DC_f=DC_f) 322 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 323 | print("device:",device) 324 | model.to(device) 325 | 326 | criterion = nn.CrossEntropyLoss() 327 | learning_rate = args.learning_rate 328 | 329 | #optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) 330 | if EC_f == 'rbf-lc': 331 | base_params = [model.i2h.weight, model.i2h.bias, 332 | model.h2h.weight, model.h2h.bias, 333 | model.h2d.weight, model.h2d.bias, 334 | model.d2d.weight, model.d2d.bias, 335 | model.d2o.weight, model.d2o.bias,model.threshold_event] 336 | else: 337 | base_params = [model.i2h.weight, model.i2h.bias, 338 | model.h2h.weight, model.h2h.bias, 339 | model.h2d.weight, model.h2d.bias, 340 | model.d2d.weight, model.d2d.bias, 341 | model.d2o.weight, model.d2o.bias] 342 | 343 | optimizer = torch.optim.Adam([ 344 | {'params': base_params}, 345 | {'params': model.tau_adp_h, 'lr': learning_rate * 2}, 346 | {'params': model.tau_adp_d, 'lr': learning_rate * 3}, 347 | {'params': model.tau_adp_o, 'lr': learning_rate * 2}, 348 | {'params': model.tau_m_h, 'lr': learning_rate * 2}, 349 | {'params': model.tau_m_d, 'lr': learning_rate * 2}, 350 | {'params': model.tau_m_o, 'lr': learning_rate * 2},], 351 | lr=learning_rate) 352 | 353 | 354 | scheduler = StepLR(optimizer, step_size=25, gamma=.75) 355 | scheduler = MultiStepLR(optimizer, milestones=[25,50,100,150],gamma=0.5) 356 | now = datetime.now() 357 | dt_string = now.strftime("%d-%m-%Y %H:%M:%S") 358 | print('Time: ',dt_string) 359 | file_name = 'Task-'+task+'||Time-'+ dt_string+'||EC_f--'+EC_f+'||DC_f--'+DC_f+'||advanced' 360 | MyFile=open('./result_file/'+file_name+'.txt','w') 361 | MyFile.write(file_name) 362 | MyFile.write('\nnetwork: ['+str(hidden_dims[0])+' '+str(hidden_dims[1])+']') 363 | MyFile.write('\nlearning_rate: '+str(learning_rate)) 364 | MyFile.write('\nbatch_size: '+str(batch_size)) 365 | MyFile.write('\n\n =========== Result ======== \n') 366 | acc = train(model, num_epochs,train_loader,test_loader,file_name,MyFile) 367 | accuracy = test(model,test_loader) 368 | print('test Accuracy: ', accuracy) 369 | MyFile.write('test Accuracy: '+ str(accuracy)) 370 | MyFile.close() 371 | 372 | ################### 373 | ## Accuracy curve 374 | ################### 375 | if num_epochs > 10: 376 | plt.plot(acc) 377 | plt.title('Learning Curve -- Accuracy') 378 | plt.xlabel('Epoch') 379 | plt.ylabel('Accuracy: %') 380 | plt.show() 381 | 382 | # python s_mnist-gpu.py --task smnist --ec_f rbf --dc_f adp-spike 383 | -------------------------------------------------------------------------------- /vis_result copy 2.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 2, 4 | "metadata": { 5 | "language_info": { 6 | "name": "python", 7 | "codemirror_mode": { 8 | "name": "ipython", 9 | "version": 3 10 | }, 11 | "version": "3.6.9-final" 12 | }, 13 | "orig_nbformat": 2, 14 | "file_extension": ".py", 15 | "mimetype": "text/x-python", 16 | "name": "python", 17 | "npconvert_exporter": "python", 18 | "pygments_lexer": "ipython3", 19 | "version": 3, 20 | "kernelspec": { 21 | "name": "python3", 22 | "display_name": "Python 3" 23 | } 24 | }, 25 | "cells": [ 26 | { 27 | "cell_type": "code", 28 | "execution_count": 9, 29 | "metadata": {}, 30 | "outputs": [], 31 | "source": [ 32 | "%matplotlib inline\n", 33 | "%config InlineBackend.figure_format = 'svg'" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 11, 39 | "metadata": { 40 | "tags": [] 41 | }, 42 | "outputs": [ 43 | { 44 | "output_type": "stream", 45 | "name": "stdout", 46 | "text": [ 47 | "Populating the interactive namespace from numpy and matplotlib\n" 48 | ] 49 | } 50 | ], 51 | "source": [ 52 | "%pylab inline" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": 13, 58 | "metadata": {}, 59 | "outputs": [], 60 | "source": [ 61 | "\n", 62 | "import torch\n", 63 | "import torch.nn as nn\n", 64 | "import torchvision.transforms as transforms\n", 65 | "import torchvision.datasets as dsets\n", 66 | "import torch.nn.functional as F\n", 67 | "import numpy as np\n", 68 | "from torch.autograd import Variable\n", 69 | "from torch.optim.lr_scheduler import StepLR,MultiStepLR\n", 70 | "import matplotlib.pyplot as plt\n", 71 | "from sklearn.utils import shuffle\n", 72 | "from torch.utils.data import TensorDataset, DataLoader\n", 73 | "from datetime import datetime\n", 74 | "from sklearn.metrics import confusion_matrix\n", 75 | "import scipy.io\n", 76 | "\n", 77 | "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", 78 | "device = torch.device(\"cpu\")\n", 79 | "\n", 80 | "'''\n", 81 | "STEP 3a_v2: CREATE Adaptative spike MODEL CLASS\n", 82 | "'''\n", 83 | "b_j0 = 0.01 # neural threshold baseline\n", 84 | "tau_m = 20 # ms membrane potential constant\n", 85 | "R_m = 1 # membrane resistance\n", 86 | "dt = 1 #\n", 87 | "gamma = .5 # gradient scale\n", 88 | "lens = 0.5\n", 89 | "\n", 90 | "\n", 91 | "# define approximate firing function\n", 92 | "class ActFun_adp(torch.autograd.Function):\n", 93 | " @staticmethod\n", 94 | " def forward(ctx, input): # input = membrane potential- threshold\n", 95 | " ctx.save_for_backward(input)\n", 96 | " return input.gt(0).float() # is firing\n", 97 | "\n", 98 | " @staticmethod\n", 99 | " def backward(ctx, grad_output): # approximate the gradients\n", 100 | " input, = ctx.saved_tensors\n", 101 | " grad_input = grad_output.clone()\n", 102 | " #temp = abs(input) < lens\n", 103 | " temp = np.exp((-(input) ** 2 / (2 * lens * lens))) / (np.sqrt(2 * np.pi) * lens)\n", 104 | " return gamma * grad_input * temp.float()\n", 105 | "\n", 106 | "\n", 107 | "act_fun_adp = ActFun_adp.apply\n", 108 | "# membrane potential update\n", 109 | "\n", 110 | "tau_m = torch.FloatTensor([tau_m])\n", 111 | "\n", 112 | "\n", 113 | "def mem_update_adp(ops, x, mem, spike, tau_adp, b, dt=1, isAdapt=1):\n", 114 | " alpha = torch.exp(-1. * dt / tau_m)\n", 115 | " ro = torch.exp(-1. * dt / tau_adp)\n", 116 | " # tau_adp is tau_adaptative which is learnable # add requiregredients\n", 117 | " if isAdapt:\n", 118 | " beta = 1.8\n", 119 | " else:\n", 120 | " beta = 0.\n", 121 | "\n", 122 | " b = ro * b + (1 - ro) * spike\n", 123 | " B = b_j0 + beta * b\n", 124 | "\n", 125 | " mem = mem * alpha + (1 - alpha) * R_m * ops(x) - B * spike * dt\n", 126 | " inputs_ = mem - B\n", 127 | " spike = act_fun_adp(inputs_) # act_fun : approximation firing function\n", 128 | " return mem, spike, B, b\n", 129 | "\n", 130 | "\n", 131 | "def mem_update_NU_adp(inputs, mem, spike, tau_adp, b, isAdapt=1, dt=1):\n", 132 | " # tau_adp = torch.FloatTensor([tau_adp])\n", 133 | " alpha = torch.exp(-1. * dt / tau_m)\n", 134 | " ro = torch.exp(-1. * dt / tau_adp)\n", 135 | " # tau_adp is tau_adaptative which is learnable # add requiregredients\n", 136 | " if isAdapt:\n", 137 | " beta = 1.8\n", 138 | " else:\n", 139 | " beta = 0.\n", 140 | "\n", 141 | " b = ro * b + (1 - ro) * spike\n", 142 | " B = b_j0 + beta * b\n", 143 | "\n", 144 | " mem = mem * alpha + (1 - alpha) * R_m * inputs - B * spike * dt\n", 145 | " inputs_ = mem - B\n", 146 | " spike = act_fun_adp(inputs_) # act_fun : approximation firing function\n", 147 | " return mem, spike, B, b\n", 148 | "\n", 149 | "class RNN_s(nn.Module):\n", 150 | " def __init__(self, input_size, hidden_size, output_size, sub_seq_length,criterion):\n", 151 | " super(RNN_s, self).__init__()\n", 152 | " self.criterion = criterion\n", 153 | "\n", 154 | " self.hidden_size = hidden_size\n", 155 | " self.output_size = output_size\n", 156 | " self.sub_seq_length = sub_seq_length\n", 157 | " self.i2h = nn.Linear(input_size, hidden_size)\n", 158 | " self.h2h = nn.Linear(hidden_size, hidden_size)\n", 159 | " self.h2o = nn.Linear(hidden_size, output_size)\n", 160 | "\n", 161 | " self.tau_adp_h = nn.Parameter(torch.Tensor(hidden_size))\n", 162 | " self.tau_adp_o = nn.Parameter(torch.Tensor(output_size))\n", 163 | " nn.init.orthogonal_(self.h2h.weight)\n", 164 | " nn.init.xavier_uniform_(self.i2h.weight)\n", 165 | " nn.init.xavier_uniform_(self.h2o.weight)\n", 166 | " nn.init.constant_(self.i2h.bias, 0)\n", 167 | " nn.init.constant_(self.h2h.bias, 0)\n", 168 | " nn.init.constant_(self.h2o.bias, 0)\n", 169 | "\n", 170 | " nn.init.constant_(self.tau_adp_h, 7)\n", 171 | " nn.init.constant_(self.tau_adp_o, 100)\n", 172 | " self.b_h = self.b_o = 0\n", 173 | "\n", 174 | " def forward(self, input,labels):\n", 175 | " self.b_h = self.b_o = 0\n", 176 | " total_spikes = 0\n", 177 | " # Feed in the whole sequence\n", 178 | " batch_size, seq_num, input_dim = input.shape\n", 179 | " # hidden_mem = hidden_spike = torch.zeros(batch_size, self.hidden_size)\n", 180 | " # output_mem = output_spike = out_spike = torch.zeros(batch_size, self.output_size)\n", 181 | " hidden_mem = hidden_spike = torch.rand(batch_size, self.hidden_size)\n", 182 | " output_mem = output_spike = out_spike = torch.rand(batch_size, self.output_size)\n", 183 | " output_spike_sum = torch.zeros(batch_size,seq_num, self.output_size)\n", 184 | " self.b_h = self.b_o = 0.01\n", 185 | "\n", 186 | " max_iters = 1301\n", 187 | " loss = 0\n", 188 | "\n", 189 | " output_ = []\n", 190 | " I_h = []\n", 191 | " spike_train = []\n", 192 | " predictions = []\n", 193 | " for i in range(max_iters): # Go through the sequence\n", 194 | " if i < seq_num:\n", 195 | " input_x = input[:, i, :]\n", 196 | " else:\n", 197 | " input_x = torch.zeros(batch_size,input_dim)\n", 198 | "\n", 199 | " ################# update states #########################\n", 200 | " h_input = self.i2h(input_x.float()) + self.h2h(hidden_spike)\n", 201 | " hidden_mem, hidden_spike, theta_h, self.b_h = mem_update_NU_adp(h_input,hidden_mem, hidden_spike,\n", 202 | " self.tau_adp_h, self.b_h,isAdapt=0)#, dt=input_dt_h)\n", 203 | "\n", 204 | " I_h.append(h_input.data.cpu().numpy())\n", 205 | " spike_train.append(hidden_spike.data.cpu().numpy())\n", 206 | " output_mem, output_spike, theta_o, self.b_o = mem_update_adp(self.h2o, hidden_spike,output_mem,\n", 207 | " output_spike, self.tau_adp_o, self.b_o)#, dt=input_dt_o)\n", 208 | " output_spike_sum[:,i,:] = output_spike\n", 209 | " total_spikes = total_spikes + int(hidden_spike.sum() + output_spike.sum())\n", 210 | " ################# classification #########################\n", 211 | " if i >= self.sub_seq_length:\n", 212 | " output_sumspike = output_mem #output_spike_sum[:, i-1:i, :].sum(axis=1)\n", 213 | " output_sumspike = F.log_softmax(output_sumspike,dim=1)\n", 214 | "\n", 215 | " predictions.append(output_sumspike.data.cpu().numpy())\n", 216 | " output_.append(output_sumspike.data.cpu().numpy())\n", 217 | " loss += self.criterion(output_sumspike, labels[:, i])\n", 218 | "\n", 219 | " predictions = torch.tensor(predictions)\n", 220 | " return predictions, loss , total_spikes, np.array(spike_train)\n", 221 | "\n", 222 | " def predict(self,input, lablel):\n", 223 | " prediction = self.forward(input, lablel)\n", 224 | " # prediction, _, total_spikes = self.forward(dt_h, dt_o, max_i, input, lablel)\n", 225 | " return prediction\n", 226 | "\n", 227 | "\n", 228 | "\n", 229 | "import scipy.signal as ssg\n", 230 | "def convert_seq(x,threshold=0.03):\n", 231 | " l = len(x)\n", 232 | " x= ssg.savgol_filter(x, 5, 3)\n", 233 | " X = np.zeros((l,2))\n", 234 | " for i in range(len(x)-1):\n", 235 | " if x[i+1] - x[i] >= threshold:\n", 236 | " X[i,0] = 1\n", 237 | " elif x[i] - x[i+1] >= threshold:\n", 238 | " X[i,1] = 1\n", 239 | " return X\n", 240 | "\n", 241 | "\n", 242 | "def expand_dim(x, N):\n", 243 | " y = np.zeros((x.shape[0], x.shape[1], N))\n", 244 | " for i in range(x.shape[0]):\n", 245 | " y[i, :, :] = np.tile(x[i,:], (N,1)).transpose()\n", 246 | "\n", 247 | " return y\n", 248 | "\n", 249 | "def lbl_to_spike(prediction):\n", 250 | " N = len(prediction)\n", 251 | " detections = np.zeros(N)\n", 252 | " for i in range(1, N):\n", 253 | " if (prediction[i] != prediction[i-1]):\n", 254 | " detections[i] = prediction[i]+1\n", 255 | " return detections\n", 256 | "\n", 257 | "\n", 258 | "def calculate_stats(prediction, lbl, tol):\n", 259 | " decisions = lbl_to_spike(prediction)\n", 260 | " labs = lbl_to_spike(lbl)\n", 261 | "\n", 262 | " lbl_indices = np.nonzero(labs)\n", 263 | " lbl_indices = np.array(lbl_indices).flatten()\n", 264 | "\n", 265 | " dist = np.zeros((len(lbl_indices), 6))\n", 266 | " for i in range(len(lbl_indices)):\n", 267 | " index = lbl_indices[i]\n", 268 | " lab = int(labs[index])\n", 269 | " dec_indices = np.array(np.nonzero((decisions-lab) == 0)).flatten() #indices where decisions == lab\n", 270 | " if len(dec_indices) == 0:\n", 271 | " dist[i, lab - 1] = 250\n", 272 | " continue\n", 273 | " j = np.argmin(np.abs(dec_indices - index)) # j is closest val in dec_indices to index\n", 274 | " dist[i, lab-1] = abs(dec_indices[j]-index)\n", 275 | " if (dist[i, lab-1] <= tol):\n", 276 | " decisions[dec_indices[j]] = 0 # mark as handled\n", 277 | "\n", 278 | " mean_error = np.mean(dist, axis=0)\n", 279 | " TP = np.sum(dist <= tol, axis=0)\n", 280 | " FN = np.sum(dist > tol, axis=0)\n", 281 | "\n", 282 | " FP = np.zeros(6)\n", 283 | " for i in decisions[(decisions > 0)]:\n", 284 | " FP[int(i-1)] += 1\n", 285 | "\n", 286 | " return mean_error, TP, FN, FP\n", 287 | "\n", 288 | "def accuracy_with_window(pred,target,before_window=10,after_window=10):\n", 289 | " # this function will used to replace the function (x-y)==0\n", 290 | " # to \n", 291 | " acc = 0\n", 292 | " n = len(pred)\n", 293 | "\n", 294 | " for i,p in enumerate(pred):\n", 295 | " if i< before_window:\n", 296 | " window_label = target[:i+after_window]\n", 297 | " elif n-i tol, axis=0)\n", 329 | "\n", 330 | " FP = np.zeros(6)\n", 331 | " for i in decisions[(decisions > 0)]:\n", 332 | " FP[int(i-1)] += 1\n", 333 | "\n", 334 | " return mean_error, TP, FN, FP\n", 335 | "\n", 336 | "\n", 337 | "\n", 338 | "def convert_dataset_wtime(mat_data):\n", 339 | " X = mat_data[\"x\"]\n", 340 | " Y = mat_data[\"y\"]\n", 341 | " t = mat_data[\"t\"]\n", 342 | " Y = np.argmax(Y[:, :, :], axis=-1)\n", 343 | " d1,d2 = t.shape\n", 344 | "\n", 345 | " # dt = np.zeros((size(t[:, 1]), size(t[1, :])))\n", 346 | " dt = np.zeros((d1,d2))\n", 347 | " for trace in range(d1):\n", 348 | " dt[trace, 0] = 1\n", 349 | " dt[trace, 1:] = t[trace, 1:] - t[trace, :-1]\n", 350 | "\n", 351 | " return dt, X, Y\n", 352 | "\n", 353 | "\n", 354 | "def load_max_i(mat_data):\n", 355 | " max_i = mat_data[\"max_i\"]\n", 356 | " return np.array(max_i.squeeze(),dtype=np.float16)\n", 357 | "def plot_confusion_matrix(cm,\n", 358 | " target_names,\n", 359 | " title='Confusion matrix',\n", 360 | " cmap=None,\n", 361 | " normalize=True):\n", 362 | " import matplotlib.pyplot as plt\n", 363 | " import numpy as np\n", 364 | " import itertools\n", 365 | "\n", 366 | " accuracy = np.trace(cm) / float(np.sum(cm))\n", 367 | " misclass = 1 - accuracy\n", 368 | "\n", 369 | " if cmap is None:\n", 370 | " cmap = plt.get_cmap('Blues')\n", 371 | "\n", 372 | " plt.figure(figsize=(8, 6))\n", 373 | " plt.imshow(cm, interpolation='nearest', cmap=cmap)\n", 374 | " plt.title(title)\n", 375 | " plt.colorbar()\n", 376 | "\n", 377 | " if target_names is not None:\n", 378 | " tick_marks = np.arange(len(target_names))\n", 379 | " plt.xticks(tick_marks, target_names, rotation=45)\n", 380 | " plt.yticks(tick_marks, target_names)\n", 381 | "\n", 382 | " if normalize:\n", 383 | " cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]\n", 384 | "\n", 385 | "\n", 386 | " thresh = cm.max() / 1.5 if normalize else cm.max() / 2\n", 387 | " for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):\n", 388 | " if normalize:\n", 389 | " plt.text(j, i, \"{:0.2f}\".format(cm[i, j]),\n", 390 | " horizontalalignment=\"center\",\n", 391 | " color=\"white\" if cm[i, j] > thresh else \"black\")\n", 392 | " else:\n", 393 | " plt.text(j, i, \"{:,}\".format(cm[i, j]),\n", 394 | " horizontalalignment=\"center\",\n", 395 | " color=\"white\" if cm[i, j] > thresh else \"black\")\n", 396 | "\n", 397 | "\n", 398 | " plt.tight_layout()\n", 399 | " plt.ylabel('True label')\n", 400 | " plt.xlabel('Predicted label')\n", 401 | " # plt.savefig('cm_SRNN_846.png')\n", 402 | " # plt.xlabel('Predicted label\\naccuracy={:0.4f}; misclass={:0.4f}'.format(accuracy, misclass))\n", 403 | " plt.show()\n", 404 | " " 405 | ] 406 | }, 407 | { 408 | "cell_type": "code", 409 | "execution_count": 15, 410 | "metadata": { 411 | "tags": [] 412 | }, 413 | "outputs": [ 414 | { 415 | "output_type": "stream", 416 | "name": "stdout", 417 | "text": [ 418 | "sequence length: 1301 , input dimension: 4\ntraining dataset distribution: (618, 1301)\ntest dataset distribution: (141, 1301)\n" 419 | ] 420 | } 421 | ], 422 | "source": [ 423 | "train_mat = scipy.io.loadmat('../QTDB_train.mat')\n", 424 | "test_mat = scipy.io.loadmat('../QTDB_test.mat')\n", 425 | "\n", 426 | "# # # old dataset\n", 427 | "# xxt = np.load('../dataset/test_y.npy')\n", 428 | "# print(xxt.shape)\n", 429 | "# xxt = np.load('../dataset1/xxt.npy')\n", 430 | "# yyt = np.load('../dataset1/yyt.npy')\n", 431 | "# xxv = np.load('../dataset1/xxv.npy')\n", 432 | "# yyv = np.load('../dataset1/yyv.npy')\n", 433 | "\n", 434 | "train_dt, train_x, train_y = convert_dataset_wtime(train_mat)\n", 435 | "train_max_i = load_max_i(train_mat)\n", 436 | "\n", 437 | "test_dt, test_x, test_y = convert_dataset_wtime(test_mat)\n", 438 | "test_max_i = load_max_i(test_mat)\n", 439 | "\n", 440 | "nb_of_sample, seq_dim, input_dim = np.shape(train_x)\n", 441 | "print('sequence length: {} , input dimension: {}'.format(seq_dim, input_dim))\n", 442 | "print('training dataset distribution: ',train_y.shape)\n", 443 | "print('test dataset distribution: ',test_y.shape)" 444 | ] 445 | }, 446 | { 447 | "cell_type": "code", 448 | "execution_count": null, 449 | "metadata": {}, 450 | "outputs": [], 451 | "source": [ 452 | "np.mean(train_x)" 453 | ] 454 | }, 455 | { 456 | "cell_type": "code", 457 | "execution_count": 6, 458 | "metadata": { 459 | "tags": [] 460 | }, 461 | "outputs": [ 462 | { 463 | "output_type": "execute_result", 464 | "data": { 465 | "text/plain": [ 466 | "RNN_s(\n", 467 | " (criterion): NLLLoss()\n", 468 | " (i2h): Linear(in_features=4, out_features=36, bias=True)\n", 469 | " (h2h): Linear(in_features=36, out_features=36, bias=True)\n", 470 | " (h2o): Linear(in_features=36, out_features=6, bias=True)\n", 471 | ")" 472 | ] 473 | }, 474 | "metadata": {}, 475 | "execution_count": 6 476 | } 477 | ], 478 | "source": [ 479 | "# STEP 2: MAKING DATASET ITERABLE\n", 480 | "batch_size = 64\n", 481 | "n_iters = 300000\n", 482 | "lens = 0.5 # hyper-parameters of approximate function\n", 483 | "num_epochs = 0\n", 484 | "nb_of_batch = nb_of_sample // batch_size\n", 485 | "\n", 486 | "sub_seq_length = 10\n", 487 | "#L = seq_dim - sub_seq_length\n", 488 | "hidden_dim = 36\n", 489 | "output_dim = 6\n", 490 | "\n", 491 | "train_data = TensorDataset(torch.from_numpy(train_x*1.),torch.from_numpy(train_y))\n", 492 | "train_loader = DataLoader(train_data, shuffle=True, batch_size=batch_size, drop_last=True)\n", 493 | "test_data = TensorDataset(torch.from_numpy(test_x*1.),torch.from_numpy(test_y))\n", 494 | "test_loader = DataLoader(test_data, shuffle=True, batch_size=batch_size, drop_last=True)\n", 495 | "\n", 496 | "criterion = nn.NLLLoss()#nn.CrossEntropyLoss()\n", 497 | "model = RNN_s(input_size=input_dim, hidden_size=hidden_dim,\n", 498 | " output_size=output_dim, sub_seq_length=sub_seq_length,criterion=criterion)\n", 499 | "model.to(device)\n", 500 | "\n", 501 | "\n", 502 | "# model.load_state_dict(torch.load('./model_0.853338_36_New.pth'))" 503 | ] 504 | }, 505 | { 506 | "cell_type": "code", 507 | "execution_count": 7, 508 | "metadata": {}, 509 | "outputs": [ 510 | { 511 | "output_type": "execute_result", 512 | "data": { 513 | "text/plain": [ 514 | "" 515 | ] 516 | }, 517 | "metadata": {}, 518 | "execution_count": 7 519 | } 520 | ], 521 | "source": [ 522 | "model.load_state_dict(torch.load('./model_0.853338_36_New.pth'))" 523 | ] 524 | }, 525 | { 526 | "cell_type": "code", 527 | "execution_count": 8, 528 | "metadata": { 529 | "tags": [] 530 | }, 531 | "outputs": [ 532 | { 533 | "output_type": "stream", 534 | "name": "stdout", 535 | "text": [ 536 | "0 0.9465530596436871\n", 537 | "10 0.9385254559538061\n", 538 | "20 0.9238316550477668\n", 539 | "30 0.9258889083231304\n", 540 | "40 0.9186676994577846\n", 541 | "50 0.912060873923543\n", 542 | "60 0.9113662048735889\n", 543 | "70 0.918765887345763\n", 544 | "80 0.9184955676047852\n", 545 | "90 0.9190933001932228\n", 546 | "100 0.9188747689641156\n", 547 | "110 0.9175930384295993\n", 548 | "120 0.9202360909282956\n", 549 | "130 0.9202287119872755\n", 550 | "140 0.9175689855024695\n", 551 | "dict_items([('err', array([3.76748186, 1.01513057, 1.04319487, 0.9086288 , 0.97002002,\n", 552 | " 1.73929765])), ('TP', array([4253, 4299, 4297, 4294, 4299, 4280])), ('FP', array([738., 93., 27., 31., 131., 768.])), ('FN', array([50, 4, 6, 9, 4, 23]))])\n", 553 | "fr: 0.2907288732513192\n", 554 | "test accuracy: 0.9175689855024695\n" 555 | ] 556 | } 557 | ], 558 | "source": [ 559 | "test_seq_dim = test_x.shape[1]\n", 560 | "\n", 561 | "window_size_front = 5\n", 562 | "window_size_beind = 5 \n", 563 | "acc = []\n", 564 | "predicted_list = []\n", 565 | "labels_list = []\n", 566 | "fr_list = []\n", 567 | "stats = {\"err\": 0, \"TP\": 0, \"FP\": 0, \"FN\": 0}\n", 568 | "for i in range(len(test_x)):\n", 569 | " x_emp = test_x[i:i+1]\n", 570 | " y_emp = test_y[i:i+1]\n", 571 | "\n", 572 | " images = torch.tensor(x_emp*1.).view((-1, test_seq_dim, input_dim)).requires_grad_().to(device)\n", 573 | " labels = torch.tensor(y_emp).view((-1, test_seq_dim)).long().to(device)\n", 574 | "\n", 575 | " pred,_,spikesum,_ = model.forward(images,labels) #model.predict(images,labels)\n", 576 | " # print(pred.shape, pred.dtype)\n", 577 | " # print(pred.numpy().shape )\n", 578 | " a_len = 1301-sub_seq_length\n", 579 | " a_np = pred.data.cpu().numpy().reshape(a_len, 6)\n", 580 | " a_np_pred = np.argmax(a_np, axis=1)\n", 581 | " labels = y_emp[0, sub_seq_length:sub_seq_length+a_len].reshape(a_len)\n", 582 | " # acc_ = (a_np_pred == labels).sum() / int(a_len)\n", 583 | " acc_ = accuracy_with_window(a_np_pred.flatten(),labels.flatten(),before_window=window_size_front ,after_window=window_size_beind)\n", 584 | " acc.append(acc_)\n", 585 | " labels_list.extend(labels)\n", 586 | " predicted_list.extend(a_np_pred)\n", 587 | " fr_list.append(spikesum)\n", 588 | " if i%10 == 0:\n", 589 | " print(i,np.mean(acc))\n", 590 | "\n", 591 | " err, TP, FN, FP = calculate_stats(a_np_pred[30:-30], labels[30:-30], 0.150 * 250)\n", 592 | " stats[\"err\"] = stats[\"err\"] + err / len(test_x)\n", 593 | " stats[\"TP\"] = stats[\"TP\"] + TP\n", 594 | " stats[\"FP\"] = stats[\"FP\"] + FP\n", 595 | " stats[\"FN\"] = stats[\"FN\"] + FN\n", 596 | " \n", 597 | "print(stats.items())\n", 598 | "print('fr:',np.mean(fr_list)/1301./42.)\n", 599 | "test_acc = np.mean(acc)\n", 600 | "print('test accuracy: ',np.mean(acc))" 601 | ] 602 | }, 603 | { 604 | "cell_type": "markdown", 605 | "metadata": {}, 606 | "source": [ 607 | "## Accuracy with time shift" 608 | ] 609 | }, 610 | { 611 | "cell_type": "code", 612 | "execution_count": 10, 613 | "metadata": { 614 | "tags": [] 615 | }, 616 | "outputs": [ 617 | { 618 | "output_type": "stream", 619 | "name": "stdout", 620 | "text": [ 621 | "0 0.9031758326878389\n", 622 | "10 0.8874022956129849\n", 623 | "20 0.8765445760023608\n", 624 | "30 0.8764898428325129\n", 625 | "40 0.8647295535697418\n", 626 | "50 0.854999164654243\n", 627 | "60 0.8548970806719915\n", 628 | "70 0.8600386205692716\n", 629 | "80 0.8585745569995507\n", 630 | "90 0.8594240770848053\n", 631 | "100 0.8596835671173623\n", 632 | "110 0.8555139182559786\n", 633 | "120 0.8571163362375248\n", 634 | "130 0.8578946434801119\n", 635 | "140 0.8544588559091584\n", 636 | "dict_items([('err', array([4.93194504, 0.76258907, 1.18140922, 1.14728654, 0.92581813,\n", 637 | " 1.77659498])), ('TP', array([4245, 4300, 4297, 4295, 4299, 4281])), ('FP', array([758., 69., 35., 30., 120., 796.])), ('FN', array([58, 3, 6, 8, 4, 22]))])\n", 638 | "fr: 0.2906338641125303\n", 639 | "test accuracy: 0.8544588559091584\n" 640 | ] 641 | } 642 | ], 643 | "source": [ 644 | "test_seq_dim = test_x.shape[1]\n", 645 | "acc = []\n", 646 | "predicted_list = []\n", 647 | "labels_list = []\n", 648 | "fr_list = []\n", 649 | "stats = {\"err\": 0, \"TP\": 0, \"FP\": 0, \"FN\": 0}\n", 650 | "t_shift = 0\n", 651 | "for i in range(len(test_x)):\n", 652 | " x_emp = test_x[i:i+1]\n", 653 | " y_emp = test_y[i:i+1]\n", 654 | "\n", 655 | " images = torch.tensor(x_emp*1.).view((-1, test_seq_dim, input_dim)).requires_grad_().to(device)\n", 656 | " labels = torch.tensor(y_emp).view((-1, test_seq_dim)).long().to(device)\n", 657 | "\n", 658 | " pred,_,spikesum,_ = model.forward(images,labels) #model.predict(images,labels)\n", 659 | " # print(pred.shape, pred.dtype)\n", 660 | " # print(pred.numpy().shape )\n", 661 | " a_len = 1301-sub_seq_length\n", 662 | " a_np = pred.data.cpu().numpy().reshape(a_len, 6)\n", 663 | " a_np_pred = np.argmax(a_np, axis=1)\n", 664 | " # print(a_np_pred.shape)\n", 665 | " a_np_pred_shift = a_np_pred[:a_len-t_shift]\n", 666 | " labels = y_emp[0, sub_seq_length:sub_seq_length+a_len].reshape(a_len)\n", 667 | " labels_shift = labels[t_shift:]\n", 668 | " acc_ = (a_np_pred_shift == labels_shift).sum() / int(a_len)\n", 669 | " #acc_ = accuracy_with_window(a_np_pred.flatten(),labels.flatten(),before_window=3,after_window=3)\n", 670 | " acc.append(acc_)\n", 671 | " labels_list.extend(labels)\n", 672 | " predicted_list.extend(a_np_pred)\n", 673 | " fr_list.append(spikesum)\n", 674 | " if i%10 == 0:\n", 675 | " print(i,np.mean(acc))\n", 676 | "\n", 677 | " err, TP, FN, FP = calculate_stats(a_np_pred[30:-30], labels[30:-30], 0.150 * 250)\n", 678 | " stats[\"err\"] = stats[\"err\"] + err / len(test_x)\n", 679 | " stats[\"TP\"] = stats[\"TP\"] + TP\n", 680 | " stats[\"FP\"] = stats[\"FP\"] + FP\n", 681 | " stats[\"FN\"] = stats[\"FN\"] + FN\n", 682 | " \n", 683 | "print(stats.items())\n", 684 | "print('fr:',np.mean(fr_list)/1301./42.)\n", 685 | "test_acc = np.mean(acc)\n", 686 | "print('test accuracy: ',np.mean(acc))" 687 | ] 688 | }, 689 | { 690 | "cell_type": "code", 691 | "execution_count": 12, 692 | "metadata": { 693 | "tags": [] 694 | }, 695 | "outputs": [ 696 | { 697 | "output_type": "stream", 698 | "name": "stdout", 699 | "text": [ 700 | "[[34954 752 16 71 1576 4430]\n [ 1524 7147 970 0 38 466]\n [ 58 643 8656 451 9 23]\n [ 94 0 687 7131 1150 3]\n [ 703 35 35 1301 52021 1790]\n [ 4777 310 248 141 4192 45629]]\n[[8.30064118e-01 8.46179813e-02 1.50772710e-03 7.80648708e-03\n 2.67182043e-02 8.46372824e-02]\n [3.61909285e-02 8.04208394e-01 9.14059555e-02 0.00000000e+00\n 6.44220663e-04 8.90315431e-03]\n [1.37734505e-03 7.23528750e-02 8.15680362e-01 4.95876855e-02\n 1.52578578e-04 4.39426071e-04]\n [2.23224887e-03 0.00000000e+00 6.47380324e-02 7.84057174e-01\n 1.94961516e-02 5.73164441e-05]\n [1.66943719e-02 3.93833690e-03 3.29815303e-03 1.43045629e-01\n 8.81921134e-01 3.41988116e-02]\n [1.13440988e-01 3.48824125e-02 2.33697701e-02 1.55030236e-02\n 7.10677110e-02 8.71764009e-01]]\n" 701 | ] 702 | }, 703 | { 704 | "output_type": "display_data", 705 | "data": { 706 | "text/plain": "
", 707 | "image/svg+xml": "\n\n\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n" 708 | }, 709 | "metadata": { 710 | "needs_background": "light" 711 | } 712 | } 713 | ], 714 | "source": [ 715 | "# p_t == y_t+1\n", 716 | "predicted_list = np.array(predicted_list).reshape((-1,))\n", 717 | "labels_list = np.array(labels_list).reshape(((-1,)))\n", 718 | "cm = confusion_matrix(labels_list, predicted_list)\n", 719 | "print(cm)\n", 720 | "cm_r = cm/cm.sum(axis=0)\n", 721 | "print(cm_r)\n", 722 | "# import seaborn as sn\n", 723 | "plot_confusion_matrix(np.array(cm_r),\n", 724 | " normalize = True,\n", 725 | " target_names = [i for i in ['P','PQ','QR','RS','ST','TP']],\n", 726 | " title = \"SRNN Confusion Matrix\")" 727 | ] 728 | }, 729 | { 730 | "cell_type": "markdown", 731 | "metadata": {}, 732 | "source": [ 733 | "## Confusion Matrix" 734 | ] 735 | }, 736 | { 737 | "cell_type": "code", 738 | "execution_count": 14, 739 | "metadata": { 740 | "tags": [] 741 | }, 742 | "outputs": [ 743 | { 744 | "output_type": "stream", 745 | "name": "stdout", 746 | "text": [ 747 | "[[34954 752 16 71 1576 4430]\n [ 1524 7147 970 0 38 466]\n [ 58 643 8656 451 9 23]\n [ 94 0 687 7131 1150 3]\n [ 703 35 35 1301 52021 1790]\n [ 4777 310 248 141 4192 45629]]\n[[8.30064118e-01 8.46179813e-02 1.50772710e-03 7.80648708e-03\n 2.67182043e-02 8.46372824e-02]\n [3.61909285e-02 8.04208394e-01 9.14059555e-02 0.00000000e+00\n 6.44220663e-04 8.90315431e-03]\n [1.37734505e-03 7.23528750e-02 8.15680362e-01 4.95876855e-02\n 1.52578578e-04 4.39426071e-04]\n [2.23224887e-03 0.00000000e+00 6.47380324e-02 7.84057174e-01\n 1.94961516e-02 5.73164441e-05]\n [1.66943719e-02 3.93833690e-03 3.29815303e-03 1.43045629e-01\n 8.81921134e-01 3.41988116e-02]\n [1.13440988e-01 3.48824125e-02 2.33697701e-02 1.55030236e-02\n 7.10677110e-02 8.71764009e-01]]\n" 748 | ] 749 | }, 750 | { 751 | "output_type": "display_data", 752 | "data": { 753 | "text/plain": "
", 754 | "image/svg+xml": "\n\n\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n" 755 | }, 756 | "metadata": { 757 | "needs_background": "light" 758 | } 759 | } 760 | ], 761 | "source": [ 762 | "predicted_list = np.array(predicted_list).reshape((-1,))\n", 763 | "labels_list = np.array(labels_list).reshape(((-1,)))\n", 764 | "cm = confusion_matrix(labels_list, predicted_list)\n", 765 | "print(cm)\n", 766 | "cm_r = cm/cm.sum(axis=0)\n", 767 | "print(cm_r)\n", 768 | "# import seaborn as sn\n", 769 | "plot_confusion_matrix(np.array(cm_r),\n", 770 | " normalize = True,\n", 771 | " target_names = [i for i in ['P','PQ','QR','RS','ST','TP']],\n", 772 | " title = \"SRNN Confusion Matrix\")\n", 773 | "" 774 | ] 775 | }, 776 | { 777 | "cell_type": "code", 778 | "execution_count": null, 779 | "metadata": {}, 780 | "outputs": [], 781 | "source": [] 782 | } 783 | ] 784 | } --------------------------------------------------------------------------------