├── README.md ├── Submission.py ├── data-agu.jpg ├── data.jpg ├── data_agu.py ├── data_tree.jpg ├── loss.py ├── model.py ├── net.jpg ├── result.jpg ├── seg_iou.py └── train.py /README.md: -------------------------------------------------------------------------------- 1 | # 2020“华为云杯”人工智能大赛冠军方案分享 2 | ### 我们是西南交通大学朱庆教授带领的虚拟地理环境团队,欢迎访问我们的主页:https://vrlab.org.cn/ 了解更多信息。我们的方案在初赛的线上精度达0.8411,排名(2/377) 3 | 4 | ## 总体方案介绍: 5 | ##### 针对遥感影像中道路尺度差异大,道路与其它背景信息样本不平衡,传感器、环境、构筑材料差异导致外观多样化等特点。我们在E-D架构的基础上,提出一种通道注意力增强的特征自适应融合方法,并设计基于梯度的边缘约束模块。在增强空间细节和语义特征的同时,提高道路边缘的特征响应,实现多尺度道路准确提取 6 | ![Alt text](https://github.com/liaochengcsu/road_segmentation_pytorch/blob/main/net.jpg) 7 | 8 | ## 方案策略总结: 9 | * 利用随机平衡采样法采样影像; 10 | * 训练过程中使用随机翻转,旋转,缩放,颜色空间变换等数据增强方法; 11 | * 使用BCE+Dice做损失函数,有助于类别不平衡样本的模型优化; 12 | * 引入基于梯度的边缘特征提取模块,使提取的道路边缘更精细; 13 | * 引入空间和通道注意力机制,分别提高道路完整性和特征自适应融合; 14 | * 使用ImageNet预训练的ResNext200做预训练; 15 | * BCE损失函数类别加权; 16 | * 预测过程使用原始+水平翻转+180旋转的平均值; 17 | * 对预测结果做空洞填充和噪声去除。 18 | ![Alt text](https://github.com/liaochengcsu/road_segmentation_pytorch/blob/main/data-agu.jpg) 19 | ## 训练自定义数据: 20 | ##### 1. 数据集准备。将原始影像和标签进行随机采样,裁剪后的标签文件名与影像相同,分别保存在两个文件夹下。将文件名写入csv文件,如下图所示: 21 | ![Alt text](https://github.com/liaochengcsu/road_segmentation_pytorch/blob/main/data_tree.jpg) 22 | ##### 2. train.py中修改第25-27行的csv文件路径及验证集数量 23 | ``` 24 | train_path = r'C:\Data\Road_Seg\data\data\train/train.csv' 25 | val_path = r'C:\Data\Road_Seg\data\data\val/.csv' 26 | num_test_img = 4396 27 | ``` 28 | ##### 3. data_agu.py文件第122-123、144-145行读取csv并获取文件绝对路径 29 | ``` 30 | fn = os.path.join(self.file_path, "images/" + fn) 31 | label = os.path.join(self.file_path, "labels/" + lab) 32 | ``` 33 | ##### 4. 程序训练过程中自动下载基于ImageNet的预训练模型,执行完一个epoch后会计算验证集精度,训练完成后将模型按照初赛格式提交即可。我们的验证结果如下图: 34 | ![Alt text](https://github.com/liaochengcsu/road_segmentation_pytorch/blob/main/result.jpg) 35 | 36 | # Help 37 | Any question? Please contact me with: liaocheng@my.swjtu.edu.cn 38 | -------------------------------------------------------------------------------- /Submission.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 3 | 4 | import cv2 5 | import glob 6 | import json 7 | import numpy as np 8 | import pandas as pd 9 | from tqdm import tqdm 10 | from omegaconf import OmegaConf 11 | 12 | import torch 13 | import pycocotools.mask as mutils 14 | 15 | from src.models import * 16 | 17 | data_dir = "./data/test" 18 | train_images_A = sorted(glob.glob(os.path.join(data_dir, "A/*"))) 19 | train_images_B = sorted(glob.glob(os.path.join(data_dir, "B/*"))) 20 | df = pd.DataFrame({"image_file_A": train_images_A, "image_file_B": train_images_B}) 21 | df["uid"] = df.image_file_A.apply(lambda x: int(os.path.basename(x).split(".")[0])) 22 | 23 | 24 | def get_model(cfg): 25 | cfg = cfg.copy() 26 | model = eval(cfg.pop("type"))(**cfg) 27 | return model 28 | 29 | def get_models(names, folds): 30 | model_infos = [ 31 | dict( 32 | ckpt = f"./logs/{name}/f{fold}/last.ckpt", 33 | ) for name in names for fold in folds 34 | ] 35 | models = [] 36 | for model_info in model_infos: 37 | if not os.path.exists(model_info["ckpt"]): 38 | model_info['ckpt'] = sorted(glob.glob(model_info['ckpt']))[-1] 39 | stt = torch.load(model_info["ckpt"], map_location = "cpu") 40 | cfg = OmegaConf.create(eval(str(stt["hyper_parameters"]))).model 41 | stt = {k[6:]: v for k, v in stt["state_dict"].items()} 42 | 43 | model = get_model(cfg) 44 | model.load_state_dict(stt, strict = True) 45 | model.eval() 46 | model.cuda() 47 | models.append(model) 48 | return models 49 | 50 | mean = np.array([0.485, 0.456, 0.406]) 51 | std = np.array([0.229, 0.224, 0.225]) 52 | def load(row): 53 | imgA = cv2.imread(row.image_file_A) 54 | imgB = cv2.imread(row.image_file_B) 55 | imgA = cv2.cvtColor(imgA, cv2.COLOR_BGR2RGB) 56 | imgB = cv2.cvtColor(imgB, cv2.COLOR_BGR2RGB) 57 | imgA = (imgA / 255. - mean) / std 58 | imgB = (imgB / 255. - mean) / std 59 | img = np.concatenate([imgA, imgB], -1).astype(np.float32) 60 | return img, None 61 | 62 | 63 | def predict(row, models, img): 64 | img = torch.tensor(img.transpose(2, 0, 1)).unsqueeze(0).cuda() 65 | with torch.no_grad(): 66 | preds = [] 67 | for model in models: 68 | pred = model(img).sigmoid() 69 | pred = pred.squeeze().detach().cpu().numpy() 70 | preds.append(pred) 71 | pred = sum(preds) / len(preds) 72 | return pred 73 | 74 | 75 | def get_dt(row, pred, img_id, dts): 76 | mask = pred.round().astype(np.uint8) 77 | nc, label = cv2.connectedComponents(mask, connectivity = 8) 78 | for c in range(nc): 79 | if np.all(mask[label == c] == 0): 80 | continue 81 | else: 82 | ann = np.asfortranarray((label == c).astype(np.uint8)) 83 | rle = mutils.encode(ann) 84 | bbox = [int(_) for _ in mutils.toBbox(rle)] 85 | area = int(mutils.area(rle)) 86 | score = float(pred[label == c].mean()) 87 | dts.append({ 88 | "segmentation": { 89 | "size": [int(_) for _ in rle["size"]], 90 | "counts": rle["counts"].decode()}, 91 | "bbox": [int(_) for _ in bbox], "area": int(area), "iscrowd": 0, "category_id": 1, 92 | "image_id": int(img_id), "id": len(dts), 93 | "score": float(score) 94 | }) 95 | 96 | names = [ 97 | "base" 98 | ] 99 | folds = [0] 100 | 101 | os.system("mkdir -p results") 102 | sub = df 103 | models = get_models(names, folds) 104 | dts = [] 105 | for idx in tqdm(range(len(sub))): 106 | row = sub.loc[idx] 107 | img, mask = load(row) 108 | pred = predict(row, models, img) 109 | get_dt(row, pred, row.uid, dts) 110 | with open("./results/test.segm.json", "w") as f: 111 | json.dump(dts, f) 112 | os.system("zip -9 -r results.zip results") -------------------------------------------------------------------------------- /data-agu.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liaochengcsu/road_segmentation_pytorch/bfcc2c9c83eb78c514ce77474e3345b47c810969/data-agu.jpg -------------------------------------------------------------------------------- /data.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liaochengcsu/road_segmentation_pytorch/bfcc2c9c83eb78c514ce77474e3345b47c810969/data.jpg -------------------------------------------------------------------------------- /data_agu.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import os 4 | from PIL import Image,ImageEnhance 5 | from torch.utils.data import Dataset 6 | import pandas as pd 7 | import random 8 | 9 | 10 | def randomHueSaturationValue(image, hue_shift_limit=(-180, 180), 11 | sat_shift_limit=(-255, 255), 12 | val_shift_limit=(-255, 255), u=0.5): 13 | if np.random.random() < u: 14 | image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV) 15 | h, s, v = cv2.split(image) 16 | hue_shift = np.random.randint(hue_shift_limit[0], hue_shift_limit[1]+1) 17 | hue_shift = np.uint8(hue_shift) 18 | h += hue_shift 19 | sat_shift = np.random.uniform(sat_shift_limit[0], sat_shift_limit[1]) 20 | s = cv2.add(s, sat_shift) 21 | val_shift = np.random.uniform(val_shift_limit[0], val_shift_limit[1]) 22 | v = cv2.add(v, val_shift) 23 | image = cv2.merge((h, s, v)) 24 | #image = cv2.merge((s, v)) 25 | image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR) 26 | return image 27 | 28 | 29 | def randomShiftScaleRotate(image, mask, 30 | shift_limit=(-0.0, 0.0), 31 | scale_limit=(-0.0, 0.0), 32 | rotate_limit=(-0.0, 0.0), 33 | aspect_limit=(-0.0, 0.0), 34 | borderMode=cv2.BORDER_CONSTANT, u=0.5): 35 | if np.random.random() < u: 36 | height, width, channel = image.shape 37 | 38 | angle = np.random.uniform(rotate_limit[0], rotate_limit[1]) 39 | scale = np.random.uniform(1 + scale_limit[0], 1 + scale_limit[1]) 40 | aspect = np.random.uniform(1 + aspect_limit[0], 1 + aspect_limit[1]) 41 | sx = scale * aspect / (aspect ** 0.5) 42 | sy = scale / (aspect ** 0.5) 43 | dx = round(np.random.uniform(shift_limit[0], shift_limit[1]) * width) 44 | dy = round(np.random.uniform(shift_limit[0], shift_limit[1]) * height) 45 | 46 | cc = np.math.cos(angle / 180 * np.math.pi) * sx 47 | ss = np.math.sin(angle / 180 * np.math.pi) * sy 48 | rotate_matrix = np.array([[cc, -ss], [ss, cc]]) 49 | 50 | box0 = np.array([[0, 0], [width, 0], [width, height], [0, height], ]) 51 | box1 = box0 - np.array([width / 2, height / 2]) 52 | box1 = np.dot(box1, rotate_matrix.T) + np.array([width / 2 + dx, height / 2 + dy]) 53 | 54 | box0 = box0.astype(np.float32) 55 | box1 = box1.astype(np.float32) 56 | mat = cv2.getPerspectiveTransform(box0, box1) 57 | image = cv2.warpPerspective(image, mat, (width, height), flags=cv2.INTER_LINEAR, borderMode=borderMode, 58 | borderValue=( 59 | 0, 0, 60 | 0,)) 61 | mask = cv2.warpPerspective(mask, mat, (width, height), flags=cv2.INTER_LINEAR, borderMode=borderMode, 62 | borderValue=( 63 | 0, 0, 64 | 0,)) 65 | 66 | return image, mask 67 | 68 | 69 | def randomHorizontalFlip(image, mask, u=0.5): 70 | if np.random.random() < u: 71 | image = cv2.flip(image, 1) 72 | mask = cv2.flip(mask, 1) 73 | 74 | return image, mask 75 | 76 | 77 | def randomVerticleFlip(image, mask, u=0.5): 78 | if np.random.random() < u: 79 | image = cv2.flip(image, 0) 80 | mask = cv2.flip(mask, 0) 81 | 82 | return image, mask 83 | 84 | 85 | def randomRotate90(image, mask, u=0.5): 86 | if np.random.random() < u: 87 | image=np.rot90(image) 88 | mask=np.rot90(mask) 89 | 90 | return image, mask 91 | 92 | 93 | def grade(img): 94 | x = cv2.Sobel(img, cv2.CV_32F, 1, 0, ksize=1) 95 | y = cv2.Sobel(img, cv2.CV_32F, 0, 1, ksize=1) 96 | absX = cv2.convertScaleAbs(x) 97 | absY = cv2.convertScaleAbs(y) 98 | dst = cv2.addWeighted(absX, 0.5, absY, 0.5, 0) 99 | mi=np.min(dst) 100 | ma=np.max(dst) 101 | res=(dst-mi)/(0.000000001+(ma-mi)) 102 | res[np.isnan(res)]=0 103 | return res 104 | 105 | 106 | class Mydataset(Dataset): 107 | def __init__(self, path,augment=False,transform=None, target_transform=None): 108 | 109 | self.aug=augment 110 | self.file_path=os.path.dirname(path) 111 | data = pd.read_csv(path) # 获取csv表中的数据 112 | imgs = [] 113 | for i in range(len(data)): 114 | imgs.append((data.iloc[i,0], data.iloc[i,1])) 115 | self.imgs = imgs 116 | self.transform = transform 117 | self.target_transform = target_transform 118 | 119 | def __getitem__(self, item): 120 | if self.aug==False: 121 | fn, lab = self.imgs[item] 122 | # fn = os.path.join(self.file_path, "image_A/" + fn) 123 | # label = os.path.join(self.file_path, "image_A/" + lab) 124 | fn = os.path.join(self.file_path, "images/"+ fn) 125 | label = os.path.join(self.file_path, "labels/"+ lab) 126 | 127 | bgr_img = cv2.imread(fn, -1) 128 | rgb_img = bgr_img[..., ::-1] # bgr2rgb 129 | gray = cv2.cvtColor(rgb_img, cv2.COLOR_RGB2GRAY) 130 | grad = (255 * grade(gray)).astype(np.uint8) 131 | 132 | # img = Image.open(fn).convert('RGB') 133 | img = cv2.merge([rgb_img, grad]) 134 | img = Image.fromarray(img, mode="CMYK") 135 | if self.transform is not None: 136 | img = self.transform(img) 137 | 138 | gt = cv2.imread(label, -1) 139 | return img, gt, lab 140 | 141 | 142 | else: 143 | # 进行数据增强 144 | fn, lab = self.imgs[item] 145 | # train with data.cvs 146 | fn = os.path.join(self.file_path, "images/"+ fn) 147 | label = os.path.join(self.file_path, "labels/"+ lab) 148 | 149 | gt = cv2.imread(label, -1) 150 | image = cv2.imread(fn,-1) 151 | 152 | image = randomHueSaturationValue(image, 153 | hue_shift_limit=(-30, 30), 154 | sat_shift_limit=(-5, 5), 155 | val_shift_limit=(-15, 15)) 156 | 157 | image, gt = randomShiftScaleRotate(image, gt, 158 | shift_limit=(-0.1, 0.1), 159 | scale_limit=(-0.1, 0.1), 160 | aspect_limit=(-0.1, 0.1), 161 | rotate_limit=(-0, 0)) 162 | 163 | image, gt = randomHorizontalFlip(image, gt) 164 | image, gt = randomVerticleFlip(image, gt) 165 | image, gt = randomRotate90(image, gt) 166 | 167 | rgb_img = image[..., ::-1] # bgr2rgb 168 | gray = cv2.cvtColor(rgb_img, cv2.COLOR_RGB2GRAY) 169 | grad = (255 * grade(gray)).astype(np.uint8) 170 | img = cv2.merge([rgb_img, grad]) 171 | img = Image.fromarray(img, mode="CMYK") 172 | if self.transform is not None: 173 | img = self.transform(img.copy()) 174 | return img, gt.copy(), lab 175 | 176 | def __len__(self): 177 | return len(self.imgs) 178 | 179 | -------------------------------------------------------------------------------- /data_tree.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liaochengcsu/road_segmentation_pytorch/bfcc2c9c83eb78c514ce77474e3345b47c810969/data_tree.jpg -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | """ 2 | Lovasz-Softmax and Jaccard hinge loss in PyTorch 3 | Maxim Berman 2018 ESAT-PSI KU Leuven (MIT License) 4 | https://github.com/bermanmaxim/LovaszSoftmax/blob/master/pytorch/lovasz_losses.py 5 | """ 6 | 7 | import torch.nn as nn 8 | 9 | # from __future__ import print_function, division 10 | 11 | import torch 12 | from torch.autograd import Variable 13 | import torch.nn.functional as F 14 | import numpy as np 15 | 16 | try: 17 | from itertools import ifilterfalse 18 | except ImportError: # py3k 19 | from itertools import filterfalse as ifilterfalse 20 | 21 | import torch 22 | import torch.nn as nn 23 | from torch.autograd import Variable as V 24 | 25 | import cv2 26 | import numpy as np 27 | 28 | 29 | class dice_bce_loss(nn.Module): 30 | def __init__(self, batch=True): 31 | super(dice_bce_loss, self).__init__() 32 | self.batch = batch 33 | self.bce_loss = nn.BCELoss() 34 | 35 | def soft_dice_coeff(self, y_true, y_pred): 36 | smooth = 0.0 # may change 37 | if self.batch: 38 | i = torch.sum(y_true) 39 | j = torch.sum(y_pred) 40 | intersection = torch.sum(y_true * y_pred) 41 | else: 42 | i = y_true.sum(1).sum(1).sum(1) 43 | j = y_pred.sum(1).sum(1).sum(1) 44 | intersection = (y_true * y_pred).sum(1).sum(1).sum(1) 45 | # score = (2. * intersection + smooth) / (i + j + smooth) 46 | score = (intersection + smooth) / (i + j - intersection + smooth)#iou 47 | return score.mean() 48 | 49 | def soft_dice_loss(self, y_true, y_pred): 50 | loss = 1 - self.soft_dice_coeff(y_true, y_pred) 51 | return loss 52 | 53 | def __call__(self, y_true, y_pred): 54 | a = self.bce_loss(y_pred, y_true) 55 | b = self.soft_dice_loss(y_true, y_pred) 56 | return a + b 57 | 58 | 59 | 60 | class dice_bce_loss_with_logits(nn.Module): 61 | def __init__(self, batch=True): 62 | super(dice_bce_loss_with_logits, self).__init__() 63 | self.batch = batch 64 | # self.bce_loss = nn.BCELoss() 65 | # self.bce_loss = F.binary_cross_entropy_with_logits() 66 | 67 | def soft_dice_coeff(self, y_true, y_pred): 68 | y_pred = torch.sigmoid(y_pred) 69 | # smooth = 0.0 # may change 70 | smooth = 1.0 # may change 71 | if self.batch: 72 | i = torch.sum(y_true) 73 | j = torch.sum(y_pred) 74 | intersection = torch.sum(y_true * y_pred) 75 | else: 76 | i = y_true.sum(1).sum(1).sum(1) 77 | j = y_pred.sum(1).sum(1).sum(1) 78 | intersection = (y_true * y_pred).sum(1).sum(1).sum(1) 79 | # score = (2. * intersection + smooth) / (i + j + smooth) 80 | score = (intersection + smooth) / (i + j - intersection + smooth) #iou 81 | return score.mean() 82 | 83 | def soft_dice_loss(self, y_true, y_pred): 84 | loss = 1 - self.soft_dice_coeff(y_true, y_pred) 85 | return loss 86 | 87 | def __call__(self, y_true, y_pred): 88 | a = F.binary_cross_entropy_with_logits(y_pred, y_true, pos_weight=torch.Tensor([1.5]).cuda()) 89 | # a = nn.BCEWithLogitsLoss(y_pred, y_true, pos_weight=torch.Tensor([1.5]).cuda()) 90 | # a = self.bce_loss(y_pred, y_true) 91 | b = self.soft_dice_loss(y_true, y_pred) 92 | return a + b 93 | 94 | 95 | def make_one_hot(input, num_classes): 96 | """Convert class index tensor to one hot encoding tensor. 97 | Args: 98 | input: A tensor of shape [N, 1, *] 99 | num_classes: An int of number of class 100 | Returns: 101 | A tensor of shape [N, num_classes, *] 102 | """ 103 | shape = np.array(input.shape) 104 | shape[1] = num_classes 105 | shape = tuple(shape) 106 | result = torch.zeros(shape) 107 | result = result.scatter_(1, input.cpu(), 1) 108 | 109 | return result 110 | 111 | 112 | 113 | class BinaryDiceLoss(nn.Module): 114 | """Dice loss of binary class 115 | Args: 116 | smooth: A float number to smooth loss, and avoid NaN error, default: 1 117 | p: Denominator value: \sum{x^p} + \sum{y^p}, default: 2 118 | predict: A tensor of shape [N, *] 119 | target: A tensor of shape same with predict 120 | reduction: Reduction method to apply, return mean over batch if 'mean', 121 | return sum if 'sum', return a tensor of shape [N,] if 'none' 122 | Returns: 123 | Loss tensor according to arg reduction 124 | Raise: 125 | Exception if unexpected reduction 126 | """ 127 | def __init__(self, smooth=1, p=2, reduction='mean'): 128 | super(BinaryDiceLoss, self).__init__() 129 | self.smooth = smooth 130 | self.p = p 131 | self.reduction = reduction 132 | 133 | def forward(self, predict, target): 134 | assert predict.shape[0] == target.shape[0], "predict & target batch size don't match" 135 | predict = predict.contiguous().view(predict.shape[0], -1) 136 | target = target.contiguous().view(target.shape[0], -1) 137 | 138 | num = torch.sum(torch.mul(predict, target), dim=1) + self.smooth 139 | den = torch.sum(predict.pow(self.p) + target.pow(self.p), dim=1) + self.smooth 140 | 141 | loss = 1 - num / den 142 | 143 | if self.reduction == 'mean': 144 | return loss.mean() 145 | elif self.reduction == 'sum': 146 | return loss.sum() 147 | elif self.reduction == 'none': 148 | return loss 149 | else: 150 | raise Exception('Unexpected reduction {}'.format(self.reduction)) 151 | 152 | 153 | class DiceLoss(nn.Module): 154 | """Dice loss, need one hot encode input 155 | Args: 156 | weight: An array of shape [num_classes,] 157 | ignore_index: class index to ignore 158 | predict: A tensor of shape [N, C, *] 159 | target: A tensor of same shape with predict 160 | other args pass to BinaryDiceLoss 161 | Return: 162 | same as BinaryDiceLoss 163 | """ 164 | def __init__(self, weight=None, ignore_index=None, **kwargs): 165 | super(DiceLoss, self).__init__() 166 | self.kwargs = kwargs 167 | self.weight = weight 168 | self.ignore_index = ignore_index 169 | 170 | def forward(self, predict, target): 171 | assert predict.shape == target.shape, 'predict & target shape do not match' 172 | dice = BinaryDiceLoss(**self.kwargs) 173 | total_loss = 0 174 | predict = F.softmax(predict, dim=1) 175 | 176 | for i in range(target.shape[1]): 177 | if i != self.ignore_index: 178 | dice_loss = dice(predict[:, i], target[:, i]) 179 | if self.weight is not None: 180 | assert self.weight.shape[0] == target.shape[1], \ 181 | 'Expect weight shape [{}], get[{}]'.format(target.shape[1], self.weight.shape[0]) 182 | dice_loss *= self.weights[i] 183 | total_loss += dice_loss 184 | 185 | return total_loss/target.shape[1] 186 | 187 | 188 | def lovasz_grad(gt_sorted): 189 | """ 190 | Computes gradient of the Lovasz extension w.r.t sorted errors 191 | See Alg. 1 in paper 192 | """ 193 | p = len(gt_sorted) 194 | gts = gt_sorted.sum() 195 | intersection = gts - gt_sorted.float().cumsum(0) 196 | union = gts + (1 - gt_sorted).float().cumsum(0) 197 | jaccard = 1. - intersection / union 198 | if p > 1: # cover 1-pixel case 199 | jaccard[1:p] = jaccard[1:p] - jaccard[0:-1] 200 | return jaccard 201 | 202 | 203 | def iou_binary(preds, labels, EMPTY=1., ignore=None, per_image=True): 204 | """ 205 | IoU for foreground class 206 | binary: 1 foreground, 0 background 207 | """ 208 | if not per_image: 209 | preds, labels = (preds,), (labels,) 210 | ious = [] 211 | for pred, label in zip(preds, labels): 212 | intersection = ((label == 1) & (pred == 1)).sum() 213 | union = ((label == 1) | ((pred == 1) & (label != ignore))).sum() 214 | if not union: 215 | iou = EMPTY 216 | else: 217 | iou = float(intersection) / float(union) 218 | ious.append(iou) 219 | iou = mean(ious) # mean accross images if per_image 220 | return 100 * iou 221 | 222 | 223 | def iou(preds, labels, C, EMPTY=1., ignore=None, per_image=False): 224 | """ 225 | Array of IoU for each (non ignored) class 226 | """ 227 | if not per_image: 228 | preds, labels = (preds,), (labels,) 229 | ious = [] 230 | for pred, label in zip(preds, labels): 231 | iou = [] 232 | for i in range(C): 233 | if i != ignore: # The ignored label is sometimes among predicted classes (ENet - CityScapes) 234 | intersection = ((label == i) & (pred == i)).sum() 235 | union = ((label == i) | ((pred == i) & (label != ignore))).sum() 236 | if not union: 237 | iou.append(EMPTY) 238 | else: 239 | iou.append(float(intersection) / float(union)) 240 | ious.append(iou) 241 | ious = [mean(iou) for iou in zip(*ious)] # mean accross images if per_image 242 | return 100 * np.array(ious) 243 | 244 | 245 | # --------------------------- BINARY LOSSES --------------------------- 246 | 247 | def lovasz_hinge(logits, labels, per_image=True, ignore=None): 248 | """ 249 | Binary Lovasz hinge loss 250 | logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty) 251 | labels: [B, H, W] Tensor, binary ground truth masks (0 or 1) 252 | per_image: compute the loss per image instead of per batch 253 | ignore: void class id 254 | """ 255 | if per_image: 256 | loss = mean(lovasz_hinge_flat(*flatten_binary_scores(log.unsqueeze(0), lab.unsqueeze(0), ignore)) 257 | for log, lab in zip(logits, labels)) 258 | else: 259 | loss = lovasz_hinge_flat(*flatten_binary_scores(logits, labels, ignore)) 260 | return loss 261 | 262 | 263 | def lovasz_hinge_flat(logits, labels): 264 | """ 265 | Binary Lovasz hinge loss 266 | logits: [P] Variable, logits at each prediction (between -\infty and +\infty) 267 | labels: [P] Tensor, binary ground truth labels (0 or 1) 268 | ignore: label to ignore 269 | """ 270 | if len(labels) == 0: 271 | # only void pixels, the gradients should be 0 272 | return logits.sum() * 0. 273 | signs = 2. * labels.float() - 1. 274 | errors = (1. - logits * Variable(signs)) 275 | errors_sorted, perm = torch.sort(errors, dim=0, descending=True) 276 | perm = perm.data 277 | gt_sorted = labels[perm] 278 | grad = lovasz_grad(gt_sorted) 279 | loss = torch.dot(F.relu(errors_sorted), Variable(grad)) 280 | return loss 281 | 282 | 283 | def flatten_binary_scores(scores, labels, ignore=None): 284 | """ 285 | Flattens predictions in the batch (binary case) 286 | Remove labels equal to 'ignore' 287 | """ 288 | scores = scores.view(-1) 289 | labels = labels.view(-1) 290 | if ignore is None: 291 | return scores, labels 292 | valid = (labels != ignore) 293 | vscores = scores[valid] 294 | vlabels = labels[valid] 295 | return vscores, vlabels 296 | 297 | 298 | class StableBCELoss(torch.nn.modules.Module): 299 | def __init__(self): 300 | super(StableBCELoss, self).__init__() 301 | 302 | def forward(self, input, target): 303 | neg_abs = - input.abs() 304 | loss = input.clamp(min=0) - input * target + (1 + neg_abs.exp()).log() 305 | return loss.mean() 306 | 307 | 308 | def binary_xloss(logits, labels, ignore=None): 309 | """ 310 | Binary Cross entropy loss 311 | logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty) 312 | labels: [B, H, W] Tensor, binary ground truth masks (0 or 1) 313 | ignore: void class id 314 | """ 315 | logits, labels = flatten_binary_scores(logits, labels, ignore) 316 | loss = StableBCELoss()(logits, Variable(labels.float())) 317 | return loss 318 | 319 | 320 | # --------------------------- MULTICLASS LOSSES --------------------------- 321 | 322 | 323 | def lovasz_softmax(probas, labels, classes='present', per_image=False, ignore=None): 324 | """ 325 | Multi-class Lovasz-Softmax loss 326 | probas: [B, C, H, W] Variable, class probabilities at each prediction (between 0 and 1). 327 | Interpreted as binary (sigmoid) output with outputs of size [B, H, W]. 328 | labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1) 329 | classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average. 330 | per_image: compute the loss per image instead of per batch 331 | ignore: void class labels 332 | """ 333 | if per_image: 334 | loss = mean(lovasz_softmax_flat(*flatten_probas(prob.unsqueeze(0), lab.unsqueeze(0), ignore), classes=classes) 335 | for prob, lab in zip(probas, labels)) 336 | else: 337 | loss = lovasz_softmax_flat(*flatten_probas(probas, labels, ignore), classes=classes) 338 | return loss 339 | 340 | 341 | def lovasz_softmax_flat(probas, labels, classes='present'): 342 | """ 343 | Multi-class Lovasz-Softmax loss 344 | probas: [P, C] Variable, class probabilities at each prediction (between 0 and 1) 345 | labels: [P] Tensor, ground truth labels (between 0 and C - 1) 346 | classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average. 347 | """ 348 | if probas.numel() == 0: 349 | # only void pixels, the gradients should be 0 350 | return probas * 0. 351 | C = probas.size(1) 352 | losses = [] 353 | class_to_sum = list(range(C)) if classes in ['all', 'present'] else classes 354 | for c in class_to_sum: 355 | fg = (labels == c).float() # foreground for class c 356 | if (classes is 'present' and fg.sum() == 0): 357 | continue 358 | if C == 1: 359 | if len(classes) > 1: 360 | raise ValueError('Sigmoid output possible only with 1 class') 361 | class_pred = probas[:, 0] 362 | else: 363 | class_pred = probas[:, c] 364 | errors = (Variable(fg) - class_pred).abs() 365 | errors_sorted, perm = torch.sort(errors, 0, descending=True) 366 | perm = perm.data 367 | fg_sorted = fg[perm] 368 | losses.append(torch.dot(errors_sorted, Variable(lovasz_grad(fg_sorted)))) 369 | return mean(losses) 370 | 371 | 372 | def flatten_probas(probas, labels, ignore=None): 373 | """ 374 | Flattens predictions in the batch 375 | """ 376 | if probas.dim() == 3: 377 | # assumes output of a sigmoid layer 378 | B, H, W = probas.size() 379 | probas = probas.view(B, 1, H, W) 380 | B, C, H, W = probas.size() 381 | probas = probas.permute(0, 2, 3, 1).contiguous().view(-1, C) # B * H * W, C = P, C 382 | labels = labels.view(-1) 383 | if ignore is None: 384 | return probas, labels 385 | valid = (labels != ignore) 386 | vprobas = probas[valid.nonzero().squeeze()] 387 | vlabels = labels[valid] 388 | return vprobas, vlabels 389 | 390 | 391 | def xloss(logits, labels, ignore=None): 392 | """ 393 | Cross entropy loss 394 | """ 395 | return F.cross_entropy(logits, Variable(labels), ignore_index=255) 396 | 397 | 398 | # --------------------------- HELPER FUNCTIONS --------------------------- 399 | def isnan(x): 400 | return x != x 401 | 402 | 403 | def mean(l, ignore_nan=False, empty=0): 404 | """ 405 | nanmean compatible with generators. 406 | """ 407 | l = iter(l) 408 | if ignore_nan: 409 | l = ifilterfalse(isnan, l) 410 | try: 411 | n = 1 412 | acc = next(l) 413 | except StopIteration: 414 | if empty == 'raise': 415 | raise ValueError('Empty mean') 416 | return empty 417 | for n, v in enumerate(l, 2): 418 | acc += v 419 | if n == 1: 420 | return acc 421 | return acc / n 422 | 423 | 424 | class LovaszSoftmax(nn.Module): 425 | def __init__(self, classes='present', per_image=False, ignore_index=255): 426 | super(LovaszSoftmax, self).__init__() 427 | self.smooth = classes 428 | self.per_image = per_image 429 | self.ignore_index = ignore_index 430 | 431 | def forward(self, output, target): 432 | logits = F.softmax(output, dim=1) 433 | loss = lovasz_softmax(logits, target, ignore=self.ignore_index) 434 | return loss -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from torch.nn.modules.conv import _ConvNd 7 | from torch.nn.modules.utils import _pair 8 | from torch.nn import Conv2d, Module, ReLU 9 | from functools import partial 10 | nonlinearity = partial(F.relu, inplace=True) 11 | 12 | 13 | __all__ = ['ResNet', 'Bottleneck','SplAtConv2d','resnest50', 'resnest101', 'resnest200', 'resnest269'] 14 | 15 | _url_format = 'https://s3.us-west-1.wasabisys.com/resnest/torch/{}-{}.pth' 16 | 17 | _model_sha256 = {name: checksum for checksum, name in [ 18 | ('528c19ca', 'resnest50'), 19 | ('22405ba7', 'resnest101'), 20 | ('75117900', 'resnest200'), 21 | ('0cc87c48', 'resnest269'), 22 | ]} 23 | 24 | 25 | def short_hash(name): 26 | if name not in _model_sha256: 27 | raise ValueError('Pretrained model for {name} is not available.'.format(name=name)) 28 | return _model_sha256[name][:8] 29 | 30 | 31 | resnest_model_urls = {name: _url_format.format(name, short_hash(name)) for 32 | name in _model_sha256.keys() 33 | } 34 | 35 | 36 | def conv3x3(in_planes, out_planes, stride=1): 37 | """3x3 convolution with padding""" 38 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 39 | padding=1, bias=False) 40 | 41 | class SplAtConv2d(Module): 42 | """Split-Attention Conv2d 43 | """ 44 | def __init__(self, in_channels, channels, kernel_size, stride=(1, 1), padding=(0, 0), 45 | dilation=(1, 1), groups=1, bias=True, 46 | radix=2, reduction_factor=4, 47 | rectify=False, rectify_avg=False, norm_layer=None, 48 | dropblock_prob=0.0, **kwargs): 49 | super(SplAtConv2d, self).__init__() 50 | padding = _pair(padding) 51 | self.rectify = rectify and (padding[0] > 0 or padding[1] > 0) 52 | self.rectify_avg = rectify_avg 53 | inter_channels = max(in_channels*radix//reduction_factor, 32) 54 | self.radix = radix 55 | self.cardinality = groups 56 | self.channels = channels 57 | self.dropblock_prob = dropblock_prob 58 | if self.rectify: 59 | from rfconv import RFConv2d 60 | self.conv = RFConv2d(in_channels, channels*radix, kernel_size, stride, padding, dilation, 61 | groups=groups*radix, bias=bias, average_mode=rectify_avg, **kwargs) 62 | else: 63 | self.conv = Conv2d(in_channels, channels*radix, kernel_size, stride, padding, dilation, 64 | groups=groups*radix, bias=bias, **kwargs) 65 | self.use_bn = norm_layer is not None 66 | if self.use_bn: 67 | self.bn0 = norm_layer(channels*radix) 68 | self.relu = ReLU(inplace=True) 69 | self.fc1 = Conv2d(channels, inter_channels, 1, groups=self.cardinality) 70 | if self.use_bn: 71 | self.bn1 = norm_layer(inter_channels) 72 | self.fc2 = Conv2d(inter_channels, channels*radix, 1, groups=self.cardinality) 73 | if dropblock_prob > 0.0: 74 | self.dropblock = DropBlock2D(dropblock_prob, 3) 75 | self.rsoftmax = rSoftMax(radix, groups) 76 | 77 | def forward(self, x): 78 | x = self.conv(x) 79 | if self.use_bn: 80 | x = self.bn0(x) 81 | if self.dropblock_prob > 0.0: 82 | x = self.dropblock(x) 83 | x = self.relu(x) 84 | 85 | batch, rchannel = x.shape[:2] 86 | if self.radix > 1: 87 | if torch.__version__ < '1.5': 88 | splited = torch.split(x, int(rchannel//self.radix), dim=1) 89 | else: 90 | splited = torch.split(x, rchannel//self.radix, dim=1) 91 | gap = sum(splited) 92 | else: 93 | gap = x 94 | gap = F.adaptive_avg_pool2d(gap, 1) 95 | gap = self.fc1(gap) 96 | 97 | if self.use_bn: 98 | gap = self.bn1(gap) 99 | gap = self.relu(gap) 100 | 101 | atten = self.fc2(gap) 102 | atten = self.rsoftmax(atten).view(batch, -1, 1, 1) 103 | 104 | if self.radix > 1: 105 | if torch.__version__ < '1.5': 106 | attens = torch.split(atten, int(rchannel//self.radix), dim=1) 107 | else: 108 | attens = torch.split(atten, rchannel//self.radix, dim=1) 109 | out = sum([att*split for (att, split) in zip(attens, splited)]) 110 | else: 111 | out = atten * x 112 | return out.contiguous() 113 | 114 | class rSoftMax(nn.Module): 115 | def __init__(self, radix, cardinality): 116 | super().__init__() 117 | self.radix = radix 118 | self.cardinality = cardinality 119 | 120 | def forward(self, x): 121 | batch = x.size(0) 122 | if self.radix > 1: 123 | x = x.view(batch, self.cardinality, self.radix, -1).transpose(1, 2) 124 | x = F.softmax(x, dim=1) 125 | x = x.reshape(batch, -1) 126 | else: 127 | x = torch.sigmoid(x) 128 | return x 129 | 130 | class DropBlock2D(object): 131 | def __init__(self, *args, **kwargs): 132 | raise NotImplementedError 133 | 134 | class GlobalAvgPool2d(nn.Module): 135 | def __init__(self): 136 | """Global average pooling over the input's spatial dimensions""" 137 | super(GlobalAvgPool2d, self).__init__() 138 | 139 | def forward(self, inputs): 140 | return nn.functional.adaptive_avg_pool2d(inputs, 1).view(inputs.size(0), -1) 141 | 142 | 143 | class Bottleneck(nn.Module): 144 | """ResNet Bottleneck 145 | """ 146 | # pylint: disable=unused-argument 147 | expansion = 4 148 | def __init__(self, inplanes, planes, stride=1, downsample=None, 149 | radix=1, cardinality=1, bottleneck_width=64, 150 | avd=False, avd_first=False, dilation=1, is_first=False, 151 | rectified_conv=False, rectify_avg=False, 152 | norm_layer=None, dropblock_prob=0.0, last_gamma=False): 153 | super(Bottleneck, self).__init__() 154 | group_width = int(planes * (bottleneck_width / 64.)) * cardinality 155 | self.conv1 = nn.Conv2d(inplanes, group_width, kernel_size=1, bias=False) 156 | self.bn1 = norm_layer(group_width) 157 | self.dropblock_prob = dropblock_prob 158 | self.radix = radix 159 | self.avd = avd and (stride > 1 or is_first) 160 | self.avd_first = avd_first 161 | 162 | if self.avd: 163 | self.avd_layer = nn.AvgPool2d(3, stride, padding=1) 164 | stride = 1 165 | 166 | if dropblock_prob > 0.0: 167 | self.dropblock1 = DropBlock2D(dropblock_prob, 3) 168 | if radix == 1: 169 | self.dropblock2 = DropBlock2D(dropblock_prob, 3) 170 | self.dropblock3 = DropBlock2D(dropblock_prob, 3) 171 | 172 | if radix >= 1: 173 | self.conv2 = SplAtConv2d( 174 | group_width, group_width, kernel_size=3, 175 | stride=stride, padding=dilation, 176 | dilation=dilation, groups=cardinality, bias=False, 177 | radix=radix, rectify=rectified_conv, 178 | rectify_avg=rectify_avg, 179 | norm_layer=norm_layer, 180 | dropblock_prob=dropblock_prob) 181 | elif rectified_conv: 182 | from rfconv import RFConv2d 183 | self.conv2 = RFConv2d( 184 | group_width, group_width, kernel_size=3, stride=stride, 185 | padding=dilation, dilation=dilation, 186 | groups=cardinality, bias=False, 187 | average_mode=rectify_avg) 188 | self.bn2 = norm_layer(group_width) 189 | else: 190 | self.conv2 = nn.Conv2d( 191 | group_width, group_width, kernel_size=3, stride=stride, 192 | padding=dilation, dilation=dilation, 193 | groups=cardinality, bias=False) 194 | self.bn2 = norm_layer(group_width) 195 | 196 | self.conv3 = nn.Conv2d( 197 | group_width, planes * 4, kernel_size=1, bias=False) 198 | self.bn3 = norm_layer(planes*4) 199 | 200 | if last_gamma: 201 | from torch.nn.init import zeros_ 202 | zeros_(self.bn3.weight) 203 | self.relu = nn.ReLU(inplace=True) 204 | self.downsample = downsample 205 | self.dilation = dilation 206 | self.stride = stride 207 | 208 | def forward(self, x): 209 | residual = x 210 | 211 | out = self.conv1(x) 212 | out = self.bn1(out) 213 | if self.dropblock_prob > 0.0: 214 | out = self.dropblock1(out) 215 | out = self.relu(out) 216 | if self.avd and self.avd_first: 217 | out = self.avd_layer(out) 218 | out = self.conv2(out) 219 | if self.radix == 0: 220 | out = self.bn2(out) 221 | if self.dropblock_prob > 0.0: 222 | out = self.dropblock2(out) 223 | out = self.relu(out) 224 | if self.avd and not self.avd_first: 225 | out = self.avd_layer(out) 226 | out = self.conv3(out) 227 | out = self.bn3(out) 228 | if self.dropblock_prob > 0.0: 229 | out = self.dropblock3(out) 230 | if self.downsample is not None: 231 | residual = self.downsample(x) 232 | out += residual 233 | out = self.relu(out) 234 | 235 | return out 236 | 237 | 238 | class ResNet(nn.Module): 239 | # pylint: disable=unused-variable 240 | def __init__(self, block, layers, radix=1, groups=1, bottleneck_width=64, 241 | num_classes=1000, dilated=False, dilation=1, 242 | deep_stem=False, stem_width=64, avg_down=False, 243 | rectified_conv=False, rectify_avg=False, 244 | avd=False, avd_first=False, 245 | final_drop=0.0, dropblock_prob=0, 246 | last_gamma=False, norm_layer=nn.BatchNorm2d,pretrained=False): 247 | self.cardinality = groups 248 | self.bottleneck_width = bottleneck_width 249 | # ResNet-D params 250 | self.inplanes = stem_width*2 if deep_stem else 64 251 | self.avg_down = avg_down 252 | self.last_gamma = last_gamma 253 | # ResNeSt params 254 | self.radix = radix 255 | self.avd = avd 256 | self.avd_first = avd_first 257 | 258 | super(ResNet, self).__init__() 259 | self.rectified_conv = rectified_conv 260 | self.rectify_avg = rectify_avg 261 | if rectified_conv: 262 | from rfconv import RFConv2d 263 | conv_layer = RFConv2d 264 | else: 265 | conv_layer = nn.Conv2d 266 | conv_kwargs = {'average_mode': rectify_avg} if rectified_conv else {} 267 | if deep_stem: 268 | self.conv1 = nn.Sequential( 269 | conv_layer(3, stem_width, kernel_size=3, stride=2, padding=1, bias=False, **conv_kwargs), 270 | norm_layer(stem_width), 271 | nn.ReLU(inplace=True), 272 | conv_layer(stem_width, stem_width, kernel_size=3, stride=1, padding=1, bias=False, **conv_kwargs), 273 | norm_layer(stem_width), 274 | nn.ReLU(inplace=True), 275 | conv_layer(stem_width, stem_width*2, kernel_size=3, stride=1, padding=1, bias=False, **conv_kwargs), 276 | ) 277 | else: 278 | self.conv1 = conv_layer(3, 64, kernel_size=7, stride=2, padding=3, 279 | bias=False, **conv_kwargs) 280 | self.bn1 = norm_layer(self.inplanes) 281 | self.relu = nn.ReLU(inplace=True) 282 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 283 | self.layer1 = self._make_layer(block, 64, layers[0], norm_layer=norm_layer, is_first=False) 284 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, norm_layer=norm_layer) 285 | if dilated or dilation == 4: 286 | self.layer3 = self._make_layer(block, 256, layers[2], stride=1, 287 | dilation=2, norm_layer=norm_layer, 288 | dropblock_prob=dropblock_prob) 289 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1, 290 | dilation=4, norm_layer=norm_layer, 291 | dropblock_prob=dropblock_prob) 292 | elif dilation==2: 293 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 294 | dilation=1, norm_layer=norm_layer, 295 | dropblock_prob=dropblock_prob) 296 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1, 297 | dilation=2, norm_layer=norm_layer, 298 | dropblock_prob=dropblock_prob) 299 | else: 300 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 301 | norm_layer=norm_layer, 302 | dropblock_prob=dropblock_prob) 303 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 304 | norm_layer=norm_layer, 305 | dropblock_prob=dropblock_prob) 306 | 307 | self._init_weight() 308 | 309 | if pretrained: 310 | self._load_pretrained_model() 311 | self.avgpool = GlobalAvgPool2d() 312 | self.drop = nn.Dropout(final_drop) if final_drop > 0.0 else None 313 | self.fc = nn.Linear(512 * block.expansion, num_classes) 314 | 315 | for m in self.modules(): 316 | if isinstance(m, nn.Conv2d): 317 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 318 | m.weight.data.normal_(0, math.sqrt(2. / n)) 319 | elif isinstance(m, norm_layer): 320 | m.weight.data.fill_(1) 321 | m.bias.data.zero_() 322 | 323 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, norm_layer=None, 324 | dropblock_prob=0.0, is_first=True): 325 | downsample = None 326 | if stride != 1 or self.inplanes != planes * block.expansion: 327 | down_layers = [] 328 | if self.avg_down: 329 | if dilation == 1: 330 | down_layers.append(nn.AvgPool2d(kernel_size=stride, stride=stride, 331 | ceil_mode=True, count_include_pad=False)) 332 | else: 333 | down_layers.append(nn.AvgPool2d(kernel_size=1, stride=1, 334 | ceil_mode=True, count_include_pad=False)) 335 | down_layers.append(nn.Conv2d(self.inplanes, planes * block.expansion, 336 | kernel_size=1, stride=1, bias=False)) 337 | else: 338 | down_layers.append(nn.Conv2d(self.inplanes, planes * block.expansion, 339 | kernel_size=1, stride=stride, bias=False)) 340 | down_layers.append(norm_layer(planes * block.expansion)) 341 | downsample = nn.Sequential(*down_layers) 342 | 343 | layers = [] 344 | if dilation == 1 or dilation == 2: 345 | layers.append(block(self.inplanes, planes, stride, downsample=downsample, 346 | radix=self.radix, cardinality=self.cardinality, 347 | bottleneck_width=self.bottleneck_width, 348 | avd=self.avd, avd_first=self.avd_first, 349 | dilation=1, is_first=is_first, rectified_conv=self.rectified_conv, 350 | rectify_avg=self.rectify_avg, 351 | norm_layer=norm_layer, dropblock_prob=dropblock_prob, 352 | last_gamma=self.last_gamma)) 353 | elif dilation == 4: 354 | layers.append(block(self.inplanes, planes, stride, downsample=downsample, 355 | radix=self.radix, cardinality=self.cardinality, 356 | bottleneck_width=self.bottleneck_width, 357 | avd=self.avd, avd_first=self.avd_first, 358 | dilation=2, is_first=is_first, rectified_conv=self.rectified_conv, 359 | rectify_avg=self.rectify_avg, 360 | norm_layer=norm_layer, dropblock_prob=dropblock_prob, 361 | last_gamma=self.last_gamma)) 362 | else: 363 | raise RuntimeError("=> unknown dilation size: {}".format(dilation)) 364 | 365 | self.inplanes = planes * block.expansion 366 | for i in range(1, blocks): 367 | layers.append(block(self.inplanes, planes, 368 | radix=self.radix, cardinality=self.cardinality, 369 | bottleneck_width=self.bottleneck_width, 370 | avd=self.avd, avd_first=self.avd_first, 371 | dilation=dilation, rectified_conv=self.rectified_conv, 372 | rectify_avg=self.rectify_avg, 373 | norm_layer=norm_layer, dropblock_prob=dropblock_prob, 374 | last_gamma=self.last_gamma)) 375 | 376 | return nn.Sequential(*layers) 377 | 378 | def forward(self, x): 379 | x = self.conv1(x) 380 | x = self.bn1(x) 381 | x = self.relu(x) 382 | x = self.maxpool(x) 383 | 384 | lay1 = self.layer1(x) 385 | lay2 = self.layer2(lay1) 386 | lay3 = self.layer3(lay2) 387 | lay4 = self.layer4(lay3) 388 | return x, lay1, lay2, lay3, lay4 389 | 390 | def _init_weight(self): 391 | for m in self.modules(): 392 | if isinstance(m, nn.Conv2d): 393 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 394 | # m.weight.data.normal_(0, math.sqrt(2. / n)) 395 | torch.nn.init.kaiming_normal_(m.weight) 396 | elif isinstance(m, nn.BatchNorm2d): 397 | m.weight.data.fill_(1) 398 | m.bias.data.zero_() 399 | 400 | def _load_pretrained_model(self): 401 | pretrain_dict = torch.load("utils/resnet101-5d3b4d8f.pth", map_location=torch.device('cpu')) 402 | model_dict = {} 403 | state_dict = self.state_dict() 404 | for k, v in pretrain_dict.items(): 405 | if k in state_dict: 406 | model_dict[k] = v 407 | state_dict.update(model_dict) 408 | self.load_state_dict(state_dict) 409 | 410 | 411 | def short_hash(name): 412 | if name not in _model_sha256: 413 | raise ValueError('Pretrained model for {name} is not available.'.format(name=name)) 414 | return _model_sha256[name][:8] 415 | 416 | 417 | resnest_model_urls = {name: _url_format.format(name, short_hash(name)) for 418 | name in _model_sha256.keys() 419 | } 420 | 421 | 422 | def resnest50(pretrained=False, **kwargs): 423 | model = ResNet(Bottleneck, [3, 4, 6, 3], 424 | radix=2, groups=1, bottleneck_width=64, 425 | deep_stem=True, stem_width=32, avg_down=True, 426 | avd=True, avd_first=False, **kwargs) 427 | if pretrained: 428 | model.load_state_dict(torch.hub.load_state_dict_from_url( 429 | resnest_model_urls['resnest50'], progress=True, check_hash=True)) 430 | return model 431 | 432 | 433 | def resnest101(pretrained=False, **kwargs): 434 | model = ResNet(Bottleneck, [3, 4, 23, 3], 435 | radix=2, groups=1, bottleneck_width=64,num_classes=1000, 436 | deep_stem=True, stem_width=64, avg_down=True, 437 | avd=True, avd_first=False, **kwargs) 438 | if pretrained: 439 | model.load_state_dict(torch.hub.load_state_dict_from_url( 440 | resnest_model_urls['resnest101'], progress=True, check_hash=True)) 441 | return model 442 | 443 | 444 | def resnest200(pretrained=False, **kwargs): 445 | model = ResNet(Bottleneck, [3, 24, 36, 3], 446 | radix=2, groups=1, bottleneck_width=64, 447 | deep_stem=True, stem_width=64, avg_down=True, 448 | avd=True, avd_first=False, **kwargs) 449 | if pretrained: 450 | model.load_state_dict(torch.hub.load_state_dict_from_url( 451 | resnest_model_urls['resnest200'], progress=True, check_hash=True)) 452 | return model 453 | 454 | 455 | def resnest269(pretrained=False, **kwargs): 456 | model = ResNet(Bottleneck, [3, 30, 48, 8], 457 | radix=2, groups=1, bottleneck_width=64, 458 | deep_stem=True, stem_width=64, avg_down=True, 459 | avd=True, avd_first=False, **kwargs) 460 | if pretrained: 461 | model.load_state_dict(torch.hub.load_state_dict_from_url( 462 | resnest_model_urls['resnest269'], progress=True, check_hash=True)) 463 | return model 464 | 465 | 466 | class PositionAttentionModule(nn.Module): 467 | """ Position attention module""" 468 | def __init__(self, in_channels, **kwargs): 469 | super(PositionAttentionModule, self).__init__() 470 | self.conv_b = nn.Conv2d(in_channels, in_channels // 8, 1) 471 | self.conv_c = nn.Conv2d(in_channels, in_channels // 8, 1) 472 | self.conv_d = nn.Conv2d(in_channels, in_channels, 1) 473 | self.alpha = nn.Parameter(torch.zeros(1)) 474 | self.softmax = nn.Softmax(dim=-1) 475 | 476 | def forward(self, x): 477 | batch_size, _, height, width = x.size() 478 | feat_b = self.conv_b(x).view(batch_size, -1, height * width).permute(0, 2, 1) 479 | feat_c = self.conv_c(x).view(batch_size, -1, height * width) 480 | attention_s = self.softmax(torch.bmm(feat_b, feat_c)) 481 | feat_d = self.conv_d(x).view(batch_size, -1, height * width) 482 | feat_e = torch.bmm(feat_d, attention_s.permute(0, 2, 1)).view(batch_size, -1, height, width) 483 | out = self.alpha * feat_e + x 484 | return out 485 | 486 | 487 | class ChannelAttentionModule(nn.Module): 488 | """Channel attention module""" 489 | def __init__(self, **kwargs): 490 | super(ChannelAttentionModule, self).__init__() 491 | self.beta = nn.Parameter(torch.zeros(1)) 492 | self.softmax = nn.Softmax(dim=-1) 493 | 494 | def forward(self, x): 495 | batch_size, _, height, width = x.size() 496 | feat_a = x.view(batch_size, -1, height * width) 497 | feat_a_transpose = x.view(batch_size, -1, height * width).permute(0, 2, 1) 498 | attention = torch.bmm(feat_a, feat_a_transpose) 499 | attention_new = torch.max(attention, dim=-1, keepdim=True)[0].expand_as(attention) - attention 500 | attention = self.softmax(attention_new) 501 | feat_e = torch.bmm(attention, feat_a).view(batch_size, -1, height, width) 502 | out = self.beta * feat_e + x 503 | return out 504 | 505 | 506 | class SEModule(nn.Module): 507 | def __init__(self, channels, reduction=1): 508 | super(SEModule, self).__init__() 509 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 510 | self.fc1 = nn.Conv2d(channels, channels // reduction, kernel_size=1, 511 | padding=0) 512 | self.relu = nn.ReLU(inplace=True) 513 | self.fc2 = nn.Conv2d(channels // reduction, channels, kernel_size=1, 514 | padding=0) 515 | self.sigmoid = nn.Sigmoid() 516 | 517 | def forward(self, x): 518 | residual = x 519 | module_input = x 520 | x = self.avg_pool(x) 521 | x = self.fc1(x) 522 | x = self.relu(x) 523 | x = self.fc2(x) 524 | x = self.sigmoid(x) 525 | return module_input * x + residual 526 | 527 | 528 | class Dblock(nn.Module): 529 | def __init__(self, channel): 530 | super(Dblock, self).__init__() 531 | self.conv = nn.Conv2d(5120, 1024, 1, bias=False) 532 | self.bn = nn.BatchNorm2d(1024) 533 | self.cam5 = ChannelAttentionModule() 534 | self.dilate1 = nn.Conv2d(channel, channel, kernel_size=3, dilation=1, padding=1) 535 | self.dilate2 = nn.Conv2d(channel, channel, kernel_size=3, dilation=2, padding=2) 536 | self.dilate3 = nn.Conv2d(channel, channel, kernel_size=3, dilation=4, padding=4) 537 | self.dilate4 = nn.Conv2d(channel, channel, kernel_size=3, dilation=8, padding=8) 538 | # self.dilate5 = nn.Conv2d(channel, channel, kernel_size=3, dilation=16, padding=16) 539 | for m in self.modules(): 540 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): 541 | if m.bias is not None: 542 | m.bias.data.zero_() 543 | 544 | # def forward(self, x): 545 | # dilate1_out = nonlinearity(self.dilate1(x)) 546 | # dilate2_out = nonlinearity(self.dilate2(dilate1_out)) 547 | # dilate3_out = nonlinearity(self.dilate3(dilate2_out)) 548 | # dilate4_out = nonlinearity(self.dilate4(dilate3_out)) 549 | # # dilate5_out = nonlinearity(self.dilate5(dilate4_out)) 550 | # out = x + dilate1_out + dilate2_out + dilate3_out + dilate4_out # + dilate5_out 551 | # return out 552 | def forward(self, x): 553 | dilate1_out = nonlinearity(self.dilate1(x)) 554 | dilate2_out = nonlinearity(self.dilate2(x)) 555 | dilate3_out = nonlinearity(self.dilate3(x)) 556 | dilate4_out = nonlinearity(self.dilate4(x)) 557 | # dilate5_out = nonlinearity(self.dilate5(dilate4_out)) 558 | # out = x + dilate1_out + dilate2_out + dilate3_out + dilate4_out 559 | # out = (x + dilate1_out + dilate2_out + dilate3_out + dilate4_out)/5.0 # + dilate5_out 560 | out = torch.cat((x, dilate1_out, dilate2_out, dilate3_out, dilate4_out), dim=1) 561 | out = self.cam5(out) 562 | out = self.conv(out) 563 | out = self.bn(out) 564 | return out 565 | 566 | class DecoderBlock(nn.Module): 567 | def __init__(self, in_channels, n_filters): 568 | super(DecoderBlock, self).__init__() 569 | 570 | self.conv1 = nn.Conv2d(in_channels, in_channels // 4, 1) 571 | self.norm1 = nn.BatchNorm2d(in_channels // 4) 572 | self.relu1 = nonlinearity 573 | 574 | self.deconv2 = nn.ConvTranspose2d(in_channels // 4, in_channels // 4, 3, stride=2, padding=1, output_padding=1) 575 | self.norm2 = nn.BatchNorm2d(in_channels // 4) 576 | self.relu2 = nonlinearity 577 | 578 | self.conv3 = nn.Conv2d(in_channels // 4, n_filters, 1) 579 | self.norm3 = nn.BatchNorm2d(n_filters) 580 | self.relu3 = nonlinearity 581 | 582 | def forward(self, x): 583 | x = self.conv1(x) 584 | x = self.norm1(x) 585 | x = self.relu1(x) 586 | x = self.deconv2(x) 587 | x = self.norm2(x) 588 | x = self.relu2(x) 589 | x = self.conv3(x) 590 | x = self.norm3(x) 591 | x = self.relu3(x) 592 | return x 593 | 594 | 595 | class GatedSpatialConv2d(_ConvNd): 596 | def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, 597 | padding=0, dilation=1, groups=1, bias=False): 598 | kernel_size = _pair(kernel_size) 599 | stride = _pair(stride) 600 | padding = _pair(padding) 601 | dilation = _pair(dilation) 602 | super(GatedSpatialConv2d, self).__init__( 603 | in_channels, out_channels, kernel_size, stride, padding, dilation, 604 | False, _pair(0), groups, bias,'zeros') 605 | 606 | self._gate_conv = nn.Sequential( 607 | nn.BatchNorm2d(in_channels + 1), 608 | nn.Conv2d(in_channels + 1, in_channels + 1, 1), 609 | nn.ReLU(), 610 | nn.Conv2d(in_channels + 1, 1, 1), 611 | nn.BatchNorm2d(1), 612 | nn.Sigmoid() 613 | ) 614 | 615 | def forward(self, input_features, gating_features): 616 | """ 617 | :param input_features: [NxCxHxW] featuers comming from the shape branch (canny branch). 618 | :param gating_features: [Nx1xHxW] features comming from the texture branch (resnet). Only one channel feature map. 619 | :return: 620 | """ 621 | alphas = self._gate_conv(torch.cat([input_features, gating_features], dim=1)) 622 | 623 | input_features = (input_features * (alphas + 1)) 624 | return F.conv2d(input_features, self.weight, self.bias, self.stride, 625 | self.padding, self.dilation, self.groups) 626 | 627 | def reset_parameters(self): 628 | nn.init.xavier_normal_(self.weight) 629 | if self.bias is not None: 630 | nn.init.zeros_(self.bias) 631 | 632 | 633 | def conv3x3(in_planes, out_planes, stride=1): 634 | """3x3 convolution with padding""" 635 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 636 | padding=1, bias=False) 637 | 638 | 639 | class BasicBlock(nn.Module): 640 | expansion = 1 641 | def __init__(self, inplanes, planes, stride=1, downsample=None): 642 | super(BasicBlock, self).__init__() 643 | self.conv1 = conv3x3(inplanes, planes, stride) 644 | self.bn1 = nn.BatchNorm2d(planes) 645 | self.relu = nn.ReLU(inplace=True) 646 | self.conv2 = conv3x3(planes, planes) 647 | self.bn2 = nn.BatchNorm2d(planes) 648 | self.downsample = downsample 649 | self.stride = stride 650 | for m in self.modules(): 651 | if isinstance(m, nn.Conv2d): 652 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 653 | elif isinstance(m, nn.BatchNorm2d): 654 | nn.init.constant_(m.weight, 1) 655 | nn.init.constant_(m.bias, 0) 656 | 657 | def forward(self, x): 658 | residual = x 659 | 660 | out = self.conv1(x) 661 | out = self.bn1(out) 662 | out = self.relu(out) 663 | 664 | out = self.conv2(out) 665 | out = self.bn2(out) 666 | 667 | if self.downsample is not None: 668 | residual = self.downsample(x) 669 | 670 | out += residual 671 | out = self.relu(out) 672 | 673 | return out 674 | 675 | 676 | class AttDinkNet34(nn.Module): 677 | def __init__(self, num_classes=1, num_channels=3, pretrained=False): 678 | super(AttDinkNet34, self).__init__() 679 | 680 | # filters = [64, 128, 256, 512] 681 | filters = [128, 256, 512, 1024] 682 | 683 | self.resnet_features = resnest200(pretrained=pretrained) 684 | 685 | self.resx = BasicBlock(64, 64, stride=1, downsample=None) 686 | self.conv4 = nn.Conv2d(2048, 1024, 1, bias=False) 687 | self.bn4 = nn.BatchNorm2d(1024) 688 | self.rl4 = nn.ReLU() 689 | self.conv3 = nn.Conv2d(1024, 512, 1, bias=False) 690 | self.bn3 = nn.BatchNorm2d(512) 691 | self.rl3 = nn.ReLU() 692 | self.conv2 = nn.Conv2d(512, 256, 1, bias=False) 693 | self.bn2 = nn.BatchNorm2d(256) 694 | self.rl2 = nn.ReLU() 695 | self.conv1 = nn.Conv2d(256, 128, 1, bias=False) 696 | self.bn1 = nn.BatchNorm2d(128) 697 | self.rl1 = nn.ReLU() 698 | 699 | self.dblock = Dblock(1024) 700 | self.pam1 = PositionAttentionModule(256) 701 | self.pam2 = PositionAttentionModule(512) 702 | self.pam3 = PositionAttentionModule(1024) 703 | self.cam = ChannelAttentionModule() 704 | self.sel = SEModule(channels=128,reduction=1) 705 | 706 | self.gate1 = GatedSpatialConv2d(64, 64) 707 | self.gate2 = GatedSpatialConv2d(32, 32) 708 | self.gate3 = GatedSpatialConv2d(16, 16) 709 | 710 | self.down1 = nn.Conv2d(1024, 128, kernel_size=1, padding=0, bias=False) 711 | self.down2 = nn.Conv2d(512, 1, kernel_size=1, padding=0, bias=False) 712 | self.down3 = nn.Conv2d(256, 1, kernel_size=1, padding=0, bias=False) 713 | self.down4 = nn.Conv2d(128, 1, kernel_size=1, padding=0, bias=False) 714 | self.res1 = BasicBlock(128, 128, stride=1, downsample=None) 715 | self.d1 = nn.Conv2d(128, 64, 1) 716 | self.res2 = BasicBlock(64, 64, stride=1, downsample=None) 717 | self.d2 = nn.Conv2d(64, 32, 1) 718 | self.res3 = BasicBlock(32, 32, stride=1, downsample=None) 719 | self.d3 = nn.Conv2d(32, 16, 1) 720 | self.fuse = nn.Conv2d(16, 1, kernel_size=1, padding=0, bias=False) 721 | self.cw = nn.Conv2d(2, 1, kernel_size=1, padding=0, bias=False) 722 | 723 | self.fuse3 = nn.Conv2d(64, 1, kernel_size=1, padding=0, bias=False) 724 | self.fuse2 = nn.Conv2d(32, 1, kernel_size=1, padding=0, bias=False) 725 | 726 | self.decoder4 = DecoderBlock(filters[3], filters[2]) 727 | self.decoder3 = DecoderBlock(filters[2], filters[1]) 728 | self.decoder2 = DecoderBlock(filters[1], filters[0]) 729 | self.decoder1 = DecoderBlock(filters[0], filters[0]) 730 | 731 | self.finaldeconv1 = nn.ConvTranspose2d(filters[0], 64, 4, 2, 1) 732 | self.finalrelu1 = nonlinearity 733 | self.finalconv2 = nn.Conv2d(64, 64, 3, padding=1) 734 | self.finalrelu2 = nonlinearity 735 | self.finalconv3 = nn.Conv2d(64, num_classes, 3, padding=1) 736 | 737 | def forward(self, input): 738 | # Encoder 739 | x, e1, e2, e3, e4 = self.resnet_features(input[:, 0:3, :, :]) 740 | x_size = input.size() 741 | # x = self.resx(x) 742 | 743 | e1=self.conv1(e1) #128 128 128 744 | e1=self.bn1(e1) 745 | # 128,128,128 746 | # e1 = self.rl1(e1) 747 | 748 | e2 = self.conv2(e2) # 256 64 64 749 | e2 = self.bn2(e2) 750 | # 256,64,64 751 | e2 = self.rl2(e2) 752 | e2 = self.pam1(e2) 753 | 754 | e3 = self.conv3(e3) # 512 32 32 755 | e3 = self.bn3(e3) 756 | # 512,32,32 757 | e3 = self.rl3(e3) 758 | e3 = self.pam2(e3) 759 | 760 | e4 = self.conv4(e4) # 1024 16 16 761 | e4 = self.bn4(e4) 762 | # 1024,16,16 763 | e4 = self.rl4(e4) 764 | e4 = self.pam3(e4) 765 | # Center 766 | e4 = self.dblock(e4) 767 | # e4 = self.pam3(e4) 768 | # e4 = self.cam(e4) 769 | 770 | canny = input[:, -1, :, :] 771 | canny = torch.unsqueeze(canny, dim=1) 772 | 773 | m1f = self.down1(e4) # 1024-128 [16,16] 774 | cs = self.res1(m1f) 775 | cs = F.interpolate(cs, e3.size()[2:], mode='bilinear', align_corners=True) 776 | cs = self.d1(cs) # 128-64 [32,32] 777 | m2f = self.down2(e3) # 256-1 [32,32] 778 | cs = self.gate1(cs, m2f) 779 | cs = self.res2(cs) # 32 780 | cs = F.interpolate(cs, e2.size()[2:], mode='bilinear', align_corners=True) 781 | cs = self.d2(cs) # 32-16 [64,64] 782 | m3f = self.down3(e2) # 128-1 [64,64] 783 | cs = self.gate2(cs, m3f) 784 | cs = self.res3(cs) # 16 785 | cs = F.interpolate(cs, e1.size()[2:], mode='bilinear', align_corners=True) 786 | cs = self.d3(cs) # 16-8 [128,128] 787 | m4f = self.down4(e1) # 64-1 [128,128] 788 | cs = self.gate3(cs, m4f) # 8 789 | cs = self.fuse(cs) # 8-》1 790 | cs = F.interpolate(cs, x_size[2:], mode='bilinear', align_corners=True) 791 | edge_out = torch.sigmoid(cs) 792 | cat = torch.cat((edge_out, canny), dim=1) 793 | acts = self.cw(cat) 794 | acts = torch.sigmoid(acts) 795 | 796 | # Decoder 797 | d4 = self.decoder4(e4) + e3 798 | # d4 = self.cam(d4) 799 | d3 = self.decoder3(d4) + e2 800 | d2 = self.decoder2(d3) + e1 801 | d1 = self.decoder1(d2) 802 | # d1 = self.decoder1(d2 + x) 803 | d1 = self.cam(d1) 804 | # d1 = self.sel(d1) 805 | 806 | out = self.finaldeconv1(d1) 807 | out = self.finalrelu1(out) 808 | out = self.finalconv2(out) 809 | out = self.finalrelu2(out) 810 | out = self.finalconv3(out) 811 | 812 | # return torch.sigmoid(out + acts) 813 | return out + acts -------------------------------------------------------------------------------- /net.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liaochengcsu/road_segmentation_pytorch/bfcc2c9c83eb78c514ce77474e3345b47c810969/net.jpg -------------------------------------------------------------------------------- /result.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liaochengcsu/road_segmentation_pytorch/bfcc2c9c83eb78c514ce77474e3345b47c810969/result.jpg -------------------------------------------------------------------------------- /seg_iou.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import copyreg 4 | import types 5 | 6 | 7 | def pixel_accuracy(eval_segm, gt_segm): 8 | ''' 9 | sum_i(n_ii) / sum_i(t_i) 10 | ''' 11 | 12 | check_size(eval_segm, gt_segm) 13 | 14 | cl, n_cl = extract_classes(gt_segm) 15 | eval_mask, gt_mask = extract_both_masks(eval_segm, gt_segm, cl, n_cl) 16 | 17 | sum_n_ii = 0 18 | sum_t_i = 0 19 | 20 | for i, c in enumerate(cl): 21 | curr_eval_mask = eval_mask[i, :, :] 22 | curr_gt_mask = gt_mask[i, :, :] 23 | 24 | sum_n_ii += np.sum(np.logical_and(curr_eval_mask, curr_gt_mask)) 25 | sum_t_i += np.sum(curr_gt_mask) 26 | 27 | if (sum_t_i == 0): 28 | pixel_accuracy_ = 0 29 | else: 30 | pixel_accuracy_ = sum_n_ii / sum_t_i 31 | 32 | return pixel_accuracy_ 33 | 34 | 35 | def mean_accuracy(eval_segm, gt_segm): 36 | ''' 37 | (1/n_cl) sum_i(n_ii/t_i) 38 | ''' 39 | 40 | check_size(eval_segm, gt_segm) 41 | 42 | cl, n_cl = extract_classes(gt_segm) 43 | eval_mask, gt_mask = extract_both_masks(eval_segm, gt_segm, cl, n_cl) 44 | 45 | accuracy = list([0]) * n_cl 46 | 47 | for i, c in enumerate(cl): 48 | curr_eval_mask = eval_mask[i, :, :] 49 | curr_gt_mask = gt_mask[i, :, :] 50 | 51 | n_ii = np.sum(np.logical_and(curr_eval_mask, curr_gt_mask)) 52 | t_i = np.sum(curr_gt_mask) 53 | 54 | if (t_i != 0): 55 | accuracy[i] = n_ii / t_i 56 | 57 | mean_accuracy_ = np.mean(accuracy) 58 | return mean_accuracy_ 59 | 60 | 61 | def mean_IU(eval_segm, gt_segm): 62 | ''' 63 | (1/n_cl) * sum_i(n_ii / (t_i + sum_j(n_ji) - n_ii)) 64 | ''' 65 | 66 | check_size(eval_segm, gt_segm) 67 | 68 | cl, n_cl = union_classes(eval_segm, gt_segm) 69 | _, n_cl_gt = extract_classes(gt_segm) 70 | eval_mask, gt_mask = extract_both_masks(eval_segm, gt_segm, cl, n_cl) 71 | 72 | IU = list([0]) * n_cl 73 | 74 | for i, c in enumerate(cl): 75 | curr_eval_mask = eval_mask[i, :, :] 76 | curr_gt_mask = gt_mask[i, :, :] 77 | 78 | if (np.sum(curr_eval_mask) == 0) or (np.sum(curr_gt_mask) == 0): 79 | continue 80 | 81 | n_ii = np.sum(np.logical_and(curr_eval_mask, curr_gt_mask)) 82 | t_i = np.sum(curr_gt_mask) 83 | n_ij = np.sum(curr_eval_mask) 84 | 85 | IU[i] = n_ii / (t_i + n_ij - n_ii) 86 | 87 | mean_IU_ = np.sum(IU) / n_cl_gt 88 | return mean_IU_ 89 | 90 | 91 | def frequency_weighted_IU(eval_segm, gt_segm): 92 | ''' 93 | sum_k(t_k)^(-1) * sum_i((t_i*n_ii)/(t_i + sum_j(n_ji) - n_ii)) 94 | ''' 95 | 96 | check_size(eval_segm, gt_segm) 97 | 98 | cl, n_cl = union_classes(eval_segm, gt_segm) 99 | eval_mask, gt_mask = extract_both_masks(eval_segm, gt_segm, cl, n_cl) 100 | 101 | frequency_weighted_IU_ = list([0]) * n_cl 102 | 103 | for i, c in enumerate(cl): 104 | curr_eval_mask = eval_mask[i, :, :] 105 | curr_gt_mask = gt_mask[i, :, :] 106 | 107 | if (np.sum(curr_eval_mask) == 0) or (np.sum(curr_gt_mask) == 0): 108 | continue 109 | 110 | n_ii = np.sum(np.logical_and(curr_eval_mask, curr_gt_mask)) 111 | t_i = np.sum(curr_gt_mask) 112 | n_ij = np.sum(curr_eval_mask) 113 | 114 | frequency_weighted_IU_[i] = (t_i * n_ii) / (t_i + n_ij - n_ii) 115 | 116 | sum_k_t_k = get_pixel_area(eval_segm) 117 | 118 | frequency_weighted_IU_ = np.sum(frequency_weighted_IU_) / sum_k_t_k 119 | return frequency_weighted_IU_ 120 | 121 | 122 | ''' 123 | Auxiliary functions used during evaluation. 124 | ''' 125 | 126 | 127 | def get_pixel_area(segm): 128 | return segm.shape[0] * segm.shape[1] 129 | 130 | 131 | def extract_both_masks(eval_segm, gt_segm, cl, n_cl): 132 | eval_mask = extract_masks(eval_segm, cl, n_cl) 133 | gt_mask = extract_masks(gt_segm, cl, n_cl) 134 | 135 | return eval_mask, gt_mask 136 | 137 | 138 | def extract_classes(segm): 139 | cl = np.unique(segm) 140 | n_cl = len(cl) 141 | 142 | return cl, n_cl 143 | 144 | 145 | def union_classes(eval_segm, gt_segm): 146 | eval_cl, _ = extract_classes(eval_segm) 147 | gt_cl, _ = extract_classes(gt_segm) 148 | 149 | cl = np.union1d(eval_cl, gt_cl) 150 | n_cl = len(cl) 151 | 152 | return cl, n_cl 153 | 154 | 155 | def extract_masks(segm, cl, n_cl): 156 | h, w = segm_size(segm) 157 | masks = np.zeros((n_cl, h, w)) 158 | 159 | for i, c in enumerate(cl): 160 | masks[i, :, :] = segm == c 161 | 162 | return masks 163 | 164 | 165 | def segm_size(segm): 166 | try: 167 | height = segm.shape[0] 168 | width = segm.shape[1] 169 | except IndexError: 170 | raise 171 | 172 | return height, width 173 | 174 | 175 | def check_size(eval_segm, gt_segm): 176 | h_e, w_e = segm_size(eval_segm) 177 | h_g, w_g = segm_size(gt_segm) 178 | 179 | if (h_e != h_g) or (w_e != w_g): 180 | print("DiffDim: Different dimensions of matrices!") 181 | 182 | 183 | def _pickle_method(m): 184 | if m.im_self is None: 185 | return getattr, (m.im_class, m.im_func.func_name) 186 | else: 187 | return getattr, (m.im_self, m.im_func.func_name) 188 | 189 | 190 | copyreg.pickle(types.MethodType, _pickle_method) 191 | 192 | class ConfusionMatrix(object): 193 | 194 | def __init__(self, nclass, classes=None, ignore_label=255): 195 | self.nclass = nclass 196 | self.classes = classes 197 | self.M = np.zeros((nclass, nclass)) 198 | self.ignore_label = ignore_label 199 | 200 | def add(self, gt, pred): 201 | assert (np.max(pred) <= self.nclass) 202 | assert (len(gt) == len(pred)) 203 | for i in range(len(gt)): 204 | if not gt[i] == self.ignore_label: 205 | self.M[gt[i], pred[i]] += 1.0 206 | 207 | def addM(self, matrix): 208 | assert (matrix.shape == self.M.shape) 209 | self.M += matrix 210 | 211 | def __str__(self): 212 | pass 213 | 214 | # Pii为预测正确的数量,Pij和Pji分别被解释为假正和假负,尽管两者都是假正与假负之和 215 | def recall(self): # 预测为正确的像素中确认为正确像素的个数 216 | recall = 0.0 217 | for i in range(self.nclass): 218 | recall += self.M[i, i] / np.sum(self.M[:, i]) 219 | 220 | return recall / self.nclass 221 | 222 | def accuracy(self): # 分割正确的像素除以总像素 223 | accuracy = 0.0 224 | for i in range(self.nclass): 225 | accuracy += self.M[i, i] / np.sum(self.M[i, :]) 226 | 227 | return accuracy / self.nclass 228 | 229 | # 雅卡尔指数,又称为交并比(IOU) 230 | def jaccard(self): 231 | jaccard = 0.0 232 | jaccard_perclass = [] 233 | for i in range(self.nclass): 234 | if not self.M[i, i] == 0: 235 | jaccard_perclass.append(self.M[i, i] / (np.sum(self.M[i, :]) + np.sum(self.M[:, i]) - self.M[i, i])) 236 | 237 | return np.sum(jaccard_perclass) / len(jaccard_perclass), jaccard_perclass, self.M 238 | 239 | def generateM(self, item): 240 | gt, pred = item 241 | m = np.zeros((self.nclass, self.nclass)) 242 | assert (len(gt) == len(pred)) 243 | for i in range(len(gt)): 244 | if gt[i] < self.nclass: # and pred[i] < self.nclass: 245 | m[gt[i], pred[i]] += 1.0 246 | return m 247 | 248 | 249 | def get_iou(data_list, class_num, save_path=None): 250 | """ 251 | Args: 252 | data_list: a list, its elements [gt, output] 253 | class_num: the number of label 254 | """ 255 | from multiprocessing import Pool 256 | 257 | ConfM = ConfusionMatrix(class_num) 258 | f = ConfM.generateM 259 | pool = Pool() 260 | m_list = pool.map(f, data_list) 261 | pool.close() 262 | pool.join() 263 | 264 | for m in m_list: 265 | ConfM.addM(m) 266 | 267 | aveJ, j_list, M = ConfM.jaccard() 268 | # print(j_list) 269 | # print(M) 270 | # print('meanIOU: ' + str(aveJ) + '\n') 271 | 272 | if save_path: 273 | with open(save_path, 'w') as f: 274 | f.write('meanIOU: ' + str(aveJ) + '\n') 275 | f.write(str(j_list) + '\n') 276 | f.write(str(M) + '\n') 277 | return aveJ, j_list 278 | 279 | 280 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import tensorflow as tf 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | from torch.autograd import Variable 6 | from torch.optim.lr_scheduler import ReduceLROnPlateau 7 | 8 | import os 9 | import numpy as np 10 | from tensorboardX import SummaryWriter 11 | from torch.utils.data import DataLoader 12 | import torchvision.transforms as transforms 13 | from datetime import datetime 14 | 15 | from data_agu import Mydataset 16 | from model import AttDinkNet34 17 | from seg_iou import mean_IU 18 | from loss import dice_bce_loss_with_logits 19 | 20 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 21 | tf.logging.set_verbosity(tf.logging.INFO) 22 | pretrained_path = r'/home/vge/PycharmProjects/seg8/road/Result/11-29_08-55-28/24_checkpoint-best8358.pth' 23 | 24 | class args: 25 | train_path = r'C:\Data\Road_Seg\data\data\train/train.csv' 26 | val_path = r'C:\Data\Road_Seg\data\data\val/test.csv' 27 | num_test_img = 4396 28 | 29 | result_dir = 'Result/' 30 | batch_size = 6 31 | learning_rate = 0.01 32 | max_epoch = 350 33 | 34 | best_train_acc = 0.6 35 | now_time = datetime.now() 36 | time_str = datetime.strftime(now_time,'%m-%d_%H-%M-%S') 37 | # 模型保存路径 38 | log_dir = os.path.join(args.result_dir,time_str) 39 | if not os.path.exists(log_dir): 40 | os.makedirs(log_dir) 41 | 42 | writer = SummaryWriter(log_dir) 43 | normMean = [0.4758, 0.4873, 0.5098, 0] 44 | normStd = [0.1670, 0.1496, 0.1477, 1] 45 | normTransfrom = transforms.Normalize(normMean, normStd) 46 | transform = transforms.Compose([ 47 | transforms.ToTensor(), 48 | normTransfrom 49 | ]) 50 | # 数据加载,详见data_agu.py 51 | train_data = Mydataset(path=args.train_path,transform=transform,augment=True) 52 | val_data = Mydataset(path=args.val_path,transform=transform,augment=False) 53 | train_loader = DataLoader(dataset=train_data, batch_size=args.batch_size, shuffle=True, drop_last=True, num_workers=4) 54 | val_loader = DataLoader(dataset=val_data, batch_size=args.batch_size * 3, shuffle=False, drop_last=True, num_workers=2) 55 | 56 | print("train data set:",len(train_loader)*args.batch_size) 57 | print("valid data set:",len(val_loader)) 58 | 59 | net = AttDinkNet34(pretrained=True) 60 | net.cuda() 61 | 62 | if torch.cuda.is_available(): 63 | # for continue training 64 | w = torch.Tensor([1.5, 1]).cuda() 65 | # continue training... 66 | # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 67 | # checkpoint = torch.load(pretrained_path) 68 | # net = net.to(device) 69 | # net.load_state_dict(checkpoint['state_dict']) 70 | else: 71 | w = torch.Tensor([1.5, 1]) 72 | 73 | # 损失函数及优化方法定义 74 | criterion4 = dice_bce_loss_with_logits().cuda() 75 | optimizer = optim.SGD(net.parameters(), lr=args.learning_rate, momentum=0.9, dampening=0.1) 76 | scheduler = ReduceLROnPlateau(optimizer, mode="max", factor=0.5, patience=3, verbose=True,min_lr=0.0000001) 77 | 78 | # ---------------------------4、训练网络--------------------------- 79 | for epoch in range(args.max_epoch): 80 | loss_sigma = 0.0 81 | loss_val_sigma = 0.0 82 | acc_val_sigma = 0.0 83 | net.train() 84 | 85 | for i,data in enumerate(train_loader): 86 | inputs, labels,lab_name = data 87 | inputs = Variable(inputs.cuda()) 88 | labels = Variable(labels.cuda()) 89 | labels = labels.float().cuda() 90 | optimizer.zero_grad() 91 | outputs = net.forward(inputs) 92 | # outputs=torch.sigmoid(outputs) 93 | outputs=torch.squeeze(outputs,dim=1) 94 | 95 | loss = criterion4(labels, outputs) 96 | loss.backward() 97 | optimizer.step() 98 | 99 | loss_sigma += loss.item() 100 | if i % 200 == 0 and i>0 : 101 | loss_avg = loss_sigma /200 102 | loss_sigma = 0.0 103 | tf.logging.info("Training:Epoch[{:0>3}/{:0>3}] Iter[{:0>3}/{:0>3}] Loss:{:.4f}".format( 104 | epoch + 1, args.max_epoch,i+1,len(train_loader),loss_avg)) 105 | writer.add_scalar("LOSS", loss_avg, epoch) 106 | 107 | # ---------------------------每个epoch验证网络--------------------------- 108 | if epoch%1==0: 109 | net.eval() 110 | acc_val_sigma = 0 111 | acc_val = 0 112 | data_list = [] 113 | for i, data in enumerate(val_loader): 114 | inputs, labels, img_name = data 115 | inputs = Variable(inputs.cuda()) 116 | labels = Variable(labels.cuda()) 117 | labels = labels.float().cuda() 118 | with torch.no_grad(): 119 | predicts = net.forward(inputs) 120 | 121 | predicts = torch.sigmoid(predicts) 122 | predicts[predicts < 0.5] = 0 123 | predicts[predicts >= 0.5] = 1 124 | result = np.squeeze(predicts) 125 | # outputs = torch.squeeze(outputs, dim=1) 126 | 127 | cc = labels.shape[0] 128 | for index in range(cc): 129 | # 评估方法为平均iou 130 | acc_val_sigma += mean_IU(labels[index].cpu().detach().numpy(), result[index].cpu().detach().numpy()) 131 | 132 | # 验证精度提高时,保存模型 133 | val_acc = acc_val_sigma / args.num_test_img 134 | print("valid acc:", val_acc) 135 | print("lr:",args.learning_rate) 136 | print("best acc:", best_train_acc) 137 | scheduler.step(val_acc) 138 | if (val_acc) > best_train_acc: 139 | # state = {'epoch': epoch, 'state_dict': net.state_dict(), 'optimizer': optimizer.state_dict()} 140 | state = {'state_dict': net.state_dict()} 141 | filename = os.path.join(log_dir, str(epoch) + '_checkpoint-best.pth') 142 | torch.save(state, filename) 143 | best_train_acc = val_acc 144 | tf.logging.info('Save model successfully to "%s"!' % (log_dir + 'net_params.pkl')) 145 | tf.logging.info("After 1 epoch:acc_val:{:.4f},loss_val:{:.4f}".format(acc_val_sigma / (len(val_loader)), 146 | loss_val_sigma / (len(val_loader)))) 147 | 148 | writer.close() 149 | net_save_path = os.path.join(log_dir,'net_params_end.pkl') 150 | torch.save(net.state_dict(),net_save_path) --------------------------------------------------------------------------------