├── stretchbev ├── .DS_Store ├── utils │ ├── .DS_Store │ ├── __pycache__ │ │ ├── network.cpython-37.pyc │ │ ├── geometry.cpython-37.pyc │ │ ├── instance.cpython-37.pyc │ │ ├── lyft_splits.cpython-37.pyc │ │ └── visualisation.cpython-37.pyc │ ├── lyft_splits.py │ ├── network.py │ ├── .ipynb_checkpoints │ │ ├── network-checkpoint.py │ │ └── visualisation-checkpoint.py │ ├── geometry.py │ └── visualisation.py ├── layers │ ├── .DS_Store │ ├── __pycache__ │ │ ├── temporal.cpython-37.pyc │ │ └── convolutions.cpython-37.pyc │ ├── convolutions.py │ ├── .ipynb_checkpoints │ │ └── temporal-checkpoint.py │ └── temporal.py ├── models │ ├── .DS_Store │ ├── __pycache__ │ │ ├── decoder.cpython-37.pyc │ │ ├── encoder.cpython-37.pyc │ │ ├── model_utils.cpython-37.pyc │ │ ├── res_models.cpython-37.pyc │ │ ├── srvp_models.cpython-37.pyc │ │ ├── stretchbev.cpython-37.pyc │ │ ├── distributions.cpython-37.pyc │ │ ├── temporal_model.cpython-37.pyc │ │ └── future_prediction.cpython-37.pyc │ ├── future_prediction.py │ ├── .ipynb_checkpoints │ │ ├── future_prediction-checkpoint.py │ │ ├── temporal_model-checkpoint.py │ │ ├── decoder-checkpoint.py │ │ ├── model_utils-checkpoint.py │ │ ├── res_models-checkpoint.py │ │ └── srvp_models-checkpoint.py │ ├── distributions.py │ ├── temporal_model.py │ ├── decoder.py │ ├── encoder.py │ ├── model_utils.py │ ├── res_models.py │ └── srvp_models.py ├── __pycache__ │ ├── config.cpython-37.pyc │ ├── data.cpython-37.pyc │ ├── losses.cpython-37.pyc │ ├── metrics.cpython-37.pyc │ └── trainer.cpython-37.pyc ├── configs │ └── stretchbev.yml ├── losses.py ├── .ipynb_checkpoints │ ├── losses-checkpoint.py │ ├── config-checkpoint.py │ ├── metrics-checkpoint.py │ └── trainer-checkpoint.py ├── config.py └── metrics.py ├── environment.yml ├── LICENSE ├── train.py ├── README.md ├── evaluate.py └── visualise.py /stretchbev/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kaanakan/stretchbev/HEAD/stretchbev/.DS_Store -------------------------------------------------------------------------------- /stretchbev/utils/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kaanakan/stretchbev/HEAD/stretchbev/utils/.DS_Store -------------------------------------------------------------------------------- /stretchbev/layers/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kaanakan/stretchbev/HEAD/stretchbev/layers/.DS_Store -------------------------------------------------------------------------------- /stretchbev/models/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kaanakan/stretchbev/HEAD/stretchbev/models/.DS_Store -------------------------------------------------------------------------------- /stretchbev/__pycache__/config.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kaanakan/stretchbev/HEAD/stretchbev/__pycache__/config.cpython-37.pyc -------------------------------------------------------------------------------- /stretchbev/__pycache__/data.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kaanakan/stretchbev/HEAD/stretchbev/__pycache__/data.cpython-37.pyc -------------------------------------------------------------------------------- /stretchbev/__pycache__/losses.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kaanakan/stretchbev/HEAD/stretchbev/__pycache__/losses.cpython-37.pyc -------------------------------------------------------------------------------- /stretchbev/__pycache__/metrics.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kaanakan/stretchbev/HEAD/stretchbev/__pycache__/metrics.cpython-37.pyc -------------------------------------------------------------------------------- /stretchbev/__pycache__/trainer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kaanakan/stretchbev/HEAD/stretchbev/__pycache__/trainer.cpython-37.pyc -------------------------------------------------------------------------------- /stretchbev/utils/__pycache__/network.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kaanakan/stretchbev/HEAD/stretchbev/utils/__pycache__/network.cpython-37.pyc -------------------------------------------------------------------------------- /stretchbev/layers/__pycache__/temporal.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kaanakan/stretchbev/HEAD/stretchbev/layers/__pycache__/temporal.cpython-37.pyc -------------------------------------------------------------------------------- /stretchbev/models/__pycache__/decoder.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kaanakan/stretchbev/HEAD/stretchbev/models/__pycache__/decoder.cpython-37.pyc -------------------------------------------------------------------------------- /stretchbev/models/__pycache__/encoder.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kaanakan/stretchbev/HEAD/stretchbev/models/__pycache__/encoder.cpython-37.pyc -------------------------------------------------------------------------------- /stretchbev/utils/__pycache__/geometry.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kaanakan/stretchbev/HEAD/stretchbev/utils/__pycache__/geometry.cpython-37.pyc -------------------------------------------------------------------------------- /stretchbev/utils/__pycache__/instance.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kaanakan/stretchbev/HEAD/stretchbev/utils/__pycache__/instance.cpython-37.pyc -------------------------------------------------------------------------------- /stretchbev/models/__pycache__/model_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kaanakan/stretchbev/HEAD/stretchbev/models/__pycache__/model_utils.cpython-37.pyc -------------------------------------------------------------------------------- /stretchbev/models/__pycache__/res_models.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kaanakan/stretchbev/HEAD/stretchbev/models/__pycache__/res_models.cpython-37.pyc -------------------------------------------------------------------------------- /stretchbev/models/__pycache__/srvp_models.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kaanakan/stretchbev/HEAD/stretchbev/models/__pycache__/srvp_models.cpython-37.pyc -------------------------------------------------------------------------------- /stretchbev/models/__pycache__/stretchbev.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kaanakan/stretchbev/HEAD/stretchbev/models/__pycache__/stretchbev.cpython-37.pyc -------------------------------------------------------------------------------- /stretchbev/utils/__pycache__/lyft_splits.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kaanakan/stretchbev/HEAD/stretchbev/utils/__pycache__/lyft_splits.cpython-37.pyc -------------------------------------------------------------------------------- /stretchbev/layers/__pycache__/convolutions.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kaanakan/stretchbev/HEAD/stretchbev/layers/__pycache__/convolutions.cpython-37.pyc -------------------------------------------------------------------------------- /stretchbev/models/__pycache__/distributions.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kaanakan/stretchbev/HEAD/stretchbev/models/__pycache__/distributions.cpython-37.pyc -------------------------------------------------------------------------------- /stretchbev/models/__pycache__/temporal_model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kaanakan/stretchbev/HEAD/stretchbev/models/__pycache__/temporal_model.cpython-37.pyc -------------------------------------------------------------------------------- /stretchbev/utils/__pycache__/visualisation.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kaanakan/stretchbev/HEAD/stretchbev/utils/__pycache__/visualisation.cpython-37.pyc -------------------------------------------------------------------------------- /stretchbev/models/__pycache__/future_prediction.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kaanakan/stretchbev/HEAD/stretchbev/models/__pycache__/future_prediction.cpython-37.pyc -------------------------------------------------------------------------------- /stretchbev/configs/stretchbev.yml: -------------------------------------------------------------------------------- 1 | TAG: 'baseline' 2 | 3 | GPUS: [0, 1, 2, 3] 4 | 5 | BATCHSIZE: 2 6 | PRECISION: 16 7 | 8 | TIME_RECEPTIVE_FIELD: 3 9 | N_FUTURE_FRAMES: 4 10 | 11 | EPOCHS: 20 12 | 13 | MODEL: 14 | BN_MOMENTUM: 0.05 15 | 16 | 17 | PRETRAINED: 18 | LOAD_WEIGHTS: True 19 | PATH: './static_lift_splat_setting.ckpt' 20 | 21 | 22 | INSTANCE_FLOW: 23 | ENABLED: True 24 | 25 | OPTIMIZER: 26 | LR: 1e-4 27 | 28 | N_WORKERS: 4 29 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: fiery 2 | channels: 3 | - defaults 4 | - conda-forge 5 | - pytorch 6 | dependencies: 7 | - python=3.7.10 8 | - pytorch=1.7.0 9 | - torchvision=0.8.1 10 | - cudatoolkit=10.1 11 | - numpy=1.19.2 12 | - scipy=1.5.2 13 | - pillow=8.0.1 14 | - tqdm=4.50.2 15 | - pytorch-lightning=1.2.5 16 | - efficientnet-pytorch=0.7.0 17 | - fvcore=0.1.2.post20201122 18 | - pip=21.0.1 19 | - pip: 20 | - nuscenes-devkit==1.1.0 21 | - lyft-dataset-sdk==0.0.8 22 | - opencv-python==4.5.1.48 23 | - moviepy==1.0.3 24 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Wayve Technologies Limited 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /stretchbev/utils/lyft_splits.py: -------------------------------------------------------------------------------- 1 | TRAIN_LYFT_INDICES = [1, 3, 5, 6, 7, 8, 9, 10, 11, 12, 14, 15, 16, 2 | 17, 18, 19, 20, 21, 23, 24, 27, 28, 29, 30, 31, 32, 3 | 33, 35, 36, 37, 39, 41, 43, 44, 45, 46, 47, 48, 49, 4 | 50, 51, 52, 53, 55, 56, 59, 60, 62, 63, 65, 68, 69, 5 | 70, 71, 72, 73, 74, 75, 76, 78, 79, 81, 82, 83, 84, 6 | 86, 87, 88, 89, 93, 95, 97, 98, 99, 103, 104, 107, 108, 7 | 109, 110, 111, 113, 114, 115, 116, 117, 118, 119, 121, 122, 124, 8 | 127, 128, 130, 131, 132, 134, 135, 136, 137, 138, 139, 143, 144, 9 | 146, 147, 148, 149, 150, 151, 152, 153, 154, 156, 157, 158, 159, 10 | 161, 162, 165, 166, 167, 171, 172, 173, 174, 175, 176, 177, 178, 11 | 179] 12 | 13 | VAL_LYFT_INDICES = [0, 2, 4, 13, 22, 25, 26, 34, 38, 40, 42, 54, 57, 14 | 58, 61, 64, 66, 67, 77, 80, 85, 90, 91, 92, 94, 96, 15 | 100, 101, 102, 105, 106, 112, 120, 123, 125, 126, 129, 133, 140, 16 | 141, 142, 145, 155, 160, 163, 164, 168, 169, 170] 17 | -------------------------------------------------------------------------------- /stretchbev/utils/network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision 4 | 5 | 6 | def pack_sequence_dim(x): 7 | b, s = x.shape[:2] 8 | return x.view(b * s, *x.shape[2:]) 9 | 10 | 11 | def unpack_sequence_dim(x, b, s): 12 | return x.view(b, s, *x.shape[1:]) 13 | 14 | 15 | def preprocess_batch(batch, device, unsqueeze=False): 16 | for key, value in batch.items(): 17 | if key != 'sample_token': 18 | batch[key] = value.to(device) 19 | if unsqueeze: 20 | batch[key] = batch[key].unsqueeze(0) 21 | 22 | 23 | def set_module_grad(module, requires_grad=False): 24 | for p in module.parameters(): 25 | p.requires_grad = requires_grad 26 | 27 | 28 | def set_bn_momentum(model, momentum=0.1): 29 | for m in model.modules(): 30 | if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)): 31 | m.momentum = momentum 32 | 33 | 34 | class NormalizeInverse(torchvision.transforms.Normalize): 35 | # https://discuss.pytorch.org/t/simple-way-to-inverse-transform-normalization/4821/8 36 | def __init__(self, mean, std): 37 | mean = torch.as_tensor(mean) 38 | std = torch.as_tensor(std) 39 | std_inv = 1 / (std + 1e-7) 40 | mean_inv = -mean * std_inv 41 | super().__init__(mean=mean_inv, std=std_inv) 42 | 43 | def __call__(self, tensor): 44 | return super().__call__(tensor.clone()) 45 | -------------------------------------------------------------------------------- /stretchbev/utils/.ipynb_checkpoints/network-checkpoint.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision 4 | 5 | def pack_sequence_dim(x): 6 | b, s = x.shape[:2] 7 | return x.view(b * s, *x.shape[2:]) 8 | 9 | 10 | def unpack_sequence_dim(x, b, s): 11 | return x.view(b, s, *x.shape[1:]) 12 | 13 | 14 | def preprocess_batch(batch, device, unsqueeze=False): 15 | for key, value in batch.items(): 16 | if key != 'sample_token': 17 | batch[key] = value.to(device) 18 | if unsqueeze: 19 | batch[key] = batch[key].unsqueeze(0) 20 | 21 | 22 | def set_module_grad(module, requires_grad=False): 23 | for p in module.parameters(): 24 | p.requires_grad = requires_grad 25 | 26 | 27 | def set_bn_momentum(model, momentum=0.1): 28 | for m in model.modules(): 29 | if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)): 30 | m.momentum = momentum 31 | 32 | 33 | class NormalizeInverse(torchvision.transforms.Normalize): 34 | # https://discuss.pytorch.org/t/simple-way-to-inverse-transform-normalization/4821/8 35 | def __init__(self, mean, std): 36 | mean = torch.as_tensor(mean) 37 | std = torch.as_tensor(std) 38 | std_inv = 1 / (std + 1e-7) 39 | mean_inv = -mean * std_inv 40 | super().__init__(mean=mean_inv, std=std_inv) 41 | 42 | def __call__(self, tensor): 43 | return super().__call__(tensor.clone()) 44 | -------------------------------------------------------------------------------- /stretchbev/models/future_prediction.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from stretchbev.layers.convolutions import Bottleneck 4 | from stretchbev.layers.temporal import SpatialGRU 5 | 6 | 7 | class FuturePrediction(torch.nn.Module): 8 | def __init__(self, in_channels, latent_dim, n_gru_blocks=3, n_res_layers=3): 9 | super().__init__() 10 | self.n_gru_blocks = n_gru_blocks 11 | 12 | # Convolutional recurrent model with z_t as an initial hidden state and inputs the sample 13 | # from the probabilistic model. The architecture of the model is: 14 | # [Spatial GRU - [Bottleneck] x n_res_layers] x n_gru_blocks 15 | self.spatial_grus = [] 16 | self.res_blocks = [] 17 | 18 | for i in range(self.n_gru_blocks): 19 | gru_in_channels = latent_dim if i == 0 else in_channels 20 | self.spatial_grus.append(SpatialGRU(gru_in_channels, in_channels)) 21 | self.res_blocks.append(torch.nn.Sequential(*[Bottleneck(in_channels) 22 | for _ in range(n_res_layers)])) 23 | 24 | self.spatial_grus = torch.nn.ModuleList(self.spatial_grus) 25 | self.res_blocks = torch.nn.ModuleList(self.res_blocks) 26 | 27 | def forward(self, x, hidden_state=None): 28 | # x has shape (b, n_future, c, h, w), hidden_state (b, c, h, w) 29 | for i in range(self.n_gru_blocks): 30 | x = self.spatial_grus[i](x, hidden_state, flow=None) 31 | b, n_future, c, h, w = x.shape 32 | 33 | x = self.res_blocks[i](x.view(b * n_future, c, h, w)) 34 | x = x.view(b, n_future, c, h, w) 35 | 36 | return x 37 | -------------------------------------------------------------------------------- /stretchbev/models/.ipynb_checkpoints/future_prediction-checkpoint.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from fiery.layers.convolutions import Bottleneck 4 | from fiery.layers.temporal import SpatialGRU 5 | 6 | 7 | class FuturePrediction(torch.nn.Module): 8 | def __init__(self, in_channels, latent_dim, n_gru_blocks=3, n_res_layers=3): 9 | super().__init__() 10 | self.n_gru_blocks = n_gru_blocks 11 | 12 | # Convolutional recurrent model with z_t as an initial hidden state and inputs the sample 13 | # from the probabilistic model. The architecture of the model is: 14 | # [Spatial GRU - [Bottleneck] x n_res_layers] x n_gru_blocks 15 | self.spatial_grus = [] 16 | self.res_blocks = [] 17 | 18 | for i in range(self.n_gru_blocks): 19 | gru_in_channels = latent_dim if i == 0 else in_channels 20 | self.spatial_grus.append(SpatialGRU(gru_in_channels, in_channels)) 21 | self.res_blocks.append(torch.nn.Sequential(*[Bottleneck(in_channels) 22 | for _ in range(n_res_layers)])) 23 | 24 | self.spatial_grus = torch.nn.ModuleList(self.spatial_grus) 25 | self.res_blocks = torch.nn.ModuleList(self.res_blocks) 26 | 27 | def forward(self, x, hidden_state=None): 28 | # x has shape (b, n_future, c, h, w), hidden_state (b, c, h, w) 29 | for i in range(self.n_gru_blocks): 30 | x = self.spatial_grus[i](x, hidden_state, flow=None) 31 | b, n_future, c, h, w = x.shape 32 | 33 | x = self.res_blocks[i](x.view(b * n_future, c, h, w)) 34 | x = x.view(b, n_future, c, h, w) 35 | 36 | return x 37 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import socket 3 | import time 4 | 5 | import pytorch_lightning as pl 6 | import torch 7 | from pytorch_lightning.callbacks import ModelCheckpoint 8 | from pytorch_lightning.plugins import DDPPlugin 9 | 10 | from stretchbev.config import get_parser, get_cfg 11 | from stretchbev.data import prepare_dataloaders 12 | from stretchbev.trainer import TrainingModule 13 | 14 | 15 | def main(): 16 | args = get_parser().parse_args() 17 | cfg = get_cfg(args) 18 | 19 | trainloader, valloader = prepare_dataloaders(cfg) 20 | model = TrainingModule(cfg.convert_to_dict()) 21 | model.len_loader = len(trainloader) 22 | 23 | if cfg.PRETRAINED.LOAD_WEIGHTS: 24 | # Load single-image instance segmentation model. 25 | pretrained_model_weights = torch.load( 26 | os.path.join('.', cfg.PRETRAINED.PATH), map_location='cpu' 27 | )['state_dict'] 28 | 29 | new_dict = {key: val for (key, val) in pretrained_model_weights.items() if 'decoder' not in key} 30 | model.load_state_dict(new_dict, strict=False) 31 | print(f'Loaded single-image model weights from {cfg.PRETRAINED.PATH}') 32 | 33 | save_dir = os.path.join( 34 | cfg.LOG_DIR, time.strftime('%d%B%Yat%H:%M:%S%Z') + '_' + socket.gethostname() + '_' + cfg.TAG 35 | ) 36 | checkpoint_callback = ModelCheckpoint(dirpath='weights', filename='stretchbev-{epoch:02d}', save_top_k=-1) 37 | tb_logger = pl.loggers.TensorBoardLogger(save_dir=save_dir) 38 | trainer = pl.Trainer( 39 | gpus=cfg.GPUS, 40 | accelerator='ddp', 41 | precision=cfg.PRECISION, 42 | sync_batchnorm=True, 43 | gradient_clip_val=cfg.GRAD_NORM_CLIP, 44 | max_epochs=cfg.EPOCHS, 45 | weights_summary='full', 46 | logger=tb_logger, 47 | log_every_n_steps=cfg.LOGGING_INTERVAL, 48 | plugins=DDPPlugin(find_unused_parameters=True), 49 | profiler='simple', 50 | callbacks=[checkpoint_callback] 51 | ) 52 | trainer.fit(model, trainloader, valloader) 53 | 54 | 55 | if __name__ == "__main__": 56 | main() 57 | -------------------------------------------------------------------------------- /stretchbev/models/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from stretchbev.layers.convolutions import Bottleneck 5 | 6 | 7 | class DistributionModule(nn.Module): 8 | """ 9 | A convolutional net that parametrises a diagonal Gaussian distribution. 10 | """ 11 | 12 | def __init__(self, in_channels, latent_dim, min_log_sigma, max_log_sigma): 13 | super().__init__() 14 | self.compress_dim = in_channels // 2 15 | self.latent_dim = latent_dim 16 | self.min_log_sigma = min_log_sigma 17 | self.max_log_sigma = max_log_sigma 18 | 19 | self.encoder = DistributionEncoder( 20 | in_channels, 21 | self.compress_dim, 22 | ) 23 | self.last_conv = nn.Sequential( 24 | nn.AdaptiveAvgPool2d(1), nn.Conv2d(self.compress_dim, out_channels=2 * self.latent_dim, kernel_size=1) 25 | ) 26 | 27 | def forward(self, s_t): 28 | b, s = s_t.shape[:2] 29 | assert s == 1 30 | encoding = self.encoder(s_t[:, 0]) 31 | 32 | mu_log_sigma = self.last_conv(encoding).view(b, 1, 2 * self.latent_dim) 33 | mu = mu_log_sigma[:, :, :self.latent_dim] 34 | log_sigma = mu_log_sigma[:, :, self.latent_dim:] 35 | 36 | # clip the log_sigma value for numerical stability 37 | log_sigma = torch.clamp(log_sigma, self.min_log_sigma, self.max_log_sigma) 38 | return mu, log_sigma 39 | 40 | 41 | class DistributionEncoder(nn.Module): 42 | """Encodes s_t or (s_t, y_{t+1}, ..., y_{t+H}). 43 | """ 44 | 45 | def __init__(self, in_channels, out_channels): 46 | super().__init__() 47 | 48 | self.model = nn.Sequential( 49 | Bottleneck(in_channels, out_channels=out_channels, downsample=True), 50 | Bottleneck(out_channels, out_channels=out_channels, downsample=True), 51 | Bottleneck(out_channels, out_channels=out_channels, downsample=True), 52 | Bottleneck(out_channels, out_channels=out_channels, downsample=True), 53 | ) 54 | 55 | def forward(self, s_t): 56 | return self.model(s_t) 57 | -------------------------------------------------------------------------------- /stretchbev/models/temporal_model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from stretchbev.layers.temporal import Bottleneck3D, TemporalBlock 4 | 5 | 6 | class TemporalModel(nn.Module): 7 | def __init__( 8 | self, in_channels, receptive_field, input_shape, start_out_channels=64, extra_in_channels=0, 9 | n_spatial_layers_between_temporal_layers=0, use_pyramid_pooling=True): 10 | super().__init__() 11 | self.receptive_field = receptive_field 12 | n_temporal_layers = receptive_field - 1 13 | 14 | h, w = input_shape 15 | modules = [] 16 | 17 | block_in_channels = in_channels 18 | block_out_channels = start_out_channels 19 | 20 | for _ in range(n_temporal_layers): 21 | if use_pyramid_pooling: 22 | use_pyramid_pooling = True 23 | pool_sizes = [(2, h, w)] 24 | else: 25 | use_pyramid_pooling = False 26 | pool_sizes = None 27 | temporal = TemporalBlock( 28 | block_in_channels, 29 | block_out_channels, 30 | use_pyramid_pooling=use_pyramid_pooling, 31 | pool_sizes=pool_sizes, 32 | ) 33 | spatial = [ 34 | Bottleneck3D(block_out_channels, block_out_channels, kernel_size=(1, 3, 3)) 35 | for _ in range(n_spatial_layers_between_temporal_layers) 36 | ] 37 | temporal_spatial_layers = nn.Sequential(temporal, *spatial) 38 | modules.extend(temporal_spatial_layers) 39 | 40 | block_in_channels = block_out_channels 41 | block_out_channels += extra_in_channels 42 | 43 | self.out_channels = block_in_channels 44 | 45 | self.model = nn.Sequential(*modules) 46 | 47 | def forward(self, x): 48 | # Reshape input tensor to (batch, C, time, H, W) 49 | x = x.permute(0, 2, 1, 3, 4) 50 | x = self.model(x) 51 | x = x.permute(0, 2, 1, 3, 4).contiguous() 52 | return x[:, (self.receptive_field - 1):] 53 | 54 | 55 | class TemporalModelIdentity(nn.Module): 56 | def __init__(self, in_channels, receptive_field): 57 | super().__init__() 58 | self.receptive_field = receptive_field 59 | self.out_channels = in_channels 60 | 61 | def forward(self, x): 62 | return x[:, (self.receptive_field - 1):] 63 | -------------------------------------------------------------------------------- /stretchbev/models/.ipynb_checkpoints/temporal_model-checkpoint.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from fiery.layers.temporal import Bottleneck3D, TemporalBlock 4 | 5 | 6 | class TemporalModel(nn.Module): 7 | def __init__( 8 | self, in_channels, receptive_field, input_shape, start_out_channels=64, extra_in_channels=0, 9 | n_spatial_layers_between_temporal_layers=0, use_pyramid_pooling=True): 10 | super().__init__() 11 | self.receptive_field = receptive_field 12 | n_temporal_layers = receptive_field - 1 13 | 14 | h, w = input_shape 15 | modules = [] 16 | 17 | block_in_channels = in_channels 18 | block_out_channels = start_out_channels 19 | 20 | for _ in range(n_temporal_layers): 21 | if use_pyramid_pooling: 22 | use_pyramid_pooling = True 23 | pool_sizes = [(2, h, w)] 24 | else: 25 | use_pyramid_pooling = False 26 | pool_sizes = None 27 | temporal = TemporalBlock( 28 | block_in_channels, 29 | block_out_channels, 30 | use_pyramid_pooling=use_pyramid_pooling, 31 | pool_sizes=pool_sizes, 32 | ) 33 | spatial = [ 34 | Bottleneck3D(block_out_channels, block_out_channels, kernel_size=(1, 3, 3)) 35 | for _ in range(n_spatial_layers_between_temporal_layers) 36 | ] 37 | temporal_spatial_layers = nn.Sequential(temporal, *spatial) 38 | modules.extend(temporal_spatial_layers) 39 | 40 | block_in_channels = block_out_channels 41 | block_out_channels += extra_in_channels 42 | 43 | self.out_channels = block_in_channels 44 | 45 | self.model = nn.Sequential(*modules) 46 | 47 | def forward(self, x): 48 | # Reshape input tensor to (batch, C, time, H, W) 49 | x = x.permute(0, 2, 1, 3, 4) 50 | x = self.model(x) 51 | x = x.permute(0, 2, 1, 3, 4).contiguous() 52 | return x[:, (self.receptive_field - 1):] 53 | 54 | 55 | class TemporalModelIdentity(nn.Module): 56 | def __init__(self, in_channels, receptive_field): 57 | super().__init__() 58 | self.receptive_field = receptive_field 59 | self.out_channels = in_channels 60 | 61 | def forward(self, x): 62 | return x[:, (self.receptive_field - 1):] 63 | -------------------------------------------------------------------------------- /stretchbev/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class SpatialRegressionLoss(nn.Module): 7 | def __init__(self, norm, ignore_index=255, future_discount=1.0): 8 | super(SpatialRegressionLoss, self).__init__() 9 | self.norm = norm 10 | self.ignore_index = ignore_index 11 | self.future_discount = future_discount 12 | 13 | if norm == 1: 14 | self.loss_fn = F.l1_loss 15 | elif norm == 2: 16 | self.loss_fn = F.mse_loss 17 | else: 18 | raise ValueError(f'Expected norm 1 or 2, but got norm={norm}') 19 | 20 | def forward(self, prediction, target): 21 | assert len(prediction.shape) == 5, 'Must be a 5D tensor' 22 | # ignore_index is the same across all channels 23 | mask = target[:, :, :1] != self.ignore_index 24 | if mask.sum() == 0: 25 | return prediction.new_zeros(1)[0].float() 26 | 27 | loss = self.loss_fn(prediction, target, reduction='none') 28 | 29 | # Sum channel dimension 30 | loss = torch.sum(loss, dim=-3, keepdims=True) 31 | 32 | seq_len = loss.shape[1] 33 | future_discounts = self.future_discount ** torch.arange(seq_len, device=loss.device, dtype=loss.dtype) 34 | future_discounts = future_discounts.view(1, seq_len, 1, 1, 1) 35 | loss = loss * future_discounts 36 | 37 | return loss[mask].mean() 38 | 39 | 40 | class SegmentationLoss(nn.Module): 41 | def __init__(self, class_weights, ignore_index=255, use_top_k=False, top_k_ratio=1.0, future_discount=1.0): 42 | super().__init__() 43 | self.class_weights = class_weights 44 | self.ignore_index = ignore_index 45 | self.use_top_k = use_top_k 46 | self.top_k_ratio = top_k_ratio 47 | self.future_discount = future_discount 48 | 49 | def forward(self, prediction, target): 50 | if target.shape[-3] != 1: 51 | raise ValueError('segmentation label must be an index-label with channel dimension = 1.') 52 | b, s, c, h, w = prediction.shape 53 | 54 | prediction = prediction.view(b * s, c, h, w) 55 | target = target.view(b * s, h, w) 56 | loss = F.cross_entropy( 57 | prediction, 58 | target, 59 | ignore_index=self.ignore_index, 60 | reduction='none', 61 | weight=self.class_weights.to(target.device), 62 | ) 63 | 64 | loss = loss.view(b, s, h, w) 65 | 66 | future_discounts = self.future_discount ** torch.arange(s, device=loss.device, dtype=loss.dtype) 67 | future_discounts = future_discounts.view(1, s, 1, 1) 68 | loss = loss * future_discounts 69 | 70 | loss = loss.view(b, s, -1) 71 | if self.use_top_k: 72 | # Penalises the top-k hardest pixels 73 | k = int(self.top_k_ratio * loss.shape[2]) 74 | loss, _ = torch.sort(loss, dim=2, descending=True) 75 | loss = loss[:, :, :k] 76 | 77 | return torch.mean(loss) 78 | 79 | 80 | class ProbabilisticLoss(nn.Module): 81 | def forward(self, output): 82 | present_mu = output['present_mu'] 83 | present_log_sigma = output['present_log_sigma'] 84 | future_mu = output['future_mu'] 85 | future_log_sigma = output['future_log_sigma'] 86 | 87 | var_future = torch.exp(2 * future_log_sigma) 88 | var_present = torch.exp(2 * present_log_sigma) 89 | kl_div = ( 90 | present_log_sigma - future_log_sigma - 0.5 + (var_future + (future_mu - present_mu) ** 2) / ( 91 | 2 * var_present) 92 | ) 93 | 94 | kl_loss = torch.mean(torch.sum(kl_div, dim=-1)) 95 | 96 | return kl_loss 97 | -------------------------------------------------------------------------------- /stretchbev/.ipynb_checkpoints/losses-checkpoint.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class SpatialRegressionLoss(nn.Module): 7 | def __init__(self, norm, ignore_index=255, future_discount=1.0): 8 | super(SpatialRegressionLoss, self).__init__() 9 | self.norm = norm 10 | self.ignore_index = ignore_index 11 | self.future_discount = future_discount 12 | 13 | if norm == 1: 14 | self.loss_fn = F.l1_loss 15 | elif norm == 2: 16 | self.loss_fn = F.mse_loss 17 | else: 18 | raise ValueError(f'Expected norm 1 or 2, but got norm={norm}') 19 | 20 | def forward(self, prediction, target): 21 | assert len(prediction.shape) == 5, 'Must be a 5D tensor' 22 | # ignore_index is the same across all channels 23 | mask = target[:, :, :1] != self.ignore_index 24 | if mask.sum() == 0: 25 | return prediction.new_zeros(1)[0].float() 26 | 27 | loss = self.loss_fn(prediction, target, reduction='none') 28 | 29 | # Sum channel dimension 30 | loss = torch.sum(loss, dim=-3, keepdims=True) 31 | 32 | seq_len = loss.shape[1] 33 | future_discounts = self.future_discount ** torch.arange(seq_len, device=loss.device, dtype=loss.dtype) 34 | future_discounts = future_discounts.view(1, seq_len, 1, 1, 1) 35 | loss = loss * future_discounts 36 | 37 | return loss[mask].mean() 38 | 39 | 40 | class SegmentationLoss(nn.Module): 41 | def __init__(self, class_weights, ignore_index=255, use_top_k=False, top_k_ratio=1.0, future_discount=1.0): 42 | super().__init__() 43 | self.class_weights = class_weights 44 | self.ignore_index = ignore_index 45 | self.use_top_k = use_top_k 46 | self.top_k_ratio = top_k_ratio 47 | self.future_discount = future_discount 48 | 49 | def forward(self, prediction, target): 50 | if target.shape[-3] != 1: 51 | raise ValueError('segmentation label must be an index-label with channel dimension = 1.') 52 | b, s, c, h, w = prediction.shape 53 | 54 | prediction = prediction.view(b * s, c, h, w) 55 | target = target.view(b * s, h, w) 56 | loss = F.cross_entropy( 57 | prediction, 58 | target, 59 | ignore_index=self.ignore_index, 60 | reduction='none', 61 | weight=self.class_weights.to(target.device), 62 | ) 63 | 64 | loss = loss.view(b, s, h, w) 65 | 66 | future_discounts = self.future_discount ** torch.arange(s, device=loss.device, dtype=loss.dtype) 67 | future_discounts = future_discounts.view(1, s, 1, 1) 68 | loss = loss * future_discounts 69 | 70 | loss = loss.view(b, s, -1) 71 | if self.use_top_k: 72 | # Penalises the top-k hardest pixels 73 | k = int(self.top_k_ratio * loss.shape[2]) 74 | loss, _ = torch.sort(loss, dim=2, descending=True) 75 | loss = loss[:, :, :k] 76 | 77 | return torch.mean(loss) 78 | 79 | 80 | class ProbabilisticLoss(nn.Module): 81 | def forward(self, output): 82 | present_mu = output['present_mu'] 83 | present_log_sigma = output['present_log_sigma'] 84 | future_mu = output['future_mu'] 85 | future_log_sigma = output['future_log_sigma'] 86 | 87 | var_future = torch.exp(2 * future_log_sigma) 88 | var_present = torch.exp(2 * present_log_sigma) 89 | kl_div = ( 90 | present_log_sigma - future_log_sigma - 0.5 + (var_future + (future_mu - present_mu) ** 2) / ( 91 | 2 * var_present) 92 | ) 93 | 94 | kl_loss = torch.mean(torch.sum(kl_div, dim=-1)) 95 | 96 | return kl_loss 97 | -------------------------------------------------------------------------------- /stretchbev/models/decoder.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torchvision.models.resnet import resnet18 3 | 4 | from stretchbev.layers.convolutions import UpsamplingAdd 5 | 6 | 7 | class Decoder(nn.Module): 8 | def __init__(self, in_channels, n_classes, predict_future_flow): 9 | super().__init__() 10 | backbone = resnet18(pretrained=False, zero_init_residual=True) 11 | self.first_conv = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False) 12 | self.bn1 = backbone.bn1 13 | self.relu = backbone.relu 14 | 15 | self.layer1 = backbone.layer1 16 | self.layer2 = backbone.layer2 17 | self.layer3 = backbone.layer3 18 | self.predict_future_flow = predict_future_flow 19 | 20 | shared_out_channels = in_channels 21 | self.up3_skip = UpsamplingAdd(256, 128, scale_factor=2) 22 | self.up2_skip = UpsamplingAdd(128, 64, scale_factor=2) 23 | self.up1_skip = UpsamplingAdd(64, shared_out_channels, scale_factor=2) 24 | 25 | self.segmentation_head = nn.Sequential( 26 | nn.Conv2d(shared_out_channels, shared_out_channels, kernel_size=3, padding=1, bias=False), 27 | nn.BatchNorm2d(shared_out_channels), 28 | nn.ReLU(inplace=True), 29 | nn.Conv2d(shared_out_channels, n_classes, kernel_size=1, padding=0), 30 | ) 31 | self.instance_offset_head = nn.Sequential( 32 | nn.Conv2d(shared_out_channels, shared_out_channels, kernel_size=3, padding=1, bias=False), 33 | nn.BatchNorm2d(shared_out_channels), 34 | nn.ReLU(inplace=True), 35 | nn.Conv2d(shared_out_channels, 2, kernel_size=1, padding=0), 36 | ) 37 | self.instance_center_head = nn.Sequential( 38 | nn.Conv2d(shared_out_channels, shared_out_channels, kernel_size=3, padding=1, bias=False), 39 | nn.BatchNorm2d(shared_out_channels), 40 | nn.ReLU(inplace=True), 41 | nn.Conv2d(shared_out_channels, 1, kernel_size=1, padding=0), 42 | nn.Sigmoid(), 43 | ) 44 | 45 | if self.predict_future_flow: 46 | self.instance_future_head = nn.Sequential( 47 | nn.Conv2d(shared_out_channels, shared_out_channels, kernel_size=3, padding=1, bias=False), 48 | nn.BatchNorm2d(shared_out_channels), 49 | nn.ReLU(inplace=True), 50 | nn.Conv2d(shared_out_channels, 2, kernel_size=1, padding=0), 51 | ) 52 | 53 | def forward(self, x): 54 | b, s, c, h, w = x.shape 55 | x = x.view(b * s, c, h, w) 56 | 57 | # (H, W) 58 | skip_x = {'1': x} 59 | x = self.first_conv(x) 60 | x = self.bn1(x) 61 | x = self.relu(x) 62 | 63 | # (H/4, W/4) 64 | x = self.layer1(x) 65 | skip_x['2'] = x 66 | x = self.layer2(x) 67 | skip_x['3'] = x 68 | 69 | # (H/8, W/8) 70 | x = self.layer3(x) 71 | 72 | #  First upsample to (H/4, W/4) 73 | x = self.up3_skip(x, skip_x['3']) 74 | 75 | # Second upsample to (H/2, W/2) 76 | x = self.up2_skip(x, skip_x['2']) 77 | 78 | # Third upsample to (H, W) 79 | x = self.up1_skip(x, skip_x['1']) 80 | 81 | segmentation_output = self.segmentation_head(x) 82 | instance_center_output = self.instance_center_head(x) 83 | instance_offset_output = self.instance_offset_head(x) 84 | instance_future_output = self.instance_future_head(x) if self.predict_future_flow else None 85 | return { 86 | 'segmentation': segmentation_output.view(b, s, *segmentation_output.shape[1:]), 87 | 'instance_center': instance_center_output.view(b, s, *instance_center_output.shape[1:]), 88 | 'instance_offset': instance_offset_output.view(b, s, *instance_offset_output.shape[1:]), 89 | 'instance_flow': instance_future_output.view(b, s, *instance_future_output.shape[1:]) 90 | if instance_future_output is not None else None, 91 | } 92 | -------------------------------------------------------------------------------- /stretchbev/models/.ipynb_checkpoints/decoder-checkpoint.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torchvision.models.resnet import resnet18 3 | 4 | from fiery.layers.convolutions import UpsamplingAdd 5 | 6 | 7 | class Decoder(nn.Module): 8 | def __init__(self, in_channels, n_classes, predict_future_flow): 9 | super().__init__() 10 | backbone = resnet18(pretrained=False, zero_init_residual=True) 11 | self.first_conv = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False) 12 | self.bn1 = backbone.bn1 13 | self.relu = backbone.relu 14 | 15 | self.layer1 = backbone.layer1 16 | self.layer2 = backbone.layer2 17 | self.layer3 = backbone.layer3 18 | self.predict_future_flow = predict_future_flow 19 | 20 | shared_out_channels = in_channels 21 | self.up3_skip = UpsamplingAdd(256, 128, scale_factor=2) 22 | self.up2_skip = UpsamplingAdd(128, 64, scale_factor=2) 23 | self.up1_skip = UpsamplingAdd(64, shared_out_channels, scale_factor=2) 24 | 25 | self.segmentation_head = nn.Sequential( 26 | nn.Conv2d(shared_out_channels, shared_out_channels, kernel_size=3, padding=1, bias=False), 27 | nn.BatchNorm2d(shared_out_channels), 28 | nn.ReLU(inplace=True), 29 | nn.Conv2d(shared_out_channels, n_classes, kernel_size=1, padding=0), 30 | ) 31 | self.instance_offset_head = nn.Sequential( 32 | nn.Conv2d(shared_out_channels, shared_out_channels, kernel_size=3, padding=1, bias=False), 33 | nn.BatchNorm2d(shared_out_channels), 34 | nn.ReLU(inplace=True), 35 | nn.Conv2d(shared_out_channels, 2, kernel_size=1, padding=0), 36 | ) 37 | self.instance_center_head = nn.Sequential( 38 | nn.Conv2d(shared_out_channels, shared_out_channels, kernel_size=3, padding=1, bias=False), 39 | nn.BatchNorm2d(shared_out_channels), 40 | nn.ReLU(inplace=True), 41 | nn.Conv2d(shared_out_channels, 1, kernel_size=1, padding=0), 42 | nn.Sigmoid(), 43 | ) 44 | 45 | if self.predict_future_flow: 46 | self.instance_future_head = nn.Sequential( 47 | nn.Conv2d(shared_out_channels, shared_out_channels, kernel_size=3, padding=1, bias=False), 48 | nn.BatchNorm2d(shared_out_channels), 49 | nn.ReLU(inplace=True), 50 | nn.Conv2d(shared_out_channels, 2, kernel_size=1, padding=0), 51 | ) 52 | 53 | def forward(self, x): 54 | b, s, c, h, w = x.shape 55 | x = x.view(b * s, c, h, w) 56 | 57 | # (H, W) 58 | skip_x = {'1': x} 59 | x = self.first_conv(x) 60 | x = self.bn1(x) 61 | x = self.relu(x) 62 | 63 | # (H/4, W/4) 64 | x = self.layer1(x) 65 | skip_x['2'] = x 66 | x = self.layer2(x) 67 | skip_x['3'] = x 68 | 69 | # (H/8, W/8) 70 | x = self.layer3(x) 71 | 72 | # First upsample to (H/4, W/4) 73 | x = self.up3_skip(x, skip_x['3']) 74 | 75 | # Second upsample to (H/2, W/2) 76 | x = self.up2_skip(x, skip_x['2']) 77 | 78 | # Third upsample to (H, W) 79 | x = self.up1_skip(x, skip_x['1']) 80 | 81 | segmentation_output = self.segmentation_head(x) 82 | instance_center_output = self.instance_center_head(x) 83 | instance_offset_output = self.instance_offset_head(x) 84 | instance_future_output = self.instance_future_head(x) if self.predict_future_flow else None 85 | return { 86 | 'segmentation': segmentation_output.view(b, s, *segmentation_output.shape[1:]), 87 | 'instance_center': instance_center_output.view(b, s, *instance_center_output.shape[1:]), 88 | 'instance_offset': instance_offset_output.view(b, s, *instance_offset_output.shape[1:]), 89 | 'instance_flow': instance_future_output.view(b, s, *instance_future_output.shape[1:]) 90 | if instance_future_output is not None else None, 91 | } 92 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # StretchBEV: Stretching Future Instance Prediction Spatially and Temporally (ECCV 2022) 2 | 3 | 4 | [![report](https://img.shields.io/badge/ArXiv-Paper-red)](https://arxiv.org/abs/2203.13641) 5 | [![report](https://img.shields.io/badge/Project-Page-blue)](https://kuis-ai.github.io/stretchbev/) 6 | [![report](https://img.shields.io/badge/Pretrained-Models-yellow)](https://github.com/kaanakan/stretchbev/releases/tag/v1.0) 7 | [![report](https://img.shields.io/badge/Presentation-Video-brightgreen)](https://www.youtube.com/watch?v=2SiUNs6BMVk) 8 | 9 | 10 | 11 | > [**StretchBEV: Stretching Future Instance Prediction Spatially and Temporally**](https://arxiv.org/abs/2203.13641), 12 | > [Adil Kaan Akan](https://kaanakan.github.io), 13 | > [Fatma Guney](https://mysite.ku.edu.tr/fguney/), 14 | > *European Conference on Computer Vision (ECCV), 2022* 15 | 16 |

