├── README.assets ├── method.png ├── sample_00.png └── filestruct_00.png ├── examples ├── LSTM │ ├── meta.py │ ├── model.py │ ├── learner.py │ └── main.py ├── CSNN │ ├── meta.py │ ├── main.py │ └── learner.py ├── NearestNeighbor │ └── main.py ├── SSiamese │ ├── model.py │ └── main.py └── SMAML │ ├── main.py │ ├── meta.py │ └── learner.py ├── data_loader ├── nomniglot_full.py ├── NOmniglot.py ├── nomniglot_nw_ks.py └── nomniglot_pair.py ├── README.md └── utils.py /README.assets/method.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Brain-Cog-Lab/N-Omniglot/HEAD/README.assets/method.png -------------------------------------------------------------------------------- /README.assets/sample_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Brain-Cog-Lab/N-Omniglot/HEAD/README.assets/sample_00.png -------------------------------------------------------------------------------- /README.assets/filestruct_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Brain-Cog-Lab/N-Omniglot/HEAD/README.assets/filestruct_00.png -------------------------------------------------------------------------------- /examples/LSTM/meta.py: -------------------------------------------------------------------------------- 1 | from torch import optim 2 | from examples.LSTM.learner import * 3 | 4 | 5 | class Meta(nn.Module): 6 | """ 7 | Meta Learner 8 | """ 9 | def __init__(self, args): 10 | super(Meta, self).__init__() 11 | 12 | self.net = net() 13 | self.optim = optim.SGD(self.net.parameters(), lr=0.1) 14 | 15 | self.onehot = torch.eye(1623).cuda() 16 | self.lossfunction = nn.CrossEntropyLoss() 17 | 18 | def forward(self, x, y): 19 | self.net.train() 20 | logits = self.net(x) 21 | loss = self.lossfunction(logits, y) 22 | 23 | self.optim.zero_grad() 24 | loss.backward() 25 | self.optim.step() 26 | pred_q = F.softmax(logits, dim=1).argmax(dim=1) 27 | correct = torch.eq(pred_q, y).sum().item() / pred_q.shape[0] # convert to numpy 28 | 29 | return correct 30 | 31 | def test(self, x, y): 32 | self.net.eval() 33 | 34 | logits = self.net(x) 35 | 36 | pred_q = F.softmax(logits, dim=1).argmax(dim=1) 37 | correct = torch.eq(pred_q, y).sum().item() / pred_q.shape[0] # convert to numpy 38 | 39 | return correct 40 | -------------------------------------------------------------------------------- /examples/CSNN/meta.py: -------------------------------------------------------------------------------- 1 | from torch import optim 2 | from examples.CSNN.learner import * 3 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 4 | 5 | 6 | class Meta(nn.Module): 7 | """ 8 | Meta Learner 9 | """ 10 | def __init__(self, args): 11 | super(Meta, self).__init__() 12 | 13 | self.net = SCNN() 14 | self.optim = optim.Adam(self.net.parameters(), lr=0.002) 15 | 16 | self.onehot = torch.eye(1623).cuda() 17 | self.lossfunction = nn.CrossEntropyLoss() 18 | 19 | def forward(self, x, y): 20 | self.net.train() 21 | logits = self.net(x) 22 | loss = self.lossfunction(logits, y) 23 | 24 | self.optim.zero_grad() 25 | loss.backward() 26 | self.optim.step() 27 | 28 | pred_q = F.softmax(logits, dim=1).argmax(dim=1) 29 | correct = torch.eq(pred_q, y).sum().item() / pred_q.shape[0] 30 | 31 | return correct 32 | 33 | def test(self, x, y): 34 | self.net.eval() 35 | 36 | logits = self.net(x) 37 | pred_q = F.softmax(logits, dim=1).argmax(dim=1) 38 | correct = torch.eq(pred_q, y).sum().item() / pred_q.shape[0] 39 | 40 | return correct 41 | -------------------------------------------------------------------------------- /examples/LSTM/model.py: -------------------------------------------------------------------------------- 1 | import random 2 | import os 3 | import torch 4 | 5 | 6 | class LSTM(torch.nn.Module): 7 | def __init__(self, inputnum, hnum): 8 | super().__init__() 9 | self.hnum = hnum 10 | self.wx = torch.randn((inputnum, 4 * hnum), requires_grad=True) / inputnum ** (1 / 2) 11 | self.wh = torch.randn((hnum, 4 * hnum), requires_grad=True) / hnum ** (1 / 2) 12 | self.b = torch.rand([1, 4 * hnum], requires_grad=True) 13 | 14 | self.wx = torch.nn.Parameter(self.wx) 15 | self.wh = torch.nn.Parameter(self.wh) 16 | self.b = torch.nn.Parameter(self.b) 17 | 18 | def forward(self, fullx, h, c): # x =[batch,L,in] 19 | self.fullh = [] 20 | self.h = h 21 | self.c = c 22 | for i in range(fullx.shape[1]): 23 | self.x = fullx[:, i, :] 24 | self.t = torch.mm(self.x, self.wx) + torch.mm(self.h, self.wh) + self.b # t=[batch,4h]) 25 | 26 | self.gate = torch.sigmoid(self.t[:, :3 * self.hnum]) # gate=[batch,3h]) 27 | f, i, o = [self.gate[:, n * self.hnum:(n + 1) * self.hnum] for n in range(3)] 28 | g = torch.tanh(self.t[:, 3 * self.hnum:4 * self.hnum]) 29 | self.c = self.c * f 30 | self.c = self.c + g * i 31 | self.h = torch.tanh(self.c) * o 32 | 33 | self.fullh.append(self.h.unsqueeze(1)) 34 | self.fullh = torch.cat(self.fullh, dim=1) 35 | return self.fullh, self.h, self.c # fullh=[batch,L,h] #h=[batch,] 36 | 37 | def __call__(self, x, h, c): 38 | return self.forward(x, h, c) 39 | 40 | 41 | if __name__ == "__main__": 42 | rnn = LSTM(10, 3) 43 | 44 | y, h, c = rnn(torch.ones([1, 100, 10]), torch.ones([1, 3]), torch.ones([1, 3])) 45 | print(h) 46 | a = torch.zeros([1, 100, 3]) 47 | a[0, -1, :] = 1 48 | print(a) 49 | y.backward(a) 50 | print(rnn.wh.grad) 51 | -------------------------------------------------------------------------------- /data_loader/nomniglot_full.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset, DataLoader 3 | from data_loader.NOmniglot import NOmniglot 4 | 5 | 6 | class NOmniglotfull(Dataset): 7 | ''' 8 | solve few-shot learning as general classification problem, 9 | We combine the original training set with the test set and take 3/4 as the training set 10 | ''' 11 | def __init__(self, root='data/', train=True, frames_num=4, data_type='event', 12 | transform=None, target_transform=None, use_npz=True, crop=True, create=True): 13 | super().__init__() 14 | 15 | trainSet = NOmniglot(root=root, train=True, frames_num=frames_num, data_type=data_type, 16 | transform=transform, target_transform=target_transform, 17 | use_npz=use_npz, crop=crop, create=create) 18 | testSet = NOmniglot(root=root, train=False, frames_num=frames_num, data_type=data_type, 19 | transform=transform, target_transform=lambda x: x + 964, 20 | use_npz=use_npz, crop=crop, create=create) 21 | self.data = torch.utils.data.ConcatDataset([trainSet, testSet]) 22 | if train: 23 | self.id = [j for j in range(len(self.data)) if j % 20 in [i for i in range(15)]] 24 | 25 | else: 26 | self.id = [j for j in range(len(self.data)) if j % 20 in [i for i in range(15, 20)]] 27 | 28 | def __len__(self): 29 | return len(self.id) 30 | 31 | def __getitem__(self, index): 32 | image, label = self.data[self.id[index]] 33 | return image, label 34 | 35 | 36 | if __name__ == '__main__': 37 | db_train = NOmniglotfull('../../data/', train=True, frames_num=4, data_type='event') 38 | dataloadertrain = DataLoader(db_train, batch_size=16, shuffle=True, num_workers=16, pin_memory=True) 39 | for x_spt, y_spt, x_qry, y_qry in dataloadertrain: 40 | print(x_spt.shape) -------------------------------------------------------------------------------- /examples/LSTM/learner.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from examples.LSTM.model import * 5 | 6 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 7 | 8 | 9 | class net(torch.nn.Module): 10 | def __init__(self): 11 | super().__init__() 12 | 13 | self.conv1 = torch.nn.Conv2d(2, 32, kernel_size=3, stride=1, padding=1) 14 | self.conv2 = torch.nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1) 15 | self.bn1 = torch.nn.BatchNorm2d(32, eps=1e-5, momentum=0.1, affine=True) 16 | self.bn2 = torch.nn.BatchNorm2d(32, eps=1e-5, momentum=0.1, affine=True) 17 | self.relu1 = torch.nn.ReLU(True) 18 | self.relu2 = torch.nn.ReLU(True) 19 | self.maxpool1 = torch.nn.MaxPool2d(2, stride=2, padding=0) 20 | self.maxpool2 = torch.nn.MaxPool2d(2, stride=2, padding=0) 21 | self.LSTM = LSTM(1568, 128) 22 | self.fc = torch.nn.Linear(128, 1623) 23 | 24 | def forward(self, input): 25 | h = torch.zeros([input.shape[0], 128], device=device) 26 | c = torch.zeros([input.shape[0], 128], device=device) 27 | 28 | result = None 29 | for step in range(input.shape[1]): # simulation time steps 30 | 31 | result = input[:, step] 32 | result = self.conv1(result.float()) 33 | result = self.bn1(result) 34 | result = self.relu1(result) 35 | result = self.maxpool1(result) 36 | 37 | result = self.conv2(result) # +result 38 | result = self.bn2(result) 39 | result = self.relu2(result) 40 | result = self.maxpool2(result) 41 | result = result.reshape(input.shape[0], 1, -1) 42 | result, h, c = self.LSTM(result, h, c) 43 | result = result.reshape(input.shape[0], -1) 44 | 45 | result = self.fc(result) 46 | 47 | return result 48 | 49 | def __call__(self, x): 50 | return self.forward(x) 51 | -------------------------------------------------------------------------------- /examples/NearestNeighbor/main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from data_loader.nomniglot_pair import NOmniglotTestSet 3 | from torch.utils.data import DataLoader 4 | import numpy as np 5 | import gflags 6 | import sys 7 | import os 8 | from tqdm import tqdm 9 | 10 | if __name__ == '__main__': 11 | Flags = gflags.FLAGS 12 | gflags.DEFINE_bool("cuda", True, "use cuda") 13 | gflags.DEFINE_integer("way", 5, "how much way one-shot learning") 14 | gflags.DEFINE_integer("shot", 1, "how much shot few-shot learning") 15 | gflags.DEFINE_string("time", 10000, "number of samples to test accuracy") 16 | gflags.DEFINE_integer("workers", 4, "number of dataLoader workers") 17 | gflags.DEFINE_string("gpu_ids", "0", "gpu ids used to train") 18 | Flags(sys.argv) 19 | T = 4 20 | data_type = 'event' 21 | 22 | os.environ["CUDA_VISIBLE_DEVICES"] = Flags.gpu_ids 23 | device = torch.device('cuda') if torch.cuda.is_available() else 'cpu' 24 | print("use gpu:", Flags.gpu_ids, "to train.") 25 | print("%d way %d shot" % (Flags.way, Flags.shot)) 26 | 27 | testSet = NOmniglotTestSet(root='../../data/', time=Flags.time, way=Flags.way, shot=Flags.shot, use_frame=True, 28 | frames_num=T, 29 | data_type=data_type, use_npz=True, resize=105) 30 | testLoader = DataLoader(testSet, batch_size=Flags.way * Flags.shot, shuffle=False, num_workers=Flags.workers) 31 | 32 | right, error = 0, 0 33 | for _, (test1, test2) in tqdm(enumerate(testLoader, 1)): 34 | if Flags.cuda: 35 | test1, test2 = test1.cuda(), test2.cuda() 36 | with torch.no_grad(): 37 | L2_dis = ((test1 - test2) ** 2).sum(1).sum(1).sum(1).sum(1).sqrt() 38 | pred = L2_dis.argmin().cpu().item() 39 | if pred < Flags.shot: 40 | right += 1 41 | else: 42 | error += 1 43 | print("#" * 70) 44 | print("final accuracy: {:.4f}".format(right / (right + error) * 100)) 45 | -------------------------------------------------------------------------------- /examples/CSNN/main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from data_loader.nomniglot_full import NOmniglotfull 4 | import argparse 5 | import torchvision 6 | from examples.CSNN.meta import Meta 7 | 8 | 9 | def main(args): 10 | torch.manual_seed(args.seed) 11 | torch.cuda.manual_seed_all(args.seed) 12 | np.random.seed(args.seed) 13 | 14 | print(args) 15 | 16 | device = torch.device('cuda') 17 | model = Meta(args).to(device) 18 | TrainSet = NOmniglotfull(root=r'../../data/', train=True, frames_num=4, data_type='event', 19 | transform=torchvision.transforms.Resize((28, 28))) 20 | TestSet = NOmniglotfull(root=r'../../data/', train=False, frames_num=4, data_type='event', 21 | transform=torchvision.transforms.Resize((28, 28))) 22 | 23 | trainLoader = torch.utils.data.DataLoader(TrainSet, batch_size=1, shuffle=True, num_workers=0, pin_memory=True) 24 | testLoader = torch.utils.data.DataLoader(TestSet, batch_size=1, shuffle=True, num_workers=0, pin_memory=True) 25 | 26 | best = 0 27 | besttest = 0 28 | for step in range(args.epoch): 29 | accs = [] 30 | for x, y in trainLoader: 31 | x, y = x .float().to(device), y .long().to(device) 32 | accs.append(model(x, y)) 33 | 34 | accs = np.array(accs).mean(axis=0).astype(np.float16) 35 | if best < accs: best = accs 36 | print('\ttraining acc:', accs, "best", best) 37 | 38 | accstest = [] 39 | for x, y in testLoader: 40 | 41 | x, y = x .float().to(device), (y).long().to(device) 42 | 43 | test_acc = model.test(x, y) 44 | accstest.append(test_acc) 45 | 46 | accstest = np.array(accstest).mean(axis=0).astype(np.float16) 47 | if besttest < accstest: besttest = accstest 48 | print('Test acc:', accstest, "best", besttest) 49 | 50 | 51 | if __name__ == '__main__': 52 | argparser = argparse.ArgumentParser() 53 | argparser.add_argument('--epoch', type=int, help='epoch number', default=400000) 54 | argparser.add_argument('--seed', type=int, help='seed number', default=0) 55 | 56 | args = argparser.parse_args() 57 | 58 | main(args) 59 | -------------------------------------------------------------------------------- /examples/LSTM/main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from data_loader.nomniglot_full import NOmniglotfull 4 | import argparse 5 | import torchvision 6 | 7 | from examples.LSTM.meta import Meta 8 | 9 | 10 | def main(args): 11 | torch.manual_seed(args.seed) 12 | torch.cuda.manual_seed_all(args.seed) 13 | np.random.seed(args.seed) 14 | 15 | print(args) 16 | 17 | device = torch.device('cuda') 18 | maml = Meta(args).to(device) 19 | TrainSet = NOmniglotfull(root=r'../../data/', train=True, frames_num=4, data_type='event', 20 | transform=torchvision.transforms.Resize((28, 28))) 21 | TestSet = NOmniglotfull(root=r'../../data/', train=False, frames_num=4, data_type='event', 22 | transform=torchvision.transforms.Resize((28, 28))) 23 | 24 | trainLoader = torch.utils.data.DataLoader(TrainSet, batch_size=64, shuffle=True, num_workers=0, pin_memory=True) 25 | testLoader = torch.utils.data.DataLoader(TestSet, batch_size=64, shuffle=False, num_workers=0, pin_memory=True) 26 | 27 | best = 0 28 | besttest = 0 29 | for step in range(args.epoch): 30 | accs = [] 31 | 32 | for x, y in trainLoader: 33 | x, y = x.float().to(device), y.long().to(device) 34 | accs.append(maml(x, y)) 35 | 36 | accs = np.array(accs).mean(axis=0).astype(np.float16) 37 | if best < accs: best = accs 38 | print('training acc:', accs, "best", best) 39 | 40 | accstest = [] 41 | 42 | for x, y in testLoader: 43 | x, y = x.float().to(device), y.long().to(device) 44 | 45 | test_acc = maml.test(x, y) 46 | accstest.append(test_acc) 47 | 48 | accstest = np.array(accstest).mean(axis=0).astype(np.float16) 49 | if besttest < accstest: besttest = accstest 50 | print('Test acc:', accstest, "best", besttest) 51 | 52 | 53 | if __name__ == '__main__': 54 | argparser = argparse.ArgumentParser() 55 | argparser.add_argument('--epoch', type=int, help='epoch number', default=400000) 56 | argparser.add_argument('--seed', type=int, help='seed number', default=0) 57 | 58 | args = argparser.parse_args() 59 | 60 | main(args) 61 | -------------------------------------------------------------------------------- /data_loader/NOmniglot.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | from utils import * 3 | 4 | 5 | class NOmniglot(Dataset): 6 | def __init__(self, root='data/', frames_num=12, train=True, data_type='event', 7 | transform=None, target_transform=None, use_npz=True, crop=True, create=True, thread_num=16): 8 | super().__init__() 9 | self.crop = crop 10 | self.data_type = data_type 11 | self.use_npz = use_npz 12 | self.transform = transform 13 | self.target_transform = target_transform 14 | events_npy_root = os.path.join(root, 'events_npy', 'background' if train else "evaluation") 15 | 16 | frames_root = os.path.join(root, f'fnum_{frames_num}_dtype_{data_type}_npz_{use_npz}', 17 | 'background' if train else "evaluation") 18 | 19 | if not os.path.exists(frames_root) and create: 20 | if not os.path.exists(events_npy_root) and create: 21 | os.makedirs(events_npy_root) 22 | print('creating event data..') 23 | convert_aedat4_dir_to_events_dir(root, train) 24 | else: 25 | print(f'npy format events data root {events_npy_root}, already exists') 26 | 27 | os.makedirs(frames_root) 28 | print('creating frames data..') 29 | convert_events_dir_to_frames_dir(events_npy_root, frames_root, '.npy', frames_num, data_type, 30 | thread_num=thread_num, compress=use_npz) 31 | else: 32 | print(f'frames data root {frames_root} already exists.') 33 | 34 | self.datadict, self.num_classes = list_class_files(events_npy_root, frames_root, True, use_npz=use_npz) 35 | 36 | self.datalist = [] 37 | for i in self.datadict: 38 | self.datalist.extend([(j, i) for j in self.datadict[i]]) 39 | 40 | def __len__(self): 41 | return len(self.datalist) 42 | 43 | def __getitem__(self, index): 44 | image, label = self.datalist[index] 45 | image, label = self.readimage(image, label) 46 | return image, label 47 | 48 | def readimage(self, image, label): 49 | if self.use_npz: 50 | image = torch.tensor(np.load(image)['arr_0']).float() 51 | else: 52 | image = torch.tensor(np.load(image)).float() 53 | if self.crop: 54 | image = image[:, :, 4:254, 54:304] 55 | if self.transform is not None: image = self.transform(image) 56 | if self.target_transform is not None: label = self.target_transform(label) 57 | return image, label 58 | 59 | 60 | 61 | -------------------------------------------------------------------------------- /examples/SSiamese/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | # define approximate firing function 7 | thresh, lens = 0.5, 0.5 8 | decay = 0.2 9 | 10 | 11 | class ActFun(torch.autograd.Function): 12 | @staticmethod 13 | def forward(ctx, input): 14 | ctx.save_for_backward(input) 15 | return input.gt(thresh).float() 16 | 17 | @staticmethod 18 | def backward(ctx, grad_output): 19 | input, = ctx.saved_tensors 20 | grad_input = grad_output.clone() 21 | temp = abs(input - thresh) < lens 22 | return grad_input * temp.float() 23 | 24 | 25 | act_fun = ActFun.apply 26 | 27 | 28 | # membrane potential update 29 | def mem_update(ops, x, mem, spike): 30 | mem = mem * decay * (1. - spike) # + ops(x) 31 | mem = mem + ops(x) 32 | spike = act_fun(mem) # act_fun : approximation firing function 33 | return mem, spike 34 | 35 | 36 | class SSiamese(nn.Module): 37 | def __init__(self, device): 38 | super(SSiamese, self).__init__() 39 | self.device = device 40 | self.conv1 = nn.Conv2d(1, 64, 10) 41 | self.conv2 = nn.Conv2d(64, 128, 7) 42 | 43 | self.liner = nn.Linear(56448, 4096) 44 | self.out = nn.Linear(4096, 1) 45 | 46 | def forward(self, input1, input2, batch_size=128, time_window=10): 47 | device = self.device 48 | c1_mem1 = c1_spike1 = torch.zeros(batch_size, 64, 96, 96, device=device) 49 | c2_mem1 = c2_spike1 = torch.zeros(batch_size, 128, 42, 42, device=device) 50 | 51 | c1_mem2 = c1_spike2 = torch.zeros(batch_size, 64, 96, 96, device=device) 52 | c2_mem2 = c2_spike2 = torch.zeros(batch_size, 128, 42, 42, device=device) 53 | h1_mem2 = h1_spike2 = torch.zeros(batch_size, 4096, device=device) 54 | 55 | outputs = torch.zeros(batch_size, 1, device=device) 56 | 57 | for step in range(time_window): # simulation time steps 58 | x = input1[step] 59 | c1_mem1, c1_spike1 = mem_update(self.conv1, x, c1_mem1, c1_spike1) 60 | x = F.avg_pool2d(c1_spike1, 2) 61 | c2_mem1, c2_spike1 = mem_update(self.conv2, x, c2_mem1, c2_spike1) 62 | x = F.avg_pool2d(c2_spike1, 2) 63 | x1 = x.view(batch_size, -1) 64 | 65 | x = input2[step] 66 | c1_mem2, c1_spike2 = mem_update(self.conv1, x, c1_mem2, c1_spike2) 67 | x = F.avg_pool2d(c1_spike2, 2) 68 | c2_mem2, c2_spike2 = mem_update(self.conv2, x, c2_mem2, c2_spike2) 69 | x = F.avg_pool2d(c2_spike2, 2) 70 | x = x.view(batch_size, -1) 71 | 72 | h1_mem2, h1_spike2 = mem_update(self.liner, torch.abs(x-x1), h1_mem2, h1_spike2) 73 | outputs += self.out(h1_spike2) 74 | 75 | return outputs 76 | 77 | 78 | # for test 79 | if __name__ == '__main__': 80 | net = SSiamese() 81 | print(net) 82 | print(list(net.parameters())) 83 | -------------------------------------------------------------------------------- /data_loader/nomniglot_nw_ks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import numpy as np 4 | from torch.utils.data import Dataset, DataLoader 5 | from data_loader.NOmniglot import NOmniglot 6 | 7 | 8 | class NOmniglotNWayKShot(Dataset): 9 | ''' 10 | get n-wway k-shot data as meta learning 11 | We set the sampling times of each epoch as "len(self.dataSet) // (self.n_way * (self.k_shot + self.k_query))" 12 | you can increase or decrease the number of epochs to determine the total training times 13 | ''' 14 | def __init__(self, root, n_way, k_shot, k_query, train=True, frames_num=12, data_type='event', 15 | transform=torchvision.transforms.Resize((28, 28))): 16 | self.dataSet = NOmniglot(root=root, train=train, 17 | frames_num=frames_num, data_type=data_type, transform=transform) 18 | self.n_way = n_way # n way 19 | self.k_shot = k_shot # k shot 20 | self.k_query = k_query # k query 21 | assert (k_shot + k_query) <= 20 22 | self.length = 256 23 | self.data_cache = self.load_data_cache(self.dataSet.datadict, self.length) 24 | 25 | def load_data_cache(self, data_dict, length): 26 | ''' 27 | The dataset is sampled randomly length times, and the address is saved to obtain 28 | ''' 29 | data_cache = [] 30 | for i in range(length): 31 | selected_cls = np.random.choice(len(data_dict), self.n_way, False) 32 | 33 | x_spts, y_spts, x_qrys, y_qrys = [], [], [], [] 34 | for j, cur_class in enumerate(selected_cls): 35 | selected_img = np.random.choice(20, self.k_shot + self.k_query, False) 36 | 37 | x_spts.append(np.array(data_dict[cur_class])[selected_img[:self.k_shot]]) 38 | x_qrys.append(np.array(data_dict[cur_class])[selected_img[self.k_shot:]]) 39 | y_spts.append([j for _ in range(self.k_shot)]) 40 | y_qrys.append([j for _ in range(self.k_query)]) 41 | 42 | shufflespt = np.random.choice(self.n_way * self.k_shot, self.n_way * self.k_shot, False) 43 | shuffleqry = np.random.choice(self.n_way * self.k_query, self.n_way * self.k_query, False) 44 | 45 | temp = [np.array(x_spts).reshape(-1)[shufflespt], np.array(y_spts).reshape(-1)[shufflespt], 46 | np.array(x_qrys).reshape(-1)[shuffleqry], np.array(y_qrys).reshape(-1)[shuffleqry]] 47 | data_cache.append(temp) 48 | return data_cache 49 | 50 | def __getitem__(self, index): 51 | x_spts, y_spts, x_qrys, y_qrys = self.data_cache[index] 52 | x_sptst, y_sptst, x_qryst, y_qryst = [], [], [], [] 53 | 54 | for i, j in zip(x_spts, y_spts): 55 | i, j = self.dataSet.readimage(i, j) 56 | x_sptst.append(i.unsqueeze(0)) 57 | y_sptst.append(j) 58 | for i, j in zip(x_qrys, y_qrys): 59 | i, j = self.dataSet.readimage(i, j) 60 | x_qryst.append(i.unsqueeze(0)) 61 | y_qryst.append(j) 62 | return torch.cat(x_sptst, dim=0), np.array(y_sptst), torch.cat(x_qryst, dim=0), np.array(y_qryst) 63 | 64 | def reset(self): 65 | self.data_cache = self.load_data_cache(self.dataSet.datadict, self.length) 66 | 67 | def __len__(self): 68 | return len(self.data_cache) 69 | 70 | 71 | if __name__ == "__main__": 72 | db_train = NOmniglotNWayKShot('./data/', n_way=5, k_shot=1, k_query=15, 73 | frames_num=4, data_type='frequency', train=True) 74 | dataloadertrain = DataLoader(db_train, batch_size=16, shuffle=True, num_workers=16, pin_memory=True) 75 | for x_spt, y_spt, x_qry, y_qry in dataloadertrain: 76 | print(x_spt.shape) 77 | db_train.resampling() -------------------------------------------------------------------------------- /examples/SMAML/main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from data_loader.nomniglot_nw_ks import NOmniglotNWayKShot 4 | import argparse 5 | from torch.utils.data import Dataset, DataLoader 6 | from examples.SMAML.meta import Meta 7 | 8 | 9 | # Dacay learning_rate 10 | def lr_scheduler(optimizer, epoch, init_lr=0.1, lr_decay_epoch=100): 11 | """Decay learning rate by a factor of 0.1 every lr_decay_epoch epochs.""" 12 | if epoch % lr_decay_epoch == 0 and epoch > 1: 13 | for param_group in optimizer.param_groups: 14 | if param_group['lr']>0.005: 15 | param_group['lr']*=0.5 16 | return optimizer 17 | 18 | 19 | def main(args): 20 | torch.manual_seed(args.seed) 21 | torch.cuda.manual_seed_all(args.seed) 22 | np.random.seed(args.seed) 23 | 24 | device = torch.device('cuda') 25 | maml = Meta(args).to(device) 26 | 27 | best = 0 28 | besttest = 0 29 | 30 | db_train = NOmniglotNWayKShot(r'../../data/', 31 | n_way=args.n_way, 32 | k_shot=args.k_spt, 33 | k_query=args.k_qry, 34 | train=True, 35 | frames_num=args.frames_num, 36 | data_type=args.result_type) 37 | db_test = NOmniglotNWayKShot(r'../../data/', 38 | n_way=args.n_way, 39 | k_shot=args.k_spt, 40 | k_query=args.k_qry, 41 | train=False, 42 | frames_num=args.frames_num, 43 | data_type=args.result_type) 44 | print(len(db_train)) 45 | dataloadertrain = DataLoader(db_train, batch_size=args.task_num, shuffle=True, num_workers=4, pin_memory=True) 46 | dataloadertest = DataLoader(db_test, batch_size=args.task_num, shuffle=False, num_workers=4, pin_memory=True) 47 | for step in range(args.epoch): 48 | acctrains = [] 49 | for x_spt, y_spt, x_qry, y_qry in dataloadertrain: 50 | x_spt, y_spt, x_qry, y_qry = (x_spt).to(device), y_spt.long().to(device), \ 51 | (x_qry).to(device), y_qry.long().to(device) 52 | 53 | acctrains.append(maml(x_spt, y_spt, x_qry, y_qry)[0]) 54 | 55 | acctrains = np.array(acctrains).mean(axis=0).astype(np.float16) 56 | if best < acctrains: best = acctrains 57 | 58 | print('step:', step, '\ttraining acc:', acctrains, "best", best,args.n_way,args.k_spt,args.frames_num,args.result_type,args.seed,"3layer") 59 | 60 | accstest = [] 61 | for x_spt, y_spt, x_qry, y_qry in dataloadertest: 62 | x_spt, y_spt, x_qry, y_qry = x_spt.to(device), y_spt.long().to(device), \ 63 | x_qry.to(device), y_qry.long().to(device) 64 | 65 | for x_spt_one, y_spt_one, x_qry_one, y_qry_one in zip(x_spt, y_spt, x_qry, y_qry): 66 | test_acc = maml.finetunning(x_spt_one, y_spt_one, x_qry_one, y_qry_one) 67 | accstest.append(test_acc[-1]) 68 | 69 | accstest = np.array(accstest).mean(axis=0).astype(np.float16) 70 | if besttest < accstest: besttest = accstest 71 | 72 | print('Test acc:', accstest, "best", besttest) 73 | db_train.reset() 74 | 75 | 76 | if __name__ == '__main__': 77 | argparser = argparse.ArgumentParser() 78 | argparser.add_argument('--epoch', type=int, help='epoch number', default=400000) 79 | argparser.add_argument('--n_way', type=int, help='n way', default=5) 80 | argparser.add_argument('--seed', type=int, help='seed', default=0) 81 | argparser.add_argument('--k_spt', type=int, help='k shot for support set', default=1) 82 | argparser.add_argument('--k_qry', type=int, help='k shot for query set', default=15) 83 | argparser.add_argument('--frames_num', type=int, help='frames_num', default=4) 84 | argparser.add_argument('--result_type', type=str, help='result_type', default="event") 85 | 86 | argparser.add_argument('--task_num', type=int, help='meta batch size, namely task num', default=16) 87 | argparser.add_argument('--meta_lr', type=float, help='meta-level outer learning rate', default=0.2) 88 | argparser.add_argument('--update_lr', type=float, help='task-level inner update learning rate', default=0.2) 89 | argparser.add_argument('--update_step', type=int, help='task-level inner update steps', default=3) 90 | argparser.add_argument('--update_step_test', type=int, help='update steps for finetunning', default=8) 91 | 92 | args = argparser.parse_args() 93 | 94 | main(args) 95 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # N-Omniglot 2 | 3 | [[Paper]](https://www.nature.com/articles/s41597-022-01851-z) || [[Dataset]](https://figshare.com/articles/dataset/N-Omniglot/16821427) 4 | 5 | N-Omniglot is a large neuromorphic few-shot learning dataset. It reconstructs strokes of Omniglot as videos and uses Davis346 to capture the writing of the characters. The recordings can be displayed using DV software's playback function (https://inivation.gitlab.io/dv/dv-docs/docs/getting-started.html). N-Omniglot is sparse and has little similarity between frames. It can be used for event-driven pattern recognition, few-shot learning and stroke generation. 6 | 7 | It is a neuromorphic event dataset composed of 1623 handwritten characters obtained by the neuromorphic camera Davis346. Each type of character contains handwritten samples of 20 different participants. The file structure and sample can be found in the corresponding PNG files in `samples`. 8 | 9 | **The raw data can be found on the https://doi.org/10.6084/m9.figshare.16821427.** 10 | 11 | 12 | 13 | ## Structure 14 | 15 | ![filestruct_00.png](README.assets/filestruct_00.png)sample_00 16 | 17 | 18 | 19 | ## How to use N-Omniglot 20 | 21 | We also provide an interface to this dataset in `data_loader` so that users can easily access their own applications using Pytorch, Python 3 is recommended. 22 | 23 | - NOmniglot.py: basic dataset 24 | - nomniglot_full.py: get full train and test loader, for direct to SCNN 25 | - nomniglot_train_test.py: split train and test loader, for Siamese Net 26 | - nomniglot_nw_ks.py: change into n-way k-shot, for MAML 27 | - utils.py: some functions 28 | 29 | 30 | 31 | As with `DVS-Gesture`, each N-Omniglot raw file contains 20 samples of event information. The `NOmniglot` class first splits N-Omniglot dataset into single sample and stores in the `event_npy` folder for long-term use (reference [SpikingJelly](https://github.com/fangwei123456/spikingjelly)). Later, the event data will be encoded into different event frames according to different parameters. The main parameters include frame number and data type. The event type is used to output the event frame of the operation `OR`, and the float type is used to output the firing rate of each pixel. 32 | 33 | Before you run this code, some packages need to be ready: 34 | 35 | pip install dv 36 | pip install pandas 37 | torch 38 | torchvision >= 0.8.1 39 | 40 | 41 | 42 | - #### use `nomniglot_full`: 43 | 44 | ```python 45 | db_train = NOmniglotfull('./data/', train=True, frames_num=4, data_type='frequency', thread_num=16) 46 | dataloadertrain = DataLoader(db_train, batch_size=16, shuffle=True, num_workers=16, pin_memory=True) 47 | for x_spt, y_spt, x_qry, y_qry in dataloadertrain: 48 | print(x_spt.shape) 49 | ``` 50 | 51 | 52 | 53 | - #### use `nomniglot_pair`: 54 | 55 | ```python 56 | data_type = 'frequency' 57 | T = 4 58 | trainSet = NOmniglotTrain(root='data/', use_frame=True, frames_num=T, data_type=data_type, use_npz=True, resize=105) 59 | testSet = NOmniglotTest(root='data/', time=1000, way=5, shot=1, use_frame=True, frames_num=T, data_type=data_type, use_npz=True, resize=105) 60 | trainLoader = DataLoader(trainSet, batch_size=48, shuffle=False, num_workers=4) 61 | testLoader = DataLoader(testSet, batch_size=5 * 1, shuffle=False, num_workers=4) 62 | for batch_id, (img1, img2) in enumerate(testLoader, 1): 63 | # img1.shape [batch, T, 2, H, W] 64 | print(batch_id) 65 | break 66 | 67 | for batch_id, (img1, img2, label) in enumerate(trainLoader, 1): 68 | # img1.shape [batch, T, 2, H, W] 69 | print(batch_id) 70 | break 71 | ``` 72 | 73 | 74 | 75 | - #### use `nomniglot_nw_ks`: 76 | 77 | ```python 78 | db_train = NOmniglotNWayKShot('./data/', n_way=5, k_shot=1, k_query=15, 79 | frames_num=4, data_type='frequency', train=True) 80 | dataloadertrain = DataLoader(db_train, batch_size=16, shuffle=True, num_workers=16, pin_memory=True) 81 | for x_spt, y_spt, x_qry, y_qry in dataloadertrain: 82 | print(x_spt.shape) 83 | db_train.resampling() 84 | ``` 85 | 86 | 87 | 88 | 89 | 90 | ## Experiment 91 | 92 | method 93 | 94 | We provide four modified SNN-appropriate few-shot learning methods in `examples` to provide a benchmark for N-Omniglot dataset. Different `way`, `shot`, `data_type`, `frames_num` can be choose to run the experiments. You can run a method directly in the `PyCharm` environment 95 | 96 | 97 | 98 | ## Reference 99 | 100 | [1] Yang Li, Yiting Dong, Dongcheng Zhao, Yi Zeng. N-Omniglot: a Large-scale Neuromorphic Dataset for Spatio-temporal Sparse Few-shot Learning. figshare https://doi.org/10.6084/m9.figshare.16821427.v2 (2021). 101 | 102 | [2] Li Y, Dong Y, Zhao D, et al. N-Omniglot, a large-scale neuromorphic dataset for spatio-temporal sparse few-shot learning[J]. Scientific Data, 2022, 9(1): 746. 103 | -------------------------------------------------------------------------------- /examples/SMAML/meta.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch import optim 4 | from torch.nn import functional as F 5 | from torch.utils.data import TensorDataset, DataLoader 6 | from torch import optim 7 | import numpy as np 8 | from examples.SMAML.learner import SCNN2 9 | from copy import deepcopy 10 | 11 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 12 | 13 | 14 | class Meta(nn.Module): 15 | """ 16 | Meta Learner 17 | """ 18 | def __init__(self, args): 19 | """ 20 | :param args: 21 | """ 22 | super(Meta, self).__init__() 23 | 24 | self.update_lr = args.update_lr 25 | self.meta_lr = args.meta_lr 26 | self.n_way = args.n_way 27 | self.k_spt = args.k_spt 28 | self.k_qry = args.k_qry 29 | self.task_num = args.task_num 30 | self.update_step = args.update_step 31 | self.update_step_test = args.update_step_test 32 | 33 | self.net = SCNN2(device) 34 | 35 | self.optimmeta = optim.SGD(self.net.parameters(), lr=self.meta_lr) 36 | 37 | def forward(self, x_spt, y_spt, x_qry, y_qry): 38 | """ 39 | :param x_spt: [b, setsz, c_, h, w] 40 | :param y_spt: [b, setsz] 41 | :param x_qry: [b, querysz, c_, h, w] 42 | :param y_qry: [b, querysz] 43 | :return: 44 | """ 45 | 46 | criterion = nn.CrossEntropyLoss() 47 | grad_q = None 48 | 49 | corrects = [0 for _ in range(self.update_step)] 50 | lossfull = None 51 | for i in range(x_spt.shape[0]): 52 | logits = self.net(x_spt[i]) 53 | loss = criterion(logits, y_spt[i]) 54 | self.optimmeta.zero_grad() 55 | loss.backward() 56 | nettemp = deepcopy(self.net) 57 | paratemp = {i: torch.clone(j) for i, j in self.net.named_parameters()} 58 | nettemp.load_state_dict(paratemp, strict=False) 59 | 60 | optimtemp = optim.SGD(nettemp.parameters(), lr=self.update_lr) 61 | for g, p in zip(nettemp.parameters(), self.net.parameters()): 62 | g.grad = p.grad 63 | optimtemp.step() # 更新了fast weight 64 | for j in range(self.update_step): 65 | logits = nettemp(x_spt[i]) 66 | loss = criterion(logits, y_spt[i]) 67 | optimtemp.zero_grad() 68 | loss.backward() 69 | optimtemp.step() 70 | 71 | logits = nettemp(x_qry[i]) 72 | loss = criterion(logits, y_qry[i]) 73 | optimtemp.zero_grad() 74 | if lossfull is None: 75 | lossfull = loss / self.task_num 76 | else: 77 | lossfull += loss / self.task_num 78 | 79 | with torch.no_grad(): 80 | pred_q = F.softmax(logits, dim=1).argmax(dim=1) 81 | # _, q = y_qry[i] .max(1) 82 | q = y_qry[i] 83 | correct = torch.eq(pred_q, q).sum().item() # convert to numpy 84 | corrects[0] = corrects[0] + correct 85 | 86 | self.optimmeta.zero_grad() 87 | lossfull.backward() 88 | self.optimmeta.step() 89 | 90 | accs = np.array(corrects) / (self.k_qry * self.n_way * x_spt.shape[0]) 91 | 92 | return accs 93 | 94 | def finetunning(self, x_spt, y_spt, x_qry, y_qry): 95 | """ 96 | :param x_spt: [setsz, c_, h, w] 97 | :param y_spt: [setsz] 98 | :param x_qry: [querysz, c_, h, w] 99 | :param y_qry: [querysz] 100 | :return: 101 | """ 102 | querysz = x_qry.size(0) 103 | criterion = nn.CrossEntropyLoss() 104 | corrects = [0 for _ in range(self.update_step_test)] 105 | 106 | # in order to not ruin the state of running_mean/variance and bn_weight/bias 107 | # we finetunning on the copied model instead of self.net 108 | nettemp = deepcopy(self.net) 109 | 110 | optimtemp = optim.SGD(nettemp.parameters(), lr=self.update_lr) 111 | # 1. run the i-th task and compute loss for k=0 112 | logits = nettemp(x_spt) 113 | loss = criterion(logits, y_spt) 114 | optimtemp.zero_grad() 115 | loss.backward() 116 | optimtemp.step() 117 | # this is the loss and accuracy before first update 118 | 119 | for j in range(self.update_step_test): 120 | logits = nettemp(x_spt) 121 | loss = criterion(logits, y_spt) 122 | optimtemp.zero_grad() 123 | loss.backward() 124 | optimtemp.step() 125 | 126 | # [setsz, nway] 127 | logits_q = nettemp(x_qry) 128 | # [setsz] 129 | pred_q = F.softmax(logits_q, dim=1).argmax(dim=1) 130 | # scalar 131 | correct = torch.eq(pred_q, y_qry).sum().item() 132 | corrects[j] = corrects[j] + correct 133 | del nettemp 134 | 135 | accs = np.array(corrects) / querysz 136 | 137 | return accs 138 | -------------------------------------------------------------------------------- /examples/SSiamese/main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import sys 3 | from data_loader.nomniglot_pair import NOmniglotTrainSet, NOmniglotTestSet 4 | from torch.utils.data import DataLoader 5 | from torch.autograd import Variable 6 | from examples.SSiamese.model import SSiamese 7 | import time 8 | import numpy as np 9 | import gflags 10 | from collections import deque 11 | import os 12 | 13 | 14 | if __name__ == '__main__': 15 | Flags = gflags.FLAGS 16 | gflags.DEFINE_bool("cuda", True, "use cuda") 17 | gflags.DEFINE_integer("way", 5, "how much way one-shot learning") 18 | gflags.DEFINE_integer("shot", 1, "how much shot few-shot learning") 19 | gflags.DEFINE_string("time", 2000, "number of samples to test accuracy") 20 | gflags.DEFINE_integer("workers", 8, "number of dataLoader workers") 21 | gflags.DEFINE_integer("batch_size",64, "number of batch size") 22 | gflags.DEFINE_float("lr", 0.0001, "learning rate") 23 | gflags.DEFINE_integer("show_every", 100, "show result after each show_every iter.") 24 | gflags.DEFINE_integer("save_every", 100, "save model after each save_every iter.") 25 | gflags.DEFINE_integer("test_every", 10000, "test model after each test_every iter.") 26 | gflags.DEFINE_integer("max_iter", 40000, "number of iterations before stopping") 27 | gflags.DEFINE_string("model_path", "./", "path to store model") 28 | gflags.DEFINE_string("gpu_ids", "0", "gpu ids used to train") 29 | 30 | Flags(sys.argv) 31 | T = 4 32 | data_type = 'event' # frequency 33 | 34 | os.environ["CUDA_VISIBLE_DEVICES"] = Flags.gpu_ids 35 | device = torch.device('cuda') if torch.cuda.is_available() else 'cpu' 36 | print("use gpu:", Flags.gpu_ids, "to train.") 37 | print("way:%d, shot: %d" % (Flags.way, Flags.shot)) 38 | 39 | seed = 346 40 | torch.manual_seed(seed) 41 | torch.cuda.manual_seed_all(seed) 42 | torch.backends.cudnn.deterministic = True 43 | 44 | trainSet = NOmniglotTrainSet(root='../../data/', use_frame=True, frames_num=T, data_type=data_type, 45 | use_npz=True, resize=105) 46 | testSet = NOmniglotTestSet(root='../../data/', time=Flags.time, way=Flags.way, shot=Flags.shot, use_frame=True, 47 | frames_num=T, data_type=data_type, use_npz=True, resize=105) 48 | trainLoader = DataLoader(trainSet, batch_size=Flags.batch_size, shuffle=False, num_workers=Flags.workers) 49 | testLoader = DataLoader(testSet, batch_size=Flags.way * Flags.shot, shuffle=False, num_workers=Flags.workers) 50 | 51 | acc_list = [] 52 | loss_fn = torch.nn.BCEWithLogitsLoss(reduction='mean') 53 | net = SSiamese(device=device) 54 | 55 | # multi gpu 56 | if len(Flags.gpu_ids.split(",")) > 1: 57 | net = torch.nn.DataParallel(net) 58 | 59 | if Flags.cuda: 60 | net.cuda() 61 | 62 | net.train() 63 | 64 | optimizer = torch.optim.Adam(net.parameters(), lr=Flags.lr) 65 | optimizer.zero_grad() 66 | 67 | train_loss = [] 68 | loss_val = 0 69 | time_start = time.time() 70 | queue = deque(maxlen=20) 71 | 72 | best = 0 73 | for batch_id, (img1, img2, label) in enumerate(trainLoader, 1): 74 | img1 = img1.permute(1, 0, 2, 3, 4)[:, :, 0, :, :].unsqueeze(2) 75 | img2 = img2.permute(1, 0, 2, 3, 4)[:, :, 0, :, :].unsqueeze(2) 76 | max = torch.max(img1.max(), img2.max()) 77 | torch.cuda.empty_cache() 78 | if batch_id > Flags.max_iter: 79 | break 80 | optimizer.zero_grad() 81 | output = net.forward(img1.to(device), img2.to(device), batch_size=Flags.batch_size, time_window=T) 82 | loss = loss_fn(output, label.to(device)) 83 | loss_val += loss.cpu().item() 84 | loss.backward() 85 | optimizer.step() 86 | if batch_id % Flags.show_every == 0 : 87 | print('[%d]\tloss:\t%.5f\ttime lapsed:\t%.2f s'%(batch_id, loss_val/Flags.show_every, time.time() - time_start)) 88 | loss_val = 0 89 | time_start = time.time() 90 | # if batch_id % Flags.save_every == 0: 91 | # torch.save(net.state_dict(), Flags.model_path + '/model'+ ".pt") 92 | if batch_id % Flags.test_every == 0 or (batch_id > 30000 and batch_id % 200 == 0): 93 | right, error = 0, 0 94 | for _, (test1, test2) in enumerate(testLoader, 1): 95 | test1 = test1.permute(1, 0, 2, 3, 4)[:, :, 0, :, :].unsqueeze(2) 96 | test2 = test2.permute(1, 0, 2, 3, 4)[:, :, 0, :, :].unsqueeze(2) 97 | if Flags.cuda: 98 | test1, test2 = test1.cuda(), test2.cuda() 99 | test1, test2 = Variable(test1), Variable(test2) 100 | with torch.no_grad(): 101 | output = net.forward(test1, test2, batch_size=Flags.way*Flags.shot, time_window=T).data.cpu().numpy() 102 | pred = np.argmax(output) 103 | if pred < Flags.shot: 104 | right += 1 105 | else: error += 1 106 | print('*'*70) 107 | print('[%d]\tTest set\tcorrect:\t%d\terror:\t%d\tprecision:\t%f'%(batch_id, right, error, right*1.0/(right+error))) 108 | print('*'*70) 109 | queue.append(right*1.0/(right+error)) 110 | 111 | if queue[-1] > best and queue[-1]!= 1.0: 112 | best = queue[-1] 113 | # torch.save(net.state_dict(), Flags.model_path + '/model' + ".pt") 114 | print("data_type: %s, %d-way %d-shot, gpu %s, best:%.4f" % ( 115 | data_type, Flags.way, Flags.shot, Flags.gpu_ids, best)) 116 | 117 | acc = 0.0 118 | for d in queue: 119 | acc += d 120 | print("#"*70) 121 | print("final accuracy: ", acc/20) 122 | print("besr accuracy:", best) 123 | 124 | -------------------------------------------------------------------------------- /examples/CSNN/learner.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 7 | thresh = 0.5 # neuronal threshold 8 | lens = 0.5 # hyper-parameters of approximate function 9 | decay = 0.2 # decay constants 10 | 11 | 12 | class SurrGradSpike(torch.autograd.Function): 13 | """ 14 | Here we implement our spiking nonlinearity which also implements 15 | the surrogate gradient. By subclassing torch.autograd.Function, 16 | we will be able to use all of PyTorch's autograd functionality. 17 | Here we use the normalized negative part of a fast sigmoid 18 | as this was done in Zenke & Ganguli (2018). 19 | """ 20 | 21 | scale = 100.0 # controls steepness of surrogate gradient 22 | 23 | @staticmethod 24 | def forward(ctx, input): 25 | """ 26 | In the forward pass we compute a step function of the input Tensor 27 | and return it. ctx is a context object that we use to stash information which 28 | we need to later backpropagate our error signals. To achieve this we use the 29 | ctx.save_for_backward method. 30 | """ 31 | ctx.save_for_backward(input) 32 | out = torch.zeros_like(input) 33 | out[input > 0] = 1.0 34 | return out 35 | 36 | @staticmethod 37 | def backward(ctx, grad_output): 38 | """ 39 | In the backward pass we receive a Tensor we need to compute the 40 | surrogate gradient of the loss with respect to the input. 41 | Here we use the normalized negative part of a fast sigmoid 42 | as this was done in Zenke & Ganguli (2018). 43 | """ 44 | input, = ctx.saved_tensors 45 | grad_input = grad_output.clone() 46 | grad = grad_input/(SurrGradSpike.scale*torch.abs(input)+1.0)**2 47 | return grad 48 | 49 | # here we overwrite our naive spike function by the "SurrGradSpike" nonlinearity which implements a surrogate gradient 50 | spike_fn = SurrGradSpike.apply 51 | 52 | # define approximate firing function 53 | class ActFun(torch.autograd.Function): 54 | 55 | @staticmethod 56 | def forward(ctx, input): 57 | ctx.save_for_backward(input) 58 | return input.gt(thresh).float() 59 | 60 | @staticmethod 61 | def backward(ctx, grad_output): 62 | input, = ctx.saved_tensors 63 | grad_input = grad_output.clone() 64 | temp = abs(input - thresh) < lens 65 | return grad_input * temp.float() 66 | 67 | act_fun = ActFun.apply 68 | # membrane potential update 69 | 70 | # Dacay learning_rate 71 | def lr_scheduler(optimizer, epoch, init_lr=0.1, lr_decay_epoch=50): 72 | """Decay learning rate by a factor of 0.1 every lr_decay_epoch epochs.""" 73 | if epoch % lr_decay_epoch == 0 and epoch > 1: 74 | for param_group in optimizer.param_groups: 75 | param_group['lr'] = param_group['lr'] * 0.1 76 | return optimizer 77 | 78 | 79 | class LIFConv(nn.Module): 80 | def __init__(self,in_planes, out_planes, kernel_size, stride, padding, decay=0.2): 81 | super( ).__init__() 82 | self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding) 83 | self.mem = self.spike = None 84 | self.decay=decay 85 | 86 | def mem_update(self, x): 87 | if self.mem is None: 88 | self.mem = torch.zeros_like(x, device=device) 89 | self.spike = torch.zeros_like(x, device=device) 90 | 91 | self.mem = self.mem * self.decay * (1. - self.spike) + x 92 | self.spike = act_fun(self.mem) # act_fun : approximation firing function 93 | return self.spike 94 | 95 | def forward(self, x ): 96 | x = self.conv(x) 97 | x = self.mem_update(x) 98 | return x 99 | 100 | def reset(self): 101 | self.mem = self.spike = None 102 | 103 | 104 | class LIFLinear(nn.Module): 105 | def __init__(self, in_planes, out_planes, decay=0.2, last_layer=False): 106 | super().__init__() 107 | self.fc = nn.Linear(in_planes, out_planes) 108 | self.mem = self.spike = None 109 | self.decay = decay 110 | self.last_layer = last_layer 111 | 112 | def mem_update(self, x): 113 | if self.mem is None: 114 | self.mem = torch.zeros_like(x, device=device) 115 | self.spike = torch.zeros_like(x, device=device) 116 | if self.last_layer: 117 | self.mem = self.mem + x 118 | else: 119 | self.mem = self.mem * self.decay * (1. - self.spike) + x 120 | self.spike = act_fun(self.mem) # act_fun : approximation firing function 121 | return self.spike 122 | 123 | def forward(self, x): 124 | x = self.fc(x) 125 | x = self.mem_update(x) 126 | return x 127 | 128 | def reset(self): 129 | self.mem = self.spike = None 130 | 131 | 132 | class SCNN(nn.Module): 133 | def __init__(self): 134 | super(SCNN, self).__init__() 135 | self.conv1 = LIFConv(2, 15, kernel_size=5, stride=1, padding=0) 136 | self.conv2 = LIFConv(15, 40, kernel_size=5, stride=1, padding=0) 137 | 138 | self.fc1 = LIFLinear(640, 300) 139 | self.fc2 = LIFLinear(300, 1623,0.2,True) 140 | 141 | def forward(self, input ): 142 | for step in range(input.shape[1]): # simulation time steps 143 | x = input[:,step] 144 | x = self.conv1(x.float()) 145 | x = F.avg_pool2d(x, 2) 146 | x = self.conv2(x) 147 | x = F.avg_pool2d(x, 2) 148 | 149 | x = x.view(input.shape[0], -1) 150 | x = self.fc1(x) 151 | x = self.fc2(x) 152 | 153 | outputs = self.fc2.mem / input.shape[1] 154 | self.reset() 155 | return outputs 156 | 157 | def reset(self): 158 | for i in self.children(): 159 | i.reset() 160 | 161 | -------------------------------------------------------------------------------- /data_loader/nomniglot_pair.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset, DataLoader 3 | import numpy as np 4 | from numpy.random import choice as npc 5 | import random 6 | import torch.nn.functional as F 7 | from data_loader.NOmniglot import NOmniglot 8 | 9 | 10 | class NOmniglotTrainSet(Dataset): 11 | ''' 12 | Dataloader for Siamese Net 13 | The pairs of similar samples are labeled as 1, and those of different samples are labeled as 0 14 | ''' 15 | def __init__(self, root='data/', use_frame=True, frames_num=10, data_type='event', use_npz=True, resize=None): 16 | super(NOmniglotTrainSet, self).__init__() 17 | self.resize = resize 18 | self.data_type = data_type 19 | self.use_frame = use_frame 20 | self.dataSet = NOmniglot(root=root, train=True, frames_num=frames_num, data_type=data_type, use_npz=use_npz) 21 | self.datas, self.num_classes = self.dataSet.datadict, self.dataSet.num_classes 22 | 23 | np.random.seed(0) 24 | 25 | def __len__(self): 26 | ''' 27 | Sampling upper limit, you can set the maximum sampling times when using to terminate 28 | ''' 29 | return 21000000 30 | 31 | def __getitem__(self, index): 32 | # get image from same class 33 | if index % 2 == 1: 34 | label = 1.0 35 | idx1 = random.randint(0, self.num_classes - 1) 36 | image1 = random.choice(self.datas[idx1]) 37 | image2 = random.choice(self.datas[idx1]) 38 | # get image from different class 39 | else: 40 | label = 0.0 41 | idx1 = random.randint(0, self.num_classes - 1) 42 | idx2 = random.randint(0, self.num_classes - 1) 43 | while idx1 == idx2: 44 | idx2 = random.randint(0, self.num_classes - 1) 45 | image1 = random.choice(self.datas[idx1]) 46 | image2 = random.choice(self.datas[idx2]) 47 | 48 | if self.use_frame: 49 | if self.data_type == 'event': 50 | image1 = torch.tensor(np.load(image1)['arr_0']).float() 51 | image2 = torch.tensor(np.load(image2)['arr_0']).float() 52 | elif self.data_type == 'frequency': 53 | image1 = torch.tensor(np.load(image1)['arr_0']).float() 54 | image2 = torch.tensor(np.load(image2)['arr_0']).float() 55 | else: 56 | raise NotImplementedError 57 | 58 | if self.resize is not None: 59 | image1 = image1[:, :, 4:254, 54:304] 60 | image1 = F.interpolate(image1, size=(self.resize, self.resize)) 61 | image2 = image2[:, :, 4:254, 54:304] 62 | image2 = F.interpolate(image2, size=(self.resize, self.resize)) 63 | 64 | return image1, image2, torch.from_numpy(np.array([label], dtype=np.float32)) 65 | 66 | 67 | class NOmniglotTestSet(Dataset): 68 | ''' 69 | Dataloader for Siamese Net 70 | 71 | ''' 72 | def __init__(self, root='data/', time=1000, way=20, shot=1, query=1, use_frame=True, frames_num=10, data_type='event', use_npz=True, resize=None): 73 | super(NOmniglotTestSet, self).__init__() 74 | self.resize = resize 75 | self.use_frame = use_frame 76 | self.time = time # Sampling times 77 | self.way = way 78 | self.shot = shot 79 | self.query = query 80 | self.img1 = None # Fix test sample while sampling support set 81 | self.c1 = None # Fixed categories when sampling multiple samples 82 | self.c2 = None 83 | self.select_class = [] # selected classes 84 | self.select_sample = [] # selected samples 85 | 86 | self.data_type = data_type 87 | np.random.seed(0) 88 | self.dataSet = NOmniglot(root=root, train=False, frames_num=frames_num, data_type=data_type, use_npz=use_npz) 89 | self.datas, self.num_classes = self.dataSet.datadict, self.dataSet.num_classes 90 | 91 | def __len__(self): 92 | ''' 93 | In general, the total number of test tasks is 1000. 94 | Since one test sample is collected at a time, way * shot support samples are used for each test 95 | ''' 96 | return self.time * self.way * self.shot 97 | 98 | def __getitem__(self, index): 99 | ''' 100 | The 0th sample of each way*shot is used for query and recorded in the selected sample 101 | to achieve the effect of selecting K +1 102 | ''' 103 | idx = index % (self.way * self.shot) 104 | # generate image pair from same class 105 | if idx == 0: # 106 | self.select_class = [] 107 | self.c1 = random.randint(0, self.num_classes - 1) 108 | self.c2 = self.c1 109 | sind = random.randint(0, len(self.datas[self.c1]) - 1) 110 | self.select_sample.append(sind) 111 | self.img1 = self.datas[self.c1][sind] 112 | 113 | sind = random.randint(0, len(self.datas[self.c2]) - 1) 114 | while sind in self.select_sample: 115 | sind = random.randint(0, len(self.datas[self.c2]) - 1) 116 | img2 = self.datas[self.c1][sind] 117 | self.select_sample.append(sind) 118 | self.select_class.append(self.c1) 119 | # generate image pair from different class 120 | else: 121 | if index % self.shot == 0: 122 | self.c2 = random.randint(0, self.num_classes - 1) 123 | while self.c2 in self.select_class: # self.c1 == c2: 124 | self.c2 = random.randint(0, self.num_classes - 1) 125 | self.select_class.append(self.c2) 126 | self.select_sample = [] 127 | sind = random.randint(0, len(self.datas[self.c2]) - 1) 128 | while sind in self.select_sample: 129 | sind = random.randint(0, len(self.datas[self.c2]) - 1) 130 | img2 = self.datas[self.c2][sind] 131 | self.select_sample.append(sind) 132 | 133 | if self.use_frame: 134 | if self.data_type == 'event': 135 | img1 = torch.tensor(np.load(self.img1)['arr_0']).float() 136 | img2 = torch.tensor(np.load(img2)['arr_0']).float() 137 | elif self.data_type == 'frequency': 138 | img1 = torch.tensor(np.load(self.img1)['arr_0']).float() 139 | img2 = torch.tensor(np.load(img2)['arr_0']).float() 140 | else: 141 | raise NotImplementedError 142 | 143 | if self.resize is not None: 144 | img1 = img1[:,:,4:254, 54:304] 145 | img1 = F.interpolate(img1, size=(self.resize,self.resize)) 146 | img2 = img2[:, :, 4:254, 54:304] 147 | img2 = F.interpolate(img2, size=(self.resize,self.resize)) 148 | return img1, img2 149 | 150 | 151 | if __name__ == '__main__': 152 | data_type = 'frequency' 153 | T = 4 154 | trainSet = NOmniglotTrainSet(root='data/', use_frame=True, frames_num=T, data_type=data_type, use_npz=True, resize=105) 155 | testSet = NOmniglotTestSet(root='data/', time=1000, way=5, shot=1, use_frame=True, frames_num=T, 156 | data_type=data_type, use_npz=True, resize=105) 157 | trainLoader = DataLoader(trainSet, batch_size=48, shuffle=False, num_workers=4) 158 | testLoader = DataLoader(testSet, batch_size=5 * 1, shuffle=False, num_workers=4) 159 | for batch_id, (img1, img2) in enumerate(testLoader, 1): 160 | # img1.shape [batch, T, 2, H, W] 161 | print(batch_id) 162 | break 163 | 164 | for batch_id, (img1, img2, label) in enumerate(trainLoader, 1): 165 | # img1.shape [batch, T, 2, H, W] 166 | print(batch_id) 167 | break -------------------------------------------------------------------------------- /examples/SMAML/learner.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 7 | thresh = 0.5 # neuronal threshold 8 | lens = 0.5 # hyper-parameters of approximate function 9 | decay = 0.2 # decay constants 10 | 11 | 12 | class SurrGradSpike(torch.autograd.Function): 13 | """ 14 | Here we implement our spiking nonlinearity which also implements 15 | the surrogate gradient. By subclassing torch.autograd.Function, 16 | we will be able to use all of PyTorch's autograd functionality. 17 | Here we use the normalized negative part of a fast sigmoid 18 | as this was done in Zenke & Ganguli (2018). 19 | """ 20 | 21 | scale = 100.0 # controls steepness of surrogate gradient 22 | 23 | @staticmethod 24 | def forward(ctx, input): 25 | """ 26 | In the forward pass we compute a step function of the input Tensor 27 | and return it. ctx is a context object that we use to stash information which 28 | we need to later backpropagate our error signals. To achieve this we use the 29 | ctx.save_for_backward method. 30 | """ 31 | ctx.save_for_backward(input) 32 | out = torch.zeros_like(input) 33 | out[input > 0] = 1.0 34 | return out 35 | 36 | @staticmethod 37 | def backward(ctx, grad_output): 38 | """ 39 | In the backward pass we receive a Tensor we need to compute the 40 | surrogate gradient of the loss with respect to the input. 41 | Here we use the normalized negative part of a fast sigmoid 42 | as this was done in Zenke & Ganguli (2018). 43 | """ 44 | input, = ctx.saved_tensors 45 | grad_input = grad_output.clone() 46 | grad = grad_input/(SurrGradSpike.scale*torch.abs(input)+1.0)**2 47 | return grad 48 | 49 | # here we overwrite our naive spike function by the "SurrGradSpike" nonlinearity which implements a surrogate gradient 50 | spike_fn = SurrGradSpike.apply 51 | 52 | 53 | # define approximate firing function 54 | class ActFun(torch.autograd.Function): 55 | 56 | @staticmethod 57 | def forward(ctx, input): 58 | ctx.save_for_backward(input) 59 | return input.gt(thresh).float() 60 | 61 | @staticmethod 62 | def backward(ctx, grad_output): 63 | input, = ctx.saved_tensors 64 | grad_input = grad_output.clone() 65 | temp = abs(input - thresh) < lens 66 | return grad_input * temp.float() 67 | 68 | 69 | act_fun = ActFun.apply 70 | # membrane potential update 71 | 72 | 73 | # Dacay learning_rate 74 | def lr_scheduler(optimizer, epoch, init_lr=0.1, lr_decay_epoch=50): 75 | """Decay learning rate by a factor of 0.1 every lr_decay_epoch epochs.""" 76 | if epoch % lr_decay_epoch == 0 and epoch > 1: 77 | for param_group in optimizer.param_groups: 78 | param_group['lr'] = param_group['lr'] * 0.1 79 | return optimizer 80 | 81 | class LIFConv(nn.Module): 82 | def __init__(self,in_planes, out_planes, kernel_size, stride, padding ,decay=0.2,last_layer=False): 83 | super( ).__init__() 84 | self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding) 85 | self.mem = self.spike = None 86 | self.decay=decay 87 | self.last_layer=last_layer 88 | def mem_update( self, x ): 89 | if self.mem is None: 90 | self.mem=torch.zeros_like(x, device=device) 91 | self.spike=torch.zeros_like(x, device=device) 92 | if self.last_layer: 93 | self.mem = self.mem + x 94 | else: 95 | self.mem = self.mem * self.decay * (1. - self.spike) + x 96 | self.spike = act_fun(self.mem) # act_fun : approximation firing function 97 | return self.spike 98 | 99 | def forward(self, x ): 100 | x=self.conv(x) 101 | x=self.mem_update(x) 102 | return x 103 | 104 | def reset(self): 105 | self.mem = self.spike = None 106 | 107 | 108 | class LIFLinear(nn.Module): 109 | def __init__(self, in_planes, out_planes, decay=0.2, last_layer=False): 110 | super().__init__() 111 | self.fc = nn.Linear(in_planes, out_planes) 112 | self.mem = self.spike = None 113 | self.decay = decay 114 | self.last_layer = last_layer 115 | 116 | def mem_update(self, x): 117 | if self.mem is None: 118 | self.mem = torch.zeros_like(x, device=device) 119 | self.spike = torch.zeros_like(x, device=device) 120 | if self.last_layer: 121 | self.mem = self.mem + x 122 | else: 123 | self.mem = self.mem * self.decay * (1. - self.spike) + x 124 | self.spike = act_fun(self.mem) # act_fun : approximation firing function 125 | return self.spike 126 | 127 | def forward(self, x): 128 | x = self.fc(x) 129 | x = self.mem_update(x) 130 | return x 131 | 132 | def reset(self): 133 | self.mem = self.spike = None 134 | 135 | # cnn_layer(in_planes, out_planes, stride, padding, kernel_size) 136 | 137 | 138 | def mem_update(ops, x, mem, spike, last_layer=False): 139 | if last_layer: 140 | mem = mem + ops(x) 141 | else: 142 | mem = mem * decay * (1. - spike) + ops(x) 143 | spike = act_fun(mem) 144 | return mem, spike 145 | 146 | 147 | 148 | class SCNN(nn.Module): 149 | def __init__(self, device): 150 | super(SCNN, self).__init__() 151 | self.conv1 = LIFConv(1, 15, kernel_size=5, stride=1, padding=0) 152 | self.conv2 = LIFConv(15, 40, kernel_size=5, stride=1, padding=0) 153 | self.fc1 = LIFLinear(640, 300) 154 | self.fc2 = LIFLinear(300, 20,0.2,True) 155 | def forward(self, input, time_window=12): 156 | 157 | 158 | for step in range(time_window): # simulation time steps 159 | 160 | x = input > torch.rand(input.size(), device=device) # prob. firing 161 | 162 | x = self.conv1(x.float()) 163 | 164 | x = F.avg_pool2d(x, 2) 165 | x = self.conv2(x) 166 | 167 | x = F.avg_pool2d(x, 2) 168 | 169 | x = x.view(input.shape[0], -1) 170 | x = self.fc1(x) 171 | x = self.fc2(x) 172 | 173 | outputs = self.fc2.mem / time_window 174 | self.reset() 175 | return outputs 176 | 177 | def reset(self): 178 | for i in self.children(): 179 | i.reset() 180 | 181 | 182 | class SCNN2(nn.Module): 183 | # 184 | def __init__(self, device): 185 | super().__init__() 186 | # self.batch_size = 16 187 | self.cfg_fc = (300, 5) 188 | 189 | self.cfg_cnn = ((2, 15, 5, 1, 0), (15, 40, 5, 1, 0)) 190 | self.cfg_kernel = (24, 8, 4) 191 | in_planes1, out_planes1, kernel_size1, stride1, padding1 = self.cfg_cnn[0] 192 | # self.bn1 = nn.BatchNorm2d(15) 193 | self.conv1 = nn.Sequential( 194 | nn.Conv2d(in_planes1, out_planes1, kernel_size=kernel_size1, stride=stride1, padding=padding1), 195 | nn.BatchNorm2d(out_planes1), 196 | nn.ReLU(inplace=True) 197 | ) 198 | in_planes2, out_planes2, kernel_size2, stride2, padding2 = self.cfg_cnn[1] 199 | self.conv2 = nn.Sequential( 200 | nn.Conv2d(in_planes2, out_planes2, kernel_size=kernel_size2, stride=stride2, padding=padding2), 201 | nn.BatchNorm2d(out_planes2), 202 | nn.ReLU(inplace=True) 203 | ) 204 | self.fc1 = nn.Linear(self.cfg_kernel[-1] * self.cfg_kernel[-1] * self.cfg_cnn[-1][1], self.cfg_fc[0]) 205 | 206 | self.fc2 = nn.Linear(self.cfg_fc[0], self.cfg_fc[1]) 207 | self.device = device 208 | 209 | def forward(self, input, time_window=20): 210 | self.batch_size = input.shape[0] 211 | # print(input.shape) 212 | c1_mem = c1_spike = torch.zeros(self.batch_size, self.cfg_cnn[0][1], self.cfg_kernel[0], self.cfg_kernel[0], 213 | device=self.device) 214 | c2_mem = c2_spike = torch.zeros(self.batch_size, self.cfg_cnn[1][1], self.cfg_kernel[1], self.cfg_kernel[1], 215 | device=self.device) 216 | 217 | h1_mem = h1_spike = h1_sumspike = torch.zeros(self.batch_size, self.cfg_fc[0], device=self.device) 218 | h2_mem = h2_spike = h2_sumspike = torch.zeros(self.batch_size, self.cfg_fc[1], device=self.device) 219 | for step in range(input.shape[1]): # simulation time steps 220 | x = 1 - input[:, step] 221 | c1_mem, c1_spike = mem_update(self.conv1, x.float(), c1_mem, c1_spike) 222 | 223 | x = F.avg_pool2d(c1_spike, 2) 224 | 225 | c2_mem, c2_spike = mem_update(self.conv2, x, c2_mem, c2_spike) 226 | 227 | x = F.avg_pool2d(c2_spike, 2) 228 | x = x.view(self.batch_size, -1) 229 | 230 | h1_mem, h1_spike = mem_update(self.fc1, x, h1_mem, h1_spike) 231 | h1_sumspike += h1_spike 232 | h2_mem, h2_spike = mem_update(self.fc2, h1_spike, h2_mem, h2_spike, last_layer=True) 233 | h2_sumspike += h2_spike 234 | 235 | outputs = h2_mem / input.shape[1] # h2_sumspike / time_window 236 | return outputs 237 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import threading 3 | import numpy as np 4 | import pandas 5 | import os 6 | from dv import AedatFile 7 | 8 | 9 | class FunctionThread(threading.Thread): 10 | def __init__(self, f, *args, **kwargs): 11 | super().__init__() 12 | self.f = f 13 | self.args = args 14 | self.kwargs = kwargs 15 | 16 | def run(self): 17 | self.f(*self.args, **self.kwargs) 18 | 19 | 20 | def integrate_events_to_frames(events, height, width, frames_num=10, data_type='event'): 21 | frames = np.zeros(shape=[frames_num, 2, height * width]) 22 | 23 | # create j_{l}和j_{r} 24 | j_l = np.zeros(shape=[frames_num], dtype=int) 25 | j_r = np.zeros(shape=[frames_num], dtype=int) 26 | 27 | # split by time 28 | events['t'] -= events['t'][0] # start with 0 timestamp 29 | assert events['t'][-1] > frames_num 30 | dt = events['t'][-1] // frames_num # get length of each frame 31 | idx = np.arange(events['t'].size) 32 | for i in range(frames_num): 33 | t_l = dt * i 34 | t_r = t_l + dt 35 | mask = np.logical_and(events['t'] >= t_l, events['t'] < t_r) 36 | idx_masked = idx[mask] 37 | if len(idx_masked) == 0: 38 | j_l[i] = -1 39 | j_r[i] = -1 40 | else: 41 | j_l[i] = idx_masked[0] 42 | j_r[i] = idx_masked[-1] + 1 if i < frames_num - 1 else events['t'].size 43 | 44 | for i in range(frames_num): 45 | if j_l[i] >= 0: 46 | x = events['x'][j_l[i]:j_r[i]] 47 | y = events['y'][j_l[i]:j_r[i]] 48 | p = events['p'][j_l[i]:j_r[i]] 49 | mask = [] 50 | mask.append(p == 0) 51 | mask.append(np.logical_not(mask[0])) 52 | for j in range(2): 53 | position = y[mask[j]] * width + x[mask[j]] 54 | events_number_per_pos = np.bincount(position) 55 | frames[i][j][np.arange(events_number_per_pos.size)] += events_number_per_pos 56 | 57 | if data_type == 'frequency': 58 | if i < frames_num - 1: 59 | frames[i] /= dt 60 | else: 61 | frames[i] /= (dt + events['t'][-1] % frames_num) 62 | frames = frames.astype(np.float16) 63 | 64 | if data_type == 'event': 65 | frames = (frames > 0).astype(np.bool) 66 | else: 67 | frames = normalize_frame(frames, 'max') 68 | return frames.reshape((frames_num, 2, height, width)) 69 | 70 | 71 | def normalize_frame(frames: np.ndarray or torch.Tensor, normalization: str): 72 | eps = 1e-5 73 | for i in range(frames.shape[0]): 74 | if normalization == 'max': 75 | frames[i][0] = frames[i][0] / max(frames[i][0].max(), eps) 76 | frames[i][1] = frames[i][1] / max(frames[i][1].max(), eps) 77 | 78 | elif normalization == 'norm': 79 | frames[i][0] = (frames[i][0] - frames[i][0].mean()) / np.sqrt(max(frames[i][0].var(), eps)) 80 | frames[i][1] = (frames[i][1] - frames[i][1].mean()) / np.sqrt(max(frames[i][1].var(), eps)) 81 | 82 | elif normalization == 'sum': 83 | frames[i][0] = frames[i][0] / max(frames[i][0].sum(), eps) 84 | frames[i][1] = frames[i][1] / max(frames[i][1].sum(), eps) 85 | 86 | else: 87 | raise NotImplementedError 88 | return frames 89 | 90 | 91 | def convert_events_dir_to_frames_dir(events_data_dir, frames_data_dir, suffix, 92 | frames_num=12, result_type='event', thread_num=1, 93 | compress=True): 94 | """ 95 | Iterate through all event data in eventS_date_DIR and generate frame data files in frames_data_DIR 96 | """ 97 | def read_function(file_name): 98 | return np.load(file_name, allow_pickle=True).item() 99 | 100 | def cvt_fun(events_file_list): 101 | for events_file in events_file_list: 102 | print(events_file) 103 | frames = integrate_events_to_frames(read_function(events_file), 260, 346, frames_num, result_type ) 104 | if compress: 105 | frames_file = os.path.join(frames_data_dir, 106 | os.path.basename(events_file)[0: -suffix.__len__()] + '.npz') 107 | np.savez_compressed(frames_file, frames) 108 | else: 109 | frames_file = os.path.join(frames_data_dir, 110 | os.path.basename(events_file)[0: -suffix.__len__()] + '.npy') 111 | np.save(frames_file, frames) 112 | 113 | # Obtain the path of the all files 114 | events_file_list = list_all_files(events_data_dir, '.npy') 115 | 116 | if thread_num == 1: 117 | cvt_fun(events_file_list) 118 | else: 119 | # Multithreading acceleration 120 | thread_list = [] 121 | block = events_file_list.__len__() // thread_num 122 | for i in range(thread_num - 1): 123 | thread_list.append(FunctionThread(cvt_fun, events_file_list[i * block: (i + 1) * block])) 124 | thread_list[-1].start() 125 | print(f'thread {i} start, processing files index: {i * block} : {(i + 1) * block}.') 126 | thread_list.append(FunctionThread(cvt_fun, events_file_list[(thread_num - 1) * block:])) 127 | thread_list[-1].start() 128 | print( 129 | f'thread {thread_num} start, processing files index: {(thread_num - 1) * block} : {events_file_list.__len__()}.') 130 | for i in range(thread_num): 131 | thread_list[i].join() 132 | print(f'thread {i} finished.') 133 | 134 | 135 | def convert_aedat4_dir_to_events_dir(root, train): 136 | kind = 'background' if train else "evaluation" 137 | originroot = root 138 | root = root + '/dvs_' + kind + '/' 139 | alphabet_names = [a for a in os.listdir(root) if a[0] != '.'] # get folder names 140 | 141 | for a in range(len(alphabet_names)): 142 | alpha_name = alphabet_names[a] 143 | 144 | for b in range(len(os.listdir(os.path.join(root, alpha_name)))): 145 | character_id = b + 1 146 | character_path = alpha_name + '/character' + num2str(character_id) 147 | print('Parsing %s \\ character%s ...' % (alpha_name, num2str(character_id))) 148 | 149 | file_path = os.path.join(root, character_path) 150 | aedat4_name = [a for a in os.listdir(file_path) if a[-4:] == 'dat4' and len(a) == 11][0] 151 | csv_name = [a for a in os.listdir(file_path) if a[-4:] == '.csv' and len(a) == 8][0] 152 | number = csv_name[:4] 153 | new_path = originroot + '/events_npy/' + kind + '/' + alpha_name + '/character' + num2str(character_id) 154 | if not os.path.exists(new_path): 155 | os.makedirs(new_path) 156 | 157 | start_end_timestamp = pandas.read_csv(os.path.join(file_path, csv_name)).values 158 | 159 | a_timestamp, a_polarity, a_x, a_y = [], [], [], [] 160 | with AedatFile(os.path.join(file_path, aedat4_name)) as f: # read aedat4 161 | for e in f['events']: 162 | a_timestamp.append(e.timestamp) 163 | a_polarity.append(e.polarity) 164 | a_x.append(e.x) 165 | a_y.append(e.y) 166 | 167 | for ii in range(20): # each file has 20 samples 168 | name = str(number) + '_' + num2str(ii + 1) + '.npy' 169 | start_index = a_timestamp.index(start_end_timestamp[ii][1]) 170 | end_index = a_timestamp.index(start_end_timestamp[ii][2]) 171 | tmp = {'t': np.array(a_timestamp[start_index:end_index]), 172 | 'x': np.array(a_x[start_index:end_index]), 173 | 'y': np.array(a_y[start_index:end_index]), 174 | 'p': np.array(a_polarity[start_index:end_index])} 175 | np.save(os.path.join(new_path, name), tmp) 176 | 177 | 178 | def num2str(idx): 179 | if idx < 10: 180 | return '0' + str(idx) 181 | return str(idx) 182 | 183 | 184 | def list_all_files(root, suffix, getlen=False): 185 | ''' 186 | List the path of all files under root, output a list 187 | ''' 188 | file_list = [] 189 | alphabet_names = [a for a in os.listdir(root) if a[0] != '.'] # get folder names 190 | idx = 0 191 | for a in range(len(alphabet_names)): 192 | alpha_name = alphabet_names[a] 193 | for b in range(len(os.listdir(os.path.join(root, alpha_name)))): 194 | character_id = b + 1 195 | character_path = os.path.join(root, alpha_name, 'character' + num2str(character_id)) 196 | idx += 1 197 | for c in range(len(os.listdir(character_path))): 198 | fn_example = os.listdir(character_path)[c] 199 | if fn_example[-4:] == suffix: 200 | file_list.append(os.path.join(character_path, fn_example)) 201 | if getlen: 202 | return file_list, idx 203 | else: 204 | return file_list 205 | 206 | 207 | def list_class_files(root, frames_kind_root, getlen=False, use_npz=False): 208 | ''' 209 | index the generated samples, 210 | get dictionaries according to categories, each corresponding to a list, 211 | the list contain the address of the new file in fnum_x_dtype_x_npz_True 212 | ''' 213 | file_list = {} 214 | alphabet_names = [a for a in os.listdir(root) if a[0] != '.'] # get folder names 215 | idx = 0 216 | for a in range(len(alphabet_names)): 217 | alpha_name = alphabet_names[a] 218 | for b in range(len(os.listdir(os.path.join(root, alpha_name)))): 219 | character_id = b + 1 220 | character_path = os.path.join(root, alpha_name, 'character' + num2str(character_id)) 221 | file_list[idx] = [] 222 | for c in range(len(os.listdir(character_path))): 223 | fn_example = os.listdir(character_path)[c] 224 | if use_npz: 225 | fn_example = fn_example[:-1] + 'z' 226 | file_list[idx].append(os.path.join(frames_kind_root, fn_example)) 227 | idx += 1 228 | if getlen: 229 | return file_list, idx 230 | else: 231 | return file_list 232 | --------------------------------------------------------------------------------