├── assets ├── arch.png ├── poster.pdf └── slides.pdf ├── src ├── models │ └── spike │ │ ├── mem_pred_d8k8_c3(relu(c2(c1(x))+c1(x)))_125.pkl │ │ ├── surrogate.py │ │ └── neuron.py └── utils │ └── registry.py ├── configs └── experiment │ └── spikingssm │ ├── minst.yaml │ ├── pminst.yaml │ ├── pathx.yaml │ ├── aan.yaml │ ├── imdb.yaml │ ├── listops.yaml │ ├── cifar.yaml │ ├── pathfinder.yaml │ └── wt103.yaml ├── dataset.py ├── LICENSE ├── utils.py ├── generate.py ├── .gitignore ├── model.py ├── readme.md ├── models └── spike │ └── ss4d.py ├── train.py └── convert.ipynb /assets/arch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenshuaijie/SDN/HEAD/assets/arch.png -------------------------------------------------------------------------------- /assets/poster.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenshuaijie/SDN/HEAD/assets/poster.pdf -------------------------------------------------------------------------------- /assets/slides.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenshuaijie/SDN/HEAD/assets/slides.pdf -------------------------------------------------------------------------------- /src/models/spike/mem_pred_d8k8_c3(relu(c2(c1(x))+c1(x)))_125.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenshuaijie/SDN/HEAD/src/models/spike/mem_pred_d8k8_c3(relu(c2(c1(x))+c1(x)))_125.pkl -------------------------------------------------------------------------------- /src/models/spike/surrogate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class piecewise_quadratic(torch.autograd.Function): 5 | @staticmethod 6 | def forward(ctx, x): 7 | if x.requires_grad: 8 | ctx.save_for_backward(x) 9 | return (x >= 0).to(x) 10 | 11 | @staticmethod 12 | def backward(ctx, grad_output): 13 | x = ctx.saved_tensors[0] 14 | x_abs = x.abs() 15 | mask = x_abs > 1 16 | grad_x = (grad_output * (-x_abs + 1.0)).masked_fill_(mask, 0) 17 | return grad_x, None 18 | 19 | 20 | def piecewise_quadratic_surrogate(): 21 | return piecewise_quadratic.apply 22 | -------------------------------------------------------------------------------- /configs/experiment/spikingssm/minst.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /pipeline: mnist 4 | - override /scheduler: cosine_warmup 5 | 6 | model: 7 | _name_: spikingssm 8 | dropout: 0.1 9 | # tie_dropout: true 10 | n_layers: 2 11 | d_model: 400 12 | prenorm: false 13 | layer: 14 | d_state: 64 15 | bidirectional: false 16 | lr: 0.001 17 | 18 | dataset: 19 | permute: false 20 | 21 | loader: 22 | batch_size: 50 23 | 24 | optimizer: 25 | lr: 0.01 26 | weight_decay: 0.01 27 | 28 | trainer: 29 | max_epochs: 100 30 | 31 | scheduler: 32 | num_training_steps: 90000 # 200 epochs 33 | 34 | train: 35 | seed: 1111 36 | -------------------------------------------------------------------------------- /configs/experiment/spikingssm/pminst.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /pipeline: mnist 4 | - override /scheduler: cosine_warmup 5 | 6 | model: 7 | _name_: spikingssm 8 | dropout: 0.1 9 | # tie_dropout: true 10 | n_layers: 4 11 | d_model: 256 12 | prenorm: false 13 | layer: 14 | d_state: 64 15 | bidirectional: false 16 | lr: 0.001 17 | 18 | dataset: 19 | permute: true 20 | 21 | loader: 22 | batch_size: 50 23 | 24 | optimizer: 25 | lr: 0.01 26 | weight_decay: 0.01 27 | 28 | trainer: 29 | max_epochs: 100 30 | 31 | scheduler: 32 | num_training_steps: 90000 # 200 epochs 33 | 34 | train: 35 | seed: 1111 36 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | 4 | 5 | class MemDataset(Dataset): 6 | def __init__(self, filename, transform=None) -> None: 7 | super().__init__() 8 | data = torch.load(filename) 9 | self.transform = transform 10 | self.input = data["input"] 11 | self.mem = data["mem"] 12 | self.spike = data["spike"] 13 | 14 | def __len__(self): 15 | return self.input.size(0) 16 | 17 | def __getitem__(self, index): 18 | x = self.input[index] 19 | m = self.mem[index] 20 | s = self.spike[index] 21 | if self.transform: 22 | x = self.transform(x) 23 | return x, m, s 24 | -------------------------------------------------------------------------------- /configs/experiment/spikingssm/pathx.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /pipeline: pathx 4 | - override /scheduler: cosine_warmup 5 | 6 | scheduler: 7 | num_training_steps: 500000 # 50 epochs 8 | num_warmup_steps: 50000 9 | 10 | model: 11 | _name_: spikingssm 12 | dropout: 0. 13 | n_layers: 6 14 | prenorm: true 15 | d_model: 256 16 | norm: batch 17 | layer: 18 | d_state: 64 19 | lr: 0.001 20 | dt_min: 0.0001 21 | dt_max: 0.1 22 | bidirectional: true 23 | learnable_vth: true 24 | shared_vth: false 25 | trainable_B: true 26 | 27 | loader: 28 | batch_size: 16 29 | 30 | optimizer: 31 | lr: 0.001 32 | weight_decay: 0.01 33 | 34 | trainer: 35 | max_epochs: 50 36 | 37 | train: 38 | seed: 3333 39 | interval: step # For cosine scheduler 40 | -------------------------------------------------------------------------------- /configs/experiment/spikingssm/aan.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /pipeline: aan 4 | - override /scheduler: cosine_warmup 5 | 6 | decoder: 7 | mode: pool 8 | 9 | model: 10 | _name_: spikingssm 11 | dropout: 0 12 | n_layers: 6 13 | d_model: 256 14 | prenorm: true 15 | norm: batch 16 | layer: 17 | d_state: 64 18 | lr: 0.001 19 | dt_min: 0.001 20 | dt_max: 0.1 21 | bidirectional: true 22 | learnable_vth: true 23 | shared_vth: false 24 | trainable_B: false 25 | 26 | loader: 27 | batch_size: 64 28 | 29 | optimizer: 30 | lr: 0.01 31 | weight_decay: 0.01 32 | 33 | scheduler: 34 | num_training_steps: 50000 # 20 epochs 35 | num_warmup_steps: 5000 36 | 37 | trainer: 38 | max_epochs: 20 39 | 40 | train: 41 | seed: 3333 42 | interval: step 43 | -------------------------------------------------------------------------------- /configs/experiment/spikingssm/imdb.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /pipeline: imdb 4 | - override /scheduler: cosine_warmup 5 | 6 | decoder: 7 | mode: pool 8 | 9 | model: 10 | dropout: 0.0 11 | n_layers: 6 12 | d_model: 256 13 | prenorm: true 14 | norm: batch 15 | layer: 16 | d_state: 4 17 | lr: 0.001 18 | dt_min: 0.001 19 | dt_max: 0.1 20 | bidirectional: true 21 | learnable_vth: true 22 | shared_vth: false 23 | trainable_B: true 24 | 25 | dataset: 26 | l_max: 4096 27 | level: char 28 | 29 | loader: 30 | batch_size: 16 31 | 32 | optimizer: 33 | lr: 0.01 34 | weight_decay: 0.05 35 | 36 | scheduler: 37 | num_training_steps: 50000 38 | num_warmup_steps: 5000 39 | 40 | trainer: 41 | max_epochs: 32 42 | 43 | train: 44 | seed: 3333 45 | -------------------------------------------------------------------------------- /configs/experiment/spikingssm/listops.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /pipeline: listops 4 | - override /scheduler: cosine_warmup 5 | 6 | model: 7 | _name_: spikingssm 8 | dropout: 0 9 | # tie_dropout: true 10 | n_layers: 6 11 | d_model: 256 12 | prenorm: false 13 | norm: batch 14 | layer: 15 | d_state: 4 16 | lr: 0.001 17 | dt_min: 0.001 18 | dt_max: 0.1 19 | bidirectional: true 20 | learnable_vth: true 21 | shared_vth: false 22 | trainable_B: true 23 | 24 | decoder: 25 | mode: pool 26 | 27 | loader: 28 | batch_size: 32 29 | 30 | optimizer: 31 | lr: 0.01 32 | weight_decay: 0.01 33 | 34 | scheduler: 35 | num_training_steps: 120000 36 | num_warmup_steps: 12000 37 | 38 | trainer: 39 | max_epochs: 40 40 | 41 | train: 42 | seed: 1234 43 | -------------------------------------------------------------------------------- /configs/experiment/spikingssm/cifar.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /pipeline: cifar 4 | - override /scheduler: cosine_warmup 5 | 6 | model: 7 | _name_: spikingssm 8 | dropout: 0.1 9 | # tie_dropout: true 10 | n_layers: 6 11 | d_model: 512 12 | prenorm: false 13 | norm: layer 14 | layer: 15 | d_state: 64 16 | bidirectional: true 17 | learnable_vth: true 18 | lr: 0.001 19 | dt_min: 0.001 20 | dt_max: 0.1 21 | shared_vth: false 22 | trainable_B: false 23 | 24 | dataset: 25 | grayscale: true 26 | 27 | loader: 28 | batch_size: 50 29 | 30 | optimizer: 31 | lr: 0.01 32 | weight_decay: 0.01 33 | 34 | trainer: 35 | max_epochs: 200 36 | 37 | scheduler: 38 | num_training_steps: 180000 # 200 epochs 39 | num_warmup_steps: 18000 40 | 41 | train: 42 | seed: 2222 43 | -------------------------------------------------------------------------------- /configs/experiment/spikingssm/pathfinder.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /pipeline: pathfinder 4 | - override /scheduler: cosine_warmup 5 | 6 | scheduler: 7 | num_training_steps: 500000 # 200 epochs 8 | num_warmup_steps: 50000 9 | 10 | model: 11 | _name_: spikingssm 12 | dropout: 0.0 13 | n_layers: 6 14 | prenorm: true 15 | d_model: 256 16 | norm: batch 17 | layer: 18 | d_state: 64 19 | lr: 0.001 20 | dt_min: 0.001 21 | dt_max: 0.1 22 | bidirectional: true 23 | learnable_vth: true 24 | shared_vth: false 25 | trainable_B: false 26 | 27 | decoder: 28 | mode: pool 29 | 30 | loader: 31 | batch_size: 64 32 | 33 | optimizer: 34 | lr: 0.004 35 | weight_decay: 0.01 36 | 37 | trainer: 38 | max_epochs: 200 39 | 40 | train: 41 | seed: 3333 42 | interval: step 43 | -------------------------------------------------------------------------------- /configs/experiment/spikingssm/wt103.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /pipeline: wt103 4 | 5 | # Dataset 6 | dataset: 7 | test_split: True 8 | 9 | loader: 10 | batch_size: 1 11 | l_max: 8192 12 | n_context: 1 13 | eval: 14 | batch_size: null 15 | l_max: null 16 | 17 | task: 18 | div_val: 4 19 | dropemb: 0.25 20 | dropsoft: 0.25 21 | 22 | # Model 23 | model: 24 | _name_: spikingssm 25 | dropout: 0.1 26 | prenorm: True 27 | n_layers: 16 28 | d_model: 1024 29 | transposed: True 30 | layer: 31 | d_state: 64 32 | lr: 0.001 33 | 34 | # Optimizer (adamw) 35 | optimizer: 36 | lr: 5e-4 37 | weight_decay: 0.1 38 | 39 | # Scheduler (cosine) 40 | trainer: 41 | max_epochs: 1000 42 | 43 | scheduler: 44 | num_warmup_steps: 1000 45 | num_training_steps: 800000 46 | 47 | train: 48 | seed: 1111 49 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Shuaijie Shen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import torch 4 | 5 | 6 | def parameters_count(model: torch.nn.Module): 7 | return sum(p.nelement() for p in model.parameters()) 8 | 9 | 10 | class Metrics: 11 | def __init__( 12 | self, name: str, scale: float = 1.0, format=".4f", suffix: str = "" 13 | ) -> None: 14 | self.name = name 15 | self.scale = scale 16 | self.format = format 17 | self.suffix = suffix 18 | self.reset() 19 | 20 | def reset(self): 21 | self.val = 0 22 | self.num = 0 23 | 24 | def update(self, val: float, num: int = 1): 25 | self.val += val * num 26 | self.num += num 27 | 28 | def __lt__(self, other: Metrics): 29 | return self.avg < other.avg 30 | 31 | @property 32 | def avg(self): 33 | if self.num == 0: 34 | return 0.0 35 | return self.val / self.num 36 | 37 | def __str__(self): 38 | return f"{self.name}: {self.avg * self.scale:{self.format}}{self.suffix}" 39 | 40 | 41 | class MetricsCheckpoint: 42 | def __init__(self, key: str = None, **kwargs) -> None: 43 | self.key = key 44 | self.data = kwargs 45 | 46 | def __lt__(self, other: MetricsCheckpoint): 47 | if self.key is not None and self.key in self.data: 48 | return self.data[self.key] < other.data[self.key] 49 | for key in self.data.keys(): 50 | return self.data[key] < other.data[key] 51 | 52 | def __str__(self): 53 | return " | ".join(f"{k}: {v:.6f}" for k, v in self.data.items()) 54 | 55 | 56 | if __name__ == "__main__": 57 | model = torch.nn.Linear(1024, 1024) 58 | print(f"{parameters_count(model):,}") 59 | 60 | m1 = MetricsCheckpoint( 61 | accurcy=0.18, 62 | epoch=1, 63 | ) 64 | m2 = MetricsCheckpoint(accurcy=0.32, epoch=2) 65 | print(m2) 66 | 67 | m = Metrics("accuracy", 100, suffix="%") 68 | m.update(0.9) 69 | print(m) 70 | m.update(0.8) 71 | print(m) 72 | -------------------------------------------------------------------------------- /src/models/spike/neuron.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | import os 4 | 5 | 6 | class SDNNeuron(nn.Module): 7 | ckt_path = "mem_pred_d8k8_c3(relu(c2(c1(x))+c1(x)))_125.pkl" 8 | 9 | def __init__(self, surrogate_function): 10 | super().__init__() 11 | self.surrogate_function = surrogate_function 12 | ckt_path = os.path.join(os.path.dirname(__file__), self.ckt_path) 13 | self.model = torch.jit.load(ckt_path).eval() 14 | 15 | def forward(self, x): 16 | mem = self.pred(x) 17 | s = self.surrogate_function(mem + x - 1.0) 18 | return s 19 | 20 | @torch.no_grad() 21 | def pred(self, x): 22 | shape = x.shape 23 | L = x.size(-1) 24 | return self.model(x.detach().view(-1, 1, L)).view(shape) 25 | 26 | 27 | class BPTTNueron(nn.Module): 28 | def __init__(self, surrogate_function, tau=0.125, vth=1.0, v_r=0): 29 | super().__init__() 30 | self.surrogate_function = surrogate_function 31 | self.tau = tau 32 | self.vth = vth 33 | self.v_r = v_r 34 | 35 | def forward(self, x): 36 | u = torch.zeros_like(x[..., 0]) 37 | out = [] 38 | for i in range(x.size(-1)): 39 | u = u * self.tau + x[..., i] 40 | s = self.surrogate_function(u - self.vth) 41 | out.append(s) 42 | u = (1 - s.detach()) * u + s.detach() * self.v_r 43 | return torch.stack(out, -1) 44 | 45 | 46 | class SLTTNueron(nn.Module): 47 | def __init__(self, surrogate_function, tau=0.125, vth=1.0, v_r=0): 48 | super().__init__() 49 | self.surrogate_function = surrogate_function 50 | self.tau = tau 51 | self.vth = vth 52 | self.v_r = v_r 53 | 54 | def forward(self, x): 55 | u = torch.zeros_like(x[..., 0]) 56 | out = [] 57 | for i in range(x.size(-1)): 58 | u = u.detach() * self.tau + x[..., i] 59 | s = self.surrogate_function(u - self.vth) 60 | out.append(s) 61 | u = (1 - s.detach()) * u + s.detach() * self.v_r 62 | return torch.stack(out, -1) 63 | 64 | 65 | registry = { 66 | "sdn": SDNNeuron, 67 | "bptt": BPTTNueron, 68 | "sltt": SLTTNueron, 69 | } 70 | -------------------------------------------------------------------------------- /generate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import warnings 4 | from torch import Tensor 5 | import argparse 6 | 7 | 8 | def heaviside(x: Tensor): 9 | """heaviside function 10 | 11 | Args: 12 | x (Tensor): u - vth 13 | 14 | Returns: 15 | Tensor: spike 16 | """ 17 | return (x >= 0).int() 18 | 19 | 20 | @torch.no_grad() 21 | def hardreset(x: Tensor, tau: float = 0.2, v_th: float = 1.0): 22 | """perform lif evolution with hardreset mechanism 23 | 24 | Args: 25 | x (Tensor): input currents with (T, N) 26 | tau (float, optional): attenuation coefficient. Defaults to 0.2. 27 | v_th (float, optional): threshold when to spike. Defaults to 1.0. 28 | 29 | Returns: 30 | (Tensor, Tensor): spikes, attenuated membrane potential 31 | """ 32 | x = x.cuda() 33 | if len(x.shape) == 1: 34 | x.view(-1, 1) 35 | y = [] 36 | mem = [] 37 | u = torch.zeros_like(x[0]) 38 | for i in x: 39 | u = tau * u 40 | mem.append(u) 41 | u = u + i 42 | s = heaviside(u - v_th) 43 | y.append(s) 44 | u = u * (1 - s) 45 | y = torch.stack(y).int() 46 | mem = torch.stack(mem) 47 | return y, mem 48 | 49 | 50 | @torch.no_grad() 51 | def generate_dataset( 52 | root, name="training", number=5000, timestp=1024, m=0, std=1.0, tau=0.2 53 | ): 54 | """_summary_ 55 | 56 | Args: 57 | root (path): the path to save dataset 58 | name (str, optional): the name of dataset. Defaults to "training". 59 | number (int, optional): sample number. Defaults to 5000. 60 | timestp (int, optional): steps. Defaults to 1024. 61 | m (int, optional): mean of input current. Defaults to 0. 62 | std (float, optional): std of input current. Defaults to 1.0. 63 | tau (float, optional): attenuation coefficient. Defaults to 0.2. 64 | v_th (float, optional): threshold when to spike. Defaults to 1.0. 65 | """ 66 | filename = os.path.join( 67 | root, f"{name}-mem-T{timestp}-N({m},{std})-{number}-tau_{tau}.pt" 68 | ) 69 | if os.path.exists(filename): 70 | warnings.warn(f"File `{filename}` exists, program terminated!") 71 | return 72 | 73 | os.makedirs(root, exist_ok=True) 74 | print("Sampling input currents ==>") 75 | dset_x = torch.randn(timestp, number) * std + m 76 | print("Generating label ==>") 77 | s, mem = hardreset(dset_x, tau=tau) 78 | print(f"spiking rate: {s.float().mean().item()}") 79 | data = { 80 | "input": dset_x.transpose(0, 1).cpu(), 81 | "mem": mem.transpose(0, 1).cpu(), 82 | "spike": s.transpose(0, 1).cpu(), 83 | } 84 | print("Dataset generated.") 85 | torch.save(data, filename) 86 | print(f"Dataset saved! File: {filename}") 87 | print( 88 | 'File format: {"input": %s, "mem": %s, "spike": %s}' 89 | % (data["input"].size(), data["mem"].size(), data["spike"].size()) 90 | ) 91 | 92 | 93 | if __name__ == "__main__": 94 | parser = argparse.ArgumentParser(description="Dateset generation script.") 95 | parser.add_argument("root", help="path to save dataset.") 96 | parser.add_argument("-n", "--name", default="training", type=str, help="name of dataset") 97 | parser.add_argument("-N", "--number", default=50000, type=int, help="size of dataset") 98 | parser.add_argument("-T", "--timestep", default=1024, type=int, help="number of step") 99 | parser.add_argument("-m", "--mean", default=0.0, type=float, help="mean of input current") 100 | parser.add_argument("-s", "--std", default=1.0, type=float, help="std of input current") 101 | parser.add_argument("-t", "--tau", default=0.2, type=float, help="attenuation coefficient") 102 | args = parser.parse_args() 103 | generate_dataset(args.root, args.name, args.number, args.timestep, args.mean, args.std, args.tau) 104 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import Tensor, nn 6 | 7 | 8 | def get_conv1d_kernel_bn( 9 | in_channels: int, out_channels: int, kernel_size: int, padding: int 10 | ): 11 | return nn.Sequential( 12 | nn.Conv1d( 13 | in_channels, out_channels, kernel_size, padding=padding, groups=in_channels 14 | ), 15 | nn.BatchNorm1d(out_channels), 16 | ) 17 | 18 | 19 | def get_conv1d_k1_bn( 20 | in_channels: int, 21 | out_channels: int, 22 | ): 23 | return nn.Sequential( 24 | nn.Conv1d(in_channels, out_channels, 1), 25 | nn.BatchNorm1d(out_channels), 26 | ) 27 | 28 | 29 | def trunc(x: Tensor, keep: int): 30 | """Truncate a Tensor along the last dimension 31 | 32 | Args: 33 | x (Tensor): tensor to be truncated 34 | keep (int): number of elements kept 35 | Returns: 36 | Tensor: Truncated tensor 37 | """ 38 | return x[..., :keep] 39 | 40 | 41 | class SDN(nn.Module): 42 | def __init__( 43 | self, 44 | d_model=8, 45 | kernel_size=8, 46 | n_layers=1, 47 | ): 48 | super().__init__() 49 | # in order to fuse this module into next layer, the bias is set False, and no batchnorm follow. 50 | self.encoder = nn.Conv1d(1, d_model, kernel_size=1, bias=False) 51 | self.spatial_layers = nn.ModuleList( 52 | [ 53 | get_conv1d_kernel_bn( 54 | d_model, d_model, kernel_size=kernel_size, padding=kernel_size 55 | ) 56 | ] 57 | ) 58 | self.feature_layers = nn.ModuleList( 59 | get_conv1d_k1_bn(d_model, d_model) for _ in range(n_layers) 60 | ) 61 | for _ in range(n_layers - 1): 62 | self.spatial_layers.append( 63 | get_conv1d_kernel_bn( 64 | d_model, d_model, kernel_size=kernel_size, padding=kernel_size - 1 65 | ) 66 | ) 67 | 68 | self.decoder = nn.Conv1d(d_model, 1, 1) 69 | 70 | def forward(self, x): 71 | """ 72 | Input x is shape (B, D, L) 73 | """ 74 | L = x.size(-1) 75 | truncL = partial(trunc, keep=L) 76 | x = self.encoder(x) 77 | for spatial, features in zip(self.spatial_layers, self.feature_layers): 78 | x = F.relu(truncL(spatial(x))) 79 | x = F.relu(features(x) + x) 80 | return self.decoder(x).squeeze() 81 | 82 | 83 | class FusedSDN(nn.Module): 84 | def __init__( 85 | self, 86 | d_model=8, 87 | kernel_size=8, 88 | n_layers=1, 89 | ): 90 | super().__init__() 91 | # fused module has no encoder, which is fused into the first spatial layer. 92 | self.encoder = nn.Conv1d(1, d_model, kernel_size=1, padding=1, bias=False) 93 | self.spatial_layers = nn.ModuleList( 94 | [nn.Conv1d(1, d_model, kernel_size=kernel_size, padding=kernel_size)] 95 | ) 96 | self.feature_layers = nn.ModuleList( 97 | nn.Conv1d(d_model, d_model, 1) for _ in range(n_layers) 98 | ) 99 | for _ in range(n_layers - 1): 100 | self.spatial_layers.append( 101 | nn.Conv1d( 102 | d_model, 103 | d_model, 104 | kernel_size=kernel_size, 105 | padding=kernel_size - 1, 106 | groups=d_model, 107 | ) 108 | ) 109 | 110 | self.decoder = nn.Conv1d(d_model, 1, 1) 111 | 112 | def forward(self, x): 113 | """ 114 | Input x is shape (B, D, L) 115 | """ 116 | L = x.size(-1) 117 | truncL = partial(trunc, keep=L) 118 | for spatial, features in zip(self.spatial_layers, self.feature_layers): 119 | x = F.relu(truncL(spatial(x))) 120 | x = F.relu(features(x) + x) 121 | return self.decoder(x).squeeze() 122 | 123 | 124 | if __name__ == "__main__": 125 | model = SDN() 126 | x = torch.randn(64, 1, 1024) 127 | model(x) 128 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # Surrogate Dynamic Network (SDN) and SpikingSSMs 2 | 3 | ![SpikingSSMs Architecture](assets/arch.png "Architecture and Computation Graph of SpikingSSMs") 4 | 5 | > **SpikingSSMs: Learning Long Sequences with Sparse and Parallel Spiking State Space Models** 6 | > Shuaijie Shen*, Chao Wang*, Renzhuo Huang, Yan Zhong, Qinghai Guo, Zhichao Lu, Jianguo Zhang, Luziwei Leng 7 | > Paper: [https://arxiv.org/abs/2408.14909](https://arxiv.org/abs/2408.14909) 8 | 9 | This repository provides the official implementations and experiments for **SDN** (Surrogate Dynamic Network) and **SpikingSSMs** (Spiking State Space Models). 10 | 11 | --- 12 | 13 | ## Table of Contents 14 | 15 | - [Surrogate Dynamic Network (SDN) and SpikingSSMs](#surrogate-dynamic-network-sdn-and-spikingssms) 16 | - [Table of Contents](#table-of-contents) 17 | - [Overview](#overview) 18 | - [Installation](#installation) 19 | - [Dependencies](#dependencies) 20 | - [SDN Requirements](#sdn-requirements) 21 | - [SpikingSSMs Requirements](#spikingssms-requirements) 22 | - [Quick Start](#quick-start) 23 | - [Generating Datasets](#generating-datasets) 24 | - [Training SDN](#training-sdn) 25 | - [Model Reduction](#model-reduction) 26 | - [Reproducing SpikingSSMs Experiments](#reproducing-spikingssms-experiments) 27 | - [Citation](#citation) 28 | - [License](#license) 29 | 30 | --- 31 | 32 | ## Overview 33 | 34 | This repository contains two core components: 35 | 1. **SDN**: A lightweight module for simulating spiking neuron dynamics. 36 | 2. **SpikingSSMs**: A novel architecture combining spiking neural networks with state space models for long-sequence tasks. 37 | 38 | --- 39 | 40 | ## Installation 41 | 42 | ### Dependencies 43 | 44 | #### SDN Requirements 45 | - Python 3.8+ 46 | - PyTorch ≥1.10 47 | - loguru 48 | 49 | Install via conda/pip: 50 | ```bash 51 | # PyTorch with CUDA 11.8 52 | conda install pytorch torchvision pytorch-cuda=11.8 -c pytorch -c nvidia 53 | 54 | # Loguru 55 | pip install loguru 56 | ``` 57 | 58 | #### SpikingSSMs Requirements 59 | Clone the official S4 repository and install dependencies: 60 | ```bash 61 | git clone https://github.com/state-spaces/s4.git 62 | cd s4 63 | # Follow S4's installation instructions 64 | ``` 65 | 66 | --- 67 | 68 | ## Quick Start 69 | 70 | ### Generating Datasets 71 | 72 | 1. Generate training data: 73 | ```bash 74 | python generate.py dataset 75 | ``` 76 | 77 | 2. Generate test data: 78 | ```bash 79 | python generate.py dataset -n test 80 | ``` 81 | 82 | **Dataset Structure** ([dataset/](dataset/)): 83 | ``` 84 | training-mem-T1024-N(0.0,1.0)-50000-tau_0.2.pt 85 | test-mem-T1024-N(0.0,1.0)-50000-tau_0.2.pt 86 | ``` 87 | 88 | Each file contains: 89 | - `input`: Input current tensor (shape: [50000, 1024]) 90 | - `mem`: Attenuated membrane potential (shape: [50000, 1024]) 91 | - `spike`: Spike train (shape: [50000, 1024]) 92 | 93 | ### Training SDN 94 | 95 | ```bash 96 | python train.py \ 97 | --training 'dataset/training-mem-T1024-N(0.0,1.0)-50000-tau_0.2.pt' \ 98 | --test 'dataset/test-mem-T1024-N(0.0,1.0)-50000-tau_0.2.pt' \ 99 | --save exp1 100 | ``` 101 | 102 | Logs and checkpoints will be saved in [exp1/](exp1/). 103 | 104 | ### Model Reduction 105 | 106 | Optimize SDN for inference: 107 | ```bash 108 | jupyter notebook convert.ipynb # Follow interactive instructions 109 | ``` 110 | 111 | --- 112 | 113 | ## Reproducing SpikingSSMs Experiments 114 | 115 | 1. Clone and setup S4: 116 | ```bash 117 | git clone https://github.com/state-spaces/s4.git 118 | cd s4 119 | # Install S4 dependencies (refer to their documentation) 120 | ``` 121 | 122 | 2. Integrate our components: 123 | ```bash 124 | cp -r /path/to/this/repo/src ./src 125 | cp -r /path/to/this/repo/models ./models 126 | cp -r /path/to/this/repo/configs ./configs 127 | ``` 128 | 129 | 3. Run CIFAR-10 experiment: 130 | ```bash 131 | python -m train experiment=spikingssm/cifar 132 | ``` 133 | 134 | --- 135 | 136 | ## Citation 137 | 138 | If you use this work in your research, please cite: 139 | ```bibtex 140 | @article{Shen_Wang_Huang_Zhong_Guo_Lu_Zhang_Leng_2025, 141 | title = {SpikingSSMs: Learning Long Sequences with Sparse and Parallel Spiking State Space Models}, 142 | volume = {39}, 143 | url = {https://ojs.aaai.org/index.php/AAAI/article/view/34245}, 144 | doi = {10.1609/aaai.v39i19.34245}, 145 | number = {19}, 146 | journal = {Proceedings of the AAAI Conference on Artificial Intelligence}, 147 | author = {Shen, Shuaijie and Wang, Chao and Huang, Renzhuo and Zhong, Yan and Guo, Qinghai and Lu, Zhichao and Zhang, Jianguo and Leng, Luziwei}, 148 | year = {2025}, 149 | month = {Apr.}, 150 | pages = {20380-20388} 151 | } 152 | ``` 153 | 154 | --- 155 | 156 | ## License 157 | 158 | This project is licensed under the [MIT License](LICENSE). -------------------------------------------------------------------------------- /src/utils/registry.py: -------------------------------------------------------------------------------- 1 | optimizer = { 2 | "adam": "torch.optim.Adam", 3 | "adamw": "torch.optim.AdamW", 4 | "rmsprop": "torch.optim.RMSprop", 5 | "sgd": "torch.optim.SGD", 6 | "lamb": "src.utils.optim.lamb.JITLamb", 7 | } 8 | 9 | scheduler = { 10 | "constant": "transformers.get_constant_schedule", 11 | "plateau": "torch.optim.lr_scheduler.ReduceLROnPlateau", 12 | "step": "torch.optim.lr_scheduler.StepLR", 13 | "multistep": "torch.optim.lr_scheduler.MultiStepLR", 14 | "cosine": "torch.optim.lr_scheduler.CosineAnnealingLR", 15 | "constant_warmup": "transformers.get_constant_schedule_with_warmup", 16 | "linear_warmup": "transformers.get_linear_schedule_with_warmup", 17 | "cosine_warmup": "transformers.get_cosine_schedule_with_warmup", 18 | "timm_cosine": "src.utils.optim.schedulers.TimmCosineLRScheduler", 19 | } 20 | 21 | callbacks = { 22 | "timer": "src.callbacks.timer.Timer", 23 | "params": "src.callbacks.params.ParamsLog", 24 | "learning_rate_monitor": "pytorch_lightning.callbacks.LearningRateMonitor", 25 | "model_checkpoint": "pytorch_lightning.callbacks.ModelCheckpoint", 26 | "early_stopping": "pytorch_lightning.callbacks.EarlyStopping", 27 | "swa": "pytorch_lightning.callbacks.StochasticWeightAveraging", 28 | "rich_model_summary": "pytorch_lightning.callbacks.RichModelSummary", 29 | "rich_progress_bar": "pytorch_lightning.callbacks.RichProgressBar", 30 | "progressive_resizing": "src.callbacks.progressive_resizing.ProgressiveResizing", 31 | # "profiler": "pytorch_lightning.profilers.PyTorchProfiler", 32 | } 33 | 34 | model = { 35 | # Backbones from this repo 36 | "spikingssm": "models.spike.ss4d.SpikingSSM", 37 | "model": "src.models.sequence.backbones.model.SequenceModel", 38 | "unet": "src.models.sequence.backbones.unet.SequenceUNet", 39 | "sashimi": "src.models.sequence.backbones.sashimi.Sashimi", 40 | "sashimi_standalone": "models.sashimi.sashimi.Sashimi", 41 | # Baseline RNNs 42 | "lstm": "src.models.baselines.lstm.TorchLSTM", 43 | "gru": "src.models.baselines.gru.TorchGRU", 44 | "unicornn": "src.models.baselines.unicornn.UnICORNN", 45 | "odelstm": "src.models.baselines.odelstm.ODELSTM", 46 | "lipschitzrnn": "src.models.baselines.lipschitzrnn.RnnModels", 47 | "stackedrnn": "src.models.baselines.samplernn.StackedRNN", 48 | "stackedrnn_baseline": "src.models.baselines.samplernn.StackedRNNBaseline", 49 | "samplernn": "src.models.baselines.samplernn.SampleRNN", 50 | "dcgru": "src.models.baselines.dcgru.DCRNNModel_classification", 51 | "dcgru_ss": "src.models.baselines.dcgru.DCRNNModel_nextTimePred", 52 | # Baseline CNNs 53 | "ckconv": "src.models.baselines.ckconv.ClassificationCKCNN", 54 | "wavegan": "src.models.baselines.wavegan.WaveGANDiscriminator", # DEPRECATED 55 | "denseinception": "src.models.baselines.dense_inception.DenseInception", 56 | "wavenet": "src.models.baselines.wavenet.WaveNetModel", 57 | "torch/resnet2d": "src.models.baselines.resnet.TorchVisionResnet", # 2D ResNet 58 | # Nonaka 1D CNN baselines 59 | "nonaka/resnet18": "src.models.baselines.nonaka.resnet.resnet1d18", 60 | "nonaka/inception": "src.models.baselines.nonaka.inception.inception1d", 61 | "nonaka/xresnet50": "src.models.baselines.nonaka.xresnet.xresnet1d50", 62 | # ViT Variants (note: small variant is taken from Tri, differs from original) 63 | "vit": "models.baselines.vit.ViT", 64 | "vit_s_16": "src.models.baselines.vit_all.vit_small_patch16_224", 65 | "vit_b_16": "src.models.baselines.vit_all.vit_base_patch16_224", 66 | # Timm models 67 | "timm/convnext_base": "src.models.baselines.convnext_timm.convnext_base", 68 | "timm/convnext_small": "src.models.baselines.convnext_timm.convnext_small", 69 | "timm/convnext_tiny": "src.models.baselines.convnext_timm.convnext_tiny", 70 | "timm/convnext_micro": "src.models.baselines.convnext_timm.convnext_micro", 71 | "timm/resnet50": "src.models.baselines.resnet_timm.resnet50", # Can also register many other variants in resnet_timm 72 | "timm/convnext_tiny_3d": "src.models.baselines.convnext_timm.convnext3d_tiny", 73 | # Segmentation models 74 | "convnext_unet_tiny": "src.models.segmentation.convnext_unet.convnext_tiny_unet", 75 | } 76 | 77 | layer = { 78 | "id": "src.models.sequence.base.SequenceIdentity", 79 | "lstm": "src.models.baselines.lstm.TorchLSTM", 80 | "standalone": "models.s4.s4.S4Block", 81 | "s4d": "models.s4.s4d.S4D", 82 | "ffn": "src.models.sequence.modules.ffn.FFN", 83 | "sru": "src.models.sequence.rnns.sru.SRURNN", 84 | "rnn": "src.models.sequence.rnns.rnn.RNN", # General RNN wrapper 85 | "conv1d": "src.models.sequence.convs.conv1d.Conv1d", 86 | "conv2d": "src.models.sequence.convs.conv2d.Conv2d", 87 | "mha": "src.models.sequence.attention.mha.MultiheadAttention", 88 | "vit": "src.models.sequence.attention.mha.VitAttention", 89 | "performer": "src.models.sequence.attention.linear.Performer", 90 | "lssl": "src.models.sequence.modules.lssl.LSSL", 91 | "s4": "src.models.sequence.modules.s4block.S4Block", 92 | "fftconv": "src.models.sequence.kernels.fftconv.FFTConv", 93 | "s4nd": "src.models.sequence.modules.s4nd.S4ND", 94 | "mega": "src.models.sequence.modules.mega.MegaBlock", 95 | "h3": "src.models.sequence.experimental.h3.H3", 96 | "h4": "src.models.sequence.experimental.h4.H4", 97 | # 'packedrnn': 'models.sequence.rnns.packedrnn.PackedRNN', 98 | "ss4d": "models.spike.ss4d.SS4D", 99 | } 100 | 101 | layer_decay = { 102 | "convnext_timm_tiny": "src.models.baselines.convnext_timm.get_num_layer_for_convnext_tiny", 103 | } 104 | 105 | model_state_hook = { 106 | "convnext_timm_tiny_2d_to_3d": "src.models.baselines.convnext_timm.convnext_timm_tiny_2d_to_3d", 107 | "convnext_timm_tiny_s4nd_2d_to_3d": "src.models.baselines.convnext_timm.convnext_timm_tiny_s4nd_2d_to_3d", 108 | } 109 | -------------------------------------------------------------------------------- /models/spike/ss4d.py: -------------------------------------------------------------------------------- 1 | """Minimal version of S4D with extra options and features stripped out, for pedagogical purposes.""" 2 | 3 | import math 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from einops import rearrange, repeat 8 | 9 | from src.models.nn import DropoutNd 10 | from src.models.spike.neuron import registry 11 | from src.models.spike.surrogate import piecewise_quadratic_surrogate 12 | from src.models.sequence.kernels.ssm import SSMKernelDiag 13 | 14 | 15 | class S4DKernel(nn.Module): 16 | """Generate convolution kernel from diagonal SSM parameters.""" 17 | 18 | def __init__(self, d_model, N=64, dt_min=0.001, dt_max=0.1, channels=1, lr=None): 19 | super().__init__() 20 | # Generate dt 21 | lr = min(lr, 0.001) 22 | H = d_model 23 | log_dt = torch.rand(H) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min) 24 | 25 | C = torch.randn(channels, H, N // 2, dtype=torch.cfloat) 26 | self.C = nn.Parameter(torch.view_as_real(C)) 27 | self.register("log_dt", log_dt, lr) 28 | 29 | log_A_real = torch.log(0.5 * torch.ones(H, N // 2)) 30 | A_imag = math.pi * repeat(torch.arange(N // 2), "n -> h n", h=H) 31 | self.register("log_A_real", log_A_real, lr) 32 | self.register("A_imag", A_imag, lr) 33 | 34 | def forward(self, L): 35 | """ 36 | returns: (..., c, L) where c is number of channels (default 1) 37 | """ 38 | 39 | # Materialize parameters 40 | dt = torch.exp(self.log_dt) # (H) 41 | C = torch.view_as_complex(self.C) # (C H N) 42 | A = -torch.exp(self.log_A_real) + 1j * self.A_imag # (H N) 43 | 44 | # Vandermonde multiplication 45 | dtA = A * dt.unsqueeze(-1) # (H N) 46 | K = dtA.unsqueeze(-1) * torch.arange(L, device=A.device) # (H N L) 47 | C = C * (torch.exp(dtA) - 1.0) / A 48 | K = 2 * torch.einsum("chn, hnl -> chl", C, torch.exp(K)).real 49 | 50 | return K, None 51 | 52 | def register(self, name, tensor, lr=None): 53 | """Register a tensor with a configurable learning rate and 0 weight decay""" 54 | 55 | if lr == 0.0: 56 | self.register_buffer(name, tensor) 57 | else: 58 | self.register_parameter(name, nn.Parameter(tensor)) 59 | 60 | optim = {"weight_decay": 0.0} 61 | if lr is not None: 62 | optim["lr"] = lr 63 | setattr(getattr(self, name), "_optim", optim) 64 | 65 | 66 | class SS4D(nn.Module): 67 | def __init__( 68 | self, 69 | d_model, 70 | neuron="sdn", 71 | learnable_vth=True, 72 | shared_vth=False, 73 | d_state=64, 74 | dropout=0.0, 75 | transposed=True, 76 | bidirectional=False, 77 | channels=1, 78 | trainable_B=False, 79 | **kernel_args, 80 | ): 81 | super().__init__() 82 | 83 | self.h = d_model 84 | self.n = d_state 85 | # self.d_output = self.h 86 | self.transposed = transposed 87 | self.learnable_vth = learnable_vth 88 | self.bidirectional = bidirectional 89 | self.D = nn.Parameter(torch.randn(channels, self.h)) 90 | 91 | if learnable_vth: 92 | if shared_vth: 93 | self.ln_vth = nn.Parameter(torch.zeros(1)) 94 | else: 95 | self.ln_vth = nn.Parameter(torch.zeros(d_model, 1)) 96 | 97 | if bidirectional: 98 | channels *= 2 99 | # SSM Kernel 100 | if trainable_B: 101 | self.kernel = SSMKernelDiag(d_model=self.h, d_state=self.n, channels=channels, init="diag-lin", **kernel_args) 102 | else: 103 | self.kernel = S4DKernel(self.h, N=self.n, channels=channels, **kernel_args) 104 | 105 | self.neuron = registry[neuron](piecewise_quadratic_surrogate()) 106 | 107 | # # Pointwise 108 | # self.activation = nn.GELU() no use 109 | 110 | # dropout_fn = nn.Dropout2d # NOTE: bugged in PyTorch 1.11 111 | dropout_fn = DropoutNd 112 | self.dropout = dropout_fn(dropout) if dropout > 0.0 else nn.Identity() 113 | 114 | # position-wise output transform to mix features 115 | self.output_linear = nn.Sequential( 116 | nn.Conv1d(self.h, 2 * self.h, kernel_size=1), 117 | nn.GLU(dim=-2), 118 | ) 119 | 120 | def forward(self, u, **kwargs): # absorbs return_output and transformer src mask 121 | """Input and output shape (B, H, L)""" 122 | if not self.transposed: 123 | u = u.transpose(-1, -2) 124 | L = u.size(-1) 125 | 126 | # Compute SSM Kernel 127 | k, _ = self.kernel(L=L) # (H L) 128 | if self.bidirectional: 129 | k0, k1 = rearrange(k, "(s c) h l -> s c h l", s=2) 130 | k = F.pad(k0, (0, L)) + F.pad(k1.flip(-1), (L, 0)) 131 | 132 | # Convolution 133 | k_f = torch.fft.rfft(k, n=2 * L) # (C H L) 134 | u_f = torch.fft.rfft(u, n=2 * L) # (B H L) 135 | 136 | y = torch.einsum("bhl,chl->bchl", u_f, k_f) 137 | 138 | y = torch.fft.irfft(y, n=2 * L)[..., :L] # (B C H L) 139 | 140 | # Compute D term in state space equation - essentially a skip connection 141 | y = y + torch.einsum("bhl,ch->bchl", u, self.D) 142 | 143 | y = rearrange(y, "b c h l -> b (c h) l") 144 | 145 | if self.learnable_vth: 146 | y = y / torch.exp(self.ln_vth) 147 | 148 | y = self.dropout(self.neuron(y)) 149 | y = self.output_linear(y) 150 | if not self.transposed: 151 | y = y.transpose(-1, -2) 152 | return ( 153 | y, 154 | None, 155 | ) # Return a dummy state to satisfy this repo's interface, but this can be modified 156 | 157 | 158 | class SpikingSSM(nn.Module): 159 | def __init__( 160 | self, 161 | d_model=256, 162 | n_layers=4, 163 | dropout=0.2, 164 | prenorm=False, 165 | norm="layer", 166 | layer=None, # layer config 167 | **kwargs, 168 | ): 169 | super().__init__() 170 | 171 | self.prenorm = prenorm 172 | self.norm = norm 173 | 174 | # for dataset adaptability 175 | self.d_model = self.d_output = d_model 176 | 177 | # Stack S4 layers as residual blocks 178 | self.s4_layers = nn.ModuleList() 179 | self.norms = nn.ModuleList() 180 | self.dropouts = nn.ModuleList() 181 | for _ in range(n_layers): 182 | self.s4_layers.append(SS4D(d_model, dropout=dropout, transposed=True, **layer)) 183 | if norm == "batch": 184 | self.norms.append(nn.BatchNorm1d(d_model)) 185 | elif norm == "layer": 186 | self.norms.append(nn.LayerNorm(d_model)) 187 | self.dropouts.append(DropoutNd(dropout)) 188 | 189 | def forward(self, x, **kwargs): 190 | """ 191 | Input x is shape (B, L, d_input) 192 | """ 193 | # x = self.encoder(x) # (B, L, d_input) -> (B, L, d_model) 194 | x = x.transpose(-1, -2) # (B, L, d_model) -> (B, d_model, L) 195 | for layer, norm, dropout in zip(self.s4_layers, self.norms, self.dropouts): 196 | # Each iteration of this loop will map (B, d_model, L) -> (B, d_model, L) 197 | 198 | z = x 199 | # Prenorm 200 | if self.prenorm: 201 | z = norm(z) if self.norm == "batch" else norm(z.transpose(-1, -2)).transpose(-1, -2) 202 | 203 | # Apply S4 block: we ignore the state input and output 204 | z, _ = layer(z) 205 | 206 | # Dropout on the output of the S4 block 207 | z = dropout(z) 208 | 209 | # Residual connection 210 | x = z + x 211 | 212 | if not self.prenorm: 213 | # Postnorm 214 | z = norm(z) if self.norm == "batch" else norm(z.transpose(-1, -2)).transpose(-1, -2) 215 | 216 | x = x.transpose(-1, -2) 217 | 218 | return x, None 219 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import time 4 | 5 | import torch 6 | import torch.backends.cudnn as cudnn 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | import torchvision.transforms as transforms 10 | from loguru import logger 11 | 12 | from dataset import MemDataset 13 | from model import SDN 14 | from utils import Metrics, MetricsCheckpoint, parameters_count 15 | 16 | parser = argparse.ArgumentParser(description="SDN Training") 17 | parser.add_argument("--lr", default=0.01, type=float, help="Learning rate") 18 | parser.add_argument("--weight_decay", default=0.01, type=float, help="Weight decay") 19 | parser.add_argument("--epochs", default=100, type=int, help="Training epochs") 20 | parser.add_argument( 21 | "--num_workers", default=4, type=int, help="Number of workers to use for dataloader" 22 | ) 23 | parser.add_argument("--batch_size", default=64, type=int, help="Batch size") 24 | # Model 25 | parser.add_argument("--n_layers", default=1, type=int, help="Number of layers") 26 | parser.add_argument("--d_model", default=8, type=int, help="Model dimension") 27 | parser.add_argument("--k", "--kernel_size", default=8, type=int, help="Kernel size") 28 | 29 | # Dataset 30 | parser.add_argument( 31 | "--training", type=str, required=True, help="Path to training dataset" 32 | ) 33 | parser.add_argument("--test", type=str, required=True, help="Path to test dataset") 34 | 35 | # General 36 | items_group = parser.add_mutually_exclusive_group() 37 | items_group.add_argument( 38 | "--resume", "-r", default=None, type=str, help="Path where checkpoint to resume" 39 | ) 40 | items_group.add_argument( 41 | "--save", 42 | "-s", 43 | default="exp", 44 | type=str, 45 | help="Path where checkpoint to save and will be inactive given `resume`", 46 | ) 47 | 48 | args = parser.parse_args() 49 | 50 | device = "cuda" if torch.cuda.is_available() else "cpu" 51 | logger.info(f"Model will be trained on device: {device}") 52 | 53 | if args.resume is not None: 54 | args.save = args.resume 55 | 56 | logger.add(os.path.join(args.save, "exp.log")) 57 | 58 | # Data 59 | logger.info("==> Preparing data.") 60 | 61 | 62 | def split_train_val(train, val_split): 63 | train_len = int(len(train) * (1.0 - val_split)) 64 | train, val = torch.utils.data.random_split( 65 | train, 66 | (train_len, len(train) - train_len), 67 | generator=torch.Generator().manual_seed(42), 68 | ) 69 | return train, val 70 | 71 | 72 | transform = transforms.Lambda(lambda x: x.view(1, -1)) 73 | 74 | transform_train = transform_test = transform 75 | 76 | trainset = MemDataset( 77 | filename=args.training, 78 | transform=transform_train, 79 | ) 80 | trainset, valset = split_train_val(trainset, val_split=0.1) 81 | 82 | testset = MemDataset( 83 | filename=args.test, 84 | transform=transform_test, 85 | ) 86 | 87 | # Dataloaders 88 | trainloader = torch.utils.data.DataLoader( 89 | trainset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers 90 | ) 91 | valloader = torch.utils.data.DataLoader( 92 | valset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers 93 | ) 94 | testloader = torch.utils.data.DataLoader( 95 | testset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers 96 | ) 97 | 98 | # Model 99 | logger.info("==> Building model.") 100 | model = SDN( 101 | d_model=args.d_model, 102 | kernel_size=args.k, 103 | n_layers=args.n_layers, 104 | ) 105 | 106 | logger.info(model) 107 | logger.info(f"Params: {parameters_count(model):,}") 108 | 109 | model = model.to(device) 110 | if device == "cuda": 111 | cudnn.benchmark = True 112 | 113 | criterion = nn.SmoothL1Loss() 114 | optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) 115 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs) 116 | 117 | 118 | best_metrics_checkpoint = MetricsCheckpoint(loss=float("+inf"), epoch=-1) 119 | start_epoch = 0 120 | 121 | last_checkpoint_filename = os.path.join(args.save, "last_checkpoint.pth") 122 | best_checkpoint_filename = os.path.join(args.save, "best_checkpoint.pth") 123 | 124 | if args.resume: 125 | # Load checkpoint. 126 | logger.info("==> Resuming from checkpoint..") 127 | assert os.path.exists(last_checkpoint_filename), "Error: no checkpoint found!" 128 | checkpoint = torch.load(last_checkpoint_filename) 129 | model.load_state_dict(checkpoint["model"]) 130 | optimizer.load_state_dict(checkpoint["optimizer"]) 131 | scheduler.load_state_dict(checkpoint["scheduler"]) 132 | start_epoch = checkpoint["epoch"] + 1 133 | if os.path.exists(last_checkpoint_filename): 134 | best_metrics_checkpoint = torch.load(best_checkpoint_filename)["metrics"] 135 | 136 | 137 | # Training 138 | def train(trainloader): 139 | model.train() 140 | train_loss = Metrics("Loss") 141 | abs_error = [] 142 | acc1 = Metrics("Acc@1", scale=100, format=".2f", suffix="%") 143 | for inputs, targets, spikes in trainloader: 144 | inputs, targets, spikes = ( 145 | inputs.to(device), 146 | targets.to(device), 147 | spikes.to(device), 148 | ) 149 | optimizer.zero_grad() 150 | outputs = model(inputs) 151 | loss = criterion(outputs, targets) 152 | 153 | pred_s = (outputs.detach() + inputs.squeeze() >= 1).float() 154 | acc = (pred_s == spikes).float().mean() 155 | acc1.update(acc.item(), inputs.size(0)) 156 | 157 | loss.backward() 158 | optimizer.step() 159 | 160 | abs_error.append((outputs - targets).abs()) 161 | train_loss.update(loss.item()) 162 | abs_error = torch.cat(abs_error) 163 | std, mean = torch.std_mean(abs_error) 164 | return MetricsCheckpoint( 165 | loss=train_loss.avg, 166 | mae_max=abs_error.max().item(), 167 | mae_mean=mean.item(), 168 | mae_std=std.item(), 169 | acc1=acc1.avg, 170 | ) 171 | 172 | 173 | @torch.no_grad() 174 | def eval(dataloader): 175 | model.eval() 176 | eval_loss = Metrics("Loss") 177 | abs_error = [] 178 | acc1 = Metrics("Acc@1", scale=100, format=".2f", suffix="%") 179 | for inputs, targets, spikes in dataloader: 180 | inputs, targets, spikes = ( 181 | inputs.to(device), 182 | targets.to(device), 183 | spikes.to(device), 184 | ) 185 | outputs = model(inputs) 186 | 187 | loss = criterion(outputs, targets) 188 | 189 | pred_s = (outputs.detach() + inputs.squeeze() >= 1).float() 190 | 191 | acc = (pred_s == spikes).float().mean() 192 | acc1.update(acc.item(), inputs.size(0)) 193 | 194 | abs_error.append((outputs - targets).abs()) 195 | eval_loss.update(loss.item()) 196 | 197 | abs_error = torch.cat(abs_error) 198 | std, mean = torch.std_mean(abs_error) 199 | 200 | return MetricsCheckpoint( 201 | loss=eval_loss.avg, 202 | mae_max=abs_error.max().item(), 203 | mae_mean=mean.item(), 204 | mae_std=std.item(), 205 | acc1=acc1.avg, 206 | ) 207 | 208 | 209 | logger.info("==> Training") 210 | for epoch in range(start_epoch, args.epochs): 211 | logger.info(f"==> Epoch {epoch}:") 212 | train_metrics_checkpoint = train(trainloader) 213 | logger.info(f"training: {train_metrics_checkpoint}") 214 | val_metrics_checkpoint = eval(valloader) 215 | logger.info(f"validation: {val_metrics_checkpoint}") 216 | test_metrics_checkpoint = eval(testloader) 217 | logger.info(f"test: {test_metrics_checkpoint}") 218 | 219 | scheduler.step() 220 | checkpoint = { 221 | "model": model.state_dict(), 222 | "optimizer": optimizer.state_dict(), 223 | "scheduler": scheduler.state_dict(), 224 | "metrics": val_metrics_checkpoint, 225 | "epoch": epoch, 226 | } 227 | torch.save(checkpoint, last_checkpoint_filename) 228 | if val_metrics_checkpoint < best_metrics_checkpoint: 229 | best_metrics_checkpoint = val_metrics_checkpoint 230 | checkpoint["test_metrics"] = test_metrics_checkpoint 231 | checkpoint["training_metrics"] = test_metrics_checkpoint 232 | torch.save(checkpoint, best_checkpoint_filename) 233 | 234 | best_checkpoint = torch.load(best_checkpoint_filename) 235 | logger.info("=================================") 236 | logger.info("Best Performance:") 237 | logger.info(f"Training: {best_checkpoint['training_metrics']}") 238 | logger.info(f"Validation: {best_checkpoint['metrics']}") 239 | logger.info(f"Test: {best_checkpoint['test_metrics']}") 240 | logger.info("=================================") 241 | logger.info("==> Finished.") 242 | -------------------------------------------------------------------------------- /convert.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch\n", 10 | "from torch.nn.utils import fuse_conv_bn_eval\n", 11 | "\n", 12 | "from model import SDN, FusedSDN" 13 | ] 14 | }, 15 | { 16 | "cell_type": "markdown", 17 | "metadata": {}, 18 | "source": [ 19 | "#### 0. Architecture hyper-parametersm, checkpoint and helper function" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": 2, 25 | "metadata": {}, 26 | "outputs": [], 27 | "source": [ 28 | "d_model = 8\n", 29 | "kernel_size = 8\n", 30 | "n_layers = 1\n", 31 | "save_file = \"sdn.pkl\"" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": 3, 37 | "metadata": {}, 38 | "outputs": [], 39 | "source": [ 40 | "model_checkpoint = \"exp1/best_checkpoint.pth\" # specific the checkpoint you want to convert\n", 41 | "checkpoint = torch.load(model_checkpoint)" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": 4, 47 | "metadata": {}, 48 | "outputs": [], 49 | "source": [ 50 | "def copy_parameter(target, source):\n", 51 | " target.data.copy_(source)" 52 | ] 53 | }, 54 | { 55 | "cell_type": "markdown", 56 | "metadata": {}, 57 | "source": [ 58 | "#### 1. Build models." 59 | ] 60 | }, 61 | { 62 | "cell_type": "markdown", 63 | "metadata": {}, 64 | "source": [ 65 | "##### a. Build SDN models" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": 5, 71 | "metadata": {}, 72 | "outputs": [ 73 | { 74 | "data": { 75 | "text/plain": [ 76 | "" 77 | ] 78 | }, 79 | "execution_count": 5, 80 | "metadata": {}, 81 | "output_type": "execute_result" 82 | } 83 | ], 84 | "source": [ 85 | "model = SDN(d_model, kernel_size, n_layers)\n", 86 | "model.load_state_dict(checkpoint[\"model\"])" 87 | ] 88 | }, 89 | { 90 | "cell_type": "markdown", 91 | "metadata": {}, 92 | "source": [ 93 | "##### b. Build fused SDN model." 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": 6, 99 | "metadata": {}, 100 | "outputs": [], 101 | "source": [ 102 | "fused_model = FusedSDN(d_model, kernel_size, n_layers)" 103 | ] 104 | }, 105 | { 106 | "cell_type": "markdown", 107 | "metadata": {}, 108 | "source": [ 109 | "#### 2. Conversion" 110 | ] 111 | }, 112 | { 113 | "cell_type": "markdown", 114 | "metadata": {}, 115 | "source": [ 116 | "Convert the `dtype` of models to `double` in order to avoid precision errors and switch them to eval mode" 117 | ] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "execution_count": 7, 122 | "metadata": {}, 123 | "outputs": [ 124 | { 125 | "data": { 126 | "text/plain": [ 127 | "FusedSDN(\n", 128 | " (encoder): Conv1d(1, 8, kernel_size=(1,), stride=(1,), padding=(1,), bias=False)\n", 129 | " (spatial_layers): ModuleList(\n", 130 | " (0): Conv1d(1, 8, kernel_size=(8,), stride=(1,), padding=(8,))\n", 131 | " )\n", 132 | " (feature_layers): ModuleList(\n", 133 | " (0): Conv1d(8, 8, kernel_size=(1,), stride=(1,))\n", 134 | " )\n", 135 | " (decoder): Conv1d(8, 1, kernel_size=(1,), stride=(1,))\n", 136 | ")" 137 | ] 138 | }, 139 | "execution_count": 7, 140 | "metadata": {}, 141 | "output_type": "execute_result" 142 | } 143 | ], 144 | "source": [ 145 | "model.double().eval()\n", 146 | "fused_model.double().eval()" 147 | ] 148 | }, 149 | { 150 | "cell_type": "markdown", 151 | "metadata": {}, 152 | "source": [ 153 | "#### 3. Fusion" 154 | ] 155 | }, 156 | { 157 | "cell_type": "markdown", 158 | "metadata": {}, 159 | "source": [ 160 | "Fuse the encoder into the first spatial layer." 161 | ] 162 | }, 163 | { 164 | "cell_type": "code", 165 | "execution_count": 8, 166 | "metadata": {}, 167 | "outputs": [], 168 | "source": [ 169 | "copy_parameter(\n", 170 | " fused_model.spatial_layers[0].weight,\n", 171 | " model.encoder.weight * model.spatial_layers[0][0].weight,\n", 172 | ")\n", 173 | "copy_parameter(fused_model.spatial_layers[0].bias, model.spatial_layers[0][0].bias)" 174 | ] 175 | }, 176 | { 177 | "cell_type": "markdown", 178 | "metadata": {}, 179 | "source": [ 180 | "Fuse all bn layers into its previous `conv1d` layer" 181 | ] 182 | }, 183 | { 184 | "cell_type": "code", 185 | "execution_count": 9, 186 | "metadata": {}, 187 | "outputs": [], 188 | "source": [ 189 | "fused_model.spatial_layers[0] = fuse_conv_bn_eval(\n", 190 | " fused_model.spatial_layers[0], model.spatial_layers[0][1]\n", 191 | ")\n", 192 | "\n", 193 | "\n", 194 | "for i in range(1, n_layers):\n", 195 | " fused_model.spatial_layers[i] = fuse_conv_bn_eval(\n", 196 | " model.spatial_layers[i][0], model.spatial_layers[i][1]\n", 197 | " )\n", 198 | "\n", 199 | "for i in range(n_layers):\n", 200 | " fused_model.feature_layers[i] = fuse_conv_bn_eval(\n", 201 | " model.feature_layers[i][0], model.feature_layers[i][1]\n", 202 | " )" 203 | ] 204 | }, 205 | { 206 | "cell_type": "markdown", 207 | "metadata": {}, 208 | "source": [ 209 | "Copy the decoder" 210 | ] 211 | }, 212 | { 213 | "cell_type": "code", 214 | "execution_count": 10, 215 | "metadata": {}, 216 | "outputs": [], 217 | "source": [ 218 | "fused_model.decoder = model.decoder" 219 | ] 220 | }, 221 | { 222 | "cell_type": "markdown", 223 | "metadata": {}, 224 | "source": [ 225 | "#### 4. Test" 226 | ] 227 | }, 228 | { 229 | "cell_type": "code", 230 | "execution_count": 11, 231 | "metadata": {}, 232 | "outputs": [], 233 | "source": [ 234 | "x = torch.randn(64, 1, 1024).double()" 235 | ] 236 | }, 237 | { 238 | "cell_type": "code", 239 | "execution_count": 12, 240 | "metadata": {}, 241 | "outputs": [ 242 | { 243 | "data": { 244 | "text/plain": [ 245 | "True" 246 | ] 247 | }, 248 | "execution_count": 12, 249 | "metadata": {}, 250 | "output_type": "execute_result" 251 | } 252 | ], 253 | "source": [ 254 | "torch.allclose(model(x), fused_model(x))" 255 | ] 256 | }, 257 | { 258 | "cell_type": "markdown", 259 | "metadata": {}, 260 | "source": [ 261 | "#### 5. Save" 262 | ] 263 | }, 264 | { 265 | "cell_type": "markdown", 266 | "metadata": {}, 267 | "source": [ 268 | "Here we use `torch.jit.trace` to turn the fused model into a `TorchScript`.\n", 269 | "\n", 270 | "Note that we convert the `dtype` of model and input to `float`." 271 | ] 272 | }, 273 | { 274 | "cell_type": "code", 275 | "execution_count": 13, 276 | "metadata": {}, 277 | "outputs": [], 278 | "source": [ 279 | "fused_model.float() # in-place operation\n", 280 | "x = x.float()\n", 281 | "with torch.no_grad():\n", 282 | " traced_fused_model = torch.jit.trace(fused_model, x)" 283 | ] 284 | }, 285 | { 286 | "cell_type": "code", 287 | "execution_count": 14, 288 | "metadata": {}, 289 | "outputs": [], 290 | "source": [ 291 | "traced_fused_model.save(save_file)" 292 | ] 293 | }, 294 | { 295 | "cell_type": "code", 296 | "execution_count": 15, 297 | "metadata": {}, 298 | "outputs": [ 299 | { 300 | "data": { 301 | "text/plain": [ 302 | "True" 303 | ] 304 | }, 305 | "execution_count": 15, 306 | "metadata": {}, 307 | "output_type": "execute_result" 308 | } 309 | ], 310 | "source": [ 311 | "reload_model = torch.jit.load(save_file)\n", 312 | "torch.allclose(reload_model(x), fused_model(x))" 313 | ] 314 | }, 315 | { 316 | "cell_type": "markdown", 317 | "metadata": {}, 318 | "source": [ 319 | "Now, we can use the fused model without source code in our training of SNNs.\n", 320 | "\n", 321 | "Here is an example:" 322 | ] 323 | }, 324 | { 325 | "cell_type": "code", 326 | "execution_count": 16, 327 | "metadata": {}, 328 | "outputs": [ 329 | { 330 | "data": { 331 | "text/plain": [ 332 | "tensor([[0., 0., 0., ..., 0., 1., 0.],\n", 333 | " [0., 0., 0., ..., 0., 0., 0.],\n", 334 | " [1., 0., 0., ..., 0., 0., 0.],\n", 335 | " ...,\n", 336 | " [1., 0., 0., ..., 1., 0., 0.],\n", 337 | " [1., 0., 0., ..., 1., 0., 1.],\n", 338 | " [0., 1., 0., ..., 0., 1., 0.]])" 339 | ] 340 | }, 341 | "execution_count": 16, 342 | "metadata": {}, 343 | "output_type": "execute_result" 344 | } 345 | ], 346 | "source": [ 347 | "from torch import nn\n", 348 | "\n", 349 | "\n", 350 | "class StraightThroughEstimator(torch.autograd.Function):\n", 351 | " @staticmethod\n", 352 | " def forward(ctx, x):\n", 353 | " return (x >= 0).to(x)\n", 354 | "\n", 355 | " @staticmethod\n", 356 | " def backward(ctx, grad_out):\n", 357 | " return grad_out\n", 358 | "\n", 359 | "\n", 360 | "class SDNLIF(nn.Module):\n", 361 | " model_path = save_file\n", 362 | "\n", 363 | " def __init__(self, surrogate_func):\n", 364 | " super().__init__()\n", 365 | " self.model = torch.jit.load(self.model_path).eval()\n", 366 | " self.surrogate_func = surrogate_func\n", 367 | "\n", 368 | " def forward(self, x: torch.Tensor):\n", 369 | " m = self.pred(x)\n", 370 | " s = self.surrogate_func(m + x - 1.0)\n", 371 | " return s\n", 372 | "\n", 373 | " @torch.no_grad()\n", 374 | " def pred(self, x):\n", 375 | " shape = x.shape\n", 376 | " L = x.size(-1)\n", 377 | " return self.model(x.detach().view(-1, 1, L)).view(shape)\n", 378 | "\n", 379 | "\n", 380 | "test_model = SDNLIF(StraightThroughEstimator.apply)\n", 381 | "test_model(torch.randn(10, 1024))" 382 | ] 383 | } 384 | ], 385 | "metadata": { 386 | "kernelspec": { 387 | "display_name": "torch2.2", 388 | "language": "python", 389 | "name": "python3" 390 | }, 391 | "language_info": { 392 | "codemirror_mode": { 393 | "name": "ipython", 394 | "version": 3 395 | }, 396 | "file_extension": ".py", 397 | "mimetype": "text/x-python", 398 | "name": "python", 399 | "nbconvert_exporter": "python", 400 | "pygments_lexer": "ipython3", 401 | "version": "3.12.2" 402 | } 403 | }, 404 | "nbformat": 4, 405 | "nbformat_minor": 2 406 | } 407 | --------------------------------------------------------------------------------