├── LICENSE ├── README.md ├── configs ├── __init__.py ├── base_config.py ├── my_config.py ├── optuna_config.py └── parser.py ├── core ├── __init__.py ├── base_trainer.py ├── loss.py └── seg_trainer.py ├── datasets ├── __init__.py ├── cityscapes.py ├── dataset_registry.py └── test_dataset.py ├── main.py ├── models ├── __init__.py ├── adscnet.py ├── aglnet.py ├── backbone.py ├── bisenetv1.py ├── bisenetv2.py ├── canet.py ├── cfpnet.py ├── cgnet.py ├── contextnet.py ├── dabnet.py ├── ddrnet.py ├── dfanet.py ├── edanet.py ├── enet.py ├── erfnet.py ├── esnet.py ├── espnet.py ├── espnetv2.py ├── fanet.py ├── farseenet.py ├── fastscnn.py ├── fddwnet.py ├── fpenet.py ├── fssnet.py ├── icnet.py ├── lednet.py ├── linknet.py ├── lite_hrnet.py ├── liteseg.py ├── mininet.py ├── mininetv2.py ├── model_registry.py ├── modules.py ├── pp_liteseg.py ├── regseg.py ├── segnet.py ├── shelfnet.py ├── smp_wrapper.py ├── sqnet.py ├── stdc.py └── swiftnet.py ├── optuna_results ├── bisenetv1.json ├── ddrnet.json └── liteseg.json ├── optuna_search.py ├── requirements.txt ├── tools ├── export.py ├── get_model_infos.py └── test_speed.py └── utils ├── __init__.py ├── metrics.py ├── model_ema.py ├── optimizer.py ├── parallel.py ├── scheduler.py ├── transforms.py └── utils.py /configs/__init__.py: -------------------------------------------------------------------------------- 1 | from .my_config import MyConfig 2 | from .parser import load_parser -------------------------------------------------------------------------------- /configs/base_config.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | class BaseConfig: 5 | def __init__(self,): 6 | # Task 7 | self.task = 'train' # train, val, predict 8 | 9 | # Dataset 10 | self.dataset = None 11 | self.dataroot = None 12 | self.num_class = -1 13 | self.ignore_index = 255 14 | 15 | # Model 16 | self.model = None 17 | self.encoder = None 18 | self.decoder = None 19 | self.encoder_weights = 'imagenet' 20 | 21 | # Detail Head (For STDC) 22 | self.use_detail_head = False 23 | self.detail_thrs = 0.1 24 | self.detail_loss_coef = 1.0 25 | self.dice_loss_coef = 1.0 26 | self.bce_loss_coef = 1.0 27 | 28 | # Training 29 | self.total_epoch = 200 30 | self.base_lr = 0.01 31 | self.train_bs = 16 # For each GPU 32 | self.use_aux = False 33 | self.aux_coef = None 34 | 35 | # Validating 36 | self.val_bs = 16 # For each GPU 37 | self.begin_val_epoch = 0 # Epoch to start validation 38 | self.val_interval = 1 # Epoch interval between validation 39 | 40 | # Testing 41 | self.test_bs = 16 42 | self.test_data_folder = None 43 | self.colormap = 'cityscapes' 44 | self.save_mask = True 45 | self.blend_prediction = True 46 | self.blend_alpha = 0.3 47 | 48 | # Loss 49 | self.loss_type = 'ohem' 50 | self.class_weights = None 51 | self.ohem_thrs = 0.7 52 | self.reduction = 'mean' 53 | 54 | # Scheduler 55 | self.lr_policy = 'cos_warmup' 56 | self.warmup_epochs = 3 57 | 58 | # Optimizer 59 | self.optimizer_type = 'sgd' 60 | self.momentum = 0.9 # For SGD 61 | self.weight_decay = 1e-4 # For SGD 62 | 63 | # Monitoring 64 | self.save_ckpt = True 65 | self.save_dir = 'save' 66 | self.use_tb = True # tensorboard 67 | self.tb_log_dir = None 68 | self.ckpt_name = None 69 | self.logger_name = None 70 | 71 | # Training setting 72 | self.amp_training = False 73 | self.resume_training = True 74 | self.load_ckpt = True 75 | self.load_ckpt_path = None 76 | self.base_workers = 8 77 | self.random_seed = 1 78 | self.use_ema = False 79 | 80 | # Augmentation 81 | self.crop_size = 512 82 | self.crop_h = None 83 | self.crop_w = None 84 | self.scale = 1.0 85 | self.randscale = 0.0 86 | self.brightness = 0.0 87 | self.contrast = 0.0 88 | self.saturation = 0.0 89 | self.h_flip = 0.0 90 | self.v_flip = 0.0 91 | 92 | # DDP 93 | self.synBN = True 94 | self.destroy_ddp_process = True 95 | 96 | # Knowledge Distillation 97 | self.kd_training = False 98 | self.teacher_ckpt = '' 99 | self.teacher_model = 'smp' 100 | self.teacher_encoder = None 101 | self.teacher_decoder = None 102 | self.kd_loss_type = 'kl_div' 103 | self.kd_loss_coefficient = 1.0 104 | self.kd_temperature = 4.0 105 | 106 | # Export 107 | self.export_format = 'onnx' 108 | self.export_size = (512, 1024) 109 | self.export_name = None 110 | self.onnx_opset = 11 111 | self.load_onnx_path = None 112 | 113 | def init_dependent_config(self): 114 | if self.load_ckpt_path is None and self.task == 'train': 115 | self.load_ckpt_path = f'{self.save_dir}/last.pth' 116 | 117 | if self.tb_log_dir is None: 118 | self.tb_log_dir = f'{self.save_dir}/tb_logs/' 119 | 120 | if self.crop_h is None: 121 | self.crop_h = self.crop_size 122 | 123 | if self.crop_w is None: 124 | self.crop_w = self.crop_size 125 | 126 | if self.export_name is None: 127 | suffix = os.path.basename(self.load_ckpt_path).replace('.pth', '') if self.load_ckpt_path else 'dummy' 128 | self.export_name = f'{self.model}_{suffix}' -------------------------------------------------------------------------------- /configs/my_config.py: -------------------------------------------------------------------------------- 1 | from .base_config import BaseConfig 2 | 3 | 4 | class MyConfig(BaseConfig): 5 | def __init__(self,): 6 | super().__init__() 7 | # Task 8 | self.task = 'predict' 9 | 10 | # Dataset 11 | self.dataset = 'cityscapes' 12 | self.data_root = '/path/to/your/dataset' 13 | self.num_class = 19 14 | 15 | # Model 16 | self.model = 'bisenetv2' 17 | 18 | # Training 19 | self.total_epoch = 200 20 | self.train_bs = 8 21 | self.loss_type = 'ohem' 22 | self.optimizer_type = 'adam' 23 | self.logger_name = 'seg_trainer' 24 | self.use_aux = True 25 | 26 | # Validating 27 | self.val_bs = 10 28 | 29 | # Testing 30 | self.test_bs = 8 31 | self.test_data_folder = '/path/to/your/test/folder' 32 | self.load_ckpt_path = '/path/to/your/inference/checkpoint' 33 | self.save_mask = True 34 | 35 | # Training setting 36 | self.use_ema = False 37 | 38 | # Augmentation 39 | self.crop_size = 768 40 | self.randscale = [-0.5, 1.0] 41 | self.scale = 1.0 42 | self.brightness = 0.5 43 | self.contrast = 0.5 44 | self.saturation = 0.5 45 | self.h_flip = 0.5 46 | 47 | # Knowledge Distillation 48 | self.kd_training = False 49 | self.teacher_ckpt = '/path/to/your/teacher/checkpoint' 50 | self.teacher_model = 'smp' 51 | self.teacher_encoder = 'resnet101' 52 | self.teacher_decoder = 'deeplabv3p' -------------------------------------------------------------------------------- /configs/optuna_config.py: -------------------------------------------------------------------------------- 1 | try: 2 | import optuna 3 | except: 4 | raise RuntimeError('Unable to import Optuna. Please check whether you have installed it correctly.\n') 5 | from .base_config import BaseConfig 6 | 7 | 8 | class OptunaConfig(BaseConfig): 9 | def __init__(self,): 10 | super().__init__() 11 | # Task 12 | self.task = 'train' 13 | 14 | # Dataset 15 | self.dataset = 'cityscapes' 16 | self.data_root = '/path/to/your/dataset' 17 | self.num_class = 19 18 | 19 | # Model 20 | self.model = 'bisenetv1' 21 | 22 | # Training 23 | self.total_epoch = 200 24 | self.train_bs = 5 25 | self.loss_type = 'ohem' 26 | self.logger_name = 'seg_trainer' 27 | 28 | # Validating 29 | self.val_bs = 4 30 | 31 | # Training setting 32 | self.load_ckpt = False 33 | 34 | # DDP 35 | self.synBN = True 36 | self.destroy_ddp_process = False 37 | 38 | # Augmentation 39 | self.scale = 1.0 40 | self.crop_size = 1024 41 | 42 | # Optuna 43 | self.study_name = 'optuna-study' 44 | self.study_direction = 'maximize' 45 | self.num_trial = 100 46 | self.save_every_trial = True 47 | 48 | def get_trial_params(self, trial): 49 | self.optimizer_type = trial.suggest_categorical('optimizer', ['sgd', 'adam', 'adamw']) 50 | self.base_lr = trial.suggest_loguniform('base_lr', 1e-3, 1e-1) 51 | self.use_ema = trial.suggest_categorical('use_ema', [True, False]) 52 | self.scale_max = trial.suggest_float('scale_max', 0.25, 1.5) 53 | self.scale_min = trial.suggest_float('scale_min', 0.1, 0.8) 54 | self.brightness = trial.suggest_float('brightness', 0.0, 0.9) 55 | self.contrast = trial.suggest_float('contrast', 0.0, 0.9) 56 | self.saturation = trial.suggest_float('saturation', 0.0, 0.9) 57 | self.h_flip = trial.suggest_float('h_flip', 0.0, 0.5) 58 | 59 | self.randscale = [-self.scale_min, self.scale_max] -------------------------------------------------------------------------------- /core/__init__.py: -------------------------------------------------------------------------------- 1 | from .seg_trainer import SegTrainer 2 | from .loss import get_loss_fn, kd_loss_fn, get_detail_loss_fn -------------------------------------------------------------------------------- /core/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class OhemCELoss(nn.Module): 7 | def __init__(self, thresh, ignore_index=255): 8 | super().__init__() 9 | self.thresh = -torch.log(torch.tensor(thresh, requires_grad=False, dtype=torch.float)).cuda() 10 | self.ignore_index = ignore_index 11 | self.criteria = nn.CrossEntropyLoss(ignore_index=ignore_index, reduction='none') 12 | 13 | def forward(self, logits, labels): 14 | n_min = labels[labels != self.ignore_index].numel() // 16 15 | loss = self.criteria(logits, labels).view(-1) 16 | loss_hard = loss[loss > self.thresh] 17 | if loss_hard.numel() < n_min: 18 | loss_hard, _ = loss.topk(n_min) 19 | 20 | return torch.mean(loss_hard) 21 | 22 | 23 | class DiceLoss(nn.Module): 24 | def __init__(self, smooth=1): 25 | super().__init__() 26 | self.smooth = smooth 27 | 28 | def forward(self, logits, labels): 29 | logits = torch.flatten(logits, 1) 30 | labels = torch.flatten(labels, 1) 31 | 32 | intersection = torch.sum(logits * labels, dim=1) 33 | loss = 1 - ((2 * intersection + self.smooth) / (logits.sum(1) + labels.sum(1) + self.smooth)) 34 | 35 | return torch.mean(loss) 36 | 37 | 38 | class DetailLoss(nn.Module): 39 | '''Implement detail loss used in paper 40 | `Rethinking BiSeNet For Real-time Semantic Segmentation`''' 41 | def __init__(self, dice_loss_coef=1., bce_loss_coef=1., smooth=1): 42 | super().__init__() 43 | self.dice_loss_coef = dice_loss_coef 44 | self.bce_loss_coef = bce_loss_coef 45 | self.dice_loss_fn = DiceLoss(smooth) 46 | self.bce_loss_fn = nn.BCEWithLogitsLoss() 47 | 48 | def forward(self, logits, labels): 49 | loss = self.dice_loss_coef * self.dice_loss_fn(logits, labels) + \ 50 | self.bce_loss_coef * self.bce_loss_fn(logits, labels) 51 | 52 | return loss 53 | 54 | 55 | def get_loss_fn(config, device): 56 | if config.class_weights is None: 57 | weights = None 58 | else: 59 | weights = torch.Tensor(config.class_weights).to(device) 60 | 61 | if config.loss_type == 'ce': 62 | criterion = nn.CrossEntropyLoss(ignore_index=config.ignore_index, 63 | reduction=config.reduction, weight=weights) 64 | 65 | elif config.loss_type == 'ohem': 66 | criterion = OhemCELoss(thresh=config.ohem_thrs, ignore_index=config.ignore_index) 67 | 68 | else: 69 | raise NotImplementedError(f"Unsupport loss type: {config.loss_type}") 70 | 71 | return criterion 72 | 73 | 74 | def get_detail_loss_fn(config): 75 | detail_loss_fn = DetailLoss(dice_loss_coef=config.dice_loss_coef, bce_loss_coef=config.bce_loss_coef) 76 | 77 | return detail_loss_fn 78 | 79 | 80 | def kd_loss_fn(config, outputs, outputsT): 81 | if config.kd_loss_type == 'kl_div': 82 | lossT = F.kl_div(F.log_softmax(outputs/config.kd_temperature, dim=1), 83 | F.softmax(outputsT.detach()/config.kd_temperature, dim=1)) * config.kd_temperature ** 2 84 | 85 | elif config.kd_loss_type == 'mse': 86 | lossT = F.mse_loss(outputs, outputsT.detach()) 87 | 88 | return lossT -------------------------------------------------------------------------------- /core/seg_trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from PIL import Image 5 | from tqdm import tqdm 6 | from torch.cuda import amp 7 | import torch.nn.functional as F 8 | 9 | from .base_trainer import BaseTrainer 10 | from .loss import kd_loss_fn 11 | from models import get_teacher_model 12 | from utils import (get_seg_metrics, sampler_set_epoch, get_colormap) 13 | 14 | 15 | class SegTrainer(BaseTrainer): 16 | def __init__(self, config): 17 | super().__init__(config) 18 | if config.task == 'predict': 19 | self.colormap = torch.tensor(get_colormap(config)).to(self.device) 20 | else: 21 | self.metrics = get_seg_metrics(config).to(self.device) 22 | 23 | if config.task == 'train': 24 | self.teacher_model = get_teacher_model(config, self.device) 25 | 26 | if config.use_detail_head: 27 | from .loss import get_detail_loss_fn 28 | from models import LaplacianConv 29 | 30 | self.laplacian_conv = LaplacianConv(self.device) 31 | self.detail_loss_fn = get_detail_loss_fn(config) 32 | 33 | def train_one_epoch(self, config): 34 | self.model.train() 35 | 36 | sampler_set_epoch(config, self.train_loader, self.cur_epoch) 37 | 38 | pbar = tqdm(self.train_loader) if self.main_rank else self.train_loader 39 | 40 | for cur_itrs, (images, masks) in enumerate(pbar): 41 | self.cur_itrs = cur_itrs 42 | self.train_itrs += 1 43 | 44 | images = images.to(self.device, dtype=torch.float32) 45 | masks = masks.to(self.device, dtype=torch.long) 46 | 47 | self.optimizer.zero_grad() 48 | 49 | # Forward path 50 | if config.use_aux: 51 | with amp.autocast(enabled=config.amp_training): 52 | preds, preds_aux = self.model(images, is_training=True) 53 | loss = self.loss_fn(preds, masks) 54 | 55 | masks_auxs = masks.unsqueeze(1).float() 56 | if config.aux_coef is None: 57 | config.aux_coef = torch.ones(len(preds_aux)) 58 | elif len(preds_aux) != len(config.aux_coef): 59 | raise ValueError('Auxiliary loss coefficient length does not match.') 60 | 61 | for i in range(len(preds_aux)): 62 | aux_size = preds_aux[i].size()[2:] 63 | masks_aux = F.interpolate(masks_auxs, aux_size, mode='nearest') 64 | masks_aux = masks_aux.squeeze(1).to(self.device, dtype=torch.long) 65 | 66 | with amp.autocast(enabled=config.amp_training): 67 | loss += config.aux_coef[i] * self.loss_fn(preds_aux[i], masks_aux) 68 | 69 | # Detail loss proposed in paper for model STDC 70 | elif config.use_detail_head: 71 | masks_detail = masks.unsqueeze(1).float() 72 | masks_detail = self.laplacian_conv(masks_detail) 73 | 74 | with amp.autocast(enabled=config.amp_training): 75 | # Detail ground truth 76 | masks_detail = self.model.module.detail_conv(masks_detail) 77 | masks_detail[masks_detail > config.detail_thrs] = 1 78 | masks_detail[masks_detail <= config.detail_thrs] = 0 79 | detail_size = masks_detail.size()[2:] 80 | 81 | preds, preds_detail = self.model(images, is_training=True) 82 | preds_detail = F.interpolate(preds_detail, detail_size, mode='bilinear', align_corners=True) 83 | loss_detail = self.detail_loss_fn(preds_detail, masks_detail) 84 | loss = self.loss_fn(preds, masks) + config.detail_loss_coef * loss_detail 85 | 86 | else: 87 | with amp.autocast(enabled=config.amp_training): 88 | preds = self.model(images) 89 | loss = self.loss_fn(preds, masks) 90 | 91 | if config.use_tb and self.main_rank: 92 | self.writer.add_scalar('train/loss', loss.detach(), self.train_itrs) 93 | if config.use_detail_head: 94 | self.writer.add_scalar('train/loss_detail', loss_detail.detach(), self.train_itrs) 95 | 96 | # Knowledge distillation 97 | if config.kd_training: 98 | with amp.autocast(enabled=config.amp_training): 99 | with torch.no_grad(): 100 | teacher_preds = self.teacher_model(images) # Teacher predictions 101 | 102 | loss_kd = kd_loss_fn(config, preds, teacher_preds.detach()) 103 | loss += config.kd_loss_coefficient * loss_kd 104 | 105 | if config.use_tb and self.main_rank: 106 | self.writer.add_scalar('train/loss_kd', loss_kd.detach(), self.train_itrs) 107 | self.writer.add_scalar('train/loss_total', loss.detach(), self.train_itrs) 108 | 109 | # Backward path 110 | self.scaler.scale(loss).backward() 111 | self.scaler.step(self.optimizer) 112 | self.scaler.update() 113 | self.scheduler.step() 114 | 115 | self.ema_model.update(self.model, self.train_itrs) 116 | 117 | if self.main_rank: 118 | pbar.set_description(('%s'*2) % 119 | (f'Epoch:{self.cur_epoch}/{config.total_epoch}{" "*4}|', 120 | f'Loss:{loss.detach():4.4g}{" "*4}|',) 121 | ) 122 | 123 | return 124 | 125 | @torch.no_grad() 126 | def validate(self, config, val_best=False): 127 | pbar = tqdm(self.val_loader) if self.main_rank else self.val_loader 128 | for (images, masks) in pbar: 129 | images = images.to(self.device, dtype=torch.float32) 130 | masks = masks.to(self.device, dtype=torch.long) 131 | 132 | preds = self.ema_model.ema(images) 133 | self.metrics.update(preds.detach(), masks) 134 | 135 | if self.main_rank: 136 | pbar.set_description(('%s'*1) % (f'Validating:{" "*4}|',)) 137 | 138 | iou = self.metrics.compute() 139 | score = iou.mean() # mIoU 140 | 141 | if self.main_rank: 142 | if val_best: 143 | self.logger.info(f'\n\nTrain {config.total_epoch} epochs finished.' + 144 | f'\n\nBest mIoU is: {score:.4f}\n') 145 | else: 146 | self.logger.info(f' Epoch{self.cur_epoch} mIoU: {score:.4f} | ' + 147 | f'best mIoU so far: {self.best_score:.4f}\n') 148 | 149 | if config.use_tb and self.cur_epoch < config.total_epoch: 150 | self.writer.add_scalar('val/mIoU', score.cpu(), self.cur_epoch+1) 151 | for i in range(config.num_class): 152 | self.writer.add_scalar(f'val/IoU_cls{i:02f}', iou[i].cpu(), self.cur_epoch+1) 153 | self.metrics.reset() 154 | return score 155 | 156 | @torch.no_grad() 157 | def predict(self, config): 158 | if config.DDP: 159 | raise ValueError('Predict mode currently does not support DDP.') 160 | 161 | self.logger.info('\nStart predicting...\n') 162 | 163 | self.model.eval() # Put model in evalation mode 164 | 165 | for (images, images_aug, img_names) in tqdm(self.test_loader): 166 | images_aug = images_aug.to(self.device, dtype=torch.float32) 167 | 168 | preds = self.model(images_aug) 169 | 170 | preds = self.colormap[preds.max(dim=1)[1]].cpu().numpy() 171 | 172 | images = images.cpu().numpy() 173 | 174 | # Saving results 175 | for i in range(preds.shape[0]): 176 | save_path = os.path.join(config.save_dir, img_names[i]) 177 | save_suffix = img_names[i].split('.')[-1] 178 | 179 | pred = Image.fromarray(preds[i].astype(np.uint8)) 180 | 181 | if config.save_mask: 182 | pred.save(save_path) 183 | 184 | if config.blend_prediction: 185 | save_blend_path = save_path.replace(f'.{save_suffix}', f'_blend.{save_suffix}') 186 | 187 | image = Image.fromarray(images[i].astype(np.uint8)) 188 | image = Image.blend(image, pred, config.blend_alpha) 189 | image.save(save_blend_path) 190 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader 2 | 3 | from .cityscapes import Cityscapes 4 | from .dataset_registry import dataset_hub 5 | 6 | 7 | def get_dataset(config): 8 | if config.dataset in dataset_hub.keys(): 9 | train_dataset = dataset_hub[config.dataset](config=config, mode='train') 10 | val_dataset = dataset_hub[config.dataset](config=config, mode='val') 11 | else: 12 | raise NotImplementedError('Unsupported dataset!') 13 | 14 | return train_dataset, val_dataset 15 | 16 | 17 | def get_loader(config, rank, pin_memory=True): 18 | train_dataset, val_dataset = get_dataset(config) 19 | 20 | # Make sure train number is divisible by train batch size 21 | config.train_num = int(len(train_dataset) // config.train_bs * config.train_bs) 22 | config.val_num = len(val_dataset) 23 | 24 | if config.DDP: 25 | from torch.utils.data.distributed import DistributedSampler 26 | train_sampler = DistributedSampler(train_dataset, num_replicas=config.gpu_num, 27 | rank=rank, shuffle=True) 28 | val_sampler = DistributedSampler(val_dataset, num_replicas=config.gpu_num, 29 | rank=rank, shuffle=False) 30 | 31 | train_loader = DataLoader(train_dataset, batch_size=config.train_bs, shuffle=False, 32 | num_workers=config.num_workers, pin_memory=pin_memory, 33 | sampler=train_sampler, drop_last=True) 34 | 35 | val_loader = DataLoader(val_dataset, batch_size=config.val_bs, shuffle=False, 36 | num_workers=config.num_workers, pin_memory=pin_memory, 37 | sampler=val_sampler) 38 | else: 39 | train_loader = DataLoader(train_dataset, batch_size=config.train_bs, 40 | shuffle=True, num_workers=config.num_workers, drop_last=True) 41 | 42 | val_loader = DataLoader(val_dataset, batch_size=config.val_bs, 43 | shuffle=False, num_workers=config.num_workers) 44 | 45 | return train_loader, val_loader 46 | 47 | 48 | def get_test_loader(config): 49 | from .test_dataset import TestDataset 50 | dataset = TestDataset(config) 51 | 52 | config.test_num = len(dataset) 53 | 54 | if config.DDP: 55 | raise NotImplementedError() 56 | 57 | else: 58 | test_loader = DataLoader(dataset, batch_size=config.test_bs, 59 | shuffle=False, num_workers=config.num_workers) 60 | 61 | return test_loader 62 | 63 | 64 | def list_available_datasets(): 65 | dataset_list = list(dataset_hub.keys()) 66 | 67 | return dataset_list -------------------------------------------------------------------------------- /datasets/dataset_registry.py: -------------------------------------------------------------------------------- 1 | dataset_hub = {} 2 | 3 | 4 | def register_dataset(dataset_class): 5 | dataset_hub[dataset_class.__name__.lower()] = dataset_class 6 | return dataset_class -------------------------------------------------------------------------------- /datasets/test_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from PIL import Image 4 | from torch.utils.data import Dataset 5 | import albumentations as AT 6 | from albumentations.pytorch import ToTensorV2 7 | from utils import transforms 8 | 9 | 10 | class TestDataset(Dataset): 11 | def __init__(self, config): 12 | data_folder = os.path.expanduser(config.test_data_folder) 13 | 14 | if not os.path.isdir(data_folder): 15 | raise RuntimeError(f'Test image directory: {data_folder} does not exist.') 16 | 17 | self.transform = AT.Compose([ 18 | transforms.Scale(scale=config.scale, is_testing=True), 19 | AT.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), 20 | ToTensorV2(), 21 | ]) 22 | 23 | self.images = [] 24 | self.img_names = [] 25 | 26 | for file_name in os.listdir(data_folder): 27 | self.images.append(os.path.join(data_folder, file_name)) 28 | self.img_names.append(file_name) 29 | 30 | def __len__(self): 31 | return len(self.images) 32 | 33 | def __getitem__(self, index): 34 | image = np.asarray(Image.open(self.images[index]).convert('RGB')) 35 | img_name = self.img_names[index] 36 | 37 | # Perform augmentation and normalization 38 | augmented = self.transform(image=image) 39 | image_aug = augmented['image'] 40 | 41 | return image, image_aug, img_name -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from core import SegTrainer 2 | from configs import MyConfig, load_parser 3 | 4 | import warnings 5 | warnings.filterwarnings("ignore") 6 | 7 | 8 | if __name__ == '__main__': 9 | config = MyConfig() 10 | 11 | config.init_dependent_config() 12 | 13 | # If you want to use command-line arguments, please uncomment the following line 14 | # config = load_parser(config) 15 | 16 | trainer = SegTrainer(config) 17 | 18 | if config.task == 'train': 19 | trainer.run(config) 20 | elif config.task == 'val': 21 | trainer.validate(config) 22 | elif config.task == 'predict': 23 | trainer.predict(config) 24 | else: 25 | raise ValueError(f'Unsupported task type: {config.task}.\n') -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | import os, torch 2 | 3 | from .adscnet import ADSCNet 4 | from .aglnet import AGLNet 5 | from .bisenetv1 import BiSeNetv1 6 | from .bisenetv2 import BiSeNetv2 7 | from .canet import CANet 8 | from .cfpnet import CFPNet 9 | from .cgnet import CGNet 10 | from .contextnet import ContextNet 11 | from .dabnet import DABNet 12 | from .ddrnet import DDRNet 13 | from .dfanet import DFANet 14 | from .edanet import EDANet 15 | from .enet import ENet 16 | from .erfnet import ERFNet 17 | from .esnet import ESNet 18 | from .espnet import ESPNet 19 | from .espnetv2 import ESPNetv2 20 | from .fanet import FANet 21 | from .farseenet import FarSeeNet 22 | from .fastscnn import FastSCNN 23 | from .fddwnet import FDDWNet 24 | from .fpenet import FPENet 25 | from .fssnet import FSSNet 26 | from .icnet import ICNet 27 | from .lednet import LEDNet 28 | from .linknet import LinkNet 29 | from .lite_hrnet import LiteHRNet 30 | from .liteseg import LiteSeg 31 | from .mininet import MiniNet 32 | from .mininetv2 import MiniNetv2 33 | from .pp_liteseg import PPLiteSeg 34 | from .regseg import RegSeg 35 | from .segnet import SegNet 36 | from .shelfnet import ShelfNet 37 | from .sqnet import SQNet 38 | from .stdc import STDC, LaplacianConv 39 | from .swiftnet import SwiftNet 40 | from .model_registry import model_hub, aux_models, detail_head_models 41 | 42 | 43 | def get_model(config): 44 | if config.model == 'smp': # Use segmentation models pytorch 45 | from .smp_wrapper import get_smp_model 46 | 47 | model = get_smp_model(config.encoder, config.decoder, config.encoder_weights, config.num_class) 48 | 49 | elif config.model in model_hub.keys(): 50 | if config.model in aux_models: # models support auxiliary heads 51 | if config.model in detail_head_models: # models support detail heads 52 | model = model_hub[config.model](num_class=config.num_class, use_detail_head=config.use_detail_head, use_aux=config.use_aux) 53 | else: 54 | model = model_hub[config.model](num_class=config.num_class, use_aux=config.use_aux) 55 | 56 | else: 57 | if config.use_aux: 58 | raise ValueError(f'Model {config.model} does not support auxiliary heads.\n') 59 | 60 | model = model_hub[config.model](num_class=config.num_class) 61 | 62 | else: 63 | raise NotImplementedError(f"Unsupport model type: {config.model}") 64 | 65 | return model 66 | 67 | 68 | def list_available_models(): 69 | model_list = list(model_hub.keys()) 70 | 71 | try: 72 | import segmentation_models_pytorch as smp 73 | model_list.append('smp') 74 | except: 75 | pass 76 | 77 | return model_list 78 | 79 | 80 | def get_teacher_model(config, device): 81 | if config.kd_training: 82 | if not os.path.isfile(config.teacher_ckpt): 83 | raise ValueError(f'Could not find teacher checkpoint at path {config.teacher_ckpt}.') 84 | 85 | if config.teacher_model == 'smp': 86 | from .smp_wrapper import get_smp_model 87 | 88 | model = get_smp_model(config.teacher_encoder, config.teacher_decoder, None, config.num_class) 89 | 90 | else: 91 | raise NotImplementedError() 92 | 93 | teacher_ckpt = torch.load(config.teacher_ckpt, map_location=torch.device('cpu')) 94 | model.load_state_dict(teacher_ckpt['state_dict']) 95 | del teacher_ckpt 96 | 97 | model = model.to(device) 98 | model.eval() 99 | else: 100 | model = None 101 | 102 | return model -------------------------------------------------------------------------------- /models/adscnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Paper: ADSCNet: asymmetric depthwise separable convolution for semantic 3 | segmentation in real-time 4 | Url: https://link.springer.com/article/10.1007/s10489-019-01587-1 5 | Create by: zh320 6 | Date: 2023/09/30 7 | """ 8 | 9 | import torch 10 | import torch.nn as nn 11 | 12 | from .modules import conv1x1, ConvBNAct, DWConvBNAct, DeConvBNAct, Activation 13 | from .model_registry import register_model 14 | 15 | 16 | @register_model() 17 | class ADSCNet(nn.Module): 18 | def __init__(self, num_class=1, n_channel=3, act_type='relu6'): 19 | super().__init__() 20 | self.conv0 = ConvBNAct(n_channel, 32, 3, 2, act_type=act_type, inplace=True) 21 | self.conv1 = ADSCModule(32, 1, act_type=act_type) 22 | self.conv2_4 = nn.Sequential( 23 | ADSCModule(32, 1, act_type=act_type), 24 | ADSCModule(32, 2, act_type=act_type), 25 | ADSCModule(64, 1, act_type=act_type) 26 | ) 27 | self.conv5 = ADSCModule(64, 2, act_type=act_type) 28 | self.ddcc = DDCC(128, [3, 5, 9, 13], act_type) 29 | self.up1 = nn.Sequential( 30 | DeConvBNAct(128, 64), 31 | ADSCModule(64, 1, act_type=act_type) 32 | ) 33 | self.up2 = nn.Sequential( 34 | ADSCModule(64, 1, act_type=act_type), 35 | DeConvBNAct(64, 32) 36 | ) 37 | self.up3 = nn.Sequential( 38 | ADSCModule(32, 1, act_type=act_type), 39 | DeConvBNAct(32, num_class) 40 | ) 41 | 42 | def forward(self, x): 43 | x = self.conv0(x) 44 | x1 = self.conv1(x) 45 | x4 = self.conv2_4(x1) 46 | x = self.conv5(x4) 47 | x = self.ddcc(x) 48 | x = self.up1(x) 49 | x += x4 50 | x = self.up2(x) 51 | x += x1 52 | x = self.up3(x) 53 | 54 | return x 55 | 56 | 57 | class ADSCModule(nn.Module): 58 | def __init__(self, channels, stride, dilation=1, act_type='relu'): 59 | super().__init__() 60 | assert stride in [1, 2], 'Unsupported stride type.\n' 61 | self.use_skip = stride == 1 62 | self.conv = nn.Sequential( 63 | DWConvBNAct(channels, channels, (3, 1), stride, dilation, act_type, inplace=True), 64 | conv1x1(channels, channels), 65 | DWConvBNAct(channels, channels, (1, 3), 1, dilation, act_type, inplace=True), 66 | conv1x1(channels, channels) 67 | ) 68 | if not self.use_skip: 69 | self.pool = nn.AvgPool2d(3, 2, 1) 70 | 71 | def forward(self, x): 72 | x_conv = self.conv(x) 73 | 74 | if self.use_skip: 75 | x = x + x_conv 76 | else: 77 | x_pool = self.pool(x) 78 | x = torch.cat([x_conv, x_pool], dim=1) 79 | 80 | return x 81 | 82 | 83 | class DDCC(nn.Module): 84 | def __init__(self, channels, dilations, act_type): 85 | super().__init__() 86 | assert len(dilations)==4, 'Length of dilations should be 4.\n' 87 | self.block1 = nn.Sequential( 88 | nn.AvgPool2d(dilations[0], 1, dilations[0]//2), 89 | ADSCModule(channels, 1, dilations[0], act_type) 90 | ) 91 | 92 | self.block2 = nn.Sequential( 93 | conv1x1(2*channels, channels), 94 | nn.AvgPool2d(dilations[1], 1, dilations[1]//2), 95 | ADSCModule(channels, 1, dilations[1], act_type) 96 | ) 97 | 98 | self.block3 = nn.Sequential( 99 | conv1x1(3*channels, channels), 100 | nn.AvgPool2d(dilations[2], 1, dilations[2]//2), 101 | ADSCModule(channels, 1, dilations[2], act_type) 102 | ) 103 | 104 | self.block4 = nn.Sequential( 105 | conv1x1(4*channels, channels), 106 | nn.AvgPool2d(dilations[3], 1, dilations[3]//2), 107 | ADSCModule(channels, 1, dilations[3], act_type) 108 | ) 109 | 110 | self.conv_last = conv1x1(5*channels, channels) 111 | 112 | def forward(self, x): 113 | x1 = self.block1(x) 114 | 115 | x2 = torch.cat([x, x1], dim=1) 116 | x2 = self.block2(x2) 117 | 118 | x3 = torch.cat([x, x1, x2], dim=1) 119 | x3 = self.block3(x3) 120 | 121 | x4 = torch.cat([x, x1, x2, x3], dim=1) 122 | x4 = self.block4(x4) 123 | 124 | x = torch.cat([x, x1, x2, x3, x4], dim=1) 125 | x = self.conv_last(x) 126 | 127 | return x -------------------------------------------------------------------------------- /models/aglnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Paper: AGLNet: Towards real-time semantic segmentation of self-driving images 3 | via attention-guided lightweight network 4 | Url: https://www.sciencedirect.com/science/article/abs/pii/S1568494620306207 5 | Create by: zh320 6 | Date: 2023/08/27 7 | """ 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | from .modules import conv1x1, ConvBNAct, Activation, channel_shuffle 14 | from .enet import InitialBlock as DownsamplingUnit 15 | from .lednet import SSnbtUnit 16 | from .model_registry import register_model 17 | 18 | 19 | @register_model() 20 | class AGLNet(nn.Module): 21 | def __init__(self, num_class=1, n_channel=3, act_type='relu'): 22 | super().__init__() 23 | self.layer1 = DownsamplingUnit(n_channel, 32, act_type=act_type) 24 | self.layer2_4 = build_blocks(SSnbtUnit, 32, 3, act_type=act_type) 25 | self.layer5 = DownsamplingUnit(32, 64, act_type=act_type) 26 | self.layer6_7 = build_blocks(SSnbtUnit, 64, 2, act_type=act_type) 27 | self.layer8 = DownsamplingUnit(64, 128, act_type=act_type) 28 | self.layer9_16 = build_blocks(SSnbtUnit, 128, 8, dilations=[1,2,5,9,2,5,9,17], act_type=act_type) 29 | self.layer17 = FAPM(128, act_type=act_type) 30 | self.layer18 = GAUM(64, 128, 64, act_type=act_type) 31 | self.layer19 = GAUM(32, 64, 32, act_type=act_type) 32 | self.layer20 = conv1x1(32, num_class) 33 | 34 | def forward(self, x): 35 | size = x.size()[2:] 36 | 37 | # Stage 1 38 | x = self.layer1(x) 39 | x = self.layer2_4(x) 40 | x_s1 = x 41 | 42 | # Stage 2 43 | x = self.layer5(x) 44 | x = self.layer6_7(x) 45 | x_s2 = x 46 | 47 | # Stage 3 48 | x = self.layer8(x) 49 | x = self.layer9_16(x) 50 | 51 | x = self.layer17(x) 52 | x = self.layer18(x, x_s2) 53 | x = self.layer19(x, x_s1) 54 | 55 | x = self.layer20(x) 56 | x = F.interpolate(x, size, mode='bilinear', align_corners=True) 57 | 58 | return x 59 | 60 | 61 | def build_blocks(block, channels, num_block, dilations=[], act_type='relu'): 62 | if len(dilations) == 0: 63 | dilations = [1 for _ in range(num_block)] 64 | else: 65 | if len(dilations) != num_block: 66 | raise ValueError(f'Number of dilation should be equal to number of blocks') 67 | 68 | layers = [] 69 | for i in range(num_block): 70 | layers.append(block(channels, dilation=dilations[i], act_type=act_type)) 71 | return nn.Sequential(*layers) 72 | 73 | 74 | class FAPM(nn.Module): 75 | def __init__(self, channels, act_type): 76 | super().__init__() 77 | self.pfa = PyramidFeatureAttention(channels, act_type) 78 | self.conv = conv1x1(1, channels) 79 | self.gp = nn.Sequential( 80 | nn.AdaptiveAvgPool2d(1), 81 | conv1x1(channels, channels), 82 | ) 83 | 84 | def forward(self, x): 85 | size = x.size()[2:] 86 | x_pfa = self.pfa(x) 87 | x_pfa = self.conv(x_pfa) 88 | 89 | x_gp = self.gp(x) 90 | x_gp = F.interpolate(x_gp, size, mode='bilinear', align_corners=True) 91 | 92 | x = x * x_pfa 93 | x += x_gp 94 | 95 | return x 96 | 97 | 98 | class PyramidFeatureAttention(nn.Module): 99 | def __init__(self, channels, act_type): 100 | super().__init__() 101 | self.conv11 = ConvBNAct(channels, 1, (1,7), 2, act_type=act_type) 102 | self.conv12 = ConvBNAct(1, 1, (7,1), 1, act_type=act_type) 103 | self.conv21 = ConvBNAct(1, 1, (1,5), 2, act_type=act_type) 104 | self.conv22 = ConvBNAct(1, 1, (5,1), 1, act_type=act_type) 105 | self.conv31 = ConvBNAct(1, 1, (1,3), 2, act_type=act_type) 106 | self.conv32 = ConvBNAct(1, 1, (3,1), 1, act_type=act_type) 107 | 108 | def forward(self, x): 109 | size0 = x.size()[2:] 110 | 111 | x = self.conv11(x) 112 | size1 = x.size()[2:] 113 | x1 = self.conv12(x) 114 | 115 | x = self.conv21(x) 116 | size2 = x.size()[2:] 117 | x2 = self.conv22(x) 118 | 119 | x = self.conv31(x) 120 | x = self.conv32(x) 121 | x = F.interpolate(x, size2, mode='bilinear', align_corners=True) 122 | 123 | x += x2 124 | x = F.interpolate(x, size1, mode='bilinear', align_corners=True) 125 | 126 | x += x1 127 | x = F.interpolate(x, size0, mode='bilinear', align_corners=True) 128 | 129 | return x 130 | 131 | 132 | class GAUM(nn.Module): 133 | def __init__(self, low_channels, high_channels, out_channels, act_type): 134 | super().__init__() 135 | self.up_conv = nn.Sequential( 136 | nn.ConvTranspose2d(high_channels, low_channels, 3, 2, 1, 1), 137 | nn.BatchNorm2d(low_channels), 138 | Activation(act_type) 139 | ) 140 | self.sab = SpatialAttentionBlock(low_channels) 141 | self.cab = ChannelAttentionBlock(low_channels, out_channels) 142 | 143 | def forward(self, x_high, x_low): 144 | x_low = self.sab(x_low) 145 | x_high = self.up_conv(x_high) 146 | x_skip = x_high 147 | 148 | x_high = x_high * x_low 149 | x_skip2 = x_high 150 | 151 | x_high = self.cab(x_high) 152 | x_high = x_high * x_skip2 153 | 154 | x_high += x_skip 155 | return x_high 156 | 157 | 158 | class SpatialAttentionBlock(nn.Module): 159 | def __init__(self, channels): 160 | super().__init__() 161 | self.conv = conv1x1(channels, 1) 162 | 163 | def forward(self, x): 164 | x_s = self.conv(x) 165 | x_s = torch.sigmoid(x_s) 166 | x = x * x_s 167 | return x 168 | 169 | 170 | class ChannelAttentionBlock(nn.Module): 171 | def __init__(self, in_channels, out_channels): 172 | super().__init__() 173 | self.pool = nn.AdaptiveAvgPool2d(1) 174 | self.conv = conv1x1(in_channels, out_channels) 175 | 176 | def forward(self, x): 177 | x_c = self.pool(x) 178 | x_c = self.conv(x_c) 179 | x_c = torch.sigmoid(x_c) 180 | x = x * x_c 181 | return x -------------------------------------------------------------------------------- /models/backbone.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class ResNet(nn.Module): 5 | # Load ResNet pretrained on ImageNet from torchvision, see 6 | # https://pytorch.org/vision/stable/models/resnet.html 7 | def __init__(self, resnet_type, pretrained=True): 8 | super().__init__() 9 | from torchvision.models import resnet18, resnet34, resnet50, resnet101, resnet152 10 | 11 | resnet_hub = {'resnet18':resnet18, 'resnet34':resnet34, 'resnet50':resnet50, 12 | 'resnet101':resnet101, 'resnet152':resnet152} 13 | if resnet_type not in resnet_hub: 14 | raise ValueError(f'Unsupported ResNet type: {resnet_type}.\n') 15 | 16 | resnet = resnet_hub[resnet_type](pretrained=pretrained) 17 | self.conv1 = resnet.conv1 18 | self.bn1 = resnet.bn1 19 | self.relu = resnet.relu 20 | self.maxpool = resnet.maxpool 21 | self.layer1 = resnet.layer1 22 | self.layer2 = resnet.layer2 23 | self.layer3 = resnet.layer3 24 | self.layer4 = resnet.layer4 25 | 26 | def forward(self, x): 27 | x = self.conv1(x) # 2x down 28 | x = self.bn1(x) 29 | x = self.relu(x) 30 | x = self.maxpool(x) # 4x down 31 | x1 = self.layer1(x) 32 | x2 = self.layer2(x1) # 8x down 33 | x3 = self.layer3(x2) # 16x down 34 | x4 = self.layer4(x3) # 32x down 35 | 36 | return x1, x2, x3, x4 37 | 38 | 39 | class Mobilenetv2(nn.Module): 40 | def __init__(self, pretrained=True): 41 | super().__init__() 42 | from torchvision.models import mobilenet_v2 43 | 44 | mobilenet = mobilenet_v2(pretrained=pretrained) 45 | 46 | self.layer1 = mobilenet.features[:4] 47 | self.layer2 = mobilenet.features[4:7] 48 | self.layer3 = mobilenet.features[7:14] 49 | self.layer4 = mobilenet.features[14:18] 50 | 51 | def forward(self, x): 52 | x1 = self.layer1(x) # 4x down 53 | x2 = self.layer2(x1) # 8x down 54 | x3 = self.layer3(x2) # 16x down 55 | x4 = self.layer4(x3) # 32x down 56 | 57 | return x1, x2, x3, x4 -------------------------------------------------------------------------------- /models/bisenetv1.py: -------------------------------------------------------------------------------- 1 | """ 2 | Paper: BiSeNet: Bilateral Segmentation Network for Real-time Semantic Segmentation 3 | Url: https://arxiv.org/abs/1808.00897 4 | Create by: zh320 5 | Date: 2023/09/03 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | from .modules import conv1x1, ConvBNAct, SegHead 13 | from .backbone import ResNet 14 | from .model_registry import register_model 15 | 16 | 17 | @register_model() 18 | class BiSeNetv1(nn.Module): 19 | def __init__(self, num_class=1, n_channel=3, backbone_type='resnet18', act_type='relu',): 20 | super().__init__() 21 | self.spatial_path = SpatialPath(n_channel, 128, act_type=act_type) 22 | self.context_path = ContextPath(256, backbone_type, act_type=act_type) 23 | self.ffm = FeatureFusionModule(384, 256, act_type=act_type) 24 | self.seg_head = SegHead(256, num_class, act_type=act_type) 25 | 26 | def forward(self, x): 27 | size = x.size()[2:] 28 | x_s = self.spatial_path(x) 29 | x_c = self.context_path(x) 30 | x = self.ffm(x_s, x_c) 31 | x = self.seg_head(x) 32 | x = F.interpolate(x, size, mode='bilinear', align_corners=True) 33 | 34 | return x 35 | 36 | 37 | class SpatialPath(nn.Sequential): 38 | def __init__(self, in_channels, out_channels, act_type): 39 | super().__init__( 40 | ConvBNAct(in_channels, out_channels, 3, 2, act_type=act_type), 41 | ConvBNAct(out_channels, out_channels, 3, 2, act_type=act_type), 42 | ConvBNAct(out_channels, out_channels, 3, 2, act_type=act_type), 43 | ) 44 | 45 | 46 | class ContextPath(nn.Module): 47 | def __init__(self, out_channels, backbone_type, act_type): 48 | super().__init__() 49 | if 'resnet' in backbone_type: 50 | self.backbone = ResNet(backbone_type) 51 | channels = [256, 512] if ('18' in backbone_type) or ('34' in backbone_type) else [1024, 2048] 52 | else: 53 | raise NotImplementedError() 54 | 55 | self.pool = nn.AdaptiveAvgPool2d(1) 56 | self.arm_16 = AttentionRefinementModule(channels[0]) 57 | self.arm_32 = AttentionRefinementModule(channels[1]) 58 | 59 | self.conv_16 = conv1x1(channels[0], out_channels) 60 | self.conv_32 = conv1x1(channels[1], out_channels) 61 | 62 | def forward(self, x): 63 | _, _, x_16, x_32 = self.backbone(x) 64 | x_32_avg = self.pool(x_32) 65 | x_32 = self.arm_32(x_32) 66 | x_32 += x_32_avg 67 | x_32 = self.conv_32(x_32) 68 | x_32 = F.interpolate(x_32, scale_factor=2, mode='bilinear', align_corners=True) 69 | 70 | x_16 = self.arm_16(x_16) 71 | x_16 = self.conv_16(x_16) 72 | x_16 += x_32 73 | x_16 = F.interpolate(x_16, scale_factor=2, mode='bilinear', align_corners=True) 74 | 75 | return x_16 76 | 77 | 78 | class AttentionRefinementModule(nn.Module): 79 | def __init__(self, channels): 80 | super().__init__() 81 | self.pool = nn.AdaptiveAvgPool2d(1) 82 | self.conv = ConvBNAct(channels, channels, 1, act_type='sigmoid') 83 | 84 | def forward(self, x): 85 | x_pool = self.pool(x) 86 | x_pool = x_pool.expand_as(x) 87 | x_pool = self.conv(x_pool) 88 | x = x * x_pool 89 | 90 | return x 91 | 92 | 93 | class FeatureFusionModule(nn.Module): 94 | def __init__(self, in_channels, out_channels, act_type): 95 | super().__init__() 96 | self.conv1 = ConvBNAct(in_channels, out_channels, 3, act_type=act_type) 97 | self.pool = nn.AdaptiveAvgPool2d(1) 98 | self.conv2 = nn.Sequential( 99 | conv1x1(out_channels, out_channels), 100 | nn.ReLU(), 101 | conv1x1(out_channels, out_channels), 102 | nn.Sigmoid(), 103 | ) 104 | 105 | def forward(self, x_low, x_high): 106 | x = torch.cat([x_low, x_high], dim=1) 107 | x = self.conv1(x) 108 | 109 | x_pool = self.pool(x) 110 | x_pool = x_pool.expand_as(x) 111 | x_pool = self.conv2(x_pool) 112 | 113 | x_pool = x * x_pool 114 | x = x + x_pool 115 | 116 | return x -------------------------------------------------------------------------------- /models/canet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Paper: Cross Attention Network for Semantic Segmentation 3 | Url: https://arxiv.org/abs/1907.10958 4 | Create by: zh320 5 | Date: 2023/09/30 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | from .modules import ConvBNAct, DeConvBNAct, Activation 12 | from .backbone import ResNet, Mobilenetv2 13 | from .model_registry import register_model 14 | 15 | 16 | @register_model() 17 | class CANet(nn.Module): 18 | def __init__(self, num_class=1, n_channel=3, backbone_type='mobilenet_v2', act_type='relu'): 19 | super().__init__() 20 | self.spatial_branch = SpatialBranch(n_channel, 64, act_type) 21 | self.context_branch = ContextBranch(64*4, backbone_type) 22 | self.fca = FeatureCrossAttentionModule(64*4, num_class, act_type) 23 | self.up = DeConvBNAct(num_class, num_class, scale_factor=8) 24 | 25 | def forward(self, x): 26 | size = x.size()[2:] 27 | x_s = self.spatial_branch(x) 28 | x_c = self.context_branch(x) 29 | x = self.fca(x_s, x_c) 30 | x = self.up(x) 31 | 32 | return x 33 | 34 | 35 | class SpatialBranch(nn.Sequential): 36 | def __init__(self, n_channel, channels, act_type): 37 | super().__init__( 38 | ConvBNAct(n_channel, channels, 3, 2, act_type=act_type, inplace=True), 39 | ConvBNAct(channels, channels*2, 3, 2, act_type=act_type, inplace=True), 40 | ConvBNAct(channels*2, channels*4, 3, 2, act_type=act_type, inplace=True), 41 | ) 42 | 43 | 44 | class ContextBranch(nn.Module): 45 | def __init__(self, out_channels, backbone_type, hid_channels=192): 46 | super().__init__() 47 | if 'mobilenet' in backbone_type: 48 | self.backbone = Mobilenetv2() 49 | channels = [320, 96] 50 | elif 'resnet' in backbone_type: 51 | self.backbone = ResNet(backbone_type) 52 | channels = [512, 256] if (('18' in backbone_type) or ('34' in backbone_type)) else [2048, 1024] 53 | else: 54 | raise NotImplementedError() 55 | 56 | self.up1 = DeConvBNAct(channels[0], hid_channels) 57 | self.up2 = DeConvBNAct(channels[1] + hid_channels, out_channels) 58 | 59 | def forward(self, x): 60 | _, _, x_d16, x = self.backbone(x) 61 | x = self.up1(x) 62 | 63 | x = torch.cat([x, x_d16], dim=1) 64 | x = self.up2(x) 65 | 66 | return x 67 | 68 | 69 | class FeatureCrossAttentionModule(nn.Module): 70 | def __init__(self, in_channels, out_channels, act_type): 71 | super().__init__() 72 | self.conv_init = ConvBNAct(2*in_channels, in_channels, act_type=act_type, inplace=True) 73 | self.sa = SpatialAttentionBlock(in_channels) 74 | self.ca = ChannelAttentionBlock(in_channels) 75 | self.conv_last = ConvBNAct(in_channels, out_channels, inplace=True) 76 | 77 | def forward(self, x_s, x_c): 78 | x = torch.cat([x_s, x_c], dim=1) 79 | x_s = self.sa(x_s) 80 | x_c = self.ca(x_c) 81 | 82 | x = self.conv_init(x) 83 | residual = x 84 | 85 | x = x * x_s 86 | x = x * x_c 87 | x += residual 88 | 89 | x = self.conv_last(x) 90 | 91 | return x 92 | 93 | 94 | class SpatialAttentionBlock(nn.Sequential): 95 | def __init__(self, in_channels): 96 | super().__init__( 97 | ConvBNAct(in_channels, 1, act_type='sigmoid') 98 | ) 99 | 100 | 101 | class ChannelAttentionBlock(nn.Module): 102 | def __init__(self, in_channels): 103 | super().__init__() 104 | self.in_channels = in_channels 105 | self.max_pool = nn.AdaptiveMaxPool2d(1) 106 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 107 | self.fc = nn.Linear(in_channels, in_channels) 108 | 109 | def forward(self, x): 110 | x_max = self.max_pool(x).view(-1, self.in_channels) 111 | x_avg = self.avg_pool(x).view(-1, self.in_channels) 112 | 113 | x_max = self.fc(x_max) 114 | x_avg = self.fc(x_avg) 115 | 116 | x = x_max + x_avg 117 | x = torch.sigmoid(x) 118 | 119 | return x.unsqueeze(-1).unsqueeze(-1) -------------------------------------------------------------------------------- /models/cfpnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Paper: CFPNet: Channel-wise Feature Pyramid for Real-Time Semantic Segmentation 3 | Url: https://arxiv.org/abs/2103.12212 4 | Create by: zh320 5 | Date: 2023/09/30 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from math import ceil 12 | 13 | from .modules import ConvBNAct 14 | from .enet import InitialBlock as DownsamplingBlock 15 | from .model_registry import register_model 16 | 17 | 18 | @register_model() 19 | class CFPNet(nn.Module): 20 | def __init__(self, num_class=1, n_channel=3, n=2, m=6, dilations=[2,2,4,4,8,8,16,16], 21 | act_type='prelu'): 22 | super().__init__() 23 | assert len(dilations) == (n+m), f'Length of dilations should be equal to {n+m}.\n' 24 | self.conv_init = nn.Sequential( 25 | ConvBNAct(n_channel, 32, stride=2, act_type=act_type), 26 | ConvBNAct(32, 32, act_type=act_type), 27 | ConvBNAct(32, 32, act_type=act_type) 28 | ) 29 | self.downsample1 = DownsamplingBlock(32+3, 64, act_type) 30 | self.cfp1 = build_blocks(CFPModule, 64, n, dilations[:n], act_type) 31 | self.downsample2 = DownsamplingBlock(64+3, 128, act_type) 32 | self.cfp2 = build_blocks(CFPModule, 128, m, dilations[n:], act_type) 33 | self.seg_head = ConvBNAct(128+3, num_class, 1, act_type=act_type) 34 | 35 | def forward(self, x): 36 | size = x.size()[2:] 37 | x_d2 = F.interpolate(x, scale_factor=0.5, mode='bilinear', align_corners=True) 38 | x_d4 = F.interpolate(x, scale_factor=0.25, mode='bilinear', align_corners=True) 39 | x_d8 = F.interpolate(x, scale_factor=0.125, mode='bilinear', align_corners=True) 40 | 41 | x = self.conv_init(x) 42 | x = torch.cat([x, x_d2], dim=1) 43 | 44 | x = self.downsample1(x) 45 | x = self.cfp1(x) 46 | x = torch.cat([x, x_d4], dim=1) 47 | 48 | x = self.downsample2(x) 49 | x = self.cfp2(x) 50 | x = torch.cat([x, x_d8], dim=1) 51 | 52 | x = self.seg_head(x) 53 | x = F.interpolate(x, size, mode='bilinear', align_corners=True) 54 | 55 | return x 56 | 57 | 58 | def build_blocks(block, channels, num_block, dilations=[], act_type='relu'): 59 | if len(dilations) == 0: 60 | dilations = [1 for _ in range(num_block)] 61 | else: 62 | if len(dilations) != num_block: 63 | raise ValueError(f'Number of dilation should be equal to number of blocks') 64 | 65 | layers = [] 66 | for i in range(num_block): 67 | layers.append(block(channels, dilations[i], act_type=act_type)) 68 | return nn.Sequential(*layers) 69 | 70 | 71 | class CFPModule(nn.Module): 72 | def __init__(self, channels, rk, K=4, rk_ratio=None, act_type='prelu',): 73 | super().__init__() 74 | if rk_ratio is None: 75 | rk_ratio = [1/rk, 1/4, 1/2, 1] 76 | assert len(rk_ratio) == K, f'Length of rk_ratio should be {K}.\n' 77 | 78 | self.K = K 79 | channel_kn = channels // K 80 | 81 | self.conv_init = ConvBNAct(channels, channel_kn, 1, act_type=act_type) 82 | 83 | self.layers = nn.ModuleList() 84 | for k in range(K): 85 | dt = ceil(rk * rk_ratio[k]) # dilation 86 | self.layers.append(FeaturePyramidChannel(channel_kn, dt, act_type=act_type)) 87 | 88 | self.conv_last = ConvBNAct(channels, channels, 1, act_type=act_type) 89 | 90 | def forward(self, x): 91 | residual = x 92 | 93 | x = self.conv_init(x) # Projection 94 | 95 | transform_feats = [] # Parallel FP channels 96 | for i in range(self.K): 97 | transform_feats.append(self.layers[i](x)) 98 | 99 | for j in range(1, self.K): 100 | transform_feats[j] += transform_feats[j-1] 101 | 102 | x = torch.cat(transform_feats, dim=1) # Concatenation 103 | 104 | x = self.conv_last(x) 105 | 106 | x += residual 107 | 108 | return x 109 | 110 | 111 | class FeaturePyramidChannel(nn.Module): 112 | def __init__(self, channels, dilation, act_type, channel_split=[1,1,2]): 113 | super().__init__() 114 | split_num = sum(channel_split) 115 | assert channels % split_num == 0, f'Channel of FPC should be multiple of {split_num}.\n' 116 | ch_b1 = (channels // split_num) * channel_split[0] 117 | ch_b2 = (channels // split_num) * channel_split[1] 118 | ch_b3 = (channels // split_num) * channel_split[2] 119 | 120 | self.block1 = nn.Sequential( 121 | ConvBNAct(channels, ch_b1, (3, 1), dilation=dilation, act_type=act_type), 122 | ConvBNAct(ch_b1, ch_b1, (1, 3), dilation=dilation, act_type=act_type), 123 | ) 124 | self.block2 = nn.Sequential( 125 | ConvBNAct(ch_b1, ch_b2, (3, 1), dilation=dilation, act_type=act_type), 126 | ConvBNAct(ch_b2, ch_b2, (1, 3), dilation=dilation, act_type=act_type), 127 | ) 128 | self.block3 = nn.Sequential( 129 | ConvBNAct(ch_b2, ch_b3, (3, 1), dilation=dilation, act_type=act_type), 130 | ConvBNAct(ch_b3, ch_b3, (1, 3), dilation=dilation, act_type=act_type), 131 | ) 132 | 133 | def forward(self, x): 134 | x1 = self.block1(x) 135 | x2 = self.block2(x1) 136 | x3 = self.block3(x2) 137 | 138 | x = torch.cat([x1, x2, x3], dim=1) 139 | 140 | return x -------------------------------------------------------------------------------- /models/cgnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Paper: CGNet: A Light-weight Context Guided Network for Semantic Segmentation 3 | Url: https://arxiv.org/abs/1811.08201 4 | Create by: zh320 5 | Date: 2023/09/24 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | from .modules import conv1x1, ConvBNAct, Activation 13 | from .model_registry import register_model 14 | 15 | 16 | @register_model() 17 | class CGNet(nn.Module): 18 | def __init__(self, num_class=1, n_channel=3, M=3, N=15, act_type='prelu'): 19 | super().__init__() 20 | self.stage1 = InitBlock(n_channel, 32, act_type=act_type) 21 | self.stage2_down = CGBlock(64, 64, 2, 2, act_type=act_type) 22 | self.stage2 = build_blocks(CGBlock, 64+3, 64, 2, M-1, act_type) 23 | self.stage3_down = CGBlock(128, 128, 2, 4, act_type=act_type) 24 | self.stage3 = build_blocks(CGBlock, 128+3, 128, 4, N-1, act_type) 25 | self.seg_head = conv1x1(128*2, num_class) 26 | 27 | def forward(self, x): 28 | size = x.size()[2:] 29 | x_d4 = F.interpolate(x, scale_factor=0.25, mode='bilinear', align_corners=True) 30 | x_d8 = F.interpolate(x, scale_factor=0.125, mode='bilinear', align_corners=True) 31 | 32 | x, x1 = self.stage1(x) 33 | 34 | x = torch.cat([x, x1], dim=1) 35 | x2 = self.stage2_down(x) 36 | x = torch.cat([x2, x_d4], dim=1) # Input injection 37 | x = self.stage2(x) 38 | 39 | x = torch.cat([x, x2], dim=1) 40 | x3 = self.stage3_down(x) 41 | x = torch.cat([x3, x_d8], dim=1) # Input injection 42 | x = self.stage3(x) 43 | 44 | x = torch.cat([x, x3], dim=1) 45 | x = self.seg_head(x) 46 | x = F.interpolate(x, size, mode='bilinear', align_corners=True) 47 | 48 | return x 49 | 50 | 51 | class InitBlock(nn.Module): 52 | def __init__(self, in_channels, out_channels, act_type): 53 | super().__init__() 54 | self.conv0 = ConvBNAct(in_channels, out_channels, stride=2, act_type=act_type) 55 | self.conv1 = ConvBNAct(out_channels, out_channels, act_type=act_type) 56 | self.conv2 = ConvBNAct(out_channels, out_channels, act_type=act_type) 57 | 58 | def forward(self, x): 59 | x0 = self.conv0(x) 60 | x = self.conv1(x0) 61 | x = self.conv2(x) 62 | return x, x0 63 | 64 | 65 | def build_blocks(block, in_channels, out_channels, dilation, num_block, act_type): 66 | layers = [] 67 | for _ in range(num_block): 68 | layers.append(block(in_channels, out_channels, 1, dilation, act_type=act_type)) 69 | in_channels = out_channels 70 | return nn.Sequential(*layers) 71 | 72 | 73 | class CGBlock(nn.Module): 74 | def __init__(self, in_channels, out_channels, stride, dilation, res_type='GRL', act_type='prelu'): 75 | super().__init__() 76 | if res_type not in ['GRL', 'LRL']: 77 | raise ValueError('Residual learning only support GRL and LRL type.\n') 78 | self.res_type = res_type 79 | self.use_skip = (stride == 1) and (in_channels == out_channels) 80 | self.conv = conv1x1(in_channels, out_channels//2) 81 | self.loc = nn.Conv2d(out_channels//2, out_channels//2, 3, stride, padding=1, 82 | groups=out_channels//2, bias=False) 83 | self.sur = nn.Conv2d(out_channels//2, out_channels//2, 3, stride, padding=dilation, 84 | dilation=dilation, groups=out_channels//2, bias=False) 85 | self.joi = nn.Sequential( 86 | nn.BatchNorm2d(out_channels), 87 | Activation(act_type) 88 | ) 89 | self.glo = nn.Sequential( 90 | nn.Linear(out_channels, out_channels//8), 91 | nn.Linear(out_channels//8, out_channels) 92 | ) 93 | 94 | def forward(self, x): 95 | residual = x 96 | x = self.conv(x) 97 | 98 | x_loc = self.loc(x) 99 | x_sur = self.sur(x) 100 | 101 | x = torch.cat([x_loc, x_sur], dim=1) 102 | x = self.joi(x) 103 | 104 | if self.use_skip and self.res_type == 'LRL': 105 | x += residual 106 | 107 | x_glo = F.adaptive_avg_pool2d(x, 1).squeeze(-1).squeeze(-1) 108 | x_glo = torch.sigmoid(self.glo(x_glo)) 109 | x_glo = x_glo.unsqueeze(-1).unsqueeze(-1).expand_as(x) 110 | x = x * x_glo 111 | 112 | if self.use_skip and self.res_type == 'GRL': 113 | x += residual 114 | 115 | return x -------------------------------------------------------------------------------- /models/contextnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Paper: ContextNet: Exploring Context and Detail for Semantic Segmentation in Real-time 3 | Url: https://arxiv.org/abs/1805.04554 4 | Create by: zh320 5 | Date: 2023/05/13 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | from .modules import conv1x1, DSConvBNAct, DWConvBNAct, PWConvBNAct, ConvBNAct, Activation 13 | from .model_registry import register_model 14 | 15 | 16 | @register_model() 17 | class ContextNet(nn.Module): 18 | def __init__(self, num_class=1, n_channel=3, act_type='relu'): 19 | super().__init__() 20 | self.full_res_branch = Branch_1(n_channel, [32, 64, 128], 128, act_type=act_type) 21 | self.lower_res_branch = Branch_4(n_channel, 128, act_type=act_type) 22 | self.feature_fusion = FeatureFusion(128, 128, 128, act_type=act_type) 23 | self.classifier = ConvBNAct(128, num_class, 1, act_type=act_type) 24 | 25 | def forward(self, x): 26 | size = x.size()[2:] 27 | x_lower = F.interpolate(x, scale_factor=0.25, mode='bilinear', align_corners=True) 28 | full_res_feat = self.full_res_branch(x) 29 | lower_res_feat = self.lower_res_branch(x_lower) 30 | x = self.feature_fusion(full_res_feat, lower_res_feat) 31 | x = self.classifier(x) 32 | x = F.interpolate(x, size, mode='bilinear', align_corners=True) 33 | 34 | return x 35 | 36 | 37 | class Branch_1(nn.Sequential): 38 | def __init__(self, in_channels, hid_channels, out_channels, act_type='relu'): 39 | assert len(hid_channels) == 3 40 | super().__init__( 41 | ConvBNAct(in_channels, hid_channels[0], 3, 2, act_type=act_type), 42 | DWConvBNAct(hid_channels[0], hid_channels[0], 3, 1, act_type='none'), 43 | PWConvBNAct(hid_channels[0], hid_channels[1], act_type=act_type), 44 | DWConvBNAct(hid_channels[1], hid_channels[1], 3, 1, act_type='none'), 45 | PWConvBNAct(hid_channels[1], hid_channels[2], act_type=act_type), 46 | DWConvBNAct(hid_channels[2], hid_channels[2], 3, 1, act_type='none'), 47 | PWConvBNAct(hid_channels[2], out_channels, act_type=act_type) 48 | ) 49 | 50 | 51 | class Branch_4(nn.Module): 52 | def __init__(self, in_channels, out_channels, act_type='relu'): 53 | super().__init__() 54 | self.conv_init = ConvBNAct(in_channels, 32, 3, 2, act_type=act_type) 55 | inverted_residual_setting = [ 56 | # t, c, n, s 57 | [1, 32, 1, 1], 58 | [6, 32, 1, 1], 59 | [6, 48, 3, 2], 60 | [6, 64, 3, 2], 61 | [6, 96, 2, 1], 62 | [6, 128, 2, 1], 63 | ] 64 | 65 | # Building inverted residual blocks, codes borrowed from 66 | # https://github.com/pytorch/vision/blob/main/torchvision/models/mobilenetv2.py 67 | features = [] 68 | in_channels = 32 69 | for t, c, n, s in inverted_residual_setting: 70 | for i in range(n): 71 | stride = s if i == 0 else 1 72 | features.append(InvertedResidual(in_channels, c, stride, t, act_type=act_type)) 73 | in_channels = c 74 | self.bottlenecks = nn.Sequential(*features) 75 | self.conv_last = ConvBNAct(128, out_channels, 3, 1, act_type=act_type) 76 | 77 | def forward(self, x): 78 | x = self.conv_init(x) 79 | x = self.bottlenecks(x) 80 | x = self.conv_last(x) 81 | 82 | return x 83 | 84 | 85 | class FeatureFusion(nn.Module): 86 | def __init__(self, branch_1_channels, branch_4_channels, out_channels, act_type='relu'): 87 | super().__init__() 88 | self.branch_1_conv = conv1x1(branch_1_channels, out_channels) 89 | self.branch_4_conv = nn.Sequential( 90 | DSConvBNAct(branch_4_channels, out_channels, 3, dilation=4, act_type='none'), 91 | conv1x1(out_channels, out_channels) 92 | ) 93 | self.act = Activation(act_type=act_type) 94 | 95 | def forward(self, branch_1_feat, branch_4_feat): 96 | size = branch_1_feat.size()[2:] 97 | 98 | branch_1_feat = self.branch_1_conv(branch_1_feat) 99 | 100 | branch_4_feat = F.interpolate(branch_4_feat, size, mode='bilinear', align_corners=True) 101 | branch_4_feat = self.branch_4_conv(branch_4_feat) 102 | 103 | res = branch_1_feat + branch_4_feat 104 | res = self.act(res) 105 | 106 | return res 107 | 108 | 109 | class InvertedResidual(nn.Module): 110 | def __init__(self, in_channels, out_channels, stride, expand_ratio=6, act_type='relu'): 111 | super().__init__() 112 | hid_channels = int(round(in_channels * expand_ratio)) 113 | self.use_res_connect = stride == 1 and in_channels == out_channels 114 | 115 | self.conv = nn.Sequential( 116 | PWConvBNAct(in_channels, hid_channels, act_type=act_type), 117 | DWConvBNAct(hid_channels, hid_channels, 3, stride, act_type=act_type), 118 | ConvBNAct(hid_channels, out_channels, 1, act_type='none') 119 | ) 120 | 121 | def forward(self, x): 122 | if self.use_res_connect: 123 | return x + self.conv(x) 124 | else: 125 | return self.conv(x) -------------------------------------------------------------------------------- /models/dabnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Paper: DABNet: Depth-wise Asymmetric Bottleneck for Real-time Semantic Segmentation 3 | Url: https://arxiv.org/abs/1907.11357 4 | Create by: zh320 5 | Date: 2023/08/27 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | from .modules import conv1x1, DWConvBNAct, ConvBNAct 13 | from .enet import InitialBlock 14 | from .model_registry import register_model 15 | 16 | 17 | @register_model() 18 | class DABNet(nn.Module): 19 | def __init__(self, num_class=1, n_channel=3, act_type='prelu'): 20 | super().__init__() 21 | self.layer1 = ConvBNAct(n_channel, 32, 3, 2, act_type=act_type) 22 | self.layer2 = ConvBNAct(32, 32, 3, 1, act_type=act_type) 23 | self.layer3 = ConvBNAct(32, 32, 3, 1, act_type=act_type) 24 | self.layer4 = InitialBlock(32+n_channel, 64, act_type=act_type) 25 | self.layer5_7 = build_blocks(DABModule, 64, 3, dilation=2, act_type=act_type) 26 | self.layer8 = ConvBNAct(64*2+n_channel, 128, 3, 2, act_type=act_type) 27 | self.layer9_10 = build_blocks(DABModule, 128, 2, dilation=4, act_type=act_type) 28 | self.layer11_12 = build_blocks(DABModule, 128, 2, dilation=8, act_type=act_type) 29 | self.layer13_14 = build_blocks(DABModule, 128, 2, dilation=16, act_type=act_type) 30 | self.layer15 = conv1x1(128*2+n_channel, num_class) 31 | 32 | def forward(self, x): 33 | size = x.size()[2:] 34 | x_d2 = F.avg_pool2d(x, 3, 2, 1) 35 | x_d4 = F.avg_pool2d(x_d2, 3, 2, 1) 36 | x_d8 = F.avg_pool2d(x_d4, 3, 2, 1) 37 | 38 | # Stage 1 39 | x = self.layer1(x) 40 | x = self.layer2(x) 41 | x = self.layer3(x) 42 | x = torch.cat([x, x_d2], dim=1) 43 | 44 | # Stage 2 45 | x = self.layer4(x) 46 | x_block1 = x 47 | x = self.layer5_7(x) 48 | x = torch.cat([x, x_block1], dim=1) 49 | x = torch.cat([x, x_d4], dim=1) 50 | 51 | # Stage 3 52 | x = self.layer8(x) 53 | x_block2 = x 54 | x = self.layer9_10(x) 55 | x = self.layer11_12(x) 56 | x = self.layer13_14(x) 57 | x = torch.cat([x, x_block2], dim=1) 58 | x = torch.cat([x, x_d8], dim=1) 59 | 60 | # Stage 4 61 | x = self.layer15(x) 62 | x = F.interpolate(x, size, mode='bilinear', align_corners=True) 63 | return x 64 | 65 | 66 | def build_blocks(block, channels, num_block, dilation, act_type): 67 | layers = [] 68 | for _ in range(num_block): 69 | layers.append(block(channels, dilation, act_type=act_type)) 70 | return nn.Sequential(*layers) 71 | 72 | 73 | class DABModule(nn.Module): 74 | def __init__(self, channels, dilation, act_type): 75 | super().__init__() 76 | assert channels % 2 == 0, 'Input channel of DABModule should be multiple of 2.\n' 77 | hid_channels = channels // 2 78 | self.init_conv = ConvBNAct(channels, hid_channels, 3, act_type=act_type) 79 | self.left_branch = nn.Sequential( 80 | DWConvBNAct(hid_channels, hid_channels, (3,1), act_type=act_type), 81 | DWConvBNAct(hid_channels, hid_channels, (1,3), act_type=act_type) 82 | ) 83 | self.right_branch = nn.Sequential( 84 | DWConvBNAct(hid_channels, hid_channels, (3,1), dilation=dilation, act_type=act_type), 85 | DWConvBNAct(hid_channels, hid_channels, (1,3), dilation=dilation, act_type=act_type) 86 | ) 87 | self.last_conv = ConvBNAct(hid_channels, channels, 1, act_type=act_type) 88 | 89 | def forward(self, x): 90 | residual = x 91 | x = self.init_conv(x) 92 | 93 | x_left = self.left_branch(x) 94 | x_right = self.right_branch(x) 95 | x = x_left + x_right 96 | 97 | x = self.last_conv(x) 98 | x += residual 99 | 100 | return x -------------------------------------------------------------------------------- /models/dfanet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Paper: DFANet: Deep Feature Aggregation for Real-Time Semantic Segmentation 3 | Url: https://arxiv.org/abs/1904.02216 4 | Create by: zh320 5 | Date: 2023/10/22 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | from .modules import conv1x1, DSConvBNAct, DWConvBNAct, ConvBNAct, Activation, SegHead 13 | from .model_registry import register_model 14 | 15 | 16 | @register_model() 17 | class DFANet(nn.Module): 18 | def __init__(self, num_class=1, n_channel=3, backbone_type='XceptionA', expansion=4, 19 | repeat_times=[4,6,4], use_extra_backbone=True, act_type='relu'): 20 | super().__init__() 21 | assert len(repeat_times) == 3 22 | if backbone_type == 'XceptionA': 23 | channels = [48, 96, 192] 24 | elif backbone_type == 'XceptionB': 25 | channels = [32, 64, 128] 26 | else: 27 | raise NotImplementedError() 28 | self.use_extra_backbone = use_extra_backbone 29 | 30 | self.conv1 = ConvBNAct(n_channel, 8, 3, 2, act_type=act_type) 31 | 32 | in_channels = [8, channels[0], channels[1]] 33 | self.backbone1 = Encoder(in_channels, channels, expansion, repeat_times, act_type) 34 | 35 | if self.use_extra_backbone: 36 | # Rotate the channels to perform features fusion 37 | new_channels = channels[2:] + channels[:2] 38 | in_channels = [(channels[i] + new_channels[i]) for i in range(len(channels))] 39 | self.backbone2 = Encoder(in_channels, channels, expansion, repeat_times, act_type) 40 | self.backbone3 = Encoder(in_channels, channels, expansion, repeat_times, act_type) 41 | 42 | self.decoder = Decoder(channels[0], channels[2], num_class, act_type) 43 | else: 44 | self.seg_head = SegHead(channels[2], num_class, act_type) 45 | 46 | def forward(self, x): 47 | x = self.conv1(x) 48 | 49 | x, x_enc2, x_enc3, x_enc4 = self.backbone1(x) 50 | 51 | if self.use_extra_backbone: 52 | enc_x1, fc_x1 = x_enc2, x 53 | x = F.interpolate(x, scale_factor=4, mode='bilinear', align_corners=True) 54 | 55 | x, x_enc2, x_enc3, x_enc4 = self.backbone2(x, x_enc2, x_enc3, x_enc4) 56 | enc_x2, fc_x2 = x_enc2, x 57 | x = F.interpolate(x, scale_factor=4, mode='bilinear', align_corners=True) 58 | 59 | fc_x3, enc_x3, _, _ = self.backbone3(x, x_enc2, x_enc3, x_enc4) 60 | 61 | x = self.decoder(enc_x1, enc_x2, enc_x3, fc_x1, fc_x2, fc_x3) 62 | else: 63 | x = self.seg_head(x) 64 | x = F.interpolate(x, scale_factor=16, mode='bilinear', align_corners=True) 65 | 66 | return x 67 | 68 | 69 | class Encoder(nn.Module): 70 | def __init__(self, in_channels, channels, expansion, repeat_times, act_type): 71 | super().__init__() 72 | assert len(in_channels) == 3 73 | self.enc2 = EncoderBlock(in_channels[0], channels[0], expansion, repeat_times[0], act_type) 74 | self.enc3 = EncoderBlock(in_channels[1], channels[1], expansion, repeat_times[1], act_type) 75 | self.enc4 = EncoderBlock(in_channels[2], channels[2], expansion, repeat_times[2], act_type) 76 | self.fc_attention = FCAttention(channels[2], act_type) 77 | 78 | def forward(self, x, x_enc2=None, x_enc3=None, x_enc4=None): 79 | if x_enc2 is not None: 80 | x = torch.cat([x, x_enc2], dim=1) 81 | x = self.enc2(x) 82 | x_enc2 = x 83 | 84 | if x_enc3 is not None: 85 | x = torch.cat([x, x_enc3], dim=1) 86 | x = self.enc3(x) 87 | x_enc3 = x 88 | 89 | if x_enc4 is not None: 90 | x = torch.cat([x, x_enc4], dim=1) 91 | x = self.enc4(x) 92 | x_enc4 = x 93 | 94 | x = self.fc_attention(x) 95 | 96 | return x, x_enc2, x_enc3, x_enc4 97 | 98 | 99 | class Decoder(nn.Module): 100 | def __init__(self, enc_channels, fc_channels, num_class, act_type, hid_channels=48): 101 | super().__init__() 102 | self.enc_conv1 = ConvBNAct(enc_channels, hid_channels, 3, act_type=act_type, inplace=True) 103 | self.enc_conv2 = ConvBNAct(enc_channels, hid_channels, 3, act_type=act_type, inplace=True) 104 | self.enc_conv3 = ConvBNAct(enc_channels, hid_channels, 3, act_type=act_type, inplace=True) 105 | self.conv_enc = conv1x1(hid_channels, num_class) 106 | 107 | self.fc_conv1 = SegHead(fc_channels, num_class, act_type) 108 | self.fc_conv2 = SegHead(fc_channels, num_class, act_type) 109 | self.fc_conv3 = SegHead(fc_channels, num_class, act_type) 110 | 111 | def forward(self, enc_x1, enc_x2, enc_x3, fc_x1, fc_x2, fc_x3): 112 | enc_x1 = self.enc_conv1(enc_x1) 113 | enc_x2 = self.enc_conv2(enc_x2) 114 | enc_x2 = F.interpolate(enc_x2, scale_factor=2, mode='bilinear', align_corners=True) 115 | enc_x3 = self.enc_conv3(enc_x3) 116 | enc_x3 = F.interpolate(enc_x3, scale_factor=4, mode='bilinear', align_corners=True) 117 | 118 | enc_x = enc_x1 + enc_x2 + enc_x3 119 | enc_x = self.conv_enc(enc_x) 120 | 121 | fc_x1 = self.fc_conv1(fc_x1) 122 | fc_x1 = F.interpolate(fc_x1, scale_factor=4, mode='bilinear', align_corners=True) 123 | fc_x2 = self.fc_conv2(fc_x2) 124 | fc_x2 = F.interpolate(fc_x2, scale_factor=8, mode='bilinear', align_corners=True) 125 | fc_x3 = self.fc_conv3(fc_x3) 126 | fc_x3 = F.interpolate(fc_x3, scale_factor=16, mode='bilinear', align_corners=True) 127 | 128 | x = enc_x + fc_x1 + fc_x2 + fc_x3 129 | x = F.interpolate(x, scale_factor=4, mode='bilinear', align_corners=True) 130 | 131 | return x 132 | 133 | 134 | class EncoderBlock(nn.Module): 135 | def __init__(self, in_channels, out_channels, expansion, repeat_times, act_type): 136 | super().__init__() 137 | layers = [XceptionBlock(in_channels, out_channels, 2, expansion, act_type)] 138 | 139 | for _ in range(repeat_times-1): 140 | layers.append(XceptionBlock(out_channels, out_channels, 1, expansion, act_type)) 141 | self.conv = nn.Sequential(*layers) 142 | 143 | def forward(self, x): 144 | return self.conv(x) 145 | 146 | 147 | class FCAttention(nn.Module): 148 | def __init__(self, channels, act_type, linear_channels=1000): 149 | super().__init__() 150 | self.channels = channels 151 | self.pool = nn.AdaptiveMaxPool2d(1) 152 | self.linear = nn.Linear(channels, linear_channels) 153 | self.conv = ConvBNAct(linear_channels, channels, 1, act_type=act_type, inplace=True) 154 | 155 | def forward(self, x): 156 | attention = self.pool(x).view(-1, self.channels) 157 | attention = self.linear(attention) 158 | attention = attention.unsqueeze(-1).unsqueeze(-1) 159 | attention = self.conv(attention) 160 | x = x * attention 161 | 162 | return x 163 | 164 | 165 | class XceptionBlock(nn.Module): 166 | def __init__(self, in_channels, out_channels, stride, expansion, act_type): 167 | super().__init__() 168 | self.use_skip = (in_channels == out_channels) and (stride == 1) 169 | self.stride = stride 170 | hid_channels = out_channels // expansion 171 | self.conv = nn.Sequential( 172 | # Activation(act_type, inplace=True), 173 | DSConvBNAct(in_channels, hid_channels, 3, act_type=act_type), 174 | DSConvBNAct(hid_channels, hid_channels, 3, act_type=act_type), 175 | DWConvBNAct(hid_channels, out_channels, 3, stride, act_type=act_type, inplace=True), 176 | conv1x1(out_channels, out_channels), 177 | Activation(act_type), 178 | ) 179 | if stride > 1: 180 | self.conv_stride = conv1x1(in_channels, out_channels, 2) 181 | 182 | def forward(self, x): 183 | if self.use_skip: 184 | residual = x 185 | 186 | x_right = self.conv(x) 187 | 188 | if self.stride > 1: 189 | x_left = self.conv_stride(x) 190 | x_right += x_left 191 | 192 | if self.use_skip: 193 | x_right += residual 194 | 195 | return x_right -------------------------------------------------------------------------------- /models/edanet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Paper: Efficient Dense Modules of Asymmetric Convolution for Real-Time Semantic Segmentation 3 | Url: https://arxiv.org/abs/1809.06323 4 | Create by: zh320 5 | Date: 2023/09/24 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | from .modules import conv1x1, conv3x3, ConvBNAct, Activation 13 | from .model_registry import register_model 14 | 15 | 16 | @register_model() 17 | class EDANet(nn.Module): 18 | def __init__(self, num_class=1, n_channel=3, k=40, num_b1=5, num_b2=8, act_type='relu'): 19 | super().__init__() 20 | self.stage1 = DownsamplingBlock(n_channel, 15, act_type) 21 | self.stage2_d = DownsamplingBlock(15, 60, act_type) 22 | self.stage2 = EDABlock(60, k, num_b1, [1,1,1,2,2], act_type) 23 | self.stage3_d = ConvBNAct(260, 130, 3, 2, act_type=act_type) 24 | self.stage3 = EDABlock(130, k, num_b2, [2,2,4,4,8,8,16,16], act_type) 25 | self.project = conv1x1(130+k*num_b2, num_class) 26 | 27 | def forward(self, x): 28 | size = x.size()[2:] 29 | x = self.stage1(x) 30 | x = self.stage2_d(x) 31 | x = self.stage2(x) 32 | x = self.stage3_d(x) 33 | x = self.stage3(x) 34 | x = self.project(x) 35 | x = F.interpolate(x, size, mode='bilinear', align_corners=True) 36 | 37 | return x 38 | 39 | 40 | class DownsamplingBlock(nn.Module): 41 | def __init__(self, in_channels, out_channels, act_type): 42 | super().__init__() 43 | self.conv = conv3x3(in_channels, out_channels - in_channels, 2) 44 | self.pool = nn.MaxPool2d(2, 2) 45 | self.bn_act = nn.Sequential( 46 | nn.BatchNorm2d(out_channels), 47 | Activation(act_type) 48 | ) 49 | 50 | def forward(self, x): 51 | x = torch.cat([self.conv(x), self.pool(x)], dim=1) 52 | return self.bn_act(x) 53 | 54 | 55 | class EDABlock(nn.Module): 56 | def __init__(self, in_channels, k, num_block, dilations, act_type): 57 | super().__init__() 58 | assert len(dilations) == num_block, 'number of dilation rate should be equal to number of block' 59 | 60 | layers = [] 61 | for i in range(num_block): 62 | dt = dilations[i] 63 | layers.append(EDAModule(in_channels, k, dt, act_type)) 64 | in_channels += k 65 | self.layers = nn.Sequential(*layers) 66 | 67 | def forward(self, x): 68 | return self.layers(x) 69 | 70 | 71 | class EDAModule(nn.Module): 72 | def __init__(self, in_channels, k, dilation=1, act_type='relu'): 73 | super().__init__() 74 | self.conv = nn.Sequential( 75 | ConvBNAct(in_channels, k, 1), 76 | nn.Conv2d(k, k, (3, 1), padding=(1, 0), bias=False), 77 | ConvBNAct(k, k, (1, 3), act_type=act_type), 78 | nn.Conv2d(k, k, (3, 1), dilation=dilation, 79 | padding=(dilation, 0), bias=False), 80 | ConvBNAct(k, k, (1, 3), dilation=dilation, act_type=act_type) 81 | ) 82 | 83 | def forward(self, x): 84 | residual = x 85 | x = self.conv(x) 86 | x = torch.cat([x, residual], dim=1) 87 | return x -------------------------------------------------------------------------------- /models/erfnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Paper: ERFNet: Efficient Residual Factorized ConvNet for Real-Time Semantic Segmentation 3 | Url: https://ieeexplore.ieee.org/document/8063438 4 | Create by: zh320 5 | Date: 2023/08/20 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | from .modules import ConvBNAct, DeConvBNAct, Activation 12 | from .enet import InitialBlock as DownsamplerBlock 13 | from .model_registry import register_model 14 | 15 | 16 | @register_model() 17 | class ERFNet(nn.Module): 18 | def __init__(self, num_class=1, n_channel=3, act_type='relu'): 19 | super().__init__() 20 | self.layer1 = DownsamplerBlock(n_channel, 16, act_type=act_type) 21 | 22 | self.layer2 = DownsamplerBlock(16, 64, act_type=act_type) 23 | self.layer3_7 = build_blocks(NonBt1DBlock, 64, 5, act_type=act_type) 24 | 25 | self.layer8 = DownsamplerBlock(64, 128, act_type=act_type) 26 | self.layer9_16 = build_blocks(NonBt1DBlock, 128, 8, 27 | dilations=[2,4,8,16,2,4,8,16], act_type=act_type) 28 | 29 | self.layer17 = DeConvBNAct(128, 64, act_type=act_type) 30 | self.layer18_19 = build_blocks(NonBt1DBlock, 64, 2, act_type=act_type) 31 | 32 | self.layer20 = DeConvBNAct(64, 16, act_type=act_type) 33 | self.layer21_22 = build_blocks(NonBt1DBlock, 16, 2, act_type=act_type) 34 | 35 | self.layer23 = DeConvBNAct(16, num_class, act_type=act_type) 36 | 37 | def forward(self, x): 38 | x = self.layer1(x) 39 | x = self.layer2(x) 40 | x = self.layer3_7(x) 41 | x = self.layer8(x) 42 | x = self.layer9_16(x) 43 | x = self.layer17(x) 44 | x = self.layer18_19(x) 45 | x = self.layer20(x) 46 | x = self.layer21_22(x) 47 | x = self.layer23(x) 48 | return x 49 | 50 | 51 | def build_blocks(block, channels, num_block, dilations=[], act_type='relu'): 52 | if len(dilations) == 0: 53 | dilations = [1 for _ in range(num_block)] 54 | else: 55 | if len(dilations) != num_block: 56 | raise ValueError(f'Number of dilation should be equal to number of blocks') 57 | 58 | layers = [] 59 | for i in range(num_block): 60 | layers.append(block(channels, dilation=dilations[i], act_type=act_type)) 61 | return nn.Sequential(*layers) 62 | 63 | 64 | class NonBt1DBlock(nn.Module): 65 | def __init__(self, channels, dilation=1, act_type='relu'): 66 | super().__init__() 67 | self.conv = nn.Sequential( 68 | ConvBNAct(channels, channels, (3, 1), inplace=True), 69 | ConvBNAct(channels, channels, (1, 3), inplace=True), 70 | ConvBNAct(channels, channels, (3, 1), dilation=dilation, inplace=True), 71 | nn.Conv2d(channels, channels, (1, 3), dilation=dilation, 72 | padding=(0, dilation), bias=False) 73 | ) 74 | self.bn_act = nn.Sequential( 75 | nn.BatchNorm2d(channels), 76 | Activation(act_type, inplace=True) 77 | ) 78 | 79 | def forward(self, x): 80 | residual = x 81 | x = self.conv(x) 82 | x += residual 83 | x = self.bn_act(x) 84 | return x -------------------------------------------------------------------------------- /models/esnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Paper: ESNet: An Efficient Symmetric Network for Real-time Semantic Segmentation 3 | Url: https://arxiv.org/abs/1906.09826 4 | Create by: zh320 5 | Date: 2023/09/24 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | from .modules import ConvBNAct, DeConvBNAct, Activation 13 | from .enet import InitialBlock as DownsamplingUnit 14 | from .model_registry import register_model 15 | 16 | 17 | @register_model() 18 | class ESNet(nn.Module): 19 | def __init__(self, num_class=1, n_channel=3, act_type='relu'): 20 | super().__init__() 21 | self.block1_down = DownsamplingUnit(n_channel, 16, act_type) 22 | self.block1 = build_blocks('fcu', 16, 3, K=3, act_type=act_type) 23 | self.block2_down = DownsamplingUnit(16, 64, act_type) 24 | self.block2 = build_blocks('fcu', 64, 2, K=5, act_type=act_type) 25 | self.block3_down = DownsamplingUnit(64, 128, act_type) 26 | self.block3 = build_blocks('pfcu', 128, 3, r1=2, r2=5, r3=9, act_type=act_type) 27 | self.block4_up = DeConvBNAct(128, 64, act_type=act_type) 28 | self.block4 = build_blocks('fcu', 64, 2, K=5, act_type=act_type) 29 | self.block5_up = DeConvBNAct(64, 16, act_type=act_type) 30 | self.block5 = build_blocks('fcu', 16, 2, K=3, act_type=act_type) 31 | self.full_conv = DeConvBNAct(16, num_class, act_type=act_type) 32 | 33 | def forward(self, x): 34 | x = self.block1_down(x) 35 | x = self.block1(x) 36 | 37 | x = self.block2_down(x) 38 | x = self.block2(x) 39 | 40 | x = self.block3_down(x) 41 | x = self.block3(x) 42 | 43 | x = self.block4_up(x) 44 | x = self.block4(x) 45 | 46 | x = self.block5_up(x) 47 | x = self.block5(x) 48 | 49 | x = self.full_conv(x) 50 | 51 | return x 52 | 53 | 54 | def build_blocks(block_type, channels, num_block, K=None, r1=None, r2=None, r3=None, 55 | act_type='relu'): 56 | layers = [] 57 | for _ in range(num_block): 58 | if block_type == 'fcu': 59 | layers.append(FCU(channels, K, act_type)) 60 | elif block_type == 'pfcu': 61 | layers.append(PFCU(channels, r1, r2, r3, act_type)) 62 | else: 63 | raise NotImplementedError(f'Unsupported block type: {block_type}.\n') 64 | return nn.Sequential(*layers) 65 | 66 | 67 | class FCU(nn.Module): 68 | def __init__(self, channels, K, act_type): 69 | super().__init__() 70 | assert K is not None, 'K should not be None.\n' 71 | padding = (K - 1) // 2 72 | self.conv = nn.Sequential( 73 | nn.Conv2d(channels, channels, (K, 1), padding=(padding, 0), bias=False), 74 | Activation(act_type, inplace=True), 75 | ConvBNAct(channels, channels, (1, K), act_type=act_type, inplace=True), 76 | nn.Conv2d(channels, channels, (K, 1), padding=(padding, 0), bias=False), 77 | Activation(act_type, inplace=True), 78 | ConvBNAct(channels, channels, (1, K), act_type='none') 79 | ) 80 | self.act = Activation(act_type) 81 | 82 | def forward(self, x): 83 | residual = x 84 | 85 | x = self.conv(x) 86 | x += residual 87 | 88 | return self.act(x) 89 | 90 | 91 | class PFCU(nn.Module): 92 | def __init__(self, channels, r1, r2, r3, act_type): 93 | super().__init__() 94 | assert (r1 is not None) and (r2 is not None) and (r3 is not None) 95 | 96 | self.conv0 = nn.Sequential( 97 | nn.Conv2d(channels, channels, (3, 1), padding=(1, 0), bias=False), 98 | Activation(act_type, inplace=True), 99 | ConvBNAct(channels, channels, (1, 3), act_type=act_type, inplace=True) 100 | ) 101 | self.conv_left = nn.Sequential( 102 | nn.Conv2d(channels, channels, (3, 1), padding=(r1, 0), 103 | dilation=r1, bias=False), 104 | Activation(act_type, inplace=True), 105 | ConvBNAct(channels, channels, (1, 3), dilation=r1, act_type='none') 106 | ) 107 | self.conv_mid = nn.Sequential( 108 | nn.Conv2d(channels, channels, (3, 1), padding=(r2, 0), 109 | dilation=r2, bias=False), 110 | Activation(act_type, inplace=True), 111 | ConvBNAct(channels, channels, (1, 3), dilation=r2, act_type='none') 112 | ) 113 | self.conv_right = nn.Sequential( 114 | nn.Conv2d(channels, channels, (3, 1), padding=(r3, 0), 115 | dilation=r3, bias=False), 116 | Activation(act_type, inplace=True), 117 | ConvBNAct(channels, channels, (1, 3), dilation=r3, act_type='none') 118 | ) 119 | self.act = Activation(act_type) 120 | 121 | def forward(self, x): 122 | residual = x 123 | 124 | x = self.conv0(x) 125 | 126 | x_left = self.conv_left(x) 127 | x_mid = self.conv_mid(x) 128 | x_right = self.conv_right(x) 129 | 130 | x = x_left + x_mid + x_right + residual 131 | 132 | return self.act(x) -------------------------------------------------------------------------------- /models/espnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Paper: ESPNet: Efficient Spatial Pyramid of Dilated Convolutions for Semantic Segmentation 3 | Url: https://arxiv.org/abs/1803.06815 4 | Create by: zh320 5 | Date: 2023/08/06 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | from .modules import conv1x1, conv3x3, ConvBNAct, DeConvBNAct, Activation 13 | from .model_registry import register_model 14 | 15 | 16 | @register_model() 17 | class ESPNet(nn.Module): 18 | def __init__(self, num_class=1, n_channel=3, arch_type='espnet', K=5, alpha2=2, 19 | alpha3=8, block_channel=[16, 64, 128], act_type='prelu'): 20 | super().__init__() 21 | arch_hub = ['espnet', 'espnet-a', 'espnet-b', 'espnet-c'] 22 | if arch_type not in arch_hub: 23 | raise ValueError(f'Unsupport architecture type: {arch_type}.\n') 24 | self.arch_type = arch_type 25 | 26 | use_skip = arch_type in ['espnet', 'espnet-b', 'espnet-c'] 27 | reinforce = arch_type in ['espnet', 'espnet-c'] 28 | use_decoder = arch_type in ['espnet'] 29 | 30 | if arch_type == 'espnet-a': 31 | block_channel[2] = block_channel[1] 32 | 33 | self.use_skip = use_skip 34 | self.reinforce = reinforce 35 | self.use_decoder = use_decoder 36 | 37 | self.l1_block = ConvBNAct(n_channel, block_channel[0], 3, 2, act_type=act_type) 38 | self.l2_block = L2Block(block_channel[0], block_channel[1], arch_type, alpha2, use_skip, reinforce, act_type) 39 | self.l3_block = L3Block(block_channel[2], num_class, arch_type, alpha3, use_skip, reinforce, use_decoder, act_type) 40 | 41 | if use_decoder: 42 | self.decoder = Decoder(num_class, 19, 131, act_type) 43 | 44 | def forward(self, x): 45 | x_input = x 46 | x = self.l1_block(x) 47 | if self.reinforce: 48 | size = x.size()[2:] 49 | x_half = F.interpolate(x_input, size, mode='bilinear') 50 | x = torch.cat([x, x_half], dim=1) 51 | if self.use_decoder: 52 | x_l1 = x 53 | 54 | if self.reinforce: 55 | x = self.l2_block(x, x_input) 56 | if self.use_decoder: 57 | x_l2 = x 58 | else: 59 | x = self.l2_block(x) 60 | 61 | x = self.l3_block(x) 62 | 63 | if self.use_decoder: 64 | x = self.decoder(x, x_l1, x_l2) 65 | else: 66 | size = x_input.size()[2:] 67 | x = F.interpolate(x, size, mode='bilinear', align_corners=True) 68 | 69 | return x 70 | 71 | 72 | class L2Block(nn.Module): 73 | def __init__(self, in_channels, hid_channels, arch_type, alpha, use_skip, 74 | reinforce, act_type='prelu'): 75 | super().__init__() 76 | self.arch_type = arch_type 77 | self.alpha = alpha 78 | self.use_skip = use_skip 79 | self.reinforce = reinforce 80 | 81 | if reinforce: 82 | in_channels += 3 83 | 84 | self.conv1 = ESPModule(in_channels, hid_channels, stride=2, act_type=act_type) 85 | 86 | layers = [] 87 | for _ in range(alpha): 88 | layers.append(ESPModule(hid_channels, hid_channels, act_type=act_type)) 89 | self.layers = nn.Sequential(*layers) 90 | 91 | def forward(self, x, x_input=None): 92 | x = self.conv1(x) 93 | if self.use_skip: 94 | skip = x 95 | 96 | x = self.layers(x) 97 | 98 | if self.use_skip: 99 | x = torch.cat([x, skip], dim=1) 100 | 101 | if self.reinforce: 102 | size = x.size()[2:] 103 | x_quarter = F.interpolate(x_input, size, mode='bilinear') 104 | x = torch.cat([x, x_quarter], dim=1) 105 | 106 | return x 107 | 108 | 109 | class L3Block(nn.Module): 110 | def __init__(self, in_channels, out_channels, arch_type, alpha, use_skip, 111 | reinforce, use_decoder, act_type='prelu'): 112 | super().__init__() 113 | self.arch_type = arch_type 114 | self.alpha = alpha 115 | self.use_skip = use_skip 116 | 117 | if reinforce: 118 | in_channels += 3 119 | 120 | self.conv1 = ESPModule(in_channels, 128, stride=2, act_type=act_type) 121 | 122 | layers = [] 123 | for _ in range(alpha): 124 | layers.append(ESPModule(128, 128, act_type=act_type)) 125 | self.layers = nn.Sequential(*layers) 126 | 127 | if use_decoder: 128 | self.conv_last = ConvBNAct(256, out_channels, 1, act_type=act_type) 129 | elif use_skip: 130 | self.conv_last = conv1x1(256, out_channels) 131 | else: 132 | self.conv_last = conv1x1(128, out_channels) 133 | 134 | def forward(self, x): 135 | x = self.conv1(x) 136 | if self.use_skip: 137 | skip = x 138 | 139 | x = self.layers(x) 140 | 141 | if self.use_skip: 142 | x = torch.cat([x, skip], dim=1) 143 | 144 | x = self.conv_last(x) 145 | 146 | return x 147 | 148 | 149 | class Decoder(nn.Module): 150 | def __init__(self, num_class, l1_channel, l2_channel, act_type='prelu'): 151 | super().__init__() 152 | self.upconv_l3 = DeConvBNAct(num_class, num_class, act_type=act_type) 153 | self.conv_cat_l2 = ConvBNAct(l2_channel, num_class, 1) 154 | self.conv_l2 = ESPModule(2*num_class, num_class) 155 | self.upconv_l2 = DeConvBNAct(num_class, num_class, act_type=act_type) 156 | self.conv_cat_l1 = ConvBNAct(l1_channel, num_class, 1) 157 | self.conv_l1 = ESPModule(2*num_class, num_class) 158 | self.upconv_l1 = DeConvBNAct(num_class, num_class) 159 | 160 | def forward(self, x, x_l1, x_l2): 161 | x = self.upconv_l3(x) 162 | x_l2 = self.conv_cat_l2(x_l2) 163 | x = torch.cat([x, x_l2], dim=1) 164 | x = self.conv_l2(x) 165 | 166 | x = self.upconv_l2(x) 167 | x_l1 = self.conv_cat_l1(x_l1) 168 | x = torch.cat([x, x_l1], dim=1) 169 | x = self.conv_l1(x) 170 | 171 | x = self.upconv_l1(x) 172 | 173 | return x 174 | 175 | 176 | class ESPModule(nn.Module): 177 | def __init__(self, in_channels, out_channels, K=5, ks=3, stride=1, act_type='prelu',): 178 | super().__init__() 179 | self.K = K 180 | self.stride = stride 181 | self.use_skip = (in_channels == out_channels) and (stride == 1) 182 | channel_kn = out_channels // K 183 | channel_k1 = out_channels - (K -1) * channel_kn 184 | self.perfect_divisor = channel_k1 == channel_kn 185 | 186 | if self.perfect_divisor: 187 | self.conv_kn = conv1x1(in_channels, channel_kn, stride) 188 | else: 189 | self.conv_kn = conv1x1(in_channels, channel_kn, stride) 190 | self.conv_k1 = conv1x1(in_channels, channel_k1, stride) 191 | 192 | self.layers = nn.ModuleList() 193 | for k in range(1, K+1): 194 | dt = 2**(k-1) # dilation 195 | channel = channel_k1 if k == 1 else channel_kn 196 | self.layers.append(ConvBNAct(channel, channel, ks, 1, dt, act_type=act_type)) 197 | 198 | def forward(self, x): 199 | if self.use_skip: 200 | residual = x 201 | 202 | transform_feats = [] 203 | if self.perfect_divisor: 204 | x = self.conv_kn(x) # Reduce 205 | for i in range(self.K): 206 | transform_feats.append(self.layers[i](x)) # Split --> Transform 207 | 208 | for j in range(1, self.K): 209 | transform_feats[j] += transform_feats[j-1] # Merge: Sum 210 | else: 211 | x1 = self.conv_k1(x) # Reduce 212 | xn = self.conv_kn(x) # Reduce 213 | transform_feats.append(self.layers[0](x1)) # Split --> Transform 214 | for i in range(1, self.K): 215 | transform_feats.append(self.layers[i](xn)) # Split --> Transform 216 | 217 | for j in range(2, self.K): 218 | transform_feats[j] += transform_feats[j-1] # Merge: Sum 219 | 220 | x = torch.cat(transform_feats, dim=1) # Merge: Concat 221 | 222 | if self.use_skip: 223 | x += residual 224 | 225 | return x -------------------------------------------------------------------------------- /models/espnetv2.py: -------------------------------------------------------------------------------- 1 | """ 2 | Paper: ESPNetv2: A Light-weight, Power Efficient, and General Purpose 3 | Convolutional Neural Network 4 | Url: https://arxiv.org/abs/1811.11431 5 | Create by: zh320 6 | Date: 2023/09/03 7 | """ 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | from .modules import conv1x1, DSConvBNAct, PWConvBNAct, ConvBNAct, PyramidPoolingModule, SegHead 14 | from .model_registry import register_model 15 | 16 | 17 | @register_model() 18 | class ESPNetv2(nn.Module): 19 | def __init__(self, num_class=1, n_channel=3, K=4, alpha3=3, alpha4=7, act_type='prelu'): 20 | super().__init__() 21 | self.pool = nn.AvgPool2d(3, 2, 1) 22 | self.l1_block = ConvBNAct(n_channel, 32, 3, 2, act_type=act_type) 23 | self.l2_block = EESPModule(32, stride=2, act_type=act_type) 24 | self.l3_block1 = EESPModule(64, stride=2, act_type=act_type) 25 | self.l3_block2 = build_blocks(EESPModule, 128, alpha3, act_type=act_type) 26 | self.l4_block1 = EESPModule(128, stride=2, act_type=act_type) 27 | self.l4_block2 = build_blocks(EESPModule, 256, alpha4, act_type=act_type) 28 | 29 | self.convl4_l3 = ConvBNAct(256, 128, 1) 30 | self.ppm = PyramidPoolingModule(256, 256, act_type=act_type, bias=True) 31 | self.decoder = SegHead(256, num_class, act_type=act_type) 32 | 33 | def forward(self, x): 34 | size = x.size()[2:] 35 | x_d4 = self.pool(self.pool(x)) 36 | x_d8 = self.pool(x_d4) 37 | x_d16 = self.pool(x_d8) 38 | 39 | x = self.l1_block(x) 40 | x = self.l2_block(x, x_d4) 41 | 42 | x = self.l3_block1(x, x_d8) 43 | x3 = self.l3_block2(x) 44 | size_l3 = x3.size()[2:] 45 | 46 | x = self.l4_block1(x3, x_d16) 47 | x = self.l4_block2(x) 48 | x = F.interpolate(x, size_l3, mode='bilinear', align_corners=True) 49 | x = self.convl4_l3(x) 50 | x = torch.cat([x, x3], dim=1) 51 | 52 | x = self.ppm(x) 53 | x = self.decoder(x) 54 | x = F.interpolate(x, size, mode='bilinear', align_corners=True) 55 | 56 | return x 57 | 58 | 59 | def build_blocks(block, channels, num_block, act_type='relu'): 60 | layers = [] 61 | for _ in range(num_block): 62 | layers.append(block(channels, act_type=act_type)) 63 | return nn.Sequential(*layers) 64 | 65 | 66 | class EESPModule(nn.Module): 67 | def __init__(self, channels, K=4, ks=3, stride=1, act_type='prelu'): 68 | super().__init__() 69 | assert channels % K == 0, 'Input channels should be integer multiples of K.\n' 70 | 71 | self.K = K 72 | channel_k = channels // K 73 | self.use_skip = stride == 1 74 | 75 | self.conv_init = nn.Conv2d(channels, channel_k, 1, groups=K, bias=False) 76 | self.layers = nn.ModuleList() 77 | for k in range(1, K+1): 78 | dt = 2**(k-1) # dilation 79 | self.layers.append(DSConvBNAct(channel_k, channel_k, ks, stride, dt, act_type=act_type)) 80 | self.conv_last = nn.Conv2d(channels, channels, 1, groups=K, bias=False) 81 | 82 | if not self.use_skip: 83 | self.pool = nn.AvgPool2d(3, 2, 1) 84 | self.conv_stride = nn.Sequential( 85 | ConvBNAct(3, 3, 3), 86 | conv1x1(3, channels*2) 87 | ) 88 | 89 | def forward(self, x, img=None): 90 | if not self.use_skip and img is None: 91 | raise ValueError('Strided EESP unit needs downsampled input image.\n') 92 | 93 | residual = x 94 | transform_feats = [] 95 | 96 | x = self.conv_init(x) # Reduce 97 | for i in range(self.K): 98 | transform_feats.append(self.layers[i](x)) # Split --> Transform 99 | 100 | for j in range(1, self.K): 101 | transform_feats[j] += transform_feats[j-1] # Merge: Sum 102 | 103 | x = torch.cat(transform_feats, dim=1) # Merge: Concat 104 | x = self.conv_last(x) 105 | 106 | if self.use_skip: 107 | x += residual 108 | else: 109 | residual = self.pool(residual) 110 | x = torch.cat([x, residual], dim=1) 111 | img = self.conv_stride(img) 112 | x += img 113 | 114 | return x -------------------------------------------------------------------------------- /models/fanet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Paper: Real-time Semantic Segmentation with Fast Attention 3 | Url: https://arxiv.org/abs/2007.03815 4 | Create by: zh320 5 | Date: 2024/04/06 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from torchvision.models.resnet import BasicBlock 12 | 13 | from .modules import ConvBNAct, DeConvBNAct, SegHead, Activation 14 | from .backbone import ResNet 15 | from .model_registry import register_model 16 | 17 | 18 | @register_model() 19 | class FANet(nn.Module): 20 | def __init__(self, num_class=1, n_channel=3, att_channel=32, backbone_type='resnet18', cat_feat=True, 21 | act_type='relu'): 22 | super().__init__() 23 | if backbone_type in ['resnet18', 'resnet34']: 24 | self.backbone = ResNet(backbone_type) 25 | channels = [64, 128, 256, 512] 26 | self.num_stage = len(channels) 27 | 28 | # Reduce spatial dimension for Res-1 29 | downsample = ConvBNAct(channels[0], channels[0], 1, 2, act_type='none') 30 | self.backbone.layer1[0] = BasicBlock(channels[0], channels[0], 2, downsample) 31 | else: 32 | raise NotImplementedError() 33 | self.cat_feat = cat_feat 34 | 35 | self.fast_attention = nn.ModuleList([FastAttention(channels[i], att_channel, act_type) for i in range(self.num_stage)]) 36 | 37 | layers = [FuseUp(att_channel, att_channel, act_type=act_type) for _ in range(self.num_stage-1)] 38 | layers.append(FuseUp(att_channel, att_channel, has_up=False, act_type=act_type)) 39 | self.fuse_up = nn.ModuleList(layers) 40 | 41 | last_channel = 4*att_channel if cat_feat else att_channel 42 | self.seg_head = SegHead(last_channel, num_class, act_type) 43 | 44 | def forward(self, x): 45 | size = x.size()[2:] 46 | x1, x2, x3, x4 = self.backbone(x) 47 | 48 | x4 = self.fast_attention[3](x4) 49 | x4 = self.fuse_up[3](x4) 50 | 51 | x3 = self.fast_attention[2](x3) 52 | x3 = self.fuse_up[2](x3, x4) 53 | 54 | x2 = self.fast_attention[1](x2) 55 | x2 = self.fuse_up[1](x2, x3) 56 | 57 | x1 = self.fast_attention[0](x1) 58 | x1 = self.fuse_up[0](x1, x2) 59 | 60 | if self.cat_feat: 61 | size1 = x1.size()[2:] 62 | x4 = F.interpolate(x4, size1, mode='bilinear', align_corners=True) 63 | x3 = F.interpolate(x3, size1, mode='bilinear', align_corners=True) 64 | x2 = F.interpolate(x2, size1, mode='bilinear', align_corners=True) 65 | 66 | x = torch.cat([x1, x2, x3, x4], dim=1) 67 | x = self.seg_head(x) 68 | else: 69 | x = self.seg_head(x1) 70 | 71 | x = F.interpolate(x, size, mode='bilinear', align_corners=True) 72 | 73 | return x 74 | 75 | 76 | class FastAttention(nn.Module): 77 | def __init__(self, in_channels, out_channels, act_type): 78 | super().__init__() 79 | self.conv_q = ConvBNAct(in_channels, out_channels, 3, act_type='none') 80 | self.conv_k = ConvBNAct(in_channels, out_channels, 3, act_type='none') 81 | self.conv_v = ConvBNAct(in_channels, out_channels, 3, act_type='none') 82 | self.conv_fuse = ConvBNAct(out_channels, out_channels, 3, act_type=act_type) 83 | 84 | def forward(self, x): 85 | x_q = self.conv_q(x) 86 | x_k = self.conv_k(x) 87 | x_v = self.conv_v(x) 88 | residual = x_v 89 | 90 | B, C, H, W = x_q.size() 91 | n = H * W 92 | 93 | x_q = x_q.view(B, C, n) 94 | x_k = x_k.view(B, C, n) 95 | x_v = x_v.view(B, C, n) 96 | 97 | x_q = F.normalize(x_q, p=2, dim=1) 98 | x_k = F.normalize(x_k, p=2, dim=1).permute(0,2,1) 99 | 100 | y = (x_q @ (x_k @ x_v)) / n 101 | y = y.view(B, C, H, W) 102 | y = self.conv_fuse(y) 103 | y += residual 104 | 105 | return y 106 | 107 | 108 | class FuseUp(nn.Module): 109 | def __init__(self, in_channels, out_channels, has_up=True, act_type='relu'): 110 | super().__init__() 111 | self.has_up = has_up 112 | if has_up: 113 | self.up = DeConvBNAct(in_channels, in_channels, act_type=act_type, inplace=True) 114 | 115 | self.conv = ConvBNAct(in_channels, out_channels, 3, act_type=act_type, inplace=True) 116 | 117 | def forward(self, x_fa, x_up=None): 118 | if self.has_up: 119 | if x_up is None: 120 | raise RuntimeError('Missing input from Up layer.\n') 121 | else: 122 | x_up = self.up(x_up) 123 | x_fa += x_up 124 | 125 | x_fa = self.conv(x_fa) 126 | 127 | return x_fa -------------------------------------------------------------------------------- /models/farseenet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Paper: FarSee-Net: Real-Time Semantic Segmentation by Efficient Multi-scale 3 | Context Aggregation and Feature Space Super-resolution 4 | Url: https://arxiv.org/abs/2003.03913 5 | Create by: zh320 6 | Date: 2023/10/08 7 | """ 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | from .modules import conv1x1, DWConvBNAct, ConvBNAct 14 | from .backbone import ResNet 15 | from .model_registry import register_model 16 | 17 | 18 | @register_model() 19 | class FarSeeNet(nn.Module): 20 | def __init__(self, num_class=1, n_channel=3, backbone_type='resnet18', act_type='relu'): 21 | super().__init__() 22 | if 'resnet' in backbone_type: 23 | self.frontend_network = ResNet(backbone_type) 24 | high_channels = 512 if backbone_type in ['resnet18', 'resnet34'] else 2048 25 | low_channels = 256 if backbone_type in ['resnet18', 'resnet34'] else 1024 26 | else: 27 | raise NotImplementedError() 28 | 29 | self.backend_network = FASPP(high_channels, low_channels, num_class, act_type) 30 | 31 | def forward(self, x): 32 | size = x.size()[2:] 33 | 34 | _, _, x_low, x_high = self.frontend_network(x) 35 | 36 | x = self.backend_network(x_high, x_low) 37 | 38 | x = F.interpolate(x, size, mode='bilinear', align_corners=True) 39 | 40 | return x 41 | 42 | 43 | class FASPP(nn.Module): 44 | def __init__(self, high_channels, low_channels, num_class, act_type, 45 | dilations=[6,12,18], hid_channels=256): 46 | super().__init__() 47 | # High level convolutions 48 | self.conv_high = nn.ModuleList([ 49 | ConvBNAct(high_channels, hid_channels, 1, act_type=act_type) 50 | ]) 51 | for dt in dilations: 52 | self.conv_high.append( 53 | nn.Sequential( 54 | ConvBNAct(high_channels, hid_channels, 1, act_type=act_type), 55 | DWConvBNAct(hid_channels, hid_channels, 3, dilation=dt, act_type=act_type) 56 | ) 57 | ) 58 | 59 | self.sub_pixel_high = nn.Sequential( 60 | conv1x1(hid_channels*4, hid_channels*2*(2**2)), 61 | nn.PixelShuffle(2) 62 | ) 63 | 64 | # Low level convolutions 65 | self.conv_low_init = ConvBNAct(low_channels, 48, 1, act_type=act_type) 66 | self.conv_low = nn.ModuleList([ 67 | ConvBNAct(hid_channels*2+48, hid_channels//2, 1, act_type=act_type) 68 | ]) 69 | for dt in dilations[:-1]: 70 | self.conv_low.append( 71 | nn.Sequential( 72 | ConvBNAct(hid_channels*2+48, hid_channels//2, 1, act_type=act_type), 73 | DWConvBNAct(hid_channels//2, hid_channels//2, 3, dilation=dt, act_type=act_type) 74 | ) 75 | ) 76 | 77 | self.conv_low_last = nn.Sequential( 78 | ConvBNAct(hid_channels//2*3, hid_channels*2, 1, act_type=act_type), 79 | ConvBNAct(hid_channels*2, hid_channels*2, act_type=act_type) 80 | ) 81 | 82 | self.sub_pixel_low = nn.Sequential( 83 | conv1x1(hid_channels*2, num_class*(4**2)), 84 | nn.PixelShuffle(4) 85 | ) 86 | 87 | def forward(self, x_high, x_low): 88 | # High level features 89 | high_feats = [] 90 | for conv_high in self.conv_high: 91 | high_feats.append(conv_high(x_high)) 92 | 93 | x = torch.cat(high_feats, dim=1) 94 | x = self.sub_pixel_high(x) 95 | 96 | # Low level features 97 | x_low = self.conv_low_init(x_low) 98 | x = torch.cat([x, x_low], dim=1) 99 | 100 | low_feats = [] 101 | for conv_low in self.conv_low: 102 | low_feats.append(conv_low(x)) 103 | 104 | x = torch.cat(low_feats, dim=1) 105 | x = self.conv_low_last(x) 106 | x = self.sub_pixel_low(x) 107 | 108 | return x -------------------------------------------------------------------------------- /models/fastscnn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Paper: Fast-SCNN: Fast Semantic Segmentation Network 3 | Url: https://arxiv.org/abs/1902.04502 4 | Create by: zh320 5 | Date: 2023/04/16 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | from .modules import conv1x1, DSConvBNAct, DWConvBNAct, PWConvBNAct, ConvBNAct, Activation, PyramidPoolingModule 13 | from .model_registry import register_model 14 | 15 | 16 | @register_model() 17 | class FastSCNN(nn.Module): 18 | def __init__(self, num_class=1, n_channel=3, act_type='relu'): 19 | super().__init__() 20 | self.learning_to_downsample = LearningToDownsample(n_channel, 64, act_type=act_type) 21 | self.global_feature_extractor = GlobalFeatureExtractor(64, 128, act_type=act_type) 22 | self.feature_fusion = FeatureFusionModule(64, 128, 128, act_type=act_type) 23 | self.classifier = Classifier(128, num_class, act_type=act_type) 24 | 25 | def forward(self, x): 26 | size = x.size()[2:] 27 | higher_res_feat = self.learning_to_downsample(x) 28 | lower_res_feat = self.global_feature_extractor(higher_res_feat) 29 | x = self.feature_fusion(higher_res_feat, lower_res_feat) 30 | x = self.classifier(x) 31 | x = F.interpolate(x, size, mode='bilinear', align_corners=True) 32 | 33 | return x 34 | 35 | 36 | class LearningToDownsample(nn.Sequential): 37 | def __init__(self, in_channels, out_channels, hid_channels=[32, 48], act_type='relu'): 38 | super().__init__( 39 | ConvBNAct(in_channels, hid_channels[0], 3, 2, act_type=act_type), 40 | DSConvBNAct(hid_channels[0], hid_channels[1], 3, 2, act_type=act_type), 41 | DSConvBNAct(hid_channels[1], out_channels, 3, 2, act_type=act_type), 42 | ) 43 | 44 | 45 | class GlobalFeatureExtractor(nn.Module): 46 | def __init__(self, in_channels, out_channels, act_type='relu'): 47 | super().__init__() 48 | inverted_residual_setting = [ 49 | # t, c, n, s 50 | [6, 64, 3, 2], 51 | [6, 96, 2, 2], 52 | [6, 128, 3, 1], 53 | ] 54 | 55 | # Building inverted residual blocks, codes borrowed from 56 | # https://github.com/pytorch/vision/blob/main/torchvision/models/mobilenetv2.py 57 | features = [] 58 | for t, c, n, s in inverted_residual_setting: 59 | for i in range(n): 60 | stride = s if i == 0 else 1 61 | features.append(InvertedResidual(in_channels, c, stride, t, act_type=act_type)) 62 | in_channels = c 63 | self.bottlenecks = nn.Sequential(*features) 64 | 65 | self.ppm = PyramidPoolingModule(in_channels, out_channels, act_type=act_type, bias=True) 66 | 67 | def forward(self, x): 68 | x = self.bottlenecks(x) 69 | x = self.ppm(x) 70 | 71 | return x 72 | 73 | 74 | class FeatureFusionModule(nn.Module): 75 | def __init__(self, higher_channels, lower_channels, out_channels, act_type='relu'): 76 | super().__init__() 77 | self.higher_res_conv = conv1x1(higher_channels, out_channels) 78 | self.lower_res_conv = nn.Sequential( 79 | DWConvBNAct(lower_channels, lower_channels, 3, 1, act_type=act_type), 80 | conv1x1(lower_channels, out_channels) 81 | ) 82 | self.non_linear = nn.Sequential( 83 | nn.BatchNorm2d(out_channels), 84 | Activation(act_type) 85 | ) 86 | 87 | def forward(self, higher_res_feat, lower_res_feat): 88 | size = higher_res_feat.size()[2:] 89 | higher_res_feat = self.higher_res_conv(higher_res_feat) 90 | lower_res_feat = F.interpolate(lower_res_feat, size, mode='bilinear', align_corners=True) 91 | lower_res_feat = self.lower_res_conv(lower_res_feat) 92 | x = self.non_linear(higher_res_feat + lower_res_feat) 93 | 94 | return x 95 | 96 | 97 | class Classifier(nn.Sequential): 98 | def __init__(self, in_channels, num_class, act_type='relu'): 99 | super().__init__( 100 | DSConvBNAct(in_channels, in_channels, 3, 1, act_type=act_type), 101 | DSConvBNAct(in_channels, in_channels, 3, 1, act_type=act_type), 102 | PWConvBNAct(in_channels, num_class, act_type=act_type), 103 | ) 104 | 105 | 106 | class InvertedResidual(nn.Module): 107 | def __init__(self, in_channels, out_channels, stride, expand_ratio=6, act_type='relu'): 108 | super().__init__() 109 | hid_channels = int(round(in_channels * expand_ratio)) 110 | self.use_res_connect = stride == 1 and in_channels == out_channels 111 | 112 | self.conv = nn.Sequential( 113 | PWConvBNAct(in_channels, hid_channels, act_type=act_type), 114 | DWConvBNAct(hid_channels, hid_channels, 3, stride, act_type=act_type), 115 | ConvBNAct(hid_channels, out_channels, 1, act_type='none') 116 | ) 117 | 118 | def forward(self, x): 119 | if self.use_res_connect: 120 | return x + self.conv(x) 121 | else: 122 | return self.conv(x) -------------------------------------------------------------------------------- /models/fddwnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Paper: FDDWNet: A Lightweight Convolutional Neural Network for Real-time 3 | Sementic Segmentation 4 | Url: https://arxiv.org/abs/1911.00632 5 | Create by: zh320 6 | Date: 2023/10/08 7 | """ 8 | 9 | import torch 10 | import torch.nn as nn 11 | 12 | from .modules import DWConvBNAct, ConvBNAct, DeConvBNAct, Activation 13 | from .enet import InitialBlock as DownsamplingUnit 14 | from .model_registry import register_model 15 | 16 | 17 | @register_model() 18 | class FDDWNet(nn.Module): 19 | def __init__(self, num_class=1, n_channel=3, ks=3, act_type='relu'): 20 | super().__init__() 21 | self.layer1 = DownsamplingUnit(n_channel, 16, act_type) 22 | self.layer2 = DownsamplingUnit(16, 64, act_type) 23 | self.layer3_7 = build_blocks(EERMUnit, 64, 5, ks, [1,1,1,1,1], act_type) 24 | self.layer8 = DownsamplingUnit(64, 128, act_type) 25 | self.layer9_16 = build_blocks(EERMUnit, 128, 8, ks, [1,2,5,9,1,2,5,9], act_type) 26 | self.layer17_24 = build_blocks(EERMUnit, 128, 8, ks, [2,5,9,17,2,5,9,17], act_type) 27 | self.layer25 = DeConvBNAct(128, 64, act_type=act_type) 28 | self.layer26_27 = build_blocks(EERMUnit, 64, 2, ks, [1,1], act_type) 29 | self.layer28 = DeConvBNAct(64, 16, act_type=act_type) 30 | self.layer29_30 = build_blocks(EERMUnit, 16, 2, ks, [1,1], act_type) 31 | self.layer31 = DeConvBNAct(16, num_class, act_type=act_type) 32 | 33 | def forward(self, x): 34 | x = self.layer1(x) 35 | x = self.layer2(x) 36 | residual = self.layer3_7(x) 37 | x = self.layer8(residual) 38 | x = self.layer9_16(x) 39 | x = self.layer17_24(x) 40 | x = self.layer25(x) 41 | x = self.layer26_27(x) 42 | x += residual 43 | x = self.layer28(x) 44 | x = self.layer29_30(x) 45 | x = self.layer31(x) 46 | 47 | return x 48 | 49 | 50 | def build_blocks(block, channels, num_block, kernel_size, dilations=[], act_type='relu'): 51 | if len(dilations) == 0: 52 | dilations = [1 for _ in range(num_block)] 53 | else: 54 | if len(dilations) != num_block: 55 | raise ValueError(f'Number of dilation should be equal to number of blocks') 56 | 57 | layers = [] 58 | for i in range(num_block): 59 | layers.append(block(channels, kernel_size, dilations[i], act_type)) 60 | return nn.Sequential(*layers) 61 | 62 | 63 | class EERMUnit(nn.Module): 64 | def __init__(self, channels, ks, dt, act_type): 65 | super().__init__() 66 | self.conv = nn.Sequential( 67 | DWConvBNAct(channels, channels, (ks, 1), act_type='none'), 68 | DWConvBNAct(channels, channels, (1, ks), act_type='none'), 69 | ConvBNAct(channels, channels, 1, act_type=act_type, inplace=True), 70 | DWConvBNAct(channels, channels, (ks, 1), dilation=dt, act_type='none'), 71 | DWConvBNAct(channels, channels, (1, ks), dilation=dt, act_type='none'), 72 | ConvBNAct(channels, channels, 1, act_type='none') 73 | ) 74 | self.act = Activation(act_type) 75 | 76 | def forward(self, x): 77 | residual = x 78 | 79 | x = self.conv(x) 80 | x += residual 81 | 82 | return self.act(x) -------------------------------------------------------------------------------- /models/fpenet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Paper: Feature Pyramid Encoding Network for Real-time Semantic Segmentation 3 | Url: https://arxiv.org/abs/1909.08599 4 | Create by: zh320 5 | Date: 2023/10/08 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | from .modules import DWConvBNAct, ConvBNAct 13 | from .model_registry import register_model 14 | 15 | 16 | @register_model() 17 | class FPENet(nn.Module): 18 | def __init__(self, num_class=1, n_channel=3, p=3, q=9, k=4, act_type='relu'): 19 | super().__init__() 20 | self.stage1 = nn.Sequential( 21 | ConvBNAct(n_channel, 16, 3, 2, act_type=act_type, inplace=True), 22 | FPEBlock(16, 16, 1, 1, act_type=act_type) 23 | ) 24 | self.stage2_0 = FPEBlock(16, 32, k, 2, act_type=act_type) 25 | self.stage2 = build_blocks(FPEBlock, 32, p-1, k, act_type) 26 | self.stage3_0 = FPEBlock(32, 64, k, 2, act_type=act_type) 27 | self.stage3 = build_blocks(FPEBlock, 64, q-1, k, act_type) 28 | self.decoder2 = MEUModule(32, 64, 64, act_type) 29 | self.decoder1 = MEUModule(16, 64, 32, act_type) 30 | self.final = ConvBNAct(32, num_class, 1, act_type=act_type, inplace=True) 31 | 32 | def forward(self, x): 33 | size = x.size()[2:] 34 | x1 = self.stage1(x) 35 | x = self.stage2_0(x1) 36 | x2 = self.stage2(x) 37 | x = self.stage3_0(x2) 38 | x = self.stage3(x) 39 | x = self.decoder2(x2, x) 40 | x = self.decoder1(x1, x) 41 | x = self.final(x) 42 | x = F.interpolate(x, size, mode='bilinear', align_corners=True) 43 | 44 | return x 45 | 46 | 47 | def build_blocks(block, channels, num_block, expansion, act_type): 48 | layers = [] 49 | for i in range(num_block): 50 | layers.append(block(channels, channels, expansion, 1, act_type=act_type)) 51 | return nn.Sequential(*layers) 52 | 53 | 54 | class FPEBlock(nn.Module): 55 | def __init__(self, in_channels, out_channels, expansion, stride, dilations=[1,2,4,8], 56 | act_type='relu'): 57 | super().__init__() 58 | assert len(dilations) > 0, 'Length of dilations should be larger than 0.\n' 59 | self.K = len(dilations) 60 | self.use_skip = (in_channels == out_channels) and (stride == 1) 61 | expand_channels = out_channels * expansion 62 | self.ch = expand_channels // self.K 63 | 64 | self.conv_init = ConvBNAct(in_channels, expand_channels, 1, act_type=act_type, inplace=True) 65 | 66 | self.layers = nn.ModuleList() 67 | for i in range(self.K): 68 | self.layers.append(DWConvBNAct(self.ch, self.ch, 3, stride, dilations[i], act_type=act_type)) 69 | 70 | self.conv_last = ConvBNAct(expand_channels, out_channels, 1, act_type=act_type) 71 | 72 | def forward(self, x): 73 | if self.use_skip: 74 | residual = x 75 | 76 | x = self.conv_init(x) 77 | 78 | transform_feats = [] 79 | for i in range(self.K): 80 | transform_feats.append(self.layers[i](x[:, i*self.ch:(i+1)*self.ch])) 81 | 82 | for j in range(1, self.K): 83 | transform_feats[j] += transform_feats[j-1] 84 | 85 | x = torch.cat(transform_feats, dim=1) 86 | 87 | x = self.conv_last(x) 88 | 89 | if self.use_skip: 90 | x += residual 91 | 92 | return x 93 | 94 | 95 | class MEUModule(nn.Module): 96 | def __init__(self, low_channels, high_channels, out_channels, act_type): 97 | super().__init__() 98 | self.conv_low = ConvBNAct(low_channels, out_channels, 1, act_type=act_type, inplace=True) 99 | self.conv_high = ConvBNAct(high_channels, out_channels, 1, act_type=act_type, inplace=True) 100 | self.sa = SpatialAttentionBlock(act_type) 101 | self.ca = ChannelAttentionBlock(out_channels, act_type) 102 | 103 | def forward(self, x_low, x_high): 104 | x_low = self.conv_low(x_low) 105 | x_high = self.conv_high(x_high) 106 | 107 | x_sa = self.sa(x_low) 108 | x_ca = self.ca(x_high) 109 | 110 | x_low = x_low * x_ca 111 | x_high = F.interpolate(x_high, scale_factor=2, mode='bilinear', align_corners=True) 112 | x_high = x_high * x_sa 113 | 114 | return x_low + x_high 115 | 116 | 117 | class SpatialAttentionBlock(nn.Module): 118 | def __init__(self, act_type): 119 | super().__init__() 120 | self.conv = ConvBNAct(1, 1, 1, act_type=act_type, inplace=True) 121 | 122 | def forward(self, x): 123 | x = self.conv(torch.mean(x, dim=1, keepdim=True)) 124 | 125 | return x 126 | 127 | 128 | class ChannelAttentionBlock(nn.Sequential): 129 | def __init__(self, channels, act_type): 130 | super().__init__( 131 | nn.AdaptiveAvgPool2d(1), 132 | ConvBNAct(channels, channels, 1, act_type=act_type, inplace=True) 133 | ) -------------------------------------------------------------------------------- /models/fssnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Paper: Fast Semantic Segmentation for Scene Perception 3 | Url: https://ieeexplore.ieee.org/document/8392426 4 | Create by: zh320 5 | Date: 2023/10/22 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | from .modules import ConvBNAct, DeConvBNAct, Activation 13 | from .enet import InitialBlock as InitBlock 14 | from .model_registry import register_model 15 | 16 | 17 | @register_model() 18 | class FSSNet(nn.Module): 19 | def __init__(self, num_class=1, n_channel=3, act_type='prelu'): 20 | super().__init__() 21 | # Encoder 22 | self.init_block = InitBlock(n_channel, 16, act_type) 23 | self.down1 = DownsamplingBlock(16, 64, act_type) 24 | self.factorized = build_blocks(FactorizedBlock, 64, 4, act_type=act_type) 25 | self.down2 = DownsamplingBlock(64, 128, act_type) 26 | self.dilated = build_blocks(DilatedBlock, 128, 6, [2,5,9,2,5,9], act_type) 27 | # Decoder 28 | self.up2 = UpsamplingBlock(128, 64, act_type) 29 | self.bottleneck2 = build_blocks(DilatedBlock, 64, 2, act_type=act_type) 30 | self.up1 = UpsamplingBlock(64, 16, act_type) 31 | self.bottleneck1 = build_blocks(DilatedBlock, 16, 2, act_type=act_type) 32 | self.full_conv = DeConvBNAct(16, num_class, act_type=act_type) 33 | 34 | def forward(self, x): 35 | x = self.init_block(x) # 2x down 36 | x_d1 = self.down1(x) # 4x down 37 | x = self.factorized(x_d1) 38 | x_d2 = self.down2(x) # 8x down 39 | x = self.dilated(x_d2) 40 | 41 | x = self.up2(x, x_d2) # 8x up 42 | x = self.bottleneck2(x) 43 | x = self.up1(x, x_d1) # 4x up 44 | x = self.bottleneck1(x) 45 | x = self.full_conv(x) # 2x up 46 | 47 | return x 48 | 49 | 50 | def build_blocks(block, channels, num_block, dilations=[], act_type='relu'): 51 | if len(dilations) == 0: 52 | dilations = [1 for _ in range(num_block)] 53 | else: 54 | if len(dilations) != num_block: 55 | raise ValueError(f'Number of dilation should be equal to number of blocks') 56 | 57 | layers = [] 58 | for i in range(num_block): 59 | layers.append(block(channels, dilations[i], act_type)) 60 | return nn.Sequential(*layers) 61 | 62 | 63 | class FactorizedBlock(nn.Module): 64 | def __init__(self, channels, dilation=1, act_type='relu'): 65 | super().__init__() 66 | hid_channels = channels // 4 67 | self.conv = nn.Sequential( 68 | ConvBNAct(channels, hid_channels, 1, act_type=act_type), 69 | ConvBNAct(hid_channels, hid_channels, (1,3), act_type='none'), 70 | ConvBNAct(hid_channels, hid_channels, (3,1), act_type=act_type), 71 | ConvBNAct(hid_channels, channels, 1, act_type='none') 72 | ) 73 | self.act = Activation(act_type) 74 | 75 | def forward(self, x): 76 | residual = x 77 | 78 | x = self.conv(x) 79 | x += residual 80 | 81 | return self.act(x) 82 | 83 | 84 | class DilatedBlock(nn.Module): 85 | def __init__(self, channels, dilation, act_type): 86 | super().__init__() 87 | hid_channels = channels // 4 88 | self.conv = nn.Sequential( 89 | ConvBNAct(channels, hid_channels, 1, act_type=act_type), 90 | ConvBNAct(hid_channels, hid_channels, 3, dilation=dilation, act_type=act_type), 91 | ConvBNAct(hid_channels, channels, 1, act_type='none') 92 | ) 93 | self.act = Activation(act_type) 94 | 95 | def forward(self, x): 96 | residual = x 97 | 98 | x = self.conv(x) 99 | x += residual 100 | 101 | return self.act(x) 102 | 103 | 104 | class DownsamplingBlock(nn.Module): 105 | def __init__(self, in_channels, out_channels, act_type): 106 | super().__init__() 107 | hid_channels = out_channels // 4 108 | self.conv = nn.Sequential( 109 | ConvBNAct(in_channels, hid_channels, 2, 2, act_type=act_type), 110 | ConvBNAct(hid_channels, hid_channels, 3, act_type=act_type), 111 | ConvBNAct(hid_channels, out_channels, 1, act_type='none') 112 | ) 113 | self.pool = nn.Sequential( 114 | nn.MaxPool2d(3, 2, 1), 115 | ConvBNAct(in_channels, out_channels, 1, act_type='none') 116 | ) 117 | self.act = Activation(act_type) 118 | 119 | def forward(self, x): 120 | x_pool = self.pool(x) 121 | x = self.conv(x) 122 | x += x_pool 123 | 124 | return self.act(x) 125 | 126 | 127 | class UpsamplingBlock(nn.Module): 128 | def __init__(self, in_channels, out_channels, act_type): 129 | super().__init__() 130 | hid_channels = in_channels // 4 131 | self.deconv = nn.Sequential( 132 | ConvBNAct(in_channels, hid_channels, 1, act_type=act_type), 133 | DeConvBNAct(hid_channels, hid_channels, act_type=act_type), 134 | ConvBNAct(hid_channels, out_channels, 1, act_type='none') 135 | ) 136 | self.conv = ConvBNAct(in_channels, out_channels, 1, act_type='none') 137 | self.act = Activation(act_type) 138 | 139 | def forward(self, x, pool_feat): 140 | x_deconv = self.deconv(x) 141 | 142 | x = x + pool_feat 143 | x = self.conv(x) 144 | x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True) 145 | 146 | x += x_deconv 147 | 148 | return self.act(x) -------------------------------------------------------------------------------- /models/icnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Paper: ICNet for Real-Time Semantic Segmentation on High-Resolution Images 3 | Url: https://arxiv.org/abs/1704.08545 4 | Create by: zh320 5 | Date: 2023/10/15 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | from .modules import conv1x1, ConvBNAct, Activation, PyramidPoolingModule, SegHead 13 | from .model_registry import register_model, aux_models 14 | 15 | 16 | @register_model(aux_models) 17 | class ICNet(nn.Module): 18 | def __init__(self, num_class=1, n_channel=3, backbone_type='resnet18', act_type='relu', 19 | use_aux=True): 20 | super().__init__() 21 | if 'resnet' in backbone_type: 22 | self.backbone = ResNet(backbone_type) 23 | ch1 = 512 if backbone_type in ['resnet18', 'resnet34'] else 2048 24 | ch2 = 128 if backbone_type in ['resnet18', 'resnet34'] else 512 25 | else: 26 | raise NotImplementedError() 27 | 28 | self.use_aux = use_aux 29 | self.bottom_branch = HighResolutionBranch(n_channel, 128, act_type=act_type) 30 | self.ppm = PyramidPoolingModule(ch1, 256, act_type=act_type) 31 | self.cff42 = CascadeFeatureFusionUnit(256, ch2, 128, num_class, act_type, use_aux) 32 | self.cff21 = CascadeFeatureFusionUnit(128, 128, 128, num_class, act_type, use_aux) 33 | self.seg_head = SegHead(128, num_class, act_type) 34 | 35 | def forward(self, x, is_training=False): 36 | size = x.size()[2:] 37 | x_d2 = F.interpolate(x, scale_factor=0.5, mode='bilinear', align_corners=True) 38 | x_d4 = F.interpolate(x, scale_factor=0.25, mode='bilinear', align_corners=True) 39 | 40 | # Lowest resolution branch 41 | x_d4, _ = self.backbone(x_d4) # 32x down 42 | x_d4 = self.ppm(x_d4) 43 | 44 | # Medium resolution branch 45 | _, x_d2 = self.backbone(x_d2) # 16x down 46 | 47 | # High resolution branch 48 | x = self.bottom_branch(x) # 8x down 49 | 50 | # Cascade feature fusion 51 | if self.use_aux: 52 | x_d2, aux2 = self.cff42(x_d4, x_d2) # 16x down 53 | x, aux3 = self.cff21(x_d2, x) # 8x down 54 | else: 55 | x_d2 = self.cff42(x_d4, x_d2) 56 | x = self.cff21(x_d2, x) 57 | 58 | x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True) # 4x down 59 | x = self.seg_head(x) # 4x down 60 | 61 | x = F.interpolate(x, size, mode='bilinear', align_corners=True) 62 | 63 | if self.use_aux and is_training: 64 | return x, (aux2, aux3) 65 | else: 66 | return x 67 | 68 | 69 | class CascadeFeatureFusionUnit(nn.Module): 70 | def __init__(self, channel1, channel2, out_channels, num_class, act_type, use_aux): 71 | super().__init__() 72 | self.use_aux = use_aux 73 | self.conv1 = ConvBNAct(channel1, out_channels, 3, 1, 2, act_type='none') 74 | self.conv2 = ConvBNAct(channel2, out_channels, 1, act_type='none') 75 | self.act = Activation(act_type) 76 | if use_aux: 77 | self.classifier = SegHead(channel1, num_class, act_type) 78 | 79 | def forward(self, x1, x2): 80 | x1 = F.interpolate(x1, scale_factor=2, mode='bilinear', align_corners=True) 81 | if self.use_aux: 82 | x_aux = self.classifier(x1) 83 | 84 | x1 = self.conv1(x1) 85 | x2 = self.conv2(x2) 86 | 87 | x = self.act(x1 + x2) 88 | 89 | if self.use_aux: 90 | return x, x_aux 91 | else: 92 | return x 93 | 94 | 95 | class HighResolutionBranch(nn.Sequential): 96 | def __init__(self, in_channels, out_channels, hid_channels=32, act_type='relu'): 97 | super().__init__( 98 | ConvBNAct(in_channels, hid_channels, 3, 2, act_type=act_type), 99 | ConvBNAct(hid_channels, hid_channels*2, 3, 2, act_type=act_type), 100 | ConvBNAct(hid_channels*2, out_channels, 3, 2, act_type=act_type) 101 | ) 102 | 103 | 104 | class ResNet(nn.Module): 105 | def __init__(self, resnet_type, pretrained=True): 106 | super().__init__() 107 | from torchvision.models import resnet18, resnet34, resnet50, resnet101 108 | 109 | resnet_hub = {'resnet18':resnet18, 'resnet34':resnet34, 'resnet50':resnet50, 110 | 'resnet101':resnet101,} 111 | if resnet_type not in resnet_hub.keys(): 112 | raise ValueError(f'Unsupported ResNet type: {resnet_type}.\n') 113 | 114 | use_basicblock = resnet_type in ['resnet18', 'resnet34'] 115 | 116 | resnet = resnet_hub[resnet_type](pretrained=pretrained) 117 | self.conv1 = resnet.conv1 118 | self.bn1 = resnet.bn1 119 | self.relu = resnet.relu 120 | self.maxpool = resnet.maxpool 121 | self.layer1 = resnet.layer1 122 | self.layer2 = resnet.layer2 123 | self.layer3 = resnet.layer3 124 | self.layer4 = resnet.layer4 125 | 126 | # Change stride-2 conv to dilated conv 127 | layers = [[self.layer3[0], resnet.layer3[0]], [self.layer4[0], resnet.layer4[0]]] 128 | for i in range(1,3): 129 | ch = 128 if use_basicblock else 512 130 | resnet_downsample = layers[i-1][1].downsample[0] 131 | resnet_conv = layers[i-1][1].conv1 if use_basicblock else layers[i-1][1].conv2 132 | 133 | layers[i-1][0].downsample[0] = nn.Conv2d(ch*i, ch*i*2, 1, 1, bias=False) 134 | if use_basicblock: 135 | layers[i-1][0].conv1 = nn.Conv2d(ch*i, ch*i*2, 3, 1, 2*i, 2*i, bias=False) 136 | else: 137 | layers[i-1][0].conv2 = nn.Conv2d(ch//2*i, ch//2*i, 3, 1, 2*i, 2*i, bias=False) 138 | 139 | with torch.no_grad(): 140 | layers[i-1][1].downsample[0].weight.copy_(resnet_downsample.weight) 141 | if use_basicblock: 142 | layers[i-1][1].conv1.weight.copy_(resnet_conv.weight) 143 | else: 144 | layers[i-1][1].conv2.weight.copy_(resnet_conv.weight) 145 | 146 | def forward(self, x): 147 | x = self.conv1(x) # 2x down 148 | x = self.bn1(x) 149 | x = self.relu(x) 150 | x = self.maxpool(x) # 4x down 151 | x = self.layer1(x) 152 | x2 = self.layer2(x) # 8x down 153 | x = self.layer3(x2) # 8x down with dilation 2 154 | x = self.layer4(x) # 8x down with dilation 4 155 | 156 | return x, x2 -------------------------------------------------------------------------------- /models/lednet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Paper: LEDNet: A Lightweight Encoder-Decoder Network for Real-Time Semantic Segmentation 3 | Url: https://arxiv.org/abs/1905.02423 4 | Create by: zh320 5 | Date: 2023/04/23 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | from .modules import conv1x1, ConvBNAct, Activation, channel_shuffle 13 | from .enet import InitialBlock as DownsampleUint 14 | from .model_registry import register_model 15 | 16 | 17 | @register_model() 18 | class LEDNet(nn.Module): 19 | def __init__(self, num_class=1, n_channel=3, act_type='relu'): 20 | super().__init__() 21 | self.encoder = Encoder(n_channel, 128, act_type) 22 | self.apn = AttentionPyramidNetwork(128, num_class, act_type) 23 | 24 | def forward(self, x): 25 | size = x.size()[2:] 26 | x = self.encoder(x) 27 | x = self.apn(x) 28 | x = F.interpolate(x, size, mode='bilinear', align_corners=True) 29 | return x 30 | 31 | 32 | class Encoder(nn.Sequential): 33 | def __init__(self, in_channels, out_channels, act_type): 34 | super().__init__( 35 | DownsampleUint(in_channels, 32, act_type), 36 | SSnbtUnit(32, 1, act_type=act_type), 37 | SSnbtUnit(32, 1, act_type=act_type), 38 | SSnbtUnit(32, 1, act_type=act_type), 39 | DownsampleUint(32, 64, act_type), 40 | SSnbtUnit(64, 1, act_type=act_type), 41 | SSnbtUnit(64, 1, act_type=act_type), 42 | DownsampleUint(64, out_channels, act_type), 43 | SSnbtUnit(out_channels, 1, act_type=act_type), 44 | SSnbtUnit(out_channels, 2, act_type=act_type), 45 | SSnbtUnit(out_channels, 5, act_type=act_type), 46 | SSnbtUnit(out_channels, 9, act_type=act_type), 47 | SSnbtUnit(out_channels, 2, act_type=act_type), 48 | SSnbtUnit(out_channels, 5, act_type=act_type), 49 | SSnbtUnit(out_channels, 9, act_type=act_type), 50 | SSnbtUnit(out_channels, 17, act_type=act_type), 51 | ) 52 | 53 | 54 | class SSnbtUnit(nn.Module): 55 | def __init__(self, channels, dilation, act_type): 56 | super().__init__() 57 | assert channels % 2 == 0, 'Input channel should be multiple of 2.\n' 58 | split_channels = channels // 2 59 | self.split_channels = split_channels 60 | self.left_branch = nn.Sequential( 61 | nn.Conv2d(split_channels, split_channels, (3, 1), padding=(1,0)), 62 | Activation(act_type), 63 | ConvBNAct(split_channels, split_channels, (1, 3), act_type=act_type), 64 | nn.Conv2d(split_channels, split_channels, (3, 1), 65 | padding=(dilation,0), dilation=dilation), 66 | Activation(act_type), 67 | ConvBNAct(split_channels, split_channels, (1, 3), dilation=dilation, act_type=act_type), 68 | ) 69 | 70 | self.right_branch = nn.Sequential( 71 | nn.Conv2d(split_channels, split_channels, (1, 3), padding=(0,1)), 72 | Activation(act_type), 73 | ConvBNAct(split_channels, split_channels, (3, 1), act_type=act_type), 74 | nn.Conv2d(split_channels, split_channels, (1, 3), 75 | padding=(0,dilation), dilation=dilation), 76 | Activation(act_type), 77 | ConvBNAct(split_channels, split_channels, (3, 1), dilation=dilation, act_type=act_type), 78 | ) 79 | self.act = Activation(act_type) 80 | 81 | def forward(self, x): 82 | x_left = x[:, :self.split_channels].clone() 83 | x_right = x[:, self.split_channels:].clone() 84 | x_left = self.left_branch(x_left) 85 | x_right = self.right_branch(x_right) 86 | x_cat = torch.cat([x_left, x_right], dim=1) 87 | x += x_cat 88 | x = self.act(x) 89 | x = channel_shuffle(x) 90 | return x 91 | 92 | 93 | class AttentionPyramidNetwork(nn.Module): 94 | def __init__(self, in_channels, out_channels, act_type): 95 | super().__init__() 96 | self.left_conv1_1 = ConvBNAct(in_channels, in_channels, 3, 2, act_type=act_type) 97 | self.left_conv1_2 = ConvBNAct(in_channels, out_channels, 3, act_type=act_type) 98 | self.left_conv2_1 = ConvBNAct(in_channels, in_channels, 3, 2, act_type=act_type) 99 | self.left_conv2_2 = ConvBNAct(in_channels, out_channels, 3, act_type=act_type) 100 | self.left_conv3 = nn.Sequential( 101 | ConvBNAct(in_channels, in_channels, 3, 2, act_type=act_type), 102 | ConvBNAct(in_channels, out_channels, 3, act_type=act_type) 103 | ) 104 | 105 | self.mid_branch = ConvBNAct(in_channels, out_channels, act_type=act_type) 106 | self.right_branch = nn.Sequential( 107 | nn.AdaptiveAvgPool2d(1), 108 | ConvBNAct(in_channels, out_channels, act_type=act_type), 109 | ) 110 | 111 | def forward(self, x): 112 | size0 = x.size()[2:] 113 | 114 | x_left = self.left_conv1_1(x) 115 | size1 = x_left.size()[2:] 116 | 117 | x_left2 = self.left_conv2_1(x_left) 118 | size2 = x_left2.size()[2:] 119 | 120 | x_left3 = self.left_conv3(x_left2) 121 | x_left3 = F.interpolate(x_left3, size2, mode='bilinear', align_corners=True) 122 | 123 | x_left2 = self.left_conv2_2(x_left2) 124 | x_left2 += x_left3 125 | x_left2 = F.interpolate(x_left2, size1, mode='bilinear', align_corners=True) 126 | 127 | x_left = self.left_conv1_2(x_left) 128 | x_left += x_left2 129 | x_left = F.interpolate(x_left, size0, mode='bilinear', align_corners=True) 130 | 131 | x_mid = self.mid_branch(x) 132 | x_mid = torch.mul(x_left, x_mid) 133 | 134 | x_right = self.right_branch(x) 135 | x_right = F.interpolate(x_right, size0, mode='bilinear', align_corners=True) 136 | 137 | x_mid += x_right 138 | return x_mid -------------------------------------------------------------------------------- /models/linknet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Paper: LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation 3 | Url: https://arxiv.org/abs/1707.03718 4 | Create by: zh320 5 | Date: 2023/04/23 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | from .modules import ConvBNAct, DeConvBNAct, Activation 12 | from .backbone import ResNet 13 | from .model_registry import register_model 14 | 15 | 16 | @register_model() 17 | class LinkNet(nn.Module): 18 | def __init__(self, num_class=1, n_channel=3, backbone_type='resnet18', act_type='relu'): 19 | super().__init__() 20 | if 'resnet' in backbone_type: 21 | self.backbone = ResNet(backbone_type) 22 | channels = [64, 128, 256, 512] if backbone_type in ['resnet18', 'resnet34'] else [256, 512, 1024, 2048] 23 | else: 24 | raise NotImplementedError() 25 | 26 | self.dec_block4 = DecoderBlock(channels[3], channels[2], act_type) 27 | self.dec_block3 = DecoderBlock(channels[2], channels[1], act_type) 28 | self.dec_block2 = DecoderBlock(channels[1], channels[0], act_type) 29 | self.dec_block1 = DecoderBlock(channels[0], channels[0], act_type, scale_factor=1) 30 | self.seg_head = SegHead(channels[0], num_class, act_type) 31 | 32 | def forward(self, x): 33 | x_1, x_2, x_3, x_4 = self.backbone(x) 34 | x = self.dec_block4(x_4) 35 | x = self.dec_block3(x + x_3) 36 | x = self.dec_block2(x + x_2) 37 | x = self.dec_block1(x + x_1) 38 | x = self.seg_head(x) 39 | 40 | return x 41 | 42 | 43 | class DecoderBlock(nn.Module): 44 | def __init__(self, in_channels, out_channels, act_type, scale_factor=2): 45 | super().__init__() 46 | hid_channels = in_channels // 4 47 | self.conv1 = ConvBNAct(in_channels, hid_channels, 1, act_type=act_type) 48 | if scale_factor > 1: 49 | self.full_conv = DeConvBNAct(hid_channels, hid_channels, scale_factor, act_type=act_type) 50 | else: 51 | self.full_conv = ConvBNAct(hid_channels, hid_channels, 3, act_type=act_type) 52 | self.conv2 = ConvBNAct(hid_channels, out_channels, 1, act_type=act_type) 53 | 54 | def forward(self, x): 55 | x = self.conv1(x) 56 | x = self.full_conv(x) 57 | x = self.conv2(x) 58 | 59 | return x 60 | 61 | 62 | class SegHead(nn.Sequential): 63 | def __init__(self, in_channels, num_class, act_type, scale_factor=2): 64 | hid_channels = in_channels // 2 65 | super().__init__( 66 | DeConvBNAct(in_channels, hid_channels, scale_factor, act_type=act_type), 67 | ConvBNAct(hid_channels, hid_channels, 3, act_type=act_type), 68 | DeConvBNAct(hid_channels, num_class, scale_factor, act_type=act_type) 69 | ) -------------------------------------------------------------------------------- /models/liteseg.py: -------------------------------------------------------------------------------- 1 | """ 2 | Paper: LiteSeg: A Novel Lightweight ConvNet for Semantic Segmentation 3 | Url: https://arxiv.org/abs/1912.06683 4 | Create by: zh320 5 | Date: 2023/10/15 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | from .modules import conv1x1, ConvBNAct 13 | from .backbone import ResNet, Mobilenetv2 14 | from .model_registry import register_model 15 | 16 | 17 | @register_model() 18 | class LiteSeg(nn.Module): 19 | def __init__(self, num_class=1, n_channel=3, backbone_type='mobilenet_v2', act_type='relu'): 20 | super().__init__() 21 | if backbone_type == 'mobilenet_v2': 22 | self.backbone = Mobilenetv2() 23 | channels = [320, 32] 24 | elif 'resnet' in backbone_type: 25 | self.backbone = ResNet(backbone_type) 26 | channels = [512, 128] if backbone_type in ['resnet18', 'resnet34'] else [2048, 512] 27 | else: 28 | raise NotImplementedError() 29 | 30 | self.daspp = DASPPModule(channels[0], 512, act_type) 31 | self.seg_head = SegHead(512 + channels[1], num_class, act_type) 32 | 33 | def forward(self, x): 34 | size = x.size()[2:] 35 | 36 | _, x1, _, x = self.backbone(x) 37 | size1 = x1.size()[2:] 38 | 39 | x = self.daspp(x) 40 | x = F.interpolate(x, size1, mode='bilinear', align_corners=True) 41 | x = torch.cat([x, x1], dim=1) 42 | 43 | x = self.seg_head(x) 44 | x = F.interpolate(x, size, mode='bilinear', align_corners=True) 45 | 46 | return x 47 | 48 | 49 | class DASPPModule(nn.Module): 50 | def __init__(self, in_channels, out_channels, act_type): 51 | super().__init__() 52 | hid_channels = in_channels // 5 53 | last_channels = in_channels - hid_channels * 4 54 | self.stage1 = ConvBNAct(in_channels, hid_channels, 1, act_type=act_type) 55 | self.stage2 = ConvBNAct(in_channels, hid_channels, 3, dilation=3, act_type=act_type) 56 | self.stage3 = ConvBNAct(in_channels, hid_channels, 3, dilation=6, act_type=act_type) 57 | self.stage4 = ConvBNAct(in_channels, hid_channels, 3, dilation=9, act_type=act_type) 58 | self.stage5 = nn.Sequential( 59 | nn.AdaptiveAvgPool2d(1), 60 | conv1x1(in_channels, last_channels) 61 | ) 62 | self.conv = ConvBNAct(2*in_channels, out_channels, 1, act_type=act_type) 63 | 64 | def forward(self, x): 65 | size = x.size()[2:] 66 | 67 | x1 = self.stage1(x) 68 | x2 = self.stage2(x) 69 | x3 = self.stage3(x) 70 | x4 = self.stage4(x) 71 | x5 = self.stage5(x) 72 | x5 = F.interpolate(x5, size, mode='bilinear', align_corners=True) 73 | 74 | x = self.conv(torch.cat([x, x1, x2, x3, x4, x5], dim=1)) 75 | return x 76 | 77 | 78 | class SegHead(nn.Sequential): 79 | def __init__(self, in_channels, num_class, act_type, hid_channels=256): 80 | super().__init__( 81 | ConvBNAct(in_channels, hid_channels, 3, act_type=act_type), 82 | ConvBNAct(hid_channels, hid_channels//2, 3, act_type=act_type), 83 | conv1x1(hid_channels//2, num_class) 84 | ) -------------------------------------------------------------------------------- /models/mininet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Paper: Enhancing V-SLAM Keyframe Selection with an Efficient ConvNet for Semantic Analysis 3 | Url: https://ieeexplore.ieee.org/abstract/document/8793923 4 | Create by: zh320 5 | Date: 2023/10/15 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | from .modules import conv1x1, DSConvBNAct, ConvBNAct, DeConvBNAct, Activation 12 | from .model_registry import register_model 13 | 14 | 15 | @register_model() 16 | class MiniNet(nn.Module): 17 | def __init__(self, num_class=1, n_channel=3, act_type='selu'): 18 | super().__init__() 19 | # Downsample block 20 | self.down1 = DSConvBNAct(n_channel, 12, 3, 2, act_type=act_type) 21 | self.down2 = DSConvBNAct(12, 24, 3, 2, act_type=act_type) 22 | self.down3 = DSConvBNAct(24, 48, 3, 2, act_type=act_type) 23 | self.down4 = DSConvBNAct(48, 96, 3, 2, act_type=act_type) 24 | # Branch 1 25 | self.branch1 = nn.Sequential( 26 | ConvModule(96, 1, act_type), 27 | ConvModule(96, 2, act_type), 28 | ConvModule(96, 4, act_type), 29 | ConvModule(96, 8, act_type), 30 | ) 31 | # Branch 2 32 | self.branch2_down = DSConvBNAct(96, 192, 3, 2, act_type=act_type) 33 | self.branch2 = nn.Sequential( 34 | ConvModule(192, 1, act_type), 35 | DSConvBNAct(192, 386, 3, 2, act_type=act_type), 36 | ConvModule(386, 1, act_type), 37 | ConvModule(386, 1, act_type), 38 | DeConvBNAct(386, 192, act_type=act_type), 39 | ConvModule(192, 1, act_type), 40 | ) 41 | self.branch2_up = DeConvBNAct(192*2, 96, act_type=act_type) 42 | # Upsample Block 43 | self.up4 = nn.Sequential( 44 | DeConvBNAct(96*3, 96, act_type=act_type), 45 | ConvModule(96, 1, act_type), 46 | conv1x1(96, 48) 47 | ) 48 | self.up3 = DeConvBNAct(48*2, 24, act_type=act_type) 49 | self.up2 = DeConvBNAct(24*2, 12, act_type=act_type) 50 | self.up1 = DeConvBNAct(12*2, num_class, act_type=act_type) 51 | 52 | def forward(self, x): 53 | x_d1 = self.down1(x) 54 | x_d2 = self.down2(x_d1) 55 | x_d3 = self.down3(x_d2) 56 | x_d4 = self.down4(x_d3) 57 | 58 | x_b1 = self.branch1(x_d4) 59 | 60 | x_d5 = self.branch2_down(x_d4) 61 | x_b2 = self.branch2(x_d5) 62 | x_b2 = torch.cat([x_b2, x_d5], dim=1) 63 | x_b2 = self.branch2_up(x_b2) 64 | 65 | x = torch.cat([x_b1, x_b2, x_d4], dim=1) 66 | x = self.up4(x) 67 | x = torch.cat([x, x_d3], dim=1) 68 | x = self.up3(x) 69 | x = torch.cat([x, x_d2], dim=1) 70 | x = self.up2(x) 71 | x = torch.cat([x, x_d1], dim=1) 72 | x = self.up1(x) 73 | 74 | return x 75 | 76 | 77 | class ConvModule(nn.Module): 78 | def __init__(self, channels, dilation, act_type): 79 | super().__init__() 80 | self.conv1 = nn.Sequential( 81 | nn.Conv2d(channels, channels, (1,3), padding=(0, dilation), 82 | dilation=dilation, groups=channels, bias=False), 83 | Activation(act_type), 84 | nn.Conv2d(channels, channels, (3,1), padding=(dilation, 0), 85 | dilation=dilation, groups=channels, bias=False), 86 | Activation(act_type), 87 | ) 88 | self.conv2 = nn.Sequential( 89 | nn.Conv2d(channels, channels, (3,1), padding=(dilation, 0), 90 | dilation=dilation, groups=channels, bias=False), 91 | Activation(act_type), 92 | nn.Conv2d(channels, channels, (1,3), padding=(0, dilation), 93 | dilation=dilation, groups=channels, bias=False), 94 | ) 95 | self.dropout = nn.Dropout(p=0.25) 96 | self.act = Activation(act_type) 97 | 98 | def forward(self, x): 99 | residual = x 100 | 101 | x1 = self.conv1(x) 102 | x = self.conv2(x1) 103 | 104 | x += x1 105 | x = self.dropout(x) 106 | x += residual 107 | 108 | return self.act(x) -------------------------------------------------------------------------------- /models/mininetv2.py: -------------------------------------------------------------------------------- 1 | """ 2 | Paper: MiniNet: An Efficient Semantic Segmentation ConvNet for Real-Time Robotic Applications 3 | Url: https://ieeexplore.ieee.org/abstract/document/9023474 4 | Create by: zh320 5 | Date: 2023/10/15 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | from .modules import DWConvBNAct, PWConvBNAct, ConvBNAct, DeConvBNAct, Activation 13 | from .enet import InitialBlock as DownsamplingUnit 14 | from .model_registry import register_model 15 | 16 | 17 | @register_model() 18 | class MiniNetv2(nn.Module): 19 | def __init__(self, num_class=1, n_channel=3, feat_dt=[1,2,1,4,1,8,1,16,1,1,1,2,1,4,1,8], 20 | act_type='relu'): 21 | super().__init__() 22 | self.d1_2 = nn.Sequential( 23 | DownsamplingUnit(n_channel, 16, act_type), 24 | DownsamplingUnit(16, 64, act_type), 25 | ) 26 | self.ref = nn.Sequential( 27 | DownsamplingUnit(n_channel, 16, act_type), 28 | DownsamplingUnit(16, 64, act_type) 29 | ) 30 | self.m1_10 = build_blocks(MultiDilationDSConv, 64, 10, act_type=act_type) 31 | self.d3 = DownsamplingUnit(64, 128, act_type) 32 | self.feature_extractor = build_blocks(MultiDilationDSConv, 128, len(feat_dt), feat_dt, act_type) 33 | self.up1 = DeConvBNAct(128, 64, act_type=act_type) 34 | self.m26_29 = build_blocks(MultiDilationDSConv, 64, 4, act_type=act_type) 35 | self.output = DeConvBNAct(64, num_class, act_type=act_type) 36 | 37 | def forward(self, x): 38 | size = x.size()[2:] 39 | 40 | x_ref = self.ref(x) 41 | 42 | x = self.d1_2(x) 43 | x = self.m1_10(x) 44 | x = self.d3(x) 45 | x = self.feature_extractor(x) 46 | x = self.up1(x) 47 | x += x_ref 48 | 49 | x = self.m26_29(x) 50 | x = self.output(x) 51 | 52 | x = F.interpolate(x, size, mode='bilinear', align_corners=True) 53 | 54 | return x 55 | 56 | 57 | def build_blocks(block, channels, num_block, dilations=[], act_type='relu'): 58 | if len(dilations) == 0: 59 | dilations = [1 for _ in range(num_block)] 60 | else: 61 | if len(dilations) != num_block: 62 | raise ValueError(f'Number of dilation should be equal to number of blocks') 63 | 64 | layers = [] 65 | for i in range(num_block): 66 | layers.append(block(channels, channels, 3, 1, dilations[i], act_type)) 67 | return nn.Sequential(*layers) 68 | 69 | 70 | class MultiDilationDSConv(nn.Module): 71 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, act_type='relu'): 72 | super().__init__() 73 | self.dilated = dilation > 1 74 | self.dw_conv = DWConvBNAct(in_channels, in_channels, kernel_size, stride, 1, act_type) 75 | self.pw_conv = PWConvBNAct(in_channels, out_channels, act_type, inplace=True) 76 | if self.dilated: 77 | self.ddw_conv = DWConvBNAct(in_channels, in_channels, kernel_size, stride, dilation, act_type, inplace=True) 78 | 79 | def forward(self, x): 80 | x_dw = self.dw_conv(x) 81 | if self.dilated: 82 | x_ddw = self.ddw_conv(x) 83 | x_dw += x_ddw 84 | x = self.pw_conv(x_dw) 85 | 86 | return x -------------------------------------------------------------------------------- /models/model_registry.py: -------------------------------------------------------------------------------- 1 | model_hub = {} 2 | aux_models = [] 3 | detail_head_models = [] 4 | 5 | 6 | def register_model(*other_registries): 7 | def decorator(model_class): 8 | model_hub[model_class.__name__.lower()] = model_class 9 | 10 | for registry in other_registries: 11 | if isinstance(registry, list): 12 | registry.append(model_class.__name__.lower()) 13 | else: 14 | print(f"Model registry is not a list. Skipping registry: {registry}") 15 | 16 | return model_class 17 | return decorator -------------------------------------------------------------------------------- /models/pp_liteseg.py: -------------------------------------------------------------------------------- 1 | """ 2 | Paper: PP-LiteSeg: A Superior Real-Time Semantic Segmentation Model 3 | Url: https://arxiv.org/abs/2204.02681 4 | Create by: zh320 5 | Date: 2023/07/15 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | from .modules import conv1x1, conv3x3, ConvBNAct 13 | from .model_registry import register_model 14 | 15 | 16 | @register_model() 17 | class PPLiteSeg(nn.Module): 18 | def __init__(self, num_class=1, n_channel=3, encoder_channels=[32, 64, 256, 512, 1024], 19 | encoder_type='stdc1', fusion_type='spatial', act_type='relu'): 20 | super().__init__() 21 | decoder_channel_hub = {'stdc1': [32, 64, 128], 'stdc2': [64, 96, 128]} 22 | decoder_channels = decoder_channel_hub[encoder_type] 23 | 24 | self.encoder = Encoder(n_channel, encoder_channels, encoder_type, act_type) 25 | self.sppm = SPPM(encoder_channels[-1], decoder_channels[0], act_type) 26 | self.decoder = FLD(encoder_channels, decoder_channels, num_class, fusion_type, act_type) 27 | 28 | def forward(self, x): 29 | size = x.size()[2:] 30 | x3, x4, x5 = self.encoder(x) 31 | x5 = self.sppm(x5) 32 | x = self.decoder(x3, x4, x5, size) 33 | 34 | return x 35 | 36 | 37 | class Encoder(nn.Module): 38 | def __init__(self, in_channels, encoder_channels, encoder_type, act_type): 39 | super().__init__() 40 | encoder_hub = {'stdc1':STDCBackbone, 'stdc2':STDCBackbone} 41 | if encoder_type not in encoder_hub.keys(): 42 | raise ValueError(f'Unsupport encoder type: {encoder_type}.\n') 43 | 44 | self.encoder = encoder_hub[encoder_type](in_channels, encoder_channels, encoder_type, act_type) 45 | 46 | def forward(self, x): 47 | x3, x4, x5 = self.encoder(x) 48 | 49 | return x3, x4, x5 50 | 51 | 52 | class SPPM(nn.Module): 53 | def __init__(self, in_channels, out_channels, act_type): 54 | super().__init__() 55 | hid_channels = int(in_channels // 4) 56 | self.act_type = act_type 57 | 58 | self.pool1 = self._make_pool_layer(in_channels, hid_channels, 1) 59 | self.pool2 = self._make_pool_layer(in_channels, hid_channels, 2) 60 | self.pool3 = self._make_pool_layer(in_channels, hid_channels, 4) 61 | self.conv = conv3x3(hid_channels, out_channels) 62 | 63 | def _make_pool_layer(self, in_channels, out_channels, pool_size): 64 | return nn.Sequential( 65 | nn.AdaptiveAvgPool2d(pool_size), 66 | ConvBNAct(in_channels, out_channels, 1, act_type=self.act_type) 67 | ) 68 | 69 | def forward(self, x): 70 | size = x.size()[2:] 71 | x1 = F.interpolate(self.pool1(x), size, mode='bilinear', align_corners=True) 72 | x2 = F.interpolate(self.pool2(x), size, mode='bilinear', align_corners=True) 73 | x3 = F.interpolate(self.pool3(x), size, mode='bilinear', align_corners=True) 74 | x = self.conv(x1 + x2 + x3) 75 | 76 | return x 77 | 78 | 79 | class FLD(nn.Module): 80 | def __init__(self, encoder_channels, decoder_channels, num_class, fusion_type, act_type): 81 | super().__init__() 82 | self.stage6 = ConvBNAct(decoder_channels[0], decoder_channels[0]) 83 | self.fusion1 = UAFM(encoder_channels[3], decoder_channels[0], fusion_type) 84 | self.stage7 = ConvBNAct(decoder_channels[0], decoder_channels[1]) 85 | self.fusion2 = UAFM(encoder_channels[2], decoder_channels[1], fusion_type) 86 | self.stage8 = ConvBNAct(decoder_channels[1], decoder_channels[2]) 87 | self.seg_head = ConvBNAct(decoder_channels[2], num_class, 3, act_type=act_type) 88 | 89 | def forward(self, x3, x4, x5, size): 90 | x = self.stage6(x5) 91 | x = self.fusion1(x, x4) 92 | x = self.stage7(x) 93 | x = self.fusion2(x, x3) 94 | x = self.stage8(x) 95 | x = self.seg_head(x) 96 | x = F.interpolate(x, size, mode='bilinear', align_corners=True) 97 | 98 | return x 99 | 100 | 101 | class STDCBackbone(nn.Module): 102 | def __init__(self, in_channels, encoder_channels, encoder_type, act_type): 103 | super().__init__() 104 | repeat_times_hub = {'stdc1': [1,1,1], 'stdc2': [3,4,2]} 105 | repeat_times = repeat_times_hub[encoder_type] 106 | self.stage1 = ConvBNAct(in_channels, encoder_channels[0], 3, 2) 107 | self.stage2 = ConvBNAct(encoder_channels[0], encoder_channels[1], 3, 2) 108 | self.stage3 = self._make_stage(encoder_channels[1], encoder_channels[2], repeat_times[0], act_type) 109 | self.stage4 = self._make_stage(encoder_channels[2], encoder_channels[3], repeat_times[1], act_type) 110 | self.stage5 = self._make_stage(encoder_channels[3], encoder_channels[4], repeat_times[2], act_type) 111 | 112 | def _make_stage(self, in_channels, out_channels, repeat_times, act_type): 113 | layers = [STDCModule(in_channels, out_channels, 2, act_type)] 114 | 115 | for _ in range(repeat_times): 116 | layers.append(STDCModule(out_channels, out_channels, 1, act_type)) 117 | return nn.Sequential(*layers) 118 | 119 | def forward(self, x): 120 | x = self.stage1(x) 121 | x = self.stage2(x) 122 | x3 = self.stage3(x) 123 | x4 = self.stage4(x3) 124 | x5 = self.stage5(x4) 125 | return x3, x4, x5 126 | 127 | 128 | class STDCModule(nn.Module): 129 | def __init__(self, in_channels, out_channels, stride, act_type): 130 | super().__init__() 131 | if out_channels % 8 != 0: 132 | raise ValueError('Output channel should be evenly divided by 8.\n') 133 | self.stride = stride 134 | self.block1 = ConvBNAct(in_channels, out_channels//2, 1) 135 | self.block2 = ConvBNAct(out_channels//2, out_channels//4, 3, stride) 136 | if self.stride == 2: 137 | self.pool = nn.AvgPool2d(3, 2, 1) 138 | self.block3 = ConvBNAct(out_channels//4, out_channels//8, 3) 139 | self.block4 = ConvBNAct(out_channels//8, out_channels//8, 3) 140 | 141 | def forward(self, x): 142 | x = self.block1(x) 143 | x2 = self.block2(x) 144 | if self.stride == 2: 145 | x = self.pool(x) 146 | x3 = self.block3(x2) 147 | x4 = self.block4(x3) 148 | 149 | return torch.cat([x, x2, x3, x4], dim=1) 150 | 151 | 152 | class UAFM(nn.Module): 153 | def __init__(self, in_channels, out_channels, fusion_type): 154 | super().__init__() 155 | fusion_hub = {'spatial': SpatialAttentionModule, 'channel': ChannelAttentionModule} 156 | if fusion_type not in fusion_hub.keys(): 157 | raise ValueError(f'Unsupport fusion type: {fusion_type}.\n') 158 | 159 | self.conv = conv1x1(in_channels, out_channels) 160 | self.attention = fusion_hub[fusion_type](out_channels) 161 | 162 | def forward(self, x_high, x_low): 163 | size = x_low.size()[2:] 164 | x_low = self.conv(x_low) 165 | x_up = F.interpolate(x_high, size, mode='bilinear', align_corners=True) 166 | alpha = self.attention(x_up, x_low) 167 | x = alpha * x_up + (1 - alpha) * x_low 168 | 169 | return x 170 | 171 | 172 | class SpatialAttentionModule(nn.Module): 173 | def __init__(self, out_channels): 174 | super().__init__() 175 | self.conv = conv1x1(4, 1) 176 | 177 | def forward(self, x_up, x_low): 178 | mean_up = torch.mean(x_up, dim=1, keepdim=True) 179 | max_up, _ = torch.max(x_up, dim=1, keepdim=True) 180 | mean_low = torch.mean(x_low, dim=1, keepdim=True) 181 | max_low, _ = torch.max(x_low, dim=1, keepdim=True) 182 | x = self.conv(torch.cat([mean_up, max_up, mean_low, max_low], dim=1)) 183 | x = torch.sigmoid(x) # [N, 1, H, W] 184 | 185 | return x 186 | 187 | 188 | class ChannelAttentionModule(nn.Module): 189 | def __init__(self, out_channels): 190 | super().__init__() 191 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 192 | self.max_pool = nn.AdaptiveMaxPool2d(1) 193 | self.conv = conv1x1(4*out_channels, out_channels) 194 | 195 | def forward(self, x_up, x_low): 196 | avg_up = self.avg_pool(x_up) 197 | max_up = self.max_pool(x_up) 198 | avg_low = self.avg_pool(x_low) 199 | max_low = self.max_pool(x_low) 200 | x = self.conv(torch.cat([avg_up, max_up, avg_low, max_low], dim=1)) 201 | x = torch.sigmoid(x) # [N, C, 1, 1] 202 | 203 | return x -------------------------------------------------------------------------------- /models/regseg.py: -------------------------------------------------------------------------------- 1 | """ 2 | Paper: Rethinking Dilated Convolution for Real-time Semantic Segmentation 3 | Url: https://arxiv.org/abs/2111.09957 4 | Create by: zh320 5 | Date: 2024/01/13 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | from .modules import conv1x1, ConvBNAct, Activation 13 | from .model_registry import register_model 14 | 15 | 16 | @register_model() 17 | class RegSeg(nn.Module): 18 | def __init__(self, num_class=1, n_channel=3, dilations=None, act_type='relu'): 19 | super().__init__() 20 | if dilations is None: 21 | dilations = [[1,1], [1,2], [1,2], [1,3], [2,3], [2,7], [2,3], 22 | [2,6], [2,5], [2,9], [2,11], [4,7], [5,14]] 23 | else: 24 | if len(dilations) != 13: 25 | raise ValueError("Dilation pairs' length should be 13\n") 26 | 27 | # Backbone-1 28 | self.conv_init = ConvBNAct(n_channel, 32, 3, 2, act_type=act_type) 29 | 30 | # Backbone-2 31 | self.stage_d4 = DBlock(32, 48, 2, act_type=act_type) 32 | 33 | # Backbone-3 34 | layers = [DBlock(48, 128, 2, act_type=act_type)] 35 | for _ in range(3-1): 36 | layers.append(DBlock(128, 128, 1, r1=1, r2=1, act_type=act_type)) 37 | self.stage_d8 = nn.Sequential(*layers) 38 | 39 | # Backbone-4 40 | layers = [DBlock(128, 256, 2, act_type=act_type)] 41 | for i in range(13-1): 42 | layers.append(DBlock(256, 256, 1, r1=dilations[i][0], r2=dilations[i][1], act_type=act_type)) 43 | 44 | # Backbone-5 45 | layers.append(DBlock(256, 320, 2, r1=dilations[-1][0], r2=dilations[-1][1], act_type=act_type)) 46 | self.stage_d16 = nn.Sequential(*layers) 47 | 48 | # Decoder 49 | self.decoder = Decoder(num_class, 48, 128, 320, act_type) 50 | 51 | def forward(self, x): 52 | size = x.size()[2:] 53 | 54 | x = self.conv_init(x) # 2x down 55 | x_d4 = self.stage_d4(x) # 4x down 56 | x_d8 = self.stage_d8(x_d4) # 8x down 57 | x_d16 = self.stage_d16(x_d8) # 16x down 58 | x = self.decoder(x_d4, x_d8, x_d16) # 4x down 59 | x = F.interpolate(x, size, mode='bilinear', align_corners=True) 60 | 61 | return x 62 | 63 | 64 | class DBlock(nn.Module): 65 | def __init__(self, in_channels, out_channels, stride=1, r1=None, r2=None, 66 | g=16, se_ratio=0.25, act_type='relu'): 67 | super().__init__() 68 | assert stride in [1, 2], f'Unsupported stride: {stride}' 69 | self.stride = stride 70 | 71 | self.conv1 = ConvBNAct(in_channels, out_channels, 1, act_type=act_type) 72 | if stride == 1: 73 | assert in_channels == out_channels, 'In_channels should be the same as out_channels when stride = 1' 74 | split_ch = out_channels // 2 75 | assert split_ch % g == 0, 'Group width `g` should be evenly divided by split_ch' 76 | groups = split_ch // g 77 | self.split_channels = split_ch 78 | self.conv_left = ConvBNAct(split_ch, split_ch, 3, dilation=r1, groups=groups, act_type=act_type) 79 | self.conv_right = ConvBNAct(split_ch, split_ch, 3, dilation=r2, groups=groups, act_type=act_type) 80 | else: # stride == 2 81 | assert out_channels % g == 0, 'Group width `g` should be evenly divided by out_channels' 82 | groups = out_channels // g 83 | self.conv_left = ConvBNAct(out_channels, out_channels, 3, 2, groups=groups, act_type=act_type) 84 | self.conv_skip = nn.Sequential( 85 | nn.AvgPool2d(2, 2, 0), 86 | ConvBNAct(in_channels, out_channels, 1, act_type='none') 87 | ) 88 | self.conv2 = nn.Sequential( 89 | SEBlock(out_channels, se_ratio, act_type), 90 | ConvBNAct(out_channels, out_channels, 1, act_type='none') 91 | ) 92 | self.act = Activation(act_type) 93 | 94 | def forward(self, x): 95 | residual = x 96 | x = self.conv1(x) 97 | if self.stride == 1: 98 | x_left = self.conv_left(x[:, :self.split_channels]) 99 | x_right = self.conv_right(x[:,self.split_channels:]) 100 | x = torch.cat([x_left, x_right], dim=1) 101 | else: 102 | x = self.conv_left(x) 103 | residual = self.conv_skip(residual) 104 | 105 | x = self.conv2(x) 106 | x += residual 107 | 108 | return self.act(x) 109 | 110 | 111 | class SEBlock(nn.Module): 112 | def __init__(self, channels, reduction_ratio, act_type): 113 | super().__init__() 114 | squeeze_channels = int(channels * reduction_ratio) 115 | self.pool = nn.AdaptiveAvgPool2d(1) 116 | self.se_block = nn.Sequential( 117 | nn.Linear(channels, squeeze_channels), 118 | Activation(act_type), 119 | nn.Linear(squeeze_channels, channels), 120 | Activation('sigmoid') 121 | ) 122 | 123 | def forward(self, x): 124 | residual = x 125 | x = self.pool(x).squeeze(-1).squeeze(-1) 126 | x = self.se_block(x).unsqueeze(-1).unsqueeze(-1) 127 | x = x * residual 128 | 129 | return x 130 | 131 | 132 | class Decoder(nn.Module): 133 | def __init__(self, num_class, d4_channel, d8_channel, d16_channel, act_type): 134 | super().__init__() 135 | self.conv_d16 = ConvBNAct(d16_channel, 128, 1, act_type=act_type) 136 | self.conv_d8_stage1 = ConvBNAct(d8_channel, 128, 1, act_type=act_type) 137 | self.conv_d4_stage1 = ConvBNAct(d4_channel, 8, 1, act_type=act_type) 138 | self.conv_d8_stage2 = ConvBNAct(128, 64, 3, act_type=act_type) 139 | self.conv_d4_stage2 = nn.Sequential( 140 | ConvBNAct(64+8, 64, 3, act_type=act_type), 141 | conv1x1(64, num_class) 142 | ) 143 | 144 | def forward(self, x_d4, x_d8, x_d16): 145 | size_d4 = x_d4.size()[2:] 146 | size_d8 = x_d8.size()[2:] 147 | 148 | x_d16 = self.conv_d16(x_d16) 149 | x_d16 = F.interpolate(x_d16, size_d8, mode='bilinear', align_corners=True) 150 | 151 | x_d8 = self.conv_d8_stage1(x_d8) 152 | x_d8 += x_d16 153 | x_d8 = self.conv_d8_stage2(x_d8) 154 | x_d8 = F.interpolate(x_d8, size_d4, mode='bilinear', align_corners=True) 155 | 156 | x_d4 = self.conv_d4_stage1(x_d4) 157 | x_d4 = torch.cat([x_d4, x_d8], dim=1) 158 | x_d4 = self.conv_d4_stage2(x_d4) 159 | 160 | return x_d4 -------------------------------------------------------------------------------- /models/segnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Paper: SegNet: A Deep Convolutional Encoder-Decoder Architecture for Image Segmentation 3 | Url: https://arxiv.org/abs/1511.00561 4 | Create by: zh320 5 | Date: 2023/08/20 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | from .modules import ConvBNAct 12 | from .model_registry import register_model 13 | 14 | 15 | @register_model() 16 | class SegNet(nn.Module): 17 | def __init__(self, num_class=1, n_channel=3, hid_channel=64, act_type='relu'): 18 | super().__init__() 19 | self.down_stage1 = DownsampleBlock(n_channel, hid_channel, act_type, False) 20 | self.down_stage2 = DownsampleBlock(hid_channel, hid_channel*2, act_type, False) 21 | self.down_stage3 = DownsampleBlock(hid_channel*2, hid_channel*4, act_type, True) 22 | self.down_stage4 = DownsampleBlock(hid_channel*4, hid_channel*8, act_type, True) 23 | self.down_stage5 = DownsampleBlock(hid_channel*8, hid_channel*8, act_type, True) 24 | self.up_stage5 = UpsampleBlock(hid_channel*8, hid_channel*8, act_type, True) 25 | self.up_stage4 = UpsampleBlock(hid_channel*8, hid_channel*4, act_type, True) 26 | self.up_stage3 = UpsampleBlock(hid_channel*4, hid_channel*2, act_type, True) 27 | self.up_stage2 = UpsampleBlock(hid_channel*2, hid_channel, act_type, False) 28 | self.up_stage1 = UpsampleBlock(hid_channel, hid_channel, act_type, False) 29 | self.classifier = ConvBNAct(hid_channel, num_class, act_type=act_type) 30 | 31 | def forward(self, x): 32 | x, indices1 = self.down_stage1(x) 33 | x, indices2 = self.down_stage2(x) 34 | x, indices3 = self.down_stage3(x) 35 | x, indices4 = self.down_stage4(x) 36 | x, indices5 = self.down_stage5(x) 37 | x = self.up_stage5(x, indices5) 38 | x = self.up_stage4(x, indices4) 39 | x = self.up_stage3(x, indices3) 40 | x = self.up_stage2(x, indices2) 41 | x = self.up_stage1(x, indices1) 42 | x = self.classifier(x) 43 | 44 | return x 45 | 46 | 47 | class DownsampleBlock(nn.Module): 48 | def __init__(self, in_channels, out_channels, act_type='relu', extra_conv=False): 49 | super().__init__() 50 | layers = [ConvBNAct(in_channels, out_channels, 3, act_type=act_type, inplace=True), 51 | ConvBNAct(out_channels, out_channels, 3, act_type=act_type, inplace=True)] 52 | if extra_conv: 53 | layers.append(ConvBNAct(out_channels, out_channels, 3, act_type=act_type, inplace=True)) 54 | self.conv = nn.Sequential(*layers) 55 | 56 | self.pool = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True) 57 | 58 | def forward(self, x): 59 | x = self.conv(x) 60 | x, indices = self.pool(x) 61 | return x, indices 62 | 63 | 64 | class UpsampleBlock(nn.Module): 65 | def __init__(self, in_channels, out_channels, act_type='relu', extra_conv=False): 66 | super().__init__() 67 | self.pool = nn.MaxUnpool2d(kernel_size=2, stride=2) 68 | 69 | hid_channel = in_channels if extra_conv else out_channels 70 | 71 | layers = [ConvBNAct(in_channels, in_channels, 3, act_type=act_type, inplace=True), 72 | ConvBNAct(in_channels, hid_channel, 3, act_type=act_type, inplace=True)] 73 | 74 | if extra_conv: 75 | layers.append(ConvBNAct(in_channels, out_channels, 3, act_type=act_type, inplace=True)) 76 | self.conv = nn.Sequential(*layers) 77 | 78 | def forward(self, x, indices): 79 | x = self.pool(x, indices) 80 | x = self.conv(x) 81 | 82 | return x -------------------------------------------------------------------------------- /models/shelfnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Paper: ShelfNet for Fast Semantic Segmentation 3 | Url: https://arxiv.org/abs/1811.11254 4 | Create by: zh320 5 | Date: 2023/10/22 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | from .modules import conv1x1, ConvBNAct, DeConvBNAct, Activation 13 | from .backbone import ResNet 14 | from .model_registry import register_model 15 | 16 | 17 | @register_model() 18 | class ShelfNet(nn.Module): 19 | def __init__(self, num_class=1, n_channel=3, backbone_type='resnet18', 20 | hid_channels=[32,64,128,256], act_type='relu'): 21 | super().__init__() 22 | if 'resnet' in backbone_type: 23 | self.backbone = ResNet(backbone_type) 24 | channels = [64, 128, 256, 512] if backbone_type in ['resnet18', 'resnet34'] else [256, 512, 1024, 2048] 25 | else: 26 | raise NotImplementedError() 27 | 28 | self.conv_A = ConvBNAct(channels[0], hid_channels[0], 1, act_type=act_type) 29 | self.conv_B = ConvBNAct(channels[1], hid_channels[1], 1, act_type=act_type) 30 | self.conv_C = ConvBNAct(channels[2], hid_channels[2], 1, act_type=act_type) 31 | self.conv_D = ConvBNAct(channels[3], hid_channels[3], 1, act_type=act_type) 32 | 33 | self.decoder2 = DecoderBlock(hid_channels, act_type) 34 | self.encoder3 = EncoderBlock(hid_channels, act_type) 35 | self.decoder4 = DecoderBlock(hid_channels, act_type) 36 | 37 | self.classifier = conv1x1(hid_channels[0], num_class) 38 | 39 | def forward(self, x): 40 | size = x.size()[2:] 41 | x_a, x_b, x_c, x_d = self.backbone(x) 42 | 43 | # Column 1 44 | x_a = self.conv_A(x_a) 45 | x_b = self.conv_B(x_b) 46 | x_c = self.conv_C(x_c) 47 | x_d = self.conv_D(x_d) 48 | 49 | # Column 2 50 | x_a, x_b, x_c = self.decoder2(x_a, x_b, x_c, x_d, return_hid_feats=True) 51 | 52 | # Column 3 53 | x_a, x_b, x_c, x_d = self.encoder3(x_a, x_b, x_c) 54 | 55 | # Column 4 56 | x = self.decoder4(x_a, x_b, x_c, x_d) 57 | 58 | x = self.classifier(x) 59 | x = F.interpolate(x, size, mode='bilinear', align_corners=True) 60 | 61 | return x 62 | 63 | 64 | class EncoderBlock(nn.Module): 65 | def __init__(self, channels, act_type): 66 | super().__init__() 67 | self.block_A = SBlock(channels[0], act_type) 68 | self.down_A = ConvBNAct(channels[0], channels[1], 3, 2, act_type=act_type) 69 | 70 | self.block_B = SBlock(channels[1], act_type) 71 | self.down_B = ConvBNAct(channels[1], channels[2], 3, 2, act_type=act_type) 72 | 73 | self.block_C = SBlock(channels[2], act_type) 74 | self.down_C = ConvBNAct(channels[2], channels[3], 3, 2, act_type=act_type) 75 | 76 | def forward(self, x_a, x_b, x_c): 77 | x_a = self.block_A(x_a) 78 | x = self.down_A(x_a) 79 | 80 | x_b = self.block_B(x_b, x) 81 | x = self.down_B(x_b) 82 | 83 | x_c = self.block_C(x_c, x) 84 | x_d = self.down_C(x_c) 85 | 86 | return x_a, x_b, x_c, x_d 87 | 88 | 89 | class DecoderBlock(nn.Module): 90 | def __init__(self, channels, act_type): 91 | super().__init__() 92 | self.block_D = SBlock(channels[3], act_type) 93 | self.up_D = DeConvBNAct(channels[3], channels[2], act_type=act_type) 94 | 95 | self.block_C = SBlock(channels[2], act_type) 96 | self.up_C = DeConvBNAct(channels[2], channels[1], act_type=act_type) 97 | 98 | self.block_B = SBlock(channels[1], act_type) 99 | self.up_B = DeConvBNAct(channels[1], channels[0], act_type=act_type) 100 | 101 | self.block_A = SBlock(channels[0], act_type) 102 | 103 | def forward(self, x_a, x_b, x_c, x_d, return_hid_feats=False): 104 | x_d = self.block_D(x_d) 105 | x = self.up_D(x_d) 106 | 107 | x_c = self.block_C(x_c, x) 108 | x = self.up_C(x_c) 109 | 110 | x_b = self.block_B(x_b, x) 111 | x = self.up_B(x_b) 112 | 113 | x_a = self.block_A(x_a, x) 114 | 115 | if return_hid_feats: 116 | return x_a, x_b, x_c 117 | else: 118 | return x_a 119 | 120 | 121 | class SBlock(nn.Module): 122 | def __init__(self, channels, act_type): 123 | super().__init__() 124 | self.conv1 = ConvBNAct(channels, channels, 3, act_type=act_type) 125 | self.conv2 = ConvBNAct(channels, channels, 3, act_type='none') 126 | self.act = Activation(act_type) 127 | 128 | def forward(self, x_l, x_v=0.): 129 | x = x_l + x_v 130 | residual = x 131 | 132 | x = self.conv1(x) 133 | x = self.conv2(x) 134 | 135 | x += residual 136 | 137 | return self.act(x) -------------------------------------------------------------------------------- /models/smp_wrapper.py: -------------------------------------------------------------------------------- 1 | import segmentation_models_pytorch as smp 2 | 3 | 4 | encoder_hub = smp.encoders.get_encoder_names() 5 | decoder_hub = {'deeplabv3':smp.DeepLabV3, 'deeplabv3p':smp.DeepLabV3Plus, 'fpn':smp.FPN, 6 | 'linknet':smp.Linknet, 'manet':smp.MAnet, 'pan':smp.PAN, 'pspnet':smp.PSPNet, 7 | 'unet':smp.Unet, 'unetpp':smp.UnetPlusPlus} 8 | 9 | 10 | def get_smp_model(encoder_name, decoder_name, encoder_weights, num_class): 11 | if encoder_name not in encoder_hub: 12 | raise ValueError(f'Unsupported encoder: {encoder_name} for SMP model. Available encoders are:\n {encoder_hub}.') 13 | 14 | if decoder_name not in decoder_hub: 15 | raise ValueError(f'Unsupported decoder: {decoder_name} for SMP model. Available decoders are:\n {decoder_hub.keys()}.') 16 | 17 | model = decoder_hub[decoder_name](encoder_name=encoder_name, encoder_weights=encoder_weights, in_channels=3, classes=num_class) 18 | 19 | return model -------------------------------------------------------------------------------- /models/sqnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Paper: Speeding up Semantic Segmentation for Autonomous Driving 3 | Url: https://openreview.net/pdf?id=S1uHiFyyg 4 | Create by: zh320 5 | Date: 2023/10/22 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | from .modules import ConvBNAct, DeConvBNAct, Activation 12 | from .model_registry import register_model 13 | 14 | 15 | @register_model() 16 | class SQNet(nn.Module): 17 | def __init__(self, num_class=1, n_channel=3, act_type='elu'): 18 | super().__init__() 19 | # Encoder, SqueezeNet-1.1 20 | self.conv = ConvBNAct(n_channel, 64, 3, 2, act_type=act_type) 21 | self.pool1 = nn.MaxPool2d(3, 2, 1) 22 | self.fire1 = nn.Sequential( 23 | FireModule(64, 16, 64, 64, act_type), 24 | FireModule(128, 16, 64, 64, act_type) 25 | ) 26 | self.pool2 = nn.MaxPool2d(3, 2, 1) 27 | self.fire2 = nn.Sequential( 28 | FireModule(128, 32, 128, 128, act_type), 29 | FireModule(256, 32, 128, 128, act_type) 30 | ) 31 | self.pool3 = nn.MaxPool2d(3, 2, 1) 32 | self.fire3 = nn.Sequential( 33 | FireModule(256, 48, 192, 192, act_type), 34 | FireModule(384, 48, 192, 192, act_type), 35 | FireModule(384, 64, 256, 256, act_type), 36 | FireModule(512, 64, 256, 256, act_type) 37 | ) 38 | # Decoder 39 | self.pdc = ParallelDilatedConv(512, 128, [1,2,4,8], act_type) 40 | self.up1 = DeConvBNAct(128, 128, act_type=act_type) 41 | self.refine1 = BypassRefinementModule(256, 128, 128, act_type) 42 | self.up2 = DeConvBNAct(128, 128, act_type=act_type) 43 | self.refine2 = BypassRefinementModule(128, 128, 64, act_type=act_type) 44 | self.up3 = DeConvBNAct(64, 64, act_type=act_type) 45 | self.refine3 = BypassRefinementModule(64, 64, num_class, act_type=act_type) 46 | self.up4 = DeConvBNAct(num_class, num_class, act_type=act_type) 47 | 48 | def forward(self, x): 49 | x1 = self.conv(x) 50 | x = self.pool1(x1) 51 | x2 = self.fire1(x) 52 | x = self.pool2(x2) 53 | x3 = self.fire2(x) 54 | x = self.pool3(x3) 55 | x = self.fire3(x) 56 | x = self.pdc(x) 57 | x = self.up1(x) 58 | x = self.refine1(x3, x) 59 | x = self.up2(x) 60 | x = self.refine2(x2, x) 61 | x = self.up3(x) 62 | x = self.refine3(x1, x) 63 | x = self.up4(x) 64 | 65 | return x 66 | 67 | 68 | class FireModule(nn.Module): 69 | def __init__(self, in_channels, sq_channels, ex1_channels, ex3_channels, act_type): 70 | super().__init__() 71 | self.conv_squeeze = ConvBNAct(in_channels, sq_channels, 1, act_type=act_type) 72 | self.conv_expand1 = ConvBNAct(sq_channels, ex1_channels, 1, act_type=act_type) 73 | self.conv_expand3 = ConvBNAct(sq_channels, ex3_channels, 3, act_type=act_type) 74 | 75 | def forward(self, x): 76 | x = self.conv_squeeze(x) 77 | x1 = self.conv_expand1(x) 78 | x3 = self.conv_expand3(x) 79 | x = torch.cat([x1, x3], dim=1) 80 | 81 | return x 82 | 83 | 84 | class ParallelDilatedConv(nn.Module): 85 | def __init__(self, in_channels, out_channels, dilations, act_type): 86 | super().__init__() 87 | assert len(dilations) == 4, 'Length of dilations should be 4.\n' 88 | self.conv0 = ConvBNAct(in_channels, out_channels, 3, dilation=dilations[0], act_type=act_type) 89 | self.conv1 = ConvBNAct(in_channels, out_channels, 3, dilation=dilations[1], act_type=act_type) 90 | self.conv2 = ConvBNAct(in_channels, out_channels, 3, dilation=dilations[2], act_type=act_type) 91 | self.conv3 = ConvBNAct(in_channels, out_channels, 3, dilation=dilations[3], act_type=act_type) 92 | 93 | def forward(self, x): 94 | x0 = self.conv0(x) 95 | x1 = self.conv1(x) 96 | x2 = self.conv2(x) 97 | x3 = self.conv3(x) 98 | x = x0 + x1 + x2 + x3 99 | 100 | return x 101 | 102 | 103 | class BypassRefinementModule(nn.Module): 104 | def __init__(self, low_channels, high_channels, out_channels, act_type): 105 | super().__init__() 106 | self.conv_low = ConvBNAct(low_channels, low_channels, 3, act_type=act_type) 107 | self.conv_cat = ConvBNAct(low_channels + high_channels, out_channels, 3, act_type=act_type) 108 | 109 | def forward(self, x_low, x_high): 110 | x_low = self.conv_low(x_low) 111 | x = torch.cat([x_low, x_high], dim=1) 112 | x = self.conv_cat(x) 113 | 114 | return x -------------------------------------------------------------------------------- /models/stdc.py: -------------------------------------------------------------------------------- 1 | """ 2 | Paper: Rethinking BiSeNet For Real-time Semantic Segmentation 3 | Url: https://arxiv.org/abs/2104.13188 4 | Create by: zh320 5 | Date: 2024/01/20 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | from .modules import conv1x1, ConvBNAct, SegHead 13 | from .bisenetv1 import AttentionRefinementModule, FeatureFusionModule 14 | from .model_registry import register_model, aux_models, detail_head_models 15 | 16 | 17 | @register_model(aux_models, detail_head_models) 18 | class STDC(nn.Module): 19 | def __init__(self, num_class=1, n_channel=3, encoder_type='stdc1', use_detail_head=False, use_aux=False, 20 | act_type='relu'): 21 | super().__init__() 22 | repeat_times_hub = {'stdc1': [1,1,1], 'stdc2': [3,4,2]} 23 | if encoder_type not in repeat_times_hub.keys(): 24 | raise ValueError('Unsupported encoder type.\n') 25 | repeat_times = repeat_times_hub[encoder_type] 26 | assert not use_detail_head * use_aux, 'Currently only support either aux-head or detail head.\n' 27 | self.use_detail_head = use_detail_head 28 | self.use_aux = use_aux 29 | 30 | self.stage1 = ConvBNAct(n_channel, 32, 3, 2) 31 | self.stage2 = ConvBNAct(32, 64, 3, 2) 32 | self.stage3 = self._make_stage(64, 256, repeat_times[0], act_type) 33 | self.stage4 = self._make_stage(256, 512, repeat_times[1], act_type) 34 | self.stage5 = self._make_stage(512, 1024, repeat_times[2], act_type) 35 | 36 | if use_aux: 37 | self.aux_head3 = SegHead(256, num_class, act_type) 38 | self.aux_head4 = SegHead(512, num_class, act_type) 39 | self.aux_head5 = SegHead(1024, num_class, act_type) 40 | 41 | self.pool = nn.AdaptiveAvgPool2d(1) 42 | self.arm4 = AttentionRefinementModule(512) 43 | self.arm5 = AttentionRefinementModule(1024) 44 | self.conv4 = conv1x1(512, 256) 45 | self.conv5 = conv1x1(1024, 256) 46 | 47 | self.ffm = FeatureFusionModule(256+256, 128, act_type) 48 | 49 | self.seg_head = SegHead(128, num_class, act_type) 50 | if use_detail_head: 51 | self.detail_head = SegHead(256, 1, act_type) 52 | self.detail_conv = conv1x1(3, 1) 53 | 54 | def _make_stage(self, in_channels, out_channels, repeat_times, act_type): 55 | layers = [STDCModule(in_channels, out_channels, 2, act_type)] 56 | 57 | for _ in range(repeat_times): 58 | layers.append(STDCModule(out_channels, out_channels, 1, act_type)) 59 | return nn.Sequential(*layers) 60 | 61 | def forward(self, x, is_training=False): 62 | size = x.size()[2:] 63 | 64 | x = self.stage1(x) 65 | x = self.stage2(x) 66 | x3 = self.stage3(x) 67 | if self.use_aux: 68 | aux3 = self.aux_head3(x3) 69 | 70 | x4 = self.stage4(x3) 71 | if self.use_aux: 72 | aux4 = self.aux_head4(x4) 73 | 74 | x5 = self.stage5(x4) 75 | if self.use_aux: 76 | aux5 = self.aux_head5(x5) 77 | 78 | x5_pool = self.pool(x5) 79 | x5 = x5_pool + self.arm5(x5) 80 | x5 = self.conv5(x5) 81 | x5 = F.interpolate(x5, scale_factor=2, mode='bilinear', align_corners=True) 82 | 83 | x4 = self.arm4(x4) 84 | x4 = self.conv4(x4) 85 | x4 += x5 86 | x4 = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=True) 87 | 88 | x = self.ffm(x4, x3) 89 | x = self.seg_head(x) 90 | x = F.interpolate(x, size, mode='bilinear', align_corners=True) 91 | 92 | if self.use_detail_head and is_training: 93 | x_detail = self.detail_head(x3) 94 | return x, x_detail 95 | elif self.use_aux and is_training: 96 | return x, (aux3, aux4, aux5) 97 | else: 98 | return x 99 | 100 | 101 | class STDCModule(nn.Module): 102 | def __init__(self, in_channels, out_channels, stride, act_type): 103 | super().__init__() 104 | if out_channels % 8 != 0: 105 | raise ValueError('Output channel should be evenly divided by 8.\n') 106 | if stride not in [1, 2]: 107 | raise ValueError(f'Unsupported stride: {stride}\n') 108 | 109 | self.stride = stride 110 | self.block1 = ConvBNAct(in_channels, out_channels//2, 1) 111 | self.block2 = ConvBNAct(out_channels//2, out_channels//4, 3, stride) 112 | if self.stride == 2: 113 | self.pool = nn.AvgPool2d(3, 2, 1) 114 | self.block3 = ConvBNAct(out_channels//4, out_channels//8, 3) 115 | self.block4 = ConvBNAct(out_channels//8, out_channels//8, 3) 116 | 117 | def forward(self, x): 118 | x1 = self.block1(x) 119 | x2 = self.block2(x1) 120 | if self.stride == 2: 121 | x1 = self.pool(x1) 122 | x3 = self.block3(x2) 123 | x4 = self.block4(x3) 124 | 125 | return torch.cat([x1, x2, x3, x4], dim=1) 126 | 127 | 128 | class LaplacianConv(nn.Module): 129 | def __init__(self, device): 130 | super().__init__() 131 | self.laplacian_kernel = torch.tensor([[[[-1.,-1.,-1.],[-1.,8.,-1.],[-1.,-1.,-1.]]]]).to(device) 132 | 133 | def forward(self, lbl): 134 | size = lbl.size()[2:] 135 | lbl_1x = F.conv2d(lbl, self.laplacian_kernel, stride=1, padding=1) 136 | lbl_2x = F.conv2d(lbl, self.laplacian_kernel, stride=2, padding=1) 137 | lbl_4x = F.conv2d(lbl, self.laplacian_kernel, stride=4, padding=1) 138 | 139 | lbl_2x = F.interpolate(lbl_2x, size, mode='nearest') 140 | lbl_4x = F.interpolate(lbl_4x, size, mode='nearest') 141 | 142 | lbl = torch.cat([lbl_1x, lbl_2x, lbl_4x], dim=1) 143 | 144 | return lbl -------------------------------------------------------------------------------- /models/swiftnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Paper: In Defense of Pre-trained ImageNet Architectures for Real-time 3 | Semantic Segmentation of Road-driving Images 4 | Url: https://arxiv.org/abs/1903.08469 5 | Create by: zh320 6 | Date: 2023/10/22 7 | """ 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | from .modules import conv1x1, PWConvBNAct, ConvBNAct, PyramidPoolingModule 14 | from .backbone import ResNet, Mobilenetv2 15 | from .model_registry import register_model 16 | 17 | 18 | @register_model() 19 | class SwiftNet(nn.Module): 20 | def __init__(self, num_class=1, n_channel=3, backbone_type='resnet18', up_channels=128, 21 | act_type='relu'): 22 | super().__init__() 23 | if 'resnet' in backbone_type: 24 | self.backbone = ResNet(backbone_type) 25 | channels = [64, 128, 256, 512] if backbone_type in ['resnet18', 'resnet34'] else [256, 512, 1024, 2048] 26 | elif backbone_type == 'mobilenet_v2': 27 | self.backbone = Mobilenetv2() 28 | channels = [24, 32, 96, 320] 29 | else: 30 | raise NotImplementedError() 31 | 32 | self.connection1 = ConvBNAct(channels[0], up_channels, 1, act_type=act_type) 33 | self.connection2 = ConvBNAct(channels[1], up_channels, 1, act_type=act_type) 34 | self.connection3 = ConvBNAct(channels[2], up_channels, 1, act_type=act_type) 35 | self.spp = PyramidPoolingModule(channels[3], up_channels, act_type, bias=True) 36 | self.decoder = Decoder(up_channels, num_class, act_type) 37 | 38 | def forward(self, x): 39 | size = x.size()[2:] 40 | 41 | x1, x2, x3, x4 = self.backbone(x) 42 | 43 | x1 = self.connection1(x1) 44 | x2 = self.connection2(x2) 45 | x3 = self.connection3(x3) 46 | x4 = self.spp(x4) 47 | 48 | x = self.decoder(x4, x1, x2, x3) 49 | x = F.interpolate(x, size, mode='bilinear', align_corners=True) 50 | 51 | return x 52 | 53 | 54 | class Decoder(nn.Module): 55 | def __init__(self, channels, num_class, act_type): 56 | super().__init__() 57 | self.up_stage3 = ConvBNAct(channels, channels, 3, act_type=act_type) 58 | self.up_stage2 = ConvBNAct(channels, channels, 3, act_type=act_type) 59 | self.up_stage1 = ConvBNAct(channels, num_class, 3, act_type=act_type) 60 | 61 | def forward(self, x, x1, x2, x3): 62 | x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True) 63 | x += x3 64 | x = self.up_stage3(x) 65 | 66 | x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True) 67 | x += x2 68 | x = self.up_stage2(x) 69 | 70 | x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True) 71 | x += x1 72 | x = self.up_stage1(x) 73 | 74 | return x -------------------------------------------------------------------------------- /optuna_results/bisenetv1.json: -------------------------------------------------------------------------------- 1 | { 2 | "params": { 3 | "optimizer": "adamw", 4 | "base_lr": 0.0011041009325283182, 5 | "use_ema": true, 6 | "scale_max": 0.5767783923237831, 7 | "scale_min": 0.7670720406133559, 8 | "brightness": 0.4613930720243222, 9 | "contrast": 0.45946637168836507, 10 | "saturation": 0.4113661588519595, 11 | "h_flip": 0.47701530410336107 12 | }, 13 | "value": 0.7440062165260315, 14 | "finished_trials": 100, 15 | "pruned_trials": 85, 16 | "completed_trials": 15 17 | } -------------------------------------------------------------------------------- /optuna_results/ddrnet.json: -------------------------------------------------------------------------------- 1 | { 2 | "params": { 3 | "optimizer": "adam", 4 | "base_lr": 0.002790064901626871, 5 | "aux_coef": 0.19763405764797368, 6 | "use_ema": true, 7 | "scale_max": 0.6496731805629183, 8 | "scale_min": 0.32619249404087225, 9 | "brightness": 0.4860585050496376, 10 | "contrast": 0.6266853104289152, 11 | "saturation": 0.5714174223652339, 12 | "h_flip": 0.11671740659004737 13 | }, 14 | "value": 0.722217857837677, 15 | "finished_trials": 100, 16 | "pruned_trials": 91, 17 | "completed_trials": 9 18 | } -------------------------------------------------------------------------------- /optuna_results/liteseg.json: -------------------------------------------------------------------------------- 1 | { 2 | "params": { 3 | "optimizer": "sgd", 4 | "base_lr": 0.012828976260490054, 5 | "use_ema": true, 6 | "scale_max": 0.943029447189605, 7 | "scale_min": 0.5842921221652486, 8 | "brightness": 0.26756860167160557, 9 | "contrast": 0.16430400780652107, 10 | "saturation": 0.5078665623692797, 11 | "h_flip": 0.4059527463082462 12 | }, 13 | "value": 0.754679262638092, 14 | "finished_trials": 100, 15 | "pruned_trials": 84, 16 | "completed_trials": 16 17 | } -------------------------------------------------------------------------------- /optuna_search.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import optuna 4 | from optuna.trial import TrialState 5 | from optuna.storages import RetryFailedTrialCallback 6 | import torch.distributed as dist 7 | from core import SegTrainer 8 | from configs.optuna_config import OptunaConfig 9 | 10 | import warnings 11 | warnings.filterwarnings("ignore") 12 | 13 | 14 | class OptunaTrainer(SegTrainer): 15 | def __init__(self, config, trial): 16 | super().__init__(config) 17 | self.trial = trial 18 | 19 | def validate(self, config, *args, **kwargs): 20 | val_score = super().validate(config) 21 | self.after_validate(val_score) 22 | return val_score 23 | 24 | def after_validate(self, val_score): 25 | self.trial.report(val_score, self.cur_epoch) 26 | 27 | if self.trial.should_prune(): 28 | raise optuna.exceptions.TrialPruned() 29 | 30 | 31 | if __name__ == '__main__': 32 | LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) 33 | DDP = LOCAL_RANK != -1 34 | MAIN_RANK = LOCAL_RANK in [-1, 0] 35 | 36 | if DDP: 37 | dist.init_process_group(backend=dist.Backend.NCCL, init_method='env://') 38 | 39 | config = OptunaConfig() 40 | STUDY_NAME = config.study_name 41 | STUDY_DIRECTION = config.study_direction 42 | NUM_TRIAL = config.num_trial 43 | SAVE_DIR = config.save_dir 44 | del config 45 | 46 | trial_scores = {} 47 | def objective(trial): 48 | trial = optuna.integration.TorchDistributedTrial(trial) if DDP else trial 49 | 50 | config = OptunaConfig() 51 | config.init_dependent_config() 52 | 53 | if MAIN_RANK: 54 | print(f"Running trial: {trial.number}...\n") 55 | if config.save_every_trial: 56 | config.save_dir = f'{SAVE_DIR}/trial_{trial.number}' 57 | 58 | config.get_trial_params(trial) 59 | trainer = OptunaTrainer(config, trial) 60 | best_score = trainer.run(config) 61 | 62 | trial_scores[trial.number] = best_score.item() 63 | with open(f'{SAVE_DIR}/trial_scores.json', 'w') as f: 64 | json.dump(trial_scores, f, indent=1) 65 | 66 | return best_score 67 | 68 | if MAIN_RANK: 69 | storage = optuna.storages.RDBStorage("sqlite:///optuna.db", heartbeat_interval=1, failed_trial_callback=RetryFailedTrialCallback(),) 70 | study = optuna.create_study(storage=storage, study_name=STUDY_NAME, direction=STUDY_DIRECTION, load_if_exists=True) 71 | 72 | print('Using Optuna to perform hyperparameter search.\n') 73 | study.optimize(objective, n_trials=NUM_TRIAL, gc_after_trial=True) 74 | 75 | best_trial = study.best_trial 76 | pruned_trials = study.get_trials(deepcopy=False, states=[TrialState.PRUNED]) 77 | complete_trials = study.get_trials(deepcopy=False, states=[TrialState.COMPLETE]) 78 | 79 | optuna_results = {'params':best_trial.params, 80 | 'value':best_trial.value, 81 | 'finished_trials': len(study.trials), 82 | 'pruned_trials': len(pruned_trials), 83 | 'completed_trials': len(complete_trials)} 84 | 85 | with open(f'{SAVE_DIR}/optuna_results.json', 'w') as f: 86 | json.dump(optuna_results, f, indent=1) 87 | 88 | else: 89 | for _ in range(NUM_TRIAL): 90 | try: 91 | objective(None) 92 | except optuna.TrialPruned: 93 | pass -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.0.0 2 | addict==2.4.0 3 | albumentations==1.3.0 4 | auditwheel==5.4.0 5 | cachetools==4.2.4 6 | certifi==2024.8.30 7 | charset-normalizer==3.3.2 8 | contourpy==1.1.0 9 | cycler==0.11.0 10 | distro==1.8.0 11 | efficientnet-pytorch==0.7.1 12 | filelock==3.16.0 13 | fonttools==4.41.1 14 | fsspec==2024.9.0 15 | google-auth==2.34.0 16 | google-auth-oauthlib==0.4.6 17 | grpcio==1.66.1 18 | huggingface-hub==0.24.6 19 | idna==3.8 20 | imageio==2.35.1 21 | importlib_metadata==8.4.0 22 | importlib-resources==6.0.0 23 | joblib==1.4.2 24 | kiwisolver==1.4.4 25 | lazy_loader==0.4 26 | lightning-utilities==0.11.7 27 | loguru==0.5.3 28 | Markdown==3.7 29 | MarkupSafe==2.1.5 30 | munch==4.0.0 31 | networkx==3.1 32 | numpy==1.24.4 33 | oauthlib==3.2.2 34 | opencv-python-headless==4.10.0.84 35 | packaging==24.1 36 | pathspec==0.9.0 37 | pillow==10.4.0 38 | pip==24.2 39 | pretrainedmodels==0.7.4 40 | protobuf==3.20.0 41 | pyasn1==0.4.8 42 | pyasn1-modules==0.2.8 43 | pycocotools==2.0.6 44 | pyelftools==0.29 45 | pyparsing==3.0.9 46 | python-dateutil==2.8.2 47 | PyWavelets==1.4.1 48 | PyYAML==6.0.2 49 | qudida==0.0.4 50 | requests==2.32.3 51 | requests-oauthlib==1.3.0 52 | rsa==4.7.2 53 | scikit-image==0.21.0 54 | scikit-learn==1.3.2 55 | scipy==1.10.1 56 | segmentation-models-pytorch==0.3.2 57 | setuptools==47.3.1 58 | six==1.16.0 59 | tb-nightly==2.8.0a20211117 60 | tensorboard-data-server==0.6.1 61 | tensorboard-plugin-wit==1.8.0 62 | threadpoolctl==3.5.0 63 | tifffile==2023.7.10 64 | timm==0.6.12 65 | torch==1.8.1+cu111 66 | torchaudio==0.8.1 67 | torchmetrics==1.2.0 68 | torchvision==0.9.1+cu111 69 | tqdm==4.66.5 70 | typing_extensions==4.12.2 71 | urllib3==2.2.2 72 | Werkzeug==3.0.4 73 | wheel==0.43.0 74 | zipp==3.16.2 75 | -------------------------------------------------------------------------------- /tools/export.py: -------------------------------------------------------------------------------- 1 | import os, sys, torch 2 | sys.path.append( os.path.dirname( os.path.dirname( os.path.abspath(__file__) ) ) ) 3 | 4 | from configs import MyConfig, load_parser 5 | from models import get_model 6 | 7 | import warnings 8 | warnings.filterwarnings("ignore") 9 | 10 | 11 | class Exporter: 12 | def __init__(self, config): 13 | config.use_aux = False 14 | config.use_detail_head = False 15 | 16 | self.load_ckpt_path = config.load_ckpt_path 17 | self.export_format = config.export_format 18 | self.export_size = config.export_size 19 | self.onnx_opset = config.onnx_opset 20 | self.export_path = config.export_name + f'.{config.export_format}' 21 | self.config = config 22 | 23 | self.model = get_model(config) 24 | self.load_ckpt() 25 | 26 | def load_ckpt(self): 27 | if not self.load_ckpt_path: # when set to None 28 | pass 29 | elif os.path.isfile(self.load_ckpt_path): 30 | checkpoint = torch.load(self.load_ckpt_path, map_location=torch.device('cpu')) 31 | self.model.load_state_dict(checkpoint['state_dict'], strict=False) 32 | self.model.eval() 33 | 34 | print(f'Loading checkpoint: {self.load_ckpt_path} successfully.\n') 35 | del checkpoint 36 | else: 37 | raise RuntimeError 38 | 39 | def export(self): 40 | print('\n=========Export=========') 41 | print(f'Model: {self.config.model}\nEncoder: {self.config.encoder}\nDecoder: {self.config.decoder}') 42 | print(f'Export Size (H, W): {self.export_size}') 43 | print(f'Export Format: {self.export_format}') 44 | 45 | if self.export_format == 'onnx': 46 | from models.modules import replace_adaptive_avg_pool 47 | self.model = replace_adaptive_avg_pool(self.model) 48 | 49 | self.export_onnx() 50 | print('\nExporting Finished.\n') 51 | 52 | else: 53 | raise NotImplementedError 54 | 55 | def export_onnx(self, image=None): 56 | image = torch.rand(1, 3, *self.export_size) if not image else image 57 | torch.onnx.export(self.model, image, self.export_path, opset_version=self.onnx_opset, 58 | input_names=['input'], output_names=['output']) 59 | 60 | 61 | if __name__ == '__main__': 62 | config = MyConfig() 63 | config = load_parser(config) 64 | config.load_ckpt_path = None # None if you do not have a ckpt to load 65 | config.init_dependent_config() 66 | 67 | try: 68 | exporter = Exporter(config) 69 | exporter.export() 70 | except Exception as e: 71 | print(f'\nUnable to export PyTorch model {config.model} to {config.export_format} due to: {e}') -------------------------------------------------------------------------------- /tools/get_model_infos.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from os import path 3 | sys.path.append( path.dirname( path.dirname( path.abspath(__file__) ) ) ) 4 | 5 | from configs import MyConfig, load_parser 6 | from models import get_model 7 | 8 | 9 | def cal_model_params(config, imgw=1024, imgh=512): 10 | model = get_model(config) 11 | print(f'\nModel: {config.model}\nEncoder: {config.encoder}\nDecoder: {config.decoder}') 12 | 13 | try: 14 | from ptflops import get_model_complexity_info 15 | model.eval() 16 | ''' 17 | Notice that ptflops doesn't take into account torch.nn.functional.* operations. 18 | If you want to get correct macs result, you need to modify the modules like 19 | torch.nn.functional.interpolate to torch.nn.Upsample. 20 | ''' 21 | _, params = get_model_complexity_info(model, (3, imgh, imgw), as_strings=True, 22 | print_per_layer_stat=False, verbose=False) 23 | print(f'Number of parameters: {params}\n') 24 | except: 25 | import numpy as np 26 | params = np.sum([p.numel() for p in model.parameters()]) 27 | print(f'Number of parameters: {params / 1e6:.2f}M\n') 28 | 29 | 30 | if __name__ == '__main__': 31 | config = MyConfig() 32 | config = load_parser(config) 33 | 34 | config.use_aux = False 35 | config.use_detail_head = False 36 | 37 | cal_model_params(config) -------------------------------------------------------------------------------- /tools/test_speed.py: -------------------------------------------------------------------------------- 1 | import sys, time, torch 2 | from os import path 3 | sys.path.append( path.dirname( path.dirname( path.abspath(__file__) ) ) ) 4 | 5 | from configs import MyConfig, load_parser 6 | from models import get_model, model_hub 7 | 8 | 9 | def test_model_speed(config, mode='cuda', ratio=0.5, imgw=2048, imgh=1024, iterations=None): 10 | if mode == 'cuda': 11 | # Codes are based on 12 | # https://github.com/ydhongHIT/DDRNet/blob/main/segmentation/DDRNet_23_slim_eval_speed.py 13 | 14 | if ratio != 1.0: 15 | assert ratio > 0, 'Ratio should be larger than 0.\n' 16 | imgw = int(imgw * ratio) 17 | imgh = int(imgh * ratio) 18 | 19 | device = torch.device('cuda') 20 | # torch.backends.cudnn.enabled = True 21 | # torch.backends.cudnn.benchmark = True 22 | 23 | model = get_model(config) 24 | model.eval() 25 | model.to(device) 26 | print('\n=========Speed Testing=========') 27 | print(f'Model: {config.model}\nEncoder: {config.encoder}\nDecoder: {config.decoder}') 28 | print(f'Size (W, H): {imgw}, {imgh}') 29 | 30 | input = torch.randn(1, 3, imgh, imgw).cuda() 31 | with torch.no_grad(): 32 | for _ in range(10): 33 | model(input) 34 | 35 | if iterations is None: 36 | elapsed_time = 0 37 | iterations = 100 38 | while elapsed_time < 1: 39 | torch.cuda.synchronize() 40 | torch.cuda.synchronize() 41 | t_start = time.time() 42 | for _ in range(iterations): 43 | model(input) 44 | torch.cuda.synchronize() 45 | torch.cuda.synchronize() 46 | elapsed_time = time.time() - t_start 47 | iterations *= 2 48 | FPS = iterations / elapsed_time 49 | iterations = int(FPS * 6) 50 | 51 | torch.cuda.synchronize() 52 | torch.cuda.synchronize() 53 | t_start = time.time() 54 | for _ in range(iterations): 55 | model(input) 56 | torch.cuda.synchronize() 57 | torch.cuda.synchronize() 58 | elapsed_time = time.time() - t_start 59 | latency = elapsed_time / iterations * 1000 60 | torch.cuda.empty_cache() 61 | FPS = 1000 / latency 62 | 63 | elif mode == 'cpu': 64 | import numpy as np 65 | import onnxruntime as ort 66 | from tools.export import Exporter 67 | 68 | try: 69 | config.export_name = f'{config.model}_dummy' 70 | exporter = Exporter(config) 71 | exporter.export() 72 | except Exception as e: 73 | print(f'\nUnable to export PyTorch model {config.model} to ONNX due to: {e}') 74 | return -1 75 | 76 | load_onnx_path = f'{config.model}_dummy.onnx' if not config.load_onnx_path else config.load_onnx_path 77 | 78 | print('\n=========Speed Testing=========') 79 | print(f'Model: {config.model}\nEncoder: {config.encoder}\nDecoder: {config.decoder}') 80 | print(f'Size (H, W): {config.export_size}') 81 | 82 | session = ort.InferenceSession(load_onnx_path, providers=["CPUExecutionProvider"]) 83 | input_name = session.get_inputs()[0].name 84 | input_shape = session.get_inputs()[0].shape 85 | 86 | dummy_input = np.random.randn(*input_shape).astype(np.float32) 87 | 88 | print('\nRunning CPU warmup...') 89 | for _ in range(10): 90 | session.run(None, {input_name: dummy_input}) 91 | 92 | num_iterations = iterations if iterations else 100 93 | print('Start speed testing on CPU using ONNX runtime...') 94 | start_time = time.time() 95 | for _ in range(num_iterations): 96 | session.run(None, {input_name: dummy_input}) 97 | 98 | end_time = time.time() 99 | FPS = num_iterations / (end_time - start_time) 100 | 101 | else: 102 | raise NotImplementedError 103 | 104 | print(f'FPS: {FPS}\n') 105 | return FPS 106 | 107 | 108 | if __name__ == '__main__': 109 | mode = 'cpu' 110 | test_all_model = False 111 | 112 | config = MyConfig() 113 | config = load_parser(config) 114 | config.use_aux = False 115 | config.use_detail_head = False 116 | config.load_ckpt_path = None # None if you do not have a ckpt to load and export to ONNX 117 | config.init_dependent_config() 118 | 119 | with open(f'{mode}_perf.txt', 'w') as f: 120 | f.write('model\t\tFPS\n') 121 | 122 | if test_all_model: 123 | for model_name in sorted(model_hub.keys()): 124 | config.model = model_name 125 | 126 | fps = test_model_speed(config, mode=mode) 127 | with open(f'{mode}_perf.txt', 'a+') as f: 128 | f.write(f'{config.model}\t\t{fps:.2f}\n') 129 | 130 | elif config.model in model_hub.keys(): 131 | fps = test_model_speed(config, mode=mode) 132 | with open(f'{mode}_perf.txt', 'a+') as f: 133 | f.write(f'{config.model}\t\t{fps:.2f}\n') 134 | 135 | else: 136 | raise ValueError(f'Unsupported model: {config.model}\n') -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import * 2 | from .parallel import * 3 | from .transforms import * 4 | from .optimizer import get_optimizer 5 | from .scheduler import get_scheduler 6 | from .metrics import get_seg_metrics 7 | from .model_ema import get_ema_model -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | from torchmetrics import JaccardIndex 2 | 3 | 4 | def get_seg_metrics(config, task='multiclass', reduction='none'): 5 | metrics = JaccardIndex(task=task, num_classes=config.num_class, 6 | ignore_index=config.ignore_index, average=reduction,) 7 | return metrics -------------------------------------------------------------------------------- /utils/model_ema.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Codes are based on 3 | https://github.com/huggingface/pytorch-image-models/blob/main/timm/utils/model_ema.py 4 | ''' 5 | 6 | import torch 7 | import torch.nn as nn 8 | from copy import deepcopy 9 | from .parallel import de_parallel 10 | 11 | 12 | def get_ema_model(config, model, device): 13 | return ModelEmaV2(config, model, device=device) 14 | 15 | 16 | class ModelEmaV2(nn.Module): 17 | def __init__(self, config, model, device=None): 18 | super().__init__() 19 | # make a copy of the model for accumulating moving average of weights 20 | self.ema = deepcopy(de_parallel(model)) 21 | self.ema.eval() 22 | self.device = device # perform ema on different device from model if set 23 | if self.device is not None: 24 | self.ema.to(device=device) 25 | self.use_ema = config.use_ema 26 | self.total_itrs = config.total_itrs 27 | 28 | @torch.no_grad() 29 | def _update(self, model, update_fn): 30 | for ema_v, model_v in zip(self.ema.state_dict().values(), model.state_dict().values()): 31 | if self.device is not None: 32 | model_v = model_v.to(device=self.device) 33 | ema_v.copy_(update_fn(ema_v, model_v)) 34 | 35 | def update(self, model, cur_itrs): 36 | if self.use_ema: 37 | decay = min(max(cur_itrs / self.total_itrs, 0), 1) 38 | self._update(de_parallel(model), update_fn=lambda e, m: decay * e + (1. - decay) * m) 39 | else: 40 | self._update(de_parallel(model), update_fn=lambda e, m: m) -------------------------------------------------------------------------------- /utils/optimizer.py: -------------------------------------------------------------------------------- 1 | from torch.optim import SGD, Adam, AdamW 2 | 3 | 4 | def get_optimizer(config, model): 5 | optimizer_hub = {'sgd':SGD, 'adam':Adam, 'adamw':AdamW} 6 | params = model.parameters() 7 | 8 | if config.optimizer_type == 'sgd': 9 | config.lr = config.base_lr * config.gpu_num 10 | optimizer = optimizer_hub[config.optimizer_type](params=params, lr=config.lr, 11 | momentum=config.momentum, 12 | weight_decay=config.weight_decay) 13 | 14 | elif config.optimizer_type in ['adam', 'adamw']: 15 | config.lr = 0.1 * config.base_lr * config.gpu_num 16 | optimizer = optimizer_hub[config.optimizer_type](params=params, lr=config.lr) 17 | 18 | else: 19 | raise NotImplementedError(f'Unsupported optimizer type: {config.optimizer_type}') 20 | 21 | return optimizer -------------------------------------------------------------------------------- /utils/parallel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.distributed as dist 4 | from torch.nn.parallel import DistributedDataParallel as DDP 5 | 6 | 7 | def is_parallel(model): 8 | # Returns True if model is of type DP or DDP 9 | return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel) 10 | 11 | 12 | def de_parallel(model): 13 | # De-parallelize a model: returns single-GPU model if model is of type DP or DDP 14 | return model.module if is_parallel(model) else model 15 | 16 | 17 | def set_device(config, rank): 18 | if config.DDP: 19 | torch.cuda.set_device(rank) 20 | if not dist.is_initialized(): 21 | dist.init_process_group(backend=dist.Backend.NCCL, init_method='env://') 22 | device = torch.device('cuda', rank) 23 | config.gpu_num = dist.get_world_size() 24 | else: # DP 25 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 26 | config.gpu_num = torch.cuda.device_count() 27 | config.train_bs *= config.gpu_num 28 | 29 | # Setup num_workers 30 | config.num_workers = config.gpu_num * config.base_workers 31 | 32 | return device 33 | 34 | 35 | def parallel_model(config, model, rank, device): 36 | if config.DDP: 37 | if config.synBN: 38 | model = nn.SyncBatchNorm.convert_sync_batchnorm(model) 39 | model = DDP(model.to(rank), device_ids=[rank], output_device=rank) 40 | else: 41 | model = nn.DataParallel(model) 42 | model.to(device) 43 | 44 | return model 45 | 46 | 47 | def destroy_ddp_process(config): 48 | if config.DDP and config.destroy_ddp_process: 49 | dist.destroy_process_group() 50 | 51 | 52 | def sampler_set_epoch(config, loader, cur_epochs): 53 | if config.DDP: 54 | loader.sampler.set_epoch(cur_epochs) -------------------------------------------------------------------------------- /utils/scheduler.py: -------------------------------------------------------------------------------- 1 | from torch.optim.lr_scheduler import OneCycleLR, StepLR 2 | from math import ceil 3 | 4 | 5 | def get_scheduler(config, optimizer): 6 | if config.DDP: 7 | config.iters_per_epoch = ceil(config.train_num/config.train_bs/config.gpu_num) 8 | else: 9 | config.iters_per_epoch = ceil(config.train_num/config.train_bs) 10 | config.total_itrs = int(config.total_epoch*config.iters_per_epoch) 11 | 12 | if config.lr_policy == 'cos_warmup': 13 | warmup_ratio = config.warmup_epochs / config.total_epoch 14 | scheduler = OneCycleLR(optimizer, max_lr=config.lr, total_steps=config.total_itrs, 15 | pct_start=warmup_ratio) 16 | 17 | elif config.lr_policy == 'linear': 18 | scheduler = OneCycleLR(optimizer, max_lr=config.lr, total_steps=config.total_itrs, 19 | pct_start=0., anneal_strategy='linear') 20 | 21 | elif config.lr_policy == 'step': 22 | scheduler = StepLR(optimizer, step_size=config.step_size, gamma=0.1) 23 | 24 | else: 25 | raise NotImplementedError(f'Unsupported scheduler type: {config.lr_policy}') 26 | return scheduler -------------------------------------------------------------------------------- /utils/transforms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import albumentations as AT 3 | 4 | 5 | def to_numpy(array): 6 | if not isinstance(array, np.ndarray): 7 | array = np.asarray(array) 8 | return array 9 | 10 | 11 | class Scale: 12 | def __init__(self, scale, interpolation=1, p=1, is_testing=False): 13 | self.scale = scale 14 | self.interpolation = interpolation 15 | self.p = p 16 | self.is_testing = is_testing 17 | 18 | def __call__(self, image, mask=None): 19 | img = to_numpy(image) 20 | if not self.is_testing: 21 | msk = to_numpy(mask) 22 | 23 | imgh, imgw, _ = img.shape 24 | new_imgh, new_imgw = int(imgh * self.scale), int(imgw * self.scale) 25 | 26 | aug = AT.Resize(height=new_imgh, width=new_imgw, interpolation=self.interpolation, p=self.p) 27 | 28 | if self.is_testing: 29 | augmented = aug(image=img) 30 | else: 31 | augmented = aug(image=img, mask=msk) 32 | return augmented -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import os, random, torch, json 2 | import numpy as np 3 | 4 | 5 | def mkdir(path): 6 | if not os.path.exists(path): 7 | os.mkdir(path) 8 | 9 | 10 | def set_seed(seed): 11 | random.seed(seed) 12 | np.random.seed(seed) 13 | torch.manual_seed(seed) 14 | torch.cuda.manual_seed(seed) 15 | 16 | 17 | def get_writer(config, main_rank): 18 | if config.use_tb and main_rank: 19 | from torch.utils.tensorboard import SummaryWriter 20 | writer = SummaryWriter(config.tb_log_dir) 21 | else: 22 | writer = None 23 | return writer 24 | 25 | 26 | def get_logger(config, main_rank): 27 | if main_rank: 28 | import sys 29 | from loguru import logger 30 | logger.remove() 31 | logger.add(sys.stderr, format="[{time:YYYY-MM-DD HH:mm}] {message}", level="INFO") 32 | 33 | log_path = f'{config.save_dir}/{config.logger_name}.log' 34 | logger.add(log_path, format="[{time:YYYY-MM-DD HH:mm}] {message}", level="INFO") 35 | else: 36 | logger = None 37 | return logger 38 | 39 | 40 | def save_config(config): 41 | config_dict = vars(config) 42 | with open(f'{config.save_dir}/config.json', 'w') as f: 43 | json.dump(config_dict, f, indent=4) 44 | 45 | 46 | def log_config(config, logger): 47 | keys = ['task', 'dataset', 'num_class', 'model', 'encoder', 'decoder', 'loss_type', 48 | 'optimizer_type', 'lr_policy', 'total_epoch', 'train_bs', 'val_bs', 49 | 'train_num', 'val_num', 'gpu_num', 'num_workers', 'amp_training', 50 | 'DDP', 'kd_training', 'synBN', 'use_ema', 'use_aux'] 51 | 52 | config_dict = vars(config) 53 | infos = f"\n\n\n{'#'*25} Config Informations {'#'*25}\n" 54 | infos += '\n'.join('%s: %s' % (k, config_dict[k]) for k in keys) 55 | infos += f"\n{'#'*71}\n\n" 56 | logger.info(infos) 57 | 58 | 59 | def get_colormap(config): 60 | if config.colormap == 'cityscapes': 61 | colormap = {0:(128, 64,128), 1:(244, 35,232), 2:( 70, 70, 70), 3:(102,102,156), 62 | 4:(190,153,153), 5:(153,153,153), 6:(250,170, 30), 7:(220,220, 0), 63 | 8:(107,142, 35), 9:(152,251,152), 10:( 70,130,180), 11:(220, 20, 60), 64 | 12:(255, 0, 0), 13:( 0, 0,142), 14:( 0, 0, 70), 15:( 0, 60,100), 65 | 16:( 0, 80,100), 17:( 0, 0,230), 18:(119, 11, 32)} 66 | 67 | elif config.colormap == 'custom': 68 | raise NotImplementedError() 69 | 70 | else: 71 | raise ValueError(f'Unsupport colormap type: {config.colormap}.') 72 | 73 | colormap = [color for color in colormap.values()] 74 | 75 | if len(colormap) < config.num_class: 76 | raise ValueError('Length of colormap is smaller than the number of class.') 77 | else: 78 | return colormap[:config.num_class] --------------------------------------------------------------------------------