├── .gitignore
├── fiery
├── configs
│ ├── literature
│ │ ├── static_lss_setting.yml
│ │ ├── lift_splat_setting.yml
│ │ ├── static_pon_setting.yml
│ │ ├── pon_setting.yml
│ │ └── fishing_setting.yml
│ ├── lyft
│ │ ├── debug_lyft.yml
│ │ ├── single_timeframe.yml
│ │ └── baseline.yml
│ ├── debug_baseline.yml
│ ├── temporal_single_timeframe.yml
│ ├── single_timeframe.yml
│ └── baseline.yml
├── utils
│ ├── lyft_splits.py
│ ├── network.py
│ ├── geometry.py
│ ├── visualisation.py
│ └── instance.py
├── models
│ ├── future_prediction.py
│ ├── distributions.py
│ ├── temporal_model.py
│ ├── decoder.py
│ ├── encoder.py
│ └── fiery.py
├── losses.py
├── config.py
├── layers
│ ├── convolutions.py
│ └── temporal.py
├── metrics.py
├── trainer.py
└── data.py
├── environment.yml
├── DATASET.md
├── LICENSE
├── train.py
├── evaluate.py
├── README.md
└── visualise.py
/.gitignore:
--------------------------------------------------------------------------------
1 | tensorboard_logs/
2 | example_data/
3 | output_vis/
4 | .idea/
5 | *.pyc
6 | *.ipynb_checkpoints
7 | *.ipynb
8 | *.ckpt
--------------------------------------------------------------------------------
/fiery/configs/literature/static_lss_setting.yml:
--------------------------------------------------------------------------------
1 | _BASE_: '../single_timeframe.yml'
2 |
3 | TAG: 'lift_splat_setting'
4 |
5 | DATASET:
6 | FILTER_INVISIBLE_VEHICLES: False
7 |
--------------------------------------------------------------------------------
/fiery/configs/literature/lift_splat_setting.yml:
--------------------------------------------------------------------------------
1 | _BASE_: '../temporal_single_timeframe.yml'
2 |
3 | TAG: 'temporal_lift_splat_setting'
4 |
5 | DATASET:
6 | FILTER_INVISIBLE_VEHICLES: False
7 |
--------------------------------------------------------------------------------
/fiery/configs/literature/static_pon_setting.yml:
--------------------------------------------------------------------------------
1 | _BASE_: 'static_lss_setting.yml'
2 |
3 | TAG: 'pyramid_occupancy_network_setting'
4 |
5 | LIFT:
6 | X_BOUND: [-50.0, 50.0, 0.25]
7 | Y_BOUND: [-25.0, 25.0, 0.25]
8 |
--------------------------------------------------------------------------------
/fiery/configs/lyft/debug_lyft.yml:
--------------------------------------------------------------------------------
1 | _BASE_: 'baseline.yml'
2 |
3 | TAG: 'debug'
4 |
5 | BATCHSIZE: 1
6 | GPUS: [0]
7 |
8 | DATASET:
9 | VERSION: 'mini'
10 |
11 | VIS_INTERVAL: 4
12 |
13 | N_WORKERS: 0
14 |
--------------------------------------------------------------------------------
/fiery/configs/lyft/single_timeframe.yml:
--------------------------------------------------------------------------------
1 | _BASE_: '../single_timeframe.yml'
2 |
3 | TAG: 'lyft_single_frame'
4 |
5 | DATASET:
6 | NAME: 'lyft'
7 |
8 | IMAGE:
9 | H: 1080
10 | W: 1920
11 | RESIZE_SCALE: 0.25
12 |
--------------------------------------------------------------------------------
/fiery/configs/debug_baseline.yml:
--------------------------------------------------------------------------------
1 | _BASE_: 'baseline.yml'
2 |
3 | TAG: 'debug'
4 |
5 | EPOCHS: 2
6 | BATCHSIZE: 1
7 | GPUS: [0]
8 | LOGGING_INTERVAL: 10
9 |
10 | DATASET:
11 | VERSION: 'mini'
12 |
13 | VIS_INTERVAL: 8
14 |
--------------------------------------------------------------------------------
/fiery/configs/literature/pon_setting.yml:
--------------------------------------------------------------------------------
1 | _BASE_: '../temporal_single_timeframe.yml'
2 |
3 | TAG: 'temporal_pon_setting'
4 |
5 | DATASET:
6 | FILTER_INVISIBLE_VEHICLES: False
7 |
8 | LIFT:
9 | X_BOUND: [-50.0, 50.0, 0.25]
10 | Y_BOUND: [-25.0, 25.0, 0.25]
11 |
--------------------------------------------------------------------------------
/fiery/configs/literature/fishing_setting.yml:
--------------------------------------------------------------------------------
1 | _BASE_: '../baseline.yml'
2 |
3 | TAG: 'fishing_setting'
4 |
5 | BATCHSIZE: 3
6 |
7 | DATASET:
8 | FILTER_INVISIBLE_VEHICLES: False
9 |
10 | LIFT:
11 | D_BOUND: [2.0, 16.0, 0.5]
12 | X_BOUND: [-16.0, 16.0, 0.1]
13 | Y_BOUND: [-9.6, 9.7, 0.1]
14 |
--------------------------------------------------------------------------------
/fiery/configs/temporal_single_timeframe.yml:
--------------------------------------------------------------------------------
1 | _BASE_: 'single_timeframe.yml'
2 |
3 | TAG: 'temporal_single_timeframe'
4 |
5 | BATCHSIZE: 4
6 | PRECISION: 16
7 |
8 | TIME_RECEPTIVE_FIELD: 3
9 |
10 | MODEL:
11 | BN_MOMENTUM: 0.05
12 | TEMPORAL_MODEL:
13 | NAME: 'temporal_block'
14 | INPUT_EGOPOSE: True
15 |
--------------------------------------------------------------------------------
/fiery/configs/lyft/baseline.yml:
--------------------------------------------------------------------------------
1 | _BASE_: '../baseline.yml'
2 |
3 | TAG: 'lyft_baseline'
4 |
5 | GPUS: [0, 1, 2, 3]
6 |
7 | BATCHSIZE: 3
8 | TIME_RECEPTIVE_FIELD: 5
9 | N_FUTURE_FRAMES: 10
10 |
11 | DATASET:
12 | NAME: 'lyft'
13 |
14 | IMAGE:
15 | H: 1080
16 | W: 1920
17 | RESIZE_SCALE: 0.25
18 |
19 | MODEL:
20 | SUBSAMPLE: True
21 |
--------------------------------------------------------------------------------
/fiery/configs/single_timeframe.yml:
--------------------------------------------------------------------------------
1 | TAG: 'single_timeframe_model'
2 |
3 | GPUS: [0, 1]
4 |
5 | BATCHSIZE: 8
6 |
7 | TIME_RECEPTIVE_FIELD: 1
8 | N_FUTURE_FRAMES: 0
9 |
10 | PROBABILISTIC:
11 | ENABLED: False
12 |
13 | MODEL:
14 | TEMPORAL_MODEL:
15 | NAME: 'identity'
16 | INPUT_EGOPOSE: False
17 |
18 | INSTANCE_FLOW:
19 | ENABLED: False
20 |
21 | OPTIMIZER:
22 | LR: 1e-3
23 |
24 | N_WORKERS: 10
25 |
--------------------------------------------------------------------------------
/fiery/configs/baseline.yml:
--------------------------------------------------------------------------------
1 | TAG: 'baseline'
2 |
3 | GPUS: [0, 1, 2, 3]
4 |
5 | BATCHSIZE: 3
6 | PRECISION: 16
7 |
8 | TIME_RECEPTIVE_FIELD: 3
9 | N_FUTURE_FRAMES: 4
10 |
11 | PROBABILISTIC:
12 | ENABLED: True
13 |
14 | MODEL:
15 | BN_MOMENTUM: 0.05
16 | TEMPORAL_MODEL:
17 | NAME: 'temporal_block'
18 | INPUT_EGOPOSE: True
19 |
20 | INSTANCE_FLOW:
21 | ENABLED: True
22 |
23 | OPTIMIZER:
24 | LR: 3e-4
25 |
26 | N_WORKERS: 10
27 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: fiery
2 | channels:
3 | - defaults
4 | - conda-forge
5 | - pytorch
6 | dependencies:
7 | - python=3.7.10
8 | - pytorch=1.7.0
9 | - torchvision=0.8.1
10 | - cudatoolkit=10.1
11 | - numpy=1.19.2
12 | - scipy=1.5.2
13 | - pillow=8.0.1
14 | - tqdm=4.50.2
15 | - pytorch-lightning=1.2.5
16 | - efficientnet-pytorch=0.7.0
17 | - fvcore=0.1.2.post20201122
18 | - pip=21.0.1
19 | - pip:
20 | - nuscenes-devkit==1.1.0
21 | - lyft-dataset-sdk==0.0.8
22 | - opencv-python==4.5.1.48
23 | - moviepy==1.0.3
24 |
--------------------------------------------------------------------------------
/DATASET.md:
--------------------------------------------------------------------------------
1 | ## 📖 Dataset
2 |
3 | - Go to the official website of the [NuScenes dataset](https://www.nuscenes.org/download). You will need to create an
4 | account.
5 | - Download the _Full dataset (v1.0)_. This includes the _Mini_ dataset (Metadata and sensor file blobs) and the
6 | _Trainval_ dataset (Metadata, File blobs part 1-10).
7 | - Extract the tar files in order to obtain the following folder structure. The `nuscenes` dataset folder will be
8 | designated as `${NUSCENES_DATAROOT}`.
9 | ```
10 | nuscenes
11 | │
12 | └───trainval
13 | │ maps
14 | │ samples
15 | │ sweeps
16 | │ v1.0-trainval
17 | │
18 | └───mini
19 | maps
20 | samples
21 | sweeps
22 | v1.0-mini
23 | ```
24 | - The full dataset is around ~400GB. It is possible to reduce the dataset size to ~60GB by only downloading the
25 | keyframe blobs only part 1-10, instead of all the file blobs part 1-10. The keyframe blobs contain the data we
26 | need (RGB images and 3D bounding box of dynamic objects at 2Hz). The remaining file blobs also include RGB
27 | images and LiDAR sweeps at a higher frequency, but they are not used during training.
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Wayve Technologies Limited
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/fiery/utils/lyft_splits.py:
--------------------------------------------------------------------------------
1 | TRAIN_LYFT_INDICES = [1, 3, 5, 6, 7, 8, 9, 10, 11, 12, 14, 15, 16,
2 | 17, 18, 19, 20, 21, 23, 24, 27, 28, 29, 30, 31, 32,
3 | 33, 35, 36, 37, 39, 41, 43, 44, 45, 46, 47, 48, 49,
4 | 50, 51, 52, 53, 55, 56, 59, 60, 62, 63, 65, 68, 69,
5 | 70, 71, 72, 73, 74, 75, 76, 78, 79, 81, 82, 83, 84,
6 | 86, 87, 88, 89, 93, 95, 97, 98, 99, 103, 104, 107, 108,
7 | 109, 110, 111, 113, 114, 115, 116, 117, 118, 119, 121, 122, 124,
8 | 127, 128, 130, 131, 132, 134, 135, 136, 137, 138, 139, 143, 144,
9 | 146, 147, 148, 149, 150, 151, 152, 153, 154, 156, 157, 158, 159,
10 | 161, 162, 165, 166, 167, 171, 172, 173, 174, 175, 176, 177, 178,
11 | 179]
12 |
13 | VAL_LYFT_INDICES = [0, 2, 4, 13, 22, 25, 26, 34, 38, 40, 42, 54, 57,
14 | 58, 61, 64, 66, 67, 77, 80, 85, 90, 91, 92, 94, 96,
15 | 100, 101, 102, 105, 106, 112, 120, 123, 125, 126, 129, 133, 140,
16 | 141, 142, 145, 155, 160, 163, 164, 168, 169, 170]
17 |
--------------------------------------------------------------------------------
/fiery/utils/network.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torchvision
4 |
5 | def pack_sequence_dim(x):
6 | b, s = x.shape[:2]
7 | return x.view(b * s, *x.shape[2:])
8 |
9 |
10 | def unpack_sequence_dim(x, b, s):
11 | return x.view(b, s, *x.shape[1:])
12 |
13 |
14 | def preprocess_batch(batch, device, unsqueeze=False):
15 | for key, value in batch.items():
16 | if key != 'sample_token':
17 | batch[key] = value.to(device)
18 | if unsqueeze:
19 | batch[key] = batch[key].unsqueeze(0)
20 |
21 |
22 | def set_module_grad(module, requires_grad=False):
23 | for p in module.parameters():
24 | p.requires_grad = requires_grad
25 |
26 |
27 | def set_bn_momentum(model, momentum=0.1):
28 | for m in model.modules():
29 | if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
30 | m.momentum = momentum
31 |
32 |
33 | class NormalizeInverse(torchvision.transforms.Normalize):
34 | # https://discuss.pytorch.org/t/simple-way-to-inverse-transform-normalization/4821/8
35 | def __init__(self, mean, std):
36 | mean = torch.as_tensor(mean)
37 | std = torch.as_tensor(std)
38 | std_inv = 1 / (std + 1e-7)
39 | mean_inv = -mean * std_inv
40 | super().__init__(mean=mean_inv, std=std_inv)
41 |
42 | def __call__(self, tensor):
43 | return super().__call__(tensor.clone())
44 |
--------------------------------------------------------------------------------
/fiery/models/future_prediction.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from fiery.layers.convolutions import Bottleneck
4 | from fiery.layers.temporal import SpatialGRU
5 |
6 |
7 | class FuturePrediction(torch.nn.Module):
8 | def __init__(self, in_channels, latent_dim, n_gru_blocks=3, n_res_layers=3):
9 | super().__init__()
10 | self.n_gru_blocks = n_gru_blocks
11 |
12 | # Convolutional recurrent model with z_t as an initial hidden state and inputs the sample
13 | # from the probabilistic model. The architecture of the model is:
14 | # [Spatial GRU - [Bottleneck] x n_res_layers] x n_gru_blocks
15 | self.spatial_grus = []
16 | self.res_blocks = []
17 |
18 | for i in range(self.n_gru_blocks):
19 | gru_in_channels = latent_dim if i == 0 else in_channels
20 | self.spatial_grus.append(SpatialGRU(gru_in_channels, in_channels))
21 | self.res_blocks.append(torch.nn.Sequential(*[Bottleneck(in_channels)
22 | for _ in range(n_res_layers)]))
23 |
24 | self.spatial_grus = torch.nn.ModuleList(self.spatial_grus)
25 | self.res_blocks = torch.nn.ModuleList(self.res_blocks)
26 |
27 | def forward(self, x, hidden_state):
28 | # x has shape (b, n_future, c, h, w), hidden_state (b, c, h, w)
29 | for i in range(self.n_gru_blocks):
30 | x = self.spatial_grus[i](x, hidden_state, flow=None)
31 | b, n_future, c, h, w = x.shape
32 |
33 | x = self.res_blocks[i](x.view(b * n_future, c, h, w))
34 | x = x.view(b, n_future, c, h, w)
35 |
36 | return x
37 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import socket
4 | import torch
5 | import pytorch_lightning as pl
6 | from pytorch_lightning.plugins import DDPPlugin
7 |
8 | from fiery.config import get_parser, get_cfg
9 | from fiery.data import prepare_dataloaders
10 | from fiery.trainer import TrainingModule
11 |
12 |
13 | def main():
14 | args = get_parser().parse_args()
15 | cfg = get_cfg(args)
16 |
17 | trainloader, valloader = prepare_dataloaders(cfg)
18 | model = TrainingModule(cfg.convert_to_dict())
19 |
20 | if cfg.PRETRAINED.LOAD_WEIGHTS:
21 | # Load single-image instance segmentation model.
22 | pretrained_model_weights = torch.load(
23 | os.path.join(cfg.DATASET.DATAROOT, cfg.PRETRAINED.PATH), map_location='cpu'
24 | )['state_dict']
25 |
26 | model.load_state_dict(pretrained_model_weights, strict=False)
27 | print(f'Loaded single-image model weights from {cfg.PRETRAINED.PATH}')
28 |
29 | save_dir = os.path.join(
30 | cfg.LOG_DIR, time.strftime('%d%B%Yat%H:%M:%S%Z') + '_' + socket.gethostname() + '_' + cfg.TAG
31 | )
32 | tb_logger = pl.loggers.TensorBoardLogger(save_dir=save_dir)
33 | trainer = pl.Trainer(
34 | gpus=cfg.GPUS,
35 | accelerator='ddp',
36 | precision=cfg.PRECISION,
37 | sync_batchnorm=True,
38 | gradient_clip_val=cfg.GRAD_NORM_CLIP,
39 | max_epochs=cfg.EPOCHS,
40 | weights_summary='full',
41 | logger=tb_logger,
42 | log_every_n_steps=cfg.LOGGING_INTERVAL,
43 | plugins=DDPPlugin(find_unused_parameters=True),
44 | profiler='simple',
45 | )
46 | trainer.fit(model, trainloader, valloader)
47 |
48 |
49 | if __name__ == "__main__":
50 | main()
51 |
--------------------------------------------------------------------------------
/fiery/models/distributions.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from fiery.layers.convolutions import Bottleneck
5 |
6 |
7 | class DistributionModule(nn.Module):
8 | """
9 | A convolutional net that parametrises a diagonal Gaussian distribution.
10 | """
11 |
12 | def __init__(
13 | self, in_channels, latent_dim, min_log_sigma, max_log_sigma):
14 | super().__init__()
15 | self.compress_dim = in_channels // 2
16 | self.latent_dim = latent_dim
17 | self.min_log_sigma = min_log_sigma
18 | self.max_log_sigma = max_log_sigma
19 |
20 | self.encoder = DistributionEncoder(
21 | in_channels,
22 | self.compress_dim,
23 | )
24 | self.last_conv = nn.Sequential(
25 | nn.AdaptiveAvgPool2d(1), nn.Conv2d(self.compress_dim, out_channels=2 * self.latent_dim, kernel_size=1)
26 | )
27 |
28 | def forward(self, s_t):
29 | b, s = s_t.shape[:2]
30 | assert s == 1
31 | encoding = self.encoder(s_t[:, 0])
32 |
33 | mu_log_sigma = self.last_conv(encoding).view(b, 1, 2 * self.latent_dim)
34 | mu = mu_log_sigma[:, :, :self.latent_dim]
35 | log_sigma = mu_log_sigma[:, :, self.latent_dim:]
36 |
37 | # clip the log_sigma value for numerical stability
38 | log_sigma = torch.clamp(log_sigma, self.min_log_sigma, self.max_log_sigma)
39 | return mu, log_sigma
40 |
41 |
42 | class DistributionEncoder(nn.Module):
43 | """Encodes s_t or (s_t, y_{t+1}, ..., y_{t+H}).
44 | """
45 | def __init__(self, in_channels, out_channels):
46 | super().__init__()
47 |
48 | self.model = nn.Sequential(
49 | Bottleneck(in_channels, out_channels=out_channels, downsample=True),
50 | Bottleneck(out_channels, out_channels=out_channels, downsample=True),
51 | Bottleneck(out_channels, out_channels=out_channels, downsample=True),
52 | Bottleneck(out_channels, out_channels=out_channels, downsample=True),
53 | )
54 |
55 | def forward(self, s_t):
56 | return self.model(s_t)
57 |
--------------------------------------------------------------------------------
/fiery/models/temporal_model.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | from fiery.layers.temporal import Bottleneck3D, TemporalBlock
4 |
5 |
6 | class TemporalModel(nn.Module):
7 | def __init__(
8 | self, in_channels, receptive_field, input_shape, start_out_channels=64, extra_in_channels=0,
9 | n_spatial_layers_between_temporal_layers=0, use_pyramid_pooling=True):
10 | super().__init__()
11 | self.receptive_field = receptive_field
12 | n_temporal_layers = receptive_field - 1
13 |
14 | h, w = input_shape
15 | modules = []
16 |
17 | block_in_channels = in_channels
18 | block_out_channels = start_out_channels
19 |
20 | for _ in range(n_temporal_layers):
21 | if use_pyramid_pooling:
22 | use_pyramid_pooling = True
23 | pool_sizes = [(2, h, w)]
24 | else:
25 | use_pyramid_pooling = False
26 | pool_sizes = None
27 | temporal = TemporalBlock(
28 | block_in_channels,
29 | block_out_channels,
30 | use_pyramid_pooling=use_pyramid_pooling,
31 | pool_sizes=pool_sizes,
32 | )
33 | spatial = [
34 | Bottleneck3D(block_out_channels, block_out_channels, kernel_size=(1, 3, 3))
35 | for _ in range(n_spatial_layers_between_temporal_layers)
36 | ]
37 | temporal_spatial_layers = nn.Sequential(temporal, *spatial)
38 | modules.extend(temporal_spatial_layers)
39 |
40 | block_in_channels = block_out_channels
41 | block_out_channels += extra_in_channels
42 |
43 | self.out_channels = block_in_channels
44 |
45 | self.model = nn.Sequential(*modules)
46 |
47 | def forward(self, x):
48 | # Reshape input tensor to (batch, C, time, H, W)
49 | x = x.permute(0, 2, 1, 3, 4)
50 | x = self.model(x)
51 | x = x.permute(0, 2, 1, 3, 4).contiguous()
52 | return x[:, (self.receptive_field - 1):]
53 |
54 |
55 | class TemporalModelIdentity(nn.Module):
56 | def __init__(self, in_channels, receptive_field):
57 | super().__init__()
58 | self.receptive_field = receptive_field
59 | self.out_channels = in_channels
60 |
61 | def forward(self, x):
62 | return x[:, (self.receptive_field - 1):]
63 |
--------------------------------------------------------------------------------
/fiery/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class SpatialRegressionLoss(nn.Module):
7 | def __init__(self, norm, ignore_index=255, future_discount=1.0):
8 | super(SpatialRegressionLoss, self).__init__()
9 | self.norm = norm
10 | self.ignore_index = ignore_index
11 | self.future_discount = future_discount
12 |
13 | if norm == 1:
14 | self.loss_fn = F.l1_loss
15 | elif norm == 2:
16 | self.loss_fn = F.mse_loss
17 | else:
18 | raise ValueError(f'Expected norm 1 or 2, but got norm={norm}')
19 |
20 | def forward(self, prediction, target):
21 | assert len(prediction.shape) == 5, 'Must be a 5D tensor'
22 | # ignore_index is the same across all channels
23 | mask = target[:, :, :1] != self.ignore_index
24 | if mask.sum() == 0:
25 | return prediction.new_zeros(1)[0].float()
26 |
27 | loss = self.loss_fn(prediction, target, reduction='none')
28 |
29 | # Sum channel dimension
30 | loss = torch.sum(loss, dim=-3, keepdims=True)
31 |
32 | seq_len = loss.shape[1]
33 | future_discounts = self.future_discount ** torch.arange(seq_len, device=loss.device, dtype=loss.dtype)
34 | future_discounts = future_discounts.view(1, seq_len, 1, 1, 1)
35 | loss = loss * future_discounts
36 |
37 | return loss[mask].mean()
38 |
39 |
40 | class SegmentationLoss(nn.Module):
41 | def __init__(self, class_weights, ignore_index=255, use_top_k=False, top_k_ratio=1.0, future_discount=1.0):
42 | super().__init__()
43 | self.class_weights = class_weights
44 | self.ignore_index = ignore_index
45 | self.use_top_k = use_top_k
46 | self.top_k_ratio = top_k_ratio
47 | self.future_discount = future_discount
48 |
49 | def forward(self, prediction, target):
50 | if target.shape[-3] != 1:
51 | raise ValueError('segmentation label must be an index-label with channel dimension = 1.')
52 | b, s, c, h, w = prediction.shape
53 |
54 | prediction = prediction.view(b * s, c, h, w)
55 | target = target.view(b * s, h, w)
56 | loss = F.cross_entropy(
57 | prediction,
58 | target,
59 | ignore_index=self.ignore_index,
60 | reduction='none',
61 | weight=self.class_weights.to(target.device),
62 | )
63 |
64 | loss = loss.view(b, s, h, w)
65 |
66 | future_discounts = self.future_discount ** torch.arange(s, device=loss.device, dtype=loss.dtype)
67 | future_discounts = future_discounts.view(1, s, 1, 1)
68 | loss = loss * future_discounts
69 |
70 | loss = loss.view(b, s, -1)
71 | if self.use_top_k:
72 | # Penalises the top-k hardest pixels
73 | k = int(self.top_k_ratio * loss.shape[2])
74 | loss, _ = torch.sort(loss, dim=2, descending=True)
75 | loss = loss[:, :, :k]
76 |
77 | return torch.mean(loss)
78 |
79 |
80 | class ProbabilisticLoss(nn.Module):
81 | def forward(self, output):
82 | present_mu = output['present_mu']
83 | present_log_sigma = output['present_log_sigma']
84 | future_mu = output['future_mu']
85 | future_log_sigma = output['future_log_sigma']
86 |
87 | var_future = torch.exp(2 * future_log_sigma)
88 | var_present = torch.exp(2 * present_log_sigma)
89 | kl_div = (
90 | present_log_sigma - future_log_sigma - 0.5 + (var_future + (future_mu - present_mu) ** 2) / (
91 | 2 * var_present)
92 | )
93 |
94 | kl_loss = torch.mean(torch.sum(kl_div, dim=-1))
95 |
96 | return kl_loss
97 |
--------------------------------------------------------------------------------
/fiery/models/decoder.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from torchvision.models.resnet import resnet18
3 |
4 | from fiery.layers.convolutions import UpsamplingAdd
5 |
6 |
7 | class Decoder(nn.Module):
8 | def __init__(self, in_channels, n_classes, predict_future_flow):
9 | super().__init__()
10 | backbone = resnet18(pretrained=False, zero_init_residual=True)
11 | self.first_conv = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
12 | self.bn1 = backbone.bn1
13 | self.relu = backbone.relu
14 |
15 | self.layer1 = backbone.layer1
16 | self.layer2 = backbone.layer2
17 | self.layer3 = backbone.layer3
18 | self.predict_future_flow = predict_future_flow
19 |
20 | shared_out_channels = in_channels
21 | self.up3_skip = UpsamplingAdd(256, 128, scale_factor=2)
22 | self.up2_skip = UpsamplingAdd(128, 64, scale_factor=2)
23 | self.up1_skip = UpsamplingAdd(64, shared_out_channels, scale_factor=2)
24 |
25 | self.segmentation_head = nn.Sequential(
26 | nn.Conv2d(shared_out_channels, shared_out_channels, kernel_size=3, padding=1, bias=False),
27 | nn.BatchNorm2d(shared_out_channels),
28 | nn.ReLU(inplace=True),
29 | nn.Conv2d(shared_out_channels, n_classes, kernel_size=1, padding=0),
30 | )
31 | self.instance_offset_head = nn.Sequential(
32 | nn.Conv2d(shared_out_channels, shared_out_channels, kernel_size=3, padding=1, bias=False),
33 | nn.BatchNorm2d(shared_out_channels),
34 | nn.ReLU(inplace=True),
35 | nn.Conv2d(shared_out_channels, 2, kernel_size=1, padding=0),
36 | )
37 | self.instance_center_head = nn.Sequential(
38 | nn.Conv2d(shared_out_channels, shared_out_channels, kernel_size=3, padding=1, bias=False),
39 | nn.BatchNorm2d(shared_out_channels),
40 | nn.ReLU(inplace=True),
41 | nn.Conv2d(shared_out_channels, 1, kernel_size=1, padding=0),
42 | nn.Sigmoid(),
43 | )
44 |
45 | if self.predict_future_flow:
46 | self.instance_future_head = nn.Sequential(
47 | nn.Conv2d(shared_out_channels, shared_out_channels, kernel_size=3, padding=1, bias=False),
48 | nn.BatchNorm2d(shared_out_channels),
49 | nn.ReLU(inplace=True),
50 | nn.Conv2d(shared_out_channels, 2, kernel_size=1, padding=0),
51 | )
52 |
53 | def forward(self, x):
54 | b, s, c, h, w = x.shape
55 | x = x.view(b * s, c, h, w)
56 |
57 | # (H, W)
58 | skip_x = {'1': x}
59 | x = self.first_conv(x)
60 | x = self.bn1(x)
61 | x = self.relu(x)
62 |
63 | # (H/4, W/4)
64 | x = self.layer1(x)
65 | skip_x['2'] = x
66 | x = self.layer2(x)
67 | skip_x['3'] = x
68 |
69 | # (H/8, W/8)
70 | x = self.layer3(x)
71 |
72 | # First upsample to (H/4, W/4)
73 | x = self.up3_skip(x, skip_x['3'])
74 |
75 | # Second upsample to (H/2, W/2)
76 | x = self.up2_skip(x, skip_x['2'])
77 |
78 | # Third upsample to (H, W)
79 | x = self.up1_skip(x, skip_x['1'])
80 |
81 | segmentation_output = self.segmentation_head(x)
82 | instance_center_output = self.instance_center_head(x)
83 | instance_offset_output = self.instance_offset_head(x)
84 | instance_future_output = self.instance_future_head(x) if self.predict_future_flow else None
85 | return {
86 | 'segmentation': segmentation_output.view(b, s, *segmentation_output.shape[1:]),
87 | 'instance_center': instance_center_output.view(b, s, *instance_center_output.shape[1:]),
88 | 'instance_offset': instance_offset_output.view(b, s, *instance_offset_output.shape[1:]),
89 | 'instance_flow': instance_future_output.view(b, s, *instance_future_output.shape[1:])
90 | if instance_future_output is not None else None,
91 | }
92 |
--------------------------------------------------------------------------------
/evaluate.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentParser
2 |
3 | import torch
4 | from tqdm import tqdm
5 |
6 | from fiery.data import prepare_dataloaders
7 | from fiery.trainer import TrainingModule
8 | from fiery.metrics import IntersectionOverUnion, PanopticMetric
9 | from fiery.utils.network import preprocess_batch
10 | from fiery.utils.instance import predict_instance_segmentation_and_trajectories
11 |
12 | # 30mx30m, 100mx100m
13 | EVALUATION_RANGES = {'30x30': (70, 130),
14 | '100x100': (0, 200)
15 | }
16 |
17 |
18 | def eval(checkpoint_path, dataroot, version):
19 | trainer = TrainingModule.load_from_checkpoint(checkpoint_path, strict=True)
20 | print(f'Loaded weights from \n {checkpoint_path}')
21 | trainer.eval()
22 |
23 | device = torch.device('cuda:0')
24 | trainer.to(device)
25 | model = trainer.model
26 |
27 | cfg = model.cfg
28 | cfg.GPUS = "[0]"
29 | cfg.BATCHSIZE = 1
30 |
31 | cfg.DATASET.DATAROOT = dataroot
32 | cfg.DATASET.VERSION = version
33 |
34 | _, valloader = prepare_dataloaders(cfg)
35 |
36 | panoptic_metrics = {}
37 | iou_metrics = {}
38 | n_classes = len(cfg.SEMANTIC_SEG.WEIGHTS)
39 | for key in EVALUATION_RANGES.keys():
40 | panoptic_metrics[key] = PanopticMetric(n_classes=n_classes, temporally_consistent=True).to(
41 | device)
42 | iou_metrics[key] = IntersectionOverUnion(n_classes).to(device)
43 |
44 | for i, batch in enumerate(tqdm(valloader)):
45 | preprocess_batch(batch, device)
46 | image = batch['image']
47 | intrinsics = batch['intrinsics']
48 | extrinsics = batch['extrinsics']
49 | future_egomotion = batch['future_egomotion']
50 |
51 | batch_size = image.shape[0]
52 |
53 | labels, future_distribution_inputs = trainer.prepare_future_labels(batch)
54 |
55 | with torch.no_grad():
56 | # Evaluate with mean prediction
57 | noise = torch.zeros((batch_size, 1, model.latent_dim), device=device)
58 | output = model(image, intrinsics, extrinsics, future_egomotion,
59 | future_distribution_inputs, noise=noise)
60 |
61 | # Consistent instance seg
62 | pred_consistent_instance_seg = predict_instance_segmentation_and_trajectories(
63 | output, compute_matched_centers=False, make_consistent=True
64 | )
65 |
66 | segmentation_pred = output['segmentation'].detach()
67 | segmentation_pred = torch.argmax(segmentation_pred, dim=2, keepdims=True)
68 |
69 | for key, grid in EVALUATION_RANGES.items():
70 | limits = slice(grid[0], grid[1])
71 | panoptic_metrics[key](pred_consistent_instance_seg[..., limits, limits].contiguous().detach(),
72 | labels['instance'][..., limits, limits].contiguous()
73 | )
74 |
75 | iou_metrics[key](segmentation_pred[..., limits, limits].contiguous(),
76 | labels['segmentation'][..., limits, limits].contiguous()
77 | )
78 |
79 | results = {}
80 | for key, grid in EVALUATION_RANGES.items():
81 | panoptic_scores = panoptic_metrics[key].compute()
82 | for panoptic_key, value in panoptic_scores.items():
83 | results[f'{panoptic_key}'] = results.get(f'{panoptic_key}', []) + [100 * value[1].item()]
84 |
85 | iou_scores = iou_metrics[key].compute()
86 | results['iou'] = results.get('iou', []) + [100 * iou_scores[1].item()]
87 |
88 | for panoptic_key in ['iou', 'pq', 'sq', 'rq']:
89 | print(panoptic_key)
90 | print(' & '.join([f'{x:.1f}' for x in results[panoptic_key]]))
91 |
92 |
93 | if __name__ == '__main__':
94 | parser = ArgumentParser(description='Fiery evaluation')
95 | parser.add_argument('--checkpoint', default='./fiery.ckpt', type=str, help='path to checkpoint')
96 | parser.add_argument('--dataroot', default='./nuscenes', type=str, help='path to the dataset')
97 | parser.add_argument('--version', default='trainval', type=str, choices=['mini', 'trainval'],
98 | help='dataset version')
99 |
100 | args = parser.parse_args()
101 |
102 | eval(args.checkpoint, args.dataroot, args.version)
103 |
--------------------------------------------------------------------------------
/fiery/models/encoder.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from efficientnet_pytorch import EfficientNet
3 |
4 | from fiery.layers.convolutions import UpsamplingConcat
5 |
6 |
7 | class Encoder(nn.Module):
8 | def __init__(self, cfg, D):
9 | super().__init__()
10 | self.D = D
11 | self.C = cfg.OUT_CHANNELS
12 | self.use_depth_distribution = cfg.USE_DEPTH_DISTRIBUTION
13 | self.downsample = cfg.DOWNSAMPLE
14 | self.version = cfg.NAME.split('-')[1]
15 |
16 | self.backbone = EfficientNet.from_pretrained(cfg.NAME)
17 | self.delete_unused_layers()
18 |
19 | if self.downsample == 16:
20 | if self.version == 'b0':
21 | upsampling_in_channels = 320 + 112
22 | elif self.version == 'b4':
23 | upsampling_in_channels = 448 + 160
24 | upsampling_out_channels = 512
25 | elif self.downsample == 8:
26 | if self.version == 'b0':
27 | upsampling_in_channels = 112 + 40
28 | elif self.version == 'b4':
29 | upsampling_in_channels = 160 + 56
30 | upsampling_out_channels = 128
31 | else:
32 | raise ValueError(f'Downsample factor {self.downsample} not handled.')
33 |
34 | self.upsampling_layer = UpsamplingConcat(upsampling_in_channels, upsampling_out_channels)
35 | if self.use_depth_distribution:
36 | self.depth_layer = nn.Conv2d(upsampling_out_channels, self.C + self.D, kernel_size=1, padding=0)
37 | else:
38 | self.depth_layer = nn.Conv2d(upsampling_out_channels, self.C, kernel_size=1, padding=0)
39 |
40 | def delete_unused_layers(self):
41 | indices_to_delete = []
42 | for idx in range(len(self.backbone._blocks)):
43 | if self.downsample == 8:
44 | if self.version == 'b0' and idx > 10:
45 | indices_to_delete.append(idx)
46 | if self.version == 'b4' and idx > 21:
47 | indices_to_delete.append(idx)
48 |
49 | for idx in reversed(indices_to_delete):
50 | del self.backbone._blocks[idx]
51 |
52 | del self.backbone._conv_head
53 | del self.backbone._bn1
54 | del self.backbone._avg_pooling
55 | del self.backbone._dropout
56 | del self.backbone._fc
57 |
58 | def get_features(self, x):
59 | # Adapted from https://github.com/lukemelas/EfficientNet-PyTorch/blob/master/efficientnet_pytorch/model.py#L231
60 | endpoints = dict()
61 |
62 | # Stem
63 | x = self.backbone._swish(self.backbone._bn0(self.backbone._conv_stem(x)))
64 | prev_x = x
65 |
66 | # Blocks
67 | for idx, block in enumerate(self.backbone._blocks):
68 | drop_connect_rate = self.backbone._global_params.drop_connect_rate
69 | if drop_connect_rate:
70 | drop_connect_rate *= float(idx) / len(self.backbone._blocks)
71 | x = block(x, drop_connect_rate=drop_connect_rate)
72 | if prev_x.size(2) > x.size(2):
73 | endpoints['reduction_{}'.format(len(endpoints) + 1)] = prev_x
74 | prev_x = x
75 |
76 | if self.downsample == 8:
77 | if self.version == 'b0' and idx == 10:
78 | break
79 | if self.version == 'b4' and idx == 21:
80 | break
81 |
82 | # Head
83 | endpoints['reduction_{}'.format(len(endpoints) + 1)] = x
84 |
85 | if self.downsample == 16:
86 | input_1, input_2 = endpoints['reduction_5'], endpoints['reduction_4']
87 | elif self.downsample == 8:
88 | input_1, input_2 = endpoints['reduction_4'], endpoints['reduction_3']
89 |
90 | x = self.upsampling_layer(input_1, input_2)
91 | return x
92 |
93 | def forward(self, x):
94 | x = self.get_features(x) # get feature vector
95 |
96 | x = self.depth_layer(x) # feature and depth head
97 |
98 | if self.use_depth_distribution:
99 | depth = x[:, : self.D].softmax(dim=1)
100 | x = depth.unsqueeze(1) * x[:, self.D : (self.D + self.C)].unsqueeze(2) # outer product depth and features
101 | else:
102 | x = x.unsqueeze(2).repeat(1, 1, self.D, 1, 1)
103 |
104 | return x
105 |
--------------------------------------------------------------------------------
/fiery/config.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from fvcore.common.config import CfgNode as _CfgNode
3 |
4 |
5 | def convert_to_dict(cfg_node, key_list=[]):
6 | """Convert a config node to dictionary."""
7 | _VALID_TYPES = {tuple, list, str, int, float, bool}
8 | if not isinstance(cfg_node, _CfgNode):
9 | if type(cfg_node) not in _VALID_TYPES:
10 | print(
11 | 'Key {} with value {} is not a valid type; valid types: {}'.format(
12 | '.'.join(key_list), type(cfg_node), _VALID_TYPES
13 | ),
14 | )
15 | return cfg_node
16 | else:
17 | cfg_dict = dict(cfg_node)
18 | for k, v in cfg_dict.items():
19 | cfg_dict[k] = convert_to_dict(v, key_list + [k])
20 | return cfg_dict
21 |
22 |
23 | class CfgNode(_CfgNode):
24 | """Remove once https://github.com/rbgirshick/yacs/issues/19 is merged."""
25 |
26 | def convert_to_dict(self):
27 | return convert_to_dict(self)
28 |
29 |
30 | CN = CfgNode
31 |
32 | _C = CN()
33 | _C.LOG_DIR = 'tensorboard_logs'
34 | _C.TAG = 'default'
35 |
36 | _C.GPUS = [0] # gpus to use
37 | _C.PRECISION = 32 # 16bit or 32bit
38 | _C.BATCHSIZE = 3
39 | _C.EPOCHS = 20
40 |
41 | _C.N_WORKERS = 5
42 | _C.VIS_INTERVAL = 5000
43 | _C.LOGGING_INTERVAL = 500
44 |
45 | _C.PRETRAINED = CN()
46 | _C.PRETRAINED.LOAD_WEIGHTS = False
47 | _C.PRETRAINED.PATH = ''
48 |
49 | _C.DATASET = CN()
50 | _C.DATASET.DATAROOT = './nuscenes/'
51 | _C.DATASET.VERSION = 'trainval'
52 | _C.DATASET.NAME = 'nuscenes'
53 | _C.DATASET.IGNORE_INDEX = 255 # Ignore index when creating flow/offset labels
54 | _C.DATASET.FILTER_INVISIBLE_VEHICLES = True # Filter vehicles that are not visible from the cameras
55 |
56 | _C.TIME_RECEPTIVE_FIELD = 3 # how many frames of temporal context (1 for single timeframe)
57 | _C.N_FUTURE_FRAMES = 4 # how many time steps into the future to predict
58 |
59 | _C.IMAGE = CN()
60 | _C.IMAGE.FINAL_DIM = (224, 480)
61 | _C.IMAGE.RESIZE_SCALE = 0.3
62 | _C.IMAGE.TOP_CROP = 46
63 | _C.IMAGE.ORIGINAL_HEIGHT = 900 # Original input RGB camera height
64 | _C.IMAGE.ORIGINAL_WIDTH = 1600 # Original input RGB camera width
65 | _C.IMAGE.NAMES = ['CAM_FRONT_LEFT', 'CAM_FRONT', 'CAM_FRONT_RIGHT', 'CAM_BACK_LEFT', 'CAM_BACK', 'CAM_BACK_RIGHT']
66 |
67 | _C.LIFT = CN() # image to BEV lifting
68 | _C.LIFT.X_BOUND = [-50.0, 50.0, 0.5] # Forward
69 | _C.LIFT.Y_BOUND = [-50.0, 50.0, 0.5] # Sides
70 | _C.LIFT.Z_BOUND = [-10.0, 10.0, 20.0] # Height
71 | _C.LIFT.D_BOUND = [2.0, 50.0, 1.0]
72 |
73 | _C.MODEL = CN()
74 |
75 | _C.MODEL.ENCODER = CN()
76 | _C.MODEL.ENCODER.DOWNSAMPLE = 8
77 | _C.MODEL.ENCODER.NAME = 'efficientnet-b4'
78 | _C.MODEL.ENCODER.OUT_CHANNELS = 64
79 | _C.MODEL.ENCODER.USE_DEPTH_DISTRIBUTION = True
80 |
81 | _C.MODEL.TEMPORAL_MODEL = CN()
82 | _C.MODEL.TEMPORAL_MODEL.NAME = 'temporal_block' # type of temporal model
83 | _C.MODEL.TEMPORAL_MODEL.START_OUT_CHANNELS = 64
84 | _C.MODEL.TEMPORAL_MODEL.EXTRA_IN_CHANNELS = 0
85 | _C.MODEL.TEMPORAL_MODEL.INBETWEEN_LAYERS = 0
86 | _C.MODEL.TEMPORAL_MODEL.PYRAMID_POOLING = True
87 | _C.MODEL.TEMPORAL_MODEL.INPUT_EGOPOSE = True
88 |
89 | _C.MODEL.DISTRIBUTION = CN()
90 | _C.MODEL.DISTRIBUTION.LATENT_DIM = 32
91 | _C.MODEL.DISTRIBUTION.MIN_LOG_SIGMA = -5.0
92 | _C.MODEL.DISTRIBUTION.MAX_LOG_SIGMA = 5.0
93 |
94 | _C.MODEL.FUTURE_PRED = CN()
95 | _C.MODEL.FUTURE_PRED.N_GRU_BLOCKS = 3
96 | _C.MODEL.FUTURE_PRED.N_RES_LAYERS = 3
97 |
98 | _C.MODEL.DECODER = CN()
99 |
100 | _C.MODEL.BN_MOMENTUM = 0.1
101 | _C.MODEL.SUBSAMPLE = False # Subsample frames for Lyft
102 |
103 | _C.SEMANTIC_SEG = CN()
104 | _C.SEMANTIC_SEG.WEIGHTS = [1.0, 2.0] # per class cross entropy weights (bg, dynamic)
105 | _C.SEMANTIC_SEG.USE_TOP_K = True # backprop only top-k hardest pixels
106 | _C.SEMANTIC_SEG.TOP_K_RATIO = 0.25
107 |
108 | _C.INSTANCE_SEG = CN()
109 |
110 | _C.INSTANCE_FLOW = CN()
111 | _C.INSTANCE_FLOW.ENABLED = True
112 |
113 | _C.PROBABILISTIC = CN()
114 | _C.PROBABILISTIC.ENABLED = True # learn a distribution over futures
115 | _C.PROBABILISTIC.WEIGHT = 100.0
116 | _C.PROBABILISTIC.FUTURE_DIM = 6 # number of dimension added (future flow, future centerness, offset, seg)
117 |
118 | _C.FUTURE_DISCOUNT = 0.95
119 |
120 | _C.OPTIMIZER = CN()
121 | _C.OPTIMIZER.LR = 3e-4
122 | _C.OPTIMIZER.WEIGHT_DECAY = 1e-7
123 | _C.GRAD_NORM_CLIP = 5
124 |
125 |
126 | def get_parser():
127 | parser = argparse.ArgumentParser(description='Fiery training')
128 | # TODO: remove below?
129 | parser.add_argument('--config-file', default='', metavar='FILE', help='path to config file')
130 | parser.add_argument(
131 | 'opts', help='Modify config options using the command-line', default=None, nargs=argparse.REMAINDER,
132 | )
133 | return parser
134 |
135 |
136 | def get_cfg(args=None, cfg_dict=None):
137 | """ First get default config. Then merge cfg_dict. Then merge according to args. """
138 |
139 | cfg = _C.clone()
140 |
141 | if cfg_dict is not None:
142 | cfg.merge_from_other_cfg(CfgNode(cfg_dict))
143 |
144 | if args is not None:
145 | if args.config_file:
146 | cfg.merge_from_file(args.config_file)
147 | cfg.merge_from_list(args.opts)
148 | cfg.freeze()
149 | return cfg
150 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # FIERY
2 | This is the PyTorch implementation for inference and training of the future prediction bird's-eye view network as
3 | described in:
4 |
5 | > **FIERY: Future Instance Segmentation in Bird's-Eye view from Surround Monocular Cameras**
6 | >
7 | > [Anthony Hu](https://anthonyhu.github.io/), [Zak Murez](http://zak.murez.com/),
8 | [Nikhil Mohan](https://uk.linkedin.com/in/nikhilmohan33),
9 | [Sofía Dudas](https://uk.linkedin.com/in/sof%C3%ADa-josefina-lago-dudas-2b0737132),
10 | [Jeffrey Hawke](https://uk.linkedin.com/in/jeffrey-hawke),
11 | [Vijay Badrinarayanan](https://sites.google.com/site/vijaybacademichomepage/home),
12 | [Roberto Cipolla](https://mi.eng.cam.ac.uk/~cipolla/index.htm) and [Alex Kendall](https://alexgkendall.com/)
13 | >
14 | > [ICCV 2021 (Oral)](https://arxiv.org/abs/2104.10490)
15 | > [Blog post](https://wayve.ai/blog/fiery-future-instance-prediction-birds-eye-view)
16 |
17 |
18 |
19 |
20 | Multimodal future predictions by our bird’s-eye view network.
21 | Top two rows: RGB camera inputs. The predicted future trajectories and segmentations are projected to the ground plane in the images.
22 | Bottom row: future instance prediction in bird’s-eye view in a 100m×100m capture size around the ego-vehicle, which is indicated by a black rectangle in the center.
23 |
24 |
25 |
26 | If you find our work useful, please consider citing:
27 | ```bibtex
28 | @inproceedings{fiery2021,
29 | title = {{FIERY}: Future Instance Segmentation in Bird's-Eye view from Surround Monocular Cameras},
30 | author = {Anthony Hu and Zak Murez and Nikhil Mohan and Sofía Dudas and
31 | Jeffrey Hawke and Vijay Badrinarayanan and Roberto Cipolla and Alex Kendall},
32 | booktitle = {Proceedings of the International Conference on Computer Vision ({ICCV})},
33 | year = {2021}
34 | }
35 | ```
36 |
37 | ## ⚙ Setup
38 | - Create the [conda](https://docs.conda.io/en/latest/miniconda.html) environment by running `conda env create`.
39 |
40 | ## 🏄 Prediction
41 | ### Visualisation
42 |
43 | In a colab notebook:
44 | [](https://colab.research.google.com/drive/12ahc3whI1RQZIVDi53grMWHzdA7WqIuo?usp=sharing)
45 |
46 | Or locally:
47 | - Download [pre-trained weights](https://github.com/wayveai/fiery/releases/download/v1.0/fiery.ckpt).
48 | - Run `python visualise.py --checkpoint ${CHECKPOINT_PATH}`. This will render predictions from the network and save
49 | them to an `output_vis` folder.
50 |
51 | ### Evaluation
52 | - Download the [NuScenes dataset](https://www.nuscenes.org/download). For detailed instructions, see [DATASET.md](DATASET.md).
53 | - Download [pre-trained weights](https://github.com/wayveai/fiery/releases/download/v1.0/fiery.ckpt).
54 | - Run `python evaluate.py --checkpoint ${CHECKPOINT_PATH} --dataroot ${NUSCENES_DATAROOT}`.
55 |
56 | ## 🔥 Pre-trained models
57 |
58 | All the configs are in the folder `fiery/configs`
59 |
60 | | Config and weights | Dataset | Past context | Future horizon | BEV size | IoU | VPQ|
61 | |--------------|---------|-----------------------|----------------|----------|------|----|
62 | | [`baseline.yml`](https://github.com/wayveai/fiery/releases/download/v1.0/fiery.ckpt) | NuScenes | 1.0s | 2.0s | 100mx100m (50cm res.) | 36.7 | 29.9 |
63 | | [`lyft/baseline.yml`](https://github.com/wayveai/fiery/releases/download/v1.0/lyft_fiery.ckpt) | Lyft | 0.8s | 2.0s| 100mx100m (50cm res.) | 36.3 | 29.2 |
64 | | [`literature/static_pon_setting.yml`](https://github.com/wayveai/fiery/releases/download/v1.0/static_pon_setting.ckpt) | NuScenes| 0.0s | 0.0s | 100mx50m (25cm res.) | 37.7| - |
65 | | [`literature/pon_setting.yml`](https://github.com/wayveai/fiery/releases/download/v1.0/pon_setting.ckpt) | NuScenes| 1.0s | 0.0s | 100mx50m (25cm res.) |39.9 | - |
66 | | [`literature/static_lss_setting.yml`](https://github.com/wayveai/fiery/releases/download/v1.0/static_lift_splat_setting.ckpt) | NuScenes | 0.0s | 0.0s | 100mx100m (50cm res.) | 35.8 | - |
67 | | [`literature/lift_splat_setting.yml`](https://github.com/wayveai/fiery/releases/download/v1.0/lift_splat_setting.ckpt) | NuScenes | 1.0s | 0.0s | 100mx100m (50cm res.) | 38.2 | - |
68 | | [`literature/fishing_setting.yml`](https://github.com/wayveai/fiery/releases/download/v1.0/fishing_setting.ckpt) | NuScenes | 1.0s | 2.0s | 32.0mx19.2m (10cm res.) | 57.6 | - |
69 |
70 |
71 | ## 🏊 Training
72 | To train the model from scratch on NuScenes:
73 | - Download the [NuScenes dataset](https://www.nuscenes.org/download). For detailed instructions, see [DATASET.md](DATASET.md).
74 | - Run `python train.py --config fiery/configs/baseline.yml DATASET.DATAROOT ${NUSCENES_DATAROOT}`.
75 |
76 | This will train the model on 4 GPUs, each with a batch of size 3. To train on single GPU add the flag `GPUS 1`, and to change the batch
77 | size use the flag `BATCHSIZE ${DESIRED_BATCHSIZE}`.
78 |
79 | ## 🙌 Credits
80 | Big thanks to Giulio D'Ippolito ([@gdippolito](https://github.com/gdippolito)) for the technical help on the gpu
81 | servers, Piotr Sokólski ([@pyetras](https://github.com/pyetras)) for the panoptic metric implementation, and to Hannes Liik ([@hannesliik](https://github.com/hannesliik))
82 | for the awesome future trajectory visualisation on the ground plane.
83 |
--------------------------------------------------------------------------------
/visualise.py:
--------------------------------------------------------------------------------
1 | import os
2 | from argparse import ArgumentParser
3 | from glob import glob
4 |
5 | import cv2
6 | import numpy as np
7 | import torch
8 | import torchvision
9 | import matplotlib as mpl
10 | import matplotlib.pyplot as plt
11 | from PIL import Image
12 |
13 | from fiery.trainer import TrainingModule
14 | from fiery.utils.network import NormalizeInverse
15 | from fiery.utils.instance import predict_instance_segmentation_and_trajectories
16 | from fiery.utils.visualisation import plot_instance_map, generate_instance_colours, make_contour, convert_figure_numpy
17 |
18 | EXAMPLE_DATA_PATH = 'example_data'
19 |
20 |
21 | def plot_prediction(image, output, cfg):
22 | # Process predictions
23 | consistent_instance_seg, matched_centers = predict_instance_segmentation_and_trajectories(
24 | output, compute_matched_centers=True
25 | )
26 |
27 | # Plot future trajectories
28 | unique_ids = torch.unique(consistent_instance_seg[0, 0]).cpu().long().numpy()[1:]
29 | instance_map = dict(zip(unique_ids, unique_ids))
30 | instance_colours = generate_instance_colours(instance_map)
31 | vis_image = plot_instance_map(consistent_instance_seg[0, 0].cpu().numpy(), instance_map)
32 | trajectory_img = np.zeros(vis_image.shape, dtype=np.uint8)
33 | for instance_id in unique_ids:
34 | path = matched_centers[instance_id]
35 | for t in range(len(path) - 1):
36 | color = instance_colours[instance_id].tolist()
37 | cv2.line(trajectory_img, tuple(path[t]), tuple(path[t + 1]),
38 | color, 4)
39 |
40 | # Overlay arrows
41 | temp_img = cv2.addWeighted(vis_image, 0.7, trajectory_img, 0.3, 1.0)
42 | mask = ~ np.all(trajectory_img == 0, axis=2)
43 | vis_image[mask] = temp_img[mask]
44 |
45 | # Plot present RGB frames and predictions
46 | val_w = 2.99
47 | cameras = cfg.IMAGE.NAMES
48 | image_ratio = cfg.IMAGE.FINAL_DIM[0] / cfg.IMAGE.FINAL_DIM[1]
49 | val_h = val_w * image_ratio
50 | fig = plt.figure(figsize=(4 * val_w, 2 * val_h))
51 | width_ratios = (val_w, val_w, val_w, val_w)
52 | gs = mpl.gridspec.GridSpec(2, 4, width_ratios=width_ratios)
53 | gs.update(wspace=0.0, hspace=0.0, left=0.0, right=1.0, top=1.0, bottom=0.0)
54 |
55 | denormalise_img = torchvision.transforms.Compose(
56 | (NormalizeInverse(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
57 | torchvision.transforms.ToPILImage(),)
58 | )
59 | for imgi, img in enumerate(image[0, -1]):
60 | ax = plt.subplot(gs[imgi // 3, imgi % 3])
61 | showimg = denormalise_img(img.cpu())
62 | if imgi > 2:
63 | showimg = showimg.transpose(Image.FLIP_LEFT_RIGHT)
64 |
65 | plt.annotate(cameras[imgi].replace('_', ' ').replace('CAM ', ''), (0.01, 0.87), c='white',
66 | xycoords='axes fraction', fontsize=14)
67 | plt.imshow(showimg)
68 | plt.axis('off')
69 |
70 | ax = plt.subplot(gs[:, 3])
71 | plt.imshow(make_contour(vis_image[::-1, ::-1]))
72 | plt.axis('off')
73 |
74 | plt.draw()
75 | figure_numpy = convert_figure_numpy(fig)
76 | plt.close()
77 | return figure_numpy
78 |
79 |
80 | def download_example_data():
81 | from requests import get
82 |
83 | def download(url, file_name):
84 | # open in binary mode
85 | with open(file_name, "wb") as file:
86 | # get request
87 | response = get(url)
88 | # write to file
89 | file.write(response.content)
90 |
91 | os.makedirs(EXAMPLE_DATA_PATH, exist_ok=True)
92 | url_list = ['https://github.com/wayveai/fiery/releases/download/v1.0/example_1.npz',
93 | 'https://github.com/wayveai/fiery/releases/download/v1.0/example_2.npz',
94 | 'https://github.com/wayveai/fiery/releases/download/v1.0/example_3.npz',
95 | 'https://github.com/wayveai/fiery/releases/download/v1.0/example_4.npz'
96 | ]
97 | for url in url_list:
98 | download(url, os.path.join(EXAMPLE_DATA_PATH, os.path.basename(url)))
99 |
100 |
101 | def visualise(checkpoint_path):
102 | trainer = TrainingModule.load_from_checkpoint(checkpoint_path, strict=True)
103 |
104 | device = torch.device('cuda:0')
105 | trainer = trainer.to(device)
106 | trainer.eval()
107 |
108 | # Download example data
109 | download_example_data()
110 | # Load data
111 | for data_path in sorted(glob(os.path.join(EXAMPLE_DATA_PATH, '*.npz'))):
112 | data = np.load(data_path)
113 | image = torch.from_numpy(data['image']).to(device)
114 | intrinsics = torch.from_numpy(data['intrinsics']).to(device)
115 | extrinsics = torch.from_numpy(data['extrinsics']).to(device)
116 | future_egomotions = torch.from_numpy(data['future_egomotion']).to(device)
117 |
118 | # Forward pass
119 | with torch.no_grad():
120 | output = trainer.model(image, intrinsics, extrinsics, future_egomotions)
121 |
122 | figure_numpy = plot_prediction(image, output, trainer.cfg)
123 | os.makedirs('./output_vis', exist_ok=True)
124 | output_filename = os.path.join('./output_vis', os.path.basename(data_path).split('.')[0]) + '.png'
125 | Image.fromarray(figure_numpy).save(output_filename)
126 | print(f'Saved output in {output_filename}')
127 |
128 |
129 | if __name__ == '__main__':
130 | parser = ArgumentParser(description='Fiery visualisation')
131 | parser.add_argument('--checkpoint', default='./fiery.ckpt', type=str, help='path to checkpoint')
132 |
133 | args = parser.parse_args()
134 |
135 | visualise(args.checkpoint)
136 |
--------------------------------------------------------------------------------
/fiery/layers/convolutions.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 | from functools import partial
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 |
9 | class ConvBlock(nn.Module):
10 | """2D convolution followed by
11 | - an optional normalisation (batch norm or instance norm)
12 | - an optional activation (ReLU, LeakyReLU, or tanh)
13 | """
14 |
15 | def __init__(
16 | self,
17 | in_channels,
18 | out_channels=None,
19 | kernel_size=3,
20 | stride=1,
21 | norm='bn',
22 | activation='relu',
23 | bias=False,
24 | transpose=False,
25 | ):
26 | super().__init__()
27 | out_channels = out_channels or in_channels
28 | padding = int((kernel_size - 1) / 2)
29 | self.conv = nn.Conv2d if not transpose else partial(nn.ConvTranspose2d, output_padding=1)
30 | self.conv = self.conv(in_channels, out_channels, kernel_size, stride, padding=padding, bias=bias)
31 |
32 | if norm == 'bn':
33 | self.norm = nn.BatchNorm2d(out_channels)
34 | elif norm == 'in':
35 | self.norm = nn.InstanceNorm2d(out_channels)
36 | elif norm == 'none':
37 | self.norm = None
38 | else:
39 | raise ValueError('Invalid norm {}'.format(norm))
40 |
41 | if activation == 'relu':
42 | self.activation = nn.ReLU(inplace=True)
43 | elif activation == 'lrelu':
44 | self.activation = nn.LeakyReLU(0.1, inplace=True)
45 | elif activation == 'elu':
46 | self.activation = nn.ELU(inplace=True)
47 | elif activation == 'tanh':
48 | self.activation = nn.Tanh(inplace=True)
49 | elif activation == 'none':
50 | self.activation = None
51 | else:
52 | raise ValueError('Invalid activation {}'.format(activation))
53 |
54 | def forward(self, x):
55 | x = self.conv(x)
56 |
57 | if self.norm:
58 | x = self.norm(x)
59 | if self.activation:
60 | x = self.activation(x)
61 | return x
62 |
63 |
64 | class Bottleneck(nn.Module):
65 | """
66 | Defines a bottleneck module with a residual connection
67 | """
68 |
69 | def __init__(
70 | self,
71 | in_channels,
72 | out_channels=None,
73 | kernel_size=3,
74 | dilation=1,
75 | groups=1,
76 | upsample=False,
77 | downsample=False,
78 | dropout=0.0,
79 | ):
80 | super().__init__()
81 | self._downsample = downsample
82 | bottleneck_channels = int(in_channels / 2)
83 | out_channels = out_channels or in_channels
84 | padding_size = ((kernel_size - 1) * dilation + 1) // 2
85 |
86 | # Define the main conv operation
87 | assert dilation == 1
88 | if upsample:
89 | assert not downsample, 'downsample and upsample not possible simultaneously.'
90 | bottleneck_conv = nn.ConvTranspose2d(
91 | bottleneck_channels,
92 | bottleneck_channels,
93 | kernel_size=kernel_size,
94 | bias=False,
95 | dilation=1,
96 | stride=2,
97 | output_padding=padding_size,
98 | padding=padding_size,
99 | groups=groups,
100 | )
101 | elif downsample:
102 | bottleneck_conv = nn.Conv2d(
103 | bottleneck_channels,
104 | bottleneck_channels,
105 | kernel_size=kernel_size,
106 | bias=False,
107 | dilation=dilation,
108 | stride=2,
109 | padding=padding_size,
110 | groups=groups,
111 | )
112 | else:
113 | bottleneck_conv = nn.Conv2d(
114 | bottleneck_channels,
115 | bottleneck_channels,
116 | kernel_size=kernel_size,
117 | bias=False,
118 | dilation=dilation,
119 | padding=padding_size,
120 | groups=groups,
121 | )
122 |
123 | self.layers = nn.Sequential(
124 | OrderedDict(
125 | [
126 | # First projection with 1x1 kernel
127 | ('conv_down_project', nn.Conv2d(in_channels, bottleneck_channels, kernel_size=1, bias=False)),
128 | ('abn_down_project', nn.Sequential(nn.BatchNorm2d(bottleneck_channels),
129 | nn.ReLU(inplace=True))),
130 | # Second conv block
131 | ('conv', bottleneck_conv),
132 | ('abn', nn.Sequential(nn.BatchNorm2d(bottleneck_channels), nn.ReLU(inplace=True))),
133 | # Final projection with 1x1 kernel
134 | ('conv_up_project', nn.Conv2d(bottleneck_channels, out_channels, kernel_size=1, bias=False)),
135 | ('abn_up_project', nn.Sequential(nn.BatchNorm2d(out_channels),
136 | nn.ReLU(inplace=True))),
137 | # Regulariser
138 | ('dropout', nn.Dropout2d(p=dropout)),
139 | ]
140 | )
141 | )
142 |
143 | if out_channels == in_channels and not downsample and not upsample:
144 | self.projection = None
145 | else:
146 | projection = OrderedDict()
147 | if upsample:
148 | projection.update({'upsample_skip_proj': Interpolate(scale_factor=2)})
149 | elif downsample:
150 | projection.update({'upsample_skip_proj': nn.MaxPool2d(kernel_size=2, stride=2)})
151 | projection.update(
152 | {
153 | 'conv_skip_proj': nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
154 | 'bn_skip_proj': nn.BatchNorm2d(out_channels),
155 | }
156 | )
157 | self.projection = nn.Sequential(projection)
158 |
159 | # pylint: disable=arguments-differ
160 | def forward(self, *args):
161 | (x,) = args
162 | x_residual = self.layers(x)
163 | if self.projection is not None:
164 | if self._downsample:
165 | # pad h/w dimensions if they are odd to prevent shape mismatch with residual layer
166 | x = nn.functional.pad(x, (0, x.shape[-1] % 2, 0, x.shape[-2] % 2), value=0)
167 | return x_residual + self.projection(x)
168 | return x_residual + x
169 |
170 |
171 | class Interpolate(nn.Module):
172 | def __init__(self, scale_factor: int = 2):
173 | super().__init__()
174 | self._interpolate = nn.functional.interpolate
175 | self._scale_factor = scale_factor
176 |
177 | # pylint: disable=arguments-differ
178 | def forward(self, x):
179 | return self._interpolate(x, scale_factor=self._scale_factor, mode='bilinear', align_corners=False)
180 |
181 |
182 | class UpsamplingConcat(nn.Module):
183 | def __init__(self, in_channels, out_channels, scale_factor=2):
184 | super().__init__()
185 |
186 | self.upsample = nn.Upsample(scale_factor=scale_factor, mode='bilinear', align_corners=False)
187 |
188 | self.conv = nn.Sequential(
189 | nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
190 | nn.BatchNorm2d(out_channels),
191 | nn.ReLU(inplace=True),
192 | nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
193 | nn.BatchNorm2d(out_channels),
194 | nn.ReLU(inplace=True),
195 | )
196 |
197 | def forward(self, x_to_upsample, x):
198 | x_to_upsample = self.upsample(x_to_upsample)
199 | x_to_upsample = torch.cat([x, x_to_upsample], dim=1)
200 | return self.conv(x_to_upsample)
201 |
202 |
203 | class UpsamplingAdd(nn.Module):
204 | def __init__(self, in_channels, out_channels, scale_factor=2):
205 | super().__init__()
206 | self.upsample_layer = nn.Sequential(
207 | nn.Upsample(scale_factor=scale_factor, mode='bilinear', align_corners=False),
208 | nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0, bias=False),
209 | nn.BatchNorm2d(out_channels),
210 | )
211 |
212 | def forward(self, x, x_skip):
213 | x = self.upsample_layer(x)
214 | return x + x_skip
215 |
--------------------------------------------------------------------------------
/fiery/utils/geometry.py:
--------------------------------------------------------------------------------
1 | import PIL
2 | import numpy as np
3 | import torch
4 |
5 | from pyquaternion import Quaternion
6 |
7 |
8 | def resize_and_crop_image(img, resize_dims, crop):
9 | # Bilinear resizing followed by cropping
10 | img = img.resize(resize_dims, resample=PIL.Image.BILINEAR)
11 | img = img.crop(crop)
12 | return img
13 |
14 |
15 | def update_intrinsics(intrinsics, top_crop=0.0, left_crop=0.0, scale_width=1.0, scale_height=1.0):
16 | """
17 | Parameters
18 | ----------
19 | intrinsics: torch.Tensor (3, 3)
20 | top_crop: float
21 | left_crop: float
22 | scale_width: float
23 | scale_height: float
24 | """
25 | updated_intrinsics = intrinsics.clone()
26 | # Adjust intrinsics scale due to resizing
27 | updated_intrinsics[0, 0] *= scale_width
28 | updated_intrinsics[0, 2] *= scale_width
29 | updated_intrinsics[1, 1] *= scale_height
30 | updated_intrinsics[1, 2] *= scale_height
31 |
32 | # Adjust principal point due to cropping
33 | updated_intrinsics[0, 2] -= left_crop
34 | updated_intrinsics[1, 2] -= top_crop
35 |
36 | return updated_intrinsics
37 |
38 |
39 | def calculate_birds_eye_view_parameters(x_bounds, y_bounds, z_bounds):
40 | """
41 | Parameters
42 | ----------
43 | x_bounds: Forward direction in the ego-car.
44 | y_bounds: Sides
45 | z_bounds: Height
46 |
47 | Returns
48 | -------
49 | bev_resolution: Bird's-eye view bev_resolution
50 | bev_start_position Bird's-eye view first element
51 | bev_dimension Bird's-eye view tensor spatial dimension
52 | """
53 | bev_resolution = torch.tensor([row[2] for row in [x_bounds, y_bounds, z_bounds]])
54 | bev_start_position = torch.tensor([row[0] + row[2] / 2.0 for row in [x_bounds, y_bounds, z_bounds]])
55 | bev_dimension = torch.tensor([(row[1] - row[0]) / row[2] for row in [x_bounds, y_bounds, z_bounds]],
56 | dtype=torch.long)
57 |
58 | return bev_resolution, bev_start_position, bev_dimension
59 |
60 |
61 | def convert_egopose_to_matrix_numpy(egopose):
62 | transformation_matrix = np.zeros((4, 4), dtype=np.float32)
63 | rotation = Quaternion(egopose['rotation']).rotation_matrix
64 | translation = np.array(egopose['translation'])
65 | transformation_matrix[:3, :3] = rotation
66 | transformation_matrix[:3, 3] = translation
67 | transformation_matrix[3, 3] = 1.0
68 | return transformation_matrix
69 |
70 |
71 | def invert_matrix_egopose_numpy(egopose):
72 | """ Compute the inverse transformation of a 4x4 egopose numpy matrix."""
73 | inverse_matrix = np.zeros((4, 4), dtype=np.float32)
74 | rotation = egopose[:3, :3]
75 | translation = egopose[:3, 3]
76 | inverse_matrix[:3, :3] = rotation.T
77 | inverse_matrix[:3, 3] = -np.dot(rotation.T, translation)
78 | inverse_matrix[3, 3] = 1.0
79 | return inverse_matrix
80 |
81 |
82 | def mat2pose_vec(matrix: torch.Tensor):
83 | """
84 | Converts a 4x4 pose matrix into a 6-dof pose vector
85 | Args:
86 | matrix (ndarray): 4x4 pose matrix
87 | Returns:
88 | vector (ndarray): 6-dof pose vector comprising translation components (tx, ty, tz) and
89 | rotation components (rx, ry, rz)
90 | """
91 |
92 | # M[1, 2] = -sinx*cosy, M[2, 2] = +cosx*cosy
93 | rotx = torch.atan2(-matrix[..., 1, 2], matrix[..., 2, 2])
94 |
95 | # M[0, 2] = +siny, M[1, 2] = -sinx*cosy, M[2, 2] = +cosx*cosy
96 | cosy = torch.sqrt(matrix[..., 1, 2] ** 2 + matrix[..., 2, 2] ** 2)
97 | roty = torch.atan2(matrix[..., 0, 2], cosy)
98 |
99 | # M[0, 0] = +cosy*cosz, M[0, 1] = -cosy*sinz
100 | rotz = torch.atan2(-matrix[..., 0, 1], matrix[..., 0, 0])
101 |
102 | rotation = torch.stack((rotx, roty, rotz), dim=-1)
103 |
104 | # Extract translation params
105 | translation = matrix[..., :3, 3]
106 | return torch.cat((translation, rotation), dim=-1)
107 |
108 |
109 | def euler2mat(angle: torch.Tensor):
110 | """Convert euler angles to rotation matrix.
111 | Reference: https://github.com/pulkitag/pycaffe-utils/blob/master/rot_utils.py#L174
112 | Args:
113 | angle: rotation angle along 3 axis (in radians) [Bx3]
114 | Returns:
115 | Rotation matrix corresponding to the euler angles [Bx3x3]
116 | """
117 | shape = angle.shape
118 | angle = angle.view(-1, 3)
119 | x, y, z = angle[:, 0], angle[:, 1], angle[:, 2]
120 |
121 | cosz = torch.cos(z)
122 | sinz = torch.sin(z)
123 |
124 | zeros = torch.zeros_like(z)
125 | ones = torch.ones_like(z)
126 | zmat = torch.stack([cosz, -sinz, zeros, sinz, cosz, zeros, zeros, zeros, ones], dim=1).view(-1, 3, 3)
127 |
128 | cosy = torch.cos(y)
129 | siny = torch.sin(y)
130 |
131 | ymat = torch.stack([cosy, zeros, siny, zeros, ones, zeros, -siny, zeros, cosy], dim=1).view(-1, 3, 3)
132 |
133 | cosx = torch.cos(x)
134 | sinx = torch.sin(x)
135 |
136 | xmat = torch.stack([ones, zeros, zeros, zeros, cosx, -sinx, zeros, sinx, cosx], dim=1).view(-1, 3, 3)
137 |
138 | rot_mat = xmat.bmm(ymat).bmm(zmat)
139 | rot_mat = rot_mat.view(*shape[:-1], 3, 3)
140 | return rot_mat
141 |
142 |
143 | def pose_vec2mat(vec: torch.Tensor):
144 | """
145 | Convert 6DoF parameters to transformation matrix.
146 | Args:
147 | vec: 6DoF parameters in the order of tx, ty, tz, rx, ry, rz [B,6]
148 | Returns:
149 | A transformation matrix [B,4,4]
150 | """
151 | translation = vec[..., :3].unsqueeze(-1) # [...x3x1]
152 | rot = vec[..., 3:].contiguous() # [...x3]
153 | rot_mat = euler2mat(rot) # [...,3,3]
154 | transform_mat = torch.cat([rot_mat, translation], dim=-1) # [...,3,4]
155 | transform_mat = torch.nn.functional.pad(transform_mat, [0, 0, 0, 1], value=0) # [...,4,4]
156 | transform_mat[..., 3, 3] = 1.0
157 | return transform_mat
158 |
159 |
160 | def invert_pose_matrix(x):
161 | """
162 | Parameters
163 | ----------
164 | x: [B, 4, 4] batch of pose matrices
165 |
166 | Returns
167 | -------
168 | out: [B, 4, 4] batch of inverse pose matrices
169 | """
170 | assert len(x.shape) == 3 and x.shape[1:] == (4, 4), 'Only works for batch of pose matrices.'
171 |
172 | transposed_rotation = torch.transpose(x[:, :3, :3], 1, 2)
173 | translation = x[:, :3, 3:]
174 |
175 | inverse_mat = torch.cat([transposed_rotation, -torch.bmm(transposed_rotation, translation)], dim=-1) # [B,3,4]
176 | inverse_mat = torch.nn.functional.pad(inverse_mat, [0, 0, 0, 1], value=0) # [B,4,4]
177 | inverse_mat[..., 3, 3] = 1.0
178 | return inverse_mat
179 |
180 |
181 | def warp_features(x, flow, mode='nearest', spatial_extent=None):
182 | """ Applies a rotation and translation to feature map x.
183 | Args:
184 | x: (b, c, h, w) feature map
185 | flow: (b, 6) 6DoF vector (only uses the xy poriton)
186 | mode: use 'nearest' when dealing with categorical inputs
187 | Returns:
188 | in plane transformed feature map
189 | """
190 | if flow is None:
191 | return x
192 | b, c, h, w = x.shape
193 | # z-rotation
194 | angle = flow[:, 5].clone() # torch.atan2(flow[:, 1, 0], flow[:, 0, 0])
195 | # x-y translation
196 | translation = flow[:, :2].clone() # flow[:, :2, 3]
197 |
198 | # Normalise translation. Need to divide by how many meters is half of the image.
199 | # because translation of 1.0 correspond to translation of half of the image.
200 | translation[:, 0] /= spatial_extent[0]
201 | translation[:, 1] /= spatial_extent[1]
202 | # forward axis is inverted
203 | translation[:, 0] *= -1
204 |
205 | cos_theta = torch.cos(angle)
206 | sin_theta = torch.sin(angle)
207 |
208 | # output = Rot.input + translation
209 | # tx and ty are inverted as is the case when going from real coordinates to numpy coordinates
210 | # translation_pos_0 -> positive value makes the image move to the left
211 | # translation_pos_1 -> positive value makes the image move to the top
212 | # Angle -> positive value in rad makes the image move in the trigonometric way
213 | transformation = torch.stack([cos_theta, -sin_theta, translation[:, 1],
214 | sin_theta, cos_theta, translation[:, 0]], dim=-1).view(b, 2, 3)
215 |
216 | # Note that a rotation will preserve distances only if height = width. Otherwise there's
217 | # resizing going on. e.g. rotation of pi/2 of a 100x200 image will make what's in the center of the image
218 | # elongated.
219 | grid = torch.nn.functional.affine_grid(transformation, size=x.shape, align_corners=False)
220 | warped_x = torch.nn.functional.grid_sample(x, grid.float(), mode=mode, padding_mode='zeros', align_corners=False)
221 |
222 | return warped_x
223 |
224 |
225 | def cumulative_warp_features(x, flow, mode='nearest', spatial_extent=None):
226 | """ Warps a sequence of feature maps by accumulating incremental 2d flow.
227 |
228 | x[:, -1] remains unchanged
229 | x[:, -2] is warped using flow[:, -2]
230 | x[:, -3] is warped using flow[:, -3] @ flow[:, -2]
231 | ...
232 | x[:, 0] is warped using flow[:, 0] @ ... @ flow[:, -3] @ flow[:, -2]
233 |
234 | Args:
235 | x: (b, t, c, h, w) sequence of feature maps
236 | flow: (b, t, 6) sequence of 6 DoF pose
237 | from t to t+1 (only uses the xy poriton)
238 |
239 | """
240 | sequence_length = x.shape[1]
241 | if sequence_length == 1:
242 | return x
243 |
244 | flow = pose_vec2mat(flow)
245 |
246 | out = [x[:, -1]]
247 | cum_flow = flow[:, -2]
248 | for t in reversed(range(sequence_length - 1)):
249 | out.append(warp_features(x[:, t], mat2pose_vec(cum_flow), mode=mode, spatial_extent=spatial_extent))
250 | # @ is the equivalent of torch.bmm
251 | cum_flow = flow[:, t - 1] @ cum_flow
252 |
253 | return torch.stack(out[::-1], 1)
254 |
255 |
256 | def cumulative_warp_features_reverse(x, flow, mode='nearest', spatial_extent=None):
257 | """ Warps a sequence of feature maps by accumulating incremental 2d flow.
258 |
259 | x[:, 0] remains unchanged
260 | x[:, 1] is warped using flow[:, 0].inverse()
261 | x[:, 2] is warped using flow[:, 0].inverse() @ flow[:, 1].inverse()
262 | ...
263 |
264 | Args:
265 | x: (b, t, c, h, w) sequence of feature maps
266 | flow: (b, t, 6) sequence of 6 DoF pose
267 | from t to t+1 (only uses the xy poriton)
268 |
269 | """
270 | flow = pose_vec2mat(flow)
271 |
272 | out = [x[:,0]]
273 |
274 | for i in range(1, x.shape[1]):
275 | if i==1:
276 | cum_flow = invert_pose_matrix(flow[:, 0])
277 | else:
278 | cum_flow = cum_flow @ invert_pose_matrix(flow[:,i-1])
279 | out.append( warp_features(x[:,i], mat2pose_vec(cum_flow), mode, spatial_extent=spatial_extent))
280 | return torch.stack(out, 1)
281 |
282 |
283 | class VoxelsSumming(torch.autograd.Function):
284 | """Adapted from https://github.com/nv-tlabs/lift-splat-shoot/blob/master/src/tools.py#L193"""
285 | @staticmethod
286 | def forward(ctx, x, geometry, ranks):
287 | """The features `x` and `geometry` are ranked by voxel positions."""
288 | # Cumulative sum of all features.
289 | x = x.cumsum(0)
290 |
291 | # Indicates the change of voxel.
292 | mask = torch.ones(x.shape[0], device=x.device, dtype=torch.bool)
293 | mask[:-1] = ranks[1:] != ranks[:-1]
294 |
295 | x, geometry = x[mask], geometry[mask]
296 | # Calculate sum of features within a voxel.
297 | x = torch.cat((x[:1], x[1:] - x[:-1]))
298 |
299 | ctx.save_for_backward(mask)
300 | ctx.mark_non_differentiable(geometry)
301 |
302 | return x, geometry
303 |
304 | @staticmethod
305 | def backward(ctx, grad_x, grad_geometry):
306 | (mask,) = ctx.saved_tensors
307 | # Since the operation is summing, we simply need to send gradient
308 | # to all elements that were part of the summation process.
309 | indices = torch.cumsum(mask, 0)
310 | indices[mask] -= 1
311 |
312 | output_grad = grad_x[indices]
313 |
314 | return output_grad, None, None
315 |
--------------------------------------------------------------------------------
/fiery/layers/temporal.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | from fiery.layers.convolutions import ConvBlock
7 | from fiery.utils.geometry import warp_features
8 |
9 |
10 | class SpatialGRU(nn.Module):
11 | """A GRU cell that takes an input tensor [BxTxCxHxW] and an optional previous state and passes a
12 | convolutional gated recurrent unit over the data"""
13 |
14 | def __init__(self, input_size, hidden_size, gru_bias_init=0.0, norm='bn', activation='relu'):
15 | super().__init__()
16 | self.input_size = input_size
17 | self.hidden_size = hidden_size
18 | self.gru_bias_init = gru_bias_init
19 |
20 | self.conv_update = nn.Conv2d(input_size + hidden_size, hidden_size, kernel_size=3, bias=True, padding=1)
21 | self.conv_reset = nn.Conv2d(input_size + hidden_size, hidden_size, kernel_size=3, bias=True, padding=1)
22 |
23 | self.conv_state_tilde = ConvBlock(
24 | input_size + hidden_size, hidden_size, kernel_size=3, bias=False, norm=norm, activation=activation
25 | )
26 |
27 | def forward(self, x, state=None, flow=None, mode='bilinear'):
28 | # pylint: disable=unused-argument, arguments-differ
29 | # Check size
30 | assert len(x.size()) == 5, 'Input tensor must be BxTxCxHxW.'
31 | b, timesteps, c, h, w = x.size()
32 | assert c == self.input_size, f'feature sizes must match, got input {c} for layer with size {self.input_size}'
33 |
34 | # recurrent layers
35 | rnn_output = []
36 | rnn_state = torch.zeros(b, self.hidden_size, h, w, device=x.device) if state is None else state
37 | for t in range(timesteps):
38 | x_t = x[:, t]
39 | if flow is not None:
40 | rnn_state = warp_features(rnn_state, flow[:, t], mode=mode)
41 |
42 | # propagate rnn state
43 | rnn_state = self.gru_cell(x_t, rnn_state)
44 | rnn_output.append(rnn_state)
45 |
46 | # reshape rnn output to batch tensor
47 | return torch.stack(rnn_output, dim=1)
48 |
49 | def gru_cell(self, x, state):
50 | # Compute gates
51 | x_and_state = torch.cat([x, state], dim=1)
52 | update_gate = self.conv_update(x_and_state)
53 | reset_gate = self.conv_reset(x_and_state)
54 | # Add bias to initialise gate as close to identity function
55 | update_gate = torch.sigmoid(update_gate + self.gru_bias_init)
56 | reset_gate = torch.sigmoid(reset_gate + self.gru_bias_init)
57 |
58 | # Compute proposal state, activation is defined in norm_act_config (can be tanh, ReLU etc)
59 | state_tilde = self.conv_state_tilde(torch.cat([x, (1.0 - reset_gate) * state], dim=1))
60 |
61 | output = (1.0 - update_gate) * state + update_gate * state_tilde
62 | return output
63 |
64 |
65 | class CausalConv3d(nn.Module):
66 | def __init__(self, in_channels, out_channels, kernel_size=(2, 3, 3), dilation=(1, 1, 1), bias=False):
67 | super().__init__()
68 | assert len(kernel_size) == 3, 'kernel_size must be a 3-tuple.'
69 | time_pad = (kernel_size[0] - 1) * dilation[0]
70 | height_pad = ((kernel_size[1] - 1) * dilation[1]) // 2
71 | width_pad = ((kernel_size[2] - 1) * dilation[2]) // 2
72 |
73 | # Pad temporally on the left
74 | self.pad = nn.ConstantPad3d(padding=(width_pad, width_pad, height_pad, height_pad, time_pad, 0), value=0)
75 | self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, dilation=dilation, stride=1, padding=0, bias=bias)
76 | self.norm = nn.BatchNorm3d(out_channels)
77 | self.activation = nn.ReLU(inplace=True)
78 |
79 | def forward(self, *inputs):
80 | (x,) = inputs
81 | x = self.pad(x)
82 | x = self.conv(x)
83 | x = self.norm(x)
84 | x = self.activation(x)
85 | return x
86 |
87 |
88 | class CausalMaxPool3d(nn.Module):
89 | def __init__(self, kernel_size=(2, 3, 3)):
90 | super().__init__()
91 | assert len(kernel_size) == 3, 'kernel_size must be a 3-tuple.'
92 | time_pad = kernel_size[0] - 1
93 | height_pad = (kernel_size[1] - 1) // 2
94 | width_pad = (kernel_size[2] - 1) // 2
95 |
96 | # Pad temporally on the left
97 | self.pad = nn.ConstantPad3d(padding=(width_pad, width_pad, height_pad, height_pad, time_pad, 0), value=0)
98 | self.max_pool = nn.MaxPool3d(kernel_size, stride=1)
99 |
100 | def forward(self, *inputs):
101 | (x,) = inputs
102 | x = self.pad(x)
103 | x = self.max_pool(x)
104 | return x
105 |
106 |
107 | def conv_1x1x1_norm_activated(in_channels, out_channels):
108 | """1x1x1 3D convolution, normalization and activation layer."""
109 | return nn.Sequential(
110 | OrderedDict(
111 | [
112 | ('conv', nn.Conv3d(in_channels, out_channels, kernel_size=1, bias=False)),
113 | ('norm', nn.BatchNorm3d(out_channels)),
114 | ('activation', nn.ReLU(inplace=True)),
115 | ]
116 | )
117 | )
118 |
119 |
120 | class Bottleneck3D(nn.Module):
121 | """
122 | Defines a bottleneck module with a residual connection
123 | """
124 |
125 | def __init__(self, in_channels, out_channels=None, kernel_size=(2, 3, 3), dilation=(1, 1, 1)):
126 | super().__init__()
127 | bottleneck_channels = in_channels // 2
128 | out_channels = out_channels or in_channels
129 |
130 | self.layers = nn.Sequential(
131 | OrderedDict(
132 | [
133 | # First projection with 1x1 kernel
134 | ('conv_down_project', conv_1x1x1_norm_activated(in_channels, bottleneck_channels)),
135 | # Second conv block
136 | (
137 | 'conv',
138 | CausalConv3d(
139 | bottleneck_channels,
140 | bottleneck_channels,
141 | kernel_size=kernel_size,
142 | dilation=dilation,
143 | bias=False,
144 | ),
145 | ),
146 | # Final projection with 1x1 kernel
147 | ('conv_up_project', conv_1x1x1_norm_activated(bottleneck_channels, out_channels)),
148 | ]
149 | )
150 | )
151 |
152 | if out_channels != in_channels:
153 | self.projection = nn.Sequential(
154 | nn.Conv3d(in_channels, out_channels, kernel_size=1, bias=False),
155 | nn.BatchNorm3d(out_channels),
156 | )
157 | else:
158 | self.projection = None
159 |
160 | def forward(self, *args):
161 | (x,) = args
162 | x_residual = self.layers(x)
163 | x_features = self.projection(x) if self.projection is not None else x
164 | return x_residual + x_features
165 |
166 |
167 | class PyramidSpatioTemporalPooling(nn.Module):
168 | """ Spatio-temporal pyramid pooling.
169 | Performs 3D average pooling followed by 1x1x1 convolution to reduce the number of channels and upsampling.
170 | Setting contains a list of kernel_size: usually it is [(2, h, w), (2, h//2, w//2), (2, h//4, w//4)]
171 | """
172 |
173 | def __init__(self, in_channels, reduction_channels, pool_sizes):
174 | super().__init__()
175 | self.features = []
176 | for pool_size in pool_sizes:
177 | assert pool_size[0] == 2, (
178 | "Time kernel should be 2 as PyTorch raises an error when" "padding with more than half the kernel size"
179 | )
180 | stride = (1, *pool_size[1:])
181 | padding = (pool_size[0] - 1, 0, 0)
182 | self.features.append(
183 | nn.Sequential(
184 | OrderedDict(
185 | [
186 | # Pad the input tensor but do not take into account zero padding into the average.
187 | (
188 | 'avgpool',
189 | torch.nn.AvgPool3d(
190 | kernel_size=pool_size, stride=stride, padding=padding, count_include_pad=False
191 | ),
192 | ),
193 | ('conv_bn_relu', conv_1x1x1_norm_activated(in_channels, reduction_channels)),
194 | ]
195 | )
196 | )
197 | )
198 | self.features = nn.ModuleList(self.features)
199 |
200 | def forward(self, *inputs):
201 | (x,) = inputs
202 | b, _, t, h, w = x.shape
203 | # Do not include current tensor when concatenating
204 | out = []
205 | for f in self.features:
206 | # Remove unnecessary padded values (time dimension) on the right
207 | x_pool = f(x)[:, :, :-1].contiguous()
208 | c = x_pool.shape[1]
209 | x_pool = nn.functional.interpolate(
210 | x_pool.view(b * t, c, *x_pool.shape[-2:]), (h, w), mode='bilinear', align_corners=False
211 | )
212 | x_pool = x_pool.view(b, c, t, h, w)
213 | out.append(x_pool)
214 | out = torch.cat(out, 1)
215 | return out
216 |
217 |
218 | class TemporalBlock(nn.Module):
219 | """ Temporal block with the following layers:
220 | - 2x3x3, 1x3x3, spatio-temporal pyramid pooling
221 | - dropout
222 | - skip connection.
223 | """
224 |
225 | def __init__(self, in_channels, out_channels=None, use_pyramid_pooling=False, pool_sizes=None):
226 | super().__init__()
227 | self.in_channels = in_channels
228 | self.half_channels = in_channels // 2
229 | self.out_channels = out_channels or self.in_channels
230 | self.kernels = [(2, 3, 3), (1, 3, 3)]
231 |
232 | # Flag for spatio-temporal pyramid pooling
233 | self.use_pyramid_pooling = use_pyramid_pooling
234 |
235 | # 3 convolution paths: 2x3x3, 1x3x3, 1x1x1
236 | self.convolution_paths = []
237 | for kernel_size in self.kernels:
238 | self.convolution_paths.append(
239 | nn.Sequential(
240 | conv_1x1x1_norm_activated(self.in_channels, self.half_channels),
241 | CausalConv3d(self.half_channels, self.half_channels, kernel_size=kernel_size),
242 | )
243 | )
244 | self.convolution_paths.append(conv_1x1x1_norm_activated(self.in_channels, self.half_channels))
245 | self.convolution_paths = nn.ModuleList(self.convolution_paths)
246 |
247 | agg_in_channels = len(self.convolution_paths) * self.half_channels
248 |
249 | if self.use_pyramid_pooling:
250 | assert pool_sizes is not None, "setting must contain the list of kernel_size, but is None."
251 | reduction_channels = self.in_channels // 3
252 | self.pyramid_pooling = PyramidSpatioTemporalPooling(self.in_channels, reduction_channels, pool_sizes)
253 | agg_in_channels += len(pool_sizes) * reduction_channels
254 |
255 | # Feature aggregation
256 | self.aggregation = nn.Sequential(
257 | conv_1x1x1_norm_activated(agg_in_channels, self.out_channels),)
258 |
259 | if self.out_channels != self.in_channels:
260 | self.projection = nn.Sequential(
261 | nn.Conv3d(self.in_channels, self.out_channels, kernel_size=1, bias=False),
262 | nn.BatchNorm3d(self.out_channels),
263 | )
264 | else:
265 | self.projection = None
266 |
267 | def forward(self, *inputs):
268 | (x,) = inputs
269 | x_paths = []
270 | for conv in self.convolution_paths:
271 | x_paths.append(conv(x))
272 | x_residual = torch.cat(x_paths, dim=1)
273 | if self.use_pyramid_pooling:
274 | x_pool = self.pyramid_pooling(x)
275 | x_residual = torch.cat([x_residual, x_pool], dim=1)
276 | x_residual = self.aggregation(x_residual)
277 |
278 | if self.out_channels != self.in_channels:
279 | x = self.projection(x)
280 | x = x + x_residual
281 | return x
282 |
--------------------------------------------------------------------------------
/fiery/metrics.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | import torch
4 | from pytorch_lightning.metrics.metric import Metric
5 | from pytorch_lightning.metrics.functional.classification import stat_scores_multiple_classes
6 | from pytorch_lightning.metrics.functional.reduction import reduce
7 |
8 |
9 | class IntersectionOverUnion(Metric):
10 | """Computes intersection-over-union."""
11 | def __init__(
12 | self,
13 | n_classes: int,
14 | ignore_index: Optional[int] = None,
15 | absent_score: float = 0.0,
16 | reduction: str = 'none',
17 | compute_on_step: bool = False,
18 | ):
19 | super().__init__(compute_on_step=compute_on_step)
20 |
21 | self.n_classes = n_classes
22 | self.ignore_index = ignore_index
23 | self.absent_score = absent_score
24 | self.reduction = reduction
25 |
26 | self.add_state('true_positive', default=torch.zeros(n_classes), dist_reduce_fx='sum')
27 | self.add_state('false_positive', default=torch.zeros(n_classes), dist_reduce_fx='sum')
28 | self.add_state('false_negative', default=torch.zeros(n_classes), dist_reduce_fx='sum')
29 | self.add_state('support', default=torch.zeros(n_classes), dist_reduce_fx='sum')
30 |
31 | def update(self, prediction: torch.Tensor, target: torch.Tensor):
32 | tps, fps, _, fns, sups = stat_scores_multiple_classes(prediction, target, self.n_classes)
33 |
34 | self.true_positive += tps
35 | self.false_positive += fps
36 | self.false_negative += fns
37 | self.support += sups
38 |
39 | def compute(self):
40 | scores = torch.zeros(self.n_classes, device=self.true_positive.device, dtype=torch.float32)
41 |
42 | for class_idx in range(self.n_classes):
43 | if class_idx == self.ignore_index:
44 | continue
45 |
46 | tp = self.true_positive[class_idx]
47 | fp = self.false_positive[class_idx]
48 | fn = self.false_negative[class_idx]
49 | sup = self.support[class_idx]
50 |
51 | # If this class is absent in the target (no support) AND absent in the pred (no true or false
52 | # positives), then use the absent_score for this class.
53 | if sup + tp + fp == 0:
54 | scores[class_idx] = self.absent_score
55 | continue
56 |
57 | denominator = tp + fp + fn
58 | score = tp.to(torch.float) / denominator
59 | scores[class_idx] = score
60 |
61 | # Remove the ignored class index from the scores.
62 | if (self.ignore_index is not None) and (0 <= self.ignore_index < self.n_classes):
63 | scores = torch.cat([scores[:self.ignore_index], scores[self.ignore_index+1:]])
64 |
65 | return reduce(scores, reduction=self.reduction)
66 |
67 |
68 | class PanopticMetric(Metric):
69 | def __init__(
70 | self,
71 | n_classes: int,
72 | temporally_consistent: bool = True,
73 | vehicles_id: int = 1,
74 | compute_on_step: bool = False,
75 | ):
76 | super().__init__(compute_on_step=compute_on_step)
77 |
78 | self.n_classes = n_classes
79 | self.temporally_consistent = temporally_consistent
80 | self.vehicles_id = vehicles_id
81 | self.keys = ['iou', 'true_positive', 'false_positive', 'false_negative']
82 |
83 | self.add_state('iou', default=torch.zeros(n_classes), dist_reduce_fx='sum')
84 | self.add_state('true_positive', default=torch.zeros(n_classes), dist_reduce_fx='sum')
85 | self.add_state('false_positive', default=torch.zeros(n_classes), dist_reduce_fx='sum')
86 | self.add_state('false_negative', default=torch.zeros(n_classes), dist_reduce_fx='sum')
87 |
88 | def update(self, pred_instance, gt_instance):
89 | """
90 | Update state with predictions and targets.
91 |
92 | Parameters
93 | ----------
94 | pred_instance: (b, s, h, w)
95 | Temporally consistent instance segmentation prediction.
96 | gt_instance: (b, s, h, w)
97 | Ground truth instance segmentation.
98 | """
99 | batch_size, sequence_length = gt_instance.shape[:2]
100 | # Process labels
101 | assert gt_instance.min() == 0, 'ID 0 of gt_instance must be background'
102 | pred_segmentation = (pred_instance > 0).long()
103 | gt_segmentation = (gt_instance > 0).long()
104 |
105 | for b in range(batch_size):
106 | unique_id_mapping = {}
107 | for t in range(sequence_length):
108 | result = self.panoptic_metrics(
109 | pred_segmentation[b, t].detach(),
110 | pred_instance[b, t].detach(),
111 | gt_segmentation[b, t],
112 | gt_instance[b, t],
113 | unique_id_mapping,
114 | )
115 |
116 | self.iou += result['iou']
117 | self.true_positive += result['true_positive']
118 | self.false_positive += result['false_positive']
119 | self.false_negative += result['false_negative']
120 |
121 | def compute(self):
122 | denominator = torch.maximum(
123 | (self.true_positive + self.false_positive / 2 + self.false_negative / 2),
124 | torch.ones_like(self.true_positive)
125 | )
126 | pq = self.iou / denominator
127 | sq = self.iou / torch.maximum(self.true_positive, torch.ones_like(self.true_positive))
128 | rq = self.true_positive / denominator
129 |
130 | return {'pq': pq,
131 | 'sq': sq,
132 | 'rq': rq,
133 | # If 0, it means there wasn't any detection.
134 | 'denominator': (self.true_positive + self.false_positive / 2 + self.false_negative / 2),
135 | }
136 |
137 | def panoptic_metrics(self, pred_segmentation, pred_instance, gt_segmentation, gt_instance, unique_id_mapping):
138 | """
139 | Computes panoptic quality metric components.
140 |
141 | Parameters
142 | ----------
143 | pred_segmentation: [H, W] range {0, ..., n_classes-1} (>= n_classes is void)
144 | pred_instance: [H, W] range {0, ..., n_instances} (zero means background)
145 | gt_segmentation: [H, W] range {0, ..., n_classes-1} (>= n_classes is void)
146 | gt_instance: [H, W] range {0, ..., n_instances} (zero means background)
147 | unique_id_mapping: instance id mapping to check consistency
148 | """
149 | n_classes = self.n_classes
150 |
151 | result = {key: torch.zeros(n_classes, dtype=torch.float32, device=gt_instance.device) for key in self.keys}
152 |
153 | assert pred_segmentation.dim() == 2
154 | assert pred_segmentation.shape == pred_instance.shape == gt_segmentation.shape == gt_instance.shape
155 |
156 | n_instances = int(torch.cat([pred_instance, gt_instance]).max().item())
157 | n_all_things = n_instances + n_classes # Classes + instances.
158 | n_things_and_void = n_all_things + 1
159 |
160 | # Now 1 is background; 0 is void (not used). 2 is vehicle semantic class but since it overlaps with
161 | # instances, it is not present.
162 | # and the rest are instance ids starting from 3
163 | prediction, pred_to_cls = self.combine_mask(pred_segmentation, pred_instance, n_classes, n_all_things)
164 | target, target_to_cls = self.combine_mask(gt_segmentation, gt_instance, n_classes, n_all_things)
165 |
166 | # Compute ious between all stuff and things
167 | # hack for bincounting 2 arrays together
168 | x = prediction + n_things_and_void * target
169 | bincount_2d = torch.bincount(x.long(), minlength=n_things_and_void ** 2)
170 | if bincount_2d.shape[0] != n_things_and_void ** 2:
171 | raise ValueError('Incorrect bincount size.')
172 | conf = bincount_2d.reshape((n_things_and_void, n_things_and_void))
173 | # Drop void class
174 | conf = conf[1:, 1:]
175 |
176 | # Confusion matrix contains intersections between all combinations of classes
177 | union = conf.sum(0).unsqueeze(0) + conf.sum(1).unsqueeze(1) - conf
178 | iou = torch.where(union > 0, (conf.float() + 1e-9) / (union.float() + 1e-9), torch.zeros_like(union).float())
179 |
180 | # In the iou matrix, first dimension is target idx, second dimension is pred idx.
181 | # Mapping will contain a tuple that maps prediction idx to target idx for segments matched by iou.
182 | mapping = (iou > 0.5).nonzero(as_tuple=False)
183 |
184 | # Check that classes match.
185 | is_matching = pred_to_cls[mapping[:, 1]] == target_to_cls[mapping[:, 0]]
186 | mapping = mapping[is_matching]
187 | tp_mask = torch.zeros_like(conf, dtype=torch.bool)
188 | tp_mask[mapping[:, 0], mapping[:, 1]] = True
189 |
190 | # First ids correspond to "stuff" i.e. semantic seg.
191 | # Instance ids are offset accordingly
192 | for target_id, pred_id in mapping:
193 | cls_id = pred_to_cls[pred_id]
194 |
195 | if self.temporally_consistent and cls_id == self.vehicles_id:
196 | if target_id.item() in unique_id_mapping and unique_id_mapping[target_id.item()] != pred_id.item():
197 | # Not temporally consistent
198 | result['false_negative'][target_to_cls[target_id]] += 1
199 | result['false_positive'][pred_to_cls[pred_id]] += 1
200 | unique_id_mapping[target_id.item()] = pred_id.item()
201 | continue
202 |
203 | result['true_positive'][cls_id] += 1
204 | result['iou'][cls_id] += iou[target_id][pred_id]
205 | unique_id_mapping[target_id.item()] = pred_id.item()
206 |
207 | for target_id in range(n_classes, n_all_things):
208 | # If this is a true positive do nothing.
209 | if tp_mask[target_id, n_classes:].any():
210 | continue
211 | # If this target instance didn't match with any predictions and was present set it as false negative.
212 | if target_to_cls[target_id] != -1:
213 | result['false_negative'][target_to_cls[target_id]] += 1
214 |
215 | for pred_id in range(n_classes, n_all_things):
216 | # If this is a true positive do nothing.
217 | if tp_mask[n_classes:, pred_id].any():
218 | continue
219 | # If this predicted instance didn't match with any prediction, set that predictions as false positive.
220 | if pred_to_cls[pred_id] != -1 and (conf[:, pred_id] > 0).any():
221 | result['false_positive'][pred_to_cls[pred_id]] += 1
222 |
223 | return result
224 |
225 | def combine_mask(self, segmentation: torch.Tensor, instance: torch.Tensor, n_classes: int, n_all_things: int):
226 | """Shifts all things ids by num_classes and combines things and stuff into a single mask
227 |
228 | Returns a combined mask + a mapping from id to segmentation class.
229 | """
230 | instance = instance.view(-1)
231 | instance_mask = instance > 0
232 | instance = instance - 1 + n_classes
233 |
234 | segmentation = segmentation.clone().view(-1)
235 | segmentation_mask = segmentation < n_classes # Remove void pixels.
236 |
237 | # Build an index from instance id to class id.
238 | instance_id_to_class_tuples = torch.cat(
239 | (
240 | instance[instance_mask & segmentation_mask].unsqueeze(1),
241 | segmentation[instance_mask & segmentation_mask].unsqueeze(1),
242 | ),
243 | dim=1,
244 | )
245 | instance_id_to_class = -instance_id_to_class_tuples.new_ones((n_all_things,))
246 | instance_id_to_class[instance_id_to_class_tuples[:, 0]] = instance_id_to_class_tuples[:, 1]
247 | instance_id_to_class[torch.arange(n_classes, device=segmentation.device)] = torch.arange(
248 | n_classes, device=segmentation.device
249 | )
250 |
251 | segmentation[instance_mask] = instance[instance_mask]
252 | segmentation += 1 # Shift all legit classes by 1.
253 | segmentation[~segmentation_mask] = 0 # Shift void class to zero.
254 |
255 | return segmentation, instance_id_to_class
256 |
--------------------------------------------------------------------------------
/fiery/trainer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import pytorch_lightning as pl
4 |
5 | from fiery.config import get_cfg
6 | from fiery.models.fiery import Fiery
7 | from fiery.losses import ProbabilisticLoss, SpatialRegressionLoss, SegmentationLoss
8 | from fiery.metrics import IntersectionOverUnion, PanopticMetric
9 | from fiery.utils.geometry import cumulative_warp_features_reverse
10 | from fiery.utils.instance import predict_instance_segmentation_and_trajectories
11 | from fiery.utils.visualisation import visualise_output
12 |
13 |
14 | class TrainingModule(pl.LightningModule):
15 | def __init__(self, hparams):
16 | super().__init__()
17 |
18 | # see config.py for details
19 | self.hparams = hparams
20 | # pytorch lightning does not support saving YACS CfgNone
21 | cfg = get_cfg(cfg_dict=self.hparams)
22 | self.cfg = cfg
23 | self.n_classes = len(self.cfg.SEMANTIC_SEG.WEIGHTS)
24 |
25 | # Bird's-eye view extent in meters
26 | assert self.cfg.LIFT.X_BOUND[1] > 0 and self.cfg.LIFT.Y_BOUND[1] > 0
27 | self.spatial_extent = (self.cfg.LIFT.X_BOUND[1], self.cfg.LIFT.Y_BOUND[1])
28 |
29 | # Model
30 | self.model = Fiery(cfg)
31 |
32 | # Losses
33 | self.losses_fn = nn.ModuleDict()
34 | self.losses_fn['segmentation'] = SegmentationLoss(
35 | class_weights=torch.Tensor(self.cfg.SEMANTIC_SEG.WEIGHTS),
36 | use_top_k=self.cfg.SEMANTIC_SEG.USE_TOP_K,
37 | top_k_ratio=self.cfg.SEMANTIC_SEG.TOP_K_RATIO,
38 | future_discount=self.cfg.FUTURE_DISCOUNT,
39 | )
40 |
41 | # Uncertainty weighting
42 | self.model.segmentation_weight = nn.Parameter(torch.tensor(0.0), requires_grad=True)
43 |
44 | self.metric_iou_val = IntersectionOverUnion(self.n_classes)
45 |
46 | self.losses_fn['instance_center'] = SpatialRegressionLoss(
47 | norm=2, future_discount=self.cfg.FUTURE_DISCOUNT
48 | )
49 | self.losses_fn['instance_offset'] = SpatialRegressionLoss(
50 | norm=1, future_discount=self.cfg.FUTURE_DISCOUNT, ignore_index=self.cfg.DATASET.IGNORE_INDEX
51 | )
52 |
53 | # Uncertainty weighting
54 | self.model.centerness_weight = nn.Parameter(torch.tensor(0.0), requires_grad=True)
55 | self.model.offset_weight = nn.Parameter(torch.tensor(0.0), requires_grad=True)
56 |
57 | self.metric_panoptic_val = PanopticMetric(n_classes=self.n_classes)
58 |
59 | if self.cfg.INSTANCE_FLOW.ENABLED:
60 | self.losses_fn['instance_flow'] = SpatialRegressionLoss(
61 | norm=1, future_discount=self.cfg.FUTURE_DISCOUNT, ignore_index=self.cfg.DATASET.IGNORE_INDEX
62 | )
63 | # Uncertainty weighting
64 | self.model.flow_weight = nn.Parameter(torch.tensor(0.0), requires_grad=True)
65 |
66 | if self.cfg.PROBABILISTIC.ENABLED:
67 | self.losses_fn['probabilistic'] = ProbabilisticLoss()
68 |
69 | self.training_step_count = 0
70 |
71 | def shared_step(self, batch, is_train):
72 | image = batch['image']
73 | intrinsics = batch['intrinsics']
74 | extrinsics = batch['extrinsics']
75 | future_egomotion = batch['future_egomotion']
76 |
77 | # Warp labels
78 | labels, future_distribution_inputs = self.prepare_future_labels(batch)
79 |
80 | # Forward pass
81 | output = self.model(
82 | image, intrinsics, extrinsics, future_egomotion, future_distribution_inputs
83 | )
84 |
85 | #####
86 | # Loss computation
87 | #####
88 | loss = {}
89 | segmentation_factor = 1 / torch.exp(self.model.segmentation_weight)
90 | loss['segmentation'] = segmentation_factor * self.losses_fn['segmentation'](
91 | output['segmentation'], labels['segmentation']
92 | )
93 | loss['segmentation_uncertainty'] = 0.5 * self.model.segmentation_weight
94 |
95 | centerness_factor = 1 / (2*torch.exp(self.model.centerness_weight))
96 | loss['instance_center'] = centerness_factor * self.losses_fn['instance_center'](
97 | output['instance_center'], labels['centerness']
98 | )
99 |
100 | offset_factor = 1 / (2*torch.exp(self.model.offset_weight))
101 | loss['instance_offset'] = offset_factor * self.losses_fn['instance_offset'](
102 | output['instance_offset'], labels['offset']
103 | )
104 |
105 | loss['centerness_uncertainty'] = 0.5 * self.model.centerness_weight
106 | loss['offset_uncertainty'] = 0.5 * self.model.offset_weight
107 |
108 | if self.cfg.INSTANCE_FLOW.ENABLED:
109 | flow_factor = 1 / (2*torch.exp(self.model.flow_weight))
110 | loss['instance_flow'] = flow_factor * self.losses_fn['instance_flow'](
111 | output['instance_flow'], labels['flow']
112 | )
113 |
114 | loss['flow_uncertainty'] = 0.5 * self.model.flow_weight
115 |
116 | if self.cfg.PROBABILISTIC.ENABLED:
117 | loss['probabilistic'] = self.cfg.PROBABILISTIC.WEIGHT * self.losses_fn['probabilistic'](output)
118 |
119 | # Metrics
120 | if not is_train:
121 | seg_prediction = output['segmentation'].detach()
122 | seg_prediction = torch.argmax(seg_prediction, dim=2, keepdims=True)
123 | self.metric_iou_val(seg_prediction, labels['segmentation'])
124 |
125 | pred_consistent_instance_seg = predict_instance_segmentation_and_trajectories(
126 | output, compute_matched_centers=False
127 | )
128 |
129 | self.metric_panoptic_val(pred_consistent_instance_seg, labels['instance'])
130 |
131 | return output, labels, loss
132 |
133 | def prepare_future_labels(self, batch):
134 | labels = {}
135 | future_distribution_inputs = []
136 |
137 | segmentation_labels = batch['segmentation']
138 | instance_center_labels = batch['centerness']
139 | instance_offset_labels = batch['offset']
140 | instance_flow_labels = batch['flow']
141 | gt_instance = batch['instance']
142 | future_egomotion = batch['future_egomotion']
143 |
144 | # Warp labels to present's reference frame
145 | segmentation_labels = cumulative_warp_features_reverse(
146 | segmentation_labels[:, (self.model.receptive_field - 1):].float(),
147 | future_egomotion[:, (self.model.receptive_field - 1):],
148 | mode='nearest', spatial_extent=self.spatial_extent,
149 | ).long().contiguous()
150 | labels['segmentation'] = segmentation_labels
151 | future_distribution_inputs.append(segmentation_labels)
152 |
153 | # Warp instance labels to present's reference frame
154 | gt_instance = cumulative_warp_features_reverse(
155 | gt_instance[:, (self.model.receptive_field - 1):].float().unsqueeze(2),
156 | future_egomotion[:, (self.model.receptive_field - 1):],
157 | mode='nearest', spatial_extent=self.spatial_extent,
158 | ).long().contiguous()[:, :, 0]
159 | labels['instance'] = gt_instance
160 |
161 | instance_center_labels = cumulative_warp_features_reverse(
162 | instance_center_labels[:, (self.model.receptive_field - 1):],
163 | future_egomotion[:, (self.model.receptive_field - 1):],
164 | mode='nearest', spatial_extent=self.spatial_extent,
165 | ).contiguous()
166 | labels['centerness'] = instance_center_labels
167 |
168 | instance_offset_labels = cumulative_warp_features_reverse(
169 | instance_offset_labels[:, (self.model.receptive_field- 1):],
170 | future_egomotion[:, (self.model.receptive_field - 1):],
171 | mode='nearest', spatial_extent=self.spatial_extent,
172 | ).contiguous()
173 | labels['offset'] = instance_offset_labels
174 |
175 | future_distribution_inputs.append(instance_center_labels)
176 | future_distribution_inputs.append(instance_offset_labels)
177 |
178 | if self.cfg.INSTANCE_FLOW.ENABLED:
179 | instance_flow_labels = cumulative_warp_features_reverse(
180 | instance_flow_labels[:, (self.model.receptive_field - 1):],
181 | future_egomotion[:, (self.model.receptive_field - 1):],
182 | mode='nearest', spatial_extent=self.spatial_extent,
183 | ).contiguous()
184 | labels['flow'] = instance_flow_labels
185 |
186 | future_distribution_inputs.append(instance_flow_labels)
187 |
188 | if len(future_distribution_inputs) > 0:
189 | future_distribution_inputs = torch.cat(future_distribution_inputs, dim=2)
190 |
191 | return labels, future_distribution_inputs
192 |
193 | def visualise(self, labels, output, batch_idx, prefix='train'):
194 | visualisation_video = visualise_output(labels, output, self.cfg)
195 | name = f'{prefix}_outputs'
196 | if prefix == 'val':
197 | name = name + f'_{batch_idx}'
198 | self.logger.experiment.add_video(name, visualisation_video, global_step=self.training_step_count, fps=2)
199 |
200 | def training_step(self, batch, batch_idx):
201 | output, labels, loss = self.shared_step(batch, True)
202 | self.training_step_count += 1
203 | for key, value in loss.items():
204 | self.logger.experiment.add_scalar(key, value, global_step=self.training_step_count)
205 |
206 | if self.training_step_count % self.cfg.VIS_INTERVAL == 0:
207 | self.visualise(labels, output, batch_idx, prefix='train')
208 | return sum(loss.values())
209 |
210 | def validation_step(self, batch, batch_idx):
211 | output, labels, loss = self.shared_step(batch, False)
212 | for key, value in loss.items():
213 | self.log('val_' + key, value)
214 |
215 | if batch_idx == 0:
216 | self.visualise(labels, output, batch_idx, prefix='val')
217 |
218 | def shared_epoch_end(self, step_outputs, is_train):
219 | # log per class iou metrics
220 | class_names = ['background', 'dynamic']
221 | if not is_train:
222 | scores = self.metric_iou_val.compute()
223 | for key, value in zip(class_names, scores):
224 | self.logger.experiment.add_scalar('val_iou_' + key, value, global_step=self.training_step_count)
225 | self.metric_iou_val.reset()
226 |
227 | if not is_train:
228 | scores = self.metric_panoptic_val.compute()
229 | for key, value in scores.items():
230 | for instance_name, score in zip(['background', 'vehicles'], value):
231 | if instance_name != 'background':
232 | self.logger.experiment.add_scalar(f'val_{key}_{instance_name}', score.item(),
233 | global_step=self.training_step_count)
234 | self.metric_panoptic_val.reset()
235 |
236 | self.logger.experiment.add_scalar('segmentation_weight',
237 | 1 / (torch.exp(self.model.segmentation_weight)),
238 | global_step=self.training_step_count)
239 | self.logger.experiment.add_scalar('centerness_weight',
240 | 1 / (2 * torch.exp(self.model.centerness_weight)),
241 | global_step=self.training_step_count)
242 | self.logger.experiment.add_scalar('offset_weight', 1 / (2 * torch.exp(self.model.offset_weight)),
243 | global_step=self.training_step_count)
244 | if self.cfg.INSTANCE_FLOW.ENABLED:
245 | self.logger.experiment.add_scalar('flow_weight', 1 / (2 * torch.exp(self.model.flow_weight)),
246 | global_step=self.training_step_count)
247 |
248 | def training_epoch_end(self, step_outputs):
249 | self.shared_epoch_end(step_outputs, True)
250 |
251 | def validation_epoch_end(self, step_outputs):
252 | self.shared_epoch_end(step_outputs, False)
253 |
254 | def configure_optimizers(self):
255 | params = self.model.parameters()
256 | optimizer = torch.optim.Adam(
257 | params, lr=self.cfg.OPTIMIZER.LR, weight_decay=self.cfg.OPTIMIZER.WEIGHT_DECAY
258 | )
259 |
260 | return optimizer
261 |
--------------------------------------------------------------------------------
/fiery/utils/visualisation.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import matplotlib.pylab
4 |
5 | from fiery.utils.instance import predict_instance_segmentation_and_trajectories
6 |
7 | DEFAULT_COLORMAP = matplotlib.pylab.cm.jet
8 |
9 |
10 | def flow_to_image(flow: np.ndarray, autoscale: bool = False) -> np.ndarray:
11 | """
12 | Applies colour map to flow which should be a 2 channel image tensor HxWx2. Returns a HxWx3 numpy image
13 | Code adapted from: https://github.com/liruoteng/FlowNet/blob/master/models/flownet/scripts/flowlib.py
14 | """
15 | u = flow[0, :, :]
16 | v = flow[1, :, :]
17 |
18 | # Convert to polar coordinates
19 | rad = np.sqrt(u ** 2 + v ** 2)
20 | maxrad = np.max(rad)
21 |
22 | # Normalise flow maps
23 | if autoscale:
24 | u /= maxrad + np.finfo(float).eps
25 | v /= maxrad + np.finfo(float).eps
26 |
27 | # visualise flow with cmap
28 | return np.uint8(compute_color(u, v) * 255)
29 |
30 |
31 | def _normalise(image: np.ndarray) -> np.ndarray:
32 | lower = np.min(image)
33 | delta = np.max(image) - lower
34 | if delta == 0:
35 | delta = 1
36 | image = (image.astype(np.float32) - lower) / delta
37 | return image
38 |
39 |
40 | def apply_colour_map(
41 | image: np.ndarray, cmap: matplotlib.colors.LinearSegmentedColormap = DEFAULT_COLORMAP, autoscale: bool = False
42 | ) -> np.ndarray:
43 | """
44 | Applies a colour map to the given 1 or 2 channel numpy image. if 2 channel, must be 2xHxW.
45 | Returns a HxWx3 numpy image
46 | """
47 | if image.ndim == 2 or (image.ndim == 3 and image.shape[0] == 1):
48 | if image.ndim == 3:
49 | image = image[0]
50 | # grayscale scalar image
51 | if autoscale:
52 | image = _normalise(image)
53 | return cmap(image)[:, :, :3]
54 | if image.shape[0] == 2:
55 | # 2 dimensional UV
56 | return flow_to_image(image, autoscale=autoscale)
57 | if image.shape[0] == 3:
58 | # normalise rgb channels
59 | if autoscale:
60 | image = _normalise(image)
61 | return np.transpose(image, axes=[1, 2, 0])
62 | raise Exception('Image must be 1, 2 or 3 channel to convert to colour_map (CxHxW)')
63 |
64 |
65 | def heatmap_image(
66 | image: np.ndarray, cmap: matplotlib.colors.LinearSegmentedColormap = DEFAULT_COLORMAP, autoscale: bool = True
67 | ) -> np.ndarray:
68 | """Colorize an 1 or 2 channel image with a colourmap."""
69 | if not issubclass(image.dtype.type, np.floating):
70 | raise ValueError(f"Expected a ndarray of float type, but got dtype {image.dtype}")
71 | if not (image.ndim == 2 or (image.ndim == 3 and image.shape[0] in [1, 2])):
72 | raise ValueError(f"Expected a ndarray of shape [H, W] or [1, H, W] or [2, H, W], but got shape {image.shape}")
73 | heatmap_np = apply_colour_map(image, cmap=cmap, autoscale=autoscale)
74 | heatmap_np = np.uint8(heatmap_np * 255)
75 | return heatmap_np
76 |
77 |
78 | def compute_color(u: np.ndarray, v: np.ndarray) -> np.ndarray:
79 | assert u.shape == v.shape
80 | [h, w] = u.shape
81 | img = np.zeros([h, w, 3])
82 | nan_mask = np.isnan(u) | np.isnan(v)
83 | u[nan_mask] = 0
84 | v[nan_mask] = 0
85 |
86 | colorwheel = make_color_wheel()
87 | ncols = np.size(colorwheel, 0)
88 |
89 | rad = np.sqrt(u ** 2 + v ** 2)
90 | a = np.arctan2(-v, -u) / np.pi
91 | f_k = (a + 1) / 2 * (ncols - 1) + 1
92 | k_0 = np.floor(f_k).astype(int)
93 | k_1 = k_0 + 1
94 | k_1[k_1 == ncols + 1] = 1
95 | f = f_k - k_0
96 |
97 | for i in range(0, np.size(colorwheel, 1)):
98 | tmp = colorwheel[:, i]
99 | col0 = tmp[k_0 - 1] / 255
100 | col1 = tmp[k_1 - 1] / 255
101 | col = (1 - f) * col0 + f * col1
102 |
103 | idx = rad <= 1
104 | col[idx] = 1 - rad[idx] * (1 - col[idx])
105 | notidx = np.logical_not(idx)
106 |
107 | col[notidx] *= 0.75
108 | img[:, :, i] = col * (1 - nan_mask)
109 |
110 | return img
111 |
112 |
113 | def make_color_wheel() -> np.ndarray:
114 | """
115 | Create colour wheel.
116 | Code adapted from https://github.com/liruoteng/FlowNet/blob/master/models/flownet/scripts/flowlib.py
117 | """
118 | red_yellow = 15
119 | yellow_green = 6
120 | green_cyan = 4
121 | cyan_blue = 11
122 | blue_magenta = 13
123 | magenta_red = 6
124 |
125 | ncols = red_yellow + yellow_green + green_cyan + cyan_blue + blue_magenta + magenta_red
126 | colorwheel = np.zeros([ncols, 3])
127 |
128 | col = 0
129 |
130 | # red_yellow
131 | colorwheel[0:red_yellow, 0] = 255
132 | colorwheel[0:red_yellow, 1] = np.transpose(np.floor(255 * np.arange(0, red_yellow) / red_yellow))
133 | col += red_yellow
134 |
135 | # yellow_green
136 | colorwheel[col : col + yellow_green, 0] = 255 - np.transpose(
137 | np.floor(255 * np.arange(0, yellow_green) / yellow_green)
138 | )
139 | colorwheel[col : col + yellow_green, 1] = 255
140 | col += yellow_green
141 |
142 | # green_cyan
143 | colorwheel[col : col + green_cyan, 1] = 255
144 | colorwheel[col : col + green_cyan, 2] = np.transpose(np.floor(255 * np.arange(0, green_cyan) / green_cyan))
145 | col += green_cyan
146 |
147 | # cyan_blue
148 | colorwheel[col : col + cyan_blue, 1] = 255 - np.transpose(np.floor(255 * np.arange(0, cyan_blue) / cyan_blue))
149 | colorwheel[col : col + cyan_blue, 2] = 255
150 | col += cyan_blue
151 |
152 | # blue_magenta
153 | colorwheel[col : col + blue_magenta, 2] = 255
154 | colorwheel[col : col + blue_magenta, 0] = np.transpose(np.floor(255 * np.arange(0, blue_magenta) / blue_magenta))
155 | col += +blue_magenta
156 |
157 | # magenta_red
158 | colorwheel[col : col + magenta_red, 2] = 255 - np.transpose(np.floor(255 * np.arange(0, magenta_red) / magenta_red))
159 | colorwheel[col : col + magenta_red, 0] = 255
160 |
161 | return colorwheel
162 |
163 |
164 | def make_contour(img, colour=[0, 0, 0], double_line=False):
165 | h, w = img.shape[:2]
166 | out = img.copy()
167 | # Vertical lines
168 | out[np.arange(h), np.repeat(0, h)] = colour
169 | out[np.arange(h), np.repeat(w - 1, h)] = colour
170 |
171 | # Horizontal lines
172 | out[np.repeat(0, w), np.arange(w)] = colour
173 | out[np.repeat(h - 1, w), np.arange(w)] = colour
174 |
175 | if double_line:
176 | out[np.arange(h), np.repeat(1, h)] = colour
177 | out[np.arange(h), np.repeat(w - 2, h)] = colour
178 |
179 | # Horizontal lines
180 | out[np.repeat(1, w), np.arange(w)] = colour
181 | out[np.repeat(h - 2, w), np.arange(w)] = colour
182 | return out
183 |
184 |
185 | def plot_instance_map(instance_image, instance_map, instance_colours=None, bg_image=None):
186 | if isinstance(instance_image, torch.Tensor):
187 | instance_image = instance_image.cpu().numpy()
188 | assert isinstance(instance_image, np.ndarray)
189 | if instance_colours is None:
190 | instance_colours = generate_instance_colours(instance_map)
191 | if len(instance_image.shape) > 2:
192 | instance_image = instance_image.reshape((instance_image.shape[-2], instance_image.shape[-1]))
193 |
194 | if bg_image is None:
195 | plot_image = 255 * np.ones((instance_image.shape[0], instance_image.shape[1], 3), dtype=np.uint8)
196 | else:
197 | plot_image = bg_image
198 |
199 | for key, value in instance_colours.items():
200 | plot_image[instance_image == key] = value
201 |
202 | return plot_image
203 |
204 |
205 | def visualise_output(labels, output, cfg):
206 | semantic_colours = np.array([[255, 255, 255], [0, 0, 0]], dtype=np.uint8)
207 |
208 | consistent_instance_seg = predict_instance_segmentation_and_trajectories(
209 | output, compute_matched_centers=False
210 | )
211 |
212 | sequence_length = consistent_instance_seg.shape[1]
213 | b = 0
214 | video = []
215 | for t in range(sequence_length):
216 | out_t = []
217 | # Ground truth
218 | unique_ids = torch.unique(labels['instance'][b, t]).cpu().numpy()[1:]
219 | instance_map = dict(zip(unique_ids, unique_ids))
220 | instance_plot = plot_instance_map(labels['instance'][b, t].cpu(), instance_map)[::-1, ::-1]
221 | instance_plot = make_contour(instance_plot)
222 |
223 | semantic_seg = labels['segmentation'].squeeze(2).cpu().numpy()
224 | semantic_plot = semantic_colours[semantic_seg[b, t][::-1, ::-1]]
225 | semantic_plot = make_contour(semantic_plot)
226 |
227 | if cfg.INSTANCE_FLOW.ENABLED:
228 | future_flow_plot = labels['flow'][b, t].cpu().numpy()
229 | future_flow_plot[:, semantic_seg[b, t] != 1] = 0
230 | future_flow_plot = flow_to_image(future_flow_plot)[::-1, ::-1]
231 | future_flow_plot = make_contour(future_flow_plot)
232 | else:
233 | future_flow_plot = np.zeros_like(semantic_plot)
234 |
235 | center_plot = heatmap_image(labels['centerness'][b, t, 0].cpu().numpy())[::-1, ::-1]
236 | center_plot = make_contour(center_plot)
237 |
238 | offset_plot = labels['offset'][b, t].cpu().numpy()
239 | offset_plot[:, semantic_seg[b, t] != 1] = 0
240 | offset_plot = flow_to_image(offset_plot)[::-1, ::-1]
241 | offset_plot = make_contour(offset_plot)
242 |
243 | out_t.append(np.concatenate([instance_plot, future_flow_plot,
244 | semantic_plot, center_plot, offset_plot], axis=0))
245 |
246 | # Predictions
247 | unique_ids = torch.unique(consistent_instance_seg[b, t]).cpu().numpy()[1:]
248 | instance_map = dict(zip(unique_ids, unique_ids))
249 | instance_plot = plot_instance_map(consistent_instance_seg[b, t].cpu(), instance_map)[::-1, ::-1]
250 | instance_plot = make_contour(instance_plot)
251 |
252 | semantic_seg = output['segmentation'].argmax(dim=2).detach().cpu().numpy()
253 | semantic_plot = semantic_colours[semantic_seg[b, t][::-1, ::-1]]
254 | semantic_plot = make_contour(semantic_plot)
255 |
256 | if cfg.INSTANCE_FLOW.ENABLED:
257 | future_flow_plot = output['instance_flow'][b, t].detach().cpu().numpy()
258 | future_flow_plot[:, semantic_seg[b, t] != 1] = 0
259 | future_flow_plot = flow_to_image(future_flow_plot)[::-1, ::-1]
260 | future_flow_plot = make_contour(future_flow_plot)
261 | else:
262 | future_flow_plot = np.zeros_like(semantic_plot)
263 |
264 | center_plot = heatmap_image(output['instance_center'][b, t, 0].detach().cpu().numpy())[::-1, ::-1]
265 | center_plot = make_contour(center_plot)
266 |
267 | offset_plot = output['instance_offset'][b, t].detach().cpu().numpy()
268 | offset_plot[:, semantic_seg[b, t] != 1] = 0
269 | offset_plot = flow_to_image(offset_plot)[::-1, ::-1]
270 | offset_plot = make_contour(offset_plot)
271 |
272 | out_t.append(np.concatenate([instance_plot, future_flow_plot,
273 | semantic_plot, center_plot, offset_plot], axis=0))
274 | out_t = np.concatenate(out_t, axis=1)
275 | # Shape (C, H, W)
276 | out_t = out_t.transpose((2, 0, 1))
277 |
278 | video.append(out_t)
279 |
280 | # Shape (B, T, C, H, W)
281 | video = np.stack(video)[None]
282 | return video
283 |
284 |
285 | def convert_figure_numpy(figure):
286 | """ Convert figure to numpy image """
287 | figure_np = np.frombuffer(figure.canvas.tostring_rgb(), dtype=np.uint8)
288 | figure_np = figure_np.reshape(figure.canvas.get_width_height()[::-1] + (3,))
289 | return figure_np
290 |
291 |
292 | def generate_instance_colours(instance_map):
293 | # Most distinct 22 colors (kelly colors from https://stackoverflow.com/questions/470690/how-to-automatically-generate
294 | # -n-distinct-colors)
295 | # plus some colours from AD40k
296 | INSTANCE_COLOURS = np.asarray([
297 | [0, 0, 0],
298 | [255, 179, 0],
299 | [128, 62, 117],
300 | [255, 104, 0],
301 | [166, 189, 215],
302 | [193, 0, 32],
303 | [206, 162, 98],
304 | [129, 112, 102],
305 | [0, 125, 52],
306 | [246, 118, 142],
307 | [0, 83, 138],
308 | [255, 122, 92],
309 | [83, 55, 122],
310 | [255, 142, 0],
311 | [179, 40, 81],
312 | [244, 200, 0],
313 | [127, 24, 13],
314 | [147, 170, 0],
315 | [89, 51, 21],
316 | [241, 58, 19],
317 | [35, 44, 22],
318 | [112, 224, 255],
319 | [70, 184, 160],
320 | [153, 0, 255],
321 | [71, 255, 0],
322 | [255, 0, 163],
323 | [255, 204, 0],
324 | [0, 255, 235],
325 | [255, 0, 235],
326 | [255, 0, 122],
327 | [255, 245, 0],
328 | [10, 190, 212],
329 | [214, 255, 0],
330 | [0, 204, 255],
331 | [20, 0, 255],
332 | [255, 255, 0],
333 | [0, 153, 255],
334 | [0, 255, 204],
335 | [41, 255, 0],
336 | [173, 0, 255],
337 | [0, 245, 255],
338 | [71, 0, 255],
339 | [0, 255, 184],
340 | [0, 92, 255],
341 | [184, 255, 0],
342 | [255, 214, 0],
343 | [25, 194, 194],
344 | [92, 0, 255],
345 | [220, 220, 220],
346 | [255, 9, 92],
347 | [112, 9, 255],
348 | [8, 255, 214],
349 | [255, 184, 6],
350 | [10, 255, 71],
351 | [255, 41, 10],
352 | [7, 255, 255],
353 | [224, 255, 8],
354 | [102, 8, 255],
355 | [255, 61, 6],
356 | [255, 194, 7],
357 | [0, 255, 20],
358 | [255, 8, 41],
359 | [255, 5, 153],
360 | [6, 51, 255],
361 | [235, 12, 255],
362 | [160, 150, 20],
363 | [0, 163, 255],
364 | [140, 140, 140],
365 | [250, 10, 15],
366 | [20, 255, 0],
367 | ])
368 |
369 | return {instance_id: INSTANCE_COLOURS[global_instance_id % len(INSTANCE_COLOURS)] for
370 | instance_id, global_instance_id in instance_map.items()
371 | }
372 |
--------------------------------------------------------------------------------
/fiery/utils/instance.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 |
3 | import torch
4 | import torch.nn.functional as F
5 | import numpy as np
6 | from scipy.optimize import linear_sum_assignment
7 |
8 | from fiery.utils.geometry import mat2pose_vec, pose_vec2mat, warp_features
9 |
10 |
11 | # set ignore index to 0 for vis
12 | def convert_instance_mask_to_center_and_offset_label(instance_img, future_egomotion, num_instances, ignore_index=255,
13 | subtract_egomotion=True, sigma=3, spatial_extent=None):
14 | seq_len, h, w = instance_img.shape
15 | center_label = torch.zeros(seq_len, 1, h, w)
16 | offset_label = ignore_index * torch.ones(seq_len, 2, h, w)
17 | future_displacement_label = ignore_index * torch.ones(seq_len, 2, h, w)
18 | # x is vertical displacement, y is horizontal displacement
19 | x, y = torch.meshgrid(torch.arange(h, dtype=torch.float), torch.arange(w, dtype=torch.float))
20 |
21 | if subtract_egomotion:
22 | future_egomotion_inv = mat2pose_vec(pose_vec2mat(future_egomotion).inverse())
23 |
24 | # Compute warped instance segmentation
25 | warped_instance_seg = {}
26 | for t in range(1, seq_len):
27 | warped_inst_t = warp_features(instance_img[t].unsqueeze(0).unsqueeze(1).float(),
28 | future_egomotion_inv[t - 1].unsqueeze(0), mode='nearest',
29 | spatial_extent=spatial_extent)
30 | warped_instance_seg[t] = warped_inst_t[0, 0]
31 |
32 | # Ignore id 0 which is the background
33 | for instance_id in range(1, num_instances+1):
34 | prev_xc = None
35 | prev_yc = None
36 | prev_mask = None
37 | for t in range(seq_len):
38 | instance_mask = (instance_img[t] == instance_id)
39 | if instance_mask.sum() == 0:
40 | # this instance is not in this frame
41 | prev_xc = None
42 | prev_yc = None
43 | prev_mask = None
44 | continue
45 |
46 | xc = x[instance_mask].mean().round().long()
47 | yc = y[instance_mask].mean().round().long()
48 |
49 | off_x = xc - x
50 | off_y = yc - y
51 | g = torch.exp(-(off_x ** 2 + off_y ** 2) / sigma ** 2)
52 | center_label[t, 0] = torch.maximum(center_label[t, 0], g)
53 | offset_label[t, 0, instance_mask] = off_x[instance_mask]
54 | offset_label[t, 1, instance_mask] = off_y[instance_mask]
55 |
56 | if prev_xc is not None:
57 | # old method
58 | # cur_pt = torch.stack((xc, yc)).unsqueeze(0).float()
59 | # if subtract_egomotion:
60 | # cur_pt = warp_points(cur_pt, future_egomotion_inv[t - 1])
61 | # cur_pt = cur_pt.squeeze(0)
62 |
63 | warped_instance_mask = warped_instance_seg[t] == instance_id
64 | if warped_instance_mask.sum() > 0:
65 | warped_xc = x[warped_instance_mask].mean().round()
66 | warped_yc = y[warped_instance_mask].mean().round()
67 |
68 | delta_x = warped_xc - prev_xc
69 | delta_y = warped_yc - prev_yc
70 | future_displacement_label[t - 1, 0, prev_mask] = delta_x
71 | future_displacement_label[t - 1, 1, prev_mask] = delta_y
72 |
73 | prev_xc = xc
74 | prev_yc = yc
75 | prev_mask = instance_mask
76 |
77 | return center_label, offset_label, future_displacement_label
78 |
79 |
80 | def find_instance_centers(center_prediction: torch.Tensor, conf_threshold: float = 0.1, nms_kernel_size: float = 3):
81 | assert len(center_prediction.shape) == 3
82 | center_prediction = F.threshold(center_prediction, threshold=conf_threshold, value=-1)
83 |
84 | nms_padding = (nms_kernel_size - 1) // 2
85 | maxpooled_center_prediction = F.max_pool2d(
86 | center_prediction, kernel_size=nms_kernel_size, stride=1, padding=nms_padding
87 | )
88 |
89 | # Filter all elements that are not the maximum (i.e. the center of the heatmap instance)
90 | center_prediction[center_prediction != maxpooled_center_prediction] = -1
91 | return torch.nonzero(center_prediction > 0)[:, 1:]
92 |
93 |
94 | def group_pixels(centers: torch.Tensor, offset_predictions: torch.Tensor) -> torch.Tensor:
95 | width, height = offset_predictions.shape[-2:]
96 | x_grid = (
97 | torch.arange(width, dtype=offset_predictions.dtype, device=offset_predictions.device)
98 | .view(1, width, 1)
99 | .repeat(1, 1, height)
100 | )
101 | y_grid = (
102 | torch.arange(height, dtype=offset_predictions.dtype, device=offset_predictions.device)
103 | .view(1, 1, height)
104 | .repeat(1, width, 1)
105 | )
106 | pixel_grid = torch.cat((x_grid, y_grid), dim=0)
107 | center_locations = (pixel_grid + offset_predictions).view(2, width * height, 1).permute(2, 1, 0)
108 | centers = centers.view(-1, 1, 2)
109 |
110 | distances = torch.norm(centers - center_locations, dim=-1)
111 |
112 | instance_id = torch.argmin(distances, dim=0).reshape(1, width, height) + 1
113 | return instance_id
114 |
115 |
116 | def get_instance_segmentation_and_centers(
117 | center_predictions: torch.Tensor,
118 | offset_predictions: torch.Tensor,
119 | foreground_mask: torch.Tensor,
120 | conf_threshold: float = 0.1,
121 | nms_kernel_size: float = 3,
122 | max_n_instance_centers: int = 100,
123 | ) -> Tuple[torch.Tensor, torch.Tensor]:
124 | width, height = center_predictions.shape[-2:]
125 | center_predictions = center_predictions.view(1, width, height)
126 | offset_predictions = offset_predictions.view(2, width, height)
127 | foreground_mask = foreground_mask.view(1, width, height)
128 |
129 | centers = find_instance_centers(center_predictions, conf_threshold=conf_threshold, nms_kernel_size=nms_kernel_size)
130 | if not len(centers):
131 | return torch.zeros(center_predictions.shape, dtype=torch.int64, device=center_predictions.device), \
132 | torch.zeros((0, 2), device=centers.device)
133 |
134 | if len(centers) > max_n_instance_centers:
135 | print(f'There are a lot of detected instance centers: {centers.shape}')
136 | centers = centers[:max_n_instance_centers].clone()
137 |
138 | instance_ids = group_pixels(centers, offset_predictions)
139 | instance_seg = (instance_ids * foreground_mask.float()).long()
140 |
141 | # Make the indices of instance_seg consecutive
142 | instance_seg = make_instance_seg_consecutive(instance_seg)
143 |
144 | return instance_seg.long(), centers
145 |
146 |
147 | def update_instance_ids(instance_seg, old_ids, new_ids):
148 | """
149 | Parameters
150 | ----------
151 | instance_seg: torch.Tensor arbitrary shape
152 | old_ids: 1D tensor containing the list of old ids, must be all present in instance_seg.
153 | new_ids: 1D tensor with the new ids, aligned with old_ids
154 |
155 | Returns
156 | new_instance_seg: torch.Tensor same shape as instance_seg with new ids
157 | """
158 | indices = torch.arange(old_ids.max() + 1, device=instance_seg.device)
159 | for old_id, new_id in zip(old_ids, new_ids):
160 | indices[old_id] = new_id
161 |
162 | return indices[instance_seg].long()
163 |
164 |
165 | def make_instance_seg_consecutive(instance_seg):
166 | # Make the indices of instance_seg consecutive
167 | unique_ids = torch.unique(instance_seg)
168 | new_ids = torch.arange(len(unique_ids), device=instance_seg.device)
169 | instance_seg = update_instance_ids(instance_seg, unique_ids, new_ids)
170 | return instance_seg
171 |
172 |
173 | def make_instance_id_temporally_consistent(pred_inst, future_flow, matching_threshold=3.0):
174 | """
175 | Parameters
176 | ----------
177 | pred_inst: torch.Tensor (1, seq_len, h, w)
178 | future_flow: torch.Tensor(1, seq_len, 2, h, w)
179 | matching_threshold: distance threshold for a match to be valid.
180 |
181 | Returns
182 | -------
183 | consistent_instance_seg: torch.Tensor(1, seq_len, h, w)
184 |
185 | 1. time t. Loop over all detected instances. Use flow to compute new centers at time t+1.
186 | 2. Store those centers
187 | 3. time t+1. Re-identify instances by comparing position of actual centers, and flow-warped centers.
188 | Make the labels at t+1 consistent with the matching
189 | 4. Repeat
190 | """
191 | assert pred_inst.shape[0] == 1, 'Assumes batch size = 1'
192 |
193 | # Initialise instance segmentations with prediction corresponding to the present
194 | consistent_instance_seg = [pred_inst[0, 0]]
195 | largest_instance_id = consistent_instance_seg[0].max().item()
196 |
197 | _, seq_len, h, w = pred_inst.shape
198 | device = pred_inst.device
199 | for t in range(seq_len - 1):
200 | # Compute predicted future instance means
201 | grid = torch.stack(torch.meshgrid(
202 | torch.arange(h, dtype=torch.float, device=device), torch.arange(w, dtype=torch.float, device=device)
203 | ))
204 |
205 | # Add future flow
206 | grid = grid + future_flow[0, t]
207 | warped_centers = []
208 | # Go through all ids, except the background
209 | t_instance_ids = torch.unique(consistent_instance_seg[-1])[1:].cpu().numpy()
210 |
211 | if len(t_instance_ids) == 0:
212 | # No instance so nothing to update
213 | consistent_instance_seg.append(pred_inst[0, t + 1])
214 | continue
215 |
216 | for instance_id in t_instance_ids:
217 | instance_mask = (consistent_instance_seg[-1] == instance_id)
218 | warped_centers.append(grid[:, instance_mask].mean(dim=1))
219 | warped_centers = torch.stack(warped_centers)
220 |
221 | # Compute actual future instance means
222 | centers = []
223 | grid = torch.stack(torch.meshgrid(
224 | torch.arange(h, dtype=torch.float, device=device), torch.arange(w, dtype=torch.float, device=device)
225 | ))
226 | n_instances = int(pred_inst[0, t + 1].max().item())
227 |
228 | if n_instances == 0:
229 | # No instance, so nothing to update.
230 | consistent_instance_seg.append(pred_inst[0, t + 1])
231 | continue
232 |
233 | for instance_id in range(1, n_instances + 1):
234 | instance_mask = (pred_inst[0, t + 1] == instance_id)
235 | centers.append(grid[:, instance_mask].mean(dim=1))
236 | centers = torch.stack(centers)
237 |
238 | # Compute distance matrix between warped centers and actual centers
239 | distances = torch.norm(centers.unsqueeze(0) - warped_centers.unsqueeze(1), dim=-1).cpu().numpy()
240 | # outputs (row, col) with row: index in frame t, col: index in frame t+1
241 | # the missing ids in col must be added (correspond to new instances)
242 | ids_t, ids_t_one = linear_sum_assignment(distances)
243 | matching_distances = distances[ids_t, ids_t_one]
244 | # Offset by one as id=0 is the background
245 | ids_t += 1
246 | ids_t_one += 1
247 |
248 | # swap ids_t with real ids. as those ids correspond to the position in the distance matrix.
249 | id_mapping = dict(zip(np.arange(1, len(t_instance_ids) + 1), t_instance_ids))
250 | ids_t = np.vectorize(id_mapping.__getitem__, otypes=[np.int64])(ids_t)
251 |
252 | # Filter low quality match
253 | ids_t = ids_t[matching_distances < matching_threshold]
254 | ids_t_one = ids_t_one[matching_distances < matching_threshold]
255 |
256 | # Elements that are in t+1, but weren't matched
257 | remaining_ids = set(torch.unique(pred_inst[0, t + 1]).cpu().numpy()).difference(set(ids_t_one))
258 | # remove background
259 | remaining_ids.remove(0)
260 | # Set remaining_ids to a new unique id
261 | for remaining_id in list(remaining_ids):
262 | largest_instance_id += 1
263 | ids_t = np.append(ids_t, largest_instance_id)
264 | ids_t_one = np.append(ids_t_one, remaining_id)
265 |
266 | consistent_instance_seg.append(update_instance_ids(pred_inst[0, t + 1], old_ids=ids_t_one, new_ids=ids_t))
267 |
268 | consistent_instance_seg = torch.stack(consistent_instance_seg).unsqueeze(0)
269 | return consistent_instance_seg
270 |
271 |
272 | def predict_instance_segmentation_and_trajectories(
273 | output, compute_matched_centers=False, make_consistent=True, vehicles_id=1,
274 | ):
275 | preds = output['segmentation'].detach()
276 | preds = torch.argmax(preds, dim=2, keepdims=True)
277 | foreground_masks = preds.squeeze(2) == vehicles_id
278 |
279 | batch_size, seq_len = preds.shape[:2]
280 | pred_inst = []
281 | for b in range(batch_size):
282 | pred_inst_batch = []
283 | for t in range(seq_len):
284 | pred_instance_t, _ = get_instance_segmentation_and_centers(
285 | output['instance_center'][b, t].detach(),
286 | output['instance_offset'][b, t].detach(),
287 | foreground_masks[b, t].detach()
288 | )
289 | pred_inst_batch.append(pred_instance_t)
290 | pred_inst.append(torch.stack(pred_inst_batch, dim=0))
291 |
292 | pred_inst = torch.stack(pred_inst).squeeze(2)
293 |
294 | if make_consistent:
295 | if output['instance_flow'] is None:
296 | print('Using zero flow because instance_future_output is None')
297 | output['instance_flow'] = torch.zeros_like(output['instance_offset'])
298 | consistent_instance_seg = []
299 | for b in range(batch_size):
300 | consistent_instance_seg.append(
301 | make_instance_id_temporally_consistent(pred_inst[b:b+1],
302 | output['instance_flow'][b:b+1].detach())
303 | )
304 | consistent_instance_seg = torch.cat(consistent_instance_seg, dim=0)
305 | else:
306 | consistent_instance_seg = pred_inst
307 |
308 | if compute_matched_centers:
309 | assert batch_size == 1
310 | # Generate trajectories
311 | matched_centers = {}
312 | _, seq_len, h, w = consistent_instance_seg.shape
313 | grid = torch.stack(torch.meshgrid(
314 | torch.arange(h, dtype=torch.float, device=preds.device),
315 | torch.arange(w, dtype=torch.float, device=preds.device)
316 | ))
317 |
318 | for instance_id in torch.unique(consistent_instance_seg[0, 0])[1:].cpu().numpy():
319 | for t in range(seq_len):
320 | instance_mask = consistent_instance_seg[0, t] == instance_id
321 | if instance_mask.sum() > 0:
322 | matched_centers[instance_id] = matched_centers.get(instance_id, []) + [
323 | grid[:, instance_mask].mean(dim=-1)]
324 |
325 | for key, value in matched_centers.items():
326 | matched_centers[key] = torch.stack(value).cpu().numpy()[:, ::-1]
327 |
328 | return consistent_instance_seg, matched_centers
329 |
330 | return consistent_instance_seg
331 |
332 |
333 |
--------------------------------------------------------------------------------
/fiery/models/fiery.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from fiery.models.encoder import Encoder
5 | from fiery.models.temporal_model import TemporalModelIdentity, TemporalModel
6 | from fiery.models.distributions import DistributionModule
7 | from fiery.models.future_prediction import FuturePrediction
8 | from fiery.models.decoder import Decoder
9 | from fiery.utils.network import pack_sequence_dim, unpack_sequence_dim, set_bn_momentum
10 | from fiery.utils.geometry import cumulative_warp_features, calculate_birds_eye_view_parameters, VoxelsSumming
11 |
12 |
13 | class Fiery(nn.Module):
14 | def __init__(self, cfg):
15 | super().__init__()
16 | self.cfg = cfg
17 |
18 | bev_resolution, bev_start_position, bev_dimension = calculate_birds_eye_view_parameters(
19 | self.cfg.LIFT.X_BOUND, self.cfg.LIFT.Y_BOUND, self.cfg.LIFT.Z_BOUND
20 | )
21 | self.bev_resolution = nn.Parameter(bev_resolution, requires_grad=False)
22 | self.bev_start_position = nn.Parameter(bev_start_position, requires_grad=False)
23 | self.bev_dimension = nn.Parameter(bev_dimension, requires_grad=False)
24 |
25 | self.encoder_downsample = self.cfg.MODEL.ENCODER.DOWNSAMPLE
26 | self.encoder_out_channels = self.cfg.MODEL.ENCODER.OUT_CHANNELS
27 |
28 | self.frustum = self.create_frustum()
29 | self.depth_channels, _, _, _ = self.frustum.shape
30 |
31 | if self.cfg.TIME_RECEPTIVE_FIELD == 1:
32 | assert self.cfg.MODEL.TEMPORAL_MODEL.NAME == 'identity'
33 |
34 | # temporal block
35 | self.receptive_field = self.cfg.TIME_RECEPTIVE_FIELD
36 | self.n_future = self.cfg.N_FUTURE_FRAMES
37 | self.latent_dim = self.cfg.MODEL.DISTRIBUTION.LATENT_DIM
38 |
39 | if self.cfg.MODEL.SUBSAMPLE:
40 | assert self.cfg.DATASET.NAME == 'lyft'
41 | self.receptive_field = 3
42 | self.n_future = 5
43 |
44 | # Spatial extent in bird's-eye view, in meters
45 | self.spatial_extent = (self.cfg.LIFT.X_BOUND[1], self.cfg.LIFT.Y_BOUND[1])
46 | self.bev_size = (self.bev_dimension[0].item(), self.bev_dimension[1].item())
47 |
48 | # Encoder
49 | self.encoder = Encoder(cfg=self.cfg.MODEL.ENCODER, D=self.depth_channels)
50 |
51 | # Temporal model
52 | temporal_in_channels = self.encoder_out_channels
53 | if self.cfg.MODEL.TEMPORAL_MODEL.INPUT_EGOPOSE:
54 | temporal_in_channels += 6
55 | if self.cfg.MODEL.TEMPORAL_MODEL.NAME == 'identity':
56 | self.temporal_model = TemporalModelIdentity(temporal_in_channels, self.receptive_field)
57 | elif cfg.MODEL.TEMPORAL_MODEL.NAME == 'temporal_block':
58 | self.temporal_model = TemporalModel(
59 | temporal_in_channels,
60 | self.receptive_field,
61 | input_shape=self.bev_size,
62 | start_out_channels=self.cfg.MODEL.TEMPORAL_MODEL.START_OUT_CHANNELS,
63 | extra_in_channels=self.cfg.MODEL.TEMPORAL_MODEL.EXTRA_IN_CHANNELS,
64 | n_spatial_layers_between_temporal_layers=self.cfg.MODEL.TEMPORAL_MODEL.INBETWEEN_LAYERS,
65 | use_pyramid_pooling=self.cfg.MODEL.TEMPORAL_MODEL.PYRAMID_POOLING,
66 | )
67 | else:
68 | raise NotImplementedError(f'Temporal module {self.cfg.MODEL.TEMPORAL_MODEL.NAME}.')
69 |
70 | self.future_pred_in_channels = self.temporal_model.out_channels
71 | if self.n_future > 0:
72 | # probabilistic sampling
73 | if self.cfg.PROBABILISTIC.ENABLED:
74 | # Distribution networks
75 | self.present_distribution = DistributionModule(
76 | self.future_pred_in_channels,
77 | self.latent_dim,
78 | min_log_sigma=self.cfg.MODEL.DISTRIBUTION.MIN_LOG_SIGMA,
79 | max_log_sigma=self.cfg.MODEL.DISTRIBUTION.MAX_LOG_SIGMA,
80 | )
81 |
82 | future_distribution_in_channels = (self.future_pred_in_channels
83 | + self.n_future * self.cfg.PROBABILISTIC.FUTURE_DIM
84 | )
85 | self.future_distribution = DistributionModule(
86 | future_distribution_in_channels,
87 | self.latent_dim,
88 | min_log_sigma=self.cfg.MODEL.DISTRIBUTION.MIN_LOG_SIGMA,
89 | max_log_sigma=self.cfg.MODEL.DISTRIBUTION.MAX_LOG_SIGMA,
90 | )
91 |
92 | # Future prediction
93 | self.future_prediction = FuturePrediction(
94 | in_channels=self.future_pred_in_channels,
95 | latent_dim=self.latent_dim,
96 | n_gru_blocks=self.cfg.MODEL.FUTURE_PRED.N_GRU_BLOCKS,
97 | n_res_layers=self.cfg.MODEL.FUTURE_PRED.N_RES_LAYERS,
98 | )
99 |
100 | # Decoder
101 | self.decoder = Decoder(
102 | in_channels=self.future_pred_in_channels,
103 | n_classes=len(self.cfg.SEMANTIC_SEG.WEIGHTS),
104 | predict_future_flow=self.cfg.INSTANCE_FLOW.ENABLED,
105 | )
106 |
107 | set_bn_momentum(self, self.cfg.MODEL.BN_MOMENTUM)
108 |
109 | def create_frustum(self):
110 | # Create grid in image plane
111 | h, w = self.cfg.IMAGE.FINAL_DIM
112 | downsampled_h, downsampled_w = h // self.encoder_downsample, w // self.encoder_downsample
113 |
114 | # Depth grid
115 | depth_grid = torch.arange(*self.cfg.LIFT.D_BOUND, dtype=torch.float)
116 | depth_grid = depth_grid.view(-1, 1, 1).expand(-1, downsampled_h, downsampled_w)
117 | n_depth_slices = depth_grid.shape[0]
118 |
119 | # x and y grids
120 | x_grid = torch.linspace(0, w - 1, downsampled_w, dtype=torch.float)
121 | x_grid = x_grid.view(1, 1, downsampled_w).expand(n_depth_slices, downsampled_h, downsampled_w)
122 | y_grid = torch.linspace(0, h - 1, downsampled_h, dtype=torch.float)
123 | y_grid = y_grid.view(1, downsampled_h, 1).expand(n_depth_slices, downsampled_h, downsampled_w)
124 |
125 | # Dimension (n_depth_slices, downsampled_h, downsampled_w, 3)
126 | # containing data points in the image: left-right, top-bottom, depth
127 | frustum = torch.stack((x_grid, y_grid, depth_grid), -1)
128 | return nn.Parameter(frustum, requires_grad=False)
129 |
130 | def forward(self, image, intrinsics, extrinsics, future_egomotion, future_distribution_inputs=None, noise=None):
131 | output = {}
132 |
133 | # Only process features from the past and present
134 | image = image[:, :self.receptive_field].contiguous()
135 | intrinsics = intrinsics[:, :self.receptive_field].contiguous()
136 | extrinsics = extrinsics[:, :self.receptive_field].contiguous()
137 | future_egomotion = future_egomotion[:, :self.receptive_field].contiguous()
138 |
139 | # Lifting features and project to bird's-eye view
140 | x = self.calculate_birds_eye_view_features(image, intrinsics, extrinsics)
141 |
142 | # Warp past features to the present's reference frame
143 | x = cumulative_warp_features(
144 | x.clone(), future_egomotion,
145 | mode='bilinear', spatial_extent=self.spatial_extent,
146 | )
147 |
148 | if self.cfg.MODEL.TEMPORAL_MODEL.INPUT_EGOPOSE:
149 | b, s, c = future_egomotion.shape
150 | h, w = x.shape[-2:]
151 | future_egomotions_spatial = future_egomotion.view(b, s, c, 1, 1).expand(b, s, c, h, w)
152 | # at time 0, no egomotion so feed zero vector
153 | future_egomotions_spatial = torch.cat([torch.zeros_like(future_egomotions_spatial[:, :1]),
154 | future_egomotions_spatial[:, :(self.receptive_field-1)]], dim=1)
155 | x = torch.cat([x, future_egomotions_spatial], dim=-3)
156 |
157 | # Temporal model
158 | states = self.temporal_model(x)
159 |
160 | if self.n_future > 0:
161 | present_state = states[:, :1].contiguous()
162 | if self.cfg.PROBABILISTIC.ENABLED:
163 | # Do probabilistic computation
164 | sample, output_distribution = self.distribution_forward(
165 | present_state, future_distribution_inputs, noise
166 | )
167 | output = {**output, **output_distribution}
168 |
169 | # Prepare future prediction input
170 | b, _, _, h, w = present_state.shape
171 | hidden_state = present_state[:, 0]
172 |
173 | if self.cfg.PROBABILISTIC.ENABLED:
174 | future_prediction_input = sample.expand(-1, self.n_future, -1, -1, -1)
175 | else:
176 | future_prediction_input = hidden_state.new_zeros(b, self.n_future, self.latent_dim, h, w)
177 |
178 | # Recursively predict future states
179 | future_states = self.future_prediction(future_prediction_input, hidden_state)
180 |
181 | # Concatenate present state
182 | future_states = torch.cat([present_state, future_states], dim=1)
183 |
184 | # Predict bird's-eye view outputs
185 | if self.n_future > 0:
186 | bev_output = self.decoder(future_states)
187 | else:
188 | bev_output = self.decoder(states[:, -1:])
189 | output = {**output, **bev_output}
190 |
191 | return output
192 |
193 | def get_geometry(self, intrinsics, extrinsics):
194 | """Calculate the (x, y, z) 3D position of the features.
195 | """
196 | rotation, translation = extrinsics[..., :3, :3], extrinsics[..., :3, 3]
197 | B, N, _ = translation.shape
198 | # Add batch, camera dimension, and a dummy dimension at the end
199 | points = self.frustum.unsqueeze(0).unsqueeze(0).unsqueeze(-1)
200 |
201 | # Camera to ego reference frame
202 | points = torch.cat((points[:, :, :, :, :, :2] * points[:, :, :, :, :, 2:3], points[:, :, :, :, :, 2:3]), 5)
203 | combined_transformation = rotation.matmul(torch.inverse(intrinsics))
204 | points = combined_transformation.view(B, N, 1, 1, 1, 3, 3).matmul(points).squeeze(-1)
205 | points += translation.view(B, N, 1, 1, 1, 3)
206 |
207 | # The 3 dimensions in the ego reference frame are: (forward, sides, height)
208 | return points
209 |
210 | def encoder_forward(self, x):
211 | # batch, n_cameras, channels, height, width
212 | b, n, c, h, w = x.shape
213 |
214 | x = x.view(b * n, c, h, w)
215 | x = self.encoder(x)
216 | x = x.view(b, n, *x.shape[1:])
217 | x = x.permute(0, 1, 3, 4, 5, 2)
218 |
219 | return x
220 |
221 | def projection_to_birds_eye_view(self, x, geometry):
222 | """ Adapted from https://github.com/nv-tlabs/lift-splat-shoot/blob/master/src/models.py#L200"""
223 | # batch, n_cameras, depth, height, width, channels
224 | batch, n, d, h, w, c = x.shape
225 | output = torch.zeros(
226 | (batch, c, self.bev_dimension[0], self.bev_dimension[1]), dtype=torch.float, device=x.device
227 | )
228 |
229 | # Number of 3D points
230 | N = n * d * h * w
231 | for b in range(batch):
232 | # flatten x
233 | x_b = x[b].reshape(N, c)
234 |
235 | # Convert positions to integer indices
236 | geometry_b = ((geometry[b] - (self.bev_start_position - self.bev_resolution / 2.0)) / self.bev_resolution)
237 | geometry_b = geometry_b.view(N, 3).long()
238 |
239 | # Mask out points that are outside the considered spatial extent.
240 | mask = (
241 | (geometry_b[:, 0] >= 0)
242 | & (geometry_b[:, 0] < self.bev_dimension[0])
243 | & (geometry_b[:, 1] >= 0)
244 | & (geometry_b[:, 1] < self.bev_dimension[1])
245 | & (geometry_b[:, 2] >= 0)
246 | & (geometry_b[:, 2] < self.bev_dimension[2])
247 | )
248 | x_b = x_b[mask]
249 | geometry_b = geometry_b[mask]
250 |
251 | # Sort tensors so that those within the same voxel are consecutives.
252 | ranks = (
253 | geometry_b[:, 0] * (self.bev_dimension[1] * self.bev_dimension[2])
254 | + geometry_b[:, 1] * (self.bev_dimension[2])
255 | + geometry_b[:, 2]
256 | )
257 | ranks_indices = ranks.argsort()
258 | x_b, geometry_b, ranks = x_b[ranks_indices], geometry_b[ranks_indices], ranks[ranks_indices]
259 |
260 | # Project to bird's-eye view by summing voxels.
261 | x_b, geometry_b = VoxelsSumming.apply(x_b, geometry_b, ranks)
262 |
263 | bev_feature = torch.zeros((self.bev_dimension[2], self.bev_dimension[0], self.bev_dimension[1], c),
264 | device=x_b.device)
265 | bev_feature[geometry_b[:, 2], geometry_b[:, 0], geometry_b[:, 1]] = x_b
266 |
267 | # Put channel in second position and remove z dimension
268 | bev_feature = bev_feature.permute((0, 3, 1, 2))
269 | bev_feature = bev_feature.squeeze(0)
270 |
271 | output[b] = bev_feature
272 |
273 | return output
274 |
275 | def calculate_birds_eye_view_features(self, x, intrinsics, extrinsics):
276 | b, s, n, c, h, w = x.shape
277 | # Reshape
278 | x = pack_sequence_dim(x)
279 | intrinsics = pack_sequence_dim(intrinsics)
280 | extrinsics = pack_sequence_dim(extrinsics)
281 |
282 | geometry = self.get_geometry(intrinsics, extrinsics)
283 | x = self.encoder_forward(x)
284 | x = self.projection_to_birds_eye_view(x, geometry)
285 | x = unpack_sequence_dim(x, b, s)
286 | return x
287 |
288 | def distribution_forward(self, present_features, future_distribution_inputs=None, noise=None):
289 | """
290 | Parameters
291 | ----------
292 | present_features: 5-D output from dynamics module with shape (b, 1, c, h, w)
293 | future_distribution_inputs: 5-D tensor containing labels shape (b, s, cfg.PROB_FUTURE_DIM, h, w)
294 | noise: a sample from a (0, 1) gaussian with shape (b, s, latent_dim). If None, will sample in function
295 |
296 | Returns
297 | -------
298 | sample: sample taken from present/future distribution, broadcast to shape (b, s, latent_dim, h, w)
299 | present_distribution_mu: shape (b, s, latent_dim)
300 | present_distribution_log_sigma: shape (b, s, latent_dim)
301 | future_distribution_mu: shape (b, s, latent_dim)
302 | future_distribution_log_sigma: shape (b, s, latent_dim)
303 | """
304 | b, s, _, h, w = present_features.size()
305 | assert s == 1
306 |
307 | present_mu, present_log_sigma = self.present_distribution(present_features)
308 |
309 | future_mu, future_log_sigma = None, None
310 | if future_distribution_inputs is not None:
311 | # Concatenate future labels to z_t
312 | future_features = future_distribution_inputs[:, 1:].contiguous().view(b, 1, -1, h, w)
313 | future_features = torch.cat([present_features, future_features], dim=2)
314 | future_mu, future_log_sigma = self.future_distribution(future_features)
315 |
316 | if noise is None:
317 | if self.training:
318 | noise = torch.randn_like(present_mu)
319 | else:
320 | noise = torch.zeros_like(present_mu)
321 | if self.training:
322 | mu = future_mu
323 | sigma = torch.exp(future_log_sigma)
324 | else:
325 | mu = present_mu
326 | sigma = torch.exp(present_log_sigma)
327 | sample = mu + sigma * noise
328 |
329 | # Spatially broadcast sample to the dimensions of present_features
330 | sample = sample.view(b, s, self.latent_dim, 1, 1).expand(b, s, self.latent_dim, h, w)
331 |
332 | output_distribution = {
333 | 'present_mu': present_mu,
334 | 'present_log_sigma': present_log_sigma,
335 | 'future_mu': future_mu,
336 | 'future_log_sigma': future_log_sigma,
337 | }
338 |
339 | return sample, output_distribution
340 |
--------------------------------------------------------------------------------
/fiery/data.py:
--------------------------------------------------------------------------------
1 | import os
2 | from PIL import Image
3 |
4 | import numpy as np
5 | import cv2
6 | import torch
7 | import torchvision
8 |
9 | from pyquaternion import Quaternion
10 | from nuscenes.nuscenes import NuScenes
11 | from nuscenes.utils.splits import create_splits_scenes
12 | from nuscenes.utils.data_classes import Box
13 | from lyft_dataset_sdk.lyftdataset import LyftDataset
14 |
15 | from fiery.utils.geometry import (
16 | resize_and_crop_image,
17 | update_intrinsics,
18 | calculate_birds_eye_view_parameters,
19 | convert_egopose_to_matrix_numpy,
20 | pose_vec2mat,
21 | mat2pose_vec,
22 | invert_matrix_egopose_numpy,
23 | )
24 | from fiery.utils.instance import convert_instance_mask_to_center_and_offset_label
25 | from fiery.utils.lyft_splits import TRAIN_LYFT_INDICES, VAL_LYFT_INDICES
26 |
27 |
28 | class FuturePredictionDataset(torch.utils.data.Dataset):
29 | def __init__(self, nusc, is_train, cfg):
30 | self.nusc = nusc
31 | self.is_train = is_train
32 | self.cfg = cfg
33 |
34 | self.is_lyft = isinstance(nusc, LyftDataset)
35 |
36 | if self.is_lyft:
37 | self.dataroot = self.nusc.data_path
38 | else:
39 | self.dataroot = self.nusc.dataroot
40 |
41 | self.mode = 'train' if self.is_train else 'val'
42 |
43 | self.sequence_length = cfg.TIME_RECEPTIVE_FIELD + cfg.N_FUTURE_FRAMES
44 |
45 | self.scenes = self.get_scenes()
46 | self.ixes = self.prepro()
47 | self.indices = self.get_indices()
48 |
49 | # Image resizing and cropping
50 | self.augmentation_parameters = self.get_resizing_and_cropping_parameters()
51 |
52 | # Normalising input images
53 | self.normalise_image = torchvision.transforms.Compose(
54 | [torchvision.transforms.ToTensor(),
55 | torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
56 | ]
57 | )
58 |
59 | # Bird's-eye view parameters
60 | bev_resolution, bev_start_position, bev_dimension = calculate_birds_eye_view_parameters(
61 | cfg.LIFT.X_BOUND, cfg.LIFT.Y_BOUND, cfg.LIFT.Z_BOUND
62 | )
63 | self.bev_resolution, self.bev_start_position, self.bev_dimension = (
64 | bev_resolution.numpy(), bev_start_position.numpy(), bev_dimension.numpy()
65 | )
66 |
67 | # Spatial extent in bird's-eye view, in meters
68 | self.spatial_extent = (self.cfg.LIFT.X_BOUND[1], self.cfg.LIFT.Y_BOUND[1])
69 |
70 | def get_scenes(self):
71 |
72 | if self.is_lyft:
73 | scenes = [row['name'] for row in self.nusc.scene]
74 |
75 | # Split in train/val
76 | indices = TRAIN_LYFT_INDICES if self.is_train else VAL_LYFT_INDICES
77 | scenes = [scenes[i] for i in indices]
78 | else:
79 | # filter by scene split
80 | split = {'v1.0-trainval': {True: 'train', False: 'val'},
81 | 'v1.0-mini': {True: 'mini_train', False: 'mini_val'},}[
82 | self.nusc.version
83 | ][self.is_train]
84 |
85 | scenes = create_splits_scenes()[split]
86 |
87 | return scenes
88 |
89 | def prepro(self):
90 | samples = [samp for samp in self.nusc.sample]
91 |
92 | # remove samples that aren't in this split
93 | samples = [samp for samp in samples if self.nusc.get('scene', samp['scene_token'])['name'] in self.scenes]
94 |
95 | # sort by scene, timestamp (only to make chronological viz easier)
96 | samples.sort(key=lambda x: (x['scene_token'], x['timestamp']))
97 |
98 | return samples
99 |
100 | def get_indices(self):
101 | indices = []
102 | for index in range(len(self.ixes)):
103 | is_valid_data = True
104 | previous_rec = None
105 | current_indices = []
106 | for t in range(self.sequence_length):
107 | index_t = index + t
108 | # Going over the dataset size limit.
109 | if index_t >= len(self.ixes):
110 | is_valid_data = False
111 | break
112 | rec = self.ixes[index_t]
113 | # Check if scene is the same
114 | if (previous_rec is not None) and (rec['scene_token'] != previous_rec['scene_token']):
115 | is_valid_data = False
116 | break
117 |
118 | current_indices.append(index_t)
119 | previous_rec = rec
120 |
121 | if is_valid_data:
122 | indices.append(current_indices)
123 |
124 | return np.asarray(indices)
125 |
126 | def get_resizing_and_cropping_parameters(self):
127 | original_height, original_width = self.cfg.IMAGE.ORIGINAL_HEIGHT, self.cfg.IMAGE.ORIGINAL_WIDTH
128 | final_height, final_width = self.cfg.IMAGE.FINAL_DIM
129 |
130 | resize_scale = self.cfg.IMAGE.RESIZE_SCALE
131 | resize_dims = (int(original_width * resize_scale), int(original_height * resize_scale))
132 | resized_width, resized_height = resize_dims
133 |
134 | crop_h = self.cfg.IMAGE.TOP_CROP
135 | crop_w = int(max(0, (resized_width - final_width) / 2))
136 | # Left, top, right, bottom crops.
137 | crop = (crop_w, crop_h, crop_w + final_width, crop_h + final_height)
138 |
139 | if resized_width != final_width:
140 | print('Zero padding left and right parts of the image.')
141 | if crop_h + final_height != resized_height:
142 | print('Zero padding bottom part of the image.')
143 |
144 | return {'scale_width': resize_scale,
145 | 'scale_height': resize_scale,
146 | 'resize_dims': resize_dims,
147 | 'crop': crop,
148 | }
149 |
150 | def get_input_data(self, rec):
151 | """
152 | Parameters
153 | ----------
154 | rec: nuscenes identifier for a given timestamp
155 |
156 | Returns
157 | -------
158 | images: torch.Tensor (N, 3, H, W)
159 | intrinsics: torch.Tensor (3, 3)
160 | extrinsics: torch.Tensor(N, 4, 4)
161 | """
162 | images = []
163 | intrinsics = []
164 | extrinsics = []
165 | cameras = self.cfg.IMAGE.NAMES
166 |
167 | # The extrinsics we want are from the camera sensor to "flat egopose" as defined
168 | # https://github.com/nutonomy/nuscenes-devkit/blob/9b492f76df22943daf1dc991358d3d606314af27/python-sdk/nuscenes/nuscenes.py#L279
169 | # which corresponds to the position of the lidar.
170 | # This is because the labels are generated by projecting the 3D bounding box in this lidar's reference frame.
171 |
172 | # From lidar egopose to world.
173 | lidar_sample = self.nusc.get('sample_data', rec['data']['LIDAR_TOP'])
174 | lidar_pose = self.nusc.get('ego_pose', lidar_sample['ego_pose_token'])
175 | yaw = Quaternion(lidar_pose['rotation']).yaw_pitch_roll[0]
176 | lidar_rotation = Quaternion(scalar=np.cos(yaw / 2), vector=[0, 0, np.sin(yaw / 2)])
177 | lidar_translation = np.array(lidar_pose['translation'])[:, None]
178 | lidar_to_world = np.vstack([
179 | np.hstack((lidar_rotation.rotation_matrix, lidar_translation)),
180 | np.array([0, 0, 0, 1])
181 | ])
182 |
183 | for cam in cameras:
184 | camera_sample = self.nusc.get('sample_data', rec['data'][cam])
185 |
186 | # Transformation from world to egopose
187 | car_egopose = self.nusc.get('ego_pose', camera_sample['ego_pose_token'])
188 | egopose_rotation = Quaternion(car_egopose['rotation']).inverse
189 | egopose_translation = -np.array(car_egopose['translation'])[:, None]
190 | world_to_car_egopose = np.vstack([
191 | np.hstack((egopose_rotation.rotation_matrix, egopose_rotation.rotation_matrix @ egopose_translation)),
192 | np.array([0, 0, 0, 1])
193 | ])
194 |
195 | # From egopose to sensor
196 | sensor_sample = self.nusc.get('calibrated_sensor', camera_sample['calibrated_sensor_token'])
197 | intrinsic = torch.Tensor(sensor_sample['camera_intrinsic'])
198 | sensor_rotation = Quaternion(sensor_sample['rotation'])
199 | sensor_translation = np.array(sensor_sample['translation'])[:, None]
200 | car_egopose_to_sensor = np.vstack([
201 | np.hstack((sensor_rotation.rotation_matrix, sensor_translation)),
202 | np.array([0, 0, 0, 1])
203 | ])
204 | car_egopose_to_sensor = np.linalg.inv(car_egopose_to_sensor)
205 |
206 | # Combine all the transformation.
207 | # From sensor to lidar.
208 | lidar_to_sensor = car_egopose_to_sensor @ world_to_car_egopose @ lidar_to_world
209 | sensor_to_lidar = torch.from_numpy(np.linalg.inv(lidar_to_sensor)).float()
210 |
211 | # Load image
212 | image_filename = os.path.join(self.dataroot, camera_sample['filename'])
213 | img = Image.open(image_filename)
214 | # Resize and crop
215 | img = resize_and_crop_image(
216 | img, resize_dims=self.augmentation_parameters['resize_dims'], crop=self.augmentation_parameters['crop']
217 | )
218 | # Normalise image
219 | normalised_img = self.normalise_image(img)
220 |
221 | # Combine resize/cropping in the intrinsics
222 | top_crop = self.augmentation_parameters['crop'][1]
223 | left_crop = self.augmentation_parameters['crop'][0]
224 | intrinsic = update_intrinsics(
225 | intrinsic, top_crop, left_crop,
226 | scale_width=self.augmentation_parameters['scale_width'],
227 | scale_height=self.augmentation_parameters['scale_height']
228 | )
229 |
230 | images.append(normalised_img.unsqueeze(0).unsqueeze(0))
231 | intrinsics.append(intrinsic.unsqueeze(0).unsqueeze(0))
232 | extrinsics.append(sensor_to_lidar.unsqueeze(0).unsqueeze(0))
233 |
234 | images, intrinsics, extrinsics = (torch.cat(images, dim=1),
235 | torch.cat(intrinsics, dim=1),
236 | torch.cat(extrinsics, dim=1)
237 | )
238 |
239 | return images, intrinsics, extrinsics
240 |
241 | def _get_top_lidar_pose(self, rec):
242 | egopose = self.nusc.get('ego_pose', self.nusc.get('sample_data', rec['data']['LIDAR_TOP'])['ego_pose_token'])
243 | trans = -np.array(egopose['translation'])
244 | yaw = Quaternion(egopose['rotation']).yaw_pitch_roll[0]
245 | rot = Quaternion(scalar=np.cos(yaw / 2), vector=[0, 0, np.sin(yaw / 2)]).inverse
246 | return trans, rot
247 |
248 | def get_birds_eye_view_label(self, rec, instance_map):
249 | translation, rotation = self._get_top_lidar_pose(rec)
250 | segmentation = np.zeros((self.bev_dimension[0], self.bev_dimension[1]))
251 | # Background is ID 0
252 | instance = np.zeros((self.bev_dimension[0], self.bev_dimension[1]))
253 | z_position = np.zeros((self.bev_dimension[0], self.bev_dimension[1]))
254 | attribute_label = np.zeros((self.bev_dimension[0], self.bev_dimension[1]))
255 |
256 | for annotation_token in rec['anns']:
257 | # Filter out all non vehicle instances
258 | annotation = self.nusc.get('sample_annotation', annotation_token)
259 |
260 | if not self.is_lyft:
261 | # NuScenes filter
262 | if 'vehicle' not in annotation['category_name']:
263 | continue
264 | if self.cfg.DATASET.FILTER_INVISIBLE_VEHICLES and int(annotation['visibility_token']) == 1:
265 | continue
266 | else:
267 | # Lyft filter
268 | if annotation['category_name'] not in ['bus', 'car', 'construction_vehicle', 'trailer', 'truck']:
269 | continue
270 |
271 | if annotation['instance_token'] not in instance_map:
272 | instance_map[annotation['instance_token']] = len(instance_map) + 1
273 | instance_id = instance_map[annotation['instance_token']]
274 |
275 | if not self.is_lyft:
276 | instance_attribute = int(annotation['visibility_token'])
277 | else:
278 | instance_attribute = 0
279 |
280 | poly_region, z = self._get_poly_region_in_image(annotation, translation, rotation)
281 | cv2.fillPoly(instance, [poly_region], instance_id)
282 | cv2.fillPoly(segmentation, [poly_region], 1.0)
283 | cv2.fillPoly(z_position, [poly_region], z)
284 | cv2.fillPoly(attribute_label, [poly_region], instance_attribute)
285 |
286 | return segmentation, instance, z_position, instance_map, attribute_label
287 |
288 | def _get_poly_region_in_image(self, instance_annotation, ego_translation, ego_rotation):
289 | box = Box(
290 | instance_annotation['translation'], instance_annotation['size'], Quaternion(instance_annotation['rotation'])
291 | )
292 | box.translate(ego_translation)
293 | box.rotate(ego_rotation)
294 |
295 | pts = box.bottom_corners()[:2].T
296 | pts = np.round((pts - self.bev_start_position[:2] + self.bev_resolution[:2] / 2.0) / self.bev_resolution[:2]).astype(np.int32)
297 | pts[:, [1, 0]] = pts[:, [0, 1]]
298 |
299 | z = box.bottom_corners()[2, 0]
300 | return pts, z
301 |
302 | def get_label(self, rec, instance_map):
303 | segmentation_np, instance_np, z_position_np, instance_map, attribute_label_np = \
304 | self.get_birds_eye_view_label(rec, instance_map)
305 | segmentation = torch.from_numpy(segmentation_np).long().unsqueeze(0).unsqueeze(0)
306 | instance = torch.from_numpy(instance_np).long().unsqueeze(0)
307 | z_position = torch.from_numpy(z_position_np).float().unsqueeze(0).unsqueeze(0)
308 | attribute_label = torch.from_numpy(attribute_label_np).long().unsqueeze(0).unsqueeze(0)
309 |
310 | return segmentation, instance, z_position, instance_map, attribute_label
311 |
312 | def get_future_egomotion(self, rec, index):
313 | rec_t0 = rec
314 |
315 | # Identity
316 | future_egomotion = np.eye(4, dtype=np.float32)
317 |
318 | if index < len(self.ixes) - 1:
319 | rec_t1 = self.ixes[index + 1]
320 |
321 | if rec_t0['scene_token'] == rec_t1['scene_token']:
322 | egopose_t0 = self.nusc.get(
323 | 'ego_pose', self.nusc.get('sample_data', rec_t0['data']['LIDAR_TOP'])['ego_pose_token']
324 | )
325 | egopose_t1 = self.nusc.get(
326 | 'ego_pose', self.nusc.get('sample_data', rec_t1['data']['LIDAR_TOP'])['ego_pose_token']
327 | )
328 |
329 | egopose_t0 = convert_egopose_to_matrix_numpy(egopose_t0)
330 | egopose_t1 = convert_egopose_to_matrix_numpy(egopose_t1)
331 |
332 | future_egomotion = invert_matrix_egopose_numpy(egopose_t1).dot(egopose_t0)
333 | future_egomotion[3, :3] = 0.0
334 | future_egomotion[3, 3] = 1.0
335 |
336 | future_egomotion = torch.Tensor(future_egomotion).float()
337 |
338 | # Convert to 6DoF vector
339 | future_egomotion = mat2pose_vec(future_egomotion)
340 | return future_egomotion.unsqueeze(0)
341 |
342 | def __len__(self):
343 | return len(self.indices)
344 |
345 | def __getitem__(self, index):
346 | """
347 | Returns
348 | -------
349 | data: dict with the following keys:
350 | image: torch.Tensor (T, N, 3, H, W)
351 | normalised cameras images with T the sequence length, and N the number of cameras.
352 | intrinsics: torch.Tensor (T, N, 3, 3)
353 | intrinsics containing resizing and cropping parameters.
354 | extrinsics: torch.Tensor (T, N, 4, 4)
355 | 6 DoF pose from world coordinates to camera coordinates.
356 | segmentation: torch.Tensor (T, 1, H_bev, W_bev)
357 | (H_bev, W_bev) are the pixel dimensions in bird's-eye view.
358 | instance: torch.Tensor (T, 1, H_bev, W_bev)
359 | centerness: torch.Tensor (T, 1, H_bev, W_bev)
360 | offset: torch.Tensor (T, 2, H_bev, W_bev)
361 | flow: torch.Tensor (T, 2, H_bev, W_bev)
362 | future_egomotion: torch.Tensor (T, 6)
363 | 6 DoF egomotion t -> t+1
364 | sample_token: List (T,)
365 | 'z_position': list_z_position,
366 | 'attribute': list_attribute_label,
367 |
368 | """
369 | data = {}
370 | keys = ['image', 'intrinsics', 'extrinsics',
371 | 'segmentation', 'instance', 'centerness', 'offset', 'flow', 'future_egomotion',
372 | 'sample_token',
373 | 'z_position', 'attribute'
374 | ]
375 | for key in keys:
376 | data[key] = []
377 |
378 | instance_map = {}
379 | # Loop over all the frames in the sequence.
380 | for index_t in self.indices[index]:
381 | rec = self.ixes[index_t]
382 |
383 | images, intrinsics, extrinsics = self.get_input_data(rec)
384 | segmentation, instance, z_position, instance_map, attribute_label = self.get_label(rec, instance_map)
385 |
386 | future_egomotion = self.get_future_egomotion(rec, index_t)
387 |
388 | data['image'].append(images)
389 | data['intrinsics'].append(intrinsics)
390 | data['extrinsics'].append(extrinsics)
391 | data['segmentation'].append(segmentation)
392 | data['instance'].append(instance)
393 | data['future_egomotion'].append(future_egomotion)
394 | data['sample_token'].append(rec['token'])
395 | data['z_position'].append(z_position)
396 | data['attribute'].append(attribute_label)
397 |
398 | for key, value in data.items():
399 | if key in ['sample_token', 'centerness', 'offset', 'flow']:
400 | continue
401 | data[key] = torch.cat(value, dim=0)
402 |
403 | # If lyft need to subsample, and update future_egomotions
404 | if self.cfg.MODEL.SUBSAMPLE:
405 | for key, value in data.items():
406 | if key in ['future_egomotion', 'sample_token', 'centerness', 'offset', 'flow']:
407 | continue
408 | data[key] = data[key][::2].clone()
409 | data['sample_token'] = data['sample_token'][::2]
410 |
411 | # Update future egomotions
412 | future_egomotions_matrix = pose_vec2mat(data['future_egomotion'])
413 | future_egomotion_accum = torch.zeros_like(future_egomotions_matrix)
414 | future_egomotion_accum[:-1] = future_egomotions_matrix[:-1] @ future_egomotions_matrix[1:]
415 | future_egomotion_accum = mat2pose_vec(future_egomotion_accum)
416 | data['future_egomotion'] = future_egomotion_accum[::2].clone()
417 |
418 | instance_centerness, instance_offset, instance_flow = convert_instance_mask_to_center_and_offset_label(
419 | data['instance'], data['future_egomotion'],
420 | num_instances=len(instance_map), ignore_index=self.cfg.DATASET.IGNORE_INDEX, subtract_egomotion=True,
421 | spatial_extent=self.spatial_extent,
422 | )
423 | data['centerness'] = instance_centerness
424 | data['offset'] = instance_offset
425 | data['flow'] = instance_flow
426 | return data
427 |
428 |
429 | def prepare_dataloaders(cfg, return_dataset=False):
430 | version = cfg.DATASET.VERSION
431 | train_on_training_data = True
432 |
433 | if cfg.DATASET.NAME == 'nuscenes':
434 | # 28130 train and 6019 val
435 | dataroot = os.path.join(cfg.DATASET.DATAROOT, version)
436 | nusc = NuScenes(version='v1.0-{}'.format(cfg.DATASET.VERSION), dataroot=dataroot, verbose=False)
437 | elif cfg.DATASET.NAME == 'lyft':
438 | # train contains 22680 samples
439 | # we split in 16506 6174
440 | dataroot = os.path.join(cfg.DATASET.DATAROOT, 'trainval')
441 | nusc = LyftDataset(data_path=dataroot,
442 | json_path=os.path.join(dataroot, 'train_data'),
443 | verbose=True)
444 |
445 | traindata = FuturePredictionDataset(nusc, train_on_training_data, cfg)
446 | valdata = FuturePredictionDataset(nusc, False, cfg)
447 |
448 | if cfg.DATASET.VERSION == 'mini':
449 | traindata.indices = traindata.indices[:10]
450 | valdata.indices = valdata.indices[:10]
451 |
452 | nworkers = cfg.N_WORKERS
453 | trainloader = torch.utils.data.DataLoader(
454 | traindata, batch_size=cfg.BATCHSIZE, shuffle=True, num_workers=nworkers, pin_memory=True, drop_last=True
455 | )
456 | valloader = torch.utils.data.DataLoader(
457 | valdata, batch_size=cfg.BATCHSIZE, shuffle=False, num_workers=nworkers, pin_memory=True, drop_last=False)
458 |
459 | if return_dataset:
460 | return trainloader, valloader, traindata, valdata
461 | else:
462 | return trainloader, valloader
463 |
--------------------------------------------------------------------------------