├── README.assets
├── method.png
├── sample_00.png
└── filestruct_00.png
├── examples
├── LSTM
│ ├── meta.py
│ ├── model.py
│ ├── learner.py
│ └── main.py
├── CSNN
│ ├── meta.py
│ ├── main.py
│ └── learner.py
├── NearestNeighbor
│ └── main.py
├── SSiamese
│ ├── model.py
│ └── main.py
└── SMAML
│ ├── main.py
│ ├── meta.py
│ └── learner.py
├── data_loader
├── nomniglot_full.py
├── NOmniglot.py
├── nomniglot_nw_ks.py
└── nomniglot_pair.py
├── README.md
└── utils.py
/README.assets/method.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Brain-Cog-Lab/N-Omniglot/HEAD/README.assets/method.png
--------------------------------------------------------------------------------
/README.assets/sample_00.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Brain-Cog-Lab/N-Omniglot/HEAD/README.assets/sample_00.png
--------------------------------------------------------------------------------
/README.assets/filestruct_00.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Brain-Cog-Lab/N-Omniglot/HEAD/README.assets/filestruct_00.png
--------------------------------------------------------------------------------
/examples/LSTM/meta.py:
--------------------------------------------------------------------------------
1 | from torch import optim
2 | from examples.LSTM.learner import *
3 |
4 |
5 | class Meta(nn.Module):
6 | """
7 | Meta Learner
8 | """
9 | def __init__(self, args):
10 | super(Meta, self).__init__()
11 |
12 | self.net = net()
13 | self.optim = optim.SGD(self.net.parameters(), lr=0.1)
14 |
15 | self.onehot = torch.eye(1623).cuda()
16 | self.lossfunction = nn.CrossEntropyLoss()
17 |
18 | def forward(self, x, y):
19 | self.net.train()
20 | logits = self.net(x)
21 | loss = self.lossfunction(logits, y)
22 |
23 | self.optim.zero_grad()
24 | loss.backward()
25 | self.optim.step()
26 | pred_q = F.softmax(logits, dim=1).argmax(dim=1)
27 | correct = torch.eq(pred_q, y).sum().item() / pred_q.shape[0] # convert to numpy
28 |
29 | return correct
30 |
31 | def test(self, x, y):
32 | self.net.eval()
33 |
34 | logits = self.net(x)
35 |
36 | pred_q = F.softmax(logits, dim=1).argmax(dim=1)
37 | correct = torch.eq(pred_q, y).sum().item() / pred_q.shape[0] # convert to numpy
38 |
39 | return correct
40 |
--------------------------------------------------------------------------------
/examples/CSNN/meta.py:
--------------------------------------------------------------------------------
1 | from torch import optim
2 | from examples.CSNN.learner import *
3 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
4 |
5 |
6 | class Meta(nn.Module):
7 | """
8 | Meta Learner
9 | """
10 | def __init__(self, args):
11 | super(Meta, self).__init__()
12 |
13 | self.net = SCNN()
14 | self.optim = optim.Adam(self.net.parameters(), lr=0.002)
15 |
16 | self.onehot = torch.eye(1623).cuda()
17 | self.lossfunction = nn.CrossEntropyLoss()
18 |
19 | def forward(self, x, y):
20 | self.net.train()
21 | logits = self.net(x)
22 | loss = self.lossfunction(logits, y)
23 |
24 | self.optim.zero_grad()
25 | loss.backward()
26 | self.optim.step()
27 |
28 | pred_q = F.softmax(logits, dim=1).argmax(dim=1)
29 | correct = torch.eq(pred_q, y).sum().item() / pred_q.shape[0]
30 |
31 | return correct
32 |
33 | def test(self, x, y):
34 | self.net.eval()
35 |
36 | logits = self.net(x)
37 | pred_q = F.softmax(logits, dim=1).argmax(dim=1)
38 | correct = torch.eq(pred_q, y).sum().item() / pred_q.shape[0]
39 |
40 | return correct
41 |
--------------------------------------------------------------------------------
/examples/LSTM/model.py:
--------------------------------------------------------------------------------
1 | import random
2 | import os
3 | import torch
4 |
5 |
6 | class LSTM(torch.nn.Module):
7 | def __init__(self, inputnum, hnum):
8 | super().__init__()
9 | self.hnum = hnum
10 | self.wx = torch.randn((inputnum, 4 * hnum), requires_grad=True) / inputnum ** (1 / 2)
11 | self.wh = torch.randn((hnum, 4 * hnum), requires_grad=True) / hnum ** (1 / 2)
12 | self.b = torch.rand([1, 4 * hnum], requires_grad=True)
13 |
14 | self.wx = torch.nn.Parameter(self.wx)
15 | self.wh = torch.nn.Parameter(self.wh)
16 | self.b = torch.nn.Parameter(self.b)
17 |
18 | def forward(self, fullx, h, c): # x =[batch,L,in]
19 | self.fullh = []
20 | self.h = h
21 | self.c = c
22 | for i in range(fullx.shape[1]):
23 | self.x = fullx[:, i, :]
24 | self.t = torch.mm(self.x, self.wx) + torch.mm(self.h, self.wh) + self.b # t=[batch,4h])
25 |
26 | self.gate = torch.sigmoid(self.t[:, :3 * self.hnum]) # gate=[batch,3h])
27 | f, i, o = [self.gate[:, n * self.hnum:(n + 1) * self.hnum] for n in range(3)]
28 | g = torch.tanh(self.t[:, 3 * self.hnum:4 * self.hnum])
29 | self.c = self.c * f
30 | self.c = self.c + g * i
31 | self.h = torch.tanh(self.c) * o
32 |
33 | self.fullh.append(self.h.unsqueeze(1))
34 | self.fullh = torch.cat(self.fullh, dim=1)
35 | return self.fullh, self.h, self.c # fullh=[batch,L,h] #h=[batch,]
36 |
37 | def __call__(self, x, h, c):
38 | return self.forward(x, h, c)
39 |
40 |
41 | if __name__ == "__main__":
42 | rnn = LSTM(10, 3)
43 |
44 | y, h, c = rnn(torch.ones([1, 100, 10]), torch.ones([1, 3]), torch.ones([1, 3]))
45 | print(h)
46 | a = torch.zeros([1, 100, 3])
47 | a[0, -1, :] = 1
48 | print(a)
49 | y.backward(a)
50 | print(rnn.wh.grad)
51 |
--------------------------------------------------------------------------------
/data_loader/nomniglot_full.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import Dataset, DataLoader
3 | from data_loader.NOmniglot import NOmniglot
4 |
5 |
6 | class NOmniglotfull(Dataset):
7 | '''
8 | solve few-shot learning as general classification problem,
9 | We combine the original training set with the test set and take 3/4 as the training set
10 | '''
11 | def __init__(self, root='data/', train=True, frames_num=4, data_type='event',
12 | transform=None, target_transform=None, use_npz=True, crop=True, create=True):
13 | super().__init__()
14 |
15 | trainSet = NOmniglot(root=root, train=True, frames_num=frames_num, data_type=data_type,
16 | transform=transform, target_transform=target_transform,
17 | use_npz=use_npz, crop=crop, create=create)
18 | testSet = NOmniglot(root=root, train=False, frames_num=frames_num, data_type=data_type,
19 | transform=transform, target_transform=lambda x: x + 964,
20 | use_npz=use_npz, crop=crop, create=create)
21 | self.data = torch.utils.data.ConcatDataset([trainSet, testSet])
22 | if train:
23 | self.id = [j for j in range(len(self.data)) if j % 20 in [i for i in range(15)]]
24 |
25 | else:
26 | self.id = [j for j in range(len(self.data)) if j % 20 in [i for i in range(15, 20)]]
27 |
28 | def __len__(self):
29 | return len(self.id)
30 |
31 | def __getitem__(self, index):
32 | image, label = self.data[self.id[index]]
33 | return image, label
34 |
35 |
36 | if __name__ == '__main__':
37 | db_train = NOmniglotfull('../../data/', train=True, frames_num=4, data_type='event')
38 | dataloadertrain = DataLoader(db_train, batch_size=16, shuffle=True, num_workers=16, pin_memory=True)
39 | for x_spt, y_spt, x_qry, y_qry in dataloadertrain:
40 | print(x_spt.shape)
--------------------------------------------------------------------------------
/examples/LSTM/learner.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from examples.LSTM.model import *
5 |
6 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7 |
8 |
9 | class net(torch.nn.Module):
10 | def __init__(self):
11 | super().__init__()
12 |
13 | self.conv1 = torch.nn.Conv2d(2, 32, kernel_size=3, stride=1, padding=1)
14 | self.conv2 = torch.nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1)
15 | self.bn1 = torch.nn.BatchNorm2d(32, eps=1e-5, momentum=0.1, affine=True)
16 | self.bn2 = torch.nn.BatchNorm2d(32, eps=1e-5, momentum=0.1, affine=True)
17 | self.relu1 = torch.nn.ReLU(True)
18 | self.relu2 = torch.nn.ReLU(True)
19 | self.maxpool1 = torch.nn.MaxPool2d(2, stride=2, padding=0)
20 | self.maxpool2 = torch.nn.MaxPool2d(2, stride=2, padding=0)
21 | self.LSTM = LSTM(1568, 128)
22 | self.fc = torch.nn.Linear(128, 1623)
23 |
24 | def forward(self, input):
25 | h = torch.zeros([input.shape[0], 128], device=device)
26 | c = torch.zeros([input.shape[0], 128], device=device)
27 |
28 | result = None
29 | for step in range(input.shape[1]): # simulation time steps
30 |
31 | result = input[:, step]
32 | result = self.conv1(result.float())
33 | result = self.bn1(result)
34 | result = self.relu1(result)
35 | result = self.maxpool1(result)
36 |
37 | result = self.conv2(result) # +result
38 | result = self.bn2(result)
39 | result = self.relu2(result)
40 | result = self.maxpool2(result)
41 | result = result.reshape(input.shape[0], 1, -1)
42 | result, h, c = self.LSTM(result, h, c)
43 | result = result.reshape(input.shape[0], -1)
44 |
45 | result = self.fc(result)
46 |
47 | return result
48 |
49 | def __call__(self, x):
50 | return self.forward(x)
51 |
--------------------------------------------------------------------------------
/examples/NearestNeighbor/main.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from data_loader.nomniglot_pair import NOmniglotTestSet
3 | from torch.utils.data import DataLoader
4 | import numpy as np
5 | import gflags
6 | import sys
7 | import os
8 | from tqdm import tqdm
9 |
10 | if __name__ == '__main__':
11 | Flags = gflags.FLAGS
12 | gflags.DEFINE_bool("cuda", True, "use cuda")
13 | gflags.DEFINE_integer("way", 5, "how much way one-shot learning")
14 | gflags.DEFINE_integer("shot", 1, "how much shot few-shot learning")
15 | gflags.DEFINE_string("time", 10000, "number of samples to test accuracy")
16 | gflags.DEFINE_integer("workers", 4, "number of dataLoader workers")
17 | gflags.DEFINE_string("gpu_ids", "0", "gpu ids used to train")
18 | Flags(sys.argv)
19 | T = 4
20 | data_type = 'event'
21 |
22 | os.environ["CUDA_VISIBLE_DEVICES"] = Flags.gpu_ids
23 | device = torch.device('cuda') if torch.cuda.is_available() else 'cpu'
24 | print("use gpu:", Flags.gpu_ids, "to train.")
25 | print("%d way %d shot" % (Flags.way, Flags.shot))
26 |
27 | testSet = NOmniglotTestSet(root='../../data/', time=Flags.time, way=Flags.way, shot=Flags.shot, use_frame=True,
28 | frames_num=T,
29 | data_type=data_type, use_npz=True, resize=105)
30 | testLoader = DataLoader(testSet, batch_size=Flags.way * Flags.shot, shuffle=False, num_workers=Flags.workers)
31 |
32 | right, error = 0, 0
33 | for _, (test1, test2) in tqdm(enumerate(testLoader, 1)):
34 | if Flags.cuda:
35 | test1, test2 = test1.cuda(), test2.cuda()
36 | with torch.no_grad():
37 | L2_dis = ((test1 - test2) ** 2).sum(1).sum(1).sum(1).sum(1).sqrt()
38 | pred = L2_dis.argmin().cpu().item()
39 | if pred < Flags.shot:
40 | right += 1
41 | else:
42 | error += 1
43 | print("#" * 70)
44 | print("final accuracy: {:.4f}".format(right / (right + error) * 100))
45 |
--------------------------------------------------------------------------------
/examples/CSNN/main.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from data_loader.nomniglot_full import NOmniglotfull
4 | import argparse
5 | import torchvision
6 | from examples.CSNN.meta import Meta
7 |
8 |
9 | def main(args):
10 | torch.manual_seed(args.seed)
11 | torch.cuda.manual_seed_all(args.seed)
12 | np.random.seed(args.seed)
13 |
14 | print(args)
15 |
16 | device = torch.device('cuda')
17 | model = Meta(args).to(device)
18 | TrainSet = NOmniglotfull(root=r'../../data/', train=True, frames_num=4, data_type='event',
19 | transform=torchvision.transforms.Resize((28, 28)))
20 | TestSet = NOmniglotfull(root=r'../../data/', train=False, frames_num=4, data_type='event',
21 | transform=torchvision.transforms.Resize((28, 28)))
22 |
23 | trainLoader = torch.utils.data.DataLoader(TrainSet, batch_size=1, shuffle=True, num_workers=0, pin_memory=True)
24 | testLoader = torch.utils.data.DataLoader(TestSet, batch_size=1, shuffle=True, num_workers=0, pin_memory=True)
25 |
26 | best = 0
27 | besttest = 0
28 | for step in range(args.epoch):
29 | accs = []
30 | for x, y in trainLoader:
31 | x, y = x .float().to(device), y .long().to(device)
32 | accs.append(model(x, y))
33 |
34 | accs = np.array(accs).mean(axis=0).astype(np.float16)
35 | if best < accs: best = accs
36 | print('\ttraining acc:', accs, "best", best)
37 |
38 | accstest = []
39 | for x, y in testLoader:
40 |
41 | x, y = x .float().to(device), (y).long().to(device)
42 |
43 | test_acc = model.test(x, y)
44 | accstest.append(test_acc)
45 |
46 | accstest = np.array(accstest).mean(axis=0).astype(np.float16)
47 | if besttest < accstest: besttest = accstest
48 | print('Test acc:', accstest, "best", besttest)
49 |
50 |
51 | if __name__ == '__main__':
52 | argparser = argparse.ArgumentParser()
53 | argparser.add_argument('--epoch', type=int, help='epoch number', default=400000)
54 | argparser.add_argument('--seed', type=int, help='seed number', default=0)
55 |
56 | args = argparser.parse_args()
57 |
58 | main(args)
59 |
--------------------------------------------------------------------------------
/examples/LSTM/main.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from data_loader.nomniglot_full import NOmniglotfull
4 | import argparse
5 | import torchvision
6 |
7 | from examples.LSTM.meta import Meta
8 |
9 |
10 | def main(args):
11 | torch.manual_seed(args.seed)
12 | torch.cuda.manual_seed_all(args.seed)
13 | np.random.seed(args.seed)
14 |
15 | print(args)
16 |
17 | device = torch.device('cuda')
18 | maml = Meta(args).to(device)
19 | TrainSet = NOmniglotfull(root=r'../../data/', train=True, frames_num=4, data_type='event',
20 | transform=torchvision.transforms.Resize((28, 28)))
21 | TestSet = NOmniglotfull(root=r'../../data/', train=False, frames_num=4, data_type='event',
22 | transform=torchvision.transforms.Resize((28, 28)))
23 |
24 | trainLoader = torch.utils.data.DataLoader(TrainSet, batch_size=64, shuffle=True, num_workers=0, pin_memory=True)
25 | testLoader = torch.utils.data.DataLoader(TestSet, batch_size=64, shuffle=False, num_workers=0, pin_memory=True)
26 |
27 | best = 0
28 | besttest = 0
29 | for step in range(args.epoch):
30 | accs = []
31 |
32 | for x, y in trainLoader:
33 | x, y = x.float().to(device), y.long().to(device)
34 | accs.append(maml(x, y))
35 |
36 | accs = np.array(accs).mean(axis=0).astype(np.float16)
37 | if best < accs: best = accs
38 | print('training acc:', accs, "best", best)
39 |
40 | accstest = []
41 |
42 | for x, y in testLoader:
43 | x, y = x.float().to(device), y.long().to(device)
44 |
45 | test_acc = maml.test(x, y)
46 | accstest.append(test_acc)
47 |
48 | accstest = np.array(accstest).mean(axis=0).astype(np.float16)
49 | if besttest < accstest: besttest = accstest
50 | print('Test acc:', accstest, "best", besttest)
51 |
52 |
53 | if __name__ == '__main__':
54 | argparser = argparse.ArgumentParser()
55 | argparser.add_argument('--epoch', type=int, help='epoch number', default=400000)
56 | argparser.add_argument('--seed', type=int, help='seed number', default=0)
57 |
58 | args = argparser.parse_args()
59 |
60 | main(args)
61 |
--------------------------------------------------------------------------------
/data_loader/NOmniglot.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import Dataset
2 | from utils import *
3 |
4 |
5 | class NOmniglot(Dataset):
6 | def __init__(self, root='data/', frames_num=12, train=True, data_type='event',
7 | transform=None, target_transform=None, use_npz=True, crop=True, create=True, thread_num=16):
8 | super().__init__()
9 | self.crop = crop
10 | self.data_type = data_type
11 | self.use_npz = use_npz
12 | self.transform = transform
13 | self.target_transform = target_transform
14 | events_npy_root = os.path.join(root, 'events_npy', 'background' if train else "evaluation")
15 |
16 | frames_root = os.path.join(root, f'fnum_{frames_num}_dtype_{data_type}_npz_{use_npz}',
17 | 'background' if train else "evaluation")
18 |
19 | if not os.path.exists(frames_root) and create:
20 | if not os.path.exists(events_npy_root) and create:
21 | os.makedirs(events_npy_root)
22 | print('creating event data..')
23 | convert_aedat4_dir_to_events_dir(root, train)
24 | else:
25 | print(f'npy format events data root {events_npy_root}, already exists')
26 |
27 | os.makedirs(frames_root)
28 | print('creating frames data..')
29 | convert_events_dir_to_frames_dir(events_npy_root, frames_root, '.npy', frames_num, data_type,
30 | thread_num=thread_num, compress=use_npz)
31 | else:
32 | print(f'frames data root {frames_root} already exists.')
33 |
34 | self.datadict, self.num_classes = list_class_files(events_npy_root, frames_root, True, use_npz=use_npz)
35 |
36 | self.datalist = []
37 | for i in self.datadict:
38 | self.datalist.extend([(j, i) for j in self.datadict[i]])
39 |
40 | def __len__(self):
41 | return len(self.datalist)
42 |
43 | def __getitem__(self, index):
44 | image, label = self.datalist[index]
45 | image, label = self.readimage(image, label)
46 | return image, label
47 |
48 | def readimage(self, image, label):
49 | if self.use_npz:
50 | image = torch.tensor(np.load(image)['arr_0']).float()
51 | else:
52 | image = torch.tensor(np.load(image)).float()
53 | if self.crop:
54 | image = image[:, :, 4:254, 54:304]
55 | if self.transform is not None: image = self.transform(image)
56 | if self.target_transform is not None: label = self.target_transform(label)
57 | return image, label
58 |
59 |
60 |
61 |
--------------------------------------------------------------------------------
/examples/SSiamese/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | # define approximate firing function
7 | thresh, lens = 0.5, 0.5
8 | decay = 0.2
9 |
10 |
11 | class ActFun(torch.autograd.Function):
12 | @staticmethod
13 | def forward(ctx, input):
14 | ctx.save_for_backward(input)
15 | return input.gt(thresh).float()
16 |
17 | @staticmethod
18 | def backward(ctx, grad_output):
19 | input, = ctx.saved_tensors
20 | grad_input = grad_output.clone()
21 | temp = abs(input - thresh) < lens
22 | return grad_input * temp.float()
23 |
24 |
25 | act_fun = ActFun.apply
26 |
27 |
28 | # membrane potential update
29 | def mem_update(ops, x, mem, spike):
30 | mem = mem * decay * (1. - spike) # + ops(x)
31 | mem = mem + ops(x)
32 | spike = act_fun(mem) # act_fun : approximation firing function
33 | return mem, spike
34 |
35 |
36 | class SSiamese(nn.Module):
37 | def __init__(self, device):
38 | super(SSiamese, self).__init__()
39 | self.device = device
40 | self.conv1 = nn.Conv2d(1, 64, 10)
41 | self.conv2 = nn.Conv2d(64, 128, 7)
42 |
43 | self.liner = nn.Linear(56448, 4096)
44 | self.out = nn.Linear(4096, 1)
45 |
46 | def forward(self, input1, input2, batch_size=128, time_window=10):
47 | device = self.device
48 | c1_mem1 = c1_spike1 = torch.zeros(batch_size, 64, 96, 96, device=device)
49 | c2_mem1 = c2_spike1 = torch.zeros(batch_size, 128, 42, 42, device=device)
50 |
51 | c1_mem2 = c1_spike2 = torch.zeros(batch_size, 64, 96, 96, device=device)
52 | c2_mem2 = c2_spike2 = torch.zeros(batch_size, 128, 42, 42, device=device)
53 | h1_mem2 = h1_spike2 = torch.zeros(batch_size, 4096, device=device)
54 |
55 | outputs = torch.zeros(batch_size, 1, device=device)
56 |
57 | for step in range(time_window): # simulation time steps
58 | x = input1[step]
59 | c1_mem1, c1_spike1 = mem_update(self.conv1, x, c1_mem1, c1_spike1)
60 | x = F.avg_pool2d(c1_spike1, 2)
61 | c2_mem1, c2_spike1 = mem_update(self.conv2, x, c2_mem1, c2_spike1)
62 | x = F.avg_pool2d(c2_spike1, 2)
63 | x1 = x.view(batch_size, -1)
64 |
65 | x = input2[step]
66 | c1_mem2, c1_spike2 = mem_update(self.conv1, x, c1_mem2, c1_spike2)
67 | x = F.avg_pool2d(c1_spike2, 2)
68 | c2_mem2, c2_spike2 = mem_update(self.conv2, x, c2_mem2, c2_spike2)
69 | x = F.avg_pool2d(c2_spike2, 2)
70 | x = x.view(batch_size, -1)
71 |
72 | h1_mem2, h1_spike2 = mem_update(self.liner, torch.abs(x-x1), h1_mem2, h1_spike2)
73 | outputs += self.out(h1_spike2)
74 |
75 | return outputs
76 |
77 |
78 | # for test
79 | if __name__ == '__main__':
80 | net = SSiamese()
81 | print(net)
82 | print(list(net.parameters()))
83 |
--------------------------------------------------------------------------------
/data_loader/nomniglot_nw_ks.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision
3 | import numpy as np
4 | from torch.utils.data import Dataset, DataLoader
5 | from data_loader.NOmniglot import NOmniglot
6 |
7 |
8 | class NOmniglotNWayKShot(Dataset):
9 | '''
10 | get n-wway k-shot data as meta learning
11 | We set the sampling times of each epoch as "len(self.dataSet) // (self.n_way * (self.k_shot + self.k_query))"
12 | you can increase or decrease the number of epochs to determine the total training times
13 | '''
14 | def __init__(self, root, n_way, k_shot, k_query, train=True, frames_num=12, data_type='event',
15 | transform=torchvision.transforms.Resize((28, 28))):
16 | self.dataSet = NOmniglot(root=root, train=train,
17 | frames_num=frames_num, data_type=data_type, transform=transform)
18 | self.n_way = n_way # n way
19 | self.k_shot = k_shot # k shot
20 | self.k_query = k_query # k query
21 | assert (k_shot + k_query) <= 20
22 | self.length = 256
23 | self.data_cache = self.load_data_cache(self.dataSet.datadict, self.length)
24 |
25 | def load_data_cache(self, data_dict, length):
26 | '''
27 | The dataset is sampled randomly length times, and the address is saved to obtain
28 | '''
29 | data_cache = []
30 | for i in range(length):
31 | selected_cls = np.random.choice(len(data_dict), self.n_way, False)
32 |
33 | x_spts, y_spts, x_qrys, y_qrys = [], [], [], []
34 | for j, cur_class in enumerate(selected_cls):
35 | selected_img = np.random.choice(20, self.k_shot + self.k_query, False)
36 |
37 | x_spts.append(np.array(data_dict[cur_class])[selected_img[:self.k_shot]])
38 | x_qrys.append(np.array(data_dict[cur_class])[selected_img[self.k_shot:]])
39 | y_spts.append([j for _ in range(self.k_shot)])
40 | y_qrys.append([j for _ in range(self.k_query)])
41 |
42 | shufflespt = np.random.choice(self.n_way * self.k_shot, self.n_way * self.k_shot, False)
43 | shuffleqry = np.random.choice(self.n_way * self.k_query, self.n_way * self.k_query, False)
44 |
45 | temp = [np.array(x_spts).reshape(-1)[shufflespt], np.array(y_spts).reshape(-1)[shufflespt],
46 | np.array(x_qrys).reshape(-1)[shuffleqry], np.array(y_qrys).reshape(-1)[shuffleqry]]
47 | data_cache.append(temp)
48 | return data_cache
49 |
50 | def __getitem__(self, index):
51 | x_spts, y_spts, x_qrys, y_qrys = self.data_cache[index]
52 | x_sptst, y_sptst, x_qryst, y_qryst = [], [], [], []
53 |
54 | for i, j in zip(x_spts, y_spts):
55 | i, j = self.dataSet.readimage(i, j)
56 | x_sptst.append(i.unsqueeze(0))
57 | y_sptst.append(j)
58 | for i, j in zip(x_qrys, y_qrys):
59 | i, j = self.dataSet.readimage(i, j)
60 | x_qryst.append(i.unsqueeze(0))
61 | y_qryst.append(j)
62 | return torch.cat(x_sptst, dim=0), np.array(y_sptst), torch.cat(x_qryst, dim=0), np.array(y_qryst)
63 |
64 | def reset(self):
65 | self.data_cache = self.load_data_cache(self.dataSet.datadict, self.length)
66 |
67 | def __len__(self):
68 | return len(self.data_cache)
69 |
70 |
71 | if __name__ == "__main__":
72 | db_train = NOmniglotNWayKShot('./data/', n_way=5, k_shot=1, k_query=15,
73 | frames_num=4, data_type='frequency', train=True)
74 | dataloadertrain = DataLoader(db_train, batch_size=16, shuffle=True, num_workers=16, pin_memory=True)
75 | for x_spt, y_spt, x_qry, y_qry in dataloadertrain:
76 | print(x_spt.shape)
77 | db_train.resampling()
--------------------------------------------------------------------------------
/examples/SMAML/main.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from data_loader.nomniglot_nw_ks import NOmniglotNWayKShot
4 | import argparse
5 | from torch.utils.data import Dataset, DataLoader
6 | from examples.SMAML.meta import Meta
7 |
8 |
9 | # Dacay learning_rate
10 | def lr_scheduler(optimizer, epoch, init_lr=0.1, lr_decay_epoch=100):
11 | """Decay learning rate by a factor of 0.1 every lr_decay_epoch epochs."""
12 | if epoch % lr_decay_epoch == 0 and epoch > 1:
13 | for param_group in optimizer.param_groups:
14 | if param_group['lr']>0.005:
15 | param_group['lr']*=0.5
16 | return optimizer
17 |
18 |
19 | def main(args):
20 | torch.manual_seed(args.seed)
21 | torch.cuda.manual_seed_all(args.seed)
22 | np.random.seed(args.seed)
23 |
24 | device = torch.device('cuda')
25 | maml = Meta(args).to(device)
26 |
27 | best = 0
28 | besttest = 0
29 |
30 | db_train = NOmniglotNWayKShot(r'../../data/',
31 | n_way=args.n_way,
32 | k_shot=args.k_spt,
33 | k_query=args.k_qry,
34 | train=True,
35 | frames_num=args.frames_num,
36 | data_type=args.result_type)
37 | db_test = NOmniglotNWayKShot(r'../../data/',
38 | n_way=args.n_way,
39 | k_shot=args.k_spt,
40 | k_query=args.k_qry,
41 | train=False,
42 | frames_num=args.frames_num,
43 | data_type=args.result_type)
44 | print(len(db_train))
45 | dataloadertrain = DataLoader(db_train, batch_size=args.task_num, shuffle=True, num_workers=4, pin_memory=True)
46 | dataloadertest = DataLoader(db_test, batch_size=args.task_num, shuffle=False, num_workers=4, pin_memory=True)
47 | for step in range(args.epoch):
48 | acctrains = []
49 | for x_spt, y_spt, x_qry, y_qry in dataloadertrain:
50 | x_spt, y_spt, x_qry, y_qry = (x_spt).to(device), y_spt.long().to(device), \
51 | (x_qry).to(device), y_qry.long().to(device)
52 |
53 | acctrains.append(maml(x_spt, y_spt, x_qry, y_qry)[0])
54 |
55 | acctrains = np.array(acctrains).mean(axis=0).astype(np.float16)
56 | if best < acctrains: best = acctrains
57 |
58 | print('step:', step, '\ttraining acc:', acctrains, "best", best,args.n_way,args.k_spt,args.frames_num,args.result_type,args.seed,"3layer")
59 |
60 | accstest = []
61 | for x_spt, y_spt, x_qry, y_qry in dataloadertest:
62 | x_spt, y_spt, x_qry, y_qry = x_spt.to(device), y_spt.long().to(device), \
63 | x_qry.to(device), y_qry.long().to(device)
64 |
65 | for x_spt_one, y_spt_one, x_qry_one, y_qry_one in zip(x_spt, y_spt, x_qry, y_qry):
66 | test_acc = maml.finetunning(x_spt_one, y_spt_one, x_qry_one, y_qry_one)
67 | accstest.append(test_acc[-1])
68 |
69 | accstest = np.array(accstest).mean(axis=0).astype(np.float16)
70 | if besttest < accstest: besttest = accstest
71 |
72 | print('Test acc:', accstest, "best", besttest)
73 | db_train.reset()
74 |
75 |
76 | if __name__ == '__main__':
77 | argparser = argparse.ArgumentParser()
78 | argparser.add_argument('--epoch', type=int, help='epoch number', default=400000)
79 | argparser.add_argument('--n_way', type=int, help='n way', default=5)
80 | argparser.add_argument('--seed', type=int, help='seed', default=0)
81 | argparser.add_argument('--k_spt', type=int, help='k shot for support set', default=1)
82 | argparser.add_argument('--k_qry', type=int, help='k shot for query set', default=15)
83 | argparser.add_argument('--frames_num', type=int, help='frames_num', default=4)
84 | argparser.add_argument('--result_type', type=str, help='result_type', default="event")
85 |
86 | argparser.add_argument('--task_num', type=int, help='meta batch size, namely task num', default=16)
87 | argparser.add_argument('--meta_lr', type=float, help='meta-level outer learning rate', default=0.2)
88 | argparser.add_argument('--update_lr', type=float, help='task-level inner update learning rate', default=0.2)
89 | argparser.add_argument('--update_step', type=int, help='task-level inner update steps', default=3)
90 | argparser.add_argument('--update_step_test', type=int, help='update steps for finetunning', default=8)
91 |
92 | args = argparser.parse_args()
93 |
94 | main(args)
95 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # N-Omniglot
2 |
3 | [[Paper]](https://www.nature.com/articles/s41597-022-01851-z) || [[Dataset]](https://figshare.com/articles/dataset/N-Omniglot/16821427)
4 |
5 | N-Omniglot is a large neuromorphic few-shot learning dataset. It reconstructs strokes of Omniglot as videos and uses Davis346 to capture the writing of the characters. The recordings can be displayed using DV software's playback function (https://inivation.gitlab.io/dv/dv-docs/docs/getting-started.html). N-Omniglot is sparse and has little similarity between frames. It can be used for event-driven pattern recognition, few-shot learning and stroke generation.
6 |
7 | It is a neuromorphic event dataset composed of 1623 handwritten characters obtained by the neuromorphic camera Davis346. Each type of character contains handwritten samples of 20 different participants. The file structure and sample can be found in the corresponding PNG files in `samples`.
8 |
9 | **The raw data can be found on the https://doi.org/10.6084/m9.figshare.16821427.**
10 |
11 |
12 |
13 | ## Structure
14 |
15 | 
16 |
17 |
18 |
19 | ## How to use N-Omniglot
20 |
21 | We also provide an interface to this dataset in `data_loader` so that users can easily access their own applications using Pytorch, Python 3 is recommended.
22 |
23 | - NOmniglot.py: basic dataset
24 | - nomniglot_full.py: get full train and test loader, for direct to SCNN
25 | - nomniglot_train_test.py: split train and test loader, for Siamese Net
26 | - nomniglot_nw_ks.py: change into n-way k-shot, for MAML
27 | - utils.py: some functions
28 |
29 |
30 |
31 | As with `DVS-Gesture`, each N-Omniglot raw file contains 20 samples of event information. The `NOmniglot` class first splits N-Omniglot dataset into single sample and stores in the `event_npy` folder for long-term use (reference [SpikingJelly](https://github.com/fangwei123456/spikingjelly)). Later, the event data will be encoded into different event frames according to different parameters. The main parameters include frame number and data type. The event type is used to output the event frame of the operation `OR`, and the float type is used to output the firing rate of each pixel.
32 |
33 | Before you run this code, some packages need to be ready:
34 |
35 | pip install dv
36 | pip install pandas
37 | torch
38 | torchvision >= 0.8.1
39 |
40 |
41 |
42 | - #### use `nomniglot_full`:
43 |
44 | ```python
45 | db_train = NOmniglotfull('./data/', train=True, frames_num=4, data_type='frequency', thread_num=16)
46 | dataloadertrain = DataLoader(db_train, batch_size=16, shuffle=True, num_workers=16, pin_memory=True)
47 | for x_spt, y_spt, x_qry, y_qry in dataloadertrain:
48 | print(x_spt.shape)
49 | ```
50 |
51 |
52 |
53 | - #### use `nomniglot_pair`:
54 |
55 | ```python
56 | data_type = 'frequency'
57 | T = 4
58 | trainSet = NOmniglotTrain(root='data/', use_frame=True, frames_num=T, data_type=data_type, use_npz=True, resize=105)
59 | testSet = NOmniglotTest(root='data/', time=1000, way=5, shot=1, use_frame=True, frames_num=T, data_type=data_type, use_npz=True, resize=105)
60 | trainLoader = DataLoader(trainSet, batch_size=48, shuffle=False, num_workers=4)
61 | testLoader = DataLoader(testSet, batch_size=5 * 1, shuffle=False, num_workers=4)
62 | for batch_id, (img1, img2) in enumerate(testLoader, 1):
63 | # img1.shape [batch, T, 2, H, W]
64 | print(batch_id)
65 | break
66 |
67 | for batch_id, (img1, img2, label) in enumerate(trainLoader, 1):
68 | # img1.shape [batch, T, 2, H, W]
69 | print(batch_id)
70 | break
71 | ```
72 |
73 |
74 |
75 | - #### use `nomniglot_nw_ks`:
76 |
77 | ```python
78 | db_train = NOmniglotNWayKShot('./data/', n_way=5, k_shot=1, k_query=15,
79 | frames_num=4, data_type='frequency', train=True)
80 | dataloadertrain = DataLoader(db_train, batch_size=16, shuffle=True, num_workers=16, pin_memory=True)
81 | for x_spt, y_spt, x_qry, y_qry in dataloadertrain:
82 | print(x_spt.shape)
83 | db_train.resampling()
84 | ```
85 |
86 |
87 |
88 |
89 |
90 | ## Experiment
91 |
92 |
93 |
94 | We provide four modified SNN-appropriate few-shot learning methods in `examples` to provide a benchmark for N-Omniglot dataset. Different `way`, `shot`, `data_type`, `frames_num` can be choose to run the experiments. You can run a method directly in the `PyCharm` environment
95 |
96 |
97 |
98 | ## Reference
99 |
100 | [1] Yang Li, Yiting Dong, Dongcheng Zhao, Yi Zeng. N-Omniglot: a Large-scale Neuromorphic Dataset for Spatio-temporal Sparse Few-shot Learning. figshare https://doi.org/10.6084/m9.figshare.16821427.v2 (2021).
101 |
102 | [2] Li Y, Dong Y, Zhao D, et al. N-Omniglot, a large-scale neuromorphic dataset for spatio-temporal sparse few-shot learning[J]. Scientific Data, 2022, 9(1): 746.
103 |
--------------------------------------------------------------------------------
/examples/SMAML/meta.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch import optim
4 | from torch.nn import functional as F
5 | from torch.utils.data import TensorDataset, DataLoader
6 | from torch import optim
7 | import numpy as np
8 | from examples.SMAML.learner import SCNN2
9 | from copy import deepcopy
10 |
11 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12 |
13 |
14 | class Meta(nn.Module):
15 | """
16 | Meta Learner
17 | """
18 | def __init__(self, args):
19 | """
20 | :param args:
21 | """
22 | super(Meta, self).__init__()
23 |
24 | self.update_lr = args.update_lr
25 | self.meta_lr = args.meta_lr
26 | self.n_way = args.n_way
27 | self.k_spt = args.k_spt
28 | self.k_qry = args.k_qry
29 | self.task_num = args.task_num
30 | self.update_step = args.update_step
31 | self.update_step_test = args.update_step_test
32 |
33 | self.net = SCNN2(device)
34 |
35 | self.optimmeta = optim.SGD(self.net.parameters(), lr=self.meta_lr)
36 |
37 | def forward(self, x_spt, y_spt, x_qry, y_qry):
38 | """
39 | :param x_spt: [b, setsz, c_, h, w]
40 | :param y_spt: [b, setsz]
41 | :param x_qry: [b, querysz, c_, h, w]
42 | :param y_qry: [b, querysz]
43 | :return:
44 | """
45 |
46 | criterion = nn.CrossEntropyLoss()
47 | grad_q = None
48 |
49 | corrects = [0 for _ in range(self.update_step)]
50 | lossfull = None
51 | for i in range(x_spt.shape[0]):
52 | logits = self.net(x_spt[i])
53 | loss = criterion(logits, y_spt[i])
54 | self.optimmeta.zero_grad()
55 | loss.backward()
56 | nettemp = deepcopy(self.net)
57 | paratemp = {i: torch.clone(j) for i, j in self.net.named_parameters()}
58 | nettemp.load_state_dict(paratemp, strict=False)
59 |
60 | optimtemp = optim.SGD(nettemp.parameters(), lr=self.update_lr)
61 | for g, p in zip(nettemp.parameters(), self.net.parameters()):
62 | g.grad = p.grad
63 | optimtemp.step() # 更新了fast weight
64 | for j in range(self.update_step):
65 | logits = nettemp(x_spt[i])
66 | loss = criterion(logits, y_spt[i])
67 | optimtemp.zero_grad()
68 | loss.backward()
69 | optimtemp.step()
70 |
71 | logits = nettemp(x_qry[i])
72 | loss = criterion(logits, y_qry[i])
73 | optimtemp.zero_grad()
74 | if lossfull is None:
75 | lossfull = loss / self.task_num
76 | else:
77 | lossfull += loss / self.task_num
78 |
79 | with torch.no_grad():
80 | pred_q = F.softmax(logits, dim=1).argmax(dim=1)
81 | # _, q = y_qry[i] .max(1)
82 | q = y_qry[i]
83 | correct = torch.eq(pred_q, q).sum().item() # convert to numpy
84 | corrects[0] = corrects[0] + correct
85 |
86 | self.optimmeta.zero_grad()
87 | lossfull.backward()
88 | self.optimmeta.step()
89 |
90 | accs = np.array(corrects) / (self.k_qry * self.n_way * x_spt.shape[0])
91 |
92 | return accs
93 |
94 | def finetunning(self, x_spt, y_spt, x_qry, y_qry):
95 | """
96 | :param x_spt: [setsz, c_, h, w]
97 | :param y_spt: [setsz]
98 | :param x_qry: [querysz, c_, h, w]
99 | :param y_qry: [querysz]
100 | :return:
101 | """
102 | querysz = x_qry.size(0)
103 | criterion = nn.CrossEntropyLoss()
104 | corrects = [0 for _ in range(self.update_step_test)]
105 |
106 | # in order to not ruin the state of running_mean/variance and bn_weight/bias
107 | # we finetunning on the copied model instead of self.net
108 | nettemp = deepcopy(self.net)
109 |
110 | optimtemp = optim.SGD(nettemp.parameters(), lr=self.update_lr)
111 | # 1. run the i-th task and compute loss for k=0
112 | logits = nettemp(x_spt)
113 | loss = criterion(logits, y_spt)
114 | optimtemp.zero_grad()
115 | loss.backward()
116 | optimtemp.step()
117 | # this is the loss and accuracy before first update
118 |
119 | for j in range(self.update_step_test):
120 | logits = nettemp(x_spt)
121 | loss = criterion(logits, y_spt)
122 | optimtemp.zero_grad()
123 | loss.backward()
124 | optimtemp.step()
125 |
126 | # [setsz, nway]
127 | logits_q = nettemp(x_qry)
128 | # [setsz]
129 | pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
130 | # scalar
131 | correct = torch.eq(pred_q, y_qry).sum().item()
132 | corrects[j] = corrects[j] + correct
133 | del nettemp
134 |
135 | accs = np.array(corrects) / querysz
136 |
137 | return accs
138 |
--------------------------------------------------------------------------------
/examples/SSiamese/main.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import sys
3 | from data_loader.nomniglot_pair import NOmniglotTrainSet, NOmniglotTestSet
4 | from torch.utils.data import DataLoader
5 | from torch.autograd import Variable
6 | from examples.SSiamese.model import SSiamese
7 | import time
8 | import numpy as np
9 | import gflags
10 | from collections import deque
11 | import os
12 |
13 |
14 | if __name__ == '__main__':
15 | Flags = gflags.FLAGS
16 | gflags.DEFINE_bool("cuda", True, "use cuda")
17 | gflags.DEFINE_integer("way", 5, "how much way one-shot learning")
18 | gflags.DEFINE_integer("shot", 1, "how much shot few-shot learning")
19 | gflags.DEFINE_string("time", 2000, "number of samples to test accuracy")
20 | gflags.DEFINE_integer("workers", 8, "number of dataLoader workers")
21 | gflags.DEFINE_integer("batch_size",64, "number of batch size")
22 | gflags.DEFINE_float("lr", 0.0001, "learning rate")
23 | gflags.DEFINE_integer("show_every", 100, "show result after each show_every iter.")
24 | gflags.DEFINE_integer("save_every", 100, "save model after each save_every iter.")
25 | gflags.DEFINE_integer("test_every", 10000, "test model after each test_every iter.")
26 | gflags.DEFINE_integer("max_iter", 40000, "number of iterations before stopping")
27 | gflags.DEFINE_string("model_path", "./", "path to store model")
28 | gflags.DEFINE_string("gpu_ids", "0", "gpu ids used to train")
29 |
30 | Flags(sys.argv)
31 | T = 4
32 | data_type = 'event' # frequency
33 |
34 | os.environ["CUDA_VISIBLE_DEVICES"] = Flags.gpu_ids
35 | device = torch.device('cuda') if torch.cuda.is_available() else 'cpu'
36 | print("use gpu:", Flags.gpu_ids, "to train.")
37 | print("way:%d, shot: %d" % (Flags.way, Flags.shot))
38 |
39 | seed = 346
40 | torch.manual_seed(seed)
41 | torch.cuda.manual_seed_all(seed)
42 | torch.backends.cudnn.deterministic = True
43 |
44 | trainSet = NOmniglotTrainSet(root='../../data/', use_frame=True, frames_num=T, data_type=data_type,
45 | use_npz=True, resize=105)
46 | testSet = NOmniglotTestSet(root='../../data/', time=Flags.time, way=Flags.way, shot=Flags.shot, use_frame=True,
47 | frames_num=T, data_type=data_type, use_npz=True, resize=105)
48 | trainLoader = DataLoader(trainSet, batch_size=Flags.batch_size, shuffle=False, num_workers=Flags.workers)
49 | testLoader = DataLoader(testSet, batch_size=Flags.way * Flags.shot, shuffle=False, num_workers=Flags.workers)
50 |
51 | acc_list = []
52 | loss_fn = torch.nn.BCEWithLogitsLoss(reduction='mean')
53 | net = SSiamese(device=device)
54 |
55 | # multi gpu
56 | if len(Flags.gpu_ids.split(",")) > 1:
57 | net = torch.nn.DataParallel(net)
58 |
59 | if Flags.cuda:
60 | net.cuda()
61 |
62 | net.train()
63 |
64 | optimizer = torch.optim.Adam(net.parameters(), lr=Flags.lr)
65 | optimizer.zero_grad()
66 |
67 | train_loss = []
68 | loss_val = 0
69 | time_start = time.time()
70 | queue = deque(maxlen=20)
71 |
72 | best = 0
73 | for batch_id, (img1, img2, label) in enumerate(trainLoader, 1):
74 | img1 = img1.permute(1, 0, 2, 3, 4)[:, :, 0, :, :].unsqueeze(2)
75 | img2 = img2.permute(1, 0, 2, 3, 4)[:, :, 0, :, :].unsqueeze(2)
76 | max = torch.max(img1.max(), img2.max())
77 | torch.cuda.empty_cache()
78 | if batch_id > Flags.max_iter:
79 | break
80 | optimizer.zero_grad()
81 | output = net.forward(img1.to(device), img2.to(device), batch_size=Flags.batch_size, time_window=T)
82 | loss = loss_fn(output, label.to(device))
83 | loss_val += loss.cpu().item()
84 | loss.backward()
85 | optimizer.step()
86 | if batch_id % Flags.show_every == 0 :
87 | print('[%d]\tloss:\t%.5f\ttime lapsed:\t%.2f s'%(batch_id, loss_val/Flags.show_every, time.time() - time_start))
88 | loss_val = 0
89 | time_start = time.time()
90 | # if batch_id % Flags.save_every == 0:
91 | # torch.save(net.state_dict(), Flags.model_path + '/model'+ ".pt")
92 | if batch_id % Flags.test_every == 0 or (batch_id > 30000 and batch_id % 200 == 0):
93 | right, error = 0, 0
94 | for _, (test1, test2) in enumerate(testLoader, 1):
95 | test1 = test1.permute(1, 0, 2, 3, 4)[:, :, 0, :, :].unsqueeze(2)
96 | test2 = test2.permute(1, 0, 2, 3, 4)[:, :, 0, :, :].unsqueeze(2)
97 | if Flags.cuda:
98 | test1, test2 = test1.cuda(), test2.cuda()
99 | test1, test2 = Variable(test1), Variable(test2)
100 | with torch.no_grad():
101 | output = net.forward(test1, test2, batch_size=Flags.way*Flags.shot, time_window=T).data.cpu().numpy()
102 | pred = np.argmax(output)
103 | if pred < Flags.shot:
104 | right += 1
105 | else: error += 1
106 | print('*'*70)
107 | print('[%d]\tTest set\tcorrect:\t%d\terror:\t%d\tprecision:\t%f'%(batch_id, right, error, right*1.0/(right+error)))
108 | print('*'*70)
109 | queue.append(right*1.0/(right+error))
110 |
111 | if queue[-1] > best and queue[-1]!= 1.0:
112 | best = queue[-1]
113 | # torch.save(net.state_dict(), Flags.model_path + '/model' + ".pt")
114 | print("data_type: %s, %d-way %d-shot, gpu %s, best:%.4f" % (
115 | data_type, Flags.way, Flags.shot, Flags.gpu_ids, best))
116 |
117 | acc = 0.0
118 | for d in queue:
119 | acc += d
120 | print("#"*70)
121 | print("final accuracy: ", acc/20)
122 | print("besr accuracy:", best)
123 |
124 |
--------------------------------------------------------------------------------
/examples/CSNN/learner.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7 | thresh = 0.5 # neuronal threshold
8 | lens = 0.5 # hyper-parameters of approximate function
9 | decay = 0.2 # decay constants
10 |
11 |
12 | class SurrGradSpike(torch.autograd.Function):
13 | """
14 | Here we implement our spiking nonlinearity which also implements
15 | the surrogate gradient. By subclassing torch.autograd.Function,
16 | we will be able to use all of PyTorch's autograd functionality.
17 | Here we use the normalized negative part of a fast sigmoid
18 | as this was done in Zenke & Ganguli (2018).
19 | """
20 |
21 | scale = 100.0 # controls steepness of surrogate gradient
22 |
23 | @staticmethod
24 | def forward(ctx, input):
25 | """
26 | In the forward pass we compute a step function of the input Tensor
27 | and return it. ctx is a context object that we use to stash information which
28 | we need to later backpropagate our error signals. To achieve this we use the
29 | ctx.save_for_backward method.
30 | """
31 | ctx.save_for_backward(input)
32 | out = torch.zeros_like(input)
33 | out[input > 0] = 1.0
34 | return out
35 |
36 | @staticmethod
37 | def backward(ctx, grad_output):
38 | """
39 | In the backward pass we receive a Tensor we need to compute the
40 | surrogate gradient of the loss with respect to the input.
41 | Here we use the normalized negative part of a fast sigmoid
42 | as this was done in Zenke & Ganguli (2018).
43 | """
44 | input, = ctx.saved_tensors
45 | grad_input = grad_output.clone()
46 | grad = grad_input/(SurrGradSpike.scale*torch.abs(input)+1.0)**2
47 | return grad
48 |
49 | # here we overwrite our naive spike function by the "SurrGradSpike" nonlinearity which implements a surrogate gradient
50 | spike_fn = SurrGradSpike.apply
51 |
52 | # define approximate firing function
53 | class ActFun(torch.autograd.Function):
54 |
55 | @staticmethod
56 | def forward(ctx, input):
57 | ctx.save_for_backward(input)
58 | return input.gt(thresh).float()
59 |
60 | @staticmethod
61 | def backward(ctx, grad_output):
62 | input, = ctx.saved_tensors
63 | grad_input = grad_output.clone()
64 | temp = abs(input - thresh) < lens
65 | return grad_input * temp.float()
66 |
67 | act_fun = ActFun.apply
68 | # membrane potential update
69 |
70 | # Dacay learning_rate
71 | def lr_scheduler(optimizer, epoch, init_lr=0.1, lr_decay_epoch=50):
72 | """Decay learning rate by a factor of 0.1 every lr_decay_epoch epochs."""
73 | if epoch % lr_decay_epoch == 0 and epoch > 1:
74 | for param_group in optimizer.param_groups:
75 | param_group['lr'] = param_group['lr'] * 0.1
76 | return optimizer
77 |
78 |
79 | class LIFConv(nn.Module):
80 | def __init__(self,in_planes, out_planes, kernel_size, stride, padding, decay=0.2):
81 | super( ).__init__()
82 | self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding)
83 | self.mem = self.spike = None
84 | self.decay=decay
85 |
86 | def mem_update(self, x):
87 | if self.mem is None:
88 | self.mem = torch.zeros_like(x, device=device)
89 | self.spike = torch.zeros_like(x, device=device)
90 |
91 | self.mem = self.mem * self.decay * (1. - self.spike) + x
92 | self.spike = act_fun(self.mem) # act_fun : approximation firing function
93 | return self.spike
94 |
95 | def forward(self, x ):
96 | x = self.conv(x)
97 | x = self.mem_update(x)
98 | return x
99 |
100 | def reset(self):
101 | self.mem = self.spike = None
102 |
103 |
104 | class LIFLinear(nn.Module):
105 | def __init__(self, in_planes, out_planes, decay=0.2, last_layer=False):
106 | super().__init__()
107 | self.fc = nn.Linear(in_planes, out_planes)
108 | self.mem = self.spike = None
109 | self.decay = decay
110 | self.last_layer = last_layer
111 |
112 | def mem_update(self, x):
113 | if self.mem is None:
114 | self.mem = torch.zeros_like(x, device=device)
115 | self.spike = torch.zeros_like(x, device=device)
116 | if self.last_layer:
117 | self.mem = self.mem + x
118 | else:
119 | self.mem = self.mem * self.decay * (1. - self.spike) + x
120 | self.spike = act_fun(self.mem) # act_fun : approximation firing function
121 | return self.spike
122 |
123 | def forward(self, x):
124 | x = self.fc(x)
125 | x = self.mem_update(x)
126 | return x
127 |
128 | def reset(self):
129 | self.mem = self.spike = None
130 |
131 |
132 | class SCNN(nn.Module):
133 | def __init__(self):
134 | super(SCNN, self).__init__()
135 | self.conv1 = LIFConv(2, 15, kernel_size=5, stride=1, padding=0)
136 | self.conv2 = LIFConv(15, 40, kernel_size=5, stride=1, padding=0)
137 |
138 | self.fc1 = LIFLinear(640, 300)
139 | self.fc2 = LIFLinear(300, 1623,0.2,True)
140 |
141 | def forward(self, input ):
142 | for step in range(input.shape[1]): # simulation time steps
143 | x = input[:,step]
144 | x = self.conv1(x.float())
145 | x = F.avg_pool2d(x, 2)
146 | x = self.conv2(x)
147 | x = F.avg_pool2d(x, 2)
148 |
149 | x = x.view(input.shape[0], -1)
150 | x = self.fc1(x)
151 | x = self.fc2(x)
152 |
153 | outputs = self.fc2.mem / input.shape[1]
154 | self.reset()
155 | return outputs
156 |
157 | def reset(self):
158 | for i in self.children():
159 | i.reset()
160 |
161 |
--------------------------------------------------------------------------------
/data_loader/nomniglot_pair.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import Dataset, DataLoader
3 | import numpy as np
4 | from numpy.random import choice as npc
5 | import random
6 | import torch.nn.functional as F
7 | from data_loader.NOmniglot import NOmniglot
8 |
9 |
10 | class NOmniglotTrainSet(Dataset):
11 | '''
12 | Dataloader for Siamese Net
13 | The pairs of similar samples are labeled as 1, and those of different samples are labeled as 0
14 | '''
15 | def __init__(self, root='data/', use_frame=True, frames_num=10, data_type='event', use_npz=True, resize=None):
16 | super(NOmniglotTrainSet, self).__init__()
17 | self.resize = resize
18 | self.data_type = data_type
19 | self.use_frame = use_frame
20 | self.dataSet = NOmniglot(root=root, train=True, frames_num=frames_num, data_type=data_type, use_npz=use_npz)
21 | self.datas, self.num_classes = self.dataSet.datadict, self.dataSet.num_classes
22 |
23 | np.random.seed(0)
24 |
25 | def __len__(self):
26 | '''
27 | Sampling upper limit, you can set the maximum sampling times when using to terminate
28 | '''
29 | return 21000000
30 |
31 | def __getitem__(self, index):
32 | # get image from same class
33 | if index % 2 == 1:
34 | label = 1.0
35 | idx1 = random.randint(0, self.num_classes - 1)
36 | image1 = random.choice(self.datas[idx1])
37 | image2 = random.choice(self.datas[idx1])
38 | # get image from different class
39 | else:
40 | label = 0.0
41 | idx1 = random.randint(0, self.num_classes - 1)
42 | idx2 = random.randint(0, self.num_classes - 1)
43 | while idx1 == idx2:
44 | idx2 = random.randint(0, self.num_classes - 1)
45 | image1 = random.choice(self.datas[idx1])
46 | image2 = random.choice(self.datas[idx2])
47 |
48 | if self.use_frame:
49 | if self.data_type == 'event':
50 | image1 = torch.tensor(np.load(image1)['arr_0']).float()
51 | image2 = torch.tensor(np.load(image2)['arr_0']).float()
52 | elif self.data_type == 'frequency':
53 | image1 = torch.tensor(np.load(image1)['arr_0']).float()
54 | image2 = torch.tensor(np.load(image2)['arr_0']).float()
55 | else:
56 | raise NotImplementedError
57 |
58 | if self.resize is not None:
59 | image1 = image1[:, :, 4:254, 54:304]
60 | image1 = F.interpolate(image1, size=(self.resize, self.resize))
61 | image2 = image2[:, :, 4:254, 54:304]
62 | image2 = F.interpolate(image2, size=(self.resize, self.resize))
63 |
64 | return image1, image2, torch.from_numpy(np.array([label], dtype=np.float32))
65 |
66 |
67 | class NOmniglotTestSet(Dataset):
68 | '''
69 | Dataloader for Siamese Net
70 |
71 | '''
72 | def __init__(self, root='data/', time=1000, way=20, shot=1, query=1, use_frame=True, frames_num=10, data_type='event', use_npz=True, resize=None):
73 | super(NOmniglotTestSet, self).__init__()
74 | self.resize = resize
75 | self.use_frame = use_frame
76 | self.time = time # Sampling times
77 | self.way = way
78 | self.shot = shot
79 | self.query = query
80 | self.img1 = None # Fix test sample while sampling support set
81 | self.c1 = None # Fixed categories when sampling multiple samples
82 | self.c2 = None
83 | self.select_class = [] # selected classes
84 | self.select_sample = [] # selected samples
85 |
86 | self.data_type = data_type
87 | np.random.seed(0)
88 | self.dataSet = NOmniglot(root=root, train=False, frames_num=frames_num, data_type=data_type, use_npz=use_npz)
89 | self.datas, self.num_classes = self.dataSet.datadict, self.dataSet.num_classes
90 |
91 | def __len__(self):
92 | '''
93 | In general, the total number of test tasks is 1000.
94 | Since one test sample is collected at a time, way * shot support samples are used for each test
95 | '''
96 | return self.time * self.way * self.shot
97 |
98 | def __getitem__(self, index):
99 | '''
100 | The 0th sample of each way*shot is used for query and recorded in the selected sample
101 | to achieve the effect of selecting K +1
102 | '''
103 | idx = index % (self.way * self.shot)
104 | # generate image pair from same class
105 | if idx == 0: #
106 | self.select_class = []
107 | self.c1 = random.randint(0, self.num_classes - 1)
108 | self.c2 = self.c1
109 | sind = random.randint(0, len(self.datas[self.c1]) - 1)
110 | self.select_sample.append(sind)
111 | self.img1 = self.datas[self.c1][sind]
112 |
113 | sind = random.randint(0, len(self.datas[self.c2]) - 1)
114 | while sind in self.select_sample:
115 | sind = random.randint(0, len(self.datas[self.c2]) - 1)
116 | img2 = self.datas[self.c1][sind]
117 | self.select_sample.append(sind)
118 | self.select_class.append(self.c1)
119 | # generate image pair from different class
120 | else:
121 | if index % self.shot == 0:
122 | self.c2 = random.randint(0, self.num_classes - 1)
123 | while self.c2 in self.select_class: # self.c1 == c2:
124 | self.c2 = random.randint(0, self.num_classes - 1)
125 | self.select_class.append(self.c2)
126 | self.select_sample = []
127 | sind = random.randint(0, len(self.datas[self.c2]) - 1)
128 | while sind in self.select_sample:
129 | sind = random.randint(0, len(self.datas[self.c2]) - 1)
130 | img2 = self.datas[self.c2][sind]
131 | self.select_sample.append(sind)
132 |
133 | if self.use_frame:
134 | if self.data_type == 'event':
135 | img1 = torch.tensor(np.load(self.img1)['arr_0']).float()
136 | img2 = torch.tensor(np.load(img2)['arr_0']).float()
137 | elif self.data_type == 'frequency':
138 | img1 = torch.tensor(np.load(self.img1)['arr_0']).float()
139 | img2 = torch.tensor(np.load(img2)['arr_0']).float()
140 | else:
141 | raise NotImplementedError
142 |
143 | if self.resize is not None:
144 | img1 = img1[:,:,4:254, 54:304]
145 | img1 = F.interpolate(img1, size=(self.resize,self.resize))
146 | img2 = img2[:, :, 4:254, 54:304]
147 | img2 = F.interpolate(img2, size=(self.resize,self.resize))
148 | return img1, img2
149 |
150 |
151 | if __name__ == '__main__':
152 | data_type = 'frequency'
153 | T = 4
154 | trainSet = NOmniglotTrainSet(root='data/', use_frame=True, frames_num=T, data_type=data_type, use_npz=True, resize=105)
155 | testSet = NOmniglotTestSet(root='data/', time=1000, way=5, shot=1, use_frame=True, frames_num=T,
156 | data_type=data_type, use_npz=True, resize=105)
157 | trainLoader = DataLoader(trainSet, batch_size=48, shuffle=False, num_workers=4)
158 | testLoader = DataLoader(testSet, batch_size=5 * 1, shuffle=False, num_workers=4)
159 | for batch_id, (img1, img2) in enumerate(testLoader, 1):
160 | # img1.shape [batch, T, 2, H, W]
161 | print(batch_id)
162 | break
163 |
164 | for batch_id, (img1, img2, label) in enumerate(trainLoader, 1):
165 | # img1.shape [batch, T, 2, H, W]
166 | print(batch_id)
167 | break
--------------------------------------------------------------------------------
/examples/SMAML/learner.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7 | thresh = 0.5 # neuronal threshold
8 | lens = 0.5 # hyper-parameters of approximate function
9 | decay = 0.2 # decay constants
10 |
11 |
12 | class SurrGradSpike(torch.autograd.Function):
13 | """
14 | Here we implement our spiking nonlinearity which also implements
15 | the surrogate gradient. By subclassing torch.autograd.Function,
16 | we will be able to use all of PyTorch's autograd functionality.
17 | Here we use the normalized negative part of a fast sigmoid
18 | as this was done in Zenke & Ganguli (2018).
19 | """
20 |
21 | scale = 100.0 # controls steepness of surrogate gradient
22 |
23 | @staticmethod
24 | def forward(ctx, input):
25 | """
26 | In the forward pass we compute a step function of the input Tensor
27 | and return it. ctx is a context object that we use to stash information which
28 | we need to later backpropagate our error signals. To achieve this we use the
29 | ctx.save_for_backward method.
30 | """
31 | ctx.save_for_backward(input)
32 | out = torch.zeros_like(input)
33 | out[input > 0] = 1.0
34 | return out
35 |
36 | @staticmethod
37 | def backward(ctx, grad_output):
38 | """
39 | In the backward pass we receive a Tensor we need to compute the
40 | surrogate gradient of the loss with respect to the input.
41 | Here we use the normalized negative part of a fast sigmoid
42 | as this was done in Zenke & Ganguli (2018).
43 | """
44 | input, = ctx.saved_tensors
45 | grad_input = grad_output.clone()
46 | grad = grad_input/(SurrGradSpike.scale*torch.abs(input)+1.0)**2
47 | return grad
48 |
49 | # here we overwrite our naive spike function by the "SurrGradSpike" nonlinearity which implements a surrogate gradient
50 | spike_fn = SurrGradSpike.apply
51 |
52 |
53 | # define approximate firing function
54 | class ActFun(torch.autograd.Function):
55 |
56 | @staticmethod
57 | def forward(ctx, input):
58 | ctx.save_for_backward(input)
59 | return input.gt(thresh).float()
60 |
61 | @staticmethod
62 | def backward(ctx, grad_output):
63 | input, = ctx.saved_tensors
64 | grad_input = grad_output.clone()
65 | temp = abs(input - thresh) < lens
66 | return grad_input * temp.float()
67 |
68 |
69 | act_fun = ActFun.apply
70 | # membrane potential update
71 |
72 |
73 | # Dacay learning_rate
74 | def lr_scheduler(optimizer, epoch, init_lr=0.1, lr_decay_epoch=50):
75 | """Decay learning rate by a factor of 0.1 every lr_decay_epoch epochs."""
76 | if epoch % lr_decay_epoch == 0 and epoch > 1:
77 | for param_group in optimizer.param_groups:
78 | param_group['lr'] = param_group['lr'] * 0.1
79 | return optimizer
80 |
81 | class LIFConv(nn.Module):
82 | def __init__(self,in_planes, out_planes, kernel_size, stride, padding ,decay=0.2,last_layer=False):
83 | super( ).__init__()
84 | self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding)
85 | self.mem = self.spike = None
86 | self.decay=decay
87 | self.last_layer=last_layer
88 | def mem_update( self, x ):
89 | if self.mem is None:
90 | self.mem=torch.zeros_like(x, device=device)
91 | self.spike=torch.zeros_like(x, device=device)
92 | if self.last_layer:
93 | self.mem = self.mem + x
94 | else:
95 | self.mem = self.mem * self.decay * (1. - self.spike) + x
96 | self.spike = act_fun(self.mem) # act_fun : approximation firing function
97 | return self.spike
98 |
99 | def forward(self, x ):
100 | x=self.conv(x)
101 | x=self.mem_update(x)
102 | return x
103 |
104 | def reset(self):
105 | self.mem = self.spike = None
106 |
107 |
108 | class LIFLinear(nn.Module):
109 | def __init__(self, in_planes, out_planes, decay=0.2, last_layer=False):
110 | super().__init__()
111 | self.fc = nn.Linear(in_planes, out_planes)
112 | self.mem = self.spike = None
113 | self.decay = decay
114 | self.last_layer = last_layer
115 |
116 | def mem_update(self, x):
117 | if self.mem is None:
118 | self.mem = torch.zeros_like(x, device=device)
119 | self.spike = torch.zeros_like(x, device=device)
120 | if self.last_layer:
121 | self.mem = self.mem + x
122 | else:
123 | self.mem = self.mem * self.decay * (1. - self.spike) + x
124 | self.spike = act_fun(self.mem) # act_fun : approximation firing function
125 | return self.spike
126 |
127 | def forward(self, x):
128 | x = self.fc(x)
129 | x = self.mem_update(x)
130 | return x
131 |
132 | def reset(self):
133 | self.mem = self.spike = None
134 |
135 | # cnn_layer(in_planes, out_planes, stride, padding, kernel_size)
136 |
137 |
138 | def mem_update(ops, x, mem, spike, last_layer=False):
139 | if last_layer:
140 | mem = mem + ops(x)
141 | else:
142 | mem = mem * decay * (1. - spike) + ops(x)
143 | spike = act_fun(mem)
144 | return mem, spike
145 |
146 |
147 |
148 | class SCNN(nn.Module):
149 | def __init__(self, device):
150 | super(SCNN, self).__init__()
151 | self.conv1 = LIFConv(1, 15, kernel_size=5, stride=1, padding=0)
152 | self.conv2 = LIFConv(15, 40, kernel_size=5, stride=1, padding=0)
153 | self.fc1 = LIFLinear(640, 300)
154 | self.fc2 = LIFLinear(300, 20,0.2,True)
155 | def forward(self, input, time_window=12):
156 |
157 |
158 | for step in range(time_window): # simulation time steps
159 |
160 | x = input > torch.rand(input.size(), device=device) # prob. firing
161 |
162 | x = self.conv1(x.float())
163 |
164 | x = F.avg_pool2d(x, 2)
165 | x = self.conv2(x)
166 |
167 | x = F.avg_pool2d(x, 2)
168 |
169 | x = x.view(input.shape[0], -1)
170 | x = self.fc1(x)
171 | x = self.fc2(x)
172 |
173 | outputs = self.fc2.mem / time_window
174 | self.reset()
175 | return outputs
176 |
177 | def reset(self):
178 | for i in self.children():
179 | i.reset()
180 |
181 |
182 | class SCNN2(nn.Module):
183 | #
184 | def __init__(self, device):
185 | super().__init__()
186 | # self.batch_size = 16
187 | self.cfg_fc = (300, 5)
188 |
189 | self.cfg_cnn = ((2, 15, 5, 1, 0), (15, 40, 5, 1, 0))
190 | self.cfg_kernel = (24, 8, 4)
191 | in_planes1, out_planes1, kernel_size1, stride1, padding1 = self.cfg_cnn[0]
192 | # self.bn1 = nn.BatchNorm2d(15)
193 | self.conv1 = nn.Sequential(
194 | nn.Conv2d(in_planes1, out_planes1, kernel_size=kernel_size1, stride=stride1, padding=padding1),
195 | nn.BatchNorm2d(out_planes1),
196 | nn.ReLU(inplace=True)
197 | )
198 | in_planes2, out_planes2, kernel_size2, stride2, padding2 = self.cfg_cnn[1]
199 | self.conv2 = nn.Sequential(
200 | nn.Conv2d(in_planes2, out_planes2, kernel_size=kernel_size2, stride=stride2, padding=padding2),
201 | nn.BatchNorm2d(out_planes2),
202 | nn.ReLU(inplace=True)
203 | )
204 | self.fc1 = nn.Linear(self.cfg_kernel[-1] * self.cfg_kernel[-1] * self.cfg_cnn[-1][1], self.cfg_fc[0])
205 |
206 | self.fc2 = nn.Linear(self.cfg_fc[0], self.cfg_fc[1])
207 | self.device = device
208 |
209 | def forward(self, input, time_window=20):
210 | self.batch_size = input.shape[0]
211 | # print(input.shape)
212 | c1_mem = c1_spike = torch.zeros(self.batch_size, self.cfg_cnn[0][1], self.cfg_kernel[0], self.cfg_kernel[0],
213 | device=self.device)
214 | c2_mem = c2_spike = torch.zeros(self.batch_size, self.cfg_cnn[1][1], self.cfg_kernel[1], self.cfg_kernel[1],
215 | device=self.device)
216 |
217 | h1_mem = h1_spike = h1_sumspike = torch.zeros(self.batch_size, self.cfg_fc[0], device=self.device)
218 | h2_mem = h2_spike = h2_sumspike = torch.zeros(self.batch_size, self.cfg_fc[1], device=self.device)
219 | for step in range(input.shape[1]): # simulation time steps
220 | x = 1 - input[:, step]
221 | c1_mem, c1_spike = mem_update(self.conv1, x.float(), c1_mem, c1_spike)
222 |
223 | x = F.avg_pool2d(c1_spike, 2)
224 |
225 | c2_mem, c2_spike = mem_update(self.conv2, x, c2_mem, c2_spike)
226 |
227 | x = F.avg_pool2d(c2_spike, 2)
228 | x = x.view(self.batch_size, -1)
229 |
230 | h1_mem, h1_spike = mem_update(self.fc1, x, h1_mem, h1_spike)
231 | h1_sumspike += h1_spike
232 | h2_mem, h2_spike = mem_update(self.fc2, h1_spike, h2_mem, h2_spike, last_layer=True)
233 | h2_sumspike += h2_spike
234 |
235 | outputs = h2_mem / input.shape[1] # h2_sumspike / time_window
236 | return outputs
237 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import threading
3 | import numpy as np
4 | import pandas
5 | import os
6 | from dv import AedatFile
7 |
8 |
9 | class FunctionThread(threading.Thread):
10 | def __init__(self, f, *args, **kwargs):
11 | super().__init__()
12 | self.f = f
13 | self.args = args
14 | self.kwargs = kwargs
15 |
16 | def run(self):
17 | self.f(*self.args, **self.kwargs)
18 |
19 |
20 | def integrate_events_to_frames(events, height, width, frames_num=10, data_type='event'):
21 | frames = np.zeros(shape=[frames_num, 2, height * width])
22 |
23 | # create j_{l}和j_{r}
24 | j_l = np.zeros(shape=[frames_num], dtype=int)
25 | j_r = np.zeros(shape=[frames_num], dtype=int)
26 |
27 | # split by time
28 | events['t'] -= events['t'][0] # start with 0 timestamp
29 | assert events['t'][-1] > frames_num
30 | dt = events['t'][-1] // frames_num # get length of each frame
31 | idx = np.arange(events['t'].size)
32 | for i in range(frames_num):
33 | t_l = dt * i
34 | t_r = t_l + dt
35 | mask = np.logical_and(events['t'] >= t_l, events['t'] < t_r)
36 | idx_masked = idx[mask]
37 | if len(idx_masked) == 0:
38 | j_l[i] = -1
39 | j_r[i] = -1
40 | else:
41 | j_l[i] = idx_masked[0]
42 | j_r[i] = idx_masked[-1] + 1 if i < frames_num - 1 else events['t'].size
43 |
44 | for i in range(frames_num):
45 | if j_l[i] >= 0:
46 | x = events['x'][j_l[i]:j_r[i]]
47 | y = events['y'][j_l[i]:j_r[i]]
48 | p = events['p'][j_l[i]:j_r[i]]
49 | mask = []
50 | mask.append(p == 0)
51 | mask.append(np.logical_not(mask[0]))
52 | for j in range(2):
53 | position = y[mask[j]] * width + x[mask[j]]
54 | events_number_per_pos = np.bincount(position)
55 | frames[i][j][np.arange(events_number_per_pos.size)] += events_number_per_pos
56 |
57 | if data_type == 'frequency':
58 | if i < frames_num - 1:
59 | frames[i] /= dt
60 | else:
61 | frames[i] /= (dt + events['t'][-1] % frames_num)
62 | frames = frames.astype(np.float16)
63 |
64 | if data_type == 'event':
65 | frames = (frames > 0).astype(np.bool)
66 | else:
67 | frames = normalize_frame(frames, 'max')
68 | return frames.reshape((frames_num, 2, height, width))
69 |
70 |
71 | def normalize_frame(frames: np.ndarray or torch.Tensor, normalization: str):
72 | eps = 1e-5
73 | for i in range(frames.shape[0]):
74 | if normalization == 'max':
75 | frames[i][0] = frames[i][0] / max(frames[i][0].max(), eps)
76 | frames[i][1] = frames[i][1] / max(frames[i][1].max(), eps)
77 |
78 | elif normalization == 'norm':
79 | frames[i][0] = (frames[i][0] - frames[i][0].mean()) / np.sqrt(max(frames[i][0].var(), eps))
80 | frames[i][1] = (frames[i][1] - frames[i][1].mean()) / np.sqrt(max(frames[i][1].var(), eps))
81 |
82 | elif normalization == 'sum':
83 | frames[i][0] = frames[i][0] / max(frames[i][0].sum(), eps)
84 | frames[i][1] = frames[i][1] / max(frames[i][1].sum(), eps)
85 |
86 | else:
87 | raise NotImplementedError
88 | return frames
89 |
90 |
91 | def convert_events_dir_to_frames_dir(events_data_dir, frames_data_dir, suffix,
92 | frames_num=12, result_type='event', thread_num=1,
93 | compress=True):
94 | """
95 | Iterate through all event data in eventS_date_DIR and generate frame data files in frames_data_DIR
96 | """
97 | def read_function(file_name):
98 | return np.load(file_name, allow_pickle=True).item()
99 |
100 | def cvt_fun(events_file_list):
101 | for events_file in events_file_list:
102 | print(events_file)
103 | frames = integrate_events_to_frames(read_function(events_file), 260, 346, frames_num, result_type )
104 | if compress:
105 | frames_file = os.path.join(frames_data_dir,
106 | os.path.basename(events_file)[0: -suffix.__len__()] + '.npz')
107 | np.savez_compressed(frames_file, frames)
108 | else:
109 | frames_file = os.path.join(frames_data_dir,
110 | os.path.basename(events_file)[0: -suffix.__len__()] + '.npy')
111 | np.save(frames_file, frames)
112 |
113 | # Obtain the path of the all files
114 | events_file_list = list_all_files(events_data_dir, '.npy')
115 |
116 | if thread_num == 1:
117 | cvt_fun(events_file_list)
118 | else:
119 | # Multithreading acceleration
120 | thread_list = []
121 | block = events_file_list.__len__() // thread_num
122 | for i in range(thread_num - 1):
123 | thread_list.append(FunctionThread(cvt_fun, events_file_list[i * block: (i + 1) * block]))
124 | thread_list[-1].start()
125 | print(f'thread {i} start, processing files index: {i * block} : {(i + 1) * block}.')
126 | thread_list.append(FunctionThread(cvt_fun, events_file_list[(thread_num - 1) * block:]))
127 | thread_list[-1].start()
128 | print(
129 | f'thread {thread_num} start, processing files index: {(thread_num - 1) * block} : {events_file_list.__len__()}.')
130 | for i in range(thread_num):
131 | thread_list[i].join()
132 | print(f'thread {i} finished.')
133 |
134 |
135 | def convert_aedat4_dir_to_events_dir(root, train):
136 | kind = 'background' if train else "evaluation"
137 | originroot = root
138 | root = root + '/dvs_' + kind + '/'
139 | alphabet_names = [a for a in os.listdir(root) if a[0] != '.'] # get folder names
140 |
141 | for a in range(len(alphabet_names)):
142 | alpha_name = alphabet_names[a]
143 |
144 | for b in range(len(os.listdir(os.path.join(root, alpha_name)))):
145 | character_id = b + 1
146 | character_path = alpha_name + '/character' + num2str(character_id)
147 | print('Parsing %s \\ character%s ...' % (alpha_name, num2str(character_id)))
148 |
149 | file_path = os.path.join(root, character_path)
150 | aedat4_name = [a for a in os.listdir(file_path) if a[-4:] == 'dat4' and len(a) == 11][0]
151 | csv_name = [a for a in os.listdir(file_path) if a[-4:] == '.csv' and len(a) == 8][0]
152 | number = csv_name[:4]
153 | new_path = originroot + '/events_npy/' + kind + '/' + alpha_name + '/character' + num2str(character_id)
154 | if not os.path.exists(new_path):
155 | os.makedirs(new_path)
156 |
157 | start_end_timestamp = pandas.read_csv(os.path.join(file_path, csv_name)).values
158 |
159 | a_timestamp, a_polarity, a_x, a_y = [], [], [], []
160 | with AedatFile(os.path.join(file_path, aedat4_name)) as f: # read aedat4
161 | for e in f['events']:
162 | a_timestamp.append(e.timestamp)
163 | a_polarity.append(e.polarity)
164 | a_x.append(e.x)
165 | a_y.append(e.y)
166 |
167 | for ii in range(20): # each file has 20 samples
168 | name = str(number) + '_' + num2str(ii + 1) + '.npy'
169 | start_index = a_timestamp.index(start_end_timestamp[ii][1])
170 | end_index = a_timestamp.index(start_end_timestamp[ii][2])
171 | tmp = {'t': np.array(a_timestamp[start_index:end_index]),
172 | 'x': np.array(a_x[start_index:end_index]),
173 | 'y': np.array(a_y[start_index:end_index]),
174 | 'p': np.array(a_polarity[start_index:end_index])}
175 | np.save(os.path.join(new_path, name), tmp)
176 |
177 |
178 | def num2str(idx):
179 | if idx < 10:
180 | return '0' + str(idx)
181 | return str(idx)
182 |
183 |
184 | def list_all_files(root, suffix, getlen=False):
185 | '''
186 | List the path of all files under root, output a list
187 | '''
188 | file_list = []
189 | alphabet_names = [a for a in os.listdir(root) if a[0] != '.'] # get folder names
190 | idx = 0
191 | for a in range(len(alphabet_names)):
192 | alpha_name = alphabet_names[a]
193 | for b in range(len(os.listdir(os.path.join(root, alpha_name)))):
194 | character_id = b + 1
195 | character_path = os.path.join(root, alpha_name, 'character' + num2str(character_id))
196 | idx += 1
197 | for c in range(len(os.listdir(character_path))):
198 | fn_example = os.listdir(character_path)[c]
199 | if fn_example[-4:] == suffix:
200 | file_list.append(os.path.join(character_path, fn_example))
201 | if getlen:
202 | return file_list, idx
203 | else:
204 | return file_list
205 |
206 |
207 | def list_class_files(root, frames_kind_root, getlen=False, use_npz=False):
208 | '''
209 | index the generated samples,
210 | get dictionaries according to categories, each corresponding to a list,
211 | the list contain the address of the new file in fnum_x_dtype_x_npz_True
212 | '''
213 | file_list = {}
214 | alphabet_names = [a for a in os.listdir(root) if a[0] != '.'] # get folder names
215 | idx = 0
216 | for a in range(len(alphabet_names)):
217 | alpha_name = alphabet_names[a]
218 | for b in range(len(os.listdir(os.path.join(root, alpha_name)))):
219 | character_id = b + 1
220 | character_path = os.path.join(root, alpha_name, 'character' + num2str(character_id))
221 | file_list[idx] = []
222 | for c in range(len(os.listdir(character_path))):
223 | fn_example = os.listdir(character_path)[c]
224 | if use_npz:
225 | fn_example = fn_example[:-1] + 'z'
226 | file_list[idx].append(os.path.join(frames_kind_root, fn_example))
227 | idx += 1
228 | if getlen:
229 | return file_list, idx
230 | else:
231 | return file_list
232 |
--------------------------------------------------------------------------------