├── scripts
├── .train_perceive_load.sh.swp
├── eval_plan.sh
├── train_perceive.sh
├── train_plan.sh
├── train_prediction.sh
└── train_perceive_load.sh
├── maps
├── Town01.h5
├── Town02.h5
├── Town03.h5
├── Town04.h5
├── Town05.h5
├── Town06.h5
├── Town07.h5
├── Town10HD.h5
└── hdmap_generate.py
├── __assets__
├── logo.png
└── overview.png
├── sad
├── models
│ ├── module
│ │ ├── __init__.py
│ │ ├── .ipynb_checkpoints
│ │ │ ├── __init__-checkpoint.py
│ │ │ ├── downsample-checkpoint.py
│ │ │ ├── sps-checkpoint.py
│ │ │ ├── MS_ResNet-checkpoint.py
│ │ │ └── ms_conv-checkpoint.py
│ │ ├── downsample.py
│ │ ├── sps.py
│ │ ├── ms_conv.py
│ │ ├── MS_ResNet.py
│ │ └── Real_MS_ResNet.py
│ ├── temporal_model.py
│ ├── distributions_snn.py
│ ├── future_prediction_snn.py
│ ├── encoder_snn_merge.py
│ ├── planning_model.py
│ └── decoder_snn_ms.py
├── configs
│ └── nuscenes
│ │ ├── Perception.yml
│ │ ├── Prediction.yml
│ │ └── Planning.yml
├── utils
│ ├── network.py
│ ├── .ipynb_checkpoints
│ │ └── network-checkpoint.py
│ ├── sampler.py
│ ├── geometry.py
│ └── instance.py
├── datas
│ └── dataloaders.py
├── config.py
└── losses.py
├── environment.yml
├── train.py
├── README.md
├── evaluate.py
└── LICENSE
/scripts/.train_perceive_load.sh.swp:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/maps/Town01.h5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ridgerchu/SAD/HEAD/maps/Town01.h5
--------------------------------------------------------------------------------
/maps/Town02.h5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ridgerchu/SAD/HEAD/maps/Town02.h5
--------------------------------------------------------------------------------
/maps/Town03.h5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ridgerchu/SAD/HEAD/maps/Town03.h5
--------------------------------------------------------------------------------
/maps/Town04.h5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ridgerchu/SAD/HEAD/maps/Town04.h5
--------------------------------------------------------------------------------
/maps/Town05.h5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ridgerchu/SAD/HEAD/maps/Town05.h5
--------------------------------------------------------------------------------
/maps/Town06.h5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ridgerchu/SAD/HEAD/maps/Town06.h5
--------------------------------------------------------------------------------
/maps/Town07.h5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ridgerchu/SAD/HEAD/maps/Town07.h5
--------------------------------------------------------------------------------
/maps/Town10HD.h5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ridgerchu/SAD/HEAD/maps/Town10HD.h5
--------------------------------------------------------------------------------
/__assets__/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ridgerchu/SAD/HEAD/__assets__/logo.png
--------------------------------------------------------------------------------
/__assets__/overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ridgerchu/SAD/HEAD/__assets__/overview.png
--------------------------------------------------------------------------------
/scripts/eval_plan.sh:
--------------------------------------------------------------------------------
1 | #! /bin/bash
2 | echo "checkpoint: $1"
3 | echo "dataroot: $2"
4 | python evaluate.py --checkpoint $1 --dataroot $2
--------------------------------------------------------------------------------
/scripts/train_perceive.sh:
--------------------------------------------------------------------------------
1 | #! /bin/bash
2 | echo "configs: $1"
3 | echo "DATASET.DATAROOT: $2"
4 | python train.py --config $1 DATASET.DATAROOT $2 DATASET.MAP_FOLDER $2
--------------------------------------------------------------------------------
/scripts/train_plan.sh:
--------------------------------------------------------------------------------
1 | #! /bin/bash
2 | echo "configs: $1"
3 | echo "DATASET.DATAROOT: $2"
4 | echo "PRETRAINED.PATH: $3"
5 | python train.py --config $1 DATASET.DATAROOT $2 DATASET.MAP_FOLDER $2 PRETRAINED.PATH $3
--------------------------------------------------------------------------------
/scripts/train_prediction.sh:
--------------------------------------------------------------------------------
1 | #! /bin/bash
2 | echo "configs: $1"
3 | echo "DATASET.DATAROOT: $2"
4 | echo "PRETRAINED.PATH: $3"
5 | python train.py --config $1 DATASET.DATAROOT $2 DATASET.MAP_FOLDER $2 PRETRAINED.PATH $3
--------------------------------------------------------------------------------
/scripts/train_perceive_load.sh:
--------------------------------------------------------------------------------
1 | #! /bin/bash
2 | echo "configs: $1"
3 | echo "DATASET.DATAROOT: $2"
4 | echo "PRETRAINED.PATH: $3"
5 | python train.py --config $1 DATASET.DATAROOT $2 DATASET.MAP_FOLDER $2 PRETRAINED.PATH $3
--------------------------------------------------------------------------------
/sad/models/module/__init__.py:
--------------------------------------------------------------------------------
1 | from .ms_conv import MS_Block_Conv
2 | from .sps import MS_SPS
3 | from .downsample import Sps_downsampling
4 |
5 |
6 | __all__ = [
7 | "MS_SPS",
8 | "MS_Block_Conv",
9 | "Sps_downsampling"
10 | ]
11 |
--------------------------------------------------------------------------------
/sad/models/module/.ipynb_checkpoints/__init__-checkpoint.py:
--------------------------------------------------------------------------------
1 | from .ms_conv import MS_Block_Conv
2 | from .sps import MS_SPS
3 | from .downsample import Sps_downsampling
4 |
5 |
6 | __all__ = [
7 | "MS_SPS",
8 | "MS_Block_Conv",
9 | "Sps_downsampling"
10 | ]
11 |
--------------------------------------------------------------------------------
/sad/models/temporal_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from sad.layers.temporal import Bottleneck3D, TemporalBlock
5 | from sad.layers.convolutions import ConvBlock, Bottleneck, SpikingDeepLabHead
6 |
7 |
8 | class TemporalModelIdentity(nn.Module):
9 | def __init__(self, in_channels, receptive_field):
10 | super().__init__()
11 | self.receptive_field = receptive_field
12 | self.out_channels = in_channels
13 |
14 | def forward(self, x):
15 | return x
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: sad
2 | channels:
3 | - defaults
4 | - conda-forge
5 | - pytorch
6 | dependencies:
7 | - python=3.7.10
8 | - pytorch=1.10.2
9 | - torchvision=0.11.3
10 | - cudatoolkit=11.3
11 | - numpy=1.19.2
12 | - scipy=1.5.2
13 | - pillow=8.0.1
14 | - tqdm=4.50.2
15 | - scikit-image=0.18.1
16 | - pytorch-lightning=1.2.5
17 | - efficientnet-pytorch=0.7.0
18 | - fvcore=0.1.2.post20201122
19 | - pip=21.0.1
20 | - timm
21 | - pip:
22 | - nuscenes-devkit==1.1.0
23 | - lyft-dataset-sdk==0.0.8
24 | - opencv-python==4.5.1.48
25 | - moviepy==1.0.3
26 | - shapely==1.8.0
27 | - spikingjelly==0.0.0.0.12
28 | - timm==0.6.12
29 | - torchinfo
30 | - wandb
31 |
--------------------------------------------------------------------------------
/sad/configs/nuscenes/Perception.yml:
--------------------------------------------------------------------------------
1 | TAG: 'Perception'
2 |
3 | GPUS: [0]
4 |
5 | BATCHSIZE: 3
6 | PRECISION: 16
7 | EPOCHS: 20
8 |
9 | N_WORKERS: 3
10 |
11 | DATASET:
12 | VERSION: 'trainval'
13 |
14 | TIME_RECEPTIVE_FIELD: 3
15 | N_FUTURE_FRAMES: 0
16 |
17 | LIFT:
18 | GT_DEPTH: False
19 |
20 | MODEL:
21 | ENCODER:
22 | NAME: "/vol4/Coin-MLP-back/Pure-MLP-SNN/18M-with-downsample-conv/checkpoint-308.pth.tar"
23 | TEMPORAL_MODEL:
24 | NAME: 'identity'
25 | INPUT_EGOPOSE: True
26 | BN_MOMENTUM: 0.05
27 |
28 | SEMANTIC_SEG:
29 | PEDESTRIAN:
30 | ENABLED: True
31 | HDMAP:
32 | ENABLED: True
33 |
34 | INSTANCE_SEG:
35 | ENABLED: False
36 |
37 | INSTANCE_FLOW:
38 | ENABLED: False
39 |
40 | PROBABILISTIC:
41 | ENABLED: False
42 |
43 | PLANNING:
44 | ENABLED: False
45 |
46 | OPTIMIZER:
47 | LR: 1e-3
48 |
49 |
50 |
--------------------------------------------------------------------------------
/sad/configs/nuscenes/Prediction.yml:
--------------------------------------------------------------------------------
1 | TAG: 'Prediction'
2 |
3 | GPUS: [0, 1, 2, 3]
4 |
5 | BATCHSIZE: 3
6 | PRECISION: 16
7 | EPOCHS: 20
8 |
9 | N_WORKERS: 3
10 |
11 | DATASET:
12 | VERSION: 'trainval'
13 |
14 | TIME_RECEPTIVE_FIELD: 3
15 | N_FUTURE_FRAMES: 4
16 |
17 | LIFT:
18 | GT_DEPTH: False
19 |
20 | MODEL:
21 | ENCODER:
22 | NAME: "/vol4/Coin-MLP-back/Pure-MLP-SNN/18M-with-downsample-conv/checkpoint-308.pth.tar"
23 | TEMPORAL_MODEL:
24 | NAME: 'identity'
25 | INPUT_EGOPOSE: True
26 | BN_MOMENTUM: 0.05
27 |
28 | SEMANTIC_SEG:
29 | PEDESTRIAN:
30 | ENABLED: False
31 | HDMAP:
32 | ENABLED: False
33 |
34 | INSTANCE_FLOW:
35 | ENABLED: True
36 |
37 | PROBABILISTIC:
38 | ENABLED: True
39 | METHOD: 'GAUSSIAN'
40 |
41 | PLANNING:
42 | ENABLED: False
43 |
44 | FUTURE_DISCOUNT: 0.95
45 |
46 | OPTIMIZER:
47 | LR: 2e-4
48 |
49 |
50 | PRETRAINED:
51 | LOAD_WEIGHTS: True
52 |
--------------------------------------------------------------------------------
/sad/configs/nuscenes/Planning.yml:
--------------------------------------------------------------------------------
1 | TAG: 'Planning'
2 |
3 | GPUS: [0]
4 |
5 | BATCHSIZE: 1
6 | PRECISION: 16
7 | EPOCHS: 20
8 |
9 | N_WORKERS: 4
10 |
11 | DATASET:
12 | VERSION: 'trainval'
13 |
14 | TIME_RECEPTIVE_FIELD: 3
15 | N_FUTURE_FRAMES: 6
16 |
17 | LIFT:
18 | GT_DEPTH: False
19 |
20 | MODEL:
21 | ENCODER:
22 | NAME: "/vol4/Coin-MLP-back/Pure-MLP-SNN/18M-with-downsample-conv/checkpoint-308.pth.tar"
23 | TEMPORAL_MODEL:
24 | NAME: 'identity'
25 | INPUT_EGOPOSE: True
26 | BN_MOMENTUM: 0.05
27 |
28 | SEMANTIC_SEG:
29 | PEDESTRIAN:
30 | ENABLED: True
31 | HDMAP:
32 | ENABLED: True
33 |
34 | INSTANCE_SEG:
35 | ENABLED: False
36 |
37 | INSTANCE_FLOW:
38 | ENABLED: False
39 |
40 | PROBABILISTIC:
41 | ENABLED: True
42 | METHOD: 'GAUSSIAN'
43 |
44 | PLANNING:
45 | ENABLED: True
46 | SAMPLE_NUM: 1800
47 |
48 | FUTURE_DISCOUNT: 0.95
49 |
50 | OPTIMIZER:
51 | LR: 1e-4
52 |
53 | COST_FUNCTION:
54 | SAFETY: 1.
55 | HEADWAY: 1.
56 | LRDIVIDER: 10.
57 | COMFORT: 0.1
58 | PROGRESS: 0.5
59 | VOLUME: 100.
60 |
61 | PRETRAINED:
62 | LOAD_WEIGHTS: True
63 |
--------------------------------------------------------------------------------
/sad/utils/network.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torchvision
4 |
5 | def pack_sequence_dim(x):
6 | b, s = x.shape[:2]
7 | return x.view(b * s, *x.shape[2:])
8 |
9 |
10 | def unpack_sequence_dim(x, b, s):
11 | return x.view(b, s, *x.shape[1:])
12 |
13 |
14 | def preprocess_batch(batch, device, unsqueeze=False):
15 | for key, value in batch.items():
16 | if torch.is_tensor(value):
17 | batch[key] = value.to(device)
18 | if unsqueeze:
19 | batch[key] = batch[key].unsqueeze(0)
20 |
21 |
22 | def set_module_grad(module, requires_grad=False):
23 | for p in module.parameters():
24 | p.requires_grad = requires_grad
25 |
26 |
27 | def set_bn_momentum(model, momentum=0.1):
28 | for m in model.modules():
29 | if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
30 | m.momentum = momentum
31 |
32 |
33 | class NormalizeInverse(torchvision.transforms.Normalize):
34 | # https://discuss.pytorch.org/t/simple-way-to-inverse-transform-normalization/4821/8
35 | def __init__(self, mean, std):
36 | mean = torch.as_tensor(mean)
37 | std = torch.as_tensor(std)
38 | std_inv = 1 / (std + 1e-7)
39 | mean_inv = -mean * std_inv
40 | super().__init__(mean=mean_inv, std=std_inv)
41 |
42 | def __call__(self, tensor):
43 | return super().__call__(tensor.clone())
44 |
--------------------------------------------------------------------------------
/sad/utils/.ipynb_checkpoints/network-checkpoint.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torchvision
4 |
5 | def pack_sequence_dim(x):
6 | b, s = x.shape[:2]
7 | return x.view(b * s, *x.shape[2:])
8 |
9 |
10 | def unpack_sequence_dim(x, b, s):
11 | return x.view(b, s, *x.shape[1:])
12 |
13 |
14 | def preprocess_batch(batch, device, unsqueeze=False):
15 | for key, value in batch.items():
16 | if torch.is_tensor(value):
17 | batch[key] = value.to(device)
18 | if unsqueeze:
19 | batch[key] = batch[key].unsqueeze(0)
20 |
21 |
22 | def set_module_grad(module, requires_grad=False):
23 | for p in module.parameters():
24 | p.requires_grad = requires_grad
25 |
26 |
27 | def set_bn_momentum(model, momentum=0.1):
28 | for m in model.modules():
29 | if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
30 | m.momentum = momentum
31 |
32 |
33 | class NormalizeInverse(torchvision.transforms.Normalize):
34 | # https://discuss.pytorch.org/t/simple-way-to-inverse-transform-normalization/4821/8
35 | def __init__(self, mean, std):
36 | mean = torch.as_tensor(mean)
37 | std = torch.as_tensor(std)
38 | std_inv = 1 / (std + 1e-7)
39 | mean_inv = -mean * std_inv
40 | super().__init__(mean=mean_inv, std=std_inv)
41 |
42 | def __call__(self, tensor):
43 | return super().__call__(tensor.clone())
44 |
--------------------------------------------------------------------------------
/sad/datas/dataloaders.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.utils.data
3 | from nuscenes.nuscenes import NuScenes
4 | from sad.datas.NuscenesData import FuturePredictionDataset
5 | from sad.datas.CarlaData import CarlaDataset
6 |
7 |
8 | def prepare_dataloaders(cfg, return_dataset=False):
9 | if cfg.DATASET.NAME == 'nuscenes':
10 | # 28130 train and 6019 val
11 | dataroot = cfg.DATASET.DATAROOT
12 | nusc = NuScenes(version='v1.0-{}'.format(cfg.DATASET.VERSION), dataroot=dataroot, verbose=False)
13 | traindata = FuturePredictionDataset(nusc, 0, cfg)
14 | valdata = FuturePredictionDataset(nusc, 1, cfg)
15 |
16 | if cfg.DATASET.VERSION == 'mini':
17 | traindata.indices = traindata.indices[:10]
18 | # valdata.indices = valdata.indices[:10]
19 |
20 | nworkers = cfg.N_WORKERS
21 | trainloader = torch.utils.data.DataLoader(
22 | traindata, batch_size=cfg.BATCHSIZE, shuffle=True, num_workers=nworkers, pin_memory=True, drop_last=True
23 | )
24 | valloader = torch.utils.data.DataLoader(
25 | valdata, batch_size=cfg.BATCHSIZE, shuffle=False, num_workers=nworkers, pin_memory=True, drop_last=False)
26 | elif cfg.DATASET.NAME == 'carla':
27 | dataroot = cfg.DATASET.DATAROOT
28 | traindata = CarlaDataset(dataroot, True, cfg)
29 | valdata = CarlaDataset(dataroot, False, cfg)
30 | nworkers = cfg.N_WORKERS
31 | trainloader = torch.utils.data.DataLoader(
32 | traindata, batch_size=cfg.BATCHSIZE, shuffle=True, num_workers=nworkers, pin_memory=True, drop_last=True
33 | )
34 | valloader = torch.utils.data.DataLoader(
35 | valdata, batch_size=cfg.BATCHSIZE, shuffle=False, num_workers=nworkers, pin_memory=True, drop_last=False)
36 | else:
37 | raise NotImplementedError
38 |
39 | if return_dataset:
40 | return trainloader, valloader, traindata, valdata
41 | else:
42 | return trainloader, valloader
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import socket
4 | import torch
5 | import pytorch_lightning as pl
6 | from pytorch_lightning.plugins import DDPPlugin
7 | from pytorch_lightning.callbacks import ModelCheckpoint
8 | from pytorch_lightning.loggers import WandbLogger, TensorBoardLogger
9 | import wandb
10 |
11 | from sad.config import get_parser, get_cfg
12 | from sad.datas.dataloaders import prepare_dataloaders
13 | from sad.trainer import TrainingModule
14 |
15 |
16 |
17 | def main():
18 | args = get_parser().parse_args()
19 | cfg = get_cfg(args)
20 | # print(cfg)
21 | trainloader, valloader = prepare_dataloaders(cfg)
22 | print("load data!!!")
23 | model = TrainingModule(cfg.convert_to_dict())
24 |
25 | if cfg.PRETRAINED.LOAD_WEIGHTS:
26 | # Load single-image instance segmentation model.
27 | pretrained_model_weights = torch.load(
28 | cfg.PRETRAINED.PATH, map_location='cpu'
29 | )['state_dict']
30 | state = model.state_dict()
31 | pretrained_model_weights = {k: v for k, v in pretrained_model_weights.items() if k in state and 'decoder' not in k}
32 | model.load_state_dict(pretrained_model_weights, strict=False)
33 | print(f'Loaded single-image model weights from {cfg.PRETRAINED.PATH}')
34 |
35 | save_dir = os.path.join(
36 | cfg.LOG_DIR, time.strftime('%d%B%Yat%H_%M_%S%Z') + '_' + socket.gethostname() + '_' + cfg.TAG
37 | )
38 | # tb_logger = pl.loggers.TensorBoardLogger(save_dir=save_dir)
39 | wandb_logger = WandbLogger(name='sad-preception', project='sad', entity="ncg_ucsc", config=cfg, log_model="all")
40 |
41 | # if os.environ.get('LOCAL_RANK', '0') == '0':
42 | # wandb.init(name='sad-preception', project='sad', entity="snn_ad", config=cfg)
43 |
44 | checkpoint_callback = ModelCheckpoint(
45 | monitor='step_val_seg_iou_dynamic',
46 | save_top_k=-1,
47 | save_last=True,
48 | period=1,
49 | mode='min'
50 | )
51 | trainer = pl.Trainer(
52 | gpus=cfg.GPUS,
53 | accelerator='ddp',
54 | precision=cfg.PRECISION,
55 | sync_batchnorm=True,
56 | gradient_clip_val=cfg.GRAD_NORM_CLIP,
57 | max_epochs=cfg.EPOCHS,
58 | weights_summary='full',
59 | logger=wandb_logger,
60 | log_every_n_steps=cfg.LOGGING_INTERVAL,
61 | plugins=DDPPlugin(find_unused_parameters=False),
62 | profiler='simple',
63 | callbacks=[checkpoint_callback]
64 | )
65 |
66 | trainer.fit(model, trainloader, valloader)
67 |
68 |
69 | if __name__ == "__main__":
70 | main()
71 |
--------------------------------------------------------------------------------
/sad/models/distributions_snn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from sad.layers.convolutions import SpikingBottleneck
5 |
6 |
7 | class DistributionModule(nn.Module):
8 | """
9 | A convolutional net that parametrises a diagonal Gaussian distribution.
10 | """
11 |
12 | def __init__(
13 | self, in_channels, latent_dim, method="GAUSSIAN"):
14 | super().__init__()
15 | self.compress_dim = in_channels // 2
16 | self.latent_dim = latent_dim
17 | self.method = method
18 |
19 | if method == 'GAUSSIAN':
20 | self.encoder = DistributionEncoder(in_channels, self.compress_dim)
21 | self.decoder = nn.Sequential(
22 | nn.AdaptiveAvgPool2d(1), nn.Conv2d(self.compress_dim, out_channels=2 * self.latent_dim, kernel_size=1)
23 | )
24 | elif method == 'MIXGAUSSIAN':
25 | self.encoder = DistributionEncoder(in_channels, self.compress_dim)
26 | self.decoder = nn.Sequential(
27 | nn.AdaptiveAvgPool2d(1), nn.Conv2d(self.compress_dim, out_channels=6 * self.latent_dim + 3, kernel_size=1)
28 | )
29 | elif method == 'BERNOULLI':
30 | self.encoder = nn.Sequential(
31 | Bottleneck(in_channels, self.latent_dim)
32 | )
33 | self.decoder = nn.LogSigmoid()
34 | else:
35 | raise NotImplementedError
36 |
37 | def forward(self, s_t):
38 | b, s = s_t.shape[:2]
39 | assert s == 1
40 | encoding = self.encoder(s_t[:, 0])
41 |
42 | if self.method == 'GAUSSIAN':
43 | decoder = self.decoder(encoding).view(b, 1, 2 * self.latent_dim)
44 | elif self.method == 'MIXGAUSSIAN':
45 | decoder = self.decoder(encoding).view(b, 1, 6 * self.latent_dim + 3)
46 | elif self.method == 'BERNOULLI':
47 | decoder = self.decoder(encoding)
48 | else:
49 | raise NotImplementedError
50 |
51 | return decoder
52 |
53 |
54 | class DistributionEncoder(nn.Module):
55 | """Encodes s_t or (s_t, y_{t+1}, ..., y_{t+H}).
56 | """
57 | def __init__(self, in_channels, out_channels):
58 | super().__init__()
59 |
60 | self.model = nn.Sequential(
61 | SpikingBottleneck(in_channels, out_channels=out_channels, downsample=True),
62 | SpikingBottleneck(out_channels, out_channels=out_channels, downsample=True),
63 | SpikingBottleneck(out_channels, out_channels=out_channels, downsample=True),
64 | SpikingBottleneck(out_channels, out_channels=out_channels, downsample=True),
65 | )
66 |
67 | def forward(self, s_t):
68 | return self.model(s_t)
69 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |

