├── stretchbev
├── .DS_Store
├── utils
│ ├── .DS_Store
│ ├── __pycache__
│ │ ├── network.cpython-37.pyc
│ │ ├── geometry.cpython-37.pyc
│ │ ├── instance.cpython-37.pyc
│ │ ├── lyft_splits.cpython-37.pyc
│ │ └── visualisation.cpython-37.pyc
│ ├── lyft_splits.py
│ ├── network.py
│ ├── .ipynb_checkpoints
│ │ ├── network-checkpoint.py
│ │ └── visualisation-checkpoint.py
│ ├── geometry.py
│ └── visualisation.py
├── layers
│ ├── .DS_Store
│ ├── __pycache__
│ │ ├── temporal.cpython-37.pyc
│ │ └── convolutions.cpython-37.pyc
│ ├── convolutions.py
│ ├── .ipynb_checkpoints
│ │ └── temporal-checkpoint.py
│ └── temporal.py
├── models
│ ├── .DS_Store
│ ├── __pycache__
│ │ ├── decoder.cpython-37.pyc
│ │ ├── encoder.cpython-37.pyc
│ │ ├── model_utils.cpython-37.pyc
│ │ ├── res_models.cpython-37.pyc
│ │ ├── srvp_models.cpython-37.pyc
│ │ ├── stretchbev.cpython-37.pyc
│ │ ├── distributions.cpython-37.pyc
│ │ ├── temporal_model.cpython-37.pyc
│ │ └── future_prediction.cpython-37.pyc
│ ├── future_prediction.py
│ ├── .ipynb_checkpoints
│ │ ├── future_prediction-checkpoint.py
│ │ ├── temporal_model-checkpoint.py
│ │ ├── decoder-checkpoint.py
│ │ ├── model_utils-checkpoint.py
│ │ ├── res_models-checkpoint.py
│ │ └── srvp_models-checkpoint.py
│ ├── distributions.py
│ ├── temporal_model.py
│ ├── decoder.py
│ ├── encoder.py
│ ├── model_utils.py
│ ├── res_models.py
│ └── srvp_models.py
├── __pycache__
│ ├── config.cpython-37.pyc
│ ├── data.cpython-37.pyc
│ ├── losses.cpython-37.pyc
│ ├── metrics.cpython-37.pyc
│ └── trainer.cpython-37.pyc
├── configs
│ └── stretchbev.yml
├── losses.py
├── .ipynb_checkpoints
│ ├── losses-checkpoint.py
│ ├── config-checkpoint.py
│ ├── metrics-checkpoint.py
│ └── trainer-checkpoint.py
├── config.py
└── metrics.py
├── environment.yml
├── LICENSE
├── train.py
├── README.md
├── evaluate.py
└── visualise.py
/stretchbev/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kaanakan/stretchbev/HEAD/stretchbev/.DS_Store
--------------------------------------------------------------------------------
/stretchbev/utils/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kaanakan/stretchbev/HEAD/stretchbev/utils/.DS_Store
--------------------------------------------------------------------------------
/stretchbev/layers/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kaanakan/stretchbev/HEAD/stretchbev/layers/.DS_Store
--------------------------------------------------------------------------------
/stretchbev/models/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kaanakan/stretchbev/HEAD/stretchbev/models/.DS_Store
--------------------------------------------------------------------------------
/stretchbev/__pycache__/config.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kaanakan/stretchbev/HEAD/stretchbev/__pycache__/config.cpython-37.pyc
--------------------------------------------------------------------------------
/stretchbev/__pycache__/data.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kaanakan/stretchbev/HEAD/stretchbev/__pycache__/data.cpython-37.pyc
--------------------------------------------------------------------------------
/stretchbev/__pycache__/losses.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kaanakan/stretchbev/HEAD/stretchbev/__pycache__/losses.cpython-37.pyc
--------------------------------------------------------------------------------
/stretchbev/__pycache__/metrics.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kaanakan/stretchbev/HEAD/stretchbev/__pycache__/metrics.cpython-37.pyc
--------------------------------------------------------------------------------
/stretchbev/__pycache__/trainer.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kaanakan/stretchbev/HEAD/stretchbev/__pycache__/trainer.cpython-37.pyc
--------------------------------------------------------------------------------
/stretchbev/utils/__pycache__/network.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kaanakan/stretchbev/HEAD/stretchbev/utils/__pycache__/network.cpython-37.pyc
--------------------------------------------------------------------------------
/stretchbev/layers/__pycache__/temporal.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kaanakan/stretchbev/HEAD/stretchbev/layers/__pycache__/temporal.cpython-37.pyc
--------------------------------------------------------------------------------
/stretchbev/models/__pycache__/decoder.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kaanakan/stretchbev/HEAD/stretchbev/models/__pycache__/decoder.cpython-37.pyc
--------------------------------------------------------------------------------
/stretchbev/models/__pycache__/encoder.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kaanakan/stretchbev/HEAD/stretchbev/models/__pycache__/encoder.cpython-37.pyc
--------------------------------------------------------------------------------
/stretchbev/utils/__pycache__/geometry.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kaanakan/stretchbev/HEAD/stretchbev/utils/__pycache__/geometry.cpython-37.pyc
--------------------------------------------------------------------------------
/stretchbev/utils/__pycache__/instance.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kaanakan/stretchbev/HEAD/stretchbev/utils/__pycache__/instance.cpython-37.pyc
--------------------------------------------------------------------------------
/stretchbev/models/__pycache__/model_utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kaanakan/stretchbev/HEAD/stretchbev/models/__pycache__/model_utils.cpython-37.pyc
--------------------------------------------------------------------------------
/stretchbev/models/__pycache__/res_models.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kaanakan/stretchbev/HEAD/stretchbev/models/__pycache__/res_models.cpython-37.pyc
--------------------------------------------------------------------------------
/stretchbev/models/__pycache__/srvp_models.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kaanakan/stretchbev/HEAD/stretchbev/models/__pycache__/srvp_models.cpython-37.pyc
--------------------------------------------------------------------------------
/stretchbev/models/__pycache__/stretchbev.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kaanakan/stretchbev/HEAD/stretchbev/models/__pycache__/stretchbev.cpython-37.pyc
--------------------------------------------------------------------------------
/stretchbev/utils/__pycache__/lyft_splits.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kaanakan/stretchbev/HEAD/stretchbev/utils/__pycache__/lyft_splits.cpython-37.pyc
--------------------------------------------------------------------------------
/stretchbev/layers/__pycache__/convolutions.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kaanakan/stretchbev/HEAD/stretchbev/layers/__pycache__/convolutions.cpython-37.pyc
--------------------------------------------------------------------------------
/stretchbev/models/__pycache__/distributions.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kaanakan/stretchbev/HEAD/stretchbev/models/__pycache__/distributions.cpython-37.pyc
--------------------------------------------------------------------------------
/stretchbev/models/__pycache__/temporal_model.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kaanakan/stretchbev/HEAD/stretchbev/models/__pycache__/temporal_model.cpython-37.pyc
--------------------------------------------------------------------------------
/stretchbev/utils/__pycache__/visualisation.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kaanakan/stretchbev/HEAD/stretchbev/utils/__pycache__/visualisation.cpython-37.pyc
--------------------------------------------------------------------------------
/stretchbev/models/__pycache__/future_prediction.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kaanakan/stretchbev/HEAD/stretchbev/models/__pycache__/future_prediction.cpython-37.pyc
--------------------------------------------------------------------------------
/stretchbev/configs/stretchbev.yml:
--------------------------------------------------------------------------------
1 | TAG: 'baseline'
2 |
3 | GPUS: [0, 1, 2, 3]
4 |
5 | BATCHSIZE: 2
6 | PRECISION: 16
7 |
8 | TIME_RECEPTIVE_FIELD: 3
9 | N_FUTURE_FRAMES: 4
10 |
11 | EPOCHS: 20
12 |
13 | MODEL:
14 | BN_MOMENTUM: 0.05
15 |
16 |
17 | PRETRAINED:
18 | LOAD_WEIGHTS: True
19 | PATH: './static_lift_splat_setting.ckpt'
20 |
21 |
22 | INSTANCE_FLOW:
23 | ENABLED: True
24 |
25 | OPTIMIZER:
26 | LR: 1e-4
27 |
28 | N_WORKERS: 4
29 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: fiery
2 | channels:
3 | - defaults
4 | - conda-forge
5 | - pytorch
6 | dependencies:
7 | - python=3.7.10
8 | - pytorch=1.7.0
9 | - torchvision=0.8.1
10 | - cudatoolkit=10.1
11 | - numpy=1.19.2
12 | - scipy=1.5.2
13 | - pillow=8.0.1
14 | - tqdm=4.50.2
15 | - pytorch-lightning=1.2.5
16 | - efficientnet-pytorch=0.7.0
17 | - fvcore=0.1.2.post20201122
18 | - pip=21.0.1
19 | - pip:
20 | - nuscenes-devkit==1.1.0
21 | - lyft-dataset-sdk==0.0.8
22 | - opencv-python==4.5.1.48
23 | - moviepy==1.0.3
24 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Wayve Technologies Limited
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/stretchbev/utils/lyft_splits.py:
--------------------------------------------------------------------------------
1 | TRAIN_LYFT_INDICES = [1, 3, 5, 6, 7, 8, 9, 10, 11, 12, 14, 15, 16,
2 | 17, 18, 19, 20, 21, 23, 24, 27, 28, 29, 30, 31, 32,
3 | 33, 35, 36, 37, 39, 41, 43, 44, 45, 46, 47, 48, 49,
4 | 50, 51, 52, 53, 55, 56, 59, 60, 62, 63, 65, 68, 69,
5 | 70, 71, 72, 73, 74, 75, 76, 78, 79, 81, 82, 83, 84,
6 | 86, 87, 88, 89, 93, 95, 97, 98, 99, 103, 104, 107, 108,
7 | 109, 110, 111, 113, 114, 115, 116, 117, 118, 119, 121, 122, 124,
8 | 127, 128, 130, 131, 132, 134, 135, 136, 137, 138, 139, 143, 144,
9 | 146, 147, 148, 149, 150, 151, 152, 153, 154, 156, 157, 158, 159,
10 | 161, 162, 165, 166, 167, 171, 172, 173, 174, 175, 176, 177, 178,
11 | 179]
12 |
13 | VAL_LYFT_INDICES = [0, 2, 4, 13, 22, 25, 26, 34, 38, 40, 42, 54, 57,
14 | 58, 61, 64, 66, 67, 77, 80, 85, 90, 91, 92, 94, 96,
15 | 100, 101, 102, 105, 106, 112, 120, 123, 125, 126, 129, 133, 140,
16 | 141, 142, 145, 155, 160, 163, 164, 168, 169, 170]
17 |
--------------------------------------------------------------------------------
/stretchbev/utils/network.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torchvision
4 |
5 |
6 | def pack_sequence_dim(x):
7 | b, s = x.shape[:2]
8 | return x.view(b * s, *x.shape[2:])
9 |
10 |
11 | def unpack_sequence_dim(x, b, s):
12 | return x.view(b, s, *x.shape[1:])
13 |
14 |
15 | def preprocess_batch(batch, device, unsqueeze=False):
16 | for key, value in batch.items():
17 | if key != 'sample_token':
18 | batch[key] = value.to(device)
19 | if unsqueeze:
20 | batch[key] = batch[key].unsqueeze(0)
21 |
22 |
23 | def set_module_grad(module, requires_grad=False):
24 | for p in module.parameters():
25 | p.requires_grad = requires_grad
26 |
27 |
28 | def set_bn_momentum(model, momentum=0.1):
29 | for m in model.modules():
30 | if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
31 | m.momentum = momentum
32 |
33 |
34 | class NormalizeInverse(torchvision.transforms.Normalize):
35 | # https://discuss.pytorch.org/t/simple-way-to-inverse-transform-normalization/4821/8
36 | def __init__(self, mean, std):
37 | mean = torch.as_tensor(mean)
38 | std = torch.as_tensor(std)
39 | std_inv = 1 / (std + 1e-7)
40 | mean_inv = -mean * std_inv
41 | super().__init__(mean=mean_inv, std=std_inv)
42 |
43 | def __call__(self, tensor):
44 | return super().__call__(tensor.clone())
45 |
--------------------------------------------------------------------------------
/stretchbev/utils/.ipynb_checkpoints/network-checkpoint.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torchvision
4 |
5 | def pack_sequence_dim(x):
6 | b, s = x.shape[:2]
7 | return x.view(b * s, *x.shape[2:])
8 |
9 |
10 | def unpack_sequence_dim(x, b, s):
11 | return x.view(b, s, *x.shape[1:])
12 |
13 |
14 | def preprocess_batch(batch, device, unsqueeze=False):
15 | for key, value in batch.items():
16 | if key != 'sample_token':
17 | batch[key] = value.to(device)
18 | if unsqueeze:
19 | batch[key] = batch[key].unsqueeze(0)
20 |
21 |
22 | def set_module_grad(module, requires_grad=False):
23 | for p in module.parameters():
24 | p.requires_grad = requires_grad
25 |
26 |
27 | def set_bn_momentum(model, momentum=0.1):
28 | for m in model.modules():
29 | if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
30 | m.momentum = momentum
31 |
32 |
33 | class NormalizeInverse(torchvision.transforms.Normalize):
34 | # https://discuss.pytorch.org/t/simple-way-to-inverse-transform-normalization/4821/8
35 | def __init__(self, mean, std):
36 | mean = torch.as_tensor(mean)
37 | std = torch.as_tensor(std)
38 | std_inv = 1 / (std + 1e-7)
39 | mean_inv = -mean * std_inv
40 | super().__init__(mean=mean_inv, std=std_inv)
41 |
42 | def __call__(self, tensor):
43 | return super().__call__(tensor.clone())
44 |
--------------------------------------------------------------------------------
/stretchbev/models/future_prediction.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from stretchbev.layers.convolutions import Bottleneck
4 | from stretchbev.layers.temporal import SpatialGRU
5 |
6 |
7 | class FuturePrediction(torch.nn.Module):
8 | def __init__(self, in_channels, latent_dim, n_gru_blocks=3, n_res_layers=3):
9 | super().__init__()
10 | self.n_gru_blocks = n_gru_blocks
11 |
12 | # Convolutional recurrent model with z_t as an initial hidden state and inputs the sample
13 | # from the probabilistic model. The architecture of the model is:
14 | # [Spatial GRU - [Bottleneck] x n_res_layers] x n_gru_blocks
15 | self.spatial_grus = []
16 | self.res_blocks = []
17 |
18 | for i in range(self.n_gru_blocks):
19 | gru_in_channels = latent_dim if i == 0 else in_channels
20 | self.spatial_grus.append(SpatialGRU(gru_in_channels, in_channels))
21 | self.res_blocks.append(torch.nn.Sequential(*[Bottleneck(in_channels)
22 | for _ in range(n_res_layers)]))
23 |
24 | self.spatial_grus = torch.nn.ModuleList(self.spatial_grus)
25 | self.res_blocks = torch.nn.ModuleList(self.res_blocks)
26 |
27 | def forward(self, x, hidden_state=None):
28 | # x has shape (b, n_future, c, h, w), hidden_state (b, c, h, w)
29 | for i in range(self.n_gru_blocks):
30 | x = self.spatial_grus[i](x, hidden_state, flow=None)
31 | b, n_future, c, h, w = x.shape
32 |
33 | x = self.res_blocks[i](x.view(b * n_future, c, h, w))
34 | x = x.view(b, n_future, c, h, w)
35 |
36 | return x
37 |
--------------------------------------------------------------------------------
/stretchbev/models/.ipynb_checkpoints/future_prediction-checkpoint.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from fiery.layers.convolutions import Bottleneck
4 | from fiery.layers.temporal import SpatialGRU
5 |
6 |
7 | class FuturePrediction(torch.nn.Module):
8 | def __init__(self, in_channels, latent_dim, n_gru_blocks=3, n_res_layers=3):
9 | super().__init__()
10 | self.n_gru_blocks = n_gru_blocks
11 |
12 | # Convolutional recurrent model with z_t as an initial hidden state and inputs the sample
13 | # from the probabilistic model. The architecture of the model is:
14 | # [Spatial GRU - [Bottleneck] x n_res_layers] x n_gru_blocks
15 | self.spatial_grus = []
16 | self.res_blocks = []
17 |
18 | for i in range(self.n_gru_blocks):
19 | gru_in_channels = latent_dim if i == 0 else in_channels
20 | self.spatial_grus.append(SpatialGRU(gru_in_channels, in_channels))
21 | self.res_blocks.append(torch.nn.Sequential(*[Bottleneck(in_channels)
22 | for _ in range(n_res_layers)]))
23 |
24 | self.spatial_grus = torch.nn.ModuleList(self.spatial_grus)
25 | self.res_blocks = torch.nn.ModuleList(self.res_blocks)
26 |
27 | def forward(self, x, hidden_state=None):
28 | # x has shape (b, n_future, c, h, w), hidden_state (b, c, h, w)
29 | for i in range(self.n_gru_blocks):
30 | x = self.spatial_grus[i](x, hidden_state, flow=None)
31 | b, n_future, c, h, w = x.shape
32 |
33 | x = self.res_blocks[i](x.view(b * n_future, c, h, w))
34 | x = x.view(b, n_future, c, h, w)
35 |
36 | return x
37 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import socket
3 | import time
4 |
5 | import pytorch_lightning as pl
6 | import torch
7 | from pytorch_lightning.callbacks import ModelCheckpoint
8 | from pytorch_lightning.plugins import DDPPlugin
9 |
10 | from stretchbev.config import get_parser, get_cfg
11 | from stretchbev.data import prepare_dataloaders
12 | from stretchbev.trainer import TrainingModule
13 |
14 |
15 | def main():
16 | args = get_parser().parse_args()
17 | cfg = get_cfg(args)
18 |
19 | trainloader, valloader = prepare_dataloaders(cfg)
20 | model = TrainingModule(cfg.convert_to_dict())
21 | model.len_loader = len(trainloader)
22 |
23 | if cfg.PRETRAINED.LOAD_WEIGHTS:
24 | # Load single-image instance segmentation model.
25 | pretrained_model_weights = torch.load(
26 | os.path.join('.', cfg.PRETRAINED.PATH), map_location='cpu'
27 | )['state_dict']
28 |
29 | new_dict = {key: val for (key, val) in pretrained_model_weights.items() if 'decoder' not in key}
30 | model.load_state_dict(new_dict, strict=False)
31 | print(f'Loaded single-image model weights from {cfg.PRETRAINED.PATH}')
32 |
33 | save_dir = os.path.join(
34 | cfg.LOG_DIR, time.strftime('%d%B%Yat%H:%M:%S%Z') + '_' + socket.gethostname() + '_' + cfg.TAG
35 | )
36 | checkpoint_callback = ModelCheckpoint(dirpath='weights', filename='stretchbev-{epoch:02d}', save_top_k=-1)
37 | tb_logger = pl.loggers.TensorBoardLogger(save_dir=save_dir)
38 | trainer = pl.Trainer(
39 | gpus=cfg.GPUS,
40 | accelerator='ddp',
41 | precision=cfg.PRECISION,
42 | sync_batchnorm=True,
43 | gradient_clip_val=cfg.GRAD_NORM_CLIP,
44 | max_epochs=cfg.EPOCHS,
45 | weights_summary='full',
46 | logger=tb_logger,
47 | log_every_n_steps=cfg.LOGGING_INTERVAL,
48 | plugins=DDPPlugin(find_unused_parameters=True),
49 | profiler='simple',
50 | callbacks=[checkpoint_callback]
51 | )
52 | trainer.fit(model, trainloader, valloader)
53 |
54 |
55 | if __name__ == "__main__":
56 | main()
57 |
--------------------------------------------------------------------------------
/stretchbev/models/distributions.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from stretchbev.layers.convolutions import Bottleneck
5 |
6 |
7 | class DistributionModule(nn.Module):
8 | """
9 | A convolutional net that parametrises a diagonal Gaussian distribution.
10 | """
11 |
12 | def __init__(self, in_channels, latent_dim, min_log_sigma, max_log_sigma):
13 | super().__init__()
14 | self.compress_dim = in_channels // 2
15 | self.latent_dim = latent_dim
16 | self.min_log_sigma = min_log_sigma
17 | self.max_log_sigma = max_log_sigma
18 |
19 | self.encoder = DistributionEncoder(
20 | in_channels,
21 | self.compress_dim,
22 | )
23 | self.last_conv = nn.Sequential(
24 | nn.AdaptiveAvgPool2d(1), nn.Conv2d(self.compress_dim, out_channels=2 * self.latent_dim, kernel_size=1)
25 | )
26 |
27 | def forward(self, s_t):
28 | b, s = s_t.shape[:2]
29 | assert s == 1
30 | encoding = self.encoder(s_t[:, 0])
31 |
32 | mu_log_sigma = self.last_conv(encoding).view(b, 1, 2 * self.latent_dim)
33 | mu = mu_log_sigma[:, :, :self.latent_dim]
34 | log_sigma = mu_log_sigma[:, :, self.latent_dim:]
35 |
36 | # clip the log_sigma value for numerical stability
37 | log_sigma = torch.clamp(log_sigma, self.min_log_sigma, self.max_log_sigma)
38 | return mu, log_sigma
39 |
40 |
41 | class DistributionEncoder(nn.Module):
42 | """Encodes s_t or (s_t, y_{t+1}, ..., y_{t+H}).
43 | """
44 |
45 | def __init__(self, in_channels, out_channels):
46 | super().__init__()
47 |
48 | self.model = nn.Sequential(
49 | Bottleneck(in_channels, out_channels=out_channels, downsample=True),
50 | Bottleneck(out_channels, out_channels=out_channels, downsample=True),
51 | Bottleneck(out_channels, out_channels=out_channels, downsample=True),
52 | Bottleneck(out_channels, out_channels=out_channels, downsample=True),
53 | )
54 |
55 | def forward(self, s_t):
56 | return self.model(s_t)
57 |
--------------------------------------------------------------------------------
/stretchbev/models/temporal_model.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | from stretchbev.layers.temporal import Bottleneck3D, TemporalBlock
4 |
5 |
6 | class TemporalModel(nn.Module):
7 | def __init__(
8 | self, in_channels, receptive_field, input_shape, start_out_channels=64, extra_in_channels=0,
9 | n_spatial_layers_between_temporal_layers=0, use_pyramid_pooling=True):
10 | super().__init__()
11 | self.receptive_field = receptive_field
12 | n_temporal_layers = receptive_field - 1
13 |
14 | h, w = input_shape
15 | modules = []
16 |
17 | block_in_channels = in_channels
18 | block_out_channels = start_out_channels
19 |
20 | for _ in range(n_temporal_layers):
21 | if use_pyramid_pooling:
22 | use_pyramid_pooling = True
23 | pool_sizes = [(2, h, w)]
24 | else:
25 | use_pyramid_pooling = False
26 | pool_sizes = None
27 | temporal = TemporalBlock(
28 | block_in_channels,
29 | block_out_channels,
30 | use_pyramid_pooling=use_pyramid_pooling,
31 | pool_sizes=pool_sizes,
32 | )
33 | spatial = [
34 | Bottleneck3D(block_out_channels, block_out_channels, kernel_size=(1, 3, 3))
35 | for _ in range(n_spatial_layers_between_temporal_layers)
36 | ]
37 | temporal_spatial_layers = nn.Sequential(temporal, *spatial)
38 | modules.extend(temporal_spatial_layers)
39 |
40 | block_in_channels = block_out_channels
41 | block_out_channels += extra_in_channels
42 |
43 | self.out_channels = block_in_channels
44 |
45 | self.model = nn.Sequential(*modules)
46 |
47 | def forward(self, x):
48 | # Reshape input tensor to (batch, C, time, H, W)
49 | x = x.permute(0, 2, 1, 3, 4)
50 | x = self.model(x)
51 | x = x.permute(0, 2, 1, 3, 4).contiguous()
52 | return x[:, (self.receptive_field - 1):]
53 |
54 |
55 | class TemporalModelIdentity(nn.Module):
56 | def __init__(self, in_channels, receptive_field):
57 | super().__init__()
58 | self.receptive_field = receptive_field
59 | self.out_channels = in_channels
60 |
61 | def forward(self, x):
62 | return x[:, (self.receptive_field - 1):]
63 |
--------------------------------------------------------------------------------
/stretchbev/models/.ipynb_checkpoints/temporal_model-checkpoint.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | from fiery.layers.temporal import Bottleneck3D, TemporalBlock
4 |
5 |
6 | class TemporalModel(nn.Module):
7 | def __init__(
8 | self, in_channels, receptive_field, input_shape, start_out_channels=64, extra_in_channels=0,
9 | n_spatial_layers_between_temporal_layers=0, use_pyramid_pooling=True):
10 | super().__init__()
11 | self.receptive_field = receptive_field
12 | n_temporal_layers = receptive_field - 1
13 |
14 | h, w = input_shape
15 | modules = []
16 |
17 | block_in_channels = in_channels
18 | block_out_channels = start_out_channels
19 |
20 | for _ in range(n_temporal_layers):
21 | if use_pyramid_pooling:
22 | use_pyramid_pooling = True
23 | pool_sizes = [(2, h, w)]
24 | else:
25 | use_pyramid_pooling = False
26 | pool_sizes = None
27 | temporal = TemporalBlock(
28 | block_in_channels,
29 | block_out_channels,
30 | use_pyramid_pooling=use_pyramid_pooling,
31 | pool_sizes=pool_sizes,
32 | )
33 | spatial = [
34 | Bottleneck3D(block_out_channels, block_out_channels, kernel_size=(1, 3, 3))
35 | for _ in range(n_spatial_layers_between_temporal_layers)
36 | ]
37 | temporal_spatial_layers = nn.Sequential(temporal, *spatial)
38 | modules.extend(temporal_spatial_layers)
39 |
40 | block_in_channels = block_out_channels
41 | block_out_channels += extra_in_channels
42 |
43 | self.out_channels = block_in_channels
44 |
45 | self.model = nn.Sequential(*modules)
46 |
47 | def forward(self, x):
48 | # Reshape input tensor to (batch, C, time, H, W)
49 | x = x.permute(0, 2, 1, 3, 4)
50 | x = self.model(x)
51 | x = x.permute(0, 2, 1, 3, 4).contiguous()
52 | return x[:, (self.receptive_field - 1):]
53 |
54 |
55 | class TemporalModelIdentity(nn.Module):
56 | def __init__(self, in_channels, receptive_field):
57 | super().__init__()
58 | self.receptive_field = receptive_field
59 | self.out_channels = in_channels
60 |
61 | def forward(self, x):
62 | return x[:, (self.receptive_field - 1):]
63 |
--------------------------------------------------------------------------------
/stretchbev/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class SpatialRegressionLoss(nn.Module):
7 | def __init__(self, norm, ignore_index=255, future_discount=1.0):
8 | super(SpatialRegressionLoss, self).__init__()
9 | self.norm = norm
10 | self.ignore_index = ignore_index
11 | self.future_discount = future_discount
12 |
13 | if norm == 1:
14 | self.loss_fn = F.l1_loss
15 | elif norm == 2:
16 | self.loss_fn = F.mse_loss
17 | else:
18 | raise ValueError(f'Expected norm 1 or 2, but got norm={norm}')
19 |
20 | def forward(self, prediction, target):
21 | assert len(prediction.shape) == 5, 'Must be a 5D tensor'
22 | # ignore_index is the same across all channels
23 | mask = target[:, :, :1] != self.ignore_index
24 | if mask.sum() == 0:
25 | return prediction.new_zeros(1)[0].float()
26 |
27 | loss = self.loss_fn(prediction, target, reduction='none')
28 |
29 | # Sum channel dimension
30 | loss = torch.sum(loss, dim=-3, keepdims=True)
31 |
32 | seq_len = loss.shape[1]
33 | future_discounts = self.future_discount ** torch.arange(seq_len, device=loss.device, dtype=loss.dtype)
34 | future_discounts = future_discounts.view(1, seq_len, 1, 1, 1)
35 | loss = loss * future_discounts
36 |
37 | return loss[mask].mean()
38 |
39 |
40 | class SegmentationLoss(nn.Module):
41 | def __init__(self, class_weights, ignore_index=255, use_top_k=False, top_k_ratio=1.0, future_discount=1.0):
42 | super().__init__()
43 | self.class_weights = class_weights
44 | self.ignore_index = ignore_index
45 | self.use_top_k = use_top_k
46 | self.top_k_ratio = top_k_ratio
47 | self.future_discount = future_discount
48 |
49 | def forward(self, prediction, target):
50 | if target.shape[-3] != 1:
51 | raise ValueError('segmentation label must be an index-label with channel dimension = 1.')
52 | b, s, c, h, w = prediction.shape
53 |
54 | prediction = prediction.view(b * s, c, h, w)
55 | target = target.view(b * s, h, w)
56 | loss = F.cross_entropy(
57 | prediction,
58 | target,
59 | ignore_index=self.ignore_index,
60 | reduction='none',
61 | weight=self.class_weights.to(target.device),
62 | )
63 |
64 | loss = loss.view(b, s, h, w)
65 |
66 | future_discounts = self.future_discount ** torch.arange(s, device=loss.device, dtype=loss.dtype)
67 | future_discounts = future_discounts.view(1, s, 1, 1)
68 | loss = loss * future_discounts
69 |
70 | loss = loss.view(b, s, -1)
71 | if self.use_top_k:
72 | # Penalises the top-k hardest pixels
73 | k = int(self.top_k_ratio * loss.shape[2])
74 | loss, _ = torch.sort(loss, dim=2, descending=True)
75 | loss = loss[:, :, :k]
76 |
77 | return torch.mean(loss)
78 |
79 |
80 | class ProbabilisticLoss(nn.Module):
81 | def forward(self, output):
82 | present_mu = output['present_mu']
83 | present_log_sigma = output['present_log_sigma']
84 | future_mu = output['future_mu']
85 | future_log_sigma = output['future_log_sigma']
86 |
87 | var_future = torch.exp(2 * future_log_sigma)
88 | var_present = torch.exp(2 * present_log_sigma)
89 | kl_div = (
90 | present_log_sigma - future_log_sigma - 0.5 + (var_future + (future_mu - present_mu) ** 2) / (
91 | 2 * var_present)
92 | )
93 |
94 | kl_loss = torch.mean(torch.sum(kl_div, dim=-1))
95 |
96 | return kl_loss
97 |
--------------------------------------------------------------------------------
/stretchbev/.ipynb_checkpoints/losses-checkpoint.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class SpatialRegressionLoss(nn.Module):
7 | def __init__(self, norm, ignore_index=255, future_discount=1.0):
8 | super(SpatialRegressionLoss, self).__init__()
9 | self.norm = norm
10 | self.ignore_index = ignore_index
11 | self.future_discount = future_discount
12 |
13 | if norm == 1:
14 | self.loss_fn = F.l1_loss
15 | elif norm == 2:
16 | self.loss_fn = F.mse_loss
17 | else:
18 | raise ValueError(f'Expected norm 1 or 2, but got norm={norm}')
19 |
20 | def forward(self, prediction, target):
21 | assert len(prediction.shape) == 5, 'Must be a 5D tensor'
22 | # ignore_index is the same across all channels
23 | mask = target[:, :, :1] != self.ignore_index
24 | if mask.sum() == 0:
25 | return prediction.new_zeros(1)[0].float()
26 |
27 | loss = self.loss_fn(prediction, target, reduction='none')
28 |
29 | # Sum channel dimension
30 | loss = torch.sum(loss, dim=-3, keepdims=True)
31 |
32 | seq_len = loss.shape[1]
33 | future_discounts = self.future_discount ** torch.arange(seq_len, device=loss.device, dtype=loss.dtype)
34 | future_discounts = future_discounts.view(1, seq_len, 1, 1, 1)
35 | loss = loss * future_discounts
36 |
37 | return loss[mask].mean()
38 |
39 |
40 | class SegmentationLoss(nn.Module):
41 | def __init__(self, class_weights, ignore_index=255, use_top_k=False, top_k_ratio=1.0, future_discount=1.0):
42 | super().__init__()
43 | self.class_weights = class_weights
44 | self.ignore_index = ignore_index
45 | self.use_top_k = use_top_k
46 | self.top_k_ratio = top_k_ratio
47 | self.future_discount = future_discount
48 |
49 | def forward(self, prediction, target):
50 | if target.shape[-3] != 1:
51 | raise ValueError('segmentation label must be an index-label with channel dimension = 1.')
52 | b, s, c, h, w = prediction.shape
53 |
54 | prediction = prediction.view(b * s, c, h, w)
55 | target = target.view(b * s, h, w)
56 | loss = F.cross_entropy(
57 | prediction,
58 | target,
59 | ignore_index=self.ignore_index,
60 | reduction='none',
61 | weight=self.class_weights.to(target.device),
62 | )
63 |
64 | loss = loss.view(b, s, h, w)
65 |
66 | future_discounts = self.future_discount ** torch.arange(s, device=loss.device, dtype=loss.dtype)
67 | future_discounts = future_discounts.view(1, s, 1, 1)
68 | loss = loss * future_discounts
69 |
70 | loss = loss.view(b, s, -1)
71 | if self.use_top_k:
72 | # Penalises the top-k hardest pixels
73 | k = int(self.top_k_ratio * loss.shape[2])
74 | loss, _ = torch.sort(loss, dim=2, descending=True)
75 | loss = loss[:, :, :k]
76 |
77 | return torch.mean(loss)
78 |
79 |
80 | class ProbabilisticLoss(nn.Module):
81 | def forward(self, output):
82 | present_mu = output['present_mu']
83 | present_log_sigma = output['present_log_sigma']
84 | future_mu = output['future_mu']
85 | future_log_sigma = output['future_log_sigma']
86 |
87 | var_future = torch.exp(2 * future_log_sigma)
88 | var_present = torch.exp(2 * present_log_sigma)
89 | kl_div = (
90 | present_log_sigma - future_log_sigma - 0.5 + (var_future + (future_mu - present_mu) ** 2) / (
91 | 2 * var_present)
92 | )
93 |
94 | kl_loss = torch.mean(torch.sum(kl_div, dim=-1))
95 |
96 | return kl_loss
97 |
--------------------------------------------------------------------------------
/stretchbev/models/decoder.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from torchvision.models.resnet import resnet18
3 |
4 | from stretchbev.layers.convolutions import UpsamplingAdd
5 |
6 |
7 | class Decoder(nn.Module):
8 | def __init__(self, in_channels, n_classes, predict_future_flow):
9 | super().__init__()
10 | backbone = resnet18(pretrained=False, zero_init_residual=True)
11 | self.first_conv = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
12 | self.bn1 = backbone.bn1
13 | self.relu = backbone.relu
14 |
15 | self.layer1 = backbone.layer1
16 | self.layer2 = backbone.layer2
17 | self.layer3 = backbone.layer3
18 | self.predict_future_flow = predict_future_flow
19 |
20 | shared_out_channels = in_channels
21 | self.up3_skip = UpsamplingAdd(256, 128, scale_factor=2)
22 | self.up2_skip = UpsamplingAdd(128, 64, scale_factor=2)
23 | self.up1_skip = UpsamplingAdd(64, shared_out_channels, scale_factor=2)
24 |
25 | self.segmentation_head = nn.Sequential(
26 | nn.Conv2d(shared_out_channels, shared_out_channels, kernel_size=3, padding=1, bias=False),
27 | nn.BatchNorm2d(shared_out_channels),
28 | nn.ReLU(inplace=True),
29 | nn.Conv2d(shared_out_channels, n_classes, kernel_size=1, padding=0),
30 | )
31 | self.instance_offset_head = nn.Sequential(
32 | nn.Conv2d(shared_out_channels, shared_out_channels, kernel_size=3, padding=1, bias=False),
33 | nn.BatchNorm2d(shared_out_channels),
34 | nn.ReLU(inplace=True),
35 | nn.Conv2d(shared_out_channels, 2, kernel_size=1, padding=0),
36 | )
37 | self.instance_center_head = nn.Sequential(
38 | nn.Conv2d(shared_out_channels, shared_out_channels, kernel_size=3, padding=1, bias=False),
39 | nn.BatchNorm2d(shared_out_channels),
40 | nn.ReLU(inplace=True),
41 | nn.Conv2d(shared_out_channels, 1, kernel_size=1, padding=0),
42 | nn.Sigmoid(),
43 | )
44 |
45 | if self.predict_future_flow:
46 | self.instance_future_head = nn.Sequential(
47 | nn.Conv2d(shared_out_channels, shared_out_channels, kernel_size=3, padding=1, bias=False),
48 | nn.BatchNorm2d(shared_out_channels),
49 | nn.ReLU(inplace=True),
50 | nn.Conv2d(shared_out_channels, 2, kernel_size=1, padding=0),
51 | )
52 |
53 | def forward(self, x):
54 | b, s, c, h, w = x.shape
55 | x = x.view(b * s, c, h, w)
56 |
57 | # (H, W)
58 | skip_x = {'1': x}
59 | x = self.first_conv(x)
60 | x = self.bn1(x)
61 | x = self.relu(x)
62 |
63 | # (H/4, W/4)
64 | x = self.layer1(x)
65 | skip_x['2'] = x
66 | x = self.layer2(x)
67 | skip_x['3'] = x
68 |
69 | # (H/8, W/8)
70 | x = self.layer3(x)
71 |
72 | # First upsample to (H/4, W/4)
73 | x = self.up3_skip(x, skip_x['3'])
74 |
75 | # Second upsample to (H/2, W/2)
76 | x = self.up2_skip(x, skip_x['2'])
77 |
78 | # Third upsample to (H, W)
79 | x = self.up1_skip(x, skip_x['1'])
80 |
81 | segmentation_output = self.segmentation_head(x)
82 | instance_center_output = self.instance_center_head(x)
83 | instance_offset_output = self.instance_offset_head(x)
84 | instance_future_output = self.instance_future_head(x) if self.predict_future_flow else None
85 | return {
86 | 'segmentation': segmentation_output.view(b, s, *segmentation_output.shape[1:]),
87 | 'instance_center': instance_center_output.view(b, s, *instance_center_output.shape[1:]),
88 | 'instance_offset': instance_offset_output.view(b, s, *instance_offset_output.shape[1:]),
89 | 'instance_flow': instance_future_output.view(b, s, *instance_future_output.shape[1:])
90 | if instance_future_output is not None else None,
91 | }
92 |
--------------------------------------------------------------------------------
/stretchbev/models/.ipynb_checkpoints/decoder-checkpoint.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from torchvision.models.resnet import resnet18
3 |
4 | from fiery.layers.convolutions import UpsamplingAdd
5 |
6 |
7 | class Decoder(nn.Module):
8 | def __init__(self, in_channels, n_classes, predict_future_flow):
9 | super().__init__()
10 | backbone = resnet18(pretrained=False, zero_init_residual=True)
11 | self.first_conv = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
12 | self.bn1 = backbone.bn1
13 | self.relu = backbone.relu
14 |
15 | self.layer1 = backbone.layer1
16 | self.layer2 = backbone.layer2
17 | self.layer3 = backbone.layer3
18 | self.predict_future_flow = predict_future_flow
19 |
20 | shared_out_channels = in_channels
21 | self.up3_skip = UpsamplingAdd(256, 128, scale_factor=2)
22 | self.up2_skip = UpsamplingAdd(128, 64, scale_factor=2)
23 | self.up1_skip = UpsamplingAdd(64, shared_out_channels, scale_factor=2)
24 |
25 | self.segmentation_head = nn.Sequential(
26 | nn.Conv2d(shared_out_channels, shared_out_channels, kernel_size=3, padding=1, bias=False),
27 | nn.BatchNorm2d(shared_out_channels),
28 | nn.ReLU(inplace=True),
29 | nn.Conv2d(shared_out_channels, n_classes, kernel_size=1, padding=0),
30 | )
31 | self.instance_offset_head = nn.Sequential(
32 | nn.Conv2d(shared_out_channels, shared_out_channels, kernel_size=3, padding=1, bias=False),
33 | nn.BatchNorm2d(shared_out_channels),
34 | nn.ReLU(inplace=True),
35 | nn.Conv2d(shared_out_channels, 2, kernel_size=1, padding=0),
36 | )
37 | self.instance_center_head = nn.Sequential(
38 | nn.Conv2d(shared_out_channels, shared_out_channels, kernel_size=3, padding=1, bias=False),
39 | nn.BatchNorm2d(shared_out_channels),
40 | nn.ReLU(inplace=True),
41 | nn.Conv2d(shared_out_channels, 1, kernel_size=1, padding=0),
42 | nn.Sigmoid(),
43 | )
44 |
45 | if self.predict_future_flow:
46 | self.instance_future_head = nn.Sequential(
47 | nn.Conv2d(shared_out_channels, shared_out_channels, kernel_size=3, padding=1, bias=False),
48 | nn.BatchNorm2d(shared_out_channels),
49 | nn.ReLU(inplace=True),
50 | nn.Conv2d(shared_out_channels, 2, kernel_size=1, padding=0),
51 | )
52 |
53 | def forward(self, x):
54 | b, s, c, h, w = x.shape
55 | x = x.view(b * s, c, h, w)
56 |
57 | # (H, W)
58 | skip_x = {'1': x}
59 | x = self.first_conv(x)
60 | x = self.bn1(x)
61 | x = self.relu(x)
62 |
63 | # (H/4, W/4)
64 | x = self.layer1(x)
65 | skip_x['2'] = x
66 | x = self.layer2(x)
67 | skip_x['3'] = x
68 |
69 | # (H/8, W/8)
70 | x = self.layer3(x)
71 |
72 | # First upsample to (H/4, W/4)
73 | x = self.up3_skip(x, skip_x['3'])
74 |
75 | # Second upsample to (H/2, W/2)
76 | x = self.up2_skip(x, skip_x['2'])
77 |
78 | # Third upsample to (H, W)
79 | x = self.up1_skip(x, skip_x['1'])
80 |
81 | segmentation_output = self.segmentation_head(x)
82 | instance_center_output = self.instance_center_head(x)
83 | instance_offset_output = self.instance_offset_head(x)
84 | instance_future_output = self.instance_future_head(x) if self.predict_future_flow else None
85 | return {
86 | 'segmentation': segmentation_output.view(b, s, *segmentation_output.shape[1:]),
87 | 'instance_center': instance_center_output.view(b, s, *instance_center_output.shape[1:]),
88 | 'instance_offset': instance_offset_output.view(b, s, *instance_offset_output.shape[1:]),
89 | 'instance_flow': instance_future_output.view(b, s, *instance_future_output.shape[1:])
90 | if instance_future_output is not None else None,
91 | }
92 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # StretchBEV: Stretching Future Instance Prediction Spatially and Temporally (ECCV 2022)
2 |
3 |
4 | [](https://arxiv.org/abs/2203.13641)
5 | [](https://kuis-ai.github.io/stretchbev/)
6 | [](https://github.com/kaanakan/stretchbev/releases/tag/v1.0)
7 | [](https://www.youtube.com/watch?v=2SiUNs6BMVk)
8 |
9 |
10 |
11 | > [**StretchBEV: Stretching Future Instance Prediction Spatially and Temporally**](https://arxiv.org/abs/2203.13641),
12 | > [Adil Kaan Akan](https://kaanakan.github.io),
13 | > [Fatma Guney](https://mysite.ku.edu.tr/fguney/),
14 | > *European Conference on Computer Vision (ECCV), 2022*
15 |
16 |
17 |
18 |
19 |
20 | ## Features
21 |
22 | StretchBEV is a future instance prediction network in Bird's-eye view representation. It earns temporal dynamics in a latent space through stochastic residual updates at each time step. By sampling from a learned distribution at each time step, we obtain more diverse future predictions that are also more accurate compared to previous work, especially stretching both spatially further regions in the scene and temporally over longer time horizons
23 |
24 |
25 | ## Requirements
26 |
27 | All models were trained with Python 3.7.10 and PyTorch 1.7.0
28 |
29 | A list of required Python packages is available in the `environment.yml` file.
30 |
31 |
32 |
33 | ## Datasets
34 |
35 | For preparations of datasets, we followed [FIERY](https://github.com/wayveai/fiery). Please follow [this link](https://github.com/wayveai/fiery/blob/master/DATASET.md) below if you want to construct the datasets.
36 |
37 |
38 | ## Training
39 |
40 | To train the model on NuScenes:
41 |
42 | - First, you need to download [`static_lift_splat_setting.ckpt`](https://github.com/wayveai/fiery/releases/download/v1.0/static_lift_splat_setting.ckpt) and copy it to this directory.
43 | - Run `python train.py --config fiery/configs/baseline.yml DATASET.DATAROOT ${NUSCENES_DATAROOT}`.
44 |
45 | This will train the model on 4 GPUs, each with a batch of size 2. To train on single GPU add the flag `GPUS 1`, and to change the batch size use the flag `BATCHSIZE ${DESIRED_BATCHSIZE}`.
46 |
47 |
48 | ## Evaluation
49 |
50 | To evaluate a trained model on NuScenes:
51 |
52 | - Download [pre-trained weights](https://github.com/wayveai/fiery/releases/download/v1.0/stretchbev.ckpt).
53 | - Run `python evaluate.py --checkpoint ${CHECKPOINT_PATH} --dataroot ${NUSCENES_DATAROOT}`.
54 |
55 | ### Pretrained weights
56 |
57 | You can download the pretrained weights from the releases of this repository or the links below.
58 |
59 | [Normal setting weight](https://github.com/wayveai/fiery/releases/download/v1.0/stretchbev.ckpt)
60 |
61 | [Fishing setting weight](https://github.com/wayveai/fiery/releases/download/v1.0/stretchbev_fishing.ckpt)
62 |
63 |
64 |
65 | ## How to Cite
66 |
67 | Please cite the paper if you benefit from our paper or the repository:
68 |
69 | ```
70 | @InProceedings{Akan2022ECCV,
71 | author = {Akan, Adil Kaan and G\"uney, Fatma},
72 | title = {StretchBEV: Stretching Future Instance Prediction Spatially and Temporally},
73 | journal = {European Conference on Computer Vision (ECCV)},
74 | year = {2022},
75 | }
76 | ```
77 |
78 | ## Acknowledgments
79 |
80 | We would like to thank FIERY and SRVP authors for making their repositories public. This repository contains several code segments from [FIERY's repository](https://github.com/wayveai/fiery) and [SRVP's repository](https://github.com/edouardelasalles/srvp). We appreciate the efforts by [Berkay Ugur Senocak](https://github.com/4turkuaz) for cleaning the code before release.
81 |
--------------------------------------------------------------------------------
/stretchbev/models/encoder.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from efficientnet_pytorch import EfficientNet
3 |
4 | from stretchbev.layers.convolutions import UpsamplingConcat
5 |
6 |
7 | class Encoder(nn.Module):
8 | def __init__(self, cfg, D):
9 | super().__init__()
10 | self.D = D
11 | self.C = cfg.OUT_CHANNELS
12 | self.use_depth_distribution = cfg.USE_DEPTH_DISTRIBUTION
13 | self.downsample = cfg.DOWNSAMPLE
14 | self.version = cfg.NAME.split('-')[1]
15 |
16 | self.backbone = EfficientNet.from_pretrained(cfg.NAME)
17 | self.delete_unused_layers()
18 |
19 | if self.downsample == 16:
20 | if self.version == 'b0':
21 | upsampling_in_channels = 320 + 112
22 | elif self.version == 'b4':
23 | upsampling_in_channels = 448 + 160
24 | upsampling_out_channels = 512
25 | elif self.downsample == 8:
26 | if self.version == 'b0':
27 | upsampling_in_channels = 112 + 40
28 | elif self.version == 'b4':
29 | upsampling_in_channels = 160 + 56
30 | upsampling_out_channels = 128
31 | else:
32 | raise ValueError(f'Downsample factor {self.downsample} not handled.')
33 |
34 | self.upsampling_layer = UpsamplingConcat(upsampling_in_channels, upsampling_out_channels)
35 | if self.use_depth_distribution:
36 | self.depth_layer = nn.Conv2d(upsampling_out_channels, self.C + self.D, kernel_size=1, padding=0)
37 | else:
38 | self.depth_layer = nn.Conv2d(upsampling_out_channels, self.C, kernel_size=1, padding=0)
39 |
40 | def delete_unused_layers(self):
41 | indices_to_delete = []
42 | for idx in range(len(self.backbone._blocks)):
43 | if self.downsample == 8:
44 | if self.version == 'b0' and idx > 10:
45 | indices_to_delete.append(idx)
46 | if self.version == 'b4' and idx > 21:
47 | indices_to_delete.append(idx)
48 |
49 | for idx in reversed(indices_to_delete):
50 | del self.backbone._blocks[idx]
51 |
52 | del self.backbone._conv_head
53 | del self.backbone._bn1
54 | del self.backbone._avg_pooling
55 | del self.backbone._dropout
56 | del self.backbone._fc
57 |
58 | def get_features(self, x):
59 | # Adapted from https://github.com/lukemelas/EfficientNet-PyTorch/blob/master/efficientnet_pytorch/model.py#L231
60 | endpoints = dict()
61 |
62 | # Stem
63 | x = self.backbone._swish(self.backbone._bn0(self.backbone._conv_stem(x)))
64 | prev_x = x
65 |
66 | # Blocks
67 | for idx, block in enumerate(self.backbone._blocks):
68 | drop_connect_rate = self.backbone._global_params.drop_connect_rate
69 | if drop_connect_rate:
70 | drop_connect_rate *= float(idx) / len(self.backbone._blocks)
71 | x = block(x, drop_connect_rate=drop_connect_rate)
72 | if prev_x.size(2) > x.size(2):
73 | endpoints['reduction_{}'.format(len(endpoints) + 1)] = prev_x
74 | prev_x = x
75 |
76 | if self.downsample == 8:
77 | if self.version == 'b0' and idx == 10:
78 | break
79 | if self.version == 'b4' and idx == 21:
80 | break
81 |
82 | # Head
83 | endpoints['reduction_{}'.format(len(endpoints) + 1)] = x
84 |
85 | if self.downsample == 16:
86 | input_1, input_2 = endpoints['reduction_5'], endpoints['reduction_4']
87 | elif self.downsample == 8:
88 | input_1, input_2 = endpoints['reduction_4'], endpoints['reduction_3']
89 |
90 | x = self.upsampling_layer(input_1, input_2)
91 | return x
92 |
93 | def forward(self, x):
94 | x = self.get_features(x) # get feature vector
95 |
96 | x = self.depth_layer(x) # feature and depth head
97 |
98 | if self.use_depth_distribution:
99 | depth = x[:, : self.D].softmax(dim=1)
100 | x = depth.unsqueeze(1) * x[:, self.D: (self.D + self.C)].unsqueeze(2) # outer product depth and features
101 | else:
102 | x = x.unsqueeze(2).repeat(1, 1, self.D, 1, 1)
103 |
104 | return x
105 |
--------------------------------------------------------------------------------
/evaluate.py:
--------------------------------------------------------------------------------
1 | import random
2 | from argparse import ArgumentParser
3 |
4 | import torch
5 | from tqdm import tqdm
6 |
7 | from stretchbev.data import prepare_dataloaders
8 | from stretchbev.metrics import IntersectionOverUnion, PanopticMetric
9 | from stretchbev.trainer import TrainingModule
10 | from stretchbev.utils.instance import predict_instance_segmentation_and_trajectories
11 | from stretchbev.utils.network import preprocess_batch
12 |
13 | # 30mx30m, 100mx100m
14 | EVALUATION_RANGES = {
15 | '30x30': (70, 130),
16 | '100x100': (0, 200)
17 | }
18 |
19 |
20 | def eval(checkpoint_path, dataroot, version):
21 | trainer = TrainingModule.load_from_checkpoint(checkpoint_path, strict=True)
22 | print(f'Loaded weights from \n {checkpoint_path}')
23 | trainer.eval()
24 |
25 | device = torch.device('cuda:0')
26 | trainer.to(device)
27 | model = trainer.model
28 | model.eval()
29 |
30 | cfg = model.cfg
31 | cfg.GPUS = "[0]"
32 | cfg.BATCHSIZE = 1
33 |
34 | cfg.DATASET.DATAROOT = dataroot
35 | cfg.DATASET.VERSION = version
36 |
37 | _, valloader = prepare_dataloaders(cfg)
38 |
39 | panoptic_metrics = {}
40 | iou_metrics = {}
41 | n_classes = len(cfg.SEMANTIC_SEG.WEIGHTS)
42 | for key in EVALUATION_RANGES.keys():
43 | panoptic_metrics[key] = PanopticMetric(n_classes=n_classes, temporally_consistent=True).to(
44 | device)
45 | iou_metrics[key] = IntersectionOverUnion(n_classes).to(device)
46 |
47 | for i, batch in enumerate(tqdm(valloader)):
48 | preprocess_batch(batch, device)
49 | image = batch['image'] # [:, 3:]
50 | intrinsics = batch['intrinsics'] # [:, 3:]
51 | extrinsics = batch['extrinsics'] # [:, 3:]
52 | future_egomotion = batch['future_egomotion'] # [:, 3:]
53 |
54 | batch_size = image.shape[0]
55 |
56 | labels, future_distribution_inputs = trainer.prepare_future_labels(batch)
57 |
58 | with torch.no_grad():
59 | # Evaluate with mean prediction
60 | output = model.multi_sample_inference(image, intrinsics, extrinsics, future_egomotion, num_samples=1,
61 | future_distribution_inputs=future_distribution_inputs)[0]
62 | labels = {k: v[:, 3:] for k, v in labels.items()}
63 |
64 | # Consistent instance seg
65 | pred_consistent_instance_seg = predict_instance_segmentation_and_trajectories(
66 | output, compute_matched_centers=False, make_consistent=True
67 | )
68 |
69 | segmentation_pred = output['segmentation'].detach()
70 | segmentation_pred = torch.argmax(segmentation_pred, dim=2, keepdims=True)
71 |
72 | for key, grid in EVALUATION_RANGES.items():
73 | limits = slice(grid[0], grid[1])
74 | panoptic_metrics[key](pred_consistent_instance_seg[..., limits, limits].contiguous().detach(),
75 | labels['instance'][..., limits, limits].contiguous()
76 | )
77 |
78 | iou_metrics[key](segmentation_pred[..., limits, limits].contiguous(),
79 | labels['segmentation'][..., limits, limits].contiguous()
80 | )
81 | # iou_scores = iou_metrics[key].compute()
82 | # print(iou_scores[1].item())
83 | # iou_metrics[key].reset()
84 |
85 | results = {}
86 | for key, grid in EVALUATION_RANGES.items():
87 | panoptic_scores = panoptic_metrics[key].compute()
88 | for panoptic_key, value in panoptic_scores.items():
89 | results[f'{panoptic_key}'] = results.get(f'{panoptic_key}', []) + [100 * value[1].item()]
90 |
91 | iou_scores = iou_metrics[key].compute()
92 | results['iou'] = results.get('iou', []) + [100 * iou_scores[1].item()]
93 |
94 | for panoptic_key in ['iou', 'pq', 'sq', 'rq']:
95 | print(panoptic_key)
96 | print(' & '.join([f'{x:.1f}' for x in results[panoptic_key]]))
97 |
98 |
99 | if __name__ == '__main__':
100 | parser = ArgumentParser(description='Fiery evaluation')
101 | parser.add_argument('--checkpoint', default='./fiery.ckpt', type=str, help='path to checkpoint')
102 | parser.add_argument('--dataroot', default='./nuscenes', type=str, help='path to the dataset')
103 | parser.add_argument('--version', default='trainval', type=str, choices=['mini', 'trainval'],
104 | help='dataset version')
105 |
106 | args = parser.parse_args()
107 | torch.manual_seed(0)
108 | random.seed(0)
109 |
110 | eval(args.checkpoint, args.dataroot, args.version)
111 |
--------------------------------------------------------------------------------
/stretchbev/config.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from fvcore.common.config import CfgNode as _CfgNode
3 |
4 |
5 | def convert_to_dict(cfg_node, key_list=[]):
6 | """Convert a config node to dictionary."""
7 | _VALID_TYPES = {tuple, list, str, int, float, bool}
8 | if not isinstance(cfg_node, _CfgNode):
9 | if type(cfg_node) not in _VALID_TYPES:
10 | print(
11 | 'Key {} with value {} is not a valid type; valid types: {}'.format(
12 | '.'.join(key_list), type(cfg_node), _VALID_TYPES
13 | ),
14 | )
15 | return cfg_node
16 | else:
17 | cfg_dict = dict(cfg_node)
18 | for k, v in cfg_dict.items():
19 | cfg_dict[k] = convert_to_dict(v, key_list + [k])
20 | return cfg_dict
21 |
22 |
23 | class CfgNode(_CfgNode):
24 | """Remove once https://github.com/rbgirshick/yacs/issues/19 is merged."""
25 |
26 | def convert_to_dict(self):
27 | return convert_to_dict(self)
28 |
29 |
30 | CN = CfgNode
31 |
32 | _C = CN()
33 | _C.LOG_DIR = 'tensorboard_logs'
34 | _C.TAG = 'default'
35 |
36 | _C.GPUS = [0] # which gpus to use
37 | _C.PRECISION = 32 # 16bit or 32bit
38 | _C.BATCHSIZE = 3
39 | _C.EPOCHS = 20
40 |
41 | _C.N_WORKERS = 5
42 | _C.VIS_INTERVAL = 5000
43 | _C.LOGGING_INTERVAL = 500
44 |
45 | _C.PRETRAINED = CN()
46 | _C.PRETRAINED.LOAD_WEIGHTS = False
47 | _C.PRETRAINED.PATH = ''
48 |
49 | _C.DATASET = CN()
50 | _C.DATASET.DATAROOT = './nuscenes/'
51 | _C.DATASET.VERSION = 'trainval'
52 | _C.DATASET.NAME = 'nuscenes'
53 | _C.DATASET.IGNORE_INDEX = 255 # Ignore index when creating flow/offset labels
54 | _C.DATASET.FILTER_INVISIBLE_VEHICLES = True # Filter vehicles that are not visible from the cameras
55 |
56 | _C.TIME_RECEPTIVE_FIELD = 3 # how many frames of temporal context (1 for single timeframe)
57 | _C.N_FUTURE_FRAMES = 4 # how many time steps into the future to predict
58 |
59 | _C.IMAGE = CN()
60 | _C.IMAGE.FINAL_DIM = (224, 480)
61 | _C.IMAGE.RESIZE_SCALE = 0.3
62 | _C.IMAGE.TOP_CROP = 46
63 | _C.IMAGE.ORIGINAL_HEIGHT = 900 # Original input RGB camera height
64 | _C.IMAGE.ORIGINAL_WIDTH = 1600 # Original input RGB camera width
65 | _C.IMAGE.NAMES = ['CAM_FRONT_LEFT', 'CAM_FRONT', 'CAM_FRONT_RIGHT', 'CAM_BACK_LEFT', 'CAM_BACK', 'CAM_BACK_RIGHT']
66 |
67 | _C.LIFT = CN() # image to BEV lifting
68 | _C.LIFT.X_BOUND = [-50.0, 50.0, 0.5] # Forward
69 | _C.LIFT.Y_BOUND = [-50.0, 50.0, 0.5] # Sides
70 | _C.LIFT.Z_BOUND = [-10.0, 10.0, 20.0] # Height
71 | _C.LIFT.D_BOUND = [2.0, 50.0, 1.0]
72 |
73 | _C.MODEL = CN()
74 |
75 | _C.MODEL.ENCODER = CN()
76 | _C.MODEL.ENCODER.DOWNSAMPLE = 8
77 | _C.MODEL.ENCODER.NAME = 'efficientnet-b4'
78 | _C.MODEL.ENCODER.OUT_CHANNELS = 64
79 | _C.MODEL.ENCODER.USE_DEPTH_DISTRIBUTION = True
80 |
81 | _C.MODEL.SMALL_ENCODER = CN()
82 | _C.MODEL.SMALL_ENCODER.FILTER_SIZE = 64
83 | _C.MODEL.SMALL_ENCODER.SKIPCO = True
84 |
85 |
86 | _C.MODEL.FIRST_STATE = CN()
87 | _C.MODEL.FIRST_STATE.NUM_LAYERS = 3
88 |
89 | _C.MODEL.DYNAMICS = CN()
90 | _C.MODEL.DYNAMICS.NUM_LAYERS = 3
91 |
92 |
93 |
94 |
95 | _C.MODEL.DISTRIBUTION = CN()
96 | _C.MODEL.DISTRIBUTION.LATENT_DIM = 32
97 | _C.MODEL.DISTRIBUTION.MIN_LOG_SIGMA = -5.0
98 | _C.MODEL.DISTRIBUTION.MAX_LOG_SIGMA = 5.0
99 | _C.MODEL.DISTRIBUTION.POSTERIOR_LAYERS = 2
100 | _C.MODEL.DISTRIBUTION.PRIOR_LAYERS = 2
101 |
102 | _C.MODEL.FUTURE_PRED = CN()
103 | _C.MODEL.FUTURE_PRED.N_GRU_BLOCKS = 3
104 | _C.MODEL.FUTURE_PRED.N_RES_LAYERS = 3
105 |
106 | _C.MODEL.DECODER = CN()
107 |
108 | _C.MODEL.BN_MOMENTUM = 0.1
109 | _C.MODEL.SUBSAMPLE = False # Subsample frames for Lyft
110 |
111 | _C.SEMANTIC_SEG = CN()
112 | _C.SEMANTIC_SEG.WEIGHTS = [1.0, 2.0] # per class cross entropy weights (bg, dynamic, drivable, lane)
113 | _C.SEMANTIC_SEG.USE_TOP_K = True # backprop only top-k hardest pixels
114 | _C.SEMANTIC_SEG.TOP_K_RATIO = 0.25
115 |
116 | _C.INSTANCE_SEG = CN()
117 |
118 | _C.INSTANCE_FLOW = CN()
119 | _C.INSTANCE_FLOW.ENABLED = True
120 |
121 | _C.PROBABILISTIC = CN()
122 | _C.PROBABILISTIC.ENABLED = True # learn a distribution over futures
123 | _C.PROBABILISTIC.WEIGHT = 100.0
124 | _C.PROBABILISTIC.FUTURE_DIM = 6 # number of dimension added (future flow, future centerness, offset, seg)
125 |
126 | _C.FUTURE_DISCOUNT = 0.95
127 |
128 | _C.OPTIMIZER = CN()
129 | _C.OPTIMIZER.LR = 3e-4
130 | _C.OPTIMIZER.WEIGHT_DECAY = 1e-7
131 | _C.GRAD_NORM_CLIP = 5
132 |
133 |
134 | def get_parser():
135 | parser = argparse.ArgumentParser(description='Stretchbev training')
136 | # TODO: remove below?
137 | parser.add_argument('--config-file', default='', metavar='FILE', help='path to config file')
138 | parser.add_argument(
139 | 'opts', help='Modify config options using the command-line', default=None, nargs=argparse.REMAINDER,
140 | )
141 | return parser
142 |
143 |
144 | def get_cfg(args=None, cfg_dict=None):
145 | """ First get default config. Then merge cfg_dict. Then merge according to args. """
146 |
147 | cfg = _C.clone()
148 |
149 | if cfg_dict is not None:
150 | cfg.merge_from_other_cfg(CfgNode(cfg_dict))
151 |
152 | if args is not None:
153 | if args.config_file:
154 | cfg.merge_from_file(args.config_file)
155 | cfg.merge_from_list(args.opts)
156 | cfg.freeze()
157 | return cfg
158 |
--------------------------------------------------------------------------------
/stretchbev/.ipynb_checkpoints/config-checkpoint.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from fvcore.common.config import CfgNode as _CfgNode
3 |
4 |
5 | def convert_to_dict(cfg_node, key_list=[]):
6 | """Convert a config node to dictionary."""
7 | _VALID_TYPES = {tuple, list, str, int, float, bool}
8 | if not isinstance(cfg_node, _CfgNode):
9 | if type(cfg_node) not in _VALID_TYPES:
10 | print(
11 | 'Key {} with value {} is not a valid type; valid types: {}'.format(
12 | '.'.join(key_list), type(cfg_node), _VALID_TYPES
13 | ),
14 | )
15 | return cfg_node
16 | else:
17 | cfg_dict = dict(cfg_node)
18 | for k, v in cfg_dict.items():
19 | cfg_dict[k] = convert_to_dict(v, key_list + [k])
20 | return cfg_dict
21 |
22 |
23 | class CfgNode(_CfgNode):
24 | """Remove once https://github.com/rbgirshick/yacs/issues/19 is merged."""
25 |
26 | def convert_to_dict(self):
27 | return convert_to_dict(self)
28 |
29 |
30 | CN = CfgNode
31 |
32 | _C = CN()
33 | _C.LOG_DIR = 'tensorboard_logs'
34 | _C.TAG = 'default'
35 |
36 | _C.GPUS = [0] # which gpus to use
37 | _C.PRECISION = 32 # 16bit or 32bit
38 | _C.BATCHSIZE = 3
39 | _C.EPOCHS = 20
40 |
41 | _C.N_WORKERS = 5
42 | _C.VIS_INTERVAL = 5000
43 | _C.LOGGING_INTERVAL = 500
44 |
45 | _C.PRETRAINED = CN()
46 | _C.PRETRAINED.LOAD_WEIGHTS = False
47 | _C.PRETRAINED.PATH = ''
48 |
49 | _C.DATASET = CN()
50 | _C.DATASET.DATAROOT = './nuscenes/'
51 | _C.DATASET.VERSION = 'trainval'
52 | _C.DATASET.NAME = 'nuscenes'
53 | _C.DATASET.IGNORE_INDEX = 255 # Ignore index when creating flow/offset labels
54 | _C.DATASET.FILTER_INVISIBLE_VEHICLES = True # Filter vehicles that are not visible from the cameras
55 |
56 | _C.TIME_RECEPTIVE_FIELD = 3 # how many frames of temporal context (1 for single timeframe)
57 | _C.N_FUTURE_FRAMES = 4 # how many time steps into the future to predict
58 |
59 | _C.IMAGE = CN()
60 | _C.IMAGE.FINAL_DIM = (224, 480)
61 | _C.IMAGE.RESIZE_SCALE = 0.3
62 | _C.IMAGE.TOP_CROP = 46
63 | _C.IMAGE.ORIGINAL_HEIGHT = 900 # Original input RGB camera height
64 | _C.IMAGE.ORIGINAL_WIDTH = 1600 # Original input RGB camera width
65 | _C.IMAGE.NAMES = ['CAM_FRONT_LEFT', 'CAM_FRONT', 'CAM_FRONT_RIGHT', 'CAM_BACK_LEFT', 'CAM_BACK', 'CAM_BACK_RIGHT']
66 |
67 | _C.LIFT = CN() # image to BEV lifting
68 | _C.LIFT.X_BOUND = [-50.0, 50.0, 0.5] # Forward
69 | _C.LIFT.Y_BOUND = [-50.0, 50.0, 0.5] # Sides
70 | _C.LIFT.Z_BOUND = [-10.0, 10.0, 20.0] # Height
71 | _C.LIFT.D_BOUND = [2.0, 50.0, 1.0]
72 |
73 | _C.MODEL = CN()
74 |
75 | _C.MODEL.ENCODER = CN()
76 | _C.MODEL.ENCODER.DOWNSAMPLE = 8
77 | _C.MODEL.ENCODER.NAME = 'efficientnet-b4'
78 | _C.MODEL.ENCODER.OUT_CHANNELS = 64
79 | _C.MODEL.ENCODER.USE_DEPTH_DISTRIBUTION = True
80 |
81 | _C.MODEL.SMALL_ENCODER = CN()
82 | _C.MODEL.SMALL_ENCODER.FILTER_SIZE = 64
83 | _C.MODEL.SMALL_ENCODER.SKIPCO = True
84 |
85 |
86 | _C.MODEL.FIRST_STATE = CN()
87 | _C.MODEL.FIRST_STATE.NUM_LAYERS = 3
88 |
89 | _C.MODEL.DYNAMICS = CN()
90 | _C.MODEL.DYNAMICS.NUM_LAYERS = 3
91 |
92 |
93 |
94 |
95 | _C.MODEL.DISTRIBUTION = CN()
96 | _C.MODEL.DISTRIBUTION.LATENT_DIM = 32
97 | _C.MODEL.DISTRIBUTION.MIN_LOG_SIGMA = -5.0
98 | _C.MODEL.DISTRIBUTION.MAX_LOG_SIGMA = 5.0
99 | _C.MODEL.DISTRIBUTION.POSTERIOR_LAYERS = 2
100 | _C.MODEL.DISTRIBUTION.PRIOR_LAYERS = 2
101 |
102 | _C.MODEL.FUTURE_PRED = CN()
103 | _C.MODEL.FUTURE_PRED.N_GRU_BLOCKS = 3
104 | _C.MODEL.FUTURE_PRED.N_RES_LAYERS = 3
105 |
106 | _C.MODEL.DECODER = CN()
107 |
108 | _C.MODEL.BN_MOMENTUM = 0.1
109 | _C.MODEL.SUBSAMPLE = False # Subsample frames for Lyft
110 |
111 | _C.SEMANTIC_SEG = CN()
112 | _C.SEMANTIC_SEG.WEIGHTS = [1.0, 2.0] # per class cross entropy weights (bg, dynamic, drivable, lane)
113 | _C.SEMANTIC_SEG.USE_TOP_K = True # backprop only top-k hardest pixels
114 | _C.SEMANTIC_SEG.TOP_K_RATIO = 0.25
115 |
116 | _C.INSTANCE_SEG = CN()
117 |
118 | _C.INSTANCE_FLOW = CN()
119 | _C.INSTANCE_FLOW.ENABLED = True
120 |
121 | _C.PROBABILISTIC = CN()
122 | _C.PROBABILISTIC.ENABLED = True # learn a distribution over futures
123 | _C.PROBABILISTIC.WEIGHT = 100.0
124 | _C.PROBABILISTIC.FUTURE_DIM = 6 # number of dimension added (future flow, future centerness, offset, seg)
125 |
126 | _C.FUTURE_DISCOUNT = 0.95
127 |
128 | _C.OPTIMIZER = CN()
129 | _C.OPTIMIZER.LR = 3e-4
130 | _C.OPTIMIZER.WEIGHT_DECAY = 1e-7
131 | _C.GRAD_NORM_CLIP = 5
132 |
133 |
134 | def get_parser():
135 | parser = argparse.ArgumentParser(description='Fiery training')
136 | # TODO: remove below?
137 | parser.add_argument('--config-file', default='', metavar='FILE', help='path to config file')
138 | parser.add_argument(
139 | 'opts', help='Modify config options using the command-line', default=None, nargs=argparse.REMAINDER,
140 | )
141 | return parser
142 |
143 |
144 | def get_cfg(args=None, cfg_dict=None):
145 | """ First get default config. Then merge cfg_dict. Then merge according to args. """
146 |
147 | cfg = _C.clone()
148 |
149 | if cfg_dict is not None:
150 | cfg.merge_from_other_cfg(CfgNode(cfg_dict))
151 |
152 | if args is not None:
153 | if args.config_file:
154 | cfg.merge_from_file(args.config_file)
155 | cfg.merge_from_list(args.opts)
156 | cfg.freeze()
157 | return cfg
158 |
--------------------------------------------------------------------------------
/stretchbev/models/model_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Mickael Chen, Edouard Delasalles, Jean-Yves Franceschi, Patrick Gallinari, Sylvain Lamprier
2 |
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 |
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | import torch
17 |
18 | import torch.distributions as distrib
19 | import torch.nn as nn
20 | import torch.nn.functional as F
21 |
22 |
23 | def init_weight(m, init_type='normal', init_gain=0.02):
24 | """
25 | Initializes the input module with the given parameters.
26 |
27 | Only deals with `Conv2d`, `ConvTranspose2d`, `Linear` and `BatchNorm2d` layers.
28 |
29 | Parameters
30 | ----------
31 | m : torch.nn.Module
32 | Module to initialize.
33 | init_type : str
34 | 'normal', 'xavier', 'kaiming', or 'orthogonal'. Orthogonal initialization types for convolutions and linear
35 | operations. Ignored for batch normalization which uses a normal initialization.
36 | init_gain : float
37 | Gain to use for the initialization.
38 | """
39 | classname = m.__class__.__name__
40 | if classname in ('Conv2d', 'ConvTranspose2d', 'Linear'):
41 | if init_type == 'normal':
42 | nn.init.normal_(m.weight.data, 0.0, init_gain)
43 | elif init_type == 'xavier':
44 | nn.init.xavier_normal_(m.weight.data, gain=init_gain)
45 | elif init_type == 'kaiming':
46 | nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
47 | elif init_type == 'orthogonal':
48 | nn.init.orthogonal_(m.weight.data, gain=init_gain)
49 | else:
50 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
51 | if hasattr(m, 'bias') and m.bias is not None:
52 | nn.init.constant_(m.bias.data, 0.0)
53 | elif classname == 'BatchNorm2d':
54 | if m.weight is not None:
55 | nn.init.normal_(m.weight.data, 1.0, init_gain)
56 | if m.bias is not None:
57 | nn.init.constant_(m.bias.data, 0.0)
58 |
59 |
60 | def make_normal_from_raw_params(raw_params, scale_stddev=1, dim=2, eps=1e-8, max_log_sigma=-10000, min_log_sigma=10000):
61 | """
62 | Creates a normal distribution from the given parameters.
63 |
64 | Parameters
65 | ----------
66 | raw_params : torch.*.Tensor
67 | Tensor containing the Gaussian mean and a raw scale parameter on a given dimension.
68 | scale_stddev : float
69 | Multiplier of the final scale parameter of the Gaussian.
70 | dim : int
71 | Dimensions of raw_params so that the first half corresponds to the mean, and the second half to the scale.
72 | eps : float
73 | Minimum possible value of the final scale parameter.
74 |
75 | Returns
76 | -------
77 | torch.distributions.Normal
78 | Normal distribution with the input mean and eps + softplus(raw scale) * scale_stddev as standard deviation.
79 | """
80 | dim = 2 if len(raw_params.shape) == 5 else 1
81 | loc, raw_scale = torch.chunk(raw_params, 2, dim)
82 | assert loc.shape[dim] == raw_scale.shape[dim], f'{loc.shape[dim]}, {raw_scale.shape[dim]}'
83 | # raw_scale = torch.clamp(raw_scale, min_log_sigma, max_log_sigma)
84 | scale = F.softplus(raw_scale) + eps
85 | normal = distrib.Normal(loc, scale * scale_stddev)
86 | return normal
87 |
88 |
89 | def rsample_normal(raw_params, scale_stddev=1, max_log_sigma=-10000, min_log_sigma=10000):
90 | """
91 | Samples from a normal distribution with given parameters.
92 |
93 | Parameters
94 | ----------
95 | raw_params : torch.*.Tensor
96 | Tensor containing a Gaussian mean and a raw scale parameter on its last dimension.
97 | scale_stddev : float
98 | Multiplier of the final scale parameter of the Gaussian.
99 |
100 | Returns
101 | -------
102 | torch.*.Tensor
103 | Sample from the normal distribution with the input mean and eps + softplus(raw scale) * scale_stddev as
104 | standard deviation.
105 | """
106 |
107 | normal = make_normal_from_raw_params(raw_params, scale_stddev=scale_stddev)
108 | sample = normal.rsample()
109 | return sample
110 |
111 |
112 | def neg_logprob(loc, data, scale=1):
113 | """
114 | Computes the negative log density function of a given input with respect to a normal distribution created from
115 | given parameters.
116 |
117 | Parameters
118 | ----------
119 | loc : torch.*.Tensor
120 | Tensor containing the mean of the Gaussian on its last dimension.
121 | data : torch.*.Tensor
122 | Computes the log density function of this tensor with respect to the Gaussian distribution of input mean and
123 | standard deviation.
124 | scale : float
125 | Standard deviation of the Gaussian.
126 |
127 | Returns
128 | -------
129 | torch.*.Tensor
130 | Sample from the normal distribution with the input mean and eps + softplus(raw scale) * scale_stddev as
131 | standard deviation.
132 | """
133 | obs_distrib = distrib.Normal(loc, scale)
134 | return -obs_distrib.log_prob(data)
135 |
--------------------------------------------------------------------------------
/stretchbev/models/.ipynb_checkpoints/model_utils-checkpoint.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Mickael Chen, Edouard Delasalles, Jean-Yves Franceschi, Patrick Gallinari, Sylvain Lamprier
2 |
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 |
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | import torch
17 |
18 | import torch.distributions as distrib
19 | import torch.nn as nn
20 | import torch.nn.functional as F
21 |
22 |
23 | def init_weight(m, init_type='normal', init_gain=0.02):
24 | """
25 | Initializes the input module with the given parameters.
26 |
27 | Only deals with `Conv2d`, `ConvTranspose2d`, `Linear` and `BatchNorm2d` layers.
28 |
29 | Parameters
30 | ----------
31 | m : torch.nn.Module
32 | Module to initialize.
33 | init_type : str
34 | 'normal', 'xavier', 'kaiming', or 'orthogonal'. Orthogonal initialization types for convolutions and linear
35 | operations. Ignored for batch normalization which uses a normal initialization.
36 | init_gain : float
37 | Gain to use for the initialization.
38 | """
39 | classname = m.__class__.__name__
40 | if classname in ('Conv2d', 'ConvTranspose2d', 'Linear'):
41 | if init_type == 'normal':
42 | nn.init.normal_(m.weight.data, 0.0, init_gain)
43 | elif init_type == 'xavier':
44 | nn.init.xavier_normal_(m.weight.data, gain=init_gain)
45 | elif init_type == 'kaiming':
46 | nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
47 | elif init_type == 'orthogonal':
48 | nn.init.orthogonal_(m.weight.data, gain=init_gain)
49 | else:
50 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
51 | if hasattr(m, 'bias') and m.bias is not None:
52 | nn.init.constant_(m.bias.data, 0.0)
53 | elif classname == 'BatchNorm2d':
54 | if m.weight is not None:
55 | nn.init.normal_(m.weight.data, 1.0, init_gain)
56 | if m.bias is not None:
57 | nn.init.constant_(m.bias.data, 0.0)
58 |
59 |
60 | def make_normal_from_raw_params(raw_params, scale_stddev=1, dim=2, eps=1e-8, max_log_sigma=-10000, min_log_sigma=10000):
61 | """
62 | Creates a normal distribution from the given parameters.
63 |
64 | Parameters
65 | ----------
66 | raw_params : torch.*.Tensor
67 | Tensor containing the Gaussian mean and a raw scale parameter on a given dimension.
68 | scale_stddev : float
69 | Multiplier of the final scale parameter of the Gaussian.
70 | dim : int
71 | Dimensions of raw_params so that the first half corresponds to the mean, and the second half to the scale.
72 | eps : float
73 | Minimum possible value of the final scale parameter.
74 |
75 | Returns
76 | -------
77 | torch.distributions.Normal
78 | Normal distribution with the input mean and eps + softplus(raw scale) * scale_stddev as standard deviation.
79 | """
80 | dim = 2 if len(raw_params.shape) == 5 else 1
81 | loc, raw_scale = torch.chunk(raw_params, 2, dim)
82 | assert loc.shape[dim] == raw_scale.shape[dim], f'{loc.shape[dim]}, {raw_scale.shape[dim]}'
83 | # raw_scale = torch.clamp(raw_scale, min_log_sigma, max_log_sigma)
84 | scale = F.softplus(raw_scale) + eps
85 | normal = distrib.Normal(loc, scale * scale_stddev)
86 | return normal
87 |
88 |
89 | def rsample_normal(raw_params, scale_stddev=1, max_log_sigma=-10000, min_log_sigma=10000):
90 | """
91 | Samples from a normal distribution with given parameters.
92 |
93 | Parameters
94 | ----------
95 | raw_params : torch.*.Tensor
96 | Tensor containing a Gaussian mean and a raw scale parameter on its last dimension.
97 | scale_stddev : float
98 | Multiplier of the final scale parameter of the Gaussian.
99 |
100 | Returns
101 | -------
102 | torch.*.Tensor
103 | Sample from the normal distribution with the input mean and eps + softplus(raw scale) * scale_stddev as
104 | standard deviation.
105 | """
106 |
107 | normal = make_normal_from_raw_params(raw_params, scale_stddev=scale_stddev)
108 | sample = normal.rsample()
109 | return sample
110 |
111 |
112 | def neg_logprob(loc, data, scale=1):
113 | """
114 | Computes the negative log density function of a given input with respect to a normal distribution created from
115 | given parameters.
116 |
117 | Parameters
118 | ----------
119 | loc : torch.*.Tensor
120 | Tensor containing the mean of the Gaussian on its last dimension.
121 | data : torch.*.Tensor
122 | Computes the log density function of this tensor with respect to the Gaussian distribution of input mean and
123 | standard deviation.
124 | scale : float
125 | Standard deviation of the Gaussian.
126 |
127 | Returns
128 | -------
129 | torch.*.Tensor
130 | Sample from the normal distribution with the input mean and eps + softplus(raw scale) * scale_stddev as
131 | standard deviation.
132 | """
133 | obs_distrib = distrib.Normal(loc, scale)
134 | return -obs_distrib.log_prob(data)
--------------------------------------------------------------------------------
/visualise.py:
--------------------------------------------------------------------------------
1 | import os
2 | from argparse import ArgumentParser
3 | from glob import glob
4 |
5 | import cv2
6 | import matplotlib as mpl
7 | import matplotlib.pyplot as plt
8 | import numpy as np
9 | import torch
10 | import torchvision
11 | from PIL import Image
12 | from fiery.trainer import TrainingModule
13 | from fiery.utils.instance import predict_instance_segmentation_and_trajectories
14 | from fiery.utils.network import NormalizeInverse
15 | from fiery.utils.visualisation import plot_instance_map, generate_instance_colours, make_contour, convert_figure_numpy
16 |
17 | EXAMPLE_DATA_PATH = 'example_data'
18 |
19 |
20 | def plot_prediction(image, output, cfg):
21 | # Process predictions
22 | consistent_instance_seg, matched_centers = predict_instance_segmentation_and_trajectories(
23 | output, compute_matched_centers=True
24 | )
25 |
26 | # Plot future trajectories
27 | unique_ids = torch.unique(consistent_instance_seg[0, 0]).cpu().long().numpy()[1:]
28 | instance_map = dict(zip(unique_ids, unique_ids))
29 | instance_colours = generate_instance_colours(instance_map)
30 | vis_image = plot_instance_map(consistent_instance_seg[0, 0].cpu().numpy(), instance_map)
31 | trajectory_img = np.zeros(vis_image.shape, dtype=np.uint8)
32 | for instance_id in unique_ids:
33 | path = matched_centers[instance_id]
34 | for t in range(len(path) - 1):
35 | color = instance_colours[instance_id].tolist()
36 | cv2.line(trajectory_img, tuple(path[t]), tuple(path[t + 1]),
37 | color, 4)
38 |
39 | # Overlay arrows
40 | temp_img = cv2.addWeighted(vis_image, 0.7, trajectory_img, 0.3, 1.0)
41 | mask = ~ np.all(trajectory_img == 0, axis=2)
42 | vis_image[mask] = temp_img[mask]
43 |
44 | # Plot present RGB frames and predictions
45 | val_w = 2.99
46 | cameras = cfg.IMAGE.NAMES
47 | image_ratio = cfg.IMAGE.FINAL_DIM[0] / cfg.IMAGE.FINAL_DIM[1]
48 | val_h = val_w * image_ratio
49 | fig = plt.figure(figsize=(4 * val_w, 2 * val_h))
50 | width_ratios = (val_w, val_w, val_w, val_w)
51 | gs = mpl.gridspec.GridSpec(2, 4, width_ratios=width_ratios)
52 | gs.update(wspace=0.0, hspace=0.0, left=0.0, right=1.0, top=1.0, bottom=0.0)
53 |
54 | denormalise_img = torchvision.transforms.Compose(
55 | (NormalizeInverse(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
56 | torchvision.transforms.ToPILImage(),)
57 | )
58 | for imgi, img in enumerate(image[0, -1]):
59 | ax = plt.subplot(gs[imgi // 3, imgi % 3])
60 | showimg = denormalise_img(img.cpu())
61 | if imgi > 2:
62 | showimg = showimg.transpose(Image.FLIP_LEFT_RIGHT)
63 |
64 | plt.annotate(cameras[imgi].replace('_', ' ').replace('CAM ', ''), (0.01, 0.87), c='white',
65 | xycoords='axes fraction', fontsize=14)
66 | plt.imshow(showimg)
67 | plt.axis('off')
68 |
69 | ax = plt.subplot(gs[:, 3])
70 | plt.imshow(make_contour(vis_image[::-1, ::-1]))
71 | plt.axis('off')
72 |
73 | plt.draw()
74 | figure_numpy = convert_figure_numpy(fig)
75 | plt.close()
76 | return figure_numpy
77 |
78 |
79 | def download_example_data():
80 | from requests import get
81 |
82 | def download(url, file_name):
83 | # open in binary mode
84 | with open(file_name, "wb") as file:
85 | # get request
86 | response = get(url)
87 | # write to file
88 | file.write(response.content)
89 |
90 | os.makedirs(EXAMPLE_DATA_PATH, exist_ok=True)
91 | url_list = ['https://github.com/wayveai/fiery/releases/download/v1.0/example_1.npz',
92 | 'https://github.com/wayveai/fiery/releases/download/v1.0/example_2.npz',
93 | 'https://github.com/wayveai/fiery/releases/download/v1.0/example_3.npz',
94 | 'https://github.com/wayveai/fiery/releases/download/v1.0/example_4.npz'
95 | ]
96 | for url in url_list:
97 | download(url, os.path.join(EXAMPLE_DATA_PATH, os.path.basename(url)))
98 |
99 |
100 | def visualise(checkpoint_path):
101 | trainer = TrainingModule.load_from_checkpoint(checkpoint_path, strict=True)
102 |
103 | device = torch.device('cuda:0')
104 | trainer = trainer.to(device)
105 | trainer.eval()
106 |
107 | # Download example data
108 | download_example_data()
109 | # Load data
110 | for data_path in sorted(glob(os.path.join(EXAMPLE_DATA_PATH, '*.npz'))):
111 | data = np.load(data_path)
112 | image = torch.from_numpy(data['image']).to(device)
113 | intrinsics = torch.from_numpy(data['intrinsics']).to(device)
114 | extrinsics = torch.from_numpy(data['extrinsics']).to(device)
115 | future_egomotions = torch.from_numpy(data['future_egomotion']).to(device)
116 |
117 | # Forward pass
118 | with torch.no_grad():
119 | output = trainer.model(image, intrinsics, extrinsics, future_egomotions)
120 |
121 | figure_numpy = plot_prediction(image, output, trainer.cfg)
122 | os.makedirs('./output_vis', exist_ok=True)
123 | output_filename = os.path.join('./output_vis', os.path.basename(data_path).split('.')[0]) + '.png'
124 | Image.fromarray(figure_numpy).save(output_filename)
125 | print(f'Saved output in {output_filename}')
126 |
127 |
128 | if __name__ == '__main__':
129 | parser = ArgumentParser(description='Fiery visualisation')
130 | parser.add_argument('--checkpoint', default='./fiery.ckpt', type=str, help='path to checkpoint')
131 |
132 | args = parser.parse_args()
133 |
134 | visualise(args.checkpoint)
135 |
--------------------------------------------------------------------------------
/stretchbev/models/res_models.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 | from functools import partial
3 |
4 | import torch
5 | import torch.nn as nn
6 |
7 |
8 | class ConvBlock(nn.Module):
9 | """2D convolution followed by
10 | - an optional normalisation (batch norm or instance norm)
11 | - an optional activation (ReLU, LeakyReLU, or tanh)
12 | """
13 |
14 | def __init__(self, in_channels, out_channels=None, kernel_size=3, stride=1, norm='bn', activation='lrelu',
15 | bias=False, transpose=False):
16 | super().__init__()
17 | out_channels = out_channels or in_channels
18 | padding = int((kernel_size - 1) / 2)
19 | self.conv = nn.Conv2d if not transpose else partial(nn.ConvTranspose2d)
20 | self.conv = self.conv(in_channels, out_channels, kernel_size, stride, padding=padding, bias=bias)
21 |
22 | if norm == 'bn':
23 | self.norm = nn.BatchNorm2d(out_channels)
24 | elif norm == 'in':
25 | self.norm = nn.InstanceNorm2d(out_channels)
26 | elif norm == 'none':
27 | self.norm = None
28 | else:
29 | raise ValueError('Invalid norm {}'.format(norm))
30 |
31 | if activation == 'relu':
32 | self.activation = nn.ReLU()
33 | elif activation == 'lrelu':
34 | self.activation = nn.LeakyReLU(0.1)
35 | elif activation == 'tanh':
36 | self.activation = nn.Tanh()
37 | elif activation == 'none':
38 | self.activation = None
39 | else:
40 | raise ValueError('Invalid activation {}'.format(activation))
41 |
42 | def forward(self, x):
43 | x = self.conv(x)
44 |
45 | if self.norm:
46 | x = self.norm(x)
47 | if self.activation:
48 | x = self.activation(x)
49 | return x
50 |
51 |
52 | class ResBlock(nn.Module):
53 | """Residual block:
54 | x -> Conv -> norm -> act. -> Conv -> norm -> act. -> ADD -> out
55 | | |
56 | ---------------------------------------------------
57 | """
58 |
59 | def __init__(self, in_channels, out_channels=None, norm='bn', activation='lrelu', bias=False):
60 | super().__init__()
61 | out_channels = out_channels or in_channels
62 |
63 | self.layers = nn.Sequential(OrderedDict([
64 | ('conv_1', ConvBlock(in_channels, in_channels, 3, stride=1, norm=norm, activation=activation, bias=bias)),
65 | ('conv_2', ConvBlock(in_channels, out_channels, 3, stride=1, norm=norm, activation=activation, bias=bias)),
66 | ('dropout', nn.Dropout2d(0.25)),
67 | ]))
68 |
69 | if out_channels != in_channels:
70 | self.projection = nn.Conv2d(in_channels, out_channels, 1)
71 | else:
72 | self.projection = None
73 |
74 | def forward(self, x):
75 | x_residual = self.layers(x)
76 |
77 | if self.projection:
78 | x = self.projection(x)
79 | return x + x_residual
80 |
81 |
82 | class SmallEncoder(nn.Module):
83 | def __init__(self, nc, nh, nf):
84 | super(SmallEncoder, self).__init__()
85 |
86 | self.blocks = nn.ModuleList([
87 | ResBlock(nc, nf),
88 | ResBlock(nf, nf * 2),
89 | ResBlock(nf * 2, nf * 2),
90 | ResBlock(nf * 2, nf * 2),
91 | ResBlock(nf * 2, nf * 4)
92 | ])
93 | self.last_conv = nn.Sequential(
94 | ConvBlock(nf * 4, nh, 3, stride=1, activation='tanh')
95 | )
96 | self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
97 |
98 | def forward(self, x, return_skip=False):
99 | h = x
100 | skips = []
101 | for i, layer in enumerate(self.blocks):
102 | if i in [1, 2]:
103 | h = self.maxpool(h)
104 | h = layer(h)
105 | skips.append(h)
106 | h = self.last_conv(h)
107 | if return_skip:
108 | return h, skips[::-1]
109 | return h
110 |
111 |
112 | class SmallDecoder(nn.Module):
113 | def __init__(self, nc, nh, nf, skip):
114 | super(SmallDecoder, self).__init__()
115 | coef = 2 if skip else 1
116 | self.skip = skip
117 |
118 | self.first_upconv = ConvBlock(nc, nf * 4, stride=1, transpose=True)
119 |
120 | self.blocks = nn.ModuleList([
121 | ResBlock(nf * 4 * coef, nf * 2),
122 | ResBlock(nf * 2 * coef, nf * 2),
123 | ResBlock(nf * 2 * coef, nf * 2),
124 | ResBlock(nf * 2 * coef, nf),
125 | ResBlock(nf * coef, nf)
126 | ])
127 | self.last_conv = nn.Sequential(
128 | ConvBlock(nf * coef, nf, 3, stride=1),
129 | ConvBlock(nf, nh, 3, stride=1, transpose=True, bias=True, norm='none'),
130 |
131 | )
132 | self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
133 |
134 | def forward(self, z, skip=None, sigmoid=False):
135 | assert skip is None and not self.skip or self.skip and skip is not None
136 | h = self.first_upconv(z)
137 | for i, layer in enumerate(self.blocks):
138 | # print(i, h.shape, skip[i].shape)
139 | if skip is not None:
140 | h = torch.cat([h, skip[i]], 1)
141 | h = layer(h)
142 | if i in [2, 3]:
143 | h = self.upsample(h)
144 | x_ = h
145 | if sigmoid:
146 | x_ = torch.sigmoid(x_)
147 | return x_
148 |
149 |
150 | class SELayer(nn.Module):
151 | def __init__(self, channel, reduction=8):
152 | super(SELayer, self).__init__()
153 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
154 | self.fc = nn.Sequential(
155 | nn.Linear(channel, channel // reduction, bias=False),
156 | nn.ReLU(inplace=True),
157 | nn.Linear(channel // reduction, channel, bias=False),
158 | nn.Sigmoid()
159 | )
160 |
161 | def forward(self, x):
162 | b, c, _, _ = x.size()
163 | y = self.avg_pool(x).view(b, c)
164 | y = self.fc(y).view(b, c, 1, 1)
165 | return x * y.expand_as(x)
166 |
167 |
168 | class ConvNet(nn.Module):
169 | def __init__(self, in_c, out_c, nlayers):
170 | super(ConvNet, self).__init__()
171 | self.model = nn.Sequential(
172 | ResBlock(in_c, out_c),
173 | SELayer(out_c),
174 | ResBlock(out_c, out_c),
175 | SELayer(out_c),
176 | ConvBlock(out_c, out_c, 3, stride=1, bias=True, norm='none'),
177 | )
178 |
179 | def forward(self, x):
180 | return self.model(x)
181 |
--------------------------------------------------------------------------------
/stretchbev/models/.ipynb_checkpoints/res_models-checkpoint.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 | from functools import partial
3 | import torch.nn as nn
4 | import torch
5 |
6 |
7 | class ConvBlock(nn.Module):
8 | """2D convolution followed by
9 | - an optional normalisation (batch norm or instance norm)
10 | - an optional activation (ReLU, LeakyReLU, or tanh)
11 | """
12 | def __init__(self, in_channels, out_channels=None, kernel_size=3, stride=1, norm='bn', activation='lrelu',
13 | bias=False, transpose=False):
14 | super().__init__()
15 | out_channels = out_channels or in_channels
16 | padding = int((kernel_size - 1) / 2)
17 | self.conv = nn.Conv2d if not transpose else partial(nn.ConvTranspose2d)
18 | self.conv = self.conv(in_channels, out_channels, kernel_size, stride, padding=padding, bias=bias)
19 |
20 | if norm == 'bn':
21 | self.norm = nn.BatchNorm2d(out_channels)
22 | elif norm == 'in':
23 | self.norm = nn.InstanceNorm2d(out_channels)
24 | elif norm == 'none':
25 | self.norm = None
26 | else:
27 | raise ValueError('Invalid norm {}'.format(norm))
28 |
29 | if activation == 'relu':
30 | self.activation = nn.ReLU()
31 | elif activation == 'lrelu':
32 | self.activation = nn.LeakyReLU(0.1)
33 | elif activation == 'tanh':
34 | self.activation = nn.Tanh()
35 | elif activation == 'none':
36 | self.activation = None
37 | else:
38 | raise ValueError('Invalid activation {}'.format(activation))
39 |
40 | def forward(self, x):
41 | x = self.conv(x)
42 |
43 | if self.norm:
44 | x = self.norm(x)
45 | if self.activation:
46 | x = self.activation(x)
47 | return x
48 |
49 |
50 | class ResBlock(nn.Module):
51 | """Residual block:
52 | x -> Conv -> norm -> act. -> Conv -> norm -> act. -> ADD -> out
53 | | |
54 | ---------------------------------------------------
55 | """
56 | def __init__(self, in_channels, out_channels=None, norm='bn', activation='lrelu', bias=False):
57 | super().__init__()
58 | out_channels = out_channels or in_channels
59 |
60 | self.layers = nn.Sequential(OrderedDict([
61 | ('conv_1', ConvBlock(in_channels, in_channels, 3, stride=1, norm=norm, activation=activation, bias=bias)),
62 | ('conv_2', ConvBlock(in_channels, out_channels, 3, stride=1, norm=norm, activation=activation, bias=bias)),
63 | # ('dropout', nn.Dropout2d(0.25)),
64 | ]))
65 |
66 | if out_channels != in_channels:
67 | self.projection = nn.Conv2d(in_channels, out_channels, 1)
68 | else:
69 | self.projection = None
70 |
71 | def forward(self, x):
72 | x_residual = self.layers(x)
73 |
74 | if self.projection:
75 | x = self.projection(x)
76 | return x + x_residual
77 |
78 |
79 |
80 | class SmallEncoder(nn.Module):
81 | def __init__(self, nc, nh, nf):
82 | super(SmallEncoder, self).__init__()
83 |
84 | self.blocks = nn.ModuleList([
85 | ResBlock(nc, nf),
86 | ResBlock(nf, nf*2),
87 | ResBlock(nf*2, nf*2),
88 | ResBlock(nf*2, nf*2),
89 | ResBlock(nf*2, nf*4)
90 | ])
91 | self.last_conv = nn.Sequential(
92 | ConvBlock(nf*4, nh, 3, stride=1, activation='tanh')
93 | )
94 | self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
95 |
96 | def forward(self, x, return_skip=False):
97 | h = x
98 | skips = []
99 | for i, layer in enumerate(self.blocks):
100 | if i in [1,2]:
101 | h = self.maxpool(h)
102 | h = layer(h)
103 | skips.append(h)
104 | h = self.last_conv(h)
105 | if return_skip:
106 | return h, skips[::-1]
107 | return h
108 |
109 | class SmallDecoder(nn.Module):
110 | def __init__(self, nc, nh, nf, skip):
111 | super(SmallDecoder, self).__init__()
112 | coef = 2 if skip else 1
113 | self.skip = skip
114 |
115 | self.first_upconv = ConvBlock(nc,nf*4, stride=1, transpose=True)
116 |
117 | self.blocks = nn.ModuleList([
118 | ResBlock(nf*4*coef, nf*2),
119 | ResBlock(nf*2*coef, nf*2),
120 | ResBlock(nf*2*coef, nf*2),
121 | ResBlock(nf*2*coef, nf),
122 | ResBlock(nf*coef, nf)
123 | ])
124 | self.last_conv = nn.Sequential(
125 | ConvBlock(nf*coef, nf, 3, stride=1),
126 | ConvBlock(nf, nh, 3, stride=1, transpose=True, bias=True, norm='none'),
127 |
128 | )
129 | self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
130 |
131 | def forward(self, z, skip=None, sigmoid=False):
132 | assert skip is None and not self.skip or self.skip and skip is not None
133 | h = self.first_upconv(z)
134 | for i, layer in enumerate(self.blocks):
135 | # print(i, h.shape, skip[i].shape)
136 | if skip is not None:
137 | h = torch.cat([h, skip[i]], 1)
138 | h = layer(h)
139 | if i in [2, 3]:
140 | h = self.upsample(h)
141 | x_ = h
142 | if sigmoid:
143 | x_ = torch.sigmoid(x_)
144 | return x_
145 |
146 | class SELayer(nn.Module):
147 | def __init__(self, channel, reduction=8):
148 | super(SELayer, self).__init__()
149 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
150 | self.fc = nn.Sequential(
151 | nn.Linear(channel, channel // reduction, bias=False),
152 | nn.ReLU(inplace=True),
153 | nn.Linear(channel // reduction, channel, bias=False),
154 | nn.Sigmoid()
155 | )
156 |
157 | def forward(self, x):
158 | b, c, _, _ = x.size()
159 | y = self.avg_pool(x).view(b, c)
160 | y = self.fc(y).view(b, c, 1, 1)
161 | return x * y.expand_as(x)
162 |
163 | class ConvNet(nn.Module):
164 | def __init__(self, in_c, out_c, nlayers):
165 | super(ConvNet, self).__init__()
166 | self.model = nn.Sequential(
167 | ResBlock(in_c, out_c),
168 | SELayer(out_c),
169 | ResBlock(out_c, out_c),
170 | SELayer(out_c),
171 | ConvBlock(out_c, out_c, 3, stride=1, bias=True, norm='none'),
172 | )
173 |
174 | def forward(self, x):
175 | return self.model(x)
--------------------------------------------------------------------------------
/stretchbev/layers/convolutions.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 | from functools import partial
3 |
4 | import torch
5 | import torch.nn as nn
6 |
7 |
8 | class ConvBlock(nn.Module):
9 | """2D convolution followed by
10 | - an optional normalisation (batch norm or instance norm)
11 | - an optional activation (ReLU, LeakyReLU, or tanh)
12 | """
13 |
14 | def __init__(
15 | self,
16 | in_channels,
17 | out_channels=None,
18 | kernel_size=3,
19 | stride=1,
20 | norm='bn',
21 | activation='relu',
22 | bias=False,
23 | transpose=False,
24 | ):
25 | super().__init__()
26 | out_channels = out_channels or in_channels
27 | padding = int((kernel_size - 1) / 2)
28 | self.conv = nn.Conv2d if not transpose else partial(nn.ConvTranspose2d, output_padding=1)
29 | self.conv = self.conv(in_channels, out_channels, kernel_size, stride, padding=padding, bias=bias)
30 |
31 | if norm == 'bn':
32 | self.norm = nn.BatchNorm2d(out_channels)
33 | elif norm == 'in':
34 | self.norm = nn.InstanceNorm2d(out_channels)
35 | elif norm == 'none':
36 | self.norm = None
37 | else:
38 | raise ValueError('Invalid norm {}'.format(norm))
39 |
40 | if activation == 'relu':
41 | self.activation = nn.ReLU(inplace=True)
42 | elif activation == 'lrelu':
43 | self.activation = nn.LeakyReLU(0.1, inplace=True)
44 | elif activation == 'elu':
45 | self.activation = nn.ELU(inplace=True)
46 | elif activation == 'tanh':
47 | self.activation = nn.Tanh(inplace=True)
48 | elif activation == 'none':
49 | self.activation = None
50 | else:
51 | raise ValueError('Invalid activation {}'.format(activation))
52 |
53 | def forward(self, x):
54 | x = self.conv(x)
55 |
56 | if self.norm:
57 | x = self.norm(x)
58 | if self.activation:
59 | x = self.activation(x)
60 | return x
61 |
62 |
63 | class Bottleneck(nn.Module):
64 | """
65 | Defines a bottleneck module with a residual connection
66 | """
67 |
68 | def __init__(
69 | self,
70 | in_channels,
71 | out_channels=None,
72 | kernel_size=3,
73 | dilation=1,
74 | groups=1,
75 | upsample=False,
76 | downsample=False,
77 | dropout=0.0,
78 | ):
79 | super().__init__()
80 | self._downsample = downsample
81 | bottleneck_channels = int(in_channels / 2)
82 | out_channels = out_channels or in_channels
83 | padding_size = ((kernel_size - 1) * dilation + 1) // 2
84 |
85 | # Define the main conv operation
86 | assert dilation == 1
87 | if upsample:
88 | assert not downsample, 'downsample and upsample not possible simultaneously.'
89 | bottleneck_conv = nn.ConvTranspose2d(
90 | bottleneck_channels,
91 | bottleneck_channels,
92 | kernel_size=kernel_size,
93 | bias=False,
94 | dilation=1,
95 | stride=2,
96 | output_padding=padding_size,
97 | padding=padding_size,
98 | groups=groups,
99 | )
100 | elif downsample:
101 | bottleneck_conv = nn.Conv2d(
102 | bottleneck_channels,
103 | bottleneck_channels,
104 | kernel_size=kernel_size,
105 | bias=False,
106 | dilation=dilation,
107 | stride=2,
108 | padding=padding_size,
109 | groups=groups,
110 | )
111 | else:
112 | bottleneck_conv = nn.Conv2d(
113 | bottleneck_channels,
114 | bottleneck_channels,
115 | kernel_size=kernel_size,
116 | bias=False,
117 | dilation=dilation,
118 | padding=padding_size,
119 | groups=groups,
120 | )
121 |
122 | self.layers = nn.Sequential(
123 | OrderedDict(
124 | [
125 | # First projection with 1x1 kernel
126 | ('conv_down_project', nn.Conv2d(in_channels, bottleneck_channels, kernel_size=1, bias=False)),
127 | ('abn_down_project', nn.Sequential(nn.BatchNorm2d(bottleneck_channels),
128 | nn.ReLU(inplace=True))),
129 | # Second conv block
130 | ('conv', bottleneck_conv),
131 | ('abn', nn.Sequential(nn.BatchNorm2d(bottleneck_channels), nn.ReLU(inplace=True))),
132 | # Final projection with 1x1 kernel
133 | ('conv_up_project', nn.Conv2d(bottleneck_channels, out_channels, kernel_size=1, bias=False)),
134 | ('abn_up_project', nn.Sequential(nn.BatchNorm2d(out_channels),
135 | nn.ReLU(inplace=True))),
136 | # Regulariser
137 | ('dropout', nn.Dropout2d(p=dropout)),
138 | ]
139 | )
140 | )
141 |
142 | if out_channels == in_channels and not downsample and not upsample:
143 | self.projection = None
144 | else:
145 | projection = OrderedDict()
146 | if upsample:
147 | projection.update({'upsample_skip_proj': Interpolate(scale_factor=2)})
148 | elif downsample:
149 | projection.update({'upsample_skip_proj': nn.MaxPool2d(kernel_size=2, stride=2)})
150 | projection.update(
151 | {
152 | 'conv_skip_proj': nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
153 | 'bn_skip_proj': nn.BatchNorm2d(out_channels),
154 | }
155 | )
156 | self.projection = nn.Sequential(projection)
157 |
158 | # pylint: disable=arguments-differ
159 | def forward(self, *args):
160 | (x,) = args
161 | x_residual = self.layers(x)
162 | if self.projection is not None:
163 | if self._downsample:
164 | # pad h/w dimensions if they are odd to prevent shape mismatch with residual layer
165 | x = nn.functional.pad(x, (0, x.shape[-1] % 2, 0, x.shape[-2] % 2), value=0)
166 | return x_residual + self.projection(x)
167 | return x_residual + x
168 |
169 |
170 | class Interpolate(nn.Module):
171 | def __init__(self, scale_factor: int = 2):
172 | super().__init__()
173 | self._interpolate = nn.functional.interpolate
174 | self._scale_factor = scale_factor
175 |
176 | # pylint: disable=arguments-differ
177 | def forward(self, x):
178 | return self._interpolate(x, scale_factor=self._scale_factor, mode='bilinear', align_corners=False)
179 |
180 |
181 | class UpsamplingConcat(nn.Module):
182 | def __init__(self, in_channels, out_channels, scale_factor=2):
183 | super().__init__()
184 |
185 | self.upsample = nn.Upsample(scale_factor=scale_factor, mode='bilinear', align_corners=False)
186 |
187 | self.conv = nn.Sequential(
188 | nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
189 | nn.BatchNorm2d(out_channels),
190 | nn.ReLU(inplace=True),
191 | nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
192 | nn.BatchNorm2d(out_channels),
193 | nn.ReLU(inplace=True),
194 | )
195 |
196 | def forward(self, x_to_upsample, x):
197 | x_to_upsample = self.upsample(x_to_upsample)
198 | x_to_upsample = torch.cat([x, x_to_upsample], dim=1)
199 | return self.conv(x_to_upsample)
200 |
201 |
202 | class UpsamplingAdd(nn.Module):
203 | def __init__(self, in_channels, out_channels, scale_factor=2):
204 | super().__init__()
205 | self.upsample_layer = nn.Sequential(
206 | nn.Upsample(scale_factor=scale_factor, mode='bilinear', align_corners=False),
207 | nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0, bias=False),
208 | nn.BatchNorm2d(out_channels),
209 | )
210 |
211 | def forward(self, x, x_skip):
212 | x = self.upsample_layer(x)
213 | return x + x_skip
214 |
--------------------------------------------------------------------------------
/stretchbev/models/srvp_models.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | import torch.nn as nn
4 |
5 |
6 | def activation_factory(name):
7 | """
8 | Returns the activation layer corresponding to the input activation name.
9 | Parameters
10 | ----------
11 | name : str
12 | 'relu', 'leaky_relu', 'elu', 'sigmoid', or 'tanh'. Adds the corresponding activation function after the
13 | convolution.
14 | Returns
15 | -------
16 | torch.nn.Module
17 | Element-wise activation layer.
18 | """
19 | if name == 'relu':
20 | return nn.ReLU(inplace=True)
21 | if name == 'leaky_relu':
22 | return nn.LeakyReLU(0.2, inplace=True)
23 | if name == 'elu':
24 | return nn.ELU(inplace=True)
25 | if name == 'sigmoid':
26 | return nn.Sigmoid()
27 | if name == 'tanh':
28 | return nn.Tanh()
29 | raise ValueError(f'Activation function \'{name}\' not yet implemented')
30 |
31 |
32 | def make_conv_block(conv, activation, bn=True):
33 | """
34 | Supplements a convolutional block with activation functions and batch normalization.
35 | Parameters
36 | ----------
37 | conv : torch.nn.Module
38 | Convolutional block.
39 | activation : str
40 | 'relu', 'leaky_relu', 'elu', 'sigmoid', 'tanh', or 'none'. Adds the corresponding activation function, or no
41 | activation if 'none' is chosen, after the convolution.
42 | bn : bool
43 | Whether to add batch normalization after the activation.
44 | Returns
45 | -------
46 | torch.nn.Sequential
47 | Sequence of the input convolutional block, the potentially chosen activation function, and the potential batch
48 | normalization.
49 | """
50 | out_channels = conv.out_channels
51 | modules = [conv]
52 | if bn:
53 | modules.append(nn.BatchNorm2d(out_channels))
54 | if activation != 'none':
55 | modules.append(activation_factory(activation))
56 | return nn.Sequential(*modules)
57 |
58 |
59 | class VGG64Encoder(nn.Module):
60 | # a note: we are downsampling to 1/8 size
61 | # hard coded for now
62 | """
63 | Module implementing the VGG encoder.
64 | """
65 |
66 | def __init__(self, nc, nh, nf):
67 | """
68 | Parameters
69 | ----------
70 | nc : int
71 | Number of channels in the input data.
72 | nh : int
73 | Number of dimensions of the output flat vector.
74 | nf : int
75 | Number of filters per channel of the first convolution.
76 | """
77 | super(VGG64Encoder, self).__init__()
78 | self.conv = nn.ModuleList([
79 | nn.Sequential(
80 | make_conv_block(nn.Conv2d(nc, nf, 3, 1, 1, bias=False), activation='leaky_relu'),
81 | make_conv_block(nn.Conv2d(nf, nf, 3, 1, 1, bias=False), activation='leaky_relu'),
82 | ),
83 | nn.Sequential(
84 | # nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
85 | make_conv_block(nn.Conv2d(nf, nf * 2, 3, 1, 1, bias=False), activation='leaky_relu'),
86 | make_conv_block(nn.Conv2d(nf * 2, nf * 2, 3, 1, 1, bias=False), activation='leaky_relu'),
87 | ),
88 | nn.Sequential(
89 | # nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
90 | make_conv_block(nn.Conv2d(nf * 2, nf * 2, 3, 1, 1, bias=False), activation='leaky_relu'),
91 | make_conv_block(nn.Conv2d(nf * 2, nf * 2, 3, 1, 1, bias=False), activation='leaky_relu'),
92 | make_conv_block(nn.Conv2d(nf * 2, nf * 2, 3, 1, 1, bias=False), activation='leaky_relu'),
93 | ),
94 | nn.Sequential(
95 | # nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
96 | make_conv_block(nn.Conv2d(nf * 2, nf * 2, 3, 1, 1, bias=False), activation='leaky_relu'),
97 | make_conv_block(nn.Conv2d(nf * 2, nf * 2, 3, 1, 1, bias=False), activation='leaky_relu'),
98 | make_conv_block(nn.Conv2d(nf * 2, nf * 4, 3, 1, 1, bias=False), activation='leaky_relu'),
99 | )
100 | ])
101 | self.last_conv = nn.Sequential(
102 | make_conv_block(nn.Conv2d(nf * 4, nh, 3, 1, 1, bias=False), activation='tanh')
103 | )
104 | self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
105 |
106 | def forward(self, x, return_skip=False):
107 | """
108 | Parameters
109 | ----------
110 | x : torch.*.Tensor
111 | Encoder input.
112 | return_skip : bool
113 | Whether to extract and return, besides the network output, skip connections.
114 | Returns
115 | -------
116 | torch.*.Tensor
117 | Encoder output as a tensor of shape (batch, size).
118 | list
119 | Only if return_skip is True. List of skip connections represented as torch.*.Tensor corresponding to each
120 | convolutional block in reverse order (from the deepest to the shallowest convolutional block).
121 | """
122 | skips = []
123 | h = x
124 | for i, layer in enumerate(self.conv):
125 |
126 | if i in [1, 2]:
127 | h = self.maxpool(h)
128 | h_res = layer(h)
129 | print(i, h.shape, h_res.shape)
130 | h = h + h_res
131 | skips.append(h)
132 | h = self.last_conv(h)
133 | if return_skip:
134 | return h, skips[::-1]
135 | return h
136 |
137 |
138 | class VGG64Decoder(nn.Module):
139 | # a note: we are upsampling to 1/8 size
140 | # hard coded for now
141 | """
142 | Module implementing the VGG decoder.
143 | """
144 |
145 | def __init__(self, nc, ny, nf, skip):
146 | """
147 | Parameters
148 | ----------
149 | nc : int
150 | Number of channels in the output shape.
151 | ny : int
152 | Number of dimensions of the input flat vector.
153 | nf : int
154 | Number of filters per channel of the first convolution of the mirror encoder architecture.
155 | skip : list
156 | List of torch.*.Tensor representing skip connections in the same order as the decoder convolutional
157 | blocks. Must be None when skip connections are not allowed.
158 | """
159 | super(VGG64Decoder, self).__init__()
160 | # decoder
161 | coef = 2 if skip else 1
162 | self.skip = skip
163 | self.first_upconv = nn.Sequential(
164 | make_conv_block(nn.ConvTranspose2d(ny, nf * 4, 3, 1, 1, bias=False), activation='leaky_relu'),
165 | )
166 | self.conv = nn.ModuleList([
167 | nn.Sequential(
168 | make_conv_block(nn.Conv2d(nf * 4 * coef, nf * 2, 3, 1, 1, bias=False), activation='leaky_relu'),
169 | make_conv_block(nn.Conv2d(nf * 2, nf * 2, 3, 1, 1, bias=False), activation='leaky_relu'),
170 | make_conv_block(nn.Conv2d(nf * 2, nf * 2, 3, 1, 1, bias=False), activation='leaky_relu'),
171 | # nn.Upsample(scale_factor=2, mode='nearest'),
172 | ),
173 | nn.Sequential(
174 | make_conv_block(nn.Conv2d(nf * 2 * coef, nf * 2, 3, 1, 1, bias=False), activation='leaky_relu'),
175 | make_conv_block(nn.Conv2d(nf * 2, nf * 2, 3, 1, 1, bias=False), activation='leaky_relu'),
176 | make_conv_block(nn.Conv2d(nf * 2, nf * 2, 3, 1, 1, bias=False), activation='leaky_relu'),
177 | # nn.Upsample(scale_factor=2, mode='nearest'),
178 | ),
179 | nn.Sequential(
180 | make_conv_block(nn.Conv2d(nf * 2 * coef, nf * 2, 3, 1, 1, bias=False), activation='leaky_relu'),
181 | make_conv_block(nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=False), activation='leaky_relu'),
182 | # nn.Upsample(scale_factor=2, mode='nearest'),
183 | ),
184 | nn.Sequential(
185 | make_conv_block(nn.Conv2d(nf * coef, nf, 3, 1, 1, bias=False), activation='leaky_relu'),
186 | nn.ConvTranspose2d(nf, nc, 3, 1, 1, bias=False),
187 | ),
188 | ])
189 | self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
190 |
191 | def forward(self, z, skip=None, sigmoid=False):
192 | """
193 | Parameters
194 | ----------
195 | z : torch.*.Tensor
196 | Decoder input.
197 | skip : list
198 | List of torch.*.Tensor representing skip connections in the same order as the decoder convolutional
199 | blocks. Must be None when skip connections are not allowed.
200 | sigmoid : bool
201 | Whether to apply a sigmoid at the end of the decoder.
202 | Returns
203 | -------
204 | torch.*.Tensor
205 | Decoder output as a frame of shape (batch, channels, width, height).
206 | """
207 | assert skip is None and not self.skip or self.skip and skip is not None
208 | h = self.first_upconv(z)
209 | for i, layer in enumerate(self.conv):
210 | if skip is not None:
211 | h = torch.cat([h, skip[i]], 1)
212 | h_res = layer(h)
213 | h = h + h_res
214 | if i in [1, 2]:
215 | h = self.upsample(h)
216 | x_ = h
217 | if sigmoid:
218 | x_ = torch.sigmoid(x_)
219 | return x_
220 |
221 |
222 | class SELayer(nn.Module):
223 | def __init__(self, channel, reduction=8):
224 | super(SELayer, self).__init__()
225 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
226 | self.fc = nn.Sequential(
227 | nn.Linear(channel, channel // reduction, bias=False),
228 | nn.ReLU(inplace=True),
229 | nn.Linear(channel // reduction, channel, bias=False),
230 | nn.Sigmoid()
231 | )
232 |
233 | def forward(self, x):
234 | b, c, _, _ = x.size()
235 | y = self.avg_pool(x).view(b, c)
236 | y = self.fc(y).view(b, c, 1, 1)
237 | return x * y.expand_as(x)
238 |
239 |
240 | class ConvNet(nn.Module):
241 | def __init__(self, in_channels, out_channels, nlayers):
242 | super(ConvNet, self).__init__()
243 |
244 | layers = []
245 | in_c = in_channels
246 | for _ in range(nlayers - 1):
247 | layers += [
248 | make_conv_block(nn.Conv2d(in_c, out_channels, 3, 1, 1, bias=False), activation='leaky_relu')
249 | ]
250 | in_c = out_channels
251 | layers += [SELayer(in_c)]
252 | layers += [make_conv_block(nn.Conv2d(in_c, out_channels, 3, 1, 1, bias=True), activation='none', bn=False)]
253 | self.model = nn.Sequential(*layers)
254 |
255 | def forward(self, x):
256 | return self.model(x)
257 |
--------------------------------------------------------------------------------
/stretchbev/models/.ipynb_checkpoints/srvp_models-checkpoint.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | import torch.distributions as distrib
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 |
8 | def activation_factory(name):
9 | """
10 | Returns the activation layer corresponding to the input activation name.
11 | Parameters
12 | ----------
13 | name : str
14 | 'relu', 'leaky_relu', 'elu', 'sigmoid', or 'tanh'. Adds the corresponding activation function after the
15 | convolution.
16 | Returns
17 | -------
18 | torch.nn.Module
19 | Element-wise activation layer.
20 | """
21 | if name == 'relu':
22 | return nn.ReLU(inplace=True)
23 | if name == 'leaky_relu':
24 | return nn.LeakyReLU(0.2, inplace=True)
25 | if name == 'elu':
26 | return nn.ELU(inplace=True)
27 | if name == 'sigmoid':
28 | return nn.Sigmoid()
29 | if name == 'tanh':
30 | return nn.Tanh()
31 | raise ValueError(f'Activation function \'{name}\' not yet implemented')
32 |
33 |
34 |
35 | def make_conv_block(conv, activation, bn=True):
36 | """
37 | Supplements a convolutional block with activation functions and batch normalization.
38 | Parameters
39 | ----------
40 | conv : torch.nn.Module
41 | Convolutional block.
42 | activation : str
43 | 'relu', 'leaky_relu', 'elu', 'sigmoid', 'tanh', or 'none'. Adds the corresponding activation function, or no
44 | activation if 'none' is chosen, after the convolution.
45 | bn : bool
46 | Whether to add batch normalization after the activation.
47 | Returns
48 | -------
49 | torch.nn.Sequential
50 | Sequence of the input convolutional block, the potentially chosen activation function, and the potential batch
51 | normalization.
52 | """
53 | out_channels = conv.out_channels
54 | modules = [conv]
55 | if bn:
56 | modules.append(nn.BatchNorm2d(out_channels))
57 | if activation != 'none':
58 | modules.append(activation_factory(activation))
59 | return nn.Sequential(*modules)
60 |
61 |
62 |
63 |
64 | class VGG64Encoder(nn.Module):
65 | # a note: we are downsampling to 1/8 size
66 | # hard coded for now
67 | """
68 | Module implementing the VGG encoder.
69 | """
70 | def __init__(self, nc, nh, nf):
71 | """
72 | Parameters
73 | ----------
74 | nc : int
75 | Number of channels in the input data.
76 | nh : int
77 | Number of dimensions of the output flat vector.
78 | nf : int
79 | Number of filters per channel of the first convolution.
80 | """
81 | super(VGG64Encoder, self).__init__()
82 | self.conv = nn.ModuleList([
83 | nn.Sequential(
84 | make_conv_block(nn.Conv2d(nc, nf, 3, 1, 1, bias=False), activation='leaky_relu'),
85 | make_conv_block(nn.Conv2d(nf, nf, 3, 1, 1, bias=False), activation='leaky_relu'),
86 | ),
87 | nn.Sequential(
88 | # nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
89 | make_conv_block(nn.Conv2d(nf, nf * 2, 3, 1, 1, bias=False), activation='leaky_relu'),
90 | make_conv_block(nn.Conv2d(nf * 2, nf * 2, 3, 1, 1, bias=False), activation='leaky_relu'),
91 | ),
92 | nn.Sequential(
93 | # nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
94 | make_conv_block(nn.Conv2d(nf * 2, nf * 2, 3, 1, 1, bias=False), activation='leaky_relu'),
95 | make_conv_block(nn.Conv2d(nf * 2, nf * 2, 3, 1, 1, bias=False), activation='leaky_relu'),
96 | make_conv_block(nn.Conv2d(nf * 2, nf * 2, 3, 1, 1, bias=False), activation='leaky_relu'),
97 | ),
98 | nn.Sequential(
99 | # nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
100 | make_conv_block(nn.Conv2d(nf * 2, nf * 2, 3, 1, 1, bias=False), activation='leaky_relu'),
101 | make_conv_block(nn.Conv2d(nf * 2, nf * 2, 3, 1, 1, bias=False), activation='leaky_relu'),
102 | make_conv_block(nn.Conv2d(nf * 2, nf * 4, 3, 1, 1, bias=False), activation='leaky_relu'),
103 | )
104 | ])
105 | self.last_conv = nn.Sequential(
106 | make_conv_block(nn.Conv2d(nf * 4, nh, 3, 1, 1, bias=False), activation='tanh')
107 | )
108 | self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
109 |
110 | def forward(self, x, return_skip=False):
111 | """
112 | Parameters
113 | ----------
114 | x : torch.*.Tensor
115 | Encoder input.
116 | return_skip : bool
117 | Whether to extract and return, besides the network output, skip connections.
118 | Returns
119 | -------
120 | torch.*.Tensor
121 | Encoder output as a tensor of shape (batch, size).
122 | list
123 | Only if return_skip is True. List of skip connections represented as torch.*.Tensor corresponding to each
124 | convolutional block in reverse order (from the deepest to the shallowest convolutional block).
125 | """
126 | skips = []
127 | h = x
128 | for i,layer in enumerate(self.conv):
129 |
130 | if i in [1,2]:
131 | h = self.maxpool(h)
132 | h_res = layer(h)
133 | print(i, h.shape, h_res.shape)
134 | h = h + h_res
135 | skips.append(h)
136 | h = self.last_conv(h)
137 | if return_skip:
138 | return h, skips[::-1]
139 | return h
140 |
141 |
142 |
143 | class VGG64Decoder(nn.Module):
144 | # a note: we are upsampling to 1/8 size
145 | # hard coded for now
146 | """
147 | Module implementing the VGG decoder.
148 | """
149 | def __init__(self, nc, ny, nf, skip):
150 | """
151 | Parameters
152 | ----------
153 | nc : int
154 | Number of channels in the output shape.
155 | ny : int
156 | Number of dimensions of the input flat vector.
157 | nf : int
158 | Number of filters per channel of the first convolution of the mirror encoder architecture.
159 | skip : list
160 | List of torch.*.Tensor representing skip connections in the same order as the decoder convolutional
161 | blocks. Must be None when skip connections are not allowed.
162 | """
163 | super(VGG64Decoder, self).__init__()
164 | # decoder
165 | coef = 2 if skip else 1
166 | self.skip = skip
167 | self.first_upconv = nn.Sequential(
168 | make_conv_block(nn.ConvTranspose2d(ny, nf * 4, 3, 1, 1, bias=False), activation='leaky_relu'),
169 | )
170 | self.conv = nn.ModuleList([
171 | nn.Sequential(
172 | make_conv_block(nn.Conv2d(nf * 4 * coef, nf * 2, 3, 1, 1, bias=False), activation='leaky_relu'),
173 | make_conv_block(nn.Conv2d(nf * 2, nf * 2, 3, 1, 1, bias=False), activation='leaky_relu'),
174 | make_conv_block(nn.Conv2d(nf * 2, nf * 2, 3, 1, 1, bias=False), activation='leaky_relu'),
175 | # nn.Upsample(scale_factor=2, mode='nearest'),
176 | ),
177 | nn.Sequential(
178 | make_conv_block(nn.Conv2d(nf * 2 * coef, nf * 2, 3, 1, 1, bias=False), activation='leaky_relu'),
179 | make_conv_block(nn.Conv2d(nf * 2, nf * 2, 3, 1, 1, bias=False), activation='leaky_relu'),
180 | make_conv_block(nn.Conv2d(nf * 2, nf * 2, 3, 1, 1, bias=False), activation='leaky_relu'),
181 | # nn.Upsample(scale_factor=2, mode='nearest'),
182 | ),
183 | nn.Sequential(
184 | make_conv_block(nn.Conv2d(nf * 2 * coef, nf * 2, 3, 1, 1, bias=False), activation='leaky_relu'),
185 | make_conv_block(nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=False), activation='leaky_relu'),
186 | # nn.Upsample(scale_factor=2, mode='nearest'),
187 | ),
188 | nn.Sequential(
189 | make_conv_block(nn.Conv2d(nf * coef, nf, 3, 1, 1, bias=False), activation='leaky_relu'),
190 | nn.ConvTranspose2d(nf, nc, 3, 1, 1, bias=False),
191 | ),
192 | ])
193 | self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
194 |
195 | def forward(self, z, skip=None, sigmoid=False):
196 | """
197 | Parameters
198 | ----------
199 | z : torch.*.Tensor
200 | Decoder input.
201 | skip : list
202 | List of torch.*.Tensor representing skip connections in the same order as the decoder convolutional
203 | blocks. Must be None when skip connections are not allowed.
204 | sigmoid : bool
205 | Whether to apply a sigmoid at the end of the decoder.
206 | Returns
207 | -------
208 | torch.*.Tensor
209 | Decoder output as a frame of shape (batch, channels, width, height).
210 | """
211 | assert skip is None and not self.skip or self.skip and skip is not None
212 | h = self.first_upconv(z)
213 | for i, layer in enumerate(self.conv):
214 | if skip is not None:
215 | h = torch.cat([h, skip[i]], 1)
216 | h_res = layer(h)
217 | h = h + h_res
218 | if i in [1,2]:
219 | h = self.upsample(h)
220 | x_ = h
221 | if sigmoid:
222 | x_ = torch.sigmoid(x_)
223 | return x_
224 |
225 |
226 | class SELayer(nn.Module):
227 | def __init__(self, channel, reduction=8):
228 | super(SELayer, self).__init__()
229 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
230 | self.fc = nn.Sequential(
231 | nn.Linear(channel, channel // reduction, bias=False),
232 | nn.ReLU(inplace=True),
233 | nn.Linear(channel // reduction, channel, bias=False),
234 | nn.Sigmoid()
235 | )
236 |
237 | def forward(self, x):
238 | b, c, _, _ = x.size()
239 | y = self.avg_pool(x).view(b, c)
240 | y = self.fc(y).view(b, c, 1, 1)
241 | return x * y.expand_as(x)
242 |
243 |
244 | class ConvNet(nn.Module):
245 | def __init__(self, in_channels, out_channels, nlayers):
246 | super(ConvNet, self).__init__()
247 |
248 | layers = []
249 | in_c = in_channels
250 | for _ in range(nlayers - 1):
251 | layers += [
252 | make_conv_block(nn.Conv2d(in_c, out_channels, 3, 1, 1, bias=False), activation='leaky_relu')
253 | ]
254 | in_c = out_channels
255 | layers += [SELayer(in_c)]
256 | layers += [make_conv_block(nn.Conv2d(in_c, out_channels, 3, 1, 1, bias=True), activation='none', bn=False)]
257 | self.model = nn.Sequential(*layers)
258 |
259 | def forward(self, x):
260 | return self.model(x)
--------------------------------------------------------------------------------
/stretchbev/utils/geometry.py:
--------------------------------------------------------------------------------
1 | import PIL
2 | import numpy as np
3 | import torch
4 |
5 | from pyquaternion import Quaternion
6 |
7 |
8 | def resize_and_crop_image(img, resize_dims, crop):
9 | # Bilinear resizing followed by cropping
10 | img = img.resize(resize_dims, resample=PIL.Image.BILINEAR)
11 | img = img.crop(crop)
12 | return img
13 |
14 |
15 | def update_intrinsics(intrinsics, top_crop=0.0, left_crop=0.0, scale_width=1.0, scale_height=1.0):
16 | """
17 | Parameters
18 | ----------
19 | intrinsics: torch.Tensor (3, 3)
20 | top_crop: float
21 | left_crop: float
22 | scale_width: float
23 | scale_height: float
24 | """
25 | updated_intrinsics = intrinsics.clone()
26 | # Adjust intrinsics scale due to resizing
27 | updated_intrinsics[0, 0] *= scale_width
28 | updated_intrinsics[0, 2] *= scale_width
29 | updated_intrinsics[1, 1] *= scale_height
30 | updated_intrinsics[1, 2] *= scale_height
31 |
32 | # Adjust principal point due to cropping
33 | updated_intrinsics[0, 2] -= left_crop
34 | updated_intrinsics[1, 2] -= top_crop
35 |
36 | return updated_intrinsics
37 |
38 |
39 | def calculate_birds_eye_view_parameters(x_bounds, y_bounds, z_bounds):
40 | """
41 | Parameters
42 | ----------
43 | x_bounds: Forward direction in the ego-car.
44 | y_bounds: Sides
45 | z_bounds: Height
46 |
47 | Returns
48 | -------
49 | bev_resolution: Bird's-eye view bev_resolution
50 | bev_start_position Bird's-eye view first element
51 | bev_dimension Bird's-eye view tensor spatial dimension
52 | """
53 | bev_resolution = torch.tensor([row[2] for row in [x_bounds, y_bounds, z_bounds]])
54 | bev_start_position = torch.tensor([row[0] + row[2] / 2.0 for row in [x_bounds, y_bounds, z_bounds]])
55 | bev_dimension = torch.tensor([(row[1] - row[0]) / row[2] for row in [x_bounds, y_bounds, z_bounds]],
56 | dtype=torch.long)
57 |
58 | return bev_resolution, bev_start_position, bev_dimension
59 |
60 |
61 | def convert_egopose_to_matrix_numpy(egopose):
62 | transformation_matrix = np.zeros((4, 4), dtype=np.float32)
63 | rotation = Quaternion(egopose['rotation']).rotation_matrix
64 | translation = np.array(egopose['translation'])
65 | transformation_matrix[:3, :3] = rotation
66 | transformation_matrix[:3, 3] = translation
67 | transformation_matrix[3, 3] = 1.0
68 | return transformation_matrix
69 |
70 |
71 | def invert_matrix_egopose_numpy(egopose):
72 | """ Compute the inverse transformation of a 4x4 egopose numpy matrix."""
73 | inverse_matrix = np.zeros((4, 4), dtype=np.float32)
74 | rotation = egopose[:3, :3]
75 | translation = egopose[:3, 3]
76 | inverse_matrix[:3, :3] = rotation.T
77 | inverse_matrix[:3, 3] = -np.dot(rotation.T, translation)
78 | inverse_matrix[3, 3] = 1.0
79 | return inverse_matrix
80 |
81 |
82 | def mat2pose_vec(matrix: torch.Tensor):
83 | """
84 | Converts a 4x4 pose matrix into a 6-dof pose vector
85 | Args:
86 | matrix (ndarray): 4x4 pose matrix
87 | Returns:
88 | vector (ndarray): 6-dof pose vector comprising translation components (tx, ty, tz) and
89 | rotation components (rx, ry, rz)
90 | """
91 |
92 | # M[1, 2] = -sinx*cosy, M[2, 2] = +cosx*cosy
93 | rotx = torch.atan2(-matrix[..., 1, 2], matrix[..., 2, 2])
94 |
95 | # M[0, 2] = +siny, M[1, 2] = -sinx*cosy, M[2, 2] = +cosx*cosy
96 | cosy = torch.sqrt(matrix[..., 1, 2] ** 2 + matrix[..., 2, 2] ** 2)
97 | roty = torch.atan2(matrix[..., 0, 2], cosy)
98 |
99 | # M[0, 0] = +cosy*cosz, M[0, 1] = -cosy*sinz
100 | rotz = torch.atan2(-matrix[..., 0, 1], matrix[..., 0, 0])
101 |
102 | rotation = torch.stack((rotx, roty, rotz), dim=-1)
103 |
104 | # Extract translation params
105 | translation = matrix[..., :3, 3]
106 | return torch.cat((translation, rotation), dim=-1)
107 |
108 |
109 | def euler2mat(angle: torch.Tensor):
110 | """Convert euler angles to rotation matrix.
111 | Reference: https://github.com/pulkitag/pycaffe-utils/blob/master/rot_utils.py#L174
112 | Args:
113 | angle: rotation angle along 3 axis (in radians) [Bx3]
114 | Returns:
115 | Rotation matrix corresponding to the euler angles [Bx3x3]
116 | """
117 | shape = angle.shape
118 | angle = angle.view(-1, 3)
119 | x, y, z = angle[:, 0], angle[:, 1], angle[:, 2]
120 |
121 | cosz = torch.cos(z)
122 | sinz = torch.sin(z)
123 |
124 | zeros = torch.zeros_like(z)
125 | ones = torch.ones_like(z)
126 | zmat = torch.stack([cosz, -sinz, zeros, sinz, cosz, zeros, zeros, zeros, ones], dim=1).view(-1, 3, 3)
127 |
128 | cosy = torch.cos(y)
129 | siny = torch.sin(y)
130 |
131 | ymat = torch.stack([cosy, zeros, siny, zeros, ones, zeros, -siny, zeros, cosy], dim=1).view(-1, 3, 3)
132 |
133 | cosx = torch.cos(x)
134 | sinx = torch.sin(x)
135 |
136 | xmat = torch.stack([ones, zeros, zeros, zeros, cosx, -sinx, zeros, sinx, cosx], dim=1).view(-1, 3, 3)
137 |
138 | rot_mat = xmat.bmm(ymat).bmm(zmat)
139 | rot_mat = rot_mat.view(*shape[:-1], 3, 3)
140 | return rot_mat
141 |
142 |
143 | def pose_vec2mat(vec: torch.Tensor):
144 | """
145 | Convert 6DoF parameters to transformation matrix.
146 | Args:
147 | vec: 6DoF parameters in the order of tx, ty, tz, rx, ry, rz [B,6]
148 | Returns:
149 | A transformation matrix [B,4,4]
150 | """
151 | translation = vec[..., :3].unsqueeze(-1) # [...x3x1]
152 | rot = vec[..., 3:].contiguous() # [...x3]
153 | rot_mat = euler2mat(rot) # [...,3,3]
154 | transform_mat = torch.cat([rot_mat, translation], dim=-1) # [...,3,4]
155 | transform_mat = torch.nn.functional.pad(transform_mat, [0, 0, 0, 1], value=0) # [...,4,4]
156 | transform_mat[..., 3, 3] = 1.0
157 | return transform_mat
158 |
159 |
160 | def invert_pose_matrix(x):
161 | """
162 | Parameters
163 | ----------
164 | x: [B, 4, 4] batch of pose matrices
165 |
166 | Returns
167 | -------
168 | out: [B, 4, 4] batch of inverse pose matrices
169 | """
170 | assert len(x.shape) == 3 and x.shape[1:] == (4, 4), 'Only works for batch of pose matrices.'
171 |
172 | transposed_rotation = torch.transpose(x[:, :3, :3], 1, 2)
173 | translation = x[:, :3, 3:]
174 |
175 | inverse_mat = torch.cat([transposed_rotation, -torch.bmm(transposed_rotation, translation)], dim=-1) # [B,3,4]
176 | inverse_mat = torch.nn.functional.pad(inverse_mat, [0, 0, 0, 1], value=0) # [B,4,4]
177 | inverse_mat[..., 3, 3] = 1.0
178 | return inverse_mat
179 |
180 |
181 | def warp_features(x, flow, mode='nearest', spatial_extent=None):
182 | """ Applies a rotation and translation to feature map x.
183 | Args:
184 | x: (b, c, h, w) feature map
185 | flow: (b, 6) 6DoF vector (only uses the xy poriton)
186 | mode: use 'nearest' when dealing with categorical inputs
187 | Returns:
188 | in plane transformed feature map
189 | """
190 | if flow is None:
191 | return x
192 | b, c, h, w = x.shape
193 | # z-rotation
194 | angle = flow[:, 5].clone() # torch.atan2(flow[:, 1, 0], flow[:, 0, 0])
195 | # x-y translation
196 | translation = flow[:, :2].clone() # flow[:, :2, 3]
197 |
198 | # Normalise translation. Need to divide by how many meters is half of the image.
199 | # because translation of 1.0 correspond to translation of half of the image.
200 | translation[:, 0] /= spatial_extent[0]
201 | translation[:, 1] /= spatial_extent[1]
202 | # forward axis is inverted
203 | translation[:, 0] *= -1
204 |
205 | cos_theta = torch.cos(angle)
206 | sin_theta = torch.sin(angle)
207 |
208 | # output = Rot.input + translation
209 | # tx and ty are inverted as is the case when going from real coordinates to numpy coordinates
210 | # translation_pos_0 -> positive value makes the image move to the left
211 | # translation_pos_1 -> positive value makes the image move to the top
212 | # Angle -> positive value in rad makes the image move in the trigonometric way
213 | transformation = torch.stack([cos_theta, -sin_theta, translation[:, 1],
214 | sin_theta, cos_theta, translation[:, 0]], dim=-1).view(b, 2, 3)
215 |
216 | # Note that a rotation will preserve distances only if height = width. Otherwise there's
217 | # resizing going on. e.g. rotation of pi/2 of a 100x200 image will make what's in the center of the image
218 | # elongated.
219 | grid = torch.nn.functional.affine_grid(transformation, size=x.shape, align_corners=False)
220 | warped_x = torch.nn.functional.grid_sample(x, grid.float(), mode=mode, padding_mode='zeros', align_corners=False)
221 |
222 | return warped_x
223 |
224 |
225 | def cumulative_warp_features(x, flow, mode='nearest', spatial_extent=None):
226 | """ Warps a sequence of feature maps by accumulating incremental 2d flow.
227 |
228 | x[:, -1] remains unchanged
229 | x[:, -2] is warped using flow[:, -2]
230 | x[:, -3] is warped using flow[:, -3] @ flow[:, -2]
231 | ...
232 | x[:, 0] is warped using flow[:, 0] @ ... @ flow[:, -3] @ flow[:, -2]
233 |
234 | Args:
235 | x: (b, t, c, h, w) sequence of feature maps
236 | flow: (b, t, 6) sequence of 6 DoF pose
237 | from t to t+1 (only uses the xy poriton)
238 |
239 | """
240 | sequence_length = x.shape[1]
241 | if sequence_length == 1:
242 | return x
243 |
244 | flow = pose_vec2mat(flow)
245 |
246 | out = [x[:, -1]]
247 | cum_flow = flow[:, -2]
248 | for t in reversed(range(sequence_length - 1)):
249 | out.append(warp_features(x[:, t], mat2pose_vec(cum_flow), mode=mode, spatial_extent=spatial_extent))
250 | # @ is the equivalent of torch.bmm
251 | cum_flow = flow[:, t - 1] @ cum_flow
252 |
253 | return torch.stack(out[::-1], 1)
254 |
255 |
256 | def cumulative_warp_features_reverse(x, flow, mode='nearest', spatial_extent=None):
257 | """ Warps a sequence of feature maps by accumulating incremental 2d flow.
258 |
259 | x[:, 0] remains unchanged
260 | x[:, 1] is warped using flow[:, 0].inverse()
261 | x[:, 2] is warped using flow[:, 0].inverse() @ flow[:, 1].inverse()
262 | ...
263 |
264 | Args:
265 | x: (b, t, c, h, w) sequence of feature maps
266 | flow: (b, t, 6) sequence of 6 DoF pose
267 | from t to t+1 (only uses the xy poriton)
268 |
269 | """
270 | flow = pose_vec2mat(flow)
271 |
272 | out = [x[:, 0]]
273 |
274 | for i in range(1, x.shape[1]):
275 | if i == 1:
276 | cum_flow = invert_pose_matrix(flow[:, 0])
277 | else:
278 | cum_flow = cum_flow @ invert_pose_matrix(flow[:, i - 1])
279 | out.append(warp_features(x[:, i], mat2pose_vec(cum_flow), mode, spatial_extent=spatial_extent))
280 | return torch.stack(out, 1)
281 |
282 |
283 | class VoxelsSumming(torch.autograd.Function):
284 | """Adapted from https://github.com/nv-tlabs/lift-splat-shoot/blob/master/src/tools.py#L193"""
285 |
286 | @staticmethod
287 | def forward(ctx, x, geometry, ranks):
288 | """The features `x` and `geometry` are ranked by voxel positions."""
289 | # Cumulative sum of all features.
290 | x = x.cumsum(0)
291 |
292 | # Indicates the change of voxel.
293 | mask = torch.ones(x.shape[0], device=x.device, dtype=torch.bool)
294 | mask[:-1] = ranks[1:] != ranks[:-1]
295 |
296 | x, geometry = x[mask], geometry[mask]
297 | # Calculate sum of features within a voxel.
298 | x = torch.cat((x[:1], x[1:] - x[:-1]))
299 |
300 | ctx.save_for_backward(mask)
301 | ctx.mark_non_differentiable(geometry)
302 |
303 | return x, geometry
304 |
305 | @staticmethod
306 | def backward(ctx, grad_x, grad_geometry):
307 | (mask,) = ctx.saved_tensors
308 | # Since the operation is summing, we simply need to send gradient
309 | # to all elements that were part of the summation process.
310 | indices = torch.cumsum(mask, 0)
311 | indices[mask] -= 1
312 |
313 | output_grad = grad_x[indices]
314 |
315 | return output_grad, None, None
316 |
--------------------------------------------------------------------------------
/stretchbev/layers/.ipynb_checkpoints/temporal-checkpoint.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | from fiery.layers.convolutions import ConvBlock
7 | from fiery.utils.geometry import warp_features
8 |
9 |
10 | class SpatialGRU(nn.Module):
11 | """A GRU cell that takes an input tensor [BxTxCxHxW] and an optional previous state and passes a
12 | convolutional gated recurrent unit over the data"""
13 |
14 | def __init__(self, input_size, hidden_size, gru_bias_init=0.0, norm='bn', activation='relu'):
15 | super().__init__()
16 | self.input_size = input_size
17 | self.hidden_size = hidden_size
18 | self.gru_bias_init = gru_bias_init
19 |
20 | self.conv_update = nn.Conv2d(input_size + hidden_size, hidden_size, kernel_size=3, bias=True, padding=1)
21 | self.conv_reset = nn.Conv2d(input_size + hidden_size, hidden_size, kernel_size=3, bias=True, padding=1)
22 |
23 | self.conv_state_tilde = ConvBlock(
24 | input_size + hidden_size, hidden_size, kernel_size=3, bias=False, norm=norm, activation=activation
25 | )
26 |
27 | def forward(self, x, state=None, flow=None, mode='bilinear'):
28 | # pylint: disable=unused-argument, arguments-differ
29 | # Check size
30 | assert len(x.size()) == 5, 'Input tensor must be BxTxCxHxW.'
31 | b, timesteps, c, h, w = x.size()
32 | assert c == self.input_size, f'feature sizes must match, got input {c} for layer with size {self.input_size}'
33 |
34 | # recurrent layers
35 | rnn_output = []
36 | rnn_state = torch.zeros(b, self.hidden_size, h, w, device=x.device) if state is None else state
37 | for t in range(timesteps):
38 | x_t = x[:, t]
39 | if flow is not None:
40 | rnn_state = warp_features(rnn_state, flow[:, t], mode=mode)
41 |
42 | # propagate rnn state
43 | rnn_state = self.gru_cell(x_t, rnn_state)
44 | rnn_output.append(rnn_state)
45 |
46 | # reshape rnn output to batch tensor
47 | return torch.stack(rnn_output, dim=1)
48 |
49 | def gru_cell(self, x, state):
50 | # Compute gates
51 | x_and_state = torch.cat([x, state], dim=1)
52 | update_gate = self.conv_update(x_and_state)
53 | reset_gate = self.conv_reset(x_and_state)
54 | # Add bias to initialise gate as close to identity function
55 | update_gate = torch.sigmoid(update_gate + self.gru_bias_init)
56 | reset_gate = torch.sigmoid(reset_gate + self.gru_bias_init)
57 |
58 | # Compute proposal state, activation is defined in norm_act_config (can be tanh, ReLU etc)
59 | state_tilde = self.conv_state_tilde(torch.cat([x, (1.0 - reset_gate) * state], dim=1))
60 |
61 | output = (1.0 - update_gate) * state + update_gate * state_tilde
62 | return output
63 |
64 |
65 | class CausalConv3d(nn.Module):
66 | def __init__(self, in_channels, out_channels, kernel_size=(2, 3, 3), dilation=(1, 1, 1), bias=False):
67 | super().__init__()
68 | assert len(kernel_size) == 3, 'kernel_size must be a 3-tuple.'
69 | time_pad = (kernel_size[0] - 1) * dilation[0]
70 | height_pad = ((kernel_size[1] - 1) * dilation[1]) // 2
71 | width_pad = ((kernel_size[2] - 1) * dilation[2]) // 2
72 |
73 | # Pad temporally on the left
74 | self.pad = nn.ConstantPad3d(padding=(width_pad, width_pad, height_pad, height_pad, time_pad, 0), value=0)
75 | self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, dilation=dilation, stride=1, padding=0, bias=bias)
76 | self.norm = nn.BatchNorm3d(out_channels)
77 | self.activation = nn.ReLU(inplace=True)
78 |
79 | def forward(self, *inputs):
80 | (x,) = inputs
81 | x = self.pad(x)
82 | x = self.conv(x)
83 | x = self.norm(x)
84 | x = self.activation(x)
85 | return x
86 |
87 |
88 | class CausalMaxPool3d(nn.Module):
89 | def __init__(self, kernel_size=(2, 3, 3)):
90 | super().__init__()
91 | assert len(kernel_size) == 3, 'kernel_size must be a 3-tuple.'
92 | time_pad = kernel_size[0] - 1
93 | height_pad = (kernel_size[1] - 1) // 2
94 | width_pad = (kernel_size[2] - 1) // 2
95 |
96 | # Pad temporally on the left
97 | self.pad = nn.ConstantPad3d(padding=(width_pad, width_pad, height_pad, height_pad, time_pad, 0), value=0)
98 | self.max_pool = nn.MaxPool3d(kernel_size, stride=1)
99 |
100 | def forward(self, *inputs):
101 | (x,) = inputs
102 | x = self.pad(x)
103 | x = self.max_pool(x)
104 | return x
105 |
106 |
107 | def conv_1x1x1_norm_activated(in_channels, out_channels):
108 | """1x1x1 3D convolution, normalization and activation layer."""
109 | return nn.Sequential(
110 | OrderedDict(
111 | [
112 | ('conv', nn.Conv3d(in_channels, out_channels, kernel_size=1, bias=False)),
113 | ('norm', nn.BatchNorm3d(out_channels)),
114 | ('activation', nn.ReLU(inplace=True)),
115 | ]
116 | )
117 | )
118 |
119 |
120 | class Bottleneck3D(nn.Module):
121 | """
122 | Defines a bottleneck module with a residual connection
123 | """
124 |
125 | def __init__(self, in_channels, out_channels=None, kernel_size=(2, 3, 3), dilation=(1, 1, 1)):
126 | super().__init__()
127 | bottleneck_channels = in_channels // 2
128 | out_channels = out_channels or in_channels
129 |
130 | self.layers = nn.Sequential(
131 | OrderedDict(
132 | [
133 | # First projection with 1x1 kernel
134 | ('conv_down_project', conv_1x1x1_norm_activated(in_channels, bottleneck_channels)),
135 | # Second conv block
136 | (
137 | 'conv',
138 | CausalConv3d(
139 | bottleneck_channels,
140 | bottleneck_channels,
141 | kernel_size=kernel_size,
142 | dilation=dilation,
143 | bias=False,
144 | ),
145 | ),
146 | # Final projection with 1x1 kernel
147 | ('conv_up_project', conv_1x1x1_norm_activated(bottleneck_channels, out_channels)),
148 | ]
149 | )
150 | )
151 |
152 | if out_channels != in_channels:
153 | self.projection = nn.Sequential(
154 | nn.Conv3d(in_channels, out_channels, kernel_size=1, bias=False),
155 | nn.BatchNorm3d(out_channels),
156 | )
157 | else:
158 | self.projection = None
159 |
160 | def forward(self, *args):
161 | (x,) = args
162 | x_residual = self.layers(x)
163 | x_features = self.projection(x) if self.projection is not None else x
164 | return x_residual + x_features
165 |
166 |
167 | class PyramidSpatioTemporalPooling(nn.Module):
168 | """ Spatio-temporal pyramid pooling.
169 | Performs 3D average pooling followed by 1x1x1 convolution to reduce the number of channels and upsampling.
170 | Setting contains a list of kernel_size: usually it is [(2, h, w), (2, h//2, w//2), (2, h//4, w//4)]
171 | """
172 |
173 | def __init__(self, in_channels, reduction_channels, pool_sizes):
174 | super().__init__()
175 | self.features = []
176 | for pool_size in pool_sizes:
177 | assert pool_size[0] == 2, (
178 | "Time kernel should be 2 as PyTorch raises an error when" "padding with more than half the kernel size"
179 | )
180 | stride = (1, *pool_size[1:])
181 | padding = (pool_size[0] - 1, 0, 0)
182 | self.features.append(
183 | nn.Sequential(
184 | OrderedDict(
185 | [
186 | # Pad the input tensor but do not take into account zero padding into the average.
187 | (
188 | 'avgpool',
189 | torch.nn.AvgPool3d(
190 | kernel_size=pool_size, stride=stride, padding=padding, count_include_pad=False
191 | ),
192 | ),
193 | ('conv_bn_relu', conv_1x1x1_norm_activated(in_channels, reduction_channels)),
194 | ]
195 | )
196 | )
197 | )
198 | self.features = nn.ModuleList(self.features)
199 |
200 | def forward(self, *inputs):
201 | (x,) = inputs
202 | b, _, t, h, w = x.shape
203 | # Do not include current tensor when concatenating
204 | out = []
205 | for f in self.features:
206 | # Remove unnecessary padded values (time dimension) on the right
207 | x_pool = f(x)[:, :, :-1].contiguous()
208 | c = x_pool.shape[1]
209 | x_pool = nn.functional.interpolate(
210 | x_pool.view(b * t, c, *x_pool.shape[-2:]), (h, w), mode='bilinear', align_corners=False
211 | )
212 | x_pool = x_pool.view(b, c, t, h, w)
213 | out.append(x_pool)
214 | out = torch.cat(out, 1)
215 | return out
216 |
217 |
218 | class TemporalBlock(nn.Module):
219 | """ Temporal block with the following layers:
220 | - 2x3x3, 1x3x3, spatio-temporal pyramid pooling
221 | - dropout
222 | - skip connection.
223 | """
224 |
225 | def __init__(self, in_channels, out_channels=None, use_pyramid_pooling=False, pool_sizes=None):
226 | super().__init__()
227 | self.in_channels = in_channels
228 | self.half_channels = in_channels // 2
229 | self.out_channels = out_channels or self.in_channels
230 | self.kernels = [(2, 3, 3), (1, 3, 3)]
231 |
232 | # Flag for spatio-temporal pyramid pooling
233 | self.use_pyramid_pooling = use_pyramid_pooling
234 |
235 | # 3 convolution paths: 2x3x3, 1x3x3, 1x1x1
236 | self.convolution_paths = []
237 | for kernel_size in self.kernels:
238 | self.convolution_paths.append(
239 | nn.Sequential(
240 | conv_1x1x1_norm_activated(self.in_channels, self.half_channels),
241 | CausalConv3d(self.half_channels, self.half_channels, kernel_size=kernel_size),
242 | )
243 | )
244 | self.convolution_paths.append(conv_1x1x1_norm_activated(self.in_channels, self.half_channels))
245 | self.convolution_paths = nn.ModuleList(self.convolution_paths)
246 |
247 | agg_in_channels = len(self.convolution_paths) * self.half_channels
248 |
249 | if self.use_pyramid_pooling:
250 | assert pool_sizes is not None, "setting must contain the list of kernel_size, but is None."
251 | reduction_channels = self.in_channels // 3
252 | self.pyramid_pooling = PyramidSpatioTemporalPooling(self.in_channels, reduction_channels, pool_sizes)
253 | agg_in_channels += len(pool_sizes) * reduction_channels
254 |
255 | # Feature aggregation
256 | self.aggregation = nn.Sequential(
257 | conv_1x1x1_norm_activated(agg_in_channels, self.out_channels),)
258 |
259 | if self.out_channels != self.in_channels:
260 | self.projection = nn.Sequential(
261 | nn.Conv3d(self.in_channels, self.out_channels, kernel_size=1, bias=False),
262 | nn.BatchNorm3d(self.out_channels),
263 | )
264 | else:
265 | self.projection = None
266 |
267 | def forward(self, *inputs):
268 | (x,) = inputs
269 | x_paths = []
270 | for conv in self.convolution_paths:
271 | x_paths.append(conv(x))
272 | x_residual = torch.cat(x_paths, dim=1)
273 | if self.use_pyramid_pooling:
274 | x_pool = self.pyramid_pooling(x)
275 | x_residual = torch.cat([x_residual, x_pool], dim=1)
276 | x_residual = self.aggregation(x_residual)
277 |
278 | if self.out_channels != self.in_channels:
279 | x = self.projection(x)
280 | x = x + x_residual
281 | return x
282 |
--------------------------------------------------------------------------------
/stretchbev/.ipynb_checkpoints/metrics-checkpoint.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | import torch
4 | from pytorch_lightning.metrics.metric import Metric
5 | from pytorch_lightning.metrics.functional.classification import stat_scores_multiple_classes
6 | from pytorch_lightning.metrics.functional.reduction import reduce
7 |
8 |
9 | class IntersectionOverUnion(Metric):
10 | """Computes intersection-over-union."""
11 | def __init__(
12 | self,
13 | n_classes: int,
14 | ignore_index: Optional[int] = None,
15 | absent_score: float = 0.0,
16 | reduction: str = 'none',
17 | compute_on_step: bool = False,
18 | ):
19 | super().__init__(compute_on_step=compute_on_step)
20 |
21 | self.n_classes = n_classes
22 | self.ignore_index = ignore_index
23 | self.absent_score = absent_score
24 | self.reduction = reduction
25 |
26 | self.add_state('true_positive', default=torch.zeros(n_classes), dist_reduce_fx='sum')
27 | self.add_state('false_positive', default=torch.zeros(n_classes), dist_reduce_fx='sum')
28 | self.add_state('false_negative', default=torch.zeros(n_classes), dist_reduce_fx='sum')
29 | self.add_state('support', default=torch.zeros(n_classes), dist_reduce_fx='sum')
30 |
31 | def update(self, prediction: torch.Tensor, target: torch.Tensor):
32 | tps, fps, _, fns, sups = stat_scores_multiple_classes(prediction, target, self.n_classes)
33 |
34 | self.true_positive += tps
35 | self.false_positive += fps
36 | self.false_negative += fns
37 | self.support += sups
38 |
39 | def compute(self):
40 | scores = torch.zeros(self.n_classes, device=self.true_positive.device, dtype=torch.float32)
41 |
42 | for class_idx in range(self.n_classes):
43 | if class_idx == self.ignore_index:
44 | continue
45 |
46 | tp = self.true_positive[class_idx]
47 | fp = self.false_positive[class_idx]
48 | fn = self.false_negative[class_idx]
49 | sup = self.support[class_idx]
50 |
51 | # If this class is absent in the target (no support) AND absent in the pred (no true or false
52 | # positives), then use the absent_score for this class.
53 | if sup + tp + fp == 0:
54 | scores[class_idx] = self.absent_score
55 | continue
56 |
57 | denominator = tp + fp + fn
58 | score = tp.to(torch.float) / denominator
59 | scores[class_idx] = score
60 |
61 | # Remove the ignored class index from the scores.
62 | if (self.ignore_index is not None) and (0 <= self.ignore_index < self.n_classes):
63 | scores = torch.cat([scores[:self.ignore_index], scores[self.ignore_index+1:]])
64 |
65 | return reduce(scores, reduction=self.reduction)
66 |
67 |
68 | class PanopticMetric(Metric):
69 | def __init__(
70 | self,
71 | n_classes: int,
72 | temporally_consistent: bool = True,
73 | vehicles_id: int = 1,
74 | compute_on_step: bool = False,
75 | ):
76 | super().__init__(compute_on_step=compute_on_step)
77 |
78 | self.n_classes = n_classes
79 | self.temporally_consistent = temporally_consistent
80 | self.vehicles_id = vehicles_id
81 | self.keys = ['iou', 'true_positive', 'false_positive', 'false_negative']
82 |
83 | self.add_state('iou', default=torch.zeros(n_classes), dist_reduce_fx='sum')
84 | self.add_state('true_positive', default=torch.zeros(n_classes), dist_reduce_fx='sum')
85 | self.add_state('false_positive', default=torch.zeros(n_classes), dist_reduce_fx='sum')
86 | self.add_state('false_negative', default=torch.zeros(n_classes), dist_reduce_fx='sum')
87 |
88 | def update(self, pred_instance, gt_instance):
89 | """
90 | Update state with predictions and targets.
91 |
92 | Parameters
93 | ----------
94 | pred_instance: (b, s, h, w)
95 | Temporally consistent instance segmentation prediction.
96 | gt_instance: (b, s, h, w)
97 | Ground truth instance segmentation.
98 | """
99 | batch_size, sequence_length = gt_instance.shape[:2]
100 | # Process labels
101 | assert gt_instance.min() == 0, 'ID 0 of gt_instance must be background'
102 | pred_segmentation = (pred_instance > 0).long()
103 | gt_segmentation = (gt_instance > 0).long()
104 |
105 | for b in range(batch_size):
106 | unique_id_mapping = {}
107 | for t in range(sequence_length):
108 | result = self.panoptic_metrics(
109 | pred_segmentation[b, t].detach(),
110 | pred_instance[b, t].detach(),
111 | gt_segmentation[b, t],
112 | gt_instance[b, t],
113 | unique_id_mapping,
114 | )
115 |
116 | self.iou += result['iou']
117 | self.true_positive += result['true_positive']
118 | self.false_positive += result['false_positive']
119 | self.false_negative += result['false_negative']
120 |
121 | def compute(self):
122 | denominator = torch.maximum(
123 | (self.true_positive + self.false_positive / 2 + self.false_negative / 2),
124 | torch.ones_like(self.true_positive)
125 | )
126 | pq = self.iou / denominator
127 | sq = self.iou / torch.maximum(self.true_positive, torch.ones_like(self.true_positive))
128 | rq = self.true_positive / denominator
129 |
130 | return {'pq': pq,
131 | 'sq': sq,
132 | 'rq': rq,
133 | # If 0, it means there wasn't any detection.
134 | 'denominator': (self.true_positive + self.false_positive / 2 + self.false_negative / 2),
135 | }
136 |
137 | def panoptic_metrics(self, pred_segmentation, pred_instance, gt_segmentation, gt_instance, unique_id_mapping):
138 | """
139 | Computes panoptic quality metric components.
140 |
141 | Parameters
142 | ----------
143 | pred_segmentation: [H, W] range {0, ..., n_classes-1} (>= n_classes is void)
144 | pred_instance: [H, W] range {0, ..., n_instances} (zero means background)
145 | gt_segmentation: [H, W] range {0, ..., n_classes-1} (>= n_classes is void)
146 | gt_instance: [H, W] range {0, ..., n_instances} (zero means background)
147 | unique_id_mapping: instance id mapping to check consistency
148 | """
149 | n_classes = self.n_classes
150 |
151 | result = {key: torch.zeros(n_classes, dtype=torch.float32, device=gt_instance.device) for key in self.keys}
152 |
153 | assert pred_segmentation.dim() == 2
154 | assert pred_segmentation.shape == pred_instance.shape == gt_segmentation.shape == gt_instance.shape
155 |
156 | n_instances = int(torch.cat([pred_instance, gt_instance]).max().item())
157 | n_all_things = n_instances + n_classes # Classes + instances.
158 | n_things_and_void = n_all_things + 1
159 |
160 | # Now 1 is background; 0 is void (not used). 2 is vehicle semantic class but since it overlaps with
161 | # instances, it is not present.
162 | # and the rest are instance ids starting from 3
163 | prediction, pred_to_cls = self.combine_mask(pred_segmentation, pred_instance, n_classes, n_all_things)
164 | target, target_to_cls = self.combine_mask(gt_segmentation, gt_instance, n_classes, n_all_things)
165 |
166 | # Compute ious between all stuff and things
167 | # hack for bincounting 2 arrays together
168 | x = prediction + n_things_and_void * target
169 | bincount_2d = torch.bincount(x.long(), minlength=n_things_and_void ** 2)
170 | if bincount_2d.shape[0] != n_things_and_void ** 2:
171 | raise ValueError('Incorrect bincount size.')
172 | conf = bincount_2d.reshape((n_things_and_void, n_things_and_void))
173 | # Drop void class
174 | conf = conf[1:, 1:]
175 |
176 | # Confusion matrix contains intersections between all combinations of classes
177 | union = conf.sum(0).unsqueeze(0) + conf.sum(1).unsqueeze(1) - conf
178 | iou = torch.where(union > 0, (conf.float() + 1e-9) / (union.float() + 1e-9), torch.zeros_like(union).float())
179 |
180 | # In the iou matrix, first dimension is target idx, second dimension is pred idx.
181 | # Mapping will contain a tuple that maps prediction idx to target idx for segments matched by iou.
182 | mapping = (iou > 0.5).nonzero(as_tuple=False)
183 |
184 | # Check that classes match.
185 | is_matching = pred_to_cls[mapping[:, 1]] == target_to_cls[mapping[:, 0]]
186 | mapping = mapping[is_matching]
187 | tp_mask = torch.zeros_like(conf, dtype=torch.bool)
188 | tp_mask[mapping[:, 0], mapping[:, 1]] = True
189 |
190 | # First ids correspond to "stuff" i.e. semantic seg.
191 | # Instance ids are offset accordingly
192 | for target_id, pred_id in mapping:
193 | cls_id = pred_to_cls[pred_id]
194 |
195 | if self.temporally_consistent and cls_id == self.vehicles_id:
196 | if target_id.item() in unique_id_mapping and unique_id_mapping[target_id.item()] != pred_id.item():
197 | # Not temporally consistent
198 | result['false_negative'][target_to_cls[target_id]] += 1
199 | result['false_positive'][pred_to_cls[pred_id]] += 1
200 | unique_id_mapping[target_id.item()] = pred_id.item()
201 | continue
202 |
203 | result['true_positive'][cls_id] += 1
204 | result['iou'][cls_id] += iou[target_id][pred_id]
205 | unique_id_mapping[target_id.item()] = pred_id.item()
206 |
207 | for target_id in range(n_classes, n_all_things):
208 | # If this is a true positive do nothing.
209 | if tp_mask[target_id, n_classes:].any():
210 | continue
211 | # If this target instance didn't match with any predictions and was present set it as false negative.
212 | if target_to_cls[target_id] != -1:
213 | result['false_negative'][target_to_cls[target_id]] += 1
214 |
215 | for pred_id in range(n_classes, n_all_things):
216 | # If this is a true positive do nothing.
217 | if tp_mask[n_classes:, pred_id].any():
218 | continue
219 | # If this predicted instance didn't match with any prediction, set that predictions as false positive.
220 | if pred_to_cls[pred_id] != -1 and (conf[:, pred_id] > 0).any():
221 | result['false_positive'][pred_to_cls[pred_id]] += 1
222 |
223 | return result
224 |
225 | def combine_mask(self, segmentation: torch.Tensor, instance: torch.Tensor, n_classes: int, n_all_things: int):
226 | """Shifts all things ids by num_classes and combines things and stuff into a single mask
227 |
228 | Returns a combined mask + a mapping from id to segmentation class.
229 | """
230 | instance = instance.view(-1)
231 | instance_mask = instance > 0
232 | instance = instance - 1 + n_classes
233 |
234 | segmentation = segmentation.clone().view(-1)
235 | segmentation_mask = segmentation < n_classes # Remove void pixels.
236 |
237 | # Build an index from instance id to class id.
238 | instance_id_to_class_tuples = torch.cat(
239 | (
240 | instance[instance_mask & segmentation_mask].unsqueeze(1),
241 | segmentation[instance_mask & segmentation_mask].unsqueeze(1),
242 | ),
243 | dim=1,
244 | )
245 | instance_id_to_class = -instance_id_to_class_tuples.new_ones((n_all_things,))
246 | instance_id_to_class[instance_id_to_class_tuples[:, 0]] = instance_id_to_class_tuples[:, 1]
247 | instance_id_to_class[torch.arange(n_classes, device=segmentation.device)] = torch.arange(
248 | n_classes, device=segmentation.device
249 | )
250 |
251 | segmentation[instance_mask] = instance[instance_mask]
252 | segmentation += 1 # Shift all legit classes by 1.
253 | segmentation[~segmentation_mask] = 0 # Shift void class to zero.
254 |
255 | return segmentation, instance_id_to_class
256 |
--------------------------------------------------------------------------------
/stretchbev/layers/temporal.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | from stretchbev.layers.convolutions import ConvBlock
7 | from stretchbev.utils.geometry import warp_features
8 |
9 |
10 | class SpatialGRU(nn.Module):
11 | """A GRU cell that takes an input tensor [BxTxCxHxW] and an optional previous state and passes a
12 | convolutional gated recurrent unit over the data"""
13 |
14 | def __init__(self, input_size, hidden_size, gru_bias_init=0.0, norm='bn', activation='relu'):
15 | super().__init__()
16 | self.input_size = input_size
17 | self.hidden_size = hidden_size
18 | self.gru_bias_init = gru_bias_init
19 |
20 | self.conv_update = nn.Conv2d(input_size + hidden_size, hidden_size, kernel_size=3, bias=True, padding=1)
21 | self.conv_reset = nn.Conv2d(input_size + hidden_size, hidden_size, kernel_size=3, bias=True, padding=1)
22 |
23 | self.conv_state_tilde = ConvBlock(
24 | input_size + hidden_size, hidden_size, kernel_size=3, bias=False, norm=norm, activation=activation
25 | )
26 |
27 | def forward(self, x, state=None, flow=None, mode='bilinear'):
28 | # pylint: disable=unused-argument, arguments-differ
29 | # Check size
30 | assert len(x.size()) == 5, 'Input tensor must be BxTxCxHxW.'
31 | b, timesteps, c, h, w = x.size()
32 | assert c == self.input_size, f'feature sizes must match, got input {c} for layer with size {self.input_size}'
33 |
34 | # recurrent layers
35 | rnn_output = []
36 | rnn_state = torch.zeros(b, self.hidden_size, h, w, device=x.device) if state is None else state
37 | for t in range(timesteps):
38 | x_t = x[:, t]
39 | if flow is not None:
40 | rnn_state = warp_features(rnn_state, flow[:, t], mode=mode)
41 |
42 | # propagate rnn state
43 | rnn_state = self.gru_cell(x_t, rnn_state)
44 | rnn_output.append(rnn_state)
45 |
46 | # reshape rnn output to batch tensor
47 | return torch.stack(rnn_output, dim=1)
48 |
49 | def gru_cell(self, x, state):
50 | # Compute gates
51 | x_and_state = torch.cat([x, state], dim=1)
52 | update_gate = self.conv_update(x_and_state)
53 | reset_gate = self.conv_reset(x_and_state)
54 | # Add bias to initialise gate as close to identity function
55 | update_gate = torch.sigmoid(update_gate + self.gru_bias_init)
56 | reset_gate = torch.sigmoid(reset_gate + self.gru_bias_init)
57 |
58 | # Compute proposal state, activation is defined in norm_act_config (can be tanh, ReLU etc)
59 | state_tilde = self.conv_state_tilde(torch.cat([x, (1.0 - reset_gate) * state], dim=1))
60 |
61 | output = (1.0 - update_gate) * state + update_gate * state_tilde
62 | return output
63 |
64 |
65 | class CausalConv3d(nn.Module):
66 | def __init__(self, in_channels, out_channels, kernel_size=(2, 3, 3), dilation=(1, 1, 1), bias=False):
67 | super().__init__()
68 | assert len(kernel_size) == 3, 'kernel_size must be a 3-tuple.'
69 | time_pad = (kernel_size[0] - 1) * dilation[0]
70 | height_pad = ((kernel_size[1] - 1) * dilation[1]) // 2
71 | width_pad = ((kernel_size[2] - 1) * dilation[2]) // 2
72 |
73 | # Pad temporally on the left
74 | self.pad = nn.ConstantPad3d(padding=(width_pad, width_pad, height_pad, height_pad, time_pad, 0), value=0)
75 | self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, dilation=dilation, stride=1, padding=0, bias=bias)
76 | self.norm = nn.BatchNorm3d(out_channels)
77 | self.activation = nn.ReLU(inplace=True)
78 |
79 | def forward(self, *inputs):
80 | (x,) = inputs
81 | x = self.pad(x)
82 | x = self.conv(x)
83 | x = self.norm(x)
84 | x = self.activation(x)
85 | return x
86 |
87 |
88 | class CausalMaxPool3d(nn.Module):
89 | def __init__(self, kernel_size=(2, 3, 3)):
90 | super().__init__()
91 | assert len(kernel_size) == 3, 'kernel_size must be a 3-tuple.'
92 | time_pad = kernel_size[0] - 1
93 | height_pad = (kernel_size[1] - 1) // 2
94 | width_pad = (kernel_size[2] - 1) // 2
95 |
96 | # Pad temporally on the left
97 | self.pad = nn.ConstantPad3d(padding=(width_pad, width_pad, height_pad, height_pad, time_pad, 0), value=0)
98 | self.max_pool = nn.MaxPool3d(kernel_size, stride=1)
99 |
100 | def forward(self, *inputs):
101 | (x,) = inputs
102 | x = self.pad(x)
103 | x = self.max_pool(x)
104 | return x
105 |
106 |
107 | def conv_1x1x1_norm_activated(in_channels, out_channels):
108 | """1x1x1 3D convolution, normalization and activation layer."""
109 | return nn.Sequential(
110 | OrderedDict(
111 | [
112 | ('conv', nn.Conv3d(in_channels, out_channels, kernel_size=1, bias=False)),
113 | ('norm', nn.BatchNorm3d(out_channels)),
114 | ('activation', nn.ReLU(inplace=True)),
115 | ]
116 | )
117 | )
118 |
119 |
120 | class Bottleneck3D(nn.Module):
121 | """
122 | Defines a bottleneck module with a residual connection
123 | """
124 |
125 | def __init__(self, in_channels, out_channels=None, kernel_size=(2, 3, 3), dilation=(1, 1, 1)):
126 | super().__init__()
127 | bottleneck_channels = in_channels // 2
128 | out_channels = out_channels or in_channels
129 |
130 | self.layers = nn.Sequential(
131 | OrderedDict(
132 | [
133 | # First projection with 1x1 kernel
134 | (
135 | 'conv_down_project', conv_1x1x1_norm_activated(in_channels, bottleneck_channels)
136 | ),
137 | # Second conv block
138 | (
139 | 'conv',
140 | CausalConv3d(
141 | bottleneck_channels,
142 | bottleneck_channels,
143 | kernel_size=kernel_size,
144 | dilation=dilation,
145 | bias=False,
146 | ),
147 | ),
148 | # Final projection with 1x1 kernel
149 | (
150 | 'conv_up_project', conv_1x1x1_norm_activated(bottleneck_channels, out_channels)
151 | ),
152 | ]
153 | )
154 | )
155 |
156 | if out_channels != in_channels:
157 | self.projection = nn.Sequential(
158 | nn.Conv3d(in_channels, out_channels, kernel_size=1, bias=False),
159 | nn.BatchNorm3d(out_channels),
160 | )
161 | else:
162 | self.projection = None
163 |
164 | def forward(self, *args):
165 | (x,) = args
166 | x_residual = self.layers(x)
167 | x_features = self.projection(x) if self.projection is not None else x
168 | return x_residual + x_features
169 |
170 |
171 | class PyramidSpatioTemporalPooling(nn.Module):
172 | """ Spatio-temporal pyramid pooling.
173 | Performs 3D average pooling followed by 1x1x1 convolution to reduce the number of channels and upsampling.
174 | Setting contains a list of kernel_size: usually it is [(2, h, w), (2, h//2, w//2), (2, h//4, w//4)]
175 | """
176 |
177 | def __init__(self, in_channels, reduction_channels, pool_sizes):
178 | super().__init__()
179 | self.features = []
180 | for pool_size in pool_sizes:
181 | assert pool_size[0] == 2, (
182 | "Time kernel should be 2 as PyTorch raises an error when" "padding with more than half the kernel size"
183 | )
184 | stride = (1, *pool_size[1:])
185 | padding = (pool_size[0] - 1, 0, 0)
186 | self.features.append(
187 | nn.Sequential(
188 | OrderedDict(
189 | [
190 | # Pad the input tensor but do not take into account zero padding into the average.
191 | (
192 | 'avgpool',
193 | torch.nn.AvgPool3d(
194 | kernel_size=pool_size, stride=stride, padding=padding, count_include_pad=False
195 | ),
196 | ),
197 | (
198 | 'conv_bn_relu', conv_1x1x1_norm_activated(in_channels, reduction_channels)
199 | ),
200 | ]
201 | )
202 | )
203 | )
204 | self.features = nn.ModuleList(self.features)
205 |
206 | def forward(self, *inputs):
207 | (x,) = inputs
208 | b, _, t, h, w = x.shape
209 | # Do not include current tensor when concatenating
210 | out = []
211 | for f in self.features:
212 | # Remove unnecessary padded values (time dimension) on the right
213 | x_pool = f(x)[:, :, :-1].contiguous()
214 | c = x_pool.shape[1]
215 | x_pool = nn.functional.interpolate(
216 | x_pool.view(b * t, c, *x_pool.shape[-2:]), (h, w), mode='bilinear', align_corners=False
217 | )
218 | x_pool = x_pool.view(b, c, t, h, w)
219 | out.append(x_pool)
220 | out = torch.cat(out, 1)
221 | return out
222 |
223 |
224 | class TemporalBlock(nn.Module):
225 | """ Temporal block with the following layers:
226 | - 2x3x3, 1x3x3, spatio-temporal pyramid pooling
227 | - dropout
228 | - skip connection.
229 | """
230 |
231 | def __init__(self, in_channels, out_channels=None, use_pyramid_pooling=False, pool_sizes=None):
232 | super().__init__()
233 | self.in_channels = in_channels
234 | self.half_channels = in_channels // 2
235 | self.out_channels = out_channels or self.in_channels
236 | self.kernels = [(2, 3, 3), (1, 3, 3)]
237 |
238 | # Flag for spatio-temporal pyramid pooling
239 | self.use_pyramid_pooling = use_pyramid_pooling
240 |
241 | # 3 convolution paths: 2x3x3, 1x3x3, 1x1x1
242 | self.convolution_paths = []
243 | for kernel_size in self.kernels:
244 | self.convolution_paths.append(
245 | nn.Sequential(
246 | conv_1x1x1_norm_activated(self.in_channels, self.half_channels),
247 | CausalConv3d(self.half_channels, self.half_channels, kernel_size=kernel_size),
248 | )
249 | )
250 | self.convolution_paths.append(conv_1x1x1_norm_activated(self.in_channels, self.half_channels))
251 | self.convolution_paths = nn.ModuleList(self.convolution_paths)
252 |
253 | agg_in_channels = len(self.convolution_paths) * self.half_channels
254 |
255 | if self.use_pyramid_pooling:
256 | assert pool_sizes is not None, "setting must contain the list of kernel_size, but is None."
257 | reduction_channels = self.in_channels // 3
258 | self.pyramid_pooling = PyramidSpatioTemporalPooling(self.in_channels, reduction_channels, pool_sizes)
259 | agg_in_channels += len(pool_sizes) * reduction_channels
260 |
261 | # Feature aggregation
262 | self.aggregation = nn.Sequential(
263 | conv_1x1x1_norm_activated(agg_in_channels, self.out_channels), )
264 |
265 | if self.out_channels != self.in_channels:
266 | self.projection = nn.Sequential(
267 | nn.Conv3d(self.in_channels, self.out_channels, kernel_size=1, bias=False),
268 | nn.BatchNorm3d(self.out_channels),
269 | )
270 | else:
271 | self.projection = None
272 |
273 | def forward(self, *inputs):
274 | (x,) = inputs
275 | x_paths = []
276 | for conv in self.convolution_paths:
277 | x_paths.append(conv(x))
278 | x_residual = torch.cat(x_paths, dim=1)
279 | if self.use_pyramid_pooling:
280 | x_pool = self.pyramid_pooling(x)
281 | x_residual = torch.cat([x_residual, x_pool], dim=1)
282 | x_residual = self.aggregation(x_residual)
283 |
284 | if self.out_channels != self.in_channels:
285 | x = self.projection(x)
286 | x = x + x_residual
287 | return x
288 |
--------------------------------------------------------------------------------
/stretchbev/metrics.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | import torch
4 | from pytorch_lightning.metrics.functional.classification import stat_scores_multiple_classes
5 | from pytorch_lightning.metrics.functional.reduction import reduce
6 | from pytorch_lightning.metrics.metric import Metric
7 |
8 |
9 | class IntersectionOverUnion(Metric):
10 | """Computes intersection-over-union."""
11 |
12 | def __init__(
13 | self,
14 | n_classes: int,
15 | ignore_index: Optional[int] = None,
16 | absent_score: float = 0.0,
17 | reduction: str = 'none',
18 | compute_on_step: bool = False,
19 | ):
20 | super().__init__(compute_on_step=compute_on_step)
21 |
22 | self.n_classes = n_classes
23 | self.ignore_index = ignore_index
24 | self.absent_score = absent_score
25 | self.reduction = reduction
26 |
27 | self.add_state('true_positive', default=torch.zeros(n_classes), dist_reduce_fx='sum')
28 | self.add_state('false_positive', default=torch.zeros(n_classes), dist_reduce_fx='sum')
29 | self.add_state('false_negative', default=torch.zeros(n_classes), dist_reduce_fx='sum')
30 | self.add_state('support', default=torch.zeros(n_classes), dist_reduce_fx='sum')
31 |
32 | def update(self, prediction: torch.Tensor, target: torch.Tensor):
33 | tps, fps, _, fns, sups = stat_scores_multiple_classes(prediction, target, self.n_classes)
34 |
35 | # print('update', tps, fps, fns, sups)
36 | self.true_positive += tps
37 | self.false_positive += fps
38 | self.false_negative += fns
39 | self.support += sups
40 |
41 | # return tps, fps, fns, sups
42 |
43 | def compute(self):
44 | scores = torch.zeros(self.n_classes, device=self.true_positive.device, dtype=torch.float32)
45 |
46 | # print('compute', self.true_positive, self.false_positive, self.false_negative, self.support)
47 | for class_idx in range(self.n_classes):
48 | if class_idx == self.ignore_index:
49 | continue
50 |
51 | tp = self.true_positive[class_idx]
52 | fp = self.false_positive[class_idx]
53 | fn = self.false_negative[class_idx]
54 | sup = self.support[class_idx]
55 |
56 | # If this class is absent in the target (no support) AND absent in the pred (no true or false
57 | # positives), then use the absent_score for this class.
58 | if sup + tp + fp == 0:
59 | scores[class_idx] = self.absent_score
60 | continue
61 |
62 | denominator = tp + fp + fn
63 | score = tp.to(torch.float) / denominator
64 | scores[class_idx] = score
65 |
66 | # Remove the ignored class index from the scores.
67 | if (self.ignore_index is not None) and (0 <= self.ignore_index < self.n_classes):
68 | scores = torch.cat([scores[:self.ignore_index], scores[self.ignore_index + 1:]])
69 |
70 | return reduce(scores, reduction=self.reduction)
71 |
72 | def calculate_batch(self, prediction: torch.Tensor, target: torch.Tensor):
73 | self.reset()
74 | self.update(prediction, target)
75 |
76 | return self.compute()
77 |
78 |
79 | class PanopticMetric(Metric):
80 | def __init__(
81 | self,
82 | n_classes: int,
83 | temporally_consistent: bool = True,
84 | vehicles_id: int = 1,
85 | compute_on_step: bool = False,
86 | ):
87 | super().__init__(compute_on_step=compute_on_step)
88 |
89 | self.n_classes = n_classes
90 | self.temporally_consistent = temporally_consistent
91 | self.vehicles_id = vehicles_id
92 | self.keys = ['iou', 'true_positive', 'false_positive', 'false_negative']
93 |
94 | self.add_state('iou', default=torch.zeros(n_classes), dist_reduce_fx='sum')
95 | self.add_state('true_positive', default=torch.zeros(n_classes), dist_reduce_fx='sum')
96 | self.add_state('false_positive', default=torch.zeros(n_classes), dist_reduce_fx='sum')
97 | self.add_state('false_negative', default=torch.zeros(n_classes), dist_reduce_fx='sum')
98 |
99 | def update(self, pred_instance, gt_instance):
100 | """
101 | Update state with predictions and targets.
102 |
103 | Parameters
104 | ----------
105 | pred_instance: (b, s, h, w)
106 | Temporally consistent instance segmentation prediction.
107 | gt_instance: (b, s, h, w)
108 | Ground truth instance segmentation.
109 | """
110 | batch_size, sequence_length = gt_instance.shape[:2]
111 | # Process labels
112 | assert gt_instance.min() == 0, 'ID 0 of gt_instance must be background'
113 | pred_segmentation = (pred_instance > 0).long()
114 | gt_segmentation = (gt_instance > 0).long()
115 |
116 | for b in range(batch_size):
117 | unique_id_mapping = {}
118 | for t in range(sequence_length):
119 | result = self.panoptic_metrics(
120 | pred_segmentation[b, t].detach(),
121 | pred_instance[b, t].detach(),
122 | gt_segmentation[b, t],
123 | gt_instance[b, t],
124 | unique_id_mapping,
125 | )
126 |
127 | self.iou += result['iou']
128 | self.true_positive += result['true_positive']
129 | self.false_positive += result['false_positive']
130 | self.false_negative += result['false_negative']
131 |
132 | def compute(self):
133 | denominator = torch.maximum(
134 | (self.true_positive + self.false_positive / 2 + self.false_negative / 2),
135 | torch.ones_like(self.true_positive)
136 | )
137 | pq = self.iou / denominator
138 | sq = self.iou / torch.maximum(self.true_positive, torch.ones_like(self.true_positive))
139 | rq = self.true_positive / denominator
140 |
141 | return {'pq': pq,
142 | 'sq': sq,
143 | 'rq': rq,
144 | # If 0, it means there wasn't any detection.
145 | 'denominator': (self.true_positive + self.false_positive / 2 + self.false_negative / 2),
146 | }
147 |
148 | def calculate_batch(self, pred_instance, gt_instance):
149 | self.reset()
150 | self.update(pred_instance, gt_instance)
151 | return self.compute()
152 |
153 | def panoptic_metrics(self, pred_segmentation, pred_instance, gt_segmentation, gt_instance, unique_id_mapping):
154 | """
155 | Computes panoptic quality metric components.
156 |
157 | Parameters
158 | ----------
159 | pred_segmentation: [H, W] range {0, ..., n_classes-1} (>= n_classes is void)
160 | pred_instance: [H, W] range {0, ..., n_instances} (zero means background)
161 | gt_segmentation: [H, W] range {0, ..., n_classes-1} (>= n_classes is void)
162 | gt_instance: [H, W] range {0, ..., n_instances} (zero means background)
163 | unique_id_mapping: instance id mapping to check consistency
164 | """
165 | n_classes = self.n_classes
166 |
167 | result = {key: torch.zeros(n_classes, dtype=torch.float32, device=gt_instance.device) for key in self.keys}
168 |
169 | assert pred_segmentation.dim() == 2
170 | assert pred_segmentation.shape == pred_instance.shape == gt_segmentation.shape == gt_instance.shape
171 |
172 | n_instances = int(torch.cat([pred_instance, gt_instance]).max().item())
173 | n_all_things = n_instances + n_classes # Classes + instances.
174 | n_things_and_void = n_all_things + 1
175 |
176 | # Now 1 is background; 0 is void (not used). 2 is vehicle semantic class but since it overlaps with
177 | # instances, it is not present.
178 | # and the rest are instance ids starting from 3
179 | prediction, pred_to_cls = self.combine_mask(pred_segmentation, pred_instance, n_classes, n_all_things)
180 | target, target_to_cls = self.combine_mask(gt_segmentation, gt_instance, n_classes, n_all_things)
181 |
182 | # Compute ious between all stuff and things
183 | # hack for bincounting 2 arrays together
184 | x = prediction + n_things_and_void * target
185 | bincount_2d = torch.bincount(x.long(), minlength=n_things_and_void ** 2)
186 | if bincount_2d.shape[0] != n_things_and_void ** 2:
187 | raise ValueError('Incorrect bincount size.')
188 | conf = bincount_2d.reshape((n_things_and_void, n_things_and_void))
189 | # Drop void class
190 | conf = conf[1:, 1:]
191 |
192 | # Confusion matrix contains intersections between all combinations of classes
193 | union = conf.sum(0).unsqueeze(0) + conf.sum(1).unsqueeze(1) - conf
194 | iou = torch.where(union > 0, (conf.float() + 1e-9) / (union.float() + 1e-9), torch.zeros_like(union).float())
195 |
196 | # In the iou matrix, first dimension is target idx, second dimension is pred idx.
197 | # Mapping will contain a tuple that maps prediction idx to target idx for segments matched by iou.
198 | mapping = (iou > 0.5).nonzero(as_tuple=False)
199 |
200 | # Check that classes match.
201 | is_matching = pred_to_cls[mapping[:, 1]] == target_to_cls[mapping[:, 0]]
202 | mapping = mapping[is_matching]
203 | tp_mask = torch.zeros_like(conf, dtype=torch.bool)
204 | tp_mask[mapping[:, 0], mapping[:, 1]] = True
205 |
206 | # First ids correspond to "stuff" i.e. semantic seg.
207 | # Instance ids are offset accordingly
208 | for target_id, pred_id in mapping:
209 | cls_id = pred_to_cls[pred_id]
210 |
211 | if self.temporally_consistent and cls_id == self.vehicles_id:
212 | if target_id.item() in unique_id_mapping and unique_id_mapping[target_id.item()] != pred_id.item():
213 | # Not temporally consistent
214 | result['false_negative'][target_to_cls[target_id]] += 1
215 | result['false_positive'][pred_to_cls[pred_id]] += 1
216 | unique_id_mapping[target_id.item()] = pred_id.item()
217 | continue
218 |
219 | result['true_positive'][cls_id] += 1
220 | result['iou'][cls_id] += iou[target_id][pred_id]
221 | unique_id_mapping[target_id.item()] = pred_id.item()
222 |
223 | for target_id in range(n_classes, n_all_things):
224 | # If this is a true positive do nothing.
225 | if tp_mask[target_id, n_classes:].any():
226 | continue
227 | # If this target instance didn't match with any predictions and was present set it as false negative.
228 | if target_to_cls[target_id] != -1:
229 | result['false_negative'][target_to_cls[target_id]] += 1
230 |
231 | for pred_id in range(n_classes, n_all_things):
232 | # If this is a true positive do nothing.
233 | if tp_mask[n_classes:, pred_id].any():
234 | continue
235 | # If this predicted instance didn't match with any prediction, set that predictions as false positive.
236 | if pred_to_cls[pred_id] != -1 and (conf[:, pred_id] > 0).any():
237 | result['false_positive'][pred_to_cls[pred_id]] += 1
238 |
239 | return result
240 |
241 | def combine_mask(self, segmentation: torch.Tensor, instance: torch.Tensor, n_classes: int, n_all_things: int):
242 | """Shifts all things ids by num_classes and combines things and stuff into a single mask
243 |
244 | Returns a combined mask + a mapping from id to segmentation class.
245 | """
246 | instance = instance.view(-1)
247 | instance_mask = instance > 0
248 | instance = instance - 1 + n_classes
249 |
250 | segmentation = segmentation.clone().view(-1)
251 | segmentation_mask = segmentation < n_classes # Remove void pixels.
252 |
253 | # Build an index from instance id to class id.
254 | instance_id_to_class_tuples = torch.cat(
255 | (
256 | instance[instance_mask & segmentation_mask].unsqueeze(1),
257 | segmentation[instance_mask & segmentation_mask].unsqueeze(1),
258 | ),
259 | dim=1,
260 | )
261 | instance_id_to_class = -instance_id_to_class_tuples.new_ones((n_all_things,))
262 | instance_id_to_class[instance_id_to_class_tuples[:, 0]] = instance_id_to_class_tuples[:, 1]
263 | instance_id_to_class[torch.arange(n_classes, device=segmentation.device)] = torch.arange(
264 | n_classes, device=segmentation.device
265 | )
266 |
267 | segmentation[instance_mask] = instance[instance_mask]
268 | segmentation += 1 # Shift all legit classes by 1.
269 | segmentation[~segmentation_mask] = 0 # Shift void class to zero.
270 |
271 | return segmentation, instance_id_to_class
272 |
--------------------------------------------------------------------------------
/stretchbev/utils/visualisation.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pylab
2 | import numpy as np
3 | import torch
4 |
5 | from stretchbev.utils.instance import predict_instance_segmentation_and_trajectories
6 |
7 | DEFAULT_COLORMAP = matplotlib.pylab.cm.jet
8 |
9 |
10 | def flow_to_image(flow: np.ndarray, autoscale: bool = False) -> np.ndarray:
11 | """
12 | Applies colour map to flow which should be a 2 channel image tensor HxWx2. Returns a HxWx3 numpy image
13 | Code adapted from: https://github.com/liruoteng/FlowNet/blob/master/models/flownet/scripts/flowlib.py
14 | """
15 | u = flow[0, :, :]
16 | v = flow[1, :, :]
17 |
18 | # Convert to polar coordinates
19 | rad = np.sqrt(u ** 2 + v ** 2)
20 | maxrad = np.max(rad)
21 |
22 | # Normalise flow maps
23 | if autoscale:
24 | u /= maxrad + np.finfo(float).eps
25 | v /= maxrad + np.finfo(float).eps
26 |
27 | # visualise flow with cmap
28 | return np.uint8(compute_color(u, v) * 255)
29 |
30 |
31 | def _normalise(image: np.ndarray) -> np.ndarray:
32 | lower = np.min(image)
33 | delta = np.max(image) - lower
34 | if delta == 0:
35 | delta = 1
36 | image = (image.astype(np.float32) - lower) / delta
37 | return image
38 |
39 |
40 | def apply_colour_map(
41 | image: np.ndarray, cmap: matplotlib.colors.LinearSegmentedColormap = DEFAULT_COLORMAP, autoscale: bool = False
42 | ) -> np.ndarray:
43 | """
44 | Applies a colour map to the given 1 or 2 channel numpy image. if 2 channel, must be 2xHxW.
45 | Returns a HxWx3 numpy image
46 | """
47 | if image.ndim == 2 or (image.ndim == 3 and image.shape[0] == 1):
48 | if image.ndim == 3:
49 | image = image[0]
50 | # grayscale scalar image
51 | if autoscale:
52 | image = _normalise(image)
53 | return cmap(image)[:, :, :3]
54 | if image.shape[0] == 2:
55 | # 2 dimensional UV
56 | return flow_to_image(image, autoscale=autoscale)
57 | if image.shape[0] == 3:
58 | # normalise rgb channels
59 | if autoscale:
60 | image = _normalise(image)
61 | return np.transpose(image, axes=[1, 2, 0])
62 | raise Exception('Image must be 1, 2 or 3 channel to convert to colour_map (CxHxW)')
63 |
64 |
65 | def heatmap_image(
66 | image: np.ndarray, cmap: matplotlib.colors.LinearSegmentedColormap = DEFAULT_COLORMAP, autoscale: bool = True
67 | ) -> np.ndarray:
68 | """Colorize an 1 or 2 channel image with a colourmap."""
69 | if not issubclass(image.dtype.type, np.floating):
70 | raise ValueError(f"Expected a ndarray of float type, but got dtype {image.dtype}")
71 | if not (image.ndim == 2 or (image.ndim == 3 and image.shape[0] in [1, 2])):
72 | raise ValueError(f"Expected a ndarray of shape [H, W] or [1, H, W] or [2, H, W], but got shape {image.shape}")
73 | heatmap_np = apply_colour_map(image, cmap=cmap, autoscale=autoscale)
74 | heatmap_np = np.uint8(heatmap_np * 255)
75 | return heatmap_np
76 |
77 |
78 | def compute_color(u: np.ndarray, v: np.ndarray) -> np.ndarray:
79 | assert u.shape == v.shape
80 | [h, w] = u.shape
81 | img = np.zeros([h, w, 3])
82 | nan_mask = np.isnan(u) | np.isnan(v)
83 | u[nan_mask] = 0
84 | v[nan_mask] = 0
85 |
86 | colorwheel = make_color_wheel()
87 | ncols = np.size(colorwheel, 0)
88 |
89 | rad = np.sqrt(u ** 2 + v ** 2)
90 | a = np.arctan2(-v, -u) / np.pi
91 | f_k = (a + 1) / 2 * (ncols - 1) + 1
92 | k_0 = np.floor(f_k).astype(int)
93 | k_1 = k_0 + 1
94 | k_1[k_1 == ncols + 1] = 1
95 | f = f_k - k_0
96 |
97 | for i in range(0, np.size(colorwheel, 1)):
98 | tmp = colorwheel[:, i]
99 | col0 = tmp[k_0 - 1] / 255
100 | col1 = tmp[k_1 - 1] / 255
101 | col = (1 - f) * col0 + f * col1
102 |
103 | idx = rad <= 1
104 | col[idx] = 1 - rad[idx] * (1 - col[idx])
105 | notidx = np.logical_not(idx)
106 |
107 | col[notidx] *= 0.75
108 | img[:, :, i] = col * (1 - nan_mask)
109 |
110 | return img
111 |
112 |
113 | def make_color_wheel() -> np.ndarray:
114 | """
115 | Create colour wheel.
116 | Code adapted from https://github.com/liruoteng/FlowNet/blob/master/models/flownet/scripts/flowlib.py
117 | """
118 | red_yellow = 15
119 | yellow_green = 6
120 | green_cyan = 4
121 | cyan_blue = 11
122 | blue_magenta = 13
123 | magenta_red = 6
124 |
125 | ncols = red_yellow + yellow_green + green_cyan + cyan_blue + blue_magenta + magenta_red
126 | colorwheel = np.zeros([ncols, 3])
127 |
128 | col = 0
129 |
130 | # red_yellow
131 | colorwheel[0:red_yellow, 0] = 255
132 | colorwheel[0:red_yellow, 1] = np.transpose(np.floor(255 * np.arange(0, red_yellow) / red_yellow))
133 | col += red_yellow
134 |
135 | # yellow_green
136 | colorwheel[col: col + yellow_green, 0] = 255 - np.transpose(
137 | np.floor(255 * np.arange(0, yellow_green) / yellow_green)
138 | )
139 | colorwheel[col: col + yellow_green, 1] = 255
140 | col += yellow_green
141 |
142 | # green_cyan
143 | colorwheel[col: col + green_cyan, 1] = 255
144 | colorwheel[col: col + green_cyan, 2] = np.transpose(np.floor(255 * np.arange(0, green_cyan) / green_cyan))
145 | col += green_cyan
146 |
147 | # cyan_blue
148 | colorwheel[col: col + cyan_blue, 1] = 255 - np.transpose(np.floor(255 * np.arange(0, cyan_blue) / cyan_blue))
149 | colorwheel[col: col + cyan_blue, 2] = 255
150 | col += cyan_blue
151 |
152 | # blue_magenta
153 | colorwheel[col: col + blue_magenta, 2] = 255
154 | colorwheel[col: col + blue_magenta, 0] = np.transpose(np.floor(255 * np.arange(0, blue_magenta) / blue_magenta))
155 | col += +blue_magenta
156 |
157 | # magenta_red
158 | colorwheel[col: col + magenta_red, 2] = 255 - np.transpose(np.floor(255 * np.arange(0, magenta_red) / magenta_red))
159 | colorwheel[col: col + magenta_red, 0] = 255
160 |
161 | return colorwheel
162 |
163 |
164 | def make_contour(img, colour=[0, 0, 0], double_line=False):
165 | h, w = img.shape[:2]
166 | out = img.copy()
167 | # Vertical lines
168 | out[np.arange(h), np.repeat(0, h)] = colour
169 | out[np.arange(h), np.repeat(w - 1, h)] = colour
170 |
171 | # Horizontal lines
172 | out[np.repeat(0, w), np.arange(w)] = colour
173 | out[np.repeat(h - 1, w), np.arange(w)] = colour
174 |
175 | if double_line:
176 | out[np.arange(h), np.repeat(1, h)] = colour
177 | out[np.arange(h), np.repeat(w - 2, h)] = colour
178 |
179 | # Horizontal lines
180 | out[np.repeat(1, w), np.arange(w)] = colour
181 | out[np.repeat(h - 2, w), np.arange(w)] = colour
182 | return out
183 |
184 |
185 | def plot_instance_map(instance_image, instance_map, instance_colours=None, bg_image=None):
186 | if isinstance(instance_image, torch.Tensor):
187 | instance_image = instance_image.cpu().numpy()
188 | assert isinstance(instance_image, np.ndarray)
189 | if instance_colours is None:
190 | instance_colours = generate_instance_colours(instance_map)
191 | if len(instance_image.shape) > 2:
192 | instance_image = instance_image.reshape((instance_image.shape[-2], instance_image.shape[-1]))
193 |
194 | if bg_image is None:
195 | plot_image = 255 * np.ones((instance_image.shape[0], instance_image.shape[1], 3), dtype=np.uint8)
196 | else:
197 | plot_image = bg_image
198 |
199 | for key, value in instance_colours.items():
200 | plot_image[instance_image == key] = value
201 |
202 | return plot_image
203 |
204 |
205 | def visualise_output(labels, output, cfg):
206 | semantic_colours = np.array([[255, 255, 255], [0, 0, 0]], dtype=np.uint8)
207 |
208 | consistent_instance_seg = predict_instance_segmentation_and_trajectories(
209 | output, compute_matched_centers=False
210 | )
211 |
212 | sequence_length = consistent_instance_seg.shape[1]
213 | b = 0
214 | video = []
215 | for t in range(sequence_length):
216 | out_t = []
217 | # Ground truth
218 | unique_ids = torch.unique(labels['instance'][b, t]).cpu().numpy()[1:]
219 | instance_map = dict(zip(unique_ids, unique_ids))
220 | instance_plot = plot_instance_map(labels['instance'][b, t].cpu(), instance_map)[::-1, ::-1]
221 | instance_plot = make_contour(instance_plot)
222 |
223 | semantic_seg = labels['segmentation'].squeeze(2).cpu().numpy()
224 | semantic_plot = semantic_colours[semantic_seg[b, t][::-1, ::-1]]
225 | semantic_plot = make_contour(semantic_plot)
226 |
227 | if cfg.INSTANCE_FLOW.ENABLED:
228 | future_flow_plot = labels['flow'][b, t].cpu().numpy()
229 | future_flow_plot[:, semantic_seg[b, t] != 1] = 0
230 | future_flow_plot = flow_to_image(future_flow_plot)[::-1, ::-1]
231 | future_flow_plot = make_contour(future_flow_plot)
232 | else:
233 | future_flow_plot = np.zeros_like(semantic_plot)
234 |
235 | center_plot = heatmap_image(labels['centerness'][b, t, 0].cpu().numpy())[::-1, ::-1]
236 | center_plot = make_contour(center_plot)
237 |
238 | offset_plot = labels['offset'][b, t].cpu().numpy()
239 | offset_plot[:, semantic_seg[b, t] != 1] = 0
240 | offset_plot = flow_to_image(offset_plot)[::-1, ::-1]
241 | offset_plot = make_contour(offset_plot)
242 |
243 | out_t.append(np.concatenate([instance_plot, future_flow_plot,
244 | semantic_plot, center_plot, offset_plot], axis=0))
245 |
246 | # Predictions
247 | unique_ids = torch.unique(consistent_instance_seg[b, t]).cpu().numpy()[1:]
248 | instance_map = dict(zip(unique_ids, unique_ids))
249 | instance_plot = plot_instance_map(consistent_instance_seg[b, t].cpu(), instance_map)[::-1, ::-1]
250 | instance_plot = make_contour(instance_plot)
251 |
252 | semantic_seg = output['segmentation'].argmax(dim=2).detach().cpu().numpy()
253 | semantic_plot = semantic_colours[semantic_seg[b, t][::-1, ::-1]]
254 | semantic_plot = make_contour(semantic_plot)
255 |
256 | if cfg.INSTANCE_FLOW.ENABLED:
257 | future_flow_plot = output['instance_flow'][b, t].detach().cpu().numpy()
258 | future_flow_plot[:, semantic_seg[b, t] != 1] = 0
259 | future_flow_plot = flow_to_image(future_flow_plot)[::-1, ::-1]
260 | future_flow_plot = make_contour(future_flow_plot)
261 | else:
262 | future_flow_plot = np.zeros_like(semantic_plot)
263 |
264 | center_plot = heatmap_image(output['instance_center'][b, t, 0].detach().cpu().numpy())[::-1, ::-1]
265 | center_plot = make_contour(center_plot)
266 |
267 | offset_plot = output['instance_offset'][b, t].detach().cpu().numpy()
268 | offset_plot[:, semantic_seg[b, t] != 1] = 0
269 | offset_plot = flow_to_image(offset_plot)[::-1, ::-1]
270 | offset_plot = make_contour(offset_plot)
271 |
272 | out_t.append(np.concatenate([instance_plot, future_flow_plot,
273 | semantic_plot, center_plot, offset_plot], axis=0))
274 | out_t = np.concatenate(out_t, axis=1)
275 | # Shape (C, H, W)
276 | out_t = out_t.transpose((2, 0, 1))
277 |
278 | video.append(out_t)
279 |
280 | # Shape (B, T, C, H, W)
281 | video = np.stack(video)[None]
282 | return video
283 |
284 |
285 | def convert_figure_numpy(figure):
286 | """ Convert figure to numpy image """
287 | figure_np = np.frombuffer(figure.canvas.tostring_rgb(), dtype=np.uint8)
288 | figure_np = figure_np.reshape(figure.canvas.get_width_height()[::-1] + (3,))
289 | return figure_np
290 |
291 |
292 | def generate_instance_colours(instance_map):
293 | # Most distinct 22 colors (kelly colors from https://stackoverflow.com/questions/470690/how-to-automatically-generate
294 | # -n-distinct-colors)
295 | # plus some colours from AD40k
296 | INSTANCE_COLOURS = np.asarray([
297 | [0, 0, 0],
298 | [255, 179, 0],
299 | [128, 62, 117],
300 | [255, 104, 0],
301 | [166, 189, 215],
302 | [193, 0, 32],
303 | [206, 162, 98],
304 | [129, 112, 102],
305 | [0, 125, 52],
306 | [246, 118, 142],
307 | [0, 83, 138],
308 | [255, 122, 92],
309 | [83, 55, 122],
310 | [255, 142, 0],
311 | [179, 40, 81],
312 | [244, 200, 0],
313 | [127, 24, 13],
314 | [147, 170, 0],
315 | [89, 51, 21],
316 | [241, 58, 19],
317 | [35, 44, 22],
318 | [112, 224, 255],
319 | [70, 184, 160],
320 | [153, 0, 255],
321 | [71, 255, 0],
322 | [255, 0, 163],
323 | [255, 204, 0],
324 | [0, 255, 235],
325 | [255, 0, 235],
326 | [255, 0, 122],
327 | [255, 245, 0],
328 | [10, 190, 212],
329 | [214, 255, 0],
330 | [0, 204, 255],
331 | [20, 0, 255],
332 | [255, 255, 0],
333 | [0, 153, 255],
334 | [0, 255, 204],
335 | [41, 255, 0],
336 | [173, 0, 255],
337 | [0, 245, 255],
338 | [71, 0, 255],
339 | [0, 255, 184],
340 | [0, 92, 255],
341 | [184, 255, 0],
342 | [255, 214, 0],
343 | [25, 194, 194],
344 | [92, 0, 255],
345 | [220, 220, 220],
346 | [255, 9, 92],
347 | [112, 9, 255],
348 | [8, 255, 214],
349 | [255, 184, 6],
350 | [10, 255, 71],
351 | [255, 41, 10],
352 | [7, 255, 255],
353 | [224, 255, 8],
354 | [102, 8, 255],
355 | [255, 61, 6],
356 | [255, 194, 7],
357 | [0, 255, 20],
358 | [255, 8, 41],
359 | [255, 5, 153],
360 | [6, 51, 255],
361 | [235, 12, 255],
362 | [160, 150, 20],
363 | [0, 163, 255],
364 | [140, 140, 140],
365 | [250, 10, 15],
366 | [20, 255, 0],
367 | ])
368 |
369 | return {instance_id: INSTANCE_COLOURS[global_instance_id % len(INSTANCE_COLOURS)] for
370 | instance_id, global_instance_id in instance_map.items()
371 | }
372 |
--------------------------------------------------------------------------------
/stretchbev/utils/.ipynb_checkpoints/visualisation-checkpoint.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import matplotlib.pylab
4 |
5 | from fiery.utils.instance import predict_instance_segmentation_and_trajectories
6 |
7 | DEFAULT_COLORMAP = matplotlib.pylab.cm.jet
8 |
9 |
10 | def flow_to_image(flow: np.ndarray, autoscale: bool = False) -> np.ndarray:
11 | """
12 | Applies colour map to flow which should be a 2 channel image tensor HxWx2. Returns a HxWx3 numpy image
13 | Code adapted from: https://github.com/liruoteng/FlowNet/blob/master/models/flownet/scripts/flowlib.py
14 | """
15 | u = flow[0, :, :]
16 | v = flow[1, :, :]
17 |
18 | # Convert to polar coordinates
19 | rad = np.sqrt(u ** 2 + v ** 2)
20 | maxrad = np.max(rad)
21 |
22 | # Normalise flow maps
23 | if autoscale:
24 | u /= maxrad + np.finfo(float).eps
25 | v /= maxrad + np.finfo(float).eps
26 |
27 | # visualise flow with cmap
28 | return np.uint8(compute_color(u, v) * 255)
29 |
30 |
31 | def _normalise(image: np.ndarray) -> np.ndarray:
32 | lower = np.min(image)
33 | delta = np.max(image) - lower
34 | if delta == 0:
35 | delta = 1
36 | image = (image.astype(np.float32) - lower) / delta
37 | return image
38 |
39 |
40 | def apply_colour_map(
41 | image: np.ndarray, cmap: matplotlib.colors.LinearSegmentedColormap = DEFAULT_COLORMAP, autoscale: bool = False
42 | ) -> np.ndarray:
43 | """
44 | Applies a colour map to the given 1 or 2 channel numpy image. if 2 channel, must be 2xHxW.
45 | Returns a HxWx3 numpy image
46 | """
47 | if image.ndim == 2 or (image.ndim == 3 and image.shape[0] == 1):
48 | if image.ndim == 3:
49 | image = image[0]
50 | # grayscale scalar image
51 | if autoscale:
52 | image = _normalise(image)
53 | return cmap(image)[:, :, :3]
54 | if image.shape[0] == 2:
55 | # 2 dimensional UV
56 | return flow_to_image(image, autoscale=autoscale)
57 | if image.shape[0] == 3:
58 | # normalise rgb channels
59 | if autoscale:
60 | image = _normalise(image)
61 | return np.transpose(image, axes=[1, 2, 0])
62 | raise Exception('Image must be 1, 2 or 3 channel to convert to colour_map (CxHxW)')
63 |
64 |
65 | def heatmap_image(
66 | image: np.ndarray, cmap: matplotlib.colors.LinearSegmentedColormap = DEFAULT_COLORMAP, autoscale: bool = True
67 | ) -> np.ndarray:
68 | """Colorize an 1 or 2 channel image with a colourmap."""
69 | if not issubclass(image.dtype.type, np.floating):
70 | raise ValueError(f"Expected a ndarray of float type, but got dtype {image.dtype}")
71 | if not (image.ndim == 2 or (image.ndim == 3 and image.shape[0] in [1, 2])):
72 | raise ValueError(f"Expected a ndarray of shape [H, W] or [1, H, W] or [2, H, W], but got shape {image.shape}")
73 | heatmap_np = apply_colour_map(image, cmap=cmap, autoscale=autoscale)
74 | heatmap_np = np.uint8(heatmap_np * 255)
75 | return heatmap_np
76 |
77 |
78 | def compute_color(u: np.ndarray, v: np.ndarray) -> np.ndarray:
79 | assert u.shape == v.shape
80 | [h, w] = u.shape
81 | img = np.zeros([h, w, 3])
82 | nan_mask = np.isnan(u) | np.isnan(v)
83 | u[nan_mask] = 0
84 | v[nan_mask] = 0
85 |
86 | colorwheel = make_color_wheel()
87 | ncols = np.size(colorwheel, 0)
88 |
89 | rad = np.sqrt(u ** 2 + v ** 2)
90 | a = np.arctan2(-v, -u) / np.pi
91 | f_k = (a + 1) / 2 * (ncols - 1) + 1
92 | k_0 = np.floor(f_k).astype(int)
93 | k_1 = k_0 + 1
94 | k_1[k_1 == ncols + 1] = 1
95 | f = f_k - k_0
96 |
97 | for i in range(0, np.size(colorwheel, 1)):
98 | tmp = colorwheel[:, i]
99 | col0 = tmp[k_0 - 1] / 255
100 | col1 = tmp[k_1 - 1] / 255
101 | col = (1 - f) * col0 + f * col1
102 |
103 | idx = rad <= 1
104 | col[idx] = 1 - rad[idx] * (1 - col[idx])
105 | notidx = np.logical_not(idx)
106 |
107 | col[notidx] *= 0.75
108 | img[:, :, i] = col * (1 - nan_mask)
109 |
110 | return img
111 |
112 |
113 | def make_color_wheel() -> np.ndarray:
114 | """
115 | Create colour wheel.
116 | Code adapted from https://github.com/liruoteng/FlowNet/blob/master/models/flownet/scripts/flowlib.py
117 | """
118 | red_yellow = 15
119 | yellow_green = 6
120 | green_cyan = 4
121 | cyan_blue = 11
122 | blue_magenta = 13
123 | magenta_red = 6
124 |
125 | ncols = red_yellow + yellow_green + green_cyan + cyan_blue + blue_magenta + magenta_red
126 | colorwheel = np.zeros([ncols, 3])
127 |
128 | col = 0
129 |
130 | # red_yellow
131 | colorwheel[0:red_yellow, 0] = 255
132 | colorwheel[0:red_yellow, 1] = np.transpose(np.floor(255 * np.arange(0, red_yellow) / red_yellow))
133 | col += red_yellow
134 |
135 | # yellow_green
136 | colorwheel[col : col + yellow_green, 0] = 255 - np.transpose(
137 | np.floor(255 * np.arange(0, yellow_green) / yellow_green)
138 | )
139 | colorwheel[col : col + yellow_green, 1] = 255
140 | col += yellow_green
141 |
142 | # green_cyan
143 | colorwheel[col : col + green_cyan, 1] = 255
144 | colorwheel[col : col + green_cyan, 2] = np.transpose(np.floor(255 * np.arange(0, green_cyan) / green_cyan))
145 | col += green_cyan
146 |
147 | # cyan_blue
148 | colorwheel[col : col + cyan_blue, 1] = 255 - np.transpose(np.floor(255 * np.arange(0, cyan_blue) / cyan_blue))
149 | colorwheel[col : col + cyan_blue, 2] = 255
150 | col += cyan_blue
151 |
152 | # blue_magenta
153 | colorwheel[col : col + blue_magenta, 2] = 255
154 | colorwheel[col : col + blue_magenta, 0] = np.transpose(np.floor(255 * np.arange(0, blue_magenta) / blue_magenta))
155 | col += +blue_magenta
156 |
157 | # magenta_red
158 | colorwheel[col : col + magenta_red, 2] = 255 - np.transpose(np.floor(255 * np.arange(0, magenta_red) / magenta_red))
159 | colorwheel[col : col + magenta_red, 0] = 255
160 |
161 | return colorwheel
162 |
163 |
164 | def make_contour(img, colour=[0, 0, 0], double_line=False):
165 | h, w = img.shape[:2]
166 | out = img.copy()
167 | # Vertical lines
168 | out[np.arange(h), np.repeat(0, h)] = colour
169 | out[np.arange(h), np.repeat(w - 1, h)] = colour
170 |
171 | # Horizontal lines
172 | out[np.repeat(0, w), np.arange(w)] = colour
173 | out[np.repeat(h - 1, w), np.arange(w)] = colour
174 |
175 | if double_line:
176 | out[np.arange(h), np.repeat(1, h)] = colour
177 | out[np.arange(h), np.repeat(w - 2, h)] = colour
178 |
179 | # Horizontal lines
180 | out[np.repeat(1, w), np.arange(w)] = colour
181 | out[np.repeat(h - 2, w), np.arange(w)] = colour
182 | return out
183 |
184 |
185 | def plot_instance_map(instance_image, instance_map, instance_colours=None, bg_image=None):
186 | if isinstance(instance_image, torch.Tensor):
187 | instance_image = instance_image.cpu().numpy()
188 | assert isinstance(instance_image, np.ndarray)
189 | if instance_colours is None:
190 | instance_colours = generate_instance_colours(instance_map)
191 | if len(instance_image.shape) > 2:
192 | instance_image = instance_image.reshape((instance_image.shape[-2], instance_image.shape[-1]))
193 |
194 | if bg_image is None:
195 | plot_image = 255 * np.ones((instance_image.shape[0], instance_image.shape[1], 3), dtype=np.uint8)
196 | else:
197 | plot_image = bg_image
198 |
199 | for key, value in instance_colours.items():
200 | plot_image[instance_image == key] = value
201 |
202 | return plot_image
203 |
204 |
205 | def visualise_output(labels, output, cfg):
206 | semantic_colours = np.array([[255, 255, 255], [0, 0, 0]], dtype=np.uint8)
207 |
208 | consistent_instance_seg = predict_instance_segmentation_and_trajectories(
209 | output, compute_matched_centers=False
210 | )
211 |
212 | sequence_length = consistent_instance_seg.shape[1]
213 | b = 0
214 | video = []
215 | for t in range(sequence_length):
216 | out_t = []
217 | # Ground truth
218 | unique_ids = torch.unique(labels['instance'][b, t]).cpu().numpy()[1:]
219 | instance_map = dict(zip(unique_ids, unique_ids))
220 | instance_plot = plot_instance_map(labels['instance'][b, t].cpu(), instance_map)[::-1, ::-1]
221 | instance_plot = make_contour(instance_plot)
222 |
223 | semantic_seg = labels['segmentation'].squeeze(2).cpu().numpy()
224 | semantic_plot = semantic_colours[semantic_seg[b, t][::-1, ::-1]]
225 | semantic_plot = make_contour(semantic_plot)
226 |
227 | if cfg.INSTANCE_FLOW.ENABLED:
228 | future_flow_plot = labels['flow'][b, t].cpu().numpy()
229 | future_flow_plot[:, semantic_seg[b, t] != 1] = 0
230 | future_flow_plot = flow_to_image(future_flow_plot)[::-1, ::-1]
231 | future_flow_plot = make_contour(future_flow_plot)
232 | else:
233 | future_flow_plot = np.zeros_like(semantic_plot)
234 |
235 | center_plot = heatmap_image(labels['centerness'][b, t, 0].cpu().numpy())[::-1, ::-1]
236 | center_plot = make_contour(center_plot)
237 |
238 | offset_plot = labels['offset'][b, t].cpu().numpy()
239 | offset_plot[:, semantic_seg[b, t] != 1] = 0
240 | offset_plot = flow_to_image(offset_plot)[::-1, ::-1]
241 | offset_plot = make_contour(offset_plot)
242 |
243 | out_t.append(np.concatenate([instance_plot, future_flow_plot,
244 | semantic_plot, center_plot, offset_plot], axis=0))
245 |
246 | # Predictions
247 | unique_ids = torch.unique(consistent_instance_seg[b, t]).cpu().numpy()[1:]
248 | instance_map = dict(zip(unique_ids, unique_ids))
249 | instance_plot = plot_instance_map(consistent_instance_seg[b, t].cpu(), instance_map)[::-1, ::-1]
250 | instance_plot = make_contour(instance_plot)
251 |
252 | semantic_seg = output['segmentation'].argmax(dim=2).detach().cpu().numpy()
253 | semantic_plot = semantic_colours[semantic_seg[b, t][::-1, ::-1]]
254 | semantic_plot = make_contour(semantic_plot)
255 |
256 | if cfg.INSTANCE_FLOW.ENABLED:
257 | future_flow_plot = output['instance_flow'][b, t].detach().cpu().numpy()
258 | future_flow_plot[:, semantic_seg[b, t] != 1] = 0
259 | future_flow_plot = flow_to_image(future_flow_plot)[::-1, ::-1]
260 | future_flow_plot = make_contour(future_flow_plot)
261 | else:
262 | future_flow_plot = np.zeros_like(semantic_plot)
263 |
264 | center_plot = heatmap_image(output['instance_center'][b, t, 0].detach().cpu().numpy())[::-1, ::-1]
265 | center_plot = make_contour(center_plot)
266 |
267 | offset_plot = output['instance_offset'][b, t].detach().cpu().numpy()
268 | offset_plot[:, semantic_seg[b, t] != 1] = 0
269 | offset_plot = flow_to_image(offset_plot)[::-1, ::-1]
270 | offset_plot = make_contour(offset_plot)
271 |
272 | out_t.append(np.concatenate([instance_plot, future_flow_plot,
273 | semantic_plot, center_plot, offset_plot], axis=0))
274 | out_t = np.concatenate(out_t, axis=1)
275 | # Shape (C, H, W)
276 | out_t = out_t.transpose((2, 0, 1))
277 |
278 | video.append(out_t)
279 |
280 | # Shape (B, T, C, H, W)
281 | video = np.stack(video)[None]
282 | return video
283 |
284 |
285 | def convert_figure_numpy(figure):
286 | """ Convert figure to numpy image """
287 | figure_np = np.frombuffer(figure.canvas.tostring_rgb(), dtype=np.uint8)
288 | figure_np = figure_np.reshape(figure.canvas.get_width_height()[::-1] + (3,))
289 | return figure_np
290 |
291 |
292 | def generate_instance_colours(instance_map):
293 | # Most distinct 22 colors (kelly colors from https://stackoverflow.com/questions/470690/how-to-automatically-generate
294 | # -n-distinct-colors)
295 | # plus some colours from AD40k
296 | INSTANCE_COLOURS = np.asarray([
297 | [0, 0, 0],
298 | [255, 179, 0],
299 | [128, 62, 117],
300 | [255, 104, 0],
301 | [166, 189, 215],
302 | [193, 0, 32],
303 | [206, 162, 98],
304 | [129, 112, 102],
305 | [0, 125, 52],
306 | [246, 118, 142],
307 | [0, 83, 138],
308 | [255, 122, 92],
309 | [83, 55, 122],
310 | [255, 142, 0],
311 | [179, 40, 81],
312 | [244, 200, 0],
313 | [127, 24, 13],
314 | [147, 170, 0],
315 | [89, 51, 21],
316 | [241, 58, 19],
317 | [35, 44, 22],
318 | [112, 224, 255],
319 | [70, 184, 160],
320 | [153, 0, 255],
321 | [71, 255, 0],
322 | [255, 0, 163],
323 | [255, 204, 0],
324 | [0, 255, 235],
325 | [255, 0, 235],
326 | [255, 0, 122],
327 | [255, 245, 0],
328 | [10, 190, 212],
329 | [214, 255, 0],
330 | [0, 204, 255],
331 | [20, 0, 255],
332 | [255, 255, 0],
333 | [0, 153, 255],
334 | [0, 255, 204],
335 | [41, 255, 0],
336 | [173, 0, 255],
337 | [0, 245, 255],
338 | [71, 0, 255],
339 | [0, 255, 184],
340 | [0, 92, 255],
341 | [184, 255, 0],
342 | [255, 214, 0],
343 | [25, 194, 194],
344 | [92, 0, 255],
345 | [220, 220, 220],
346 | [255, 9, 92],
347 | [112, 9, 255],
348 | [8, 255, 214],
349 | [255, 184, 6],
350 | [10, 255, 71],
351 | [255, 41, 10],
352 | [7, 255, 255],
353 | [224, 255, 8],
354 | [102, 8, 255],
355 | [255, 61, 6],
356 | [255, 194, 7],
357 | [0, 255, 20],
358 | [255, 8, 41],
359 | [255, 5, 153],
360 | [6, 51, 255],
361 | [235, 12, 255],
362 | [160, 150, 20],
363 | [0, 163, 255],
364 | [140, 140, 140],
365 | [250, 10, 15],
366 | [20, 255, 0],
367 | ])
368 |
369 | return {instance_id: INSTANCE_COLOURS[global_instance_id % len(INSTANCE_COLOURS)] for
370 | instance_id, global_instance_id in instance_map.items()
371 | }
372 |
--------------------------------------------------------------------------------
/stretchbev/.ipynb_checkpoints/trainer-checkpoint.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import pytorch_lightning as pl
4 | import torch.distributions as distrib
5 | import torch.nn.functional as F
6 |
7 |
8 |
9 | from fiery.config import get_cfg
10 | from fiery.models.fiery import Fiery
11 | from fiery.losses import ProbabilisticLoss, SpatialRegressionLoss, SegmentationLoss
12 | from fiery.metrics import IntersectionOverUnion, PanopticMetric
13 | from fiery.utils.geometry import cumulative_warp_features_reverse
14 | from fiery.utils.instance import predict_instance_segmentation_and_trajectories
15 | from fiery.utils.visualisation import visualise_output
16 | from fiery.models import model_utils
17 |
18 |
19 | class TrainingModule(pl.LightningModule):
20 | def __init__(self, hparams):
21 | super().__init__()
22 |
23 | # see config.py for details
24 | self.hparams = hparams
25 | # pytorch lightning does not support saving YACS CfgNone
26 | cfg = get_cfg(cfg_dict=self.hparams)
27 | self.cfg = cfg
28 | self.n_classes = len(self.cfg.SEMANTIC_SEG.WEIGHTS)
29 |
30 | # Bird's-eye view extent in meters
31 | assert self.cfg.LIFT.X_BOUND[1] > 0 and self.cfg.LIFT.Y_BOUND[1] > 0
32 | self.spatial_extent = (self.cfg.LIFT.X_BOUND[1], self.cfg.LIFT.Y_BOUND[1])
33 |
34 | # Model
35 | self.model = Fiery(cfg)
36 |
37 | # Losses
38 | self.losses_fn = nn.ModuleDict()
39 | self.losses_fn['segmentation'] = SegmentationLoss(
40 | class_weights=torch.Tensor(self.cfg.SEMANTIC_SEG.WEIGHTS),
41 | use_top_k=self.cfg.SEMANTIC_SEG.USE_TOP_K,
42 | top_k_ratio=self.cfg.SEMANTIC_SEG.TOP_K_RATIO,
43 | future_discount=self.cfg.FUTURE_DISCOUNT,
44 | )
45 |
46 | # Uncertainty weighting
47 | self.model.segmentation_weight = nn.Parameter(torch.tensor(0.0), requires_grad=True)
48 |
49 | self.metric_iou_val = IntersectionOverUnion(self.n_classes)
50 |
51 | self.losses_fn['instance_center'] = SpatialRegressionLoss(
52 | norm=2, future_discount=self.cfg.FUTURE_DISCOUNT
53 | )
54 | self.losses_fn['instance_offset'] = SpatialRegressionLoss(
55 | norm=1, future_discount=self.cfg.FUTURE_DISCOUNT, ignore_index=self.cfg.DATASET.IGNORE_INDEX
56 | )
57 |
58 | # Uncertainty weighting
59 | # self.model.centerness_weight = nn.Parameter(torch.tensor(0.0), requires_grad=True)
60 | # self.model.offset_weight = nn.Parameter(torch.tensor(0.0), requires_grad=True)
61 |
62 | self.metric_panoptic_val = PanopticMetric(n_classes=self.n_classes)
63 |
64 | if self.cfg.INSTANCE_FLOW.ENABLED:
65 | self.losses_fn['instance_flow'] = SpatialRegressionLoss(
66 | norm=1, future_discount=self.cfg.FUTURE_DISCOUNT, ignore_index=self.cfg.DATASET.IGNORE_INDEX
67 | )
68 | # Uncertainty weighting
69 | # self.model.flow_weight = nn.Parameter(torch.tensor(0.0), requires_grad=True)
70 |
71 | # self.model.kl_y_weight = nn.Parameter(torch.tensor(0.0), requires_grad=True)
72 | # self.model.kl_z_weight = nn.Parameter(torch.tensor(0.0), requires_grad=True)
73 | # self.model.reconstruction_weight = nn.Parameter(torch.tensor(0.0), requires_grad=True)
74 |
75 | self.training_step_count = 0
76 |
77 | def shared_step(self, batch, is_train):
78 | image = batch['image']
79 | intrinsics = batch['intrinsics']
80 | extrinsics = batch['extrinsics']
81 | future_egomotion = batch['future_egomotion']
82 |
83 | # Warp labels
84 | labels, future_distribution_inputs = self.prepare_future_labels(batch)
85 |
86 | # Forward pass
87 | all_output = self.model(
88 | image, intrinsics, extrinsics, future_egomotion, future_distribution_inputs
89 | )
90 |
91 | #####
92 | # Loss computation
93 | #####
94 | output = all_output['bev_output']
95 | loss = {}
96 | output = all_output['bev_output']
97 | segmentation_factor = 5.44 #1 / torch.exp(model.segmentation_weight)
98 | loss['segmentation'] = segmentation_factor * self.losses_fn['segmentation'](
99 | output['segmentation'], labels['segmentation']
100 | )
101 | # loss['segmentation_uncertainty'] = #0.5 * model.segmentation_weight
102 |
103 | centerness_factor = 591 # 1 / (2*torch.exp(model.centerness_weight))
104 | loss['instance_center'] = centerness_factor * self.losses_fn['instance_center'](
105 | output['instance_center'], labels['centerness']
106 | )
107 |
108 | offset_factor = 1.56 # / (2*torch.exp(model.offset_weight))
109 | loss['instance_offset'] = offset_factor * self.losses_fn['instance_offset'](
110 | output['instance_offset'], labels['offset']
111 | )
112 |
113 | # loss['centerness_uncertainty'] = 0.5 * model.centerness_weight
114 | # loss['offset_uncertainty'] = 0.5 * model.offset_weight
115 |
116 |
117 | flow_factor = 3.90 #/ (2*torch.exp(model.flow_weight))
118 | loss['instance_flow'] = flow_factor * self.losses_fn['instance_flow'](
119 | output['instance_flow'], labels['flow']
120 | )
121 |
122 | reconstruction_factor = 10# / (2*torch.exp(self.model.reconstruction_weight))
123 | loss['state_reconstruction'] = reconstruction_factor * F.mse_loss(all_output['generated_srvp_x'], all_output['lss_outs'])
124 | # loss['reconstruction_uncertainty'] = 0.5 * self.model.reconstruction_weight
125 |
126 | kl_y_factor = 1e-2# / (2*torch.exp(self.model.kl_y_weight))
127 | q_y_0 = model_utils.make_normal_from_raw_params(all_output['q_y0_params'])
128 | kl_y_0 = distrib.kl_divergence(q_y_0, distrib.Normal(0, 1)).mean()
129 |
130 | loss['kl_y0'] = kl_y_factor * kl_y_0
131 | # loss['kl_y0_uncertainty'] = 0.5 * self.model.kl_y_weight
132 |
133 | kl_z_factor = 1e-2# / (2*torch.exp(self.model.kl_z_weight))
134 | q_z, p_z = model_utils.make_normal_from_raw_params(all_output['q_z_params']), model_utils.make_normal_from_raw_params(all_output['p_z_params'])
135 | kl_z = distrib.kl_divergence(q_z, p_z).mean()
136 |
137 | loss['kl_z'] = kl_z_factor * kl_z
138 | # loss['kl_z_uncertainty'] = 0.5 * self.model.kl_z_weight
139 |
140 | # Metrics
141 | if not is_train:
142 | seg_prediction = output['segmentation'].detach()
143 | seg_prediction = torch.argmax(seg_prediction, dim=2, keepdims=True)
144 | self.metric_iou_val(seg_prediction, labels['segmentation'])
145 |
146 | pred_consistent_instance_seg = predict_instance_segmentation_and_trajectories(
147 | output, compute_matched_centers=False
148 | )
149 |
150 | self.metric_panoptic_val(pred_consistent_instance_seg, labels['instance'])
151 |
152 | return all_output, labels, loss
153 |
154 | def prepare_future_labels(self, batch, receptive_field=1):
155 | labels = {}
156 | future_distribution_inputs = []
157 |
158 | segmentation_labels = batch['segmentation']
159 | instance_center_labels = batch['centerness']
160 | instance_offset_labels = batch['offset']
161 | instance_flow_labels = batch['flow']
162 | gt_instance = batch['instance']
163 | future_egomotion = batch['future_egomotion']
164 |
165 | # Warp labels to present's reference frame
166 | segmentation_labels = cumulative_warp_features_reverse(
167 | segmentation_labels[:, (receptive_field - 1):].float(),
168 | future_egomotion[:, (receptive_field - 1):],
169 | mode='nearest', spatial_extent=self.spatial_extent,
170 | ).long().contiguous()
171 | labels['segmentation'] = segmentation_labels
172 | future_distribution_inputs.append(segmentation_labels)
173 |
174 | # Warp instance labels to present's reference frame
175 | gt_instance = cumulative_warp_features_reverse(
176 | gt_instance[:, (receptive_field - 1):].float().unsqueeze(2),
177 | future_egomotion[:, (receptive_field - 1):],
178 | mode='nearest', spatial_extent=self.spatial_extent,
179 | ).long().contiguous()[:, :, 0]
180 | labels['instance'] = gt_instance
181 |
182 | instance_center_labels = cumulative_warp_features_reverse(
183 | instance_center_labels[:, (receptive_field - 1):],
184 | future_egomotion[:, (receptive_field - 1):],
185 | mode='nearest', spatial_extent=self.spatial_extent,
186 | ).contiguous()
187 | labels['centerness'] = instance_center_labels
188 |
189 | instance_offset_labels = cumulative_warp_features_reverse(
190 | instance_offset_labels[:, (receptive_field- 1):],
191 | future_egomotion[:, (receptive_field - 1):],
192 | mode='nearest', spatial_extent=self.spatial_extent,
193 | ).contiguous()
194 | labels['offset'] = instance_offset_labels
195 |
196 | future_distribution_inputs.append(instance_center_labels)
197 | future_distribution_inputs.append(instance_offset_labels)
198 |
199 | if self.cfg.INSTANCE_FLOW.ENABLED:
200 | instance_flow_labels = cumulative_warp_features_reverse(
201 | instance_flow_labels[:, (receptive_field - 1):],
202 | future_egomotion[:, (receptive_field - 1):],
203 | mode='nearest', spatial_extent=self.spatial_extent,
204 | ).contiguous()
205 | labels['flow'] = instance_flow_labels
206 |
207 | future_distribution_inputs.append(instance_flow_labels)
208 |
209 |
210 | if len(future_distribution_inputs) > 0:
211 | future_distribution_inputs = torch.cat(future_distribution_inputs, dim=2)
212 |
213 | b, t, n, h, w = future_distribution_inputs.shape
214 | future_distribution_inputs = F.adaptive_max_pool2d(future_distribution_inputs.reshape(b, t*n, h, w), 50).reshape(b, t, n, 50, 50)
215 |
216 | return labels, future_distribution_inputs
217 |
218 | def visualise(self, labels, output, batch_idx, prefix='train'):
219 | visualisation_video = visualise_output(labels, output['bev_output'], self.cfg)
220 | name = f'{prefix}_outputs'
221 | if prefix == 'val':
222 | name = name + f'_{batch_idx}'
223 | self.logger.experiment.add_video(name, visualisation_video, global_step=self.training_step_count, fps=2)
224 |
225 | def training_step(self, batch, batch_idx):
226 | output, labels, loss = self.shared_step(batch, True)
227 | self.training_step_count += 1
228 | for key, value in loss.items():
229 | self.logger.experiment.add_scalar(key, value, global_step=self.training_step_count)
230 | if self.training_step_count % self.cfg.VIS_INTERVAL == 0:
231 | self.visualise(labels, output, batch_idx, prefix='train')
232 | return sum(loss.values())
233 |
234 | def validation_step(self, batch, batch_idx):
235 | output, labels, loss = self.shared_step(batch, False)
236 | for key, value in loss.items():
237 | self.log('val_' + key, value)
238 |
239 | if batch_idx == 0:
240 | self.visualise(labels, output, batch_idx, prefix='val')
241 |
242 | def shared_epoch_end(self, step_outputs, is_train):
243 | # log per class iou metrics
244 | class_names = ['background', 'dynamic']
245 | if not is_train:
246 | scores = self.metric_iou_val.compute()
247 | for key, value in zip(class_names, scores):
248 | self.logger.experiment.add_scalar('val_iou_' + key, value, global_step=self.training_step_count)
249 | self.metric_iou_val.reset()
250 |
251 | if not is_train:
252 | scores = self.metric_panoptic_val.compute()
253 | for key, value in scores.items():
254 | for instance_name, score in zip(['background', 'vehicles'], value):
255 | if instance_name != 'background':
256 | self.logger.experiment.add_scalar(f'val_{key}_{instance_name}', score.item(),
257 | global_step=self.training_step_count)
258 | self.metric_panoptic_val.reset()
259 |
260 | # self.logger.experiment.add_scalar('segmentation_weight',
261 | # 1 / (torch.exp(self.model.segmentation_weight)),
262 | # global_step=self.training_step_count)
263 | # self.logger.experiment.add_scalar('centerness_weight',
264 | # 1 / (2 * torch.exp(self.model.centerness_weight)),
265 | # global_step=self.training_step_count)
266 | # self.logger.experiment.add_scalar('offset_weight', 1 / (2 * torch.exp(self.model.offset_weight)),
267 | # global_step=self.training_step_count)
268 | # if self.cfg.INSTANCE_FLOW.ENABLED:
269 | # self.logger.experiment.add_scalar('flow_weight', 1 / (2 * torch.exp(self.model.flow_weight)),
270 | # global_step=self.training_step_count)
271 |
272 | # self.logger.experiment.add_scalar('reconstruction_weight', 1 / (2 * torch.exp(self.model.reconstruction_weight)),
273 | # global_step=self.training_step_count)
274 | # self.logger.experiment.add_scalar('kl_y_weight', 1 / (2 * torch.exp(self.model.kl_y_weight)),
275 | # global_step=self.training_step_count)
276 | # self.logger.experiment.add_scalar('kl_z_weight', 1 / (2 * torch.exp(self.model.kl_z_weight)),
277 | # global_step=self.training_step_count)
278 |
279 |
280 | def training_epoch_end(self, step_outputs):
281 | self.shared_epoch_end(step_outputs, True)
282 |
283 | def validation_epoch_end(self, step_outputs):
284 | self.shared_epoch_end(step_outputs, False)
285 |
286 | def configure_optimizers(self):
287 | params = self.model.parameters()
288 | optimizer = torch.optim.Adam(
289 | params, lr=self.cfg.OPTIMIZER.LR, weight_decay=self.cfg.OPTIMIZER.WEIGHT_DECAY
290 | )
291 | # scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[60*self.len_loader,80*self.len_loader], gamma=0.1)
292 |
293 | return optimizer
294 |
--------------------------------------------------------------------------------