├── scripts ├── .train_perceive_load.sh.swp ├── eval_plan.sh ├── train_perceive.sh ├── train_plan.sh ├── train_prediction.sh └── train_perceive_load.sh ├── maps ├── Town01.h5 ├── Town02.h5 ├── Town03.h5 ├── Town04.h5 ├── Town05.h5 ├── Town06.h5 ├── Town07.h5 ├── Town10HD.h5 └── hdmap_generate.py ├── __assets__ ├── logo.png └── overview.png ├── sad ├── models │ ├── module │ │ ├── __init__.py │ │ ├── .ipynb_checkpoints │ │ │ ├── __init__-checkpoint.py │ │ │ ├── downsample-checkpoint.py │ │ │ ├── sps-checkpoint.py │ │ │ ├── MS_ResNet-checkpoint.py │ │ │ └── ms_conv-checkpoint.py │ │ ├── downsample.py │ │ ├── sps.py │ │ ├── ms_conv.py │ │ ├── MS_ResNet.py │ │ └── Real_MS_ResNet.py │ ├── temporal_model.py │ ├── distributions_snn.py │ ├── future_prediction_snn.py │ ├── encoder_snn_merge.py │ ├── planning_model.py │ └── decoder_snn_ms.py ├── configs │ └── nuscenes │ │ ├── Perception.yml │ │ ├── Prediction.yml │ │ └── Planning.yml ├── utils │ ├── network.py │ ├── .ipynb_checkpoints │ │ └── network-checkpoint.py │ ├── sampler.py │ ├── geometry.py │ └── instance.py ├── datas │ └── dataloaders.py ├── config.py └── losses.py ├── environment.yml ├── train.py ├── README.md ├── evaluate.py └── LICENSE /scripts/.train_perceive_load.sh.swp: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /maps/Town01.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ridgerchu/SAD/HEAD/maps/Town01.h5 -------------------------------------------------------------------------------- /maps/Town02.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ridgerchu/SAD/HEAD/maps/Town02.h5 -------------------------------------------------------------------------------- /maps/Town03.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ridgerchu/SAD/HEAD/maps/Town03.h5 -------------------------------------------------------------------------------- /maps/Town04.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ridgerchu/SAD/HEAD/maps/Town04.h5 -------------------------------------------------------------------------------- /maps/Town05.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ridgerchu/SAD/HEAD/maps/Town05.h5 -------------------------------------------------------------------------------- /maps/Town06.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ridgerchu/SAD/HEAD/maps/Town06.h5 -------------------------------------------------------------------------------- /maps/Town07.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ridgerchu/SAD/HEAD/maps/Town07.h5 -------------------------------------------------------------------------------- /maps/Town10HD.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ridgerchu/SAD/HEAD/maps/Town10HD.h5 -------------------------------------------------------------------------------- /__assets__/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ridgerchu/SAD/HEAD/__assets__/logo.png -------------------------------------------------------------------------------- /__assets__/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ridgerchu/SAD/HEAD/__assets__/overview.png -------------------------------------------------------------------------------- /scripts/eval_plan.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | echo "checkpoint: $1" 3 | echo "dataroot: $2" 4 | python evaluate.py --checkpoint $1 --dataroot $2 -------------------------------------------------------------------------------- /scripts/train_perceive.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | echo "configs: $1" 3 | echo "DATASET.DATAROOT: $2" 4 | python train.py --config $1 DATASET.DATAROOT $2 DATASET.MAP_FOLDER $2 -------------------------------------------------------------------------------- /scripts/train_plan.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | echo "configs: $1" 3 | echo "DATASET.DATAROOT: $2" 4 | echo "PRETRAINED.PATH: $3" 5 | python train.py --config $1 DATASET.DATAROOT $2 DATASET.MAP_FOLDER $2 PRETRAINED.PATH $3 -------------------------------------------------------------------------------- /scripts/train_prediction.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | echo "configs: $1" 3 | echo "DATASET.DATAROOT: $2" 4 | echo "PRETRAINED.PATH: $3" 5 | python train.py --config $1 DATASET.DATAROOT $2 DATASET.MAP_FOLDER $2 PRETRAINED.PATH $3 -------------------------------------------------------------------------------- /scripts/train_perceive_load.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | echo "configs: $1" 3 | echo "DATASET.DATAROOT: $2" 4 | echo "PRETRAINED.PATH: $3" 5 | python train.py --config $1 DATASET.DATAROOT $2 DATASET.MAP_FOLDER $2 PRETRAINED.PATH $3 -------------------------------------------------------------------------------- /sad/models/module/__init__.py: -------------------------------------------------------------------------------- 1 | from .ms_conv import MS_Block_Conv 2 | from .sps import MS_SPS 3 | from .downsample import Sps_downsampling 4 | 5 | 6 | __all__ = [ 7 | "MS_SPS", 8 | "MS_Block_Conv", 9 | "Sps_downsampling" 10 | ] 11 | -------------------------------------------------------------------------------- /sad/models/module/.ipynb_checkpoints/__init__-checkpoint.py: -------------------------------------------------------------------------------- 1 | from .ms_conv import MS_Block_Conv 2 | from .sps import MS_SPS 3 | from .downsample import Sps_downsampling 4 | 5 | 6 | __all__ = [ 7 | "MS_SPS", 8 | "MS_Block_Conv", 9 | "Sps_downsampling" 10 | ] 11 | -------------------------------------------------------------------------------- /sad/models/temporal_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from sad.layers.temporal import Bottleneck3D, TemporalBlock 5 | from sad.layers.convolutions import ConvBlock, Bottleneck, SpikingDeepLabHead 6 | 7 | 8 | class TemporalModelIdentity(nn.Module): 9 | def __init__(self, in_channels, receptive_field): 10 | super().__init__() 11 | self.receptive_field = receptive_field 12 | self.out_channels = in_channels 13 | 14 | def forward(self, x): 15 | return x -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: sad 2 | channels: 3 | - defaults 4 | - conda-forge 5 | - pytorch 6 | dependencies: 7 | - python=3.7.10 8 | - pytorch=1.10.2 9 | - torchvision=0.11.3 10 | - cudatoolkit=11.3 11 | - numpy=1.19.2 12 | - scipy=1.5.2 13 | - pillow=8.0.1 14 | - tqdm=4.50.2 15 | - scikit-image=0.18.1 16 | - pytorch-lightning=1.2.5 17 | - efficientnet-pytorch=0.7.0 18 | - fvcore=0.1.2.post20201122 19 | - pip=21.0.1 20 | - timm 21 | - pip: 22 | - nuscenes-devkit==1.1.0 23 | - lyft-dataset-sdk==0.0.8 24 | - opencv-python==4.5.1.48 25 | - moviepy==1.0.3 26 | - shapely==1.8.0 27 | - spikingjelly==0.0.0.0.12 28 | - timm==0.6.12 29 | - torchinfo 30 | - wandb 31 | -------------------------------------------------------------------------------- /sad/configs/nuscenes/Perception.yml: -------------------------------------------------------------------------------- 1 | TAG: 'Perception' 2 | 3 | GPUS: [0] 4 | 5 | BATCHSIZE: 3 6 | PRECISION: 16 7 | EPOCHS: 20 8 | 9 | N_WORKERS: 3 10 | 11 | DATASET: 12 | VERSION: 'trainval' 13 | 14 | TIME_RECEPTIVE_FIELD: 3 15 | N_FUTURE_FRAMES: 0 16 | 17 | LIFT: 18 | GT_DEPTH: False 19 | 20 | MODEL: 21 | ENCODER: 22 | NAME: "/vol4/Coin-MLP-back/Pure-MLP-SNN/18M-with-downsample-conv/checkpoint-308.pth.tar" 23 | TEMPORAL_MODEL: 24 | NAME: 'identity' 25 | INPUT_EGOPOSE: True 26 | BN_MOMENTUM: 0.05 27 | 28 | SEMANTIC_SEG: 29 | PEDESTRIAN: 30 | ENABLED: True 31 | HDMAP: 32 | ENABLED: True 33 | 34 | INSTANCE_SEG: 35 | ENABLED: False 36 | 37 | INSTANCE_FLOW: 38 | ENABLED: False 39 | 40 | PROBABILISTIC: 41 | ENABLED: False 42 | 43 | PLANNING: 44 | ENABLED: False 45 | 46 | OPTIMIZER: 47 | LR: 1e-3 48 | 49 | 50 | -------------------------------------------------------------------------------- /sad/configs/nuscenes/Prediction.yml: -------------------------------------------------------------------------------- 1 | TAG: 'Prediction' 2 | 3 | GPUS: [0, 1, 2, 3] 4 | 5 | BATCHSIZE: 3 6 | PRECISION: 16 7 | EPOCHS: 20 8 | 9 | N_WORKERS: 3 10 | 11 | DATASET: 12 | VERSION: 'trainval' 13 | 14 | TIME_RECEPTIVE_FIELD: 3 15 | N_FUTURE_FRAMES: 4 16 | 17 | LIFT: 18 | GT_DEPTH: False 19 | 20 | MODEL: 21 | ENCODER: 22 | NAME: "/vol4/Coin-MLP-back/Pure-MLP-SNN/18M-with-downsample-conv/checkpoint-308.pth.tar" 23 | TEMPORAL_MODEL: 24 | NAME: 'identity' 25 | INPUT_EGOPOSE: True 26 | BN_MOMENTUM: 0.05 27 | 28 | SEMANTIC_SEG: 29 | PEDESTRIAN: 30 | ENABLED: False 31 | HDMAP: 32 | ENABLED: False 33 | 34 | INSTANCE_FLOW: 35 | ENABLED: True 36 | 37 | PROBABILISTIC: 38 | ENABLED: True 39 | METHOD: 'GAUSSIAN' 40 | 41 | PLANNING: 42 | ENABLED: False 43 | 44 | FUTURE_DISCOUNT: 0.95 45 | 46 | OPTIMIZER: 47 | LR: 2e-4 48 | 49 | 50 | PRETRAINED: 51 | LOAD_WEIGHTS: True 52 | -------------------------------------------------------------------------------- /sad/configs/nuscenes/Planning.yml: -------------------------------------------------------------------------------- 1 | TAG: 'Planning' 2 | 3 | GPUS: [0] 4 | 5 | BATCHSIZE: 1 6 | PRECISION: 16 7 | EPOCHS: 20 8 | 9 | N_WORKERS: 4 10 | 11 | DATASET: 12 | VERSION: 'trainval' 13 | 14 | TIME_RECEPTIVE_FIELD: 3 15 | N_FUTURE_FRAMES: 6 16 | 17 | LIFT: 18 | GT_DEPTH: False 19 | 20 | MODEL: 21 | ENCODER: 22 | NAME: "/vol4/Coin-MLP-back/Pure-MLP-SNN/18M-with-downsample-conv/checkpoint-308.pth.tar" 23 | TEMPORAL_MODEL: 24 | NAME: 'identity' 25 | INPUT_EGOPOSE: True 26 | BN_MOMENTUM: 0.05 27 | 28 | SEMANTIC_SEG: 29 | PEDESTRIAN: 30 | ENABLED: True 31 | HDMAP: 32 | ENABLED: True 33 | 34 | INSTANCE_SEG: 35 | ENABLED: False 36 | 37 | INSTANCE_FLOW: 38 | ENABLED: False 39 | 40 | PROBABILISTIC: 41 | ENABLED: True 42 | METHOD: 'GAUSSIAN' 43 | 44 | PLANNING: 45 | ENABLED: True 46 | SAMPLE_NUM: 1800 47 | 48 | FUTURE_DISCOUNT: 0.95 49 | 50 | OPTIMIZER: 51 | LR: 1e-4 52 | 53 | COST_FUNCTION: 54 | SAFETY: 1. 55 | HEADWAY: 1. 56 | LRDIVIDER: 10. 57 | COMFORT: 0.1 58 | PROGRESS: 0.5 59 | VOLUME: 100. 60 | 61 | PRETRAINED: 62 | LOAD_WEIGHTS: True 63 | -------------------------------------------------------------------------------- /sad/utils/network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision 4 | 5 | def pack_sequence_dim(x): 6 | b, s = x.shape[:2] 7 | return x.view(b * s, *x.shape[2:]) 8 | 9 | 10 | def unpack_sequence_dim(x, b, s): 11 | return x.view(b, s, *x.shape[1:]) 12 | 13 | 14 | def preprocess_batch(batch, device, unsqueeze=False): 15 | for key, value in batch.items(): 16 | if torch.is_tensor(value): 17 | batch[key] = value.to(device) 18 | if unsqueeze: 19 | batch[key] = batch[key].unsqueeze(0) 20 | 21 | 22 | def set_module_grad(module, requires_grad=False): 23 | for p in module.parameters(): 24 | p.requires_grad = requires_grad 25 | 26 | 27 | def set_bn_momentum(model, momentum=0.1): 28 | for m in model.modules(): 29 | if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)): 30 | m.momentum = momentum 31 | 32 | 33 | class NormalizeInverse(torchvision.transforms.Normalize): 34 | # https://discuss.pytorch.org/t/simple-way-to-inverse-transform-normalization/4821/8 35 | def __init__(self, mean, std): 36 | mean = torch.as_tensor(mean) 37 | std = torch.as_tensor(std) 38 | std_inv = 1 / (std + 1e-7) 39 | mean_inv = -mean * std_inv 40 | super().__init__(mean=mean_inv, std=std_inv) 41 | 42 | def __call__(self, tensor): 43 | return super().__call__(tensor.clone()) 44 | -------------------------------------------------------------------------------- /sad/utils/.ipynb_checkpoints/network-checkpoint.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision 4 | 5 | def pack_sequence_dim(x): 6 | b, s = x.shape[:2] 7 | return x.view(b * s, *x.shape[2:]) 8 | 9 | 10 | def unpack_sequence_dim(x, b, s): 11 | return x.view(b, s, *x.shape[1:]) 12 | 13 | 14 | def preprocess_batch(batch, device, unsqueeze=False): 15 | for key, value in batch.items(): 16 | if torch.is_tensor(value): 17 | batch[key] = value.to(device) 18 | if unsqueeze: 19 | batch[key] = batch[key].unsqueeze(0) 20 | 21 | 22 | def set_module_grad(module, requires_grad=False): 23 | for p in module.parameters(): 24 | p.requires_grad = requires_grad 25 | 26 | 27 | def set_bn_momentum(model, momentum=0.1): 28 | for m in model.modules(): 29 | if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)): 30 | m.momentum = momentum 31 | 32 | 33 | class NormalizeInverse(torchvision.transforms.Normalize): 34 | # https://discuss.pytorch.org/t/simple-way-to-inverse-transform-normalization/4821/8 35 | def __init__(self, mean, std): 36 | mean = torch.as_tensor(mean) 37 | std = torch.as_tensor(std) 38 | std_inv = 1 / (std + 1e-7) 39 | mean_inv = -mean * std_inv 40 | super().__init__(mean=mean_inv, std=std_inv) 41 | 42 | def __call__(self, tensor): 43 | return super().__call__(tensor.clone()) 44 | -------------------------------------------------------------------------------- /sad/datas/dataloaders.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data 3 | from nuscenes.nuscenes import NuScenes 4 | from sad.datas.NuscenesData import FuturePredictionDataset 5 | from sad.datas.CarlaData import CarlaDataset 6 | 7 | 8 | def prepare_dataloaders(cfg, return_dataset=False): 9 | if cfg.DATASET.NAME == 'nuscenes': 10 | # 28130 train and 6019 val 11 | dataroot = cfg.DATASET.DATAROOT 12 | nusc = NuScenes(version='v1.0-{}'.format(cfg.DATASET.VERSION), dataroot=dataroot, verbose=False) 13 | traindata = FuturePredictionDataset(nusc, 0, cfg) 14 | valdata = FuturePredictionDataset(nusc, 1, cfg) 15 | 16 | if cfg.DATASET.VERSION == 'mini': 17 | traindata.indices = traindata.indices[:10] 18 | # valdata.indices = valdata.indices[:10] 19 | 20 | nworkers = cfg.N_WORKERS 21 | trainloader = torch.utils.data.DataLoader( 22 | traindata, batch_size=cfg.BATCHSIZE, shuffle=True, num_workers=nworkers, pin_memory=True, drop_last=True 23 | ) 24 | valloader = torch.utils.data.DataLoader( 25 | valdata, batch_size=cfg.BATCHSIZE, shuffle=False, num_workers=nworkers, pin_memory=True, drop_last=False) 26 | elif cfg.DATASET.NAME == 'carla': 27 | dataroot = cfg.DATASET.DATAROOT 28 | traindata = CarlaDataset(dataroot, True, cfg) 29 | valdata = CarlaDataset(dataroot, False, cfg) 30 | nworkers = cfg.N_WORKERS 31 | trainloader = torch.utils.data.DataLoader( 32 | traindata, batch_size=cfg.BATCHSIZE, shuffle=True, num_workers=nworkers, pin_memory=True, drop_last=True 33 | ) 34 | valloader = torch.utils.data.DataLoader( 35 | valdata, batch_size=cfg.BATCHSIZE, shuffle=False, num_workers=nworkers, pin_memory=True, drop_last=False) 36 | else: 37 | raise NotImplementedError 38 | 39 | if return_dataset: 40 | return trainloader, valloader, traindata, valdata 41 | else: 42 | return trainloader, valloader -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import socket 4 | import torch 5 | import pytorch_lightning as pl 6 | from pytorch_lightning.plugins import DDPPlugin 7 | from pytorch_lightning.callbacks import ModelCheckpoint 8 | from pytorch_lightning.loggers import WandbLogger, TensorBoardLogger 9 | import wandb 10 | 11 | from sad.config import get_parser, get_cfg 12 | from sad.datas.dataloaders import prepare_dataloaders 13 | from sad.trainer import TrainingModule 14 | 15 | 16 | 17 | def main(): 18 | args = get_parser().parse_args() 19 | cfg = get_cfg(args) 20 | # print(cfg) 21 | trainloader, valloader = prepare_dataloaders(cfg) 22 | print("load data!!!") 23 | model = TrainingModule(cfg.convert_to_dict()) 24 | 25 | if cfg.PRETRAINED.LOAD_WEIGHTS: 26 | # Load single-image instance segmentation model. 27 | pretrained_model_weights = torch.load( 28 | cfg.PRETRAINED.PATH, map_location='cpu' 29 | )['state_dict'] 30 | state = model.state_dict() 31 | pretrained_model_weights = {k: v for k, v in pretrained_model_weights.items() if k in state and 'decoder' not in k} 32 | model.load_state_dict(pretrained_model_weights, strict=False) 33 | print(f'Loaded single-image model weights from {cfg.PRETRAINED.PATH}') 34 | 35 | save_dir = os.path.join( 36 | cfg.LOG_DIR, time.strftime('%d%B%Yat%H_%M_%S%Z') + '_' + socket.gethostname() + '_' + cfg.TAG 37 | ) 38 | # tb_logger = pl.loggers.TensorBoardLogger(save_dir=save_dir) 39 | wandb_logger = WandbLogger(name='sad-preception', project='sad', entity="ncg_ucsc", config=cfg, log_model="all") 40 | 41 | # if os.environ.get('LOCAL_RANK', '0') == '0': 42 | # wandb.init(name='sad-preception', project='sad', entity="snn_ad", config=cfg) 43 | 44 | checkpoint_callback = ModelCheckpoint( 45 | monitor='step_val_seg_iou_dynamic', 46 | save_top_k=-1, 47 | save_last=True, 48 | period=1, 49 | mode='min' 50 | ) 51 | trainer = pl.Trainer( 52 | gpus=cfg.GPUS, 53 | accelerator='ddp', 54 | precision=cfg.PRECISION, 55 | sync_batchnorm=True, 56 | gradient_clip_val=cfg.GRAD_NORM_CLIP, 57 | max_epochs=cfg.EPOCHS, 58 | weights_summary='full', 59 | logger=wandb_logger, 60 | log_every_n_steps=cfg.LOGGING_INTERVAL, 61 | plugins=DDPPlugin(find_unused_parameters=False), 62 | profiler='simple', 63 | callbacks=[checkpoint_callback] 64 | ) 65 | 66 | trainer.fit(model, trainloader, valloader) 67 | 68 | 69 | if __name__ == "__main__": 70 | main() 71 | -------------------------------------------------------------------------------- /sad/models/distributions_snn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from sad.layers.convolutions import SpikingBottleneck 5 | 6 | 7 | class DistributionModule(nn.Module): 8 | """ 9 | A convolutional net that parametrises a diagonal Gaussian distribution. 10 | """ 11 | 12 | def __init__( 13 | self, in_channels, latent_dim, method="GAUSSIAN"): 14 | super().__init__() 15 | self.compress_dim = in_channels // 2 16 | self.latent_dim = latent_dim 17 | self.method = method 18 | 19 | if method == 'GAUSSIAN': 20 | self.encoder = DistributionEncoder(in_channels, self.compress_dim) 21 | self.decoder = nn.Sequential( 22 | nn.AdaptiveAvgPool2d(1), nn.Conv2d(self.compress_dim, out_channels=2 * self.latent_dim, kernel_size=1) 23 | ) 24 | elif method == 'MIXGAUSSIAN': 25 | self.encoder = DistributionEncoder(in_channels, self.compress_dim) 26 | self.decoder = nn.Sequential( 27 | nn.AdaptiveAvgPool2d(1), nn.Conv2d(self.compress_dim, out_channels=6 * self.latent_dim + 3, kernel_size=1) 28 | ) 29 | elif method == 'BERNOULLI': 30 | self.encoder = nn.Sequential( 31 | Bottleneck(in_channels, self.latent_dim) 32 | ) 33 | self.decoder = nn.LogSigmoid() 34 | else: 35 | raise NotImplementedError 36 | 37 | def forward(self, s_t): 38 | b, s = s_t.shape[:2] 39 | assert s == 1 40 | encoding = self.encoder(s_t[:, 0]) 41 | 42 | if self.method == 'GAUSSIAN': 43 | decoder = self.decoder(encoding).view(b, 1, 2 * self.latent_dim) 44 | elif self.method == 'MIXGAUSSIAN': 45 | decoder = self.decoder(encoding).view(b, 1, 6 * self.latent_dim + 3) 46 | elif self.method == 'BERNOULLI': 47 | decoder = self.decoder(encoding) 48 | else: 49 | raise NotImplementedError 50 | 51 | return decoder 52 | 53 | 54 | class DistributionEncoder(nn.Module): 55 | """Encodes s_t or (s_t, y_{t+1}, ..., y_{t+H}). 56 | """ 57 | def __init__(self, in_channels, out_channels): 58 | super().__init__() 59 | 60 | self.model = nn.Sequential( 61 | SpikingBottleneck(in_channels, out_channels=out_channels, downsample=True), 62 | SpikingBottleneck(out_channels, out_channels=out_channels, downsample=True), 63 | SpikingBottleneck(out_channels, out_channels=out_channels, downsample=True), 64 | SpikingBottleneck(out_channels, out_channels=out_channels, downsample=True), 65 | ) 66 | 67 | def forward(self, s_t): 68 | return self.model(s_t) 69 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | 4 | 5 |
6 | 7 |

[NeurIPS 2024] Spiking Autonomous Driving (SAD): End-to-End Autonomous Driving with Spiking Neural Networks

