├── .gitignore ├── fiery ├── configs │ ├── literature │ │ ├── static_lss_setting.yml │ │ ├── lift_splat_setting.yml │ │ ├── static_pon_setting.yml │ │ ├── pon_setting.yml │ │ └── fishing_setting.yml │ ├── lyft │ │ ├── debug_lyft.yml │ │ ├── single_timeframe.yml │ │ └── baseline.yml │ ├── debug_baseline.yml │ ├── temporal_single_timeframe.yml │ ├── single_timeframe.yml │ └── baseline.yml ├── utils │ ├── lyft_splits.py │ ├── network.py │ ├── geometry.py │ ├── visualisation.py │ └── instance.py ├── models │ ├── future_prediction.py │ ├── distributions.py │ ├── temporal_model.py │ ├── decoder.py │ ├── encoder.py │ └── fiery.py ├── losses.py ├── config.py ├── layers │ ├── convolutions.py │ └── temporal.py ├── metrics.py ├── trainer.py └── data.py ├── environment.yml ├── DATASET.md ├── LICENSE ├── train.py ├── evaluate.py ├── README.md └── visualise.py /.gitignore: -------------------------------------------------------------------------------- 1 | tensorboard_logs/ 2 | example_data/ 3 | output_vis/ 4 | .idea/ 5 | *.pyc 6 | *.ipynb_checkpoints 7 | *.ipynb 8 | *.ckpt -------------------------------------------------------------------------------- /fiery/configs/literature/static_lss_setting.yml: -------------------------------------------------------------------------------- 1 | _BASE_: '../single_timeframe.yml' 2 | 3 | TAG: 'lift_splat_setting' 4 | 5 | DATASET: 6 | FILTER_INVISIBLE_VEHICLES: False 7 | -------------------------------------------------------------------------------- /fiery/configs/literature/lift_splat_setting.yml: -------------------------------------------------------------------------------- 1 | _BASE_: '../temporal_single_timeframe.yml' 2 | 3 | TAG: 'temporal_lift_splat_setting' 4 | 5 | DATASET: 6 | FILTER_INVISIBLE_VEHICLES: False 7 | -------------------------------------------------------------------------------- /fiery/configs/literature/static_pon_setting.yml: -------------------------------------------------------------------------------- 1 | _BASE_: 'static_lss_setting.yml' 2 | 3 | TAG: 'pyramid_occupancy_network_setting' 4 | 5 | LIFT: 6 | X_BOUND: [-50.0, 50.0, 0.25] 7 | Y_BOUND: [-25.0, 25.0, 0.25] 8 | -------------------------------------------------------------------------------- /fiery/configs/lyft/debug_lyft.yml: -------------------------------------------------------------------------------- 1 | _BASE_: 'baseline.yml' 2 | 3 | TAG: 'debug' 4 | 5 | BATCHSIZE: 1 6 | GPUS: [0] 7 | 8 | DATASET: 9 | VERSION: 'mini' 10 | 11 | VIS_INTERVAL: 4 12 | 13 | N_WORKERS: 0 14 | -------------------------------------------------------------------------------- /fiery/configs/lyft/single_timeframe.yml: -------------------------------------------------------------------------------- 1 | _BASE_: '../single_timeframe.yml' 2 | 3 | TAG: 'lyft_single_frame' 4 | 5 | DATASET: 6 | NAME: 'lyft' 7 | 8 | IMAGE: 9 | H: 1080 10 | W: 1920 11 | RESIZE_SCALE: 0.25 12 | -------------------------------------------------------------------------------- /fiery/configs/debug_baseline.yml: -------------------------------------------------------------------------------- 1 | _BASE_: 'baseline.yml' 2 | 3 | TAG: 'debug' 4 | 5 | EPOCHS: 2 6 | BATCHSIZE: 1 7 | GPUS: [0] 8 | LOGGING_INTERVAL: 10 9 | 10 | DATASET: 11 | VERSION: 'mini' 12 | 13 | VIS_INTERVAL: 8 14 | -------------------------------------------------------------------------------- /fiery/configs/literature/pon_setting.yml: -------------------------------------------------------------------------------- 1 | _BASE_: '../temporal_single_timeframe.yml' 2 | 3 | TAG: 'temporal_pon_setting' 4 | 5 | DATASET: 6 | FILTER_INVISIBLE_VEHICLES: False 7 | 8 | LIFT: 9 | X_BOUND: [-50.0, 50.0, 0.25] 10 | Y_BOUND: [-25.0, 25.0, 0.25] 11 | -------------------------------------------------------------------------------- /fiery/configs/literature/fishing_setting.yml: -------------------------------------------------------------------------------- 1 | _BASE_: '../baseline.yml' 2 | 3 | TAG: 'fishing_setting' 4 | 5 | BATCHSIZE: 3 6 | 7 | DATASET: 8 | FILTER_INVISIBLE_VEHICLES: False 9 | 10 | LIFT: 11 | D_BOUND: [2.0, 16.0, 0.5] 12 | X_BOUND: [-16.0, 16.0, 0.1] 13 | Y_BOUND: [-9.6, 9.7, 0.1] 14 | -------------------------------------------------------------------------------- /fiery/configs/temporal_single_timeframe.yml: -------------------------------------------------------------------------------- 1 | _BASE_: 'single_timeframe.yml' 2 | 3 | TAG: 'temporal_single_timeframe' 4 | 5 | BATCHSIZE: 4 6 | PRECISION: 16 7 | 8 | TIME_RECEPTIVE_FIELD: 3 9 | 10 | MODEL: 11 | BN_MOMENTUM: 0.05 12 | TEMPORAL_MODEL: 13 | NAME: 'temporal_block' 14 | INPUT_EGOPOSE: True 15 | -------------------------------------------------------------------------------- /fiery/configs/lyft/baseline.yml: -------------------------------------------------------------------------------- 1 | _BASE_: '../baseline.yml' 2 | 3 | TAG: 'lyft_baseline' 4 | 5 | GPUS: [0, 1, 2, 3] 6 | 7 | BATCHSIZE: 3 8 | TIME_RECEPTIVE_FIELD: 5 9 | N_FUTURE_FRAMES: 10 10 | 11 | DATASET: 12 | NAME: 'lyft' 13 | 14 | IMAGE: 15 | H: 1080 16 | W: 1920 17 | RESIZE_SCALE: 0.25 18 | 19 | MODEL: 20 | SUBSAMPLE: True 21 | -------------------------------------------------------------------------------- /fiery/configs/single_timeframe.yml: -------------------------------------------------------------------------------- 1 | TAG: 'single_timeframe_model' 2 | 3 | GPUS: [0, 1] 4 | 5 | BATCHSIZE: 8 6 | 7 | TIME_RECEPTIVE_FIELD: 1 8 | N_FUTURE_FRAMES: 0 9 | 10 | PROBABILISTIC: 11 | ENABLED: False 12 | 13 | MODEL: 14 | TEMPORAL_MODEL: 15 | NAME: 'identity' 16 | INPUT_EGOPOSE: False 17 | 18 | INSTANCE_FLOW: 19 | ENABLED: False 20 | 21 | OPTIMIZER: 22 | LR: 1e-3 23 | 24 | N_WORKERS: 10 25 | -------------------------------------------------------------------------------- /fiery/configs/baseline.yml: -------------------------------------------------------------------------------- 1 | TAG: 'baseline' 2 | 3 | GPUS: [0, 1, 2, 3] 4 | 5 | BATCHSIZE: 3 6 | PRECISION: 16 7 | 8 | TIME_RECEPTIVE_FIELD: 3 9 | N_FUTURE_FRAMES: 4 10 | 11 | PROBABILISTIC: 12 | ENABLED: True 13 | 14 | MODEL: 15 | BN_MOMENTUM: 0.05 16 | TEMPORAL_MODEL: 17 | NAME: 'temporal_block' 18 | INPUT_EGOPOSE: True 19 | 20 | INSTANCE_FLOW: 21 | ENABLED: True 22 | 23 | OPTIMIZER: 24 | LR: 3e-4 25 | 26 | N_WORKERS: 10 27 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: fiery 2 | channels: 3 | - defaults 4 | - conda-forge 5 | - pytorch 6 | dependencies: 7 | - python=3.7.10 8 | - pytorch=1.7.0 9 | - torchvision=0.8.1 10 | - cudatoolkit=10.1 11 | - numpy=1.19.2 12 | - scipy=1.5.2 13 | - pillow=8.0.1 14 | - tqdm=4.50.2 15 | - pytorch-lightning=1.2.5 16 | - efficientnet-pytorch=0.7.0 17 | - fvcore=0.1.2.post20201122 18 | - pip=21.0.1 19 | - pip: 20 | - nuscenes-devkit==1.1.0 21 | - lyft-dataset-sdk==0.0.8 22 | - opencv-python==4.5.1.48 23 | - moviepy==1.0.3 24 | -------------------------------------------------------------------------------- /DATASET.md: -------------------------------------------------------------------------------- 1 | ## 📖 Dataset 2 | 3 | - Go to the official website of the [NuScenes dataset](https://www.nuscenes.org/download). You will need to create an 4 | account. 5 | - Download the _Full dataset (v1.0)_. This includes the _Mini_ dataset (Metadata and sensor file blobs) and the 6 | _Trainval_ dataset (Metadata, File blobs part 1-10). 7 | - Extract the tar files in order to obtain the following folder structure. The `nuscenes` dataset folder will be 8 | designated as `${NUSCENES_DATAROOT}`. 9 | ``` 10 | nuscenes 11 | │ 12 | └───trainval 13 | │ maps 14 | │ samples 15 | │ sweeps 16 | │ v1.0-trainval 17 | │ 18 | └───mini 19 | maps 20 | samples 21 | sweeps 22 | v1.0-mini 23 | ``` 24 | - The full dataset is around ~400GB. It is possible to reduce the dataset size to ~60GB by only downloading the 25 | keyframe blobs only part 1-10, instead of all the file blobs part 1-10. The keyframe blobs contain the data we 26 | need (RGB images and 3D bounding box of dynamic objects at 2Hz). The remaining file blobs also include RGB 27 | images and LiDAR sweeps at a higher frequency, but they are not used during training. -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Wayve Technologies Limited 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /fiery/utils/lyft_splits.py: -------------------------------------------------------------------------------- 1 | TRAIN_LYFT_INDICES = [1, 3, 5, 6, 7, 8, 9, 10, 11, 12, 14, 15, 16, 2 | 17, 18, 19, 20, 21, 23, 24, 27, 28, 29, 30, 31, 32, 3 | 33, 35, 36, 37, 39, 41, 43, 44, 45, 46, 47, 48, 49, 4 | 50, 51, 52, 53, 55, 56, 59, 60, 62, 63, 65, 68, 69, 5 | 70, 71, 72, 73, 74, 75, 76, 78, 79, 81, 82, 83, 84, 6 | 86, 87, 88, 89, 93, 95, 97, 98, 99, 103, 104, 107, 108, 7 | 109, 110, 111, 113, 114, 115, 116, 117, 118, 119, 121, 122, 124, 8 | 127, 128, 130, 131, 132, 134, 135, 136, 137, 138, 139, 143, 144, 9 | 146, 147, 148, 149, 150, 151, 152, 153, 154, 156, 157, 158, 159, 10 | 161, 162, 165, 166, 167, 171, 172, 173, 174, 175, 176, 177, 178, 11 | 179] 12 | 13 | VAL_LYFT_INDICES = [0, 2, 4, 13, 22, 25, 26, 34, 38, 40, 42, 54, 57, 14 | 58, 61, 64, 66, 67, 77, 80, 85, 90, 91, 92, 94, 96, 15 | 100, 101, 102, 105, 106, 112, 120, 123, 125, 126, 129, 133, 140, 16 | 141, 142, 145, 155, 160, 163, 164, 168, 169, 170] 17 | -------------------------------------------------------------------------------- /fiery/utils/network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision 4 | 5 | def pack_sequence_dim(x): 6 | b, s = x.shape[:2] 7 | return x.view(b * s, *x.shape[2:]) 8 | 9 | 10 | def unpack_sequence_dim(x, b, s): 11 | return x.view(b, s, *x.shape[1:]) 12 | 13 | 14 | def preprocess_batch(batch, device, unsqueeze=False): 15 | for key, value in batch.items(): 16 | if key != 'sample_token': 17 | batch[key] = value.to(device) 18 | if unsqueeze: 19 | batch[key] = batch[key].unsqueeze(0) 20 | 21 | 22 | def set_module_grad(module, requires_grad=False): 23 | for p in module.parameters(): 24 | p.requires_grad = requires_grad 25 | 26 | 27 | def set_bn_momentum(model, momentum=0.1): 28 | for m in model.modules(): 29 | if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)): 30 | m.momentum = momentum 31 | 32 | 33 | class NormalizeInverse(torchvision.transforms.Normalize): 34 | # https://discuss.pytorch.org/t/simple-way-to-inverse-transform-normalization/4821/8 35 | def __init__(self, mean, std): 36 | mean = torch.as_tensor(mean) 37 | std = torch.as_tensor(std) 38 | std_inv = 1 / (std + 1e-7) 39 | mean_inv = -mean * std_inv 40 | super().__init__(mean=mean_inv, std=std_inv) 41 | 42 | def __call__(self, tensor): 43 | return super().__call__(tensor.clone()) 44 | -------------------------------------------------------------------------------- /fiery/models/future_prediction.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from fiery.layers.convolutions import Bottleneck 4 | from fiery.layers.temporal import SpatialGRU 5 | 6 | 7 | class FuturePrediction(torch.nn.Module): 8 | def __init__(self, in_channels, latent_dim, n_gru_blocks=3, n_res_layers=3): 9 | super().__init__() 10 | self.n_gru_blocks = n_gru_blocks 11 | 12 | # Convolutional recurrent model with z_t as an initial hidden state and inputs the sample 13 | # from the probabilistic model. The architecture of the model is: 14 | # [Spatial GRU - [Bottleneck] x n_res_layers] x n_gru_blocks 15 | self.spatial_grus = [] 16 | self.res_blocks = [] 17 | 18 | for i in range(self.n_gru_blocks): 19 | gru_in_channels = latent_dim if i == 0 else in_channels 20 | self.spatial_grus.append(SpatialGRU(gru_in_channels, in_channels)) 21 | self.res_blocks.append(torch.nn.Sequential(*[Bottleneck(in_channels) 22 | for _ in range(n_res_layers)])) 23 | 24 | self.spatial_grus = torch.nn.ModuleList(self.spatial_grus) 25 | self.res_blocks = torch.nn.ModuleList(self.res_blocks) 26 | 27 | def forward(self, x, hidden_state): 28 | # x has shape (b, n_future, c, h, w), hidden_state (b, c, h, w) 29 | for i in range(self.n_gru_blocks): 30 | x = self.spatial_grus[i](x, hidden_state, flow=None) 31 | b, n_future, c, h, w = x.shape 32 | 33 | x = self.res_blocks[i](x.view(b * n_future, c, h, w)) 34 | x = x.view(b, n_future, c, h, w) 35 | 36 | return x 37 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import socket 4 | import torch 5 | import pytorch_lightning as pl 6 | from pytorch_lightning.plugins import DDPPlugin 7 | 8 | from fiery.config import get_parser, get_cfg 9 | from fiery.data import prepare_dataloaders 10 | from fiery.trainer import TrainingModule 11 | 12 | 13 | def main(): 14 | args = get_parser().parse_args() 15 | cfg = get_cfg(args) 16 | 17 | trainloader, valloader = prepare_dataloaders(cfg) 18 | model = TrainingModule(cfg.convert_to_dict()) 19 | 20 | if cfg.PRETRAINED.LOAD_WEIGHTS: 21 | # Load single-image instance segmentation model. 22 | pretrained_model_weights = torch.load( 23 | os.path.join(cfg.DATASET.DATAROOT, cfg.PRETRAINED.PATH), map_location='cpu' 24 | )['state_dict'] 25 | 26 | model.load_state_dict(pretrained_model_weights, strict=False) 27 | print(f'Loaded single-image model weights from {cfg.PRETRAINED.PATH}') 28 | 29 | save_dir = os.path.join( 30 | cfg.LOG_DIR, time.strftime('%d%B%Yat%H:%M:%S%Z') + '_' + socket.gethostname() + '_' + cfg.TAG 31 | ) 32 | tb_logger = pl.loggers.TensorBoardLogger(save_dir=save_dir) 33 | trainer = pl.Trainer( 34 | gpus=cfg.GPUS, 35 | accelerator='ddp', 36 | precision=cfg.PRECISION, 37 | sync_batchnorm=True, 38 | gradient_clip_val=cfg.GRAD_NORM_CLIP, 39 | max_epochs=cfg.EPOCHS, 40 | weights_summary='full', 41 | logger=tb_logger, 42 | log_every_n_steps=cfg.LOGGING_INTERVAL, 43 | plugins=DDPPlugin(find_unused_parameters=True), 44 | profiler='simple', 45 | ) 46 | trainer.fit(model, trainloader, valloader) 47 | 48 | 49 | if __name__ == "__main__": 50 | main() 51 | -------------------------------------------------------------------------------- /fiery/models/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from fiery.layers.convolutions import Bottleneck 5 | 6 | 7 | class DistributionModule(nn.Module): 8 | """ 9 | A convolutional net that parametrises a diagonal Gaussian distribution. 10 | """ 11 | 12 | def __init__( 13 | self, in_channels, latent_dim, min_log_sigma, max_log_sigma): 14 | super().__init__() 15 | self.compress_dim = in_channels // 2 16 | self.latent_dim = latent_dim 17 | self.min_log_sigma = min_log_sigma 18 | self.max_log_sigma = max_log_sigma 19 | 20 | self.encoder = DistributionEncoder( 21 | in_channels, 22 | self.compress_dim, 23 | ) 24 | self.last_conv = nn.Sequential( 25 | nn.AdaptiveAvgPool2d(1), nn.Conv2d(self.compress_dim, out_channels=2 * self.latent_dim, kernel_size=1) 26 | ) 27 | 28 | def forward(self, s_t): 29 | b, s = s_t.shape[:2] 30 | assert s == 1 31 | encoding = self.encoder(s_t[:, 0]) 32 | 33 | mu_log_sigma = self.last_conv(encoding).view(b, 1, 2 * self.latent_dim) 34 | mu = mu_log_sigma[:, :, :self.latent_dim] 35 | log_sigma = mu_log_sigma[:, :, self.latent_dim:] 36 | 37 | # clip the log_sigma value for numerical stability 38 | log_sigma = torch.clamp(log_sigma, self.min_log_sigma, self.max_log_sigma) 39 | return mu, log_sigma 40 | 41 | 42 | class DistributionEncoder(nn.Module): 43 | """Encodes s_t or (s_t, y_{t+1}, ..., y_{t+H}). 44 | """ 45 | def __init__(self, in_channels, out_channels): 46 | super().__init__() 47 | 48 | self.model = nn.Sequential( 49 | Bottleneck(in_channels, out_channels=out_channels, downsample=True), 50 | Bottleneck(out_channels, out_channels=out_channels, downsample=True), 51 | Bottleneck(out_channels, out_channels=out_channels, downsample=True), 52 | Bottleneck(out_channels, out_channels=out_channels, downsample=True), 53 | ) 54 | 55 | def forward(self, s_t): 56 | return self.model(s_t) 57 | -------------------------------------------------------------------------------- /fiery/models/temporal_model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from fiery.layers.temporal import Bottleneck3D, TemporalBlock 4 | 5 | 6 | class TemporalModel(nn.Module): 7 | def __init__( 8 | self, in_channels, receptive_field, input_shape, start_out_channels=64, extra_in_channels=0, 9 | n_spatial_layers_between_temporal_layers=0, use_pyramid_pooling=True): 10 | super().__init__() 11 | self.receptive_field = receptive_field 12 | n_temporal_layers = receptive_field - 1 13 | 14 | h, w = input_shape 15 | modules = [] 16 | 17 | block_in_channels = in_channels 18 | block_out_channels = start_out_channels 19 | 20 | for _ in range(n_temporal_layers): 21 | if use_pyramid_pooling: 22 | use_pyramid_pooling = True 23 | pool_sizes = [(2, h, w)] 24 | else: 25 | use_pyramid_pooling = False 26 | pool_sizes = None 27 | temporal = TemporalBlock( 28 | block_in_channels, 29 | block_out_channels, 30 | use_pyramid_pooling=use_pyramid_pooling, 31 | pool_sizes=pool_sizes, 32 | ) 33 | spatial = [ 34 | Bottleneck3D(block_out_channels, block_out_channels, kernel_size=(1, 3, 3)) 35 | for _ in range(n_spatial_layers_between_temporal_layers) 36 | ] 37 | temporal_spatial_layers = nn.Sequential(temporal, *spatial) 38 | modules.extend(temporal_spatial_layers) 39 | 40 | block_in_channels = block_out_channels 41 | block_out_channels += extra_in_channels 42 | 43 | self.out_channels = block_in_channels 44 | 45 | self.model = nn.Sequential(*modules) 46 | 47 | def forward(self, x): 48 | # Reshape input tensor to (batch, C, time, H, W) 49 | x = x.permute(0, 2, 1, 3, 4) 50 | x = self.model(x) 51 | x = x.permute(0, 2, 1, 3, 4).contiguous() 52 | return x[:, (self.receptive_field - 1):] 53 | 54 | 55 | class TemporalModelIdentity(nn.Module): 56 | def __init__(self, in_channels, receptive_field): 57 | super().__init__() 58 | self.receptive_field = receptive_field 59 | self.out_channels = in_channels 60 | 61 | def forward(self, x): 62 | return x[:, (self.receptive_field - 1):] 63 | -------------------------------------------------------------------------------- /fiery/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class SpatialRegressionLoss(nn.Module): 7 | def __init__(self, norm, ignore_index=255, future_discount=1.0): 8 | super(SpatialRegressionLoss, self).__init__() 9 | self.norm = norm 10 | self.ignore_index = ignore_index 11 | self.future_discount = future_discount 12 | 13 | if norm == 1: 14 | self.loss_fn = F.l1_loss 15 | elif norm == 2: 16 | self.loss_fn = F.mse_loss 17 | else: 18 | raise ValueError(f'Expected norm 1 or 2, but got norm={norm}') 19 | 20 | def forward(self, prediction, target): 21 | assert len(prediction.shape) == 5, 'Must be a 5D tensor' 22 | # ignore_index is the same across all channels 23 | mask = target[:, :, :1] != self.ignore_index 24 | if mask.sum() == 0: 25 | return prediction.new_zeros(1)[0].float() 26 | 27 | loss = self.loss_fn(prediction, target, reduction='none') 28 | 29 | # Sum channel dimension 30 | loss = torch.sum(loss, dim=-3, keepdims=True) 31 | 32 | seq_len = loss.shape[1] 33 | future_discounts = self.future_discount ** torch.arange(seq_len, device=loss.device, dtype=loss.dtype) 34 | future_discounts = future_discounts.view(1, seq_len, 1, 1, 1) 35 | loss = loss * future_discounts 36 | 37 | return loss[mask].mean() 38 | 39 | 40 | class SegmentationLoss(nn.Module): 41 | def __init__(self, class_weights, ignore_index=255, use_top_k=False, top_k_ratio=1.0, future_discount=1.0): 42 | super().__init__() 43 | self.class_weights = class_weights 44 | self.ignore_index = ignore_index 45 | self.use_top_k = use_top_k 46 | self.top_k_ratio = top_k_ratio 47 | self.future_discount = future_discount 48 | 49 | def forward(self, prediction, target): 50 | if target.shape[-3] != 1: 51 | raise ValueError('segmentation label must be an index-label with channel dimension = 1.') 52 | b, s, c, h, w = prediction.shape 53 | 54 | prediction = prediction.view(b * s, c, h, w) 55 | target = target.view(b * s, h, w) 56 | loss = F.cross_entropy( 57 | prediction, 58 | target, 59 | ignore_index=self.ignore_index, 60 | reduction='none', 61 | weight=self.class_weights.to(target.device), 62 | ) 63 | 64 | loss = loss.view(b, s, h, w) 65 | 66 | future_discounts = self.future_discount ** torch.arange(s, device=loss.device, dtype=loss.dtype) 67 | future_discounts = future_discounts.view(1, s, 1, 1) 68 | loss = loss * future_discounts 69 | 70 | loss = loss.view(b, s, -1) 71 | if self.use_top_k: 72 | # Penalises the top-k hardest pixels 73 | k = int(self.top_k_ratio * loss.shape[2]) 74 | loss, _ = torch.sort(loss, dim=2, descending=True) 75 | loss = loss[:, :, :k] 76 | 77 | return torch.mean(loss) 78 | 79 | 80 | class ProbabilisticLoss(nn.Module): 81 | def forward(self, output): 82 | present_mu = output['present_mu'] 83 | present_log_sigma = output['present_log_sigma'] 84 | future_mu = output['future_mu'] 85 | future_log_sigma = output['future_log_sigma'] 86 | 87 | var_future = torch.exp(2 * future_log_sigma) 88 | var_present = torch.exp(2 * present_log_sigma) 89 | kl_div = ( 90 | present_log_sigma - future_log_sigma - 0.5 + (var_future + (future_mu - present_mu) ** 2) / ( 91 | 2 * var_present) 92 | ) 93 | 94 | kl_loss = torch.mean(torch.sum(kl_div, dim=-1)) 95 | 96 | return kl_loss 97 | -------------------------------------------------------------------------------- /fiery/models/decoder.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torchvision.models.resnet import resnet18 3 | 4 | from fiery.layers.convolutions import UpsamplingAdd 5 | 6 | 7 | class Decoder(nn.Module): 8 | def __init__(self, in_channels, n_classes, predict_future_flow): 9 | super().__init__() 10 | backbone = resnet18(pretrained=False, zero_init_residual=True) 11 | self.first_conv = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False) 12 | self.bn1 = backbone.bn1 13 | self.relu = backbone.relu 14 | 15 | self.layer1 = backbone.layer1 16 | self.layer2 = backbone.layer2 17 | self.layer3 = backbone.layer3 18 | self.predict_future_flow = predict_future_flow 19 | 20 | shared_out_channels = in_channels 21 | self.up3_skip = UpsamplingAdd(256, 128, scale_factor=2) 22 | self.up2_skip = UpsamplingAdd(128, 64, scale_factor=2) 23 | self.up1_skip = UpsamplingAdd(64, shared_out_channels, scale_factor=2) 24 | 25 | self.segmentation_head = nn.Sequential( 26 | nn.Conv2d(shared_out_channels, shared_out_channels, kernel_size=3, padding=1, bias=False), 27 | nn.BatchNorm2d(shared_out_channels), 28 | nn.ReLU(inplace=True), 29 | nn.Conv2d(shared_out_channels, n_classes, kernel_size=1, padding=0), 30 | ) 31 | self.instance_offset_head = nn.Sequential( 32 | nn.Conv2d(shared_out_channels, shared_out_channels, kernel_size=3, padding=1, bias=False), 33 | nn.BatchNorm2d(shared_out_channels), 34 | nn.ReLU(inplace=True), 35 | nn.Conv2d(shared_out_channels, 2, kernel_size=1, padding=0), 36 | ) 37 | self.instance_center_head = nn.Sequential( 38 | nn.Conv2d(shared_out_channels, shared_out_channels, kernel_size=3, padding=1, bias=False), 39 | nn.BatchNorm2d(shared_out_channels), 40 | nn.ReLU(inplace=True), 41 | nn.Conv2d(shared_out_channels, 1, kernel_size=1, padding=0), 42 | nn.Sigmoid(), 43 | ) 44 | 45 | if self.predict_future_flow: 46 | self.instance_future_head = nn.Sequential( 47 | nn.Conv2d(shared_out_channels, shared_out_channels, kernel_size=3, padding=1, bias=False), 48 | nn.BatchNorm2d(shared_out_channels), 49 | nn.ReLU(inplace=True), 50 | nn.Conv2d(shared_out_channels, 2, kernel_size=1, padding=0), 51 | ) 52 | 53 | def forward(self, x): 54 | b, s, c, h, w = x.shape 55 | x = x.view(b * s, c, h, w) 56 | 57 | # (H, W) 58 | skip_x = {'1': x} 59 | x = self.first_conv(x) 60 | x = self.bn1(x) 61 | x = self.relu(x) 62 | 63 | # (H/4, W/4) 64 | x = self.layer1(x) 65 | skip_x['2'] = x 66 | x = self.layer2(x) 67 | skip_x['3'] = x 68 | 69 | # (H/8, W/8) 70 | x = self.layer3(x) 71 | 72 | # First upsample to (H/4, W/4) 73 | x = self.up3_skip(x, skip_x['3']) 74 | 75 | # Second upsample to (H/2, W/2) 76 | x = self.up2_skip(x, skip_x['2']) 77 | 78 | # Third upsample to (H, W) 79 | x = self.up1_skip(x, skip_x['1']) 80 | 81 | segmentation_output = self.segmentation_head(x) 82 | instance_center_output = self.instance_center_head(x) 83 | instance_offset_output = self.instance_offset_head(x) 84 | instance_future_output = self.instance_future_head(x) if self.predict_future_flow else None 85 | return { 86 | 'segmentation': segmentation_output.view(b, s, *segmentation_output.shape[1:]), 87 | 'instance_center': instance_center_output.view(b, s, *instance_center_output.shape[1:]), 88 | 'instance_offset': instance_offset_output.view(b, s, *instance_offset_output.shape[1:]), 89 | 'instance_flow': instance_future_output.view(b, s, *instance_future_output.shape[1:]) 90 | if instance_future_output is not None else None, 91 | } 92 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | import torch 4 | from tqdm import tqdm 5 | 6 | from fiery.data import prepare_dataloaders 7 | from fiery.trainer import TrainingModule 8 | from fiery.metrics import IntersectionOverUnion, PanopticMetric 9 | from fiery.utils.network import preprocess_batch 10 | from fiery.utils.instance import predict_instance_segmentation_and_trajectories 11 | 12 | # 30mx30m, 100mx100m 13 | EVALUATION_RANGES = {'30x30': (70, 130), 14 | '100x100': (0, 200) 15 | } 16 | 17 | 18 | def eval(checkpoint_path, dataroot, version): 19 | trainer = TrainingModule.load_from_checkpoint(checkpoint_path, strict=True) 20 | print(f'Loaded weights from \n {checkpoint_path}') 21 | trainer.eval() 22 | 23 | device = torch.device('cuda:0') 24 | trainer.to(device) 25 | model = trainer.model 26 | 27 | cfg = model.cfg 28 | cfg.GPUS = "[0]" 29 | cfg.BATCHSIZE = 1 30 | 31 | cfg.DATASET.DATAROOT = dataroot 32 | cfg.DATASET.VERSION = version 33 | 34 | _, valloader = prepare_dataloaders(cfg) 35 | 36 | panoptic_metrics = {} 37 | iou_metrics = {} 38 | n_classes = len(cfg.SEMANTIC_SEG.WEIGHTS) 39 | for key in EVALUATION_RANGES.keys(): 40 | panoptic_metrics[key] = PanopticMetric(n_classes=n_classes, temporally_consistent=True).to( 41 | device) 42 | iou_metrics[key] = IntersectionOverUnion(n_classes).to(device) 43 | 44 | for i, batch in enumerate(tqdm(valloader)): 45 | preprocess_batch(batch, device) 46 | image = batch['image'] 47 | intrinsics = batch['intrinsics'] 48 | extrinsics = batch['extrinsics'] 49 | future_egomotion = batch['future_egomotion'] 50 | 51 | batch_size = image.shape[0] 52 | 53 | labels, future_distribution_inputs = trainer.prepare_future_labels(batch) 54 | 55 | with torch.no_grad(): 56 | # Evaluate with mean prediction 57 | noise = torch.zeros((batch_size, 1, model.latent_dim), device=device) 58 | output = model(image, intrinsics, extrinsics, future_egomotion, 59 | future_distribution_inputs, noise=noise) 60 | 61 | # Consistent instance seg 62 | pred_consistent_instance_seg = predict_instance_segmentation_and_trajectories( 63 | output, compute_matched_centers=False, make_consistent=True 64 | ) 65 | 66 | segmentation_pred = output['segmentation'].detach() 67 | segmentation_pred = torch.argmax(segmentation_pred, dim=2, keepdims=True) 68 | 69 | for key, grid in EVALUATION_RANGES.items(): 70 | limits = slice(grid[0], grid[1]) 71 | panoptic_metrics[key](pred_consistent_instance_seg[..., limits, limits].contiguous().detach(), 72 | labels['instance'][..., limits, limits].contiguous() 73 | ) 74 | 75 | iou_metrics[key](segmentation_pred[..., limits, limits].contiguous(), 76 | labels['segmentation'][..., limits, limits].contiguous() 77 | ) 78 | 79 | results = {} 80 | for key, grid in EVALUATION_RANGES.items(): 81 | panoptic_scores = panoptic_metrics[key].compute() 82 | for panoptic_key, value in panoptic_scores.items(): 83 | results[f'{panoptic_key}'] = results.get(f'{panoptic_key}', []) + [100 * value[1].item()] 84 | 85 | iou_scores = iou_metrics[key].compute() 86 | results['iou'] = results.get('iou', []) + [100 * iou_scores[1].item()] 87 | 88 | for panoptic_key in ['iou', 'pq', 'sq', 'rq']: 89 | print(panoptic_key) 90 | print(' & '.join([f'{x:.1f}' for x in results[panoptic_key]])) 91 | 92 | 93 | if __name__ == '__main__': 94 | parser = ArgumentParser(description='Fiery evaluation') 95 | parser.add_argument('--checkpoint', default='./fiery.ckpt', type=str, help='path to checkpoint') 96 | parser.add_argument('--dataroot', default='./nuscenes', type=str, help='path to the dataset') 97 | parser.add_argument('--version', default='trainval', type=str, choices=['mini', 'trainval'], 98 | help='dataset version') 99 | 100 | args = parser.parse_args() 101 | 102 | eval(args.checkpoint, args.dataroot, args.version) 103 | -------------------------------------------------------------------------------- /fiery/models/encoder.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from efficientnet_pytorch import EfficientNet 3 | 4 | from fiery.layers.convolutions import UpsamplingConcat 5 | 6 | 7 | class Encoder(nn.Module): 8 | def __init__(self, cfg, D): 9 | super().__init__() 10 | self.D = D 11 | self.C = cfg.OUT_CHANNELS 12 | self.use_depth_distribution = cfg.USE_DEPTH_DISTRIBUTION 13 | self.downsample = cfg.DOWNSAMPLE 14 | self.version = cfg.NAME.split('-')[1] 15 | 16 | self.backbone = EfficientNet.from_pretrained(cfg.NAME) 17 | self.delete_unused_layers() 18 | 19 | if self.downsample == 16: 20 | if self.version == 'b0': 21 | upsampling_in_channels = 320 + 112 22 | elif self.version == 'b4': 23 | upsampling_in_channels = 448 + 160 24 | upsampling_out_channels = 512 25 | elif self.downsample == 8: 26 | if self.version == 'b0': 27 | upsampling_in_channels = 112 + 40 28 | elif self.version == 'b4': 29 | upsampling_in_channels = 160 + 56 30 | upsampling_out_channels = 128 31 | else: 32 | raise ValueError(f'Downsample factor {self.downsample} not handled.') 33 | 34 | self.upsampling_layer = UpsamplingConcat(upsampling_in_channels, upsampling_out_channels) 35 | if self.use_depth_distribution: 36 | self.depth_layer = nn.Conv2d(upsampling_out_channels, self.C + self.D, kernel_size=1, padding=0) 37 | else: 38 | self.depth_layer = nn.Conv2d(upsampling_out_channels, self.C, kernel_size=1, padding=0) 39 | 40 | def delete_unused_layers(self): 41 | indices_to_delete = [] 42 | for idx in range(len(self.backbone._blocks)): 43 | if self.downsample == 8: 44 | if self.version == 'b0' and idx > 10: 45 | indices_to_delete.append(idx) 46 | if self.version == 'b4' and idx > 21: 47 | indices_to_delete.append(idx) 48 | 49 | for idx in reversed(indices_to_delete): 50 | del self.backbone._blocks[idx] 51 | 52 | del self.backbone._conv_head 53 | del self.backbone._bn1 54 | del self.backbone._avg_pooling 55 | del self.backbone._dropout 56 | del self.backbone._fc 57 | 58 | def get_features(self, x): 59 | # Adapted from https://github.com/lukemelas/EfficientNet-PyTorch/blob/master/efficientnet_pytorch/model.py#L231 60 | endpoints = dict() 61 | 62 | # Stem 63 | x = self.backbone._swish(self.backbone._bn0(self.backbone._conv_stem(x))) 64 | prev_x = x 65 | 66 | # Blocks 67 | for idx, block in enumerate(self.backbone._blocks): 68 | drop_connect_rate = self.backbone._global_params.drop_connect_rate 69 | if drop_connect_rate: 70 | drop_connect_rate *= float(idx) / len(self.backbone._blocks) 71 | x = block(x, drop_connect_rate=drop_connect_rate) 72 | if prev_x.size(2) > x.size(2): 73 | endpoints['reduction_{}'.format(len(endpoints) + 1)] = prev_x 74 | prev_x = x 75 | 76 | if self.downsample == 8: 77 | if self.version == 'b0' and idx == 10: 78 | break 79 | if self.version == 'b4' and idx == 21: 80 | break 81 | 82 | # Head 83 | endpoints['reduction_{}'.format(len(endpoints) + 1)] = x 84 | 85 | if self.downsample == 16: 86 | input_1, input_2 = endpoints['reduction_5'], endpoints['reduction_4'] 87 | elif self.downsample == 8: 88 | input_1, input_2 = endpoints['reduction_4'], endpoints['reduction_3'] 89 | 90 | x = self.upsampling_layer(input_1, input_2) 91 | return x 92 | 93 | def forward(self, x): 94 | x = self.get_features(x) # get feature vector 95 | 96 | x = self.depth_layer(x) # feature and depth head 97 | 98 | if self.use_depth_distribution: 99 | depth = x[:, : self.D].softmax(dim=1) 100 | x = depth.unsqueeze(1) * x[:, self.D : (self.D + self.C)].unsqueeze(2) # outer product depth and features 101 | else: 102 | x = x.unsqueeze(2).repeat(1, 1, self.D, 1, 1) 103 | 104 | return x 105 | -------------------------------------------------------------------------------- /fiery/config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from fvcore.common.config import CfgNode as _CfgNode 3 | 4 | 5 | def convert_to_dict(cfg_node, key_list=[]): 6 | """Convert a config node to dictionary.""" 7 | _VALID_TYPES = {tuple, list, str, int, float, bool} 8 | if not isinstance(cfg_node, _CfgNode): 9 | if type(cfg_node) not in _VALID_TYPES: 10 | print( 11 | 'Key {} with value {} is not a valid type; valid types: {}'.format( 12 | '.'.join(key_list), type(cfg_node), _VALID_TYPES 13 | ), 14 | ) 15 | return cfg_node 16 | else: 17 | cfg_dict = dict(cfg_node) 18 | for k, v in cfg_dict.items(): 19 | cfg_dict[k] = convert_to_dict(v, key_list + [k]) 20 | return cfg_dict 21 | 22 | 23 | class CfgNode(_CfgNode): 24 | """Remove once https://github.com/rbgirshick/yacs/issues/19 is merged.""" 25 | 26 | def convert_to_dict(self): 27 | return convert_to_dict(self) 28 | 29 | 30 | CN = CfgNode 31 | 32 | _C = CN() 33 | _C.LOG_DIR = 'tensorboard_logs' 34 | _C.TAG = 'default' 35 | 36 | _C.GPUS = [0] # gpus to use 37 | _C.PRECISION = 32 # 16bit or 32bit 38 | _C.BATCHSIZE = 3 39 | _C.EPOCHS = 20 40 | 41 | _C.N_WORKERS = 5 42 | _C.VIS_INTERVAL = 5000 43 | _C.LOGGING_INTERVAL = 500 44 | 45 | _C.PRETRAINED = CN() 46 | _C.PRETRAINED.LOAD_WEIGHTS = False 47 | _C.PRETRAINED.PATH = '' 48 | 49 | _C.DATASET = CN() 50 | _C.DATASET.DATAROOT = './nuscenes/' 51 | _C.DATASET.VERSION = 'trainval' 52 | _C.DATASET.NAME = 'nuscenes' 53 | _C.DATASET.IGNORE_INDEX = 255 # Ignore index when creating flow/offset labels 54 | _C.DATASET.FILTER_INVISIBLE_VEHICLES = True # Filter vehicles that are not visible from the cameras 55 | 56 | _C.TIME_RECEPTIVE_FIELD = 3 # how many frames of temporal context (1 for single timeframe) 57 | _C.N_FUTURE_FRAMES = 4 # how many time steps into the future to predict 58 | 59 | _C.IMAGE = CN() 60 | _C.IMAGE.FINAL_DIM = (224, 480) 61 | _C.IMAGE.RESIZE_SCALE = 0.3 62 | _C.IMAGE.TOP_CROP = 46 63 | _C.IMAGE.ORIGINAL_HEIGHT = 900 # Original input RGB camera height 64 | _C.IMAGE.ORIGINAL_WIDTH = 1600 # Original input RGB camera width 65 | _C.IMAGE.NAMES = ['CAM_FRONT_LEFT', 'CAM_FRONT', 'CAM_FRONT_RIGHT', 'CAM_BACK_LEFT', 'CAM_BACK', 'CAM_BACK_RIGHT'] 66 | 67 | _C.LIFT = CN() # image to BEV lifting 68 | _C.LIFT.X_BOUND = [-50.0, 50.0, 0.5] # Forward 69 | _C.LIFT.Y_BOUND = [-50.0, 50.0, 0.5] # Sides 70 | _C.LIFT.Z_BOUND = [-10.0, 10.0, 20.0] # Height 71 | _C.LIFT.D_BOUND = [2.0, 50.0, 1.0] 72 | 73 | _C.MODEL = CN() 74 | 75 | _C.MODEL.ENCODER = CN() 76 | _C.MODEL.ENCODER.DOWNSAMPLE = 8 77 | _C.MODEL.ENCODER.NAME = 'efficientnet-b4' 78 | _C.MODEL.ENCODER.OUT_CHANNELS = 64 79 | _C.MODEL.ENCODER.USE_DEPTH_DISTRIBUTION = True 80 | 81 | _C.MODEL.TEMPORAL_MODEL = CN() 82 | _C.MODEL.TEMPORAL_MODEL.NAME = 'temporal_block' # type of temporal model 83 | _C.MODEL.TEMPORAL_MODEL.START_OUT_CHANNELS = 64 84 | _C.MODEL.TEMPORAL_MODEL.EXTRA_IN_CHANNELS = 0 85 | _C.MODEL.TEMPORAL_MODEL.INBETWEEN_LAYERS = 0 86 | _C.MODEL.TEMPORAL_MODEL.PYRAMID_POOLING = True 87 | _C.MODEL.TEMPORAL_MODEL.INPUT_EGOPOSE = True 88 | 89 | _C.MODEL.DISTRIBUTION = CN() 90 | _C.MODEL.DISTRIBUTION.LATENT_DIM = 32 91 | _C.MODEL.DISTRIBUTION.MIN_LOG_SIGMA = -5.0 92 | _C.MODEL.DISTRIBUTION.MAX_LOG_SIGMA = 5.0 93 | 94 | _C.MODEL.FUTURE_PRED = CN() 95 | _C.MODEL.FUTURE_PRED.N_GRU_BLOCKS = 3 96 | _C.MODEL.FUTURE_PRED.N_RES_LAYERS = 3 97 | 98 | _C.MODEL.DECODER = CN() 99 | 100 | _C.MODEL.BN_MOMENTUM = 0.1 101 | _C.MODEL.SUBSAMPLE = False # Subsample frames for Lyft 102 | 103 | _C.SEMANTIC_SEG = CN() 104 | _C.SEMANTIC_SEG.WEIGHTS = [1.0, 2.0] # per class cross entropy weights (bg, dynamic) 105 | _C.SEMANTIC_SEG.USE_TOP_K = True # backprop only top-k hardest pixels 106 | _C.SEMANTIC_SEG.TOP_K_RATIO = 0.25 107 | 108 | _C.INSTANCE_SEG = CN() 109 | 110 | _C.INSTANCE_FLOW = CN() 111 | _C.INSTANCE_FLOW.ENABLED = True 112 | 113 | _C.PROBABILISTIC = CN() 114 | _C.PROBABILISTIC.ENABLED = True # learn a distribution over futures 115 | _C.PROBABILISTIC.WEIGHT = 100.0 116 | _C.PROBABILISTIC.FUTURE_DIM = 6 # number of dimension added (future flow, future centerness, offset, seg) 117 | 118 | _C.FUTURE_DISCOUNT = 0.95 119 | 120 | _C.OPTIMIZER = CN() 121 | _C.OPTIMIZER.LR = 3e-4 122 | _C.OPTIMIZER.WEIGHT_DECAY = 1e-7 123 | _C.GRAD_NORM_CLIP = 5 124 | 125 | 126 | def get_parser(): 127 | parser = argparse.ArgumentParser(description='Fiery training') 128 | # TODO: remove below? 129 | parser.add_argument('--config-file', default='', metavar='FILE', help='path to config file') 130 | parser.add_argument( 131 | 'opts', help='Modify config options using the command-line', default=None, nargs=argparse.REMAINDER, 132 | ) 133 | return parser 134 | 135 | 136 | def get_cfg(args=None, cfg_dict=None): 137 | """ First get default config. Then merge cfg_dict. Then merge according to args. """ 138 | 139 | cfg = _C.clone() 140 | 141 | if cfg_dict is not None: 142 | cfg.merge_from_other_cfg(CfgNode(cfg_dict)) 143 | 144 | if args is not None: 145 | if args.config_file: 146 | cfg.merge_from_file(args.config_file) 147 | cfg.merge_from_list(args.opts) 148 | cfg.freeze() 149 | return cfg 150 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FIERY 2 | This is the PyTorch implementation for inference and training of the future prediction bird's-eye view network as 3 | described in: 4 | 5 | > **FIERY: Future Instance Segmentation in Bird's-Eye view from Surround Monocular Cameras** 6 | > 7 | > [Anthony Hu](https://anthonyhu.github.io/), [Zak Murez](http://zak.murez.com/), 8 | [Nikhil Mohan](https://uk.linkedin.com/in/nikhilmohan33), 9 | [Sofía Dudas](https://uk.linkedin.com/in/sof%C3%ADa-josefina-lago-dudas-2b0737132), 10 | [Jeffrey Hawke](https://uk.linkedin.com/in/jeffrey-hawke), 11 | [‪Vijay Badrinarayanan](https://sites.google.com/site/vijaybacademichomepage/home), 12 | [Roberto Cipolla](https://mi.eng.cam.ac.uk/~cipolla/index.htm) and [Alex Kendall](https://alexgkendall.com/) 13 | > 14 | > [ICCV 2021 (Oral)](https://arxiv.org/abs/2104.10490)
15 | > [Blog post](https://wayve.ai/blog/fiery-future-instance-prediction-birds-eye-view) 16 | 17 |

18 | FIERY future prediction 19 |
20 | Multimodal future predictions by our bird’s-eye view network.
21 | Top two rows: RGB camera inputs. The predicted future trajectories and segmentations are projected to the ground plane in the images.
22 | Bottom row: future instance prediction in bird’s-eye view in a 100m×100m capture size around the ego-vehicle, which is indicated by a black rectangle in the center. 23 |
24 |

25 | 26 | If you find our work useful, please consider citing: 27 | ```bibtex 28 | @inproceedings{fiery2021, 29 | title = {{FIERY}: Future Instance Segmentation in Bird's-Eye view from Surround Monocular Cameras}, 30 | author = {Anthony Hu and Zak Murez and Nikhil Mohan and Sofía Dudas and 31 | Jeffrey Hawke and Vijay Badrinarayanan and Roberto Cipolla and Alex Kendall}, 32 | booktitle = {Proceedings of the International Conference on Computer Vision ({ICCV})}, 33 | year = {2021} 34 | } 35 | ``` 36 | 37 | ## ⚙ Setup 38 | - Create the [conda](https://docs.conda.io/en/latest/miniconda.html) environment by running `conda env create`. 39 | 40 | ## 🏄 Prediction 41 | ### Visualisation 42 | 43 | In a colab notebook: 44 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/12ahc3whI1RQZIVDi53grMWHzdA7WqIuo?usp=sharing) 45 | 46 | Or locally: 47 | - Download [pre-trained weights](https://github.com/wayveai/fiery/releases/download/v1.0/fiery.ckpt). 48 | - Run `python visualise.py --checkpoint ${CHECKPOINT_PATH}`. This will render predictions from the network and save 49 | them to an `output_vis` folder. 50 | 51 | ### Evaluation 52 | - Download the [NuScenes dataset](https://www.nuscenes.org/download). For detailed instructions, see [DATASET.md](DATASET.md). 53 | - Download [pre-trained weights](https://github.com/wayveai/fiery/releases/download/v1.0/fiery.ckpt). 54 | - Run `python evaluate.py --checkpoint ${CHECKPOINT_PATH} --dataroot ${NUSCENES_DATAROOT}`. 55 | 56 | ## 🔥 Pre-trained models 57 | 58 | All the configs are in the folder `fiery/configs` 59 | 60 | | Config and weights | Dataset | Past context | Future horizon | BEV size | IoU | VPQ| 61 | |--------------|---------|-----------------------|----------------|----------|------|----| 62 | | [`baseline.yml`](https://github.com/wayveai/fiery/releases/download/v1.0/fiery.ckpt) | NuScenes | 1.0s | 2.0s | 100mx100m (50cm res.) | 36.7 | 29.9 | 63 | | [`lyft/baseline.yml`](https://github.com/wayveai/fiery/releases/download/v1.0/lyft_fiery.ckpt) | Lyft | 0.8s | 2.0s| 100mx100m (50cm res.) | 36.3 | 29.2 | 64 | | [`literature/static_pon_setting.yml`](https://github.com/wayveai/fiery/releases/download/v1.0/static_pon_setting.ckpt) | NuScenes| 0.0s | 0.0s | 100mx50m (25cm res.) | 37.7| - | 65 | | [`literature/pon_setting.yml`](https://github.com/wayveai/fiery/releases/download/v1.0/pon_setting.ckpt) | NuScenes| 1.0s | 0.0s | 100mx50m (25cm res.) |39.9 | - | 66 | | [`literature/static_lss_setting.yml`](https://github.com/wayveai/fiery/releases/download/v1.0/static_lift_splat_setting.ckpt) | NuScenes | 0.0s | 0.0s | 100mx100m (50cm res.) | 35.8 | - | 67 | | [`literature/lift_splat_setting.yml`](https://github.com/wayveai/fiery/releases/download/v1.0/lift_splat_setting.ckpt) | NuScenes | 1.0s | 0.0s | 100mx100m (50cm res.) | 38.2 | - | 68 | | [`literature/fishing_setting.yml`](https://github.com/wayveai/fiery/releases/download/v1.0/fishing_setting.ckpt) | NuScenes | 1.0s | 2.0s | 32.0mx19.2m (10cm res.) | 57.6 | - | 69 | 70 | 71 | ## 🏊 Training 72 | To train the model from scratch on NuScenes: 73 | - Download the [NuScenes dataset](https://www.nuscenes.org/download). For detailed instructions, see [DATASET.md](DATASET.md). 74 | - Run `python train.py --config fiery/configs/baseline.yml DATASET.DATAROOT ${NUSCENES_DATAROOT}`. 75 | 76 | This will train the model on 4 GPUs, each with a batch of size 3. To train on single GPU add the flag `GPUS 1`, and to change the batch 77 | size use the flag `BATCHSIZE ${DESIRED_BATCHSIZE}`. 78 | 79 | ## 🙌 Credits 80 | Big thanks to Giulio D'Ippolito ([@gdippolito](https://github.com/gdippolito)) for the technical help on the gpu 81 | servers, Piotr Sokólski ([@pyetras](https://github.com/pyetras)) for the panoptic metric implementation, and to Hannes Liik ([@hannesliik](https://github.com/hannesliik)) 82 | for the awesome future trajectory visualisation on the ground plane. 83 | -------------------------------------------------------------------------------- /visualise.py: -------------------------------------------------------------------------------- 1 | import os 2 | from argparse import ArgumentParser 3 | from glob import glob 4 | 5 | import cv2 6 | import numpy as np 7 | import torch 8 | import torchvision 9 | import matplotlib as mpl 10 | import matplotlib.pyplot as plt 11 | from PIL import Image 12 | 13 | from fiery.trainer import TrainingModule 14 | from fiery.utils.network import NormalizeInverse 15 | from fiery.utils.instance import predict_instance_segmentation_and_trajectories 16 | from fiery.utils.visualisation import plot_instance_map, generate_instance_colours, make_contour, convert_figure_numpy 17 | 18 | EXAMPLE_DATA_PATH = 'example_data' 19 | 20 | 21 | def plot_prediction(image, output, cfg): 22 | # Process predictions 23 | consistent_instance_seg, matched_centers = predict_instance_segmentation_and_trajectories( 24 | output, compute_matched_centers=True 25 | ) 26 | 27 | # Plot future trajectories 28 | unique_ids = torch.unique(consistent_instance_seg[0, 0]).cpu().long().numpy()[1:] 29 | instance_map = dict(zip(unique_ids, unique_ids)) 30 | instance_colours = generate_instance_colours(instance_map) 31 | vis_image = plot_instance_map(consistent_instance_seg[0, 0].cpu().numpy(), instance_map) 32 | trajectory_img = np.zeros(vis_image.shape, dtype=np.uint8) 33 | for instance_id in unique_ids: 34 | path = matched_centers[instance_id] 35 | for t in range(len(path) - 1): 36 | color = instance_colours[instance_id].tolist() 37 | cv2.line(trajectory_img, tuple(path[t]), tuple(path[t + 1]), 38 | color, 4) 39 | 40 | # Overlay arrows 41 | temp_img = cv2.addWeighted(vis_image, 0.7, trajectory_img, 0.3, 1.0) 42 | mask = ~ np.all(trajectory_img == 0, axis=2) 43 | vis_image[mask] = temp_img[mask] 44 | 45 | # Plot present RGB frames and predictions 46 | val_w = 2.99 47 | cameras = cfg.IMAGE.NAMES 48 | image_ratio = cfg.IMAGE.FINAL_DIM[0] / cfg.IMAGE.FINAL_DIM[1] 49 | val_h = val_w * image_ratio 50 | fig = plt.figure(figsize=(4 * val_w, 2 * val_h)) 51 | width_ratios = (val_w, val_w, val_w, val_w) 52 | gs = mpl.gridspec.GridSpec(2, 4, width_ratios=width_ratios) 53 | gs.update(wspace=0.0, hspace=0.0, left=0.0, right=1.0, top=1.0, bottom=0.0) 54 | 55 | denormalise_img = torchvision.transforms.Compose( 56 | (NormalizeInverse(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 57 | torchvision.transforms.ToPILImage(),) 58 | ) 59 | for imgi, img in enumerate(image[0, -1]): 60 | ax = plt.subplot(gs[imgi // 3, imgi % 3]) 61 | showimg = denormalise_img(img.cpu()) 62 | if imgi > 2: 63 | showimg = showimg.transpose(Image.FLIP_LEFT_RIGHT) 64 | 65 | plt.annotate(cameras[imgi].replace('_', ' ').replace('CAM ', ''), (0.01, 0.87), c='white', 66 | xycoords='axes fraction', fontsize=14) 67 | plt.imshow(showimg) 68 | plt.axis('off') 69 | 70 | ax = plt.subplot(gs[:, 3]) 71 | plt.imshow(make_contour(vis_image[::-1, ::-1])) 72 | plt.axis('off') 73 | 74 | plt.draw() 75 | figure_numpy = convert_figure_numpy(fig) 76 | plt.close() 77 | return figure_numpy 78 | 79 | 80 | def download_example_data(): 81 | from requests import get 82 | 83 | def download(url, file_name): 84 | # open in binary mode 85 | with open(file_name, "wb") as file: 86 | # get request 87 | response = get(url) 88 | # write to file 89 | file.write(response.content) 90 | 91 | os.makedirs(EXAMPLE_DATA_PATH, exist_ok=True) 92 | url_list = ['https://github.com/wayveai/fiery/releases/download/v1.0/example_1.npz', 93 | 'https://github.com/wayveai/fiery/releases/download/v1.0/example_2.npz', 94 | 'https://github.com/wayveai/fiery/releases/download/v1.0/example_3.npz', 95 | 'https://github.com/wayveai/fiery/releases/download/v1.0/example_4.npz' 96 | ] 97 | for url in url_list: 98 | download(url, os.path.join(EXAMPLE_DATA_PATH, os.path.basename(url))) 99 | 100 | 101 | def visualise(checkpoint_path): 102 | trainer = TrainingModule.load_from_checkpoint(checkpoint_path, strict=True) 103 | 104 | device = torch.device('cuda:0') 105 | trainer = trainer.to(device) 106 | trainer.eval() 107 | 108 | # Download example data 109 | download_example_data() 110 | # Load data 111 | for data_path in sorted(glob(os.path.join(EXAMPLE_DATA_PATH, '*.npz'))): 112 | data = np.load(data_path) 113 | image = torch.from_numpy(data['image']).to(device) 114 | intrinsics = torch.from_numpy(data['intrinsics']).to(device) 115 | extrinsics = torch.from_numpy(data['extrinsics']).to(device) 116 | future_egomotions = torch.from_numpy(data['future_egomotion']).to(device) 117 | 118 | # Forward pass 119 | with torch.no_grad(): 120 | output = trainer.model(image, intrinsics, extrinsics, future_egomotions) 121 | 122 | figure_numpy = plot_prediction(image, output, trainer.cfg) 123 | os.makedirs('./output_vis', exist_ok=True) 124 | output_filename = os.path.join('./output_vis', os.path.basename(data_path).split('.')[0]) + '.png' 125 | Image.fromarray(figure_numpy).save(output_filename) 126 | print(f'Saved output in {output_filename}') 127 | 128 | 129 | if __name__ == '__main__': 130 | parser = ArgumentParser(description='Fiery visualisation') 131 | parser.add_argument('--checkpoint', default='./fiery.ckpt', type=str, help='path to checkpoint') 132 | 133 | args = parser.parse_args() 134 | 135 | visualise(args.checkpoint) 136 | -------------------------------------------------------------------------------- /fiery/layers/convolutions.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from functools import partial 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | class ConvBlock(nn.Module): 10 | """2D convolution followed by 11 | - an optional normalisation (batch norm or instance norm) 12 | - an optional activation (ReLU, LeakyReLU, or tanh) 13 | """ 14 | 15 | def __init__( 16 | self, 17 | in_channels, 18 | out_channels=None, 19 | kernel_size=3, 20 | stride=1, 21 | norm='bn', 22 | activation='relu', 23 | bias=False, 24 | transpose=False, 25 | ): 26 | super().__init__() 27 | out_channels = out_channels or in_channels 28 | padding = int((kernel_size - 1) / 2) 29 | self.conv = nn.Conv2d if not transpose else partial(nn.ConvTranspose2d, output_padding=1) 30 | self.conv = self.conv(in_channels, out_channels, kernel_size, stride, padding=padding, bias=bias) 31 | 32 | if norm == 'bn': 33 | self.norm = nn.BatchNorm2d(out_channels) 34 | elif norm == 'in': 35 | self.norm = nn.InstanceNorm2d(out_channels) 36 | elif norm == 'none': 37 | self.norm = None 38 | else: 39 | raise ValueError('Invalid norm {}'.format(norm)) 40 | 41 | if activation == 'relu': 42 | self.activation = nn.ReLU(inplace=True) 43 | elif activation == 'lrelu': 44 | self.activation = nn.LeakyReLU(0.1, inplace=True) 45 | elif activation == 'elu': 46 | self.activation = nn.ELU(inplace=True) 47 | elif activation == 'tanh': 48 | self.activation = nn.Tanh(inplace=True) 49 | elif activation == 'none': 50 | self.activation = None 51 | else: 52 | raise ValueError('Invalid activation {}'.format(activation)) 53 | 54 | def forward(self, x): 55 | x = self.conv(x) 56 | 57 | if self.norm: 58 | x = self.norm(x) 59 | if self.activation: 60 | x = self.activation(x) 61 | return x 62 | 63 | 64 | class Bottleneck(nn.Module): 65 | """ 66 | Defines a bottleneck module with a residual connection 67 | """ 68 | 69 | def __init__( 70 | self, 71 | in_channels, 72 | out_channels=None, 73 | kernel_size=3, 74 | dilation=1, 75 | groups=1, 76 | upsample=False, 77 | downsample=False, 78 | dropout=0.0, 79 | ): 80 | super().__init__() 81 | self._downsample = downsample 82 | bottleneck_channels = int(in_channels / 2) 83 | out_channels = out_channels or in_channels 84 | padding_size = ((kernel_size - 1) * dilation + 1) // 2 85 | 86 | # Define the main conv operation 87 | assert dilation == 1 88 | if upsample: 89 | assert not downsample, 'downsample and upsample not possible simultaneously.' 90 | bottleneck_conv = nn.ConvTranspose2d( 91 | bottleneck_channels, 92 | bottleneck_channels, 93 | kernel_size=kernel_size, 94 | bias=False, 95 | dilation=1, 96 | stride=2, 97 | output_padding=padding_size, 98 | padding=padding_size, 99 | groups=groups, 100 | ) 101 | elif downsample: 102 | bottleneck_conv = nn.Conv2d( 103 | bottleneck_channels, 104 | bottleneck_channels, 105 | kernel_size=kernel_size, 106 | bias=False, 107 | dilation=dilation, 108 | stride=2, 109 | padding=padding_size, 110 | groups=groups, 111 | ) 112 | else: 113 | bottleneck_conv = nn.Conv2d( 114 | bottleneck_channels, 115 | bottleneck_channels, 116 | kernel_size=kernel_size, 117 | bias=False, 118 | dilation=dilation, 119 | padding=padding_size, 120 | groups=groups, 121 | ) 122 | 123 | self.layers = nn.Sequential( 124 | OrderedDict( 125 | [ 126 | # First projection with 1x1 kernel 127 | ('conv_down_project', nn.Conv2d(in_channels, bottleneck_channels, kernel_size=1, bias=False)), 128 | ('abn_down_project', nn.Sequential(nn.BatchNorm2d(bottleneck_channels), 129 | nn.ReLU(inplace=True))), 130 | # Second conv block 131 | ('conv', bottleneck_conv), 132 | ('abn', nn.Sequential(nn.BatchNorm2d(bottleneck_channels), nn.ReLU(inplace=True))), 133 | # Final projection with 1x1 kernel 134 | ('conv_up_project', nn.Conv2d(bottleneck_channels, out_channels, kernel_size=1, bias=False)), 135 | ('abn_up_project', nn.Sequential(nn.BatchNorm2d(out_channels), 136 | nn.ReLU(inplace=True))), 137 | # Regulariser 138 | ('dropout', nn.Dropout2d(p=dropout)), 139 | ] 140 | ) 141 | ) 142 | 143 | if out_channels == in_channels and not downsample and not upsample: 144 | self.projection = None 145 | else: 146 | projection = OrderedDict() 147 | if upsample: 148 | projection.update({'upsample_skip_proj': Interpolate(scale_factor=2)}) 149 | elif downsample: 150 | projection.update({'upsample_skip_proj': nn.MaxPool2d(kernel_size=2, stride=2)}) 151 | projection.update( 152 | { 153 | 'conv_skip_proj': nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False), 154 | 'bn_skip_proj': nn.BatchNorm2d(out_channels), 155 | } 156 | ) 157 | self.projection = nn.Sequential(projection) 158 | 159 | # pylint: disable=arguments-differ 160 | def forward(self, *args): 161 | (x,) = args 162 | x_residual = self.layers(x) 163 | if self.projection is not None: 164 | if self._downsample: 165 | # pad h/w dimensions if they are odd to prevent shape mismatch with residual layer 166 | x = nn.functional.pad(x, (0, x.shape[-1] % 2, 0, x.shape[-2] % 2), value=0) 167 | return x_residual + self.projection(x) 168 | return x_residual + x 169 | 170 | 171 | class Interpolate(nn.Module): 172 | def __init__(self, scale_factor: int = 2): 173 | super().__init__() 174 | self._interpolate = nn.functional.interpolate 175 | self._scale_factor = scale_factor 176 | 177 | # pylint: disable=arguments-differ 178 | def forward(self, x): 179 | return self._interpolate(x, scale_factor=self._scale_factor, mode='bilinear', align_corners=False) 180 | 181 | 182 | class UpsamplingConcat(nn.Module): 183 | def __init__(self, in_channels, out_channels, scale_factor=2): 184 | super().__init__() 185 | 186 | self.upsample = nn.Upsample(scale_factor=scale_factor, mode='bilinear', align_corners=False) 187 | 188 | self.conv = nn.Sequential( 189 | nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False), 190 | nn.BatchNorm2d(out_channels), 191 | nn.ReLU(inplace=True), 192 | nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False), 193 | nn.BatchNorm2d(out_channels), 194 | nn.ReLU(inplace=True), 195 | ) 196 | 197 | def forward(self, x_to_upsample, x): 198 | x_to_upsample = self.upsample(x_to_upsample) 199 | x_to_upsample = torch.cat([x, x_to_upsample], dim=1) 200 | return self.conv(x_to_upsample) 201 | 202 | 203 | class UpsamplingAdd(nn.Module): 204 | def __init__(self, in_channels, out_channels, scale_factor=2): 205 | super().__init__() 206 | self.upsample_layer = nn.Sequential( 207 | nn.Upsample(scale_factor=scale_factor, mode='bilinear', align_corners=False), 208 | nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0, bias=False), 209 | nn.BatchNorm2d(out_channels), 210 | ) 211 | 212 | def forward(self, x, x_skip): 213 | x = self.upsample_layer(x) 214 | return x + x_skip 215 | -------------------------------------------------------------------------------- /fiery/utils/geometry.py: -------------------------------------------------------------------------------- 1 | import PIL 2 | import numpy as np 3 | import torch 4 | 5 | from pyquaternion import Quaternion 6 | 7 | 8 | def resize_and_crop_image(img, resize_dims, crop): 9 | # Bilinear resizing followed by cropping 10 | img = img.resize(resize_dims, resample=PIL.Image.BILINEAR) 11 | img = img.crop(crop) 12 | return img 13 | 14 | 15 | def update_intrinsics(intrinsics, top_crop=0.0, left_crop=0.0, scale_width=1.0, scale_height=1.0): 16 | """ 17 | Parameters 18 | ---------- 19 | intrinsics: torch.Tensor (3, 3) 20 | top_crop: float 21 | left_crop: float 22 | scale_width: float 23 | scale_height: float 24 | """ 25 | updated_intrinsics = intrinsics.clone() 26 | # Adjust intrinsics scale due to resizing 27 | updated_intrinsics[0, 0] *= scale_width 28 | updated_intrinsics[0, 2] *= scale_width 29 | updated_intrinsics[1, 1] *= scale_height 30 | updated_intrinsics[1, 2] *= scale_height 31 | 32 | # Adjust principal point due to cropping 33 | updated_intrinsics[0, 2] -= left_crop 34 | updated_intrinsics[1, 2] -= top_crop 35 | 36 | return updated_intrinsics 37 | 38 | 39 | def calculate_birds_eye_view_parameters(x_bounds, y_bounds, z_bounds): 40 | """ 41 | Parameters 42 | ---------- 43 | x_bounds: Forward direction in the ego-car. 44 | y_bounds: Sides 45 | z_bounds: Height 46 | 47 | Returns 48 | ------- 49 | bev_resolution: Bird's-eye view bev_resolution 50 | bev_start_position Bird's-eye view first element 51 | bev_dimension Bird's-eye view tensor spatial dimension 52 | """ 53 | bev_resolution = torch.tensor([row[2] for row in [x_bounds, y_bounds, z_bounds]]) 54 | bev_start_position = torch.tensor([row[0] + row[2] / 2.0 for row in [x_bounds, y_bounds, z_bounds]]) 55 | bev_dimension = torch.tensor([(row[1] - row[0]) / row[2] for row in [x_bounds, y_bounds, z_bounds]], 56 | dtype=torch.long) 57 | 58 | return bev_resolution, bev_start_position, bev_dimension 59 | 60 | 61 | def convert_egopose_to_matrix_numpy(egopose): 62 | transformation_matrix = np.zeros((4, 4), dtype=np.float32) 63 | rotation = Quaternion(egopose['rotation']).rotation_matrix 64 | translation = np.array(egopose['translation']) 65 | transformation_matrix[:3, :3] = rotation 66 | transformation_matrix[:3, 3] = translation 67 | transformation_matrix[3, 3] = 1.0 68 | return transformation_matrix 69 | 70 | 71 | def invert_matrix_egopose_numpy(egopose): 72 | """ Compute the inverse transformation of a 4x4 egopose numpy matrix.""" 73 | inverse_matrix = np.zeros((4, 4), dtype=np.float32) 74 | rotation = egopose[:3, :3] 75 | translation = egopose[:3, 3] 76 | inverse_matrix[:3, :3] = rotation.T 77 | inverse_matrix[:3, 3] = -np.dot(rotation.T, translation) 78 | inverse_matrix[3, 3] = 1.0 79 | return inverse_matrix 80 | 81 | 82 | def mat2pose_vec(matrix: torch.Tensor): 83 | """ 84 | Converts a 4x4 pose matrix into a 6-dof pose vector 85 | Args: 86 | matrix (ndarray): 4x4 pose matrix 87 | Returns: 88 | vector (ndarray): 6-dof pose vector comprising translation components (tx, ty, tz) and 89 | rotation components (rx, ry, rz) 90 | """ 91 | 92 | # M[1, 2] = -sinx*cosy, M[2, 2] = +cosx*cosy 93 | rotx = torch.atan2(-matrix[..., 1, 2], matrix[..., 2, 2]) 94 | 95 | # M[0, 2] = +siny, M[1, 2] = -sinx*cosy, M[2, 2] = +cosx*cosy 96 | cosy = torch.sqrt(matrix[..., 1, 2] ** 2 + matrix[..., 2, 2] ** 2) 97 | roty = torch.atan2(matrix[..., 0, 2], cosy) 98 | 99 | # M[0, 0] = +cosy*cosz, M[0, 1] = -cosy*sinz 100 | rotz = torch.atan2(-matrix[..., 0, 1], matrix[..., 0, 0]) 101 | 102 | rotation = torch.stack((rotx, roty, rotz), dim=-1) 103 | 104 | # Extract translation params 105 | translation = matrix[..., :3, 3] 106 | return torch.cat((translation, rotation), dim=-1) 107 | 108 | 109 | def euler2mat(angle: torch.Tensor): 110 | """Convert euler angles to rotation matrix. 111 | Reference: https://github.com/pulkitag/pycaffe-utils/blob/master/rot_utils.py#L174 112 | Args: 113 | angle: rotation angle along 3 axis (in radians) [Bx3] 114 | Returns: 115 | Rotation matrix corresponding to the euler angles [Bx3x3] 116 | """ 117 | shape = angle.shape 118 | angle = angle.view(-1, 3) 119 | x, y, z = angle[:, 0], angle[:, 1], angle[:, 2] 120 | 121 | cosz = torch.cos(z) 122 | sinz = torch.sin(z) 123 | 124 | zeros = torch.zeros_like(z) 125 | ones = torch.ones_like(z) 126 | zmat = torch.stack([cosz, -sinz, zeros, sinz, cosz, zeros, zeros, zeros, ones], dim=1).view(-1, 3, 3) 127 | 128 | cosy = torch.cos(y) 129 | siny = torch.sin(y) 130 | 131 | ymat = torch.stack([cosy, zeros, siny, zeros, ones, zeros, -siny, zeros, cosy], dim=1).view(-1, 3, 3) 132 | 133 | cosx = torch.cos(x) 134 | sinx = torch.sin(x) 135 | 136 | xmat = torch.stack([ones, zeros, zeros, zeros, cosx, -sinx, zeros, sinx, cosx], dim=1).view(-1, 3, 3) 137 | 138 | rot_mat = xmat.bmm(ymat).bmm(zmat) 139 | rot_mat = rot_mat.view(*shape[:-1], 3, 3) 140 | return rot_mat 141 | 142 | 143 | def pose_vec2mat(vec: torch.Tensor): 144 | """ 145 | Convert 6DoF parameters to transformation matrix. 146 | Args: 147 | vec: 6DoF parameters in the order of tx, ty, tz, rx, ry, rz [B,6] 148 | Returns: 149 | A transformation matrix [B,4,4] 150 | """ 151 | translation = vec[..., :3].unsqueeze(-1) # [...x3x1] 152 | rot = vec[..., 3:].contiguous() # [...x3] 153 | rot_mat = euler2mat(rot) # [...,3,3] 154 | transform_mat = torch.cat([rot_mat, translation], dim=-1) # [...,3,4] 155 | transform_mat = torch.nn.functional.pad(transform_mat, [0, 0, 0, 1], value=0) # [...,4,4] 156 | transform_mat[..., 3, 3] = 1.0 157 | return transform_mat 158 | 159 | 160 | def invert_pose_matrix(x): 161 | """ 162 | Parameters 163 | ---------- 164 | x: [B, 4, 4] batch of pose matrices 165 | 166 | Returns 167 | ------- 168 | out: [B, 4, 4] batch of inverse pose matrices 169 | """ 170 | assert len(x.shape) == 3 and x.shape[1:] == (4, 4), 'Only works for batch of pose matrices.' 171 | 172 | transposed_rotation = torch.transpose(x[:, :3, :3], 1, 2) 173 | translation = x[:, :3, 3:] 174 | 175 | inverse_mat = torch.cat([transposed_rotation, -torch.bmm(transposed_rotation, translation)], dim=-1) # [B,3,4] 176 | inverse_mat = torch.nn.functional.pad(inverse_mat, [0, 0, 0, 1], value=0) # [B,4,4] 177 | inverse_mat[..., 3, 3] = 1.0 178 | return inverse_mat 179 | 180 | 181 | def warp_features(x, flow, mode='nearest', spatial_extent=None): 182 | """ Applies a rotation and translation to feature map x. 183 | Args: 184 | x: (b, c, h, w) feature map 185 | flow: (b, 6) 6DoF vector (only uses the xy poriton) 186 | mode: use 'nearest' when dealing with categorical inputs 187 | Returns: 188 | in plane transformed feature map 189 | """ 190 | if flow is None: 191 | return x 192 | b, c, h, w = x.shape 193 | # z-rotation 194 | angle = flow[:, 5].clone() # torch.atan2(flow[:, 1, 0], flow[:, 0, 0]) 195 | # x-y translation 196 | translation = flow[:, :2].clone() # flow[:, :2, 3] 197 | 198 | # Normalise translation. Need to divide by how many meters is half of the image. 199 | # because translation of 1.0 correspond to translation of half of the image. 200 | translation[:, 0] /= spatial_extent[0] 201 | translation[:, 1] /= spatial_extent[1] 202 | # forward axis is inverted 203 | translation[:, 0] *= -1 204 | 205 | cos_theta = torch.cos(angle) 206 | sin_theta = torch.sin(angle) 207 | 208 | # output = Rot.input + translation 209 | # tx and ty are inverted as is the case when going from real coordinates to numpy coordinates 210 | # translation_pos_0 -> positive value makes the image move to the left 211 | # translation_pos_1 -> positive value makes the image move to the top 212 | # Angle -> positive value in rad makes the image move in the trigonometric way 213 | transformation = torch.stack([cos_theta, -sin_theta, translation[:, 1], 214 | sin_theta, cos_theta, translation[:, 0]], dim=-1).view(b, 2, 3) 215 | 216 | # Note that a rotation will preserve distances only if height = width. Otherwise there's 217 | # resizing going on. e.g. rotation of pi/2 of a 100x200 image will make what's in the center of the image 218 | # elongated. 219 | grid = torch.nn.functional.affine_grid(transformation, size=x.shape, align_corners=False) 220 | warped_x = torch.nn.functional.grid_sample(x, grid.float(), mode=mode, padding_mode='zeros', align_corners=False) 221 | 222 | return warped_x 223 | 224 | 225 | def cumulative_warp_features(x, flow, mode='nearest', spatial_extent=None): 226 | """ Warps a sequence of feature maps by accumulating incremental 2d flow. 227 | 228 | x[:, -1] remains unchanged 229 | x[:, -2] is warped using flow[:, -2] 230 | x[:, -3] is warped using flow[:, -3] @ flow[:, -2] 231 | ... 232 | x[:, 0] is warped using flow[:, 0] @ ... @ flow[:, -3] @ flow[:, -2] 233 | 234 | Args: 235 | x: (b, t, c, h, w) sequence of feature maps 236 | flow: (b, t, 6) sequence of 6 DoF pose 237 | from t to t+1 (only uses the xy poriton) 238 | 239 | """ 240 | sequence_length = x.shape[1] 241 | if sequence_length == 1: 242 | return x 243 | 244 | flow = pose_vec2mat(flow) 245 | 246 | out = [x[:, -1]] 247 | cum_flow = flow[:, -2] 248 | for t in reversed(range(sequence_length - 1)): 249 | out.append(warp_features(x[:, t], mat2pose_vec(cum_flow), mode=mode, spatial_extent=spatial_extent)) 250 | # @ is the equivalent of torch.bmm 251 | cum_flow = flow[:, t - 1] @ cum_flow 252 | 253 | return torch.stack(out[::-1], 1) 254 | 255 | 256 | def cumulative_warp_features_reverse(x, flow, mode='nearest', spatial_extent=None): 257 | """ Warps a sequence of feature maps by accumulating incremental 2d flow. 258 | 259 | x[:, 0] remains unchanged 260 | x[:, 1] is warped using flow[:, 0].inverse() 261 | x[:, 2] is warped using flow[:, 0].inverse() @ flow[:, 1].inverse() 262 | ... 263 | 264 | Args: 265 | x: (b, t, c, h, w) sequence of feature maps 266 | flow: (b, t, 6) sequence of 6 DoF pose 267 | from t to t+1 (only uses the xy poriton) 268 | 269 | """ 270 | flow = pose_vec2mat(flow) 271 | 272 | out = [x[:,0]] 273 | 274 | for i in range(1, x.shape[1]): 275 | if i==1: 276 | cum_flow = invert_pose_matrix(flow[:, 0]) 277 | else: 278 | cum_flow = cum_flow @ invert_pose_matrix(flow[:,i-1]) 279 | out.append( warp_features(x[:,i], mat2pose_vec(cum_flow), mode, spatial_extent=spatial_extent)) 280 | return torch.stack(out, 1) 281 | 282 | 283 | class VoxelsSumming(torch.autograd.Function): 284 | """Adapted from https://github.com/nv-tlabs/lift-splat-shoot/blob/master/src/tools.py#L193""" 285 | @staticmethod 286 | def forward(ctx, x, geometry, ranks): 287 | """The features `x` and `geometry` are ranked by voxel positions.""" 288 | # Cumulative sum of all features. 289 | x = x.cumsum(0) 290 | 291 | # Indicates the change of voxel. 292 | mask = torch.ones(x.shape[0], device=x.device, dtype=torch.bool) 293 | mask[:-1] = ranks[1:] != ranks[:-1] 294 | 295 | x, geometry = x[mask], geometry[mask] 296 | # Calculate sum of features within a voxel. 297 | x = torch.cat((x[:1], x[1:] - x[:-1])) 298 | 299 | ctx.save_for_backward(mask) 300 | ctx.mark_non_differentiable(geometry) 301 | 302 | return x, geometry 303 | 304 | @staticmethod 305 | def backward(ctx, grad_x, grad_geometry): 306 | (mask,) = ctx.saved_tensors 307 | # Since the operation is summing, we simply need to send gradient 308 | # to all elements that were part of the summation process. 309 | indices = torch.cumsum(mask, 0) 310 | indices[mask] -= 1 311 | 312 | output_grad = grad_x[indices] 313 | 314 | return output_grad, None, None 315 | -------------------------------------------------------------------------------- /fiery/layers/temporal.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from fiery.layers.convolutions import ConvBlock 7 | from fiery.utils.geometry import warp_features 8 | 9 | 10 | class SpatialGRU(nn.Module): 11 | """A GRU cell that takes an input tensor [BxTxCxHxW] and an optional previous state and passes a 12 | convolutional gated recurrent unit over the data""" 13 | 14 | def __init__(self, input_size, hidden_size, gru_bias_init=0.0, norm='bn', activation='relu'): 15 | super().__init__() 16 | self.input_size = input_size 17 | self.hidden_size = hidden_size 18 | self.gru_bias_init = gru_bias_init 19 | 20 | self.conv_update = nn.Conv2d(input_size + hidden_size, hidden_size, kernel_size=3, bias=True, padding=1) 21 | self.conv_reset = nn.Conv2d(input_size + hidden_size, hidden_size, kernel_size=3, bias=True, padding=1) 22 | 23 | self.conv_state_tilde = ConvBlock( 24 | input_size + hidden_size, hidden_size, kernel_size=3, bias=False, norm=norm, activation=activation 25 | ) 26 | 27 | def forward(self, x, state=None, flow=None, mode='bilinear'): 28 | # pylint: disable=unused-argument, arguments-differ 29 | # Check size 30 | assert len(x.size()) == 5, 'Input tensor must be BxTxCxHxW.' 31 | b, timesteps, c, h, w = x.size() 32 | assert c == self.input_size, f'feature sizes must match, got input {c} for layer with size {self.input_size}' 33 | 34 | # recurrent layers 35 | rnn_output = [] 36 | rnn_state = torch.zeros(b, self.hidden_size, h, w, device=x.device) if state is None else state 37 | for t in range(timesteps): 38 | x_t = x[:, t] 39 | if flow is not None: 40 | rnn_state = warp_features(rnn_state, flow[:, t], mode=mode) 41 | 42 | # propagate rnn state 43 | rnn_state = self.gru_cell(x_t, rnn_state) 44 | rnn_output.append(rnn_state) 45 | 46 | # reshape rnn output to batch tensor 47 | return torch.stack(rnn_output, dim=1) 48 | 49 | def gru_cell(self, x, state): 50 | # Compute gates 51 | x_and_state = torch.cat([x, state], dim=1) 52 | update_gate = self.conv_update(x_and_state) 53 | reset_gate = self.conv_reset(x_and_state) 54 | # Add bias to initialise gate as close to identity function 55 | update_gate = torch.sigmoid(update_gate + self.gru_bias_init) 56 | reset_gate = torch.sigmoid(reset_gate + self.gru_bias_init) 57 | 58 | # Compute proposal state, activation is defined in norm_act_config (can be tanh, ReLU etc) 59 | state_tilde = self.conv_state_tilde(torch.cat([x, (1.0 - reset_gate) * state], dim=1)) 60 | 61 | output = (1.0 - update_gate) * state + update_gate * state_tilde 62 | return output 63 | 64 | 65 | class CausalConv3d(nn.Module): 66 | def __init__(self, in_channels, out_channels, kernel_size=(2, 3, 3), dilation=(1, 1, 1), bias=False): 67 | super().__init__() 68 | assert len(kernel_size) == 3, 'kernel_size must be a 3-tuple.' 69 | time_pad = (kernel_size[0] - 1) * dilation[0] 70 | height_pad = ((kernel_size[1] - 1) * dilation[1]) // 2 71 | width_pad = ((kernel_size[2] - 1) * dilation[2]) // 2 72 | 73 | # Pad temporally on the left 74 | self.pad = nn.ConstantPad3d(padding=(width_pad, width_pad, height_pad, height_pad, time_pad, 0), value=0) 75 | self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, dilation=dilation, stride=1, padding=0, bias=bias) 76 | self.norm = nn.BatchNorm3d(out_channels) 77 | self.activation = nn.ReLU(inplace=True) 78 | 79 | def forward(self, *inputs): 80 | (x,) = inputs 81 | x = self.pad(x) 82 | x = self.conv(x) 83 | x = self.norm(x) 84 | x = self.activation(x) 85 | return x 86 | 87 | 88 | class CausalMaxPool3d(nn.Module): 89 | def __init__(self, kernel_size=(2, 3, 3)): 90 | super().__init__() 91 | assert len(kernel_size) == 3, 'kernel_size must be a 3-tuple.' 92 | time_pad = kernel_size[0] - 1 93 | height_pad = (kernel_size[1] - 1) // 2 94 | width_pad = (kernel_size[2] - 1) // 2 95 | 96 | # Pad temporally on the left 97 | self.pad = nn.ConstantPad3d(padding=(width_pad, width_pad, height_pad, height_pad, time_pad, 0), value=0) 98 | self.max_pool = nn.MaxPool3d(kernel_size, stride=1) 99 | 100 | def forward(self, *inputs): 101 | (x,) = inputs 102 | x = self.pad(x) 103 | x = self.max_pool(x) 104 | return x 105 | 106 | 107 | def conv_1x1x1_norm_activated(in_channels, out_channels): 108 | """1x1x1 3D convolution, normalization and activation layer.""" 109 | return nn.Sequential( 110 | OrderedDict( 111 | [ 112 | ('conv', nn.Conv3d(in_channels, out_channels, kernel_size=1, bias=False)), 113 | ('norm', nn.BatchNorm3d(out_channels)), 114 | ('activation', nn.ReLU(inplace=True)), 115 | ] 116 | ) 117 | ) 118 | 119 | 120 | class Bottleneck3D(nn.Module): 121 | """ 122 | Defines a bottleneck module with a residual connection 123 | """ 124 | 125 | def __init__(self, in_channels, out_channels=None, kernel_size=(2, 3, 3), dilation=(1, 1, 1)): 126 | super().__init__() 127 | bottleneck_channels = in_channels // 2 128 | out_channels = out_channels or in_channels 129 | 130 | self.layers = nn.Sequential( 131 | OrderedDict( 132 | [ 133 | # First projection with 1x1 kernel 134 | ('conv_down_project', conv_1x1x1_norm_activated(in_channels, bottleneck_channels)), 135 | # Second conv block 136 | ( 137 | 'conv', 138 | CausalConv3d( 139 | bottleneck_channels, 140 | bottleneck_channels, 141 | kernel_size=kernel_size, 142 | dilation=dilation, 143 | bias=False, 144 | ), 145 | ), 146 | # Final projection with 1x1 kernel 147 | ('conv_up_project', conv_1x1x1_norm_activated(bottleneck_channels, out_channels)), 148 | ] 149 | ) 150 | ) 151 | 152 | if out_channels != in_channels: 153 | self.projection = nn.Sequential( 154 | nn.Conv3d(in_channels, out_channels, kernel_size=1, bias=False), 155 | nn.BatchNorm3d(out_channels), 156 | ) 157 | else: 158 | self.projection = None 159 | 160 | def forward(self, *args): 161 | (x,) = args 162 | x_residual = self.layers(x) 163 | x_features = self.projection(x) if self.projection is not None else x 164 | return x_residual + x_features 165 | 166 | 167 | class PyramidSpatioTemporalPooling(nn.Module): 168 | """ Spatio-temporal pyramid pooling. 169 | Performs 3D average pooling followed by 1x1x1 convolution to reduce the number of channels and upsampling. 170 | Setting contains a list of kernel_size: usually it is [(2, h, w), (2, h//2, w//2), (2, h//4, w//4)] 171 | """ 172 | 173 | def __init__(self, in_channels, reduction_channels, pool_sizes): 174 | super().__init__() 175 | self.features = [] 176 | for pool_size in pool_sizes: 177 | assert pool_size[0] == 2, ( 178 | "Time kernel should be 2 as PyTorch raises an error when" "padding with more than half the kernel size" 179 | ) 180 | stride = (1, *pool_size[1:]) 181 | padding = (pool_size[0] - 1, 0, 0) 182 | self.features.append( 183 | nn.Sequential( 184 | OrderedDict( 185 | [ 186 | # Pad the input tensor but do not take into account zero padding into the average. 187 | ( 188 | 'avgpool', 189 | torch.nn.AvgPool3d( 190 | kernel_size=pool_size, stride=stride, padding=padding, count_include_pad=False 191 | ), 192 | ), 193 | ('conv_bn_relu', conv_1x1x1_norm_activated(in_channels, reduction_channels)), 194 | ] 195 | ) 196 | ) 197 | ) 198 | self.features = nn.ModuleList(self.features) 199 | 200 | def forward(self, *inputs): 201 | (x,) = inputs 202 | b, _, t, h, w = x.shape 203 | # Do not include current tensor when concatenating 204 | out = [] 205 | for f in self.features: 206 | # Remove unnecessary padded values (time dimension) on the right 207 | x_pool = f(x)[:, :, :-1].contiguous() 208 | c = x_pool.shape[1] 209 | x_pool = nn.functional.interpolate( 210 | x_pool.view(b * t, c, *x_pool.shape[-2:]), (h, w), mode='bilinear', align_corners=False 211 | ) 212 | x_pool = x_pool.view(b, c, t, h, w) 213 | out.append(x_pool) 214 | out = torch.cat(out, 1) 215 | return out 216 | 217 | 218 | class TemporalBlock(nn.Module): 219 | """ Temporal block with the following layers: 220 | - 2x3x3, 1x3x3, spatio-temporal pyramid pooling 221 | - dropout 222 | - skip connection. 223 | """ 224 | 225 | def __init__(self, in_channels, out_channels=None, use_pyramid_pooling=False, pool_sizes=None): 226 | super().__init__() 227 | self.in_channels = in_channels 228 | self.half_channels = in_channels // 2 229 | self.out_channels = out_channels or self.in_channels 230 | self.kernels = [(2, 3, 3), (1, 3, 3)] 231 | 232 | # Flag for spatio-temporal pyramid pooling 233 | self.use_pyramid_pooling = use_pyramid_pooling 234 | 235 | # 3 convolution paths: 2x3x3, 1x3x3, 1x1x1 236 | self.convolution_paths = [] 237 | for kernel_size in self.kernels: 238 | self.convolution_paths.append( 239 | nn.Sequential( 240 | conv_1x1x1_norm_activated(self.in_channels, self.half_channels), 241 | CausalConv3d(self.half_channels, self.half_channels, kernel_size=kernel_size), 242 | ) 243 | ) 244 | self.convolution_paths.append(conv_1x1x1_norm_activated(self.in_channels, self.half_channels)) 245 | self.convolution_paths = nn.ModuleList(self.convolution_paths) 246 | 247 | agg_in_channels = len(self.convolution_paths) * self.half_channels 248 | 249 | if self.use_pyramid_pooling: 250 | assert pool_sizes is not None, "setting must contain the list of kernel_size, but is None." 251 | reduction_channels = self.in_channels // 3 252 | self.pyramid_pooling = PyramidSpatioTemporalPooling(self.in_channels, reduction_channels, pool_sizes) 253 | agg_in_channels += len(pool_sizes) * reduction_channels 254 | 255 | # Feature aggregation 256 | self.aggregation = nn.Sequential( 257 | conv_1x1x1_norm_activated(agg_in_channels, self.out_channels),) 258 | 259 | if self.out_channels != self.in_channels: 260 | self.projection = nn.Sequential( 261 | nn.Conv3d(self.in_channels, self.out_channels, kernel_size=1, bias=False), 262 | nn.BatchNorm3d(self.out_channels), 263 | ) 264 | else: 265 | self.projection = None 266 | 267 | def forward(self, *inputs): 268 | (x,) = inputs 269 | x_paths = [] 270 | for conv in self.convolution_paths: 271 | x_paths.append(conv(x)) 272 | x_residual = torch.cat(x_paths, dim=1) 273 | if self.use_pyramid_pooling: 274 | x_pool = self.pyramid_pooling(x) 275 | x_residual = torch.cat([x_residual, x_pool], dim=1) 276 | x_residual = self.aggregation(x_residual) 277 | 278 | if self.out_channels != self.in_channels: 279 | x = self.projection(x) 280 | x = x + x_residual 281 | return x 282 | -------------------------------------------------------------------------------- /fiery/metrics.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | from pytorch_lightning.metrics.metric import Metric 5 | from pytorch_lightning.metrics.functional.classification import stat_scores_multiple_classes 6 | from pytorch_lightning.metrics.functional.reduction import reduce 7 | 8 | 9 | class IntersectionOverUnion(Metric): 10 | """Computes intersection-over-union.""" 11 | def __init__( 12 | self, 13 | n_classes: int, 14 | ignore_index: Optional[int] = None, 15 | absent_score: float = 0.0, 16 | reduction: str = 'none', 17 | compute_on_step: bool = False, 18 | ): 19 | super().__init__(compute_on_step=compute_on_step) 20 | 21 | self.n_classes = n_classes 22 | self.ignore_index = ignore_index 23 | self.absent_score = absent_score 24 | self.reduction = reduction 25 | 26 | self.add_state('true_positive', default=torch.zeros(n_classes), dist_reduce_fx='sum') 27 | self.add_state('false_positive', default=torch.zeros(n_classes), dist_reduce_fx='sum') 28 | self.add_state('false_negative', default=torch.zeros(n_classes), dist_reduce_fx='sum') 29 | self.add_state('support', default=torch.zeros(n_classes), dist_reduce_fx='sum') 30 | 31 | def update(self, prediction: torch.Tensor, target: torch.Tensor): 32 | tps, fps, _, fns, sups = stat_scores_multiple_classes(prediction, target, self.n_classes) 33 | 34 | self.true_positive += tps 35 | self.false_positive += fps 36 | self.false_negative += fns 37 | self.support += sups 38 | 39 | def compute(self): 40 | scores = torch.zeros(self.n_classes, device=self.true_positive.device, dtype=torch.float32) 41 | 42 | for class_idx in range(self.n_classes): 43 | if class_idx == self.ignore_index: 44 | continue 45 | 46 | tp = self.true_positive[class_idx] 47 | fp = self.false_positive[class_idx] 48 | fn = self.false_negative[class_idx] 49 | sup = self.support[class_idx] 50 | 51 | # If this class is absent in the target (no support) AND absent in the pred (no true or false 52 | # positives), then use the absent_score for this class. 53 | if sup + tp + fp == 0: 54 | scores[class_idx] = self.absent_score 55 | continue 56 | 57 | denominator = tp + fp + fn 58 | score = tp.to(torch.float) / denominator 59 | scores[class_idx] = score 60 | 61 | # Remove the ignored class index from the scores. 62 | if (self.ignore_index is not None) and (0 <= self.ignore_index < self.n_classes): 63 | scores = torch.cat([scores[:self.ignore_index], scores[self.ignore_index+1:]]) 64 | 65 | return reduce(scores, reduction=self.reduction) 66 | 67 | 68 | class PanopticMetric(Metric): 69 | def __init__( 70 | self, 71 | n_classes: int, 72 | temporally_consistent: bool = True, 73 | vehicles_id: int = 1, 74 | compute_on_step: bool = False, 75 | ): 76 | super().__init__(compute_on_step=compute_on_step) 77 | 78 | self.n_classes = n_classes 79 | self.temporally_consistent = temporally_consistent 80 | self.vehicles_id = vehicles_id 81 | self.keys = ['iou', 'true_positive', 'false_positive', 'false_negative'] 82 | 83 | self.add_state('iou', default=torch.zeros(n_classes), dist_reduce_fx='sum') 84 | self.add_state('true_positive', default=torch.zeros(n_classes), dist_reduce_fx='sum') 85 | self.add_state('false_positive', default=torch.zeros(n_classes), dist_reduce_fx='sum') 86 | self.add_state('false_negative', default=torch.zeros(n_classes), dist_reduce_fx='sum') 87 | 88 | def update(self, pred_instance, gt_instance): 89 | """ 90 | Update state with predictions and targets. 91 | 92 | Parameters 93 | ---------- 94 | pred_instance: (b, s, h, w) 95 | Temporally consistent instance segmentation prediction. 96 | gt_instance: (b, s, h, w) 97 | Ground truth instance segmentation. 98 | """ 99 | batch_size, sequence_length = gt_instance.shape[:2] 100 | # Process labels 101 | assert gt_instance.min() == 0, 'ID 0 of gt_instance must be background' 102 | pred_segmentation = (pred_instance > 0).long() 103 | gt_segmentation = (gt_instance > 0).long() 104 | 105 | for b in range(batch_size): 106 | unique_id_mapping = {} 107 | for t in range(sequence_length): 108 | result = self.panoptic_metrics( 109 | pred_segmentation[b, t].detach(), 110 | pred_instance[b, t].detach(), 111 | gt_segmentation[b, t], 112 | gt_instance[b, t], 113 | unique_id_mapping, 114 | ) 115 | 116 | self.iou += result['iou'] 117 | self.true_positive += result['true_positive'] 118 | self.false_positive += result['false_positive'] 119 | self.false_negative += result['false_negative'] 120 | 121 | def compute(self): 122 | denominator = torch.maximum( 123 | (self.true_positive + self.false_positive / 2 + self.false_negative / 2), 124 | torch.ones_like(self.true_positive) 125 | ) 126 | pq = self.iou / denominator 127 | sq = self.iou / torch.maximum(self.true_positive, torch.ones_like(self.true_positive)) 128 | rq = self.true_positive / denominator 129 | 130 | return {'pq': pq, 131 | 'sq': sq, 132 | 'rq': rq, 133 | # If 0, it means there wasn't any detection. 134 | 'denominator': (self.true_positive + self.false_positive / 2 + self.false_negative / 2), 135 | } 136 | 137 | def panoptic_metrics(self, pred_segmentation, pred_instance, gt_segmentation, gt_instance, unique_id_mapping): 138 | """ 139 | Computes panoptic quality metric components. 140 | 141 | Parameters 142 | ---------- 143 | pred_segmentation: [H, W] range {0, ..., n_classes-1} (>= n_classes is void) 144 | pred_instance: [H, W] range {0, ..., n_instances} (zero means background) 145 | gt_segmentation: [H, W] range {0, ..., n_classes-1} (>= n_classes is void) 146 | gt_instance: [H, W] range {0, ..., n_instances} (zero means background) 147 | unique_id_mapping: instance id mapping to check consistency 148 | """ 149 | n_classes = self.n_classes 150 | 151 | result = {key: torch.zeros(n_classes, dtype=torch.float32, device=gt_instance.device) for key in self.keys} 152 | 153 | assert pred_segmentation.dim() == 2 154 | assert pred_segmentation.shape == pred_instance.shape == gt_segmentation.shape == gt_instance.shape 155 | 156 | n_instances = int(torch.cat([pred_instance, gt_instance]).max().item()) 157 | n_all_things = n_instances + n_classes # Classes + instances. 158 | n_things_and_void = n_all_things + 1 159 | 160 | # Now 1 is background; 0 is void (not used). 2 is vehicle semantic class but since it overlaps with 161 | # instances, it is not present. 162 | # and the rest are instance ids starting from 3 163 | prediction, pred_to_cls = self.combine_mask(pred_segmentation, pred_instance, n_classes, n_all_things) 164 | target, target_to_cls = self.combine_mask(gt_segmentation, gt_instance, n_classes, n_all_things) 165 | 166 | # Compute ious between all stuff and things 167 | # hack for bincounting 2 arrays together 168 | x = prediction + n_things_and_void * target 169 | bincount_2d = torch.bincount(x.long(), minlength=n_things_and_void ** 2) 170 | if bincount_2d.shape[0] != n_things_and_void ** 2: 171 | raise ValueError('Incorrect bincount size.') 172 | conf = bincount_2d.reshape((n_things_and_void, n_things_and_void)) 173 | # Drop void class 174 | conf = conf[1:, 1:] 175 | 176 | # Confusion matrix contains intersections between all combinations of classes 177 | union = conf.sum(0).unsqueeze(0) + conf.sum(1).unsqueeze(1) - conf 178 | iou = torch.where(union > 0, (conf.float() + 1e-9) / (union.float() + 1e-9), torch.zeros_like(union).float()) 179 | 180 | # In the iou matrix, first dimension is target idx, second dimension is pred idx. 181 | # Mapping will contain a tuple that maps prediction idx to target idx for segments matched by iou. 182 | mapping = (iou > 0.5).nonzero(as_tuple=False) 183 | 184 | # Check that classes match. 185 | is_matching = pred_to_cls[mapping[:, 1]] == target_to_cls[mapping[:, 0]] 186 | mapping = mapping[is_matching] 187 | tp_mask = torch.zeros_like(conf, dtype=torch.bool) 188 | tp_mask[mapping[:, 0], mapping[:, 1]] = True 189 | 190 | # First ids correspond to "stuff" i.e. semantic seg. 191 | # Instance ids are offset accordingly 192 | for target_id, pred_id in mapping: 193 | cls_id = pred_to_cls[pred_id] 194 | 195 | if self.temporally_consistent and cls_id == self.vehicles_id: 196 | if target_id.item() in unique_id_mapping and unique_id_mapping[target_id.item()] != pred_id.item(): 197 | # Not temporally consistent 198 | result['false_negative'][target_to_cls[target_id]] += 1 199 | result['false_positive'][pred_to_cls[pred_id]] += 1 200 | unique_id_mapping[target_id.item()] = pred_id.item() 201 | continue 202 | 203 | result['true_positive'][cls_id] += 1 204 | result['iou'][cls_id] += iou[target_id][pred_id] 205 | unique_id_mapping[target_id.item()] = pred_id.item() 206 | 207 | for target_id in range(n_classes, n_all_things): 208 | # If this is a true positive do nothing. 209 | if tp_mask[target_id, n_classes:].any(): 210 | continue 211 | # If this target instance didn't match with any predictions and was present set it as false negative. 212 | if target_to_cls[target_id] != -1: 213 | result['false_negative'][target_to_cls[target_id]] += 1 214 | 215 | for pred_id in range(n_classes, n_all_things): 216 | # If this is a true positive do nothing. 217 | if tp_mask[n_classes:, pred_id].any(): 218 | continue 219 | # If this predicted instance didn't match with any prediction, set that predictions as false positive. 220 | if pred_to_cls[pred_id] != -1 and (conf[:, pred_id] > 0).any(): 221 | result['false_positive'][pred_to_cls[pred_id]] += 1 222 | 223 | return result 224 | 225 | def combine_mask(self, segmentation: torch.Tensor, instance: torch.Tensor, n_classes: int, n_all_things: int): 226 | """Shifts all things ids by num_classes and combines things and stuff into a single mask 227 | 228 | Returns a combined mask + a mapping from id to segmentation class. 229 | """ 230 | instance = instance.view(-1) 231 | instance_mask = instance > 0 232 | instance = instance - 1 + n_classes 233 | 234 | segmentation = segmentation.clone().view(-1) 235 | segmentation_mask = segmentation < n_classes # Remove void pixels. 236 | 237 | # Build an index from instance id to class id. 238 | instance_id_to_class_tuples = torch.cat( 239 | ( 240 | instance[instance_mask & segmentation_mask].unsqueeze(1), 241 | segmentation[instance_mask & segmentation_mask].unsqueeze(1), 242 | ), 243 | dim=1, 244 | ) 245 | instance_id_to_class = -instance_id_to_class_tuples.new_ones((n_all_things,)) 246 | instance_id_to_class[instance_id_to_class_tuples[:, 0]] = instance_id_to_class_tuples[:, 1] 247 | instance_id_to_class[torch.arange(n_classes, device=segmentation.device)] = torch.arange( 248 | n_classes, device=segmentation.device 249 | ) 250 | 251 | segmentation[instance_mask] = instance[instance_mask] 252 | segmentation += 1 # Shift all legit classes by 1. 253 | segmentation[~segmentation_mask] = 0 # Shift void class to zero. 254 | 255 | return segmentation, instance_id_to_class 256 | -------------------------------------------------------------------------------- /fiery/trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import pytorch_lightning as pl 4 | 5 | from fiery.config import get_cfg 6 | from fiery.models.fiery import Fiery 7 | from fiery.losses import ProbabilisticLoss, SpatialRegressionLoss, SegmentationLoss 8 | from fiery.metrics import IntersectionOverUnion, PanopticMetric 9 | from fiery.utils.geometry import cumulative_warp_features_reverse 10 | from fiery.utils.instance import predict_instance_segmentation_and_trajectories 11 | from fiery.utils.visualisation import visualise_output 12 | 13 | 14 | class TrainingModule(pl.LightningModule): 15 | def __init__(self, hparams): 16 | super().__init__() 17 | 18 | # see config.py for details 19 | self.hparams = hparams 20 | # pytorch lightning does not support saving YACS CfgNone 21 | cfg = get_cfg(cfg_dict=self.hparams) 22 | self.cfg = cfg 23 | self.n_classes = len(self.cfg.SEMANTIC_SEG.WEIGHTS) 24 | 25 | # Bird's-eye view extent in meters 26 | assert self.cfg.LIFT.X_BOUND[1] > 0 and self.cfg.LIFT.Y_BOUND[1] > 0 27 | self.spatial_extent = (self.cfg.LIFT.X_BOUND[1], self.cfg.LIFT.Y_BOUND[1]) 28 | 29 | # Model 30 | self.model = Fiery(cfg) 31 | 32 | # Losses 33 | self.losses_fn = nn.ModuleDict() 34 | self.losses_fn['segmentation'] = SegmentationLoss( 35 | class_weights=torch.Tensor(self.cfg.SEMANTIC_SEG.WEIGHTS), 36 | use_top_k=self.cfg.SEMANTIC_SEG.USE_TOP_K, 37 | top_k_ratio=self.cfg.SEMANTIC_SEG.TOP_K_RATIO, 38 | future_discount=self.cfg.FUTURE_DISCOUNT, 39 | ) 40 | 41 | # Uncertainty weighting 42 | self.model.segmentation_weight = nn.Parameter(torch.tensor(0.0), requires_grad=True) 43 | 44 | self.metric_iou_val = IntersectionOverUnion(self.n_classes) 45 | 46 | self.losses_fn['instance_center'] = SpatialRegressionLoss( 47 | norm=2, future_discount=self.cfg.FUTURE_DISCOUNT 48 | ) 49 | self.losses_fn['instance_offset'] = SpatialRegressionLoss( 50 | norm=1, future_discount=self.cfg.FUTURE_DISCOUNT, ignore_index=self.cfg.DATASET.IGNORE_INDEX 51 | ) 52 | 53 | # Uncertainty weighting 54 | self.model.centerness_weight = nn.Parameter(torch.tensor(0.0), requires_grad=True) 55 | self.model.offset_weight = nn.Parameter(torch.tensor(0.0), requires_grad=True) 56 | 57 | self.metric_panoptic_val = PanopticMetric(n_classes=self.n_classes) 58 | 59 | if self.cfg.INSTANCE_FLOW.ENABLED: 60 | self.losses_fn['instance_flow'] = SpatialRegressionLoss( 61 | norm=1, future_discount=self.cfg.FUTURE_DISCOUNT, ignore_index=self.cfg.DATASET.IGNORE_INDEX 62 | ) 63 | # Uncertainty weighting 64 | self.model.flow_weight = nn.Parameter(torch.tensor(0.0), requires_grad=True) 65 | 66 | if self.cfg.PROBABILISTIC.ENABLED: 67 | self.losses_fn['probabilistic'] = ProbabilisticLoss() 68 | 69 | self.training_step_count = 0 70 | 71 | def shared_step(self, batch, is_train): 72 | image = batch['image'] 73 | intrinsics = batch['intrinsics'] 74 | extrinsics = batch['extrinsics'] 75 | future_egomotion = batch['future_egomotion'] 76 | 77 | # Warp labels 78 | labels, future_distribution_inputs = self.prepare_future_labels(batch) 79 | 80 | # Forward pass 81 | output = self.model( 82 | image, intrinsics, extrinsics, future_egomotion, future_distribution_inputs 83 | ) 84 | 85 | ##### 86 | # Loss computation 87 | ##### 88 | loss = {} 89 | segmentation_factor = 1 / torch.exp(self.model.segmentation_weight) 90 | loss['segmentation'] = segmentation_factor * self.losses_fn['segmentation']( 91 | output['segmentation'], labels['segmentation'] 92 | ) 93 | loss['segmentation_uncertainty'] = 0.5 * self.model.segmentation_weight 94 | 95 | centerness_factor = 1 / (2*torch.exp(self.model.centerness_weight)) 96 | loss['instance_center'] = centerness_factor * self.losses_fn['instance_center']( 97 | output['instance_center'], labels['centerness'] 98 | ) 99 | 100 | offset_factor = 1 / (2*torch.exp(self.model.offset_weight)) 101 | loss['instance_offset'] = offset_factor * self.losses_fn['instance_offset']( 102 | output['instance_offset'], labels['offset'] 103 | ) 104 | 105 | loss['centerness_uncertainty'] = 0.5 * self.model.centerness_weight 106 | loss['offset_uncertainty'] = 0.5 * self.model.offset_weight 107 | 108 | if self.cfg.INSTANCE_FLOW.ENABLED: 109 | flow_factor = 1 / (2*torch.exp(self.model.flow_weight)) 110 | loss['instance_flow'] = flow_factor * self.losses_fn['instance_flow']( 111 | output['instance_flow'], labels['flow'] 112 | ) 113 | 114 | loss['flow_uncertainty'] = 0.5 * self.model.flow_weight 115 | 116 | if self.cfg.PROBABILISTIC.ENABLED: 117 | loss['probabilistic'] = self.cfg.PROBABILISTIC.WEIGHT * self.losses_fn['probabilistic'](output) 118 | 119 | # Metrics 120 | if not is_train: 121 | seg_prediction = output['segmentation'].detach() 122 | seg_prediction = torch.argmax(seg_prediction, dim=2, keepdims=True) 123 | self.metric_iou_val(seg_prediction, labels['segmentation']) 124 | 125 | pred_consistent_instance_seg = predict_instance_segmentation_and_trajectories( 126 | output, compute_matched_centers=False 127 | ) 128 | 129 | self.metric_panoptic_val(pred_consistent_instance_seg, labels['instance']) 130 | 131 | return output, labels, loss 132 | 133 | def prepare_future_labels(self, batch): 134 | labels = {} 135 | future_distribution_inputs = [] 136 | 137 | segmentation_labels = batch['segmentation'] 138 | instance_center_labels = batch['centerness'] 139 | instance_offset_labels = batch['offset'] 140 | instance_flow_labels = batch['flow'] 141 | gt_instance = batch['instance'] 142 | future_egomotion = batch['future_egomotion'] 143 | 144 | # Warp labels to present's reference frame 145 | segmentation_labels = cumulative_warp_features_reverse( 146 | segmentation_labels[:, (self.model.receptive_field - 1):].float(), 147 | future_egomotion[:, (self.model.receptive_field - 1):], 148 | mode='nearest', spatial_extent=self.spatial_extent, 149 | ).long().contiguous() 150 | labels['segmentation'] = segmentation_labels 151 | future_distribution_inputs.append(segmentation_labels) 152 | 153 | # Warp instance labels to present's reference frame 154 | gt_instance = cumulative_warp_features_reverse( 155 | gt_instance[:, (self.model.receptive_field - 1):].float().unsqueeze(2), 156 | future_egomotion[:, (self.model.receptive_field - 1):], 157 | mode='nearest', spatial_extent=self.spatial_extent, 158 | ).long().contiguous()[:, :, 0] 159 | labels['instance'] = gt_instance 160 | 161 | instance_center_labels = cumulative_warp_features_reverse( 162 | instance_center_labels[:, (self.model.receptive_field - 1):], 163 | future_egomotion[:, (self.model.receptive_field - 1):], 164 | mode='nearest', spatial_extent=self.spatial_extent, 165 | ).contiguous() 166 | labels['centerness'] = instance_center_labels 167 | 168 | instance_offset_labels = cumulative_warp_features_reverse( 169 | instance_offset_labels[:, (self.model.receptive_field- 1):], 170 | future_egomotion[:, (self.model.receptive_field - 1):], 171 | mode='nearest', spatial_extent=self.spatial_extent, 172 | ).contiguous() 173 | labels['offset'] = instance_offset_labels 174 | 175 | future_distribution_inputs.append(instance_center_labels) 176 | future_distribution_inputs.append(instance_offset_labels) 177 | 178 | if self.cfg.INSTANCE_FLOW.ENABLED: 179 | instance_flow_labels = cumulative_warp_features_reverse( 180 | instance_flow_labels[:, (self.model.receptive_field - 1):], 181 | future_egomotion[:, (self.model.receptive_field - 1):], 182 | mode='nearest', spatial_extent=self.spatial_extent, 183 | ).contiguous() 184 | labels['flow'] = instance_flow_labels 185 | 186 | future_distribution_inputs.append(instance_flow_labels) 187 | 188 | if len(future_distribution_inputs) > 0: 189 | future_distribution_inputs = torch.cat(future_distribution_inputs, dim=2) 190 | 191 | return labels, future_distribution_inputs 192 | 193 | def visualise(self, labels, output, batch_idx, prefix='train'): 194 | visualisation_video = visualise_output(labels, output, self.cfg) 195 | name = f'{prefix}_outputs' 196 | if prefix == 'val': 197 | name = name + f'_{batch_idx}' 198 | self.logger.experiment.add_video(name, visualisation_video, global_step=self.training_step_count, fps=2) 199 | 200 | def training_step(self, batch, batch_idx): 201 | output, labels, loss = self.shared_step(batch, True) 202 | self.training_step_count += 1 203 | for key, value in loss.items(): 204 | self.logger.experiment.add_scalar(key, value, global_step=self.training_step_count) 205 | 206 | if self.training_step_count % self.cfg.VIS_INTERVAL == 0: 207 | self.visualise(labels, output, batch_idx, prefix='train') 208 | return sum(loss.values()) 209 | 210 | def validation_step(self, batch, batch_idx): 211 | output, labels, loss = self.shared_step(batch, False) 212 | for key, value in loss.items(): 213 | self.log('val_' + key, value) 214 | 215 | if batch_idx == 0: 216 | self.visualise(labels, output, batch_idx, prefix='val') 217 | 218 | def shared_epoch_end(self, step_outputs, is_train): 219 | # log per class iou metrics 220 | class_names = ['background', 'dynamic'] 221 | if not is_train: 222 | scores = self.metric_iou_val.compute() 223 | for key, value in zip(class_names, scores): 224 | self.logger.experiment.add_scalar('val_iou_' + key, value, global_step=self.training_step_count) 225 | self.metric_iou_val.reset() 226 | 227 | if not is_train: 228 | scores = self.metric_panoptic_val.compute() 229 | for key, value in scores.items(): 230 | for instance_name, score in zip(['background', 'vehicles'], value): 231 | if instance_name != 'background': 232 | self.logger.experiment.add_scalar(f'val_{key}_{instance_name}', score.item(), 233 | global_step=self.training_step_count) 234 | self.metric_panoptic_val.reset() 235 | 236 | self.logger.experiment.add_scalar('segmentation_weight', 237 | 1 / (torch.exp(self.model.segmentation_weight)), 238 | global_step=self.training_step_count) 239 | self.logger.experiment.add_scalar('centerness_weight', 240 | 1 / (2 * torch.exp(self.model.centerness_weight)), 241 | global_step=self.training_step_count) 242 | self.logger.experiment.add_scalar('offset_weight', 1 / (2 * torch.exp(self.model.offset_weight)), 243 | global_step=self.training_step_count) 244 | if self.cfg.INSTANCE_FLOW.ENABLED: 245 | self.logger.experiment.add_scalar('flow_weight', 1 / (2 * torch.exp(self.model.flow_weight)), 246 | global_step=self.training_step_count) 247 | 248 | def training_epoch_end(self, step_outputs): 249 | self.shared_epoch_end(step_outputs, True) 250 | 251 | def validation_epoch_end(self, step_outputs): 252 | self.shared_epoch_end(step_outputs, False) 253 | 254 | def configure_optimizers(self): 255 | params = self.model.parameters() 256 | optimizer = torch.optim.Adam( 257 | params, lr=self.cfg.OPTIMIZER.LR, weight_decay=self.cfg.OPTIMIZER.WEIGHT_DECAY 258 | ) 259 | 260 | return optimizer 261 | -------------------------------------------------------------------------------- /fiery/utils/visualisation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import matplotlib.pylab 4 | 5 | from fiery.utils.instance import predict_instance_segmentation_and_trajectories 6 | 7 | DEFAULT_COLORMAP = matplotlib.pylab.cm.jet 8 | 9 | 10 | def flow_to_image(flow: np.ndarray, autoscale: bool = False) -> np.ndarray: 11 | """ 12 | Applies colour map to flow which should be a 2 channel image tensor HxWx2. Returns a HxWx3 numpy image 13 | Code adapted from: https://github.com/liruoteng/FlowNet/blob/master/models/flownet/scripts/flowlib.py 14 | """ 15 | u = flow[0, :, :] 16 | v = flow[1, :, :] 17 | 18 | # Convert to polar coordinates 19 | rad = np.sqrt(u ** 2 + v ** 2) 20 | maxrad = np.max(rad) 21 | 22 | # Normalise flow maps 23 | if autoscale: 24 | u /= maxrad + np.finfo(float).eps 25 | v /= maxrad + np.finfo(float).eps 26 | 27 | # visualise flow with cmap 28 | return np.uint8(compute_color(u, v) * 255) 29 | 30 | 31 | def _normalise(image: np.ndarray) -> np.ndarray: 32 | lower = np.min(image) 33 | delta = np.max(image) - lower 34 | if delta == 0: 35 | delta = 1 36 | image = (image.astype(np.float32) - lower) / delta 37 | return image 38 | 39 | 40 | def apply_colour_map( 41 | image: np.ndarray, cmap: matplotlib.colors.LinearSegmentedColormap = DEFAULT_COLORMAP, autoscale: bool = False 42 | ) -> np.ndarray: 43 | """ 44 | Applies a colour map to the given 1 or 2 channel numpy image. if 2 channel, must be 2xHxW. 45 | Returns a HxWx3 numpy image 46 | """ 47 | if image.ndim == 2 or (image.ndim == 3 and image.shape[0] == 1): 48 | if image.ndim == 3: 49 | image = image[0] 50 | # grayscale scalar image 51 | if autoscale: 52 | image = _normalise(image) 53 | return cmap(image)[:, :, :3] 54 | if image.shape[0] == 2: 55 | # 2 dimensional UV 56 | return flow_to_image(image, autoscale=autoscale) 57 | if image.shape[0] == 3: 58 | # normalise rgb channels 59 | if autoscale: 60 | image = _normalise(image) 61 | return np.transpose(image, axes=[1, 2, 0]) 62 | raise Exception('Image must be 1, 2 or 3 channel to convert to colour_map (CxHxW)') 63 | 64 | 65 | def heatmap_image( 66 | image: np.ndarray, cmap: matplotlib.colors.LinearSegmentedColormap = DEFAULT_COLORMAP, autoscale: bool = True 67 | ) -> np.ndarray: 68 | """Colorize an 1 or 2 channel image with a colourmap.""" 69 | if not issubclass(image.dtype.type, np.floating): 70 | raise ValueError(f"Expected a ndarray of float type, but got dtype {image.dtype}") 71 | if not (image.ndim == 2 or (image.ndim == 3 and image.shape[0] in [1, 2])): 72 | raise ValueError(f"Expected a ndarray of shape [H, W] or [1, H, W] or [2, H, W], but got shape {image.shape}") 73 | heatmap_np = apply_colour_map(image, cmap=cmap, autoscale=autoscale) 74 | heatmap_np = np.uint8(heatmap_np * 255) 75 | return heatmap_np 76 | 77 | 78 | def compute_color(u: np.ndarray, v: np.ndarray) -> np.ndarray: 79 | assert u.shape == v.shape 80 | [h, w] = u.shape 81 | img = np.zeros([h, w, 3]) 82 | nan_mask = np.isnan(u) | np.isnan(v) 83 | u[nan_mask] = 0 84 | v[nan_mask] = 0 85 | 86 | colorwheel = make_color_wheel() 87 | ncols = np.size(colorwheel, 0) 88 | 89 | rad = np.sqrt(u ** 2 + v ** 2) 90 | a = np.arctan2(-v, -u) / np.pi 91 | f_k = (a + 1) / 2 * (ncols - 1) + 1 92 | k_0 = np.floor(f_k).astype(int) 93 | k_1 = k_0 + 1 94 | k_1[k_1 == ncols + 1] = 1 95 | f = f_k - k_0 96 | 97 | for i in range(0, np.size(colorwheel, 1)): 98 | tmp = colorwheel[:, i] 99 | col0 = tmp[k_0 - 1] / 255 100 | col1 = tmp[k_1 - 1] / 255 101 | col = (1 - f) * col0 + f * col1 102 | 103 | idx = rad <= 1 104 | col[idx] = 1 - rad[idx] * (1 - col[idx]) 105 | notidx = np.logical_not(idx) 106 | 107 | col[notidx] *= 0.75 108 | img[:, :, i] = col * (1 - nan_mask) 109 | 110 | return img 111 | 112 | 113 | def make_color_wheel() -> np.ndarray: 114 | """ 115 | Create colour wheel. 116 | Code adapted from https://github.com/liruoteng/FlowNet/blob/master/models/flownet/scripts/flowlib.py 117 | """ 118 | red_yellow = 15 119 | yellow_green = 6 120 | green_cyan = 4 121 | cyan_blue = 11 122 | blue_magenta = 13 123 | magenta_red = 6 124 | 125 | ncols = red_yellow + yellow_green + green_cyan + cyan_blue + blue_magenta + magenta_red 126 | colorwheel = np.zeros([ncols, 3]) 127 | 128 | col = 0 129 | 130 | # red_yellow 131 | colorwheel[0:red_yellow, 0] = 255 132 | colorwheel[0:red_yellow, 1] = np.transpose(np.floor(255 * np.arange(0, red_yellow) / red_yellow)) 133 | col += red_yellow 134 | 135 | # yellow_green 136 | colorwheel[col : col + yellow_green, 0] = 255 - np.transpose( 137 | np.floor(255 * np.arange(0, yellow_green) / yellow_green) 138 | ) 139 | colorwheel[col : col + yellow_green, 1] = 255 140 | col += yellow_green 141 | 142 | # green_cyan 143 | colorwheel[col : col + green_cyan, 1] = 255 144 | colorwheel[col : col + green_cyan, 2] = np.transpose(np.floor(255 * np.arange(0, green_cyan) / green_cyan)) 145 | col += green_cyan 146 | 147 | # cyan_blue 148 | colorwheel[col : col + cyan_blue, 1] = 255 - np.transpose(np.floor(255 * np.arange(0, cyan_blue) / cyan_blue)) 149 | colorwheel[col : col + cyan_blue, 2] = 255 150 | col += cyan_blue 151 | 152 | # blue_magenta 153 | colorwheel[col : col + blue_magenta, 2] = 255 154 | colorwheel[col : col + blue_magenta, 0] = np.transpose(np.floor(255 * np.arange(0, blue_magenta) / blue_magenta)) 155 | col += +blue_magenta 156 | 157 | # magenta_red 158 | colorwheel[col : col + magenta_red, 2] = 255 - np.transpose(np.floor(255 * np.arange(0, magenta_red) / magenta_red)) 159 | colorwheel[col : col + magenta_red, 0] = 255 160 | 161 | return colorwheel 162 | 163 | 164 | def make_contour(img, colour=[0, 0, 0], double_line=False): 165 | h, w = img.shape[:2] 166 | out = img.copy() 167 | # Vertical lines 168 | out[np.arange(h), np.repeat(0, h)] = colour 169 | out[np.arange(h), np.repeat(w - 1, h)] = colour 170 | 171 | # Horizontal lines 172 | out[np.repeat(0, w), np.arange(w)] = colour 173 | out[np.repeat(h - 1, w), np.arange(w)] = colour 174 | 175 | if double_line: 176 | out[np.arange(h), np.repeat(1, h)] = colour 177 | out[np.arange(h), np.repeat(w - 2, h)] = colour 178 | 179 | # Horizontal lines 180 | out[np.repeat(1, w), np.arange(w)] = colour 181 | out[np.repeat(h - 2, w), np.arange(w)] = colour 182 | return out 183 | 184 | 185 | def plot_instance_map(instance_image, instance_map, instance_colours=None, bg_image=None): 186 | if isinstance(instance_image, torch.Tensor): 187 | instance_image = instance_image.cpu().numpy() 188 | assert isinstance(instance_image, np.ndarray) 189 | if instance_colours is None: 190 | instance_colours = generate_instance_colours(instance_map) 191 | if len(instance_image.shape) > 2: 192 | instance_image = instance_image.reshape((instance_image.shape[-2], instance_image.shape[-1])) 193 | 194 | if bg_image is None: 195 | plot_image = 255 * np.ones((instance_image.shape[0], instance_image.shape[1], 3), dtype=np.uint8) 196 | else: 197 | plot_image = bg_image 198 | 199 | for key, value in instance_colours.items(): 200 | plot_image[instance_image == key] = value 201 | 202 | return plot_image 203 | 204 | 205 | def visualise_output(labels, output, cfg): 206 | semantic_colours = np.array([[255, 255, 255], [0, 0, 0]], dtype=np.uint8) 207 | 208 | consistent_instance_seg = predict_instance_segmentation_and_trajectories( 209 | output, compute_matched_centers=False 210 | ) 211 | 212 | sequence_length = consistent_instance_seg.shape[1] 213 | b = 0 214 | video = [] 215 | for t in range(sequence_length): 216 | out_t = [] 217 | # Ground truth 218 | unique_ids = torch.unique(labels['instance'][b, t]).cpu().numpy()[1:] 219 | instance_map = dict(zip(unique_ids, unique_ids)) 220 | instance_plot = plot_instance_map(labels['instance'][b, t].cpu(), instance_map)[::-1, ::-1] 221 | instance_plot = make_contour(instance_plot) 222 | 223 | semantic_seg = labels['segmentation'].squeeze(2).cpu().numpy() 224 | semantic_plot = semantic_colours[semantic_seg[b, t][::-1, ::-1]] 225 | semantic_plot = make_contour(semantic_plot) 226 | 227 | if cfg.INSTANCE_FLOW.ENABLED: 228 | future_flow_plot = labels['flow'][b, t].cpu().numpy() 229 | future_flow_plot[:, semantic_seg[b, t] != 1] = 0 230 | future_flow_plot = flow_to_image(future_flow_plot)[::-1, ::-1] 231 | future_flow_plot = make_contour(future_flow_plot) 232 | else: 233 | future_flow_plot = np.zeros_like(semantic_plot) 234 | 235 | center_plot = heatmap_image(labels['centerness'][b, t, 0].cpu().numpy())[::-1, ::-1] 236 | center_plot = make_contour(center_plot) 237 | 238 | offset_plot = labels['offset'][b, t].cpu().numpy() 239 | offset_plot[:, semantic_seg[b, t] != 1] = 0 240 | offset_plot = flow_to_image(offset_plot)[::-1, ::-1] 241 | offset_plot = make_contour(offset_plot) 242 | 243 | out_t.append(np.concatenate([instance_plot, future_flow_plot, 244 | semantic_plot, center_plot, offset_plot], axis=0)) 245 | 246 | # Predictions 247 | unique_ids = torch.unique(consistent_instance_seg[b, t]).cpu().numpy()[1:] 248 | instance_map = dict(zip(unique_ids, unique_ids)) 249 | instance_plot = plot_instance_map(consistent_instance_seg[b, t].cpu(), instance_map)[::-1, ::-1] 250 | instance_plot = make_contour(instance_plot) 251 | 252 | semantic_seg = output['segmentation'].argmax(dim=2).detach().cpu().numpy() 253 | semantic_plot = semantic_colours[semantic_seg[b, t][::-1, ::-1]] 254 | semantic_plot = make_contour(semantic_plot) 255 | 256 | if cfg.INSTANCE_FLOW.ENABLED: 257 | future_flow_plot = output['instance_flow'][b, t].detach().cpu().numpy() 258 | future_flow_plot[:, semantic_seg[b, t] != 1] = 0 259 | future_flow_plot = flow_to_image(future_flow_plot)[::-1, ::-1] 260 | future_flow_plot = make_contour(future_flow_plot) 261 | else: 262 | future_flow_plot = np.zeros_like(semantic_plot) 263 | 264 | center_plot = heatmap_image(output['instance_center'][b, t, 0].detach().cpu().numpy())[::-1, ::-1] 265 | center_plot = make_contour(center_plot) 266 | 267 | offset_plot = output['instance_offset'][b, t].detach().cpu().numpy() 268 | offset_plot[:, semantic_seg[b, t] != 1] = 0 269 | offset_plot = flow_to_image(offset_plot)[::-1, ::-1] 270 | offset_plot = make_contour(offset_plot) 271 | 272 | out_t.append(np.concatenate([instance_plot, future_flow_plot, 273 | semantic_plot, center_plot, offset_plot], axis=0)) 274 | out_t = np.concatenate(out_t, axis=1) 275 | # Shape (C, H, W) 276 | out_t = out_t.transpose((2, 0, 1)) 277 | 278 | video.append(out_t) 279 | 280 | # Shape (B, T, C, H, W) 281 | video = np.stack(video)[None] 282 | return video 283 | 284 | 285 | def convert_figure_numpy(figure): 286 | """ Convert figure to numpy image """ 287 | figure_np = np.frombuffer(figure.canvas.tostring_rgb(), dtype=np.uint8) 288 | figure_np = figure_np.reshape(figure.canvas.get_width_height()[::-1] + (3,)) 289 | return figure_np 290 | 291 | 292 | def generate_instance_colours(instance_map): 293 | # Most distinct 22 colors (kelly colors from https://stackoverflow.com/questions/470690/how-to-automatically-generate 294 | # -n-distinct-colors) 295 | # plus some colours from AD40k 296 | INSTANCE_COLOURS = np.asarray([ 297 | [0, 0, 0], 298 | [255, 179, 0], 299 | [128, 62, 117], 300 | [255, 104, 0], 301 | [166, 189, 215], 302 | [193, 0, 32], 303 | [206, 162, 98], 304 | [129, 112, 102], 305 | [0, 125, 52], 306 | [246, 118, 142], 307 | [0, 83, 138], 308 | [255, 122, 92], 309 | [83, 55, 122], 310 | [255, 142, 0], 311 | [179, 40, 81], 312 | [244, 200, 0], 313 | [127, 24, 13], 314 | [147, 170, 0], 315 | [89, 51, 21], 316 | [241, 58, 19], 317 | [35, 44, 22], 318 | [112, 224, 255], 319 | [70, 184, 160], 320 | [153, 0, 255], 321 | [71, 255, 0], 322 | [255, 0, 163], 323 | [255, 204, 0], 324 | [0, 255, 235], 325 | [255, 0, 235], 326 | [255, 0, 122], 327 | [255, 245, 0], 328 | [10, 190, 212], 329 | [214, 255, 0], 330 | [0, 204, 255], 331 | [20, 0, 255], 332 | [255, 255, 0], 333 | [0, 153, 255], 334 | [0, 255, 204], 335 | [41, 255, 0], 336 | [173, 0, 255], 337 | [0, 245, 255], 338 | [71, 0, 255], 339 | [0, 255, 184], 340 | [0, 92, 255], 341 | [184, 255, 0], 342 | [255, 214, 0], 343 | [25, 194, 194], 344 | [92, 0, 255], 345 | [220, 220, 220], 346 | [255, 9, 92], 347 | [112, 9, 255], 348 | [8, 255, 214], 349 | [255, 184, 6], 350 | [10, 255, 71], 351 | [255, 41, 10], 352 | [7, 255, 255], 353 | [224, 255, 8], 354 | [102, 8, 255], 355 | [255, 61, 6], 356 | [255, 194, 7], 357 | [0, 255, 20], 358 | [255, 8, 41], 359 | [255, 5, 153], 360 | [6, 51, 255], 361 | [235, 12, 255], 362 | [160, 150, 20], 363 | [0, 163, 255], 364 | [140, 140, 140], 365 | [250, 10, 15], 366 | [20, 255, 0], 367 | ]) 368 | 369 | return {instance_id: INSTANCE_COLOURS[global_instance_id % len(INSTANCE_COLOURS)] for 370 | instance_id, global_instance_id in instance_map.items() 371 | } 372 | -------------------------------------------------------------------------------- /fiery/utils/instance.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | import numpy as np 6 | from scipy.optimize import linear_sum_assignment 7 | 8 | from fiery.utils.geometry import mat2pose_vec, pose_vec2mat, warp_features 9 | 10 | 11 | # set ignore index to 0 for vis 12 | def convert_instance_mask_to_center_and_offset_label(instance_img, future_egomotion, num_instances, ignore_index=255, 13 | subtract_egomotion=True, sigma=3, spatial_extent=None): 14 | seq_len, h, w = instance_img.shape 15 | center_label = torch.zeros(seq_len, 1, h, w) 16 | offset_label = ignore_index * torch.ones(seq_len, 2, h, w) 17 | future_displacement_label = ignore_index * torch.ones(seq_len, 2, h, w) 18 | # x is vertical displacement, y is horizontal displacement 19 | x, y = torch.meshgrid(torch.arange(h, dtype=torch.float), torch.arange(w, dtype=torch.float)) 20 | 21 | if subtract_egomotion: 22 | future_egomotion_inv = mat2pose_vec(pose_vec2mat(future_egomotion).inverse()) 23 | 24 | # Compute warped instance segmentation 25 | warped_instance_seg = {} 26 | for t in range(1, seq_len): 27 | warped_inst_t = warp_features(instance_img[t].unsqueeze(0).unsqueeze(1).float(), 28 | future_egomotion_inv[t - 1].unsqueeze(0), mode='nearest', 29 | spatial_extent=spatial_extent) 30 | warped_instance_seg[t] = warped_inst_t[0, 0] 31 | 32 | # Ignore id 0 which is the background 33 | for instance_id in range(1, num_instances+1): 34 | prev_xc = None 35 | prev_yc = None 36 | prev_mask = None 37 | for t in range(seq_len): 38 | instance_mask = (instance_img[t] == instance_id) 39 | if instance_mask.sum() == 0: 40 | # this instance is not in this frame 41 | prev_xc = None 42 | prev_yc = None 43 | prev_mask = None 44 | continue 45 | 46 | xc = x[instance_mask].mean().round().long() 47 | yc = y[instance_mask].mean().round().long() 48 | 49 | off_x = xc - x 50 | off_y = yc - y 51 | g = torch.exp(-(off_x ** 2 + off_y ** 2) / sigma ** 2) 52 | center_label[t, 0] = torch.maximum(center_label[t, 0], g) 53 | offset_label[t, 0, instance_mask] = off_x[instance_mask] 54 | offset_label[t, 1, instance_mask] = off_y[instance_mask] 55 | 56 | if prev_xc is not None: 57 | # old method 58 | # cur_pt = torch.stack((xc, yc)).unsqueeze(0).float() 59 | # if subtract_egomotion: 60 | # cur_pt = warp_points(cur_pt, future_egomotion_inv[t - 1]) 61 | # cur_pt = cur_pt.squeeze(0) 62 | 63 | warped_instance_mask = warped_instance_seg[t] == instance_id 64 | if warped_instance_mask.sum() > 0: 65 | warped_xc = x[warped_instance_mask].mean().round() 66 | warped_yc = y[warped_instance_mask].mean().round() 67 | 68 | delta_x = warped_xc - prev_xc 69 | delta_y = warped_yc - prev_yc 70 | future_displacement_label[t - 1, 0, prev_mask] = delta_x 71 | future_displacement_label[t - 1, 1, prev_mask] = delta_y 72 | 73 | prev_xc = xc 74 | prev_yc = yc 75 | prev_mask = instance_mask 76 | 77 | return center_label, offset_label, future_displacement_label 78 | 79 | 80 | def find_instance_centers(center_prediction: torch.Tensor, conf_threshold: float = 0.1, nms_kernel_size: float = 3): 81 | assert len(center_prediction.shape) == 3 82 | center_prediction = F.threshold(center_prediction, threshold=conf_threshold, value=-1) 83 | 84 | nms_padding = (nms_kernel_size - 1) // 2 85 | maxpooled_center_prediction = F.max_pool2d( 86 | center_prediction, kernel_size=nms_kernel_size, stride=1, padding=nms_padding 87 | ) 88 | 89 | # Filter all elements that are not the maximum (i.e. the center of the heatmap instance) 90 | center_prediction[center_prediction != maxpooled_center_prediction] = -1 91 | return torch.nonzero(center_prediction > 0)[:, 1:] 92 | 93 | 94 | def group_pixels(centers: torch.Tensor, offset_predictions: torch.Tensor) -> torch.Tensor: 95 | width, height = offset_predictions.shape[-2:] 96 | x_grid = ( 97 | torch.arange(width, dtype=offset_predictions.dtype, device=offset_predictions.device) 98 | .view(1, width, 1) 99 | .repeat(1, 1, height) 100 | ) 101 | y_grid = ( 102 | torch.arange(height, dtype=offset_predictions.dtype, device=offset_predictions.device) 103 | .view(1, 1, height) 104 | .repeat(1, width, 1) 105 | ) 106 | pixel_grid = torch.cat((x_grid, y_grid), dim=0) 107 | center_locations = (pixel_grid + offset_predictions).view(2, width * height, 1).permute(2, 1, 0) 108 | centers = centers.view(-1, 1, 2) 109 | 110 | distances = torch.norm(centers - center_locations, dim=-1) 111 | 112 | instance_id = torch.argmin(distances, dim=0).reshape(1, width, height) + 1 113 | return instance_id 114 | 115 | 116 | def get_instance_segmentation_and_centers( 117 | center_predictions: torch.Tensor, 118 | offset_predictions: torch.Tensor, 119 | foreground_mask: torch.Tensor, 120 | conf_threshold: float = 0.1, 121 | nms_kernel_size: float = 3, 122 | max_n_instance_centers: int = 100, 123 | ) -> Tuple[torch.Tensor, torch.Tensor]: 124 | width, height = center_predictions.shape[-2:] 125 | center_predictions = center_predictions.view(1, width, height) 126 | offset_predictions = offset_predictions.view(2, width, height) 127 | foreground_mask = foreground_mask.view(1, width, height) 128 | 129 | centers = find_instance_centers(center_predictions, conf_threshold=conf_threshold, nms_kernel_size=nms_kernel_size) 130 | if not len(centers): 131 | return torch.zeros(center_predictions.shape, dtype=torch.int64, device=center_predictions.device), \ 132 | torch.zeros((0, 2), device=centers.device) 133 | 134 | if len(centers) > max_n_instance_centers: 135 | print(f'There are a lot of detected instance centers: {centers.shape}') 136 | centers = centers[:max_n_instance_centers].clone() 137 | 138 | instance_ids = group_pixels(centers, offset_predictions) 139 | instance_seg = (instance_ids * foreground_mask.float()).long() 140 | 141 | # Make the indices of instance_seg consecutive 142 | instance_seg = make_instance_seg_consecutive(instance_seg) 143 | 144 | return instance_seg.long(), centers 145 | 146 | 147 | def update_instance_ids(instance_seg, old_ids, new_ids): 148 | """ 149 | Parameters 150 | ---------- 151 | instance_seg: torch.Tensor arbitrary shape 152 | old_ids: 1D tensor containing the list of old ids, must be all present in instance_seg. 153 | new_ids: 1D tensor with the new ids, aligned with old_ids 154 | 155 | Returns 156 | new_instance_seg: torch.Tensor same shape as instance_seg with new ids 157 | """ 158 | indices = torch.arange(old_ids.max() + 1, device=instance_seg.device) 159 | for old_id, new_id in zip(old_ids, new_ids): 160 | indices[old_id] = new_id 161 | 162 | return indices[instance_seg].long() 163 | 164 | 165 | def make_instance_seg_consecutive(instance_seg): 166 | # Make the indices of instance_seg consecutive 167 | unique_ids = torch.unique(instance_seg) 168 | new_ids = torch.arange(len(unique_ids), device=instance_seg.device) 169 | instance_seg = update_instance_ids(instance_seg, unique_ids, new_ids) 170 | return instance_seg 171 | 172 | 173 | def make_instance_id_temporally_consistent(pred_inst, future_flow, matching_threshold=3.0): 174 | """ 175 | Parameters 176 | ---------- 177 | pred_inst: torch.Tensor (1, seq_len, h, w) 178 | future_flow: torch.Tensor(1, seq_len, 2, h, w) 179 | matching_threshold: distance threshold for a match to be valid. 180 | 181 | Returns 182 | ------- 183 | consistent_instance_seg: torch.Tensor(1, seq_len, h, w) 184 | 185 | 1. time t. Loop over all detected instances. Use flow to compute new centers at time t+1. 186 | 2. Store those centers 187 | 3. time t+1. Re-identify instances by comparing position of actual centers, and flow-warped centers. 188 | Make the labels at t+1 consistent with the matching 189 | 4. Repeat 190 | """ 191 | assert pred_inst.shape[0] == 1, 'Assumes batch size = 1' 192 | 193 | # Initialise instance segmentations with prediction corresponding to the present 194 | consistent_instance_seg = [pred_inst[0, 0]] 195 | largest_instance_id = consistent_instance_seg[0].max().item() 196 | 197 | _, seq_len, h, w = pred_inst.shape 198 | device = pred_inst.device 199 | for t in range(seq_len - 1): 200 | # Compute predicted future instance means 201 | grid = torch.stack(torch.meshgrid( 202 | torch.arange(h, dtype=torch.float, device=device), torch.arange(w, dtype=torch.float, device=device) 203 | )) 204 | 205 | # Add future flow 206 | grid = grid + future_flow[0, t] 207 | warped_centers = [] 208 | # Go through all ids, except the background 209 | t_instance_ids = torch.unique(consistent_instance_seg[-1])[1:].cpu().numpy() 210 | 211 | if len(t_instance_ids) == 0: 212 | # No instance so nothing to update 213 | consistent_instance_seg.append(pred_inst[0, t + 1]) 214 | continue 215 | 216 | for instance_id in t_instance_ids: 217 | instance_mask = (consistent_instance_seg[-1] == instance_id) 218 | warped_centers.append(grid[:, instance_mask].mean(dim=1)) 219 | warped_centers = torch.stack(warped_centers) 220 | 221 | # Compute actual future instance means 222 | centers = [] 223 | grid = torch.stack(torch.meshgrid( 224 | torch.arange(h, dtype=torch.float, device=device), torch.arange(w, dtype=torch.float, device=device) 225 | )) 226 | n_instances = int(pred_inst[0, t + 1].max().item()) 227 | 228 | if n_instances == 0: 229 | # No instance, so nothing to update. 230 | consistent_instance_seg.append(pred_inst[0, t + 1]) 231 | continue 232 | 233 | for instance_id in range(1, n_instances + 1): 234 | instance_mask = (pred_inst[0, t + 1] == instance_id) 235 | centers.append(grid[:, instance_mask].mean(dim=1)) 236 | centers = torch.stack(centers) 237 | 238 | # Compute distance matrix between warped centers and actual centers 239 | distances = torch.norm(centers.unsqueeze(0) - warped_centers.unsqueeze(1), dim=-1).cpu().numpy() 240 | # outputs (row, col) with row: index in frame t, col: index in frame t+1 241 | # the missing ids in col must be added (correspond to new instances) 242 | ids_t, ids_t_one = linear_sum_assignment(distances) 243 | matching_distances = distances[ids_t, ids_t_one] 244 | # Offset by one as id=0 is the background 245 | ids_t += 1 246 | ids_t_one += 1 247 | 248 | # swap ids_t with real ids. as those ids correspond to the position in the distance matrix. 249 | id_mapping = dict(zip(np.arange(1, len(t_instance_ids) + 1), t_instance_ids)) 250 | ids_t = np.vectorize(id_mapping.__getitem__, otypes=[np.int64])(ids_t) 251 | 252 | # Filter low quality match 253 | ids_t = ids_t[matching_distances < matching_threshold] 254 | ids_t_one = ids_t_one[matching_distances < matching_threshold] 255 | 256 | # Elements that are in t+1, but weren't matched 257 | remaining_ids = set(torch.unique(pred_inst[0, t + 1]).cpu().numpy()).difference(set(ids_t_one)) 258 | # remove background 259 | remaining_ids.remove(0) 260 | #  Set remaining_ids to a new unique id 261 | for remaining_id in list(remaining_ids): 262 | largest_instance_id += 1 263 | ids_t = np.append(ids_t, largest_instance_id) 264 | ids_t_one = np.append(ids_t_one, remaining_id) 265 | 266 | consistent_instance_seg.append(update_instance_ids(pred_inst[0, t + 1], old_ids=ids_t_one, new_ids=ids_t)) 267 | 268 | consistent_instance_seg = torch.stack(consistent_instance_seg).unsqueeze(0) 269 | return consistent_instance_seg 270 | 271 | 272 | def predict_instance_segmentation_and_trajectories( 273 | output, compute_matched_centers=False, make_consistent=True, vehicles_id=1, 274 | ): 275 | preds = output['segmentation'].detach() 276 | preds = torch.argmax(preds, dim=2, keepdims=True) 277 | foreground_masks = preds.squeeze(2) == vehicles_id 278 | 279 | batch_size, seq_len = preds.shape[:2] 280 | pred_inst = [] 281 | for b in range(batch_size): 282 | pred_inst_batch = [] 283 | for t in range(seq_len): 284 | pred_instance_t, _ = get_instance_segmentation_and_centers( 285 | output['instance_center'][b, t].detach(), 286 | output['instance_offset'][b, t].detach(), 287 | foreground_masks[b, t].detach() 288 | ) 289 | pred_inst_batch.append(pred_instance_t) 290 | pred_inst.append(torch.stack(pred_inst_batch, dim=0)) 291 | 292 | pred_inst = torch.stack(pred_inst).squeeze(2) 293 | 294 | if make_consistent: 295 | if output['instance_flow'] is None: 296 | print('Using zero flow because instance_future_output is None') 297 | output['instance_flow'] = torch.zeros_like(output['instance_offset']) 298 | consistent_instance_seg = [] 299 | for b in range(batch_size): 300 | consistent_instance_seg.append( 301 | make_instance_id_temporally_consistent(pred_inst[b:b+1], 302 | output['instance_flow'][b:b+1].detach()) 303 | ) 304 | consistent_instance_seg = torch.cat(consistent_instance_seg, dim=0) 305 | else: 306 | consistent_instance_seg = pred_inst 307 | 308 | if compute_matched_centers: 309 | assert batch_size == 1 310 | # Generate trajectories 311 | matched_centers = {} 312 | _, seq_len, h, w = consistent_instance_seg.shape 313 | grid = torch.stack(torch.meshgrid( 314 | torch.arange(h, dtype=torch.float, device=preds.device), 315 | torch.arange(w, dtype=torch.float, device=preds.device) 316 | )) 317 | 318 | for instance_id in torch.unique(consistent_instance_seg[0, 0])[1:].cpu().numpy(): 319 | for t in range(seq_len): 320 | instance_mask = consistent_instance_seg[0, t] == instance_id 321 | if instance_mask.sum() > 0: 322 | matched_centers[instance_id] = matched_centers.get(instance_id, []) + [ 323 | grid[:, instance_mask].mean(dim=-1)] 324 | 325 | for key, value in matched_centers.items(): 326 | matched_centers[key] = torch.stack(value).cpu().numpy()[:, ::-1] 327 | 328 | return consistent_instance_seg, matched_centers 329 | 330 | return consistent_instance_seg 331 | 332 | 333 | -------------------------------------------------------------------------------- /fiery/models/fiery.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from fiery.models.encoder import Encoder 5 | from fiery.models.temporal_model import TemporalModelIdentity, TemporalModel 6 | from fiery.models.distributions import DistributionModule 7 | from fiery.models.future_prediction import FuturePrediction 8 | from fiery.models.decoder import Decoder 9 | from fiery.utils.network import pack_sequence_dim, unpack_sequence_dim, set_bn_momentum 10 | from fiery.utils.geometry import cumulative_warp_features, calculate_birds_eye_view_parameters, VoxelsSumming 11 | 12 | 13 | class Fiery(nn.Module): 14 | def __init__(self, cfg): 15 | super().__init__() 16 | self.cfg = cfg 17 | 18 | bev_resolution, bev_start_position, bev_dimension = calculate_birds_eye_view_parameters( 19 | self.cfg.LIFT.X_BOUND, self.cfg.LIFT.Y_BOUND, self.cfg.LIFT.Z_BOUND 20 | ) 21 | self.bev_resolution = nn.Parameter(bev_resolution, requires_grad=False) 22 | self.bev_start_position = nn.Parameter(bev_start_position, requires_grad=False) 23 | self.bev_dimension = nn.Parameter(bev_dimension, requires_grad=False) 24 | 25 | self.encoder_downsample = self.cfg.MODEL.ENCODER.DOWNSAMPLE 26 | self.encoder_out_channels = self.cfg.MODEL.ENCODER.OUT_CHANNELS 27 | 28 | self.frustum = self.create_frustum() 29 | self.depth_channels, _, _, _ = self.frustum.shape 30 | 31 | if self.cfg.TIME_RECEPTIVE_FIELD == 1: 32 | assert self.cfg.MODEL.TEMPORAL_MODEL.NAME == 'identity' 33 | 34 | # temporal block 35 | self.receptive_field = self.cfg.TIME_RECEPTIVE_FIELD 36 | self.n_future = self.cfg.N_FUTURE_FRAMES 37 | self.latent_dim = self.cfg.MODEL.DISTRIBUTION.LATENT_DIM 38 | 39 | if self.cfg.MODEL.SUBSAMPLE: 40 | assert self.cfg.DATASET.NAME == 'lyft' 41 | self.receptive_field = 3 42 | self.n_future = 5 43 | 44 | # Spatial extent in bird's-eye view, in meters 45 | self.spatial_extent = (self.cfg.LIFT.X_BOUND[1], self.cfg.LIFT.Y_BOUND[1]) 46 | self.bev_size = (self.bev_dimension[0].item(), self.bev_dimension[1].item()) 47 | 48 | # Encoder 49 | self.encoder = Encoder(cfg=self.cfg.MODEL.ENCODER, D=self.depth_channels) 50 | 51 | # Temporal model 52 | temporal_in_channels = self.encoder_out_channels 53 | if self.cfg.MODEL.TEMPORAL_MODEL.INPUT_EGOPOSE: 54 | temporal_in_channels += 6 55 | if self.cfg.MODEL.TEMPORAL_MODEL.NAME == 'identity': 56 | self.temporal_model = TemporalModelIdentity(temporal_in_channels, self.receptive_field) 57 | elif cfg.MODEL.TEMPORAL_MODEL.NAME == 'temporal_block': 58 | self.temporal_model = TemporalModel( 59 | temporal_in_channels, 60 | self.receptive_field, 61 | input_shape=self.bev_size, 62 | start_out_channels=self.cfg.MODEL.TEMPORAL_MODEL.START_OUT_CHANNELS, 63 | extra_in_channels=self.cfg.MODEL.TEMPORAL_MODEL.EXTRA_IN_CHANNELS, 64 | n_spatial_layers_between_temporal_layers=self.cfg.MODEL.TEMPORAL_MODEL.INBETWEEN_LAYERS, 65 | use_pyramid_pooling=self.cfg.MODEL.TEMPORAL_MODEL.PYRAMID_POOLING, 66 | ) 67 | else: 68 | raise NotImplementedError(f'Temporal module {self.cfg.MODEL.TEMPORAL_MODEL.NAME}.') 69 | 70 | self.future_pred_in_channels = self.temporal_model.out_channels 71 | if self.n_future > 0: 72 | # probabilistic sampling 73 | if self.cfg.PROBABILISTIC.ENABLED: 74 | # Distribution networks 75 | self.present_distribution = DistributionModule( 76 | self.future_pred_in_channels, 77 | self.latent_dim, 78 | min_log_sigma=self.cfg.MODEL.DISTRIBUTION.MIN_LOG_SIGMA, 79 | max_log_sigma=self.cfg.MODEL.DISTRIBUTION.MAX_LOG_SIGMA, 80 | ) 81 | 82 | future_distribution_in_channels = (self.future_pred_in_channels 83 | + self.n_future * self.cfg.PROBABILISTIC.FUTURE_DIM 84 | ) 85 | self.future_distribution = DistributionModule( 86 | future_distribution_in_channels, 87 | self.latent_dim, 88 | min_log_sigma=self.cfg.MODEL.DISTRIBUTION.MIN_LOG_SIGMA, 89 | max_log_sigma=self.cfg.MODEL.DISTRIBUTION.MAX_LOG_SIGMA, 90 | ) 91 | 92 | # Future prediction 93 | self.future_prediction = FuturePrediction( 94 | in_channels=self.future_pred_in_channels, 95 | latent_dim=self.latent_dim, 96 | n_gru_blocks=self.cfg.MODEL.FUTURE_PRED.N_GRU_BLOCKS, 97 | n_res_layers=self.cfg.MODEL.FUTURE_PRED.N_RES_LAYERS, 98 | ) 99 | 100 | # Decoder 101 | self.decoder = Decoder( 102 | in_channels=self.future_pred_in_channels, 103 | n_classes=len(self.cfg.SEMANTIC_SEG.WEIGHTS), 104 | predict_future_flow=self.cfg.INSTANCE_FLOW.ENABLED, 105 | ) 106 | 107 | set_bn_momentum(self, self.cfg.MODEL.BN_MOMENTUM) 108 | 109 | def create_frustum(self): 110 | # Create grid in image plane 111 | h, w = self.cfg.IMAGE.FINAL_DIM 112 | downsampled_h, downsampled_w = h // self.encoder_downsample, w // self.encoder_downsample 113 | 114 | # Depth grid 115 | depth_grid = torch.arange(*self.cfg.LIFT.D_BOUND, dtype=torch.float) 116 | depth_grid = depth_grid.view(-1, 1, 1).expand(-1, downsampled_h, downsampled_w) 117 | n_depth_slices = depth_grid.shape[0] 118 | 119 | # x and y grids 120 | x_grid = torch.linspace(0, w - 1, downsampled_w, dtype=torch.float) 121 | x_grid = x_grid.view(1, 1, downsampled_w).expand(n_depth_slices, downsampled_h, downsampled_w) 122 | y_grid = torch.linspace(0, h - 1, downsampled_h, dtype=torch.float) 123 | y_grid = y_grid.view(1, downsampled_h, 1).expand(n_depth_slices, downsampled_h, downsampled_w) 124 | 125 | # Dimension (n_depth_slices, downsampled_h, downsampled_w, 3) 126 | # containing data points in the image: left-right, top-bottom, depth 127 | frustum = torch.stack((x_grid, y_grid, depth_grid), -1) 128 | return nn.Parameter(frustum, requires_grad=False) 129 | 130 | def forward(self, image, intrinsics, extrinsics, future_egomotion, future_distribution_inputs=None, noise=None): 131 | output = {} 132 | 133 | # Only process features from the past and present 134 | image = image[:, :self.receptive_field].contiguous() 135 | intrinsics = intrinsics[:, :self.receptive_field].contiguous() 136 | extrinsics = extrinsics[:, :self.receptive_field].contiguous() 137 | future_egomotion = future_egomotion[:, :self.receptive_field].contiguous() 138 | 139 | # Lifting features and project to bird's-eye view 140 | x = self.calculate_birds_eye_view_features(image, intrinsics, extrinsics) 141 | 142 | # Warp past features to the present's reference frame 143 | x = cumulative_warp_features( 144 | x.clone(), future_egomotion, 145 | mode='bilinear', spatial_extent=self.spatial_extent, 146 | ) 147 | 148 | if self.cfg.MODEL.TEMPORAL_MODEL.INPUT_EGOPOSE: 149 | b, s, c = future_egomotion.shape 150 | h, w = x.shape[-2:] 151 | future_egomotions_spatial = future_egomotion.view(b, s, c, 1, 1).expand(b, s, c, h, w) 152 | # at time 0, no egomotion so feed zero vector 153 | future_egomotions_spatial = torch.cat([torch.zeros_like(future_egomotions_spatial[:, :1]), 154 | future_egomotions_spatial[:, :(self.receptive_field-1)]], dim=1) 155 | x = torch.cat([x, future_egomotions_spatial], dim=-3) 156 | 157 | #  Temporal model 158 | states = self.temporal_model(x) 159 | 160 | if self.n_future > 0: 161 | present_state = states[:, :1].contiguous() 162 | if self.cfg.PROBABILISTIC.ENABLED: 163 | # Do probabilistic computation 164 | sample, output_distribution = self.distribution_forward( 165 | present_state, future_distribution_inputs, noise 166 | ) 167 | output = {**output, **output_distribution} 168 | 169 | # Prepare future prediction input 170 | b, _, _, h, w = present_state.shape 171 | hidden_state = present_state[:, 0] 172 | 173 | if self.cfg.PROBABILISTIC.ENABLED: 174 | future_prediction_input = sample.expand(-1, self.n_future, -1, -1, -1) 175 | else: 176 | future_prediction_input = hidden_state.new_zeros(b, self.n_future, self.latent_dim, h, w) 177 | 178 | # Recursively predict future states 179 | future_states = self.future_prediction(future_prediction_input, hidden_state) 180 | 181 | # Concatenate present state 182 | future_states = torch.cat([present_state, future_states], dim=1) 183 | 184 | # Predict bird's-eye view outputs 185 | if self.n_future > 0: 186 | bev_output = self.decoder(future_states) 187 | else: 188 | bev_output = self.decoder(states[:, -1:]) 189 | output = {**output, **bev_output} 190 | 191 | return output 192 | 193 | def get_geometry(self, intrinsics, extrinsics): 194 | """Calculate the (x, y, z) 3D position of the features. 195 | """ 196 | rotation, translation = extrinsics[..., :3, :3], extrinsics[..., :3, 3] 197 | B, N, _ = translation.shape 198 | # Add batch, camera dimension, and a dummy dimension at the end 199 | points = self.frustum.unsqueeze(0).unsqueeze(0).unsqueeze(-1) 200 | 201 | # Camera to ego reference frame 202 | points = torch.cat((points[:, :, :, :, :, :2] * points[:, :, :, :, :, 2:3], points[:, :, :, :, :, 2:3]), 5) 203 | combined_transformation = rotation.matmul(torch.inverse(intrinsics)) 204 | points = combined_transformation.view(B, N, 1, 1, 1, 3, 3).matmul(points).squeeze(-1) 205 | points += translation.view(B, N, 1, 1, 1, 3) 206 | 207 | # The 3 dimensions in the ego reference frame are: (forward, sides, height) 208 | return points 209 | 210 | def encoder_forward(self, x): 211 | # batch, n_cameras, channels, height, width 212 | b, n, c, h, w = x.shape 213 | 214 | x = x.view(b * n, c, h, w) 215 | x = self.encoder(x) 216 | x = x.view(b, n, *x.shape[1:]) 217 | x = x.permute(0, 1, 3, 4, 5, 2) 218 | 219 | return x 220 | 221 | def projection_to_birds_eye_view(self, x, geometry): 222 | """ Adapted from https://github.com/nv-tlabs/lift-splat-shoot/blob/master/src/models.py#L200""" 223 | # batch, n_cameras, depth, height, width, channels 224 | batch, n, d, h, w, c = x.shape 225 | output = torch.zeros( 226 | (batch, c, self.bev_dimension[0], self.bev_dimension[1]), dtype=torch.float, device=x.device 227 | ) 228 | 229 | # Number of 3D points 230 | N = n * d * h * w 231 | for b in range(batch): 232 | # flatten x 233 | x_b = x[b].reshape(N, c) 234 | 235 | # Convert positions to integer indices 236 | geometry_b = ((geometry[b] - (self.bev_start_position - self.bev_resolution / 2.0)) / self.bev_resolution) 237 | geometry_b = geometry_b.view(N, 3).long() 238 | 239 | # Mask out points that are outside the considered spatial extent. 240 | mask = ( 241 | (geometry_b[:, 0] >= 0) 242 | & (geometry_b[:, 0] < self.bev_dimension[0]) 243 | & (geometry_b[:, 1] >= 0) 244 | & (geometry_b[:, 1] < self.bev_dimension[1]) 245 | & (geometry_b[:, 2] >= 0) 246 | & (geometry_b[:, 2] < self.bev_dimension[2]) 247 | ) 248 | x_b = x_b[mask] 249 | geometry_b = geometry_b[mask] 250 | 251 | # Sort tensors so that those within the same voxel are consecutives. 252 | ranks = ( 253 | geometry_b[:, 0] * (self.bev_dimension[1] * self.bev_dimension[2]) 254 | + geometry_b[:, 1] * (self.bev_dimension[2]) 255 | + geometry_b[:, 2] 256 | ) 257 | ranks_indices = ranks.argsort() 258 | x_b, geometry_b, ranks = x_b[ranks_indices], geometry_b[ranks_indices], ranks[ranks_indices] 259 | 260 | # Project to bird's-eye view by summing voxels. 261 | x_b, geometry_b = VoxelsSumming.apply(x_b, geometry_b, ranks) 262 | 263 | bev_feature = torch.zeros((self.bev_dimension[2], self.bev_dimension[0], self.bev_dimension[1], c), 264 | device=x_b.device) 265 | bev_feature[geometry_b[:, 2], geometry_b[:, 0], geometry_b[:, 1]] = x_b 266 | 267 | # Put channel in second position and remove z dimension 268 | bev_feature = bev_feature.permute((0, 3, 1, 2)) 269 | bev_feature = bev_feature.squeeze(0) 270 | 271 | output[b] = bev_feature 272 | 273 | return output 274 | 275 | def calculate_birds_eye_view_features(self, x, intrinsics, extrinsics): 276 | b, s, n, c, h, w = x.shape 277 | # Reshape 278 | x = pack_sequence_dim(x) 279 | intrinsics = pack_sequence_dim(intrinsics) 280 | extrinsics = pack_sequence_dim(extrinsics) 281 | 282 | geometry = self.get_geometry(intrinsics, extrinsics) 283 | x = self.encoder_forward(x) 284 | x = self.projection_to_birds_eye_view(x, geometry) 285 | x = unpack_sequence_dim(x, b, s) 286 | return x 287 | 288 | def distribution_forward(self, present_features, future_distribution_inputs=None, noise=None): 289 | """ 290 | Parameters 291 | ---------- 292 | present_features: 5-D output from dynamics module with shape (b, 1, c, h, w) 293 | future_distribution_inputs: 5-D tensor containing labels shape (b, s, cfg.PROB_FUTURE_DIM, h, w) 294 | noise: a sample from a (0, 1) gaussian with shape (b, s, latent_dim). If None, will sample in function 295 | 296 | Returns 297 | ------- 298 | sample: sample taken from present/future distribution, broadcast to shape (b, s, latent_dim, h, w) 299 | present_distribution_mu: shape (b, s, latent_dim) 300 | present_distribution_log_sigma: shape (b, s, latent_dim) 301 | future_distribution_mu: shape (b, s, latent_dim) 302 | future_distribution_log_sigma: shape (b, s, latent_dim) 303 | """ 304 | b, s, _, h, w = present_features.size() 305 | assert s == 1 306 | 307 | present_mu, present_log_sigma = self.present_distribution(present_features) 308 | 309 | future_mu, future_log_sigma = None, None 310 | if future_distribution_inputs is not None: 311 | # Concatenate future labels to z_t 312 | future_features = future_distribution_inputs[:, 1:].contiguous().view(b, 1, -1, h, w) 313 | future_features = torch.cat([present_features, future_features], dim=2) 314 | future_mu, future_log_sigma = self.future_distribution(future_features) 315 | 316 | if noise is None: 317 | if self.training: 318 | noise = torch.randn_like(present_mu) 319 | else: 320 | noise = torch.zeros_like(present_mu) 321 | if self.training: 322 | mu = future_mu 323 | sigma = torch.exp(future_log_sigma) 324 | else: 325 | mu = present_mu 326 | sigma = torch.exp(present_log_sigma) 327 | sample = mu + sigma * noise 328 | 329 | # Spatially broadcast sample to the dimensions of present_features 330 | sample = sample.view(b, s, self.latent_dim, 1, 1).expand(b, s, self.latent_dim, h, w) 331 | 332 | output_distribution = { 333 | 'present_mu': present_mu, 334 | 'present_log_sigma': present_log_sigma, 335 | 'future_mu': future_mu, 336 | 'future_log_sigma': future_log_sigma, 337 | } 338 | 339 | return sample, output_distribution 340 | -------------------------------------------------------------------------------- /fiery/data.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | 4 | import numpy as np 5 | import cv2 6 | import torch 7 | import torchvision 8 | 9 | from pyquaternion import Quaternion 10 | from nuscenes.nuscenes import NuScenes 11 | from nuscenes.utils.splits import create_splits_scenes 12 | from nuscenes.utils.data_classes import Box 13 | from lyft_dataset_sdk.lyftdataset import LyftDataset 14 | 15 | from fiery.utils.geometry import ( 16 | resize_and_crop_image, 17 | update_intrinsics, 18 | calculate_birds_eye_view_parameters, 19 | convert_egopose_to_matrix_numpy, 20 | pose_vec2mat, 21 | mat2pose_vec, 22 | invert_matrix_egopose_numpy, 23 | ) 24 | from fiery.utils.instance import convert_instance_mask_to_center_and_offset_label 25 | from fiery.utils.lyft_splits import TRAIN_LYFT_INDICES, VAL_LYFT_INDICES 26 | 27 | 28 | class FuturePredictionDataset(torch.utils.data.Dataset): 29 | def __init__(self, nusc, is_train, cfg): 30 | self.nusc = nusc 31 | self.is_train = is_train 32 | self.cfg = cfg 33 | 34 | self.is_lyft = isinstance(nusc, LyftDataset) 35 | 36 | if self.is_lyft: 37 | self.dataroot = self.nusc.data_path 38 | else: 39 | self.dataroot = self.nusc.dataroot 40 | 41 | self.mode = 'train' if self.is_train else 'val' 42 | 43 | self.sequence_length = cfg.TIME_RECEPTIVE_FIELD + cfg.N_FUTURE_FRAMES 44 | 45 | self.scenes = self.get_scenes() 46 | self.ixes = self.prepro() 47 | self.indices = self.get_indices() 48 | 49 | # Image resizing and cropping 50 | self.augmentation_parameters = self.get_resizing_and_cropping_parameters() 51 | 52 | # Normalising input images 53 | self.normalise_image = torchvision.transforms.Compose( 54 | [torchvision.transforms.ToTensor(), 55 | torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 56 | ] 57 | ) 58 | 59 | # Bird's-eye view parameters 60 | bev_resolution, bev_start_position, bev_dimension = calculate_birds_eye_view_parameters( 61 | cfg.LIFT.X_BOUND, cfg.LIFT.Y_BOUND, cfg.LIFT.Z_BOUND 62 | ) 63 | self.bev_resolution, self.bev_start_position, self.bev_dimension = ( 64 | bev_resolution.numpy(), bev_start_position.numpy(), bev_dimension.numpy() 65 | ) 66 | 67 | # Spatial extent in bird's-eye view, in meters 68 | self.spatial_extent = (self.cfg.LIFT.X_BOUND[1], self.cfg.LIFT.Y_BOUND[1]) 69 | 70 | def get_scenes(self): 71 | 72 | if self.is_lyft: 73 | scenes = [row['name'] for row in self.nusc.scene] 74 | 75 | # Split in train/val 76 | indices = TRAIN_LYFT_INDICES if self.is_train else VAL_LYFT_INDICES 77 | scenes = [scenes[i] for i in indices] 78 | else: 79 | # filter by scene split 80 | split = {'v1.0-trainval': {True: 'train', False: 'val'}, 81 | 'v1.0-mini': {True: 'mini_train', False: 'mini_val'},}[ 82 | self.nusc.version 83 | ][self.is_train] 84 | 85 | scenes = create_splits_scenes()[split] 86 | 87 | return scenes 88 | 89 | def prepro(self): 90 | samples = [samp for samp in self.nusc.sample] 91 | 92 | # remove samples that aren't in this split 93 | samples = [samp for samp in samples if self.nusc.get('scene', samp['scene_token'])['name'] in self.scenes] 94 | 95 | # sort by scene, timestamp (only to make chronological viz easier) 96 | samples.sort(key=lambda x: (x['scene_token'], x['timestamp'])) 97 | 98 | return samples 99 | 100 | def get_indices(self): 101 | indices = [] 102 | for index in range(len(self.ixes)): 103 | is_valid_data = True 104 | previous_rec = None 105 | current_indices = [] 106 | for t in range(self.sequence_length): 107 | index_t = index + t 108 | # Going over the dataset size limit. 109 | if index_t >= len(self.ixes): 110 | is_valid_data = False 111 | break 112 | rec = self.ixes[index_t] 113 | # Check if scene is the same 114 | if (previous_rec is not None) and (rec['scene_token'] != previous_rec['scene_token']): 115 | is_valid_data = False 116 | break 117 | 118 | current_indices.append(index_t) 119 | previous_rec = rec 120 | 121 | if is_valid_data: 122 | indices.append(current_indices) 123 | 124 | return np.asarray(indices) 125 | 126 | def get_resizing_and_cropping_parameters(self): 127 | original_height, original_width = self.cfg.IMAGE.ORIGINAL_HEIGHT, self.cfg.IMAGE.ORIGINAL_WIDTH 128 | final_height, final_width = self.cfg.IMAGE.FINAL_DIM 129 | 130 | resize_scale = self.cfg.IMAGE.RESIZE_SCALE 131 | resize_dims = (int(original_width * resize_scale), int(original_height * resize_scale)) 132 | resized_width, resized_height = resize_dims 133 | 134 | crop_h = self.cfg.IMAGE.TOP_CROP 135 | crop_w = int(max(0, (resized_width - final_width) / 2)) 136 | # Left, top, right, bottom crops. 137 | crop = (crop_w, crop_h, crop_w + final_width, crop_h + final_height) 138 | 139 | if resized_width != final_width: 140 | print('Zero padding left and right parts of the image.') 141 | if crop_h + final_height != resized_height: 142 | print('Zero padding bottom part of the image.') 143 | 144 | return {'scale_width': resize_scale, 145 | 'scale_height': resize_scale, 146 | 'resize_dims': resize_dims, 147 | 'crop': crop, 148 | } 149 | 150 | def get_input_data(self, rec): 151 | """ 152 | Parameters 153 | ---------- 154 | rec: nuscenes identifier for a given timestamp 155 | 156 | Returns 157 | ------- 158 | images: torch.Tensor (N, 3, H, W) 159 | intrinsics: torch.Tensor (3, 3) 160 | extrinsics: torch.Tensor(N, 4, 4) 161 | """ 162 | images = [] 163 | intrinsics = [] 164 | extrinsics = [] 165 | cameras = self.cfg.IMAGE.NAMES 166 | 167 | # The extrinsics we want are from the camera sensor to "flat egopose" as defined 168 | # https://github.com/nutonomy/nuscenes-devkit/blob/9b492f76df22943daf1dc991358d3d606314af27/python-sdk/nuscenes/nuscenes.py#L279 169 | # which corresponds to the position of the lidar. 170 | # This is because the labels are generated by projecting the 3D bounding box in this lidar's reference frame. 171 | 172 | # From lidar egopose to world. 173 | lidar_sample = self.nusc.get('sample_data', rec['data']['LIDAR_TOP']) 174 | lidar_pose = self.nusc.get('ego_pose', lidar_sample['ego_pose_token']) 175 | yaw = Quaternion(lidar_pose['rotation']).yaw_pitch_roll[0] 176 | lidar_rotation = Quaternion(scalar=np.cos(yaw / 2), vector=[0, 0, np.sin(yaw / 2)]) 177 | lidar_translation = np.array(lidar_pose['translation'])[:, None] 178 | lidar_to_world = np.vstack([ 179 | np.hstack((lidar_rotation.rotation_matrix, lidar_translation)), 180 | np.array([0, 0, 0, 1]) 181 | ]) 182 | 183 | for cam in cameras: 184 | camera_sample = self.nusc.get('sample_data', rec['data'][cam]) 185 | 186 | # Transformation from world to egopose 187 | car_egopose = self.nusc.get('ego_pose', camera_sample['ego_pose_token']) 188 | egopose_rotation = Quaternion(car_egopose['rotation']).inverse 189 | egopose_translation = -np.array(car_egopose['translation'])[:, None] 190 | world_to_car_egopose = np.vstack([ 191 | np.hstack((egopose_rotation.rotation_matrix, egopose_rotation.rotation_matrix @ egopose_translation)), 192 | np.array([0, 0, 0, 1]) 193 | ]) 194 | 195 | # From egopose to sensor 196 | sensor_sample = self.nusc.get('calibrated_sensor', camera_sample['calibrated_sensor_token']) 197 | intrinsic = torch.Tensor(sensor_sample['camera_intrinsic']) 198 | sensor_rotation = Quaternion(sensor_sample['rotation']) 199 | sensor_translation = np.array(sensor_sample['translation'])[:, None] 200 | car_egopose_to_sensor = np.vstack([ 201 | np.hstack((sensor_rotation.rotation_matrix, sensor_translation)), 202 | np.array([0, 0, 0, 1]) 203 | ]) 204 | car_egopose_to_sensor = np.linalg.inv(car_egopose_to_sensor) 205 | 206 | # Combine all the transformation. 207 | # From sensor to lidar. 208 | lidar_to_sensor = car_egopose_to_sensor @ world_to_car_egopose @ lidar_to_world 209 | sensor_to_lidar = torch.from_numpy(np.linalg.inv(lidar_to_sensor)).float() 210 | 211 | # Load image 212 | image_filename = os.path.join(self.dataroot, camera_sample['filename']) 213 | img = Image.open(image_filename) 214 | # Resize and crop 215 | img = resize_and_crop_image( 216 | img, resize_dims=self.augmentation_parameters['resize_dims'], crop=self.augmentation_parameters['crop'] 217 | ) 218 | # Normalise image 219 | normalised_img = self.normalise_image(img) 220 | 221 | # Combine resize/cropping in the intrinsics 222 | top_crop = self.augmentation_parameters['crop'][1] 223 | left_crop = self.augmentation_parameters['crop'][0] 224 | intrinsic = update_intrinsics( 225 | intrinsic, top_crop, left_crop, 226 | scale_width=self.augmentation_parameters['scale_width'], 227 | scale_height=self.augmentation_parameters['scale_height'] 228 | ) 229 | 230 | images.append(normalised_img.unsqueeze(0).unsqueeze(0)) 231 | intrinsics.append(intrinsic.unsqueeze(0).unsqueeze(0)) 232 | extrinsics.append(sensor_to_lidar.unsqueeze(0).unsqueeze(0)) 233 | 234 | images, intrinsics, extrinsics = (torch.cat(images, dim=1), 235 | torch.cat(intrinsics, dim=1), 236 | torch.cat(extrinsics, dim=1) 237 | ) 238 | 239 | return images, intrinsics, extrinsics 240 | 241 | def _get_top_lidar_pose(self, rec): 242 | egopose = self.nusc.get('ego_pose', self.nusc.get('sample_data', rec['data']['LIDAR_TOP'])['ego_pose_token']) 243 | trans = -np.array(egopose['translation']) 244 | yaw = Quaternion(egopose['rotation']).yaw_pitch_roll[0] 245 | rot = Quaternion(scalar=np.cos(yaw / 2), vector=[0, 0, np.sin(yaw / 2)]).inverse 246 | return trans, rot 247 | 248 | def get_birds_eye_view_label(self, rec, instance_map): 249 | translation, rotation = self._get_top_lidar_pose(rec) 250 | segmentation = np.zeros((self.bev_dimension[0], self.bev_dimension[1])) 251 | # Background is ID 0 252 | instance = np.zeros((self.bev_dimension[0], self.bev_dimension[1])) 253 | z_position = np.zeros((self.bev_dimension[0], self.bev_dimension[1])) 254 | attribute_label = np.zeros((self.bev_dimension[0], self.bev_dimension[1])) 255 | 256 | for annotation_token in rec['anns']: 257 | # Filter out all non vehicle instances 258 | annotation = self.nusc.get('sample_annotation', annotation_token) 259 | 260 | if not self.is_lyft: 261 | # NuScenes filter 262 | if 'vehicle' not in annotation['category_name']: 263 | continue 264 | if self.cfg.DATASET.FILTER_INVISIBLE_VEHICLES and int(annotation['visibility_token']) == 1: 265 | continue 266 | else: 267 | # Lyft filter 268 | if annotation['category_name'] not in ['bus', 'car', 'construction_vehicle', 'trailer', 'truck']: 269 | continue 270 | 271 | if annotation['instance_token'] not in instance_map: 272 | instance_map[annotation['instance_token']] = len(instance_map) + 1 273 | instance_id = instance_map[annotation['instance_token']] 274 | 275 | if not self.is_lyft: 276 | instance_attribute = int(annotation['visibility_token']) 277 | else: 278 | instance_attribute = 0 279 | 280 | poly_region, z = self._get_poly_region_in_image(annotation, translation, rotation) 281 | cv2.fillPoly(instance, [poly_region], instance_id) 282 | cv2.fillPoly(segmentation, [poly_region], 1.0) 283 | cv2.fillPoly(z_position, [poly_region], z) 284 | cv2.fillPoly(attribute_label, [poly_region], instance_attribute) 285 | 286 | return segmentation, instance, z_position, instance_map, attribute_label 287 | 288 | def _get_poly_region_in_image(self, instance_annotation, ego_translation, ego_rotation): 289 | box = Box( 290 | instance_annotation['translation'], instance_annotation['size'], Quaternion(instance_annotation['rotation']) 291 | ) 292 | box.translate(ego_translation) 293 | box.rotate(ego_rotation) 294 | 295 | pts = box.bottom_corners()[:2].T 296 | pts = np.round((pts - self.bev_start_position[:2] + self.bev_resolution[:2] / 2.0) / self.bev_resolution[:2]).astype(np.int32) 297 | pts[:, [1, 0]] = pts[:, [0, 1]] 298 | 299 | z = box.bottom_corners()[2, 0] 300 | return pts, z 301 | 302 | def get_label(self, rec, instance_map): 303 | segmentation_np, instance_np, z_position_np, instance_map, attribute_label_np = \ 304 | self.get_birds_eye_view_label(rec, instance_map) 305 | segmentation = torch.from_numpy(segmentation_np).long().unsqueeze(0).unsqueeze(0) 306 | instance = torch.from_numpy(instance_np).long().unsqueeze(0) 307 | z_position = torch.from_numpy(z_position_np).float().unsqueeze(0).unsqueeze(0) 308 | attribute_label = torch.from_numpy(attribute_label_np).long().unsqueeze(0).unsqueeze(0) 309 | 310 | return segmentation, instance, z_position, instance_map, attribute_label 311 | 312 | def get_future_egomotion(self, rec, index): 313 | rec_t0 = rec 314 | 315 | # Identity 316 | future_egomotion = np.eye(4, dtype=np.float32) 317 | 318 | if index < len(self.ixes) - 1: 319 | rec_t1 = self.ixes[index + 1] 320 | 321 | if rec_t0['scene_token'] == rec_t1['scene_token']: 322 | egopose_t0 = self.nusc.get( 323 | 'ego_pose', self.nusc.get('sample_data', rec_t0['data']['LIDAR_TOP'])['ego_pose_token'] 324 | ) 325 | egopose_t1 = self.nusc.get( 326 | 'ego_pose', self.nusc.get('sample_data', rec_t1['data']['LIDAR_TOP'])['ego_pose_token'] 327 | ) 328 | 329 | egopose_t0 = convert_egopose_to_matrix_numpy(egopose_t0) 330 | egopose_t1 = convert_egopose_to_matrix_numpy(egopose_t1) 331 | 332 | future_egomotion = invert_matrix_egopose_numpy(egopose_t1).dot(egopose_t0) 333 | future_egomotion[3, :3] = 0.0 334 | future_egomotion[3, 3] = 1.0 335 | 336 | future_egomotion = torch.Tensor(future_egomotion).float() 337 | 338 | # Convert to 6DoF vector 339 | future_egomotion = mat2pose_vec(future_egomotion) 340 | return future_egomotion.unsqueeze(0) 341 | 342 | def __len__(self): 343 | return len(self.indices) 344 | 345 | def __getitem__(self, index): 346 | """ 347 | Returns 348 | ------- 349 | data: dict with the following keys: 350 | image: torch.Tensor (T, N, 3, H, W) 351 | normalised cameras images with T the sequence length, and N the number of cameras. 352 | intrinsics: torch.Tensor (T, N, 3, 3) 353 | intrinsics containing resizing and cropping parameters. 354 | extrinsics: torch.Tensor (T, N, 4, 4) 355 | 6 DoF pose from world coordinates to camera coordinates. 356 | segmentation: torch.Tensor (T, 1, H_bev, W_bev) 357 | (H_bev, W_bev) are the pixel dimensions in bird's-eye view. 358 | instance: torch.Tensor (T, 1, H_bev, W_bev) 359 | centerness: torch.Tensor (T, 1, H_bev, W_bev) 360 | offset: torch.Tensor (T, 2, H_bev, W_bev) 361 | flow: torch.Tensor (T, 2, H_bev, W_bev) 362 | future_egomotion: torch.Tensor (T, 6) 363 | 6 DoF egomotion t -> t+1 364 | sample_token: List (T,) 365 | 'z_position': list_z_position, 366 | 'attribute': list_attribute_label, 367 | 368 | """ 369 | data = {} 370 | keys = ['image', 'intrinsics', 'extrinsics', 371 | 'segmentation', 'instance', 'centerness', 'offset', 'flow', 'future_egomotion', 372 | 'sample_token', 373 | 'z_position', 'attribute' 374 | ] 375 | for key in keys: 376 | data[key] = [] 377 | 378 | instance_map = {} 379 | # Loop over all the frames in the sequence. 380 | for index_t in self.indices[index]: 381 | rec = self.ixes[index_t] 382 | 383 | images, intrinsics, extrinsics = self.get_input_data(rec) 384 | segmentation, instance, z_position, instance_map, attribute_label = self.get_label(rec, instance_map) 385 | 386 | future_egomotion = self.get_future_egomotion(rec, index_t) 387 | 388 | data['image'].append(images) 389 | data['intrinsics'].append(intrinsics) 390 | data['extrinsics'].append(extrinsics) 391 | data['segmentation'].append(segmentation) 392 | data['instance'].append(instance) 393 | data['future_egomotion'].append(future_egomotion) 394 | data['sample_token'].append(rec['token']) 395 | data['z_position'].append(z_position) 396 | data['attribute'].append(attribute_label) 397 | 398 | for key, value in data.items(): 399 | if key in ['sample_token', 'centerness', 'offset', 'flow']: 400 | continue 401 | data[key] = torch.cat(value, dim=0) 402 | 403 | # If lyft need to subsample, and update future_egomotions 404 | if self.cfg.MODEL.SUBSAMPLE: 405 | for key, value in data.items(): 406 | if key in ['future_egomotion', 'sample_token', 'centerness', 'offset', 'flow']: 407 | continue 408 | data[key] = data[key][::2].clone() 409 | data['sample_token'] = data['sample_token'][::2] 410 | 411 | # Update future egomotions 412 | future_egomotions_matrix = pose_vec2mat(data['future_egomotion']) 413 | future_egomotion_accum = torch.zeros_like(future_egomotions_matrix) 414 | future_egomotion_accum[:-1] = future_egomotions_matrix[:-1] @ future_egomotions_matrix[1:] 415 | future_egomotion_accum = mat2pose_vec(future_egomotion_accum) 416 | data['future_egomotion'] = future_egomotion_accum[::2].clone() 417 | 418 | instance_centerness, instance_offset, instance_flow = convert_instance_mask_to_center_and_offset_label( 419 | data['instance'], data['future_egomotion'], 420 | num_instances=len(instance_map), ignore_index=self.cfg.DATASET.IGNORE_INDEX, subtract_egomotion=True, 421 | spatial_extent=self.spatial_extent, 422 | ) 423 | data['centerness'] = instance_centerness 424 | data['offset'] = instance_offset 425 | data['flow'] = instance_flow 426 | return data 427 | 428 | 429 | def prepare_dataloaders(cfg, return_dataset=False): 430 | version = cfg.DATASET.VERSION 431 | train_on_training_data = True 432 | 433 | if cfg.DATASET.NAME == 'nuscenes': 434 | # 28130 train and 6019 val 435 | dataroot = os.path.join(cfg.DATASET.DATAROOT, version) 436 | nusc = NuScenes(version='v1.0-{}'.format(cfg.DATASET.VERSION), dataroot=dataroot, verbose=False) 437 | elif cfg.DATASET.NAME == 'lyft': 438 | # train contains 22680 samples 439 | # we split in 16506 6174 440 | dataroot = os.path.join(cfg.DATASET.DATAROOT, 'trainval') 441 | nusc = LyftDataset(data_path=dataroot, 442 | json_path=os.path.join(dataroot, 'train_data'), 443 | verbose=True) 444 | 445 | traindata = FuturePredictionDataset(nusc, train_on_training_data, cfg) 446 | valdata = FuturePredictionDataset(nusc, False, cfg) 447 | 448 | if cfg.DATASET.VERSION == 'mini': 449 | traindata.indices = traindata.indices[:10] 450 | valdata.indices = valdata.indices[:10] 451 | 452 | nworkers = cfg.N_WORKERS 453 | trainloader = torch.utils.data.DataLoader( 454 | traindata, batch_size=cfg.BATCHSIZE, shuffle=True, num_workers=nworkers, pin_memory=True, drop_last=True 455 | ) 456 | valloader = torch.utils.data.DataLoader( 457 | valdata, batch_size=cfg.BATCHSIZE, shuffle=False, num_workers=nworkers, pin_memory=True, drop_last=False) 458 | 459 | if return_dataset: 460 | return trainloader, valloader, traindata, valdata 461 | else: 462 | return trainloader, valloader 463 | --------------------------------------------------------------------------------