├── semseg ├── utils │ ├── __init__.py │ ├── visualize.py │ └── utils.py ├── models │ ├── layers │ │ ├── __init__.py │ │ ├── common.py │ │ └── initialize.py │ ├── __init__.py │ ├── heads │ │ ├── __init__.py │ │ └── segformer.py │ ├── backbones │ │ ├── __init__.py │ │ └── mmsformer.py │ ├── mmsformer.py │ └── base.py ├── datasets │ ├── __init__.py │ ├── unzip.py │ ├── pst.py │ ├── fmb.py │ └── mcubes.py ├── optimizers.py ├── __init__.py ├── metrics.py ├── losses.py ├── schedulers.py ├── augmentations.py └── augmentations_mm.py ├── figs ├── MMSFormer-V2.png └── MMSFormer-v2.png ├── tools ├── train.sh ├── infer_mm.py ├── val_mm.py └── train_mm.py ├── configs ├── pst_rgbt.yaml ├── fmb_rgbt.yaml └── mcubes_rgbadn.yaml ├── requirements.txt ├── .gitignore ├── environment.yaml ├── README.md └── LICENSE /semseg/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /semseg/models/layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .common import * 2 | from .initialize import * -------------------------------------------------------------------------------- /figs/MMSFormer-V2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSIPlab/MMSFormer/HEAD/figs/MMSFormer-V2.png -------------------------------------------------------------------------------- /figs/MMSFormer-v2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSIPlab/MMSFormer/HEAD/figs/MMSFormer-v2.png -------------------------------------------------------------------------------- /semseg/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .mmsformer import MMSFormer 2 | 3 | __all__ = [ 4 | 'MMSFormer', 5 | ] -------------------------------------------------------------------------------- /semseg/models/heads/__init__.py: -------------------------------------------------------------------------------- 1 | from .segformer import SegFormerHead 2 | 3 | __all__ = ['SegFormerHead'] -------------------------------------------------------------------------------- /semseg/models/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | from .mmsformer import MMSFormer 2 | 3 | __all__ = [ 4 | 'MMSFormer', 5 | ] -------------------------------------------------------------------------------- /semseg/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .mcubes import MCubeS 2 | from .fmb import FMB 3 | from .pst import PST 4 | 5 | __all__ = [ 6 | 'MCubeS', 7 | 'FMB', 8 | 'PST', 9 | ] -------------------------------------------------------------------------------- /semseg/datasets/unzip.py: -------------------------------------------------------------------------------- 1 | import zipfile 2 | 3 | with zipfile.ZipFile("data/MCubeS/multimodal_dataset.zip", "r") as zip_ref: 4 | for name in zip_ref.namelist(): 5 | try: 6 | zip_ref.extract(name, "multimodal_dataset_extracted/") 7 | except zipfile.BadZipFile as e: 8 | print(e) -------------------------------------------------------------------------------- /tools/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -l 2 | 3 | #SBATCH --mem=16G 4 | #SBATCH --time=120:00:00 5 | #SBATCH --mail-user=emailaddr@domain.com 6 | #SBATCH --mail-type=ALL 7 | #SBATCH --cpus-per-task=16 8 | #SBATCH -p batch 9 | #SBATCH --output=output_%j-%N.txt # logging per job and per host in the current directory. Both stdout and stderr are logged 10 | #SBATCH --gres=gpu:2 11 | 12 | conda activate mmsformer 13 | 14 | python -m tools.train_mm --cfg configs/mcubes_rgbadn.yaml -------------------------------------------------------------------------------- /semseg/optimizers.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch.optim import AdamW, SGD 3 | 4 | 5 | def get_optimizer(model: nn.Module, optimizer: str, lr: float, weight_decay: float = 0.01): 6 | wd_params, nwd_params = [], [] 7 | for p in model.parameters(): 8 | if p.requires_grad: 9 | if p.dim() == 1: 10 | nwd_params.append(p) 11 | else: 12 | wd_params.append(p) 13 | 14 | params = [ 15 | {"params": wd_params}, 16 | {"params": nwd_params, "weight_decay": 0} 17 | ] 18 | 19 | if optimizer == 'adamw': 20 | return AdamW(params, lr, betas=(0.9, 0.999), eps=1e-8, weight_decay=weight_decay) 21 | else: 22 | return SGD(params, lr, momentum=0.9, weight_decay=weight_decay) -------------------------------------------------------------------------------- /semseg/__init__.py: -------------------------------------------------------------------------------- 1 | from tabulate import tabulate 2 | from semseg import models 3 | from semseg import datasets 4 | from semseg.models import backbones, heads 5 | 6 | 7 | def show_models(): 8 | model_names = models.__all__ 9 | numbers = list(range(1, len(model_names)+1)) 10 | print(tabulate({'No.': numbers, 'Model Names': model_names}, headers='keys')) 11 | 12 | 13 | def show_backbones(): 14 | backbone_names = backbones.__all__ 15 | variants = [] 16 | for name in backbone_names: 17 | try: 18 | variants.append(list(eval(f"backbones.{name.lower()}_settings").keys())) 19 | except: 20 | variants.append('-') 21 | print(tabulate({'Backbone Names': backbone_names, 'Variants': variants}, headers='keys')) 22 | 23 | 24 | def show_heads(): 25 | head_names = heads.__all__ 26 | numbers = list(range(1, len(head_names)+1)) 27 | print(tabulate({'No.': numbers, 'Heads': head_names}, headers='keys')) 28 | 29 | 30 | def show_datasets(): 31 | dataset_names = datasets.__all__ 32 | numbers = list(range(1, len(dataset_names)+1)) 33 | print(tabulate({'No.': numbers, 'Datasets': dataset_names}, headers='keys')) 34 | -------------------------------------------------------------------------------- /semseg/models/layers/common.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, Tensor 3 | 4 | 5 | class ConvModule(nn.Sequential): 6 | def __init__(self, c1, c2, k, s=1, p=0, d=1, g=1): 7 | super().__init__( 8 | nn.Conv2d(c1, c2, k, s, p, d, g, bias=False), 9 | nn.BatchNorm2d(c2), 10 | nn.ReLU(True) 11 | ) 12 | 13 | 14 | class DropPath(nn.Module): 15 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 16 | Copied from timm 17 | This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, 18 | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... 19 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for 20 | changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 21 | 'survival rate' as the argument. 22 | """ 23 | def __init__(self, p: float = None): 24 | super().__init__() 25 | self.p = p 26 | 27 | def forward(self, x: Tensor) -> Tensor: 28 | if self.p == 0. or not self.training: 29 | return x 30 | kp = 1 - self.p 31 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) 32 | random_tensor = kp + torch.rand(shape, dtype=x.dtype, device=x.device) 33 | random_tensor.floor_() # binarize 34 | return x.div(kp) * random_tensor -------------------------------------------------------------------------------- /semseg/models/mmsformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | from semseg.models.base import BaseModel 4 | from semseg.models.heads import SegFormerHead 5 | 6 | 7 | class MMSFormer(BaseModel): 8 | def __init__(self, backbone: str = 'MMSFormer-B0', num_classes: int = 20, modals: list = ['img', 'aolp', 'dolp', 'nir']) -> None: 9 | super().__init__(backbone, num_classes, modals) 10 | self.decode_head = SegFormerHead(self.backbone.channels, 256 if 'B0' in backbone or 'B1' in backbone else 512, num_classes) 11 | self.apply(self._init_weights) 12 | 13 | def forward(self, x: list) -> list: 14 | y = self.backbone(x) 15 | y = self.decode_head(y) 16 | y = F.interpolate(y, size=x[0].shape[2:], mode='bilinear', align_corners=False) 17 | return y 18 | 19 | def init_pretrained(self, pretrained: str = None) -> None: 20 | checkpoint = torch.load(pretrained, map_location='cpu') 21 | if 'state_dict' in checkpoint.keys(): 22 | checkpoint = checkpoint['state_dict'] 23 | if 'model' in checkpoint.keys(): 24 | checkpoint = checkpoint['model'] 25 | msg = self.backbone.load_state_dict(checkpoint, strict=False) 26 | del checkpoint 27 | 28 | 29 | if __name__ == '__main__': 30 | modals = ['img', 'aolp', 'dolp', 'nir'] 31 | model = MMSFormer('MMSFormer-B2', 25, modals) 32 | model.init_pretrained('checkpoints/pretrained/segformer/mit_b2.pth') 33 | x = [torch.zeros(1, 3, 1024, 1024), torch.ones(1, 3, 1024, 1024), torch.ones(1, 3, 1024, 1024)*2, torch.ones(1, 3, 1024, 1024) *3] 34 | y = model(x) 35 | print(y.shape) 36 | -------------------------------------------------------------------------------- /semseg/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | from typing import Tuple 4 | 5 | 6 | class Metrics: 7 | def __init__(self, num_classes: int, ignore_label: int, device) -> None: 8 | self.ignore_label = ignore_label 9 | self.num_classes = num_classes 10 | self.hist = torch.zeros(num_classes, num_classes).to(device) 11 | 12 | def update(self, pred: Tensor, target: Tensor) -> None: 13 | pred = pred.argmax(dim=1) 14 | keep = target != self.ignore_label 15 | self.hist += torch.bincount(target[keep] * self.num_classes + pred[keep], minlength=self.num_classes**2).view(self.num_classes, self.num_classes) 16 | 17 | def compute_iou(self) -> Tuple[Tensor, Tensor]: 18 | ious = self.hist.diag() / (self.hist.sum(0) + self.hist.sum(1) - self.hist.diag()) 19 | ious[ious.isnan()]=0. 20 | miou = ious.mean().item() 21 | # miou = ious[~ious.isnan()].mean().item() 22 | ious *= 100 23 | miou *= 100 24 | return ious.cpu().numpy().round(2).tolist(), round(miou, 2) 25 | 26 | def compute_f1(self) -> Tuple[Tensor, Tensor]: 27 | f1 = 2 * self.hist.diag() / (self.hist.sum(0) + self.hist.sum(1)) 28 | f1[f1.isnan()]=0. 29 | mf1 = f1.mean().item() 30 | # mf1 = f1[~f1.isnan()].mean().item() 31 | f1 *= 100 32 | mf1 *= 100 33 | return f1.cpu().numpy().round(2).tolist(), round(mf1, 2) 34 | 35 | def compute_pixel_acc(self) -> Tuple[Tensor, Tensor]: 36 | acc = self.hist.diag() / self.hist.sum(1) 37 | acc[acc.isnan()]=0. 38 | macc = acc.mean().item() 39 | # macc = acc[~acc.isnan()].mean().item() 40 | acc *= 100 41 | macc *= 100 42 | return acc.cpu().numpy().round(2).tolist(), round(macc, 2) 43 | 44 | -------------------------------------------------------------------------------- /semseg/models/heads/segformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, Tensor 3 | from typing import Tuple 4 | from torch.nn import functional as F 5 | 6 | 7 | class MLP(nn.Module): 8 | def __init__(self, dim, embed_dim): 9 | super().__init__() 10 | self.proj = nn.Linear(dim, embed_dim) 11 | 12 | def forward(self, x: Tensor) -> Tensor: 13 | x = x.flatten(2).transpose(1, 2) 14 | x = self.proj(x) 15 | return x 16 | 17 | 18 | class ConvModule(nn.Module): 19 | def __init__(self, c1, c2): 20 | super().__init__() 21 | self.conv = nn.Conv2d(c1, c2, 1, bias=False) 22 | self.bn = nn.BatchNorm2d(c2) # use SyncBN in original 23 | self.activate = nn.ReLU(True) 24 | 25 | def forward(self, x: Tensor) -> Tensor: 26 | return self.activate(self.bn(self.conv(x))) 27 | 28 | 29 | class SegFormerHead(nn.Module): 30 | def __init__(self, dims: list, embed_dim: int = 256, num_classes: int = 19): 31 | super().__init__() 32 | for i, dim in enumerate(dims): 33 | self.add_module(f"linear_c{i+1}", MLP(dim, embed_dim)) 34 | 35 | self.linear_fuse = ConvModule(embed_dim*4, embed_dim) 36 | self.linear_pred = nn.Conv2d(embed_dim, num_classes, 1) 37 | self.dropout = nn.Dropout2d(0.1) 38 | 39 | def forward(self, features: Tuple[Tensor, Tensor, Tensor, Tensor]) -> Tensor: 40 | B, _, H, W = features[0].shape 41 | outs = [self.linear_c1(features[0]).permute(0, 2, 1).reshape(B, -1, *features[0].shape[-2:])] 42 | 43 | for i, feature in enumerate(features[1:]): 44 | cf = eval(f"self.linear_c{i+2}")(feature).permute(0, 2, 1).reshape(B, -1, *feature.shape[-2:]) 45 | outs.append(F.interpolate(cf, size=(H, W), mode='bilinear', align_corners=False)) 46 | 47 | seg = self.linear_fuse(torch.cat(outs[::-1], dim=1)) 48 | seg = self.linear_pred(self.dropout(seg)) 49 | return seg -------------------------------------------------------------------------------- /semseg/models/layers/initialize.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import warnings 4 | from torch import nn, Tensor 5 | 6 | 7 | def _no_grad_trunc_normal_(tensor, mean, std, a, b): 8 | # Cut & paste from PyTorch official master until it's in a few official releases - RW 9 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf 10 | def norm_cdf(x): 11 | # Computes standard normal cumulative distribution function 12 | return (1. + math.erf(x / math.sqrt(2.))) / 2. 13 | 14 | if (mean < a - 2 * std) or (mean > b + 2 * std): 15 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " 16 | "The distribution of values may be incorrect.", 17 | stacklevel=2) 18 | 19 | with torch.no_grad(): 20 | # Values are generated by using a truncated uniform distribution and 21 | # then using the inverse CDF for the normal distribution. 22 | # Get upper and lower cdf values 23 | l = norm_cdf((a - mean) / std) 24 | u = norm_cdf((b - mean) / std) 25 | 26 | # Uniformly fill tensor with values from [l, u], then translate to 27 | # [2l-1, 2u-1]. 28 | tensor.uniform_(2 * l - 1, 2 * u - 1) 29 | 30 | # Use inverse cdf transform for normal distribution to get truncated 31 | # standard normal 32 | tensor.erfinv_() 33 | 34 | # Transform to proper mean, std 35 | tensor.mul_(std * math.sqrt(2.)) 36 | tensor.add_(mean) 37 | 38 | # Clamp to ensure it's in the proper range 39 | tensor.clamp_(min=a, max=b) 40 | return tensor 41 | 42 | 43 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): 44 | # type: (Tensor, float, float, float, float) -> Tensor 45 | r"""Fills the input Tensor with values drawn from a truncated 46 | normal distribution. The values are effectively drawn from the 47 | normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` 48 | with values outside :math:`[a, b]` redrawn until they are within 49 | the bounds. The method used for generating the random values works 50 | best when :math:`a \leq \text{mean} \leq b`. 51 | Args: 52 | tensor: an n-dimensional `torch.Tensor` 53 | mean: the mean of the normal distribution 54 | std: the standard deviation of the normal distribution 55 | a: the minimum cutoff value 56 | b: the maximum cutoff value 57 | Examples: 58 | >>> w = torch.empty(3, 5) 59 | >>> nn.init.trunc_normal_(w) 60 | """ 61 | return _no_grad_trunc_normal_(tensor, mean, std, a, b) 62 | -------------------------------------------------------------------------------- /configs/pst_rgbt.yaml: -------------------------------------------------------------------------------- 1 | DEVICE : cuda # device used for training and evaluation (cpu, cuda, cuda0, cuda1, ...) 2 | SAVE_DIR : 'output/MMSFormer' # output folder name used for saving the model, logs and inference results 3 | GPUs : 2 4 | GPU_IDs : [0, 1] 5 | USE_WANDB : False 6 | WANDB_NAME : 'MMSF-FMB-PST' # name for the run 7 | 8 | MODEL: 9 | NAME : MMSFormer # name of the model you are using 10 | BACKBONE : MMSFormer-B4 # model variant 11 | PRETRAINED : 'checkpoints/pretrained/segformer/mit_b4.pth' # backbone model's weight 12 | RESUME : '' 13 | 14 | DATASET: 15 | NAME : PST # dataset name to be trained with (camvid, cityscapes, ade20k) 16 | ROOT : 'PATH/TO/DATASET/ROOT' # dataset root path 17 | IGNORE_LABEL : 255 18 | # MODALS : ['img'] 19 | MODALS : ['img', 'thermal'] 20 | 21 | TRAIN: 22 | IMAGE_SIZE : [1280, 720] # training image size in (h, w) === Fixed in dataloader, following MCubeSNet 23 | BATCH_SIZE : 2 # batch size used to train 24 | EPOCHS : 200 # number of epochs to train 25 | EVAL_START : 0 # evaluation interval during training 26 | EVAL_INTERVAL : 1 # evaluation interval during training 27 | AMP : true # use AMP in training 28 | DDP : false # use DDP training 29 | 30 | LOSS: 31 | NAME : OhemCrossEntropy # loss function name 32 | CLS_WEIGHTS : false # use class weights in loss calculation 33 | 34 | OPTIMIZER: 35 | NAME : adamw # optimizer name 36 | LR : 0.00006 # initial learning rate used in optimizer 37 | WEIGHT_DECAY : 0.01 # decay rate used in optimizer 38 | 39 | SCHEDULER: 40 | NAME : warmuppolylr # scheduler name 41 | POWER : 0.9 # scheduler power 42 | WARMUP : 10 # warmup epochs used in scheduler 43 | WARMUP_RATIO : 0.1 # warmup ratio 44 | 45 | 46 | EVAL: 47 | MODEL_PATH : 'PATH/TO/MODEL/WEIGHT' # Path to your saved model 48 | IMAGE_SIZE : [1280, 720] # evaluation image size in (h, w) 49 | BATCH_SIZE : 2 # batch size 50 | VIS_SAVE_DIR : 'PATH/TO/TARGET/DIRECTORY' # Where to save visualization 51 | MSF: 52 | ENABLE : false # multi-scale and flip evaluation 53 | FLIP : true # use flip in evaluation 54 | SCALES : [0.5, 0.75, 1.0, 1.25, 1.5, 1.75] # scales used in MSF evaluation 55 | -------------------------------------------------------------------------------- /configs/fmb_rgbt.yaml: -------------------------------------------------------------------------------- 1 | DEVICE : cuda # device used for training and evaluation (cpu, cuda, cuda0, cuda1, ...) 2 | SAVE_DIR : 'output/MMSFormer' # output folder name used for saving the model, logs and inference results 3 | GPUs : 2 4 | GPU_IDs : [0, 1] 5 | USE_WANDB : False 6 | WANDB_NAME : 'MMSF-FMB-RGBT' # name for the run 7 | 8 | MODEL: 9 | NAME : MMSFormer # name of the model you are using 10 | BACKBONE : MMSFormer-B3 # model variant 11 | PRETRAINED : 'checkpoints/pretrained/segformer/mit_b3.pth' # backbone model's weight 12 | RESUME : '' # checkpoint file 13 | 14 | DATASET: 15 | NAME : FMB # dataset name to be trained with (camvid, cityscapes, ade20k) 16 | ROOT : 'PATH/TO/DATASET/ROOT' # dataset root path 17 | IGNORE_LABEL : 255 18 | # MODALS : ['img'] 19 | MODALS : ['img', 'thermal'] 20 | 21 | TRAIN: 22 | IMAGE_SIZE : [800, 600] # training image size in (h, w) === Fixed in dataloader, following MCubeSNet 23 | BATCH_SIZE : 2 # batch size used to train 24 | EPOCHS : 120 # number of epochs to train 25 | EVAL_START : 0 # evaluation interval during training 26 | EVAL_INTERVAL : 1 # evaluation interval during training 27 | AMP : true # use AMP in training 28 | DDP : false # use DDP training 29 | 30 | LOSS: 31 | NAME : OhemCrossEntropy # loss function name 32 | CLS_WEIGHTS : false # use class weights in loss calculation 33 | 34 | OPTIMIZER: 35 | NAME : adamw # optimizer name 36 | LR : 0.0001 # initial learning rate used in optimizer 37 | WEIGHT_DECAY : 0.01 # decay rate used in optimizer 38 | 39 | SCHEDULER: 40 | NAME : warmuppolylr # scheduler name 41 | POWER : 0.9 # scheduler power 42 | WARMUP : 10 # warmup epochs used in scheduler 43 | WARMUP_RATIO : 0.1 # warmup ratio 44 | 45 | 46 | EVAL: 47 | MODEL_PATH : 'PATH/TO/MODEL/WEIGHT' # Path to your saved model 48 | IMAGE_SIZE : [800, 600] # evaluation image size in (h, w) 49 | BATCH_SIZE : 2 # batch size 50 | VIS_SAVE_DIR : 'PATH/TO/TARGET/DIRECTORY' # Where to save visualization 51 | MSF: 52 | ENABLE : false # multi-scale and flip evaluation 53 | FLIP : true # use flip in evaluation 54 | SCALES : [0.5, 0.75, 1.0, 1.25, 1.5, 1.75] # scales used in MSF evaluation 55 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.2.0 2 | addict==2.4.0 3 | appdirs==1.4.4 4 | argon2-cffi==21.3.0 5 | argon2-cffi-bindings==21.2.0 6 | asttokens==2.0.5 7 | attrs==21.4.0 8 | backcall==0.2.0 9 | bleach==4.1.0 10 | cachetools==5.0.0 11 | certifi==2021.10.8 12 | cffi 13 | charset-normalizer==2.1.1 14 | click==8.1.6 15 | cycler==0.11.0 16 | dataclasses==0.6 17 | debugpy==1.5.1 18 | decorator==5.1.1 19 | defusedxml==0.7.1 20 | descartes==1.1.0 21 | docker-pycreds==0.4.0 22 | easydict==1.9 23 | einops==0.4.1 24 | entrypoints==0.4 25 | executing==0.8.3 26 | fire==0.4.0 27 | future 28 | fvcore==0.1.5.post20220512 29 | gitdb==4.0.10 30 | GitPython==3.1.41 31 | google-auth==2.11.0 32 | google-auth-oauthlib==0.4.6 33 | grpcio==1.48.1 34 | idna==3.3 35 | importlib-metadata==4.12.0 36 | importlib-resources==5.4.0 37 | iopath==0.1.10 38 | ipykernel==6.9.1 39 | ipython==8.1.0 40 | ipython-genutils==0.2.0 41 | ipywidgets==7.6.5 42 | jedi==0.18.1 43 | Jinja2==3.0.3 44 | joblib==1.1.0 45 | jsonschema==4.4.0 46 | jupyter==1.0.0 47 | jupyter-client==7.1.2 48 | jupyter-console==6.4.0 49 | jupyter-core==4.9.2 50 | jupyterlab-pygments==0.1.2 51 | jupyterlab-widgets==1.0.2 52 | kiwisolver==1.3.2 53 | Markdown==3.4.1 54 | MarkupSafe==2.1.1 55 | matplotlib==3.4.3 56 | matplotlib-inline==0.1.3 57 | mistune==0.8.4 58 | mkl-fft==1.3.0 59 | mkl-random 60 | mkl-service==2.4.0 61 | mmcv==0.6.2 62 | nbclient==0.5.11 63 | nbconvert==6.4.2 64 | nbformat==5.1.3 65 | nest-asyncio==1.5.4 66 | notebook==6.4.8 67 | numpy 68 | nuscenes-devkit==1.1.9 69 | oauthlib==3.2.1 70 | olefile 71 | opencv-python==4.5.3.56 72 | packaging==21.3 73 | pandocfilters==1.5.0 74 | parso==0.8.3 75 | pathtools==0.1.2 76 | pexpect==4.8.0 77 | pickleshare==0.7.5 78 | Pillow 79 | plyfile==0.7.4 80 | portalocker==2.5.1 81 | prometheus-client==0.13.1 82 | prompt-toolkit==3.0.28 83 | protobuf==3.18.1 84 | psutil==5.9.5 85 | ptyprocess==0.7.0 86 | pure-eval==0.2.2 87 | pyasn1==0.4.8 88 | pyasn1-modules==0.2.8 89 | pycocotools==2.0.4 90 | pycparser 91 | Pygments==2.11.2 92 | pyparsing==3.0.6 93 | pyquaternion==0.9.9 94 | pyrsistent==0.18.1 95 | python-dateutil==2.8.2 96 | PyYAML==6.0 97 | pyzmq==22.3.0 98 | qtconsole==5.2.2 99 | QtPy==2.0.1 100 | requests==2.28.1 101 | requests-oauthlib==1.3.1 102 | rsa==4.9 103 | scikit-learn==1.0.2 104 | scipy==1.7.1 105 | Send2Trash==1.8.0 106 | sentry-sdk==1.28.1 107 | setproctitle==1.3.2 108 | Shapely==1.8.1.post1 109 | six 110 | smmap==5.0.0 111 | stack-data==0.2.0 112 | tabulate==0.8.10 113 | tensorboard==2.10.0 114 | tensorboard-data-server==0.6.1 115 | tensorboard-plugin-wit==1.8.1 116 | tensorboardX==2.4 117 | termcolor==1.1.0 118 | terminado==0.13.1 119 | testpath==0.6.0 120 | threadpoolctl==3.1.0 121 | timm==0.4.12 122 | torch 123 | torchaudio==0.9.0a0+33b2469 124 | torchvision 125 | tornado==6.1 126 | tqdm==4.62.3 127 | traitlets==5.1.1 128 | typing-extensions 129 | urllib3==1.26.12 130 | wandb==0.15.7 131 | wcwidth==0.2.5 132 | webencodings==0.5.1 133 | Werkzeug==2.2.2 134 | widgetsnbextension==3.5.2 135 | yacs==0.1.8 136 | yapf==0.32.0 137 | zipp==3.7.0 138 | -------------------------------------------------------------------------------- /configs/mcubes_rgbadn.yaml: -------------------------------------------------------------------------------- 1 | DEVICE : cuda # device used for training and evaluation (cpu, cuda, cuda0, cuda1, ...) 2 | SAVE_DIR : 'output/MMSFormer' # output folder name used for saving the model, logs and inference results 3 | GPUs : 2 4 | GPU_IDs : [0, 1] 5 | USE_WANDB : False # Whether you want to use wandb 6 | WANDB_NAME : 'MMSF-MCubeS-RGBNAD' # name for the run 7 | 8 | MODEL: 9 | NAME : MMSFormer # name of the model you are using 10 | BACKBONE : MMSFormer-B4 # model variant 11 | PRETRAINED : 'checkpoints/pretrained/segformer/mit_b4.pth' # backbone model's weight 12 | RESUME : '' # checkpoint file 13 | 14 | DATASET: 15 | NAME : MCubeS # dataset name to be trained with (camvid, cityscapes, ade20k) 16 | ROOT : 'PATH/TO/DATASET/ROOT' # dataset root path 17 | IGNORE_LABEL : 255 18 | # MODALS : ['image'] 19 | # MODALS : ['image', 'nir'] 20 | # MODALS : ['image', 'aolp'] 21 | # MODALS : ['image', 'dolp'] 22 | # MODALS : ['image', 'aolp', 'nir'] 23 | # MODALS : ['image', 'aolp', 'dolp'] 24 | MODALS : ['image', 'nir', 'aolp', 'dolp'] 25 | 26 | TRAIN: 27 | IMAGE_SIZE : [512, 512] # training image size in (h, w) === Fixed in dataloader, following MCubeSNet 28 | BATCH_SIZE : 4 # batch size used to train 29 | EPOCHS : 500 # number of epochs to train 30 | EVAL_START : 0 # evaluation interval during training 31 | EVAL_INTERVAL : 1 # evaluation interval during training 32 | AMP : true # use AMP in training 33 | DDP : false # use DDP training 34 | 35 | LOSS: 36 | NAME : OhemCrossEntropy # loss function name 37 | CLS_WEIGHTS : false # use class weights in loss calculation 38 | 39 | OPTIMIZER: 40 | NAME : adamw # optimizer name 41 | LR : 0.00006 # initial learning rate used in optimizer 42 | WEIGHT_DECAY : 0.01 # decay rate used in optimizer 43 | 44 | SCHEDULER: 45 | NAME : warmuppolylr # scheduler name 46 | POWER : 0.9 # scheduler power 47 | WARMUP : 10 # warmup epochs used in scheduler 48 | WARMUP_RATIO : 0.1 # warmup ratio 49 | 50 | 51 | EVAL: 52 | MODEL_PATH : 'PATH/TO/MODEL/WEIGHT' # Path to your saved model 53 | IMAGE_SIZE : [1024, 1024] # evaluation image size in (h, w) 54 | BATCH_SIZE : 2 # batch size 55 | VIS_SAVE_DIR : 'PATH/TO/TARGET/DIRECTORY' # Where to save visualization 56 | MSF: 57 | ENABLE : false # multi-scale and flip evaluation 58 | FLIP : true # use flip in evaluation 59 | SCALES : [0.5, 0.75, 1.0, 1.25, 1.5, 1.75] # scales used in MSF evaluation 60 | -------------------------------------------------------------------------------- /semseg/models/base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | from torch import nn 4 | from semseg.models.backbones import * 5 | from semseg.models.layers import trunc_normal_ 6 | from collections import OrderedDict 7 | 8 | def load_dualpath_model(model, model_file): 9 | # load raw state_dict 10 | if isinstance(model_file, str): 11 | raw_state_dict = torch.load(model_file, map_location=torch.device('cpu')) 12 | #raw_state_dict = torch.load(model_file) 13 | if 'model' in raw_state_dict.keys(): 14 | raw_state_dict = raw_state_dict['model'] 15 | else: 16 | raw_state_dict = model_file 17 | 18 | state_dict = {} 19 | for k, v in raw_state_dict.items(): 20 | if k.find('patch_embed') >= 0: 21 | state_dict[k] = v 22 | # patch_embedx, proj, weight = k.split('.') 23 | # state_dict[k.replace('patch_embed', 'extra_patch_embed')] = v 24 | # state_dict[new_k] = v 25 | elif k.find('block') >= 0: 26 | state_dict[k] = v 27 | # state_dict[k.replace('block', 'extra_block')] = v 28 | elif k.find('norm') >= 0: 29 | state_dict[k] = v 30 | # state_dict[k.replace('norm', 'extra_norm')] = v 31 | 32 | msg = model.load_state_dict(state_dict, strict=False) 33 | print(msg) 34 | del state_dict 35 | 36 | 37 | class BaseModel(nn.Module): 38 | def __init__(self, backbone: str = 'MiT-B0', num_classes: int = 19, modals: list = ['rgb', 'depth', 'event', 'lidar']) -> None: 39 | super().__init__() 40 | backbone, variant = backbone.split('-') 41 | self.backbone = eval(backbone)(variant, modals) 42 | # self.backbone = eval(backbone)(variant) 43 | self.modals = modals 44 | 45 | def _init_weights(self, m: nn.Module) -> None: 46 | if isinstance(m, nn.Linear): 47 | trunc_normal_(m.weight, std=.02) 48 | if m.bias is not None: 49 | nn.init.zeros_(m.bias) 50 | elif isinstance(m, nn.Conv2d): 51 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 52 | fan_out // m.groups 53 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 54 | if m.bias is not None: 55 | nn.init.zeros_(m.bias) 56 | elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)): 57 | nn.init.ones_(m.weight) 58 | nn.init.zeros_(m.bias) 59 | 60 | def init_pretrained(self, pretrained: str = None) -> None: 61 | if pretrained: 62 | if len(self.modals)>1: 63 | load_dualpath_model(self.backbone, pretrained) 64 | else: 65 | checkpoint = torch.load(pretrained, map_location='cpu') 66 | if 'state_dict' in checkpoint.keys(): 67 | checkpoint = checkpoint['state_dict'] 68 | # if 'PoolFormer' in self.__class__.__name__: 69 | # new_dict = OrderedDict() 70 | # for k, v in checkpoint.items(): 71 | # if not 'backbone.' in k: 72 | # new_dict['backbone.'+k] = v 73 | # else: 74 | # new_dict[k] = v 75 | # checkpoint = new_dict 76 | if 'model' in checkpoint.keys(): # --- for HorNet 77 | checkpoint = checkpoint['model'] 78 | msg = self.backbone.load_state_dict(checkpoint, strict=False) 79 | print(msg) 80 | -------------------------------------------------------------------------------- /semseg/datasets/pst.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from torch import Tensor 5 | from torch.utils.data import Dataset 6 | import torchvision.transforms.functional as TF 7 | from torchvision import io 8 | from pathlib import Path 9 | from typing import Tuple 10 | import glob 11 | import einops 12 | from torch.utils.data import DataLoader 13 | from torch.utils.data import DistributedSampler, RandomSampler 14 | from semseg.augmentations_mm import get_train_augmentation 15 | 16 | class PST(Dataset): 17 | """ 18 | num_classes: 5 19 | """ 20 | CLASSES = ["Background", "Fire-Extinguisher", "Backpack", "Hand-Drill", "Survivor"] 21 | 22 | PALETTE = torch.tensor([[0, 0, 0], 23 | [100, 40, 40], 24 | [55, 90, 80], 25 | [220, 20, 60], 26 | [153, 153, 153]]) 27 | 28 | def __init__(self, root: str = 'data/PST900', split: str = 'train', transform = None, modals = ['img', 'thermal'], case = None) -> None: 29 | super().__init__() 30 | assert split in ['train', 'val'] 31 | self.transform = transform 32 | self.n_classes = len(self.CLASSES) 33 | self.ignore_label = 255 34 | self.modals = modals 35 | if split == 'val': 36 | split = 'test' 37 | self.files = sorted(glob.glob(os.path.join(*[root, split, 'rgb', '*.png']))) 38 | # --- debug 39 | # self.files = sorted(glob.glob(os.path.join(*[root, 'img', '*', split, '*', '*.png'])))[:100] 40 | # --- split as case 41 | if not self.files: 42 | raise Exception(f"No images found in {img_path}") 43 | print(f"Found {len(self.files)} {split} images.") 44 | 45 | def __len__(self) -> int: 46 | return len(self.files) 47 | 48 | def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]: 49 | item_name = self.files[index].split("/")[-1].split(".")[0] 50 | rgb = str(self.files[index]) 51 | thermal = rgb.replace('/rgb', '/thermal') 52 | lbl_path = rgb.replace('/rgb', '/labels') 53 | 54 | sample = {} 55 | sample['img'] = io.read_image(rgb)[:3, ...] 56 | H, W = sample['img'].shape[1:] 57 | if 'thermal' in self.modals: 58 | sample['thermal'] = self._open_img(thermal) 59 | label = io.read_image(lbl_path)[0,...].unsqueeze(0) 60 | # label[label==255] = 0 61 | # label -= 1 62 | sample['mask'] = label 63 | 64 | if self.transform: 65 | sample = self.transform(sample) 66 | label = sample['mask'] 67 | del sample['mask'] 68 | label = self.encode(label.squeeze().numpy()).long() 69 | sample = [sample[k] for k in self.modals] 70 | # return sample, label, item_name 71 | return sample, label 72 | 73 | def _open_img(self, file): 74 | img = io.read_image(file) 75 | C, H, W = img.shape 76 | if C == 4: 77 | img = img[:3, ...] 78 | if C == 1: 79 | img = img.repeat(3, 1, 1) 80 | return img 81 | 82 | def encode(self, label: Tensor) -> Tensor: 83 | return torch.from_numpy(label) 84 | 85 | 86 | if __name__ == '__main__': 87 | cases = ['cloud', 'fog', 'night', 'rain', 'sun', 'motionblur', 'overexposure', 'underexposure', 'lidarjitter', 'eventlowres'] 88 | traintransform = get_train_augmentation((1024, 1024), seg_fill=255) 89 | for case in cases: 90 | 91 | trainset = DELIVER(transform=traintransform, split='val', case=case) 92 | trainloader = DataLoader(trainset, batch_size=2, num_workers=2, drop_last=False, pin_memory=False) 93 | 94 | for i, (sample, lbl) in enumerate(trainloader): 95 | print(torch.unique(lbl)) -------------------------------------------------------------------------------- /semseg/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, Tensor 3 | from torch.nn import functional as F 4 | 5 | 6 | class CrossEntropy(nn.Module): 7 | def __init__(self, ignore_label: int = 255, weight: Tensor = None, aux_weights: list = [1, 0.4, 0.4]) -> None: 8 | super().__init__() 9 | self.aux_weights = aux_weights 10 | self.criterion = nn.CrossEntropyLoss(weight=weight, ignore_index=ignore_label) 11 | 12 | def _forward(self, preds: Tensor, labels: Tensor) -> Tensor: 13 | # preds in shape [B, C, H, W] and labels in shape [B, H, W] 14 | return self.criterion(preds, labels) 15 | 16 | def forward(self, preds, labels: Tensor) -> Tensor: 17 | if isinstance(preds, tuple): 18 | return sum([w * self._forward(pred, labels) for (pred, w) in zip(preds, self.aux_weights)]) 19 | return self._forward(preds, labels) 20 | 21 | 22 | class OhemCrossEntropy(nn.Module): 23 | def __init__(self, ignore_label: int = 255, weight: Tensor = None, thresh: float = 0.7, aux_weights: list = [1, 1]) -> None: 24 | super().__init__() 25 | self.ignore_label = ignore_label 26 | self.aux_weights = aux_weights 27 | self.thresh = -torch.log(torch.tensor(thresh, dtype=torch.float)) 28 | self.criterion = nn.CrossEntropyLoss(weight=weight, ignore_index=ignore_label, reduction='none') 29 | 30 | def _forward(self, preds: Tensor, labels: Tensor) -> Tensor: 31 | # preds in shape [B, C, H, W] and labels in shape [B, H, W] 32 | n_min = labels[labels != self.ignore_label].numel() // 16 33 | loss = self.criterion(preds, labels).view(-1) 34 | loss_hard = loss[loss > self.thresh] 35 | 36 | if loss_hard.numel() < n_min: 37 | loss_hard, _ = loss.topk(n_min) 38 | 39 | return torch.mean(loss_hard) 40 | 41 | def forward(self, preds, labels: Tensor) -> Tensor: 42 | if isinstance(preds, tuple): 43 | return sum([w * self._forward(pred, labels) for (pred, w) in zip(preds, self.aux_weights)]) 44 | return self._forward(preds, labels) 45 | 46 | 47 | class Dice(nn.Module): 48 | def __init__(self, delta: float = 0.5, aux_weights: list = [1, 0.4, 0.4]): 49 | """ 50 | delta: Controls weight given to FP and FN. This equals to dice score when delta=0.5 51 | """ 52 | super().__init__() 53 | self.delta = delta 54 | self.aux_weights = aux_weights 55 | 56 | def _forward(self, preds: Tensor, labels: Tensor) -> Tensor: 57 | # preds in shape [B, C, H, W] and labels in shape [B, H, W] 58 | num_classes = preds.shape[1] 59 | labels = F.one_hot(labels, num_classes).permute(0, 3, 1, 2) 60 | tp = torch.sum(labels*preds, dim=(2, 3)) 61 | fn = torch.sum(labels*(1-preds), dim=(2, 3)) 62 | fp = torch.sum((1-labels)*preds, dim=(2, 3)) 63 | 64 | dice_score = (tp + 1e-6) / (tp + self.delta * fn + (1 - self.delta) * fp + 1e-6) 65 | dice_score = torch.sum(1 - dice_score, dim=-1) 66 | 67 | dice_score = dice_score / num_classes 68 | return dice_score.mean() 69 | 70 | def forward(self, preds, targets: Tensor) -> Tensor: 71 | if isinstance(preds, tuple): 72 | return sum([w * self._forward(pred, targets) for (pred, w) in zip(preds, self.aux_weights)]) 73 | return self._forward(preds, targets) 74 | 75 | 76 | __all__ = ['CrossEntropy', 'OhemCrossEntropy', 'Dice'] 77 | 78 | 79 | def get_loss(loss_fn_name: str = 'CrossEntropy', ignore_label: int = 255, cls_weights: Tensor = None): 80 | assert loss_fn_name in __all__, f"Unavailable loss function name >> {loss_fn_name}.\nAvailable loss functions: {__all__}" 81 | if loss_fn_name == 'Dice': 82 | return Dice() 83 | return eval(loss_fn_name)(ignore_label, cls_weights) 84 | 85 | 86 | if __name__ == '__main__': 87 | pred = torch.randint(0, 19, (2, 19, 480, 640), dtype=torch.float) 88 | label = torch.randint(0, 19, (2, 480, 640), dtype=torch.long) 89 | loss_fn = Dice() 90 | y = loss_fn(pred, label) 91 | print(y) -------------------------------------------------------------------------------- /semseg/datasets/fmb.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from torch import Tensor 5 | from torch.utils.data import Dataset 6 | import torchvision.transforms.functional as TF 7 | from torchvision import io 8 | from pathlib import Path 9 | from typing import Tuple 10 | import glob 11 | import einops 12 | from torch.utils.data import DataLoader 13 | from torch.utils.data import DistributedSampler, RandomSampler 14 | from semseg.augmentations_mm import get_train_augmentation 15 | 16 | class FMB(Dataset): 17 | """ 18 | num_classes: 14 19 | """ 20 | CLASSES = ["Road", "Sidewalk", "Building", "Traffic Light", "Traffic Sign", "Vegetation", "Sky", "Person", "Car", "Truck", "Bus", "Motorcycle", "Bicycle", "Pole"] 21 | 22 | PALETTE = torch.tensor([[70, 70, 70], 23 | [100, 40, 40], 24 | [55, 90, 80], 25 | [220, 20, 60], 26 | [153, 153, 153], 27 | [157, 234, 50], 28 | [128, 64, 128], 29 | [244, 35, 232], 30 | [107, 142, 35], 31 | [0, 0, 142], 32 | [102, 102, 156], 33 | [220, 220, 0], 34 | [70, 130, 180], 35 | [81, 0, 81], 36 | [150, 100, 100], 37 | ]) 38 | 39 | def __init__(self, root: str = 'data/FMB', split: str = 'train', transform = None, modals = ['img', 'thermal'], case = None) -> None: 40 | super().__init__() 41 | assert split in ['train', 'val'] 42 | self.transform = transform 43 | self.n_classes = len(self.CLASSES) 44 | self.ignore_label = 255 45 | self.modals = modals 46 | if split == 'val': 47 | split = 'test' 48 | self.files = sorted(glob.glob(os.path.join(*[root, split, 'Visible', '*.png']))) 49 | # --- debug 50 | # self.files = sorted(glob.glob(os.path.join(*[root, 'img', '*', split, '*', '*.png'])))[:100] 51 | # --- split as case 52 | if not self.files: 53 | raise Exception(f"No images found in {img_path}") 54 | print(f"Found {len(self.files)} {split} images.") 55 | 56 | def __len__(self) -> int: 57 | return len(self.files) 58 | 59 | def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]: 60 | item_name = self.files[index].split("/")[-1].split(".")[0] 61 | rgb = str(self.files[index]) 62 | thermal = rgb.replace('/Visible', '/Infrared') 63 | lbl_path = rgb.replace('/Visible', '/Label') 64 | 65 | sample = {} 66 | sample['img'] = io.read_image(rgb)[:3, ...] 67 | H, W = sample['img'].shape[1:] 68 | if 'thermal' in self.modals: 69 | sample['thermal'] = self._open_img(thermal) 70 | label = io.read_image(lbl_path)[0,...].unsqueeze(0) 71 | label[label==255] = 0 72 | label -= 1 73 | sample['mask'] = label 74 | 75 | if self.transform: 76 | sample = self.transform(sample) 77 | label = sample['mask'] 78 | del sample['mask'] 79 | label = self.encode(label.squeeze().numpy()).long() 80 | sample = [sample[k] for k in self.modals] 81 | # return sample, label, item_name 82 | return sample, label 83 | 84 | def _open_img(self, file): 85 | img = io.read_image(file) 86 | C, H, W = img.shape 87 | if C == 4: 88 | img = img[:3, ...] 89 | if C == 1: 90 | img = img.repeat(3, 1, 1) 91 | return img 92 | 93 | def encode(self, label: Tensor) -> Tensor: 94 | return torch.from_numpy(label) 95 | 96 | 97 | if __name__ == '__main__': 98 | cases = ['cloud', 'fog', 'night', 'rain', 'sun', 'motionblur', 'overexposure', 'underexposure', 'lidarjitter', 'eventlowres'] 99 | traintransform = get_train_augmentation((1024, 1024), seg_fill=255) 100 | for case in cases: 101 | 102 | trainset = DELIVER(transform=traintransform, split='val', case=case) 103 | trainloader = DataLoader(trainset, batch_size=2, num_workers=2, drop_last=False, pin_memory=False) 104 | 105 | for i, (sample, lbl) in enumerate(trainloader): 106 | print(torch.unique(lbl)) -------------------------------------------------------------------------------- /semseg/schedulers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | from torch.optim.lr_scheduler import _LRScheduler 4 | 5 | 6 | class PolyLR(_LRScheduler): 7 | def __init__(self, optimizer, max_iter, decay_iter=1, power=0.9, last_epoch=-1) -> None: 8 | self.decay_iter = decay_iter 9 | self.max_iter = max_iter 10 | self.power = power 11 | super().__init__(optimizer, last_epoch=last_epoch) 12 | 13 | def get_lr(self): 14 | if self.last_epoch % self.decay_iter or self.last_epoch % self.max_iter: 15 | return self.base_lrs 16 | else: 17 | factor = (1 - self.last_epoch / float(self.max_iter)) ** self.power 18 | return [factor*lr for lr in self.base_lrs] 19 | 20 | 21 | class WarmupLR(_LRScheduler): 22 | def __init__(self, optimizer, warmup_iter=500, warmup_ratio=5e-4, warmup='exp', last_epoch=-1) -> None: 23 | self.warmup_iter = warmup_iter 24 | self.warmup_ratio = warmup_ratio 25 | self.warmup = warmup 26 | super().__init__(optimizer, last_epoch) 27 | 28 | def get_lr(self): 29 | ratio = self.get_lr_ratio() 30 | return [ratio * lr for lr in self.base_lrs] 31 | 32 | def get_lr_ratio(self): 33 | return self.get_warmup_ratio() if self.last_epoch < self.warmup_iter else self.get_main_ratio() 34 | 35 | def get_main_ratio(self): 36 | raise NotImplementedError 37 | 38 | def get_warmup_ratio(self): 39 | assert self.warmup in ['linear', 'exp'] 40 | alpha = self.last_epoch / self.warmup_iter 41 | 42 | return self.warmup_ratio + (1. - self.warmup_ratio) * alpha if self.warmup == 'linear' else self.warmup_ratio ** (1. - alpha) 43 | 44 | 45 | class WarmupPolyLR(WarmupLR): 46 | def __init__(self, optimizer, power, max_iter, warmup_iter=500, warmup_ratio=5e-4, warmup='exp', last_epoch=-1) -> None: 47 | self.power = power 48 | self.max_iter = max_iter 49 | super().__init__(optimizer, warmup_iter, warmup_ratio, warmup, last_epoch) 50 | 51 | def get_main_ratio(self): 52 | real_iter = self.last_epoch - self.warmup_iter 53 | real_max_iter = self.max_iter - self.warmup_iter 54 | alpha = real_iter / real_max_iter 55 | 56 | return (1 - alpha) ** self.power 57 | 58 | 59 | class WarmupExpLR(WarmupLR): 60 | def __init__(self, optimizer, gamma, interval=1, warmup_iter=500, warmup_ratio=5e-4, warmup='exp', last_epoch=-1) -> None: 61 | self.gamma = gamma 62 | self.interval = interval 63 | super().__init__(optimizer, warmup_iter, warmup_ratio, warmup, last_epoch) 64 | 65 | def get_main_ratio(self): 66 | real_iter = self.last_epoch - self.warmup_iter 67 | return self.gamma ** (real_iter // self.interval) 68 | 69 | 70 | class WarmupCosineLR(WarmupLR): 71 | def __init__(self, optimizer, max_iter, eta_ratio=0, warmup_iter=500, warmup_ratio=5e-4, warmup='exp', last_epoch=-1) -> None: 72 | self.eta_ratio = eta_ratio 73 | self.max_iter = max_iter 74 | super().__init__(optimizer, warmup_iter, warmup_ratio, warmup, last_epoch) 75 | 76 | def get_main_ratio(self): 77 | real_iter = self.last_epoch - self.warmup_iter 78 | real_max_iter = self.max_iter - self.warmup_iter 79 | 80 | return self.eta_ratio + (1 - self.eta_ratio) * (1 + math.cos(math.pi * self.last_epoch / real_max_iter)) / 2 81 | 82 | 83 | 84 | __all__ = ['polylr', 'warmuppolylr', 'warmupcosinelr', 'warmupsteplr'] 85 | 86 | 87 | def get_scheduler(scheduler_name: str, optimizer, max_iter: int, power: int, warmup_iter: int, warmup_ratio: float): 88 | assert scheduler_name in __all__, f"Unavailable scheduler name >> {scheduler_name}.\nAvailable schedulers: {__all__}" 89 | if scheduler_name == 'warmuppolylr': 90 | return WarmupPolyLR(optimizer, power, max_iter, warmup_iter, warmup_ratio, warmup='linear') 91 | elif scheduler_name == 'warmupcosinelr': 92 | return WarmupCosineLR(optimizer, max_iter, warmup_iter=warmup_iter, warmup_ratio=warmup_ratio) 93 | return PolyLR(optimizer, max_iter) 94 | 95 | 96 | if __name__ == '__main__': 97 | model = torch.nn.Conv2d(3, 16, 3, 1, 1) 98 | optim = torch.optim.SGD(model.parameters(), lr=1e-3) 99 | 100 | max_iter = 20000 101 | sched = WarmupPolyLR(optim, power=0.9, max_iter=max_iter, warmup_iter=200, warmup_ratio=0.1, warmup='exp', last_epoch=-1) 102 | 103 | lrs = [] 104 | 105 | for _ in range(max_iter): 106 | lr = sched.get_lr()[0] 107 | lrs.append(lr) 108 | optim.step() 109 | sched.step() 110 | 111 | import matplotlib.pyplot as plt 112 | import numpy as np 113 | 114 | plt.plot(np.arange(len(lrs)), np.array(lrs)) 115 | plt.grid() 116 | plt.show() -------------------------------------------------------------------------------- /semseg/utils/visualize.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import numpy as np 4 | import torch 5 | import matplotlib.pyplot as plt 6 | from torch.utils.data import DataLoader 7 | from torchvision import transforms as T 8 | from torchvision.utils import make_grid 9 | from semseg.augmentations import Compose, Normalize, RandomResizedCrop 10 | from PIL import Image, ImageDraw, ImageFont 11 | 12 | 13 | def visualize_dataset_sample(dataset, root, split='val', batch_size=4): 14 | transform = Compose([ 15 | RandomResizedCrop((512, 512), scale=(1.0, 1.0)), 16 | Normalize() 17 | ]) 18 | 19 | dataset = dataset(root, split=split, transform=transform) 20 | dataloader = DataLoader(dataset, shuffle=True, batch_size=batch_size) 21 | image, label = next(iter(dataloader)) 22 | 23 | print(f"Image Shape\t: {image.shape}") 24 | print(f"Label Shape\t: {label.shape}") 25 | print(f"Classes\t\t: {label.unique().tolist()}") 26 | 27 | label[label == -1] = 0 28 | label[label == 255] = 0 29 | labels = [dataset.PALETTE[lbl.to(int)].permute(2, 0, 1) for lbl in label] 30 | labels = torch.stack(labels) 31 | 32 | inv_normalize = T.Normalize( 33 | mean=(-0.485/0.229, -0.456/0.224, -0.406/0.225), 34 | std=(1/0.229, 1/0.224, 1/0.225) 35 | ) 36 | image = inv_normalize(image) 37 | image *= 255 38 | images = torch.vstack([image, labels]) 39 | 40 | plt.imshow(make_grid(images, nrow=4).to(torch.uint8).numpy().transpose((1, 2, 0))) 41 | plt.show() 42 | 43 | 44 | colors = [ 45 | [120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50], [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255], 46 | [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7], [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82], 47 | [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3], [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255], 48 | [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220], [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224], 49 | [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255], [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7], 50 | [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153], [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255], 51 | [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0], [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255], 52 | [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255], [11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255], 53 | [0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0], [255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0], 54 | [0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255], [173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255], 55 | [255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20], [255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255], 56 | [255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255], [0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255], 57 | [0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0], [143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0], 58 | [8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255], [255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112], 59 | [92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160], [163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163], 60 | [255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0], [255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0], 61 | [10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255], [255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204], 62 | [41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255], [71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255], 63 | [184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194], [102, 255, 0], [92, 0, 255] 64 | ] 65 | 66 | 67 | def generate_palette(num_classes, background: bool = False): 68 | random.shuffle(colors) 69 | if background: 70 | palette = [[0, 0, 0]] 71 | palette += colors[:num_classes-1] 72 | else: 73 | palette = colors[:num_classes] 74 | return np.array(palette) 75 | 76 | 77 | def draw_text(image: torch.Tensor, seg_map: torch.Tensor, labels: list, fontsize: int = 15): 78 | image = image.to(torch.uint8) 79 | font = ImageFont.truetype("Helvetica.ttf", fontsize) 80 | pil_image = Image.fromarray(image.numpy()) 81 | draw = ImageDraw.Draw(pil_image) 82 | 83 | indices = seg_map.unique().tolist() 84 | classes = [labels[index] for index in indices] 85 | 86 | for idx, cls in zip(indices, classes): 87 | mask = seg_map == idx 88 | mask = mask.squeeze().numpy() 89 | center = np.median((mask == 1).nonzero(), axis=1)[::-1] 90 | bbox = draw.textbbox(center, cls, font=font) 91 | bbox = (bbox[0]-3, bbox[1]-3, bbox[2]+3, bbox[3]+3) 92 | draw.rectangle(bbox, fill=(255, 255, 255), width=1) 93 | draw.text(center, cls, fill=(0, 0, 0), font=font) 94 | return pil_image -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Repo-specific GitIgnore ---------------------------------------------------------------------------------------------- 2 | *.jpg 3 | *.jpeg 4 | *.png 5 | *.bmp 6 | *.tif 7 | *.tiff 8 | *.heic 9 | *.JPG 10 | *.JPEG 11 | *.PNG 12 | *.BMP 13 | *.TIF 14 | *.TIFF 15 | *.HEIC 16 | *.mp4 17 | *.mov 18 | *.MOV 19 | *.avi 20 | *.data 21 | *.json 22 | *.pth 23 | *.cfg 24 | !cfg/yolov3*.cfg 25 | 26 | storage.googleapis.com 27 | runs/* 28 | data/* 29 | !data/images/zidane.jpg 30 | !data/images/bus.jpg 31 | !data/coco.names 32 | !data/coco_paper.names 33 | !data/coco.data 34 | !data/coco_*.data 35 | !data/coco_*.txt 36 | !data/trainvalno5k.shapes 37 | !data/*.sh 38 | 39 | test.py 40 | test_imgs/ 41 | 42 | pycocotools/* 43 | results*.txt 44 | gcp_test*.sh 45 | 46 | checkpoints/ 47 | # output/ 48 | # output*/ 49 | *events* 50 | assests/*/ 51 | 52 | # Datasets ------------------------------------------------------------------------------------------------------------- 53 | coco/ 54 | coco128/ 55 | VOC/ 56 | 57 | # MATLAB GitIgnore ----------------------------------------------------------------------------------------------------- 58 | *.m~ 59 | *.mat 60 | !targets*.mat 61 | 62 | # Neural Network weights ----------------------------------------------------------------------------------------------- 63 | *.weights 64 | *.pt 65 | *.onnx 66 | *.mlmodel 67 | *.torchscript 68 | darknet53.conv.74 69 | yolov3-tiny.conv.15 70 | 71 | # GitHub Python GitIgnore ---------------------------------------------------------------------------------------------- 72 | # Byte-compiled / optimized / DLL files 73 | __pycache__/ 74 | *.py[cod] 75 | *$py.class 76 | 77 | # C extensions 78 | *.so 79 | 80 | # Distribution / packaging 81 | .Python 82 | env/ 83 | build/ 84 | develop-eggs/ 85 | dist/ 86 | downloads/ 87 | eggs/ 88 | .eggs/ 89 | lib/ 90 | lib64/ 91 | parts/ 92 | sdist/ 93 | var/ 94 | wheels/ 95 | *.egg-info/ 96 | wandb/ 97 | .installed.cfg 98 | *.egg 99 | 100 | 101 | # PyInstaller 102 | # Usually these files are written by a python script from a template 103 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 104 | *.manifest 105 | *.spec 106 | 107 | # Installer logs 108 | pip-log.txt 109 | pip-delete-this-directory.txt 110 | 111 | # Unit test / coverage reports 112 | htmlcov/ 113 | .tox/ 114 | .coverage 115 | .coverage.* 116 | .cache 117 | nosetests.xml 118 | coverage.xml 119 | *.cover 120 | .hypothesis/ 121 | 122 | # Translations 123 | *.mo 124 | *.pot 125 | 126 | # Django stuff: 127 | # *.log 128 | local_settings.py 129 | 130 | # Flask stuff: 131 | instance/ 132 | .webassets-cache 133 | 134 | # Scrapy stuff: 135 | .scrapy 136 | 137 | # Sphinx documentation 138 | docs/_build/ 139 | 140 | # PyBuilder 141 | target/ 142 | 143 | # Jupyter Notebook 144 | .ipynb_checkpoints 145 | 146 | # pyenv 147 | .python-version 148 | 149 | # celery beat schedule file 150 | celerybeat-schedule 151 | 152 | # SageMath parsed files 153 | *.sage.py 154 | 155 | # dotenv 156 | .env 157 | 158 | # virtualenv 159 | .venv* 160 | venv*/ 161 | ENV*/ 162 | 163 | # Spyder project settings 164 | .spyderproject 165 | .spyproject 166 | 167 | # Rope project settings 168 | .ropeproject 169 | 170 | # mkdocs documentation 171 | /site 172 | 173 | # mypy 174 | .mypy_cache/ 175 | 176 | 177 | # https://github.com/github/gitignore/blob/master/Global/macOS.gitignore ----------------------------------------------- 178 | 179 | # General 180 | .DS_Store 181 | .AppleDouble 182 | .LSOverride 183 | 184 | # Icon must end with two \r 185 | Icon 186 | Icon? 187 | 188 | # Thumbnails 189 | ._* 190 | 191 | # Files that might appear in the root of a volume 192 | .DocumentRevisions-V100 193 | .fseventsd 194 | .Spotlight-V100 195 | .TemporaryItems 196 | .Trashes 197 | .VolumeIcon.icns 198 | .com.apple.timemachine.donotpresent 199 | 200 | # Directories potentially created on remote AFP share 201 | .AppleDB 202 | .AppleDesktop 203 | Network Trash Folder 204 | Temporary Items 205 | .apdisk 206 | 207 | 208 | # https://github.com/github/gitignore/blob/master/Global/JetBrains.gitignore 209 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm 210 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 211 | 212 | # User-specific stuff: 213 | .idea/* 214 | .idea/**/workspace.xml 215 | .idea/**/tasks.xml 216 | .idea/dictionaries 217 | .html # Bokeh Plots 218 | .pg # TensorFlow Frozen Graphs 219 | .avi # videos 220 | 221 | # Sensitive or high-churn files: 222 | .idea/**/dataSources/ 223 | .idea/**/dataSources.ids 224 | .idea/**/dataSources.local.xml 225 | .idea/**/sqlDataSources.xml 226 | .idea/**/dynamic.xml 227 | .idea/**/uiDesigner.xml 228 | 229 | # Gradle: 230 | .idea/**/gradle.xml 231 | .idea/**/libraries 232 | 233 | # CMake 234 | cmake-build-debug/ 235 | cmake-build-release/ 236 | 237 | # Mongo Explorer plugin: 238 | .idea/**/mongoSettings.xml 239 | 240 | ## File-based project format: 241 | *.iws 242 | 243 | ## Plugin-specific files: 244 | 245 | # IntelliJ 246 | out/ 247 | 248 | # mpeltonen/sbt-idea plugin 249 | .idea_modules/ 250 | 251 | # JIRA plugin 252 | atlassian-ide-plugin.xml 253 | 254 | # Cursive Clojure plugin 255 | .idea/replstate.xml 256 | 257 | # Crashlytics plugin (for Android Studio and IntelliJ) 258 | com_crashlytics_export_strings.xml 259 | crashlytics.properties 260 | crashlytics-build.properties 261 | fabric.properties 262 | 263 | output/ 264 | data/ 265 | wandb/ 266 | output*.txt -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: mmsformer 2 | channels: 3 | - pytorch 4 | - defaults 5 | - conda-forge 6 | dependencies: 7 | - _libgcc_mutex=0.1=main 8 | - _openmp_mutex=4.5=1_gnu 9 | - _pytorch_select=0.1=cpu_0 10 | - blas=1.0=mkl 11 | - bzip2=1.0.8=h7f98852_4 12 | - ca-certificates=2021.10.26=h06a4308_2 13 | - certifi=2021.10.8=py38h06a4308_2 14 | - cffi=1.14.6=py38ha65f79e_0 15 | - cudatoolkit=11.3.1=h2bc3f7f_2 16 | - cudnn=8.2.1.32=h86fa8c9_0 17 | - ffmpeg=4.3=hf484d3e_0 18 | - freetype=2.10.4=h5ab3b9f_0 19 | - future=0.18.2=py38h578d9bd_4 20 | - gmp=6.2.1=h58526e2_0 21 | - gnutls=3.6.13=h85f3911_1 22 | - intel-openmp=2021.3.0=h06a4308_3350 23 | - jpeg=9d=h7f8727e_0 24 | - lame=3.100=h7f98852_1001 25 | - lcms2=2.12=h3be6417_0 26 | - ld_impl_linux-64=2.35.1=h7274673_9 27 | - libblas=3.9.0=11_linux64_mkl 28 | - libffi=3.3=he6710b0_2 29 | - libgcc-ng=9.3.0=h5101ec6_17 30 | - libgomp=9.3.0=h5101ec6_17 31 | - libiconv=1.16=h516909a_0 32 | - liblapack=3.9.0=11_linux64_mkl 33 | - libpng=1.6.37=hbc83047_0 34 | - libprotobuf=3.16.0=h780b84a_0 35 | - libstdcxx-ng=9.3.0=hd4cf53a_17 36 | - libtiff=4.2.0=h85742a9_0 37 | - libuv=1.40.0=h7b6447c_0 38 | - libwebp-base=1.2.0=h27cfd23_0 39 | - lz4-c=1.9.3=h295c915_1 40 | - magma=2.5.4=h6103c52_2 41 | - mkl=2021.3.0=h06a4308_520 42 | - mkl-service=2.4.0=py38h7f8727e_0 43 | - mkl_fft=1.3.0=py38h42c9631_2 44 | - mkl_random=1.2.2=py38h51133e4_0 45 | - nccl=2.11.4.1=hdc17891_0 46 | - ncurses=6.2=he6710b0_1 47 | - nettle=3.6=he412f7d_0 48 | - ninja=1.10.2=hff7bd54_1 49 | - numpy=1.21.2=py38h20f2e39_0 50 | - numpy-base=1.21.2=py38h79a1101_0 51 | - olefile=0.46=pyhd3eb1b0_0 52 | - openh264=2.1.1=h780b84a_0 53 | - openjpeg=2.4.0=h3ad879b_0 54 | - openssl=1.1.1m=h7f8727e_0 55 | - pillow=8.3.1=py38h2c7a002_0 56 | - pycparser=2.21=pyhd8ed1ab_0 57 | - python=3.8.12=h12debd9_0 58 | - python_abi=3.8=2_cp38 59 | - pytorch=1.9.0=cuda112py38h3d13190_1 60 | - pytorch-gpu=1.9.0=cuda112py38h0bbbad9_1 61 | - pytorch-mutex=1.0=cuda 62 | - readline=8.1=h27cfd23_0 63 | - six=1.16.0=pyhd3eb1b0_0 64 | - sleef=3.5.1=h7f98852_1 65 | - sqlite=3.36.0=hc218d9a_0 66 | - tk=8.6.11=h1ccaba5_0 67 | - torchaudio=0.9.0=py38 68 | - torchvision=0.10.0=py38cuda112h04b465a_0_cuda 69 | - typing_extensions=3.10.0.2=pyh06a4308_0 70 | - xz=5.2.5=h7b6447c_0 71 | - yaml=0.2.5=h7b6447c_0 72 | - zlib=1.2.11=h7b6447c_3 73 | - zstd=1.4.9=haebb681_0 74 | - pip: 75 | - absl-py==1.2.0 76 | - addict==2.4.0 77 | - appdirs==1.4.4 78 | - argon2-cffi==21.3.0 79 | - argon2-cffi-bindings==21.2.0 80 | - asttokens==2.0.5 81 | - attrs==21.4.0 82 | - backcall==0.2.0 83 | - bleach==4.1.0 84 | - cachetools==5.0.0 85 | - charset-normalizer==2.1.1 86 | - click==8.1.6 87 | - cycler==0.11.0 88 | - dataclasses==0.6 89 | - debugpy==1.5.1 90 | - decorator==5.1.1 91 | - defusedxml==0.7.1 92 | - descartes==1.1.0 93 | - docker-pycreds==0.4.0 94 | - easydict==1.9 95 | - einops==0.4.1 96 | - entrypoints==0.4 97 | - executing==0.8.3 98 | - fire==0.4.0 99 | - fvcore==0.1.5.post20220512 100 | - gitdb==4.0.10 101 | - gitpython==3.1.32 102 | - google-auth==2.11.0 103 | - google-auth-oauthlib==0.4.6 104 | - grpcio==1.48.1 105 | - idna==3.3 106 | - importlib-metadata==4.12.0 107 | - importlib-resources==5.4.0 108 | - iopath==0.1.10 109 | - ipykernel==6.9.1 110 | - ipython==8.1.0 111 | - ipython-genutils==0.2.0 112 | - ipywidgets==7.6.5 113 | - jedi==0.18.1 114 | - jinja2==3.0.3 115 | - joblib==1.1.0 116 | - jsonschema==4.4.0 117 | - jupyter==1.0.0 118 | - jupyter-client==7.1.2 119 | - jupyter-console==6.4.0 120 | - jupyter-core==4.9.2 121 | - jupyterlab-pygments==0.1.2 122 | - jupyterlab-widgets==1.0.2 123 | - kiwisolver==1.3.2 124 | - markdown==3.4.1 125 | - markupsafe==2.1.1 126 | - matplotlib==3.4.3 127 | - matplotlib-inline==0.1.3 128 | - mistune==0.8.4 129 | - mmcv==0.6.2 130 | - nbclient==0.5.11 131 | - nbconvert==6.4.2 132 | - nbformat==5.1.3 133 | - nest-asyncio==1.5.4 134 | - notebook==6.4.8 135 | - nuscenes-devkit==1.1.9 136 | - oauthlib==3.2.1 137 | - opencv-python==4.5.3.56 138 | - packaging==21.3 139 | - pandocfilters==1.5.0 140 | - parso==0.8.3 141 | - pathtools==0.1.2 142 | - pexpect==4.8.0 143 | - pickleshare==0.7.5 144 | - pip==22.0.3 145 | - plyfile==0.7.4 146 | - portalocker==2.5.1 147 | - prometheus-client==0.13.1 148 | - prompt-toolkit==3.0.28 149 | - protobuf==3.18.1 150 | - psutil==5.9.5 151 | - ptyprocess==0.7.0 152 | - pure-eval==0.2.2 153 | - pyasn1==0.4.8 154 | - pyasn1-modules==0.2.8 155 | - pycocotools==2.0.4 156 | - pygments==2.11.2 157 | - pyparsing==3.0.6 158 | - pyquaternion==0.9.9 159 | - pyrsistent==0.18.1 160 | - python-dateutil==2.8.2 161 | - pyyaml==6.0 162 | - pyzmq==22.3.0 163 | - qtconsole==5.2.2 164 | - qtpy==2.0.1 165 | - requests==2.28.1 166 | - requests-oauthlib==1.3.1 167 | - rsa==4.9 168 | - scikit-learn==1.0.2 169 | - scipy==1.7.1 170 | - send2trash==1.8.0 171 | - sentry-sdk==1.28.1 172 | - setproctitle==1.3.2 173 | - setuptools==59.5.0 174 | - shapely==1.8.1.post1 175 | - smmap==5.0.0 176 | - stack-data==0.2.0 177 | - tabulate==0.8.10 178 | - tensorboard==2.10.0 179 | - tensorboard-data-server==0.6.1 180 | - tensorboard-plugin-wit==1.8.1 181 | - tensorboardx==2.4 182 | - termcolor==1.1.0 183 | - terminado==0.13.1 184 | - testpath==0.6.0 185 | - threadpoolctl==3.1.0 186 | - timm==0.4.12 187 | - tornado==6.1 188 | - tqdm==4.62.3 189 | - traitlets==5.1.1 190 | - urllib3==1.26.12 191 | - wandb==0.15.7 192 | - wcwidth==0.2.5 193 | - webencodings==0.5.1 194 | - werkzeug==2.2.2 195 | - wheel==0.37.1 196 | - widgetsnbextension==3.5.2 197 | - yacs==0.1.8 198 | - yapf==0.32.0 199 | - zipp==3.7.0 -------------------------------------------------------------------------------- /tools/infer_mm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | import yaml 4 | import math 5 | from torch import Tensor 6 | from torch.nn import functional as F 7 | from pathlib import Path 8 | from torchvision import io 9 | from torchvision import transforms as T 10 | import torchvision.transforms.functional as TF 11 | from semseg.models import * 12 | from semseg.datasets import * 13 | from semseg.utils.utils import timer 14 | from semseg.utils.visualize import draw_text 15 | import glob 16 | import os 17 | from PIL import Image, ImageDraw, ImageFont 18 | 19 | 20 | class SemSeg: 21 | def __init__(self, cfg) -> None: 22 | # inference device cuda or cpu 23 | self.device = torch.device(cfg['DEVICE']) 24 | 25 | # get dataset classes' colors and labels 26 | self.palette = eval(cfg['DATASET']['NAME']).PALETTE 27 | self.labels = eval(cfg['DATASET']['NAME']).CLASSES 28 | 29 | # initialize the model and load weights and send to device 30 | self.model = eval(cfg['MODEL']['NAME'])(cfg['MODEL']['BACKBONE'], len(self.palette), cfg['DATASET']['MODALS']) 31 | msg = self.model.load_state_dict(torch.load(cfg['EVAL']['MODEL_PATH'], map_location='cpu')) 32 | print(msg) 33 | self.model = self.model.to(self.device) 34 | self.model.eval() 35 | 36 | # preprocess parameters and transformation pipeline 37 | self.size = cfg['TEST']['IMAGE_SIZE'] 38 | self.tf_pipeline_img = T.Compose([ 39 | T.Resize(self.size), 40 | T.Lambda(lambda x: x / 255), 41 | T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 42 | T.Lambda(lambda x: x.unsqueeze(0)) 43 | ]) 44 | self.tf_pipeline_modal = T.Compose([ 45 | T.Resize(self.size), 46 | T.Lambda(lambda x: x / 255), 47 | T.Lambda(lambda x: x.unsqueeze(0)) 48 | ]) 49 | 50 | def postprocess(self, orig_img: Tensor, seg_map: Tensor, overlay: bool) -> Tensor: 51 | seg_map = seg_map.softmax(dim=1).argmax(dim=1).cpu().to(int) 52 | 53 | seg_image = self.palette[seg_map].squeeze() 54 | if overlay: 55 | seg_image = (orig_img.permute(1, 2, 0) * 0.4) + (seg_image * 0.6) 56 | 57 | image = seg_image.to(torch.uint8) 58 | pil_image = Image.fromarray(image.numpy()) 59 | return pil_image 60 | 61 | @torch.inference_mode() 62 | @timer 63 | def model_forward(self, img: Tensor) -> Tensor: 64 | return self.model(img) 65 | 66 | def _open_img(self, file): 67 | img = io.read_image(file) 68 | C, H, W = img.shape 69 | if C == 4: 70 | img = img[:3, ...] 71 | if C == 1: 72 | img = img.repeat(3, 1, 1) 73 | return img 74 | 75 | def predict(self, img_fname: str, overlay: bool) -> Tensor: 76 | if cfg['DATASET']['NAME'] == 'DELIVER': 77 | x1 = img_fname.replace('/img', '/hha').replace('_rgb', '_depth') 78 | x2 = img_fname.replace('/img', '/lidar').replace('_rgb', '_lidar') 79 | x3 = img_fname.replace('/img', '/event').replace('_rgb', '_event') 80 | lbl_path = img_fname.replace('/img', '/semantic').replace('_rgb', '_semantic') 81 | elif cfg['DATASET']['NAME'] == 'KITTI360': 82 | x1 = os.path.join(img_fname.replace('data_2d_raw', 'data_2d_hha')) 83 | x2 = os.path.join(img_fname.replace('data_2d_raw', 'data_2d_lidar')) 84 | x2 = x2.replace('.png', '_color.png') 85 | x3 = os.path.join(img_fname.replace('data_2d_raw', 'data_2d_event')) 86 | x3 = x3.replace('/image_00/data_rect/', '/').replace('.png', '_event_image.png') 87 | lbl_path = os.path.join(*[img_fname.replace('data_2d_raw', 'data_2d_semantics/train').replace('data_rect', 'semantic')]) 88 | 89 | image = io.read_image(img_fname)[:3, ...] 90 | img = self.tf_pipeline_img(image).to(self.device) 91 | # --- modals 92 | x1 = self._open_img(x1) 93 | x1 = self.tf_pipeline_modal(x1).to(self.device) 94 | x2 = self._open_img(x2) 95 | x2 = self.tf_pipeline_modal(x2).to(self.device) 96 | x3 = self._open_img(x3) 97 | x3 = self.tf_pipeline_modal(x3).to(self.device) 98 | label = io.read_image(lbl_path)[0,...].unsqueeze(0) 99 | label[label==255] = 0 100 | label -= 1 101 | 102 | sample = [img, x1, x2, x3][:len(modals)] 103 | 104 | seg_map = self.model_forward(sample) 105 | seg_map = self.postprocess(image, seg_map, overlay) 106 | return seg_map 107 | 108 | 109 | if __name__ == '__main__': 110 | parser = argparse.ArgumentParser() 111 | parser.add_argument('--cfg', type=str, default='configs/DELIVER.yaml') 112 | args = parser.parse_args() 113 | with open(args.cfg) as f: 114 | cfg = yaml.load(f, Loader=yaml.SafeLoader) 115 | 116 | # cases = ['cloud', 'fog', 'night', 'rain', 'sun', 'motionblur', 'overexposure', 'underexposure', 'lidarjitter', 'eventlowres', None] 117 | cases = ['lidarjitter'] 118 | 119 | modals = cfg['DATASET']['MODALS'] 120 | 121 | test_file = Path(cfg['TEST']['FILE']) 122 | if not test_file.exists(): 123 | raise FileNotFoundError(test_file) 124 | 125 | # print(f"Model {cfg['MODEL']['NAME']} {cfg['MODEL']['BACKBONE']}") 126 | # print(f"Model {cfg['DATASET']['NAME']}") 127 | 128 | modals_name = ''.join([m[0] for m in cfg['DATASET']['MODALS']]) 129 | save_dir = Path(cfg['SAVE_DIR']) / 'test_results' / (cfg['DATASET']['NAME']+'_'+cfg['MODEL']['BACKBONE']+'_'+modals_name) 130 | 131 | semseg = SemSeg(cfg) 132 | 133 | if test_file.is_file(): 134 | segmap = semseg.predict(str(test_file), cfg['TEST']['OVERLAY']) 135 | segmap.save(save_dir / f"{str(test_file.stem)}.png") 136 | else: 137 | if cfg['DATASET']['NAME'] == 'DELIVER': 138 | files = sorted(glob.glob(os.path.join(*[str(test_file), 'img', '*', 'val', '*', '*.png']))) # --- Deliver 139 | elif cfg['DATASET']['NAME'] == 'KITTI360': 140 | source = os.path.join(test_file, 'val.txt') 141 | files = [] 142 | with open(source) as f: 143 | files_ = f.readlines() 144 | for item in files_: 145 | file_name = item.strip() 146 | if ' ' in file_name: 147 | # --- KITTI-360 148 | file_name = os.path.join(*[str(test_file), file_name.split(' ')[0]]) 149 | files.append(file_name) 150 | else: 151 | raise NotImplementedError() 152 | 153 | for file in files: 154 | print(file) 155 | if not '2013_05_28_drive_0000_sync' in file: 156 | continue 157 | segmap = semseg.predict(file, cfg['TEST']['OVERLAY']) 158 | save_path = os.path.join(str(save_dir),file) 159 | os.makedirs(os.path.dirname(save_path), exist_ok=True) 160 | segmap.save(save_path) 161 | -------------------------------------------------------------------------------- /tools/val_mm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | import yaml 4 | import math 5 | import os 6 | import time 7 | from pathlib import Path 8 | from tqdm import tqdm 9 | from tabulate import tabulate 10 | from torch.utils.data import DataLoader 11 | from torch.nn import functional as F 12 | from semseg.models import * 13 | from semseg.datasets import * 14 | from semseg.augmentations_mm import get_val_augmentation 15 | from semseg.metrics import Metrics 16 | from semseg.utils.utils import setup_cudnn 17 | from math import ceil 18 | import numpy as np 19 | from torch.utils.data import DistributedSampler, RandomSampler 20 | from torch import distributed as dist 21 | from torch.nn.parallel import DistributedDataParallel as DDP 22 | from semseg.utils.utils import fix_seeds, setup_cudnn, cleanup_ddp, setup_ddp, get_logger, cal_flops, print_iou 23 | 24 | def pad_image(img, target_size): 25 | rows_to_pad = max(target_size[0] - img.shape[2], 0) 26 | cols_to_pad = max(target_size[1] - img.shape[3], 0) 27 | padded_img = F.pad(img, (0, cols_to_pad, 0, rows_to_pad), "constant", 0) 28 | return padded_img 29 | 30 | @torch.no_grad() 31 | def sliding_predict(model, image, num_classes, flip=True): 32 | image_size = image[0].shape 33 | tile_size = (int(ceil(image_size[2]*1)), int(ceil(image_size[3]*1))) 34 | overlap = 1/3 35 | 36 | stride = ceil(tile_size[0] * (1 - overlap)) 37 | 38 | num_rows = int(ceil((image_size[2] - tile_size[0]) / stride) + 1) 39 | num_cols = int(ceil((image_size[3] - tile_size[1]) / stride) + 1) 40 | total_predictions = torch.zeros((num_classes, image_size[2], image_size[3]), device=torch.device('cuda')) 41 | count_predictions = torch.zeros((image_size[2], image_size[3]), device=torch.device('cuda')) 42 | tile_counter = 0 43 | 44 | for row in range(num_rows): 45 | for col in range(num_cols): 46 | x_min, y_min = int(col * stride), int(row * stride) 47 | x_max = min(x_min + tile_size[1], image_size[3]) 48 | y_max = min(y_min + tile_size[0], image_size[2]) 49 | 50 | img = [modal[:, :, y_min:y_max, x_min:x_max] for modal in image] 51 | padded_img = [pad_image(modal, tile_size) for modal in img] 52 | tile_counter += 1 53 | padded_prediction = model(padded_img) 54 | if flip: 55 | fliped_img = [padded_modal.flip(-1) for padded_modal in padded_img] 56 | fliped_predictions = model(fliped_img) 57 | padded_prediction += fliped_predictions.flip(-1) 58 | predictions = padded_prediction[:, :, :img[0].shape[2], :img[0].shape[3]] 59 | count_predictions[y_min:y_max, x_min:x_max] += 1 60 | total_predictions[:, y_min:y_max, x_min:x_max] += predictions.squeeze(0) 61 | 62 | return total_predictions.unsqueeze(0) 63 | 64 | @torch.no_grad() 65 | def evaluate(model, dataloader, device, loss_fn=None): 66 | print('Evaluating...') 67 | model.eval() 68 | n_classes = dataloader.dataset.n_classes 69 | metrics = Metrics(n_classes, dataloader.dataset.ignore_label, device) 70 | sliding = False 71 | test_loss = 0.0 72 | iter = 0 73 | for images, labels in tqdm(dataloader): 74 | images = [x.to(device) for x in images] 75 | labels = labels.to(device) 76 | if sliding: 77 | # preds = sliding_predict(model, images, num_classes=n_classes).softmax(dim=1) 78 | preds = sliding_predict(model, images, num_classes=n_classes) 79 | else: 80 | # preds = model(images).softmax(dim=1) 81 | preds = model(images) 82 | 83 | metrics.update(preds.softmax(dim=1), labels) 84 | 85 | if loss_fn is not None: 86 | loss = loss_fn(preds, labels) 87 | test_loss += loss.item() 88 | iter += 1 89 | 90 | test_loss /= iter 91 | ious, miou = metrics.compute_iou() 92 | acc, macc = metrics.compute_pixel_acc() 93 | f1, mf1 = metrics.compute_f1() 94 | 95 | return acc, macc, f1, mf1, ious, miou, test_loss 96 | 97 | 98 | @torch.no_grad() 99 | def evaluate_msf(model, dataloader, device, scales, flip): 100 | model.eval() 101 | 102 | n_classes = dataloader.dataset.n_classes 103 | metrics = Metrics(n_classes, dataloader.dataset.ignore_label, device) 104 | 105 | for images, labels in tqdm(dataloader): 106 | labels = labels.to(device) 107 | B, H, W = labels.shape 108 | scaled_logits = torch.zeros(B, n_classes, H, W).to(device) 109 | 110 | for scale in scales: 111 | new_H, new_W = int(scale * H), int(scale * W) 112 | new_H, new_W = int(math.ceil(new_H / 32)) * 32, int(math.ceil(new_W / 32)) * 32 113 | scaled_images = [F.interpolate(img, size=(new_H, new_W), mode='bilinear', align_corners=True) for img in images] 114 | scaled_images = [scaled_img.to(device) for scaled_img in scaled_images] 115 | logits = model(scaled_images) 116 | logits = F.interpolate(logits, size=(H, W), mode='bilinear', align_corners=True) 117 | scaled_logits += logits.softmax(dim=1) 118 | 119 | if flip: 120 | scaled_images = [torch.flip(scaled_img, dims=(3,)) for scaled_img in scaled_images] 121 | logits = model(scaled_images) 122 | logits = torch.flip(logits, dims=(3,)) 123 | logits = F.interpolate(logits, size=(H, W), mode='bilinear', align_corners=True) 124 | scaled_logits += logits.softmax(dim=1) 125 | 126 | metrics.update(scaled_logits, labels) 127 | 128 | acc, macc = metrics.compute_pixel_acc() 129 | f1, mf1 = metrics.compute_f1() 130 | ious, miou = metrics.compute_iou() 131 | return acc, macc, f1, mf1, ious, miou 132 | 133 | 134 | def main(cfg): 135 | device = torch.device(cfg['DEVICE']) 136 | 137 | eval_cfg = cfg['EVAL'] 138 | transform = get_val_augmentation(eval_cfg['IMAGE_SIZE']) 139 | # cases = ['cloud', 'fog', 'night', 'rain', 'sun'] 140 | # cases = ['motionblur', 'overexposure', 'underexposure', 'lidarjitter', 'eventlowres'] 141 | cases = [None] # all 142 | 143 | model_path = Path(eval_cfg['MODEL_PATH']) 144 | if not model_path.exists(): 145 | raise FileNotFoundError 146 | print(f"Evaluating {model_path}...") 147 | 148 | exp_time = time.strftime('%Y%m%d_%H%M%S', time.localtime()) 149 | eval_path = os.path.join(os.path.dirname(eval_cfg['MODEL_PATH']), 'eval_{}.txt'.format(exp_time)) 150 | 151 | for case in cases: 152 | dataset = eval(cfg['DATASET']['NAME'])(cfg['DATASET']['ROOT'], 'val', transform, cfg['DATASET']['MODALS'], case) 153 | # --- test set 154 | # dataset = eval(cfg['DATASET']['NAME'])(cfg['DATASET']['ROOT'], 'test', transform, cfg['DATASET']['MODALS'], case) 155 | 156 | model = eval(cfg['MODEL']['NAME'])(cfg['MODEL']['BACKBONE'], dataset.n_classes, cfg['DATASET']['MODALS']) 157 | msg = model.load_state_dict(torch.load(str(model_path), map_location='cpu')) 158 | print(msg) 159 | model = model.to(device) 160 | sampler_val = None 161 | dataloader = DataLoader(dataset, batch_size=eval_cfg['BATCH_SIZE'], num_workers=eval_cfg['BATCH_SIZE'], pin_memory=False, sampler=sampler_val) 162 | if True: 163 | if eval_cfg['MSF']['ENABLE']: 164 | acc, macc, f1, mf1, ious, miou = evaluate_msf(model, dataloader, device, eval_cfg['MSF']['SCALES'], eval_cfg['MSF']['FLIP']) 165 | else: 166 | acc, macc, f1, mf1, ious, miou, _ = evaluate(model, dataloader, device) 167 | 168 | table = { 169 | 'Class': list(dataset.CLASSES) + ['Mean'], 170 | 'IoU': ious + [miou], 171 | 'F1': f1 + [mf1], 172 | 'Acc': acc + [macc] 173 | } 174 | print("mIoU : {}".format(miou)) 175 | print("Results saved in {}".format(eval_cfg['MODEL_PATH'])) 176 | 177 | with open(eval_path, 'a+') as f: 178 | f.writelines(eval_cfg['MODEL_PATH']) 179 | f.write("\n============== Eval on {} {} images =================\n".format(case, len(dataset))) 180 | f.write("\n") 181 | print(tabulate(table, headers='keys'), file=f) 182 | 183 | 184 | 185 | if __name__ == '__main__': 186 | parser = argparse.ArgumentParser() 187 | parser.add_argument('--cfg', type=str, default='configs/mcubes_rgbadn.yaml') 188 | args = parser.parse_args() 189 | 190 | with open(args.cfg) as f: 191 | cfg = yaml.load(f, Loader=yaml.SafeLoader) 192 | 193 | setup_cudnn() 194 | # gpu = setup_ddp() 195 | # main(cfg, gpu) 196 | main(cfg) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | ## MMSFormer: Multimodal Transformer for Material and Semantic Segmentation 4 | 5 |
6 | 7 |
8 | 9 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/multimodal-transformer-for-material/semantic-segmentation-on-mcubes)](https://paperswithcode.com/sota/semantic-segmentation-on-mcubes?p=multimodal-transformer-for-material) 10 | 11 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/multimodal-transformer-for-material/semantic-segmentation-on-fmb-dataset)](https://paperswithcode.com/sota/semantic-segmentation-on-fmb-dataset?p=multimodal-transformer-for-material) 12 | 13 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/multimodal-transformer-for-material/thermal-image-segmentation-on-pst900)](https://paperswithcode.com/sota/thermal-image-segmentation-on-pst900?p=multimodal-transformer-for-material) 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 |
28 | 29 | ## Introduction 30 | 31 | Leveraging information across diverse modalities is known to enhance performance on multimodal segmentation tasks. However, effectively fusing information from different modalities remains challenging due to the unique characteristics of each modality. In this paper, we propose a novel fusion strategy that can effectively fuse information from different modality combinations. We also propose a new model named **M**ulti-**M**odal **S**egmentation Trans**Former** (**MMSFormer**) that incorporates the proposed fusion strategy to perform multimodal material and semantic segmentation tasks. MMSFormer outperforms current state-of-the-art models on three different datasets. As we begin with only one input modality, performance improves progressively as additional modalities are incorporated, showcasing the effectiveness of the fusion block in combining useful information from diverse input modalities. Ablation studies show that different modules in the fusion block are crucial for overall model performance. Furthermore, our ablation studies also highlight the capacity of different input modalities to improve performance in the identification of different types of materials. 32 | 33 | For more details, please check our [arXiv](https://arxiv.org/abs/2309.04001) paper. 34 | 35 | ## Updates 36 | - [x] 09/2023: Init repository. 37 | - [x] 09/2023: Release the code for MMSFormer. 38 | - [x] 09/2023: Release MMSFormer model weights. Download from [**GoogleDrive**](https://drive.google.com/drive/folders/1OPr7PUrL7hkBXogmHFzHuTJweHuJmlP-?usp=sharing). 39 | - [x] 01/2024: Update code, description and pretrained weights. 40 | - [x] 04/2024: Accepted by [IEEE Open Journal of Signal Processing](https://ieeexplore.ieee.org/document/10502124). 41 | 42 | ## MMSFormer model 43 | 44 |
45 | 46 | ![MMSFormer](figs/MMSFormer-V2.png) 47 | **Figure:** Overall architecture of MMSFormer model and proposed fusion block. 48 | 49 |
50 | 51 | ## Environment 52 | 53 | First, create and activate the environment using the following commands: 54 | ```bash 55 | conda env create -f environment.yaml 56 | conda activate mmsformer 57 | ``` 58 | 59 | ## Data preparation 60 | Download the dataset: 61 | - [MCubeS](https://github.com/kyotovision-public/multimodal-material-segmentation), for multimodal material segmentation with RGB-A-D-N modalities. 62 | - [FMB](https://github.com/JinyuanLiu-CV/SegMiF), for FMB dataset with RGB-Infrared modalities. 63 | - [PST](https://github.com/ShreyasSkandanS/pst900_thermal_rgb), for PST900 dataset with RGB-Thermal modalities. 64 | 65 | Then, put the dataset under `data` directory as follows: 66 | 67 | ``` 68 | data/ 69 | ├── MCubeS 70 | │   ├── polL_color 71 | │   ├── polL_aolp_sin 72 | │   ├── polL_aolp_cos 73 | │   ├── polL_dolp 74 | │   ├── NIR_warped 75 | │   ├── NIR_warped_mask 76 | │   ├── GT 77 | │   ├── SSGT4MS 78 | │   ├── list_folder 79 | │   └── SS 80 | ├── FMB 81 | │   ├── test 82 | │   │   ├── color 83 | │   │   ├── Infrared 84 | │   │   ├── Label 85 | │   │   └── Visible 86 | │   ├── train 87 | │   │   ├── color 88 | │   │   ├── Infrared 89 | │   │   ├── Label 90 | │   │   └── Visible 91 | ├── PST 92 | │   ├── test 93 | │   │   ├── rgb 94 | │   │   ├── thermal 95 | │   │   └── labels 96 | │   ├── train 97 | │   │   ├── rgb 98 | │   │   ├── thermal 99 | │   │   └── labels 100 | ``` 101 | 102 | ## Model Zoo 103 | 104 | ### MCubeS 105 | | Model-Modal | mIoU | weight | 106 | | :--------------- | :----- | :----- | 107 | | MCubeS-RGB | 50.44 | [GoogleDrive](https://drive.google.com/drive/folders/1TiC4spUgMGo8zO2iChpuuRo8cmZC2yeh?usp=sharing) | 108 | | MCubeS-RGB-A | 51.30 | [GoogleDrive](https://drive.google.com/drive/folders/1TiC4spUgMGo8zO2iChpuuRo8cmZC2yeh?usp=sharing) | 109 | | MCubeS-RGB-A-D | 52.03 | [GoogleDrive](https://drive.google.com/drive/folders/1TiC4spUgMGo8zO2iChpuuRo8cmZC2yeh?usp=sharing) | 110 | | MCubeS-RGB-A-D-N | 53.11 | [GoogleDrive](https://drive.google.com/drive/folders/1TiC4spUgMGo8zO2iChpuuRo8cmZC2yeh?usp=sharing) | 111 | 112 | ### FMB 113 | | Model-Modal | mIoU | weight | 114 | | :--------------- | :----- | :----- | 115 | | FMB-RGB | 57.17 | [GoogleDrive](https://drive.google.com/drive/folders/15kuBWiEHOxxOLMxvASYzPhdSgG8ZWfgm?usp=sharing) | 116 | | FMB-RGB-Infrared | 61.68 | [GoogleDrive](https://drive.google.com/drive/folders/15kuBWiEHOxxOLMxvASYzPhdSgG8ZWfgm?usp=sharing) | 117 | 118 | ### PST900 119 | | Model-Modal | mIoU | weight | 120 | | :--------------- | :----- | :----- | 121 | | PST-RGB-T | 87.45 | [GoogleDrive](https://drive.google.com/drive/folders/1yv7wfGVLrxBYQ3teDg3eL-zYJsie56Ll?usp=sharing) | 122 | 123 | 124 | ## Training 125 | 126 | Before training, please download [pre-trained SegFormer](https://drive.google.com/drive/folders/10XgSW8f7ghRs9fJ0dE-EV8G2E_guVsT5), and put it in the correct directory following this structure: 127 | 128 | ```text 129 | checkpoints/pretrained/segformer 130 | ├── mit_b0.pth 131 | ├── mit_b1.pth 132 | ├── mit_b2.pth 133 | ├── mit_b3.pth 134 | └── mit_b4.pth 135 | ``` 136 | 137 | To train MMSFormer model, please update the appropriate configuration file in `configs/` with appropriate paths and hyper-parameters. Then run as follows: 138 | 139 | ```bash 140 | cd path/to/MMSFormer 141 | conda activate mmsformer 142 | 143 | python -m tools.train_mm --cfg configs/mcubes_rgbadn.yaml 144 | 145 | python -m tools.train_mm --cfg configs/fmb_rgbt.yaml 146 | 147 | python -m tools.train_mm --cfg configs/pst_rgbt.yaml 148 | ``` 149 | 150 | 151 | ## Evaluation 152 | To evaluate MMSFormer models, please download respective model weights ([**GoogleDrive**](https://drive.google.com/drive/folders/1OPr7PUrL7hkBXogmHFzHuTJweHuJmlP-?usp=sharing)) and save them under any folder you like. 153 | 154 | 163 | 164 | Then, update the `EVAL` section of the appropriate configuration file in `configs/` and run: 165 | 166 | ```bash 167 | cd path/to/MMSFormer 168 | conda activate mmsformer 169 | 170 | python -m tools.val_mm --cfg configs/mcubes_rgbadn.yaml 171 | 172 | python -m tools.val_mm --cfg configs/fmb_rgbt.yaml 173 | 174 | python -m tools.val_mm --cfg configs/pst_rgbt.yaml 175 | ``` 176 | 177 | ## License 178 | 179 | This repository is under the Apache-2.0 license. For commercial use, please contact with the authors. 180 | 181 | 182 | ## Citations 183 | 184 | If you use MMSFormer model, please cite the following work: 185 | 186 | - **MMSFormer** [[**arXiv**](https://arxiv.org/abs/2309.04001)] 187 | ``` 188 | @ARTICLE{Reza2024MMSFormer, 189 | author={Reza, Md Kaykobad and Prater-Bennette, Ashley and Asif, M. Salman}, 190 | journal={IEEE Open Journal of Signal Processing}, 191 | title={MMSFormer: Multimodal Transformer for Material and Semantic Segmentation}, 192 | year={2024}, 193 | volume={}, 194 | number={}, 195 | pages={1-12}, 196 | keywords={Image segmentation;Feature extraction;Transformers;Task analysis;Fuses;Semantic segmentation;Decoding;multimodal image segmentation;material segmentation;semantic segmentation;multimodal fusion;transformer}, 197 | doi={10.1109/OJSP.2024.3389812} 198 | } 199 | ``` 200 | 201 | ## Acknowledgements 202 | Our codebase is based on the following Github repositories. Thanks to the following public repositories: 203 | - [DELIVER](https://github.com/jamycheung/DELIVER) 204 | - [RGBX-semantic-segmentation](https://github.com/huaaaliu/RGBX_Semantic_Segmentation) 205 | - [Semantic-segmentation](https://github.com/sithu31296/semantic-segmentation) 206 | 207 | **Note:** This is a research level repository and might contain issues/bugs. Please contact the authors for any query. 208 | -------------------------------------------------------------------------------- /semseg/utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import random 4 | import time 5 | import os 6 | import sys 7 | import functools 8 | from pathlib import Path 9 | from torch.backends import cudnn 10 | from torch import nn, Tensor 11 | from torch.autograd import profiler 12 | from typing import Union 13 | from torch import distributed as dist 14 | from tabulate import tabulate 15 | from semseg import models 16 | import logging 17 | from fvcore.nn import flop_count_table, FlopCountAnalysis 18 | import datetime 19 | 20 | def fix_seeds(seed: int = 3407) -> None: 21 | torch.manual_seed(seed) 22 | torch.cuda.manual_seed(seed) 23 | np.random.seed(seed) 24 | random.seed(seed) 25 | 26 | def setup_cudnn() -> None: 27 | # Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html 28 | cudnn.benchmark = True 29 | cudnn.deterministic = False 30 | 31 | def time_sync() -> float: 32 | if torch.cuda.is_available(): 33 | torch.cuda.synchronize() 34 | return time.time() 35 | 36 | def get_model_size(model: Union[nn.Module, torch.jit.ScriptModule]): 37 | tmp_model_path = Path('temp.p') 38 | if isinstance(model, torch.jit.ScriptModule): 39 | torch.jit.save(model, tmp_model_path) 40 | else: 41 | torch.save(model.state_dict(), tmp_model_path) 42 | size = tmp_model_path.stat().st_size 43 | os.remove(tmp_model_path) 44 | return size / 1e6 # in MB 45 | 46 | @torch.no_grad() 47 | def test_model_latency(model: nn.Module, inputs: torch.Tensor, use_cuda: bool = False) -> float: 48 | with profiler.profile(use_cuda=use_cuda) as prof: 49 | _ = model(inputs) 50 | return prof.self_cpu_time_total / 1000 # ms 51 | 52 | def count_parameters(model: nn.Module) -> float: 53 | return sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6 # in M 54 | 55 | def setup_ddp(): 56 | # print(os.environ.keys()) 57 | if 'SLURM_PROCID' in os.environ and not 'RANK' in os.environ: 58 | # --- multi nodes 59 | world_size = int(os.environ['WORLD_SIZE']) 60 | rank = int(os.environ["SLURM_PROCID"]) 61 | gpus_per_node = int(os.environ["SLURM_GPUS_ON_NODE"]) 62 | gpu = rank - gpus_per_node * (rank // gpus_per_node) 63 | torch.cuda.set_device(gpu) 64 | dist.init_process_group(backend="nccl", world_size=world_size, rank=rank, timeout=datetime.timedelta(seconds=7200)) 65 | elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 66 | rank = int(os.environ['RANK']) 67 | world_size = int(os.environ['WORLD_SIZE']) 68 | # gpu = int(os.environ(['LOCAL_RANK'])) 69 | # --- 70 | gpu = int(os.environ['LOCAL_RANK']) 71 | torch.cuda.set_device(gpu) 72 | dist.init_process_group('nccl', init_method="env://",world_size=world_size, rank=rank, timeout=datetime.timedelta(seconds=7200)) 73 | dist.barrier() 74 | else: 75 | gpu = 0 76 | return gpu 77 | 78 | def cleanup_ddp(): 79 | if dist.is_initialized(): 80 | dist.destroy_process_group() 81 | 82 | def reduce_tensor(tensor: Tensor) -> Tensor: 83 | rt = tensor.clone() 84 | dist.all_reduce(rt, op=dist.ReduceOp.SUM) 85 | rt /= dist.get_world_size() 86 | return rt 87 | 88 | @torch.no_grad() 89 | def throughput(dataloader, model: nn.Module, times: int = 30): 90 | model.eval() 91 | images, _ = next(iter(dataloader)) 92 | images = images.cuda(non_blocking=True) 93 | B = images.shape[0] 94 | print(f"Throughput averaged with {times} times") 95 | start = time_sync() 96 | for _ in range(times): 97 | model(images) 98 | end = time_sync() 99 | 100 | print(f"Batch Size {B} throughput {times * B / (end - start)} images/s") 101 | 102 | 103 | def show_models(): 104 | model_names = models.__all__ 105 | model_variants = [list(eval(f'models.{name.lower()}_settings').keys()) for name in model_names] 106 | 107 | print(tabulate({'Model Names': model_names, 'Model Variants': model_variants}, headers='keys')) 108 | 109 | 110 | def timer(func): 111 | @functools.wraps(func) 112 | def wrapper_timer(*args, **kwargs): 113 | tic = time.perf_counter() 114 | value = func(*args, **kwargs) 115 | toc = time.perf_counter() 116 | elapsed_time = toc - tic 117 | print(f"Elapsed time: {elapsed_time * 1000:.2f}ms") 118 | return value 119 | return wrapper_timer 120 | 121 | 122 | # _default_level_name = os.getenv('ENGINE_LOGGING_LEVEL', 'INFO') 123 | # _default_level = logging.getLevelName(_default_level_name.upper()) 124 | 125 | def get_logger(log_file=None): 126 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s: - %(message)s',datefmt='%Y%m%d %H:%M:%S') 127 | logger = logging.getLogger() 128 | # logger.setLevel(logging.DEBUG) 129 | logger.setLevel(logging.INFO) 130 | del logger.handlers[:] 131 | 132 | if log_file: 133 | file_handler = logging.FileHandler(log_file, mode='w') 134 | # file_handler.setLevel(logging.DEBUG) 135 | file_handler.setLevel(logging.INFO) 136 | file_handler.setFormatter(formatter) 137 | logger.addHandler(file_handler) 138 | 139 | stream_handler = logging.StreamHandler() 140 | stream_handler.setFormatter(formatter) 141 | # stream_handler.setLevel(logging.DEBUG) 142 | stream_handler.setLevel(logging.INFO) 143 | logger.addHandler(stream_handler) 144 | return logger 145 | 146 | 147 | def cal_flops(model, modals, logger): 148 | x = [torch.zeros(1, 3, 512, 512) for _ in range(len(modals))] 149 | # x = [torch.zeros(2, 3, 512, 512) for _ in range(len(modals))] #--- PGSNet 150 | # x = [torch.zeros(1, 3, 512, 512) for _ in range(len(modals))] # --- for HRFuser 151 | if torch.distributed.is_initialized(): 152 | if 'HR' in model.module.__class__.__name__: 153 | x = [torch.zeros(1, 3, 512, 512) for _ in range(len(modals))] # --- for HorNet 154 | else: 155 | if 'HR' in model.__class__.__name__: 156 | x = [torch.zeros(1, 3, 512, 512) for _ in range(len(modals))] # --- for HorNet 157 | 158 | if torch.cuda.is_available: 159 | x = [xi.cuda() for xi in x] 160 | model = model.cuda() 161 | logger.info(flop_count_table(FlopCountAnalysis(model, x))) 162 | 163 | def print_iou(epoch, iou, miou, acc, macc, class_names): 164 | assert len(iou) == len(class_names) 165 | assert len(acc) == len(class_names) 166 | lines = ['\n%-8s\t%-8s\t%-8s' % ('Class', 'IoU', 'Acc')] 167 | for i in range(len(iou)): 168 | if class_names is None: 169 | cls = 'Class %d:' % (i+1) 170 | else: 171 | cls = '%d %s' % (i+1, class_names[i]) 172 | lines.append('%-8s\t%.2f\t%.2f' % (cls, iou[i], acc[i])) 173 | lines.append('== %-8s\t%d\t%-8s\t%.2f\t%-8s\t%.2f' % ('Epoch:', epoch, 'mean_IoU', miou, 'mean_Acc',macc)) 174 | line = "\n".join(lines) 175 | return line 176 | 177 | 178 | def nchw_to_nlc(x): 179 | """Flatten [N, C, H, W] shape tensor to [N, L, C] shape tensor. 180 | 181 | Args: 182 | x (Tensor): The input tensor of shape [N, C, H, W] before conversion. 183 | 184 | Returns: 185 | Tensor: The output tensor of shape [N, L, C] after conversion. 186 | """ 187 | assert len(x.shape) == 4 188 | return x.flatten(2).transpose(1, 2).contiguous() 189 | 190 | def nlc_to_nchw(x, hw_shape): 191 | """Convert [N, L, C] shape tensor to [N, C, H, W] shape tensor. 192 | 193 | Args: 194 | x (Tensor): The input tensor of shape [N, L, C] before conversion. 195 | hw_shape (Sequence[int]): The height and width of output feature map. 196 | 197 | Returns: 198 | Tensor: The output tensor of shape [N, C, H, W] after conversion. 199 | """ 200 | H, W = hw_shape 201 | assert len(x.shape) == 3 202 | B, L, C = x.shape 203 | assert L == H * W, 'The seq_len does not match H, W' 204 | return x.transpose(1, 2).reshape(B, C, H, W).contiguous() 205 | 206 | def nlc2nchw2nlc(module, x, hw_shape, contiguous=False, **kwargs): 207 | """Convert [N, L, C] shape tensor `x` to [N, C, H, W] shape tensor. Use the 208 | reshaped tensor as the input of `module`, and convert the output of 209 | `module`, whose shape is. 210 | [N, C, H, W], to [N, L, C]. 211 | Args: 212 | module (Callable): A callable object the takes a tensor 213 | with shape [N, C, H, W] as input. 214 | x (Tensor): The input tensor of shape [N, L, C]. 215 | hw_shape: (Sequence[int]): The height and width of the 216 | feature map with shape [N, C, H, W]. 217 | contiguous (Bool): Whether to make the tensor contiguous 218 | after each shape transform. 219 | Returns: 220 | Tensor: The output tensor of shape [N, L, C]. 221 | Example: 222 | >>> import torch 223 | >>> import torch.nn as nn 224 | >>> conv = nn.Conv2d(16, 16, 3, 1, 1) 225 | >>> feature_map = torch.rand(4, 25, 16) 226 | >>> output = nlc2nchw2nlc(conv, feature_map, (5, 5)) 227 | """ 228 | H, W = hw_shape 229 | assert len(x.shape) == 3 230 | B, L, C = x.shape 231 | assert L == H * W, 'The seq_len doesn\'t match H, W' 232 | if not contiguous: 233 | x = x.transpose(1, 2).reshape(B, C, H, W) 234 | x = module(x, **kwargs) 235 | x = x.flatten(2).transpose(1, 2) 236 | else: 237 | x = x.transpose(1, 2).reshape(B, C, H, W).contiguous() 238 | x = module(x, **kwargs) 239 | x = x.flatten(2).transpose(1, 2).contiguous() 240 | return x 241 | -------------------------------------------------------------------------------- /tools/train_mm.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import argparse 4 | import yaml 5 | import time 6 | import multiprocessing as mp 7 | from tabulate import tabulate 8 | from tqdm import tqdm 9 | from torch.utils.data import DataLoader 10 | from pathlib import Path 11 | from torch.utils.tensorboard import SummaryWriter 12 | from torch.cuda.amp import GradScaler, autocast 13 | from torch.nn.parallel import DistributedDataParallel as DDP 14 | from torch.utils.data import DistributedSampler, RandomSampler 15 | from torch import distributed as dist 16 | from semseg.models import * 17 | from semseg.datasets import * 18 | from semseg.augmentations_mm import get_train_augmentation, get_val_augmentation 19 | from semseg.losses import get_loss 20 | from semseg.schedulers import get_scheduler 21 | from semseg.optimizers import get_optimizer 22 | from semseg.utils.utils import fix_seeds, setup_cudnn, cleanup_ddp, setup_ddp, get_logger, cal_flops, print_iou 23 | from tools.val_mm import evaluate 24 | import wandb 25 | from semseg.metrics import Metrics 26 | import gc 27 | 28 | 29 | def main(cfg, save_dir): 30 | start = time.time() 31 | best_mIoU = 0.0 32 | best_epoch = 0 33 | num_workers = 2 34 | device = torch.device(cfg['DEVICE']) 35 | train_cfg, eval_cfg = cfg['TRAIN'], cfg['EVAL'] 36 | dataset_cfg, model_cfg = cfg['DATASET'], cfg['MODEL'] 37 | loss_cfg, optim_cfg, sched_cfg = cfg['LOSS'], cfg['OPTIMIZER'], cfg['SCHEDULER'] 38 | epochs, lr = train_cfg['EPOCHS'], optim_cfg['LR'] 39 | resume_path = cfg['MODEL']['RESUME'] 40 | gpus = cfg['GPUs'] 41 | use_wandb = cfg['USE_WANDB'] 42 | wandb_name = cfg['WANDB_NAME'] 43 | # gpus = int(os.environ['WORLD_SIZE']) 44 | 45 | traintransform = get_train_augmentation(train_cfg['IMAGE_SIZE'], seg_fill=dataset_cfg['IGNORE_LABEL']) 46 | valtransform = get_val_augmentation(eval_cfg['IMAGE_SIZE']) 47 | 48 | trainset = eval(dataset_cfg['NAME'])(dataset_cfg['ROOT'], 'train', traintransform, dataset_cfg['MODALS']) 49 | valset = eval(dataset_cfg['NAME'])(dataset_cfg['ROOT'], 'val', valtransform, dataset_cfg['MODALS']) 50 | class_names = trainset.CLASSES 51 | 52 | model = eval(model_cfg['NAME'])(model_cfg['BACKBONE'], trainset.n_classes, dataset_cfg['MODALS']) 53 | resume_checkpoint = None 54 | if os.path.isfile(resume_path): 55 | resume_checkpoint = torch.load(resume_path, map_location=torch.device('cpu')) 56 | msg = model.load_state_dict(resume_checkpoint['model_state_dict']) 57 | # print(msg) 58 | logger.info(msg) 59 | else: 60 | model.init_pretrained(model_cfg['PRETRAINED']) 61 | 62 | model = torch.nn.DataParallel(model, device_ids=cfg['GPU_IDs']) 63 | model = model.to(device) 64 | 65 | iters_per_epoch = len(trainset) // train_cfg['BATCH_SIZE'] 66 | loss_fn = get_loss(loss_cfg['NAME'], trainset.ignore_label, None) 67 | start_epoch = 0 68 | optimizer = get_optimizer(model, optim_cfg['NAME'], lr, optim_cfg['WEIGHT_DECAY']) 69 | scheduler = get_scheduler(sched_cfg['NAME'], optimizer, int((epochs+1)*iters_per_epoch), sched_cfg['POWER'], iters_per_epoch * sched_cfg['WARMUP'], sched_cfg['WARMUP_RATIO']) 70 | 71 | if train_cfg['DDP']: 72 | sampler = DistributedSampler(trainset, dist.get_world_size(), dist.get_rank(), shuffle=True) 73 | sampler_val = None 74 | model = DDP(model, device_ids=[gpu], output_device=0, find_unused_parameters=True) 75 | else: 76 | sampler = RandomSampler(trainset) 77 | sampler_val = None 78 | 79 | if resume_checkpoint: 80 | start_epoch = resume_checkpoint['epoch'] - 1 81 | optimizer.load_state_dict(resume_checkpoint['optimizer_state_dict']) 82 | scheduler.load_state_dict(resume_checkpoint['scheduler_state_dict']) 83 | loss = resume_checkpoint['loss'] 84 | best_mIoU = resume_checkpoint['best_miou'] 85 | del resume_checkpoint 86 | 87 | trainloader = DataLoader(trainset, batch_size=train_cfg['BATCH_SIZE'], num_workers=num_workers, drop_last=True, pin_memory=False, sampler=sampler) 88 | valloader = DataLoader(valset, batch_size=eval_cfg['BATCH_SIZE'], num_workers=num_workers, pin_memory=False, sampler=sampler_val) 89 | 90 | scaler = GradScaler(enabled=train_cfg['AMP']) 91 | if (train_cfg['DDP'] and torch.distributed.get_rank() == 0) or (not train_cfg['DDP']): 92 | writer = SummaryWriter(str(save_dir)) 93 | logger.info('================== model complexity =====================') 94 | cal_flops(model, dataset_cfg['MODALS'], logger) 95 | logger.info('================== model structure =====================') 96 | logger.info(model) 97 | logger.info('================== training config =====================') 98 | logger.info(cfg) 99 | logger.info('================== parameter count =====================') 100 | logger.info(sum(p.numel() for p in model.parameters() if p.requires_grad)) 101 | 102 | for epoch in range(start_epoch, epochs): 103 | # Clean Memory 104 | torch.cuda.empty_cache() 105 | gc.collect() 106 | 107 | model.train() 108 | if train_cfg['DDP']: sampler.set_epoch(epoch) 109 | 110 | train_loss = 0.0 111 | lr = scheduler.get_lr() 112 | lr = sum(lr) / len(lr) 113 | pbar = tqdm(enumerate(trainloader), total=iters_per_epoch, desc=f"Epoch: [{epoch+1}/{epochs}] Iter: [{0}/{iters_per_epoch}] LR: {lr:.8f} Loss: {train_loss:.8f}") 114 | metrics = Metrics(trainset.n_classes, trainloader.dataset.ignore_label, device) 115 | 116 | for iter, (sample, lbl) in pbar: 117 | optimizer.zero_grad(set_to_none=True) 118 | sample = [x.to(device) for x in sample] 119 | lbl = lbl.to(device) 120 | 121 | with autocast(enabled=train_cfg['AMP']): 122 | logits = model(sample) 123 | loss = loss_fn(logits, lbl) 124 | 125 | metrics.update(logits.softmax(dim=1), lbl) 126 | 127 | scaler.scale(loss).backward() 128 | scaler.step(optimizer) 129 | scaler.update() 130 | scheduler.step() 131 | torch.cuda.synchronize() 132 | 133 | lr = scheduler.get_lr() 134 | lr = sum(lr) / len(lr) 135 | if lr <= 1e-8: 136 | lr = 1e-8 # minimum of lr 137 | train_loss += loss.item() 138 | 139 | # Clean Memory 140 | torch.cuda.empty_cache() 141 | gc.collect() 142 | 143 | pbar.set_description(f"Epoch: [{epoch+1}/{epochs}] Iter: [{iter+1}/{iters_per_epoch}] LR: {lr:.8f} Loss: {train_loss / (iter+1):.8f}") 144 | 145 | train_loss /= iter+1 146 | if (train_cfg['DDP'] and torch.distributed.get_rank() == 0) or (not train_cfg['DDP']): 147 | writer.add_scalar('train/loss', train_loss, epoch) 148 | 149 | ious, miou = metrics.compute_iou() 150 | acc, macc = metrics.compute_pixel_acc() 151 | f1, mf1 = metrics.compute_f1() 152 | 153 | # if use_wandb: 154 | train_log_data = { 155 | "Epoch": epoch+1, 156 | "Train Loss": train_loss, 157 | "Train mIoU": miou, 158 | "Train Pixel Acc": macc, 159 | "Train F1": mf1, 160 | } 161 | 162 | if ((epoch+1) % train_cfg['EVAL_INTERVAL'] == 0 and (epoch+1)>train_cfg['EVAL_START']) or (epoch+1) == epochs: 163 | if (train_cfg['DDP'] and torch.distributed.get_rank() == 0) or (not train_cfg['DDP']): 164 | acc, macc, f1, mf1, ious, miou, test_loss = evaluate(model, valloader, device, loss_fn=loss_fn) 165 | writer.add_scalar('val/mIoU', miou, epoch) 166 | 167 | # if use wandb 168 | log_data = { 169 | "Test Loss": test_loss, 170 | "Test mIoU": miou, 171 | "Test Pixel Acc": macc, 172 | "Test F1": mf1, 173 | } 174 | log_data.update(train_log_data) 175 | print(log_data) 176 | if use_wandb: 177 | wandb.log(log_data) 178 | 179 | if miou > best_mIoU: 180 | prev_best_ckp = save_dir / f"{model_cfg['NAME']}_{model_cfg['BACKBONE']}_{dataset_cfg['NAME']}_epoch{best_epoch}_{best_mIoU}_checkpoint.pth" 181 | prev_best = save_dir / f"{model_cfg['NAME']}_{model_cfg['BACKBONE']}_{dataset_cfg['NAME']}_epoch{best_epoch}_{best_mIoU}.pth" 182 | if os.path.isfile(prev_best): os.remove(prev_best) 183 | if os.path.isfile(prev_best_ckp): os.remove(prev_best_ckp) 184 | best_mIoU = miou 185 | best_epoch = epoch+1 186 | cur_best_ckp = save_dir / f"{model_cfg['NAME']}_{model_cfg['BACKBONE']}_{dataset_cfg['NAME']}_epoch{best_epoch}_{best_mIoU}_checkpoint.pth" 187 | cur_best = save_dir / f"{model_cfg['NAME']}_{model_cfg['BACKBONE']}_{dataset_cfg['NAME']}_epoch{best_epoch}_{best_mIoU}.pth" 188 | # torch.save(model.module.state_dict() if train_cfg['DDP'] else model.state_dict(), cur_best) 189 | torch.save(model.module.state_dict(), cur_best) 190 | # --- 191 | torch.save({'epoch': best_epoch, 192 | 'model_state_dict': model.module.state_dict() if train_cfg['DDP'] else model.state_dict(), 193 | 'optimizer_state_dict': optimizer.state_dict(), 194 | 'loss': train_loss, 195 | 'scheduler_state_dict': scheduler.state_dict(), 196 | 'best_miou': best_mIoU, 197 | }, cur_best_ckp) 198 | logger.info(print_iou(epoch, ious, miou, acc, macc, class_names)) 199 | logger.info(f"Current epoch:{epoch} mIoU: {miou} Best mIoU: {best_mIoU}") 200 | 201 | if (train_cfg['DDP'] and torch.distributed.get_rank() == 0) or (not train_cfg['DDP']): 202 | writer.close() 203 | pbar.close() 204 | end = time.gmtime(time.time() - start) 205 | 206 | table = [ 207 | ['Best mIoU', f"{best_mIoU:.2f}"], 208 | ['Total Training Time', time.strftime("%H:%M:%S", end)] 209 | ] 210 | logger.info(tabulate(table, numalign='right')) 211 | 212 | 213 | if __name__ == '__main__': 214 | parser = argparse.ArgumentParser() 215 | parser.add_argument('--cfg', type=str, default='configs/mcubes_rgbadn.yaml', help='Configuration file to use') 216 | args = parser.parse_args() 217 | 218 | with open(args.cfg) as f: 219 | cfg = yaml.load(f, Loader=yaml.SafeLoader) 220 | 221 | fix_seeds(3407) 222 | setup_cudnn() 223 | # gpu = setup_ddp() 224 | modals = ''.join([m[0] for m in cfg['DATASET']['MODALS']]) 225 | model = cfg['MODEL']['BACKBONE'] 226 | # exp_name = '_'.join([cfg['DATASET']['NAME'], model, modals]) 227 | exp_name = cfg['WANDB_NAME'] 228 | if cfg['USE_WANDB']: 229 | wandb.init(project="ProjcetName", entity="EntityName", name=exp_name) 230 | 231 | save_dir = Path(cfg['SAVE_DIR'], exp_name) 232 | if os.path.isfile(cfg['MODEL']['RESUME']): 233 | save_dir = Path(os.path.dirname(cfg['MODEL']['RESUME'])) 234 | os.makedirs(save_dir, exist_ok=True) 235 | logger = get_logger(save_dir / 'train.log') 236 | main(cfg, save_dir) 237 | cleanup_ddp() -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [2023] [Jiaming Zhang] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /semseg/augmentations.py: -------------------------------------------------------------------------------- 1 | import torchvision.transforms.functional as TF 2 | import random 3 | import math 4 | import torch 5 | from torch import Tensor 6 | from typing import Tuple, List, Union, Tuple, Optional 7 | 8 | 9 | class Compose: 10 | def __init__(self, transforms: list) -> None: 11 | self.transforms = transforms 12 | 13 | def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]: 14 | if mask.ndim == 2: 15 | assert img.shape[1:] == mask.shape 16 | else: 17 | assert img.shape[1:] == mask.shape[1:] 18 | 19 | for transform in self.transforms: 20 | img, mask = transform(img, mask) 21 | 22 | return img, mask 23 | 24 | 25 | class Normalize: 26 | def __init__(self, mean: list = (0.485, 0.456, 0.406), std: list = (0.229, 0.224, 0.225)): 27 | self.mean = mean 28 | self.std = std 29 | 30 | def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]: 31 | img = img.float() 32 | img /= 255 33 | img = TF.normalize(img, self.mean, self.std) 34 | return img, mask 35 | 36 | 37 | class ColorJitter: 38 | def __init__(self, brightness=0, contrast=0, saturation=0, hue=0) -> None: 39 | self.brightness = brightness 40 | self.contrast = contrast 41 | self.saturation = saturation 42 | self.hue = hue 43 | 44 | def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]: 45 | if self.brightness > 0: 46 | img = TF.adjust_brightness(img, self.brightness) 47 | if self.contrast > 0: 48 | img = TF.adjust_contrast(img, self.contrast) 49 | if self.saturation > 0: 50 | img = TF.adjust_saturation(img, self.saturation) 51 | if self.hue > 0: 52 | img = TF.adjust_hue(img, self.hue) 53 | return img, mask 54 | 55 | 56 | class AdjustGamma: 57 | def __init__(self, gamma: float, gain: float = 1) -> None: 58 | """ 59 | Args: 60 | gamma: Non-negative real number. gamma larger than 1 make the shadows darker, while gamma smaller than 1 make dark regions lighter. 61 | gain: constant multiplier 62 | """ 63 | self.gamma = gamma 64 | self.gain = gain 65 | 66 | def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]: 67 | return TF.adjust_gamma(img, self.gamma, self.gain), mask 68 | 69 | 70 | class RandomAdjustSharpness: 71 | def __init__(self, sharpness_factor: float, p: float = 0.5) -> None: 72 | self.sharpness = sharpness_factor 73 | self.p = p 74 | 75 | def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]: 76 | if random.random() < self.p: 77 | img = TF.adjust_sharpness(img, self.sharpness) 78 | return img, mask 79 | 80 | 81 | class RandomAutoContrast: 82 | def __init__(self, p: float = 0.5) -> None: 83 | self.p = p 84 | 85 | def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]: 86 | if random.random() < self.p: 87 | img = TF.autocontrast(img) 88 | return img, mask 89 | 90 | 91 | class RandomGaussianBlur: 92 | def __init__(self, kernel_size: int = 3, p: float = 0.5) -> None: 93 | self.kernel_size = kernel_size 94 | self.p = p 95 | 96 | def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]: 97 | if random.random() < self.p: 98 | img = TF.gaussian_blur(img, self.kernel_size) 99 | return img, mask 100 | 101 | 102 | class RandomHorizontalFlip: 103 | def __init__(self, p: float = 0.5) -> None: 104 | self.p = p 105 | 106 | def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]: 107 | if random.random() < self.p: 108 | return TF.hflip(img), TF.hflip(mask) 109 | return img, mask 110 | 111 | 112 | class RandomVerticalFlip: 113 | def __init__(self, p: float = 0.5) -> None: 114 | self.p = p 115 | 116 | def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]: 117 | if random.random() < self.p: 118 | return TF.vflip(img), TF.vflip(mask) 119 | return img, mask 120 | 121 | 122 | class RandomGrayscale: 123 | def __init__(self, p: float = 0.5) -> None: 124 | self.p = p 125 | 126 | def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]: 127 | if random.random() < self.p: 128 | img = TF.rgb_to_grayscale(img, 3) 129 | return img, mask 130 | 131 | 132 | class Equalize: 133 | def __call__(self, image, label): 134 | return TF.equalize(image), label 135 | 136 | 137 | class Posterize: 138 | def __init__(self, bits=2): 139 | self.bits = bits # 0-8 140 | 141 | def __call__(self, image, label): 142 | return TF.posterize(image, self.bits), label 143 | 144 | 145 | class Affine: 146 | def __init__(self, angle=0, translate=[0, 0], scale=1.0, shear=[0, 0], seg_fill=0): 147 | self.angle = angle 148 | self.translate = translate 149 | self.scale = scale 150 | self.shear = shear 151 | self.seg_fill = seg_fill 152 | 153 | def __call__(self, img, label): 154 | return TF.affine(img, self.angle, self.translate, self.scale, self.shear, TF.InterpolationMode.BILINEAR, 0), TF.affine(label, self.angle, self.translate, self.scale, self.shear, TF.InterpolationMode.NEAREST, self.seg_fill) 155 | 156 | 157 | class RandomRotation: 158 | def __init__(self, degrees: float = 10.0, p: float = 0.2, seg_fill: int = 0, expand: bool = False) -> None: 159 | """Rotate the image by a random angle between -angle and angle with probability p 160 | 161 | Args: 162 | p: probability 163 | angle: rotation angle value in degrees, counter-clockwise. 164 | expand: Optional expansion flag. 165 | If true, expands the output image to make it large enough to hold the entire rotated image. 166 | If false or omitted, make the output image the same size as the input image. 167 | Note that the expand flag assumes rotation around the center and no translation. 168 | """ 169 | self.p = p 170 | self.angle = degrees 171 | self.expand = expand 172 | self.seg_fill = seg_fill 173 | 174 | def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]: 175 | random_angle = random.random() * 2 * self.angle - self.angle 176 | if random.random() < self.p: 177 | img = TF.rotate(img, random_angle, TF.InterpolationMode.BILINEAR, self.expand, fill=0) 178 | mask = TF.rotate(mask, random_angle, TF.InterpolationMode.NEAREST, self.expand, fill=self.seg_fill) 179 | return img, mask 180 | 181 | 182 | class CenterCrop: 183 | def __init__(self, size: Union[int, List[int], Tuple[int]]) -> None: 184 | """Crops the image at the center 185 | 186 | Args: 187 | output_size: height and width of the crop box. If int, this size is used for both directions. 188 | """ 189 | self.size = (size, size) if isinstance(size, int) else size 190 | 191 | def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]: 192 | return TF.center_crop(img, self.size), TF.center_crop(mask, self.size) 193 | 194 | 195 | class RandomCrop: 196 | def __init__(self, size: Union[int, List[int], Tuple[int]], p: float = 0.5) -> None: 197 | """Randomly Crops the image. 198 | 199 | Args: 200 | output_size: height and width of the crop box. If int, this size is used for both directions. 201 | """ 202 | self.size = (size, size) if isinstance(size, int) else size 203 | self.p = p 204 | 205 | def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]: 206 | H, W = img.shape[1:] 207 | tH, tW = self.size 208 | 209 | if random.random() < self.p: 210 | margin_h = max(H - tH, 0) 211 | margin_w = max(W - tW, 0) 212 | y1 = random.randint(0, margin_h+1) 213 | x1 = random.randint(0, margin_w+1) 214 | y2 = y1 + tH 215 | x2 = x1 + tW 216 | img = img[:, y1:y2, x1:x2] 217 | mask = mask[:, y1:y2, x1:x2] 218 | return img, mask 219 | 220 | 221 | class Pad: 222 | def __init__(self, size: Union[List[int], Tuple[int], int], seg_fill: int = 0) -> None: 223 | """Pad the given image on all sides with the given "pad" value. 224 | Args: 225 | size: expected output image size (h, w) 226 | fill: Pixel fill value for constant fill. Default is 0. This value is only used when the padding mode is constant. 227 | """ 228 | self.size = size 229 | self.seg_fill = seg_fill 230 | 231 | def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]: 232 | padding = (0, 0, self.size[1]-img.shape[2], self.size[0]-img.shape[1]) 233 | return TF.pad(img, padding), TF.pad(mask, padding, self.seg_fill) 234 | 235 | 236 | class ResizePad: 237 | def __init__(self, size: Union[int, Tuple[int], List[int]], seg_fill: int = 0) -> None: 238 | """Resize the input image to the given size. 239 | Args: 240 | size: Desired output size. 241 | If size is a sequence, the output size will be matched to this. 242 | If size is an int, the smaller edge of the image will be matched to this number maintaining the aspect ratio. 243 | """ 244 | self.size = size 245 | self.seg_fill = seg_fill 246 | 247 | def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]: 248 | H, W = img.shape[1:] 249 | tH, tW = self.size 250 | 251 | # scale the image 252 | scale_factor = min(tH/H, tW/W) if W > H else max(tH/H, tW/W) 253 | # nH, nW = int(H * scale_factor + 0.5), int(W * scale_factor + 0.5) 254 | nH, nW = round(H*scale_factor), round(W*scale_factor) 255 | img = TF.resize(img, (nH, nW), TF.InterpolationMode.BILINEAR) 256 | mask = TF.resize(mask, (nH, nW), TF.InterpolationMode.NEAREST) 257 | 258 | # pad the image 259 | padding = [0, 0, tW - nW, tH - nH] 260 | img = TF.pad(img, padding, fill=0) 261 | mask = TF.pad(mask, padding, fill=self.seg_fill) 262 | return img, mask 263 | 264 | 265 | class Resize: 266 | def __init__(self, size: Union[int, Tuple[int], List[int]]) -> None: 267 | """Resize the input image to the given size. 268 | Args: 269 | size: Desired output size. 270 | If size is a sequence, the output size will be matched to this. 271 | If size is an int, the smaller edge of the image will be matched to this number maintaining the aspect ratio. 272 | """ 273 | self.size = size 274 | 275 | def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]: 276 | H, W = img.shape[1:] 277 | 278 | # scale the image 279 | scale_factor = self.size[0] / min(H, W) 280 | nH, nW = round(H*scale_factor), round(W*scale_factor) 281 | img = TF.resize(img, (nH, nW), TF.InterpolationMode.BILINEAR) 282 | mask = TF.resize(mask, (nH, nW), TF.InterpolationMode.NEAREST) 283 | 284 | # make the image divisible by stride 285 | alignH, alignW = int(math.ceil(nH / 32)) * 32, int(math.ceil(nW / 32)) * 32 286 | img = TF.resize(img, (alignH, alignW), TF.InterpolationMode.BILINEAR) 287 | mask = TF.resize(mask, (alignH, alignW), TF.InterpolationMode.NEAREST) 288 | return img, mask 289 | 290 | 291 | class RandomResizedCrop: 292 | def __init__(self, size: Union[int, Tuple[int], List[int]], scale: Tuple[float, float] = (0.5, 2.0), seg_fill: int = 0) -> None: 293 | """Resize the input image to the given size. 294 | """ 295 | self.size = size 296 | self.scale = scale 297 | self.seg_fill = seg_fill 298 | 299 | def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]: 300 | H, W = img.shape[1:] 301 | tH, tW = self.size 302 | 303 | # get the scale 304 | ratio = random.random() * (self.scale[1] - self.scale[0]) + self.scale[0] 305 | # ratio = random.uniform(min(self.scale), max(self.scale)) 306 | scale = int(tH*ratio), int(tW*4*ratio) 307 | 308 | # scale the image 309 | scale_factor = min(max(scale)/max(H, W), min(scale)/min(H, W)) 310 | nH, nW = int(H * scale_factor + 0.5), int(W * scale_factor + 0.5) 311 | # nH, nW = int(math.ceil(nH / 32)) * 32, int(math.ceil(nW / 32)) * 32 312 | img = TF.resize(img, (nH, nW), TF.InterpolationMode.BILINEAR) 313 | mask = TF.resize(mask, (nH, nW), TF.InterpolationMode.NEAREST) 314 | 315 | # random crop 316 | margin_h = max(img.shape[1] - tH, 0) 317 | margin_w = max(img.shape[2] - tW, 0) 318 | y1 = random.randint(0, margin_h+1) 319 | x1 = random.randint(0, margin_w+1) 320 | y2 = y1 + tH 321 | x2 = x1 + tW 322 | img = img[:, y1:y2, x1:x2] 323 | mask = mask[:, y1:y2, x1:x2] 324 | 325 | # pad the image 326 | if img.shape[1:] != self.size: 327 | padding = [0, 0, tW - img.shape[2], tH - img.shape[1]] 328 | img = TF.pad(img, padding, fill=0) 329 | mask = TF.pad(mask, padding, fill=self.seg_fill) 330 | return img, mask 331 | 332 | 333 | 334 | def get_train_augmentation(size: Union[int, Tuple[int], List[int]], seg_fill: int = 0): 335 | return Compose([ 336 | # ColorJitter(brightness=0.0, contrast=0.5, saturation=0.5, hue=0.5), 337 | # RandomAdjustSharpness(sharpness_factor=0.1, p=0.5), 338 | # RandomAutoContrast(p=0.2), 339 | RandomHorizontalFlip(p=0.5), 340 | # RandomVerticalFlip(p=0.5), 341 | # RandomGaussianBlur((3, 3), p=0.5), 342 | # RandomGrayscale(p=0.5), 343 | # RandomRotation(degrees=10, p=0.3, seg_fill=seg_fill), 344 | RandomResizedCrop(size, scale=(0.5, 2.0), seg_fill=seg_fill), 345 | Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 346 | ]) 347 | 348 | def get_val_augmentation(size: Union[int, Tuple[int], List[int]]): 349 | return Compose([ 350 | Resize(size), 351 | Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 352 | ]) 353 | 354 | 355 | if __name__ == '__main__': 356 | h = 230 357 | w = 420 358 | img = torch.randn(3, h, w) 359 | mask = torch.randn(1, h, w) 360 | aug = Compose([ 361 | RandomResizedCrop((512, 512)), 362 | # RandomCrop((512, 512), p=1.0), 363 | # Pad((512, 512)) 364 | ]) 365 | img, mask = aug(img, mask) 366 | print(img.shape, mask.shape) -------------------------------------------------------------------------------- /semseg/augmentations_mm.py: -------------------------------------------------------------------------------- 1 | import torchvision.transforms.functional as TF 2 | import random 3 | import math 4 | import torch 5 | from torch import Tensor 6 | from typing import Tuple, List, Union, Tuple, Optional 7 | 8 | 9 | class Compose: 10 | def __init__(self, transforms: list) -> None: 11 | self.transforms = transforms 12 | 13 | def __call__(self, sample: list) -> list: 14 | img, mask = sample['img'], sample['mask'] 15 | if mask.ndim == 2: 16 | assert img.shape[1:] == mask.shape 17 | else: 18 | assert img.shape[1:] == mask.shape[1:] 19 | 20 | for transform in self.transforms: 21 | sample = transform(sample) 22 | 23 | return sample 24 | 25 | 26 | class Normalize: 27 | def __init__(self, mean: list = (0.485, 0.456, 0.406), std: list = (0.229, 0.224, 0.225)): 28 | self.mean = mean 29 | self.std = std 30 | 31 | def __call__(self, sample: list) -> list: 32 | for k, v in sample.items(): 33 | if k == 'mask': 34 | continue 35 | elif k == 'img': 36 | sample[k] = sample[k].float() 37 | sample[k] /= 255 38 | sample[k] = TF.normalize(sample[k], self.mean, self.std) 39 | else: 40 | sample[k] = sample[k].float() 41 | sample[k] /= 255 42 | 43 | return sample 44 | 45 | 46 | class RandomColorJitter: 47 | def __init__(self, p=0.5) -> None: 48 | self.p = p 49 | 50 | def __call__(self, sample: list) -> list: 51 | if random.random() < self.p: 52 | self.brightness = random.uniform(0.5, 1.5) 53 | sample['img'] = TF.adjust_brightness(sample['img'], self.brightness) 54 | self.contrast = random.uniform(0.5, 1.5) 55 | sample['img'] = TF.adjust_contrast(sample['img'], self.contrast) 56 | self.saturation = random.uniform(0.5, 1.5) 57 | sample['img'] = TF.adjust_saturation(sample['img'], self.saturation) 58 | return sample 59 | 60 | 61 | class AdjustGamma: 62 | def __init__(self, gamma: float, gain: float = 1) -> None: 63 | """ 64 | Args: 65 | gamma: Non-negative real number. gamma larger than 1 make the shadows darker, while gamma smaller than 1 make dark regions lighter. 66 | gain: constant multiplier 67 | """ 68 | self.gamma = gamma 69 | self.gain = gain 70 | 71 | def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]: 72 | return TF.adjust_gamma(img, self.gamma, self.gain), mask 73 | 74 | 75 | class RandomAdjustSharpness: 76 | def __init__(self, sharpness_factor: float, p: float = 0.5) -> None: 77 | self.sharpness = sharpness_factor 78 | self.p = p 79 | 80 | def __call__(self, sample: list) -> list: 81 | if random.random() < self.p: 82 | sample['img'] = TF.adjust_sharpness(sample['img'], self.sharpness) 83 | return sample 84 | 85 | 86 | class RandomAutoContrast: 87 | def __init__(self, p: float = 0.5) -> None: 88 | self.p = p 89 | 90 | def __call__(self, sample: list) -> list: 91 | if random.random() < self.p: 92 | sample['img'] = TF.autocontrast(sample['img']) 93 | return sample 94 | 95 | 96 | class RandomGaussianBlur: 97 | def __init__(self, kernel_size: int = 3, p: float = 0.5) -> None: 98 | self.kernel_size = kernel_size 99 | self.p = p 100 | 101 | def __call__(self, sample: list) -> list: 102 | if random.random() < self.p: 103 | sample['img'] = TF.gaussian_blur(sample['img'], self.kernel_size) 104 | # img = TF.gaussian_blur(img, self.kernel_size) 105 | return sample 106 | 107 | 108 | class RandomHorizontalFlip: 109 | def __init__(self, p: float = 0.5) -> None: 110 | self.p = p 111 | 112 | def __call__(self, sample: list) -> list: 113 | if random.random() < self.p: 114 | for k, v in sample.items(): 115 | sample[k] = TF.hflip(v) 116 | return sample 117 | return sample 118 | 119 | 120 | class RandomVerticalFlip: 121 | def __init__(self, p: float = 0.5) -> None: 122 | self.p = p 123 | 124 | def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]: 125 | if random.random() < self.p: 126 | return TF.vflip(img), TF.vflip(mask) 127 | return img, mask 128 | 129 | 130 | class RandomGrayscale: 131 | def __init__(self, p: float = 0.5) -> None: 132 | self.p = p 133 | 134 | def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]: 135 | if random.random() < self.p: 136 | img = TF.rgb_to_grayscale(img, 3) 137 | return img, mask 138 | 139 | 140 | class Equalize: 141 | def __call__(self, image, label): 142 | return TF.equalize(image), label 143 | 144 | 145 | class Posterize: 146 | def __init__(self, bits=2): 147 | self.bits = bits # 0-8 148 | 149 | def __call__(self, image, label): 150 | return TF.posterize(image, self.bits), label 151 | 152 | 153 | class Affine: 154 | def __init__(self, angle=0, translate=[0, 0], scale=1.0, shear=[0, 0], seg_fill=0): 155 | self.angle = angle 156 | self.translate = translate 157 | self.scale = scale 158 | self.shear = shear 159 | self.seg_fill = seg_fill 160 | 161 | def __call__(self, img, label): 162 | return TF.affine(img, self.angle, self.translate, self.scale, self.shear, TF.InterpolationMode.BILINEAR, 0), TF.affine(label, self.angle, self.translate, self.scale, self.shear, TF.InterpolationMode.NEAREST, self.seg_fill) 163 | 164 | 165 | class RandomRotation: 166 | def __init__(self, degrees: float = 10.0, p: float = 0.2, seg_fill: int = 0, expand: bool = False) -> None: 167 | """Rotate the image by a random angle between -angle and angle with probability p 168 | 169 | Args: 170 | p: probability 171 | angle: rotation angle value in degrees, counter-clockwise. 172 | expand: Optional expansion flag. 173 | If true, expands the output image to make it large enough to hold the entire rotated image. 174 | If false or omitted, make the output image the same size as the input image. 175 | Note that the expand flag assumes rotation around the center and no translation. 176 | """ 177 | self.p = p 178 | self.angle = degrees 179 | self.expand = expand 180 | self.seg_fill = seg_fill 181 | 182 | def __call__(self, sample: list) -> list: 183 | random_angle = random.random() * 2 * self.angle - self.angle 184 | if random.random() < self.p: 185 | for k, v in sample.items(): 186 | if k == 'mask': 187 | sample[k] = TF.rotate(v, random_angle, TF.InterpolationMode.NEAREST, self.expand, fill=self.seg_fill) 188 | else: 189 | sample[k] = TF.rotate(v, random_angle, TF.InterpolationMode.BILINEAR, self.expand, fill=0) 190 | # img = TF.rotate(img, random_angle, TF.InterpolationMode.BILINEAR, self.expand, fill=0) 191 | # mask = TF.rotate(mask, random_angle, TF.InterpolationMode.NEAREST, self.expand, fill=self.seg_fill) 192 | return sample 193 | 194 | 195 | class CenterCrop: 196 | def __init__(self, size: Union[int, List[int], Tuple[int]]) -> None: 197 | """Crops the image at the center 198 | 199 | Args: 200 | output_size: height and width of the crop box. If int, this size is used for both directions. 201 | """ 202 | self.size = (size, size) if isinstance(size, int) else size 203 | 204 | def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]: 205 | return TF.center_crop(img, self.size), TF.center_crop(mask, self.size) 206 | 207 | 208 | class RandomCrop: 209 | def __init__(self, size: Union[int, List[int], Tuple[int]], p: float = 0.5) -> None: 210 | """Randomly Crops the image. 211 | 212 | Args: 213 | output_size: height and width of the crop box. If int, this size is used for both directions. 214 | """ 215 | self.size = (size, size) if isinstance(size, int) else size 216 | self.p = p 217 | 218 | def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]: 219 | H, W = img.shape[1:] 220 | tH, tW = self.size 221 | 222 | if random.random() < self.p: 223 | margin_h = max(H - tH, 0) 224 | margin_w = max(W - tW, 0) 225 | y1 = random.randint(0, margin_h+1) 226 | x1 = random.randint(0, margin_w+1) 227 | y2 = y1 + tH 228 | x2 = x1 + tW 229 | img = img[:, y1:y2, x1:x2] 230 | mask = mask[:, y1:y2, x1:x2] 231 | return img, mask 232 | 233 | 234 | class Pad: 235 | def __init__(self, size: Union[List[int], Tuple[int], int], seg_fill: int = 0) -> None: 236 | """Pad the given image on all sides with the given "pad" value. 237 | Args: 238 | size: expected output image size (h, w) 239 | fill: Pixel fill value for constant fill. Default is 0. This value is only used when the padding mode is constant. 240 | """ 241 | self.size = size 242 | self.seg_fill = seg_fill 243 | 244 | def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]: 245 | padding = (0, 0, self.size[1]-img.shape[2], self.size[0]-img.shape[1]) 246 | return TF.pad(img, padding), TF.pad(mask, padding, self.seg_fill) 247 | 248 | 249 | class ResizePad: 250 | def __init__(self, size: Union[int, Tuple[int], List[int]], seg_fill: int = 0) -> None: 251 | """Resize the input image to the given size. 252 | Args: 253 | size: Desired output size. 254 | If size is a sequence, the output size will be matched to this. 255 | If size is an int, the smaller edge of the image will be matched to this number maintaining the aspect ratio. 256 | """ 257 | self.size = size 258 | self.seg_fill = seg_fill 259 | 260 | def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]: 261 | H, W = img.shape[1:] 262 | tH, tW = self.size 263 | 264 | # scale the image 265 | scale_factor = min(tH/H, tW/W) if W > H else max(tH/H, tW/W) 266 | # nH, nW = int(H * scale_factor + 0.5), int(W * scale_factor + 0.5) 267 | nH, nW = round(H*scale_factor), round(W*scale_factor) 268 | img = TF.resize(img, (nH, nW), TF.InterpolationMode.BILINEAR) 269 | mask = TF.resize(mask, (nH, nW), TF.InterpolationMode.NEAREST) 270 | 271 | # pad the image 272 | padding = [0, 0, tW - nW, tH - nH] 273 | img = TF.pad(img, padding, fill=0) 274 | mask = TF.pad(mask, padding, fill=self.seg_fill) 275 | return img, mask 276 | 277 | 278 | class Resize: 279 | def __init__(self, size: Union[int, Tuple[int], List[int]]) -> None: 280 | """Resize the input image to the given size. 281 | Args: 282 | size: Desired output size. 283 | If size is a sequence, the output size will be matched to this. 284 | If size is an int, the smaller edge of the image will be matched to this number maintaining the aspect ratio. 285 | """ 286 | self.size = size 287 | 288 | def __call__(self, sample:list) -> list: 289 | H, W = sample['img'].shape[1:] 290 | 291 | # scale the image 292 | scale_factor = self.size[0] / min(H, W) 293 | nH, nW = round(H*scale_factor), round(W*scale_factor) 294 | for k, v in sample.items(): 295 | if k == 'mask': 296 | sample[k] = TF.resize(v, (nH, nW), TF.InterpolationMode.NEAREST) 297 | else: 298 | sample[k] = TF.resize(v, (nH, nW), TF.InterpolationMode.BILINEAR) 299 | # img = TF.resize(img, (nH, nW), TF.InterpolationMode.BILINEAR) 300 | # mask = TF.resize(mask, (nH, nW), TF.InterpolationMode.NEAREST) 301 | 302 | # make the image divisible by stride 303 | alignH, alignW = int(math.ceil(nH / 32)) * 32, int(math.ceil(nW / 32)) * 32 304 | 305 | for k, v in sample.items(): 306 | if k == 'mask': 307 | sample[k] = TF.resize(v, (alignH, alignW), TF.InterpolationMode.NEAREST) 308 | else: 309 | sample[k] = TF.resize(v, (alignH, alignW), TF.InterpolationMode.BILINEAR) 310 | # img = TF.resize(img, (alignH, alignW), TF.InterpolationMode.BILINEAR) 311 | # mask = TF.resize(mask, (alignH, alignW), TF.InterpolationMode.NEAREST) 312 | return sample 313 | 314 | 315 | class RandomResizedCrop: 316 | def __init__(self, size: Union[int, Tuple[int], List[int]], scale: Tuple[float, float] = (0.5, 2.0), seg_fill: int = 0) -> None: 317 | """Resize the input image to the given size. 318 | """ 319 | self.size = size 320 | self.scale = scale 321 | self.seg_fill = seg_fill 322 | 323 | def __call__(self, sample: list) -> list: 324 | # img, mask = sample['img'], sample['mask'] 325 | H, W = sample['img'].shape[1:] 326 | tH, tW = self.size 327 | 328 | # get the scale 329 | ratio = random.random() * (self.scale[1] - self.scale[0]) + self.scale[0] 330 | # ratio = random.uniform(min(self.scale), max(self.scale)) 331 | scale = int(tH*ratio), int(tW*4*ratio) 332 | # scale the image 333 | scale_factor = min(max(scale)/max(H, W), min(scale)/min(H, W)) 334 | nH, nW = int(H * scale_factor + 0.5), int(W * scale_factor + 0.5) 335 | # nH, nW = int(math.ceil(nH / 32)) * 32, int(math.ceil(nW / 32)) * 32 336 | for k, v in sample.items(): 337 | if k == 'mask': 338 | sample[k] = TF.resize(v, (nH, nW), TF.InterpolationMode.NEAREST) 339 | else: 340 | sample[k] = TF.resize(v, (nH, nW), TF.InterpolationMode.BILINEAR) 341 | 342 | # random crop 343 | margin_h = max(sample['img'].shape[1] - tH, 0) 344 | margin_w = max(sample['img'].shape[2] - tW, 0) 345 | y1 = random.randint(0, margin_h+1) 346 | x1 = random.randint(0, margin_w+1) 347 | y2 = y1 + tH 348 | x2 = x1 + tW 349 | for k, v in sample.items(): 350 | sample[k] = v[:, y1:y2, x1:x2] 351 | 352 | # pad the image 353 | if sample['img'].shape[1:] != self.size: 354 | padding = [0, 0, tW - sample['img'].shape[2], tH - sample['img'].shape[1]] 355 | for k, v in sample.items(): 356 | if k == 'mask': 357 | sample[k] = TF.pad(v, padding, fill=self.seg_fill) 358 | else: 359 | sample[k] = TF.pad(v, padding, fill=0) 360 | 361 | return sample 362 | 363 | 364 | 365 | def get_train_augmentation(size: Union[int, Tuple[int], List[int]], seg_fill: int = 0): 366 | return Compose([ 367 | RandomColorJitter(p=0.2), # 368 | RandomHorizontalFlip(p=0.5), # 369 | RandomGaussianBlur((3, 3), p=0.2), # 370 | RandomResizedCrop(size, scale=(0.5, 2.0), seg_fill=seg_fill), # 371 | Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 372 | ]) 373 | 374 | def get_val_augmentation(size: Union[int, Tuple[int], List[int]]): 375 | return Compose([ 376 | Resize(size), 377 | Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 378 | ]) 379 | 380 | 381 | if __name__ == '__main__': 382 | h = 230 383 | w = 420 384 | sample = {} 385 | sample['img'] = torch.randn(3, h, w) 386 | sample['depth'] = torch.randn(3, h, w) 387 | sample['lidar'] = torch.randn(3, h, w) 388 | sample['event'] = torch.randn(3, h, w) 389 | sample['mask'] = torch.randn(1, h, w) 390 | aug = Compose([ 391 | RandomHorizontalFlip(p=0.5), 392 | RandomResizedCrop((512, 512)), 393 | Resize((224, 224)), 394 | Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 395 | ]) 396 | sample = aug(sample) 397 | for k, v in sample.items(): 398 | print(k, v.shape) -------------------------------------------------------------------------------- /semseg/models/backbones/mmsformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, Tensor 3 | from torch.nn import functional as F 4 | from semseg.models.layers import DropPath 5 | import torch.nn.init as init 6 | 7 | 8 | class ChannelAttentionBlock(nn.Module): 9 | def __init__(self, channel, reduction=16): 10 | super(ChannelAttentionBlock, self).__init__() 11 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 12 | self.fc = nn.Sequential( 13 | nn.Linear(channel, channel // reduction, bias=False), 14 | nn.ReLU(inplace=True), 15 | nn.Linear(channel // reduction, channel, bias=False), 16 | nn.Sigmoid() 17 | ) 18 | 19 | # Initialize linear layers with Kaiming initialization 20 | for m in self.fc: 21 | if isinstance(m, nn.Linear): 22 | init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 23 | 24 | def forward(self, x, H, W): 25 | B, _, C = x.shape 26 | x = x.transpose(1, 2).view(B, C, H, W) 27 | b, c, _, _ = x.size() 28 | y = self.avg_pool(x).view(b, c) 29 | y = self.fc(y).view(b, c, 1, 1) 30 | return (x * y.expand_as(x)).flatten(2).transpose(1, 2) 31 | 32 | 33 | class DWConv(nn.Module): 34 | def __init__(self, dim): 35 | super().__init__() 36 | self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, groups=dim) 37 | 38 | def forward(self, x: Tensor, H, W) -> Tensor: 39 | B, _, C = x.shape 40 | x = x.transpose(1, 2).view(B, C, H, W) 41 | x = self.dwconv(x) 42 | return x.flatten(2).transpose(1, 2) 43 | 44 | 45 | class CustomDWConv(nn.Module): 46 | def __init__(self, dim, kernel): 47 | super().__init__() 48 | self.dwconv = nn.Conv2d(dim, dim, kernel, 1, padding='same', groups=dim) 49 | 50 | # Apply Kaiming initialization with fan-in to the dwconv layer 51 | init.kaiming_normal_(self.dwconv.weight, mode='fan_in', nonlinearity='relu') 52 | 53 | def forward(self, x: Tensor, H, W) -> Tensor: 54 | B, _, C = x.shape 55 | x = x.transpose(1, 2).view(B, C, H, W) 56 | x = self.dwconv(x) 57 | return x.flatten(2).transpose(1, 2) 58 | 59 | 60 | class CustomPWConv(nn.Module): 61 | def __init__(self, dim): 62 | super().__init__() 63 | self.pwconv = nn.Conv2d(dim, dim, 1) 64 | self.bn = nn.BatchNorm2d(dim) 65 | 66 | # Initialize pwconv layer with Kaiming initialization 67 | init.kaiming_normal_(self.pwconv.weight, mode='fan_in', nonlinearity='relu') 68 | 69 | def forward(self, x: Tensor, H, W) -> Tensor: 70 | B, _, C = x.shape 71 | x = x.transpose(1, 2).view(B, C, H, W) 72 | x = self.bn(self.pwconv(x)) 73 | return x.flatten(2).transpose(1, 2) 74 | 75 | 76 | class MLP(nn.Module): 77 | def __init__(self, c1, c2): 78 | super().__init__() 79 | self.fc1 = nn.Linear(c1, c2) 80 | self.dwconv = DWConv(c2) 81 | self.fc2 = nn.Linear(c2, c1) 82 | 83 | def forward(self, x: Tensor, H, W) -> Tensor: 84 | return self.fc2(F.gelu(self.dwconv(self.fc1(x), H, W))) 85 | 86 | 87 | class MixFFN(nn.Module): 88 | def __init__(self, c1, c2): 89 | super().__init__() 90 | self.fc1 = nn.Linear(c1, c2) 91 | self.pwconv1 = CustomPWConv(c2) 92 | self.dwconv3 = CustomDWConv(c2, 3) 93 | self.dwconv5 = CustomDWConv(c2, 5) 94 | self.dwconv7 = CustomDWConv(c2, 7) 95 | self.pwconv2 = CustomPWConv(c2) 96 | self.fc2 = nn.Linear(c2, c1) 97 | 98 | # Initialize fc1 layer with Kaiming initialization 99 | init.kaiming_normal_(self.fc1.weight, mode='fan_in', nonlinearity='relu') 100 | init.kaiming_normal_(self.fc2.weight, mode='fan_in', nonlinearity='relu') 101 | 102 | def forward(self, x: Tensor, H, W) -> Tensor: 103 | x = self.fc1(x) 104 | x = self.pwconv1(x, H, W) 105 | x1 = self.dwconv3(x, H, W) 106 | x2 = self.dwconv5(x, H, W) 107 | x3 = self.dwconv7(x, H, W) 108 | return self.fc2(F.gelu(self.pwconv2(x + x1 + x2 + x3, H, W))) 109 | 110 | 111 | class FusionBlock(nn.Module): 112 | def __init__(self, channels, reduction=16, num_modals=2): 113 | super(FusionBlock, self).__init__() 114 | self.channels = channels 115 | self.reduction = reduction 116 | self.num_modals = num_modals 117 | 118 | self.liner_fusion_layers = nn.ModuleList([ 119 | nn.Linear(self.channels[0]*self.num_modals, self.channels[0]), 120 | nn.Linear(self.channels[1]*self.num_modals, self.channels[1]), 121 | nn.Linear(self.channels[2]*self.num_modals, self.channels[2]), 122 | nn.Linear(self.channels[3]*self.num_modals, self.channels[3]), 123 | ]) 124 | 125 | self.mix_ffn = nn.ModuleList([ 126 | MixFFN(self.channels[0], self.channels[0]), 127 | MixFFN(self.channels[1], self.channels[1]), 128 | MixFFN(self.channels[2], self.channels[2]), 129 | MixFFN(self.channels[3], self.channels[3]), 130 | ]) 131 | 132 | self.channel_attns = nn.ModuleList([ 133 | ChannelAttentionBlock(self.channels[0]), 134 | ChannelAttentionBlock(self.channels[1]), 135 | ChannelAttentionBlock(self.channels[2]), 136 | ChannelAttentionBlock(self.channels[3]), 137 | ]) 138 | 139 | # Initialize linear fusion layers with Kaiming initialization 140 | for linear_layer in self.liner_fusion_layers: 141 | init.kaiming_normal_(linear_layer.weight, mode='fan_in', nonlinearity='relu') 142 | 143 | def forward(self, x, layer_idx): 144 | B, C, H, W = x[0].shape 145 | x = torch.cat(x, dim=1) 146 | x = x.flatten(2).transpose(1, 2) 147 | x_sum = self.liner_fusion_layers[layer_idx](x) 148 | x_sum = self.mix_ffn[layer_idx](x_sum, H, W) + self.channel_attns[layer_idx](x_sum, H, W) 149 | return x_sum.reshape(B, H, W, -1).permute(0, 3, 1, 2) 150 | 151 | 152 | class Attention(nn.Module): 153 | def __init__(self, dim, head, sr_ratio): 154 | super().__init__() 155 | self.head = head 156 | self.sr_ratio = sr_ratio 157 | self.scale = (dim // head) ** -0.5 158 | self.q = nn.Linear(dim, dim) 159 | self.kv = nn.Linear(dim, dim*2) 160 | self.proj = nn.Linear(dim, dim) 161 | 162 | if sr_ratio > 1: 163 | self.sr = nn.Conv2d(dim, dim, sr_ratio, sr_ratio) 164 | self.norm = nn.LayerNorm(dim) 165 | 166 | def forward(self, x: Tensor, H, W) -> Tensor: 167 | B, N, C = x.shape 168 | q = self.q(x).reshape(B, N, self.head, C // self.head).permute(0, 2, 1, 3) 169 | 170 | if self.sr_ratio > 1: 171 | x = x.permute(0, 2, 1).reshape(B, C, H, W) 172 | x = self.sr(x).reshape(B, C, -1).permute(0, 2, 1) 173 | x = self.norm(x) 174 | 175 | k, v = self.kv(x).reshape(B, -1, 2, self.head, C // self.head).permute(2, 0, 3, 1, 4) 176 | 177 | attn = (q @ k.transpose(-2, -1)) * self.scale 178 | attn = attn.softmax(dim=-1) 179 | 180 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 181 | x = self.proj(x) 182 | return x 183 | 184 | 185 | class PatchEmbed(nn.Module): 186 | def __init__(self, c1=3, c2=32, patch_size=7, stride=4, padding=0): 187 | super().__init__() 188 | self.proj = nn.Conv2d(c1, c2, patch_size, stride, padding) # padding=(ps[0]//2, ps[1]//2) 189 | self.norm = nn.LayerNorm(c2) 190 | 191 | def forward(self, x: Tensor) -> Tensor: 192 | x = self.proj(x) 193 | _, _, H, W = x.shape 194 | x = x.flatten(2).transpose(1, 2) 195 | x = self.norm(x) 196 | return x, H, W 197 | 198 | 199 | class Block(nn.Module): 200 | def __init__(self, dim, head, sr_ratio=1, dpr=0., is_fan=False): 201 | super().__init__() 202 | self.norm1 = nn.LayerNorm(dim) 203 | self.attn = Attention(dim, head, sr_ratio) 204 | self.drop_path = DropPath(dpr) if dpr > 0. else nn.Identity() 205 | self.norm2 = nn.LayerNorm(dim) 206 | self.mlp = MLP(dim, int(dim*4)) if not is_fan else ChannelProcessing(dim, mlp_hidden_dim=int(dim*4)) 207 | 208 | def forward(self, x: Tensor, H, W) -> Tensor: 209 | x = x + self.drop_path(self.attn(self.norm1(x), H, W)) 210 | x = x + self.drop_path(self.mlp(self.norm2(x), H, W)) 211 | return x 212 | 213 | 214 | class ChannelProcessing(nn.Module): 215 | def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., drop_path=0., mlp_hidden_dim=None, norm_layer=nn.LayerNorm): 216 | super().__init__() 217 | assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." 218 | self.dim = dim 219 | self.num_heads = num_heads 220 | 221 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 222 | self.mlp_v = MLP(dim, mlp_hidden_dim) 223 | self.norm_v = norm_layer(dim) 224 | 225 | self.q = nn.Linear(dim, dim, bias=qkv_bias) 226 | self.pool = nn.AdaptiveAvgPool2d((None, 1)) 227 | self.sigmoid = nn.Sigmoid() 228 | 229 | def forward(self, x, H, W, atten=None): 230 | B, N, C = x.shape 231 | 232 | v = x.reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 233 | q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 234 | k = x.reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 235 | 236 | q = q.softmax(-2).transpose(-1,-2) 237 | _, _, Nk, Ck = k.shape 238 | k = k.softmax(-2) 239 | k = torch.nn.functional.avg_pool2d(k, (1, Ck)) 240 | 241 | attn = self.sigmoid(q @ k) 242 | 243 | Bv, Hd, Nv, Cv = v.shape 244 | v = self.norm_v(self.mlp_v(v.transpose(1, 2).reshape(Bv, Nv, Hd*Cv), H, W)).reshape(Bv, Nv, Hd, Cv).transpose(1, 2) 245 | x = (attn * v.transpose(-1, -2)).permute(0, 3, 1, 2).reshape(B, N, C) 246 | return x 247 | 248 | 249 | mit_settings = { 250 | 'B0': [[32, 64, 160, 256], [2, 2, 2, 2]], 251 | 'B1': [[64, 128, 320, 512], [2, 2, 2, 2]], 252 | 'B2': [[64, 128, 320, 512], [3, 4, 6, 3]], 253 | 'B3': [[64, 128, 320, 512], [3, 4, 18, 3]], 254 | 'B4': [[64, 128, 320, 512], [3, 8, 27, 3]], 255 | 'B5': [[64, 128, 320, 512], [3, 6, 40, 3]] 256 | } 257 | 258 | 259 | class MixTransformer(nn.Module): 260 | def __init__(self, model_name: str = 'B0', modality: str = 'depth'): 261 | super().__init__() 262 | assert model_name in mit_settings.keys(), f"Model name should be in {list(cmnext_settings.keys())}" 263 | # self.model_name = 'B2' 264 | self.model_name = model_name 265 | # TODO: Must comment the following line later 266 | # self.model_name = 'B2' if modality == 'depth' else model_name 267 | embed_dims, depths = mit_settings[self.model_name] 268 | self.modality = modality 269 | drop_path_rate = 0.1 270 | self.channels = embed_dims 271 | 272 | # patch_embed 273 | self.patch_embed1 = PatchEmbed(3, embed_dims[0], 7, 4, 7//2) 274 | self.patch_embed2 = PatchEmbed(embed_dims[0], embed_dims[1], 3, 2, 3//2) 275 | self.patch_embed3 = PatchEmbed(embed_dims[1], embed_dims[2], 3, 2, 3//2) 276 | self.patch_embed4 = PatchEmbed(embed_dims[2], embed_dims[3], 3, 2, 3//2) 277 | 278 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] 279 | 280 | cur = 0 281 | self.block1 = nn.ModuleList([Block(embed_dims[0], 1, 8, dpr[cur+i]) for i in range(depths[0])]) 282 | self.norm1 = nn.LayerNorm(embed_dims[0]) 283 | 284 | cur += depths[0] 285 | self.block2 = nn.ModuleList([Block(embed_dims[1], 2, 4, dpr[cur+i]) for i in range(depths[1])]) 286 | self.norm2 = nn.LayerNorm(embed_dims[1]) 287 | 288 | cur += depths[1] 289 | self.block3 = nn.ModuleList([Block(embed_dims[2], 5, 2, dpr[cur+i]) for i in range(depths[2])]) 290 | self.norm3 = nn.LayerNorm(embed_dims[2]) 291 | 292 | cur += depths[2] 293 | self.block4 = nn.ModuleList([Block(embed_dims[3], 8, 1, dpr[cur+i]) for i in range(depths[3])]) 294 | self.norm4 = nn.LayerNorm(embed_dims[3]) 295 | 296 | # Initialize with pretrained weights 297 | self.init_weights() 298 | 299 | def init_weights(self): 300 | print(f"Initializing weight for {self.modality}...") 301 | checkpoint = torch.load(f'checkpoints/pretrained/segformer/mit_{self.model_name.lower()}.pth', map_location=torch.device('cpu')) 302 | if 'state_dict' in checkpoint.keys(): 303 | checkpoint = checkpoint['state_dict'] 304 | msg = self.load_state_dict(checkpoint, strict=False) 305 | del checkpoint 306 | print(f"Weight init complete with message: {msg}") 307 | 308 | def forward(self, x: Tensor) -> list: 309 | x_cam = x 310 | 311 | B = x_cam.shape[0] 312 | outs = [] 313 | # stage 1 314 | x_cam, H, W = self.patch_embed1(x_cam) 315 | for blk in self.block1: 316 | x_cam = blk(x_cam, H, W) 317 | x1_cam = self.norm1(x_cam).reshape(B, H, W, -1).permute(0, 3, 1, 2) 318 | outs.append(x1_cam) 319 | 320 | # stage 2 321 | x_cam, H, W = self.patch_embed2(x1_cam) 322 | for blk in self.block2: 323 | x_cam = blk(x_cam, H, W) 324 | x2_cam = self.norm2(x_cam).reshape(B, H, W, -1).permute(0, 3, 1, 2) 325 | outs.append(x2_cam) 326 | 327 | # stage 3 328 | x_cam, H, W = self.patch_embed3(x2_cam) 329 | for blk in self.block3: 330 | x_cam = blk(x_cam, H, W) 331 | x3_cam = self.norm3(x_cam).reshape(B, H, W, -1).permute(0, 3, 1, 2) 332 | outs.append(x3_cam) 333 | 334 | # stage 4 335 | x_cam, H, W = self.patch_embed4(x3_cam) 336 | for blk in self.block4: 337 | x_cam = blk(x_cam, H, W) 338 | x4_cam = self.norm4(x_cam).reshape(B, H, W, -1).permute(0, 3, 1, 2) 339 | outs.append(x4_cam) 340 | 341 | return outs 342 | 343 | 344 | class MMSFormer(nn.Module): 345 | def __init__(self, model_name: str = 'B0', modals: list = ['rgb', 'depth', 'event', 'lidar']): 346 | super().__init__() 347 | assert model_name in mit_settings.keys(), f"Model name should be in {list(cmnext_settings.keys())}" 348 | embed_dims, depths = mit_settings[model_name] 349 | self.modals = modals[1:] if len(modals)>1 else [] 350 | self.num_modals = len(self.modals) 351 | drop_path_rate = 0.1 352 | self.channels = embed_dims 353 | 354 | # patch_embed 355 | self.patch_embed1 = PatchEmbed(3, embed_dims[0], 7, 4, 7//2) 356 | self.patch_embed2 = PatchEmbed(embed_dims[0], embed_dims[1], 3, 2, 3//2) 357 | self.patch_embed3 = PatchEmbed(embed_dims[1], embed_dims[2], 3, 2, 3//2) 358 | self.patch_embed4 = PatchEmbed(embed_dims[2], embed_dims[3], 3, 2, 3//2) 359 | 360 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] 361 | 362 | cur = 0 363 | self.block1 = nn.ModuleList([Block(embed_dims[0], 1, 8, dpr[cur+i]) for i in range(depths[0])]) 364 | self.norm1 = nn.LayerNorm(embed_dims[0]) 365 | 366 | cur += depths[0] 367 | self.block2 = nn.ModuleList([Block(embed_dims[1], 2, 4, dpr[cur+i]) for i in range(depths[1])]) 368 | self.norm2 = nn.LayerNorm(embed_dims[1]) 369 | 370 | cur += depths[1] 371 | self.block3 = nn.ModuleList([Block(embed_dims[2], 5, 2, dpr[cur+i]) for i in range(depths[2])]) 372 | self.norm3 = nn.LayerNorm(embed_dims[2]) 373 | 374 | cur += depths[2] 375 | self.block4 = nn.ModuleList([Block(embed_dims[3], 8, 1, dpr[cur+i]) for i in range(depths[3])]) 376 | self.norm4 = nn.LayerNorm(embed_dims[3]) 377 | 378 | # Have extra modality 379 | if self.num_modals > 0: 380 | # Backbones and Fusion Block for extra modalities 381 | self.extra_mit = nn.ModuleList([MixTransformer('B1', self.modals[i]) for i in range(self.num_modals)]) 382 | self.fusion_block = FusionBlock(self.channels, reduction=16, num_modals=self.num_modals+1) 383 | 384 | def forward(self, x: list) -> list: 385 | x_cam = x[0] 386 | if self.num_modals > 0: 387 | x_ext = x[1:] 388 | B = x_cam.shape[0] 389 | outs = [] 390 | 391 | # stage 1 392 | x_cam, H, W = self.patch_embed1(x_cam) 393 | for blk in self.block1: 394 | x_cam = blk(x_cam, H, W) 395 | x1_cam = self.norm1(x_cam).reshape(B, H, W, -1).permute(0, 3, 1, 2) 396 | # Extra Modalities 397 | if self.num_modals > 0: 398 | for i in range(self.num_modals): 399 | x_ext[i], _, _ = self.extra_mit[i].patch_embed1(x_ext[i]) 400 | for blk in self.extra_mit[i].block1: 401 | x_ext[i] = blk(x_ext[i], H, W) 402 | x_ext[i] = self.extra_mit[i].norm1(x_ext[i]).reshape(B, H, W, -1).permute(0, 3, 1, 2) 403 | x_fused = self.fusion_block([x1_cam, *x_ext], layer_idx=0) 404 | outs.append(x_fused) 405 | else: 406 | outs.append(x1_cam) 407 | 408 | # stage 2 409 | x_cam, H, W = self.patch_embed2(x1_cam) 410 | for blk in self.block2: 411 | x_cam = blk(x_cam, H, W) 412 | x2_cam = self.norm2(x_cam).reshape(B, H, W, -1).permute(0, 3, 1, 2) 413 | # Extra Modalities 414 | if self.num_modals > 0: 415 | for i in range(self.num_modals): 416 | x_ext[i], _, _ = self.extra_mit[i].patch_embed2(x_ext[i]) 417 | for blk in self.extra_mit[i].block2: 418 | x_ext[i] = blk(x_ext[i], H, W) 419 | x_ext[i] = self.extra_mit[i].norm2(x_ext[i]).reshape(B, H, W, -1).permute(0, 3, 1, 2) 420 | x_fused = self.fusion_block([x2_cam, *x_ext], layer_idx=1) 421 | outs.append(x_fused) 422 | else: 423 | outs.append(x2_cam) 424 | 425 | # stage 3 426 | x_cam, H, W = self.patch_embed3(x2_cam) 427 | for blk in self.block3: 428 | x_cam = blk(x_cam, H, W) 429 | x3_cam = self.norm3(x_cam).reshape(B, H, W, -1).permute(0, 3, 1, 2) 430 | # Extra Modalities 431 | if self.num_modals > 0: 432 | for i in range(self.num_modals): 433 | x_ext[i], _, _ = self.extra_mit[i].patch_embed3(x_ext[i]) 434 | for blk in self.extra_mit[i].block3: 435 | x_ext[i] = blk(x_ext[i], H, W) 436 | x_ext[i] = self.extra_mit[i].norm3(x_ext[i]).reshape(B, H, W, -1).permute(0, 3, 1, 2) 437 | x_fused = self.fusion_block([x3_cam, *x_ext], layer_idx=2) 438 | outs.append(x_fused) 439 | else: 440 | outs.append(x3_cam) 441 | 442 | # stage 4 443 | x_cam, H, W = self.patch_embed4(x3_cam) 444 | for blk in self.block4: 445 | x_cam = blk(x_cam, H, W) 446 | x4_cam = self.norm4(x_cam).reshape(B, H, W, -1).permute(0, 3, 1, 2) 447 | # Extra Modalities 448 | if self.num_modals > 0: 449 | for i in range(self.num_modals): 450 | x_ext[i], _, _ = self.extra_mit[i].patch_embed4(x_ext[i]) 451 | for blk in self.extra_mit[i].block4: 452 | x_ext[i] = blk(x_ext[i], H, W) 453 | x_ext[i] = self.extra_mit[i].norm4(x_ext[i]).reshape(B, H, W, -1).permute(0, 3, 1, 2) 454 | x_fused = self.fusion_block([x4_cam, *x_ext], layer_idx=3) 455 | outs.append(x_fused) 456 | else: 457 | outs.append(x4_cam) 458 | 459 | return outs 460 | 461 | 462 | if __name__ == '__main__': 463 | modals = ['img', 'aolp', 'dolp', 'nir'] 464 | x = [torch.zeros(1, 3, 1024, 1024), torch.ones(1, 3, 1024, 1024), torch.ones(1, 3, 1024, 1024)*2, torch.ones(1, 3, 1024, 1024) *3] 465 | model = MMSFormer('B2', modals) 466 | outs = model(x) 467 | for y in outs: 468 | print(y.shape) 469 | 470 | -------------------------------------------------------------------------------- /semseg/datasets/mcubes.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from torch import Tensor 5 | from torch.utils.data import Dataset 6 | from torchvision import io 7 | from torchvision import transforms 8 | from pathlib import Path 9 | from typing import Tuple 10 | import glob 11 | import einops 12 | from torch.utils.data import DataLoader 13 | from torch.utils.data import DistributedSampler, RandomSampler 14 | from semseg.augmentations_mm import get_train_augmentation 15 | import cv2 16 | import random 17 | from PIL import Image, ImageOps, ImageFilter 18 | 19 | class MCubeS(Dataset): 20 | """ 21 | num_classes: 20 22 | """ 23 | CLASSES = ['asphalt','concrete','metal','road_marking','fabric','glass','plaster','plastic','rubber','sand', 24 | 'gravel','ceramic','cobblestone','brick','grass','wood','leaf','water','human','sky',] 25 | 26 | PALETTE = torch.tensor([[ 44, 160, 44], 27 | [ 31, 119, 180], 28 | [255, 127, 14], 29 | [214, 39, 40], 30 | [140, 86, 75], 31 | [127, 127, 127], 32 | [188, 189, 34], 33 | [255, 152, 150], 34 | [ 23, 190, 207], 35 | [174, 199, 232], 36 | [196, 156, 148], 37 | [197, 176, 213], 38 | [247, 182, 210], 39 | [199, 199, 199], 40 | [219, 219, 141], 41 | [158, 218, 229], 42 | [ 57, 59, 121], 43 | [107, 110, 207], 44 | [156, 158, 222], 45 | [ 99, 121, 57]]) 46 | 47 | def __init__(self, root: str = 'data/MCubeS/multimodal_dataset', split: str = 'train', transform = None, modals = ['image', 'aolp', 'dolp', 'nir'], case = None) -> None: 48 | super().__init__() 49 | assert split in ['train', 'val'] 50 | self.split = split 51 | self.root = root 52 | self.transform = transform 53 | self.n_classes = len(self.CLASSES) 54 | self.ignore_label = 255 55 | self.modals = modals 56 | self._left_offset = 192 57 | 58 | self.img_h = 1024 59 | self.img_w = 1224 60 | max_dim = max(self.img_h, self.img_w) 61 | u_vec = (np.arange(self.img_w)-self.img_w/2)/max_dim*2 62 | v_vec = (np.arange(self.img_h)-self.img_h/2)/max_dim*2 63 | self.u_map, self.v_map = np.meshgrid(u_vec, v_vec) 64 | self.u_map = self.u_map[:,:self._left_offset] 65 | 66 | self.base_size = 512 67 | self.crop_size = 512 68 | self.files = self._get_file_names(split) 69 | 70 | if not self.files: 71 | raise Exception(f"No images found in {img_path}") 72 | print(f"Found {len(self.files)} {split} images.") 73 | 74 | def __len__(self) -> int: 75 | return len(self.files) 76 | 77 | def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]: 78 | item_name = str(self.files[index]) 79 | rgb = os.path.join(*[self.root, 'polL_color', item_name+'.png']) 80 | x1 = os.path.join(*[self.root, 'polL_aolp_sin', item_name+'.npy']) 81 | x1_1 = os.path.join(*[self.root, 'polL_aolp_cos', item_name+'.npy']) 82 | x2 = os.path.join(*[self.root, 'polL_dolp', item_name+'.npy']) 83 | x3 = os.path.join(*[self.root, 'NIR_warped', item_name+'.png']) 84 | lbl_path = os.path.join(*[self.root, 'GT', item_name+'.png']) 85 | nir_mask = os.path.join(*[self.root, 'NIR_warped_mask', item_name+'.png']) 86 | _mask = os.path.join(*[self.root, 'SS', item_name+'.png']) 87 | 88 | _img = cv2.imread(rgb,-1)[:,:,::-1] 89 | _img = _img.astype(np.float32)/65535 if _img.dtype==np.uint16 else _img.astype(np.float32)/255 90 | _target = cv2.imread(lbl_path,-1) 91 | _mask = cv2.imread(_mask,-1) 92 | # _aolp_sin = (np.load(x1) + 2.66) / (2.66 + 2.40) # See below of this script for min-max values 93 | # _aolp_cos = (np.load(x1_1) + 0.57 ) / (0.57 + 1.64) 94 | _aolp_sin = np.load(x1) 95 | _aolp_cos = np.load(x1_1) 96 | _aolp = np.stack([_aolp_sin, _aolp_cos, _aolp_sin], axis=2) # H x W x 3 97 | dolp = (np.load(x2) + 0.55) / 2.0 98 | _dolp = np.stack([dolp, dolp, dolp], axis=2) # H x W x 3 99 | nir = cv2.imread(x3,-1) 100 | nir = nir.astype(np.float32)/65535 if nir.dtype==np.uint16 else nir.astype(np.float32)/255 101 | _nir = np.stack([nir, nir, nir], axis=2) # H x W x 3 102 | 103 | _nir_mask = cv2.imread(nir_mask,0) 104 | 105 | _img, _target, _aolp, _dolp, _nir, _nir_mask, _mask = _img[:,self._left_offset:], _target[:,self._left_offset:], \ 106 | _aolp[:,self._left_offset:], _dolp[:,self._left_offset:], \ 107 | _nir[:,self._left_offset:], _nir_mask[:,self._left_offset:], _mask[:,self._left_offset:] 108 | sample = {'image': _img, 'label': _target, 'aolp': _aolp, 'dolp': _dolp, 'nir': _nir, 'nir_mask': _nir_mask, 'u_map': self.u_map, 'v_map': self.v_map, 'mask':_mask} 109 | 110 | if self.split == "train": 111 | sample = self.transform_tr(sample) 112 | elif self.split == 'val': 113 | sample = self.transform_val(sample) 114 | elif self.split == 'test': 115 | sample = self.transform_val(sample) 116 | else: 117 | raise NotImplementedError() 118 | label = sample['label'].long() 119 | sample = [sample[k] for k in self.modals] 120 | # del _img, _target, _aolp, _dolp, _nir, _nir_mask, _mask 121 | return sample, label 122 | 123 | def transform_tr(self, sample): 124 | composed_transforms = transforms.Compose([ 125 | RandomHorizontalFlip(), 126 | RandomScaleCrop(base_size=self.base_size, crop_size=self.crop_size, fill=255), 127 | RandomGaussianBlur(), 128 | Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), 129 | ToTensor()]) 130 | 131 | return composed_transforms(sample) 132 | 133 | def transform_val(self, sample): 134 | composed_transforms = transforms.Compose([ 135 | FixScaleCrop(crop_size=1024), 136 | Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), 137 | ToTensor()]) 138 | 139 | return composed_transforms(sample) 140 | 141 | def _get_file_names(self, split_name): 142 | assert split_name in ['train', 'val'] 143 | source = os.path.join(self.root, 'list_folder/test.txt') if split_name == 'val' else os.path.join(self.root, 'list_folder/train.txt') 144 | file_names = [] 145 | with open(source) as f: 146 | files = f.readlines() 147 | for item in files: 148 | file_name = item.strip() 149 | if ' ' in file_name: 150 | # --- KITTI-360 151 | file_name = file_name.split(' ')[0] 152 | file_names.append(file_name) 153 | return file_names 154 | 155 | class Normalize(object): 156 | """Normalize a tensor image with mean and standard deviation. 157 | Args: 158 | mean (tuple): means for each channel. 159 | std (tuple): standard deviations for each channel. 160 | """ 161 | def __init__(self, mean=(0., 0., 0.), std=(1., 1., 1.)): 162 | self.mean = mean 163 | self.std = std 164 | 165 | def __call__(self, sample): 166 | img = sample['image'] 167 | mask = sample['label'] 168 | img = np.array(img).astype(np.float32) 169 | mask = np.array(mask).astype(np.float32) 170 | img -= self.mean 171 | img /= self.std 172 | 173 | nir = sample['nir'] 174 | nir = np.array(nir).astype(np.float32) 175 | # nir /= 255 176 | 177 | return {'image': img, 178 | 'label': mask, 179 | 'aolp' : sample['aolp'], 180 | 'dolp' : sample['dolp'], 181 | 'nir' : nir, 182 | 'nir_mask': sample['nir_mask'], 183 | 'u_map': sample['u_map'], 184 | 'v_map': sample['v_map'], 185 | 'mask':sample['mask']} 186 | 187 | 188 | class ToTensor(object): 189 | """Convert ndarrays in sample to Tensors.""" 190 | 191 | def __call__(self, sample): 192 | # swap color axis because 193 | # numpy image: H x W x C 194 | # torch image: C X H X W 195 | img = sample['image'] 196 | mask = sample['label'] 197 | aolp = sample['aolp'] 198 | dolp = sample['dolp'] 199 | nir = sample['nir'] 200 | nir_mask = sample['nir_mask'] 201 | SS=sample['mask'] 202 | 203 | img = np.array(img).astype(np.float32).transpose((2, 0, 1)) 204 | mask = np.array(mask).astype(np.float32) 205 | aolp = np.array(aolp).astype(np.float32).transpose((2, 0, 1)) 206 | dolp = np.array(dolp).astype(np.float32).transpose((2, 0, 1)) 207 | SS = np.array(SS).astype(np.float32) 208 | nir = np.array(nir).astype(np.float32).transpose((2, 0, 1)) 209 | nir_mask = np.array(nir_mask).astype(np.float32) 210 | 211 | img = torch.from_numpy(img).float() 212 | mask = torch.from_numpy(mask).float() 213 | aolp = torch.from_numpy(aolp).float() 214 | dolp = torch.from_numpy(dolp).float() 215 | SS = torch.from_numpy(SS).float() 216 | nir = torch.from_numpy(nir).float() 217 | nir_mask = torch.from_numpy(nir_mask).float() 218 | 219 | u_map = sample['u_map'] 220 | v_map = sample['v_map'] 221 | u_map = torch.from_numpy(u_map.astype(np.float32)).float() 222 | v_map = torch.from_numpy(v_map.astype(np.float32)).float() 223 | 224 | return {'image': img, 225 | 'label': mask, 226 | 'aolp' : aolp, 227 | 'dolp' : dolp, 228 | 'nir' : nir, 229 | 'nir_mask' : nir_mask, 230 | 'u_map': u_map, 231 | 'v_map': v_map, 232 | 'mask':SS} 233 | 234 | 235 | class RandomHorizontalFlip(object): 236 | def __call__(self, sample): 237 | img = sample['image'] 238 | mask = sample['label'] 239 | aolp = sample['aolp'] 240 | dolp = sample['dolp'] 241 | nir = sample['nir'] 242 | nir_mask = sample['nir_mask'] 243 | u_map = sample['u_map'] 244 | v_map = sample['v_map'] 245 | SS=sample['mask'] 246 | if random.random() < 0.5: 247 | # img = img.transpose(Image.FLIP_LEFT_RIGHT) 248 | # mask = mask.transpose(Image.FLIP_LEFT_RIGHT) 249 | # nir = nir.transpose(Image.FLIP_LEFT_RIGHT) 250 | 251 | img = img[:,::-1] 252 | mask = mask[:,::-1] 253 | nir = nir[:,::-1] 254 | nir_mask = nir_mask[:,::-1] 255 | aolp = aolp[:,::-1] 256 | dolp = dolp[:,::-1] 257 | SS = SS[:,::-1] 258 | u_map = u_map[:,::-1] 259 | 260 | return {'image': img, 261 | 'label': mask, 262 | 'aolp' : aolp, 263 | 'dolp' : dolp, 264 | 'nir' : nir, 265 | 'nir_mask' : nir_mask, 266 | 'u_map': u_map, 267 | 'v_map': v_map, 268 | 'mask':SS} 269 | 270 | class RandomGaussianBlur(object): 271 | def __call__(self, sample): 272 | img = sample['image'] 273 | mask = sample['label'] 274 | nir = sample['nir'] 275 | if random.random() < 0.5: 276 | radius = random.random() 277 | # img = img.filter(ImageFilter.GaussianBlur(radius=radius)) 278 | # nir = nir.filter(ImageFilter.GaussianBlur(radius=radius)) 279 | img = cv2.GaussianBlur(img, (0,0), radius) 280 | nir = cv2.GaussianBlur(nir, (0,0), radius) 281 | 282 | return {'image': img, 283 | 'label': mask, 284 | 'aolp' : sample['aolp'], 285 | 'dolp' : sample['dolp'], 286 | 'nir' : nir, 287 | 'nir_mask': sample['nir_mask'], 288 | 'u_map': sample['u_map'], 289 | 'v_map': sample['v_map'], 290 | 'mask':sample['mask']} 291 | 292 | class RandomScaleCrop(object): 293 | def __init__(self, base_size, crop_size, fill=255): 294 | self.base_size = base_size 295 | self.crop_size = crop_size 296 | self.fill = fill 297 | 298 | def __call__(self, sample): 299 | img = sample['image'] 300 | mask = sample['label'] 301 | aolp = sample['aolp'] 302 | dolp = sample['dolp'] 303 | nir = sample['nir'] 304 | nir_mask = sample['nir_mask'] 305 | SS=sample['mask'] 306 | # random scale (short edge) 307 | short_size = random.randint(int(self.base_size * 0.5), int(self.base_size * 2.0)) 308 | # w, h = img.size 309 | h, w = img.shape[:2] 310 | if h > w: 311 | ow = short_size 312 | oh = int(1.0 * h * ow / w) 313 | else: 314 | oh = short_size 315 | ow = int(1.0 * w * oh / h) 316 | 317 | # pad crop 318 | if short_size < self.crop_size: 319 | padh = self.crop_size - oh if oh < self.crop_size else 0 320 | padw = self.crop_size - ow if ow < self.crop_size else 0 321 | 322 | # random crop crop_size 323 | # w, h = img.size 324 | h, w = img.shape[:2] 325 | 326 | # x1 = random.randint(0, w - self.crop_size) 327 | # y1 = random.randint(0, h - self.crop_size) 328 | x1 = random.randint(0, max(0, ow - self.crop_size)) 329 | y1 = random.randint(0, max(0, oh - self.crop_size)) 330 | 331 | u_map = sample['u_map'] 332 | v_map = sample['v_map'] 333 | u_map = cv2.resize(u_map,(ow,oh)) 334 | v_map = cv2.resize(v_map,(ow,oh)) 335 | aolp = cv2.resize(aolp ,(ow,oh)) 336 | dolp = cv2.resize(dolp ,(ow,oh)) 337 | SS = cv2.resize(SS ,(ow,oh)) 338 | img = cv2.resize(img ,(ow,oh), interpolation=cv2.INTER_LINEAR) 339 | mask = cv2.resize(mask ,(ow,oh), interpolation=cv2.INTER_NEAREST) 340 | nir = cv2.resize(nir ,(ow,oh), interpolation=cv2.INTER_LINEAR) 341 | nir_mask = cv2.resize(nir_mask ,(ow,oh), interpolation=cv2.INTER_NEAREST) 342 | if short_size < self.crop_size: 343 | u_map_ = np.zeros((oh+padh,ow+padw)) 344 | u_map_[:oh,:ow] = u_map 345 | u_map = u_map_ 346 | v_map_ = np.zeros((oh+padh,ow+padw)) 347 | v_map_[:oh,:ow] = v_map 348 | v_map = v_map_ 349 | aolp_ = np.zeros((oh+padh,ow+padw,3)) 350 | aolp_[:oh,:ow] = aolp 351 | aolp = aolp_ 352 | dolp_ = np.zeros((oh+padh,ow+padw,3)) 353 | dolp_[:oh,:ow] = dolp 354 | dolp = dolp_ 355 | 356 | img_ = np.zeros((oh+padh,ow+padw,3)) 357 | img_[:oh,:ow] = img 358 | img = img_ 359 | SS_ = np.zeros((oh+padh,ow+padw)) 360 | SS_[:oh,:ow] = SS 361 | SS = SS_ 362 | mask_ = np.full((oh+padh,ow+padw),self.fill) 363 | mask_[:oh,:ow] = mask 364 | mask = mask_ 365 | nir_ = np.zeros((oh+padh,ow+padw,3)) 366 | nir_[:oh,:ow] = nir 367 | nir = nir_ 368 | nir_mask_ = np.zeros((oh+padh,ow+padw)) 369 | nir_mask_[:oh,:ow] = nir_mask 370 | nir_mask = nir_mask_ 371 | 372 | u_map = u_map[y1:y1+self.crop_size, x1:x1+self.crop_size] 373 | v_map = v_map[y1:y1+self.crop_size, x1:x1+self.crop_size] 374 | aolp = aolp[y1:y1+self.crop_size, x1:x1+self.crop_size] 375 | dolp = dolp[y1:y1+self.crop_size, x1:x1+self.crop_size] 376 | img = img[y1:y1+self.crop_size, x1:x1+self.crop_size] 377 | mask = mask[y1:y1+self.crop_size, x1:x1+self.crop_size] 378 | nir = nir[y1:y1+self.crop_size, x1:x1+self.crop_size] 379 | SS = SS[y1:y1+self.crop_size, x1:x1+self.crop_size] 380 | nir_mask = nir_mask[y1:y1+self.crop_size, x1:x1+self.crop_size] 381 | return {'image': img, 382 | 'label': mask, 383 | 'aolp' : aolp, 384 | 'dolp' : dolp, 385 | 'nir' : nir, 386 | 'nir_mask' : nir_mask, 387 | 'u_map': u_map, 388 | 'v_map': v_map, 389 | 'mask':SS} 390 | 391 | class FixScaleCrop(object): 392 | def __init__(self, crop_size): 393 | self.crop_size = crop_size 394 | 395 | def __call__(self, sample): 396 | img = sample['image'] 397 | mask = sample['label'] 398 | aolp = sample['aolp'] 399 | dolp = sample['dolp'] 400 | nir = sample['nir'] 401 | nir_mask = sample['nir_mask'] 402 | SS = sample['mask'] 403 | 404 | # w, h = img.size 405 | h, w = img.shape[:2] 406 | 407 | if w > h: 408 | oh = self.crop_size 409 | ow = int(1.0 * w * oh / h) 410 | else: 411 | ow = self.crop_size 412 | oh = int(1.0 * h * ow / w) 413 | # img = img.resize((ow, oh), Image.BILINEAR) 414 | # mask = mask.resize((ow, oh), Image.NEAREST) 415 | # nir = nir.resize((ow, oh), Image.BILINEAR) 416 | 417 | # center crop 418 | # w, h = img.size 419 | # h, w = img.shape[:2] 420 | x1 = int(round((ow - self.crop_size) / 2.)) 421 | y1 = int(round((oh - self.crop_size) / 2.)) 422 | # img = img.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size)) 423 | # mask = mask.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size)) 424 | # nir = nir.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size)) 425 | 426 | u_map = sample['u_map'] 427 | v_map = sample['v_map'] 428 | u_map = cv2.resize(u_map,(ow,oh)) 429 | v_map = cv2.resize(v_map,(ow,oh)) 430 | aolp = cv2.resize(aolp ,(ow,oh)) 431 | dolp = cv2.resize(dolp ,(ow,oh)) 432 | SS = cv2.resize(SS ,(ow,oh)) 433 | img = cv2.resize(img ,(ow,oh), interpolation=cv2.INTER_LINEAR) 434 | mask = cv2.resize(mask ,(ow,oh), interpolation=cv2.INTER_NEAREST) 435 | nir = cv2.resize(nir ,(ow,oh), interpolation=cv2.INTER_LINEAR) 436 | nir_mask = cv2.resize(nir_mask,(ow,oh), interpolation=cv2.INTER_NEAREST) 437 | u_map = u_map[y1:y1+self.crop_size, x1:x1+self.crop_size] 438 | v_map = v_map[y1:y1+self.crop_size, x1:x1+self.crop_size] 439 | aolp = aolp[y1:y1+self.crop_size, x1:x1+self.crop_size] 440 | dolp = dolp[y1:y1+self.crop_size, x1:x1+self.crop_size] 441 | img = img[y1:y1+self.crop_size, x1:x1+self.crop_size] 442 | mask = mask[y1:y1+self.crop_size, x1:x1+self.crop_size] 443 | SS = SS[y1:y1+self.crop_size, x1:x1+self.crop_size] 444 | nir = nir[y1:y1+self.crop_size, x1:x1+self.crop_size] 445 | nir_mask = nir_mask[y1:y1+self.crop_size, x1:x1+self.crop_size] 446 | return {'image': img, 447 | 'label': mask, 448 | 'aolp' : aolp, 449 | 'dolp' : dolp, 450 | 'nir' : nir, 451 | 'nir_mask' : nir_mask, 452 | 'u_map': u_map, 453 | 'v_map': v_map, 454 | 'mask':SS} 455 | 456 | 457 | if __name__ == '__main__': 458 | traintransform = get_train_augmentation((1024, 1224), seg_fill=255) 459 | 460 | trainset = MCubeS(transform=traintransform, split='val') 461 | trainloader = DataLoader(trainset, batch_size=1, num_workers=0, drop_last=False, pin_memory=False) 462 | 463 | for i, (sample, lbl) in enumerate(trainloader): 464 | print(torch.unique(lbl)) 465 | 466 | 467 | # Reading DoLP.... 468 | # DoLP: Global Min: -0.5281060934066772, Global Max: 1.4464844465255737, Negative Count: 11528302, Negative Percent: 1.839560036254085% 469 | # Reading AoLP (Sin).... 470 | # AoLP (Sin): Global Min: -2.6559877395629883, Global Max: 2.3968138694763184, Negative Count: 279532583, Negative Percent: 44.604744785283906% 471 | # Reading AoLP (Sin).... 472 | # AoLP (Cos): Global Min: -0.5681185722351074, Global Max: 1.6386538743972778, Negative Count: 3569347, Negative Percent: 0.5695572597528594% --------------------------------------------------------------------------------