8 | 9 |
If you find our project useful, please give us a star ⭐ on GitHub!
10 | 11 | ## Introduction 12 | 13 | Spiking Autonomous Driving (SAD) is the first end-to-end autonomous driving system built entirely with Spiking Neural Networks (SNNs). It integrates perception, prediction, and planning modules into a unified neuromorphic framework. 14 | 15 | ### Key Features 16 | 17 | - End-to-end SNN architecture for autonomous driving, integrating perception, prediction, and planning 18 | - Perception module constructs spatio-temporal Bird's Eye View (BEV) representation from multi-view cameras using SNNs 19 | - Prediction module forecasts future states using a novel dual-pathway SNN 20 | - Planning module generates safe trajectories considering occupancy prediction, traffic rules, and ride comfort 21 | 22 | ## System Overview 23 | 24 | 25 | 26 | 27 | ## Modules 28 | 29 | ### Perception 30 | 31 | The perception module constructs a spatio-temporal BEV representation from multi-camera inputs. The encoder uses sequence repetition, while the decoder employs sequence alignment. 32 | 33 | ### Prediction 34 | 35 | The prediction module utilizes a dual-pathway SNN, where one pathway encodes past information and the other predicts future distributions. The outputs from both pathways are fused. 36 | 37 | ### Planning 38 | 39 | The planning module optimizes trajectories using Spiking Gated Recurrent Units (SGRUs), taking into account static occupancy, future predictions, comfort, and other factors. 40 | 41 | ## Get Started 42 | 43 | ### Setup 44 | 45 | ``` 46 | conda env create -f environment.yml 47 | ``` 48 | 49 | ### Training 50 | 51 | First, go to `/sad/configs` and modify the configs. Change the NAME in MODEL/ENCODER to the model we provided. The link is as follows: https://huggingface.co/ridger/MLP-SNN/blob/main/model.pth.tar 52 | 53 | ``` 54 | # Perception module pretraining 55 | bash scripts/train_perceive.sh ${configs} ${dataroot} 56 | 57 | # Prediction module pretraining 58 | bash scripts/train_prediction.sh ${configs} ${dataroot} ${pretrained} 59 | 60 | # Entire model end-to-end training 61 | bash scripts/train_plan.sh ${configs} ${dataroot} ${pretrained} 62 | ``` 63 | ## Citation 64 | 65 | If you find SAD useful in your work, please cite the following source: 66 | 67 | ``` 68 | @inproceedings{ 69 | zhu2024autonomous, 70 | title={Autonomous Driving with Spiking Neural Networks}, 71 | author={Rui-Jie Zhu and Ziqing Wang and Leilani H. Gilpin and Jason Eshraghian}, 72 | booktitle={The Thirty-eighth Annual Conference on Neural Information Processing Systems}, 73 | year={2024}, 74 | url={https://openreview.net/forum?id=95VyH4VxN9} 75 | } 76 | ``` 77 | -------------------------------------------------------------------------------- /sad/models/future_prediction_snn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import copy 4 | 5 | from sad.layers.convolutions import SpikingDeepLabHead 6 | from sad.layers.temporal_snn import SpatialGRU, Dual_LIF_temporal_mixer, BiGRU 7 | 8 | 9 | from sad.models.module.MS_ResNet import multi_step_sew_resnet18 10 | from spikingjelly.clock_driven.neuron import ( 11 | MultiStepParametricLIFNode, 12 | MultiStepLIFNode, 13 | ) 14 | 15 | 16 | class FuturePrediction(nn.Module): 17 | def __init__(self, in_channels, latent_dim, n_future, mixture=True): 18 | super(FuturePrediction, self).__init__() 19 | # self.n_spatial_gru = n_gru_blocks 20 | 21 | backbone = multi_step_sew_resnet18(pretrained=False, multi_step_neuron=MultiStepLIFNode) 22 | 23 | gru_in_channels = latent_dim 24 | self.layer1 = backbone.layer1 25 | self.layer1[0].conv1 = torch.nn.Conv2d(in_channels, in_channels, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), 26 | bias=False) 27 | self.layer1[0].bn1 = torch.nn.BatchNorm2d(in_channels) 28 | self.layer1[0].conv2 = torch.nn.Conv2d(in_channels, in_channels, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), 29 | bias=False) 30 | self.layer1[0].bn2 = torch.nn.BatchNorm2d(in_channels) 31 | 32 | self.layer1[1].conv1 = torch.nn.Conv2d(in_channels, in_channels, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), 33 | bias=False) 34 | self.layer1[1].bn1 = torch.nn.BatchNorm2d(in_channels) 35 | self.layer1[1].conv2 = torch.nn.Conv2d(in_channels, in_channels, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), 36 | bias=False) 37 | self.layer1[1].bn2 = torch.nn.BatchNorm2d(in_channels) 38 | 39 | # self.layer2 = backbone.layer2 40 | # self.layer3 = backbone.layer3 41 | # self.layer3[1].conv2 = torch.nn.Conv2d(256, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 42 | # self.layer3[1].bn2 = torch.nn.BatchNorm2d(64) 43 | 44 | self.layer2 = copy.deepcopy(backbone.layer1) 45 | self.layer2[0].conv1 = torch.nn.Conv2d(in_channels, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 46 | self.layer2[0].bn1 = torch.nn.BatchNorm2d(384) 47 | self.layer2[0].conv2 = torch.nn.Conv2d(384, in_channels, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 48 | self.layer2[0].bn2 = torch.nn.BatchNorm2d(in_channels) 49 | 50 | self.layer2[1].conv1 = torch.nn.Conv2d(in_channels, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 51 | self.layer2[1].bn1 = torch.nn.BatchNorm2d(384) 52 | self.layer2[1].conv2 = torch.nn.Conv2d(384, in_channels, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 53 | self.layer2[1].bn2 = torch.nn.BatchNorm2d(in_channels) 54 | 55 | self.layer3 = copy.deepcopy(self.layer2) 56 | 57 | # 修改第1个MultiStepBasicBlock中的BatchNorm2d层的特征数 58 | 59 | 60 | 61 | self.dual_lif = Dual_LIF_temporal_mixer(gru_in_channels, in_channels, n_future=n_future, mixture=mixture) 62 | # self.res_blocks1 = nn.Sequential(*[Block(in_channels) for _ in range(n_res_layers)]) 63 | # 64 | # self.spatial_grus = [] 65 | # self.res_blocks = [] 66 | # for i in range(self.n_spatial_gru): 67 | # self.spatial_grus.append(SpatialGRU(in_channels, in_channels)) 68 | # if i < self.n_spatial_gru - 1: 69 | # self.res_blocks.append(nn.Sequential(*[Block(in_channels) for _ in range(n_res_layers)])) 70 | # else: 71 | # self.res_blocks.append(DeepLabHead(in_channels, in_channels, 128)) 72 | # 73 | # self.spatial_grus = torch.nn.ModuleList(self.spatial_grus) 74 | # self.res_blocks = torch.nn.ModuleList(self.res_blocks) 75 | 76 | def forward(self, x, state): 77 | # x has shape (b, 1, c, h, w), state: torch.Tensor [b, n_present, hidden_size, h, w] 78 | x = self.dual_lif(x, state) 79 | 80 | 81 | # b, n_future, c, h, w = x.shape # 预测未来情况,这时候就已经有未来的feature了 82 | # x = self.res_blocks1(x.view(b * n_future, c, h, w)) # 过ResNet Block,此时没有未来feature 83 | # x = x.view(b, n_future, c, h, w) 84 | 85 | x = torch.cat([state, x], dim=1) # 回正后把未来feature和当前的状况做一个融合 86 | b, s, c, h, w = x.shape 87 | x = x.reshape(s, b, c, h, w) 88 | x = self.layer1(x) 89 | x = self.layer2(x) 90 | x = self.layer3(x) 91 | x = x.reshape(b, s, c, h, w) 92 | 93 | # hidden_state = x[:, 0] 94 | # for i in range(self.n_spatial_gru): 95 | # x = self.spatial_grus[i](x, hidden_state) # 使用Spatial GRU,正常计算 96 | # 97 | # b, s, c, h, w = x.shape 98 | # x = self.res_blocks[i](x.view(b * s, c, h, w)) # 过Res Blocks,特征提取,这一块可以被整合进入RWKV 99 | # x = x.view(b, s, c, h, w) 100 | 101 | return x -------------------------------------------------------------------------------- /sad/models/module/.ipynb_checkpoints/downsample-checkpoint.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from spikingjelly.clock_driven.neuron import MultiStepParametricLIFNode, MultiStepLIFNode 4 | from spikingjelly.clock_driven import layer 5 | 6 | def MS_conv_unit(in_channels, out_channels,kernel_size=1,padding=0,groups=1): 7 | return nn.Sequential( 8 | layer.SeqToANNContainer( 9 | nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding, groups=groups,bias=True), 10 | nn.BatchNorm2d(out_channels) # 这里可以进行改进 ? 11 | ) 12 | ) 13 | class MS_ConvBlock(nn.Module): 14 | def __init__(self, dim, 15 | mlp_ratio=4.0): 16 | super().__init__() 17 | self.neuron1 = MultiStepLIFNode(tau=2.0, detach_reset=True,backend='torch') 18 | self.conv1 = MS_conv_unit(dim, dim, 3, 1) 19 | 20 | self.neuron2 = MultiStepLIFNode(tau=2.0, detach_reset=True,backend='torch') 21 | self.conv2 = MS_conv_unit(dim, dim * mlp_ratio, 3, 1) 22 | 23 | self.neuron3 = MultiStepLIFNode(tau=2.0, detach_reset=True,backend='torch') 24 | 25 | self.conv3 = MS_conv_unit(dim*mlp_ratio, dim, 3, 1) 26 | 27 | 28 | def forward(self, x, hook=None, num=None): 29 | short_cut1 = x 30 | x = self.neuron1(x) 31 | if hook is not None: 32 | hook[self._get_name() + num + "_lif1"] = x.detach() 33 | x = self.conv1(x)+short_cut1 34 | short_cut2 = x 35 | x = self.neuron2(x) 36 | if hook is not None: 37 | hook[self._get_name() + num + "_lif2"] = x.detach() 38 | 39 | x = self.conv2(x) 40 | x = self.neuron3(x) 41 | if hook is not None: 42 | hook[self._get_name() + num + "_lif3"] = x.detach() 43 | 44 | x = self.conv3(x) 45 | x = x + short_cut2 46 | return x 47 | class MS_DownSampling(nn.Module): 48 | def __init__( 49 | self, 50 | in_channels=2, 51 | embed_dims=256, 52 | kernel_size=3, 53 | stride=2, 54 | padding=1, 55 | first_layer=True, 56 | ): 57 | super().__init__() 58 | 59 | self.encode_conv = nn.Conv2d( 60 | in_channels, 61 | embed_dims, 62 | kernel_size=kernel_size, 63 | stride=stride, 64 | padding=padding, 65 | ) 66 | 67 | self.encode_bn = nn.BatchNorm2d(embed_dims) 68 | if not first_layer: 69 | self.encode_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='torch') 70 | 71 | def forward(self, x, hook=None, num=None): 72 | 73 | T, B, _, _, _ = x.shape 74 | 75 | if hasattr(self, "encode_lif"): 76 | x = self.encode_lif(x) 77 | if hook is not None: 78 | hook[self._get_name() + num + "_lif"] = x.detach() 79 | x = self.encode_conv(x.flatten(0, 1)) 80 | _, _, H, W = x.shape 81 | x = self.encode_bn(x).reshape(T, B, -1, H, W).contiguous() 82 | return x 83 | class Sps_downsampling(nn.Module): 84 | def __init__( 85 | self, 86 | in_channels=3, 87 | mlp_ratios=4, 88 | embed_dim=[64,128,256], 89 | pooling_stat="1111", 90 | spike_mode="lif", 91 | ): 92 | super().__init__() 93 | self.downsample1_1 = MS_DownSampling( 94 | in_channels=in_channels, 95 | embed_dims=embed_dim[0] // 2, 96 | kernel_size=7, 97 | stride=2, 98 | padding=3, 99 | first_layer=True, 100 | ) 101 | 102 | self.ConvBlock1_1 = nn.ModuleList( 103 | [MS_ConvBlock(dim=embed_dim[0] // 2, mlp_ratio=mlp_ratios)] 104 | ) 105 | 106 | self.downsample1_2 = MS_DownSampling( 107 | in_channels=embed_dim[0] // 2, 108 | embed_dims=embed_dim[0], 109 | kernel_size=3, 110 | stride=2, 111 | padding=1, 112 | first_layer=False, 113 | ) 114 | 115 | self.ConvBlock1_2 = nn.ModuleList( 116 | [MS_ConvBlock(dim=embed_dim[0], mlp_ratio=mlp_ratios)] 117 | ) 118 | 119 | self.downsample2 = MS_DownSampling( 120 | in_channels=embed_dim[0], 121 | embed_dims=embed_dim[1], 122 | kernel_size=3, 123 | stride=2, 124 | padding=1, 125 | first_layer=False, 126 | ) 127 | 128 | self.ConvBlock2_1 = nn.ModuleList( 129 | [MS_ConvBlock(dim=embed_dim[1], mlp_ratio=mlp_ratios)] 130 | ) 131 | 132 | self.ConvBlock2_2 = nn.ModuleList( 133 | [MS_ConvBlock(dim=embed_dim[1], mlp_ratio=mlp_ratios)] 134 | ) 135 | 136 | self.downsample3 = MS_DownSampling( 137 | in_channels=embed_dim[1], 138 | embed_dims=embed_dim[2], 139 | kernel_size=3, 140 | stride=2, 141 | padding=1, 142 | first_layer=False, 143 | ) 144 | def forward(self,x, hook=None): 145 | x = self.downsample1_1(x,hook,"1") 146 | for blk in self.ConvBlock1_1: 147 | x = blk(x,hook,"1_1") # 112 148 | x = self.downsample1_2(x) 149 | for blk in self.ConvBlock1_2: 150 | x = blk(x,hook,"1_2") #56 151 | 152 | 153 | x = self.downsample2(x,hook,"2") 154 | 155 | for blk in self.ConvBlock2_1: 156 | x = blk(x,hook,"2_1") #28 157 | 158 | for blk in self.ConvBlock2_2: 159 | output_1 = blk(x,hook,"2_2") #28 160 | 161 | 162 | output_2 = self.downsample3(x,hook,"3") 163 | 164 | return output_1, output_2 165 | # model=Sps_downsampling() 166 | # x=torch.randn(1,1,3,224,224) 167 | # print(model(x).shape) 168 | -------------------------------------------------------------------------------- /sad/models/module/downsample.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from spikingjelly.clock_driven.neuron import MultiStepParametricLIFNode, MultiStepLIFNode 4 | from spikingjelly.clock_driven import layer 5 | 6 | def MS_conv_unit(in_channels, out_channels,kernel_size=1,padding=0,groups=1): 7 | return nn.Sequential( 8 | layer.SeqToANNContainer( 9 | nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding, groups=groups,bias=True), 10 | nn.BatchNorm2d(out_channels) # 这里可以进行改进 ? 11 | ) 12 | ) 13 | class MS_ConvBlock(nn.Module): 14 | def __init__(self, dim, 15 | mlp_ratio=4.0): 16 | super().__init__() 17 | self.neuron1 = MultiStepLIFNode(tau=2.0, detach_reset=True,backend='torch') 18 | self.conv1 = MS_conv_unit(dim, dim, 3, 1) 19 | 20 | self.neuron2 = MultiStepLIFNode(tau=2.0, detach_reset=True,backend='torch') 21 | self.conv2 = MS_conv_unit(dim, dim * mlp_ratio, 3, 1) 22 | 23 | self.neuron3 = MultiStepLIFNode(tau=2.0, detach_reset=True,backend='torch') 24 | 25 | self.conv3 = MS_conv_unit(dim*mlp_ratio, dim, 3, 1) 26 | 27 | 28 | def forward(self, x, hook=None, num=None): 29 | with torch.backends.cudnn.flags(enabled=False): 30 | short_cut1 = x 31 | x = self.neuron1(x) 32 | if hook is not None: 33 | hook[self._get_name() + num + "_lif1"] = x.detach() 34 | x = self.conv1(x) + short_cut1 35 | short_cut2 = x 36 | x = self.neuron2(x) 37 | if hook is not None: 38 | hook[self._get_name() + num + "_lif2"] = x.detach() 39 | 40 | x = self.conv2(x) 41 | x = self.neuron3(x) 42 | if hook is not None: 43 | hook[self._get_name() + num + "_lif3"] = x.detach() 44 | 45 | x = self.conv3(x) 46 | x = x + short_cut2 47 | return x 48 | class MS_DownSampling(nn.Module): 49 | def __init__( 50 | self, 51 | in_channels=2, 52 | embed_dims=256, 53 | kernel_size=3, 54 | stride=2, 55 | padding=1, 56 | first_layer=True, 57 | ): 58 | super().__init__() 59 | 60 | self.encode_conv = nn.Conv2d( 61 | in_channels, 62 | embed_dims, 63 | kernel_size=kernel_size, 64 | stride=stride, 65 | padding=padding, 66 | ) 67 | 68 | self.encode_bn = nn.BatchNorm2d(embed_dims) 69 | if not first_layer: 70 | self.encode_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='torch') 71 | 72 | def forward(self, x, hook=None, num=None): 73 | 74 | T, B, _, _, _ = x.shape 75 | 76 | if hasattr(self, "encode_lif"): 77 | x = self.encode_lif(x) 78 | if hook is not None: 79 | hook[self._get_name() + num + "_lif"] = x.detach() 80 | x = self.encode_conv(x.flatten(0, 1)) 81 | _, _, H, W = x.shape 82 | x = self.encode_bn(x).reshape(T, B, -1, H, W).contiguous() 83 | return x 84 | class Sps_downsampling(nn.Module): 85 | def __init__( 86 | self, 87 | in_channels=3, 88 | mlp_ratios=4, 89 | embed_dim=[64,128,256], 90 | pooling_stat="1111", 91 | spike_mode="lif", 92 | ): 93 | super().__init__() 94 | self.downsample1_1 = MS_DownSampling( 95 | in_channels=in_channels, 96 | embed_dims=embed_dim[0] // 2, 97 | kernel_size=7, 98 | stride=2, 99 | padding=3, 100 | first_layer=True, 101 | ) 102 | 103 | self.ConvBlock1_1 = nn.ModuleList( 104 | [MS_ConvBlock(dim=embed_dim[0] // 2, mlp_ratio=mlp_ratios)] 105 | ) 106 | 107 | self.downsample1_2 = MS_DownSampling( 108 | in_channels=embed_dim[0] // 2, 109 | embed_dims=embed_dim[0], 110 | kernel_size=3, 111 | stride=2, 112 | padding=1, 113 | first_layer=False, 114 | ) 115 | 116 | self.ConvBlock1_2 = nn.ModuleList( 117 | [MS_ConvBlock(dim=embed_dim[0], mlp_ratio=mlp_ratios)] 118 | ) 119 | 120 | self.downsample2 = MS_DownSampling( 121 | in_channels=embed_dim[0], 122 | embed_dims=embed_dim[1], 123 | kernel_size=3, 124 | stride=2, 125 | padding=1, 126 | first_layer=False, 127 | ) 128 | 129 | self.ConvBlock2_1 = nn.ModuleList( 130 | [MS_ConvBlock(dim=embed_dim[1], mlp_ratio=mlp_ratios)] 131 | ) 132 | 133 | self.ConvBlock2_2 = nn.ModuleList( 134 | [MS_ConvBlock(dim=embed_dim[1], mlp_ratio=mlp_ratios)] 135 | ) 136 | 137 | self.downsample3 = MS_DownSampling( 138 | in_channels=embed_dim[1], 139 | embed_dims=embed_dim[2], 140 | kernel_size=3, 141 | stride=2, 142 | padding=1, 143 | first_layer=False, 144 | ) 145 | def forward(self,x, hook=None): 146 | x = self.downsample1_1(x,hook,"1") 147 | for blk in self.ConvBlock1_1: 148 | x = blk(x,hook,"1_1") # 112 149 | x = self.downsample1_2(x) 150 | for blk in self.ConvBlock1_2: 151 | x = blk(x,hook,"1_2") #56 152 | 153 | 154 | x = self.downsample2(x,hook,"2") 155 | 156 | for blk in self.ConvBlock2_1: 157 | x = blk(x,hook,"2_1") #28 158 | 159 | for blk in self.ConvBlock2_2: 160 | output_1 = blk(x,hook,"2_2") #28 161 | 162 | 163 | output_2 = self.downsample3(x,hook,"3") 164 | 165 | return output_1, output_2 166 | # model=Sps_downsampling() 167 | # x=torch.randn(1,1,3,224,224) 168 | # print(model(x).shape) 169 | -------------------------------------------------------------------------------- /sad/models/encoder_snn_merge.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from timm.models.helpers import clean_state_dict 6 | from timm.models.layers import trunc_normal_ 7 | from timm.models.registry import register_model 8 | from timm.models.vision_transformer import _cfg 9 | from spikingjelly.clock_driven.neuron import ( 10 | MultiStepLIFNode, 11 | MultiStepParametricLIFNode, 12 | ) 13 | from sad.models.module import * 14 | from sad.layers.convolutions import SpikingUpsamplingConcat, SpikingDeepLabHead, DeepLabHead, MergeUpsamplingConcat 15 | from timm.models import ( 16 | create_model, 17 | safe_model_name, 18 | resume_checkpoint, 19 | load_checkpoint, 20 | convert_splitbn_model, 21 | model_parameters, 22 | ) 23 | 24 | class CoinMLP(nn.Module): 25 | def __init__( 26 | self, 27 | img_size_h=128, 28 | img_size_w=128, 29 | patch_size=16, 30 | in_channels=2, 31 | num_classes=11, 32 | embed_dims=512, 33 | num_heads=8, 34 | mlp_ratios=4, 35 | qkv_bias=False, 36 | qk_scale=None, 37 | drop_rate=0.0, 38 | attn_drop_rate=0.0, 39 | drop_path_rate=0.0, 40 | norm_layer=nn.LayerNorm, 41 | depths=[6, 8, 6], 42 | sr_ratios=[8, 4, 2], 43 | T=4, 44 | pooling_stat="1111", 45 | attn_mode="direct_xor", 46 | spike_mode="lif", 47 | get_embed=False, 48 | dvs_mode=False, 49 | TET=False, 50 | cml=False, 51 | pretrained=False, 52 | pretrained_cfg=None, 53 | ): 54 | super().__init__() 55 | self.num_classes = num_classes 56 | self.depths = depths 57 | 58 | self.T = T 59 | self.TET = TET 60 | self.dvs = dvs_mode 61 | self.D = 48 62 | self.C = 64 63 | self.depth_layer_1 = SpikingDeepLabHead(384, 384, hidden_channel=64) 64 | self.depth_layer_2 = MergeUpsamplingConcat(192 + 384, self.D) 65 | 66 | self.feature_layer_1 = SpikingDeepLabHead(384, 384, hidden_channel=64) 67 | self.feature_layer_2 = MergeUpsamplingConcat(192 + 384, self.C) 68 | 69 | 70 | 71 | dpr = [ 72 | x.item() for x in torch.linspace(0, drop_path_rate, depths) 73 | ] # stochastic depth decay rule 74 | 75 | patch_embed = Sps_downsampling( 76 | embed_dim=[int(embed_dims/4),int(embed_dims/2),embed_dims], 77 | pooling_stat=pooling_stat, 78 | spike_mode=spike_mode, 79 | ) 80 | 81 | 82 | setattr(self, f"patch_embed", patch_embed) 83 | 84 | # classification head 85 | self.head_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend="torch") 86 | self.head_lif_2 = MultiStepLIFNode(tau=2.0, detach_reset=True, backend="torch") 87 | self.apply(self._init_weights) 88 | 89 | def _init_weights(self, m): 90 | if isinstance(m, nn.Conv2d): 91 | trunc_normal_(m.weight, std=0.02) 92 | if m.bias is not None: 93 | nn.init.constant_(m.bias, 0) 94 | elif isinstance(m, nn.BatchNorm2d): 95 | nn.init.constant_(m.bias, 0) 96 | nn.init.constant_(m.weight, 1.0) 97 | 98 | def forward_features(self, x, hook=None): 99 | patch_embed = getattr(self, f"patch_embed") 100 | 101 | x_1, x_2 = patch_embed(x, hook=hook) 102 | # import pdb 103 | # pdb.set_trace() 104 | 105 | return x_1, x_2 106 | 107 | def forward(self, x, hook=None): 108 | x_1, x_2 = self.forward_features(x, hook=hook) 109 | 110 | x_1 = self.head_lif(x_1) 111 | x_2 = self.head_lif_2(x_2) 112 | 113 | 114 | # x_2 = torch.mean(x_2, dim=0) 115 | 116 | x_1 = torch.mean(x_1, dim=0) 117 | feature = self.feature_layer_1(x_2) 118 | feature = torch.mean(feature, dim=0) 119 | feature = self.feature_layer_2(feature, x_1) 120 | 121 | 122 | depth = self.depth_layer_1(x_2) 123 | depth = torch.mean(depth, dim=0) 124 | depth = self.depth_layer_2(depth, x_1) 125 | 126 | if hook is not None: 127 | hook["head_lif"] = x.detach() 128 | 129 | # feature = torch.mean(feature, dim=0) 130 | # depth = torch.mean(depth, dim=0) 131 | 132 | return feature, depth 133 | 134 | 135 | @register_model 136 | def sdt(**kwargs): 137 | model = CoinMLP( 138 | **kwargs, 139 | ) 140 | model.default_cfg = _cfg() 141 | return model 142 | 143 | def main(): 144 | encoder = create_model( 145 | "sdt", 146 | T = 3, # 默认时间步长 147 | pretrained = False, # 默认情况下不使用预训练模型 148 | drop_rate = 0.0, # 默认dropout率 149 | drop_path_rate = 0.2, # 默认drop path率 150 | drop_block_rate = None, # 默认drop block率,未指定 151 | num_heads = 8, # 默认头数 152 | num_classes = 1000, # 默认类别数 153 | pooling_stat = "1111", # 默认池化状态 154 | img_size_h = 480, # 默认图像高度,未指定 155 | img_size_w = 224, # 默认图像宽度,未指定 156 | patch_size = None, # 默认patch大小,未指定 157 | embed_dims = 384, # `args.dim`没有在您提供的参数解析器中直接列出,请指定一个默认值或确认是否有误 158 | mlp_ratios = 4, # 默认MLP比率 159 | in_channels = 3, # 默认输入通道数 160 | qkv_bias = False, # qkv偏置,默认未指定,这里设为False 161 | depths = 6, # 默认层数 162 | sr_ratios = 1, # 默认sr比率,未在参数解析器直接列出,这里设为1 163 | spike_mode = "lif", # 默认脉冲模式 164 | dvs_mode = False, # `args.dvs_mode`没有在您提供的参数解析器中直接列出,请指定一个默认值或确认是否有误 165 | TET = False, # 默认TET设置 166 | 167 | ) 168 | checkpoint = torch.load("/vol5/Coin-MLP-back/Pure-MLP-SNN/18M-with-downsample-conv/checkpoint-308.pth.tar", map_location="cpu") 169 | state_dict = clean_state_dict(checkpoint["state_dict"]) 170 | encoder.load_state_dict(state_dict, strict=False) 171 | x = torch.randn(3, 3, 3, 224, 480) 172 | encoder = encoder.cuda() 173 | y_1,y_2 = encoder(x.cuda()) 174 | print(y_1.shape) 175 | print(y_2.shape) 176 | 177 | 178 | if __name__ == "__main__": 179 | main() -------------------------------------------------------------------------------- /sad/models/module/sps.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from spikingjelly.clock_driven.neuron import ( 4 | MultiStepLIFNode, 5 | MultiStepParametricLIFNode, 6 | ) 7 | from timm.models.layers import to_2tuple 8 | 9 | 10 | class MS_SPS(nn.Module): 11 | def __init__( 12 | self, 13 | img_size_h=128, 14 | img_size_w=128, 15 | patch_size=4, 16 | in_channels=2, 17 | embed_dims=256, 18 | pooling_stat="1111", 19 | spike_mode="lif", 20 | ): 21 | super().__init__() 22 | self.image_size = [img_size_h, img_size_w] 23 | patch_size = to_2tuple(patch_size) 24 | self.patch_size = patch_size 25 | self.pooling_stat = pooling_stat 26 | 27 | self.C = in_channels 28 | self.H, self.W = ( 29 | self.image_size[0] // patch_size[0], 30 | self.image_size[1] // patch_size[1], 31 | ) 32 | self.num_patches = self.H * self.W 33 | self.proj_conv = nn.Conv2d( 34 | in_channels, embed_dims // 8, kernel_size=3, stride=1, padding=1, bias=False 35 | ) 36 | self.proj_bn = nn.BatchNorm2d(embed_dims // 8) 37 | if spike_mode == "lif": 38 | self.proj_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend="cupy") 39 | elif spike_mode == "plif": 40 | self.proj_lif = MultiStepParametricLIFNode( 41 | init_tau=2.0, detach_reset=True, backend="cupy" 42 | ) 43 | self.maxpool = nn.MaxPool2d( 44 | kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False 45 | ) 46 | 47 | self.proj_conv1 = nn.Conv2d( 48 | embed_dims // 8, 49 | embed_dims // 4, 50 | kernel_size=3, 51 | stride=1, 52 | padding=1, 53 | bias=False, 54 | ) 55 | self.proj_bn1 = nn.BatchNorm2d(embed_dims // 4) 56 | if spike_mode == "lif": 57 | self.proj_lif1 = MultiStepLIFNode( 58 | tau=2.0, detach_reset=True, backend="cupy" 59 | ) 60 | elif spike_mode == "plif": 61 | self.proj_lif1 = MultiStepParametricLIFNode( 62 | init_tau=2.0, detach_reset=True, backend="cupy" 63 | ) 64 | self.maxpool1 = nn.MaxPool2d( 65 | kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False 66 | ) 67 | 68 | self.proj_conv2 = nn.Conv2d( 69 | embed_dims // 4, 70 | embed_dims // 2, 71 | kernel_size=3, 72 | stride=1, 73 | padding=1, 74 | bias=False, 75 | ) 76 | self.proj_bn2 = nn.BatchNorm2d(embed_dims // 2) 77 | if spike_mode == "lif": 78 | self.proj_lif2 = MultiStepLIFNode( 79 | tau=2.0, detach_reset=True, backend="cupy" 80 | ) 81 | elif spike_mode == "plif": 82 | self.proj_lif2 = MultiStepParametricLIFNode( 83 | init_tau=2.0, detach_reset=True, backend="cupy" 84 | ) 85 | self.maxpool2 = nn.MaxPool2d( 86 | kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False 87 | ) 88 | 89 | self.proj_conv3 = nn.Conv2d( 90 | embed_dims // 2, embed_dims, kernel_size=3, stride=1, padding=1, bias=False 91 | ) 92 | self.proj_bn3 = nn.BatchNorm2d(embed_dims) 93 | if spike_mode == "lif": 94 | self.proj_lif3 = MultiStepLIFNode( 95 | tau=2.0, detach_reset=True, backend="cupy" 96 | ) 97 | elif spike_mode == "plif": 98 | self.proj_lif3 = MultiStepParametricLIFNode( 99 | init_tau=2.0, detach_reset=True, backend="cupy" 100 | ) 101 | self.maxpool3 = nn.MaxPool2d( 102 | kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False 103 | ) 104 | 105 | self.rpe_conv = nn.Conv2d( 106 | embed_dims, embed_dims, kernel_size=3, stride=1, padding=1, bias=False 107 | ) 108 | self.rpe_bn = nn.BatchNorm2d(embed_dims) 109 | if spike_mode == "lif": 110 | self.rpe_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend="cupy") 111 | elif spike_mode == "plif": 112 | self.rpe_lif = MultiStepParametricLIFNode( 113 | init_tau=2.0, detach_reset=True, backend="cupy" 114 | ) 115 | 116 | def forward(self, x, hook=None): 117 | T, B, _, H, W = x.shape 118 | ratio = 1 119 | x = self.proj_conv(x.flatten(0, 1)) # have some fire value 120 | x = self.proj_bn(x).reshape(T, B, -1, H // ratio, W // ratio).contiguous() 121 | x = self.proj_lif(x) 122 | if hook is not None: 123 | hook[self._get_name() + "_lif"] = x.detach() 124 | x = x.flatten(0, 1).contiguous() 125 | if self.pooling_stat[0] == "1": 126 | x = self.maxpool(x) 127 | ratio *= 2 128 | 129 | x = self.proj_conv1(x) 130 | x = self.proj_bn1(x).reshape(T, B, -1, H // ratio, W // ratio).contiguous() 131 | x = self.proj_lif1(x) 132 | if hook is not None: 133 | hook[self._get_name() + "_lif1"] = x.detach() 134 | x = x.flatten(0, 1).contiguous() 135 | if self.pooling_stat[1] == "1": 136 | x = self.maxpool1(x) 137 | ratio *= 2 138 | 139 | x = self.proj_conv2(x) 140 | x = self.proj_bn2(x).reshape(T, B, -1, H // ratio, W // ratio).contiguous() 141 | x = self.proj_lif2(x) 142 | if hook is not None: 143 | hook[self._get_name() + "_lif2"] = x.detach() 144 | x = x.flatten(0, 1).contiguous() 145 | if self.pooling_stat[2] == "1": 146 | x = self.maxpool2(x) 147 | ratio *= 2 148 | 149 | x = self.proj_conv3(x) 150 | x = self.proj_bn3(x) 151 | if self.pooling_stat[3] == "1": 152 | x = self.maxpool3(x) 153 | ratio *= 2 154 | 155 | x_feat = x 156 | x = self.proj_lif3(x.reshape(T, B, -1, H // ratio, W // ratio).contiguous()) 157 | if hook is not None: 158 | hook[self._get_name() + "_lif3"] = x.detach() 159 | x = x.flatten(0, 1).contiguous() 160 | x = self.rpe_conv(x) 161 | x = self.rpe_bn(x) 162 | x = (x + x_feat).reshape(T, B, -1, H // ratio, W // ratio).contiguous() 163 | 164 | H, W = H // self.patch_size[0], W // self.patch_size[1] 165 | return x, (H, W), hook 166 | -------------------------------------------------------------------------------- /sad/models/module/.ipynb_checkpoints/sps-checkpoint.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from spikingjelly.clock_driven.neuron import ( 4 | MultiStepLIFNode, 5 | MultiStepParametricLIFNode, 6 | ) 7 | from timm.models.layers import to_2tuple 8 | 9 | 10 | class MS_SPS(nn.Module): 11 | def __init__( 12 | self, 13 | img_size_h=128, 14 | img_size_w=128, 15 | patch_size=4, 16 | in_channels=2, 17 | embed_dims=256, 18 | pooling_stat="1111", 19 | spike_mode="lif", 20 | ): 21 | super().__init__() 22 | self.image_size = [img_size_h, img_size_w] 23 | patch_size = to_2tuple(patch_size) 24 | self.patch_size = patch_size 25 | self.pooling_stat = pooling_stat 26 | 27 | self.C = in_channels 28 | self.H, self.W = ( 29 | self.image_size[0] // patch_size[0], 30 | self.image_size[1] // patch_size[1], 31 | ) 32 | self.num_patches = self.H * self.W 33 | self.proj_conv = nn.Conv2d( 34 | in_channels, embed_dims // 8, kernel_size=3, stride=1, padding=1, bias=False 35 | ) 36 | self.proj_bn = nn.BatchNorm2d(embed_dims // 8) 37 | if spike_mode == "lif": 38 | self.proj_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend="cupy") 39 | elif spike_mode == "plif": 40 | self.proj_lif = MultiStepParametricLIFNode( 41 | init_tau=2.0, detach_reset=True, backend="cupy" 42 | ) 43 | self.maxpool = nn.MaxPool2d( 44 | kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False 45 | ) 46 | 47 | self.proj_conv1 = nn.Conv2d( 48 | embed_dims // 8, 49 | embed_dims // 4, 50 | kernel_size=3, 51 | stride=1, 52 | padding=1, 53 | bias=False, 54 | ) 55 | self.proj_bn1 = nn.BatchNorm2d(embed_dims // 4) 56 | if spike_mode == "lif": 57 | self.proj_lif1 = MultiStepLIFNode( 58 | tau=2.0, detach_reset=True, backend="cupy" 59 | ) 60 | elif spike_mode == "plif": 61 | self.proj_lif1 = MultiStepParametricLIFNode( 62 | init_tau=2.0, detach_reset=True, backend="cupy" 63 | ) 64 | self.maxpool1 = nn.MaxPool2d( 65 | kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False 66 | ) 67 | 68 | self.proj_conv2 = nn.Conv2d( 69 | embed_dims // 4, 70 | embed_dims // 2, 71 | kernel_size=3, 72 | stride=1, 73 | padding=1, 74 | bias=False, 75 | ) 76 | self.proj_bn2 = nn.BatchNorm2d(embed_dims // 2) 77 | if spike_mode == "lif": 78 | self.proj_lif2 = MultiStepLIFNode( 79 | tau=2.0, detach_reset=True, backend="cupy" 80 | ) 81 | elif spike_mode == "plif": 82 | self.proj_lif2 = MultiStepParametricLIFNode( 83 | init_tau=2.0, detach_reset=True, backend="cupy" 84 | ) 85 | self.maxpool2 = nn.MaxPool2d( 86 | kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False 87 | ) 88 | 89 | self.proj_conv3 = nn.Conv2d( 90 | embed_dims // 2, embed_dims, kernel_size=3, stride=1, padding=1, bias=False 91 | ) 92 | self.proj_bn3 = nn.BatchNorm2d(embed_dims) 93 | if spike_mode == "lif": 94 | self.proj_lif3 = MultiStepLIFNode( 95 | tau=2.0, detach_reset=True, backend="cupy" 96 | ) 97 | elif spike_mode == "plif": 98 | self.proj_lif3 = MultiStepParametricLIFNode( 99 | init_tau=2.0, detach_reset=True, backend="cupy" 100 | ) 101 | self.maxpool3 = nn.MaxPool2d( 102 | kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False 103 | ) 104 | 105 | self.rpe_conv = nn.Conv2d( 106 | embed_dims, embed_dims, kernel_size=3, stride=1, padding=1, bias=False 107 | ) 108 | self.rpe_bn = nn.BatchNorm2d(embed_dims) 109 | if spike_mode == "lif": 110 | self.rpe_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend="cupy") 111 | elif spike_mode == "plif": 112 | self.rpe_lif = MultiStepParametricLIFNode( 113 | init_tau=2.0, detach_reset=True, backend="cupy" 114 | ) 115 | 116 | def forward(self, x, hook=None): 117 | T, B, _, H, W = x.shape 118 | ratio = 1 119 | x = self.proj_conv(x.flatten(0, 1)) # have some fire value 120 | x = self.proj_bn(x).reshape(T, B, -1, H // ratio, W // ratio).contiguous() 121 | x = self.proj_lif(x) 122 | if hook is not None: 123 | hook[self._get_name() + "_lif"] = x.detach() 124 | x = x.flatten(0, 1).contiguous() 125 | if self.pooling_stat[0] == "1": 126 | x = self.maxpool(x) 127 | ratio *= 2 128 | 129 | x = self.proj_conv1(x) 130 | x = self.proj_bn1(x).reshape(T, B, -1, H // ratio, W // ratio).contiguous() 131 | x = self.proj_lif1(x) 132 | if hook is not None: 133 | hook[self._get_name() + "_lif1"] = x.detach() 134 | x = x.flatten(0, 1).contiguous() 135 | if self.pooling_stat[1] == "1": 136 | x = self.maxpool1(x) 137 | ratio *= 2 138 | 139 | x = self.proj_conv2(x) 140 | x = self.proj_bn2(x).reshape(T, B, -1, H // ratio, W // ratio).contiguous() 141 | x = self.proj_lif2(x) 142 | if hook is not None: 143 | hook[self._get_name() + "_lif2"] = x.detach() 144 | x = x.flatten(0, 1).contiguous() 145 | if self.pooling_stat[2] == "1": 146 | x = self.maxpool2(x) 147 | ratio *= 2 148 | 149 | x = self.proj_conv3(x) 150 | x = self.proj_bn3(x) 151 | if self.pooling_stat[3] == "1": 152 | x = self.maxpool3(x) 153 | ratio *= 2 154 | 155 | x_feat = x 156 | x = self.proj_lif3(x.reshape(T, B, -1, H // ratio, W // ratio).contiguous()) 157 | if hook is not None: 158 | hook[self._get_name() + "_lif3"] = x.detach() 159 | x = x.flatten(0, 1).contiguous() 160 | x = self.rpe_conv(x) 161 | x = self.rpe_bn(x) 162 | x = (x + x_feat).reshape(T, B, -1, H // ratio, W // ratio).contiguous() 163 | 164 | H, W = H // self.patch_size[0], W // self.patch_size[1] 165 | return x, (H, W), hook 166 | -------------------------------------------------------------------------------- /sad/config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from fvcore.common.config import CfgNode as _CfgNode 3 | 4 | 5 | def convert_to_dict(cfg_node, key_list=[]): 6 | """Convert a config node to dictionary.""" 7 | _VALID_TYPES = {tuple, list, str, int, float, bool} 8 | if not isinstance(cfg_node, _CfgNode): 9 | if type(cfg_node) not in _VALID_TYPES: 10 | print( 11 | 'Key {} with value {} is not a valid type; valid types: {}'.format( 12 | '.'.join(key_list), type(cfg_node), _VALID_TYPES 13 | ), 14 | ) 15 | return cfg_node 16 | else: 17 | cfg_dict = dict(cfg_node) 18 | for k, v in cfg_dict.items(): 19 | cfg_dict[k] = convert_to_dict(v, key_list + [k]) 20 | return cfg_dict 21 | 22 | 23 | class CfgNode(_CfgNode): 24 | """Remove once https://github.com/rbgirshick/yacs/issues/19 is merged.""" 25 | 26 | def convert_to_dict(self): 27 | return convert_to_dict(self) 28 | 29 | 30 | CN = CfgNode 31 | 32 | _C = CN() 33 | _C.LOG_DIR = 'tensorboard_logs' 34 | _C.TAG = 'default' 35 | 36 | _C.GPUS = [0] # which gpus to use 37 | _C.PRECISION = 32 # 16bit or 32bit 38 | _C.BATCHSIZE = 3 39 | _C.EPOCHS = 20 40 | 41 | _C.N_WORKERS = 5 42 | _C.VIS_INTERVAL = 5000 43 | _C.LOGGING_INTERVAL = 500 44 | 45 | _C.PRETRAINED = CN() 46 | _C.PRETRAINED.LOAD_WEIGHTS = False 47 | _C.PRETRAINED.PATH = '' 48 | 49 | _C.DATASET = CN() 50 | _C.DATASET.DATAROOT = '/data/Nuscenes' 51 | _C.DATASET.VERSION = 'trainval' 52 | _C.DATASET.NAME = 'nuscenes' 53 | _C.DATASET.MAP_FOLDER = '/data/Nuscenes' 54 | _C.DATASET.IGNORE_INDEX = 255 # Ignore index when creating flow/offset labels 55 | _C.DATASET.FILTER_INVISIBLE_VEHICLES = True # Filter vehicles that are not visible from the cameras 56 | _C.DATASET.SAVE_DIR = 'datas' 57 | 58 | _C.TIME_RECEPTIVE_FIELD = 3 # how many frames of temporal context (1 for single timeframe) 59 | _C.N_FUTURE_FRAMES = 4 # how many time steps into the future to predict 60 | 61 | _C.IMAGE = CN() 62 | _C.IMAGE.FINAL_DIM = (224, 480) 63 | _C.IMAGE.RESIZE_SCALE = 0.3 64 | _C.IMAGE.TOP_CROP = 46 65 | _C.IMAGE.ORIGINAL_HEIGHT = 900 # Original input RGB camera height 66 | _C.IMAGE.ORIGINAL_WIDTH = 1600 # Original input RGB camera width 67 | _C.IMAGE.NAMES = ['CAM_FRONT_LEFT', 'CAM_FRONT', 'CAM_FRONT_RIGHT', 'CAM_BACK_LEFT', 'CAM_BACK', 'CAM_BACK_RIGHT'] 68 | 69 | _C.LIFT = CN() # image to BEV lifting 70 | _C.LIFT.X_BOUND = [-50.0, 50.0, 0.5] # Forward 71 | _C.LIFT.Y_BOUND = [-50.0, 50.0, 0.5] # Sides 72 | _C.LIFT.Z_BOUND = [-10.0, 10.0, 20.0] # Height 73 | _C.LIFT.D_BOUND = [2.0, 50.0, 1.0] 74 | _C.LIFT.GT_DEPTH = False 75 | _C.LIFT.DISCOUNT = 0.5 76 | 77 | _C.EGO = CN() 78 | _C.EGO.WIDTH = 1.85 79 | _C.EGO.HEIGHT = 4.084 80 | 81 | _C.MODEL = CN() 82 | 83 | _C.MODEL.ENCODER = CN() 84 | _C.MODEL.ENCODER.DOWNSAMPLE = 8 85 | _C.MODEL.ENCODER.NAME = 'efficientnet-b4' 86 | _C.MODEL.ENCODER.OUT_CHANNELS = 64 87 | _C.MODEL.ENCODER.USE_DEPTH_DISTRIBUTION = True 88 | 89 | _C.MODEL.TEMPORAL_MODEL = CN() 90 | _C.MODEL.TEMPORAL_MODEL.NAME = 'temporal_block' # type of temporal model 91 | _C.MODEL.TEMPORAL_MODEL.START_OUT_CHANNELS = 64 92 | _C.MODEL.TEMPORAL_MODEL.EXTRA_IN_CHANNELS = 0 93 | _C.MODEL.TEMPORAL_MODEL.INBETWEEN_LAYERS = 0 94 | _C.MODEL.TEMPORAL_MODEL.PYRAMID_POOLING = True 95 | _C.MODEL.TEMPORAL_MODEL.INPUT_EGOPOSE = True 96 | 97 | _C.MODEL.DISTRIBUTION = CN() 98 | _C.MODEL.DISTRIBUTION.LATENT_DIM = 32 99 | _C.MODEL.DISTRIBUTION.MIN_LOG_SIGMA = -5.0 100 | _C.MODEL.DISTRIBUTION.MAX_LOG_SIGMA = 5.0 101 | 102 | _C.MODEL.FUTURE_PRED = CN() 103 | _C.MODEL.FUTURE_PRED.N_GRU_BLOCKS = 2 104 | _C.MODEL.FUTURE_PRED.N_RES_LAYERS = 1 105 | _C.MODEL.FUTURE_PRED.MIXTURE = True 106 | 107 | _C.MODEL.DECODER = CN() 108 | 109 | _C.MODEL.BN_MOMENTUM = 0.1 110 | 111 | _C.SEMANTIC_SEG = CN() 112 | 113 | _C.SEMANTIC_SEG.VEHICLE = CN() 114 | _C.SEMANTIC_SEG.VEHICLE.WEIGHTS = [1.0, 2.0] 115 | _C.SEMANTIC_SEG.VEHICLE.USE_TOP_K = True # backprop only top-k hardest pixels 116 | _C.SEMANTIC_SEG.VEHICLE.TOP_K_RATIO = 0.25 117 | 118 | _C.SEMANTIC_SEG.PEDESTRIAN = CN() 119 | _C.SEMANTIC_SEG.PEDESTRIAN.ENABLED = True 120 | _C.SEMANTIC_SEG.PEDESTRIAN.WEIGHTS = [1.0, 10.0] 121 | _C.SEMANTIC_SEG.PEDESTRIAN.USE_TOP_K = True 122 | _C.SEMANTIC_SEG.PEDESTRIAN.TOP_K_RATIO = 0.25 123 | 124 | _C.SEMANTIC_SEG.HDMAP = CN() 125 | _C.SEMANTIC_SEG.HDMAP.ENABLED = True 126 | _C.SEMANTIC_SEG.HDMAP.ELEMENTS = ['lane_divider', 'drivable_area'] 127 | _C.SEMANTIC_SEG.HDMAP.WEIGHTS = [[1.0, 5.0], [1.0, 1.0]] 128 | _C.SEMANTIC_SEG.HDMAP.TRAIN_WEIGHT = [1, 1] 129 | _C.SEMANTIC_SEG.HDMAP.USE_TOP_K = [True, False] 130 | _C.SEMANTIC_SEG.HDMAP.TOP_K_RATIO = [0.25, 0.25] 131 | 132 | _C.INSTANCE_SEG = CN() 133 | _C.INSTANCE_SEG.ENABLED = True 134 | 135 | _C.INSTANCE_FLOW = CN() 136 | _C.INSTANCE_FLOW.ENABLED = True 137 | 138 | _C.PROBABILISTIC = CN() 139 | _C.PROBABILISTIC.ENABLED = True # learn a distribution over futures 140 | _C.PROBABILISTIC.METHOD = 'GAUSSIAN' # [BERNOULLI, GAUSSIAN, MIXGAUSSIAN] 141 | 142 | _C.PLANNING = CN() 143 | _C.PLANNING.ENABLED = True 144 | _C.PLANNING.GRU_STATE_SIZE = 256 145 | _C.PLANNING.SAMPLE_NUM = 600 146 | _C.PLANNING.COMMAND = ['LEFT', 'FORWARD', 'RIGHT'] 147 | 148 | _C.FUTURE_DISCOUNT = 0.95 149 | 150 | _C.OPTIMIZER = CN() 151 | _C.OPTIMIZER.LR = 3e-4 152 | _C.OPTIMIZER.WEIGHT_DECAY = 1e-7 153 | _C.GRAD_NORM_CLIP = 5 154 | 155 | _C.COST_FUNCTION = CN() 156 | _C.COST_FUNCTION.SAFETY = 0.1 157 | _C.COST_FUNCTION.LAMBDA = 1. 158 | _C.COST_FUNCTION.HEADWAY = 1. 159 | _C.COST_FUNCTION.LRDIVIDER = 10. 160 | _C.COST_FUNCTION.COMFORT = 0.1 161 | _C.COST_FUNCTION.PROGRESS = 0.5 162 | _C.COST_FUNCTION.VOLUME = 100. 163 | 164 | def get_parser(): 165 | parser = argparse.ArgumentParser(description='Fiery training') 166 | parser.add_argument('--config-file', default='', metavar='FILE', help='path to config file') 167 | parser.add_argument( 168 | 'opts', help='Modify config options using the command-line', default=None, nargs=argparse.REMAINDER, 169 | ) 170 | return parser 171 | 172 | 173 | def get_cfg(args=None, cfg_dict=None): 174 | """ First get default config. Then merge cfg_dict. Then merge according to args. """ 175 | 176 | cfg = _C.clone() 177 | 178 | if cfg_dict is not None: 179 | tmp = CfgNode(cfg_dict) 180 | for i in tmp.COST_FUNCTION: 181 | tmp.COST_FUNCTION.update({i: float(tmp.COST_FUNCTION.get(i))}) 182 | cfg.merge_from_other_cfg(tmp) 183 | 184 | if args is not None: 185 | if args.config_file: 186 | cfg.merge_from_file(args.config_file) 187 | cfg.merge_from_list(args.opts) 188 | # cfg.freeze() 189 | return cfg 190 | -------------------------------------------------------------------------------- /maps/hdmap_generate.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 as cv 3 | import json 4 | from collections import deque 5 | from pathlib import Path 6 | import h5py 7 | import os 8 | import tqdm 9 | 10 | classes = { 11 | 0: [0, 0, 0], # unlabeled 12 | 1: [0, 0, 0], # building 13 | 2: [0, 0, 0], # fence 14 | 3: [0, 0, 0], # other 15 | 4: [0, 255, 0], # pedestrian 16 | 5: [0, 0, 0], # pole 17 | 6: [157, 234, 50], # road line 18 | 7: [128, 64, 128], # road 19 | 8: [255, 255, 255], # sidewalk 20 | 9: [0, 0, 0], # vegetation 21 | 10: [0, 0, 255], # vehicle 22 | 11: [0, 0, 0], # wall 23 | 12: [0, 0, 0], # traffic sign 24 | 13: [0, 0, 0], # sky 25 | 14: [0, 0, 0], # ground 26 | 15: [0, 0, 0], # bridge 27 | 16: [0, 0, 0], # rail track 28 | 17: [0, 0, 0], # guard rail 29 | 18: [0, 0, 0], # traffic light 30 | 19: [0, 0, 0], # static 31 | 20: [0, 0, 0], # dynamic 32 | 21: [0, 0, 0], # water 33 | 22: [0, 0, 0], # terrain 34 | 23: [255, 0, 0], # red light 35 | 24: [0, 0, 0], # yellow light #TODO should be red 36 | 25: [0, 0, 0], # green light 37 | 26: [157, 234, 50], # stop sign 38 | 27: [157, 234, 50], # stop line marking 39 | 40 | } 41 | 42 | 43 | COLOR_BLACK = (0, 0, 0) 44 | COLOR_RED = (255, 0, 0) 45 | COLOR_GREEN = (0, 255, 0) 46 | COLOR_BLUE = (0, 0, 255) 47 | COLOR_CYAN = (0, 255, 255) 48 | COLOR_MAGENTA = (255, 0, 255) 49 | COLOR_MAGENTA_2 = (255, 140, 255) 50 | COLOR_YELLOW = (255, 255, 0) 51 | COLOR_YELLOW_2 = (160, 160, 0) 52 | COLOR_WHITE = (255, 255, 255) 53 | COLOR_ALUMINIUM_0 = (238, 238, 236) 54 | COLOR_ALUMINIUM_3 = (136, 138, 133) 55 | COLOR_ALUMINIUM_5 = (46, 52, 54) 56 | 57 | pixels_per_meter = 5 58 | width = 512 59 | pixels_ev_to_bottom = 256 60 | 61 | def tint(color, factor): 62 | r, g, b = color 63 | r = int(r + (255-r) * factor) 64 | g = int(g + (255-g) * factor) 65 | b = int(b + (255-b) * factor) 66 | r = min(r, 255) 67 | g = min(g, 255) 68 | b = min(b, 255) 69 | return (r, g, b) 70 | 71 | def get_warp_transform(ev_loc, ev_rot, world_offset): 72 | ev_loc_in_px = world_to_pixel(ev_loc, world_offset) 73 | yaw = np.deg2rad(ev_rot) 74 | 75 | forward_vec = np.array([np.cos(yaw), np.sin(yaw)]) 76 | right_vec = np.array([np.cos(yaw + 0.5*np.pi), np.sin(yaw + 0.5*np.pi)]) 77 | 78 | bottom_left = ev_loc_in_px - pixels_ev_to_bottom * forward_vec - (0.5*width) * right_vec 79 | top_left = ev_loc_in_px + (width-pixels_ev_to_bottom) * forward_vec - (0.5*width) * right_vec 80 | top_right = ev_loc_in_px + (width-pixels_ev_to_bottom) * forward_vec + (0.5*width) * right_vec 81 | 82 | src_pts = np.stack((bottom_left, top_left, top_right), axis=0).astype(np.float32) 83 | dst_pts = np.array([[0, width-1], 84 | [0, 0], 85 | [width-1, 0]], dtype=np.float32) 86 | return cv.getAffineTransform(src_pts, dst_pts) 87 | 88 | def world_to_pixel(location, world_offset, projective=False): 89 | """Converts the world coordinates to pixel coordinates""" 90 | x = pixels_per_meter * (location[0] - world_offset[0]) 91 | y = pixels_per_meter * (location[1] - world_offset[1]) 92 | 93 | if projective: 94 | p = np.array([x, y, 1], dtype=np.float32) 95 | else: 96 | p = np.array([x, y], dtype=np.float32) 97 | return p 98 | 99 | Town2map = { 100 | "town01": "Town01.h5", 101 | "town02": "Town02.h5", 102 | "town03": "Town03.h5", 103 | "town04": "Town04.h5", 104 | "town05": "Town05.h5", 105 | "town06": "Town06.h5", 106 | "town07": "Town07.h5", 107 | "town10": "Town10HD.h5", 108 | 109 | } 110 | 111 | root = "" 112 | map_path = "" 113 | town_list = list(os.listdir(root)) 114 | # town_list = town_list[18:] 115 | # town_list = ["town02_short","town02_tiny"] 116 | for town in tqdm.tqdm(town_list): 117 | 118 | maps_h5_path = os.path.join(map_path, Town2map[town[:6]]) 119 | 120 | with h5py.File(maps_h5_path, 'r', libver='latest', swmr=True) as hf: 121 | road = np.array(hf['road'], dtype=np.uint8) 122 | lane_marking_yellow_broken = np.array(hf['lane_marking_yellow_broken'], dtype=np.uint8) 123 | lane_marking_yellow_solid = np.array(hf['lane_marking_yellow_solid'], dtype=np.uint8) 124 | lane_marking_white_broken = np.array(hf['lane_marking_white_broken'], dtype=np.uint8) 125 | lane_marking_white_solid = np.array(hf['lane_marking_white_solid'], dtype=np.uint8) 126 | world_offset = np.array(hf.attrs['world_offset_in_meters'], dtype=np.float32) 127 | 128 | 129 | town_folder = os.path.join(root, town) 130 | route_list = [route for route in os.listdir(town_folder) if os.path.isdir(os.path.join(town_folder, route)) ] 131 | 132 | for route in tqdm.tqdm(route_list): 133 | route_folder = os.path.join(town_folder, route) 134 | os.makedirs(os.path.join(route_folder, "hdmap"), exist_ok=True) 135 | measurement_folder = os.path.join(route_folder, "meta") 136 | measurement_files = os.listdir(measurement_folder) 137 | for measurement in measurement_files: 138 | with open(os.path.join(measurement_folder, measurement), "r") as read_file: 139 | measurement_data = json.load(read_file) 140 | x = measurement_data['x'] 141 | y = measurement_data['y'] 142 | 143 | theta = measurement_data['theta'] 144 | if np.isnan(theta): 145 | theta = 0 146 | 147 | 148 | ev_loc = [y , -x] 149 | ev_rot = np.rad2deg(theta) - 90 150 | 151 | M_warp = get_warp_transform(ev_loc, ev_rot, world_offset) 152 | road_mask = cv.warpAffine(road, M_warp, (width, width)).astype(np.bool) 153 | lane_mask_white_broken = cv.warpAffine(lane_marking_white_broken, M_warp, 154 | (width, width)).astype(np.bool) 155 | lane_mask_white_solid = cv.warpAffine(lane_marking_white_solid, M_warp, 156 | (width, width)).astype(np.bool) 157 | lane_mask_yellow_broken = cv.warpAffine(lane_marking_yellow_broken, M_warp, 158 | (width, width)).astype(np.bool) 159 | lane_mask_yellow_solid = cv.warpAffine(lane_marking_yellow_solid, M_warp, 160 | (width, width)).astype(np.bool) 161 | 162 | image = np.zeros([width, width, 3], dtype=np.uint8) 163 | image[road_mask] = COLOR_ALUMINIUM_5 164 | image[lane_mask_white_broken] = COLOR_MAGENTA 165 | image[lane_mask_white_solid] = COLOR_MAGENTA 166 | image[lane_mask_yellow_broken] = COLOR_MAGENTA 167 | image[lane_mask_yellow_solid] = COLOR_MAGENTA 168 | cv.imwrite(os.path.join(route_folder, "hdmap", measurement.replace('json', 'png')), image) 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | # image[lane_mask_broken] = COLOR_MAGENTA_2 179 | 180 | # cv.imwrite("hdmap.png", image) 181 | 182 | 183 | # seg = cv.imread("0009.png")[:, :, 2] 184 | # result = np.zeros((seg.shape[0], seg.shape[1], 3)) 185 | # for key, value in classes.items(): 186 | # result[np.where(seg == key)] = value 187 | 188 | # cv.imwrite("seg_vis.png", result) -------------------------------------------------------------------------------- /sad/models/planning_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | from sad.layers.convolutions import Bottleneck, SpikingBottleneck 7 | from sad.layers.temporal import SpatialGRU, Dual_GRU, BiGRU 8 | from sad.cost import Cost_Function 9 | 10 | from spikingjelly.clock_driven.rnn import SpikingGRUCell 11 | from spikingjelly.clock_driven import surrogate 12 | 13 | class Planning(nn.Module): 14 | def __init__(self, cfg, feature_channel, gru_input_size=6, gru_state_size=256): 15 | super(Planning, self).__init__() 16 | self.cost_function = Cost_Function(cfg) 17 | 18 | self.sample_num = cfg.PLANNING.SAMPLE_NUM 19 | self.commands = cfg.PLANNING.COMMAND 20 | assert self.sample_num % 3 == 0 21 | self.num = int(self.sample_num / 3) 22 | 23 | self.reduce_channel = nn.Sequential( 24 | SpikingBottleneck(feature_channel, feature_channel, downsample=True), 25 | SpikingBottleneck(feature_channel, int(feature_channel/2), downsample=True), 26 | SpikingBottleneck(int(feature_channel/2), int(feature_channel/2), downsample=True), 27 | SpikingBottleneck(int(feature_channel/2), int(feature_channel/8)) 28 | ) 29 | 30 | # self.GRU = nn.GRUCell(gru_input_size, gru_state_size) 31 | self.GRU = SpikingGRUCell(gru_input_size, gru_state_size, surrogate_function1=surrogate.ATan(), surrogate_function2=surrogate.ATan()) 32 | self.decoder = nn.Sequential( 33 | nn.Linear(gru_state_size, gru_state_size), 34 | nn.Linear(gru_state_size, 2) 35 | ) 36 | 37 | 38 | def compute_L2(self, trajs, gt_traj): 39 | ''' 40 | trajs: torch.Tensor (B, N, n_future, 3) 41 | gt_traj: torch.Tensor (B,1, n_future, 3) 42 | ''' 43 | if trajs.ndim == 4 and gt_traj.ndim == 4: 44 | return ((trajs[:,:,:,:2] - gt_traj[:,:,:,:2]) ** 2).sum(dim=-1) 45 | if trajs.ndim == 3 and gt_traj.ndim == 3: 46 | return ((trajs[:, :, :2] - gt_traj[:, :, :2]) ** 2).sum(dim=-1) 47 | 48 | raise ValueError('trajs ndim != gt_traj ndim') 49 | 50 | def select(self, trajs, cost_volume, semantic_pred, lane_divider, drivable_area, target_points, k=1): 51 | ''' 52 | trajs: torch.Tensor (B, N, n_future, 3) 53 | cost_volume: torch.Tensor (B, n_future, 200, 200) 54 | semantic_pred: torch.Tensor(B, n_future, 200, 200) 55 | lane_divider: torch.Tensor(B, 1/2, 200, 200) 56 | drivable_area: torch.Tensor(B, 1/2, 200, 200) 57 | target_points: torch.Tensor (B, 2) 58 | ''' 59 | sm_cost_fc, sm_cost_fo = self.cost_function(cost_volume, trajs[:,:,:,:2], semantic_pred, lane_divider, drivable_area, target_points) 60 | 61 | CS = sm_cost_fc + sm_cost_fo.sum(dim=-1) 62 | CC, KK = torch.topk(CS, k, dim=-1, largest=False) 63 | 64 | ii = torch.arange(len(trajs)) 65 | select_traj = trajs[ii[:,None], KK].squeeze(1) # (B, n_future, 3) 66 | 67 | return select_traj 68 | 69 | def loss(self, trajs, gt_trajs, cost_volume, semantic_pred, lane_divider, drivable_area, target_points): 70 | ''' 71 | trajs: torch.Tensor (B, N, n_future, 3) 72 | gt_trajs: torch.Tensor (B, n_future, 3) 73 | cost_volume: torch.Tensor (B, n_future, 200, 200) 74 | semantic_pred: torch.Tensor(B, n_future, 200, 200) 75 | lane_divider: torch.Tensor(B, 1/2, 200, 200) 76 | drivable_area: torch.Tensor(B, 1/2, 200, 200) 77 | target_points: torch.Tensor (B, 2) 78 | ''' 79 | sm_cost_fc, sm_cost_fo = self.cost_function(cost_volume, trajs[:, :, :, :2], semantic_pred, lane_divider, drivable_area, target_points) 80 | 81 | if gt_trajs.ndim == 3: 82 | gt_trajs = gt_trajs[:, None] 83 | 84 | gt_cost_fc, gt_cost_fo = self.cost_function(cost_volume, gt_trajs[:, :, :, :2], semantic_pred, lane_divider, drivable_area, target_points) 85 | 86 | L, _ = F.relu( 87 | F.relu(gt_cost_fo - sm_cost_fo).sum(-1) + (gt_cost_fc - sm_cost_fc) + self.compute_L2(trajs, gt_trajs).mean( 88 | dim=-1)).max(dim=-1) 89 | 90 | return torch.mean(L) 91 | 92 | def forward(self,cam_front, trajs, gt_trajs, cost_volume, semantic_pred, hd_map, commands, target_points): 93 | ''' 94 | cam_front: torch.Tensor (B, 64, 60, 28) 95 | trajs: torch.Tensor (B, N, n_future, 3) 96 | gt_trajs: torch.Tensor (B, n_future, 3) 97 | cost_volume: torch.Tensor (B, n_future, 200, 200) 98 | semantic_pred: torch.Tensor(B, n_future, 200, 200) 99 | hd_map: torch.Tensor (B, 2/4, 200, 200) 100 | commands: List (B) 101 | target_points: (B, 2) 102 | ''' 103 | 104 | cur_trajs = [] 105 | for i in range(len(commands)): 106 | command = commands[i] 107 | traj = trajs[i] 108 | if command == 'LEFT': 109 | cur_trajs.append(traj[:self.num].repeat(3, 1, 1)) 110 | elif command == 'FORWARD': 111 | cur_trajs.append(traj[self.num:self.num * 2].repeat(3, 1, 1)) 112 | elif command == 'RIGHT': 113 | cur_trajs.append(traj[self.num * 2:].repeat(3, 1, 1)) 114 | else: 115 | cur_trajs.append(traj) 116 | cur_trajs = torch.stack(cur_trajs) 117 | 118 | if hd_map.shape[1] == 2: 119 | lane_divider = hd_map[:, 0:1] 120 | drivable_area = hd_map[:, 1:2] 121 | elif hd_map.shape[1] == 4: 122 | lane_divider = hd_map[:, 0:2] 123 | drivable_area = hd_map[:, 2:4] 124 | else: 125 | raise NotImplementedError 126 | 127 | if self.training: 128 | loss = self.loss(cur_trajs, gt_trajs, cost_volume, semantic_pred, lane_divider, drivable_area, target_points) 129 | else: 130 | loss = 0 131 | 132 | cam_front = self.reduce_channel(cam_front) 133 | h0 = cam_front.flatten(start_dim=1) # (B, 256/128) 134 | final_traj = self.select(cur_trajs, cost_volume, semantic_pred, lane_divider, drivable_area, target_points) # (B, n_future, 3) 135 | target_points = target_points.to(dtype=h0.dtype) 136 | b, s, _ = final_traj.shape 137 | x = torch.zeros((b, 2), device=h0.device) 138 | output_traj = [] 139 | for i in range(s): 140 | x = torch.cat([x, final_traj[:,i,:2], target_points], dim=-1) # (B, 6) 141 | h0 = self.GRU(x, h0) 142 | x = self.decoder(h0) # (B, 2) 143 | output_traj.append(x) 144 | output_traj = torch.stack(output_traj, dim=1) # (B, 4, 2) 145 | 146 | output_traj = torch.cat( 147 | [output_traj, torch.zeros((*output_traj.shape[:-1],1), device=output_traj.device)], dim=-1 148 | ) 149 | 150 | if self.training: 151 | loss = loss*0.5 + (F.smooth_l1_loss(output_traj[:,:,:2], gt_trajs[:,:,:2], reduction='none')*torch.tensor([10., 1.], device=loss.device)).mean() 152 | 153 | return loss, output_traj 154 | -------------------------------------------------------------------------------- /sad/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class SpatialRegressionLoss(nn.Module): 7 | def __init__(self, norm, ignore_index=255, future_discount=1.0): 8 | super(SpatialRegressionLoss, self).__init__() 9 | self.norm = norm 10 | self.ignore_index = ignore_index 11 | self.future_discount = future_discount 12 | 13 | if norm == 1: 14 | self.loss_fn = F.l1_loss 15 | elif norm == 2: 16 | self.loss_fn = F.mse_loss 17 | else: 18 | raise ValueError(f'Expected norm 1 or 2, but got norm={norm}') 19 | 20 | def forward(self, prediction, target, n_present=3): 21 | assert len(prediction.shape) == 5, 'Must be a 5D tensor' 22 | # ignore_index is the same across all channels 23 | mask = target[:, :, :1] != self.ignore_index 24 | if mask.sum() == 0: 25 | return prediction.new_zeros(1)[0].float() 26 | 27 | loss = self.loss_fn(prediction, target, reduction='none') 28 | 29 | # Sum channel dimension 30 | loss = torch.sum(loss, dim=-3, keepdim=True) 31 | 32 | seq_len = loss.shape[1] 33 | assert seq_len >= n_present 34 | future_len = seq_len - n_present 35 | future_discounts = self.future_discount ** torch.arange(1, future_len+1, device=loss.device, dtype=loss.dtype) 36 | discounts = torch.cat([torch.ones(n_present, device=loss.device, dtype=loss.dtype), future_discounts], dim=0) 37 | discounts = discounts.view(1, seq_len, 1, 1, 1) 38 | loss = loss * discounts 39 | 40 | return loss[mask].mean() 41 | 42 | 43 | class SegmentationLoss(nn.Module): 44 | def __init__(self, class_weights, ignore_index=255, use_top_k=False, top_k_ratio=1.0, future_discount=1.0): 45 | super().__init__() 46 | self.class_weights = class_weights 47 | self.ignore_index = ignore_index 48 | self.use_top_k = use_top_k 49 | self.top_k_ratio = top_k_ratio 50 | self.future_discount = future_discount 51 | 52 | def forward(self, prediction, target, n_present=3): 53 | if target.shape[-3] != 1: 54 | raise ValueError('segmentation label must be an index-label with channel dimension = 1.') 55 | b, s, c, h, w = prediction.shape 56 | 57 | prediction = prediction.view(b * s, c, h, w) 58 | target = target.view(b * s, h, w) 59 | loss = F.cross_entropy( 60 | prediction, 61 | target, 62 | ignore_index=self.ignore_index, 63 | reduction='none', 64 | weight=self.class_weights.to(target.device), 65 | ) 66 | 67 | loss = loss.view(b, s, h, w) 68 | 69 | assert s >= n_present 70 | future_len = s - n_present 71 | future_discounts = self.future_discount ** torch.arange(1, future_len+1, device=loss.device, dtype=loss.dtype) 72 | discounts = torch.cat([torch.ones(n_present, device=loss.device, dtype=loss.dtype), future_discounts], dim=0) 73 | discounts = discounts.view(1, s, 1, 1) 74 | loss = loss * discounts 75 | 76 | loss = loss.view(b, s, -1) 77 | if self.use_top_k: 78 | # Penalises the top-k hardest pixels 79 | k = int(self.top_k_ratio * loss.shape[2]) 80 | loss, _ = torch.sort(loss, dim=2, descending=True) 81 | loss = loss[:, :, :k] 82 | 83 | return torch.mean(loss) 84 | 85 | class HDmapLoss(nn.Module): 86 | def __init__(self, class_weights, training_weights, use_top_k, top_k_ratio, ignore_index=255): 87 | super(HDmapLoss, self).__init__() 88 | self.class_weights = class_weights 89 | self.training_weights = training_weights 90 | self.ignore_index = ignore_index 91 | self.use_top_k = use_top_k 92 | self.top_k_ratio = top_k_ratio 93 | 94 | def forward(self, prediction, target): 95 | loss = 0 96 | for i in range(target.shape[-3]): 97 | cur_target = target[:, i] 98 | b, h, w = cur_target.shape 99 | cur_prediction = prediction[:, 2*i:2*(i+1)] 100 | cur_loss = F.cross_entropy( 101 | cur_prediction, 102 | cur_target, 103 | ignore_index=self.ignore_index, 104 | reduction='none', 105 | weight=self.class_weights[i].to(target.device), 106 | ) 107 | 108 | cur_loss = cur_loss.view(b, -1) 109 | if self.use_top_k[i]: 110 | k = int(self.top_k_ratio[i] * cur_loss.shape[1]) 111 | cur_loss, _ = torch.sort(cur_loss, dim=1, descending=True) 112 | cur_loss = cur_loss[:, :k] 113 | loss += torch.mean(cur_loss) * self.training_weights[i] 114 | return loss 115 | 116 | class DepthLoss(nn.Module): 117 | def __init__(self, class_weights=None, ignore_index=255): 118 | super(DepthLoss, self).__init__() 119 | self.class_weights = class_weights 120 | self.ignore_index = ignore_index 121 | 122 | def forward(self, prediction, target): 123 | b, s, n, d, h, w = prediction.shape 124 | 125 | prediction = prediction.view(b*s*n, d, h, w) 126 | target = target.view(b*s*n, h, w) 127 | loss = F.cross_entropy( 128 | prediction, 129 | target, 130 | ignore_index=self.ignore_index, 131 | reduction='none', 132 | weight=self.class_weights 133 | ) 134 | return torch.mean(loss) 135 | 136 | 137 | class ProbabilisticLoss(nn.Module): 138 | def __init__(self, method): 139 | super(ProbabilisticLoss, self).__init__() 140 | self.method = method 141 | 142 | def kl_div(self, present_mu, present_log_sigma, future_mu, future_log_sigma): 143 | var_future = torch.exp(2 * future_log_sigma) 144 | var_present = torch.exp(2 * present_log_sigma) 145 | kl_div = ( 146 | present_log_sigma - future_log_sigma - 0.5 + (var_future + (future_mu - present_mu) ** 2) / ( 147 | 2 * var_present) 148 | ) 149 | 150 | kl_loss = torch.mean(torch.sum(kl_div, dim=-1)) 151 | return kl_loss 152 | 153 | def forward(self, output): 154 | if self.method == 'GAUSSIAN': 155 | present_mu = output['present_mu'] 156 | present_log_sigma = output['present_log_sigma'] 157 | future_mu = output['future_mu'] 158 | future_log_sigma = output['future_log_sigma'] 159 | 160 | kl_loss = self.kl_div(present_mu, present_log_sigma, future_mu, future_log_sigma) 161 | elif self.method == 'MIXGAUSSIAN': 162 | present_mu = output['present_mu'] 163 | present_log_sigma = output['present_log_sigma'] 164 | future_mu = output['future_mu'] 165 | future_log_sigma = output['future_log_sigma'] 166 | 167 | kl_loss = 0 168 | for i in range(len(present_mu)): 169 | kl_loss += self.kl_div(present_mu[i], present_log_sigma[i], future_mu[i], future_log_sigma[i]) 170 | elif self.method == 'BERNOULLI': 171 | present_log_prob = output['present_log_prob'] 172 | future_log_prob = output['future_log_prob'] 173 | 174 | kl_loss = F.kl_div(present_log_prob, future_log_prob, reduction='batchmean', log_target=True) 175 | else: 176 | raise NotImplementedError 177 | 178 | 179 | return kl_loss -------------------------------------------------------------------------------- /sad/utils/sampler.py: -------------------------------------------------------------------------------- 1 | # sampler.py 2 | # trajectory sampler 3 | # including a mix of clothoids, straight lines, and circles 4 | import numpy as np 5 | import torch 6 | from scipy.special import fresnel 7 | 8 | def sample(v0, Kappa, T0, N0, tt, M, possibility = None): 9 | ''' 10 | :param v0: initial velocity 11 | :param Kappa: curvature 12 | :param T0: initial tangent vector 13 | :param N0: initial normal vector 14 | :param tt: time stamp 15 | :param M: the number of sample 16 | :param possibility: torch.Tensor [3] 17 | :param debug: whether in debug mode 18 | :return: the nparray of trajectory 19 | ''' 20 | # sample accelerations 21 | if possibility is None: 22 | possibility = [0.4, 0.2, 0.4] 23 | 24 | straight_num = int(M * possibility[1]) 25 | left_num = int(M * possibility[0]) 26 | right_num = int(M * possibility[2]) 27 | 28 | accelerations = 10*(np.random.rand(M)-0.5) + 2 # -3m/s^2 to 7m/s^2 29 | 30 | # sample velocities 31 | # randomly sample a velocity <=15m/s at 80% of time 32 | v_options = np.stack((np.full(M, v0), 15*np.random.rand(M))) 33 | v_selections = (np.random.rand(M) >= 0.2).astype(int) 34 | velocities = v_options[v_selections, np.arange(M)] 35 | 36 | # generate longitudinal distances 37 | L = velocities[:, None] * tt[None, :] + accelerations[:, None] * (tt[None, :]**2) / 2 38 | L_straight = L[:straight_num] 39 | L = L[straight_num:] 40 | # print("L:", L) 41 | 42 | # scaling factor which determine the Clothiod curve 6 ~ 80 43 | alphas = (80 - 6) * np.random.rand(left_num + right_num) + 6 44 | 45 | ############################################################################ 46 | # sample M straight lines 47 | line_points = L_straight[:, :, None] * T0[None, None, :] 48 | line_thetas = np.zeros_like(L_straight) 49 | lines = np.concatenate((line_points, line_thetas[:, :, None]), axis=-1) 50 | 51 | ############################################################################ 52 | # sample M circles 53 | Krappa = min(-0.01, Kappa) if Kappa <= 0 else max(0.01, Kappa) 54 | radius = np.abs(1 / Krappa) 55 | center = np.array([-1 / Krappa, 0]) 56 | circle_phis = L / radius if Krappa >= 0 else np.pi - L/radius 57 | 58 | circle_points = np.dstack([ 59 | center[0] + radius * np.cos(circle_phis), 60 | center[1] + radius * np.sin(circle_phis), 61 | ]) 62 | 63 | # rotate thetas, wrap 64 | circle_thetas = L/radius if Krappa >= 0 else -L/radius 65 | circle_thetas = (circle_thetas + np.pi) % (2 * np.pi) - np.pi 66 | circles = np.concatenate((circle_points, circle_thetas[:, :, None]), axis=-1) 67 | 68 | ############################################################################ 69 | # sample M clothoids 70 | # Xi0 = Kappa / np.pi 71 | # Xis = Xi0 + L 72 | Xi0 = np.abs(Kappa) / np.pi 73 | Xis = Xi0 + L 74 | 75 | # 76 | # Ss, Cs = fresnel((Xis - Xi0) / alphas[:, None]) 77 | Ss, Cs = fresnel(Xis / alphas[:, None]) 78 | 79 | clothoid_points = alphas[:, None, None] * (Cs[:, :, None]*T0[None, None, :] + Ss[:, :, None]*N0[None, None, :]) 80 | 81 | # 82 | Xs = clothoid_points[:, :, 0] - clothoid_points[:, 0, 0, None] 83 | Ys = clothoid_points[:, :, 1] - clothoid_points[:, 0, 1, None] 84 | clothoid_theta0s = 0.5 * np.pi * ((Kappa / np.pi / alphas) ** 2) 85 | clothoid_theta0s = clothoid_theta0s[:, None] 86 | signed_clothoid_theta0s = clothoid_theta0s * np.sign(Kappa) 87 | # when kappa is positive, the clothoid curves left, theta is positive 88 | # we will rotate it clockwise by theta 89 | # when kappa is negative, the clothoid curves right, theta is negative 90 | # we will rotate it counterclockwise by theta 91 | clothoid_points[:, :, 0] = np.cos(signed_clothoid_theta0s) * Xs + np.sin(signed_clothoid_theta0s) * Ys 92 | clothoid_points[:, :, 1] = - np.sin(signed_clothoid_theta0s) * Xs + np.cos(signed_clothoid_theta0s) * Ys 93 | 94 | # tangent vector: http://mathworld.wolfram.com/CornuSpiral.html 95 | clothoid_thetas = 0.5 * np.pi * ((Xis / alphas[:, None])**2) 96 | clothoid_thetas = clothoid_thetas - clothoid_theta0s 97 | signed_clothoid_thetas = clothoid_thetas * np.sign(Kappa) 98 | # clothoid_thetas = clothoid_thetas if Krappa >= 0 else -clothoid_thetas 99 | # wrap 100 | # clothoid_thetas = (clothoid_thetas + np.pi) % (2 * np.pi) - np.pi 101 | wrapped_signed_clothoid_thetas = (signed_clothoid_thetas + np.pi) % (2 * np.pi) - np.pi 102 | # wrapped_signed_clothoid_thetas = (signed_clothoid_thetas) % (2 * np.pi) 103 | # 104 | clothoids = np.concatenate((clothoid_points, wrapped_signed_clothoid_thetas[:, :, None]), axis=-1) 105 | 106 | ############################################################################ 107 | # pick M in total 108 | t_options = np.stack((circles, clothoids)) 109 | t_selections = np.random.choice([0, 1], size=left_num + right_num, p=(0.2, 0.8)) 110 | ############################################################################ 111 | 112 | trajs = t_options[t_selections, np.arange(left_num + right_num)] 113 | 114 | # toss a coin for vertical flipping 115 | # left_possibility = possibility[1] / (possibility[1] + possibility[2]) 116 | # if Kappa > 0: 117 | # heads = (np.random.rand(M) <= left_possibility.item()) 118 | # else: 119 | # heads = (np.random.rand(M) <= (1- left_possibility).item()) 120 | # tails = np.logical_not(heads) 121 | # 122 | # # NOTE theta means what here 123 | # conditions = [heads[:, None, None], tails[:, None, None]] 124 | # choices = [trajs, np.dstack(( 125 | # -trajs[:, :, 0], trajs[:, :, 1], -trajs[:, :, 2] 126 | # ))] 127 | # 128 | # trajectories = np.select(conditions, choices) 129 | if Kappa > 0: 130 | left_curve = trajs[: left_num] 131 | right_curve = trajs[left_num: left_num + right_num] 132 | right_curve = np.dstack(( 133 | -right_curve[:, :, 0], right_curve[:, :, 1], -right_curve[:, :, 2] 134 | )) 135 | else: 136 | right_curve = trajs[: left_num] 137 | left_curve = trajs[left_num: left_num + right_num] 138 | left_curve = np.dstack(( 139 | -left_curve[:, :, 0], left_curve[:, :, 1], -left_curve[:, :, 2] 140 | )) 141 | 142 | trajectories = np.concatenate([left_curve, lines, right_curve], axis=0) 143 | mask = np.argsort(trajectories[:, -1, 0]) 144 | trajectories = trajectories[mask] 145 | 146 | return trajectories 147 | 148 | if __name__ == "__main__": 149 | from nuscenes.nuscenes import NuScenes 150 | from nuscenes.can_bus.can_bus_api import NuScenesCanBus 151 | import matplotlib.pyplot as plt 152 | nusc = NuScenes("v1.0-mini", '/home/hsc/data/Nuscenes') 153 | nusc_can = NuScenesCanBus(dataroot='/home/hsc/data/Nuscenes') 154 | 155 | for scene in nusc.scene: 156 | scene_name = scene["name"] 157 | scene_id = int(scene_name[-4:]) 158 | if scene_id in nusc_can.can_blacklist: 159 | print(f"skipping {scene_name}") 160 | continue 161 | pose = nusc_can.get_messages(scene_name, "pose") # The current pose of the ego vehicle, sampled at 50 HZ 162 | saf = nusc_can.get_messages(scene_name, "steeranglefeedback") # Steering angle feedback in radians at 100 HZ 163 | vm = nusc_can.get_messages(scene_name, "vehicle_monitor") # information most, but sample at 2 HZ 164 | # NOTE: I tried to verify if the relevant measurements are consistent 165 | # across multiple tables that contain redundant information 166 | # NOTE: verified pose's velocity matches vehicle monitor's 167 | # but the pose table offers at a much higher frequency 168 | # NOTE: same that steeranglefeedback's steering angle matches vehicle monitor's 169 | # but the steeranglefeedback table offers at a much higher frequency 170 | print(pose[23]) 171 | print(saf[45]) 172 | print(vm[0]) 173 | 174 | # initial velocity (m/s) 175 | v0 = pose[23]["vel"][0] 176 | # curvature 177 | #Kappa = 2 * saf[45]["value"] / 2.588 # 2 x \phi / distance between front and rear 178 | Kappa = 0 179 | # T0: longitudinal axis Tangent vector 180 | T0 = np.array([0.0, 1.0]) 181 | # N0: normal directional vector Normal vector 182 | N0 = np.array([1.0, 0.0]) if Kappa <= 0 else np.array([-1.0, 0.0]) 183 | # tt: time stamps 184 | tt = np.arange(0.0, 3.01, 0.01) 185 | # M: number of samples 186 | M = 1800 187 | # 188 | debug = False 189 | # 190 | trajectories = sample(v0, Kappa, T0, N0, tt, M) 191 | 192 | trajectories = trajectories[:,::100] 193 | 194 | # 195 | for i in range(len(trajectories)): 196 | trajectory = trajectories[i] 197 | plt.plot(trajectory[:, 0], trajectory[:, 1]) 198 | plt.grid(False) 199 | plt.axis("equal") 200 | 201 | plt.show() 202 | 203 | break 204 | -------------------------------------------------------------------------------- /sad/models/decoder_snn_ms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchvision.models.resnet import resnet18 4 | 5 | from spikingjelly.clock_driven.layer import SeqToANNContainer 6 | from sad.models.module.Real_MS_ResNet import multi_step_sew_resnet18 7 | from sad.layers.convolutions import UpsamplingAdd, SpikingUpsamplingAdd_MS, SpikingDeepLabHead 8 | from spikingjelly.clock_driven.neuron import ( 9 | MultiStepLIFNode, 10 | MultiStepParametricLIFNode, 11 | ) 12 | from spikingjelly.clock_driven import surrogate, layer 13 | 14 | class MeanDim(nn.Module): 15 | def __init__(self, dim=0): 16 | super(MeanDim, self).__init__() 17 | self.dim = dim 18 | 19 | def forward(self, x): 20 | return x.mean(self.dim) 21 | class Decoder(nn.Module): 22 | def __init__(self, in_channels, n_classes, n_present, n_hdmap, predict_gate): 23 | super().__init__() 24 | self.perceive_hdmap = predict_gate['perceive_hdmap'] 25 | self.predict_pedestrian = predict_gate['predict_pedestrian'] 26 | self.predict_instance = predict_gate['predict_instance'] 27 | self.predict_future_flow = predict_gate['predict_future_flow'] 28 | self.planning = predict_gate['planning'] 29 | 30 | self.n_classes = n_classes 31 | self.n_present = n_present 32 | if self.predict_instance is False and self.predict_future_flow is True: 33 | raise ValueError('flow cannot be True when not predicting instance') 34 | 35 | backbone = multi_step_sew_resnet18(pretrained=False, multi_step_neuron=MultiStepLIFNode) 36 | 37 | self.first_conv = SeqToANNContainer(nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)) 38 | self.bn1 = SeqToANNContainer(nn.BatchNorm2d(64)) 39 | self.lif1 = MultiStepLIFNode(tau=2.0, detach_reset=True, backend="torch", surrogate_function=surrogate.ATan()) 40 | 41 | self.layer1 = backbone.layer1 42 | self.layer2 = backbone.layer2 43 | self.layer3 = backbone.layer3 44 | 45 | shared_out_channels = in_channels 46 | self.up3_skip = SpikingUpsamplingAdd_MS(256, 128, scale_factor=2) 47 | #self.up3_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend="torch", surrogate_function=surrogate.ATan()) 48 | self.up2_skip = SpikingUpsamplingAdd_MS(128, 64, scale_factor=2) 49 | #self.up2_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend="torch", surrogate_function=surrogate.ATan()) 50 | self.up1_skip = SpikingUpsamplingAdd_MS(64, shared_out_channels, scale_factor=2) 51 | self.up1_lif = MultiStepLIFNode(tau=2.0, v_threshold=1.0, detach_reset=True, backend="torch", surrogate_function=surrogate.ATan()) 52 | 53 | self.segmentation_head = nn.Sequential( 54 | nn.Conv2d(shared_out_channels, shared_out_channels, kernel_size=3, padding=1, bias=False), 55 | nn.BatchNorm2d(shared_out_channels), 56 | nn.Conv2d(shared_out_channels, self.n_classes, kernel_size=1, padding=0), 57 | ) 58 | 59 | if self.predict_pedestrian: 60 | self.pedestrian_head = nn.Sequential( 61 | nn.Conv2d(shared_out_channels, shared_out_channels, kernel_size=3, padding=1, bias=False), 62 | nn.BatchNorm2d(shared_out_channels), 63 | nn.Conv2d(shared_out_channels, self.n_classes, kernel_size=1, padding=0), 64 | ) 65 | 66 | if self.perceive_hdmap: 67 | self.hdmap_head = nn.Sequential( 68 | nn.Conv2d(shared_out_channels, shared_out_channels, kernel_size=3, padding=1, bias=False), 69 | nn.Conv2d(shared_out_channels, 2 * n_hdmap, kernel_size=1, padding=0), 70 | ) 71 | 72 | if self.predict_instance: 73 | self.instance_offset_head = nn.Sequential( 74 | nn.Conv2d(shared_out_channels, shared_out_channels, kernel_size=3, padding=1, bias=False), 75 | nn.BatchNorm2d(shared_out_channels), 76 | nn.ReLU(inplace=True), 77 | nn.Conv2d(shared_out_channels, 2, kernel_size=1, padding=0), 78 | ) 79 | self.instance_center_head = nn.Sequential( 80 | nn.Conv2d(shared_out_channels, shared_out_channels, kernel_size=3, padding=1, bias=False), 81 | nn.BatchNorm2d(shared_out_channels), 82 | nn.ReLU(inplace=True), 83 | nn.Conv2d(shared_out_channels, 1, kernel_size=1, padding=0), 84 | nn.Sigmoid(), 85 | ) 86 | 87 | if self.predict_future_flow: 88 | self.instance_future_head = nn.Sequential( 89 | nn.Conv2d(shared_out_channels, shared_out_channels, kernel_size=3, padding=1, bias=False), 90 | nn.BatchNorm2d(shared_out_channels), 91 | nn.ReLU(inplace=True), 92 | nn.Conv2d(shared_out_channels, 2, kernel_size=1, padding=0), 93 | ) 94 | 95 | if self.planning: 96 | self.costvolume_head = nn.Sequential( 97 | nn.Conv2d(shared_out_channels, shared_out_channels, kernel_size=3, padding=1, bias=False), 98 | nn.BatchNorm2d(shared_out_channels), 99 | nn.Conv2d(shared_out_channels, 1, kernel_size=1, padding=0), 100 | ) 101 | 102 | 103 | def forward(self, x): 104 | b, s, c, h, w = x.shape 105 | x = x.reshape(s, b, c, h, w) 106 | #x = x.view(b * s, c, h, w) 107 | # (H, W) 108 | skip_x = {'1': x} 109 | x = self.up1_lif(x) 110 | 111 | 112 | # (H/2, W/2) 113 | x = self.first_conv(x) 114 | x = self.bn1(x) 115 | # _, c_1, h_1, w_1 = x.shape 116 | # x = x.reshape(s, b, c_1, h_1, w_1) 117 | x = self.layer1(x) 118 | skip_x['2'] = x 119 | 120 | # (H/4 , W/4) 121 | x = self.layer2(x) 122 | skip_x['3'] = x 123 | 124 | # (H/8, W/8) 125 | x = self.layer3(x) # (b*s, 256, 25, 25) 126 | 127 | # First upsample to (H/4, W/4) 128 | #x = x.view(b*s, *x.shape[2:]) 129 | x = self.up3_skip(x, skip_x['3']) 130 | #x = x.reshape(s, b, *x.shape[1:]) 131 | #x = self.up3_lif(x) 132 | 133 | # Second upsample to (H/2, W/2) 134 | #x = x.view(b*s, *x.shape[2:]) 135 | x = self.up2_skip(x, skip_x['2']) 136 | # x = x.reshape(s, b, *x.shape[1:]) 137 | # x = self.up2_lif(x) 138 | 139 | # Third upsample to (H, W) 140 | # x = x.view(b*s, *x.shape[2:]) 141 | x = self.up1_skip(x, skip_x['1']) 142 | # x = x.reshape(s, b, *x.shape[1:]) 143 | x = self.up1_lif(x) 144 | x = x.view(b*s, *x.shape[2:]) 145 | 146 | segmentation_output = self.segmentation_head(x) 147 | pedestrian_output = self.pedestrian_head(x) if self.predict_pedestrian else None 148 | hdmap_output = self.hdmap_head( 149 | x.view(b, s, *x.shape[1:])[:, self.n_present - 1]) if self.perceive_hdmap else None 150 | instance_center_output = self.instance_center_head(x) if self.predict_instance else None 151 | instance_offset_output = self.instance_offset_head(x) if self.predict_instance else None 152 | instance_future_output = self.instance_future_head(x) if self.predict_future_flow else None 153 | costvolume = self.costvolume_head(x).squeeze(1) if self.planning else None 154 | return { 155 | 'segmentation': segmentation_output.view(b, s, *segmentation_output.shape[1:]), 156 | 'pedestrian': pedestrian_output.view(b, s, *pedestrian_output.shape[1:]) 157 | if pedestrian_output is not None else None, 158 | 'hdmap': hdmap_output, 159 | 'instance_center': instance_center_output.view(b, s, *instance_center_output.shape[1:]) 160 | if instance_center_output is not None else None, 161 | 'instance_offset': instance_offset_output.view(b, s, *instance_offset_output.shape[1:]) 162 | if instance_offset_output is not None else None, 163 | 'instance_flow': instance_future_output.view(b, s, *instance_future_output.shape[1:]) 164 | if instance_future_output is not None else None, 165 | 'costvolume': costvolume.view(b, s, *costvolume.shape[1:]) 166 | if costvolume is not None else None, 167 | } 168 | 169 | 170 | # 定义Decoder类 171 | # 注意:这里省略了Decoder类的定义,因为它已经在你的提问中给出。 172 | # 请确保完整的Decoder类定义包含在这个脚本中,或者通过适当的导入包含在内。 173 | 174 | def main(): 175 | # 假设的配置参数 176 | in_channels = 3 # 输入图像的通道数 177 | n_classes = 10 # 分类的类别数 178 | n_present = 5 # 假设的当前帧数 179 | n_hdmap = 2 # HD地图的输出通道数 180 | predict_gate = { 181 | 'perceive_hdmap': True, 182 | 'predict_pedestrian': True, 183 | 'predict_instance': True, 184 | 'predict_future_flow': True, 185 | 'planning': True 186 | } 187 | 188 | # 创建Decoder实例 189 | decoder = Decoder(in_channels, n_classes, n_present, n_hdmap, predict_gate) 190 | 191 | # 生成一个假设的输入张量,假设batch_size=1, sequence_length=5, height=256, width=256 192 | x = torch.randn(1, 5, in_channels, 256, 256) 193 | 194 | # 前向传播 195 | output = decoder(x) 196 | 197 | # 打印输出结果的一些信息 198 | print("Output keys:", output.keys()) 199 | for key, value in output.items(): 200 | if value is not None: 201 | print(f"{key}: shape {value.shape}") 202 | else: 203 | print(f"{key}: None") 204 | 205 | if __name__ == "__main__": 206 | main() 207 | -------------------------------------------------------------------------------- /sad/models/module/.ipynb_checkpoints/MS_ResNet-checkpoint.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from spikingjelly.clock_driven import functional 4 | try: 5 | from torchvision.models.utils import load_state_dict_from_url 6 | except ImportError: 7 | from torchvision._internally_replaced_utils import load_state_dict_from_url 8 | 9 | __all__ = ['MultiStepMSResNet', 'multi_step_sew_resnet18'] 10 | 11 | model_urls = { 12 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 13 | 14 | } 15 | 16 | # modified by https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py 17 | 18 | def sew_function(x: torch.Tensor, y: torch.Tensor, cnf:str): 19 | if cnf == 'ADD': 20 | return x + y 21 | elif cnf == 'AND': 22 | return x * y 23 | elif cnf == 'IAND': 24 | return x * (1. - y) 25 | else: 26 | raise NotImplementedError 27 | 28 | 29 | 30 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 31 | """3x3 convolution with padding""" 32 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 33 | padding=dilation, groups=groups, bias=False, dilation=dilation) 34 | 35 | 36 | def conv1x1(in_planes, out_planes, stride=1): 37 | """1x1 convolution""" 38 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 39 | class BasicBlock(nn.Module): 40 | expansion = 1 41 | 42 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 43 | base_width=64, dilation=1, norm_layer=None, cnf: str = None, single_step_neuron: callable = None, **kwargs): 44 | super(BasicBlock, self).__init__() 45 | if norm_layer is None: 46 | norm_layer = nn.BatchNorm2d 47 | if groups != 1 or base_width != 64: 48 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 49 | if dilation > 1: 50 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 51 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 52 | self.conv1 = conv3x3(inplanes, planes, stride) 53 | self.bn1 = norm_layer(planes) 54 | self.sn1 = single_step_neuron(**kwargs) 55 | self.conv2 = conv3x3(planes, planes) 56 | self.bn2 = norm_layer(planes) 57 | self.sn2 = single_step_neuron(**kwargs) 58 | self.downsample = downsample 59 | if downsample is not None: 60 | self.downsample_sn = single_step_neuron(**kwargs) 61 | self.stride = stride 62 | self.cnf = cnf 63 | 64 | def forward(self, x): 65 | identity = x 66 | 67 | out = self.conv1(x) 68 | out = self.bn1(out) 69 | out = self.sn1(out) 70 | 71 | out = self.conv2(out) 72 | out = self.bn2(out) 73 | out = self.sn2(out) 74 | 75 | if self.downsample is not None: 76 | identity = self.downsample_sn(self.downsample(x)) 77 | 78 | out = sew_function(identity, out, self.cnf) 79 | 80 | return out 81 | 82 | def extra_repr(self) -> str: 83 | return super().extra_repr() + f'cnf={self.cnf}' 84 | 85 | 86 | 87 | class MultiStepBasicBlock(BasicBlock): 88 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 89 | base_width=64, dilation=1, norm_layer=None, cnf: str = None, multi_step_neuron: callable = None, **kwargs): 90 | super().__init__(inplanes, planes, stride, downsample, groups, 91 | base_width, dilation, norm_layer, cnf, multi_step_neuron, **kwargs) 92 | 93 | def forward(self, x_seq): 94 | identity = x_seq 95 | 96 | out = self.sn1(x_seq) 97 | out = functional.seq_to_ann_forward(x_seq, [self.conv1, self.bn1]) 98 | 99 | out = self.sn2(out) 100 | out = functional.seq_to_ann_forward(out, [self.conv2, self.bn2]) 101 | 102 | 103 | if self.downsample is not None: 104 | identity = functional.seq_to_ann_forward(x_seq, self.downsample) 105 | 106 | out = identity + out 107 | 108 | return out 109 | 110 | 111 | class MultiStepMSResNet(nn.Module): 112 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 113 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 114 | norm_layer=None, T:int=None, cnf: str=None, multi_step_neuron: callable = None, **kwargs): 115 | super().__init__() 116 | self.T = T 117 | if norm_layer is None: 118 | norm_layer = nn.BatchNorm2d 119 | self._norm_layer = norm_layer 120 | 121 | self.inplanes = 64 122 | self.dilation = 1 123 | if replace_stride_with_dilation is None: 124 | # each element in the tuple indicates if we should replace 125 | # the 2x2 stride with a dilated convolution instead 126 | replace_stride_with_dilation = [False, False, False] 127 | if len(replace_stride_with_dilation) != 3: 128 | raise ValueError("replace_stride_with_dilation should be None " 129 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 130 | self.groups = groups 131 | self.base_width = width_per_group 132 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 133 | bias=False) 134 | self.bn1 = norm_layer(self.inplanes) 135 | self.sn1 = multi_step_neuron(**kwargs) 136 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 137 | self.layer1 = self._make_layer(block, 64, layers[0], cnf=cnf, multi_step_neuron=multi_step_neuron, **kwargs) 138 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 139 | dilate=replace_stride_with_dilation[0], cnf=cnf, multi_step_neuron=multi_step_neuron, 140 | **kwargs) 141 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 142 | dilate=replace_stride_with_dilation[1], cnf=cnf, multi_step_neuron=multi_step_neuron, 143 | **kwargs) 144 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 145 | dilate=replace_stride_with_dilation[2], cnf=cnf, multi_step_neuron=multi_step_neuron, 146 | **kwargs) 147 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 148 | self.fc = nn.Linear(512 * block.expansion, num_classes) 149 | 150 | for m in self.modules(): 151 | if isinstance(m, nn.Conv2d): 152 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 153 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 154 | nn.init.constant_(m.weight, 1) 155 | nn.init.constant_(m.bias, 0) 156 | 157 | # Zero-initialize the last BN in each residual branch, 158 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 159 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 160 | if zero_init_residual: 161 | for m in self.modules(): 162 | if isinstance(m, Bottleneck): 163 | nn.init.constant_(m.bn3.weight, 0) 164 | elif isinstance(m, BasicBlock): 165 | nn.init.constant_(m.bn2.weight, 0) 166 | 167 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False, cnf: str = None, multi_step_neuron: callable = None, **kwargs): 168 | norm_layer = self._norm_layer 169 | downsample = None 170 | previous_dilation = self.dilation 171 | if dilate: 172 | self.dilation *= stride 173 | stride = 1 174 | if stride != 1 or self.inplanes != planes * block.expansion: 175 | downsample = nn.Sequential( 176 | conv1x1(self.inplanes, planes * block.expansion, stride), 177 | norm_layer(planes * block.expansion), 178 | ) 179 | 180 | layers = [] 181 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 182 | self.base_width, previous_dilation, norm_layer, cnf, multi_step_neuron, **kwargs)) 183 | self.inplanes = planes * block.expansion 184 | for _ in range(1, blocks): 185 | layers.append(block(self.inplanes, planes, groups=self.groups, 186 | base_width=self.base_width, dilation=self.dilation, 187 | norm_layer=norm_layer, cnf=cnf, multi_step_neuron=multi_step_neuron, **kwargs)) 188 | 189 | return nn.Sequential(*layers) 190 | 191 | def _forward_impl(self, x: torch.Tensor): 192 | # See note [TorchScript super()] 193 | x_seq = None 194 | if x.dim() == 5: 195 | # x.shape = [T, N, C, H, W] 196 | x_seq = functional.seq_to_ann_forward(x, [self.conv1, self.bn1]) 197 | else: 198 | assert self.T is not None, 'When x.shape is [N, C, H, W], self.T can not be None.' 199 | # x.shape = [N, C, H, W] 200 | x = self.conv1(x) 201 | x = self.bn1(x) 202 | x.unsqueeze_(0) 203 | x_seq = x.repeat(self.T, 1, 1, 1, 1) 204 | 205 | x_seq = functional.seq_to_ann_forward(x_seq, self.maxpool) 206 | 207 | x_seq = self.layer1(x_seq) 208 | x_seq = self.layer2(x_seq) 209 | x_seq = self.layer3(x_seq) 210 | x_seq = self.layer4(x_seq) 211 | 212 | x_seq = functional.seq_to_ann_forward(x_seq, self.avgpool) 213 | x_seq = self.sn1(x_seq) 214 | x_seq = torch.flatten(x_seq, 2) 215 | # x_seq = self.fc(x_seq.mean(0)) 216 | x_seq = functional.seq_to_ann_forward(x_seq, self.fc) 217 | 218 | return x_seq 219 | 220 | def forward(self, x): 221 | """ 222 | :param x: the input with `shape=[N, C, H, W]` or `[*, N, C, H, W]` 223 | :type x: torch.Tensor 224 | :return: output 225 | :rtype: torch.Tensor 226 | """ 227 | return self._forward_impl(x) 228 | 229 | 230 | 231 | 232 | def _multi_step_sew_resnet(arch, block, layers, pretrained, progress, T, cnf, multi_step_neuron, **kwargs): 233 | model = MultiStepMSResNet(block, layers, T=T, cnf=cnf, multi_step_neuron=multi_step_neuron, **kwargs) 234 | if pretrained: 235 | state_dict = load_state_dict_from_url(model_urls[arch], 236 | progress=progress) 237 | model.load_state_dict(state_dict) 238 | return model 239 | 240 | def multi_step_sew_resnet18(pretrained=False, progress=True, T: int = None, cnf: str = None, multi_step_neuron: callable=None, **kwargs): 241 | """ 242 | :param pretrained: If True, the SNN will load parameters from the ANN pre-trained on ImageNet 243 | :type pretrained: bool 244 | :param progress: If True, displays a progress bar of the download to stderr 245 | :type progress: bool 246 | :param T: total time-steps 247 | :type T: int 248 | :param cnf: the name of spike-element-wise function 249 | :type cnf: str 250 | :param multi_step_neuron: a multi-step neuron 251 | :type multi_step_neuron: callable 252 | :param kwargs: kwargs for `multi_step_neuron` 253 | :type kwargs: dict 254 | :return: Spiking ResNet-18 255 | :rtype: torch.nn.Module 256 | 257 | The multi-step spike-element-wise ResNet-18 `"Deep Residual Learning in Spiking Neural Networks" `_ 258 | modified by the ResNet-18 model from `"Deep Residual Learning for Image Recognition" `_ 259 | """ 260 | 261 | return _multi_step_sew_resnet('resnet18', MultiStepBasicBlock, [2, 2, 2, 2], pretrained, progress, T, cnf, multi_step_neuron, **kwargs) 262 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | from PIL import Image 3 | import torch 4 | import torch.utils.data 5 | import numpy as np 6 | import torchvision 7 | from tqdm import tqdm 8 | from nuscenes.nuscenes import NuScenes 9 | import matplotlib 10 | from matplotlib import pyplot as plt 11 | import pathlib 12 | import datetime 13 | 14 | from sad.datas.NuscenesData import FuturePredictionDataset 15 | from sad.trainer import TrainingModule 16 | from sad.metrics import IntersectionOverUnion, PanopticMetric, PlanningMetric 17 | from sad.utils.network import preprocess_batch, NormalizeInverse 18 | from sad.utils.instance import predict_instance_segmentation_and_trajectories 19 | from sad.utils.visualisation import make_contour 20 | 21 | def mk_save_dir(): 22 | now = datetime.datetime.now() 23 | string = '_'.join(map(lambda x: '%02d' % x, (now.month, now.day, now.hour, now.minute, now.second))) 24 | save_path = pathlib.Path('imgs') / string 25 | save_path.mkdir(parents=True, exist_ok=False) 26 | return save_path 27 | 28 | def eval(checkpoint_path, dataroot): 29 | save_path = mk_save_dir() 30 | 31 | trainer = TrainingModule.load_from_checkpoint(checkpoint_path, strict=True) 32 | print(f'Loaded weights from \n {checkpoint_path}') 33 | trainer.eval() 34 | 35 | device = torch.device('cuda:0') 36 | trainer.to(device) 37 | model = trainer.model 38 | 39 | cfg = model.cfg 40 | cfg.GPUS = "[0]" 41 | cfg.BATCHSIZE = 1 42 | cfg.LIFT.GT_DEPTH = False 43 | cfg.DATASET.DATAROOT = dataroot 44 | cfg.DATASET.MAP_FOLDER = dataroot 45 | 46 | dataroot = cfg.DATASET.DATAROOT 47 | nworkers = cfg.N_WORKERS 48 | nusc = NuScenes(version='v1.0-{}'.format(cfg.DATASET.VERSION), dataroot=dataroot, verbose=False) 49 | valdata = FuturePredictionDataset(nusc, 1, cfg) 50 | valloader = torch.utils.data.DataLoader( 51 | valdata, batch_size=cfg.BATCHSIZE, shuffle=False, num_workers=nworkers, pin_memory=True, drop_last=False 52 | ) 53 | 54 | n_classes = len(cfg.SEMANTIC_SEG.VEHICLE.WEIGHTS) 55 | hdmap_class = cfg.SEMANTIC_SEG.HDMAP.ELEMENTS 56 | metric_vehicle_val = IntersectionOverUnion(n_classes).to(device) 57 | future_second = int(cfg.N_FUTURE_FRAMES / 2) 58 | 59 | if cfg.SEMANTIC_SEG.PEDESTRIAN.ENABLED: 60 | metric_pedestrian_val = IntersectionOverUnion(n_classes).to(device) 61 | 62 | if cfg.SEMANTIC_SEG.HDMAP.ENABLED: 63 | metric_hdmap_val = [] 64 | for i in range(len(hdmap_class)): 65 | metric_hdmap_val.append(IntersectionOverUnion(2, absent_score=1).to(device)) 66 | 67 | if cfg.INSTANCE_SEG.ENABLED: 68 | metric_panoptic_val = PanopticMetric(n_classes=n_classes).to(device) 69 | 70 | if cfg.PLANNING.ENABLED: 71 | metric_planning_val = [] 72 | for i in range(future_second): 73 | metric_planning_val.append(PlanningMetric(cfg, 2*(i+1)).to(device)) 74 | 75 | 76 | for index, batch in enumerate(tqdm(valloader)): 77 | preprocess_batch(batch, device) 78 | image = batch['image'] 79 | intrinsics = batch['intrinsics'] 80 | extrinsics = batch['extrinsics'] 81 | future_egomotion = batch['future_egomotion'] 82 | command = batch['command'] 83 | trajs = batch['sample_trajectory'] 84 | target_points = batch['target_point'] 85 | B = len(image) 86 | labels = trainer.prepare_future_labels(batch) 87 | 88 | with torch.no_grad(): 89 | output = model( 90 | image, intrinsics, extrinsics, future_egomotion 91 | ) 92 | 93 | n_present = model.receptive_field 94 | 95 | # semantic segmentation metric 96 | seg_prediction = output['segmentation'].detach() 97 | seg_prediction = torch.argmax(seg_prediction, dim=2, keepdim=True) 98 | metric_vehicle_val(seg_prediction[:, n_present - 1:], labels['segmentation'][:, n_present - 1:]) 99 | 100 | if cfg.SEMANTIC_SEG.PEDESTRIAN.ENABLED: 101 | pedestrian_prediction = output['pedestrian'].detach() 102 | pedestrian_prediction = torch.argmax(pedestrian_prediction, dim=2, keepdim=True) 103 | metric_pedestrian_val(pedestrian_prediction[:, n_present - 1:], 104 | labels['pedestrian'][:, n_present - 1:]) 105 | else: 106 | pedestrian_prediction = torch.zeros_like(seg_prediction) 107 | 108 | if cfg.SEMANTIC_SEG.HDMAP.ENABLED: 109 | for i in range(len(hdmap_class)): 110 | hdmap_prediction = output['hdmap'][:, 2 * i:2 * (i + 1)].detach() 111 | hdmap_prediction = torch.argmax(hdmap_prediction, dim=1, keepdim=True) 112 | metric_hdmap_val[i](hdmap_prediction, labels['hdmap'][:, i:i + 1]) 113 | 114 | if cfg.INSTANCE_SEG.ENABLED: 115 | pred_consistent_instance_seg = predict_instance_segmentation_and_trajectories( 116 | output, compute_matched_centers=False, make_consistent=True 117 | ) 118 | metric_panoptic_val(pred_consistent_instance_seg[:, n_present - 1:], 119 | labels['instance'][:, n_present - 1:]) 120 | 121 | if cfg.PLANNING.ENABLED: 122 | occupancy = torch.logical_or(seg_prediction, pedestrian_prediction) 123 | _, final_traj = model.planning( 124 | cam_front=output['cam_front'].detach(), 125 | trajs=trajs[:, :, 1:], 126 | gt_trajs=labels['gt_trajectory'][:, 1:], 127 | cost_volume=output['costvolume'][:, n_present:].detach(), 128 | semantic_pred=occupancy[:, n_present:].squeeze(2), 129 | hd_map=output['hdmap'].detach(), 130 | commands=command, 131 | target_points=target_points 132 | ) 133 | occupancy = torch.logical_or(labels['segmentation'][:, n_present:].squeeze(2), 134 | labels['pedestrian'][:, n_present:].squeeze(2)) 135 | for i in range(future_second): 136 | cur_time = (i+1)*2 137 | metric_planning_val[i](final_traj[:,:cur_time].detach(), labels['gt_trajectory'][:,1:cur_time+1], occupancy[:,:cur_time]) 138 | 139 | if index % 100 == 0: 140 | save(output, labels, batch, n_present, index, save_path) 141 | 142 | 143 | results = {} 144 | 145 | scores = metric_vehicle_val.compute() 146 | results['vehicle_iou'] = scores[1] 147 | 148 | if cfg.SEMANTIC_SEG.PEDESTRIAN.ENABLED: 149 | scores = metric_pedestrian_val.compute() 150 | results['pedestrian_iou'] = scores[1] 151 | 152 | if cfg.SEMANTIC_SEG.HDMAP.ENABLED: 153 | for i, name in enumerate(hdmap_class): 154 | scores = metric_hdmap_val[i].compute() 155 | results[name + '_iou'] = scores[1] 156 | 157 | if cfg.INSTANCE_SEG.ENABLED: 158 | scores = metric_panoptic_val.compute() 159 | for key, value in scores.items(): 160 | results['vehicle_'+key] = value[1] 161 | 162 | if cfg.PLANNING.ENABLED: 163 | for i in range(future_second): 164 | scores = metric_planning_val[i].compute() 165 | for key, value in scores.items(): 166 | results['plan_'+key+'_{}s'.format(i+1)]=value.mean() 167 | 168 | for key, value in results.items(): 169 | print(f'{key} : {value.item()}') 170 | 171 | def save(output, labels, batch, n_present, frame, save_path): 172 | hdmap = output['hdmap'].detach() 173 | segmentation = output['segmentation'][:, n_present - 1].detach() 174 | pedestrian = output['pedestrian'][:, n_present - 1].detach() 175 | gt_trajs = labels['gt_trajectory'] 176 | images = batch['image'] 177 | 178 | denormalise_img = torchvision.transforms.Compose( 179 | (NormalizeInverse(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 180 | torchvision.transforms.ToPILImage(),) 181 | ) 182 | 183 | val_w = 2.99 184 | val_h = 2.99 * (224. / 480.) 185 | plt.figure(1, figsize=(4*val_w,2*val_h)) 186 | width_ratios = (val_w,val_w,val_w,val_w) 187 | gs = matplotlib.gridspec.GridSpec(2, 4, width_ratios=width_ratios) 188 | gs.update(wspace=0.0, hspace=0.0, left=0.0, right=1.0, top=1.0, bottom=0.0) 189 | 190 | plt.subplot(gs[0, 0]) 191 | #plt.annotate('FRONT LEFT', (0.01, 0.87), c='white', xycoords='axes fraction', fontsize=14) 192 | plt.imshow(denormalise_img(images[0,n_present-1,0].cpu())) 193 | plt.axis('off') 194 | 195 | plt.subplot(gs[0, 1]) 196 | #plt.annotate('FRONT', (0.01, 0.87), c='white', xycoords='axes fraction', fontsize=14) 197 | plt.imshow(denormalise_img(images[0,n_present-1,1].cpu())) 198 | plt.axis('off') 199 | 200 | plt.subplot(gs[0, 2]) 201 | #plt.annotate('FRONT RIGHT', (0.01, 0.87), c='white', xycoords='axes fraction', fontsize=14) 202 | plt.imshow(denormalise_img(images[0,n_present-1,2].cpu())) 203 | plt.axis('off') 204 | 205 | plt.subplot(gs[1, 0]) 206 | #plt.annotate('BACK LEFT', (0.01, 0.87), c='white', xycoords='axes fraction', fontsize=14) 207 | showing = denormalise_img(images[0,n_present-1,3].cpu()) 208 | showing = showing.transpose(Image.FLIP_LEFT_RIGHT) 209 | plt.imshow(showing) 210 | plt.axis('off') 211 | 212 | plt.subplot(gs[1, 1]) 213 | #plt.annotate('BACK', (0.01, 0.87), c='white', xycoords='axes fraction', fontsize=14) 214 | showing = denormalise_img(images[0, n_present - 1, 4].cpu()) 215 | showing = showing.transpose(Image.FLIP_LEFT_RIGHT) 216 | plt.imshow(showing) 217 | plt.axis('off') 218 | 219 | plt.subplot(gs[1, 2]) 220 | #plt.annotate('BACK_RIGHT', (0.01, 0.87), c='white', xycoords='axes fraction', fontsize=14) 221 | showing = denormalise_img(images[0, n_present - 1, 5].cpu()) 222 | showing = showing.transpose(Image.FLIP_LEFT_RIGHT) 223 | plt.imshow(showing) 224 | plt.axis('off') 225 | 226 | plt.subplot(gs[:, 3]) 227 | showing = torch.zeros((200, 200, 3)).numpy() 228 | showing[:, :] = np.array([219 / 255, 215 / 255, 215 / 255]) 229 | 230 | # drivable 231 | area = torch.argmax(hdmap[0, 2:4], dim=0).cpu().numpy() 232 | hdmap_index = area > 0 233 | showing[hdmap_index] = np.array([161 / 255, 158 / 255, 158 / 255]) 234 | 235 | # lane 236 | area = torch.argmax(hdmap[0, 0:2], dim=0).cpu().numpy() 237 | hdmap_index = area > 0 238 | showing[hdmap_index] = np.array([84 / 255, 70 / 255, 70 / 255]) 239 | 240 | # semantic 241 | semantic_seg = torch.argmax(segmentation[0], dim=0).cpu().numpy() 242 | semantic_index = semantic_seg > 0 243 | showing[semantic_index] = np.array([255 / 255, 128 / 255, 0 / 255]) 244 | 245 | pedestrian_seg = torch.argmax(pedestrian[0], dim=0).cpu().numpy() 246 | pedestrian_index = pedestrian_seg > 0 247 | showing[pedestrian_index] = np.array([28 / 255, 81 / 255, 227 / 255]) 248 | 249 | plt.imshow(make_contour(showing)) 250 | plt.axis('off') 251 | 252 | bx = np.array([-50.0 + 0.5/2.0, -50.0 + 0.5/2.0]) 253 | dx = np.array([0.5, 0.5]) 254 | w, h = 1.85, 4.084 255 | pts = np.array([ 256 | [-h / 2. + 0.5, w / 2.], 257 | [h / 2. + 0.5, w / 2.], 258 | [h / 2. + 0.5, -w / 2.], 259 | [-h / 2. + 0.5, -w / 2.], 260 | ]) 261 | pts = (pts - bx) / dx 262 | pts[:, [0, 1]] = pts[:, [1, 0]] 263 | plt.fill(pts[:, 0], pts[:, 1], '#76b900') 264 | 265 | plt.xlim((200, 0)) 266 | plt.ylim((0, 200)) 267 | gt_trajs[0, :, :1] = gt_trajs[0, :, :1] * -1 268 | gt_trajs = (gt_trajs[0, :, :2].cpu().numpy() - bx) / dx 269 | plt.plot(gt_trajs[:, 0], gt_trajs[:, 1], linewidth=3.0) 270 | 271 | plt.savefig(save_path / ('%04d.png' % frame)) 272 | plt.close() 273 | 274 | if __name__ == '__main__': 275 | parser = ArgumentParser(description='sad evaluation') 276 | parser.add_argument('--checkpoint', default='last.ckpt', type=str, help='path to checkpoint') 277 | parser.add_argument('--dataroot', default=None, type=str) 278 | 279 | args = parser.parse_args() 280 | 281 | eval(args.checkpoint, args.dataroot) -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /sad/models/module/ms_conv.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | from timm.models.layers import DropPath 4 | from spikingjelly.clock_driven.neuron import ( 5 | MultiStepLIFNode, 6 | MultiStepParametricLIFNode, 7 | ) 8 | 9 | 10 | class Erode(nn.Module): 11 | def __init__(self) -> None: 12 | super().__init__() 13 | self.pool = nn.MaxPool3d( 14 | kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1) 15 | ) 16 | 17 | def forward(self, x): 18 | return self.pool(x) 19 | 20 | class SpatialGatingUnit(nn.Module): 21 | def __init__(self, d_ffn, seq_len, spike_mode): 22 | super().__init__() 23 | d_ffn = int(d_ffn/2) 24 | self.norm = nn.BatchNorm1d(seq_len) 25 | self.spatial_proj = nn.Conv1d(seq_len, seq_len, kernel_size=1) 26 | self.fc1_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend="torch") 27 | #self.local_proj = nn.Conv1d(d_ffn, d_ffn, kernel_size=4, groups=d_ffn, padding='same') 28 | nn.init.constant_(self.spatial_proj.bias, 1.0) 29 | #nn.init.constant_(self.local_proj.bias, 1.0) 30 | 31 | 32 | 33 | def forward(self, x): 34 | T, B, C, L = x.shape 35 | 36 | 37 | L = int(L/2) 38 | u, v = x.chunk(2, dim=-1) 39 | #u = self.local_proj(u.reshape(T*B, L, C)) 40 | 41 | #u = self.spatial_proj(u.flatten(0, 1)) 42 | #u = self.norm(u.flatten(0, 1)) 43 | #u = self.fc1_lif(u.reshape(T, B, C, L)) 44 | out = u * v 45 | return out 46 | class MS_MLP_Conv(nn.Module): 47 | def __init__( 48 | self, 49 | in_features, 50 | hidden_features=None, 51 | out_features=None, 52 | drop=0.0, 53 | spike_mode="lif", 54 | layer=0, 55 | ): 56 | super().__init__() 57 | out_features = out_features or in_features 58 | hidden_features = hidden_features or in_features 59 | self.res = in_features == hidden_features 60 | self.fc1_conv = nn.Conv2d(in_features, hidden_features, kernel_size=1, stride=1) 61 | self.fc1_bn = nn.BatchNorm2d(hidden_features) 62 | self.sgu = SpatialGatingUnit(hidden_features, 64, spike_mode) 63 | if spike_mode == "lif": 64 | self.fc1_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend="torch") 65 | elif spike_mode == "plif": 66 | self.fc1_lif = MultiStepParametricLIFNode( 67 | init_tau=2.0, detach_reset=True, backend="torch" 68 | ) 69 | 70 | self.fc2_conv = nn.Conv2d( 71 | int(hidden_features/2), out_features, kernel_size=1, stride=1 72 | ) 73 | self.fc2_bn = nn.BatchNorm2d(out_features) 74 | if spike_mode == "lif": 75 | self.fc2_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend="torch") 76 | elif spike_mode == "plif": 77 | self.fc2_lif = MultiStepParametricLIFNode( 78 | init_tau=2.0, detach_reset=True, backend="torch" 79 | ) 80 | 81 | self.c_hidden = hidden_features 82 | self.c_output = out_features 83 | self.layer = layer 84 | 85 | def forward(self, x, hook=None): 86 | T, B, C, H, W = x.shape 87 | identity = x 88 | 89 | x = self.fc1_lif(x) 90 | if hook is not None: 91 | hook[self._get_name() + str(self.layer) + "_fc1_lif"] = x.detach() 92 | x = self.fc1_conv(x.flatten(0, 1)) 93 | x = self.fc1_bn(x).reshape(T, B, self.c_hidden, H, W).contiguous() 94 | if self.res: 95 | x = identity + x 96 | identity = x 97 | x = x.reshape(T, B, self.c_hidden, H*W).permute(0, 1, 3, 2) 98 | 99 | x = self.fc2_lif(x) 100 | x = self.sgu(x).permute(0, 1, 3, 2).reshape(T, B, int(self.c_hidden/2), H, W) 101 | if hook is not None: 102 | hook[self._get_name() + str(self.layer) + "_fc2_lif"] = x.detach() 103 | x = self.fc2_conv(x.flatten(0, 1)) 104 | x = self.fc2_bn(x).reshape(T, B, C, H, W).contiguous() 105 | 106 | x = x + identity 107 | return x, hook 108 | 109 | 110 | class MS_SSA_Conv(nn.Module): 111 | def __init__( 112 | self, 113 | dim, 114 | num_heads=8, 115 | qkv_bias=False, 116 | qk_scale=None, 117 | attn_drop=0.0, 118 | proj_drop=0.0, 119 | sr_ratio=1, 120 | mode="direct_xor", 121 | spike_mode="lif", 122 | dvs=False, 123 | layer=0, 124 | ): 125 | super().__init__() 126 | assert ( 127 | dim % num_heads == 0 128 | ), f"dim {dim} should be divided by num_heads {num_heads}." 129 | self.dim = dim 130 | self.dvs = dvs 131 | self.num_heads = num_heads 132 | if dvs: 133 | self.pool = Erode() 134 | self.scale = 0.125 135 | self.q_conv = nn.Conv2d(dim, dim, kernel_size=1, stride=1, bias=False) 136 | self.q_bn = nn.BatchNorm2d(dim) 137 | if spike_mode == "lif": 138 | self.q_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend="torch") 139 | elif spike_mode == "plif": 140 | self.q_lif = MultiStepParametricLIFNode( 141 | init_tau=2.0, detach_reset=True, backend="torch" 142 | ) 143 | 144 | self.k_conv = nn.Conv2d(dim, dim, kernel_size=1, stride=1, bias=False) 145 | self.k_bn = nn.BatchNorm2d(dim) 146 | if spike_mode == "lif": 147 | self.k_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend="torch") 148 | elif spike_mode == "plif": 149 | self.k_lif = MultiStepParametricLIFNode( 150 | init_tau=2.0, detach_reset=True, backend="torch" 151 | ) 152 | 153 | self.v_conv = nn.Conv2d(dim, dim, kernel_size=1, stride=1, bias=False) 154 | self.v_bn = nn.BatchNorm2d(dim) 155 | if spike_mode == "lif": 156 | self.v_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend="torch") 157 | elif spike_mode == "plif": 158 | self.v_lif = MultiStepParametricLIFNode( 159 | init_tau=2.0, detach_reset=True, backend="torch" 160 | ) 161 | 162 | if spike_mode == "lif": 163 | self.attn_lif = MultiStepLIFNode( 164 | tau=2.0, v_threshold=0.5, detach_reset=True, backend="torch" 165 | ) 166 | elif spike_mode == "plif": 167 | self.attn_lif = MultiStepParametricLIFNode( 168 | init_tau=2.0, v_threshold=0.5, detach_reset=True, backend="torch" 169 | ) 170 | 171 | self.talking_heads = nn.Conv1d( 172 | num_heads, num_heads, kernel_size=1, stride=1, bias=False 173 | ) 174 | if spike_mode == "lif": 175 | self.talking_heads_lif = MultiStepLIFNode( 176 | tau=2.0, v_threshold=0.5, detach_reset=True, backend="torch" 177 | ) 178 | elif spike_mode == "plif": 179 | self.talking_heads_lif = MultiStepParametricLIFNode( 180 | init_tau=2.0, v_threshold=0.5, detach_reset=True, backend="torch" 181 | ) 182 | 183 | self.proj_conv = nn.Conv2d(dim, dim, kernel_size=1, stride=1) 184 | self.proj_bn = nn.BatchNorm2d(dim) 185 | 186 | if spike_mode == "lif": 187 | self.shortcut_lif = MultiStepLIFNode( 188 | tau=2.0, detach_reset=True, backend="torch" 189 | ) 190 | elif spike_mode == "plif": 191 | self.shortcut_lif = MultiStepParametricLIFNode( 192 | init_tau=2.0, detach_reset=True, backend="torch" 193 | ) 194 | 195 | self.mode = mode 196 | self.layer = layer 197 | 198 | def forward(self, x, hook=None): 199 | T, B, C, H, W = x.shape 200 | identity = x 201 | N = H * W 202 | x = self.shortcut_lif(x) 203 | if hook is not None: 204 | hook[self._get_name() + str(self.layer) + "_first_lif"] = x.detach() 205 | 206 | x_for_qkv = x.flatten(0, 1) 207 | q_conv_out = self.q_conv(x_for_qkv) 208 | q_conv_out = self.q_bn(q_conv_out).reshape(T, B, C, H, W).contiguous() 209 | q_conv_out = self.q_lif(q_conv_out) 210 | 211 | if hook is not None: 212 | hook[self._get_name() + str(self.layer) + "_q_lif"] = q_conv_out.detach() 213 | q = ( 214 | q_conv_out.flatten(3) 215 | .transpose(-1, -2) 216 | .reshape(T, B, N, self.num_heads, C // self.num_heads) 217 | .permute(0, 1, 3, 2, 4) 218 | .contiguous() 219 | ) 220 | 221 | k_conv_out = self.k_conv(x_for_qkv) 222 | k_conv_out = self.k_bn(k_conv_out).reshape(T, B, C, H, W).contiguous() 223 | k_conv_out = self.k_lif(k_conv_out) 224 | if self.dvs: 225 | k_conv_out = self.pool(k_conv_out) 226 | if hook is not None: 227 | hook[self._get_name() + str(self.layer) + "_k_lif"] = k_conv_out.detach() 228 | k = ( 229 | k_conv_out.flatten(3) 230 | .transpose(-1, -2) 231 | .reshape(T, B, N, self.num_heads, C // self.num_heads) 232 | .permute(0, 1, 3, 2, 4) 233 | .contiguous() 234 | ) 235 | 236 | v_conv_out = self.v_conv(x_for_qkv) 237 | v_conv_out = self.v_bn(v_conv_out).reshape(T, B, C, H, W).contiguous() 238 | v_conv_out = self.v_lif(v_conv_out) 239 | if self.dvs: 240 | v_conv_out = self.pool(v_conv_out) 241 | if hook is not None: 242 | hook[self._get_name() + str(self.layer) + "_v_lif"] = v_conv_out.detach() 243 | v = ( 244 | v_conv_out.flatten(3) 245 | .transpose(-1, -2) 246 | .reshape(T, B, N, self.num_heads, C // self.num_heads) 247 | .permute(0, 1, 3, 2, 4) 248 | .contiguous() 249 | ) # T B head N C//h 250 | 251 | kv = k.mul(v) 252 | if hook is not None: 253 | hook[self._get_name() + str(self.layer) + "_kv_before"] = kv 254 | if self.dvs: 255 | kv = self.pool(kv) 256 | kv = kv.sum(dim=-2, keepdim=True) 257 | kv = self.talking_heads_lif(kv) 258 | if hook is not None: 259 | hook[self._get_name() + str(self.layer) + "_kv"] = kv.detach() 260 | x = q.mul(kv) 261 | if self.dvs: 262 | x = self.pool(x) 263 | if hook is not None: 264 | hook[self._get_name() + str(self.layer) + "_x_after_qkv"] = x.detach() 265 | 266 | x = x.transpose(3, 4).reshape(T, B, C, H, W).contiguous() 267 | x = ( 268 | self.proj_bn(self.proj_conv(x.flatten(0, 1))) 269 | .reshape(T, B, C, H, W) 270 | .contiguous() 271 | ) 272 | 273 | x = x + identity 274 | return x, v, hook 275 | 276 | 277 | class MS_Block_Conv(nn.Module): 278 | def __init__( 279 | self, 280 | dim, 281 | num_heads, 282 | mlp_ratio=4.0, 283 | qkv_bias=False, 284 | qk_scale=None, 285 | drop=0.0, 286 | attn_drop=0.0, 287 | drop_path=0.0, 288 | norm_layer=nn.LayerNorm, 289 | sr_ratio=1, 290 | attn_mode="direct_xor", 291 | spike_mode="lif", 292 | dvs=False, 293 | layer=0, 294 | ): 295 | super().__init__() 296 | # self.attn = MS_SSA_Conv( 297 | # dim, 298 | # num_heads=num_heads, 299 | # qkv_bias=qkv_bias, 300 | # qk_scale=qk_scale, 301 | # attn_drop=attn_drop, 302 | # proj_drop=drop, 303 | # sr_ratio=sr_ratio, 304 | # mode=attn_mode, 305 | # spike_mode=spike_mode, 306 | # dvs=dvs, 307 | # layer=layer, 308 | # ) 309 | self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 310 | mlp_hidden_dim = int(dim * mlp_ratio) 311 | self.mlp = MS_MLP_Conv( 312 | in_features=dim, 313 | hidden_features=mlp_hidden_dim, 314 | drop=drop, 315 | spike_mode=spike_mode, 316 | layer=layer, 317 | ) 318 | 319 | def forward(self, x, hook=None): 320 | #x_attn, attn, hook = self.attn(x, hook=hook) 321 | x, hook = self.mlp(x, hook=hook) 322 | return x, hook 323 | -------------------------------------------------------------------------------- /sad/models/module/.ipynb_checkpoints/ms_conv-checkpoint.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | from timm.models.layers import DropPath 4 | from spikingjelly.clock_driven.neuron import ( 5 | MultiStepLIFNode, 6 | MultiStepParametricLIFNode, 7 | ) 8 | 9 | 10 | class Erode(nn.Module): 11 | def __init__(self) -> None: 12 | super().__init__() 13 | self.pool = nn.MaxPool3d( 14 | kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1) 15 | ) 16 | 17 | def forward(self, x): 18 | return self.pool(x) 19 | 20 | class SpatialGatingUnit(nn.Module): 21 | def __init__(self, d_ffn, seq_len, spike_mode): 22 | super().__init__() 23 | d_ffn = int(d_ffn/2) 24 | self.norm = nn.BatchNorm1d(seq_len) 25 | self.spatial_proj = nn.Conv1d(seq_len, seq_len, kernel_size=1) 26 | self.fc1_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend="torch") 27 | #self.local_proj = nn.Conv1d(d_ffn, d_ffn, kernel_size=4, groups=d_ffn, padding='same') 28 | nn.init.constant_(self.spatial_proj.bias, 1.0) 29 | #nn.init.constant_(self.local_proj.bias, 1.0) 30 | 31 | 32 | 33 | def forward(self, x): 34 | T, B, C, L = x.shape 35 | 36 | 37 | L = int(L/2) 38 | u, v = x.chunk(2, dim=-1) 39 | #u = self.local_proj(u.reshape(T*B, L, C)) 40 | 41 | #u = self.spatial_proj(u.flatten(0, 1)) 42 | #u = self.norm(u.flatten(0, 1)) 43 | #u = self.fc1_lif(u.reshape(T, B, C, L)) 44 | out = u * v 45 | return out 46 | class MS_MLP_Conv(nn.Module): 47 | def __init__( 48 | self, 49 | in_features, 50 | hidden_features=None, 51 | out_features=None, 52 | drop=0.0, 53 | spike_mode="lif", 54 | layer=0, 55 | ): 56 | super().__init__() 57 | out_features = out_features or in_features 58 | hidden_features = hidden_features or in_features 59 | self.res = in_features == hidden_features 60 | self.fc1_conv = nn.Conv2d(in_features, hidden_features, kernel_size=1, stride=1) 61 | self.fc1_bn = nn.BatchNorm2d(hidden_features) 62 | self.sgu = SpatialGatingUnit(hidden_features, 64, spike_mode) 63 | if spike_mode == "lif": 64 | self.fc1_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend="torch") 65 | elif spike_mode == "plif": 66 | self.fc1_lif = MultiStepParametricLIFNode( 67 | init_tau=2.0, detach_reset=True, backend="torch" 68 | ) 69 | 70 | self.fc2_conv = nn.Conv2d( 71 | int(hidden_features/2), out_features, kernel_size=1, stride=1 72 | ) 73 | self.fc2_bn = nn.BatchNorm2d(out_features) 74 | if spike_mode == "lif": 75 | self.fc2_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend="torch") 76 | elif spike_mode == "plif": 77 | self.fc2_lif = MultiStepParametricLIFNode( 78 | init_tau=2.0, detach_reset=True, backend="torch" 79 | ) 80 | 81 | self.c_hidden = hidden_features 82 | self.c_output = out_features 83 | self.layer = layer 84 | 85 | def forward(self, x, hook=None): 86 | T, B, C, H, W = x.shape 87 | identity = x 88 | 89 | x = self.fc1_lif(x) 90 | if hook is not None: 91 | hook[self._get_name() + str(self.layer) + "_fc1_lif"] = x.detach() 92 | x = self.fc1_conv(x.flatten(0, 1)) 93 | x = self.fc1_bn(x).reshape(T, B, self.c_hidden, H, W).contiguous() 94 | if self.res: 95 | x = identity + x 96 | identity = x 97 | x = x.reshape(T, B, self.c_hidden, H*W).permute(0, 1, 3, 2) 98 | 99 | x = self.fc2_lif(x) 100 | x = self.sgu(x).permute(0, 1, 3, 2).reshape(T, B, int(self.c_hidden/2), H, W) 101 | if hook is not None: 102 | hook[self._get_name() + str(self.layer) + "_fc2_lif"] = x.detach() 103 | x = self.fc2_conv(x.flatten(0, 1)) 104 | x = self.fc2_bn(x).reshape(T, B, C, H, W).contiguous() 105 | 106 | x = x + identity 107 | return x, hook 108 | 109 | 110 | class MS_SSA_Conv(nn.Module): 111 | def __init__( 112 | self, 113 | dim, 114 | num_heads=8, 115 | qkv_bias=False, 116 | qk_scale=None, 117 | attn_drop=0.0, 118 | proj_drop=0.0, 119 | sr_ratio=1, 120 | mode="direct_xor", 121 | spike_mode="lif", 122 | dvs=False, 123 | layer=0, 124 | ): 125 | super().__init__() 126 | assert ( 127 | dim % num_heads == 0 128 | ), f"dim {dim} should be divided by num_heads {num_heads}." 129 | self.dim = dim 130 | self.dvs = dvs 131 | self.num_heads = num_heads 132 | if dvs: 133 | self.pool = Erode() 134 | self.scale = 0.125 135 | self.q_conv = nn.Conv2d(dim, dim, kernel_size=1, stride=1, bias=False) 136 | self.q_bn = nn.BatchNorm2d(dim) 137 | if spike_mode == "lif": 138 | self.q_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend="torch") 139 | elif spike_mode == "plif": 140 | self.q_lif = MultiStepParametricLIFNode( 141 | init_tau=2.0, detach_reset=True, backend="torch" 142 | ) 143 | 144 | self.k_conv = nn.Conv2d(dim, dim, kernel_size=1, stride=1, bias=False) 145 | self.k_bn = nn.BatchNorm2d(dim) 146 | if spike_mode == "lif": 147 | self.k_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend="torch") 148 | elif spike_mode == "plif": 149 | self.k_lif = MultiStepParametricLIFNode( 150 | init_tau=2.0, detach_reset=True, backend="torch" 151 | ) 152 | 153 | self.v_conv = nn.Conv2d(dim, dim, kernel_size=1, stride=1, bias=False) 154 | self.v_bn = nn.BatchNorm2d(dim) 155 | if spike_mode == "lif": 156 | self.v_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend="torch") 157 | elif spike_mode == "plif": 158 | self.v_lif = MultiStepParametricLIFNode( 159 | init_tau=2.0, detach_reset=True, backend="torch" 160 | ) 161 | 162 | if spike_mode == "lif": 163 | self.attn_lif = MultiStepLIFNode( 164 | tau=2.0, v_threshold=0.5, detach_reset=True, backend="torch" 165 | ) 166 | elif spike_mode == "plif": 167 | self.attn_lif = MultiStepParametricLIFNode( 168 | init_tau=2.0, v_threshold=0.5, detach_reset=True, backend="torch" 169 | ) 170 | 171 | self.talking_heads = nn.Conv1d( 172 | num_heads, num_heads, kernel_size=1, stride=1, bias=False 173 | ) 174 | if spike_mode == "lif": 175 | self.talking_heads_lif = MultiStepLIFNode( 176 | tau=2.0, v_threshold=0.5, detach_reset=True, backend="torch" 177 | ) 178 | elif spike_mode == "plif": 179 | self.talking_heads_lif = MultiStepParametricLIFNode( 180 | init_tau=2.0, v_threshold=0.5, detach_reset=True, backend="torch" 181 | ) 182 | 183 | self.proj_conv = nn.Conv2d(dim, dim, kernel_size=1, stride=1) 184 | self.proj_bn = nn.BatchNorm2d(dim) 185 | 186 | if spike_mode == "lif": 187 | self.shortcut_lif = MultiStepLIFNode( 188 | tau=2.0, detach_reset=True, backend="torch" 189 | ) 190 | elif spike_mode == "plif": 191 | self.shortcut_lif = MultiStepParametricLIFNode( 192 | init_tau=2.0, detach_reset=True, backend="torch" 193 | ) 194 | 195 | self.mode = mode 196 | self.layer = layer 197 | 198 | def forward(self, x, hook=None): 199 | T, B, C, H, W = x.shape 200 | identity = x 201 | N = H * W 202 | x = self.shortcut_lif(x) 203 | if hook is not None: 204 | hook[self._get_name() + str(self.layer) + "_first_lif"] = x.detach() 205 | 206 | x_for_qkv = x.flatten(0, 1) 207 | q_conv_out = self.q_conv(x_for_qkv) 208 | q_conv_out = self.q_bn(q_conv_out).reshape(T, B, C, H, W).contiguous() 209 | q_conv_out = self.q_lif(q_conv_out) 210 | 211 | if hook is not None: 212 | hook[self._get_name() + str(self.layer) + "_q_lif"] = q_conv_out.detach() 213 | q = ( 214 | q_conv_out.flatten(3) 215 | .transpose(-1, -2) 216 | .reshape(T, B, N, self.num_heads, C // self.num_heads) 217 | .permute(0, 1, 3, 2, 4) 218 | .contiguous() 219 | ) 220 | 221 | k_conv_out = self.k_conv(x_for_qkv) 222 | k_conv_out = self.k_bn(k_conv_out).reshape(T, B, C, H, W).contiguous() 223 | k_conv_out = self.k_lif(k_conv_out) 224 | if self.dvs: 225 | k_conv_out = self.pool(k_conv_out) 226 | if hook is not None: 227 | hook[self._get_name() + str(self.layer) + "_k_lif"] = k_conv_out.detach() 228 | k = ( 229 | k_conv_out.flatten(3) 230 | .transpose(-1, -2) 231 | .reshape(T, B, N, self.num_heads, C // self.num_heads) 232 | .permute(0, 1, 3, 2, 4) 233 | .contiguous() 234 | ) 235 | 236 | v_conv_out = self.v_conv(x_for_qkv) 237 | v_conv_out = self.v_bn(v_conv_out).reshape(T, B, C, H, W).contiguous() 238 | v_conv_out = self.v_lif(v_conv_out) 239 | if self.dvs: 240 | v_conv_out = self.pool(v_conv_out) 241 | if hook is not None: 242 | hook[self._get_name() + str(self.layer) + "_v_lif"] = v_conv_out.detach() 243 | v = ( 244 | v_conv_out.flatten(3) 245 | .transpose(-1, -2) 246 | .reshape(T, B, N, self.num_heads, C // self.num_heads) 247 | .permute(0, 1, 3, 2, 4) 248 | .contiguous() 249 | ) # T B head N C//h 250 | 251 | kv = k.mul(v) 252 | if hook is not None: 253 | hook[self._get_name() + str(self.layer) + "_kv_before"] = kv 254 | if self.dvs: 255 | kv = self.pool(kv) 256 | kv = kv.sum(dim=-2, keepdim=True) 257 | kv = self.talking_heads_lif(kv) 258 | if hook is not None: 259 | hook[self._get_name() + str(self.layer) + "_kv"] = kv.detach() 260 | x = q.mul(kv) 261 | if self.dvs: 262 | x = self.pool(x) 263 | if hook is not None: 264 | hook[self._get_name() + str(self.layer) + "_x_after_qkv"] = x.detach() 265 | 266 | x = x.transpose(3, 4).reshape(T, B, C, H, W).contiguous() 267 | x = ( 268 | self.proj_bn(self.proj_conv(x.flatten(0, 1))) 269 | .reshape(T, B, C, H, W) 270 | .contiguous() 271 | ) 272 | 273 | x = x + identity 274 | return x, v, hook 275 | 276 | 277 | class MS_Block_Conv(nn.Module): 278 | def __init__( 279 | self, 280 | dim, 281 | num_heads, 282 | mlp_ratio=4.0, 283 | qkv_bias=False, 284 | qk_scale=None, 285 | drop=0.0, 286 | attn_drop=0.0, 287 | drop_path=0.0, 288 | norm_layer=nn.LayerNorm, 289 | sr_ratio=1, 290 | attn_mode="direct_xor", 291 | spike_mode="lif", 292 | dvs=False, 293 | layer=0, 294 | ): 295 | super().__init__() 296 | # self.attn = MS_SSA_Conv( 297 | # dim, 298 | # num_heads=num_heads, 299 | # qkv_bias=qkv_bias, 300 | # qk_scale=qk_scale, 301 | # attn_drop=attn_drop, 302 | # proj_drop=drop, 303 | # sr_ratio=sr_ratio, 304 | # mode=attn_mode, 305 | # spike_mode=spike_mode, 306 | # dvs=dvs, 307 | # layer=layer, 308 | # ) 309 | self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 310 | mlp_hidden_dim = int(dim * mlp_ratio) 311 | self.mlp = MS_MLP_Conv( 312 | in_features=dim, 313 | hidden_features=mlp_hidden_dim, 314 | drop=drop, 315 | spike_mode=spike_mode, 316 | layer=layer, 317 | ) 318 | 319 | def forward(self, x, hook=None): 320 | #x_attn, attn, hook = self.attn(x, hook=hook) 321 | x, hook = self.mlp(x, hook=hook) 322 | return x, hook 323 | -------------------------------------------------------------------------------- /sad/models/module/MS_ResNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from spikingjelly.clock_driven import functional 4 | try: 5 | from torchvision.models.utils import load_state_dict_from_url 6 | except ImportError: 7 | from torchvision._internally_replaced_utils import load_state_dict_from_url 8 | 9 | __all__ = ['MultiStepMSResNet', 'multi_step_sew_resnet18'] 10 | 11 | model_urls = { 12 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 13 | 14 | } 15 | 16 | # modified by https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py 17 | 18 | def sew_function(x: torch.Tensor, y: torch.Tensor, cnf:str): 19 | if cnf == 'ADD': 20 | return x + y 21 | elif cnf == 'AND': 22 | return x * y 23 | elif cnf == 'IAND': 24 | return x * (1. - y) 25 | else: 26 | raise NotImplementedError 27 | 28 | 29 | 30 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 31 | """3x3 convolution with padding""" 32 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 33 | padding=dilation, groups=groups, bias=False, dilation=dilation) 34 | 35 | 36 | def conv1x1(in_planes, out_planes, stride=1): 37 | """1x1 convolution""" 38 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 39 | class BasicBlock(nn.Module): 40 | expansion = 1 41 | 42 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 43 | base_width=64, dilation=1, norm_layer=None, cnf: str = None, single_step_neuron: callable = None, **kwargs): 44 | super(BasicBlock, self).__init__() 45 | if norm_layer is None: 46 | norm_layer = nn.BatchNorm2d 47 | if groups != 1 or base_width != 64: 48 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 49 | if dilation > 1: 50 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 51 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 52 | self.conv1 = conv3x3(inplanes, planes, stride) 53 | self.bn1 = norm_layer(planes) 54 | self.sn1 = single_step_neuron(**kwargs) 55 | self.conv2 = conv3x3(planes, planes) 56 | self.bn2 = norm_layer(planes) 57 | self.sn2 = single_step_neuron(**kwargs) 58 | self.downsample = downsample 59 | if downsample is not None: 60 | self.downsample_sn = single_step_neuron(**kwargs) 61 | self.stride = stride 62 | self.cnf = cnf 63 | 64 | def forward(self, x): 65 | identity = x 66 | 67 | out = self.conv1(x) 68 | out = self.bn1(out) 69 | out = self.sn1(out) 70 | 71 | out = self.conv2(out) 72 | out = self.bn2(out) 73 | out = self.sn2(out) 74 | 75 | if self.downsample is not None: 76 | identity = self.downsample_sn(self.downsample(x)) 77 | 78 | out = sew_function(identity, out, self.cnf) 79 | 80 | return out 81 | 82 | def extra_repr(self) -> str: 83 | return super().extra_repr() + f'cnf={self.cnf}' 84 | 85 | 86 | 87 | # class MultiStepBasicBlock(BasicBlock): 88 | # def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 89 | # base_width=64, dilation=1, norm_layer=None, cnf: str = None, multi_step_neuron: callable = None, **kwargs): 90 | # super().__init__(inplanes, planes, stride, downsample, groups, 91 | # base_width, dilation, norm_layer, cnf, multi_step_neuron, **kwargs) 92 | # 93 | # def forward(self, x_seq): 94 | # identity = x_seq 95 | # 96 | # out = self.sn1(x_seq) 97 | # out = functional.seq_to_ann_forward(x_seq, [self.conv1, self.bn1]) 98 | # 99 | # out = self.sn2(out) 100 | # out = functional.seq_to_ann_forward(out, [self.conv2, self.bn2]) 101 | # 102 | # 103 | # if self.downsample is not None: 104 | # identity = functional.seq_to_ann_forward(x_seq, self.downsample) 105 | # 106 | # out = identity + out 107 | # 108 | # return out 109 | 110 | class MultiStepBasicBlock(BasicBlock): 111 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 112 | base_width=64, dilation=1, norm_layer=None, cnf: str = None, multi_step_neuron: callable = None, **kwargs): 113 | super().__init__(inplanes, planes, stride, downsample, groups, 114 | base_width, dilation, norm_layer, cnf, multi_step_neuron, **kwargs) 115 | 116 | def forward(self, x_seq): 117 | identity = x_seq 118 | 119 | out = functional.seq_to_ann_forward(x_seq, [self.conv1, self.bn1]) 120 | out = self.sn1(out) 121 | 122 | out = functional.seq_to_ann_forward(out, [self.conv2, self.bn2]) 123 | out = self.sn2(out) 124 | 125 | if self.downsample is not None: 126 | identity = self.downsample_sn(functional.seq_to_ann_forward(x_seq, self.downsample)) 127 | 128 | out = sew_function(identity, out, self.cnf) 129 | 130 | return out 131 | 132 | class MultiStepMSResNet(nn.Module): 133 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 134 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 135 | norm_layer=None, T:int=None, cnf: str=None, multi_step_neuron: callable = None, **kwargs): 136 | super().__init__() 137 | self.T = T 138 | if norm_layer is None: 139 | norm_layer = nn.BatchNorm2d 140 | self._norm_layer = norm_layer 141 | 142 | self.inplanes = 64 143 | self.dilation = 1 144 | if replace_stride_with_dilation is None: 145 | # each element in the tuple indicates if we should replace 146 | # the 2x2 stride with a dilated convolution instead 147 | replace_stride_with_dilation = [False, False, False] 148 | if len(replace_stride_with_dilation) != 3: 149 | raise ValueError("replace_stride_with_dilation should be None " 150 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 151 | self.groups = groups 152 | self.base_width = width_per_group 153 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 154 | bias=False) 155 | self.bn1 = norm_layer(self.inplanes) 156 | self.sn1 = multi_step_neuron(**kwargs) 157 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 158 | self.layer1 = self.make_layer(block, 64, layers[0], cnf=cnf, multi_step_neuron=multi_step_neuron, **kwargs) 159 | self.layer2 = self.make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0], cnf=cnf, 160 | multi_step_neuron=multi_step_neuron, **kwargs) 161 | self.layer3 = self.make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1], cnf=cnf, 162 | multi_step_neuron=multi_step_neuron, **kwargs) 163 | self.layer4 = self.make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2], cnf=cnf, 164 | multi_step_neuron=multi_step_neuron, **kwargs) 165 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 166 | self.fc = nn.Linear(512 * block.expansion, num_classes) 167 | 168 | for m in self.modules(): 169 | if isinstance(m, nn.Conv2d): 170 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 171 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 172 | nn.init.constant_(m.weight, 1) 173 | nn.init.constant_(m.bias, 0) 174 | 175 | # Zero-initialize the last BN in each residual branch, 176 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 177 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 178 | if zero_init_residual: 179 | for m in self.modules(): 180 | if isinstance(m, Bottleneck): 181 | nn.init.constant_(m.bn3.weight, 0) 182 | elif isinstance(m, BasicBlock): 183 | nn.init.constant_(m.bn2.weight, 0) 184 | 185 | def make_layer(self, block, planes, blocks, stride=1, dilate=False, cnf: str = None, multi_step_neuron: callable = None, **kwargs): 186 | norm_layer = self._norm_layer 187 | downsample = None 188 | previous_dilation = self.dilation 189 | if dilate: 190 | self.dilation *= stride 191 | stride = 1 192 | if stride != 1 or self.inplanes != planes * block.expansion: 193 | downsample = nn.Sequential( 194 | conv1x1(self.inplanes, planes * block.expansion, stride), 195 | norm_layer(planes * block.expansion), 196 | ) 197 | 198 | layers = [] 199 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 200 | self.base_width, previous_dilation, norm_layer, cnf, multi_step_neuron, **kwargs)) 201 | self.inplanes = planes * block.expansion 202 | for _ in range(1, blocks): 203 | layers.append(block(self.inplanes, planes, groups=self.groups, 204 | base_width=self.base_width, dilation=self.dilation, 205 | norm_layer=norm_layer, cnf=cnf, multi_step_neuron=multi_step_neuron, **kwargs)) 206 | 207 | return nn.Sequential(*layers) 208 | 209 | def _forward_impl(self, x: torch.Tensor): 210 | # See note [TorchScript super()] 211 | x_seq = None 212 | if x.dim() == 5: 213 | # x.shape = [T, N, C, H, W] 214 | x_seq = functional.seq_to_ann_forward(x, [self.conv1, self.bn1]) 215 | else: 216 | assert self.T is not None, 'When x.shape is [N, C, H, W], self.T can not be None.' 217 | # x.shape = [N, C, H, W] 218 | x = self.conv1(x) 219 | x = self.bn1(x) 220 | x.unsqueeze_(0) 221 | x_seq = x.repeat(self.T, 1, 1, 1, 1) 222 | 223 | x_seq = functional.seq_to_ann_forward(x_seq, self.maxpool) 224 | 225 | x_seq = self.layer1(x_seq) 226 | x_seq = self.layer2(x_seq) 227 | x_seq = self.layer3(x_seq) 228 | x_seq = self.layer4(x_seq) 229 | 230 | x_seq = functional.seq_to_ann_forward(x_seq, self.avgpool) 231 | x_seq = self.sn1(x_seq) 232 | x_seq = torch.flatten(x_seq, 2) 233 | # x_seq = self.fc(x_seq.mean(0)) 234 | x_seq = functional.seq_to_ann_forward(x_seq, self.fc) 235 | 236 | return x_seq 237 | 238 | def forward(self, x): 239 | """ 240 | :param x: the input with `shape=[N, C, H, W]` or `[*, N, C, H, W]` 241 | :type x: torch.Tensor 242 | :return: output 243 | :rtype: torch.Tensor 244 | """ 245 | return self._forward_impl(x) 246 | 247 | 248 | 249 | 250 | def _multi_step_sew_resnet(arch, block, layers, pretrained, progress, T, cnf, multi_step_neuron, **kwargs): 251 | model = MultiStepMSResNet(block, layers, T=T, cnf=cnf, multi_step_neuron=multi_step_neuron, **kwargs) 252 | if pretrained: 253 | state_dict = load_state_dict_from_url(model_urls[arch], 254 | progress=progress) 255 | model.load_state_dict(state_dict) 256 | return model 257 | 258 | def multi_step_sew_resnet18(pretrained=False, progress=True, T: int = None, cnf: str = 'ADD', multi_step_neuron: callable=None, **kwargs): 259 | """ 260 | :param pretrained: If True, the SNN will load parameters from the ANN pre-trained on ImageNet 261 | :type pretrained: bool 262 | :param progress: If True, displays a progress bar of the download to stderr 263 | :type progress: bool 264 | :param T: total time-steps 265 | :type T: int 266 | :param cnf: the name of spike-element-wise function 267 | :type cnf: str 268 | :param multi_step_neuron: a multi-step neuron 269 | :type multi_step_neuron: callable 270 | :param kwargs: kwargs for `multi_step_neuron` 271 | :type kwargs: dict 272 | :return: Spiking ResNet-18 273 | :rtype: torch.nn.Module 274 | 275 | The multi-step spike-element-wise ResNet-18 `"Deep Residual Learning in Spiking Neural Networks" `_ 276 | modified by the ResNet-18 model from `"Deep Residual Learning for Image Recognition" `_ 277 | """ 278 | 279 | return _multi_step_sew_resnet('resnet18', MultiStepBasicBlock, [2, 2, 2, 2], pretrained, progress, T, cnf, multi_step_neuron, **kwargs) 280 | -------------------------------------------------------------------------------- /sad/models/module/Real_MS_ResNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from spikingjelly.clock_driven import functional 4 | 5 | try: 6 | from torchvision.models.utils import load_state_dict_from_url 7 | except ImportError: 8 | from torchvision._internally_replaced_utils import load_state_dict_from_url 9 | 10 | __all__ = ['MultiStepMSResNet', 'multi_step_sew_resnet18'] 11 | 12 | model_urls = { 13 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 14 | 15 | } 16 | 17 | 18 | # modified by https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py 19 | 20 | def sew_function(x: torch.Tensor, y: torch.Tensor, cnf: str): 21 | if cnf == 'ADD': 22 | return x + y 23 | elif cnf == 'AND': 24 | return x * y 25 | elif cnf == 'IAND': 26 | return x * (1. - y) 27 | else: 28 | raise NotImplementedError 29 | 30 | 31 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 32 | """3x3 convolution with padding""" 33 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 34 | padding=dilation, groups=groups, bias=False, dilation=dilation) 35 | 36 | 37 | def conv1x1(in_planes, out_planes, stride=1): 38 | """1x1 convolution""" 39 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 40 | 41 | 42 | class BasicBlock(nn.Module): 43 | expansion = 1 44 | 45 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 46 | base_width=64, dilation=1, norm_layer=None, cnf: str = None, single_step_neuron: callable = None, 47 | **kwargs): 48 | super(BasicBlock, self).__init__() 49 | if norm_layer is None: 50 | norm_layer = nn.BatchNorm2d 51 | if groups != 1 or base_width != 64: 52 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 53 | if dilation > 1: 54 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 55 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 56 | self.conv1 = conv3x3(inplanes, planes, stride) 57 | self.bn1 = norm_layer(planes) 58 | self.sn1 = single_step_neuron(**kwargs) 59 | self.conv2 = conv3x3(planes, planes) 60 | self.bn2 = norm_layer(planes) 61 | self.sn2 = single_step_neuron(**kwargs) 62 | self.downsample = downsample 63 | if downsample is not None: 64 | self.downsample_sn = single_step_neuron(**kwargs) 65 | self.stride = stride 66 | self.cnf = cnf 67 | 68 | def forward(self, x): 69 | identity = x 70 | 71 | out = self.conv1(x) 72 | out = self.bn1(out) 73 | out = self.sn1(out) 74 | 75 | out = self.conv2(out) 76 | out = self.bn2(out) 77 | out = self.sn2(out) 78 | 79 | if self.downsample is not None: 80 | identity = self.downsample_sn(self.downsample(x)) 81 | 82 | out = sew_function(identity, out, self.cnf) 83 | 84 | return out 85 | 86 | def extra_repr(self) -> str: 87 | return super().extra_repr() + f'cnf={self.cnf}' 88 | 89 | 90 | class MultiStepBasicBlock(BasicBlock): 91 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 92 | base_width=64, dilation=1, norm_layer=None, cnf: str = None, multi_step_neuron: callable = None, **kwargs): 93 | super().__init__(inplanes, planes, stride, downsample, groups, 94 | base_width, dilation, norm_layer, cnf, multi_step_neuron, **kwargs) 95 | 96 | def forward(self, x_seq): 97 | identity = x_seq 98 | 99 | out = self.sn1(x_seq) 100 | out = functional.seq_to_ann_forward(out, [self.conv1, self.bn1]) 101 | 102 | out = self.sn2(out) 103 | out = functional.seq_to_ann_forward(out, [self.conv2, self.bn2]) 104 | 105 | 106 | if self.downsample is not None: 107 | x_seq = self.downsample_sn(x_seq) 108 | identity = functional.seq_to_ann_forward(x_seq, self.downsample) 109 | 110 | out = identity + out 111 | 112 | return out 113 | 114 | # class MultiStepBasicBlock(BasicBlock): 115 | # def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 116 | # base_width=64, dilation=1, norm_layer=None, cnf: str = None, multi_step_neuron: callable = None, 117 | # **kwargs): 118 | # super().__init__(inplanes, planes, stride, downsample, groups, 119 | # base_width, dilation, norm_layer, cnf, multi_step_neuron, **kwargs) 120 | # 121 | # def forward(self, x_seq): 122 | # identity = x_seq 123 | # 124 | # x_seq = self.sn1(x_seq) 125 | # out = functional.seq_to_ann_forward(x_seq, [self.conv1, self.bn1]) 126 | # 127 | # out = self.sn2(out) 128 | # out = functional.seq_to_ann_forward(out, [self.conv2, self.bn2]) 129 | # 130 | # 131 | # if self.downsample is not None: 132 | # identity = self.downsample_sn(functional.seq_to_ann_forward(x_seq, self.downsample)) 133 | # 134 | # out = sew_function(identity, out, self.cnf) 135 | # 136 | # return out 137 | 138 | 139 | class MultiStepMSResNet(nn.Module): 140 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 141 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 142 | norm_layer=None, T: int = None, cnf: str = None, multi_step_neuron: callable = None, **kwargs): 143 | super().__init__() 144 | self.T = T 145 | if norm_layer is None: 146 | norm_layer = nn.BatchNorm2d 147 | self._norm_layer = norm_layer 148 | 149 | self.inplanes = 64 150 | self.dilation = 1 151 | if replace_stride_with_dilation is None: 152 | # each element in the tuple indicates if we should replace 153 | # the 2x2 stride with a dilated convolution instead 154 | replace_stride_with_dilation = [False, False, False] 155 | if len(replace_stride_with_dilation) != 3: 156 | raise ValueError("replace_stride_with_dilation should be None " 157 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 158 | self.groups = groups 159 | self.base_width = width_per_group 160 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 161 | bias=False) 162 | self.bn1 = norm_layer(self.inplanes) 163 | self.sn1 = multi_step_neuron(**kwargs) 164 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 165 | self.layer1 = self.make_layer(block, 64, layers[0], cnf=cnf, multi_step_neuron=multi_step_neuron, **kwargs) 166 | self.layer2 = self.make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0], cnf=cnf, 167 | multi_step_neuron=multi_step_neuron, **kwargs) 168 | self.layer3 = self.make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1], cnf=cnf, 169 | multi_step_neuron=multi_step_neuron, **kwargs) 170 | self.layer4 = self.make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2], cnf=cnf, 171 | multi_step_neuron=multi_step_neuron, **kwargs) 172 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 173 | self.fc = nn.Linear(512 * block.expansion, num_classes) 174 | 175 | for m in self.modules(): 176 | if isinstance(m, nn.Conv2d): 177 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 178 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 179 | nn.init.constant_(m.weight, 1) 180 | nn.init.constant_(m.bias, 0) 181 | 182 | # Zero-initialize the last BN in each residual branch, 183 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 184 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 185 | if zero_init_residual: 186 | for m in self.modules(): 187 | if isinstance(m, Bottleneck): 188 | nn.init.constant_(m.bn3.weight, 0) 189 | elif isinstance(m, BasicBlock): 190 | nn.init.constant_(m.bn2.weight, 0) 191 | 192 | def make_layer(self, block, planes, blocks, stride=1, dilate=False, cnf: str = None, 193 | multi_step_neuron: callable = None, **kwargs): 194 | norm_layer = self._norm_layer 195 | downsample = None 196 | previous_dilation = self.dilation 197 | if dilate: 198 | self.dilation *= stride 199 | stride = 1 200 | if stride != 1 or self.inplanes != planes * block.expansion: 201 | downsample = nn.Sequential( 202 | conv1x1(self.inplanes, planes * block.expansion, stride), 203 | norm_layer(planes * block.expansion), 204 | ) 205 | 206 | layers = [] 207 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 208 | self.base_width, previous_dilation, norm_layer, cnf, multi_step_neuron, **kwargs)) 209 | self.inplanes = planes * block.expansion 210 | for _ in range(1, blocks): 211 | layers.append(block(self.inplanes, planes, groups=self.groups, 212 | base_width=self.base_width, dilation=self.dilation, 213 | norm_layer=norm_layer, cnf=cnf, multi_step_neuron=multi_step_neuron, **kwargs)) 214 | 215 | return nn.Sequential(*layers) 216 | 217 | def _forward_impl(self, x: torch.Tensor): 218 | # See note [TorchScript super()] 219 | x_seq = None 220 | if x.dim() == 5: 221 | # x.shape = [T, N, C, H, W] 222 | x_seq = functional.seq_to_ann_forward(x, [self.conv1, self.bn1]) 223 | else: 224 | assert self.T is not None, 'When x.shape is [N, C, H, W], self.T can not be None.' 225 | # x.shape = [N, C, H, W] 226 | x = self.conv1(x) 227 | x = self.bn1(x) 228 | x.unsqueeze_(0) 229 | x_seq = x.repeat(self.T, 1, 1, 1, 1) 230 | 231 | x_seq = functional.seq_to_ann_forward(x_seq, self.maxpool) 232 | 233 | x_seq = self.layer1(x_seq) 234 | x_seq = self.layer2(x_seq) 235 | x_seq = self.layer3(x_seq) 236 | x_seq = self.layer4(x_seq) 237 | 238 | x_seq = functional.seq_to_ann_forward(x_seq, self.avgpool) 239 | x_seq = self.sn1(x_seq) 240 | x_seq = torch.flatten(x_seq, 2) 241 | # x_seq = self.fc(x_seq.mean(0)) 242 | x_seq = functional.seq_to_ann_forward(x_seq, self.fc) 243 | 244 | return x_seq 245 | 246 | def forward(self, x): 247 | """ 248 | :param x: the input with `shape=[N, C, H, W]` or `[*, N, C, H, W]` 249 | :type x: torch.Tensor 250 | :return: output 251 | :rtype: torch.Tensor 252 | """ 253 | return self._forward_impl(x) 254 | 255 | 256 | def _multi_step_sew_resnet(arch, block, layers, pretrained, progress, T, cnf, multi_step_neuron, **kwargs): 257 | model = MultiStepMSResNet(block, layers, T=T, cnf=cnf, multi_step_neuron=multi_step_neuron, **kwargs) 258 | if pretrained: 259 | state_dict = load_state_dict_from_url(model_urls[arch], 260 | progress=progress) 261 | model.load_state_dict(state_dict) 262 | return model 263 | 264 | 265 | def multi_step_sew_resnet18(pretrained=False, progress=True, T: int = None, cnf: str = 'ADD', 266 | multi_step_neuron: callable = None, **kwargs): 267 | """ 268 | :param pretrained: If True, the SNN will load parameters from the ANN pre-trained on ImageNet 269 | :type pretrained: bool 270 | :param progress: If True, displays a progress bar of the download to stderr 271 | :type progress: bool 272 | :param T: total time-steps 273 | :type T: int 274 | :param cnf: the name of spike-element-wise function 275 | :type cnf: str 276 | :param multi_step_neuron: a multi-step neuron 277 | :type multi_step_neuron: callable 278 | :param kwargs: kwargs for `multi_step_neuron` 279 | :type kwargs: dict 280 | :return: Spiking ResNet-18 281 | :rtype: torch.nn.Module 282 | 283 | The multi-step spike-element-wise ResNet-18 `"Deep Residual Learning in Spiking Neural Networks" `_ 284 | modified by the ResNet-18 model from `"Deep Residual Learning for Image Recognition" `_ 285 | """ 286 | 287 | return _multi_step_sew_resnet('resnet18', MultiStepBasicBlock, [2, 2, 2, 2], pretrained, progress, T, cnf, 288 | multi_step_neuron, **kwargs) 289 | -------------------------------------------------------------------------------- /sad/utils/geometry.py: -------------------------------------------------------------------------------- 1 | import PIL 2 | import numpy as np 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | from pyquaternion import Quaternion 7 | from nuscenes.utils.geometry_utils import transform_matrix 8 | 9 | def resize_and_crop_image(img, resize_dims, crop): 10 | # Bilinear resizing followed by cropping 11 | img = img.resize(resize_dims, resample=PIL.Image.BILINEAR) 12 | img = img.crop(crop) 13 | return img 14 | 15 | 16 | def update_intrinsics(intrinsics, top_crop=0.0, left_crop=0.0, scale_width=1.0, scale_height=1.0): 17 | """ 18 | Parameters 19 | ---------- 20 | intrinsics: torch.Tensor (3, 3) 21 | top_crop: float 22 | left_crop: float 23 | scale_width: float 24 | scale_height: float 25 | """ 26 | updated_intrinsics = intrinsics.clone() 27 | # Adjust intrinsics scale due to resizing 28 | updated_intrinsics[0, 0] *= scale_width 29 | updated_intrinsics[0, 2] *= scale_width 30 | updated_intrinsics[1, 1] *= scale_height 31 | updated_intrinsics[1, 2] *= scale_height 32 | 33 | # Adjust principal point due to cropping 34 | updated_intrinsics[0, 2] -= left_crop 35 | updated_intrinsics[1, 2] -= top_crop 36 | 37 | return updated_intrinsics 38 | 39 | 40 | def calculate_birds_eye_view_parameters(x_bounds, y_bounds, z_bounds): 41 | """ 42 | Parameters 43 | ---------- 44 | x_bounds: Forward direction in the ego-car. 45 | y_bounds: Sides 46 | z_bounds: Height 47 | 48 | Returns 49 | ------- 50 | bev_resolution: Bird's-eye view bev_resolution 51 | bev_start_position Bird's-eye view first element 52 | bev_dimension Bird's-eye view tensor spatial dimension 53 | """ 54 | bev_resolution = torch.tensor([row[2] for row in [x_bounds, y_bounds, z_bounds]]) 55 | bev_start_position = torch.tensor([row[0] + row[2] / 2.0 for row in [x_bounds, y_bounds, z_bounds]]) 56 | bev_dimension = torch.tensor([(row[1] - row[0]) / row[2] for row in [x_bounds, y_bounds, z_bounds]], 57 | dtype=torch.long) 58 | 59 | return bev_resolution, bev_start_position, bev_dimension 60 | 61 | 62 | def convert_egopose_to_matrix_numpy(egopose): 63 | transformation_matrix = np.zeros((4, 4), dtype=np.float32) 64 | rotation = Quaternion(egopose['rotation']).rotation_matrix 65 | translation = np.array(egopose['translation']) 66 | transformation_matrix[:3, :3] = rotation 67 | transformation_matrix[:3, 3] = translation 68 | transformation_matrix[3, 3] = 1.0 69 | return transformation_matrix 70 | 71 | def get_global_pose(rec, nusc, inverse=False): 72 | lidar_sample_data = nusc.get('sample_data', rec['data']['LIDAR_TOP']) 73 | 74 | sd_ep = nusc.get("ego_pose", lidar_sample_data["ego_pose_token"]) 75 | sd_cs = nusc.get("calibrated_sensor", lidar_sample_data["calibrated_sensor_token"]) 76 | if inverse is False: 77 | global_from_ego = transform_matrix(sd_ep["translation"], Quaternion(sd_ep["rotation"]), inverse=False) 78 | ego_from_sensor = transform_matrix(sd_cs["translation"], Quaternion(sd_cs["rotation"]), inverse=False) 79 | pose = global_from_ego.dot(ego_from_sensor) 80 | else: 81 | sensor_from_ego = transform_matrix(sd_cs["translation"], Quaternion(sd_cs["rotation"]), inverse=True) 82 | ego_from_global = transform_matrix(sd_ep["translation"], Quaternion(sd_ep["rotation"]), inverse=True) 83 | pose = sensor_from_ego.dot(ego_from_global) 84 | return pose 85 | 86 | def invert_matrix_egopose_numpy(egopose): 87 | """ Compute the inverse transformation of a 4x4 egopose numpy matrix.""" 88 | inverse_matrix = np.zeros((4, 4), dtype=np.float32) 89 | rotation = egopose[:3, :3] 90 | translation = egopose[:3, 3] 91 | inverse_matrix[:3, :3] = rotation.T 92 | inverse_matrix[:3, 3] = -np.dot(rotation.T, translation) 93 | inverse_matrix[3, 3] = 1.0 94 | return inverse_matrix 95 | 96 | 97 | def mat2pose_vec(matrix: torch.Tensor): 98 | """ 99 | Converts a 4x4 pose matrix into a 6-dof pose vector 100 | Args: 101 | matrix (ndarray): 4x4 pose matrix 102 | Returns: 103 | vector (ndarray): 6-dof pose vector comprising translation components (tx, ty, tz) and 104 | rotation components (rx, ry, rz) 105 | """ 106 | 107 | # M[1, 2] = -sinx*cosy, M[2, 2] = +cosx*cosy 108 | rotx = torch.atan2(-matrix[..., 1, 2], matrix[..., 2, 2]) 109 | 110 | # M[0, 2] = +siny, M[1, 2] = -sinx*cosy, M[2, 2] = +cosx*cosy 111 | cosy = torch.sqrt(matrix[..., 1, 2] ** 2 + matrix[..., 2, 2] ** 2) 112 | roty = torch.atan2(matrix[..., 0, 2], cosy) 113 | 114 | # M[0, 0] = +cosy*cosz, M[0, 1] = -cosy*sinz 115 | rotz = torch.atan2(-matrix[..., 0, 1], matrix[..., 0, 0]) 116 | 117 | rotation = torch.stack((rotx, roty, rotz), dim=-1) 118 | 119 | # Extract translation params 120 | translation = matrix[..., :3, 3] 121 | return torch.cat((translation, rotation), dim=-1) 122 | 123 | 124 | def euler2mat(angle: torch.Tensor): 125 | """Convert euler angles to rotation matrix. 126 | Reference: https://github.com/pulkitag/pycaffe-utils/blob/master/rot_utils.py#L174 127 | Args: 128 | angle: rotation angle along 3 axis (in radians) [Bx3] 129 | Returns: 130 | Rotation matrix corresponding to the euler angles [Bx3x3] 131 | """ 132 | shape = angle.shape 133 | angle = angle.view(-1, 3) 134 | x, y, z = angle[:, 0], angle[:, 1], angle[:, 2] 135 | 136 | cosz = torch.cos(z) 137 | sinz = torch.sin(z) 138 | 139 | zeros = torch.zeros_like(z) 140 | ones = torch.ones_like(z) 141 | zmat = torch.stack([cosz, -sinz, zeros, sinz, cosz, zeros, zeros, zeros, ones], dim=1).view(-1, 3, 3) 142 | 143 | cosy = torch.cos(y) 144 | siny = torch.sin(y) 145 | 146 | ymat = torch.stack([cosy, zeros, siny, zeros, ones, zeros, -siny, zeros, cosy], dim=1).view(-1, 3, 3) 147 | 148 | cosx = torch.cos(x) 149 | sinx = torch.sin(x) 150 | 151 | xmat = torch.stack([ones, zeros, zeros, zeros, cosx, -sinx, zeros, sinx, cosx], dim=1).view(-1, 3, 3) 152 | 153 | rot_mat = xmat.bmm(ymat).bmm(zmat) 154 | rot_mat = rot_mat.view(*shape[:-1], 3, 3) 155 | return rot_mat 156 | 157 | 158 | def pose_vec2mat(vec: torch.Tensor): 159 | """ 160 | Convert 6DoF parameters to transformation matrix. 161 | Args: 162 | vec: 6DoF parameters in the order of tx, ty, tz, rx, ry, rz [B,6] 163 | Returns: 164 | A transformation matrix [B,4,4] 165 | """ 166 | translation = vec[..., :3].unsqueeze(-1) # [...x3x1] 167 | rot = vec[..., 3:].contiguous() # [...x3] 168 | rot_mat = euler2mat(rot) # [...,3,3] 169 | transform_mat = torch.cat([rot_mat, translation], dim=-1) # [...,3,4] 170 | transform_mat = torch.nn.functional.pad(transform_mat, [0, 0, 0, 1], value=0) # [...,4,4] 171 | transform_mat[..., 3, 3] = 1.0 172 | return transform_mat 173 | 174 | 175 | def invert_pose_matrix(x): 176 | """ 177 | Parameters 178 | ---------- 179 | x: [B, 4, 4] batch of pose matrices 180 | 181 | Returns 182 | ------- 183 | out: [B, 4, 4] batch of inverse pose matrices 184 | """ 185 | assert len(x.shape) == 3 and x.shape[1:] == (4, 4), 'Only works for batch of pose matrices.' 186 | 187 | transposed_rotation = torch.transpose(x[:, :3, :3], 1, 2) 188 | translation = x[:, :3, 3:] 189 | 190 | inverse_mat = torch.cat([transposed_rotation, -torch.bmm(transposed_rotation, translation)], dim=-1) # [B,3,4] 191 | inverse_mat = torch.nn.functional.pad(inverse_mat, [0, 0, 0, 1], value=0) # [B,4,4] 192 | inverse_mat[..., 3, 3] = 1.0 193 | return inverse_mat 194 | 195 | 196 | def warp_features(x, flow, mode='nearest', spatial_extent=None): 197 | """ Applies a rotation and translation to feature map x. 198 | Args: 199 | x: (b, c, h, w) feature map 200 | flow: (b, 6) 6DoF vector (only uses the xy poriton) 201 | mode: use 'nearest' when dealing with categorical inputs 202 | Returns: 203 | in plane transformed feature map 204 | """ 205 | if flow is None: 206 | return x 207 | b, c, h, w = x.shape 208 | # z-rotation 209 | angle = flow[:, 5].clone() # torch.atan2(flow[:, 1, 0], flow[:, 0, 0]) 210 | # x-y translation 211 | translation = flow[:, :2].clone() # flow[:, :2, 3] 212 | 213 | # Normalise translation. Need to divide by how many meters is half of the image. 214 | # because translation of 1.0 correspond to translation of half of the image. 215 | translation[:, 0] /= spatial_extent[0] 216 | translation[:, 1] /= spatial_extent[1] 217 | # forward axis is inverted 218 | translation[:, 0] *= -1 219 | 220 | cos_theta = torch.cos(angle) 221 | sin_theta = torch.sin(angle) 222 | 223 | # output = Rot.input + translation 224 | # tx and ty are inverted as is the case when going from real coordinates to numpy coordinates 225 | # translation_pos_0 -> positive value makes the image move to the left 226 | # translation_pos_1 -> positive value makes the image move to the top 227 | # Angle -> positive value in rad makes the image move in the trigonometric way 228 | transformation = torch.stack([cos_theta, -sin_theta, translation[:, 1], 229 | sin_theta, cos_theta, translation[:, 0]], dim=-1).view(b, 2, 3) 230 | 231 | # Note that a rotation will preserve distances only if height = width. Otherwise there's 232 | # resizing going on. e.g. rotation of pi/2 of a 100x200 image will make what's in the center of the image 233 | # elongated. 234 | grid = torch.nn.functional.affine_grid(transformation, size=x.shape, align_corners=False) 235 | grid = grid.to(dtype=x.dtype) 236 | warped_x = torch.nn.functional.grid_sample(x, grid, mode=mode, padding_mode='zeros', align_corners=False) 237 | 238 | return warped_x 239 | 240 | 241 | def cumulative_warp_features(x, flow, mode='nearest', spatial_extent=None): 242 | """ Warps a sequence of feature maps by accumulating incremental 2d flow. 243 | 244 | x[:, -1] remains unchanged 245 | x[:, -2] is warped using flow[:, -2] 246 | x[:, -3] is warped using flow[:, -3] @ flow[:, -2] 247 | ... 248 | x[:, 0] is warped using flow[:, 0] @ ... @ flow[:, -3] @ flow[:, -2] 249 | 250 | Args: 251 | x: (b, t, c, h, w) sequence of feature maps 252 | flow: (b, t, 6) sequence of 6 DoF pose 253 | from t to t+1 (only uses the xy poriton) 254 | 255 | """ 256 | sequence_length = x.shape[1] 257 | if sequence_length == 1: 258 | return x 259 | 260 | flow = pose_vec2mat(flow) 261 | 262 | out = [x[:, -1]] 263 | cum_flow = flow[:, -2] 264 | for t in reversed(range(sequence_length - 1)): 265 | out.append(warp_features(x[:, t], mat2pose_vec(cum_flow), mode=mode, spatial_extent=spatial_extent)) 266 | # @ is the equivalent of torch.bmm 267 | cum_flow = flow[:, t - 1] @ cum_flow 268 | 269 | return torch.stack(out[::-1], 1) 270 | 271 | 272 | def cumulative_warp_features_reverse(x, flow, mode='nearest', spatial_extent=None): 273 | """ Warps a sequence of feature maps by accumulating incremental 2d flow. 274 | 275 | x[:, 0] remains unchanged 276 | x[:, 1] is warped using flow[:, 0].inverse() 277 | x[:, 2] is warped using flow[:, 0].inverse() @ flow[:, 1].inverse() 278 | ... 279 | 280 | Args: 281 | x: (b, t, c, h, w) sequence of feature maps 282 | flow: (b, t, 6) sequence of 6 DoF pose 283 | from t to t+1 (only uses the xy poriton) 284 | 285 | """ 286 | flow = pose_vec2mat(flow) 287 | 288 | out = [x[:,0]] 289 | 290 | for i in range(1, x.shape[1]): 291 | if i==1: 292 | cum_flow = invert_pose_matrix(flow[:, 0]) 293 | else: 294 | cum_flow = cum_flow @ invert_pose_matrix(flow[:,i-1]) 295 | out.append( warp_features(x[:,i], mat2pose_vec(cum_flow), mode, spatial_extent=spatial_extent)) 296 | return torch.stack(out, 1) 297 | 298 | 299 | class VoxelsSumming(torch.autograd.Function): 300 | """Adapted from https://github.com/nv-tlabs/lift-splat-shoot/blob/master/src/tools.py#L193""" 301 | @staticmethod 302 | def forward(ctx, x, geometry, ranks): 303 | """The features `x` and `geometry` are ranked by voxel positions.""" 304 | # Cumulative sum of all features. 305 | x = x.cumsum(0) 306 | 307 | # Indicates the change of voxel. 308 | mask = torch.ones(x.shape[0], device=x.device, dtype=torch.bool) 309 | mask[:-1] = ranks[1:] != ranks[:-1] 310 | 311 | x, geometry = x[mask], geometry[mask] 312 | # Calculate sum of features within a voxel. 313 | x = torch.cat((x[:1], x[1:] - x[:-1])) 314 | 315 | ctx.save_for_backward(mask) 316 | ctx.mark_non_differentiable(geometry) 317 | 318 | return x, geometry 319 | 320 | @staticmethod 321 | def backward(ctx, grad_x, grad_geometry): 322 | (mask,) = ctx.saved_tensors 323 | # Since the operation is summing, we simply need to send gradient 324 | # to all elements that were part of the summation process. 325 | indices = torch.cumsum(mask, 0) 326 | indices[mask] -= 1 327 | 328 | output_grad = grad_x[indices] 329 | 330 | return output_grad, None, None 331 | -------------------------------------------------------------------------------- /sad/utils/instance.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | import numpy as np 6 | from scipy.optimize import linear_sum_assignment 7 | 8 | from sad.utils.geometry import mat2pose_vec, pose_vec2mat, warp_features 9 | 10 | 11 | # set ignore index to 0 for vis 12 | def convert_instance_mask_to_center_and_offset_label(instance_img, future_egomotion, num_instances, ignore_index=255, 13 | subtract_egomotion=True, sigma=3, spatial_extent=None): 14 | seq_len, h, w = instance_img.shape 15 | center_label = torch.zeros(seq_len, 1, h, w) 16 | offset_label = ignore_index * torch.ones(seq_len, 2, h, w) 17 | future_displacement_label = ignore_index * torch.ones(seq_len, 2, h, w) 18 | # x is vertical displacement, y is horizontal displacement 19 | x, y = torch.meshgrid(torch.arange(h, dtype=torch.float), torch.arange(w, dtype=torch.float)) 20 | 21 | if subtract_egomotion: 22 | future_egomotion_inv = mat2pose_vec(pose_vec2mat(future_egomotion).inverse()) 23 | 24 | # Compute warped instance segmentation 25 | warped_instance_seg = {} 26 | for t in range(1, seq_len): 27 | warped_inst_t = warp_features(instance_img[t].unsqueeze(0).unsqueeze(1).float(), 28 | future_egomotion_inv[t - 1].unsqueeze(0), mode='nearest', 29 | spatial_extent=spatial_extent) 30 | warped_instance_seg[t] = warped_inst_t[0, 0] 31 | 32 | # Ignore id 0 which is the background 33 | for instance_id in range(1, num_instances+1): 34 | prev_xc = None 35 | prev_yc = None 36 | prev_mask = None 37 | for t in range(seq_len): 38 | instance_mask = (instance_img[t] == instance_id) 39 | if instance_mask.sum() == 0: 40 | # this instance is not in this frame 41 | prev_xc = None 42 | prev_yc = None 43 | prev_mask = None 44 | continue 45 | 46 | xc = x[instance_mask].mean().round().long() 47 | yc = y[instance_mask].mean().round().long() 48 | 49 | off_x = xc - x 50 | off_y = yc - y 51 | g = torch.exp(-(off_x ** 2 + off_y ** 2) / sigma ** 2) 52 | center_label[t, 0] = torch.maximum(center_label[t, 0], g) 53 | offset_label[t, 0, instance_mask] = off_x[instance_mask] 54 | offset_label[t, 1, instance_mask] = off_y[instance_mask] 55 | 56 | if prev_xc is not None: 57 | # old method 58 | # cur_pt = torch.stack((xc, yc)).unsqueeze(0).float() 59 | # if subtract_egomotion: 60 | # cur_pt = warp_points(cur_pt, future_egomotion_inv[t - 1]) 61 | # cur_pt = cur_pt.squeeze(0) 62 | 63 | warped_instance_mask = warped_instance_seg[t] == instance_id 64 | if warped_instance_mask.sum() > 0: 65 | warped_xc = x[warped_instance_mask].mean().round() 66 | warped_yc = y[warped_instance_mask].mean().round() 67 | 68 | delta_x = warped_xc - prev_xc 69 | delta_y = warped_yc - prev_yc 70 | future_displacement_label[t - 1, 0, prev_mask] = delta_x 71 | future_displacement_label[t - 1, 1, prev_mask] = delta_y 72 | 73 | prev_xc = xc 74 | prev_yc = yc 75 | prev_mask = instance_mask 76 | 77 | return center_label, offset_label, future_displacement_label 78 | 79 | 80 | def find_instance_centers(center_prediction: torch.Tensor, conf_threshold: float = 0.1, nms_kernel_size: float = 3): 81 | assert len(center_prediction.shape) == 3 82 | center_prediction = F.threshold(center_prediction, threshold=conf_threshold, value=-1) 83 | 84 | nms_padding = (nms_kernel_size - 1) // 2 85 | maxpooled_center_prediction = F.max_pool2d( 86 | center_prediction, kernel_size=nms_kernel_size, stride=1, padding=nms_padding 87 | ) 88 | 89 | # Filter all elements that are not the maximum (i.e. the center of the heatmap instance) 90 | center_prediction[center_prediction != maxpooled_center_prediction] = -1 91 | return torch.nonzero(center_prediction > 0)[:, 1:] 92 | 93 | 94 | def group_pixels(centers: torch.Tensor, offset_predictions: torch.Tensor) -> torch.Tensor: 95 | width, height = offset_predictions.shape[-2:] 96 | x_grid = ( 97 | torch.arange(width, dtype=offset_predictions.dtype, device=offset_predictions.device) 98 | .view(1, width, 1) 99 | .repeat(1, 1, height) 100 | ) 101 | y_grid = ( 102 | torch.arange(height, dtype=offset_predictions.dtype, device=offset_predictions.device) 103 | .view(1, 1, height) 104 | .repeat(1, width, 1) 105 | ) 106 | pixel_grid = torch.cat((x_grid, y_grid), dim=0) 107 | center_locations = (pixel_grid + offset_predictions).view(2, width * height, 1).permute(2, 1, 0) 108 | centers = centers.view(-1, 1, 2) 109 | 110 | distances = torch.norm(centers - center_locations, dim=-1) 111 | 112 | instance_id = torch.argmin(distances, dim=0).reshape(1, width, height) + 1 113 | return instance_id 114 | 115 | 116 | def get_instance_segmentation_and_centers( 117 | center_predictions: torch.Tensor, 118 | offset_predictions: torch.Tensor, 119 | foreground_mask: torch.Tensor, 120 | conf_threshold: float = 0.1, 121 | nms_kernel_size: float = 3, 122 | max_n_instance_centers: int = 100, 123 | ) -> Tuple[torch.Tensor, torch.Tensor]: 124 | width, height = center_predictions.shape[-2:] 125 | center_predictions = center_predictions.view(1, width, height) 126 | offset_predictions = offset_predictions.view(2, width, height) 127 | foreground_mask = foreground_mask.view(1, width, height) 128 | 129 | centers = find_instance_centers(center_predictions, conf_threshold=conf_threshold, nms_kernel_size=nms_kernel_size) 130 | if not len(centers): 131 | return torch.zeros(center_predictions.shape, dtype=torch.int64, device=center_predictions.device), \ 132 | torch.zeros((0, 2), device=centers.device) 133 | 134 | if len(centers) > max_n_instance_centers: 135 | # print(f'There are a lot of detected instance centers: {centers.shape}') 136 | centers = centers[:max_n_instance_centers].clone() 137 | 138 | instance_ids = group_pixels(centers, offset_predictions) 139 | instance_seg = (instance_ids * foreground_mask.float()).long() 140 | 141 | # Make the indices of instance_seg consecutive 142 | instance_seg = make_instance_seg_consecutive(instance_seg) 143 | 144 | return instance_seg.long(), centers 145 | 146 | 147 | def update_instance_ids(instance_seg, old_ids, new_ids): 148 | """ 149 | Parameters 150 | ---------- 151 | instance_seg: torch.Tensor arbitrary shape 152 | old_ids: 1D tensor containing the list of old ids, must be all present in instance_seg. 153 | new_ids: 1D tensor with the new ids, aligned with old_ids 154 | 155 | Returns 156 | new_instance_seg: torch.Tensor same shape as instance_seg with new ids 157 | """ 158 | indices = torch.arange(old_ids.max() + 1, device=instance_seg.device) 159 | for old_id, new_id in zip(old_ids, new_ids): 160 | indices[old_id] = new_id 161 | 162 | return indices[instance_seg].long() 163 | 164 | 165 | def make_instance_seg_consecutive(instance_seg): 166 | # Make the indices of instance_seg consecutive 167 | unique_ids = torch.unique(instance_seg) 168 | new_ids = torch.arange(len(unique_ids), device=instance_seg.device) 169 | instance_seg = update_instance_ids(instance_seg, unique_ids, new_ids) 170 | return instance_seg 171 | 172 | 173 | def make_instance_id_temporally_consistent(pred_inst, future_flow, matching_threshold=3.0): 174 | """ 175 | Parameters 176 | ---------- 177 | pred_inst: torch.Tensor (1, seq_len, h, w) 178 | future_flow: torch.Tensor(1, seq_len, 2, h, w) 179 | matching_threshold: distance threshold for a match to be valid. 180 | 181 | Returns 182 | ------- 183 | consistent_instance_seg: torch.Tensor(1, seq_len, h, w) 184 | 185 | 1. time t. Loop over all detected instances. Use flow to compute new centers at time t+1. 186 | 2. Store those centers 187 | 3. time t+1. Re-identify instances by comparing position of actual centers, and flow-warped centers. 188 | Make the labels at t+1 consistent with the matching 189 | 4. Repeat 190 | """ 191 | assert pred_inst.shape[0] == 1, 'Assumes batch size = 1' 192 | 193 | # Initialise instance segmentations with prediction corresponding to the present 194 | consistent_instance_seg = [pred_inst[0, 0]] 195 | largest_instance_id = consistent_instance_seg[0].max().item() 196 | 197 | _, seq_len, h, w = pred_inst.shape 198 | device = pred_inst.device 199 | for t in range(seq_len - 1): 200 | # Compute predicted future instance means 201 | grid = torch.stack(torch.meshgrid( 202 | torch.arange(h, dtype=torch.float, device=device), torch.arange(w, dtype=torch.float, device=device) 203 | )) 204 | 205 | # Add future flow 206 | grid = grid + future_flow[0, t] 207 | warped_centers = [] 208 | # Go through all ids, except the background 209 | t_instance_ids = torch.unique(consistent_instance_seg[-1])[1:].cpu().numpy() 210 | 211 | if len(t_instance_ids) == 0: 212 | # No instance so nothing to update 213 | consistent_instance_seg.append(pred_inst[0, t + 1]) 214 | continue 215 | 216 | for instance_id in t_instance_ids: 217 | instance_mask = (consistent_instance_seg[-1] == instance_id) 218 | warped_centers.append(grid[:, instance_mask].mean(dim=1)) 219 | warped_centers = torch.stack(warped_centers) 220 | 221 | # Compute actual future instance means 222 | centers = [] 223 | grid = torch.stack(torch.meshgrid( 224 | torch.arange(h, dtype=torch.float, device=device), torch.arange(w, dtype=torch.float, device=device) 225 | )) 226 | n_instances = int(pred_inst[0, t + 1].max().item()) 227 | 228 | if n_instances == 0: 229 | # No instance, so nothing to update. 230 | consistent_instance_seg.append(pred_inst[0, t + 1]) 231 | continue 232 | 233 | for instance_id in range(1, n_instances + 1): 234 | instance_mask = (pred_inst[0, t + 1] == instance_id) 235 | centers.append(grid[:, instance_mask].mean(dim=1)) 236 | centers = torch.stack(centers) 237 | 238 | # Compute distance matrix between warped centers and actual centers 239 | distances = torch.norm(centers.unsqueeze(0) - warped_centers.unsqueeze(1), dim=-1).cpu().numpy() 240 | # outputs (row, col) with row: index in frame t, col: index in frame t+1 241 | # the missing ids in col must be added (correspond to new instances) 242 | ids_t, ids_t_one = linear_sum_assignment(distances) 243 | matching_distances = distances[ids_t, ids_t_one] 244 | # Offset by one as id=0 is the background 245 | ids_t += 1 246 | ids_t_one += 1 247 | 248 | # swap ids_t with real ids. as those ids correspond to the position in the distance matrix. 249 | id_mapping = dict(zip(np.arange(1, len(t_instance_ids) + 1), t_instance_ids)) 250 | ids_t = np.vectorize(id_mapping.__getitem__, otypes=[np.int64])(ids_t) 251 | 252 | # Filter low quality match 253 | ids_t = ids_t[matching_distances < matching_threshold] 254 | ids_t_one = ids_t_one[matching_distances < matching_threshold] 255 | 256 | # Elements that are in t+1, but weren't matched 257 | remaining_ids = set(torch.unique(pred_inst[0, t + 1]).cpu().numpy()).difference(set(ids_t_one)) 258 | # remove background 259 | remaining_ids.remove(0) 260 | #  Set remaining_ids to a new unique id 261 | for remaining_id in list(remaining_ids): 262 | largest_instance_id += 1 263 | ids_t = np.append(ids_t, largest_instance_id) 264 | ids_t_one = np.append(ids_t_one, remaining_id) 265 | 266 | consistent_instance_seg.append(update_instance_ids(pred_inst[0, t + 1], old_ids=ids_t_one, new_ids=ids_t)) 267 | 268 | consistent_instance_seg = torch.stack(consistent_instance_seg).unsqueeze(0) 269 | return consistent_instance_seg 270 | 271 | 272 | def predict_instance_segmentation_and_trajectories( 273 | output, compute_matched_centers=False, make_consistent=True, vehicles_id=1, 274 | ): 275 | preds = output['segmentation'].detach() 276 | preds = torch.argmax(preds, dim=2, keepdim=True) 277 | foreground_masks = preds.squeeze(2) == vehicles_id 278 | 279 | batch_size, seq_len = preds.shape[:2] 280 | pred_inst = [] 281 | for b in range(batch_size): 282 | pred_inst_batch = [] 283 | for t in range(seq_len): 284 | pred_instance_t, _ = get_instance_segmentation_and_centers( 285 | output['instance_center'][b, t].detach(), 286 | output['instance_offset'][b, t].detach(), 287 | foreground_masks[b, t].detach() 288 | ) 289 | pred_inst_batch.append(pred_instance_t) 290 | pred_inst.append(torch.stack(pred_inst_batch, dim=0)) 291 | 292 | pred_inst = torch.stack(pred_inst).squeeze(2) 293 | 294 | if make_consistent: 295 | if output['instance_flow'] is None: 296 | print('Using zero flow because instance_future_output is None') 297 | output['instance_flow'] = torch.zeros_like(output['instance_offset']) 298 | consistent_instance_seg = [] 299 | for b in range(batch_size): 300 | consistent_instance_seg.append( 301 | make_instance_id_temporally_consistent(pred_inst[b:b+1], 302 | output['instance_flow'][b:b+1].detach()) 303 | ) 304 | consistent_instance_seg = torch.cat(consistent_instance_seg, dim=0) 305 | else: 306 | consistent_instance_seg = pred_inst 307 | 308 | if compute_matched_centers: 309 | assert batch_size == 1 310 | # Generate trajectories 311 | matched_centers = {} 312 | _, seq_len, h, w = consistent_instance_seg.shape 313 | grid = torch.stack(torch.meshgrid( 314 | torch.arange(h, dtype=torch.float, device=preds.device), 315 | torch.arange(w, dtype=torch.float, device=preds.device) 316 | )) 317 | 318 | for instance_id in torch.unique(consistent_instance_seg[0, 0])[1:].cpu().numpy(): 319 | for t in range(seq_len): 320 | instance_mask = consistent_instance_seg[0, t] == instance_id 321 | if instance_mask.sum() > 0: 322 | matched_centers[instance_id] = matched_centers.get(instance_id, []) + [ 323 | grid[:, instance_mask].mean(dim=-1)] 324 | 325 | for key, value in matched_centers.items(): 326 | matched_centers[key] = torch.stack(value).cpu().numpy()[:, ::-1] 327 | 328 | return consistent_instance_seg, matched_centers 329 | 330 | return consistent_instance_seg 331 | 332 | 333 | --------------------------------------------------------------------------------