├── README.md ├── augmentations.py ├── cityscapes_loader.py ├── config ├── icnet-cityscapes.yml └── spnet-cityscapes.yml ├── demo.py ├── inputs ├── aachen_000001_000019_leftImg8bit.png └── frankfurt_000000_000294_leftImg8bit.png ├── loss.py ├── metrics.py ├── models ├── icnet.py ├── spnet.py └── trl.py ├── schedulers.py ├── train.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | 2 | # Scence Parsing Network 3 | ## 1. 项目概要 4 | 本项目旨在实现车辆前方行车环境的实时解析。具体通过对行车记录仪的图像、视频数据的语义分割和深度估计实现。要实现的目标如下图所示: 5 | 6 | ![](https://raw.githubusercontent.com/EEEGUI/ImageBed/master/img/fig_045.png) 7 | 8 | ## 2. 模型 9 | 本项目在实现行车环境场景语义分割和深度估计实时解析的过程中,尝试使用了不同的模型,有复现[文献](http://openaccess.thecvf.com/content_ECCV_2018/papers/Zhenyu_Zhang_Joint_Task-Recursive_Learning_ECCV_2018_paper.pdf)中提出的网络框架TRL、也有在语义分割[ICNet](https://arxiv.org/pdf/1704.08545.pdf)的基础上增加深度分支,最后自己搭建了一个轻量化的模型。 10 | 11 | ### TRL 12 | TRL([Joint Task-Recursive Learning for Semantic Segmentation and Depth Estimation](http://openaccess.thecvf.com/content_ECCV_2018/papers/Zhenyu_Zhang_Joint_Task-Recursive_Learning_ECCV_2018_paper.pdf))是ECCV 2018上一个同时实现语义分割和深度估计的网络。网络框架如下图所示: 13 | 14 | ![](https://raw.githubusercontent.com/EEEGUI/ImageBed/master/img/fig_046.png) 15 | 16 | TRL network整体上是一个Encoder-Decoder的结构。 输入的RGB图像通过ResNet被处理成了不同尺度的特征图,这些特征图随后被输入到Decoder模块中处理得到语义信息和深度信息。在Decoder中,总共有4个语义预测分支和4个深度估计分支,二者交替进行。每一分支在进行预测时,都会综合前面已经提取的语义特征和深度特征,因为语义和深度存在一定的关系,二者特征的融合有利于提升精度。 17 | 18 | 但是在复现完论文后发现,网络的参数量高达(150)341M,发现原网络在多处对通道数为2048的特征图进行了多尺度的卷积操作,有$1*1,3*3,5*5,7*7$,因为卷积操作的参数量、计算量与卷积核尺寸、通道数成正比,$5*5,7*7$的大卷积核大大增加了参数量和计算量。先用1×1的卷积降维,再用3×3的空洞卷积替代5×5、7×7的卷积,减少了参数量,同时也提高了计算速度。 19 | 20 | ### ICNet 21 | ICNet是在PSPNet基础上改进的语义分割网络,旨在提高语义分割的速度。网络包含三个分支,不同分支上网络深度和特征图的尺寸不一样。在较小的特征图上充分提取语义信息,再和高分辨率分支提取的特征相融合补充细节信息。本项目在ICNet的基础上,在输出语义预测的模块并行增加了深度估计分支。 22 | 23 | ![](https://raw.githubusercontent.com/EEEGUI/ImageBed/master/img/fig_047.png) 24 | >网络结构图 25 | 26 | ![](https://raw.githubusercontent.com/EEEGUI/ImageBed/master/img/2019-06-21_15-26-22.png) 27 | 28 | >手绘网络结构细节图 29 | 30 | ### SPNet 31 | SPNet的网络结构如图所示: 32 | 33 | ![](https://raw.githubusercontent.com/EEEGUI/ImageBed/master/img/fig_048.png) 34 | 35 | 网络整体上也是一个Encoder-Decoder结构。 36 | 37 | Encoder部分由降采样单元和改进的残差单元组成。 38 | 39 | ![](https://raw.githubusercontent.com/EEEGUI/ImageBed/master/img/fig_050.png) 40 | 41 | 对于残差单元改进有以下几点: 42 | - 首先将输入在通道维度上一分为二,分别进入两个不同的卷积分支,实现输入通道$N_{in}$和输出通道的减小$N_{out}$的减小。 43 | - 其次,将卷积分支上$3\times3$的卷积拆分成$3\times1$和$1\times3$卷积核,减小了卷积核的大小。 44 | - 最后,级联两个分支的输出,恢复了通道数,并与原始输入直接相加,维持残差结构。由于通道拆分会导致不同分支之间的通道无法进行特征的组合,因此在单元最后增加一个通道的重组,重新分布通道的顺序,保证通道间特征的交流。 45 | 46 | Decoder部分由两部分组成,第一部分是中间两个分支,用于捕捉语义信息与深度信息的共同点。两个分支分别是多尺度卷积模块(Multi-scale Convolution Module)分支和普通的卷积运算分支。两个分支输出的通道个数均为$C+1$,其中$C$个通道为语义通道,$1$个通道为深度通道。 第二部分是旁路的两个分支,用于捕捉语义和深度各自独特的信息。多尺度卷积模块如下图所示: 47 | 48 | ![](https://raw.githubusercontent.com/EEEGUI/ImageBed/master/img/fig_051.png) 49 | 50 | 模型的效果就是介绍开始贴的图示。 51 | 52 | ## 代码使用 53 | 54 | ### 训练 55 | - **环境**:我自己使用的环境是 56 | - Ubuntu 16.4 57 | - Pytorch 1.0 58 | - cuda 10 59 | - 显卡 2080ti 60 | 61 | - **数据准备**:数据集到[Cityscapes](https://www.cityscapes-dataset.com/)上下载,其中深度数据集需要额外发邮件申请,没法直接下载。 62 | 63 | - **配置文件**:修改配置文件`config/spnet-cityscapes.yml`中的内容,将数据集位置改为自己数据集的路径。 64 | 65 | - 执行`train.py`即可 66 | 67 | ### 测试 68 | - 修改配置文件`config/spnet-cityscapes.yml`中test部分模型的保存位置 69 | - 将图像放到`inputs`文件夹中 70 | - 执行`demo.py`文件 71 | 72 | -------------------------------------------------------------------------------- /augmentations.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torchvision.transforms import functional as F 3 | 4 | 5 | class Rescale(object): 6 | def __init__(self, output_size): 7 | assert isinstance(output_size, (int, tuple)) 8 | self.output_size = output_size 9 | 10 | def __call__(self, sample): 11 | image, label, depth = sample['image'], sample['label'], sample['depth'] 12 | 13 | h, w = image.size[:2] 14 | if isinstance(self.output_size, int): 15 | if h > w: 16 | new_h, new_w = self.output_size * h / w, self.output_size 17 | else: 18 | new_h, new_w = self.output_size, self.output_size * w / h 19 | else: 20 | new_h, new_w = self.output_size 21 | 22 | new_h, new_w = int(new_h), int(new_w) 23 | 24 | image = F.resize(image, (new_h, new_w)) 25 | label = F.resize(label, (new_h, new_w)) 26 | depth = F.resize(depth, (new_h, new_w)) 27 | 28 | return {'image': image, 'label': label, 'depth': depth} 29 | 30 | 31 | class RandomHorizonFlip(object): 32 | def __init__(self, probability): 33 | self.probability = probability 34 | 35 | def __call__(self, sample): 36 | image, label, depth = sample['image'], sample['label'], sample['depth'] 37 | 38 | p = np.random.random() 39 | if p < self.probability: 40 | 41 | image = F.hflip(image) 42 | label = F.hflip(label) 43 | depth = F.hflip(depth) 44 | 45 | return {'image': image, 'label': label, 'depth': depth} 46 | 47 | 48 | class RandomRotate(object): 49 | def __init__(self, angle): 50 | self.angle = angle 51 | 52 | def __call__(self, sample): 53 | image, label, depth = sample['image'], sample['label'], sample['depth'] 54 | 55 | angle = np.random.randint(self.angle) 56 | 57 | image = F.rotate(image, angle) 58 | label = F.rotate(label, angle) 59 | depth = F.rotate(depth, angle) 60 | 61 | return {'image': image, 'label': label, 'depth': depth} 62 | 63 | 64 | class RandomCrop(object): 65 | def __init__(self, input_size): 66 | self.img_size = input_size 67 | 68 | def __call__(self, sample): 69 | image, label, depth = sample['image'], sample['label'], sample['depth'] 70 | h = np.random.randint(self.img_size[0], 1024) 71 | w = np.random.randint(self.img_size[1], 2048) 72 | i = np.random.randint(0, 1024 - h) 73 | j = np.random.randint(0, 2048 - w) 74 | 75 | image = F.crop(image, i, j, h, w) 76 | label = F.crop(label, i, j, h, w) 77 | depth = F.crop(depth, i, j, h, w) 78 | 79 | return {'image': image, 'label': label, 'depth': depth} -------------------------------------------------------------------------------- /cityscapes_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | import scipy.misc as m 5 | from PIL import Image 6 | 7 | from torch.utils import data 8 | from utils import recursive_glob 9 | from torchvision.transforms import Compose 10 | from torchvision.transforms import functional as F 11 | from augmentations import RandomHorizonFlip, RandomRotate, RandomCrop 12 | 13 | 14 | class cityscapesLoader(data.Dataset): 15 | 16 | colors = [ # [ 0, 0, 0], 17 | [128, 64, 128], 18 | [244, 35, 232], 19 | [70, 70, 70], 20 | [102, 102, 156], 21 | [190, 153, 153], 22 | [153, 153, 153], 23 | [250, 170, 30], 24 | [220, 220, 0], 25 | [107, 142, 35], 26 | [152, 251, 152], 27 | [0, 130, 180], 28 | [220, 20, 60], 29 | [255, 0, 0], 30 | [0, 0, 142], 31 | [0, 0, 70], 32 | [0, 60, 100], 33 | [0, 80, 100], 34 | [0, 0, 230], 35 | [119, 11, 32], 36 | ] 37 | 38 | label_colours = dict(zip(range(19), colors)) 39 | 40 | def __init__( 41 | self, 42 | root, 43 | split="train", 44 | is_transform=False, 45 | img_size=(512, 1024), 46 | max_depth=250, 47 | augmentations=None, 48 | img_norm=True, 49 | mean=[0, 0, 0], 50 | test_mode=False, 51 | ): 52 | """__init__ 53 | :param root: 54 | :param split: 55 | :param is_transform: 56 | :param img_size: 57 | :param augmentations 58 | """ 59 | self.root = root 60 | self.split = split 61 | self.is_transform = is_transform 62 | self.augmentations = augmentations 63 | self.img_norm = img_norm 64 | self.n_classes = 19 65 | self.img_size = img_size 66 | self.max_depth = max_depth 67 | self.mean = np.array(mean) 68 | self.files = {} 69 | 70 | self.images_base = os.path.join(self.root, "leftImg8bit", self.split) 71 | self.annotations_base = os.path.join(self.root, "gtFine", self.split) 72 | self.disparity_base = os.path.join(self.root, "disparity", self.split) 73 | 74 | self.files[split] = recursive_glob(rootdir=self.images_base, suffix=".png") 75 | 76 | self.void_classes = [0, 1, 2, 3, 4, 5, 6, 9, 10, 14, 15, 16, 18, 29, 30, -1] 77 | self.valid_classes = [7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33] 78 | 79 | self.class_names = [ 80 | "unlabelled", 81 | "road", 82 | "sidewalk", 83 | "building", 84 | "wall", 85 | "fence", 86 | "pole", 87 | "traffic_light", 88 | "traffic_sign", 89 | "vegetation", 90 | "terrain", 91 | "sky", 92 | "person", 93 | "rider", 94 | "car", 95 | "truck", 96 | "bus", 97 | "train", 98 | "motorcycle", 99 | "bicycle", 100 | ] 101 | 102 | self.ignore_index = 250 103 | self.class_map = dict(zip(self.valid_classes, range(19))) 104 | 105 | if not self.files[split]: 106 | raise Exception("No files for split=[%s] found in %s" % (split, self.images_base)) 107 | 108 | print("Found %d %s images" % (len(self.files[split]), split)) 109 | 110 | def __len__(self): 111 | """__len__""" 112 | return len(self.files[self.split]) 113 | 114 | def __getitem__(self, index): 115 | """__getitem__ 116 | 117 | :param index: 118 | """ 119 | img_path = self.files[self.split][index].rstrip() 120 | lbl_path = os.path.join( 121 | self.annotations_base, 122 | img_path.split(os.sep)[-2], 123 | os.path.basename(img_path)[:-15] + "gtFine_labelIds.png", 124 | ) 125 | depth_path = os.path.join( 126 | self.disparity_base, 127 | img_path.split(os.sep)[-2], 128 | os.path.basename(img_path)[:-15] + "disparity.png", 129 | ) 130 | 131 | img = Image.open(img_path) 132 | lbl = Image.open(lbl_path) 133 | depth = Image.open(depth_path) 134 | 135 | sample = {'image': img, 'label': lbl, 'depth': depth} 136 | 137 | if self.augmentations is not None: 138 | sample = self.augmentations(sample) 139 | 140 | if self.is_transform: 141 | sample = self.transform(sample) 142 | 143 | return sample 144 | 145 | def transform(self, sample): 146 | """transform 147 | 148 | :param img: 149 | :param lbl: 150 | """ 151 | img, lbl, depth = sample['image'], sample['label'], sample['depth'] 152 | 153 | # image 154 | img = F.resize(img, (self.img_size[0], self.img_size[1])) 155 | img = np.array(img, dtype=np.uint8) 156 | img = img[:, :, ::-1] # RGB -> BGR 157 | img = img.astype(np.float64) 158 | img -= self.mean 159 | if self.img_norm: 160 | # Resize scales images from 0 to 255, thus we need 161 | # to divide by 255.0 162 | img = img.astype(float) / 255.0 163 | # NHWC -> NCHW 164 | img = img.transpose(2, 0, 1) 165 | img = torch.from_numpy(img).float() 166 | 167 | 168 | # label 169 | lbl = F.resize(lbl, (self.img_size[0], self.img_size[1]), interpolation=Image.NEAREST) 170 | lbl = self.encode_segmap(np.array(lbl, dtype=np.uint8)) 171 | classes = np.unique(lbl) 172 | lbl = lbl.astype(int) 173 | 174 | if not np.all(classes == np.unique(lbl)): 175 | print("WARN: resizing labels yielded fewer classes") 176 | 177 | if not np.all(np.unique(lbl[lbl != self.ignore_index]) < self.n_classes): 178 | print("after det", classes, np.unique(lbl)) 179 | raise ValueError("Segmentation map contained invalid class values") 180 | lbl = torch.from_numpy(lbl).long() 181 | 182 | 183 | # depth 184 | depth = F.resize(depth, (self.img_size[0], self.img_size[1])) 185 | depth = self.decode_depthmap(np.array(depth, dtype=np.float32)) 186 | depth = torch.from_numpy(depth).float() 187 | depth = torch.unsqueeze(depth, 0) 188 | 189 | return {'image': img, 'label': lbl, 'depth': depth} 190 | 191 | def img_recover(self, tensor_img): 192 | img = tensor_img.cpu().numpy() 193 | # CHW -> HWC 194 | img = img.transpose(1, 2, 0) 195 | if self.img_norm: 196 | img = (img * 255.0) 197 | img += self.mean 198 | img = img[:, :, ::-1] 199 | img = np.array(img, dtype=np.uint8) 200 | return img 201 | 202 | def decode_segmap(self, temp): 203 | r = temp.copy() 204 | g = temp.copy() 205 | b = temp.copy() 206 | for l in range(0, self.n_classes): 207 | r[temp == l] = self.label_colours[l][0] 208 | g[temp == l] = self.label_colours[l][1] 209 | b[temp == l] = self.label_colours[l][2] 210 | 211 | rgb = np.zeros((temp.shape[0], temp.shape[1], 3)) 212 | rgb[:, :, 0] = r / 255.0 213 | rgb[:, :, 1] = g / 255.0 214 | rgb[:, :, 2] = b / 255.0 215 | return rgb 216 | 217 | def encode_segmap(self, mask): 218 | # Put all void classes to zero 219 | for _voidc in self.void_classes: 220 | mask[mask == _voidc] = self.ignore_index 221 | for _validc in self.valid_classes: 222 | mask[mask == _validc] = self.class_map[_validc] 223 | return mask 224 | 225 | def decode_depthmap(self, disparity): 226 | disparity[disparity > 0] = (disparity[disparity > 0] - 1) / 256 227 | 228 | disparity[disparity > 0] = (0.209313 * 2262.52) / disparity[disparity > 0] 229 | 230 | disparity[disparity <= 0] = 0 231 | disparity[disparity > self.max_depth] = 0 232 | 233 | depth = disparity 234 | 235 | return depth 236 | 237 | def encode_depthmap(self, depth): 238 | depth = (depth * 256 + 1).astype('int16') 239 | return depth 240 | 241 | 242 | if __name__ == "__main__": 243 | import matplotlib.pyplot as plt 244 | import seaborn as sns 245 | augmentations = Compose([RandomRotate(10), RandomCrop(), RandomHorizonFlip(0.5)]) 246 | # augmentations = None 247 | 248 | local_path = "/home/lin/Documents/dataset/Cityscapes/" 249 | dst = cityscapesLoader(local_path, is_transform=True, augmentations=augmentations) 250 | bs = 4 251 | trainloader = data.DataLoader(dst, batch_size=bs, num_workers=0) 252 | for i, data_samples in enumerate(trainloader): 253 | imgs, labels, depth = data_samples['image'], data_samples['label'], data_samples['depth'] 254 | imgs = imgs.numpy()[:, ::-1, :, :] 255 | imgs = np.transpose(imgs, [0, 2, 3, 1]) 256 | # depthhh = depth.view(-1).numpy() 257 | # sns.distplot(depthhh) 258 | f, axarr = plt.subplots(bs, 3) 259 | for j in range(bs): 260 | axarr[j][0].imshow(imgs[j]) 261 | # axarr[j][0].imshow(dst.img_recover(imgs[j])) 262 | axarr[j][1].imshow(dst.decode_segmap(labels.numpy()[j])) 263 | axarr[j][2].imshow(dst.encode_depthmap(depth.numpy()[j][0]), cmap='gray') 264 | plt.show() 265 | 266 | -------------------------------------------------------------------------------- /config/icnet-cityscapes.yml: -------------------------------------------------------------------------------- 1 | arch: icnet 2 | data: 3 | dataset: cityscapes 4 | path: /home/lin/Documents/dataset/Cityscapes/ 5 | train_split: train 6 | val_split: train 7 | model: 8 | multi_results: False 9 | n_classes: 19 10 | input_size: [513, 1025] 11 | block_config: [3, 4, 6, 3] 12 | training: 13 | train_iters: 500 14 | batch_size: 1 15 | accu_steps: 1 16 | val_interval: 10 17 | print_interval: 1 18 | img_size: [513, 1025] 19 | argumentation: 20 | random_hflip: 0.5 21 | random_rotate: 8 22 | optimizer: 23 | lr: 0.01 24 | momentum: 0.9 25 | optimizer_loss: 26 | lr: 0.001 27 | schedule: 28 | gamma: 2 29 | loss: 30 | loss_weights: 0 31 | delta1: 0 32 | delta2: 1 33 | resume: 34 | visdom: False 35 | 36 | testing: 37 | model_path: runs/icnet-cityscapes/431/SPNet_cityscapes_best_model.pth 38 | config_path: configs 39 | img_fold: inputs 40 | output_fold: outputs 41 | img_rows: 513 42 | img_cols: 1025 43 | downsample: 3 44 | bs: 1 45 | 46 | device: cuda 47 | -------------------------------------------------------------------------------- /config/spnet-cityscapes.yml: -------------------------------------------------------------------------------- 1 | arch: spnet 2 | data: 3 | dataset: cityscapes 4 | path: /home/lin/Documents/dataset/Cityscapes/ 5 | train_split: train 6 | val_split: val 7 | model: 8 | num_classes: 19 9 | 10 | training: 11 | train_iters: 297500 12 | batch_size: 4 13 | accu_steps: 1 14 | val_interval: 2975 15 | print_interval: 100 16 | img_size: [512, 1024] 17 | argumentation: 18 | random_hflip: 0.5 19 | random_rotate: 8 20 | optimizer: 21 | lr: 0.01 22 | momentum: 0.9 23 | schedule: 24 | gamma: 2 25 | loss: 26 | loss_weights: 0.8 27 | resume: 28 | visdom: False 29 | 30 | device: cuda 31 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import yaml 4 | import numpy as np 5 | import scipy.misc as misc 6 | from models.icnet import icnet 7 | from models.spnet import spnet 8 | from cityscapes_loader import cityscapesLoader 9 | import cv2 10 | 11 | models = {'spnet': spnet, 'icnet': icnet} 12 | 13 | 14 | def test_img(cfg): 15 | device = torch.device(cfg['device']) 16 | data_loader = cityscapesLoader 17 | loader = data_loader(root=cfg['data']['path'], is_transform=True, test_mode=True) 18 | n_classes = loader.n_classes 19 | # Setup Model 20 | if 'multi_results' in cfg['model']: 21 | cfg['model']['multi_results'] = False 22 | 23 | model = models[cfg['arch']](**cfg['model']) 24 | model.load_state_dict(torch.load(cfg['testing']['model_path'])["model_state"]) 25 | model.eval() 26 | model.to(device) 27 | 28 | for img_name in os.listdir(cfg['testing']['img_fold']): 29 | seg_output_path = os.path.join(cfg['testing']['output_fold'], 'seg_%s.png' % img_name.split('.')[0]) 30 | depth_output_path = os.path.join(cfg['testing']['output_fold'], 'depth_%s.png' % img_name.split('.')[0]) 31 | if not os.path.exists(seg_output_path): 32 | img_path = os.path.join(cfg['testing']['img_fold'], img_name) 33 | img = misc.imread(img_path) 34 | orig_size = img.shape[:-1] 35 | 36 | # uint8 with RGB mode, resize width and height which are odd numbers 37 | # img = misc.imresize(img, (orig_size[0] // 2 * 2 + 1, orig_size[1] // 2 * 2 + 1)) 38 | img = misc.imresize(img, (cfg['testing']['img_rows'], cfg['testing']['img_cols'])) 39 | img = img.astype(np.float64) 40 | img = img[:, :, ::-1] # RGB -> BGR 41 | img = img.astype(float) / 255.0 42 | # HWC -> CHW 43 | img = img.transpose(2, 0, 1) 44 | img = np.expand_dims(img, 0) 45 | img = torch.from_numpy(img).float() 46 | 47 | img = img.to(device) 48 | depth_result, seg_result = model(img)[0] 49 | 50 | # save segmentation result 51 | seg_result = np.squeeze(seg_result.data.max(1)[1].cpu().numpy(), axis=0) 52 | seg_result = seg_result.astype(np.float32) 53 | # float32 with F mode, resize back to orig_size 54 | seg_result = misc.imresize(seg_result, orig_size, "nearest", mode="F") 55 | 56 | decoded = loader.decode_segmap(seg_result) 57 | misc.imsave(seg_output_path, decoded) 58 | 59 | # save depth map 60 | if cfg['testing']['pred_depth']: 61 | depth_result = np.squeeze(depth_result.cpu().detach().numpy(), axis=0) 62 | depth_result = np.squeeze(depth_result, axis=0) 63 | depth_result = depth_result.astype(np.float32) 64 | # float32 with F mode, resize back to orig_size 65 | depth_result = misc.imresize(depth_result, orig_size, "nearest", mode='F') 66 | 67 | depth_color = cv2.applyColorMap(cv2.convertScaleAbs(depth_result, alpha=15), cv2.COLORMAP_JET) 68 | # convert to mat png 69 | misc.imsave(depth_output_path, depth_color) 70 | 71 | 72 | if __name__ == "__main__": 73 | with open('config/spnet-cityscapes.yml') as fp: 74 | cfg = yaml.safe_load(fp) 75 | test_img(cfg) 76 | -------------------------------------------------------------------------------- /inputs/aachen_000001_000019_leftImg8bit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EEEGUI/SceneParsingNetwork/b515c7af06a886ee3ae5458f4e4189262df42086/inputs/aachen_000001_000019_leftImg8bit.png -------------------------------------------------------------------------------- /inputs/frankfurt_000000_000294_leftImg8bit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EEEGUI/SceneParsingNetwork/b515c7af06a886ee3ae5458f4e4189262df42086/inputs/frankfurt_000000_000294_leftImg8bit.png -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class Loss(nn.Module): 7 | def __init__(self, loss_weights=0.8, delta1=0.1, delta2=10): 8 | super(Loss, self).__init__() 9 | self.loss_weights = [loss_weights ** (3-i) for i in range(4)] 10 | self.delta1 = nn.Parameter(torch.Tensor([delta1])) 11 | self.delta2 = nn.Parameter(torch.Tensor([delta2])) 12 | 13 | def forward(self, outputs, depths, labels): 14 | # loss = self.delta1 + self.delta2 15 | loss = 0 16 | for i, pair in enumerate(outputs): 17 | # loss = loss + torch.exp(-self.delta1) * self.loss_weights[i] * sim_depth_loss(depths, pair[0]) + \ 18 | # torch.exp(-self.delta2) * self.loss_weights[i] * cross_entropy2d(pair[1], labels) 19 | 20 | loss = loss + self.delta1 * self.loss_weights[i] * sim_depth_loss(depths, pair[0]) + \ 21 | self.delta2 * self.loss_weights[i] * cross_entropy2d(pair[1], labels) 22 | return loss 23 | 24 | 25 | def sim_depth_loss(y_true, y_pred): 26 | mask = (y_true > 0).float() 27 | y_pred = mask * y_pred 28 | c = 0.2 * torch.max(torch.abs(y_pred-y_true)) 29 | loss = torch.abs(y_true-y_pred) * (torch.abs(y_pred - y_true) <= c).float() + (torch.pow(y_true-y_pred, 2) + c**2)/(2*c) * (torch.abs(y_pred-y_true) > c).float() 30 | 31 | loss = torch.mean(loss) 32 | 33 | return loss 34 | 35 | 36 | def cross_entropy2d(input, target, weight=None, size_average=True): 37 | n, c, h, w = input.size() 38 | nt, ht, wt = target.size() 39 | 40 | # Handle inconsistent size between input and target 41 | if h != ht and w != wt: # upsample labels 42 | input = F.interpolate(input, size=(ht, wt), mode="bilinear", align_corners=True) 43 | 44 | input = input.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c) 45 | target = target.view(-1) 46 | loss = F.cross_entropy( 47 | input, target, weight=weight, size_average=size_average, ignore_index=250 48 | ) 49 | return loss 50 | 51 | 52 | def bootstrapped_cross_entropy2d(input, target, K, weight=None, size_average=True): 53 | 54 | batch_size = input.size()[0] 55 | 56 | def _bootstrap_xentropy_single(input, target, K, weight=None, size_average=True): 57 | 58 | n, c, h, w = input.size() 59 | input = input.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c) 60 | target = target.view(-1) 61 | loss = F.cross_entropy( 62 | input, target, weight=weight, reduce=False, size_average=False, ignore_index=250 63 | ) 64 | 65 | topk_loss, _ = loss.topk(K) 66 | reduced_topk_loss = topk_loss.sum() / K 67 | 68 | return reduced_topk_loss 69 | 70 | loss = 0.0 71 | # Bootstrap from each image not entire batch 72 | for i in range(batch_size): 73 | loss += _bootstrap_xentropy_single( 74 | input=torch.unsqueeze(input[i], 0), 75 | target=torch.unsqueeze(target[i], 0), 76 | K=K, 77 | weight=weight, 78 | size_average=size_average, 79 | ) 80 | return loss / float(batch_size) 81 | 82 | 83 | 84 | -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class SegmentationScore(object): 5 | """ 6 | 语义分割准确度 7 | """ 8 | def __init__(self, n_classes=19): 9 | self.n_classes = n_classes 10 | self.confusion_matrix = np.zeros((n_classes, n_classes)) 11 | 12 | def _fast_hist(self, label_true, label_pred, n_class): 13 | mask = (label_true >= 0) & (label_true < n_class) 14 | hist = np.bincount( 15 | n_class * label_true[mask].astype(int) + label_pred[mask], minlength=n_class ** 2 16 | ).reshape(n_class, n_class) 17 | return hist 18 | 19 | def update(self, label_trues, label_preds): 20 | for lt, lp in zip(label_trues, label_preds): 21 | self.confusion_matrix += self._fast_hist(lt.flatten(), lp.flatten(), self.n_classes) 22 | 23 | def get_scores(self): 24 | """Returns accuracy score evaluation result. 25 | - overall accuracy 26 | - mean accuracy 27 | - mean IU 28 | - fwavacc 29 | """ 30 | hist = self.confusion_matrix 31 | acc = np.diag(hist).sum() / hist.sum() 32 | acc_cls = np.diag(hist) / hist.sum(axis=1) 33 | acc_cls = np.nanmean(acc_cls) 34 | iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist)) 35 | mean_iu = np.nanmean(iu) 36 | freq = hist.sum(axis=1) / hist.sum() 37 | fwavacc = (freq[freq > 0] * iu[freq > 0]).sum() 38 | cls_iu = dict(zip(range(self.n_classes), iu)) 39 | 40 | return ( 41 | { 42 | "Overall Acc: \t": acc, 43 | "Mean Acc : \t": acc_cls, 44 | "FreqW Acc : \t": fwavacc, 45 | "Mean IoU : \t": mean_iu, 46 | }, 47 | cls_iu, 48 | ) 49 | 50 | def reset(self): 51 | self.confusion_matrix = np.zeros((self.n_classes, self.n_classes)) 52 | 53 | 54 | class DepthEstimateScore(object): 55 | """ 56 | 深度估计准确度 57 | """ 58 | def __init__(self): 59 | self.score_dict = {'a1': [], 60 | 'a2': [], 61 | 'a3': [], 62 | 'abs_rel': [], 63 | 'rmse': [], 64 | 'log_10': []} 65 | 66 | def reset(self): 67 | self.score_dict = {'a1': [], 68 | 'a2': [], 69 | 'a3': [], 70 | 'abs_rel': [], 71 | 'rmse': [], 72 | 'log_10': []} 73 | 74 | def update(self, label_true, label_pred): 75 | errors = self._compute_errors(label_true, label_pred) 76 | for i, key in enumerate(self.score_dict.keys()): 77 | self.score_dict[key].append(errors[i]) 78 | 79 | def get_scores(self): 80 | scores = {} 81 | for key in self.score_dict.keys(): 82 | scores[key] = np.mean(self.score_dict[key]) 83 | return scores 84 | 85 | def _compute_errors(self, gt, pred): 86 | pred[gt <= 0] = np.nan 87 | gt[gt <= 0] = np.nan 88 | thresh = np.maximum((gt / pred), (pred / gt)) 89 | a1 = (thresh < 1.25).sum()/(thresh.size - np.isnan(thresh).sum()) 90 | a2 = (thresh < 1.25 ** 2).sum()/(thresh.size - np.isnan(thresh).sum()) 91 | a3 = (thresh < 1.25 ** 3).sum()/(thresh.size - np.isnan(thresh).sum()) 92 | 93 | abs_rel = np.nanmean(np.abs(gt - pred) / gt) 94 | 95 | rmse = np.power(gt - pred, 2) 96 | rmse = np.sqrt(np.nanmean(rmse)) 97 | 98 | log_10 = np.nanmean(np.abs(np.log10(gt) - np.log10(pred))) 99 | 100 | return [a1, a2, a3, abs_rel, rmse, log_10] 101 | 102 | 103 | class averageMeter(object): 104 | """Computes and stores the average and current value""" 105 | 106 | def __init__(self): 107 | self.reset() 108 | 109 | def reset(self): 110 | self.val = 0 111 | self.avg = 0 112 | self.sum = 0 113 | self.count = 0 114 | 115 | def update(self, val, n=1): 116 | self.val = val 117 | self.sum += val * n 118 | self.count += n 119 | self.avg = self.sum / self.count 120 | -------------------------------------------------------------------------------- /models/icnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from utils import ( 6 | get_interp_size, 7 | cascadeFeatureFusion, 8 | conv2DBatchNormRelu, 9 | residualBlockPSP, 10 | pyramidPooling, 11 | ) 12 | 13 | 14 | class icnet(nn.Module): 15 | """ 16 | Image Cascade Network 17 | URL: https://arxiv.org/abs/1704.08545 18 | 19 | References: 20 | 1) Original Author's code: https://github.com/hszhao/ICNet 21 | 2) Chainer implementation by @mitmul: https://github.com/mitmul/chainer-pspnet 22 | 3) TensorFlow implementation by @hellochick: https://github.com/hellochick/ICNet-tensorflow 23 | 24 | """ 25 | 26 | def __init__( 27 | self, 28 | n_classes=19, 29 | block_config=[3, 4, 6, 3], 30 | input_size=(1025, 2049), 31 | multi_results=True, 32 | is_batchnorm=True, 33 | ): 34 | 35 | super(icnet, self).__init__() 36 | 37 | bias = not is_batchnorm 38 | 39 | self.block_config = block_config 40 | self.n_classes = n_classes 41 | self.input_size = input_size 42 | self.multi_results = multi_results 43 | # Encoder 44 | self.convbnrelu1_1 = conv2DBatchNormRelu( 45 | in_channels=3, 46 | k_size=3, 47 | n_filters=32, 48 | padding=1, 49 | stride=2, 50 | bias=bias, 51 | is_batchnorm=is_batchnorm, 52 | ) 53 | self.convbnrelu1_2 = conv2DBatchNormRelu( 54 | in_channels=32, 55 | k_size=3, 56 | n_filters=32, 57 | padding=1, 58 | stride=1, 59 | bias=bias, 60 | is_batchnorm=is_batchnorm, 61 | ) 62 | self.convbnrelu1_3 = conv2DBatchNormRelu( 63 | in_channels=32, 64 | k_size=3, 65 | n_filters=64, 66 | padding=1, 67 | stride=1, 68 | bias=bias, 69 | is_batchnorm=is_batchnorm, 70 | ) 71 | 72 | # Vanilla Residual Blocks 73 | self.res_block2 = residualBlockPSP( 74 | self.block_config[0], 64, 32, 128, 1, 1, is_batchnorm=is_batchnorm) 75 | 76 | self.res_block3_conv = residualBlockPSP( 77 | self.block_config[1], 128, 64, 256, 2, 1, include_range="conv",is_batchnorm=is_batchnorm) 78 | 79 | self.res_block3_identity = residualBlockPSP( 80 | self.block_config[1], 128, 64, 256, 2, 1, include_range="identity",is_batchnorm=is_batchnorm) 81 | 82 | # Dilated Residual Blocks 83 | self.res_block4 = residualBlockPSP( 84 | self.block_config[2], 256, 128, 512, 1, 2, is_batchnorm=is_batchnorm) 85 | 86 | self.res_block5 = residualBlockPSP( 87 | self.block_config[3], 512, 256, 1024, 1, 4, is_batchnorm=is_batchnorm) 88 | 89 | # Pyramid Pooling Module 90 | self.pyramid_pooling = pyramidPooling( 91 | 1024, [6, 3, 2, 1], model_name="icnet", fusion_mode="sum", is_batchnorm=is_batchnorm) 92 | 93 | # Final conv layer with kernel 1 in sub4 branch 94 | self.conv5_4_k1 = conv2DBatchNormRelu( 95 | in_channels=1024, 96 | k_size=1, 97 | n_filters=256, 98 | padding=0, 99 | stride=1, 100 | bias=bias, 101 | is_batchnorm=is_batchnorm, 102 | ) 103 | 104 | # High-resolution (sub1) branch 105 | self.convbnrelu1_sub1 = conv2DBatchNormRelu( 106 | in_channels=3, 107 | k_size=3, 108 | n_filters=32, 109 | padding=1, 110 | stride=2, 111 | bias=bias, 112 | is_batchnorm=is_batchnorm, 113 | ) 114 | self.convbnrelu2_sub1 = conv2DBatchNormRelu( 115 | in_channels=32, 116 | k_size=3, 117 | n_filters=32, 118 | padding=1, 119 | stride=2, 120 | bias=bias, 121 | is_batchnorm=is_batchnorm, 122 | ) 123 | self.convbnrelu3_sub1 = conv2DBatchNormRelu( 124 | in_channels=32, 125 | k_size=3, 126 | n_filters=64, 127 | padding=1, 128 | stride=2, 129 | bias=bias, 130 | is_batchnorm=is_batchnorm, 131 | ) 132 | self.classification = nn.Conv2d(128, self.n_classes, 1, 1, 0) 133 | self.depth_estimate = nn.Conv2d(128, 1, 1, 1, 0) 134 | 135 | # Cascade Feature Fusion Units 136 | self.cff_sub24 = cascadeFeatureFusion( 137 | self.n_classes, 256, 256, 128, is_batchnorm=is_batchnorm 138 | ) 139 | self.cff_sub12 = cascadeFeatureFusion( 140 | self.n_classes, 128, 64, 128, is_batchnorm=is_batchnorm 141 | ) 142 | 143 | # Define auxiliary loss function 144 | # self.loss = multi_scale_cross_entropy2d 145 | 146 | def forward(self, x): 147 | # H, W -> H/2, W/2 148 | x_sub2 = F.interpolate( 149 | x, size=get_interp_size(x, s_factor=2), mode="bilinear", align_corners=True 150 | ) 151 | 152 | # H/2, W/2 -> H/4, W/4 153 | x_sub2 = self.convbnrelu1_1(x_sub2) 154 | x_sub2 = self.convbnrelu1_2(x_sub2) 155 | x_sub2 = self.convbnrelu1_3(x_sub2) 156 | 157 | # H/4, W/4 -> H/8, W/8 158 | x_sub2 = F.max_pool2d(x_sub2, 3, 2, 1) 159 | 160 | # H/8, W/8 -> H/16, W/16 161 | x_sub2 = self.res_block2(x_sub2) 162 | x_sub2 = self.res_block3_conv(x_sub2) 163 | # H/16, W/16 -> H/32, W/32 164 | x_sub4 = F.interpolate( 165 | x_sub2, size=get_interp_size(x_sub2, s_factor=2), mode="bilinear", align_corners=True 166 | ) 167 | x_sub4 = self.res_block3_identity(x_sub4) 168 | 169 | x_sub4 = self.res_block4(x_sub4) 170 | x_sub4 = self.res_block5(x_sub4) 171 | 172 | x_sub4 = self.pyramid_pooling(x_sub4) 173 | x_sub4 = self.conv5_4_k1(x_sub4) 174 | 175 | x_sub1 = self.convbnrelu1_sub1(x) 176 | x_sub1 = self.convbnrelu2_sub1(x_sub1) 177 | x_sub1 = self.convbnrelu3_sub1(x_sub1) 178 | 179 | x_sub24, sub4_cls, sub4_depth = self.cff_sub24(x_sub4, x_sub2) 180 | x_sub12, sub24_cls, sub24_depth = self.cff_sub12(x_sub24, x_sub1) 181 | 182 | x_sub12 = F.interpolate( 183 | x_sub12, size=get_interp_size(x_sub12, z_factor=2), mode="bilinear", align_corners=True 184 | ) 185 | 186 | sub124_cls = self.classification(x_sub12) 187 | sub124_depth = self.depth_estimate(x_sub12) 188 | 189 | sub124_depth = F.interpolate(sub124_depth, size=get_interp_size(sub124_depth, z_factor=4), 190 | mode="bilinear", align_corners=True) 191 | sub124_cls = F.interpolate(sub124_cls, size=get_interp_size(sub124_cls, z_factor=4), 192 | mode="bilinear", align_corners=True) 193 | 194 | if self.multi_results: 195 | sub4_depth = F.interpolate(sub4_depth, size=get_interp_size(sub4_depth, z_factor=16), 196 | mode="bilinear", align_corners=True) 197 | sub4_cls = F.interpolate(sub4_cls, size=get_interp_size(sub4_cls, z_factor=16), 198 | mode="bilinear", align_corners=True) 199 | 200 | sub24_depth = F.interpolate(sub24_depth, size=get_interp_size(sub24_depth, z_factor=8), 201 | mode="bilinear", align_corners=True) 202 | sub24_cls = F.interpolate(sub24_cls, size=get_interp_size(sub24_cls, z_factor=8), 203 | mode="bilinear", align_corners=True) 204 | 205 | return [(sub4_depth, sub4_cls), (sub24_depth, sub24_cls), (sub124_depth, sub124_cls)] 206 | 207 | else: 208 | return sub124_depth, sub124_cls 209 | 210 | 211 | if __name__ == "__main__": 212 | net = icnet(is_batchnorm=True).cuda() 213 | from torchsummary import summary 214 | summary(net, (3, 513, 1025)) 215 | 216 | 217 | -------------------------------------------------------------------------------- /models/spnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.init as init 4 | import torch.nn.functional as F 5 | from torch.nn.functional import interpolate as interpolate 6 | 7 | 8 | def channel_shuffle(x, groups): 9 | batchsize, num_channels, height, width = x.data.size() 10 | 11 | channels_per_group = num_channels // groups 12 | 13 | # reshape 14 | x = x.view(batchsize, groups, 15 | channels_per_group, height, width) 16 | 17 | x = torch.transpose(x, 1, 2).contiguous() 18 | 19 | # flatten 20 | x = x.view(batchsize, -1, height, width) 21 | 22 | return x 23 | 24 | 25 | class Conv2dBnRelu(nn.Module): 26 | def __init__(self, in_ch, out_ch, kernel_size=3, stride=1, padding=0, dilation=1, bias=True): 27 | super(Conv2dBnRelu, self).__init__() 28 | 29 | self.conv = nn.Sequential( 30 | nn.Conv2d(in_ch, out_ch, kernel_size, stride, padding, dilation=dilation, bias=bias), 31 | nn.BatchNorm2d(out_ch, eps=1e-3), 32 | nn.ReLU(inplace=True) 33 | ) 34 | 35 | def forward(self, x): 36 | return self.conv(x) 37 | 38 | 39 | class DownsamplerBlock(nn.Module): 40 | def __init__(self, in_channel, out_channel): 41 | super().__init__() 42 | 43 | self.conv = nn.Conv2d(in_channel, out_channel - in_channel, (3, 3), stride=2, padding=1, bias=True) 44 | self.pool = nn.MaxPool2d(2, stride=2) 45 | self.bn = nn.BatchNorm2d(out_channel, eps=1e-3) 46 | self.relu = nn.ReLU(inplace=True) 47 | 48 | def forward(self, input): 49 | output = torch.cat([self.conv(input), self.pool(input)], 1) 50 | output = self.bn(output) 51 | output = self.relu(output) 52 | 53 | return output 54 | 55 | 56 | class ResModule(nn.Module): 57 | def __init__(self, chann, dropprob, dilated): 58 | super().__init__() 59 | 60 | oup_inc = chann // 2 61 | 62 | # dw 63 | self.conv3x1_1_l = nn.Conv2d(oup_inc, oup_inc, (3, 1), stride=1, padding=(1, 0), bias=True) 64 | 65 | self.conv1x3_1_l = nn.Conv2d(oup_inc, oup_inc, (1, 3), stride=1, padding=(0, 1), bias=True) 66 | 67 | self.bn1_l = nn.BatchNorm2d(oup_inc, eps=1e-03) 68 | 69 | self.conv3x1_2_l = nn.Conv2d(oup_inc, oup_inc, (3, 1), stride=1, padding=(1 * dilated, 0), bias=True, 70 | dilation=(dilated, 1)) 71 | 72 | self.conv1x3_2_l = nn.Conv2d(oup_inc, oup_inc, (1, 3), stride=1, padding=(0, 1 * dilated), bias=True, 73 | dilation=(1, dilated)) 74 | 75 | self.bn2_l = nn.BatchNorm2d(oup_inc, eps=1e-03) 76 | 77 | # dw 78 | self.conv3x1_1_r = nn.Conv2d(oup_inc, oup_inc, (3, 1), stride=1, padding=(1, 0), bias=True) 79 | 80 | self.conv1x3_1_r = nn.Conv2d(oup_inc, oup_inc, (1, 3), stride=1, padding=(0, 1), bias=True) 81 | 82 | self.bn1_r = nn.BatchNorm2d(oup_inc, eps=1e-03) 83 | 84 | self.conv3x1_2_r = nn.Conv2d(oup_inc, oup_inc, (3, 1), stride=1, padding=(1 * dilated, 0), bias=True, 85 | dilation=(dilated, 1)) 86 | 87 | self.conv1x3_2_r = nn.Conv2d(oup_inc, oup_inc, (1, 3), stride=1, padding=(0, 1 * dilated), bias=True, 88 | dilation=(1, dilated)) 89 | 90 | self.bn2_r = nn.BatchNorm2d(oup_inc, eps=1e-03) 91 | 92 | self.relu = nn.ReLU(inplace=True) 93 | self.dropout = nn.Dropout2d(dropprob) 94 | 95 | @staticmethod 96 | def _concat(x, out): 97 | return torch.cat((x, out), 1) 98 | 99 | def forward(self, input): 100 | x1 = input[:, :(input.shape[1] // 2), :, :] 101 | x2 = input[:, (input.shape[1] // 2):, :, :] 102 | 103 | output1 = self.conv3x1_1_l(x1) 104 | output1 = self.relu(output1) 105 | output1 = self.conv1x3_1_l(output1) 106 | output1 = self.bn1_l(output1) 107 | output1 = self.relu(output1) 108 | 109 | output1 = self.conv3x1_2_l(output1) 110 | output1 = self.relu(output1) 111 | output1 = self.conv1x3_2_l(output1) 112 | output1 = self.bn2_l(output1) 113 | 114 | output2 = self.conv1x3_1_r(x2) 115 | output2 = self.relu(output2) 116 | output2 = self.conv3x1_1_r(output2) 117 | output2 = self.bn1_r(output2) 118 | output2 = self.relu(output2) 119 | 120 | output2 = self.conv1x3_2_r(output2) 121 | output2 = self.relu(output2) 122 | output2 = self.conv3x1_2_r(output2) 123 | output2 = self.bn2_r(output2) 124 | 125 | if (self.dropout.p != 0): 126 | output1 = self.dropout(output1) 127 | output2 = self.dropout(output2) 128 | 129 | out = self._concat(output1, output2) 130 | out = F.relu(input + out, inplace=True) 131 | return channel_shuffle(out, 2) 132 | 133 | 134 | class Encoder(nn.Module): 135 | def __init__(self, num_classes): 136 | super().__init__() 137 | self.initial_block = DownsamplerBlock(3, 32) 138 | 139 | self.layers = nn.ModuleList() 140 | 141 | for x in range(0, 3): 142 | self.layers.append(ResModule(32, 0.03, 1)) 143 | 144 | self.layers.append(DownsamplerBlock(32, 64)) 145 | 146 | for x in range(0, 2): 147 | self.layers.append(ResModule(64, 0.03, 1)) 148 | 149 | self.layers.append(DownsamplerBlock(64, 128)) 150 | 151 | for x in range(0, 1): 152 | self.layers.append(ResModule(128, 0.3, 1)) 153 | self.layers.append(ResModule(128, 0.3, 2)) 154 | self.layers.append(ResModule(128, 0.3, 5)) 155 | self.layers.append(ResModule(128, 0.3, 9)) 156 | 157 | for x in range(0, 1): 158 | self.layers.append(ResModule(128, 0.3, 2)) 159 | self.layers.append(ResModule(128, 0.3, 5)) 160 | self.layers.append(ResModule(128, 0.3, 9)) 161 | self.layers.append(ResModule(128, 0.3, 17)) 162 | 163 | # Only in encoder mode: 164 | self.output_conv = nn.Conv2d(128, num_classes, 1, stride=1, padding=0, bias=True) 165 | 166 | def forward(self, input, predict=False): 167 | 168 | output = self.initial_block(input) 169 | 170 | for layer in self.layers: 171 | output = layer(output) 172 | 173 | if predict: 174 | output = self.output_conv(output) 175 | 176 | return output 177 | 178 | 179 | class Interpolate(nn.Module): 180 | def __init__(self, size, mode): 181 | super(Interpolate, self).__init__() 182 | 183 | self.interp = nn.functional.interpolate 184 | self.size = size 185 | self.mode = mode 186 | 187 | def forward(self, x): 188 | x = self.interp(x, size=self.size, mode=self.mode, align_corners=True) 189 | return x 190 | 191 | 192 | class FPAModule(nn.Module): 193 | def __init__(self, in_ch, out_ch): 194 | super(FPAModule, self).__init__() 195 | # global pooling branch 196 | self.branch1 = nn.Sequential( 197 | nn.AdaptiveAvgPool2d(1), 198 | Conv2dBnRelu(in_ch, out_ch, kernel_size=1, stride=1, padding=0) 199 | ) 200 | 201 | # midddle branch 202 | self.mid = nn.Sequential( 203 | Conv2dBnRelu(in_ch, out_ch, kernel_size=1, stride=1, padding=0) 204 | ) 205 | 206 | self.down1 = Conv2dBnRelu(in_ch, 1, kernel_size=7, stride=2, padding=3) 207 | 208 | self.down2 = Conv2dBnRelu(1, 1, kernel_size=5, stride=2, padding=2) 209 | 210 | self.down3 = nn.Sequential( 211 | Conv2dBnRelu(1, 1, kernel_size=3, stride=2, padding=1), 212 | Conv2dBnRelu(1, 1, kernel_size=3, stride=1, padding=1) 213 | ) 214 | 215 | self.conv2 = Conv2dBnRelu(1, 1, kernel_size=5, stride=1, padding=2) 216 | self.conv1 = Conv2dBnRelu(1, 1, kernel_size=7, stride=1, padding=3) 217 | 218 | def forward(self, x): 219 | h = x.size()[2] 220 | w = x.size()[3] 221 | 222 | b1 = self.branch1(x) 223 | # b1 = Interpolate(size=(h, w), mode="bilinear")(b1) 224 | b1 = interpolate(b1, size=(h, w), mode="bilinear", align_corners=True) 225 | 226 | mid = self.mid(x) 227 | 228 | x1 = self.down1(x) 229 | x2 = self.down2(x1) 230 | x3 = self.down3(x2) 231 | # x3 = Interpolate(size=(h // 4, w // 4), mode="bilinear")(x3) 232 | x3 = interpolate(x3, size=(h // 4, w // 4), mode="bilinear", align_corners=True) 233 | x2 = self.conv2(x2) 234 | x = x2 + x3 235 | # x = Interpolate(size=(h // 2, w // 2), mode="bilinear")(x) 236 | x = interpolate(x, size=(h // 2, w // 2), mode="bilinear", align_corners=True) 237 | 238 | x1 = self.conv1(x1) 239 | x = x + x1 240 | # x = Interpolate(size=(h, w), mode="bilinear")(x) 241 | x = interpolate(x, size=(h, w), mode="bilinear", align_corners=True) 242 | 243 | x = torch.mul(x, mid) 244 | 245 | x = x + b1 246 | 247 | return x 248 | 249 | 250 | class Decoder(nn.Module): 251 | def __init__(self, num_classes): 252 | super().__init__() 253 | 254 | self.fpa = FPAModule(in_ch=128, out_ch=num_classes) 255 | # self.upsample = Interpolate(size=(512, 1024), mode="bilinear") 256 | # self.output_conv = nn.ConvTranspose2d(16, num_classes, kernel_size=4, stride=2, padding=1, output_padding=0, bias=True) 257 | # self.output_conv = nn.ConvTranspose2d(16, num_classes, kernel_size=3, stride=2, padding=1, output_padding=1, bias=True) 258 | # self.output_conv = nn.ConvTranspose2d(16, num_classes, kernel_size=2, stride=2, padding=0, output_padding=0, bias=True) 259 | 260 | def forward(self, input): 261 | output = self.fpa(input) 262 | out = interpolate(output, size=(512, 1024), mode="bilinear", align_corners=True) 263 | # out = self.upsample(output) 264 | return out 265 | 266 | 267 | # spnet 268 | class Net(nn.Module): 269 | def __init__(self, num_classes, encoder=None): 270 | super().__init__() 271 | 272 | if (encoder == None): 273 | self.encoder = Encoder(num_classes) 274 | else: 275 | self.encoder = encoder 276 | self.decoder = Decoder(num_classes) 277 | 278 | def forward(self, input, only_encode=False): 279 | if only_encode: 280 | return self.encoder.forward(input, predict=True) 281 | else: 282 | output = self.encoder(input) 283 | return self.decoder.forward(output) 284 | 285 | 286 | class spnet(nn.Module): 287 | def __init__(self, num_classes): 288 | super(spnet, self).__init__() 289 | self.encoder = Encoder(num_classes) 290 | self.decoder_seg = Decoder(num_classes) 291 | self.decoder_depth = Decoder(1) 292 | 293 | def forward(self, x): 294 | x = self.encoder(x) 295 | seg_logits = self.decoder_seg(x) 296 | depth = self.decoder_depth(x) 297 | return [[depth, seg_logits]] 298 | 299 | 300 | if __name__ == '__main__': 301 | from utils import params_size 302 | batch = 4 303 | x = torch.rand(batch, 3, 512, 1024).cuda() 304 | net = spnet(20).cuda() 305 | # params_size(net) 306 | with torch.no_grad(): 307 | import time 308 | t1 = time.time() 309 | for i in range(100//batch): 310 | print(i) 311 | s = net(x) 312 | t2 = time.time() 313 | print(100/(t2 - t1)) 314 | 315 | -------------------------------------------------------------------------------- /models/trl.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision.models as models 4 | 5 | 6 | class conv2DGroupNormRelu(nn.Module): 7 | def __init__( 8 | self, in_channels, n_filters, k_size, stride, padding, bias=True, dilation=1, n_groups=16 9 | ): 10 | super(conv2DGroupNormRelu, self).__init__() 11 | 12 | conv_mod = nn.Conv2d( 13 | int(in_channels), 14 | int(n_filters), 15 | kernel_size=k_size, 16 | padding=padding, 17 | stride=stride, 18 | bias=bias, 19 | dilation=dilation, 20 | ) 21 | 22 | self.cgr_unit = nn.Sequential( 23 | conv_mod, nn.GroupNorm(n_groups, int(n_filters)), nn.ReLU(inplace=True) 24 | ) 25 | 26 | def forward(self, inputs): 27 | outputs = self.cgr_unit(inputs) 28 | return outputs 29 | 30 | 31 | class conv2DBatchNormRelu(nn.Module): 32 | def __init__( 33 | self, 34 | in_channels, 35 | n_filters, 36 | k_size, 37 | stride, 38 | padding, 39 | bias=True, 40 | dilation=1, 41 | is_batchnorm=True, 42 | ): 43 | super(conv2DBatchNormRelu, self).__init__() 44 | 45 | conv_mod = nn.Conv2d( 46 | int(in_channels), 47 | int(n_filters), 48 | kernel_size=k_size, 49 | padding=padding, 50 | stride=stride, 51 | bias=bias, 52 | dilation=dilation, 53 | ) 54 | 55 | if is_batchnorm: 56 | self.cbr_unit = nn.Sequential( 57 | conv_mod, nn.BatchNorm2d(int(n_filters)), nn.ReLU(inplace=True) 58 | ) 59 | else: 60 | self.cbr_unit = nn.Sequential(conv_mod, nn.ReLU(inplace=True)) 61 | 62 | def forward(self, inputs): 63 | outputs = self.cbr_unit(inputs) 64 | return outputs 65 | 66 | 67 | class ResidualBlock(nn.Module): 68 | expansion = 4 69 | 70 | def __init__(self, inplanes, planes, stride=1): 71 | super(ResidualBlock, self).__init__() 72 | self.inplanes = inplanes 73 | self.planes = planes 74 | self.conv1 = nn.Conv2d(inplanes, int(planes/4), kernel_size=1, stride=1, padding=0) 75 | self.bn1 = nn.BatchNorm2d(int(planes/4)) 76 | self.conv2 = nn.Conv2d(int(planes/4), int(planes/4), kernel_size=3, stride=stride, padding=1) 77 | self.bn2 = nn.BatchNorm2d(int(planes/4)) 78 | self.conv3 = nn.Conv2d(int(planes/4), planes, kernel_size=1, stride=1, padding=0) 79 | self.bn3 = nn.BatchNorm2d(planes) 80 | self.relu = nn.ReLU(inplace=True) 81 | self.downsample = nn.Sequential(nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, padding=0), 82 | nn.BatchNorm2d(planes)) 83 | self.stride = stride 84 | 85 | def forward(self, x): 86 | identity = x 87 | 88 | out = self.conv1(x) 89 | out = self.bn1(out) 90 | out = self.relu(out) 91 | 92 | out = self.conv2(out) 93 | out = self.bn2(out) 94 | out = self.relu(out) 95 | 96 | out = self.conv3(out) 97 | out = self.bn3(out) 98 | 99 | if self.inplanes != self.planes * self.expansion: 100 | identity = self.downsample(x) 101 | 102 | out += identity 103 | out = self.relu(out) 104 | 105 | return out 106 | 107 | 108 | class UpSampleBlock(nn.Module): 109 | def __init__(self, in_channels, upscale_factor=2): 110 | super(UpSampleBlock, self).__init__() 111 | 112 | self.conv_1 = conv2DBatchNormRelu(in_channels=in_channels, 113 | n_filters=in_channels/2, 114 | k_size=1, 115 | stride=1, 116 | padding=0) 117 | 118 | self.conv_2 = conv2DBatchNormRelu(in_channels=in_channels, 119 | n_filters=in_channels/2, 120 | k_size=3, 121 | stride=1, 122 | dilation=2, 123 | padding=2) 124 | 125 | self.conv_3 = conv2DBatchNormRelu(in_channels=in_channels, 126 | n_filters=in_channels/2, 127 | k_size=5, 128 | stride=1, 129 | dilation=1, 130 | padding=2) 131 | 132 | self.conv_4 = conv2DBatchNormRelu(in_channels=in_channels, 133 | n_filters=in_channels/2, 134 | k_size=7, 135 | stride=1, 136 | dilation=1, 137 | padding=3) 138 | 139 | self.sub_pixel = nn.PixelShuffle(upscale_factor) 140 | 141 | def forward(self, x): 142 | # x = torch.cat([self.conv_1(x), self.conv_1(x), self.conv_1(x), self.conv_1(x)], 1) 143 | x = self.conv_1(x) 144 | x = self.sub_pixel(x) 145 | return x 146 | 147 | 148 | class TAM(nn.Module): 149 | def __init__(self, in_channels): 150 | super(TAM, self).__init__() 151 | self.BU_conv1 = nn.Conv2d(in_channels*2, in_channels, kernel_size=1, stride=1, padding=0) 152 | self.BU_conv2 = nn.Conv2d(in_channels*2, in_channels, kernel_size=1, stride=1, padding=0) 153 | self.BU_sigmoid = nn.Sigmoid() 154 | 155 | self.downsample1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 156 | self.downsample2 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 157 | 158 | self.rb1 = ResidualBlock(in_channels, in_channels) 159 | self.rb2 = ResidualBlock(in_channels, in_channels) 160 | self.rb3 = ResidualBlock(in_channels, in_channels) 161 | self.rb4 = ResidualBlock(in_channels, in_channels) 162 | 163 | self.upsample1 = nn.UpsamplingBilinear2d(scale_factor=2) 164 | self.upsample2 = nn.UpsamplingBilinear2d(scale_factor=2) 165 | 166 | self.M_sigmoid = nn.Sigmoid() 167 | 168 | self.conv = nn.Conv2d(in_channels*2, in_channels, kernel_size=1, stride=1, padding=0) 169 | 170 | def forward(self, fs, fd): 171 | # temp->B 172 | temp = self.BU_sigmoid(self.BU_conv1(torch.cat([fs, fd], 1))) 173 | temp = self.BU_conv2(torch.cat([temp * fd, temp*(1-temp)], 1)) 174 | 175 | temp = self.downsample1(temp) 176 | temp = self.rb1(temp) 177 | temp = self.downsample2(temp) 178 | temp = self.rb2(temp) 179 | temp = self.upsample1(temp) 180 | temp = self.rb3(temp) 181 | temp = self.upsample2(temp) 182 | temp = self.rb4(temp) 183 | 184 | # temp -> M 185 | temp = self.M_sigmoid(temp) 186 | 187 | temp = self.conv(torch.cat([(1+temp)*fd, (1+temp)*fs], 1)) 188 | 189 | return temp 190 | 191 | 192 | class ResidualBlockDecode(nn.Module): 193 | def __init__(self, in_channels, out_channels): 194 | super(ResidualBlockDecode, self).__init__() 195 | self.bottleneck1 = ResidualBlock(in_channels, out_channels) 196 | self.bottleneck2 = ResidualBlock(out_channels, out_channels) 197 | 198 | def forward(self, x): 199 | x = self.bottleneck1(x) 200 | x = self.bottleneck2(x) 201 | return x 202 | 203 | 204 | class TRL(nn.Module): 205 | def __init__(self, n_classes): 206 | super(TRL, self).__init__() 207 | resnet50 = models.resnet50() 208 | self.conv1 = nn.Sequential(*list(resnet50.children())[:4]) 209 | self.res_2 = resnet50.layer1 210 | self.res_3 = resnet50.layer2 211 | self.res_4 = resnet50.layer3 212 | self.res_5 = resnet50.layer4 213 | 214 | self.upsample_res_5 = UpSampleBlock(2048) 215 | self.upsample_res_d1 = UpSampleBlock(1024) 216 | self.upsample_res_d2 = UpSampleBlock(1024) 217 | self.upsample_res_d3 = UpSampleBlock(512) 218 | self.upsample_res_d4 = UpSampleBlock(512) 219 | self.upsample_res_d5 = UpSampleBlock(256) 220 | self.upsample_res_d6 = UpSampleBlock(256) 221 | 222 | self.TAM_res_d3 = TAM(512) 223 | self.TAM_res_d4 = TAM(512) 224 | self.TAM_res_d5 = TAM(256) 225 | self.TAM_res_d6 = TAM(256) 226 | self.TAM_res_d7 = TAM(128) 227 | self.TAM_res_d8 = TAM(128) 228 | 229 | self.res_d1 = ResidualBlockDecode(2048, 1024) 230 | self.res_d2 = ResidualBlockDecode(3072, 1024) 231 | self.res_d3 = ResidualBlockDecode(1024, 512) 232 | self.res_d4 = ResidualBlockDecode(1024, 512) 233 | self.res_d5 = ResidualBlockDecode(512, 256) 234 | self.res_d6 = ResidualBlockDecode(512, 256) 235 | self.res_d7 = ResidualBlockDecode(128, 128) 236 | self.res_d8 = ResidualBlockDecode(128, 128) 237 | 238 | self.conv_d1 = nn.Conv2d(in_channels=1024, out_channels=1, kernel_size=1) 239 | self.conv_d2 = nn.Conv2d(in_channels=1024, out_channels=n_classes, kernel_size=1) 240 | self.conv_d3 = nn.Conv2d(in_channels=512, out_channels=1, kernel_size=1) 241 | self.conv_d4 = nn.Conv2d(in_channels=512, out_channels=n_classes, kernel_size=1) 242 | self.conv_d5 = nn.Conv2d(in_channels=256, out_channels=1, kernel_size=1) 243 | self.conv_d6 = nn.Conv2d(in_channels=256, out_channels=n_classes, kernel_size=1) 244 | self.conv_d7 = nn.Conv2d(in_channels=128, out_channels=1, kernel_size=1) 245 | self.conv_d8 = nn.Conv2d(in_channels=128, out_channels=n_classes, kernel_size=1) 246 | 247 | def forward(self, x): 248 | x = self.conv1(x) 249 | x2 = self.res_2(x) 250 | x3 = self.res_3(x2) 251 | x4 = self.res_4(x3) 252 | x = self.res_5(x4) 253 | 254 | x = self.upsample_res_5(x) 255 | r1 = self.res_d1(torch.cat([x, x4], 1)) 256 | depth_1 = self.conv_d1(r1) 257 | r2 = self.res_d2(torch.cat([x, r1, x4], 1)) 258 | segmentation_1 = self.conv_d2(r2) 259 | 260 | x = self.upsample_res_d2(r2) 261 | r1 = self.res_d3(torch.cat([self.TAM_res_d3(x, self.upsample_res_d1(r1)), x3], 1)) 262 | depth_2 = self.conv_d3(r1) 263 | r2 = self.res_d4(torch.cat([self.TAM_res_d4(x, r1), x3], 1)) 264 | segmentation_2 = self.conv_d4(r2) 265 | 266 | x = self.upsample_res_d4(r2) 267 | r1 = self.res_d5(torch.cat([self.TAM_res_d5(x, self.upsample_res_d3(r1)), x2], 1)) 268 | depth_3 = self.conv_d5(r1) 269 | r2 = self.res_d6(torch.cat([self.TAM_res_d6(x, r1), x2], 1)) 270 | segmentation_3 = self.conv_d6(r2) 271 | 272 | x = self.upsample_res_d6(r2) 273 | r1 = self.res_d7(self.TAM_res_d7(x, self.upsample_res_d5(r1))) 274 | depth_4 = self.conv_d7(r1) 275 | r2 = self.res_d8(self.TAM_res_d8(x, r1)) 276 | segmentation_4 = self.conv_d8(r2) 277 | 278 | return [(depth_1, segmentation_1), 279 | (depth_2, segmentation_2), 280 | (depth_3, segmentation_3), 281 | (depth_4, segmentation_4)] 282 | 283 | 284 | if __name__ == '__main__': 285 | import time 286 | from utils import params_size 287 | x = torch.rand((1, 3, 512, 1024)) 288 | model = TRL(19) 289 | params_size(model) 290 | t1 = time.time() 291 | x = model(x) 292 | for each in x: 293 | for i in each: 294 | print(i.size()) 295 | print(time.time()-t1) 296 | # t1 = time.time() 297 | # resnet50 = models.resnet50() 298 | # params_size(resnet50) 299 | # print(time.time()-t1) 300 | 301 | 302 | 303 | 304 | 305 | -------------------------------------------------------------------------------- /schedulers.py: -------------------------------------------------------------------------------- 1 | from torch.optim.lr_scheduler import _LRScheduler 2 | 3 | 4 | class ConstantLR(_LRScheduler): 5 | def __init__(self, optimizer, last_epoch=-1): 6 | super(ConstantLR, self).__init__(optimizer, last_epoch) 7 | 8 | def get_lr(self): 9 | return [base_lr for base_lr in self.base_lrs] 10 | 11 | 12 | class PolynomialLR(_LRScheduler): 13 | def __init__(self, optimizer, max_iter, decay_iter=1, gamma=0.9, last_epoch=-1): 14 | self.decay_iter = decay_iter 15 | self.max_iter = max_iter 16 | self.gamma = gamma 17 | super(PolynomialLR, self).__init__(optimizer, last_epoch) 18 | 19 | def get_lr(self): 20 | # if self.last_epoch % self.decay_iter or self.last_epoch % self.max_iter: 21 | # print('hhh') 22 | # return [base_lr for base_lr in self.base_lrs] 23 | # else: 24 | factor = (1 - self.last_epoch / float(self.max_iter)) ** self.gamma 25 | return [base_lr * factor for base_lr in self.base_lrs] 26 | 27 | 28 | class WarmUpLR(_LRScheduler): 29 | def __init__( 30 | self, optimizer, scheduler, mode="linear", warmup_iters=100, gamma=0.2, last_epoch=-1 31 | ): 32 | self.mode = mode 33 | self.scheduler = scheduler 34 | self.warmup_iters = warmup_iters 35 | self.gamma = gamma 36 | super(WarmUpLR, self).__init__(optimizer, last_epoch) 37 | 38 | def get_lr(self): 39 | cold_lrs = self.scheduler.get_lr() 40 | 41 | if self.last_epoch < self.warmup_iters: 42 | if self.mode == "linear": 43 | alpha = self.last_epoch / float(self.warmup_iters) 44 | factor = self.gamma * (1 - alpha) + alpha 45 | 46 | elif self.mode == "constant": 47 | factor = self.gamma 48 | else: 49 | raise KeyError("WarmUp type {} not implemented".format(self.mode)) 50 | 51 | return [factor * base_lr for base_lr in cold_lrs] 52 | 53 | return cold_lrs 54 | 55 | 56 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | import time 4 | import shutil 5 | import torch 6 | import random 7 | 8 | from tqdm import tqdm 9 | from models import icnet, spnet 10 | from torch import optim 11 | from loss import Loss 12 | from utils import get_logger 13 | from schedulers import PolynomialLR 14 | from torch.utils import data 15 | from torchvision.transforms import Compose 16 | from augmentations import * 17 | from cityscapes_loader import cityscapesLoader 18 | from tensorboardX import SummaryWriter 19 | from metrics import SegmentationScore, DepthEstimateScore, averageMeter 20 | 21 | networks = {'icnet': icnet.icnet, 'spnet': spnet.spnet} 22 | 23 | 24 | def train(cfg, writer, logger): 25 | 26 | # Setup seeds 27 | torch.manual_seed(cfg.get("seed", 1337)) 28 | torch.cuda.manual_seed(cfg.get("seed", 1337)) 29 | np.random.seed(cfg.get("seed", 1337)) 30 | random.seed(cfg.get("seed", 1337)) 31 | 32 | # Setup device 33 | device = torch.device(cfg['device']) 34 | 35 | # Setup Metrics 36 | seg_scores = SegmentationScore() 37 | depth_scores = DepthEstimateScore() 38 | 39 | augmentations = Compose([ 40 | RandomRotate(cfg['training']['argumentation']['random_rotate']), 41 | RandomCrop(cfg['training']['img_size']), 42 | RandomHorizonFlip(cfg['training']['argumentation']['random_hflip']), 43 | ]) 44 | 45 | traindata = cityscapesLoader(cfg['data']['path'], 46 | img_size=cfg['training']['img_size'], 47 | split=cfg['data']['train_split'], 48 | is_transform=True, 49 | augmentations=augmentations) 50 | 51 | valdata = cityscapesLoader(cfg['data']['path'], 52 | img_size=cfg['training']['img_size'], 53 | split=cfg['data']['val_split'], 54 | is_transform=True) 55 | 56 | trainloader = data.DataLoader(traindata, batch_size=cfg['training']['batch_size']) 57 | valloader = data.DataLoader(valdata, batch_size=cfg['training']['batch_size']) 58 | 59 | # Setup Model 60 | model = networks[cfg['arch']](**cfg['model']) 61 | 62 | model.to(device) 63 | loss_fn = Loss(**cfg['training']['loss']).to(device) 64 | # Setup optimizer, lr_scheduler and loss function 65 | optimizer = optim.SGD(model.parameters(), **cfg['training']['optimizer']) 66 | # TODO 67 | # optimizer_loss = optim.SGD(loss_fn.parameters(), **cfg['training']['optimizer_loss']) 68 | 69 | scheduler = PolynomialLR(optimizer, max_iter=cfg['training']['train_iters'], **cfg['training']['schedule']) 70 | # TODO 71 | # scheduler_loss = ConstantLR(optimizer_loss) 72 | 73 | start_iter = 0 74 | if cfg["training"]["resume"] is not None: 75 | if os.path.isfile(cfg["training"]["resume"]): 76 | logger.info( 77 | "Loading model and optimizer from checkpoint '{}'".format(cfg["training"]["resume"]) 78 | ) 79 | checkpoint = torch.load(cfg["training"]["resume"]) 80 | model.load_state_dict(checkpoint["model_state"]) 81 | optimizer.load_state_dict(checkpoint["optimizer_state"]) 82 | scheduler.load_state_dict(checkpoint["scheduler_state"]) 83 | start_iter = checkpoint["epoch"] 84 | logger.info( 85 | "Loaded checkpoint '{}' (iter {})".format( 86 | cfg["training"]["resume"], checkpoint["epoch"] 87 | ) 88 | ) 89 | else: 90 | logger.info("No checkpoint found at '{}'".format(cfg["training"]["resume"])) 91 | 92 | val_loss_meter = averageMeter() 93 | time_meter = averageMeter() 94 | 95 | best_miou = -100.0 96 | best_abs_rel = float('inf') 97 | i = start_iter 98 | flag = True 99 | optimizer.zero_grad() 100 | while i <= cfg["training"]["train_iters"] and flag: 101 | for sample in trainloader: 102 | i += 1 103 | start_ts = time.time() 104 | scheduler.step() 105 | # TODO 106 | # scheduler_loss.step() 107 | model.train() 108 | images = sample['image'].to(device) 109 | labels = sample['label'].to(device) 110 | depths = sample['depth'].to(device) 111 | 112 | # TODO 113 | # optimizer_loss.zero_grad() 114 | outputs = model(images) 115 | loss = loss_fn(outputs, depths, labels) / cfg['training']['accu_steps'] 116 | loss.backward() 117 | if i % cfg['training']['accu_steps'] == 0: 118 | optimizer.step() 119 | optimizer.zero_grad() 120 | # TODO 121 | # optimizer_loss.step() 122 | 123 | time_meter.update(time.time() - start_ts) 124 | 125 | if (i) % cfg["training"]["print_interval"] == 0: 126 | fmt_str = "Iter [{:d}/{:d}] Loss: {:.4f} Time/Image: {:.4f}" 127 | print_str = fmt_str.format( 128 | i, 129 | cfg["training"]["train_iters"], 130 | loss.item()*cfg['training']['accu_steps'], 131 | time_meter.avg / cfg["training"]["batch_size"], 132 | ) 133 | 134 | logger.info(print_str) 135 | writer.add_scalar("loss/train_loss", loss.item()*cfg['training']['accu_steps'], i) 136 | writer.add_scalar('param/delta1', loss_fn.delta1, i) 137 | writer.add_scalar('param/delta2', loss_fn.delta2, i) 138 | writer.add_scalar('param/learning-rate', scheduler.get_lr()[0], i) 139 | time_meter.reset() 140 | 141 | if i % cfg["training"]["val_interval"] == 0 or i == cfg["training"]["train_iters"]: 142 | model.eval() 143 | with torch.no_grad(): 144 | for i_val, sample in tqdm(enumerate(valloader)): 145 | images_val = sample['image'].to(device) 146 | labels_val = sample['label'].to(device) 147 | depths_val = sample['depth'].to(device) 148 | outputs = model(images_val) 149 | val_loss = loss_fn(outputs, depths_val, labels_val) 150 | 151 | depth_scores.update(depths_val.cpu().numpy(), outputs[-1][0].data.cpu().numpy()) 152 | seg_scores.update(labels_val.cpu().numpy(), outputs[-1][1].data.max(1)[1].cpu().numpy()) 153 | 154 | val_loss_meter.update(val_loss.item()) 155 | 156 | writer.add_scalar("loss/val_loss", val_loss_meter.avg, i) 157 | logger.info("Iter %d Loss: %.4f" % (i, val_loss_meter.avg)) 158 | 159 | seg_score, class_iou = seg_scores.get_scores() 160 | for k, v in seg_score.items(): 161 | logger.info("{}: {}".format(k, v)) 162 | writer.add_scalar("seg_val_metrics/{}".format(k), v, i) 163 | 164 | for k, v in class_iou.items(): 165 | # logger.info("{}: {}".format(k, v)) 166 | writer.add_scalar("seg_val_metrics/cls_{}".format(k), v, i) 167 | 168 | depth_score = depth_scores.get_scores() 169 | for k, v in depth_score.items(): 170 | logger.info("{}: {}".format(k, v)) 171 | writer.add_scalar("depth_val_metrics/{}".format(k), v, i) 172 | 173 | val_loss_meter.reset() 174 | seg_scores.reset() 175 | depth_scores.reset() 176 | 177 | # if seg_score["Mean IoU : \t"] >= best_miou and depth_score['abs_rel'] <= best_abs_rel: 178 | if seg_score["Mean IoU : \t"] >= best_miou: 179 | best_iou = seg_score["Mean IoU : \t"] 180 | best_abs_rel = depth_score['abs_rel'] 181 | state = { 182 | "epoch": i + 1, 183 | "model_state": model.state_dict(), 184 | "optimizer_state": optimizer.state_dict(), 185 | "scheduler_state": scheduler.state_dict(), 186 | "best_iou": best_iou, 187 | "best_abs_rel": best_abs_rel 188 | } 189 | save_path = os.path.join( 190 | writer.file_writer.get_logdir(), 191 | "{}_{}_best_model.pth".format(cfg['arch'], cfg["data"]["dataset"]), 192 | ) 193 | torch.save(state, save_path) 194 | 195 | state = { 196 | "epoch": i + 1, 197 | "model_state": model.state_dict(), 198 | "optimizer_state": optimizer.state_dict(), 199 | "scheduler_state": scheduler.state_dict(), 200 | } 201 | save_path = os.path.join( 202 | writer.file_writer.get_logdir(), 203 | "{}_{}_{}_model.pth".format(i, cfg['arch'], cfg["data"]["dataset"]), 204 | ) 205 | torch.save(state, save_path) 206 | 207 | if i == cfg["training"]["train_iters"]: 208 | flag = False 209 | break 210 | 211 | 212 | if __name__ == "__main__": 213 | config_file = "config/spnet-cityscapes.yml" 214 | with open(config_file) as fp: 215 | cfg = yaml.safe_load(fp) 216 | 217 | # run_id = random.randint(1, 100000) 218 | run_id = 607 219 | logdir = os.path.join("runs", os.path.basename(config_file)[:-4], str(run_id)) 220 | writer = SummaryWriter(log_dir=logdir) 221 | print("RUNDIR: {}".format(logdir)) 222 | shutil.copy(config_file, logdir) 223 | 224 | logger = get_logger(logdir, config_file.split('/')[-1].split('-')[0]) 225 | logger.info("Let the games begin") 226 | 227 | train(cfg, writer, logger) 228 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import logging 4 | import datetime 5 | import torch.nn as nn 6 | import numpy as np 7 | import torch.nn.functional as F 8 | from PIL import Image 9 | import torchvision.transforms as transforms 10 | 11 | from torch.autograd import Variable 12 | 13 | class conv2DBatchNorm(nn.Module): 14 | def __init__( 15 | self, 16 | in_channels, 17 | n_filters, 18 | k_size, 19 | stride, 20 | padding, 21 | bias=True, 22 | dilation=1, 23 | is_batchnorm=True, 24 | ): 25 | super(conv2DBatchNorm, self).__init__() 26 | 27 | conv_mod = nn.Conv2d( 28 | int(in_channels), 29 | int(n_filters), 30 | kernel_size=k_size, 31 | padding=padding, 32 | stride=stride, 33 | bias=bias, 34 | dilation=dilation, 35 | ) 36 | 37 | if is_batchnorm: 38 | self.cb_unit = nn.Sequential(conv_mod, nn.BatchNorm2d(int(n_filters))) 39 | else: 40 | self.cb_unit = nn.Sequential(conv_mod) 41 | 42 | def forward(self, inputs): 43 | outputs = self.cb_unit(inputs) 44 | return outputs 45 | 46 | 47 | class conv2DGroupNorm(nn.Module): 48 | def __init__( 49 | self, in_channels, n_filters, k_size, stride, padding, bias=True, dilation=1, n_groups=16 50 | ): 51 | super(conv2DGroupNorm, self).__init__() 52 | 53 | conv_mod = nn.Conv2d( 54 | int(in_channels), 55 | int(n_filters), 56 | kernel_size=k_size, 57 | padding=padding, 58 | stride=stride, 59 | bias=bias, 60 | dilation=dilation, 61 | ) 62 | 63 | self.cg_unit = nn.Sequential(conv_mod, nn.GroupNorm(n_groups, int(n_filters))) 64 | 65 | def forward(self, inputs): 66 | outputs = self.cg_unit(inputs) 67 | return outputs 68 | 69 | 70 | class deconv2DBatchNorm(nn.Module): 71 | def __init__(self, in_channels, n_filters, k_size, stride, padding, bias=True): 72 | super(deconv2DBatchNorm, self).__init__() 73 | 74 | self.dcb_unit = nn.Sequential( 75 | nn.ConvTranspose2d( 76 | int(in_channels), 77 | int(n_filters), 78 | kernel_size=k_size, 79 | padding=padding, 80 | stride=stride, 81 | bias=bias, 82 | ), 83 | nn.BatchNorm2d(int(n_filters)), 84 | ) 85 | 86 | def forward(self, inputs): 87 | outputs = self.dcb_unit(inputs) 88 | return outputs 89 | 90 | 91 | class conv2DBatchNormRelu(nn.Module): 92 | def __init__( 93 | self, 94 | in_channels, 95 | n_filters, 96 | k_size, 97 | stride, 98 | padding, 99 | bias=True, 100 | dilation=1, 101 | is_batchnorm=True, 102 | ): 103 | super(conv2DBatchNormRelu, self).__init__() 104 | 105 | conv_mod = nn.Conv2d( 106 | int(in_channels), 107 | int(n_filters), 108 | kernel_size=k_size, 109 | padding=padding, 110 | stride=stride, 111 | bias=bias, 112 | dilation=dilation, 113 | ) 114 | 115 | if is_batchnorm: 116 | self.cbr_unit = nn.Sequential( 117 | conv_mod, nn.BatchNorm2d(int(n_filters)), nn.ReLU(inplace=True) 118 | ) 119 | else: 120 | self.cbr_unit = nn.Sequential(conv_mod, nn.ReLU(inplace=True)) 121 | 122 | def forward(self, inputs): 123 | outputs = self.cbr_unit(inputs) 124 | return outputs 125 | 126 | 127 | class conv2DGroupNormRelu(nn.Module): 128 | def __init__( 129 | self, in_channels, n_filters, k_size, stride, padding, bias=True, dilation=1, n_groups=16 130 | ): 131 | super(conv2DGroupNormRelu, self).__init__() 132 | 133 | conv_mod = nn.Conv2d( 134 | int(in_channels), 135 | int(n_filters), 136 | kernel_size=k_size, 137 | padding=padding, 138 | stride=stride, 139 | bias=bias, 140 | dilation=dilation, 141 | ) 142 | 143 | self.cgr_unit = nn.Sequential( 144 | conv_mod, nn.GroupNorm(n_groups, int(n_filters)), nn.ReLU(inplace=True) 145 | ) 146 | 147 | def forward(self, inputs): 148 | outputs = self.cgr_unit(inputs) 149 | return outputs 150 | 151 | 152 | class deconv2DBatchNormRelu(nn.Module): 153 | def __init__(self, in_channels, n_filters, k_size, stride, padding, bias=True): 154 | super(deconv2DBatchNormRelu, self).__init__() 155 | 156 | self.dcbr_unit = nn.Sequential( 157 | nn.ConvTranspose2d( 158 | int(in_channels), 159 | int(n_filters), 160 | kernel_size=k_size, 161 | padding=padding, 162 | stride=stride, 163 | bias=bias, 164 | ), 165 | nn.BatchNorm2d(int(n_filters)), 166 | nn.ReLU(inplace=True), 167 | ) 168 | 169 | def forward(self, inputs): 170 | outputs = self.dcbr_unit(inputs) 171 | return outputs 172 | 173 | 174 | class unetConv2(nn.Module): 175 | def __init__(self, in_size, out_size, is_batchnorm): 176 | super(unetConv2, self).__init__() 177 | 178 | if is_batchnorm: 179 | self.conv1 = nn.Sequential( 180 | nn.Conv2d(in_size, out_size, 3, 1, 0), nn.BatchNorm2d(out_size), nn.ReLU() 181 | ) 182 | self.conv2 = nn.Sequential( 183 | nn.Conv2d(out_size, out_size, 3, 1, 0), nn.BatchNorm2d(out_size), nn.ReLU() 184 | ) 185 | else: 186 | self.conv1 = nn.Sequential(nn.Conv2d(in_size, out_size, 3, 1, 0), nn.ReLU()) 187 | self.conv2 = nn.Sequential(nn.Conv2d(out_size, out_size, 3, 1, 0), nn.ReLU()) 188 | 189 | def forward(self, inputs): 190 | outputs = self.conv1(inputs) 191 | outputs = self.conv2(outputs) 192 | return outputs 193 | 194 | 195 | class unetUp(nn.Module): 196 | def __init__(self, in_size, out_size, is_deconv): 197 | super(unetUp, self).__init__() 198 | self.conv = unetConv2(in_size, out_size, False) 199 | if is_deconv: 200 | self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=2, stride=2) 201 | else: 202 | self.up = nn.UpsamplingBilinear2d(scale_factor=2) 203 | 204 | def forward(self, inputs1, inputs2): 205 | outputs2 = self.up(inputs2) 206 | offset = outputs2.size()[2] - inputs1.size()[2] 207 | padding = 2 * [offset // 2, offset // 2] 208 | outputs1 = F.pad(inputs1, padding) 209 | return self.conv(torch.cat([outputs1, outputs2], 1)) 210 | 211 | 212 | class segnetDown2(nn.Module): 213 | def __init__(self, in_size, out_size): 214 | super(segnetDown2, self).__init__() 215 | self.conv1 = conv2DBatchNormRelu(in_size, out_size, 3, 1, 1) 216 | self.conv2 = conv2DBatchNormRelu(out_size, out_size, 3, 1, 1) 217 | self.maxpool_with_argmax = nn.MaxPool2d(2, 2, return_indices=True) 218 | 219 | def forward(self, inputs): 220 | outputs = self.conv1(inputs) 221 | outputs = self.conv2(outputs) 222 | unpooled_shape = outputs.size() 223 | outputs, indices = self.maxpool_with_argmax(outputs) 224 | return outputs, indices, unpooled_shape 225 | 226 | 227 | class segnetDown3(nn.Module): 228 | def __init__(self, in_size, out_size): 229 | super(segnetDown3, self).__init__() 230 | self.conv1 = conv2DBatchNormRelu(in_size, out_size, 3, 1, 1) 231 | self.conv2 = conv2DBatchNormRelu(out_size, out_size, 3, 1, 1) 232 | self.conv3 = conv2DBatchNormRelu(out_size, out_size, 3, 1, 1) 233 | self.maxpool_with_argmax = nn.MaxPool2d(2, 2, return_indices=True) 234 | 235 | def forward(self, inputs): 236 | outputs = self.conv1(inputs) 237 | outputs = self.conv2(outputs) 238 | outputs = self.conv3(outputs) 239 | unpooled_shape = outputs.size() 240 | outputs, indices = self.maxpool_with_argmax(outputs) 241 | return outputs, indices, unpooled_shape 242 | 243 | 244 | class segnetUp2(nn.Module): 245 | def __init__(self, in_size, out_size): 246 | super(segnetUp2, self).__init__() 247 | self.unpool = nn.MaxUnpool2d(2, 2) 248 | self.conv1 = conv2DBatchNormRelu(in_size, in_size, 3, 1, 1) 249 | self.conv2 = conv2DBatchNormRelu(in_size, out_size, 3, 1, 1) 250 | 251 | def forward(self, inputs, indices, output_shape): 252 | outputs = self.unpool(input=inputs, indices=indices, output_size=output_shape) 253 | outputs = self.conv1(outputs) 254 | outputs = self.conv2(outputs) 255 | return outputs 256 | 257 | 258 | class segnetUp3(nn.Module): 259 | def __init__(self, in_size, out_size): 260 | super(segnetUp3, self).__init__() 261 | self.unpool = nn.MaxUnpool2d(2, 2) 262 | self.conv1 = conv2DBatchNormRelu(in_size, in_size, 3, 1, 1) 263 | self.conv2 = conv2DBatchNormRelu(in_size, in_size, 3, 1, 1) 264 | self.conv3 = conv2DBatchNormRelu(in_size, out_size, 3, 1, 1) 265 | 266 | def forward(self, inputs, indices, output_shape): 267 | outputs = self.unpool(input=inputs, indices=indices, output_size=output_shape) 268 | outputs = self.conv1(outputs) 269 | outputs = self.conv2(outputs) 270 | outputs = self.conv3(outputs) 271 | return outputs 272 | 273 | 274 | class residualBlock(nn.Module): 275 | expansion = 1 276 | 277 | def __init__(self, in_channels, n_filters, stride=1, downsample=None): 278 | super(residualBlock, self).__init__() 279 | 280 | self.convbnrelu1 = conv2DBatchNormRelu(in_channels, n_filters, 3, stride, 1, bias=False) 281 | self.convbn2 = conv2DBatchNorm(n_filters, n_filters, 3, 1, 1, bias=False) 282 | self.downsample = downsample 283 | self.stride = stride 284 | self.relu = nn.ReLU(inplace=True) 285 | 286 | def forward(self, x): 287 | residual = x 288 | 289 | out = self.convbnrelu1(x) 290 | out = self.convbn2(out) 291 | 292 | if self.downsample is not None: 293 | residual = self.downsample(x) 294 | 295 | out += residual 296 | out = self.relu(out) 297 | return out 298 | 299 | 300 | class residualBottleneck(nn.Module): 301 | expansion = 4 302 | 303 | def __init__(self, in_channels, n_filters, stride=1, downsample=None): 304 | super(residualBottleneck, self).__init__() 305 | self.convbn1 = nn.Conv2DBatchNorm(in_channels, n_filters, k_size=1, bias=False) 306 | self.convbn2 = nn.Conv2DBatchNorm( 307 | n_filters, n_filters, k_size=3, padding=1, stride=stride, bias=False 308 | ) 309 | self.convbn3 = nn.Conv2DBatchNorm(n_filters, n_filters * 4, k_size=1, bias=False) 310 | self.relu = nn.ReLU(inplace=True) 311 | self.downsample = downsample 312 | self.stride = stride 313 | 314 | def forward(self, x): 315 | residual = x 316 | 317 | out = self.convbn1(x) 318 | out = self.convbn2(out) 319 | out = self.convbn3(out) 320 | 321 | if self.downsample is not None: 322 | residual = self.downsample(x) 323 | 324 | out += residual 325 | out = self.relu(out) 326 | 327 | return out 328 | 329 | 330 | class linknetUp(nn.Module): 331 | def __init__(self, in_channels, n_filters): 332 | super(linknetUp, self).__init__() 333 | 334 | # B, 2C, H, W -> B, C/2, H, W 335 | self.convbnrelu1 = conv2DBatchNormRelu( 336 | in_channels, n_filters / 2, k_size=1, stride=1, padding=1 337 | ) 338 | 339 | # B, C/2, H, W -> B, C/2, H, W 340 | self.deconvbnrelu2 = nn.deconv2DBatchNormRelu( 341 | n_filters / 2, n_filters / 2, k_size=3, stride=2, padding=0 342 | ) 343 | 344 | # B, C/2, H, W -> B, C, H, W 345 | self.convbnrelu3 = conv2DBatchNormRelu( 346 | n_filters / 2, n_filters, k_size=1, stride=1, padding=1 347 | ) 348 | 349 | def forward(self, x): 350 | x = self.convbnrelu1(x) 351 | x = self.deconvbnrelu2(x) 352 | x = self.convbnrelu3(x) 353 | return x 354 | 355 | 356 | class FRRU(nn.Module): 357 | """ 358 | Full Resolution Residual Unit for FRRN 359 | """ 360 | 361 | def __init__(self, prev_channels, out_channels, scale, group_norm=False, n_groups=None): 362 | super(FRRU, self).__init__() 363 | self.scale = scale 364 | self.prev_channels = prev_channels 365 | self.out_channels = out_channels 366 | self.group_norm = group_norm 367 | self.n_groups = n_groups 368 | 369 | if self.group_norm: 370 | conv_unit = conv2DGroupNormRelu 371 | self.conv1 = conv_unit( 372 | prev_channels + 32, 373 | out_channels, 374 | k_size=3, 375 | stride=1, 376 | padding=1, 377 | bias=False, 378 | n_groups=self.n_groups, 379 | ) 380 | self.conv2 = conv_unit( 381 | out_channels, 382 | out_channels, 383 | k_size=3, 384 | stride=1, 385 | padding=1, 386 | bias=False, 387 | n_groups=self.n_groups, 388 | ) 389 | 390 | else: 391 | conv_unit = conv2DBatchNormRelu 392 | self.conv1 = conv_unit( 393 | prev_channels + 32, out_channels, k_size=3, stride=1, padding=1, bias=False 394 | ) 395 | self.conv2 = conv_unit( 396 | out_channels, out_channels, k_size=3, stride=1, padding=1, bias=False 397 | ) 398 | 399 | self.conv_res = nn.Conv2d(out_channels, 32, kernel_size=1, stride=1, padding=0) 400 | 401 | def forward(self, y, z): 402 | x = torch.cat([y, nn.MaxPool2d(self.scale, self.scale)(z)], dim=1) 403 | y_prime = self.conv1(x) 404 | y_prime = self.conv2(y_prime) 405 | 406 | x = self.conv_res(y_prime) 407 | upsample_size = torch.Size([_s * self.scale for _s in y_prime.shape[-2:]]) 408 | x = F.upsample(x, size=upsample_size, mode="nearest") 409 | z_prime = z + x 410 | 411 | return y_prime, z_prime 412 | 413 | 414 | class RU(nn.Module): 415 | """ 416 | Residual Unit for FRRN 417 | """ 418 | 419 | def __init__(self, channels, kernel_size=3, strides=1, group_norm=False, n_groups=None): 420 | super(RU, self).__init__() 421 | self.group_norm = group_norm 422 | self.n_groups = n_groups 423 | 424 | if self.group_norm: 425 | self.conv1 = conv2DGroupNormRelu( 426 | channels, 427 | channels, 428 | k_size=kernel_size, 429 | stride=strides, 430 | padding=1, 431 | bias=False, 432 | n_groups=self.n_groups, 433 | ) 434 | self.conv2 = conv2DGroupNorm( 435 | channels, 436 | channels, 437 | k_size=kernel_size, 438 | stride=strides, 439 | padding=1, 440 | bias=False, 441 | n_groups=self.n_groups, 442 | ) 443 | 444 | else: 445 | self.conv1 = conv2DBatchNormRelu( 446 | channels, channels, k_size=kernel_size, stride=strides, padding=1, bias=False 447 | ) 448 | self.conv2 = conv2DBatchNorm( 449 | channels, channels, k_size=kernel_size, stride=strides, padding=1, bias=False 450 | ) 451 | 452 | def forward(self, x): 453 | incoming = x 454 | x = self.conv1(x) 455 | x = self.conv2(x) 456 | return x + incoming 457 | 458 | 459 | class residualConvUnit(nn.Module): 460 | def __init__(self, channels, kernel_size=3): 461 | super(residualConvUnit, self).__init__() 462 | 463 | self.residual_conv_unit = nn.Sequential( 464 | nn.ReLU(inplace=True), 465 | nn.Conv2d(channels, channels, kernel_size=kernel_size), 466 | nn.ReLU(inplace=True), 467 | nn.Conv2d(channels, channels, kernel_size=kernel_size), 468 | ) 469 | 470 | def forward(self, x): 471 | input = x 472 | x = self.residual_conv_unit(x) 473 | return x + input 474 | 475 | 476 | class multiResolutionFusion(nn.Module): 477 | def __init__(self, channels, up_scale_high, up_scale_low, high_shape, low_shape): 478 | super(multiResolutionFusion, self).__init__() 479 | 480 | self.up_scale_high = up_scale_high 481 | self.up_scale_low = up_scale_low 482 | 483 | self.conv_high = nn.Conv2d(high_shape[1], channels, kernel_size=3) 484 | 485 | if low_shape is not None: 486 | self.conv_low = nn.Conv2d(low_shape[1], channels, kernel_size=3) 487 | 488 | def forward(self, x_high, x_low): 489 | high_upsampled = F.upsample( 490 | self.conv_high(x_high), scale_factor=self.up_scale_high, mode="bilinear" 491 | ) 492 | 493 | if x_low is None: 494 | return high_upsampled 495 | 496 | low_upsampled = F.upsample( 497 | self.conv_low(x_low), scale_factor=self.up_scale_low, mode="bilinear" 498 | ) 499 | 500 | return low_upsampled + high_upsampled 501 | 502 | 503 | class chainedResidualPooling(nn.Module): 504 | def __init__(self, channels, input_shape): 505 | super(chainedResidualPooling, self).__init__() 506 | 507 | self.chained_residual_pooling = nn.Sequential( 508 | nn.ReLU(inplace=True), 509 | nn.MaxPool2d(5, 1, 2), 510 | nn.Conv2d(input_shape[1], channels, kernel_size=3), 511 | ) 512 | 513 | def forward(self, x): 514 | input = x 515 | x = self.chained_residual_pooling(x) 516 | return x + input 517 | 518 | 519 | class pyramidPooling(nn.Module): 520 | def __init__( 521 | self, in_channels, pool_sizes, model_name="pspnet", fusion_mode="cat", is_batchnorm=True 522 | ): 523 | super(pyramidPooling, self).__init__() 524 | 525 | bias = not is_batchnorm 526 | 527 | self.paths = [] 528 | for i in range(len(pool_sizes)): 529 | self.paths.append( 530 | conv2DBatchNormRelu( 531 | in_channels, 532 | int(in_channels / len(pool_sizes)), 533 | 1, 534 | 1, 535 | 0, 536 | bias=bias, 537 | is_batchnorm=is_batchnorm, 538 | ) 539 | ) 540 | 541 | self.path_module_list = nn.ModuleList(self.paths) 542 | self.pool_sizes = pool_sizes 543 | self.model_name = model_name 544 | self.fusion_mode = fusion_mode 545 | 546 | def forward(self, x): 547 | h, w = x.shape[2:] 548 | 549 | if self.training or self.model_name != "icnet": # general settings or pspnet 550 | k_sizes = [] 551 | strides = [] 552 | for pool_size in self.pool_sizes: 553 | k_sizes.append((int(h / pool_size), int(w / pool_size))) 554 | strides.append((int(h / pool_size), int(w / pool_size))) 555 | else: # eval mode and icnet: pre-trained for 1025 x 2049 556 | k_sizes = [(8, 15), (13, 25), (17, 33), (33, 65)] 557 | strides = [(5, 10), (10, 20), (16, 32), (33, 65)] 558 | 559 | if self.fusion_mode == "cat": # pspnet: concat (including x) 560 | output_slices = [x] 561 | 562 | for i, (module, pool_size) in enumerate(zip(self.path_module_list, self.pool_sizes)): 563 | out = F.avg_pool2d(x, k_sizes[i], stride=strides[i], padding=0) 564 | # out = F.adaptive_avg_pool2d(x, output_size=(pool_size, pool_size)) 565 | if self.model_name != "icnet": 566 | out = module(out) 567 | out = F.interpolate(out, size=(h, w), mode="bilinear", align_corners=True) 568 | output_slices.append(out) 569 | 570 | return torch.cat(output_slices, dim=1) 571 | else: # icnet: element-wise sum (including x) 572 | pp_sum = x 573 | 574 | for i, (module, pool_size) in enumerate(zip(self.path_module_list, self.pool_sizes)): 575 | out = F.avg_pool2d(x, k_sizes[i], stride=strides[i], padding=0) 576 | # out = F.adaptive_avg_pool2d(x, output_size=(pool_size, pool_size)) 577 | if self.model_name != "icnet": 578 | out = module(out) 579 | out = F.interpolate(out, size=(h, w), mode="bilinear", align_corners=True) 580 | pp_sum = pp_sum + out 581 | 582 | return pp_sum 583 | 584 | 585 | class bottleNeckPSP(nn.Module): 586 | def __init__( 587 | self, in_channels, mid_channels, out_channels, stride, dilation=1, is_batchnorm=True 588 | ): 589 | super(bottleNeckPSP, self).__init__() 590 | 591 | bias = not is_batchnorm 592 | 593 | self.cbr1 = conv2DBatchNormRelu( 594 | in_channels, mid_channels, 1, stride=1, padding=0, bias=bias, is_batchnorm=is_batchnorm 595 | ) 596 | if dilation > 1: 597 | self.cbr2 = conv2DBatchNormRelu( 598 | mid_channels, 599 | mid_channels, 600 | 3, 601 | stride=stride, 602 | padding=dilation, 603 | bias=bias, 604 | dilation=dilation, 605 | is_batchnorm=is_batchnorm, 606 | ) 607 | else: 608 | self.cbr2 = conv2DBatchNormRelu( 609 | mid_channels, 610 | mid_channels, 611 | 3, 612 | stride=stride, 613 | padding=1, 614 | bias=bias, 615 | dilation=1, 616 | is_batchnorm=is_batchnorm, 617 | ) 618 | self.cb3 = conv2DBatchNorm( 619 | mid_channels, out_channels, 1, stride=1, padding=0, bias=bias, is_batchnorm=is_batchnorm 620 | ) 621 | self.cb4 = conv2DBatchNorm( 622 | in_channels, 623 | out_channels, 624 | 1, 625 | stride=stride, 626 | padding=0, 627 | bias=bias, 628 | is_batchnorm=is_batchnorm, 629 | ) 630 | 631 | def forward(self, x): 632 | conv = self.cb3(self.cbr2(self.cbr1(x))) 633 | residual = self.cb4(x) 634 | return F.relu(conv + residual, inplace=True) 635 | 636 | 637 | class bottleNeckIdentifyPSP(nn.Module): 638 | def __init__(self, in_channels, mid_channels, stride, dilation=1, is_batchnorm=True): 639 | super(bottleNeckIdentifyPSP, self).__init__() 640 | 641 | bias = not is_batchnorm 642 | 643 | self.cbr1 = conv2DBatchNormRelu( 644 | in_channels, mid_channels, 1, stride=1, padding=0, bias=bias, is_batchnorm=is_batchnorm 645 | ) 646 | if dilation > 1: 647 | self.cbr2 = conv2DBatchNormRelu( 648 | mid_channels, 649 | mid_channels, 650 | 3, 651 | stride=1, 652 | padding=dilation, 653 | bias=bias, 654 | dilation=dilation, 655 | is_batchnorm=is_batchnorm, 656 | ) 657 | else: 658 | self.cbr2 = conv2DBatchNormRelu( 659 | mid_channels, 660 | mid_channels, 661 | 3, 662 | stride=1, 663 | padding=1, 664 | bias=bias, 665 | dilation=1, 666 | is_batchnorm=is_batchnorm, 667 | ) 668 | self.cb3 = conv2DBatchNorm( 669 | mid_channels, in_channels, 1, stride=1, padding=0, bias=bias, is_batchnorm=is_batchnorm 670 | ) 671 | 672 | def forward(self, x): 673 | residual = x 674 | x = self.cb3(self.cbr2(self.cbr1(x))) 675 | return F.relu(x + residual, inplace=True) 676 | 677 | 678 | class residualBlockPSP(nn.Module): 679 | def __init__( 680 | self, 681 | n_blocks, 682 | in_channels, 683 | mid_channels, 684 | out_channels, 685 | stride, 686 | dilation=1, 687 | include_range="all", 688 | is_batchnorm=True, 689 | ): 690 | super(residualBlockPSP, self).__init__() 691 | 692 | if dilation > 1: 693 | stride = 1 694 | 695 | # residualBlockPSP = convBlockPSP + identityBlockPSPs 696 | layers = [] 697 | if include_range in ["all", "conv"]: 698 | layers.append( 699 | bottleNeckPSP( 700 | in_channels, 701 | mid_channels, 702 | out_channels, 703 | stride, 704 | dilation, 705 | is_batchnorm=is_batchnorm, 706 | ) 707 | ) 708 | if include_range in ["all", "identity"]: 709 | for i in range(n_blocks - 1): 710 | layers.append( 711 | bottleNeckIdentifyPSP( 712 | out_channels, mid_channels, stride, dilation, is_batchnorm=is_batchnorm 713 | ) 714 | ) 715 | 716 | self.layers = nn.Sequential(*layers) 717 | 718 | def forward(self, x): 719 | return self.layers(x) 720 | 721 | 722 | class cascadeFeatureFusion(nn.Module): 723 | def __init__( 724 | self, n_classes, low_in_channels, high_in_channels, out_channels, is_batchnorm=True 725 | ): 726 | super(cascadeFeatureFusion, self).__init__() 727 | 728 | bias = not is_batchnorm 729 | 730 | self.low_dilated_conv_bn = conv2DBatchNorm( 731 | low_in_channels, 732 | out_channels, 733 | 3, 734 | stride=1, 735 | padding=2, 736 | bias=bias, 737 | dilation=2, 738 | is_batchnorm=is_batchnorm, 739 | ) 740 | self.low_classifier_conv = nn.Conv2d( 741 | int(low_in_channels), 742 | int(n_classes), 743 | kernel_size=1, 744 | padding=0, 745 | stride=1, 746 | bias=True, 747 | dilation=1, 748 | ) # Train only 749 | 750 | self.low_depth_conv = nn.Conv2d( 751 | int(low_in_channels), 752 | 1, 753 | kernel_size=1, 754 | padding=0, 755 | stride=1, 756 | bias=True, 757 | dilation=1, 758 | ) 759 | self.high_proj_conv_bn = conv2DBatchNorm( 760 | high_in_channels, 761 | out_channels, 762 | 1, 763 | stride=1, 764 | padding=0, 765 | bias=bias, 766 | is_batchnorm=is_batchnorm, 767 | ) 768 | 769 | def forward(self, x_low, x_high): 770 | x_low_upsampled = F.interpolate( 771 | x_low, size=get_interp_size(x_low, z_factor=2), mode="bilinear", align_corners=True 772 | ) 773 | 774 | low_cls = self.low_classifier_conv(x_low_upsampled) 775 | low_depth = self.low_depth_conv(x_low_upsampled) 776 | low_fm = self.low_dilated_conv_bn(x_low_upsampled) 777 | high_fm = self.high_proj_conv_bn(x_high) 778 | high_fused_fm = F.relu(low_fm + high_fm, inplace=True) 779 | 780 | return high_fused_fm, low_cls, low_depth 781 | 782 | 783 | def get_interp_size(input, s_factor=1, z_factor=1): # for caffe 784 | ori_h, ori_w = input.shape[2:] 785 | 786 | # shrink (s_factor >= 1) 787 | ori_h = (ori_h - 1) / s_factor + 1 788 | ori_w = (ori_w - 1) / s_factor + 1 789 | 790 | # zoom (z_factor >= 1) 791 | ori_h = ori_h + (ori_h - 1) * (z_factor - 1) 792 | ori_w = ori_w + (ori_w - 1) * (z_factor - 1) 793 | 794 | resize_shape = (int(ori_h), int(ori_w)) 795 | return resize_shape 796 | 797 | 798 | def get_multiply_scale_inputs(input, scales=[4, 8, 16]): 799 | inputs = [] 800 | for s_scale in scales: 801 | inputs.append( 802 | F.interpolate(input, size=get_interp_size(input, s_factor=s_scale), mode="bilinear", align_corners=True)) 803 | return tuple(inputs) 804 | 805 | 806 | def interp(input, output_size, mode="bilinear"): 807 | n, c, ih, iw = input.shape 808 | oh, ow = output_size 809 | 810 | # normalize to [-1, 1] 811 | h = torch.arange(0, oh, dtype=torch.float, device=input.device) / (oh - 1) * 2 - 1 812 | w = torch.arange(0, ow, dtype=torch.float, device=input.device) / (ow - 1) * 2 - 1 813 | 814 | grid = torch.zeros(oh, ow, 2, dtype=torch.float, device=input.device) 815 | grid[:, :, 0] = w.unsqueeze(0).repeat(oh, 1) 816 | grid[:, :, 1] = h.unsqueeze(0).repeat(ow, 1).transpose(0, 1) 817 | grid = grid.unsqueeze(0).repeat(n, 1, 1, 1) # grid.shape: [n, oh, ow, 2] 818 | grid = Variable(grid) 819 | if input.is_cuda: 820 | grid = grid.cuda() 821 | 822 | return F.grid_sample(input, grid, mode=mode) 823 | 824 | 825 | def get_upsampling_weight(in_channels, out_channels, kernel_size): 826 | """Make a 2D bilinear kernel suitable for upsampling""" 827 | factor = (kernel_size + 1) // 2 828 | if kernel_size % 2 == 1: 829 | center = factor - 1 830 | else: 831 | center = factor - 0.5 832 | og = np.ogrid[:kernel_size, :kernel_size] 833 | filt = (1 - abs(og[0] - center) / factor) * (1 - abs(og[1] - center) / factor) 834 | weight = np.zeros((in_channels, out_channels, kernel_size, kernel_size), dtype=np.float64) 835 | weight[range(in_channels), range(out_channels), :, :] = filt 836 | return torch.from_numpy(weight).float() 837 | 838 | 839 | def scale_up(images, size): 840 | transform_scale = transforms.Resize(size) 841 | images = transform_scale(images) 842 | return images 843 | 844 | def img2depth(filename): 845 | depth_png = np.array(Image.open(filename), dtype=int) 846 | # make sure we have a proper 16bit depth map here.. not 8bit! 847 | assert(np.max(depth_png) > 255) 848 | 849 | depth = depth_png.astype(np.float) / 256. 850 | # depth[depth_png == 0] = -1. 851 | return depth 852 | 853 | 854 | def depth2img(depth, filename): 855 | depth = (depth * 256).astype('int16') 856 | depth_png = Image.fromarray(depth) 857 | depth_png.save(filename) 858 | 859 | 860 | def get_logger(logdir): 861 | logger = logging.getLogger("icnet") 862 | ts = str(datetime.datetime.now()).split(".")[0].replace(" ", "_") 863 | ts = ts.replace(":", "_").replace("-", "_") 864 | file_path = os.path.join(logdir, "run_{}.log".format(ts)) 865 | hdlr = logging.FileHandler(file_path) 866 | formatter = logging.Formatter("%(asctime)s %(levelname)s %(message)s") 867 | hdlr.setFormatter(formatter) 868 | logger.addHandler(hdlr) 869 | logger.setLevel(logging.INFO) 870 | return logger 871 | 872 | 873 | def recursive_glob(rootdir=".", suffix=""): 874 | """Performs recursive glob with given suffix and rootdir 875 | :param rootdir is the root directory 876 | :param suffix is the suffix to be searched 877 | """ 878 | return [ 879 | os.path.join(looproot, filename) 880 | for looproot, _, filenames in os.walk(rootdir) 881 | for filename in filenames 882 | if filename.endswith(suffix) 883 | ] 884 | 885 | 886 | def alpha_blend(input_image, segmentation_mask, alpha=0.5): 887 | """Alpha Blending utility to overlay RGB masks on RBG images 888 | :param input_image is a np.ndarray with 3 channels 889 | :param segmentation_mask is a np.ndarray with 3 channels 890 | :param alpha is a float value 891 | """ 892 | blended = np.zeros(input_image.size, dtype=np.float32) 893 | blended = input_image * alpha + segmentation_mask * (1 - alpha) 894 | return blended 895 | 896 | 897 | class BatchNorm1d: 898 | def __init__(self, train=True, momentum=0.1, eps=1e-5): 899 | self.train = train 900 | self.momentum = momentum 901 | self.eps = eps 902 | 903 | self.gamma = None 904 | self.beta = None 905 | 906 | self.dw = None 907 | self.db = None 908 | 909 | self.sqrt = None 910 | self.std = None 911 | 912 | self.running_mean = None 913 | self.running_var = None 914 | 915 | def __call__(self, x): 916 | if self.train is True: 917 | mean = np.mean(x, axis=0, keepdims=True) 918 | var = np.var(x, axis=0, keepdims=True) 919 | sqrt = np.sqrt(var + self.eps) 920 | std = (x - mean) / sqrt 921 | self.sqrt = sqrt 922 | self.std = std 923 | 924 | if self.running_mean is None: 925 | self.running_mean = np.zeros_like(mean) 926 | self.running_var = np.ones_like(var) 927 | 928 | num = np.shape(x)[0] 929 | self.running_mean = (1 - self.momentum) * self.running_mean 930 | self.running_mean += self.momentum * mean 931 | self.running_var = (1 - self.momentum) * self.running_var 932 | self.running_var += self.momentum * var * num / (num - 1) 933 | else: 934 | mean = self.running_mean 935 | var = self.running_var 936 | sqrt = np.sqrt(var + self.eps) 937 | std = (x - mean) / sqrt 938 | 939 | out = std * self.gamma + self.beta 940 | return out 941 | 942 | def backward(self, d_loss): 943 | std_t = self.std.T 944 | shape_t = np.shape(std_t) 945 | r = np.zeros([shape_t[0], shape_t[1], shape_t[1]]) 946 | shift_eye = np.eye(shape_t[1]) * shape_t[1] - 1 947 | for i in range(shape_t[0]): 948 | r[i] = std_t[i][:, np.newaxis] * std_t[i][np.newaxis, :] 949 | r[i] = shift_eye - r[i] 950 | 951 | u = self.gamma / shape_t[1] / self.sqrt 952 | u = u.T 953 | y = r * u[:, np.newaxis] 954 | 955 | dx = np.zeros(shape_t) 956 | for i in range(shape_t[0]): 957 | dx[i] = np.dot(d_loss.T[i], y[i]) 958 | dx = dx.T 959 | 960 | self.dw = np.sum(self.std * d_loss, axis=0) 961 | self.db = np.sum(d_loss, axis=0) 962 | 963 | return dx 964 | 965 | 966 | def params_size(net): 967 | params = list(net.parameters()) 968 | k=0 969 | for i in params: 970 | l = 1 971 | for j in i.size(): 972 | l *= j 973 | k += l 974 | print("参数量总和:%.4fM" % (k/10**6)) 975 | 976 | 977 | def convert_param_dict(model_path): 978 | param_dict = torch.load(model_path)['model_state'] 979 | param_dict = {k.replace('module.', ''): v for k, v in param_dict.items()} 980 | return param_dict 981 | --------------------------------------------------------------------------------