17 | 18 |

19 | 20 | ## Features 21 | 22 | StretchBEV is a future instance prediction network in Bird's-eye view representation. It earns temporal dynamics in a latent space through stochastic residual updates at each time step. By sampling from a learned distribution at each time step, we obtain more diverse future predictions that are also more accurate compared to previous work, especially stretching both spatially further regions in the scene and temporally over longer time horizons 23 | 24 | 25 | ## Requirements 26 | 27 | All models were trained with Python 3.7.10 and PyTorch 1.7.0 28 | 29 | A list of required Python packages is available in the `environment.yml` file. 30 | 31 | 32 | 33 | ## Datasets 34 | 35 | For preparations of datasets, we followed [FIERY](https://github.com/wayveai/fiery). Please follow [this link](https://github.com/wayveai/fiery/blob/master/DATASET.md) below if you want to construct the datasets. 36 | 37 | 38 | ## Training 39 | 40 | To train the model on NuScenes: 41 | 42 | - First, you need to download [`static_lift_splat_setting.ckpt`](https://github.com/wayveai/fiery/releases/download/v1.0/static_lift_splat_setting.ckpt) and copy it to this directory. 43 | - Run `python train.py --config fiery/configs/baseline.yml DATASET.DATAROOT ${NUSCENES_DATAROOT}`. 44 | 45 | This will train the model on 4 GPUs, each with a batch of size 2. To train on single GPU add the flag `GPUS 1`, and to change the batch size use the flag `BATCHSIZE ${DESIRED_BATCHSIZE}`. 46 | 47 | 48 | ## Evaluation 49 | 50 | To evaluate a trained model on NuScenes: 51 | 52 | - Download [pre-trained weights](https://github.com/wayveai/fiery/releases/download/v1.0/stretchbev.ckpt). 53 | - Run `python evaluate.py --checkpoint ${CHECKPOINT_PATH} --dataroot ${NUSCENES_DATAROOT}`. 54 | 55 | ### Pretrained weights 56 | 57 | You can download the pretrained weights from the releases of this repository or the links below. 58 | 59 | [Normal setting weight](https://github.com/wayveai/fiery/releases/download/v1.0/stretchbev.ckpt) 60 | 61 | [Fishing setting weight](https://github.com/wayveai/fiery/releases/download/v1.0/stretchbev_fishing.ckpt) 62 | 63 | 64 | 65 | ## How to Cite 66 | 67 | Please cite the paper if you benefit from our paper or the repository: 68 | 69 | ``` 70 | @InProceedings{Akan2022ECCV, 71 | author = {Akan, Adil Kaan and G\"uney, Fatma}, 72 | title = {StretchBEV: Stretching Future Instance Prediction Spatially and Temporally}, 73 | journal = {European Conference on Computer Vision (ECCV)}, 74 | year = {2022}, 75 | } 76 | ``` 77 | 78 | ## Acknowledgments 79 | 80 | We would like to thank FIERY and SRVP authors for making their repositories public. This repository contains several code segments from [FIERY's repository](https://github.com/wayveai/fiery) and [SRVP's repository](https://github.com/edouardelasalles/srvp). We appreciate the efforts by [Berkay Ugur Senocak](https://github.com/4turkuaz) for cleaning the code before release. 81 | -------------------------------------------------------------------------------- /stretchbev/models/encoder.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from efficientnet_pytorch import EfficientNet 3 | 4 | from stretchbev.layers.convolutions import UpsamplingConcat 5 | 6 | 7 | class Encoder(nn.Module): 8 | def __init__(self, cfg, D): 9 | super().__init__() 10 | self.D = D 11 | self.C = cfg.OUT_CHANNELS 12 | self.use_depth_distribution = cfg.USE_DEPTH_DISTRIBUTION 13 | self.downsample = cfg.DOWNSAMPLE 14 | self.version = cfg.NAME.split('-')[1] 15 | 16 | self.backbone = EfficientNet.from_pretrained(cfg.NAME) 17 | self.delete_unused_layers() 18 | 19 | if self.downsample == 16: 20 | if self.version == 'b0': 21 | upsampling_in_channels = 320 + 112 22 | elif self.version == 'b4': 23 | upsampling_in_channels = 448 + 160 24 | upsampling_out_channels = 512 25 | elif self.downsample == 8: 26 | if self.version == 'b0': 27 | upsampling_in_channels = 112 + 40 28 | elif self.version == 'b4': 29 | upsampling_in_channels = 160 + 56 30 | upsampling_out_channels = 128 31 | else: 32 | raise ValueError(f'Downsample factor {self.downsample} not handled.') 33 | 34 | self.upsampling_layer = UpsamplingConcat(upsampling_in_channels, upsampling_out_channels) 35 | if self.use_depth_distribution: 36 | self.depth_layer = nn.Conv2d(upsampling_out_channels, self.C + self.D, kernel_size=1, padding=0) 37 | else: 38 | self.depth_layer = nn.Conv2d(upsampling_out_channels, self.C, kernel_size=1, padding=0) 39 | 40 | def delete_unused_layers(self): 41 | indices_to_delete = [] 42 | for idx in range(len(self.backbone._blocks)): 43 | if self.downsample == 8: 44 | if self.version == 'b0' and idx > 10: 45 | indices_to_delete.append(idx) 46 | if self.version == 'b4' and idx > 21: 47 | indices_to_delete.append(idx) 48 | 49 | for idx in reversed(indices_to_delete): 50 | del self.backbone._blocks[idx] 51 | 52 | del self.backbone._conv_head 53 | del self.backbone._bn1 54 | del self.backbone._avg_pooling 55 | del self.backbone._dropout 56 | del self.backbone._fc 57 | 58 | def get_features(self, x): 59 | # Adapted from https://github.com/lukemelas/EfficientNet-PyTorch/blob/master/efficientnet_pytorch/model.py#L231 60 | endpoints = dict() 61 | 62 | # Stem 63 | x = self.backbone._swish(self.backbone._bn0(self.backbone._conv_stem(x))) 64 | prev_x = x 65 | 66 | # Blocks 67 | for idx, block in enumerate(self.backbone._blocks): 68 | drop_connect_rate = self.backbone._global_params.drop_connect_rate 69 | if drop_connect_rate: 70 | drop_connect_rate *= float(idx) / len(self.backbone._blocks) 71 | x = block(x, drop_connect_rate=drop_connect_rate) 72 | if prev_x.size(2) > x.size(2): 73 | endpoints['reduction_{}'.format(len(endpoints) + 1)] = prev_x 74 | prev_x = x 75 | 76 | if self.downsample == 8: 77 | if self.version == 'b0' and idx == 10: 78 | break 79 | if self.version == 'b4' and idx == 21: 80 | break 81 | 82 | # Head 83 | endpoints['reduction_{}'.format(len(endpoints) + 1)] = x 84 | 85 | if self.downsample == 16: 86 | input_1, input_2 = endpoints['reduction_5'], endpoints['reduction_4'] 87 | elif self.downsample == 8: 88 | input_1, input_2 = endpoints['reduction_4'], endpoints['reduction_3'] 89 | 90 | x = self.upsampling_layer(input_1, input_2) 91 | return x 92 | 93 | def forward(self, x): 94 | x = self.get_features(x) # get feature vector 95 | 96 | x = self.depth_layer(x) # feature and depth head 97 | 98 | if self.use_depth_distribution: 99 | depth = x[:, : self.D].softmax(dim=1) 100 | x = depth.unsqueeze(1) * x[:, self.D: (self.D + self.C)].unsqueeze(2) # outer product depth and features 101 | else: 102 | x = x.unsqueeze(2).repeat(1, 1, self.D, 1, 1) 103 | 104 | return x 105 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import random 2 | from argparse import ArgumentParser 3 | 4 | import torch 5 | from tqdm import tqdm 6 | 7 | from stretchbev.data import prepare_dataloaders 8 | from stretchbev.metrics import IntersectionOverUnion, PanopticMetric 9 | from stretchbev.trainer import TrainingModule 10 | from stretchbev.utils.instance import predict_instance_segmentation_and_trajectories 11 | from stretchbev.utils.network import preprocess_batch 12 | 13 | # 30mx30m, 100mx100m 14 | EVALUATION_RANGES = { 15 | '30x30': (70, 130), 16 | '100x100': (0, 200) 17 | } 18 | 19 | 20 | def eval(checkpoint_path, dataroot, version): 21 | trainer = TrainingModule.load_from_checkpoint(checkpoint_path, strict=True) 22 | print(f'Loaded weights from \n {checkpoint_path}') 23 | trainer.eval() 24 | 25 | device = torch.device('cuda:0') 26 | trainer.to(device) 27 | model = trainer.model 28 | model.eval() 29 | 30 | cfg = model.cfg 31 | cfg.GPUS = "[0]" 32 | cfg.BATCHSIZE = 1 33 | 34 | cfg.DATASET.DATAROOT = dataroot 35 | cfg.DATASET.VERSION = version 36 | 37 | _, valloader = prepare_dataloaders(cfg) 38 | 39 | panoptic_metrics = {} 40 | iou_metrics = {} 41 | n_classes = len(cfg.SEMANTIC_SEG.WEIGHTS) 42 | for key in EVALUATION_RANGES.keys(): 43 | panoptic_metrics[key] = PanopticMetric(n_classes=n_classes, temporally_consistent=True).to( 44 | device) 45 | iou_metrics[key] = IntersectionOverUnion(n_classes).to(device) 46 | 47 | for i, batch in enumerate(tqdm(valloader)): 48 | preprocess_batch(batch, device) 49 | image = batch['image'] # [:, 3:] 50 | intrinsics = batch['intrinsics'] # [:, 3:] 51 | extrinsics = batch['extrinsics'] # [:, 3:] 52 | future_egomotion = batch['future_egomotion'] # [:, 3:] 53 | 54 | batch_size = image.shape[0] 55 | 56 | labels, future_distribution_inputs = trainer.prepare_future_labels(batch) 57 | 58 | with torch.no_grad(): 59 | #  Evaluate with mean prediction 60 | output = model.multi_sample_inference(image, intrinsics, extrinsics, future_egomotion, num_samples=1, 61 | future_distribution_inputs=future_distribution_inputs)[0] 62 | labels = {k: v[:, 3:] for k, v in labels.items()} 63 | 64 | #  Consistent instance seg 65 | pred_consistent_instance_seg = predict_instance_segmentation_and_trajectories( 66 | output, compute_matched_centers=False, make_consistent=True 67 | ) 68 | 69 | segmentation_pred = output['segmentation'].detach() 70 | segmentation_pred = torch.argmax(segmentation_pred, dim=2, keepdims=True) 71 | 72 | for key, grid in EVALUATION_RANGES.items(): 73 | limits = slice(grid[0], grid[1]) 74 | panoptic_metrics[key](pred_consistent_instance_seg[..., limits, limits].contiguous().detach(), 75 | labels['instance'][..., limits, limits].contiguous() 76 | ) 77 | 78 | iou_metrics[key](segmentation_pred[..., limits, limits].contiguous(), 79 | labels['segmentation'][..., limits, limits].contiguous() 80 | ) 81 | # iou_scores = iou_metrics[key].compute() 82 | # print(iou_scores[1].item()) 83 | # iou_metrics[key].reset() 84 | 85 | results = {} 86 | for key, grid in EVALUATION_RANGES.items(): 87 | panoptic_scores = panoptic_metrics[key].compute() 88 | for panoptic_key, value in panoptic_scores.items(): 89 | results[f'{panoptic_key}'] = results.get(f'{panoptic_key}', []) + [100 * value[1].item()] 90 | 91 | iou_scores = iou_metrics[key].compute() 92 | results['iou'] = results.get('iou', []) + [100 * iou_scores[1].item()] 93 | 94 | for panoptic_key in ['iou', 'pq', 'sq', 'rq']: 95 | print(panoptic_key) 96 | print(' & '.join([f'{x:.1f}' for x in results[panoptic_key]])) 97 | 98 | 99 | if __name__ == '__main__': 100 | parser = ArgumentParser(description='Fiery evaluation') 101 | parser.add_argument('--checkpoint', default='./fiery.ckpt', type=str, help='path to checkpoint') 102 | parser.add_argument('--dataroot', default='./nuscenes', type=str, help='path to the dataset') 103 | parser.add_argument('--version', default='trainval', type=str, choices=['mini', 'trainval'], 104 | help='dataset version') 105 | 106 | args = parser.parse_args() 107 | torch.manual_seed(0) 108 | random.seed(0) 109 | 110 | eval(args.checkpoint, args.dataroot, args.version) 111 | -------------------------------------------------------------------------------- /stretchbev/config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from fvcore.common.config import CfgNode as _CfgNode 3 | 4 | 5 | def convert_to_dict(cfg_node, key_list=[]): 6 | """Convert a config node to dictionary.""" 7 | _VALID_TYPES = {tuple, list, str, int, float, bool} 8 | if not isinstance(cfg_node, _CfgNode): 9 | if type(cfg_node) not in _VALID_TYPES: 10 | print( 11 | 'Key {} with value {} is not a valid type; valid types: {}'.format( 12 | '.'.join(key_list), type(cfg_node), _VALID_TYPES 13 | ), 14 | ) 15 | return cfg_node 16 | else: 17 | cfg_dict = dict(cfg_node) 18 | for k, v in cfg_dict.items(): 19 | cfg_dict[k] = convert_to_dict(v, key_list + [k]) 20 | return cfg_dict 21 | 22 | 23 | class CfgNode(_CfgNode): 24 | """Remove once https://github.com/rbgirshick/yacs/issues/19 is merged.""" 25 | 26 | def convert_to_dict(self): 27 | return convert_to_dict(self) 28 | 29 | 30 | CN = CfgNode 31 | 32 | _C = CN() 33 | _C.LOG_DIR = 'tensorboard_logs' 34 | _C.TAG = 'default' 35 | 36 | _C.GPUS = [0] # which gpus to use 37 | _C.PRECISION = 32 # 16bit or 32bit 38 | _C.BATCHSIZE = 3 39 | _C.EPOCHS = 20 40 | 41 | _C.N_WORKERS = 5 42 | _C.VIS_INTERVAL = 5000 43 | _C.LOGGING_INTERVAL = 500 44 | 45 | _C.PRETRAINED = CN() 46 | _C.PRETRAINED.LOAD_WEIGHTS = False 47 | _C.PRETRAINED.PATH = '' 48 | 49 | _C.DATASET = CN() 50 | _C.DATASET.DATAROOT = './nuscenes/' 51 | _C.DATASET.VERSION = 'trainval' 52 | _C.DATASET.NAME = 'nuscenes' 53 | _C.DATASET.IGNORE_INDEX = 255 # Ignore index when creating flow/offset labels 54 | _C.DATASET.FILTER_INVISIBLE_VEHICLES = True # Filter vehicles that are not visible from the cameras 55 | 56 | _C.TIME_RECEPTIVE_FIELD = 3 # how many frames of temporal context (1 for single timeframe) 57 | _C.N_FUTURE_FRAMES = 4 # how many time steps into the future to predict 58 | 59 | _C.IMAGE = CN() 60 | _C.IMAGE.FINAL_DIM = (224, 480) 61 | _C.IMAGE.RESIZE_SCALE = 0.3 62 | _C.IMAGE.TOP_CROP = 46 63 | _C.IMAGE.ORIGINAL_HEIGHT = 900 # Original input RGB camera height 64 | _C.IMAGE.ORIGINAL_WIDTH = 1600 # Original input RGB camera width 65 | _C.IMAGE.NAMES = ['CAM_FRONT_LEFT', 'CAM_FRONT', 'CAM_FRONT_RIGHT', 'CAM_BACK_LEFT', 'CAM_BACK', 'CAM_BACK_RIGHT'] 66 | 67 | _C.LIFT = CN() # image to BEV lifting 68 | _C.LIFT.X_BOUND = [-50.0, 50.0, 0.5] # Forward 69 | _C.LIFT.Y_BOUND = [-50.0, 50.0, 0.5] # Sides 70 | _C.LIFT.Z_BOUND = [-10.0, 10.0, 20.0] # Height 71 | _C.LIFT.D_BOUND = [2.0, 50.0, 1.0] 72 | 73 | _C.MODEL = CN() 74 | 75 | _C.MODEL.ENCODER = CN() 76 | _C.MODEL.ENCODER.DOWNSAMPLE = 8 77 | _C.MODEL.ENCODER.NAME = 'efficientnet-b4' 78 | _C.MODEL.ENCODER.OUT_CHANNELS = 64 79 | _C.MODEL.ENCODER.USE_DEPTH_DISTRIBUTION = True 80 | 81 | _C.MODEL.SMALL_ENCODER = CN() 82 | _C.MODEL.SMALL_ENCODER.FILTER_SIZE = 64 83 | _C.MODEL.SMALL_ENCODER.SKIPCO = True 84 | 85 | 86 | _C.MODEL.FIRST_STATE = CN() 87 | _C.MODEL.FIRST_STATE.NUM_LAYERS = 3 88 | 89 | _C.MODEL.DYNAMICS = CN() 90 | _C.MODEL.DYNAMICS.NUM_LAYERS = 3 91 | 92 | 93 | 94 | 95 | _C.MODEL.DISTRIBUTION = CN() 96 | _C.MODEL.DISTRIBUTION.LATENT_DIM = 32 97 | _C.MODEL.DISTRIBUTION.MIN_LOG_SIGMA = -5.0 98 | _C.MODEL.DISTRIBUTION.MAX_LOG_SIGMA = 5.0 99 | _C.MODEL.DISTRIBUTION.POSTERIOR_LAYERS = 2 100 | _C.MODEL.DISTRIBUTION.PRIOR_LAYERS = 2 101 | 102 | _C.MODEL.FUTURE_PRED = CN() 103 | _C.MODEL.FUTURE_PRED.N_GRU_BLOCKS = 3 104 | _C.MODEL.FUTURE_PRED.N_RES_LAYERS = 3 105 | 106 | _C.MODEL.DECODER = CN() 107 | 108 | _C.MODEL.BN_MOMENTUM = 0.1 109 | _C.MODEL.SUBSAMPLE = False # Subsample frames for Lyft 110 | 111 | _C.SEMANTIC_SEG = CN() 112 | _C.SEMANTIC_SEG.WEIGHTS = [1.0, 2.0] # per class cross entropy weights (bg, dynamic, drivable, lane) 113 | _C.SEMANTIC_SEG.USE_TOP_K = True # backprop only top-k hardest pixels 114 | _C.SEMANTIC_SEG.TOP_K_RATIO = 0.25 115 | 116 | _C.INSTANCE_SEG = CN() 117 | 118 | _C.INSTANCE_FLOW = CN() 119 | _C.INSTANCE_FLOW.ENABLED = True 120 | 121 | _C.PROBABILISTIC = CN() 122 | _C.PROBABILISTIC.ENABLED = True # learn a distribution over futures 123 | _C.PROBABILISTIC.WEIGHT = 100.0 124 | _C.PROBABILISTIC.FUTURE_DIM = 6 # number of dimension added (future flow, future centerness, offset, seg) 125 | 126 | _C.FUTURE_DISCOUNT = 0.95 127 | 128 | _C.OPTIMIZER = CN() 129 | _C.OPTIMIZER.LR = 3e-4 130 | _C.OPTIMIZER.WEIGHT_DECAY = 1e-7 131 | _C.GRAD_NORM_CLIP = 5 132 | 133 | 134 | def get_parser(): 135 | parser = argparse.ArgumentParser(description='Stretchbev training') 136 | # TODO: remove below? 137 | parser.add_argument('--config-file', default='', metavar='FILE', help='path to config file') 138 | parser.add_argument( 139 | 'opts', help='Modify config options using the command-line', default=None, nargs=argparse.REMAINDER, 140 | ) 141 | return parser 142 | 143 | 144 | def get_cfg(args=None, cfg_dict=None): 145 | """ First get default config. Then merge cfg_dict. Then merge according to args. """ 146 | 147 | cfg = _C.clone() 148 | 149 | if cfg_dict is not None: 150 | cfg.merge_from_other_cfg(CfgNode(cfg_dict)) 151 | 152 | if args is not None: 153 | if args.config_file: 154 | cfg.merge_from_file(args.config_file) 155 | cfg.merge_from_list(args.opts) 156 | cfg.freeze() 157 | return cfg 158 | -------------------------------------------------------------------------------- /stretchbev/.ipynb_checkpoints/config-checkpoint.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from fvcore.common.config import CfgNode as _CfgNode 3 | 4 | 5 | def convert_to_dict(cfg_node, key_list=[]): 6 | """Convert a config node to dictionary.""" 7 | _VALID_TYPES = {tuple, list, str, int, float, bool} 8 | if not isinstance(cfg_node, _CfgNode): 9 | if type(cfg_node) not in _VALID_TYPES: 10 | print( 11 | 'Key {} with value {} is not a valid type; valid types: {}'.format( 12 | '.'.join(key_list), type(cfg_node), _VALID_TYPES 13 | ), 14 | ) 15 | return cfg_node 16 | else: 17 | cfg_dict = dict(cfg_node) 18 | for k, v in cfg_dict.items(): 19 | cfg_dict[k] = convert_to_dict(v, key_list + [k]) 20 | return cfg_dict 21 | 22 | 23 | class CfgNode(_CfgNode): 24 | """Remove once https://github.com/rbgirshick/yacs/issues/19 is merged.""" 25 | 26 | def convert_to_dict(self): 27 | return convert_to_dict(self) 28 | 29 | 30 | CN = CfgNode 31 | 32 | _C = CN() 33 | _C.LOG_DIR = 'tensorboard_logs' 34 | _C.TAG = 'default' 35 | 36 | _C.GPUS = [0] # which gpus to use 37 | _C.PRECISION = 32 # 16bit or 32bit 38 | _C.BATCHSIZE = 3 39 | _C.EPOCHS = 20 40 | 41 | _C.N_WORKERS = 5 42 | _C.VIS_INTERVAL = 5000 43 | _C.LOGGING_INTERVAL = 500 44 | 45 | _C.PRETRAINED = CN() 46 | _C.PRETRAINED.LOAD_WEIGHTS = False 47 | _C.PRETRAINED.PATH = '' 48 | 49 | _C.DATASET = CN() 50 | _C.DATASET.DATAROOT = './nuscenes/' 51 | _C.DATASET.VERSION = 'trainval' 52 | _C.DATASET.NAME = 'nuscenes' 53 | _C.DATASET.IGNORE_INDEX = 255 # Ignore index when creating flow/offset labels 54 | _C.DATASET.FILTER_INVISIBLE_VEHICLES = True # Filter vehicles that are not visible from the cameras 55 | 56 | _C.TIME_RECEPTIVE_FIELD = 3 # how many frames of temporal context (1 for single timeframe) 57 | _C.N_FUTURE_FRAMES = 4 # how many time steps into the future to predict 58 | 59 | _C.IMAGE = CN() 60 | _C.IMAGE.FINAL_DIM = (224, 480) 61 | _C.IMAGE.RESIZE_SCALE = 0.3 62 | _C.IMAGE.TOP_CROP = 46 63 | _C.IMAGE.ORIGINAL_HEIGHT = 900 # Original input RGB camera height 64 | _C.IMAGE.ORIGINAL_WIDTH = 1600 # Original input RGB camera width 65 | _C.IMAGE.NAMES = ['CAM_FRONT_LEFT', 'CAM_FRONT', 'CAM_FRONT_RIGHT', 'CAM_BACK_LEFT', 'CAM_BACK', 'CAM_BACK_RIGHT'] 66 | 67 | _C.LIFT = CN() # image to BEV lifting 68 | _C.LIFT.X_BOUND = [-50.0, 50.0, 0.5] # Forward 69 | _C.LIFT.Y_BOUND = [-50.0, 50.0, 0.5] # Sides 70 | _C.LIFT.Z_BOUND = [-10.0, 10.0, 20.0] # Height 71 | _C.LIFT.D_BOUND = [2.0, 50.0, 1.0] 72 | 73 | _C.MODEL = CN() 74 | 75 | _C.MODEL.ENCODER = CN() 76 | _C.MODEL.ENCODER.DOWNSAMPLE = 8 77 | _C.MODEL.ENCODER.NAME = 'efficientnet-b4' 78 | _C.MODEL.ENCODER.OUT_CHANNELS = 64 79 | _C.MODEL.ENCODER.USE_DEPTH_DISTRIBUTION = True 80 | 81 | _C.MODEL.SMALL_ENCODER = CN() 82 | _C.MODEL.SMALL_ENCODER.FILTER_SIZE = 64 83 | _C.MODEL.SMALL_ENCODER.SKIPCO = True 84 | 85 | 86 | _C.MODEL.FIRST_STATE = CN() 87 | _C.MODEL.FIRST_STATE.NUM_LAYERS = 3 88 | 89 | _C.MODEL.DYNAMICS = CN() 90 | _C.MODEL.DYNAMICS.NUM_LAYERS = 3 91 | 92 | 93 | 94 | 95 | _C.MODEL.DISTRIBUTION = CN() 96 | _C.MODEL.DISTRIBUTION.LATENT_DIM = 32 97 | _C.MODEL.DISTRIBUTION.MIN_LOG_SIGMA = -5.0 98 | _C.MODEL.DISTRIBUTION.MAX_LOG_SIGMA = 5.0 99 | _C.MODEL.DISTRIBUTION.POSTERIOR_LAYERS = 2 100 | _C.MODEL.DISTRIBUTION.PRIOR_LAYERS = 2 101 | 102 | _C.MODEL.FUTURE_PRED = CN() 103 | _C.MODEL.FUTURE_PRED.N_GRU_BLOCKS = 3 104 | _C.MODEL.FUTURE_PRED.N_RES_LAYERS = 3 105 | 106 | _C.MODEL.DECODER = CN() 107 | 108 | _C.MODEL.BN_MOMENTUM = 0.1 109 | _C.MODEL.SUBSAMPLE = False # Subsample frames for Lyft 110 | 111 | _C.SEMANTIC_SEG = CN() 112 | _C.SEMANTIC_SEG.WEIGHTS = [1.0, 2.0] # per class cross entropy weights (bg, dynamic, drivable, lane) 113 | _C.SEMANTIC_SEG.USE_TOP_K = True # backprop only top-k hardest pixels 114 | _C.SEMANTIC_SEG.TOP_K_RATIO = 0.25 115 | 116 | _C.INSTANCE_SEG = CN() 117 | 118 | _C.INSTANCE_FLOW = CN() 119 | _C.INSTANCE_FLOW.ENABLED = True 120 | 121 | _C.PROBABILISTIC = CN() 122 | _C.PROBABILISTIC.ENABLED = True # learn a distribution over futures 123 | _C.PROBABILISTIC.WEIGHT = 100.0 124 | _C.PROBABILISTIC.FUTURE_DIM = 6 # number of dimension added (future flow, future centerness, offset, seg) 125 | 126 | _C.FUTURE_DISCOUNT = 0.95 127 | 128 | _C.OPTIMIZER = CN() 129 | _C.OPTIMIZER.LR = 3e-4 130 | _C.OPTIMIZER.WEIGHT_DECAY = 1e-7 131 | _C.GRAD_NORM_CLIP = 5 132 | 133 | 134 | def get_parser(): 135 | parser = argparse.ArgumentParser(description='Fiery training') 136 | # TODO: remove below? 137 | parser.add_argument('--config-file', default='', metavar='FILE', help='path to config file') 138 | parser.add_argument( 139 | 'opts', help='Modify config options using the command-line', default=None, nargs=argparse.REMAINDER, 140 | ) 141 | return parser 142 | 143 | 144 | def get_cfg(args=None, cfg_dict=None): 145 | """ First get default config. Then merge cfg_dict. Then merge according to args. """ 146 | 147 | cfg = _C.clone() 148 | 149 | if cfg_dict is not None: 150 | cfg.merge_from_other_cfg(CfgNode(cfg_dict)) 151 | 152 | if args is not None: 153 | if args.config_file: 154 | cfg.merge_from_file(args.config_file) 155 | cfg.merge_from_list(args.opts) 156 | cfg.freeze() 157 | return cfg 158 | -------------------------------------------------------------------------------- /stretchbev/models/model_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Mickael Chen, Edouard Delasalles, Jean-Yves Franceschi, Patrick Gallinari, Sylvain Lamprier 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import torch 17 | 18 | import torch.distributions as distrib 19 | import torch.nn as nn 20 | import torch.nn.functional as F 21 | 22 | 23 | def init_weight(m, init_type='normal', init_gain=0.02): 24 | """ 25 | Initializes the input module with the given parameters. 26 | 27 | Only deals with `Conv2d`, `ConvTranspose2d`, `Linear` and `BatchNorm2d` layers. 28 | 29 | Parameters 30 | ---------- 31 | m : torch.nn.Module 32 | Module to initialize. 33 | init_type : str 34 | 'normal', 'xavier', 'kaiming', or 'orthogonal'. Orthogonal initialization types for convolutions and linear 35 | operations. Ignored for batch normalization which uses a normal initialization. 36 | init_gain : float 37 | Gain to use for the initialization. 38 | """ 39 | classname = m.__class__.__name__ 40 | if classname in ('Conv2d', 'ConvTranspose2d', 'Linear'): 41 | if init_type == 'normal': 42 | nn.init.normal_(m.weight.data, 0.0, init_gain) 43 | elif init_type == 'xavier': 44 | nn.init.xavier_normal_(m.weight.data, gain=init_gain) 45 | elif init_type == 'kaiming': 46 | nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 47 | elif init_type == 'orthogonal': 48 | nn.init.orthogonal_(m.weight.data, gain=init_gain) 49 | else: 50 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 51 | if hasattr(m, 'bias') and m.bias is not None: 52 | nn.init.constant_(m.bias.data, 0.0) 53 | elif classname == 'BatchNorm2d': 54 | if m.weight is not None: 55 | nn.init.normal_(m.weight.data, 1.0, init_gain) 56 | if m.bias is not None: 57 | nn.init.constant_(m.bias.data, 0.0) 58 | 59 | 60 | def make_normal_from_raw_params(raw_params, scale_stddev=1, dim=2, eps=1e-8, max_log_sigma=-10000, min_log_sigma=10000): 61 | """ 62 | Creates a normal distribution from the given parameters. 63 | 64 | Parameters 65 | ---------- 66 | raw_params : torch.*.Tensor 67 | Tensor containing the Gaussian mean and a raw scale parameter on a given dimension. 68 | scale_stddev : float 69 | Multiplier of the final scale parameter of the Gaussian. 70 | dim : int 71 | Dimensions of raw_params so that the first half corresponds to the mean, and the second half to the scale. 72 | eps : float 73 | Minimum possible value of the final scale parameter. 74 | 75 | Returns 76 | ------- 77 | torch.distributions.Normal 78 | Normal distribution with the input mean and eps + softplus(raw scale) * scale_stddev as standard deviation. 79 | """ 80 | dim = 2 if len(raw_params.shape) == 5 else 1 81 | loc, raw_scale = torch.chunk(raw_params, 2, dim) 82 | assert loc.shape[dim] == raw_scale.shape[dim], f'{loc.shape[dim]}, {raw_scale.shape[dim]}' 83 | # raw_scale = torch.clamp(raw_scale, min_log_sigma, max_log_sigma) 84 | scale = F.softplus(raw_scale) + eps 85 | normal = distrib.Normal(loc, scale * scale_stddev) 86 | return normal 87 | 88 | 89 | def rsample_normal(raw_params, scale_stddev=1, max_log_sigma=-10000, min_log_sigma=10000): 90 | """ 91 | Samples from a normal distribution with given parameters. 92 | 93 | Parameters 94 | ---------- 95 | raw_params : torch.*.Tensor 96 | Tensor containing a Gaussian mean and a raw scale parameter on its last dimension. 97 | scale_stddev : float 98 | Multiplier of the final scale parameter of the Gaussian. 99 | 100 | Returns 101 | ------- 102 | torch.*.Tensor 103 | Sample from the normal distribution with the input mean and eps + softplus(raw scale) * scale_stddev as 104 | standard deviation. 105 | """ 106 | 107 | normal = make_normal_from_raw_params(raw_params, scale_stddev=scale_stddev) 108 | sample = normal.rsample() 109 | return sample 110 | 111 | 112 | def neg_logprob(loc, data, scale=1): 113 | """ 114 | Computes the negative log density function of a given input with respect to a normal distribution created from 115 | given parameters. 116 | 117 | Parameters 118 | ---------- 119 | loc : torch.*.Tensor 120 | Tensor containing the mean of the Gaussian on its last dimension. 121 | data : torch.*.Tensor 122 | Computes the log density function of this tensor with respect to the Gaussian distribution of input mean and 123 | standard deviation. 124 | scale : float 125 | Standard deviation of the Gaussian. 126 | 127 | Returns 128 | ------- 129 | torch.*.Tensor 130 | Sample from the normal distribution with the input mean and eps + softplus(raw scale) * scale_stddev as 131 | standard deviation. 132 | """ 133 | obs_distrib = distrib.Normal(loc, scale) 134 | return -obs_distrib.log_prob(data) 135 | -------------------------------------------------------------------------------- /stretchbev/models/.ipynb_checkpoints/model_utils-checkpoint.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Mickael Chen, Edouard Delasalles, Jean-Yves Franceschi, Patrick Gallinari, Sylvain Lamprier 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import torch 17 | 18 | import torch.distributions as distrib 19 | import torch.nn as nn 20 | import torch.nn.functional as F 21 | 22 | 23 | def init_weight(m, init_type='normal', init_gain=0.02): 24 | """ 25 | Initializes the input module with the given parameters. 26 | 27 | Only deals with `Conv2d`, `ConvTranspose2d`, `Linear` and `BatchNorm2d` layers. 28 | 29 | Parameters 30 | ---------- 31 | m : torch.nn.Module 32 | Module to initialize. 33 | init_type : str 34 | 'normal', 'xavier', 'kaiming', or 'orthogonal'. Orthogonal initialization types for convolutions and linear 35 | operations. Ignored for batch normalization which uses a normal initialization. 36 | init_gain : float 37 | Gain to use for the initialization. 38 | """ 39 | classname = m.__class__.__name__ 40 | if classname in ('Conv2d', 'ConvTranspose2d', 'Linear'): 41 | if init_type == 'normal': 42 | nn.init.normal_(m.weight.data, 0.0, init_gain) 43 | elif init_type == 'xavier': 44 | nn.init.xavier_normal_(m.weight.data, gain=init_gain) 45 | elif init_type == 'kaiming': 46 | nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 47 | elif init_type == 'orthogonal': 48 | nn.init.orthogonal_(m.weight.data, gain=init_gain) 49 | else: 50 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 51 | if hasattr(m, 'bias') and m.bias is not None: 52 | nn.init.constant_(m.bias.data, 0.0) 53 | elif classname == 'BatchNorm2d': 54 | if m.weight is not None: 55 | nn.init.normal_(m.weight.data, 1.0, init_gain) 56 | if m.bias is not None: 57 | nn.init.constant_(m.bias.data, 0.0) 58 | 59 | 60 | def make_normal_from_raw_params(raw_params, scale_stddev=1, dim=2, eps=1e-8, max_log_sigma=-10000, min_log_sigma=10000): 61 | """ 62 | Creates a normal distribution from the given parameters. 63 | 64 | Parameters 65 | ---------- 66 | raw_params : torch.*.Tensor 67 | Tensor containing the Gaussian mean and a raw scale parameter on a given dimension. 68 | scale_stddev : float 69 | Multiplier of the final scale parameter of the Gaussian. 70 | dim : int 71 | Dimensions of raw_params so that the first half corresponds to the mean, and the second half to the scale. 72 | eps : float 73 | Minimum possible value of the final scale parameter. 74 | 75 | Returns 76 | ------- 77 | torch.distributions.Normal 78 | Normal distribution with the input mean and eps + softplus(raw scale) * scale_stddev as standard deviation. 79 | """ 80 | dim = 2 if len(raw_params.shape) == 5 else 1 81 | loc, raw_scale = torch.chunk(raw_params, 2, dim) 82 | assert loc.shape[dim] == raw_scale.shape[dim], f'{loc.shape[dim]}, {raw_scale.shape[dim]}' 83 | # raw_scale = torch.clamp(raw_scale, min_log_sigma, max_log_sigma) 84 | scale = F.softplus(raw_scale) + eps 85 | normal = distrib.Normal(loc, scale * scale_stddev) 86 | return normal 87 | 88 | 89 | def rsample_normal(raw_params, scale_stddev=1, max_log_sigma=-10000, min_log_sigma=10000): 90 | """ 91 | Samples from a normal distribution with given parameters. 92 | 93 | Parameters 94 | ---------- 95 | raw_params : torch.*.Tensor 96 | Tensor containing a Gaussian mean and a raw scale parameter on its last dimension. 97 | scale_stddev : float 98 | Multiplier of the final scale parameter of the Gaussian. 99 | 100 | Returns 101 | ------- 102 | torch.*.Tensor 103 | Sample from the normal distribution with the input mean and eps + softplus(raw scale) * scale_stddev as 104 | standard deviation. 105 | """ 106 | 107 | normal = make_normal_from_raw_params(raw_params, scale_stddev=scale_stddev) 108 | sample = normal.rsample() 109 | return sample 110 | 111 | 112 | def neg_logprob(loc, data, scale=1): 113 | """ 114 | Computes the negative log density function of a given input with respect to a normal distribution created from 115 | given parameters. 116 | 117 | Parameters 118 | ---------- 119 | loc : torch.*.Tensor 120 | Tensor containing the mean of the Gaussian on its last dimension. 121 | data : torch.*.Tensor 122 | Computes the log density function of this tensor with respect to the Gaussian distribution of input mean and 123 | standard deviation. 124 | scale : float 125 | Standard deviation of the Gaussian. 126 | 127 | Returns 128 | ------- 129 | torch.*.Tensor 130 | Sample from the normal distribution with the input mean and eps + softplus(raw scale) * scale_stddev as 131 | standard deviation. 132 | """ 133 | obs_distrib = distrib.Normal(loc, scale) 134 | return -obs_distrib.log_prob(data) -------------------------------------------------------------------------------- /visualise.py: -------------------------------------------------------------------------------- 1 | import os 2 | from argparse import ArgumentParser 3 | from glob import glob 4 | 5 | import cv2 6 | import matplotlib as mpl 7 | import matplotlib.pyplot as plt 8 | import numpy as np 9 | import torch 10 | import torchvision 11 | from PIL import Image 12 | from fiery.trainer import TrainingModule 13 | from fiery.utils.instance import predict_instance_segmentation_and_trajectories 14 | from fiery.utils.network import NormalizeInverse 15 | from fiery.utils.visualisation import plot_instance_map, generate_instance_colours, make_contour, convert_figure_numpy 16 | 17 | EXAMPLE_DATA_PATH = 'example_data' 18 | 19 | 20 | def plot_prediction(image, output, cfg): 21 | # Process predictions 22 | consistent_instance_seg, matched_centers = predict_instance_segmentation_and_trajectories( 23 | output, compute_matched_centers=True 24 | ) 25 | 26 | # Plot future trajectories 27 | unique_ids = torch.unique(consistent_instance_seg[0, 0]).cpu().long().numpy()[1:] 28 | instance_map = dict(zip(unique_ids, unique_ids)) 29 | instance_colours = generate_instance_colours(instance_map) 30 | vis_image = plot_instance_map(consistent_instance_seg[0, 0].cpu().numpy(), instance_map) 31 | trajectory_img = np.zeros(vis_image.shape, dtype=np.uint8) 32 | for instance_id in unique_ids: 33 | path = matched_centers[instance_id] 34 | for t in range(len(path) - 1): 35 | color = instance_colours[instance_id].tolist() 36 | cv2.line(trajectory_img, tuple(path[t]), tuple(path[t + 1]), 37 | color, 4) 38 | 39 | # Overlay arrows 40 | temp_img = cv2.addWeighted(vis_image, 0.7, trajectory_img, 0.3, 1.0) 41 | mask = ~ np.all(trajectory_img == 0, axis=2) 42 | vis_image[mask] = temp_img[mask] 43 | 44 | # Plot present RGB frames and predictions 45 | val_w = 2.99 46 | cameras = cfg.IMAGE.NAMES 47 | image_ratio = cfg.IMAGE.FINAL_DIM[0] / cfg.IMAGE.FINAL_DIM[1] 48 | val_h = val_w * image_ratio 49 | fig = plt.figure(figsize=(4 * val_w, 2 * val_h)) 50 | width_ratios = (val_w, val_w, val_w, val_w) 51 | gs = mpl.gridspec.GridSpec(2, 4, width_ratios=width_ratios) 52 | gs.update(wspace=0.0, hspace=0.0, left=0.0, right=1.0, top=1.0, bottom=0.0) 53 | 54 | denormalise_img = torchvision.transforms.Compose( 55 | (NormalizeInverse(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 56 | torchvision.transforms.ToPILImage(),) 57 | ) 58 | for imgi, img in enumerate(image[0, -1]): 59 | ax = plt.subplot(gs[imgi // 3, imgi % 3]) 60 | showimg = denormalise_img(img.cpu()) 61 | if imgi > 2: 62 | showimg = showimg.transpose(Image.FLIP_LEFT_RIGHT) 63 | 64 | plt.annotate(cameras[imgi].replace('_', ' ').replace('CAM ', ''), (0.01, 0.87), c='white', 65 | xycoords='axes fraction', fontsize=14) 66 | plt.imshow(showimg) 67 | plt.axis('off') 68 | 69 | ax = plt.subplot(gs[:, 3]) 70 | plt.imshow(make_contour(vis_image[::-1, ::-1])) 71 | plt.axis('off') 72 | 73 | plt.draw() 74 | figure_numpy = convert_figure_numpy(fig) 75 | plt.close() 76 | return figure_numpy 77 | 78 | 79 | def download_example_data(): 80 | from requests import get 81 | 82 | def download(url, file_name): 83 | # open in binary mode 84 | with open(file_name, "wb") as file: 85 | # get request 86 | response = get(url) 87 | # write to file 88 | file.write(response.content) 89 | 90 | os.makedirs(EXAMPLE_DATA_PATH, exist_ok=True) 91 | url_list = ['https://github.com/wayveai/fiery/releases/download/v1.0/example_1.npz', 92 | 'https://github.com/wayveai/fiery/releases/download/v1.0/example_2.npz', 93 | 'https://github.com/wayveai/fiery/releases/download/v1.0/example_3.npz', 94 | 'https://github.com/wayveai/fiery/releases/download/v1.0/example_4.npz' 95 | ] 96 | for url in url_list: 97 | download(url, os.path.join(EXAMPLE_DATA_PATH, os.path.basename(url))) 98 | 99 | 100 | def visualise(checkpoint_path): 101 | trainer = TrainingModule.load_from_checkpoint(checkpoint_path, strict=True) 102 | 103 | device = torch.device('cuda:0') 104 | trainer = trainer.to(device) 105 | trainer.eval() 106 | 107 | # Download example data 108 | download_example_data() 109 | # Load data 110 | for data_path in sorted(glob(os.path.join(EXAMPLE_DATA_PATH, '*.npz'))): 111 | data = np.load(data_path) 112 | image = torch.from_numpy(data['image']).to(device) 113 | intrinsics = torch.from_numpy(data['intrinsics']).to(device) 114 | extrinsics = torch.from_numpy(data['extrinsics']).to(device) 115 | future_egomotions = torch.from_numpy(data['future_egomotion']).to(device) 116 | 117 | # Forward pass 118 | with torch.no_grad(): 119 | output = trainer.model(image, intrinsics, extrinsics, future_egomotions) 120 | 121 | figure_numpy = plot_prediction(image, output, trainer.cfg) 122 | os.makedirs('./output_vis', exist_ok=True) 123 | output_filename = os.path.join('./output_vis', os.path.basename(data_path).split('.')[0]) + '.png' 124 | Image.fromarray(figure_numpy).save(output_filename) 125 | print(f'Saved output in {output_filename}') 126 | 127 | 128 | if __name__ == '__main__': 129 | parser = ArgumentParser(description='Fiery visualisation') 130 | parser.add_argument('--checkpoint', default='./fiery.ckpt', type=str, help='path to checkpoint') 131 | 132 | args = parser.parse_args() 133 | 134 | visualise(args.checkpoint) 135 | -------------------------------------------------------------------------------- /stretchbev/models/res_models.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from functools import partial 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | class ConvBlock(nn.Module): 9 | """2D convolution followed by 10 | - an optional normalisation (batch norm or instance norm) 11 | - an optional activation (ReLU, LeakyReLU, or tanh) 12 | """ 13 | 14 | def __init__(self, in_channels, out_channels=None, kernel_size=3, stride=1, norm='bn', activation='lrelu', 15 | bias=False, transpose=False): 16 | super().__init__() 17 | out_channels = out_channels or in_channels 18 | padding = int((kernel_size - 1) / 2) 19 | self.conv = nn.Conv2d if not transpose else partial(nn.ConvTranspose2d) 20 | self.conv = self.conv(in_channels, out_channels, kernel_size, stride, padding=padding, bias=bias) 21 | 22 | if norm == 'bn': 23 | self.norm = nn.BatchNorm2d(out_channels) 24 | elif norm == 'in': 25 | self.norm = nn.InstanceNorm2d(out_channels) 26 | elif norm == 'none': 27 | self.norm = None 28 | else: 29 | raise ValueError('Invalid norm {}'.format(norm)) 30 | 31 | if activation == 'relu': 32 | self.activation = nn.ReLU() 33 | elif activation == 'lrelu': 34 | self.activation = nn.LeakyReLU(0.1) 35 | elif activation == 'tanh': 36 | self.activation = nn.Tanh() 37 | elif activation == 'none': 38 | self.activation = None 39 | else: 40 | raise ValueError('Invalid activation {}'.format(activation)) 41 | 42 | def forward(self, x): 43 | x = self.conv(x) 44 | 45 | if self.norm: 46 | x = self.norm(x) 47 | if self.activation: 48 | x = self.activation(x) 49 | return x 50 | 51 | 52 | class ResBlock(nn.Module): 53 | """Residual block: 54 | x -> Conv -> norm -> act. -> Conv -> norm -> act. -> ADD -> out 55 | | | 56 | --------------------------------------------------- 57 | """ 58 | 59 | def __init__(self, in_channels, out_channels=None, norm='bn', activation='lrelu', bias=False): 60 | super().__init__() 61 | out_channels = out_channels or in_channels 62 | 63 | self.layers = nn.Sequential(OrderedDict([ 64 | ('conv_1', ConvBlock(in_channels, in_channels, 3, stride=1, norm=norm, activation=activation, bias=bias)), 65 | ('conv_2', ConvBlock(in_channels, out_channels, 3, stride=1, norm=norm, activation=activation, bias=bias)), 66 | ('dropout', nn.Dropout2d(0.25)), 67 | ])) 68 | 69 | if out_channels != in_channels: 70 | self.projection = nn.Conv2d(in_channels, out_channels, 1) 71 | else: 72 | self.projection = None 73 | 74 | def forward(self, x): 75 | x_residual = self.layers(x) 76 | 77 | if self.projection: 78 | x = self.projection(x) 79 | return x + x_residual 80 | 81 | 82 | class SmallEncoder(nn.Module): 83 | def __init__(self, nc, nh, nf): 84 | super(SmallEncoder, self).__init__() 85 | 86 | self.blocks = nn.ModuleList([ 87 | ResBlock(nc, nf), 88 | ResBlock(nf, nf * 2), 89 | ResBlock(nf * 2, nf * 2), 90 | ResBlock(nf * 2, nf * 2), 91 | ResBlock(nf * 2, nf * 4) 92 | ]) 93 | self.last_conv = nn.Sequential( 94 | ConvBlock(nf * 4, nh, 3, stride=1, activation='tanh') 95 | ) 96 | self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) 97 | 98 | def forward(self, x, return_skip=False): 99 | h = x 100 | skips = [] 101 | for i, layer in enumerate(self.blocks): 102 | if i in [1, 2]: 103 | h = self.maxpool(h) 104 | h = layer(h) 105 | skips.append(h) 106 | h = self.last_conv(h) 107 | if return_skip: 108 | return h, skips[::-1] 109 | return h 110 | 111 | 112 | class SmallDecoder(nn.Module): 113 | def __init__(self, nc, nh, nf, skip): 114 | super(SmallDecoder, self).__init__() 115 | coef = 2 if skip else 1 116 | self.skip = skip 117 | 118 | self.first_upconv = ConvBlock(nc, nf * 4, stride=1, transpose=True) 119 | 120 | self.blocks = nn.ModuleList([ 121 | ResBlock(nf * 4 * coef, nf * 2), 122 | ResBlock(nf * 2 * coef, nf * 2), 123 | ResBlock(nf * 2 * coef, nf * 2), 124 | ResBlock(nf * 2 * coef, nf), 125 | ResBlock(nf * coef, nf) 126 | ]) 127 | self.last_conv = nn.Sequential( 128 | ConvBlock(nf * coef, nf, 3, stride=1), 129 | ConvBlock(nf, nh, 3, stride=1, transpose=True, bias=True, norm='none'), 130 | 131 | ) 132 | self.upsample = nn.Upsample(scale_factor=2, mode='nearest') 133 | 134 | def forward(self, z, skip=None, sigmoid=False): 135 | assert skip is None and not self.skip or self.skip and skip is not None 136 | h = self.first_upconv(z) 137 | for i, layer in enumerate(self.blocks): 138 | # print(i, h.shape, skip[i].shape) 139 | if skip is not None: 140 | h = torch.cat([h, skip[i]], 1) 141 | h = layer(h) 142 | if i in [2, 3]: 143 | h = self.upsample(h) 144 | x_ = h 145 | if sigmoid: 146 | x_ = torch.sigmoid(x_) 147 | return x_ 148 | 149 | 150 | class SELayer(nn.Module): 151 | def __init__(self, channel, reduction=8): 152 | super(SELayer, self).__init__() 153 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 154 | self.fc = nn.Sequential( 155 | nn.Linear(channel, channel // reduction, bias=False), 156 | nn.ReLU(inplace=True), 157 | nn.Linear(channel // reduction, channel, bias=False), 158 | nn.Sigmoid() 159 | ) 160 | 161 | def forward(self, x): 162 | b, c, _, _ = x.size() 163 | y = self.avg_pool(x).view(b, c) 164 | y = self.fc(y).view(b, c, 1, 1) 165 | return x * y.expand_as(x) 166 | 167 | 168 | class ConvNet(nn.Module): 169 | def __init__(self, in_c, out_c, nlayers): 170 | super(ConvNet, self).__init__() 171 | self.model = nn.Sequential( 172 | ResBlock(in_c, out_c), 173 | SELayer(out_c), 174 | ResBlock(out_c, out_c), 175 | SELayer(out_c), 176 | ConvBlock(out_c, out_c, 3, stride=1, bias=True, norm='none'), 177 | ) 178 | 179 | def forward(self, x): 180 | return self.model(x) 181 | -------------------------------------------------------------------------------- /stretchbev/models/.ipynb_checkpoints/res_models-checkpoint.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from functools import partial 3 | import torch.nn as nn 4 | import torch 5 | 6 | 7 | class ConvBlock(nn.Module): 8 | """2D convolution followed by 9 | - an optional normalisation (batch norm or instance norm) 10 | - an optional activation (ReLU, LeakyReLU, or tanh) 11 | """ 12 | def __init__(self, in_channels, out_channels=None, kernel_size=3, stride=1, norm='bn', activation='lrelu', 13 | bias=False, transpose=False): 14 | super().__init__() 15 | out_channels = out_channels or in_channels 16 | padding = int((kernel_size - 1) / 2) 17 | self.conv = nn.Conv2d if not transpose else partial(nn.ConvTranspose2d) 18 | self.conv = self.conv(in_channels, out_channels, kernel_size, stride, padding=padding, bias=bias) 19 | 20 | if norm == 'bn': 21 | self.norm = nn.BatchNorm2d(out_channels) 22 | elif norm == 'in': 23 | self.norm = nn.InstanceNorm2d(out_channels) 24 | elif norm == 'none': 25 | self.norm = None 26 | else: 27 | raise ValueError('Invalid norm {}'.format(norm)) 28 | 29 | if activation == 'relu': 30 | self.activation = nn.ReLU() 31 | elif activation == 'lrelu': 32 | self.activation = nn.LeakyReLU(0.1) 33 | elif activation == 'tanh': 34 | self.activation = nn.Tanh() 35 | elif activation == 'none': 36 | self.activation = None 37 | else: 38 | raise ValueError('Invalid activation {}'.format(activation)) 39 | 40 | def forward(self, x): 41 | x = self.conv(x) 42 | 43 | if self.norm: 44 | x = self.norm(x) 45 | if self.activation: 46 | x = self.activation(x) 47 | return x 48 | 49 | 50 | class ResBlock(nn.Module): 51 | """Residual block: 52 | x -> Conv -> norm -> act. -> Conv -> norm -> act. -> ADD -> out 53 | | | 54 | --------------------------------------------------- 55 | """ 56 | def __init__(self, in_channels, out_channels=None, norm='bn', activation='lrelu', bias=False): 57 | super().__init__() 58 | out_channels = out_channels or in_channels 59 | 60 | self.layers = nn.Sequential(OrderedDict([ 61 | ('conv_1', ConvBlock(in_channels, in_channels, 3, stride=1, norm=norm, activation=activation, bias=bias)), 62 | ('conv_2', ConvBlock(in_channels, out_channels, 3, stride=1, norm=norm, activation=activation, bias=bias)), 63 | # ('dropout', nn.Dropout2d(0.25)), 64 | ])) 65 | 66 | if out_channels != in_channels: 67 | self.projection = nn.Conv2d(in_channels, out_channels, 1) 68 | else: 69 | self.projection = None 70 | 71 | def forward(self, x): 72 | x_residual = self.layers(x) 73 | 74 | if self.projection: 75 | x = self.projection(x) 76 | return x + x_residual 77 | 78 | 79 | 80 | class SmallEncoder(nn.Module): 81 | def __init__(self, nc, nh, nf): 82 | super(SmallEncoder, self).__init__() 83 | 84 | self.blocks = nn.ModuleList([ 85 | ResBlock(nc, nf), 86 | ResBlock(nf, nf*2), 87 | ResBlock(nf*2, nf*2), 88 | ResBlock(nf*2, nf*2), 89 | ResBlock(nf*2, nf*4) 90 | ]) 91 | self.last_conv = nn.Sequential( 92 | ConvBlock(nf*4, nh, 3, stride=1, activation='tanh') 93 | ) 94 | self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) 95 | 96 | def forward(self, x, return_skip=False): 97 | h = x 98 | skips = [] 99 | for i, layer in enumerate(self.blocks): 100 | if i in [1,2]: 101 | h = self.maxpool(h) 102 | h = layer(h) 103 | skips.append(h) 104 | h = self.last_conv(h) 105 | if return_skip: 106 | return h, skips[::-1] 107 | return h 108 | 109 | class SmallDecoder(nn.Module): 110 | def __init__(self, nc, nh, nf, skip): 111 | super(SmallDecoder, self).__init__() 112 | coef = 2 if skip else 1 113 | self.skip = skip 114 | 115 | self.first_upconv = ConvBlock(nc,nf*4, stride=1, transpose=True) 116 | 117 | self.blocks = nn.ModuleList([ 118 | ResBlock(nf*4*coef, nf*2), 119 | ResBlock(nf*2*coef, nf*2), 120 | ResBlock(nf*2*coef, nf*2), 121 | ResBlock(nf*2*coef, nf), 122 | ResBlock(nf*coef, nf) 123 | ]) 124 | self.last_conv = nn.Sequential( 125 | ConvBlock(nf*coef, nf, 3, stride=1), 126 | ConvBlock(nf, nh, 3, stride=1, transpose=True, bias=True, norm='none'), 127 | 128 | ) 129 | self.upsample = nn.Upsample(scale_factor=2, mode='nearest') 130 | 131 | def forward(self, z, skip=None, sigmoid=False): 132 | assert skip is None and not self.skip or self.skip and skip is not None 133 | h = self.first_upconv(z) 134 | for i, layer in enumerate(self.blocks): 135 | # print(i, h.shape, skip[i].shape) 136 | if skip is not None: 137 | h = torch.cat([h, skip[i]], 1) 138 | h = layer(h) 139 | if i in [2, 3]: 140 | h = self.upsample(h) 141 | x_ = h 142 | if sigmoid: 143 | x_ = torch.sigmoid(x_) 144 | return x_ 145 | 146 | class SELayer(nn.Module): 147 | def __init__(self, channel, reduction=8): 148 | super(SELayer, self).__init__() 149 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 150 | self.fc = nn.Sequential( 151 | nn.Linear(channel, channel // reduction, bias=False), 152 | nn.ReLU(inplace=True), 153 | nn.Linear(channel // reduction, channel, bias=False), 154 | nn.Sigmoid() 155 | ) 156 | 157 | def forward(self, x): 158 | b, c, _, _ = x.size() 159 | y = self.avg_pool(x).view(b, c) 160 | y = self.fc(y).view(b, c, 1, 1) 161 | return x * y.expand_as(x) 162 | 163 | class ConvNet(nn.Module): 164 | def __init__(self, in_c, out_c, nlayers): 165 | super(ConvNet, self).__init__() 166 | self.model = nn.Sequential( 167 | ResBlock(in_c, out_c), 168 | SELayer(out_c), 169 | ResBlock(out_c, out_c), 170 | SELayer(out_c), 171 | ConvBlock(out_c, out_c, 3, stride=1, bias=True, norm='none'), 172 | ) 173 | 174 | def forward(self, x): 175 | return self.model(x) -------------------------------------------------------------------------------- /stretchbev/layers/convolutions.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from functools import partial 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | class ConvBlock(nn.Module): 9 | """2D convolution followed by 10 | - an optional normalisation (batch norm or instance norm) 11 | - an optional activation (ReLU, LeakyReLU, or tanh) 12 | """ 13 | 14 | def __init__( 15 | self, 16 | in_channels, 17 | out_channels=None, 18 | kernel_size=3, 19 | stride=1, 20 | norm='bn', 21 | activation='relu', 22 | bias=False, 23 | transpose=False, 24 | ): 25 | super().__init__() 26 | out_channels = out_channels or in_channels 27 | padding = int((kernel_size - 1) / 2) 28 | self.conv = nn.Conv2d if not transpose else partial(nn.ConvTranspose2d, output_padding=1) 29 | self.conv = self.conv(in_channels, out_channels, kernel_size, stride, padding=padding, bias=bias) 30 | 31 | if norm == 'bn': 32 | self.norm = nn.BatchNorm2d(out_channels) 33 | elif norm == 'in': 34 | self.norm = nn.InstanceNorm2d(out_channels) 35 | elif norm == 'none': 36 | self.norm = None 37 | else: 38 | raise ValueError('Invalid norm {}'.format(norm)) 39 | 40 | if activation == 'relu': 41 | self.activation = nn.ReLU(inplace=True) 42 | elif activation == 'lrelu': 43 | self.activation = nn.LeakyReLU(0.1, inplace=True) 44 | elif activation == 'elu': 45 | self.activation = nn.ELU(inplace=True) 46 | elif activation == 'tanh': 47 | self.activation = nn.Tanh(inplace=True) 48 | elif activation == 'none': 49 | self.activation = None 50 | else: 51 | raise ValueError('Invalid activation {}'.format(activation)) 52 | 53 | def forward(self, x): 54 | x = self.conv(x) 55 | 56 | if self.norm: 57 | x = self.norm(x) 58 | if self.activation: 59 | x = self.activation(x) 60 | return x 61 | 62 | 63 | class Bottleneck(nn.Module): 64 | """ 65 | Defines a bottleneck module with a residual connection 66 | """ 67 | 68 | def __init__( 69 | self, 70 | in_channels, 71 | out_channels=None, 72 | kernel_size=3, 73 | dilation=1, 74 | groups=1, 75 | upsample=False, 76 | downsample=False, 77 | dropout=0.0, 78 | ): 79 | super().__init__() 80 | self._downsample = downsample 81 | bottleneck_channels = int(in_channels / 2) 82 | out_channels = out_channels or in_channels 83 | padding_size = ((kernel_size - 1) * dilation + 1) // 2 84 | 85 | # Define the main conv operation 86 | assert dilation == 1 87 | if upsample: 88 | assert not downsample, 'downsample and upsample not possible simultaneously.' 89 | bottleneck_conv = nn.ConvTranspose2d( 90 | bottleneck_channels, 91 | bottleneck_channels, 92 | kernel_size=kernel_size, 93 | bias=False, 94 | dilation=1, 95 | stride=2, 96 | output_padding=padding_size, 97 | padding=padding_size, 98 | groups=groups, 99 | ) 100 | elif downsample: 101 | bottleneck_conv = nn.Conv2d( 102 | bottleneck_channels, 103 | bottleneck_channels, 104 | kernel_size=kernel_size, 105 | bias=False, 106 | dilation=dilation, 107 | stride=2, 108 | padding=padding_size, 109 | groups=groups, 110 | ) 111 | else: 112 | bottleneck_conv = nn.Conv2d( 113 | bottleneck_channels, 114 | bottleneck_channels, 115 | kernel_size=kernel_size, 116 | bias=False, 117 | dilation=dilation, 118 | padding=padding_size, 119 | groups=groups, 120 | ) 121 | 122 | self.layers = nn.Sequential( 123 | OrderedDict( 124 | [ 125 | # First projection with 1x1 kernel 126 | ('conv_down_project', nn.Conv2d(in_channels, bottleneck_channels, kernel_size=1, bias=False)), 127 | ('abn_down_project', nn.Sequential(nn.BatchNorm2d(bottleneck_channels), 128 | nn.ReLU(inplace=True))), 129 | # Second conv block 130 | ('conv', bottleneck_conv), 131 | ('abn', nn.Sequential(nn.BatchNorm2d(bottleneck_channels), nn.ReLU(inplace=True))), 132 | # Final projection with 1x1 kernel 133 | ('conv_up_project', nn.Conv2d(bottleneck_channels, out_channels, kernel_size=1, bias=False)), 134 | ('abn_up_project', nn.Sequential(nn.BatchNorm2d(out_channels), 135 | nn.ReLU(inplace=True))), 136 | # Regulariser 137 | ('dropout', nn.Dropout2d(p=dropout)), 138 | ] 139 | ) 140 | ) 141 | 142 | if out_channels == in_channels and not downsample and not upsample: 143 | self.projection = None 144 | else: 145 | projection = OrderedDict() 146 | if upsample: 147 | projection.update({'upsample_skip_proj': Interpolate(scale_factor=2)}) 148 | elif downsample: 149 | projection.update({'upsample_skip_proj': nn.MaxPool2d(kernel_size=2, stride=2)}) 150 | projection.update( 151 | { 152 | 'conv_skip_proj': nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False), 153 | 'bn_skip_proj': nn.BatchNorm2d(out_channels), 154 | } 155 | ) 156 | self.projection = nn.Sequential(projection) 157 | 158 | # pylint: disable=arguments-differ 159 | def forward(self, *args): 160 | (x,) = args 161 | x_residual = self.layers(x) 162 | if self.projection is not None: 163 | if self._downsample: 164 | # pad h/w dimensions if they are odd to prevent shape mismatch with residual layer 165 | x = nn.functional.pad(x, (0, x.shape[-1] % 2, 0, x.shape[-2] % 2), value=0) 166 | return x_residual + self.projection(x) 167 | return x_residual + x 168 | 169 | 170 | class Interpolate(nn.Module): 171 | def __init__(self, scale_factor: int = 2): 172 | super().__init__() 173 | self._interpolate = nn.functional.interpolate 174 | self._scale_factor = scale_factor 175 | 176 | # pylint: disable=arguments-differ 177 | def forward(self, x): 178 | return self._interpolate(x, scale_factor=self._scale_factor, mode='bilinear', align_corners=False) 179 | 180 | 181 | class UpsamplingConcat(nn.Module): 182 | def __init__(self, in_channels, out_channels, scale_factor=2): 183 | super().__init__() 184 | 185 | self.upsample = nn.Upsample(scale_factor=scale_factor, mode='bilinear', align_corners=False) 186 | 187 | self.conv = nn.Sequential( 188 | nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False), 189 | nn.BatchNorm2d(out_channels), 190 | nn.ReLU(inplace=True), 191 | nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False), 192 | nn.BatchNorm2d(out_channels), 193 | nn.ReLU(inplace=True), 194 | ) 195 | 196 | def forward(self, x_to_upsample, x): 197 | x_to_upsample = self.upsample(x_to_upsample) 198 | x_to_upsample = torch.cat([x, x_to_upsample], dim=1) 199 | return self.conv(x_to_upsample) 200 | 201 | 202 | class UpsamplingAdd(nn.Module): 203 | def __init__(self, in_channels, out_channels, scale_factor=2): 204 | super().__init__() 205 | self.upsample_layer = nn.Sequential( 206 | nn.Upsample(scale_factor=scale_factor, mode='bilinear', align_corners=False), 207 | nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0, bias=False), 208 | nn.BatchNorm2d(out_channels), 209 | ) 210 | 211 | def forward(self, x, x_skip): 212 | x = self.upsample_layer(x) 213 | return x + x_skip 214 | -------------------------------------------------------------------------------- /stretchbev/models/srvp_models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import torch.nn as nn 4 | 5 | 6 | def activation_factory(name): 7 | """ 8 | Returns the activation layer corresponding to the input activation name. 9 | Parameters 10 | ---------- 11 | name : str 12 | 'relu', 'leaky_relu', 'elu', 'sigmoid', or 'tanh'. Adds the corresponding activation function after the 13 | convolution. 14 | Returns 15 | ------- 16 | torch.nn.Module 17 | Element-wise activation layer. 18 | """ 19 | if name == 'relu': 20 | return nn.ReLU(inplace=True) 21 | if name == 'leaky_relu': 22 | return nn.LeakyReLU(0.2, inplace=True) 23 | if name == 'elu': 24 | return nn.ELU(inplace=True) 25 | if name == 'sigmoid': 26 | return nn.Sigmoid() 27 | if name == 'tanh': 28 | return nn.Tanh() 29 | raise ValueError(f'Activation function \'{name}\' not yet implemented') 30 | 31 | 32 | def make_conv_block(conv, activation, bn=True): 33 | """ 34 | Supplements a convolutional block with activation functions and batch normalization. 35 | Parameters 36 | ---------- 37 | conv : torch.nn.Module 38 | Convolutional block. 39 | activation : str 40 | 'relu', 'leaky_relu', 'elu', 'sigmoid', 'tanh', or 'none'. Adds the corresponding activation function, or no 41 | activation if 'none' is chosen, after the convolution. 42 | bn : bool 43 | Whether to add batch normalization after the activation. 44 | Returns 45 | ------- 46 | torch.nn.Sequential 47 | Sequence of the input convolutional block, the potentially chosen activation function, and the potential batch 48 | normalization. 49 | """ 50 | out_channels = conv.out_channels 51 | modules = [conv] 52 | if bn: 53 | modules.append(nn.BatchNorm2d(out_channels)) 54 | if activation != 'none': 55 | modules.append(activation_factory(activation)) 56 | return nn.Sequential(*modules) 57 | 58 | 59 | class VGG64Encoder(nn.Module): 60 | # a note: we are downsampling to 1/8 size 61 | # hard coded for now 62 | """ 63 | Module implementing the VGG encoder. 64 | """ 65 | 66 | def __init__(self, nc, nh, nf): 67 | """ 68 | Parameters 69 | ---------- 70 | nc : int 71 | Number of channels in the input data. 72 | nh : int 73 | Number of dimensions of the output flat vector. 74 | nf : int 75 | Number of filters per channel of the first convolution. 76 | """ 77 | super(VGG64Encoder, self).__init__() 78 | self.conv = nn.ModuleList([ 79 | nn.Sequential( 80 | make_conv_block(nn.Conv2d(nc, nf, 3, 1, 1, bias=False), activation='leaky_relu'), 81 | make_conv_block(nn.Conv2d(nf, nf, 3, 1, 1, bias=False), activation='leaky_relu'), 82 | ), 83 | nn.Sequential( 84 | # nn.MaxPool2d(kernel_size=2, stride=2, padding=0), 85 | make_conv_block(nn.Conv2d(nf, nf * 2, 3, 1, 1, bias=False), activation='leaky_relu'), 86 | make_conv_block(nn.Conv2d(nf * 2, nf * 2, 3, 1, 1, bias=False), activation='leaky_relu'), 87 | ), 88 | nn.Sequential( 89 | # nn.MaxPool2d(kernel_size=2, stride=2, padding=0), 90 | make_conv_block(nn.Conv2d(nf * 2, nf * 2, 3, 1, 1, bias=False), activation='leaky_relu'), 91 | make_conv_block(nn.Conv2d(nf * 2, nf * 2, 3, 1, 1, bias=False), activation='leaky_relu'), 92 | make_conv_block(nn.Conv2d(nf * 2, nf * 2, 3, 1, 1, bias=False), activation='leaky_relu'), 93 | ), 94 | nn.Sequential( 95 | # nn.MaxPool2d(kernel_size=2, stride=2, padding=0), 96 | make_conv_block(nn.Conv2d(nf * 2, nf * 2, 3, 1, 1, bias=False), activation='leaky_relu'), 97 | make_conv_block(nn.Conv2d(nf * 2, nf * 2, 3, 1, 1, bias=False), activation='leaky_relu'), 98 | make_conv_block(nn.Conv2d(nf * 2, nf * 4, 3, 1, 1, bias=False), activation='leaky_relu'), 99 | ) 100 | ]) 101 | self.last_conv = nn.Sequential( 102 | make_conv_block(nn.Conv2d(nf * 4, nh, 3, 1, 1, bias=False), activation='tanh') 103 | ) 104 | self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) 105 | 106 | def forward(self, x, return_skip=False): 107 | """ 108 | Parameters 109 | ---------- 110 | x : torch.*.Tensor 111 | Encoder input. 112 | return_skip : bool 113 | Whether to extract and return, besides the network output, skip connections. 114 | Returns 115 | ------- 116 | torch.*.Tensor 117 | Encoder output as a tensor of shape (batch, size). 118 | list 119 | Only if return_skip is True. List of skip connections represented as torch.*.Tensor corresponding to each 120 | convolutional block in reverse order (from the deepest to the shallowest convolutional block). 121 | """ 122 | skips = [] 123 | h = x 124 | for i, layer in enumerate(self.conv): 125 | 126 | if i in [1, 2]: 127 | h = self.maxpool(h) 128 | h_res = layer(h) 129 | print(i, h.shape, h_res.shape) 130 | h = h + h_res 131 | skips.append(h) 132 | h = self.last_conv(h) 133 | if return_skip: 134 | return h, skips[::-1] 135 | return h 136 | 137 | 138 | class VGG64Decoder(nn.Module): 139 | # a note: we are upsampling to 1/8 size 140 | # hard coded for now 141 | """ 142 | Module implementing the VGG decoder. 143 | """ 144 | 145 | def __init__(self, nc, ny, nf, skip): 146 | """ 147 | Parameters 148 | ---------- 149 | nc : int 150 | Number of channels in the output shape. 151 | ny : int 152 | Number of dimensions of the input flat vector. 153 | nf : int 154 | Number of filters per channel of the first convolution of the mirror encoder architecture. 155 | skip : list 156 | List of torch.*.Tensor representing skip connections in the same order as the decoder convolutional 157 | blocks. Must be None when skip connections are not allowed. 158 | """ 159 | super(VGG64Decoder, self).__init__() 160 | # decoder 161 | coef = 2 if skip else 1 162 | self.skip = skip 163 | self.first_upconv = nn.Sequential( 164 | make_conv_block(nn.ConvTranspose2d(ny, nf * 4, 3, 1, 1, bias=False), activation='leaky_relu'), 165 | ) 166 | self.conv = nn.ModuleList([ 167 | nn.Sequential( 168 | make_conv_block(nn.Conv2d(nf * 4 * coef, nf * 2, 3, 1, 1, bias=False), activation='leaky_relu'), 169 | make_conv_block(nn.Conv2d(nf * 2, nf * 2, 3, 1, 1, bias=False), activation='leaky_relu'), 170 | make_conv_block(nn.Conv2d(nf * 2, nf * 2, 3, 1, 1, bias=False), activation='leaky_relu'), 171 | # nn.Upsample(scale_factor=2, mode='nearest'), 172 | ), 173 | nn.Sequential( 174 | make_conv_block(nn.Conv2d(nf * 2 * coef, nf * 2, 3, 1, 1, bias=False), activation='leaky_relu'), 175 | make_conv_block(nn.Conv2d(nf * 2, nf * 2, 3, 1, 1, bias=False), activation='leaky_relu'), 176 | make_conv_block(nn.Conv2d(nf * 2, nf * 2, 3, 1, 1, bias=False), activation='leaky_relu'), 177 | # nn.Upsample(scale_factor=2, mode='nearest'), 178 | ), 179 | nn.Sequential( 180 | make_conv_block(nn.Conv2d(nf * 2 * coef, nf * 2, 3, 1, 1, bias=False), activation='leaky_relu'), 181 | make_conv_block(nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=False), activation='leaky_relu'), 182 | # nn.Upsample(scale_factor=2, mode='nearest'), 183 | ), 184 | nn.Sequential( 185 | make_conv_block(nn.Conv2d(nf * coef, nf, 3, 1, 1, bias=False), activation='leaky_relu'), 186 | nn.ConvTranspose2d(nf, nc, 3, 1, 1, bias=False), 187 | ), 188 | ]) 189 | self.upsample = nn.Upsample(scale_factor=2, mode='nearest') 190 | 191 | def forward(self, z, skip=None, sigmoid=False): 192 | """ 193 | Parameters 194 | ---------- 195 | z : torch.*.Tensor 196 | Decoder input. 197 | skip : list 198 | List of torch.*.Tensor representing skip connections in the same order as the decoder convolutional 199 | blocks. Must be None when skip connections are not allowed. 200 | sigmoid : bool 201 | Whether to apply a sigmoid at the end of the decoder. 202 | Returns 203 | ------- 204 | torch.*.Tensor 205 | Decoder output as a frame of shape (batch, channels, width, height). 206 | """ 207 | assert skip is None and not self.skip or self.skip and skip is not None 208 | h = self.first_upconv(z) 209 | for i, layer in enumerate(self.conv): 210 | if skip is not None: 211 | h = torch.cat([h, skip[i]], 1) 212 | h_res = layer(h) 213 | h = h + h_res 214 | if i in [1, 2]: 215 | h = self.upsample(h) 216 | x_ = h 217 | if sigmoid: 218 | x_ = torch.sigmoid(x_) 219 | return x_ 220 | 221 | 222 | class SELayer(nn.Module): 223 | def __init__(self, channel, reduction=8): 224 | super(SELayer, self).__init__() 225 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 226 | self.fc = nn.Sequential( 227 | nn.Linear(channel, channel // reduction, bias=False), 228 | nn.ReLU(inplace=True), 229 | nn.Linear(channel // reduction, channel, bias=False), 230 | nn.Sigmoid() 231 | ) 232 | 233 | def forward(self, x): 234 | b, c, _, _ = x.size() 235 | y = self.avg_pool(x).view(b, c) 236 | y = self.fc(y).view(b, c, 1, 1) 237 | return x * y.expand_as(x) 238 | 239 | 240 | class ConvNet(nn.Module): 241 | def __init__(self, in_channels, out_channels, nlayers): 242 | super(ConvNet, self).__init__() 243 | 244 | layers = [] 245 | in_c = in_channels 246 | for _ in range(nlayers - 1): 247 | layers += [ 248 | make_conv_block(nn.Conv2d(in_c, out_channels, 3, 1, 1, bias=False), activation='leaky_relu') 249 | ] 250 | in_c = out_channels 251 | layers += [SELayer(in_c)] 252 | layers += [make_conv_block(nn.Conv2d(in_c, out_channels, 3, 1, 1, bias=True), activation='none', bn=False)] 253 | self.model = nn.Sequential(*layers) 254 | 255 | def forward(self, x): 256 | return self.model(x) 257 | -------------------------------------------------------------------------------- /stretchbev/models/.ipynb_checkpoints/srvp_models-checkpoint.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import torch.distributions as distrib 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | def activation_factory(name): 9 | """ 10 | Returns the activation layer corresponding to the input activation name. 11 | Parameters 12 | ---------- 13 | name : str 14 | 'relu', 'leaky_relu', 'elu', 'sigmoid', or 'tanh'. Adds the corresponding activation function after the 15 | convolution. 16 | Returns 17 | ------- 18 | torch.nn.Module 19 | Element-wise activation layer. 20 | """ 21 | if name == 'relu': 22 | return nn.ReLU(inplace=True) 23 | if name == 'leaky_relu': 24 | return nn.LeakyReLU(0.2, inplace=True) 25 | if name == 'elu': 26 | return nn.ELU(inplace=True) 27 | if name == 'sigmoid': 28 | return nn.Sigmoid() 29 | if name == 'tanh': 30 | return nn.Tanh() 31 | raise ValueError(f'Activation function \'{name}\' not yet implemented') 32 | 33 | 34 | 35 | def make_conv_block(conv, activation, bn=True): 36 | """ 37 | Supplements a convolutional block with activation functions and batch normalization. 38 | Parameters 39 | ---------- 40 | conv : torch.nn.Module 41 | Convolutional block. 42 | activation : str 43 | 'relu', 'leaky_relu', 'elu', 'sigmoid', 'tanh', or 'none'. Adds the corresponding activation function, or no 44 | activation if 'none' is chosen, after the convolution. 45 | bn : bool 46 | Whether to add batch normalization after the activation. 47 | Returns 48 | ------- 49 | torch.nn.Sequential 50 | Sequence of the input convolutional block, the potentially chosen activation function, and the potential batch 51 | normalization. 52 | """ 53 | out_channels = conv.out_channels 54 | modules = [conv] 55 | if bn: 56 | modules.append(nn.BatchNorm2d(out_channels)) 57 | if activation != 'none': 58 | modules.append(activation_factory(activation)) 59 | return nn.Sequential(*modules) 60 | 61 | 62 | 63 | 64 | class VGG64Encoder(nn.Module): 65 | # a note: we are downsampling to 1/8 size 66 | # hard coded for now 67 | """ 68 | Module implementing the VGG encoder. 69 | """ 70 | def __init__(self, nc, nh, nf): 71 | """ 72 | Parameters 73 | ---------- 74 | nc : int 75 | Number of channels in the input data. 76 | nh : int 77 | Number of dimensions of the output flat vector. 78 | nf : int 79 | Number of filters per channel of the first convolution. 80 | """ 81 | super(VGG64Encoder, self).__init__() 82 | self.conv = nn.ModuleList([ 83 | nn.Sequential( 84 | make_conv_block(nn.Conv2d(nc, nf, 3, 1, 1, bias=False), activation='leaky_relu'), 85 | make_conv_block(nn.Conv2d(nf, nf, 3, 1, 1, bias=False), activation='leaky_relu'), 86 | ), 87 | nn.Sequential( 88 | # nn.MaxPool2d(kernel_size=2, stride=2, padding=0), 89 | make_conv_block(nn.Conv2d(nf, nf * 2, 3, 1, 1, bias=False), activation='leaky_relu'), 90 | make_conv_block(nn.Conv2d(nf * 2, nf * 2, 3, 1, 1, bias=False), activation='leaky_relu'), 91 | ), 92 | nn.Sequential( 93 | # nn.MaxPool2d(kernel_size=2, stride=2, padding=0), 94 | make_conv_block(nn.Conv2d(nf * 2, nf * 2, 3, 1, 1, bias=False), activation='leaky_relu'), 95 | make_conv_block(nn.Conv2d(nf * 2, nf * 2, 3, 1, 1, bias=False), activation='leaky_relu'), 96 | make_conv_block(nn.Conv2d(nf * 2, nf * 2, 3, 1, 1, bias=False), activation='leaky_relu'), 97 | ), 98 | nn.Sequential( 99 | # nn.MaxPool2d(kernel_size=2, stride=2, padding=0), 100 | make_conv_block(nn.Conv2d(nf * 2, nf * 2, 3, 1, 1, bias=False), activation='leaky_relu'), 101 | make_conv_block(nn.Conv2d(nf * 2, nf * 2, 3, 1, 1, bias=False), activation='leaky_relu'), 102 | make_conv_block(nn.Conv2d(nf * 2, nf * 4, 3, 1, 1, bias=False), activation='leaky_relu'), 103 | ) 104 | ]) 105 | self.last_conv = nn.Sequential( 106 | make_conv_block(nn.Conv2d(nf * 4, nh, 3, 1, 1, bias=False), activation='tanh') 107 | ) 108 | self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) 109 | 110 | def forward(self, x, return_skip=False): 111 | """ 112 | Parameters 113 | ---------- 114 | x : torch.*.Tensor 115 | Encoder input. 116 | return_skip : bool 117 | Whether to extract and return, besides the network output, skip connections. 118 | Returns 119 | ------- 120 | torch.*.Tensor 121 | Encoder output as a tensor of shape (batch, size). 122 | list 123 | Only if return_skip is True. List of skip connections represented as torch.*.Tensor corresponding to each 124 | convolutional block in reverse order (from the deepest to the shallowest convolutional block). 125 | """ 126 | skips = [] 127 | h = x 128 | for i,layer in enumerate(self.conv): 129 | 130 | if i in [1,2]: 131 | h = self.maxpool(h) 132 | h_res = layer(h) 133 | print(i, h.shape, h_res.shape) 134 | h = h + h_res 135 | skips.append(h) 136 | h = self.last_conv(h) 137 | if return_skip: 138 | return h, skips[::-1] 139 | return h 140 | 141 | 142 | 143 | class VGG64Decoder(nn.Module): 144 | # a note: we are upsampling to 1/8 size 145 | # hard coded for now 146 | """ 147 | Module implementing the VGG decoder. 148 | """ 149 | def __init__(self, nc, ny, nf, skip): 150 | """ 151 | Parameters 152 | ---------- 153 | nc : int 154 | Number of channels in the output shape. 155 | ny : int 156 | Number of dimensions of the input flat vector. 157 | nf : int 158 | Number of filters per channel of the first convolution of the mirror encoder architecture. 159 | skip : list 160 | List of torch.*.Tensor representing skip connections in the same order as the decoder convolutional 161 | blocks. Must be None when skip connections are not allowed. 162 | """ 163 | super(VGG64Decoder, self).__init__() 164 | # decoder 165 | coef = 2 if skip else 1 166 | self.skip = skip 167 | self.first_upconv = nn.Sequential( 168 | make_conv_block(nn.ConvTranspose2d(ny, nf * 4, 3, 1, 1, bias=False), activation='leaky_relu'), 169 | ) 170 | self.conv = nn.ModuleList([ 171 | nn.Sequential( 172 | make_conv_block(nn.Conv2d(nf * 4 * coef, nf * 2, 3, 1, 1, bias=False), activation='leaky_relu'), 173 | make_conv_block(nn.Conv2d(nf * 2, nf * 2, 3, 1, 1, bias=False), activation='leaky_relu'), 174 | make_conv_block(nn.Conv2d(nf * 2, nf * 2, 3, 1, 1, bias=False), activation='leaky_relu'), 175 | # nn.Upsample(scale_factor=2, mode='nearest'), 176 | ), 177 | nn.Sequential( 178 | make_conv_block(nn.Conv2d(nf * 2 * coef, nf * 2, 3, 1, 1, bias=False), activation='leaky_relu'), 179 | make_conv_block(nn.Conv2d(nf * 2, nf * 2, 3, 1, 1, bias=False), activation='leaky_relu'), 180 | make_conv_block(nn.Conv2d(nf * 2, nf * 2, 3, 1, 1, bias=False), activation='leaky_relu'), 181 | # nn.Upsample(scale_factor=2, mode='nearest'), 182 | ), 183 | nn.Sequential( 184 | make_conv_block(nn.Conv2d(nf * 2 * coef, nf * 2, 3, 1, 1, bias=False), activation='leaky_relu'), 185 | make_conv_block(nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=False), activation='leaky_relu'), 186 | # nn.Upsample(scale_factor=2, mode='nearest'), 187 | ), 188 | nn.Sequential( 189 | make_conv_block(nn.Conv2d(nf * coef, nf, 3, 1, 1, bias=False), activation='leaky_relu'), 190 | nn.ConvTranspose2d(nf, nc, 3, 1, 1, bias=False), 191 | ), 192 | ]) 193 | self.upsample = nn.Upsample(scale_factor=2, mode='nearest') 194 | 195 | def forward(self, z, skip=None, sigmoid=False): 196 | """ 197 | Parameters 198 | ---------- 199 | z : torch.*.Tensor 200 | Decoder input. 201 | skip : list 202 | List of torch.*.Tensor representing skip connections in the same order as the decoder convolutional 203 | blocks. Must be None when skip connections are not allowed. 204 | sigmoid : bool 205 | Whether to apply a sigmoid at the end of the decoder. 206 | Returns 207 | ------- 208 | torch.*.Tensor 209 | Decoder output as a frame of shape (batch, channels, width, height). 210 | """ 211 | assert skip is None and not self.skip or self.skip and skip is not None 212 | h = self.first_upconv(z) 213 | for i, layer in enumerate(self.conv): 214 | if skip is not None: 215 | h = torch.cat([h, skip[i]], 1) 216 | h_res = layer(h) 217 | h = h + h_res 218 | if i in [1,2]: 219 | h = self.upsample(h) 220 | x_ = h 221 | if sigmoid: 222 | x_ = torch.sigmoid(x_) 223 | return x_ 224 | 225 | 226 | class SELayer(nn.Module): 227 | def __init__(self, channel, reduction=8): 228 | super(SELayer, self).__init__() 229 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 230 | self.fc = nn.Sequential( 231 | nn.Linear(channel, channel // reduction, bias=False), 232 | nn.ReLU(inplace=True), 233 | nn.Linear(channel // reduction, channel, bias=False), 234 | nn.Sigmoid() 235 | ) 236 | 237 | def forward(self, x): 238 | b, c, _, _ = x.size() 239 | y = self.avg_pool(x).view(b, c) 240 | y = self.fc(y).view(b, c, 1, 1) 241 | return x * y.expand_as(x) 242 | 243 | 244 | class ConvNet(nn.Module): 245 | def __init__(self, in_channels, out_channels, nlayers): 246 | super(ConvNet, self).__init__() 247 | 248 | layers = [] 249 | in_c = in_channels 250 | for _ in range(nlayers - 1): 251 | layers += [ 252 | make_conv_block(nn.Conv2d(in_c, out_channels, 3, 1, 1, bias=False), activation='leaky_relu') 253 | ] 254 | in_c = out_channels 255 | layers += [SELayer(in_c)] 256 | layers += [make_conv_block(nn.Conv2d(in_c, out_channels, 3, 1, 1, bias=True), activation='none', bn=False)] 257 | self.model = nn.Sequential(*layers) 258 | 259 | def forward(self, x): 260 | return self.model(x) -------------------------------------------------------------------------------- /stretchbev/utils/geometry.py: -------------------------------------------------------------------------------- 1 | import PIL 2 | import numpy as np 3 | import torch 4 | 5 | from pyquaternion import Quaternion 6 | 7 | 8 | def resize_and_crop_image(img, resize_dims, crop): 9 | # Bilinear resizing followed by cropping 10 | img = img.resize(resize_dims, resample=PIL.Image.BILINEAR) 11 | img = img.crop(crop) 12 | return img 13 | 14 | 15 | def update_intrinsics(intrinsics, top_crop=0.0, left_crop=0.0, scale_width=1.0, scale_height=1.0): 16 | """ 17 | Parameters 18 | ---------- 19 | intrinsics: torch.Tensor (3, 3) 20 | top_crop: float 21 | left_crop: float 22 | scale_width: float 23 | scale_height: float 24 | """ 25 | updated_intrinsics = intrinsics.clone() 26 | # Adjust intrinsics scale due to resizing 27 | updated_intrinsics[0, 0] *= scale_width 28 | updated_intrinsics[0, 2] *= scale_width 29 | updated_intrinsics[1, 1] *= scale_height 30 | updated_intrinsics[1, 2] *= scale_height 31 | 32 | # Adjust principal point due to cropping 33 | updated_intrinsics[0, 2] -= left_crop 34 | updated_intrinsics[1, 2] -= top_crop 35 | 36 | return updated_intrinsics 37 | 38 | 39 | def calculate_birds_eye_view_parameters(x_bounds, y_bounds, z_bounds): 40 | """ 41 | Parameters 42 | ---------- 43 | x_bounds: Forward direction in the ego-car. 44 | y_bounds: Sides 45 | z_bounds: Height 46 | 47 | Returns 48 | ------- 49 | bev_resolution: Bird's-eye view bev_resolution 50 | bev_start_position Bird's-eye view first element 51 | bev_dimension Bird's-eye view tensor spatial dimension 52 | """ 53 | bev_resolution = torch.tensor([row[2] for row in [x_bounds, y_bounds, z_bounds]]) 54 | bev_start_position = torch.tensor([row[0] + row[2] / 2.0 for row in [x_bounds, y_bounds, z_bounds]]) 55 | bev_dimension = torch.tensor([(row[1] - row[0]) / row[2] for row in [x_bounds, y_bounds, z_bounds]], 56 | dtype=torch.long) 57 | 58 | return bev_resolution, bev_start_position, bev_dimension 59 | 60 | 61 | def convert_egopose_to_matrix_numpy(egopose): 62 | transformation_matrix = np.zeros((4, 4), dtype=np.float32) 63 | rotation = Quaternion(egopose['rotation']).rotation_matrix 64 | translation = np.array(egopose['translation']) 65 | transformation_matrix[:3, :3] = rotation 66 | transformation_matrix[:3, 3] = translation 67 | transformation_matrix[3, 3] = 1.0 68 | return transformation_matrix 69 | 70 | 71 | def invert_matrix_egopose_numpy(egopose): 72 | """ Compute the inverse transformation of a 4x4 egopose numpy matrix.""" 73 | inverse_matrix = np.zeros((4, 4), dtype=np.float32) 74 | rotation = egopose[:3, :3] 75 | translation = egopose[:3, 3] 76 | inverse_matrix[:3, :3] = rotation.T 77 | inverse_matrix[:3, 3] = -np.dot(rotation.T, translation) 78 | inverse_matrix[3, 3] = 1.0 79 | return inverse_matrix 80 | 81 | 82 | def mat2pose_vec(matrix: torch.Tensor): 83 | """ 84 | Converts a 4x4 pose matrix into a 6-dof pose vector 85 | Args: 86 | matrix (ndarray): 4x4 pose matrix 87 | Returns: 88 | vector (ndarray): 6-dof pose vector comprising translation components (tx, ty, tz) and 89 | rotation components (rx, ry, rz) 90 | """ 91 | 92 | # M[1, 2] = -sinx*cosy, M[2, 2] = +cosx*cosy 93 | rotx = torch.atan2(-matrix[..., 1, 2], matrix[..., 2, 2]) 94 | 95 | # M[0, 2] = +siny, M[1, 2] = -sinx*cosy, M[2, 2] = +cosx*cosy 96 | cosy = torch.sqrt(matrix[..., 1, 2] ** 2 + matrix[..., 2, 2] ** 2) 97 | roty = torch.atan2(matrix[..., 0, 2], cosy) 98 | 99 | # M[0, 0] = +cosy*cosz, M[0, 1] = -cosy*sinz 100 | rotz = torch.atan2(-matrix[..., 0, 1], matrix[..., 0, 0]) 101 | 102 | rotation = torch.stack((rotx, roty, rotz), dim=-1) 103 | 104 | # Extract translation params 105 | translation = matrix[..., :3, 3] 106 | return torch.cat((translation, rotation), dim=-1) 107 | 108 | 109 | def euler2mat(angle: torch.Tensor): 110 | """Convert euler angles to rotation matrix. 111 | Reference: https://github.com/pulkitag/pycaffe-utils/blob/master/rot_utils.py#L174 112 | Args: 113 | angle: rotation angle along 3 axis (in radians) [Bx3] 114 | Returns: 115 | Rotation matrix corresponding to the euler angles [Bx3x3] 116 | """ 117 | shape = angle.shape 118 | angle = angle.view(-1, 3) 119 | x, y, z = angle[:, 0], angle[:, 1], angle[:, 2] 120 | 121 | cosz = torch.cos(z) 122 | sinz = torch.sin(z) 123 | 124 | zeros = torch.zeros_like(z) 125 | ones = torch.ones_like(z) 126 | zmat = torch.stack([cosz, -sinz, zeros, sinz, cosz, zeros, zeros, zeros, ones], dim=1).view(-1, 3, 3) 127 | 128 | cosy = torch.cos(y) 129 | siny = torch.sin(y) 130 | 131 | ymat = torch.stack([cosy, zeros, siny, zeros, ones, zeros, -siny, zeros, cosy], dim=1).view(-1, 3, 3) 132 | 133 | cosx = torch.cos(x) 134 | sinx = torch.sin(x) 135 | 136 | xmat = torch.stack([ones, zeros, zeros, zeros, cosx, -sinx, zeros, sinx, cosx], dim=1).view(-1, 3, 3) 137 | 138 | rot_mat = xmat.bmm(ymat).bmm(zmat) 139 | rot_mat = rot_mat.view(*shape[:-1], 3, 3) 140 | return rot_mat 141 | 142 | 143 | def pose_vec2mat(vec: torch.Tensor): 144 | """ 145 | Convert 6DoF parameters to transformation matrix. 146 | Args: 147 | vec: 6DoF parameters in the order of tx, ty, tz, rx, ry, rz [B,6] 148 | Returns: 149 | A transformation matrix [B,4,4] 150 | """ 151 | translation = vec[..., :3].unsqueeze(-1) # [...x3x1] 152 | rot = vec[..., 3:].contiguous() # [...x3] 153 | rot_mat = euler2mat(rot) # [...,3,3] 154 | transform_mat = torch.cat([rot_mat, translation], dim=-1) # [...,3,4] 155 | transform_mat = torch.nn.functional.pad(transform_mat, [0, 0, 0, 1], value=0) # [...,4,4] 156 | transform_mat[..., 3, 3] = 1.0 157 | return transform_mat 158 | 159 | 160 | def invert_pose_matrix(x): 161 | """ 162 | Parameters 163 | ---------- 164 | x: [B, 4, 4] batch of pose matrices 165 | 166 | Returns 167 | ------- 168 | out: [B, 4, 4] batch of inverse pose matrices 169 | """ 170 | assert len(x.shape) == 3 and x.shape[1:] == (4, 4), 'Only works for batch of pose matrices.' 171 | 172 | transposed_rotation = torch.transpose(x[:, :3, :3], 1, 2) 173 | translation = x[:, :3, 3:] 174 | 175 | inverse_mat = torch.cat([transposed_rotation, -torch.bmm(transposed_rotation, translation)], dim=-1) # [B,3,4] 176 | inverse_mat = torch.nn.functional.pad(inverse_mat, [0, 0, 0, 1], value=0) # [B,4,4] 177 | inverse_mat[..., 3, 3] = 1.0 178 | return inverse_mat 179 | 180 | 181 | def warp_features(x, flow, mode='nearest', spatial_extent=None): 182 | """ Applies a rotation and translation to feature map x. 183 | Args: 184 | x: (b, c, h, w) feature map 185 | flow: (b, 6) 6DoF vector (only uses the xy poriton) 186 | mode: use 'nearest' when dealing with categorical inputs 187 | Returns: 188 | in plane transformed feature map 189 | """ 190 | if flow is None: 191 | return x 192 | b, c, h, w = x.shape 193 | # z-rotation 194 | angle = flow[:, 5].clone() # torch.atan2(flow[:, 1, 0], flow[:, 0, 0]) 195 | # x-y translation 196 | translation = flow[:, :2].clone() # flow[:, :2, 3] 197 | 198 | # Normalise translation. Need to divide by how many meters is half of the image. 199 | # because translation of 1.0 correspond to translation of half of the image. 200 | translation[:, 0] /= spatial_extent[0] 201 | translation[:, 1] /= spatial_extent[1] 202 | # forward axis is inverted 203 | translation[:, 0] *= -1 204 | 205 | cos_theta = torch.cos(angle) 206 | sin_theta = torch.sin(angle) 207 | 208 | # output = Rot.input + translation 209 | # tx and ty are inverted as is the case when going from real coordinates to numpy coordinates 210 | # translation_pos_0 -> positive value makes the image move to the left 211 | # translation_pos_1 -> positive value makes the image move to the top 212 | # Angle -> positive value in rad makes the image move in the trigonometric way 213 | transformation = torch.stack([cos_theta, -sin_theta, translation[:, 1], 214 | sin_theta, cos_theta, translation[:, 0]], dim=-1).view(b, 2, 3) 215 | 216 | # Note that a rotation will preserve distances only if height = width. Otherwise there's 217 | # resizing going on. e.g. rotation of pi/2 of a 100x200 image will make what's in the center of the image 218 | # elongated. 219 | grid = torch.nn.functional.affine_grid(transformation, size=x.shape, align_corners=False) 220 | warped_x = torch.nn.functional.grid_sample(x, grid.float(), mode=mode, padding_mode='zeros', align_corners=False) 221 | 222 | return warped_x 223 | 224 | 225 | def cumulative_warp_features(x, flow, mode='nearest', spatial_extent=None): 226 | """ Warps a sequence of feature maps by accumulating incremental 2d flow. 227 | 228 | x[:, -1] remains unchanged 229 | x[:, -2] is warped using flow[:, -2] 230 | x[:, -3] is warped using flow[:, -3] @ flow[:, -2] 231 | ... 232 | x[:, 0] is warped using flow[:, 0] @ ... @ flow[:, -3] @ flow[:, -2] 233 | 234 | Args: 235 | x: (b, t, c, h, w) sequence of feature maps 236 | flow: (b, t, 6) sequence of 6 DoF pose 237 | from t to t+1 (only uses the xy poriton) 238 | 239 | """ 240 | sequence_length = x.shape[1] 241 | if sequence_length == 1: 242 | return x 243 | 244 | flow = pose_vec2mat(flow) 245 | 246 | out = [x[:, -1]] 247 | cum_flow = flow[:, -2] 248 | for t in reversed(range(sequence_length - 1)): 249 | out.append(warp_features(x[:, t], mat2pose_vec(cum_flow), mode=mode, spatial_extent=spatial_extent)) 250 | # @ is the equivalent of torch.bmm 251 | cum_flow = flow[:, t - 1] @ cum_flow 252 | 253 | return torch.stack(out[::-1], 1) 254 | 255 | 256 | def cumulative_warp_features_reverse(x, flow, mode='nearest', spatial_extent=None): 257 | """ Warps a sequence of feature maps by accumulating incremental 2d flow. 258 | 259 | x[:, 0] remains unchanged 260 | x[:, 1] is warped using flow[:, 0].inverse() 261 | x[:, 2] is warped using flow[:, 0].inverse() @ flow[:, 1].inverse() 262 | ... 263 | 264 | Args: 265 | x: (b, t, c, h, w) sequence of feature maps 266 | flow: (b, t, 6) sequence of 6 DoF pose 267 | from t to t+1 (only uses the xy poriton) 268 | 269 | """ 270 | flow = pose_vec2mat(flow) 271 | 272 | out = [x[:, 0]] 273 | 274 | for i in range(1, x.shape[1]): 275 | if i == 1: 276 | cum_flow = invert_pose_matrix(flow[:, 0]) 277 | else: 278 | cum_flow = cum_flow @ invert_pose_matrix(flow[:, i - 1]) 279 | out.append(warp_features(x[:, i], mat2pose_vec(cum_flow), mode, spatial_extent=spatial_extent)) 280 | return torch.stack(out, 1) 281 | 282 | 283 | class VoxelsSumming(torch.autograd.Function): 284 | """Adapted from https://github.com/nv-tlabs/lift-splat-shoot/blob/master/src/tools.py#L193""" 285 | 286 | @staticmethod 287 | def forward(ctx, x, geometry, ranks): 288 | """The features `x` and `geometry` are ranked by voxel positions.""" 289 | # Cumulative sum of all features. 290 | x = x.cumsum(0) 291 | 292 | # Indicates the change of voxel. 293 | mask = torch.ones(x.shape[0], device=x.device, dtype=torch.bool) 294 | mask[:-1] = ranks[1:] != ranks[:-1] 295 | 296 | x, geometry = x[mask], geometry[mask] 297 | # Calculate sum of features within a voxel. 298 | x = torch.cat((x[:1], x[1:] - x[:-1])) 299 | 300 | ctx.save_for_backward(mask) 301 | ctx.mark_non_differentiable(geometry) 302 | 303 | return x, geometry 304 | 305 | @staticmethod 306 | def backward(ctx, grad_x, grad_geometry): 307 | (mask,) = ctx.saved_tensors 308 | # Since the operation is summing, we simply need to send gradient 309 | # to all elements that were part of the summation process. 310 | indices = torch.cumsum(mask, 0) 311 | indices[mask] -= 1 312 | 313 | output_grad = grad_x[indices] 314 | 315 | return output_grad, None, None 316 | -------------------------------------------------------------------------------- /stretchbev/layers/.ipynb_checkpoints/temporal-checkpoint.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from fiery.layers.convolutions import ConvBlock 7 | from fiery.utils.geometry import warp_features 8 | 9 | 10 | class SpatialGRU(nn.Module): 11 | """A GRU cell that takes an input tensor [BxTxCxHxW] and an optional previous state and passes a 12 | convolutional gated recurrent unit over the data""" 13 | 14 | def __init__(self, input_size, hidden_size, gru_bias_init=0.0, norm='bn', activation='relu'): 15 | super().__init__() 16 | self.input_size = input_size 17 | self.hidden_size = hidden_size 18 | self.gru_bias_init = gru_bias_init 19 | 20 | self.conv_update = nn.Conv2d(input_size + hidden_size, hidden_size, kernel_size=3, bias=True, padding=1) 21 | self.conv_reset = nn.Conv2d(input_size + hidden_size, hidden_size, kernel_size=3, bias=True, padding=1) 22 | 23 | self.conv_state_tilde = ConvBlock( 24 | input_size + hidden_size, hidden_size, kernel_size=3, bias=False, norm=norm, activation=activation 25 | ) 26 | 27 | def forward(self, x, state=None, flow=None, mode='bilinear'): 28 | # pylint: disable=unused-argument, arguments-differ 29 | # Check size 30 | assert len(x.size()) == 5, 'Input tensor must be BxTxCxHxW.' 31 | b, timesteps, c, h, w = x.size() 32 | assert c == self.input_size, f'feature sizes must match, got input {c} for layer with size {self.input_size}' 33 | 34 | # recurrent layers 35 | rnn_output = [] 36 | rnn_state = torch.zeros(b, self.hidden_size, h, w, device=x.device) if state is None else state 37 | for t in range(timesteps): 38 | x_t = x[:, t] 39 | if flow is not None: 40 | rnn_state = warp_features(rnn_state, flow[:, t], mode=mode) 41 | 42 | # propagate rnn state 43 | rnn_state = self.gru_cell(x_t, rnn_state) 44 | rnn_output.append(rnn_state) 45 | 46 | # reshape rnn output to batch tensor 47 | return torch.stack(rnn_output, dim=1) 48 | 49 | def gru_cell(self, x, state): 50 | # Compute gates 51 | x_and_state = torch.cat([x, state], dim=1) 52 | update_gate = self.conv_update(x_and_state) 53 | reset_gate = self.conv_reset(x_and_state) 54 | # Add bias to initialise gate as close to identity function 55 | update_gate = torch.sigmoid(update_gate + self.gru_bias_init) 56 | reset_gate = torch.sigmoid(reset_gate + self.gru_bias_init) 57 | 58 | # Compute proposal state, activation is defined in norm_act_config (can be tanh, ReLU etc) 59 | state_tilde = self.conv_state_tilde(torch.cat([x, (1.0 - reset_gate) * state], dim=1)) 60 | 61 | output = (1.0 - update_gate) * state + update_gate * state_tilde 62 | return output 63 | 64 | 65 | class CausalConv3d(nn.Module): 66 | def __init__(self, in_channels, out_channels, kernel_size=(2, 3, 3), dilation=(1, 1, 1), bias=False): 67 | super().__init__() 68 | assert len(kernel_size) == 3, 'kernel_size must be a 3-tuple.' 69 | time_pad = (kernel_size[0] - 1) * dilation[0] 70 | height_pad = ((kernel_size[1] - 1) * dilation[1]) // 2 71 | width_pad = ((kernel_size[2] - 1) * dilation[2]) // 2 72 | 73 | # Pad temporally on the left 74 | self.pad = nn.ConstantPad3d(padding=(width_pad, width_pad, height_pad, height_pad, time_pad, 0), value=0) 75 | self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, dilation=dilation, stride=1, padding=0, bias=bias) 76 | self.norm = nn.BatchNorm3d(out_channels) 77 | self.activation = nn.ReLU(inplace=True) 78 | 79 | def forward(self, *inputs): 80 | (x,) = inputs 81 | x = self.pad(x) 82 | x = self.conv(x) 83 | x = self.norm(x) 84 | x = self.activation(x) 85 | return x 86 | 87 | 88 | class CausalMaxPool3d(nn.Module): 89 | def __init__(self, kernel_size=(2, 3, 3)): 90 | super().__init__() 91 | assert len(kernel_size) == 3, 'kernel_size must be a 3-tuple.' 92 | time_pad = kernel_size[0] - 1 93 | height_pad = (kernel_size[1] - 1) // 2 94 | width_pad = (kernel_size[2] - 1) // 2 95 | 96 | # Pad temporally on the left 97 | self.pad = nn.ConstantPad3d(padding=(width_pad, width_pad, height_pad, height_pad, time_pad, 0), value=0) 98 | self.max_pool = nn.MaxPool3d(kernel_size, stride=1) 99 | 100 | def forward(self, *inputs): 101 | (x,) = inputs 102 | x = self.pad(x) 103 | x = self.max_pool(x) 104 | return x 105 | 106 | 107 | def conv_1x1x1_norm_activated(in_channels, out_channels): 108 | """1x1x1 3D convolution, normalization and activation layer.""" 109 | return nn.Sequential( 110 | OrderedDict( 111 | [ 112 | ('conv', nn.Conv3d(in_channels, out_channels, kernel_size=1, bias=False)), 113 | ('norm', nn.BatchNorm3d(out_channels)), 114 | ('activation', nn.ReLU(inplace=True)), 115 | ] 116 | ) 117 | ) 118 | 119 | 120 | class Bottleneck3D(nn.Module): 121 | """ 122 | Defines a bottleneck module with a residual connection 123 | """ 124 | 125 | def __init__(self, in_channels, out_channels=None, kernel_size=(2, 3, 3), dilation=(1, 1, 1)): 126 | super().__init__() 127 | bottleneck_channels = in_channels // 2 128 | out_channels = out_channels or in_channels 129 | 130 | self.layers = nn.Sequential( 131 | OrderedDict( 132 | [ 133 | # First projection with 1x1 kernel 134 | ('conv_down_project', conv_1x1x1_norm_activated(in_channels, bottleneck_channels)), 135 | # Second conv block 136 | ( 137 | 'conv', 138 | CausalConv3d( 139 | bottleneck_channels, 140 | bottleneck_channels, 141 | kernel_size=kernel_size, 142 | dilation=dilation, 143 | bias=False, 144 | ), 145 | ), 146 | # Final projection with 1x1 kernel 147 | ('conv_up_project', conv_1x1x1_norm_activated(bottleneck_channels, out_channels)), 148 | ] 149 | ) 150 | ) 151 | 152 | if out_channels != in_channels: 153 | self.projection = nn.Sequential( 154 | nn.Conv3d(in_channels, out_channels, kernel_size=1, bias=False), 155 | nn.BatchNorm3d(out_channels), 156 | ) 157 | else: 158 | self.projection = None 159 | 160 | def forward(self, *args): 161 | (x,) = args 162 | x_residual = self.layers(x) 163 | x_features = self.projection(x) if self.projection is not None else x 164 | return x_residual + x_features 165 | 166 | 167 | class PyramidSpatioTemporalPooling(nn.Module): 168 | """ Spatio-temporal pyramid pooling. 169 | Performs 3D average pooling followed by 1x1x1 convolution to reduce the number of channels and upsampling. 170 | Setting contains a list of kernel_size: usually it is [(2, h, w), (2, h//2, w//2), (2, h//4, w//4)] 171 | """ 172 | 173 | def __init__(self, in_channels, reduction_channels, pool_sizes): 174 | super().__init__() 175 | self.features = [] 176 | for pool_size in pool_sizes: 177 | assert pool_size[0] == 2, ( 178 | "Time kernel should be 2 as PyTorch raises an error when" "padding with more than half the kernel size" 179 | ) 180 | stride = (1, *pool_size[1:]) 181 | padding = (pool_size[0] - 1, 0, 0) 182 | self.features.append( 183 | nn.Sequential( 184 | OrderedDict( 185 | [ 186 | # Pad the input tensor but do not take into account zero padding into the average. 187 | ( 188 | 'avgpool', 189 | torch.nn.AvgPool3d( 190 | kernel_size=pool_size, stride=stride, padding=padding, count_include_pad=False 191 | ), 192 | ), 193 | ('conv_bn_relu', conv_1x1x1_norm_activated(in_channels, reduction_channels)), 194 | ] 195 | ) 196 | ) 197 | ) 198 | self.features = nn.ModuleList(self.features) 199 | 200 | def forward(self, *inputs): 201 | (x,) = inputs 202 | b, _, t, h, w = x.shape 203 | # Do not include current tensor when concatenating 204 | out = [] 205 | for f in self.features: 206 | # Remove unnecessary padded values (time dimension) on the right 207 | x_pool = f(x)[:, :, :-1].contiguous() 208 | c = x_pool.shape[1] 209 | x_pool = nn.functional.interpolate( 210 | x_pool.view(b * t, c, *x_pool.shape[-2:]), (h, w), mode='bilinear', align_corners=False 211 | ) 212 | x_pool = x_pool.view(b, c, t, h, w) 213 | out.append(x_pool) 214 | out = torch.cat(out, 1) 215 | return out 216 | 217 | 218 | class TemporalBlock(nn.Module): 219 | """ Temporal block with the following layers: 220 | - 2x3x3, 1x3x3, spatio-temporal pyramid pooling 221 | - dropout 222 | - skip connection. 223 | """ 224 | 225 | def __init__(self, in_channels, out_channels=None, use_pyramid_pooling=False, pool_sizes=None): 226 | super().__init__() 227 | self.in_channels = in_channels 228 | self.half_channels = in_channels // 2 229 | self.out_channels = out_channels or self.in_channels 230 | self.kernels = [(2, 3, 3), (1, 3, 3)] 231 | 232 | # Flag for spatio-temporal pyramid pooling 233 | self.use_pyramid_pooling = use_pyramid_pooling 234 | 235 | # 3 convolution paths: 2x3x3, 1x3x3, 1x1x1 236 | self.convolution_paths = [] 237 | for kernel_size in self.kernels: 238 | self.convolution_paths.append( 239 | nn.Sequential( 240 | conv_1x1x1_norm_activated(self.in_channels, self.half_channels), 241 | CausalConv3d(self.half_channels, self.half_channels, kernel_size=kernel_size), 242 | ) 243 | ) 244 | self.convolution_paths.append(conv_1x1x1_norm_activated(self.in_channels, self.half_channels)) 245 | self.convolution_paths = nn.ModuleList(self.convolution_paths) 246 | 247 | agg_in_channels = len(self.convolution_paths) * self.half_channels 248 | 249 | if self.use_pyramid_pooling: 250 | assert pool_sizes is not None, "setting must contain the list of kernel_size, but is None." 251 | reduction_channels = self.in_channels // 3 252 | self.pyramid_pooling = PyramidSpatioTemporalPooling(self.in_channels, reduction_channels, pool_sizes) 253 | agg_in_channels += len(pool_sizes) * reduction_channels 254 | 255 | # Feature aggregation 256 | self.aggregation = nn.Sequential( 257 | conv_1x1x1_norm_activated(agg_in_channels, self.out_channels),) 258 | 259 | if self.out_channels != self.in_channels: 260 | self.projection = nn.Sequential( 261 | nn.Conv3d(self.in_channels, self.out_channels, kernel_size=1, bias=False), 262 | nn.BatchNorm3d(self.out_channels), 263 | ) 264 | else: 265 | self.projection = None 266 | 267 | def forward(self, *inputs): 268 | (x,) = inputs 269 | x_paths = [] 270 | for conv in self.convolution_paths: 271 | x_paths.append(conv(x)) 272 | x_residual = torch.cat(x_paths, dim=1) 273 | if self.use_pyramid_pooling: 274 | x_pool = self.pyramid_pooling(x) 275 | x_residual = torch.cat([x_residual, x_pool], dim=1) 276 | x_residual = self.aggregation(x_residual) 277 | 278 | if self.out_channels != self.in_channels: 279 | x = self.projection(x) 280 | x = x + x_residual 281 | return x 282 | -------------------------------------------------------------------------------- /stretchbev/.ipynb_checkpoints/metrics-checkpoint.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | from pytorch_lightning.metrics.metric import Metric 5 | from pytorch_lightning.metrics.functional.classification import stat_scores_multiple_classes 6 | from pytorch_lightning.metrics.functional.reduction import reduce 7 | 8 | 9 | class IntersectionOverUnion(Metric): 10 | """Computes intersection-over-union.""" 11 | def __init__( 12 | self, 13 | n_classes: int, 14 | ignore_index: Optional[int] = None, 15 | absent_score: float = 0.0, 16 | reduction: str = 'none', 17 | compute_on_step: bool = False, 18 | ): 19 | super().__init__(compute_on_step=compute_on_step) 20 | 21 | self.n_classes = n_classes 22 | self.ignore_index = ignore_index 23 | self.absent_score = absent_score 24 | self.reduction = reduction 25 | 26 | self.add_state('true_positive', default=torch.zeros(n_classes), dist_reduce_fx='sum') 27 | self.add_state('false_positive', default=torch.zeros(n_classes), dist_reduce_fx='sum') 28 | self.add_state('false_negative', default=torch.zeros(n_classes), dist_reduce_fx='sum') 29 | self.add_state('support', default=torch.zeros(n_classes), dist_reduce_fx='sum') 30 | 31 | def update(self, prediction: torch.Tensor, target: torch.Tensor): 32 | tps, fps, _, fns, sups = stat_scores_multiple_classes(prediction, target, self.n_classes) 33 | 34 | self.true_positive += tps 35 | self.false_positive += fps 36 | self.false_negative += fns 37 | self.support += sups 38 | 39 | def compute(self): 40 | scores = torch.zeros(self.n_classes, device=self.true_positive.device, dtype=torch.float32) 41 | 42 | for class_idx in range(self.n_classes): 43 | if class_idx == self.ignore_index: 44 | continue 45 | 46 | tp = self.true_positive[class_idx] 47 | fp = self.false_positive[class_idx] 48 | fn = self.false_negative[class_idx] 49 | sup = self.support[class_idx] 50 | 51 | # If this class is absent in the target (no support) AND absent in the pred (no true or false 52 | # positives), then use the absent_score for this class. 53 | if sup + tp + fp == 0: 54 | scores[class_idx] = self.absent_score 55 | continue 56 | 57 | denominator = tp + fp + fn 58 | score = tp.to(torch.float) / denominator 59 | scores[class_idx] = score 60 | 61 | # Remove the ignored class index from the scores. 62 | if (self.ignore_index is not None) and (0 <= self.ignore_index < self.n_classes): 63 | scores = torch.cat([scores[:self.ignore_index], scores[self.ignore_index+1:]]) 64 | 65 | return reduce(scores, reduction=self.reduction) 66 | 67 | 68 | class PanopticMetric(Metric): 69 | def __init__( 70 | self, 71 | n_classes: int, 72 | temporally_consistent: bool = True, 73 | vehicles_id: int = 1, 74 | compute_on_step: bool = False, 75 | ): 76 | super().__init__(compute_on_step=compute_on_step) 77 | 78 | self.n_classes = n_classes 79 | self.temporally_consistent = temporally_consistent 80 | self.vehicles_id = vehicles_id 81 | self.keys = ['iou', 'true_positive', 'false_positive', 'false_negative'] 82 | 83 | self.add_state('iou', default=torch.zeros(n_classes), dist_reduce_fx='sum') 84 | self.add_state('true_positive', default=torch.zeros(n_classes), dist_reduce_fx='sum') 85 | self.add_state('false_positive', default=torch.zeros(n_classes), dist_reduce_fx='sum') 86 | self.add_state('false_negative', default=torch.zeros(n_classes), dist_reduce_fx='sum') 87 | 88 | def update(self, pred_instance, gt_instance): 89 | """ 90 | Update state with predictions and targets. 91 | 92 | Parameters 93 | ---------- 94 | pred_instance: (b, s, h, w) 95 | Temporally consistent instance segmentation prediction. 96 | gt_instance: (b, s, h, w) 97 | Ground truth instance segmentation. 98 | """ 99 | batch_size, sequence_length = gt_instance.shape[:2] 100 | # Process labels 101 | assert gt_instance.min() == 0, 'ID 0 of gt_instance must be background' 102 | pred_segmentation = (pred_instance > 0).long() 103 | gt_segmentation = (gt_instance > 0).long() 104 | 105 | for b in range(batch_size): 106 | unique_id_mapping = {} 107 | for t in range(sequence_length): 108 | result = self.panoptic_metrics( 109 | pred_segmentation[b, t].detach(), 110 | pred_instance[b, t].detach(), 111 | gt_segmentation[b, t], 112 | gt_instance[b, t], 113 | unique_id_mapping, 114 | ) 115 | 116 | self.iou += result['iou'] 117 | self.true_positive += result['true_positive'] 118 | self.false_positive += result['false_positive'] 119 | self.false_negative += result['false_negative'] 120 | 121 | def compute(self): 122 | denominator = torch.maximum( 123 | (self.true_positive + self.false_positive / 2 + self.false_negative / 2), 124 | torch.ones_like(self.true_positive) 125 | ) 126 | pq = self.iou / denominator 127 | sq = self.iou / torch.maximum(self.true_positive, torch.ones_like(self.true_positive)) 128 | rq = self.true_positive / denominator 129 | 130 | return {'pq': pq, 131 | 'sq': sq, 132 | 'rq': rq, 133 | # If 0, it means there wasn't any detection. 134 | 'denominator': (self.true_positive + self.false_positive / 2 + self.false_negative / 2), 135 | } 136 | 137 | def panoptic_metrics(self, pred_segmentation, pred_instance, gt_segmentation, gt_instance, unique_id_mapping): 138 | """ 139 | Computes panoptic quality metric components. 140 | 141 | Parameters 142 | ---------- 143 | pred_segmentation: [H, W] range {0, ..., n_classes-1} (>= n_classes is void) 144 | pred_instance: [H, W] range {0, ..., n_instances} (zero means background) 145 | gt_segmentation: [H, W] range {0, ..., n_classes-1} (>= n_classes is void) 146 | gt_instance: [H, W] range {0, ..., n_instances} (zero means background) 147 | unique_id_mapping: instance id mapping to check consistency 148 | """ 149 | n_classes = self.n_classes 150 | 151 | result = {key: torch.zeros(n_classes, dtype=torch.float32, device=gt_instance.device) for key in self.keys} 152 | 153 | assert pred_segmentation.dim() == 2 154 | assert pred_segmentation.shape == pred_instance.shape == gt_segmentation.shape == gt_instance.shape 155 | 156 | n_instances = int(torch.cat([pred_instance, gt_instance]).max().item()) 157 | n_all_things = n_instances + n_classes # Classes + instances. 158 | n_things_and_void = n_all_things + 1 159 | 160 | # Now 1 is background; 0 is void (not used). 2 is vehicle semantic class but since it overlaps with 161 | # instances, it is not present. 162 | # and the rest are instance ids starting from 3 163 | prediction, pred_to_cls = self.combine_mask(pred_segmentation, pred_instance, n_classes, n_all_things) 164 | target, target_to_cls = self.combine_mask(gt_segmentation, gt_instance, n_classes, n_all_things) 165 | 166 | # Compute ious between all stuff and things 167 | # hack for bincounting 2 arrays together 168 | x = prediction + n_things_and_void * target 169 | bincount_2d = torch.bincount(x.long(), minlength=n_things_and_void ** 2) 170 | if bincount_2d.shape[0] != n_things_and_void ** 2: 171 | raise ValueError('Incorrect bincount size.') 172 | conf = bincount_2d.reshape((n_things_and_void, n_things_and_void)) 173 | # Drop void class 174 | conf = conf[1:, 1:] 175 | 176 | # Confusion matrix contains intersections between all combinations of classes 177 | union = conf.sum(0).unsqueeze(0) + conf.sum(1).unsqueeze(1) - conf 178 | iou = torch.where(union > 0, (conf.float() + 1e-9) / (union.float() + 1e-9), torch.zeros_like(union).float()) 179 | 180 | # In the iou matrix, first dimension is target idx, second dimension is pred idx. 181 | # Mapping will contain a tuple that maps prediction idx to target idx for segments matched by iou. 182 | mapping = (iou > 0.5).nonzero(as_tuple=False) 183 | 184 | # Check that classes match. 185 | is_matching = pred_to_cls[mapping[:, 1]] == target_to_cls[mapping[:, 0]] 186 | mapping = mapping[is_matching] 187 | tp_mask = torch.zeros_like(conf, dtype=torch.bool) 188 | tp_mask[mapping[:, 0], mapping[:, 1]] = True 189 | 190 | # First ids correspond to "stuff" i.e. semantic seg. 191 | # Instance ids are offset accordingly 192 | for target_id, pred_id in mapping: 193 | cls_id = pred_to_cls[pred_id] 194 | 195 | if self.temporally_consistent and cls_id == self.vehicles_id: 196 | if target_id.item() in unique_id_mapping and unique_id_mapping[target_id.item()] != pred_id.item(): 197 | # Not temporally consistent 198 | result['false_negative'][target_to_cls[target_id]] += 1 199 | result['false_positive'][pred_to_cls[pred_id]] += 1 200 | unique_id_mapping[target_id.item()] = pred_id.item() 201 | continue 202 | 203 | result['true_positive'][cls_id] += 1 204 | result['iou'][cls_id] += iou[target_id][pred_id] 205 | unique_id_mapping[target_id.item()] = pred_id.item() 206 | 207 | for target_id in range(n_classes, n_all_things): 208 | # If this is a true positive do nothing. 209 | if tp_mask[target_id, n_classes:].any(): 210 | continue 211 | # If this target instance didn't match with any predictions and was present set it as false negative. 212 | if target_to_cls[target_id] != -1: 213 | result['false_negative'][target_to_cls[target_id]] += 1 214 | 215 | for pred_id in range(n_classes, n_all_things): 216 | # If this is a true positive do nothing. 217 | if tp_mask[n_classes:, pred_id].any(): 218 | continue 219 | # If this predicted instance didn't match with any prediction, set that predictions as false positive. 220 | if pred_to_cls[pred_id] != -1 and (conf[:, pred_id] > 0).any(): 221 | result['false_positive'][pred_to_cls[pred_id]] += 1 222 | 223 | return result 224 | 225 | def combine_mask(self, segmentation: torch.Tensor, instance: torch.Tensor, n_classes: int, n_all_things: int): 226 | """Shifts all things ids by num_classes and combines things and stuff into a single mask 227 | 228 | Returns a combined mask + a mapping from id to segmentation class. 229 | """ 230 | instance = instance.view(-1) 231 | instance_mask = instance > 0 232 | instance = instance - 1 + n_classes 233 | 234 | segmentation = segmentation.clone().view(-1) 235 | segmentation_mask = segmentation < n_classes # Remove void pixels. 236 | 237 | # Build an index from instance id to class id. 238 | instance_id_to_class_tuples = torch.cat( 239 | ( 240 | instance[instance_mask & segmentation_mask].unsqueeze(1), 241 | segmentation[instance_mask & segmentation_mask].unsqueeze(1), 242 | ), 243 | dim=1, 244 | ) 245 | instance_id_to_class = -instance_id_to_class_tuples.new_ones((n_all_things,)) 246 | instance_id_to_class[instance_id_to_class_tuples[:, 0]] = instance_id_to_class_tuples[:, 1] 247 | instance_id_to_class[torch.arange(n_classes, device=segmentation.device)] = torch.arange( 248 | n_classes, device=segmentation.device 249 | ) 250 | 251 | segmentation[instance_mask] = instance[instance_mask] 252 | segmentation += 1 # Shift all legit classes by 1. 253 | segmentation[~segmentation_mask] = 0 # Shift void class to zero. 254 | 255 | return segmentation, instance_id_to_class 256 | -------------------------------------------------------------------------------- /stretchbev/layers/temporal.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from stretchbev.layers.convolutions import ConvBlock 7 | from stretchbev.utils.geometry import warp_features 8 | 9 | 10 | class SpatialGRU(nn.Module): 11 | """A GRU cell that takes an input tensor [BxTxCxHxW] and an optional previous state and passes a 12 | convolutional gated recurrent unit over the data""" 13 | 14 | def __init__(self, input_size, hidden_size, gru_bias_init=0.0, norm='bn', activation='relu'): 15 | super().__init__() 16 | self.input_size = input_size 17 | self.hidden_size = hidden_size 18 | self.gru_bias_init = gru_bias_init 19 | 20 | self.conv_update = nn.Conv2d(input_size + hidden_size, hidden_size, kernel_size=3, bias=True, padding=1) 21 | self.conv_reset = nn.Conv2d(input_size + hidden_size, hidden_size, kernel_size=3, bias=True, padding=1) 22 | 23 | self.conv_state_tilde = ConvBlock( 24 | input_size + hidden_size, hidden_size, kernel_size=3, bias=False, norm=norm, activation=activation 25 | ) 26 | 27 | def forward(self, x, state=None, flow=None, mode='bilinear'): 28 | # pylint: disable=unused-argument, arguments-differ 29 | # Check size 30 | assert len(x.size()) == 5, 'Input tensor must be BxTxCxHxW.' 31 | b, timesteps, c, h, w = x.size() 32 | assert c == self.input_size, f'feature sizes must match, got input {c} for layer with size {self.input_size}' 33 | 34 | # recurrent layers 35 | rnn_output = [] 36 | rnn_state = torch.zeros(b, self.hidden_size, h, w, device=x.device) if state is None else state 37 | for t in range(timesteps): 38 | x_t = x[:, t] 39 | if flow is not None: 40 | rnn_state = warp_features(rnn_state, flow[:, t], mode=mode) 41 | 42 | # propagate rnn state 43 | rnn_state = self.gru_cell(x_t, rnn_state) 44 | rnn_output.append(rnn_state) 45 | 46 | # reshape rnn output to batch tensor 47 | return torch.stack(rnn_output, dim=1) 48 | 49 | def gru_cell(self, x, state): 50 | # Compute gates 51 | x_and_state = torch.cat([x, state], dim=1) 52 | update_gate = self.conv_update(x_and_state) 53 | reset_gate = self.conv_reset(x_and_state) 54 | # Add bias to initialise gate as close to identity function 55 | update_gate = torch.sigmoid(update_gate + self.gru_bias_init) 56 | reset_gate = torch.sigmoid(reset_gate + self.gru_bias_init) 57 | 58 | # Compute proposal state, activation is defined in norm_act_config (can be tanh, ReLU etc) 59 | state_tilde = self.conv_state_tilde(torch.cat([x, (1.0 - reset_gate) * state], dim=1)) 60 | 61 | output = (1.0 - update_gate) * state + update_gate * state_tilde 62 | return output 63 | 64 | 65 | class CausalConv3d(nn.Module): 66 | def __init__(self, in_channels, out_channels, kernel_size=(2, 3, 3), dilation=(1, 1, 1), bias=False): 67 | super().__init__() 68 | assert len(kernel_size) == 3, 'kernel_size must be a 3-tuple.' 69 | time_pad = (kernel_size[0] - 1) * dilation[0] 70 | height_pad = ((kernel_size[1] - 1) * dilation[1]) // 2 71 | width_pad = ((kernel_size[2] - 1) * dilation[2]) // 2 72 | 73 | # Pad temporally on the left 74 | self.pad = nn.ConstantPad3d(padding=(width_pad, width_pad, height_pad, height_pad, time_pad, 0), value=0) 75 | self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, dilation=dilation, stride=1, padding=0, bias=bias) 76 | self.norm = nn.BatchNorm3d(out_channels) 77 | self.activation = nn.ReLU(inplace=True) 78 | 79 | def forward(self, *inputs): 80 | (x,) = inputs 81 | x = self.pad(x) 82 | x = self.conv(x) 83 | x = self.norm(x) 84 | x = self.activation(x) 85 | return x 86 | 87 | 88 | class CausalMaxPool3d(nn.Module): 89 | def __init__(self, kernel_size=(2, 3, 3)): 90 | super().__init__() 91 | assert len(kernel_size) == 3, 'kernel_size must be a 3-tuple.' 92 | time_pad = kernel_size[0] - 1 93 | height_pad = (kernel_size[1] - 1) // 2 94 | width_pad = (kernel_size[2] - 1) // 2 95 | 96 | # Pad temporally on the left 97 | self.pad = nn.ConstantPad3d(padding=(width_pad, width_pad, height_pad, height_pad, time_pad, 0), value=0) 98 | self.max_pool = nn.MaxPool3d(kernel_size, stride=1) 99 | 100 | def forward(self, *inputs): 101 | (x,) = inputs 102 | x = self.pad(x) 103 | x = self.max_pool(x) 104 | return x 105 | 106 | 107 | def conv_1x1x1_norm_activated(in_channels, out_channels): 108 | """1x1x1 3D convolution, normalization and activation layer.""" 109 | return nn.Sequential( 110 | OrderedDict( 111 | [ 112 | ('conv', nn.Conv3d(in_channels, out_channels, kernel_size=1, bias=False)), 113 | ('norm', nn.BatchNorm3d(out_channels)), 114 | ('activation', nn.ReLU(inplace=True)), 115 | ] 116 | ) 117 | ) 118 | 119 | 120 | class Bottleneck3D(nn.Module): 121 | """ 122 | Defines a bottleneck module with a residual connection 123 | """ 124 | 125 | def __init__(self, in_channels, out_channels=None, kernel_size=(2, 3, 3), dilation=(1, 1, 1)): 126 | super().__init__() 127 | bottleneck_channels = in_channels // 2 128 | out_channels = out_channels or in_channels 129 | 130 | self.layers = nn.Sequential( 131 | OrderedDict( 132 | [ 133 | # First projection with 1x1 kernel 134 | ( 135 | 'conv_down_project', conv_1x1x1_norm_activated(in_channels, bottleneck_channels) 136 | ), 137 | # Second conv block 138 | ( 139 | 'conv', 140 | CausalConv3d( 141 | bottleneck_channels, 142 | bottleneck_channels, 143 | kernel_size=kernel_size, 144 | dilation=dilation, 145 | bias=False, 146 | ), 147 | ), 148 | # Final projection with 1x1 kernel 149 | ( 150 | 'conv_up_project', conv_1x1x1_norm_activated(bottleneck_channels, out_channels) 151 | ), 152 | ] 153 | ) 154 | ) 155 | 156 | if out_channels != in_channels: 157 | self.projection = nn.Sequential( 158 | nn.Conv3d(in_channels, out_channels, kernel_size=1, bias=False), 159 | nn.BatchNorm3d(out_channels), 160 | ) 161 | else: 162 | self.projection = None 163 | 164 | def forward(self, *args): 165 | (x,) = args 166 | x_residual = self.layers(x) 167 | x_features = self.projection(x) if self.projection is not None else x 168 | return x_residual + x_features 169 | 170 | 171 | class PyramidSpatioTemporalPooling(nn.Module): 172 | """ Spatio-temporal pyramid pooling. 173 | Performs 3D average pooling followed by 1x1x1 convolution to reduce the number of channels and upsampling. 174 | Setting contains a list of kernel_size: usually it is [(2, h, w), (2, h//2, w//2), (2, h//4, w//4)] 175 | """ 176 | 177 | def __init__(self, in_channels, reduction_channels, pool_sizes): 178 | super().__init__() 179 | self.features = [] 180 | for pool_size in pool_sizes: 181 | assert pool_size[0] == 2, ( 182 | "Time kernel should be 2 as PyTorch raises an error when" "padding with more than half the kernel size" 183 | ) 184 | stride = (1, *pool_size[1:]) 185 | padding = (pool_size[0] - 1, 0, 0) 186 | self.features.append( 187 | nn.Sequential( 188 | OrderedDict( 189 | [ 190 | # Pad the input tensor but do not take into account zero padding into the average. 191 | ( 192 | 'avgpool', 193 | torch.nn.AvgPool3d( 194 | kernel_size=pool_size, stride=stride, padding=padding, count_include_pad=False 195 | ), 196 | ), 197 | ( 198 | 'conv_bn_relu', conv_1x1x1_norm_activated(in_channels, reduction_channels) 199 | ), 200 | ] 201 | ) 202 | ) 203 | ) 204 | self.features = nn.ModuleList(self.features) 205 | 206 | def forward(self, *inputs): 207 | (x,) = inputs 208 | b, _, t, h, w = x.shape 209 | # Do not include current tensor when concatenating 210 | out = [] 211 | for f in self.features: 212 | # Remove unnecessary padded values (time dimension) on the right 213 | x_pool = f(x)[:, :, :-1].contiguous() 214 | c = x_pool.shape[1] 215 | x_pool = nn.functional.interpolate( 216 | x_pool.view(b * t, c, *x_pool.shape[-2:]), (h, w), mode='bilinear', align_corners=False 217 | ) 218 | x_pool = x_pool.view(b, c, t, h, w) 219 | out.append(x_pool) 220 | out = torch.cat(out, 1) 221 | return out 222 | 223 | 224 | class TemporalBlock(nn.Module): 225 | """ Temporal block with the following layers: 226 | - 2x3x3, 1x3x3, spatio-temporal pyramid pooling 227 | - dropout 228 | - skip connection. 229 | """ 230 | 231 | def __init__(self, in_channels, out_channels=None, use_pyramid_pooling=False, pool_sizes=None): 232 | super().__init__() 233 | self.in_channels = in_channels 234 | self.half_channels = in_channels // 2 235 | self.out_channels = out_channels or self.in_channels 236 | self.kernels = [(2, 3, 3), (1, 3, 3)] 237 | 238 | # Flag for spatio-temporal pyramid pooling 239 | self.use_pyramid_pooling = use_pyramid_pooling 240 | 241 | # 3 convolution paths: 2x3x3, 1x3x3, 1x1x1 242 | self.convolution_paths = [] 243 | for kernel_size in self.kernels: 244 | self.convolution_paths.append( 245 | nn.Sequential( 246 | conv_1x1x1_norm_activated(self.in_channels, self.half_channels), 247 | CausalConv3d(self.half_channels, self.half_channels, kernel_size=kernel_size), 248 | ) 249 | ) 250 | self.convolution_paths.append(conv_1x1x1_norm_activated(self.in_channels, self.half_channels)) 251 | self.convolution_paths = nn.ModuleList(self.convolution_paths) 252 | 253 | agg_in_channels = len(self.convolution_paths) * self.half_channels 254 | 255 | if self.use_pyramid_pooling: 256 | assert pool_sizes is not None, "setting must contain the list of kernel_size, but is None." 257 | reduction_channels = self.in_channels // 3 258 | self.pyramid_pooling = PyramidSpatioTemporalPooling(self.in_channels, reduction_channels, pool_sizes) 259 | agg_in_channels += len(pool_sizes) * reduction_channels 260 | 261 | # Feature aggregation 262 | self.aggregation = nn.Sequential( 263 | conv_1x1x1_norm_activated(agg_in_channels, self.out_channels), ) 264 | 265 | if self.out_channels != self.in_channels: 266 | self.projection = nn.Sequential( 267 | nn.Conv3d(self.in_channels, self.out_channels, kernel_size=1, bias=False), 268 | nn.BatchNorm3d(self.out_channels), 269 | ) 270 | else: 271 | self.projection = None 272 | 273 | def forward(self, *inputs): 274 | (x,) = inputs 275 | x_paths = [] 276 | for conv in self.convolution_paths: 277 | x_paths.append(conv(x)) 278 | x_residual = torch.cat(x_paths, dim=1) 279 | if self.use_pyramid_pooling: 280 | x_pool = self.pyramid_pooling(x) 281 | x_residual = torch.cat([x_residual, x_pool], dim=1) 282 | x_residual = self.aggregation(x_residual) 283 | 284 | if self.out_channels != self.in_channels: 285 | x = self.projection(x) 286 | x = x + x_residual 287 | return x 288 | -------------------------------------------------------------------------------- /stretchbev/metrics.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | from pytorch_lightning.metrics.functional.classification import stat_scores_multiple_classes 5 | from pytorch_lightning.metrics.functional.reduction import reduce 6 | from pytorch_lightning.metrics.metric import Metric 7 | 8 | 9 | class IntersectionOverUnion(Metric): 10 | """Computes intersection-over-union.""" 11 | 12 | def __init__( 13 | self, 14 | n_classes: int, 15 | ignore_index: Optional[int] = None, 16 | absent_score: float = 0.0, 17 | reduction: str = 'none', 18 | compute_on_step: bool = False, 19 | ): 20 | super().__init__(compute_on_step=compute_on_step) 21 | 22 | self.n_classes = n_classes 23 | self.ignore_index = ignore_index 24 | self.absent_score = absent_score 25 | self.reduction = reduction 26 | 27 | self.add_state('true_positive', default=torch.zeros(n_classes), dist_reduce_fx='sum') 28 | self.add_state('false_positive', default=torch.zeros(n_classes), dist_reduce_fx='sum') 29 | self.add_state('false_negative', default=torch.zeros(n_classes), dist_reduce_fx='sum') 30 | self.add_state('support', default=torch.zeros(n_classes), dist_reduce_fx='sum') 31 | 32 | def update(self, prediction: torch.Tensor, target: torch.Tensor): 33 | tps, fps, _, fns, sups = stat_scores_multiple_classes(prediction, target, self.n_classes) 34 | 35 | # print('update', tps, fps, fns, sups) 36 | self.true_positive += tps 37 | self.false_positive += fps 38 | self.false_negative += fns 39 | self.support += sups 40 | 41 | # return tps, fps, fns, sups 42 | 43 | def compute(self): 44 | scores = torch.zeros(self.n_classes, device=self.true_positive.device, dtype=torch.float32) 45 | 46 | # print('compute', self.true_positive, self.false_positive, self.false_negative, self.support) 47 | for class_idx in range(self.n_classes): 48 | if class_idx == self.ignore_index: 49 | continue 50 | 51 | tp = self.true_positive[class_idx] 52 | fp = self.false_positive[class_idx] 53 | fn = self.false_negative[class_idx] 54 | sup = self.support[class_idx] 55 | 56 | # If this class is absent in the target (no support) AND absent in the pred (no true or false 57 | # positives), then use the absent_score for this class. 58 | if sup + tp + fp == 0: 59 | scores[class_idx] = self.absent_score 60 | continue 61 | 62 | denominator = tp + fp + fn 63 | score = tp.to(torch.float) / denominator 64 | scores[class_idx] = score 65 | 66 | # Remove the ignored class index from the scores. 67 | if (self.ignore_index is not None) and (0 <= self.ignore_index < self.n_classes): 68 | scores = torch.cat([scores[:self.ignore_index], scores[self.ignore_index + 1:]]) 69 | 70 | return reduce(scores, reduction=self.reduction) 71 | 72 | def calculate_batch(self, prediction: torch.Tensor, target: torch.Tensor): 73 | self.reset() 74 | self.update(prediction, target) 75 | 76 | return self.compute() 77 | 78 | 79 | class PanopticMetric(Metric): 80 | def __init__( 81 | self, 82 | n_classes: int, 83 | temporally_consistent: bool = True, 84 | vehicles_id: int = 1, 85 | compute_on_step: bool = False, 86 | ): 87 | super().__init__(compute_on_step=compute_on_step) 88 | 89 | self.n_classes = n_classes 90 | self.temporally_consistent = temporally_consistent 91 | self.vehicles_id = vehicles_id 92 | self.keys = ['iou', 'true_positive', 'false_positive', 'false_negative'] 93 | 94 | self.add_state('iou', default=torch.zeros(n_classes), dist_reduce_fx='sum') 95 | self.add_state('true_positive', default=torch.zeros(n_classes), dist_reduce_fx='sum') 96 | self.add_state('false_positive', default=torch.zeros(n_classes), dist_reduce_fx='sum') 97 | self.add_state('false_negative', default=torch.zeros(n_classes), dist_reduce_fx='sum') 98 | 99 | def update(self, pred_instance, gt_instance): 100 | """ 101 | Update state with predictions and targets. 102 | 103 | Parameters 104 | ---------- 105 | pred_instance: (b, s, h, w) 106 | Temporally consistent instance segmentation prediction. 107 | gt_instance: (b, s, h, w) 108 | Ground truth instance segmentation. 109 | """ 110 | batch_size, sequence_length = gt_instance.shape[:2] 111 | # Process labels 112 | assert gt_instance.min() == 0, 'ID 0 of gt_instance must be background' 113 | pred_segmentation = (pred_instance > 0).long() 114 | gt_segmentation = (gt_instance > 0).long() 115 | 116 | for b in range(batch_size): 117 | unique_id_mapping = {} 118 | for t in range(sequence_length): 119 | result = self.panoptic_metrics( 120 | pred_segmentation[b, t].detach(), 121 | pred_instance[b, t].detach(), 122 | gt_segmentation[b, t], 123 | gt_instance[b, t], 124 | unique_id_mapping, 125 | ) 126 | 127 | self.iou += result['iou'] 128 | self.true_positive += result['true_positive'] 129 | self.false_positive += result['false_positive'] 130 | self.false_negative += result['false_negative'] 131 | 132 | def compute(self): 133 | denominator = torch.maximum( 134 | (self.true_positive + self.false_positive / 2 + self.false_negative / 2), 135 | torch.ones_like(self.true_positive) 136 | ) 137 | pq = self.iou / denominator 138 | sq = self.iou / torch.maximum(self.true_positive, torch.ones_like(self.true_positive)) 139 | rq = self.true_positive / denominator 140 | 141 | return {'pq': pq, 142 | 'sq': sq, 143 | 'rq': rq, 144 | #  If 0, it means there wasn't any detection. 145 | 'denominator': (self.true_positive + self.false_positive / 2 + self.false_negative / 2), 146 | } 147 | 148 | def calculate_batch(self, pred_instance, gt_instance): 149 | self.reset() 150 | self.update(pred_instance, gt_instance) 151 | return self.compute() 152 | 153 | def panoptic_metrics(self, pred_segmentation, pred_instance, gt_segmentation, gt_instance, unique_id_mapping): 154 | """ 155 | Computes panoptic quality metric components. 156 | 157 | Parameters 158 | ---------- 159 | pred_segmentation: [H, W] range {0, ..., n_classes-1} (>= n_classes is void) 160 | pred_instance: [H, W] range {0, ..., n_instances} (zero means background) 161 | gt_segmentation: [H, W] range {0, ..., n_classes-1} (>= n_classes is void) 162 | gt_instance: [H, W] range {0, ..., n_instances} (zero means background) 163 | unique_id_mapping: instance id mapping to check consistency 164 | """ 165 | n_classes = self.n_classes 166 | 167 | result = {key: torch.zeros(n_classes, dtype=torch.float32, device=gt_instance.device) for key in self.keys} 168 | 169 | assert pred_segmentation.dim() == 2 170 | assert pred_segmentation.shape == pred_instance.shape == gt_segmentation.shape == gt_instance.shape 171 | 172 | n_instances = int(torch.cat([pred_instance, gt_instance]).max().item()) 173 | n_all_things = n_instances + n_classes # Classes + instances. 174 | n_things_and_void = n_all_things + 1 175 | 176 | #  Now 1 is background; 0 is void (not used). 2 is vehicle semantic class but since it overlaps with 177 | # instances, it is not present. 178 | # and the rest are instance ids starting from 3 179 | prediction, pred_to_cls = self.combine_mask(pred_segmentation, pred_instance, n_classes, n_all_things) 180 | target, target_to_cls = self.combine_mask(gt_segmentation, gt_instance, n_classes, n_all_things) 181 | 182 | # Compute ious between all stuff and things 183 | # hack for bincounting 2 arrays together 184 | x = prediction + n_things_and_void * target 185 | bincount_2d = torch.bincount(x.long(), minlength=n_things_and_void ** 2) 186 | if bincount_2d.shape[0] != n_things_and_void ** 2: 187 | raise ValueError('Incorrect bincount size.') 188 | conf = bincount_2d.reshape((n_things_and_void, n_things_and_void)) 189 | # Drop void class 190 | conf = conf[1:, 1:] 191 | 192 | # Confusion matrix contains intersections between all combinations of classes 193 | union = conf.sum(0).unsqueeze(0) + conf.sum(1).unsqueeze(1) - conf 194 | iou = torch.where(union > 0, (conf.float() + 1e-9) / (union.float() + 1e-9), torch.zeros_like(union).float()) 195 | 196 | # In the iou matrix, first dimension is target idx, second dimension is pred idx. 197 | # Mapping will contain a tuple that maps prediction idx to target idx for segments matched by iou. 198 | mapping = (iou > 0.5).nonzero(as_tuple=False) 199 | 200 | # Check that classes match. 201 | is_matching = pred_to_cls[mapping[:, 1]] == target_to_cls[mapping[:, 0]] 202 | mapping = mapping[is_matching] 203 | tp_mask = torch.zeros_like(conf, dtype=torch.bool) 204 | tp_mask[mapping[:, 0], mapping[:, 1]] = True 205 | 206 | # First ids correspond to "stuff" i.e. semantic seg. 207 | # Instance ids are offset accordingly 208 | for target_id, pred_id in mapping: 209 | cls_id = pred_to_cls[pred_id] 210 | 211 | if self.temporally_consistent and cls_id == self.vehicles_id: 212 | if target_id.item() in unique_id_mapping and unique_id_mapping[target_id.item()] != pred_id.item(): 213 | # Not temporally consistent 214 | result['false_negative'][target_to_cls[target_id]] += 1 215 | result['false_positive'][pred_to_cls[pred_id]] += 1 216 | unique_id_mapping[target_id.item()] = pred_id.item() 217 | continue 218 | 219 | result['true_positive'][cls_id] += 1 220 | result['iou'][cls_id] += iou[target_id][pred_id] 221 | unique_id_mapping[target_id.item()] = pred_id.item() 222 | 223 | for target_id in range(n_classes, n_all_things): 224 | # If this is a true positive do nothing. 225 | if tp_mask[target_id, n_classes:].any(): 226 | continue 227 | # If this target instance didn't match with any predictions and was present set it as false negative. 228 | if target_to_cls[target_id] != -1: 229 | result['false_negative'][target_to_cls[target_id]] += 1 230 | 231 | for pred_id in range(n_classes, n_all_things): 232 | # If this is a true positive do nothing. 233 | if tp_mask[n_classes:, pred_id].any(): 234 | continue 235 | # If this predicted instance didn't match with any prediction, set that predictions as false positive. 236 | if pred_to_cls[pred_id] != -1 and (conf[:, pred_id] > 0).any(): 237 | result['false_positive'][pred_to_cls[pred_id]] += 1 238 | 239 | return result 240 | 241 | def combine_mask(self, segmentation: torch.Tensor, instance: torch.Tensor, n_classes: int, n_all_things: int): 242 | """Shifts all things ids by num_classes and combines things and stuff into a single mask 243 | 244 | Returns a combined mask + a mapping from id to segmentation class. 245 | """ 246 | instance = instance.view(-1) 247 | instance_mask = instance > 0 248 | instance = instance - 1 + n_classes 249 | 250 | segmentation = segmentation.clone().view(-1) 251 | segmentation_mask = segmentation < n_classes # Remove void pixels. 252 | 253 | # Build an index from instance id to class id. 254 | instance_id_to_class_tuples = torch.cat( 255 | ( 256 | instance[instance_mask & segmentation_mask].unsqueeze(1), 257 | segmentation[instance_mask & segmentation_mask].unsqueeze(1), 258 | ), 259 | dim=1, 260 | ) 261 | instance_id_to_class = -instance_id_to_class_tuples.new_ones((n_all_things,)) 262 | instance_id_to_class[instance_id_to_class_tuples[:, 0]] = instance_id_to_class_tuples[:, 1] 263 | instance_id_to_class[torch.arange(n_classes, device=segmentation.device)] = torch.arange( 264 | n_classes, device=segmentation.device 265 | ) 266 | 267 | segmentation[instance_mask] = instance[instance_mask] 268 | segmentation += 1 # Shift all legit classes by 1. 269 | segmentation[~segmentation_mask] = 0 # Shift void class to zero. 270 | 271 | return segmentation, instance_id_to_class 272 | -------------------------------------------------------------------------------- /stretchbev/utils/visualisation.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pylab 2 | import numpy as np 3 | import torch 4 | 5 | from stretchbev.utils.instance import predict_instance_segmentation_and_trajectories 6 | 7 | DEFAULT_COLORMAP = matplotlib.pylab.cm.jet 8 | 9 | 10 | def flow_to_image(flow: np.ndarray, autoscale: bool = False) -> np.ndarray: 11 | """ 12 | Applies colour map to flow which should be a 2 channel image tensor HxWx2. Returns a HxWx3 numpy image 13 | Code adapted from: https://github.com/liruoteng/FlowNet/blob/master/models/flownet/scripts/flowlib.py 14 | """ 15 | u = flow[0, :, :] 16 | v = flow[1, :, :] 17 | 18 | # Convert to polar coordinates 19 | rad = np.sqrt(u ** 2 + v ** 2) 20 | maxrad = np.max(rad) 21 | 22 | # Normalise flow maps 23 | if autoscale: 24 | u /= maxrad + np.finfo(float).eps 25 | v /= maxrad + np.finfo(float).eps 26 | 27 | # visualise flow with cmap 28 | return np.uint8(compute_color(u, v) * 255) 29 | 30 | 31 | def _normalise(image: np.ndarray) -> np.ndarray: 32 | lower = np.min(image) 33 | delta = np.max(image) - lower 34 | if delta == 0: 35 | delta = 1 36 | image = (image.astype(np.float32) - lower) / delta 37 | return image 38 | 39 | 40 | def apply_colour_map( 41 | image: np.ndarray, cmap: matplotlib.colors.LinearSegmentedColormap = DEFAULT_COLORMAP, autoscale: bool = False 42 | ) -> np.ndarray: 43 | """ 44 | Applies a colour map to the given 1 or 2 channel numpy image. if 2 channel, must be 2xHxW. 45 | Returns a HxWx3 numpy image 46 | """ 47 | if image.ndim == 2 or (image.ndim == 3 and image.shape[0] == 1): 48 | if image.ndim == 3: 49 | image = image[0] 50 | # grayscale scalar image 51 | if autoscale: 52 | image = _normalise(image) 53 | return cmap(image)[:, :, :3] 54 | if image.shape[0] == 2: 55 | # 2 dimensional UV 56 | return flow_to_image(image, autoscale=autoscale) 57 | if image.shape[0] == 3: 58 | # normalise rgb channels 59 | if autoscale: 60 | image = _normalise(image) 61 | return np.transpose(image, axes=[1, 2, 0]) 62 | raise Exception('Image must be 1, 2 or 3 channel to convert to colour_map (CxHxW)') 63 | 64 | 65 | def heatmap_image( 66 | image: np.ndarray, cmap: matplotlib.colors.LinearSegmentedColormap = DEFAULT_COLORMAP, autoscale: bool = True 67 | ) -> np.ndarray: 68 | """Colorize an 1 or 2 channel image with a colourmap.""" 69 | if not issubclass(image.dtype.type, np.floating): 70 | raise ValueError(f"Expected a ndarray of float type, but got dtype {image.dtype}") 71 | if not (image.ndim == 2 or (image.ndim == 3 and image.shape[0] in [1, 2])): 72 | raise ValueError(f"Expected a ndarray of shape [H, W] or [1, H, W] or [2, H, W], but got shape {image.shape}") 73 | heatmap_np = apply_colour_map(image, cmap=cmap, autoscale=autoscale) 74 | heatmap_np = np.uint8(heatmap_np * 255) 75 | return heatmap_np 76 | 77 | 78 | def compute_color(u: np.ndarray, v: np.ndarray) -> np.ndarray: 79 | assert u.shape == v.shape 80 | [h, w] = u.shape 81 | img = np.zeros([h, w, 3]) 82 | nan_mask = np.isnan(u) | np.isnan(v) 83 | u[nan_mask] = 0 84 | v[nan_mask] = 0 85 | 86 | colorwheel = make_color_wheel() 87 | ncols = np.size(colorwheel, 0) 88 | 89 | rad = np.sqrt(u ** 2 + v ** 2) 90 | a = np.arctan2(-v, -u) / np.pi 91 | f_k = (a + 1) / 2 * (ncols - 1) + 1 92 | k_0 = np.floor(f_k).astype(int) 93 | k_1 = k_0 + 1 94 | k_1[k_1 == ncols + 1] = 1 95 | f = f_k - k_0 96 | 97 | for i in range(0, np.size(colorwheel, 1)): 98 | tmp = colorwheel[:, i] 99 | col0 = tmp[k_0 - 1] / 255 100 | col1 = tmp[k_1 - 1] / 255 101 | col = (1 - f) * col0 + f * col1 102 | 103 | idx = rad <= 1 104 | col[idx] = 1 - rad[idx] * (1 - col[idx]) 105 | notidx = np.logical_not(idx) 106 | 107 | col[notidx] *= 0.75 108 | img[:, :, i] = col * (1 - nan_mask) 109 | 110 | return img 111 | 112 | 113 | def make_color_wheel() -> np.ndarray: 114 | """ 115 | Create colour wheel. 116 | Code adapted from https://github.com/liruoteng/FlowNet/blob/master/models/flownet/scripts/flowlib.py 117 | """ 118 | red_yellow = 15 119 | yellow_green = 6 120 | green_cyan = 4 121 | cyan_blue = 11 122 | blue_magenta = 13 123 | magenta_red = 6 124 | 125 | ncols = red_yellow + yellow_green + green_cyan + cyan_blue + blue_magenta + magenta_red 126 | colorwheel = np.zeros([ncols, 3]) 127 | 128 | col = 0 129 | 130 | # red_yellow 131 | colorwheel[0:red_yellow, 0] = 255 132 | colorwheel[0:red_yellow, 1] = np.transpose(np.floor(255 * np.arange(0, red_yellow) / red_yellow)) 133 | col += red_yellow 134 | 135 | # yellow_green 136 | colorwheel[col: col + yellow_green, 0] = 255 - np.transpose( 137 | np.floor(255 * np.arange(0, yellow_green) / yellow_green) 138 | ) 139 | colorwheel[col: col + yellow_green, 1] = 255 140 | col += yellow_green 141 | 142 | # green_cyan 143 | colorwheel[col: col + green_cyan, 1] = 255 144 | colorwheel[col: col + green_cyan, 2] = np.transpose(np.floor(255 * np.arange(0, green_cyan) / green_cyan)) 145 | col += green_cyan 146 | 147 | # cyan_blue 148 | colorwheel[col: col + cyan_blue, 1] = 255 - np.transpose(np.floor(255 * np.arange(0, cyan_blue) / cyan_blue)) 149 | colorwheel[col: col + cyan_blue, 2] = 255 150 | col += cyan_blue 151 | 152 | # blue_magenta 153 | colorwheel[col: col + blue_magenta, 2] = 255 154 | colorwheel[col: col + blue_magenta, 0] = np.transpose(np.floor(255 * np.arange(0, blue_magenta) / blue_magenta)) 155 | col += +blue_magenta 156 | 157 | # magenta_red 158 | colorwheel[col: col + magenta_red, 2] = 255 - np.transpose(np.floor(255 * np.arange(0, magenta_red) / magenta_red)) 159 | colorwheel[col: col + magenta_red, 0] = 255 160 | 161 | return colorwheel 162 | 163 | 164 | def make_contour(img, colour=[0, 0, 0], double_line=False): 165 | h, w = img.shape[:2] 166 | out = img.copy() 167 | # Vertical lines 168 | out[np.arange(h), np.repeat(0, h)] = colour 169 | out[np.arange(h), np.repeat(w - 1, h)] = colour 170 | 171 | # Horizontal lines 172 | out[np.repeat(0, w), np.arange(w)] = colour 173 | out[np.repeat(h - 1, w), np.arange(w)] = colour 174 | 175 | if double_line: 176 | out[np.arange(h), np.repeat(1, h)] = colour 177 | out[np.arange(h), np.repeat(w - 2, h)] = colour 178 | 179 | # Horizontal lines 180 | out[np.repeat(1, w), np.arange(w)] = colour 181 | out[np.repeat(h - 2, w), np.arange(w)] = colour 182 | return out 183 | 184 | 185 | def plot_instance_map(instance_image, instance_map, instance_colours=None, bg_image=None): 186 | if isinstance(instance_image, torch.Tensor): 187 | instance_image = instance_image.cpu().numpy() 188 | assert isinstance(instance_image, np.ndarray) 189 | if instance_colours is None: 190 | instance_colours = generate_instance_colours(instance_map) 191 | if len(instance_image.shape) > 2: 192 | instance_image = instance_image.reshape((instance_image.shape[-2], instance_image.shape[-1])) 193 | 194 | if bg_image is None: 195 | plot_image = 255 * np.ones((instance_image.shape[0], instance_image.shape[1], 3), dtype=np.uint8) 196 | else: 197 | plot_image = bg_image 198 | 199 | for key, value in instance_colours.items(): 200 | plot_image[instance_image == key] = value 201 | 202 | return plot_image 203 | 204 | 205 | def visualise_output(labels, output, cfg): 206 | semantic_colours = np.array([[255, 255, 255], [0, 0, 0]], dtype=np.uint8) 207 | 208 | consistent_instance_seg = predict_instance_segmentation_and_trajectories( 209 | output, compute_matched_centers=False 210 | ) 211 | 212 | sequence_length = consistent_instance_seg.shape[1] 213 | b = 0 214 | video = [] 215 | for t in range(sequence_length): 216 | out_t = [] 217 | # Ground truth 218 | unique_ids = torch.unique(labels['instance'][b, t]).cpu().numpy()[1:] 219 | instance_map = dict(zip(unique_ids, unique_ids)) 220 | instance_plot = plot_instance_map(labels['instance'][b, t].cpu(), instance_map)[::-1, ::-1] 221 | instance_plot = make_contour(instance_plot) 222 | 223 | semantic_seg = labels['segmentation'].squeeze(2).cpu().numpy() 224 | semantic_plot = semantic_colours[semantic_seg[b, t][::-1, ::-1]] 225 | semantic_plot = make_contour(semantic_plot) 226 | 227 | if cfg.INSTANCE_FLOW.ENABLED: 228 | future_flow_plot = labels['flow'][b, t].cpu().numpy() 229 | future_flow_plot[:, semantic_seg[b, t] != 1] = 0 230 | future_flow_plot = flow_to_image(future_flow_plot)[::-1, ::-1] 231 | future_flow_plot = make_contour(future_flow_plot) 232 | else: 233 | future_flow_plot = np.zeros_like(semantic_plot) 234 | 235 | center_plot = heatmap_image(labels['centerness'][b, t, 0].cpu().numpy())[::-1, ::-1] 236 | center_plot = make_contour(center_plot) 237 | 238 | offset_plot = labels['offset'][b, t].cpu().numpy() 239 | offset_plot[:, semantic_seg[b, t] != 1] = 0 240 | offset_plot = flow_to_image(offset_plot)[::-1, ::-1] 241 | offset_plot = make_contour(offset_plot) 242 | 243 | out_t.append(np.concatenate([instance_plot, future_flow_plot, 244 | semantic_plot, center_plot, offset_plot], axis=0)) 245 | 246 | # Predictions 247 | unique_ids = torch.unique(consistent_instance_seg[b, t]).cpu().numpy()[1:] 248 | instance_map = dict(zip(unique_ids, unique_ids)) 249 | instance_plot = plot_instance_map(consistent_instance_seg[b, t].cpu(), instance_map)[::-1, ::-1] 250 | instance_plot = make_contour(instance_plot) 251 | 252 | semantic_seg = output['segmentation'].argmax(dim=2).detach().cpu().numpy() 253 | semantic_plot = semantic_colours[semantic_seg[b, t][::-1, ::-1]] 254 | semantic_plot = make_contour(semantic_plot) 255 | 256 | if cfg.INSTANCE_FLOW.ENABLED: 257 | future_flow_plot = output['instance_flow'][b, t].detach().cpu().numpy() 258 | future_flow_plot[:, semantic_seg[b, t] != 1] = 0 259 | future_flow_plot = flow_to_image(future_flow_plot)[::-1, ::-1] 260 | future_flow_plot = make_contour(future_flow_plot) 261 | else: 262 | future_flow_plot = np.zeros_like(semantic_plot) 263 | 264 | center_plot = heatmap_image(output['instance_center'][b, t, 0].detach().cpu().numpy())[::-1, ::-1] 265 | center_plot = make_contour(center_plot) 266 | 267 | offset_plot = output['instance_offset'][b, t].detach().cpu().numpy() 268 | offset_plot[:, semantic_seg[b, t] != 1] = 0 269 | offset_plot = flow_to_image(offset_plot)[::-1, ::-1] 270 | offset_plot = make_contour(offset_plot) 271 | 272 | out_t.append(np.concatenate([instance_plot, future_flow_plot, 273 | semantic_plot, center_plot, offset_plot], axis=0)) 274 | out_t = np.concatenate(out_t, axis=1) 275 | # Shape (C, H, W) 276 | out_t = out_t.transpose((2, 0, 1)) 277 | 278 | video.append(out_t) 279 | 280 | # Shape (B, T, C, H, W) 281 | video = np.stack(video)[None] 282 | return video 283 | 284 | 285 | def convert_figure_numpy(figure): 286 | """ Convert figure to numpy image """ 287 | figure_np = np.frombuffer(figure.canvas.tostring_rgb(), dtype=np.uint8) 288 | figure_np = figure_np.reshape(figure.canvas.get_width_height()[::-1] + (3,)) 289 | return figure_np 290 | 291 | 292 | def generate_instance_colours(instance_map): 293 | # Most distinct 22 colors (kelly colors from https://stackoverflow.com/questions/470690/how-to-automatically-generate 294 | # -n-distinct-colors) 295 | # plus some colours from AD40k 296 | INSTANCE_COLOURS = np.asarray([ 297 | [0, 0, 0], 298 | [255, 179, 0], 299 | [128, 62, 117], 300 | [255, 104, 0], 301 | [166, 189, 215], 302 | [193, 0, 32], 303 | [206, 162, 98], 304 | [129, 112, 102], 305 | [0, 125, 52], 306 | [246, 118, 142], 307 | [0, 83, 138], 308 | [255, 122, 92], 309 | [83, 55, 122], 310 | [255, 142, 0], 311 | [179, 40, 81], 312 | [244, 200, 0], 313 | [127, 24, 13], 314 | [147, 170, 0], 315 | [89, 51, 21], 316 | [241, 58, 19], 317 | [35, 44, 22], 318 | [112, 224, 255], 319 | [70, 184, 160], 320 | [153, 0, 255], 321 | [71, 255, 0], 322 | [255, 0, 163], 323 | [255, 204, 0], 324 | [0, 255, 235], 325 | [255, 0, 235], 326 | [255, 0, 122], 327 | [255, 245, 0], 328 | [10, 190, 212], 329 | [214, 255, 0], 330 | [0, 204, 255], 331 | [20, 0, 255], 332 | [255, 255, 0], 333 | [0, 153, 255], 334 | [0, 255, 204], 335 | [41, 255, 0], 336 | [173, 0, 255], 337 | [0, 245, 255], 338 | [71, 0, 255], 339 | [0, 255, 184], 340 | [0, 92, 255], 341 | [184, 255, 0], 342 | [255, 214, 0], 343 | [25, 194, 194], 344 | [92, 0, 255], 345 | [220, 220, 220], 346 | [255, 9, 92], 347 | [112, 9, 255], 348 | [8, 255, 214], 349 | [255, 184, 6], 350 | [10, 255, 71], 351 | [255, 41, 10], 352 | [7, 255, 255], 353 | [224, 255, 8], 354 | [102, 8, 255], 355 | [255, 61, 6], 356 | [255, 194, 7], 357 | [0, 255, 20], 358 | [255, 8, 41], 359 | [255, 5, 153], 360 | [6, 51, 255], 361 | [235, 12, 255], 362 | [160, 150, 20], 363 | [0, 163, 255], 364 | [140, 140, 140], 365 | [250, 10, 15], 366 | [20, 255, 0], 367 | ]) 368 | 369 | return {instance_id: INSTANCE_COLOURS[global_instance_id % len(INSTANCE_COLOURS)] for 370 | instance_id, global_instance_id in instance_map.items() 371 | } 372 | -------------------------------------------------------------------------------- /stretchbev/utils/.ipynb_checkpoints/visualisation-checkpoint.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import matplotlib.pylab 4 | 5 | from fiery.utils.instance import predict_instance_segmentation_and_trajectories 6 | 7 | DEFAULT_COLORMAP = matplotlib.pylab.cm.jet 8 | 9 | 10 | def flow_to_image(flow: np.ndarray, autoscale: bool = False) -> np.ndarray: 11 | """ 12 | Applies colour map to flow which should be a 2 channel image tensor HxWx2. Returns a HxWx3 numpy image 13 | Code adapted from: https://github.com/liruoteng/FlowNet/blob/master/models/flownet/scripts/flowlib.py 14 | """ 15 | u = flow[0, :, :] 16 | v = flow[1, :, :] 17 | 18 | # Convert to polar coordinates 19 | rad = np.sqrt(u ** 2 + v ** 2) 20 | maxrad = np.max(rad) 21 | 22 | # Normalise flow maps 23 | if autoscale: 24 | u /= maxrad + np.finfo(float).eps 25 | v /= maxrad + np.finfo(float).eps 26 | 27 | # visualise flow with cmap 28 | return np.uint8(compute_color(u, v) * 255) 29 | 30 | 31 | def _normalise(image: np.ndarray) -> np.ndarray: 32 | lower = np.min(image) 33 | delta = np.max(image) - lower 34 | if delta == 0: 35 | delta = 1 36 | image = (image.astype(np.float32) - lower) / delta 37 | return image 38 | 39 | 40 | def apply_colour_map( 41 | image: np.ndarray, cmap: matplotlib.colors.LinearSegmentedColormap = DEFAULT_COLORMAP, autoscale: bool = False 42 | ) -> np.ndarray: 43 | """ 44 | Applies a colour map to the given 1 or 2 channel numpy image. if 2 channel, must be 2xHxW. 45 | Returns a HxWx3 numpy image 46 | """ 47 | if image.ndim == 2 or (image.ndim == 3 and image.shape[0] == 1): 48 | if image.ndim == 3: 49 | image = image[0] 50 | # grayscale scalar image 51 | if autoscale: 52 | image = _normalise(image) 53 | return cmap(image)[:, :, :3] 54 | if image.shape[0] == 2: 55 | # 2 dimensional UV 56 | return flow_to_image(image, autoscale=autoscale) 57 | if image.shape[0] == 3: 58 | # normalise rgb channels 59 | if autoscale: 60 | image = _normalise(image) 61 | return np.transpose(image, axes=[1, 2, 0]) 62 | raise Exception('Image must be 1, 2 or 3 channel to convert to colour_map (CxHxW)') 63 | 64 | 65 | def heatmap_image( 66 | image: np.ndarray, cmap: matplotlib.colors.LinearSegmentedColormap = DEFAULT_COLORMAP, autoscale: bool = True 67 | ) -> np.ndarray: 68 | """Colorize an 1 or 2 channel image with a colourmap.""" 69 | if not issubclass(image.dtype.type, np.floating): 70 | raise ValueError(f"Expected a ndarray of float type, but got dtype {image.dtype}") 71 | if not (image.ndim == 2 or (image.ndim == 3 and image.shape[0] in [1, 2])): 72 | raise ValueError(f"Expected a ndarray of shape [H, W] or [1, H, W] or [2, H, W], but got shape {image.shape}") 73 | heatmap_np = apply_colour_map(image, cmap=cmap, autoscale=autoscale) 74 | heatmap_np = np.uint8(heatmap_np * 255) 75 | return heatmap_np 76 | 77 | 78 | def compute_color(u: np.ndarray, v: np.ndarray) -> np.ndarray: 79 | assert u.shape == v.shape 80 | [h, w] = u.shape 81 | img = np.zeros([h, w, 3]) 82 | nan_mask = np.isnan(u) | np.isnan(v) 83 | u[nan_mask] = 0 84 | v[nan_mask] = 0 85 | 86 | colorwheel = make_color_wheel() 87 | ncols = np.size(colorwheel, 0) 88 | 89 | rad = np.sqrt(u ** 2 + v ** 2) 90 | a = np.arctan2(-v, -u) / np.pi 91 | f_k = (a + 1) / 2 * (ncols - 1) + 1 92 | k_0 = np.floor(f_k).astype(int) 93 | k_1 = k_0 + 1 94 | k_1[k_1 == ncols + 1] = 1 95 | f = f_k - k_0 96 | 97 | for i in range(0, np.size(colorwheel, 1)): 98 | tmp = colorwheel[:, i] 99 | col0 = tmp[k_0 - 1] / 255 100 | col1 = tmp[k_1 - 1] / 255 101 | col = (1 - f) * col0 + f * col1 102 | 103 | idx = rad <= 1 104 | col[idx] = 1 - rad[idx] * (1 - col[idx]) 105 | notidx = np.logical_not(idx) 106 | 107 | col[notidx] *= 0.75 108 | img[:, :, i] = col * (1 - nan_mask) 109 | 110 | return img 111 | 112 | 113 | def make_color_wheel() -> np.ndarray: 114 | """ 115 | Create colour wheel. 116 | Code adapted from https://github.com/liruoteng/FlowNet/blob/master/models/flownet/scripts/flowlib.py 117 | """ 118 | red_yellow = 15 119 | yellow_green = 6 120 | green_cyan = 4 121 | cyan_blue = 11 122 | blue_magenta = 13 123 | magenta_red = 6 124 | 125 | ncols = red_yellow + yellow_green + green_cyan + cyan_blue + blue_magenta + magenta_red 126 | colorwheel = np.zeros([ncols, 3]) 127 | 128 | col = 0 129 | 130 | # red_yellow 131 | colorwheel[0:red_yellow, 0] = 255 132 | colorwheel[0:red_yellow, 1] = np.transpose(np.floor(255 * np.arange(0, red_yellow) / red_yellow)) 133 | col += red_yellow 134 | 135 | # yellow_green 136 | colorwheel[col : col + yellow_green, 0] = 255 - np.transpose( 137 | np.floor(255 * np.arange(0, yellow_green) / yellow_green) 138 | ) 139 | colorwheel[col : col + yellow_green, 1] = 255 140 | col += yellow_green 141 | 142 | # green_cyan 143 | colorwheel[col : col + green_cyan, 1] = 255 144 | colorwheel[col : col + green_cyan, 2] = np.transpose(np.floor(255 * np.arange(0, green_cyan) / green_cyan)) 145 | col += green_cyan 146 | 147 | # cyan_blue 148 | colorwheel[col : col + cyan_blue, 1] = 255 - np.transpose(np.floor(255 * np.arange(0, cyan_blue) / cyan_blue)) 149 | colorwheel[col : col + cyan_blue, 2] = 255 150 | col += cyan_blue 151 | 152 | # blue_magenta 153 | colorwheel[col : col + blue_magenta, 2] = 255 154 | colorwheel[col : col + blue_magenta, 0] = np.transpose(np.floor(255 * np.arange(0, blue_magenta) / blue_magenta)) 155 | col += +blue_magenta 156 | 157 | # magenta_red 158 | colorwheel[col : col + magenta_red, 2] = 255 - np.transpose(np.floor(255 * np.arange(0, magenta_red) / magenta_red)) 159 | colorwheel[col : col + magenta_red, 0] = 255 160 | 161 | return colorwheel 162 | 163 | 164 | def make_contour(img, colour=[0, 0, 0], double_line=False): 165 | h, w = img.shape[:2] 166 | out = img.copy() 167 | # Vertical lines 168 | out[np.arange(h), np.repeat(0, h)] = colour 169 | out[np.arange(h), np.repeat(w - 1, h)] = colour 170 | 171 | # Horizontal lines 172 | out[np.repeat(0, w), np.arange(w)] = colour 173 | out[np.repeat(h - 1, w), np.arange(w)] = colour 174 | 175 | if double_line: 176 | out[np.arange(h), np.repeat(1, h)] = colour 177 | out[np.arange(h), np.repeat(w - 2, h)] = colour 178 | 179 | # Horizontal lines 180 | out[np.repeat(1, w), np.arange(w)] = colour 181 | out[np.repeat(h - 2, w), np.arange(w)] = colour 182 | return out 183 | 184 | 185 | def plot_instance_map(instance_image, instance_map, instance_colours=None, bg_image=None): 186 | if isinstance(instance_image, torch.Tensor): 187 | instance_image = instance_image.cpu().numpy() 188 | assert isinstance(instance_image, np.ndarray) 189 | if instance_colours is None: 190 | instance_colours = generate_instance_colours(instance_map) 191 | if len(instance_image.shape) > 2: 192 | instance_image = instance_image.reshape((instance_image.shape[-2], instance_image.shape[-1])) 193 | 194 | if bg_image is None: 195 | plot_image = 255 * np.ones((instance_image.shape[0], instance_image.shape[1], 3), dtype=np.uint8) 196 | else: 197 | plot_image = bg_image 198 | 199 | for key, value in instance_colours.items(): 200 | plot_image[instance_image == key] = value 201 | 202 | return plot_image 203 | 204 | 205 | def visualise_output(labels, output, cfg): 206 | semantic_colours = np.array([[255, 255, 255], [0, 0, 0]], dtype=np.uint8) 207 | 208 | consistent_instance_seg = predict_instance_segmentation_and_trajectories( 209 | output, compute_matched_centers=False 210 | ) 211 | 212 | sequence_length = consistent_instance_seg.shape[1] 213 | b = 0 214 | video = [] 215 | for t in range(sequence_length): 216 | out_t = [] 217 | # Ground truth 218 | unique_ids = torch.unique(labels['instance'][b, t]).cpu().numpy()[1:] 219 | instance_map = dict(zip(unique_ids, unique_ids)) 220 | instance_plot = plot_instance_map(labels['instance'][b, t].cpu(), instance_map)[::-1, ::-1] 221 | instance_plot = make_contour(instance_plot) 222 | 223 | semantic_seg = labels['segmentation'].squeeze(2).cpu().numpy() 224 | semantic_plot = semantic_colours[semantic_seg[b, t][::-1, ::-1]] 225 | semantic_plot = make_contour(semantic_plot) 226 | 227 | if cfg.INSTANCE_FLOW.ENABLED: 228 | future_flow_plot = labels['flow'][b, t].cpu().numpy() 229 | future_flow_plot[:, semantic_seg[b, t] != 1] = 0 230 | future_flow_plot = flow_to_image(future_flow_plot)[::-1, ::-1] 231 | future_flow_plot = make_contour(future_flow_plot) 232 | else: 233 | future_flow_plot = np.zeros_like(semantic_plot) 234 | 235 | center_plot = heatmap_image(labels['centerness'][b, t, 0].cpu().numpy())[::-1, ::-1] 236 | center_plot = make_contour(center_plot) 237 | 238 | offset_plot = labels['offset'][b, t].cpu().numpy() 239 | offset_plot[:, semantic_seg[b, t] != 1] = 0 240 | offset_plot = flow_to_image(offset_plot)[::-1, ::-1] 241 | offset_plot = make_contour(offset_plot) 242 | 243 | out_t.append(np.concatenate([instance_plot, future_flow_plot, 244 | semantic_plot, center_plot, offset_plot], axis=0)) 245 | 246 | # Predictions 247 | unique_ids = torch.unique(consistent_instance_seg[b, t]).cpu().numpy()[1:] 248 | instance_map = dict(zip(unique_ids, unique_ids)) 249 | instance_plot = plot_instance_map(consistent_instance_seg[b, t].cpu(), instance_map)[::-1, ::-1] 250 | instance_plot = make_contour(instance_plot) 251 | 252 | semantic_seg = output['segmentation'].argmax(dim=2).detach().cpu().numpy() 253 | semantic_plot = semantic_colours[semantic_seg[b, t][::-1, ::-1]] 254 | semantic_plot = make_contour(semantic_plot) 255 | 256 | if cfg.INSTANCE_FLOW.ENABLED: 257 | future_flow_plot = output['instance_flow'][b, t].detach().cpu().numpy() 258 | future_flow_plot[:, semantic_seg[b, t] != 1] = 0 259 | future_flow_plot = flow_to_image(future_flow_plot)[::-1, ::-1] 260 | future_flow_plot = make_contour(future_flow_plot) 261 | else: 262 | future_flow_plot = np.zeros_like(semantic_plot) 263 | 264 | center_plot = heatmap_image(output['instance_center'][b, t, 0].detach().cpu().numpy())[::-1, ::-1] 265 | center_plot = make_contour(center_plot) 266 | 267 | offset_plot = output['instance_offset'][b, t].detach().cpu().numpy() 268 | offset_plot[:, semantic_seg[b, t] != 1] = 0 269 | offset_plot = flow_to_image(offset_plot)[::-1, ::-1] 270 | offset_plot = make_contour(offset_plot) 271 | 272 | out_t.append(np.concatenate([instance_plot, future_flow_plot, 273 | semantic_plot, center_plot, offset_plot], axis=0)) 274 | out_t = np.concatenate(out_t, axis=1) 275 | # Shape (C, H, W) 276 | out_t = out_t.transpose((2, 0, 1)) 277 | 278 | video.append(out_t) 279 | 280 | # Shape (B, T, C, H, W) 281 | video = np.stack(video)[None] 282 | return video 283 | 284 | 285 | def convert_figure_numpy(figure): 286 | """ Convert figure to numpy image """ 287 | figure_np = np.frombuffer(figure.canvas.tostring_rgb(), dtype=np.uint8) 288 | figure_np = figure_np.reshape(figure.canvas.get_width_height()[::-1] + (3,)) 289 | return figure_np 290 | 291 | 292 | def generate_instance_colours(instance_map): 293 | # Most distinct 22 colors (kelly colors from https://stackoverflow.com/questions/470690/how-to-automatically-generate 294 | # -n-distinct-colors) 295 | # plus some colours from AD40k 296 | INSTANCE_COLOURS = np.asarray([ 297 | [0, 0, 0], 298 | [255, 179, 0], 299 | [128, 62, 117], 300 | [255, 104, 0], 301 | [166, 189, 215], 302 | [193, 0, 32], 303 | [206, 162, 98], 304 | [129, 112, 102], 305 | [0, 125, 52], 306 | [246, 118, 142], 307 | [0, 83, 138], 308 | [255, 122, 92], 309 | [83, 55, 122], 310 | [255, 142, 0], 311 | [179, 40, 81], 312 | [244, 200, 0], 313 | [127, 24, 13], 314 | [147, 170, 0], 315 | [89, 51, 21], 316 | [241, 58, 19], 317 | [35, 44, 22], 318 | [112, 224, 255], 319 | [70, 184, 160], 320 | [153, 0, 255], 321 | [71, 255, 0], 322 | [255, 0, 163], 323 | [255, 204, 0], 324 | [0, 255, 235], 325 | [255, 0, 235], 326 | [255, 0, 122], 327 | [255, 245, 0], 328 | [10, 190, 212], 329 | [214, 255, 0], 330 | [0, 204, 255], 331 | [20, 0, 255], 332 | [255, 255, 0], 333 | [0, 153, 255], 334 | [0, 255, 204], 335 | [41, 255, 0], 336 | [173, 0, 255], 337 | [0, 245, 255], 338 | [71, 0, 255], 339 | [0, 255, 184], 340 | [0, 92, 255], 341 | [184, 255, 0], 342 | [255, 214, 0], 343 | [25, 194, 194], 344 | [92, 0, 255], 345 | [220, 220, 220], 346 | [255, 9, 92], 347 | [112, 9, 255], 348 | [8, 255, 214], 349 | [255, 184, 6], 350 | [10, 255, 71], 351 | [255, 41, 10], 352 | [7, 255, 255], 353 | [224, 255, 8], 354 | [102, 8, 255], 355 | [255, 61, 6], 356 | [255, 194, 7], 357 | [0, 255, 20], 358 | [255, 8, 41], 359 | [255, 5, 153], 360 | [6, 51, 255], 361 | [235, 12, 255], 362 | [160, 150, 20], 363 | [0, 163, 255], 364 | [140, 140, 140], 365 | [250, 10, 15], 366 | [20, 255, 0], 367 | ]) 368 | 369 | return {instance_id: INSTANCE_COLOURS[global_instance_id % len(INSTANCE_COLOURS)] for 370 | instance_id, global_instance_id in instance_map.items() 371 | } 372 | -------------------------------------------------------------------------------- /stretchbev/.ipynb_checkpoints/trainer-checkpoint.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import pytorch_lightning as pl 4 | import torch.distributions as distrib 5 | import torch.nn.functional as F 6 | 7 | 8 | 9 | from fiery.config import get_cfg 10 | from fiery.models.fiery import Fiery 11 | from fiery.losses import ProbabilisticLoss, SpatialRegressionLoss, SegmentationLoss 12 | from fiery.metrics import IntersectionOverUnion, PanopticMetric 13 | from fiery.utils.geometry import cumulative_warp_features_reverse 14 | from fiery.utils.instance import predict_instance_segmentation_and_trajectories 15 | from fiery.utils.visualisation import visualise_output 16 | from fiery.models import model_utils 17 | 18 | 19 | class TrainingModule(pl.LightningModule): 20 | def __init__(self, hparams): 21 | super().__init__() 22 | 23 | # see config.py for details 24 | self.hparams = hparams 25 | # pytorch lightning does not support saving YACS CfgNone 26 | cfg = get_cfg(cfg_dict=self.hparams) 27 | self.cfg = cfg 28 | self.n_classes = len(self.cfg.SEMANTIC_SEG.WEIGHTS) 29 | 30 | # Bird's-eye view extent in meters 31 | assert self.cfg.LIFT.X_BOUND[1] > 0 and self.cfg.LIFT.Y_BOUND[1] > 0 32 | self.spatial_extent = (self.cfg.LIFT.X_BOUND[1], self.cfg.LIFT.Y_BOUND[1]) 33 | 34 | # Model 35 | self.model = Fiery(cfg) 36 | 37 | # Losses 38 | self.losses_fn = nn.ModuleDict() 39 | self.losses_fn['segmentation'] = SegmentationLoss( 40 | class_weights=torch.Tensor(self.cfg.SEMANTIC_SEG.WEIGHTS), 41 | use_top_k=self.cfg.SEMANTIC_SEG.USE_TOP_K, 42 | top_k_ratio=self.cfg.SEMANTIC_SEG.TOP_K_RATIO, 43 | future_discount=self.cfg.FUTURE_DISCOUNT, 44 | ) 45 | 46 | # Uncertainty weighting 47 | self.model.segmentation_weight = nn.Parameter(torch.tensor(0.0), requires_grad=True) 48 | 49 | self.metric_iou_val = IntersectionOverUnion(self.n_classes) 50 | 51 | self.losses_fn['instance_center'] = SpatialRegressionLoss( 52 | norm=2, future_discount=self.cfg.FUTURE_DISCOUNT 53 | ) 54 | self.losses_fn['instance_offset'] = SpatialRegressionLoss( 55 | norm=1, future_discount=self.cfg.FUTURE_DISCOUNT, ignore_index=self.cfg.DATASET.IGNORE_INDEX 56 | ) 57 | 58 | # Uncertainty weighting 59 | # self.model.centerness_weight = nn.Parameter(torch.tensor(0.0), requires_grad=True) 60 | # self.model.offset_weight = nn.Parameter(torch.tensor(0.0), requires_grad=True) 61 | 62 | self.metric_panoptic_val = PanopticMetric(n_classes=self.n_classes) 63 | 64 | if self.cfg.INSTANCE_FLOW.ENABLED: 65 | self.losses_fn['instance_flow'] = SpatialRegressionLoss( 66 | norm=1, future_discount=self.cfg.FUTURE_DISCOUNT, ignore_index=self.cfg.DATASET.IGNORE_INDEX 67 | ) 68 | # Uncertainty weighting 69 | # self.model.flow_weight = nn.Parameter(torch.tensor(0.0), requires_grad=True) 70 | 71 | # self.model.kl_y_weight = nn.Parameter(torch.tensor(0.0), requires_grad=True) 72 | # self.model.kl_z_weight = nn.Parameter(torch.tensor(0.0), requires_grad=True) 73 | # self.model.reconstruction_weight = nn.Parameter(torch.tensor(0.0), requires_grad=True) 74 | 75 | self.training_step_count = 0 76 | 77 | def shared_step(self, batch, is_train): 78 | image = batch['image'] 79 | intrinsics = batch['intrinsics'] 80 | extrinsics = batch['extrinsics'] 81 | future_egomotion = batch['future_egomotion'] 82 | 83 | # Warp labels 84 | labels, future_distribution_inputs = self.prepare_future_labels(batch) 85 | 86 | # Forward pass 87 | all_output = self.model( 88 | image, intrinsics, extrinsics, future_egomotion, future_distribution_inputs 89 | ) 90 | 91 | ##### 92 | # Loss computation 93 | ##### 94 | output = all_output['bev_output'] 95 | loss = {} 96 | output = all_output['bev_output'] 97 | segmentation_factor = 5.44 #1 / torch.exp(model.segmentation_weight) 98 | loss['segmentation'] = segmentation_factor * self.losses_fn['segmentation']( 99 | output['segmentation'], labels['segmentation'] 100 | ) 101 | # loss['segmentation_uncertainty'] = #0.5 * model.segmentation_weight 102 | 103 | centerness_factor = 591 # 1 / (2*torch.exp(model.centerness_weight)) 104 | loss['instance_center'] = centerness_factor * self.losses_fn['instance_center']( 105 | output['instance_center'], labels['centerness'] 106 | ) 107 | 108 | offset_factor = 1.56 # / (2*torch.exp(model.offset_weight)) 109 | loss['instance_offset'] = offset_factor * self.losses_fn['instance_offset']( 110 | output['instance_offset'], labels['offset'] 111 | ) 112 | 113 | # loss['centerness_uncertainty'] = 0.5 * model.centerness_weight 114 | # loss['offset_uncertainty'] = 0.5 * model.offset_weight 115 | 116 | 117 | flow_factor = 3.90 #/ (2*torch.exp(model.flow_weight)) 118 | loss['instance_flow'] = flow_factor * self.losses_fn['instance_flow']( 119 | output['instance_flow'], labels['flow'] 120 | ) 121 | 122 | reconstruction_factor = 10# / (2*torch.exp(self.model.reconstruction_weight)) 123 | loss['state_reconstruction'] = reconstruction_factor * F.mse_loss(all_output['generated_srvp_x'], all_output['lss_outs']) 124 | # loss['reconstruction_uncertainty'] = 0.5 * self.model.reconstruction_weight 125 | 126 | kl_y_factor = 1e-2# / (2*torch.exp(self.model.kl_y_weight)) 127 | q_y_0 = model_utils.make_normal_from_raw_params(all_output['q_y0_params']) 128 | kl_y_0 = distrib.kl_divergence(q_y_0, distrib.Normal(0, 1)).mean() 129 | 130 | loss['kl_y0'] = kl_y_factor * kl_y_0 131 | # loss['kl_y0_uncertainty'] = 0.5 * self.model.kl_y_weight 132 | 133 | kl_z_factor = 1e-2# / (2*torch.exp(self.model.kl_z_weight)) 134 | q_z, p_z = model_utils.make_normal_from_raw_params(all_output['q_z_params']), model_utils.make_normal_from_raw_params(all_output['p_z_params']) 135 | kl_z = distrib.kl_divergence(q_z, p_z).mean() 136 | 137 | loss['kl_z'] = kl_z_factor * kl_z 138 | # loss['kl_z_uncertainty'] = 0.5 * self.model.kl_z_weight 139 | 140 | # Metrics 141 | if not is_train: 142 | seg_prediction = output['segmentation'].detach() 143 | seg_prediction = torch.argmax(seg_prediction, dim=2, keepdims=True) 144 | self.metric_iou_val(seg_prediction, labels['segmentation']) 145 | 146 | pred_consistent_instance_seg = predict_instance_segmentation_and_trajectories( 147 | output, compute_matched_centers=False 148 | ) 149 | 150 | self.metric_panoptic_val(pred_consistent_instance_seg, labels['instance']) 151 | 152 | return all_output, labels, loss 153 | 154 | def prepare_future_labels(self, batch, receptive_field=1): 155 | labels = {} 156 | future_distribution_inputs = [] 157 | 158 | segmentation_labels = batch['segmentation'] 159 | instance_center_labels = batch['centerness'] 160 | instance_offset_labels = batch['offset'] 161 | instance_flow_labels = batch['flow'] 162 | gt_instance = batch['instance'] 163 | future_egomotion = batch['future_egomotion'] 164 | 165 | # Warp labels to present's reference frame 166 | segmentation_labels = cumulative_warp_features_reverse( 167 | segmentation_labels[:, (receptive_field - 1):].float(), 168 | future_egomotion[:, (receptive_field - 1):], 169 | mode='nearest', spatial_extent=self.spatial_extent, 170 | ).long().contiguous() 171 | labels['segmentation'] = segmentation_labels 172 | future_distribution_inputs.append(segmentation_labels) 173 | 174 | # Warp instance labels to present's reference frame 175 | gt_instance = cumulative_warp_features_reverse( 176 | gt_instance[:, (receptive_field - 1):].float().unsqueeze(2), 177 | future_egomotion[:, (receptive_field - 1):], 178 | mode='nearest', spatial_extent=self.spatial_extent, 179 | ).long().contiguous()[:, :, 0] 180 | labels['instance'] = gt_instance 181 | 182 | instance_center_labels = cumulative_warp_features_reverse( 183 | instance_center_labels[:, (receptive_field - 1):], 184 | future_egomotion[:, (receptive_field - 1):], 185 | mode='nearest', spatial_extent=self.spatial_extent, 186 | ).contiguous() 187 | labels['centerness'] = instance_center_labels 188 | 189 | instance_offset_labels = cumulative_warp_features_reverse( 190 | instance_offset_labels[:, (receptive_field- 1):], 191 | future_egomotion[:, (receptive_field - 1):], 192 | mode='nearest', spatial_extent=self.spatial_extent, 193 | ).contiguous() 194 | labels['offset'] = instance_offset_labels 195 | 196 | future_distribution_inputs.append(instance_center_labels) 197 | future_distribution_inputs.append(instance_offset_labels) 198 | 199 | if self.cfg.INSTANCE_FLOW.ENABLED: 200 | instance_flow_labels = cumulative_warp_features_reverse( 201 | instance_flow_labels[:, (receptive_field - 1):], 202 | future_egomotion[:, (receptive_field - 1):], 203 | mode='nearest', spatial_extent=self.spatial_extent, 204 | ).contiguous() 205 | labels['flow'] = instance_flow_labels 206 | 207 | future_distribution_inputs.append(instance_flow_labels) 208 | 209 | 210 | if len(future_distribution_inputs) > 0: 211 | future_distribution_inputs = torch.cat(future_distribution_inputs, dim=2) 212 | 213 | b, t, n, h, w = future_distribution_inputs.shape 214 | future_distribution_inputs = F.adaptive_max_pool2d(future_distribution_inputs.reshape(b, t*n, h, w), 50).reshape(b, t, n, 50, 50) 215 | 216 | return labels, future_distribution_inputs 217 | 218 | def visualise(self, labels, output, batch_idx, prefix='train'): 219 | visualisation_video = visualise_output(labels, output['bev_output'], self.cfg) 220 | name = f'{prefix}_outputs' 221 | if prefix == 'val': 222 | name = name + f'_{batch_idx}' 223 | self.logger.experiment.add_video(name, visualisation_video, global_step=self.training_step_count, fps=2) 224 | 225 | def training_step(self, batch, batch_idx): 226 | output, labels, loss = self.shared_step(batch, True) 227 | self.training_step_count += 1 228 | for key, value in loss.items(): 229 | self.logger.experiment.add_scalar(key, value, global_step=self.training_step_count) 230 | if self.training_step_count % self.cfg.VIS_INTERVAL == 0: 231 | self.visualise(labels, output, batch_idx, prefix='train') 232 | return sum(loss.values()) 233 | 234 | def validation_step(self, batch, batch_idx): 235 | output, labels, loss = self.shared_step(batch, False) 236 | for key, value in loss.items(): 237 | self.log('val_' + key, value) 238 | 239 | if batch_idx == 0: 240 | self.visualise(labels, output, batch_idx, prefix='val') 241 | 242 | def shared_epoch_end(self, step_outputs, is_train): 243 | # log per class iou metrics 244 | class_names = ['background', 'dynamic'] 245 | if not is_train: 246 | scores = self.metric_iou_val.compute() 247 | for key, value in zip(class_names, scores): 248 | self.logger.experiment.add_scalar('val_iou_' + key, value, global_step=self.training_step_count) 249 | self.metric_iou_val.reset() 250 | 251 | if not is_train: 252 | scores = self.metric_panoptic_val.compute() 253 | for key, value in scores.items(): 254 | for instance_name, score in zip(['background', 'vehicles'], value): 255 | if instance_name != 'background': 256 | self.logger.experiment.add_scalar(f'val_{key}_{instance_name}', score.item(), 257 | global_step=self.training_step_count) 258 | self.metric_panoptic_val.reset() 259 | 260 | # self.logger.experiment.add_scalar('segmentation_weight', 261 | # 1 / (torch.exp(self.model.segmentation_weight)), 262 | # global_step=self.training_step_count) 263 | # self.logger.experiment.add_scalar('centerness_weight', 264 | # 1 / (2 * torch.exp(self.model.centerness_weight)), 265 | # global_step=self.training_step_count) 266 | # self.logger.experiment.add_scalar('offset_weight', 1 / (2 * torch.exp(self.model.offset_weight)), 267 | # global_step=self.training_step_count) 268 | # if self.cfg.INSTANCE_FLOW.ENABLED: 269 | # self.logger.experiment.add_scalar('flow_weight', 1 / (2 * torch.exp(self.model.flow_weight)), 270 | # global_step=self.training_step_count) 271 | 272 | # self.logger.experiment.add_scalar('reconstruction_weight', 1 / (2 * torch.exp(self.model.reconstruction_weight)), 273 | # global_step=self.training_step_count) 274 | # self.logger.experiment.add_scalar('kl_y_weight', 1 / (2 * torch.exp(self.model.kl_y_weight)), 275 | # global_step=self.training_step_count) 276 | # self.logger.experiment.add_scalar('kl_z_weight', 1 / (2 * torch.exp(self.model.kl_z_weight)), 277 | # global_step=self.training_step_count) 278 | 279 | 280 | def training_epoch_end(self, step_outputs): 281 | self.shared_epoch_end(step_outputs, True) 282 | 283 | def validation_epoch_end(self, step_outputs): 284 | self.shared_epoch_end(step_outputs, False) 285 | 286 | def configure_optimizers(self): 287 | params = self.model.parameters() 288 | optimizer = torch.optim.Adam( 289 | params, lr=self.cfg.OPTIMIZER.LR, weight_decay=self.cfg.OPTIMIZER.WEIGHT_DECAY 290 | ) 291 | # scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[60*self.len_loader,80*self.len_loader], gamma=0.1) 292 | 293 | return optimizer 294 | --------------------------------------------------------------------------------