4 |
5 |
6 |
7 | [NeurIPS 2024] Spiking Autonomous Driving (SAD): End-to-End Autonomous Driving with Spiking Neural Networks
8 |
9 | If you find our project useful, please give us a star ⭐ on GitHub!
10 |
11 | ## Introduction
12 |
13 | Spiking Autonomous Driving (SAD) is the first end-to-end autonomous driving system built entirely with Spiking Neural Networks (SNNs). It integrates perception, prediction, and planning modules into a unified neuromorphic framework.
14 |
15 | ### Key Features
16 |
17 | - End-to-end SNN architecture for autonomous driving, integrating perception, prediction, and planning
18 | - Perception module constructs spatio-temporal Bird's Eye View (BEV) representation from multi-view cameras using SNNs
19 | - Prediction module forecasts future states using a novel dual-pathway SNN
20 | - Planning module generates safe trajectories considering occupancy prediction, traffic rules, and ride comfort
21 |
22 | ## System Overview
23 |
24 |
25 |
26 |
27 | ## Modules
28 |
29 | ### Perception
30 |
31 | The perception module constructs a spatio-temporal BEV representation from multi-camera inputs. The encoder uses sequence repetition, while the decoder employs sequence alignment.
32 |
33 | ### Prediction
34 |
35 | The prediction module utilizes a dual-pathway SNN, where one pathway encodes past information and the other predicts future distributions. The outputs from both pathways are fused.
36 |
37 | ### Planning
38 |
39 | The planning module optimizes trajectories using Spiking Gated Recurrent Units (SGRUs), taking into account static occupancy, future predictions, comfort, and other factors.
40 |
41 | ## Get Started
42 |
43 | ### Setup
44 |
45 | ```
46 | conda env create -f environment.yml
47 | ```
48 |
49 | ### Training
50 |
51 | First, go to `/sad/configs` and modify the configs. Change the NAME in MODEL/ENCODER to the model we provided. The link is as follows: https://huggingface.co/ridger/MLP-SNN/blob/main/model.pth.tar
52 |
53 | ```
54 | # Perception module pretraining
55 | bash scripts/train_perceive.sh ${configs} ${dataroot}
56 |
57 | # Prediction module pretraining
58 | bash scripts/train_prediction.sh ${configs} ${dataroot} ${pretrained}
59 |
60 | # Entire model end-to-end training
61 | bash scripts/train_plan.sh ${configs} ${dataroot} ${pretrained}
62 | ```
63 | ## Citation
64 |
65 | If you find SAD useful in your work, please cite the following source:
66 |
67 | ```
68 | @inproceedings{
69 | zhu2024autonomous,
70 | title={Autonomous Driving with Spiking Neural Networks},
71 | author={Rui-Jie Zhu and Ziqing Wang and Leilani H. Gilpin and Jason Eshraghian},
72 | booktitle={The Thirty-eighth Annual Conference on Neural Information Processing Systems},
73 | year={2024},
74 | url={https://openreview.net/forum?id=95VyH4VxN9}
75 | }
76 | ```
77 |
--------------------------------------------------------------------------------
/sad/models/future_prediction_snn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import copy
4 |
5 | from sad.layers.convolutions import SpikingDeepLabHead
6 | from sad.layers.temporal_snn import SpatialGRU, Dual_LIF_temporal_mixer, BiGRU
7 |
8 |
9 | from sad.models.module.MS_ResNet import multi_step_sew_resnet18
10 | from spikingjelly.clock_driven.neuron import (
11 | MultiStepParametricLIFNode,
12 | MultiStepLIFNode,
13 | )
14 |
15 |
16 | class FuturePrediction(nn.Module):
17 | def __init__(self, in_channels, latent_dim, n_future, mixture=True):
18 | super(FuturePrediction, self).__init__()
19 | # self.n_spatial_gru = n_gru_blocks
20 |
21 | backbone = multi_step_sew_resnet18(pretrained=False, multi_step_neuron=MultiStepLIFNode)
22 |
23 | gru_in_channels = latent_dim
24 | self.layer1 = backbone.layer1
25 | self.layer1[0].conv1 = torch.nn.Conv2d(in_channels, in_channels, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1),
26 | bias=False)
27 | self.layer1[0].bn1 = torch.nn.BatchNorm2d(in_channels)
28 | self.layer1[0].conv2 = torch.nn.Conv2d(in_channels, in_channels, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1),
29 | bias=False)
30 | self.layer1[0].bn2 = torch.nn.BatchNorm2d(in_channels)
31 |
32 | self.layer1[1].conv1 = torch.nn.Conv2d(in_channels, in_channels, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1),
33 | bias=False)
34 | self.layer1[1].bn1 = torch.nn.BatchNorm2d(in_channels)
35 | self.layer1[1].conv2 = torch.nn.Conv2d(in_channels, in_channels, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1),
36 | bias=False)
37 | self.layer1[1].bn2 = torch.nn.BatchNorm2d(in_channels)
38 |
39 | # self.layer2 = backbone.layer2
40 | # self.layer3 = backbone.layer3
41 | # self.layer3[1].conv2 = torch.nn.Conv2d(256, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
42 | # self.layer3[1].bn2 = torch.nn.BatchNorm2d(64)
43 |
44 | self.layer2 = copy.deepcopy(backbone.layer1)
45 | self.layer2[0].conv1 = torch.nn.Conv2d(in_channels, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
46 | self.layer2[0].bn1 = torch.nn.BatchNorm2d(384)
47 | self.layer2[0].conv2 = torch.nn.Conv2d(384, in_channels, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
48 | self.layer2[0].bn2 = torch.nn.BatchNorm2d(in_channels)
49 |
50 | self.layer2[1].conv1 = torch.nn.Conv2d(in_channels, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
51 | self.layer2[1].bn1 = torch.nn.BatchNorm2d(384)
52 | self.layer2[1].conv2 = torch.nn.Conv2d(384, in_channels, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
53 | self.layer2[1].bn2 = torch.nn.BatchNorm2d(in_channels)
54 |
55 | self.layer3 = copy.deepcopy(self.layer2)
56 |
57 | # 修改第1个MultiStepBasicBlock中的BatchNorm2d层的特征数
58 |
59 |
60 |
61 | self.dual_lif = Dual_LIF_temporal_mixer(gru_in_channels, in_channels, n_future=n_future, mixture=mixture)
62 | # self.res_blocks1 = nn.Sequential(*[Block(in_channels) for _ in range(n_res_layers)])
63 | #
64 | # self.spatial_grus = []
65 | # self.res_blocks = []
66 | # for i in range(self.n_spatial_gru):
67 | # self.spatial_grus.append(SpatialGRU(in_channels, in_channels))
68 | # if i < self.n_spatial_gru - 1:
69 | # self.res_blocks.append(nn.Sequential(*[Block(in_channels) for _ in range(n_res_layers)]))
70 | # else:
71 | # self.res_blocks.append(DeepLabHead(in_channels, in_channels, 128))
72 | #
73 | # self.spatial_grus = torch.nn.ModuleList(self.spatial_grus)
74 | # self.res_blocks = torch.nn.ModuleList(self.res_blocks)
75 |
76 | def forward(self, x, state):
77 | # x has shape (b, 1, c, h, w), state: torch.Tensor [b, n_present, hidden_size, h, w]
78 | x = self.dual_lif(x, state)
79 |
80 |
81 | # b, n_future, c, h, w = x.shape # 预测未来情况,这时候就已经有未来的feature了
82 | # x = self.res_blocks1(x.view(b * n_future, c, h, w)) # 过ResNet Block,此时没有未来feature
83 | # x = x.view(b, n_future, c, h, w)
84 |
85 | x = torch.cat([state, x], dim=1) # 回正后把未来feature和当前的状况做一个融合
86 | b, s, c, h, w = x.shape
87 | x = x.reshape(s, b, c, h, w)
88 | x = self.layer1(x)
89 | x = self.layer2(x)
90 | x = self.layer3(x)
91 | x = x.reshape(b, s, c, h, w)
92 |
93 | # hidden_state = x[:, 0]
94 | # for i in range(self.n_spatial_gru):
95 | # x = self.spatial_grus[i](x, hidden_state) # 使用Spatial GRU,正常计算
96 | #
97 | # b, s, c, h, w = x.shape
98 | # x = self.res_blocks[i](x.view(b * s, c, h, w)) # 过Res Blocks,特征提取,这一块可以被整合进入RWKV
99 | # x = x.view(b, s, c, h, w)
100 |
101 | return x
--------------------------------------------------------------------------------
/sad/models/module/.ipynb_checkpoints/downsample-checkpoint.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from spikingjelly.clock_driven.neuron import MultiStepParametricLIFNode, MultiStepLIFNode
4 | from spikingjelly.clock_driven import layer
5 |
6 | def MS_conv_unit(in_channels, out_channels,kernel_size=1,padding=0,groups=1):
7 | return nn.Sequential(
8 | layer.SeqToANNContainer(
9 | nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding, groups=groups,bias=True),
10 | nn.BatchNorm2d(out_channels) # 这里可以进行改进 ?
11 | )
12 | )
13 | class MS_ConvBlock(nn.Module):
14 | def __init__(self, dim,
15 | mlp_ratio=4.0):
16 | super().__init__()
17 | self.neuron1 = MultiStepLIFNode(tau=2.0, detach_reset=True,backend='torch')
18 | self.conv1 = MS_conv_unit(dim, dim, 3, 1)
19 |
20 | self.neuron2 = MultiStepLIFNode(tau=2.0, detach_reset=True,backend='torch')
21 | self.conv2 = MS_conv_unit(dim, dim * mlp_ratio, 3, 1)
22 |
23 | self.neuron3 = MultiStepLIFNode(tau=2.0, detach_reset=True,backend='torch')
24 |
25 | self.conv3 = MS_conv_unit(dim*mlp_ratio, dim, 3, 1)
26 |
27 |
28 | def forward(self, x, hook=None, num=None):
29 | short_cut1 = x
30 | x = self.neuron1(x)
31 | if hook is not None:
32 | hook[self._get_name() + num + "_lif1"] = x.detach()
33 | x = self.conv1(x)+short_cut1
34 | short_cut2 = x
35 | x = self.neuron2(x)
36 | if hook is not None:
37 | hook[self._get_name() + num + "_lif2"] = x.detach()
38 |
39 | x = self.conv2(x)
40 | x = self.neuron3(x)
41 | if hook is not None:
42 | hook[self._get_name() + num + "_lif3"] = x.detach()
43 |
44 | x = self.conv3(x)
45 | x = x + short_cut2
46 | return x
47 | class MS_DownSampling(nn.Module):
48 | def __init__(
49 | self,
50 | in_channels=2,
51 | embed_dims=256,
52 | kernel_size=3,
53 | stride=2,
54 | padding=1,
55 | first_layer=True,
56 | ):
57 | super().__init__()
58 |
59 | self.encode_conv = nn.Conv2d(
60 | in_channels,
61 | embed_dims,
62 | kernel_size=kernel_size,
63 | stride=stride,
64 | padding=padding,
65 | )
66 |
67 | self.encode_bn = nn.BatchNorm2d(embed_dims)
68 | if not first_layer:
69 | self.encode_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='torch')
70 |
71 | def forward(self, x, hook=None, num=None):
72 |
73 | T, B, _, _, _ = x.shape
74 |
75 | if hasattr(self, "encode_lif"):
76 | x = self.encode_lif(x)
77 | if hook is not None:
78 | hook[self._get_name() + num + "_lif"] = x.detach()
79 | x = self.encode_conv(x.flatten(0, 1))
80 | _, _, H, W = x.shape
81 | x = self.encode_bn(x).reshape(T, B, -1, H, W).contiguous()
82 | return x
83 | class Sps_downsampling(nn.Module):
84 | def __init__(
85 | self,
86 | in_channels=3,
87 | mlp_ratios=4,
88 | embed_dim=[64,128,256],
89 | pooling_stat="1111",
90 | spike_mode="lif",
91 | ):
92 | super().__init__()
93 | self.downsample1_1 = MS_DownSampling(
94 | in_channels=in_channels,
95 | embed_dims=embed_dim[0] // 2,
96 | kernel_size=7,
97 | stride=2,
98 | padding=3,
99 | first_layer=True,
100 | )
101 |
102 | self.ConvBlock1_1 = nn.ModuleList(
103 | [MS_ConvBlock(dim=embed_dim[0] // 2, mlp_ratio=mlp_ratios)]
104 | )
105 |
106 | self.downsample1_2 = MS_DownSampling(
107 | in_channels=embed_dim[0] // 2,
108 | embed_dims=embed_dim[0],
109 | kernel_size=3,
110 | stride=2,
111 | padding=1,
112 | first_layer=False,
113 | )
114 |
115 | self.ConvBlock1_2 = nn.ModuleList(
116 | [MS_ConvBlock(dim=embed_dim[0], mlp_ratio=mlp_ratios)]
117 | )
118 |
119 | self.downsample2 = MS_DownSampling(
120 | in_channels=embed_dim[0],
121 | embed_dims=embed_dim[1],
122 | kernel_size=3,
123 | stride=2,
124 | padding=1,
125 | first_layer=False,
126 | )
127 |
128 | self.ConvBlock2_1 = nn.ModuleList(
129 | [MS_ConvBlock(dim=embed_dim[1], mlp_ratio=mlp_ratios)]
130 | )
131 |
132 | self.ConvBlock2_2 = nn.ModuleList(
133 | [MS_ConvBlock(dim=embed_dim[1], mlp_ratio=mlp_ratios)]
134 | )
135 |
136 | self.downsample3 = MS_DownSampling(
137 | in_channels=embed_dim[1],
138 | embed_dims=embed_dim[2],
139 | kernel_size=3,
140 | stride=2,
141 | padding=1,
142 | first_layer=False,
143 | )
144 | def forward(self,x, hook=None):
145 | x = self.downsample1_1(x,hook,"1")
146 | for blk in self.ConvBlock1_1:
147 | x = blk(x,hook,"1_1") # 112
148 | x = self.downsample1_2(x)
149 | for blk in self.ConvBlock1_2:
150 | x = blk(x,hook,"1_2") #56
151 |
152 |
153 | x = self.downsample2(x,hook,"2")
154 |
155 | for blk in self.ConvBlock2_1:
156 | x = blk(x,hook,"2_1") #28
157 |
158 | for blk in self.ConvBlock2_2:
159 | output_1 = blk(x,hook,"2_2") #28
160 |
161 |
162 | output_2 = self.downsample3(x,hook,"3")
163 |
164 | return output_1, output_2
165 | # model=Sps_downsampling()
166 | # x=torch.randn(1,1,3,224,224)
167 | # print(model(x).shape)
168 |
--------------------------------------------------------------------------------
/sad/models/module/downsample.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from spikingjelly.clock_driven.neuron import MultiStepParametricLIFNode, MultiStepLIFNode
4 | from spikingjelly.clock_driven import layer
5 |
6 | def MS_conv_unit(in_channels, out_channels,kernel_size=1,padding=0,groups=1):
7 | return nn.Sequential(
8 | layer.SeqToANNContainer(
9 | nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding, groups=groups,bias=True),
10 | nn.BatchNorm2d(out_channels) # 这里可以进行改进 ?
11 | )
12 | )
13 | class MS_ConvBlock(nn.Module):
14 | def __init__(self, dim,
15 | mlp_ratio=4.0):
16 | super().__init__()
17 | self.neuron1 = MultiStepLIFNode(tau=2.0, detach_reset=True,backend='torch')
18 | self.conv1 = MS_conv_unit(dim, dim, 3, 1)
19 |
20 | self.neuron2 = MultiStepLIFNode(tau=2.0, detach_reset=True,backend='torch')
21 | self.conv2 = MS_conv_unit(dim, dim * mlp_ratio, 3, 1)
22 |
23 | self.neuron3 = MultiStepLIFNode(tau=2.0, detach_reset=True,backend='torch')
24 |
25 | self.conv3 = MS_conv_unit(dim*mlp_ratio, dim, 3, 1)
26 |
27 |
28 | def forward(self, x, hook=None, num=None):
29 | with torch.backends.cudnn.flags(enabled=False):
30 | short_cut1 = x
31 | x = self.neuron1(x)
32 | if hook is not None:
33 | hook[self._get_name() + num + "_lif1"] = x.detach()
34 | x = self.conv1(x) + short_cut1
35 | short_cut2 = x
36 | x = self.neuron2(x)
37 | if hook is not None:
38 | hook[self._get_name() + num + "_lif2"] = x.detach()
39 |
40 | x = self.conv2(x)
41 | x = self.neuron3(x)
42 | if hook is not None:
43 | hook[self._get_name() + num + "_lif3"] = x.detach()
44 |
45 | x = self.conv3(x)
46 | x = x + short_cut2
47 | return x
48 | class MS_DownSampling(nn.Module):
49 | def __init__(
50 | self,
51 | in_channels=2,
52 | embed_dims=256,
53 | kernel_size=3,
54 | stride=2,
55 | padding=1,
56 | first_layer=True,
57 | ):
58 | super().__init__()
59 |
60 | self.encode_conv = nn.Conv2d(
61 | in_channels,
62 | embed_dims,
63 | kernel_size=kernel_size,
64 | stride=stride,
65 | padding=padding,
66 | )
67 |
68 | self.encode_bn = nn.BatchNorm2d(embed_dims)
69 | if not first_layer:
70 | self.encode_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='torch')
71 |
72 | def forward(self, x, hook=None, num=None):
73 |
74 | T, B, _, _, _ = x.shape
75 |
76 | if hasattr(self, "encode_lif"):
77 | x = self.encode_lif(x)
78 | if hook is not None:
79 | hook[self._get_name() + num + "_lif"] = x.detach()
80 | x = self.encode_conv(x.flatten(0, 1))
81 | _, _, H, W = x.shape
82 | x = self.encode_bn(x).reshape(T, B, -1, H, W).contiguous()
83 | return x
84 | class Sps_downsampling(nn.Module):
85 | def __init__(
86 | self,
87 | in_channels=3,
88 | mlp_ratios=4,
89 | embed_dim=[64,128,256],
90 | pooling_stat="1111",
91 | spike_mode="lif",
92 | ):
93 | super().__init__()
94 | self.downsample1_1 = MS_DownSampling(
95 | in_channels=in_channels,
96 | embed_dims=embed_dim[0] // 2,
97 | kernel_size=7,
98 | stride=2,
99 | padding=3,
100 | first_layer=True,
101 | )
102 |
103 | self.ConvBlock1_1 = nn.ModuleList(
104 | [MS_ConvBlock(dim=embed_dim[0] // 2, mlp_ratio=mlp_ratios)]
105 | )
106 |
107 | self.downsample1_2 = MS_DownSampling(
108 | in_channels=embed_dim[0] // 2,
109 | embed_dims=embed_dim[0],
110 | kernel_size=3,
111 | stride=2,
112 | padding=1,
113 | first_layer=False,
114 | )
115 |
116 | self.ConvBlock1_2 = nn.ModuleList(
117 | [MS_ConvBlock(dim=embed_dim[0], mlp_ratio=mlp_ratios)]
118 | )
119 |
120 | self.downsample2 = MS_DownSampling(
121 | in_channels=embed_dim[0],
122 | embed_dims=embed_dim[1],
123 | kernel_size=3,
124 | stride=2,
125 | padding=1,
126 | first_layer=False,
127 | )
128 |
129 | self.ConvBlock2_1 = nn.ModuleList(
130 | [MS_ConvBlock(dim=embed_dim[1], mlp_ratio=mlp_ratios)]
131 | )
132 |
133 | self.ConvBlock2_2 = nn.ModuleList(
134 | [MS_ConvBlock(dim=embed_dim[1], mlp_ratio=mlp_ratios)]
135 | )
136 |
137 | self.downsample3 = MS_DownSampling(
138 | in_channels=embed_dim[1],
139 | embed_dims=embed_dim[2],
140 | kernel_size=3,
141 | stride=2,
142 | padding=1,
143 | first_layer=False,
144 | )
145 | def forward(self,x, hook=None):
146 | x = self.downsample1_1(x,hook,"1")
147 | for blk in self.ConvBlock1_1:
148 | x = blk(x,hook,"1_1") # 112
149 | x = self.downsample1_2(x)
150 | for blk in self.ConvBlock1_2:
151 | x = blk(x,hook,"1_2") #56
152 |
153 |
154 | x = self.downsample2(x,hook,"2")
155 |
156 | for blk in self.ConvBlock2_1:
157 | x = blk(x,hook,"2_1") #28
158 |
159 | for blk in self.ConvBlock2_2:
160 | output_1 = blk(x,hook,"2_2") #28
161 |
162 |
163 | output_2 = self.downsample3(x,hook,"3")
164 |
165 | return output_1, output_2
166 | # model=Sps_downsampling()
167 | # x=torch.randn(1,1,3,224,224)
168 | # print(model(x).shape)
169 |
--------------------------------------------------------------------------------
/sad/models/encoder_snn_merge.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from timm.models.helpers import clean_state_dict
6 | from timm.models.layers import trunc_normal_
7 | from timm.models.registry import register_model
8 | from timm.models.vision_transformer import _cfg
9 | from spikingjelly.clock_driven.neuron import (
10 | MultiStepLIFNode,
11 | MultiStepParametricLIFNode,
12 | )
13 | from sad.models.module import *
14 | from sad.layers.convolutions import SpikingUpsamplingConcat, SpikingDeepLabHead, DeepLabHead, MergeUpsamplingConcat
15 | from timm.models import (
16 | create_model,
17 | safe_model_name,
18 | resume_checkpoint,
19 | load_checkpoint,
20 | convert_splitbn_model,
21 | model_parameters,
22 | )
23 |
24 | class CoinMLP(nn.Module):
25 | def __init__(
26 | self,
27 | img_size_h=128,
28 | img_size_w=128,
29 | patch_size=16,
30 | in_channels=2,
31 | num_classes=11,
32 | embed_dims=512,
33 | num_heads=8,
34 | mlp_ratios=4,
35 | qkv_bias=False,
36 | qk_scale=None,
37 | drop_rate=0.0,
38 | attn_drop_rate=0.0,
39 | drop_path_rate=0.0,
40 | norm_layer=nn.LayerNorm,
41 | depths=[6, 8, 6],
42 | sr_ratios=[8, 4, 2],
43 | T=4,
44 | pooling_stat="1111",
45 | attn_mode="direct_xor",
46 | spike_mode="lif",
47 | get_embed=False,
48 | dvs_mode=False,
49 | TET=False,
50 | cml=False,
51 | pretrained=False,
52 | pretrained_cfg=None,
53 | ):
54 | super().__init__()
55 | self.num_classes = num_classes
56 | self.depths = depths
57 |
58 | self.T = T
59 | self.TET = TET
60 | self.dvs = dvs_mode
61 | self.D = 48
62 | self.C = 64
63 | self.depth_layer_1 = SpikingDeepLabHead(384, 384, hidden_channel=64)
64 | self.depth_layer_2 = MergeUpsamplingConcat(192 + 384, self.D)
65 |
66 | self.feature_layer_1 = SpikingDeepLabHead(384, 384, hidden_channel=64)
67 | self.feature_layer_2 = MergeUpsamplingConcat(192 + 384, self.C)
68 |
69 |
70 |
71 | dpr = [
72 | x.item() for x in torch.linspace(0, drop_path_rate, depths)
73 | ] # stochastic depth decay rule
74 |
75 | patch_embed = Sps_downsampling(
76 | embed_dim=[int(embed_dims/4),int(embed_dims/2),embed_dims],
77 | pooling_stat=pooling_stat,
78 | spike_mode=spike_mode,
79 | )
80 |
81 |
82 | setattr(self, f"patch_embed", patch_embed)
83 |
84 | # classification head
85 | self.head_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend="torch")
86 | self.head_lif_2 = MultiStepLIFNode(tau=2.0, detach_reset=True, backend="torch")
87 | self.apply(self._init_weights)
88 |
89 | def _init_weights(self, m):
90 | if isinstance(m, nn.Conv2d):
91 | trunc_normal_(m.weight, std=0.02)
92 | if m.bias is not None:
93 | nn.init.constant_(m.bias, 0)
94 | elif isinstance(m, nn.BatchNorm2d):
95 | nn.init.constant_(m.bias, 0)
96 | nn.init.constant_(m.weight, 1.0)
97 |
98 | def forward_features(self, x, hook=None):
99 | patch_embed = getattr(self, f"patch_embed")
100 |
101 | x_1, x_2 = patch_embed(x, hook=hook)
102 | # import pdb
103 | # pdb.set_trace()
104 |
105 | return x_1, x_2
106 |
107 | def forward(self, x, hook=None):
108 | x_1, x_2 = self.forward_features(x, hook=hook)
109 |
110 | x_1 = self.head_lif(x_1)
111 | x_2 = self.head_lif_2(x_2)
112 |
113 |
114 | # x_2 = torch.mean(x_2, dim=0)
115 |
116 | x_1 = torch.mean(x_1, dim=0)
117 | feature = self.feature_layer_1(x_2)
118 | feature = torch.mean(feature, dim=0)
119 | feature = self.feature_layer_2(feature, x_1)
120 |
121 |
122 | depth = self.depth_layer_1(x_2)
123 | depth = torch.mean(depth, dim=0)
124 | depth = self.depth_layer_2(depth, x_1)
125 |
126 | if hook is not None:
127 | hook["head_lif"] = x.detach()
128 |
129 | # feature = torch.mean(feature, dim=0)
130 | # depth = torch.mean(depth, dim=0)
131 |
132 | return feature, depth
133 |
134 |
135 | @register_model
136 | def sdt(**kwargs):
137 | model = CoinMLP(
138 | **kwargs,
139 | )
140 | model.default_cfg = _cfg()
141 | return model
142 |
143 | def main():
144 | encoder = create_model(
145 | "sdt",
146 | T = 3, # 默认时间步长
147 | pretrained = False, # 默认情况下不使用预训练模型
148 | drop_rate = 0.0, # 默认dropout率
149 | drop_path_rate = 0.2, # 默认drop path率
150 | drop_block_rate = None, # 默认drop block率,未指定
151 | num_heads = 8, # 默认头数
152 | num_classes = 1000, # 默认类别数
153 | pooling_stat = "1111", # 默认池化状态
154 | img_size_h = 480, # 默认图像高度,未指定
155 | img_size_w = 224, # 默认图像宽度,未指定
156 | patch_size = None, # 默认patch大小,未指定
157 | embed_dims = 384, # `args.dim`没有在您提供的参数解析器中直接列出,请指定一个默认值或确认是否有误
158 | mlp_ratios = 4, # 默认MLP比率
159 | in_channels = 3, # 默认输入通道数
160 | qkv_bias = False, # qkv偏置,默认未指定,这里设为False
161 | depths = 6, # 默认层数
162 | sr_ratios = 1, # 默认sr比率,未在参数解析器直接列出,这里设为1
163 | spike_mode = "lif", # 默认脉冲模式
164 | dvs_mode = False, # `args.dvs_mode`没有在您提供的参数解析器中直接列出,请指定一个默认值或确认是否有误
165 | TET = False, # 默认TET设置
166 |
167 | )
168 | checkpoint = torch.load("/vol5/Coin-MLP-back/Pure-MLP-SNN/18M-with-downsample-conv/checkpoint-308.pth.tar", map_location="cpu")
169 | state_dict = clean_state_dict(checkpoint["state_dict"])
170 | encoder.load_state_dict(state_dict, strict=False)
171 | x = torch.randn(3, 3, 3, 224, 480)
172 | encoder = encoder.cuda()
173 | y_1,y_2 = encoder(x.cuda())
174 | print(y_1.shape)
175 | print(y_2.shape)
176 |
177 |
178 | if __name__ == "__main__":
179 | main()
--------------------------------------------------------------------------------
/sad/models/module/sps.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from spikingjelly.clock_driven.neuron import (
4 | MultiStepLIFNode,
5 | MultiStepParametricLIFNode,
6 | )
7 | from timm.models.layers import to_2tuple
8 |
9 |
10 | class MS_SPS(nn.Module):
11 | def __init__(
12 | self,
13 | img_size_h=128,
14 | img_size_w=128,
15 | patch_size=4,
16 | in_channels=2,
17 | embed_dims=256,
18 | pooling_stat="1111",
19 | spike_mode="lif",
20 | ):
21 | super().__init__()
22 | self.image_size = [img_size_h, img_size_w]
23 | patch_size = to_2tuple(patch_size)
24 | self.patch_size = patch_size
25 | self.pooling_stat = pooling_stat
26 |
27 | self.C = in_channels
28 | self.H, self.W = (
29 | self.image_size[0] // patch_size[0],
30 | self.image_size[1] // patch_size[1],
31 | )
32 | self.num_patches = self.H * self.W
33 | self.proj_conv = nn.Conv2d(
34 | in_channels, embed_dims // 8, kernel_size=3, stride=1, padding=1, bias=False
35 | )
36 | self.proj_bn = nn.BatchNorm2d(embed_dims // 8)
37 | if spike_mode == "lif":
38 | self.proj_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend="cupy")
39 | elif spike_mode == "plif":
40 | self.proj_lif = MultiStepParametricLIFNode(
41 | init_tau=2.0, detach_reset=True, backend="cupy"
42 | )
43 | self.maxpool = nn.MaxPool2d(
44 | kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False
45 | )
46 |
47 | self.proj_conv1 = nn.Conv2d(
48 | embed_dims // 8,
49 | embed_dims // 4,
50 | kernel_size=3,
51 | stride=1,
52 | padding=1,
53 | bias=False,
54 | )
55 | self.proj_bn1 = nn.BatchNorm2d(embed_dims // 4)
56 | if spike_mode == "lif":
57 | self.proj_lif1 = MultiStepLIFNode(
58 | tau=2.0, detach_reset=True, backend="cupy"
59 | )
60 | elif spike_mode == "plif":
61 | self.proj_lif1 = MultiStepParametricLIFNode(
62 | init_tau=2.0, detach_reset=True, backend="cupy"
63 | )
64 | self.maxpool1 = nn.MaxPool2d(
65 | kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False
66 | )
67 |
68 | self.proj_conv2 = nn.Conv2d(
69 | embed_dims // 4,
70 | embed_dims // 2,
71 | kernel_size=3,
72 | stride=1,
73 | padding=1,
74 | bias=False,
75 | )
76 | self.proj_bn2 = nn.BatchNorm2d(embed_dims // 2)
77 | if spike_mode == "lif":
78 | self.proj_lif2 = MultiStepLIFNode(
79 | tau=2.0, detach_reset=True, backend="cupy"
80 | )
81 | elif spike_mode == "plif":
82 | self.proj_lif2 = MultiStepParametricLIFNode(
83 | init_tau=2.0, detach_reset=True, backend="cupy"
84 | )
85 | self.maxpool2 = nn.MaxPool2d(
86 | kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False
87 | )
88 |
89 | self.proj_conv3 = nn.Conv2d(
90 | embed_dims // 2, embed_dims, kernel_size=3, stride=1, padding=1, bias=False
91 | )
92 | self.proj_bn3 = nn.BatchNorm2d(embed_dims)
93 | if spike_mode == "lif":
94 | self.proj_lif3 = MultiStepLIFNode(
95 | tau=2.0, detach_reset=True, backend="cupy"
96 | )
97 | elif spike_mode == "plif":
98 | self.proj_lif3 = MultiStepParametricLIFNode(
99 | init_tau=2.0, detach_reset=True, backend="cupy"
100 | )
101 | self.maxpool3 = nn.MaxPool2d(
102 | kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False
103 | )
104 |
105 | self.rpe_conv = nn.Conv2d(
106 | embed_dims, embed_dims, kernel_size=3, stride=1, padding=1, bias=False
107 | )
108 | self.rpe_bn = nn.BatchNorm2d(embed_dims)
109 | if spike_mode == "lif":
110 | self.rpe_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend="cupy")
111 | elif spike_mode == "plif":
112 | self.rpe_lif = MultiStepParametricLIFNode(
113 | init_tau=2.0, detach_reset=True, backend="cupy"
114 | )
115 |
116 | def forward(self, x, hook=None):
117 | T, B, _, H, W = x.shape
118 | ratio = 1
119 | x = self.proj_conv(x.flatten(0, 1)) # have some fire value
120 | x = self.proj_bn(x).reshape(T, B, -1, H // ratio, W // ratio).contiguous()
121 | x = self.proj_lif(x)
122 | if hook is not None:
123 | hook[self._get_name() + "_lif"] = x.detach()
124 | x = x.flatten(0, 1).contiguous()
125 | if self.pooling_stat[0] == "1":
126 | x = self.maxpool(x)
127 | ratio *= 2
128 |
129 | x = self.proj_conv1(x)
130 | x = self.proj_bn1(x).reshape(T, B, -1, H // ratio, W // ratio).contiguous()
131 | x = self.proj_lif1(x)
132 | if hook is not None:
133 | hook[self._get_name() + "_lif1"] = x.detach()
134 | x = x.flatten(0, 1).contiguous()
135 | if self.pooling_stat[1] == "1":
136 | x = self.maxpool1(x)
137 | ratio *= 2
138 |
139 | x = self.proj_conv2(x)
140 | x = self.proj_bn2(x).reshape(T, B, -1, H // ratio, W // ratio).contiguous()
141 | x = self.proj_lif2(x)
142 | if hook is not None:
143 | hook[self._get_name() + "_lif2"] = x.detach()
144 | x = x.flatten(0, 1).contiguous()
145 | if self.pooling_stat[2] == "1":
146 | x = self.maxpool2(x)
147 | ratio *= 2
148 |
149 | x = self.proj_conv3(x)
150 | x = self.proj_bn3(x)
151 | if self.pooling_stat[3] == "1":
152 | x = self.maxpool3(x)
153 | ratio *= 2
154 |
155 | x_feat = x
156 | x = self.proj_lif3(x.reshape(T, B, -1, H // ratio, W // ratio).contiguous())
157 | if hook is not None:
158 | hook[self._get_name() + "_lif3"] = x.detach()
159 | x = x.flatten(0, 1).contiguous()
160 | x = self.rpe_conv(x)
161 | x = self.rpe_bn(x)
162 | x = (x + x_feat).reshape(T, B, -1, H // ratio, W // ratio).contiguous()
163 |
164 | H, W = H // self.patch_size[0], W // self.patch_size[1]
165 | return x, (H, W), hook
166 |
--------------------------------------------------------------------------------
/sad/models/module/.ipynb_checkpoints/sps-checkpoint.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from spikingjelly.clock_driven.neuron import (
4 | MultiStepLIFNode,
5 | MultiStepParametricLIFNode,
6 | )
7 | from timm.models.layers import to_2tuple
8 |
9 |
10 | class MS_SPS(nn.Module):
11 | def __init__(
12 | self,
13 | img_size_h=128,
14 | img_size_w=128,
15 | patch_size=4,
16 | in_channels=2,
17 | embed_dims=256,
18 | pooling_stat="1111",
19 | spike_mode="lif",
20 | ):
21 | super().__init__()
22 | self.image_size = [img_size_h, img_size_w]
23 | patch_size = to_2tuple(patch_size)
24 | self.patch_size = patch_size
25 | self.pooling_stat = pooling_stat
26 |
27 | self.C = in_channels
28 | self.H, self.W = (
29 | self.image_size[0] // patch_size[0],
30 | self.image_size[1] // patch_size[1],
31 | )
32 | self.num_patches = self.H * self.W
33 | self.proj_conv = nn.Conv2d(
34 | in_channels, embed_dims // 8, kernel_size=3, stride=1, padding=1, bias=False
35 | )
36 | self.proj_bn = nn.BatchNorm2d(embed_dims // 8)
37 | if spike_mode == "lif":
38 | self.proj_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend="cupy")
39 | elif spike_mode == "plif":
40 | self.proj_lif = MultiStepParametricLIFNode(
41 | init_tau=2.0, detach_reset=True, backend="cupy"
42 | )
43 | self.maxpool = nn.MaxPool2d(
44 | kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False
45 | )
46 |
47 | self.proj_conv1 = nn.Conv2d(
48 | embed_dims // 8,
49 | embed_dims // 4,
50 | kernel_size=3,
51 | stride=1,
52 | padding=1,
53 | bias=False,
54 | )
55 | self.proj_bn1 = nn.BatchNorm2d(embed_dims // 4)
56 | if spike_mode == "lif":
57 | self.proj_lif1 = MultiStepLIFNode(
58 | tau=2.0, detach_reset=True, backend="cupy"
59 | )
60 | elif spike_mode == "plif":
61 | self.proj_lif1 = MultiStepParametricLIFNode(
62 | init_tau=2.0, detach_reset=True, backend="cupy"
63 | )
64 | self.maxpool1 = nn.MaxPool2d(
65 | kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False
66 | )
67 |
68 | self.proj_conv2 = nn.Conv2d(
69 | embed_dims // 4,
70 | embed_dims // 2,
71 | kernel_size=3,
72 | stride=1,
73 | padding=1,
74 | bias=False,
75 | )
76 | self.proj_bn2 = nn.BatchNorm2d(embed_dims // 2)
77 | if spike_mode == "lif":
78 | self.proj_lif2 = MultiStepLIFNode(
79 | tau=2.0, detach_reset=True, backend="cupy"
80 | )
81 | elif spike_mode == "plif":
82 | self.proj_lif2 = MultiStepParametricLIFNode(
83 | init_tau=2.0, detach_reset=True, backend="cupy"
84 | )
85 | self.maxpool2 = nn.MaxPool2d(
86 | kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False
87 | )
88 |
89 | self.proj_conv3 = nn.Conv2d(
90 | embed_dims // 2, embed_dims, kernel_size=3, stride=1, padding=1, bias=False
91 | )
92 | self.proj_bn3 = nn.BatchNorm2d(embed_dims)
93 | if spike_mode == "lif":
94 | self.proj_lif3 = MultiStepLIFNode(
95 | tau=2.0, detach_reset=True, backend="cupy"
96 | )
97 | elif spike_mode == "plif":
98 | self.proj_lif3 = MultiStepParametricLIFNode(
99 | init_tau=2.0, detach_reset=True, backend="cupy"
100 | )
101 | self.maxpool3 = nn.MaxPool2d(
102 | kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False
103 | )
104 |
105 | self.rpe_conv = nn.Conv2d(
106 | embed_dims, embed_dims, kernel_size=3, stride=1, padding=1, bias=False
107 | )
108 | self.rpe_bn = nn.BatchNorm2d(embed_dims)
109 | if spike_mode == "lif":
110 | self.rpe_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend="cupy")
111 | elif spike_mode == "plif":
112 | self.rpe_lif = MultiStepParametricLIFNode(
113 | init_tau=2.0, detach_reset=True, backend="cupy"
114 | )
115 |
116 | def forward(self, x, hook=None):
117 | T, B, _, H, W = x.shape
118 | ratio = 1
119 | x = self.proj_conv(x.flatten(0, 1)) # have some fire value
120 | x = self.proj_bn(x).reshape(T, B, -1, H // ratio, W // ratio).contiguous()
121 | x = self.proj_lif(x)
122 | if hook is not None:
123 | hook[self._get_name() + "_lif"] = x.detach()
124 | x = x.flatten(0, 1).contiguous()
125 | if self.pooling_stat[0] == "1":
126 | x = self.maxpool(x)
127 | ratio *= 2
128 |
129 | x = self.proj_conv1(x)
130 | x = self.proj_bn1(x).reshape(T, B, -1, H // ratio, W // ratio).contiguous()
131 | x = self.proj_lif1(x)
132 | if hook is not None:
133 | hook[self._get_name() + "_lif1"] = x.detach()
134 | x = x.flatten(0, 1).contiguous()
135 | if self.pooling_stat[1] == "1":
136 | x = self.maxpool1(x)
137 | ratio *= 2
138 |
139 | x = self.proj_conv2(x)
140 | x = self.proj_bn2(x).reshape(T, B, -1, H // ratio, W // ratio).contiguous()
141 | x = self.proj_lif2(x)
142 | if hook is not None:
143 | hook[self._get_name() + "_lif2"] = x.detach()
144 | x = x.flatten(0, 1).contiguous()
145 | if self.pooling_stat[2] == "1":
146 | x = self.maxpool2(x)
147 | ratio *= 2
148 |
149 | x = self.proj_conv3(x)
150 | x = self.proj_bn3(x)
151 | if self.pooling_stat[3] == "1":
152 | x = self.maxpool3(x)
153 | ratio *= 2
154 |
155 | x_feat = x
156 | x = self.proj_lif3(x.reshape(T, B, -1, H // ratio, W // ratio).contiguous())
157 | if hook is not None:
158 | hook[self._get_name() + "_lif3"] = x.detach()
159 | x = x.flatten(0, 1).contiguous()
160 | x = self.rpe_conv(x)
161 | x = self.rpe_bn(x)
162 | x = (x + x_feat).reshape(T, B, -1, H // ratio, W // ratio).contiguous()
163 |
164 | H, W = H // self.patch_size[0], W // self.patch_size[1]
165 | return x, (H, W), hook
166 |
--------------------------------------------------------------------------------
/sad/config.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from fvcore.common.config import CfgNode as _CfgNode
3 |
4 |
5 | def convert_to_dict(cfg_node, key_list=[]):
6 | """Convert a config node to dictionary."""
7 | _VALID_TYPES = {tuple, list, str, int, float, bool}
8 | if not isinstance(cfg_node, _CfgNode):
9 | if type(cfg_node) not in _VALID_TYPES:
10 | print(
11 | 'Key {} with value {} is not a valid type; valid types: {}'.format(
12 | '.'.join(key_list), type(cfg_node), _VALID_TYPES
13 | ),
14 | )
15 | return cfg_node
16 | else:
17 | cfg_dict = dict(cfg_node)
18 | for k, v in cfg_dict.items():
19 | cfg_dict[k] = convert_to_dict(v, key_list + [k])
20 | return cfg_dict
21 |
22 |
23 | class CfgNode(_CfgNode):
24 | """Remove once https://github.com/rbgirshick/yacs/issues/19 is merged."""
25 |
26 | def convert_to_dict(self):
27 | return convert_to_dict(self)
28 |
29 |
30 | CN = CfgNode
31 |
32 | _C = CN()
33 | _C.LOG_DIR = 'tensorboard_logs'
34 | _C.TAG = 'default'
35 |
36 | _C.GPUS = [0] # which gpus to use
37 | _C.PRECISION = 32 # 16bit or 32bit
38 | _C.BATCHSIZE = 3
39 | _C.EPOCHS = 20
40 |
41 | _C.N_WORKERS = 5
42 | _C.VIS_INTERVAL = 5000
43 | _C.LOGGING_INTERVAL = 500
44 |
45 | _C.PRETRAINED = CN()
46 | _C.PRETRAINED.LOAD_WEIGHTS = False
47 | _C.PRETRAINED.PATH = ''
48 |
49 | _C.DATASET = CN()
50 | _C.DATASET.DATAROOT = '/data/Nuscenes'
51 | _C.DATASET.VERSION = 'trainval'
52 | _C.DATASET.NAME = 'nuscenes'
53 | _C.DATASET.MAP_FOLDER = '/data/Nuscenes'
54 | _C.DATASET.IGNORE_INDEX = 255 # Ignore index when creating flow/offset labels
55 | _C.DATASET.FILTER_INVISIBLE_VEHICLES = True # Filter vehicles that are not visible from the cameras
56 | _C.DATASET.SAVE_DIR = 'datas'
57 |
58 | _C.TIME_RECEPTIVE_FIELD = 3 # how many frames of temporal context (1 for single timeframe)
59 | _C.N_FUTURE_FRAMES = 4 # how many time steps into the future to predict
60 |
61 | _C.IMAGE = CN()
62 | _C.IMAGE.FINAL_DIM = (224, 480)
63 | _C.IMAGE.RESIZE_SCALE = 0.3
64 | _C.IMAGE.TOP_CROP = 46
65 | _C.IMAGE.ORIGINAL_HEIGHT = 900 # Original input RGB camera height
66 | _C.IMAGE.ORIGINAL_WIDTH = 1600 # Original input RGB camera width
67 | _C.IMAGE.NAMES = ['CAM_FRONT_LEFT', 'CAM_FRONT', 'CAM_FRONT_RIGHT', 'CAM_BACK_LEFT', 'CAM_BACK', 'CAM_BACK_RIGHT']
68 |
69 | _C.LIFT = CN() # image to BEV lifting
70 | _C.LIFT.X_BOUND = [-50.0, 50.0, 0.5] # Forward
71 | _C.LIFT.Y_BOUND = [-50.0, 50.0, 0.5] # Sides
72 | _C.LIFT.Z_BOUND = [-10.0, 10.0, 20.0] # Height
73 | _C.LIFT.D_BOUND = [2.0, 50.0, 1.0]
74 | _C.LIFT.GT_DEPTH = False
75 | _C.LIFT.DISCOUNT = 0.5
76 |
77 | _C.EGO = CN()
78 | _C.EGO.WIDTH = 1.85
79 | _C.EGO.HEIGHT = 4.084
80 |
81 | _C.MODEL = CN()
82 |
83 | _C.MODEL.ENCODER = CN()
84 | _C.MODEL.ENCODER.DOWNSAMPLE = 8
85 | _C.MODEL.ENCODER.NAME = 'efficientnet-b4'
86 | _C.MODEL.ENCODER.OUT_CHANNELS = 64
87 | _C.MODEL.ENCODER.USE_DEPTH_DISTRIBUTION = True
88 |
89 | _C.MODEL.TEMPORAL_MODEL = CN()
90 | _C.MODEL.TEMPORAL_MODEL.NAME = 'temporal_block' # type of temporal model
91 | _C.MODEL.TEMPORAL_MODEL.START_OUT_CHANNELS = 64
92 | _C.MODEL.TEMPORAL_MODEL.EXTRA_IN_CHANNELS = 0
93 | _C.MODEL.TEMPORAL_MODEL.INBETWEEN_LAYERS = 0
94 | _C.MODEL.TEMPORAL_MODEL.PYRAMID_POOLING = True
95 | _C.MODEL.TEMPORAL_MODEL.INPUT_EGOPOSE = True
96 |
97 | _C.MODEL.DISTRIBUTION = CN()
98 | _C.MODEL.DISTRIBUTION.LATENT_DIM = 32
99 | _C.MODEL.DISTRIBUTION.MIN_LOG_SIGMA = -5.0
100 | _C.MODEL.DISTRIBUTION.MAX_LOG_SIGMA = 5.0
101 |
102 | _C.MODEL.FUTURE_PRED = CN()
103 | _C.MODEL.FUTURE_PRED.N_GRU_BLOCKS = 2
104 | _C.MODEL.FUTURE_PRED.N_RES_LAYERS = 1
105 | _C.MODEL.FUTURE_PRED.MIXTURE = True
106 |
107 | _C.MODEL.DECODER = CN()
108 |
109 | _C.MODEL.BN_MOMENTUM = 0.1
110 |
111 | _C.SEMANTIC_SEG = CN()
112 |
113 | _C.SEMANTIC_SEG.VEHICLE = CN()
114 | _C.SEMANTIC_SEG.VEHICLE.WEIGHTS = [1.0, 2.0]
115 | _C.SEMANTIC_SEG.VEHICLE.USE_TOP_K = True # backprop only top-k hardest pixels
116 | _C.SEMANTIC_SEG.VEHICLE.TOP_K_RATIO = 0.25
117 |
118 | _C.SEMANTIC_SEG.PEDESTRIAN = CN()
119 | _C.SEMANTIC_SEG.PEDESTRIAN.ENABLED = True
120 | _C.SEMANTIC_SEG.PEDESTRIAN.WEIGHTS = [1.0, 10.0]
121 | _C.SEMANTIC_SEG.PEDESTRIAN.USE_TOP_K = True
122 | _C.SEMANTIC_SEG.PEDESTRIAN.TOP_K_RATIO = 0.25
123 |
124 | _C.SEMANTIC_SEG.HDMAP = CN()
125 | _C.SEMANTIC_SEG.HDMAP.ENABLED = True
126 | _C.SEMANTIC_SEG.HDMAP.ELEMENTS = ['lane_divider', 'drivable_area']
127 | _C.SEMANTIC_SEG.HDMAP.WEIGHTS = [[1.0, 5.0], [1.0, 1.0]]
128 | _C.SEMANTIC_SEG.HDMAP.TRAIN_WEIGHT = [1, 1]
129 | _C.SEMANTIC_SEG.HDMAP.USE_TOP_K = [True, False]
130 | _C.SEMANTIC_SEG.HDMAP.TOP_K_RATIO = [0.25, 0.25]
131 |
132 | _C.INSTANCE_SEG = CN()
133 | _C.INSTANCE_SEG.ENABLED = True
134 |
135 | _C.INSTANCE_FLOW = CN()
136 | _C.INSTANCE_FLOW.ENABLED = True
137 |
138 | _C.PROBABILISTIC = CN()
139 | _C.PROBABILISTIC.ENABLED = True # learn a distribution over futures
140 | _C.PROBABILISTIC.METHOD = 'GAUSSIAN' # [BERNOULLI, GAUSSIAN, MIXGAUSSIAN]
141 |
142 | _C.PLANNING = CN()
143 | _C.PLANNING.ENABLED = True
144 | _C.PLANNING.GRU_STATE_SIZE = 256
145 | _C.PLANNING.SAMPLE_NUM = 600
146 | _C.PLANNING.COMMAND = ['LEFT', 'FORWARD', 'RIGHT']
147 |
148 | _C.FUTURE_DISCOUNT = 0.95
149 |
150 | _C.OPTIMIZER = CN()
151 | _C.OPTIMIZER.LR = 3e-4
152 | _C.OPTIMIZER.WEIGHT_DECAY = 1e-7
153 | _C.GRAD_NORM_CLIP = 5
154 |
155 | _C.COST_FUNCTION = CN()
156 | _C.COST_FUNCTION.SAFETY = 0.1
157 | _C.COST_FUNCTION.LAMBDA = 1.
158 | _C.COST_FUNCTION.HEADWAY = 1.
159 | _C.COST_FUNCTION.LRDIVIDER = 10.
160 | _C.COST_FUNCTION.COMFORT = 0.1
161 | _C.COST_FUNCTION.PROGRESS = 0.5
162 | _C.COST_FUNCTION.VOLUME = 100.
163 |
164 | def get_parser():
165 | parser = argparse.ArgumentParser(description='Fiery training')
166 | parser.add_argument('--config-file', default='', metavar='FILE', help='path to config file')
167 | parser.add_argument(
168 | 'opts', help='Modify config options using the command-line', default=None, nargs=argparse.REMAINDER,
169 | )
170 | return parser
171 |
172 |
173 | def get_cfg(args=None, cfg_dict=None):
174 | """ First get default config. Then merge cfg_dict. Then merge according to args. """
175 |
176 | cfg = _C.clone()
177 |
178 | if cfg_dict is not None:
179 | tmp = CfgNode(cfg_dict)
180 | for i in tmp.COST_FUNCTION:
181 | tmp.COST_FUNCTION.update({i: float(tmp.COST_FUNCTION.get(i))})
182 | cfg.merge_from_other_cfg(tmp)
183 |
184 | if args is not None:
185 | if args.config_file:
186 | cfg.merge_from_file(args.config_file)
187 | cfg.merge_from_list(args.opts)
188 | # cfg.freeze()
189 | return cfg
190 |
--------------------------------------------------------------------------------
/maps/hdmap_generate.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import cv2 as cv
3 | import json
4 | from collections import deque
5 | from pathlib import Path
6 | import h5py
7 | import os
8 | import tqdm
9 |
10 | classes = {
11 | 0: [0, 0, 0], # unlabeled
12 | 1: [0, 0, 0], # building
13 | 2: [0, 0, 0], # fence
14 | 3: [0, 0, 0], # other
15 | 4: [0, 255, 0], # pedestrian
16 | 5: [0, 0, 0], # pole
17 | 6: [157, 234, 50], # road line
18 | 7: [128, 64, 128], # road
19 | 8: [255, 255, 255], # sidewalk
20 | 9: [0, 0, 0], # vegetation
21 | 10: [0, 0, 255], # vehicle
22 | 11: [0, 0, 0], # wall
23 | 12: [0, 0, 0], # traffic sign
24 | 13: [0, 0, 0], # sky
25 | 14: [0, 0, 0], # ground
26 | 15: [0, 0, 0], # bridge
27 | 16: [0, 0, 0], # rail track
28 | 17: [0, 0, 0], # guard rail
29 | 18: [0, 0, 0], # traffic light
30 | 19: [0, 0, 0], # static
31 | 20: [0, 0, 0], # dynamic
32 | 21: [0, 0, 0], # water
33 | 22: [0, 0, 0], # terrain
34 | 23: [255, 0, 0], # red light
35 | 24: [0, 0, 0], # yellow light #TODO should be red
36 | 25: [0, 0, 0], # green light
37 | 26: [157, 234, 50], # stop sign
38 | 27: [157, 234, 50], # stop line marking
39 |
40 | }
41 |
42 |
43 | COLOR_BLACK = (0, 0, 0)
44 | COLOR_RED = (255, 0, 0)
45 | COLOR_GREEN = (0, 255, 0)
46 | COLOR_BLUE = (0, 0, 255)
47 | COLOR_CYAN = (0, 255, 255)
48 | COLOR_MAGENTA = (255, 0, 255)
49 | COLOR_MAGENTA_2 = (255, 140, 255)
50 | COLOR_YELLOW = (255, 255, 0)
51 | COLOR_YELLOW_2 = (160, 160, 0)
52 | COLOR_WHITE = (255, 255, 255)
53 | COLOR_ALUMINIUM_0 = (238, 238, 236)
54 | COLOR_ALUMINIUM_3 = (136, 138, 133)
55 | COLOR_ALUMINIUM_5 = (46, 52, 54)
56 |
57 | pixels_per_meter = 5
58 | width = 512
59 | pixels_ev_to_bottom = 256
60 |
61 | def tint(color, factor):
62 | r, g, b = color
63 | r = int(r + (255-r) * factor)
64 | g = int(g + (255-g) * factor)
65 | b = int(b + (255-b) * factor)
66 | r = min(r, 255)
67 | g = min(g, 255)
68 | b = min(b, 255)
69 | return (r, g, b)
70 |
71 | def get_warp_transform(ev_loc, ev_rot, world_offset):
72 | ev_loc_in_px = world_to_pixel(ev_loc, world_offset)
73 | yaw = np.deg2rad(ev_rot)
74 |
75 | forward_vec = np.array([np.cos(yaw), np.sin(yaw)])
76 | right_vec = np.array([np.cos(yaw + 0.5*np.pi), np.sin(yaw + 0.5*np.pi)])
77 |
78 | bottom_left = ev_loc_in_px - pixels_ev_to_bottom * forward_vec - (0.5*width) * right_vec
79 | top_left = ev_loc_in_px + (width-pixels_ev_to_bottom) * forward_vec - (0.5*width) * right_vec
80 | top_right = ev_loc_in_px + (width-pixels_ev_to_bottom) * forward_vec + (0.5*width) * right_vec
81 |
82 | src_pts = np.stack((bottom_left, top_left, top_right), axis=0).astype(np.float32)
83 | dst_pts = np.array([[0, width-1],
84 | [0, 0],
85 | [width-1, 0]], dtype=np.float32)
86 | return cv.getAffineTransform(src_pts, dst_pts)
87 |
88 | def world_to_pixel(location, world_offset, projective=False):
89 | """Converts the world coordinates to pixel coordinates"""
90 | x = pixels_per_meter * (location[0] - world_offset[0])
91 | y = pixels_per_meter * (location[1] - world_offset[1])
92 |
93 | if projective:
94 | p = np.array([x, y, 1], dtype=np.float32)
95 | else:
96 | p = np.array([x, y], dtype=np.float32)
97 | return p
98 |
99 | Town2map = {
100 | "town01": "Town01.h5",
101 | "town02": "Town02.h5",
102 | "town03": "Town03.h5",
103 | "town04": "Town04.h5",
104 | "town05": "Town05.h5",
105 | "town06": "Town06.h5",
106 | "town07": "Town07.h5",
107 | "town10": "Town10HD.h5",
108 |
109 | }
110 |
111 | root = ""
112 | map_path = ""
113 | town_list = list(os.listdir(root))
114 | # town_list = town_list[18:]
115 | # town_list = ["town02_short","town02_tiny"]
116 | for town in tqdm.tqdm(town_list):
117 |
118 | maps_h5_path = os.path.join(map_path, Town2map[town[:6]])
119 |
120 | with h5py.File(maps_h5_path, 'r', libver='latest', swmr=True) as hf:
121 | road = np.array(hf['road'], dtype=np.uint8)
122 | lane_marking_yellow_broken = np.array(hf['lane_marking_yellow_broken'], dtype=np.uint8)
123 | lane_marking_yellow_solid = np.array(hf['lane_marking_yellow_solid'], dtype=np.uint8)
124 | lane_marking_white_broken = np.array(hf['lane_marking_white_broken'], dtype=np.uint8)
125 | lane_marking_white_solid = np.array(hf['lane_marking_white_solid'], dtype=np.uint8)
126 | world_offset = np.array(hf.attrs['world_offset_in_meters'], dtype=np.float32)
127 |
128 |
129 | town_folder = os.path.join(root, town)
130 | route_list = [route for route in os.listdir(town_folder) if os.path.isdir(os.path.join(town_folder, route)) ]
131 |
132 | for route in tqdm.tqdm(route_list):
133 | route_folder = os.path.join(town_folder, route)
134 | os.makedirs(os.path.join(route_folder, "hdmap"), exist_ok=True)
135 | measurement_folder = os.path.join(route_folder, "meta")
136 | measurement_files = os.listdir(measurement_folder)
137 | for measurement in measurement_files:
138 | with open(os.path.join(measurement_folder, measurement), "r") as read_file:
139 | measurement_data = json.load(read_file)
140 | x = measurement_data['x']
141 | y = measurement_data['y']
142 |
143 | theta = measurement_data['theta']
144 | if np.isnan(theta):
145 | theta = 0
146 |
147 |
148 | ev_loc = [y , -x]
149 | ev_rot = np.rad2deg(theta) - 90
150 |
151 | M_warp = get_warp_transform(ev_loc, ev_rot, world_offset)
152 | road_mask = cv.warpAffine(road, M_warp, (width, width)).astype(np.bool)
153 | lane_mask_white_broken = cv.warpAffine(lane_marking_white_broken, M_warp,
154 | (width, width)).astype(np.bool)
155 | lane_mask_white_solid = cv.warpAffine(lane_marking_white_solid, M_warp,
156 | (width, width)).astype(np.bool)
157 | lane_mask_yellow_broken = cv.warpAffine(lane_marking_yellow_broken, M_warp,
158 | (width, width)).astype(np.bool)
159 | lane_mask_yellow_solid = cv.warpAffine(lane_marking_yellow_solid, M_warp,
160 | (width, width)).astype(np.bool)
161 |
162 | image = np.zeros([width, width, 3], dtype=np.uint8)
163 | image[road_mask] = COLOR_ALUMINIUM_5
164 | image[lane_mask_white_broken] = COLOR_MAGENTA
165 | image[lane_mask_white_solid] = COLOR_MAGENTA
166 | image[lane_mask_yellow_broken] = COLOR_MAGENTA
167 | image[lane_mask_yellow_solid] = COLOR_MAGENTA
168 | cv.imwrite(os.path.join(route_folder, "hdmap", measurement.replace('json', 'png')), image)
169 |
170 |
171 |
172 |
173 |
174 |
175 |
176 |
177 |
178 | # image[lane_mask_broken] = COLOR_MAGENTA_2
179 |
180 | # cv.imwrite("hdmap.png", image)
181 |
182 |
183 | # seg = cv.imread("0009.png")[:, :, 2]
184 | # result = np.zeros((seg.shape[0], seg.shape[1], 3))
185 | # for key, value in classes.items():
186 | # result[np.where(seg == key)] = value
187 |
188 | # cv.imwrite("seg_vis.png", result)
--------------------------------------------------------------------------------
/sad/models/planning_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 |
6 | from sad.layers.convolutions import Bottleneck, SpikingBottleneck
7 | from sad.layers.temporal import SpatialGRU, Dual_GRU, BiGRU
8 | from sad.cost import Cost_Function
9 |
10 | from spikingjelly.clock_driven.rnn import SpikingGRUCell
11 | from spikingjelly.clock_driven import surrogate
12 |
13 | class Planning(nn.Module):
14 | def __init__(self, cfg, feature_channel, gru_input_size=6, gru_state_size=256):
15 | super(Planning, self).__init__()
16 | self.cost_function = Cost_Function(cfg)
17 |
18 | self.sample_num = cfg.PLANNING.SAMPLE_NUM
19 | self.commands = cfg.PLANNING.COMMAND
20 | assert self.sample_num % 3 == 0
21 | self.num = int(self.sample_num / 3)
22 |
23 | self.reduce_channel = nn.Sequential(
24 | SpikingBottleneck(feature_channel, feature_channel, downsample=True),
25 | SpikingBottleneck(feature_channel, int(feature_channel/2), downsample=True),
26 | SpikingBottleneck(int(feature_channel/2), int(feature_channel/2), downsample=True),
27 | SpikingBottleneck(int(feature_channel/2), int(feature_channel/8))
28 | )
29 |
30 | # self.GRU = nn.GRUCell(gru_input_size, gru_state_size)
31 | self.GRU = SpikingGRUCell(gru_input_size, gru_state_size, surrogate_function1=surrogate.ATan(), surrogate_function2=surrogate.ATan())
32 | self.decoder = nn.Sequential(
33 | nn.Linear(gru_state_size, gru_state_size),
34 | nn.Linear(gru_state_size, 2)
35 | )
36 |
37 |
38 | def compute_L2(self, trajs, gt_traj):
39 | '''
40 | trajs: torch.Tensor (B, N, n_future, 3)
41 | gt_traj: torch.Tensor (B,1, n_future, 3)
42 | '''
43 | if trajs.ndim == 4 and gt_traj.ndim == 4:
44 | return ((trajs[:,:,:,:2] - gt_traj[:,:,:,:2]) ** 2).sum(dim=-1)
45 | if trajs.ndim == 3 and gt_traj.ndim == 3:
46 | return ((trajs[:, :, :2] - gt_traj[:, :, :2]) ** 2).sum(dim=-1)
47 |
48 | raise ValueError('trajs ndim != gt_traj ndim')
49 |
50 | def select(self, trajs, cost_volume, semantic_pred, lane_divider, drivable_area, target_points, k=1):
51 | '''
52 | trajs: torch.Tensor (B, N, n_future, 3)
53 | cost_volume: torch.Tensor (B, n_future, 200, 200)
54 | semantic_pred: torch.Tensor(B, n_future, 200, 200)
55 | lane_divider: torch.Tensor(B, 1/2, 200, 200)
56 | drivable_area: torch.Tensor(B, 1/2, 200, 200)
57 | target_points: torch.Tensor (B, 2)
58 | '''
59 | sm_cost_fc, sm_cost_fo = self.cost_function(cost_volume, trajs[:,:,:,:2], semantic_pred, lane_divider, drivable_area, target_points)
60 |
61 | CS = sm_cost_fc + sm_cost_fo.sum(dim=-1)
62 | CC, KK = torch.topk(CS, k, dim=-1, largest=False)
63 |
64 | ii = torch.arange(len(trajs))
65 | select_traj = trajs[ii[:,None], KK].squeeze(1) # (B, n_future, 3)
66 |
67 | return select_traj
68 |
69 | def loss(self, trajs, gt_trajs, cost_volume, semantic_pred, lane_divider, drivable_area, target_points):
70 | '''
71 | trajs: torch.Tensor (B, N, n_future, 3)
72 | gt_trajs: torch.Tensor (B, n_future, 3)
73 | cost_volume: torch.Tensor (B, n_future, 200, 200)
74 | semantic_pred: torch.Tensor(B, n_future, 200, 200)
75 | lane_divider: torch.Tensor(B, 1/2, 200, 200)
76 | drivable_area: torch.Tensor(B, 1/2, 200, 200)
77 | target_points: torch.Tensor (B, 2)
78 | '''
79 | sm_cost_fc, sm_cost_fo = self.cost_function(cost_volume, trajs[:, :, :, :2], semantic_pred, lane_divider, drivable_area, target_points)
80 |
81 | if gt_trajs.ndim == 3:
82 | gt_trajs = gt_trajs[:, None]
83 |
84 | gt_cost_fc, gt_cost_fo = self.cost_function(cost_volume, gt_trajs[:, :, :, :2], semantic_pred, lane_divider, drivable_area, target_points)
85 |
86 | L, _ = F.relu(
87 | F.relu(gt_cost_fo - sm_cost_fo).sum(-1) + (gt_cost_fc - sm_cost_fc) + self.compute_L2(trajs, gt_trajs).mean(
88 | dim=-1)).max(dim=-1)
89 |
90 | return torch.mean(L)
91 |
92 | def forward(self,cam_front, trajs, gt_trajs, cost_volume, semantic_pred, hd_map, commands, target_points):
93 | '''
94 | cam_front: torch.Tensor (B, 64, 60, 28)
95 | trajs: torch.Tensor (B, N, n_future, 3)
96 | gt_trajs: torch.Tensor (B, n_future, 3)
97 | cost_volume: torch.Tensor (B, n_future, 200, 200)
98 | semantic_pred: torch.Tensor(B, n_future, 200, 200)
99 | hd_map: torch.Tensor (B, 2/4, 200, 200)
100 | commands: List (B)
101 | target_points: (B, 2)
102 | '''
103 |
104 | cur_trajs = []
105 | for i in range(len(commands)):
106 | command = commands[i]
107 | traj = trajs[i]
108 | if command == 'LEFT':
109 | cur_trajs.append(traj[:self.num].repeat(3, 1, 1))
110 | elif command == 'FORWARD':
111 | cur_trajs.append(traj[self.num:self.num * 2].repeat(3, 1, 1))
112 | elif command == 'RIGHT':
113 | cur_trajs.append(traj[self.num * 2:].repeat(3, 1, 1))
114 | else:
115 | cur_trajs.append(traj)
116 | cur_trajs = torch.stack(cur_trajs)
117 |
118 | if hd_map.shape[1] == 2:
119 | lane_divider = hd_map[:, 0:1]
120 | drivable_area = hd_map[:, 1:2]
121 | elif hd_map.shape[1] == 4:
122 | lane_divider = hd_map[:, 0:2]
123 | drivable_area = hd_map[:, 2:4]
124 | else:
125 | raise NotImplementedError
126 |
127 | if self.training:
128 | loss = self.loss(cur_trajs, gt_trajs, cost_volume, semantic_pred, lane_divider, drivable_area, target_points)
129 | else:
130 | loss = 0
131 |
132 | cam_front = self.reduce_channel(cam_front)
133 | h0 = cam_front.flatten(start_dim=1) # (B, 256/128)
134 | final_traj = self.select(cur_trajs, cost_volume, semantic_pred, lane_divider, drivable_area, target_points) # (B, n_future, 3)
135 | target_points = target_points.to(dtype=h0.dtype)
136 | b, s, _ = final_traj.shape
137 | x = torch.zeros((b, 2), device=h0.device)
138 | output_traj = []
139 | for i in range(s):
140 | x = torch.cat([x, final_traj[:,i,:2], target_points], dim=-1) # (B, 6)
141 | h0 = self.GRU(x, h0)
142 | x = self.decoder(h0) # (B, 2)
143 | output_traj.append(x)
144 | output_traj = torch.stack(output_traj, dim=1) # (B, 4, 2)
145 |
146 | output_traj = torch.cat(
147 | [output_traj, torch.zeros((*output_traj.shape[:-1],1), device=output_traj.device)], dim=-1
148 | )
149 |
150 | if self.training:
151 | loss = loss*0.5 + (F.smooth_l1_loss(output_traj[:,:,:2], gt_trajs[:,:,:2], reduction='none')*torch.tensor([10., 1.], device=loss.device)).mean()
152 |
153 | return loss, output_traj
154 |
--------------------------------------------------------------------------------
/sad/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class SpatialRegressionLoss(nn.Module):
7 | def __init__(self, norm, ignore_index=255, future_discount=1.0):
8 | super(SpatialRegressionLoss, self).__init__()
9 | self.norm = norm
10 | self.ignore_index = ignore_index
11 | self.future_discount = future_discount
12 |
13 | if norm == 1:
14 | self.loss_fn = F.l1_loss
15 | elif norm == 2:
16 | self.loss_fn = F.mse_loss
17 | else:
18 | raise ValueError(f'Expected norm 1 or 2, but got norm={norm}')
19 |
20 | def forward(self, prediction, target, n_present=3):
21 | assert len(prediction.shape) == 5, 'Must be a 5D tensor'
22 | # ignore_index is the same across all channels
23 | mask = target[:, :, :1] != self.ignore_index
24 | if mask.sum() == 0:
25 | return prediction.new_zeros(1)[0].float()
26 |
27 | loss = self.loss_fn(prediction, target, reduction='none')
28 |
29 | # Sum channel dimension
30 | loss = torch.sum(loss, dim=-3, keepdim=True)
31 |
32 | seq_len = loss.shape[1]
33 | assert seq_len >= n_present
34 | future_len = seq_len - n_present
35 | future_discounts = self.future_discount ** torch.arange(1, future_len+1, device=loss.device, dtype=loss.dtype)
36 | discounts = torch.cat([torch.ones(n_present, device=loss.device, dtype=loss.dtype), future_discounts], dim=0)
37 | discounts = discounts.view(1, seq_len, 1, 1, 1)
38 | loss = loss * discounts
39 |
40 | return loss[mask].mean()
41 |
42 |
43 | class SegmentationLoss(nn.Module):
44 | def __init__(self, class_weights, ignore_index=255, use_top_k=False, top_k_ratio=1.0, future_discount=1.0):
45 | super().__init__()
46 | self.class_weights = class_weights
47 | self.ignore_index = ignore_index
48 | self.use_top_k = use_top_k
49 | self.top_k_ratio = top_k_ratio
50 | self.future_discount = future_discount
51 |
52 | def forward(self, prediction, target, n_present=3):
53 | if target.shape[-3] != 1:
54 | raise ValueError('segmentation label must be an index-label with channel dimension = 1.')
55 | b, s, c, h, w = prediction.shape
56 |
57 | prediction = prediction.view(b * s, c, h, w)
58 | target = target.view(b * s, h, w)
59 | loss = F.cross_entropy(
60 | prediction,
61 | target,
62 | ignore_index=self.ignore_index,
63 | reduction='none',
64 | weight=self.class_weights.to(target.device),
65 | )
66 |
67 | loss = loss.view(b, s, h, w)
68 |
69 | assert s >= n_present
70 | future_len = s - n_present
71 | future_discounts = self.future_discount ** torch.arange(1, future_len+1, device=loss.device, dtype=loss.dtype)
72 | discounts = torch.cat([torch.ones(n_present, device=loss.device, dtype=loss.dtype), future_discounts], dim=0)
73 | discounts = discounts.view(1, s, 1, 1)
74 | loss = loss * discounts
75 |
76 | loss = loss.view(b, s, -1)
77 | if self.use_top_k:
78 | # Penalises the top-k hardest pixels
79 | k = int(self.top_k_ratio * loss.shape[2])
80 | loss, _ = torch.sort(loss, dim=2, descending=True)
81 | loss = loss[:, :, :k]
82 |
83 | return torch.mean(loss)
84 |
85 | class HDmapLoss(nn.Module):
86 | def __init__(self, class_weights, training_weights, use_top_k, top_k_ratio, ignore_index=255):
87 | super(HDmapLoss, self).__init__()
88 | self.class_weights = class_weights
89 | self.training_weights = training_weights
90 | self.ignore_index = ignore_index
91 | self.use_top_k = use_top_k
92 | self.top_k_ratio = top_k_ratio
93 |
94 | def forward(self, prediction, target):
95 | loss = 0
96 | for i in range(target.shape[-3]):
97 | cur_target = target[:, i]
98 | b, h, w = cur_target.shape
99 | cur_prediction = prediction[:, 2*i:2*(i+1)]
100 | cur_loss = F.cross_entropy(
101 | cur_prediction,
102 | cur_target,
103 | ignore_index=self.ignore_index,
104 | reduction='none',
105 | weight=self.class_weights[i].to(target.device),
106 | )
107 |
108 | cur_loss = cur_loss.view(b, -1)
109 | if self.use_top_k[i]:
110 | k = int(self.top_k_ratio[i] * cur_loss.shape[1])
111 | cur_loss, _ = torch.sort(cur_loss, dim=1, descending=True)
112 | cur_loss = cur_loss[:, :k]
113 | loss += torch.mean(cur_loss) * self.training_weights[i]
114 | return loss
115 |
116 | class DepthLoss(nn.Module):
117 | def __init__(self, class_weights=None, ignore_index=255):
118 | super(DepthLoss, self).__init__()
119 | self.class_weights = class_weights
120 | self.ignore_index = ignore_index
121 |
122 | def forward(self, prediction, target):
123 | b, s, n, d, h, w = prediction.shape
124 |
125 | prediction = prediction.view(b*s*n, d, h, w)
126 | target = target.view(b*s*n, h, w)
127 | loss = F.cross_entropy(
128 | prediction,
129 | target,
130 | ignore_index=self.ignore_index,
131 | reduction='none',
132 | weight=self.class_weights
133 | )
134 | return torch.mean(loss)
135 |
136 |
137 | class ProbabilisticLoss(nn.Module):
138 | def __init__(self, method):
139 | super(ProbabilisticLoss, self).__init__()
140 | self.method = method
141 |
142 | def kl_div(self, present_mu, present_log_sigma, future_mu, future_log_sigma):
143 | var_future = torch.exp(2 * future_log_sigma)
144 | var_present = torch.exp(2 * present_log_sigma)
145 | kl_div = (
146 | present_log_sigma - future_log_sigma - 0.5 + (var_future + (future_mu - present_mu) ** 2) / (
147 | 2 * var_present)
148 | )
149 |
150 | kl_loss = torch.mean(torch.sum(kl_div, dim=-1))
151 | return kl_loss
152 |
153 | def forward(self, output):
154 | if self.method == 'GAUSSIAN':
155 | present_mu = output['present_mu']
156 | present_log_sigma = output['present_log_sigma']
157 | future_mu = output['future_mu']
158 | future_log_sigma = output['future_log_sigma']
159 |
160 | kl_loss = self.kl_div(present_mu, present_log_sigma, future_mu, future_log_sigma)
161 | elif self.method == 'MIXGAUSSIAN':
162 | present_mu = output['present_mu']
163 | present_log_sigma = output['present_log_sigma']
164 | future_mu = output['future_mu']
165 | future_log_sigma = output['future_log_sigma']
166 |
167 | kl_loss = 0
168 | for i in range(len(present_mu)):
169 | kl_loss += self.kl_div(present_mu[i], present_log_sigma[i], future_mu[i], future_log_sigma[i])
170 | elif self.method == 'BERNOULLI':
171 | present_log_prob = output['present_log_prob']
172 | future_log_prob = output['future_log_prob']
173 |
174 | kl_loss = F.kl_div(present_log_prob, future_log_prob, reduction='batchmean', log_target=True)
175 | else:
176 | raise NotImplementedError
177 |
178 |
179 | return kl_loss
--------------------------------------------------------------------------------
/sad/utils/sampler.py:
--------------------------------------------------------------------------------
1 | # sampler.py
2 | # trajectory sampler
3 | # including a mix of clothoids, straight lines, and circles
4 | import numpy as np
5 | import torch
6 | from scipy.special import fresnel
7 |
8 | def sample(v0, Kappa, T0, N0, tt, M, possibility = None):
9 | '''
10 | :param v0: initial velocity
11 | :param Kappa: curvature
12 | :param T0: initial tangent vector
13 | :param N0: initial normal vector
14 | :param tt: time stamp
15 | :param M: the number of sample
16 | :param possibility: torch.Tensor [3]
17 | :param debug: whether in debug mode
18 | :return: the nparray of trajectory
19 | '''
20 | # sample accelerations
21 | if possibility is None:
22 | possibility = [0.4, 0.2, 0.4]
23 |
24 | straight_num = int(M * possibility[1])
25 | left_num = int(M * possibility[0])
26 | right_num = int(M * possibility[2])
27 |
28 | accelerations = 10*(np.random.rand(M)-0.5) + 2 # -3m/s^2 to 7m/s^2
29 |
30 | # sample velocities
31 | # randomly sample a velocity <=15m/s at 80% of time
32 | v_options = np.stack((np.full(M, v0), 15*np.random.rand(M)))
33 | v_selections = (np.random.rand(M) >= 0.2).astype(int)
34 | velocities = v_options[v_selections, np.arange(M)]
35 |
36 | # generate longitudinal distances
37 | L = velocities[:, None] * tt[None, :] + accelerations[:, None] * (tt[None, :]**2) / 2
38 | L_straight = L[:straight_num]
39 | L = L[straight_num:]
40 | # print("L:", L)
41 |
42 | # scaling factor which determine the Clothiod curve 6 ~ 80
43 | alphas = (80 - 6) * np.random.rand(left_num + right_num) + 6
44 |
45 | ############################################################################
46 | # sample M straight lines
47 | line_points = L_straight[:, :, None] * T0[None, None, :]
48 | line_thetas = np.zeros_like(L_straight)
49 | lines = np.concatenate((line_points, line_thetas[:, :, None]), axis=-1)
50 |
51 | ############################################################################
52 | # sample M circles
53 | Krappa = min(-0.01, Kappa) if Kappa <= 0 else max(0.01, Kappa)
54 | radius = np.abs(1 / Krappa)
55 | center = np.array([-1 / Krappa, 0])
56 | circle_phis = L / radius if Krappa >= 0 else np.pi - L/radius
57 |
58 | circle_points = np.dstack([
59 | center[0] + radius * np.cos(circle_phis),
60 | center[1] + radius * np.sin(circle_phis),
61 | ])
62 |
63 | # rotate thetas, wrap
64 | circle_thetas = L/radius if Krappa >= 0 else -L/radius
65 | circle_thetas = (circle_thetas + np.pi) % (2 * np.pi) - np.pi
66 | circles = np.concatenate((circle_points, circle_thetas[:, :, None]), axis=-1)
67 |
68 | ############################################################################
69 | # sample M clothoids
70 | # Xi0 = Kappa / np.pi
71 | # Xis = Xi0 + L
72 | Xi0 = np.abs(Kappa) / np.pi
73 | Xis = Xi0 + L
74 |
75 | #
76 | # Ss, Cs = fresnel((Xis - Xi0) / alphas[:, None])
77 | Ss, Cs = fresnel(Xis / alphas[:, None])
78 |
79 | clothoid_points = alphas[:, None, None] * (Cs[:, :, None]*T0[None, None, :] + Ss[:, :, None]*N0[None, None, :])
80 |
81 | #
82 | Xs = clothoid_points[:, :, 0] - clothoid_points[:, 0, 0, None]
83 | Ys = clothoid_points[:, :, 1] - clothoid_points[:, 0, 1, None]
84 | clothoid_theta0s = 0.5 * np.pi * ((Kappa / np.pi / alphas) ** 2)
85 | clothoid_theta0s = clothoid_theta0s[:, None]
86 | signed_clothoid_theta0s = clothoid_theta0s * np.sign(Kappa)
87 | # when kappa is positive, the clothoid curves left, theta is positive
88 | # we will rotate it clockwise by theta
89 | # when kappa is negative, the clothoid curves right, theta is negative
90 | # we will rotate it counterclockwise by theta
91 | clothoid_points[:, :, 0] = np.cos(signed_clothoid_theta0s) * Xs + np.sin(signed_clothoid_theta0s) * Ys
92 | clothoid_points[:, :, 1] = - np.sin(signed_clothoid_theta0s) * Xs + np.cos(signed_clothoid_theta0s) * Ys
93 |
94 | # tangent vector: http://mathworld.wolfram.com/CornuSpiral.html
95 | clothoid_thetas = 0.5 * np.pi * ((Xis / alphas[:, None])**2)
96 | clothoid_thetas = clothoid_thetas - clothoid_theta0s
97 | signed_clothoid_thetas = clothoid_thetas * np.sign(Kappa)
98 | # clothoid_thetas = clothoid_thetas if Krappa >= 0 else -clothoid_thetas
99 | # wrap
100 | # clothoid_thetas = (clothoid_thetas + np.pi) % (2 * np.pi) - np.pi
101 | wrapped_signed_clothoid_thetas = (signed_clothoid_thetas + np.pi) % (2 * np.pi) - np.pi
102 | # wrapped_signed_clothoid_thetas = (signed_clothoid_thetas) % (2 * np.pi)
103 | #
104 | clothoids = np.concatenate((clothoid_points, wrapped_signed_clothoid_thetas[:, :, None]), axis=-1)
105 |
106 | ############################################################################
107 | # pick M in total
108 | t_options = np.stack((circles, clothoids))
109 | t_selections = np.random.choice([0, 1], size=left_num + right_num, p=(0.2, 0.8))
110 | ############################################################################
111 |
112 | trajs = t_options[t_selections, np.arange(left_num + right_num)]
113 |
114 | # toss a coin for vertical flipping
115 | # left_possibility = possibility[1] / (possibility[1] + possibility[2])
116 | # if Kappa > 0:
117 | # heads = (np.random.rand(M) <= left_possibility.item())
118 | # else:
119 | # heads = (np.random.rand(M) <= (1- left_possibility).item())
120 | # tails = np.logical_not(heads)
121 | #
122 | # # NOTE theta means what here
123 | # conditions = [heads[:, None, None], tails[:, None, None]]
124 | # choices = [trajs, np.dstack((
125 | # -trajs[:, :, 0], trajs[:, :, 1], -trajs[:, :, 2]
126 | # ))]
127 | #
128 | # trajectories = np.select(conditions, choices)
129 | if Kappa > 0:
130 | left_curve = trajs[: left_num]
131 | right_curve = trajs[left_num: left_num + right_num]
132 | right_curve = np.dstack((
133 | -right_curve[:, :, 0], right_curve[:, :, 1], -right_curve[:, :, 2]
134 | ))
135 | else:
136 | right_curve = trajs[: left_num]
137 | left_curve = trajs[left_num: left_num + right_num]
138 | left_curve = np.dstack((
139 | -left_curve[:, :, 0], left_curve[:, :, 1], -left_curve[:, :, 2]
140 | ))
141 |
142 | trajectories = np.concatenate([left_curve, lines, right_curve], axis=0)
143 | mask = np.argsort(trajectories[:, -1, 0])
144 | trajectories = trajectories[mask]
145 |
146 | return trajectories
147 |
148 | if __name__ == "__main__":
149 | from nuscenes.nuscenes import NuScenes
150 | from nuscenes.can_bus.can_bus_api import NuScenesCanBus
151 | import matplotlib.pyplot as plt
152 | nusc = NuScenes("v1.0-mini", '/home/hsc/data/Nuscenes')
153 | nusc_can = NuScenesCanBus(dataroot='/home/hsc/data/Nuscenes')
154 |
155 | for scene in nusc.scene:
156 | scene_name = scene["name"]
157 | scene_id = int(scene_name[-4:])
158 | if scene_id in nusc_can.can_blacklist:
159 | print(f"skipping {scene_name}")
160 | continue
161 | pose = nusc_can.get_messages(scene_name, "pose") # The current pose of the ego vehicle, sampled at 50 HZ
162 | saf = nusc_can.get_messages(scene_name, "steeranglefeedback") # Steering angle feedback in radians at 100 HZ
163 | vm = nusc_can.get_messages(scene_name, "vehicle_monitor") # information most, but sample at 2 HZ
164 | # NOTE: I tried to verify if the relevant measurements are consistent
165 | # across multiple tables that contain redundant information
166 | # NOTE: verified pose's velocity matches vehicle monitor's
167 | # but the pose table offers at a much higher frequency
168 | # NOTE: same that steeranglefeedback's steering angle matches vehicle monitor's
169 | # but the steeranglefeedback table offers at a much higher frequency
170 | print(pose[23])
171 | print(saf[45])
172 | print(vm[0])
173 |
174 | # initial velocity (m/s)
175 | v0 = pose[23]["vel"][0]
176 | # curvature
177 | #Kappa = 2 * saf[45]["value"] / 2.588 # 2 x \phi / distance between front and rear
178 | Kappa = 0
179 | # T0: longitudinal axis Tangent vector
180 | T0 = np.array([0.0, 1.0])
181 | # N0: normal directional vector Normal vector
182 | N0 = np.array([1.0, 0.0]) if Kappa <= 0 else np.array([-1.0, 0.0])
183 | # tt: time stamps
184 | tt = np.arange(0.0, 3.01, 0.01)
185 | # M: number of samples
186 | M = 1800
187 | #
188 | debug = False
189 | #
190 | trajectories = sample(v0, Kappa, T0, N0, tt, M)
191 |
192 | trajectories = trajectories[:,::100]
193 |
194 | #
195 | for i in range(len(trajectories)):
196 | trajectory = trajectories[i]
197 | plt.plot(trajectory[:, 0], trajectory[:, 1])
198 | plt.grid(False)
199 | plt.axis("equal")
200 |
201 | plt.show()
202 |
203 | break
204 |
--------------------------------------------------------------------------------
/sad/models/decoder_snn_ms.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torchvision.models.resnet import resnet18
4 |
5 | from spikingjelly.clock_driven.layer import SeqToANNContainer
6 | from sad.models.module.Real_MS_ResNet import multi_step_sew_resnet18
7 | from sad.layers.convolutions import UpsamplingAdd, SpikingUpsamplingAdd_MS, SpikingDeepLabHead
8 | from spikingjelly.clock_driven.neuron import (
9 | MultiStepLIFNode,
10 | MultiStepParametricLIFNode,
11 | )
12 | from spikingjelly.clock_driven import surrogate, layer
13 |
14 | class MeanDim(nn.Module):
15 | def __init__(self, dim=0):
16 | super(MeanDim, self).__init__()
17 | self.dim = dim
18 |
19 | def forward(self, x):
20 | return x.mean(self.dim)
21 | class Decoder(nn.Module):
22 | def __init__(self, in_channels, n_classes, n_present, n_hdmap, predict_gate):
23 | super().__init__()
24 | self.perceive_hdmap = predict_gate['perceive_hdmap']
25 | self.predict_pedestrian = predict_gate['predict_pedestrian']
26 | self.predict_instance = predict_gate['predict_instance']
27 | self.predict_future_flow = predict_gate['predict_future_flow']
28 | self.planning = predict_gate['planning']
29 |
30 | self.n_classes = n_classes
31 | self.n_present = n_present
32 | if self.predict_instance is False and self.predict_future_flow is True:
33 | raise ValueError('flow cannot be True when not predicting instance')
34 |
35 | backbone = multi_step_sew_resnet18(pretrained=False, multi_step_neuron=MultiStepLIFNode)
36 |
37 | self.first_conv = SeqToANNContainer(nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False))
38 | self.bn1 = SeqToANNContainer(nn.BatchNorm2d(64))
39 | self.lif1 = MultiStepLIFNode(tau=2.0, detach_reset=True, backend="torch", surrogate_function=surrogate.ATan())
40 |
41 | self.layer1 = backbone.layer1
42 | self.layer2 = backbone.layer2
43 | self.layer3 = backbone.layer3
44 |
45 | shared_out_channels = in_channels
46 | self.up3_skip = SpikingUpsamplingAdd_MS(256, 128, scale_factor=2)
47 | #self.up3_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend="torch", surrogate_function=surrogate.ATan())
48 | self.up2_skip = SpikingUpsamplingAdd_MS(128, 64, scale_factor=2)
49 | #self.up2_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend="torch", surrogate_function=surrogate.ATan())
50 | self.up1_skip = SpikingUpsamplingAdd_MS(64, shared_out_channels, scale_factor=2)
51 | self.up1_lif = MultiStepLIFNode(tau=2.0, v_threshold=1.0, detach_reset=True, backend="torch", surrogate_function=surrogate.ATan())
52 |
53 | self.segmentation_head = nn.Sequential(
54 | nn.Conv2d(shared_out_channels, shared_out_channels, kernel_size=3, padding=1, bias=False),
55 | nn.BatchNorm2d(shared_out_channels),
56 | nn.Conv2d(shared_out_channels, self.n_classes, kernel_size=1, padding=0),
57 | )
58 |
59 | if self.predict_pedestrian:
60 | self.pedestrian_head = nn.Sequential(
61 | nn.Conv2d(shared_out_channels, shared_out_channels, kernel_size=3, padding=1, bias=False),
62 | nn.BatchNorm2d(shared_out_channels),
63 | nn.Conv2d(shared_out_channels, self.n_classes, kernel_size=1, padding=0),
64 | )
65 |
66 | if self.perceive_hdmap:
67 | self.hdmap_head = nn.Sequential(
68 | nn.Conv2d(shared_out_channels, shared_out_channels, kernel_size=3, padding=1, bias=False),
69 | nn.Conv2d(shared_out_channels, 2 * n_hdmap, kernel_size=1, padding=0),
70 | )
71 |
72 | if self.predict_instance:
73 | self.instance_offset_head = nn.Sequential(
74 | nn.Conv2d(shared_out_channels, shared_out_channels, kernel_size=3, padding=1, bias=False),
75 | nn.BatchNorm2d(shared_out_channels),
76 | nn.ReLU(inplace=True),
77 | nn.Conv2d(shared_out_channels, 2, kernel_size=1, padding=0),
78 | )
79 | self.instance_center_head = nn.Sequential(
80 | nn.Conv2d(shared_out_channels, shared_out_channels, kernel_size=3, padding=1, bias=False),
81 | nn.BatchNorm2d(shared_out_channels),
82 | nn.ReLU(inplace=True),
83 | nn.Conv2d(shared_out_channels, 1, kernel_size=1, padding=0),
84 | nn.Sigmoid(),
85 | )
86 |
87 | if self.predict_future_flow:
88 | self.instance_future_head = nn.Sequential(
89 | nn.Conv2d(shared_out_channels, shared_out_channels, kernel_size=3, padding=1, bias=False),
90 | nn.BatchNorm2d(shared_out_channels),
91 | nn.ReLU(inplace=True),
92 | nn.Conv2d(shared_out_channels, 2, kernel_size=1, padding=0),
93 | )
94 |
95 | if self.planning:
96 | self.costvolume_head = nn.Sequential(
97 | nn.Conv2d(shared_out_channels, shared_out_channels, kernel_size=3, padding=1, bias=False),
98 | nn.BatchNorm2d(shared_out_channels),
99 | nn.Conv2d(shared_out_channels, 1, kernel_size=1, padding=0),
100 | )
101 |
102 |
103 | def forward(self, x):
104 | b, s, c, h, w = x.shape
105 | x = x.reshape(s, b, c, h, w)
106 | #x = x.view(b * s, c, h, w)
107 | # (H, W)
108 | skip_x = {'1': x}
109 | x = self.up1_lif(x)
110 |
111 |
112 | # (H/2, W/2)
113 | x = self.first_conv(x)
114 | x = self.bn1(x)
115 | # _, c_1, h_1, w_1 = x.shape
116 | # x = x.reshape(s, b, c_1, h_1, w_1)
117 | x = self.layer1(x)
118 | skip_x['2'] = x
119 |
120 | # (H/4 , W/4)
121 | x = self.layer2(x)
122 | skip_x['3'] = x
123 |
124 | # (H/8, W/8)
125 | x = self.layer3(x) # (b*s, 256, 25, 25)
126 |
127 | # First upsample to (H/4, W/4)
128 | #x = x.view(b*s, *x.shape[2:])
129 | x = self.up3_skip(x, skip_x['3'])
130 | #x = x.reshape(s, b, *x.shape[1:])
131 | #x = self.up3_lif(x)
132 |
133 | # Second upsample to (H/2, W/2)
134 | #x = x.view(b*s, *x.shape[2:])
135 | x = self.up2_skip(x, skip_x['2'])
136 | # x = x.reshape(s, b, *x.shape[1:])
137 | # x = self.up2_lif(x)
138 |
139 | # Third upsample to (H, W)
140 | # x = x.view(b*s, *x.shape[2:])
141 | x = self.up1_skip(x, skip_x['1'])
142 | # x = x.reshape(s, b, *x.shape[1:])
143 | x = self.up1_lif(x)
144 | x = x.view(b*s, *x.shape[2:])
145 |
146 | segmentation_output = self.segmentation_head(x)
147 | pedestrian_output = self.pedestrian_head(x) if self.predict_pedestrian else None
148 | hdmap_output = self.hdmap_head(
149 | x.view(b, s, *x.shape[1:])[:, self.n_present - 1]) if self.perceive_hdmap else None
150 | instance_center_output = self.instance_center_head(x) if self.predict_instance else None
151 | instance_offset_output = self.instance_offset_head(x) if self.predict_instance else None
152 | instance_future_output = self.instance_future_head(x) if self.predict_future_flow else None
153 | costvolume = self.costvolume_head(x).squeeze(1) if self.planning else None
154 | return {
155 | 'segmentation': segmentation_output.view(b, s, *segmentation_output.shape[1:]),
156 | 'pedestrian': pedestrian_output.view(b, s, *pedestrian_output.shape[1:])
157 | if pedestrian_output is not None else None,
158 | 'hdmap': hdmap_output,
159 | 'instance_center': instance_center_output.view(b, s, *instance_center_output.shape[1:])
160 | if instance_center_output is not None else None,
161 | 'instance_offset': instance_offset_output.view(b, s, *instance_offset_output.shape[1:])
162 | if instance_offset_output is not None else None,
163 | 'instance_flow': instance_future_output.view(b, s, *instance_future_output.shape[1:])
164 | if instance_future_output is not None else None,
165 | 'costvolume': costvolume.view(b, s, *costvolume.shape[1:])
166 | if costvolume is not None else None,
167 | }
168 |
169 |
170 | # 定义Decoder类
171 | # 注意:这里省略了Decoder类的定义,因为它已经在你的提问中给出。
172 | # 请确保完整的Decoder类定义包含在这个脚本中,或者通过适当的导入包含在内。
173 |
174 | def main():
175 | # 假设的配置参数
176 | in_channels = 3 # 输入图像的通道数
177 | n_classes = 10 # 分类的类别数
178 | n_present = 5 # 假设的当前帧数
179 | n_hdmap = 2 # HD地图的输出通道数
180 | predict_gate = {
181 | 'perceive_hdmap': True,
182 | 'predict_pedestrian': True,
183 | 'predict_instance': True,
184 | 'predict_future_flow': True,
185 | 'planning': True
186 | }
187 |
188 | # 创建Decoder实例
189 | decoder = Decoder(in_channels, n_classes, n_present, n_hdmap, predict_gate)
190 |
191 | # 生成一个假设的输入张量,假设batch_size=1, sequence_length=5, height=256, width=256
192 | x = torch.randn(1, 5, in_channels, 256, 256)
193 |
194 | # 前向传播
195 | output = decoder(x)
196 |
197 | # 打印输出结果的一些信息
198 | print("Output keys:", output.keys())
199 | for key, value in output.items():
200 | if value is not None:
201 | print(f"{key}: shape {value.shape}")
202 | else:
203 | print(f"{key}: None")
204 |
205 | if __name__ == "__main__":
206 | main()
207 |
--------------------------------------------------------------------------------
/sad/models/module/.ipynb_checkpoints/MS_ResNet-checkpoint.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from spikingjelly.clock_driven import functional
4 | try:
5 | from torchvision.models.utils import load_state_dict_from_url
6 | except ImportError:
7 | from torchvision._internally_replaced_utils import load_state_dict_from_url
8 |
9 | __all__ = ['MultiStepMSResNet', 'multi_step_sew_resnet18']
10 |
11 | model_urls = {
12 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
13 |
14 | }
15 |
16 | # modified by https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py
17 |
18 | def sew_function(x: torch.Tensor, y: torch.Tensor, cnf:str):
19 | if cnf == 'ADD':
20 | return x + y
21 | elif cnf == 'AND':
22 | return x * y
23 | elif cnf == 'IAND':
24 | return x * (1. - y)
25 | else:
26 | raise NotImplementedError
27 |
28 |
29 |
30 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
31 | """3x3 convolution with padding"""
32 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
33 | padding=dilation, groups=groups, bias=False, dilation=dilation)
34 |
35 |
36 | def conv1x1(in_planes, out_planes, stride=1):
37 | """1x1 convolution"""
38 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
39 | class BasicBlock(nn.Module):
40 | expansion = 1
41 |
42 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
43 | base_width=64, dilation=1, norm_layer=None, cnf: str = None, single_step_neuron: callable = None, **kwargs):
44 | super(BasicBlock, self).__init__()
45 | if norm_layer is None:
46 | norm_layer = nn.BatchNorm2d
47 | if groups != 1 or base_width != 64:
48 | raise ValueError('BasicBlock only supports groups=1 and base_width=64')
49 | if dilation > 1:
50 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
51 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
52 | self.conv1 = conv3x3(inplanes, planes, stride)
53 | self.bn1 = norm_layer(planes)
54 | self.sn1 = single_step_neuron(**kwargs)
55 | self.conv2 = conv3x3(planes, planes)
56 | self.bn2 = norm_layer(planes)
57 | self.sn2 = single_step_neuron(**kwargs)
58 | self.downsample = downsample
59 | if downsample is not None:
60 | self.downsample_sn = single_step_neuron(**kwargs)
61 | self.stride = stride
62 | self.cnf = cnf
63 |
64 | def forward(self, x):
65 | identity = x
66 |
67 | out = self.conv1(x)
68 | out = self.bn1(out)
69 | out = self.sn1(out)
70 |
71 | out = self.conv2(out)
72 | out = self.bn2(out)
73 | out = self.sn2(out)
74 |
75 | if self.downsample is not None:
76 | identity = self.downsample_sn(self.downsample(x))
77 |
78 | out = sew_function(identity, out, self.cnf)
79 |
80 | return out
81 |
82 | def extra_repr(self) -> str:
83 | return super().extra_repr() + f'cnf={self.cnf}'
84 |
85 |
86 |
87 | class MultiStepBasicBlock(BasicBlock):
88 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
89 | base_width=64, dilation=1, norm_layer=None, cnf: str = None, multi_step_neuron: callable = None, **kwargs):
90 | super().__init__(inplanes, planes, stride, downsample, groups,
91 | base_width, dilation, norm_layer, cnf, multi_step_neuron, **kwargs)
92 |
93 | def forward(self, x_seq):
94 | identity = x_seq
95 |
96 | out = self.sn1(x_seq)
97 | out = functional.seq_to_ann_forward(x_seq, [self.conv1, self.bn1])
98 |
99 | out = self.sn2(out)
100 | out = functional.seq_to_ann_forward(out, [self.conv2, self.bn2])
101 |
102 |
103 | if self.downsample is not None:
104 | identity = functional.seq_to_ann_forward(x_seq, self.downsample)
105 |
106 | out = identity + out
107 |
108 | return out
109 |
110 |
111 | class MultiStepMSResNet(nn.Module):
112 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
113 | groups=1, width_per_group=64, replace_stride_with_dilation=None,
114 | norm_layer=None, T:int=None, cnf: str=None, multi_step_neuron: callable = None, **kwargs):
115 | super().__init__()
116 | self.T = T
117 | if norm_layer is None:
118 | norm_layer = nn.BatchNorm2d
119 | self._norm_layer = norm_layer
120 |
121 | self.inplanes = 64
122 | self.dilation = 1
123 | if replace_stride_with_dilation is None:
124 | # each element in the tuple indicates if we should replace
125 | # the 2x2 stride with a dilated convolution instead
126 | replace_stride_with_dilation = [False, False, False]
127 | if len(replace_stride_with_dilation) != 3:
128 | raise ValueError("replace_stride_with_dilation should be None "
129 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
130 | self.groups = groups
131 | self.base_width = width_per_group
132 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
133 | bias=False)
134 | self.bn1 = norm_layer(self.inplanes)
135 | self.sn1 = multi_step_neuron(**kwargs)
136 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
137 | self.layer1 = self._make_layer(block, 64, layers[0], cnf=cnf, multi_step_neuron=multi_step_neuron, **kwargs)
138 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
139 | dilate=replace_stride_with_dilation[0], cnf=cnf, multi_step_neuron=multi_step_neuron,
140 | **kwargs)
141 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
142 | dilate=replace_stride_with_dilation[1], cnf=cnf, multi_step_neuron=multi_step_neuron,
143 | **kwargs)
144 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
145 | dilate=replace_stride_with_dilation[2], cnf=cnf, multi_step_neuron=multi_step_neuron,
146 | **kwargs)
147 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
148 | self.fc = nn.Linear(512 * block.expansion, num_classes)
149 |
150 | for m in self.modules():
151 | if isinstance(m, nn.Conv2d):
152 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
153 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
154 | nn.init.constant_(m.weight, 1)
155 | nn.init.constant_(m.bias, 0)
156 |
157 | # Zero-initialize the last BN in each residual branch,
158 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
159 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
160 | if zero_init_residual:
161 | for m in self.modules():
162 | if isinstance(m, Bottleneck):
163 | nn.init.constant_(m.bn3.weight, 0)
164 | elif isinstance(m, BasicBlock):
165 | nn.init.constant_(m.bn2.weight, 0)
166 |
167 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False, cnf: str = None, multi_step_neuron: callable = None, **kwargs):
168 | norm_layer = self._norm_layer
169 | downsample = None
170 | previous_dilation = self.dilation
171 | if dilate:
172 | self.dilation *= stride
173 | stride = 1
174 | if stride != 1 or self.inplanes != planes * block.expansion:
175 | downsample = nn.Sequential(
176 | conv1x1(self.inplanes, planes * block.expansion, stride),
177 | norm_layer(planes * block.expansion),
178 | )
179 |
180 | layers = []
181 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
182 | self.base_width, previous_dilation, norm_layer, cnf, multi_step_neuron, **kwargs))
183 | self.inplanes = planes * block.expansion
184 | for _ in range(1, blocks):
185 | layers.append(block(self.inplanes, planes, groups=self.groups,
186 | base_width=self.base_width, dilation=self.dilation,
187 | norm_layer=norm_layer, cnf=cnf, multi_step_neuron=multi_step_neuron, **kwargs))
188 |
189 | return nn.Sequential(*layers)
190 |
191 | def _forward_impl(self, x: torch.Tensor):
192 | # See note [TorchScript super()]
193 | x_seq = None
194 | if x.dim() == 5:
195 | # x.shape = [T, N, C, H, W]
196 | x_seq = functional.seq_to_ann_forward(x, [self.conv1, self.bn1])
197 | else:
198 | assert self.T is not None, 'When x.shape is [N, C, H, W], self.T can not be None.'
199 | # x.shape = [N, C, H, W]
200 | x = self.conv1(x)
201 | x = self.bn1(x)
202 | x.unsqueeze_(0)
203 | x_seq = x.repeat(self.T, 1, 1, 1, 1)
204 |
205 | x_seq = functional.seq_to_ann_forward(x_seq, self.maxpool)
206 |
207 | x_seq = self.layer1(x_seq)
208 | x_seq = self.layer2(x_seq)
209 | x_seq = self.layer3(x_seq)
210 | x_seq = self.layer4(x_seq)
211 |
212 | x_seq = functional.seq_to_ann_forward(x_seq, self.avgpool)
213 | x_seq = self.sn1(x_seq)
214 | x_seq = torch.flatten(x_seq, 2)
215 | # x_seq = self.fc(x_seq.mean(0))
216 | x_seq = functional.seq_to_ann_forward(x_seq, self.fc)
217 |
218 | return x_seq
219 |
220 | def forward(self, x):
221 | """
222 | :param x: the input with `shape=[N, C, H, W]` or `[*, N, C, H, W]`
223 | :type x: torch.Tensor
224 | :return: output
225 | :rtype: torch.Tensor
226 | """
227 | return self._forward_impl(x)
228 |
229 |
230 |
231 |
232 | def _multi_step_sew_resnet(arch, block, layers, pretrained, progress, T, cnf, multi_step_neuron, **kwargs):
233 | model = MultiStepMSResNet(block, layers, T=T, cnf=cnf, multi_step_neuron=multi_step_neuron, **kwargs)
234 | if pretrained:
235 | state_dict = load_state_dict_from_url(model_urls[arch],
236 | progress=progress)
237 | model.load_state_dict(state_dict)
238 | return model
239 |
240 | def multi_step_sew_resnet18(pretrained=False, progress=True, T: int = None, cnf: str = None, multi_step_neuron: callable=None, **kwargs):
241 | """
242 | :param pretrained: If True, the SNN will load parameters from the ANN pre-trained on ImageNet
243 | :type pretrained: bool
244 | :param progress: If True, displays a progress bar of the download to stderr
245 | :type progress: bool
246 | :param T: total time-steps
247 | :type T: int
248 | :param cnf: the name of spike-element-wise function
249 | :type cnf: str
250 | :param multi_step_neuron: a multi-step neuron
251 | :type multi_step_neuron: callable
252 | :param kwargs: kwargs for `multi_step_neuron`
253 | :type kwargs: dict
254 | :return: Spiking ResNet-18
255 | :rtype: torch.nn.Module
256 |
257 | The multi-step spike-element-wise ResNet-18 `"Deep Residual Learning in Spiking Neural Networks" `_
258 | modified by the ResNet-18 model from `"Deep Residual Learning for Image Recognition" `_
259 | """
260 |
261 | return _multi_step_sew_resnet('resnet18', MultiStepBasicBlock, [2, 2, 2, 2], pretrained, progress, T, cnf, multi_step_neuron, **kwargs)
262 |
--------------------------------------------------------------------------------
/evaluate.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentParser
2 | from PIL import Image
3 | import torch
4 | import torch.utils.data
5 | import numpy as np
6 | import torchvision
7 | from tqdm import tqdm
8 | from nuscenes.nuscenes import NuScenes
9 | import matplotlib
10 | from matplotlib import pyplot as plt
11 | import pathlib
12 | import datetime
13 |
14 | from sad.datas.NuscenesData import FuturePredictionDataset
15 | from sad.trainer import TrainingModule
16 | from sad.metrics import IntersectionOverUnion, PanopticMetric, PlanningMetric
17 | from sad.utils.network import preprocess_batch, NormalizeInverse
18 | from sad.utils.instance import predict_instance_segmentation_and_trajectories
19 | from sad.utils.visualisation import make_contour
20 |
21 | def mk_save_dir():
22 | now = datetime.datetime.now()
23 | string = '_'.join(map(lambda x: '%02d' % x, (now.month, now.day, now.hour, now.minute, now.second)))
24 | save_path = pathlib.Path('imgs') / string
25 | save_path.mkdir(parents=True, exist_ok=False)
26 | return save_path
27 |
28 | def eval(checkpoint_path, dataroot):
29 | save_path = mk_save_dir()
30 |
31 | trainer = TrainingModule.load_from_checkpoint(checkpoint_path, strict=True)
32 | print(f'Loaded weights from \n {checkpoint_path}')
33 | trainer.eval()
34 |
35 | device = torch.device('cuda:0')
36 | trainer.to(device)
37 | model = trainer.model
38 |
39 | cfg = model.cfg
40 | cfg.GPUS = "[0]"
41 | cfg.BATCHSIZE = 1
42 | cfg.LIFT.GT_DEPTH = False
43 | cfg.DATASET.DATAROOT = dataroot
44 | cfg.DATASET.MAP_FOLDER = dataroot
45 |
46 | dataroot = cfg.DATASET.DATAROOT
47 | nworkers = cfg.N_WORKERS
48 | nusc = NuScenes(version='v1.0-{}'.format(cfg.DATASET.VERSION), dataroot=dataroot, verbose=False)
49 | valdata = FuturePredictionDataset(nusc, 1, cfg)
50 | valloader = torch.utils.data.DataLoader(
51 | valdata, batch_size=cfg.BATCHSIZE, shuffle=False, num_workers=nworkers, pin_memory=True, drop_last=False
52 | )
53 |
54 | n_classes = len(cfg.SEMANTIC_SEG.VEHICLE.WEIGHTS)
55 | hdmap_class = cfg.SEMANTIC_SEG.HDMAP.ELEMENTS
56 | metric_vehicle_val = IntersectionOverUnion(n_classes).to(device)
57 | future_second = int(cfg.N_FUTURE_FRAMES / 2)
58 |
59 | if cfg.SEMANTIC_SEG.PEDESTRIAN.ENABLED:
60 | metric_pedestrian_val = IntersectionOverUnion(n_classes).to(device)
61 |
62 | if cfg.SEMANTIC_SEG.HDMAP.ENABLED:
63 | metric_hdmap_val = []
64 | for i in range(len(hdmap_class)):
65 | metric_hdmap_val.append(IntersectionOverUnion(2, absent_score=1).to(device))
66 |
67 | if cfg.INSTANCE_SEG.ENABLED:
68 | metric_panoptic_val = PanopticMetric(n_classes=n_classes).to(device)
69 |
70 | if cfg.PLANNING.ENABLED:
71 | metric_planning_val = []
72 | for i in range(future_second):
73 | metric_planning_val.append(PlanningMetric(cfg, 2*(i+1)).to(device))
74 |
75 |
76 | for index, batch in enumerate(tqdm(valloader)):
77 | preprocess_batch(batch, device)
78 | image = batch['image']
79 | intrinsics = batch['intrinsics']
80 | extrinsics = batch['extrinsics']
81 | future_egomotion = batch['future_egomotion']
82 | command = batch['command']
83 | trajs = batch['sample_trajectory']
84 | target_points = batch['target_point']
85 | B = len(image)
86 | labels = trainer.prepare_future_labels(batch)
87 |
88 | with torch.no_grad():
89 | output = model(
90 | image, intrinsics, extrinsics, future_egomotion
91 | )
92 |
93 | n_present = model.receptive_field
94 |
95 | # semantic segmentation metric
96 | seg_prediction = output['segmentation'].detach()
97 | seg_prediction = torch.argmax(seg_prediction, dim=2, keepdim=True)
98 | metric_vehicle_val(seg_prediction[:, n_present - 1:], labels['segmentation'][:, n_present - 1:])
99 |
100 | if cfg.SEMANTIC_SEG.PEDESTRIAN.ENABLED:
101 | pedestrian_prediction = output['pedestrian'].detach()
102 | pedestrian_prediction = torch.argmax(pedestrian_prediction, dim=2, keepdim=True)
103 | metric_pedestrian_val(pedestrian_prediction[:, n_present - 1:],
104 | labels['pedestrian'][:, n_present - 1:])
105 | else:
106 | pedestrian_prediction = torch.zeros_like(seg_prediction)
107 |
108 | if cfg.SEMANTIC_SEG.HDMAP.ENABLED:
109 | for i in range(len(hdmap_class)):
110 | hdmap_prediction = output['hdmap'][:, 2 * i:2 * (i + 1)].detach()
111 | hdmap_prediction = torch.argmax(hdmap_prediction, dim=1, keepdim=True)
112 | metric_hdmap_val[i](hdmap_prediction, labels['hdmap'][:, i:i + 1])
113 |
114 | if cfg.INSTANCE_SEG.ENABLED:
115 | pred_consistent_instance_seg = predict_instance_segmentation_and_trajectories(
116 | output, compute_matched_centers=False, make_consistent=True
117 | )
118 | metric_panoptic_val(pred_consistent_instance_seg[:, n_present - 1:],
119 | labels['instance'][:, n_present - 1:])
120 |
121 | if cfg.PLANNING.ENABLED:
122 | occupancy = torch.logical_or(seg_prediction, pedestrian_prediction)
123 | _, final_traj = model.planning(
124 | cam_front=output['cam_front'].detach(),
125 | trajs=trajs[:, :, 1:],
126 | gt_trajs=labels['gt_trajectory'][:, 1:],
127 | cost_volume=output['costvolume'][:, n_present:].detach(),
128 | semantic_pred=occupancy[:, n_present:].squeeze(2),
129 | hd_map=output['hdmap'].detach(),
130 | commands=command,
131 | target_points=target_points
132 | )
133 | occupancy = torch.logical_or(labels['segmentation'][:, n_present:].squeeze(2),
134 | labels['pedestrian'][:, n_present:].squeeze(2))
135 | for i in range(future_second):
136 | cur_time = (i+1)*2
137 | metric_planning_val[i](final_traj[:,:cur_time].detach(), labels['gt_trajectory'][:,1:cur_time+1], occupancy[:,:cur_time])
138 |
139 | if index % 100 == 0:
140 | save(output, labels, batch, n_present, index, save_path)
141 |
142 |
143 | results = {}
144 |
145 | scores = metric_vehicle_val.compute()
146 | results['vehicle_iou'] = scores[1]
147 |
148 | if cfg.SEMANTIC_SEG.PEDESTRIAN.ENABLED:
149 | scores = metric_pedestrian_val.compute()
150 | results['pedestrian_iou'] = scores[1]
151 |
152 | if cfg.SEMANTIC_SEG.HDMAP.ENABLED:
153 | for i, name in enumerate(hdmap_class):
154 | scores = metric_hdmap_val[i].compute()
155 | results[name + '_iou'] = scores[1]
156 |
157 | if cfg.INSTANCE_SEG.ENABLED:
158 | scores = metric_panoptic_val.compute()
159 | for key, value in scores.items():
160 | results['vehicle_'+key] = value[1]
161 |
162 | if cfg.PLANNING.ENABLED:
163 | for i in range(future_second):
164 | scores = metric_planning_val[i].compute()
165 | for key, value in scores.items():
166 | results['plan_'+key+'_{}s'.format(i+1)]=value.mean()
167 |
168 | for key, value in results.items():
169 | print(f'{key} : {value.item()}')
170 |
171 | def save(output, labels, batch, n_present, frame, save_path):
172 | hdmap = output['hdmap'].detach()
173 | segmentation = output['segmentation'][:, n_present - 1].detach()
174 | pedestrian = output['pedestrian'][:, n_present - 1].detach()
175 | gt_trajs = labels['gt_trajectory']
176 | images = batch['image']
177 |
178 | denormalise_img = torchvision.transforms.Compose(
179 | (NormalizeInverse(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
180 | torchvision.transforms.ToPILImage(),)
181 | )
182 |
183 | val_w = 2.99
184 | val_h = 2.99 * (224. / 480.)
185 | plt.figure(1, figsize=(4*val_w,2*val_h))
186 | width_ratios = (val_w,val_w,val_w,val_w)
187 | gs = matplotlib.gridspec.GridSpec(2, 4, width_ratios=width_ratios)
188 | gs.update(wspace=0.0, hspace=0.0, left=0.0, right=1.0, top=1.0, bottom=0.0)
189 |
190 | plt.subplot(gs[0, 0])
191 | #plt.annotate('FRONT LEFT', (0.01, 0.87), c='white', xycoords='axes fraction', fontsize=14)
192 | plt.imshow(denormalise_img(images[0,n_present-1,0].cpu()))
193 | plt.axis('off')
194 |
195 | plt.subplot(gs[0, 1])
196 | #plt.annotate('FRONT', (0.01, 0.87), c='white', xycoords='axes fraction', fontsize=14)
197 | plt.imshow(denormalise_img(images[0,n_present-1,1].cpu()))
198 | plt.axis('off')
199 |
200 | plt.subplot(gs[0, 2])
201 | #plt.annotate('FRONT RIGHT', (0.01, 0.87), c='white', xycoords='axes fraction', fontsize=14)
202 | plt.imshow(denormalise_img(images[0,n_present-1,2].cpu()))
203 | plt.axis('off')
204 |
205 | plt.subplot(gs[1, 0])
206 | #plt.annotate('BACK LEFT', (0.01, 0.87), c='white', xycoords='axes fraction', fontsize=14)
207 | showing = denormalise_img(images[0,n_present-1,3].cpu())
208 | showing = showing.transpose(Image.FLIP_LEFT_RIGHT)
209 | plt.imshow(showing)
210 | plt.axis('off')
211 |
212 | plt.subplot(gs[1, 1])
213 | #plt.annotate('BACK', (0.01, 0.87), c='white', xycoords='axes fraction', fontsize=14)
214 | showing = denormalise_img(images[0, n_present - 1, 4].cpu())
215 | showing = showing.transpose(Image.FLIP_LEFT_RIGHT)
216 | plt.imshow(showing)
217 | plt.axis('off')
218 |
219 | plt.subplot(gs[1, 2])
220 | #plt.annotate('BACK_RIGHT', (0.01, 0.87), c='white', xycoords='axes fraction', fontsize=14)
221 | showing = denormalise_img(images[0, n_present - 1, 5].cpu())
222 | showing = showing.transpose(Image.FLIP_LEFT_RIGHT)
223 | plt.imshow(showing)
224 | plt.axis('off')
225 |
226 | plt.subplot(gs[:, 3])
227 | showing = torch.zeros((200, 200, 3)).numpy()
228 | showing[:, :] = np.array([219 / 255, 215 / 255, 215 / 255])
229 |
230 | # drivable
231 | area = torch.argmax(hdmap[0, 2:4], dim=0).cpu().numpy()
232 | hdmap_index = area > 0
233 | showing[hdmap_index] = np.array([161 / 255, 158 / 255, 158 / 255])
234 |
235 | # lane
236 | area = torch.argmax(hdmap[0, 0:2], dim=0).cpu().numpy()
237 | hdmap_index = area > 0
238 | showing[hdmap_index] = np.array([84 / 255, 70 / 255, 70 / 255])
239 |
240 | # semantic
241 | semantic_seg = torch.argmax(segmentation[0], dim=0).cpu().numpy()
242 | semantic_index = semantic_seg > 0
243 | showing[semantic_index] = np.array([255 / 255, 128 / 255, 0 / 255])
244 |
245 | pedestrian_seg = torch.argmax(pedestrian[0], dim=0).cpu().numpy()
246 | pedestrian_index = pedestrian_seg > 0
247 | showing[pedestrian_index] = np.array([28 / 255, 81 / 255, 227 / 255])
248 |
249 | plt.imshow(make_contour(showing))
250 | plt.axis('off')
251 |
252 | bx = np.array([-50.0 + 0.5/2.0, -50.0 + 0.5/2.0])
253 | dx = np.array([0.5, 0.5])
254 | w, h = 1.85, 4.084
255 | pts = np.array([
256 | [-h / 2. + 0.5, w / 2.],
257 | [h / 2. + 0.5, w / 2.],
258 | [h / 2. + 0.5, -w / 2.],
259 | [-h / 2. + 0.5, -w / 2.],
260 | ])
261 | pts = (pts - bx) / dx
262 | pts[:, [0, 1]] = pts[:, [1, 0]]
263 | plt.fill(pts[:, 0], pts[:, 1], '#76b900')
264 |
265 | plt.xlim((200, 0))
266 | plt.ylim((0, 200))
267 | gt_trajs[0, :, :1] = gt_trajs[0, :, :1] * -1
268 | gt_trajs = (gt_trajs[0, :, :2].cpu().numpy() - bx) / dx
269 | plt.plot(gt_trajs[:, 0], gt_trajs[:, 1], linewidth=3.0)
270 |
271 | plt.savefig(save_path / ('%04d.png' % frame))
272 | plt.close()
273 |
274 | if __name__ == '__main__':
275 | parser = ArgumentParser(description='sad evaluation')
276 | parser.add_argument('--checkpoint', default='last.ckpt', type=str, help='path to checkpoint')
277 | parser.add_argument('--dataroot', default=None, type=str)
278 |
279 | args = parser.parse_args()
280 |
281 | eval(args.checkpoint, args.dataroot)
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/sad/models/module/ms_conv.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | from timm.models.layers import DropPath
4 | from spikingjelly.clock_driven.neuron import (
5 | MultiStepLIFNode,
6 | MultiStepParametricLIFNode,
7 | )
8 |
9 |
10 | class Erode(nn.Module):
11 | def __init__(self) -> None:
12 | super().__init__()
13 | self.pool = nn.MaxPool3d(
14 | kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1)
15 | )
16 |
17 | def forward(self, x):
18 | return self.pool(x)
19 |
20 | class SpatialGatingUnit(nn.Module):
21 | def __init__(self, d_ffn, seq_len, spike_mode):
22 | super().__init__()
23 | d_ffn = int(d_ffn/2)
24 | self.norm = nn.BatchNorm1d(seq_len)
25 | self.spatial_proj = nn.Conv1d(seq_len, seq_len, kernel_size=1)
26 | self.fc1_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend="torch")
27 | #self.local_proj = nn.Conv1d(d_ffn, d_ffn, kernel_size=4, groups=d_ffn, padding='same')
28 | nn.init.constant_(self.spatial_proj.bias, 1.0)
29 | #nn.init.constant_(self.local_proj.bias, 1.0)
30 |
31 |
32 |
33 | def forward(self, x):
34 | T, B, C, L = x.shape
35 |
36 |
37 | L = int(L/2)
38 | u, v = x.chunk(2, dim=-1)
39 | #u = self.local_proj(u.reshape(T*B, L, C))
40 |
41 | #u = self.spatial_proj(u.flatten(0, 1))
42 | #u = self.norm(u.flatten(0, 1))
43 | #u = self.fc1_lif(u.reshape(T, B, C, L))
44 | out = u * v
45 | return out
46 | class MS_MLP_Conv(nn.Module):
47 | def __init__(
48 | self,
49 | in_features,
50 | hidden_features=None,
51 | out_features=None,
52 | drop=0.0,
53 | spike_mode="lif",
54 | layer=0,
55 | ):
56 | super().__init__()
57 | out_features = out_features or in_features
58 | hidden_features = hidden_features or in_features
59 | self.res = in_features == hidden_features
60 | self.fc1_conv = nn.Conv2d(in_features, hidden_features, kernel_size=1, stride=1)
61 | self.fc1_bn = nn.BatchNorm2d(hidden_features)
62 | self.sgu = SpatialGatingUnit(hidden_features, 64, spike_mode)
63 | if spike_mode == "lif":
64 | self.fc1_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend="torch")
65 | elif spike_mode == "plif":
66 | self.fc1_lif = MultiStepParametricLIFNode(
67 | init_tau=2.0, detach_reset=True, backend="torch"
68 | )
69 |
70 | self.fc2_conv = nn.Conv2d(
71 | int(hidden_features/2), out_features, kernel_size=1, stride=1
72 | )
73 | self.fc2_bn = nn.BatchNorm2d(out_features)
74 | if spike_mode == "lif":
75 | self.fc2_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend="torch")
76 | elif spike_mode == "plif":
77 | self.fc2_lif = MultiStepParametricLIFNode(
78 | init_tau=2.0, detach_reset=True, backend="torch"
79 | )
80 |
81 | self.c_hidden = hidden_features
82 | self.c_output = out_features
83 | self.layer = layer
84 |
85 | def forward(self, x, hook=None):
86 | T, B, C, H, W = x.shape
87 | identity = x
88 |
89 | x = self.fc1_lif(x)
90 | if hook is not None:
91 | hook[self._get_name() + str(self.layer) + "_fc1_lif"] = x.detach()
92 | x = self.fc1_conv(x.flatten(0, 1))
93 | x = self.fc1_bn(x).reshape(T, B, self.c_hidden, H, W).contiguous()
94 | if self.res:
95 | x = identity + x
96 | identity = x
97 | x = x.reshape(T, B, self.c_hidden, H*W).permute(0, 1, 3, 2)
98 |
99 | x = self.fc2_lif(x)
100 | x = self.sgu(x).permute(0, 1, 3, 2).reshape(T, B, int(self.c_hidden/2), H, W)
101 | if hook is not None:
102 | hook[self._get_name() + str(self.layer) + "_fc2_lif"] = x.detach()
103 | x = self.fc2_conv(x.flatten(0, 1))
104 | x = self.fc2_bn(x).reshape(T, B, C, H, W).contiguous()
105 |
106 | x = x + identity
107 | return x, hook
108 |
109 |
110 | class MS_SSA_Conv(nn.Module):
111 | def __init__(
112 | self,
113 | dim,
114 | num_heads=8,
115 | qkv_bias=False,
116 | qk_scale=None,
117 | attn_drop=0.0,
118 | proj_drop=0.0,
119 | sr_ratio=1,
120 | mode="direct_xor",
121 | spike_mode="lif",
122 | dvs=False,
123 | layer=0,
124 | ):
125 | super().__init__()
126 | assert (
127 | dim % num_heads == 0
128 | ), f"dim {dim} should be divided by num_heads {num_heads}."
129 | self.dim = dim
130 | self.dvs = dvs
131 | self.num_heads = num_heads
132 | if dvs:
133 | self.pool = Erode()
134 | self.scale = 0.125
135 | self.q_conv = nn.Conv2d(dim, dim, kernel_size=1, stride=1, bias=False)
136 | self.q_bn = nn.BatchNorm2d(dim)
137 | if spike_mode == "lif":
138 | self.q_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend="torch")
139 | elif spike_mode == "plif":
140 | self.q_lif = MultiStepParametricLIFNode(
141 | init_tau=2.0, detach_reset=True, backend="torch"
142 | )
143 |
144 | self.k_conv = nn.Conv2d(dim, dim, kernel_size=1, stride=1, bias=False)
145 | self.k_bn = nn.BatchNorm2d(dim)
146 | if spike_mode == "lif":
147 | self.k_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend="torch")
148 | elif spike_mode == "plif":
149 | self.k_lif = MultiStepParametricLIFNode(
150 | init_tau=2.0, detach_reset=True, backend="torch"
151 | )
152 |
153 | self.v_conv = nn.Conv2d(dim, dim, kernel_size=1, stride=1, bias=False)
154 | self.v_bn = nn.BatchNorm2d(dim)
155 | if spike_mode == "lif":
156 | self.v_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend="torch")
157 | elif spike_mode == "plif":
158 | self.v_lif = MultiStepParametricLIFNode(
159 | init_tau=2.0, detach_reset=True, backend="torch"
160 | )
161 |
162 | if spike_mode == "lif":
163 | self.attn_lif = MultiStepLIFNode(
164 | tau=2.0, v_threshold=0.5, detach_reset=True, backend="torch"
165 | )
166 | elif spike_mode == "plif":
167 | self.attn_lif = MultiStepParametricLIFNode(
168 | init_tau=2.0, v_threshold=0.5, detach_reset=True, backend="torch"
169 | )
170 |
171 | self.talking_heads = nn.Conv1d(
172 | num_heads, num_heads, kernel_size=1, stride=1, bias=False
173 | )
174 | if spike_mode == "lif":
175 | self.talking_heads_lif = MultiStepLIFNode(
176 | tau=2.0, v_threshold=0.5, detach_reset=True, backend="torch"
177 | )
178 | elif spike_mode == "plif":
179 | self.talking_heads_lif = MultiStepParametricLIFNode(
180 | init_tau=2.0, v_threshold=0.5, detach_reset=True, backend="torch"
181 | )
182 |
183 | self.proj_conv = nn.Conv2d(dim, dim, kernel_size=1, stride=1)
184 | self.proj_bn = nn.BatchNorm2d(dim)
185 |
186 | if spike_mode == "lif":
187 | self.shortcut_lif = MultiStepLIFNode(
188 | tau=2.0, detach_reset=True, backend="torch"
189 | )
190 | elif spike_mode == "plif":
191 | self.shortcut_lif = MultiStepParametricLIFNode(
192 | init_tau=2.0, detach_reset=True, backend="torch"
193 | )
194 |
195 | self.mode = mode
196 | self.layer = layer
197 |
198 | def forward(self, x, hook=None):
199 | T, B, C, H, W = x.shape
200 | identity = x
201 | N = H * W
202 | x = self.shortcut_lif(x)
203 | if hook is not None:
204 | hook[self._get_name() + str(self.layer) + "_first_lif"] = x.detach()
205 |
206 | x_for_qkv = x.flatten(0, 1)
207 | q_conv_out = self.q_conv(x_for_qkv)
208 | q_conv_out = self.q_bn(q_conv_out).reshape(T, B, C, H, W).contiguous()
209 | q_conv_out = self.q_lif(q_conv_out)
210 |
211 | if hook is not None:
212 | hook[self._get_name() + str(self.layer) + "_q_lif"] = q_conv_out.detach()
213 | q = (
214 | q_conv_out.flatten(3)
215 | .transpose(-1, -2)
216 | .reshape(T, B, N, self.num_heads, C // self.num_heads)
217 | .permute(0, 1, 3, 2, 4)
218 | .contiguous()
219 | )
220 |
221 | k_conv_out = self.k_conv(x_for_qkv)
222 | k_conv_out = self.k_bn(k_conv_out).reshape(T, B, C, H, W).contiguous()
223 | k_conv_out = self.k_lif(k_conv_out)
224 | if self.dvs:
225 | k_conv_out = self.pool(k_conv_out)
226 | if hook is not None:
227 | hook[self._get_name() + str(self.layer) + "_k_lif"] = k_conv_out.detach()
228 | k = (
229 | k_conv_out.flatten(3)
230 | .transpose(-1, -2)
231 | .reshape(T, B, N, self.num_heads, C // self.num_heads)
232 | .permute(0, 1, 3, 2, 4)
233 | .contiguous()
234 | )
235 |
236 | v_conv_out = self.v_conv(x_for_qkv)
237 | v_conv_out = self.v_bn(v_conv_out).reshape(T, B, C, H, W).contiguous()
238 | v_conv_out = self.v_lif(v_conv_out)
239 | if self.dvs:
240 | v_conv_out = self.pool(v_conv_out)
241 | if hook is not None:
242 | hook[self._get_name() + str(self.layer) + "_v_lif"] = v_conv_out.detach()
243 | v = (
244 | v_conv_out.flatten(3)
245 | .transpose(-1, -2)
246 | .reshape(T, B, N, self.num_heads, C // self.num_heads)
247 | .permute(0, 1, 3, 2, 4)
248 | .contiguous()
249 | ) # T B head N C//h
250 |
251 | kv = k.mul(v)
252 | if hook is not None:
253 | hook[self._get_name() + str(self.layer) + "_kv_before"] = kv
254 | if self.dvs:
255 | kv = self.pool(kv)
256 | kv = kv.sum(dim=-2, keepdim=True)
257 | kv = self.talking_heads_lif(kv)
258 | if hook is not None:
259 | hook[self._get_name() + str(self.layer) + "_kv"] = kv.detach()
260 | x = q.mul(kv)
261 | if self.dvs:
262 | x = self.pool(x)
263 | if hook is not None:
264 | hook[self._get_name() + str(self.layer) + "_x_after_qkv"] = x.detach()
265 |
266 | x = x.transpose(3, 4).reshape(T, B, C, H, W).contiguous()
267 | x = (
268 | self.proj_bn(self.proj_conv(x.flatten(0, 1)))
269 | .reshape(T, B, C, H, W)
270 | .contiguous()
271 | )
272 |
273 | x = x + identity
274 | return x, v, hook
275 |
276 |
277 | class MS_Block_Conv(nn.Module):
278 | def __init__(
279 | self,
280 | dim,
281 | num_heads,
282 | mlp_ratio=4.0,
283 | qkv_bias=False,
284 | qk_scale=None,
285 | drop=0.0,
286 | attn_drop=0.0,
287 | drop_path=0.0,
288 | norm_layer=nn.LayerNorm,
289 | sr_ratio=1,
290 | attn_mode="direct_xor",
291 | spike_mode="lif",
292 | dvs=False,
293 | layer=0,
294 | ):
295 | super().__init__()
296 | # self.attn = MS_SSA_Conv(
297 | # dim,
298 | # num_heads=num_heads,
299 | # qkv_bias=qkv_bias,
300 | # qk_scale=qk_scale,
301 | # attn_drop=attn_drop,
302 | # proj_drop=drop,
303 | # sr_ratio=sr_ratio,
304 | # mode=attn_mode,
305 | # spike_mode=spike_mode,
306 | # dvs=dvs,
307 | # layer=layer,
308 | # )
309 | self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
310 | mlp_hidden_dim = int(dim * mlp_ratio)
311 | self.mlp = MS_MLP_Conv(
312 | in_features=dim,
313 | hidden_features=mlp_hidden_dim,
314 | drop=drop,
315 | spike_mode=spike_mode,
316 | layer=layer,
317 | )
318 |
319 | def forward(self, x, hook=None):
320 | #x_attn, attn, hook = self.attn(x, hook=hook)
321 | x, hook = self.mlp(x, hook=hook)
322 | return x, hook
323 |
--------------------------------------------------------------------------------
/sad/models/module/.ipynb_checkpoints/ms_conv-checkpoint.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | from timm.models.layers import DropPath
4 | from spikingjelly.clock_driven.neuron import (
5 | MultiStepLIFNode,
6 | MultiStepParametricLIFNode,
7 | )
8 |
9 |
10 | class Erode(nn.Module):
11 | def __init__(self) -> None:
12 | super().__init__()
13 | self.pool = nn.MaxPool3d(
14 | kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1)
15 | )
16 |
17 | def forward(self, x):
18 | return self.pool(x)
19 |
20 | class SpatialGatingUnit(nn.Module):
21 | def __init__(self, d_ffn, seq_len, spike_mode):
22 | super().__init__()
23 | d_ffn = int(d_ffn/2)
24 | self.norm = nn.BatchNorm1d(seq_len)
25 | self.spatial_proj = nn.Conv1d(seq_len, seq_len, kernel_size=1)
26 | self.fc1_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend="torch")
27 | #self.local_proj = nn.Conv1d(d_ffn, d_ffn, kernel_size=4, groups=d_ffn, padding='same')
28 | nn.init.constant_(self.spatial_proj.bias, 1.0)
29 | #nn.init.constant_(self.local_proj.bias, 1.0)
30 |
31 |
32 |
33 | def forward(self, x):
34 | T, B, C, L = x.shape
35 |
36 |
37 | L = int(L/2)
38 | u, v = x.chunk(2, dim=-1)
39 | #u = self.local_proj(u.reshape(T*B, L, C))
40 |
41 | #u = self.spatial_proj(u.flatten(0, 1))
42 | #u = self.norm(u.flatten(0, 1))
43 | #u = self.fc1_lif(u.reshape(T, B, C, L))
44 | out = u * v
45 | return out
46 | class MS_MLP_Conv(nn.Module):
47 | def __init__(
48 | self,
49 | in_features,
50 | hidden_features=None,
51 | out_features=None,
52 | drop=0.0,
53 | spike_mode="lif",
54 | layer=0,
55 | ):
56 | super().__init__()
57 | out_features = out_features or in_features
58 | hidden_features = hidden_features or in_features
59 | self.res = in_features == hidden_features
60 | self.fc1_conv = nn.Conv2d(in_features, hidden_features, kernel_size=1, stride=1)
61 | self.fc1_bn = nn.BatchNorm2d(hidden_features)
62 | self.sgu = SpatialGatingUnit(hidden_features, 64, spike_mode)
63 | if spike_mode == "lif":
64 | self.fc1_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend="torch")
65 | elif spike_mode == "plif":
66 | self.fc1_lif = MultiStepParametricLIFNode(
67 | init_tau=2.0, detach_reset=True, backend="torch"
68 | )
69 |
70 | self.fc2_conv = nn.Conv2d(
71 | int(hidden_features/2), out_features, kernel_size=1, stride=1
72 | )
73 | self.fc2_bn = nn.BatchNorm2d(out_features)
74 | if spike_mode == "lif":
75 | self.fc2_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend="torch")
76 | elif spike_mode == "plif":
77 | self.fc2_lif = MultiStepParametricLIFNode(
78 | init_tau=2.0, detach_reset=True, backend="torch"
79 | )
80 |
81 | self.c_hidden = hidden_features
82 | self.c_output = out_features
83 | self.layer = layer
84 |
85 | def forward(self, x, hook=None):
86 | T, B, C, H, W = x.shape
87 | identity = x
88 |
89 | x = self.fc1_lif(x)
90 | if hook is not None:
91 | hook[self._get_name() + str(self.layer) + "_fc1_lif"] = x.detach()
92 | x = self.fc1_conv(x.flatten(0, 1))
93 | x = self.fc1_bn(x).reshape(T, B, self.c_hidden, H, W).contiguous()
94 | if self.res:
95 | x = identity + x
96 | identity = x
97 | x = x.reshape(T, B, self.c_hidden, H*W).permute(0, 1, 3, 2)
98 |
99 | x = self.fc2_lif(x)
100 | x = self.sgu(x).permute(0, 1, 3, 2).reshape(T, B, int(self.c_hidden/2), H, W)
101 | if hook is not None:
102 | hook[self._get_name() + str(self.layer) + "_fc2_lif"] = x.detach()
103 | x = self.fc2_conv(x.flatten(0, 1))
104 | x = self.fc2_bn(x).reshape(T, B, C, H, W).contiguous()
105 |
106 | x = x + identity
107 | return x, hook
108 |
109 |
110 | class MS_SSA_Conv(nn.Module):
111 | def __init__(
112 | self,
113 | dim,
114 | num_heads=8,
115 | qkv_bias=False,
116 | qk_scale=None,
117 | attn_drop=0.0,
118 | proj_drop=0.0,
119 | sr_ratio=1,
120 | mode="direct_xor",
121 | spike_mode="lif",
122 | dvs=False,
123 | layer=0,
124 | ):
125 | super().__init__()
126 | assert (
127 | dim % num_heads == 0
128 | ), f"dim {dim} should be divided by num_heads {num_heads}."
129 | self.dim = dim
130 | self.dvs = dvs
131 | self.num_heads = num_heads
132 | if dvs:
133 | self.pool = Erode()
134 | self.scale = 0.125
135 | self.q_conv = nn.Conv2d(dim, dim, kernel_size=1, stride=1, bias=False)
136 | self.q_bn = nn.BatchNorm2d(dim)
137 | if spike_mode == "lif":
138 | self.q_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend="torch")
139 | elif spike_mode == "plif":
140 | self.q_lif = MultiStepParametricLIFNode(
141 | init_tau=2.0, detach_reset=True, backend="torch"
142 | )
143 |
144 | self.k_conv = nn.Conv2d(dim, dim, kernel_size=1, stride=1, bias=False)
145 | self.k_bn = nn.BatchNorm2d(dim)
146 | if spike_mode == "lif":
147 | self.k_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend="torch")
148 | elif spike_mode == "plif":
149 | self.k_lif = MultiStepParametricLIFNode(
150 | init_tau=2.0, detach_reset=True, backend="torch"
151 | )
152 |
153 | self.v_conv = nn.Conv2d(dim, dim, kernel_size=1, stride=1, bias=False)
154 | self.v_bn = nn.BatchNorm2d(dim)
155 | if spike_mode == "lif":
156 | self.v_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend="torch")
157 | elif spike_mode == "plif":
158 | self.v_lif = MultiStepParametricLIFNode(
159 | init_tau=2.0, detach_reset=True, backend="torch"
160 | )
161 |
162 | if spike_mode == "lif":
163 | self.attn_lif = MultiStepLIFNode(
164 | tau=2.0, v_threshold=0.5, detach_reset=True, backend="torch"
165 | )
166 | elif spike_mode == "plif":
167 | self.attn_lif = MultiStepParametricLIFNode(
168 | init_tau=2.0, v_threshold=0.5, detach_reset=True, backend="torch"
169 | )
170 |
171 | self.talking_heads = nn.Conv1d(
172 | num_heads, num_heads, kernel_size=1, stride=1, bias=False
173 | )
174 | if spike_mode == "lif":
175 | self.talking_heads_lif = MultiStepLIFNode(
176 | tau=2.0, v_threshold=0.5, detach_reset=True, backend="torch"
177 | )
178 | elif spike_mode == "plif":
179 | self.talking_heads_lif = MultiStepParametricLIFNode(
180 | init_tau=2.0, v_threshold=0.5, detach_reset=True, backend="torch"
181 | )
182 |
183 | self.proj_conv = nn.Conv2d(dim, dim, kernel_size=1, stride=1)
184 | self.proj_bn = nn.BatchNorm2d(dim)
185 |
186 | if spike_mode == "lif":
187 | self.shortcut_lif = MultiStepLIFNode(
188 | tau=2.0, detach_reset=True, backend="torch"
189 | )
190 | elif spike_mode == "plif":
191 | self.shortcut_lif = MultiStepParametricLIFNode(
192 | init_tau=2.0, detach_reset=True, backend="torch"
193 | )
194 |
195 | self.mode = mode
196 | self.layer = layer
197 |
198 | def forward(self, x, hook=None):
199 | T, B, C, H, W = x.shape
200 | identity = x
201 | N = H * W
202 | x = self.shortcut_lif(x)
203 | if hook is not None:
204 | hook[self._get_name() + str(self.layer) + "_first_lif"] = x.detach()
205 |
206 | x_for_qkv = x.flatten(0, 1)
207 | q_conv_out = self.q_conv(x_for_qkv)
208 | q_conv_out = self.q_bn(q_conv_out).reshape(T, B, C, H, W).contiguous()
209 | q_conv_out = self.q_lif(q_conv_out)
210 |
211 | if hook is not None:
212 | hook[self._get_name() + str(self.layer) + "_q_lif"] = q_conv_out.detach()
213 | q = (
214 | q_conv_out.flatten(3)
215 | .transpose(-1, -2)
216 | .reshape(T, B, N, self.num_heads, C // self.num_heads)
217 | .permute(0, 1, 3, 2, 4)
218 | .contiguous()
219 | )
220 |
221 | k_conv_out = self.k_conv(x_for_qkv)
222 | k_conv_out = self.k_bn(k_conv_out).reshape(T, B, C, H, W).contiguous()
223 | k_conv_out = self.k_lif(k_conv_out)
224 | if self.dvs:
225 | k_conv_out = self.pool(k_conv_out)
226 | if hook is not None:
227 | hook[self._get_name() + str(self.layer) + "_k_lif"] = k_conv_out.detach()
228 | k = (
229 | k_conv_out.flatten(3)
230 | .transpose(-1, -2)
231 | .reshape(T, B, N, self.num_heads, C // self.num_heads)
232 | .permute(0, 1, 3, 2, 4)
233 | .contiguous()
234 | )
235 |
236 | v_conv_out = self.v_conv(x_for_qkv)
237 | v_conv_out = self.v_bn(v_conv_out).reshape(T, B, C, H, W).contiguous()
238 | v_conv_out = self.v_lif(v_conv_out)
239 | if self.dvs:
240 | v_conv_out = self.pool(v_conv_out)
241 | if hook is not None:
242 | hook[self._get_name() + str(self.layer) + "_v_lif"] = v_conv_out.detach()
243 | v = (
244 | v_conv_out.flatten(3)
245 | .transpose(-1, -2)
246 | .reshape(T, B, N, self.num_heads, C // self.num_heads)
247 | .permute(0, 1, 3, 2, 4)
248 | .contiguous()
249 | ) # T B head N C//h
250 |
251 | kv = k.mul(v)
252 | if hook is not None:
253 | hook[self._get_name() + str(self.layer) + "_kv_before"] = kv
254 | if self.dvs:
255 | kv = self.pool(kv)
256 | kv = kv.sum(dim=-2, keepdim=True)
257 | kv = self.talking_heads_lif(kv)
258 | if hook is not None:
259 | hook[self._get_name() + str(self.layer) + "_kv"] = kv.detach()
260 | x = q.mul(kv)
261 | if self.dvs:
262 | x = self.pool(x)
263 | if hook is not None:
264 | hook[self._get_name() + str(self.layer) + "_x_after_qkv"] = x.detach()
265 |
266 | x = x.transpose(3, 4).reshape(T, B, C, H, W).contiguous()
267 | x = (
268 | self.proj_bn(self.proj_conv(x.flatten(0, 1)))
269 | .reshape(T, B, C, H, W)
270 | .contiguous()
271 | )
272 |
273 | x = x + identity
274 | return x, v, hook
275 |
276 |
277 | class MS_Block_Conv(nn.Module):
278 | def __init__(
279 | self,
280 | dim,
281 | num_heads,
282 | mlp_ratio=4.0,
283 | qkv_bias=False,
284 | qk_scale=None,
285 | drop=0.0,
286 | attn_drop=0.0,
287 | drop_path=0.0,
288 | norm_layer=nn.LayerNorm,
289 | sr_ratio=1,
290 | attn_mode="direct_xor",
291 | spike_mode="lif",
292 | dvs=False,
293 | layer=0,
294 | ):
295 | super().__init__()
296 | # self.attn = MS_SSA_Conv(
297 | # dim,
298 | # num_heads=num_heads,
299 | # qkv_bias=qkv_bias,
300 | # qk_scale=qk_scale,
301 | # attn_drop=attn_drop,
302 | # proj_drop=drop,
303 | # sr_ratio=sr_ratio,
304 | # mode=attn_mode,
305 | # spike_mode=spike_mode,
306 | # dvs=dvs,
307 | # layer=layer,
308 | # )
309 | self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
310 | mlp_hidden_dim = int(dim * mlp_ratio)
311 | self.mlp = MS_MLP_Conv(
312 | in_features=dim,
313 | hidden_features=mlp_hidden_dim,
314 | drop=drop,
315 | spike_mode=spike_mode,
316 | layer=layer,
317 | )
318 |
319 | def forward(self, x, hook=None):
320 | #x_attn, attn, hook = self.attn(x, hook=hook)
321 | x, hook = self.mlp(x, hook=hook)
322 | return x, hook
323 |
--------------------------------------------------------------------------------
/sad/models/module/MS_ResNet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from spikingjelly.clock_driven import functional
4 | try:
5 | from torchvision.models.utils import load_state_dict_from_url
6 | except ImportError:
7 | from torchvision._internally_replaced_utils import load_state_dict_from_url
8 |
9 | __all__ = ['MultiStepMSResNet', 'multi_step_sew_resnet18']
10 |
11 | model_urls = {
12 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
13 |
14 | }
15 |
16 | # modified by https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py
17 |
18 | def sew_function(x: torch.Tensor, y: torch.Tensor, cnf:str):
19 | if cnf == 'ADD':
20 | return x + y
21 | elif cnf == 'AND':
22 | return x * y
23 | elif cnf == 'IAND':
24 | return x * (1. - y)
25 | else:
26 | raise NotImplementedError
27 |
28 |
29 |
30 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
31 | """3x3 convolution with padding"""
32 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
33 | padding=dilation, groups=groups, bias=False, dilation=dilation)
34 |
35 |
36 | def conv1x1(in_planes, out_planes, stride=1):
37 | """1x1 convolution"""
38 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
39 | class BasicBlock(nn.Module):
40 | expansion = 1
41 |
42 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
43 | base_width=64, dilation=1, norm_layer=None, cnf: str = None, single_step_neuron: callable = None, **kwargs):
44 | super(BasicBlock, self).__init__()
45 | if norm_layer is None:
46 | norm_layer = nn.BatchNorm2d
47 | if groups != 1 or base_width != 64:
48 | raise ValueError('BasicBlock only supports groups=1 and base_width=64')
49 | if dilation > 1:
50 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
51 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
52 | self.conv1 = conv3x3(inplanes, planes, stride)
53 | self.bn1 = norm_layer(planes)
54 | self.sn1 = single_step_neuron(**kwargs)
55 | self.conv2 = conv3x3(planes, planes)
56 | self.bn2 = norm_layer(planes)
57 | self.sn2 = single_step_neuron(**kwargs)
58 | self.downsample = downsample
59 | if downsample is not None:
60 | self.downsample_sn = single_step_neuron(**kwargs)
61 | self.stride = stride
62 | self.cnf = cnf
63 |
64 | def forward(self, x):
65 | identity = x
66 |
67 | out = self.conv1(x)
68 | out = self.bn1(out)
69 | out = self.sn1(out)
70 |
71 | out = self.conv2(out)
72 | out = self.bn2(out)
73 | out = self.sn2(out)
74 |
75 | if self.downsample is not None:
76 | identity = self.downsample_sn(self.downsample(x))
77 |
78 | out = sew_function(identity, out, self.cnf)
79 |
80 | return out
81 |
82 | def extra_repr(self) -> str:
83 | return super().extra_repr() + f'cnf={self.cnf}'
84 |
85 |
86 |
87 | # class MultiStepBasicBlock(BasicBlock):
88 | # def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
89 | # base_width=64, dilation=1, norm_layer=None, cnf: str = None, multi_step_neuron: callable = None, **kwargs):
90 | # super().__init__(inplanes, planes, stride, downsample, groups,
91 | # base_width, dilation, norm_layer, cnf, multi_step_neuron, **kwargs)
92 | #
93 | # def forward(self, x_seq):
94 | # identity = x_seq
95 | #
96 | # out = self.sn1(x_seq)
97 | # out = functional.seq_to_ann_forward(x_seq, [self.conv1, self.bn1])
98 | #
99 | # out = self.sn2(out)
100 | # out = functional.seq_to_ann_forward(out, [self.conv2, self.bn2])
101 | #
102 | #
103 | # if self.downsample is not None:
104 | # identity = functional.seq_to_ann_forward(x_seq, self.downsample)
105 | #
106 | # out = identity + out
107 | #
108 | # return out
109 |
110 | class MultiStepBasicBlock(BasicBlock):
111 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
112 | base_width=64, dilation=1, norm_layer=None, cnf: str = None, multi_step_neuron: callable = None, **kwargs):
113 | super().__init__(inplanes, planes, stride, downsample, groups,
114 | base_width, dilation, norm_layer, cnf, multi_step_neuron, **kwargs)
115 |
116 | def forward(self, x_seq):
117 | identity = x_seq
118 |
119 | out = functional.seq_to_ann_forward(x_seq, [self.conv1, self.bn1])
120 | out = self.sn1(out)
121 |
122 | out = functional.seq_to_ann_forward(out, [self.conv2, self.bn2])
123 | out = self.sn2(out)
124 |
125 | if self.downsample is not None:
126 | identity = self.downsample_sn(functional.seq_to_ann_forward(x_seq, self.downsample))
127 |
128 | out = sew_function(identity, out, self.cnf)
129 |
130 | return out
131 |
132 | class MultiStepMSResNet(nn.Module):
133 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
134 | groups=1, width_per_group=64, replace_stride_with_dilation=None,
135 | norm_layer=None, T:int=None, cnf: str=None, multi_step_neuron: callable = None, **kwargs):
136 | super().__init__()
137 | self.T = T
138 | if norm_layer is None:
139 | norm_layer = nn.BatchNorm2d
140 | self._norm_layer = norm_layer
141 |
142 | self.inplanes = 64
143 | self.dilation = 1
144 | if replace_stride_with_dilation is None:
145 | # each element in the tuple indicates if we should replace
146 | # the 2x2 stride with a dilated convolution instead
147 | replace_stride_with_dilation = [False, False, False]
148 | if len(replace_stride_with_dilation) != 3:
149 | raise ValueError("replace_stride_with_dilation should be None "
150 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
151 | self.groups = groups
152 | self.base_width = width_per_group
153 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
154 | bias=False)
155 | self.bn1 = norm_layer(self.inplanes)
156 | self.sn1 = multi_step_neuron(**kwargs)
157 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
158 | self.layer1 = self.make_layer(block, 64, layers[0], cnf=cnf, multi_step_neuron=multi_step_neuron, **kwargs)
159 | self.layer2 = self.make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0], cnf=cnf,
160 | multi_step_neuron=multi_step_neuron, **kwargs)
161 | self.layer3 = self.make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1], cnf=cnf,
162 | multi_step_neuron=multi_step_neuron, **kwargs)
163 | self.layer4 = self.make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2], cnf=cnf,
164 | multi_step_neuron=multi_step_neuron, **kwargs)
165 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
166 | self.fc = nn.Linear(512 * block.expansion, num_classes)
167 |
168 | for m in self.modules():
169 | if isinstance(m, nn.Conv2d):
170 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
171 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
172 | nn.init.constant_(m.weight, 1)
173 | nn.init.constant_(m.bias, 0)
174 |
175 | # Zero-initialize the last BN in each residual branch,
176 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
177 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
178 | if zero_init_residual:
179 | for m in self.modules():
180 | if isinstance(m, Bottleneck):
181 | nn.init.constant_(m.bn3.weight, 0)
182 | elif isinstance(m, BasicBlock):
183 | nn.init.constant_(m.bn2.weight, 0)
184 |
185 | def make_layer(self, block, planes, blocks, stride=1, dilate=False, cnf: str = None, multi_step_neuron: callable = None, **kwargs):
186 | norm_layer = self._norm_layer
187 | downsample = None
188 | previous_dilation = self.dilation
189 | if dilate:
190 | self.dilation *= stride
191 | stride = 1
192 | if stride != 1 or self.inplanes != planes * block.expansion:
193 | downsample = nn.Sequential(
194 | conv1x1(self.inplanes, planes * block.expansion, stride),
195 | norm_layer(planes * block.expansion),
196 | )
197 |
198 | layers = []
199 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
200 | self.base_width, previous_dilation, norm_layer, cnf, multi_step_neuron, **kwargs))
201 | self.inplanes = planes * block.expansion
202 | for _ in range(1, blocks):
203 | layers.append(block(self.inplanes, planes, groups=self.groups,
204 | base_width=self.base_width, dilation=self.dilation,
205 | norm_layer=norm_layer, cnf=cnf, multi_step_neuron=multi_step_neuron, **kwargs))
206 |
207 | return nn.Sequential(*layers)
208 |
209 | def _forward_impl(self, x: torch.Tensor):
210 | # See note [TorchScript super()]
211 | x_seq = None
212 | if x.dim() == 5:
213 | # x.shape = [T, N, C, H, W]
214 | x_seq = functional.seq_to_ann_forward(x, [self.conv1, self.bn1])
215 | else:
216 | assert self.T is not None, 'When x.shape is [N, C, H, W], self.T can not be None.'
217 | # x.shape = [N, C, H, W]
218 | x = self.conv1(x)
219 | x = self.bn1(x)
220 | x.unsqueeze_(0)
221 | x_seq = x.repeat(self.T, 1, 1, 1, 1)
222 |
223 | x_seq = functional.seq_to_ann_forward(x_seq, self.maxpool)
224 |
225 | x_seq = self.layer1(x_seq)
226 | x_seq = self.layer2(x_seq)
227 | x_seq = self.layer3(x_seq)
228 | x_seq = self.layer4(x_seq)
229 |
230 | x_seq = functional.seq_to_ann_forward(x_seq, self.avgpool)
231 | x_seq = self.sn1(x_seq)
232 | x_seq = torch.flatten(x_seq, 2)
233 | # x_seq = self.fc(x_seq.mean(0))
234 | x_seq = functional.seq_to_ann_forward(x_seq, self.fc)
235 |
236 | return x_seq
237 |
238 | def forward(self, x):
239 | """
240 | :param x: the input with `shape=[N, C, H, W]` or `[*, N, C, H, W]`
241 | :type x: torch.Tensor
242 | :return: output
243 | :rtype: torch.Tensor
244 | """
245 | return self._forward_impl(x)
246 |
247 |
248 |
249 |
250 | def _multi_step_sew_resnet(arch, block, layers, pretrained, progress, T, cnf, multi_step_neuron, **kwargs):
251 | model = MultiStepMSResNet(block, layers, T=T, cnf=cnf, multi_step_neuron=multi_step_neuron, **kwargs)
252 | if pretrained:
253 | state_dict = load_state_dict_from_url(model_urls[arch],
254 | progress=progress)
255 | model.load_state_dict(state_dict)
256 | return model
257 |
258 | def multi_step_sew_resnet18(pretrained=False, progress=True, T: int = None, cnf: str = 'ADD', multi_step_neuron: callable=None, **kwargs):
259 | """
260 | :param pretrained: If True, the SNN will load parameters from the ANN pre-trained on ImageNet
261 | :type pretrained: bool
262 | :param progress: If True, displays a progress bar of the download to stderr
263 | :type progress: bool
264 | :param T: total time-steps
265 | :type T: int
266 | :param cnf: the name of spike-element-wise function
267 | :type cnf: str
268 | :param multi_step_neuron: a multi-step neuron
269 | :type multi_step_neuron: callable
270 | :param kwargs: kwargs for `multi_step_neuron`
271 | :type kwargs: dict
272 | :return: Spiking ResNet-18
273 | :rtype: torch.nn.Module
274 |
275 | The multi-step spike-element-wise ResNet-18 `"Deep Residual Learning in Spiking Neural Networks" `_
276 | modified by the ResNet-18 model from `"Deep Residual Learning for Image Recognition" `_
277 | """
278 |
279 | return _multi_step_sew_resnet('resnet18', MultiStepBasicBlock, [2, 2, 2, 2], pretrained, progress, T, cnf, multi_step_neuron, **kwargs)
280 |
--------------------------------------------------------------------------------
/sad/models/module/Real_MS_ResNet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from spikingjelly.clock_driven import functional
4 |
5 | try:
6 | from torchvision.models.utils import load_state_dict_from_url
7 | except ImportError:
8 | from torchvision._internally_replaced_utils import load_state_dict_from_url
9 |
10 | __all__ = ['MultiStepMSResNet', 'multi_step_sew_resnet18']
11 |
12 | model_urls = {
13 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
14 |
15 | }
16 |
17 |
18 | # modified by https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py
19 |
20 | def sew_function(x: torch.Tensor, y: torch.Tensor, cnf: str):
21 | if cnf == 'ADD':
22 | return x + y
23 | elif cnf == 'AND':
24 | return x * y
25 | elif cnf == 'IAND':
26 | return x * (1. - y)
27 | else:
28 | raise NotImplementedError
29 |
30 |
31 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
32 | """3x3 convolution with padding"""
33 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
34 | padding=dilation, groups=groups, bias=False, dilation=dilation)
35 |
36 |
37 | def conv1x1(in_planes, out_planes, stride=1):
38 | """1x1 convolution"""
39 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
40 |
41 |
42 | class BasicBlock(nn.Module):
43 | expansion = 1
44 |
45 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
46 | base_width=64, dilation=1, norm_layer=None, cnf: str = None, single_step_neuron: callable = None,
47 | **kwargs):
48 | super(BasicBlock, self).__init__()
49 | if norm_layer is None:
50 | norm_layer = nn.BatchNorm2d
51 | if groups != 1 or base_width != 64:
52 | raise ValueError('BasicBlock only supports groups=1 and base_width=64')
53 | if dilation > 1:
54 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
55 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
56 | self.conv1 = conv3x3(inplanes, planes, stride)
57 | self.bn1 = norm_layer(planes)
58 | self.sn1 = single_step_neuron(**kwargs)
59 | self.conv2 = conv3x3(planes, planes)
60 | self.bn2 = norm_layer(planes)
61 | self.sn2 = single_step_neuron(**kwargs)
62 | self.downsample = downsample
63 | if downsample is not None:
64 | self.downsample_sn = single_step_neuron(**kwargs)
65 | self.stride = stride
66 | self.cnf = cnf
67 |
68 | def forward(self, x):
69 | identity = x
70 |
71 | out = self.conv1(x)
72 | out = self.bn1(out)
73 | out = self.sn1(out)
74 |
75 | out = self.conv2(out)
76 | out = self.bn2(out)
77 | out = self.sn2(out)
78 |
79 | if self.downsample is not None:
80 | identity = self.downsample_sn(self.downsample(x))
81 |
82 | out = sew_function(identity, out, self.cnf)
83 |
84 | return out
85 |
86 | def extra_repr(self) -> str:
87 | return super().extra_repr() + f'cnf={self.cnf}'
88 |
89 |
90 | class MultiStepBasicBlock(BasicBlock):
91 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
92 | base_width=64, dilation=1, norm_layer=None, cnf: str = None, multi_step_neuron: callable = None, **kwargs):
93 | super().__init__(inplanes, planes, stride, downsample, groups,
94 | base_width, dilation, norm_layer, cnf, multi_step_neuron, **kwargs)
95 |
96 | def forward(self, x_seq):
97 | identity = x_seq
98 |
99 | out = self.sn1(x_seq)
100 | out = functional.seq_to_ann_forward(out, [self.conv1, self.bn1])
101 |
102 | out = self.sn2(out)
103 | out = functional.seq_to_ann_forward(out, [self.conv2, self.bn2])
104 |
105 |
106 | if self.downsample is not None:
107 | x_seq = self.downsample_sn(x_seq)
108 | identity = functional.seq_to_ann_forward(x_seq, self.downsample)
109 |
110 | out = identity + out
111 |
112 | return out
113 |
114 | # class MultiStepBasicBlock(BasicBlock):
115 | # def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
116 | # base_width=64, dilation=1, norm_layer=None, cnf: str = None, multi_step_neuron: callable = None,
117 | # **kwargs):
118 | # super().__init__(inplanes, planes, stride, downsample, groups,
119 | # base_width, dilation, norm_layer, cnf, multi_step_neuron, **kwargs)
120 | #
121 | # def forward(self, x_seq):
122 | # identity = x_seq
123 | #
124 | # x_seq = self.sn1(x_seq)
125 | # out = functional.seq_to_ann_forward(x_seq, [self.conv1, self.bn1])
126 | #
127 | # out = self.sn2(out)
128 | # out = functional.seq_to_ann_forward(out, [self.conv2, self.bn2])
129 | #
130 | #
131 | # if self.downsample is not None:
132 | # identity = self.downsample_sn(functional.seq_to_ann_forward(x_seq, self.downsample))
133 | #
134 | # out = sew_function(identity, out, self.cnf)
135 | #
136 | # return out
137 |
138 |
139 | class MultiStepMSResNet(nn.Module):
140 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
141 | groups=1, width_per_group=64, replace_stride_with_dilation=None,
142 | norm_layer=None, T: int = None, cnf: str = None, multi_step_neuron: callable = None, **kwargs):
143 | super().__init__()
144 | self.T = T
145 | if norm_layer is None:
146 | norm_layer = nn.BatchNorm2d
147 | self._norm_layer = norm_layer
148 |
149 | self.inplanes = 64
150 | self.dilation = 1
151 | if replace_stride_with_dilation is None:
152 | # each element in the tuple indicates if we should replace
153 | # the 2x2 stride with a dilated convolution instead
154 | replace_stride_with_dilation = [False, False, False]
155 | if len(replace_stride_with_dilation) != 3:
156 | raise ValueError("replace_stride_with_dilation should be None "
157 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
158 | self.groups = groups
159 | self.base_width = width_per_group
160 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
161 | bias=False)
162 | self.bn1 = norm_layer(self.inplanes)
163 | self.sn1 = multi_step_neuron(**kwargs)
164 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
165 | self.layer1 = self.make_layer(block, 64, layers[0], cnf=cnf, multi_step_neuron=multi_step_neuron, **kwargs)
166 | self.layer2 = self.make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0], cnf=cnf,
167 | multi_step_neuron=multi_step_neuron, **kwargs)
168 | self.layer3 = self.make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1], cnf=cnf,
169 | multi_step_neuron=multi_step_neuron, **kwargs)
170 | self.layer4 = self.make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2], cnf=cnf,
171 | multi_step_neuron=multi_step_neuron, **kwargs)
172 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
173 | self.fc = nn.Linear(512 * block.expansion, num_classes)
174 |
175 | for m in self.modules():
176 | if isinstance(m, nn.Conv2d):
177 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
178 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
179 | nn.init.constant_(m.weight, 1)
180 | nn.init.constant_(m.bias, 0)
181 |
182 | # Zero-initialize the last BN in each residual branch,
183 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
184 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
185 | if zero_init_residual:
186 | for m in self.modules():
187 | if isinstance(m, Bottleneck):
188 | nn.init.constant_(m.bn3.weight, 0)
189 | elif isinstance(m, BasicBlock):
190 | nn.init.constant_(m.bn2.weight, 0)
191 |
192 | def make_layer(self, block, planes, blocks, stride=1, dilate=False, cnf: str = None,
193 | multi_step_neuron: callable = None, **kwargs):
194 | norm_layer = self._norm_layer
195 | downsample = None
196 | previous_dilation = self.dilation
197 | if dilate:
198 | self.dilation *= stride
199 | stride = 1
200 | if stride != 1 or self.inplanes != planes * block.expansion:
201 | downsample = nn.Sequential(
202 | conv1x1(self.inplanes, planes * block.expansion, stride),
203 | norm_layer(planes * block.expansion),
204 | )
205 |
206 | layers = []
207 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
208 | self.base_width, previous_dilation, norm_layer, cnf, multi_step_neuron, **kwargs))
209 | self.inplanes = planes * block.expansion
210 | for _ in range(1, blocks):
211 | layers.append(block(self.inplanes, planes, groups=self.groups,
212 | base_width=self.base_width, dilation=self.dilation,
213 | norm_layer=norm_layer, cnf=cnf, multi_step_neuron=multi_step_neuron, **kwargs))
214 |
215 | return nn.Sequential(*layers)
216 |
217 | def _forward_impl(self, x: torch.Tensor):
218 | # See note [TorchScript super()]
219 | x_seq = None
220 | if x.dim() == 5:
221 | # x.shape = [T, N, C, H, W]
222 | x_seq = functional.seq_to_ann_forward(x, [self.conv1, self.bn1])
223 | else:
224 | assert self.T is not None, 'When x.shape is [N, C, H, W], self.T can not be None.'
225 | # x.shape = [N, C, H, W]
226 | x = self.conv1(x)
227 | x = self.bn1(x)
228 | x.unsqueeze_(0)
229 | x_seq = x.repeat(self.T, 1, 1, 1, 1)
230 |
231 | x_seq = functional.seq_to_ann_forward(x_seq, self.maxpool)
232 |
233 | x_seq = self.layer1(x_seq)
234 | x_seq = self.layer2(x_seq)
235 | x_seq = self.layer3(x_seq)
236 | x_seq = self.layer4(x_seq)
237 |
238 | x_seq = functional.seq_to_ann_forward(x_seq, self.avgpool)
239 | x_seq = self.sn1(x_seq)
240 | x_seq = torch.flatten(x_seq, 2)
241 | # x_seq = self.fc(x_seq.mean(0))
242 | x_seq = functional.seq_to_ann_forward(x_seq, self.fc)
243 |
244 | return x_seq
245 |
246 | def forward(self, x):
247 | """
248 | :param x: the input with `shape=[N, C, H, W]` or `[*, N, C, H, W]`
249 | :type x: torch.Tensor
250 | :return: output
251 | :rtype: torch.Tensor
252 | """
253 | return self._forward_impl(x)
254 |
255 |
256 | def _multi_step_sew_resnet(arch, block, layers, pretrained, progress, T, cnf, multi_step_neuron, **kwargs):
257 | model = MultiStepMSResNet(block, layers, T=T, cnf=cnf, multi_step_neuron=multi_step_neuron, **kwargs)
258 | if pretrained:
259 | state_dict = load_state_dict_from_url(model_urls[arch],
260 | progress=progress)
261 | model.load_state_dict(state_dict)
262 | return model
263 |
264 |
265 | def multi_step_sew_resnet18(pretrained=False, progress=True, T: int = None, cnf: str = 'ADD',
266 | multi_step_neuron: callable = None, **kwargs):
267 | """
268 | :param pretrained: If True, the SNN will load parameters from the ANN pre-trained on ImageNet
269 | :type pretrained: bool
270 | :param progress: If True, displays a progress bar of the download to stderr
271 | :type progress: bool
272 | :param T: total time-steps
273 | :type T: int
274 | :param cnf: the name of spike-element-wise function
275 | :type cnf: str
276 | :param multi_step_neuron: a multi-step neuron
277 | :type multi_step_neuron: callable
278 | :param kwargs: kwargs for `multi_step_neuron`
279 | :type kwargs: dict
280 | :return: Spiking ResNet-18
281 | :rtype: torch.nn.Module
282 |
283 | The multi-step spike-element-wise ResNet-18 `"Deep Residual Learning in Spiking Neural Networks" `_
284 | modified by the ResNet-18 model from `"Deep Residual Learning for Image Recognition" `_
285 | """
286 |
287 | return _multi_step_sew_resnet('resnet18', MultiStepBasicBlock, [2, 2, 2, 2], pretrained, progress, T, cnf,
288 | multi_step_neuron, **kwargs)
289 |
--------------------------------------------------------------------------------
/sad/utils/geometry.py:
--------------------------------------------------------------------------------
1 | import PIL
2 | import numpy as np
3 | import torch
4 | import torch.nn.functional as F
5 |
6 | from pyquaternion import Quaternion
7 | from nuscenes.utils.geometry_utils import transform_matrix
8 |
9 | def resize_and_crop_image(img, resize_dims, crop):
10 | # Bilinear resizing followed by cropping
11 | img = img.resize(resize_dims, resample=PIL.Image.BILINEAR)
12 | img = img.crop(crop)
13 | return img
14 |
15 |
16 | def update_intrinsics(intrinsics, top_crop=0.0, left_crop=0.0, scale_width=1.0, scale_height=1.0):
17 | """
18 | Parameters
19 | ----------
20 | intrinsics: torch.Tensor (3, 3)
21 | top_crop: float
22 | left_crop: float
23 | scale_width: float
24 | scale_height: float
25 | """
26 | updated_intrinsics = intrinsics.clone()
27 | # Adjust intrinsics scale due to resizing
28 | updated_intrinsics[0, 0] *= scale_width
29 | updated_intrinsics[0, 2] *= scale_width
30 | updated_intrinsics[1, 1] *= scale_height
31 | updated_intrinsics[1, 2] *= scale_height
32 |
33 | # Adjust principal point due to cropping
34 | updated_intrinsics[0, 2] -= left_crop
35 | updated_intrinsics[1, 2] -= top_crop
36 |
37 | return updated_intrinsics
38 |
39 |
40 | def calculate_birds_eye_view_parameters(x_bounds, y_bounds, z_bounds):
41 | """
42 | Parameters
43 | ----------
44 | x_bounds: Forward direction in the ego-car.
45 | y_bounds: Sides
46 | z_bounds: Height
47 |
48 | Returns
49 | -------
50 | bev_resolution: Bird's-eye view bev_resolution
51 | bev_start_position Bird's-eye view first element
52 | bev_dimension Bird's-eye view tensor spatial dimension
53 | """
54 | bev_resolution = torch.tensor([row[2] for row in [x_bounds, y_bounds, z_bounds]])
55 | bev_start_position = torch.tensor([row[0] + row[2] / 2.0 for row in [x_bounds, y_bounds, z_bounds]])
56 | bev_dimension = torch.tensor([(row[1] - row[0]) / row[2] for row in [x_bounds, y_bounds, z_bounds]],
57 | dtype=torch.long)
58 |
59 | return bev_resolution, bev_start_position, bev_dimension
60 |
61 |
62 | def convert_egopose_to_matrix_numpy(egopose):
63 | transformation_matrix = np.zeros((4, 4), dtype=np.float32)
64 | rotation = Quaternion(egopose['rotation']).rotation_matrix
65 | translation = np.array(egopose['translation'])
66 | transformation_matrix[:3, :3] = rotation
67 | transformation_matrix[:3, 3] = translation
68 | transformation_matrix[3, 3] = 1.0
69 | return transformation_matrix
70 |
71 | def get_global_pose(rec, nusc, inverse=False):
72 | lidar_sample_data = nusc.get('sample_data', rec['data']['LIDAR_TOP'])
73 |
74 | sd_ep = nusc.get("ego_pose", lidar_sample_data["ego_pose_token"])
75 | sd_cs = nusc.get("calibrated_sensor", lidar_sample_data["calibrated_sensor_token"])
76 | if inverse is False:
77 | global_from_ego = transform_matrix(sd_ep["translation"], Quaternion(sd_ep["rotation"]), inverse=False)
78 | ego_from_sensor = transform_matrix(sd_cs["translation"], Quaternion(sd_cs["rotation"]), inverse=False)
79 | pose = global_from_ego.dot(ego_from_sensor)
80 | else:
81 | sensor_from_ego = transform_matrix(sd_cs["translation"], Quaternion(sd_cs["rotation"]), inverse=True)
82 | ego_from_global = transform_matrix(sd_ep["translation"], Quaternion(sd_ep["rotation"]), inverse=True)
83 | pose = sensor_from_ego.dot(ego_from_global)
84 | return pose
85 |
86 | def invert_matrix_egopose_numpy(egopose):
87 | """ Compute the inverse transformation of a 4x4 egopose numpy matrix."""
88 | inverse_matrix = np.zeros((4, 4), dtype=np.float32)
89 | rotation = egopose[:3, :3]
90 | translation = egopose[:3, 3]
91 | inverse_matrix[:3, :3] = rotation.T
92 | inverse_matrix[:3, 3] = -np.dot(rotation.T, translation)
93 | inverse_matrix[3, 3] = 1.0
94 | return inverse_matrix
95 |
96 |
97 | def mat2pose_vec(matrix: torch.Tensor):
98 | """
99 | Converts a 4x4 pose matrix into a 6-dof pose vector
100 | Args:
101 | matrix (ndarray): 4x4 pose matrix
102 | Returns:
103 | vector (ndarray): 6-dof pose vector comprising translation components (tx, ty, tz) and
104 | rotation components (rx, ry, rz)
105 | """
106 |
107 | # M[1, 2] = -sinx*cosy, M[2, 2] = +cosx*cosy
108 | rotx = torch.atan2(-matrix[..., 1, 2], matrix[..., 2, 2])
109 |
110 | # M[0, 2] = +siny, M[1, 2] = -sinx*cosy, M[2, 2] = +cosx*cosy
111 | cosy = torch.sqrt(matrix[..., 1, 2] ** 2 + matrix[..., 2, 2] ** 2)
112 | roty = torch.atan2(matrix[..., 0, 2], cosy)
113 |
114 | # M[0, 0] = +cosy*cosz, M[0, 1] = -cosy*sinz
115 | rotz = torch.atan2(-matrix[..., 0, 1], matrix[..., 0, 0])
116 |
117 | rotation = torch.stack((rotx, roty, rotz), dim=-1)
118 |
119 | # Extract translation params
120 | translation = matrix[..., :3, 3]
121 | return torch.cat((translation, rotation), dim=-1)
122 |
123 |
124 | def euler2mat(angle: torch.Tensor):
125 | """Convert euler angles to rotation matrix.
126 | Reference: https://github.com/pulkitag/pycaffe-utils/blob/master/rot_utils.py#L174
127 | Args:
128 | angle: rotation angle along 3 axis (in radians) [Bx3]
129 | Returns:
130 | Rotation matrix corresponding to the euler angles [Bx3x3]
131 | """
132 | shape = angle.shape
133 | angle = angle.view(-1, 3)
134 | x, y, z = angle[:, 0], angle[:, 1], angle[:, 2]
135 |
136 | cosz = torch.cos(z)
137 | sinz = torch.sin(z)
138 |
139 | zeros = torch.zeros_like(z)
140 | ones = torch.ones_like(z)
141 | zmat = torch.stack([cosz, -sinz, zeros, sinz, cosz, zeros, zeros, zeros, ones], dim=1).view(-1, 3, 3)
142 |
143 | cosy = torch.cos(y)
144 | siny = torch.sin(y)
145 |
146 | ymat = torch.stack([cosy, zeros, siny, zeros, ones, zeros, -siny, zeros, cosy], dim=1).view(-1, 3, 3)
147 |
148 | cosx = torch.cos(x)
149 | sinx = torch.sin(x)
150 |
151 | xmat = torch.stack([ones, zeros, zeros, zeros, cosx, -sinx, zeros, sinx, cosx], dim=1).view(-1, 3, 3)
152 |
153 | rot_mat = xmat.bmm(ymat).bmm(zmat)
154 | rot_mat = rot_mat.view(*shape[:-1], 3, 3)
155 | return rot_mat
156 |
157 |
158 | def pose_vec2mat(vec: torch.Tensor):
159 | """
160 | Convert 6DoF parameters to transformation matrix.
161 | Args:
162 | vec: 6DoF parameters in the order of tx, ty, tz, rx, ry, rz [B,6]
163 | Returns:
164 | A transformation matrix [B,4,4]
165 | """
166 | translation = vec[..., :3].unsqueeze(-1) # [...x3x1]
167 | rot = vec[..., 3:].contiguous() # [...x3]
168 | rot_mat = euler2mat(rot) # [...,3,3]
169 | transform_mat = torch.cat([rot_mat, translation], dim=-1) # [...,3,4]
170 | transform_mat = torch.nn.functional.pad(transform_mat, [0, 0, 0, 1], value=0) # [...,4,4]
171 | transform_mat[..., 3, 3] = 1.0
172 | return transform_mat
173 |
174 |
175 | def invert_pose_matrix(x):
176 | """
177 | Parameters
178 | ----------
179 | x: [B, 4, 4] batch of pose matrices
180 |
181 | Returns
182 | -------
183 | out: [B, 4, 4] batch of inverse pose matrices
184 | """
185 | assert len(x.shape) == 3 and x.shape[1:] == (4, 4), 'Only works for batch of pose matrices.'
186 |
187 | transposed_rotation = torch.transpose(x[:, :3, :3], 1, 2)
188 | translation = x[:, :3, 3:]
189 |
190 | inverse_mat = torch.cat([transposed_rotation, -torch.bmm(transposed_rotation, translation)], dim=-1) # [B,3,4]
191 | inverse_mat = torch.nn.functional.pad(inverse_mat, [0, 0, 0, 1], value=0) # [B,4,4]
192 | inverse_mat[..., 3, 3] = 1.0
193 | return inverse_mat
194 |
195 |
196 | def warp_features(x, flow, mode='nearest', spatial_extent=None):
197 | """ Applies a rotation and translation to feature map x.
198 | Args:
199 | x: (b, c, h, w) feature map
200 | flow: (b, 6) 6DoF vector (only uses the xy poriton)
201 | mode: use 'nearest' when dealing with categorical inputs
202 | Returns:
203 | in plane transformed feature map
204 | """
205 | if flow is None:
206 | return x
207 | b, c, h, w = x.shape
208 | # z-rotation
209 | angle = flow[:, 5].clone() # torch.atan2(flow[:, 1, 0], flow[:, 0, 0])
210 | # x-y translation
211 | translation = flow[:, :2].clone() # flow[:, :2, 3]
212 |
213 | # Normalise translation. Need to divide by how many meters is half of the image.
214 | # because translation of 1.0 correspond to translation of half of the image.
215 | translation[:, 0] /= spatial_extent[0]
216 | translation[:, 1] /= spatial_extent[1]
217 | # forward axis is inverted
218 | translation[:, 0] *= -1
219 |
220 | cos_theta = torch.cos(angle)
221 | sin_theta = torch.sin(angle)
222 |
223 | # output = Rot.input + translation
224 | # tx and ty are inverted as is the case when going from real coordinates to numpy coordinates
225 | # translation_pos_0 -> positive value makes the image move to the left
226 | # translation_pos_1 -> positive value makes the image move to the top
227 | # Angle -> positive value in rad makes the image move in the trigonometric way
228 | transformation = torch.stack([cos_theta, -sin_theta, translation[:, 1],
229 | sin_theta, cos_theta, translation[:, 0]], dim=-1).view(b, 2, 3)
230 |
231 | # Note that a rotation will preserve distances only if height = width. Otherwise there's
232 | # resizing going on. e.g. rotation of pi/2 of a 100x200 image will make what's in the center of the image
233 | # elongated.
234 | grid = torch.nn.functional.affine_grid(transformation, size=x.shape, align_corners=False)
235 | grid = grid.to(dtype=x.dtype)
236 | warped_x = torch.nn.functional.grid_sample(x, grid, mode=mode, padding_mode='zeros', align_corners=False)
237 |
238 | return warped_x
239 |
240 |
241 | def cumulative_warp_features(x, flow, mode='nearest', spatial_extent=None):
242 | """ Warps a sequence of feature maps by accumulating incremental 2d flow.
243 |
244 | x[:, -1] remains unchanged
245 | x[:, -2] is warped using flow[:, -2]
246 | x[:, -3] is warped using flow[:, -3] @ flow[:, -2]
247 | ...
248 | x[:, 0] is warped using flow[:, 0] @ ... @ flow[:, -3] @ flow[:, -2]
249 |
250 | Args:
251 | x: (b, t, c, h, w) sequence of feature maps
252 | flow: (b, t, 6) sequence of 6 DoF pose
253 | from t to t+1 (only uses the xy poriton)
254 |
255 | """
256 | sequence_length = x.shape[1]
257 | if sequence_length == 1:
258 | return x
259 |
260 | flow = pose_vec2mat(flow)
261 |
262 | out = [x[:, -1]]
263 | cum_flow = flow[:, -2]
264 | for t in reversed(range(sequence_length - 1)):
265 | out.append(warp_features(x[:, t], mat2pose_vec(cum_flow), mode=mode, spatial_extent=spatial_extent))
266 | # @ is the equivalent of torch.bmm
267 | cum_flow = flow[:, t - 1] @ cum_flow
268 |
269 | return torch.stack(out[::-1], 1)
270 |
271 |
272 | def cumulative_warp_features_reverse(x, flow, mode='nearest', spatial_extent=None):
273 | """ Warps a sequence of feature maps by accumulating incremental 2d flow.
274 |
275 | x[:, 0] remains unchanged
276 | x[:, 1] is warped using flow[:, 0].inverse()
277 | x[:, 2] is warped using flow[:, 0].inverse() @ flow[:, 1].inverse()
278 | ...
279 |
280 | Args:
281 | x: (b, t, c, h, w) sequence of feature maps
282 | flow: (b, t, 6) sequence of 6 DoF pose
283 | from t to t+1 (only uses the xy poriton)
284 |
285 | """
286 | flow = pose_vec2mat(flow)
287 |
288 | out = [x[:,0]]
289 |
290 | for i in range(1, x.shape[1]):
291 | if i==1:
292 | cum_flow = invert_pose_matrix(flow[:, 0])
293 | else:
294 | cum_flow = cum_flow @ invert_pose_matrix(flow[:,i-1])
295 | out.append( warp_features(x[:,i], mat2pose_vec(cum_flow), mode, spatial_extent=spatial_extent))
296 | return torch.stack(out, 1)
297 |
298 |
299 | class VoxelsSumming(torch.autograd.Function):
300 | """Adapted from https://github.com/nv-tlabs/lift-splat-shoot/blob/master/src/tools.py#L193"""
301 | @staticmethod
302 | def forward(ctx, x, geometry, ranks):
303 | """The features `x` and `geometry` are ranked by voxel positions."""
304 | # Cumulative sum of all features.
305 | x = x.cumsum(0)
306 |
307 | # Indicates the change of voxel.
308 | mask = torch.ones(x.shape[0], device=x.device, dtype=torch.bool)
309 | mask[:-1] = ranks[1:] != ranks[:-1]
310 |
311 | x, geometry = x[mask], geometry[mask]
312 | # Calculate sum of features within a voxel.
313 | x = torch.cat((x[:1], x[1:] - x[:-1]))
314 |
315 | ctx.save_for_backward(mask)
316 | ctx.mark_non_differentiable(geometry)
317 |
318 | return x, geometry
319 |
320 | @staticmethod
321 | def backward(ctx, grad_x, grad_geometry):
322 | (mask,) = ctx.saved_tensors
323 | # Since the operation is summing, we simply need to send gradient
324 | # to all elements that were part of the summation process.
325 | indices = torch.cumsum(mask, 0)
326 | indices[mask] -= 1
327 |
328 | output_grad = grad_x[indices]
329 |
330 | return output_grad, None, None
331 |
--------------------------------------------------------------------------------
/sad/utils/instance.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 |
3 | import torch
4 | import torch.nn.functional as F
5 | import numpy as np
6 | from scipy.optimize import linear_sum_assignment
7 |
8 | from sad.utils.geometry import mat2pose_vec, pose_vec2mat, warp_features
9 |
10 |
11 | # set ignore index to 0 for vis
12 | def convert_instance_mask_to_center_and_offset_label(instance_img, future_egomotion, num_instances, ignore_index=255,
13 | subtract_egomotion=True, sigma=3, spatial_extent=None):
14 | seq_len, h, w = instance_img.shape
15 | center_label = torch.zeros(seq_len, 1, h, w)
16 | offset_label = ignore_index * torch.ones(seq_len, 2, h, w)
17 | future_displacement_label = ignore_index * torch.ones(seq_len, 2, h, w)
18 | # x is vertical displacement, y is horizontal displacement
19 | x, y = torch.meshgrid(torch.arange(h, dtype=torch.float), torch.arange(w, dtype=torch.float))
20 |
21 | if subtract_egomotion:
22 | future_egomotion_inv = mat2pose_vec(pose_vec2mat(future_egomotion).inverse())
23 |
24 | # Compute warped instance segmentation
25 | warped_instance_seg = {}
26 | for t in range(1, seq_len):
27 | warped_inst_t = warp_features(instance_img[t].unsqueeze(0).unsqueeze(1).float(),
28 | future_egomotion_inv[t - 1].unsqueeze(0), mode='nearest',
29 | spatial_extent=spatial_extent)
30 | warped_instance_seg[t] = warped_inst_t[0, 0]
31 |
32 | # Ignore id 0 which is the background
33 | for instance_id in range(1, num_instances+1):
34 | prev_xc = None
35 | prev_yc = None
36 | prev_mask = None
37 | for t in range(seq_len):
38 | instance_mask = (instance_img[t] == instance_id)
39 | if instance_mask.sum() == 0:
40 | # this instance is not in this frame
41 | prev_xc = None
42 | prev_yc = None
43 | prev_mask = None
44 | continue
45 |
46 | xc = x[instance_mask].mean().round().long()
47 | yc = y[instance_mask].mean().round().long()
48 |
49 | off_x = xc - x
50 | off_y = yc - y
51 | g = torch.exp(-(off_x ** 2 + off_y ** 2) / sigma ** 2)
52 | center_label[t, 0] = torch.maximum(center_label[t, 0], g)
53 | offset_label[t, 0, instance_mask] = off_x[instance_mask]
54 | offset_label[t, 1, instance_mask] = off_y[instance_mask]
55 |
56 | if prev_xc is not None:
57 | # old method
58 | # cur_pt = torch.stack((xc, yc)).unsqueeze(0).float()
59 | # if subtract_egomotion:
60 | # cur_pt = warp_points(cur_pt, future_egomotion_inv[t - 1])
61 | # cur_pt = cur_pt.squeeze(0)
62 |
63 | warped_instance_mask = warped_instance_seg[t] == instance_id
64 | if warped_instance_mask.sum() > 0:
65 | warped_xc = x[warped_instance_mask].mean().round()
66 | warped_yc = y[warped_instance_mask].mean().round()
67 |
68 | delta_x = warped_xc - prev_xc
69 | delta_y = warped_yc - prev_yc
70 | future_displacement_label[t - 1, 0, prev_mask] = delta_x
71 | future_displacement_label[t - 1, 1, prev_mask] = delta_y
72 |
73 | prev_xc = xc
74 | prev_yc = yc
75 | prev_mask = instance_mask
76 |
77 | return center_label, offset_label, future_displacement_label
78 |
79 |
80 | def find_instance_centers(center_prediction: torch.Tensor, conf_threshold: float = 0.1, nms_kernel_size: float = 3):
81 | assert len(center_prediction.shape) == 3
82 | center_prediction = F.threshold(center_prediction, threshold=conf_threshold, value=-1)
83 |
84 | nms_padding = (nms_kernel_size - 1) // 2
85 | maxpooled_center_prediction = F.max_pool2d(
86 | center_prediction, kernel_size=nms_kernel_size, stride=1, padding=nms_padding
87 | )
88 |
89 | # Filter all elements that are not the maximum (i.e. the center of the heatmap instance)
90 | center_prediction[center_prediction != maxpooled_center_prediction] = -1
91 | return torch.nonzero(center_prediction > 0)[:, 1:]
92 |
93 |
94 | def group_pixels(centers: torch.Tensor, offset_predictions: torch.Tensor) -> torch.Tensor:
95 | width, height = offset_predictions.shape[-2:]
96 | x_grid = (
97 | torch.arange(width, dtype=offset_predictions.dtype, device=offset_predictions.device)
98 | .view(1, width, 1)
99 | .repeat(1, 1, height)
100 | )
101 | y_grid = (
102 | torch.arange(height, dtype=offset_predictions.dtype, device=offset_predictions.device)
103 | .view(1, 1, height)
104 | .repeat(1, width, 1)
105 | )
106 | pixel_grid = torch.cat((x_grid, y_grid), dim=0)
107 | center_locations = (pixel_grid + offset_predictions).view(2, width * height, 1).permute(2, 1, 0)
108 | centers = centers.view(-1, 1, 2)
109 |
110 | distances = torch.norm(centers - center_locations, dim=-1)
111 |
112 | instance_id = torch.argmin(distances, dim=0).reshape(1, width, height) + 1
113 | return instance_id
114 |
115 |
116 | def get_instance_segmentation_and_centers(
117 | center_predictions: torch.Tensor,
118 | offset_predictions: torch.Tensor,
119 | foreground_mask: torch.Tensor,
120 | conf_threshold: float = 0.1,
121 | nms_kernel_size: float = 3,
122 | max_n_instance_centers: int = 100,
123 | ) -> Tuple[torch.Tensor, torch.Tensor]:
124 | width, height = center_predictions.shape[-2:]
125 | center_predictions = center_predictions.view(1, width, height)
126 | offset_predictions = offset_predictions.view(2, width, height)
127 | foreground_mask = foreground_mask.view(1, width, height)
128 |
129 | centers = find_instance_centers(center_predictions, conf_threshold=conf_threshold, nms_kernel_size=nms_kernel_size)
130 | if not len(centers):
131 | return torch.zeros(center_predictions.shape, dtype=torch.int64, device=center_predictions.device), \
132 | torch.zeros((0, 2), device=centers.device)
133 |
134 | if len(centers) > max_n_instance_centers:
135 | # print(f'There are a lot of detected instance centers: {centers.shape}')
136 | centers = centers[:max_n_instance_centers].clone()
137 |
138 | instance_ids = group_pixels(centers, offset_predictions)
139 | instance_seg = (instance_ids * foreground_mask.float()).long()
140 |
141 | # Make the indices of instance_seg consecutive
142 | instance_seg = make_instance_seg_consecutive(instance_seg)
143 |
144 | return instance_seg.long(), centers
145 |
146 |
147 | def update_instance_ids(instance_seg, old_ids, new_ids):
148 | """
149 | Parameters
150 | ----------
151 | instance_seg: torch.Tensor arbitrary shape
152 | old_ids: 1D tensor containing the list of old ids, must be all present in instance_seg.
153 | new_ids: 1D tensor with the new ids, aligned with old_ids
154 |
155 | Returns
156 | new_instance_seg: torch.Tensor same shape as instance_seg with new ids
157 | """
158 | indices = torch.arange(old_ids.max() + 1, device=instance_seg.device)
159 | for old_id, new_id in zip(old_ids, new_ids):
160 | indices[old_id] = new_id
161 |
162 | return indices[instance_seg].long()
163 |
164 |
165 | def make_instance_seg_consecutive(instance_seg):
166 | # Make the indices of instance_seg consecutive
167 | unique_ids = torch.unique(instance_seg)
168 | new_ids = torch.arange(len(unique_ids), device=instance_seg.device)
169 | instance_seg = update_instance_ids(instance_seg, unique_ids, new_ids)
170 | return instance_seg
171 |
172 |
173 | def make_instance_id_temporally_consistent(pred_inst, future_flow, matching_threshold=3.0):
174 | """
175 | Parameters
176 | ----------
177 | pred_inst: torch.Tensor (1, seq_len, h, w)
178 | future_flow: torch.Tensor(1, seq_len, 2, h, w)
179 | matching_threshold: distance threshold for a match to be valid.
180 |
181 | Returns
182 | -------
183 | consistent_instance_seg: torch.Tensor(1, seq_len, h, w)
184 |
185 | 1. time t. Loop over all detected instances. Use flow to compute new centers at time t+1.
186 | 2. Store those centers
187 | 3. time t+1. Re-identify instances by comparing position of actual centers, and flow-warped centers.
188 | Make the labels at t+1 consistent with the matching
189 | 4. Repeat
190 | """
191 | assert pred_inst.shape[0] == 1, 'Assumes batch size = 1'
192 |
193 | # Initialise instance segmentations with prediction corresponding to the present
194 | consistent_instance_seg = [pred_inst[0, 0]]
195 | largest_instance_id = consistent_instance_seg[0].max().item()
196 |
197 | _, seq_len, h, w = pred_inst.shape
198 | device = pred_inst.device
199 | for t in range(seq_len - 1):
200 | # Compute predicted future instance means
201 | grid = torch.stack(torch.meshgrid(
202 | torch.arange(h, dtype=torch.float, device=device), torch.arange(w, dtype=torch.float, device=device)
203 | ))
204 |
205 | # Add future flow
206 | grid = grid + future_flow[0, t]
207 | warped_centers = []
208 | # Go through all ids, except the background
209 | t_instance_ids = torch.unique(consistent_instance_seg[-1])[1:].cpu().numpy()
210 |
211 | if len(t_instance_ids) == 0:
212 | # No instance so nothing to update
213 | consistent_instance_seg.append(pred_inst[0, t + 1])
214 | continue
215 |
216 | for instance_id in t_instance_ids:
217 | instance_mask = (consistent_instance_seg[-1] == instance_id)
218 | warped_centers.append(grid[:, instance_mask].mean(dim=1))
219 | warped_centers = torch.stack(warped_centers)
220 |
221 | # Compute actual future instance means
222 | centers = []
223 | grid = torch.stack(torch.meshgrid(
224 | torch.arange(h, dtype=torch.float, device=device), torch.arange(w, dtype=torch.float, device=device)
225 | ))
226 | n_instances = int(pred_inst[0, t + 1].max().item())
227 |
228 | if n_instances == 0:
229 | # No instance, so nothing to update.
230 | consistent_instance_seg.append(pred_inst[0, t + 1])
231 | continue
232 |
233 | for instance_id in range(1, n_instances + 1):
234 | instance_mask = (pred_inst[0, t + 1] == instance_id)
235 | centers.append(grid[:, instance_mask].mean(dim=1))
236 | centers = torch.stack(centers)
237 |
238 | # Compute distance matrix between warped centers and actual centers
239 | distances = torch.norm(centers.unsqueeze(0) - warped_centers.unsqueeze(1), dim=-1).cpu().numpy()
240 | # outputs (row, col) with row: index in frame t, col: index in frame t+1
241 | # the missing ids in col must be added (correspond to new instances)
242 | ids_t, ids_t_one = linear_sum_assignment(distances)
243 | matching_distances = distances[ids_t, ids_t_one]
244 | # Offset by one as id=0 is the background
245 | ids_t += 1
246 | ids_t_one += 1
247 |
248 | # swap ids_t with real ids. as those ids correspond to the position in the distance matrix.
249 | id_mapping = dict(zip(np.arange(1, len(t_instance_ids) + 1), t_instance_ids))
250 | ids_t = np.vectorize(id_mapping.__getitem__, otypes=[np.int64])(ids_t)
251 |
252 | # Filter low quality match
253 | ids_t = ids_t[matching_distances < matching_threshold]
254 | ids_t_one = ids_t_one[matching_distances < matching_threshold]
255 |
256 | # Elements that are in t+1, but weren't matched
257 | remaining_ids = set(torch.unique(pred_inst[0, t + 1]).cpu().numpy()).difference(set(ids_t_one))
258 | # remove background
259 | remaining_ids.remove(0)
260 | # Set remaining_ids to a new unique id
261 | for remaining_id in list(remaining_ids):
262 | largest_instance_id += 1
263 | ids_t = np.append(ids_t, largest_instance_id)
264 | ids_t_one = np.append(ids_t_one, remaining_id)
265 |
266 | consistent_instance_seg.append(update_instance_ids(pred_inst[0, t + 1], old_ids=ids_t_one, new_ids=ids_t))
267 |
268 | consistent_instance_seg = torch.stack(consistent_instance_seg).unsqueeze(0)
269 | return consistent_instance_seg
270 |
271 |
272 | def predict_instance_segmentation_and_trajectories(
273 | output, compute_matched_centers=False, make_consistent=True, vehicles_id=1,
274 | ):
275 | preds = output['segmentation'].detach()
276 | preds = torch.argmax(preds, dim=2, keepdim=True)
277 | foreground_masks = preds.squeeze(2) == vehicles_id
278 |
279 | batch_size, seq_len = preds.shape[:2]
280 | pred_inst = []
281 | for b in range(batch_size):
282 | pred_inst_batch = []
283 | for t in range(seq_len):
284 | pred_instance_t, _ = get_instance_segmentation_and_centers(
285 | output['instance_center'][b, t].detach(),
286 | output['instance_offset'][b, t].detach(),
287 | foreground_masks[b, t].detach()
288 | )
289 | pred_inst_batch.append(pred_instance_t)
290 | pred_inst.append(torch.stack(pred_inst_batch, dim=0))
291 |
292 | pred_inst = torch.stack(pred_inst).squeeze(2)
293 |
294 | if make_consistent:
295 | if output['instance_flow'] is None:
296 | print('Using zero flow because instance_future_output is None')
297 | output['instance_flow'] = torch.zeros_like(output['instance_offset'])
298 | consistent_instance_seg = []
299 | for b in range(batch_size):
300 | consistent_instance_seg.append(
301 | make_instance_id_temporally_consistent(pred_inst[b:b+1],
302 | output['instance_flow'][b:b+1].detach())
303 | )
304 | consistent_instance_seg = torch.cat(consistent_instance_seg, dim=0)
305 | else:
306 | consistent_instance_seg = pred_inst
307 |
308 | if compute_matched_centers:
309 | assert batch_size == 1
310 | # Generate trajectories
311 | matched_centers = {}
312 | _, seq_len, h, w = consistent_instance_seg.shape
313 | grid = torch.stack(torch.meshgrid(
314 | torch.arange(h, dtype=torch.float, device=preds.device),
315 | torch.arange(w, dtype=torch.float, device=preds.device)
316 | ))
317 |
318 | for instance_id in torch.unique(consistent_instance_seg[0, 0])[1:].cpu().numpy():
319 | for t in range(seq_len):
320 | instance_mask = consistent_instance_seg[0, t] == instance_id
321 | if instance_mask.sum() > 0:
322 | matched_centers[instance_id] = matched_centers.get(instance_id, []) + [
323 | grid[:, instance_mask].mean(dim=-1)]
324 |
325 | for key, value in matched_centers.items():
326 | matched_centers[key] = torch.stack(value).cpu().numpy()[:, ::-1]
327 |
328 | return consistent_instance_seg, matched_centers
329 |
330 | return consistent_instance_seg
331 |
332 |
333 |
--------------------------------------------------------------------------------