├── semseg
├── utils
│ ├── __init__.py
│ ├── visualize.py
│ └── utils.py
├── models
│ ├── layers
│ │ ├── __init__.py
│ │ ├── common.py
│ │ └── initialize.py
│ ├── __init__.py
│ ├── heads
│ │ ├── __init__.py
│ │ └── segformer.py
│ ├── backbones
│ │ ├── __init__.py
│ │ └── mmsformer.py
│ ├── mmsformer.py
│ └── base.py
├── datasets
│ ├── __init__.py
│ ├── unzip.py
│ ├── pst.py
│ ├── fmb.py
│ └── mcubes.py
├── optimizers.py
├── __init__.py
├── metrics.py
├── losses.py
├── schedulers.py
├── augmentations.py
└── augmentations_mm.py
├── figs
├── MMSFormer-V2.png
└── MMSFormer-v2.png
├── tools
├── train.sh
├── infer_mm.py
├── val_mm.py
└── train_mm.py
├── configs
├── pst_rgbt.yaml
├── fmb_rgbt.yaml
└── mcubes_rgbadn.yaml
├── requirements.txt
├── .gitignore
├── environment.yaml
├── README.md
└── LICENSE
/semseg/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/semseg/models/layers/__init__.py:
--------------------------------------------------------------------------------
1 | from .common import *
2 | from .initialize import *
--------------------------------------------------------------------------------
/figs/MMSFormer-V2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CSIPlab/MMSFormer/HEAD/figs/MMSFormer-V2.png
--------------------------------------------------------------------------------
/figs/MMSFormer-v2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CSIPlab/MMSFormer/HEAD/figs/MMSFormer-v2.png
--------------------------------------------------------------------------------
/semseg/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .mmsformer import MMSFormer
2 |
3 | __all__ = [
4 | 'MMSFormer',
5 | ]
--------------------------------------------------------------------------------
/semseg/models/heads/__init__.py:
--------------------------------------------------------------------------------
1 | from .segformer import SegFormerHead
2 |
3 | __all__ = ['SegFormerHead']
--------------------------------------------------------------------------------
/semseg/models/backbones/__init__.py:
--------------------------------------------------------------------------------
1 | from .mmsformer import MMSFormer
2 |
3 | __all__ = [
4 | 'MMSFormer',
5 | ]
--------------------------------------------------------------------------------
/semseg/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from .mcubes import MCubeS
2 | from .fmb import FMB
3 | from .pst import PST
4 |
5 | __all__ = [
6 | 'MCubeS',
7 | 'FMB',
8 | 'PST',
9 | ]
--------------------------------------------------------------------------------
/semseg/datasets/unzip.py:
--------------------------------------------------------------------------------
1 | import zipfile
2 |
3 | with zipfile.ZipFile("data/MCubeS/multimodal_dataset.zip", "r") as zip_ref:
4 | for name in zip_ref.namelist():
5 | try:
6 | zip_ref.extract(name, "multimodal_dataset_extracted/")
7 | except zipfile.BadZipFile as e:
8 | print(e)
--------------------------------------------------------------------------------
/tools/train.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash -l
2 |
3 | #SBATCH --mem=16G
4 | #SBATCH --time=120:00:00
5 | #SBATCH --mail-user=emailaddr@domain.com
6 | #SBATCH --mail-type=ALL
7 | #SBATCH --cpus-per-task=16
8 | #SBATCH -p batch
9 | #SBATCH --output=output_%j-%N.txt # logging per job and per host in the current directory. Both stdout and stderr are logged
10 | #SBATCH --gres=gpu:2
11 |
12 | conda activate mmsformer
13 |
14 | python -m tools.train_mm --cfg configs/mcubes_rgbadn.yaml
--------------------------------------------------------------------------------
/semseg/optimizers.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | from torch.optim import AdamW, SGD
3 |
4 |
5 | def get_optimizer(model: nn.Module, optimizer: str, lr: float, weight_decay: float = 0.01):
6 | wd_params, nwd_params = [], []
7 | for p in model.parameters():
8 | if p.requires_grad:
9 | if p.dim() == 1:
10 | nwd_params.append(p)
11 | else:
12 | wd_params.append(p)
13 |
14 | params = [
15 | {"params": wd_params},
16 | {"params": nwd_params, "weight_decay": 0}
17 | ]
18 |
19 | if optimizer == 'adamw':
20 | return AdamW(params, lr, betas=(0.9, 0.999), eps=1e-8, weight_decay=weight_decay)
21 | else:
22 | return SGD(params, lr, momentum=0.9, weight_decay=weight_decay)
--------------------------------------------------------------------------------
/semseg/__init__.py:
--------------------------------------------------------------------------------
1 | from tabulate import tabulate
2 | from semseg import models
3 | from semseg import datasets
4 | from semseg.models import backbones, heads
5 |
6 |
7 | def show_models():
8 | model_names = models.__all__
9 | numbers = list(range(1, len(model_names)+1))
10 | print(tabulate({'No.': numbers, 'Model Names': model_names}, headers='keys'))
11 |
12 |
13 | def show_backbones():
14 | backbone_names = backbones.__all__
15 | variants = []
16 | for name in backbone_names:
17 | try:
18 | variants.append(list(eval(f"backbones.{name.lower()}_settings").keys()))
19 | except:
20 | variants.append('-')
21 | print(tabulate({'Backbone Names': backbone_names, 'Variants': variants}, headers='keys'))
22 |
23 |
24 | def show_heads():
25 | head_names = heads.__all__
26 | numbers = list(range(1, len(head_names)+1))
27 | print(tabulate({'No.': numbers, 'Heads': head_names}, headers='keys'))
28 |
29 |
30 | def show_datasets():
31 | dataset_names = datasets.__all__
32 | numbers = list(range(1, len(dataset_names)+1))
33 | print(tabulate({'No.': numbers, 'Datasets': dataset_names}, headers='keys'))
34 |
--------------------------------------------------------------------------------
/semseg/models/layers/common.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn, Tensor
3 |
4 |
5 | class ConvModule(nn.Sequential):
6 | def __init__(self, c1, c2, k, s=1, p=0, d=1, g=1):
7 | super().__init__(
8 | nn.Conv2d(c1, c2, k, s, p, d, g, bias=False),
9 | nn.BatchNorm2d(c2),
10 | nn.ReLU(True)
11 | )
12 |
13 |
14 | class DropPath(nn.Module):
15 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
16 | Copied from timm
17 | This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
18 | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
19 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
20 | changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
21 | 'survival rate' as the argument.
22 | """
23 | def __init__(self, p: float = None):
24 | super().__init__()
25 | self.p = p
26 |
27 | def forward(self, x: Tensor) -> Tensor:
28 | if self.p == 0. or not self.training:
29 | return x
30 | kp = 1 - self.p
31 | shape = (x.shape[0],) + (1,) * (x.ndim - 1)
32 | random_tensor = kp + torch.rand(shape, dtype=x.dtype, device=x.device)
33 | random_tensor.floor_() # binarize
34 | return x.div(kp) * random_tensor
--------------------------------------------------------------------------------
/semseg/models/mmsformer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn import functional as F
3 | from semseg.models.base import BaseModel
4 | from semseg.models.heads import SegFormerHead
5 |
6 |
7 | class MMSFormer(BaseModel):
8 | def __init__(self, backbone: str = 'MMSFormer-B0', num_classes: int = 20, modals: list = ['img', 'aolp', 'dolp', 'nir']) -> None:
9 | super().__init__(backbone, num_classes, modals)
10 | self.decode_head = SegFormerHead(self.backbone.channels, 256 if 'B0' in backbone or 'B1' in backbone else 512, num_classes)
11 | self.apply(self._init_weights)
12 |
13 | def forward(self, x: list) -> list:
14 | y = self.backbone(x)
15 | y = self.decode_head(y)
16 | y = F.interpolate(y, size=x[0].shape[2:], mode='bilinear', align_corners=False)
17 | return y
18 |
19 | def init_pretrained(self, pretrained: str = None) -> None:
20 | checkpoint = torch.load(pretrained, map_location='cpu')
21 | if 'state_dict' in checkpoint.keys():
22 | checkpoint = checkpoint['state_dict']
23 | if 'model' in checkpoint.keys():
24 | checkpoint = checkpoint['model']
25 | msg = self.backbone.load_state_dict(checkpoint, strict=False)
26 | del checkpoint
27 |
28 |
29 | if __name__ == '__main__':
30 | modals = ['img', 'aolp', 'dolp', 'nir']
31 | model = MMSFormer('MMSFormer-B2', 25, modals)
32 | model.init_pretrained('checkpoints/pretrained/segformer/mit_b2.pth')
33 | x = [torch.zeros(1, 3, 1024, 1024), torch.ones(1, 3, 1024, 1024), torch.ones(1, 3, 1024, 1024)*2, torch.ones(1, 3, 1024, 1024) *3]
34 | y = model(x)
35 | print(y.shape)
36 |
--------------------------------------------------------------------------------
/semseg/metrics.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import Tensor
3 | from typing import Tuple
4 |
5 |
6 | class Metrics:
7 | def __init__(self, num_classes: int, ignore_label: int, device) -> None:
8 | self.ignore_label = ignore_label
9 | self.num_classes = num_classes
10 | self.hist = torch.zeros(num_classes, num_classes).to(device)
11 |
12 | def update(self, pred: Tensor, target: Tensor) -> None:
13 | pred = pred.argmax(dim=1)
14 | keep = target != self.ignore_label
15 | self.hist += torch.bincount(target[keep] * self.num_classes + pred[keep], minlength=self.num_classes**2).view(self.num_classes, self.num_classes)
16 |
17 | def compute_iou(self) -> Tuple[Tensor, Tensor]:
18 | ious = self.hist.diag() / (self.hist.sum(0) + self.hist.sum(1) - self.hist.diag())
19 | ious[ious.isnan()]=0.
20 | miou = ious.mean().item()
21 | # miou = ious[~ious.isnan()].mean().item()
22 | ious *= 100
23 | miou *= 100
24 | return ious.cpu().numpy().round(2).tolist(), round(miou, 2)
25 |
26 | def compute_f1(self) -> Tuple[Tensor, Tensor]:
27 | f1 = 2 * self.hist.diag() / (self.hist.sum(0) + self.hist.sum(1))
28 | f1[f1.isnan()]=0.
29 | mf1 = f1.mean().item()
30 | # mf1 = f1[~f1.isnan()].mean().item()
31 | f1 *= 100
32 | mf1 *= 100
33 | return f1.cpu().numpy().round(2).tolist(), round(mf1, 2)
34 |
35 | def compute_pixel_acc(self) -> Tuple[Tensor, Tensor]:
36 | acc = self.hist.diag() / self.hist.sum(1)
37 | acc[acc.isnan()]=0.
38 | macc = acc.mean().item()
39 | # macc = acc[~acc.isnan()].mean().item()
40 | acc *= 100
41 | macc *= 100
42 | return acc.cpu().numpy().round(2).tolist(), round(macc, 2)
43 |
44 |
--------------------------------------------------------------------------------
/semseg/models/heads/segformer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn, Tensor
3 | from typing import Tuple
4 | from torch.nn import functional as F
5 |
6 |
7 | class MLP(nn.Module):
8 | def __init__(self, dim, embed_dim):
9 | super().__init__()
10 | self.proj = nn.Linear(dim, embed_dim)
11 |
12 | def forward(self, x: Tensor) -> Tensor:
13 | x = x.flatten(2).transpose(1, 2)
14 | x = self.proj(x)
15 | return x
16 |
17 |
18 | class ConvModule(nn.Module):
19 | def __init__(self, c1, c2):
20 | super().__init__()
21 | self.conv = nn.Conv2d(c1, c2, 1, bias=False)
22 | self.bn = nn.BatchNorm2d(c2) # use SyncBN in original
23 | self.activate = nn.ReLU(True)
24 |
25 | def forward(self, x: Tensor) -> Tensor:
26 | return self.activate(self.bn(self.conv(x)))
27 |
28 |
29 | class SegFormerHead(nn.Module):
30 | def __init__(self, dims: list, embed_dim: int = 256, num_classes: int = 19):
31 | super().__init__()
32 | for i, dim in enumerate(dims):
33 | self.add_module(f"linear_c{i+1}", MLP(dim, embed_dim))
34 |
35 | self.linear_fuse = ConvModule(embed_dim*4, embed_dim)
36 | self.linear_pred = nn.Conv2d(embed_dim, num_classes, 1)
37 | self.dropout = nn.Dropout2d(0.1)
38 |
39 | def forward(self, features: Tuple[Tensor, Tensor, Tensor, Tensor]) -> Tensor:
40 | B, _, H, W = features[0].shape
41 | outs = [self.linear_c1(features[0]).permute(0, 2, 1).reshape(B, -1, *features[0].shape[-2:])]
42 |
43 | for i, feature in enumerate(features[1:]):
44 | cf = eval(f"self.linear_c{i+2}")(feature).permute(0, 2, 1).reshape(B, -1, *feature.shape[-2:])
45 | outs.append(F.interpolate(cf, size=(H, W), mode='bilinear', align_corners=False))
46 |
47 | seg = self.linear_fuse(torch.cat(outs[::-1], dim=1))
48 | seg = self.linear_pred(self.dropout(seg))
49 | return seg
--------------------------------------------------------------------------------
/semseg/models/layers/initialize.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 | import warnings
4 | from torch import nn, Tensor
5 |
6 |
7 | def _no_grad_trunc_normal_(tensor, mean, std, a, b):
8 | # Cut & paste from PyTorch official master until it's in a few official releases - RW
9 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
10 | def norm_cdf(x):
11 | # Computes standard normal cumulative distribution function
12 | return (1. + math.erf(x / math.sqrt(2.))) / 2.
13 |
14 | if (mean < a - 2 * std) or (mean > b + 2 * std):
15 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
16 | "The distribution of values may be incorrect.",
17 | stacklevel=2)
18 |
19 | with torch.no_grad():
20 | # Values are generated by using a truncated uniform distribution and
21 | # then using the inverse CDF for the normal distribution.
22 | # Get upper and lower cdf values
23 | l = norm_cdf((a - mean) / std)
24 | u = norm_cdf((b - mean) / std)
25 |
26 | # Uniformly fill tensor with values from [l, u], then translate to
27 | # [2l-1, 2u-1].
28 | tensor.uniform_(2 * l - 1, 2 * u - 1)
29 |
30 | # Use inverse cdf transform for normal distribution to get truncated
31 | # standard normal
32 | tensor.erfinv_()
33 |
34 | # Transform to proper mean, std
35 | tensor.mul_(std * math.sqrt(2.))
36 | tensor.add_(mean)
37 |
38 | # Clamp to ensure it's in the proper range
39 | tensor.clamp_(min=a, max=b)
40 | return tensor
41 |
42 |
43 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
44 | # type: (Tensor, float, float, float, float) -> Tensor
45 | r"""Fills the input Tensor with values drawn from a truncated
46 | normal distribution. The values are effectively drawn from the
47 | normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
48 | with values outside :math:`[a, b]` redrawn until they are within
49 | the bounds. The method used for generating the random values works
50 | best when :math:`a \leq \text{mean} \leq b`.
51 | Args:
52 | tensor: an n-dimensional `torch.Tensor`
53 | mean: the mean of the normal distribution
54 | std: the standard deviation of the normal distribution
55 | a: the minimum cutoff value
56 | b: the maximum cutoff value
57 | Examples:
58 | >>> w = torch.empty(3, 5)
59 | >>> nn.init.trunc_normal_(w)
60 | """
61 | return _no_grad_trunc_normal_(tensor, mean, std, a, b)
62 |
--------------------------------------------------------------------------------
/configs/pst_rgbt.yaml:
--------------------------------------------------------------------------------
1 | DEVICE : cuda # device used for training and evaluation (cpu, cuda, cuda0, cuda1, ...)
2 | SAVE_DIR : 'output/MMSFormer' # output folder name used for saving the model, logs and inference results
3 | GPUs : 2
4 | GPU_IDs : [0, 1]
5 | USE_WANDB : False
6 | WANDB_NAME : 'MMSF-FMB-PST' # name for the run
7 |
8 | MODEL:
9 | NAME : MMSFormer # name of the model you are using
10 | BACKBONE : MMSFormer-B4 # model variant
11 | PRETRAINED : 'checkpoints/pretrained/segformer/mit_b4.pth' # backbone model's weight
12 | RESUME : ''
13 |
14 | DATASET:
15 | NAME : PST # dataset name to be trained with (camvid, cityscapes, ade20k)
16 | ROOT : 'PATH/TO/DATASET/ROOT' # dataset root path
17 | IGNORE_LABEL : 255
18 | # MODALS : ['img']
19 | MODALS : ['img', 'thermal']
20 |
21 | TRAIN:
22 | IMAGE_SIZE : [1280, 720] # training image size in (h, w) === Fixed in dataloader, following MCubeSNet
23 | BATCH_SIZE : 2 # batch size used to train
24 | EPOCHS : 200 # number of epochs to train
25 | EVAL_START : 0 # evaluation interval during training
26 | EVAL_INTERVAL : 1 # evaluation interval during training
27 | AMP : true # use AMP in training
28 | DDP : false # use DDP training
29 |
30 | LOSS:
31 | NAME : OhemCrossEntropy # loss function name
32 | CLS_WEIGHTS : false # use class weights in loss calculation
33 |
34 | OPTIMIZER:
35 | NAME : adamw # optimizer name
36 | LR : 0.00006 # initial learning rate used in optimizer
37 | WEIGHT_DECAY : 0.01 # decay rate used in optimizer
38 |
39 | SCHEDULER:
40 | NAME : warmuppolylr # scheduler name
41 | POWER : 0.9 # scheduler power
42 | WARMUP : 10 # warmup epochs used in scheduler
43 | WARMUP_RATIO : 0.1 # warmup ratio
44 |
45 |
46 | EVAL:
47 | MODEL_PATH : 'PATH/TO/MODEL/WEIGHT' # Path to your saved model
48 | IMAGE_SIZE : [1280, 720] # evaluation image size in (h, w)
49 | BATCH_SIZE : 2 # batch size
50 | VIS_SAVE_DIR : 'PATH/TO/TARGET/DIRECTORY' # Where to save visualization
51 | MSF:
52 | ENABLE : false # multi-scale and flip evaluation
53 | FLIP : true # use flip in evaluation
54 | SCALES : [0.5, 0.75, 1.0, 1.25, 1.5, 1.75] # scales used in MSF evaluation
55 |
--------------------------------------------------------------------------------
/configs/fmb_rgbt.yaml:
--------------------------------------------------------------------------------
1 | DEVICE : cuda # device used for training and evaluation (cpu, cuda, cuda0, cuda1, ...)
2 | SAVE_DIR : 'output/MMSFormer' # output folder name used for saving the model, logs and inference results
3 | GPUs : 2
4 | GPU_IDs : [0, 1]
5 | USE_WANDB : False
6 | WANDB_NAME : 'MMSF-FMB-RGBT' # name for the run
7 |
8 | MODEL:
9 | NAME : MMSFormer # name of the model you are using
10 | BACKBONE : MMSFormer-B3 # model variant
11 | PRETRAINED : 'checkpoints/pretrained/segformer/mit_b3.pth' # backbone model's weight
12 | RESUME : '' # checkpoint file
13 |
14 | DATASET:
15 | NAME : FMB # dataset name to be trained with (camvid, cityscapes, ade20k)
16 | ROOT : 'PATH/TO/DATASET/ROOT' # dataset root path
17 | IGNORE_LABEL : 255
18 | # MODALS : ['img']
19 | MODALS : ['img', 'thermal']
20 |
21 | TRAIN:
22 | IMAGE_SIZE : [800, 600] # training image size in (h, w) === Fixed in dataloader, following MCubeSNet
23 | BATCH_SIZE : 2 # batch size used to train
24 | EPOCHS : 120 # number of epochs to train
25 | EVAL_START : 0 # evaluation interval during training
26 | EVAL_INTERVAL : 1 # evaluation interval during training
27 | AMP : true # use AMP in training
28 | DDP : false # use DDP training
29 |
30 | LOSS:
31 | NAME : OhemCrossEntropy # loss function name
32 | CLS_WEIGHTS : false # use class weights in loss calculation
33 |
34 | OPTIMIZER:
35 | NAME : adamw # optimizer name
36 | LR : 0.0001 # initial learning rate used in optimizer
37 | WEIGHT_DECAY : 0.01 # decay rate used in optimizer
38 |
39 | SCHEDULER:
40 | NAME : warmuppolylr # scheduler name
41 | POWER : 0.9 # scheduler power
42 | WARMUP : 10 # warmup epochs used in scheduler
43 | WARMUP_RATIO : 0.1 # warmup ratio
44 |
45 |
46 | EVAL:
47 | MODEL_PATH : 'PATH/TO/MODEL/WEIGHT' # Path to your saved model
48 | IMAGE_SIZE : [800, 600] # evaluation image size in (h, w)
49 | BATCH_SIZE : 2 # batch size
50 | VIS_SAVE_DIR : 'PATH/TO/TARGET/DIRECTORY' # Where to save visualization
51 | MSF:
52 | ENABLE : false # multi-scale and flip evaluation
53 | FLIP : true # use flip in evaluation
54 | SCALES : [0.5, 0.75, 1.0, 1.25, 1.5, 1.75] # scales used in MSF evaluation
55 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==1.2.0
2 | addict==2.4.0
3 | appdirs==1.4.4
4 | argon2-cffi==21.3.0
5 | argon2-cffi-bindings==21.2.0
6 | asttokens==2.0.5
7 | attrs==21.4.0
8 | backcall==0.2.0
9 | bleach==4.1.0
10 | cachetools==5.0.0
11 | certifi==2021.10.8
12 | cffi
13 | charset-normalizer==2.1.1
14 | click==8.1.6
15 | cycler==0.11.0
16 | dataclasses==0.6
17 | debugpy==1.5.1
18 | decorator==5.1.1
19 | defusedxml==0.7.1
20 | descartes==1.1.0
21 | docker-pycreds==0.4.0
22 | easydict==1.9
23 | einops==0.4.1
24 | entrypoints==0.4
25 | executing==0.8.3
26 | fire==0.4.0
27 | future
28 | fvcore==0.1.5.post20220512
29 | gitdb==4.0.10
30 | GitPython==3.1.41
31 | google-auth==2.11.0
32 | google-auth-oauthlib==0.4.6
33 | grpcio==1.48.1
34 | idna==3.3
35 | importlib-metadata==4.12.0
36 | importlib-resources==5.4.0
37 | iopath==0.1.10
38 | ipykernel==6.9.1
39 | ipython==8.1.0
40 | ipython-genutils==0.2.0
41 | ipywidgets==7.6.5
42 | jedi==0.18.1
43 | Jinja2==3.0.3
44 | joblib==1.1.0
45 | jsonschema==4.4.0
46 | jupyter==1.0.0
47 | jupyter-client==7.1.2
48 | jupyter-console==6.4.0
49 | jupyter-core==4.9.2
50 | jupyterlab-pygments==0.1.2
51 | jupyterlab-widgets==1.0.2
52 | kiwisolver==1.3.2
53 | Markdown==3.4.1
54 | MarkupSafe==2.1.1
55 | matplotlib==3.4.3
56 | matplotlib-inline==0.1.3
57 | mistune==0.8.4
58 | mkl-fft==1.3.0
59 | mkl-random
60 | mkl-service==2.4.0
61 | mmcv==0.6.2
62 | nbclient==0.5.11
63 | nbconvert==6.4.2
64 | nbformat==5.1.3
65 | nest-asyncio==1.5.4
66 | notebook==6.4.8
67 | numpy
68 | nuscenes-devkit==1.1.9
69 | oauthlib==3.2.1
70 | olefile
71 | opencv-python==4.5.3.56
72 | packaging==21.3
73 | pandocfilters==1.5.0
74 | parso==0.8.3
75 | pathtools==0.1.2
76 | pexpect==4.8.0
77 | pickleshare==0.7.5
78 | Pillow
79 | plyfile==0.7.4
80 | portalocker==2.5.1
81 | prometheus-client==0.13.1
82 | prompt-toolkit==3.0.28
83 | protobuf==3.18.1
84 | psutil==5.9.5
85 | ptyprocess==0.7.0
86 | pure-eval==0.2.2
87 | pyasn1==0.4.8
88 | pyasn1-modules==0.2.8
89 | pycocotools==2.0.4
90 | pycparser
91 | Pygments==2.11.2
92 | pyparsing==3.0.6
93 | pyquaternion==0.9.9
94 | pyrsistent==0.18.1
95 | python-dateutil==2.8.2
96 | PyYAML==6.0
97 | pyzmq==22.3.0
98 | qtconsole==5.2.2
99 | QtPy==2.0.1
100 | requests==2.28.1
101 | requests-oauthlib==1.3.1
102 | rsa==4.9
103 | scikit-learn==1.0.2
104 | scipy==1.7.1
105 | Send2Trash==1.8.0
106 | sentry-sdk==1.28.1
107 | setproctitle==1.3.2
108 | Shapely==1.8.1.post1
109 | six
110 | smmap==5.0.0
111 | stack-data==0.2.0
112 | tabulate==0.8.10
113 | tensorboard==2.10.0
114 | tensorboard-data-server==0.6.1
115 | tensorboard-plugin-wit==1.8.1
116 | tensorboardX==2.4
117 | termcolor==1.1.0
118 | terminado==0.13.1
119 | testpath==0.6.0
120 | threadpoolctl==3.1.0
121 | timm==0.4.12
122 | torch
123 | torchaudio==0.9.0a0+33b2469
124 | torchvision
125 | tornado==6.1
126 | tqdm==4.62.3
127 | traitlets==5.1.1
128 | typing-extensions
129 | urllib3==1.26.12
130 | wandb==0.15.7
131 | wcwidth==0.2.5
132 | webencodings==0.5.1
133 | Werkzeug==2.2.2
134 | widgetsnbextension==3.5.2
135 | yacs==0.1.8
136 | yapf==0.32.0
137 | zipp==3.7.0
138 |
--------------------------------------------------------------------------------
/configs/mcubes_rgbadn.yaml:
--------------------------------------------------------------------------------
1 | DEVICE : cuda # device used for training and evaluation (cpu, cuda, cuda0, cuda1, ...)
2 | SAVE_DIR : 'output/MMSFormer' # output folder name used for saving the model, logs and inference results
3 | GPUs : 2
4 | GPU_IDs : [0, 1]
5 | USE_WANDB : False # Whether you want to use wandb
6 | WANDB_NAME : 'MMSF-MCubeS-RGBNAD' # name for the run
7 |
8 | MODEL:
9 | NAME : MMSFormer # name of the model you are using
10 | BACKBONE : MMSFormer-B4 # model variant
11 | PRETRAINED : 'checkpoints/pretrained/segformer/mit_b4.pth' # backbone model's weight
12 | RESUME : '' # checkpoint file
13 |
14 | DATASET:
15 | NAME : MCubeS # dataset name to be trained with (camvid, cityscapes, ade20k)
16 | ROOT : 'PATH/TO/DATASET/ROOT' # dataset root path
17 | IGNORE_LABEL : 255
18 | # MODALS : ['image']
19 | # MODALS : ['image', 'nir']
20 | # MODALS : ['image', 'aolp']
21 | # MODALS : ['image', 'dolp']
22 | # MODALS : ['image', 'aolp', 'nir']
23 | # MODALS : ['image', 'aolp', 'dolp']
24 | MODALS : ['image', 'nir', 'aolp', 'dolp']
25 |
26 | TRAIN:
27 | IMAGE_SIZE : [512, 512] # training image size in (h, w) === Fixed in dataloader, following MCubeSNet
28 | BATCH_SIZE : 4 # batch size used to train
29 | EPOCHS : 500 # number of epochs to train
30 | EVAL_START : 0 # evaluation interval during training
31 | EVAL_INTERVAL : 1 # evaluation interval during training
32 | AMP : true # use AMP in training
33 | DDP : false # use DDP training
34 |
35 | LOSS:
36 | NAME : OhemCrossEntropy # loss function name
37 | CLS_WEIGHTS : false # use class weights in loss calculation
38 |
39 | OPTIMIZER:
40 | NAME : adamw # optimizer name
41 | LR : 0.00006 # initial learning rate used in optimizer
42 | WEIGHT_DECAY : 0.01 # decay rate used in optimizer
43 |
44 | SCHEDULER:
45 | NAME : warmuppolylr # scheduler name
46 | POWER : 0.9 # scheduler power
47 | WARMUP : 10 # warmup epochs used in scheduler
48 | WARMUP_RATIO : 0.1 # warmup ratio
49 |
50 |
51 | EVAL:
52 | MODEL_PATH : 'PATH/TO/MODEL/WEIGHT' # Path to your saved model
53 | IMAGE_SIZE : [1024, 1024] # evaluation image size in (h, w)
54 | BATCH_SIZE : 2 # batch size
55 | VIS_SAVE_DIR : 'PATH/TO/TARGET/DIRECTORY' # Where to save visualization
56 | MSF:
57 | ENABLE : false # multi-scale and flip evaluation
58 | FLIP : true # use flip in evaluation
59 | SCALES : [0.5, 0.75, 1.0, 1.25, 1.5, 1.75] # scales used in MSF evaluation
60 |
--------------------------------------------------------------------------------
/semseg/models/base.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 | from torch import nn
4 | from semseg.models.backbones import *
5 | from semseg.models.layers import trunc_normal_
6 | from collections import OrderedDict
7 |
8 | def load_dualpath_model(model, model_file):
9 | # load raw state_dict
10 | if isinstance(model_file, str):
11 | raw_state_dict = torch.load(model_file, map_location=torch.device('cpu'))
12 | #raw_state_dict = torch.load(model_file)
13 | if 'model' in raw_state_dict.keys():
14 | raw_state_dict = raw_state_dict['model']
15 | else:
16 | raw_state_dict = model_file
17 |
18 | state_dict = {}
19 | for k, v in raw_state_dict.items():
20 | if k.find('patch_embed') >= 0:
21 | state_dict[k] = v
22 | # patch_embedx, proj, weight = k.split('.')
23 | # state_dict[k.replace('patch_embed', 'extra_patch_embed')] = v
24 | # state_dict[new_k] = v
25 | elif k.find('block') >= 0:
26 | state_dict[k] = v
27 | # state_dict[k.replace('block', 'extra_block')] = v
28 | elif k.find('norm') >= 0:
29 | state_dict[k] = v
30 | # state_dict[k.replace('norm', 'extra_norm')] = v
31 |
32 | msg = model.load_state_dict(state_dict, strict=False)
33 | print(msg)
34 | del state_dict
35 |
36 |
37 | class BaseModel(nn.Module):
38 | def __init__(self, backbone: str = 'MiT-B0', num_classes: int = 19, modals: list = ['rgb', 'depth', 'event', 'lidar']) -> None:
39 | super().__init__()
40 | backbone, variant = backbone.split('-')
41 | self.backbone = eval(backbone)(variant, modals)
42 | # self.backbone = eval(backbone)(variant)
43 | self.modals = modals
44 |
45 | def _init_weights(self, m: nn.Module) -> None:
46 | if isinstance(m, nn.Linear):
47 | trunc_normal_(m.weight, std=.02)
48 | if m.bias is not None:
49 | nn.init.zeros_(m.bias)
50 | elif isinstance(m, nn.Conv2d):
51 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
52 | fan_out // m.groups
53 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
54 | if m.bias is not None:
55 | nn.init.zeros_(m.bias)
56 | elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)):
57 | nn.init.ones_(m.weight)
58 | nn.init.zeros_(m.bias)
59 |
60 | def init_pretrained(self, pretrained: str = None) -> None:
61 | if pretrained:
62 | if len(self.modals)>1:
63 | load_dualpath_model(self.backbone, pretrained)
64 | else:
65 | checkpoint = torch.load(pretrained, map_location='cpu')
66 | if 'state_dict' in checkpoint.keys():
67 | checkpoint = checkpoint['state_dict']
68 | # if 'PoolFormer' in self.__class__.__name__:
69 | # new_dict = OrderedDict()
70 | # for k, v in checkpoint.items():
71 | # if not 'backbone.' in k:
72 | # new_dict['backbone.'+k] = v
73 | # else:
74 | # new_dict[k] = v
75 | # checkpoint = new_dict
76 | if 'model' in checkpoint.keys(): # --- for HorNet
77 | checkpoint = checkpoint['model']
78 | msg = self.backbone.load_state_dict(checkpoint, strict=False)
79 | print(msg)
80 |
--------------------------------------------------------------------------------
/semseg/datasets/pst.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 | from torch import Tensor
5 | from torch.utils.data import Dataset
6 | import torchvision.transforms.functional as TF
7 | from torchvision import io
8 | from pathlib import Path
9 | from typing import Tuple
10 | import glob
11 | import einops
12 | from torch.utils.data import DataLoader
13 | from torch.utils.data import DistributedSampler, RandomSampler
14 | from semseg.augmentations_mm import get_train_augmentation
15 |
16 | class PST(Dataset):
17 | """
18 | num_classes: 5
19 | """
20 | CLASSES = ["Background", "Fire-Extinguisher", "Backpack", "Hand-Drill", "Survivor"]
21 |
22 | PALETTE = torch.tensor([[0, 0, 0],
23 | [100, 40, 40],
24 | [55, 90, 80],
25 | [220, 20, 60],
26 | [153, 153, 153]])
27 |
28 | def __init__(self, root: str = 'data/PST900', split: str = 'train', transform = None, modals = ['img', 'thermal'], case = None) -> None:
29 | super().__init__()
30 | assert split in ['train', 'val']
31 | self.transform = transform
32 | self.n_classes = len(self.CLASSES)
33 | self.ignore_label = 255
34 | self.modals = modals
35 | if split == 'val':
36 | split = 'test'
37 | self.files = sorted(glob.glob(os.path.join(*[root, split, 'rgb', '*.png'])))
38 | # --- debug
39 | # self.files = sorted(glob.glob(os.path.join(*[root, 'img', '*', split, '*', '*.png'])))[:100]
40 | # --- split as case
41 | if not self.files:
42 | raise Exception(f"No images found in {img_path}")
43 | print(f"Found {len(self.files)} {split} images.")
44 |
45 | def __len__(self) -> int:
46 | return len(self.files)
47 |
48 | def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]:
49 | item_name = self.files[index].split("/")[-1].split(".")[0]
50 | rgb = str(self.files[index])
51 | thermal = rgb.replace('/rgb', '/thermal')
52 | lbl_path = rgb.replace('/rgb', '/labels')
53 |
54 | sample = {}
55 | sample['img'] = io.read_image(rgb)[:3, ...]
56 | H, W = sample['img'].shape[1:]
57 | if 'thermal' in self.modals:
58 | sample['thermal'] = self._open_img(thermal)
59 | label = io.read_image(lbl_path)[0,...].unsqueeze(0)
60 | # label[label==255] = 0
61 | # label -= 1
62 | sample['mask'] = label
63 |
64 | if self.transform:
65 | sample = self.transform(sample)
66 | label = sample['mask']
67 | del sample['mask']
68 | label = self.encode(label.squeeze().numpy()).long()
69 | sample = [sample[k] for k in self.modals]
70 | # return sample, label, item_name
71 | return sample, label
72 |
73 | def _open_img(self, file):
74 | img = io.read_image(file)
75 | C, H, W = img.shape
76 | if C == 4:
77 | img = img[:3, ...]
78 | if C == 1:
79 | img = img.repeat(3, 1, 1)
80 | return img
81 |
82 | def encode(self, label: Tensor) -> Tensor:
83 | return torch.from_numpy(label)
84 |
85 |
86 | if __name__ == '__main__':
87 | cases = ['cloud', 'fog', 'night', 'rain', 'sun', 'motionblur', 'overexposure', 'underexposure', 'lidarjitter', 'eventlowres']
88 | traintransform = get_train_augmentation((1024, 1024), seg_fill=255)
89 | for case in cases:
90 |
91 | trainset = DELIVER(transform=traintransform, split='val', case=case)
92 | trainloader = DataLoader(trainset, batch_size=2, num_workers=2, drop_last=False, pin_memory=False)
93 |
94 | for i, (sample, lbl) in enumerate(trainloader):
95 | print(torch.unique(lbl))
--------------------------------------------------------------------------------
/semseg/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn, Tensor
3 | from torch.nn import functional as F
4 |
5 |
6 | class CrossEntropy(nn.Module):
7 | def __init__(self, ignore_label: int = 255, weight: Tensor = None, aux_weights: list = [1, 0.4, 0.4]) -> None:
8 | super().__init__()
9 | self.aux_weights = aux_weights
10 | self.criterion = nn.CrossEntropyLoss(weight=weight, ignore_index=ignore_label)
11 |
12 | def _forward(self, preds: Tensor, labels: Tensor) -> Tensor:
13 | # preds in shape [B, C, H, W] and labels in shape [B, H, W]
14 | return self.criterion(preds, labels)
15 |
16 | def forward(self, preds, labels: Tensor) -> Tensor:
17 | if isinstance(preds, tuple):
18 | return sum([w * self._forward(pred, labels) for (pred, w) in zip(preds, self.aux_weights)])
19 | return self._forward(preds, labels)
20 |
21 |
22 | class OhemCrossEntropy(nn.Module):
23 | def __init__(self, ignore_label: int = 255, weight: Tensor = None, thresh: float = 0.7, aux_weights: list = [1, 1]) -> None:
24 | super().__init__()
25 | self.ignore_label = ignore_label
26 | self.aux_weights = aux_weights
27 | self.thresh = -torch.log(torch.tensor(thresh, dtype=torch.float))
28 | self.criterion = nn.CrossEntropyLoss(weight=weight, ignore_index=ignore_label, reduction='none')
29 |
30 | def _forward(self, preds: Tensor, labels: Tensor) -> Tensor:
31 | # preds in shape [B, C, H, W] and labels in shape [B, H, W]
32 | n_min = labels[labels != self.ignore_label].numel() // 16
33 | loss = self.criterion(preds, labels).view(-1)
34 | loss_hard = loss[loss > self.thresh]
35 |
36 | if loss_hard.numel() < n_min:
37 | loss_hard, _ = loss.topk(n_min)
38 |
39 | return torch.mean(loss_hard)
40 |
41 | def forward(self, preds, labels: Tensor) -> Tensor:
42 | if isinstance(preds, tuple):
43 | return sum([w * self._forward(pred, labels) for (pred, w) in zip(preds, self.aux_weights)])
44 | return self._forward(preds, labels)
45 |
46 |
47 | class Dice(nn.Module):
48 | def __init__(self, delta: float = 0.5, aux_weights: list = [1, 0.4, 0.4]):
49 | """
50 | delta: Controls weight given to FP and FN. This equals to dice score when delta=0.5
51 | """
52 | super().__init__()
53 | self.delta = delta
54 | self.aux_weights = aux_weights
55 |
56 | def _forward(self, preds: Tensor, labels: Tensor) -> Tensor:
57 | # preds in shape [B, C, H, W] and labels in shape [B, H, W]
58 | num_classes = preds.shape[1]
59 | labels = F.one_hot(labels, num_classes).permute(0, 3, 1, 2)
60 | tp = torch.sum(labels*preds, dim=(2, 3))
61 | fn = torch.sum(labels*(1-preds), dim=(2, 3))
62 | fp = torch.sum((1-labels)*preds, dim=(2, 3))
63 |
64 | dice_score = (tp + 1e-6) / (tp + self.delta * fn + (1 - self.delta) * fp + 1e-6)
65 | dice_score = torch.sum(1 - dice_score, dim=-1)
66 |
67 | dice_score = dice_score / num_classes
68 | return dice_score.mean()
69 |
70 | def forward(self, preds, targets: Tensor) -> Tensor:
71 | if isinstance(preds, tuple):
72 | return sum([w * self._forward(pred, targets) for (pred, w) in zip(preds, self.aux_weights)])
73 | return self._forward(preds, targets)
74 |
75 |
76 | __all__ = ['CrossEntropy', 'OhemCrossEntropy', 'Dice']
77 |
78 |
79 | def get_loss(loss_fn_name: str = 'CrossEntropy', ignore_label: int = 255, cls_weights: Tensor = None):
80 | assert loss_fn_name in __all__, f"Unavailable loss function name >> {loss_fn_name}.\nAvailable loss functions: {__all__}"
81 | if loss_fn_name == 'Dice':
82 | return Dice()
83 | return eval(loss_fn_name)(ignore_label, cls_weights)
84 |
85 |
86 | if __name__ == '__main__':
87 | pred = torch.randint(0, 19, (2, 19, 480, 640), dtype=torch.float)
88 | label = torch.randint(0, 19, (2, 480, 640), dtype=torch.long)
89 | loss_fn = Dice()
90 | y = loss_fn(pred, label)
91 | print(y)
--------------------------------------------------------------------------------
/semseg/datasets/fmb.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 | from torch import Tensor
5 | from torch.utils.data import Dataset
6 | import torchvision.transforms.functional as TF
7 | from torchvision import io
8 | from pathlib import Path
9 | from typing import Tuple
10 | import glob
11 | import einops
12 | from torch.utils.data import DataLoader
13 | from torch.utils.data import DistributedSampler, RandomSampler
14 | from semseg.augmentations_mm import get_train_augmentation
15 |
16 | class FMB(Dataset):
17 | """
18 | num_classes: 14
19 | """
20 | CLASSES = ["Road", "Sidewalk", "Building", "Traffic Light", "Traffic Sign", "Vegetation", "Sky", "Person", "Car", "Truck", "Bus", "Motorcycle", "Bicycle", "Pole"]
21 |
22 | PALETTE = torch.tensor([[70, 70, 70],
23 | [100, 40, 40],
24 | [55, 90, 80],
25 | [220, 20, 60],
26 | [153, 153, 153],
27 | [157, 234, 50],
28 | [128, 64, 128],
29 | [244, 35, 232],
30 | [107, 142, 35],
31 | [0, 0, 142],
32 | [102, 102, 156],
33 | [220, 220, 0],
34 | [70, 130, 180],
35 | [81, 0, 81],
36 | [150, 100, 100],
37 | ])
38 |
39 | def __init__(self, root: str = 'data/FMB', split: str = 'train', transform = None, modals = ['img', 'thermal'], case = None) -> None:
40 | super().__init__()
41 | assert split in ['train', 'val']
42 | self.transform = transform
43 | self.n_classes = len(self.CLASSES)
44 | self.ignore_label = 255
45 | self.modals = modals
46 | if split == 'val':
47 | split = 'test'
48 | self.files = sorted(glob.glob(os.path.join(*[root, split, 'Visible', '*.png'])))
49 | # --- debug
50 | # self.files = sorted(glob.glob(os.path.join(*[root, 'img', '*', split, '*', '*.png'])))[:100]
51 | # --- split as case
52 | if not self.files:
53 | raise Exception(f"No images found in {img_path}")
54 | print(f"Found {len(self.files)} {split} images.")
55 |
56 | def __len__(self) -> int:
57 | return len(self.files)
58 |
59 | def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]:
60 | item_name = self.files[index].split("/")[-1].split(".")[0]
61 | rgb = str(self.files[index])
62 | thermal = rgb.replace('/Visible', '/Infrared')
63 | lbl_path = rgb.replace('/Visible', '/Label')
64 |
65 | sample = {}
66 | sample['img'] = io.read_image(rgb)[:3, ...]
67 | H, W = sample['img'].shape[1:]
68 | if 'thermal' in self.modals:
69 | sample['thermal'] = self._open_img(thermal)
70 | label = io.read_image(lbl_path)[0,...].unsqueeze(0)
71 | label[label==255] = 0
72 | label -= 1
73 | sample['mask'] = label
74 |
75 | if self.transform:
76 | sample = self.transform(sample)
77 | label = sample['mask']
78 | del sample['mask']
79 | label = self.encode(label.squeeze().numpy()).long()
80 | sample = [sample[k] for k in self.modals]
81 | # return sample, label, item_name
82 | return sample, label
83 |
84 | def _open_img(self, file):
85 | img = io.read_image(file)
86 | C, H, W = img.shape
87 | if C == 4:
88 | img = img[:3, ...]
89 | if C == 1:
90 | img = img.repeat(3, 1, 1)
91 | return img
92 |
93 | def encode(self, label: Tensor) -> Tensor:
94 | return torch.from_numpy(label)
95 |
96 |
97 | if __name__ == '__main__':
98 | cases = ['cloud', 'fog', 'night', 'rain', 'sun', 'motionblur', 'overexposure', 'underexposure', 'lidarjitter', 'eventlowres']
99 | traintransform = get_train_augmentation((1024, 1024), seg_fill=255)
100 | for case in cases:
101 |
102 | trainset = DELIVER(transform=traintransform, split='val', case=case)
103 | trainloader = DataLoader(trainset, batch_size=2, num_workers=2, drop_last=False, pin_memory=False)
104 |
105 | for i, (sample, lbl) in enumerate(trainloader):
106 | print(torch.unique(lbl))
--------------------------------------------------------------------------------
/semseg/schedulers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 | from torch.optim.lr_scheduler import _LRScheduler
4 |
5 |
6 | class PolyLR(_LRScheduler):
7 | def __init__(self, optimizer, max_iter, decay_iter=1, power=0.9, last_epoch=-1) -> None:
8 | self.decay_iter = decay_iter
9 | self.max_iter = max_iter
10 | self.power = power
11 | super().__init__(optimizer, last_epoch=last_epoch)
12 |
13 | def get_lr(self):
14 | if self.last_epoch % self.decay_iter or self.last_epoch % self.max_iter:
15 | return self.base_lrs
16 | else:
17 | factor = (1 - self.last_epoch / float(self.max_iter)) ** self.power
18 | return [factor*lr for lr in self.base_lrs]
19 |
20 |
21 | class WarmupLR(_LRScheduler):
22 | def __init__(self, optimizer, warmup_iter=500, warmup_ratio=5e-4, warmup='exp', last_epoch=-1) -> None:
23 | self.warmup_iter = warmup_iter
24 | self.warmup_ratio = warmup_ratio
25 | self.warmup = warmup
26 | super().__init__(optimizer, last_epoch)
27 |
28 | def get_lr(self):
29 | ratio = self.get_lr_ratio()
30 | return [ratio * lr for lr in self.base_lrs]
31 |
32 | def get_lr_ratio(self):
33 | return self.get_warmup_ratio() if self.last_epoch < self.warmup_iter else self.get_main_ratio()
34 |
35 | def get_main_ratio(self):
36 | raise NotImplementedError
37 |
38 | def get_warmup_ratio(self):
39 | assert self.warmup in ['linear', 'exp']
40 | alpha = self.last_epoch / self.warmup_iter
41 |
42 | return self.warmup_ratio + (1. - self.warmup_ratio) * alpha if self.warmup == 'linear' else self.warmup_ratio ** (1. - alpha)
43 |
44 |
45 | class WarmupPolyLR(WarmupLR):
46 | def __init__(self, optimizer, power, max_iter, warmup_iter=500, warmup_ratio=5e-4, warmup='exp', last_epoch=-1) -> None:
47 | self.power = power
48 | self.max_iter = max_iter
49 | super().__init__(optimizer, warmup_iter, warmup_ratio, warmup, last_epoch)
50 |
51 | def get_main_ratio(self):
52 | real_iter = self.last_epoch - self.warmup_iter
53 | real_max_iter = self.max_iter - self.warmup_iter
54 | alpha = real_iter / real_max_iter
55 |
56 | return (1 - alpha) ** self.power
57 |
58 |
59 | class WarmupExpLR(WarmupLR):
60 | def __init__(self, optimizer, gamma, interval=1, warmup_iter=500, warmup_ratio=5e-4, warmup='exp', last_epoch=-1) -> None:
61 | self.gamma = gamma
62 | self.interval = interval
63 | super().__init__(optimizer, warmup_iter, warmup_ratio, warmup, last_epoch)
64 |
65 | def get_main_ratio(self):
66 | real_iter = self.last_epoch - self.warmup_iter
67 | return self.gamma ** (real_iter // self.interval)
68 |
69 |
70 | class WarmupCosineLR(WarmupLR):
71 | def __init__(self, optimizer, max_iter, eta_ratio=0, warmup_iter=500, warmup_ratio=5e-4, warmup='exp', last_epoch=-1) -> None:
72 | self.eta_ratio = eta_ratio
73 | self.max_iter = max_iter
74 | super().__init__(optimizer, warmup_iter, warmup_ratio, warmup, last_epoch)
75 |
76 | def get_main_ratio(self):
77 | real_iter = self.last_epoch - self.warmup_iter
78 | real_max_iter = self.max_iter - self.warmup_iter
79 |
80 | return self.eta_ratio + (1 - self.eta_ratio) * (1 + math.cos(math.pi * self.last_epoch / real_max_iter)) / 2
81 |
82 |
83 |
84 | __all__ = ['polylr', 'warmuppolylr', 'warmupcosinelr', 'warmupsteplr']
85 |
86 |
87 | def get_scheduler(scheduler_name: str, optimizer, max_iter: int, power: int, warmup_iter: int, warmup_ratio: float):
88 | assert scheduler_name in __all__, f"Unavailable scheduler name >> {scheduler_name}.\nAvailable schedulers: {__all__}"
89 | if scheduler_name == 'warmuppolylr':
90 | return WarmupPolyLR(optimizer, power, max_iter, warmup_iter, warmup_ratio, warmup='linear')
91 | elif scheduler_name == 'warmupcosinelr':
92 | return WarmupCosineLR(optimizer, max_iter, warmup_iter=warmup_iter, warmup_ratio=warmup_ratio)
93 | return PolyLR(optimizer, max_iter)
94 |
95 |
96 | if __name__ == '__main__':
97 | model = torch.nn.Conv2d(3, 16, 3, 1, 1)
98 | optim = torch.optim.SGD(model.parameters(), lr=1e-3)
99 |
100 | max_iter = 20000
101 | sched = WarmupPolyLR(optim, power=0.9, max_iter=max_iter, warmup_iter=200, warmup_ratio=0.1, warmup='exp', last_epoch=-1)
102 |
103 | lrs = []
104 |
105 | for _ in range(max_iter):
106 | lr = sched.get_lr()[0]
107 | lrs.append(lr)
108 | optim.step()
109 | sched.step()
110 |
111 | import matplotlib.pyplot as plt
112 | import numpy as np
113 |
114 | plt.plot(np.arange(len(lrs)), np.array(lrs))
115 | plt.grid()
116 | plt.show()
--------------------------------------------------------------------------------
/semseg/utils/visualize.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import random
3 | import numpy as np
4 | import torch
5 | import matplotlib.pyplot as plt
6 | from torch.utils.data import DataLoader
7 | from torchvision import transforms as T
8 | from torchvision.utils import make_grid
9 | from semseg.augmentations import Compose, Normalize, RandomResizedCrop
10 | from PIL import Image, ImageDraw, ImageFont
11 |
12 |
13 | def visualize_dataset_sample(dataset, root, split='val', batch_size=4):
14 | transform = Compose([
15 | RandomResizedCrop((512, 512), scale=(1.0, 1.0)),
16 | Normalize()
17 | ])
18 |
19 | dataset = dataset(root, split=split, transform=transform)
20 | dataloader = DataLoader(dataset, shuffle=True, batch_size=batch_size)
21 | image, label = next(iter(dataloader))
22 |
23 | print(f"Image Shape\t: {image.shape}")
24 | print(f"Label Shape\t: {label.shape}")
25 | print(f"Classes\t\t: {label.unique().tolist()}")
26 |
27 | label[label == -1] = 0
28 | label[label == 255] = 0
29 | labels = [dataset.PALETTE[lbl.to(int)].permute(2, 0, 1) for lbl in label]
30 | labels = torch.stack(labels)
31 |
32 | inv_normalize = T.Normalize(
33 | mean=(-0.485/0.229, -0.456/0.224, -0.406/0.225),
34 | std=(1/0.229, 1/0.224, 1/0.225)
35 | )
36 | image = inv_normalize(image)
37 | image *= 255
38 | images = torch.vstack([image, labels])
39 |
40 | plt.imshow(make_grid(images, nrow=4).to(torch.uint8).numpy().transpose((1, 2, 0)))
41 | plt.show()
42 |
43 |
44 | colors = [
45 | [120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50], [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],
46 | [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7], [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
47 | [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3], [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],
48 | [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220], [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],
49 | [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255], [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],
50 | [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153], [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],
51 | [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0], [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],
52 | [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255], [11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255],
53 | [0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0], [255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0],
54 | [0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255], [173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255],
55 | [255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20], [255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255],
56 | [255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255], [0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255],
57 | [0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0], [143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0],
58 | [8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255], [255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112],
59 | [92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160], [163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163],
60 | [255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0], [255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0],
61 | [10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255], [255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204],
62 | [41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255], [71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255],
63 | [184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194], [102, 255, 0], [92, 0, 255]
64 | ]
65 |
66 |
67 | def generate_palette(num_classes, background: bool = False):
68 | random.shuffle(colors)
69 | if background:
70 | palette = [[0, 0, 0]]
71 | palette += colors[:num_classes-1]
72 | else:
73 | palette = colors[:num_classes]
74 | return np.array(palette)
75 |
76 |
77 | def draw_text(image: torch.Tensor, seg_map: torch.Tensor, labels: list, fontsize: int = 15):
78 | image = image.to(torch.uint8)
79 | font = ImageFont.truetype("Helvetica.ttf", fontsize)
80 | pil_image = Image.fromarray(image.numpy())
81 | draw = ImageDraw.Draw(pil_image)
82 |
83 | indices = seg_map.unique().tolist()
84 | classes = [labels[index] for index in indices]
85 |
86 | for idx, cls in zip(indices, classes):
87 | mask = seg_map == idx
88 | mask = mask.squeeze().numpy()
89 | center = np.median((mask == 1).nonzero(), axis=1)[::-1]
90 | bbox = draw.textbbox(center, cls, font=font)
91 | bbox = (bbox[0]-3, bbox[1]-3, bbox[2]+3, bbox[3]+3)
92 | draw.rectangle(bbox, fill=(255, 255, 255), width=1)
93 | draw.text(center, cls, fill=(0, 0, 0), font=font)
94 | return pil_image
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Repo-specific GitIgnore ----------------------------------------------------------------------------------------------
2 | *.jpg
3 | *.jpeg
4 | *.png
5 | *.bmp
6 | *.tif
7 | *.tiff
8 | *.heic
9 | *.JPG
10 | *.JPEG
11 | *.PNG
12 | *.BMP
13 | *.TIF
14 | *.TIFF
15 | *.HEIC
16 | *.mp4
17 | *.mov
18 | *.MOV
19 | *.avi
20 | *.data
21 | *.json
22 | *.pth
23 | *.cfg
24 | !cfg/yolov3*.cfg
25 |
26 | storage.googleapis.com
27 | runs/*
28 | data/*
29 | !data/images/zidane.jpg
30 | !data/images/bus.jpg
31 | !data/coco.names
32 | !data/coco_paper.names
33 | !data/coco.data
34 | !data/coco_*.data
35 | !data/coco_*.txt
36 | !data/trainvalno5k.shapes
37 | !data/*.sh
38 |
39 | test.py
40 | test_imgs/
41 |
42 | pycocotools/*
43 | results*.txt
44 | gcp_test*.sh
45 |
46 | checkpoints/
47 | # output/
48 | # output*/
49 | *events*
50 | assests/*/
51 |
52 | # Datasets -------------------------------------------------------------------------------------------------------------
53 | coco/
54 | coco128/
55 | VOC/
56 |
57 | # MATLAB GitIgnore -----------------------------------------------------------------------------------------------------
58 | *.m~
59 | *.mat
60 | !targets*.mat
61 |
62 | # Neural Network weights -----------------------------------------------------------------------------------------------
63 | *.weights
64 | *.pt
65 | *.onnx
66 | *.mlmodel
67 | *.torchscript
68 | darknet53.conv.74
69 | yolov3-tiny.conv.15
70 |
71 | # GitHub Python GitIgnore ----------------------------------------------------------------------------------------------
72 | # Byte-compiled / optimized / DLL files
73 | __pycache__/
74 | *.py[cod]
75 | *$py.class
76 |
77 | # C extensions
78 | *.so
79 |
80 | # Distribution / packaging
81 | .Python
82 | env/
83 | build/
84 | develop-eggs/
85 | dist/
86 | downloads/
87 | eggs/
88 | .eggs/
89 | lib/
90 | lib64/
91 | parts/
92 | sdist/
93 | var/
94 | wheels/
95 | *.egg-info/
96 | wandb/
97 | .installed.cfg
98 | *.egg
99 |
100 |
101 | # PyInstaller
102 | # Usually these files are written by a python script from a template
103 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
104 | *.manifest
105 | *.spec
106 |
107 | # Installer logs
108 | pip-log.txt
109 | pip-delete-this-directory.txt
110 |
111 | # Unit test / coverage reports
112 | htmlcov/
113 | .tox/
114 | .coverage
115 | .coverage.*
116 | .cache
117 | nosetests.xml
118 | coverage.xml
119 | *.cover
120 | .hypothesis/
121 |
122 | # Translations
123 | *.mo
124 | *.pot
125 |
126 | # Django stuff:
127 | # *.log
128 | local_settings.py
129 |
130 | # Flask stuff:
131 | instance/
132 | .webassets-cache
133 |
134 | # Scrapy stuff:
135 | .scrapy
136 |
137 | # Sphinx documentation
138 | docs/_build/
139 |
140 | # PyBuilder
141 | target/
142 |
143 | # Jupyter Notebook
144 | .ipynb_checkpoints
145 |
146 | # pyenv
147 | .python-version
148 |
149 | # celery beat schedule file
150 | celerybeat-schedule
151 |
152 | # SageMath parsed files
153 | *.sage.py
154 |
155 | # dotenv
156 | .env
157 |
158 | # virtualenv
159 | .venv*
160 | venv*/
161 | ENV*/
162 |
163 | # Spyder project settings
164 | .spyderproject
165 | .spyproject
166 |
167 | # Rope project settings
168 | .ropeproject
169 |
170 | # mkdocs documentation
171 | /site
172 |
173 | # mypy
174 | .mypy_cache/
175 |
176 |
177 | # https://github.com/github/gitignore/blob/master/Global/macOS.gitignore -----------------------------------------------
178 |
179 | # General
180 | .DS_Store
181 | .AppleDouble
182 | .LSOverride
183 |
184 | # Icon must end with two \r
185 | Icon
186 | Icon?
187 |
188 | # Thumbnails
189 | ._*
190 |
191 | # Files that might appear in the root of a volume
192 | .DocumentRevisions-V100
193 | .fseventsd
194 | .Spotlight-V100
195 | .TemporaryItems
196 | .Trashes
197 | .VolumeIcon.icns
198 | .com.apple.timemachine.donotpresent
199 |
200 | # Directories potentially created on remote AFP share
201 | .AppleDB
202 | .AppleDesktop
203 | Network Trash Folder
204 | Temporary Items
205 | .apdisk
206 |
207 |
208 | # https://github.com/github/gitignore/blob/master/Global/JetBrains.gitignore
209 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm
210 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
211 |
212 | # User-specific stuff:
213 | .idea/*
214 | .idea/**/workspace.xml
215 | .idea/**/tasks.xml
216 | .idea/dictionaries
217 | .html # Bokeh Plots
218 | .pg # TensorFlow Frozen Graphs
219 | .avi # videos
220 |
221 | # Sensitive or high-churn files:
222 | .idea/**/dataSources/
223 | .idea/**/dataSources.ids
224 | .idea/**/dataSources.local.xml
225 | .idea/**/sqlDataSources.xml
226 | .idea/**/dynamic.xml
227 | .idea/**/uiDesigner.xml
228 |
229 | # Gradle:
230 | .idea/**/gradle.xml
231 | .idea/**/libraries
232 |
233 | # CMake
234 | cmake-build-debug/
235 | cmake-build-release/
236 |
237 | # Mongo Explorer plugin:
238 | .idea/**/mongoSettings.xml
239 |
240 | ## File-based project format:
241 | *.iws
242 |
243 | ## Plugin-specific files:
244 |
245 | # IntelliJ
246 | out/
247 |
248 | # mpeltonen/sbt-idea plugin
249 | .idea_modules/
250 |
251 | # JIRA plugin
252 | atlassian-ide-plugin.xml
253 |
254 | # Cursive Clojure plugin
255 | .idea/replstate.xml
256 |
257 | # Crashlytics plugin (for Android Studio and IntelliJ)
258 | com_crashlytics_export_strings.xml
259 | crashlytics.properties
260 | crashlytics-build.properties
261 | fabric.properties
262 |
263 | output/
264 | data/
265 | wandb/
266 | output*.txt
--------------------------------------------------------------------------------
/environment.yaml:
--------------------------------------------------------------------------------
1 | name: mmsformer
2 | channels:
3 | - pytorch
4 | - defaults
5 | - conda-forge
6 | dependencies:
7 | - _libgcc_mutex=0.1=main
8 | - _openmp_mutex=4.5=1_gnu
9 | - _pytorch_select=0.1=cpu_0
10 | - blas=1.0=mkl
11 | - bzip2=1.0.8=h7f98852_4
12 | - ca-certificates=2021.10.26=h06a4308_2
13 | - certifi=2021.10.8=py38h06a4308_2
14 | - cffi=1.14.6=py38ha65f79e_0
15 | - cudatoolkit=11.3.1=h2bc3f7f_2
16 | - cudnn=8.2.1.32=h86fa8c9_0
17 | - ffmpeg=4.3=hf484d3e_0
18 | - freetype=2.10.4=h5ab3b9f_0
19 | - future=0.18.2=py38h578d9bd_4
20 | - gmp=6.2.1=h58526e2_0
21 | - gnutls=3.6.13=h85f3911_1
22 | - intel-openmp=2021.3.0=h06a4308_3350
23 | - jpeg=9d=h7f8727e_0
24 | - lame=3.100=h7f98852_1001
25 | - lcms2=2.12=h3be6417_0
26 | - ld_impl_linux-64=2.35.1=h7274673_9
27 | - libblas=3.9.0=11_linux64_mkl
28 | - libffi=3.3=he6710b0_2
29 | - libgcc-ng=9.3.0=h5101ec6_17
30 | - libgomp=9.3.0=h5101ec6_17
31 | - libiconv=1.16=h516909a_0
32 | - liblapack=3.9.0=11_linux64_mkl
33 | - libpng=1.6.37=hbc83047_0
34 | - libprotobuf=3.16.0=h780b84a_0
35 | - libstdcxx-ng=9.3.0=hd4cf53a_17
36 | - libtiff=4.2.0=h85742a9_0
37 | - libuv=1.40.0=h7b6447c_0
38 | - libwebp-base=1.2.0=h27cfd23_0
39 | - lz4-c=1.9.3=h295c915_1
40 | - magma=2.5.4=h6103c52_2
41 | - mkl=2021.3.0=h06a4308_520
42 | - mkl-service=2.4.0=py38h7f8727e_0
43 | - mkl_fft=1.3.0=py38h42c9631_2
44 | - mkl_random=1.2.2=py38h51133e4_0
45 | - nccl=2.11.4.1=hdc17891_0
46 | - ncurses=6.2=he6710b0_1
47 | - nettle=3.6=he412f7d_0
48 | - ninja=1.10.2=hff7bd54_1
49 | - numpy=1.21.2=py38h20f2e39_0
50 | - numpy-base=1.21.2=py38h79a1101_0
51 | - olefile=0.46=pyhd3eb1b0_0
52 | - openh264=2.1.1=h780b84a_0
53 | - openjpeg=2.4.0=h3ad879b_0
54 | - openssl=1.1.1m=h7f8727e_0
55 | - pillow=8.3.1=py38h2c7a002_0
56 | - pycparser=2.21=pyhd8ed1ab_0
57 | - python=3.8.12=h12debd9_0
58 | - python_abi=3.8=2_cp38
59 | - pytorch=1.9.0=cuda112py38h3d13190_1
60 | - pytorch-gpu=1.9.0=cuda112py38h0bbbad9_1
61 | - pytorch-mutex=1.0=cuda
62 | - readline=8.1=h27cfd23_0
63 | - six=1.16.0=pyhd3eb1b0_0
64 | - sleef=3.5.1=h7f98852_1
65 | - sqlite=3.36.0=hc218d9a_0
66 | - tk=8.6.11=h1ccaba5_0
67 | - torchaudio=0.9.0=py38
68 | - torchvision=0.10.0=py38cuda112h04b465a_0_cuda
69 | - typing_extensions=3.10.0.2=pyh06a4308_0
70 | - xz=5.2.5=h7b6447c_0
71 | - yaml=0.2.5=h7b6447c_0
72 | - zlib=1.2.11=h7b6447c_3
73 | - zstd=1.4.9=haebb681_0
74 | - pip:
75 | - absl-py==1.2.0
76 | - addict==2.4.0
77 | - appdirs==1.4.4
78 | - argon2-cffi==21.3.0
79 | - argon2-cffi-bindings==21.2.0
80 | - asttokens==2.0.5
81 | - attrs==21.4.0
82 | - backcall==0.2.0
83 | - bleach==4.1.0
84 | - cachetools==5.0.0
85 | - charset-normalizer==2.1.1
86 | - click==8.1.6
87 | - cycler==0.11.0
88 | - dataclasses==0.6
89 | - debugpy==1.5.1
90 | - decorator==5.1.1
91 | - defusedxml==0.7.1
92 | - descartes==1.1.0
93 | - docker-pycreds==0.4.0
94 | - easydict==1.9
95 | - einops==0.4.1
96 | - entrypoints==0.4
97 | - executing==0.8.3
98 | - fire==0.4.0
99 | - fvcore==0.1.5.post20220512
100 | - gitdb==4.0.10
101 | - gitpython==3.1.32
102 | - google-auth==2.11.0
103 | - google-auth-oauthlib==0.4.6
104 | - grpcio==1.48.1
105 | - idna==3.3
106 | - importlib-metadata==4.12.0
107 | - importlib-resources==5.4.0
108 | - iopath==0.1.10
109 | - ipykernel==6.9.1
110 | - ipython==8.1.0
111 | - ipython-genutils==0.2.0
112 | - ipywidgets==7.6.5
113 | - jedi==0.18.1
114 | - jinja2==3.0.3
115 | - joblib==1.1.0
116 | - jsonschema==4.4.0
117 | - jupyter==1.0.0
118 | - jupyter-client==7.1.2
119 | - jupyter-console==6.4.0
120 | - jupyter-core==4.9.2
121 | - jupyterlab-pygments==0.1.2
122 | - jupyterlab-widgets==1.0.2
123 | - kiwisolver==1.3.2
124 | - markdown==3.4.1
125 | - markupsafe==2.1.1
126 | - matplotlib==3.4.3
127 | - matplotlib-inline==0.1.3
128 | - mistune==0.8.4
129 | - mmcv==0.6.2
130 | - nbclient==0.5.11
131 | - nbconvert==6.4.2
132 | - nbformat==5.1.3
133 | - nest-asyncio==1.5.4
134 | - notebook==6.4.8
135 | - nuscenes-devkit==1.1.9
136 | - oauthlib==3.2.1
137 | - opencv-python==4.5.3.56
138 | - packaging==21.3
139 | - pandocfilters==1.5.0
140 | - parso==0.8.3
141 | - pathtools==0.1.2
142 | - pexpect==4.8.0
143 | - pickleshare==0.7.5
144 | - pip==22.0.3
145 | - plyfile==0.7.4
146 | - portalocker==2.5.1
147 | - prometheus-client==0.13.1
148 | - prompt-toolkit==3.0.28
149 | - protobuf==3.18.1
150 | - psutil==5.9.5
151 | - ptyprocess==0.7.0
152 | - pure-eval==0.2.2
153 | - pyasn1==0.4.8
154 | - pyasn1-modules==0.2.8
155 | - pycocotools==2.0.4
156 | - pygments==2.11.2
157 | - pyparsing==3.0.6
158 | - pyquaternion==0.9.9
159 | - pyrsistent==0.18.1
160 | - python-dateutil==2.8.2
161 | - pyyaml==6.0
162 | - pyzmq==22.3.0
163 | - qtconsole==5.2.2
164 | - qtpy==2.0.1
165 | - requests==2.28.1
166 | - requests-oauthlib==1.3.1
167 | - rsa==4.9
168 | - scikit-learn==1.0.2
169 | - scipy==1.7.1
170 | - send2trash==1.8.0
171 | - sentry-sdk==1.28.1
172 | - setproctitle==1.3.2
173 | - setuptools==59.5.0
174 | - shapely==1.8.1.post1
175 | - smmap==5.0.0
176 | - stack-data==0.2.0
177 | - tabulate==0.8.10
178 | - tensorboard==2.10.0
179 | - tensorboard-data-server==0.6.1
180 | - tensorboard-plugin-wit==1.8.1
181 | - tensorboardx==2.4
182 | - termcolor==1.1.0
183 | - terminado==0.13.1
184 | - testpath==0.6.0
185 | - threadpoolctl==3.1.0
186 | - timm==0.4.12
187 | - tornado==6.1
188 | - tqdm==4.62.3
189 | - traitlets==5.1.1
190 | - urllib3==1.26.12
191 | - wandb==0.15.7
192 | - wcwidth==0.2.5
193 | - webencodings==0.5.1
194 | - werkzeug==2.2.2
195 | - wheel==0.37.1
196 | - widgetsnbextension==3.5.2
197 | - yacs==0.1.8
198 | - yapf==0.32.0
199 | - zipp==3.7.0
--------------------------------------------------------------------------------
/tools/infer_mm.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import argparse
3 | import yaml
4 | import math
5 | from torch import Tensor
6 | from torch.nn import functional as F
7 | from pathlib import Path
8 | from torchvision import io
9 | from torchvision import transforms as T
10 | import torchvision.transforms.functional as TF
11 | from semseg.models import *
12 | from semseg.datasets import *
13 | from semseg.utils.utils import timer
14 | from semseg.utils.visualize import draw_text
15 | import glob
16 | import os
17 | from PIL import Image, ImageDraw, ImageFont
18 |
19 |
20 | class SemSeg:
21 | def __init__(self, cfg) -> None:
22 | # inference device cuda or cpu
23 | self.device = torch.device(cfg['DEVICE'])
24 |
25 | # get dataset classes' colors and labels
26 | self.palette = eval(cfg['DATASET']['NAME']).PALETTE
27 | self.labels = eval(cfg['DATASET']['NAME']).CLASSES
28 |
29 | # initialize the model and load weights and send to device
30 | self.model = eval(cfg['MODEL']['NAME'])(cfg['MODEL']['BACKBONE'], len(self.palette), cfg['DATASET']['MODALS'])
31 | msg = self.model.load_state_dict(torch.load(cfg['EVAL']['MODEL_PATH'], map_location='cpu'))
32 | print(msg)
33 | self.model = self.model.to(self.device)
34 | self.model.eval()
35 |
36 | # preprocess parameters and transformation pipeline
37 | self.size = cfg['TEST']['IMAGE_SIZE']
38 | self.tf_pipeline_img = T.Compose([
39 | T.Resize(self.size),
40 | T.Lambda(lambda x: x / 255),
41 | T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
42 | T.Lambda(lambda x: x.unsqueeze(0))
43 | ])
44 | self.tf_pipeline_modal = T.Compose([
45 | T.Resize(self.size),
46 | T.Lambda(lambda x: x / 255),
47 | T.Lambda(lambda x: x.unsqueeze(0))
48 | ])
49 |
50 | def postprocess(self, orig_img: Tensor, seg_map: Tensor, overlay: bool) -> Tensor:
51 | seg_map = seg_map.softmax(dim=1).argmax(dim=1).cpu().to(int)
52 |
53 | seg_image = self.palette[seg_map].squeeze()
54 | if overlay:
55 | seg_image = (orig_img.permute(1, 2, 0) * 0.4) + (seg_image * 0.6)
56 |
57 | image = seg_image.to(torch.uint8)
58 | pil_image = Image.fromarray(image.numpy())
59 | return pil_image
60 |
61 | @torch.inference_mode()
62 | @timer
63 | def model_forward(self, img: Tensor) -> Tensor:
64 | return self.model(img)
65 |
66 | def _open_img(self, file):
67 | img = io.read_image(file)
68 | C, H, W = img.shape
69 | if C == 4:
70 | img = img[:3, ...]
71 | if C == 1:
72 | img = img.repeat(3, 1, 1)
73 | return img
74 |
75 | def predict(self, img_fname: str, overlay: bool) -> Tensor:
76 | if cfg['DATASET']['NAME'] == 'DELIVER':
77 | x1 = img_fname.replace('/img', '/hha').replace('_rgb', '_depth')
78 | x2 = img_fname.replace('/img', '/lidar').replace('_rgb', '_lidar')
79 | x3 = img_fname.replace('/img', '/event').replace('_rgb', '_event')
80 | lbl_path = img_fname.replace('/img', '/semantic').replace('_rgb', '_semantic')
81 | elif cfg['DATASET']['NAME'] == 'KITTI360':
82 | x1 = os.path.join(img_fname.replace('data_2d_raw', 'data_2d_hha'))
83 | x2 = os.path.join(img_fname.replace('data_2d_raw', 'data_2d_lidar'))
84 | x2 = x2.replace('.png', '_color.png')
85 | x3 = os.path.join(img_fname.replace('data_2d_raw', 'data_2d_event'))
86 | x3 = x3.replace('/image_00/data_rect/', '/').replace('.png', '_event_image.png')
87 | lbl_path = os.path.join(*[img_fname.replace('data_2d_raw', 'data_2d_semantics/train').replace('data_rect', 'semantic')])
88 |
89 | image = io.read_image(img_fname)[:3, ...]
90 | img = self.tf_pipeline_img(image).to(self.device)
91 | # --- modals
92 | x1 = self._open_img(x1)
93 | x1 = self.tf_pipeline_modal(x1).to(self.device)
94 | x2 = self._open_img(x2)
95 | x2 = self.tf_pipeline_modal(x2).to(self.device)
96 | x3 = self._open_img(x3)
97 | x3 = self.tf_pipeline_modal(x3).to(self.device)
98 | label = io.read_image(lbl_path)[0,...].unsqueeze(0)
99 | label[label==255] = 0
100 | label -= 1
101 |
102 | sample = [img, x1, x2, x3][:len(modals)]
103 |
104 | seg_map = self.model_forward(sample)
105 | seg_map = self.postprocess(image, seg_map, overlay)
106 | return seg_map
107 |
108 |
109 | if __name__ == '__main__':
110 | parser = argparse.ArgumentParser()
111 | parser.add_argument('--cfg', type=str, default='configs/DELIVER.yaml')
112 | args = parser.parse_args()
113 | with open(args.cfg) as f:
114 | cfg = yaml.load(f, Loader=yaml.SafeLoader)
115 |
116 | # cases = ['cloud', 'fog', 'night', 'rain', 'sun', 'motionblur', 'overexposure', 'underexposure', 'lidarjitter', 'eventlowres', None]
117 | cases = ['lidarjitter']
118 |
119 | modals = cfg['DATASET']['MODALS']
120 |
121 | test_file = Path(cfg['TEST']['FILE'])
122 | if not test_file.exists():
123 | raise FileNotFoundError(test_file)
124 |
125 | # print(f"Model {cfg['MODEL']['NAME']} {cfg['MODEL']['BACKBONE']}")
126 | # print(f"Model {cfg['DATASET']['NAME']}")
127 |
128 | modals_name = ''.join([m[0] for m in cfg['DATASET']['MODALS']])
129 | save_dir = Path(cfg['SAVE_DIR']) / 'test_results' / (cfg['DATASET']['NAME']+'_'+cfg['MODEL']['BACKBONE']+'_'+modals_name)
130 |
131 | semseg = SemSeg(cfg)
132 |
133 | if test_file.is_file():
134 | segmap = semseg.predict(str(test_file), cfg['TEST']['OVERLAY'])
135 | segmap.save(save_dir / f"{str(test_file.stem)}.png")
136 | else:
137 | if cfg['DATASET']['NAME'] == 'DELIVER':
138 | files = sorted(glob.glob(os.path.join(*[str(test_file), 'img', '*', 'val', '*', '*.png']))) # --- Deliver
139 | elif cfg['DATASET']['NAME'] == 'KITTI360':
140 | source = os.path.join(test_file, 'val.txt')
141 | files = []
142 | with open(source) as f:
143 | files_ = f.readlines()
144 | for item in files_:
145 | file_name = item.strip()
146 | if ' ' in file_name:
147 | # --- KITTI-360
148 | file_name = os.path.join(*[str(test_file), file_name.split(' ')[0]])
149 | files.append(file_name)
150 | else:
151 | raise NotImplementedError()
152 |
153 | for file in files:
154 | print(file)
155 | if not '2013_05_28_drive_0000_sync' in file:
156 | continue
157 | segmap = semseg.predict(file, cfg['TEST']['OVERLAY'])
158 | save_path = os.path.join(str(save_dir),file)
159 | os.makedirs(os.path.dirname(save_path), exist_ok=True)
160 | segmap.save(save_path)
161 |
--------------------------------------------------------------------------------
/tools/val_mm.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import argparse
3 | import yaml
4 | import math
5 | import os
6 | import time
7 | from pathlib import Path
8 | from tqdm import tqdm
9 | from tabulate import tabulate
10 | from torch.utils.data import DataLoader
11 | from torch.nn import functional as F
12 | from semseg.models import *
13 | from semseg.datasets import *
14 | from semseg.augmentations_mm import get_val_augmentation
15 | from semseg.metrics import Metrics
16 | from semseg.utils.utils import setup_cudnn
17 | from math import ceil
18 | import numpy as np
19 | from torch.utils.data import DistributedSampler, RandomSampler
20 | from torch import distributed as dist
21 | from torch.nn.parallel import DistributedDataParallel as DDP
22 | from semseg.utils.utils import fix_seeds, setup_cudnn, cleanup_ddp, setup_ddp, get_logger, cal_flops, print_iou
23 |
24 | def pad_image(img, target_size):
25 | rows_to_pad = max(target_size[0] - img.shape[2], 0)
26 | cols_to_pad = max(target_size[1] - img.shape[3], 0)
27 | padded_img = F.pad(img, (0, cols_to_pad, 0, rows_to_pad), "constant", 0)
28 | return padded_img
29 |
30 | @torch.no_grad()
31 | def sliding_predict(model, image, num_classes, flip=True):
32 | image_size = image[0].shape
33 | tile_size = (int(ceil(image_size[2]*1)), int(ceil(image_size[3]*1)))
34 | overlap = 1/3
35 |
36 | stride = ceil(tile_size[0] * (1 - overlap))
37 |
38 | num_rows = int(ceil((image_size[2] - tile_size[0]) / stride) + 1)
39 | num_cols = int(ceil((image_size[3] - tile_size[1]) / stride) + 1)
40 | total_predictions = torch.zeros((num_classes, image_size[2], image_size[3]), device=torch.device('cuda'))
41 | count_predictions = torch.zeros((image_size[2], image_size[3]), device=torch.device('cuda'))
42 | tile_counter = 0
43 |
44 | for row in range(num_rows):
45 | for col in range(num_cols):
46 | x_min, y_min = int(col * stride), int(row * stride)
47 | x_max = min(x_min + tile_size[1], image_size[3])
48 | y_max = min(y_min + tile_size[0], image_size[2])
49 |
50 | img = [modal[:, :, y_min:y_max, x_min:x_max] for modal in image]
51 | padded_img = [pad_image(modal, tile_size) for modal in img]
52 | tile_counter += 1
53 | padded_prediction = model(padded_img)
54 | if flip:
55 | fliped_img = [padded_modal.flip(-1) for padded_modal in padded_img]
56 | fliped_predictions = model(fliped_img)
57 | padded_prediction += fliped_predictions.flip(-1)
58 | predictions = padded_prediction[:, :, :img[0].shape[2], :img[0].shape[3]]
59 | count_predictions[y_min:y_max, x_min:x_max] += 1
60 | total_predictions[:, y_min:y_max, x_min:x_max] += predictions.squeeze(0)
61 |
62 | return total_predictions.unsqueeze(0)
63 |
64 | @torch.no_grad()
65 | def evaluate(model, dataloader, device, loss_fn=None):
66 | print('Evaluating...')
67 | model.eval()
68 | n_classes = dataloader.dataset.n_classes
69 | metrics = Metrics(n_classes, dataloader.dataset.ignore_label, device)
70 | sliding = False
71 | test_loss = 0.0
72 | iter = 0
73 | for images, labels in tqdm(dataloader):
74 | images = [x.to(device) for x in images]
75 | labels = labels.to(device)
76 | if sliding:
77 | # preds = sliding_predict(model, images, num_classes=n_classes).softmax(dim=1)
78 | preds = sliding_predict(model, images, num_classes=n_classes)
79 | else:
80 | # preds = model(images).softmax(dim=1)
81 | preds = model(images)
82 |
83 | metrics.update(preds.softmax(dim=1), labels)
84 |
85 | if loss_fn is not None:
86 | loss = loss_fn(preds, labels)
87 | test_loss += loss.item()
88 | iter += 1
89 |
90 | test_loss /= iter
91 | ious, miou = metrics.compute_iou()
92 | acc, macc = metrics.compute_pixel_acc()
93 | f1, mf1 = metrics.compute_f1()
94 |
95 | return acc, macc, f1, mf1, ious, miou, test_loss
96 |
97 |
98 | @torch.no_grad()
99 | def evaluate_msf(model, dataloader, device, scales, flip):
100 | model.eval()
101 |
102 | n_classes = dataloader.dataset.n_classes
103 | metrics = Metrics(n_classes, dataloader.dataset.ignore_label, device)
104 |
105 | for images, labels in tqdm(dataloader):
106 | labels = labels.to(device)
107 | B, H, W = labels.shape
108 | scaled_logits = torch.zeros(B, n_classes, H, W).to(device)
109 |
110 | for scale in scales:
111 | new_H, new_W = int(scale * H), int(scale * W)
112 | new_H, new_W = int(math.ceil(new_H / 32)) * 32, int(math.ceil(new_W / 32)) * 32
113 | scaled_images = [F.interpolate(img, size=(new_H, new_W), mode='bilinear', align_corners=True) for img in images]
114 | scaled_images = [scaled_img.to(device) for scaled_img in scaled_images]
115 | logits = model(scaled_images)
116 | logits = F.interpolate(logits, size=(H, W), mode='bilinear', align_corners=True)
117 | scaled_logits += logits.softmax(dim=1)
118 |
119 | if flip:
120 | scaled_images = [torch.flip(scaled_img, dims=(3,)) for scaled_img in scaled_images]
121 | logits = model(scaled_images)
122 | logits = torch.flip(logits, dims=(3,))
123 | logits = F.interpolate(logits, size=(H, W), mode='bilinear', align_corners=True)
124 | scaled_logits += logits.softmax(dim=1)
125 |
126 | metrics.update(scaled_logits, labels)
127 |
128 | acc, macc = metrics.compute_pixel_acc()
129 | f1, mf1 = metrics.compute_f1()
130 | ious, miou = metrics.compute_iou()
131 | return acc, macc, f1, mf1, ious, miou
132 |
133 |
134 | def main(cfg):
135 | device = torch.device(cfg['DEVICE'])
136 |
137 | eval_cfg = cfg['EVAL']
138 | transform = get_val_augmentation(eval_cfg['IMAGE_SIZE'])
139 | # cases = ['cloud', 'fog', 'night', 'rain', 'sun']
140 | # cases = ['motionblur', 'overexposure', 'underexposure', 'lidarjitter', 'eventlowres']
141 | cases = [None] # all
142 |
143 | model_path = Path(eval_cfg['MODEL_PATH'])
144 | if not model_path.exists():
145 | raise FileNotFoundError
146 | print(f"Evaluating {model_path}...")
147 |
148 | exp_time = time.strftime('%Y%m%d_%H%M%S', time.localtime())
149 | eval_path = os.path.join(os.path.dirname(eval_cfg['MODEL_PATH']), 'eval_{}.txt'.format(exp_time))
150 |
151 | for case in cases:
152 | dataset = eval(cfg['DATASET']['NAME'])(cfg['DATASET']['ROOT'], 'val', transform, cfg['DATASET']['MODALS'], case)
153 | # --- test set
154 | # dataset = eval(cfg['DATASET']['NAME'])(cfg['DATASET']['ROOT'], 'test', transform, cfg['DATASET']['MODALS'], case)
155 |
156 | model = eval(cfg['MODEL']['NAME'])(cfg['MODEL']['BACKBONE'], dataset.n_classes, cfg['DATASET']['MODALS'])
157 | msg = model.load_state_dict(torch.load(str(model_path), map_location='cpu'))
158 | print(msg)
159 | model = model.to(device)
160 | sampler_val = None
161 | dataloader = DataLoader(dataset, batch_size=eval_cfg['BATCH_SIZE'], num_workers=eval_cfg['BATCH_SIZE'], pin_memory=False, sampler=sampler_val)
162 | if True:
163 | if eval_cfg['MSF']['ENABLE']:
164 | acc, macc, f1, mf1, ious, miou = evaluate_msf(model, dataloader, device, eval_cfg['MSF']['SCALES'], eval_cfg['MSF']['FLIP'])
165 | else:
166 | acc, macc, f1, mf1, ious, miou, _ = evaluate(model, dataloader, device)
167 |
168 | table = {
169 | 'Class': list(dataset.CLASSES) + ['Mean'],
170 | 'IoU': ious + [miou],
171 | 'F1': f1 + [mf1],
172 | 'Acc': acc + [macc]
173 | }
174 | print("mIoU : {}".format(miou))
175 | print("Results saved in {}".format(eval_cfg['MODEL_PATH']))
176 |
177 | with open(eval_path, 'a+') as f:
178 | f.writelines(eval_cfg['MODEL_PATH'])
179 | f.write("\n============== Eval on {} {} images =================\n".format(case, len(dataset)))
180 | f.write("\n")
181 | print(tabulate(table, headers='keys'), file=f)
182 |
183 |
184 |
185 | if __name__ == '__main__':
186 | parser = argparse.ArgumentParser()
187 | parser.add_argument('--cfg', type=str, default='configs/mcubes_rgbadn.yaml')
188 | args = parser.parse_args()
189 |
190 | with open(args.cfg) as f:
191 | cfg = yaml.load(f, Loader=yaml.SafeLoader)
192 |
193 | setup_cudnn()
194 | # gpu = setup_ddp()
195 | # main(cfg, gpu)
196 | main(cfg)
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## MMSFormer: Multimodal Transformer for Material and Semantic Segmentation
4 |
5 |
6 |
7 |
8 |
9 | [](https://paperswithcode.com/sota/semantic-segmentation-on-mcubes?p=multimodal-transformer-for-material)
10 |
11 | [](https://paperswithcode.com/sota/semantic-segmentation-on-fmb-dataset?p=multimodal-transformer-for-material)
12 |
13 | [](https://paperswithcode.com/sota/thermal-image-segmentation-on-pst900?p=multimodal-transformer-for-material)
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 | ## Introduction
30 |
31 | Leveraging information across diverse modalities is known to enhance performance on multimodal segmentation tasks. However, effectively fusing information from different modalities remains challenging due to the unique characteristics of each modality. In this paper, we propose a novel fusion strategy that can effectively fuse information from different modality combinations. We also propose a new model named **M**ulti-**M**odal **S**egmentation Trans**Former** (**MMSFormer**) that incorporates the proposed fusion strategy to perform multimodal material and semantic segmentation tasks. MMSFormer outperforms current state-of-the-art models on three different datasets. As we begin with only one input modality, performance improves progressively as additional modalities are incorporated, showcasing the effectiveness of the fusion block in combining useful information from diverse input modalities. Ablation studies show that different modules in the fusion block are crucial for overall model performance. Furthermore, our ablation studies also highlight the capacity of different input modalities to improve performance in the identification of different types of materials.
32 |
33 | For more details, please check our [arXiv](https://arxiv.org/abs/2309.04001) paper.
34 |
35 | ## Updates
36 | - [x] 09/2023: Init repository.
37 | - [x] 09/2023: Release the code for MMSFormer.
38 | - [x] 09/2023: Release MMSFormer model weights. Download from [**GoogleDrive**](https://drive.google.com/drive/folders/1OPr7PUrL7hkBXogmHFzHuTJweHuJmlP-?usp=sharing).
39 | - [x] 01/2024: Update code, description and pretrained weights.
40 | - [x] 04/2024: Accepted by [IEEE Open Journal of Signal Processing](https://ieeexplore.ieee.org/document/10502124).
41 |
42 | ## MMSFormer model
43 |
44 |
45 |
46 | 
47 | **Figure:** Overall architecture of MMSFormer model and proposed fusion block.
48 |
49 |
50 |
51 | ## Environment
52 |
53 | First, create and activate the environment using the following commands:
54 | ```bash
55 | conda env create -f environment.yaml
56 | conda activate mmsformer
57 | ```
58 |
59 | ## Data preparation
60 | Download the dataset:
61 | - [MCubeS](https://github.com/kyotovision-public/multimodal-material-segmentation), for multimodal material segmentation with RGB-A-D-N modalities.
62 | - [FMB](https://github.com/JinyuanLiu-CV/SegMiF), for FMB dataset with RGB-Infrared modalities.
63 | - [PST](https://github.com/ShreyasSkandanS/pst900_thermal_rgb), for PST900 dataset with RGB-Thermal modalities.
64 |
65 | Then, put the dataset under `data` directory as follows:
66 |
67 | ```
68 | data/
69 | ├── MCubeS
70 | │ ├── polL_color
71 | │ ├── polL_aolp_sin
72 | │ ├── polL_aolp_cos
73 | │ ├── polL_dolp
74 | │ ├── NIR_warped
75 | │ ├── NIR_warped_mask
76 | │ ├── GT
77 | │ ├── SSGT4MS
78 | │ ├── list_folder
79 | │ └── SS
80 | ├── FMB
81 | │ ├── test
82 | │ │ ├── color
83 | │ │ ├── Infrared
84 | │ │ ├── Label
85 | │ │ └── Visible
86 | │ ├── train
87 | │ │ ├── color
88 | │ │ ├── Infrared
89 | │ │ ├── Label
90 | │ │ └── Visible
91 | ├── PST
92 | │ ├── test
93 | │ │ ├── rgb
94 | │ │ ├── thermal
95 | │ │ └── labels
96 | │ ├── train
97 | │ │ ├── rgb
98 | │ │ ├── thermal
99 | │ │ └── labels
100 | ```
101 |
102 | ## Model Zoo
103 |
104 | ### MCubeS
105 | | Model-Modal | mIoU | weight |
106 | | :--------------- | :----- | :----- |
107 | | MCubeS-RGB | 50.44 | [GoogleDrive](https://drive.google.com/drive/folders/1TiC4spUgMGo8zO2iChpuuRo8cmZC2yeh?usp=sharing) |
108 | | MCubeS-RGB-A | 51.30 | [GoogleDrive](https://drive.google.com/drive/folders/1TiC4spUgMGo8zO2iChpuuRo8cmZC2yeh?usp=sharing) |
109 | | MCubeS-RGB-A-D | 52.03 | [GoogleDrive](https://drive.google.com/drive/folders/1TiC4spUgMGo8zO2iChpuuRo8cmZC2yeh?usp=sharing) |
110 | | MCubeS-RGB-A-D-N | 53.11 | [GoogleDrive](https://drive.google.com/drive/folders/1TiC4spUgMGo8zO2iChpuuRo8cmZC2yeh?usp=sharing) |
111 |
112 | ### FMB
113 | | Model-Modal | mIoU | weight |
114 | | :--------------- | :----- | :----- |
115 | | FMB-RGB | 57.17 | [GoogleDrive](https://drive.google.com/drive/folders/15kuBWiEHOxxOLMxvASYzPhdSgG8ZWfgm?usp=sharing) |
116 | | FMB-RGB-Infrared | 61.68 | [GoogleDrive](https://drive.google.com/drive/folders/15kuBWiEHOxxOLMxvASYzPhdSgG8ZWfgm?usp=sharing) |
117 |
118 | ### PST900
119 | | Model-Modal | mIoU | weight |
120 | | :--------------- | :----- | :----- |
121 | | PST-RGB-T | 87.45 | [GoogleDrive](https://drive.google.com/drive/folders/1yv7wfGVLrxBYQ3teDg3eL-zYJsie56Ll?usp=sharing) |
122 |
123 |
124 | ## Training
125 |
126 | Before training, please download [pre-trained SegFormer](https://drive.google.com/drive/folders/10XgSW8f7ghRs9fJ0dE-EV8G2E_guVsT5), and put it in the correct directory following this structure:
127 |
128 | ```text
129 | checkpoints/pretrained/segformer
130 | ├── mit_b0.pth
131 | ├── mit_b1.pth
132 | ├── mit_b2.pth
133 | ├── mit_b3.pth
134 | └── mit_b4.pth
135 | ```
136 |
137 | To train MMSFormer model, please update the appropriate configuration file in `configs/` with appropriate paths and hyper-parameters. Then run as follows:
138 |
139 | ```bash
140 | cd path/to/MMSFormer
141 | conda activate mmsformer
142 |
143 | python -m tools.train_mm --cfg configs/mcubes_rgbadn.yaml
144 |
145 | python -m tools.train_mm --cfg configs/fmb_rgbt.yaml
146 |
147 | python -m tools.train_mm --cfg configs/pst_rgbt.yaml
148 | ```
149 |
150 |
151 | ## Evaluation
152 | To evaluate MMSFormer models, please download respective model weights ([**GoogleDrive**](https://drive.google.com/drive/folders/1OPr7PUrL7hkBXogmHFzHuTJweHuJmlP-?usp=sharing)) and save them under any folder you like.
153 |
154 |
163 |
164 | Then, update the `EVAL` section of the appropriate configuration file in `configs/` and run:
165 |
166 | ```bash
167 | cd path/to/MMSFormer
168 | conda activate mmsformer
169 |
170 | python -m tools.val_mm --cfg configs/mcubes_rgbadn.yaml
171 |
172 | python -m tools.val_mm --cfg configs/fmb_rgbt.yaml
173 |
174 | python -m tools.val_mm --cfg configs/pst_rgbt.yaml
175 | ```
176 |
177 | ## License
178 |
179 | This repository is under the Apache-2.0 license. For commercial use, please contact with the authors.
180 |
181 |
182 | ## Citations
183 |
184 | If you use MMSFormer model, please cite the following work:
185 |
186 | - **MMSFormer** [[**arXiv**](https://arxiv.org/abs/2309.04001)]
187 | ```
188 | @ARTICLE{Reza2024MMSFormer,
189 | author={Reza, Md Kaykobad and Prater-Bennette, Ashley and Asif, M. Salman},
190 | journal={IEEE Open Journal of Signal Processing},
191 | title={MMSFormer: Multimodal Transformer for Material and Semantic Segmentation},
192 | year={2024},
193 | volume={},
194 | number={},
195 | pages={1-12},
196 | keywords={Image segmentation;Feature extraction;Transformers;Task analysis;Fuses;Semantic segmentation;Decoding;multimodal image segmentation;material segmentation;semantic segmentation;multimodal fusion;transformer},
197 | doi={10.1109/OJSP.2024.3389812}
198 | }
199 | ```
200 |
201 | ## Acknowledgements
202 | Our codebase is based on the following Github repositories. Thanks to the following public repositories:
203 | - [DELIVER](https://github.com/jamycheung/DELIVER)
204 | - [RGBX-semantic-segmentation](https://github.com/huaaaliu/RGBX_Semantic_Segmentation)
205 | - [Semantic-segmentation](https://github.com/sithu31296/semantic-segmentation)
206 |
207 | **Note:** This is a research level repository and might contain issues/bugs. Please contact the authors for any query.
208 |
--------------------------------------------------------------------------------
/semseg/utils/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import random
4 | import time
5 | import os
6 | import sys
7 | import functools
8 | from pathlib import Path
9 | from torch.backends import cudnn
10 | from torch import nn, Tensor
11 | from torch.autograd import profiler
12 | from typing import Union
13 | from torch import distributed as dist
14 | from tabulate import tabulate
15 | from semseg import models
16 | import logging
17 | from fvcore.nn import flop_count_table, FlopCountAnalysis
18 | import datetime
19 |
20 | def fix_seeds(seed: int = 3407) -> None:
21 | torch.manual_seed(seed)
22 | torch.cuda.manual_seed(seed)
23 | np.random.seed(seed)
24 | random.seed(seed)
25 |
26 | def setup_cudnn() -> None:
27 | # Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html
28 | cudnn.benchmark = True
29 | cudnn.deterministic = False
30 |
31 | def time_sync() -> float:
32 | if torch.cuda.is_available():
33 | torch.cuda.synchronize()
34 | return time.time()
35 |
36 | def get_model_size(model: Union[nn.Module, torch.jit.ScriptModule]):
37 | tmp_model_path = Path('temp.p')
38 | if isinstance(model, torch.jit.ScriptModule):
39 | torch.jit.save(model, tmp_model_path)
40 | else:
41 | torch.save(model.state_dict(), tmp_model_path)
42 | size = tmp_model_path.stat().st_size
43 | os.remove(tmp_model_path)
44 | return size / 1e6 # in MB
45 |
46 | @torch.no_grad()
47 | def test_model_latency(model: nn.Module, inputs: torch.Tensor, use_cuda: bool = False) -> float:
48 | with profiler.profile(use_cuda=use_cuda) as prof:
49 | _ = model(inputs)
50 | return prof.self_cpu_time_total / 1000 # ms
51 |
52 | def count_parameters(model: nn.Module) -> float:
53 | return sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6 # in M
54 |
55 | def setup_ddp():
56 | # print(os.environ.keys())
57 | if 'SLURM_PROCID' in os.environ and not 'RANK' in os.environ:
58 | # --- multi nodes
59 | world_size = int(os.environ['WORLD_SIZE'])
60 | rank = int(os.environ["SLURM_PROCID"])
61 | gpus_per_node = int(os.environ["SLURM_GPUS_ON_NODE"])
62 | gpu = rank - gpus_per_node * (rank // gpus_per_node)
63 | torch.cuda.set_device(gpu)
64 | dist.init_process_group(backend="nccl", world_size=world_size, rank=rank, timeout=datetime.timedelta(seconds=7200))
65 | elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
66 | rank = int(os.environ['RANK'])
67 | world_size = int(os.environ['WORLD_SIZE'])
68 | # gpu = int(os.environ(['LOCAL_RANK']))
69 | # ---
70 | gpu = int(os.environ['LOCAL_RANK'])
71 | torch.cuda.set_device(gpu)
72 | dist.init_process_group('nccl', init_method="env://",world_size=world_size, rank=rank, timeout=datetime.timedelta(seconds=7200))
73 | dist.barrier()
74 | else:
75 | gpu = 0
76 | return gpu
77 |
78 | def cleanup_ddp():
79 | if dist.is_initialized():
80 | dist.destroy_process_group()
81 |
82 | def reduce_tensor(tensor: Tensor) -> Tensor:
83 | rt = tensor.clone()
84 | dist.all_reduce(rt, op=dist.ReduceOp.SUM)
85 | rt /= dist.get_world_size()
86 | return rt
87 |
88 | @torch.no_grad()
89 | def throughput(dataloader, model: nn.Module, times: int = 30):
90 | model.eval()
91 | images, _ = next(iter(dataloader))
92 | images = images.cuda(non_blocking=True)
93 | B = images.shape[0]
94 | print(f"Throughput averaged with {times} times")
95 | start = time_sync()
96 | for _ in range(times):
97 | model(images)
98 | end = time_sync()
99 |
100 | print(f"Batch Size {B} throughput {times * B / (end - start)} images/s")
101 |
102 |
103 | def show_models():
104 | model_names = models.__all__
105 | model_variants = [list(eval(f'models.{name.lower()}_settings').keys()) for name in model_names]
106 |
107 | print(tabulate({'Model Names': model_names, 'Model Variants': model_variants}, headers='keys'))
108 |
109 |
110 | def timer(func):
111 | @functools.wraps(func)
112 | def wrapper_timer(*args, **kwargs):
113 | tic = time.perf_counter()
114 | value = func(*args, **kwargs)
115 | toc = time.perf_counter()
116 | elapsed_time = toc - tic
117 | print(f"Elapsed time: {elapsed_time * 1000:.2f}ms")
118 | return value
119 | return wrapper_timer
120 |
121 |
122 | # _default_level_name = os.getenv('ENGINE_LOGGING_LEVEL', 'INFO')
123 | # _default_level = logging.getLevelName(_default_level_name.upper())
124 |
125 | def get_logger(log_file=None):
126 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s: - %(message)s',datefmt='%Y%m%d %H:%M:%S')
127 | logger = logging.getLogger()
128 | # logger.setLevel(logging.DEBUG)
129 | logger.setLevel(logging.INFO)
130 | del logger.handlers[:]
131 |
132 | if log_file:
133 | file_handler = logging.FileHandler(log_file, mode='w')
134 | # file_handler.setLevel(logging.DEBUG)
135 | file_handler.setLevel(logging.INFO)
136 | file_handler.setFormatter(formatter)
137 | logger.addHandler(file_handler)
138 |
139 | stream_handler = logging.StreamHandler()
140 | stream_handler.setFormatter(formatter)
141 | # stream_handler.setLevel(logging.DEBUG)
142 | stream_handler.setLevel(logging.INFO)
143 | logger.addHandler(stream_handler)
144 | return logger
145 |
146 |
147 | def cal_flops(model, modals, logger):
148 | x = [torch.zeros(1, 3, 512, 512) for _ in range(len(modals))]
149 | # x = [torch.zeros(2, 3, 512, 512) for _ in range(len(modals))] #--- PGSNet
150 | # x = [torch.zeros(1, 3, 512, 512) for _ in range(len(modals))] # --- for HRFuser
151 | if torch.distributed.is_initialized():
152 | if 'HR' in model.module.__class__.__name__:
153 | x = [torch.zeros(1, 3, 512, 512) for _ in range(len(modals))] # --- for HorNet
154 | else:
155 | if 'HR' in model.__class__.__name__:
156 | x = [torch.zeros(1, 3, 512, 512) for _ in range(len(modals))] # --- for HorNet
157 |
158 | if torch.cuda.is_available:
159 | x = [xi.cuda() for xi in x]
160 | model = model.cuda()
161 | logger.info(flop_count_table(FlopCountAnalysis(model, x)))
162 |
163 | def print_iou(epoch, iou, miou, acc, macc, class_names):
164 | assert len(iou) == len(class_names)
165 | assert len(acc) == len(class_names)
166 | lines = ['\n%-8s\t%-8s\t%-8s' % ('Class', 'IoU', 'Acc')]
167 | for i in range(len(iou)):
168 | if class_names is None:
169 | cls = 'Class %d:' % (i+1)
170 | else:
171 | cls = '%d %s' % (i+1, class_names[i])
172 | lines.append('%-8s\t%.2f\t%.2f' % (cls, iou[i], acc[i]))
173 | lines.append('== %-8s\t%d\t%-8s\t%.2f\t%-8s\t%.2f' % ('Epoch:', epoch, 'mean_IoU', miou, 'mean_Acc',macc))
174 | line = "\n".join(lines)
175 | return line
176 |
177 |
178 | def nchw_to_nlc(x):
179 | """Flatten [N, C, H, W] shape tensor to [N, L, C] shape tensor.
180 |
181 | Args:
182 | x (Tensor): The input tensor of shape [N, C, H, W] before conversion.
183 |
184 | Returns:
185 | Tensor: The output tensor of shape [N, L, C] after conversion.
186 | """
187 | assert len(x.shape) == 4
188 | return x.flatten(2).transpose(1, 2).contiguous()
189 |
190 | def nlc_to_nchw(x, hw_shape):
191 | """Convert [N, L, C] shape tensor to [N, C, H, W] shape tensor.
192 |
193 | Args:
194 | x (Tensor): The input tensor of shape [N, L, C] before conversion.
195 | hw_shape (Sequence[int]): The height and width of output feature map.
196 |
197 | Returns:
198 | Tensor: The output tensor of shape [N, C, H, W] after conversion.
199 | """
200 | H, W = hw_shape
201 | assert len(x.shape) == 3
202 | B, L, C = x.shape
203 | assert L == H * W, 'The seq_len does not match H, W'
204 | return x.transpose(1, 2).reshape(B, C, H, W).contiguous()
205 |
206 | def nlc2nchw2nlc(module, x, hw_shape, contiguous=False, **kwargs):
207 | """Convert [N, L, C] shape tensor `x` to [N, C, H, W] shape tensor. Use the
208 | reshaped tensor as the input of `module`, and convert the output of
209 | `module`, whose shape is.
210 | [N, C, H, W], to [N, L, C].
211 | Args:
212 | module (Callable): A callable object the takes a tensor
213 | with shape [N, C, H, W] as input.
214 | x (Tensor): The input tensor of shape [N, L, C].
215 | hw_shape: (Sequence[int]): The height and width of the
216 | feature map with shape [N, C, H, W].
217 | contiguous (Bool): Whether to make the tensor contiguous
218 | after each shape transform.
219 | Returns:
220 | Tensor: The output tensor of shape [N, L, C].
221 | Example:
222 | >>> import torch
223 | >>> import torch.nn as nn
224 | >>> conv = nn.Conv2d(16, 16, 3, 1, 1)
225 | >>> feature_map = torch.rand(4, 25, 16)
226 | >>> output = nlc2nchw2nlc(conv, feature_map, (5, 5))
227 | """
228 | H, W = hw_shape
229 | assert len(x.shape) == 3
230 | B, L, C = x.shape
231 | assert L == H * W, 'The seq_len doesn\'t match H, W'
232 | if not contiguous:
233 | x = x.transpose(1, 2).reshape(B, C, H, W)
234 | x = module(x, **kwargs)
235 | x = x.flatten(2).transpose(1, 2)
236 | else:
237 | x = x.transpose(1, 2).reshape(B, C, H, W).contiguous()
238 | x = module(x, **kwargs)
239 | x = x.flatten(2).transpose(1, 2).contiguous()
240 | return x
241 |
--------------------------------------------------------------------------------
/tools/train_mm.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import argparse
4 | import yaml
5 | import time
6 | import multiprocessing as mp
7 | from tabulate import tabulate
8 | from tqdm import tqdm
9 | from torch.utils.data import DataLoader
10 | from pathlib import Path
11 | from torch.utils.tensorboard import SummaryWriter
12 | from torch.cuda.amp import GradScaler, autocast
13 | from torch.nn.parallel import DistributedDataParallel as DDP
14 | from torch.utils.data import DistributedSampler, RandomSampler
15 | from torch import distributed as dist
16 | from semseg.models import *
17 | from semseg.datasets import *
18 | from semseg.augmentations_mm import get_train_augmentation, get_val_augmentation
19 | from semseg.losses import get_loss
20 | from semseg.schedulers import get_scheduler
21 | from semseg.optimizers import get_optimizer
22 | from semseg.utils.utils import fix_seeds, setup_cudnn, cleanup_ddp, setup_ddp, get_logger, cal_flops, print_iou
23 | from tools.val_mm import evaluate
24 | import wandb
25 | from semseg.metrics import Metrics
26 | import gc
27 |
28 |
29 | def main(cfg, save_dir):
30 | start = time.time()
31 | best_mIoU = 0.0
32 | best_epoch = 0
33 | num_workers = 2
34 | device = torch.device(cfg['DEVICE'])
35 | train_cfg, eval_cfg = cfg['TRAIN'], cfg['EVAL']
36 | dataset_cfg, model_cfg = cfg['DATASET'], cfg['MODEL']
37 | loss_cfg, optim_cfg, sched_cfg = cfg['LOSS'], cfg['OPTIMIZER'], cfg['SCHEDULER']
38 | epochs, lr = train_cfg['EPOCHS'], optim_cfg['LR']
39 | resume_path = cfg['MODEL']['RESUME']
40 | gpus = cfg['GPUs']
41 | use_wandb = cfg['USE_WANDB']
42 | wandb_name = cfg['WANDB_NAME']
43 | # gpus = int(os.environ['WORLD_SIZE'])
44 |
45 | traintransform = get_train_augmentation(train_cfg['IMAGE_SIZE'], seg_fill=dataset_cfg['IGNORE_LABEL'])
46 | valtransform = get_val_augmentation(eval_cfg['IMAGE_SIZE'])
47 |
48 | trainset = eval(dataset_cfg['NAME'])(dataset_cfg['ROOT'], 'train', traintransform, dataset_cfg['MODALS'])
49 | valset = eval(dataset_cfg['NAME'])(dataset_cfg['ROOT'], 'val', valtransform, dataset_cfg['MODALS'])
50 | class_names = trainset.CLASSES
51 |
52 | model = eval(model_cfg['NAME'])(model_cfg['BACKBONE'], trainset.n_classes, dataset_cfg['MODALS'])
53 | resume_checkpoint = None
54 | if os.path.isfile(resume_path):
55 | resume_checkpoint = torch.load(resume_path, map_location=torch.device('cpu'))
56 | msg = model.load_state_dict(resume_checkpoint['model_state_dict'])
57 | # print(msg)
58 | logger.info(msg)
59 | else:
60 | model.init_pretrained(model_cfg['PRETRAINED'])
61 |
62 | model = torch.nn.DataParallel(model, device_ids=cfg['GPU_IDs'])
63 | model = model.to(device)
64 |
65 | iters_per_epoch = len(trainset) // train_cfg['BATCH_SIZE']
66 | loss_fn = get_loss(loss_cfg['NAME'], trainset.ignore_label, None)
67 | start_epoch = 0
68 | optimizer = get_optimizer(model, optim_cfg['NAME'], lr, optim_cfg['WEIGHT_DECAY'])
69 | scheduler = get_scheduler(sched_cfg['NAME'], optimizer, int((epochs+1)*iters_per_epoch), sched_cfg['POWER'], iters_per_epoch * sched_cfg['WARMUP'], sched_cfg['WARMUP_RATIO'])
70 |
71 | if train_cfg['DDP']:
72 | sampler = DistributedSampler(trainset, dist.get_world_size(), dist.get_rank(), shuffle=True)
73 | sampler_val = None
74 | model = DDP(model, device_ids=[gpu], output_device=0, find_unused_parameters=True)
75 | else:
76 | sampler = RandomSampler(trainset)
77 | sampler_val = None
78 |
79 | if resume_checkpoint:
80 | start_epoch = resume_checkpoint['epoch'] - 1
81 | optimizer.load_state_dict(resume_checkpoint['optimizer_state_dict'])
82 | scheduler.load_state_dict(resume_checkpoint['scheduler_state_dict'])
83 | loss = resume_checkpoint['loss']
84 | best_mIoU = resume_checkpoint['best_miou']
85 | del resume_checkpoint
86 |
87 | trainloader = DataLoader(trainset, batch_size=train_cfg['BATCH_SIZE'], num_workers=num_workers, drop_last=True, pin_memory=False, sampler=sampler)
88 | valloader = DataLoader(valset, batch_size=eval_cfg['BATCH_SIZE'], num_workers=num_workers, pin_memory=False, sampler=sampler_val)
89 |
90 | scaler = GradScaler(enabled=train_cfg['AMP'])
91 | if (train_cfg['DDP'] and torch.distributed.get_rank() == 0) or (not train_cfg['DDP']):
92 | writer = SummaryWriter(str(save_dir))
93 | logger.info('================== model complexity =====================')
94 | cal_flops(model, dataset_cfg['MODALS'], logger)
95 | logger.info('================== model structure =====================')
96 | logger.info(model)
97 | logger.info('================== training config =====================')
98 | logger.info(cfg)
99 | logger.info('================== parameter count =====================')
100 | logger.info(sum(p.numel() for p in model.parameters() if p.requires_grad))
101 |
102 | for epoch in range(start_epoch, epochs):
103 | # Clean Memory
104 | torch.cuda.empty_cache()
105 | gc.collect()
106 |
107 | model.train()
108 | if train_cfg['DDP']: sampler.set_epoch(epoch)
109 |
110 | train_loss = 0.0
111 | lr = scheduler.get_lr()
112 | lr = sum(lr) / len(lr)
113 | pbar = tqdm(enumerate(trainloader), total=iters_per_epoch, desc=f"Epoch: [{epoch+1}/{epochs}] Iter: [{0}/{iters_per_epoch}] LR: {lr:.8f} Loss: {train_loss:.8f}")
114 | metrics = Metrics(trainset.n_classes, trainloader.dataset.ignore_label, device)
115 |
116 | for iter, (sample, lbl) in pbar:
117 | optimizer.zero_grad(set_to_none=True)
118 | sample = [x.to(device) for x in sample]
119 | lbl = lbl.to(device)
120 |
121 | with autocast(enabled=train_cfg['AMP']):
122 | logits = model(sample)
123 | loss = loss_fn(logits, lbl)
124 |
125 | metrics.update(logits.softmax(dim=1), lbl)
126 |
127 | scaler.scale(loss).backward()
128 | scaler.step(optimizer)
129 | scaler.update()
130 | scheduler.step()
131 | torch.cuda.synchronize()
132 |
133 | lr = scheduler.get_lr()
134 | lr = sum(lr) / len(lr)
135 | if lr <= 1e-8:
136 | lr = 1e-8 # minimum of lr
137 | train_loss += loss.item()
138 |
139 | # Clean Memory
140 | torch.cuda.empty_cache()
141 | gc.collect()
142 |
143 | pbar.set_description(f"Epoch: [{epoch+1}/{epochs}] Iter: [{iter+1}/{iters_per_epoch}] LR: {lr:.8f} Loss: {train_loss / (iter+1):.8f}")
144 |
145 | train_loss /= iter+1
146 | if (train_cfg['DDP'] and torch.distributed.get_rank() == 0) or (not train_cfg['DDP']):
147 | writer.add_scalar('train/loss', train_loss, epoch)
148 |
149 | ious, miou = metrics.compute_iou()
150 | acc, macc = metrics.compute_pixel_acc()
151 | f1, mf1 = metrics.compute_f1()
152 |
153 | # if use_wandb:
154 | train_log_data = {
155 | "Epoch": epoch+1,
156 | "Train Loss": train_loss,
157 | "Train mIoU": miou,
158 | "Train Pixel Acc": macc,
159 | "Train F1": mf1,
160 | }
161 |
162 | if ((epoch+1) % train_cfg['EVAL_INTERVAL'] == 0 and (epoch+1)>train_cfg['EVAL_START']) or (epoch+1) == epochs:
163 | if (train_cfg['DDP'] and torch.distributed.get_rank() == 0) or (not train_cfg['DDP']):
164 | acc, macc, f1, mf1, ious, miou, test_loss = evaluate(model, valloader, device, loss_fn=loss_fn)
165 | writer.add_scalar('val/mIoU', miou, epoch)
166 |
167 | # if use wandb
168 | log_data = {
169 | "Test Loss": test_loss,
170 | "Test mIoU": miou,
171 | "Test Pixel Acc": macc,
172 | "Test F1": mf1,
173 | }
174 | log_data.update(train_log_data)
175 | print(log_data)
176 | if use_wandb:
177 | wandb.log(log_data)
178 |
179 | if miou > best_mIoU:
180 | prev_best_ckp = save_dir / f"{model_cfg['NAME']}_{model_cfg['BACKBONE']}_{dataset_cfg['NAME']}_epoch{best_epoch}_{best_mIoU}_checkpoint.pth"
181 | prev_best = save_dir / f"{model_cfg['NAME']}_{model_cfg['BACKBONE']}_{dataset_cfg['NAME']}_epoch{best_epoch}_{best_mIoU}.pth"
182 | if os.path.isfile(prev_best): os.remove(prev_best)
183 | if os.path.isfile(prev_best_ckp): os.remove(prev_best_ckp)
184 | best_mIoU = miou
185 | best_epoch = epoch+1
186 | cur_best_ckp = save_dir / f"{model_cfg['NAME']}_{model_cfg['BACKBONE']}_{dataset_cfg['NAME']}_epoch{best_epoch}_{best_mIoU}_checkpoint.pth"
187 | cur_best = save_dir / f"{model_cfg['NAME']}_{model_cfg['BACKBONE']}_{dataset_cfg['NAME']}_epoch{best_epoch}_{best_mIoU}.pth"
188 | # torch.save(model.module.state_dict() if train_cfg['DDP'] else model.state_dict(), cur_best)
189 | torch.save(model.module.state_dict(), cur_best)
190 | # ---
191 | torch.save({'epoch': best_epoch,
192 | 'model_state_dict': model.module.state_dict() if train_cfg['DDP'] else model.state_dict(),
193 | 'optimizer_state_dict': optimizer.state_dict(),
194 | 'loss': train_loss,
195 | 'scheduler_state_dict': scheduler.state_dict(),
196 | 'best_miou': best_mIoU,
197 | }, cur_best_ckp)
198 | logger.info(print_iou(epoch, ious, miou, acc, macc, class_names))
199 | logger.info(f"Current epoch:{epoch} mIoU: {miou} Best mIoU: {best_mIoU}")
200 |
201 | if (train_cfg['DDP'] and torch.distributed.get_rank() == 0) or (not train_cfg['DDP']):
202 | writer.close()
203 | pbar.close()
204 | end = time.gmtime(time.time() - start)
205 |
206 | table = [
207 | ['Best mIoU', f"{best_mIoU:.2f}"],
208 | ['Total Training Time', time.strftime("%H:%M:%S", end)]
209 | ]
210 | logger.info(tabulate(table, numalign='right'))
211 |
212 |
213 | if __name__ == '__main__':
214 | parser = argparse.ArgumentParser()
215 | parser.add_argument('--cfg', type=str, default='configs/mcubes_rgbadn.yaml', help='Configuration file to use')
216 | args = parser.parse_args()
217 |
218 | with open(args.cfg) as f:
219 | cfg = yaml.load(f, Loader=yaml.SafeLoader)
220 |
221 | fix_seeds(3407)
222 | setup_cudnn()
223 | # gpu = setup_ddp()
224 | modals = ''.join([m[0] for m in cfg['DATASET']['MODALS']])
225 | model = cfg['MODEL']['BACKBONE']
226 | # exp_name = '_'.join([cfg['DATASET']['NAME'], model, modals])
227 | exp_name = cfg['WANDB_NAME']
228 | if cfg['USE_WANDB']:
229 | wandb.init(project="ProjcetName", entity="EntityName", name=exp_name)
230 |
231 | save_dir = Path(cfg['SAVE_DIR'], exp_name)
232 | if os.path.isfile(cfg['MODEL']['RESUME']):
233 | save_dir = Path(os.path.dirname(cfg['MODEL']['RESUME']))
234 | os.makedirs(save_dir, exist_ok=True)
235 | logger = get_logger(save_dir / 'train.log')
236 | main(cfg, save_dir)
237 | cleanup_ddp()
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [2023] [Jiaming Zhang]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/semseg/augmentations.py:
--------------------------------------------------------------------------------
1 | import torchvision.transforms.functional as TF
2 | import random
3 | import math
4 | import torch
5 | from torch import Tensor
6 | from typing import Tuple, List, Union, Tuple, Optional
7 |
8 |
9 | class Compose:
10 | def __init__(self, transforms: list) -> None:
11 | self.transforms = transforms
12 |
13 | def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]:
14 | if mask.ndim == 2:
15 | assert img.shape[1:] == mask.shape
16 | else:
17 | assert img.shape[1:] == mask.shape[1:]
18 |
19 | for transform in self.transforms:
20 | img, mask = transform(img, mask)
21 |
22 | return img, mask
23 |
24 |
25 | class Normalize:
26 | def __init__(self, mean: list = (0.485, 0.456, 0.406), std: list = (0.229, 0.224, 0.225)):
27 | self.mean = mean
28 | self.std = std
29 |
30 | def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]:
31 | img = img.float()
32 | img /= 255
33 | img = TF.normalize(img, self.mean, self.std)
34 | return img, mask
35 |
36 |
37 | class ColorJitter:
38 | def __init__(self, brightness=0, contrast=0, saturation=0, hue=0) -> None:
39 | self.brightness = brightness
40 | self.contrast = contrast
41 | self.saturation = saturation
42 | self.hue = hue
43 |
44 | def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]:
45 | if self.brightness > 0:
46 | img = TF.adjust_brightness(img, self.brightness)
47 | if self.contrast > 0:
48 | img = TF.adjust_contrast(img, self.contrast)
49 | if self.saturation > 0:
50 | img = TF.adjust_saturation(img, self.saturation)
51 | if self.hue > 0:
52 | img = TF.adjust_hue(img, self.hue)
53 | return img, mask
54 |
55 |
56 | class AdjustGamma:
57 | def __init__(self, gamma: float, gain: float = 1) -> None:
58 | """
59 | Args:
60 | gamma: Non-negative real number. gamma larger than 1 make the shadows darker, while gamma smaller than 1 make dark regions lighter.
61 | gain: constant multiplier
62 | """
63 | self.gamma = gamma
64 | self.gain = gain
65 |
66 | def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]:
67 | return TF.adjust_gamma(img, self.gamma, self.gain), mask
68 |
69 |
70 | class RandomAdjustSharpness:
71 | def __init__(self, sharpness_factor: float, p: float = 0.5) -> None:
72 | self.sharpness = sharpness_factor
73 | self.p = p
74 |
75 | def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]:
76 | if random.random() < self.p:
77 | img = TF.adjust_sharpness(img, self.sharpness)
78 | return img, mask
79 |
80 |
81 | class RandomAutoContrast:
82 | def __init__(self, p: float = 0.5) -> None:
83 | self.p = p
84 |
85 | def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]:
86 | if random.random() < self.p:
87 | img = TF.autocontrast(img)
88 | return img, mask
89 |
90 |
91 | class RandomGaussianBlur:
92 | def __init__(self, kernel_size: int = 3, p: float = 0.5) -> None:
93 | self.kernel_size = kernel_size
94 | self.p = p
95 |
96 | def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]:
97 | if random.random() < self.p:
98 | img = TF.gaussian_blur(img, self.kernel_size)
99 | return img, mask
100 |
101 |
102 | class RandomHorizontalFlip:
103 | def __init__(self, p: float = 0.5) -> None:
104 | self.p = p
105 |
106 | def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]:
107 | if random.random() < self.p:
108 | return TF.hflip(img), TF.hflip(mask)
109 | return img, mask
110 |
111 |
112 | class RandomVerticalFlip:
113 | def __init__(self, p: float = 0.5) -> None:
114 | self.p = p
115 |
116 | def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]:
117 | if random.random() < self.p:
118 | return TF.vflip(img), TF.vflip(mask)
119 | return img, mask
120 |
121 |
122 | class RandomGrayscale:
123 | def __init__(self, p: float = 0.5) -> None:
124 | self.p = p
125 |
126 | def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]:
127 | if random.random() < self.p:
128 | img = TF.rgb_to_grayscale(img, 3)
129 | return img, mask
130 |
131 |
132 | class Equalize:
133 | def __call__(self, image, label):
134 | return TF.equalize(image), label
135 |
136 |
137 | class Posterize:
138 | def __init__(self, bits=2):
139 | self.bits = bits # 0-8
140 |
141 | def __call__(self, image, label):
142 | return TF.posterize(image, self.bits), label
143 |
144 |
145 | class Affine:
146 | def __init__(self, angle=0, translate=[0, 0], scale=1.0, shear=[0, 0], seg_fill=0):
147 | self.angle = angle
148 | self.translate = translate
149 | self.scale = scale
150 | self.shear = shear
151 | self.seg_fill = seg_fill
152 |
153 | def __call__(self, img, label):
154 | return TF.affine(img, self.angle, self.translate, self.scale, self.shear, TF.InterpolationMode.BILINEAR, 0), TF.affine(label, self.angle, self.translate, self.scale, self.shear, TF.InterpolationMode.NEAREST, self.seg_fill)
155 |
156 |
157 | class RandomRotation:
158 | def __init__(self, degrees: float = 10.0, p: float = 0.2, seg_fill: int = 0, expand: bool = False) -> None:
159 | """Rotate the image by a random angle between -angle and angle with probability p
160 |
161 | Args:
162 | p: probability
163 | angle: rotation angle value in degrees, counter-clockwise.
164 | expand: Optional expansion flag.
165 | If true, expands the output image to make it large enough to hold the entire rotated image.
166 | If false or omitted, make the output image the same size as the input image.
167 | Note that the expand flag assumes rotation around the center and no translation.
168 | """
169 | self.p = p
170 | self.angle = degrees
171 | self.expand = expand
172 | self.seg_fill = seg_fill
173 |
174 | def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]:
175 | random_angle = random.random() * 2 * self.angle - self.angle
176 | if random.random() < self.p:
177 | img = TF.rotate(img, random_angle, TF.InterpolationMode.BILINEAR, self.expand, fill=0)
178 | mask = TF.rotate(mask, random_angle, TF.InterpolationMode.NEAREST, self.expand, fill=self.seg_fill)
179 | return img, mask
180 |
181 |
182 | class CenterCrop:
183 | def __init__(self, size: Union[int, List[int], Tuple[int]]) -> None:
184 | """Crops the image at the center
185 |
186 | Args:
187 | output_size: height and width of the crop box. If int, this size is used for both directions.
188 | """
189 | self.size = (size, size) if isinstance(size, int) else size
190 |
191 | def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]:
192 | return TF.center_crop(img, self.size), TF.center_crop(mask, self.size)
193 |
194 |
195 | class RandomCrop:
196 | def __init__(self, size: Union[int, List[int], Tuple[int]], p: float = 0.5) -> None:
197 | """Randomly Crops the image.
198 |
199 | Args:
200 | output_size: height and width of the crop box. If int, this size is used for both directions.
201 | """
202 | self.size = (size, size) if isinstance(size, int) else size
203 | self.p = p
204 |
205 | def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]:
206 | H, W = img.shape[1:]
207 | tH, tW = self.size
208 |
209 | if random.random() < self.p:
210 | margin_h = max(H - tH, 0)
211 | margin_w = max(W - tW, 0)
212 | y1 = random.randint(0, margin_h+1)
213 | x1 = random.randint(0, margin_w+1)
214 | y2 = y1 + tH
215 | x2 = x1 + tW
216 | img = img[:, y1:y2, x1:x2]
217 | mask = mask[:, y1:y2, x1:x2]
218 | return img, mask
219 |
220 |
221 | class Pad:
222 | def __init__(self, size: Union[List[int], Tuple[int], int], seg_fill: int = 0) -> None:
223 | """Pad the given image on all sides with the given "pad" value.
224 | Args:
225 | size: expected output image size (h, w)
226 | fill: Pixel fill value for constant fill. Default is 0. This value is only used when the padding mode is constant.
227 | """
228 | self.size = size
229 | self.seg_fill = seg_fill
230 |
231 | def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]:
232 | padding = (0, 0, self.size[1]-img.shape[2], self.size[0]-img.shape[1])
233 | return TF.pad(img, padding), TF.pad(mask, padding, self.seg_fill)
234 |
235 |
236 | class ResizePad:
237 | def __init__(self, size: Union[int, Tuple[int], List[int]], seg_fill: int = 0) -> None:
238 | """Resize the input image to the given size.
239 | Args:
240 | size: Desired output size.
241 | If size is a sequence, the output size will be matched to this.
242 | If size is an int, the smaller edge of the image will be matched to this number maintaining the aspect ratio.
243 | """
244 | self.size = size
245 | self.seg_fill = seg_fill
246 |
247 | def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]:
248 | H, W = img.shape[1:]
249 | tH, tW = self.size
250 |
251 | # scale the image
252 | scale_factor = min(tH/H, tW/W) if W > H else max(tH/H, tW/W)
253 | # nH, nW = int(H * scale_factor + 0.5), int(W * scale_factor + 0.5)
254 | nH, nW = round(H*scale_factor), round(W*scale_factor)
255 | img = TF.resize(img, (nH, nW), TF.InterpolationMode.BILINEAR)
256 | mask = TF.resize(mask, (nH, nW), TF.InterpolationMode.NEAREST)
257 |
258 | # pad the image
259 | padding = [0, 0, tW - nW, tH - nH]
260 | img = TF.pad(img, padding, fill=0)
261 | mask = TF.pad(mask, padding, fill=self.seg_fill)
262 | return img, mask
263 |
264 |
265 | class Resize:
266 | def __init__(self, size: Union[int, Tuple[int], List[int]]) -> None:
267 | """Resize the input image to the given size.
268 | Args:
269 | size: Desired output size.
270 | If size is a sequence, the output size will be matched to this.
271 | If size is an int, the smaller edge of the image will be matched to this number maintaining the aspect ratio.
272 | """
273 | self.size = size
274 |
275 | def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]:
276 | H, W = img.shape[1:]
277 |
278 | # scale the image
279 | scale_factor = self.size[0] / min(H, W)
280 | nH, nW = round(H*scale_factor), round(W*scale_factor)
281 | img = TF.resize(img, (nH, nW), TF.InterpolationMode.BILINEAR)
282 | mask = TF.resize(mask, (nH, nW), TF.InterpolationMode.NEAREST)
283 |
284 | # make the image divisible by stride
285 | alignH, alignW = int(math.ceil(nH / 32)) * 32, int(math.ceil(nW / 32)) * 32
286 | img = TF.resize(img, (alignH, alignW), TF.InterpolationMode.BILINEAR)
287 | mask = TF.resize(mask, (alignH, alignW), TF.InterpolationMode.NEAREST)
288 | return img, mask
289 |
290 |
291 | class RandomResizedCrop:
292 | def __init__(self, size: Union[int, Tuple[int], List[int]], scale: Tuple[float, float] = (0.5, 2.0), seg_fill: int = 0) -> None:
293 | """Resize the input image to the given size.
294 | """
295 | self.size = size
296 | self.scale = scale
297 | self.seg_fill = seg_fill
298 |
299 | def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]:
300 | H, W = img.shape[1:]
301 | tH, tW = self.size
302 |
303 | # get the scale
304 | ratio = random.random() * (self.scale[1] - self.scale[0]) + self.scale[0]
305 | # ratio = random.uniform(min(self.scale), max(self.scale))
306 | scale = int(tH*ratio), int(tW*4*ratio)
307 |
308 | # scale the image
309 | scale_factor = min(max(scale)/max(H, W), min(scale)/min(H, W))
310 | nH, nW = int(H * scale_factor + 0.5), int(W * scale_factor + 0.5)
311 | # nH, nW = int(math.ceil(nH / 32)) * 32, int(math.ceil(nW / 32)) * 32
312 | img = TF.resize(img, (nH, nW), TF.InterpolationMode.BILINEAR)
313 | mask = TF.resize(mask, (nH, nW), TF.InterpolationMode.NEAREST)
314 |
315 | # random crop
316 | margin_h = max(img.shape[1] - tH, 0)
317 | margin_w = max(img.shape[2] - tW, 0)
318 | y1 = random.randint(0, margin_h+1)
319 | x1 = random.randint(0, margin_w+1)
320 | y2 = y1 + tH
321 | x2 = x1 + tW
322 | img = img[:, y1:y2, x1:x2]
323 | mask = mask[:, y1:y2, x1:x2]
324 |
325 | # pad the image
326 | if img.shape[1:] != self.size:
327 | padding = [0, 0, tW - img.shape[2], tH - img.shape[1]]
328 | img = TF.pad(img, padding, fill=0)
329 | mask = TF.pad(mask, padding, fill=self.seg_fill)
330 | return img, mask
331 |
332 |
333 |
334 | def get_train_augmentation(size: Union[int, Tuple[int], List[int]], seg_fill: int = 0):
335 | return Compose([
336 | # ColorJitter(brightness=0.0, contrast=0.5, saturation=0.5, hue=0.5),
337 | # RandomAdjustSharpness(sharpness_factor=0.1, p=0.5),
338 | # RandomAutoContrast(p=0.2),
339 | RandomHorizontalFlip(p=0.5),
340 | # RandomVerticalFlip(p=0.5),
341 | # RandomGaussianBlur((3, 3), p=0.5),
342 | # RandomGrayscale(p=0.5),
343 | # RandomRotation(degrees=10, p=0.3, seg_fill=seg_fill),
344 | RandomResizedCrop(size, scale=(0.5, 2.0), seg_fill=seg_fill),
345 | Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
346 | ])
347 |
348 | def get_val_augmentation(size: Union[int, Tuple[int], List[int]]):
349 | return Compose([
350 | Resize(size),
351 | Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
352 | ])
353 |
354 |
355 | if __name__ == '__main__':
356 | h = 230
357 | w = 420
358 | img = torch.randn(3, h, w)
359 | mask = torch.randn(1, h, w)
360 | aug = Compose([
361 | RandomResizedCrop((512, 512)),
362 | # RandomCrop((512, 512), p=1.0),
363 | # Pad((512, 512))
364 | ])
365 | img, mask = aug(img, mask)
366 | print(img.shape, mask.shape)
--------------------------------------------------------------------------------
/semseg/augmentations_mm.py:
--------------------------------------------------------------------------------
1 | import torchvision.transforms.functional as TF
2 | import random
3 | import math
4 | import torch
5 | from torch import Tensor
6 | from typing import Tuple, List, Union, Tuple, Optional
7 |
8 |
9 | class Compose:
10 | def __init__(self, transforms: list) -> None:
11 | self.transforms = transforms
12 |
13 | def __call__(self, sample: list) -> list:
14 | img, mask = sample['img'], sample['mask']
15 | if mask.ndim == 2:
16 | assert img.shape[1:] == mask.shape
17 | else:
18 | assert img.shape[1:] == mask.shape[1:]
19 |
20 | for transform in self.transforms:
21 | sample = transform(sample)
22 |
23 | return sample
24 |
25 |
26 | class Normalize:
27 | def __init__(self, mean: list = (0.485, 0.456, 0.406), std: list = (0.229, 0.224, 0.225)):
28 | self.mean = mean
29 | self.std = std
30 |
31 | def __call__(self, sample: list) -> list:
32 | for k, v in sample.items():
33 | if k == 'mask':
34 | continue
35 | elif k == 'img':
36 | sample[k] = sample[k].float()
37 | sample[k] /= 255
38 | sample[k] = TF.normalize(sample[k], self.mean, self.std)
39 | else:
40 | sample[k] = sample[k].float()
41 | sample[k] /= 255
42 |
43 | return sample
44 |
45 |
46 | class RandomColorJitter:
47 | def __init__(self, p=0.5) -> None:
48 | self.p = p
49 |
50 | def __call__(self, sample: list) -> list:
51 | if random.random() < self.p:
52 | self.brightness = random.uniform(0.5, 1.5)
53 | sample['img'] = TF.adjust_brightness(sample['img'], self.brightness)
54 | self.contrast = random.uniform(0.5, 1.5)
55 | sample['img'] = TF.adjust_contrast(sample['img'], self.contrast)
56 | self.saturation = random.uniform(0.5, 1.5)
57 | sample['img'] = TF.adjust_saturation(sample['img'], self.saturation)
58 | return sample
59 |
60 |
61 | class AdjustGamma:
62 | def __init__(self, gamma: float, gain: float = 1) -> None:
63 | """
64 | Args:
65 | gamma: Non-negative real number. gamma larger than 1 make the shadows darker, while gamma smaller than 1 make dark regions lighter.
66 | gain: constant multiplier
67 | """
68 | self.gamma = gamma
69 | self.gain = gain
70 |
71 | def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]:
72 | return TF.adjust_gamma(img, self.gamma, self.gain), mask
73 |
74 |
75 | class RandomAdjustSharpness:
76 | def __init__(self, sharpness_factor: float, p: float = 0.5) -> None:
77 | self.sharpness = sharpness_factor
78 | self.p = p
79 |
80 | def __call__(self, sample: list) -> list:
81 | if random.random() < self.p:
82 | sample['img'] = TF.adjust_sharpness(sample['img'], self.sharpness)
83 | return sample
84 |
85 |
86 | class RandomAutoContrast:
87 | def __init__(self, p: float = 0.5) -> None:
88 | self.p = p
89 |
90 | def __call__(self, sample: list) -> list:
91 | if random.random() < self.p:
92 | sample['img'] = TF.autocontrast(sample['img'])
93 | return sample
94 |
95 |
96 | class RandomGaussianBlur:
97 | def __init__(self, kernel_size: int = 3, p: float = 0.5) -> None:
98 | self.kernel_size = kernel_size
99 | self.p = p
100 |
101 | def __call__(self, sample: list) -> list:
102 | if random.random() < self.p:
103 | sample['img'] = TF.gaussian_blur(sample['img'], self.kernel_size)
104 | # img = TF.gaussian_blur(img, self.kernel_size)
105 | return sample
106 |
107 |
108 | class RandomHorizontalFlip:
109 | def __init__(self, p: float = 0.5) -> None:
110 | self.p = p
111 |
112 | def __call__(self, sample: list) -> list:
113 | if random.random() < self.p:
114 | for k, v in sample.items():
115 | sample[k] = TF.hflip(v)
116 | return sample
117 | return sample
118 |
119 |
120 | class RandomVerticalFlip:
121 | def __init__(self, p: float = 0.5) -> None:
122 | self.p = p
123 |
124 | def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]:
125 | if random.random() < self.p:
126 | return TF.vflip(img), TF.vflip(mask)
127 | return img, mask
128 |
129 |
130 | class RandomGrayscale:
131 | def __init__(self, p: float = 0.5) -> None:
132 | self.p = p
133 |
134 | def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]:
135 | if random.random() < self.p:
136 | img = TF.rgb_to_grayscale(img, 3)
137 | return img, mask
138 |
139 |
140 | class Equalize:
141 | def __call__(self, image, label):
142 | return TF.equalize(image), label
143 |
144 |
145 | class Posterize:
146 | def __init__(self, bits=2):
147 | self.bits = bits # 0-8
148 |
149 | def __call__(self, image, label):
150 | return TF.posterize(image, self.bits), label
151 |
152 |
153 | class Affine:
154 | def __init__(self, angle=0, translate=[0, 0], scale=1.0, shear=[0, 0], seg_fill=0):
155 | self.angle = angle
156 | self.translate = translate
157 | self.scale = scale
158 | self.shear = shear
159 | self.seg_fill = seg_fill
160 |
161 | def __call__(self, img, label):
162 | return TF.affine(img, self.angle, self.translate, self.scale, self.shear, TF.InterpolationMode.BILINEAR, 0), TF.affine(label, self.angle, self.translate, self.scale, self.shear, TF.InterpolationMode.NEAREST, self.seg_fill)
163 |
164 |
165 | class RandomRotation:
166 | def __init__(self, degrees: float = 10.0, p: float = 0.2, seg_fill: int = 0, expand: bool = False) -> None:
167 | """Rotate the image by a random angle between -angle and angle with probability p
168 |
169 | Args:
170 | p: probability
171 | angle: rotation angle value in degrees, counter-clockwise.
172 | expand: Optional expansion flag.
173 | If true, expands the output image to make it large enough to hold the entire rotated image.
174 | If false or omitted, make the output image the same size as the input image.
175 | Note that the expand flag assumes rotation around the center and no translation.
176 | """
177 | self.p = p
178 | self.angle = degrees
179 | self.expand = expand
180 | self.seg_fill = seg_fill
181 |
182 | def __call__(self, sample: list) -> list:
183 | random_angle = random.random() * 2 * self.angle - self.angle
184 | if random.random() < self.p:
185 | for k, v in sample.items():
186 | if k == 'mask':
187 | sample[k] = TF.rotate(v, random_angle, TF.InterpolationMode.NEAREST, self.expand, fill=self.seg_fill)
188 | else:
189 | sample[k] = TF.rotate(v, random_angle, TF.InterpolationMode.BILINEAR, self.expand, fill=0)
190 | # img = TF.rotate(img, random_angle, TF.InterpolationMode.BILINEAR, self.expand, fill=0)
191 | # mask = TF.rotate(mask, random_angle, TF.InterpolationMode.NEAREST, self.expand, fill=self.seg_fill)
192 | return sample
193 |
194 |
195 | class CenterCrop:
196 | def __init__(self, size: Union[int, List[int], Tuple[int]]) -> None:
197 | """Crops the image at the center
198 |
199 | Args:
200 | output_size: height and width of the crop box. If int, this size is used for both directions.
201 | """
202 | self.size = (size, size) if isinstance(size, int) else size
203 |
204 | def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]:
205 | return TF.center_crop(img, self.size), TF.center_crop(mask, self.size)
206 |
207 |
208 | class RandomCrop:
209 | def __init__(self, size: Union[int, List[int], Tuple[int]], p: float = 0.5) -> None:
210 | """Randomly Crops the image.
211 |
212 | Args:
213 | output_size: height and width of the crop box. If int, this size is used for both directions.
214 | """
215 | self.size = (size, size) if isinstance(size, int) else size
216 | self.p = p
217 |
218 | def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]:
219 | H, W = img.shape[1:]
220 | tH, tW = self.size
221 |
222 | if random.random() < self.p:
223 | margin_h = max(H - tH, 0)
224 | margin_w = max(W - tW, 0)
225 | y1 = random.randint(0, margin_h+1)
226 | x1 = random.randint(0, margin_w+1)
227 | y2 = y1 + tH
228 | x2 = x1 + tW
229 | img = img[:, y1:y2, x1:x2]
230 | mask = mask[:, y1:y2, x1:x2]
231 | return img, mask
232 |
233 |
234 | class Pad:
235 | def __init__(self, size: Union[List[int], Tuple[int], int], seg_fill: int = 0) -> None:
236 | """Pad the given image on all sides with the given "pad" value.
237 | Args:
238 | size: expected output image size (h, w)
239 | fill: Pixel fill value for constant fill. Default is 0. This value is only used when the padding mode is constant.
240 | """
241 | self.size = size
242 | self.seg_fill = seg_fill
243 |
244 | def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]:
245 | padding = (0, 0, self.size[1]-img.shape[2], self.size[0]-img.shape[1])
246 | return TF.pad(img, padding), TF.pad(mask, padding, self.seg_fill)
247 |
248 |
249 | class ResizePad:
250 | def __init__(self, size: Union[int, Tuple[int], List[int]], seg_fill: int = 0) -> None:
251 | """Resize the input image to the given size.
252 | Args:
253 | size: Desired output size.
254 | If size is a sequence, the output size will be matched to this.
255 | If size is an int, the smaller edge of the image will be matched to this number maintaining the aspect ratio.
256 | """
257 | self.size = size
258 | self.seg_fill = seg_fill
259 |
260 | def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]:
261 | H, W = img.shape[1:]
262 | tH, tW = self.size
263 |
264 | # scale the image
265 | scale_factor = min(tH/H, tW/W) if W > H else max(tH/H, tW/W)
266 | # nH, nW = int(H * scale_factor + 0.5), int(W * scale_factor + 0.5)
267 | nH, nW = round(H*scale_factor), round(W*scale_factor)
268 | img = TF.resize(img, (nH, nW), TF.InterpolationMode.BILINEAR)
269 | mask = TF.resize(mask, (nH, nW), TF.InterpolationMode.NEAREST)
270 |
271 | # pad the image
272 | padding = [0, 0, tW - nW, tH - nH]
273 | img = TF.pad(img, padding, fill=0)
274 | mask = TF.pad(mask, padding, fill=self.seg_fill)
275 | return img, mask
276 |
277 |
278 | class Resize:
279 | def __init__(self, size: Union[int, Tuple[int], List[int]]) -> None:
280 | """Resize the input image to the given size.
281 | Args:
282 | size: Desired output size.
283 | If size is a sequence, the output size will be matched to this.
284 | If size is an int, the smaller edge of the image will be matched to this number maintaining the aspect ratio.
285 | """
286 | self.size = size
287 |
288 | def __call__(self, sample:list) -> list:
289 | H, W = sample['img'].shape[1:]
290 |
291 | # scale the image
292 | scale_factor = self.size[0] / min(H, W)
293 | nH, nW = round(H*scale_factor), round(W*scale_factor)
294 | for k, v in sample.items():
295 | if k == 'mask':
296 | sample[k] = TF.resize(v, (nH, nW), TF.InterpolationMode.NEAREST)
297 | else:
298 | sample[k] = TF.resize(v, (nH, nW), TF.InterpolationMode.BILINEAR)
299 | # img = TF.resize(img, (nH, nW), TF.InterpolationMode.BILINEAR)
300 | # mask = TF.resize(mask, (nH, nW), TF.InterpolationMode.NEAREST)
301 |
302 | # make the image divisible by stride
303 | alignH, alignW = int(math.ceil(nH / 32)) * 32, int(math.ceil(nW / 32)) * 32
304 |
305 | for k, v in sample.items():
306 | if k == 'mask':
307 | sample[k] = TF.resize(v, (alignH, alignW), TF.InterpolationMode.NEAREST)
308 | else:
309 | sample[k] = TF.resize(v, (alignH, alignW), TF.InterpolationMode.BILINEAR)
310 | # img = TF.resize(img, (alignH, alignW), TF.InterpolationMode.BILINEAR)
311 | # mask = TF.resize(mask, (alignH, alignW), TF.InterpolationMode.NEAREST)
312 | return sample
313 |
314 |
315 | class RandomResizedCrop:
316 | def __init__(self, size: Union[int, Tuple[int], List[int]], scale: Tuple[float, float] = (0.5, 2.0), seg_fill: int = 0) -> None:
317 | """Resize the input image to the given size.
318 | """
319 | self.size = size
320 | self.scale = scale
321 | self.seg_fill = seg_fill
322 |
323 | def __call__(self, sample: list) -> list:
324 | # img, mask = sample['img'], sample['mask']
325 | H, W = sample['img'].shape[1:]
326 | tH, tW = self.size
327 |
328 | # get the scale
329 | ratio = random.random() * (self.scale[1] - self.scale[0]) + self.scale[0]
330 | # ratio = random.uniform(min(self.scale), max(self.scale))
331 | scale = int(tH*ratio), int(tW*4*ratio)
332 | # scale the image
333 | scale_factor = min(max(scale)/max(H, W), min(scale)/min(H, W))
334 | nH, nW = int(H * scale_factor + 0.5), int(W * scale_factor + 0.5)
335 | # nH, nW = int(math.ceil(nH / 32)) * 32, int(math.ceil(nW / 32)) * 32
336 | for k, v in sample.items():
337 | if k == 'mask':
338 | sample[k] = TF.resize(v, (nH, nW), TF.InterpolationMode.NEAREST)
339 | else:
340 | sample[k] = TF.resize(v, (nH, nW), TF.InterpolationMode.BILINEAR)
341 |
342 | # random crop
343 | margin_h = max(sample['img'].shape[1] - tH, 0)
344 | margin_w = max(sample['img'].shape[2] - tW, 0)
345 | y1 = random.randint(0, margin_h+1)
346 | x1 = random.randint(0, margin_w+1)
347 | y2 = y1 + tH
348 | x2 = x1 + tW
349 | for k, v in sample.items():
350 | sample[k] = v[:, y1:y2, x1:x2]
351 |
352 | # pad the image
353 | if sample['img'].shape[1:] != self.size:
354 | padding = [0, 0, tW - sample['img'].shape[2], tH - sample['img'].shape[1]]
355 | for k, v in sample.items():
356 | if k == 'mask':
357 | sample[k] = TF.pad(v, padding, fill=self.seg_fill)
358 | else:
359 | sample[k] = TF.pad(v, padding, fill=0)
360 |
361 | return sample
362 |
363 |
364 |
365 | def get_train_augmentation(size: Union[int, Tuple[int], List[int]], seg_fill: int = 0):
366 | return Compose([
367 | RandomColorJitter(p=0.2), #
368 | RandomHorizontalFlip(p=0.5), #
369 | RandomGaussianBlur((3, 3), p=0.2), #
370 | RandomResizedCrop(size, scale=(0.5, 2.0), seg_fill=seg_fill), #
371 | Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
372 | ])
373 |
374 | def get_val_augmentation(size: Union[int, Tuple[int], List[int]]):
375 | return Compose([
376 | Resize(size),
377 | Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
378 | ])
379 |
380 |
381 | if __name__ == '__main__':
382 | h = 230
383 | w = 420
384 | sample = {}
385 | sample['img'] = torch.randn(3, h, w)
386 | sample['depth'] = torch.randn(3, h, w)
387 | sample['lidar'] = torch.randn(3, h, w)
388 | sample['event'] = torch.randn(3, h, w)
389 | sample['mask'] = torch.randn(1, h, w)
390 | aug = Compose([
391 | RandomHorizontalFlip(p=0.5),
392 | RandomResizedCrop((512, 512)),
393 | Resize((224, 224)),
394 | Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
395 | ])
396 | sample = aug(sample)
397 | for k, v in sample.items():
398 | print(k, v.shape)
--------------------------------------------------------------------------------
/semseg/models/backbones/mmsformer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn, Tensor
3 | from torch.nn import functional as F
4 | from semseg.models.layers import DropPath
5 | import torch.nn.init as init
6 |
7 |
8 | class ChannelAttentionBlock(nn.Module):
9 | def __init__(self, channel, reduction=16):
10 | super(ChannelAttentionBlock, self).__init__()
11 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
12 | self.fc = nn.Sequential(
13 | nn.Linear(channel, channel // reduction, bias=False),
14 | nn.ReLU(inplace=True),
15 | nn.Linear(channel // reduction, channel, bias=False),
16 | nn.Sigmoid()
17 | )
18 |
19 | # Initialize linear layers with Kaiming initialization
20 | for m in self.fc:
21 | if isinstance(m, nn.Linear):
22 | init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
23 |
24 | def forward(self, x, H, W):
25 | B, _, C = x.shape
26 | x = x.transpose(1, 2).view(B, C, H, W)
27 | b, c, _, _ = x.size()
28 | y = self.avg_pool(x).view(b, c)
29 | y = self.fc(y).view(b, c, 1, 1)
30 | return (x * y.expand_as(x)).flatten(2).transpose(1, 2)
31 |
32 |
33 | class DWConv(nn.Module):
34 | def __init__(self, dim):
35 | super().__init__()
36 | self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, groups=dim)
37 |
38 | def forward(self, x: Tensor, H, W) -> Tensor:
39 | B, _, C = x.shape
40 | x = x.transpose(1, 2).view(B, C, H, W)
41 | x = self.dwconv(x)
42 | return x.flatten(2).transpose(1, 2)
43 |
44 |
45 | class CustomDWConv(nn.Module):
46 | def __init__(self, dim, kernel):
47 | super().__init__()
48 | self.dwconv = nn.Conv2d(dim, dim, kernel, 1, padding='same', groups=dim)
49 |
50 | # Apply Kaiming initialization with fan-in to the dwconv layer
51 | init.kaiming_normal_(self.dwconv.weight, mode='fan_in', nonlinearity='relu')
52 |
53 | def forward(self, x: Tensor, H, W) -> Tensor:
54 | B, _, C = x.shape
55 | x = x.transpose(1, 2).view(B, C, H, W)
56 | x = self.dwconv(x)
57 | return x.flatten(2).transpose(1, 2)
58 |
59 |
60 | class CustomPWConv(nn.Module):
61 | def __init__(self, dim):
62 | super().__init__()
63 | self.pwconv = nn.Conv2d(dim, dim, 1)
64 | self.bn = nn.BatchNorm2d(dim)
65 |
66 | # Initialize pwconv layer with Kaiming initialization
67 | init.kaiming_normal_(self.pwconv.weight, mode='fan_in', nonlinearity='relu')
68 |
69 | def forward(self, x: Tensor, H, W) -> Tensor:
70 | B, _, C = x.shape
71 | x = x.transpose(1, 2).view(B, C, H, W)
72 | x = self.bn(self.pwconv(x))
73 | return x.flatten(2).transpose(1, 2)
74 |
75 |
76 | class MLP(nn.Module):
77 | def __init__(self, c1, c2):
78 | super().__init__()
79 | self.fc1 = nn.Linear(c1, c2)
80 | self.dwconv = DWConv(c2)
81 | self.fc2 = nn.Linear(c2, c1)
82 |
83 | def forward(self, x: Tensor, H, W) -> Tensor:
84 | return self.fc2(F.gelu(self.dwconv(self.fc1(x), H, W)))
85 |
86 |
87 | class MixFFN(nn.Module):
88 | def __init__(self, c1, c2):
89 | super().__init__()
90 | self.fc1 = nn.Linear(c1, c2)
91 | self.pwconv1 = CustomPWConv(c2)
92 | self.dwconv3 = CustomDWConv(c2, 3)
93 | self.dwconv5 = CustomDWConv(c2, 5)
94 | self.dwconv7 = CustomDWConv(c2, 7)
95 | self.pwconv2 = CustomPWConv(c2)
96 | self.fc2 = nn.Linear(c2, c1)
97 |
98 | # Initialize fc1 layer with Kaiming initialization
99 | init.kaiming_normal_(self.fc1.weight, mode='fan_in', nonlinearity='relu')
100 | init.kaiming_normal_(self.fc2.weight, mode='fan_in', nonlinearity='relu')
101 |
102 | def forward(self, x: Tensor, H, W) -> Tensor:
103 | x = self.fc1(x)
104 | x = self.pwconv1(x, H, W)
105 | x1 = self.dwconv3(x, H, W)
106 | x2 = self.dwconv5(x, H, W)
107 | x3 = self.dwconv7(x, H, W)
108 | return self.fc2(F.gelu(self.pwconv2(x + x1 + x2 + x3, H, W)))
109 |
110 |
111 | class FusionBlock(nn.Module):
112 | def __init__(self, channels, reduction=16, num_modals=2):
113 | super(FusionBlock, self).__init__()
114 | self.channels = channels
115 | self.reduction = reduction
116 | self.num_modals = num_modals
117 |
118 | self.liner_fusion_layers = nn.ModuleList([
119 | nn.Linear(self.channels[0]*self.num_modals, self.channels[0]),
120 | nn.Linear(self.channels[1]*self.num_modals, self.channels[1]),
121 | nn.Linear(self.channels[2]*self.num_modals, self.channels[2]),
122 | nn.Linear(self.channels[3]*self.num_modals, self.channels[3]),
123 | ])
124 |
125 | self.mix_ffn = nn.ModuleList([
126 | MixFFN(self.channels[0], self.channels[0]),
127 | MixFFN(self.channels[1], self.channels[1]),
128 | MixFFN(self.channels[2], self.channels[2]),
129 | MixFFN(self.channels[3], self.channels[3]),
130 | ])
131 |
132 | self.channel_attns = nn.ModuleList([
133 | ChannelAttentionBlock(self.channels[0]),
134 | ChannelAttentionBlock(self.channels[1]),
135 | ChannelAttentionBlock(self.channels[2]),
136 | ChannelAttentionBlock(self.channels[3]),
137 | ])
138 |
139 | # Initialize linear fusion layers with Kaiming initialization
140 | for linear_layer in self.liner_fusion_layers:
141 | init.kaiming_normal_(linear_layer.weight, mode='fan_in', nonlinearity='relu')
142 |
143 | def forward(self, x, layer_idx):
144 | B, C, H, W = x[0].shape
145 | x = torch.cat(x, dim=1)
146 | x = x.flatten(2).transpose(1, 2)
147 | x_sum = self.liner_fusion_layers[layer_idx](x)
148 | x_sum = self.mix_ffn[layer_idx](x_sum, H, W) + self.channel_attns[layer_idx](x_sum, H, W)
149 | return x_sum.reshape(B, H, W, -1).permute(0, 3, 1, 2)
150 |
151 |
152 | class Attention(nn.Module):
153 | def __init__(self, dim, head, sr_ratio):
154 | super().__init__()
155 | self.head = head
156 | self.sr_ratio = sr_ratio
157 | self.scale = (dim // head) ** -0.5
158 | self.q = nn.Linear(dim, dim)
159 | self.kv = nn.Linear(dim, dim*2)
160 | self.proj = nn.Linear(dim, dim)
161 |
162 | if sr_ratio > 1:
163 | self.sr = nn.Conv2d(dim, dim, sr_ratio, sr_ratio)
164 | self.norm = nn.LayerNorm(dim)
165 |
166 | def forward(self, x: Tensor, H, W) -> Tensor:
167 | B, N, C = x.shape
168 | q = self.q(x).reshape(B, N, self.head, C // self.head).permute(0, 2, 1, 3)
169 |
170 | if self.sr_ratio > 1:
171 | x = x.permute(0, 2, 1).reshape(B, C, H, W)
172 | x = self.sr(x).reshape(B, C, -1).permute(0, 2, 1)
173 | x = self.norm(x)
174 |
175 | k, v = self.kv(x).reshape(B, -1, 2, self.head, C // self.head).permute(2, 0, 3, 1, 4)
176 |
177 | attn = (q @ k.transpose(-2, -1)) * self.scale
178 | attn = attn.softmax(dim=-1)
179 |
180 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
181 | x = self.proj(x)
182 | return x
183 |
184 |
185 | class PatchEmbed(nn.Module):
186 | def __init__(self, c1=3, c2=32, patch_size=7, stride=4, padding=0):
187 | super().__init__()
188 | self.proj = nn.Conv2d(c1, c2, patch_size, stride, padding) # padding=(ps[0]//2, ps[1]//2)
189 | self.norm = nn.LayerNorm(c2)
190 |
191 | def forward(self, x: Tensor) -> Tensor:
192 | x = self.proj(x)
193 | _, _, H, W = x.shape
194 | x = x.flatten(2).transpose(1, 2)
195 | x = self.norm(x)
196 | return x, H, W
197 |
198 |
199 | class Block(nn.Module):
200 | def __init__(self, dim, head, sr_ratio=1, dpr=0., is_fan=False):
201 | super().__init__()
202 | self.norm1 = nn.LayerNorm(dim)
203 | self.attn = Attention(dim, head, sr_ratio)
204 | self.drop_path = DropPath(dpr) if dpr > 0. else nn.Identity()
205 | self.norm2 = nn.LayerNorm(dim)
206 | self.mlp = MLP(dim, int(dim*4)) if not is_fan else ChannelProcessing(dim, mlp_hidden_dim=int(dim*4))
207 |
208 | def forward(self, x: Tensor, H, W) -> Tensor:
209 | x = x + self.drop_path(self.attn(self.norm1(x), H, W))
210 | x = x + self.drop_path(self.mlp(self.norm2(x), H, W))
211 | return x
212 |
213 |
214 | class ChannelProcessing(nn.Module):
215 | def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., drop_path=0., mlp_hidden_dim=None, norm_layer=nn.LayerNorm):
216 | super().__init__()
217 | assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
218 | self.dim = dim
219 | self.num_heads = num_heads
220 |
221 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
222 | self.mlp_v = MLP(dim, mlp_hidden_dim)
223 | self.norm_v = norm_layer(dim)
224 |
225 | self.q = nn.Linear(dim, dim, bias=qkv_bias)
226 | self.pool = nn.AdaptiveAvgPool2d((None, 1))
227 | self.sigmoid = nn.Sigmoid()
228 |
229 | def forward(self, x, H, W, atten=None):
230 | B, N, C = x.shape
231 |
232 | v = x.reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
233 | q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
234 | k = x.reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
235 |
236 | q = q.softmax(-2).transpose(-1,-2)
237 | _, _, Nk, Ck = k.shape
238 | k = k.softmax(-2)
239 | k = torch.nn.functional.avg_pool2d(k, (1, Ck))
240 |
241 | attn = self.sigmoid(q @ k)
242 |
243 | Bv, Hd, Nv, Cv = v.shape
244 | v = self.norm_v(self.mlp_v(v.transpose(1, 2).reshape(Bv, Nv, Hd*Cv), H, W)).reshape(Bv, Nv, Hd, Cv).transpose(1, 2)
245 | x = (attn * v.transpose(-1, -2)).permute(0, 3, 1, 2).reshape(B, N, C)
246 | return x
247 |
248 |
249 | mit_settings = {
250 | 'B0': [[32, 64, 160, 256], [2, 2, 2, 2]],
251 | 'B1': [[64, 128, 320, 512], [2, 2, 2, 2]],
252 | 'B2': [[64, 128, 320, 512], [3, 4, 6, 3]],
253 | 'B3': [[64, 128, 320, 512], [3, 4, 18, 3]],
254 | 'B4': [[64, 128, 320, 512], [3, 8, 27, 3]],
255 | 'B5': [[64, 128, 320, 512], [3, 6, 40, 3]]
256 | }
257 |
258 |
259 | class MixTransformer(nn.Module):
260 | def __init__(self, model_name: str = 'B0', modality: str = 'depth'):
261 | super().__init__()
262 | assert model_name in mit_settings.keys(), f"Model name should be in {list(cmnext_settings.keys())}"
263 | # self.model_name = 'B2'
264 | self.model_name = model_name
265 | # TODO: Must comment the following line later
266 | # self.model_name = 'B2' if modality == 'depth' else model_name
267 | embed_dims, depths = mit_settings[self.model_name]
268 | self.modality = modality
269 | drop_path_rate = 0.1
270 | self.channels = embed_dims
271 |
272 | # patch_embed
273 | self.patch_embed1 = PatchEmbed(3, embed_dims[0], 7, 4, 7//2)
274 | self.patch_embed2 = PatchEmbed(embed_dims[0], embed_dims[1], 3, 2, 3//2)
275 | self.patch_embed3 = PatchEmbed(embed_dims[1], embed_dims[2], 3, 2, 3//2)
276 | self.patch_embed4 = PatchEmbed(embed_dims[2], embed_dims[3], 3, 2, 3//2)
277 |
278 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
279 |
280 | cur = 0
281 | self.block1 = nn.ModuleList([Block(embed_dims[0], 1, 8, dpr[cur+i]) for i in range(depths[0])])
282 | self.norm1 = nn.LayerNorm(embed_dims[0])
283 |
284 | cur += depths[0]
285 | self.block2 = nn.ModuleList([Block(embed_dims[1], 2, 4, dpr[cur+i]) for i in range(depths[1])])
286 | self.norm2 = nn.LayerNorm(embed_dims[1])
287 |
288 | cur += depths[1]
289 | self.block3 = nn.ModuleList([Block(embed_dims[2], 5, 2, dpr[cur+i]) for i in range(depths[2])])
290 | self.norm3 = nn.LayerNorm(embed_dims[2])
291 |
292 | cur += depths[2]
293 | self.block4 = nn.ModuleList([Block(embed_dims[3], 8, 1, dpr[cur+i]) for i in range(depths[3])])
294 | self.norm4 = nn.LayerNorm(embed_dims[3])
295 |
296 | # Initialize with pretrained weights
297 | self.init_weights()
298 |
299 | def init_weights(self):
300 | print(f"Initializing weight for {self.modality}...")
301 | checkpoint = torch.load(f'checkpoints/pretrained/segformer/mit_{self.model_name.lower()}.pth', map_location=torch.device('cpu'))
302 | if 'state_dict' in checkpoint.keys():
303 | checkpoint = checkpoint['state_dict']
304 | msg = self.load_state_dict(checkpoint, strict=False)
305 | del checkpoint
306 | print(f"Weight init complete with message: {msg}")
307 |
308 | def forward(self, x: Tensor) -> list:
309 | x_cam = x
310 |
311 | B = x_cam.shape[0]
312 | outs = []
313 | # stage 1
314 | x_cam, H, W = self.patch_embed1(x_cam)
315 | for blk in self.block1:
316 | x_cam = blk(x_cam, H, W)
317 | x1_cam = self.norm1(x_cam).reshape(B, H, W, -1).permute(0, 3, 1, 2)
318 | outs.append(x1_cam)
319 |
320 | # stage 2
321 | x_cam, H, W = self.patch_embed2(x1_cam)
322 | for blk in self.block2:
323 | x_cam = blk(x_cam, H, W)
324 | x2_cam = self.norm2(x_cam).reshape(B, H, W, -1).permute(0, 3, 1, 2)
325 | outs.append(x2_cam)
326 |
327 | # stage 3
328 | x_cam, H, W = self.patch_embed3(x2_cam)
329 | for blk in self.block3:
330 | x_cam = blk(x_cam, H, W)
331 | x3_cam = self.norm3(x_cam).reshape(B, H, W, -1).permute(0, 3, 1, 2)
332 | outs.append(x3_cam)
333 |
334 | # stage 4
335 | x_cam, H, W = self.patch_embed4(x3_cam)
336 | for blk in self.block4:
337 | x_cam = blk(x_cam, H, W)
338 | x4_cam = self.norm4(x_cam).reshape(B, H, W, -1).permute(0, 3, 1, 2)
339 | outs.append(x4_cam)
340 |
341 | return outs
342 |
343 |
344 | class MMSFormer(nn.Module):
345 | def __init__(self, model_name: str = 'B0', modals: list = ['rgb', 'depth', 'event', 'lidar']):
346 | super().__init__()
347 | assert model_name in mit_settings.keys(), f"Model name should be in {list(cmnext_settings.keys())}"
348 | embed_dims, depths = mit_settings[model_name]
349 | self.modals = modals[1:] if len(modals)>1 else []
350 | self.num_modals = len(self.modals)
351 | drop_path_rate = 0.1
352 | self.channels = embed_dims
353 |
354 | # patch_embed
355 | self.patch_embed1 = PatchEmbed(3, embed_dims[0], 7, 4, 7//2)
356 | self.patch_embed2 = PatchEmbed(embed_dims[0], embed_dims[1], 3, 2, 3//2)
357 | self.patch_embed3 = PatchEmbed(embed_dims[1], embed_dims[2], 3, 2, 3//2)
358 | self.patch_embed4 = PatchEmbed(embed_dims[2], embed_dims[3], 3, 2, 3//2)
359 |
360 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
361 |
362 | cur = 0
363 | self.block1 = nn.ModuleList([Block(embed_dims[0], 1, 8, dpr[cur+i]) for i in range(depths[0])])
364 | self.norm1 = nn.LayerNorm(embed_dims[0])
365 |
366 | cur += depths[0]
367 | self.block2 = nn.ModuleList([Block(embed_dims[1], 2, 4, dpr[cur+i]) for i in range(depths[1])])
368 | self.norm2 = nn.LayerNorm(embed_dims[1])
369 |
370 | cur += depths[1]
371 | self.block3 = nn.ModuleList([Block(embed_dims[2], 5, 2, dpr[cur+i]) for i in range(depths[2])])
372 | self.norm3 = nn.LayerNorm(embed_dims[2])
373 |
374 | cur += depths[2]
375 | self.block4 = nn.ModuleList([Block(embed_dims[3], 8, 1, dpr[cur+i]) for i in range(depths[3])])
376 | self.norm4 = nn.LayerNorm(embed_dims[3])
377 |
378 | # Have extra modality
379 | if self.num_modals > 0:
380 | # Backbones and Fusion Block for extra modalities
381 | self.extra_mit = nn.ModuleList([MixTransformer('B1', self.modals[i]) for i in range(self.num_modals)])
382 | self.fusion_block = FusionBlock(self.channels, reduction=16, num_modals=self.num_modals+1)
383 |
384 | def forward(self, x: list) -> list:
385 | x_cam = x[0]
386 | if self.num_modals > 0:
387 | x_ext = x[1:]
388 | B = x_cam.shape[0]
389 | outs = []
390 |
391 | # stage 1
392 | x_cam, H, W = self.patch_embed1(x_cam)
393 | for blk in self.block1:
394 | x_cam = blk(x_cam, H, W)
395 | x1_cam = self.norm1(x_cam).reshape(B, H, W, -1).permute(0, 3, 1, 2)
396 | # Extra Modalities
397 | if self.num_modals > 0:
398 | for i in range(self.num_modals):
399 | x_ext[i], _, _ = self.extra_mit[i].patch_embed1(x_ext[i])
400 | for blk in self.extra_mit[i].block1:
401 | x_ext[i] = blk(x_ext[i], H, W)
402 | x_ext[i] = self.extra_mit[i].norm1(x_ext[i]).reshape(B, H, W, -1).permute(0, 3, 1, 2)
403 | x_fused = self.fusion_block([x1_cam, *x_ext], layer_idx=0)
404 | outs.append(x_fused)
405 | else:
406 | outs.append(x1_cam)
407 |
408 | # stage 2
409 | x_cam, H, W = self.patch_embed2(x1_cam)
410 | for blk in self.block2:
411 | x_cam = blk(x_cam, H, W)
412 | x2_cam = self.norm2(x_cam).reshape(B, H, W, -1).permute(0, 3, 1, 2)
413 | # Extra Modalities
414 | if self.num_modals > 0:
415 | for i in range(self.num_modals):
416 | x_ext[i], _, _ = self.extra_mit[i].patch_embed2(x_ext[i])
417 | for blk in self.extra_mit[i].block2:
418 | x_ext[i] = blk(x_ext[i], H, W)
419 | x_ext[i] = self.extra_mit[i].norm2(x_ext[i]).reshape(B, H, W, -1).permute(0, 3, 1, 2)
420 | x_fused = self.fusion_block([x2_cam, *x_ext], layer_idx=1)
421 | outs.append(x_fused)
422 | else:
423 | outs.append(x2_cam)
424 |
425 | # stage 3
426 | x_cam, H, W = self.patch_embed3(x2_cam)
427 | for blk in self.block3:
428 | x_cam = blk(x_cam, H, W)
429 | x3_cam = self.norm3(x_cam).reshape(B, H, W, -1).permute(0, 3, 1, 2)
430 | # Extra Modalities
431 | if self.num_modals > 0:
432 | for i in range(self.num_modals):
433 | x_ext[i], _, _ = self.extra_mit[i].patch_embed3(x_ext[i])
434 | for blk in self.extra_mit[i].block3:
435 | x_ext[i] = blk(x_ext[i], H, W)
436 | x_ext[i] = self.extra_mit[i].norm3(x_ext[i]).reshape(B, H, W, -1).permute(0, 3, 1, 2)
437 | x_fused = self.fusion_block([x3_cam, *x_ext], layer_idx=2)
438 | outs.append(x_fused)
439 | else:
440 | outs.append(x3_cam)
441 |
442 | # stage 4
443 | x_cam, H, W = self.patch_embed4(x3_cam)
444 | for blk in self.block4:
445 | x_cam = blk(x_cam, H, W)
446 | x4_cam = self.norm4(x_cam).reshape(B, H, W, -1).permute(0, 3, 1, 2)
447 | # Extra Modalities
448 | if self.num_modals > 0:
449 | for i in range(self.num_modals):
450 | x_ext[i], _, _ = self.extra_mit[i].patch_embed4(x_ext[i])
451 | for blk in self.extra_mit[i].block4:
452 | x_ext[i] = blk(x_ext[i], H, W)
453 | x_ext[i] = self.extra_mit[i].norm4(x_ext[i]).reshape(B, H, W, -1).permute(0, 3, 1, 2)
454 | x_fused = self.fusion_block([x4_cam, *x_ext], layer_idx=3)
455 | outs.append(x_fused)
456 | else:
457 | outs.append(x4_cam)
458 |
459 | return outs
460 |
461 |
462 | if __name__ == '__main__':
463 | modals = ['img', 'aolp', 'dolp', 'nir']
464 | x = [torch.zeros(1, 3, 1024, 1024), torch.ones(1, 3, 1024, 1024), torch.ones(1, 3, 1024, 1024)*2, torch.ones(1, 3, 1024, 1024) *3]
465 | model = MMSFormer('B2', modals)
466 | outs = model(x)
467 | for y in outs:
468 | print(y.shape)
469 |
470 |
--------------------------------------------------------------------------------
/semseg/datasets/mcubes.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 | from torch import Tensor
5 | from torch.utils.data import Dataset
6 | from torchvision import io
7 | from torchvision import transforms
8 | from pathlib import Path
9 | from typing import Tuple
10 | import glob
11 | import einops
12 | from torch.utils.data import DataLoader
13 | from torch.utils.data import DistributedSampler, RandomSampler
14 | from semseg.augmentations_mm import get_train_augmentation
15 | import cv2
16 | import random
17 | from PIL import Image, ImageOps, ImageFilter
18 |
19 | class MCubeS(Dataset):
20 | """
21 | num_classes: 20
22 | """
23 | CLASSES = ['asphalt','concrete','metal','road_marking','fabric','glass','plaster','plastic','rubber','sand',
24 | 'gravel','ceramic','cobblestone','brick','grass','wood','leaf','water','human','sky',]
25 |
26 | PALETTE = torch.tensor([[ 44, 160, 44],
27 | [ 31, 119, 180],
28 | [255, 127, 14],
29 | [214, 39, 40],
30 | [140, 86, 75],
31 | [127, 127, 127],
32 | [188, 189, 34],
33 | [255, 152, 150],
34 | [ 23, 190, 207],
35 | [174, 199, 232],
36 | [196, 156, 148],
37 | [197, 176, 213],
38 | [247, 182, 210],
39 | [199, 199, 199],
40 | [219, 219, 141],
41 | [158, 218, 229],
42 | [ 57, 59, 121],
43 | [107, 110, 207],
44 | [156, 158, 222],
45 | [ 99, 121, 57]])
46 |
47 | def __init__(self, root: str = 'data/MCubeS/multimodal_dataset', split: str = 'train', transform = None, modals = ['image', 'aolp', 'dolp', 'nir'], case = None) -> None:
48 | super().__init__()
49 | assert split in ['train', 'val']
50 | self.split = split
51 | self.root = root
52 | self.transform = transform
53 | self.n_classes = len(self.CLASSES)
54 | self.ignore_label = 255
55 | self.modals = modals
56 | self._left_offset = 192
57 |
58 | self.img_h = 1024
59 | self.img_w = 1224
60 | max_dim = max(self.img_h, self.img_w)
61 | u_vec = (np.arange(self.img_w)-self.img_w/2)/max_dim*2
62 | v_vec = (np.arange(self.img_h)-self.img_h/2)/max_dim*2
63 | self.u_map, self.v_map = np.meshgrid(u_vec, v_vec)
64 | self.u_map = self.u_map[:,:self._left_offset]
65 |
66 | self.base_size = 512
67 | self.crop_size = 512
68 | self.files = self._get_file_names(split)
69 |
70 | if not self.files:
71 | raise Exception(f"No images found in {img_path}")
72 | print(f"Found {len(self.files)} {split} images.")
73 |
74 | def __len__(self) -> int:
75 | return len(self.files)
76 |
77 | def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]:
78 | item_name = str(self.files[index])
79 | rgb = os.path.join(*[self.root, 'polL_color', item_name+'.png'])
80 | x1 = os.path.join(*[self.root, 'polL_aolp_sin', item_name+'.npy'])
81 | x1_1 = os.path.join(*[self.root, 'polL_aolp_cos', item_name+'.npy'])
82 | x2 = os.path.join(*[self.root, 'polL_dolp', item_name+'.npy'])
83 | x3 = os.path.join(*[self.root, 'NIR_warped', item_name+'.png'])
84 | lbl_path = os.path.join(*[self.root, 'GT', item_name+'.png'])
85 | nir_mask = os.path.join(*[self.root, 'NIR_warped_mask', item_name+'.png'])
86 | _mask = os.path.join(*[self.root, 'SS', item_name+'.png'])
87 |
88 | _img = cv2.imread(rgb,-1)[:,:,::-1]
89 | _img = _img.astype(np.float32)/65535 if _img.dtype==np.uint16 else _img.astype(np.float32)/255
90 | _target = cv2.imread(lbl_path,-1)
91 | _mask = cv2.imread(_mask,-1)
92 | # _aolp_sin = (np.load(x1) + 2.66) / (2.66 + 2.40) # See below of this script for min-max values
93 | # _aolp_cos = (np.load(x1_1) + 0.57 ) / (0.57 + 1.64)
94 | _aolp_sin = np.load(x1)
95 | _aolp_cos = np.load(x1_1)
96 | _aolp = np.stack([_aolp_sin, _aolp_cos, _aolp_sin], axis=2) # H x W x 3
97 | dolp = (np.load(x2) + 0.55) / 2.0
98 | _dolp = np.stack([dolp, dolp, dolp], axis=2) # H x W x 3
99 | nir = cv2.imread(x3,-1)
100 | nir = nir.astype(np.float32)/65535 if nir.dtype==np.uint16 else nir.astype(np.float32)/255
101 | _nir = np.stack([nir, nir, nir], axis=2) # H x W x 3
102 |
103 | _nir_mask = cv2.imread(nir_mask,0)
104 |
105 | _img, _target, _aolp, _dolp, _nir, _nir_mask, _mask = _img[:,self._left_offset:], _target[:,self._left_offset:], \
106 | _aolp[:,self._left_offset:], _dolp[:,self._left_offset:], \
107 | _nir[:,self._left_offset:], _nir_mask[:,self._left_offset:], _mask[:,self._left_offset:]
108 | sample = {'image': _img, 'label': _target, 'aolp': _aolp, 'dolp': _dolp, 'nir': _nir, 'nir_mask': _nir_mask, 'u_map': self.u_map, 'v_map': self.v_map, 'mask':_mask}
109 |
110 | if self.split == "train":
111 | sample = self.transform_tr(sample)
112 | elif self.split == 'val':
113 | sample = self.transform_val(sample)
114 | elif self.split == 'test':
115 | sample = self.transform_val(sample)
116 | else:
117 | raise NotImplementedError()
118 | label = sample['label'].long()
119 | sample = [sample[k] for k in self.modals]
120 | # del _img, _target, _aolp, _dolp, _nir, _nir_mask, _mask
121 | return sample, label
122 |
123 | def transform_tr(self, sample):
124 | composed_transforms = transforms.Compose([
125 | RandomHorizontalFlip(),
126 | RandomScaleCrop(base_size=self.base_size, crop_size=self.crop_size, fill=255),
127 | RandomGaussianBlur(),
128 | Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
129 | ToTensor()])
130 |
131 | return composed_transforms(sample)
132 |
133 | def transform_val(self, sample):
134 | composed_transforms = transforms.Compose([
135 | FixScaleCrop(crop_size=1024),
136 | Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
137 | ToTensor()])
138 |
139 | return composed_transforms(sample)
140 |
141 | def _get_file_names(self, split_name):
142 | assert split_name in ['train', 'val']
143 | source = os.path.join(self.root, 'list_folder/test.txt') if split_name == 'val' else os.path.join(self.root, 'list_folder/train.txt')
144 | file_names = []
145 | with open(source) as f:
146 | files = f.readlines()
147 | for item in files:
148 | file_name = item.strip()
149 | if ' ' in file_name:
150 | # --- KITTI-360
151 | file_name = file_name.split(' ')[0]
152 | file_names.append(file_name)
153 | return file_names
154 |
155 | class Normalize(object):
156 | """Normalize a tensor image with mean and standard deviation.
157 | Args:
158 | mean (tuple): means for each channel.
159 | std (tuple): standard deviations for each channel.
160 | """
161 | def __init__(self, mean=(0., 0., 0.), std=(1., 1., 1.)):
162 | self.mean = mean
163 | self.std = std
164 |
165 | def __call__(self, sample):
166 | img = sample['image']
167 | mask = sample['label']
168 | img = np.array(img).astype(np.float32)
169 | mask = np.array(mask).astype(np.float32)
170 | img -= self.mean
171 | img /= self.std
172 |
173 | nir = sample['nir']
174 | nir = np.array(nir).astype(np.float32)
175 | # nir /= 255
176 |
177 | return {'image': img,
178 | 'label': mask,
179 | 'aolp' : sample['aolp'],
180 | 'dolp' : sample['dolp'],
181 | 'nir' : nir,
182 | 'nir_mask': sample['nir_mask'],
183 | 'u_map': sample['u_map'],
184 | 'v_map': sample['v_map'],
185 | 'mask':sample['mask']}
186 |
187 |
188 | class ToTensor(object):
189 | """Convert ndarrays in sample to Tensors."""
190 |
191 | def __call__(self, sample):
192 | # swap color axis because
193 | # numpy image: H x W x C
194 | # torch image: C X H X W
195 | img = sample['image']
196 | mask = sample['label']
197 | aolp = sample['aolp']
198 | dolp = sample['dolp']
199 | nir = sample['nir']
200 | nir_mask = sample['nir_mask']
201 | SS=sample['mask']
202 |
203 | img = np.array(img).astype(np.float32).transpose((2, 0, 1))
204 | mask = np.array(mask).astype(np.float32)
205 | aolp = np.array(aolp).astype(np.float32).transpose((2, 0, 1))
206 | dolp = np.array(dolp).astype(np.float32).transpose((2, 0, 1))
207 | SS = np.array(SS).astype(np.float32)
208 | nir = np.array(nir).astype(np.float32).transpose((2, 0, 1))
209 | nir_mask = np.array(nir_mask).astype(np.float32)
210 |
211 | img = torch.from_numpy(img).float()
212 | mask = torch.from_numpy(mask).float()
213 | aolp = torch.from_numpy(aolp).float()
214 | dolp = torch.from_numpy(dolp).float()
215 | SS = torch.from_numpy(SS).float()
216 | nir = torch.from_numpy(nir).float()
217 | nir_mask = torch.from_numpy(nir_mask).float()
218 |
219 | u_map = sample['u_map']
220 | v_map = sample['v_map']
221 | u_map = torch.from_numpy(u_map.astype(np.float32)).float()
222 | v_map = torch.from_numpy(v_map.astype(np.float32)).float()
223 |
224 | return {'image': img,
225 | 'label': mask,
226 | 'aolp' : aolp,
227 | 'dolp' : dolp,
228 | 'nir' : nir,
229 | 'nir_mask' : nir_mask,
230 | 'u_map': u_map,
231 | 'v_map': v_map,
232 | 'mask':SS}
233 |
234 |
235 | class RandomHorizontalFlip(object):
236 | def __call__(self, sample):
237 | img = sample['image']
238 | mask = sample['label']
239 | aolp = sample['aolp']
240 | dolp = sample['dolp']
241 | nir = sample['nir']
242 | nir_mask = sample['nir_mask']
243 | u_map = sample['u_map']
244 | v_map = sample['v_map']
245 | SS=sample['mask']
246 | if random.random() < 0.5:
247 | # img = img.transpose(Image.FLIP_LEFT_RIGHT)
248 | # mask = mask.transpose(Image.FLIP_LEFT_RIGHT)
249 | # nir = nir.transpose(Image.FLIP_LEFT_RIGHT)
250 |
251 | img = img[:,::-1]
252 | mask = mask[:,::-1]
253 | nir = nir[:,::-1]
254 | nir_mask = nir_mask[:,::-1]
255 | aolp = aolp[:,::-1]
256 | dolp = dolp[:,::-1]
257 | SS = SS[:,::-1]
258 | u_map = u_map[:,::-1]
259 |
260 | return {'image': img,
261 | 'label': mask,
262 | 'aolp' : aolp,
263 | 'dolp' : dolp,
264 | 'nir' : nir,
265 | 'nir_mask' : nir_mask,
266 | 'u_map': u_map,
267 | 'v_map': v_map,
268 | 'mask':SS}
269 |
270 | class RandomGaussianBlur(object):
271 | def __call__(self, sample):
272 | img = sample['image']
273 | mask = sample['label']
274 | nir = sample['nir']
275 | if random.random() < 0.5:
276 | radius = random.random()
277 | # img = img.filter(ImageFilter.GaussianBlur(radius=radius))
278 | # nir = nir.filter(ImageFilter.GaussianBlur(radius=radius))
279 | img = cv2.GaussianBlur(img, (0,0), radius)
280 | nir = cv2.GaussianBlur(nir, (0,0), radius)
281 |
282 | return {'image': img,
283 | 'label': mask,
284 | 'aolp' : sample['aolp'],
285 | 'dolp' : sample['dolp'],
286 | 'nir' : nir,
287 | 'nir_mask': sample['nir_mask'],
288 | 'u_map': sample['u_map'],
289 | 'v_map': sample['v_map'],
290 | 'mask':sample['mask']}
291 |
292 | class RandomScaleCrop(object):
293 | def __init__(self, base_size, crop_size, fill=255):
294 | self.base_size = base_size
295 | self.crop_size = crop_size
296 | self.fill = fill
297 |
298 | def __call__(self, sample):
299 | img = sample['image']
300 | mask = sample['label']
301 | aolp = sample['aolp']
302 | dolp = sample['dolp']
303 | nir = sample['nir']
304 | nir_mask = sample['nir_mask']
305 | SS=sample['mask']
306 | # random scale (short edge)
307 | short_size = random.randint(int(self.base_size * 0.5), int(self.base_size * 2.0))
308 | # w, h = img.size
309 | h, w = img.shape[:2]
310 | if h > w:
311 | ow = short_size
312 | oh = int(1.0 * h * ow / w)
313 | else:
314 | oh = short_size
315 | ow = int(1.0 * w * oh / h)
316 |
317 | # pad crop
318 | if short_size < self.crop_size:
319 | padh = self.crop_size - oh if oh < self.crop_size else 0
320 | padw = self.crop_size - ow if ow < self.crop_size else 0
321 |
322 | # random crop crop_size
323 | # w, h = img.size
324 | h, w = img.shape[:2]
325 |
326 | # x1 = random.randint(0, w - self.crop_size)
327 | # y1 = random.randint(0, h - self.crop_size)
328 | x1 = random.randint(0, max(0, ow - self.crop_size))
329 | y1 = random.randint(0, max(0, oh - self.crop_size))
330 |
331 | u_map = sample['u_map']
332 | v_map = sample['v_map']
333 | u_map = cv2.resize(u_map,(ow,oh))
334 | v_map = cv2.resize(v_map,(ow,oh))
335 | aolp = cv2.resize(aolp ,(ow,oh))
336 | dolp = cv2.resize(dolp ,(ow,oh))
337 | SS = cv2.resize(SS ,(ow,oh))
338 | img = cv2.resize(img ,(ow,oh), interpolation=cv2.INTER_LINEAR)
339 | mask = cv2.resize(mask ,(ow,oh), interpolation=cv2.INTER_NEAREST)
340 | nir = cv2.resize(nir ,(ow,oh), interpolation=cv2.INTER_LINEAR)
341 | nir_mask = cv2.resize(nir_mask ,(ow,oh), interpolation=cv2.INTER_NEAREST)
342 | if short_size < self.crop_size:
343 | u_map_ = np.zeros((oh+padh,ow+padw))
344 | u_map_[:oh,:ow] = u_map
345 | u_map = u_map_
346 | v_map_ = np.zeros((oh+padh,ow+padw))
347 | v_map_[:oh,:ow] = v_map
348 | v_map = v_map_
349 | aolp_ = np.zeros((oh+padh,ow+padw,3))
350 | aolp_[:oh,:ow] = aolp
351 | aolp = aolp_
352 | dolp_ = np.zeros((oh+padh,ow+padw,3))
353 | dolp_[:oh,:ow] = dolp
354 | dolp = dolp_
355 |
356 | img_ = np.zeros((oh+padh,ow+padw,3))
357 | img_[:oh,:ow] = img
358 | img = img_
359 | SS_ = np.zeros((oh+padh,ow+padw))
360 | SS_[:oh,:ow] = SS
361 | SS = SS_
362 | mask_ = np.full((oh+padh,ow+padw),self.fill)
363 | mask_[:oh,:ow] = mask
364 | mask = mask_
365 | nir_ = np.zeros((oh+padh,ow+padw,3))
366 | nir_[:oh,:ow] = nir
367 | nir = nir_
368 | nir_mask_ = np.zeros((oh+padh,ow+padw))
369 | nir_mask_[:oh,:ow] = nir_mask
370 | nir_mask = nir_mask_
371 |
372 | u_map = u_map[y1:y1+self.crop_size, x1:x1+self.crop_size]
373 | v_map = v_map[y1:y1+self.crop_size, x1:x1+self.crop_size]
374 | aolp = aolp[y1:y1+self.crop_size, x1:x1+self.crop_size]
375 | dolp = dolp[y1:y1+self.crop_size, x1:x1+self.crop_size]
376 | img = img[y1:y1+self.crop_size, x1:x1+self.crop_size]
377 | mask = mask[y1:y1+self.crop_size, x1:x1+self.crop_size]
378 | nir = nir[y1:y1+self.crop_size, x1:x1+self.crop_size]
379 | SS = SS[y1:y1+self.crop_size, x1:x1+self.crop_size]
380 | nir_mask = nir_mask[y1:y1+self.crop_size, x1:x1+self.crop_size]
381 | return {'image': img,
382 | 'label': mask,
383 | 'aolp' : aolp,
384 | 'dolp' : dolp,
385 | 'nir' : nir,
386 | 'nir_mask' : nir_mask,
387 | 'u_map': u_map,
388 | 'v_map': v_map,
389 | 'mask':SS}
390 |
391 | class FixScaleCrop(object):
392 | def __init__(self, crop_size):
393 | self.crop_size = crop_size
394 |
395 | def __call__(self, sample):
396 | img = sample['image']
397 | mask = sample['label']
398 | aolp = sample['aolp']
399 | dolp = sample['dolp']
400 | nir = sample['nir']
401 | nir_mask = sample['nir_mask']
402 | SS = sample['mask']
403 |
404 | # w, h = img.size
405 | h, w = img.shape[:2]
406 |
407 | if w > h:
408 | oh = self.crop_size
409 | ow = int(1.0 * w * oh / h)
410 | else:
411 | ow = self.crop_size
412 | oh = int(1.0 * h * ow / w)
413 | # img = img.resize((ow, oh), Image.BILINEAR)
414 | # mask = mask.resize((ow, oh), Image.NEAREST)
415 | # nir = nir.resize((ow, oh), Image.BILINEAR)
416 |
417 | # center crop
418 | # w, h = img.size
419 | # h, w = img.shape[:2]
420 | x1 = int(round((ow - self.crop_size) / 2.))
421 | y1 = int(round((oh - self.crop_size) / 2.))
422 | # img = img.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size))
423 | # mask = mask.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size))
424 | # nir = nir.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size))
425 |
426 | u_map = sample['u_map']
427 | v_map = sample['v_map']
428 | u_map = cv2.resize(u_map,(ow,oh))
429 | v_map = cv2.resize(v_map,(ow,oh))
430 | aolp = cv2.resize(aolp ,(ow,oh))
431 | dolp = cv2.resize(dolp ,(ow,oh))
432 | SS = cv2.resize(SS ,(ow,oh))
433 | img = cv2.resize(img ,(ow,oh), interpolation=cv2.INTER_LINEAR)
434 | mask = cv2.resize(mask ,(ow,oh), interpolation=cv2.INTER_NEAREST)
435 | nir = cv2.resize(nir ,(ow,oh), interpolation=cv2.INTER_LINEAR)
436 | nir_mask = cv2.resize(nir_mask,(ow,oh), interpolation=cv2.INTER_NEAREST)
437 | u_map = u_map[y1:y1+self.crop_size, x1:x1+self.crop_size]
438 | v_map = v_map[y1:y1+self.crop_size, x1:x1+self.crop_size]
439 | aolp = aolp[y1:y1+self.crop_size, x1:x1+self.crop_size]
440 | dolp = dolp[y1:y1+self.crop_size, x1:x1+self.crop_size]
441 | img = img[y1:y1+self.crop_size, x1:x1+self.crop_size]
442 | mask = mask[y1:y1+self.crop_size, x1:x1+self.crop_size]
443 | SS = SS[y1:y1+self.crop_size, x1:x1+self.crop_size]
444 | nir = nir[y1:y1+self.crop_size, x1:x1+self.crop_size]
445 | nir_mask = nir_mask[y1:y1+self.crop_size, x1:x1+self.crop_size]
446 | return {'image': img,
447 | 'label': mask,
448 | 'aolp' : aolp,
449 | 'dolp' : dolp,
450 | 'nir' : nir,
451 | 'nir_mask' : nir_mask,
452 | 'u_map': u_map,
453 | 'v_map': v_map,
454 | 'mask':SS}
455 |
456 |
457 | if __name__ == '__main__':
458 | traintransform = get_train_augmentation((1024, 1224), seg_fill=255)
459 |
460 | trainset = MCubeS(transform=traintransform, split='val')
461 | trainloader = DataLoader(trainset, batch_size=1, num_workers=0, drop_last=False, pin_memory=False)
462 |
463 | for i, (sample, lbl) in enumerate(trainloader):
464 | print(torch.unique(lbl))
465 |
466 |
467 | # Reading DoLP....
468 | # DoLP: Global Min: -0.5281060934066772, Global Max: 1.4464844465255737, Negative Count: 11528302, Negative Percent: 1.839560036254085%
469 | # Reading AoLP (Sin)....
470 | # AoLP (Sin): Global Min: -2.6559877395629883, Global Max: 2.3968138694763184, Negative Count: 279532583, Negative Percent: 44.604744785283906%
471 | # Reading AoLP (Sin)....
472 | # AoLP (Cos): Global Min: -0.5681185722351074, Global Max: 1.6386538743972778, Negative Count: 3569347, Negative Percent: 0.5695572597528594%
--------------------------------------------------------------------------------