├── README.md ├── dvs_cifar10_ASF.py ├── dvscifar10_dataloader.py ├── lenet5.py └── vgg7.py /README.md: -------------------------------------------------------------------------------- 1 | # ASF-BP 2 | 3 | The repo contains the code associated with the SNN training method ASF-BP. The code has been tested with Pytorch 1.1.0 and Python 3.7.4. 4 | 5 | ## Testing and Training 6 | To train a new model from scratch, the basic syntax is like: ```python vgg7.py``` 7 | 8 | To test a pre-trained model, the basic syntax is like```python vgg7.py --resume model_bestT1_cifar10_v7.pth.tar --evaluate``` 9 | 10 | ## Reference 11 | Chankyu Lee, Syed Shakib Sarwar, Priyadarshini Panda, Gopalakrishnan Srinivasan, and Kaushik Roy. Enabling spike-based backpropagation for training deep neural network architectures. Frontiers in Neuroscience, 14, 2020. 12 | 13 | -------------------------------------------------------------------------------- /dvs_cifar10_ASF.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import shutil 4 | import time 5 | import math 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.parallel 9 | import torch.backends.cudnn as cudnn 10 | from torch.autograd import Variable 11 | import torch.optim 12 | import torch.utils.data 13 | import torchvision 14 | import torchvision.transforms as transforms 15 | import torchvision.models as models 16 | import copy 17 | import numpy as np 18 | import random 19 | import pdb 20 | from dvscifar10_dataloader import DVSCifar10 21 | 22 | os.environ["CUDA_VISIBLE_DEVICES"] = "0, 1, 2" 23 | 24 | model_names = sorted(name for name in models.__dict__ 25 | if name.islower() and not name.startswith("__") 26 | and callable(models.__dict__[name])) 27 | parser = argparse.ArgumentParser(description='PyTorch DVSCIFAR10 Training') 28 | parser.add_argument('--dataset', default='DVSCIFAR10', type=str, help='dataset = [MNIST]') 29 | parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18', 30 | choices=model_names, 31 | help='model architecture: ' + 32 | ' | '.join(model_names) + 33 | ' (default: resnet18)') 34 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 35 | help='manual epoch number (useful on restarts)') 36 | parser.add_argument('-b', '--batch-size', default=40, type=int, 37 | metavar='N', help='mini-batch size (default: 100)') 38 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 39 | help='momentum') 40 | parser.add_argument('--weight-decay', '--wd', default=5e-4, type=float, 41 | metavar='W', help='weight decay (default: 1e-4)') 42 | parser.add_argument('--print-freq', '-p', default=500, type=int, 43 | metavar='N', help='print frequency (default: 10)') 44 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 45 | help='path to latest checkpoint (default: none)') 46 | parser.add_argument('-load', default='', type=str, metavar='PATH', 47 | help='path to training mask (default: none)') 48 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 49 | help='evaluate model on validation set') 50 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 51 | help='use pre-trained model') 52 | parser.add_argument('--lr', '--learning-rate', default=0.0005, type=float, 53 | metavar='LR', help='initial learning rate') 54 | parser.add_argument('-j', '--workers', default=1, type=int, metavar='N', 55 | help='number of data loading workers (default: 4)') 56 | parser.add_argument('--epochs', default=50, type=int, metavar='N', 57 | help='number of total epochs to run') 58 | parser.add_argument('--steps', default=100, type=int, 59 | metavar='N', help='steps (default: 100)') 60 | parser.add_argument('--repeat', default=1, type=int, 61 | metavar='N', help='repeat (default: 1)') 62 | parser.add_argument('--count', default=1, type=int, 63 | help='count (default: 1)') 64 | parser.add_argument('--per', default=0.25, type=float, 65 | metavar='per', help='per (default: 0.25)') 66 | parser.add_argument('--vth', default=1, type=float, 67 | metavar='vth', help='vth (default: 1)') 68 | parser.add_argument('--avgtimes', default=9, type=float, 69 | metavar='avgtimes', help='avgtimes (default: 9)') 70 | parser.add_argument('--leak', default=1, type=float, 71 | metavar='leak', help='leaky parameter (default: 1)') 72 | 73 | best_prec1 = 0 74 | change = 10#50 75 | change2 = 20#75 76 | change3 = 30#100 77 | change4 = 100 78 | change5 = 150 79 | 80 | tp1 = [] 81 | tp5 = [] 82 | ep = [] 83 | lRate = [] 84 | device_num = 1 85 | device = torch.device("cuda:0") 86 | 87 | tp1_tr = [] 88 | tp5_tr = [] 89 | losses_tr = [] 90 | losses_eval = [] 91 | 92 | rep = 1 #*************************** repeat times 93 | 94 | dvscifar10_path = '/data/diospada/dvs-cifar10/dvs-cifar10' 95 | sign = 1 96 | 97 | scale11 = 1 98 | scale12 = 1 99 | scalep1 = 1 100 | scale21 = 1 101 | scale22 = 1 102 | scale23 = 1 103 | scalef0 = 1 104 | args = parser.parse_args() 105 | 106 | 107 | def main(): 108 | global args, best_prec1, batch_size, device_num,sign 109 | seed1 = 44 110 | seed2 = 85 111 | seed3 = 63 112 | batch_size = args.batch_size 113 | steps = args.steps 114 | print('\n'+'='*15+'settings'+'='*15) 115 | print('lr: ', args.lr) 116 | print('steps: ', steps) 117 | print('change lr point:%d %d %d'%(change,change2,change3)) 118 | print('batchsize:',batch_size) 119 | print('vth:{}'.format(args.vth)) 120 | print('repeat:{}'.format(args.repeat)) 121 | print('count={}'.format(args.count)) 122 | print('per={}'.format(args.per)) 123 | print("rand num: %d %d %d"%(seed1,seed2,seed3)) 124 | print('='*15+'settings'+'='*15+'\n') 125 | 126 | 127 | torch.manual_seed(seed1) 128 | torch.cuda.manual_seed(seed2) 129 | torch.cuda.manual_seed_all(seed3) 130 | np.random.seed(seed1) 131 | random.seed(seed2) 132 | cudnn.benchmark = False 133 | cudnn.deterministic = True 134 | 135 | 136 | model = CNNModel() 137 | 138 | print(model) 139 | 140 | model.to(device) 141 | 142 | criterion = torch.nn.MSELoss(reduction='sum') 143 | criterion_en = torch.nn.CrossEntropyLoss() 144 | 145 | learning_rate = args.lr 146 | steps = args.steps 147 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=5e-4) 148 | 149 | # optionally resume from a checkpoint 150 | if args.resume: 151 | if os.path.isfile(args.resume): 152 | print("=> loading checkpoint '{}'".format(args.resume)) 153 | checkpoint = torch.load(args.resume) 154 | args.start_epoch = checkpoint['epoch'] 155 | best_prec1 = checkpoint['best_prec1'] 156 | model.load_state_dict(checkpoint['state_dict']) 157 | optimizer.load_state_dict(checkpoint['optimizer']) 158 | print("=> loaded checkpoint '{}' (epoch {})" 159 | .format(args.resume, checkpoint['epoch'])) 160 | else: 161 | print("=> no checkpoint found at '{}'".format(args.resume)) 162 | 163 | 164 | 165 | # Data loading code different with the mnist dataset 166 | time_read1 = time.time() 167 | train_data = DVSCifar10(dvscifar10_path, train=True, transform=transforms.ToTensor(), target_transform=None,steps=steps, count=args.count, per=args.per) 168 | train_loader = torch.utils.data.DataLoader(train_data, 169 | batch_size=args.batch_size, shuffle=True, 170 | num_workers=args.workers, 171 | pin_memory=True) 172 | time_read1 = time.time() - time_read1 173 | print('read data of testing takes %dh %dmin'%(time_read1/3600, (time_read1%3600)/60)) 174 | 175 | time_read1 = time.time() 176 | val_data = DVSCifar10(dvscifar10_path, train=False,transform=transforms.ToTensor(),steps=steps, count=args.count, per=args.per) 177 | val_loader = torch.utils.data.DataLoader(val_data, 178 | batch_size=int(args.batch_size/2), shuffle=False, 179 | num_workers=args.workers, 180 | pin_memory=False) 181 | time_read1 = time.time() - time_read1 182 | print('read data of testing takes %dh %dmin'%(time_read1/3600, (time_read1%3600)/60)) 183 | if args.evaluate: 184 | validate(val_loader, model, criterion, criterion_en, time_steps=args.steps, leak=args.leak) 185 | return 186 | 187 | for epoch in range(args.start_epoch, args.epochs): 188 | if epoch % 5 == 0 and args.hz < args.epochs: 189 | sign = 1 190 | else: 191 | sign = 0 192 | 193 | start = time.time() 194 | adjust_learning_rate(optimizer, epoch) 195 | 196 | ep.append(epoch) 197 | 198 | # train for one epoch 199 | train(train_loader, model, criterion, criterion_en, optimizer, epoch, time_steps=steps, leak=args.leak) 200 | 201 | # evaluate on validation set 202 | modeltest = model 203 | prec1 = validate(val_loader, modeltest, criterion, criterion_en, time_steps=steps, leak=args.leak) 204 | 205 | # remember best prec@1 and save checkpoint 206 | is_best = prec1 > best_prec1 207 | best_prec1 = max(prec1, best_prec1) 208 | # save_checkpoint({ 209 | # 'epoch': epoch + 1, 210 | # 'arch': args.arch, 211 | # 'state_dict': model.state_dict(), 212 | # 'best_prec1': best_prec1, 213 | # 'optimizer': optimizer.state_dict(), 214 | # }, is_best) 215 | time_use = time.time() - start 216 | print('time used this epoch: %d h%dmin%ds' %(time_use//3600,(time_use%3600)//60,time_use%60)) 217 | 218 | if sign == 1: 219 | print('\n'+'='*15+'scale'+'='*15) 220 | print('scale11: ', scale11) 221 | print('scale12: ', scale12) 222 | print('scalep1: ', scalep1) 223 | print('scale21: ', scale21) 224 | print('scale22: ', scale22) 225 | print('scale23: ', scale23) 226 | print('scalef0: ', scalef0) 227 | print('='*15+'scale'+'='*15+'\n') 228 | 229 | for k in range(0, args.epochs - args.start_epoch): 230 | print('Epoch: [{0}/{1}]\t' 231 | 'LR:{2}\t' 232 | 'Prec@1 {top1:.3f} \t' 233 | 'Prec@5 {top5:.3f} \t' 234 | 'En_Loss_Eval {losses_en_eval: .4f} \t' 235 | 'Prec@1_tr {top1_tr:.3f} \t' 236 | 'Prec@5_tr {top5_tr:.3f} \t' 237 | 'En_Loss_train {losses_en: .4f}'.format( 238 | ep[k], args.epochs, lRate[k], top1=tp1[k], top5=tp5[k], losses_en_eval=losses_eval[k], top1_tr=tp1_tr[k], 239 | top5_tr=tp5_tr[k], losses_en=losses_tr[k])) 240 | print('best_acc{}'.format(best_prec1)) 241 | 242 | 243 | def print_view(v): 244 | v = v.view(v.size(0), -1) 245 | j = 0 246 | for i in v[0]: 247 | print(i) 248 | j = j + 1 249 | print(j) 250 | 251 | def grad_cal(scale, IF_in): 252 | out = scale * IF_in.gt(0).type(torch.cuda.FloatTensor) 253 | return out 254 | 255 | def ave(output, input): 256 | c = input >= output 257 | if input[c].sum() < 1e-3: 258 | return 1 259 | return output[c].sum()/input[c].sum() 260 | 261 | def ave_p(output, input): 262 | if input.sum() < 1e-3: 263 | return 1 264 | return output.sum()/input.sum() 265 | 266 | 267 | def train(train_loader, model, criterion, criterion_en, optimizer, epoch, time_steps, leak): 268 | # print('train start') 269 | batch_time = AverageMeter() 270 | data_time = AverageMeter() 271 | losses = AverageMeter() 272 | top1 = AverageMeter() 273 | top5 = AverageMeter() 274 | 275 | top1_tr = AverageMeter() 276 | top5_tr = AverageMeter() 277 | losses_en = AverageMeter() 278 | 279 | # switch to train mode 280 | model.train() 281 | 282 | end = time.time() 283 | start_end = end 284 | # print ('mark1',train_loader.sampler) 285 | for i, (input, target) in enumerate(train_loader): 286 | # measure data loading time 287 | data_time.update(time.time() - end) 288 | input, target = input.to(device), target.to(device) 289 | labels = target.clone() 290 | optimizer.zero_grad() # Clear gradients w.r.t. parameters 291 | 292 | output = model(input, steps=time_steps, l=leak) 293 | 294 | targetN = output.data.clone().zero_().to(device) 295 | targetN.scatter_(1, target.unsqueeze(1), 1) 296 | targetN = Variable(targetN.type(torch.cuda.FloatTensor)) 297 | 298 | loss = criterion(output.cpu(), targetN.cpu()) 299 | loss_en = criterion_en(output.cpu(), labels.cpu()) 300 | 301 | # measure accuracy and record loss 302 | prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) 303 | losses.update(loss.item(), input.size(0)) 304 | top1.update(prec1.item(), input.size(0)) 305 | top5.update(prec5.item(), input.size(0)) 306 | 307 | prec1_tr, prec5_tr = accuracy(output.data, target, topk=(1, 5)) 308 | losses_en.update(loss_en.item(), input.size(0)) 309 | top1_tr.update(prec1_tr.item(), input.size(0)) 310 | top5_tr.update(prec5_tr.item(), input.size(0)) 311 | 312 | # compute gradient and do SGD step 313 | loss.backward(retain_graph=False) 314 | optimizer.step() 315 | 316 | # measure elapsed time 317 | batch_time.update(time.time() - end) 318 | end = time.time() 319 | 320 | 321 | print('Epoch: [{0}] Prec@1 {top1_tr.avg:.3f} Prec@5 {top5_tr.avg:.3f} Entropy_Loss {loss_en.avg:.4f}' 322 | .format(epoch, top1_tr=top1_tr, top5_tr=top5_tr, loss_en=losses_en)) 323 | time_use = end - start_end 324 | print('train time: %d h%dmin%ds' %(time_use//3600,(time_use%3600)//60,time_use%60)) 325 | 326 | losses_tr.append(losses_en.avg) 327 | tp1_tr.append(top1_tr.avg) 328 | tp5_tr.append(top5_tr.avg) 329 | 330 | 331 | def validate(val_loader, model, criterion, criterion_en, time_steps, leak): 332 | # validate start 333 | batch_time = AverageMeter() 334 | data_time = AverageMeter() 335 | losses = AverageMeter() 336 | top1 = AverageMeter() 337 | top5 = AverageMeter() 338 | losses_en_eval = AverageMeter() 339 | 340 | # switch to evaluate mode 341 | model.eval() 342 | 343 | end = time.time() 344 | with torch.no_grad(): 345 | for i, (input, target) in enumerate(val_loader): 346 | # measure data loading time 347 | data_time.update(time.time() - end) 348 | input_var = input.to(device) 349 | labels = Variable(target.to(device)) 350 | target = target.to(device) 351 | output = model.tst(input=input_var, steps=time_steps, l=leak) 352 | 353 | 354 | targetN = output.data.clone().zero_().to(device) 355 | targetN.scatter_(1, target.unsqueeze(1), 1) 356 | targetN = Variable(targetN.type(torch.cuda.FloatTensor)) 357 | loss = criterion(output.cpu(), targetN.cpu()) 358 | loss_en = criterion_en(output.cpu(), labels.cpu()) 359 | 360 | # measure accuracy and record loss 361 | prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) 362 | losses.update(loss.item(), input.size(0)) 363 | top1.update(prec1.item(), input.size(0)) 364 | top5.update(prec5.item(), input.size(0)) 365 | losses_en_eval.update(loss_en.item(), input.size(0)) 366 | 367 | # measure elapsed time 368 | batch_time.update(time.time() - end) 369 | end = time.time() 370 | 371 | print('Test: Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f} Entropy_Loss {losses_en_eval.avg:.4f}' 372 | .format(top1=top1, top5=top5, losses_en_eval=losses_en_eval)) 373 | 374 | tp1.append(top1.avg) 375 | tp5.append(top5.avg) 376 | losses_eval.append(losses_en_eval.avg) 377 | 378 | return top1.avg 379 | 380 | 381 | def save_checkpoint(state, is_best, filename='checkpointT1_dvscifar10_v7.pth.tar'): 382 | torch.save(state, filename) 383 | if is_best: 384 | shutil.copyfile(filename, 'model_bestT1_dvscifar10_v7.pth.tar') 385 | 386 | 387 | class AverageMeter(object): 388 | """Computes and stores the average and current value""" 389 | 390 | def __init__(self): 391 | self.reset() 392 | 393 | def reset(self): 394 | self.val = 0 395 | self.avg = 0 396 | self.sum = 0 397 | self.count = 0 398 | 399 | def update(self, val, n=1): 400 | self.val = val 401 | self.sum += val * n 402 | self.count += n 403 | self.avg = self.sum / self.count 404 | 405 | 406 | def adjust_learning_rate(optimizer, epoch): 407 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 408 | lr = args.lr * (1 ** (epoch // change)) 409 | 410 | for param_group in optimizer.param_groups: 411 | if epoch >= change5: 412 | param_group['lr'] = 0.5 * 0.5 * 0.2 * 0.2 * 0.2 * lr 413 | 414 | elif epoch >= change4: 415 | param_group['lr'] = 0.5 * 0.2 * 0.2 * 0.2 * lr 416 | 417 | elif epoch >= change3: 418 | param_group['lr'] = 0.2 * 0.2 * 0.2 * lr 419 | 420 | elif epoch >= change2: 421 | param_group['lr'] = 0.2 * 0.2 * lr 422 | 423 | elif epoch >= change: 424 | param_group['lr'] = 0.2 * lr 425 | 426 | else: 427 | param_group['lr'] = lr 428 | 429 | lRate.append(param_group['lr']) 430 | 431 | 432 | def accuracy(output, target, topk=(1,)): 433 | """Computes the precision@k for the specified values of k""" 434 | maxk = max(topk) 435 | batch_size = target.size(0) 436 | 437 | _, pred = output.topk(maxk, 1, True, True) 438 | pred = pred.t() 439 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 440 | 441 | res = [] 442 | for k in topk: 443 | correct_k = correct[:k].view(-1).float().sum(0) 444 | res.append(correct_k.mul_(100.0 / batch_size)) 445 | return res 446 | 447 | 448 | class SpikingNN(torch.autograd.Function): 449 | def forward(self, input): 450 | self.save_for_backward(input) 451 | return input.gt(0).type(torch.cuda.FloatTensor) 452 | 453 | def backward(self, grad_output): 454 | input, = self.saved_tensors 455 | grad_input = grad_output.clone() 456 | grad_input[input <= 0.0] = 0 457 | return grad_input 458 | 459 | 460 | def LIF_sNeuron(membrane_potential, threshold, l, i): 461 | # check exceed membrane potential and reset 462 | ex_membrane = nn.functional.threshold(membrane_potential, threshold, 0) 463 | membrane_potential = membrane_potential - ex_membrane # hard reset 464 | # generate spike 465 | out = SpikingNN()(ex_membrane) 466 | # decay 467 | membrane_potential = l * membrane_potential.detach() + membrane_potential - membrane_potential.detach() 468 | # out = out.detach() + torch.div(out, threshold) - torch.div(out, threshold).detach() 469 | 470 | return membrane_potential, out 471 | 472 | 473 | def Pooling_sNeuron(membrane_potential, threshold, i): 474 | # check exceed membrane potential and reset 475 | ex_membrane = nn.functional.threshold(membrane_potential, threshold, 0) 476 | membrane_potential = membrane_potential - ex_membrane 477 | # generate spike 478 | out = SpikingNN()(ex_membrane) 479 | 480 | return membrane_potential, out 481 | 482 | 483 | class CNNModel(nn.Module): 484 | def __init__(self): 485 | super(CNNModel, self).__init__() 486 | self.maxpool4 = nn.AvgPool2d(kernel_size=3) 487 | # self.maxpool4 = nn.MaxPool2d(kernel_size=3) 488 | 489 | self.cnn11 = nn.Conv2d(in_channels=2, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False) 490 | self.cnn12 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False) 491 | #self.BN1 = nn.BatchNorm2d(64) 492 | 493 | self.avgpool1 = nn.AvgPool2d(kernel_size=2) 494 | 495 | self.cnn21 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1, bias=False) 496 | self.cnn22 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1, bias=False) 497 | self.cnn23 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False) 498 | #self.BN2 = nn.BatchNorm2d(128) 499 | 500 | self.avgpool2 = nn.MaxPool2d(kernel_size=2) 501 | 502 | self.fc0 = nn.Linear(256 * 10 * 10, 1024, bias=False) 503 | self.fc1 = nn.Linear(1024, 10, bias=False) 504 | 505 | for m in self.modules(): 506 | if isinstance(m, nn.Conv2d): 507 | n = m.kernel_size[0] * m.kernel_size[1] * m.in_channels 508 | variance1 = math.sqrt(2.0 / n) 509 | m.weight.data.normal_(0, variance1) 510 | # define threshold 511 | m.threshold = args.vth 512 | 513 | elif isinstance(m, nn.Linear): 514 | size = m.weight.size() 515 | fan_in = size[1] # number of columns 516 | variance2 = math.sqrt(2.0 / fan_in) 517 | m.weight.data.normal_(0.0, variance2) 518 | # define threshold 519 | m.threshold = args.vth 520 | 521 | def forward(self, input, steps=100, l=1): 522 | global scale11,scale12,scale21,scale22,scale23,scalef0,scalep1,sign 523 | # mask used for let some element change to 0 524 | mem_11 = torch.zeros(input.size(0), 64, 42, 42, device = input.device) 525 | 526 | mem_12 = torch.zeros(input.size(0), 64, 42, 42, device = input.device) 527 | 528 | mem_1s = torch.zeros(input.size(0), 64, 21, 21, device = input.device) 529 | 530 | mem_21 = torch.zeros(input.size(0), 128, 21, 21, device = input.device) 531 | mem_22 = torch.zeros(input.size(0), 128, 21, 21, device = input.device) 532 | mem_23 = torch.zeros(input.size(0), 128*2, 21, 21, device = input.device) 533 | 534 | membrane_f0 = torch.zeros(input.size(0), 1024, device = input.device) 535 | 536 | Total_input = torch.zeros(input.size(0), 2, 42, 42, device = input.device) 537 | 538 | Total_11_output = torch.zeros(input.size(0), 64, 42, 42, device = input.device) 539 | IF_11_in = torch.zeros(input.size(0), 64, 42, 42, device = input.device) 540 | 541 | Total_12_output = torch.zeros(input.size(0), 64, 42, 42, device = input.device) 542 | IF_12_in = torch.zeros(input.size(0), 64, 42, 42, device = input.device) 543 | 544 | Total_p1_output = torch.zeros(input.size(0), 64, 21, 21,device = input.device) 545 | IF_p1_in = torch.zeros(input.size(0), 64, 21, 21, device = input.device) 546 | 547 | Total_21_output = torch.zeros(input.size(0), 128, 21, 21, device = input.device) 548 | IF_21_in = torch.zeros(input.size(0), 128, 21, 21, device = input.device) 549 | 550 | Total_22_output = torch.zeros(input.size(0), 128, 21, 21, device = input.device) 551 | IF_22_in = torch.zeros(input.size(0), 128, 21, 21, device = input.device) 552 | 553 | Total_23_output = torch.zeros(input.size(0), 128*2, 21, 21, device = input.device) 554 | IF_23_in = torch.zeros(input.size(0), 128*2, 21, 21, device = input.device) 555 | 556 | 557 | Total_f0_output = torch.zeros(input.size(0), 1024, device = input.device) 558 | IF_f0_in = torch.zeros(input.size(0), 1024, device = input.device) 559 | 560 | with torch.no_grad(): 561 | for i in range(steps * args.repeat): 562 | # Get input using frames 563 | eventframe_input = input[:, :, :, :, i % steps].float() 564 | eventframe_input = args.avgtimes * self.maxpool4(eventframe_input) 565 | Total_input = Total_input + eventframe_input 566 | 567 | # convolutional Layer 568 | in_layer = self.cnn11(eventframe_input) 569 | mem_11 = mem_11 + in_layer 570 | mem_11, out = LIF_sNeuron(mem_11, self.cnn11.threshold, l, i) 571 | IF_11_in = IF_11_in + in_layer 572 | Total_11_output = Total_11_output + out 573 | 574 | in_layer = self.cnn12(out) 575 | mem_12 = mem_12 + in_layer 576 | mem_12, out = LIF_sNeuron(mem_12, self.cnn12.threshold, l, i) 577 | IF_12_in = IF_12_in + in_layer 578 | Total_12_output = Total_12_output + out 579 | 580 | # pooling Layer 581 | in_layer = self.avgpool1(out) 582 | mem_1s = mem_1s + in_layer 583 | mem_1s, out = Pooling_sNeuron(mem_1s, 0.75, i) 584 | IF_p1_in = IF_p1_in + in_layer 585 | Total_p1_output = Total_p1_output + out 586 | 587 | # convolutional Layer 588 | in_layer = self.cnn21(out) 589 | mem_21 = mem_21 + in_layer 590 | mem_21, out = LIF_sNeuron(mem_21, self.cnn21.threshold, l, i) 591 | IF_21_in = IF_21_in + in_layer 592 | Total_21_output = Total_21_output + out 593 | 594 | in_layer = self.cnn22(out) 595 | mem_22 = mem_22 + in_layer 596 | mem_22, out = LIF_sNeuron(mem_22, self.cnn22.threshold, l, i) 597 | IF_22_in = IF_22_in + in_layer 598 | Total_22_output = Total_22_output + out 599 | 600 | in_layer = self.cnn23(out) 601 | mem_23 = mem_23 + in_layer 602 | mem_23, out = LIF_sNeuron(mem_23, self.cnn23.threshold, l, i) 603 | IF_23_in = IF_23_in + in_layer 604 | Total_23_output = Total_23_output + out 605 | 606 | # Maxpooling Layer 607 | out = self.avgpool2(out) 608 | 609 | out = out.view(out.size(0), -1) 610 | 611 | # fully-connected Layer 612 | in_layer = self.fc0(out) 613 | membrane_f0 = membrane_f0 + in_layer 614 | membrane_f0, out = LIF_sNeuron(membrane_f0, self.fc0.threshold, l, i) 615 | IF_f0_in = IF_f0_in + in_layer 616 | Total_f0_output = Total_f0_output + out 617 | 618 | if sign == 1: 619 | scalef0 = 0.6 * ave(Total_f0_output, IF_f0_in) + 0.4 * scalef0 620 | scale11 = 0.6 * ave(Total_11_output, IF_11_in) + 0.4 * scale11 621 | scale12 = 0.6 * ave(Total_12_output, IF_12_in) + 0.4 * scale12 622 | scalep1 = 0.6 * ave_p(Total_p1_output, IF_p1_in) + 0.4 * scalep1 623 | scale21 = 0.6 * ave(Total_21_output, IF_21_in) + 0.4 * scale21 624 | scale22 = 0.6 * ave(Total_22_output, IF_22_in) + 0.4 * scale22 625 | scale23 = 0.6 * ave(Total_23_output, IF_23_in) + 0.4 * scale23 626 | 627 | scale_f0 = grad_cal(scalef0, IF_f0_in) 628 | scale_11 = grad_cal(scale11, IF_11_in) 629 | scale_12 = grad_cal(scale12, IF_12_in) 630 | scale_p1 = grad_cal(scalep1, IF_p1_in) 631 | scale_21 = grad_cal(scale21, IF_21_in) 632 | scale_22 = grad_cal(scale22, IF_22_in) 633 | scale_23 = grad_cal(scale23, IF_23_in) 634 | 635 | with torch.enable_grad(): 636 | cnn11_in = self.cnn11(Total_input.detach()) 637 | tem = Total_11_output.detach() 638 | out = torch.mul(cnn11_in, scale_11) 639 | Total_11_output = out - out.detach() + tem 640 | 641 | cnn12_in = self.cnn12(Total_11_output) 642 | tem = Total_12_output.detach() 643 | out = torch.mul(cnn12_in, scale_12) 644 | Total_12_output = out - out.detach() + tem 645 | 646 | pool1_in = self.avgpool1(Total_12_output) 647 | tem = Total_p1_output.detach() 648 | out = torch.mul(pool1_in, scale_p1) 649 | Total_p1_output = out - out.detach() + tem 650 | 651 | cnn21_in = self.cnn21(Total_p1_output) 652 | tem = Total_21_output.detach() 653 | out = torch.mul(cnn21_in, scale_21) 654 | Total_21_output = out - out.detach() + tem 655 | 656 | cnn22_in = self.cnn22(Total_21_output) 657 | tem = Total_22_output.detach() 658 | out = torch.mul(cnn22_in, scale_22) 659 | Total_22_output = out - out.detach() + tem 660 | 661 | cnn23_in = self.cnn23(Total_22_output) 662 | tem = Total_23_output.detach() 663 | out = torch.mul(cnn23_in, scale_23) 664 | Total_23_output = out - out.detach() + tem 665 | 666 | Total_p2_output = self.avgpool2(Total_23_output) 667 | 668 | fc0_in = self.fc0(Total_p2_output.view(Total_p2_output.size(0),-1)) 669 | tem = Total_f0_output.detach() 670 | out = torch.mul(fc0_in, scale_f0) 671 | Total_f0_output = out - out.detach() + tem 672 | 673 | fc1_in = self.fc1(Total_f0_output) 674 | 675 | return fc1_in/self.fc1.threshold/steps 676 | 677 | 678 | def tst(self, input, steps=100, l=1): 679 | mem_11 = torch.zeros(input.size(0), 64, 42, 42,device = input.device) 680 | mem_12 = torch.zeros(input.size(0), 64, 42, 42,device = input.device) 681 | mem_1s = torch.zeros(input.size(0), 64, 21, 21,device = input.device) 682 | 683 | mem_21 = torch.zeros(input.size(0), 128, 21, 21,device = input.device) 684 | mem_22 = torch.zeros(input.size(0), 128, 21, 21,device = input.device) 685 | mem_23 = torch.zeros(input.size(0), 128*2, 21, 21,device = input.device) 686 | 687 | membrane_f0 = torch.zeros(input.size(0), 1024,device = input.device) 688 | membrane_f1 = torch.zeros(input.size(0), 10,device = input.device) 689 | 690 | for i in range(steps * args.repeat): 691 | eventframe_input = (input[:,:,:,:, (i%steps)]).float() 692 | eventframe_input = args.avgtimes * self.maxpool4(eventframe_input) 693 | 694 | # convolutional Layer 695 | mem_11 = mem_11 + self.cnn11(eventframe_input) 696 | mem_11, out = LIF_sNeuron(mem_11, self.cnn11.threshold, l, i) 697 | 698 | mem_12 = mem_12 + self.cnn12(out) 699 | mem_12, out = LIF_sNeuron(mem_12, self.cnn12.threshold, l, i) 700 | 701 | # pooling Layer 702 | mem_1s = mem_1s + self.avgpool1(out) 703 | mem_1s, out = Pooling_sNeuron(mem_1s, 0.75, i) 704 | 705 | # convolutional Layer 706 | mem_21 = mem_21 + self.cnn21(out) 707 | mem_21, out = LIF_sNeuron(mem_21, self.cnn21.threshold, l, i) 708 | 709 | mem_22 = mem_22 + self.cnn22(out) 710 | mem_22, out = LIF_sNeuron(mem_22, self.cnn22.threshold, l, i) 711 | 712 | mem_23 = mem_23 + self.cnn23(out) 713 | mem_23, out = LIF_sNeuron(mem_23, self.cnn23.threshold, l, i) 714 | 715 | # pooling Layer 716 | out = self.avgpool2(out) 717 | 718 | 719 | out = out.view(out.size(0), -1) 720 | 721 | # fully-connected Layer 722 | membrane_f0 = membrane_f0 + self.fc0(out) 723 | membrane_f0, out = LIF_sNeuron(membrane_f0, self.fc0.threshold, l, i) 724 | 725 | membrane_f1 = membrane_f1 + self.fc1(out) 726 | 727 | return membrane_f1 / self.fc1.threshold / steps 728 | 729 | if __name__ == '__main__': 730 | main() 731 | -------------------------------------------------------------------------------- /dvscifar10_dataloader.py: -------------------------------------------------------------------------------- 1 | import struct 2 | import numpy as np 3 | import scipy.misc 4 | import h5py 5 | import glob 6 | import tqdm 7 | import torch 8 | from torch.utils.data import Dataset, DataLoader 9 | import os 10 | import scipy.io as scio 11 | import sys 12 | import pdb 13 | import argparse 14 | import time 15 | import gc 16 | 17 | mapping = { 0 :'airplane' , 18 | 1 :'automobile', 19 | 2 :'bird' , 20 | 3 :'cat' , 21 | 4 :'deer' , 22 | 5 :'dog' , 23 | 6 :'frog' , 24 | 7 :'horse' , 25 | 8 :'ship' , 26 | 9 :'truck' } 27 | 28 | def connt_to_binary(tem_path1, per): 29 | data = np.load(tem_path1, allow_pickle=True) 30 | data_plus = data[data > 0] 31 | data_plus = np.sort(data_plus) 32 | lower_q = np.quantile(data_plus, per, interpolation='lower') 33 | del data_plus 34 | gc.collect() 35 | return (data >= lower_q).astype(np.int8) 36 | 37 | 38 | class DVSCifar10(Dataset): 39 | def __init__(self, root, train=True, transform=None, target_transform=None, steps=100, count=None,per=0.25): 40 | self.root = os.path.expanduser(root) 41 | self.transform = transform 42 | self.target_transform = target_transform 43 | self.train = train 44 | self.steps = steps 45 | self.count = count 46 | self.per = per 47 | path_converted = os.path.join(root, 'converted') 48 | if not os.path.exists(path_converted): 49 | os.mkdir(path_converted) 50 | if count: 51 | tem_path = os.path.join(path_converted, 'steps{}_count'.format(steps)) 52 | else: 53 | tem_path = os.path.join(path_converted, 'steps{}_binary_per{}'.format(steps,per)) 54 | print(tem_path) 55 | 56 | if not os.path.exists(tem_path): 57 | os.mkdir(tem_path) 58 | 59 | 60 | if self.train: 61 | tem_path1 = os.path.join(tem_path, 'train.npy') 62 | tem_path2 = os.path.join(tem_path, 'train_label.npy') 63 | 64 | if (not os.path.exists(tem_path1)) | (not os.path.exists(tem_path2)): 65 | print('dataset not found => creating...') 66 | if count: 67 | time2 = time.time() 68 | pre_process(raw_data_path=root,steps=steps,count=count,threshold=per) 69 | time2 = time.time() - time2 70 | print('create frame image takes %dh %dmin'%(time2/3600, (time2%3600)/60)) 71 | self.data = np.load(tem_path1, allow_pickle=True) 72 | self.targets = np.load(tem_path2, allow_pickle=True) 73 | print('load frame image for train successfully') 74 | else: 75 | count_path = os.path.join(path_converted, 'steps{}_count'.format(steps)) 76 | count_path1 = os.path.join(count_path, 'train.npy') 77 | count_path2 = os.path.join(count_path, 'train_label.npy') 78 | 79 | if (not os.path.exists(count_path1)) | (not os.path.exists(count_path2)): 80 | pre_process(raw_data_path=root,steps=steps,count=count,threshold=per) 81 | 82 | self.data = connt_to_binary(count_path1, per) 83 | self.targets = np.load(count_path2, allow_pickle=True) 84 | np.save(os.path.join(root, 'converted', 'steps{}_binary_per{}'.format(steps,per), 'train.npy'), self.data) 85 | np.save(os.path.join(root, 'converted', 'steps{}_binary_per{}'.format(steps,per), 'train_label.npy'), self.targets) 86 | print('load frame image for train successfully') 87 | 88 | else: 89 | self.data = np.load(tem_path1, allow_pickle=True) 90 | self.targets = np.load(tem_path2, allow_pickle=True) 91 | print('load frame image for train successfully') 92 | else: 93 | tem_path1 = os.path.join(tem_path, 'test.npy') 94 | tem_path2 = os.path.join(tem_path, 'test_label.npy') 95 | 96 | if (not os.path.exists(tem_path1)) | (not os.path.exists(tem_path2)): 97 | print('dataset not found => creating...') 98 | if count: 99 | time2 = time.time() 100 | pre_process(raw_data_path=root,steps=steps,count=count,threshold=per) 101 | time2 = time.time() - time2 102 | print('create frame image takes %dh %dmin'%(time2/3600, (time2%3600)/60)) 103 | self.data = np.load(tem_path1, allow_pickle=True) 104 | self.targets = np.load(tem_path2, allow_pickle=True) 105 | print('load frame image for test successfully') 106 | else: 107 | count_path = os.path.join(path_converted, 'steps{}_count'.format(steps)) 108 | count_path1 = os.path.join(count_path, 'test.npy') 109 | count_path2 = os.path.join(count_path, 'test_label.npy') 110 | 111 | if (not os.path.exists(count_path1)) | (not os.path.exists(count_path2)): 112 | pre_process(raw_data_path=root,steps=steps,count=count,threshold=per) 113 | 114 | self.data = connt_to_binary(count_path1, per) 115 | 116 | self.targets = np.load(count_path2, allow_pickle=True) 117 | np.save(os.path.join(root, 'converted', 'steps{}_binary_per{}'.format(steps,per), 'test.npy'), self.data) 118 | np.save(os.path.join(root, 'converted', 'steps{}_binary_per{}'.format(steps,per), 'test_label.npy'), self.targets) 119 | print('load binary frame image for test successfully') 120 | 121 | else: 122 | self.data = np.load(tem_path1, allow_pickle=True) 123 | self.targets = np.load(tem_path2, allow_pickle=True) 124 | print('load frame image for test successfully') 125 | 126 | 127 | 128 | def __getitem__(self, index): 129 | """ 130 | Args: 131 | index (int): Index 132 | Returns: 133 | tuple: (image, target) where target is index of the target class. 134 | """ 135 | img = self.data[index] 136 | target = (self.targets[index]) 137 | img = torch.from_numpy(img) 138 | target = torch.tensor(target) 139 | 140 | return (img, target) 141 | 142 | def __len__(self): 143 | return len(self.data) 144 | 145 | 146 | 147 | def pre_process(raw_data_path, steps=100, count=True, threshold=None): 148 | print('loading event data') 149 | train_data,test_data,train_label,test_label = import_dvscifar10(raw_data_path) 150 | print('start pre-processing') 151 | num_train = len(train_data) 152 | num_test = len(test_data) 153 | 154 | # Init the frame data 155 | train_frame_data = np.zeros([num_train, 2, 128, 128, steps], dtype=np.int8) 156 | for index,events in enumerate(train_data): 157 | if (index + 1) % 100 == 0: 158 | print("\r\tProcessing train data: {:.2f}% complete\t\t".format((index+1) / 90), end='') 159 | p = events[:, 3] 160 | x = events[:, 1] 161 | y = events[:, 2] 162 | ts = events[:, 0] 163 | step_len = ts[-1] // steps 164 | 165 | p_on = (p==1) 166 | p_off = (p==0) 167 | 168 | x_on = x[p_on] 169 | y_on = y[p_on] 170 | ts_on = ts[p_on] 171 | 172 | x_off = x[p_off] 173 | y_off = y[p_off] 174 | ts_off = ts[p_off] 175 | 176 | for j in range(steps): 177 | ts_range = np.where((ts_on >= step_len * j) & (ts_on < step_len * (j + 1))) 178 | for x1, y1 in zip(x_on[ts_range], y_on[ts_range]): 179 | train_frame_data[index, 1, x1, y1, j] += 1 180 | 181 | ts_range = np.where((ts_off >= step_len * j) & (ts_off < step_len * (j + 1))) 182 | for x1, y1 in zip(x_off[ts_range], y_off[ts_range]): 183 | train_frame_data[index, 0, x1, y1, j] += 1 184 | del events, p, x, y, ts, p_off, p_on, x_off, x_on, y_off, y_on, ts_off, ts_on 185 | gc.collect() 186 | if count: 187 | np.save(os.path.join(raw_data_path, 'converted', 'steps{}_count'.format(steps), 'train.npy'), train_frame_data) 188 | np.save(os.path.join(raw_data_path, 'converted', 'steps{}_count'.format(steps), 'train_label.npy'), train_label) 189 | else: 190 | np.save(os.path.join(raw_data_path, 'converted', 'steps{}_binary_th{}'.format(steps,per), 'train.npy'), train_frame_data) 191 | np.save(os.path.join(raw_data_path, 'converted', 'steps{}_binary_th{}'.format(steps,per), 'train_label.npy'), train_label) 192 | del train_frame_data, train_label, train_data 193 | gc.collect() 194 | 195 | 196 | test_frame_data = np.zeros([num_test, 2, 128, 128, steps], dtype=np.int8) 197 | for index,events in enumerate(test_data): 198 | if (index + 1) % 100 == 0: 199 | print("\r\tProcessing test data: {:.2f}% complete\t\t".format((index+1) / 10), end='') 200 | p = events[:, 3] 201 | x = events[:, 1] 202 | y = events[:, 2] 203 | ts = events[:, 0] 204 | step_len = ts[-1] // steps 205 | 206 | p_on = (p==1) 207 | p_off = (p==0) 208 | 209 | x_on = x[p_on] 210 | y_on = y[p_on] 211 | ts_on = ts[p_on] 212 | 213 | x_off = x[p_off] 214 | y_off = y[p_off] 215 | ts_off = ts[p_off] 216 | 217 | for j in range(steps): 218 | ts_range = np.where((ts_on >= step_len * j) & (ts_on < step_len * (j + 1))) 219 | for x1, y1 in zip(x_on[ts_range], y_on[ts_range]): 220 | test_frame_data[index, 1, x1, y1, j] += 1 221 | 222 | ts_range = np.where((ts_off >= step_len * j) & (ts_off < step_len * (j + 1))) 223 | for x1, y1 in zip(x_off[ts_range], y_off[ts_range]): 224 | test_frame_data[index, 0, x1, y1, j] += 1 225 | del events, p, x, y, ts, p_off, p_on, x_off, x_on, y_off, y_on, ts_off, ts_on 226 | gc.collect() 227 | if count: 228 | np.save(os.path.join(raw_data_path, 'converted', 'steps{}_count'.format(steps), 'test.npy'), test_frame_data) 229 | np.save(os.path.join(raw_data_path, 'converted', 'steps{}_count'.format(steps), 'test_label.npy'), test_label) 230 | else: 231 | np.save(os.path.join(raw_data_path, 'converted', 'steps{}_binary_th{}'.format(steps,per), 'test.npy'), test_frame_data) 232 | np.save(os.path.join(raw_data_path, 'converted', 'steps{}_binary_th{}'.format(steps,per), 'test_label.npy'), test_label) 233 | del test_frame_data, test_label,test_data 234 | gc.collect() 235 | 236 | return 237 | 238 | 239 | 240 | 241 | 242 | 243 | 244 | def import_dvscifar10(raw_data_path): 245 | events_path = os.path.join(raw_data_path, 'events') 246 | if not os.path.exists(events_path): 247 | os.mkdir(events_path) 248 | path1 = os.path.join(events_path, 'train_data.npy') 249 | path2 = os.path.join(events_path, 'test_data.npy') 250 | path3 = os.path.join(events_path, 'train_label.npy') 251 | path4 = os.path.join(events_path, 'test_label.npy') 252 | 253 | if os.path.exists(path1) & os.path.exists(path2) & os.path.exists(path3) & os.path.exists(path4): 254 | train_data = np.load(path1, allow_pickle=True) 255 | test_data = np.load(path2, allow_pickle=True) 256 | train_label = np.load(path3, allow_pickle=True) 257 | test_label = np.load(path4, allow_pickle=True) 258 | else: 259 | print('event data not found => creating...') 260 | time1 = time.time() 261 | train_data, test_data, train_label, test_label = create_events(raw_data_path) 262 | time1 = time.time() - time1 263 | print('create event data takes %dh %dmin'%(time1/3600, (time1%3600)/60)) 264 | return train_data, test_data, train_label, test_label 265 | 266 | 267 | 268 | 269 | def create_events(raw_data_path): 270 | train_data = [] 271 | test_data = [] 272 | train_label = [] 273 | test_label = [] 274 | index = np.arange(1000) 275 | test_index = np.random.choice(index.shape[0],100,replace=False) 276 | train_index = np.delete(index, test_index) 277 | 278 | print("processing raw training data...") 279 | key = 1 280 | for i in range(10): 281 | current_path = os.path.join(raw_data_path, mapping[i]) 282 | for fn in train_index: 283 | filename = os.path.join(current_path, "{}.mat".format(fn)) 284 | events = scio.loadmat(filename, verify_compressed_data_integrity=False)['out1'].astype(np.int64)#astype 285 | train_data.append(events) 286 | train_label.append(i) 287 | if key % 100 == 0: 288 | print("\r\tProcessing train data: {:.2f}% complete\t\t".format(key / 90), end='') 289 | key += 1 290 | 291 | print("\nprocessing testing data...") 292 | key = 1 293 | for i in range(10): 294 | current_path = os.path.join(raw_data_path, mapping[i]) 295 | for fn in test_index: 296 | filename = os.path.join(current_path, "{}".format(fn) + '.mat') 297 | events = scio.loadmat(filename, verify_compressed_data_integrity=False)['out1'].astype(np.int64)#astype 298 | test_data.append(events) 299 | test_label.append(i) 300 | if key % 100 == 0: 301 | print("\r\tTest data {:.2f}% complete\t\t".format(key / 10), end='') 302 | key += 1 303 | train_data = np.array(train_data) 304 | test_data = np.array(test_data) 305 | train_label = np.array(train_label) 306 | test_label = np.array(test_label) 307 | events_path = os.path.join(raw_data_path, 'events') 308 | if not os.path.exists(events_path): 309 | os.mkdir(events_path) 310 | np.save(os.path.join(events_path, 'train_data.npy'), train_data) 311 | np.save(os.path.join(events_path, 'test_data.npy'), test_data) 312 | np.save(os.path.join(events_path, 'train_label.npy'), train_label) 313 | np.save(os.path.join(events_path, 'test_label.npy'), test_label) 314 | return train_data, test_data, train_label, test_label 315 | 316 | -------------------------------------------------------------------------------- /lenet5.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import shutil 4 | import time 5 | import math 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.parallel 9 | import torch.backends.cudnn as cudnn 10 | from torch.autograd import Variable 11 | import torch.optim 12 | import torch.utils.data 13 | import torchvision.transforms as transforms 14 | import torchvision.models as models 15 | import torchvision.datasets as dsets 16 | import numpy as np 17 | import random 18 | 19 | import torch._utils 20 | try: 21 | torch._utils._rebuild_tensor_v2 22 | except AttributeError: 23 | def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad, backward_hooks): 24 | tensor = torch._utils._rebuild_tensor(storage, storage_offset, size, stride) 25 | tensor.requires_grad = requires_grad 26 | tensor._backward_hooks = backward_hooks 27 | return tensor 28 | torch._utils._rebuild_tensor_v2 = _rebuild_tensor_v2 29 | 30 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 31 | 32 | model_names = sorted(name for name in models.__dict__ 33 | if name.islower() and not name.startswith("__") 34 | and callable(models.__dict__[name])) 35 | parser = argparse.ArgumentParser(description='PyTorch MNIST Training') 36 | parser.add_argument('--dataset', default='MNIST', type=str, help='dataset = [MNIST]') 37 | parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18', 38 | choices=model_names, 39 | help='model architecture: ' + 40 | ' | '.join(model_names) + 41 | ' (default: resnet18)') 42 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 43 | help='manual epoch number (useful on restarts)') 44 | parser.add_argument('-b', '--batch-size', default=64, type=int, 45 | metavar='N', help='mini-batch size (default: 100)') 46 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 47 | help='momentum') 48 | parser.add_argument('--weight-decay', '--wd', default=5e-4, type=float, 49 | metavar='W', help='weight decay (default: 1e-4)') 50 | parser.add_argument('--print-freq', '-p', default=500, type=int, 51 | metavar='N', help='print frequency (default: 10)') 52 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 53 | help='path to latest checkpoint (default: none)') 54 | parser.add_argument('-load', default='', type=str, metavar='PATH', 55 | help='path to training mask (default: none)') 56 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 57 | help='evaluate model on validation set') 58 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 59 | help='use pre-trained model') 60 | parser.add_argument('--lr', '--learning-rate', default=0.00085, type=float, 61 | metavar='LR', help='initial learning rate') 62 | parser.add_argument('-j', '--workers', default=1, type=int, metavar='N', 63 | help='number of data loading workers (default: 4)') 64 | parser.add_argument('--epochs', default=125, type=int, metavar='N', 65 | help='number of total epochs to run') 66 | parser.add_argument('--steps', default=100, type=int, metavar='N', 67 | help='number of total epochs to run') 68 | parser.add_argument('--vth', default=1, type=float, metavar='Vth', 69 | help='voltage threshold') 70 | parser.add_argument('--leak', default=1, type=float, metavar='Leak', 71 | help='leaky parameter') 72 | parser.add_argument('--hz', default=5, type=int, metavar='hz', 73 | help='scale update hz') 74 | parser.add_argument('--seed', default=0, type=int, metavar='seed', 75 | help='whether change the seed') 76 | 77 | best_prec1 = 0 78 | change = 25 79 | tp1 = []; 80 | tp5 = []; 81 | ep = []; 82 | lRate = []; 83 | device_num = 1 84 | device = torch.device("cuda:0") 85 | 86 | tp1_tr = []; 87 | tp5_tr = []; 88 | losses_tr = []; 89 | losses_eval = []; 90 | 91 | sign = 1 92 | 93 | scale1 = 1 94 | scale2 = 1 95 | scale3 = 1 96 | scale4 = 1 97 | scale5 = 1 98 | args = parser.parse_args() 99 | 100 | def main(): 101 | global args, best_prec1, device_num, sign 102 | if args.seed: 103 | seed1 = random.randint(1,100) 104 | seed2 = random.randint(1,100) 105 | seed3 = random.randint(1,100) 106 | else: 107 | seed1 = 30 108 | seed2 = 22 109 | seed3 = 66 110 | batch_size = args.batch_size 111 | print('\n'+'='*15+'settings'+'='*15) 112 | print('lr: ', args.lr) 113 | print('change lr point:%d'%change) 114 | print('batchsize:',batch_size) 115 | print('lenet adapt version') 116 | print('random-seed = %d %d %d'%(seed1,seed2,seed3)) 117 | print('steps:{}'.format(args.steps)) 118 | print('vth:{}'.format(args.vth)) 119 | print('leak:{}'.format(args.leak)) 120 | print('scale hz:{}'.format(args.hz)) 121 | # print('rand seed: %d'%seed) 122 | print('='*15+'settings'+'='*15+'\n') 123 | 124 | torch.manual_seed(seed1) 125 | torch.cuda.manual_seed(seed2) 126 | torch.cuda.manual_seed_all(seed3) 127 | np.random.seed(seed1) 128 | random.seed(seed2) 129 | 130 | 131 | model = CNNModel() 132 | print(model) 133 | model = torch.nn.DataParallel(model) 134 | model.to(device) 135 | 136 | criterion = torch.nn.MSELoss(reduction='sum') 137 | criterion_en = torch.nn.CrossEntropyLoss() 138 | 139 | learning_rate = args.lr 140 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=5e-4) 141 | 142 | cudnn.benchmark = False 143 | cudnn.deterministic = True 144 | 145 | 146 | # optionally resume from a checkpoint 147 | if args.resume: 148 | if os.path.isfile(args.resume): 149 | print("=> loading checkpoint '{}'".format(args.resume)) 150 | checkpoint = torch.load(args.resume) 151 | args.start_epoch = checkpoint['epoch'] 152 | best_prec1 = checkpoint['best_prec1'] 153 | model.load_state_dict(checkpoint['state_dict']) 154 | optimizer.load_state_dict(checkpoint['optimizer']) 155 | print("=> loaded checkpoint '{}' (epoch {})" 156 | .format(args.resume, checkpoint['epoch'])) 157 | else: 158 | print("=> no checkpoint found at '{}'".format(args.resume)) 159 | 160 | 161 | 162 | '''STEP 1: LOADING DATASET''' 163 | dataset_path = '/data/diospada/mnist-python/data' 164 | train_data = dsets.MNIST(root=dataset_path, train=True, transform=transforms.ToTensor(), download=True) 165 | val_data = dsets.MNIST(root=dataset_path, train=False, transform=transforms.ToTensor()) 166 | 167 | train_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size=args.batch_size, shuffle=True) 168 | val_loader = torch.utils.data.DataLoader(dataset=val_data, batch_size=int(args.batch_size), shuffle=False) 169 | print('read dataset succeed') 170 | if args.evaluate: 171 | validate(val_loader, model, criterion, criterion_en, time_steps=args.steps, leak=args.leak) 172 | return 173 | 174 | prec1_tr = 0 175 | for epoch in range(args.start_epoch, args.epochs): 176 | if epoch % args.hz == 0 and args.hz < args.epochs: 177 | sign = 1 178 | else: 179 | sign = 0 180 | adjust_learning_rate(optimizer, epoch) 181 | ep.append(epoch) 182 | start_end = time.time() 183 | # train for one epoch 184 | prec1_tr = train(train_loader, model, criterion, criterion_en, optimizer, epoch, time_steps=args.steps, leak=args.leak) 185 | 186 | # evaluate on validation set 187 | modeltest = model.module 188 | prec1 = validate(val_loader, modeltest, criterion, criterion_en, time_steps=args.steps, leak=args.leak) 189 | 190 | 191 | # remember best prec@1 and save checkpoint 192 | is_best = prec1 > best_prec1 193 | best_prec1 = max(prec1, best_prec1) 194 | save_checkpoint({ 195 | 'epoch': epoch + 1, 196 | 'arch': args.arch, 197 | 'state_dict': model.state_dict(), 198 | 'best_prec1': best_prec1, 199 | 'optimizer': optimizer.state_dict(), 200 | }, is_best) 201 | 202 | time_used = time.time() - start_end 203 | print('time used this epoch: %dmin %ds'%(time_used//60,time_used%60)) 204 | for k in range(0, args.epochs - args.start_epoch): 205 | print('Epoch: [{0}/{1}]\t' 206 | 'LR:{2}\t' 207 | 'Prec@1 {top1:.3f} \t' 208 | 'Prec@5 {top5:.3f} '.format( 209 | ep[k], args.epochs, lRate[k], top1=tp1[k], top5=tp5[k])) 210 | print('best:',best_prec1) 211 | 212 | 213 | def grad_cal(scale, IF_in): 214 | out = scale * IF_in.gt(0).type(torch.cuda.FloatTensor) 215 | return out 216 | 217 | def ave(output, input): 218 | c = input >= output 219 | if input[c].sum() < 1e-3: 220 | return 1 221 | return output[c].sum()/input[c].sum() 222 | 223 | def ave_p(output, input): 224 | if input.sum() < 1e-3: 225 | return 1 226 | return output.sum()/input.sum() 227 | 228 | def train(train_loader, model, criterion, criterion_en, optimizer, epoch, time_steps, leak): 229 | batch_time = AverageMeter() 230 | data_time = AverageMeter() 231 | losses = AverageMeter() 232 | top1 = AverageMeter() 233 | top5 = AverageMeter() 234 | 235 | top1_tr = AverageMeter() 236 | top5_tr = AverageMeter() 237 | losses_en = AverageMeter() 238 | 239 | # switch to train mode 240 | model.train() 241 | 242 | end = time.time() 243 | start_end = end 244 | for i, (inputdata, target) in enumerate(train_loader): 245 | # measure data loading time 246 | data_time.update(time.time() - end) 247 | inputdata, target = inputdata.to(device), target.to(device) 248 | labels = target.clone() 249 | 250 | optimizer.zero_grad() # Clear gradients w.r.t. parameters 251 | 252 | output = model(inputdata, steps=time_steps, l=leak) 253 | 254 | targetN = output.data.clone().zero_().to(device) 255 | targetN.scatter_(1, target.unsqueeze(1), 1) 256 | targetN = Variable(targetN.type(torch.cuda.FloatTensor)) 257 | 258 | loss = criterion(output.cpu(), targetN.cpu()) 259 | loss_en = criterion_en(output.cpu(), labels.cpu()) 260 | 261 | # measure accuracy and record loss 262 | prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) 263 | losses.update(loss.item(), inputdata.size(0)) 264 | top1.update(prec1.item(), inputdata.size(0)) 265 | top5.update(prec5.item(), inputdata.size(0)) 266 | 267 | prec1_tr, prec5_tr = accuracy(output.data, target, topk=(1, 5)) 268 | losses_en.update(loss_en.item(), inputdata.size(0)) 269 | top1_tr.update(prec1_tr.item(), inputdata.size(0)) 270 | top5_tr.update(prec5_tr.item(), inputdata.size(0)) 271 | 272 | loss.backward(retain_graph=False) 273 | 274 | 275 | optimizer.step() 276 | 277 | 278 | # measure elapsed time 279 | batch_time.update(time.time() - end) 280 | end = time.time() 281 | time_used = end - start_end 282 | print('train time: %dmin %ds'%(time_used//60,time_used%60)) 283 | 284 | print('Epoch: [{0}] Prec@1 {top1_tr.avg:.3f} Prec@5 {top5_tr.avg:.3f} Entropy_Loss {loss_en.avg:.4f}' 285 | .format(epoch, top1_tr=top1_tr, top5_tr=top5_tr, loss_en=losses_en)) 286 | 287 | losses_tr.append(losses_en.avg) 288 | tp1_tr.append(top1_tr.avg) 289 | tp5_tr.append(top5_tr.avg) 290 | 291 | return top1_tr.avg 292 | 293 | 294 | def validate(val_loader, model, criterion, criterion_en, time_steps, leak): 295 | batch_time = AverageMeter() 296 | data_time = AverageMeter() 297 | losses = AverageMeter() 298 | top1 = AverageMeter() 299 | top5 = AverageMeter() 300 | losses_en_eval = AverageMeter() 301 | 302 | # switch to evaluate mode 303 | model.eval() 304 | 305 | end = time.time() 306 | with torch.no_grad(): 307 | for i, (inputdata, target) in enumerate(val_loader): 308 | # measure data loading time 309 | data_time.update(time.time() - end) 310 | input_var = inputdata.to(device) 311 | target = target.to(device) 312 | 313 | labels = Variable(target.to(device)) 314 | target = target.to(device) 315 | 316 | output = model.tst(input=input_var, steps=time_steps, l=leak) 317 | targetN = output.data.clone().zero_().to(device) 318 | targetN.scatter_(1, target.unsqueeze(1), 1) 319 | targetN = Variable(targetN.type(torch.cuda.FloatTensor)) 320 | loss = criterion(output.cpu(), targetN.cpu()) 321 | loss_en = criterion_en(output.cpu(), labels.cpu()) 322 | 323 | # measure accuracy and record loss 324 | prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) 325 | losses.update(loss.item(), inputdata.size(0)) 326 | top1.update(prec1.item(), inputdata.size(0)) 327 | top5.update(prec5.item(), inputdata.size(0)) 328 | losses_en_eval.update(loss_en.item(), inputdata.size(0)) 329 | 330 | # measure elapsed time 331 | batch_time.update(time.time() - end) 332 | end = time.time() 333 | 334 | print('Test: Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f} Entropy_Loss {losses_en_eval.avg:.4f}' 335 | .format(top1=top1, top5=top5, losses_en_eval=losses_en_eval)) 336 | 337 | tp1.append(top1.avg) 338 | tp5.append(top5.avg) 339 | losses_eval.append(losses_en_eval.avg) 340 | 341 | return top1.avg 342 | 343 | 344 | def save_checkpoint(state, is_best, filename='checkpointT1_mnist1.pth.tar'): 345 | torch.save(state, filename) 346 | if is_best: 347 | shutil.copyfile(filename, 'model_bestT1_mnist1.pth.tar') 348 | 349 | 350 | class AverageMeter(object): 351 | """Computes and stores the average and current value""" 352 | 353 | def __init__(self): 354 | self.reset() 355 | 356 | def reset(self): 357 | self.val = 0 358 | self.avg = 0 359 | self.sum = 0 360 | self.count = 0 361 | 362 | def update(self, val, n=1): 363 | self.val = val 364 | self.sum += val * n 365 | self.count += n 366 | self.avg = self.sum / self.count 367 | 368 | 369 | def adjust_learning_rate(optimizer, epoch): 370 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 371 | lr = args.lr 372 | 373 | for param_group in optimizer.param_groups: 374 | if epoch >= change: 375 | param_group['lr'] = 0.2 * lr 376 | 377 | elif epoch < change: 378 | param_group['lr'] = lr 379 | 380 | lRate.append(param_group['lr']) 381 | 382 | 383 | def accuracy(output, target, topk=(1,)): 384 | """Computes the precision@k for the specified values of k""" 385 | maxk = max(topk) 386 | batch_size = target.size(0) 387 | 388 | _, pred = output.topk(maxk, 1, True, True) 389 | pred = pred.t() 390 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 391 | 392 | res = [] 393 | for k in topk: 394 | correct_k = correct[:k].view(-1).float().sum(0) 395 | res.append(correct_k.mul_(100.0 / batch_size)) 396 | return res 397 | 398 | 399 | class SpikingNN(torch.autograd.Function): 400 | def forward(self, input): 401 | self.save_for_backward(input) 402 | return input.gt(0).type(torch.cuda.FloatTensor) 403 | 404 | def backward(self, grad_output): 405 | input, = self.saved_tensors 406 | grad_input = grad_output.clone() 407 | grad_input[input <= 0.0] = 0 408 | return grad_input 409 | 410 | 411 | def LIF_sNeuron(membrane_potential, threshold, l, i): 412 | # check exceed membrane potential and reset 413 | ex_membrane = nn.functional.threshold(membrane_potential, threshold, 0) 414 | membrane_potential = membrane_potential - ex_membrane 415 | # generate spike 416 | out = SpikingNN()(ex_membrane) 417 | membrane_potential = l * membrane_potential.detach() + membrane_potential - membrane_potential.detach() 418 | 419 | return membrane_potential, out 420 | 421 | 422 | def Pooling_sNeuron(membrane_potential, threshold, i): 423 | # check exceed membrane potential and reset 424 | ex_membrane = nn.functional.threshold(membrane_potential, threshold, 0) 425 | membrane_potential = membrane_potential - ex_membrane # hard reset 426 | # generate spike 427 | out = SpikingNN()(ex_membrane) 428 | 429 | return membrane_potential, out 430 | 431 | 432 | class CNNModel(nn.Module): 433 | def __init__(self): 434 | super(CNNModel, self).__init__() 435 | self.cnn1 = nn.Conv2d(in_channels=1, out_channels=20, kernel_size=5, stride=1, padding=2, bias=False) 436 | self.avgpool1 = nn.AvgPool2d(kernel_size=2) 437 | 438 | self.cnn2 = nn.Conv2d(in_channels=20, out_channels=50, kernel_size=5, stride=1, padding=2, bias=False) 439 | self.avgpool2 = nn.AvgPool2d(kernel_size=2) 440 | 441 | self.fc0 = nn.Linear(50*7*7, 200, bias=False) 442 | self.fc1 = nn.Linear(200, 10, bias=False) 443 | 444 | for m in self.modules(): 445 | if isinstance(m, nn.Conv2d): 446 | n = m.kernel_size[0] * m.kernel_size[1] * m.in_channels 447 | variance1 = math.sqrt(2. / n) 448 | m.weight.data.normal_(0, variance1) 449 | m.threshold = args.vth 450 | 451 | elif isinstance(m, nn.Linear): 452 | size = m.weight.size() 453 | fan_in = size[1] 454 | variance2 = math.sqrt(2.0 / fan_in) 455 | m.weight.data.normal_(0.0, variance2) 456 | m.threshold = args.vth 457 | 458 | def forward(self, inputdata, steps=100, l=1): 459 | 460 | global scale1, scale2, scale3, scale4, scale5, sign 461 | 462 | mem_1 = torch.zeros(inputdata.size(0), 20, 28, 28, device = inputdata.device) 463 | mem_1s = torch.zeros(inputdata.size(0), 20, 14, 14, device =inputdata.device) 464 | 465 | mem_2 = torch.zeros(inputdata.size(0), 50, 14, 14, device = inputdata.device) 466 | mem_2s = torch.zeros(inputdata.size(0), 50, 7, 7, device = inputdata.device) 467 | 468 | membrane_f0 = torch.zeros(inputdata.size(0), 200, device = inputdata.device) 469 | 470 | Total_input = torch.zeros(inputdata.size(0), 1, 28, 28, device = inputdata.device) 471 | 472 | Total_1_output = torch.zeros(inputdata.size(0), 20, 28, 28, device = inputdata.device) 473 | IF_in_c1 = torch.zeros(inputdata.size(0), 20, 28, 28, device = inputdata.device) 474 | 475 | Total_2_output = torch.zeros(inputdata.size(0), 50, 14, 14, device = inputdata.device) 476 | IF_in_c2 = torch.zeros(inputdata.size(0), 50, 14, 14, device = inputdata.device) 477 | 478 | Total_p1_output = torch.zeros(inputdata.size(0), 20, 14, 14, device = inputdata.device) 479 | IF_in_p1 = torch.zeros(inputdata.size(0), 20, 14, 14, device = inputdata.device) 480 | 481 | Total_p2_output = torch.zeros(inputdata.size(0), 50, 7, 7, device = inputdata.device) 482 | IF_in_p2 = torch.zeros(inputdata.size(0), 50, 7, 7, device = inputdata.device) 483 | 484 | Total_f0_output = torch.zeros(inputdata.size(0), 200, device = inputdata.device) 485 | IF_in_f0 = torch.zeros(inputdata.size(0), 200, device = inputdata.device) 486 | 487 | with torch.no_grad(): 488 | for i in range(steps): 489 | # Poisson input spike generation 490 | rand_num = torch.rand(inputdata.size(0), inputdata.size(1), inputdata.size(2), inputdata.size(3), device = inputdata.device) 491 | Poisson_d_input = (torch.abs(inputdata)/2) > rand_num 492 | Poisson_d_input = torch.mul(Poisson_d_input.float(), torch.sign(inputdata)) 493 | Total_input = Total_input + Poisson_d_input 494 | 495 | # convolutional Layer 496 | in_layer = self.cnn1(Poisson_d_input) 497 | mem_1 = mem_1 + in_layer 498 | mem_1, out = LIF_sNeuron(mem_1, self.cnn1.threshold, l, i) 499 | IF_in_c1 = IF_in_c1 + in_layer 500 | Total_1_output = Total_1_output + out 501 | 502 | # pooling Layer 503 | in_layer = self.avgpool1(out) 504 | mem_1s = mem_1s + in_layer 505 | mem_1s, out = Pooling_sNeuron(mem_1s, 0.75, i) 506 | IF_in_p1 = IF_in_p1 + in_layer 507 | Total_p1_output = Total_p1_output + out 508 | 509 | # convolutional Layer 510 | in_layer = self.cnn2(out) 511 | mem_2 = mem_2 + in_layer 512 | mem_2, out = LIF_sNeuron(mem_2, self.cnn2.threshold, l, i) 513 | IF_in_c2 = IF_in_c2 + in_layer 514 | Total_2_output = Total_2_output + out 515 | 516 | # pooling Layer 517 | in_layer = self.avgpool2(out) 518 | mem_2s = mem_2s + in_layer 519 | mem_2s, out = Pooling_sNeuron(mem_2s, 0.75, i) 520 | IF_in_p2 = IF_in_p2 + in_layer 521 | Total_p2_output = Total_p2_output + out 522 | 523 | out = out.view(out.size(0), -1) 524 | 525 | # fully-connected Layer 526 | in_layer = self.fc0(out) 527 | membrane_f0 = membrane_f0 + in_layer 528 | membrane_f0, out = LIF_sNeuron(membrane_f0, self.fc0.threshold, l, i) 529 | IF_in_f0 = IF_in_f0 + in_layer 530 | Total_f0_output = Total_f0_output + out 531 | 532 | if sign == 1: 533 | scale1 = 0.6 * ave(Total_1_output, IF_in_c1) + 0.4 * scale1 534 | scale2 = 0.6 * ave_p(Total_p1_output, IF_in_p1) + 0.4 * scale2 535 | scale3 = 0.6 * ave(Total_2_output, IF_in_c2) + 0.4 * scale3 536 | scale4 = 0.6 * ave_p(Total_p2_output, IF_in_p2) + 0.4 * scale4 537 | scale5 = 0.6 * ave(Total_f0_output, IF_in_f0) + 0.4 * scale5 538 | 539 | 540 | scale_1 = grad_cal(scale1, IF_in_c1) 541 | scale_2 = grad_cal(scale2, IF_in_p1) 542 | scale_3 = grad_cal(scale3, IF_in_c2) 543 | scale_4 = grad_cal(scale4, IF_in_p2) 544 | scale_5 = grad_cal(scale5, IF_in_f0) 545 | 546 | with torch.enable_grad(): 547 | cnn1_in = self.cnn1(Total_input.detach()) 548 | tem = Total_1_output.detach() 549 | out = torch.mul(cnn1_in,scale_1) 550 | Total_1_output = out - out.detach() + tem 551 | 552 | 553 | pool1_in = self.avgpool1(Total_1_output) 554 | tem = Total_p1_output.detach() 555 | out = torch.mul(pool1_in,scale_2) 556 | Total_p1_output = out - out.detach() + tem 557 | 558 | cnn2_in = self.cnn2(Total_p1_output) 559 | tem = Total_2_output.detach() 560 | out = torch.mul(cnn2_in, scale_3) 561 | Total_2_output = out - out.detach() + tem 562 | 563 | pool2_in = self.avgpool2(Total_2_output) 564 | tem = Total_p2_output.detach() 565 | out = torch.mul(pool2_in, scale_4) 566 | Total_p2_output = out - out.detach() + tem 567 | 568 | fc0_in = self.fc0(Total_p2_output.view(Total_p2_output.size(0),-1)) 569 | tem = Total_f0_output.detach() 570 | out = torch.mul(fc0_in, scale_5) 571 | Total_f0_output = out - out.detach() + tem 572 | 573 | fc1_in = self.fc1(Total_f0_output) 574 | 575 | 576 | return fc1_in/self.fc1.threshold/steps 577 | 578 | 579 | def tst(self, input, steps=100, l=1): 580 | mem_1 = torch.zeros(input.size(0), 20, 28, 28, device = input.device) 581 | mem_1s = torch.zeros(input.size(0), 20, 14, 14, device = input.device) 582 | mem_2 = torch.zeros(input.size(0), 50, 14, 14, device = input.device) 583 | mem_2s = torch.zeros(input.size(0), 50, 7, 7, device = input.device) 584 | 585 | membrane_f0 = torch.zeros(input.size(0), 200, device = input.device) 586 | membrane_f1 = torch.zeros(input.size(0), 10, device = input.device) 587 | 588 | for i in range(steps): 589 | # Poisson input spike generation 590 | rand_num = torch.rand(input.size(0), input.size(1), input.size(2), input.size(3), device =input.device) 591 | Poisson_d_input = ((torch.abs(input)/2) > rand_num).type(torch.cuda.FloatTensor) 592 | Poisson_d_input = torch.mul(Poisson_d_input, torch.sign(input)) 593 | 594 | # convolutional Layer 595 | mem_1 = mem_1 + self.cnn1(Poisson_d_input) 596 | mem_1, out = LIF_sNeuron(mem_1, self.cnn1.threshold, l, i) 597 | 598 | # pooling Layer 599 | mem_1s = mem_1s + self.avgpool1(out) 600 | mem_1s, out = Pooling_sNeuron(mem_1s, 0.75, i) 601 | 602 | # convolutional Layer 603 | mem_2 = mem_2 + self.cnn2(out) 604 | mem_2, out = LIF_sNeuron(mem_2, self.cnn1.threshold, l, i) 605 | 606 | # pooling Layer 607 | mem_2s = mem_2s + self.avgpool2(out) 608 | mem_2s, out = Pooling_sNeuron(mem_2s, 0.75, i) 609 | 610 | out = out.view(out.size(0), -1) 611 | 612 | # fully-connected Layer 613 | membrane_f0 = membrane_f0 + self.fc0(out) 614 | membrane_f0, out = LIF_sNeuron(membrane_f0, self.fc0.threshold, l, i) 615 | 616 | membrane_f1 = membrane_f1 + self.fc1(out) 617 | 618 | return membrane_f1 / self.fc1.threshold / steps 619 | 620 | if __name__ == '__main__': 621 | main() 622 | -------------------------------------------------------------------------------- /vgg7.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import shutil 4 | import time 5 | import math 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.parallel 9 | import torch.backends.cudnn as cudnn 10 | from torch.autograd import Variable 11 | import torch.optim 12 | import torch.utils.data 13 | import torchvision 14 | import torchvision.transforms as transforms 15 | import torchvision.models as models 16 | 17 | import numpy as np 18 | import random 19 | 20 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 21 | 22 | model_names = sorted(name for name in models.__dict__ 23 | if name.islower() and not name.startswith("__") 24 | and callable(models.__dict__[name])) 25 | parser = argparse.ArgumentParser(description='PyTorch MNIST Training') 26 | parser.add_argument('--dataset', default='MNIST', type=str, help='dataset = [MNIST]') 27 | parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18', 28 | choices=model_names, 29 | help='model architecture: ' + 30 | ' | '.join(model_names) + 31 | ' (default: resnet18)') 32 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 33 | help='manual epoch number (useful on restarts)') 34 | parser.add_argument('-b', '--batch-size', default=100, type=int, 35 | metavar='N', help='mini-batch size (default: 100)') 36 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 37 | help='momentum') 38 | parser.add_argument('--weight-decay', '--wd', default=5e-4, type=float, 39 | metavar='W', help='weight decay (default: 1e-4)') 40 | parser.add_argument('--print-freq', '-p', default=500, type=int, 41 | metavar='N', help='print frequency (default: 10)') 42 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 43 | help='path to latest checkpoint (default: none)') 44 | parser.add_argument('-load', default='', type=str, metavar='PATH', 45 | help='path to training mask (default: none)') 46 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 47 | help='evaluate model on validation set') 48 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 49 | help='use pre-trained model') 50 | parser.add_argument('--lr', '--learning-rate', default=0.0005, type=float, 51 | metavar='LR', help='initial learning rate') 52 | parser.add_argument('-j', '--workers', default=1, type=int, metavar='N', 53 | help='number of data loading workers (default: 4)') 54 | parser.add_argument('--epochs', default=100, type=int, metavar='N', 55 | help='number of total epochs to run') 56 | parser.add_argument('--steps', default=100, type=int, metavar='N', 57 | help='number of time steps to run') 58 | parser.add_argument('--vth', default=2, type=float, metavar='vth', 59 | help='threshold') 60 | parser.add_argument('--leak', default=1, type=float, metavar='leak', 61 | help='leaky parameter') 62 | parser.add_argument('--hz', default=5, type=int, metavar='hz', 63 | help='scale update hz') 64 | 65 | 66 | best_prec1 = 0 67 | change = 50 68 | change2 = 75 69 | change3 = 100 70 | 71 | tp1 = []; 72 | tp5 = []; 73 | ep = []; 74 | lRate = []; 75 | device_num = 1 76 | device = torch.device("cuda:0") 77 | 78 | tp1_tr = []; 79 | tp5_tr = []; 80 | losses_tr = []; 81 | losses_eval = []; 82 | 83 | args = parser.parse_args() 84 | 85 | 86 | sign = 1 87 | 88 | scale11 = 1 89 | scale12 = 1 90 | scalep1 = 1 91 | scale21 = 1 92 | scale22 = 1 93 | scale23 = 1 94 | scalef0 = 1 95 | 96 | def main(): 97 | global args, best_prec1, device_num,sign 98 | 99 | batch_size = args.batch_size 100 | 101 | seed1 = 44 102 | seed2 = 56 103 | seed3 = 78 104 | torch.manual_seed(seed1) 105 | torch.cuda.manual_seed(seed2) 106 | torch.cuda.manual_seed_all(seed3) 107 | np.random.seed(seed1) 108 | random.seed(seed2) 109 | cudnn.benchmark = False 110 | cudnn.deterministic = True 111 | 112 | print('\n'+'='*15+'settings'+'='*15) 113 | print('lr: ', args.lr) 114 | print('change lr point:%d %d %d'%(change,change2,change3)) 115 | print('batchsize:',batch_size) 116 | print('steps:', args.steps) 117 | print('vth:', args.vth) 118 | print('leak:{}'.format(args.leak)) 119 | print('hz:{}'.format(args.hz)) 120 | print('seed:%d %d %d'%(seed1,seed2,seed3)) 121 | print('='*15+'settings'+'='*15+'\n') 122 | 123 | model = CNNModel() 124 | 125 | print(model) 126 | 127 | model = torch.nn.DataParallel(model) 128 | model.to(device) 129 | 130 | 131 | criterion = torch.nn.MSELoss(reduction='sum') 132 | criterion_en = torch.nn.CrossEntropyLoss() 133 | 134 | learning_rate = args.lr 135 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=5e-4) 136 | 137 | # optionally resume from a checkpoint 138 | if args.resume: 139 | if os.path.isfile(args.resume): 140 | print("=> loading checkpoint '{}'".format(args.resume)) 141 | checkpoint = torch.load(args.resume) 142 | args.start_epoch = checkpoint['epoch'] 143 | best_prec1 = checkpoint['best_prec1'] 144 | model.load_state_dict(checkpoint['state_dict']) 145 | optimizer.load_state_dict(checkpoint['optimizer']) 146 | print("=> loaded checkpoint '{}' (epoch {})" 147 | .format(args.resume, checkpoint['epoch'])) 148 | else: 149 | print("=> no checkpoint found at '{}'".format(args.resume)) 150 | 151 | dataset_path = '/data/Zzzxd/cifar10-py' 152 | 153 | # Dataloader 154 | normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.557, 0.549, 0.5534]) 155 | transform_train = transforms.Compose([ 156 | transforms.RandomCrop(32, padding=4), 157 | transforms.RandomHorizontalFlip(), 158 | transforms.ToTensor(), 159 | normalize, 160 | ]) 161 | train_data = torchvision.datasets.CIFAR10(dataset_path, train=True, download=True, transform=transform_train) 162 | train_loader = torch.utils.data.DataLoader(train_data, 163 | batch_size=args.batch_size, shuffle=True, 164 | num_workers=args.workers, 165 | pin_memory=True) 166 | 167 | transform_test = transforms.Compose([ 168 | transforms.ToTensor(), 169 | normalize, 170 | ]) 171 | val_data = torchvision.datasets.CIFAR10(dataset_path, train=False, download=True, transform=transform_test) 172 | val_loader = torch.utils.data.DataLoader(val_data, # val_data for testing 173 | batch_size=int(args.batch_size/2), shuffle=False, 174 | num_workers=args.workers, 175 | pin_memory=False) 176 | 177 | print('read dataset done') 178 | if args.evaluate: 179 | validate(val_loader, model, criterion, criterion_en, time_steps=args.steps, leak=args.leak) 180 | return 181 | 182 | for epoch in range(args.start_epoch, args.epochs): 183 | if epoch % args.hz == 0 and args.hz < args.epochs: 184 | sign = 1 185 | else: 186 | sign = 0 187 | 188 | start = time.time() 189 | adjust_learning_rate(optimizer, epoch) 190 | 191 | ep.append(epoch) 192 | 193 | # train for one epoch 194 | train(train_loader, model, criterion, criterion_en, optimizer, epoch, time_steps=args.steps, leak=args.leak) 195 | 196 | # evaluate on validation set 197 | modeltest = model.module 198 | prec1 = validate(val_loader, modeltest, criterion, criterion_en, time_steps=args.steps, leak=args.leak) 199 | 200 | # remember best prec@1 and save checkpoint 201 | is_best = prec1 > best_prec1 202 | best_prec1 = max(prec1, best_prec1) 203 | save_checkpoint({ 204 | 'epoch': epoch + 1, 205 | 'arch': args.arch, 206 | 'state_dict': model.state_dict(), 207 | 'best_prec1': best_prec1, 208 | 'optimizer': optimizer.state_dict(), 209 | }, is_best) 210 | time_use = time.time() - start 211 | print('time used this epoch: %d h%dmin%ds' %(time_use//3600,(time_use%3600)//60,time_use%60)) 212 | 213 | if sign == 1: 214 | print('\n'+'='*15+'scale'+'='*15) 215 | print('scale11: ', scale11) 216 | print('scale12: ', scale12) 217 | print('scalep1: ', scalep1) 218 | print('scale21: ', scale21) 219 | print('scale22: ', scale22) 220 | print('scale23: ', scale23) 221 | print('scalef0: ', scalef0) 222 | print('='*15+'scale'+'='*15+'\n') 223 | 224 | for k in range(0, args.epochs - args.start_epoch): 225 | print('Epoch: [{0}/{1}]\t' 226 | 'LR:{2}\t' 227 | 'Prec@1 {top1:.3f} \t' 228 | 'Prec@5 {top5:.3f} \t' 229 | 'En_Loss_Eval {losses_en_eval: .4f} \t' 230 | 'Prec@1_tr {top1_tr:.3f} \t' 231 | 'Prec@5_tr {top5_tr:.3f} \t' 232 | 'En_Loss_train {losses_en: .4f}'.format( 233 | ep[k], args.epochs, lRate[k], top1=tp1[k], top5=tp5[k], losses_en_eval=losses_eval[k], top1_tr=tp1_tr[k], 234 | top5_tr=tp5_tr[k], losses_en=losses_tr[k])) 235 | print('best_acc={}'.format(best_prec1)) 236 | 237 | 238 | def print_view(v): 239 | v = v.view(v.size(0), -1) 240 | j = 0 241 | for i in v[0]: 242 | print(i) 243 | j = j + 1 244 | print(j) 245 | 246 | def grad_cal(scale, IF_in): 247 | out = scale * IF_in.gt(0).type(torch.cuda.FloatTensor) 248 | return out 249 | 250 | def ave(output, input): 251 | c = input >= output 252 | if input[c].sum() < 1e-3: 253 | return 1 254 | return output[c].sum()/input[c].sum() 255 | 256 | def ave_p(output, input): 257 | if input.sum() < 1e-3: 258 | return 1 259 | return output.sum()/input.sum() 260 | 261 | 262 | def train(train_loader, model, criterion, criterion_en, optimizer, epoch, time_steps, leak): 263 | batch_time = AverageMeter() 264 | data_time = AverageMeter() 265 | losses = AverageMeter() 266 | top1 = AverageMeter() 267 | top5 = AverageMeter() 268 | 269 | top1_tr = AverageMeter() 270 | top5_tr = AverageMeter() 271 | losses_en = AverageMeter() 272 | 273 | # switch to train mode 274 | model.train() 275 | 276 | end = time.time() 277 | start_end = end 278 | for i, (input, target) in enumerate(train_loader): 279 | # measure data loading time 280 | data_time.update(time.time() - end) 281 | input, target = input.to(device), target.to(device) 282 | labels = target.clone() 283 | 284 | optimizer.zero_grad() # Clear gradients w.r.t. parameters 285 | 286 | output = model(input, steps=time_steps, l=leak) 287 | 288 | targetN = output.data.clone().zero_().to(device) 289 | targetN.scatter_(1, target.unsqueeze(1), 1) 290 | targetN = Variable(targetN.type(torch.cuda.FloatTensor)) 291 | 292 | loss = criterion(output.cpu(), targetN.cpu()) 293 | loss_en = criterion_en(output.cpu(), labels.cpu()) 294 | 295 | # measure accuracy and record loss 296 | prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) 297 | losses.update(loss.item(), input.size(0)) 298 | top1.update(prec1.item(), input.size(0)) 299 | top5.update(prec5.item(), input.size(0)) 300 | 301 | prec1_tr, prec5_tr = accuracy(output.data, target, topk=(1, 5)) 302 | losses_en.update(loss_en.item(), input.size(0)) 303 | top1_tr.update(prec1_tr.item(), input.size(0)) 304 | top5_tr.update(prec5_tr.item(), input.size(0)) 305 | 306 | # compute gradient and do SGD step 307 | loss.backward(retain_graph=False) 308 | 309 | 310 | 311 | 312 | optimizer.step() 313 | 314 | 315 | # measure elapsed time 316 | batch_time.update(time.time() - end) 317 | end = time.time() 318 | 319 | 320 | print('Epoch: [{0}] Prec@1 {top1_tr.avg:.3f} Prec@5 {top5_tr.avg:.3f} Entropy_Loss {loss_en.avg:.4f}' 321 | .format(epoch, top1_tr=top1_tr, top5_tr=top5_tr, loss_en=losses_en)) 322 | time_use = end - start_end 323 | print('train time: %d h%dmin%ds' %(time_use//3600,(time_use%3600)//60,time_use%60)) 324 | 325 | losses_tr.append(losses_en.avg) 326 | tp1_tr.append(top1_tr.avg) 327 | tp5_tr.append(top5_tr.avg) 328 | 329 | 330 | def validate(val_loader, model, criterion, criterion_en, time_steps, leak): 331 | batch_time = AverageMeter() 332 | data_time = AverageMeter() 333 | losses = AverageMeter() 334 | top1 = AverageMeter() 335 | top5 = AverageMeter() 336 | losses_en_eval = AverageMeter() 337 | 338 | # switch to evaluate mode 339 | model.eval() 340 | 341 | end = time.time() 342 | with torch.no_grad(): 343 | for i, (input, target) in enumerate(val_loader): 344 | # measure data loading time 345 | data_time.update(time.time() - end) 346 | input_var = input.to(device) 347 | labels = Variable(target.to(device)) 348 | target = target.to(device) 349 | output = model.tst(input=input_var, steps=time_steps, l=leak) 350 | 351 | targetN = output.data.clone().zero_().to(device) 352 | targetN.scatter_(1, target.unsqueeze(1), 1) 353 | targetN = Variable(targetN.type(torch.cuda.FloatTensor)) 354 | loss = criterion(output.cpu(), targetN.cpu()) 355 | loss_en = criterion_en(output.cpu(), labels.cpu()) 356 | # measure accuracy and record loss 357 | prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) 358 | losses.update(loss.item(), input.size(0)) 359 | top1.update(prec1.item(), input.size(0)) 360 | top5.update(prec5.item(), input.size(0)) 361 | losses_en_eval.update(loss_en.item(), input.size(0)) 362 | 363 | # measure elapsed time 364 | batch_time.update(time.time() - end) 365 | end = time.time() 366 | 367 | print('Test: Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f} Entropy_Loss {losses_en_eval.avg:.4f}' 368 | .format(top1=top1, top5=top5, losses_en_eval=losses_en_eval)) 369 | 370 | tp1.append(top1.avg) 371 | tp5.append(top5.avg) 372 | losses_eval.append(losses_en_eval.avg) 373 | 374 | return top1.avg 375 | 376 | 377 | def save_checkpoint(state, is_best, filename='checkpointT1_cifar10_v7.pth.tar'): 378 | torch.save(state, filename) 379 | if is_best: 380 | shutil.copyfile(filename, 'model_bestT1_cifar10_v7.pth.tar') 381 | 382 | 383 | class AverageMeter(object): 384 | """Computes and stores the average and current value""" 385 | 386 | def __init__(self): 387 | self.reset() 388 | 389 | def reset(self): 390 | self.val = 0 391 | self.avg = 0 392 | self.sum = 0 393 | self.count = 0 394 | 395 | def update(self, val, n=1): 396 | self.val = val 397 | self.sum += val * n 398 | self.count += n 399 | self.avg = self.sum / self.count 400 | 401 | 402 | def adjust_learning_rate(optimizer, epoch): 403 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 404 | lr = args.lr * (1 ** (epoch // change)) 405 | 406 | for param_group in optimizer.param_groups: 407 | if epoch >= change3: 408 | param_group['lr'] = 0.2 * 0.2 * 0.2 * lr 409 | 410 | elif epoch >= change2: 411 | param_group['lr'] = 0.2 * 0.2 * lr 412 | 413 | elif epoch >= change: 414 | param_group['lr'] = 0.2 * lr 415 | 416 | else: 417 | param_group['lr'] = lr 418 | 419 | lRate.append(param_group['lr']) 420 | 421 | 422 | def accuracy(output, target, topk=(1,)): 423 | """Computes the precision@k for the specified values of k""" 424 | maxk = max(topk) 425 | batch_size = target.size(0) 426 | 427 | _, pred = output.topk(maxk, 1, True, True) 428 | pred = pred.t() 429 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 430 | 431 | res = [] 432 | for k in topk: 433 | correct_k = correct[:k].view(-1).float().sum(0) 434 | res.append(correct_k.mul_(100.0 / batch_size)) 435 | return res 436 | 437 | 438 | class SpikingNN(torch.autograd.Function): 439 | def forward(self, input): 440 | self.save_for_backward(input) 441 | return input.gt(0).type(torch.cuda.FloatTensor) 442 | 443 | def backward(self, grad_output): 444 | input, = self.saved_tensors 445 | grad_input = grad_output.clone() 446 | grad_input[input <= 0.0] = 0 447 | return grad_input 448 | 449 | 450 | def LIF_sNeuron(membrane_potential, threshold, l, i): 451 | # check exceed membrane potential and reset 452 | ex_membrane = nn.functional.threshold(membrane_potential, threshold, 0) 453 | membrane_potential = membrane_potential - ex_membrane # hard reset 454 | # generate spike 455 | out = SpikingNN()(ex_membrane) 456 | # decay 457 | # note: the detach has no effects now 458 | membrane_potential = l * membrane_potential.detach() + membrane_potential - membrane_potential.detach() 459 | 460 | return membrane_potential, out 461 | 462 | 463 | def Pooling_sNeuron(membrane_potential, threshold, i): 464 | # check exceed membrane potential and reset 465 | ex_membrane = nn.functional.threshold(membrane_potential, threshold, 0) 466 | membrane_potential = membrane_potential - ex_membrane 467 | # generate spike 468 | out = SpikingNN()(ex_membrane) 469 | 470 | return membrane_potential, out 471 | 472 | 473 | class CNNModel(nn.Module): 474 | def __init__(self): 475 | super(CNNModel, self).__init__() 476 | self.cnn11 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False) 477 | self.cnn12 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False) 478 | self.avgpool1 = nn.AvgPool2d(kernel_size=2) 479 | 480 | self.cnn21 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1, bias=False) 481 | self.cnn22 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1, bias=False) 482 | self.cnn23 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1, bias=False) 483 | 484 | self.avgpool2 = nn.MaxPool2d(kernel_size=2) 485 | 486 | self.fc0 = nn.Linear(128 * 8 * 8, 1024, bias=False) 487 | self.fc1 = nn.Linear(1024, 10, bias=False) 488 | 489 | for m in self.modules(): 490 | if isinstance(m, nn.Conv2d): 491 | n = m.kernel_size[0] * m.kernel_size[1] * m.in_channels 492 | variance1 = math.sqrt(2.0 / n) 493 | m.weight.data.normal_(0, variance1) 494 | # define threshold 495 | m.threshold = args.vth 496 | 497 | elif isinstance(m, nn.Linear): 498 | size = m.weight.size() 499 | fan_in = size[1] # number of columns 500 | variance2 = math.sqrt(2.0 / fan_in) 501 | m.weight.data.normal_(0.0, variance2) 502 | # define threshold 503 | m.threshold = args.vth 504 | 505 | def forward(self, input, steps=100, l=1): 506 | global scale11,scale12,scale21,scale22,scale23,scalef0,scalep1,sign 507 | 508 | mem_11 = torch.zeros(input.size(0), 64, 32, 32, device = input.device) 509 | mem_12 = torch.zeros(input.size(0), 64, 32, 32, device = input.device) 510 | 511 | mem_1s = torch.zeros(input.size(0), 64, 16, 16, device = input.device) 512 | 513 | mem_21 = torch.zeros(input.size(0), 128, 16, 16, device = input.device) 514 | mem_22 = torch.zeros(input.size(0), 128, 16, 16, device = input.device) 515 | mem_23 = torch.zeros(input.size(0), 128, 16, 16, device = input.device) 516 | 517 | membrane_f0 = torch.zeros(input.size(0), 1024, device = input.device) 518 | 519 | Total_input = torch.zeros(input.size(0), 3, 32, 32, device = input.device) 520 | 521 | Total_11_output = torch.zeros(input.size(0), 64, 32, 32, device = input.device) 522 | IF_11_in = torch.zeros(input.size(0), 64, 32, 32, device = input.device) 523 | 524 | Total_12_output = torch.zeros(input.size(0), 64, 32, 32, device = input.device) 525 | IF_12_in = torch.zeros(input.size(0), 64, 32, 32, device = input.device) 526 | 527 | Total_p1_output = torch.zeros(input.size(0), 64, 16, 16,device = input.device) 528 | IF_p1_in = torch.zeros(input.size(0), 64, 16, 16, device = input.device) 529 | 530 | Total_21_output = torch.zeros(input.size(0), 128, 16, 16, device = input.device) 531 | IF_21_in = torch.zeros(input.size(0), 128, 16, 16, device = input.device) 532 | 533 | Total_22_output = torch.zeros(input.size(0), 128, 16, 16, device = input.device) 534 | IF_22_in = torch.zeros(input.size(0), 128, 16, 16, device = input.device) 535 | 536 | Total_23_output = torch.zeros(input.size(0), 128, 16, 16, device = input.device) 537 | IF_23_in = torch.zeros(input.size(0), 128, 16, 16, device = input.device) 538 | 539 | Total_f0_output = torch.zeros(input.size(0), 1024, device = input.device) 540 | IF_f0_in = torch.zeros(input.size(0), 1024, device = input.device) 541 | 542 | with torch.no_grad(): 543 | for i in range(steps): 544 | # Poisson input spike generation 545 | rand_num = torch.rand(input.size(0), input.size(1), input.size(2), input.size(3), device = input.device) 546 | Poisson_d_input = (torch.abs(input) > rand_num) 547 | Poisson_d_input = torch.mul(Poisson_d_input.float(), torch.sign(input)) 548 | Total_input = Total_input + Poisson_d_input 549 | 550 | # convolutional Layer 551 | in_layer = self.cnn11(Poisson_d_input) 552 | mem_11 = mem_11 + in_layer 553 | mem_11, out = LIF_sNeuron(mem_11, self.cnn11.threshold, l, i) 554 | IF_11_in = IF_11_in + in_layer 555 | Total_11_output = Total_11_output + out 556 | 557 | 558 | in_layer = self.cnn12(out) 559 | mem_12 = mem_12 + in_layer 560 | mem_12, out = LIF_sNeuron(mem_12, self.cnn12.threshold, l, i) 561 | IF_12_in = IF_12_in + in_layer 562 | Total_12_output = Total_12_output + out 563 | 564 | 565 | # pooling Layer 566 | in_layer = self.avgpool1(out) 567 | mem_1s = mem_1s + in_layer 568 | mem_1s, out = Pooling_sNeuron(mem_1s, 0.75, i) 569 | IF_p1_in = IF_p1_in + in_layer 570 | Total_p1_output = Total_p1_output + out 571 | 572 | # convolutional Layer 573 | in_layer = self.cnn21(out) 574 | mem_21 = mem_21 + in_layer 575 | mem_21, out = LIF_sNeuron(mem_21, self.cnn21.threshold, l, i) 576 | IF_21_in = IF_21_in + in_layer 577 | Total_21_output = Total_21_output + out 578 | 579 | in_layer = self.cnn22(out) 580 | mem_22 = mem_22 + in_layer 581 | mem_22, out = LIF_sNeuron(mem_22, self.cnn22.threshold, l, i) 582 | IF_22_in = IF_22_in + in_layer 583 | Total_22_output = Total_22_output + out 584 | 585 | in_layer = self.cnn23(out) 586 | mem_23 = mem_23 + in_layer 587 | mem_23, out = LIF_sNeuron(mem_23, self.cnn23.threshold, l, i) 588 | IF_23_in = IF_23_in + in_layer 589 | Total_23_output = Total_23_output + out 590 | 591 | out = self.avgpool2(out) 592 | out = out.view(out.size(0), -1) 593 | 594 | # fully-connected Layer 595 | in_layer = self.fc0(out) 596 | membrane_f0 = membrane_f0 + in_layer 597 | membrane_f0, out = LIF_sNeuron(membrane_f0, self.fc0.threshold, l, i) 598 | IF_f0_in = IF_f0_in + in_layer 599 | Total_f0_output = Total_f0_output + out 600 | 601 | 602 | if sign == 1: 603 | scalef0 = 0.6 * ave(Total_f0_output, IF_f0_in) + 0.4 * scalef0 604 | scale11 = 0.6 * ave(Total_11_output, IF_11_in) + 0.4 * scale11 605 | scale12 = 0.6 * ave(Total_12_output, IF_12_in) + 0.4 * scale12 606 | scalep1 = 0.6 * ave_p(Total_p1_output, IF_p1_in) + 0.4 * scalep1 607 | scale21 = 0.6 * ave(Total_21_output, IF_21_in) + 0.4 * scale21 608 | scale22 = 0.6 * ave(Total_22_output, IF_22_in) + 0.4 * scale22 609 | scale23 = 0.6 * ave(Total_23_output, IF_23_in) + 0.4 * scale23 610 | 611 | 612 | scale_f0 = grad_cal(scalef0, IF_f0_in) 613 | scale_11 = grad_cal(scale11, IF_11_in) 614 | scale_12 = grad_cal(scale12, IF_12_in) 615 | scale_p1 = grad_cal(scalep1, IF_p1_in) 616 | scale_21 = grad_cal(scale21, IF_21_in) 617 | scale_22 = grad_cal(scale22, IF_22_in) 618 | scale_23 = grad_cal(scale23, IF_23_in) 619 | 620 | with torch.enable_grad(): 621 | cnn11_in = self.cnn11(Total_input.detach()) 622 | tem = Total_11_output.detach() 623 | out = torch.mul(cnn11_in, scale_11) 624 | Total_11_output = out - out.detach() + tem 625 | 626 | cnn12_in = self.cnn12(Total_11_output) 627 | tem = Total_12_output.detach() 628 | out = torch.mul(cnn12_in, scale_12) 629 | Total_12_output = out - out.detach() + tem 630 | 631 | pool1_in = self.avgpool1(Total_12_output) 632 | tem = Total_p1_output.detach() 633 | out = torch.mul(pool1_in, scale_p1) 634 | Total_p1_output = out - out.detach() + tem 635 | 636 | cnn21_in = self.cnn21(Total_p1_output) 637 | tem = Total_21_output.detach() 638 | out = torch.mul(cnn21_in, scale_21) 639 | Total_21_output = out - out.detach() + tem 640 | 641 | cnn22_in = self.cnn22(Total_21_output) 642 | tem = Total_22_output.detach() 643 | out = torch.mul(cnn22_in, scale_22) 644 | Total_22_output = out - out.detach() + tem 645 | 646 | cnn23_in = self.cnn23(Total_22_output) 647 | tem = Total_23_output.detach() 648 | out = torch.mul(cnn23_in, scale_23) 649 | Total_23_output = out - out.detach() + tem 650 | 651 | Total_p2_output = self.avgpool2(Total_23_output) 652 | 653 | fc0_in = self.fc0(Total_p2_output.view(Total_p2_output.size(0),-1)) 654 | tem = Total_f0_output.detach() 655 | out = torch.mul(fc0_in, scale_f0) 656 | Total_f0_output = out - out.detach() + tem 657 | 658 | fc1_in = self.fc1(Total_f0_output) 659 | 660 | return fc1_in/self.fc1.threshold/steps 661 | 662 | 663 | def tst(self, input, steps=100, l=1): 664 | mem_11 = torch.zeros(input.size(0), 64, 32, 32,device = input.device) 665 | mem_12 = torch.zeros(input.size(0), 64, 32, 32,device = input.device) 666 | mem_1s = torch.zeros(input.size(0), 64, 16, 16,device = input.device) 667 | 668 | mem_21 = torch.zeros(input.size(0), 128, 16, 16,device = input.device) 669 | mem_22 = torch.zeros(input.size(0), 128, 16, 16,device = input.device) 670 | mem_23 = torch.zeros(input.size(0), 128, 16, 16,device = input.device) 671 | 672 | membrane_f0 = torch.zeros(input.size(0), 1024,device = input.device) 673 | membrane_f1 = torch.zeros(input.size(0), 10,device = input.device) 674 | 675 | for i in range(steps): 676 | # Poisson input spike generation 677 | rand_num = torch.rand(input.size(0), input.size(1), input.size(2), input.size(3), device = input.device) 678 | Poisson_d_input = (torch.abs(input) > rand_num).type(torch.cuda.FloatTensor) 679 | Poisson_d_input = torch.mul(Poisson_d_input, torch.sign(input)) 680 | 681 | # convolutional Layer 682 | mem_11 = mem_11 + self.cnn11(Poisson_d_input) 683 | mem_11, out = LIF_sNeuron(mem_11, self.cnn11.threshold, l, i) 684 | 685 | mem_12 = mem_12 + self.cnn12(out) 686 | mem_12, out = LIF_sNeuron(mem_12, self.cnn12.threshold, l, i) 687 | 688 | # pooling Layer 689 | mem_1s = mem_1s + self.avgpool1(out) 690 | mem_1s, out = Pooling_sNeuron(mem_1s, 0.75, i) 691 | 692 | # convolutional Layer 693 | mem_21 = mem_21 + self.cnn21(out) 694 | mem_21, out = LIF_sNeuron(mem_21, self.cnn21.threshold, l, i) 695 | 696 | mem_22 = mem_22 + self.cnn22(out) 697 | mem_22, out = LIF_sNeuron(mem_22, self.cnn22.threshold, l, i) 698 | 699 | mem_23 = mem_23 + self.cnn23(out) 700 | mem_23, out = LIF_sNeuron(mem_23, self.cnn23.threshold, l, i) 701 | 702 | # pooling Layer 703 | out = self.avgpool2(out) 704 | out = out.view(out.size(0), -1) 705 | 706 | # fully-connected Layer 707 | membrane_f0 = membrane_f0 + self.fc0(out) 708 | membrane_f0, out = LIF_sNeuron(membrane_f0, self.fc0.threshold, l, i) 709 | 710 | membrane_f1 = membrane_f1 + self.fc1(out) 711 | return membrane_f1 / self.fc1.threshold / steps 712 | 713 | if __name__ == '__main__': 714 | main() 715 | --------------------------------------------------------------------------------